Training in progress - step 1000
Browse files- asr_config.py +7 -0
- asr_modeling.py +24 -0
- config.json +2 -1
- generation_config.json +1 -1
- model.safetensors +1 -1
asr_config.py
CHANGED
|
@@ -51,6 +51,12 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 51 |
downsample_rate: int = 5, # Granite default
|
| 52 |
projector_hidden_dim: Optional[int] = None,
|
| 53 |
projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
# MoE-specific configuration
|
| 55 |
num_experts: int = 4, # Number of experts in MoE projectors
|
| 56 |
num_experts_per_tok: int = 2, # Top-k experts per token
|
|
@@ -117,6 +123,7 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 117 |
self.downsample_rate = downsample_rate
|
| 118 |
self.projector_hidden_dim = projector_hidden_dim
|
| 119 |
self.projector_type = projector_type
|
|
|
|
| 120 |
# MoE-specific configuration
|
| 121 |
self.num_experts = num_experts
|
| 122 |
self.num_experts_per_tok = num_experts_per_tok
|
|
|
|
| 51 |
downsample_rate: int = 5, # Granite default
|
| 52 |
projector_hidden_dim: Optional[int] = None,
|
| 53 |
projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
|
| 54 |
+
# Per-time-step Bernoulli zero-mask on encoder output before the
|
| 55 |
+
# projector (training-only). 0.05–0.15 is the SpecAugment-equivalent
|
| 56 |
+
# range for frozen-encoder setups; drops whole encoder frames so
|
| 57 |
+
# the projector learns robustness to missing context. No magnitude
|
| 58 |
+
# rescaling. 0.0 disables.
|
| 59 |
+
audio_token_dropout: float = 0.0,
|
| 60 |
# MoE-specific configuration
|
| 61 |
num_experts: int = 4, # Number of experts in MoE projectors
|
| 62 |
num_experts_per_tok: int = 2, # Top-k experts per token
|
|
|
|
| 123 |
self.downsample_rate = downsample_rate
|
| 124 |
self.projector_hidden_dim = projector_hidden_dim
|
| 125 |
self.projector_type = projector_type
|
| 126 |
+
self.audio_token_dropout = audio_token_dropout
|
| 127 |
# MoE-specific configuration
|
| 128 |
self.num_experts = num_experts
|
| 129 |
self.num_experts_per_tok = num_experts_per_tok
|
asr_modeling.py
CHANGED
|
@@ -449,11 +449,35 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 449 |
encoder_out = self.audio_tower(input_features=audio_features)
|
| 450 |
hidden_states = encoder_out.last_hidden_state
|
| 451 |
|
|
|
|
| 452 |
audio_embeds = self.projector(hidden_states)
|
| 453 |
|
| 454 |
token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long)
|
| 455 |
return _gather_audio_embeds(audio_embeds, token_counts)
|
| 456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
def forward(
|
| 458 |
self,
|
| 459 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 449 |
encoder_out = self.audio_tower(input_features=audio_features)
|
| 450 |
hidden_states = encoder_out.last_hidden_state
|
| 451 |
|
| 452 |
+
hidden_states = self._maybe_drop_audio_tokens(hidden_states)
|
| 453 |
audio_embeds = self.projector(hidden_states)
|
| 454 |
|
| 455 |
token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long)
|
| 456 |
return _gather_audio_embeds(audio_embeds, token_counts)
|
| 457 |
|
| 458 |
+
def _maybe_drop_audio_tokens(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 459 |
+
"""Per-time-step Bernoulli zero-mask on encoder output (train-only).
|
| 460 |
+
|
| 461 |
+
SpecAugment-equivalent for frozen-encoder setups: drops whole frames
|
| 462 |
+
from the encoder output sequence so the projector learns robustness
|
| 463 |
+
to missing context. Length-preserving (zeros, not deletions) so
|
| 464 |
+
audio token counts in the prompt stay consistent. No magnitude
|
| 465 |
+
rescaling — the projector should not learn to compensate.
|
| 466 |
+
"""
|
| 467 |
+
p = float(getattr(self.config, "audio_token_dropout", 0.0))
|
| 468 |
+
if not self.training or p <= 0.0:
|
| 469 |
+
return hidden_states
|
| 470 |
+
keep = 1.0 - p
|
| 471 |
+
mask = torch.bernoulli(
|
| 472 |
+
torch.full(
|
| 473 |
+
hidden_states.shape[:-1],
|
| 474 |
+
keep,
|
| 475 |
+
device=hidden_states.device,
|
| 476 |
+
dtype=hidden_states.dtype,
|
| 477 |
+
)
|
| 478 |
+
).unsqueeze(-1)
|
| 479 |
+
return hidden_states * mask
|
| 480 |
+
|
| 481 |
def forward(
|
| 482 |
self,
|
| 483 |
input_ids: Optional[torch.Tensor] = None,
|
config.json
CHANGED
|
@@ -103,6 +103,7 @@
|
|
| 103 |
},
|
| 104 |
"audio_model_id": "zai-org/GLM-ASR-Nano-2512",
|
| 105 |
"audio_sample_rate": 16000,
|
|
|
|
| 106 |
"auto_map": {
|
| 107 |
"AutoConfig": "asr_config.ASRConfig",
|
| 108 |
"AutoModel": "asr_modeling.ASRModel",
|
|
@@ -340,7 +341,7 @@
|
|
| 340 |
"text_model_id": "Qwen/Qwen3-0.6B",
|
| 341 |
"top_k": null,
|
| 342 |
"top_p": null,
|
| 343 |
-
"transformers_version": "5.
|
| 344 |
"use_cache": false,
|
| 345 |
"use_lora": false,
|
| 346 |
"vocab_size": 151670
|
|
|
|
| 103 |
},
|
| 104 |
"audio_model_id": "zai-org/GLM-ASR-Nano-2512",
|
| 105 |
"audio_sample_rate": 16000,
|
| 106 |
+
"audio_token_dropout": 0.1,
|
| 107 |
"auto_map": {
|
| 108 |
"AutoConfig": "asr_config.ASRConfig",
|
| 109 |
"AutoModel": "asr_modeling.ASRModel",
|
|
|
|
| 341 |
"text_model_id": "Qwen/Qwen3-0.6B",
|
| 342 |
"top_k": null,
|
| 343 |
"top_p": null,
|
| 344 |
+
"transformers_version": "5.7.0",
|
| 345 |
"use_cache": false,
|
| 346 |
"use_lora": false,
|
| 347 |
"vocab_size": 151670
|
generation_config.json
CHANGED
|
@@ -12,6 +12,6 @@
|
|
| 12 |
"num_beams": 1,
|
| 13 |
"pad_token_id": 151643,
|
| 14 |
"repetition_penalty": 1.0,
|
| 15 |
-
"transformers_version": "5.
|
| 16 |
"use_cache": true
|
| 17 |
}
|
|
|
|
| 12 |
"num_beams": 1,
|
| 13 |
"pad_token_id": 151643,
|
| 14 |
"repetition_penalty": 1.0,
|
| 15 |
+
"transformers_version": "5.7.0",
|
| 16 |
"use_cache": true
|
| 17 |
}
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 2433494416
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa4b92ee0d0748185d09b2af6dea12c49805b253fd9b5350f3ae559299148424
|
| 3 |
size 2433494416
|