File size: 2,354 Bytes
76db545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Phase 2: Download google/waxal, apply augmentation, print statistics.
Streams examples and caches to data_cache/ as Arrow files.

Usage:
    python scripts/run_data_pipeline.py --subset bam --max-examples 100
"""
import argparse
import sys
import time
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

import os

from dotenv import load_dotenv

load_dotenv()


def main(subset: str, max_examples: int) -> None:
    import yaml
    from transformers import WhisperProcessor

    from src.data.augmentation import FieldNoiseAugmenter
    from src.data.waxal_loader import WaxalDataLoader

    with open("configs/base_config.yaml") as f:
        config = yaml.safe_load(f)

    hf_token = os.getenv("HF_TOKEN")
    model_id = config["model"]["id"]

    print("=" * 60)
    print(f"Waxal Data Pipeline — subset: {subset}")
    print("=" * 60)

    print(f"\n[1/4] Loading WhisperProcessor ({model_id})...")
    processor = WhisperProcessor.from_pretrained(model_id, token=hf_token)

    print("[2/4] Initializing augmenter...")
    augmenter = FieldNoiseAugmenter(config["paths"]["noise_samples"], config)
    print(f"      Augmenter ready: {augmenter.is_ready()}")

    print(f"[3/4] Streaming google/waxal subset={subset}...")
    loader = WaxalDataLoader(subset, config, hf_token=hf_token)

    t0 = time.time()
    count = 0
    total_duration = 0.0

    for example in loader.iter_processed(processor, split="train", augmenter=augmenter):
        count += 1
        # input_features shape: (80, 3000) = 30 seconds at most
        # Estimate actual audio duration from non-padding frames
        total_duration += 30.0  # max chunk
        if count >= max_examples:
            break

    elapsed = time.time() - t0

    print(f"\n[4/4] Results:")
    print(f"      Examples processed: {count}")
    print(f"      Approx total audio: {total_duration / 3600:.2f} hours")
    print(f"      Processing time:    {elapsed:.1f}s")
    print(f"      Throughput:         {count / elapsed:.1f} examples/sec")
    print(f"\nData pipeline PASSED.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--subset", default="bam", choices=["bam", "ful"])
    parser.add_argument("--max-examples", type=int, default=50)
    args = parser.parse_args()
    main(args.subset, args.max_examples)