""" Phase 2: Download google/waxal, apply augmentation, print statistics. Streams examples and caches to data_cache/ as Arrow files. Usage: python scripts/run_data_pipeline.py --subset bam --max-examples 100 """ import argparse import sys import time from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) import os from dotenv import load_dotenv load_dotenv() def main(subset: str, max_examples: int) -> None: import yaml from transformers import WhisperProcessor from src.data.augmentation import FieldNoiseAugmenter from src.data.waxal_loader import WaxalDataLoader with open("configs/base_config.yaml") as f: config = yaml.safe_load(f) hf_token = os.getenv("HF_TOKEN") model_id = config["model"]["id"] print("=" * 60) print(f"Waxal Data Pipeline — subset: {subset}") print("=" * 60) print(f"\n[1/4] Loading WhisperProcessor ({model_id})...") processor = WhisperProcessor.from_pretrained(model_id, token=hf_token) print("[2/4] Initializing augmenter...") augmenter = FieldNoiseAugmenter(config["paths"]["noise_samples"], config) print(f" Augmenter ready: {augmenter.is_ready()}") print(f"[3/4] Streaming google/waxal subset={subset}...") loader = WaxalDataLoader(subset, config, hf_token=hf_token) t0 = time.time() count = 0 total_duration = 0.0 for example in loader.iter_processed(processor, split="train", augmenter=augmenter): count += 1 # input_features shape: (80, 3000) = 30 seconds at most # Estimate actual audio duration from non-padding frames total_duration += 30.0 # max chunk if count >= max_examples: break elapsed = time.time() - t0 print(f"\n[4/4] Results:") print(f" Examples processed: {count}") print(f" Approx total audio: {total_duration / 3600:.2f} hours") print(f" Processing time: {elapsed:.1f}s") print(f" Throughput: {count / elapsed:.1f} examples/sec") print(f"\nData pipeline PASSED.") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--subset", default="bam", choices=["bam", "ful"]) parser.add_argument("--max-examples", type=int, default=50) args = parser.parse_args() main(args.subset, args.max_examples)