File size: 7,075 Bytes
eca55dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
from typing import Any, Dict, List, Optional, Sequence

import numpy as np
import torch
import torchaudio


class DatasetResamplerCropper:
    """
    Resample and optionally crop a waveform.
    Maintains a cache of resamplers for different source sampling rates to optimize instantiation.

    Args:
        target_sr (int): Target sampling rate.
        max_length (Optional[int]): Maximum length in samples (at target_sr).
    """

    def __init__(self, target_sr: int, max_length: Optional[int] = None):
        self.target_sr = target_sr
        self.max_length = max_length
        self.resamplers: Dict[int, torchaudio.transforms.Resample] = {}

    def forward(self, waveform: torch.Tensor, source_sr: int) -> torch.Tensor:
        """
        Args:
            waveform (torch.Tensor): Tensor of shape [T] or [C, T].
            source_sr (int): Source sampling rate.

        Returns:
            torch.Tensor: Processed waveform tensor.
        """
        # Resampling and Cropping Logic
        if source_sr != self.target_sr:
            # We need to resample.
            # Optimization: Crop in source domain first if we have a max_length

            if self.max_length is not None:
                # Calculate required source length to get max_length in target domain
                # Add a small buffer to avoid rounding issues
                crop_len_source = round(self.max_length * source_sr / self.target_sr)

                if waveform.shape[-1] > crop_len_source:
                    max_start = waveform.shape[-1] - crop_len_source
                    start = np.random.randint(0, max_start + 1)
                    waveform = waveform[..., start : start + crop_len_source]

            # Resample
            if source_sr not in self.resamplers:
                self.resamplers[source_sr] = torchaudio.transforms.Resample(
                    source_sr, self.target_sr
                )
            resampler = self.resamplers[source_sr]
            waveform = resampler(waveform)

            # Now handle max_length (trim if we cropped with buffer, or if it was already long enough)
            if self.max_length is not None and waveform.shape[-1] > self.max_length:
                waveform = waveform[..., : self.max_length]

        else:
            # No resampling, just standard random crop
            if self.max_length is not None and waveform.shape[-1] > self.max_length:
                max_start = waveform.shape[-1] - self.max_length
                start = np.random.randint(0, max_start + 1)
                waveform = waveform[..., start : start + self.max_length]

        return waveform

    def __call__(self, waveform: torch.Tensor, source_sr: int) -> torch.Tensor:
        return self.forward(waveform, source_sr)


def collate_audio_batch(
    batch: List[Dict[str, Any]],
    waveform_key: str = "waveform",
    mode: str = "pad",  # "pad" or "truncate"
    stack_waveforms: bool = True,
    pad_value: float = 0.0,
    include_keys: Optional[Sequence[str]] = None,
    exclude_keys: Optional[Sequence[str]] = None,
) -> Dict[str, Any]:
    """
    Generic collate function for audio batches where each sample is a dict
    containing at least `waveform_key` with shape [1, T] or [T].

    Pads or truncates waveforms across the batch, and returns a dict that:
        - always includes waveform_key -> Tensor [B, 1, T']
        - includes other keys aggregated into lists (or stacked if possible)

    Parameters
    ----------
    batch:
        List of sample dictionaries.
    waveform_key:
        Key of waveform in sample dict.
    mode:
        "pad" -> pad shorter waveforms to max length in batch
        "truncate" -> truncate longer waveforms to min length in batch
    stack_waveforms:
        If True, returns waveforms stacked into a single tensor [B, 1, T'].
    pad_value:
        Value used for padding.
    include_keys:
        If provided, only these keys will be included in the output (plus waveform_key).
    exclude_keys:
        If provided, these keys will not be included (except waveform_key is always kept).

    Returns
    -------
    Dict[str, Any]
        Collated batch dict.
    """
    if len(batch) == 0:
        raise ValueError("Empty batch passed to collate_audio_batch")

    # 1) Collect waveforms
    waveforms = []
    for item in batch:
        if waveform_key not in item:
            raise KeyError(
                f"Missing key '{waveform_key}' in batch item: {list(item.keys())}"
            )

        w = item[waveform_key]

        if not torch.is_tensor(w):
            raise TypeError(
                f"Expected waveform tensor for key '{waveform_key}', got {type(w)}"
            )

        # Accept [T] or [1, T]
        if w.ndim == 1:
            w = w.unsqueeze(0)
        elif w.ndim != 2:
            raise ValueError(
                f"Expected waveform with shape [T] or [1, T], got {tuple(w.shape)}"
            )

        waveforms.append(w)

    lengths = [w.shape[-1] for w in waveforms]

    # 2) Determine target length
    if mode == "pad":
        target_len = max(lengths)
    elif mode == "truncate":
        target_len = min(lengths)
    else:
        raise ValueError(f"Unknown mode '{mode}' (expected 'pad' or 'truncate')")

    # 3) Pad/truncate each waveform
    processed = []
    for w in waveforms:
        cur_len = w.shape[-1]
        if cur_len < target_len:
            pad_amount = target_len - cur_len
            w2 = torch.nn.functional.pad(w, (0, pad_amount), value=pad_value)
            processed.append(w2)
        elif cur_len > target_len:
            processed.append(w[..., :target_len])
        else:
            processed.append(w)

    if stack_waveforms:
        waveform_batch = torch.stack(processed, dim=0)  # [B, 1, T']
    else:
        waveform_batch = processed  # list of [1, T']

    # 4) Decide which other keys to include
    all_keys = set(batch[0].keys())
    all_keys.add(waveform_key)

    if include_keys is not None:
        keys_to_collate = set(include_keys) | {waveform_key}
    else:
        keys_to_collate = set(all_keys)

    if exclude_keys is not None:
        keys_to_collate -= set(exclude_keys)
        keys_to_collate.add(waveform_key)  # waveform always kept

    # 5) Collate other keys (best effort)
    out: Dict[str, Any] = {waveform_key: waveform_batch}

    for k in keys_to_collate:
        if k == waveform_key:
            continue

        values = [item.get(k, None) for item in batch]

        # If all are tensors of same shape -> stack
        if all(torch.is_tensor(v) for v in values):
            try:
                out[k] = torch.stack(values, dim=0)
                continue
            except Exception:
                # fallback to list if stacking fails
                out[k] = values
                continue

        # If all are numbers (int/float) -> tensor
        if all(isinstance(v, (int, float)) for v in values):
            out[k] = torch.tensor(values)
            continue

        # Otherwise -> list
        out[k] = values

    return out