| from typing import Any, Dict, List, Optional, Sequence |
|
|
| import numpy as np |
| import torch |
| import torchaudio |
|
|
|
|
| class DatasetResamplerCropper: |
| """ |
| Resample and optionally crop a waveform. |
| Maintains a cache of resamplers for different source sampling rates to optimize instantiation. |
| |
| Args: |
| target_sr (int): Target sampling rate. |
| max_length (Optional[int]): Maximum length in samples (at target_sr). |
| """ |
|
|
| def __init__(self, target_sr: int, max_length: Optional[int] = None): |
| self.target_sr = target_sr |
| self.max_length = max_length |
| self.resamplers: Dict[int, torchaudio.transforms.Resample] = {} |
|
|
| def forward(self, waveform: torch.Tensor, source_sr: int) -> torch.Tensor: |
| """ |
| Args: |
| waveform (torch.Tensor): Tensor of shape [T] or [C, T]. |
| source_sr (int): Source sampling rate. |
| |
| Returns: |
| torch.Tensor: Processed waveform tensor. |
| """ |
| |
| if source_sr != self.target_sr: |
| |
| |
|
|
| if self.max_length is not None: |
| |
| |
| crop_len_source = round(self.max_length * source_sr / self.target_sr) |
|
|
| if waveform.shape[-1] > crop_len_source: |
| max_start = waveform.shape[-1] - crop_len_source |
| start = np.random.randint(0, max_start + 1) |
| waveform = waveform[..., start : start + crop_len_source] |
|
|
| |
| if source_sr not in self.resamplers: |
| self.resamplers[source_sr] = torchaudio.transforms.Resample( |
| source_sr, self.target_sr |
| ) |
| resampler = self.resamplers[source_sr] |
| waveform = resampler(waveform) |
|
|
| |
| if self.max_length is not None and waveform.shape[-1] > self.max_length: |
| waveform = waveform[..., : self.max_length] |
|
|
| else: |
| |
| if self.max_length is not None and waveform.shape[-1] > self.max_length: |
| max_start = waveform.shape[-1] - self.max_length |
| start = np.random.randint(0, max_start + 1) |
| waveform = waveform[..., start : start + self.max_length] |
|
|
| return waveform |
|
|
| def __call__(self, waveform: torch.Tensor, source_sr: int) -> torch.Tensor: |
| return self.forward(waveform, source_sr) |
|
|
|
|
| def collate_audio_batch( |
| batch: List[Dict[str, Any]], |
| waveform_key: str = "waveform", |
| mode: str = "pad", |
| stack_waveforms: bool = True, |
| pad_value: float = 0.0, |
| include_keys: Optional[Sequence[str]] = None, |
| exclude_keys: Optional[Sequence[str]] = None, |
| ) -> Dict[str, Any]: |
| """ |
| Generic collate function for audio batches where each sample is a dict |
| containing at least `waveform_key` with shape [1, T] or [T]. |
| |
| Pads or truncates waveforms across the batch, and returns a dict that: |
| - always includes waveform_key -> Tensor [B, 1, T'] |
| - includes other keys aggregated into lists (or stacked if possible) |
| |
| Parameters |
| ---------- |
| batch: |
| List of sample dictionaries. |
| waveform_key: |
| Key of waveform in sample dict. |
| mode: |
| "pad" -> pad shorter waveforms to max length in batch |
| "truncate" -> truncate longer waveforms to min length in batch |
| stack_waveforms: |
| If True, returns waveforms stacked into a single tensor [B, 1, T']. |
| pad_value: |
| Value used for padding. |
| include_keys: |
| If provided, only these keys will be included in the output (plus waveform_key). |
| exclude_keys: |
| If provided, these keys will not be included (except waveform_key is always kept). |
| |
| Returns |
| ------- |
| Dict[str, Any] |
| Collated batch dict. |
| """ |
| if len(batch) == 0: |
| raise ValueError("Empty batch passed to collate_audio_batch") |
|
|
| |
| waveforms = [] |
| for item in batch: |
| if waveform_key not in item: |
| raise KeyError( |
| f"Missing key '{waveform_key}' in batch item: {list(item.keys())}" |
| ) |
|
|
| w = item[waveform_key] |
|
|
| if not torch.is_tensor(w): |
| raise TypeError( |
| f"Expected waveform tensor for key '{waveform_key}', got {type(w)}" |
| ) |
|
|
| |
| if w.ndim == 1: |
| w = w.unsqueeze(0) |
| elif w.ndim != 2: |
| raise ValueError( |
| f"Expected waveform with shape [T] or [1, T], got {tuple(w.shape)}" |
| ) |
|
|
| waveforms.append(w) |
|
|
| lengths = [w.shape[-1] for w in waveforms] |
|
|
| |
| if mode == "pad": |
| target_len = max(lengths) |
| elif mode == "truncate": |
| target_len = min(lengths) |
| else: |
| raise ValueError(f"Unknown mode '{mode}' (expected 'pad' or 'truncate')") |
|
|
| |
| processed = [] |
| for w in waveforms: |
| cur_len = w.shape[-1] |
| if cur_len < target_len: |
| pad_amount = target_len - cur_len |
| w2 = torch.nn.functional.pad(w, (0, pad_amount), value=pad_value) |
| processed.append(w2) |
| elif cur_len > target_len: |
| processed.append(w[..., :target_len]) |
| else: |
| processed.append(w) |
|
|
| if stack_waveforms: |
| waveform_batch = torch.stack(processed, dim=0) |
| else: |
| waveform_batch = processed |
|
|
| |
| all_keys = set(batch[0].keys()) |
| all_keys.add(waveform_key) |
|
|
| if include_keys is not None: |
| keys_to_collate = set(include_keys) | {waveform_key} |
| else: |
| keys_to_collate = set(all_keys) |
|
|
| if exclude_keys is not None: |
| keys_to_collate -= set(exclude_keys) |
| keys_to_collate.add(waveform_key) |
|
|
| |
| out: Dict[str, Any] = {waveform_key: waveform_batch} |
|
|
| for k in keys_to_collate: |
| if k == waveform_key: |
| continue |
|
|
| values = [item.get(k, None) for item in batch] |
|
|
| |
| if all(torch.is_tensor(v) for v in values): |
| try: |
| out[k] = torch.stack(values, dim=0) |
| continue |
| except Exception: |
| |
| out[k] = values |
| continue |
|
|
| |
| if all(isinstance(v, (int, float)) for v in values): |
| out[k] = torch.tensor(values) |
| continue |
|
|
| |
| out[k] = values |
|
|
| return out |
|
|