| | import collections |
| | import contextlib |
| | import functools |
| | import shutil |
| | import sys |
| | import time |
| | from datetime import timedelta |
| |
|
| | from packaging import version |
| | from accelerate import skip_first_batches, DistributedType, InitProcessGroupKwargs |
| | from transformers import PretrainedConfig |
| | from transformers.trainer import Trainer, TRAINING_ARGS_NAME, TRAINER_STATE_NAME |
| | import torch.distributed as dist |
| | from typing import Optional |
| | import os |
| | import torch |
| | import math |
| |
|
| | from src.data.collator.train_collator import split_vlm_inputs, get_dense_rep, split_and_process_vlm_inputs |
| | from src.model.model_gp import MMEBModel |
| | from src.loss import SimpleContrastiveLoss, DistributedContrastiveLoss |
| | from src.grad_cache.grad_cache import GradCache |
| | from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler |
| |
|
| | from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments |
| | from transformers.trainer_callback import ( |
| | ExportableState, |
| | TrainerState, |
| | ) |
| | from transformers.trainer_utils import ( |
| | TrainOutput, |
| | has_length, |
| | speed_metrics, seed_worker, |
| | ) |
| |
|
| | from transformers.trainer_pt_utils import ( |
| | get_model_param_count, |
| | ) |
| |
|
| | from transformers.trainer import FSDP_MODEL_NAME |
| | from transformers.utils import ( |
| | XLA_FSDPV2_MIN_VERSION, |
| | is_accelerate_available, |
| | is_apex_available, |
| | is_torch_xla_available, |
| | logging, is_sagemaker_mp_enabled, |
| | CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME, |
| | ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME |
| | ) |
| |
|
| | from src.utils import batch_to_device |
| | from src.utils import print_master, print_rank |
| | from src.loss_loc import MaskLoss |
| | from src.model.processor import QWEN2_5_VL_GP |
| |
|
| | if is_apex_available(): |
| | from apex import amp |
| |
|
| | if is_torch_xla_available(): |
| | import torch_xla.core.xla_model as xm |
| | from torch_xla import __version__ as XLA_VERSION |
| |
|
| | IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) |
| | if IS_XLA_FSDPV2_POST_2_2: |
| | pass |
| | else: |
| | IS_XLA_FSDPV2_POST_2_2 = False |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | def split_mixed_vlm_inputs(model_input: dict, chunk_size: int): |
| | key = list(model_input.keys())[0] |
| | if key == "qry": |
| | |
| | return split_vlm_inputs(model_input, chunk_size) |
| | else: |
| | |
| | return split_and_process_vlm_inputs(model_input, chunk_size) |
| |
|
| | class MMEBTrainer(Trainer): |
| | def __init__(self, *args, **kwargs): |
| | super(MMEBTrainer, self).__init__(*args, **kwargs) |
| | self.is_ddp = dist.is_initialized() |
| | self.processor = self.processing_class |
| | self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1 |
| |
|
| | def get_batch_samples(self, epoch_iterator, num_batches): |
| | batch_samples = [] |
| | num_items_in_batch = None |
| | for _ in range(num_batches): |
| | try: |
| | batch_samples += [next(epoch_iterator)] |
| | except StopIteration: |
| | break |
| | if len(batch_samples) > 0 and "labels" in batch_samples[0]: |
| | |
| | try: |
| | num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) |
| | except (TypeError, AttributeError): |
| | pass |
| | if self.args.average_tokens_across_devices and num_items_in_batch is not None: |
| | num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() |
| | if torch.is_tensor(num_items_in_batch): |
| | num_items_in_batch = num_items_in_batch.item() |
| | return batch_samples, num_items_in_batch |
| |
|
| | def compute_loss(self, model, inputs, *args, **kwargs): |
| | qry_inputs, tgt_inputs = inputs |
| | return model(qry=qry_inputs, tgt=tgt_inputs) |
| |
|
| | def _save(self, output_dir: Optional[str] = None, state_dict=None): |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | if state_dict is None: |
| | state_dict = self.model.state_dict() |
| | prefix = 'encoder.' |
| | assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys()) |
| | state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} |
| | self.model.encoder.save_pretrained( |
| | output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors |
| | ) |
| |
|
| | if self.tokenizer is not None: |
| | self.tokenizer.save_pretrained(output_dir) |
| |
|
| | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) |
| |
|
| |
|
| | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: |
| | |
| | if self.train_dataset is None or not has_length(self.train_dataset): |
| | return None |
| | return RandomSampler(self.train_dataset) |
| |
|
| | def get_train_dataloader(self) -> DataLoader: |
| | """ |
| | override original trainer's method to disable self.accelerator.prepare since it will wrap DataLoaderDispatcher and lead to |
| | (1) `RuntimeError: You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`.` |
| | (2) all outputs of dataloader must be tensors |
| | """ |
| | if self.train_dataset is None: |
| | raise ValueError("Trainer: training requires a train_dataset.") |
| | train_dataset = self.train_dataset |
| | data_collator = self.data_collator |
| | train_dataset = self._remove_unused_columns(train_dataset, description="training") |
| | dataloader_params = { |
| | "batch_size": self._train_batch_size, |
| | "collate_fn": data_collator, |
| | "num_workers": self.args.dataloader_num_workers, |
| | "pin_memory": self.args.dataloader_pin_memory, |
| | "persistent_workers": self.args.dataloader_persistent_workers, |
| | } |
| | if not isinstance(train_dataset, torch.utils.data.IterableDataset): |
| | dataloader_params["sampler"] = self._get_train_sampler() |
| | dataloader_params["drop_last"] = self.args.dataloader_drop_last |
| | dataloader_params["worker_init_fn"] = seed_worker |
| | dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor |
| | else: |
| | dataloader_params["sampler"] = None |
| | dataloader_params["shuffle"] = False |
| | dataloader_params["drop_last"] = True |
| | dataloader_params["prefetch_factor"] = None |
| | return DataLoader(train_dataset, **dataloader_params) |
| |
|
| | def _load_from_checkpoint(self, resume_from_checkpoint, model=None): |
| | self.model_args.checkpoint_path = resume_from_checkpoint |
| | logger.info(f"Loading checkpoint from {resume_from_checkpoint}") |
| | self.model = MMEBModel.load(self.model_args) |
| | self.model_wrapped = self.model |
| |
|
| | def _inner_training_loop( |
| | self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None |
| | ): |
| | self.accelerator.free_memory() |
| | self._train_batch_size = batch_size |
| | if self.args.auto_find_batch_size: |
| | if self.state.train_batch_size != self._train_batch_size: |
| | from accelerate.utils import release_memory |
| |
|
| | (self.model_wrapped,) = release_memory(self.model_wrapped) |
| | self.model_wrapped = self.model |
| |
|
| | |
| | if self.is_deepspeed_enabled: |
| | |
| | original_bs = self.args.per_device_train_batch_size |
| | self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) |
| | self.propagate_args_to_deepspeed(True) |
| | self.args.per_device_train_batch_size = original_bs |
| | self.state.train_batch_size = self._train_batch_size |
| | logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") |
| | |
| | train_dataloader = self.get_train_dataloader() |
| |
|
| | |
| | |
| | |
| | |
| | total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size |
| |
|
| | len_dataloader = None |
| | num_train_tokens = None |
| | if has_length(train_dataloader): |
| | len_dataloader = len(train_dataloader) |
| | num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps |
| | num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) |
| | num_examples = self.num_examples(train_dataloader) |
| | if args.max_steps > 0: |
| | max_steps = args.max_steps |
| | num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( |
| | args.max_steps % num_update_steps_per_epoch > 0 |
| | ) |
| | |
| | |
| | num_train_samples = args.max_steps * total_train_batch_size |
| | if args.include_tokens_per_second: |
| | num_train_tokens = ( |
| | self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps |
| | ) |
| | else: |
| | max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) |
| | num_train_epochs = math.ceil(args.num_train_epochs) |
| | num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs |
| | if args.include_tokens_per_second: |
| | num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs |
| | elif args.max_steps > 0: |
| | max_steps = args.max_steps |
| | |
| | num_train_epochs = sys.maxsize |
| | num_update_steps_per_epoch = max_steps |
| | num_examples = total_train_batch_size * args.max_steps |
| | num_train_samples = args.max_steps * total_train_batch_size |
| | if args.include_tokens_per_second: |
| | num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps |
| | else: |
| | raise ValueError( |
| | "args.max_steps must be set to a positive value if dataloader does not have a length, was" |
| | f" {args.max_steps}" |
| | ) |
| |
|
| | delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled |
| |
|
| | |
| | if self._created_lr_scheduler: |
| | self.lr_scheduler = None |
| | self._created_lr_scheduler = False |
| |
|
| | self.create_optimizer_and_scheduler(num_training_steps=max_steps) |
| |
|
| | self.state = TrainerState( |
| | stateful_callbacks=[ |
| | cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) |
| | ] |
| | ) |
| | self.state.is_hyper_param_search = trial is not None |
| | self.state.train_batch_size = self._train_batch_size |
| |
|
| | |
| | if args.logging_steps is not None: |
| | if args.logging_steps < 1: |
| | self.state.logging_steps = math.ceil(max_steps * args.logging_steps) |
| | else: |
| | self.state.logging_steps = args.logging_steps |
| | if args.eval_steps is not None: |
| | if args.eval_steps < 1: |
| | self.state.eval_steps = math.ceil(max_steps * args.eval_steps) |
| | else: |
| | self.state.eval_steps = args.eval_steps |
| | if args.save_steps is not None: |
| | if args.save_steps < 1: |
| | self.state.save_steps = math.ceil(max_steps * args.save_steps) |
| | else: |
| | self.state.save_steps = args.save_steps |
| |
|
| | |
| | if args.gradient_checkpointing: |
| | self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) |
| |
|
| | model = self._wrap_model(self.model_wrapped) |
| |
|
| | |
| | |
| | |
| | use_accelerator_prepare = True if model is self.model else False |
| |
|
| | if delay_optimizer_creation: |
| | if use_accelerator_prepare: |
| | self._fsdp_qlora_plugin_updates() |
| | self.model = self.accelerator.prepare(self.model) |
| | self.create_optimizer_and_scheduler(num_training_steps=max_steps) |
| |
|
| | |
| | if use_accelerator_prepare: |
| | self.model.train() |
| | if hasattr(self.lr_scheduler, "step"): |
| | if self.use_apex: |
| | model = self.accelerator.prepare(self.model) |
| | else: |
| | model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) |
| | else: |
| | |
| | model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( |
| | self.model, self.optimizer, self.lr_scheduler |
| | ) |
| | elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: |
| | |
| | self.optimizer = self.accelerator.prepare(self.optimizer) |
| |
|
| | if self.is_fsdp_enabled: |
| | self.model = self.model_wrapped = model |
| |
|
| | |
| | if model is not self.model: |
| | self.model_wrapped = model |
| |
|
| | |
| | if self.is_deepspeed_enabled: |
| | self.deepspeed = self.model_wrapped |
| |
|
| | |
| | self._load_optimizer_and_scheduler(resume_from_checkpoint) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | logger.info("***** Running training *****") |
| | logger.info(f" Num examples = {num_examples:,}") |
| | logger.info(f" Num Epochs = {num_train_epochs:,}") |
| | logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") |
| | if self.args.per_device_train_batch_size != self._train_batch_size: |
| | logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") |
| | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") |
| | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
| | logger.info(f" Total optimization steps = {max_steps:,}") |
| | logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") |
| |
|
| | self.state.epoch = 0 |
| | start_time = time.time() |
| | epochs_trained = 0 |
| | steps_trained_in_current_epoch = 0 |
| | steps_trained_progress_bar = None |
| |
|
| | |
| | |
| |
|
| | |
| | if resume_from_checkpoint is not None and os.path.isfile( |
| | os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) |
| | ): |
| | self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) |
| | self.compare_trainer_and_checkpoint_args(self.args, self.state) |
| | self._load_callback_state() |
| | epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) |
| | if not args.ignore_data_skip: |
| | steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) |
| | steps_trained_in_current_epoch *= args.gradient_accumulation_steps |
| | else: |
| | steps_trained_in_current_epoch = 0 |
| |
|
| | logger.info(" Continuing training from checkpoint, will skip to saved global_step") |
| | logger.info(f" Continuing training from epoch {epochs_trained}") |
| | logger.info(f" Continuing training from global step {self.state.global_step}") |
| | if not args.ignore_data_skip: |
| | logger.info( |
| | f" Will skip the first {epochs_trained} epochs then the first" |
| | f" {steps_trained_in_current_epoch} batches in the first epoch." |
| | ) |
| |
|
| | |
| | self.callback_handler.model = self.model |
| | self.callback_handler.optimizer = self.optimizer |
| | self.callback_handler.lr_scheduler = self.lr_scheduler |
| | self.callback_handler.train_dataloader = train_dataloader |
| | |
| | |
| | self.state.max_steps = max_steps |
| | self.state.num_train_epochs = num_train_epochs |
| | self.state.is_local_process_zero = self.is_local_process_zero() |
| | self.state.is_world_process_zero = self.is_world_process_zero() |
| |
|
| | |
| | tr_loss = torch.tensor(0.0).to(args.device) |
| | |
| | self._total_loss_scalar = 0.0 |
| | self._globalstep_last_logged = self.state.global_step |
| | model.zero_grad() |
| | grad_norm: Optional[float] = None |
| | self.control = self.callback_handler.on_train_begin(args, self.state, self.control) |
| |
|
| | if args.eval_on_start: |
| | self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) |
| |
|
| | total_batched_samples = 0 |
| | for epoch in range(epochs_trained, num_train_epochs): |
| | epoch_dataloader = train_dataloader |
| | if hasattr(epoch_dataloader.dataset, "set_epoch"): |
| | |
| | epoch_dataloader.dataset.set_epoch(epoch) |
| |
|
| | |
| | if args.past_index >= 0: |
| | self._past = None |
| |
|
| | steps_in_epoch = ( |
| | len(epoch_dataloader) |
| | if len_dataloader is not None |
| | else args.max_steps * args.gradient_accumulation_steps |
| | ) |
| | self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) |
| |
|
| | if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: |
| | self._load_rng_state(resume_from_checkpoint) |
| |
|
| | rng_to_sync = False |
| | steps_skipped = 0 |
| | if steps_trained_in_current_epoch > 0: |
| | epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) |
| | steps_skipped = steps_trained_in_current_epoch |
| | steps_trained_in_current_epoch = 0 |
| | rng_to_sync = True |
| |
|
| | step = -1 |
| | epoch_iterator = iter(epoch_dataloader) |
| | |
| | remainder = num_examples % args.gradient_accumulation_steps |
| | num_items_in_batch = None |
| | if remainder == 0: |
| | remainder = args.gradient_accumulation_steps |
| | update_step = -1 |
| | total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 |
| | for _ in range(total_updates): |
| | update_step += 1 |
| | num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder |
| | batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) |
| | for i, inputs in enumerate(batch_samples): |
| | step += 1 |
| | total_batched_samples += 1 |
| |
|
| | dataset_stat = collections.Counter(inputs[0]['global_dataset_name']) |
| | |
| | |
| | |
| |
|
| | is_last_step_and_steps_less_than_grad_acc = ( |
| | steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch |
| | ) |
| | do_sync_step = is_last_step_and_steps_less_than_grad_acc or ( |
| | total_batched_samples % args.gradient_accumulation_steps == 0 |
| | ) |
| | |
| | if not do_sync_step: |
| | self.accelerator.gradient_state._set_sync_gradients(False) |
| | else: |
| | self.accelerator.gradient_state._set_sync_gradients(True) |
| |
|
| | if self.args.include_num_input_tokens_seen: |
| | main_input_name = getattr(self.model, "main_input_name", "input_ids") |
| | if main_input_name not in inputs: |
| | logger.warning( |
| | "Tried to track the number of tokens seen, however the current model is " |
| | "not configured properly to know what item is the input. To fix this, add " |
| | "a `main_input_name` attribute to the model class you are using." |
| | ) |
| | else: |
| | input_tokens = inputs[main_input_name].numel() |
| | input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) |
| | self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).cpu().item() |
| | if rng_to_sync: |
| | self._load_rng_state(resume_from_checkpoint) |
| | rng_to_sync = False |
| |
|
| | |
| | if steps_trained_in_current_epoch > 0: |
| | steps_trained_in_current_epoch -= 1 |
| | if steps_trained_progress_bar is not None: |
| | steps_trained_progress_bar.update(1) |
| | if steps_trained_in_current_epoch == 0: |
| | self._load_rng_state(resume_from_checkpoint) |
| | continue |
| | elif steps_trained_progress_bar is not None: |
| | steps_trained_progress_bar.close() |
| | steps_trained_progress_bar = None |
| |
|
| | if step % args.gradient_accumulation_steps == 0: |
| | self.control = self.callback_handler.on_step_begin(args, self.state, self.control) |
| |
|
| | |
| | context = ( |
| | functools.partial(self.accelerator.no_sync, model=model) |
| | if i != len(batch_samples) - 1 |
| | else contextlib.nullcontext |
| | ) |
| | with context(): |
| | tr_loss_step = self.training_step(model, inputs, num_items_in_batch) |
| |
|
| | if ( |
| | args.logging_nan_inf_filter |
| | and not is_torch_xla_available() |
| | and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) |
| | ): |
| | |
| | tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) |
| | else: |
| | if tr_loss.device != tr_loss_step.device: |
| | raise ValueError( |
| | f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" |
| | ) |
| | tr_loss = tr_loss + tr_loss_step |
| |
|
| | self.current_flos += float(self.floating_point_ops(inputs)) |
| |
|
| | if do_sync_step: |
| | |
| | self.accelerator.gradient_state._set_sync_gradients(True) |
| |
|
| | |
| | if args.max_grad_norm is not None and args.max_grad_norm > 0: |
| | |
| |
|
| | if self.use_apex: |
| | |
| | _grad_norm = torch.nn.utils.clip_grad_norm_( |
| | amp.master_params(self.optimizer), |
| | args.max_grad_norm, |
| | ) |
| | else: |
| | _grad_norm = self.accelerator.clip_grad_norm_( |
| | model.parameters(), |
| | args.max_grad_norm, |
| | ) |
| |
|
| | if ( |
| | is_accelerate_available() |
| | and self.accelerator.distributed_type == DistributedType.DEEPSPEED |
| | ): |
| | grad_norm = model.get_global_grad_norm() |
| | |
| | if hasattr(grad_norm, "item"): |
| | grad_norm = grad_norm.item() |
| | else: |
| | grad_norm = _grad_norm |
| |
|
| | self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) |
| |
|
| | self.optimizer.step() |
| |
|
| | self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) |
| |
|
| | optimizer_was_run = not self.accelerator.optimizer_step_was_skipped |
| | if optimizer_was_run: |
| | |
| | if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): |
| | self.lr_scheduler.step() |
| |
|
| | model.zero_grad() |
| | self.state.global_step += 1 |
| | self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch |
| | self.control = self.callback_handler.on_step_end(args, self.state, self.control) |
| | self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, time.time()) |
| | else: |
| | self.control = self.callback_handler.on_substep_end(args, self.state, self.control) |
| |
|
| | |
| | |
| | |
| | if self.control.should_epoch_stop or self.control.should_training_stop: |
| | if is_torch_xla_available(): |
| | xm.mark_step() |
| | break |
| | |
| | if self.control.should_epoch_stop or self.control.should_training_stop: |
| | if is_torch_xla_available(): |
| | xm.mark_step() |
| | break |
| | if step < 0: |
| | logger.warning( |
| | "There seems not to be a single sample in your epoch_iterator, stopping training at step" |
| | f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" |
| | f" num_steps ({max_steps}) higher than the number of available samples." |
| | ) |
| | self.control.should_training_stop = True |
| |
|
| | self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) |
| | self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, time.time()) |
| |
|
| | if self.control.should_training_stop: |
| | break |
| |
|
| | if args.past_index and hasattr(self, "_past"): |
| | |
| | delattr(self, "_past") |
| |
|
| | logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") |
| | if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: |
| | |
| | if is_torch_xla_available(): |
| | xm.rendezvous("load_best_model_at_end") |
| | elif args.parallel_mode == ParallelMode.DISTRIBUTED: |
| | dist.barrier() |
| |
|
| | self._load_best_model() |
| |
|
| | |
| | self._total_loss_scalar += tr_loss.item() |
| | effective_global_step = max(self.state.global_step, 0.001) |
| | train_loss = self._total_loss_scalar / effective_global_step |
| |
|
| | metrics = speed_metrics( |
| | "train", |
| | start_time, |
| | num_samples=num_train_samples, |
| | num_steps=self.state.max_steps, |
| | num_tokens=num_train_tokens, |
| | ) |
| | self.store_flos() |
| | metrics["total_flos"] = self.state.total_flos |
| | metrics["train_loss"] = train_loss |
| |
|
| | self.is_in_train = False |
| |
|
| | self._memory_tracker.stop_and_update_metrics(metrics) |
| |
|
| | self.log(metrics) |
| |
|
| | run_dir = self._get_output_dir(trial) |
| | checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) |
| |
|
| | |
| | if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: |
| | for checkpoint in checkpoints_sorted: |
| | if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): |
| | logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") |
| | shutil.rmtree(checkpoint, ignore_errors=True) |
| |
|
| | self.control = self.callback_handler.on_train_end(args, self.state, self.control) |
| |
|
| | |
| | self._finish_current_push() |
| |
|
| | |
| | |
| | if self.neftune_noise_alpha is not None: |
| | self._deactivate_neftune(self.model) |
| |
|
| | return TrainOutput(self.state.global_step, train_loss, metrics) |
| |
|
| |
|
| | class GradCacheLateProcessTrainer(MMEBTrainer): |
| | """ |
| | Adapted from gradcache repo. |
| | """ |
| | def __init__(self, *args, **kwargs): |
| | self.max_length = kwargs.get("max_length", 512) |
| | if "max_length" in kwargs: |
| | del kwargs["max_length"] |
| | self.model_args = kwargs.get("model_args", None) |
| | if "model_args" in kwargs: |
| | del kwargs["model_args"] |
| | super(GradCacheLateProcessTrainer, self).__init__(*args, **kwargs) |
| | self.is_ddp = dist.is_initialized() |
| | self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1 |
| | loss_fn_cls = DistributedContrastiveLoss if self.is_ddp else SimpleContrastiveLoss |
| | loss_fn = loss_fn_cls(temperature=self.model.temperature) |
| | self._mask_loss_fn = MaskLoss(dice_weight=self.args.loc_dice_weight, bce_weight=self.args.loc_bce_weight) |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | use_gp_chat = (self.args.model_backbone == QWEN2_5_VL_GP and getattr(self.args, "gp_use_chat_processing", False)) |
| | split_fn = split_vlm_inputs if use_gp_chat else split_and_process_vlm_inputs |
| |
|
| | self.gc = GradCache( |
| | models=[self.model, self.model], |
| | chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_p_chunk_size], |
| | loss_fn=loss_fn, |
| | split_input_fn=split_mixed_vlm_inputs, |
| | get_rep_fn=get_dense_rep, |
| | fp16=self.args.fp16, |
| | scaler=self.scaler if self.args.fp16 else None |
| | ) |
| |
|
| | def _should_train_retrieval(self) -> bool: |
| | """ |
| | 当有以下任一条件时才训练检索对比分支: |
| | - 未冻结基座 (--image_encoder_freeze False) |
| | - 启用了 LoRA (--lora True) |
| | - 显式要求训练检索 (--train_retrieval True,可选) |
| | """ |
| | has_unfrozen = not getattr(self.args, "image_encoder_freeze", False) |
| | has_lora = getattr(self.args, "lora", False) |
| | explicit = getattr(self.args, "train_retrieval", False) |
| | return has_unfrozen or has_lora or explicit |
| |
|
| |
|
| | def _make_loss_require_grad(self, loss: torch.Tensor, model: torch.nn.Module) -> torch.Tensor: |
| | """ |
| | 若 loss 不带 grad(例如该 batch 没有可训练信号), |
| | 通过加 0 * (任一可训练参数和) 的方式让 loss 拥有 grad_fn,避免 backward 报错。 |
| | """ |
| | if loss.requires_grad: |
| | return loss |
| | |
| | try: |
| | any_trainable = next(p for p in model.parameters() if p.requires_grad) |
| | return loss + 0.0 * any_trainable.sum() |
| | except StopIteration: |
| | |
| | return loss |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor: |
| | model.train() |
| | queries, targets = inputs |
| | queries = batch_to_device(queries, model.device) |
| | targets = batch_to_device(targets, model.device) |
| |
|
| | queries, targets = {'qry': queries}, {'tgt': targets} |
| |
|
| | |
| | loss_retrieval = torch.tensor(0.0, device=model.device) |
| | if self._should_train_retrieval(): |
| | _distributed = self.args.local_rank > -1 |
| | if _distributed: |
| | self.gc.models = [model, model] |
| | loss_retrieval = self.gc(queries, targets, no_sync_except_last=_distributed) |
| | else: |
| | loss_retrieval = model(queries, targets) |
| |
|
| | |
| | loss_loc = torch.tensor(0.0, device=model.device) |
| | loss_le = torch.tensor(0.0, device=model.device) |
| |
|
| | if self.args.loc_weight > 0 and hasattr(self, "gp_loader"): |
| | |
| | if (self.state.global_step % self.args.gp_aux_every) == 0: |
| | try: |
| | gp_batch = next(self.gp_iter) |
| | except StopIteration: |
| | self.gp_iter = iter(self.gp_loader) |
| | gp_batch = next(self.gp_iter) |
| |
|
| | gp_batch = batch_to_device(gp_batch, model.device) |
| |
|
| | out = None |
| | try: |
| | |
| | |
| | |
| | out = self.accelerator.unwrap_model(model).encoder_qry( |
| | **gp_batch, |
| | return_dict=True, |
| | do_selection=True, |
| | delay_selection=True, |
| | use_cache=False, |
| | output_hidden_states=False, |
| | use_ref_masks=False, |
| | mask_head_only=True, |
| | ) |
| | mask_logits_all = out.image_token_mask_logits |
| | ref_masks = gp_batch.get("ref_token_masks", None) |
| | if ref_masks is not None: |
| | mask_logits_last = [m[-1] for m in mask_logits_all] |
| | |
| | flt_logits, flt_gts, trivial_cnt = [], [], 0 |
| | for pred, gt in zip(mask_logits_last, ref_masks): |
| | pv = gt.float().mean().item() |
| | if pv > 0.98 or pv < 0.02: |
| | trivial_cnt += 1 |
| | continue |
| | flt_logits.append(pred) |
| | flt_gts.append(gt) |
| | if len(flt_logits) > 0: |
| | loss_loc = self._mask_loss_fn(flt_logits, flt_gts) * self.args.loc_weight |
| | if self.state.global_step % self.args.logging_steps == 0 and self.is_world_process_zero(): |
| | print(f"[loc/qry(gp)] trivial_gt={trivial_cnt}/{len(mask_logits_last)}") |
| |
|
| | if self.args.le_weight > 0 and getattr(out, "le_loss", None) is not None: |
| | loss_le = out.le_loss * self.args.le_weight |
| | finally: |
| | try: |
| | self.accelerator.unwrap_model(model).encoder_qry.reset_image_tokens_cache() |
| | except Exception: |
| | pass |
| | del out |
| | |
| | |
| | loss = (loss_retrieval + loss_loc + loss_le) / self._dist_loss_scale_factor |
| |
|
| | |
| | if loss.requires_grad: |
| | self.accelerator.backward(loss) |
| |
|
| | |
| | return loss.detach() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | def _load_from_checkpoint(self, resume_from_checkpoint, model=None): |
| | logger.info(f"Loading checkpoint from {resume_from_checkpoint}") |
| |
|
| | |
| | qry_gp_modules_dir = os.path.join(resume_from_checkpoint, "qry_gp_modules") |
| | config_qry_path = os.path.join(resume_from_checkpoint, "config_qry.json") |
| | config_tgt_path = os.path.join(resume_from_checkpoint, "config_tgt.json") |
| |
|
| | if os.path.isdir(qry_gp_modules_dir) and os.path.exists(config_qry_path) and os.path.exists(config_tgt_path): |
| | |
| | from src.model.model_gp import AsymMMEBModel |
| | import copy |
| |
|
| | |
| | ma = copy.deepcopy(self.model_args) |
| | setattr(ma, "new_modules_dir", qry_gp_modules_dir) |
| |
|
| | |
| | self.model = AsymMMEBModel.load_asym(ma, processor=self.processing_class, is_trainable=True) |
| |
|
| | |
| | self.model.gp_do_selection_qry = getattr(self.args, "gp_do_selection", False) |
| |
|
| | |
| | try: |
| | from peft import PeftModel |
| | if isinstance(self.model.encoder_qry, PeftModel): |
| | self.model.encoder_qry = self.model.encoder_qry.merge_and_unload() |
| | except Exception: |
| | pass |
| |
|
| | self.model_wrapped = self.model |
| | logger.info("[load] AsymMMEBModel restored from checkpoint (qry_gp_modules/config_qry.json/config_tgt.json)") |
| | return |
| |
|
| | |
| | |
| | |
| | self.model_args.checkpoint_path = resume_from_checkpoint |
| | logger.info("Loading fallback single-encoder model via MMEBModel.load(...)") |
| | self.model = MMEBModel.load(self.model_args) |
| | self.model_wrapped = self.model |
| | |
| | |
| | def _save(self, output_dir: Optional[str] = None, state_dict=None): |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | |
| | if hasattr(self.model, "encoder_qry") and hasattr(self.model, "encoder_tgt"): |
| | |
| | if hasattr(self.model.encoder_qry, "save_new_modules"): |
| | |
| | gp_out_dir = os.path.join(output_dir, "qry_gp_modules") |
| | os.makedirs(gp_out_dir, exist_ok=True) |
| | self.model.encoder_qry.save_new_modules(gp_out_dir) |
| | |
| |
|
| | |
| | try: |
| | self.model.encoder_qry.config.to_json_file(os.path.join(output_dir, "config_qry.json")) |
| | except Exception as e: |
| | print_master(f"[warn] save config_qry.json failed: {e}") |
| | try: |
| | self.model.encoder_tgt.config.to_json_file(os.path.join(output_dir, "config_tgt.json")) |
| | except Exception as e: |
| | print_master(f"[warn] save config_tgt.json failed: {e}") |
| |
|
| | |
| | if self.tokenizer is not None: |
| | self.tokenizer.save_pretrained(output_dir) |
| |
|
| | |
| | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) |
| | print_master(f"[save] asym model saved to {output_dir} (qry_gp_modules, config_qry.json, config_tgt.json)") |
| |
|
| | else: |
| | |
| | |
| | if state_dict is None: |
| | state_dict = self.model.state_dict() |
| | prefix = "encoder." |
| | assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys()) |
| | state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} |
| |
|
| | self.model.encoder.save_pretrained( |
| | output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors |
| | ) |
| | if self.tokenizer is not None: |
| | self.tokenizer.save_pretrained(output_dir) |
| | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) |
| | try: |
| | self.model.encoder.config.to_json_file(os.path.join(output_dir, "config.json")) |
| | except Exception: |
| | pass |
| | print_master(f"[save] single-encoder model saved to {output_dir}") |
| |
|