File size: 1,164 Bytes
eca55dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | import sys
import os
# Add src to path
sys.path.append(os.path.abspath("src"))
from data.audioset_datamodule import AudioSetDataModule
def verify_data():
print("Initializing DataModule...")
dm = AudioSetDataModule(
data_dir="data/AudioSet",
batch_size=4,
num_workers=0, # Use 0 for debugging
target_sample_rate=32000,
)
dm.setup()
print(f"Train dataset size: {len(dm.train_dataset)}")
print(f"Val dataset size: {len(dm.val_dataset)}")
print("Fetching a batch...")
loader = dm.train_dataloader()
batch = next(iter(loader))
waveform = batch["waveform"]
target = batch["target"]
audio_name = batch["audio_name"]
index = batch["index"]
print(f"Waveform shape: {waveform.shape}")
print(f"Target shape: {target.shape}")
print(f"Audio names: {audio_name}")
print(f"Indices: {index}")
assert waveform.ndim == 3, "Waveform should be [B, C, T]"
assert waveform.shape[1] == 1, "Channel dim should be 1"
assert waveform.shape[2] == 320000, "Time dim should be 320000"
print("Verification successful!")
if __name__ == "__main__":
verify_data()
|