File size: 4,358 Bytes
025878f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Config dataclass for the toy 50M LM.

Scaled up from the toy_1m_gemma4_dsv4 baseline. Architectural levers stay the
same (alternating SLIDE/GLOBAL Gemma 4 attention, optional Muon, optional
512-slot Engram, full v2 stabilisation), only the shape numbers change.

Two architectural variants are flag-gated:

  attention_pattern:
    "all_global" -- every layer is full causal attention (baseline).
    "gemma4"     -- alternating SLIDE/GLOBAL across layers; last layer is GLOBAL.

  optimizer:
    "adamw" -- AdamW for everything (baseline).
    "muon"  -- Muon for params with .dim() >= 2; AdamW for embeddings + 1D.

  engram_enabled: optional 512-slot external memory bank with zero-init gate.

When attention_pattern == "all_global" and optimizer == "adamw" and engram_enabled
is False, training math is bit-identical to a plain causal transformer baseline.

Defaults
--------
* vocab=8192 (up from 4096): fresh BPE on a larger FineWeb-edu sample.
* dim=512, n_layers=12, n_heads=8, head_dim=64.
* mlp_hidden=2048 (4x dim, SwiGLU).
* max_seq_len=8192 (up from 4096).
* sliding_window=1024 ("larger model" Gemma 4 tier; 1M used 512).
* All v2 stabilisers ON: lm_head_logit_cap=30.0, z_loss_weight=1e-4, lr_schedule="wsd".
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal


AttentionPattern = Literal["all_global", "gemma4"]
OptimizerName = Literal["adamw", "muon"]
LRSchedule = Literal["cosine", "wsd"]


@dataclass
class Config:
    # ---------- model shape ----------
    vocab_size: int = 8192
    dim: int = 512
    n_layers: int = 12
    n_heads: int = 8
    head_dim: int = 64  # n_heads * head_dim must equal dim
    mlp_hidden: int = 2048
    max_seq_len: int = 8192

    # ---------- gemma4 SWA ----------
    attention_pattern: AttentionPattern = "gemma4"
    sliding_window: int = 1024

    # ---------- engram (off by default) ----------
    engram_enabled: bool = False
    engram_slots: int = 512
    engram_inject_layer: int = 6  # mid-stack for the 12-layer build

    # ---------- training ----------
    optimizer: OptimizerName = "muon"
    rope_base: float = 10000.0
    norm_eps: float = 1e-5
    dropout: float = 0.0
    tie_embeddings: bool = True

    # ---------- CE stabilisation (Gemma-2 logit cap + PaLM z-loss) ----------
    # ON by default at 50M scale -- the 1M project added these as a v2 bolt-on
    # but at 50M with bf16 they're standard practice (DeepSeek V2/3, Gemma 2/3,
    # PaLM). Bit-identical to the un-stabilised path when both knobs are 0/None.
    lm_head_logit_cap: float | None = 30.0
    z_loss_weight: float = 1e-4

    # ---------- LR schedule ----------
    # WSD by default at 50M (per Apr 2026 small-LM research; lets the head
    # decay over the last 20 % of post-warmup, much smoother than cosine).
    lr_schedule: LRSchedule = "wsd"
    wsd_decay_frac: float = 0.2

    # ---------- bookkeeping ----------
    init_std: float = 0.02

    def __post_init__(self) -> None:
        assert self.n_heads * self.head_dim == self.dim, (
            f"n_heads*head_dim={self.n_heads * self.head_dim} != dim={self.dim}"
        )
        assert self.attention_pattern in ("all_global", "gemma4")
        assert self.optimizer in ("adamw", "muon")
        assert self.lr_schedule in ("cosine", "wsd")
        assert 0.0 <= self.wsd_decay_frac <= 1.0
        assert self.z_loss_weight >= 0.0
        assert self.lm_head_logit_cap is None or self.lm_head_logit_cap > 0
        # Last layer must be GLOBAL when using gemma4 (canonical invariant).
        # Concretely: layer i is GLOBAL iff (i % 2 == 1) for i in [0, n_layers).
        # n_layers must be even, last index n_layers-1 must be odd.
        if self.attention_pattern == "gemma4":
            assert self.n_layers % 2 == 0 and self.n_layers >= 2, (
                "gemma4 pattern requires even n_layers >= 2 so the last layer is GLOBAL"
            )

    def attention_kind(self, layer_idx: int) -> Literal["slide", "global"]:
        """Return whether `layer_idx` is a sliding-window or global-attention layer."""
        if self.attention_pattern == "all_global":
            return "global"
        # gemma4: even idx = SLIDE, odd idx = GLOBAL. Last layer (n_layers-1) is odd
        # for any even n_layers, so it is GLOBAL.
        return "global" if (layer_idx % 2 == 1) else "slide"