Feature Extraction
Transformers
Safetensors
English
usad2
automatic-speech-recognition
audio-classification
audio
speech
music
custom_code
Instructions to use MIT-SLS/USAD2-XXLarge-Plus with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use MIT-SLS/USAD2-XXLarge-Plus with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="MIT-SLS/USAD2-XXLarge-Plus", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("MIT-SLS/USAD2-XXLarge-Plus", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from dataclasses import make_dataclass | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torchaudio | |
| from torch import nn | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torchaudio.compliance.kaldi import fbank | |
| from .usad_modules import ConformerEncoder, lengths_to_padding_mask | |
| MAX_MEL_LENGTH = 3000 # 30 seconds | |
| def wav_to_fbank( | |
| wavs: torch.Tensor, | |
| mel_dim: int = 128, | |
| norm_mean: float = -4.268, | |
| norm_std: float = 4.569, | |
| wav_lengths: Optional[torch.Tensor] = None, | |
| sample_rate: int = 16000, | |
| return_lengths: bool = False, | |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
| """Convert waveform to fbank features. | |
| Args: | |
| wavs (torch.Tensor): (B, T_wav) waveform tensor. | |
| mel_dim (int, optional): mel dimension. Defaults to 128. | |
| norm_mean (float, optional): mean for normalization. Defaults to -4.268. | |
| norm_std (float, optional): std for normalization. Defaults to 4.569. | |
| wav_lengths (torch.Tensor, optional): (B,) valid waveform lengths before padding. | |
| sample_rate (int, optional): waveform sample rate. Defaults to 16000. | |
| return_lengths (bool, optional): return exact fbank lengths. Defaults to False. | |
| Returns: | |
| torch.Tensor: (B, T_mel, mel_dim) fbank features. If return_lengths is True, | |
| also returns a (B,) tensor with exact feature lengths before padding. | |
| """ | |
| # ref: https://github.com/cwx-worst-one/EAT/tree/main/feature_extract | |
| feature_dtype = wavs.dtype if wavs.is_floating_point() else torch.float32 | |
| wavs_float = wavs.to(torch.float32) | |
| if wav_lengths is None: | |
| wav_lengths = torch.full( | |
| (wavs.shape[0],), | |
| wavs.shape[1], | |
| dtype=torch.long, | |
| device=wavs.device, | |
| ) | |
| else: | |
| wav_lengths = wav_lengths.to(device=wavs.device, dtype=torch.long) | |
| if wav_lengths.dim() != 1 or wav_lengths.shape[0] != wavs.shape[0]: | |
| raise ValueError("wav_lengths must be a 1-D tensor with batch size elements.") | |
| if torch.any(wav_lengths <= 0).item(): | |
| raise ValueError("All wav_lengths values must be positive.") | |
| if torch.any(wav_lengths > wavs.shape[1]).item(): | |
| raise ValueError("wav_lengths cannot exceed the padded waveform length.") | |
| feats = [] | |
| feat_lengths = [] | |
| for i, wav_length in enumerate(wav_lengths.detach().cpu().tolist()): | |
| # Trim padding before centering so batched padding cannot affect valid audio. | |
| wav = wavs_float[i, :wav_length] | |
| wav = wav - wav.mean(dim=-1, keepdim=True) | |
| feat = fbank( | |
| wav.unsqueeze(0), | |
| htk_compat=True, | |
| sample_frequency=sample_rate, | |
| use_energy=False, | |
| window_type="hanning", | |
| num_mel_bins=mel_dim, | |
| dither=0.0, | |
| frame_shift=10, | |
| ) | |
| feat = (feat - norm_mean) / (norm_std * 2) | |
| feats.append(feat.to(dtype=feature_dtype)) | |
| feat_lengths.append(feat.shape[0]) | |
| mels = pad_sequence(feats, batch_first=True, padding_value=0.0) | |
| mel_lengths = torch.tensor(feat_lengths, dtype=torch.long, device=wavs.device) | |
| if return_lengths: | |
| return mels, mel_lengths | |
| return mels | |
| class UsadModel(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.encoder = ConformerEncoder(cfg) | |
| self.max_mel_length = MAX_MEL_LENGTH | |
| def sample_rate(self) -> int: | |
| return 16000 # Hz | |
| def encoder_frame_rate(self) -> int: | |
| return round(100 / self.cfg.conv_subsample_rate) # Hz | |
| def mel_dim(self) -> int: | |
| return self.cfg.input_dim | |
| def encoder_dim(self) -> int: | |
| return self.cfg.encoder_dim | |
| def num_layers(self) -> int: | |
| return self.cfg.num_layers | |
| def device(self) -> torch.device: | |
| return next(self.parameters()).device | |
| def dtype(self) -> torch.dtype: | |
| return next(self.parameters()).dtype | |
| def set_audio_chunk_size(self, seconds: float = 30.0) -> None: | |
| """Set the maximum chunk size for feature extraction. | |
| Args: | |
| seconds (float, optional): Chunk size in seconds. Defaults to 30.0. | |
| """ | |
| assert ( | |
| seconds >= 0.1 | |
| ), f"Chunk size must be greater than 0.1s, got {seconds} seconds." | |
| self.max_mel_length = int(seconds * 100) # 100 Hz frame rate | |
| def load_audio(self, audio_path: str, move_to_device: bool = True) -> torch.Tensor: | |
| """Load audio file and return waveform tensor. | |
| Args: | |
| audio_path (str): Path to the audio file. | |
| Returns: | |
| torch.Tensor: Waveform tensor of shape (wav_len,). | |
| """ | |
| waveform, sr = torchaudio.load(audio_path) | |
| if sr != self.sample_rate: | |
| waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate) | |
| if waveform.shape[0] > 1: | |
| # If stereo, convert to mono by averaging channels | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| waveform = waveform.squeeze(0) # Remove channel dimension if mono | |
| if move_to_device: | |
| return waveform.to(self.device) # Ensure tensor is on the same device | |
| return waveform | |
| def load_audio_batch( | |
| self, audio_paths: List[str] | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| wav_list = [] | |
| wav_lengths = [] | |
| for path in audio_paths: | |
| wav = self.load_audio(path, move_to_device=False) | |
| wav_list.append(wav) | |
| wav_lengths.append(wav.shape[0]) | |
| wavs = pad_sequence(wav_list, batch_first=True).to(self.device) | |
| wav_lengths = torch.tensor(wav_lengths, dtype=torch.long, device=self.device) | |
| return wavs, wav_lengths | |
| def forward( | |
| self, | |
| wavs: torch.Tensor, | |
| wav_lengths: Optional[torch.Tensor] = None, | |
| padding_mask: Optional[torch.Tensor] = None, | |
| target_layer: Optional[int] = None, | |
| norm_mean: float = -4.268, | |
| norm_std: float = 4.569, | |
| ) -> dict: | |
| """ | |
| Args: | |
| wavs (torch.Tensor): (B, T_wav) waveform tensor. | |
| wav_lengths (torch.Tensor, optional): (B,) lengths of each waveform. Defaults to None. | |
| padding_mask (torch.Tensor, optional): (B, T_wav) padding mask for the waveforms. | |
| If wav_lengths is not provided, this is used to infer valid lengths. | |
| target_layer (int, optional): If specified, only return the output of the target layer. Defaults to None (return all layers). | |
| norm_mean (float, optional): Mean for normalization. Defaults to -4.268. | |
| norm_std (float, optional): Std for normalization. Defaults to 4.569. | |
| Returns: | |
| dict: A dictionary containing the following keys: | |
| - "x": (B, T_out, encoder_dim) output of the encoder | |
| - "x_lengths": (B,) valid output lengths after encoder subsampling | |
| - "x_padding_mask": (B, T_out) output padding mask, where padding is True | |
| - "mel": (B, T_mel, mel_dim) input mel features | |
| - "mel_lengths": (B,) valid mel lengths before encoder subsampling | |
| - "hidden_states": list of (B, T_out, encoder_dim) hidden states of each layer | |
| - "ffn": list of (B, T_out, encoder_dim) output of the feed-forward network of each layer | |
| """ | |
| # Check types | |
| assert isinstance(wavs, torch.Tensor), "wavs must be a torch.Tensor" | |
| assert wavs.dim() == 2, "wavs must be of shape (batch_size, seq_len)" | |
| if wav_lengths is not None: | |
| assert isinstance( | |
| wav_lengths, torch.Tensor | |
| ), "wav_lengths must be a torch.Tensor" | |
| assert wav_lengths.dim() == 1, "wav_lengths must be of shape (batch_size,)" | |
| assert ( | |
| wav_lengths.shape[0] == wavs.shape[0] | |
| ), "wav_lengths must have the same batch size as wavs" | |
| if padding_mask is not None: | |
| assert isinstance( | |
| padding_mask, torch.Tensor | |
| ), "padding_mask must be a torch.Tensor" | |
| assert ( | |
| padding_mask.dim() == 2 | |
| ), "padding_mask must be of shape (batch_size, seq_len)" | |
| assert ( | |
| padding_mask.shape[0] == wavs.shape[0] | |
| ), "padding_mask must have the same batch size as wavs" | |
| assert ( | |
| padding_mask.shape[1] == wavs.shape[1] | |
| ), "padding_mask must have the same seq_len as wavs" | |
| if wav_lengths is None: | |
| wav_lengths = (~padding_mask.to(torch.bool)).sum(dim=1) | |
| if target_layer is not None: | |
| assert isinstance(target_layer, int), "target_layer must be an int or None" | |
| assert ( | |
| 1 <= target_layer <= self.cfg.num_layers | |
| ), f"target_layer must be between 1 and {self.cfg.num_layers}" | |
| mel, mel_lengths = wav_to_fbank( | |
| wavs, | |
| wav_lengths=wav_lengths, | |
| mel_dim=self.mel_dim, | |
| norm_mean=norm_mean, | |
| norm_std=norm_std, | |
| sample_rate=self.sample_rate, | |
| return_lengths=True, | |
| ) | |
| dtype = self.dtype | |
| if mel.dtype != dtype: | |
| mel = mel.to(dtype) | |
| num_layers = min( | |
| self.cfg.num_layers, | |
| target_layer if target_layer is not None else self.cfg.num_layers, | |
| ) | |
| if mel.shape[1] <= self.max_mel_length: | |
| # If the mel length is less than or equal to max_mel_length, we can process it in one go | |
| x, x_len, layer_results = self.encoder( | |
| inputs=mel, | |
| input_lengths=mel_lengths, | |
| return_hidden=True, | |
| target_layer=target_layer, | |
| ) | |
| result = { | |
| "x": x, | |
| "x_lengths": x_len, | |
| "x_padding_mask": lengths_to_padding_mask(x_len, max_len=x.size(1)), | |
| "mel": mel, | |
| "mel_lengths": mel_lengths, | |
| "hidden_states": layer_results["hidden_states"], | |
| "ffn": layer_results["ffn_1"], | |
| } | |
| return result | |
| # If the mel length is greater than max_mel_length, we need to process it in chunks | |
| result = { | |
| "x": [], | |
| "x_lengths": [], | |
| "mel": mel, | |
| "mel_lengths": mel_lengths, | |
| "hidden_states": [[] for _ in range(num_layers)], | |
| "ffn": [[] for _ in range(num_layers)], | |
| } | |
| for i in range(0, mel.shape[1], self.max_mel_length): | |
| if mel.shape[1] - i < 10: | |
| break | |
| _mel = mel[:, i : i + self.max_mel_length] | |
| _mel_lengths = None | |
| if mel_lengths is not None: | |
| _mel_lengths = torch.clamp( | |
| mel_lengths - i, min=0, max=self.max_mel_length | |
| ) | |
| x, x_len, layer_results = self.encoder( | |
| inputs=_mel, | |
| input_lengths=_mel_lengths, | |
| return_hidden=True, | |
| target_layer=target_layer, | |
| ) | |
| result["x"].append(x) | |
| result["x_lengths"].append(x_len) | |
| for j in range(num_layers): | |
| result["hidden_states"][j].append(layer_results["hidden_states"][j]) | |
| result["ffn"][j].append(layer_results["ffn_1"][j]) | |
| result["x"] = torch.cat(result["x"], dim=1) | |
| result["x_lengths"] = torch.stack(result["x_lengths"], dim=0).sum(dim=0) | |
| result["x_padding_mask"] = lengths_to_padding_mask( | |
| result["x_lengths"], max_len=result["x"].size(1) | |
| ) | |
| for j in range(num_layers): | |
| result["hidden_states"][j] = torch.cat( | |
| result["hidden_states"][j], dim=1 | |
| ) | |
| result["ffn"][j] = torch.cat(result["ffn"][j], dim=1) | |
| return result | |
| def load_from_fairseq_ckpt(cls, ckpt_path: str): | |
| checkpoint = torch.load(ckpt_path, weights_only=False) | |
| config = checkpoint["cfg"]["model"] | |
| config = make_dataclass("Config", config.keys())(**config) | |
| model = cls(config) | |
| state_dict = checkpoint["model"] | |
| for k in list(state_dict.keys()): | |
| if not k.startswith("encoder."): | |
| del state_dict[k] | |
| model.load_state_dict(state_dict, strict=True) | |
| return model | |