| from transformers import AutoModelForMaskedLM, AutoTokenizer, TrainingArguments, Trainer |
| from datasets import Dataset, DatasetDict |
| from transformers import DataCollatorForLanguageModeling |
|
|
| from src.MLM.datasets.preprocess_dataset import preprocess_dataset |
| from src.MLM.training_scripts.utils import get_new_model_name |
|
|
|
|
| def train_with_trainer( |
| model_checkpoint: str, |
| tokenizer: AutoTokenizer, |
| dataset: DatasetDict, |
| model_name: str | None = None, |
| data_collator=None, |
| num_epochs: int = 3, |
| ): |
|
|
| model = AutoModelForMaskedLM.from_pretrained(model_checkpoint) |
|
|
| model_name = get_new_model_name(model_checkpoint=model_checkpoint, model_name=model_name) |
|
|
| dataset = preprocess_dataset(dataset=dataset, tokenizer=tokenizer) |
|
|
| if data_collator is None: |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) |
|
|
| training_args = TrainingArguments( |
| model_name, |
| evaluation_strategy="epoch", |
| learning_rate=2e-5, |
| weight_decay=0.01, |
| push_to_hub=True, |
| report_to="wandb", |
| run_name=model_name, |
| num_train_epochs=num_epochs, |
| save_total_limit=1, |
| save_strategy="epoch", |
| ) |
|
|
| print(f"device: {training_args.device}") |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=dataset["train"], |
| eval_dataset=dataset["val"], |
| data_collator=data_collator, |
| ) |
|
|
| trainer.train() |
|
|