Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Optional, Union
import torch
import torch.distributed as dist
from torch import nn
from transformers import Phi3Config
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.masking_utils import (
create_causal_mask,
create_sliding_window_causal_mask,
)
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.models.phi3.modeling_phi3 import (
Phi3RMSNorm,
Phi3RotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import check_model_inputs
from specforge.distributed import get_tp_group
from specforge.layers import (
ColumnParallelLinear,
ParallelLMHead,
RowParallelLinear,
VocabParallelEmbedding,
)
class Phi3MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Add TP support
self.tp_group = get_tp_group()
self.gate_up_proj = ColumnParallelLinear(
config.hidden_size,
2 * config.intermediate_size,
bias=False,
layout_type="gate_up",
)
self.down_proj = RowParallelLinear(
config.intermediate_size, config.hidden_size, bias=False
)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * self.activation_fn(gate)
down_proj = self.down_proj(up_states)
# Add all_reduce for TP
dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group)
return down_proj
class Phi3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.num_key_value_heads = config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
# Add TP support
self.tp_group = get_tp_group()
tp_size = dist.get_world_size(self.tp_group)
# Adjust head counts for TP
self.num_attention_heads_per_rank = config.num_attention_heads // tp_size
self.num_key_value_heads_per_rank = config.num_key_value_heads // tp_size
# ColumnParallel splits the full QKV output across ranks
op_size = config.num_attention_heads * self.head_dim + 2 * (
config.num_key_value_heads * self.head_dim
)
self.o_proj = RowParallelLinear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
)
self.qkv_proj = ColumnParallelLinear(
config.hidden_size, op_size, bias=False, layout_type="merged_qkv"
)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
qkv = self.qkv_proj(hidden_states)
query_pos = self.num_attention_heads_per_rank * self.head_dim
query_states = qkv[..., :query_pos]
key_states = qkv[
...,
query_pos : query_pos + self.num_key_value_heads_per_rank * self.head_dim,
]
value_states = qkv[
..., query_pos + self.num_key_value_heads_per_rank * self.head_dim :
]
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=getattr(self.config, "sliding_window", None),
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
# Add all_reduce for TP
dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group)
return attn_output, attn_weights
class Phi3DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Phi3Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx)
self.mlp = Phi3MLP(config)
self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Phi3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.config = config
self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
tuple[torch.Tensor, torch.Tensor]
] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + self.resid_attn_dropout(
hidden_states
) # main diff with Llama
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.resid_mlp_dropout(
hidden_states
) # main diff with Llama
return hidden_states
@auto_docstring
class Phi3PreTrainedModel(PreTrainedModel):
config: Phi3Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Phi3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {}
_version = "0.0.5"
@auto_docstring
class Phi3Model(Phi3PreTrainedModel):
def __init__(self, config: Phi3Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
Phi3DecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Phi3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
layers_to_output_hidden_states: Optional[List[int]] = kwargs.pop(
"layers_to_output_hidden_states", None
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
mask_function = (
create_causal_mask
if self.config.sliding_window is None
else create_sliding_window_causal_mask
)
causal_mask = mask_function(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
all_hidden_states = ()
for idx, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
if (
layers_to_output_hidden_states is None
or idx in layers_to_output_hidden_states
):
all_hidden_states += (hidden_states,)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
)
@auto_docstring
class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = Phi3Model(config)
self.vocab_size = config.vocab_size
# Use ColumnParallelLinear for lm_head
self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
Example:
```python
>>> from transformers import AutoTokenizer, Phi3ForCausalLM
>>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True)
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
# process
# When the first time input length reached long and short factor switching point, enforce re-compute cache
# It will cause downside of slower at this single token position, however, better than current failure.
if (
past_key_values
and self.config.rope_scaling
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
):
past_length = cache_position[0]
if past_length <= self.config.original_max_position_embeddings:
past_key_values = None
model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
)
return model_inputs