| | from typing import Optional, Union |
| |
|
| | import torch |
| |
|
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from transformers.models.llama.modeling_llama import LlamaModel |
| |
|
| | from ...cache_utils import Cache |
| |
|
| |
|
| | |
| | class SuperModel(LlamaModel): |
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | ) -> Union[tuple, CausalLMOutputWithPast]: |
| | out = super().forward( |
| | input_ids, |
| | attention_mask, |
| | position_ids, |
| | past_key_values, |
| | inputs_embeds, |
| | use_cache, |
| | output_attentions, |
| | output_hidden_states, |
| | return_dict, |
| | cache_position, |
| | ) |
| | out.logits *= 2**4 |
| | return out |
| |
|