| import json |
| from pathlib import Path |
| from threading import Thread |
| from typing import Iterator, Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import ( |
| AutoModel, |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| PreTrainedModel, |
| TextIteratorStreamer, |
| ) |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| try: |
| from .asr_config import ASRConfig, compute_encoder_output_length |
| from .projectors import PROJECTOR_CLASSES |
| except ImportError: |
| from asr_config import ASRConfig, compute_encoder_output_length |
| from projectors import PROJECTOR_CLASSES |
|
|
|
|
| def _gather_audio_embeds(audio_embeds: torch.Tensor, token_counts: torch.Tensor) -> torch.Tensor: |
| """Flatten per-sample audio embeddings into a packed tensor. |
| |
| For each row i, takes the first ``token_counts[i]`` rows of |
| ``audio_embeds[i]`` and concatenates them. If any token count exceeds |
| ``audio_embeds.shape[1]``, the deficit is zero-padded. |
| |
| Equivalent to a per-sample slice/cat loop but with O(1) host-device |
| syncs per call (one ``max().item()``) instead of one per sample. |
| """ |
| _, max_len, _ = audio_embeds.shape |
| needed = int(token_counts.max().item()) |
| if needed > max_len: |
| audio_embeds = F.pad(audio_embeds, (0, 0, 0, needed - max_len)) |
| max_len = needed |
| indices = torch.arange(max_len, device=audio_embeds.device).unsqueeze(0) |
| mask = indices < token_counts.unsqueeze(1) |
| return audio_embeds[mask] |
|
|
|
|
| class ASRModel(PreTrainedModel, GenerationMixin): |
| """Audio-to-text model combining an audio encoder, projector, and language model.""" |
|
|
| config_class = ASRConfig |
| base_model_prefix = "model" |
| main_input_name = "input_features" |
| _supports_flash_attn_2 = True |
| supports_gradient_checkpointing = True |
| _is_loading_from_pretrained: bool = False |
|
|
| TRANSCRIBE_PROMPT = "Transcribe the speech to text" |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel": |
| """Load model from pretrained, handling device placement correctly.""" |
| from safetensors.torch import load_file |
| from transformers.utils.hub import cached_file |
|
|
| config = kwargs.pop("config", None) |
| if config is None: |
| config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
|
| |
| cls._is_loading_from_pretrained = True |
|
|
| try: |
| model = cls(config, **kwargs) |
|
|
| |
| subfolder = kwargs.get("subfolder") |
| revision = kwargs.get("revision") |
| cache_kwargs = {} |
| if subfolder: |
| cache_kwargs["subfolder"] = subfolder |
| if revision: |
| cache_kwargs["revision"] = revision |
|
|
| model_file = cached_file( |
| pretrained_model_name_or_path, |
| "model.safetensors", |
| _raise_exceptions_for_missing_entries=False, |
| **cache_kwargs, |
| ) |
|
|
| if model_file is not None: |
| state_dict = load_file(model_file) |
| model.load_state_dict(state_dict, strict=False) |
|
|
| |
| if getattr(config, "use_lora", False): |
| |
| adapter_config_file = cached_file( |
| pretrained_model_name_or_path, |
| "adapter_config.json", |
| _raise_exceptions_for_missing_entries=False, |
| **cache_kwargs, |
| ) |
| if adapter_config_file is not None: |
| |
| |
| from peft import PeftModel |
|
|
| model.language_model = PeftModel.from_pretrained( |
| model.language_model, |
| pretrained_model_name_or_path, |
| is_trainable=True, |
| **cache_kwargs, |
| ) |
| else: |
| |
| from peft import LoraConfig, get_peft_model |
|
|
| lora_config = LoraConfig( |
| r=config.lora_rank, |
| lora_alpha=config.lora_alpha, |
| target_modules=config.lora_target_modules, |
| lora_dropout=config.lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| model.language_model = get_peft_model(model.language_model, lora_config) |
|
|
| return model |
| finally: |
| cls._is_loading_from_pretrained = False |
|
|
| def __init__(self, config: ASRConfig, **kwargs) -> None: |
| super().__init__(config) |
|
|
| self.system_prompt = config.system_prompt |
| target_dtype = getattr(torch, config.model_dtype) |
|
|
| |
| self.audio_tower = self._load_audio_encoder(config, target_dtype) |
|
|
| |
| self.language_model = self._load_language_model(config, target_dtype) |
|
|
| |
| self._init_tokenizer(config) |
|
|
| |
| self.generation_config = self.language_model.generation_config |
| self.generation_config.max_new_tokens = config.max_new_tokens |
| self.generation_config.min_new_tokens = config.min_new_tokens |
| self.generation_config.num_beams = config.num_beams |
| self.generation_config.do_sample = config.do_sample |
| |
| self.generation_config.temperature = config.temperature |
| self.generation_config.top_p = config.top_p |
| self.generation_config.top_k = config.top_k |
| self.generation_config.use_cache = config.use_cache |
| self.generation_config.length_penalty = config.length_penalty |
| self.generation_config.repetition_penalty = config.repetition_penalty |
| self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size |
| |
| eos_candidates = [ |
| self.tokenizer.convert_tokens_to_ids("<|im_end|>"), |
| self.tokenizer.convert_tokens_to_ids("<|endoftext|>"), |
| ] |
| self.generation_config.eos_token_id = [t for t in eos_candidates if t is not None] |
| self.generation_config.pad_token_id = self.tokenizer.pad_token_id |
|
|
| |
| self.feature_extractor = self._create_feature_extractor(config) |
|
|
| |
| self.projector = self._create_projector(config, target_dtype) |
|
|
| |
| |
| if getattr(config, "use_lora", False) and not getattr( |
| self.__class__, "_is_loading_from_pretrained", False |
| ): |
| self._setup_lora(config) |
|
|
| |
| if getattr(config, "freeze_projector", False): |
| self.projector.requires_grad_(False) |
|
|
| |
| self._no_split_modules = getattr(self.language_model, "_no_split_modules", []) |
|
|
| def _create_feature_extractor(self, config: ASRConfig): |
| """Create the appropriate feature extractor for the audio encoder.""" |
| from transformers import AutoFeatureExtractor |
|
|
| feature_extractor = AutoFeatureExtractor.from_pretrained(config.audio_model_id) |
| |
| |
| |
| |
| if "whisper" not in config.audio_model_id.lower(): |
| feature_extractor.padding = False |
| return feature_extractor |
|
|
| @classmethod |
| def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module: |
| """Load and freeze the audio encoder.""" |
| encoder_kwargs = { |
| "attn_implementation": config.attn_implementation, |
| "low_cpu_mem_usage": True, |
| "dtype": dtype, |
| } |
|
|
| if "whisper" in config.audio_model_id.lower(): |
| from transformers import WhisperModel |
|
|
| full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs) |
| encoder = full_model.encoder |
| del full_model |
| elif "glm" in config.audio_model_id.lower(): |
| |
| |
| from transformers import AutoModelForSeq2SeqLM |
|
|
| full_model = AutoModelForSeq2SeqLM.from_pretrained( |
| config.audio_model_id, trust_remote_code=True, **encoder_kwargs |
| ) |
| |
| encoder = full_model.audio_tower |
| |
| full_model.language_model = None |
| full_model.multi_modal_projector = None |
| del full_model |
| else: |
| encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs) |
|
|
| encoder.requires_grad_(False) |
| encoder.eval() |
| return encoder |
|
|
| @classmethod |
| def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel: |
| """Load and freeze the language model.""" |
| decoder_kwargs = { |
| "attn_implementation": config.attn_implementation, |
| "trust_remote_code": True, |
| "low_cpu_mem_usage": True, |
| "dtype": dtype, |
| } |
|
|
| decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs) |
| decoder.config.use_cache = getattr(config, "use_cache", True) |
| if getattr(config, "freeze_language_model", True): |
| decoder.requires_grad_(False) |
| decoder.train(False) |
| return decoder |
|
|
| def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module: |
| """Create the trainable audio projector.""" |
| |
| if config.encoder_dim is None: |
| enc_cfg = self.audio_tower.config |
| config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr( |
| enc_cfg, "d_model", None |
| ) |
| if config.encoder_dim is None: |
| raise ValueError("Could not auto-detect encoder_dim. Please specify in config.") |
|
|
| if config.llm_dim is None: |
| dec_cfg = self.language_model.config |
| config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr( |
| dec_cfg, "d_model", None |
| ) |
| if config.llm_dim is None: |
| raise ValueError("Could not auto-detect llm_dim. Please specify in config.") |
|
|
| |
| projector_type = getattr(config, "projector_type", "mlp") |
| projector_class = PROJECTOR_CLASSES.get(projector_type) |
| if projector_class is None: |
| raise ValueError( |
| f"Unknown projector_type: {projector_type}. " |
| f"Valid options: {list(PROJECTOR_CLASSES.keys())}" |
| ) |
| projector = projector_class(config) |
|
|
| |
| device = next(self.language_model.parameters()).device |
| return projector.to(device=device, dtype=dtype) |
|
|
| def _setup_lora(self, config: ASRConfig): |
| """Apply LoRA adapters to the language model for Stage 2 fine-tuning.""" |
| from peft import LoraConfig, get_peft_model |
|
|
| lora_config = LoraConfig( |
| r=config.lora_rank, |
| lora_alpha=config.lora_alpha, |
| target_modules=config.lora_target_modules, |
| lora_dropout=config.lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| self.language_model = get_peft_model(self.language_model, lora_config) |
|
|
| def _init_tokenizer(self, config: ASRConfig): |
| """Initialize tokenizer with audio token.""" |
| self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True) |
|
|
| |
| |
| |
| |
| if ( |
| self.tokenizer.pad_token is None |
| or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id |
| ): |
| if "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab(): |
| self.tokenizer.pad_token = "<|finetune_right_pad_id|>" |
| elif self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| |
| existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or [] |
| if "<audio>" not in existing_special: |
| self.tokenizer.add_special_tokens( |
| {"additional_special_tokens": existing_special + ["<audio>"]} |
| ) |
| |
| |
| |
| |
| |
| |
| self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=True) |
|
|
| self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>") |
| self.tokenizer.padding_side = "right" |
|
|
| |
| for cfg in [self.config.text_config, self.language_model.config, self.generation_config]: |
| if cfg is not None: |
| cfg.pad_token_id = self.tokenizer.pad_token_id |
| cfg.eos_token_id = self.tokenizer.eos_token_id |
| cfg.bos_token_id = self.tokenizer.bos_token_id |
|
|
| def train(self, mode: bool = True): |
| """Set train/eval mode, but keep frozen submodules out of train mode. |
| |
| HF Trainer calls `model.train()` at the top of every training step, which |
| recursively switches every submodule into train mode — re-enabling dropout |
| on modules with `requires_grad_(False)`. The frozen encoder (and the LM |
| when `freeze_language_model=True`) should always run deterministically; |
| train-mode dropout only adds noise that can't improve a frozen network. |
| """ |
| super().train(mode) |
| self.audio_tower.train(False) |
| if getattr(self.config, "freeze_language_model", True): |
| self.language_model.train(False) |
| return self |
|
|
| def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None): |
| """Enable/disable gradient checkpointing for the language model.""" |
| |
| |
| if hasattr(self.language_model, "_set_gradient_checkpointing"): |
| self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func) |
| elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable: |
| self.language_model.gradient_checkpointing_enable( |
| gradient_checkpointing_kwargs={"use_reentrant": False} |
| ) |
| elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable: |
| self.language_model.gradient_checkpointing_disable() |
|
|
| def get_input_embeddings(self) -> nn.Module: |
| return self.language_model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value: nn.Module) -> None: |
| self.language_model.set_input_embeddings(value) |
|
|
| def get_output_embeddings(self) -> nn.Module: |
| return self.language_model.get_output_embeddings() |
|
|
| def set_output_embeddings(self, value: nn.Module) -> None: |
| self.language_model.set_output_embeddings(value) |
|
|
| def get_processor(self): |
| """Get the processor for this model.""" |
| try: |
| from .asr_processing import ASRProcessor |
| except ImportError: |
| from asr_processing import ASRProcessor |
|
|
| return ASRProcessor( |
| feature_extractor=self.feature_extractor, |
| tokenizer=self.tokenizer, |
| projector=self.projector, |
| encoder_conv_layers=self.config.encoder_conv_layers, |
| ) |
|
|
| def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]: |
| """Save trainable weights: projector, plus the language model when fine-tuned. |
| |
| With LoRA attached, the language_model entries are flattened to plain |
| (non-PEFT) HF naming so model.safetensors round-trips through |
| ASRModel.from_pretrained — which builds a vanilla base LM, overlays |
| these weights, and only then re-attaches PEFT. lora_*/adapter weights |
| are skipped here; PEFT serializes them separately as |
| adapter_model.safetensors via the save_pretrained path below. |
| """ |
| sd = {f"projector.{k}": v for k, v in self.projector.state_dict().items()} |
| if not getattr(self.config, "freeze_language_model", True): |
| lm = self.language_model |
| if hasattr(lm, "peft_config"): |
| for k, v in lm.state_dict().items(): |
| if "lora_" in k: |
| continue |
| if k.startswith("base_model.model."): |
| k = k[len("base_model.model.") :] |
| |
| k = k.replace(".base_layer.", ".") |
| sd[f"language_model.{k}"] = v |
| else: |
| sd.update({f"language_model.{k}": v for k, v in lm.state_dict().items()}) |
| return sd |
|
|
| def _compute_encoder_output_lengths( |
| self, |
| audio_attention_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """Compute per-sample encoder output lengths using conv layer formulas.""" |
| return compute_encoder_output_length( |
| audio_attention_mask.sum(dim=-1), |
| self.config.encoder_conv_layers, |
| ) |
|
|
| def _encode_audio( |
| self, |
| audio_features: torch.Tensor, |
| expected_token_counts: torch.Tensor, |
| ) -> torch.Tensor: |
| """Encode audio features and return flattened embeddings matching expected_token_counts. |
| |
| Args: |
| audio_features: Mel spectrogram features (batch, n_mels, mel_len) |
| expected_token_counts: Per-sample audio token counts as int64 tensor (batch,). |
| |
| Returns: |
| Flattened audio embeddings of shape (sum(expected_token_counts), hidden_dim). |
| """ |
| with torch.no_grad(): |
| encoder_out = self.audio_tower(input_features=audio_features) |
| hidden_states = encoder_out.last_hidden_state |
|
|
| hidden_states = self._maybe_drop_audio_tokens(hidden_states) |
| audio_embeds = self.projector(hidden_states) |
|
|
| token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long) |
| return _gather_audio_embeds(audio_embeds, token_counts) |
|
|
| def _maybe_drop_audio_tokens(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """Per-time-step Bernoulli zero-mask on encoder output (train-only). |
| |
| SpecAugment-equivalent for frozen-encoder setups: drops whole frames |
| from the encoder output sequence so the projector learns robustness |
| to missing context. Length-preserving (zeros, not deletions) so |
| audio token counts in the prompt stay consistent. No magnitude |
| rescaling — the projector should not learn to compensate. |
| """ |
| p = float(getattr(self.config, "audio_token_dropout", 0.0)) |
| if not self.training or p <= 0.0: |
| return hidden_states |
| keep = 1.0 - p |
| mask = torch.bernoulli( |
| torch.full( |
| hidden_states.shape[:-1], |
| keep, |
| device=hidden_states.device, |
| dtype=hidden_states.dtype, |
| ) |
| ).unsqueeze(-1) |
| return hidden_states * mask |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| input_features: Optional[torch.Tensor] = None, |
| audio_attention_mask: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| past_key_values: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.Tensor] = None, |
| audio_token_counts: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> CausalLMOutputWithPast: |
| """Forward pass for training and inference.""" |
| if inputs_embeds is None: |
| inputs_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
|
| if input_features is not None and input_ids is not None: |
| is_audio_token = input_ids == self.audio_token_id |
| if audio_token_counts is None: |
| audio_token_counts = is_audio_token.sum(dim=-1) |
| else: |
| audio_token_counts = audio_token_counts.to( |
| device=input_ids.device, dtype=torch.long |
| ) |
|
|
| audio_embeds = self._encode_audio(input_features, audio_token_counts) |
|
|
| audio_token_mask = is_audio_token.unsqueeze(-1) |
| inputs_embeds = inputs_embeds.masked_scatter( |
| audio_token_mask.to(inputs_embeds.device), |
| audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype), |
| ) |
|
|
| outputs = self.language_model( |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"): |
| aux_loss = self.projector.get_aux_loss() |
| if aux_loss is not None and aux_loss.numel() > 0: |
| outputs.loss = outputs.loss + aux_loss.to(outputs.loss.device) |
|
|
| return outputs |
|
|
| def prepare_inputs_for_generation(self, *args, **kwargs): |
| """Prepare inputs for generation, handling audio features for cached decoding.""" |
| input_features = kwargs.pop("input_features", None) |
| cache_position = kwargs.get("cache_position") |
|
|
| model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs) |
|
|
| |
| if cache_position is not None and cache_position[0] == 0 and input_features is not None: |
| model_inputs["input_features"] = input_features |
|
|
| return model_inputs |
|
|
| def _get_num_audio_tokens( |
| self, |
| audio_attention_mask: torch.Tensor, |
| ) -> int: |
| """Calculate number of audio tokens based on actual audio length. |
| |
| Uses attention mask to get real audio length, then computes: |
| mel_frames -> encoder_frames (via conv formulas) -> projector output tokens |
| """ |
| encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask) |
| |
| encoder_output_len = int(encoder_lengths.max().item()) |
| return int(self.projector.get_output_length(encoder_output_len)) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| input_features: Optional[torch.Tensor] = None, |
| audio_attention_mask: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| system_prompt: Optional[str] = None, |
| **generate_kwargs, |
| ) -> torch.Tensor: |
| """Generate transcription from audio input. |
| |
| Can be called in two ways: |
| 1. With input_ids containing <audio> tokens (from processor) |
| 2. With just audio, and we build the prompt internally |
| """ |
| if input_features is None: |
| raise ValueError("input_features required for generation") |
| if audio_attention_mask is None: |
| raise ValueError("audio_attention_mask required for generation") |
|
|
| device = input_features.device |
| batch_size = input_features.shape[0] |
|
|
| |
| encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask) |
| token_counts = self.projector.get_output_length(encoder_lengths).to(torch.long) |
| audio_embeds = self._encode_audio(input_features, token_counts) |
|
|
| |
| if input_ids is None: |
| num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask) |
| audio_placeholder = "<audio>" * num_audio_tokens |
|
|
| system_prompt = system_prompt or self.system_prompt |
|
|
| messages: list[dict[str, str]] = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| |
| user_content = audio_placeholder |
| if self.TRANSCRIBE_PROMPT: |
| user_content += " " + self.TRANSCRIBE_PROMPT |
| messages.append({"role": "user", "content": user_content}) |
|
|
| chat_result = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| enable_thinking=False, |
| ) |
| input_ids = chat_result.input_ids.to(device) |
|
|
| if input_ids.dim() == 1: |
| input_ids = input_ids.unsqueeze(0) |
| if input_ids.shape[0] == 1 and batch_size > 1: |
| input_ids = input_ids.expand(batch_size, -1) |
|
|
| attention_mask = torch.ones_like(input_ids) |
|
|
| |
| inputs_embeds = self.language_model.get_input_embeddings()(input_ids) |
| audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1) |
| inputs_embeds = inputs_embeds.masked_scatter( |
| audio_token_mask.to(inputs_embeds.device), |
| audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype), |
| ) |
|
|
| |
| |
| |
| output = self.language_model.generate( |
| input_ids=input_ids, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| generation_config=self.generation_config, |
| **generate_kwargs, |
| ) |
|
|
| |
| |
| sequences = output if isinstance(output, torch.Tensor) else output.sequences |
| input_len = input_ids.shape[1] |
| return sequences[:, input_len:] |
|
|
| def generate_streaming( |
| self, |
| input_features: torch.Tensor, |
| audio_attention_mask: torch.Tensor, |
| system_prompt: Optional[str] = None, |
| **generate_kwargs, |
| ) -> Iterator[str]: |
| """Generate transcription with streaming token output. |
| |
| Yields partial transcript strings as tokens are generated. |
| Reduces time-to-first-word by streaming tokens as they're decoded. |
| |
| Args: |
| input_features: Mel spectrogram features (batch, n_mels, mel_len) |
| audio_attention_mask: Mask for real vs padded mel frames (batch, mel_len) |
| system_prompt: Optional system prompt override |
| **generate_kwargs: Additional generation arguments |
| |
| Yields: |
| Partial transcript text as each token is generated |
| """ |
| device = input_features.device |
| batch_size = input_features.shape[0] |
|
|
| |
| encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask) |
| token_counts = self.projector.get_output_length(encoder_lengths).to(torch.long) |
| audio_embeds = self._encode_audio(input_features, token_counts) |
|
|
| |
| num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask) |
| audio_placeholder = "<audio>" * num_audio_tokens |
|
|
| system_prompt = system_prompt or self.system_prompt |
|
|
| messages: list[dict[str, str]] = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| |
| user_content = audio_placeholder |
| if self.TRANSCRIBE_PROMPT: |
| user_content += " " + self.TRANSCRIBE_PROMPT |
| messages.append({"role": "user", "content": user_content}) |
|
|
| chat_result = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| enable_thinking=False, |
| ) |
| input_ids = chat_result.input_ids.to(device) |
|
|
| if input_ids.dim() == 1: |
| input_ids = input_ids.unsqueeze(0) |
| if input_ids.shape[0] == 1 and batch_size > 1: |
| input_ids = input_ids.expand(batch_size, -1) |
|
|
| attention_mask = torch.ones_like(input_ids) |
|
|
| |
| inputs_embeds = self.language_model.get_input_embeddings()(input_ids) |
| audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1) |
| inputs_embeds = inputs_embeds.masked_scatter( |
| audio_token_mask.to(inputs_embeds.device), |
| audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype), |
| ) |
|
|
| |
| streamer = TextIteratorStreamer( |
| self.tokenizer, |
| skip_prompt=True, |
| skip_special_tokens=True, |
| ) |
|
|
| |
| gen_kwargs = { |
| "inputs_embeds": inputs_embeds, |
| "attention_mask": attention_mask, |
| "generation_config": self.generation_config, |
| "streamer": streamer, |
| **generate_kwargs, |
| } |
|
|
| |
| thread = Thread(target=self.language_model.generate, kwargs=gen_kwargs) |
| thread.start() |
|
|
| |
| |
| in_think_block = False |
| buffer = "" |
|
|
| for text in streamer: |
| buffer += text |
|
|
| |
| while "<think>" in buffer: |
| in_think_block = True |
| |
| before_think = buffer.split("<think>")[0] |
| if before_think: |
| yield before_think |
| buffer = buffer.split("<think>", 1)[-1] |
|
|
| |
| while in_think_block and "</think>" in buffer: |
| in_think_block = False |
| buffer = buffer.split("</think>", 1)[-1] |
|
|
| |
| if not in_think_block and buffer: |
| yield buffer |
| buffer = "" |
|
|
| |
| if buffer and not in_think_block: |
| yield buffer |
|
|
| thread.join() |
|
|
| def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None: |
| """Save model, tokenizer, and processor.""" |
| import shutil |
|
|
| save_dir = Path(save_directory) |
| save_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| self.config.vocab_size = self.language_model.config.vocab_size |
| self.config.text_config.vocab_size = self.language_model.config.vocab_size |
|
|
| if hasattr(self.audio_tower.config, "num_mel_bins"): |
| self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins |
|
|
| |
| tokenizer = self.tokenizer |
| del self.tokenizer |
|
|
| try: |
| super().save_pretrained(save_dir, **kwargs) |
| finally: |
| self.tokenizer = tokenizer |
|
|
| |
| self.tokenizer.save_pretrained(save_dir) |
| self.feature_extractor.save_pretrained(save_dir) |
|
|
| |
| |
| |
| if hasattr(self.language_model, "peft_config"): |
| self.language_model.save_pretrained(save_dir, save_embedding_layers=False) |
|
|
| |
| |
| |
| |
| adapter_config_path = save_dir / "adapter_config.json" |
| if adapter_config_path.exists(): |
| with adapter_config_path.open() as f: |
| adapter_config = json.load(f) |
|
|
| |
| |
| |
| repo_id = ( |
| kwargs.get("repo_id") |
| or kwargs.get("push_to_hub_model_id") |
| or getattr(self.config, "pretrained_model_path", None) |
| or "" |
| ) |
| adapter_config["base_model_name_or_path"] = repo_id |
|
|
| with adapter_config_path.open("w") as f: |
| json.dump(adapter_config, f, indent=2) |
|
|
| |
| config_path = save_dir / "preprocessor_config.json" |
| if config_path.exists(): |
| with config_path.open() as f: |
| processor_config = json.load(f) |
| else: |
| processor_config = {} |
|
|
| processor_config.update( |
| { |
| "processor_class": "ASRProcessor", |
| "auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"}, |
| } |
| ) |
|
|
| with config_path.open("w") as f: |
| json.dump(processor_config, f, indent=2) |
|
|
| |
| src_dir = Path(__file__).parent |
| for asr_file in src_dir.glob("asr_*.py"): |
| shutil.copy(asr_file, save_dir / asr_file.name) |
| |
| shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py") |
| |
| shutil.copy(src_dir / "alignment.py", save_dir / "alignment.py") |
| |
| shutil.copy(src_dir / "diarization.py", save_dir / "diarization.py") |
|
|
| def push_to_hub(self, repo_id: str, **kwargs) -> str: |
| """Push model to HuggingFace Hub, ensuring adapter_config points to repo. |
| |
| IMPORTANT: Sets base_model_name_or_path in adapter_config.json to repo_id |
| so that transformers pipeline() can load the model correctly. Without this, |
| the pipeline tries to load from "None" which fails. |
| """ |
| |
| self.config.pretrained_model_path = repo_id |
| |
| return super().push_to_hub(repo_id, **kwargs) |
|
|
|
|
| |
| |
| AutoModel.register(ASRConfig, ASRModel) |
|
|