Tilelli-llm / src /tilelli /core /ternary_conv.py
TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
5.56 kB
"""tilelli.core.ternary_conv β€” depthwise causal 1-D conv with ternary weights.
Depthwise (groups=channels) so input channels per group is 1, making the
Hadamard rotation trivial (identity); we only expose per_row + lsq.
"""
from __future__ import annotations
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from tilelli.core.ternary import (
LearnableScale,
absmean_scale,
absmean_scale_per_row,
ternarize,
ternarize_lsq,
ternarize_per_row,
ternary_signs,
)
class TernaryCausalConv1d(nn.Module):
"""Depthwise causal 1-D conv with ternary weights and an FP32 shadow param."""
def __init__(
self,
channels: int,
kernel_size: int = 5,
quantize: bool = True,
per_row: bool = False,
lsq: bool = False,
) -> None:
super().__init__()
if lsq and per_row:
raise ValueError("lsq + per_row not supported")
self.channels = channels
self.kernel_size = kernel_size
self.quantize = quantize
self.per_row = per_row
self.lsq = lsq
w = torch.randn(channels, 1, kernel_size) * (1.0 / kernel_size**0.5)
self.weight = nn.Parameter(w)
if lsq:
init_alpha = (w.abs().mean().item() or 1.0)
self.lsq_scale = LearnableScale(initial=init_alpha)
else:
self.lsq_scale = None # type: ignore[assignment]
def _quantize(self, w: Tensor) -> Tensor:
if self.lsq:
return ternarize_lsq(w, self.lsq_scale.value())
if self.per_row:
return ternarize_per_row(w)
return ternarize(w)
def forward(self, x: Tensor) -> Tensor:
if x.dim() != 3:
raise ValueError(f"expected (B, L, C), got shape {tuple(x.shape)}")
if x.shape[-1] != self.channels:
raise ValueError(
f"channel mismatch: module has {self.channels}, input has {x.shape[-1]}"
)
x_ = x.transpose(1, 2)
x_ = F.pad(x_, (self.kernel_size - 1, 0))
w = self.weight if not self.quantize else self._quantize(self.weight)
y = F.conv1d(x_, w, groups=self.channels)
return y.transpose(1, 2)
@torch.no_grad()
def trits(self) -> Tensor:
if self.lsq:
alpha = self.lsq_scale.value()
return torch.round(self.weight / alpha).clamp_(-1.0, 1.0).to(torch.int8)
if self.per_row:
alpha = absmean_scale_per_row(self.weight)
return torch.round(self.weight / alpha).clamp_(-1.0, 1.0).to(torch.int8)
return ternary_signs(self.weight)
@torch.no_grad()
def scale(self) -> Tensor:
if self.lsq:
return self.lsq_scale.value()
if self.per_row:
return absmean_scale_per_row(self.weight)
return absmean_scale(self.weight)
@torch.no_grad()
def infer(self, x: Tensor) -> Tensor:
x_ = x.transpose(1, 2)
x_ = F.pad(x_, (self.kernel_size - 1, 0))
if not self.quantize:
y = F.conv1d(x_, self.weight, groups=self.channels)
return y.transpose(1, 2)
trits = self.trits().to(x.dtype)
alpha = self.scale()
if self.per_row:
y = F.conv1d(x_, trits, groups=self.channels) * alpha.view(1, self.channels, 1)
else:
y = alpha * F.conv1d(x_, trits, groups=self.channels)
return y.transpose(1, 2)
# ── Incremental-decode helpers (KV-cache equivalent for conv) ──────── #
# The conv pathway is convolutional, not attention, but it still has a
# "state" you can cache: the last (kernel_size - 1) inputs. A single new
# input plus that buffer is sufficient to compute the next 1-token
# output, identical to running the full conv over the whole prefix.
def empty_buffer(self, batch_size: int, device, dtype) -> Tensor:
"""Zero-init buffer matching what the left-pad would produce."""
return torch.zeros(batch_size, self.kernel_size - 1, self.channels,
device=device, dtype=dtype)
def warmup_buffer(self, x: Tensor) -> Tensor:
"""Build the buffer from the FULL prompt β€” keep the last (k-1) inputs.
x is (B, L, C). Returns (B, k-1, C) ready to feed forward_incremental."""
L = x.size(1)
k1 = self.kernel_size - 1
if L >= k1:
return x[:, -k1:, :].contiguous()
buf = self.empty_buffer(x.size(0), x.device, x.dtype)
if L > 0:
buf[:, -L:, :] = x
return buf
def forward_incremental(self, x_step: Tensor, buffer: Tensor) -> tuple[Tensor, Tensor]:
"""Step one token through the conv, given the buffered last (k-1) inputs.
Returns (y_step, new_buffer) where y_step is (B, 1, C) and new_buffer
is (B, k-1, C) ready for the next step.
"""
# Concatenate buffer + new token β†’ (B, k, C). Conv with kernel size k
# over a sequence of length k gives a single output.
full = torch.cat([buffer, x_step], dim=1) # (B, k, C)
x_ = full.transpose(1, 2) # (B, C, k)
if not self.quantize:
w = self.weight
else:
w = self._quantize(self.weight)
y = F.conv1d(x_, w, groups=self.channels) # (B, C, 1)
y_step = y.transpose(1, 2) # (B, 1, C)
new_buffer = full[:, 1:, :].contiguous() # drop oldest
return y_step, new_buffer