Tilelli-llm / src /tilelli /core /ternary.py
TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
6.62 kB
"""tilelli.core.ternary — BitNet b1.58 style ternary weights with STE.
Every weight in the model lives in {-α, 0, +α} where α is a per-tensor
scalar chosen by AbsMean rescaling. The forward pass sees the ternarized
version; the backward pass pretends the ternarization is the identity so
gradients flow to a FP32 "shadow" weight. That's the straight-through
estimator (STE).
Why ternary:
- CPU inference: no float multiplies. Matmul collapses to add/subtract/skip.
- Tiny training: ternary weights are ~10x smaller than FP16.
- SDR activations (binary) × ternary weights = pure integer arithmetic
in the forward pass at inference. Zero floating point. Popcount + add.
- Biology agrees: synapses are roughly excitatory / inhibitory / silent.
Recipe (from the BitNet b1.58 paper):
1. alpha = mean(|W|) # AbsMean rescale
2. W_scaled = W / (alpha + eps)
3. W_q = clamp(round(W_scaled), -1, 1) * alpha
4. forward uses W_q, backward uses dW_q/dW = 1 (straight-through)
"""
from __future__ import annotations
import torch
from torch import Tensor, nn
EPS = 1e-5
def absmean_scale(w: Tensor) -> Tensor:
"""The per-tensor scalar alpha = mean(|W|), clamped away from zero.
Returns a 0-d tensor so it broadcasts against w without allocating.
The clamp is load-bearing: an all-zero tensor would otherwise produce
a division by zero and kill training in one step.
"""
return w.abs().mean().clamp(min=EPS)
def ternarize(w: Tensor) -> Tensor:
"""Ternarize w to values in {-alpha, 0, +alpha} with a straight-through
gradient.
Forward: returns round(w / alpha).clamp(-1, 1) * alpha
Backward: d(ternarize(w))/dw = 1 (identity — the STE trick)
The identity gradient is implemented with the classic
``w + (w_q - w).detach()`` idiom: numerically equal to w_q in the
forward pass, but its autograd graph points at w with gradient 1.
"""
alpha = absmean_scale(w)
w_scaled = w / alpha
w_q = torch.round(w_scaled).clamp_(-1.0, 1.0) * alpha
return w + (w_q - w).detach()
def ternary_values(w: Tensor) -> Tensor:
"""Return the ternarized tensor as a plain (non-STE) tensor.
Useful for inspection and inference-time weight export. This is what
the CPU inference path will actually store and consume.
"""
with torch.no_grad():
alpha = absmean_scale(w)
return torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha
def ternary_signs(w: Tensor) -> Tensor:
"""Return just the {-1, 0, +1} trits (int8), without the scale.
Storage form: 2 bits per weight is the theoretical minimum for three
states. We return int8 here for day-0 correctness; bit-pack later
once the rest of the stack is working.
"""
with torch.no_grad():
alpha = absmean_scale(w)
return torch.round(w / alpha).clamp_(-1.0, 1.0).to(torch.int8)
def absmean_scale_per_row(w: Tensor) -> Tensor:
"""Per-row alpha: one mean(|.|) per output row, clamped away from zero.
First axis is the row axis. Returns shape (rows, 1, 1, ...) so it
broadcasts against w.
"""
if w.dim() < 2:
raise ValueError(f"per-row scale needs dim>=2, got shape {tuple(w.shape)}")
flat = w.reshape(w.size(0), -1)
alpha = flat.abs().mean(dim=1).clamp(min=EPS)
view = (w.size(0),) + (1,) * (w.dim() - 1)
return alpha.view(view)
def ternarize_per_row(w: Tensor) -> Tensor:
"""Per-row ternary STE: each row of w ternarised with its own alpha."""
alpha = absmean_scale_per_row(w)
w_q = torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha
return w + (w_q - w).detach()
def ternary_values_per_row(w: Tensor) -> Tensor:
"""Detached per-row ternarised values (no STE shim)."""
with torch.no_grad():
alpha = absmean_scale_per_row(w)
return torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha
class LearnableScale(nn.Module):
"""A single learnable FP32 scalar, clamped at EPS to avoid div-by-zero.
Wraps the scalar in nn.Module so it (a) shows up in .parameters(), (b)
moves with .to(device). Use .value() to read the (clamped) scalar.
"""
def __init__(self, initial: float = 1.0) -> None:
super().__init__()
if initial <= 0:
raise ValueError(f"initial scale must be > 0, got {initial}")
self.alpha = nn.Parameter(torch.tensor(float(initial)))
def value(self) -> Tensor:
return self.alpha.clamp(min=EPS)
def ternarize_lsq(w: Tensor, alpha: Tensor) -> Tensor:
"""STE ternarize using a learnable alpha (Esser et al., LSQ).
Forward: q_int * alpha where q_int = round(w/alpha).clamp(-1, 1)
Backward: dout/dw = 1 (STE — identity gradient to w shadow)
dout/dalpha = q_int
"""
q_int = torch.round(w / alpha).clamp_(-1.0, 1.0).detach()
return q_int * alpha + (w - w.detach())
@torch.no_grad()
def deadzone_stats(w: Tensor, band: float = 0.1) -> dict[str, float]:
"""Diagnostic for Tequila-style "deadzone trapping" (arXiv 2509.23809).
A weight is deadzone-trapped when ``|w/alpha|`` sits within ``band`` of
a rounding boundary at ±0.5: the round-to-trit operation is on a knife-
edge, and STE noise dominates the true gradient signal. Tequila's
finding is that a non-trivial fraction of weights live there permanently
after long training, contributing only noise.
Returns the breakdown of the trit assignment plus the boundary-band
occupancy. Use this to verify Tequila applies before considering the
deadzone-bias fix.
Keys:
``alpha``: per-tensor AbsMean scale.
``frac_neg / frac_zero / frac_pos``: fraction of weights rounding to
−1, 0, +1 respectively (sums to 1).
``frac_boundary``: fraction with ``||w/alpha| − 0.5| < band`` — the
deadzone-trap candidates. High values (>5–10%) suggest Tequila's
bias-repurposing fix could matter.
``frac_zero_inner``: fraction with ``|w/alpha| < 0.5 − band``, i.e.
deeply zero (stable, not on the boundary).
"""
alpha = absmean_scale(w)
r = (w / alpha).abs()
sgn = torch.sign(w / alpha)
rounded = torch.round(w / alpha).clamp_(-1.0, 1.0)
n = float(w.numel())
return {
"alpha": float(alpha.item()),
"frac_neg": float((rounded == -1).sum().item()) / n,
"frac_zero": float((rounded == 0).sum().item()) / n,
"frac_pos": float((rounded == 1).sum().item()) / n,
"frac_boundary": float(((r - 0.5).abs() < band).sum().item()) / n,
"frac_zero_inner": float(((sgn != 0) & (r < 0.5 - band)).sum().item()) / n,
}