File size: 6,623 Bytes
f86dc09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | """tilelli.core.ternary — BitNet b1.58 style ternary weights with STE.
Every weight in the model lives in {-α, 0, +α} where α is a per-tensor
scalar chosen by AbsMean rescaling. The forward pass sees the ternarized
version; the backward pass pretends the ternarization is the identity so
gradients flow to a FP32 "shadow" weight. That's the straight-through
estimator (STE).
Why ternary:
- CPU inference: no float multiplies. Matmul collapses to add/subtract/skip.
- Tiny training: ternary weights are ~10x smaller than FP16.
- SDR activations (binary) × ternary weights = pure integer arithmetic
in the forward pass at inference. Zero floating point. Popcount + add.
- Biology agrees: synapses are roughly excitatory / inhibitory / silent.
Recipe (from the BitNet b1.58 paper):
1. alpha = mean(|W|) # AbsMean rescale
2. W_scaled = W / (alpha + eps)
3. W_q = clamp(round(W_scaled), -1, 1) * alpha
4. forward uses W_q, backward uses dW_q/dW = 1 (straight-through)
"""
from __future__ import annotations
import torch
from torch import Tensor, nn
EPS = 1e-5
def absmean_scale(w: Tensor) -> Tensor:
"""The per-tensor scalar alpha = mean(|W|), clamped away from zero.
Returns a 0-d tensor so it broadcasts against w without allocating.
The clamp is load-bearing: an all-zero tensor would otherwise produce
a division by zero and kill training in one step.
"""
return w.abs().mean().clamp(min=EPS)
def ternarize(w: Tensor) -> Tensor:
"""Ternarize w to values in {-alpha, 0, +alpha} with a straight-through
gradient.
Forward: returns round(w / alpha).clamp(-1, 1) * alpha
Backward: d(ternarize(w))/dw = 1 (identity — the STE trick)
The identity gradient is implemented with the classic
``w + (w_q - w).detach()`` idiom: numerically equal to w_q in the
forward pass, but its autograd graph points at w with gradient 1.
"""
alpha = absmean_scale(w)
w_scaled = w / alpha
w_q = torch.round(w_scaled).clamp_(-1.0, 1.0) * alpha
return w + (w_q - w).detach()
def ternary_values(w: Tensor) -> Tensor:
"""Return the ternarized tensor as a plain (non-STE) tensor.
Useful for inspection and inference-time weight export. This is what
the CPU inference path will actually store and consume.
"""
with torch.no_grad():
alpha = absmean_scale(w)
return torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha
def ternary_signs(w: Tensor) -> Tensor:
"""Return just the {-1, 0, +1} trits (int8), without the scale.
Storage form: 2 bits per weight is the theoretical minimum for three
states. We return int8 here for day-0 correctness; bit-pack later
once the rest of the stack is working.
"""
with torch.no_grad():
alpha = absmean_scale(w)
return torch.round(w / alpha).clamp_(-1.0, 1.0).to(torch.int8)
def absmean_scale_per_row(w: Tensor) -> Tensor:
"""Per-row alpha: one mean(|.|) per output row, clamped away from zero.
First axis is the row axis. Returns shape (rows, 1, 1, ...) so it
broadcasts against w.
"""
if w.dim() < 2:
raise ValueError(f"per-row scale needs dim>=2, got shape {tuple(w.shape)}")
flat = w.reshape(w.size(0), -1)
alpha = flat.abs().mean(dim=1).clamp(min=EPS)
view = (w.size(0),) + (1,) * (w.dim() - 1)
return alpha.view(view)
def ternarize_per_row(w: Tensor) -> Tensor:
"""Per-row ternary STE: each row of w ternarised with its own alpha."""
alpha = absmean_scale_per_row(w)
w_q = torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha
return w + (w_q - w).detach()
def ternary_values_per_row(w: Tensor) -> Tensor:
"""Detached per-row ternarised values (no STE shim)."""
with torch.no_grad():
alpha = absmean_scale_per_row(w)
return torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha
class LearnableScale(nn.Module):
"""A single learnable FP32 scalar, clamped at EPS to avoid div-by-zero.
Wraps the scalar in nn.Module so it (a) shows up in .parameters(), (b)
moves with .to(device). Use .value() to read the (clamped) scalar.
"""
def __init__(self, initial: float = 1.0) -> None:
super().__init__()
if initial <= 0:
raise ValueError(f"initial scale must be > 0, got {initial}")
self.alpha = nn.Parameter(torch.tensor(float(initial)))
def value(self) -> Tensor:
return self.alpha.clamp(min=EPS)
def ternarize_lsq(w: Tensor, alpha: Tensor) -> Tensor:
"""STE ternarize using a learnable alpha (Esser et al., LSQ).
Forward: q_int * alpha where q_int = round(w/alpha).clamp(-1, 1)
Backward: dout/dw = 1 (STE — identity gradient to w shadow)
dout/dalpha = q_int
"""
q_int = torch.round(w / alpha).clamp_(-1.0, 1.0).detach()
return q_int * alpha + (w - w.detach())
@torch.no_grad()
def deadzone_stats(w: Tensor, band: float = 0.1) -> dict[str, float]:
"""Diagnostic for Tequila-style "deadzone trapping" (arXiv 2509.23809).
A weight is deadzone-trapped when ``|w/alpha|`` sits within ``band`` of
a rounding boundary at ±0.5: the round-to-trit operation is on a knife-
edge, and STE noise dominates the true gradient signal. Tequila's
finding is that a non-trivial fraction of weights live there permanently
after long training, contributing only noise.
Returns the breakdown of the trit assignment plus the boundary-band
occupancy. Use this to verify Tequila applies before considering the
deadzone-bias fix.
Keys:
``alpha``: per-tensor AbsMean scale.
``frac_neg / frac_zero / frac_pos``: fraction of weights rounding to
−1, 0, +1 respectively (sums to 1).
``frac_boundary``: fraction with ``||w/alpha| − 0.5| < band`` — the
deadzone-trap candidates. High values (>5–10%) suggest Tequila's
bias-repurposing fix could matter.
``frac_zero_inner``: fraction with ``|w/alpha| < 0.5 − band``, i.e.
deeply zero (stable, not on the boundary).
"""
alpha = absmean_scale(w)
r = (w / alpha).abs()
sgn = torch.sign(w / alpha)
rounded = torch.round(w / alpha).clamp_(-1.0, 1.0)
n = float(w.numel())
return {
"alpha": float(alpha.item()),
"frac_neg": float((rounded == -1).sum().item()) / n,
"frac_zero": float((rounded == 0).sum().item()) / n,
"frac_pos": float((rounded == 1).sum().item()) / n,
"frac_boundary": float(((r - 0.5).abs() < band).sum().item()) / n,
"frac_zero_inner": float(((sgn != 0) & (r < 0.5 - band)).sum().item()) / n,
}
|