""" Orchestrates full LoRA fine-tuning: WhisperBackbone + PEFT LoraConfig + WaxalDataLoader + Seq2SeqTrainer Usage: trainer = WhisperLoRATrainer("configs/base_config.yaml", "configs/lora_bambara.yaml") trainer.setup() trainer.train() """ from __future__ import annotations import logging import os from pathlib import Path import torch import yaml from peft import LoraConfig, TaskType, get_peft_model from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments from src.data.augmentation import FieldNoiseAugmenter from src.data.feature_extractor import DataCollatorSpeechSeq2SeqWithPadding from src.data.waxal_loader import WaxalDataLoader from src.engine.whisper_base import WhisperBackbone from src.training.callbacks import AdapterCheckpointCallback, EarlyStoppingOnWER from src.training.metrics import make_compute_metrics logger = logging.getLogger(__name__) class WhisperLoRATrainer: """Fine-tunes a language-specific LoRA adapter on top of Whisper.""" def __init__(self, base_config_path: str, language_config_path: str) -> None: self._base_config_path = base_config_path with open(base_config_path) as f: self.config = yaml.safe_load(f) with open(language_config_path) as f: self.lang_config = yaml.safe_load(f) self._backbone: WhisperBackbone | None = None self._peft_model = None self._processor = None self._train_dataset = None self._eval_dataset = None def setup(self) -> None: """Load backbone, build LoRA config, prepare datasets.""" hf_token = os.getenv("HF_TOKEN") device = "cuda" if torch.cuda.is_available() else "cpu" # 1. Load backbone logger.info("Loading backbone model...") self._backbone = WhisperBackbone(config_path=self._base_config_path) self._backbone.load(device=device, hf_token=hf_token) self._processor = self._backbone.processor # Disable cache for training self._backbone.model.config.use_cache = False # 2. Wrap with LoRA lora_cfg = self.lang_config["lora"] lora_config = LoraConfig( r=lora_cfg["r"], lora_alpha=lora_cfg["lora_alpha"], target_modules=lora_cfg["target_modules"], lora_dropout=lora_cfg["lora_dropout"], bias=lora_cfg["bias"], task_type=TaskType.SEQ_2_SEQ_LM, ) self._peft_model = get_peft_model(self._backbone.model, lora_config) self._peft_model.print_trainable_parameters() # 3. Load data subset = self.lang_config["dataset_subset"] augmenter = FieldNoiseAugmenter(self.config["paths"]["noise_samples"], self.config) loader = WaxalDataLoader(subset, self.config, hf_token=hf_token) logger.info("Loading training data (streaming)...") self._train_dataset = loader.load_split("train", streaming=True) self._train_dataset = self._train_dataset.map( loader.make_preprocess_fn(self._processor, augmenter), remove_columns=self._train_dataset.column_names, ) try: self._eval_dataset = loader.load_split("validation", streaming=False) self._eval_dataset = self._eval_dataset.map( loader.make_preprocess_fn(self._processor, augmenter=None), remove_columns=self._eval_dataset.column_names, ) except Exception: logger.warning("No validation split found — eval will be skipped.") self._eval_dataset = None def build_training_args(self) -> Seq2SeqTrainingArguments: tc = self.config["training"] output_dir = self.lang_config.get("output_dir", tc["output_dir"]) return Seq2SeqTrainingArguments( output_dir=output_dir, per_device_train_batch_size=tc["per_device_train_batch_size"], gradient_accumulation_steps=tc["gradient_accumulation_steps"], warmup_steps=tc["warmup_steps"], max_steps=tc["max_steps"], save_steps=tc["save_steps"], eval_steps=tc["eval_steps"] if self._eval_dataset is not None else None, evaluation_strategy="steps" if self._eval_dataset is not None else "no", learning_rate=tc["learning_rate"], fp16=tc["fp16"] and torch.cuda.is_available(), dataloader_num_workers=tc["dataloader_num_workers"], # 0 on Windows predict_with_generate=True, generation_max_length=128, logging_steps=25, load_best_model_at_end=self._eval_dataset is not None, metric_for_best_model="wer", greater_is_better=False, report_to="none", ) def merge_extra_data( self, feedback_records: list[dict], repeat: int = 3, waxal_cap: int = 500, ) -> None: """ Merge feedback corrections into the training dataset. Materializes up to `waxal_cap` Waxal samples (converts streaming → Dataset), then appends `feedback_records` (each repeated `repeat` times for upsampling) preprocessed into {input_features, labels} format. Call this after setup() and before train(). Args: feedback_records: List of dicts from corrections.jsonl with keys 'audio_file' (path) and 'corrected_text'. repeat: How many times to repeat each feedback sample (3× keeps corrections competitive with Waxal baseline). waxal_cap: Max Waxal samples to materialise (avoids OOM on Colab T4). """ if self._peft_model is None: raise RuntimeError("Call setup() before merge_extra_data().") import librosa import numpy as np from datasets import Dataset, concatenate_datasets logger.info( "Merging %d feedback records (×%d) with Waxal (cap=%d)...", len(feedback_records), repeat, waxal_cap, ) # ── 1. Materialise Waxal streaming dataset ───────────────────────────── waxal_rows: list[dict] = [] for row in self._train_dataset: waxal_rows.append(row) if len(waxal_rows) >= waxal_cap: break waxal_ds = Dataset.from_list(waxal_rows) logger.info("Materialised %d Waxal samples.", len(waxal_ds)) # ── 2. Preprocess feedback records ───────────────────────────────────── def _load_preprocess(rec: dict) -> dict | None: try: audio_np, _ = librosa.load(rec["audio_file"], sr=16000, mono=True) inputs = self._processor.feature_extractor( audio_np, sampling_rate=16000, return_tensors="np" ) labels = self._processor.tokenizer( rec["corrected_text"], return_tensors="np" ).input_ids[0] return { "input_features": inputs.input_features[0], "labels": labels, } except Exception as e: logger.warning("Skipping feedback record %s: %s", rec.get("id", "?"), e) return None fb_rows = [] for rec in feedback_records * repeat: processed = _load_preprocess(rec) if processed is not None: fb_rows.append(processed) if not fb_rows: logger.warning("No feedback records could be processed — using Waxal only.") self._train_dataset = waxal_ds return fb_ds = Dataset.from_list(fb_rows) logger.info("Preprocessed %d feedback rows (after ×%d repeat).", len(fb_ds), repeat) # ── 3. Concatenate and replace train_dataset ─────────────────────────── self._train_dataset = concatenate_datasets([waxal_ds, fb_ds]).shuffle(seed=42) logger.info("Final training dataset: %d samples.", len(self._train_dataset)) def train(self) -> None: if self._peft_model is None: raise RuntimeError("Call setup() before train().") training_args = self.build_training_args() output_dir = self.lang_config.get("output_dir", self.config["training"]["output_dir"]) collator = DataCollatorSpeechSeq2SeqWithPadding( processor=self._processor, decoder_start_token_id=self._backbone.model.config.decoder_start_token_id, ) callbacks = [ AdapterCheckpointCallback(output_dir), EarlyStoppingOnWER(patience=5), ] compute_metrics = make_compute_metrics(self._processor) if self._eval_dataset is not None else None trainer = Seq2SeqTrainer( model=self._peft_model, args=training_args, train_dataset=self._train_dataset, eval_dataset=self._eval_dataset, data_collator=collator, compute_metrics=compute_metrics, callbacks=callbacks, tokenizer=self._processor.feature_extractor, ) logger.info("Starting training for language '%s'...", self.lang_config["language"]) trainer.train() # Save final adapter weights Path(output_dir).mkdir(parents=True, exist_ok=True) self._peft_model.save_pretrained(output_dir) logger.info("Adapter saved to %s", output_dir)