File size: 6,623 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""tilelli.core.ternary — BitNet b1.58 style ternary weights with STE.

Every weight in the model lives in {-α, 0, +α} where α is a per-tensor
scalar chosen by AbsMean rescaling. The forward pass sees the ternarized
version; the backward pass pretends the ternarization is the identity so
gradients flow to a FP32 "shadow" weight. That's the straight-through
estimator (STE).

Why ternary:
  - CPU inference: no float multiplies. Matmul collapses to add/subtract/skip.
  - Tiny training: ternary weights are ~10x smaller than FP16.
  - SDR activations (binary) × ternary weights = pure integer arithmetic
    in the forward pass at inference. Zero floating point. Popcount + add.
  - Biology agrees: synapses are roughly excitatory / inhibitory / silent.

Recipe (from the BitNet b1.58 paper):
  1. alpha = mean(|W|)           # AbsMean rescale
  2. W_scaled = W / (alpha + eps)
  3. W_q = clamp(round(W_scaled), -1, 1) * alpha
  4. forward uses W_q, backward uses dW_q/dW = 1  (straight-through)
"""
from __future__ import annotations

import torch
from torch import Tensor, nn


EPS = 1e-5


def absmean_scale(w: Tensor) -> Tensor:
    """The per-tensor scalar alpha = mean(|W|), clamped away from zero.

    Returns a 0-d tensor so it broadcasts against w without allocating.
    The clamp is load-bearing: an all-zero tensor would otherwise produce
    a division by zero and kill training in one step.
    """
    return w.abs().mean().clamp(min=EPS)


def ternarize(w: Tensor) -> Tensor:
    """Ternarize w to values in {-alpha, 0, +alpha} with a straight-through
    gradient.

    Forward:  returns round(w / alpha).clamp(-1, 1) * alpha
    Backward: d(ternarize(w))/dw = 1  (identity — the STE trick)

    The identity gradient is implemented with the classic
    ``w + (w_q - w).detach()`` idiom: numerically equal to w_q in the
    forward pass, but its autograd graph points at w with gradient 1.
    """
    alpha = absmean_scale(w)
    w_scaled = w / alpha
    w_q = torch.round(w_scaled).clamp_(-1.0, 1.0) * alpha
    return w + (w_q - w).detach()


def ternary_values(w: Tensor) -> Tensor:
    """Return the ternarized tensor as a plain (non-STE) tensor.

    Useful for inspection and inference-time weight export. This is what
    the CPU inference path will actually store and consume.
    """
    with torch.no_grad():
        alpha = absmean_scale(w)
        return torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha


def ternary_signs(w: Tensor) -> Tensor:
    """Return just the {-1, 0, +1} trits (int8), without the scale.

    Storage form: 2 bits per weight is the theoretical minimum for three
    states. We return int8 here for day-0 correctness; bit-pack later
    once the rest of the stack is working.
    """
    with torch.no_grad():
        alpha = absmean_scale(w)
        return torch.round(w / alpha).clamp_(-1.0, 1.0).to(torch.int8)


def absmean_scale_per_row(w: Tensor) -> Tensor:
    """Per-row alpha: one mean(|.|) per output row, clamped away from zero.

    First axis is the row axis. Returns shape (rows, 1, 1, ...) so it
    broadcasts against w.
    """
    if w.dim() < 2:
        raise ValueError(f"per-row scale needs dim>=2, got shape {tuple(w.shape)}")
    flat = w.reshape(w.size(0), -1)
    alpha = flat.abs().mean(dim=1).clamp(min=EPS)
    view = (w.size(0),) + (1,) * (w.dim() - 1)
    return alpha.view(view)


def ternarize_per_row(w: Tensor) -> Tensor:
    """Per-row ternary STE: each row of w ternarised with its own alpha."""
    alpha = absmean_scale_per_row(w)
    w_q = torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha
    return w + (w_q - w).detach()


def ternary_values_per_row(w: Tensor) -> Tensor:
    """Detached per-row ternarised values (no STE shim)."""
    with torch.no_grad():
        alpha = absmean_scale_per_row(w)
        return torch.round(w / alpha).clamp_(-1.0, 1.0) * alpha


class LearnableScale(nn.Module):
    """A single learnable FP32 scalar, clamped at EPS to avoid div-by-zero.

    Wraps the scalar in nn.Module so it (a) shows up in .parameters(), (b)
    moves with .to(device). Use .value() to read the (clamped) scalar.
    """

    def __init__(self, initial: float = 1.0) -> None:
        super().__init__()
        if initial <= 0:
            raise ValueError(f"initial scale must be > 0, got {initial}")
        self.alpha = nn.Parameter(torch.tensor(float(initial)))

    def value(self) -> Tensor:
        return self.alpha.clamp(min=EPS)


def ternarize_lsq(w: Tensor, alpha: Tensor) -> Tensor:
    """STE ternarize using a learnable alpha (Esser et al., LSQ).

    Forward:  q_int * alpha   where q_int = round(w/alpha).clamp(-1, 1)
    Backward: dout/dw = 1     (STE — identity gradient to w shadow)
              dout/dalpha = q_int
    """
    q_int = torch.round(w / alpha).clamp_(-1.0, 1.0).detach()
    return q_int * alpha + (w - w.detach())


@torch.no_grad()
def deadzone_stats(w: Tensor, band: float = 0.1) -> dict[str, float]:
    """Diagnostic for Tequila-style "deadzone trapping" (arXiv 2509.23809).

    A weight is deadzone-trapped when ``|w/alpha|`` sits within ``band`` of
    a rounding boundary at ±0.5: the round-to-trit operation is on a knife-
    edge, and STE noise dominates the true gradient signal. Tequila's
    finding is that a non-trivial fraction of weights live there permanently
    after long training, contributing only noise.

    Returns the breakdown of the trit assignment plus the boundary-band
    occupancy. Use this to verify Tequila applies before considering the
    deadzone-bias fix.

    Keys:
      ``alpha``: per-tensor AbsMean scale.
      ``frac_neg / frac_zero / frac_pos``: fraction of weights rounding to
        −1, 0, +1 respectively (sums to 1).
      ``frac_boundary``: fraction with ``||w/alpha| − 0.5| < band`` — the
        deadzone-trap candidates. High values (>5–10%) suggest Tequila's
        bias-repurposing fix could matter.
      ``frac_zero_inner``: fraction with ``|w/alpha| < 0.5 − band``, i.e.
        deeply zero (stable, not on the boundary).
    """
    alpha = absmean_scale(w)
    r = (w / alpha).abs()
    sgn = torch.sign(w / alpha)
    rounded = torch.round(w / alpha).clamp_(-1.0, 1.0)
    n = float(w.numel())
    return {
        "alpha": float(alpha.item()),
        "frac_neg": float((rounded == -1).sum().item()) / n,
        "frac_zero": float((rounded == 0).sum().item()) / n,
        "frac_pos": float((rounded == 1).sum().item()) / n,
        "frac_boundary": float(((r - 0.5).abs() < band).sum().item()) / n,
        "frac_zero_inner": float(((sgn != 0) & (r < 0.5 - band)).sum().item()) / n,
    }