BEST-RQ-2 / audio-embeddings /src /data /mock_audioset_datamodule.py
ltuncay's picture
Submission to the Interspeech 2026 Audio Encoder Capability Challenge
eca55dc verified
import torch
from torch.utils.data import Dataset, DataLoader
import lightning as L
from typing import Optional, List, Dict, Any, Union
from functools import partial
class MockAudioSetDataset(Dataset):
"""
Mock Dataset for AudioSet data that generates random noise.
"""
def __init__(
self,
length: int = 100,
max_length: int = 160000,
target_sample_rate: int = 16000,
):
self.length = length
self.max_length = max_length
self.target_sample_rate = target_sample_rate
def __len__(self) -> int:
return self.length
def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, str, int]]:
# Generate random waveform [1, T]
# Random length between max_length // 2 and max_length for realism, or just fixed max_length
# Let's do fixed max_length for simplicity in mock
waveform = torch.randn(1, self.max_length)
# Fake target (multi-hot) - assuming AudioSet has 527 classes
target = torch.zeros(527)
# Set a few random classes to 1
indices = torch.randint(0, 527, (3,))
target[indices] = 1.0
audio_name = f"mock_audio_{idx}"
return {
"waveform": waveform,
"target": target,
"audio_name": audio_name,
"index": idx,
}
class MockAudioSetDataModule(L.LightningDataModule):
"""
LightningDataModule for Mock AudioSet.
"""
def __init__(
self,
batch_size: int = 8,
num_workers: int = 0,
pin_memory: bool = True,
max_audio_length_sec: float = 10.0,
target_sample_rate: int = 16000,
collate_mode: str = "pad",
):
super().__init__()
self.save_hyperparameters()
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.collate_mode = collate_mode
self.max_audio_length = int(max_audio_length_sec * target_sample_rate)
self.train_dataset: Optional[MockAudioSetDataset] = None
self.val_dataset: Optional[MockAudioSetDataset] = None
self.test_dataset: Optional[MockAudioSetDataset] = None
def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit" or stage is None:
self.train_dataset = MockAudioSetDataset(
length=1000, # Fake dataset size
max_length=self.max_audio_length,
target_sample_rate=self.hparams.target_sample_rate,
)
self.val_dataset = MockAudioSetDataset(
length=100,
max_length=self.max_audio_length,
target_sample_rate=self.hparams.target_sample_rate,
)
if stage == "test":
self.test_dataset = MockAudioSetDataset(
length=50,
max_length=self.max_audio_length,
target_sample_rate=self.hparams.target_sample_rate,
)
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.num_workers > 0,
collate_fn=partial(self.collate_fn, mode=self.collate_mode),
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.num_workers > 0,
collate_fn=partial(self.collate_fn, mode=self.collate_mode),
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=partial(self.collate_fn, mode=self.collate_mode),
)
@staticmethod
def collate_fn(batch: List[Dict[str, Any]], mode: str = "pad") -> Dict[str, Any]:
"""
Collate function to pad or truncate waveforms.
"""
waveforms = [item["waveform"] for item in batch] # List of [1, T]
targets = torch.stack([item["target"] for item in batch])
audio_names = [item["audio_name"] for item in batch]
indices = [item["index"] for item in batch]
# Find max or min length in the batch
lengths = [w.shape[-1] for w in waveforms]
if mode == "pad":
target_wave_len = max(lengths)
elif mode == "truncate":
target_wave_len = min(lengths)
else:
raise ValueError(f"Unknown collate mode: {mode}")
# Pad or Truncate waveforms
processed_waveforms = []
for w in waveforms:
current_len = w.shape[-1]
if current_len < target_wave_len:
pad_amount = target_wave_len - current_len
# Pad at the end
w_padded = torch.nn.functional.pad(w, (0, pad_amount))
processed_waveforms.append(w_padded)
elif current_len > target_wave_len:
# Truncate
w_truncated = w[..., :target_wave_len]
processed_waveforms.append(w_truncated)
else:
processed_waveforms.append(w)
processed_waveforms = torch.stack(processed_waveforms)
return {
"waveform": processed_waveforms,
"target": targets,
"audio_name": audio_names,
"index": indices,
}