File size: 3,400 Bytes
eca55dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import numpy as np
import torchaudio
from src.data.audioset_datamodule import AudioSetDataset


# Mock Dataset to test resampling logic
class MockAudioSetDatasetResample(AudioSetDataset):
    def __init__(self, source_sr, target_sr, max_length=None):
        self.source_sample_rate = source_sr
        self.target_sample_rate = target_sr
        self.max_length = max_length
        self.transform = None
        self.valid_indices = [0]
        self.h5_file = None

    def _open_h5(self):
        pass

    def __getitem__(self, idx):
        # Create a 10s sine wave at source SR
        duration = 10
        t = np.linspace(0, duration, int(self.source_sample_rate * duration))
        waveform = np.sin(2 * np.pi * 440 * t).astype(np.float32)

        # Mock loading
        waveform = torch.from_numpy(waveform)

        # --- Copy-paste logic from AudioSetDataset.__getitem__ ---
        # Resampling and Cropping Logic
        if self.source_sample_rate != self.target_sample_rate:
            if self.max_length is not None:
                crop_len_source = (
                    int(
                        self.max_length
                        * self.source_sample_rate
                        / self.target_sample_rate
                    )
                    + 100
                )
                if len(waveform) > crop_len_source:
                    max_start = len(waveform) - crop_len_source
                    start = np.random.randint(0, max_start + 1)
                    waveform = waveform[start : start + crop_len_source]

            resampler = torchaudio.transforms.Resample(
                self.source_sample_rate, self.target_sample_rate
            )
            waveform = resampler(waveform.unsqueeze(0)).squeeze(0)

            if self.max_length is not None and len(waveform) > self.max_length:
                waveform = waveform[: self.max_length]
        else:
            if self.max_length is not None and len(waveform) > self.max_length:
                max_start = len(waveform) - self.max_length
                start = np.random.randint(0, max_start + 1)
                waveform = waveform[start : start + self.max_length]
        # ---------------------------------------------------------

        return waveform


def test_resampling():
    source_sr = 32000
    target_sr = 16000
    max_len_target = 160000  # 10s @ 16kHz

    dataset = MockAudioSetDatasetResample(
        source_sr, target_sr, max_length=max_len_target
    )

    print(
        f"Testing resampling from {source_sr} to {target_sr} with max_len {max_len_target}"
    )

    waveform = dataset[0]
    print(f"Output shape: {waveform.shape}")

    if waveform.shape[0] == max_len_target:
        print("PASS: Output length matches max_length")
    else:
        print(f"FAIL: Output length {waveform.shape[0]} != {max_len_target}")

    # Test without max_length
    dataset_no_max = MockAudioSetDatasetResample(source_sr, target_sr, max_length=None)
    waveform_full = dataset_no_max[0]
    expected_len = 160000  # 10s * 16000
    print(f"Output shape (no max): {waveform_full.shape}")
    if abs(waveform_full.shape[0] - expected_len) < 100:
        print("PASS: Output length matches expected resampled length")
    else:
        print(f"FAIL: Output length {waveform_full.shape[0]} != {expected_len}")


if __name__ == "__main__":
    test_resampling()