| import torch |
| from torch.utils.data import Dataset, DataLoader |
| import lightning as L |
| from typing import Optional, List, Dict, Any, Union |
| from functools import partial |
|
|
|
|
| class MockAudioSetDataset(Dataset): |
| """ |
| Mock Dataset for AudioSet data that generates random noise. |
| """ |
|
|
| def __init__( |
| self, |
| length: int = 100, |
| max_length: int = 160000, |
| target_sample_rate: int = 16000, |
| ): |
| self.length = length |
| self.max_length = max_length |
| self.target_sample_rate = target_sample_rate |
|
|
| def __len__(self) -> int: |
| return self.length |
|
|
| def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, str, int]]: |
| |
| |
| |
| waveform = torch.randn(1, self.max_length) |
|
|
| |
| target = torch.zeros(527) |
| |
| indices = torch.randint(0, 527, (3,)) |
| target[indices] = 1.0 |
|
|
| audio_name = f"mock_audio_{idx}" |
|
|
| return { |
| "waveform": waveform, |
| "target": target, |
| "audio_name": audio_name, |
| "index": idx, |
| } |
|
|
|
|
| class MockAudioSetDataModule(L.LightningDataModule): |
| """ |
| LightningDataModule for Mock AudioSet. |
| """ |
|
|
| def __init__( |
| self, |
| batch_size: int = 8, |
| num_workers: int = 0, |
| pin_memory: bool = True, |
| max_audio_length_sec: float = 10.0, |
| target_sample_rate: int = 16000, |
| collate_mode: str = "pad", |
| ): |
| super().__init__() |
| self.save_hyperparameters() |
|
|
| self.batch_size = batch_size |
| self.num_workers = num_workers |
| self.pin_memory = pin_memory |
| self.collate_mode = collate_mode |
|
|
| self.max_audio_length = int(max_audio_length_sec * target_sample_rate) |
|
|
| self.train_dataset: Optional[MockAudioSetDataset] = None |
| self.val_dataset: Optional[MockAudioSetDataset] = None |
| self.test_dataset: Optional[MockAudioSetDataset] = None |
|
|
| def setup(self, stage: Optional[str] = None) -> None: |
| if stage == "fit" or stage is None: |
| self.train_dataset = MockAudioSetDataset( |
| length=1000, |
| max_length=self.max_audio_length, |
| target_sample_rate=self.hparams.target_sample_rate, |
| ) |
| self.val_dataset = MockAudioSetDataset( |
| length=100, |
| max_length=self.max_audio_length, |
| target_sample_rate=self.hparams.target_sample_rate, |
| ) |
|
|
| if stage == "test": |
| self.test_dataset = MockAudioSetDataset( |
| length=50, |
| max_length=self.max_audio_length, |
| target_sample_rate=self.hparams.target_sample_rate, |
| ) |
|
|
| def train_dataloader(self) -> DataLoader: |
| return DataLoader( |
| self.train_dataset, |
| batch_size=self.batch_size, |
| shuffle=True, |
| num_workers=self.num_workers, |
| pin_memory=self.pin_memory, |
| persistent_workers=self.num_workers > 0, |
| collate_fn=partial(self.collate_fn, mode=self.collate_mode), |
| ) |
|
|
| def val_dataloader(self) -> DataLoader: |
| return DataLoader( |
| self.val_dataset, |
| batch_size=self.batch_size, |
| shuffle=False, |
| num_workers=self.num_workers, |
| pin_memory=self.pin_memory, |
| persistent_workers=self.num_workers > 0, |
| collate_fn=partial(self.collate_fn, mode=self.collate_mode), |
| ) |
|
|
| def test_dataloader(self) -> DataLoader: |
| return DataLoader( |
| self.test_dataset, |
| batch_size=self.batch_size, |
| shuffle=False, |
| num_workers=self.num_workers, |
| pin_memory=self.pin_memory, |
| collate_fn=partial(self.collate_fn, mode=self.collate_mode), |
| ) |
|
|
| @staticmethod |
| def collate_fn(batch: List[Dict[str, Any]], mode: str = "pad") -> Dict[str, Any]: |
| """ |
| Collate function to pad or truncate waveforms. |
| """ |
| waveforms = [item["waveform"] for item in batch] |
| targets = torch.stack([item["target"] for item in batch]) |
| audio_names = [item["audio_name"] for item in batch] |
| indices = [item["index"] for item in batch] |
|
|
| |
| lengths = [w.shape[-1] for w in waveforms] |
|
|
| if mode == "pad": |
| target_wave_len = max(lengths) |
| elif mode == "truncate": |
| target_wave_len = min(lengths) |
| else: |
| raise ValueError(f"Unknown collate mode: {mode}") |
|
|
| |
| processed_waveforms = [] |
| for w in waveforms: |
| current_len = w.shape[-1] |
| if current_len < target_wave_len: |
| pad_amount = target_wave_len - current_len |
| |
| w_padded = torch.nn.functional.pad(w, (0, pad_amount)) |
| processed_waveforms.append(w_padded) |
| elif current_len > target_wave_len: |
| |
| w_truncated = w[..., :target_wave_len] |
| processed_waveforms.append(w_truncated) |
| else: |
| processed_waveforms.append(w) |
|
|
| processed_waveforms = torch.stack(processed_waveforms) |
|
|
| return { |
| "waveform": processed_waveforms, |
| "target": targets, |
| "audio_name": audio_names, |
| "index": indices, |
| } |
|
|