| """Muon optimizer for 2D matrices. |
| |
| Reference: Keller Jordan, "Muon: An optimizer for hidden layers in neural networks" |
| https://kellerjordan.github.io/posts/muon/ |
| |
| Algorithm |
| --------- |
| For each 2D parameter W with gradient G: |
| 1. Maintain momentum buffer M_t = beta * M_{t-1} + G_t |
| 2. Optionally apply Nesterov: G' = G_t + beta * M_t (or just M_t without Nesterov) |
| 3. Orthogonalise G' via 5 iterations of Newton-Schulz with the quintic polynomial |
| coefficients (3.4445, -4.7750, 2.0315): |
| X <- 3.4445 * X - 4.7750 * X X^T X + 2.0315 * (X X^T)^2 X |
| after first dividing X by ||X||_F to bring its singular values into [0, ~1.5]. |
| 4. Apply the orthogonalised update: W <- W - lr * adj_factor * O |
| where adj_factor = max(1, fan_out / fan_in)**0.5 to scale shorter-dim params. |
| |
| This optimiser is intended ONLY for parameters with .dim() >= 2. The recommended |
| recipe uses AdamW for embeddings and 1D tensors (norms, biases). The wrapper |
| class `HybridOptimizer` here packages that split. |
| |
| Bit-identical guarantee |
| ----------------------- |
| When the caller selects optimizer="adamw" in Config, the train script never |
| constructs Muon -- it builds a single AdamW over all params. The HybridOptimizer |
| exists only when optimizer="muon"; it is not a sneaky pass-through. This keeps |
| the two paths cleanly separated. |
| """ |
| from __future__ import annotations |
|
|
| from typing import Iterable |
|
|
| import torch |
| from torch.optim import Optimizer |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def newton_schulz_5(G: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: |
| """Quintic Newton-Schulz, 5 iterations. Returns an approximately-orthogonal |
| matrix with the same shape as G. |
| |
| Operates on the *transposed* shape if rows < cols so that XX^T stays the |
| smaller matrix-multiply (canonical optimisation in the reference impl). |
| """ |
| assert G.dim() >= 2 |
| a, b, c = 3.4445, -4.7750, 2.0315 |
| X = G.float() |
| if X.size(-2) > X.size(-1): |
| X = X.transpose(-2, -1) |
| transposed = True |
| else: |
| transposed = False |
|
|
| |
| |
| X = X / (X.norm() + eps) |
|
|
| for _ in range(5): |
| A = X @ X.transpose(-2, -1) |
| B = b * A + c * (A @ A) |
| X = a * X + B @ X |
|
|
| if transposed: |
| X = X.transpose(-2, -1) |
| return X.to(G.dtype) |
|
|
|
|
| |
| |
| |
| class Muon(Optimizer): |
| def __init__( |
| self, |
| params: Iterable[torch.Tensor], |
| lr: float = 3e-3, |
| momentum: float = 0.95, |
| nesterov: bool = True, |
| weight_decay: float = 0.0, |
| ): |
| defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay) |
| super().__init__(params, defaults) |
| for group in self.param_groups: |
| for p in group["params"]: |
| assert p.dim() >= 2, ( |
| f"Muon expects 2D+ params; got shape {tuple(p.shape)}. " |
| "Wrap embeddings + 1D tensors with AdamW (use HybridOptimizer)." |
| ) |
|
|
| @torch.no_grad() |
| def step(self, closure=None): |
| loss = closure() if closure is not None else None |
|
|
| for group in self.param_groups: |
| lr = group["lr"] |
| beta = group["momentum"] |
| nesterov = group["nesterov"] |
| wd = group["weight_decay"] |
|
|
| for p in group["params"]: |
| if p.grad is None: |
| continue |
| g = p.grad |
| state = self.state[p] |
| if "momentum_buffer" not in state: |
| state["momentum_buffer"] = torch.zeros_like(p) |
| buf = state["momentum_buffer"] |
| buf.mul_(beta).add_(g) |
| update = g + beta * buf if nesterov else buf |
|
|
| |
| |
| |
| orig_shape = update.shape |
| if update.dim() > 2: |
| update = update.reshape(update.shape[0], -1) |
|
|
| ortho = newton_schulz_5(update) |
|
|
| |
| |
| fan_out, fan_in = ortho.shape[-2], ortho.shape[-1] |
| adj = max(1.0, fan_out / fan_in) ** 0.5 |
|
|
| if ortho.shape != orig_shape: |
| ortho = ortho.reshape(orig_shape) |
|
|
| if wd != 0.0: |
| p.add_(p, alpha=-lr * wd) |
| p.add_(ortho, alpha=-lr * adj) |
|
|
| return loss |
|
|
|
|
| |
| |
| |
| class HybridOptimizer: |
| """Routes 2D+ params to Muon and 1D / embedding params to AdamW. |
| |
| Mimics the torch.optim.Optimizer surface enough for our train loop: |
| .step(), .zero_grad(set_to_none=True), .param_groups (for LR scheduling). |
| """ |
|
|
| def __init__( |
| self, |
| named_params: Iterable[tuple[str, torch.nn.Parameter]], |
| muon_lr: float, |
| adamw_lr: float, |
| muon_momentum: float = 0.95, |
| adamw_betas: tuple[float, float] = (0.9, 0.95), |
| weight_decay: float = 0.0, |
| ): |
| muon_params = [] |
| adamw_params = [] |
| for name, p in named_params: |
| if not p.requires_grad: |
| continue |
| |
| is_embedding = "tok_emb" in name or "engram.slots" in name |
| if p.dim() >= 2 and not is_embedding: |
| muon_params.append(p) |
| else: |
| adamw_params.append(p) |
|
|
| self.muon = Muon( |
| muon_params, |
| lr=muon_lr, |
| momentum=muon_momentum, |
| nesterov=True, |
| weight_decay=weight_decay, |
| ) |
| self.adamw = torch.optim.AdamW( |
| adamw_params, |
| lr=adamw_lr, |
| betas=adamw_betas, |
| weight_decay=weight_decay, |
| ) |
| self.param_groups = self.muon.param_groups + self.adamw.param_groups |
|
|
| def step(self, closure=None): |
| if closure is not None: |
| raise NotImplementedError("HybridOptimizer does not support a closure.") |
| self.muon.step() |
| self.adamw.step() |
|
|
| def zero_grad(self, set_to_none: bool = True): |
| self.muon.zero_grad(set_to_none=set_to_none) |
| self.adamw.zero_grad(set_to_none=set_to_none) |
|
|
| def state_dict(self): |
| return {"muon": self.muon.state_dict(), "adamw": self.adamw.state_dict()} |
|
|
| def load_state_dict(self, sd): |
| self.muon.load_state_dict(sd["muon"]) |
| self.adamw.load_state_dict(sd["adamw"]) |
|
|