TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
6.36 kB
"""Muon — Momentum-Updated Newton-Schulz orthogonalised optimiser.
Jordan, Bernstein et al. (Oct 2024). Used to train Kimi K2 (1T MoE,
15.5T tokens, zero instabilities) — but Kimi K2 used MuonClip (the
QK-rescaling stability fix) on top. This implementation omits QK-Clip
since at sub-frontier scale plain Muon is empirically stable.
The core idea: SGD's momentum update (m = mu * m + g; W <- W - lr * m)
is fine, except it can leave m anisotropic — concentrated on the top
singular directions. Muon orthogonalises m via a few Newton-Schulz
iterations before applying it, so each step contributes equally across
all singular directions.
Algorithm (per 2D weight matrix, applied only to weights with ndim >= 2):
1. m_t = momentum * m_{t-1} + g_t
2. u_t = NewtonSchulz5(m_t) # orthogonalise: u_t ≈ m_t @ (m_t^T m_t)^{-1/2}
3. W_t = W_{t-1} - lr * sqrt(max(d_in, d_out) / d_min) * u_t
For 1D parameters (biases, norm scales, embeddings) Muon is *not*
recommended — fall back to AdamW for those. The convention in the
Muon papers is to declare two parameter groups: 2D-weights -> Muon,
everything-else -> AdamW. We follow that here.
Reference: https://kellerjordan.github.io/posts/muon/
"""
from __future__ import annotations
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
@torch.no_grad()
def _newton_schulz5(g: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor:
"""Approximate g @ (g^T g)^{-1/2} via 5 Newton-Schulz iterations.
Constants from the Muon reference implementation; tuned so that the
iteration converges to the correct orthogonalisation in <=5 steps for
typical weight-matrix singular-value distributions.
"""
a, b, c = (3.4445, -4.7750, 2.0315)
x = g.float()
if g.size(-2) > g.size(-1):
# Newton-Schulz expects "tall" matrix; transpose then transpose back.
x = x.transpose(-2, -1)
transposed = True
else:
transposed = False
x = x / (x.norm() + eps) # ||x|| = 1 entering the iteration
for _ in range(steps):
y = x @ x.transpose(-2, -1)
x = a * x + b * y @ x + c * y @ y @ x
if transposed:
x = x.transpose(-2, -1)
return x.to(g.dtype)
class Muon(Optimizer):
"""Muon optimiser for 2D+ parameters; pair with AdamW for 1D params.
Parameters
----------
params : iterable of 2D+ tensors only.
lr : float, default 0.02. Larger than AdamW because the orthogonalised
update has unit operator-norm, not unit element-norm.
momentum : float, default 0.95.
weight_decay : float, default 0.0.
nesterov : bool, default True. Nesterov-flavoured momentum lookahead.
ns_steps : int, default 5. Number of Newton-Schulz iterations.
"""
def __init__(
self,
params,
lr: float = 0.02,
momentum: float = 0.95,
weight_decay: float = 0.0,
nesterov: bool = True,
ns_steps: int = 5,
) -> None:
if lr <= 0.0:
raise ValueError(f"lr must be positive, got {lr}")
if not 0.0 <= momentum < 1.0:
raise ValueError(f"momentum must be in [0, 1), got {momentum}")
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
ns_steps=ns_steps,
)
super().__init__(params, defaults)
for group in self.param_groups:
for p in group["params"]:
if p.dim() < 2:
raise ValueError(
f"Muon expects 2D+ parameters; got shape {tuple(p.shape)}. "
"Pair Muon with AdamW for 1D params (biases, norms)."
)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
mom = group["momentum"]
wd = group["weight_decay"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
for p in group["params"]:
if p.grad is None:
continue
g = p.grad
state = self.state[p]
if "m" not in state:
state["m"] = torch.zeros_like(p)
m = state["m"]
m.mul_(mom).add_(g)
update = m.add(g, alpha=mom) if nesterov else m
# Newton-Schulz orthogonalisation; flatten any 3D+ into 2D first.
orig_shape = update.shape
if update.dim() > 2:
update_2d = update.reshape(update.size(0), -1)
else:
update_2d = update
u = _newton_schulz5(update_2d, steps=ns_steps)
u = u.reshape(orig_shape)
# Shape-aware LR scaling: multiply by sqrt(max(fan_in, fan_out) / d_min).
# Keeps the operator-norm step size constant across rectangular shapes.
fan_max = max(p.size(0), p.size(-1))
fan_min = min(p.size(0), p.size(-1))
shape_scale = (fan_max / fan_min) ** 0.5
if wd != 0.0:
p.mul_(1 - lr * wd)
p.add_(u, alpha=-lr * shape_scale)
return loss
def split_params_for_muon(model: torch.nn.Module
) -> tuple[list[torch.nn.Parameter], list[torch.nn.Parameter]]:
"""Split a model's parameters into (muon_params, adamw_params).
Convention from the Muon paper: 2D+ weights -> Muon; biases, norm scales,
embeddings, unembed -> AdamW. We treat embeddings and unembed (lm_head) as
AdamW-managed because their geometry (token-shaped, sparse gradients) is
poorly suited to orthogonalisation.
"""
muon_params: list[torch.nn.Parameter] = []
adamw_params: list[torch.nn.Parameter] = []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
is_embedding = ("embed" in name) or ("unembed" in name) or ("tok_embed" in name)
if p.dim() >= 2 and not is_embedding:
muon_params.append(p)
else:
adamw_params.append(p)
return muon_params, adamw_params