| import torch |
| import numpy as np |
| from src.data.audioset_datamodule import AudioSetDataset |
|
|
|
|
| |
| class MockAudioSetDataset(AudioSetDataset): |
| def __init__(self, lengths, max_length=None): |
| self.lengths = lengths |
| self.max_length = max_length |
| self.transform = None |
| self.valid_indices = list(range(len(lengths))) |
| self.h5_file = None |
|
|
| def _open_h5(self): |
| pass |
|
|
| def __getitem__(self, idx): |
| |
| length = self.lengths[idx] |
| |
| waveform = np.arange(length, dtype=np.float32) |
|
|
| |
| if self.max_length is not None and len(waveform) > self.max_length: |
| max_start = len(waveform) - self.max_length |
| start = np.random.randint(0, max_start + 1) |
| waveform = waveform[start : start + self.max_length] |
|
|
| |
| target = torch.zeros(527) |
| audio_name = f"audio_{idx}" |
|
|
| waveform = torch.from_numpy(waveform).unsqueeze(0) |
|
|
| return { |
| "waveform": waveform, |
| "target": target, |
| "audio_name": audio_name, |
| "index": idx, |
| } |
|
|
|
|
| def test_random_cropping(): |
| max_len = 100 |
| lengths = [50, 100, 150, 200] |
|
|
| dataset = MockAudioSetDataset(lengths, max_length=max_len) |
|
|
| print(f"Testing with max_length={max_len}") |
|
|
| for i in range(len(lengths)): |
| |
| starts = [] |
| for _ in range(5): |
| item = dataset[i] |
| wave = item["waveform"] |
| |
| if wave.shape[-1] > max_len: |
| print( |
| f"FAIL: Index {i} (orig {lengths[i]}) has length {wave.shape[-1]} > {max_len}" |
| ) |
|
|
| |
| start_val = wave[0, 0].item() |
| starts.append(start_val) |
|
|
| print(f"Index {i} (orig {lengths[i]}): Starts = {starts}") |
|
|
| if lengths[i] > max_len: |
| |
| if wave.shape[-1] != max_len: |
| print( |
| f"FAIL: Index {i} should be cropped to {max_len}, got {wave.shape[-1]}" |
| ) |
|
|
| |
| if ( |
| len(set(starts)) == 1 and lengths[i] > max_len + 5 |
| ): |
| print(f"WARNING: Index {i} might not be random? Starts: {starts}") |
| else: |
| |
| if wave.shape[-1] != lengths[i]: |
| print(f"FAIL: Index {i} should be {lengths[i]}, got {wave.shape[-1]}") |
|
|
| print("Test finished.") |
|
|
|
|
| if __name__ == "__main__": |
| test_random_cropping() |
|
|