| | |
| | |
| | ''' |
| | * @Desc: train GPT2 from scratch/ fine tuning. |
| | Modified based on Huggingface GPT-2 implementation |
| | ''' |
| |
|
| | import json |
| | import os |
| | import sys |
| | import argparse |
| | import logging |
| | import time |
| | import tqdm |
| | import datetime |
| | import torch |
| |
|
| | import numpy as np |
| |
|
| | from os.path import join |
| | from torch.distributed import get_rank, get_world_size |
| |
|
| | from lsp_model import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, Adam |
| | from gpt2_training.train_utils import load_model, boolean_string, set_lr, get_eval_list_same_length |
| | from gpt2_training.eval_utils import eval_model_loss |
| |
|
| | from data_loader import BucketingDataLoader, DynamicBatchingLoader, DistributedBucketingDataLoader |
| |
|
| |
|
| | from gpt2_training.distributed import all_reduce_and_rescale_tensors, all_gather_list |
| |
|
| |
|
| | logging.basicConfig( |
| | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
| | datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | INF = 100000000 |
| | CACHE_EMPTY_STEP = 10000 |
| | EVAL_STEP = 100000 |
| |
|
| | |
| | |
| | |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--model_name_or_path', type=str, |
| | help='pretrained model name or path to local checkpoint') |
| | parser.add_argument("--seed", type=int, default=42) |
| | parser.add_argument("--max_seq_length", type=int, default=128) |
| |
|
| | parser.add_argument("--skip_eval", action='store_true', |
| | help='If true, skip evaluation.') |
| | parser.add_argument("--init_checkpoint", type=str) |
| | parser.add_argument("--train_input_file", type=str) |
| | parser.add_argument("--eval_input_file", type=str) |
| | parser.add_argument("--continue_from", type=int, default=0) |
| |
|
| | parser.add_argument("--train_batch_size", type=int, default=4, |
| | help="batch size now means per GPU per step") |
| | parser.add_argument("--gradient_accumulation_steps", type=int, default=2, |
| | help="to increase effective batch size " |
| | "and reduce synchronization") |
| | parser.add_argument("--eval_batch_size", type=int, default=4) |
| | parser.add_argument("--learning_rate", type=float, default=1e-5) |
| | parser.add_argument("--num_optim_steps", type=int, default=1000000, |
| | help="new API specifies num update steps") |
| | parser.add_argument("--valid_step", type=int, default=10000, |
| | help="how many optim steps between validations") |
| | parser.add_argument("--warmup_proportion", type=float, default=0.1) |
| | parser.add_argument("--warmup_steps", type=int, default=16000) |
| |
|
| | parser.add_argument("--normalize_data", type=boolean_string, default=True) |
| | parser.add_argument("--fp16", type=boolean_string, default=True) |
| | parser.add_argument("--lr_schedule", type=str, |
| | choices=['noam', 'noamwd', 'BERT', 'None'], default='noam') |
| | parser.add_argument("--loss_scale", type=float, default=0) |
| | parser.add_argument("--no_token_id", type=boolean_string, default=True) |
| |
|
| | parser.add_argument("--output_dir", type=str) |
| | parser.add_argument("--log_dir", type=str) |
| | parser.add_argument('--pbar', type=boolean_string, default=True, help='turn on progress bar') |
| |
|
| | |
| | parser.add_argument('--local_rank', type=int, default=-1, |
| | help='for torch.distributed') |
| | parser.add_argument('--config', help='JSON config file') |
| |
|
| |
|
| | |
| | args = parser.parse_args() |
| |
|
| | if args.config is not None: |
| | |
| | opts = json.load(open(args.config)) |
| | for k, v in opts.items(): |
| | if isinstance(v, str): |
| | |
| | if 'PHILLY_JOB_DIRECTORY' in v: |
| | v = v.replace('PHILLY_JOB_DIRECTORY', |
| | os.environ['PHILLY_JOB_DIRECTORY']) |
| | elif 'PHILLY_LOG_DIRECTORY' in v: |
| | v = v.replace('PHILLY_LOG_DIRECTORY', |
| | os.environ['PHILLY_LOG_DIRECTORY']) |
| | setattr(args, k, v) |
| |
|
| | |
| | argv = sys.argv[1:] |
| | overrides, _ = parser.parse_known_args(argv) |
| | for k, v in vars(overrides).items(): |
| | if f'--{k}' in argv: |
| | setattr(args, k, v) |
| | setattr(args, 'local_rank', overrides.local_rank) |
| |
|
| |
|
| | assert args.train_batch_size % args.gradient_accumulation_steps == 0, \ |
| | 'batch size % gradient accumulation steps != 0!' |
| | args.train_batch_size = (args.train_batch_size |
| | // args.gradient_accumulation_steps) |
| | logger.info('train batch size = {}, ' |
| | 'new train batch size (after gradient accumulation) = {}'.format( |
| | args.train_batch_size*args.gradient_accumulation_steps, |
| | args.train_batch_size)) |
| |
|
| |
|
| | if args.local_rank == -1: |
| | logger.info('CUDA available? {}'.format(str(torch.cuda.is_available()))) |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | n_gpu = torch.cuda.device_count() |
| | args.device, args.n_gpu = device, n_gpu |
| | else: |
| | |
| | torch.cuda.set_device(args.local_rank) |
| | device = torch.device("cuda", args.local_rank) |
| | |
| | |
| | torch.distributed.init_process_group(backend='nccl') |
| | n_gpu = torch.distributed.get_world_size() |
| | args.device, args.n_gpu = device, 1 |
| | logger.info("device: {} n_gpu: {}, distributed training: {}, " |
| | "16-bits training: {}".format( |
| | device, n_gpu, bool(args.local_rank != -1), args.fp16)) |
| |
|
| | np.random.seed(args.seed) |
| | torch.random.manual_seed(args.seed) |
| | torch.cuda.manual_seed(args.seed) |
| | if n_gpu > 0: |
| | torch.cuda.manual_seed_all(args.seed) |
| |
|
| | timestamp = datetime.datetime.now().strftime('%Y-%m-%d%H%M%S') |
| | output_dir = join(args.output_dir, |
| | 'GPT2.{}.{}.{}gpu.{}'.format(args.learning_rate, |
| | args.train_batch_size, n_gpu, |
| | timestamp)) |
| | log_dir = args.log_dir if args.log_dir is not None and len(args.log_dir) > 0 else output_dir |
| | if args.local_rank == -1 or get_rank() == 0: |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | logger.info('Input Argument Information') |
| | args_dict = vars(args) |
| | for a in args_dict: |
| | logger.info('%-28s %s' % (a, args_dict[a])) |
| |
|
| |
|
| | |
| | |
| | |
| | enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) |
| |
|
| | config = GPT2Config.from_json_file( |
| | join(args.model_name_or_path, 'config.json')) |
| |
|
| | if args.local_rank == -1: |
| | train_dataloader = BucketingDataLoader(args.train_input_file, |
| | args.train_batch_size, |
| | args.max_seq_length) |
| | else: |
| | train_dataloader = DistributedBucketingDataLoader( |
| | get_rank(), get_world_size(), |
| | args.train_input_file, args.train_batch_size, |
| | args.max_seq_length) |
| |
|
| | eval_dataloader_loss = DynamicBatchingLoader( |
| | args.eval_input_file, enc, args.normalize_data, |
| | args.eval_batch_size, args.max_seq_length) |
| |
|
| | eval_dataloader_gen = get_eval_list_same_length( |
| | args.eval_input_file, enc, args.eval_batch_size, True) |
| |
|
| |
|
| | |
| | |
| | |
| | model = load_model(GPT2LMHeadModel(config), args.init_checkpoint, |
| | args, verbose=True) |
| | if args.local_rank != -1: |
| | |
| | params = [p.data for p in model.parameters()] |
| | all_reduce_and_rescale_tensors( |
| | params, float(torch.distributed.get_world_size())) |
| |
|
| | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) |
| | total_params = sum([np.prod(p.size()) for p in model_parameters]) |
| | logger.info('Number of parameter = {}'.format(total_params)) |
| |
|
| | param_optimizer = list(model.named_parameters()) |
| | no_decay = ['bias', 'ln'] |
| | optimizer_grouped_parameters = [ |
| | {'params': [p for n, p in param_optimizer |
| | if not any(nd in n for nd in no_decay)], |
| | 'weight_decay': 0.01}, |
| | {'params': [p for n, p in param_optimizer |
| | if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} |
| | ] |
| |
|
| | if args.fp16: |
| | logger.info('in fp16, using FusedAdam') |
| | try: |
| | from apex.optimizers import FP16_Optimizer |
| | from apex.optimizers import FusedAdam |
| | except ImportError: |
| | raise ImportError( |
| | "Please install apex from https://www.github.com/nvidia/apex " |
| | "to use distributed and fp16 training.") |
| |
|
| | optimizer = FusedAdam(optimizer_grouped_parameters, |
| | lr=args.learning_rate, |
| | bias_correction=False, |
| | max_grad_norm=1.0) |
| | if args.loss_scale == 0: |
| | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True, |
| | verbose=False) |
| | else: |
| | optimizer = FP16_Optimizer(optimizer, |
| | static_loss_scale=args.loss_scale, |
| | verbose=False) |
| | else: |
| | optimizer = Adam(optimizer_grouped_parameters, args.learning_rate, |
| | max_grad_norm=1.0) |
| |
|
| | |
| | |
| | |
| |
|
| | if args.local_rank == -1 or get_rank() == 0: |
| | train_logger = open(join(log_dir, 'train_log.txt'), 'a+', buffering=1) |
| | eval_logger = open(join(log_dir, 'eval_log.txt'), 'a+', buffering=1) |
| | print('epoch,global_step,step,mean_loss,mean_ppl,n_token_real,' |
| | 'n_token_total,epoch_time', file=train_logger) |
| | print('epoch,global_step,step,eval_loss,eval_ppl', file=eval_logger) |
| |
|
| | global_step = 0 |
| | step = 0 |
| | epoch = 0 |
| |
|
| | if args.continue_from: |
| | global_step = args.continue_from |
| | step = global_step*2 - 1 |
| |
|
| |
|
| | if args.local_rank != -1: |
| | n_gpu = 1 |
| | if args.local_rank == -1 or get_rank() == 0: |
| | if args.pbar: |
| | pbar = tqdm.tqdm(total=args.num_optim_steps, desc=f"training") |
| | else: |
| | pbar = None |
| |
|
| | while True: |
| | model.train() |
| | (tr_loss, tr_ppl, mean_ppl, nb_tr_examples, nb_tr_steps) = 0.0, 0.0, 0.0, 0, 0 |
| | n_token_real, n_token_total = 0, 0 |
| | train_start_time_epoch = time.time() |
| | for batch in train_dataloader: |
| | |
| | seq_len = batch[0].shape[1] |
| | batch = tuple(t.to(device) for t in batch) |
| | input_ids, position_ids, token_ids, label_ids, *_ = batch |
| | if args.no_token_id: |
| | token_ids = None |
| | loss, ppl = model(input_ids, position_ids, token_ids, label_ids) |
| |
|
| | if n_gpu > 1: |
| | loss = loss.mean() |
| | ppl = ppl.mean() |
| | loss = loss / (args.train_batch_size / input_ids.shape[0]) |
| | if args.fp16: |
| | optimizer.backward(loss) |
| | else: |
| | loss.backward() |
| |
|
| | tr_loss += float(loss.item()) * (args.train_batch_size / input_ids.shape[0]) |
| | nb_tr_examples += input_ids.size(0) |
| | nb_tr_steps += 1 |
| | mean_loss = tr_loss / nb_tr_steps |
| | if ppl.item() < INF: |
| | tr_ppl += ppl.item() |
| | else: |
| | tr_ppl += mean_ppl |
| | mean_ppl = tr_ppl / nb_tr_steps |
| |
|
| | n_token_total += input_ids.shape[0] * input_ids.shape[1] |
| | n_token_real += (input_ids != 0).sum().item() |
| |
|
| | |
| | step += 1 |
| | if step % args.gradient_accumulation_steps == 0: |
| | set_lr(optimizer, global_step, |
| | args.lr_schedule, args.learning_rate, |
| | args.warmup_steps, args.warmup_proportion, |
| | config.n_embd, args.num_optim_steps) |
| |
|
| | if args.local_rank != -1: |
| | grads = [p.grad.data for p in model.parameters() |
| | if p.requires_grad and p.grad is not None] |
| | all_reduce_and_rescale_tensors(grads, float(1)) |
| |
|
| | optimizer.step() |
| | optimizer.zero_grad() |
| | global_step += 1 |
| |
|
| | |
| | if args.local_rank != -1: |
| | mean_loss = sum(all_gather_list(mean_loss)) / get_world_size() |
| | mean_ppl = sum(all_gather_list(mean_ppl)) / get_world_size() |
| | n_token_real_all_proc = sum(all_gather_list(n_token_real)) |
| | n_token_total_all_proc = sum(all_gather_list(n_token_total)) |
| | else: |
| | n_token_real_all_proc = n_token_real |
| | n_token_total_all_proc = n_token_total |
| |
|
| | if args.local_rank == -1 or get_rank() == 0: |
| | epoch_time = time.time() - train_start_time_epoch |
| | if pbar is not None: |
| | pbar.set_postfix_str( |
| | f"tok/s: {n_token_real_all_proc//epoch_time//1000}k " |
| | f"ppl: {mean_ppl:.2f} epoch: {epoch}") |
| | pbar.update(1) |
| | print('{},{},{},{},{},{},{},{}'.format( |
| | epoch+1, global_step+1, step+1, mean_loss, mean_ppl, |
| | n_token_real_all_proc, n_token_total_all_proc, epoch_time), |
| | file=train_logger) |
| |
|
| | if global_step % args.valid_step == 0: |
| | if args.local_rank == -1 or get_rank() == 0: |
| | |
| | torch.save( |
| | {k: (v.cpu() if v is not None else None) |
| | for k, v in model.state_dict().items()}, |
| | join(output_dir, |
| | f'GP2-pretrain-step-{global_step}.pkl')) |
| |
|
| | eval_loss, eval_ppl = eval_model_loss( |
| | model, enc, eval_dataloader_loss, epoch, args) |
| | |
| | |
| | |
| | ''' |
| | # probably use beam search only for test set |
| | if False: |
| | gen_response_beam = eval_model_generation( |
| | model, enc, eval_dataloader_gen, epoch, args, |
| | use_beam_search=True, beam_width=3) |
| | ''' |
| | print('{},{},{},{},{}'.format( |
| | epoch+1, global_step+1, step+1, eval_loss, eval_ppl), |
| | file=eval_logger) |
| | logger.info('current learning rate: ' |
| | + str(optimizer.param_groups[0]['lr'])) |
| | model.train() |
| | if global_step >= args.num_optim_steps: |
| | break |
| |
|
| | if (step+1) % CACHE_EMPTY_STEP == 0: |
| | torch.cuda.empty_cache() |
| |
|
| | if global_step >= args.num_optim_steps: |
| | break |
| | epoch += 1 |
| |
|
| |
|
| | if args.local_rank == -1 or get_rank() == 0: |
| | if pbar is not None: |
| | pbar.close() |
| | train_logger.close() |
| | eval_logger.close() |
| |
|