| """Toy 1M-param transformer with Gemma 4 alternating SWA + optional engram memory. |
| |
| Design notes |
| ------------ |
| * RMSNorm pre-norm, SwiGLU MLP, tied embedding/output (standard Llama-ish base). |
| * Causal mask is precomputed; sliding-window layers use the same code path with |
| an additional window-restricted mask (purely a mask difference -- no kernel split). |
| * RoPE is applied to Q/K only (standard, no Gemma 4 dual-RoPE). |
| * Engram is an optional 512-slot static memory bank attended-to from one layer's |
| output; injected via a sigmoid gate that is zero-initialised so it's a no-op |
| at training start. Bit-identical to no-engram when `cfg.engram_enabled=False`. |
| """ |
| from __future__ import annotations |
|
|
| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from config import Config |
|
|
|
|
| |
| |
| |
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| dtype = x.dtype |
| xf = x.float() |
| rms = xf.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() |
| return (xf * rms).to(dtype) * self.weight |
|
|
|
|
| |
| |
| |
| def _build_rope_cache(seq_len: int, head_dim: int, base: float, device, dtype) -> tuple[torch.Tensor, torch.Tensor]: |
| assert head_dim % 2 == 0, "head_dim must be even for RoPE" |
| half = head_dim // 2 |
| inv_freq = 1.0 / (base ** (torch.arange(0, half, device=device, dtype=torch.float32) / half)) |
| t = torch.arange(seq_len, device=device, dtype=torch.float32) |
| freqs = torch.einsum("i,j->ij", t, inv_freq) |
| cos = freqs.cos().to(dtype) |
| sin = freqs.sin().to(dtype) |
| return cos, sin |
|
|
|
|
| def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| |
| x1, x2 = x.chunk(2, dim=-1) |
| cos_b = cos[None, None, :, :] |
| sin_b = sin[None, None, :, :] |
| rotated_x1 = x1 * cos_b - x2 * sin_b |
| rotated_x2 = x1 * sin_b + x2 * cos_b |
| return torch.cat([rotated_x1, rotated_x2], dim=-1) |
|
|
|
|
| |
| |
| |
| class Attention(nn.Module): |
| """MHA with RoPE and configurable causal-or-sliding mask. |
| |
| `kind == 'global'`: full causal attention. |
| `kind == 'slide'` : causal attention restricted to the last `window` tokens. |
| |
| Both code paths use F.scaled_dot_product_attention for speed; the only |
| difference is the additive mask. When kind=='global' we pass `is_causal=True` |
| and skip building an explicit mask. When kind=='slide' we build a banded |
| mask that is bit-identical to the global path with appropriate -inf entries |
| outside the window. |
| """ |
|
|
| def __init__(self, cfg: Config, kind: str): |
| super().__init__() |
| assert kind in ("global", "slide") |
| self.cfg = cfg |
| self.kind = kind |
| self.n_heads = cfg.n_heads |
| self.head_dim = cfg.head_dim |
| self.scale = self.head_dim**-0.5 |
|
|
| self.W_q = nn.Linear(cfg.dim, cfg.dim, bias=False) |
| self.W_k = nn.Linear(cfg.dim, cfg.dim, bias=False) |
| self.W_v = nn.Linear(cfg.dim, cfg.dim, bias=False) |
| self.W_o = nn.Linear(cfg.dim, cfg.dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| B, T, D = x.shape |
| q = self.W_q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| k = self.W_k(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| v = self.W_v(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
| q = _apply_rope(q, cos, sin) |
| k = _apply_rope(k, cos, sin) |
|
|
| if self.kind == "global": |
| out = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| else: |
| |
| mask = _sliding_causal_mask(T, self.cfg.sliding_window, x.device, x.dtype) |
| out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=False) |
|
|
| out = out.transpose(1, 2).contiguous().view(B, T, D) |
| return self.W_o(out) |
|
|
|
|
| def _sliding_causal_mask(T: int, window: int, device, dtype) -> torch.Tensor: |
| """(T, T) additive mask: 0 inside window+causal, -inf outside. |
| |
| Token i attends to j iff j <= i and (i - j) < window. |
| """ |
| i = torch.arange(T, device=device).unsqueeze(1) |
| j = torch.arange(T, device=device).unsqueeze(0) |
| causal = j <= i |
| in_window = (i - j) < window |
| keep = causal & in_window |
| mask = torch.zeros((T, T), device=device, dtype=dtype) |
| mask = mask.masked_fill(~keep, float("-inf")) |
| |
| return mask |
|
|
|
|
| |
| |
| |
| class SwiGLU(nn.Module): |
| def __init__(self, dim: int, hidden: int): |
| super().__init__() |
| self.w_gate = nn.Linear(dim, hidden, bias=False) |
| self.w_up = nn.Linear(dim, hidden, bias=False) |
| self.w_down = nn.Linear(hidden, dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)) |
|
|
|
|
| |
| |
| |
| class Block(nn.Module): |
| def __init__(self, cfg: Config, layer_idx: int): |
| super().__init__() |
| kind = cfg.attention_kind(layer_idx) |
| self.norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps) |
| self.attn = Attention(cfg, kind=kind) |
| self.norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps) |
| self.mlp = SwiGLU(cfg.dim, cfg.mlp_hidden) |
| self.kind = kind |
|
|
| def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(self.norm1(x), cos, sin) |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
|
|
| |
| |
| |
| class Engram(nn.Module): |
| """Static memory bank with single-head attention readout + zero-init gate. |
| |
| Bit-identical to no-engram at init (gate sigmoid is zero so injection is 0). |
| Becomes non-trivial only after the gate is trained away from zero. |
| """ |
|
|
| def __init__(self, cfg: Config): |
| super().__init__() |
| self.cfg = cfg |
| |
| self.slots = nn.Parameter(torch.randn(cfg.engram_slots, cfg.dim) * cfg.init_std) |
| self.q_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) |
| self.k_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) |
| self.v_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) |
| self.o_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) |
| self.norm = RMSNorm(cfg.dim, eps=cfg.norm_eps) |
| |
| |
| self.gate = nn.Parameter(torch.zeros(cfg.dim)) |
|
|
| def forward(self, h: torch.Tensor) -> torch.Tensor: |
| |
| h_n = self.norm(h) |
| q = self.q_proj(h_n) |
| k = self.k_proj(self.slots) |
| v = self.v_proj(self.slots) |
| scale = q.shape[-1] ** -0.5 |
| attn = torch.einsum("btd,sd->bts", q, k) * scale |
| w = attn.softmax(dim=-1) |
| retrieved = torch.einsum("bts,sd->btd", w, v) |
| retrieved = self.o_proj(retrieved) |
| |
| return h + self.gate * retrieved |
|
|
|
|
| |
| |
| |
| class ToyLM(nn.Module): |
| def __init__(self, cfg: Config): |
| super().__init__() |
| self.cfg = cfg |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.dim) |
| self.blocks = nn.ModuleList([Block(cfg, i) for i in range(cfg.n_layers)]) |
| self.norm_f = RMSNorm(cfg.dim, eps=cfg.norm_eps) |
|
|
| if cfg.engram_enabled: |
| self.engram = Engram(cfg) |
| else: |
| self.engram = None |
|
|
| if not cfg.tie_embeddings: |
| self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) |
| else: |
| self.lm_head = None |
|
|
| |
| cos, sin = _build_rope_cache(cfg.max_seq_len, cfg.head_dim, cfg.rope_base, device="cpu", dtype=torch.float32) |
| self.register_buffer("rope_cos", cos, persistent=False) |
| self.register_buffer("rope_sin", sin, persistent=False) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self) -> None: |
| std = self.cfg.init_std |
| for p_name, p in self.named_parameters(): |
| if p.dim() >= 2: |
| nn.init.normal_(p, mean=0.0, std=std) |
| elif p_name.endswith(".weight") and "norm" in p_name.lower(): |
| nn.init.ones_(p) |
| elif p_name == "engram.gate": |
| nn.init.zeros_(p) |
| else: |
| nn.init.zeros_(p) |
|
|
| def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| B, T = idx.shape |
| assert T <= self.cfg.max_seq_len, f"seq_len {T} > max {self.cfg.max_seq_len}" |
| x = self.tok_emb(idx) |
|
|
| cos = self.rope_cos[:T].to(device=x.device, dtype=x.dtype) |
| sin = self.rope_sin[:T].to(device=x.device, dtype=x.dtype) |
|
|
| for i, blk in enumerate(self.blocks): |
| x = blk(x, cos, sin) |
| if self.engram is not None and i == self.cfg.engram_inject_layer: |
| x = self.engram(x) |
|
|
| x = self.norm_f(x) |
|
|
| if self.cfg.tie_embeddings: |
| logits = F.linear(x, self.tok_emb.weight) |
| else: |
| logits = self.lm_head(x) |
|
|
| |
| if self.cfg.lm_head_logit_cap is not None: |
| cap = self.cfg.lm_head_logit_cap |
| logits = cap * torch.tanh(logits / cap) |
|
|
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy( |
| logits.reshape(-1, logits.size(-1)), |
| targets.reshape(-1), |
| ignore_index=-100, |
| ) |
| |
| |
| |
| if self.cfg.z_loss_weight > 0: |
| lse = torch.logsumexp(logits.float(), dim=-1) |
| if targets is not None: |
| valid = targets.reshape(*lse.shape) != -100 |
| if valid.any(): |
| z = (lse[valid] ** 2).mean() |
| else: |
| z = lse.new_zeros(()) |
| else: |
| z = (lse ** 2).mean() |
| loss = loss + self.cfg.z_loss_weight * z |
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| idx: torch.Tensor, |
| max_new_tokens: int = 80, |
| *, |
| temperature: float = 0.8, |
| top_p: float = 0.9, |
| rep_penalty: float = 1.3, |
| stop_token_ids: Optional[set[int]] = None, |
| ) -> torch.Tensor: |
| """Sampling-based decode with top-p + repetition penalty. |
| |
| Defaults are tuned for sub-10M LMs: greedy alone collapses into |
| token-level repetition loops at this scale (entropy stays high but |
| argmax follows a self-amplifying trajectory). T=0.8 + top-p 0.9 + |
| rep_penalty=1.3 reliably breaks the loop without going incoherent. |
| Validated 2026-04-29 on the 12k-step toy 1M checkpoint. |
| |
| Pass `temperature=0.0` to recover greedy (without rep_penalty). |
| """ |
| self.eval() |
| for _ in range(max_new_tokens): |
| logits, _ = self(idx) |
| logits = logits[:, -1].float() |
|
|
| if rep_penalty != 1.0: |
| |
| for b in range(idx.size(0)): |
| seen = torch.unique(idx[b]) |
| pos = logits[b, seen] > 0 |
| logits[b, seen] = torch.where(pos, |
| logits[b, seen] / rep_penalty, |
| logits[b, seen] * rep_penalty) |
|
|
| if temperature <= 0.0: |
| nxt = logits.argmax(dim=-1, keepdim=True) |
| else: |
| logits = logits / temperature |
| if top_p < 1.0: |
| sorted_logits, sorted_idx = logits.sort(descending=True) |
| cum = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1) |
| mask = cum > top_p |
| mask[..., 1:] = mask[..., :-1].clone() |
| mask[..., 0] = False |
| logits = logits.scatter(1, sorted_idx, |
| sorted_logits.masked_fill(mask, float('-inf'))) |
| probs = F.softmax(logits, dim=-1) |
| nxt = torch.multinomial(probs, num_samples=1) |
|
|
| idx = torch.cat([idx, nxt], dim=1) |
|
|
| if stop_token_ids is not None and nxt[0, 0].item() in stop_token_ids: |
| break |
| if idx.size(1) >= self.cfg.max_seq_len: |
| break |
|
|
| return idx |
|
|
| def num_params_breakdown(self) -> dict[str, int]: |
| emb = sum(p.numel() for p in self.tok_emb.parameters()) |
| attn = 0 |
| mlp = 0 |
| norms = 0 |
| for blk in self.blocks: |
| attn += sum(p.numel() for p in blk.attn.parameters()) |
| mlp += sum(p.numel() for p in blk.mlp.parameters()) |
| norms += sum(p.numel() for p in blk.norm1.parameters()) |
| norms += sum(p.numel() for p in blk.norm2.parameters()) |
| norms += sum(p.numel() for p in self.norm_f.parameters()) |
| engram = sum(p.numel() for p in self.engram.parameters()) if self.engram is not None else 0 |
| head = sum(p.numel() for p in self.lm_head.parameters()) if self.lm_head is not None else 0 |
| total = sum(p.numel() for p in self.parameters()) |
| return { |
| "embedding": emb, |
| "attention": attn, |
| "mlp": mlp, |
| "norms": norms, |
| "engram": engram, |
| "lm_head_extra": head, |
| "total": total, |
| } |
|
|