| """Config dataclass for the toy 50M LM. |
| |
| Scaled up from the toy_1m_gemma4_dsv4 baseline. Architectural levers stay the |
| same (alternating SLIDE/GLOBAL Gemma 4 attention, optional Muon, optional |
| 512-slot Engram, full v2 stabilisation), only the shape numbers change. |
| |
| Two architectural variants are flag-gated: |
| |
| attention_pattern: |
| "all_global" -- every layer is full causal attention (baseline). |
| "gemma4" -- alternating SLIDE/GLOBAL across layers; last layer is GLOBAL. |
| |
| optimizer: |
| "adamw" -- AdamW for everything (baseline). |
| "muon" -- Muon for params with .dim() >= 2; AdamW for embeddings + 1D. |
| |
| engram_enabled: optional 512-slot external memory bank with zero-init gate. |
| |
| When attention_pattern == "all_global" and optimizer == "adamw" and engram_enabled |
| is False, training math is bit-identical to a plain causal transformer baseline. |
| |
| Defaults |
| -------- |
| * vocab=8192 (up from 4096): fresh BPE on a larger FineWeb-edu sample. |
| * dim=512, n_layers=12, n_heads=8, head_dim=64. |
| * mlp_hidden=2048 (4x dim, SwiGLU). |
| * max_seq_len=8192 (up from 4096). |
| * sliding_window=1024 ("larger model" Gemma 4 tier; 1M used 512). |
| * All v2 stabilisers ON: lm_head_logit_cap=30.0, z_loss_weight=1e-4, lr_schedule="wsd". |
| """ |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Literal |
|
|
|
|
| AttentionPattern = Literal["all_global", "gemma4"] |
| OptimizerName = Literal["adamw", "muon"] |
| LRSchedule = Literal["cosine", "wsd"] |
|
|
|
|
| @dataclass |
| class Config: |
| |
| vocab_size: int = 8192 |
| dim: int = 512 |
| n_layers: int = 12 |
| n_heads: int = 8 |
| head_dim: int = 64 |
| mlp_hidden: int = 2048 |
| max_seq_len: int = 8192 |
|
|
| |
| attention_pattern: AttentionPattern = "gemma4" |
| sliding_window: int = 1024 |
|
|
| |
| engram_enabled: bool = False |
| engram_slots: int = 512 |
| engram_inject_layer: int = 6 |
|
|
| |
| optimizer: OptimizerName = "muon" |
| rope_base: float = 10000.0 |
| norm_eps: float = 1e-5 |
| dropout: float = 0.0 |
| tie_embeddings: bool = True |
|
|
| |
| |
| |
| |
| lm_head_logit_cap: float | None = 30.0 |
| z_loss_weight: float = 1e-4 |
|
|
| |
| |
| |
| lr_schedule: LRSchedule = "wsd" |
| wsd_decay_frac: float = 0.2 |
|
|
| |
| init_std: float = 0.02 |
|
|
| def __post_init__(self) -> None: |
| assert self.n_heads * self.head_dim == self.dim, ( |
| f"n_heads*head_dim={self.n_heads * self.head_dim} != dim={self.dim}" |
| ) |
| assert self.attention_pattern in ("all_global", "gemma4") |
| assert self.optimizer in ("adamw", "muon") |
| assert self.lr_schedule in ("cosine", "wsd") |
| assert 0.0 <= self.wsd_decay_frac <= 1.0 |
| assert self.z_loss_weight >= 0.0 |
| assert self.lm_head_logit_cap is None or self.lm_head_logit_cap > 0 |
| |
| |
| |
| if self.attention_pattern == "gemma4": |
| assert self.n_layers % 2 == 0 and self.n_layers >= 2, ( |
| "gemma4 pattern requires even n_layers >= 2 so the last layer is GLOBAL" |
| ) |
|
|
| def attention_kind(self, layer_idx: int) -> Literal["slide", "global"]: |
| """Return whether `layer_idx` is a sliding-window or global-attention layer.""" |
| if self.attention_pattern == "all_global": |
| return "global" |
| |
| |
| return "global" if (layer_idx % 2 == 1) else "slide" |
|
|