File size: 2,928 Bytes
eca55dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import numpy as np
from src.data.audioset_datamodule import AudioSetDataset


# Mock Dataset inheriting from AudioSetDataset to test logic without H5
class MockAudioSetDataset(AudioSetDataset):
    def __init__(self, lengths, max_length=None):
        self.lengths = lengths
        self.max_length = max_length
        self.transform = None
        self.valid_indices = list(range(len(lengths)))
        self.h5_file = None  # Not used

    def _open_h5(self):
        pass

    def __getitem__(self, idx):
        # Mock waveform loading
        length = self.lengths[idx]
        # Create a waveform where values are 0..L-1 so we can check cropping start
        waveform = np.arange(length, dtype=np.float32)

        # Random Crop logic from AudioSetDataset
        if self.max_length is not None and len(waveform) > self.max_length:
            max_start = len(waveform) - self.max_length
            start = np.random.randint(0, max_start + 1)
            waveform = waveform[start : start + self.max_length]

        # Mock other returns
        target = torch.zeros(527)
        audio_name = f"audio_{idx}"

        waveform = torch.from_numpy(waveform).unsqueeze(0)

        return {
            "waveform": waveform,
            "target": target,
            "audio_name": audio_name,
            "index": idx,
        }


def test_random_cropping():
    max_len = 100
    lengths = [50, 100, 150, 200]

    dataset = MockAudioSetDataset(lengths, max_length=max_len)

    print(f"Testing with max_length={max_len}")

    for i in range(len(lengths)):
        # Test multiple times to check randomness
        starts = []
        for _ in range(5):
            item = dataset[i]
            wave = item["waveform"]
            # Check length
            if wave.shape[-1] > max_len:
                print(
                    f"FAIL: Index {i} (orig {lengths[i]}) has length {wave.shape[-1]} > {max_len}"
                )

            # Check content (start index)
            start_val = wave[0, 0].item()
            starts.append(start_val)

        print(f"Index {i} (orig {lengths[i]}): Starts = {starts}")

        if lengths[i] > max_len:
            # Should be cropped to max_len
            if wave.shape[-1] != max_len:
                print(
                    f"FAIL: Index {i} should be cropped to {max_len}, got {wave.shape[-1]}"
                )

            # Should be random (unless max_start=0)
            if (
                len(set(starts)) == 1 and lengths[i] > max_len + 5
            ):  # Allow some chance of collision
                print(f"WARNING: Index {i} might not be random? Starts: {starts}")
        else:
            # Should be original length
            if wave.shape[-1] != lengths[i]:
                print(f"FAIL: Index {i} should be {lengths[i]}, got {wave.shape[-1]}")

    print("Test finished.")


if __name__ == "__main__":
    test_random_cropping()