ground-zero / src /training /trainer.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
"""
Orchestrates full LoRA fine-tuning:
WhisperBackbone + PEFT LoraConfig + WaxalDataLoader + Seq2SeqTrainer
Usage:
trainer = WhisperLoRATrainer("configs/base_config.yaml", "configs/lora_bambara.yaml")
trainer.setup()
trainer.train()
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
import torch
import yaml
from peft import LoraConfig, TaskType, get_peft_model
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from src.data.augmentation import FieldNoiseAugmenter
from src.data.feature_extractor import DataCollatorSpeechSeq2SeqWithPadding
from src.data.waxal_loader import WaxalDataLoader
from src.engine.whisper_base import WhisperBackbone
from src.training.callbacks import AdapterCheckpointCallback, EarlyStoppingOnWER
from src.training.metrics import make_compute_metrics
logger = logging.getLogger(__name__)
class WhisperLoRATrainer:
"""Fine-tunes a language-specific LoRA adapter on top of Whisper."""
def __init__(self, base_config_path: str, language_config_path: str) -> None:
self._base_config_path = base_config_path
with open(base_config_path) as f:
self.config = yaml.safe_load(f)
with open(language_config_path) as f:
self.lang_config = yaml.safe_load(f)
self._backbone: WhisperBackbone | None = None
self._peft_model = None
self._processor = None
self._train_dataset = None
self._eval_dataset = None
def setup(self) -> None:
"""Load backbone, build LoRA config, prepare datasets."""
hf_token = os.getenv("HF_TOKEN")
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Load backbone
logger.info("Loading backbone model...")
self._backbone = WhisperBackbone(config_path=self._base_config_path)
self._backbone.load(device=device, hf_token=hf_token)
self._processor = self._backbone.processor
# Disable cache for training
self._backbone.model.config.use_cache = False
# 2. Wrap with LoRA
lora_cfg = self.lang_config["lora"]
lora_config = LoraConfig(
r=lora_cfg["r"],
lora_alpha=lora_cfg["lora_alpha"],
target_modules=lora_cfg["target_modules"],
lora_dropout=lora_cfg["lora_dropout"],
bias=lora_cfg["bias"],
task_type=TaskType.SEQ_2_SEQ_LM,
)
self._peft_model = get_peft_model(self._backbone.model, lora_config)
self._peft_model.print_trainable_parameters()
# 3. Load data
subset = self.lang_config["dataset_subset"]
augmenter = FieldNoiseAugmenter(self.config["paths"]["noise_samples"], self.config)
loader = WaxalDataLoader(subset, self.config, hf_token=hf_token)
logger.info("Loading training data (streaming)...")
self._train_dataset = loader.load_split("train", streaming=True)
self._train_dataset = self._train_dataset.map(
loader.make_preprocess_fn(self._processor, augmenter),
remove_columns=self._train_dataset.column_names,
)
try:
self._eval_dataset = loader.load_split("validation", streaming=False)
self._eval_dataset = self._eval_dataset.map(
loader.make_preprocess_fn(self._processor, augmenter=None),
remove_columns=self._eval_dataset.column_names,
)
except Exception:
logger.warning("No validation split found — eval will be skipped.")
self._eval_dataset = None
def build_training_args(self) -> Seq2SeqTrainingArguments:
tc = self.config["training"]
output_dir = self.lang_config.get("output_dir", tc["output_dir"])
return Seq2SeqTrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=tc["per_device_train_batch_size"],
gradient_accumulation_steps=tc["gradient_accumulation_steps"],
warmup_steps=tc["warmup_steps"],
max_steps=tc["max_steps"],
save_steps=tc["save_steps"],
eval_steps=tc["eval_steps"] if self._eval_dataset is not None else None,
evaluation_strategy="steps" if self._eval_dataset is not None else "no",
learning_rate=tc["learning_rate"],
fp16=tc["fp16"] and torch.cuda.is_available(),
dataloader_num_workers=tc["dataloader_num_workers"], # 0 on Windows
predict_with_generate=True,
generation_max_length=128,
logging_steps=25,
load_best_model_at_end=self._eval_dataset is not None,
metric_for_best_model="wer",
greater_is_better=False,
report_to="none",
)
def merge_extra_data(
self,
feedback_records: list[dict],
repeat: int = 3,
waxal_cap: int = 500,
) -> None:
"""
Merge feedback corrections into the training dataset.
Materializes up to `waxal_cap` Waxal samples (converts streaming → Dataset),
then appends `feedback_records` (each repeated `repeat` times for upsampling)
preprocessed into {input_features, labels} format.
Call this after setup() and before train().
Args:
feedback_records: List of dicts from corrections.jsonl with keys
'audio_file' (path) and 'corrected_text'.
repeat: How many times to repeat each feedback sample
(3× keeps corrections competitive with Waxal baseline).
waxal_cap: Max Waxal samples to materialise (avoids OOM on Colab T4).
"""
if self._peft_model is None:
raise RuntimeError("Call setup() before merge_extra_data().")
import librosa
import numpy as np
from datasets import Dataset, concatenate_datasets
logger.info(
"Merging %d feedback records (×%d) with Waxal (cap=%d)...",
len(feedback_records), repeat, waxal_cap,
)
# ── 1. Materialise Waxal streaming dataset ─────────────────────────────
waxal_rows: list[dict] = []
for row in self._train_dataset:
waxal_rows.append(row)
if len(waxal_rows) >= waxal_cap:
break
waxal_ds = Dataset.from_list(waxal_rows)
logger.info("Materialised %d Waxal samples.", len(waxal_ds))
# ── 2. Preprocess feedback records ─────────────────────────────────────
def _load_preprocess(rec: dict) -> dict | None:
try:
audio_np, _ = librosa.load(rec["audio_file"], sr=16000, mono=True)
inputs = self._processor.feature_extractor(
audio_np, sampling_rate=16000, return_tensors="np"
)
labels = self._processor.tokenizer(
rec["corrected_text"], return_tensors="np"
).input_ids[0]
return {
"input_features": inputs.input_features[0],
"labels": labels,
}
except Exception as e:
logger.warning("Skipping feedback record %s: %s", rec.get("id", "?"), e)
return None
fb_rows = []
for rec in feedback_records * repeat:
processed = _load_preprocess(rec)
if processed is not None:
fb_rows.append(processed)
if not fb_rows:
logger.warning("No feedback records could be processed — using Waxal only.")
self._train_dataset = waxal_ds
return
fb_ds = Dataset.from_list(fb_rows)
logger.info("Preprocessed %d feedback rows (after ×%d repeat).", len(fb_ds), repeat)
# ── 3. Concatenate and replace train_dataset ───────────────────────────
self._train_dataset = concatenate_datasets([waxal_ds, fb_ds]).shuffle(seed=42)
logger.info("Final training dataset: %d samples.", len(self._train_dataset))
def train(self) -> None:
if self._peft_model is None:
raise RuntimeError("Call setup() before train().")
training_args = self.build_training_args()
output_dir = self.lang_config.get("output_dir", self.config["training"]["output_dir"])
collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=self._processor,
decoder_start_token_id=self._backbone.model.config.decoder_start_token_id,
)
callbacks = [
AdapterCheckpointCallback(output_dir),
EarlyStoppingOnWER(patience=5),
]
compute_metrics = make_compute_metrics(self._processor) if self._eval_dataset is not None else None
trainer = Seq2SeqTrainer(
model=self._peft_model,
args=training_args,
train_dataset=self._train_dataset,
eval_dataset=self._eval_dataset,
data_collator=collator,
compute_metrics=compute_metrics,
callbacks=callbacks,
tokenizer=self._processor.feature_extractor,
)
logger.info("Starting training for language '%s'...", self.lang_config["language"])
trainer.train()
# Save final adapter weights
Path(output_dir).mkdir(parents=True, exist_ok=True)
self._peft_model.save_pretrained(output_dir)
logger.info("Adapter saved to %s", output_dir)