BEST-RQ-2 / audio-embeddings /src /data /audio_utils.py
ltuncay's picture
Submission to the Interspeech 2026 Audio Encoder Capability Challenge
eca55dc verified
from typing import Any, Dict, List, Optional, Sequence
import numpy as np
import torch
import torchaudio
class DatasetResamplerCropper:
"""
Resample and optionally crop a waveform.
Maintains a cache of resamplers for different source sampling rates to optimize instantiation.
Args:
target_sr (int): Target sampling rate.
max_length (Optional[int]): Maximum length in samples (at target_sr).
"""
def __init__(self, target_sr: int, max_length: Optional[int] = None):
self.target_sr = target_sr
self.max_length = max_length
self.resamplers: Dict[int, torchaudio.transforms.Resample] = {}
def forward(self, waveform: torch.Tensor, source_sr: int) -> torch.Tensor:
"""
Args:
waveform (torch.Tensor): Tensor of shape [T] or [C, T].
source_sr (int): Source sampling rate.
Returns:
torch.Tensor: Processed waveform tensor.
"""
# Resampling and Cropping Logic
if source_sr != self.target_sr:
# We need to resample.
# Optimization: Crop in source domain first if we have a max_length
if self.max_length is not None:
# Calculate required source length to get max_length in target domain
# Add a small buffer to avoid rounding issues
crop_len_source = round(self.max_length * source_sr / self.target_sr)
if waveform.shape[-1] > crop_len_source:
max_start = waveform.shape[-1] - crop_len_source
start = np.random.randint(0, max_start + 1)
waveform = waveform[..., start : start + crop_len_source]
# Resample
if source_sr not in self.resamplers:
self.resamplers[source_sr] = torchaudio.transforms.Resample(
source_sr, self.target_sr
)
resampler = self.resamplers[source_sr]
waveform = resampler(waveform)
# Now handle max_length (trim if we cropped with buffer, or if it was already long enough)
if self.max_length is not None and waveform.shape[-1] > self.max_length:
waveform = waveform[..., : self.max_length]
else:
# No resampling, just standard random crop
if self.max_length is not None and waveform.shape[-1] > self.max_length:
max_start = waveform.shape[-1] - self.max_length
start = np.random.randint(0, max_start + 1)
waveform = waveform[..., start : start + self.max_length]
return waveform
def __call__(self, waveform: torch.Tensor, source_sr: int) -> torch.Tensor:
return self.forward(waveform, source_sr)
def collate_audio_batch(
batch: List[Dict[str, Any]],
waveform_key: str = "waveform",
mode: str = "pad", # "pad" or "truncate"
stack_waveforms: bool = True,
pad_value: float = 0.0,
include_keys: Optional[Sequence[str]] = None,
exclude_keys: Optional[Sequence[str]] = None,
) -> Dict[str, Any]:
"""
Generic collate function for audio batches where each sample is a dict
containing at least `waveform_key` with shape [1, T] or [T].
Pads or truncates waveforms across the batch, and returns a dict that:
- always includes waveform_key -> Tensor [B, 1, T']
- includes other keys aggregated into lists (or stacked if possible)
Parameters
----------
batch:
List of sample dictionaries.
waveform_key:
Key of waveform in sample dict.
mode:
"pad" -> pad shorter waveforms to max length in batch
"truncate" -> truncate longer waveforms to min length in batch
stack_waveforms:
If True, returns waveforms stacked into a single tensor [B, 1, T'].
pad_value:
Value used for padding.
include_keys:
If provided, only these keys will be included in the output (plus waveform_key).
exclude_keys:
If provided, these keys will not be included (except waveform_key is always kept).
Returns
-------
Dict[str, Any]
Collated batch dict.
"""
if len(batch) == 0:
raise ValueError("Empty batch passed to collate_audio_batch")
# 1) Collect waveforms
waveforms = []
for item in batch:
if waveform_key not in item:
raise KeyError(
f"Missing key '{waveform_key}' in batch item: {list(item.keys())}"
)
w = item[waveform_key]
if not torch.is_tensor(w):
raise TypeError(
f"Expected waveform tensor for key '{waveform_key}', got {type(w)}"
)
# Accept [T] or [1, T]
if w.ndim == 1:
w = w.unsqueeze(0)
elif w.ndim != 2:
raise ValueError(
f"Expected waveform with shape [T] or [1, T], got {tuple(w.shape)}"
)
waveforms.append(w)
lengths = [w.shape[-1] for w in waveforms]
# 2) Determine target length
if mode == "pad":
target_len = max(lengths)
elif mode == "truncate":
target_len = min(lengths)
else:
raise ValueError(f"Unknown mode '{mode}' (expected 'pad' or 'truncate')")
# 3) Pad/truncate each waveform
processed = []
for w in waveforms:
cur_len = w.shape[-1]
if cur_len < target_len:
pad_amount = target_len - cur_len
w2 = torch.nn.functional.pad(w, (0, pad_amount), value=pad_value)
processed.append(w2)
elif cur_len > target_len:
processed.append(w[..., :target_len])
else:
processed.append(w)
if stack_waveforms:
waveform_batch = torch.stack(processed, dim=0) # [B, 1, T']
else:
waveform_batch = processed # list of [1, T']
# 4) Decide which other keys to include
all_keys = set(batch[0].keys())
all_keys.add(waveform_key)
if include_keys is not None:
keys_to_collate = set(include_keys) | {waveform_key}
else:
keys_to_collate = set(all_keys)
if exclude_keys is not None:
keys_to_collate -= set(exclude_keys)
keys_to_collate.add(waveform_key) # waveform always kept
# 5) Collate other keys (best effort)
out: Dict[str, Any] = {waveform_key: waveform_batch}
for k in keys_to_collate:
if k == waveform_key:
continue
values = [item.get(k, None) for item in batch]
# If all are tensors of same shape -> stack
if all(torch.is_tensor(v) for v in values):
try:
out[k] = torch.stack(values, dim=0)
continue
except Exception:
# fallback to list if stacking fails
out[k] = values
continue
# If all are numbers (int/float) -> tensor
if all(isinstance(v, (int, float)) for v in values):
out[k] = torch.tensor(values)
continue
# Otherwise -> list
out[k] = values
return out