# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. ''' * @Desc: train GPT2 from scratch/ fine tuning. Modified based on Huggingface GPT-2 implementation ''' import json import os import sys import argparse import logging import time import tqdm import datetime import torch import numpy as np from os.path import join from torch.distributed import get_rank, get_world_size from lsp_model import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, Adam from gpt2_training.train_utils import load_model, boolean_string, set_lr, get_eval_list_same_length from gpt2_training.eval_utils import eval_model_loss from data_loader import BucketingDataLoader, DynamicBatchingLoader, DistributedBucketingDataLoader from gpt2_training.distributed import all_reduce_and_rescale_tensors, all_gather_list logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) INF = 100000000 CACHE_EMPTY_STEP = 10000 EVAL_STEP = 100000 ######################################################################### # Prepare Parser ########################################################################## parser = argparse.ArgumentParser() parser.add_argument('--model_name_or_path', type=str, help='pretrained model name or path to local checkpoint') parser.add_argument("--seed", type=int, default=42) parser.add_argument("--max_seq_length", type=int, default=128) parser.add_argument("--skip_eval", action='store_true', help='If true, skip evaluation.') parser.add_argument("--init_checkpoint", type=str) parser.add_argument("--train_input_file", type=str) parser.add_argument("--eval_input_file", type=str) parser.add_argument("--continue_from", type=int, default=0) parser.add_argument("--train_batch_size", type=int, default=4, help="batch size now means per GPU per step") parser.add_argument("--gradient_accumulation_steps", type=int, default=2, help="to increase effective batch size " "and reduce synchronization") parser.add_argument("--eval_batch_size", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=1e-5) parser.add_argument("--num_optim_steps", type=int, default=1000000, help="new API specifies num update steps") parser.add_argument("--valid_step", type=int, default=10000, help="how many optim steps between validations") parser.add_argument("--warmup_proportion", type=float, default=0.1) parser.add_argument("--warmup_steps", type=int, default=16000) parser.add_argument("--normalize_data", type=boolean_string, default=True) parser.add_argument("--fp16", type=boolean_string, default=True) parser.add_argument("--lr_schedule", type=str, choices=['noam', 'noamwd', 'BERT', 'None'], default='noam') parser.add_argument("--loss_scale", type=float, default=0) parser.add_argument("--no_token_id", type=boolean_string, default=True) parser.add_argument("--output_dir", type=str) parser.add_argument("--log_dir", type=str) parser.add_argument('--pbar', type=boolean_string, default=True, help='turn on progress bar') # distributed parser.add_argument('--local_rank', type=int, default=-1, help='for torch.distributed') parser.add_argument('--config', help='JSON config file') # do normal parsing args = parser.parse_args() if args.config is not None: # override argparse defaults by config JSON opts = json.load(open(args.config)) for k, v in opts.items(): if isinstance(v, str): # PHILLY ENV special cases if 'PHILLY_JOB_DIRECTORY' in v: v = v.replace('PHILLY_JOB_DIRECTORY', os.environ['PHILLY_JOB_DIRECTORY']) elif 'PHILLY_LOG_DIRECTORY' in v: v = v.replace('PHILLY_LOG_DIRECTORY', os.environ['PHILLY_LOG_DIRECTORY']) setattr(args, k, v) # command line should override config JSON argv = sys.argv[1:] overrides, _ = parser.parse_known_args(argv) for k, v in vars(overrides).items(): if f'--{k}' in argv: setattr(args, k, v) setattr(args, 'local_rank', overrides.local_rank) assert args.train_batch_size % args.gradient_accumulation_steps == 0, \ 'batch size % gradient accumulation steps != 0!' args.train_batch_size = (args.train_batch_size // args.gradient_accumulation_steps) logger.info('train batch size = {}, ' 'new train batch size (after gradient accumulation) = {}'.format( args.train_batch_size*args.gradient_accumulation_steps, args.train_batch_size)) if args.local_rank == -1: logger.info('CUDA available? {}'.format(str(torch.cuda.is_available()))) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() args.device, args.n_gpu = device, n_gpu else: # distributed training torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) # Initializes the distributed backend which will take care of # sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') n_gpu = torch.distributed.get_world_size() args.device, args.n_gpu = device, 1 logger.info("device: {} n_gpu: {}, distributed training: {}, " "16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.fp16)) np.random.seed(args.seed) torch.random.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) timestamp = datetime.datetime.now().strftime('%Y-%m-%d%H%M%S') output_dir = join(args.output_dir, 'GPT2.{}.{}.{}gpu.{}'.format(args.learning_rate, args.train_batch_size, n_gpu, timestamp)) log_dir = args.log_dir if args.log_dir is not None and len(args.log_dir) > 0 else output_dir if args.local_rank == -1 or get_rank() == 0: os.makedirs(output_dir, exist_ok=True) logger.info('Input Argument Information') args_dict = vars(args) for a in args_dict: logger.info('%-28s %s' % (a, args_dict[a])) ######################################################################### # Prepare Data Set ########################################################################## enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) config = GPT2Config.from_json_file( join(args.model_name_or_path, 'config.json')) if args.local_rank == -1: train_dataloader = BucketingDataLoader(args.train_input_file, args.train_batch_size, args.max_seq_length) else: train_dataloader = DistributedBucketingDataLoader( get_rank(), get_world_size(), args.train_input_file, args.train_batch_size, args.max_seq_length) eval_dataloader_loss = DynamicBatchingLoader( args.eval_input_file, enc, args.normalize_data, args.eval_batch_size, args.max_seq_length) eval_dataloader_gen = get_eval_list_same_length( args.eval_input_file, enc, args.eval_batch_size, True) ######################################################################### # Prepare Model and Optimizer ########################################################################## model = load_model(GPT2LMHeadModel(config), args.init_checkpoint, args, verbose=True) if args.local_rank != -1: # when from scratch make sure initial models are the same params = [p.data for p in model.parameters()] all_reduce_and_rescale_tensors( params, float(torch.distributed.get_world_size())) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) total_params = sum([np.prod(p.size()) for p in model_parameters]) logger.info('Number of parameter = {}'.format(total_params)) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'ln'] # no decay for bias and LayerNorm (ln) optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] if args.fp16: logger.info('in fp16, using FusedAdam') try: from apex.optimizers import FP16_Optimizer from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex " "to use distributed and fp16 training.") optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True, verbose=False) else: optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale, verbose=False) else: optimizer = Adam(optimizer_grouped_parameters, args.learning_rate, max_grad_norm=1.0) ######################################################################### # Training ! ########################################################################## if args.local_rank == -1 or get_rank() == 0: train_logger = open(join(log_dir, 'train_log.txt'), 'a+', buffering=1) eval_logger = open(join(log_dir, 'eval_log.txt'), 'a+', buffering=1) print('epoch,global_step,step,mean_loss,mean_ppl,n_token_real,' 'n_token_total,epoch_time', file=train_logger) print('epoch,global_step,step,eval_loss,eval_ppl', file=eval_logger) global_step = 0 step = 0 epoch = 0 if args.continue_from: global_step = args.continue_from step = global_step*2 - 1 if args.local_rank != -1: n_gpu = 1 if args.local_rank == -1 or get_rank() == 0: if args.pbar: pbar = tqdm.tqdm(total=args.num_optim_steps, desc=f"training") else: pbar = None while True: model.train() (tr_loss, tr_ppl, mean_ppl, nb_tr_examples, nb_tr_steps) = 0.0, 0.0, 0.0, 0, 0 n_token_real, n_token_total = 0, 0 train_start_time_epoch = time.time() for batch in train_dataloader: # activate new training mode seq_len = batch[0].shape[1] batch = tuple(t.to(device) for t in batch) input_ids, position_ids, token_ids, label_ids, *_ = batch if args.no_token_id: token_ids = None loss, ppl = model(input_ids, position_ids, token_ids, label_ids) if n_gpu > 1: loss = loss.mean() ppl = ppl.mean() loss = loss / (args.train_batch_size / input_ids.shape[0]) if args.fp16: optimizer.backward(loss) else: loss.backward() tr_loss += float(loss.item()) * (args.train_batch_size / input_ids.shape[0]) nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 mean_loss = tr_loss / nb_tr_steps if ppl.item() < INF: tr_ppl += ppl.item() else: tr_ppl += mean_ppl mean_ppl = tr_ppl / nb_tr_steps n_token_total += input_ids.shape[0] * input_ids.shape[1] n_token_real += (input_ids != 0).sum().item() # gradient update step += 1 if step % args.gradient_accumulation_steps == 0: set_lr(optimizer, global_step, args.lr_schedule, args.learning_rate, args.warmup_steps, args.warmup_proportion, config.n_embd, args.num_optim_steps) if args.local_rank != -1: grads = [p.grad.data for p in model.parameters() if p.requires_grad and p.grad is not None] all_reduce_and_rescale_tensors(grads, float(1)) optimizer.step() optimizer.zero_grad() global_step += 1 # Print log info to file if args.local_rank != -1: mean_loss = sum(all_gather_list(mean_loss)) / get_world_size() mean_ppl = sum(all_gather_list(mean_ppl)) / get_world_size() n_token_real_all_proc = sum(all_gather_list(n_token_real)) n_token_total_all_proc = sum(all_gather_list(n_token_total)) else: n_token_real_all_proc = n_token_real n_token_total_all_proc = n_token_total if args.local_rank == -1 or get_rank() == 0: epoch_time = time.time() - train_start_time_epoch if pbar is not None: pbar.set_postfix_str( f"tok/s: {n_token_real_all_proc//epoch_time//1000}k " f"ppl: {mean_ppl:.2f} epoch: {epoch}") pbar.update(1) print('{},{},{},{},{},{},{},{}'.format( epoch+1, global_step+1, step+1, mean_loss, mean_ppl, n_token_real_all_proc, n_token_total_all_proc, epoch_time), file=train_logger) if global_step % args.valid_step == 0: if args.local_rank == -1 or get_rank() == 0: # only rank 0 process evaluate torch.save( {k: (v.cpu() if v is not None else None) # save to cpu tensors for k, v in model.state_dict().items()}, join(output_dir, f'GP2-pretrain-step-{global_step}.pkl')) eval_loss, eval_ppl = eval_model_loss( model, enc, eval_dataloader_loss, epoch, args) # enable generation step evaluation for now # gen_response = eval_model_generation( # model, enc, eval_dataloader_gen, epoch, args) ''' # probably use beam search only for test set if False: gen_response_beam = eval_model_generation( model, enc, eval_dataloader_gen, epoch, args, use_beam_search=True, beam_width=3) ''' print('{},{},{},{},{}'.format( epoch+1, global_step+1, step+1, eval_loss, eval_ppl), file=eval_logger) logger.info('current learning rate: ' + str(optimizer.param_groups[0]['lr'])) model.train() if global_step >= args.num_optim_steps: break if (step+1) % CACHE_EMPTY_STEP == 0: torch.cuda.empty_cache() if global_step >= args.num_optim_steps: break epoch += 1 if args.local_rank == -1 or get_rank() == 0: if pbar is not None: pbar.close() train_logger.close() eval_logger.close()