"""tilelli.core.ternary_ssm — the State pathway of Tilelli. From ARCHITECTURE.md: State path: small Mamba-style SSM. Long-range topic carry. O(n). Day-0 scope: a **diagonal** state-space model — one independent scalar recurrence per channel — which is the S4D / HiPPO-diag skeleton that Mamba is built on. We skip Mamba's data-dependent selection for now; that's a refinement on top of a working diagonal SSM, not the core idea. The per-channel recurrence: h_t[c] = a[c] · h_{t-1}[c] + b[c] · x_t[c] y_t[c] = c[c] · h_t[c] Three learnable per-channel scalars: decay `a`, input gain `b`, output scale `c`. Stability demands |a| < 1; we enforce that with `tanh(a_raw)`. Training uses the **convolutional mode** — because the recurrence is linear and diagonal, y_t unrolls to a 1-D convolution with kernel K[c, i] = c[c] · a[c]^i · b[c] for i = 0 … L-1 so a single depthwise `F.conv1d` gives us the whole output sequence in one shot. This is the S4 trick. Inference uses the recurrent mode — a simple per-step state update, O(L · C) sequential — which is what Tilelli will actually run on CPU one token at a time. A note on ternary weights here: The per-channel scalars are only O(C) parameters, vs O(C²) for the Linear layers. Ternarizing them saves almost nothing and makes the decay dynamics much harder to learn (decay must be in (0, 1), which ternary {-α, 0, +α} can't cleanly express). We keep these few parameters in FP32 and are honest about it: the SSM is the one place in Tilelli where a little floating point lives. The big consumers — Linear and Conv — remain pure ternary. """ from __future__ import annotations import math import torch from torch import Tensor, nn from torch.nn import functional as F class DiagonalSSM(nn.Module): """Per-channel diagonal state-space model. Input/output shape (B, L, C). Parameters are three per-channel vectors: - ``a_raw`` : pre-tanh decay; effective a = tanh(a_raw) ∈ (-1, 1) - ``b`` : input gain - ``c_out`` : output scale The state dimension equals the channel count (one scalar state per channel). For a wider state per channel, stack multiple DiagonalSSMs or move to a non-diagonal variant. """ def __init__(self, channels: int) -> None: super().__init__() self.channels = channels # Init decay near 0.9 so early training has long-ish memory. # tanh(1.5) ≈ 0.905. self.a_raw = nn.Parameter(torch.full((channels,), 1.5)) self.b = nn.Parameter(torch.randn(channels) * (1.0 / math.sqrt(channels))) self.c_out = nn.Parameter(torch.randn(channels) * (1.0 / math.sqrt(channels))) # ------------------------------------------------------------------ # # Training forward — convolutional mode # ------------------------------------------------------------------ # def forward(self, x: Tensor) -> Tensor: if x.dim() != 3: raise ValueError(f"expected (B, L, C), got shape {tuple(x.shape)}") B, L, C = x.shape if C != self.channels: raise ValueError(f"channel mismatch: module has {self.channels}, input has {C}") a = torch.tanh(self.a_raw) # (C,), in (-1, 1) b = self.b # (C,) c_out = self.c_out # (C,) # Build the per-channel causal kernel. We want # y_t = sum_{d=0}^{L-1} (c_out * a^d * b) * x_{t-d} # torch.conv1d is cross-correlation: with left-pad L-1, the # LAST kernel element is delay 0, so the powers must run from # (L-1) down to 0 across the kernel's spatial axis. i = torch.arange(L - 1, -1, -1, device=x.device, dtype=x.dtype) # (L,) powers = a.unsqueeze(-1) ** i.unsqueeze(0) # (C, L) kernel = (c_out * b).unsqueeze(-1) * powers # (C, L) kernel = kernel.unsqueeze(1) # (C, 1, L) # Depthwise causal conv: left-pad L-1, groups=C x_ = x.transpose(1, 2) # (B, C, L) x_ = F.pad(x_, (L - 1, 0)) y = F.conv1d(x_, kernel, groups=C) return y.transpose(1, 2) # (B, L, C) # ------------------------------------------------------------------ # # Inference — recurrent mode, O(L·C) sequential # ------------------------------------------------------------------ # @torch.no_grad() def infer(self, x: Tensor) -> Tensor: """Step-by-step recurrence. Agrees with `forward` numerically. This is the path Tilelli runs at CPU inference time — one token in, one token out, state of shape (B, C) carried across steps. No L² kernel to build. """ if x.dim() != 3: raise ValueError(f"expected (B, L, C), got shape {tuple(x.shape)}") B, L, C = x.shape a = torch.tanh(self.a_raw) b = self.b c_out = self.c_out h = torch.zeros(B, C, dtype=x.dtype, device=x.device) ys = [] for t in range(L): h = a * h + b * x[:, t] ys.append(c_out * h) return torch.stack(ys, dim=1) # (B, L, C)