""" Loads and preprocesses the google/fleurs dataset for Bambara (bam) and Fula (ful). Uses streaming to avoid downloading the full corpus before training. google/waxal was removed from the Hub; google/fleurs is the maintained replacement. Subset mapping: bam → bam_ML (Bambara Mali), ful → ff_SN (Fula/Pular Senegal). Column names (audio, transcription) are identical. """ from __future__ import annotations import logging from typing import TYPE_CHECKING, Callable, Iterator import numpy as np import torch import torchaudio from datasets import load_dataset if TYPE_CHECKING: from datasets import Dataset, IterableDataset from transformers import WhisperProcessor from src.data.augmentation import FieldNoiseAugmenter logger = logging.getLogger(__name__) # google/fleurs column names (identical to the former google/waxal schema) AUDIO_COL = "audio" TEXT_COL = "transcription" TARGET_SR = 16_000 # Map our short language codes to google/fleurs subset names _FLEURS_SUBSET = { "bam": "bam_ML", # Bambara — Mali "ful": "ff_SN", # Fula/Pular — Senegal } class WaxalDataLoader: """Streams the google/fleurs dataset and prepares examples for Whisper training.""" def __init__( self, subset: str, config: dict, hf_token: str | None = None, ) -> None: if subset not in _FLEURS_SUBSET: raise ValueError(f"subset must be 'bam' or 'ful', got '{subset}'") self.subset = subset self._fleurs_subset = _FLEURS_SUBSET[subset] self.config = config self.hf_token = hf_token def load_split(self, split: str = "train", streaming: bool = True) -> "IterableDataset | Dataset": """Return a single split of google/fleurs.""" logger.info( "Loading google/fleurs subset=%s (%s) split=%s streaming=%s", self._fleurs_subset, self.subset, split, streaming, ) ds = load_dataset( "google/fleurs", self._fleurs_subset, split=split, token=self.hf_token, streaming=streaming, trust_remote_code=True, ) if streaming: ds = ds.shuffle(seed=42, buffer_size=1000) return ds def get_splits(self, streaming: bool = True) -> dict[str, "IterableDataset | Dataset"]: """Return train / validation / test splits.""" splits = {} for split in ("train", "validation", "test"): try: splits[split] = self.load_split(split, streaming=streaming) except Exception: logger.warning("Split '%s' not available for subset '%s'", split, self.subset) return splits def make_preprocess_fn( self, processor: "WhisperProcessor", augmenter: "FieldNoiseAugmenter | None" = None, ) -> Callable[[dict], dict]: """Return a function that converts a raw Waxal example into model inputs.""" def preprocess(example: dict) -> dict: # Extract and resample audio audio_array = np.array(example[AUDIO_COL]["array"], dtype=np.float32) orig_sr: int = example[AUDIO_COL]["sampling_rate"] if orig_sr != TARGET_SR: tensor = torch.from_numpy(audio_array).unsqueeze(0) tensor = torchaudio.functional.resample(tensor, orig_sr, TARGET_SR) audio_array = tensor.squeeze(0).numpy() # Apply field noise augmentation if provided if augmenter is not None and augmenter.is_ready(): audio_array = augmenter.augment(audio_array, TARGET_SR) # Extract log-mel features inputs = processor.feature_extractor( audio_array, sampling_rate=TARGET_SR, return_tensors="np", ) input_features = inputs.input_features[0] # shape (80, 3000) # Tokenize transcript text: str = example[TEXT_COL] labels = processor.tokenizer(text, return_tensors="np").input_ids[0] return { "input_features": input_features, "labels": labels, } return preprocess def iter_processed( self, processor: "WhisperProcessor", split: str = "train", augmenter: "FieldNoiseAugmenter | None" = None, ) -> Iterator[dict]: """Yield preprocessed examples one at a time (streaming).""" ds = self.load_split(split, streaming=True) fn = self.make_preprocess_fn(processor, augmenter) for example in ds: yield fn(example)