vectominist commited on
Commit
17e49ef
·
verified ·
1 Parent(s): 59a898a

Add USAD2 model

Browse files
Files changed (8) hide show
  1. README.md +166 -0
  2. __init__.py +0 -0
  3. config.json +42 -0
  4. configuration_usad2.py +72 -0
  5. model.safetensors +3 -0
  6. modeling_usad2.py +59 -0
  7. usad_model.py +325 -0
  8. usad_modules.py +1027 -0
README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-sa-4.0
3
+ pipeline_tag: feature-extraction
4
+ tags:
5
+ - automatic-speech-recognition
6
+ - audio-classification
7
+ - audio
8
+ - speech
9
+ - music
10
+ library_name: transformers
11
+ datasets:
12
+ - openslr/librispeech_asr
13
+ - facebook/multilingual_librispeech
14
+ - mozilla-foundation/common_voice_17_0
15
+ - speechcolab/gigaspeech
16
+ - facebook/voxpopuli
17
+ - espnet/mms_ulab_v2
18
+ - google/fleurs
19
+ - AISHELL/AISHELL-1
20
+ - kresnik/zeroth_korean
21
+ - ylacombe/expresso
22
+ - agkphysics/AudioSet
23
+ - 11hu83/vggsound
24
+ - benjamin-paine/free-music-archive-full
25
+ - rkstgr/mtg-jamendo
26
+ language:
27
+ - en
28
+ ---
29
+ # USAD 2.0: Scaling Representation Distillation for Universal Audio Understanding
30
+
31
+ **USAD 2.0** is a bidirectional transformer-based universal audio encoder that extracts useful representations across multiple audio domains (speech/sound/music) by distilling from SSL/supervised audio foundation models without labeled data. USAD 2.0 achieves strong or state-of-the-art performance across probing ([HEAR](https://arxiv.org/abs/2203.03022) and [MARBLE](https://arxiv.org/abs/2306.10548)) and LLM-based evaluations ([XARES-LLM](https://arxiv.org/abs/2603.22728)).
32
+
33
+ Training data:
34
+ * Multilingual speech (116k hours)
35
+ * General audio and sound (21k hours)
36
+ * Music (13k hours)
37
+
38
+
39
+ [👀 **Read Full Paper**](https://arxiv.org/abs/2606.06444)
40
+
41
+ ---
42
+
43
+ ## 🗂️ Models
44
+
45
+ ### Self-supervised Teachers (WavLM, ATST, MuQ): General-purpose encoders with good probing performance
46
+
47
+ | Model | Params | Hidden | Layers | Framerate |
48
+ |:----------------------------------------------------- | ------:| ------:| ------:| ---------:|
49
+ | [USAD 2.0 Small](https://hf.co/MIT-SLS/USAD2-Small) | 25M | 384 | 12 | 50Hz |
50
+ | [USAD 2.0 Base](https://hf.co/MIT-SLS/USAD2-Base) | 97M | 768 | 12 | 50Hz |
51
+ | [USAD 2.0 Large](https://hf.co/MIT-SLS/USAD2-Large) | 336M | 1024 | 24 | 50Hz |
52
+ | [USAD 2.0 XLarge](https://hf.co/MIT-SLS/USAD2-XLarge) | 695M | 1280 | 32 | 25Hz |
53
+
54
+ ### Supervised Teachers (Whisper & Audio Flamingo 3): State-of-the-art encoders for audio LLM frontend
55
+ We suggest selecting the best layer with the `target_layer` argument in the forward function to optimize audio LLM performance.
56
+
57
+ | Model | Params | Hidden | Layers (Best) | Framerate |
58
+ |:------------------------------------------------------------- | ------:| ------:| -------------:| ---------:|
59
+ | [USAD 2.0 Large+](https://hf.co/MIT-SLS/USAD2-Large-Plus) | 336M | 1024 | 24 (20) | 50Hz |
60
+ | [USAD 2.0 XLarge+](https://hf.co/MIT-SLS/USAD2-XLarge-Plus) | 695M | 1280 | 32 (28) | 25Hz |
61
+ | [USAD 2.0 XXLarge+](https://hf.co/MIT-SLS/USAD2-XXLarge-Plus) | 1036M | 1280 | 48 (40) | 25Hz |
62
+
63
+ ---
64
+
65
+ ## ⚙️ Performance
66
+ - [HEAR](https://arxiv.org/abs/2203.03022): probing-based general audio evaluation covering speech, sound, and music
67
+ - [MARBLE](https://arxiv.org/abs/2306.10548): probing-based music capability benchmark (instruments and singing voice)
68
+ - [XARES-LLM](https://github.com/xiaomi-research/xares-llm): frozen audio encoder + LLM with multi-task LoRA fine-tuning
69
+ - Track A (classification): keyword spotting, speaker/language identification, spoof detection, intent/emotion/sound/genre/instrument classification, and sound event detection.
70
+ - Track B (understanding): English/Mandarin ASR and audio/music captioning
71
+
72
+ | Encoder | Params | HEAR | MARBLE | XARES-LLM-A | XARES-LLM-B |
73
+ | :---------------------- | ------:| --------:| --------:| -----------:| -----------:|
74
+ | **Single-encoder SOTA** | | | | | |
75
+ |   Base | ~90M | 80.6 | 74.0 | 0.660 | 0.418 |
76
+ |   Large | ~300M | 81.8 | **77.0** | 0.691 | 0.454 |
77
+ |   XLarge | ~600M | 82.6 | 75.1 | 0.782 | 0.457 |
78
+ | **USAD 2.0** | | | | | |
79
+ |   Small | 25M | 81.0 | 72.9 | 0.604 | 0.357 |
80
+ |   Base | 97M | 81.9 | 74.1 | 0.645 | 0.442 |
81
+ |   Large | 336M | 82.9 | 75.8 | 0.667 | 0.473 |
82
+ |   XLarge | 695M | 82.5 | 75.7 | 0.708 | 0.485 |
83
+ | **USAD 2.0+** | | | | | |
84
+ |   Large+ | 336M | 84.0 | 75.1 | 0.769 | 0.580 |
85
+ |   XLarge+ | 695M | **84.4** | 75.0 | 0.772 | 0.611 |
86
+ |   XXLarge+ | 1036M | **84.4** | 75.6 | **0.783** | **0.624** |
87
+
88
+ * The above evaluations are based on *frozen* encoders.
89
+ * We encourage fine-tuning USAD 2.0 models for optimal downstream task performance.
90
+
91
+ ---
92
+
93
+ ## 🚀 How To Use
94
+
95
+ **Installation**
96
+ ```
97
+ pip install -U torch torchaudio transformers
98
+ ```
99
+
100
+ **Load Model and Extract Features**
101
+ ```python
102
+ import torch
103
+ from transformers import AutoModel
104
+
105
+ # Load pre-trained model
106
+ model = AutoModel.from_pretrained(
107
+ "MIT-SLS/USAD2-XXLarge-Plus", trust_remote_code=True
108
+ ).cuda().eval()
109
+
110
+ # Model properties
111
+ model.sample_rate # required audio sample rate
112
+ model.encoder_frame_rate # frames per second (Hz)
113
+ model.mel_dim # mel feature dimension
114
+ model.encoder_dim # hidden dimension
115
+ model.num_layers # number of encoder layers
116
+ model.device # device
117
+ model.dtype # dtype
118
+
119
+ # Model methods
120
+ model.set_audio_chunk_size(30.0) # audio will be chunked if exceeds 30 seconds (default 30s)
121
+
122
+ # Load audio and resample to 16kHz
123
+ wavs, wav_lengths = model.load_audio_batch(["audio1.wav", "audio2.wav"])
124
+ # wavs: raw waveforms (batch_size, max_wav_len)
125
+ # wav_lengths: length of each sample (batch_size, )
126
+ # You can also load waveforms directly with torchaudio.load
127
+
128
+ # Extract features
129
+ with torch.no_grad():
130
+ results = model(
131
+ wavs=wavs,
132
+ wav_lengths=wav_lengths,
133
+ target_layer=None, # None for last layer, or integer 1 ~ model.num_layers
134
+ )
135
+
136
+ # result["x"]: model final output (batch_size, seq_len, encoder_dim)
137
+ # result["x_lengths"]: valid output lengths after encoder subsampling
138
+ # result["x_padding_mask"]: output padding mask, where padding is True
139
+ # result["mel"]: mel fbank (batch_size, mel_len, mel_dim)
140
+ # result["mel_lengths"]: valid mel lengths before encoder subsampling
141
+ # result["hidden_states"]: list of (batch_size, seq_len, encoder_dim)
142
+ # result["ffn"]: list of (batch_size, seq_len, encoder_dim)
143
+ ```
144
+
145
+ * The self-attention mechanism is implemented with [SDPA](https://pytorch.org/blog/out-of-the-box-acceleration/), you may install FlashAttention to optimize inference efficiency.
146
+ * `bfloat16` is preferred for fast inference.
147
+ * Avoid using `float16` for numerical stability.
148
+
149
+ ---
150
+
151
+ ## 📖 Citation
152
+
153
+ ```bibtex
154
+ @inproceedings{chang2026usad2,
155
+ title={{USAD 2.0}: Scaling Representation Distillation for Universal Audio Understanding},
156
+ author={Chang, Heng-Jui and Liu, Alexander H. and Bhati, Saurabhchand and Athi, Mrudula and Ratnarajah, Anton and Chhetri, Amit and Glass, James},
157
+ booktitle={Interspeech},
158
+ year={2026}
159
+ }
160
+ ```
161
+
162
+ ---
163
+
164
+ ## 🙏 Acknowledgement
165
+
166
+ Our implementation is based on the awesome [facebookresearch/fairseq](https://github.com/facebookresearch/fairseq), [cwx-worst-one/EAT](https://github.com/cwx-worst-one/EAT), and [sooftware/conformer](https://github.com/sooftware/conformer) repositories.
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Usad2Model"
4
+ ],
5
+ "attention_dropout_p": 0.0,
6
+ "attention_type": "mhsa",
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_usad2.Usad2Config",
9
+ "AutoModel": "modeling_usad2.Usad2Model"
10
+ },
11
+ "conv_dropout_p": 0.0,
12
+ "conv_expansion_factor": 2,
13
+ "conv_kernel_size": 31,
14
+ "conv_pos": true,
15
+ "conv_pos_depth": 5,
16
+ "conv_pos_groups": 16,
17
+ "conv_pos_width": 95,
18
+ "conv_subsample_channels": 64,
19
+ "conv_subsample_rate": 4,
20
+ "encoder_dim": 1280,
21
+ "feed_forward_dropout_p": 0.0,
22
+ "feed_forward_expansion_factor": 4,
23
+ "half_step_residual": true,
24
+ "input_dim": 128,
25
+ "input_dropout_p": 0.0,
26
+ "layerdrop_p": 0.0,
27
+ "model_type": "usad2",
28
+ "num_attention_heads": 20,
29
+ "num_layers": 48,
30
+ "patch_size_freq": 16,
31
+ "patch_size_time": 16,
32
+ "pre_norm": true,
33
+ "rms_norm": false,
34
+ "sample_rate": 16000,
35
+ "subsample_normalization": true,
36
+ "torch_dtype": "float32",
37
+ "transformer_style": true,
38
+ "transformers_version": "4.49.0",
39
+ "usad_v2": true,
40
+ "use_framewise_subsample": true,
41
+ "use_patchwise_subsample": false
42
+ }
configuration_usad2.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class Usad2Config(PretrainedConfig):
5
+ model_type = "usad2"
6
+
7
+ def __init__(
8
+ self,
9
+ input_dim: int = 128,
10
+ use_framewise_subsample: bool = True,
11
+ conv_subsample_channels: int = 64,
12
+ conv_subsample_rate: int = 2,
13
+ use_patchwise_subsample: bool = False,
14
+ patch_size_time: int = 16,
15
+ patch_size_freq: int = 16,
16
+ subsample_normalization: bool = True,
17
+ conv_pos: bool = True,
18
+ conv_pos_depth: int = 5,
19
+ conv_pos_width: int = 95,
20
+ conv_pos_groups: int = 16,
21
+ encoder_dim: int = 384,
22
+ num_layers: int = 12,
23
+ attention_type="mhsa",
24
+ num_attention_heads: int = 8,
25
+ feed_forward_expansion_factor: int = 4,
26
+ conv_expansion_factor: int = 2,
27
+ input_dropout_p: float = 0.0,
28
+ feed_forward_dropout_p: float = 0.0,
29
+ attention_dropout_p: float = 0.0,
30
+ conv_dropout_p: float = 0.0,
31
+ conv_kernel_size: int = 31,
32
+ half_step_residual: bool = True,
33
+ transformer_style: bool = True,
34
+ layerdrop_p: float = 0.0,
35
+ usad_v2: bool = True,
36
+ pre_norm: bool = False,
37
+ rms_norm: bool = False,
38
+ sample_rate: int = 16000,
39
+ **kwargs,
40
+ ):
41
+ super().__init__(**kwargs)
42
+
43
+ self.input_dim = input_dim
44
+ self.use_framewise_subsample = use_framewise_subsample
45
+ self.conv_subsample_channels = conv_subsample_channels
46
+ self.conv_subsample_rate = conv_subsample_rate
47
+ self.use_patchwise_subsample = use_patchwise_subsample
48
+ self.patch_size_time = patch_size_time
49
+ self.patch_size_freq = patch_size_freq
50
+ self.subsample_normalization = subsample_normalization
51
+ self.conv_pos = conv_pos
52
+ self.conv_pos_depth = conv_pos_depth
53
+ self.conv_pos_width = conv_pos_width
54
+ self.conv_pos_groups = conv_pos_groups
55
+ self.encoder_dim = encoder_dim
56
+ self.num_layers = num_layers
57
+ self.attention_type = attention_type
58
+ self.num_attention_heads = num_attention_heads
59
+ self.feed_forward_expansion_factor = feed_forward_expansion_factor
60
+ self.conv_expansion_factor = conv_expansion_factor
61
+ self.input_dropout_p = input_dropout_p
62
+ self.feed_forward_dropout_p = feed_forward_dropout_p
63
+ self.attention_dropout_p = attention_dropout_p
64
+ self.conv_dropout_p = conv_dropout_p
65
+ self.conv_kernel_size = conv_kernel_size
66
+ self.half_step_residual = half_step_residual
67
+ self.transformer_style = transformer_style
68
+ self.layerdrop_p = layerdrop_p
69
+ self.usad_v2 = usad_v2
70
+ self.pre_norm = pre_norm
71
+ self.rms_norm = rms_norm
72
+ self.sample_rate = sample_rate
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4241082d91dde7444ce666613dc8c2ab0ad6d7f18267698a20604871a6f6a20b
3
+ size 4142517064
modeling_usad2.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from transformers import PreTrainedModel
5
+
6
+ from .configuration_usad2 import Usad2Config
7
+ from .usad_model import UsadModel
8
+
9
+
10
+ class Usad2Model(PreTrainedModel):
11
+ config_class = Usad2Config
12
+ base_model_prefix = "model"
13
+ main_input_name = "wavs"
14
+
15
+ def __init__(self, config: Usad2Config):
16
+ super().__init__(config)
17
+ self.model = UsadModel(config)
18
+
19
+ def forward(self, *args, **kwargs):
20
+ return self.model(*args, **kwargs)
21
+
22
+ @property
23
+ def sample_rate(self) -> int:
24
+ return 16000 # Hz
25
+
26
+ @property
27
+ def encoder_frame_rate(self) -> int:
28
+ return round(100 / self.config.conv_subsample_rate) # Hz
29
+
30
+ @property
31
+ def mel_dim(self) -> int:
32
+ return self.config.input_dim
33
+
34
+ @property
35
+ def encoder_dim(self) -> int:
36
+ return self.config.encoder_dim
37
+
38
+ @property
39
+ def num_layers(self) -> int:
40
+ return self.config.num_layers
41
+
42
+ @property
43
+ def device(self) -> torch.device:
44
+ return next(self.parameters()).device
45
+
46
+ @property
47
+ def dtype(self) -> torch.dtype:
48
+ return next(self.parameters()).dtype
49
+
50
+ def set_audio_chunk_size(self, seconds: float = 30.0) -> None:
51
+ self.model.set_audio_chunk_size(seconds)
52
+
53
+ def load_audio(self, audio_path: str) -> torch.Tensor:
54
+ return self.model.load_audio(audio_path)
55
+
56
+ def load_audio_batch(
57
+ self, audio_paths: List[str]
58
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
59
+ return self.model.load_audio_batch(audio_paths)
usad_model.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import make_dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torchaudio
6
+ from torch import nn
7
+ from torch.nn.utils.rnn import pad_sequence
8
+ from torchaudio.compliance.kaldi import fbank
9
+
10
+ from .usad_modules import ConformerEncoder, lengths_to_padding_mask
11
+
12
+ MAX_MEL_LENGTH = 3000 # 30 seconds
13
+
14
+
15
+ @torch.no_grad()
16
+ def wav_to_fbank(
17
+ wavs: torch.Tensor,
18
+ mel_dim: int = 128,
19
+ norm_mean: float = -4.268,
20
+ norm_std: float = 4.569,
21
+ wav_lengths: Optional[torch.Tensor] = None,
22
+ sample_rate: int = 16000,
23
+ return_lengths: bool = False,
24
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
25
+ """Convert waveform to fbank features.
26
+
27
+ Args:
28
+ wavs (torch.Tensor): (B, T_wav) waveform tensor.
29
+ mel_dim (int, optional): mel dimension. Defaults to 128.
30
+ norm_mean (float, optional): mean for normalization. Defaults to -4.268.
31
+ norm_std (float, optional): std for normalization. Defaults to 4.569.
32
+ wav_lengths (torch.Tensor, optional): (B,) valid waveform lengths before padding.
33
+ sample_rate (int, optional): waveform sample rate. Defaults to 16000.
34
+ return_lengths (bool, optional): return exact fbank lengths. Defaults to False.
35
+
36
+ Returns:
37
+ torch.Tensor: (B, T_mel, mel_dim) fbank features. If return_lengths is True,
38
+ also returns a (B,) tensor with exact feature lengths before padding.
39
+ """
40
+ # ref: https://github.com/cwx-worst-one/EAT/tree/main/feature_extract
41
+ feature_dtype = wavs.dtype if wavs.is_floating_point() else torch.float32
42
+ wavs_float = wavs.to(torch.float32)
43
+
44
+ if wav_lengths is None:
45
+ wav_lengths = torch.full(
46
+ (wavs.shape[0],),
47
+ wavs.shape[1],
48
+ dtype=torch.long,
49
+ device=wavs.device,
50
+ )
51
+ else:
52
+ wav_lengths = wav_lengths.to(device=wavs.device, dtype=torch.long)
53
+ if wav_lengths.dim() != 1 or wav_lengths.shape[0] != wavs.shape[0]:
54
+ raise ValueError("wav_lengths must be a 1-D tensor with batch size elements.")
55
+ if torch.any(wav_lengths <= 0).item():
56
+ raise ValueError("All wav_lengths values must be positive.")
57
+ if torch.any(wav_lengths > wavs.shape[1]).item():
58
+ raise ValueError("wav_lengths cannot exceed the padded waveform length.")
59
+
60
+ feats = []
61
+ feat_lengths = []
62
+ for i, wav_length in enumerate(wav_lengths.detach().cpu().tolist()):
63
+ # Trim padding before centering so batched padding cannot affect valid audio.
64
+ wav = wavs_float[i, :wav_length]
65
+ wav = wav - wav.mean(dim=-1, keepdim=True)
66
+ feat = fbank(
67
+ wav.unsqueeze(0),
68
+ htk_compat=True,
69
+ sample_frequency=sample_rate,
70
+ use_energy=False,
71
+ window_type="hanning",
72
+ num_mel_bins=mel_dim,
73
+ dither=0.0,
74
+ frame_shift=10,
75
+ )
76
+ feat = (feat - norm_mean) / (norm_std * 2)
77
+ feats.append(feat.to(dtype=feature_dtype))
78
+ feat_lengths.append(feat.shape[0])
79
+
80
+ mels = pad_sequence(feats, batch_first=True, padding_value=0.0)
81
+ mel_lengths = torch.tensor(feat_lengths, dtype=torch.long, device=wavs.device)
82
+
83
+ if return_lengths:
84
+ return mels, mel_lengths
85
+ return mels
86
+
87
+
88
+ class UsadModel(nn.Module):
89
+ def __init__(self, cfg):
90
+ super().__init__()
91
+
92
+ self.cfg = cfg
93
+ self.encoder = ConformerEncoder(cfg)
94
+ self.max_mel_length = MAX_MEL_LENGTH
95
+
96
+ @property
97
+ def sample_rate(self) -> int:
98
+ return 16000 # Hz
99
+
100
+ @property
101
+ def encoder_frame_rate(self) -> int:
102
+ return round(100 / self.cfg.conv_subsample_rate) # Hz
103
+
104
+ @property
105
+ def mel_dim(self) -> int:
106
+ return self.cfg.input_dim
107
+
108
+ @property
109
+ def encoder_dim(self) -> int:
110
+ return self.cfg.encoder_dim
111
+
112
+ @property
113
+ def num_layers(self) -> int:
114
+ return self.cfg.num_layers
115
+
116
+ @property
117
+ def device(self) -> torch.device:
118
+ return next(self.parameters()).device
119
+
120
+ @property
121
+ def dtype(self) -> torch.dtype:
122
+ return next(self.parameters()).dtype
123
+
124
+ def set_audio_chunk_size(self, seconds: float = 30.0) -> None:
125
+ """Set the maximum chunk size for feature extraction.
126
+ Args:
127
+ seconds (float, optional): Chunk size in seconds. Defaults to 30.0.
128
+ """
129
+ assert (
130
+ seconds >= 0.1
131
+ ), f"Chunk size must be greater than 0.1s, got {seconds} seconds."
132
+ self.max_mel_length = int(seconds * 100) # 100 Hz frame rate
133
+
134
+ def load_audio(self, audio_path: str, move_to_device: bool = True) -> torch.Tensor:
135
+ """Load audio file and return waveform tensor.
136
+ Args:
137
+ audio_path (str): Path to the audio file.
138
+ Returns:
139
+ torch.Tensor: Waveform tensor of shape (wav_len,).
140
+ """
141
+
142
+ waveform, sr = torchaudio.load(audio_path)
143
+ if sr != self.sample_rate:
144
+ waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
145
+ if waveform.shape[0] > 1:
146
+ # If stereo, convert to mono by averaging channels
147
+ waveform = waveform.mean(dim=0, keepdim=True)
148
+
149
+ waveform = waveform.squeeze(0) # Remove channel dimension if mono
150
+ if move_to_device:
151
+ return waveform.to(self.device) # Ensure tensor is on the same device
152
+ return waveform
153
+
154
+ def load_audio_batch(
155
+ self, audio_paths: List[str]
156
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
157
+ wav_list = []
158
+ wav_lengths = []
159
+ for path in audio_paths:
160
+ wav = self.load_audio(path, move_to_device=False)
161
+ wav_list.append(wav)
162
+ wav_lengths.append(wav.shape[0])
163
+ wavs = pad_sequence(wav_list, batch_first=True).to(self.device)
164
+ wav_lengths = torch.tensor(wav_lengths, dtype=torch.long, device=self.device)
165
+ return wavs, wav_lengths
166
+
167
+ def forward(
168
+ self,
169
+ wavs: torch.Tensor,
170
+ wav_lengths: Optional[torch.Tensor] = None,
171
+ padding_mask: Optional[torch.Tensor] = None,
172
+ target_layer: Optional[int] = None,
173
+ norm_mean: float = -4.268,
174
+ norm_std: float = 4.569,
175
+ ) -> dict:
176
+ """
177
+ Args:
178
+ wavs (torch.Tensor): (B, T_wav) waveform tensor.
179
+ wav_lengths (torch.Tensor, optional): (B,) lengths of each waveform. Defaults to None.
180
+ padding_mask (torch.Tensor, optional): (B, T_wav) padding mask for the waveforms.
181
+ If wav_lengths is not provided, this is used to infer valid lengths.
182
+ target_layer (int, optional): If specified, only return the output of the target layer. Defaults to None (return all layers).
183
+ norm_mean (float, optional): Mean for normalization. Defaults to -4.268.
184
+ norm_std (float, optional): Std for normalization. Defaults to 4.569.
185
+ Returns:
186
+ dict: A dictionary containing the following keys:
187
+ - "x": (B, T_out, encoder_dim) output of the encoder
188
+ - "x_lengths": (B,) valid output lengths after encoder subsampling
189
+ - "x_padding_mask": (B, T_out) output padding mask, where padding is True
190
+ - "mel": (B, T_mel, mel_dim) input mel features
191
+ - "mel_lengths": (B,) valid mel lengths before encoder subsampling
192
+ - "hidden_states": list of (B, T_out, encoder_dim) hidden states of each layer
193
+ - "ffn": list of (B, T_out, encoder_dim) output of the feed-forward network of each layer
194
+ """
195
+
196
+ # Check types
197
+ assert isinstance(wavs, torch.Tensor), "wavs must be a torch.Tensor"
198
+ assert wavs.dim() == 2, "wavs must be of shape (batch_size, seq_len)"
199
+ if wav_lengths is not None:
200
+ assert isinstance(
201
+ wav_lengths, torch.Tensor
202
+ ), "wav_lengths must be a torch.Tensor"
203
+ assert wav_lengths.dim() == 1, "wav_lengths must be of shape (batch_size,)"
204
+ assert (
205
+ wav_lengths.shape[0] == wavs.shape[0]
206
+ ), "wav_lengths must have the same batch size as wavs"
207
+ if padding_mask is not None:
208
+ assert isinstance(
209
+ padding_mask, torch.Tensor
210
+ ), "padding_mask must be a torch.Tensor"
211
+ assert (
212
+ padding_mask.dim() == 2
213
+ ), "padding_mask must be of shape (batch_size, seq_len)"
214
+ assert (
215
+ padding_mask.shape[0] == wavs.shape[0]
216
+ ), "padding_mask must have the same batch size as wavs"
217
+ assert (
218
+ padding_mask.shape[1] == wavs.shape[1]
219
+ ), "padding_mask must have the same seq_len as wavs"
220
+ if wav_lengths is None:
221
+ wav_lengths = (~padding_mask.to(torch.bool)).sum(dim=1)
222
+ if target_layer is not None:
223
+ assert isinstance(target_layer, int), "target_layer must be an int or None"
224
+ assert (
225
+ 1 <= target_layer <= self.cfg.num_layers
226
+ ), f"target_layer must be between 1 and {self.cfg.num_layers}"
227
+
228
+ mel, mel_lengths = wav_to_fbank(
229
+ wavs,
230
+ wav_lengths=wav_lengths,
231
+ mel_dim=self.mel_dim,
232
+ norm_mean=norm_mean,
233
+ norm_std=norm_std,
234
+ sample_rate=self.sample_rate,
235
+ return_lengths=True,
236
+ )
237
+
238
+ dtype = self.dtype
239
+
240
+ if mel.dtype != dtype:
241
+ mel = mel.to(dtype)
242
+
243
+ num_layers = min(
244
+ self.cfg.num_layers,
245
+ target_layer if target_layer is not None else self.cfg.num_layers,
246
+ )
247
+
248
+ if mel.shape[1] <= self.max_mel_length:
249
+ # If the mel length is less than or equal to max_mel_length, we can process it in one go
250
+ x, x_len, layer_results = self.encoder(
251
+ inputs=mel,
252
+ input_lengths=mel_lengths,
253
+ return_hidden=True,
254
+ target_layer=target_layer,
255
+ )
256
+
257
+ result = {
258
+ "x": x,
259
+ "x_lengths": x_len,
260
+ "x_padding_mask": lengths_to_padding_mask(x_len, max_len=x.size(1)),
261
+ "mel": mel,
262
+ "mel_lengths": mel_lengths,
263
+ "hidden_states": layer_results["hidden_states"],
264
+ "ffn": layer_results["ffn_1"],
265
+ }
266
+ return result
267
+
268
+ # If the mel length is greater than max_mel_length, we need to process it in chunks
269
+ result = {
270
+ "x": [],
271
+ "x_lengths": [],
272
+ "mel": mel,
273
+ "mel_lengths": mel_lengths,
274
+ "hidden_states": [[] for _ in range(num_layers)],
275
+ "ffn": [[] for _ in range(num_layers)],
276
+ }
277
+ for i in range(0, mel.shape[1], self.max_mel_length):
278
+ if mel.shape[1] - i < 10:
279
+ break
280
+
281
+ _mel = mel[:, i : i + self.max_mel_length]
282
+ _mel_lengths = None
283
+ if mel_lengths is not None:
284
+ _mel_lengths = torch.clamp(
285
+ mel_lengths - i, min=0, max=self.max_mel_length
286
+ )
287
+
288
+ x, x_len, layer_results = self.encoder(
289
+ inputs=_mel,
290
+ input_lengths=_mel_lengths,
291
+ return_hidden=True,
292
+ target_layer=target_layer,
293
+ )
294
+
295
+ result["x"].append(x)
296
+ result["x_lengths"].append(x_len)
297
+ for j in range(num_layers):
298
+ result["hidden_states"][j].append(layer_results["hidden_states"][j])
299
+ result["ffn"][j].append(layer_results["ffn_1"][j])
300
+
301
+ result["x"] = torch.cat(result["x"], dim=1)
302
+ result["x_lengths"] = torch.stack(result["x_lengths"], dim=0).sum(dim=0)
303
+ result["x_padding_mask"] = lengths_to_padding_mask(
304
+ result["x_lengths"], max_len=result["x"].size(1)
305
+ )
306
+ for j in range(num_layers):
307
+ result["hidden_states"][j] = torch.cat(
308
+ result["hidden_states"][j], dim=1
309
+ )
310
+ result["ffn"][j] = torch.cat(result["ffn"][j], dim=1)
311
+
312
+ return result
313
+
314
+ @classmethod
315
+ def load_from_fairseq_ckpt(cls, ckpt_path: str):
316
+ checkpoint = torch.load(ckpt_path, weights_only=False)
317
+ config = checkpoint["cfg"]["model"]
318
+ config = make_dataclass("Config", config.keys())(**config)
319
+ model = cls(config)
320
+ state_dict = checkpoint["model"]
321
+ for k in list(state_dict.keys()):
322
+ if not k.startswith("encoder."):
323
+ del state_dict[k]
324
+ model.load_state_dict(state_dict, strict=True)
325
+ return model
usad_modules.py ADDED
@@ -0,0 +1,1027 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/sooftware/conformer
2
+
3
+ import contextlib
4
+ import math
5
+ from collections import defaultdict
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ from torch.nn.attention import SDPBackend, sdpa_kernel
12
+
13
+
14
+ def lengths_to_padding_mask(
15
+ lengths: torch.Tensor, max_len: Optional[int] = None
16
+ ) -> torch.Tensor:
17
+ """Create padding mask from lengths.
18
+
19
+ Args:
20
+ lengths: A 1-D tensor of shape (B,).
21
+ max_len: An integer. It will be automatically set to the max value of lengths
22
+ if not given.
23
+
24
+ Returns:
25
+ A bool tensor of shape (B, max_len), where padded positions are indicated by True.
26
+ """
27
+ batch_size = lengths.size(0)
28
+ max_len = lengths.max().item() if max_len is None else max_len
29
+ seq_range = torch.arange(0, max_len, dtype=torch.long, device=lengths.device)
30
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
31
+ lengths_expand = lengths.unsqueeze(1).expand_as(seq_range_expand)
32
+ padding_mask = seq_range_expand >= lengths_expand
33
+ return padding_mask
34
+
35
+
36
+ class SamePad(nn.Module):
37
+ def __init__(self, kernel_size, causal=False):
38
+ super().__init__()
39
+ if causal:
40
+ self.remove = kernel_size - 1
41
+ else:
42
+ self.remove = 1 if kernel_size % 2 == 0 else 0
43
+
44
+ def forward(self, x):
45
+ if self.remove > 0:
46
+ x = x[:, :, : -self.remove]
47
+ return x
48
+
49
+
50
+ class TransposeLast(nn.Module):
51
+ def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
52
+ super().__init__()
53
+ self.deconstruct_idx = deconstruct_idx
54
+ self.tranpose_dim = tranpose_dim
55
+
56
+ def forward(self, x):
57
+ if self.deconstruct_idx is not None:
58
+ x = x[self.deconstruct_idx]
59
+ return x.transpose(self.tranpose_dim, -1)
60
+
61
+
62
+ class Swish(nn.Module):
63
+ def __init__(self):
64
+ super(Swish, self).__init__()
65
+
66
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
67
+ return inputs * inputs.sigmoid()
68
+
69
+
70
+ class GLU(nn.Module):
71
+ def __init__(self, dim: int) -> None:
72
+ super(GLU, self).__init__()
73
+ self.dim = dim
74
+
75
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
76
+ outputs, gate = inputs.chunk(2, dim=self.dim)
77
+ return outputs * gate.sigmoid()
78
+
79
+
80
+ class RMSNorm(torch.nn.Module):
81
+ def __init__(self, dim: int, eps: float = 1e-5):
82
+ super().__init__()
83
+ self.eps = eps
84
+ self.weight = nn.Parameter(torch.ones(dim))
85
+
86
+ def _norm(self, x):
87
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
88
+
89
+ def forward(self, x):
90
+ output = self._norm(x.float()).type_as(x)
91
+ return output * self.weight
92
+
93
+
94
+ class ResidualConnectionModule(nn.Module):
95
+ def __init__(
96
+ self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0
97
+ ):
98
+ super(ResidualConnectionModule, self).__init__()
99
+ self.module = module
100
+ self.module_factor = module_factor
101
+ self.input_factor = input_factor
102
+
103
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
104
+ return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor)
105
+
106
+
107
+ class Linear(nn.Module):
108
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
109
+ super(Linear, self).__init__()
110
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
111
+ nn.init.xavier_uniform_(self.linear.weight)
112
+ if bias:
113
+ nn.init.zeros_(self.linear.bias)
114
+
115
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
116
+ return self.linear(x)
117
+
118
+
119
+ class View(nn.Module):
120
+ def __init__(self, shape: tuple, contiguous: bool = False):
121
+ super(View, self).__init__()
122
+ self.shape = shape
123
+ self.contiguous = contiguous
124
+
125
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
126
+ if self.contiguous:
127
+ x = x.contiguous()
128
+
129
+ return x.view(*self.shape)
130
+
131
+
132
+ class Transpose(nn.Module):
133
+ def __init__(self, shape: tuple):
134
+ super(Transpose, self).__init__()
135
+ self.shape = shape
136
+
137
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
138
+ return x.transpose(*self.shape)
139
+
140
+
141
+ class FeedForwardModule(nn.Module):
142
+ def __init__(
143
+ self,
144
+ encoder_dim: int = 512,
145
+ expansion_factor: int = 4,
146
+ dropout_p: float = 0.1,
147
+ rms_norm: bool = False,
148
+ ) -> None:
149
+ super(FeedForwardModule, self).__init__()
150
+ self.sequential = nn.Sequential(
151
+ nn.LayerNorm(encoder_dim) if not rms_norm else RMSNorm(encoder_dim),
152
+ Linear(encoder_dim, encoder_dim * expansion_factor, bias=True),
153
+ Swish(),
154
+ nn.Dropout(p=dropout_p),
155
+ Linear(encoder_dim * expansion_factor, encoder_dim, bias=True),
156
+ nn.Dropout(p=dropout_p),
157
+ )
158
+
159
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
160
+ return self.sequential(inputs)
161
+
162
+
163
+ class DepthwiseConv1d(nn.Module):
164
+ def __init__(
165
+ self,
166
+ in_channels: int,
167
+ out_channels: int,
168
+ kernel_size: int,
169
+ stride: int = 1,
170
+ padding: int = 0,
171
+ bias: bool = False,
172
+ ) -> None:
173
+ super(DepthwiseConv1d, self).__init__()
174
+ assert (
175
+ out_channels % in_channels == 0
176
+ ), "out_channels should be constant multiple of in_channels"
177
+ self.conv = nn.Conv1d(
178
+ in_channels=in_channels,
179
+ out_channels=out_channels,
180
+ kernel_size=kernel_size,
181
+ groups=in_channels,
182
+ stride=stride,
183
+ padding=padding,
184
+ bias=bias,
185
+ )
186
+
187
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
188
+ return self.conv(inputs)
189
+
190
+
191
+ class PointwiseConv1d(nn.Module):
192
+ def __init__(
193
+ self,
194
+ in_channels: int,
195
+ out_channels: int,
196
+ stride: int = 1,
197
+ padding: int = 0,
198
+ bias: bool = True,
199
+ ) -> None:
200
+ super(PointwiseConv1d, self).__init__()
201
+ self.conv = nn.Conv1d(
202
+ in_channels=in_channels,
203
+ out_channels=out_channels,
204
+ kernel_size=1,
205
+ stride=stride,
206
+ padding=padding,
207
+ bias=bias,
208
+ )
209
+
210
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
211
+ return self.conv(inputs)
212
+
213
+
214
+ class ConformerConvModule(nn.Module):
215
+ def __init__(
216
+ self,
217
+ in_channels: int,
218
+ kernel_size: int = 31,
219
+ expansion_factor: int = 2,
220
+ dropout_p: float = 0.1,
221
+ rms_norm: bool = False,
222
+ ) -> None:
223
+ super(ConformerConvModule, self).__init__()
224
+ assert (
225
+ kernel_size - 1
226
+ ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
227
+ assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
228
+
229
+ self.sequential = nn.Sequential(
230
+ nn.LayerNorm(in_channels) if not rms_norm else RMSNorm(in_channels),
231
+ Transpose(shape=(1, 2)),
232
+ PointwiseConv1d(
233
+ in_channels,
234
+ in_channels * expansion_factor,
235
+ stride=1,
236
+ padding=0,
237
+ bias=True,
238
+ ),
239
+ GLU(dim=1),
240
+ DepthwiseConv1d(
241
+ in_channels,
242
+ in_channels,
243
+ kernel_size,
244
+ stride=1,
245
+ padding=(kernel_size - 1) // 2,
246
+ ),
247
+ nn.BatchNorm1d(in_channels),
248
+ Swish(),
249
+ PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True),
250
+ nn.Dropout(p=dropout_p),
251
+ )
252
+
253
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
254
+ return self.sequential(inputs).transpose(1, 2)
255
+
256
+
257
+ class FramewiseConv2dSubampling(nn.Module):
258
+ def __init__(self, out_channels: int, subsample_rate: int = 2) -> None:
259
+ super(FramewiseConv2dSubampling, self).__init__()
260
+ assert subsample_rate in {2, 4}, "subsample_rate should be 2 or 4"
261
+ self.subsample_rate = subsample_rate
262
+ self.cnn = nn.Sequential(
263
+ nn.Conv2d(1, out_channels, kernel_size=3, stride=2),
264
+ nn.ReLU(),
265
+ nn.Conv2d(
266
+ out_channels,
267
+ out_channels,
268
+ kernel_size=3,
269
+ stride=(2 if subsample_rate == 4 else 1, 2),
270
+ padding=(0 if subsample_rate == 4 else 1, 0),
271
+ ),
272
+ nn.ReLU(),
273
+ )
274
+
275
+ def forward(
276
+ self, inputs: torch.Tensor, input_lengths: torch.Tensor
277
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
278
+ # inputs: (B, T, C) -> (B, 1, T, C)
279
+ if self.subsample_rate == 2 and inputs.shape[1] % 2 == 0:
280
+ inputs = F.pad(inputs, (0, 0, 0, 1), "constant", 0)
281
+ if self.subsample_rate == 4 and inputs.shape[1] % 4 < 3:
282
+ inputs = F.pad(inputs, (0, 0, 0, 3 - inputs.shape[1] % 4), "constant", 0)
283
+ outputs = self.cnn(inputs.unsqueeze(1))
284
+ batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size()
285
+
286
+ outputs = outputs.permute(0, 2, 1, 3)
287
+ outputs = outputs.contiguous().view(
288
+ batch_size, subsampled_lengths, channels * sumsampled_dim
289
+ )
290
+
291
+ if self.subsample_rate == 4:
292
+ output_lengths = input_lengths >> 2
293
+ else:
294
+ output_lengths = input_lengths >> 1
295
+
296
+ return outputs, output_lengths
297
+
298
+ def get_out_dim(self, input_dim: int) -> int:
299
+ # dummy input to get the output dimension
300
+ with torch.no_grad():
301
+ device = next(self.parameters()).device
302
+ inputs = torch.zeros(1, 16, input_dim, device=device)
303
+ input_lengths = torch.tensor([16], device=device)
304
+ outputs, _ = self.forward(inputs, input_lengths)
305
+ return outputs.size(-1)
306
+
307
+
308
+ class PatchwiseConv2dSubampling(nn.Module):
309
+ def __init__(
310
+ self,
311
+ mel_dim: int,
312
+ out_channels: int,
313
+ patch_size_time: int = 16,
314
+ patch_size_freq: int = 16,
315
+ ) -> None:
316
+ super(PatchwiseConv2dSubampling, self).__init__()
317
+
318
+ self.mel_dim = mel_dim
319
+ self.patch_size_time = patch_size_time
320
+ self.patch_size_freq = patch_size_freq
321
+
322
+ self.proj = nn.Conv2d(
323
+ 1,
324
+ out_channels,
325
+ kernel_size=(patch_size_time, patch_size_freq),
326
+ stride=(patch_size_time, patch_size_freq),
327
+ padding=0,
328
+ )
329
+ self.cnn = nn.Sequential(
330
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
331
+ nn.ReLU(),
332
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
333
+ nn.ReLU(),
334
+ )
335
+
336
+ @property
337
+ def subsample_rate(self) -> int:
338
+ return self.patch_size_time * self.patch_size_freq // self.mel_dim
339
+
340
+ def forward(
341
+ self, inputs: torch.Tensor, input_lengths: torch.Tensor
342
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
343
+ assert (
344
+ inputs.shape[2] == self.mel_dim
345
+ ), "inputs.shape[2] should be equal to mel_dim"
346
+
347
+ # inputs: (B, Time, Freq) -> (B, 1, Time, Freq)
348
+ outputs = self.proj(inputs.unsqueeze(1))
349
+ outputs = self.cnn(outputs)
350
+ # (B, channels, Time // patch_size_time, Freq // patch_size_freq)
351
+ outputs = outputs.flatten(2, 3).transpose(1, 2)
352
+ # (B, (Time // patch_size_time) * (Freq // patch_size_freq), channels)
353
+
354
+ output_lengths = (
355
+ input_lengths
356
+ // self.patch_size_time
357
+ * (self.mel_dim // self.patch_size_freq)
358
+ )
359
+
360
+ return outputs, output_lengths
361
+
362
+
363
+ class RelPositionalEncoding(nn.Module):
364
+ def __init__(self, d_model: int) -> None:
365
+ super(RelPositionalEncoding, self).__init__()
366
+ self.d_model = d_model
367
+ self.pe = None
368
+
369
+ def extend_pe(self, x: torch.Tensor) -> None:
370
+ if self.pe is not None:
371
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
372
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
373
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
374
+ return
375
+
376
+ length = x.size(1)
377
+ pe_positive = torch.zeros(length, self.d_model, device="cpu")
378
+ pe_negative = torch.zeros(length, self.d_model, device="cpu")
379
+ position = torch.arange(0, length, dtype=torch.float32, device="cpu").unsqueeze(
380
+ 1
381
+ )
382
+ div_term = torch.exp(
383
+ torch.arange(0, self.d_model, 2, dtype=torch.float32, device="cpu")
384
+ * -(math.log(10000.0) / self.d_model)
385
+ )
386
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
387
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
388
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
389
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
390
+
391
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
392
+ pe_negative = pe_negative[1:].unsqueeze(0)
393
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
394
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
395
+
396
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
397
+ # x: (B, T, C)
398
+ self.extend_pe(x)
399
+ assert self.pe is not None
400
+ pos_emb = self.pe[
401
+ :,
402
+ self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
403
+ ]
404
+ return pos_emb
405
+
406
+
407
+ class RelativeMultiHeadAttention(nn.Module):
408
+ def __init__(
409
+ self,
410
+ d_model: int = 512,
411
+ num_heads: int = 16,
412
+ dropout_p: float = 0.1,
413
+ ):
414
+ super(RelativeMultiHeadAttention, self).__init__()
415
+ assert d_model % num_heads == 0, "d_model % num_heads should be zero."
416
+ self.d_model = d_model
417
+ self.d_head = int(d_model / num_heads)
418
+ self.num_heads = num_heads
419
+ self.sqrt_dim = math.sqrt(self.d_head)
420
+
421
+ self.query_proj = Linear(d_model, d_model)
422
+ self.key_proj = Linear(d_model, d_model)
423
+ self.value_proj = Linear(d_model, d_model)
424
+ self.pos_proj = Linear(d_model, d_model, bias=False)
425
+
426
+ self.dropout = nn.Dropout(p=dropout_p)
427
+ self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
428
+ self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
429
+ torch.nn.init.xavier_uniform_(self.u_bias)
430
+ torch.nn.init.xavier_uniform_(self.v_bias)
431
+
432
+ self.out_proj = Linear(d_model, d_model)
433
+
434
+ @staticmethod
435
+ def _relative_shift(pos_score: torch.Tensor) -> torch.Tensor:
436
+ # pos_score: (B, H, T, 2T-1)
437
+ B, H, T, L = pos_score.size()
438
+
439
+ # Pad on the left of the last dimension: (B, H, T, 2T)
440
+ pos_score = F.pad(pos_score, (1, 0))
441
+
442
+ # Reshape to (B, H, 2T, T)
443
+ pos_score = pos_score.view(B, H, L + 1, T)
444
+
445
+ # Slice and reshape back to (B, H, T, 2T-1)
446
+ pos_score = pos_score[:, :, 1:].view(B, H, T, L)
447
+
448
+ # Keep only first T positions => (B, H, T, T)
449
+ return pos_score[:, :, :, : (L // 2 + 1)]
450
+
451
+ def forward(
452
+ self,
453
+ query: torch.Tensor,
454
+ key: torch.Tensor,
455
+ value: torch.Tensor,
456
+ pos_embedding: torch.Tensor,
457
+ padding_mask: Optional[torch.Tensor] = None,
458
+ *,
459
+ need_weights: bool = False,
460
+ use_sdpa: Optional[bool] = None,
461
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
462
+ """
463
+ - If need_weights=True: returns (output, attn) like your original code.
464
+ - If need_weights=False: returns (output, None) and uses SDPA in eval for speed/memory.
465
+ """
466
+ B, Tq, _ = query.size()
467
+ _, Tk, _ = key.size()
468
+
469
+ # Project
470
+ q = self.query_proj(query) # (B, Tq, C)
471
+ k = self.key_proj(key) # (B, Tk, C)
472
+ v = self.value_proj(value) # (B, Tk, C)
473
+
474
+ # Reshape to (B, H, T, Dh)
475
+ q = q.view(B, Tq, self.num_heads, self.d_head).transpose(1, 2) # (B,H,Tq,Dh)
476
+ k = k.view(B, Tk, self.num_heads, self.d_head).transpose(1, 2) # (B,H,Tk,Dh)
477
+ v = v.view(B, Tk, self.num_heads, self.d_head).transpose(1, 2) # (B,H,Tk,Dh)
478
+
479
+ # Positional projection.
480
+ # IMPORTANT: allow pos_embedding to be (1, 2T-1, C) and broadcast across batch.
481
+ # pos_embedding expected length: 2Tq - 1 for self-attn.
482
+ pB = pos_embedding.size(0)
483
+ p = self.pos_proj(pos_embedding) # (pB, L, C)
484
+ p = p.view(pB, -1, self.num_heads, self.d_head).transpose(1, 2) # (pB,H,L,Dh)
485
+
486
+ # Compute position-based bias (scaled) to feed SDPA or add to scores
487
+ # q_pos: (B,H,Tq,Dh), p^T: (pB,H,Dh,L) -> broadcast on pB if pB==1
488
+ q_pos = q + self.v_bias.unsqueeze(0).unsqueeze(2) # (B,H,Tq,Dh)
489
+ pos_score = torch.matmul(q_pos, p.transpose(-2, -1)) # (B,H,Tq,L)
490
+ pos_bias = self._relative_shift(pos_score) # (B,H,Tq,Tq) for self-attn
491
+ pos_bias = pos_bias.mul(1.0 / self.sqrt_dim) # scale matches SDPA scaling
492
+
493
+ if padding_mask is not None:
494
+ # padding_mask: (B, T) -> (B, 1, 1, T) to broadcast with pos_bias (B, H, Tq, Tk)
495
+ # This masks out key positions that are padded across all heads and queries
496
+ if padding_mask.dtype != torch.bool:
497
+ padding_mask = padding_mask.to(torch.bool)
498
+ pos_bias = pos_bias.masked_fill(padding_mask[:, None, None, :], -1e9)
499
+
500
+ if use_sdpa is None:
501
+ use_sdpa = (not self.training) and (not need_weights)
502
+
503
+ # ---- Fast inference path: no attention matrix materialized ----
504
+ if use_sdpa:
505
+ # Content term uses u_bias
506
+ q_content = q + self.u_bias.unsqueeze(0).unsqueeze(2) # (B,H,Tq,Dh)
507
+
508
+ with sdpa_kernel(
509
+ [
510
+ SDPBackend.FLASH_ATTENTION,
511
+ SDPBackend.EFFICIENT_ATTENTION,
512
+ SDPBackend.MATH,
513
+ ]
514
+ ):
515
+ out = F.scaled_dot_product_attention(
516
+ q_content, # (B,H,Tq,Dh)
517
+ k, # (B,H,Tk,Dh)
518
+ v, # (B,H,Tk,Dh)
519
+ attn_mask=pos_bias, # (B,H,Tq,Tk) additive bias
520
+ dropout_p=0.0, # dropout disabled in inference
521
+ is_causal=False,
522
+ ) # (BH, Tq, Dh)
523
+
524
+ out = out.transpose(1, 2).contiguous().view(B, Tq, self.d_model)
525
+
526
+ return self.out_proj(out), None
527
+
528
+ # ---- Reference path (training / if you need attn weights): matches your math ----
529
+ q_content = q + self.u_bias.unsqueeze(0).unsqueeze(2) # (B,H,Tq,Dh)
530
+ content_score = torch.matmul(q_content, k.transpose(-2, -1)) # (B,H,Tq,Tk)
531
+ content_score = content_score.mul(1.0 / self.sqrt_dim)
532
+
533
+ score = content_score + pos_bias # already scaled
534
+
535
+ attn = F.softmax(score, dim=-1)
536
+ attn = self.dropout(attn)
537
+
538
+ context = torch.matmul(attn, v) # (B,H,Tq,Dh)
539
+ context = context.transpose(1, 2).contiguous().view(B, Tq, self.d_model)
540
+
541
+ return self.out_proj(context), attn
542
+
543
+
544
+ class MultiHeadedSelfAttentionModule(nn.Module):
545
+ def __init__(
546
+ self,
547
+ d_model: int,
548
+ num_heads: int,
549
+ dropout_p: float = 0.1,
550
+ rms_norm: bool = False,
551
+ ):
552
+ super(MultiHeadedSelfAttentionModule, self).__init__()
553
+ self.positional_encoding = RelPositionalEncoding(d_model)
554
+ self.layer_norm = nn.LayerNorm(d_model) if not rms_norm else RMSNorm(d_model)
555
+ self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p)
556
+ self.dropout = nn.Dropout(p=dropout_p)
557
+
558
+ def forward(
559
+ self,
560
+ inputs: torch.Tensor,
561
+ padding_mask: Optional[torch.Tensor] = None,
562
+ pos_embedding: Optional[torch.Tensor] = None,
563
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
564
+ if pos_embedding is None:
565
+ pos_embedding = self.positional_encoding(inputs)
566
+
567
+ inputs = self.layer_norm(inputs)
568
+ outputs, attn = self.attention(
569
+ inputs,
570
+ inputs,
571
+ inputs,
572
+ pos_embedding=pos_embedding,
573
+ padding_mask=padding_mask,
574
+ )
575
+
576
+ return self.dropout(outputs), attn, pos_embedding
577
+
578
+
579
+ class ConformerBlock(nn.Module):
580
+ def __init__(
581
+ self,
582
+ encoder_dim: int = 512,
583
+ attention_type: str = "mhsa",
584
+ num_attention_heads: int = 8,
585
+ feed_forward_expansion_factor: int = 4,
586
+ conv_expansion_factor: int = 2,
587
+ feed_forward_dropout_p: float = 0.1,
588
+ attention_dropout_p: float = 0.1,
589
+ conv_dropout_p: float = 0.1,
590
+ conv_kernel_size: int = 31,
591
+ half_step_residual: bool = True,
592
+ transformer_style: bool = False,
593
+ usad_v2: bool = False,
594
+ pre_norm: bool = False,
595
+ rms_norm: bool = False,
596
+ ):
597
+ super(ConformerBlock, self).__init__()
598
+
599
+ self.transformer_style = transformer_style
600
+ self.attention_type = attention_type
601
+ self.usad_v2 = usad_v2
602
+ self.pre_norm = pre_norm
603
+
604
+ if half_step_residual and not transformer_style:
605
+ self.feed_forward_residual_factor = 0.5
606
+ else:
607
+ self.feed_forward_residual_factor = 1
608
+
609
+ assert (
610
+ attention_type == "mhsa"
611
+ ), "Only 'mhsa' attention is supported in this implementation."
612
+ attention = MultiHeadedSelfAttentionModule(
613
+ d_model=encoder_dim,
614
+ num_heads=num_attention_heads,
615
+ dropout_p=attention_dropout_p,
616
+ rms_norm=rms_norm,
617
+ )
618
+
619
+ self.ffn_1 = FeedForwardModule(
620
+ encoder_dim=encoder_dim,
621
+ expansion_factor=feed_forward_expansion_factor,
622
+ dropout_p=feed_forward_dropout_p,
623
+ rms_norm=rms_norm,
624
+ )
625
+ self.attention = attention
626
+ if not transformer_style:
627
+ self.conv = ConformerConvModule(
628
+ in_channels=encoder_dim,
629
+ kernel_size=conv_kernel_size,
630
+ expansion_factor=conv_expansion_factor,
631
+ dropout_p=conv_dropout_p,
632
+ rms_norm=rms_norm,
633
+ )
634
+ self.ffn_2 = FeedForwardModule(
635
+ encoder_dim=encoder_dim,
636
+ expansion_factor=feed_forward_expansion_factor,
637
+ dropout_p=feed_forward_dropout_p,
638
+ rms_norm=rms_norm,
639
+ )
640
+ self.layernorm = (
641
+ (nn.LayerNorm(encoder_dim) if not rms_norm else RMSNorm(encoder_dim))
642
+ if not pre_norm
643
+ else nn.Identity()
644
+ )
645
+
646
+ def forward_attention(
647
+ self,
648
+ x: torch.Tensor,
649
+ pos_embedding: Optional[torch.Tensor] = None,
650
+ padding_mask: Optional[torch.Tensor] = None,
651
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
652
+ attn_out, attn, pos_embedding = self.attention(
653
+ x, pos_embedding=pos_embedding, padding_mask=padding_mask
654
+ )
655
+ return attn_out, attn, pos_embedding
656
+
657
+ def forward_legacy(
658
+ self,
659
+ x: torch.Tensor,
660
+ pos_embedding: Optional[torch.Tensor] = None,
661
+ padding_mask: Optional[torch.Tensor] = None,
662
+ ) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]:
663
+ # FFN 1
664
+ ffn_1_out = self.ffn_1(x)
665
+ x = ffn_1_out * self.feed_forward_residual_factor + x
666
+
667
+ # Attention
668
+ attn_out, attn, pos_embedding = self.forward_attention(
669
+ x, pos_embedding, padding_mask
670
+ )
671
+ x = attn_out + x
672
+
673
+ if self.transformer_style:
674
+ x = self.layernorm(x)
675
+ return x, {"ffn_1": ffn_1_out, "attn": attn, "conv": None, "ffn_2": None}
676
+
677
+ # Convolution
678
+ conv_out = self.conv(x)
679
+ x = conv_out + x
680
+
681
+ # FFN 2
682
+ ffn_2_out = self.ffn_2(x)
683
+ x = ffn_2_out * self.feed_forward_residual_factor + x
684
+ x = self.layernorm(x)
685
+
686
+ other = {
687
+ "ffn_1": ffn_1_out,
688
+ "attn": attn,
689
+ "conv": conv_out,
690
+ "ffn_2": ffn_2_out,
691
+ "pos_embedding": pos_embedding,
692
+ }
693
+
694
+ return x, other
695
+
696
+ def forward_transformer(
697
+ self,
698
+ x: torch.Tensor,
699
+ pos_embedding: Optional[torch.Tensor] = None,
700
+ padding_mask: Optional[torch.Tensor] = None,
701
+ ) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]:
702
+ # Attention
703
+ attn_out, attn, pos_embedding = self.forward_attention(
704
+ x, pos_embedding, padding_mask
705
+ )
706
+ x = attn_out + x
707
+
708
+ # FFN
709
+ ffn_out = self.ffn_1(x)
710
+ x = ffn_out * self.feed_forward_residual_factor + x
711
+
712
+ x = self.layernorm(x)
713
+ return x, {
714
+ "ffn_1": ffn_out,
715
+ "attn": attn,
716
+ "conv": None,
717
+ "ffn_2": None,
718
+ "pos_embedding": pos_embedding,
719
+ }
720
+
721
+ def forward_conformer(
722
+ self,
723
+ x: torch.Tensor,
724
+ pos_embedding: Optional[torch.Tensor] = None,
725
+ padding_mask: Optional[torch.Tensor] = None,
726
+ ) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]:
727
+ # FFN 1
728
+ ffn_1_out = self.ffn_1(x)
729
+ x = ffn_1_out * self.feed_forward_residual_factor + x
730
+
731
+ # Attention
732
+ attn_out, attn, pos_embedding = self.forward_attention(
733
+ x, pos_embedding, padding_mask
734
+ )
735
+ x = attn_out + x
736
+
737
+ # Convolution
738
+ conv_out = self.conv(x)
739
+ x = conv_out + x
740
+
741
+ # FFN 2
742
+ ffn_2_out = self.ffn_2(x)
743
+ x = ffn_2_out * self.feed_forward_residual_factor + x
744
+ x = self.layernorm(x)
745
+
746
+ other = {
747
+ "ffn_1": ffn_1_out,
748
+ "attn": attn,
749
+ "conv": conv_out,
750
+ "ffn_2": ffn_2_out,
751
+ "pos_embedding": pos_embedding,
752
+ }
753
+
754
+ return x, other
755
+
756
+ def forward(
757
+ self,
758
+ x: torch.Tensor,
759
+ pos_embedding: Optional[torch.Tensor] = None,
760
+ padding_mask: Optional[torch.Tensor] = None,
761
+ ) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]:
762
+ if not self.usad_v2:
763
+ return self.forward_legacy(x, pos_embedding, padding_mask)
764
+
765
+ if self.transformer_style:
766
+ return self.forward_transformer(x, pos_embedding, padding_mask)
767
+
768
+ return self.forward_conformer(x, pos_embedding, padding_mask)
769
+
770
+
771
+ class ConformerEncoder(nn.Module):
772
+ def __init__(self, cfg):
773
+ super(ConformerEncoder, self).__init__()
774
+
775
+ self.cfg = cfg
776
+ self.framewise_subsample = None
777
+ self.patchwise_subsample = None
778
+ self.framewise_in_proj = None
779
+ self.patchwise_in_proj = None
780
+ assert (
781
+ cfg.use_framewise_subsample or cfg.use_patchwise_subsample
782
+ ), "At least one subsampling method should be used"
783
+ if cfg.use_framewise_subsample:
784
+ self.framewise_subsample = FramewiseConv2dSubampling(
785
+ out_channels=cfg.conv_subsample_channels,
786
+ subsample_rate=cfg.conv_subsample_rate,
787
+ )
788
+ self.framewise_in_proj = nn.Sequential(
789
+ Linear(
790
+ self.framewise_subsample.get_out_dim(cfg.input_dim),
791
+ cfg.encoder_dim,
792
+ ),
793
+ nn.Dropout(p=cfg.input_dropout_p),
794
+ )
795
+ if cfg.use_patchwise_subsample:
796
+ self.patchwise_subsample = PatchwiseConv2dSubampling(
797
+ mel_dim=cfg.input_dim,
798
+ out_channels=cfg.conv_subsample_channels,
799
+ patch_size_time=cfg.patch_size_time,
800
+ patch_size_freq=cfg.patch_size_freq,
801
+ )
802
+ self.patchwise_in_proj = nn.Sequential(
803
+ Linear(
804
+ cfg.conv_subsample_channels,
805
+ cfg.encoder_dim,
806
+ ),
807
+ nn.Dropout(p=cfg.input_dropout_p),
808
+ )
809
+ assert not cfg.use_framewise_subsample or (
810
+ cfg.conv_subsample_rate == self.patchwise_subsample.subsample_rate
811
+ ), (
812
+ f"conv_subsample_rate ({cfg.conv_subsample_rate}) != patchwise_subsample.subsample_rate"
813
+ f"({self.patchwise_subsample.subsample_rate})"
814
+ )
815
+
816
+ self.framewise_norm, self.patchwise_norm = None, None
817
+ if getattr(cfg, "subsample_normalization", False):
818
+ if cfg.use_framewise_subsample:
819
+ self.framewise_norm = (
820
+ nn.LayerNorm(cfg.encoder_dim)
821
+ if not getattr(cfg, "rms_norm", False)
822
+ else RMSNorm(cfg.encoder_dim)
823
+ )
824
+ if cfg.use_patchwise_subsample:
825
+ self.patchwise_norm = (
826
+ nn.LayerNorm(cfg.encoder_dim)
827
+ if not getattr(cfg, "rms_norm", False)
828
+ else RMSNorm(cfg.encoder_dim)
829
+ )
830
+
831
+ self.conv_pos = None
832
+ self.conv_pos_post_ln = None
833
+ if cfg.conv_pos:
834
+ num_pos_layers = cfg.conv_pos_depth
835
+ k = max(3, cfg.conv_pos_width // num_pos_layers)
836
+ self.conv_pos = nn.Sequential(
837
+ TransposeLast(),
838
+ *[
839
+ nn.Sequential(
840
+ nn.Conv1d(
841
+ cfg.encoder_dim,
842
+ cfg.encoder_dim,
843
+ kernel_size=k,
844
+ padding=k // 2,
845
+ groups=cfg.conv_pos_groups,
846
+ ),
847
+ SamePad(k),
848
+ TransposeLast(),
849
+ nn.LayerNorm(cfg.encoder_dim, elementwise_affine=False),
850
+ TransposeLast(),
851
+ nn.GELU(),
852
+ )
853
+ for _ in range(num_pos_layers)
854
+ ],
855
+ TransposeLast(),
856
+ )
857
+ self.conv_pos_post_ln = (
858
+ (
859
+ nn.LayerNorm(cfg.encoder_dim)
860
+ if not getattr(cfg, "rms_norm", False)
861
+ else RMSNorm(cfg.encoder_dim)
862
+ )
863
+ if not getattr(cfg, "pre_norm", False)
864
+ else nn.Identity()
865
+ )
866
+
867
+ self.layers = nn.ModuleList(
868
+ [
869
+ ConformerBlock(
870
+ encoder_dim=cfg.encoder_dim,
871
+ attention_type=cfg.attention_type,
872
+ num_attention_heads=cfg.num_attention_heads,
873
+ feed_forward_expansion_factor=cfg.feed_forward_expansion_factor,
874
+ conv_expansion_factor=cfg.conv_expansion_factor,
875
+ feed_forward_dropout_p=cfg.feed_forward_dropout_p,
876
+ attention_dropout_p=cfg.attention_dropout_p,
877
+ conv_dropout_p=cfg.conv_dropout_p,
878
+ conv_kernel_size=cfg.conv_kernel_size,
879
+ half_step_residual=cfg.half_step_residual,
880
+ transformer_style=getattr(cfg, "transformer_style", False),
881
+ usad_v2=getattr(cfg, "usad_v2", False),
882
+ pre_norm=getattr(cfg, "pre_norm", False),
883
+ rms_norm=getattr(cfg, "rms_norm", False),
884
+ )
885
+ for _ in range(cfg.num_layers)
886
+ ]
887
+ )
888
+ self.layerdrop_p = getattr(cfg, "layerdrop_p", 0.0)
889
+
890
+ if cfg.attention_type == "mhsa" and len(self.layers) > 0:
891
+ # Share positional encoding across layers
892
+ shared_pos = None
893
+ for layer in self.layers:
894
+ if isinstance(layer.attention, MultiHeadedSelfAttentionModule):
895
+ if shared_pos is None:
896
+ shared_pos = layer.attention.positional_encoding
897
+ else:
898
+ layer.attention.positional_encoding = shared_pos
899
+ if shared_pos is not None:
900
+ # precompute positional encodings
901
+ # expecting most mel inputs to be fewer than 2000 frames (20 seconds)
902
+ max_len = 2000 // cfg.conv_subsample_rate
903
+ shared_pos.extend_pe(torch.tensor(0.0).expand(1, max_len))
904
+
905
+ def count_parameters(self) -> int:
906
+ """Count parameters of encoder"""
907
+ return sum([p.numel() for p in self.parameters() if p.requires_grad])
908
+
909
+ def update_dropout(self, dropout_p: float) -> None:
910
+ """Update dropout probability of encoder"""
911
+ for name, child in self.named_children():
912
+ if isinstance(child, nn.Dropout):
913
+ child.p = dropout_p
914
+
915
+ def forward(
916
+ self,
917
+ inputs: torch.Tensor,
918
+ input_lengths: Optional[torch.Tensor] = None,
919
+ padding_mask: Optional[torch.Tensor] = None,
920
+ *,
921
+ return_hidden: bool = False,
922
+ freeze_input_layers: bool = False,
923
+ target_layer: Optional[int] = None,
924
+ ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, List[torch.Tensor]]]:
925
+ if input_lengths is None:
926
+ input_lengths = torch.full(
927
+ (inputs.size(0),),
928
+ inputs.size(1),
929
+ dtype=torch.long,
930
+ device=inputs.device,
931
+ )
932
+
933
+ with torch.no_grad() if freeze_input_layers else contextlib.ExitStack():
934
+ frame_feat, patch_feat = None, None
935
+ frame_lengths, patch_lengths = None, None
936
+ if self.framewise_subsample is not None:
937
+ assert self.framewise_in_proj is not None
938
+ frame_feat, frame_lengths = self.framewise_subsample(
939
+ inputs, input_lengths
940
+ )
941
+ frame_feat = self.framewise_in_proj(frame_feat)
942
+ if self.framewise_norm is not None:
943
+ frame_feat = self.framewise_norm(frame_feat)
944
+
945
+ if self.patchwise_subsample is not None:
946
+ assert self.patchwise_in_proj is not None
947
+ patch_feat, patch_lengths = self.patchwise_subsample(
948
+ inputs, input_lengths
949
+ )
950
+ patch_feat = self.patchwise_in_proj(patch_feat)
951
+ if self.patchwise_norm is not None:
952
+ patch_feat = self.patchwise_norm(patch_feat)
953
+
954
+ assert frame_feat is not None or patch_feat is not None
955
+ assert frame_lengths is not None or patch_lengths is not None
956
+
957
+ if frame_feat is not None and patch_feat is not None:
958
+ assert frame_lengths is not None and patch_lengths is not None
959
+ min_len = min(frame_feat.size(1), patch_feat.size(1))
960
+ frame_feat = frame_feat[:, :min_len]
961
+ patch_feat = patch_feat[:, :min_len]
962
+
963
+ features = frame_feat + patch_feat
964
+ output_lengths = (
965
+ frame_lengths
966
+ if frame_lengths.max().item() < patch_lengths.max().item()
967
+ else patch_lengths
968
+ )
969
+ elif frame_feat is not None:
970
+ features = frame_feat
971
+ output_lengths = frame_lengths
972
+ else:
973
+ features = patch_feat
974
+ output_lengths = patch_lengths
975
+
976
+ assert features is not None
977
+ assert output_lengths is not None
978
+
979
+ # Positional encoding with convolutional layers
980
+ if self.conv_pos is not None and self.conv_pos_post_ln is not None:
981
+ pos = self.conv_pos(features)
982
+ if not self.training:
983
+ features = features.add_(pos)
984
+ else:
985
+ features = features + pos
986
+ features = self.conv_pos_post_ln(features)
987
+
988
+ # Create padding mask for attention
989
+ if padding_mask is not None:
990
+ # downsample to match features length
991
+ input_len = padding_mask.size(1)
992
+ feat_len = features.size(1)
993
+ factor = input_len / feat_len
994
+ indices = (
995
+ torch.arange(feat_len, device=padding_mask.device) * factor
996
+ ).long()
997
+ padding_mask = padding_mask.index_select(1, indices)
998
+ else:
999
+ # create from output_lengths
1000
+ padding_mask = lengths_to_padding_mask(
1001
+ output_lengths, max_len=features.size(1)
1002
+ )
1003
+
1004
+ layer_results = defaultdict(list)
1005
+ outputs = features
1006
+ other = {}
1007
+ for i, layer in enumerate(self.layers):
1008
+ if (
1009
+ self.training
1010
+ and self.layerdrop_p > 0
1011
+ and torch.rand(1).item() < self.layerdrop_p
1012
+ ):
1013
+ continue
1014
+ outputs, other = layer(
1015
+ outputs,
1016
+ pos_embedding=other.get("pos_embedding"),
1017
+ padding_mask=padding_mask,
1018
+ )
1019
+ if return_hidden:
1020
+ layer_results["hidden_states"].append(outputs)
1021
+ for k, v in other.items():
1022
+ layer_results[k].append(v)
1023
+
1024
+ if target_layer is not None and i + 1 == target_layer:
1025
+ break
1026
+
1027
+ return outputs, output_lengths, layer_results