File size: 9,658 Bytes
76db545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""
Orchestrates full LoRA fine-tuning:
  WhisperBackbone + PEFT LoraConfig + WaxalDataLoader + Seq2SeqTrainer

Usage:
    trainer = WhisperLoRATrainer("configs/base_config.yaml", "configs/lora_bambara.yaml")
    trainer.setup()
    trainer.train()
"""
from __future__ import annotations

import logging
import os
from pathlib import Path

import torch
import yaml
from peft import LoraConfig, TaskType, get_peft_model
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

from src.data.augmentation import FieldNoiseAugmenter
from src.data.feature_extractor import DataCollatorSpeechSeq2SeqWithPadding
from src.data.waxal_loader import WaxalDataLoader
from src.engine.whisper_base import WhisperBackbone
from src.training.callbacks import AdapterCheckpointCallback, EarlyStoppingOnWER
from src.training.metrics import make_compute_metrics

logger = logging.getLogger(__name__)


class WhisperLoRATrainer:
    """Fine-tunes a language-specific LoRA adapter on top of Whisper."""

    def __init__(self, base_config_path: str, language_config_path: str) -> None:
        self._base_config_path = base_config_path
        with open(base_config_path) as f:
            self.config = yaml.safe_load(f)
        with open(language_config_path) as f:
            self.lang_config = yaml.safe_load(f)

        self._backbone: WhisperBackbone | None = None
        self._peft_model = None
        self._processor = None
        self._train_dataset = None
        self._eval_dataset = None

    def setup(self) -> None:
        """Load backbone, build LoRA config, prepare datasets."""
        hf_token = os.getenv("HF_TOKEN")
        device = "cuda" if torch.cuda.is_available() else "cpu"

        # 1. Load backbone
        logger.info("Loading backbone model...")
        self._backbone = WhisperBackbone(config_path=self._base_config_path)
        self._backbone.load(device=device, hf_token=hf_token)
        self._processor = self._backbone.processor

        # Disable cache for training
        self._backbone.model.config.use_cache = False

        # 2. Wrap with LoRA
        lora_cfg = self.lang_config["lora"]
        lora_config = LoraConfig(
            r=lora_cfg["r"],
            lora_alpha=lora_cfg["lora_alpha"],
            target_modules=lora_cfg["target_modules"],
            lora_dropout=lora_cfg["lora_dropout"],
            bias=lora_cfg["bias"],
            task_type=TaskType.SEQ_2_SEQ_LM,
        )
        self._peft_model = get_peft_model(self._backbone.model, lora_config)
        self._peft_model.print_trainable_parameters()

        # 3. Load data
        subset = self.lang_config["dataset_subset"]
        augmenter = FieldNoiseAugmenter(self.config["paths"]["noise_samples"], self.config)
        loader = WaxalDataLoader(subset, self.config, hf_token=hf_token)

        logger.info("Loading training data (streaming)...")
        self._train_dataset = loader.load_split("train", streaming=True)
        self._train_dataset = self._train_dataset.map(
            loader.make_preprocess_fn(self._processor, augmenter),
            remove_columns=self._train_dataset.column_names,
        )

        try:
            self._eval_dataset = loader.load_split("validation", streaming=False)
            self._eval_dataset = self._eval_dataset.map(
                loader.make_preprocess_fn(self._processor, augmenter=None),
                remove_columns=self._eval_dataset.column_names,
            )
        except Exception:
            logger.warning("No validation split found — eval will be skipped.")
            self._eval_dataset = None

    def build_training_args(self) -> Seq2SeqTrainingArguments:
        tc = self.config["training"]
        output_dir = self.lang_config.get("output_dir", tc["output_dir"])
        return Seq2SeqTrainingArguments(
            output_dir=output_dir,
            per_device_train_batch_size=tc["per_device_train_batch_size"],
            gradient_accumulation_steps=tc["gradient_accumulation_steps"],
            warmup_steps=tc["warmup_steps"],
            max_steps=tc["max_steps"],
            save_steps=tc["save_steps"],
            eval_steps=tc["eval_steps"] if self._eval_dataset is not None else None,
            evaluation_strategy="steps" if self._eval_dataset is not None else "no",
            learning_rate=tc["learning_rate"],
            fp16=tc["fp16"] and torch.cuda.is_available(),
            dataloader_num_workers=tc["dataloader_num_workers"],  # 0 on Windows
            predict_with_generate=True,
            generation_max_length=128,
            logging_steps=25,
            load_best_model_at_end=self._eval_dataset is not None,
            metric_for_best_model="wer",
            greater_is_better=False,
            report_to="none",
        )

    def merge_extra_data(
        self,
        feedback_records: list[dict],
        repeat: int = 3,
        waxal_cap: int = 500,
    ) -> None:
        """
        Merge feedback corrections into the training dataset.

        Materializes up to `waxal_cap` Waxal samples (converts streaming → Dataset),
        then appends `feedback_records` (each repeated `repeat` times for upsampling)
        preprocessed into {input_features, labels} format.

        Call this after setup() and before train().

        Args:
            feedback_records: List of dicts from corrections.jsonl with keys
                              'audio_file' (path) and 'corrected_text'.
            repeat:           How many times to repeat each feedback sample
                              (3× keeps corrections competitive with Waxal baseline).
            waxal_cap:        Max Waxal samples to materialise (avoids OOM on Colab T4).
        """
        if self._peft_model is None:
            raise RuntimeError("Call setup() before merge_extra_data().")

        import librosa
        import numpy as np
        from datasets import Dataset, concatenate_datasets

        logger.info(
            "Merging %d feedback records (×%d) with Waxal (cap=%d)...",
            len(feedback_records), repeat, waxal_cap,
        )

        # ── 1. Materialise Waxal streaming dataset ─────────────────────────────
        waxal_rows: list[dict] = []
        for row in self._train_dataset:
            waxal_rows.append(row)
            if len(waxal_rows) >= waxal_cap:
                break
        waxal_ds = Dataset.from_list(waxal_rows)
        logger.info("Materialised %d Waxal samples.", len(waxal_ds))

        # ── 2. Preprocess feedback records ─────────────────────────────────────
        def _load_preprocess(rec: dict) -> dict | None:
            try:
                audio_np, _ = librosa.load(rec["audio_file"], sr=16000, mono=True)
                inputs = self._processor.feature_extractor(
                    audio_np, sampling_rate=16000, return_tensors="np"
                )
                labels = self._processor.tokenizer(
                    rec["corrected_text"], return_tensors="np"
                ).input_ids[0]
                return {
                    "input_features": inputs.input_features[0],
                    "labels": labels,
                }
            except Exception as e:
                logger.warning("Skipping feedback record %s: %s", rec.get("id", "?"), e)
                return None

        fb_rows = []
        for rec in feedback_records * repeat:
            processed = _load_preprocess(rec)
            if processed is not None:
                fb_rows.append(processed)

        if not fb_rows:
            logger.warning("No feedback records could be processed — using Waxal only.")
            self._train_dataset = waxal_ds
            return

        fb_ds = Dataset.from_list(fb_rows)
        logger.info("Preprocessed %d feedback rows (after ×%d repeat).", len(fb_ds), repeat)

        # ── 3. Concatenate and replace train_dataset ───────────────────────────
        self._train_dataset = concatenate_datasets([waxal_ds, fb_ds]).shuffle(seed=42)
        logger.info("Final training dataset: %d samples.", len(self._train_dataset))

    def train(self) -> None:
        if self._peft_model is None:
            raise RuntimeError("Call setup() before train().")

        training_args = self.build_training_args()
        output_dir = self.lang_config.get("output_dir", self.config["training"]["output_dir"])

        collator = DataCollatorSpeechSeq2SeqWithPadding(
            processor=self._processor,
            decoder_start_token_id=self._backbone.model.config.decoder_start_token_id,
        )

        callbacks = [
            AdapterCheckpointCallback(output_dir),
            EarlyStoppingOnWER(patience=5),
        ]

        compute_metrics = make_compute_metrics(self._processor) if self._eval_dataset is not None else None

        trainer = Seq2SeqTrainer(
            model=self._peft_model,
            args=training_args,
            train_dataset=self._train_dataset,
            eval_dataset=self._eval_dataset,
            data_collator=collator,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            tokenizer=self._processor.feature_extractor,
        )

        logger.info("Starting training for language '%s'...", self.lang_config["language"])
        trainer.train()

        # Save final adapter weights
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        self._peft_model.save_pretrained(output_dir)
        logger.info("Adapter saved to %s", output_dir)