Spaces:
Runtime error
Runtime error
| """ | |
| Tiny Transformer with modern components: | |
| - RoPE (Rotary Position Embeddings) | |
| - RMSNorm | |
| - SwiGLU activation | |
| - Weight tying | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x): | |
| norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| return x * norm * self.weight | |
| class RotaryEmbedding(nn.Module): | |
| def __init__(self, dim: int, max_seq_len: int = 512, base: int = 10000): | |
| super().__init__() | |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| self.max_seq_len = max_seq_len | |
| def forward(self, x, seq_len: int): | |
| t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) | |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| return emb.cos(), emb.sin() | |
| def rotate_half(x): | |
| x1, x2 = x.chunk(2, dim=-1) | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(q, k, cos, sin): | |
| cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim] | |
| sin = sin.unsqueeze(0).unsqueeze(0) | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| class SwiGLU(nn.Module): | |
| def __init__(self, hidden_size: int, intermediate_size: int): | |
| super().__init__() | |
| self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False) | |
| self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| def forward(self, x): | |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| class Attention(nn.Module): | |
| def __init__(self, hidden_size: int, num_heads: int, dropout: float = 0.0): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = hidden_size // num_heads | |
| self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.rotary = RotaryEmbedding(self.head_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None): | |
| B, T, C = x.shape | |
| q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) | |
| k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) | |
| cos, sin = self.rotary(x, T) | |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) | |
| # Scaled dot-product attention | |
| scale = 1.0 / math.sqrt(self.head_dim) | |
| attn = torch.matmul(q, k.transpose(-2, -1)) * scale | |
| if mask is not None: | |
| attn = attn.masked_fill(mask == 0, float('-inf')) | |
| attn = F.softmax(attn, dim=-1) | |
| attn = self.dropout(attn) | |
| out = torch.matmul(attn, v) | |
| out = out.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.o_proj(out) | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int, dropout: float = 0.0): | |
| super().__init__() | |
| self.norm1 = RMSNorm(hidden_size) | |
| self.attn = Attention(hidden_size, num_heads, dropout) | |
| self.norm2 = RMSNorm(hidden_size) | |
| self.ffn = SwiGLU(hidden_size, intermediate_size) | |
| def forward(self, x, mask=None): | |
| x = x + self.attn(self.norm1(x), mask) | |
| x = x + self.ffn(self.norm2(x)) | |
| return x | |
| class TinyLLM(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int = 32000, | |
| hidden_size: int = 512, | |
| num_layers: int = 12, | |
| num_heads: int = 8, | |
| intermediate_size: int = 1408, | |
| max_position_embeddings: int = 512, | |
| dropout: float = 0.0, | |
| tie_weights: bool = True, | |
| ): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.hidden_size = hidden_size | |
| self.embed_tokens = nn.Embedding(vocab_size, hidden_size) | |
| self.layers = nn.ModuleList([ | |
| TransformerBlock(hidden_size, num_heads, intermediate_size, dropout) | |
| for _ in range(num_layers) | |
| ]) | |
| self.norm = RMSNorm(hidden_size) | |
| self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) | |
| if tie_weights: | |
| self.lm_head.weight = self.embed_tokens.weight | |
| # Causal mask | |
| self.register_buffer( | |
| "causal_mask", | |
| torch.tril(torch.ones(max_position_embeddings, max_position_embeddings)) | |
| ) | |
| self._init_weights() | |
| def _init_weights(self): | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, input_ids, labels=None): | |
| B, T = input_ids.shape | |
| x = self.embed_tokens(input_ids) | |
| mask = self.causal_mask[:T, :T] | |
| for layer in self.layers: | |
| x = layer(x, mask) | |
| x = self.norm(x) | |
| logits = self.lm_head(x) | |
| loss = None | |
| if labels is not None: | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| loss = F.cross_entropy( | |
| shift_logits.view(-1, self.vocab_size), | |
| shift_labels.view(-1), | |
| ignore_index=-100 | |
| ) | |
| return {"loss": loss, "logits": logits} | |
| def count_parameters(self): | |
| return sum(p.numel() for p in self.parameters()) | |
| if __name__ == "__main__": | |
| # Test model | |
| model = TinyLLM() | |
| print(f"Parameters: {model.count_parameters() / 1e6:.2f}M") | |
| x = torch.randint(0, 32000, (2, 128)) | |
| out = model(x, labels=x) | |
| print(f"Loss: {out['loss'].item():.4f}") | |
| print(f"Logits shape: {out['logits'].shape}") | |