Tilelli-llm / src /tilelli /core /tilelli_lm.py
TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
4.69 kB
"""tilelli.core.tilelli_lm — minimal byte-level language model built on
ternary primitives + heterogeneous-pathway blocks.
Stacks TilelliBlock layers on top of a byte embedding and a ternary
unembedding, plus a learned positional embedding.
"""
from __future__ import annotations
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from tilelli.core.ternary_linear import TernaryLinear
from tilelli.core.tilelli_block import TilelliBlock
class TilelliLM(nn.Module):
"""Byte-level Tilelli language model."""
def __init__(
self,
vocab_size: int = 256,
d_model: int = 128,
n_layers: int = 4,
d_head: int = 32,
top_k: int = 8,
pathways: int = 5,
max_seq_len: int = 512,
quantize: bool = True,
n_banks: int = 1,
per_row: bool = False,
hadamard: bool = False,
lsq: bool = False,
dense_expand: int = 2,
fp_attention: bool = False,
top_k_routing: int = 0,
) -> None:
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.quantize = quantize
self.n_banks = n_banks
self.per_row = per_row
self.hadamard = hadamard
self.lsq = lsq
self.dense_expand = dense_expand
self.fp_attention = fp_attention
self.top_k_routing = top_k_routing
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_len, d_model)
self.blocks = nn.ModuleList(
[
TilelliBlock(
d_model=d_model,
d_head=d_head,
top_k=top_k,
pathways=pathways,
n_banks=n_banks,
quantize=quantize,
per_row=per_row,
hadamard=hadamard,
lsq=lsq,
dense_expand=dense_expand,
fp_attention=fp_attention,
top_k_routing=top_k_routing,
)
for _ in range(n_layers)
]
)
self.norm_out = nn.LayerNorm(d_model)
self.unembed = TernaryLinear(
d_model, vocab_size,
quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq,
)
def forward(self, ids: Tensor) -> Tensor:
if ids.dim() != 2:
raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}")
B, L = ids.shape
if L > self.max_seq_len:
raise ValueError(f"sequence length {L} exceeds max_seq_len {self.max_seq_len}")
positions = torch.arange(L, device=ids.device)
x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :]
for block in self.blocks:
x = block(x)
x = self.norm_out(x)
return self.unembed(x)
@property
def aux_loss(self) -> Tensor:
"""Sum of per-block load-balancing aux losses. Zero when n_banks=1."""
if self.n_banks <= 1:
return torch.tensor(0.0, device=self.token_emb.weight.device)
return sum(b.aux_loss for b in self.blocks)
def loss(self, ids: Tensor, targets: Tensor) -> Tensor:
"""Cross-entropy loss + load-balance aux when banking is on."""
logits = self.forward(ids)
ce = F.cross_entropy(logits.reshape(-1, self.vocab_size), targets.reshape(-1))
if self.n_banks > 1:
return ce + self.aux_loss
return ce
@torch.no_grad()
def generate(self, ids: Tensor, n_new_tokens: int) -> Tensor:
was_training = self.training
self.eval()
try:
for _ in range(n_new_tokens):
ids_in = ids[:, -self.max_seq_len:]
logits = self.forward(ids_in)[:, -1, :]
next_id = logits.argmax(dim=-1, keepdim=True)
ids = torch.cat([ids, next_id], dim=1)
return ids
finally:
if was_training:
self.train()
@torch.no_grad()
def router_entropies(self, ids: Tensor) -> list[Tensor]:
if ids.dim() != 2:
raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}")
positions = torch.arange(ids.size(1), device=ids.device)
x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :]
out = []
for block in self.blocks:
out.append(block.router_entropy(x))
x = block(x)
return out
def parameter_count(self) -> int:
return sum(p.numel() for p in self.parameters())