File size: 4,034 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""tilelli.core.ternary_linear — a Linear layer whose weights are born ternary.

Shadow-weight FP32 + STE ternarization on every forward. Optional flags:

  - per_row=True : one alpha per output row (closes part of the ternary gap on
    layers with non-uniform row magnitudes).
  - hadamard=True : right-multiply W by an orthogonal matrix before
    ternarizing; rotate input by H upstream so y = (xH)(WH)^T = xW^T in FP.
  - lsq=True : alpha is a learnable FP32 scalar (Esser et al.) initialised at
    AbsMean(W). Optimizer can push it; mutually exclusive with per_row.

All flags default off so the existing checkpoints + Tilelli baseline remain
bit-exact.
"""
from __future__ import annotations

import torch
from torch import Tensor, nn

from tilelli.core.hadamard import hadamard_matrix
from tilelli.core.ternary import (
    LearnableScale,
    absmean_scale,
    absmean_scale_per_row,
    deadzone_stats,
    ternarize,
    ternarize_lsq,
    ternarize_per_row,
    ternary_signs,
)


class TernaryLinear(nn.Module):
    """y = x @ ternarize(W). Shadow weight is FP32; gradients use STE."""

    def __init__(
        self,
        in_features: int,
        out_features: int,
        quantize: bool = True,
        per_row: bool = False,
        hadamard: bool = False,
        lsq: bool = False,
    ) -> None:
        super().__init__()
        if lsq and per_row:
            raise ValueError("lsq + per_row not supported (would need learnable vector)")
        self.in_features = in_features
        self.out_features = out_features
        self.quantize = quantize
        self.per_row = per_row
        self.hadamard = hadamard
        self.lsq = lsq
        w = torch.randn(out_features, in_features) * (1.0 / in_features**0.5)
        self.weight = nn.Parameter(w)
        if hadamard:
            self.register_buffer("hadamard_H", hadamard_matrix(in_features))
        else:
            self.hadamard_H = None  # type: ignore[assignment]
        if lsq:
            init_alpha = (w.abs().mean().item() or 1.0)
            self.lsq_scale = LearnableScale(initial=init_alpha)
        else:
            self.lsq_scale = None  # type: ignore[assignment]

    def _rotate_weight(self, w: Tensor) -> Tensor:
        if self.hadamard:
            return w @ self.hadamard_H
        return w

    def _quantize(self, w: Tensor) -> Tensor:
        if self.lsq:
            return ternarize_lsq(w, self.lsq_scale.value())
        if self.per_row:
            return ternarize_per_row(w)
        return ternarize(w)

    def forward(self, x: Tensor) -> Tensor:
        if not self.quantize:
            return x @ self.weight.t()
        w_rot = self._rotate_weight(self.weight)
        w_q = self._quantize(w_rot)
        if self.hadamard:
            x = x @ self.hadamard_H
        return x @ w_q.t()

    @torch.no_grad()
    def trits(self) -> Tensor:
        w = self._rotate_weight(self.weight)
        if self.lsq:
            alpha = self.lsq_scale.value()
            return torch.round(w / alpha).clamp_(-1.0, 1.0).to(torch.int8)
        if self.per_row:
            alpha = absmean_scale_per_row(w)
            return torch.round(w / alpha).clamp_(-1.0, 1.0).to(torch.int8)
        return ternary_signs(w)

    @torch.no_grad()
    def scale(self) -> Tensor:
        w = self._rotate_weight(self.weight)
        if self.lsq:
            return self.lsq_scale.value()
        if self.per_row:
            return absmean_scale_per_row(w)
        return absmean_scale(w)

    @torch.no_grad()
    def deadzone_stats(self, band: float = 0.1) -> dict[str, float]:
        return deadzone_stats(self.weight, band=band)

    @torch.no_grad()
    def infer(self, x: Tensor) -> Tensor:
        if not self.quantize:
            return x @ self.weight.t()
        if self.hadamard:
            x = x @ self.hadamard_H
        trits = self.trits().to(x.dtype)
        alpha = self.scale()
        product = x @ trits.t()
        if self.per_row:
            return product * alpha.view(-1)
        return alpha * product