"""Toy 1M-param transformer with Gemma 4 alternating SWA + optional engram memory. Design notes ------------ * RMSNorm pre-norm, SwiGLU MLP, tied embedding/output (standard Llama-ish base). * Causal mask is precomputed; sliding-window layers use the same code path with an additional window-restricted mask (purely a mask difference -- no kernel split). * RoPE is applied to Q/K only (standard, no Gemma 4 dual-RoPE). * Engram is an optional 512-slot static memory bank attended-to from one layer's output; injected via a sigmoid gate that is zero-initialised so it's a no-op at training start. Bit-identical to no-engram when `cfg.engram_enabled=False`. """ from __future__ import annotations import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from config import Config # --------------------------------------------------------------------------- # RMSNorm # --------------------------------------------------------------------------- class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: # Compute in float32 for stability; cast back to input dtype. dtype = x.dtype xf = x.float() rms = xf.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() return (xf * rms).to(dtype) * self.weight # --------------------------------------------------------------------------- # RoPE # --------------------------------------------------------------------------- def _build_rope_cache(seq_len: int, head_dim: int, base: float, device, dtype) -> tuple[torch.Tensor, torch.Tensor]: assert head_dim % 2 == 0, "head_dim must be even for RoPE" half = head_dim // 2 inv_freq = 1.0 / (base ** (torch.arange(0, half, device=device, dtype=torch.float32) / half)) t = torch.arange(seq_len, device=device, dtype=torch.float32) freqs = torch.einsum("i,j->ij", t, inv_freq) # (T, half) cos = freqs.cos().to(dtype) sin = freqs.sin().to(dtype) return cos, sin # each (T, half) def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: # x: (B, n_h, T, head_dim). cos/sin: (T, head_dim/2). x1, x2 = x.chunk(2, dim=-1) cos_b = cos[None, None, :, :] sin_b = sin[None, None, :, :] rotated_x1 = x1 * cos_b - x2 * sin_b rotated_x2 = x1 * sin_b + x2 * cos_b return torch.cat([rotated_x1, rotated_x2], dim=-1) # --------------------------------------------------------------------------- # Attention # --------------------------------------------------------------------------- class Attention(nn.Module): """MHA with RoPE and configurable causal-or-sliding mask. `kind == 'global'`: full causal attention. `kind == 'slide'` : causal attention restricted to the last `window` tokens. Both code paths use F.scaled_dot_product_attention for speed; the only difference is the additive mask. When kind=='global' we pass `is_causal=True` and skip building an explicit mask. When kind=='slide' we build a banded mask that is bit-identical to the global path with appropriate -inf entries outside the window. """ def __init__(self, cfg: Config, kind: str): super().__init__() assert kind in ("global", "slide") self.cfg = cfg self.kind = kind self.n_heads = cfg.n_heads self.head_dim = cfg.head_dim self.scale = self.head_dim**-0.5 self.W_q = nn.Linear(cfg.dim, cfg.dim, bias=False) self.W_k = nn.Linear(cfg.dim, cfg.dim, bias=False) self.W_v = nn.Linear(cfg.dim, cfg.dim, bias=False) self.W_o = nn.Linear(cfg.dim, cfg.dim, bias=False) def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: B, T, D = x.shape q = self.W_q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, H, T, Dh) k = self.W_k(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = self.W_v(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) q = _apply_rope(q, cos, sin) k = _apply_rope(k, cos, sin) if self.kind == "global": out = F.scaled_dot_product_attention(q, k, v, is_causal=True) else: # Banded causal mask: token t may attend to tokens in [max(0, t-window+1), t]. mask = _sliding_causal_mask(T, self.cfg.sliding_window, x.device, x.dtype) out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=False) out = out.transpose(1, 2).contiguous().view(B, T, D) return self.W_o(out) def _sliding_causal_mask(T: int, window: int, device, dtype) -> torch.Tensor: """(T, T) additive mask: 0 inside window+causal, -inf outside. Token i attends to j iff j <= i and (i - j) < window. """ i = torch.arange(T, device=device).unsqueeze(1) # (T,1) j = torch.arange(T, device=device).unsqueeze(0) # (1,T) causal = j <= i in_window = (i - j) < window keep = causal & in_window mask = torch.zeros((T, T), device=device, dtype=dtype) mask = mask.masked_fill(~keep, float("-inf")) # SDPA expects (..., T, T) broadcast over batch/heads. return mask # --------------------------------------------------------------------------- # MLP (SwiGLU) # --------------------------------------------------------------------------- class SwiGLU(nn.Module): def __init__(self, dim: int, hidden: int): super().__init__() self.w_gate = nn.Linear(dim, hidden, bias=False) self.w_up = nn.Linear(dim, hidden, bias=False) self.w_down = nn.Linear(hidden, dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)) # --------------------------------------------------------------------------- # Block # --------------------------------------------------------------------------- class Block(nn.Module): def __init__(self, cfg: Config, layer_idx: int): super().__init__() kind = cfg.attention_kind(layer_idx) self.norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps) self.attn = Attention(cfg, kind=kind) self.norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps) self.mlp = SwiGLU(cfg.dim, cfg.mlp_hidden) self.kind = kind def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.norm1(x), cos, sin) x = x + self.mlp(self.norm2(x)) return x # --------------------------------------------------------------------------- # Engram external memory # --------------------------------------------------------------------------- class Engram(nn.Module): """Static memory bank with single-head attention readout + zero-init gate. Bit-identical to no-engram at init (gate sigmoid is zero so injection is 0). Becomes non-trivial only after the gate is trained away from zero. """ def __init__(self, cfg: Config): super().__init__() self.cfg = cfg # Slot rows are normalised by RMSNorm at read time. self.slots = nn.Parameter(torch.randn(cfg.engram_slots, cfg.dim) * cfg.init_std) self.q_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) self.k_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) self.v_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) self.o_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) self.norm = RMSNorm(cfg.dim, eps=cfg.norm_eps) # Zero-init gate scalar -> sigmoid(0) = 0.5? No, we want exact no-op at init. # Use a *raw* gate that we multiply rather than sigmoid; init to 0. self.gate = nn.Parameter(torch.zeros(cfg.dim)) def forward(self, h: torch.Tensor) -> torch.Tensor: # h: (B, T, D). Read from memory. h_n = self.norm(h) q = self.q_proj(h_n) # (B, T, D) k = self.k_proj(self.slots) # (S, D) v = self.v_proj(self.slots) # (S, D) scale = q.shape[-1] ** -0.5 attn = torch.einsum("btd,sd->bts", q, k) * scale w = attn.softmax(dim=-1) retrieved = torch.einsum("bts,sd->btd", w, v) retrieved = self.o_proj(retrieved) # Multiplicative zero-init gate -> exact no-op at init. return h + self.gate * retrieved # --------------------------------------------------------------------------- # ToyLM # --------------------------------------------------------------------------- class ToyLM(nn.Module): def __init__(self, cfg: Config): super().__init__() self.cfg = cfg self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.dim) self.blocks = nn.ModuleList([Block(cfg, i) for i in range(cfg.n_layers)]) self.norm_f = RMSNorm(cfg.dim, eps=cfg.norm_eps) if cfg.engram_enabled: self.engram = Engram(cfg) else: self.engram = None if not cfg.tie_embeddings: self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) else: self.lm_head = None # RoPE cache; rebuilt lazily if the requested seq_len exceeds it. cos, sin = _build_rope_cache(cfg.max_seq_len, cfg.head_dim, cfg.rope_base, device="cpu", dtype=torch.float32) self.register_buffer("rope_cos", cos, persistent=False) self.register_buffer("rope_sin", sin, persistent=False) self._init_weights() def _init_weights(self) -> None: std = self.cfg.init_std for p_name, p in self.named_parameters(): if p.dim() >= 2: nn.init.normal_(p, mean=0.0, std=std) elif p_name.endswith(".weight") and "norm" in p_name.lower(): nn.init.ones_(p) elif p_name == "engram.gate": nn.init.zeros_(p) else: nn.init.zeros_(p) def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, Optional[torch.Tensor]]: B, T = idx.shape assert T <= self.cfg.max_seq_len, f"seq_len {T} > max {self.cfg.max_seq_len}" x = self.tok_emb(idx) cos = self.rope_cos[:T].to(device=x.device, dtype=x.dtype) sin = self.rope_sin[:T].to(device=x.device, dtype=x.dtype) for i, blk in enumerate(self.blocks): x = blk(x, cos, sin) if self.engram is not None and i == self.cfg.engram_inject_layer: x = self.engram(x) x = self.norm_f(x) if self.cfg.tie_embeddings: logits = F.linear(x, self.tok_emb.weight) else: logits = self.lm_head(x) # Gemma-2 logit soft-cap (bf16 stability + bounded softmax input). if self.cfg.lm_head_logit_cap is not None: cap = self.cfg.lm_head_logit_cap logits = cap * torch.tanh(logits / cap) loss = None if targets is not None: loss = F.cross_entropy( logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-100, ) # PaLM-style z-loss: penalises log-partition magnitude. Keeps the # softmax denominator from drifting; small weight (~1e-4) costs ~0. # Computed only on non-ignored positions so it composes with masked SFT. if self.cfg.z_loss_weight > 0: lse = torch.logsumexp(logits.float(), dim=-1) # (B, T) if targets is not None: valid = targets.reshape(*lse.shape) != -100 if valid.any(): z = (lse[valid] ** 2).mean() else: z = lse.new_zeros(()) else: z = (lse ** 2).mean() loss = loss + self.cfg.z_loss_weight * z return logits, loss @torch.no_grad() def generate( self, idx: torch.Tensor, max_new_tokens: int = 80, *, temperature: float = 0.8, top_p: float = 0.9, rep_penalty: float = 1.3, stop_token_ids: Optional[set[int]] = None, ) -> torch.Tensor: """Sampling-based decode with top-p + repetition penalty. Defaults are tuned for sub-10M LMs: greedy alone collapses into token-level repetition loops at this scale (entropy stays high but argmax follows a self-amplifying trajectory). T=0.8 + top-p 0.9 + rep_penalty=1.3 reliably breaks the loop without going incoherent. Validated 2026-04-29 on the 12k-step toy 1M checkpoint. Pass `temperature=0.0` to recover greedy (without rep_penalty). """ self.eval() for _ in range(max_new_tokens): logits, _ = self(idx) logits = logits[:, -1].float() # (B, V) if rep_penalty != 1.0: # Per-batch element rep penalty over already-emitted tokens. for b in range(idx.size(0)): seen = torch.unique(idx[b]) pos = logits[b, seen] > 0 logits[b, seen] = torch.where(pos, logits[b, seen] / rep_penalty, logits[b, seen] * rep_penalty) if temperature <= 0.0: nxt = logits.argmax(dim=-1, keepdim=True) else: logits = logits / temperature if top_p < 1.0: sorted_logits, sorted_idx = logits.sort(descending=True) cum = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1) mask = cum > top_p mask[..., 1:] = mask[..., :-1].clone() mask[..., 0] = False logits = logits.scatter(1, sorted_idx, sorted_logits.masked_fill(mask, float('-inf'))) probs = F.softmax(logits, dim=-1) nxt = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, nxt], dim=1) if stop_token_ids is not None and nxt[0, 0].item() in stop_token_ids: break if idx.size(1) >= self.cfg.max_seq_len: break return idx def num_params_breakdown(self) -> dict[str, int]: emb = sum(p.numel() for p in self.tok_emb.parameters()) attn = 0 mlp = 0 norms = 0 for blk in self.blocks: attn += sum(p.numel() for p in blk.attn.parameters()) mlp += sum(p.numel() for p in blk.mlp.parameters()) norms += sum(p.numel() for p in blk.norm1.parameters()) norms += sum(p.numel() for p in blk.norm2.parameters()) norms += sum(p.numel() for p in self.norm_f.parameters()) engram = sum(p.numel() for p in self.engram.parameters()) if self.engram is not None else 0 head = sum(p.numel() for p in self.lm_head.parameters()) if self.lm_head is not None else 0 total = sum(p.numel() for p in self.parameters()) return { "embedding": emb, "attention": attn, "mlp": mlp, "norms": norms, "engram": engram, "lm_head_extra": head, # 0 when tied "total": total, }