tinygpt2-it / modeling_tinygpt2.py
NotShrirang's picture
Upload folder using huggingface_hub
f9db966 verified
"""HuggingFace-compatible model definition for TinyGPT2.
This file is self-contained so it works when downloaded from the HuggingFace Hub
with `trust_remote_code=True`.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from configuration_tinygpt2 import TinyGPT2HFConfig
# ---------------------------------------------------------------------------
# Layers (self-contained copies so this file works standalone on HF Hub)
# ---------------------------------------------------------------------------
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return self.weight * (x / rms)
def precompute_freqs_cis(dim, seq_len, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(seq_len, dtype=torch.float)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)
def apply_rotary_emb(x, freqs_cis):
# x: (B, T, H, D)
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
x_rotated = x_complex * freqs_cis
return torch.view_as_real(x_rotated).flatten(-2).type_as(x)
class GroupedQueryAttention(nn.Module):
def __init__(self, n_embd, n_head, n_query_groups, dropout=0.1):
super().__init__()
assert n_head % n_query_groups == 0
self.n_head = n_head
self.n_query_groups = n_query_groups
self.head_dim = n_embd // n_head
self.q_proj = nn.Linear(n_embd, n_embd, bias=False)
self.k_proj = nn.Linear(n_embd, n_query_groups * self.head_dim, bias=False)
self.v_proj = nn.Linear(n_embd, n_query_groups * self.head_dim, bias=False)
self.out_proj = nn.Linear(n_embd, n_embd, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, freqs_cis, is_causal=True, kv_cache=None):
B, T, C = x.shape
H, G, D = self.n_head, self.n_query_groups, self.head_dim
q = self.q_proj(x).view(B, T, H, D)
k = self.k_proj(x).view(B, T, G, D)
v = self.v_proj(x).view(B, T, G, D)
q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
if kv_cache is not None:
k_past, v_past = kv_cache
k = torch.cat([k_past, k], dim=1)
v = torch.cat([v_past, v], dim=1)
new_kv_cache = (k, v)
k = k[:, :, :, None, :].expand(B, -1, G, H // G, D).reshape(B, -1, H, D)
v = v[:, :, :, None, :].expand(B, -1, G, H // G, D).reshape(B, -1, H, D)
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
use_causal = is_causal and kv_cache is None
attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=use_causal)
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(attn_output), new_kv_cache
class TinyGPT2Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = RMSNorm(config.n_embd)
self.attn = GroupedQueryAttention(
config.n_embd, config.n_head, config.gqa_kv_head, config.dropout
)
self.ln2 = RMSNorm(config.n_embd)
self.ffwd = nn.Sequential(
nn.Linear(config.n_embd, config.hidden_size),
nn.GELU(),
nn.Linear(config.hidden_size, config.n_embd),
nn.Dropout(config.dropout),
)
def forward(self, x, freqs_cis, is_causal=True, kv_cache=None):
residual = x
x = self.ln1(x)
attn_out, new_kv_cache = self.attn(x, freqs_cis, is_causal, kv_cache)
x = residual + attn_out
residual = x
x = self.ln2(x)
x = residual + self.ffwd(x)
return x, new_kv_cache
# ---------------------------------------------------------------------------
# HuggingFace PreTrainedModel wrapper
# ---------------------------------------------------------------------------
class TinyGPT2ForCausalLM(PreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "token_embedding.weight"}
config_class = TinyGPT2HFConfig
def __init__(self, config: TinyGPT2HFConfig):
super().__init__(config)
self.config = config
self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.ModuleList(
[TinyGPT2Block(config) for _ in range(config.n_layer)]
)
self.ln_f = RMSNorm(config.n_embd)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Weight tying
self.token_embedding.weight = self.lm_head.weight
# Precompute RoPE frequencies
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(
config.n_embd // config.n_head, config.block_size * 2
),
)
self.post_init()
def get_input_embeddings(self):
return self.token_embedding
def set_input_embeddings(self, value):
self.token_embedding = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
past_key_values=None,
labels=None,
use_cache=False,
**kwargs,
):
B, T = input_ids.shape
x = self.token_embedding(input_ids)
if past_key_values is not None and len(past_key_values) > 0:
start_pos = past_key_values[0][0].shape[1] # length of cached keys
freqs_cis = self.freqs_cis[start_pos : start_pos + T]
else:
freqs_cis = self.freqs_cis[:T]
new_kv_caches = []
for i, block in enumerate(self.blocks):
kv_cache = past_key_values[i] if past_key_values else None
x, new_cache = block(x, freqs_cis, is_causal=True, kv_cache=kv_cache)
new_kv_caches.append(new_cache)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=self.config.pad_token_id,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=new_kv_caches if use_cache else None,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None and len(past_key_values) > 0:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": True,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
return tuple(
(k.index_select(0, beam_idx), v.index_select(0, beam_idx))
for k, v in past_key_values
)