Spaces:
Running
Running
File size: 4,653 Bytes
76db545 d2183cd 76db545 d2183cd 76db545 d2183cd 76db545 d2183cd 76db545 d2183cd 76db545 d2183cd 76db545 d2183cd 76db545 d2183cd 76db545 d2183cd 76db545 71bb3bc 76db545 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """
Loads and preprocesses the google/fleurs dataset for Bambara (bam) and Fula (ful).
Uses streaming to avoid downloading the full corpus before training.
google/waxal was removed from the Hub; google/fleurs is the maintained replacement.
Subset mapping: bam → bam_ML (Bambara Mali), ful → ff_SN (Fula/Pular Senegal).
Column names (audio, transcription) are identical.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Callable, Iterator
import numpy as np
import torch
import torchaudio
from datasets import load_dataset
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import WhisperProcessor
from src.data.augmentation import FieldNoiseAugmenter
logger = logging.getLogger(__name__)
# google/fleurs column names (identical to the former google/waxal schema)
AUDIO_COL = "audio"
TEXT_COL = "transcription"
TARGET_SR = 16_000
# Map our short language codes to google/fleurs subset names
_FLEURS_SUBSET = {
"bam": "bam_ML", # Bambara — Mali
"ful": "ff_SN", # Fula/Pular — Senegal
}
class WaxalDataLoader:
"""Streams the google/fleurs dataset and prepares examples for Whisper training."""
def __init__(
self,
subset: str,
config: dict,
hf_token: str | None = None,
) -> None:
if subset not in _FLEURS_SUBSET:
raise ValueError(f"subset must be 'bam' or 'ful', got '{subset}'")
self.subset = subset
self._fleurs_subset = _FLEURS_SUBSET[subset]
self.config = config
self.hf_token = hf_token
def load_split(self, split: str = "train", streaming: bool = True) -> "IterableDataset | Dataset":
"""Return a single split of google/fleurs."""
logger.info(
"Loading google/fleurs subset=%s (%s) split=%s streaming=%s",
self._fleurs_subset, self.subset, split, streaming,
)
ds = load_dataset(
"google/fleurs",
self._fleurs_subset,
split=split,
token=self.hf_token,
streaming=streaming,
trust_remote_code=True,
)
if streaming:
ds = ds.shuffle(seed=42, buffer_size=1000)
return ds
def get_splits(self, streaming: bool = True) -> dict[str, "IterableDataset | Dataset"]:
"""Return train / validation / test splits."""
splits = {}
for split in ("train", "validation", "test"):
try:
splits[split] = self.load_split(split, streaming=streaming)
except Exception:
logger.warning("Split '%s' not available for subset '%s'", split, self.subset)
return splits
def make_preprocess_fn(
self,
processor: "WhisperProcessor",
augmenter: "FieldNoiseAugmenter | None" = None,
) -> Callable[[dict], dict]:
"""Return a function that converts a raw Waxal example into model inputs."""
def preprocess(example: dict) -> dict:
# Extract and resample audio
audio_array = np.array(example[AUDIO_COL]["array"], dtype=np.float32)
orig_sr: int = example[AUDIO_COL]["sampling_rate"]
if orig_sr != TARGET_SR:
tensor = torch.from_numpy(audio_array).unsqueeze(0)
tensor = torchaudio.functional.resample(tensor, orig_sr, TARGET_SR)
audio_array = tensor.squeeze(0).numpy()
# Apply field noise augmentation if provided
if augmenter is not None and augmenter.is_ready():
audio_array = augmenter.augment(audio_array, TARGET_SR)
# Extract log-mel features
inputs = processor.feature_extractor(
audio_array,
sampling_rate=TARGET_SR,
return_tensors="np",
)
input_features = inputs.input_features[0] # shape (80, 3000)
# Tokenize transcript
text: str = example[TEXT_COL]
labels = processor.tokenizer(text, return_tensors="np").input_ids[0]
return {
"input_features": input_features,
"labels": labels,
}
return preprocess
def iter_processed(
self,
processor: "WhisperProcessor",
split: str = "train",
augmenter: "FieldNoiseAugmenter | None" = None,
) -> Iterator[dict]:
"""Yield preprocessed examples one at a time (streaming)."""
ds = self.load_split(split, streaming=True)
fn = self.make_preprocess_fn(processor, augmenter)
for example in ds:
yield fn(example)
|