Tilelli-llm / src /tilelli /core /ternary_linear.py
TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
4.03 kB
"""tilelli.core.ternary_linear — a Linear layer whose weights are born ternary.
Shadow-weight FP32 + STE ternarization on every forward. Optional flags:
- per_row=True : one alpha per output row (closes part of the ternary gap on
layers with non-uniform row magnitudes).
- hadamard=True : right-multiply W by an orthogonal matrix before
ternarizing; rotate input by H upstream so y = (xH)(WH)^T = xW^T in FP.
- lsq=True : alpha is a learnable FP32 scalar (Esser et al.) initialised at
AbsMean(W). Optimizer can push it; mutually exclusive with per_row.
All flags default off so the existing checkpoints + Tilelli baseline remain
bit-exact.
"""
from __future__ import annotations
import torch
from torch import Tensor, nn
from tilelli.core.hadamard import hadamard_matrix
from tilelli.core.ternary import (
LearnableScale,
absmean_scale,
absmean_scale_per_row,
deadzone_stats,
ternarize,
ternarize_lsq,
ternarize_per_row,
ternary_signs,
)
class TernaryLinear(nn.Module):
"""y = x @ ternarize(W). Shadow weight is FP32; gradients use STE."""
def __init__(
self,
in_features: int,
out_features: int,
quantize: bool = True,
per_row: bool = False,
hadamard: bool = False,
lsq: bool = False,
) -> None:
super().__init__()
if lsq and per_row:
raise ValueError("lsq + per_row not supported (would need learnable vector)")
self.in_features = in_features
self.out_features = out_features
self.quantize = quantize
self.per_row = per_row
self.hadamard = hadamard
self.lsq = lsq
w = torch.randn(out_features, in_features) * (1.0 / in_features**0.5)
self.weight = nn.Parameter(w)
if hadamard:
self.register_buffer("hadamard_H", hadamard_matrix(in_features))
else:
self.hadamard_H = None # type: ignore[assignment]
if lsq:
init_alpha = (w.abs().mean().item() or 1.0)
self.lsq_scale = LearnableScale(initial=init_alpha)
else:
self.lsq_scale = None # type: ignore[assignment]
def _rotate_weight(self, w: Tensor) -> Tensor:
if self.hadamard:
return w @ self.hadamard_H
return w
def _quantize(self, w: Tensor) -> Tensor:
if self.lsq:
return ternarize_lsq(w, self.lsq_scale.value())
if self.per_row:
return ternarize_per_row(w)
return ternarize(w)
def forward(self, x: Tensor) -> Tensor:
if not self.quantize:
return x @ self.weight.t()
w_rot = self._rotate_weight(self.weight)
w_q = self._quantize(w_rot)
if self.hadamard:
x = x @ self.hadamard_H
return x @ w_q.t()
@torch.no_grad()
def trits(self) -> Tensor:
w = self._rotate_weight(self.weight)
if self.lsq:
alpha = self.lsq_scale.value()
return torch.round(w / alpha).clamp_(-1.0, 1.0).to(torch.int8)
if self.per_row:
alpha = absmean_scale_per_row(w)
return torch.round(w / alpha).clamp_(-1.0, 1.0).to(torch.int8)
return ternary_signs(w)
@torch.no_grad()
def scale(self) -> Tensor:
w = self._rotate_weight(self.weight)
if self.lsq:
return self.lsq_scale.value()
if self.per_row:
return absmean_scale_per_row(w)
return absmean_scale(w)
@torch.no_grad()
def deadzone_stats(self, band: float = 0.1) -> dict[str, float]:
return deadzone_stats(self.weight, band=band)
@torch.no_grad()
def infer(self, x: Tensor) -> Tensor:
if not self.quantize:
return x @ self.weight.t()
if self.hadamard:
x = x @ self.hadamard_H
trits = self.trits().to(x.dtype)
alpha = self.scale()
product = x @ trits.t()
if self.per_row:
return product * alpha.view(-1)
return alpha * product