| """Muon — Momentum-Updated Newton-Schulz orthogonalised optimiser. |
| |
| Jordan, Bernstein et al. (Oct 2024). Used to train Kimi K2 (1T MoE, |
| 15.5T tokens, zero instabilities) — but Kimi K2 used MuonClip (the |
| QK-rescaling stability fix) on top. This implementation omits QK-Clip |
| since at sub-frontier scale plain Muon is empirically stable. |
| |
| The core idea: SGD's momentum update (m = mu * m + g; W <- W - lr * m) |
| is fine, except it can leave m anisotropic — concentrated on the top |
| singular directions. Muon orthogonalises m via a few Newton-Schulz |
| iterations before applying it, so each step contributes equally across |
| all singular directions. |
| |
| Algorithm (per 2D weight matrix, applied only to weights with ndim >= 2): |
| 1. m_t = momentum * m_{t-1} + g_t |
| 2. u_t = NewtonSchulz5(m_t) # orthogonalise: u_t ≈ m_t @ (m_t^T m_t)^{-1/2} |
| 3. W_t = W_{t-1} - lr * sqrt(max(d_in, d_out) / d_min) * u_t |
| |
| For 1D parameters (biases, norm scales, embeddings) Muon is *not* |
| recommended — fall back to AdamW for those. The convention in the |
| Muon papers is to declare two parameter groups: 2D-weights -> Muon, |
| everything-else -> AdamW. We follow that here. |
| |
| Reference: https://kellerjordan.github.io/posts/muon/ |
| """ |
| from __future__ import annotations |
|
|
| import torch |
| from torch import Tensor |
| from torch.optim.optimizer import Optimizer |
|
|
|
|
| @torch.no_grad() |
| def _newton_schulz5(g: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: |
| """Approximate g @ (g^T g)^{-1/2} via 5 Newton-Schulz iterations. |
| |
| Constants from the Muon reference implementation; tuned so that the |
| iteration converges to the correct orthogonalisation in <=5 steps for |
| typical weight-matrix singular-value distributions. |
| """ |
| a, b, c = (3.4445, -4.7750, 2.0315) |
| x = g.float() |
| if g.size(-2) > g.size(-1): |
| |
| x = x.transpose(-2, -1) |
| transposed = True |
| else: |
| transposed = False |
| x = x / (x.norm() + eps) |
| for _ in range(steps): |
| y = x @ x.transpose(-2, -1) |
| x = a * x + b * y @ x + c * y @ y @ x |
| if transposed: |
| x = x.transpose(-2, -1) |
| return x.to(g.dtype) |
|
|
|
|
| class Muon(Optimizer): |
| """Muon optimiser for 2D+ parameters; pair with AdamW for 1D params. |
| |
| Parameters |
| ---------- |
| params : iterable of 2D+ tensors only. |
| lr : float, default 0.02. Larger than AdamW because the orthogonalised |
| update has unit operator-norm, not unit element-norm. |
| momentum : float, default 0.95. |
| weight_decay : float, default 0.0. |
| nesterov : bool, default True. Nesterov-flavoured momentum lookahead. |
| ns_steps : int, default 5. Number of Newton-Schulz iterations. |
| """ |
|
|
| def __init__( |
| self, |
| params, |
| lr: float = 0.02, |
| momentum: float = 0.95, |
| weight_decay: float = 0.0, |
| nesterov: bool = True, |
| ns_steps: int = 5, |
| ) -> None: |
| if lr <= 0.0: |
| raise ValueError(f"lr must be positive, got {lr}") |
| if not 0.0 <= momentum < 1.0: |
| raise ValueError(f"momentum must be in [0, 1), got {momentum}") |
| defaults = dict( |
| lr=lr, |
| momentum=momentum, |
| weight_decay=weight_decay, |
| nesterov=nesterov, |
| ns_steps=ns_steps, |
| ) |
| super().__init__(params, defaults) |
| for group in self.param_groups: |
| for p in group["params"]: |
| if p.dim() < 2: |
| raise ValueError( |
| f"Muon expects 2D+ parameters; got shape {tuple(p.shape)}. " |
| "Pair Muon with AdamW for 1D params (biases, norms)." |
| ) |
|
|
| @torch.no_grad() |
| def step(self, closure=None): |
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
|
|
| for group in self.param_groups: |
| lr = group["lr"] |
| mom = group["momentum"] |
| wd = group["weight_decay"] |
| nesterov = group["nesterov"] |
| ns_steps = group["ns_steps"] |
|
|
| for p in group["params"]: |
| if p.grad is None: |
| continue |
| g = p.grad |
|
|
| state = self.state[p] |
| if "m" not in state: |
| state["m"] = torch.zeros_like(p) |
| m = state["m"] |
| m.mul_(mom).add_(g) |
| update = m.add(g, alpha=mom) if nesterov else m |
|
|
| |
| orig_shape = update.shape |
| if update.dim() > 2: |
| update_2d = update.reshape(update.size(0), -1) |
| else: |
| update_2d = update |
| u = _newton_schulz5(update_2d, steps=ns_steps) |
| u = u.reshape(orig_shape) |
|
|
| |
| |
| fan_max = max(p.size(0), p.size(-1)) |
| fan_min = min(p.size(0), p.size(-1)) |
| shape_scale = (fan_max / fan_min) ** 0.5 |
|
|
| if wd != 0.0: |
| p.mul_(1 - lr * wd) |
| p.add_(u, alpha=-lr * shape_scale) |
|
|
| return loss |
|
|
|
|
| def split_params_for_muon(model: torch.nn.Module |
| ) -> tuple[list[torch.nn.Parameter], list[torch.nn.Parameter]]: |
| """Split a model's parameters into (muon_params, adamw_params). |
| |
| Convention from the Muon paper: 2D+ weights -> Muon; biases, norm scales, |
| embeddings, unembed -> AdamW. We treat embeddings and unembed (lm_head) as |
| AdamW-managed because their geometry (token-shaped, sparse gradients) is |
| poorly suited to orthogonalisation. |
| """ |
| muon_params: list[torch.nn.Parameter] = [] |
| adamw_params: list[torch.nn.Parameter] = [] |
| for name, p in model.named_parameters(): |
| if not p.requires_grad: |
| continue |
| is_embedding = ("embed" in name) or ("unembed" in name) or ("tok_embed" in name) |
| if p.dim() >= 2 and not is_embedding: |
| muon_params.append(p) |
| else: |
| adamw_params.append(p) |
| return muon_params, adamw_params |
|
|