MiniAxion1-0.9M / components.py
AxionLab-official's picture
Create components.py
3da2ee9 verified
"""
Model components optimized for CPU training.
Design rationale:
- RMSNorm instead of LayerNorm: simpler, faster (no mean computation)
- Rotary Position Embeddings (RoPE): no learned position embeddings needed,
saves parameters and generalizes better
- LoRA-style low-rank linear layers: dramatically reduces parameter count
while maintaining expressiveness
- All operations use float32 for CPU stability (no mixed precision)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
class RMSNorm(nn.Module):
"""
Root Mean Square normalization.
Why: ~30% faster than LayerNorm on CPU since it skips mean computation.
Empirically equivalent performance for transformers.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x * norm * self.weight
class RotaryEmbedding(nn.Module):
"""
Rotary Position Embedding (RoPE).
Why:
- No learned parameters (saves memory)
- Relative position awareness without extra params
- Extrapolates better to unseen sequence lengths
- Computationally efficient on CPU (just sin/cos)
"""
def __init__(self, dim: int, max_seq_len: int = 512, base: float = 10000.0):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# Pre-compute for max_seq_len to avoid recomputation
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached', emb.cos())
self.register_buffer('sin_cached', emb.sin())
def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
if seq_len > self.cos_cached.size(0):
self._build_cache(seq_len)
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor,
cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary embeddings to queries and keys."""
# cos, sin: [seq_len, dim]
cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim]
sin = sin.unsqueeze(0).unsqueeze(0)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LoRALinear(nn.Module):
"""
Low-Rank Adaptation linear layer.
Why: Instead of full d_in x d_out matrix, uses two smaller matrices:
d_in x rank + rank x d_out. For rank=16, d_in=d_out=256:
Full: 65,536 params
LoRA: 256*16 + 16*256 = 8,192 params (8x reduction!)
Still maintains good expressiveness for the tasks we need.
"""
def __init__(self, in_features: int, out_features: int, rank: int = 16, bias: bool = False):
super().__init__()
self.rank = rank
# If rank is large enough, just use full linear
if rank >= min(in_features, out_features) // 2:
self.use_lora = False
self.linear = nn.Linear(in_features, out_features, bias=bias)
else:
self.use_lora = True
self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=bias)
# Initialize to approximate identity-like behavior
nn.init.kaiming_uniform_(self.down.weight, a=math.sqrt(5))
nn.init.zeros_(self.up.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_lora:
return self.up(self.down(x))
return self.linear(x)
class GatedMLP(nn.Module):
"""
SwiGLU-style gated MLP.
Why: Gated activation functions consistently outperform standard ReLU/GELU
in transformers, especially at small scale. The gate provides a learned
"feature selection" mechanism.
Uses LoRA projections to save parameters.
"""
def __init__(self, d_model: int, d_ff: int, rank: int = 16, dropout: float = 0.05):
super().__init__()
self.gate_proj = LoRALinear(d_model, d_ff, rank=rank)
self.up_proj = LoRALinear(d_model, d_ff, rank=rank)
self.down_proj = LoRALinear(d_ff, d_model, rank=rank)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = F.silu(self.gate_proj(x))
up = self.up_proj(x)
return self.dropout(self.down_proj(gate * up))
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention with RoPE and optional Grouped Query Attention.
Why these choices:
- Grouped Query Attention (GQA): shares KV heads, reducing memory and params
while maintaining quality. For 8 heads with 4 KV groups: 50% KV param reduction.
- Pre-computed causal mask: avoids recomputing each forward pass on CPU
- RoPE applied per-head: correct relative position encoding
"""
def __init__(self, d_model: int, n_heads: int, rank: int = 16,
dropout: float = 0.05, max_seq_len: int = 512,
n_kv_heads: Optional[int] = None):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads or n_heads
self.head_dim = d_model // n_heads
self.n_rep = n_heads // self.n_kv_heads # repetition factor for GQA
assert d_model % n_heads == 0
self.q_proj = LoRALinear(d_model, d_model, rank=rank)
self.k_proj = LoRALinear(d_model, self.n_kv_heads * self.head_dim, rank=rank)
self.v_proj = LoRALinear(d_model, self.n_kv_heads * self.head_dim, rank=rank)
self.o_proj = LoRALinear(d_model, d_model, rank=rank)
self.dropout = nn.Dropout(dropout)
self.rope = RotaryEmbedding(self.head_dim, max_seq_len)
# Pre-compute causal mask
mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
self.register_buffer('causal_mask', mask)
def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
"""Repeat KV heads to match Q heads for GQA."""
if self.n_rep == 1:
return x
bs, n_kv, seq_len, head_dim = x.shape
x = x[:, :, None, :, :].expand(bs, n_kv, self.n_rep, seq_len, head_dim)
return x.reshape(bs, self.n_heads, seq_len, head_dim)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
# Apply RoPE
cos, sin = self.rope(T)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# Expand KV for GQA
k = self._repeat_kv(k)
v = self._repeat_kv(v)
# Attention
scale = math.sqrt(self.head_dim)
attn = torch.matmul(q, k.transpose(-2, -1)) / scale
# Apply causal mask
causal = self.causal_mask[:T, :T].unsqueeze(0).unsqueeze(0)
attn = attn.masked_fill(causal, float('-inf'))
if mask is not None:
# mask shape: [B, T] -> [B, 1, 1, T]
attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.o_proj(out)
class TransformerBlock(nn.Module):
"""
Single transformer block with pre-norm architecture.
Why pre-norm: More stable training, especially at small scale.
Gradient flow is better since residual path is unimpeded.
"""
def __init__(self, d_model: int, n_heads: int, d_ff: int,
rank: int = 16, dropout: float = 0.05,
max_seq_len: int = 512, n_kv_heads: Optional[int] = None):
super().__init__()
self.attn_norm = RMSNorm(d_model)
self.attn = MultiHeadAttention(d_model, n_heads, rank, dropout, max_seq_len, n_kv_heads)
self.ff_norm = RMSNorm(d_model)
self.ff = GatedMLP(d_model, d_ff, rank, dropout)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.attn(self.attn_norm(x), mask)
x = x + self.ff(self.ff_norm(x))
return x