from dataclasses import make_dataclass from typing import List, Optional, Tuple, Union import torch import torchaudio from torch import nn from torch.nn.utils.rnn import pad_sequence from torchaudio.compliance.kaldi import fbank from .usad_modules import ConformerEncoder, lengths_to_padding_mask MAX_MEL_LENGTH = 3000 # 30 seconds @torch.no_grad() def wav_to_fbank( wavs: torch.Tensor, mel_dim: int = 128, norm_mean: float = -4.268, norm_std: float = 4.569, wav_lengths: Optional[torch.Tensor] = None, sample_rate: int = 16000, return_lengths: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Convert waveform to fbank features. Args: wavs (torch.Tensor): (B, T_wav) waveform tensor. mel_dim (int, optional): mel dimension. Defaults to 128. norm_mean (float, optional): mean for normalization. Defaults to -4.268. norm_std (float, optional): std for normalization. Defaults to 4.569. wav_lengths (torch.Tensor, optional): (B,) valid waveform lengths before padding. sample_rate (int, optional): waveform sample rate. Defaults to 16000. return_lengths (bool, optional): return exact fbank lengths. Defaults to False. Returns: torch.Tensor: (B, T_mel, mel_dim) fbank features. If return_lengths is True, also returns a (B,) tensor with exact feature lengths before padding. """ # ref: https://github.com/cwx-worst-one/EAT/tree/main/feature_extract feature_dtype = wavs.dtype if wavs.is_floating_point() else torch.float32 wavs_float = wavs.to(torch.float32) if wav_lengths is None: wav_lengths = torch.full( (wavs.shape[0],), wavs.shape[1], dtype=torch.long, device=wavs.device, ) else: wav_lengths = wav_lengths.to(device=wavs.device, dtype=torch.long) if wav_lengths.dim() != 1 or wav_lengths.shape[0] != wavs.shape[0]: raise ValueError("wav_lengths must be a 1-D tensor with batch size elements.") if torch.any(wav_lengths <= 0).item(): raise ValueError("All wav_lengths values must be positive.") if torch.any(wav_lengths > wavs.shape[1]).item(): raise ValueError("wav_lengths cannot exceed the padded waveform length.") feats = [] feat_lengths = [] for i, wav_length in enumerate(wav_lengths.detach().cpu().tolist()): # Trim padding before centering so batched padding cannot affect valid audio. wav = wavs_float[i, :wav_length] wav = wav - wav.mean(dim=-1, keepdim=True) feat = fbank( wav.unsqueeze(0), htk_compat=True, sample_frequency=sample_rate, use_energy=False, window_type="hanning", num_mel_bins=mel_dim, dither=0.0, frame_shift=10, ) feat = (feat - norm_mean) / (norm_std * 2) feats.append(feat.to(dtype=feature_dtype)) feat_lengths.append(feat.shape[0]) mels = pad_sequence(feats, batch_first=True, padding_value=0.0) mel_lengths = torch.tensor(feat_lengths, dtype=torch.long, device=wavs.device) if return_lengths: return mels, mel_lengths return mels class UsadModel(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.encoder = ConformerEncoder(cfg) self.max_mel_length = MAX_MEL_LENGTH @property def sample_rate(self) -> int: return 16000 # Hz @property def encoder_frame_rate(self) -> int: return round(100 / self.cfg.conv_subsample_rate) # Hz @property def mel_dim(self) -> int: return self.cfg.input_dim @property def encoder_dim(self) -> int: return self.cfg.encoder_dim @property def num_layers(self) -> int: return self.cfg.num_layers @property def device(self) -> torch.device: return next(self.parameters()).device @property def dtype(self) -> torch.dtype: return next(self.parameters()).dtype def set_audio_chunk_size(self, seconds: float = 30.0) -> None: """Set the maximum chunk size for feature extraction. Args: seconds (float, optional): Chunk size in seconds. Defaults to 30.0. """ assert ( seconds >= 0.1 ), f"Chunk size must be greater than 0.1s, got {seconds} seconds." self.max_mel_length = int(seconds * 100) # 100 Hz frame rate def load_audio(self, audio_path: str, move_to_device: bool = True) -> torch.Tensor: """Load audio file and return waveform tensor. Args: audio_path (str): Path to the audio file. Returns: torch.Tensor: Waveform tensor of shape (wav_len,). """ waveform, sr = torchaudio.load(audio_path) if sr != self.sample_rate: waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate) if waveform.shape[0] > 1: # If stereo, convert to mono by averaging channels waveform = waveform.mean(dim=0, keepdim=True) waveform = waveform.squeeze(0) # Remove channel dimension if mono if move_to_device: return waveform.to(self.device) # Ensure tensor is on the same device return waveform def load_audio_batch( self, audio_paths: List[str] ) -> Tuple[torch.Tensor, torch.Tensor]: wav_list = [] wav_lengths = [] for path in audio_paths: wav = self.load_audio(path, move_to_device=False) wav_list.append(wav) wav_lengths.append(wav.shape[0]) wavs = pad_sequence(wav_list, batch_first=True).to(self.device) wav_lengths = torch.tensor(wav_lengths, dtype=torch.long, device=self.device) return wavs, wav_lengths def forward( self, wavs: torch.Tensor, wav_lengths: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, target_layer: Optional[int] = None, norm_mean: float = -4.268, norm_std: float = 4.569, ) -> dict: """ Args: wavs (torch.Tensor): (B, T_wav) waveform tensor. wav_lengths (torch.Tensor, optional): (B,) lengths of each waveform. Defaults to None. padding_mask (torch.Tensor, optional): (B, T_wav) padding mask for the waveforms. If wav_lengths is not provided, this is used to infer valid lengths. target_layer (int, optional): If specified, only return the output of the target layer. Defaults to None (return all layers). norm_mean (float, optional): Mean for normalization. Defaults to -4.268. norm_std (float, optional): Std for normalization. Defaults to 4.569. Returns: dict: A dictionary containing the following keys: - "x": (B, T_out, encoder_dim) output of the encoder - "x_lengths": (B,) valid output lengths after encoder subsampling - "x_padding_mask": (B, T_out) output padding mask, where padding is True - "mel": (B, T_mel, mel_dim) input mel features - "mel_lengths": (B,) valid mel lengths before encoder subsampling - "hidden_states": list of (B, T_out, encoder_dim) hidden states of each layer - "ffn": list of (B, T_out, encoder_dim) output of the feed-forward network of each layer """ # Check types assert isinstance(wavs, torch.Tensor), "wavs must be a torch.Tensor" assert wavs.dim() == 2, "wavs must be of shape (batch_size, seq_len)" if wav_lengths is not None: assert isinstance( wav_lengths, torch.Tensor ), "wav_lengths must be a torch.Tensor" assert wav_lengths.dim() == 1, "wav_lengths must be of shape (batch_size,)" assert ( wav_lengths.shape[0] == wavs.shape[0] ), "wav_lengths must have the same batch size as wavs" if padding_mask is not None: assert isinstance( padding_mask, torch.Tensor ), "padding_mask must be a torch.Tensor" assert ( padding_mask.dim() == 2 ), "padding_mask must be of shape (batch_size, seq_len)" assert ( padding_mask.shape[0] == wavs.shape[0] ), "padding_mask must have the same batch size as wavs" assert ( padding_mask.shape[1] == wavs.shape[1] ), "padding_mask must have the same seq_len as wavs" if wav_lengths is None: wav_lengths = (~padding_mask.to(torch.bool)).sum(dim=1) if target_layer is not None: assert isinstance(target_layer, int), "target_layer must be an int or None" assert ( 1 <= target_layer <= self.cfg.num_layers ), f"target_layer must be between 1 and {self.cfg.num_layers}" mel, mel_lengths = wav_to_fbank( wavs, wav_lengths=wav_lengths, mel_dim=self.mel_dim, norm_mean=norm_mean, norm_std=norm_std, sample_rate=self.sample_rate, return_lengths=True, ) dtype = self.dtype if mel.dtype != dtype: mel = mel.to(dtype) num_layers = min( self.cfg.num_layers, target_layer if target_layer is not None else self.cfg.num_layers, ) if mel.shape[1] <= self.max_mel_length: # If the mel length is less than or equal to max_mel_length, we can process it in one go x, x_len, layer_results = self.encoder( inputs=mel, input_lengths=mel_lengths, return_hidden=True, target_layer=target_layer, ) result = { "x": x, "x_lengths": x_len, "x_padding_mask": lengths_to_padding_mask(x_len, max_len=x.size(1)), "mel": mel, "mel_lengths": mel_lengths, "hidden_states": layer_results["hidden_states"], "ffn": layer_results["ffn_1"], } return result # If the mel length is greater than max_mel_length, we need to process it in chunks result = { "x": [], "x_lengths": [], "mel": mel, "mel_lengths": mel_lengths, "hidden_states": [[] for _ in range(num_layers)], "ffn": [[] for _ in range(num_layers)], } for i in range(0, mel.shape[1], self.max_mel_length): if mel.shape[1] - i < 10: break _mel = mel[:, i : i + self.max_mel_length] _mel_lengths = None if mel_lengths is not None: _mel_lengths = torch.clamp( mel_lengths - i, min=0, max=self.max_mel_length ) x, x_len, layer_results = self.encoder( inputs=_mel, input_lengths=_mel_lengths, return_hidden=True, target_layer=target_layer, ) result["x"].append(x) result["x_lengths"].append(x_len) for j in range(num_layers): result["hidden_states"][j].append(layer_results["hidden_states"][j]) result["ffn"][j].append(layer_results["ffn_1"][j]) result["x"] = torch.cat(result["x"], dim=1) result["x_lengths"] = torch.stack(result["x_lengths"], dim=0).sum(dim=0) result["x_padding_mask"] = lengths_to_padding_mask( result["x_lengths"], max_len=result["x"].size(1) ) for j in range(num_layers): result["hidden_states"][j] = torch.cat( result["hidden_states"][j], dim=1 ) result["ffn"][j] = torch.cat(result["ffn"][j], dim=1) return result @classmethod def load_from_fairseq_ckpt(cls, ckpt_path: str): checkpoint = torch.load(ckpt_path, weights_only=False) config = checkpoint["cfg"]["model"] config = make_dataclass("Config", config.keys())(**config) model = cls(config) state_dict = checkpoint["model"] for k in list(state_dict.keys()): if not k.startswith("encoder."): del state_dict[k] model.load_state_dict(state_dict, strict=True) return model