Shard-1 / code /muon.py
Crownelius's picture
Initial release: Shard-40m-v1 (54.5M dense transformer, anneal final)
025878f verified
"""Muon optimizer for 2D matrices.
Reference: Keller Jordan, "Muon: An optimizer for hidden layers in neural networks"
https://kellerjordan.github.io/posts/muon/
Algorithm
---------
For each 2D parameter W with gradient G:
1. Maintain momentum buffer M_t = beta * M_{t-1} + G_t
2. Optionally apply Nesterov: G' = G_t + beta * M_t (or just M_t without Nesterov)
3. Orthogonalise G' via 5 iterations of Newton-Schulz with the quintic polynomial
coefficients (3.4445, -4.7750, 2.0315):
X <- 3.4445 * X - 4.7750 * X X^T X + 2.0315 * (X X^T)^2 X
after first dividing X by ||X||_F to bring its singular values into [0, ~1.5].
4. Apply the orthogonalised update: W <- W - lr * adj_factor * O
where adj_factor = max(1, fan_out / fan_in)**0.5 to scale shorter-dim params.
This optimiser is intended ONLY for parameters with .dim() >= 2. The recommended
recipe uses AdamW for embeddings and 1D tensors (norms, biases). The wrapper
class `HybridOptimizer` here packages that split.
Bit-identical guarantee
-----------------------
When the caller selects optimizer="adamw" in Config, the train script never
constructs Muon -- it builds a single AdamW over all params. The HybridOptimizer
exists only when optimizer="muon"; it is not a sneaky pass-through. This keeps
the two paths cleanly separated.
"""
from __future__ import annotations
from typing import Iterable
import torch
from torch.optim import Optimizer
# ---------------------------------------------------------------------------
# Newton-Schulz orthogonalisation
# ---------------------------------------------------------------------------
@torch.no_grad()
def newton_schulz_5(G: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
"""Quintic Newton-Schulz, 5 iterations. Returns an approximately-orthogonal
matrix with the same shape as G.
Operates on the *transposed* shape if rows < cols so that XX^T stays the
smaller matrix-multiply (canonical optimisation in the reference impl).
"""
assert G.dim() >= 2
a, b, c = 3.4445, -4.7750, 2.0315
X = G.float() # do all NS math in fp32 even if param is bf16
if X.size(-2) > X.size(-1):
X = X.transpose(-2, -1)
transposed = True
else:
transposed = False
# Normalise so ||X||_op <= ~1.5. Frobenius norm is an upper bound on the
# spectral norm; dividing by it is safe and the standard choice.
X = X / (X.norm() + eps)
for _ in range(5):
A = X @ X.transpose(-2, -1)
B = b * A + c * (A @ A)
X = a * X + B @ X
if transposed:
X = X.transpose(-2, -1)
return X.to(G.dtype)
# ---------------------------------------------------------------------------
# Muon
# ---------------------------------------------------------------------------
class Muon(Optimizer):
def __init__(
self,
params: Iterable[torch.Tensor],
lr: float = 3e-3,
momentum: float = 0.95,
nesterov: bool = True,
weight_decay: float = 0.0,
):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay)
super().__init__(params, defaults)
for group in self.param_groups:
for p in group["params"]:
assert p.dim() >= 2, (
f"Muon expects 2D+ params; got shape {tuple(p.shape)}. "
"Wrap embeddings + 1D tensors with AdamW (use HybridOptimizer)."
)
@torch.no_grad()
def step(self, closure=None):
loss = closure() if closure is not None else None
for group in self.param_groups:
lr = group["lr"]
beta = group["momentum"]
nesterov = group["nesterov"]
wd = group["weight_decay"]
for p in group["params"]:
if p.grad is None:
continue
g = p.grad
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(p)
buf = state["momentum_buffer"]
buf.mul_(beta).add_(g)
update = g + beta * buf if nesterov else buf
# Reshape ND tensors (e.g. conv kernels) into 2D for orthogonalisation.
# Embeddings are excluded by construction; here we expect Linear weights
# which are already 2D, but keep the reshape for safety.
orig_shape = update.shape
if update.dim() > 2:
update = update.reshape(update.shape[0], -1)
ortho = newton_schulz_5(update)
# Scale by sqrt(max(1, fan_out/fan_in)) so updates have sane magnitude
# across rectangular shapes. fan_out = rows, fan_in = cols.
fan_out, fan_in = ortho.shape[-2], ortho.shape[-1]
adj = max(1.0, fan_out / fan_in) ** 0.5
if ortho.shape != orig_shape:
ortho = ortho.reshape(orig_shape)
if wd != 0.0:
p.add_(p, alpha=-lr * wd)
p.add_(ortho, alpha=-lr * adj)
return loss
# ---------------------------------------------------------------------------
# Hybrid Muon + AdamW wrapper
# ---------------------------------------------------------------------------
class HybridOptimizer:
"""Routes 2D+ params to Muon and 1D / embedding params to AdamW.
Mimics the torch.optim.Optimizer surface enough for our train loop:
.step(), .zero_grad(set_to_none=True), .param_groups (for LR scheduling).
"""
def __init__(
self,
named_params: Iterable[tuple[str, torch.nn.Parameter]],
muon_lr: float,
adamw_lr: float,
muon_momentum: float = 0.95,
adamw_betas: tuple[float, float] = (0.9, 0.95),
weight_decay: float = 0.0,
):
muon_params = []
adamw_params = []
for name, p in named_params:
if not p.requires_grad:
continue
# Embeddings have dim() == 2 but should still go to AdamW per the recipe.
is_embedding = "tok_emb" in name or "engram.slots" in name
if p.dim() >= 2 and not is_embedding:
muon_params.append(p)
else:
adamw_params.append(p)
self.muon = Muon(
muon_params,
lr=muon_lr,
momentum=muon_momentum,
nesterov=True,
weight_decay=weight_decay,
)
self.adamw = torch.optim.AdamW(
adamw_params,
lr=adamw_lr,
betas=adamw_betas,
weight_decay=weight_decay,
)
self.param_groups = self.muon.param_groups + self.adamw.param_groups
def step(self, closure=None):
if closure is not None:
raise NotImplementedError("HybridOptimizer does not support a closure.")
self.muon.step()
self.adamw.step()
def zero_grad(self, set_to_none: bool = True):
self.muon.zero_grad(set_to_none=set_to_none)
self.adamw.zero_grad(set_to_none=set_to_none)
def state_dict(self):
return {"muon": self.muon.state_dict(), "adamw": self.adamw.state_dict()}
def load_state_dict(self, sd):
self.muon.load_state_dict(sd["muon"])
self.adamw.load_state_dict(sd["adamw"])