File size: 7,390 Bytes
025878f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | """Muon optimizer for 2D matrices.
Reference: Keller Jordan, "Muon: An optimizer for hidden layers in neural networks"
https://kellerjordan.github.io/posts/muon/
Algorithm
---------
For each 2D parameter W with gradient G:
1. Maintain momentum buffer M_t = beta * M_{t-1} + G_t
2. Optionally apply Nesterov: G' = G_t + beta * M_t (or just M_t without Nesterov)
3. Orthogonalise G' via 5 iterations of Newton-Schulz with the quintic polynomial
coefficients (3.4445, -4.7750, 2.0315):
X <- 3.4445 * X - 4.7750 * X X^T X + 2.0315 * (X X^T)^2 X
after first dividing X by ||X||_F to bring its singular values into [0, ~1.5].
4. Apply the orthogonalised update: W <- W - lr * adj_factor * O
where adj_factor = max(1, fan_out / fan_in)**0.5 to scale shorter-dim params.
This optimiser is intended ONLY for parameters with .dim() >= 2. The recommended
recipe uses AdamW for embeddings and 1D tensors (norms, biases). The wrapper
class `HybridOptimizer` here packages that split.
Bit-identical guarantee
-----------------------
When the caller selects optimizer="adamw" in Config, the train script never
constructs Muon -- it builds a single AdamW over all params. The HybridOptimizer
exists only when optimizer="muon"; it is not a sneaky pass-through. This keeps
the two paths cleanly separated.
"""
from __future__ import annotations
from typing import Iterable
import torch
from torch.optim import Optimizer
# ---------------------------------------------------------------------------
# Newton-Schulz orthogonalisation
# ---------------------------------------------------------------------------
@torch.no_grad()
def newton_schulz_5(G: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
"""Quintic Newton-Schulz, 5 iterations. Returns an approximately-orthogonal
matrix with the same shape as G.
Operates on the *transposed* shape if rows < cols so that XX^T stays the
smaller matrix-multiply (canonical optimisation in the reference impl).
"""
assert G.dim() >= 2
a, b, c = 3.4445, -4.7750, 2.0315
X = G.float() # do all NS math in fp32 even if param is bf16
if X.size(-2) > X.size(-1):
X = X.transpose(-2, -1)
transposed = True
else:
transposed = False
# Normalise so ||X||_op <= ~1.5. Frobenius norm is an upper bound on the
# spectral norm; dividing by it is safe and the standard choice.
X = X / (X.norm() + eps)
for _ in range(5):
A = X @ X.transpose(-2, -1)
B = b * A + c * (A @ A)
X = a * X + B @ X
if transposed:
X = X.transpose(-2, -1)
return X.to(G.dtype)
# ---------------------------------------------------------------------------
# Muon
# ---------------------------------------------------------------------------
class Muon(Optimizer):
def __init__(
self,
params: Iterable[torch.Tensor],
lr: float = 3e-3,
momentum: float = 0.95,
nesterov: bool = True,
weight_decay: float = 0.0,
):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay)
super().__init__(params, defaults)
for group in self.param_groups:
for p in group["params"]:
assert p.dim() >= 2, (
f"Muon expects 2D+ params; got shape {tuple(p.shape)}. "
"Wrap embeddings + 1D tensors with AdamW (use HybridOptimizer)."
)
@torch.no_grad()
def step(self, closure=None):
loss = closure() if closure is not None else None
for group in self.param_groups:
lr = group["lr"]
beta = group["momentum"]
nesterov = group["nesterov"]
wd = group["weight_decay"]
for p in group["params"]:
if p.grad is None:
continue
g = p.grad
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(p)
buf = state["momentum_buffer"]
buf.mul_(beta).add_(g)
update = g + beta * buf if nesterov else buf
# Reshape ND tensors (e.g. conv kernels) into 2D for orthogonalisation.
# Embeddings are excluded by construction; here we expect Linear weights
# which are already 2D, but keep the reshape for safety.
orig_shape = update.shape
if update.dim() > 2:
update = update.reshape(update.shape[0], -1)
ortho = newton_schulz_5(update)
# Scale by sqrt(max(1, fan_out/fan_in)) so updates have sane magnitude
# across rectangular shapes. fan_out = rows, fan_in = cols.
fan_out, fan_in = ortho.shape[-2], ortho.shape[-1]
adj = max(1.0, fan_out / fan_in) ** 0.5
if ortho.shape != orig_shape:
ortho = ortho.reshape(orig_shape)
if wd != 0.0:
p.add_(p, alpha=-lr * wd)
p.add_(ortho, alpha=-lr * adj)
return loss
# ---------------------------------------------------------------------------
# Hybrid Muon + AdamW wrapper
# ---------------------------------------------------------------------------
class HybridOptimizer:
"""Routes 2D+ params to Muon and 1D / embedding params to AdamW.
Mimics the torch.optim.Optimizer surface enough for our train loop:
.step(), .zero_grad(set_to_none=True), .param_groups (for LR scheduling).
"""
def __init__(
self,
named_params: Iterable[tuple[str, torch.nn.Parameter]],
muon_lr: float,
adamw_lr: float,
muon_momentum: float = 0.95,
adamw_betas: tuple[float, float] = (0.9, 0.95),
weight_decay: float = 0.0,
):
muon_params = []
adamw_params = []
for name, p in named_params:
if not p.requires_grad:
continue
# Embeddings have dim() == 2 but should still go to AdamW per the recipe.
is_embedding = "tok_emb" in name or "engram.slots" in name
if p.dim() >= 2 and not is_embedding:
muon_params.append(p)
else:
adamw_params.append(p)
self.muon = Muon(
muon_params,
lr=muon_lr,
momentum=muon_momentum,
nesterov=True,
weight_decay=weight_decay,
)
self.adamw = torch.optim.AdamW(
adamw_params,
lr=adamw_lr,
betas=adamw_betas,
weight_decay=weight_decay,
)
self.param_groups = self.muon.param_groups + self.adamw.param_groups
def step(self, closure=None):
if closure is not None:
raise NotImplementedError("HybridOptimizer does not support a closure.")
self.muon.step()
self.adamw.step()
def zero_grad(self, set_to_none: bool = True):
self.muon.zero_grad(set_to_none=set_to_none)
self.adamw.zero_grad(set_to_none=set_to_none)
def state_dict(self):
return {"muon": self.muon.state_dict(), "adamw": self.adamw.state_dict()}
def load_state_dict(self, sd):
self.muon.load_state_dict(sd["muon"])
self.adamw.load_state_dict(sd["adamw"])
|