| | |
| | """ |
| | Builder for the training args and trainer |
| | """ |
| |
|
| | import abc |
| | import importlib |
| | import importlib.util |
| | import logging |
| | import math |
| | import sys |
| | from abc import abstractmethod |
| | from collections import defaultdict |
| | from dataclasses import dataclass, field |
| | from functools import wraps |
| | from pathlib import Path |
| | from typing import Dict, List, Literal, Optional, Type, Union |
| |
|
| | import torch |
| | import transformers |
| | from datasets import Dataset |
| | from torch.optim.lr_scheduler import OneCycleLR |
| | from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler |
| | from transformers import ( |
| | EarlyStoppingCallback, |
| | PreTrainedModel, |
| | Trainer, |
| | TrainerCallback, |
| | TrainingArguments, |
| | ) |
| | from transformers.trainer_utils import seed_worker |
| | from transformers.utils import is_sagemaker_mp_enabled |
| | from trl import DPOTrainer, ORPOConfig, ORPOTrainer |
| | from trl.trainer.utils import pad_to_length |
| |
|
| | from axolotl.loraplus import create_loraplus_optimizer |
| | from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES |
| | from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler |
| | from axolotl.utils import is_mlflow_available |
| | from axolotl.utils.callbacks import ( |
| | EvalFirstStepCallback, |
| | GPUStatsCallback, |
| | LossWatchDogCallback, |
| | SaveAxolotlConfigtoWandBCallback, |
| | SaveBetterTransformerModelCallback, |
| | SaveModelOnTrainEndCallback, |
| | bench_eval_callback_factory, |
| | causal_lm_bench_eval_callback_factory, |
| | log_prediction_callback_factory, |
| | ) |
| | from axolotl.utils.callbacks.lisa import lisa_callback_factory |
| | from axolotl.utils.collators import ( |
| | BatchSamplerDataCollatorForSeq2Seq, |
| | DataCollatorForSeq2Seq, |
| | MambaDataCollator, |
| | V2BatchSamplerDataCollatorForSeq2Seq, |
| | ) |
| | from axolotl.utils.models import ensure_dtype |
| | from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths |
| | from axolotl.utils.schedulers import ( |
| | get_cosine_schedule_with_min_lr, |
| | get_cosine_schedule_with_quadratic_warmup, |
| | get_cosine_schedule_with_warmup_decay_constant, |
| | ) |
| |
|
| | if is_sagemaker_mp_enabled(): |
| | import smdistributed.modelparallel.torch as smp |
| |
|
| | try: |
| | import torch._dynamo |
| | except ImportError: |
| | pass |
| |
|
| | LOG = logging.getLogger("axolotl.core.trainer_builder") |
| |
|
| |
|
| | def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): |
| | if isinstance(tag_names, str): |
| | tag_names = [tag_names] |
| |
|
| | if kwargs is not None: |
| | if "tags" not in kwargs: |
| | kwargs["tags"] = tag_names |
| | elif "tags" in kwargs and isinstance(kwargs["tags"], list): |
| | kwargs["tags"].extend(tag_names) |
| | elif "tags" in kwargs and isinstance(kwargs["tags"], str): |
| | tag_names.append(kwargs["tags"]) |
| | kwargs["tags"] = tag_names |
| |
|
| | return kwargs |
| |
|
| |
|
| | @dataclass |
| | class AxolotlTrainingArguments(TrainingArguments): |
| | """ |
| | Extend the base TrainingArguments for axolotl helpers |
| | """ |
| |
|
| | model_type: Optional[str] = field( |
| | default=None, metadata={"help": "HF model configuration model_type."} |
| | ) |
| | lr_quadratic_warmup: bool = field( |
| | default=False, |
| | metadata={"help": "Use quadratic warmup for cosine scheduling."}, |
| | ) |
| | pretraining: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "Indicates to trainer whether we are doing continued pretraining." |
| | }, |
| | ) |
| | sample_packing: bool = field( |
| | default=False, |
| | metadata={"help": "Use sample packing for efficient training."}, |
| | ) |
| | multipack_real_batches: bool = field( |
| | default=False, |
| | metadata={"help": "Use real batches for efficient training."}, |
| | ) |
| | eval_sample_packing: Optional[bool] = field( |
| | default=None, |
| | metadata={"help": "Use sample packing for efficient evals."}, |
| | ) |
| | sample_packing_efficiency: float = field( |
| | default=1.0, |
| | metadata={"help": "Sample packing efficiency for calculating batch length."}, |
| | ) |
| | max_seq_length: int = field( |
| | default=2048, |
| | metadata={"help": "The maximum sequence length the model can handle"}, |
| | ) |
| | sample_packing_seq_len_multiplier: int = field( |
| | default=1, |
| | metadata={"help": "the multiplier for the max len for packed sequences"}, |
| | ) |
| | relora_steps: Optional[int] = field( |
| | default=None, |
| | metadata={"help": "how often to reset for ReLoRA"}, |
| | ) |
| | relora_warmup_steps: Optional[int] = field( |
| | default=None, |
| | metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, |
| | ) |
| | relora_anneal_steps: Optional[int] = field( |
| | default=None, |
| | metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, |
| | ) |
| | relora_prune_ratio: Optional[float] = field( |
| | default=0.9, |
| | metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, |
| | ) |
| | bench_split: Optional[str] = field( |
| | default="eval", metadata={"help": "The benchmark split to run on"} |
| | ) |
| | bench_dataset: Optional[str] = field( |
| | default="pharaouk/dharma-1/dharma_1_mini.json", |
| | metadata={ |
| | "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" |
| | }, |
| | ) |
| | do_bench_eval: Optional[bool] = field( |
| | default=False, metadata={"help": "Whether to run the Benchmark evaluation."} |
| | ) |
| | do_causal_lm_eval: Optional[bool] = field( |
| | default=False, metadata={"help": "Whether to run the Causal LM evaluation."} |
| | ) |
| | max_bench_samples: Optional[int] = field( |
| | default=None, |
| | metadata={ |
| | "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." |
| | }, |
| | ) |
| | bench_source_max_len: int = field( |
| | default=2048, metadata={"help": "Maximum source sequence length for bench."} |
| | ) |
| | dataloader_prefetch_factor: Optional[int] = field( |
| | default=None, |
| | metadata={"help": "prefetch_factor argument to the dataloader"}, |
| | ) |
| | cosine_min_lr_ratio: Optional[float] = field( |
| | default=None, |
| | metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, |
| | ) |
| | cosine_constant_lr_ratio: Optional[float] = field( |
| | default=None, |
| | metadata={ |
| | "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" |
| | }, |
| | ) |
| | loraplus_lr_ratio: Optional[float] = field( |
| | default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} |
| | ) |
| | loraplus_lr_embedding: Optional[float] = field( |
| | default=1e-6, |
| | metadata={"help": "loraplus learning rate for lora embedding layers."}, |
| | ) |
| | qlora: bool = field( |
| | default=False, |
| | metadata={"help": "whether this is a qlora training"}, |
| | ) |
| | orpo_alpha: Optional[float] = field( |
| | default=None, |
| | ) |
| | lisa_n_layers: Optional[int] = field( |
| | default=None, |
| | metadata={"help": "the number of activate layers in LISA"}, |
| | ) |
| | lisa_step_interval: Optional[int] = field( |
| | default=None, |
| | metadata={"help": "how often to switch layers in LISA"}, |
| | ) |
| | lisa_layers_attribute: Optional[str] = field( |
| | default=None, |
| | metadata={"help": "path under the model to access the layers"}, |
| | ) |
| | curriculum_sampling: Optional[bool] = field( |
| | default=None, |
| | metadata={"help": "whether to use sequential sampling for curriculum learning"}, |
| | ) |
| |
|
| |
|
| | class AxolotlTrainer(Trainer): |
| | """ |
| | Extend the base Trainer for axolotl helpers |
| | """ |
| |
|
| | args = None |
| | tag_names = ["axolotl"] |
| |
|
| | def __init__( |
| | self, |
| | *_args, |
| | num_epochs=1, |
| | bench_data_collator=None, |
| | eval_data_collator=None, |
| | **kwargs, |
| | ): |
| | self.num_epochs = num_epochs |
| | self.bench_data_collator = bench_data_collator |
| | self.eval_data_collator = eval_data_collator |
| | super().__init__(*_args, **kwargs) |
| | self.train_data_collator = self.data_collator |
| | self._stored_metrics = defaultdict(lambda: defaultdict(list)) |
| | if self.args.orpo_alpha: |
| | self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") |
| |
|
| | def create_optimizer(self): |
| | if self.args.loraplus_lr_ratio is None: |
| | return super().create_optimizer() |
| |
|
| | opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model |
| | if self.optimizer is None: |
| | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( |
| | self.args, |
| | opt_model, |
| | ) |
| |
|
| | loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) |
| | loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) |
| | self.optimizer = create_loraplus_optimizer( |
| | opt_model, |
| | optimizer_cls, |
| | optimizer_kwargs, |
| | loraplus_lr_ratio, |
| | loraplus_lr_embedding, |
| | ) |
| |
|
| | if is_sagemaker_mp_enabled(): |
| | self.optimizer = smp.DistributedOptimizer( |
| | self.optimizer |
| | ) |
| |
|
| | return self.optimizer |
| |
|
| | def create_scheduler( |
| | self, num_training_steps: int, optimizer: torch.optim.Optimizer = None |
| | ): |
| | """ |
| | Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or |
| | passed as an argument. |
| | |
| | Args: |
| | num_training_steps (int): The number of training steps to do. |
| | optimizer (torch.optim.Optimizer): The training optimizer |
| | """ |
| | use_cosine_quadratic = ( |
| | self.args.lr_scheduler_type == "cosine" |
| | and self.args.lr_quadratic_warmup is True |
| | ) |
| |
|
| | use_cosine_min_lr = ( |
| | self.args.lr_scheduler_type == "cosine" |
| | and self.args.cosine_min_lr_ratio is not None |
| | ) |
| |
|
| | |
| | if self.lr_scheduler is None: |
| | |
| | if use_cosine_quadratic: |
| | if use_cosine_min_lr: |
| | LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") |
| |
|
| | self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( |
| | optimizer, |
| | num_warmup_steps=self.args.get_warmup_steps(num_training_steps), |
| | num_training_steps=num_training_steps, |
| | ) |
| | elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: |
| | assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" |
| | assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" |
| | self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( |
| | optimizer, |
| | num_warmup_steps=self.args.get_warmup_steps(num_training_steps), |
| | num_training_steps=num_training_steps, |
| | min_lr_ratio=self.args.cosine_min_lr_ratio, |
| | constant_lr_ratio=self.args.cosine_constant_lr_ratio, |
| | ) |
| | elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: |
| | assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" |
| | self.lr_scheduler = get_cosine_schedule_with_min_lr( |
| | optimizer, |
| | num_warmup_steps=self.args.get_warmup_steps(num_training_steps), |
| | num_training_steps=num_training_steps, |
| | min_lr_ratio=self.args.cosine_min_lr_ratio, |
| | ) |
| | else: |
| | return super().create_scheduler(num_training_steps, optimizer) |
| | else: |
| | if use_cosine_quadratic: |
| | LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") |
| |
|
| | if use_cosine_min_lr: |
| | LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") |
| |
|
| | return self.lr_scheduler |
| |
|
| | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: |
| | if self.args.sample_packing and not self.args.pretraining: |
| | if self.args.multipack_real_batches: |
| | batch_size = self.args.per_device_train_batch_size |
| | batch_max_len = self.args.max_seq_length |
| | else: |
| | batch_size = 1 |
| | batch_max_len = ( |
| | self.args.per_device_train_batch_size * self.args.max_seq_length |
| | ) |
| | return MultipackBatchSampler( |
| | RandomSampler(self.train_dataset), |
| | batch_size=batch_size, |
| | drop_last=True, |
| | batch_max_len=batch_max_len, |
| | lengths=get_dataset_lengths(self.train_dataset), |
| | packing_efficiency_estimate=self.args.sample_packing_efficiency, |
| | ) |
| | if self.args.curriculum_sampling: |
| | return SequentialSampler(self.train_dataset) |
| | return super()._get_train_sampler() |
| |
|
| | def _get_eval_sampler( |
| | self, eval_dataset: Dataset |
| | ) -> Optional[torch.utils.data.Sampler]: |
| | if self.args.sample_packing and self.args.eval_sample_packing is not False: |
| | if self.args.multipack_real_batches: |
| | batch_size = self.args.per_device_eval_batch_size |
| | batch_max_len = self.args.max_seq_length |
| | else: |
| | batch_size = 1 |
| | batch_max_len = ( |
| | self.args.per_device_eval_batch_size * self.args.max_seq_length |
| | ) |
| | return MultipackBatchSampler( |
| | SequentialSampler(eval_dataset), |
| | batch_size=batch_size, |
| | drop_last=True, |
| | batch_max_len=batch_max_len, |
| | lengths=get_dataset_lengths(eval_dataset), |
| | packing_efficiency_estimate=self.args.sample_packing_efficiency, |
| | ) |
| | return super()._get_eval_sampler(eval_dataset) |
| |
|
| | def get_train_dataloader(self) -> DataLoader: |
| | if self.args.sample_packing and not self.args.pretraining: |
| | train_dataset = self.train_dataset |
| | if "length" in train_dataset.features.keys(): |
| | train_dataset = train_dataset.remove_columns(["length"]) |
| | data_collator = self.data_collator |
| | dataloader_params = { |
| | "batch_size": self._train_batch_size, |
| | "collate_fn": data_collator, |
| | "num_workers": self.args.dataloader_num_workers, |
| | "pin_memory": self.args.dataloader_pin_memory, |
| | } |
| | if self.args.dataloader_prefetch_factor: |
| | dataloader_params[ |
| | "prefetch_factor" |
| | ] = self.args.dataloader_prefetch_factor |
| |
|
| | sampler = self._get_train_sampler() |
| | if isinstance(sampler, BatchSampler): |
| | dataloader_params["batch_sampler"] = sampler |
| | del dataloader_params["batch_size"] |
| | else: |
| | dataloader_params["sampler"] = sampler |
| | dataloader_params["drop_last"] = self.args.dataloader_drop_last |
| | dataloader_params["worker_init_fn"] = seed_worker |
| |
|
| | self.accelerator.even_batches = False |
| | return self.accelerator.prepare_data_loader( |
| | DataLoader(train_dataset, **dataloader_params) |
| | ) |
| | return super().get_train_dataloader() |
| |
|
| | def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: |
| | if self.args.sample_packing and self.args.eval_sample_packing is False: |
| | self.data_collator = ( |
| | self.eval_data_collator |
| | ) |
| | dataloader = super().get_eval_dataloader(eval_dataset) |
| | self.data_collator = ( |
| | self.train_data_collator |
| | ) |
| | return dataloader |
| |
|
| | if self.args.sample_packing and self.args.eval_sample_packing is not False: |
| | eval_dataset = ( |
| | eval_dataset if eval_dataset is not None else self.eval_dataset |
| | ) |
| |
|
| | eval_sampler = self._get_eval_sampler(eval_dataset) |
| | eval_dataset = eval_dataset.remove_columns(["length"]) |
| | data_collator = self.data_collator |
| | dataloader_params = { |
| | "batch_size": self.args.eval_batch_size, |
| | "collate_fn": data_collator, |
| | "num_workers": self.args.dataloader_num_workers, |
| | "pin_memory": self.args.dataloader_pin_memory, |
| | } |
| | if self.args.dataloader_prefetch_factor: |
| | dataloader_params[ |
| | "prefetch_factor" |
| | ] = self.args.dataloader_prefetch_factor |
| |
|
| | if isinstance(eval_sampler, BatchSampler): |
| | dataloader_params["batch_sampler"] = eval_sampler |
| | del dataloader_params["batch_size"] |
| | else: |
| | dataloader_params["sampler"] = eval_sampler |
| | dataloader_params["drop_last"] = self.args.dataloader_drop_last |
| |
|
| | self.accelerator.even_batches = False |
| | return self.accelerator.prepare_data_loader( |
| | DataLoader(eval_dataset, **dataloader_params) |
| | ) |
| |
|
| | return super().get_eval_dataloader(eval_dataset) |
| |
|
| | def _get_bench_sampler( |
| | self, bench_dataset: Dataset |
| | ) -> Optional[torch.utils.data.Sampler]: |
| | if self.args.world_size <= 1: |
| | return SequentialSampler(bench_dataset) |
| | return None |
| |
|
| | def get_bench_dataloader( |
| | self, |
| | bench_dataset: Dataset, |
| | ) -> DataLoader: |
| | dataloader_params = { |
| | "batch_size": self.args.eval_batch_size, |
| | "collate_fn": self.bench_data_collator, |
| | "num_workers": self.args.dataloader_num_workers, |
| | "pin_memory": self.args.dataloader_pin_memory, |
| | } |
| | if self.args.dataloader_prefetch_factor: |
| | dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor |
| |
|
| | if not isinstance(bench_dataset, torch.utils.data.IterableDataset): |
| | dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) |
| | dataloader_params["drop_last"] = self.args.dataloader_drop_last |
| |
|
| | return DataLoader(bench_dataset, **dataloader_params) |
| | |
| |
|
| | def compute_loss(self, model, inputs, return_outputs=False): |
| | |
| | |
| | |
| | |
| | |
| | |
| | if self.args.orpo_alpha: |
| | return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs) |
| | return super().compute_loss(model, inputs, return_outputs=return_outputs) |
| |
|
| | @staticmethod |
| | def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): |
| | concatenated_batch = {} |
| |
|
| | max_length = max( |
| | inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1] |
| | ) |
| | |
| | concatenated_batch["input_ids"] = pad_to_length( |
| | inputs["input_ids"], max_length, pad_token |
| | ) |
| | concatenated_batch["rejected_input_ids"] = pad_to_length( |
| | inputs["rejected_input_ids"], max_length, pad_token |
| | ) |
| | concatenated_batch["labels"] = pad_to_length( |
| | inputs["labels"], max_length, label_pad_token |
| | ) |
| | concatenated_batch["rejected_labels"] = pad_to_length( |
| | inputs["rejected_labels"], max_length, label_pad_token |
| | ) |
| | concatenated_batch["attention_mask"] = pad_to_length( |
| | inputs["attention_mask"], max_length, 0 |
| | ) |
| | concatenated_batch["rejected_attention_mask"] = pad_to_length( |
| | inputs["rejected_attention_mask"], max_length, 0 |
| | ) |
| | concatenated_batch["prompt_attention_mask"] = pad_to_length( |
| | inputs["prompt_attention_mask"], max_length, 0 |
| | ).to(device=device) |
| |
|
| | input_ids = torch.cat( |
| | [concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]], |
| | dim=0, |
| | ).to(device=device) |
| | attention_mask = torch.cat( |
| | [ |
| | concatenated_batch["attention_mask"], |
| | concatenated_batch["rejected_attention_mask"], |
| | ], |
| | dim=0, |
| | ).to(device=device) |
| | labels = torch.cat( |
| | [concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0 |
| | ).to(device=device) |
| |
|
| | return { |
| | "input_ids": input_ids, |
| | "labels": labels, |
| | "attention_mask": attention_mask, |
| | "prompt_attention_mask": concatenated_batch["prompt_attention_mask"], |
| | } |
| |
|
| | def orpo_compute_custom_loss(self, logits, labels): |
| | logits = logits.contiguous() |
| | loss = 0.0 |
| |
|
| | if labels is not None: |
| | |
| | labels = labels.to(logits.device) |
| | |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| |
|
| | |
| | loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean( |
| | dim=-1 |
| | ) |
| |
|
| | return loss |
| |
|
| | def orpo_compute_logps( |
| | self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits |
| | ): |
| | |
| | chosen_shape = chosen_attention_mask[:, :-1].shape |
| |
|
| | |
| | pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1) |
| |
|
| | |
| | prompt_attention_mask_padded = torch.nn.functional.pad( |
| | prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0 |
| | ) |
| |
|
| | |
| | mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded |
| |
|
| | per_token_logps = torch.gather( |
| | logits[:, :-1, :].log_softmax(-1), |
| | dim=2, |
| | index=(mask * chosen_inputs[:, 1:]).unsqueeze(2), |
| | ).squeeze(2) |
| | return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1) |
| |
|
| | def orpo_compute_loss(self, model, inputs, return_outputs=False): |
| | concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( |
| | inputs, |
| | label_pad_token=-100, |
| | pad_token=self.tokenizer.pad_token_id, |
| | device=self.accelerator.device, |
| | ) |
| |
|
| | |
| | outputs = model( |
| | **{ |
| | "input_ids": concat_inputs["input_ids"], |
| | "attention_mask": concat_inputs["attention_mask"], |
| | "labels": concat_inputs["labels"], |
| | }, |
| | output_hidden_states=True, |
| | ) |
| |
|
| | |
| | outputs_pos, outputs_neg = outputs.logits.chunk(2) |
| |
|
| | |
| | pos_loss = self.orpo_compute_custom_loss( |
| | logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0] |
| | ) |
| |
|
| | |
| | pos_prob = self.orpo_compute_logps( |
| | prompt_attention_mask=concat_inputs["prompt_attention_mask"], |
| | chosen_inputs=concat_inputs["input_ids"].chunk(2)[0], |
| | chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0], |
| | logits=outputs_pos, |
| | ) |
| | neg_prob = self.orpo_compute_logps( |
| | prompt_attention_mask=concat_inputs["prompt_attention_mask"], |
| | chosen_inputs=concat_inputs["input_ids"].chunk(2)[1], |
| | chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1], |
| | logits=outputs_neg, |
| | ) |
| |
|
| | |
| | log_odds = (pos_prob - neg_prob) - ( |
| | torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)) |
| | ) |
| | sig_ratio = torch.nn.functional.sigmoid(log_odds) |
| | ratio = torch.log(sig_ratio) |
| |
|
| | |
| | loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to( |
| | dtype=torch.bfloat16 |
| | ) |
| |
|
| | metrics = {} |
| | metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item() |
| | metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item() |
| | metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item() |
| | metrics["log_odds"] = torch.mean(log_odds).cpu().item() |
| | self.store_metrics(metrics, train_eval="train") |
| |
|
| | return (loss, outputs_pos) if return_outputs else loss |
| |
|
| | @wraps(Trainer.push_to_hub) |
| | def push_to_hub(self, *args, **kwargs) -> str: |
| | """ |
| | Overwrite the `push_to_hub` method in order to force-add the tags when pushing the |
| | model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. |
| | """ |
| | kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) |
| |
|
| | return super().push_to_hub(*args, **kwargs) |
| |
|
| | @wraps(Trainer.create_accelerator_and_postprocess) |
| | def create_accelerator_and_postprocess(self): |
| | res = super().create_accelerator_and_postprocess() |
| |
|
| | if self.is_fsdp_enabled: |
| | if ( |
| | "limit_all_gathers" in self.args.fsdp_config |
| | and self.args.fsdp_config["limit_all_gathers"] |
| | ): |
| | self.accelerator.state.fsdp_plugin.limit_all_gathers = True |
| |
|
| | return res |
| |
|
| | def log(self, logs: Dict[str, float]) -> None: |
| | """ |
| | Log `logs` on the various objects watching training, including stored metrics. |
| | |
| | Args: |
| | logs (`Dict[str, float]`): |
| | The values to log. |
| | """ |
| | |
| | train_eval = "train" if "loss" in logs else "eval" |
| | |
| | for key, metrics in self._stored_metrics[train_eval].items(): |
| | logs[key] = torch.tensor(metrics).mean().item() |
| | del self._stored_metrics[train_eval] |
| | return super().log(logs) |
| |
|
| | def store_metrics( |
| | self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" |
| | ) -> None: |
| | for key, value in metrics.items(): |
| | self._stored_metrics[train_eval][key].append(value) |
| |
|
| |
|
| | class AxolotlMambaTrainer(AxolotlTrainer): |
| | """ |
| | Mamba specific trainer to handle loss calculation |
| | """ |
| |
|
| | tag_names = ["axolotl", "mamba"] |
| |
|
| | def compute_loss( |
| | self, |
| | model, |
| | inputs, |
| | return_outputs=False, |
| | ): |
| | input_ids = inputs.pop("input_ids") |
| | lm_logits = model(input_ids).logits |
| |
|
| | labels = input_ids.to(lm_logits.device) |
| | shift_logits = lm_logits[:, :-1, :].contiguous() |
| | labels = labels[:, 1:].contiguous() |
| |
|
| | loss_fct = torch.nn.CrossEntropyLoss() |
| | lm_loss = loss_fct( |
| | shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) |
| | ) |
| |
|
| | return lm_loss |
| |
|
| |
|
| | class OneCycleLRSchedulerTrainer(AxolotlTrainer): |
| | """ |
| | Trainer subclass that uses the OneCycleLR scheduler |
| | """ |
| |
|
| | tag_names = ["axolotl", "onecycle"] |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.lr_scheduler = None |
| |
|
| | def create_scheduler( |
| | self, |
| | num_training_steps: int, |
| | optimizer: Optional[torch.optim.Optimizer] = None, |
| | ): |
| | optimizer = self.optimizer if optimizer is None else optimizer |
| | num_warmup_steps = self.args.get_warmup_steps(num_training_steps) |
| | pct_start = num_warmup_steps / num_training_steps |
| |
|
| | self.lr_scheduler = OneCycleLR( |
| | optimizer, |
| | max_lr=self.args.learning_rate, |
| | total_steps=num_training_steps, |
| | pct_start=pct_start, |
| | div_factor=6, |
| | ) |
| |
|
| | return self.lr_scheduler |
| |
|
| |
|
| | class ReLoRATrainer(AxolotlTrainer): |
| | """ |
| | Trainer subclass that uses the OneCycleLR scheduler |
| | """ |
| |
|
| | tag_names = ["axolotl", "relora"] |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.lr_scheduler = None |
| |
|
| | def create_scheduler( |
| | self, |
| | num_training_steps: int, |
| | optimizer: Optional[torch.optim.Optimizer] = None, |
| | ): |
| | optimizer = self.optimizer if optimizer is None else optimizer |
| | lr_scheduler = super().create_scheduler(num_training_steps, optimizer) |
| |
|
| | if self.args.relora_steps: |
| | warmup_steps = ( |
| | self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 |
| | ) |
| | anneal_steps = ( |
| | self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 |
| | ) |
| | self.lr_scheduler = ReLoRAScheduler( |
| | optimizer, |
| | lr_scheduler, |
| | self.args.relora_steps, |
| | anneal_steps, |
| | warmup_steps, |
| | ) |
| | else: |
| | self.lr_scheduler = lr_scheduler |
| |
|
| | return self.lr_scheduler |
| |
|
| |
|
| | class AxolotlDPOTrainer(DPOTrainer): |
| | """ |
| | Extend the base DPOTrainer for axolotl helpers |
| | """ |
| |
|
| | tag_names = ["axolotl", "dpo"] |
| |
|
| | @wraps(DPOTrainer.push_to_hub) |
| | def push_to_hub(self, *args, **kwargs) -> str: |
| | """ |
| | Overwrite the `push_to_hub` method in order to force-add the tags when pushing the |
| | model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. |
| | """ |
| | kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) |
| |
|
| | return super().push_to_hub(*args, **kwargs) |
| |
|
| | def tokenize_row( |
| | self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None |
| | ) -> Dict: |
| | res = super().tokenize_row(feature, model=model) |
| | if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None: |
| | for key in res.keys(): |
| | res[key] = res[key][1:] |
| | return res |
| |
|
| |
|
| | class AxolotlORPOTrainer(ORPOTrainer): |
| | """ |
| | Extend the base ORPOTrainer for axolotl helpers |
| | """ |
| |
|
| | tag_names = ["axolotl", "orpo"] |
| |
|
| |
|
| | class TrainerBuilderBase(abc.ABC): |
| | """ |
| | Base class for trainer builder |
| | """ |
| |
|
| | _train_dataset = None |
| | _eval_dataset = None |
| | _model_ref = None |
| | _peft_config = None |
| |
|
| | def __init__(self, cfg, model, tokenizer): |
| | self.cfg = cfg |
| | self.model = model |
| | self.tokenizer = tokenizer |
| |
|
| | |
| | |
| | |
| | if hasattr(model, "add_model_tags"): |
| | model.add_model_tags(["axolotl"]) |
| |
|
| | @property |
| | def model_ref(self): |
| | return self._model_ref |
| |
|
| | @model_ref.setter |
| | def model_ref(self, model): |
| | self._model_ref = model |
| |
|
| | @property |
| | def train_dataset(self): |
| | return self._train_dataset |
| |
|
| | @train_dataset.setter |
| | def train_dataset(self, dataset): |
| | self._train_dataset = dataset |
| |
|
| | @property |
| | def eval_dataset(self): |
| | return self._eval_dataset |
| |
|
| | @eval_dataset.setter |
| | def eval_dataset(self, dataset): |
| | self._eval_dataset = dataset |
| |
|
| | @property |
| | def peft_config(self): |
| | return self._peft_config |
| |
|
| | @peft_config.setter |
| | def peft_config(self, peft_config): |
| | self._peft_config = peft_config |
| |
|
| | @abstractmethod |
| | def build(self, total_num_steps): |
| | pass |
| |
|
| | def get_callbacks(self) -> List[TrainerCallback]: |
| | callbacks = [] |
| | if self.cfg.use_wandb: |
| | callbacks.append( |
| | SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) |
| | ) |
| | if self.cfg.use_mlflow and is_mlflow_available(): |
| | from axolotl.utils.callbacks.mlflow_ import ( |
| | SaveAxolotlConfigtoMlflowCallback, |
| | ) |
| |
|
| | callbacks.append( |
| | SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) |
| | ) |
| |
|
| | return callbacks |
| |
|
| | @abstractmethod |
| | def get_post_trainer_create_callbacks(self, trainer): |
| | """ |
| | Callbacks added after the trainer is created, usually b/c these need access to the trainer |
| | """ |
| |
|
| | def hook_pre_create_training_args(self, training_arguments_kwargs): |
| | |
| | return training_arguments_kwargs |
| |
|
| | def hook_post_create_training_args(self, training_arguments): |
| | |
| | return training_arguments |
| |
|
| | def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls): |
| | |
| | return trainer_kwargs, trainer_cls |
| |
|
| | def hook_post_create_trainer(self, trainer): |
| | |
| | return trainer |
| |
|
| |
|
| | class HFCausalTrainerBuilder(TrainerBuilderBase): |
| | """ |
| | Build the HuggingFace training args/trainer for Causal models |
| | """ |
| |
|
| | def get_callbacks(self): |
| | callbacks = super().get_callbacks() |
| | callbacks.append(GPUStatsCallback(self.cfg)) |
| | callbacks.append(EvalFirstStepCallback()) |
| |
|
| | if self.cfg.relora_steps: |
| | callbacks.append(ReLoRACallback(self.cfg)) |
| |
|
| | if ( |
| | hasattr(self.model, "use_bettertransformer") |
| | and self.model.use_bettertransformer is True |
| | ): |
| | callbacks.append(SaveBetterTransformerModelCallback()) |
| |
|
| | if self.cfg.loss_watchdog_threshold is not None: |
| | callbacks.append(LossWatchDogCallback(self.cfg)) |
| |
|
| | callbacks.append(SaveModelOnTrainEndCallback()) |
| |
|
| | return callbacks |
| |
|
| | def get_post_trainer_create_callbacks(self, trainer): |
| | callbacks = [] |
| | if self.cfg.use_wandb and self.cfg.eval_table_size > 0: |
| | LogPredictionCallback = log_prediction_callback_factory( |
| | trainer, self.tokenizer, "wandb" |
| | ) |
| | callbacks.append(LogPredictionCallback(self.cfg)) |
| | if ( |
| | self.cfg.use_mlflow |
| | and is_mlflow_available() |
| | and self.cfg.eval_table_size > 0 |
| | ): |
| | LogPredictionCallback = log_prediction_callback_factory( |
| | trainer, self.tokenizer, "mlflow" |
| | ) |
| | callbacks.append(LogPredictionCallback(self.cfg)) |
| |
|
| | if self.cfg.do_bench_eval: |
| | callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) |
| | if self.cfg.do_causal_lm_eval: |
| | CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory( |
| | trainer, self.tokenizer |
| | ) |
| | callbacks.append(CausalLMBenchEvalCallback(self.cfg)) |
| |
|
| | if self.cfg.early_stopping_patience: |
| | early_stop_cb = EarlyStoppingCallback( |
| | self.cfg.early_stopping_patience, |
| | ) |
| | callbacks.append(early_stop_cb) |
| |
|
| | if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: |
| | callbacks.append(lisa_callback_factory(trainer)) |
| | return callbacks |
| |
|
| | def _get_trainer_cls(self): |
| | if self.cfg.lr_scheduler == "one_cycle" and ( |
| | self.cfg.fsdp or self.cfg.adapter == "qlora" |
| | ): |
| | return OneCycleLRSchedulerTrainer |
| | if self.cfg.relora_steps: |
| | return ReLoRATrainer |
| | if self.cfg.model_config_type == "mamba": |
| | return AxolotlMambaTrainer |
| | return AxolotlTrainer |
| |
|
| | def build(self, total_num_steps): |
| | warmup_steps = None |
| | if self.cfg.warmup_steps is not None: |
| | warmup_steps = self.cfg.warmup_steps |
| | elif self.cfg.warmup_ratio is not None: |
| | warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0) |
| | else: |
| | warmup_steps = min(int(0.03 * total_num_steps), 100) |
| |
|
| | logging_steps = ( |
| | self.cfg.logging_steps |
| | if self.cfg.logging_steps is not None |
| | else max(min(int(0.005 * total_num_steps), 10), 1) |
| | ) |
| |
|
| | training_arguments_kwargs = {} |
| | if self.cfg.bf16 == "full": |
| | training_arguments_kwargs["bf16_full_eval"] = True |
| | else: |
| | training_arguments_kwargs["bf16"] = self.cfg.bf16 |
| | training_arguments_kwargs["fp16"] = ( |
| | self.cfg.fp16 and not self.cfg.bf16 |
| | ) or False |
| | training_arguments_kwargs["tf32"] = self.cfg.tf32 |
| | training_arguments_kwargs["warmup_steps"] = warmup_steps |
| | training_arguments_kwargs["logging_steps"] = logging_steps |
| |
|
| | if self.cfg.seed: |
| | training_arguments_kwargs["seed"] = self.cfg.seed |
| |
|
| | if self.cfg.gradient_checkpointing: |
| | training_arguments_kwargs[ |
| | "gradient_checkpointing" |
| | ] = self.cfg.gradient_checkpointing |
| | if self.cfg.gradient_checkpointing_kwargs is not None: |
| | training_arguments_kwargs[ |
| | "gradient_checkpointing_kwargs" |
| | ] = self.cfg.gradient_checkpointing_kwargs |
| | if self.cfg.fsdp: |
| | training_arguments_kwargs["fsdp"] = self.cfg.fsdp |
| | if self.cfg.fsdp_config: |
| | training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config) |
| |
|
| | if self.cfg.adapter == "qlora": |
| | training_arguments_kwargs["qlora"] = True |
| |
|
| | |
| | if self.cfg.deepspeed: |
| | training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed |
| |
|
| | if self.cfg.lr_quadratic_warmup is not None: |
| | training_arguments_kwargs[ |
| | "lr_quadratic_warmup" |
| | ] = self.cfg.lr_quadratic_warmup |
| |
|
| | if self.cfg.adam_beta1: |
| | training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 |
| | if self.cfg.adam_beta2: |
| | training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2 |
| | if self.cfg.adam_epsilon: |
| | training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon |
| | if self.cfg.max_grad_norm: |
| | training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm |
| |
|
| | if self.cfg.hub_model_id: |
| | training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id |
| | training_arguments_kwargs["push_to_hub"] = True |
| | training_arguments_kwargs["hub_private_repo"] = True |
| | training_arguments_kwargs["hub_always_push"] = True |
| |
|
| | if self.cfg.hub_strategy: |
| | training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy |
| |
|
| | if self.cfg.save_safetensors is not None: |
| | training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors |
| |
|
| | if self.cfg.sample_packing_eff_est: |
| | training_arguments_kwargs[ |
| | "sample_packing_efficiency" |
| | ] = self.cfg.sample_packing_eff_est |
| |
|
| | if self.cfg.dataloader_pin_memory is not None: |
| | training_arguments_kwargs[ |
| | "dataloader_pin_memory" |
| | ] = self.cfg.dataloader_pin_memory |
| | if self.cfg.dataloader_num_workers is not None: |
| | training_arguments_kwargs[ |
| | "dataloader_num_workers" |
| | ] = self.cfg.dataloader_num_workers |
| | if self.cfg.dataloader_prefetch_factor is not None: |
| | training_arguments_kwargs[ |
| | "dataloader_prefetch_factor" |
| | ] = self.cfg.dataloader_prefetch_factor |
| | if self.cfg.dataloader_drop_last is not None: |
| | training_arguments_kwargs[ |
| | "dataloader_drop_last" |
| | ] = self.cfg.dataloader_drop_last |
| | elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False: |
| | training_arguments_kwargs["dataloader_drop_last"] = True |
| |
|
| | if self.cfg.remove_unused_columns is not None: |
| | training_arguments_kwargs[ |
| | "remove_unused_columns" |
| | ] = self.cfg.remove_unused_columns |
| |
|
| | if not self.cfg.test_datasets and self.cfg.val_set_size == 0: |
| | |
| | training_arguments_kwargs["evaluation_strategy"] = "no" |
| | elif self.cfg.eval_steps: |
| | training_arguments_kwargs["evaluation_strategy"] = "steps" |
| | training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps |
| | elif self.cfg.evaluation_strategy: |
| | training_arguments_kwargs[ |
| | "evaluation_strategy" |
| | ] = self.cfg.evaluation_strategy |
| | else: |
| | |
| | training_arguments_kwargs["evaluation_strategy"] = "epoch" |
| |
|
| | if self.cfg.save_steps: |
| | training_arguments_kwargs["save_strategy"] = "steps" |
| | training_arguments_kwargs["save_steps"] = self.cfg.save_steps |
| | elif self.cfg.save_strategy: |
| | training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy |
| | else: |
| | |
| | training_arguments_kwargs["save_strategy"] = "epoch" |
| |
|
| | if self.cfg.do_bench_eval: |
| | training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval |
| | if self.cfg.bench_dataset: |
| | training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset |
| | if self.cfg.do_causal_lm_eval: |
| | training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval |
| | if self.cfg.metric_for_best_model: |
| | training_arguments_kwargs[ |
| | "metric_for_best_model" |
| | ] = self.cfg.metric_for_best_model |
| | if self.cfg.greater_is_better: |
| | training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better |
| |
|
| | if self.cfg.torch_compile: |
| | if torch.__version__ < "2.1.0": |
| | LOG.warning("torch>=2.1.0 required for torch_compile to work properly") |
| | elif torch._dynamo: |
| | torch._dynamo.config.suppress_errors = ( |
| | True |
| | ) |
| | training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile |
| | if self.cfg.torch_compile_backend: |
| | training_arguments_kwargs[ |
| | "torch_compile_backend" |
| | ] = self.cfg.torch_compile_backend |
| |
|
| | |
| | if self.cfg.ddp_timeout: |
| | training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout |
| | |
| | if self.cfg.ddp_bucket_cap_mb: |
| | training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb |
| | if self.cfg.ddp_broadcast_buffers is not None: |
| | training_arguments_kwargs[ |
| | "ddp_broadcast_buffers" |
| | ] = self.cfg.ddp_broadcast_buffers |
| |
|
| | |
| | training_arguments_kwargs["max_steps"] = ( |
| | total_num_steps if self.cfg.max_steps else -1 |
| | ) |
| | training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len |
| | training_arguments_kwargs[ |
| | "per_device_train_batch_size" |
| | ] = self.cfg.micro_batch_size |
| | if self.cfg.eval_batch_size: |
| | training_arguments_kwargs[ |
| | "per_device_eval_batch_size" |
| | ] = self.cfg.eval_batch_size |
| | training_arguments_kwargs[ |
| | "gradient_accumulation_steps" |
| | ] = self.cfg.gradient_accumulation_steps |
| | training_arguments_kwargs[ |
| | "eval_accumulation_steps" |
| | ] = self.cfg.gradient_accumulation_steps |
| | training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs |
| | training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate |
| | training_arguments_kwargs["output_dir"] = self.cfg.output_dir |
| | training_arguments_kwargs["save_total_limit"] = ( |
| | self.cfg.save_total_limit if self.cfg.save_total_limit else 4 |
| | ) |
| | training_arguments_kwargs["load_best_model_at_end"] = ( |
| | ( |
| | self.cfg.load_best_model_at_end is not False |
| | or self.cfg.early_stopping_patience |
| | ) |
| | and ( |
| | (not self.cfg.test_datasets and self.cfg.val_set_size > 0) |
| | or (self.cfg.test_datasets and self.cfg.val_set_size == 0) |
| | ) |
| | and self.cfg.save_steps |
| | and self.cfg.eval_steps |
| | and self.cfg.save_steps % self.cfg.eval_steps == 0 |
| | ) or False |
| | training_arguments_kwargs["ddp_find_unused_parameters"] = ( |
| | False if self.cfg.ddp else None |
| | ) |
| | training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length |
| | training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling |
| | report_to = None |
| | if self.cfg.use_wandb: |
| | report_to = "wandb" |
| | if self.cfg.use_mlflow: |
| | report_to = "mlflow" |
| | training_arguments_kwargs["report_to"] = report_to |
| | training_arguments_kwargs["run_name"] = ( |
| | self.cfg.wandb_name if self.cfg.use_wandb else None |
| | ) |
| | training_arguments_kwargs["optim"] = ( |
| | self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" |
| | ) |
| | if self.cfg.optim_args: |
| | if isinstance(self.cfg.optim_args, dict): |
| | optim_args = ",".join( |
| | [f"{key}={value}" for key, value in self.cfg.optim_args.items()] |
| | ) |
| | else: |
| | optim_args = self.cfg.optim_args |
| | training_arguments_kwargs["optim_args"] = optim_args |
| | if self.cfg.optim_target_modules: |
| | training_arguments_kwargs[ |
| | "optim_target_modules" |
| | ] = self.cfg.optim_target_modules |
| | training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio |
| | training_arguments_kwargs[ |
| | "loraplus_lr_embedding" |
| | ] = self.cfg.loraplus_lr_embedding |
| | training_arguments_kwargs["lr_scheduler_type"] = ( |
| | self.cfg.lr_scheduler |
| | if self.cfg.lr_scheduler |
| | and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep") |
| | else "cosine" |
| | ) |
| | training_arguments_kwargs["lr_scheduler_kwargs"] = ( |
| | self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} |
| | ) |
| | training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio |
| | training_arguments_kwargs[ |
| | "cosine_constant_lr_ratio" |
| | ] = self.cfg.cosine_constant_lr_ratio |
| | training_arguments_kwargs["weight_decay"] = ( |
| | self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 |
| | ) |
| | training_arguments_kwargs["sample_packing"] = ( |
| | self.cfg.sample_packing if self.cfg.sample_packing else False |
| | ) |
| | training_arguments_kwargs["multipack_real_batches"] = ( |
| | self.cfg.flash_attention is not True |
| | ) |
| | training_arguments_kwargs["eval_sample_packing"] = ( |
| | self.cfg.sample_packing |
| | if self.cfg.eval_sample_packing is not False |
| | else False |
| | ) |
| | training_arguments_kwargs[ |
| | "sample_packing_seq_len_multiplier" |
| | ] = self.cfg.micro_batch_size |
| | if self.cfg.relora_steps: |
| | training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps |
| | training_arguments_kwargs[ |
| | "relora_warmup_steps" |
| | ] = self.cfg.relora_warmup_steps |
| | if self.cfg.relora_anneal_steps: |
| | training_arguments_kwargs[ |
| | "relora_anneal_steps" |
| | ] = self.cfg.relora_anneal_steps |
| | if self.cfg.relora_prune_ratio: |
| | training_arguments_kwargs[ |
| | "relora_prune_ratio" |
| | ] = self.cfg.relora_prune_ratio |
| |
|
| | if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: |
| | training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers |
| | training_arguments_kwargs[ |
| | "lisa_step_interval" |
| | ] = self.cfg.lisa_step_interval |
| | training_arguments_kwargs[ |
| | "lisa_layers_attribute" |
| | ] = self.cfg.lisa_layers_attribute |
| |
|
| | training_arguments_kwargs = self.hook_pre_create_training_args( |
| | training_arguments_kwargs |
| | ) |
| | training_arguments_kwargs["model_type"] = self.cfg.model_config_type |
| | training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) |
| |
|
| | if self.cfg.rl == "orpo": |
| | training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha |
| |
|
| | if self.cfg.neftune_noise_alpha is not None: |
| | training_arguments_kwargs[ |
| | "neftune_noise_alpha" |
| | ] = self.cfg.neftune_noise_alpha |
| |
|
| | trainer_kwargs = {} |
| |
|
| | if self.cfg.optimizer == "lion_pytorch": |
| | from lion_pytorch import Lion |
| |
|
| | lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} |
| | if "weight_decay" in training_arguments_kwargs: |
| | lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"] |
| |
|
| | if ( |
| | "adam_beta1" in training_arguments_kwargs |
| | and "adam_beta2" in training_arguments_kwargs |
| | ): |
| | lion_kwargs["betas"] = ( |
| | training_arguments_kwargs["adam_beta1"], |
| | training_arguments_kwargs["adam_beta2"], |
| | ) |
| |
|
| | trainer_kwargs["optimizers"] = ( |
| | Lion(params=self.model.parameters(), **lion_kwargs), |
| | None, |
| | ) |
| | |
| | training_arguments_kwargs["optim"] = "adamw_hf" |
| |
|
| | if self.cfg.optimizer == "adamw_anyprecision": |
| | if Path(self.cfg.torchdistx_path).exists(): |
| | sys.path.append(self.cfg.torchdistx_path) |
| | importlib.import_module("torchdistx") |
| |
|
| | training_args = ( |
| | AxolotlTrainingArguments( |
| | **training_arguments_kwargs, |
| | ) |
| | ) |
| | training_args = self.hook_post_create_training_args(training_args) |
| |
|
| | data_collator_kwargs = { |
| | "padding": True, |
| | } |
| | if self.cfg.pad_to_sequence_len: |
| | data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil( |
| | self.cfg.sequence_len / 64 |
| | ) |
| | else: |
| | |
| | |
| | data_collator_kwargs["pad_to_multiple_of"] = 64 |
| |
|
| | trainer_cls = self._get_trainer_cls() |
| | trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( |
| | trainer_kwargs, trainer_cls |
| | ) |
| | trainer = trainer_cls( |
| | model=self.model, |
| | train_dataset=self.train_dataset, |
| | eval_dataset=self.eval_dataset, |
| | args=training_args, |
| | tokenizer=self.tokenizer, |
| | data_collator=self.build_collator(training_args, **data_collator_kwargs), |
| | eval_data_collator=self.build_collator( |
| | training_args, is_eval=True, **data_collator_kwargs |
| | ), |
| | bench_data_collator=transformers.DataCollatorForSeq2Seq( |
| | self.tokenizer, |
| | return_tensors="pt", |
| | **data_collator_kwargs, |
| | ), |
| | callbacks=self.get_callbacks(), |
| | num_epochs=self.cfg.num_epochs, |
| | **trainer_kwargs, |
| | ) |
| | trainer = self.hook_post_create_trainer(trainer) |
| | for callback in self.get_post_trainer_create_callbacks(trainer): |
| | trainer.add_callback(callback) |
| |
|
| | if self.cfg.deepspeed and self.cfg.sample_packing: |
| | trainer.accelerator.state.deepspeed_plugin.deepspeed_config[ |
| | "train_micro_batch_size_per_gpu" |
| | ] = self.cfg.micro_batch_size |
| |
|
| | return trainer |
| |
|
| | def build_collator( |
| | self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs |
| | ): |
| | if training_args.pretraining: |
| | return None |
| |
|
| | if self.cfg.model_config_type == "mamba": |
| | return MambaDataCollator(tokenizer=self.tokenizer) |
| |
|
| | use_batch_sampler_collator = False |
| | if is_eval is False and training_args.sample_packing: |
| | use_batch_sampler_collator = True |
| | if is_eval and training_args.eval_sample_packing: |
| | use_batch_sampler_collator = True |
| |
|
| | collator: Type[ |
| | Union[ |
| | V2BatchSamplerDataCollatorForSeq2Seq, |
| | BatchSamplerDataCollatorForSeq2Seq, |
| | DataCollatorForSeq2Seq, |
| | ] |
| | ] |
| | if use_batch_sampler_collator: |
| | if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: |
| | collator = V2BatchSamplerDataCollatorForSeq2Seq |
| | elif ( |
| | self.cfg.model_config_type in ["llama"] |
| | and self.cfg.flash_attention is not True |
| | ): |
| | collator = V2BatchSamplerDataCollatorForSeq2Seq |
| | else: |
| | collator = BatchSamplerDataCollatorForSeq2Seq |
| | else: |
| | collator = DataCollatorForSeq2Seq |
| |
|
| | return collator( |
| | self.tokenizer, |
| | return_tensors="pt", |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | class HFRLTrainerBuilder(TrainerBuilderBase): |
| | """ |
| | Trainer factory class for DPO Trainer |
| | """ |
| |
|
| | def get_callbacks(self): |
| | callbacks = super().get_callbacks() |
| | callbacks.append(SaveModelOnTrainEndCallback()) |
| |
|
| | return callbacks |
| |
|
| | def get_post_trainer_create_callbacks(self, trainer): |
| | callbacks = [] |
| | return callbacks |
| |
|
| | def build_training_arguments(self, total_num_steps): |
| | training_args_kwargs = {} |
| | for arg in [ |
| | "adam_beta1", |
| | "adam_beta2", |
| | "adam_epsilon", |
| | "dataloader_num_workers", |
| | "dataloader_pin_memory", |
| | ]: |
| | if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: |
| | training_args_kwargs[arg] = getattr(self.cfg, arg) |
| |
|
| | if self.cfg.hub_model_id: |
| | training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id |
| | training_args_kwargs["push_to_hub"] = True |
| | training_args_kwargs["hub_private_repo"] = True |
| | training_args_kwargs["hub_always_push"] = True |
| |
|
| | if self.cfg.hub_strategy: |
| | training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy |
| |
|
| | if self.cfg.save_safetensors is not None: |
| | training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors |
| |
|
| | if self.eval_dataset: |
| | training_args_kwargs["evaluation_strategy"] = "steps" |
| | training_args_kwargs["eval_steps"] = self.cfg.eval_steps |
| | else: |
| | training_args_kwargs["evaluation_strategy"] = "no" |
| |
|
| | if self.cfg.bf16 or self.cfg.bfloat16: |
| | training_args_kwargs["bf16"] = True |
| |
|
| | training_args_kwargs["lr_scheduler_type"] = ( |
| | self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" |
| | ) |
| | training_args_kwargs["lr_scheduler_kwargs"] = ( |
| | self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} |
| | ) |
| | if self.cfg.remove_unused_columns is not None: |
| | training_args_kwargs[ |
| | "remove_unused_columns" |
| | ] = self.cfg.remove_unused_columns |
| | else: |
| | training_args_kwargs["remove_unused_columns"] = False |
| |
|
| | if self.cfg.dataloader_pin_memory is not None: |
| | training_args_kwargs[ |
| | "dataloader_pin_memory" |
| | ] = self.cfg.dataloader_pin_memory |
| | if self.cfg.dataloader_num_workers is not None: |
| | training_args_kwargs[ |
| | "dataloader_num_workers" |
| | ] = self.cfg.dataloader_num_workers |
| | if self.cfg.dataloader_prefetch_factor is not None: |
| | training_args_kwargs[ |
| | "dataloader_prefetch_factor" |
| | ] = self.cfg.dataloader_prefetch_factor |
| | if self.cfg.gradient_checkpointing: |
| | training_args_kwargs[ |
| | "gradient_checkpointing" |
| | ] = self.cfg.gradient_checkpointing |
| | if self.cfg.gradient_checkpointing_kwargs is not None: |
| | training_args_kwargs[ |
| | "gradient_checkpointing_kwargs" |
| | ] = self.cfg.gradient_checkpointing_kwargs |
| | else: |
| | training_args_kwargs["gradient_checkpointing_kwargs"] = { |
| | "use_reentrant": False |
| | } |
| |
|
| | |
| | if self.cfg.save_steps: |
| | training_args_kwargs["save_strategy"] = "steps" |
| | training_args_kwargs["save_steps"] = self.cfg.save_steps |
| | elif self.cfg.save_strategy: |
| | training_args_kwargs["save_strategy"] = self.cfg.save_strategy |
| | else: |
| | |
| | training_args_kwargs["save_strategy"] = "epoch" |
| |
|
| | if self.cfg.orpo_alpha: |
| | |
| | training_args_kwargs["beta"] = self.cfg.orpo_alpha |
| |
|
| | training_args_cls = TrainingArguments |
| | if self.cfg.rl == "orpo": |
| | training_args_cls = ORPOConfig |
| | training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes |
| | training_args_kwargs["max_length"] = self.cfg.sequence_len |
| | if self.cfg.max_prompt_len: |
| | training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len |
| |
|
| | training_args = training_args_cls( |
| | per_device_train_batch_size=self.cfg.micro_batch_size, |
| | max_steps=self.cfg.max_steps or total_num_steps, |
| | gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, |
| | learning_rate=self.cfg.learning_rate, |
| | output_dir=self.cfg.output_dir, |
| | warmup_steps=self.cfg.warmup_steps, |
| | logging_first_step=True, |
| | logging_steps=1, |
| | optim=self.cfg.optimizer, |
| | save_total_limit=self.cfg.save_total_limit or 5, |
| | **training_args_kwargs, |
| | ) |
| |
|
| | return training_args |
| |
|
| | def build(self, total_num_steps): |
| | training_args = self.build_training_arguments(total_num_steps) |
| | dpo_trainer_kwargs = {} |
| | if self.cfg.rl == "ipo": |
| | dpo_trainer_kwargs["loss_type"] = "ipo" |
| | if self.cfg.dpo_label_smoothing: |
| | dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing |
| | elif self.cfg.rl == "kto_pair": |
| | dpo_trainer_kwargs["loss_type"] = "kto_pair" |
| | if self.eval_dataset: |
| | dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset |
| | if self.cfg.adapter and self.peft_config: |
| | dpo_trainer_kwargs["peft_config"] = self.peft_config |
| | if self.cfg.precompute_ref_log_probs is not None: |
| | dpo_trainer_kwargs[ |
| | "precompute_ref_log_probs" |
| | ] = self.cfg.precompute_ref_log_probs |
| | if self.cfg.rl in ["dpo", "ipo", "kto_pair"]: |
| | trainer_cls = AxolotlDPOTrainer |
| | dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 |
| | trainer_cls_args = [self.model, self.model_ref] |
| |
|
| | |
| | dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len |
| | dpo_trainer_kwargs["max_target_length"] = None |
| | dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len |
| | dpo_trainer_kwargs["generate_during_eval"] = True |
| | if self.cfg.rl == "dpo": |
| | dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes |
| | elif self.cfg.rl == "orpo": |
| | trainer_cls = AxolotlORPOTrainer |
| | trainer_cls_args = [self.model] |
| | else: |
| | raise ValueError(f"Unsupported RL: {self.cfg.rl}") |
| | dpo_trainer = trainer_cls( |
| | *trainer_cls_args, |
| | args=training_args, |
| | train_dataset=self.train_dataset, |
| | tokenizer=self.tokenizer, |
| | callbacks=self.get_callbacks(), |
| | **dpo_trainer_kwargs, |
| | ) |
| | if self.cfg.fsdp: |
| | ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype) |
| |
|
| | dpo_trainer = self.hook_post_create_trainer(dpo_trainer) |
| | for callback in self.get_post_trainer_create_callbacks(dpo_trainer): |
| | dpo_trainer.add_callback(callback) |
| |
|
| | return dpo_trainer |
| |
|
| |
|
| | class HFPPOTrainerBuilder(TrainerBuilderBase): |
| | """ |
| | HF Factory class for PPO Trainer |
| | """ |
| |
|
| | def get_callbacks(self): |
| | callbacks = [] |
| | return callbacks |
| |
|
| | def get_post_trainer_create_callbacks(self, trainer): |
| | callbacks = [] |
| | return callbacks |
| |
|
| | def build(self, total_num_steps): |
| | |
| | pass |
| |
|