"""Muon optimizer for 2D matrices. Reference: Keller Jordan, "Muon: An optimizer for hidden layers in neural networks" https://kellerjordan.github.io/posts/muon/ Algorithm --------- For each 2D parameter W with gradient G: 1. Maintain momentum buffer M_t = beta * M_{t-1} + G_t 2. Optionally apply Nesterov: G' = G_t + beta * M_t (or just M_t without Nesterov) 3. Orthogonalise G' via 5 iterations of Newton-Schulz with the quintic polynomial coefficients (3.4445, -4.7750, 2.0315): X <- 3.4445 * X - 4.7750 * X X^T X + 2.0315 * (X X^T)^2 X after first dividing X by ||X||_F to bring its singular values into [0, ~1.5]. 4. Apply the orthogonalised update: W <- W - lr * adj_factor * O where adj_factor = max(1, fan_out / fan_in)**0.5 to scale shorter-dim params. This optimiser is intended ONLY for parameters with .dim() >= 2. The recommended recipe uses AdamW for embeddings and 1D tensors (norms, biases). The wrapper class `HybridOptimizer` here packages that split. Bit-identical guarantee ----------------------- When the caller selects optimizer="adamw" in Config, the train script never constructs Muon -- it builds a single AdamW over all params. The HybridOptimizer exists only when optimizer="muon"; it is not a sneaky pass-through. This keeps the two paths cleanly separated. """ from __future__ import annotations from typing import Iterable import torch from torch.optim import Optimizer # --------------------------------------------------------------------------- # Newton-Schulz orthogonalisation # --------------------------------------------------------------------------- @torch.no_grad() def newton_schulz_5(G: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: """Quintic Newton-Schulz, 5 iterations. Returns an approximately-orthogonal matrix with the same shape as G. Operates on the *transposed* shape if rows < cols so that XX^T stays the smaller matrix-multiply (canonical optimisation in the reference impl). """ assert G.dim() >= 2 a, b, c = 3.4445, -4.7750, 2.0315 X = G.float() # do all NS math in fp32 even if param is bf16 if X.size(-2) > X.size(-1): X = X.transpose(-2, -1) transposed = True else: transposed = False # Normalise so ||X||_op <= ~1.5. Frobenius norm is an upper bound on the # spectral norm; dividing by it is safe and the standard choice. X = X / (X.norm() + eps) for _ in range(5): A = X @ X.transpose(-2, -1) B = b * A + c * (A @ A) X = a * X + B @ X if transposed: X = X.transpose(-2, -1) return X.to(G.dtype) # --------------------------------------------------------------------------- # Muon # --------------------------------------------------------------------------- class Muon(Optimizer): def __init__( self, params: Iterable[torch.Tensor], lr: float = 3e-3, momentum: float = 0.95, nesterov: bool = True, weight_decay: float = 0.0, ): defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay) super().__init__(params, defaults) for group in self.param_groups: for p in group["params"]: assert p.dim() >= 2, ( f"Muon expects 2D+ params; got shape {tuple(p.shape)}. " "Wrap embeddings + 1D tensors with AdamW (use HybridOptimizer)." ) @torch.no_grad() def step(self, closure=None): loss = closure() if closure is not None else None for group in self.param_groups: lr = group["lr"] beta = group["momentum"] nesterov = group["nesterov"] wd = group["weight_decay"] for p in group["params"]: if p.grad is None: continue g = p.grad state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(p) buf = state["momentum_buffer"] buf.mul_(beta).add_(g) update = g + beta * buf if nesterov else buf # Reshape ND tensors (e.g. conv kernels) into 2D for orthogonalisation. # Embeddings are excluded by construction; here we expect Linear weights # which are already 2D, but keep the reshape for safety. orig_shape = update.shape if update.dim() > 2: update = update.reshape(update.shape[0], -1) ortho = newton_schulz_5(update) # Scale by sqrt(max(1, fan_out/fan_in)) so updates have sane magnitude # across rectangular shapes. fan_out = rows, fan_in = cols. fan_out, fan_in = ortho.shape[-2], ortho.shape[-1] adj = max(1.0, fan_out / fan_in) ** 0.5 if ortho.shape != orig_shape: ortho = ortho.reshape(orig_shape) if wd != 0.0: p.add_(p, alpha=-lr * wd) p.add_(ortho, alpha=-lr * adj) return loss # --------------------------------------------------------------------------- # Hybrid Muon + AdamW wrapper # --------------------------------------------------------------------------- class HybridOptimizer: """Routes 2D+ params to Muon and 1D / embedding params to AdamW. Mimics the torch.optim.Optimizer surface enough for our train loop: .step(), .zero_grad(set_to_none=True), .param_groups (for LR scheduling). """ def __init__( self, named_params: Iterable[tuple[str, torch.nn.Parameter]], muon_lr: float, adamw_lr: float, muon_momentum: float = 0.95, adamw_betas: tuple[float, float] = (0.9, 0.95), weight_decay: float = 0.0, ): muon_params = [] adamw_params = [] for name, p in named_params: if not p.requires_grad: continue # Embeddings have dim() == 2 but should still go to AdamW per the recipe. is_embedding = "tok_emb" in name or "engram.slots" in name if p.dim() >= 2 and not is_embedding: muon_params.append(p) else: adamw_params.append(p) self.muon = Muon( muon_params, lr=muon_lr, momentum=muon_momentum, nesterov=True, weight_decay=weight_decay, ) self.adamw = torch.optim.AdamW( adamw_params, lr=adamw_lr, betas=adamw_betas, weight_decay=weight_decay, ) self.param_groups = self.muon.param_groups + self.adamw.param_groups def step(self, closure=None): if closure is not None: raise NotImplementedError("HybridOptimizer does not support a closure.") self.muon.step() self.adamw.step() def zero_grad(self, set_to_none: bool = True): self.muon.zero_grad(set_to_none=set_to_none) self.adamw.zero_grad(set_to_none=set_to_none) def state_dict(self): return {"muon": self.muon.state_dict(), "adamw": self.adamw.state_dict()} def load_state_dict(self, sd): self.muon.load_state_dict(sd["muon"]) self.adamw.load_state_dict(sd["adamw"])