""" Model components optimized for CPU training. Design rationale: - RMSNorm instead of LayerNorm: simpler, faster (no mean computation) - Rotary Position Embeddings (RoPE): no learned position embeddings needed, saves parameters and generalizes better - LoRA-style low-rank linear layers: dramatically reduces parameter count while maintaining expressiveness - All operations use float32 for CPU stability (no mixed precision) """ import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, Tuple class RMSNorm(nn.Module): """ Root Mean Square normalization. Why: ~30% faster than LayerNorm on CPU since it skips mean computation. Empirically equivalent performance for transformers. """ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * norm * self.weight class RotaryEmbedding(nn.Module): """ Rotary Position Embedding (RoPE). Why: - No learned parameters (saves memory) - Relative position awareness without extra params - Extrapolates better to unseen sequence lengths - Computationally efficient on CPU (just sin/cos) """ def __init__(self, dim: int, max_seq_len: int = 512, base: float = 10000.0): super().__init__() self.dim = dim inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) # Pre-compute for max_seq_len to avoid recomputation self._build_cache(max_seq_len) def _build_cache(self, seq_len: int): t = torch.arange(seq_len, dtype=self.inv_freq.dtype) freqs = torch.einsum('i,j->ij', t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer('cos_cached', emb.cos()) self.register_buffer('sin_cached', emb.sin()) def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: if seq_len > self.cos_cached.size(0): self._build_cache(seq_len) return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotate half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to queries and keys.""" # cos, sin: [seq_len, dim] cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim] sin = sin.unsqueeze(0).unsqueeze(0) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class LoRALinear(nn.Module): """ Low-Rank Adaptation linear layer. Why: Instead of full d_in x d_out matrix, uses two smaller matrices: d_in x rank + rank x d_out. For rank=16, d_in=d_out=256: Full: 65,536 params LoRA: 256*16 + 16*256 = 8,192 params (8x reduction!) Still maintains good expressiveness for the tasks we need. """ def __init__(self, in_features: int, out_features: int, rank: int = 16, bias: bool = False): super().__init__() self.rank = rank # If rank is large enough, just use full linear if rank >= min(in_features, out_features) // 2: self.use_lora = False self.linear = nn.Linear(in_features, out_features, bias=bias) else: self.use_lora = True self.down = nn.Linear(in_features, rank, bias=False) self.up = nn.Linear(rank, out_features, bias=bias) # Initialize to approximate identity-like behavior nn.init.kaiming_uniform_(self.down.weight, a=math.sqrt(5)) nn.init.zeros_(self.up.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_lora: return self.up(self.down(x)) return self.linear(x) class GatedMLP(nn.Module): """ SwiGLU-style gated MLP. Why: Gated activation functions consistently outperform standard ReLU/GELU in transformers, especially at small scale. The gate provides a learned "feature selection" mechanism. Uses LoRA projections to save parameters. """ def __init__(self, d_model: int, d_ff: int, rank: int = 16, dropout: float = 0.05): super().__init__() self.gate_proj = LoRALinear(d_model, d_ff, rank=rank) self.up_proj = LoRALinear(d_model, d_ff, rank=rank) self.down_proj = LoRALinear(d_ff, d_model, rank=rank) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: gate = F.silu(self.gate_proj(x)) up = self.up_proj(x) return self.dropout(self.down_proj(gate * up)) class MultiHeadAttention(nn.Module): """ Multi-Head Attention with RoPE and optional Grouped Query Attention. Why these choices: - Grouped Query Attention (GQA): shares KV heads, reducing memory and params while maintaining quality. For 8 heads with 4 KV groups: 50% KV param reduction. - Pre-computed causal mask: avoids recomputing each forward pass on CPU - RoPE applied per-head: correct relative position encoding """ def __init__(self, d_model: int, n_heads: int, rank: int = 16, dropout: float = 0.05, max_seq_len: int = 512, n_kv_heads: Optional[int] = None): super().__init__() self.d_model = d_model self.n_heads = n_heads self.n_kv_heads = n_kv_heads or n_heads self.head_dim = d_model // n_heads self.n_rep = n_heads // self.n_kv_heads # repetition factor for GQA assert d_model % n_heads == 0 self.q_proj = LoRALinear(d_model, d_model, rank=rank) self.k_proj = LoRALinear(d_model, self.n_kv_heads * self.head_dim, rank=rank) self.v_proj = LoRALinear(d_model, self.n_kv_heads * self.head_dim, rank=rank) self.o_proj = LoRALinear(d_model, d_model, rank=rank) self.dropout = nn.Dropout(dropout) self.rope = RotaryEmbedding(self.head_dim, max_seq_len) # Pre-compute causal mask mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool() self.register_buffer('causal_mask', mask) def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor: """Repeat KV heads to match Q heads for GQA.""" if self.n_rep == 1: return x bs, n_kv, seq_len, head_dim = x.shape x = x[:, :, None, :, :].expand(bs, n_kv, self.n_rep, seq_len, head_dim) return x.reshape(bs, self.n_heads, seq_len, head_dim) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # Apply RoPE cos, sin = self.rope(T) q, k = apply_rotary_pos_emb(q, k, cos, sin) # Expand KV for GQA k = self._repeat_kv(k) v = self._repeat_kv(v) # Attention scale = math.sqrt(self.head_dim) attn = torch.matmul(q, k.transpose(-2, -1)) / scale # Apply causal mask causal = self.causal_mask[:T, :T].unsqueeze(0).unsqueeze(0) attn = attn.masked_fill(causal, float('-inf')) if mask is not None: # mask shape: [B, T] -> [B, 1, 1, T] attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf')) attn = F.softmax(attn, dim=-1) attn = self.dropout(attn) out = torch.matmul(attn, v) out = out.transpose(1, 2).contiguous().view(B, T, C) return self.o_proj(out) class TransformerBlock(nn.Module): """ Single transformer block with pre-norm architecture. Why pre-norm: More stable training, especially at small scale. Gradient flow is better since residual path is unimpeded. """ def __init__(self, d_model: int, n_heads: int, d_ff: int, rank: int = 16, dropout: float = 0.05, max_seq_len: int = 512, n_kv_heads: Optional[int] = None): super().__init__() self.attn_norm = RMSNorm(d_model) self.attn = MultiHeadAttention(d_model, n_heads, rank, dropout, max_seq_len, n_kv_heads) self.ff_norm = RMSNorm(d_model) self.ff = GatedMLP(d_model, d_ff, rank, dropout) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: x = x + self.attn(self.attn_norm(x), mask) x = x + self.ff(self.ff_norm(x)) return x