| | import transformers |
| | import torch |
| | from typing import Optional, Tuple, Union |
| | from transformers.modeling_outputs import Seq2SeqLMOutput |
| | from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor |
| | from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE |
| |
|
| |
|
| | class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration): |
| |
|
| | def forward( |
| | self, |
| | input_features: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | decoder_input_ids: Optional[torch.LongTensor] = None, |
| | decoder_attention_mask: Optional[torch.LongTensor] = None, |
| | head_mask: Optional[torch.Tensor] = None, |
| | decoder_head_mask: Optional[torch.Tensor] = None, |
| | cross_attn_head_mask: Optional[torch.Tensor] = None, |
| | encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| | decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | forced_ac_decoder_ids: Optional[torch.LongTensor] = None, |
| | ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: |
| | return super().forward( |
| | input_features=input_features, |
| | attention_mask=attention_mask, |
| | decoder_input_ids=decoder_input_ids, |
| | decoder_attention_mask=decoder_attention_mask, |
| | head_mask=head_mask, |
| | decoder_head_mask=decoder_head_mask, |
| | cross_attn_head_mask=cross_attn_head_mask, |
| | encoder_outputs=encoder_outputs, |
| | past_key_values=past_key_values, |
| | decoder_inputs_embeds=decoder_inputs_embeds, |
| | labels=labels, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| | |
| | |
| | def generate( |
| | self, |
| | inputs: Optional[torch.Tensor] = None, |
| | forced_ac_decoder_ids: Optional[torch.Tensor] = None, |
| | generation_config=None, |
| | logits_processor=None, |
| | stopping_criteria=None, |
| | prefix_allowed_tokens_fn=None, |
| | synced_gpus=False, |
| | return_timestamps=None, |
| | task="transcribe", |
| | language="english", |
| | **kwargs, |
| | ): |
| | if generation_config is None: |
| | generation_config = self.generation_config |
| |
|
| | if return_timestamps is not None: |
| | if not hasattr(generation_config, "no_timestamps_token_id"): |
| | raise ValueError( |
| | "You are trying to return timestamps, but the generation config is not properly set." |
| | "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`." |
| | "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" |
| | ) |
| |
|
| | generation_config.return_timestamps = return_timestamps |
| | else: |
| | generation_config.return_timestamps = False |
| |
|
| | if language is not None: |
| | generation_config.language = language |
| | if task is not None: |
| | generation_config.task = task |
| |
|
| | forced_decoder_ids = [] |
| | if task is not None or language is not None: |
| | if hasattr(generation_config, "language"): |
| | if generation_config.language in generation_config.lang_to_id.keys(): |
| | language_token = generation_config.language |
| | elif generation_config.language in TO_LANGUAGE_CODE.keys(): |
| | language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" |
| | else: |
| | raise ValueError( |
| | f"Unsupported language: {language}. Language should be one of:" |
| | f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}." |
| | ) |
| | forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) |
| | else: |
| | forced_decoder_ids.append((1, None)) |
| |
|
| | if hasattr(generation_config, "task"): |
| | if generation_config.task in TASK_IDS: |
| | forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) |
| | else: |
| | raise ValueError( |
| | f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" |
| | ) |
| | else: |
| | forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) |
| | if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: |
| | idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 |
| | forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) |
| |
|
| | |
| | elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: |
| | forced_decoder_ids = self.config.forced_decoder_ids |
| | elif ( |
| | hasattr(self.generation_config, "forced_decoder_ids") |
| | and self.generation_config.forced_decoder_ids is not None |
| | ): |
| | forced_decoder_ids = self.generation_config.forced_decoder_ids |
| |
|
| | if generation_config.return_timestamps: |
| | logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)] |
| |
|
| | decoder_input_ids = None |
| |
|
| | if len(forced_decoder_ids) > 0: |
| | |
| | forced_decoder_ids.sort() |
| | if min(forced_decoder_ids)[0] != 0: |
| | forced_decoder_ids = [(0, self.config.decoder_start_token_id)] + forced_decoder_ids |
| |
|
| | position_indices, decoder_input_ids = zip(*forced_decoder_ids) |
| | assert tuple(position_indices) == tuple(range(len(position_indices))), "forced_decoder_ids is not a (continuous) prefix, we can't handle that" |
| |
|
| | device = self.get_decoder().device |
| |
|
| | if forced_ac_decoder_ids is None: |
| | forced_ac_decoder_ids = torch.tensor([[]], device=device, dtype=torch.long) |
| |
|
| | |
| | batch_size = forced_ac_decoder_ids.shape[0] |
| | fluff_len = len(decoder_input_ids) |
| | decoder_input_ids = torch.tensor(decoder_input_ids, device=device, dtype=torch.long) |
| | decoder_input_ids = decoder_input_ids.expand((batch_size, fluff_len)) |
| | decoder_input_ids = torch.cat([decoder_input_ids, forced_ac_decoder_ids], dim=1) |
| |
|
| | generation_config.forced_decoder_ids = forced_decoder_ids |
| |
|
| | return super(transformers.WhisperPreTrainedModel, self).generate( |
| | inputs, |
| | generation_config, |
| | logits_processor, |
| | stopping_criteria, |
| | prefix_allowed_tokens_fn, |
| | synced_gpus, |
| | decoder_input_ids=decoder_input_ids, |
| | **kwargs, |
| | ) |
| |
|