| | import logging |
| | import math |
| | import os |
| | import random |
| | import re |
| | import shutil |
| | import warnings |
| | from contextlib import contextmanager |
| | from pathlib import Path |
| | from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from packaging import version |
| | from torch import nn |
| | from torch.utils.data.dataloader import DataLoader |
| | from torch.utils.data.dataset import Dataset |
| | from torch.utils.data.distributed import DistributedSampler |
| | from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler |
| | from tqdm.auto import tqdm, trange |
| |
|
| | from transformers.data.data_collator import DataCollator, default_data_collator |
| | from transformers.file_utils import is_apex_available, is_torch_tpu_available |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.optimization import AdamW, get_linear_schedule_with_warmup |
| | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput, is_wandb_available |
| |
|
| | from relogic.pretrainkit.training_args import TrainingArguments |
| | from relogic.pretrainkit.trainer_utils import EvalPredictionWithSize, PredictionOutputWithSize |
| |
|
| |
|
| |
|
| | if is_apex_available(): |
| | from apex import amp |
| |
|
| |
|
| | if is_torch_tpu_available(): |
| | import torch_xla.core.xla_model as xm |
| | import torch_xla.debug.metrics as met |
| | import torch_xla.distributed.parallel_loader as pl |
| |
|
| | try: |
| | from torch.utils.tensorboard import SummaryWriter |
| |
|
| | _has_tensorboard = True |
| | except ImportError: |
| | try: |
| | from tensorboardX import SummaryWriter |
| |
|
| | _has_tensorboard = True |
| | except ImportError: |
| | _has_tensorboard = False |
| |
|
| |
|
| | def is_tensorboard_available(): |
| | return _has_tensorboard |
| |
|
| |
|
| | if is_wandb_available(): |
| | import wandb |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def set_seed(seed: int): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | |
| |
|
| |
|
| | @contextmanager |
| | def torch_distributed_zero_first(local_rank: int): |
| | """ |
| | Decorator to make all processes in distributed training wait for each local_master to do something. |
| | """ |
| | if local_rank not in [-1, 0]: |
| | torch.distributed.barrier() |
| | yield |
| | if local_rank == 0: |
| | torch.distributed.barrier() |
| |
|
| |
|
| | class SequentialDistributedSampler(Sampler): |
| | """ |
| | Distributed Sampler that subsamples indicies sequentially, |
| | making it easier to collate all results at the end. |
| | |
| | Even though we only use this sampler for eval and predict (no training), |
| | which means that the model params won't have to be synced (i.e. will not hang |
| | for synchronization even if varied number of forward passes), we still add extra |
| | samples to the sampler to make it evenly divisible (like in `DistributedSampler`) |
| | to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. |
| | """ |
| |
|
| | def __init__(self, dataset, num_replicas=None, rank=None): |
| | if num_replicas is None: |
| | if not torch.distributed.is_available(): |
| | raise RuntimeError("Requires distributed package to be available") |
| | num_replicas = torch.distributed.get_world_size() |
| | if rank is None: |
| | if not torch.distributed.is_available(): |
| | raise RuntimeError("Requires distributed package to be available") |
| | rank = torch.distributed.get_rank() |
| | self.dataset = dataset |
| | self.num_replicas = num_replicas |
| | self.rank = rank |
| | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) |
| | self.total_size = self.num_samples * self.num_replicas |
| |
|
| | def __iter__(self): |
| | indices = list(range(len(self.dataset))) |
| |
|
| | |
| | indices += indices[: (self.total_size - len(indices))] |
| | assert len(indices) == self.total_size |
| |
|
| | |
| | indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] |
| | assert len(indices) == self.num_samples |
| |
|
| | return iter(indices) |
| |
|
| | def __len__(self): |
| | return self.num_samples |
| |
|
| |
|
| | def get_tpu_sampler(dataset: Dataset): |
| | if xm.xrt_world_size() <= 1: |
| | return RandomSampler(dataset) |
| | return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) |
| |
|
| |
|
| | class Trainer: |
| | """ |
| | Trainer is a simple but feature-complete training and eval loop for PyTorch, |
| | optimized for Transformers. |
| | """ |
| |
|
| | model: PreTrainedModel |
| | args: TrainingArguments |
| | data_collator: DataCollator |
| | train_dataset: Optional[Dataset] |
| | eval_dataset: Optional[Dataset] |
| | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None |
| | prediction_loss_only: bool |
| | tb_writer: Optional["SummaryWriter"] = None |
| | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None |
| | global_step: Optional[int] = None |
| | epoch: Optional[float] = None |
| |
|
| | def __init__( |
| | self, |
| | model: PreTrainedModel, |
| | args: TrainingArguments, |
| | data_collator: Optional[DataCollator] = None, |
| | train_dataset: Optional[Dataset] = None, |
| | eval_dataset: Optional[Dataset] = None, |
| | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, |
| | prediction_loss_only=False, |
| | tb_writer: Optional["SummaryWriter"] = None, |
| | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None, |
| | ): |
| | """ |
| | Trainer is a simple but feature-complete training and eval loop for PyTorch, |
| | optimized for Transformers. |
| | |
| | Args: |
| | prediction_loss_only: |
| | (Optional) in evaluation and prediction, only return the loss |
| | """ |
| | self.model = model.to(args.device) |
| | self.args = args |
| | self.data_collator = data_collator if data_collator is not None else default_data_collator |
| | self.train_dataset = train_dataset |
| | self.eval_dataset = eval_dataset |
| | self.compute_metrics = compute_metrics |
| | self.prediction_loss_only = prediction_loss_only |
| | self.optimizers = optimizers |
| | if tb_writer is not None: |
| | self.tb_writer = tb_writer |
| | elif is_tensorboard_available() and self.is_world_master(): |
| | self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir) |
| | if not is_tensorboard_available(): |
| | logger.warning( |
| | "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it." |
| | ) |
| | if is_wandb_available(): |
| | self._setup_wandb() |
| | else: |
| | logger.info( |
| | "You are instantiating a Trainer but W&B is not installed. To use wandb logging, " |
| | "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface." |
| | ) |
| | set_seed(self.args.seed) |
| | |
| | if self.is_world_master(): |
| | os.makedirs(self.args.output_dir, exist_ok=True) |
| | if is_torch_tpu_available(): |
| | |
| | |
| | self.model.config.xla_device = True |
| | if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): |
| | self.data_collator = self.data_collator.collate_batch |
| | warnings.warn( |
| | ( |
| | "The `data_collator` should now be a simple callable (function, class with `__call__`), classes " |
| | + "with a `collate_batch` are deprecated and won't be supported in a future version." |
| | ), |
| | FutureWarning, |
| | ) |
| |
|
| | def get_train_dataloader(self) -> DataLoader: |
| | if self.train_dataset is None: |
| | raise ValueError("Trainer: training requires a train_dataset.") |
| | if is_torch_tpu_available(): |
| | train_sampler = get_tpu_sampler(self.train_dataset) |
| | else: |
| | train_sampler = ( |
| | RandomSampler(self.train_dataset) |
| | if self.args.local_rank == -1 |
| | else DistributedSampler(self.train_dataset) |
| | ) |
| |
|
| | data_loader = DataLoader( |
| | self.train_dataset, |
| | batch_size=self.args.train_batch_size, |
| | sampler=train_sampler, |
| | collate_fn=self.data_collator, |
| | drop_last=self.args.dataloader_drop_last, |
| | ) |
| |
|
| | return data_loader |
| |
|
| | def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: |
| | if eval_dataset is None and self.eval_dataset is None: |
| | raise ValueError("Trainer: evaluation requires an eval_dataset.") |
| |
|
| | eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset |
| |
|
| | if is_torch_tpu_available(): |
| | sampler = SequentialDistributedSampler( |
| | eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() |
| | ) |
| | elif self.args.local_rank != -1: |
| | sampler = SequentialDistributedSampler(eval_dataset) |
| | else: |
| | sampler = SequentialSampler(eval_dataset) |
| |
|
| | data_loader = DataLoader( |
| | eval_dataset, |
| | sampler=sampler, |
| | batch_size=self.args.eval_batch_size, |
| | collate_fn=self.data_collator, |
| | drop_last=self.args.dataloader_drop_last, |
| | ) |
| |
|
| | return data_loader |
| |
|
| | def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: |
| | |
| | if is_torch_tpu_available(): |
| | sampler = SequentialDistributedSampler( |
| | test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() |
| | ) |
| | elif self.args.local_rank != -1: |
| | sampler = SequentialDistributedSampler(test_dataset) |
| | else: |
| | sampler = SequentialSampler(test_dataset) |
| |
|
| | data_loader = DataLoader( |
| | test_dataset, |
| | sampler=sampler, |
| | batch_size=self.args.eval_batch_size, |
| | collate_fn=self.data_collator, |
| | drop_last=self.args.dataloader_drop_last, |
| | ) |
| |
|
| | return data_loader |
| |
|
| | def get_optimizers( |
| | self, num_training_steps: int |
| | ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]: |
| | """ |
| | Setup the optimizer and the learning rate scheduler. |
| | |
| | We provide a reasonable default that works well. |
| | If you want to use something else, you can pass a tuple in the Trainer's init, |
| | or override this method in a subclass. |
| | """ |
| | if self.optimizers is not None: |
| | return self.optimizers |
| | |
| | no_decay = ["bias", "LayerNorm.weight"] |
| | optimizer_grouped_parameters = [ |
| | { |
| | "params": [p for n, p in self.model.named_parameters() if "relational_transformer" not in n and not any(nd in n for nd in no_decay)], |
| | "weight_decay": self.args.weight_decay, |
| | }, |
| | { |
| | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], |
| | "weight_decay": 0.0, |
| | }, |
| | { |
| | "params": [p for n, p in self.model.named_parameters() if "relational_transformer" in n and not any(nd in n for nd in no_decay)], |
| | "weight_decay": self.args.weight_decay, |
| | "lr": 7e-5 |
| | } |
| | ] |
| | optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon) |
| | scheduler = get_linear_schedule_with_warmup( |
| | optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps |
| | ) |
| | return optimizer, scheduler |
| |
|
| | def _setup_wandb(self): |
| | """ |
| | Setup the optional Weights & Biases (`wandb`) integration. |
| | |
| | One can override this method to customize the setup if needed. Find more information at https://docs.wandb.com/huggingface |
| | You can also override the following environment variables: |
| | |
| | Environment: |
| | WANDB_WATCH: |
| | (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging |
| | or "all" to log gradients and parameters |
| | WANDB_PROJECT: |
| | (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project |
| | WANDB_DISABLED: |
| | (Optional): boolean - defaults to false, set to "true" to disable wandb entirely |
| | """ |
| | if self.is_world_master(): |
| | logger.info( |
| | 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' |
| | ) |
| | wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args)) |
| | |
| | if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": |
| | wandb.watch( |
| | self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps) |
| | ) |
| |
|
| | def num_examples(self, dataloader: DataLoader) -> int: |
| | """ |
| | Helper to get num of examples from a DataLoader, by accessing its Dataset. |
| | """ |
| | return len(dataloader.dataset) |
| |
|
| | def train(self, model_path: Optional[str] = None): |
| | """ |
| | Main training entry point. |
| | |
| | Args: |
| | model_path: |
| | (Optional) Local path to model if model to train has been instantiated from a local path |
| | If present, we will try reloading the optimizer/scheduler states from there. |
| | """ |
| | train_dataloader = self.get_train_dataloader() |
| | if self.args.max_steps > 0: |
| | t_total = self.args.max_steps |
| | num_train_epochs = ( |
| | self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 |
| | ) |
| | else: |
| | t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) |
| | num_train_epochs = self.args.num_train_epochs |
| |
|
| | optimizer, scheduler = self.get_optimizers(num_training_steps=t_total) |
| |
|
| | |
| | if ( |
| | model_path is not None |
| | and os.path.isfile(os.path.join(model_path, "optimizer.pt")) |
| | and os.path.isfile(os.path.join(model_path, "scheduler.pt")) |
| | ): |
| | |
| | optimizer.load_state_dict( |
| | torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) |
| | ) |
| | scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) |
| |
|
| | model = self.model |
| | if self.args.fp16: |
| | if not is_apex_available(): |
| | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") |
| | model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level) |
| |
|
| | |
| | if self.args.n_gpu > 1: |
| | model = torch.nn.DataParallel(model) |
| |
|
| | |
| | if self.args.local_rank != -1: |
| | model = torch.nn.parallel.DistributedDataParallel( |
| | model, |
| | device_ids=[self.args.local_rank], |
| | output_device=self.args.local_rank, |
| | find_unused_parameters=True, |
| | ) |
| |
|
| | if self.tb_writer is not None: |
| | self.tb_writer.add_text("args", self.args.to_json_string()) |
| | self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) |
| |
|
| | |
| | if is_torch_tpu_available(): |
| | total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() |
| | else: |
| | total_train_batch_size = ( |
| | self.args.train_batch_size |
| | * self.args.gradient_accumulation_steps |
| | * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1) |
| | ) |
| | logger.info("***** Running training *****") |
| | logger.info(" Num examples = %d", self.num_examples(train_dataloader)) |
| | logger.info(" Num Epochs = %d", num_train_epochs) |
| | logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) |
| | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) |
| | logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) |
| | logger.info(" Total optimization steps = %d", t_total) |
| |
|
| | self.global_step = 0 |
| | self.epoch = 0 |
| | epochs_trained = 0 |
| | steps_trained_in_current_epoch = 0 |
| | |
| | if model_path is not None: |
| | |
| | try: |
| | self.global_step = int(model_path.split("-")[-1].split("/")[0]) |
| | epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps) |
| | steps_trained_in_current_epoch = self.global_step % ( |
| | len(train_dataloader) // self.args.gradient_accumulation_steps |
| | ) |
| |
|
| | logger.info(" Continuing training from checkpoint, will skip to saved global_step") |
| | logger.info(" Continuing training from epoch %d", epochs_trained) |
| | logger.info(" Continuing training from global step %d", self.global_step) |
| | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) |
| | except ValueError: |
| | self.global_step = 0 |
| | logger.info(" Starting fine-tuning.") |
| |
|
| | tr_loss = 0.0 |
| | logging_loss = 0.0 |
| | model.zero_grad() |
| | train_iterator = trange( |
| | epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master() or not self.args.logging_tqdm |
| | ) |
| | for epoch in train_iterator: |
| | if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): |
| | train_dataloader.sampler.set_epoch(epoch) |
| |
|
| | if is_torch_tpu_available(): |
| | parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( |
| | self.args.device |
| | ) |
| | epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master() or not self.args.logging_tqdm) |
| | else: |
| | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master() or not self.args.logging_tqdm) |
| |
|
| | for step, inputs in enumerate(epoch_iterator): |
| |
|
| | |
| | if steps_trained_in_current_epoch > 0: |
| | steps_trained_in_current_epoch -= 1 |
| | continue |
| |
|
| | tr_loss += self._training_step(model, inputs, optimizer) |
| |
|
| | if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( |
| | |
| | len(epoch_iterator) <= self.args.gradient_accumulation_steps |
| | and (step + 1) == len(epoch_iterator) |
| | ): |
| | if self.args.fp16: |
| | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm) |
| | else: |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) |
| |
|
| | if is_torch_tpu_available(): |
| | xm.optimizer_step(optimizer) |
| | else: |
| | optimizer.step() |
| |
|
| | scheduler.step() |
| | model.zero_grad() |
| | self.global_step += 1 |
| | self.epoch = epoch + (step + 1) / len(epoch_iterator) |
| |
|
| | if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or ( |
| | self.global_step == 1 and self.args.logging_first_step |
| | ): |
| | logs: Dict[str, float] = {} |
| | logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps |
| | |
| | logs["learning_rate"] = ( |
| | scheduler.get_last_lr()[0] |
| | if version.parse(torch.__version__) >= version.parse("1.4") |
| | else scheduler.get_lr()[0] |
| | ) |
| | logging_loss = tr_loss |
| |
|
| | self._log(logs) |
| |
|
| | if (self.args.eval_steps > 0 and self.global_step % self.args.eval_steps == 0): |
| | if self.args.evaluate_during_training: |
| | self.evaluate() |
| |
|
| | if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: |
| | |
| | |
| | if hasattr(model, "module"): |
| | assert model.module is self.model |
| | else: |
| | assert model is self.model |
| | |
| | output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}") |
| |
|
| | self.save_model(output_dir) |
| |
|
| | if self.is_world_master(): |
| | self._rotate_checkpoints() |
| |
|
| | if is_torch_tpu_available(): |
| | xm.rendezvous("saving_optimizer_states") |
| | xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) |
| | xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) |
| | elif self.is_world_master(): |
| | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) |
| | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) |
| |
|
| | if self.args.max_steps > 0 and self.global_step > self.args.max_steps: |
| | epoch_iterator.close() |
| | break |
| | if self.args.max_steps > 0 and self.global_step > self.args.max_steps: |
| | train_iterator.close() |
| | break |
| | if self.args.tpu_metrics_debug: |
| | |
| | xm.master_print(met.metrics_report()) |
| |
|
| | if self.tb_writer: |
| | self.tb_writer.close() |
| |
|
| | logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") |
| | return TrainOutput(self.global_step, tr_loss / self.global_step) |
| |
|
| | def _log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None: |
| | if self.epoch is not None: |
| | logs["epoch"] = self.epoch |
| | if self.global_step is None: |
| | |
| | self.global_step = 0 |
| | if self.tb_writer: |
| | for k, v in logs.items(): |
| | if isinstance(v, (int, float)): |
| | self.tb_writer.add_scalar(k, v, self.global_step) |
| | else: |
| | logger.warning( |
| | "Trainer is attempting to log a value of " |
| | '"%s" of type %s for key "%s" as a scalar. ' |
| | "This invocation of Tensorboard's writer.add_scalar() " |
| | "is incorrect so we dropped this attribute.", |
| | v, |
| | type(v), |
| | k, |
| | ) |
| | self.tb_writer.flush() |
| | if is_wandb_available(): |
| | if self.is_world_master(): |
| | wandb.log(logs, step=self.global_step) |
| | output = {**logs, **{"step": self.global_step}} |
| | if iterator is not None: |
| | iterator.write(output) |
| | else: |
| | logger.info(output) |
| |
|
| | def _training_step( |
| | self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer |
| | ) -> float: |
| | model.train() |
| | for k, v in inputs.items(): |
| | if isinstance(v, torch.Tensor): |
| | inputs[k] = v.to(self.args.device) |
| |
|
| | outputs = model(**inputs) |
| | loss = outputs[0] |
| |
|
| | if self.args.n_gpu > 1: |
| | loss = loss.mean() |
| | if self.args.gradient_accumulation_steps > 1: |
| | loss = loss / self.args.gradient_accumulation_steps |
| |
|
| | if self.args.fp16: |
| | with amp.scale_loss(loss, optimizer) as scaled_loss: |
| | scaled_loss.backward() |
| | else: |
| | loss.backward() |
| |
|
| | return loss.item() |
| |
|
| | def is_local_master(self) -> bool: |
| | if is_torch_tpu_available(): |
| | return xm.is_master_ordinal(local=True) |
| | else: |
| | return self.args.local_rank in [-1, 0] |
| |
|
| | def is_world_master(self) -> bool: |
| | """ |
| | This will be True only in one process, even in distributed mode, |
| | even when training on multiple machines. |
| | """ |
| | if is_torch_tpu_available(): |
| | return xm.is_master_ordinal(local=False) |
| | else: |
| | return self.args.local_rank == -1 or torch.distributed.get_rank() == 0 |
| |
|
| | def save_model(self, output_dir: Optional[str] = None): |
| | """ |
| | Saving best-practices: if you use default names for the model, |
| | you can reload it using from_pretrained(). |
| | |
| | Will only save from the world_master process (unless in TPUs). |
| | """ |
| |
|
| | if is_torch_tpu_available(): |
| | self._save_tpu(output_dir) |
| | elif self.is_world_master(): |
| | self._save(output_dir) |
| |
|
| | def _save_tpu(self, output_dir: Optional[str] = None): |
| | output_dir = output_dir if output_dir is not None else self.args.output_dir |
| | logger.info("Saving model checkpoint to %s", output_dir) |
| |
|
| | if xm.is_master_ordinal(): |
| | os.makedirs(output_dir, exist_ok=True) |
| | torch.save(self.args, os.path.join(output_dir, "training_args.bin")) |
| |
|
| | |
| | |
| | if not isinstance(self.model, PreTrainedModel): |
| | raise ValueError("Trainer.model appears to not be a PreTrainedModel") |
| |
|
| | xm.rendezvous("saving_checkpoint") |
| | self.model.save_pretrained(output_dir) |
| |
|
| | def _save(self, output_dir: Optional[str] = None): |
| | output_dir = output_dir if output_dir is not None else self.args.output_dir |
| | os.makedirs(output_dir, exist_ok=True) |
| | logger.info("Saving model checkpoint to %s", output_dir) |
| | |
| | |
| | |
| | |
| | self.model.save_pretrained(output_dir) |
| |
|
| | |
| | torch.save(self.args, os.path.join(output_dir, "training_args.bin")) |
| |
|
| | def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]: |
| | ordering_and_checkpoint_path = [] |
| |
|
| | glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")] |
| |
|
| | for path in glob_checkpoints: |
| | if use_mtime: |
| | ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) |
| | else: |
| | regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) |
| | if regex_match and regex_match.groups(): |
| | ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) |
| |
|
| | checkpoints_sorted = sorted(ordering_and_checkpoint_path) |
| | checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] |
| | return checkpoints_sorted |
| |
|
| | def _rotate_checkpoints(self, use_mtime=False) -> None: |
| | if self.args.save_total_limit is None or self.args.save_total_limit <= 0: |
| | return |
| |
|
| | |
| | checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime) |
| | if len(checkpoints_sorted) <= self.args.save_total_limit: |
| | return |
| |
|
| | number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit) |
| | checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] |
| | for checkpoint in checkpoints_to_be_deleted: |
| | logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) |
| | shutil.rmtree(checkpoint) |
| |
|
| | def evaluate( |
| | self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None, |
| | ) -> Dict[str, float]: |
| | """ |
| | Run evaluation and return metrics. |
| | |
| | The calling script will be responsible for providing a method to compute metrics, as they are |
| | task-dependent. |
| | |
| | Args: |
| | eval_dataset: (Optional) Pass a dataset if you wish to override |
| | the one on the instance. |
| | Returns: |
| | A dict containing: |
| | - the eval loss |
| | - the potential metrics computed from the predictions |
| | """ |
| | eval_dataloader = self.get_eval_dataloader(eval_dataset) |
| |
|
| | output = self._prediction_loop(eval_dataloader, description="Evaluation") |
| |
|
| | self._log(output.metrics) |
| |
|
| | if self.args.tpu_metrics_debug: |
| | |
| | xm.master_print(met.metrics_report()) |
| |
|
| | return output.metrics |
| |
|
| | def predict(self, test_dataset: Dataset) -> PredictionOutput: |
| | """ |
| | Run prediction and return predictions and potential metrics. |
| | |
| | Depending on the dataset and your use case, your test dataset may contain labels. |
| | In that case, this method will also return metrics, like in evaluate(). |
| | """ |
| | test_dataloader = self.get_test_dataloader(test_dataset) |
| |
|
| | return self._prediction_loop(test_dataloader, description="Prediction") |
| |
|
| | def _prediction_loop( |
| | self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None |
| | ) -> PredictionOutput: |
| | """ |
| | Prediction/evaluation loop, shared by `evaluate()` and `predict()`. |
| | |
| | Works both with or without labels. |
| | |
| | NOTE: One issue is on the size of prediction and labels. |
| | For current code, it considers all the prediction and labels in different batch have same length of sequence. |
| | This is not true for our application. To make this more general, I will reformat the predictions and labels. |
| | |
| | """ |
| |
|
| | prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only |
| |
|
| | model = self.model |
| | |
| | if self.args.n_gpu > 1: |
| | model = torch.nn.DataParallel(model) |
| | else: |
| | model = self.model |
| | |
| | |
| |
|
| | batch_size = dataloader.batch_size |
| | logger.info("***** Running %s *****", description) |
| | logger.info(" Num examples = %d", self.num_examples(dataloader)) |
| | logger.info(" Batch size = %d", batch_size) |
| | eval_losses: List[float] = [] |
| | preds: torch.Tensor = None |
| | preds_size: torch.Tensor = None |
| | label_ids: torch.Tensor = None |
| | label_size: torch.Tensor = None |
| | model.eval() |
| |
|
| | if is_torch_tpu_available(): |
| | dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) |
| |
|
| | for inputs in tqdm(dataloader, desc=description): |
| | has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) |
| |
|
| | for k, v in inputs.items(): |
| | if isinstance(v, torch.Tensor): |
| | inputs[k] = v.to(self.args.device) |
| |
|
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | if has_labels: |
| | step_eval_loss, logits = outputs[:2] |
| | eval_losses += [step_eval_loss.mean().item()] |
| | else: |
| | logits = outputs[0] |
| |
|
| | if not prediction_loss_only: |
| | |
| | |
| | if preds is None: |
| | preds = logits.detach() |
| | preds_size = preds.new_full(size=preds.size()[:1], fill_value=preds.size(1)).detach() |
| | preds = preds.view(-1) |
| | else: |
| | preds_size = torch.cat((preds_size, logits.new_full(size=logits.size()[:1], fill_value=logits.size(1)).detach()), dim=0) |
| | preds = torch.cat((preds, logits.detach().view(-1)), dim=0) |
| |
|
| | if inputs.get("labels") is not None: |
| | if label_ids is None: |
| | label_ids = inputs["labels"].detach() |
| | label_size = label_ids.new_full(size=label_ids.size()[:1], fill_value=label_ids.size(1)).detach() |
| | label_ids = label_ids.view(-1) |
| | else: |
| | label_size = torch.cat((label_size, inputs["labels"].new_full(size=inputs["labels"].size()[:1], fill_value=inputs["labels"].size(1)).detach()), dim=0) |
| | label_ids = torch.cat((label_ids, inputs["labels"].detach().view(-1)), dim=0) |
| |
|
| | if self.args.local_rank != -1: |
| | |
| | if preds is not None: |
| | |
| | preds, preds_size = self.distributed_concat_with_size(preds, preds_size, num_total_examples=self.num_examples(dataloader)) |
| | if label_ids is not None: |
| | |
| | label_ids, label_size = self.distributed_concat_with_size(label_ids, label_size, num_total_examples=self.num_examples(dataloader)) |
| | elif is_torch_tpu_available(): |
| | |
| | |
| | if preds is not None: |
| | preds = xm.mesh_reduce("eval_preds", preds, torch.cat) |
| | if label_ids is not None: |
| | label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) |
| |
|
| | |
| | if preds is not None: |
| | preds = preds.cpu().numpy() |
| | preds_size = preds_size.cpu().numpy() |
| | if label_ids is not None: |
| | label_ids = label_ids.cpu().numpy() |
| | label_size = label_size.cpu().numpy() |
| | if self.compute_metrics is not None and preds is not None and label_ids is not None: |
| | |
| | metrics = self.compute_metrics(EvalPredictionWithSize(predictions=preds, predictions_size=preds_size, label_ids=label_ids, label_size=label_size)) |
| | else: |
| | metrics = {} |
| | if len(eval_losses) > 0: |
| | metrics["eval_loss"] = np.mean(eval_losses) |
| |
|
| | |
| | for key in list(metrics.keys()): |
| | if not key.startswith("eval_"): |
| | metrics[f"eval_{key}"] = metrics.pop(key) |
| |
|
| | |
| | return PredictionOutputWithSize(predictions=preds, predictions_size=preds_size, label_ids=label_ids, label_size=label_size, metrics=metrics) |
| |
|
| | def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor: |
| | assert self.args.local_rank != -1 |
| |
|
| | output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] |
| | torch.distributed.all_gather(output_tensors, tensor) |
| |
|
| | concat = torch.cat(output_tensors, dim=0) |
| |
|
| | |
| | output = concat[:num_total_examples] |
| | return output |
| |
|
| | def distributed_concat_tensor(self, tensor: torch.Tensor): |
| | assert self.args.local_rank != -1 |
| |
|
| | output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] |
| | torch.distributed.all_gather(output_tensors, tensor) |
| |
|
| | concat = torch.cat(output_tensors, dim=0) |
| | return concat |
| |
|
| | def distributed_concat_varsize_tensor(self, tensor: torch.Tensor): |
| | assert self.args.local_rank != -1 |
| |
|
| | sizes = self.distributed_concat_tensor(tensor.new_full(size=(1,), fill_value=tensor.size(0))) |
| | max_size = sizes.max().item() |
| |
|
| | padded = tensor.new_zeros(max_size) |
| | padded[:tensor.size(0)] = tensor |
| |
|
| | padded_agg = self.distributed_concat_tensor(padded) |
| | slices = [] |
| | for i, size in enumerate(sizes): |
| | start_idx = i * max_size |
| | end_idx = start_idx + size.item() |
| | slices.append(padded_agg[start_idx: end_idx]) |
| | ret = torch.cat(slices, dim=0) |
| | return ret |
| |
|
| |
|
| | def distributed_concat_with_size(self, tensor: torch.Tensor, size: torch.Tensor, num_total_examples: int) -> torch.Tensor: |
| | assert self.args.local_rank != -1 |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | concat_sizes = self.distributed_concat_varsize_tensor(size) |
| | concat = self.distributed_concat_varsize_tensor(tensor) |
| |
|
| | |
| |
|
| | assert concat_sizes.sum() == concat.size(0) |
| | return concat, concat_sizes |
| |
|
| |
|
| |
|
| |
|