Spaces:
Running
Running
| """ | |
| Custom HuggingFace Trainer callbacks: | |
| - EarlyStoppingOnWER: stops training when WER stops improving | |
| - AdapterCheckpointCallback: saves only adapter weights (not full model) per checkpoint | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING | |
| from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments | |
| if TYPE_CHECKING: | |
| pass | |
| logger = logging.getLogger(__name__) | |
| class EarlyStoppingOnWER(TrainerCallback): | |
| """ | |
| Stops training if eval WER does not improve by min_delta over `patience` evaluations. | |
| """ | |
| def __init__(self, patience: int = 5, min_delta: float = 0.001) -> None: | |
| self.patience = patience | |
| self.min_delta = min_delta | |
| self._best_wer: float = float("inf") | |
| self._no_improve_count: int = 0 | |
| def on_evaluate( | |
| self, | |
| args: TrainingArguments, | |
| state: TrainerState, | |
| control: TrainerControl, | |
| metrics: dict, | |
| **kwargs, | |
| ) -> None: | |
| wer = metrics.get("eval_wer") | |
| if wer is None: | |
| return | |
| if wer < self._best_wer - self.min_delta: | |
| self._best_wer = wer | |
| self._no_improve_count = 0 | |
| logger.info("WER improved to %.4f", wer) | |
| else: | |
| self._no_improve_count += 1 | |
| logger.info( | |
| "WER %.4f did not improve (best: %.4f). No-improve count: %d/%d", | |
| wer, self._best_wer, self._no_improve_count, self.patience, | |
| ) | |
| if self._no_improve_count >= self.patience: | |
| logger.warning("Early stopping triggered after %d evaluations without improvement.", self.patience) | |
| control.should_training_stop = True | |
| class AdapterCheckpointCallback(TrainerCallback): | |
| """ | |
| Saves only the LoRA adapter weights on each checkpoint event. | |
| Adapter weights are ~50MB vs ~3GB for the full model. | |
| """ | |
| def __init__(self, adapter_output_dir: str) -> None: | |
| self.adapter_output_dir = Path(adapter_output_dir) | |
| def on_save( | |
| self, | |
| args: TrainingArguments, | |
| state: TrainerState, | |
| control: TrainerControl, | |
| model, | |
| **kwargs, | |
| ) -> None: | |
| checkpoint_dir = self.adapter_output_dir / f"checkpoint-{state.global_step}" | |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) | |
| # model is a PeftModel — save only adapter weights | |
| if hasattr(model, "save_pretrained"): | |
| model.save_pretrained(str(checkpoint_dir)) | |
| logger.info("Adapter checkpoint saved: %s", checkpoint_dir) | |
| else: | |
| logger.warning("Model does not have save_pretrained — skipping adapter checkpoint.") | |