"""tilelli.core.ternary — BitNet b1.58 style ternary weights with STE. Every weight in the model lives in {-α, 0, +α} where α is a per-tensor scalar chosen by AbsMean rescaling. The forward pass sees the ternarized version; the backward pass pretends the ternarization is the identity so gradients flow to a FP32 "shadow" weight. That's the straight-through estimator (STE). Why ternary: - CPU inference: no float multiplies. Matmul collapses to add/subtract/skip. - Tiny training: ternary weights are ~10x smaller than FP16. - SDR activations (binary) × ternary weights = pure integer arithmetic in the forward pass at inference. Zero floating point. Popcount + add. - Biology agrees: synapses are roughly excitatory / inhibitory / silent. Recipe (from the BitNet b1.58 paper): 1. alpha = mean(|W|) # AbsMean rescale 2. W_scaled = W / (alpha + eps) 3. W_q = clamp(round(W_scaled), -1, 1) * alpha 4. forward uses W_q, backward uses dW_q/dW = 1 (straight-through) """ from __future__ import annotations import torch from torch import Tensor, nn EPS = 1e-5 def absmean_scale(w: Tensor) -> Tensor: """The per-tensor scalar alpha = mean(|W|), clamped away from zero. Returns a 0-d tensor so it broadcasts against w without allocating. The clamp is load-bearing: an all-zero tensor would otherwise produce a division by zero and kill training in one step. """ return w.abs().mean().clamp(min=EPS) def ternarize(w: Tensor) -> Tensor: """Ternarize w to values in {-alpha, 0, +alpha} with a straight-through gradient. Forward: returns round(w / alpha).clamp(-1, 1) * alpha Backward: d(ternarize(w))/dw = 1 (identity — the STE trick) The identity gradient is implemented with the classic ``w + (w_q - w).detach()`` idiom: numerically equal to w_q in the forward pass, but its autograd graph points at w with gradient 1. """ alpha = absmean_scale(w) w_scaled = w / alpha w_q = torch.round(w_scaled).clamp_(-1.0, 1.0) * alpha return w + (w_q - w).detach() def ternary_values(w: Tensor) -> Tensor: """Return the ternarized tensor as a plain (non-STE) tensor. Useful for inspection and inference-time weight export. This is what the CPU inference path will actually store and consume. """ with torch.no_grad(): alpha = absmean_scale(w) return torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha def ternary_signs(w: Tensor) -> Tensor: """Return just the {-1, 0, +1} trits (int8), without the scale. Storage form: 2 bits per weight is the theoretical minimum for three states. We return int8 here for day-0 correctness; bit-pack later once the rest of the stack is working. """ with torch.no_grad(): alpha = absmean_scale(w) return torch.round(w / alpha).clamp_(-1.0, 1.0).to(torch.int8) def absmean_scale_per_row(w: Tensor) -> Tensor: """Per-row alpha: one mean(|.|) per output row, clamped away from zero. First axis is the row axis. Returns shape (rows, 1, 1, ...) so it broadcasts against w. """ if w.dim() < 2: raise ValueError(f"per-row scale needs dim>=2, got shape {tuple(w.shape)}") flat = w.reshape(w.size(0), -1) alpha = flat.abs().mean(dim=1).clamp(min=EPS) view = (w.size(0),) + (1,) * (w.dim() - 1) return alpha.view(view) def ternarize_per_row(w: Tensor) -> Tensor: """Per-row ternary STE: each row of w ternarised with its own alpha.""" alpha = absmean_scale_per_row(w) w_q = torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha return w + (w_q - w).detach() def ternary_values_per_row(w: Tensor) -> Tensor: """Detached per-row ternarised values (no STE shim).""" with torch.no_grad(): alpha = absmean_scale_per_row(w) return torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha class LearnableScale(nn.Module): """A single learnable FP32 scalar, clamped at EPS to avoid div-by-zero. Wraps the scalar in nn.Module so it (a) shows up in .parameters(), (b) moves with .to(device). Use .value() to read the (clamped) scalar. """ def __init__(self, initial: float = 1.0) -> None: super().__init__() if initial <= 0: raise ValueError(f"initial scale must be > 0, got {initial}") self.alpha = nn.Parameter(torch.tensor(float(initial))) def value(self) -> Tensor: return self.alpha.clamp(min=EPS) def ternarize_lsq(w: Tensor, alpha: Tensor) -> Tensor: """STE ternarize using a learnable alpha (Esser et al., LSQ). Forward: q_int * alpha where q_int = round(w/alpha).clamp(-1, 1) Backward: dout/dw = 1 (STE — identity gradient to w shadow) dout/dalpha = q_int """ q_int = torch.round(w / alpha).clamp_(-1.0, 1.0).detach() return q_int * alpha + (w - w.detach()) @torch.no_grad() def deadzone_stats(w: Tensor, band: float = 0.1) -> dict[str, float]: """Diagnostic for Tequila-style "deadzone trapping" (arXiv 2509.23809). A weight is deadzone-trapped when ``|w/alpha|`` sits within ``band`` of a rounding boundary at ±0.5: the round-to-trit operation is on a knife- edge, and STE noise dominates the true gradient signal. Tequila's finding is that a non-trivial fraction of weights live there permanently after long training, contributing only noise. Returns the breakdown of the trit assignment plus the boundary-band occupancy. Use this to verify Tequila applies before considering the deadzone-bias fix. Keys: ``alpha``: per-tensor AbsMean scale. ``frac_neg / frac_zero / frac_pos``: fraction of weights rounding to −1, 0, +1 respectively (sums to 1). ``frac_boundary``: fraction with ``||w/alpha| − 0.5| < band`` — the deadzone-trap candidates. High values (>5–10%) suggest Tequila's bias-repurposing fix could matter. ``frac_zero_inner``: fraction with ``|w/alpha| < 0.5 − band``, i.e. deeply zero (stable, not on the boundary). """ alpha = absmean_scale(w) r = (w / alpha).abs() sgn = torch.sign(w / alpha) rounded = torch.round(w / alpha).clamp_(-1.0, 1.0) n = float(w.numel()) return { "alpha": float(alpha.item()), "frac_neg": float((rounded == -1).sum().item()) / n, "frac_zero": float((rounded == 0).sum().item()) / n, "frac_pos": float((rounded == 1).sum().item()) / n, "frac_boundary": float(((r - 0.5).abs() < band).sum().item()) / n, "frac_zero_inner": float(((sgn != 0) & (r < 0.5 - band)).sum().item()) / n, }