mazesmazes commited on
Commit
558cbae
·
verified ·
1 Parent(s): e494c2f

Training in progress - step 1000

Browse files
asr_config.py CHANGED
@@ -51,6 +51,12 @@ class ASRConfig(transformers.PretrainedConfig):
51
  downsample_rate: int = 5, # Granite default
52
  projector_hidden_dim: Optional[int] = None,
53
  projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
 
 
 
 
 
 
54
  # MoE-specific configuration
55
  num_experts: int = 4, # Number of experts in MoE projectors
56
  num_experts_per_tok: int = 2, # Top-k experts per token
@@ -117,6 +123,7 @@ class ASRConfig(transformers.PretrainedConfig):
117
  self.downsample_rate = downsample_rate
118
  self.projector_hidden_dim = projector_hidden_dim
119
  self.projector_type = projector_type
 
120
  # MoE-specific configuration
121
  self.num_experts = num_experts
122
  self.num_experts_per_tok = num_experts_per_tok
 
51
  downsample_rate: int = 5, # Granite default
52
  projector_hidden_dim: Optional[int] = None,
53
  projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
54
+ # Per-time-step Bernoulli zero-mask on encoder output before the
55
+ # projector (training-only). 0.05–0.15 is the SpecAugment-equivalent
56
+ # range for frozen-encoder setups; drops whole encoder frames so
57
+ # the projector learns robustness to missing context. No magnitude
58
+ # rescaling. 0.0 disables.
59
+ audio_token_dropout: float = 0.0,
60
  # MoE-specific configuration
61
  num_experts: int = 4, # Number of experts in MoE projectors
62
  num_experts_per_tok: int = 2, # Top-k experts per token
 
123
  self.downsample_rate = downsample_rate
124
  self.projector_hidden_dim = projector_hidden_dim
125
  self.projector_type = projector_type
126
+ self.audio_token_dropout = audio_token_dropout
127
  # MoE-specific configuration
128
  self.num_experts = num_experts
129
  self.num_experts_per_tok = num_experts_per_tok
asr_modeling.py CHANGED
@@ -449,11 +449,35 @@ class ASRModel(PreTrainedModel, GenerationMixin):
449
  encoder_out = self.audio_tower(input_features=audio_features)
450
  hidden_states = encoder_out.last_hidden_state
451
 
 
452
  audio_embeds = self.projector(hidden_states)
453
 
454
  token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long)
455
  return _gather_audio_embeds(audio_embeds, token_counts)
456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  def forward(
458
  self,
459
  input_ids: Optional[torch.Tensor] = None,
 
449
  encoder_out = self.audio_tower(input_features=audio_features)
450
  hidden_states = encoder_out.last_hidden_state
451
 
452
+ hidden_states = self._maybe_drop_audio_tokens(hidden_states)
453
  audio_embeds = self.projector(hidden_states)
454
 
455
  token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long)
456
  return _gather_audio_embeds(audio_embeds, token_counts)
457
 
458
+ def _maybe_drop_audio_tokens(self, hidden_states: torch.Tensor) -> torch.Tensor:
459
+ """Per-time-step Bernoulli zero-mask on encoder output (train-only).
460
+
461
+ SpecAugment-equivalent for frozen-encoder setups: drops whole frames
462
+ from the encoder output sequence so the projector learns robustness
463
+ to missing context. Length-preserving (zeros, not deletions) so
464
+ audio token counts in the prompt stay consistent. No magnitude
465
+ rescaling — the projector should not learn to compensate.
466
+ """
467
+ p = float(getattr(self.config, "audio_token_dropout", 0.0))
468
+ if not self.training or p <= 0.0:
469
+ return hidden_states
470
+ keep = 1.0 - p
471
+ mask = torch.bernoulli(
472
+ torch.full(
473
+ hidden_states.shape[:-1],
474
+ keep,
475
+ device=hidden_states.device,
476
+ dtype=hidden_states.dtype,
477
+ )
478
+ ).unsqueeze(-1)
479
+ return hidden_states * mask
480
+
481
  def forward(
482
  self,
483
  input_ids: Optional[torch.Tensor] = None,
config.json CHANGED
@@ -103,6 +103,7 @@
103
  },
104
  "audio_model_id": "zai-org/GLM-ASR-Nano-2512",
105
  "audio_sample_rate": 16000,
 
106
  "auto_map": {
107
  "AutoConfig": "asr_config.ASRConfig",
108
  "AutoModel": "asr_modeling.ASRModel",
@@ -340,7 +341,7 @@
340
  "text_model_id": "Qwen/Qwen3-0.6B",
341
  "top_k": null,
342
  "top_p": null,
343
- "transformers_version": "5.6.1",
344
  "use_cache": false,
345
  "use_lora": false,
346
  "vocab_size": 151670
 
103
  },
104
  "audio_model_id": "zai-org/GLM-ASR-Nano-2512",
105
  "audio_sample_rate": 16000,
106
+ "audio_token_dropout": 0.1,
107
  "auto_map": {
108
  "AutoConfig": "asr_config.ASRConfig",
109
  "AutoModel": "asr_modeling.ASRModel",
 
341
  "text_model_id": "Qwen/Qwen3-0.6B",
342
  "top_k": null,
343
  "top_p": null,
344
+ "transformers_version": "5.7.0",
345
  "use_cache": false,
346
  "use_lora": false,
347
  "vocab_size": 151670
generation_config.json CHANGED
@@ -12,6 +12,6 @@
12
  "num_beams": 1,
13
  "pad_token_id": 151643,
14
  "repetition_penalty": 1.0,
15
- "transformers_version": "5.6.1",
16
  "use_cache": true
17
  }
 
12
  "num_beams": 1,
13
  "pad_token_id": 151643,
14
  "repetition_penalty": 1.0,
15
+ "transformers_version": "5.7.0",
16
  "use_cache": true
17
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:333242aa801151639eea5cc44531b6ffd9678d6438423fcba55f0194979c1ceb
3
  size 2433494416
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa4b92ee0d0748185d09b2af6dea12c49805b253fd9b5350f3ae559299148424
3
  size 2433494416