| """tilelli.core.ternary_linear — a Linear layer whose weights are born ternary. |
| |
| Shadow-weight FP32 + STE ternarization on every forward. Optional flags: |
| |
| - per_row=True : one alpha per output row (closes part of the ternary gap on |
| layers with non-uniform row magnitudes). |
| - hadamard=True : right-multiply W by an orthogonal matrix before |
| ternarizing; rotate input by H upstream so y = (xH)(WH)^T = xW^T in FP. |
| - lsq=True : alpha is a learnable FP32 scalar (Esser et al.) initialised at |
| AbsMean(W). Optimizer can push it; mutually exclusive with per_row. |
| |
| All flags default off so the existing checkpoints + Tilelli baseline remain |
| bit-exact. |
| """ |
| from __future__ import annotations |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| from tilelli.core.hadamard import hadamard_matrix |
| from tilelli.core.ternary import ( |
| LearnableScale, |
| absmean_scale, |
| absmean_scale_per_row, |
| deadzone_stats, |
| ternarize, |
| ternarize_lsq, |
| ternarize_per_row, |
| ternary_signs, |
| ) |
|
|
|
|
| class TernaryLinear(nn.Module): |
| """y = x @ ternarize(W). Shadow weight is FP32; gradients use STE.""" |
|
|
| def __init__( |
| self, |
| in_features: int, |
| out_features: int, |
| quantize: bool = True, |
| per_row: bool = False, |
| hadamard: bool = False, |
| lsq: bool = False, |
| ) -> None: |
| super().__init__() |
| if lsq and per_row: |
| raise ValueError("lsq + per_row not supported (would need learnable vector)") |
| self.in_features = in_features |
| self.out_features = out_features |
| self.quantize = quantize |
| self.per_row = per_row |
| self.hadamard = hadamard |
| self.lsq = lsq |
| w = torch.randn(out_features, in_features) * (1.0 / in_features**0.5) |
| self.weight = nn.Parameter(w) |
| if hadamard: |
| self.register_buffer("hadamard_H", hadamard_matrix(in_features)) |
| else: |
| self.hadamard_H = None |
| if lsq: |
| init_alpha = (w.abs().mean().item() or 1.0) |
| self.lsq_scale = LearnableScale(initial=init_alpha) |
| else: |
| self.lsq_scale = None |
|
|
| def _rotate_weight(self, w: Tensor) -> Tensor: |
| if self.hadamard: |
| return w @ self.hadamard_H |
| return w |
|
|
| def _quantize(self, w: Tensor) -> Tensor: |
| if self.lsq: |
| return ternarize_lsq(w, self.lsq_scale.value()) |
| if self.per_row: |
| return ternarize_per_row(w) |
| return ternarize(w) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| if not self.quantize: |
| return x @ self.weight.t() |
| w_rot = self._rotate_weight(self.weight) |
| w_q = self._quantize(w_rot) |
| if self.hadamard: |
| x = x @ self.hadamard_H |
| return x @ w_q.t() |
|
|
| @torch.no_grad() |
| def trits(self) -> Tensor: |
| w = self._rotate_weight(self.weight) |
| if self.lsq: |
| alpha = self.lsq_scale.value() |
| return torch.round(w / alpha).clamp_(-1.0, 1.0).to(torch.int8) |
| if self.per_row: |
| alpha = absmean_scale_per_row(w) |
| return torch.round(w / alpha).clamp_(-1.0, 1.0).to(torch.int8) |
| return ternary_signs(w) |
|
|
| @torch.no_grad() |
| def scale(self) -> Tensor: |
| w = self._rotate_weight(self.weight) |
| if self.lsq: |
| return self.lsq_scale.value() |
| if self.per_row: |
| return absmean_scale_per_row(w) |
| return absmean_scale(w) |
|
|
| @torch.no_grad() |
| def deadzone_stats(self, band: float = 0.1) -> dict[str, float]: |
| return deadzone_stats(self.weight, band=band) |
|
|
| @torch.no_grad() |
| def infer(self, x: Tensor) -> Tensor: |
| if not self.quantize: |
| return x @ self.weight.t() |
| if self.hadamard: |
| x = x @ self.hadamard_H |
| trits = self.trits().to(x.dtype) |
| alpha = self.scale() |
| product = x @ trits.t() |
| if self.per_row: |
| return product * alpha.view(-1) |
| return alpha * product |
|
|