| """ |
| Model components optimized for CPU training. |
| |
| Design rationale: |
| - RMSNorm instead of LayerNorm: simpler, faster (no mean computation) |
| - Rotary Position Embeddings (RoPE): no learned position embeddings needed, |
| saves parameters and generalizes better |
| - LoRA-style low-rank linear layers: dramatically reduces parameter count |
| while maintaining expressiveness |
| - All operations use float32 for CPU stability (no mixed precision) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from typing import Optional, Tuple |
|
|
|
|
| class RMSNorm(nn.Module): |
| """ |
| Root Mean Square normalization. |
| |
| Why: ~30% faster than LayerNorm on CPU since it skips mean computation. |
| Empirically equivalent performance for transformers. |
| """ |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| return x * norm * self.weight |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| """ |
| Rotary Position Embedding (RoPE). |
| |
| Why: |
| - No learned parameters (saves memory) |
| - Relative position awareness without extra params |
| - Extrapolates better to unseen sequence lengths |
| - Computationally efficient on CPU (just sin/cos) |
| """ |
| def __init__(self, dim: int, max_seq_len: int = 512, base: float = 10000.0): |
| super().__init__() |
| self.dim = dim |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer('inv_freq', inv_freq) |
| |
| self._build_cache(max_seq_len) |
|
|
| def _build_cache(self, seq_len: int): |
| t = torch.arange(seq_len, dtype=self.inv_freq.dtype) |
| freqs = torch.einsum('i,j->ij', t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer('cos_cached', emb.cos()) |
| self.register_buffer('sin_cached', emb.sin()) |
|
|
| def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| if seq_len > self.cos_cached.size(0): |
| self._build_cache(seq_len) |
| return self.cos_cached[:seq_len], self.sin_cached[:seq_len] |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| """Rotate half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, |
| cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Apply rotary embeddings to queries and keys.""" |
| |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| sin = sin.unsqueeze(0).unsqueeze(0) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class LoRALinear(nn.Module): |
| """ |
| Low-Rank Adaptation linear layer. |
| |
| Why: Instead of full d_in x d_out matrix, uses two smaller matrices: |
| d_in x rank + rank x d_out. For rank=16, d_in=d_out=256: |
| Full: 65,536 params |
| LoRA: 256*16 + 16*256 = 8,192 params (8x reduction!) |
| |
| Still maintains good expressiveness for the tasks we need. |
| """ |
| def __init__(self, in_features: int, out_features: int, rank: int = 16, bias: bool = False): |
| super().__init__() |
| self.rank = rank |
| |
| if rank >= min(in_features, out_features) // 2: |
| self.use_lora = False |
| self.linear = nn.Linear(in_features, out_features, bias=bias) |
| else: |
| self.use_lora = True |
| self.down = nn.Linear(in_features, rank, bias=False) |
| self.up = nn.Linear(rank, out_features, bias=bias) |
| |
| nn.init.kaiming_uniform_(self.down.weight, a=math.sqrt(5)) |
| nn.init.zeros_(self.up.weight) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.use_lora: |
| return self.up(self.down(x)) |
| return self.linear(x) |
|
|
|
|
| class GatedMLP(nn.Module): |
| """ |
| SwiGLU-style gated MLP. |
| |
| Why: Gated activation functions consistently outperform standard ReLU/GELU |
| in transformers, especially at small scale. The gate provides a learned |
| "feature selection" mechanism. |
| |
| Uses LoRA projections to save parameters. |
| """ |
| def __init__(self, d_model: int, d_ff: int, rank: int = 16, dropout: float = 0.05): |
| super().__init__() |
| self.gate_proj = LoRALinear(d_model, d_ff, rank=rank) |
| self.up_proj = LoRALinear(d_model, d_ff, rank=rank) |
| self.down_proj = LoRALinear(d_ff, d_model, rank=rank) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| gate = F.silu(self.gate_proj(x)) |
| up = self.up_proj(x) |
| return self.dropout(self.down_proj(gate * up)) |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| """ |
| Multi-Head Attention with RoPE and optional Grouped Query Attention. |
| |
| Why these choices: |
| - Grouped Query Attention (GQA): shares KV heads, reducing memory and params |
| while maintaining quality. For 8 heads with 4 KV groups: 50% KV param reduction. |
| - Pre-computed causal mask: avoids recomputing each forward pass on CPU |
| - RoPE applied per-head: correct relative position encoding |
| """ |
| def __init__(self, d_model: int, n_heads: int, rank: int = 16, |
| dropout: float = 0.05, max_seq_len: int = 512, |
| n_kv_heads: Optional[int] = None): |
| super().__init__() |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.n_kv_heads = n_kv_heads or n_heads |
| self.head_dim = d_model // n_heads |
| self.n_rep = n_heads // self.n_kv_heads |
| |
| assert d_model % n_heads == 0 |
| |
| self.q_proj = LoRALinear(d_model, d_model, rank=rank) |
| self.k_proj = LoRALinear(d_model, self.n_kv_heads * self.head_dim, rank=rank) |
| self.v_proj = LoRALinear(d_model, self.n_kv_heads * self.head_dim, rank=rank) |
| self.o_proj = LoRALinear(d_model, d_model, rank=rank) |
| |
| self.dropout = nn.Dropout(dropout) |
| self.rope = RotaryEmbedding(self.head_dim, max_seq_len) |
| |
| |
| mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool() |
| self.register_buffer('causal_mask', mask) |
|
|
| def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor: |
| """Repeat KV heads to match Q heads for GQA.""" |
| if self.n_rep == 1: |
| return x |
| bs, n_kv, seq_len, head_dim = x.shape |
| x = x[:, :, None, :, :].expand(bs, n_kv, self.n_rep, seq_len, head_dim) |
| return x.reshape(bs, self.n_heads, seq_len, head_dim) |
|
|
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| B, T, C = x.shape |
| |
| q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
| |
| |
| cos, sin = self.rope(T) |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) |
| |
| |
| k = self._repeat_kv(k) |
| v = self._repeat_kv(v) |
| |
| |
| scale = math.sqrt(self.head_dim) |
| attn = torch.matmul(q, k.transpose(-2, -1)) / scale |
| |
| |
| causal = self.causal_mask[:T, :T].unsqueeze(0).unsqueeze(0) |
| attn = attn.masked_fill(causal, float('-inf')) |
| |
| if mask is not None: |
| |
| attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf')) |
| |
| attn = F.softmax(attn, dim=-1) |
| attn = self.dropout(attn) |
| |
| out = torch.matmul(attn, v) |
| out = out.transpose(1, 2).contiguous().view(B, T, C) |
| return self.o_proj(out) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """ |
| Single transformer block with pre-norm architecture. |
| |
| Why pre-norm: More stable training, especially at small scale. |
| Gradient flow is better since residual path is unimpeded. |
| """ |
| def __init__(self, d_model: int, n_heads: int, d_ff: int, |
| rank: int = 16, dropout: float = 0.05, |
| max_seq_len: int = 512, n_kv_heads: Optional[int] = None): |
| super().__init__() |
| self.attn_norm = RMSNorm(d_model) |
| self.attn = MultiHeadAttention(d_model, n_heads, rank, dropout, max_seq_len, n_kv_heads) |
| self.ff_norm = RMSNorm(d_model) |
| self.ff = GatedMLP(d_model, d_ff, rank, dropout) |
|
|
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| x = x + self.attn(self.attn_norm(x), mask) |
| x = x + self.ff(self.ff_norm(x)) |
| return x |