| """tilelli.core.ternary — BitNet b1.58 style ternary weights with STE. |
| |
| Every weight in the model lives in {-α, 0, +α} where α is a per-tensor |
| scalar chosen by AbsMean rescaling. The forward pass sees the ternarized |
| version; the backward pass pretends the ternarization is the identity so |
| gradients flow to a FP32 "shadow" weight. That's the straight-through |
| estimator (STE). |
| |
| Why ternary: |
| - CPU inference: no float multiplies. Matmul collapses to add/subtract/skip. |
| - Tiny training: ternary weights are ~10x smaller than FP16. |
| - SDR activations (binary) × ternary weights = pure integer arithmetic |
| in the forward pass at inference. Zero floating point. Popcount + add. |
| - Biology agrees: synapses are roughly excitatory / inhibitory / silent. |
| |
| Recipe (from the BitNet b1.58 paper): |
| 1. alpha = mean(|W|) # AbsMean rescale |
| 2. W_scaled = W / (alpha + eps) |
| 3. W_q = clamp(round(W_scaled), -1, 1) * alpha |
| 4. forward uses W_q, backward uses dW_q/dW = 1 (straight-through) |
| """ |
| from __future__ import annotations |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
|
|
| EPS = 1e-5 |
|
|
|
|
| def absmean_scale(w: Tensor) -> Tensor: |
| """The per-tensor scalar alpha = mean(|W|), clamped away from zero. |
| |
| Returns a 0-d tensor so it broadcasts against w without allocating. |
| The clamp is load-bearing: an all-zero tensor would otherwise produce |
| a division by zero and kill training in one step. |
| """ |
| return w.abs().mean().clamp(min=EPS) |
|
|
|
|
| def ternarize(w: Tensor) -> Tensor: |
| """Ternarize w to values in {-alpha, 0, +alpha} with a straight-through |
| gradient. |
| |
| Forward: returns round(w / alpha).clamp(-1, 1) * alpha |
| Backward: d(ternarize(w))/dw = 1 (identity — the STE trick) |
| |
| The identity gradient is implemented with the classic |
| ``w + (w_q - w).detach()`` idiom: numerically equal to w_q in the |
| forward pass, but its autograd graph points at w with gradient 1. |
| """ |
| alpha = absmean_scale(w) |
| w_scaled = w / alpha |
| w_q = torch.round(w_scaled).clamp_(-1.0, 1.0) * alpha |
| return w + (w_q - w).detach() |
|
|
|
|
| def ternary_values(w: Tensor) -> Tensor: |
| """Return the ternarized tensor as a plain (non-STE) tensor. |
| |
| Useful for inspection and inference-time weight export. This is what |
| the CPU inference path will actually store and consume. |
| """ |
| with torch.no_grad(): |
| alpha = absmean_scale(w) |
| return torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha |
|
|
|
|
| def ternary_signs(w: Tensor) -> Tensor: |
| """Return just the {-1, 0, +1} trits (int8), without the scale. |
| |
| Storage form: 2 bits per weight is the theoretical minimum for three |
| states. We return int8 here for day-0 correctness; bit-pack later |
| once the rest of the stack is working. |
| """ |
| with torch.no_grad(): |
| alpha = absmean_scale(w) |
| return torch.round(w / alpha).clamp_(-1.0, 1.0).to(torch.int8) |
|
|
|
|
| def absmean_scale_per_row(w: Tensor) -> Tensor: |
| """Per-row alpha: one mean(|.|) per output row, clamped away from zero. |
| |
| First axis is the row axis. Returns shape (rows, 1, 1, ...) so it |
| broadcasts against w. |
| """ |
| if w.dim() < 2: |
| raise ValueError(f"per-row scale needs dim>=2, got shape {tuple(w.shape)}") |
| flat = w.reshape(w.size(0), -1) |
| alpha = flat.abs().mean(dim=1).clamp(min=EPS) |
| view = (w.size(0),) + (1,) * (w.dim() - 1) |
| return alpha.view(view) |
|
|
|
|
| def ternarize_per_row(w: Tensor) -> Tensor: |
| """Per-row ternary STE: each row of w ternarised with its own alpha.""" |
| alpha = absmean_scale_per_row(w) |
| w_q = torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha |
| return w + (w_q - w).detach() |
|
|
|
|
| def ternary_values_per_row(w: Tensor) -> Tensor: |
| """Detached per-row ternarised values (no STE shim).""" |
| with torch.no_grad(): |
| alpha = absmean_scale_per_row(w) |
| return torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha |
|
|
|
|
| class LearnableScale(nn.Module): |
| """A single learnable FP32 scalar, clamped at EPS to avoid div-by-zero. |
| |
| Wraps the scalar in nn.Module so it (a) shows up in .parameters(), (b) |
| moves with .to(device). Use .value() to read the (clamped) scalar. |
| """ |
|
|
| def __init__(self, initial: float = 1.0) -> None: |
| super().__init__() |
| if initial <= 0: |
| raise ValueError(f"initial scale must be > 0, got {initial}") |
| self.alpha = nn.Parameter(torch.tensor(float(initial))) |
|
|
| def value(self) -> Tensor: |
| return self.alpha.clamp(min=EPS) |
|
|
|
|
| def ternarize_lsq(w: Tensor, alpha: Tensor) -> Tensor: |
| """STE ternarize using a learnable alpha (Esser et al., LSQ). |
| |
| Forward: q_int * alpha where q_int = round(w/alpha).clamp(-1, 1) |
| Backward: dout/dw = 1 (STE — identity gradient to w shadow) |
| dout/dalpha = q_int |
| """ |
| q_int = torch.round(w / alpha).clamp_(-1.0, 1.0).detach() |
| return q_int * alpha + (w - w.detach()) |
|
|
|
|
| @torch.no_grad() |
| def deadzone_stats(w: Tensor, band: float = 0.1) -> dict[str, float]: |
| """Diagnostic for Tequila-style "deadzone trapping" (arXiv 2509.23809). |
| |
| A weight is deadzone-trapped when ``|w/alpha|`` sits within ``band`` of |
| a rounding boundary at ±0.5: the round-to-trit operation is on a knife- |
| edge, and STE noise dominates the true gradient signal. Tequila's |
| finding is that a non-trivial fraction of weights live there permanently |
| after long training, contributing only noise. |
| |
| Returns the breakdown of the trit assignment plus the boundary-band |
| occupancy. Use this to verify Tequila applies before considering the |
| deadzone-bias fix. |
| |
| Keys: |
| ``alpha``: per-tensor AbsMean scale. |
| ``frac_neg / frac_zero / frac_pos``: fraction of weights rounding to |
| −1, 0, +1 respectively (sums to 1). |
| ``frac_boundary``: fraction with ``||w/alpha| − 0.5| < band`` — the |
| deadzone-trap candidates. High values (>5–10%) suggest Tequila's |
| bias-repurposing fix could matter. |
| ``frac_zero_inner``: fraction with ``|w/alpha| < 0.5 − band``, i.e. |
| deeply zero (stable, not on the boundary). |
| """ |
| alpha = absmean_scale(w) |
| r = (w / alpha).abs() |
| sgn = torch.sign(w / alpha) |
| rounded = torch.round(w / alpha).clamp_(-1.0, 1.0) |
| n = float(w.numel()) |
| return { |
| "alpha": float(alpha.item()), |
| "frac_neg": float((rounded == -1).sum().item()) / n, |
| "frac_zero": float((rounded == 0).sum().item()) / n, |
| "frac_pos": float((rounded == 1).sum().item()) / n, |
| "frac_boundary": float(((r - 0.5).abs() < band).sum().item()) / n, |
| "frac_zero_inner": float(((sgn != 0) & (r < 0.5 - band)).sum().item()) / n, |
| } |
|
|