Tilelli-llm / src /tilelli /baselines /vanilla.py
TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
4.88 kB
"""Vanilla pre-norm Transformer baseline.
A minimal, faithful pre-norm Transformer at the same byte-level tokenizer,
same max sequence length, and same parameter budget as the public
``TilelliLM`` config. Used solely for the param-matched "beat vanilla"
comparison the project's headline claim rests on.
This is the textbook decoder block: multi-head causal attention + GELU FFN
at 4× expansion, both wrapped in pre-norm residuals. No FlashAttention,
no rotary, no mixture-of-experts — anything more would muddy the
comparison. The point is to ask: at the same param count and the same
data, does the heterogeneous-pathway block beat the standard one?
"""
from __future__ import annotations
import math
import torch
from torch import Tensor, nn
from torch.nn import functional as F
class VanillaBlock(nn.Module):
"""One pre-norm Transformer decoder block.
Standard layout:
x → LayerNorm → causal MHA → +x
x → LayerNorm → GELU FFN(4×) → +x
"""
def __init__(
self,
d_model: int,
n_heads: int,
expand: int = 4,
) -> None:
super().__init__()
if d_model % n_heads != 0:
raise ValueError(
f"d_model {d_model} not divisible by n_heads {n_heads}"
)
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.norm1 = nn.LayerNorm(d_model)
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.proj = nn.Linear(d_model, d_model, bias=False)
self.norm2 = nn.LayerNorm(d_model)
self.ff_up = nn.Linear(d_model, expand * d_model, bias=False)
self.ff_down = nn.Linear(expand * d_model, d_model, bias=False)
def forward(self, x: Tensor) -> Tensor:
B, L, D = x.shape
h = self.norm1(x)
qkv = self.qkv(h).view(B, L, 3, self.n_heads, self.d_head)
q, k, v = qkv.unbind(dim=2)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
mask = torch.triu(
torch.ones(L, L, device=x.device, dtype=torch.bool),
diagonal=1,
)
scores = scores.masked_fill(mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, D)
x = x + self.proj(out)
h = self.norm2(x)
return x + self.ff_down(F.gelu(self.ff_up(h)))
class VanillaLM(nn.Module):
"""Byte-level vanilla Transformer LM.
Mirrors ``TilelliLM`` interface (``forward``, ``loss``, ``generate``,
``parameter_count``) so the trainer can swap one for the other.
"""
def __init__(
self,
vocab_size: int = 256,
d_model: int = 384,
n_layers: int = 6,
n_heads: int = 6,
expand: int = 4,
max_seq_len: int = 512,
) -> None:
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_len, d_model)
self.blocks = nn.ModuleList(
[VanillaBlock(d_model, n_heads, expand) for _ in range(n_layers)]
)
self.norm_out = nn.LayerNorm(d_model)
self.unembed = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, ids: Tensor) -> Tensor:
if ids.dim() != 2:
raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}")
B, L = ids.shape
if L > self.max_seq_len:
raise ValueError(
f"sequence length {L} exceeds max_seq_len {self.max_seq_len}"
)
positions = torch.arange(L, device=ids.device)
x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :]
for block in self.blocks:
x = block(x)
return self.unembed(self.norm_out(x))
def loss(self, ids: Tensor, targets: Tensor) -> Tensor:
logits = self.forward(ids)
return F.cross_entropy(
logits.reshape(-1, self.vocab_size), targets.reshape(-1)
)
@torch.no_grad()
def generate(self, ids: Tensor, n_new_tokens: int) -> Tensor:
was_training = self.training
self.eval()
try:
for _ in range(n_new_tokens):
ids_in = ids[:, -self.max_seq_len:]
logits = self.forward(ids_in)[:, -1, :]
next_id = logits.argmax(dim=-1, keepdim=True)
ids = torch.cat([ids, next_id], dim=1)
return ids
finally:
if was_training:
self.train()
def parameter_count(self) -> int:
return sum(p.numel() for p in self.parameters())