| """tilelli.core.ternary_conv β depthwise causal 1-D conv with ternary weights. |
| |
| Depthwise (groups=channels) so input channels per group is 1, making the |
| Hadamard rotation trivial (identity); we only expose per_row + lsq. |
| """ |
| from __future__ import annotations |
|
|
| import torch |
| from torch import Tensor, nn |
| from torch.nn import functional as F |
|
|
| from tilelli.core.ternary import ( |
| LearnableScale, |
| absmean_scale, |
| absmean_scale_per_row, |
| ternarize, |
| ternarize_lsq, |
| ternarize_per_row, |
| ternary_signs, |
| ) |
|
|
|
|
| class TernaryCausalConv1d(nn.Module): |
| """Depthwise causal 1-D conv with ternary weights and an FP32 shadow param.""" |
|
|
| def __init__( |
| self, |
| channels: int, |
| kernel_size: int = 5, |
| quantize: bool = True, |
| per_row: bool = False, |
| lsq: bool = False, |
| ) -> None: |
| super().__init__() |
| if lsq and per_row: |
| raise ValueError("lsq + per_row not supported") |
| self.channels = channels |
| self.kernel_size = kernel_size |
| self.quantize = quantize |
| self.per_row = per_row |
| self.lsq = lsq |
| w = torch.randn(channels, 1, kernel_size) * (1.0 / kernel_size**0.5) |
| self.weight = nn.Parameter(w) |
| if lsq: |
| init_alpha = (w.abs().mean().item() or 1.0) |
| self.lsq_scale = LearnableScale(initial=init_alpha) |
| else: |
| self.lsq_scale = None |
|
|
| def _quantize(self, w: Tensor) -> Tensor: |
| if self.lsq: |
| return ternarize_lsq(w, self.lsq_scale.value()) |
| if self.per_row: |
| return ternarize_per_row(w) |
| return ternarize(w) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| if x.dim() != 3: |
| raise ValueError(f"expected (B, L, C), got shape {tuple(x.shape)}") |
| if x.shape[-1] != self.channels: |
| raise ValueError( |
| f"channel mismatch: module has {self.channels}, input has {x.shape[-1]}" |
| ) |
| x_ = x.transpose(1, 2) |
| x_ = F.pad(x_, (self.kernel_size - 1, 0)) |
| w = self.weight if not self.quantize else self._quantize(self.weight) |
| y = F.conv1d(x_, w, groups=self.channels) |
| return y.transpose(1, 2) |
|
|
| @torch.no_grad() |
| def trits(self) -> Tensor: |
| if self.lsq: |
| alpha = self.lsq_scale.value() |
| return torch.round(self.weight / alpha).clamp_(-1.0, 1.0).to(torch.int8) |
| if self.per_row: |
| alpha = absmean_scale_per_row(self.weight) |
| return torch.round(self.weight / alpha).clamp_(-1.0, 1.0).to(torch.int8) |
| return ternary_signs(self.weight) |
|
|
| @torch.no_grad() |
| def scale(self) -> Tensor: |
| if self.lsq: |
| return self.lsq_scale.value() |
| if self.per_row: |
| return absmean_scale_per_row(self.weight) |
| return absmean_scale(self.weight) |
|
|
| @torch.no_grad() |
| def infer(self, x: Tensor) -> Tensor: |
| x_ = x.transpose(1, 2) |
| x_ = F.pad(x_, (self.kernel_size - 1, 0)) |
| if not self.quantize: |
| y = F.conv1d(x_, self.weight, groups=self.channels) |
| return y.transpose(1, 2) |
| trits = self.trits().to(x.dtype) |
| alpha = self.scale() |
| if self.per_row: |
| y = F.conv1d(x_, trits, groups=self.channels) * alpha.view(1, self.channels, 1) |
| else: |
| y = alpha * F.conv1d(x_, trits, groups=self.channels) |
| return y.transpose(1, 2) |
|
|
| |
| |
| |
| |
| |
|
|
| def empty_buffer(self, batch_size: int, device, dtype) -> Tensor: |
| """Zero-init buffer matching what the left-pad would produce.""" |
| return torch.zeros(batch_size, self.kernel_size - 1, self.channels, |
| device=device, dtype=dtype) |
|
|
| def warmup_buffer(self, x: Tensor) -> Tensor: |
| """Build the buffer from the FULL prompt β keep the last (k-1) inputs. |
| x is (B, L, C). Returns (B, k-1, C) ready to feed forward_incremental.""" |
| L = x.size(1) |
| k1 = self.kernel_size - 1 |
| if L >= k1: |
| return x[:, -k1:, :].contiguous() |
| buf = self.empty_buffer(x.size(0), x.device, x.dtype) |
| if L > 0: |
| buf[:, -L:, :] = x |
| return buf |
|
|
| def forward_incremental(self, x_step: Tensor, buffer: Tensor) -> tuple[Tensor, Tensor]: |
| """Step one token through the conv, given the buffered last (k-1) inputs. |
| Returns (y_step, new_buffer) where y_step is (B, 1, C) and new_buffer |
| is (B, k-1, C) ready for the next step. |
| """ |
| |
| |
| full = torch.cat([buffer, x_step], dim=1) |
| x_ = full.transpose(1, 2) |
| if not self.quantize: |
| w = self.weight |
| else: |
| w = self._quantize(self.weight) |
| y = F.conv1d(x_, w, groups=self.channels) |
| y_step = y.transpose(1, 2) |
| new_buffer = full[:, 1:, :].contiguous() |
| return y_step, new_buffer |
|
|