| | from typing import Optional, Union |
| |
|
| | import torch |
| | import transformers |
| | from transformers import ProcessorMixin |
| |
|
| | try: |
| | from .asr_config import ASRConfig |
| | except ImportError: |
| | from asr_config import ASRConfig |
| |
|
| |
|
| | class ASRProcessor(ProcessorMixin): |
| | """Processor for Whisper-based ASR models.""" |
| |
|
| | attributes = ["feature_extractor", "tokenizer"] |
| | feature_extractor_class = "AutoFeatureExtractor" |
| | tokenizer_class = "AutoTokenizer" |
| | AUDIO_TOKEN = "<audio>" |
| | TRANSCRIBE_PROMPT = "Transcribe: " |
| |
|
| | def __init__(self, feature_extractor, tokenizer, projector=None): |
| | self.feature_extractor = feature_extractor |
| | self.tokenizer = tokenizer |
| | self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN) |
| | self.projector = projector |
| |
|
| | def __call__( |
| | self, |
| | audio: Optional[Union[list, "torch.Tensor"]] = None, |
| | text: Optional[str] = None, |
| | system_prompt: Optional[str] = None, |
| | return_tensors: str = "pt", |
| | **kwargs, |
| | ) -> dict: |
| | """Process audio and text inputs for inference. |
| | |
| | Args: |
| | audio: Raw audio waveform(s) |
| | text: Target transcription (optional, for training - but use DataCollator instead) |
| | system_prompt: Optional system prompt |
| | return_tensors: Return format ("pt" for PyTorch) |
| | |
| | Returns: |
| | Dict with input_features, input_ids, attention_mask |
| | """ |
| | result = {} |
| |
|
| | |
| | if audio is not None: |
| | audio_inputs = self.feature_extractor( |
| | audio, |
| | sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000), |
| | return_attention_mask=True, |
| | return_tensors=return_tensors, |
| | **kwargs, |
| | ) |
| | result["input_features"] = audio_inputs["input_features"] |
| | result["audio_attention_mask"] = audio_inputs["attention_mask"] |
| |
|
| | |
| | real_mel_len = audio_inputs["attention_mask"].sum(dim=-1).max().item() |
| | encoder_output_len = real_mel_len // 2 |
| | num_audio_tokens = self.projector.get_output_length(encoder_output_len) |
| | else: |
| | num_audio_tokens = 0 |
| |
|
| | |
| | user_content = self.TRANSCRIBE_PROMPT |
| | if num_audio_tokens > 0: |
| | user_content += self.AUDIO_TOKEN * num_audio_tokens |
| |
|
| | messages = [] |
| | if system_prompt: |
| | messages.append({"role": "system", "content": system_prompt}) |
| | messages.append({"role": "user", "content": user_content}) |
| | if text is not None: |
| | messages.append({"role": "assistant", "content": text}) |
| |
|
| | |
| | input_ids = self.tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=True, |
| | add_generation_prompt=(text is None), |
| | return_tensors=return_tensors, |
| | ) |
| |
|
| | if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 1: |
| | input_ids = input_ids.unsqueeze(0) |
| |
|
| | result["input_ids"] = input_ids |
| | result["attention_mask"] = torch.ones_like(input_ids) |
| |
|
| | return result |
| |
|
| |
|
| | ASRProcessor.register_for_auto_class() |
| | transformers.AutoProcessor.register(ASRConfig, ASRProcessor) |
| |
|