File size: 4,653 Bytes
76db545
d2183cd
76db545
d2183cd
 
 
 
76db545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2183cd
76db545
 
 
 
d2183cd
 
 
 
 
 
76db545
 
d2183cd
76db545
 
 
 
 
 
 
d2183cd
76db545
 
d2183cd
76db545
 
 
 
d2183cd
 
 
 
 
76db545
d2183cd
 
76db545
 
 
71bb3bc
76db545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
Loads and preprocesses the google/fleurs dataset for Bambara (bam) and Fula (ful).
Uses streaming to avoid downloading the full corpus before training.

google/waxal was removed from the Hub; google/fleurs is the maintained replacement.
Subset mapping: bam → bam_ML (Bambara Mali), ful → ff_SN (Fula/Pular Senegal).
Column names (audio, transcription) are identical.
"""
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Callable, Iterator

import numpy as np
import torch
import torchaudio
from datasets import load_dataset

if TYPE_CHECKING:
    from datasets import Dataset, IterableDataset
    from transformers import WhisperProcessor

    from src.data.augmentation import FieldNoiseAugmenter

logger = logging.getLogger(__name__)

# google/fleurs column names (identical to the former google/waxal schema)
AUDIO_COL = "audio"
TEXT_COL = "transcription"
TARGET_SR = 16_000

# Map our short language codes to google/fleurs subset names
_FLEURS_SUBSET = {
    "bam": "bam_ML",  # Bambara — Mali
    "ful": "ff_SN",   # Fula/Pular — Senegal
}


class WaxalDataLoader:
    """Streams the google/fleurs dataset and prepares examples for Whisper training."""

    def __init__(
        self,
        subset: str,
        config: dict,
        hf_token: str | None = None,
    ) -> None:
        if subset not in _FLEURS_SUBSET:
            raise ValueError(f"subset must be 'bam' or 'ful', got '{subset}'")
        self.subset = subset
        self._fleurs_subset = _FLEURS_SUBSET[subset]
        self.config = config
        self.hf_token = hf_token

    def load_split(self, split: str = "train", streaming: bool = True) -> "IterableDataset | Dataset":
        """Return a single split of google/fleurs."""
        logger.info(
            "Loading google/fleurs subset=%s (%s) split=%s streaming=%s",
            self._fleurs_subset, self.subset, split, streaming,
        )
        ds = load_dataset(
            "google/fleurs",
            self._fleurs_subset,
            split=split,
            token=self.hf_token,
            streaming=streaming,
            trust_remote_code=True,
        )
        if streaming:
            ds = ds.shuffle(seed=42, buffer_size=1000)
        return ds

    def get_splits(self, streaming: bool = True) -> dict[str, "IterableDataset | Dataset"]:
        """Return train / validation / test splits."""
        splits = {}
        for split in ("train", "validation", "test"):
            try:
                splits[split] = self.load_split(split, streaming=streaming)
            except Exception:
                logger.warning("Split '%s' not available for subset '%s'", split, self.subset)
        return splits

    def make_preprocess_fn(
        self,
        processor: "WhisperProcessor",
        augmenter: "FieldNoiseAugmenter | None" = None,
    ) -> Callable[[dict], dict]:
        """Return a function that converts a raw Waxal example into model inputs."""

        def preprocess(example: dict) -> dict:
            # Extract and resample audio
            audio_array = np.array(example[AUDIO_COL]["array"], dtype=np.float32)
            orig_sr: int = example[AUDIO_COL]["sampling_rate"]

            if orig_sr != TARGET_SR:
                tensor = torch.from_numpy(audio_array).unsqueeze(0)
                tensor = torchaudio.functional.resample(tensor, orig_sr, TARGET_SR)
                audio_array = tensor.squeeze(0).numpy()

            # Apply field noise augmentation if provided
            if augmenter is not None and augmenter.is_ready():
                audio_array = augmenter.augment(audio_array, TARGET_SR)

            # Extract log-mel features
            inputs = processor.feature_extractor(
                audio_array,
                sampling_rate=TARGET_SR,
                return_tensors="np",
            )
            input_features = inputs.input_features[0]  # shape (80, 3000)

            # Tokenize transcript
            text: str = example[TEXT_COL]
            labels = processor.tokenizer(text, return_tensors="np").input_ids[0]

            return {
                "input_features": input_features,
                "labels": labels,
            }

        return preprocess

    def iter_processed(
        self,
        processor: "WhisperProcessor",
        split: str = "train",
        augmenter: "FieldNoiseAugmenter | None" = None,
    ) -> Iterator[dict]:
        """Yield preprocessed examples one at a time (streaming)."""
        ds = self.load_split(split, streaming=True)
        fn = self.make_preprocess_fn(processor, augmenter)
        for example in ds:
            yield fn(example)