Spaces:
Running
Running
| """ | |
| Orchestrates full LoRA fine-tuning: | |
| WhisperBackbone + PEFT LoraConfig + WaxalDataLoader + Seq2SeqTrainer | |
| Usage: | |
| trainer = WhisperLoRATrainer("configs/base_config.yaml", "configs/lora_bambara.yaml") | |
| trainer.setup() | |
| trainer.train() | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from pathlib import Path | |
| import torch | |
| import yaml | |
| from peft import LoraConfig, TaskType, get_peft_model | |
| from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments | |
| from src.data.augmentation import FieldNoiseAugmenter | |
| from src.data.feature_extractor import DataCollatorSpeechSeq2SeqWithPadding | |
| from src.data.waxal_loader import WaxalDataLoader | |
| from src.engine.whisper_base import WhisperBackbone | |
| from src.training.callbacks import AdapterCheckpointCallback, EarlyStoppingOnWER | |
| from src.training.metrics import make_compute_metrics | |
| logger = logging.getLogger(__name__) | |
| class WhisperLoRATrainer: | |
| """Fine-tunes a language-specific LoRA adapter on top of Whisper.""" | |
| def __init__(self, base_config_path: str, language_config_path: str) -> None: | |
| self._base_config_path = base_config_path | |
| with open(base_config_path) as f: | |
| self.config = yaml.safe_load(f) | |
| with open(language_config_path) as f: | |
| self.lang_config = yaml.safe_load(f) | |
| self._backbone: WhisperBackbone | None = None | |
| self._peft_model = None | |
| self._processor = None | |
| self._train_dataset = None | |
| self._eval_dataset = None | |
| def setup(self) -> None: | |
| """Load backbone, build LoRA config, prepare datasets.""" | |
| hf_token = os.getenv("HF_TOKEN") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # 1. Load backbone | |
| logger.info("Loading backbone model...") | |
| self._backbone = WhisperBackbone(config_path=self._base_config_path) | |
| self._backbone.load(device=device, hf_token=hf_token) | |
| self._processor = self._backbone.processor | |
| # Disable cache for training | |
| self._backbone.model.config.use_cache = False | |
| # 2. Wrap with LoRA | |
| lora_cfg = self.lang_config["lora"] | |
| lora_config = LoraConfig( | |
| r=lora_cfg["r"], | |
| lora_alpha=lora_cfg["lora_alpha"], | |
| target_modules=lora_cfg["target_modules"], | |
| lora_dropout=lora_cfg["lora_dropout"], | |
| bias=lora_cfg["bias"], | |
| task_type=TaskType.SEQ_2_SEQ_LM, | |
| ) | |
| self._peft_model = get_peft_model(self._backbone.model, lora_config) | |
| self._peft_model.print_trainable_parameters() | |
| # 3. Load data | |
| subset = self.lang_config["dataset_subset"] | |
| augmenter = FieldNoiseAugmenter(self.config["paths"]["noise_samples"], self.config) | |
| loader = WaxalDataLoader(subset, self.config, hf_token=hf_token) | |
| logger.info("Loading training data (streaming)...") | |
| self._train_dataset = loader.load_split("train", streaming=True) | |
| self._train_dataset = self._train_dataset.map( | |
| loader.make_preprocess_fn(self._processor, augmenter), | |
| remove_columns=self._train_dataset.column_names, | |
| ) | |
| try: | |
| self._eval_dataset = loader.load_split("validation", streaming=False) | |
| self._eval_dataset = self._eval_dataset.map( | |
| loader.make_preprocess_fn(self._processor, augmenter=None), | |
| remove_columns=self._eval_dataset.column_names, | |
| ) | |
| except Exception: | |
| logger.warning("No validation split found — eval will be skipped.") | |
| self._eval_dataset = None | |
| def build_training_args(self) -> Seq2SeqTrainingArguments: | |
| tc = self.config["training"] | |
| output_dir = self.lang_config.get("output_dir", tc["output_dir"]) | |
| return Seq2SeqTrainingArguments( | |
| output_dir=output_dir, | |
| per_device_train_batch_size=tc["per_device_train_batch_size"], | |
| gradient_accumulation_steps=tc["gradient_accumulation_steps"], | |
| warmup_steps=tc["warmup_steps"], | |
| max_steps=tc["max_steps"], | |
| save_steps=tc["save_steps"], | |
| eval_steps=tc["eval_steps"] if self._eval_dataset is not None else None, | |
| evaluation_strategy="steps" if self._eval_dataset is not None else "no", | |
| learning_rate=tc["learning_rate"], | |
| fp16=tc["fp16"] and torch.cuda.is_available(), | |
| dataloader_num_workers=tc["dataloader_num_workers"], # 0 on Windows | |
| predict_with_generate=True, | |
| generation_max_length=128, | |
| logging_steps=25, | |
| load_best_model_at_end=self._eval_dataset is not None, | |
| metric_for_best_model="wer", | |
| greater_is_better=False, | |
| report_to="none", | |
| ) | |
| def merge_extra_data( | |
| self, | |
| feedback_records: list[dict], | |
| repeat: int = 3, | |
| waxal_cap: int = 500, | |
| ) -> None: | |
| """ | |
| Merge feedback corrections into the training dataset. | |
| Materializes up to `waxal_cap` Waxal samples (converts streaming → Dataset), | |
| then appends `feedback_records` (each repeated `repeat` times for upsampling) | |
| preprocessed into {input_features, labels} format. | |
| Call this after setup() and before train(). | |
| Args: | |
| feedback_records: List of dicts from corrections.jsonl with keys | |
| 'audio_file' (path) and 'corrected_text'. | |
| repeat: How many times to repeat each feedback sample | |
| (3× keeps corrections competitive with Waxal baseline). | |
| waxal_cap: Max Waxal samples to materialise (avoids OOM on Colab T4). | |
| """ | |
| if self._peft_model is None: | |
| raise RuntimeError("Call setup() before merge_extra_data().") | |
| import librosa | |
| import numpy as np | |
| from datasets import Dataset, concatenate_datasets | |
| logger.info( | |
| "Merging %d feedback records (×%d) with Waxal (cap=%d)...", | |
| len(feedback_records), repeat, waxal_cap, | |
| ) | |
| # ── 1. Materialise Waxal streaming dataset ───────────────────────────── | |
| waxal_rows: list[dict] = [] | |
| for row in self._train_dataset: | |
| waxal_rows.append(row) | |
| if len(waxal_rows) >= waxal_cap: | |
| break | |
| waxal_ds = Dataset.from_list(waxal_rows) | |
| logger.info("Materialised %d Waxal samples.", len(waxal_ds)) | |
| # ── 2. Preprocess feedback records ───────────────────────────────────── | |
| def _load_preprocess(rec: dict) -> dict | None: | |
| try: | |
| audio_np, _ = librosa.load(rec["audio_file"], sr=16000, mono=True) | |
| inputs = self._processor.feature_extractor( | |
| audio_np, sampling_rate=16000, return_tensors="np" | |
| ) | |
| labels = self._processor.tokenizer( | |
| rec["corrected_text"], return_tensors="np" | |
| ).input_ids[0] | |
| return { | |
| "input_features": inputs.input_features[0], | |
| "labels": labels, | |
| } | |
| except Exception as e: | |
| logger.warning("Skipping feedback record %s: %s", rec.get("id", "?"), e) | |
| return None | |
| fb_rows = [] | |
| for rec in feedback_records * repeat: | |
| processed = _load_preprocess(rec) | |
| if processed is not None: | |
| fb_rows.append(processed) | |
| if not fb_rows: | |
| logger.warning("No feedback records could be processed — using Waxal only.") | |
| self._train_dataset = waxal_ds | |
| return | |
| fb_ds = Dataset.from_list(fb_rows) | |
| logger.info("Preprocessed %d feedback rows (after ×%d repeat).", len(fb_ds), repeat) | |
| # ── 3. Concatenate and replace train_dataset ─────────────────────────── | |
| self._train_dataset = concatenate_datasets([waxal_ds, fb_ds]).shuffle(seed=42) | |
| logger.info("Final training dataset: %d samples.", len(self._train_dataset)) | |
| def train(self) -> None: | |
| if self._peft_model is None: | |
| raise RuntimeError("Call setup() before train().") | |
| training_args = self.build_training_args() | |
| output_dir = self.lang_config.get("output_dir", self.config["training"]["output_dir"]) | |
| collator = DataCollatorSpeechSeq2SeqWithPadding( | |
| processor=self._processor, | |
| decoder_start_token_id=self._backbone.model.config.decoder_start_token_id, | |
| ) | |
| callbacks = [ | |
| AdapterCheckpointCallback(output_dir), | |
| EarlyStoppingOnWER(patience=5), | |
| ] | |
| compute_metrics = make_compute_metrics(self._processor) if self._eval_dataset is not None else None | |
| trainer = Seq2SeqTrainer( | |
| model=self._peft_model, | |
| args=training_args, | |
| train_dataset=self._train_dataset, | |
| eval_dataset=self._eval_dataset, | |
| data_collator=collator, | |
| compute_metrics=compute_metrics, | |
| callbacks=callbacks, | |
| tokenizer=self._processor.feature_extractor, | |
| ) | |
| logger.info("Starting training for language '%s'...", self.lang_config["language"]) | |
| trainer.train() | |
| # Save final adapter weights | |
| Path(output_dir).mkdir(parents=True, exist_ok=True) | |
| self._peft_model.save_pretrained(output_dir) | |
| logger.info("Adapter saved to %s", output_dir) | |