| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from typing import Callable, List, Optional, Union |
|
|
| import torch |
| import torch.distributed as dist |
| from torch import nn |
| from transformers import Phi3Config |
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.generation import GenerationMixin |
| from transformers.masking_utils import ( |
| create_causal_mask, |
| create_sliding_window_causal_mask, |
| ) |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| ) |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| from transformers.models.phi3.modeling_phi3 import ( |
| Phi3RMSNorm, |
| Phi3RotaryEmbedding, |
| apply_rotary_pos_emb, |
| eager_attention_forward, |
| ) |
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple |
| from transformers.utils.deprecation import deprecate_kwarg |
| from transformers.utils.generic import check_model_inputs |
|
|
| from specforge.distributed import get_tp_group |
| from specforge.layers import ( |
| ColumnParallelLinear, |
| ParallelLMHead, |
| RowParallelLinear, |
| VocabParallelEmbedding, |
| ) |
|
|
|
|
| class Phi3MLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
|
|
| self.config = config |
|
|
| |
| self.tp_group = get_tp_group() |
|
|
| self.gate_up_proj = ColumnParallelLinear( |
| config.hidden_size, |
| 2 * config.intermediate_size, |
| bias=False, |
| layout_type="gate_up", |
| ) |
| self.down_proj = RowParallelLinear( |
| config.intermediate_size, config.hidden_size, bias=False |
| ) |
| self.activation_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: |
| up_states = self.gate_up_proj(hidden_states) |
|
|
| gate, up_states = up_states.chunk(2, dim=-1) |
| up_states = up_states * self.activation_fn(gate) |
|
|
| down_proj = self.down_proj(up_states) |
| |
| dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) |
| return down_proj |
|
|
|
|
| class Phi3Attention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.head_dim = getattr( |
| config, "head_dim", config.hidden_size // config.num_attention_heads |
| ) |
| self.num_key_value_groups = ( |
| config.num_attention_heads // config.num_key_value_heads |
| ) |
| self.num_key_value_heads = config.num_key_value_heads |
| self.scaling = self.head_dim**-0.5 |
| self.attention_dropout = config.attention_dropout |
| self.is_causal = True |
|
|
| |
| self.tp_group = get_tp_group() |
| tp_size = dist.get_world_size(self.tp_group) |
|
|
| |
| self.num_attention_heads_per_rank = config.num_attention_heads // tp_size |
| self.num_key_value_heads_per_rank = config.num_key_value_heads // tp_size |
|
|
| |
| op_size = config.num_attention_heads * self.head_dim + 2 * ( |
| config.num_key_value_heads * self.head_dim |
| ) |
| self.o_proj = RowParallelLinear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=False |
| ) |
| self.qkv_proj = ColumnParallelLinear( |
| config.hidden_size, op_size, bias=False, layout_type="merged_qkv" |
| ) |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor], |
| past_key_values: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
| qkv = self.qkv_proj(hidden_states) |
| query_pos = self.num_attention_heads_per_rank * self.head_dim |
| query_states = qkv[..., :query_pos] |
| key_states = qkv[ |
| ..., |
| query_pos : query_pos + self.num_key_value_heads_per_rank * self.head_dim, |
| ] |
| value_states = qkv[ |
| ..., query_pos + self.num_key_value_heads_per_rank * self.head_dim : |
| ] |
|
|
| query_states = query_states.view(hidden_shape).transpose(1, 2) |
| key_states = key_states.view(hidden_shape).transpose(1, 2) |
| value_states = value_states.view(hidden_shape).transpose(1, 2) |
|
|
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb( |
| query_states, key_states, cos, sin |
| ) |
|
|
| if past_key_values is not None: |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_values.update( |
| key_states, value_states, self.layer_idx, cache_kwargs |
| ) |
|
|
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[ |
| self.config._attn_implementation |
| ] |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| sliding_window=getattr(self.config, "sliding_window", None), |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| |
| dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) |
| return attn_output, attn_weights |
|
|
|
|
| class Phi3DecoderLayer(GradientCheckpointingLayer): |
| def __init__(self, config: Phi3Config, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) |
| self.mlp = Phi3MLP(config) |
| self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = Phi3RMSNorm( |
| config.hidden_size, eps=config.rms_norm_eps |
| ) |
| self.config = config |
| self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) |
| self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[ |
| tuple[torch.Tensor, torch.Tensor] |
| ] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[ |
| torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] |
| ]: |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| hidden_states, self_attn_weights = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| hidden_states = residual + self.resid_attn_dropout( |
| hidden_states |
| ) |
|
|
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + self.resid_mlp_dropout( |
| hidden_states |
| ) |
| return hidden_states |
|
|
|
|
| @auto_docstring |
| class Phi3PreTrainedModel(PreTrainedModel): |
| config: Phi3Config |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Phi3DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
|
|
| _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = {} |
| _version = "0.0.5" |
|
|
|
|
| @auto_docstring |
| class Phi3Model(Phi3PreTrainedModel): |
| def __init__(self, config: Phi3Config): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = VocabParallelEmbedding( |
| config.vocab_size, config.hidden_size, self.padding_idx |
| ) |
| self.layers = nn.ModuleList( |
| [ |
| Phi3DecoderLayer(config, layer_idx) |
| for layer_idx in range(config.num_hidden_layers) |
| ] |
| ) |
| self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = Phi3RotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
|
|
| |
| self.post_init() |
|
|
| @check_model_inputs |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPast: |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError( |
| "You must specify exactly one of input_ids or inputs_embeds" |
| ) |
|
|
| layers_to_output_hidden_states: Optional[List[int]] = kwargs.pop( |
| "layers_to_output_hidden_states", None |
| ) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache() |
|
|
| if cache_position is None: |
| past_seen_tokens = ( |
| past_key_values.get_seq_length() if past_key_values is not None else 0 |
| ) |
| cache_position = torch.arange( |
| past_seen_tokens, |
| past_seen_tokens + inputs_embeds.shape[1], |
| device=inputs_embeds.device, |
| ) |
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| mask_function = ( |
| create_causal_mask |
| if self.config.sliding_window is None |
| else create_sliding_window_causal_mask |
| ) |
| causal_mask = mask_function( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
| all_hidden_states = () |
| for idx, decoder_layer in enumerate(self.layers): |
| hidden_states = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| if ( |
| layers_to_output_hidden_states is None |
| or idx in layers_to_output_hidden_states |
| ): |
| all_hidden_states += (hidden_states,) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values if use_cache else None, |
| hidden_states=all_hidden_states, |
| ) |
|
|
|
|
| @auto_docstring |
| class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
| _tp_plan = {"lm_head": "colwise_rep"} |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = Phi3Model(config) |
| self.vocab_size = config.vocab_size |
|
|
| |
| self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> CausalLMOutputWithPast: |
| r""" |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, Phi3ForCausalLM |
| |
| >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf") |
| >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
| outputs: BaseModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = ( |
| slice(-logits_to_keep, None) |
| if isinstance(logits_to_keep, int) |
| else logits_to_keep |
| ) |
| logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function( |
| logits=logits, |
| labels=labels, |
| vocab_size=self.config.vocab_size, |
| **kwargs, |
| ) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| cache_position=None, |
| position_ids=None, |
| use_cache=True, |
| logits_to_keep=None, |
| **kwargs, |
| ): |
| |
| |
|
|
| |
| |
| if ( |
| past_key_values |
| and self.config.rope_scaling |
| and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 |
| ): |
| past_length = cache_position[0] |
| if past_length <= self.config.original_max_position_embeddings: |
| past_key_values = None |
|
|
| model_inputs = super().prepare_inputs_for_generation( |
| input_ids=input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| position_ids=position_ids, |
| use_cache=use_cache, |
| logits_to_keep=logits_to_keep, |
| **kwargs, |
| ) |
| return model_inputs |
|
|