| import torch |
| import numpy as np |
| import torchaudio |
| from src.data.audioset_datamodule import AudioSetDataset |
|
|
|
|
| |
| class MockAudioSetDatasetResample(AudioSetDataset): |
| def __init__(self, source_sr, target_sr, max_length=None): |
| self.source_sample_rate = source_sr |
| self.target_sample_rate = target_sr |
| self.max_length = max_length |
| self.transform = None |
| self.valid_indices = [0] |
| self.h5_file = None |
|
|
| def _open_h5(self): |
| pass |
|
|
| def __getitem__(self, idx): |
| |
| duration = 10 |
| t = np.linspace(0, duration, int(self.source_sample_rate * duration)) |
| waveform = np.sin(2 * np.pi * 440 * t).astype(np.float32) |
|
|
| |
| waveform = torch.from_numpy(waveform) |
|
|
| |
| |
| if self.source_sample_rate != self.target_sample_rate: |
| if self.max_length is not None: |
| crop_len_source = ( |
| int( |
| self.max_length |
| * self.source_sample_rate |
| / self.target_sample_rate |
| ) |
| + 100 |
| ) |
| if len(waveform) > crop_len_source: |
| max_start = len(waveform) - crop_len_source |
| start = np.random.randint(0, max_start + 1) |
| waveform = waveform[start : start + crop_len_source] |
|
|
| resampler = torchaudio.transforms.Resample( |
| self.source_sample_rate, self.target_sample_rate |
| ) |
| waveform = resampler(waveform.unsqueeze(0)).squeeze(0) |
|
|
| if self.max_length is not None and len(waveform) > self.max_length: |
| waveform = waveform[: self.max_length] |
| else: |
| if self.max_length is not None and len(waveform) > self.max_length: |
| max_start = len(waveform) - self.max_length |
| start = np.random.randint(0, max_start + 1) |
| waveform = waveform[start : start + self.max_length] |
| |
|
|
| return waveform |
|
|
|
|
| def test_resampling(): |
| source_sr = 32000 |
| target_sr = 16000 |
| max_len_target = 160000 |
|
|
| dataset = MockAudioSetDatasetResample( |
| source_sr, target_sr, max_length=max_len_target |
| ) |
|
|
| print( |
| f"Testing resampling from {source_sr} to {target_sr} with max_len {max_len_target}" |
| ) |
|
|
| waveform = dataset[0] |
| print(f"Output shape: {waveform.shape}") |
|
|
| if waveform.shape[0] == max_len_target: |
| print("PASS: Output length matches max_length") |
| else: |
| print(f"FAIL: Output length {waveform.shape[0]} != {max_len_target}") |
|
|
| |
| dataset_no_max = MockAudioSetDatasetResample(source_sr, target_sr, max_length=None) |
| waveform_full = dataset_no_max[0] |
| expected_len = 160000 |
| print(f"Output shape (no max): {waveform_full.shape}") |
| if abs(waveform_full.shape[0] - expected_len) < 100: |
| print("PASS: Output length matches expected resampled length") |
| else: |
| print(f"FAIL: Output length {waveform_full.shape[0]} != {expected_len}") |
|
|
|
|
| if __name__ == "__main__": |
| test_resampling() |
|
|