| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import sys |
| | import logging |
| | from functools import partial |
| |
|
| | from demo_utils import download_model_folder |
| | import argparse |
| | import subprocess as sp |
| |
|
| |
|
| | PROJECT_FOLDER = os.path.dirname(os.path.realpath(__file__)) |
| | PYTHON_EXE = 'python' |
| | MODEL_FOLDER = os.path.join(PROJECT_FOLDER, 'models') |
| | DATA_FOLDER = os.path.join(PROJECT_FOLDER, 'data') |
| |
|
| | print(f'PROJECT_FOLDER = {PROJECT_FOLDER}') |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--data', type=str, default='dummy', |
| | help='choose from dummy, small and full') |
| | dargs = parser.parse_args() |
| |
|
| | assert dargs.data == 'dummy' or dargs.data == 'small' or dargs.data == 'full' , \ |
| | 'The specified data option is not support!' |
| |
|
| |
|
| | logging.basicConfig( |
| | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
| | datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | if os.path.exists(MODEL_FOLDER): |
| | print(f'Found existing models folder at {MODEL_FOLDER}, skip creating a new one!') |
| | os.makedirs(MODEL_FOLDER, exist_ok=True) |
| | else: |
| | os.makedirs(MODEL_FOLDER) |
| |
|
| | |
| | |
| | |
| | logger.info('Downloading models...') |
| | download_model = partial(download_model_folder, DATA_FOLDER=MODEL_FOLDER) |
| |
|
| | |
| | |
| | |
| | target_folder = download_model(model_size='small', dataset='multiref', from_scratch=False) |
| | logger.info('Done!\n') |
| |
|
| |
|
| | |
| | |
| | |
| | logger.info('Downloading and Extracting Data...') |
| | if dargs.data == 'dummy': |
| | cmd = 'bash prepare4db.sh' |
| | ret = sp.run(cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=DATA_FOLDER) |
| | elif dargs.data == 'small': |
| | myCmd = os.popen('cd reddit_extractor; SIZE=small make -j 8; cd ..').read() |
| | elif dargs.data == 'full': |
| | myCmd = os.popen('cd reddit_extractor; SIZE=full make -j 8; cd ..').read() |
| | else: |
| | raise ValueError('you need to implement your own data type, or use either dummy, small, or full') |
| |
|
| | logger.info('Preparing Data...') |
| | data_path = os.path.join(DATA_FOLDER, 'train.tsv') |
| | MAX_LEN = 128 |
| | data_db = f'{data_path[:-4]}.{MAX_LEN}len.db' |
| | if os.path.isdir(data_db): |
| | print(f'{data_db} exists, skip prepro.py') |
| | else: |
| | cmd = ['prepro.py', '--corpus', data_path, '--max_seq_len', f'{MAX_LEN}'] |
| | cmd = ' '.join(cmd) |
| | print(cmd) |
| | ret = sp.run([PYTHON_EXE] + cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=PROJECT_FOLDER) |
| | if ret.returncode != 0: |
| | print(f'error occurred, {ret.stdout}') |
| | sys.exit(ret.returncode) |
| | logger.info('Done!\n') |
| |
|
| | |
| | |
| | |
| | logger.info('Generating training CMD!') |
| | logger.info('If there is any problem, please copy (modify) and run command below') |
| | logger.info('#########################################################################') |
| | train_cmd = 'LSP_train.py' |
| | args = [ |
| | '--model_name_or_path', target_folder, |
| | '--init_checkpoint', os.path.join(target_folder, 'pytorch_model.bin'), |
| | '--train_input_file', data_db , |
| | '--eval_input_file', './data/dummy_data.tsv', |
| | '--output_dir', os.path.join(MODEL_FOLDER, 'output_model'), |
| | '--seed', '42', |
| | '--max_seq_length', '128', |
| | '--train_batch_size', '512', |
| | '--gradient_accumulation_steps', '8', |
| | '--eval_batch_size', '64', |
| | '--learning_rate', '1e-5', |
| | '--num_optim_steps', '10000', |
| | '--valid_step', '5000', |
| | '--warmup_steps', '4000', |
| | '--normalize_data', 'true', |
| | '--fp16', 'true', |
| | '--lr_schedule', 'noam', |
| | '--loss_scale', '0.0', |
| | '--no_token_id', 'true', |
| | '--pbar', 'true' |
| | ] |
| |
|
| | arg = ' '.join(args) |
| | train_cmd = train_cmd + ' ' + arg |
| | print(PYTHON_EXE + ' ' +train_cmd) |
| | logger.info('#########################################################################') |
| | with open('./output.log', 'wb') as f: |
| | process = sp.Popen([PYTHON_EXE] + train_cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=PROJECT_FOLDER) |
| | for line in iter(process.stdout.readline, b''): |
| | sys.stdout.write(line.decode(sys.stdout.encoding)) |
| | f.write(line) |
| | logger.info('Done!\n') |
| |
|