|
|
import collections |
|
|
import contextlib |
|
|
import functools |
|
|
import shutil |
|
|
import sys |
|
|
import time |
|
|
from datetime import timedelta |
|
|
|
|
|
from packaging import version |
|
|
from accelerate import skip_first_batches, DistributedType, InitProcessGroupKwargs |
|
|
from transformers import PretrainedConfig |
|
|
from transformers.trainer import Trainer, TRAINING_ARGS_NAME, TRAINER_STATE_NAME |
|
|
import torch.distributed as dist |
|
|
from typing import Optional |
|
|
import os |
|
|
import torch |
|
|
import math |
|
|
|
|
|
from src.data.collator.train_collator import split_vlm_inputs, get_dense_rep, split_and_process_vlm_inputs |
|
|
from src.model.model_multi_layer_AOP import MMEBModel |
|
|
from src.loss_multi_layer import SimpleContrastiveLoss, DistributedContrastiveLoss |
|
|
from src.grad_cache.grad_cache import GradCache |
|
|
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler |
|
|
|
|
|
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments |
|
|
from transformers.trainer_callback import ( |
|
|
ExportableState, |
|
|
TrainerState, |
|
|
) |
|
|
from transformers.trainer_utils import ( |
|
|
TrainOutput, |
|
|
has_length, |
|
|
speed_metrics, seed_worker, |
|
|
) |
|
|
|
|
|
from transformers.trainer_pt_utils import ( |
|
|
get_model_param_count, |
|
|
) |
|
|
|
|
|
from transformers.trainer import FSDP_MODEL_NAME |
|
|
from transformers.utils import ( |
|
|
XLA_FSDPV2_MIN_VERSION, |
|
|
is_accelerate_available, |
|
|
is_apex_available, |
|
|
is_torch_xla_available, |
|
|
logging, is_sagemaker_mp_enabled, |
|
|
CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME, |
|
|
ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME |
|
|
) |
|
|
|
|
|
from src.utils import batch_to_device |
|
|
from src.utils import print_master, print_rank |
|
|
|
|
|
if is_apex_available(): |
|
|
from apex import amp |
|
|
|
|
|
if is_torch_xla_available(): |
|
|
import torch_xla.core.xla_model as xm |
|
|
from torch_xla import __version__ as XLA_VERSION |
|
|
|
|
|
IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) |
|
|
if IS_XLA_FSDPV2_POST_2_2: |
|
|
pass |
|
|
else: |
|
|
IS_XLA_FSDPV2_POST_2_2 = False |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class MMEBTrainer(Trainer): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super(MMEBTrainer, self).__init__(*args, **kwargs) |
|
|
self.is_ddp = dist.is_initialized() |
|
|
self.processor = self.processing_class |
|
|
self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1 |
|
|
|
|
|
def get_batch_samples(self, epoch_iterator, num_batches): |
|
|
batch_samples = [] |
|
|
num_items_in_batch = None |
|
|
for _ in range(num_batches): |
|
|
try: |
|
|
batch_samples += [next(epoch_iterator)] |
|
|
except StopIteration: |
|
|
break |
|
|
if len(batch_samples) > 0 and "labels" in batch_samples[0]: |
|
|
|
|
|
try: |
|
|
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) |
|
|
except (TypeError, AttributeError): |
|
|
pass |
|
|
if self.args.average_tokens_across_devices and num_items_in_batch is not None: |
|
|
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() |
|
|
if torch.is_tensor(num_items_in_batch): |
|
|
num_items_in_batch = num_items_in_batch.item() |
|
|
return batch_samples, num_items_in_batch |
|
|
|
|
|
def compute_loss(self, model, inputs, *args, **kwargs): |
|
|
qry_inputs, tgt_inputs = inputs |
|
|
return model(qry=qry_inputs, tgt=tgt_inputs) |
|
|
|
|
|
def _save(self, output_dir: Optional[str] = None, state_dict=None): |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
if state_dict is None: |
|
|
state_dict = self.model.state_dict() |
|
|
prefix = 'encoder.' |
|
|
assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys()) |
|
|
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} |
|
|
self.model.encoder.save_pretrained( |
|
|
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors |
|
|
) |
|
|
|
|
|
if self.tokenizer is not None: |
|
|
self.tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) |
|
|
|
|
|
|
|
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: |
|
|
|
|
|
if self.train_dataset is None or not has_length(self.train_dataset): |
|
|
return None |
|
|
return RandomSampler(self.train_dataset) |
|
|
|
|
|
def get_train_dataloader(self) -> DataLoader: |
|
|
""" |
|
|
override original trainer's method to disable self.accelerator.prepare since it will wrap DataLoaderDispatcher and lead to |
|
|
(1) `RuntimeError: You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`.` |
|
|
(2) all outputs of dataloader must be tensors |
|
|
""" |
|
|
if self.train_dataset is None: |
|
|
raise ValueError("Trainer: training requires a train_dataset.") |
|
|
train_dataset = self.train_dataset |
|
|
data_collator = self.data_collator |
|
|
train_dataset = self._remove_unused_columns(train_dataset, description="training") |
|
|
dataloader_params = { |
|
|
"batch_size": self._train_batch_size, |
|
|
"collate_fn": data_collator, |
|
|
"num_workers": self.args.dataloader_num_workers, |
|
|
"pin_memory": self.args.dataloader_pin_memory, |
|
|
"persistent_workers": self.args.dataloader_persistent_workers, |
|
|
} |
|
|
if not isinstance(train_dataset, torch.utils.data.IterableDataset): |
|
|
dataloader_params["sampler"] = self._get_train_sampler() |
|
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last |
|
|
dataloader_params["worker_init_fn"] = seed_worker |
|
|
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor |
|
|
else: |
|
|
dataloader_params["sampler"] = None |
|
|
dataloader_params["shuffle"] = False |
|
|
dataloader_params["drop_last"] = True |
|
|
dataloader_params["prefetch_factor"] = None |
|
|
return DataLoader(train_dataset, **dataloader_params) |
|
|
|
|
|
def _load_from_checkpoint(self, resume_from_checkpoint, model=None): |
|
|
self.model_args.checkpoint_path = resume_from_checkpoint |
|
|
logger.info(f"Loading checkpoint from {resume_from_checkpoint}") |
|
|
self.model = MMEBModel.load(self.model_args) |
|
|
self.model_wrapped = self.model |
|
|
|
|
|
def _inner_training_loop( |
|
|
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None |
|
|
): |
|
|
self.accelerator.free_memory() |
|
|
self._train_batch_size = batch_size |
|
|
if self.args.auto_find_batch_size: |
|
|
if self.state.train_batch_size != self._train_batch_size: |
|
|
from accelerate.utils import release_memory |
|
|
|
|
|
(self.model_wrapped,) = release_memory(self.model_wrapped) |
|
|
self.model_wrapped = self.model |
|
|
|
|
|
|
|
|
if self.is_deepspeed_enabled: |
|
|
|
|
|
original_bs = self.args.per_device_train_batch_size |
|
|
self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) |
|
|
self.propagate_args_to_deepspeed(True) |
|
|
self.args.per_device_train_batch_size = original_bs |
|
|
self.state.train_batch_size = self._train_batch_size |
|
|
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") |
|
|
|
|
|
train_dataloader = self.get_train_dataloader() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size |
|
|
|
|
|
len_dataloader = None |
|
|
num_train_tokens = None |
|
|
if has_length(train_dataloader): |
|
|
len_dataloader = len(train_dataloader) |
|
|
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps |
|
|
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) |
|
|
num_examples = self.num_examples(train_dataloader) |
|
|
if args.max_steps > 0: |
|
|
max_steps = args.max_steps |
|
|
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( |
|
|
args.max_steps % num_update_steps_per_epoch > 0 |
|
|
) |
|
|
|
|
|
|
|
|
num_train_samples = args.max_steps * total_train_batch_size |
|
|
if args.include_tokens_per_second: |
|
|
num_train_tokens = ( |
|
|
self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps |
|
|
) |
|
|
else: |
|
|
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) |
|
|
num_train_epochs = math.ceil(args.num_train_epochs) |
|
|
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs |
|
|
if args.include_tokens_per_second: |
|
|
num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs |
|
|
elif args.max_steps > 0: |
|
|
max_steps = args.max_steps |
|
|
|
|
|
num_train_epochs = sys.maxsize |
|
|
num_update_steps_per_epoch = max_steps |
|
|
num_examples = total_train_batch_size * args.max_steps |
|
|
num_train_samples = args.max_steps * total_train_batch_size |
|
|
if args.include_tokens_per_second: |
|
|
num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps |
|
|
else: |
|
|
raise ValueError( |
|
|
"args.max_steps must be set to a positive value if dataloader does not have a length, was" |
|
|
f" {args.max_steps}" |
|
|
) |
|
|
|
|
|
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled |
|
|
|
|
|
|
|
|
if self._created_lr_scheduler: |
|
|
self.lr_scheduler = None |
|
|
self._created_lr_scheduler = False |
|
|
|
|
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps) |
|
|
|
|
|
self.state = TrainerState( |
|
|
stateful_callbacks=[ |
|
|
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) |
|
|
] |
|
|
) |
|
|
self.state.is_hyper_param_search = trial is not None |
|
|
self.state.train_batch_size = self._train_batch_size |
|
|
|
|
|
|
|
|
if args.logging_steps is not None: |
|
|
if args.logging_steps < 1: |
|
|
self.state.logging_steps = math.ceil(max_steps * args.logging_steps) |
|
|
else: |
|
|
self.state.logging_steps = args.logging_steps |
|
|
if args.eval_steps is not None: |
|
|
if args.eval_steps < 1: |
|
|
self.state.eval_steps = math.ceil(max_steps * args.eval_steps) |
|
|
else: |
|
|
self.state.eval_steps = args.eval_steps |
|
|
if args.save_steps is not None: |
|
|
if args.save_steps < 1: |
|
|
self.state.save_steps = math.ceil(max_steps * args.save_steps) |
|
|
else: |
|
|
self.state.save_steps = args.save_steps |
|
|
|
|
|
|
|
|
if args.gradient_checkpointing: |
|
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) |
|
|
|
|
|
model = self._wrap_model(self.model_wrapped) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_accelerator_prepare = True if model is self.model else False |
|
|
|
|
|
if delay_optimizer_creation: |
|
|
if use_accelerator_prepare: |
|
|
self._fsdp_qlora_plugin_updates() |
|
|
self.model = self.accelerator.prepare(self.model) |
|
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps) |
|
|
|
|
|
|
|
|
if use_accelerator_prepare: |
|
|
self.model.train() |
|
|
if hasattr(self.lr_scheduler, "step"): |
|
|
if self.use_apex: |
|
|
model = self.accelerator.prepare(self.model) |
|
|
else: |
|
|
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) |
|
|
else: |
|
|
|
|
|
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( |
|
|
self.model, self.optimizer, self.lr_scheduler |
|
|
) |
|
|
elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: |
|
|
|
|
|
self.optimizer = self.accelerator.prepare(self.optimizer) |
|
|
|
|
|
if self.is_fsdp_enabled: |
|
|
self.model = self.model_wrapped = model |
|
|
|
|
|
|
|
|
if model is not self.model: |
|
|
self.model_wrapped = model |
|
|
|
|
|
|
|
|
if self.is_deepspeed_enabled: |
|
|
self.deepspeed = self.model_wrapped |
|
|
|
|
|
|
|
|
self._load_optimizer_and_scheduler(resume_from_checkpoint) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("***** Running training *****") |
|
|
logger.info(f" Num examples = {num_examples:,}") |
|
|
logger.info(f" Num Epochs = {num_train_epochs:,}") |
|
|
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") |
|
|
if self.args.per_device_train_batch_size != self._train_batch_size: |
|
|
logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") |
|
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") |
|
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
|
|
logger.info(f" Total optimization steps = {max_steps:,}") |
|
|
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") |
|
|
|
|
|
self.state.epoch = 0 |
|
|
start_time = time.time() |
|
|
epochs_trained = 0 |
|
|
steps_trained_in_current_epoch = 0 |
|
|
steps_trained_progress_bar = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if resume_from_checkpoint is not None and os.path.isfile( |
|
|
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) |
|
|
): |
|
|
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) |
|
|
self.compare_trainer_and_checkpoint_args(self.args, self.state) |
|
|
self._load_callback_state() |
|
|
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) |
|
|
if not args.ignore_data_skip: |
|
|
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) |
|
|
steps_trained_in_current_epoch *= args.gradient_accumulation_steps |
|
|
else: |
|
|
steps_trained_in_current_epoch = 0 |
|
|
|
|
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step") |
|
|
logger.info(f" Continuing training from epoch {epochs_trained}") |
|
|
logger.info(f" Continuing training from global step {self.state.global_step}") |
|
|
if not args.ignore_data_skip: |
|
|
logger.info( |
|
|
f" Will skip the first {epochs_trained} epochs then the first" |
|
|
f" {steps_trained_in_current_epoch} batches in the first epoch." |
|
|
) |
|
|
|
|
|
|
|
|
self.callback_handler.model = self.model |
|
|
self.callback_handler.optimizer = self.optimizer |
|
|
self.callback_handler.lr_scheduler = self.lr_scheduler |
|
|
self.callback_handler.train_dataloader = train_dataloader |
|
|
|
|
|
|
|
|
self.state.max_steps = max_steps |
|
|
self.state.num_train_epochs = num_train_epochs |
|
|
self.state.is_local_process_zero = self.is_local_process_zero() |
|
|
self.state.is_world_process_zero = self.is_world_process_zero() |
|
|
|
|
|
|
|
|
tr_loss = torch.tensor(0.0).to(args.device) |
|
|
|
|
|
self._total_loss_scalar = 0.0 |
|
|
self._globalstep_last_logged = self.state.global_step |
|
|
model.zero_grad() |
|
|
grad_norm: Optional[float] = None |
|
|
self.control = self.callback_handler.on_train_begin(args, self.state, self.control) |
|
|
|
|
|
if args.eval_on_start: |
|
|
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) |
|
|
|
|
|
total_batched_samples = 0 |
|
|
for epoch in range(epochs_trained, num_train_epochs): |
|
|
epoch_dataloader = train_dataloader |
|
|
if hasattr(epoch_dataloader.dataset, "set_epoch"): |
|
|
|
|
|
epoch_dataloader.dataset.set_epoch(epoch) |
|
|
|
|
|
|
|
|
if args.past_index >= 0: |
|
|
self._past = None |
|
|
|
|
|
steps_in_epoch = ( |
|
|
len(epoch_dataloader) |
|
|
if len_dataloader is not None |
|
|
else args.max_steps * args.gradient_accumulation_steps |
|
|
) |
|
|
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) |
|
|
|
|
|
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: |
|
|
self._load_rng_state(resume_from_checkpoint) |
|
|
|
|
|
rng_to_sync = False |
|
|
steps_skipped = 0 |
|
|
if steps_trained_in_current_epoch > 0: |
|
|
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) |
|
|
steps_skipped = steps_trained_in_current_epoch |
|
|
steps_trained_in_current_epoch = 0 |
|
|
rng_to_sync = True |
|
|
|
|
|
step = -1 |
|
|
epoch_iterator = iter(epoch_dataloader) |
|
|
|
|
|
remainder = num_examples % args.gradient_accumulation_steps |
|
|
num_items_in_batch = None |
|
|
if remainder == 0: |
|
|
remainder = args.gradient_accumulation_steps |
|
|
update_step = -1 |
|
|
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 |
|
|
for _ in range(total_updates): |
|
|
update_step += 1 |
|
|
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder |
|
|
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) |
|
|
for i, inputs in enumerate(batch_samples): |
|
|
step += 1 |
|
|
total_batched_samples += 1 |
|
|
|
|
|
dataset_stat = collections.Counter(inputs[0]['global_dataset_name']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_last_step_and_steps_less_than_grad_acc = ( |
|
|
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch |
|
|
) |
|
|
do_sync_step = is_last_step_and_steps_less_than_grad_acc or ( |
|
|
total_batched_samples % args.gradient_accumulation_steps == 0 |
|
|
) |
|
|
|
|
|
if not do_sync_step: |
|
|
self.accelerator.gradient_state._set_sync_gradients(False) |
|
|
else: |
|
|
self.accelerator.gradient_state._set_sync_gradients(True) |
|
|
|
|
|
if self.args.include_num_input_tokens_seen: |
|
|
main_input_name = getattr(self.model, "main_input_name", "input_ids") |
|
|
if main_input_name not in inputs: |
|
|
logger.warning( |
|
|
"Tried to track the number of tokens seen, however the current model is " |
|
|
"not configured properly to know what item is the input. To fix this, add " |
|
|
"a `main_input_name` attribute to the model class you are using." |
|
|
) |
|
|
else: |
|
|
input_tokens = inputs[main_input_name].numel() |
|
|
input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) |
|
|
self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).cpu().item() |
|
|
if rng_to_sync: |
|
|
self._load_rng_state(resume_from_checkpoint) |
|
|
rng_to_sync = False |
|
|
|
|
|
|
|
|
if steps_trained_in_current_epoch > 0: |
|
|
steps_trained_in_current_epoch -= 1 |
|
|
if steps_trained_progress_bar is not None: |
|
|
steps_trained_progress_bar.update(1) |
|
|
if steps_trained_in_current_epoch == 0: |
|
|
self._load_rng_state(resume_from_checkpoint) |
|
|
continue |
|
|
elif steps_trained_progress_bar is not None: |
|
|
steps_trained_progress_bar.close() |
|
|
steps_trained_progress_bar = None |
|
|
|
|
|
if step % args.gradient_accumulation_steps == 0: |
|
|
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) |
|
|
|
|
|
|
|
|
context = ( |
|
|
functools.partial(self.accelerator.no_sync, model=model) |
|
|
if i != len(batch_samples) - 1 |
|
|
else contextlib.nullcontext |
|
|
) |
|
|
with context(): |
|
|
tr_loss_step = self.training_step(model, inputs, num_items_in_batch) |
|
|
|
|
|
if ( |
|
|
args.logging_nan_inf_filter |
|
|
and not is_torch_xla_available() |
|
|
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) |
|
|
): |
|
|
|
|
|
tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) |
|
|
else: |
|
|
if tr_loss.device != tr_loss_step.device: |
|
|
raise ValueError( |
|
|
f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" |
|
|
) |
|
|
tr_loss = tr_loss + tr_loss_step |
|
|
|
|
|
self.current_flos += float(self.floating_point_ops(inputs)) |
|
|
|
|
|
if do_sync_step: |
|
|
|
|
|
self.accelerator.gradient_state._set_sync_gradients(True) |
|
|
|
|
|
|
|
|
if args.max_grad_norm is not None and args.max_grad_norm > 0: |
|
|
|
|
|
|
|
|
if self.use_apex: |
|
|
|
|
|
_grad_norm = torch.nn.utils.clip_grad_norm_( |
|
|
amp.master_params(self.optimizer), |
|
|
args.max_grad_norm, |
|
|
) |
|
|
else: |
|
|
_grad_norm = self.accelerator.clip_grad_norm_( |
|
|
model.parameters(), |
|
|
args.max_grad_norm, |
|
|
) |
|
|
|
|
|
if ( |
|
|
is_accelerate_available() |
|
|
and self.accelerator.distributed_type == DistributedType.DEEPSPEED |
|
|
): |
|
|
grad_norm = model.get_global_grad_norm() |
|
|
|
|
|
if hasattr(grad_norm, "item"): |
|
|
grad_norm = grad_norm.item() |
|
|
else: |
|
|
grad_norm = _grad_norm |
|
|
|
|
|
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) |
|
|
|
|
|
self.optimizer.step() |
|
|
|
|
|
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) |
|
|
|
|
|
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped |
|
|
if optimizer_was_run: |
|
|
|
|
|
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): |
|
|
self.lr_scheduler.step() |
|
|
|
|
|
model.zero_grad() |
|
|
self.state.global_step += 1 |
|
|
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch |
|
|
self.control = self.callback_handler.on_step_end(args, self.state, self.control) |
|
|
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, time.time()) |
|
|
else: |
|
|
self.control = self.callback_handler.on_substep_end(args, self.state, self.control) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.control.should_epoch_stop or self.control.should_training_stop: |
|
|
if is_torch_xla_available(): |
|
|
xm.mark_step() |
|
|
break |
|
|
|
|
|
if self.control.should_epoch_stop or self.control.should_training_stop: |
|
|
if is_torch_xla_available(): |
|
|
xm.mark_step() |
|
|
break |
|
|
if step < 0: |
|
|
logger.warning( |
|
|
"There seems not to be a single sample in your epoch_iterator, stopping training at step" |
|
|
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" |
|
|
f" num_steps ({max_steps}) higher than the number of available samples." |
|
|
) |
|
|
self.control.should_training_stop = True |
|
|
|
|
|
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) |
|
|
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, time.time()) |
|
|
|
|
|
if self.control.should_training_stop: |
|
|
break |
|
|
|
|
|
if args.past_index and hasattr(self, "_past"): |
|
|
|
|
|
delattr(self, "_past") |
|
|
|
|
|
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") |
|
|
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: |
|
|
|
|
|
if is_torch_xla_available(): |
|
|
xm.rendezvous("load_best_model_at_end") |
|
|
elif args.parallel_mode == ParallelMode.DISTRIBUTED: |
|
|
dist.barrier() |
|
|
|
|
|
self._load_best_model() |
|
|
|
|
|
|
|
|
self._total_loss_scalar += tr_loss.item() |
|
|
effective_global_step = max(self.state.global_step, 0.001) |
|
|
train_loss = self._total_loss_scalar / effective_global_step |
|
|
|
|
|
metrics = speed_metrics( |
|
|
"train", |
|
|
start_time, |
|
|
num_samples=num_train_samples, |
|
|
num_steps=self.state.max_steps, |
|
|
num_tokens=num_train_tokens, |
|
|
) |
|
|
self.store_flos() |
|
|
metrics["total_flos"] = self.state.total_flos |
|
|
metrics["train_loss"] = train_loss |
|
|
|
|
|
self.is_in_train = False |
|
|
|
|
|
self._memory_tracker.stop_and_update_metrics(metrics) |
|
|
|
|
|
self.log(metrics) |
|
|
|
|
|
run_dir = self._get_output_dir(trial) |
|
|
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) |
|
|
|
|
|
|
|
|
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: |
|
|
for checkpoint in checkpoints_sorted: |
|
|
if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): |
|
|
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") |
|
|
shutil.rmtree(checkpoint, ignore_errors=True) |
|
|
|
|
|
self.control = self.callback_handler.on_train_end(args, self.state, self.control) |
|
|
|
|
|
|
|
|
self._finish_current_push() |
|
|
|
|
|
|
|
|
|
|
|
if self.neftune_noise_alpha is not None: |
|
|
self._deactivate_neftune(self.model) |
|
|
|
|
|
return TrainOutput(self.state.global_step, train_loss, metrics) |
|
|
|
|
|
|
|
|
class GradCacheLateProcessTrainer(MMEBTrainer): |
|
|
""" |
|
|
Adapted from gradcache repo. |
|
|
""" |
|
|
def __init__(self, *args, **kwargs): |
|
|
self.max_length = kwargs.get("max_length", 512) |
|
|
if "max_length" in kwargs: |
|
|
del kwargs["max_length"] |
|
|
self.model_args = kwargs.get("model_args", None) |
|
|
if "model_args" in kwargs: |
|
|
del kwargs["model_args"] |
|
|
super(GradCacheLateProcessTrainer, self).__init__(*args, **kwargs) |
|
|
self.is_ddp = dist.is_initialized() |
|
|
self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1 |
|
|
loss_fn_cls = DistributedContrastiveLoss if self.is_ddp else SimpleContrastiveLoss |
|
|
loss_fn = loss_fn_cls( |
|
|
temperature=self.model.temperature, |
|
|
alpha=getattr(self.model, "dual_alpha", 0.15), |
|
|
weights=getattr(self.model, "supervise_weights", None) |
|
|
) |
|
|
|
|
|
self.gc = GradCache( |
|
|
models=[self.model, self.model], |
|
|
chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_p_chunk_size], |
|
|
loss_fn=loss_fn, |
|
|
split_input_fn=split_and_process_vlm_inputs, |
|
|
get_rep_fn=get_dense_rep, |
|
|
fp16=self.args.fp16, |
|
|
scaler=self.scaler if self.args.fp16 else None |
|
|
) |
|
|
|
|
|
def _maybe_sample_keep_ratios(self, model): |
|
|
""" |
|
|
每步按配置在 [min, max] 范围内随机采样 keep_ratio_text/vision,并覆盖到 encoder.aop_prune_config。 |
|
|
仅在 mode=ratio 且 enabled 时生效。优先使用各自 range;否则使用通用 range;否则不改。 |
|
|
同时把本次采样值写入 cfg['_last_sampled_keep_ratio_*'] 供底模 monitor 打印。 |
|
|
""" |
|
|
import torch |
|
|
enc = model.module.encoder if hasattr(model, "module") and hasattr(model.module, "encoder") else getattr(model, "encoder", None) |
|
|
if enc is None: |
|
|
return None |
|
|
|
|
|
cfg = getattr(enc, "aop_prune_config", None) |
|
|
if not isinstance(cfg, dict) or not cfg.get("enabled", False): |
|
|
return None |
|
|
if str(cfg.get("mode", "ratio")).lower() != "ratio": |
|
|
return None |
|
|
sampling = str(cfg.get("keep_ratio_sampling", "off")).lower() |
|
|
if sampling not in {"step", "batch"}: |
|
|
return None |
|
|
|
|
|
device = getattr(model, "device", torch.device("cpu")) |
|
|
|
|
|
def _sample_from_range(rng, fallback): |
|
|
if rng is None: |
|
|
return fallback |
|
|
try: |
|
|
lo, hi = float(rng[0]), float(rng[1]) |
|
|
lo = max(0.0, min(1.0, lo)); hi = max(0.0, min(1.0, hi)) |
|
|
if hi < lo: lo, hi = hi, lo |
|
|
val = lo + (hi - lo) * torch.rand((), device=device).item() |
|
|
return float(val) |
|
|
except Exception: |
|
|
return fallback |
|
|
|
|
|
r_all = cfg.get("keep_ratio_range", None) |
|
|
r_t = cfg.get("keep_ratio_text_range", r_all) |
|
|
r_v = cfg.get("keep_ratio_vision_range", r_all) |
|
|
|
|
|
kr_t = _sample_from_range(r_t, cfg.get("keep_ratio_text", cfg.get("keep_ratio", None))) |
|
|
kr_v = _sample_from_range(r_v, cfg.get("keep_ratio_vision", cfg.get("keep_ratio", None))) |
|
|
|
|
|
if kr_t is not None: |
|
|
cfg["keep_ratio_text"] = float(kr_t) |
|
|
if kr_v is not None: |
|
|
cfg["keep_ratio_vision"] = float(kr_v) |
|
|
|
|
|
|
|
|
cfg["_last_sampled_keep_ratio_text"] = float(kr_t) if kr_t is not None else None |
|
|
cfg["_last_sampled_keep_ratio_vision"] = float(kr_v) if kr_v is not None else None |
|
|
|
|
|
setattr(enc, "aop_prune_config", cfg) |
|
|
|
|
|
|
|
|
if os.getenv("AOP_MONITOR", "0") == "1": |
|
|
step = getattr(self.state, "global_step", None) if hasattr(self, "state") else None |
|
|
tag = f"step={step}" if step is not None else "" |
|
|
if not hasattr(self, "_aop_sample_prints"): self._aop_sample_prints = 0 |
|
|
if self._aop_sample_prints < 5: |
|
|
print(f"[AOP][sample] {tag} keep_ratio_text={cfg.get('keep_ratio_text')} keep_ratio_vision={cfg.get('keep_ratio_vision')}") |
|
|
self._aop_sample_prints += 1 |
|
|
|
|
|
return (kr_t, kr_v) |
|
|
|
|
|
def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor: |
|
|
""" |
|
|
单卡/非DDP:两阶段 stop-grad 分块训练,按 dual_alpha 设置权重并做分块均值缩放,保证数值尺度与原始实现一致; |
|
|
多卡DDP:沿用 GradCache。 |
|
|
""" |
|
|
model.train() |
|
|
|
|
|
self._maybe_sample_keep_ratios(model) |
|
|
queries_raw, targets_raw = inputs |
|
|
queries_raw = batch_to_device(queries_raw, model.device) |
|
|
targets_raw = batch_to_device(targets_raw, model.device) |
|
|
|
|
|
|
|
|
is_ddp_wrapped = isinstance(model, torch.nn.parallel.DistributedDataParallel) |
|
|
ws = dist.get_world_size() if dist.is_initialized() else 1 |
|
|
use_ddp = (is_ddp_wrapped and ws > 1) |
|
|
if use_ddp: |
|
|
queries, targets = {'qry': queries_raw}, {'tgt': targets_raw} |
|
|
self.gc.models = [model, model] |
|
|
loss = self.gc(queries, targets, no_sync_except_last=True) |
|
|
return loss / self._dist_loss_scale_factor |
|
|
|
|
|
|
|
|
def _batch_size(d): |
|
|
for v in d.values(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
return v.size(0) |
|
|
if isinstance(v, (list, tuple)): |
|
|
return len(v) |
|
|
raise ValueError("Cannot infer batch size from inputs") |
|
|
|
|
|
def _slice_dict(d, s, e): |
|
|
out = {} |
|
|
for k, v in d.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
out[k] = v[s:e] |
|
|
elif isinstance(v, list): |
|
|
out[k] = v[s:e] |
|
|
elif isinstance(v, tuple): |
|
|
out[k] = v[s:e] |
|
|
else: |
|
|
out[k] = v |
|
|
return out |
|
|
|
|
|
B = _batch_size(queries_raw) |
|
|
assert B == _batch_size(targets_raw), "Batch mismatch between queries and targets" |
|
|
q_chunk = max(1, int(getattr(self.args, "gc_q_chunk_size", 8))) |
|
|
p_chunk = max(1, int(getattr(self.args, "gc_p_chunk_size", 8))) |
|
|
Nq = (B + q_chunk - 1) // q_chunk |
|
|
Np = (B + p_chunk - 1) // p_chunk |
|
|
|
|
|
temp = float(getattr(self.model, "temperature", 0.02)) |
|
|
alpha = float(getattr(self.model, "dual_alpha", 0.15)) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
p_chunks = [] |
|
|
for s in range(0, B, p_chunk): |
|
|
e = min(s + p_chunk, B) |
|
|
out_p = model(tgt=_slice_dict(targets_raw, s, e)) |
|
|
p_chunks.append(out_p["tgt_reps"].detach()) |
|
|
p_bank = torch.cat(p_chunks, dim=0) |
|
|
|
|
|
K = p_bank.size(1) |
|
|
weights = getattr(self.model, "supervise_weights", None) |
|
|
if weights is None or len(weights) != K: |
|
|
weights = [1.0] * K |
|
|
w = torch.tensor(weights, dtype=p_bank.dtype, device=p_bank.device) |
|
|
w = torch.clamp(w, min=0) |
|
|
w = w / max(w.sum().item(), 1e-8) |
|
|
|
|
|
|
|
|
Lq_sum_weighted = 0.0 |
|
|
Lp_sum_weighted = 0.0 |
|
|
|
|
|
|
|
|
for s in range(0, B, q_chunk): |
|
|
e = min(s + q_chunk, B) |
|
|
out_q = model(qry=_slice_dict(queries_raw, s, e)) |
|
|
q_rep = out_q["qry_reps"] |
|
|
b = q_rep.size(0) |
|
|
|
|
|
loss_chunk = 0.0 |
|
|
for k in range(K): |
|
|
logits = torch.matmul(q_rep[:, k, :], p_bank[:, k, :].transpose(0, 1)) / temp |
|
|
tgt = torch.arange(s, e, device=logits.device, dtype=torch.long) |
|
|
loss_k = self.model.cross_entropy(logits, tgt) |
|
|
loss_chunk = loss_chunk + w[k] * loss_k |
|
|
|
|
|
|
|
|
Lq_sum_weighted += float(loss_chunk.detach().item()) * b |
|
|
|
|
|
|
|
|
scale_q = (1.0 - alpha) / float(Nq) |
|
|
self.accelerator.backward((loss_chunk * scale_q) / self._dist_loss_scale_factor) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
q_chunks = [] |
|
|
for s in range(0, B, q_chunk): |
|
|
e = min(s + q_chunk, B) |
|
|
out_q = model(qry=_slice_dict(queries_raw, s, e)) |
|
|
q_chunks.append(out_q["qry_reps"].detach()) |
|
|
q_bank = torch.cat(q_chunks, dim=0) |
|
|
|
|
|
for s in range(0, B, p_chunk): |
|
|
e = min(s + p_chunk, B) |
|
|
out_p = model(tgt=_slice_dict(targets_raw, s, e)) |
|
|
p_rep = out_p["tgt_reps"] |
|
|
b = p_rep.size(0) |
|
|
|
|
|
loss_chunk = 0.0 |
|
|
for k in range(K): |
|
|
logits = torch.matmul(p_rep[:, k, :], q_bank[:, k, :].transpose(0, 1)) / temp |
|
|
tgt = torch.arange(s, e, device=logits.device, dtype=torch.long) |
|
|
loss_k = self.model.cross_entropy(logits, tgt) |
|
|
loss_chunk = loss_chunk + w[k] * loss_k |
|
|
|
|
|
|
|
|
Lp_sum_weighted += float(loss_chunk.detach().item()) * b |
|
|
|
|
|
|
|
|
scale_p = alpha / float(Np) |
|
|
self.accelerator.backward((loss_chunk * scale_p) / self._dist_loss_scale_factor) |
|
|
|
|
|
|
|
|
Lq_avg = Lq_sum_weighted / float(B) |
|
|
Lp_avg = Lp_sum_weighted / float(B) |
|
|
loss_log_val = (1.0 - alpha) * Lq_avg + alpha * Lp_avg |
|
|
|
|
|
ref = next(model.parameters()) |
|
|
loss_log = torch.tensor(loss_log_val, device=ref.device, dtype=ref.dtype) |
|
|
dummy_zero = ref.sum() * 0.0 |
|
|
return dummy_zero + loss_log |
|
|
|
|
|
|
|
|
def _save(self, output_dir: Optional[str] = None, state_dict=None): |
|
|
print_master(f"Saving model to {output_dir}") |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
if state_dict is None: |
|
|
state_dict = self.model.state_dict() |
|
|
prefix = 'encoder.' |
|
|
assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys()) |
|
|
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} |
|
|
self.model.encoder.save_pretrained( |
|
|
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors |
|
|
) |
|
|
|
|
|
if self.tokenizer is not None: |
|
|
self.tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) |
|
|
self.model.encoder.config.to_json_file(os.path.join(output_dir, 'config.json')) |
|
|
|