| """HuggingFace-compatible model definition for TinyGPT2. |
| |
| This file is self-contained so it works when downloaded from the HuggingFace Hub |
| with `trust_remote_code=True`. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| from configuration_tinygpt2 import TinyGPT2HFConfig |
|
|
|
|
| |
| |
| |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) |
| return self.weight * (x / rms) |
|
|
|
|
| def precompute_freqs_cis(dim, seq_len, theta=10000.0): |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
| t = torch.arange(seq_len, dtype=torch.float) |
| freqs = torch.outer(t, freqs) |
| return torch.polar(torch.ones_like(freqs), freqs) |
|
|
|
|
| def apply_rotary_emb(x, freqs_cis): |
| |
| x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) |
| freqs_cis = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1) |
| x_rotated = x_complex * freqs_cis |
| return torch.view_as_real(x_rotated).flatten(-2).type_as(x) |
|
|
|
|
| class GroupedQueryAttention(nn.Module): |
| def __init__(self, n_embd, n_head, n_query_groups, dropout=0.1): |
| super().__init__() |
| assert n_head % n_query_groups == 0 |
| self.n_head = n_head |
| self.n_query_groups = n_query_groups |
| self.head_dim = n_embd // n_head |
|
|
| self.q_proj = nn.Linear(n_embd, n_embd, bias=False) |
| self.k_proj = nn.Linear(n_embd, n_query_groups * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(n_embd, n_query_groups * self.head_dim, bias=False) |
| self.out_proj = nn.Linear(n_embd, n_embd, bias=False) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x, freqs_cis, is_causal=True, kv_cache=None): |
| B, T, C = x.shape |
| H, G, D = self.n_head, self.n_query_groups, self.head_dim |
|
|
| q = self.q_proj(x).view(B, T, H, D) |
| k = self.k_proj(x).view(B, T, G, D) |
| v = self.v_proj(x).view(B, T, G, D) |
|
|
| q = apply_rotary_emb(q, freqs_cis) |
| k = apply_rotary_emb(k, freqs_cis) |
|
|
| if kv_cache is not None: |
| k_past, v_past = kv_cache |
| k = torch.cat([k_past, k], dim=1) |
| v = torch.cat([v_past, v], dim=1) |
|
|
| new_kv_cache = (k, v) |
|
|
| k = k[:, :, :, None, :].expand(B, -1, G, H // G, D).reshape(B, -1, H, D) |
| v = v[:, :, :, None, :].expand(B, -1, G, H // G, D).reshape(B, -1, H, D) |
|
|
| q, k, v = (t.transpose(1, 2) for t in (q, k, v)) |
|
|
| use_causal = is_causal and kv_cache is None |
| attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=use_causal) |
| attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C) |
| return self.out_proj(attn_output), new_kv_cache |
|
|
|
|
| class TinyGPT2Block(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.ln1 = RMSNorm(config.n_embd) |
| self.attn = GroupedQueryAttention( |
| config.n_embd, config.n_head, config.gqa_kv_head, config.dropout |
| ) |
| self.ln2 = RMSNorm(config.n_embd) |
| self.ffwd = nn.Sequential( |
| nn.Linear(config.n_embd, config.hidden_size), |
| nn.GELU(), |
| nn.Linear(config.hidden_size, config.n_embd), |
| nn.Dropout(config.dropout), |
| ) |
|
|
| def forward(self, x, freqs_cis, is_causal=True, kv_cache=None): |
| residual = x |
| x = self.ln1(x) |
| attn_out, new_kv_cache = self.attn(x, freqs_cis, is_causal, kv_cache) |
| x = residual + attn_out |
|
|
| residual = x |
| x = self.ln2(x) |
| x = residual + self.ffwd(x) |
| return x, new_kv_cache |
|
|
|
|
| |
| |
| |
|
|
| class TinyGPT2ForCausalLM(PreTrainedModel, GenerationMixin): |
| _tied_weights_keys = {"lm_head.weight": "token_embedding.weight"} |
| config_class = TinyGPT2HFConfig |
|
|
| def __init__(self, config: TinyGPT2HFConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd) |
| self.blocks = nn.ModuleList( |
| [TinyGPT2Block(config) for _ in range(config.n_layer)] |
| ) |
| self.ln_f = RMSNorm(config.n_embd) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
| |
| self.token_embedding.weight = self.lm_head.weight |
|
|
| |
| self.register_buffer( |
| "freqs_cis", |
| precompute_freqs_cis( |
| config.n_embd // config.n_head, config.block_size * 2 |
| ), |
| ) |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.token_embedding |
|
|
| def set_input_embeddings(self, value): |
| self.token_embedding = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| past_key_values=None, |
| labels=None, |
| use_cache=False, |
| **kwargs, |
| ): |
| B, T = input_ids.shape |
|
|
| x = self.token_embedding(input_ids) |
|
|
| if past_key_values is not None and len(past_key_values) > 0: |
| start_pos = past_key_values[0][0].shape[1] |
| freqs_cis = self.freqs_cis[start_pos : start_pos + T] |
| else: |
| freqs_cis = self.freqs_cis[:T] |
|
|
| new_kv_caches = [] |
| for i, block in enumerate(self.blocks): |
| kv_cache = past_key_values[i] if past_key_values else None |
| x, new_cache = block(x, freqs_cis, is_causal=True, kv_cache=kv_cache) |
| new_kv_caches.append(new_cache) |
|
|
| x = self.ln_f(x) |
| logits = self.lm_head(x) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=self.config.pad_token_id, |
| ) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=new_kv_caches if use_cache else None, |
| ) |
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): |
| if past_key_values is not None and len(past_key_values) > 0: |
| input_ids = input_ids[:, -1:] |
| return { |
| "input_ids": input_ids, |
| "past_key_values": past_key_values, |
| "use_cache": True, |
| } |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values, beam_idx): |
| return tuple( |
| (k.index_select(0, beam_idx), v.index_select(0, beam_idx)) |
| for k, v in past_key_values |
| ) |
|
|