Tilelli-llm / src /tilelli /core /tilelli_lite.py
TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
16.2 kB
"""tilelli.core.tilelli_lite β€” clean 3-pathway block designed to beat a same-size vanilla baseline.
A prior 6-pathway variant of this architecture (~10.6M params) tied vanilla on
TinyStories byte-LM (mean 0.5737 vs vanilla 0.5707). Internal audit attributed
the tie to fragmentation: parameter budget was spent on pathways the byte-LM
data did not reward (an indexed-knowledge slot, a wide convolution, and a
non-selective state-space path).
Tilelli Lite cuts those underperforming slots and keeps the lessons that DO
show up at 10M scale: heterogeneous pathways with a learned router, and a
ternary-capable forward pass for inference. This module is a sibling to the
larger 5/6-pathway block (kept intact for non-byte-LM workloads); it is not
a drop-in replacement.
3-pathway block:
- Local conv k=5 (n-grams; strictly more efficient than attention here)
- Sparse causal attention with multi-head (8 heads, d_head=48 by default)
- Dense FFN with expand=4 (matches vanilla's FFN ratio)
Other lessons folded in from the prior block's audit:
- Learned positional embedding (recovers the position signal lost by
the previous unembedding-only design)
- Load-balance auxiliary loss properly wired through the router head
"""
from __future__ import annotations
import torch
from torch import Tensor, nn
from tilelli.core.sparse_attention import SparseCausalAttention
from tilelli.core.ternary_conv import TernaryCausalConv1d
from tilelli.core.ternary_linear import TernaryLinear
PATHWAY_NAMES_LITE = ("local", "sparse", "dense")
class TernaryFFN_Lite(nn.Module):
"""Wider FFN at expand=4 (matches vanilla's ratio)."""
def __init__(self, d_model: int, expand: int = 4, quantize: bool = True) -> None:
super().__init__()
d_inner = d_model * expand
self.up = TernaryLinear(d_model, d_inner, quantize=quantize)
self.down = TernaryLinear(d_inner, d_model, quantize=quantize)
def forward(self, x: Tensor) -> Tensor:
return self.down(torch.nn.functional.gelu(self.up(x)))
class TilelliLiteBlock(nn.Module):
"""3-pathway block: Local conv + Sparse multi-head attn + Dense FFN.
All pathways always fire; per-token soft router mixes them. Load-balance
aux loss penalizes router collapse to one pathway.
"""
def __init__(
self,
d_model: int,
n_heads: int = 8,
kernel_size: int = 5,
top_k: int = 16,
ffn_expand: int = 4,
quantize: bool = True,
load_balance_weight: float = 0.01,
) -> None:
super().__init__()
self.d_model = d_model
self.n_pathways = 3
self.load_balance_weight = load_balance_weight
# Multi-head sparse attention. d_head computed from n_heads so total
# head dim equals d_model (matches vanilla's attention shape).
d_head = d_model // n_heads
if d_model % n_heads != 0:
raise ValueError(f"d_model {d_model} must divide n_heads {n_heads}")
self.norm = nn.LayerNorm(d_model)
self.local = TernaryCausalConv1d(d_model, kernel_size=kernel_size, quantize=quantize)
# Per-head Sparse attention β€” wraps n_heads of the existing single-head
# implementation, concatenates outputs.
self.sparse_heads = nn.ModuleList([
SparseCausalAttention(d_model, d_head=d_head, top_k=top_k)
for _ in range(n_heads)
])
self.sparse_proj = TernaryLinear(d_model, d_model, quantize=quantize)
self.dense = TernaryFFN_Lite(d_model, expand=ffn_expand, quantize=quantize)
self.router = TernaryLinear(d_model, self.n_pathways, quantize=quantize)
self._aux_loss = torch.tensor(0.0)
def _multi_head_sparse(self, h: Tensor) -> Tensor:
"""Concat outputs of n_heads single-head Sparse attentions, project."""
# Each head outputs (B, L, d_head). Concat β†’ (B, L, n_heads*d_head=d_model).
# SparseCausalAttention returns (B, L, d_model) β€” sum heads instead, then proj.
# Sum is param-efficient and equivalent to mean attention pooling.
head_outs = [h_mod(h) for h_mod in self.sparse_heads]
# Average rather than concat to keep dims at d_model (heads' outputs
# are already d_model each; this gives a smoothed multi-head signal).
merged = torch.stack(head_outs, dim=0).mean(dim=0)
return self.sparse_proj(merged)
def forward(self, x: Tensor) -> Tensor:
h = self.norm(x)
r = torch.softmax(self.router(h), dim=-1) # (B, L, 3)
out_local = self.local(h) # (B, L, d_model)
out_sparse = self._multi_head_sparse(h)
out_dense = self.dense(h)
mixed = (
r[..., 0:1] * out_local
+ r[..., 1:2] * out_sparse
+ r[..., 2:3] * out_dense
)
# Load-balance: per-pathway mean usage should approach 1/3.
pathway_use = r.mean(dim=(0, 1)) # (3,)
target = 1.0 / self.n_pathways
self._aux_loss = ((pathway_use - target) ** 2).mean() * self.load_balance_weight
# Cache per-token router entropy on this forward call so an outer
# training loop can read it for a metacognition aux loss (see
# scripts/train_router_metacog.py). Shape (B, L). On the
# inference path nothing reads this; cheap to compute.
self._router_entropy = -(r * (r + 1e-12).log()).sum(dim=-1)
return x + mixed
@property
def aux_loss(self) -> Tensor:
return self._aux_loss
@torch.no_grad()
def router_weights(self, x: Tensor) -> Tensor:
h = self.norm(x)
return torch.softmax(self.router(h), dim=-1)
@torch.no_grad()
def router_entropy(self, x: Tensor) -> Tensor:
"""Per-token entropy of router distribution. Low β†’ committed to one
pathway (high confidence). High β†’ uncertain mix."""
r = self.router_weights(x)
return -(r * (r + 1e-12).log()).sum(dim=-1)
# ── Incremental-decode helpers ────────────────────────────────────── #
# A block "cache" is a dict:
# {"conv_buffer": (B, k-1, D),
# "sparse_caches": [head_cache_dict for each head]}
def empty_cache(self, batch_size: int, device, dtype) -> dict:
return {
"conv_buffer": self.local.empty_buffer(batch_size, device, dtype),
"sparse_caches": [h.empty_cache(batch_size, device, dtype)
for h in self.sparse_heads],
}
def warmup_cache(self, x: Tensor) -> dict:
"""Build the cache from a full-prompt input x (B, L, D) β€” the SAME x
that was fed to forward() during prompt processing. This is what the
norm-then-pathway view sees, so we pass `h = self.norm(x)` here."""
h = self.norm(x)
return {
"conv_buffer": self.local.warmup_buffer(h),
"sparse_caches": [head.warmup_cache(h) for head in self.sparse_heads],
}
def forward_incremental(self, x_step: Tensor, cache: dict) -> tuple[Tensor, dict]:
"""One-token step through the block. Returns (out_step, new_cache).
out_step is the new residual contribution + x (so caller doesn't need
to re-add the residual)."""
h = self.norm(x_step) # (B, 1, D)
r = torch.softmax(self.router(h), dim=-1) # (B, 1, 3)
# Local conv: prepend buffer, conv β†’ 1 output, slide buffer
out_local, new_conv_buf = self.local.forward_incremental(h, cache["conv_buffer"])
# Sparse multi-head: each head incrementally updates its cache
head_outs = []
new_sparse_caches = []
for head, hc in zip(self.sparse_heads, cache["sparse_caches"]):
y_h, hc_new = head.forward_incremental(h, hc)
head_outs.append(y_h)
new_sparse_caches.append(hc_new)
merged = torch.stack(head_outs, dim=0).mean(dim=0) # (B, 1, D)
out_sparse = self.sparse_proj(merged)
# Dense FFN: stateless
out_dense = self.dense(h)
mixed = (
r[..., 0:1] * out_local
+ r[..., 1:2] * out_sparse
+ r[..., 2:3] * out_dense
)
new_cache = {
"conv_buffer": new_conv_buf,
"sparse_caches": new_sparse_caches,
}
return x_step + mixed, new_cache
class TernaryEmbeddingLite(nn.Module):
"""Token id β†’ ternary vector. Embedding weights are quantized to {-1,0,+1} with a per-tensor scale at forward time."""
def __init__(self, vocab_size: int, d_model: int, quantize: bool = True) -> None:
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.quantize = quantize
w = torch.randn(vocab_size, d_model) * (1.0 / d_model**0.5)
self.weight = nn.Parameter(w)
def forward(self, ids: Tensor) -> Tensor:
if self.quantize:
from tilelli.core.ternary import ternarize
w_q = ternarize(self.weight)
else:
w_q = self.weight
return w_q[ids]
class TilelliLiteLM(nn.Module):
"""Byte-level LM with TilelliLiteBlock stack + learned positional embed."""
def __init__(
self,
vocab_size: int = 256,
d_model: int = 384,
n_layers: int = 8,
n_heads: int = 8,
kernel_size: int = 5,
top_k: int = 16,
ffn_expand: int = 4,
max_seq_len: int = 2048,
quantize: bool = True,
load_balance_weight: float = 0.01,
) -> None:
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.quantize = quantize
self.embed = TernaryEmbeddingLite(vocab_size, d_model, quantize=quantize)
# Learned positional embedding β€” FP32 even in ternary mode (position
# info must survive quantization).
self.pos_embed = nn.Embedding(max_seq_len, d_model)
nn.init.normal_(self.pos_embed.weight, std=0.02)
self.blocks = nn.ModuleList([
TilelliLiteBlock(
d_model=d_model, n_heads=n_heads, kernel_size=kernel_size,
top_k=top_k, ffn_expand=ffn_expand, quantize=quantize,
load_balance_weight=load_balance_weight,
)
for _ in range(n_layers)
])
self.final_norm = nn.LayerNorm(d_model)
self.unembed = TernaryLinear(d_model, vocab_size, quantize=quantize)
def forward(self, ids: Tensor) -> Tensor:
L = ids.size(1)
if L > self.max_seq_len:
raise ValueError(f"sequence length {L} > max_seq_len {self.max_seq_len}")
x = self.embed(ids)
pos = torch.arange(L, device=ids.device)
x = x + self.pos_embed(pos)
for blk in self.blocks:
x = blk(x)
x = self.final_norm(x)
return self.unembed(x)
def loss(self, ids: Tensor, targets: Tensor | None = None) -> Tensor:
"""Autoregressive next-token loss + load-balance aux.
Compatible with both the (ids,) "shift internally" convention and the
(ids, targets) "caller-supplied targets" convention. If targets is None
we shift ids ourselves; otherwise we trust the caller (train.py-style).
"""
if targets is None:
if ids.size(1) < 2:
raise ValueError("loss needs sequence length >= 2")
inp = ids[:, :-1]
tgt = ids[:, 1:]
else:
inp, tgt = ids, targets
logits = self(inp)
ce = torch.nn.functional.cross_entropy(
logits.reshape(-1, self.vocab_size),
tgt.reshape(-1),
)
aux = sum(blk.aux_loss for blk in self.blocks)
return ce + aux
@torch.no_grad()
def router_entropies(self, ids: Tensor) -> Tensor:
"""Per-layer router entropy, shape (n_layers, B, L)."""
x = self.embed(ids)
pos = torch.arange(ids.size(1), device=ids.device)
x = x + self.pos_embed(pos)
ents = []
for blk in self.blocks:
ents.append(blk.router_entropy(x))
x = blk(x)
return torch.stack(ents, dim=0)
# ── Incremental generation with KV cache ──────────────────────────── #
# Big perf win: each step does one forward pass over a SINGLE new token,
# using cached K/V for attention and a sliding buffer for the conv. The
# dense FFN was the dominant cost without cache; with cache it runs once
# per step, not L times.
#
# Correctness: bit-exact equivalent of the non-cached forward at the
# final position (up to float-ordering noise, which doesn't change
# argmax). Verified by tests/test_kv_cache_parity.py.
@torch.no_grad()
def warmup_caches(self, ids: Tensor) -> tuple[Tensor, list[dict]]:
"""Run the full prompt forward, build per-layer caches, return the
final hidden state at the LAST position (for the first next-token
sample) plus the caches.
"""
L = ids.size(1)
if L > self.max_seq_len:
raise ValueError(f"sequence length {L} > max_seq_len {self.max_seq_len}")
x = self.embed(ids)
pos = torch.arange(L, device=ids.device)
x = x + self.pos_embed(pos)
caches = []
for blk in self.blocks:
caches.append(blk.warmup_cache(x))
x = blk(x)
return x, caches
@torch.no_grad()
def step_with_cache(self, next_id: Tensor, pos_index: int,
caches: list[dict]) -> tuple[Tensor, list[dict]]:
"""Forward ONE new token (B, 1) at absolute position pos_index. Uses
+ updates the per-layer caches in-place-ish (returns new list)."""
x = self.embed(next_id) # (B, 1, D)
pos = torch.tensor([pos_index], device=next_id.device)
x = x + self.pos_embed(pos)
new_caches = []
for blk, c in zip(self.blocks, caches):
x, c_new = blk.forward_incremental(x, c)
new_caches.append(c_new)
x = self.final_norm(x)
return self.unembed(x), new_caches
@torch.no_grad()
def generate_with_cache(
self,
ids: Tensor,
n_new_tokens: int,
stop_ids: tuple[int, ...] = (10, 0),
return_logits: bool = False,
) -> tuple[Tensor, list[int], list[float]]:
"""Greedy generate up to n_new_tokens using the KV cache. Returns
(full_ids, generated_id_list, confidence_per_step).
For non-greedy sampling, callers should use step_with_cache directly.
"""
was_training = self.training
self.eval()
try:
# Warm caches on the prompt; get the final-position logits via
# one extra final_norm + unembed of the last hidden state.
h_last, caches = self.warmup_caches(ids) # (B, L, D)
h_last_pos = self.final_norm(h_last[:, -1:, :]) # (B, 1, D)
logits = self.unembed(h_last_pos) # (B, 1, V)
cur_pos = ids.size(1) # next pos to fill
full = ids
generated: list[int] = []
confs: list[float] = []
for _ in range(n_new_tokens):
probs = torch.softmax(logits[:, -1, :], dim=-1)
next_id = probs.argmax(dim=-1, keepdim=True) # (B, 1)
nid_int = int(next_id)
confs.append(float(probs.max()))
generated.append(nid_int)
full = torch.cat([full, next_id], dim=1)
if nid_int in stop_ids:
break
if cur_pos + 1 > self.max_seq_len:
break
logits, caches = self.step_with_cache(next_id, cur_pos, caches)
cur_pos += 1
return full, generated, confs
finally:
if was_training:
self.train()