"""Config dataclass for the toy 50M LM. Scaled up from the toy_1m_gemma4_dsv4 baseline. Architectural levers stay the same (alternating SLIDE/GLOBAL Gemma 4 attention, optional Muon, optional 512-slot Engram, full v2 stabilisation), only the shape numbers change. Two architectural variants are flag-gated: attention_pattern: "all_global" -- every layer is full causal attention (baseline). "gemma4" -- alternating SLIDE/GLOBAL across layers; last layer is GLOBAL. optimizer: "adamw" -- AdamW for everything (baseline). "muon" -- Muon for params with .dim() >= 2; AdamW for embeddings + 1D. engram_enabled: optional 512-slot external memory bank with zero-init gate. When attention_pattern == "all_global" and optimizer == "adamw" and engram_enabled is False, training math is bit-identical to a plain causal transformer baseline. Defaults -------- * vocab=8192 (up from 4096): fresh BPE on a larger FineWeb-edu sample. * dim=512, n_layers=12, n_heads=8, head_dim=64. * mlp_hidden=2048 (4x dim, SwiGLU). * max_seq_len=8192 (up from 4096). * sliding_window=1024 ("larger model" Gemma 4 tier; 1M used 512). * All v2 stabilisers ON: lm_head_logit_cap=30.0, z_loss_weight=1e-4, lr_schedule="wsd". """ from __future__ import annotations from dataclasses import dataclass from typing import Literal AttentionPattern = Literal["all_global", "gemma4"] OptimizerName = Literal["adamw", "muon"] LRSchedule = Literal["cosine", "wsd"] @dataclass class Config: # ---------- model shape ---------- vocab_size: int = 8192 dim: int = 512 n_layers: int = 12 n_heads: int = 8 head_dim: int = 64 # n_heads * head_dim must equal dim mlp_hidden: int = 2048 max_seq_len: int = 8192 # ---------- gemma4 SWA ---------- attention_pattern: AttentionPattern = "gemma4" sliding_window: int = 1024 # ---------- engram (off by default) ---------- engram_enabled: bool = False engram_slots: int = 512 engram_inject_layer: int = 6 # mid-stack for the 12-layer build # ---------- training ---------- optimizer: OptimizerName = "muon" rope_base: float = 10000.0 norm_eps: float = 1e-5 dropout: float = 0.0 tie_embeddings: bool = True # ---------- CE stabilisation (Gemma-2 logit cap + PaLM z-loss) ---------- # ON by default at 50M scale -- the 1M project added these as a v2 bolt-on # but at 50M with bf16 they're standard practice (DeepSeek V2/3, Gemma 2/3, # PaLM). Bit-identical to the un-stabilised path when both knobs are 0/None. lm_head_logit_cap: float | None = 30.0 z_loss_weight: float = 1e-4 # ---------- LR schedule ---------- # WSD by default at 50M (per Apr 2026 small-LM research; lets the head # decay over the last 20 % of post-warmup, much smoother than cosine). lr_schedule: LRSchedule = "wsd" wsd_decay_frac: float = 0.2 # ---------- bookkeeping ---------- init_std: float = 0.02 def __post_init__(self) -> None: assert self.n_heads * self.head_dim == self.dim, ( f"n_heads*head_dim={self.n_heads * self.head_dim} != dim={self.dim}" ) assert self.attention_pattern in ("all_global", "gemma4") assert self.optimizer in ("adamw", "muon") assert self.lr_schedule in ("cosine", "wsd") assert 0.0 <= self.wsd_decay_frac <= 1.0 assert self.z_loss_weight >= 0.0 assert self.lm_head_logit_cap is None or self.lm_head_logit_cap > 0 # Last layer must be GLOBAL when using gemma4 (canonical invariant). # Concretely: layer i is GLOBAL iff (i % 2 == 1) for i in [0, n_layers). # n_layers must be even, last index n_layers-1 must be odd. if self.attention_pattern == "gemma4": assert self.n_layers % 2 == 0 and self.n_layers >= 2, ( "gemma4 pattern requires even n_layers >= 2 so the last layer is GLOBAL" ) def attention_kind(self, layer_idx: int) -> Literal["slide", "global"]: """Return whether `layer_idx` is a sliding-window or global-attention layer.""" if self.attention_pattern == "all_global": return "global" # gemma4: even idx = SLIDE, odd idx = GLOBAL. Last layer (n_layers-1) is odd # for any even n_layers, so it is GLOBAL. return "global" if (layer_idx % 2 == 1) else "slide"