Shard-1 / code /model.py
Crownelius's picture
Initial release: Shard-40m-v1 (54.5M dense transformer, anneal final)
025878f verified
"""Toy 1M-param transformer with Gemma 4 alternating SWA + optional engram memory.
Design notes
------------
* RMSNorm pre-norm, SwiGLU MLP, tied embedding/output (standard Llama-ish base).
* Causal mask is precomputed; sliding-window layers use the same code path with
an additional window-restricted mask (purely a mask difference -- no kernel split).
* RoPE is applied to Q/K only (standard, no Gemma 4 dual-RoPE).
* Engram is an optional 512-slot static memory bank attended-to from one layer's
output; injected via a sigmoid gate that is zero-initialised so it's a no-op
at training start. Bit-identical to no-engram when `cfg.engram_enabled=False`.
"""
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import Config
# ---------------------------------------------------------------------------
# RMSNorm
# ---------------------------------------------------------------------------
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute in float32 for stability; cast back to input dtype.
dtype = x.dtype
xf = x.float()
rms = xf.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (xf * rms).to(dtype) * self.weight
# ---------------------------------------------------------------------------
# RoPE
# ---------------------------------------------------------------------------
def _build_rope_cache(seq_len: int, head_dim: int, base: float, device, dtype) -> tuple[torch.Tensor, torch.Tensor]:
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
half = head_dim // 2
inv_freq = 1.0 / (base ** (torch.arange(0, half, device=device, dtype=torch.float32) / half))
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, inv_freq) # (T, half)
cos = freqs.cos().to(dtype)
sin = freqs.sin().to(dtype)
return cos, sin # each (T, half)
def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
# x: (B, n_h, T, head_dim). cos/sin: (T, head_dim/2).
x1, x2 = x.chunk(2, dim=-1)
cos_b = cos[None, None, :, :]
sin_b = sin[None, None, :, :]
rotated_x1 = x1 * cos_b - x2 * sin_b
rotated_x2 = x1 * sin_b + x2 * cos_b
return torch.cat([rotated_x1, rotated_x2], dim=-1)
# ---------------------------------------------------------------------------
# Attention
# ---------------------------------------------------------------------------
class Attention(nn.Module):
"""MHA with RoPE and configurable causal-or-sliding mask.
`kind == 'global'`: full causal attention.
`kind == 'slide'` : causal attention restricted to the last `window` tokens.
Both code paths use F.scaled_dot_product_attention for speed; the only
difference is the additive mask. When kind=='global' we pass `is_causal=True`
and skip building an explicit mask. When kind=='slide' we build a banded
mask that is bit-identical to the global path with appropriate -inf entries
outside the window.
"""
def __init__(self, cfg: Config, kind: str):
super().__init__()
assert kind in ("global", "slide")
self.cfg = cfg
self.kind = kind
self.n_heads = cfg.n_heads
self.head_dim = cfg.head_dim
self.scale = self.head_dim**-0.5
self.W_q = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.W_k = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.W_v = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.W_o = nn.Linear(cfg.dim, cfg.dim, bias=False)
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
B, T, D = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, H, T, Dh)
k = self.W_k(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
q = _apply_rope(q, cos, sin)
k = _apply_rope(k, cos, sin)
if self.kind == "global":
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
else:
# Banded causal mask: token t may attend to tokens in [max(0, t-window+1), t].
mask = _sliding_causal_mask(T, self.cfg.sliding_window, x.device, x.dtype)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=False)
out = out.transpose(1, 2).contiguous().view(B, T, D)
return self.W_o(out)
def _sliding_causal_mask(T: int, window: int, device, dtype) -> torch.Tensor:
"""(T, T) additive mask: 0 inside window+causal, -inf outside.
Token i attends to j iff j <= i and (i - j) < window.
"""
i = torch.arange(T, device=device).unsqueeze(1) # (T,1)
j = torch.arange(T, device=device).unsqueeze(0) # (1,T)
causal = j <= i
in_window = (i - j) < window
keep = causal & in_window
mask = torch.zeros((T, T), device=device, dtype=dtype)
mask = mask.masked_fill(~keep, float("-inf"))
# SDPA expects (..., T, T) broadcast over batch/heads.
return mask
# ---------------------------------------------------------------------------
# MLP (SwiGLU)
# ---------------------------------------------------------------------------
class SwiGLU(nn.Module):
def __init__(self, dim: int, hidden: int):
super().__init__()
self.w_gate = nn.Linear(dim, hidden, bias=False)
self.w_up = nn.Linear(dim, hidden, bias=False)
self.w_down = nn.Linear(hidden, dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
# ---------------------------------------------------------------------------
# Block
# ---------------------------------------------------------------------------
class Block(nn.Module):
def __init__(self, cfg: Config, layer_idx: int):
super().__init__()
kind = cfg.attention_kind(layer_idx)
self.norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
self.attn = Attention(cfg, kind=kind)
self.norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
self.mlp = SwiGLU(cfg.dim, cfg.mlp_hidden)
self.kind = kind
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x), cos, sin)
x = x + self.mlp(self.norm2(x))
return x
# ---------------------------------------------------------------------------
# Engram external memory
# ---------------------------------------------------------------------------
class Engram(nn.Module):
"""Static memory bank with single-head attention readout + zero-init gate.
Bit-identical to no-engram at init (gate sigmoid is zero so injection is 0).
Becomes non-trivial only after the gate is trained away from zero.
"""
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
# Slot rows are normalised by RMSNorm at read time.
self.slots = nn.Parameter(torch.randn(cfg.engram_slots, cfg.dim) * cfg.init_std)
self.q_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.k_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.v_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.o_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.norm = RMSNorm(cfg.dim, eps=cfg.norm_eps)
# Zero-init gate scalar -> sigmoid(0) = 0.5? No, we want exact no-op at init.
# Use a *raw* gate that we multiply rather than sigmoid; init to 0.
self.gate = nn.Parameter(torch.zeros(cfg.dim))
def forward(self, h: torch.Tensor) -> torch.Tensor:
# h: (B, T, D). Read from memory.
h_n = self.norm(h)
q = self.q_proj(h_n) # (B, T, D)
k = self.k_proj(self.slots) # (S, D)
v = self.v_proj(self.slots) # (S, D)
scale = q.shape[-1] ** -0.5
attn = torch.einsum("btd,sd->bts", q, k) * scale
w = attn.softmax(dim=-1)
retrieved = torch.einsum("bts,sd->btd", w, v)
retrieved = self.o_proj(retrieved)
# Multiplicative zero-init gate -> exact no-op at init.
return h + self.gate * retrieved
# ---------------------------------------------------------------------------
# ToyLM
# ---------------------------------------------------------------------------
class ToyLM(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.dim)
self.blocks = nn.ModuleList([Block(cfg, i) for i in range(cfg.n_layers)])
self.norm_f = RMSNorm(cfg.dim, eps=cfg.norm_eps)
if cfg.engram_enabled:
self.engram = Engram(cfg)
else:
self.engram = None
if not cfg.tie_embeddings:
self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
else:
self.lm_head = None
# RoPE cache; rebuilt lazily if the requested seq_len exceeds it.
cos, sin = _build_rope_cache(cfg.max_seq_len, cfg.head_dim, cfg.rope_base, device="cpu", dtype=torch.float32)
self.register_buffer("rope_cos", cos, persistent=False)
self.register_buffer("rope_sin", sin, persistent=False)
self._init_weights()
def _init_weights(self) -> None:
std = self.cfg.init_std
for p_name, p in self.named_parameters():
if p.dim() >= 2:
nn.init.normal_(p, mean=0.0, std=std)
elif p_name.endswith(".weight") and "norm" in p_name.lower():
nn.init.ones_(p)
elif p_name == "engram.gate":
nn.init.zeros_(p)
else:
nn.init.zeros_(p)
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
B, T = idx.shape
assert T <= self.cfg.max_seq_len, f"seq_len {T} > max {self.cfg.max_seq_len}"
x = self.tok_emb(idx)
cos = self.rope_cos[:T].to(device=x.device, dtype=x.dtype)
sin = self.rope_sin[:T].to(device=x.device, dtype=x.dtype)
for i, blk in enumerate(self.blocks):
x = blk(x, cos, sin)
if self.engram is not None and i == self.cfg.engram_inject_layer:
x = self.engram(x)
x = self.norm_f(x)
if self.cfg.tie_embeddings:
logits = F.linear(x, self.tok_emb.weight)
else:
logits = self.lm_head(x)
# Gemma-2 logit soft-cap (bf16 stability + bounded softmax input).
if self.cfg.lm_head_logit_cap is not None:
cap = self.cfg.lm_head_logit_cap
logits = cap * torch.tanh(logits / cap)
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
targets.reshape(-1),
ignore_index=-100,
)
# PaLM-style z-loss: penalises log-partition magnitude. Keeps the
# softmax denominator from drifting; small weight (~1e-4) costs ~0.
# Computed only on non-ignored positions so it composes with masked SFT.
if self.cfg.z_loss_weight > 0:
lse = torch.logsumexp(logits.float(), dim=-1) # (B, T)
if targets is not None:
valid = targets.reshape(*lse.shape) != -100
if valid.any():
z = (lse[valid] ** 2).mean()
else:
z = lse.new_zeros(())
else:
z = (lse ** 2).mean()
loss = loss + self.cfg.z_loss_weight * z
return logits, loss
@torch.no_grad()
def generate(
self,
idx: torch.Tensor,
max_new_tokens: int = 80,
*,
temperature: float = 0.8,
top_p: float = 0.9,
rep_penalty: float = 1.3,
stop_token_ids: Optional[set[int]] = None,
) -> torch.Tensor:
"""Sampling-based decode with top-p + repetition penalty.
Defaults are tuned for sub-10M LMs: greedy alone collapses into
token-level repetition loops at this scale (entropy stays high but
argmax follows a self-amplifying trajectory). T=0.8 + top-p 0.9 +
rep_penalty=1.3 reliably breaks the loop without going incoherent.
Validated 2026-04-29 on the 12k-step toy 1M checkpoint.
Pass `temperature=0.0` to recover greedy (without rep_penalty).
"""
self.eval()
for _ in range(max_new_tokens):
logits, _ = self(idx)
logits = logits[:, -1].float() # (B, V)
if rep_penalty != 1.0:
# Per-batch element rep penalty over already-emitted tokens.
for b in range(idx.size(0)):
seen = torch.unique(idx[b])
pos = logits[b, seen] > 0
logits[b, seen] = torch.where(pos,
logits[b, seen] / rep_penalty,
logits[b, seen] * rep_penalty)
if temperature <= 0.0:
nxt = logits.argmax(dim=-1, keepdim=True)
else:
logits = logits / temperature
if top_p < 1.0:
sorted_logits, sorted_idx = logits.sort(descending=True)
cum = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
mask = cum > top_p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = False
logits = logits.scatter(1, sorted_idx,
sorted_logits.masked_fill(mask, float('-inf')))
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, nxt], dim=1)
if stop_token_ids is not None and nxt[0, 0].item() in stop_token_ids:
break
if idx.size(1) >= self.cfg.max_seq_len:
break
return idx
def num_params_breakdown(self) -> dict[str, int]:
emb = sum(p.numel() for p in self.tok_emb.parameters())
attn = 0
mlp = 0
norms = 0
for blk in self.blocks:
attn += sum(p.numel() for p in blk.attn.parameters())
mlp += sum(p.numel() for p in blk.mlp.parameters())
norms += sum(p.numel() for p in blk.norm1.parameters())
norms += sum(p.numel() for p in blk.norm2.parameters())
norms += sum(p.numel() for p in self.norm_f.parameters())
engram = sum(p.numel() for p in self.engram.parameters()) if self.engram is not None else 0
head = sum(p.numel() for p in self.lm_head.parameters()) if self.lm_head is not None else 0
total = sum(p.numel() for p in self.parameters())
return {
"embedding": emb,
"attention": attn,
"mlp": mlp,
"norms": norms,
"engram": engram,
"lm_head_extra": head, # 0 when tied
"total": total,
}