| | import os |
| | import argparse |
| | import random |
| | import json |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from typing import Union, Optional, Tuple |
| |
|
| | from transformers.modeling_outputs import BaseModelOutputWithPast |
| | from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa |
| | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| | from transformers.processing_utils import Unpack |
| | from transformers import MistralModel, MistralConfig, AutoModel |
| | from transformers.cache_utils import Cache, DynamicCache |
| |
|
| | from .configuration_hinvec import BidirectionalMistralConfig |
| |
|
| | from loguru import logger |
| |
|
| | HINVEC_TYPE = "hinvec" |
| | BIDIR_MISTRAL_TYPE = "bidir_mistral" |
| |
|
| | class BidirectionalMistralModel(MistralModel): |
| | config_class = BidirectionalMistralConfig |
| | |
| | def __init__(self, config: MistralConfig): |
| | super().__init__(config) |
| | for layer in self.layers: |
| | layer.self_attn.is_causal = False |
| | self._attn_implementation = "eager" |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Cache] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | **flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
| | ) -> Union[Tuple, BaseModelOutputWithPast]: |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | use_cache = use_cache if use_cache is not None else self.config.use_cache |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | if input_ids is not None and inputs_embeds is not None: |
| | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
| | elif input_ids is not None: |
| | batch_size, seq_length = input_ids.shape |
| | elif inputs_embeds is not None: |
| | batch_size, seq_length, _ = inputs_embeds.shape |
| | else: |
| | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
| |
|
| | if attention_mask is None: |
| | attention_mask = torch.ones(batch_size, seq_length, dtype=input_ids.dtype, device=input_ids.device) |
| |
|
| | if self.gradient_checkpointing and self.training and use_cache: |
| | logger.warning_once( |
| | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
| | ) |
| | use_cache = False |
| |
|
| | if inputs_embeds is None: |
| | inputs_embeds = self.embed_tokens(input_ids) |
| |
|
| | if use_cache and past_key_values is None: |
| | past_key_values = DynamicCache() |
| |
|
| | if cache_position is None: |
| | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| | cache_position = torch.arange( |
| | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| | ) |
| |
|
| | if position_ids is None: |
| | position_ids = cache_position.unsqueeze(0) |
| | |
| | if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: |
| | is_padding_right = attention_mask[:, -1].sum().item() != batch_size |
| | if is_padding_right: |
| | raise ValueError( |
| | "You are attempting to perform batched generation with padding_side='right'" |
| | " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " |
| | " call `tokenizer.padding_side = 'left'` before tokenizing the input. " |
| | ) |
| |
|
| | if self._attn_implementation == "flash_attention_2": |
| | |
| | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None |
| | elif self._attn_implementation == "sdpa" and not output_attentions: |
| | |
| | |
| | attention_mask = _prepare_4d_attention_mask_for_sdpa( |
| | attention_mask, inputs_embeds.dtype |
| | ) |
| | else: |
| | |
| | attention_mask = _prepare_4d_attention_mask( |
| | attention_mask, inputs_embeds.dtype, |
| | ) |
| |
|
| | hidden_states = inputs_embeds |
| |
|
| | |
| | position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| |
|
| | |
| | all_hidden_states = () if output_hidden_states else None |
| | all_self_attns = () if output_attentions else None |
| |
|
| | for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| |
|
| | if self.gradient_checkpointing and self.training: |
| | layer_outputs = self._gradient_checkpointing_func( |
| | decoder_layer.__call__, |
| | hidden_states, |
| | attention_mask, |
| | position_ids, |
| | past_key_values, |
| | output_attentions, |
| | use_cache, |
| | cache_position, |
| | position_embeddings, |
| | ) |
| | else: |
| | layer_outputs = decoder_layer( |
| | hidden_states, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_value=past_key_values, |
| | output_attentions=output_attentions, |
| | use_cache=use_cache, |
| | cache_position=cache_position, |
| | position_embeddings=position_embeddings, |
| | **flash_attn_kwargs, |
| | ) |
| |
|
| | hidden_states = layer_outputs[0] |
| |
|
| | if output_attentions: |
| | all_self_attns += (layer_outputs[1],) |
| |
|
| | hidden_states = self.norm(hidden_states) |
| |
|
| | |
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| |
|
| | output = BaseModelOutputWithPast( |
| | last_hidden_state=hidden_states, |
| | past_key_values=past_key_values if use_cache else None, |
| | hidden_states=all_hidden_states, |
| | attentions=all_self_attns, |
| | ) |
| | return output if return_dict else output.to_tuple() |
| |
|
| |
|
| |
|
| | AutoModel.register(BidirectionalMistralConfig, BidirectionalMistralModel) |
| | BidirectionalMistralConfig.register_for_auto_class("AutoModel") |