| import argparse |
| import logging |
| from torch.utils.data import Dataset, IterableDataset |
| import gzip |
| import json |
| from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments |
| import sys |
| from datetime import datetime |
| import torch |
| import random |
| from shutil import copyfile |
| import os |
| import wandb |
| import re |
|
|
|
|
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| handlers=[logging.StreamHandler(sys.stdout)], |
| ) |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_name", default="google/t5-v1_1-base") |
| parser.add_argument("--train_file", required=True) |
| parser.add_argument("--epochs", default=1, type=int) |
| parser.add_argument("--batch_size", default=16, type=int) |
| parser.add_argument("--max_source_length", default=384, type=int) |
| parser.add_argument("--max_target_length", default=64, type=int) |
| parser.add_argument("--name", required=True) |
| parser.add_argument("--train_size", default=100*1000*1000, type=int) |
| parser.add_argument("--eval_size", default=10000, type=int) |
| parser.add_argument("--fp16", default=False, action='store_true') |
| parser.add_argument("--no_prefix", default=False, action='store_true') |
| args = parser.parse_args() |
|
|
| wandb.init(project="doc2query", name=f"{args.name}-{args.model_name}") |
|
|
|
|
| class PairDataset: |
| def __init__(self, filepath): |
| self.filepath = filepath |
| self.examples = [] |
|
|
| def __iter__(self): |
| with gzip.open(self.filepath, 'rt') as fIn: |
| for line in fIn: |
| example = self.get_example(json.loads(line)) |
| |
| if example is not None: |
| self.examples.append(example) |
| yield example |
|
|
| while True: |
| random.shuffle(self.examples) |
| for ex in self.examples: |
| yield ex |
|
|
|
|
| def get_example(self, raw_example): |
| if isinstance(raw_example, dict): |
| if 'set' in raw_example: |
| example = random.sample(raw_example['set'], 2) |
| elif 'query' in raw_example: |
| example = [raw_example['query'], random.choice(raw_example['pos'])] |
| else: |
| raise ValueError("Unknown format: "+str(raw_example)) |
| else: |
| example = [raw_example[0], raw_example[1]] |
|
|
| return example |
|
|
|
|
|
|
|
|
| class RedditTitleDataset(PairDataset): |
| def get_example(self, raw_example): |
| return [self.clean_title(raw_example['title']), raw_example['body']] |
|
|
|
|
| def clean_title(self, text): |
| text = text.replace("&", "&").strip() |
| if text.startswith("["): |
| text = re.sub("^\[[a-zA-Z0-9]+\]", "", text).strip() |
|
|
| if text.endswith("]"): |
| text = re.sub("\[[a-zA-Z0-9\.]+\]$", "", text).strip() |
|
|
| if text.startswith("/r"): |
| text = re.sub("^/[a-zA-Z0-9/]+[;,: \-]+", "", text).strip() |
|
|
| return text |
|
|
|
|
| class StackExchangeTitleBodyDataset(PairDataset): |
| def get_example(self, raw_example): |
| return raw_example['texts'] |
|
|
|
|
| class MultiDataset(IterableDataset): |
| def __init__(self, train_config_path, num_samples): |
| self.num_samples = num_samples |
|
|
| with open(train_config_path) as fIn: |
| train_config = json.load(fIn) |
|
|
| self.categories = [] |
| self.files = {} |
| self.file2dataset = {} |
| self.file2datasetIter = {} |
|
|
| for prefix in train_config: |
| self.categories.extend([prefix]*train_config[prefix]['weight']) |
| self.files[prefix] = [] |
|
|
| for filename, weight in train_config[prefix]['files'].items(): |
| self.files[prefix].extend([filename]*weight) |
| dataset = self.OpenDataset(filename) |
| self.file2dataset[filename] = dataset |
| self.file2datasetIter[filename] = iter(dataset) |
|
|
| random.shuffle(self.files[prefix]) |
|
|
| random.shuffle(self.categories) |
|
|
|
|
|
|
| |
| def OpenDataset(self, filepath): |
| if 'reddit_title_text' in filepath: |
| dataset = RedditTitleDataset(filepath) |
| elif 'stackexchange_archive/jsonl' in filepath: |
| dataset = StackExchangeTitleBodyDataset(filepath) |
| else: |
| dataset = PairDataset(filepath) |
| return dataset |
| |
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def __iter__(self): |
| while True: |
| category = random.choice(self.categories) |
| filepath = random.choice(self.files[category]) |
| dataset = self.file2datasetIter[filepath] |
| pair = next(dataset) |
|
|
| |
| if not args.no_prefix: |
| pair[1] = category+": "+pair[1].strip() |
| yield pair |
|
|
| def delete_examples_cache(self): |
| for dataset in self.file2dataset.values(): |
| dataset.examples = [] |
|
|
|
|
|
|
| def main(): |
| |
| model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name) |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
|
| save_steps = 5000 |
|
|
| output_dir = 'output/'+args.name+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| print("Output dir:", output_dir) |
|
|
| |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| copyfile(args.train_file, os.path.join(output_dir, 'data_config.json')) |
| train_script_path = os.path.join(output_dir, 'train_script.py') |
| copyfile(__file__, train_script_path) |
| with open(train_script_path, 'a') as fOut: |
| fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
|
|
| |
|
|
| training_args = Seq2SeqTrainingArguments( |
| output_dir=output_dir, |
| fp16=args.fp16, |
| fp16_backend="amp", |
| per_device_train_batch_size=args.batch_size, |
| evaluation_strategy="steps", |
| save_steps=save_steps, |
| logging_steps=100, |
| eval_steps=save_steps, |
| warmup_steps=1000, |
| save_total_limit=1, |
| num_train_epochs=args.epochs, |
| report_to="wandb", |
| ) |
|
|
| |
|
|
| |
|
|
|
|
| train_dataset = MultiDataset(args.train_file, args.train_size) |
| train_dataset_iter = iter(train_dataset) |
| eval_dataset = [next(train_dataset_iter) for _ in range(args.eval_size)] |
| train_dataset.delete_examples_cache() |
|
|
| for i in range(50): |
| print("Target:", eval_dataset[i][0]) |
| print("Input:", eval_dataset[i][1]) |
| print("\n\n===================\n\n") |
|
|
| print("Train dataset len:", len(train_dataset)) |
|
|
| |
| def data_collator(examples): |
| targets = [row[0] for row in examples] |
| inputs = [row[1] for row in examples] |
| label_pad_token_id = -100 |
|
|
| model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None) |
|
|
| |
| with tokenizer.as_target_tokenizer(): |
| labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None) |
|
|
| |
| labels["input_ids"] = [ |
| [(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"] |
| ] |
|
|
|
|
| model_inputs["labels"] = torch.tensor(labels["input_ids"]) |
| return model_inputs |
|
|
| |
| trainer = Seq2SeqTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| tokenizer=tokenizer, |
| data_collator=data_collator |
| ) |
|
|
| |
| train_result = trainer.train() |
| trainer.save_model() |
| |
| |
| if __name__ == "__main__": |
| main() |
|
|
| |
| |