"""tilelli.core.ternary_conv — depthwise causal 1-D conv with ternary weights. Depthwise (groups=channels) so input channels per group is 1, making the Hadamard rotation trivial (identity); we only expose per_row + lsq. """ from __future__ import annotations import torch from torch import Tensor, nn from torch.nn import functional as F from tilelli.core.ternary import ( LearnableScale, absmean_scale, absmean_scale_per_row, ternarize, ternarize_lsq, ternarize_per_row, ternary_signs, ) class TernaryCausalConv1d(nn.Module): """Depthwise causal 1-D conv with ternary weights and an FP32 shadow param.""" def __init__( self, channels: int, kernel_size: int = 5, quantize: bool = True, per_row: bool = False, lsq: bool = False, ) -> None: super().__init__() if lsq and per_row: raise ValueError("lsq + per_row not supported") self.channels = channels self.kernel_size = kernel_size self.quantize = quantize self.per_row = per_row self.lsq = lsq w = torch.randn(channels, 1, kernel_size) * (1.0 / kernel_size**0.5) self.weight = nn.Parameter(w) if lsq: init_alpha = (w.abs().mean().item() or 1.0) self.lsq_scale = LearnableScale(initial=init_alpha) else: self.lsq_scale = None # type: ignore[assignment] def _quantize(self, w: Tensor) -> Tensor: if self.lsq: return ternarize_lsq(w, self.lsq_scale.value()) if self.per_row: return ternarize_per_row(w) return ternarize(w) def forward(self, x: Tensor) -> Tensor: if x.dim() != 3: raise ValueError(f"expected (B, L, C), got shape {tuple(x.shape)}") if x.shape[-1] != self.channels: raise ValueError( f"channel mismatch: module has {self.channels}, input has {x.shape[-1]}" ) x_ = x.transpose(1, 2) x_ = F.pad(x_, (self.kernel_size - 1, 0)) w = self.weight if not self.quantize else self._quantize(self.weight) y = F.conv1d(x_, w, groups=self.channels) return y.transpose(1, 2) @torch.no_grad() def trits(self) -> Tensor: if self.lsq: alpha = self.lsq_scale.value() return torch.round(self.weight / alpha).clamp_(-1.0, 1.0).to(torch.int8) if self.per_row: alpha = absmean_scale_per_row(self.weight) return torch.round(self.weight / alpha).clamp_(-1.0, 1.0).to(torch.int8) return ternary_signs(self.weight) @torch.no_grad() def scale(self) -> Tensor: if self.lsq: return self.lsq_scale.value() if self.per_row: return absmean_scale_per_row(self.weight) return absmean_scale(self.weight) @torch.no_grad() def infer(self, x: Tensor) -> Tensor: x_ = x.transpose(1, 2) x_ = F.pad(x_, (self.kernel_size - 1, 0)) if not self.quantize: y = F.conv1d(x_, self.weight, groups=self.channels) return y.transpose(1, 2) trits = self.trits().to(x.dtype) alpha = self.scale() if self.per_row: y = F.conv1d(x_, trits, groups=self.channels) * alpha.view(1, self.channels, 1) else: y = alpha * F.conv1d(x_, trits, groups=self.channels) return y.transpose(1, 2) # ── Incremental-decode helpers (KV-cache equivalent for conv) ──────── # # The conv pathway is convolutional, not attention, but it still has a # "state" you can cache: the last (kernel_size - 1) inputs. A single new # input plus that buffer is sufficient to compute the next 1-token # output, identical to running the full conv over the whole prefix. def empty_buffer(self, batch_size: int, device, dtype) -> Tensor: """Zero-init buffer matching what the left-pad would produce.""" return torch.zeros(batch_size, self.kernel_size - 1, self.channels, device=device, dtype=dtype) def warmup_buffer(self, x: Tensor) -> Tensor: """Build the buffer from the FULL prompt — keep the last (k-1) inputs. x is (B, L, C). Returns (B, k-1, C) ready to feed forward_incremental.""" L = x.size(1) k1 = self.kernel_size - 1 if L >= k1: return x[:, -k1:, :].contiguous() buf = self.empty_buffer(x.size(0), x.device, x.dtype) if L > 0: buf[:, -L:, :] = x return buf def forward_incremental(self, x_step: Tensor, buffer: Tensor) -> tuple[Tensor, Tensor]: """Step one token through the conv, given the buffered last (k-1) inputs. Returns (y_step, new_buffer) where y_step is (B, 1, C) and new_buffer is (B, k-1, C) ready for the next step. """ # Concatenate buffer + new token → (B, k, C). Conv with kernel size k # over a sequence of length k gives a single output. full = torch.cat([buffer, x_step], dim=1) # (B, k, C) x_ = full.transpose(1, 2) # (B, C, k) if not self.quantize: w = self.weight else: w = self._quantize(self.weight) y = F.conv1d(x_, w, groups=self.channels) # (B, C, 1) y_step = y.transpose(1, 2) # (B, 1, C) new_buffer = full[:, 1:, :].contiguous() # drop oldest return y_step, new_buffer