File size: 12,592 Bytes
94d3a9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
from dataclasses import make_dataclass
from typing import List, Optional, Tuple, Union

import torch
import torchaudio
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torchaudio.compliance.kaldi import fbank

from .usad_modules import ConformerEncoder, lengths_to_padding_mask

MAX_MEL_LENGTH = 3000  # 30 seconds


@torch.no_grad()
def wav_to_fbank(
    wavs: torch.Tensor,
    mel_dim: int = 128,
    norm_mean: float = -4.268,
    norm_std: float = 4.569,
    wav_lengths: Optional[torch.Tensor] = None,
    sample_rate: int = 16000,
    return_lengths: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """Convert waveform to fbank features.

    Args:
        wavs (torch.Tensor): (B, T_wav) waveform tensor.
        mel_dim (int, optional): mel dimension. Defaults to 128.
        norm_mean (float, optional): mean for normalization. Defaults to -4.268.
        norm_std (float, optional): std for normalization. Defaults to 4.569.
        wav_lengths (torch.Tensor, optional): (B,) valid waveform lengths before padding.
        sample_rate (int, optional): waveform sample rate. Defaults to 16000.
        return_lengths (bool, optional): return exact fbank lengths. Defaults to False.

    Returns:
        torch.Tensor: (B, T_mel, mel_dim) fbank features. If return_lengths is True,
        also returns a (B,) tensor with exact feature lengths before padding.
    """
    # ref: https://github.com/cwx-worst-one/EAT/tree/main/feature_extract
    feature_dtype = wavs.dtype if wavs.is_floating_point() else torch.float32
    wavs_float = wavs.to(torch.float32)

    if wav_lengths is None:
        wav_lengths = torch.full(
            (wavs.shape[0],),
            wavs.shape[1],
            dtype=torch.long,
            device=wavs.device,
        )
    else:
        wav_lengths = wav_lengths.to(device=wavs.device, dtype=torch.long)
        if wav_lengths.dim() != 1 or wav_lengths.shape[0] != wavs.shape[0]:
            raise ValueError("wav_lengths must be a 1-D tensor with batch size elements.")
        if torch.any(wav_lengths <= 0).item():
            raise ValueError("All wav_lengths values must be positive.")
        if torch.any(wav_lengths > wavs.shape[1]).item():
            raise ValueError("wav_lengths cannot exceed the padded waveform length.")

    feats = []
    feat_lengths = []
    for i, wav_length in enumerate(wav_lengths.detach().cpu().tolist()):
        # Trim padding before centering so batched padding cannot affect valid audio.
        wav = wavs_float[i, :wav_length]
        wav = wav - wav.mean(dim=-1, keepdim=True)
        feat = fbank(
            wav.unsqueeze(0),
            htk_compat=True,
            sample_frequency=sample_rate,
            use_energy=False,
            window_type="hanning",
            num_mel_bins=mel_dim,
            dither=0.0,
            frame_shift=10,
        )
        feat = (feat - norm_mean) / (norm_std * 2)
        feats.append(feat.to(dtype=feature_dtype))
        feat_lengths.append(feat.shape[0])

    mels = pad_sequence(feats, batch_first=True, padding_value=0.0)
    mel_lengths = torch.tensor(feat_lengths, dtype=torch.long, device=wavs.device)

    if return_lengths:
        return mels, mel_lengths
    return mels


class UsadModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg
        self.encoder = ConformerEncoder(cfg)
        self.max_mel_length = MAX_MEL_LENGTH

    @property
    def sample_rate(self) -> int:
        return 16000  # Hz

    @property
    def encoder_frame_rate(self) -> int:
        return round(100 / self.cfg.conv_subsample_rate)  # Hz

    @property
    def mel_dim(self) -> int:
        return self.cfg.input_dim

    @property
    def encoder_dim(self) -> int:
        return self.cfg.encoder_dim

    @property
    def num_layers(self) -> int:
        return self.cfg.num_layers

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device

    @property
    def dtype(self) -> torch.dtype:
        return next(self.parameters()).dtype

    def set_audio_chunk_size(self, seconds: float = 30.0) -> None:
        """Set the maximum chunk size for feature extraction.
        Args:
            seconds (float, optional): Chunk size in seconds. Defaults to 30.0.
        """
        assert (
            seconds >= 0.1
        ), f"Chunk size must be greater than 0.1s, got {seconds} seconds."
        self.max_mel_length = int(seconds * 100)  # 100 Hz frame rate

    def load_audio(self, audio_path: str, move_to_device: bool = True) -> torch.Tensor:
        """Load audio file and return waveform tensor.
        Args:
            audio_path (str): Path to the audio file.
        Returns:
            torch.Tensor: Waveform tensor of shape (wav_len,).
        """

        waveform, sr = torchaudio.load(audio_path)
        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
        if waveform.shape[0] > 1:
            # If stereo, convert to mono by averaging channels
            waveform = waveform.mean(dim=0, keepdim=True)

        waveform = waveform.squeeze(0)  # Remove channel dimension if mono
        if move_to_device:
            return waveform.to(self.device)  # Ensure tensor is on the same device
        return waveform

    def load_audio_batch(
        self, audio_paths: List[str]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        wav_list = []
        wav_lengths = []
        for path in audio_paths:
            wav = self.load_audio(path, move_to_device=False)
            wav_list.append(wav)
            wav_lengths.append(wav.shape[0])
        wavs = pad_sequence(wav_list, batch_first=True).to(self.device)
        wav_lengths = torch.tensor(wav_lengths, dtype=torch.long, device=self.device)
        return wavs, wav_lengths

    def forward(
        self,
        wavs: torch.Tensor,
        wav_lengths: Optional[torch.Tensor] = None,
        padding_mask: Optional[torch.Tensor] = None,
        target_layer: Optional[int] = None,
        norm_mean: float = -4.268,
        norm_std: float = 4.569,
    ) -> dict:
        """
        Args:
            wavs (torch.Tensor): (B, T_wav) waveform tensor.
            wav_lengths (torch.Tensor, optional): (B,) lengths of each waveform. Defaults to None.
            padding_mask (torch.Tensor, optional): (B, T_wav) padding mask for the waveforms.
                If wav_lengths is not provided, this is used to infer valid lengths.
            target_layer (int, optional): If specified, only return the output of the target layer. Defaults to None (return all layers).
            norm_mean (float, optional): Mean for normalization. Defaults to -4.268.
            norm_std (float, optional): Std for normalization. Defaults to 4.569.
        Returns:
            dict: A dictionary containing the following keys:
                - "x": (B, T_out, encoder_dim) output of the encoder
                - "x_lengths": (B,) valid output lengths after encoder subsampling
                - "x_padding_mask": (B, T_out) output padding mask, where padding is True
                - "mel": (B, T_mel, mel_dim) input mel features
                - "mel_lengths": (B,) valid mel lengths before encoder subsampling
                - "hidden_states": list of (B, T_out, encoder_dim) hidden states of each layer
                - "ffn": list of (B, T_out, encoder_dim) output of the feed-forward network of each layer
        """

        # Check types
        assert isinstance(wavs, torch.Tensor), "wavs must be a torch.Tensor"
        assert wavs.dim() == 2, "wavs must be of shape (batch_size, seq_len)"
        if wav_lengths is not None:
            assert isinstance(
                wav_lengths, torch.Tensor
            ), "wav_lengths must be a torch.Tensor"
            assert wav_lengths.dim() == 1, "wav_lengths must be of shape (batch_size,)"
            assert (
                wav_lengths.shape[0] == wavs.shape[0]
            ), "wav_lengths must have the same batch size as wavs"
        if padding_mask is not None:
            assert isinstance(
                padding_mask, torch.Tensor
            ), "padding_mask must be a torch.Tensor"
            assert (
                padding_mask.dim() == 2
            ), "padding_mask must be of shape (batch_size, seq_len)"
            assert (
                padding_mask.shape[0] == wavs.shape[0]
            ), "padding_mask must have the same batch size as wavs"
            assert (
                padding_mask.shape[1] == wavs.shape[1]
            ), "padding_mask must have the same seq_len as wavs"
            if wav_lengths is None:
                wav_lengths = (~padding_mask.to(torch.bool)).sum(dim=1)
        if target_layer is not None:
            assert isinstance(target_layer, int), "target_layer must be an int or None"
            assert (
                1 <= target_layer <= self.cfg.num_layers
            ), f"target_layer must be between 1 and {self.cfg.num_layers}"

        mel, mel_lengths = wav_to_fbank(
            wavs,
            wav_lengths=wav_lengths,
            mel_dim=self.mel_dim,
            norm_mean=norm_mean,
            norm_std=norm_std,
            sample_rate=self.sample_rate,
            return_lengths=True,
        )

        dtype = self.dtype

        if mel.dtype != dtype:
            mel = mel.to(dtype)

        num_layers = min(
            self.cfg.num_layers,
            target_layer if target_layer is not None else self.cfg.num_layers,
        )

        if mel.shape[1] <= self.max_mel_length:
            # If the mel length is less than or equal to max_mel_length, we can process it in one go
            x, x_len, layer_results = self.encoder(
                inputs=mel,
                input_lengths=mel_lengths,
                return_hidden=True,
                target_layer=target_layer,
            )

            result = {
                "x": x,
                "x_lengths": x_len,
                "x_padding_mask": lengths_to_padding_mask(x_len, max_len=x.size(1)),
                "mel": mel,
                "mel_lengths": mel_lengths,
                "hidden_states": layer_results["hidden_states"],
                "ffn": layer_results["ffn_1"],
            }
            return result

        # If the mel length is greater than max_mel_length, we need to process it in chunks
        result = {
            "x": [],
            "x_lengths": [],
            "mel": mel,
            "mel_lengths": mel_lengths,
            "hidden_states": [[] for _ in range(num_layers)],
            "ffn": [[] for _ in range(num_layers)],
        }
        for i in range(0, mel.shape[1], self.max_mel_length):
            if mel.shape[1] - i < 10:
                break

            _mel = mel[:, i : i + self.max_mel_length]
            _mel_lengths = None
            if mel_lengths is not None:
                _mel_lengths = torch.clamp(
                    mel_lengths - i, min=0, max=self.max_mel_length
                )

            x, x_len, layer_results = self.encoder(
                inputs=_mel,
                input_lengths=_mel_lengths,
                return_hidden=True,
                target_layer=target_layer,
            )

            result["x"].append(x)
            result["x_lengths"].append(x_len)
            for j in range(num_layers):
                result["hidden_states"][j].append(layer_results["hidden_states"][j])
                result["ffn"][j].append(layer_results["ffn_1"][j])

        result["x"] = torch.cat(result["x"], dim=1)
        result["x_lengths"] = torch.stack(result["x_lengths"], dim=0).sum(dim=0)
        result["x_padding_mask"] = lengths_to_padding_mask(
            result["x_lengths"], max_len=result["x"].size(1)
        )
        for j in range(num_layers):
            result["hidden_states"][j] = torch.cat(
                result["hidden_states"][j], dim=1
            )
            result["ffn"][j] = torch.cat(result["ffn"][j], dim=1)

        return result

    @classmethod
    def load_from_fairseq_ckpt(cls, ckpt_path: str):
        checkpoint = torch.load(ckpt_path, weights_only=False)
        config = checkpoint["cfg"]["model"]
        config = make_dataclass("Config", config.keys())(**config)
        model = cls(config)
        state_dict = checkpoint["model"]
        for k in list(state_dict.keys()):
            if not k.startswith("encoder."):
                del state_dict[k]
        model.load_state_dict(state_dict, strict=True)
        return model