ground-zero / src /data /waxal_loader.py
Broulaye Doumbia
Fix: add trust_remote_code=True for datasets 3.x compatibility
71bb3bc
"""
Loads and preprocesses the google/fleurs dataset for Bambara (bam) and Fula (ful).
Uses streaming to avoid downloading the full corpus before training.
google/waxal was removed from the Hub; google/fleurs is the maintained replacement.
Subset mapping: bam → bam_ML (Bambara Mali), ful → ff_SN (Fula/Pular Senegal).
Column names (audio, transcription) are identical.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Callable, Iterator
import numpy as np
import torch
import torchaudio
from datasets import load_dataset
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import WhisperProcessor
from src.data.augmentation import FieldNoiseAugmenter
logger = logging.getLogger(__name__)
# google/fleurs column names (identical to the former google/waxal schema)
AUDIO_COL = "audio"
TEXT_COL = "transcription"
TARGET_SR = 16_000
# Map our short language codes to google/fleurs subset names
_FLEURS_SUBSET = {
"bam": "bam_ML", # Bambara — Mali
"ful": "ff_SN", # Fula/Pular — Senegal
}
class WaxalDataLoader:
"""Streams the google/fleurs dataset and prepares examples for Whisper training."""
def __init__(
self,
subset: str,
config: dict,
hf_token: str | None = None,
) -> None:
if subset not in _FLEURS_SUBSET:
raise ValueError(f"subset must be 'bam' or 'ful', got '{subset}'")
self.subset = subset
self._fleurs_subset = _FLEURS_SUBSET[subset]
self.config = config
self.hf_token = hf_token
def load_split(self, split: str = "train", streaming: bool = True) -> "IterableDataset | Dataset":
"""Return a single split of google/fleurs."""
logger.info(
"Loading google/fleurs subset=%s (%s) split=%s streaming=%s",
self._fleurs_subset, self.subset, split, streaming,
)
ds = load_dataset(
"google/fleurs",
self._fleurs_subset,
split=split,
token=self.hf_token,
streaming=streaming,
trust_remote_code=True,
)
if streaming:
ds = ds.shuffle(seed=42, buffer_size=1000)
return ds
def get_splits(self, streaming: bool = True) -> dict[str, "IterableDataset | Dataset"]:
"""Return train / validation / test splits."""
splits = {}
for split in ("train", "validation", "test"):
try:
splits[split] = self.load_split(split, streaming=streaming)
except Exception:
logger.warning("Split '%s' not available for subset '%s'", split, self.subset)
return splits
def make_preprocess_fn(
self,
processor: "WhisperProcessor",
augmenter: "FieldNoiseAugmenter | None" = None,
) -> Callable[[dict], dict]:
"""Return a function that converts a raw Waxal example into model inputs."""
def preprocess(example: dict) -> dict:
# Extract and resample audio
audio_array = np.array(example[AUDIO_COL]["array"], dtype=np.float32)
orig_sr: int = example[AUDIO_COL]["sampling_rate"]
if orig_sr != TARGET_SR:
tensor = torch.from_numpy(audio_array).unsqueeze(0)
tensor = torchaudio.functional.resample(tensor, orig_sr, TARGET_SR)
audio_array = tensor.squeeze(0).numpy()
# Apply field noise augmentation if provided
if augmenter is not None and augmenter.is_ready():
audio_array = augmenter.augment(audio_array, TARGET_SR)
# Extract log-mel features
inputs = processor.feature_extractor(
audio_array,
sampling_rate=TARGET_SR,
return_tensors="np",
)
input_features = inputs.input_features[0] # shape (80, 3000)
# Tokenize transcript
text: str = example[TEXT_COL]
labels = processor.tokenizer(text, return_tensors="np").input_ids[0]
return {
"input_features": input_features,
"labels": labels,
}
return preprocess
def iter_processed(
self,
processor: "WhisperProcessor",
split: str = "train",
augmenter: "FieldNoiseAugmenter | None" = None,
) -> Iterator[dict]:
"""Yield preprocessed examples one at a time (streaming)."""
ds = self.load_split(split, streaming=True)
fn = self.make_preprocess_fn(processor, augmenter)
for example in ds:
yield fn(example)