"""tilelli.core.ternary_linear — a Linear layer whose weights are born ternary. Shadow-weight FP32 + STE ternarization on every forward. Optional flags: - per_row=True : one alpha per output row (closes part of the ternary gap on layers with non-uniform row magnitudes). - hadamard=True : right-multiply W by an orthogonal matrix before ternarizing; rotate input by H upstream so y = (xH)(WH)^T = xW^T in FP. - lsq=True : alpha is a learnable FP32 scalar (Esser et al.) initialised at AbsMean(W). Optimizer can push it; mutually exclusive with per_row. All flags default off so the existing checkpoints + Tilelli baseline remain bit-exact. """ from __future__ import annotations import torch from torch import Tensor, nn from tilelli.core.hadamard import hadamard_matrix from tilelli.core.ternary import ( LearnableScale, absmean_scale, absmean_scale_per_row, deadzone_stats, ternarize, ternarize_lsq, ternarize_per_row, ternary_signs, ) class TernaryLinear(nn.Module): """y = x @ ternarize(W). Shadow weight is FP32; gradients use STE.""" def __init__( self, in_features: int, out_features: int, quantize: bool = True, per_row: bool = False, hadamard: bool = False, lsq: bool = False, ) -> None: super().__init__() if lsq and per_row: raise ValueError("lsq + per_row not supported (would need learnable vector)") self.in_features = in_features self.out_features = out_features self.quantize = quantize self.per_row = per_row self.hadamard = hadamard self.lsq = lsq w = torch.randn(out_features, in_features) * (1.0 / in_features**0.5) self.weight = nn.Parameter(w) if hadamard: self.register_buffer("hadamard_H", hadamard_matrix(in_features)) else: self.hadamard_H = None # type: ignore[assignment] if lsq: init_alpha = (w.abs().mean().item() or 1.0) self.lsq_scale = LearnableScale(initial=init_alpha) else: self.lsq_scale = None # type: ignore[assignment] def _rotate_weight(self, w: Tensor) -> Tensor: if self.hadamard: return w @ self.hadamard_H return w def _quantize(self, w: Tensor) -> Tensor: if self.lsq: return ternarize_lsq(w, self.lsq_scale.value()) if self.per_row: return ternarize_per_row(w) return ternarize(w) def forward(self, x: Tensor) -> Tensor: if not self.quantize: return x @ self.weight.t() w_rot = self._rotate_weight(self.weight) w_q = self._quantize(w_rot) if self.hadamard: x = x @ self.hadamard_H return x @ w_q.t() @torch.no_grad() def trits(self) -> Tensor: w = self._rotate_weight(self.weight) if self.lsq: alpha = self.lsq_scale.value() return torch.round(w / alpha).clamp_(-1.0, 1.0).to(torch.int8) if self.per_row: alpha = absmean_scale_per_row(w) return torch.round(w / alpha).clamp_(-1.0, 1.0).to(torch.int8) return ternary_signs(w) @torch.no_grad() def scale(self) -> Tensor: w = self._rotate_weight(self.weight) if self.lsq: return self.lsq_scale.value() if self.per_row: return absmean_scale_per_row(w) return absmean_scale(w) @torch.no_grad() def deadzone_stats(self, band: float = 0.1) -> dict[str, float]: return deadzone_stats(self.weight, band=band) @torch.no_grad() def infer(self, x: Tensor) -> Tensor: if not self.quantize: return x @ self.weight.t() if self.hadamard: x = x @ self.hadamard_H trits = self.trits().to(x.dtype) alpha = self.scale() product = x @ trits.t() if self.per_row: return product * alpha.view(-1) return alpha * product