"""Vanilla pre-norm Transformer baseline. A minimal, faithful pre-norm Transformer at the same byte-level tokenizer, same max sequence length, and same parameter budget as the public ``TilelliLM`` config. Used solely for the param-matched "beat vanilla" comparison the project's headline claim rests on. This is the textbook decoder block: multi-head causal attention + GELU FFN at 4× expansion, both wrapped in pre-norm residuals. No FlashAttention, no rotary, no mixture-of-experts — anything more would muddy the comparison. The point is to ask: at the same param count and the same data, does the heterogeneous-pathway block beat the standard one? """ from __future__ import annotations import math import torch from torch import Tensor, nn from torch.nn import functional as F class VanillaBlock(nn.Module): """One pre-norm Transformer decoder block. Standard layout: x → LayerNorm → causal MHA → +x x → LayerNorm → GELU FFN(4×) → +x """ def __init__( self, d_model: int, n_heads: int, expand: int = 4, ) -> None: super().__init__() if d_model % n_heads != 0: raise ValueError( f"d_model {d_model} not divisible by n_heads {n_heads}" ) self.d_model = d_model self.n_heads = n_heads self.d_head = d_model // n_heads self.norm1 = nn.LayerNorm(d_model) self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) self.proj = nn.Linear(d_model, d_model, bias=False) self.norm2 = nn.LayerNorm(d_model) self.ff_up = nn.Linear(d_model, expand * d_model, bias=False) self.ff_down = nn.Linear(expand * d_model, d_model, bias=False) def forward(self, x: Tensor) -> Tensor: B, L, D = x.shape h = self.norm1(x) qkv = self.qkv(h).view(B, L, 3, self.n_heads, self.d_head) q, k, v = qkv.unbind(dim=2) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head) mask = torch.triu( torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1, ) scores = scores.masked_fill(mask, float("-inf")) attn = F.softmax(scores, dim=-1) out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, D) x = x + self.proj(out) h = self.norm2(x) return x + self.ff_down(F.gelu(self.ff_up(h))) class VanillaLM(nn.Module): """Byte-level vanilla Transformer LM. Mirrors ``TilelliLM`` interface (``forward``, ``loss``, ``generate``, ``parameter_count``) so the trainer can swap one for the other. """ def __init__( self, vocab_size: int = 256, d_model: int = 384, n_layers: int = 6, n_heads: int = 6, expand: int = 4, max_seq_len: int = 512, ) -> None: super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.max_seq_len = max_seq_len self.token_emb = nn.Embedding(vocab_size, d_model) self.pos_emb = nn.Embedding(max_seq_len, d_model) self.blocks = nn.ModuleList( [VanillaBlock(d_model, n_heads, expand) for _ in range(n_layers)] ) self.norm_out = nn.LayerNorm(d_model) self.unembed = nn.Linear(d_model, vocab_size, bias=False) def forward(self, ids: Tensor) -> Tensor: if ids.dim() != 2: raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}") B, L = ids.shape if L > self.max_seq_len: raise ValueError( f"sequence length {L} exceeds max_seq_len {self.max_seq_len}" ) positions = torch.arange(L, device=ids.device) x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :] for block in self.blocks: x = block(x) return self.unembed(self.norm_out(x)) def loss(self, ids: Tensor, targets: Tensor) -> Tensor: logits = self.forward(ids) return F.cross_entropy( logits.reshape(-1, self.vocab_size), targets.reshape(-1) ) @torch.no_grad() def generate(self, ids: Tensor, n_new_tokens: int) -> Tensor: was_training = self.training self.eval() try: for _ in range(n_new_tokens): ids_in = ids[:, -self.max_seq_len:] logits = self.forward(ids_in)[:, -1, :] next_id = logits.argmax(dim=-1, keepdim=True) ids = torch.cat([ids, next_id], dim=1) return ids finally: if was_training: self.train() def parameter_count(self) -> int: return sum(p.numel() for p in self.parameters())