| """tilelli.core.hadamard — orthogonal-rotation utilities for ternary quantization. |
| |
| Quantization-error reduction trick from QuaRot / SpinQuant (2024). Multiplying |
| a weight matrix by an orthogonal matrix H spreads the energy of any single |
| position across all positions, flattening outliers and producing a more |
| Gaussian-like distribution that ternarizes with less rounding error. |
| |
| Sylvester construction works only for n = 2^k. For other sizes we fall |
| back to a fixed-seed random orthogonal matrix (Householder/QR rotations), |
| treated as equivalent in practice for quantization purposes. |
| """ |
| from __future__ import annotations |
|
|
| import functools |
|
|
| import torch |
| from torch import Tensor |
|
|
|
|
| def _is_power_of_two(n: int) -> bool: |
| return n > 0 and (n & (n - 1)) == 0 |
|
|
|
|
| def _sylvester_hadamard(n: int) -> Tensor: |
| if not _is_power_of_two(n): |
| raise ValueError(f"Sylvester Hadamard requires power-of-2 size, got {n}") |
| h = torch.tensor([[1.0]]) |
| while h.size(0) < n: |
| top = torch.cat([h, h], dim=1) |
| bot = torch.cat([h, -h], dim=1) |
| h = torch.cat([top, bot], dim=0) / (2.0**0.5) |
| return h |
|
|
|
|
| def _random_orthogonal(n: int, seed: int = 1234) -> Tensor: |
| g = torch.Generator(device="cpu").manual_seed(seed) |
| a = torch.randn(n, n, generator=g, dtype=torch.float64) |
| q, _r = torch.linalg.qr(a) |
| return q.to(torch.float32) |
|
|
|
|
| @functools.lru_cache(maxsize=64) |
| def hadamard_matrix(n: int, seed: int = 1234) -> Tensor: |
| if _is_power_of_two(n): |
| return _sylvester_hadamard(n) |
| return _random_orthogonal(n, seed=seed) |
|
|
|
|
| def rotate_columns(w: Tensor, h: Tensor | None = None) -> Tensor: |
| n = w.size(-1) |
| if h is None: |
| h = hadamard_matrix(n).to(dtype=w.dtype, device=w.device) |
| return w @ h |
|
|
|
|
| def rotate_input(x: Tensor, n: int, h: Tensor | None = None) -> Tensor: |
| if h is None: |
| h = hadamard_matrix(n).to(dtype=x.dtype, device=x.device) |
| return x @ h |
|
|
|
|
| __all__ = ["hadamard_matrix", "rotate_columns", "rotate_input"] |
|
|