File size: 2,838 Bytes
a4d9876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import io
import json
import random
import shutil
from pathlib import Path

import soundfile as sf
from datasets import Audio, load_dataset
from tqdm import tqdm

DEFAULT_REPO = "saleh1312/syncing_data"
MAX_DURATION = 10.0 

def main():
    parser = argparse.ArgumentParser(description="Prepare data for OmniVoice Training")
    parser.add_argument("--repo", default=DEFAULT_REPO, help="HF Dataset ID")
    parser.add_argument("--out", default="sync_data", help="Output directory")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    args = parser.parse_args()

    out_root = Path(args.out).resolve()
    out_root = out_root / "data"
    if out_root.exists():
        print(f"Cleaning up old directory: {out_root}")
        shutil.rmtree(out_root)
    
    wav_dir = out_root / "wavs"
    wav_dir.mkdir(parents=True)

    print(f"Loading dataset: {args.repo}")
    ds = load_dataset(args.repo, split="train")
    ds = ds.cast_column("audio", Audio(decode=False))

    processed_records = []
    skipped = 0

    print("Processing audio files...")
    for i, row in enumerate(tqdm(ds)):
        audio_data = row["audio"]["bytes"]
        if not audio_data:
            continue

        # Load audio to check duration
        with io.BytesIO(audio_data) as f:
            data, sr = sf.read(f)
        
        duration = len(data) / sr
        if duration > MAX_DURATION:
            skipped += 1
            continue

        # Ensure Mono
        if data.ndim > 1:
            data = data.mean(axis=1)

        sample_id = f"sample_{i:06d}"
        wav_path = wav_dir / f"{sample_id}.wav"
        sf.write(wav_path, data, sr, subtype='PCM_16')

        tone = str(row.get("tone", "neutral")).strip().lower()
        processed_records.append({
            "id": sample_id,
            "audio_path": str(wav_path.resolve()),
            "text": row["text"],
            "language_id": "ar",
            "instruct": f"saudi, conversational, {tone}" 
        })

    random.seed(args.seed)
    random.shuffle(processed_records)
    
    split_idx = int(len(processed_records) * 0.95)
    train_data = processed_records[:split_idx]
    dev_data = processed_records[split_idx:]

    # Write input files for the tokenization script
    for name, data in [("train_raw.jsonl", train_data), ("dev_raw.jsonl", dev_data)]:
        out_path = out_root / name
        with open(out_path, "w", encoding="utf-8") as f:
            for rec in data:
                f.write(json.dumps(rec, ensure_ascii=False) + "\n")
        print(f"Created {out_path} ({len(data)} samples)")

    print(f"\nPreparation Complete!")
    print(f"Skipped {skipped} samples (> {MAX_DURATION}s)")
    print(f"Next: Run the 'extract_audio_tokens.py' script using 'sync_data/train_raw.jsonl'")

if __name__ == "__main__":
    main()