ground-zero / src /training /callbacks.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
"""
Custom HuggingFace Trainer callbacks:
- EarlyStoppingOnWER: stops training when WER stops improving
- AdapterCheckpointCallback: saves only adapter weights (not full model) per checkpoint
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
class EarlyStoppingOnWER(TrainerCallback):
"""
Stops training if eval WER does not improve by min_delta over `patience` evaluations.
"""
def __init__(self, patience: int = 5, min_delta: float = 0.001) -> None:
self.patience = patience
self.min_delta = min_delta
self._best_wer: float = float("inf")
self._no_improve_count: int = 0
def on_evaluate(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
metrics: dict,
**kwargs,
) -> None:
wer = metrics.get("eval_wer")
if wer is None:
return
if wer < self._best_wer - self.min_delta:
self._best_wer = wer
self._no_improve_count = 0
logger.info("WER improved to %.4f", wer)
else:
self._no_improve_count += 1
logger.info(
"WER %.4f did not improve (best: %.4f). No-improve count: %d/%d",
wer, self._best_wer, self._no_improve_count, self.patience,
)
if self._no_improve_count >= self.patience:
logger.warning("Early stopping triggered after %d evaluations without improvement.", self.patience)
control.should_training_stop = True
class AdapterCheckpointCallback(TrainerCallback):
"""
Saves only the LoRA adapter weights on each checkpoint event.
Adapter weights are ~50MB vs ~3GB for the full model.
"""
def __init__(self, adapter_output_dir: str) -> None:
self.adapter_output_dir = Path(adapter_output_dir)
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model,
**kwargs,
) -> None:
checkpoint_dir = self.adapter_output_dir / f"checkpoint-{state.global_step}"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# model is a PeftModel — save only adapter weights
if hasattr(model, "save_pretrained"):
model.save_pretrained(str(checkpoint_dir))
logger.info("Adapter checkpoint saved: %s", checkpoint_dir)
else:
logger.warning("Model does not have save_pretrained — skipping adapter checkpoint.")