File size: 1,164 Bytes
eca55dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import sys
import os

# Add src to path
sys.path.append(os.path.abspath("src"))

from data.audioset_datamodule import AudioSetDataModule


def verify_data():
    print("Initializing DataModule...")
    dm = AudioSetDataModule(
        data_dir="data/AudioSet",
        batch_size=4,
        num_workers=0,  # Use 0 for debugging
        target_sample_rate=32000,
    )
    dm.setup()

    print(f"Train dataset size: {len(dm.train_dataset)}")
    print(f"Val dataset size: {len(dm.val_dataset)}")

    print("Fetching a batch...")
    loader = dm.train_dataloader()
    batch = next(iter(loader))

    waveform = batch["waveform"]
    target = batch["target"]
    audio_name = batch["audio_name"]
    index = batch["index"]

    print(f"Waveform shape: {waveform.shape}")
    print(f"Target shape: {target.shape}")
    print(f"Audio names: {audio_name}")
    print(f"Indices: {index}")

    assert waveform.ndim == 3, "Waveform should be [B, C, T]"
    assert waveform.shape[1] == 1, "Channel dim should be 1"
    assert waveform.shape[2] == 320000, "Time dim should be 320000"

    print("Verification successful!")


if __name__ == "__main__":
    verify_data()