"""Muon — Momentum-Updated Newton-Schulz orthogonalised optimiser. Jordan, Bernstein et al. (Oct 2024). Used to train Kimi K2 (1T MoE, 15.5T tokens, zero instabilities) — but Kimi K2 used MuonClip (the QK-rescaling stability fix) on top. This implementation omits QK-Clip since at sub-frontier scale plain Muon is empirically stable. The core idea: SGD's momentum update (m = mu * m + g; W <- W - lr * m) is fine, except it can leave m anisotropic — concentrated on the top singular directions. Muon orthogonalises m via a few Newton-Schulz iterations before applying it, so each step contributes equally across all singular directions. Algorithm (per 2D weight matrix, applied only to weights with ndim >= 2): 1. m_t = momentum * m_{t-1} + g_t 2. u_t = NewtonSchulz5(m_t) # orthogonalise: u_t ≈ m_t @ (m_t^T m_t)^{-1/2} 3. W_t = W_{t-1} - lr * sqrt(max(d_in, d_out) / d_min) * u_t For 1D parameters (biases, norm scales, embeddings) Muon is *not* recommended — fall back to AdamW for those. The convention in the Muon papers is to declare two parameter groups: 2D-weights -> Muon, everything-else -> AdamW. We follow that here. Reference: https://kellerjordan.github.io/posts/muon/ """ from __future__ import annotations import torch from torch import Tensor from torch.optim.optimizer import Optimizer @torch.no_grad() def _newton_schulz5(g: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: """Approximate g @ (g^T g)^{-1/2} via 5 Newton-Schulz iterations. Constants from the Muon reference implementation; tuned so that the iteration converges to the correct orthogonalisation in <=5 steps for typical weight-matrix singular-value distributions. """ a, b, c = (3.4445, -4.7750, 2.0315) x = g.float() if g.size(-2) > g.size(-1): # Newton-Schulz expects "tall" matrix; transpose then transpose back. x = x.transpose(-2, -1) transposed = True else: transposed = False x = x / (x.norm() + eps) # ||x|| = 1 entering the iteration for _ in range(steps): y = x @ x.transpose(-2, -1) x = a * x + b * y @ x + c * y @ y @ x if transposed: x = x.transpose(-2, -1) return x.to(g.dtype) class Muon(Optimizer): """Muon optimiser for 2D+ parameters; pair with AdamW for 1D params. Parameters ---------- params : iterable of 2D+ tensors only. lr : float, default 0.02. Larger than AdamW because the orthogonalised update has unit operator-norm, not unit element-norm. momentum : float, default 0.95. weight_decay : float, default 0.0. nesterov : bool, default True. Nesterov-flavoured momentum lookahead. ns_steps : int, default 5. Number of Newton-Schulz iterations. """ def __init__( self, params, lr: float = 0.02, momentum: float = 0.95, weight_decay: float = 0.0, nesterov: bool = True, ns_steps: int = 5, ) -> None: if lr <= 0.0: raise ValueError(f"lr must be positive, got {lr}") if not 0.0 <= momentum < 1.0: raise ValueError(f"momentum must be in [0, 1), got {momentum}") defaults = dict( lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov, ns_steps=ns_steps, ) super().__init__(params, defaults) for group in self.param_groups: for p in group["params"]: if p.dim() < 2: raise ValueError( f"Muon expects 2D+ parameters; got shape {tuple(p.shape)}. " "Pair Muon with AdamW for 1D params (biases, norms)." ) @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: lr = group["lr"] mom = group["momentum"] wd = group["weight_decay"] nesterov = group["nesterov"] ns_steps = group["ns_steps"] for p in group["params"]: if p.grad is None: continue g = p.grad state = self.state[p] if "m" not in state: state["m"] = torch.zeros_like(p) m = state["m"] m.mul_(mom).add_(g) update = m.add(g, alpha=mom) if nesterov else m # Newton-Schulz orthogonalisation; flatten any 3D+ into 2D first. orig_shape = update.shape if update.dim() > 2: update_2d = update.reshape(update.size(0), -1) else: update_2d = update u = _newton_schulz5(update_2d, steps=ns_steps) u = u.reshape(orig_shape) # Shape-aware LR scaling: multiply by sqrt(max(fan_in, fan_out) / d_min). # Keeps the operator-norm step size constant across rectangular shapes. fan_max = max(p.size(0), p.size(-1)) fan_min = min(p.size(0), p.size(-1)) shape_scale = (fan_max / fan_min) ** 0.5 if wd != 0.0: p.mul_(1 - lr * wd) p.add_(u, alpha=-lr * shape_scale) return loss def split_params_for_muon(model: torch.nn.Module ) -> tuple[list[torch.nn.Parameter], list[torch.nn.Parameter]]: """Split a model's parameters into (muon_params, adamw_params). Convention from the Muon paper: 2D+ weights -> Muon; biases, norm scales, embeddings, unembed -> AdamW. We treat embeddings and unembed (lm_head) as AdamW-managed because their geometry (token-shaped, sparse gradients) is poorly suited to orthogonalisation. """ muon_params: list[torch.nn.Parameter] = [] adamw_params: list[torch.nn.Parameter] = [] for name, p in model.named_parameters(): if not p.requires_grad: continue is_embedding = ("embed" in name) or ("unembed" in name) or ("tok_embed" in name) if p.dim() >= 2 and not is_embedding: muon_params.append(p) else: adamw_params.append(p) return muon_params, adamw_params