| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from src.models.audio_jepa_module import AudioJEPAModule |
| from src.data.audioset_datamodule import AudioSetDataModule |
|
|
|
|
| |
| class MockAudioDataset(Dataset): |
| def __init__(self, lengths): |
| self.lengths = lengths |
|
|
| def __len__(self): |
| return len(self.lengths) |
|
|
| def __getitem__(self, idx): |
| length = self.lengths[idx] |
| waveform = torch.randn(1, length) |
| target = torch.randn(527) |
| return { |
| "waveform": waveform, |
| "target": target, |
| "audio_name": f"audio_{idx}", |
| "index": idx, |
| } |
|
|
|
|
| def test_variable_length(): |
| |
| lengths = [32000, 48000, 30000, 50000] |
| dataset = MockAudioDataset(lengths) |
|
|
| |
| |
| def collate_fn(batch): |
| return AudioSetDataModule.collate_fn(batch, hop_length=1250, patch_time_dim=16) |
|
|
| loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn) |
|
|
| batch = next(iter(loader)) |
| waveforms = batch["waveform"] |
| print(f"Batch waveforms shape: {waveforms.shape}") |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| expected_len = 78750 |
| if waveforms.shape[-1] == expected_len: |
| print("Padding logic verified!") |
| else: |
| print( |
| f"Padding logic mismatch! Expected {expected_len}, got {waveforms.shape[-1]}" |
| ) |
|
|
| |
| print("Initializing model...") |
| |
| net_config = { |
| "spectrogram": { |
| "sample_rate": 32000, |
| "n_fft": 4096, |
| "win_length": 4096, |
| "hop_length": 1250, |
| "n_mels": 128, |
| "f_min": 0.0, |
| "f_max": None, |
| |
| }, |
| "patch_embed": { |
| "img_size": ( |
| 128, |
| 256, |
| ), |
| "patch_size": (16, 16), |
| "in_chans": 1, |
| "embed_dim": 192, |
| }, |
| "masking": { |
| "input_size": (128, 256), |
| "patch_size": (16, 16), |
| "mask_ratio": (0.4, 0.6), |
| }, |
| "encoder": { |
| "embed_dim": 192, |
| "depth": 2, |
| "num_heads": 3, |
| "pos_embed_type": "rope", |
| "img_size": (128, 256), |
| "patch_size": (16, 16), |
| }, |
| "predictor": { |
| "embed_dim": 192, |
| "depth": 1, |
| "num_heads": 3, |
| "pos_embed_type": "rope", |
| "img_size": (128, 256), |
| "patch_size": (16, 16), |
| }, |
| } |
|
|
| model = AudioJEPAModule(optimizer=torch.optim.AdamW, net=net_config) |
|
|
| |
| model.current_ema_decay = 0.996 |
|
|
| print("Running training_step...") |
| loss = model.training_step(batch, 0) |
| print(f"Training step loss: {loss}") |
|
|
| print("Running validation_step...") |
| val_loss = model.validation_step(batch, 0) |
| print(f"Validation step loss: {val_loss}") |
|
|
| print("Test passed!") |
|
|
|
|
| if __name__ == "__main__": |
| test_variable_length() |
|
|