Spaces:
Sleeping
Sleeping
File size: 9,658 Bytes
76db545 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | """
Orchestrates full LoRA fine-tuning:
WhisperBackbone + PEFT LoraConfig + WaxalDataLoader + Seq2SeqTrainer
Usage:
trainer = WhisperLoRATrainer("configs/base_config.yaml", "configs/lora_bambara.yaml")
trainer.setup()
trainer.train()
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
import torch
import yaml
from peft import LoraConfig, TaskType, get_peft_model
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from src.data.augmentation import FieldNoiseAugmenter
from src.data.feature_extractor import DataCollatorSpeechSeq2SeqWithPadding
from src.data.waxal_loader import WaxalDataLoader
from src.engine.whisper_base import WhisperBackbone
from src.training.callbacks import AdapterCheckpointCallback, EarlyStoppingOnWER
from src.training.metrics import make_compute_metrics
logger = logging.getLogger(__name__)
class WhisperLoRATrainer:
"""Fine-tunes a language-specific LoRA adapter on top of Whisper."""
def __init__(self, base_config_path: str, language_config_path: str) -> None:
self._base_config_path = base_config_path
with open(base_config_path) as f:
self.config = yaml.safe_load(f)
with open(language_config_path) as f:
self.lang_config = yaml.safe_load(f)
self._backbone: WhisperBackbone | None = None
self._peft_model = None
self._processor = None
self._train_dataset = None
self._eval_dataset = None
def setup(self) -> None:
"""Load backbone, build LoRA config, prepare datasets."""
hf_token = os.getenv("HF_TOKEN")
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Load backbone
logger.info("Loading backbone model...")
self._backbone = WhisperBackbone(config_path=self._base_config_path)
self._backbone.load(device=device, hf_token=hf_token)
self._processor = self._backbone.processor
# Disable cache for training
self._backbone.model.config.use_cache = False
# 2. Wrap with LoRA
lora_cfg = self.lang_config["lora"]
lora_config = LoraConfig(
r=lora_cfg["r"],
lora_alpha=lora_cfg["lora_alpha"],
target_modules=lora_cfg["target_modules"],
lora_dropout=lora_cfg["lora_dropout"],
bias=lora_cfg["bias"],
task_type=TaskType.SEQ_2_SEQ_LM,
)
self._peft_model = get_peft_model(self._backbone.model, lora_config)
self._peft_model.print_trainable_parameters()
# 3. Load data
subset = self.lang_config["dataset_subset"]
augmenter = FieldNoiseAugmenter(self.config["paths"]["noise_samples"], self.config)
loader = WaxalDataLoader(subset, self.config, hf_token=hf_token)
logger.info("Loading training data (streaming)...")
self._train_dataset = loader.load_split("train", streaming=True)
self._train_dataset = self._train_dataset.map(
loader.make_preprocess_fn(self._processor, augmenter),
remove_columns=self._train_dataset.column_names,
)
try:
self._eval_dataset = loader.load_split("validation", streaming=False)
self._eval_dataset = self._eval_dataset.map(
loader.make_preprocess_fn(self._processor, augmenter=None),
remove_columns=self._eval_dataset.column_names,
)
except Exception:
logger.warning("No validation split found — eval will be skipped.")
self._eval_dataset = None
def build_training_args(self) -> Seq2SeqTrainingArguments:
tc = self.config["training"]
output_dir = self.lang_config.get("output_dir", tc["output_dir"])
return Seq2SeqTrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=tc["per_device_train_batch_size"],
gradient_accumulation_steps=tc["gradient_accumulation_steps"],
warmup_steps=tc["warmup_steps"],
max_steps=tc["max_steps"],
save_steps=tc["save_steps"],
eval_steps=tc["eval_steps"] if self._eval_dataset is not None else None,
evaluation_strategy="steps" if self._eval_dataset is not None else "no",
learning_rate=tc["learning_rate"],
fp16=tc["fp16"] and torch.cuda.is_available(),
dataloader_num_workers=tc["dataloader_num_workers"], # 0 on Windows
predict_with_generate=True,
generation_max_length=128,
logging_steps=25,
load_best_model_at_end=self._eval_dataset is not None,
metric_for_best_model="wer",
greater_is_better=False,
report_to="none",
)
def merge_extra_data(
self,
feedback_records: list[dict],
repeat: int = 3,
waxal_cap: int = 500,
) -> None:
"""
Merge feedback corrections into the training dataset.
Materializes up to `waxal_cap` Waxal samples (converts streaming → Dataset),
then appends `feedback_records` (each repeated `repeat` times for upsampling)
preprocessed into {input_features, labels} format.
Call this after setup() and before train().
Args:
feedback_records: List of dicts from corrections.jsonl with keys
'audio_file' (path) and 'corrected_text'.
repeat: How many times to repeat each feedback sample
(3× keeps corrections competitive with Waxal baseline).
waxal_cap: Max Waxal samples to materialise (avoids OOM on Colab T4).
"""
if self._peft_model is None:
raise RuntimeError("Call setup() before merge_extra_data().")
import librosa
import numpy as np
from datasets import Dataset, concatenate_datasets
logger.info(
"Merging %d feedback records (×%d) with Waxal (cap=%d)...",
len(feedback_records), repeat, waxal_cap,
)
# ── 1. Materialise Waxal streaming dataset ─────────────────────────────
waxal_rows: list[dict] = []
for row in self._train_dataset:
waxal_rows.append(row)
if len(waxal_rows) >= waxal_cap:
break
waxal_ds = Dataset.from_list(waxal_rows)
logger.info("Materialised %d Waxal samples.", len(waxal_ds))
# ── 2. Preprocess feedback records ─────────────────────────────────────
def _load_preprocess(rec: dict) -> dict | None:
try:
audio_np, _ = librosa.load(rec["audio_file"], sr=16000, mono=True)
inputs = self._processor.feature_extractor(
audio_np, sampling_rate=16000, return_tensors="np"
)
labels = self._processor.tokenizer(
rec["corrected_text"], return_tensors="np"
).input_ids[0]
return {
"input_features": inputs.input_features[0],
"labels": labels,
}
except Exception as e:
logger.warning("Skipping feedback record %s: %s", rec.get("id", "?"), e)
return None
fb_rows = []
for rec in feedback_records * repeat:
processed = _load_preprocess(rec)
if processed is not None:
fb_rows.append(processed)
if not fb_rows:
logger.warning("No feedback records could be processed — using Waxal only.")
self._train_dataset = waxal_ds
return
fb_ds = Dataset.from_list(fb_rows)
logger.info("Preprocessed %d feedback rows (after ×%d repeat).", len(fb_ds), repeat)
# ── 3. Concatenate and replace train_dataset ───────────────────────────
self._train_dataset = concatenate_datasets([waxal_ds, fb_ds]).shuffle(seed=42)
logger.info("Final training dataset: %d samples.", len(self._train_dataset))
def train(self) -> None:
if self._peft_model is None:
raise RuntimeError("Call setup() before train().")
training_args = self.build_training_args()
output_dir = self.lang_config.get("output_dir", self.config["training"]["output_dir"])
collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=self._processor,
decoder_start_token_id=self._backbone.model.config.decoder_start_token_id,
)
callbacks = [
AdapterCheckpointCallback(output_dir),
EarlyStoppingOnWER(patience=5),
]
compute_metrics = make_compute_metrics(self._processor) if self._eval_dataset is not None else None
trainer = Seq2SeqTrainer(
model=self._peft_model,
args=training_args,
train_dataset=self._train_dataset,
eval_dataset=self._eval_dataset,
data_collator=collator,
compute_metrics=compute_metrics,
callbacks=callbacks,
tokenizer=self._processor.feature_extractor,
)
logger.info("Starting training for language '%s'...", self.lang_config["language"])
trainer.train()
# Save final adapter weights
Path(output_dir).mkdir(parents=True, exist_ok=True)
self._peft_model.save_pretrained(output_dir)
logger.info("Adapter saved to %s", output_dir)
|