# coding=utf-8 # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Union import torch import torch.distributed as dist from torch import nn from transformers import Phi3Config from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin from transformers.masking_utils import ( create_causal_mask, create_sliding_window_causal_mask, ) from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.models.phi3.modeling_phi3 import ( Phi3RMSNorm, Phi3RotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward, ) from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import check_model_inputs from specforge.distributed import get_tp_group from specforge.layers import ( ColumnParallelLinear, ParallelLMHead, RowParallelLinear, VocabParallelEmbedding, ) class Phi3MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config # Add TP support self.tp_group = get_tp_group() self.gate_up_proj = ColumnParallelLinear( config.hidden_size, 2 * config.intermediate_size, bias=False, layout_type="gate_up", ) self.down_proj = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=False ) self.activation_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: up_states = self.gate_up_proj(hidden_states) gate, up_states = up_states.chunk(2, dim=-1) up_states = up_states * self.activation_fn(gate) down_proj = self.down_proj(up_states) # Add all_reduce for TP dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) return down_proj class Phi3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ) self.num_key_value_groups = ( config.num_attention_heads // config.num_key_value_heads ) self.num_key_value_heads = config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True # Add TP support self.tp_group = get_tp_group() tp_size = dist.get_world_size(self.tp_group) # Adjust head counts for TP self.num_attention_heads_per_rank = config.num_attention_heads // tp_size self.num_key_value_heads_per_rank = config.num_key_value_heads // tp_size # ColumnParallel splits the full QKV output across ranks op_size = config.num_attention_heads * self.head_dim + 2 * ( config.num_key_value_heads * self.head_dim ) self.o_proj = RowParallelLinear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=False ) self.qkv_proj = ColumnParallelLinear( config.hidden_size, op_size, bias=False, layout_type="merged_qkv" ) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) qkv = self.qkv_proj(hidden_states) query_pos = self.num_attention_heads_per_rank * self.head_dim query_states = qkv[..., :query_pos] key_states = qkv[ ..., query_pos : query_pos + self.num_key_value_heads_per_rank * self.head_dim, ] value_states = qkv[ ..., query_pos + self.num_key_value_heads_per_rank * self.head_dim : ] query_states = query_states.view(hidden_shape).transpose(1, 2) key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, cache_kwargs ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[ self.config._attn_implementation ] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=getattr(self.config, "sliding_window", None), **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) # Add all_reduce for TP dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) return attn_output, attn_weights class Phi3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Phi3Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) self.mlp = Phi3MLP(config) self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Phi3RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.config = config self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[ tuple[torch.Tensor, torch.Tensor] ] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[ torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] ]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + self.resid_attn_dropout( hidden_states ) # main diff with Llama residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + self.resid_mlp_dropout( hidden_states ) # main diff with Llama return hidden_states @auto_docstring class Phi3PreTrainedModel(PreTrainedModel): config: Phi3Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Phi3DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = {} _version = "0.0.5" @auto_docstring class Phi3Model(Phi3PreTrainedModel): def __init__(self, config: Phi3Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = nn.ModuleList( [ Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Phi3RotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @check_model_inputs @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You must specify exactly one of input_ids or inputs_embeds" ) layers_to_output_hidden_states: Optional[List[int]] = kwargs.pop( "layers_to_output_hidden_states", None ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) mask_function = ( create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask ) causal_mask = mask_function( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) all_hidden_states = () for idx, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) if ( layers_to_output_hidden_states is None or idx in layers_to_output_hidden_states ): all_hidden_states += (hidden_states,) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, ) @auto_docstring class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = Phi3Model(config) self.vocab_size = config.vocab_size # Use ColumnParallelLinear for lm_head self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" Example: ```python >>> from transformers import AutoTokenizer, Phi3ForCausalLM >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = ( slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep ) logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs, ) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, logits_to_keep=None, **kwargs, ): # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the # process # When the first time input length reached long and short factor switching point, enforce re-compute cache # It will cause downside of slower at this single token position, however, better than current failure. if ( past_key_values and self.config.rope_scaling and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 ): past_length = cache_position[0] if past_length <= self.config.original_max_position_embeddings: past_key_values = None model_inputs = super().prepare_inputs_for_generation( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, logits_to_keep=logits_to_keep, **kwargs, ) return model_inputs