""" Custom HuggingFace Trainer callbacks: - EarlyStoppingOnWER: stops training when WER stops improving - AdapterCheckpointCallback: saves only adapter weights (not full model) per checkpoint """ from __future__ import annotations import logging from pathlib import Path from typing import TYPE_CHECKING from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments if TYPE_CHECKING: pass logger = logging.getLogger(__name__) class EarlyStoppingOnWER(TrainerCallback): """ Stops training if eval WER does not improve by min_delta over `patience` evaluations. """ def __init__(self, patience: int = 5, min_delta: float = 0.001) -> None: self.patience = patience self.min_delta = min_delta self._best_wer: float = float("inf") self._no_improve_count: int = 0 def on_evaluate( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics: dict, **kwargs, ) -> None: wer = metrics.get("eval_wer") if wer is None: return if wer < self._best_wer - self.min_delta: self._best_wer = wer self._no_improve_count = 0 logger.info("WER improved to %.4f", wer) else: self._no_improve_count += 1 logger.info( "WER %.4f did not improve (best: %.4f). No-improve count: %d/%d", wer, self._best_wer, self._no_improve_count, self.patience, ) if self._no_improve_count >= self.patience: logger.warning("Early stopping triggered after %d evaluations without improvement.", self.patience) control.should_training_stop = True class AdapterCheckpointCallback(TrainerCallback): """ Saves only the LoRA adapter weights on each checkpoint event. Adapter weights are ~50MB vs ~3GB for the full model. """ def __init__(self, adapter_output_dir: str) -> None: self.adapter_output_dir = Path(adapter_output_dir) def on_save( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model, **kwargs, ) -> None: checkpoint_dir = self.adapter_output_dir / f"checkpoint-{state.global_step}" checkpoint_dir.mkdir(parents=True, exist_ok=True) # model is a PeftModel — save only adapter weights if hasattr(model, "save_pretrained"): model.save_pretrained(str(checkpoint_dir)) logger.info("Adapter checkpoint saved: %s", checkpoint_dir) else: logger.warning("Model does not have save_pretrained — skipping adapter checkpoint.")