Spaces:
Running
Running
File size: 2,738 Bytes
76db545 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | """
Custom HuggingFace Trainer callbacks:
- EarlyStoppingOnWER: stops training when WER stops improving
- AdapterCheckpointCallback: saves only adapter weights (not full model) per checkpoint
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
class EarlyStoppingOnWER(TrainerCallback):
"""
Stops training if eval WER does not improve by min_delta over `patience` evaluations.
"""
def __init__(self, patience: int = 5, min_delta: float = 0.001) -> None:
self.patience = patience
self.min_delta = min_delta
self._best_wer: float = float("inf")
self._no_improve_count: int = 0
def on_evaluate(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
metrics: dict,
**kwargs,
) -> None:
wer = metrics.get("eval_wer")
if wer is None:
return
if wer < self._best_wer - self.min_delta:
self._best_wer = wer
self._no_improve_count = 0
logger.info("WER improved to %.4f", wer)
else:
self._no_improve_count += 1
logger.info(
"WER %.4f did not improve (best: %.4f). No-improve count: %d/%d",
wer, self._best_wer, self._no_improve_count, self.patience,
)
if self._no_improve_count >= self.patience:
logger.warning("Early stopping triggered after %d evaluations without improvement.", self.patience)
control.should_training_stop = True
class AdapterCheckpointCallback(TrainerCallback):
"""
Saves only the LoRA adapter weights on each checkpoint event.
Adapter weights are ~50MB vs ~3GB for the full model.
"""
def __init__(self, adapter_output_dir: str) -> None:
self.adapter_output_dir = Path(adapter_output_dir)
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model,
**kwargs,
) -> None:
checkpoint_dir = self.adapter_output_dir / f"checkpoint-{state.global_step}"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# model is a PeftModel — save only adapter weights
if hasattr(model, "save_pretrained"):
model.save_pretrained(str(checkpoint_dir))
logger.info("Adapter checkpoint saved: %s", checkpoint_dir)
else:
logger.warning("Model does not have save_pretrained — skipping adapter checkpoint.")
|