| | import pandas as pd |
| | import os |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from transformers import GPT2TokenizerFast, GPT2LMHeadModel, AutoModelForCausalLM |
| | from transformers import DataCollatorWithPadding, GPT2Config, DataCollatorForLanguageModeling |
| | from transformers import Trainer, TrainingArguments, RobertaTokenizerFast |
| |
|
| | import datasets |
| | from datasets import disable_caching |
| | disable_caching() |
| | from datasets import IterableDataset |
| |
|
| | from conditional_gpt2_model import ConditionalGPT2LMHeadModel |
| |
|
| |
|
| | ENCODER_MODEL_NAME = "entropy/roberta_zinc_480m" |
| | TOKENIZER_MAX_LEN = 256 |
| |
|
| | DATA_SUBSHARDS = 10 |
| |
|
| | DATA_DIR = None |
| | TRAINER_SAVE_DIR = None |
| |
|
| | assert DATA_DIR is not None, "data directory must be specified" |
| | assert TRAINER_SAVE_DIR is not None, "trainer save directory must be specified" |
| |
|
| |
|
| |
|
| | def gen_dataset(): |
| | |
| | data_filenames = sorted([i for i in os.listdir(DATA_DIR) if '.hf' in i]) |
| | |
| | for filename in data_filenames: |
| | |
| | dataset = datasets.Dataset.load_from_disk(f'{DATA_DIR}/{filename}') |
| | |
| | keep_cols = ['input_ids', 'encoder_hidden_states'] |
| | |
| | dataset = dataset.remove_columns([i for i in dataset.column_names |
| | if not i in keep_cols]).with_format("torch") |
| | |
| | |
| | shards = [dataset.shard(num_shards=DATA_SUBSHARDS, index=index, contiguous=True) |
| | for index in range(DATA_SUBSHARDS)] |
| | |
| | for i, shard in enumerate(shards): |
| | for example in shard: |
| | |
| | example['encoder_hidden_states'] = example['encoder_hidden_states'][None,:] |
| | yield example |
| |
|
| | dataset = IterableDataset.from_generator(gen_dataset) |
| | dataset = dataset.with_format("torch") |
| |
|
| | tokenizer = RobertaTokenizerFast.from_pretrained(ENCODER_MODEL_NAME, max_len=TOKENIZER_MAX_LEN) |
| | collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) |
| |
|
| | |
| | config = GPT2Config( |
| | vocab_size=len(tokenizer), |
| | n_positions=TOKENIZER_MAX_LEN, |
| | bos_token_id=tokenizer.bos_token_id, |
| | eos_token_id=tokenizer.eos_token_id, |
| | n_layer=6, |
| | n_head=8, |
| | add_cross_attention=True, |
| | ) |
| |
|
| | model = ConditionalGPT2LMHeadModel(config) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | args = TrainingArguments( |
| | output_dir=TRAINER_SAVE_DIR, |
| | per_device_train_batch_size=192, |
| | logging_steps=25, |
| | gradient_accumulation_steps=8, |
| | num_train_epochs=1, |
| | weight_decay=0.1, |
| | warmup_steps=1000, |
| | lr_scheduler_type="cosine", |
| | learning_rate=1e-5, |
| | save_steps=200, |
| | save_total_limit=30, |
| | fp16=True, |
| | push_to_hub=False, |
| | max_steps=50000, |
| | ) |
| |
|
| |
|
| | trainer = Trainer( |
| | model=model, |
| | tokenizer=tokenizer, |
| | args=args, |
| | data_collator=collator, |
| | train_dataset=dataset, |
| | ) |
| |
|
| | trainer.train() |
| |
|
| |
|