Spaces:
Running
Running
| """ | |
| Loads and preprocesses the google/fleurs dataset for Bambara (bam) and Fula (ful). | |
| Uses streaming to avoid downloading the full corpus before training. | |
| google/waxal was removed from the Hub; google/fleurs is the maintained replacement. | |
| Subset mapping: bam → bam_ML (Bambara Mali), ful → ff_SN (Fula/Pular Senegal). | |
| Column names (audio, transcription) are identical. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from typing import TYPE_CHECKING, Callable, Iterator | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from datasets import load_dataset | |
| if TYPE_CHECKING: | |
| from datasets import Dataset, IterableDataset | |
| from transformers import WhisperProcessor | |
| from src.data.augmentation import FieldNoiseAugmenter | |
| logger = logging.getLogger(__name__) | |
| # google/fleurs column names (identical to the former google/waxal schema) | |
| AUDIO_COL = "audio" | |
| TEXT_COL = "transcription" | |
| TARGET_SR = 16_000 | |
| # Map our short language codes to google/fleurs subset names | |
| _FLEURS_SUBSET = { | |
| "bam": "bam_ML", # Bambara — Mali | |
| "ful": "ff_SN", # Fula/Pular — Senegal | |
| } | |
| class WaxalDataLoader: | |
| """Streams the google/fleurs dataset and prepares examples for Whisper training.""" | |
| def __init__( | |
| self, | |
| subset: str, | |
| config: dict, | |
| hf_token: str | None = None, | |
| ) -> None: | |
| if subset not in _FLEURS_SUBSET: | |
| raise ValueError(f"subset must be 'bam' or 'ful', got '{subset}'") | |
| self.subset = subset | |
| self._fleurs_subset = _FLEURS_SUBSET[subset] | |
| self.config = config | |
| self.hf_token = hf_token | |
| def load_split(self, split: str = "train", streaming: bool = True) -> "IterableDataset | Dataset": | |
| """Return a single split of google/fleurs.""" | |
| logger.info( | |
| "Loading google/fleurs subset=%s (%s) split=%s streaming=%s", | |
| self._fleurs_subset, self.subset, split, streaming, | |
| ) | |
| ds = load_dataset( | |
| "google/fleurs", | |
| self._fleurs_subset, | |
| split=split, | |
| token=self.hf_token, | |
| streaming=streaming, | |
| trust_remote_code=True, | |
| ) | |
| if streaming: | |
| ds = ds.shuffle(seed=42, buffer_size=1000) | |
| return ds | |
| def get_splits(self, streaming: bool = True) -> dict[str, "IterableDataset | Dataset"]: | |
| """Return train / validation / test splits.""" | |
| splits = {} | |
| for split in ("train", "validation", "test"): | |
| try: | |
| splits[split] = self.load_split(split, streaming=streaming) | |
| except Exception: | |
| logger.warning("Split '%s' not available for subset '%s'", split, self.subset) | |
| return splits | |
| def make_preprocess_fn( | |
| self, | |
| processor: "WhisperProcessor", | |
| augmenter: "FieldNoiseAugmenter | None" = None, | |
| ) -> Callable[[dict], dict]: | |
| """Return a function that converts a raw Waxal example into model inputs.""" | |
| def preprocess(example: dict) -> dict: | |
| # Extract and resample audio | |
| audio_array = np.array(example[AUDIO_COL]["array"], dtype=np.float32) | |
| orig_sr: int = example[AUDIO_COL]["sampling_rate"] | |
| if orig_sr != TARGET_SR: | |
| tensor = torch.from_numpy(audio_array).unsqueeze(0) | |
| tensor = torchaudio.functional.resample(tensor, orig_sr, TARGET_SR) | |
| audio_array = tensor.squeeze(0).numpy() | |
| # Apply field noise augmentation if provided | |
| if augmenter is not None and augmenter.is_ready(): | |
| audio_array = augmenter.augment(audio_array, TARGET_SR) | |
| # Extract log-mel features | |
| inputs = processor.feature_extractor( | |
| audio_array, | |
| sampling_rate=TARGET_SR, | |
| return_tensors="np", | |
| ) | |
| input_features = inputs.input_features[0] # shape (80, 3000) | |
| # Tokenize transcript | |
| text: str = example[TEXT_COL] | |
| labels = processor.tokenizer(text, return_tensors="np").input_ids[0] | |
| return { | |
| "input_features": input_features, | |
| "labels": labels, | |
| } | |
| return preprocess | |
| def iter_processed( | |
| self, | |
| processor: "WhisperProcessor", | |
| split: str = "train", | |
| augmenter: "FieldNoiseAugmenter | None" = None, | |
| ) -> Iterator[dict]: | |
| """Yield preprocessed examples one at a time (streaming).""" | |
| ds = self.load_split(split, streaming=True) | |
| fn = self.make_preprocess_fn(processor, augmenter) | |
| for example in ds: | |
| yield fn(example) | |