File size: 2,738 Bytes
76db545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Custom HuggingFace Trainer callbacks:
- EarlyStoppingOnWER: stops training when WER stops improving
- AdapterCheckpointCallback: saves only adapter weights (not full model) per checkpoint
"""
from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING

from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments

if TYPE_CHECKING:
    pass

logger = logging.getLogger(__name__)


class EarlyStoppingOnWER(TrainerCallback):
    """
    Stops training if eval WER does not improve by min_delta over `patience` evaluations.
    """

    def __init__(self, patience: int = 5, min_delta: float = 0.001) -> None:
        self.patience = patience
        self.min_delta = min_delta
        self._best_wer: float = float("inf")
        self._no_improve_count: int = 0

    def on_evaluate(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        metrics: dict,
        **kwargs,
    ) -> None:
        wer = metrics.get("eval_wer")
        if wer is None:
            return

        if wer < self._best_wer - self.min_delta:
            self._best_wer = wer
            self._no_improve_count = 0
            logger.info("WER improved to %.4f", wer)
        else:
            self._no_improve_count += 1
            logger.info(
                "WER %.4f did not improve (best: %.4f). No-improve count: %d/%d",
                wer, self._best_wer, self._no_improve_count, self.patience,
            )
            if self._no_improve_count >= self.patience:
                logger.warning("Early stopping triggered after %d evaluations without improvement.", self.patience)
                control.should_training_stop = True


class AdapterCheckpointCallback(TrainerCallback):
    """
    Saves only the LoRA adapter weights on each checkpoint event.
    Adapter weights are ~50MB vs ~3GB for the full model.
    """

    def __init__(self, adapter_output_dir: str) -> None:
        self.adapter_output_dir = Path(adapter_output_dir)

    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        model,
        **kwargs,
    ) -> None:
        checkpoint_dir = self.adapter_output_dir / f"checkpoint-{state.global_step}"
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        # model is a PeftModel — save only adapter weights
        if hasattr(model, "save_pretrained"):
            model.save_pretrained(str(checkpoint_dir))
            logger.info("Adapter checkpoint saved: %s", checkpoint_dir)
        else:
            logger.warning("Model does not have save_pretrained — skipping adapter checkpoint.")