mazesmazes commited on
Commit
c5a1720
·
verified ·
1 Parent(s): c7fd68d

Training in progress - step 1000

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
alignment.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Forced alignment for word-level timestamps using Wav2Vec2."""
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def _get_device() -> str:
8
+ """Get best available device for non-transformers models."""
9
+ if torch.cuda.is_available():
10
+ return "cuda"
11
+ if torch.backends.mps.is_available():
12
+ return "mps"
13
+ return "cpu"
14
+
15
+
16
+ class ForcedAligner:
17
+ """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
18
+
19
+ Uses Viterbi trellis algorithm for optimal alignment path finding.
20
+ """
21
+
22
+ _bundle = None
23
+ _model = None
24
+ _labels = None
25
+ _dictionary = None
26
+
27
+ @classmethod
28
+ def get_instance(cls, device: str = "cuda"):
29
+ """Get or create the forced alignment model (singleton).
30
+
31
+ Args:
32
+ device: Device to run model on ("cuda" or "cpu")
33
+
34
+ Returns:
35
+ Tuple of (model, labels, dictionary)
36
+ """
37
+ if cls._model is None:
38
+ import torchaudio
39
+
40
+ cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
41
+ cls._model = cls._bundle.get_model().to(device)
42
+ cls._model.eval()
43
+ cls._labels = cls._bundle.get_labels()
44
+ cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
45
+ return cls._model, cls._labels, cls._dictionary
46
+
47
+ @staticmethod
48
+ def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
49
+ """Build trellis for forced alignment using forward algorithm.
50
+
51
+ The trellis[t, j] represents the log probability of the best path that
52
+ aligns the first j tokens to the first t frames.
53
+
54
+ Args:
55
+ emission: Log-softmax emission matrix of shape (num_frames, num_classes)
56
+ tokens: List of target token indices
57
+ blank_id: Index of the blank/CTC token (default 0)
58
+
59
+ Returns:
60
+ Trellis matrix of shape (num_frames + 1, num_tokens + 1)
61
+ """
62
+ num_frames = emission.size(0)
63
+ num_tokens = len(tokens)
64
+
65
+ trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
66
+ trellis[0, 0] = 0
67
+
68
+ for t in range(num_frames):
69
+ for j in range(num_tokens + 1):
70
+ # Stay: emit blank and stay at j tokens
71
+ stay = trellis[t, j] + emission[t, blank_id]
72
+
73
+ # Move: emit token j and advance to j+1 tokens
74
+ move = trellis[t, j - 1] + emission[t, tokens[j - 1]] if j > 0 else -float("inf")
75
+
76
+ trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
77
+
78
+ return trellis
79
+
80
+ @staticmethod
81
+ def _backtrack(
82
+ trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
83
+ ) -> list[tuple[int, float, float]]:
84
+ """Backtrack through trellis to find optimal forced monotonic alignment.
85
+
86
+ Guarantees:
87
+ - All tokens are emitted exactly once
88
+ - Strictly monotonic: each token's frames come after previous token's
89
+ - No frame skipping or token teleporting
90
+
91
+ Returns list of (token_id, start_frame, end_frame) for each token.
92
+ """
93
+ num_frames = emission.size(0)
94
+ num_tokens = len(tokens)
95
+
96
+ if num_tokens == 0:
97
+ return []
98
+
99
+ # Find the best ending point (should be at num_tokens)
100
+ # But verify trellis reached a valid state
101
+ if trellis[num_frames, num_tokens] == -float("inf"):
102
+ # Alignment failed - fall back to uniform distribution
103
+ frames_per_token = num_frames / num_tokens
104
+ return [
105
+ (tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
106
+ for i in range(num_tokens)
107
+ ]
108
+
109
+ # Backtrack: find where each token transition occurred
110
+ # path[i] = frame where token i was first emitted
111
+ token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
112
+
113
+ t = num_frames
114
+ j = num_tokens
115
+
116
+ while t > 0 and j > 0:
117
+ # Check: did we transition from j-1 to j at frame t-1?
118
+ stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
119
+ move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
120
+
121
+ if move_score >= stay_score:
122
+ # Token j-1 was emitted at frame t-1
123
+ token_frames[j - 1].insert(0, t - 1)
124
+ j -= 1
125
+ # Always decrement time (monotonic)
126
+ t -= 1
127
+
128
+ # Handle any remaining tokens at the start (edge case)
129
+ while j > 0:
130
+ token_frames[j - 1].insert(0, 0)
131
+ j -= 1
132
+
133
+ # Convert to spans
134
+ token_spans: list[tuple[int, float, float]] = []
135
+ for token_idx, frames in enumerate(token_frames):
136
+ if not frames:
137
+ # Token never emitted - assign minimal span after previous
138
+ if token_spans:
139
+ prev_end = token_spans[-1][2]
140
+ frames = [int(prev_end)]
141
+ else:
142
+ frames = [0]
143
+
144
+ token_id = tokens[token_idx]
145
+ start_frame = float(min(frames))
146
+ end_frame = float(max(frames)) + 1.0
147
+ token_spans.append((token_id, start_frame, end_frame))
148
+
149
+ return token_spans
150
+
151
+ # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
152
+ # Calibrated on librispeech-alignments dataset
153
+ START_OFFSET = 0.06 # Subtract from start times (shift earlier)
154
+ END_OFFSET = -0.03 # Add to end times (shift later)
155
+
156
+ @classmethod
157
+ def align(
158
+ cls,
159
+ audio: np.ndarray,
160
+ text: str,
161
+ sample_rate: int = 16000,
162
+ _language: str = "eng",
163
+ _batch_size: int = 16,
164
+ ) -> list[dict]:
165
+ """Align transcript to audio and return word-level timestamps.
166
+
167
+ Uses Viterbi trellis algorithm for optimal forced alignment.
168
+
169
+ Args:
170
+ audio: Audio waveform as numpy array
171
+ text: Transcript text to align
172
+ sample_rate: Audio sample rate (default 16000)
173
+ _language: ISO-639-3 language code (default "eng" for English, unused)
174
+ _batch_size: Batch size for alignment model (unused)
175
+
176
+ Returns:
177
+ List of dicts with 'word', 'start', 'end' keys
178
+ """
179
+ import torchaudio
180
+
181
+ device = _get_device()
182
+ model, _labels, dictionary = cls.get_instance(device)
183
+ assert cls._bundle is not None and dictionary is not None # Initialized by get_instance
184
+
185
+ # Convert audio to tensor (copy to ensure array is writable)
186
+ if isinstance(audio, np.ndarray):
187
+ waveform = torch.from_numpy(audio.copy()).float()
188
+ else:
189
+ waveform = audio.clone().float()
190
+
191
+ # Ensure 2D (channels, time)
192
+ if waveform.dim() == 1:
193
+ waveform = waveform.unsqueeze(0)
194
+
195
+ # Resample if needed (wav2vec2 expects 16kHz)
196
+ if sample_rate != cls._bundle.sample_rate:
197
+ waveform = torchaudio.functional.resample(
198
+ waveform, sample_rate, cls._bundle.sample_rate
199
+ )
200
+
201
+ waveform = waveform.to(device)
202
+
203
+ # Get emissions from model
204
+ with torch.inference_mode():
205
+ emissions, _ = model(waveform)
206
+ emissions = torch.log_softmax(emissions, dim=-1)
207
+
208
+ emission = emissions[0].cpu()
209
+
210
+ # Normalize text: uppercase, keep only valid characters
211
+ transcript = text.upper()
212
+
213
+ # Build tokens from transcript (including word separators)
214
+ tokens = []
215
+ for char in transcript:
216
+ if char in dictionary:
217
+ tokens.append(dictionary[char])
218
+ elif char == " ":
219
+ tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
220
+
221
+ if not tokens:
222
+ return []
223
+
224
+ # Build Viterbi trellis and backtrack for optimal path
225
+ trellis = cls._get_trellis(emission, tokens, blank_id=0)
226
+ alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
227
+
228
+ # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
229
+ frame_duration = 320 / cls._bundle.sample_rate
230
+
231
+ # Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
232
+ start_offset = cls.START_OFFSET
233
+ end_offset = cls.END_OFFSET
234
+
235
+ # Group aligned tokens into words based on pipe separator
236
+ words = text.split()
237
+ word_timestamps = []
238
+ current_word_start = None
239
+ current_word_end = None
240
+ word_idx = 0
241
+ separator_id = dictionary.get("|", dictionary.get(" ", 0))
242
+
243
+ for token_id, start_frame, end_frame in alignment_path:
244
+ if token_id == separator_id: # Word separator
245
+ if (
246
+ current_word_start is not None
247
+ and current_word_end is not None
248
+ and word_idx < len(words)
249
+ ):
250
+ start_time = max(0.0, current_word_start * frame_duration - start_offset)
251
+ end_time = max(0.0, current_word_end * frame_duration - end_offset)
252
+ word_timestamps.append(
253
+ {
254
+ "word": words[word_idx],
255
+ "start": start_time,
256
+ "end": end_time,
257
+ }
258
+ )
259
+ word_idx += 1
260
+ current_word_start = None
261
+ current_word_end = None
262
+ else:
263
+ if current_word_start is None:
264
+ current_word_start = start_frame
265
+ current_word_end = end_frame
266
+
267
+ # Don't forget the last word
268
+ if (
269
+ current_word_start is not None
270
+ and current_word_end is not None
271
+ and word_idx < len(words)
272
+ ):
273
+ start_time = max(0.0, current_word_start * frame_duration - start_offset)
274
+ end_time = max(0.0, current_word_end * frame_duration - end_offset)
275
+ word_timestamps.append(
276
+ {
277
+ "word": words[word_idx],
278
+ "start": start_time,
279
+ "end": end_time,
280
+ }
281
+ )
282
+
283
+ return word_timestamps
asr_config.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import transformers
4
+
5
+
6
+ class ASRConfig(transformers.PretrainedConfig):
7
+ """Configuration class for the ASR model.
8
+
9
+ This config combines settings for:
10
+ - Audio encoder (GLM-ASR/Whisper)
11
+ - Text decoder (Qwen)
12
+ - Projector (MLP, MOSA, MoE, QFormer)
13
+ - Generation parameters
14
+ - Training options (SpecAugment, LoRA)
15
+ """
16
+
17
+ model_type = "asr_model"
18
+ is_composition = True
19
+
20
+ def __init__(
21
+ self,
22
+ audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
23
+ text_model_id: str = "Qwen/Qwen3-0.6B",
24
+ attn_implementation: str = "flash_attention_2",
25
+ model_dtype: str = "bfloat16",
26
+ num_beams: Optional[int] = None,
27
+ system_prompt: str = "You are a helpful assistant.",
28
+ encoder_dim: Optional[int] = None,
29
+ llm_dim: Optional[int] = None,
30
+ # Encoder conv layers: list of (padding, kernel_size, stride) tuples
31
+ # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
32
+ encoder_conv_layers: Optional[list] = None,
33
+ audio_sample_rate: int = 16000,
34
+ projector_pool_stride: int = 4,
35
+ downsample_rate: int = 5, # Granite default
36
+ projector_hidden_dim: Optional[int] = None,
37
+ projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
38
+ projector_num_layers: int = 2, # Number of layers in MLP projector
39
+ projector_init_std: float = 0.02, # Weight initialization std
40
+ projector_dropout: float = 0.0, # Dropout rate for projector layers
41
+ # MoE-specific configuration
42
+ num_experts: int = 4, # Number of experts in MoE projectors
43
+ num_experts_per_tok: int = 2, # Top-k experts per token
44
+ router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
45
+ # QFormer-specific configuration (Granite defaults)
46
+ qformer_window_size: int = 15, # Window size for QFormer processing
47
+ qformer_hidden_size: Optional[int] = None, # QFormer hidden size (defaults to encoder_dim)
48
+ qformer_num_layers: int = 2, # Number of QFormer transformer layers
49
+ qformer_num_heads: int = 16, # Number of attention heads in QFormer
50
+ qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
51
+ label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
52
+ inference_warmup_tokens: int = 10,
53
+ # SpecAugment settings
54
+ use_specaugment: bool = False,
55
+ num_time_masks: int = 2,
56
+ time_mask_length: int = 10,
57
+ num_freq_masks: int = 0,
58
+ freq_mask_length: int = 10,
59
+ # LoRA configuration (for Stage 2 fine-tuning)
60
+ use_lora: bool = False,
61
+ lora_rank: int = 8, # SALMONN default
62
+ lora_alpha: int = 32, # SALMONN default (scaling factor 4.0)
63
+ lora_dropout: float = 0.0,
64
+ lora_target_modules: Optional[list] = None, # Default: all linear layers
65
+ freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
66
+ do_sample: bool = False,
67
+ temperature: Optional[float] = None,
68
+ top_p: Optional[float] = None,
69
+ top_k: Optional[int] = None,
70
+ max_new_tokens: Optional[int] = None,
71
+ min_new_tokens: Optional[int] = None,
72
+ repetition_penalty: Optional[float] = None,
73
+ length_penalty: Optional[float] = None,
74
+ no_repeat_ngram_size: Optional[int] = None,
75
+ use_cache: Optional[bool] = None,
76
+ **kwargs,
77
+ ):
78
+ """Initialize ASR model configuration.
79
+
80
+ Args:
81
+ audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
82
+ text_model_id: HuggingFace model ID for text decoder (Qwen)
83
+ attn_implementation: Attention implementation ("flash_attention_2", "sdpa", "eager")
84
+ model_dtype: Model dtype ("bfloat16", "float16", "float32")
85
+ projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
86
+ use_lora: Enable LoRA adapters for Stage 2 fine-tuning
87
+ use_specaugment: Enable SpecAugment data augmentation
88
+ """
89
+ # Set default generation parameters (greedy decoding only)
90
+ generation_defaults = {
91
+ "num_beams": 1,
92
+ "max_new_tokens": 128,
93
+ "min_new_tokens": 0,
94
+ "repetition_penalty": 1.0,
95
+ "length_penalty": 1.0,
96
+ "no_repeat_ngram_size": 0, # Prevent repeating 3-grams like "so so so"
97
+ "use_cache": True,
98
+ }
99
+
100
+ # Apply defaults (config.json values take precedence)
101
+ kwargs = {**generation_defaults, **kwargs}
102
+
103
+ self.audio_model_id = audio_model_id
104
+ self.text_model_id = text_model_id
105
+ self.attn_implementation = attn_implementation
106
+ self.model_dtype = model_dtype
107
+ self.system_prompt = system_prompt
108
+ self.encoder_dim = encoder_dim
109
+ self.llm_dim = llm_dim
110
+ # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
111
+ self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
112
+ self.audio_sample_rate = audio_sample_rate
113
+ self.projector_init_std = projector_init_std
114
+ self.projector_pool_stride = projector_pool_stride
115
+ self.downsample_rate = downsample_rate
116
+ self.projector_hidden_dim = projector_hidden_dim
117
+ self.projector_type = projector_type
118
+ self.projector_num_layers = projector_num_layers
119
+ self.projector_dropout = projector_dropout
120
+ # MoE-specific configuration
121
+ self.num_experts = num_experts
122
+ self.num_experts_per_tok = num_experts_per_tok
123
+ self.router_aux_loss_coef = router_aux_loss_coef
124
+ # QFormer-specific configuration
125
+ self.qformer_window_size = qformer_window_size
126
+ self.qformer_hidden_size = qformer_hidden_size
127
+ self.qformer_num_layers = qformer_num_layers
128
+ self.qformer_num_heads = qformer_num_heads
129
+ self.qformer_intermediate_size = qformer_intermediate_size
130
+ self.label_smoothing = label_smoothing
131
+ self.inference_warmup_tokens = inference_warmup_tokens
132
+ # SpecAugment configuration
133
+ self.use_specaugment = use_specaugment
134
+ self.num_time_masks = num_time_masks
135
+ self.time_mask_length = time_mask_length
136
+ self.num_freq_masks = num_freq_masks
137
+ self.freq_mask_length = freq_mask_length
138
+ # LoRA configuration
139
+ self.use_lora = use_lora
140
+ self.lora_rank = lora_rank
141
+ self.lora_alpha = lora_alpha
142
+ self.lora_dropout = lora_dropout
143
+ self.lora_target_modules = lora_target_modules or [
144
+ "q_proj",
145
+ "k_proj",
146
+ "v_proj",
147
+ "o_proj",
148
+ "gate_proj",
149
+ "up_proj",
150
+ "down_proj",
151
+ ]
152
+ self.freeze_projector = freeze_projector
153
+
154
+ # Generation parameters (use explicit value if provided, else use default)
155
+ self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
156
+ self.max_new_tokens = (
157
+ max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
158
+ )
159
+ self.min_new_tokens = (
160
+ min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
161
+ )
162
+ self.repetition_penalty = (
163
+ repetition_penalty
164
+ if repetition_penalty is not None
165
+ else generation_defaults["repetition_penalty"]
166
+ )
167
+ self.length_penalty = (
168
+ length_penalty if length_penalty is not None else generation_defaults["length_penalty"]
169
+ )
170
+ self.no_repeat_ngram_size = (
171
+ no_repeat_ngram_size
172
+ if no_repeat_ngram_size is not None
173
+ else generation_defaults["no_repeat_ngram_size"]
174
+ )
175
+ self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
176
+ self.do_sample = do_sample
177
+ self.temperature = temperature
178
+ self.top_p = top_p
179
+ self.top_k = top_k
180
+
181
+ if "audio_config" not in kwargs:
182
+ self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
183
+ # Override dtype to match model_dtype
184
+ self.audio_config.dtype = model_dtype
185
+ else:
186
+ self.audio_config = kwargs.pop("audio_config")
187
+
188
+ if "text_config" not in kwargs:
189
+ self.text_config = transformers.AutoConfig.from_pretrained(
190
+ text_model_id, trust_remote_code=True
191
+ )
192
+ # Override dtype to match model_dtype
193
+ self.text_config.dtype = model_dtype
194
+ else:
195
+ self.text_config = kwargs.pop("text_config")
196
+
197
+ if isinstance(self.text_config, dict):
198
+ # Reconstruct config from dict using the model_type stored in the dict
199
+ model_type = self.text_config["model_type"]
200
+ config_class = transformers.AutoConfig.for_model(model_type).__class__
201
+ self.text_config = config_class(**self.text_config)
202
+
203
+ if isinstance(self.audio_config, dict):
204
+ model_type = self.audio_config.get("model_type")
205
+ if model_type:
206
+ config_class = transformers.AutoConfig.for_model(model_type).__class__
207
+ self.audio_config = config_class(**self.audio_config)
208
+
209
+ super().__init__(**kwargs)
210
+
211
+ # Point encoder to audio_config so pipeline uses correct feature extractor
212
+ # The pipeline looks for config.encoder._name_or_path for feature extractor
213
+ self.encoder = self.audio_config
214
+
215
+ self.auto_map = {
216
+ "AutoConfig": "asr_config.ASRConfig",
217
+ "AutoModel": "asr_modeling.ASRModel",
218
+ "AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
219
+ "AutoProcessor": "asr_processing.ASRProcessor",
220
+ }
221
+ self.custom_pipelines = {
222
+ "automatic-speech-recognition": {
223
+ "impl": "asr_pipeline.ASRPipeline",
224
+ "pt": ["AutoModelForSpeechSeq2Seq"],
225
+ "tf": [],
226
+ "type": "audio",
227
+ }
228
+ }
229
+ self.architectures = ["ASRModel"]
230
+ self.pipeline_tag = "automatic-speech-recognition"
231
+
232
+
233
+ transformers.AutoConfig.register("asr_model", ASRConfig)
asr_modeling.py ADDED
@@ -0,0 +1,906 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from threading import Thread
4
+ from typing import Iterator, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModel,
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ PreTrainedModel,
14
+ TextIteratorStreamer,
15
+ )
16
+ from transformers.generation import GenerationMixin
17
+ from transformers.modeling_outputs import CausalLMOutputWithPast
18
+
19
+ try:
20
+ from .asr_config import ASRConfig
21
+ from .projectors import PROJECTOR_CLASSES
22
+ except ImportError:
23
+ from asr_config import ASRConfig # type: ignore[no-redef]
24
+ from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
25
+
26
+
27
+ from torchaudio.transforms import SpecAugment
28
+
29
+
30
+ class ASRModel(PreTrainedModel, GenerationMixin):
31
+ """Audio-to-text model combining an audio encoder, projector, and language model."""
32
+
33
+ config_class = ASRConfig
34
+ base_model_prefix = "model"
35
+ main_input_name = "input_features"
36
+ _supports_flash_attn_2 = True
37
+ supports_gradient_checkpointing = True
38
+ _is_loading_from_pretrained: bool = False
39
+ _pretrained_model_path: Optional[str] = None
40
+
41
+ TRANSCRIBE_PROMPT = ""
42
+
43
+ @classmethod
44
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
45
+ """Load model from pretrained, handling device placement correctly."""
46
+ from safetensors.torch import load_file
47
+ from transformers.utils.hub import cached_file
48
+
49
+ config = kwargs.pop("config", None)
50
+ if config is None:
51
+ config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
52
+
53
+ # Set flag to avoid device_map="auto" in sub-model loaders
54
+ cls._is_loading_from_pretrained = True
55
+ cls._pretrained_model_path = pretrained_model_name_or_path
56
+
57
+ try:
58
+ model = cls(config, **kwargs)
59
+
60
+ # Load projector weights from safetensors
61
+ subfolder = kwargs.get("subfolder")
62
+ revision = kwargs.get("revision")
63
+ cache_kwargs = {}
64
+ if subfolder:
65
+ cache_kwargs["subfolder"] = subfolder
66
+ if revision:
67
+ cache_kwargs["revision"] = revision
68
+
69
+ model_file = cached_file(
70
+ pretrained_model_name_or_path,
71
+ "model.safetensors",
72
+ _raise_exceptions_for_missing_entries=False,
73
+ **cache_kwargs,
74
+ )
75
+
76
+ if model_file is not None:
77
+ state_dict = load_file(model_file)
78
+ model.load_state_dict(state_dict, strict=False)
79
+
80
+ # Load LoRA adapters if use_lora is enabled
81
+ if getattr(config, "use_lora", False):
82
+ # Check for adapter_config.json (required by PEFT to load adapters)
83
+ adapter_config_file = cached_file(
84
+ pretrained_model_name_or_path,
85
+ "adapter_config.json",
86
+ _raise_exceptions_for_missing_entries=False,
87
+ **cache_kwargs,
88
+ )
89
+ if adapter_config_file is not None:
90
+ # Load saved adapter weights using the original repo_id/path
91
+ # PEFT handles Hub downloads and caching internally
92
+ from peft import PeftModel
93
+
94
+ model.language_model = PeftModel.from_pretrained(
95
+ model.language_model,
96
+ pretrained_model_name_or_path,
97
+ is_trainable=True,
98
+ **cache_kwargs,
99
+ )
100
+ else:
101
+ # No saved adapters - initialize fresh LLM LoRA for training
102
+ from peft import LoraConfig, get_peft_model
103
+
104
+ lora_config = LoraConfig(
105
+ r=config.lora_rank,
106
+ lora_alpha=config.lora_alpha,
107
+ target_modules=config.lora_target_modules,
108
+ lora_dropout=config.lora_dropout,
109
+ bias="none",
110
+ task_type="CAUSAL_LM",
111
+ )
112
+ model.language_model = get_peft_model(model.language_model, lora_config)
113
+
114
+ return model
115
+ finally:
116
+ cls._is_loading_from_pretrained = False
117
+ cls._pretrained_model_path = None
118
+
119
+ def __init__(self, config: ASRConfig, **kwargs) -> None:
120
+ super().__init__(config)
121
+
122
+ self.system_prompt = config.system_prompt
123
+ target_dtype = getattr(torch, config.model_dtype)
124
+
125
+ # Audio encoder (frozen)
126
+ self.audio_tower = self._load_audio_encoder(config, target_dtype)
127
+
128
+ # Language model (frozen)
129
+ self.language_model = self._load_language_model(config, target_dtype)
130
+
131
+ # Initialize tokenizer and special tokens
132
+ self._init_tokenizer(config)
133
+
134
+ # Set up generation config with greedy decoding defaults
135
+ self.generation_config = self.language_model.generation_config
136
+ self.generation_config.max_new_tokens = config.max_new_tokens
137
+ self.generation_config.min_new_tokens = config.min_new_tokens
138
+ self.generation_config.num_beams = config.num_beams
139
+ self.generation_config.do_sample = config.do_sample
140
+ # Set sampling params from config (None means use model defaults)
141
+ self.generation_config.temperature = config.temperature
142
+ self.generation_config.top_p = config.top_p
143
+ self.generation_config.top_k = config.top_k
144
+ self.generation_config.use_cache = config.use_cache
145
+ self.generation_config.length_penalty = config.length_penalty
146
+ self.generation_config.repetition_penalty = config.repetition_penalty
147
+ self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
148
+ # Set EOS tokens, filtering out any that don't exist in the tokenizer
149
+ eos_candidates = [
150
+ self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
151
+ self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
152
+ ]
153
+ self.generation_config.eos_token_id = [t for t in eos_candidates if t is not None]
154
+ self.generation_config.pad_token_id = self.tokenizer.pad_token_id
155
+
156
+ # Feature extractor for audio preprocessing
157
+ self.feature_extractor = self._create_feature_extractor(config)
158
+
159
+ # Audio projector (trainable unless freeze_projector is set)
160
+ self.projector = self._create_projector(config, target_dtype)
161
+
162
+ # Setup LoRA if enabled (Stage 2 fine-tuning)
163
+ # Skip if loading from pretrained - from_pretrained will handle adapter loading
164
+ if getattr(config, "use_lora", False) and not getattr(
165
+ self.__class__, "_is_loading_from_pretrained", False
166
+ ):
167
+ self._setup_lora(config)
168
+
169
+ # Freeze projector if specified (for Stage 2 LoRA-only training)
170
+ if getattr(config, "freeze_projector", False):
171
+ self.projector.requires_grad_(False)
172
+
173
+ # SpecAugment for data augmentation during training
174
+ if getattr(config, "use_specaugment", False):
175
+ self.spec_augment = SpecAugment(
176
+ n_time_masks=config.num_time_masks,
177
+ time_mask_param=config.time_mask_length,
178
+ n_freq_masks=config.num_freq_masks,
179
+ freq_mask_param=config.freq_mask_length,
180
+ )
181
+ else:
182
+ self.spec_augment = None
183
+
184
+ # For model parallelism
185
+ self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
186
+
187
+ def _create_feature_extractor(self, config: ASRConfig):
188
+ """Create the appropriate feature extractor for the audio encoder."""
189
+ from transformers import AutoFeatureExtractor
190
+
191
+ feature_extractor = AutoFeatureExtractor.from_pretrained(config.audio_model_id)
192
+ # Whisper's encoder requires a fixed 3000 mel frames (30s) and the
193
+ # feature extractor pads to that by default — leave it alone. Other
194
+ # encoders (e.g. GLM-ASR) accept variable-length input, so we disable
195
+ # padding to avoid wasting compute on silent frames.
196
+ if "whisper" not in config.audio_model_id.lower():
197
+ feature_extractor.padding = False
198
+ return feature_extractor
199
+
200
+ @classmethod
201
+ def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
202
+ """Load and freeze the audio encoder."""
203
+ encoder_kwargs = {
204
+ "attn_implementation": config.attn_implementation,
205
+ "low_cpu_mem_usage": True,
206
+ "dtype": dtype,
207
+ }
208
+
209
+ if "whisper" in config.audio_model_id.lower():
210
+ from transformers import WhisperModel
211
+
212
+ full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
213
+ encoder = full_model.encoder
214
+ del full_model
215
+ elif "glm" in config.audio_model_id.lower():
216
+ # GLM-ASR models use audio_tower as the encoder
217
+ # Requires transformers >= 5.x or installed from source
218
+ from transformers import AutoModelForSeq2SeqLM
219
+
220
+ full_model = AutoModelForSeq2SeqLM.from_pretrained(
221
+ config.audio_model_id, trust_remote_code=True, **encoder_kwargs
222
+ )
223
+ # GLM stores encoder at audio_tower (GlmAsrEncoder)
224
+ encoder = full_model.audio_tower
225
+ # Clear references to free VRAM from the LLM decoder
226
+ full_model.language_model = None
227
+ full_model.multi_modal_projector = None
228
+ del full_model
229
+ else:
230
+ encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
231
+
232
+ encoder.requires_grad_(False)
233
+ encoder.eval()
234
+ return encoder
235
+
236
+ @classmethod
237
+ def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel:
238
+ """Load and freeze the language model."""
239
+ decoder_kwargs = {
240
+ "attn_implementation": config.attn_implementation,
241
+ "trust_remote_code": True,
242
+ "low_cpu_mem_usage": True,
243
+ "dtype": dtype,
244
+ }
245
+
246
+ decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
247
+ decoder.config.use_cache = getattr(config, "use_cache", True)
248
+ decoder.requires_grad_(False)
249
+ decoder.eval()
250
+ return decoder
251
+
252
+ def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
253
+ """Create the trainable audio projector."""
254
+ # Auto-detect dimensions if not specified
255
+ if config.encoder_dim is None:
256
+ enc_cfg = self.audio_tower.config
257
+ config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr(
258
+ enc_cfg, "d_model", None
259
+ )
260
+ if config.encoder_dim is None:
261
+ raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
262
+
263
+ if config.llm_dim is None:
264
+ dec_cfg = self.language_model.config
265
+ config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr(
266
+ dec_cfg, "d_model", None
267
+ )
268
+ if config.llm_dim is None:
269
+ raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
270
+
271
+ # Select projector type based on config
272
+ projector_type = getattr(config, "projector_type", "mlp")
273
+ projector_class = PROJECTOR_CLASSES.get(projector_type)
274
+ if projector_class is None:
275
+ raise ValueError(
276
+ f"Unknown projector_type: {projector_type}. "
277
+ f"Valid options: {list(PROJECTOR_CLASSES.keys())}"
278
+ )
279
+ projector = projector_class(config)
280
+
281
+ # Move projector to same device as language model (important when using quantization)
282
+ device = next(self.language_model.parameters()).device
283
+ return projector.to(device=device, dtype=dtype)
284
+
285
+ def _setup_lora(self, config: ASRConfig):
286
+ """Apply LoRA adapters to the language model for Stage 2 fine-tuning."""
287
+ from peft import LoraConfig, get_peft_model
288
+
289
+ lora_config = LoraConfig(
290
+ r=config.lora_rank,
291
+ lora_alpha=config.lora_alpha,
292
+ target_modules=config.lora_target_modules,
293
+ lora_dropout=config.lora_dropout,
294
+ bias="none",
295
+ task_type="CAUSAL_LM",
296
+ )
297
+ self.language_model = get_peft_model(self.language_model, lora_config)
298
+
299
+ def _init_tokenizer(self, config: ASRConfig):
300
+ """Initialize tokenizer with audio token."""
301
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
302
+
303
+ # Set pad token. Prefer a dedicated pad token if the tokenizer has one
304
+ # (e.g. Qwen's <|finetune_right_pad_id|>); otherwise fall back to
305
+ # eos_token, which is the standard pattern for Llama-style tokenizers
306
+ # (SmolLM2, Llama, etc.) that ship without a separate pad token.
307
+ if (
308
+ self.tokenizer.pad_token is None
309
+ or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
310
+ ):
311
+ if "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
312
+ self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
313
+ elif self.tokenizer.pad_token is None:
314
+ self.tokenizer.pad_token = self.tokenizer.eos_token
315
+
316
+ # Add audio token
317
+ existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or []
318
+ if "<audio>" not in existing_special:
319
+ self.tokenizer.add_special_tokens(
320
+ {"additional_special_tokens": existing_special + ["<audio>"]}
321
+ )
322
+ self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
323
+
324
+ self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
325
+ self.tokenizer.padding_side = "right"
326
+
327
+ # Sync token IDs to configs
328
+ for cfg in [self.config.text_config, self.language_model.config, self.generation_config]:
329
+ if cfg is not None:
330
+ cfg.pad_token_id = self.tokenizer.pad_token_id
331
+ cfg.eos_token_id = self.tokenizer.eos_token_id
332
+ cfg.bos_token_id = self.tokenizer.bos_token_id
333
+
334
+ def _init_weights(self, _module):
335
+ """Weight initialization (projector weights are initialized in MoEAudioProjector)."""
336
+ pass
337
+
338
+ def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
339
+ """Enable/disable gradient checkpointing for the language model."""
340
+ # The LLM still stores activations during forward for backprop to projector
341
+ # Gradient checkpointing trades compute for memory by recomputing activations
342
+ if hasattr(self.language_model, "_set_gradient_checkpointing"):
343
+ self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
344
+ elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable:
345
+ self.language_model.gradient_checkpointing_enable(
346
+ gradient_checkpointing_kwargs={"use_reentrant": False}
347
+ )
348
+ elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable:
349
+ self.language_model.gradient_checkpointing_disable()
350
+
351
+ def get_input_embeddings(self) -> nn.Module:
352
+ return self.language_model.get_input_embeddings()
353
+
354
+ def set_input_embeddings(self, value: nn.Module) -> None:
355
+ self.language_model.set_input_embeddings(value)
356
+
357
+ def get_output_embeddings(self) -> nn.Module:
358
+ return self.language_model.get_output_embeddings()
359
+
360
+ def set_output_embeddings(self, value: nn.Module) -> None:
361
+ self.language_model.set_output_embeddings(value)
362
+
363
+ def get_processor(self):
364
+ """Get the processor for this model."""
365
+ try:
366
+ from .asr_processing import ASRProcessor
367
+ except ImportError:
368
+ from asr_processing import ASRProcessor # type: ignore[no-redef]
369
+
370
+ return ASRProcessor(
371
+ feature_extractor=self.feature_extractor,
372
+ tokenizer=self.tokenizer,
373
+ projector=self.projector,
374
+ encoder_conv_layers=self.config.encoder_conv_layers,
375
+ )
376
+
377
+ def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
378
+ """Only save trainable projector weights."""
379
+ return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
380
+
381
+ def _compute_encoder_output_lengths(
382
+ self,
383
+ audio_attention_mask: torch.Tensor,
384
+ ) -> torch.Tensor:
385
+ """Compute per-sample encoder output lengths using conv layer formulas.
386
+
387
+ Args:
388
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
389
+
390
+ Returns:
391
+ Tensor of encoder output lengths per sample (batch,)
392
+ """
393
+ # Get mel frame lengths from attention mask
394
+ lengths = audio_attention_mask.sum(dim=-1)
395
+
396
+ # Apply conv layer formulas: output = (input + 2*pad - (kernel-1) - 1) // stride + 1
397
+ for padding, kernel_size, stride in self.config.encoder_conv_layers:
398
+ lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
399
+
400
+ return lengths
401
+
402
+ def _encode_audio(
403
+ self,
404
+ audio_features: torch.Tensor,
405
+ audio_attention_mask: torch.Tensor,
406
+ expected_token_counts: torch.Tensor | None = None,
407
+ ) -> torch.Tensor:
408
+ """Encode audio and project to LLM embedding space.
409
+
410
+ Args:
411
+ audio_features: Mel spectrogram features (batch, n_mels, mel_len)
412
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
413
+ expected_token_counts: Expected number of audio tokens per sample from input_ids.
414
+ If provided, output will match these counts exactly (padding/truncating as needed).
415
+
416
+ Returns:
417
+ Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
418
+ """
419
+ with torch.no_grad():
420
+ encoder_out = self.audio_tower(input_features=audio_features)
421
+ hidden_states = encoder_out.last_hidden_state
422
+
423
+ # Project to LLM space
424
+ audio_embeds = self.projector(hidden_states)
425
+
426
+ # Use expected token counts if provided (from input_ids), otherwise compute from audio
427
+ if expected_token_counts is not None:
428
+ token_counts = expected_token_counts
429
+ else:
430
+ # Compute per-sample encoder output lengths using conv formulas
431
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
432
+ token_counts = torch.tensor(
433
+ [
434
+ self.projector.get_output_length(int(length.item()))
435
+ for length in encoder_lengths
436
+ ],
437
+ device=audio_embeds.device,
438
+ )
439
+
440
+ # Extract embeddings matching expected token counts per sample
441
+ batch_size = audio_embeds.shape[0]
442
+ hidden_dim = audio_embeds.shape[2]
443
+
444
+ result_embeds = []
445
+ for i in range(batch_size):
446
+ count = int(token_counts[i].item())
447
+ sample_embeds = audio_embeds[i, :count, :] # Take first 'count' embeddings
448
+ # Pad with zeros if we don't have enough embeddings
449
+ if sample_embeds.shape[0] < count:
450
+ padding = torch.zeros(
451
+ count - sample_embeds.shape[0],
452
+ hidden_dim,
453
+ device=audio_embeds.device,
454
+ dtype=audio_embeds.dtype,
455
+ )
456
+ sample_embeds = torch.cat([sample_embeds, padding], dim=0)
457
+ result_embeds.append(sample_embeds)
458
+
459
+ return torch.cat(result_embeds, dim=0)
460
+
461
+ def forward(
462
+ self,
463
+ input_ids: Optional[torch.Tensor] = None,
464
+ input_features: Optional[torch.Tensor] = None,
465
+ audio_attention_mask: Optional[torch.Tensor] = None,
466
+ attention_mask: Optional[torch.Tensor] = None,
467
+ position_ids: Optional[torch.Tensor] = None,
468
+ past_key_values: Optional[torch.Tensor] = None,
469
+ inputs_embeds: Optional[torch.Tensor] = None,
470
+ labels: Optional[torch.Tensor] = None,
471
+ use_cache: Optional[bool] = None,
472
+ cache_position: Optional[torch.Tensor] = None,
473
+ **kwargs,
474
+ ) -> CausalLMOutputWithPast:
475
+ """Forward pass for training and inference."""
476
+ # Get text embeddings if not provided
477
+ if inputs_embeds is None:
478
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
479
+
480
+ if input_features is not None and input_ids is not None:
481
+ # Apply SpecAugment during training if enabled
482
+ if self.training and self.spec_augment is not None:
483
+ input_features = self.spec_augment(input_features)
484
+
485
+ # Count expected audio tokens from input_ids (ground truth from collator)
486
+ audio_token_counts = (input_ids == self.audio_token_id).sum(dim=-1)
487
+
488
+ # Encode audio -> flattened (total_audio_tokens, hidden_dim)
489
+ audio_embeds = self._encode_audio(
490
+ input_features, audio_attention_mask, audio_token_counts
491
+ )
492
+
493
+ # Replace <audio> token placeholders with audio embeddings using masked_scatter
494
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
495
+ inputs_embeds = inputs_embeds.masked_scatter(
496
+ audio_token_mask.to(inputs_embeds.device),
497
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
498
+ )
499
+
500
+ # Run through language model (let it compute loss if labels provided)
501
+ outputs = self.language_model(
502
+ attention_mask=attention_mask,
503
+ position_ids=position_ids,
504
+ past_key_values=past_key_values,
505
+ inputs_embeds=inputs_embeds,
506
+ labels=labels,
507
+ use_cache=use_cache,
508
+ cache_position=cache_position,
509
+ **kwargs,
510
+ )
511
+
512
+ # Add auxiliary loss from MoE projectors if available
513
+ if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
514
+ aux_loss = self.projector.get_aux_loss()
515
+ if aux_loss is not None and aux_loss.numel() > 0:
516
+ outputs.loss = outputs.loss + aux_loss.to(outputs.loss.device)
517
+
518
+ return outputs
519
+
520
+ def prepare_inputs_for_generation(self, *args, **kwargs):
521
+ """Prepare inputs for generation, handling audio features for cached decoding."""
522
+ input_features = kwargs.pop("input_features", None)
523
+ cache_position = kwargs.get("cache_position")
524
+
525
+ model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs)
526
+
527
+ # Only pass audio features on the first generation step (cache_position[0] == 0)
528
+ if cache_position is not None and cache_position[0] == 0 and input_features is not None:
529
+ model_inputs["input_features"] = input_features
530
+
531
+ return model_inputs
532
+
533
+ def _get_num_audio_tokens(
534
+ self,
535
+ audio_attention_mask: torch.Tensor,
536
+ ) -> int:
537
+ """Calculate number of audio tokens based on actual audio length.
538
+
539
+ Uses attention mask to get real audio length, then computes:
540
+ mel_frames -> encoder_frames (via conv formulas) -> projector output tokens
541
+ """
542
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
543
+ # Use max length for batch (all samples should have same token count for generation)
544
+ encoder_output_len = int(encoder_lengths.max().item())
545
+ return int(self.projector.get_output_length(encoder_output_len))
546
+
547
+ @torch.no_grad()
548
+ def generate(
549
+ self,
550
+ input_ids: Optional[torch.Tensor] = None,
551
+ input_features: Optional[torch.Tensor] = None,
552
+ audio_attention_mask: Optional[torch.Tensor] = None,
553
+ attention_mask: Optional[torch.Tensor] = None,
554
+ system_prompt: Optional[str] = None,
555
+ **generate_kwargs,
556
+ ) -> torch.Tensor:
557
+ """Generate transcription from audio input.
558
+
559
+ Can be called in two ways:
560
+ 1. With input_ids containing <audio> tokens (from processor)
561
+ 2. With just audio, and we build the prompt internally
562
+ """
563
+ if input_features is None:
564
+ raise ValueError("input_features required for generation")
565
+ if audio_attention_mask is None:
566
+ raise ValueError("audio_attention_mask required for generation")
567
+
568
+ device = input_features.device
569
+ batch_size = input_features.shape[0]
570
+
571
+ # Encode audio -> flattened embeddings
572
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
573
+
574
+ # If input_ids not provided, build prompt with correct number of audio tokens
575
+ if input_ids is None:
576
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
577
+ audio_placeholder = "<audio>" * num_audio_tokens
578
+
579
+ system_prompt = system_prompt or self.system_prompt
580
+
581
+ messages: list[dict[str, str]] = []
582
+ if system_prompt:
583
+ messages.append({"role": "system", "content": system_prompt})
584
+ # Audio tokens only (instruction-free)
585
+ user_content = audio_placeholder
586
+ if self.TRANSCRIBE_PROMPT:
587
+ user_content += " " + self.TRANSCRIBE_PROMPT
588
+ messages.append({"role": "user", "content": user_content})
589
+
590
+ chat_result = self.tokenizer.apply_chat_template(
591
+ messages,
592
+ tokenize=True,
593
+ add_generation_prompt=True,
594
+ return_tensors="pt",
595
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
596
+ )
597
+ input_ids = chat_result.input_ids.to(device)
598
+
599
+ if input_ids.dim() == 1:
600
+ input_ids = input_ids.unsqueeze(0)
601
+ if input_ids.shape[0] == 1 and batch_size > 1:
602
+ input_ids = input_ids.expand(batch_size, -1)
603
+
604
+ attention_mask = torch.ones_like(input_ids)
605
+
606
+ # Get text embeddings and replace audio tokens with audio embeddings
607
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
608
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
609
+ inputs_embeds = inputs_embeds.masked_scatter(
610
+ audio_token_mask.to(inputs_embeds.device),
611
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
612
+ )
613
+
614
+ # Generate using language model
615
+ # Pass both input_ids and inputs_embeds so repetition_penalty works correctly
616
+ # (it needs input_ids to track which tokens have been used)
617
+ output = self.language_model.generate(
618
+ input_ids=input_ids,
619
+ inputs_embeds=inputs_embeds,
620
+ attention_mask=attention_mask,
621
+ generation_config=self.generation_config,
622
+ **generate_kwargs,
623
+ )
624
+
625
+ # When using inputs_embeds with input_ids, generate returns full sequence
626
+ # Strip the input tokens to return only generated tokens
627
+ sequences = output if isinstance(output, torch.Tensor) else output.sequences
628
+ input_len = input_ids.shape[1]
629
+ return sequences[:, input_len:]
630
+
631
+ def generate_streaming(
632
+ self,
633
+ input_features: torch.Tensor,
634
+ audio_attention_mask: torch.Tensor,
635
+ system_prompt: Optional[str] = None,
636
+ **generate_kwargs,
637
+ ) -> Iterator[str]:
638
+ """Generate transcription with streaming token output.
639
+
640
+ Yields partial transcript strings as tokens are generated.
641
+ Reduces time-to-first-word by streaming tokens as they're decoded.
642
+
643
+ Args:
644
+ input_features: Mel spectrogram features (batch, n_mels, mel_len)
645
+ audio_attention_mask: Mask for real vs padded mel frames (batch, mel_len)
646
+ system_prompt: Optional system prompt override
647
+ **generate_kwargs: Additional generation arguments
648
+
649
+ Yields:
650
+ Partial transcript text as each token is generated
651
+ """
652
+ device = input_features.device
653
+ batch_size = input_features.shape[0]
654
+
655
+ # Encode audio -> flattened embeddings
656
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
657
+
658
+ # Build prompt with correct number of audio tokens
659
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
660
+ audio_placeholder = "<audio>" * num_audio_tokens
661
+
662
+ system_prompt = system_prompt or self.system_prompt
663
+
664
+ messages: list[dict[str, str]] = []
665
+ if system_prompt:
666
+ messages.append({"role": "system", "content": system_prompt})
667
+ # Audio tokens only (instruction-free)
668
+ user_content = audio_placeholder
669
+ if self.TRANSCRIBE_PROMPT:
670
+ user_content += " " + self.TRANSCRIBE_PROMPT
671
+ messages.append({"role": "user", "content": user_content})
672
+
673
+ chat_result = self.tokenizer.apply_chat_template(
674
+ messages,
675
+ tokenize=True,
676
+ add_generation_prompt=True,
677
+ return_tensors="pt",
678
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
679
+ )
680
+ input_ids = chat_result.input_ids.to(device)
681
+
682
+ if input_ids.dim() == 1:
683
+ input_ids = input_ids.unsqueeze(0)
684
+ if input_ids.shape[0] == 1 and batch_size > 1:
685
+ input_ids = input_ids.expand(batch_size, -1)
686
+
687
+ attention_mask = torch.ones_like(input_ids)
688
+
689
+ # Get text embeddings and replace audio tokens with audio embeddings
690
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
691
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
692
+ inputs_embeds = inputs_embeds.masked_scatter(
693
+ audio_token_mask.to(inputs_embeds.device),
694
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
695
+ )
696
+
697
+ # Setup streamer for token-by-token output
698
+ streamer = TextIteratorStreamer(
699
+ self.tokenizer,
700
+ skip_prompt=True,
701
+ skip_special_tokens=True,
702
+ )
703
+
704
+ # Prepare generation kwargs
705
+ gen_kwargs = {
706
+ "inputs_embeds": inputs_embeds,
707
+ "attention_mask": attention_mask,
708
+ "generation_config": self.generation_config,
709
+ "streamer": streamer,
710
+ **generate_kwargs,
711
+ }
712
+
713
+ # Run generation in background thread
714
+ thread = Thread(target=self.language_model.generate, kwargs=gen_kwargs)
715
+ thread.start()
716
+
717
+ # Yield tokens as they're generated, filtering out <think>...</think> blocks
718
+ # Start assuming no think block - only filter when we see <think>
719
+ in_think_block = False
720
+ buffer = ""
721
+
722
+ for text in streamer:
723
+ buffer += text
724
+
725
+ # Check for think block start (in case model outputs think blocks)
726
+ while "<think>" in buffer:
727
+ in_think_block = True
728
+ # Yield any text before <think>
729
+ before_think = buffer.split("<think>")[0]
730
+ if before_think:
731
+ yield before_think
732
+ buffer = buffer.split("<think>", 1)[-1]
733
+
734
+ # Check for think block end
735
+ while in_think_block and "</think>" in buffer:
736
+ in_think_block = False
737
+ buffer = buffer.split("</think>", 1)[-1]
738
+
739
+ # Yield text if not in think block
740
+ if not in_think_block and buffer:
741
+ yield buffer
742
+ buffer = ""
743
+
744
+ # Yield any remaining buffer
745
+ if buffer and not in_think_block:
746
+ yield buffer
747
+
748
+ thread.join()
749
+
750
+ @torch.no_grad()
751
+ def generate_text_only(
752
+ self,
753
+ messages: list[dict[str, str]],
754
+ max_new_tokens: int = 256,
755
+ **generate_kwargs,
756
+ ) -> str:
757
+ """Generate text using only the LLM (no audio encoding).
758
+
759
+ Used for SIFT-style response generation from metadata prompts.
760
+
761
+ Args:
762
+ messages: List of chat messages [{"role": "user", "content": "..."}]
763
+ max_new_tokens: Maximum tokens to generate
764
+ **generate_kwargs: Additional generation arguments
765
+
766
+ Returns:
767
+ Generated text response
768
+ """
769
+ device = next(self.language_model.parameters()).device
770
+
771
+ # Apply chat template
772
+ input_ids = self.tokenizer.apply_chat_template(
773
+ messages,
774
+ tokenize=True,
775
+ add_generation_prompt=True,
776
+ return_tensors="pt",
777
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
778
+ ).to(device)
779
+
780
+ if input_ids.dim() == 1:
781
+ input_ids = input_ids.unsqueeze(0)
782
+
783
+ attention_mask = torch.ones_like(input_ids)
784
+
785
+ # Generate using language model directly
786
+ output = self.language_model.generate(
787
+ input_ids=input_ids,
788
+ attention_mask=attention_mask,
789
+ max_new_tokens=max_new_tokens,
790
+ do_sample=False,
791
+ pad_token_id=self.tokenizer.pad_token_id,
792
+ eos_token_id=self.tokenizer.eos_token_id,
793
+ **generate_kwargs,
794
+ )
795
+
796
+ # Decode only the new tokens
797
+ new_tokens = output[0, input_ids.shape[1] :]
798
+ response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
799
+ return response.strip()
800
+
801
+ def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
802
+ """Save model, tokenizer, and processor."""
803
+ import shutil
804
+ from pathlib import Path as PathlibPath
805
+
806
+ save_dir = PathlibPath(save_directory)
807
+ save_dir.mkdir(parents=True, exist_ok=True)
808
+
809
+ # Update config with actual vocab size
810
+ self.config.vocab_size = self.language_model.config.vocab_size
811
+ self.config.text_config.vocab_size = self.language_model.config.vocab_size
812
+
813
+ if hasattr(self.audio_tower.config, "num_mel_bins"):
814
+ self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins
815
+
816
+ # Save model (temporarily remove non-serializable attributes)
817
+ tokenizer = self.tokenizer
818
+ del self.tokenizer
819
+
820
+ try:
821
+ super().save_pretrained(save_dir, **kwargs)
822
+ finally:
823
+ self.tokenizer = tokenizer
824
+
825
+ # Save tokenizer and feature extractor
826
+ self.tokenizer.save_pretrained(save_dir)
827
+ self.feature_extractor.save_pretrained(save_dir)
828
+
829
+ # Save LoRA adapters if present (creates adapter_model.safetensors and adapter_config.json)
830
+ # Don't save embedding layers - the <audio> token embedding is never used
831
+ # (it's replaced with projected audio embeddings before the LLM sees it)
832
+ if hasattr(self.language_model, "peft_config"):
833
+ self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
834
+
835
+ # Clear base_model_name_or_path in adapter_config.json to prevent HF pipeline
836
+ # from redirecting to the base LLM repo (like Qwen) which breaks feature
837
+ # extractor loading for multimodal models. If a repo_id is provided, use that
838
+ # so the model can be loaded directly from the Hub.
839
+ adapter_config_path = save_dir / "adapter_config.json"
840
+ if adapter_config_path.exists():
841
+ with adapter_config_path.open() as f:
842
+ adapter_config = json.load(f)
843
+
844
+ # Use repo_id if available, otherwise clear to prevent redirect.
845
+ # Use empty string instead of None to avoid str(None) -> "None" bug
846
+ # in some transformers/PEFT versions.
847
+ repo_id = (
848
+ kwargs.get("repo_id")
849
+ or kwargs.get("push_to_hub_model_id")
850
+ or getattr(self.config, "pretrained_model_path", None)
851
+ or "" # Use empty string instead of None
852
+ )
853
+ adapter_config["base_model_name_or_path"] = repo_id
854
+
855
+ with adapter_config_path.open("w") as f:
856
+ json.dump(adapter_config, f, indent=2)
857
+
858
+ # Add processor auto_map to preprocessor_config.json
859
+ config_path = save_dir / "preprocessor_config.json"
860
+ if config_path.exists():
861
+ with config_path.open() as f:
862
+ processor_config = json.load(f)
863
+ else:
864
+ processor_config = {}
865
+
866
+ processor_config.update(
867
+ {
868
+ "processor_class": "ASRProcessor",
869
+ "auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
870
+ }
871
+ )
872
+
873
+ with config_path.open("w") as f:
874
+ json.dump(processor_config, f, indent=2)
875
+
876
+ # Copy source files for auto-loading
877
+ src_dir = PathlibPath(__file__).parent
878
+ for asr_file in src_dir.glob("asr_*.py"):
879
+ shutil.copy(asr_file, save_dir / asr_file.name)
880
+ # Copy projectors module
881
+ shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
882
+ # Copy alignment module
883
+ shutil.copy(src_dir / "alignment.py", save_dir / "alignment.py")
884
+ # Copy diarization module
885
+ shutil.copy(src_dir / "diarization.py", save_dir / "diarization.py")
886
+
887
+ def push_to_hub(self, repo_id: str, **kwargs) -> str:
888
+ """Push model to HuggingFace Hub, ensuring adapter_config points to repo.
889
+
890
+ IMPORTANT: Sets base_model_name_or_path in adapter_config.json to repo_id
891
+ so that transformers pipeline() can load the model correctly. Without this,
892
+ the pipeline tries to load from "None" which fails.
893
+ """
894
+ # Store repo_id in config so save_pretrained can access it
895
+ self.config.pretrained_model_path = repo_id
896
+ # Call parent's push_to_hub
897
+ return super().push_to_hub(repo_id, **kwargs)
898
+
899
+ def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
900
+ """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
901
+ pass
902
+
903
+
904
+ # Register with transformers Auto classes
905
+ AutoConfig.register("asr_model", ASRConfig)
906
+ AutoModel.register(ASRConfig, ASRModel)
asr_pipeline.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ASR pipeline for audio-to-text transcription with optional timestamps and diarization."""
2
+
3
+ import re
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import torch
9
+ import transformers
10
+
11
+ try:
12
+ from .alignment import ForcedAligner
13
+ from .asr_modeling import ASRModel
14
+ from .diarization import SpeakerDiarizer
15
+ except ImportError:
16
+ from alignment import ForcedAligner # type: ignore[no-redef]
17
+ from asr_modeling import ASRModel # type: ignore[no-redef]
18
+ from diarization import SpeakerDiarizer # type: ignore[no-redef]
19
+
20
+ # Re-export for backwards compatibility
21
+ __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"]
22
+
23
+
24
+ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
25
+ """ASR Pipeline for audio-to-text transcription."""
26
+
27
+ model: ASRModel
28
+
29
+ def __init__(self, model: ASRModel, **kwargs):
30
+ """Initialize ASR pipeline.
31
+
32
+ Args:
33
+ model: ASRModel instance for transcription
34
+ **kwargs: Additional arguments (feature_extractor, tokenizer, device)
35
+ """
36
+ feature_extractor = kwargs.pop("feature_extractor", None)
37
+ tokenizer = kwargs.pop("tokenizer", model.tokenizer)
38
+
39
+ if feature_extractor is None:
40
+ feature_extractor = model.get_processor().feature_extractor
41
+
42
+ super().__init__(
43
+ model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
44
+ )
45
+ self._current_audio = None
46
+
47
+ def _sanitize_parameters(self, **kwargs):
48
+ """Intercept our custom parameters before parent class validates them."""
49
+ # Remove our custom parameters so parent doesn't see them
50
+ kwargs.pop("return_timestamps", None)
51
+ kwargs.pop("return_speakers", None)
52
+ kwargs.pop("num_speakers", None)
53
+ kwargs.pop("min_speakers", None)
54
+ kwargs.pop("max_speakers", None)
55
+ kwargs.pop("hf_token", None)
56
+ kwargs.pop("user_prompt", None)
57
+ kwargs.pop("diarization_backend", None)
58
+
59
+ return super()._sanitize_parameters(**kwargs)
60
+
61
+ def __call__(
62
+ self,
63
+ inputs,
64
+ **kwargs,
65
+ ):
66
+ """Transcribe audio with optional word-level timestamps and speaker diarization.
67
+
68
+ Args:
69
+ inputs: Audio input (file path, dict with array/sampling_rate, etc.)
70
+ return_timestamps: If True, return word-level timestamps using forced alignment
71
+ return_speakers: If True, return speaker labels for each word
72
+ user_prompt: Custom transcription prompt (default: "Transcribe: ")
73
+ num_speakers: Exact number of speakers (if known, for diarization)
74
+ min_speakers: Minimum number of speakers (for diarization)
75
+ max_speakers: Maximum number of speakers (for diarization)
76
+ **kwargs: Additional arguments passed to the pipeline
77
+
78
+ Returns:
79
+ Dict with 'text' key, 'words' key if return_timestamps=True,
80
+ and speaker labels on words if return_speakers=True
81
+ """
82
+ # Extract our params before super().__call__ (which will also call _sanitize_parameters)
83
+ return_timestamps = kwargs.pop("return_timestamps", False)
84
+ return_speakers = kwargs.pop("return_speakers", False)
85
+ user_prompt = kwargs.pop("user_prompt", None)
86
+ diarization_params = {
87
+ "num_speakers": kwargs.pop("num_speakers", None),
88
+ "min_speakers": kwargs.pop("min_speakers", None),
89
+ "max_speakers": kwargs.pop("max_speakers", None),
90
+ }
91
+
92
+ if return_speakers:
93
+ return_timestamps = True
94
+
95
+ # Set custom user prompt if provided
96
+ original_prompt = None
97
+ if user_prompt:
98
+ original_prompt = self.model.TRANSCRIBE_PROMPT
99
+ self.model.TRANSCRIBE_PROMPT = user_prompt
100
+
101
+ # Store audio for timestamp alignment and diarization
102
+ if return_timestamps or return_speakers:
103
+ self._current_audio = self._extract_audio(inputs)
104
+
105
+ # Run standard transcription
106
+ result = super().__call__(inputs, **kwargs)
107
+
108
+ # Add timestamps if requested
109
+ if return_timestamps and self._current_audio is not None:
110
+ text = result.get("text", "")
111
+ if text:
112
+ try:
113
+ words = ForcedAligner.align(
114
+ self._current_audio["array"],
115
+ text,
116
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
117
+ )
118
+ result["words"] = words
119
+ except Exception as e:
120
+ result["words"] = []
121
+ result["timestamp_error"] = str(e)
122
+ else:
123
+ result["words"] = []
124
+
125
+ # Add speaker diarization if requested
126
+ if return_speakers and self._current_audio is not None:
127
+ try:
128
+ # Run diarization
129
+ speaker_segments = SpeakerDiarizer.diarize(
130
+ self._current_audio["array"],
131
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
132
+ **{k: v for k, v in diarization_params.items() if v is not None},
133
+ )
134
+ result["speaker_segments"] = speaker_segments
135
+
136
+ # Assign speakers to words
137
+ if result.get("words"):
138
+ result["words"] = SpeakerDiarizer.assign_speakers_to_words(
139
+ result["words"],
140
+ speaker_segments,
141
+ )
142
+ except Exception as e:
143
+ result["speaker_segments"] = []
144
+ result["diarization_error"] = str(e)
145
+
146
+ # Clean up
147
+ self._current_audio = None
148
+ if original_prompt is not None:
149
+ self.model.TRANSCRIBE_PROMPT = original_prompt
150
+
151
+ return result
152
+
153
+ def _extract_audio(self, inputs) -> dict | None:
154
+ """Extract audio array from various input formats using HF utilities."""
155
+ from transformers.pipelines.audio_utils import ffmpeg_read
156
+
157
+ if isinstance(inputs, dict):
158
+ if "array" in inputs:
159
+ return {
160
+ "array": inputs["array"],
161
+ "sampling_rate": inputs.get("sampling_rate", 16000),
162
+ }
163
+ if "raw" in inputs:
164
+ return {
165
+ "array": inputs["raw"],
166
+ "sampling_rate": inputs.get("sampling_rate", 16000),
167
+ }
168
+ elif isinstance(inputs, str):
169
+ # File path - load audio using ffmpeg (same as HF pipeline)
170
+ with Path(inputs).open("rb") as f:
171
+ audio = ffmpeg_read(f.read(), sampling_rate=16000)
172
+ return {"array": audio, "sampling_rate": 16000}
173
+ elif isinstance(inputs, bytes):
174
+ audio = ffmpeg_read(inputs, sampling_rate=16000)
175
+ return {"array": audio, "sampling_rate": 16000}
176
+ elif isinstance(inputs, np.ndarray):
177
+ return {"array": inputs, "sampling_rate": 16000}
178
+
179
+ return None
180
+
181
+ def preprocess(self, inputs, **preprocess_params):
182
+ """Preprocess audio inputs for the model.
183
+
184
+ Args:
185
+ inputs: Audio input (dict with array, file path, etc.)
186
+ **preprocess_params: Additional preprocessing parameters
187
+
188
+ Yields:
189
+ Model input dicts with input_features and attention_mask
190
+ """
191
+ # Handle dict with "array" key (from datasets)
192
+ if isinstance(inputs, dict) and "array" in inputs:
193
+ inputs = {
194
+ "raw": inputs["array"],
195
+ "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
196
+ }
197
+
198
+ for item in super().preprocess(inputs, **preprocess_params):
199
+ if "is_last" not in item:
200
+ item["is_last"] = True
201
+ yield item
202
+
203
+ def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
204
+ """Run model forward pass to generate transcription.
205
+
206
+ Args:
207
+ model_inputs: Dict with input_features and attention_mask
208
+ **generate_kwargs: Generation parameters
209
+
210
+ Returns:
211
+ Dict with generated token IDs
212
+ """
213
+ # Extract audio features and is_last flag
214
+ is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
215
+
216
+ input_features = model_inputs["input_features"].to(self.model.device)
217
+ audio_attention_mask = model_inputs["attention_mask"].to(self.model.device)
218
+
219
+ generated_ids = self.model.generate(
220
+ input_features=input_features,
221
+ audio_attention_mask=audio_attention_mask,
222
+ **generate_kwargs,
223
+ )
224
+
225
+ return {"tokens": generated_ids, "is_last": is_last}
226
+
227
+ def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
228
+ """Convert model output tokens to text.
229
+
230
+ Args:
231
+ model_outputs: Dict with 'tokens' key containing generated IDs
232
+ **kwargs: Additional postprocessing parameters
233
+
234
+ Returns:
235
+ Dict with 'text' key containing transcription
236
+ """
237
+ # Handle list of outputs (from chunking)
238
+ if isinstance(model_outputs, list):
239
+ model_outputs = model_outputs[0] if model_outputs else {}
240
+
241
+ tokens = model_outputs.get("tokens")
242
+ if tokens is None:
243
+ return super().postprocess(model_outputs, **kwargs)
244
+
245
+ if torch.is_tensor(tokens):
246
+ tokens = tokens.cpu()
247
+ if tokens.dim() > 1:
248
+ tokens = tokens[0]
249
+
250
+ # Filter out eos tokens that the tokenizer doesn't recognize as special
251
+ # (generation_config.eos_token_id may differ from tokenizer.eos_token_id)
252
+ if hasattr(self, "model") and hasattr(self.model, "generation_config"):
253
+ eos_ids = self.model.generation_config.eos_token_id
254
+ if eos_ids is not None:
255
+ eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids}
256
+ tokens = [t for t in tokens.tolist() if t not in eos_set]
257
+
258
+ text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
259
+ # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
260
+ text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
261
+ # Truncate repetitions at end of text
262
+ text = _truncate_repetitions(text)
263
+ return {"text": text}
264
+
265
+
266
+ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
267
+ """Truncate repeated words/phrases/characters at end of text.
268
+
269
+ Detects patterns like:
270
+ - Repeated words: "the the the the" -> "the"
271
+ - Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry"
272
+ - Repeated characters: "444444" -> "4"
273
+
274
+ Args:
275
+ text: Input text to process
276
+ min_repeats: Minimum repetitions to trigger truncation (default 3)
277
+
278
+ Returns:
279
+ Text with trailing repetitions removed
280
+ """
281
+ if not text:
282
+ return text
283
+
284
+ # 1. Truncate repeated characters at end (e.g., "444444" -> "4")
285
+ char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$")
286
+ text = char_pattern.sub(r"\1", text)
287
+
288
+ # 2. Truncate repeated words at end (e.g., "the the the" -> "the")
289
+ word_pattern = re.compile(
290
+ r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
291
+ )
292
+ while word_pattern.search(text):
293
+ text = word_pattern.sub(r"\1", text)
294
+
295
+ # 3. Truncate repeated phrases (2-20 words) at end
296
+ # e.g., "i am sorry i am sorry i am sorry" -> "i am sorry"
297
+ words = text.split()
298
+ if len(words) >= min_repeats * 2:
299
+ # Try phrase lengths from 2 to 20 words
300
+ for phrase_len in range(2, min(21, len(words) // min_repeats + 1)):
301
+ # Check if the last phrase_len words repeat
302
+ phrase = " ".join(words[-phrase_len:])
303
+ # Build pattern to match repeated phrases at end
304
+ phrase_escaped = re.escape(phrase)
305
+ phrase_pattern = re.compile(
306
+ r"(^|.*?\s)("
307
+ + phrase_escaped
308
+ + r")(?:\s+"
309
+ + phrase_escaped
310
+ + r"){"
311
+ + str(min_repeats - 1)
312
+ + r",}\s*$",
313
+ re.IGNORECASE,
314
+ )
315
+ match = phrase_pattern.match(text)
316
+ if match:
317
+ # Keep prefix + one instance of the phrase
318
+ text = (match.group(1) + match.group(2)).strip()
319
+ words = text.split()
320
+ break
321
+
322
+ return text
asr_processing.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import transformers
5
+ from transformers import ProcessorMixin
6
+
7
+ try:
8
+ from .asr_config import ASRConfig
9
+ except ImportError:
10
+ from asr_config import ASRConfig # type: ignore[no-redef]
11
+
12
+
13
+ class ASRProcessor(ProcessorMixin):
14
+ """Processor for Whisper-based ASR models."""
15
+
16
+ attributes = ["feature_extractor", "tokenizer"]
17
+ feature_extractor_class = "AutoFeatureExtractor"
18
+ tokenizer_class = "AutoTokenizer"
19
+ AUDIO_TOKEN = "<audio>"
20
+ TRANSCRIBE_PROMPT = ""
21
+ # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
+ DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
+
24
+ def __init__(
25
+ self,
26
+ feature_extractor,
27
+ tokenizer,
28
+ projector=None,
29
+ encoder_conv_layers: Optional[list] = None,
30
+ ):
31
+ """Initialize the ASR processor.
32
+
33
+ Args:
34
+ feature_extractor: Audio feature extractor (WhisperFeatureExtractor)
35
+ tokenizer: Text tokenizer for the language model
36
+ projector: Audio projector module (for computing output lengths)
37
+ encoder_conv_layers: Conv layer specs [(pad, kernel, stride), ...]
38
+ """
39
+ self.feature_extractor = feature_extractor
40
+ self.tokenizer = tokenizer
41
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
42
+ self.projector = projector
43
+ self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
44
+
45
+ def _compute_encoder_output_length(self, mel_length: int) -> int:
46
+ """Compute encoder output length using conv layer formulas."""
47
+ length = mel_length
48
+ for padding, kernel_size, stride in self.encoder_conv_layers:
49
+ length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
50
+ return length
51
+
52
+ def __call__(
53
+ self,
54
+ audio: Optional[Union[list, "torch.Tensor"]] = None,
55
+ text: Optional[str] = None,
56
+ system_prompt: Optional[str] = None,
57
+ return_tensors: str = "pt",
58
+ **kwargs,
59
+ ) -> dict:
60
+ """Process audio and text inputs for inference.
61
+
62
+ Args:
63
+ audio: Raw audio waveform(s)
64
+ text: Target transcription (optional, for training - but use DataCollator instead)
65
+ system_prompt: Optional system prompt
66
+ return_tensors: Return format ("pt" for PyTorch)
67
+
68
+ Returns:
69
+ Dict with input_features, input_ids, attention_mask
70
+ """
71
+ result = {}
72
+
73
+ # Process audio
74
+ if audio is not None:
75
+ audio_inputs = self.feature_extractor(
76
+ audio,
77
+ sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
78
+ return_attention_mask=True,
79
+ return_tensors=return_tensors,
80
+ **kwargs,
81
+ )
82
+ result["input_features"] = audio_inputs["input_features"]
83
+ result["audio_attention_mask"] = audio_inputs["attention_mask"]
84
+
85
+ # Use actual audio length (from attention mask) for token count
86
+ real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
87
+ encoder_output_len = self._compute_encoder_output_length(real_mel_len)
88
+ num_audio_tokens = self.projector.get_output_length(encoder_output_len)
89
+ else:
90
+ num_audio_tokens = 0
91
+
92
+ # Build prompt with audio token placeholders (instruction-free)
93
+ if num_audio_tokens > 0:
94
+ user_content = self.AUDIO_TOKEN * num_audio_tokens
95
+ if self.TRANSCRIBE_PROMPT:
96
+ user_content += " " + self.TRANSCRIBE_PROMPT
97
+ else:
98
+ user_content = self.TRANSCRIBE_PROMPT or ""
99
+
100
+ messages = []
101
+ if system_prompt:
102
+ messages.append({"role": "system", "content": system_prompt})
103
+ messages.append({"role": "user", "content": user_content})
104
+ if text is not None:
105
+ messages.append({"role": "assistant", "content": text})
106
+
107
+ # Tokenize
108
+ tokenized = self.tokenizer.apply_chat_template(
109
+ messages,
110
+ tokenize=True,
111
+ add_generation_prompt=(text is None),
112
+ return_tensors=return_tensors,
113
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
114
+ )
115
+
116
+ # Handle both tensor and BatchEncoding returns
117
+ if isinstance(tokenized, torch.Tensor):
118
+ input_ids = tokenized
119
+ else:
120
+ # BatchEncoding or dict-like object
121
+ input_ids = tokenized.get("input_ids", tokenized.input_ids)
122
+
123
+ if input_ids.dim() == 1:
124
+ input_ids = input_ids.unsqueeze(0)
125
+
126
+ result["input_ids"] = input_ids
127
+ result["attention_mask"] = torch.ones_like(input_ids)
128
+
129
+ return result
130
+
131
+
132
+ ASRProcessor.register_for_auto_class()
133
+ transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
chat_template.jinja ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
2
+ You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
3
+ ' }}{% endif %}{{'<|im_start|>' + message['role'] + '
4
+ ' + message['content'] + '<|im_end|>' + '
5
+ '}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
6
+ ' }}{% endif %}
config.json ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ASRModel"
4
+ ],
5
+ "attn_implementation": "sdpa",
6
+ "audio_config": {
7
+ "_name_or_path": "openai/whisper-base",
8
+ "activation_dropout": 0.0,
9
+ "activation_function": "gelu",
10
+ "apply_spec_augment": false,
11
+ "architectures": [
12
+ "WhisperForConditionalGeneration"
13
+ ],
14
+ "attention_dropout": 0.0,
15
+ "begin_suppress_tokens": [
16
+ 220,
17
+ 50257
18
+ ],
19
+ "bos_token_id": 50257,
20
+ "classifier_proj_size": 256,
21
+ "d_model": 512,
22
+ "decoder_attention_heads": 8,
23
+ "decoder_ffn_dim": 2048,
24
+ "decoder_layerdrop": 0.0,
25
+ "decoder_layers": 6,
26
+ "decoder_start_token_id": 50258,
27
+ "dropout": 0.0,
28
+ "dtype": "bfloat16",
29
+ "encoder_attention_heads": 8,
30
+ "encoder_ffn_dim": 2048,
31
+ "encoder_layerdrop": 0.0,
32
+ "encoder_layers": 6,
33
+ "eos_token_id": 50257,
34
+ "forced_decoder_ids": [
35
+ [
36
+ 1,
37
+ 50259
38
+ ],
39
+ [
40
+ 2,
41
+ 50359
42
+ ],
43
+ [
44
+ 3,
45
+ 50363
46
+ ]
47
+ ],
48
+ "init_std": 0.02,
49
+ "mask_feature_length": 10,
50
+ "mask_feature_min_masks": 0,
51
+ "mask_feature_prob": 0.0,
52
+ "mask_time_length": 10,
53
+ "mask_time_min_masks": 2,
54
+ "mask_time_prob": 0.05,
55
+ "max_source_positions": 1500,
56
+ "max_target_positions": 448,
57
+ "median_filter_width": 7,
58
+ "model_type": "whisper",
59
+ "num_mel_bins": 80,
60
+ "pad_token_id": 50257,
61
+ "scale_embedding": false,
62
+ "suppress_tokens": [
63
+ 1,
64
+ 2,
65
+ 7,
66
+ 8,
67
+ 9,
68
+ 10,
69
+ 14,
70
+ 25,
71
+ 26,
72
+ 27,
73
+ 28,
74
+ 29,
75
+ 31,
76
+ 58,
77
+ 59,
78
+ 60,
79
+ 61,
80
+ 62,
81
+ 63,
82
+ 90,
83
+ 91,
84
+ 92,
85
+ 93,
86
+ 359,
87
+ 503,
88
+ 522,
89
+ 542,
90
+ 873,
91
+ 893,
92
+ 902,
93
+ 918,
94
+ 922,
95
+ 931,
96
+ 1350,
97
+ 1853,
98
+ 1982,
99
+ 2460,
100
+ 2627,
101
+ 3246,
102
+ 3253,
103
+ 3268,
104
+ 3536,
105
+ 3846,
106
+ 3961,
107
+ 4183,
108
+ 4667,
109
+ 6585,
110
+ 6647,
111
+ 7273,
112
+ 9061,
113
+ 9383,
114
+ 10428,
115
+ 10929,
116
+ 11938,
117
+ 12033,
118
+ 12331,
119
+ 12562,
120
+ 13793,
121
+ 14157,
122
+ 14635,
123
+ 15265,
124
+ 15618,
125
+ 16553,
126
+ 16604,
127
+ 18362,
128
+ 18956,
129
+ 20075,
130
+ 21675,
131
+ 22520,
132
+ 26130,
133
+ 26161,
134
+ 26435,
135
+ 28279,
136
+ 29464,
137
+ 31650,
138
+ 32302,
139
+ 32470,
140
+ 36865,
141
+ 42863,
142
+ 47425,
143
+ 49870,
144
+ 50254,
145
+ 50258,
146
+ 50358,
147
+ 50359,
148
+ 50360,
149
+ 50361,
150
+ 50362
151
+ ],
152
+ "tie_word_embeddings": true,
153
+ "use_cache": true,
154
+ "use_weighted_layer_sum": false,
155
+ "vocab_size": 51865
156
+ },
157
+ "audio_model_id": "openai/whisper-base",
158
+ "audio_sample_rate": 16000,
159
+ "auto_map": {
160
+ "AutoConfig": "asr_config.ASRConfig",
161
+ "AutoModel": "asr_modeling.ASRModel",
162
+ "AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
163
+ "AutoProcessor": "asr_processing.ASRProcessor"
164
+ },
165
+ "custom_pipelines": {
166
+ "automatic-speech-recognition": {
167
+ "impl": "asr_pipeline.ASRPipeline",
168
+ "pt": [
169
+ "AutoModelForSpeechSeq2Seq"
170
+ ],
171
+ "tf": [],
172
+ "type": "audio"
173
+ }
174
+ },
175
+ "do_sample": false,
176
+ "downsample_rate": 5,
177
+ "dtype": "bfloat16",
178
+ "encoder": {
179
+ "_name_or_path": "openai/whisper-base",
180
+ "activation_dropout": 0.0,
181
+ "activation_function": "gelu",
182
+ "apply_spec_augment": false,
183
+ "architectures": [
184
+ "WhisperForConditionalGeneration"
185
+ ],
186
+ "attention_dropout": 0.0,
187
+ "begin_suppress_tokens": [
188
+ 220,
189
+ 50257
190
+ ],
191
+ "bos_token_id": 50257,
192
+ "classifier_proj_size": 256,
193
+ "d_model": 512,
194
+ "decoder_attention_heads": 8,
195
+ "decoder_ffn_dim": 2048,
196
+ "decoder_layerdrop": 0.0,
197
+ "decoder_layers": 6,
198
+ "decoder_start_token_id": 50258,
199
+ "dropout": 0.0,
200
+ "dtype": "bfloat16",
201
+ "encoder_attention_heads": 8,
202
+ "encoder_ffn_dim": 2048,
203
+ "encoder_layerdrop": 0.0,
204
+ "encoder_layers": 6,
205
+ "eos_token_id": 50257,
206
+ "forced_decoder_ids": [
207
+ [
208
+ 1,
209
+ 50259
210
+ ],
211
+ [
212
+ 2,
213
+ 50359
214
+ ],
215
+ [
216
+ 3,
217
+ 50363
218
+ ]
219
+ ],
220
+ "init_std": 0.02,
221
+ "mask_feature_length": 10,
222
+ "mask_feature_min_masks": 0,
223
+ "mask_feature_prob": 0.0,
224
+ "mask_time_length": 10,
225
+ "mask_time_min_masks": 2,
226
+ "mask_time_prob": 0.05,
227
+ "max_source_positions": 1500,
228
+ "max_target_positions": 448,
229
+ "median_filter_width": 7,
230
+ "model_type": "whisper",
231
+ "num_mel_bins": 80,
232
+ "pad_token_id": 50257,
233
+ "scale_embedding": false,
234
+ "suppress_tokens": [
235
+ 1,
236
+ 2,
237
+ 7,
238
+ 8,
239
+ 9,
240
+ 10,
241
+ 14,
242
+ 25,
243
+ 26,
244
+ 27,
245
+ 28,
246
+ 29,
247
+ 31,
248
+ 58,
249
+ 59,
250
+ 60,
251
+ 61,
252
+ 62,
253
+ 63,
254
+ 90,
255
+ 91,
256
+ 92,
257
+ 93,
258
+ 359,
259
+ 503,
260
+ 522,
261
+ 542,
262
+ 873,
263
+ 893,
264
+ 902,
265
+ 918,
266
+ 922,
267
+ 931,
268
+ 1350,
269
+ 1853,
270
+ 1982,
271
+ 2460,
272
+ 2627,
273
+ 3246,
274
+ 3253,
275
+ 3268,
276
+ 3536,
277
+ 3846,
278
+ 3961,
279
+ 4183,
280
+ 4667,
281
+ 6585,
282
+ 6647,
283
+ 7273,
284
+ 9061,
285
+ 9383,
286
+ 10428,
287
+ 10929,
288
+ 11938,
289
+ 12033,
290
+ 12331,
291
+ 12562,
292
+ 13793,
293
+ 14157,
294
+ 14635,
295
+ 15265,
296
+ 15618,
297
+ 16553,
298
+ 16604,
299
+ 18362,
300
+ 18956,
301
+ 20075,
302
+ 21675,
303
+ 22520,
304
+ 26130,
305
+ 26161,
306
+ 26435,
307
+ 28279,
308
+ 29464,
309
+ 31650,
310
+ 32302,
311
+ 32470,
312
+ 36865,
313
+ 42863,
314
+ 47425,
315
+ 49870,
316
+ 50254,
317
+ 50258,
318
+ 50358,
319
+ 50359,
320
+ 50360,
321
+ 50361,
322
+ 50362
323
+ ],
324
+ "tie_word_embeddings": true,
325
+ "use_cache": true,
326
+ "use_weighted_layer_sum": false,
327
+ "vocab_size": 51865
328
+ },
329
+ "encoder_conv_layers": [
330
+ [
331
+ 1,
332
+ 3,
333
+ 1
334
+ ],
335
+ [
336
+ 1,
337
+ 3,
338
+ 2
339
+ ]
340
+ ],
341
+ "encoder_dim": 512,
342
+ "freeze_projector": false,
343
+ "freq_mask_length": 27,
344
+ "inference_warmup_tokens": 10,
345
+ "label_smoothing": 0.0,
346
+ "length_penalty": 1.0,
347
+ "llm_dim": 960,
348
+ "lora_alpha": 32,
349
+ "lora_dropout": 0.0,
350
+ "lora_rank": 8,
351
+ "lora_target_modules": [
352
+ "q_proj",
353
+ "k_proj",
354
+ "v_proj",
355
+ "o_proj",
356
+ "gate_proj",
357
+ "up_proj",
358
+ "down_proj"
359
+ ],
360
+ "max_new_tokens": 128,
361
+ "min_new_tokens": 0,
362
+ "model_dtype": "bfloat16",
363
+ "model_type": "asr_model",
364
+ "no_repeat_ngram_size": 0,
365
+ "num_beams": 1,
366
+ "num_experts": 4,
367
+ "num_experts_per_tok": 2,
368
+ "num_freq_masks": 2,
369
+ "num_time_masks": 2,
370
+ "pipeline_tag": "automatic-speech-recognition",
371
+ "pretrained_model_path": "mazesmazes/tiny-audio-embedded",
372
+ "projector_dropout": 0.0,
373
+ "projector_hidden_dim": 512,
374
+ "projector_init_std": 0.02,
375
+ "projector_num_layers": 2,
376
+ "projector_pool_stride": 4,
377
+ "projector_type": "mlp",
378
+ "qformer_hidden_size": null,
379
+ "qformer_intermediate_size": null,
380
+ "qformer_num_heads": 16,
381
+ "qformer_num_layers": 2,
382
+ "qformer_window_size": 15,
383
+ "repetition_penalty": 1.0,
384
+ "router_aux_loss_coef": 0.01,
385
+ "system_prompt": "",
386
+ "temperature": null,
387
+ "text_config": {
388
+ "_name_or_path": "HuggingFaceTB/SmolLM2-360M-Instruct",
389
+ "architectures": [
390
+ "LlamaForCausalLM"
391
+ ],
392
+ "attention_bias": false,
393
+ "attention_dropout": 0.0,
394
+ "bos_token_id": 1,
395
+ "dtype": "bfloat16",
396
+ "eos_token_id": 2,
397
+ "head_dim": 64,
398
+ "hidden_act": "silu",
399
+ "hidden_size": 960,
400
+ "initializer_range": 0.02,
401
+ "intermediate_size": 2560,
402
+ "is_llama_config": true,
403
+ "max_position_embeddings": 8192,
404
+ "mlp_bias": false,
405
+ "model_type": "llama",
406
+ "num_attention_heads": 15,
407
+ "num_hidden_layers": 32,
408
+ "num_key_value_heads": 5,
409
+ "pad_token_id": 2,
410
+ "pretraining_tp": 1,
411
+ "rms_norm_eps": 1e-05,
412
+ "rope_interleaved": false,
413
+ "rope_parameters": {
414
+ "rope_theta": 100000,
415
+ "rope_type": "default"
416
+ },
417
+ "tie_word_embeddings": true,
418
+ "transformers.js_config": {
419
+ "kv_cache_dtype": {
420
+ "fp16": "float16",
421
+ "q4f16": "float16"
422
+ }
423
+ },
424
+ "use_cache": true,
425
+ "vocab_size": 49153
426
+ },
427
+ "text_model_id": "HuggingFaceTB/SmolLM2-360M-Instruct",
428
+ "time_mask_length": 100,
429
+ "top_k": null,
430
+ "top_p": null,
431
+ "transformers_version": "5.6.0",
432
+ "use_cache": false,
433
+ "use_lora": false,
434
+ "use_specaugment": true,
435
+ "vocab_size": 49153
436
+ }
diarization.py ADDED
@@ -0,0 +1,732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
2
+
3
+ Spectral clustering implementation adapted from FunASR/3D-Speaker:
4
+ https://github.com/alibaba-damo-academy/FunASR
5
+ MIT License (https://opensource.org/licenses/MIT)
6
+ """
7
+
8
+ import warnings
9
+
10
+ import numpy as np
11
+ import scipy
12
+ import sklearn.metrics.pairwise
13
+ import torch
14
+ from sklearn.cluster._kmeans import k_means
15
+ from sklearn.preprocessing import normalize
16
+
17
+
18
+ def _get_device() -> torch.device:
19
+ """Get best available device for inference."""
20
+ if torch.cuda.is_available():
21
+ return torch.device("cuda")
22
+ if torch.backends.mps.is_available():
23
+ return torch.device("mps")
24
+ return torch.device("cpu")
25
+
26
+
27
+ class SpectralCluster:
28
+ """Spectral clustering using unnormalized Laplacian of affinity matrix.
29
+
30
+ Adapted from FunASR/3D-Speaker and SpeechBrain implementations.
31
+ Uses eigenvalue gap to automatically determine number of speakers.
32
+ """
33
+
34
+ def __init__(self, min_num_spks: int = 1, max_num_spks: int = 15, pval: float = 0.06):
35
+ self.min_num_spks = min_num_spks
36
+ self.max_num_spks = max_num_spks
37
+ self.pval = pval
38
+
39
+ def __call__(self, embeddings: np.ndarray, oracle_num: int | None = None) -> np.ndarray:
40
+ """Run spectral clustering on embeddings.
41
+
42
+ Args:
43
+ embeddings: Speaker embeddings of shape [N, D]
44
+ oracle_num: Optional known number of speakers
45
+
46
+ Returns:
47
+ Cluster labels of shape [N]
48
+ """
49
+ # Similarity matrix computation
50
+ sim_mat = self.get_sim_mat(embeddings)
51
+
52
+ # Refining similarity matrix with pval
53
+ prunned_sim_mat = self.p_pruning(sim_mat)
54
+
55
+ # Symmetrization
56
+ sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
57
+
58
+ # Laplacian calculation
59
+ laplacian = self.get_laplacian(sym_prund_sim_mat)
60
+
61
+ # Get Spectral Embeddings
62
+ emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
63
+
64
+ # Perform clustering
65
+ return self.cluster_embs(emb, num_of_spk)
66
+
67
+ def get_sim_mat(self, embeddings: np.ndarray) -> np.ndarray:
68
+ """Compute cosine similarity matrix."""
69
+ return sklearn.metrics.pairwise.cosine_similarity(embeddings, embeddings)
70
+
71
+ def p_pruning(self, affinity: np.ndarray) -> np.ndarray:
72
+ """Prune low similarity values in affinity matrix (keep top pval fraction)."""
73
+ n = affinity.shape[0]
74
+ pval = max(self.pval, 6.0 / n)
75
+ k_keep = max(1, int(pval * n))
76
+
77
+ # Vectorized: find top-k indices per row and zero out the rest
78
+ top_k_idx = np.argpartition(affinity, -k_keep, axis=1)[:, -k_keep:]
79
+ mask = np.zeros_like(affinity, dtype=bool)
80
+ np.put_along_axis(mask, top_k_idx, True, axis=1)
81
+ affinity[~mask] = 0
82
+ return affinity
83
+
84
+ def get_laplacian(self, sim_mat: np.ndarray) -> np.ndarray:
85
+ """Compute unnormalized Laplacian matrix."""
86
+ from scipy.sparse.csgraph import laplacian
87
+
88
+ np.fill_diagonal(sim_mat, 0)
89
+ return laplacian(sim_mat, normed=False)
90
+
91
+ def get_spec_embs(
92
+ self, laplacian: np.ndarray, k_oracle: int | None = None
93
+ ) -> tuple[np.ndarray, int]:
94
+ """Extract spectral embeddings from Laplacian."""
95
+ lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
96
+
97
+ if k_oracle is not None:
98
+ num_of_spk = k_oracle
99
+ else:
100
+ lambda_gap_list = self.get_eigen_gaps(
101
+ lambdas[self.min_num_spks - 1 : self.max_num_spks + 1]
102
+ )
103
+ num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
104
+
105
+ emb = eig_vecs[:, :num_of_spk]
106
+ return emb, num_of_spk
107
+
108
+ def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
109
+ """Cluster spectral embeddings using k-means."""
110
+ _, labels, _ = k_means(emb, k, n_init=10)
111
+ return labels
112
+
113
+ def get_eigen_gaps(self, eig_vals: np.ndarray) -> np.ndarray:
114
+ """Compute gaps between consecutive eigenvalues."""
115
+ return np.diff(eig_vals)
116
+
117
+
118
+ class SpeakerClusterer:
119
+ """Speaker clustering backend using spectral clustering with speaker merging.
120
+
121
+ Features:
122
+ - Spectral clustering with eigenvalue gap for auto speaker count detection
123
+ - P-pruning for affinity matrix refinement
124
+ - Post-clustering speaker merging by cosine similarity
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ min_num_spks: int = 2,
130
+ max_num_spks: int = 10,
131
+ merge_thr: float = 0.90, # Moderate merging
132
+ ):
133
+ self.min_num_spks = min_num_spks
134
+ self.max_num_spks = max_num_spks
135
+ self.merge_thr = merge_thr
136
+ self._spectral_cluster: SpectralCluster | None = None
137
+
138
+ def _get_spectral_cluster(self) -> SpectralCluster:
139
+ """Lazy-load spectral clusterer."""
140
+ if self._spectral_cluster is None:
141
+ self._spectral_cluster = SpectralCluster(
142
+ min_num_spks=self.min_num_spks,
143
+ max_num_spks=self.max_num_spks,
144
+ )
145
+ return self._spectral_cluster
146
+
147
+ def __call__(self, embeddings: np.ndarray, num_speakers: int | None = None) -> np.ndarray:
148
+ """Cluster speaker embeddings and return labels.
149
+
150
+ Args:
151
+ embeddings: Speaker embeddings of shape [N, D]
152
+ num_speakers: Optional oracle number of speakers
153
+
154
+ Returns:
155
+ Cluster labels of shape [N]
156
+ """
157
+ import warnings
158
+
159
+ if len(embeddings.shape) != 2:
160
+ raise ValueError(f"Expected 2D array, got shape {embeddings.shape}")
161
+
162
+ # Handle edge cases
163
+ if embeddings.shape[0] == 0:
164
+ return np.array([], dtype=int)
165
+ if embeddings.shape[0] == 1:
166
+ return np.array([0], dtype=int)
167
+ if embeddings.shape[0] < 6:
168
+ return np.zeros(embeddings.shape[0], dtype=int)
169
+
170
+ # Normalize embeddings and replace NaN/inf
171
+ embeddings = np.nan_to_num(embeddings, nan=0.0, posinf=0.0, neginf=0.0)
172
+ embeddings = normalize(embeddings)
173
+
174
+ # Run spectral clustering (suppress numerical warnings)
175
+ spectral = self._get_spectral_cluster()
176
+
177
+ # Update min/max for oracle case
178
+ if num_speakers is not None:
179
+ spectral.min_num_spks = num_speakers
180
+ spectral.max_num_spks = num_speakers
181
+
182
+ with warnings.catch_warnings():
183
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
184
+ labels = spectral(embeddings, oracle_num=num_speakers)
185
+
186
+ # Reset min/max
187
+ if num_speakers is not None:
188
+ spectral.min_num_spks = self.min_num_spks
189
+ spectral.max_num_spks = self.max_num_spks
190
+
191
+ # Merge similar speakers if no oracle
192
+ if num_speakers is None:
193
+ labels = self._merge_by_cos(labels, embeddings, self.merge_thr)
194
+
195
+ # Re-index labels sequentially
196
+ _, labels = np.unique(labels, return_inverse=True)
197
+
198
+ return labels
199
+
200
+ def _merge_by_cos(self, labels: np.ndarray, embs: np.ndarray, cos_thr: float) -> np.ndarray:
201
+ """Merge similar speakers by cosine similarity of centroids."""
202
+ from scipy.cluster.hierarchy import fcluster, linkage
203
+ from scipy.spatial.distance import pdist
204
+
205
+ unique_labels = np.unique(labels)
206
+ if len(unique_labels) <= 1:
207
+ return labels
208
+
209
+ # Compute normalized speaker centroids
210
+ centroids = np.array([embs[labels == lbl].mean(0) for lbl in unique_labels])
211
+ centroids = normalize(centroids)
212
+
213
+ # Hierarchical clustering with cosine distance
214
+ distances = pdist(centroids, metric="cosine")
215
+ linkage_matrix = linkage(distances, method="average")
216
+ merged_labels = fcluster(linkage_matrix, t=1.0 - cos_thr, criterion="distance") - 1
217
+
218
+ # Map original labels to merged labels
219
+ label_map = dict(zip(unique_labels, merged_labels))
220
+ return np.array([label_map[lbl] for lbl in labels])
221
+
222
+
223
+ class LocalSpeakerDiarizer:
224
+ """Local speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
225
+
226
+ Pipeline:
227
+ 1. TEN-VAD detects speech segments
228
+ 2. Sliding window (1.0s, 75% overlap) for uniform embedding extraction
229
+ 3. ECAPA-TDNN extracts speaker embeddings per window
230
+ 4. Spectral clustering with eigenvalue gap for auto speaker detection
231
+ 5. Frame-level consensus voting for segment reconstruction
232
+ 6. Post-processing merges short segments to reduce flicker
233
+
234
+ Tunable Parameters (class attributes):
235
+ - WINDOW_SIZE: Embedding extraction window size in seconds
236
+ - STEP_SIZE: Sliding window step size (overlap = WINDOW_SIZE - STEP_SIZE)
237
+ - VAD_THRESHOLD: Speech detection threshold (lower = more sensitive)
238
+ - VAD_MIN_DURATION: Minimum speech segment duration
239
+ - VAD_MAX_GAP: Maximum gap to bridge between segments
240
+ - VAD_PAD_ONSET/OFFSET: Padding added to speech segments
241
+ - VOTING_RATE: Frame resolution for consensus voting
242
+ - MIN_SEGMENT_DURATION: Minimum final segment duration
243
+ - SAME_SPEAKER_GAP: Maximum gap to merge same-speaker segments
244
+ - TAIL_COVERAGE_RATIO: Minimum tail coverage to add extra window
245
+ """
246
+
247
+ _ten_vad_model = None
248
+ _ecapa_model = None
249
+ _device = None
250
+
251
+ # ==================== TUNABLE PARAMETERS ====================
252
+
253
+ # Sliding window for embedding extraction
254
+ WINDOW_SIZE = 0.75 # seconds - shorter window for finer resolution
255
+ STEP_SIZE = 0.15 # seconds (80% overlap for more votes)
256
+ TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
257
+
258
+ # VAD hysteresis parameters
259
+ VAD_THRESHOLD = 0.25 # Balanced threshold
260
+ VAD_MIN_DURATION = 0.05 # Minimum speech segment duration (seconds)
261
+ VAD_MAX_GAP = 0.50 # Bridge gaps shorter than this (seconds)
262
+ VAD_PAD_ONSET = 0.05 # Padding at segment start (seconds)
263
+ VAD_PAD_OFFSET = 0.05 # Padding at segment end (seconds)
264
+
265
+ # Frame-level voting
266
+ VOTING_RATE = 0.01 # 10ms resolution for consensus voting
267
+
268
+ # Post-processing
269
+ MIN_SEGMENT_DURATION = 0.15 # Minimum final segment duration (seconds)
270
+ SHORT_SEGMENT_GAP = 0.1 # Gap threshold for merging short segments
271
+ SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
272
+
273
+ # ===========================================================
274
+
275
+ @classmethod
276
+ def _get_ten_vad_model(cls):
277
+ """Lazy-load TEN-VAD model (singleton)."""
278
+ if cls._ten_vad_model is None:
279
+ from ten_vad import TenVad
280
+
281
+ cls._ten_vad_model = TenVad(hop_size=256, threshold=cls.VAD_THRESHOLD)
282
+ return cls._ten_vad_model
283
+
284
+ @classmethod
285
+ def _get_device(cls) -> torch.device:
286
+ """Get the best available device."""
287
+ if cls._device is None:
288
+ cls._device = _get_device()
289
+ return cls._device
290
+
291
+ @classmethod
292
+ def _get_ecapa_model(cls):
293
+ """Lazy-load ECAPA-TDNN speaker embedding model (singleton)."""
294
+ if cls._ecapa_model is None:
295
+ # Suppress torchaudio deprecation warning from SpeechBrain
296
+ with warnings.catch_warnings():
297
+ warnings.filterwarnings("ignore", message="torchaudio._backend")
298
+ from speechbrain.inference.speaker import EncoderClassifier
299
+
300
+ device = cls._get_device()
301
+ cls._ecapa_model = EncoderClassifier.from_hparams(
302
+ source="speechbrain/spkrec-ecapa-voxceleb",
303
+ run_opts={"device": str(device)},
304
+ )
305
+
306
+ return cls._ecapa_model
307
+
308
+ @classmethod
309
+ def diarize(
310
+ cls,
311
+ audio: np.ndarray | str,
312
+ sample_rate: int = 16000,
313
+ num_speakers: int | None = None,
314
+ min_speakers: int = 2,
315
+ max_speakers: int = 10,
316
+ **_kwargs,
317
+ ) -> list[dict]:
318
+ """Run speaker diarization on audio.
319
+
320
+ Args:
321
+ audio: Audio waveform as numpy array or path to audio file
322
+ sample_rate: Audio sample rate (default 16000)
323
+ num_speakers: Exact number of speakers (if known)
324
+ min_speakers: Minimum number of speakers
325
+ max_speakers: Maximum number of speakers
326
+
327
+ Returns:
328
+ List of dicts with 'speaker', 'start', 'end' keys
329
+ """
330
+ # Handle file path input
331
+ if isinstance(audio, str):
332
+ import librosa
333
+
334
+ audio, sample_rate = librosa.load(audio, sr=16000)
335
+
336
+ # Ensure correct sample rate
337
+ if sample_rate != 16000:
338
+ import librosa
339
+
340
+ audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
341
+ sample_rate = 16000
342
+
343
+ audio = audio.astype(np.float32)
344
+ total_duration = len(audio) / sample_rate
345
+
346
+ # Step 1: VAD (returns segments and raw frame-level decisions)
347
+ segments, vad_frames = cls._get_speech_segments(audio, sample_rate)
348
+ if not segments:
349
+ return []
350
+
351
+ # Step 2: Extract embeddings
352
+ embeddings, window_segments = cls._extract_embeddings(audio, segments, sample_rate)
353
+ if len(embeddings) == 0:
354
+ return []
355
+
356
+ # Step 3: Cluster
357
+ clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
358
+ labels = clusterer(embeddings, num_speakers)
359
+
360
+ # Step 4: Post-process with consensus voting (VAD-aware)
361
+ return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
362
+
363
+ @classmethod
364
+ def _get_speech_segments(
365
+ cls, audio_array: np.ndarray, sample_rate: int = 16000
366
+ ) -> tuple[list[dict], list[bool]]:
367
+ """Get speech segments using TEN-VAD.
368
+
369
+ Returns:
370
+ Tuple of (segments list, vad_frames list of per-frame speech decisions)
371
+ """
372
+ vad_model = cls._get_ten_vad_model()
373
+
374
+ # Convert to int16 as required by TEN-VAD
375
+ # Clip to prevent integer overflow
376
+ if audio_array.dtype != np.int16:
377
+ audio_int16 = (np.clip(audio_array, -1.0, 1.0) * 32767).astype(np.int16)
378
+ else:
379
+ audio_int16 = audio_array
380
+
381
+ # Process frame by frame
382
+ hop_size = 256
383
+ frame_duration = hop_size / sample_rate
384
+ speech_frames: list[bool] = []
385
+
386
+ for i in range(0, len(audio_int16) - hop_size, hop_size):
387
+ frame = audio_int16[i : i + hop_size]
388
+ _, is_speech = vad_model.process(frame)
389
+ speech_frames.append(is_speech)
390
+
391
+ # Convert frame-level decisions to segments
392
+ segments = []
393
+ in_speech = False
394
+ start_idx = 0
395
+
396
+ for i, is_speech in enumerate(speech_frames):
397
+ if is_speech and not in_speech:
398
+ start_idx = i
399
+ in_speech = True
400
+ elif not is_speech and in_speech:
401
+ start_time = start_idx * frame_duration
402
+ end_time = i * frame_duration
403
+ segments.append(
404
+ {
405
+ "start": start_time,
406
+ "end": end_time,
407
+ "start_sample": int(start_time * sample_rate),
408
+ "end_sample": int(end_time * sample_rate),
409
+ }
410
+ )
411
+ in_speech = False
412
+
413
+ # Handle trailing speech
414
+ if in_speech:
415
+ start_time = start_idx * frame_duration
416
+ end_time = len(speech_frames) * frame_duration
417
+ segments.append(
418
+ {
419
+ "start": start_time,
420
+ "end": end_time,
421
+ "start_sample": int(start_time * sample_rate),
422
+ "end_sample": int(end_time * sample_rate),
423
+ }
424
+ )
425
+
426
+ return cls._apply_vad_hysteresis(segments, sample_rate), speech_frames
427
+
428
+ @classmethod
429
+ def _apply_vad_hysteresis(cls, segments: list[dict], sample_rate: int = 16000) -> list[dict]:
430
+ """Apply hysteresis-like post-processing to VAD segments."""
431
+ if not segments:
432
+ return segments
433
+
434
+ segments = sorted(segments, key=lambda x: x["start"])
435
+
436
+ # Fill short gaps
437
+ merged = [segments[0].copy()]
438
+ for seg in segments[1:]:
439
+ gap = seg["start"] - merged[-1]["end"]
440
+ if gap <= cls.VAD_MAX_GAP:
441
+ merged[-1]["end"] = seg["end"]
442
+ merged[-1]["end_sample"] = seg["end_sample"]
443
+ else:
444
+ merged.append(seg.copy())
445
+
446
+ # Remove short segments
447
+ filtered = [seg for seg in merged if (seg["end"] - seg["start"]) >= cls.VAD_MIN_DURATION]
448
+
449
+ # Dilate segments (add padding)
450
+ for seg in filtered:
451
+ seg["start"] = max(0.0, seg["start"] - cls.VAD_PAD_ONSET)
452
+ seg["end"] = seg["end"] + cls.VAD_PAD_OFFSET
453
+ seg["start_sample"] = int(seg["start"] * sample_rate)
454
+ seg["end_sample"] = int(seg["end"] * sample_rate)
455
+
456
+ return filtered
457
+
458
+ @classmethod
459
+ def _extract_embeddings(
460
+ cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
461
+ ) -> tuple[np.ndarray, list[dict]]:
462
+ """Extract speaker embeddings using sliding windows."""
463
+ speaker_model = cls._get_ecapa_model()
464
+
465
+ window_samples = int(cls.WINDOW_SIZE * sample_rate)
466
+ step_samples = int(cls.STEP_SIZE * sample_rate)
467
+
468
+ embeddings = []
469
+ window_segments = []
470
+
471
+ with torch.no_grad():
472
+ for seg in segments:
473
+ seg_start = seg["start_sample"]
474
+ seg_end = seg["end_sample"]
475
+ seg_len = seg_end - seg_start
476
+
477
+ # Generate window positions
478
+ if seg_len <= window_samples:
479
+ starts = [seg_start]
480
+ ends = [seg_end]
481
+ else:
482
+ starts = list(range(seg_start, seg_end - window_samples + 1, step_samples))
483
+ ends = [s + window_samples for s in starts]
484
+
485
+ # Cover tail if > TAIL_COVERAGE_RATIO of window remains
486
+ if ends and ends[-1] < seg_end:
487
+ remainder = seg_end - ends[-1]
488
+ if remainder > (window_samples * cls.TAIL_COVERAGE_RATIO):
489
+ starts.append(seg_end - window_samples)
490
+ ends.append(seg_end)
491
+
492
+ for c_start, c_end in zip(starts, ends):
493
+ chunk = audio_array[c_start:c_end]
494
+
495
+ # Pad short chunks with reflection
496
+ if len(chunk) < window_samples:
497
+ pad_width = window_samples - len(chunk)
498
+ chunk = np.pad(chunk, (0, pad_width), mode="reflect")
499
+
500
+ # Extract embedding using SpeechBrain's encode_batch
501
+ chunk_tensor = torch.from_numpy(chunk).float().unsqueeze(0)
502
+ embedding = (
503
+ speaker_model.encode_batch(chunk_tensor).squeeze(0).squeeze(0).cpu().numpy()
504
+ )
505
+
506
+ # Validate embedding
507
+ if np.isfinite(embedding).all() and np.linalg.norm(embedding) > 1e-8:
508
+ embeddings.append(embedding)
509
+ window_segments.append(
510
+ {
511
+ "start": c_start / sample_rate,
512
+ "end": c_end / sample_rate,
513
+ }
514
+ )
515
+
516
+ # Normalize all embeddings at once
517
+ if embeddings:
518
+ return normalize(np.array(embeddings)), window_segments
519
+ return np.array([]), []
520
+
521
+ @classmethod
522
+ def _resample_vad(cls, vad_frames: list[bool], num_frames: int) -> np.ndarray:
523
+ """Resample VAD frame decisions to match voting grid resolution.
524
+
525
+ VAD operates at 256 samples / 16000 Hz = 16ms per frame.
526
+ Voting operates at VOTING_RATE (default 10ms) per frame.
527
+ This maps VAD decisions to the finer voting grid.
528
+ """
529
+ if not vad_frames:
530
+ return np.zeros(num_frames, dtype=bool)
531
+
532
+ vad_rate = 256 / 16000 # 16ms per VAD frame
533
+ vad_arr = np.array(vad_frames)
534
+
535
+ # Vectorized: compute VAD frame indices for each voting frame
536
+ voting_times = np.arange(num_frames) * cls.VOTING_RATE
537
+ vad_indices = np.clip((voting_times / vad_rate).astype(int), 0, len(vad_arr) - 1)
538
+ return vad_arr[vad_indices]
539
+
540
+ @classmethod
541
+ def _postprocess_segments(
542
+ cls,
543
+ window_segments: list[dict],
544
+ labels: np.ndarray,
545
+ total_duration: float,
546
+ vad_frames: list[bool],
547
+ ) -> list[dict]:
548
+ """Post-process using frame-level consensus voting with VAD-aware silence."""
549
+ if not window_segments or len(labels) == 0:
550
+ return []
551
+
552
+ # Correct labels to be contiguous
553
+ unique_labels = np.unique(labels)
554
+ label_map = {old: new for new, old in enumerate(unique_labels)}
555
+ clean_labels = np.array([label_map[lbl] for lbl in labels])
556
+ num_speakers = len(unique_labels)
557
+
558
+ if num_speakers == 0:
559
+ return []
560
+
561
+ # Create voting grid
562
+ num_frames = int(np.ceil(total_duration / cls.VOTING_RATE)) + 1
563
+ votes = np.zeros((num_frames, num_speakers), dtype=np.float32)
564
+
565
+ # Accumulate votes
566
+ for win, label in zip(window_segments, clean_labels):
567
+ start_frame = int(win["start"] / cls.VOTING_RATE)
568
+ end_frame = int(win["end"] / cls.VOTING_RATE)
569
+ end_frame = min(end_frame, num_frames)
570
+ if start_frame < end_frame:
571
+ votes[start_frame:end_frame, label] += 1.0
572
+
573
+ # Determine winner per frame
574
+ frame_speakers = np.argmax(votes, axis=1)
575
+ max_votes = np.max(votes, axis=1)
576
+
577
+ # Resample VAD to voting grid resolution for silence-aware voting
578
+ vad_resampled = cls._resample_vad(vad_frames, num_frames)
579
+
580
+ # Convert frames to segments
581
+ final_segments = []
582
+ current_speaker = -1
583
+ seg_start = 0.0
584
+
585
+ for f in range(num_frames):
586
+ speaker = int(frame_speakers[f])
587
+ score = max_votes[f]
588
+
589
+ # Force silence if VAD says no speech OR no votes
590
+ if score == 0 or not vad_resampled[f]:
591
+ speaker = -1
592
+
593
+ if speaker != current_speaker:
594
+ if current_speaker != -1:
595
+ final_segments.append(
596
+ {
597
+ "speaker": f"SPEAKER_{current_speaker}",
598
+ "start": seg_start,
599
+ "end": f * cls.VOTING_RATE,
600
+ }
601
+ )
602
+ current_speaker = speaker
603
+ seg_start = f * cls.VOTING_RATE
604
+
605
+ # Close last segment
606
+ if current_speaker != -1:
607
+ final_segments.append(
608
+ {
609
+ "speaker": f"SPEAKER_{current_speaker}",
610
+ "start": seg_start,
611
+ "end": num_frames * cls.VOTING_RATE,
612
+ }
613
+ )
614
+
615
+ return cls._merge_short_segments(final_segments)
616
+
617
+ @classmethod
618
+ def _merge_short_segments(cls, segments: list[dict]) -> list[dict]:
619
+ """Merge short segments to reduce flicker."""
620
+ if not segments:
621
+ return []
622
+
623
+ clean: list[dict] = []
624
+ for seg in segments:
625
+ dur = seg["end"] - seg["start"]
626
+ if dur < cls.MIN_SEGMENT_DURATION:
627
+ if (
628
+ clean
629
+ and clean[-1]["speaker"] == seg["speaker"]
630
+ and seg["start"] - clean[-1]["end"] < cls.SHORT_SEGMENT_GAP
631
+ ):
632
+ clean[-1]["end"] = seg["end"]
633
+ continue
634
+
635
+ if (
636
+ clean
637
+ and clean[-1]["speaker"] == seg["speaker"]
638
+ and seg["start"] - clean[-1]["end"] < cls.SAME_SPEAKER_GAP
639
+ ):
640
+ clean[-1]["end"] = seg["end"]
641
+ else:
642
+ clean.append(seg)
643
+
644
+ return clean
645
+
646
+ @classmethod
647
+ def assign_speakers_to_words(
648
+ cls,
649
+ words: list[dict],
650
+ speaker_segments: list[dict],
651
+ ) -> list[dict]:
652
+ """Assign speaker labels to words based on timestamp overlap.
653
+
654
+ Args:
655
+ words: List of word dicts with 'word', 'start', 'end' keys
656
+ speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
657
+
658
+ Returns:
659
+ Words list with 'speaker' key added to each word
660
+ """
661
+ for word in words:
662
+ word_mid = (word["start"] + word["end"]) / 2
663
+
664
+ # Find the speaker segment that contains this word's midpoint
665
+ best_speaker = None
666
+ for seg in speaker_segments:
667
+ if seg["start"] <= word_mid <= seg["end"]:
668
+ best_speaker = seg["speaker"]
669
+ break
670
+
671
+ # If no exact match, find closest segment
672
+ if best_speaker is None and speaker_segments:
673
+ min_dist = float("inf")
674
+ for seg in speaker_segments:
675
+ seg_mid = (seg["start"] + seg["end"]) / 2
676
+ dist = abs(word_mid - seg_mid)
677
+ if dist < min_dist:
678
+ min_dist = dist
679
+ best_speaker = seg["speaker"]
680
+
681
+ word["speaker"] = best_speaker
682
+
683
+ return words
684
+
685
+
686
+ class SpeakerDiarizer:
687
+ """Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
688
+
689
+ Example:
690
+ >>> segments = SpeakerDiarizer.diarize(audio_array)
691
+ >>> for seg in segments:
692
+ ... print(f"{seg['speaker']}: {seg['start']:.2f} - {seg['end']:.2f}")
693
+ """
694
+
695
+ @classmethod
696
+ def diarize(
697
+ cls,
698
+ audio: np.ndarray | str,
699
+ sample_rate: int = 16000,
700
+ num_speakers: int | None = None,
701
+ min_speakers: int | None = None,
702
+ max_speakers: int | None = None,
703
+ **_kwargs,
704
+ ) -> list[dict]:
705
+ """Run speaker diarization on audio.
706
+
707
+ Args:
708
+ audio: Audio waveform as numpy array or path to audio file
709
+ sample_rate: Audio sample rate (default 16000)
710
+ num_speakers: Exact number of speakers (if known)
711
+ min_speakers: Minimum number of speakers
712
+ max_speakers: Maximum number of speakers
713
+
714
+ Returns:
715
+ List of dicts with 'speaker', 'start', 'end' keys
716
+ """
717
+ return LocalSpeakerDiarizer.diarize(
718
+ audio,
719
+ sample_rate=sample_rate,
720
+ num_speakers=num_speakers,
721
+ min_speakers=min_speakers or 2,
722
+ max_speakers=max_speakers or 10,
723
+ )
724
+
725
+ @classmethod
726
+ def assign_speakers_to_words(
727
+ cls,
728
+ words: list[dict],
729
+ speaker_segments: list[dict],
730
+ ) -> list[dict]:
731
+ """Assign speaker labels to words based on timestamp overlap."""
732
+ return LocalSpeakerDiarizer.assign_speakers_to_words(words, speaker_segments)
generation_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "do_sample": false,
5
+ "eos_token_id": [
6
+ 2,
7
+ 0
8
+ ],
9
+ "length_penalty": 1.0,
10
+ "max_new_tokens": 128,
11
+ "min_new_tokens": 0,
12
+ "no_repeat_ngram_size": 0,
13
+ "num_beams": 1,
14
+ "pad_token_id": 2,
15
+ "repetition_penalty": 1.0,
16
+ "transformers_version": "5.6.0",
17
+ "use_cache": true
18
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c5c68a46ee3dd399313e095fc4f8526d4c25bcf92932d67729c85c9ed386dc6
3
+ size 3081536
preprocessor_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "dither": 0.0,
4
+ "feature_extractor_type": "WhisperFeatureExtractor",
5
+ "feature_size": 80,
6
+ "hop_length": 160,
7
+ "n_fft": 400,
8
+ "n_samples": 480000,
9
+ "nb_max_frames": 3000,
10
+ "padding_side": "right",
11
+ "padding_value": 0.0,
12
+ "return_attention_mask": false,
13
+ "sampling_rate": 16000,
14
+ "processor_class": "ASRProcessor",
15
+ "auto_map": {
16
+ "AutoProcessor": "asr_processing.ASRProcessor"
17
+ }
18
+ }
projectors.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio projector modules for bridging encoder and decoder embeddings.
2
+
3
+ This module contains all projector architectures:
4
+ - MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling
5
+ - MOSAProjector: MOSA-style dense mixture of experts
6
+ - SharedMoEAudioProjector: Shared expert + sparse routed experts
7
+ - QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
8
+ """
9
+
10
+ import math
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F # noqa: N812
15
+ from transformers import AutoModel, Blip2QFormerConfig
16
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
17
+
18
+ # =============================================================================
19
+ # MLP Projector
20
+ # =============================================================================
21
+
22
+
23
+ class MLPAudioProjector(nn.Module):
24
+ """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
25
+
26
+ def __init__(self, config):
27
+ """Initialize MLP projector.
28
+
29
+ Args:
30
+ config: ASRConfig with encoder_dim, llm_dim, projector_pool_stride
31
+ """
32
+ super().__init__()
33
+
34
+ encoder_dim = getattr(config, "encoder_dim", 768)
35
+ llm_dim = getattr(config, "llm_dim", 2048)
36
+ self.k = getattr(config, "projector_pool_stride", 4)
37
+
38
+ # Frame stacking: concat k adjacent frames then project
39
+ in_dim = encoder_dim * self.k
40
+ # Hidden dim defaults to llm_dim, can be overridden via config
41
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or llm_dim
42
+ self.linear_1 = nn.Linear(in_dim, hidden_dim, bias=False)
43
+ self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
44
+ self.act = nn.GELU()
45
+ self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
46
+
47
+ def get_output_length(self, input_length: int) -> int:
48
+ """Calculate output sequence length given input length (matches GLM-ASR)."""
49
+ # GLM-ASR formula: (L - merge_factor) // merge_factor + 1
50
+ return (input_length - self.k) // self.k + 1
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ """Project audio features to LLM embedding space.
54
+
55
+ Args:
56
+ x: Audio encoder output of shape [batch, seq_len, encoder_dim]
57
+
58
+ Returns:
59
+ Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim]
60
+ """
61
+ batch, seq, dim = x.shape
62
+ # Truncate to match GLM-ASR: use (seq - k) // k + 1 frames
63
+ # This drops trailing frames that don't fill a complete k-frame window
64
+ out_len = (seq - self.k) // self.k + 1
65
+ x = x[:, : out_len * self.k, :] # Truncate to exact multiple
66
+ x = x.reshape(batch, out_len, dim * self.k)
67
+
68
+ x = self.linear_1(x)
69
+ x = self.norm(x)
70
+ x = self.act(x)
71
+ return self.linear_2(x)
72
+
73
+
74
+ # =============================================================================
75
+ # MoE Projector (MOSA-style)
76
+ # =============================================================================
77
+
78
+
79
+ class SimpleAdapter(nn.Module):
80
+ """Simple 2-layer GELU adapter (from MOSA paper)."""
81
+
82
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
83
+ super().__init__()
84
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
85
+ self.act = nn.GELU()
86
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
87
+
88
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
89
+ return self.fc2(self.act(self.fc1(x)))
90
+
91
+
92
+ class SwiGLU(nn.Module):
93
+ """SwiGLU activation with gated linear units (used in LLaMA, Mistral, etc.)."""
94
+
95
+ def __init__(self, dim: int, hidden_dim: int, bias: bool = False):
96
+ super().__init__()
97
+ self.w1 = nn.Linear(dim, hidden_dim, bias=bias) # Gate
98
+ self.w2 = nn.Linear(dim, hidden_dim, bias=bias) # Value
99
+ self.w3 = nn.Linear(hidden_dim, dim, bias=bias) # Output
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
103
+
104
+
105
+ class AsymmetricSwiGLU(nn.Module):
106
+ """SwiGLU that handles different input and output dimensions."""
107
+
108
+ def __init__(
109
+ self, in_features: int, hidden_features: int, out_features: int, bias: bool = False
110
+ ):
111
+ super().__init__()
112
+ self.w1 = nn.Linear(in_features, hidden_features, bias=bias) # Gate
113
+ self.w2 = nn.Linear(in_features, hidden_features, bias=bias) # Value
114
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias) # Output
115
+
116
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
117
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
118
+
119
+
120
+ class MOSAProjector(nn.Module):
121
+ """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
122
+
123
+ Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998).
124
+ Uses softmax gating over all experts (dense MoE) with only cross-entropy loss.
125
+ Uses Conv1d for downsampling (2 layers, stride 2 each = 4x total).
126
+ """
127
+
128
+ def __init__(self, config):
129
+ """Initialize MOSA projector.
130
+
131
+ Args:
132
+ config: ASRConfig with encoder_dim, llm_dim, num_experts
133
+ """
134
+ super().__init__()
135
+ self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
136
+ self.llm_dim = getattr(config, "llm_dim", None) or 2048
137
+ self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
138
+ adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
139
+ router_hidden = getattr(config, "router_hidden_dim", None) or 512
140
+
141
+ # --- 1. Conv1d Downsampler (4x reduction) ---
142
+ # 2 layers of stride-2 convolution
143
+ self.downsampler = nn.Sequential(
144
+ nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=3, stride=2, padding=1),
145
+ nn.GELU(),
146
+ nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
147
+ nn.GELU(),
148
+ )
149
+
150
+ # --- 2. Simple Router (MOSA-Base: 2 layers with ReLU) ---
151
+ # Takes downsampled features (llm_dim) -> 512 -> num_experts
152
+ self.router = nn.Sequential(
153
+ nn.Linear(self.llm_dim, router_hidden),
154
+ nn.ReLU(),
155
+ nn.Linear(router_hidden, self.num_experts),
156
+ )
157
+
158
+ # --- 3. Experts (Simple 2-layer GELU adapters) ---
159
+ # Each expert: llm_dim -> hidden -> llm_dim (much smaller than frame-stacking)
160
+ self.experts = nn.ModuleList(
161
+ [
162
+ SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
163
+ for _ in range(self.num_experts)
164
+ ]
165
+ )
166
+
167
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
168
+ """Project audio features using mixture of experts.
169
+
170
+ Args:
171
+ x: Audio encoder output of shape [batch, seq_len, encoder_dim]
172
+
173
+ Returns:
174
+ Projected features of shape [batch, out_len, llm_dim]
175
+ """
176
+ # --- 1. Conv1d Downsampling ---
177
+ # Permute for Conv1d: [B, S, D] -> [B, D, S]
178
+ x = x.transpose(1, 2)
179
+ x = self.downsampler(x)
180
+ # Permute back: [B, D, S] -> [B, S, D]
181
+ x = x.transpose(1, 2)
182
+
183
+ # --- 2. Routing ---
184
+ routing_weights = F.softmax(self.router(x), dim=-1) # (B, out_len, num_experts)
185
+
186
+ # --- 3. Expert Mixture (Dense Execution) ---
187
+ expert_outputs = torch.stack([expert(x) for expert in self.experts]) # (E, B, out_len, D)
188
+ return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
189
+
190
+ def get_output_length(self, input_length: int) -> int:
191
+ """Calculate output sequence length after Conv1d downsampling (4x reduction)."""
192
+ # Conv1d with stride 2, kernel 3, padding 1: out = (in + 2*1 - 3) // 2 + 1 = (in - 1) // 2 + 1
193
+ # Applied twice for 4x total reduction
194
+ after_conv1 = (input_length + 2 * 1 - 3) // 2 + 1
195
+ return (after_conv1 + 2 * 1 - 3) // 2 + 1
196
+
197
+
198
+ # =============================================================================
199
+ # MoE Projector (Pure PyTorch with Shared Expert)
200
+ # =============================================================================
201
+
202
+
203
+ class MoEAudioProjector(nn.Module):
204
+ """MoE projector with shared expert (DeepSeek-style), pure PyTorch implementation.
205
+
206
+ Uses 4 sparse experts with top-2 routing plus a shared expert that processes all tokens.
207
+ No external dependencies (megablocks removed).
208
+
209
+ Architecture matches main branch: norm → experts(in_dim → hidden → out_dim)
210
+ """
211
+
212
+ def __init__(self, config):
213
+ """Initialize MoE projector.
214
+
215
+ Args:
216
+ config: ASRConfig with encoder_dim, llm_dim, num_experts, num_experts_per_tok
217
+ """
218
+ super().__init__()
219
+
220
+ self.k = getattr(config, "projector_pool_stride", 4)
221
+ self.aux_coef = getattr(config, "router_aux_loss_coef", 0.01)
222
+
223
+ # Stability coefficients
224
+ self.router_z_loss_coef = getattr(
225
+ config, "router_z_loss_coef", 1e-4
226
+ ) # Prevents logit explosion
227
+ self.router_jitter_noise = getattr(
228
+ config, "router_jitter_noise", 0.01
229
+ ) # Prevents expert collapse
230
+
231
+ in_dim = config.encoder_dim * self.k
232
+ out_dim = config.llm_dim
233
+
234
+ # Expert hidden dim (default = output dim)
235
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim
236
+
237
+ # Number of experts and top-k selection
238
+ self.num_experts = getattr(config, "num_experts", 4)
239
+ self.top_k = getattr(config, "num_experts_per_tok", 2)
240
+
241
+ # A. Normalize stacked input (like main branch SharedMoEBlock)
242
+ self.norm = LlamaRMSNorm(in_dim, eps=1e-6)
243
+
244
+ # B. Router (operates on stacked input)
245
+ self.router = nn.Linear(in_dim, self.num_experts, bias=False)
246
+
247
+ # C. Experts: simple 2-layer MLP (same as MLPAudioProjector)
248
+ self.experts = nn.ModuleList(
249
+ [SimpleAdapter(in_dim, hidden_dim, out_dim) for _ in range(self.num_experts)]
250
+ )
251
+
252
+ # D. Shared Expert (same architecture)
253
+ self.shared_expert = SimpleAdapter(in_dim, hidden_dim, out_dim)
254
+
255
+ # E. Initialize weights for stable training
256
+ self._init_weights()
257
+
258
+ self.last_aux_loss = torch.tensor(0.0)
259
+
260
+ def _init_weights(self):
261
+ """Initialize weights for stable training start."""
262
+ with torch.no_grad():
263
+ # Router: small weights -> uniform probability
264
+ nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
265
+
266
+ # Experts: xavier for fc1, small for fc2 (output)
267
+ for expert in [self.shared_expert, *self.experts]:
268
+ nn.init.xavier_uniform_(expert.fc1.weight)
269
+ nn.init.normal_(expert.fc2.weight, mean=0.0, std=0.01) # Small init
270
+
271
+ def get_output_length(self, input_length: int) -> int:
272
+ """Calculate output sequence length given input length (matches MLP projector)."""
273
+ return (input_length - self.k) // self.k + 1
274
+
275
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
276
+ """Project audio features using shared + sparse MoE.
277
+
278
+ Args:
279
+ x: Audio encoder output of shape [batch, seq_len, encoder_dim]
280
+
281
+ Returns:
282
+ Projected features of shape [batch, out_len, llm_dim]
283
+ """
284
+ # 1. Frame Stacking
285
+ batch, seq, dim = x.shape
286
+ out_len = (seq - self.k) // self.k + 1
287
+ x = x[:, : out_len * self.k, :]
288
+ x = x.reshape(batch, out_len, dim * self.k)
289
+
290
+ # 2. Normalize stacked input (like main branch SharedMoEBlock)
291
+ x = self.norm(x)
292
+ flat_x = x.view(-1, x.size(-1)) # [tokens, in_dim]
293
+
294
+ # 3. Shared Expert (compute first, creates output tensor)
295
+ output = self.shared_expert(flat_x)
296
+
297
+ # 4. Sparse Experts (in-place add to shared output)
298
+ self.last_aux_loss = self._forward_sparse(flat_x, output)
299
+
300
+ return output.view(batch, out_len, -1)
301
+
302
+ def _forward_sparse(self, x: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
303
+ """Stability-hardened sparse expert dispatch (in-place add to output).
304
+
305
+ Args:
306
+ x: Flattened input of shape [tokens, dim]
307
+ output: Output tensor to add sparse expert results into (in-place)
308
+
309
+ Returns:
310
+ Auxiliary loss tensor
311
+ """
312
+ # A. Router Logic with Jitter
313
+ logits = self.router(x)
314
+
315
+ if self.training and self.router_jitter_noise > 0:
316
+ # Jitter: multiply by uniform noise (1-eps, 1+eps) to shake decision boundary
317
+ # Prevents router from getting stuck on one expert early in training
318
+ noise = torch.empty_like(logits).uniform_(
319
+ 1.0 - self.router_jitter_noise, 1.0 + self.router_jitter_noise
320
+ )
321
+ logits = logits * noise
322
+
323
+ # Force float32 for softmax (bf16/fp16 exponentials can overflow)
324
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(x)
325
+
326
+ # B. Top-K Selection
327
+ top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
328
+
329
+ # Normalize weights so they sum to 1.0
330
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6)
331
+
332
+ # C. Aux Loss + Z-Loss
333
+ aux_loss = torch.tensor(0.0, device=x.device)
334
+
335
+ if self.training:
336
+ # Load balancing loss (batch-size invariant)
337
+ prob_per_expert = probs.mean(0) # [num_experts]
338
+ target = 1.0 / self.num_experts
339
+ balance_loss = (
340
+ self.aux_coef * ((prob_per_expert - target) ** 2).mean() * self.num_experts
341
+ )
342
+
343
+ # Z-loss: penalty on large logits to prevent softmax saturation
344
+ z_loss = self.router_z_loss_coef * torch.logsumexp(logits, dim=-1).pow(2).mean()
345
+
346
+ aux_loss = balance_loss + z_loss
347
+
348
+ # D. Dispatch Loop (in-place add to output)
349
+ for i, expert in enumerate(self.experts):
350
+ # Create boolean mask for tokens that selected Expert 'i'
351
+ mask = top_k_indices == i
352
+
353
+ if mask.any():
354
+ # token_idx = which tokens, k_idx = 1st or 2nd choice
355
+ token_idx, k_idx = torch.where(mask)
356
+
357
+ # Gather inputs and compute
358
+ expert_input = x[token_idx]
359
+ expert_output = expert(expert_input)
360
+
361
+ # Apply routing weight
362
+ weight = top_k_weights[token_idx, k_idx].unsqueeze(-1)
363
+ weighted_output = (expert_output * weight).type_as(output)
364
+
365
+ # Scatter back in-place (index_add_ is atomic and deterministic)
366
+ output.index_add_(0, token_idx, weighted_output)
367
+
368
+ return aux_loss
369
+
370
+ def get_aux_loss(self) -> torch.Tensor:
371
+ """Return auxiliary load balancing loss."""
372
+ return self.last_aux_loss
373
+
374
+
375
+ # =============================================================================
376
+ # QFormer Projector (Granite-style)
377
+ # =============================================================================
378
+
379
+
380
+ class QFormerAudioProjector(nn.Module):
381
+ """
382
+ BLIP-2 QFormer projector with learnable queries.
383
+
384
+ Based on GraniteSpeechEncoderProjector - uses a QFormer model with learnable
385
+ query embeddings to compress and project audio encoder outputs. The audio
386
+ sequence is processed in windows and downsampled via cross-attention.
387
+ """
388
+
389
+ def __init__(self, config):
390
+ """Initialize QFormer projector.
391
+
392
+ Args:
393
+ config: ASRConfig with encoder_dim, llm_dim, qformer_* settings
394
+ """
395
+ super().__init__()
396
+
397
+ encoder_dim = config.encoder_dim
398
+ llm_dim = config.llm_dim
399
+
400
+ # Window and downsampling parameters (Granite defaults: window=15, downsample=5)
401
+ self.window_size = getattr(config, "qformer_window_size", 15)
402
+ self.downsample_rate = getattr(config, "downsample_rate", 5)
403
+ self.num_queries = self.window_size // self.downsample_rate
404
+
405
+ # QFormer hidden size (matches encoder for cross-attention)
406
+ qformer_hidden = getattr(config, "qformer_hidden_size", None) or encoder_dim
407
+ qformer_num_layers = getattr(config, "qformer_num_layers", 2)
408
+ qformer_num_heads = getattr(config, "qformer_num_heads", 16)
409
+ qformer_intermediate = getattr(config, "qformer_intermediate_size", None) or (
410
+ qformer_hidden * 4
411
+ )
412
+
413
+ # Learnable query embeddings (Granite uses std=1.0)
414
+ self.query = nn.Parameter(torch.zeros(1, self.num_queries, qformer_hidden))
415
+ self.query.data.normal_(mean=0.0, std=1.0)
416
+
417
+ # Optional projection if encoder dim != qformer hidden
418
+ if encoder_dim != qformer_hidden:
419
+ self.encoder_proj = nn.Linear(encoder_dim, qformer_hidden, bias=False)
420
+ else:
421
+ self.encoder_proj = None
422
+
423
+ # Configure QFormer to match Granite's exact config
424
+ qformer_config = Blip2QFormerConfig(
425
+ hidden_size=qformer_hidden,
426
+ num_hidden_layers=qformer_num_layers,
427
+ num_attention_heads=qformer_num_heads,
428
+ intermediate_size=qformer_intermediate,
429
+ encoder_hidden_size=qformer_hidden,
430
+ cross_attention_frequency=1,
431
+ # Granite-specific settings
432
+ hidden_act="gelu",
433
+ attention_probs_dropout_prob=0.1,
434
+ hidden_dropout_prob=0.1,
435
+ layer_norm_eps=1e-12,
436
+ initializer_range=0.02,
437
+ )
438
+ self.qformer = AutoModel.from_config(qformer_config)
439
+
440
+ # Final projection to LLM dimension (Granite uses bias=True)
441
+ self.linear = nn.Linear(qformer_hidden, llm_dim)
442
+
443
+ def get_output_length(self, input_length: int) -> int:
444
+ """Calculate output sequence length given input length."""
445
+ # QFormer uses window-based processing with num_queries per window
446
+ nblocks = math.ceil(input_length / self.window_size)
447
+ return nblocks * self.num_queries
448
+
449
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
450
+ """
451
+ Args:
452
+ hidden_states: [batch_size, seq_len, encoder_dim]
453
+
454
+ Returns:
455
+ projected: [batch_size, num_output_tokens, llm_dim]
456
+ """
457
+ batch_size, seq_len, dim = hidden_states.size()
458
+
459
+ # Ensure float dtype for QFormer
460
+ target_dtype = self.query.dtype
461
+ if hidden_states.dtype != target_dtype:
462
+ hidden_states = hidden_states.to(target_dtype)
463
+
464
+ # Optional encoder projection
465
+ if self.encoder_proj is not None:
466
+ hidden_states = self.encoder_proj(hidden_states)
467
+
468
+ # Compute number of windows and pad to fit
469
+ nblocks = math.ceil(seq_len / self.window_size)
470
+ pad = nblocks * self.window_size - seq_len
471
+ if pad > 0:
472
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
473
+
474
+ # Reshape to process each window: [batch*nblocks, window_size, dim]
475
+ effective_batch = batch_size * nblocks
476
+ hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
477
+
478
+ # Expand queries to match batch size
479
+ query_embeds = self.query.expand(effective_batch, -1, -1)
480
+
481
+ # QFormer cross-attention
482
+ query_output = self.qformer(
483
+ query_embeds=query_embeds,
484
+ encoder_hidden_states=hidden_states,
485
+ return_dict=True,
486
+ )
487
+
488
+ # Reshape back: [batch, nblocks * num_queries, hidden]
489
+ output_tokens = nblocks * self.num_queries
490
+ query_proj = query_output.last_hidden_state.view(batch_size, output_tokens, -1)
491
+
492
+ # Project to LLM dimension
493
+ return self.linear(query_proj)
494
+
495
+
496
+ # =============================================================================
497
+ # Projector Registry
498
+ # =============================================================================
499
+
500
+ PROJECTOR_CLASSES = {
501
+ "mlp": MLPAudioProjector,
502
+ "mosa": MOSAProjector,
503
+ "moe": MoEAudioProjector,
504
+ "qformer": QFormerAudioProjector,
505
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|im_start|>",
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "errors": "replace",
8
+ "extra_special_tokens": [
9
+ "<audio>"
10
+ ],
11
+ "is_local": false,
12
+ "local_files_only": false,
13
+ "model_max_length": 8192,
14
+ "pad_token": "<|im_end|>",
15
+ "tokenizer_class": "GPT2Tokenizer",
16
+ "unk_token": "<|endoftext|>",
17
+ "vocab_size": 49152
18
+ }