| """Vanilla pre-norm Transformer baseline. |
| |
| A minimal, faithful pre-norm Transformer at the same byte-level tokenizer, |
| same max sequence length, and same parameter budget as the public |
| ``TilelliLM`` config. Used solely for the param-matched "beat vanilla" |
| comparison the project's headline claim rests on. |
| |
| This is the textbook decoder block: multi-head causal attention + GELU FFN |
| at 4× expansion, both wrapped in pre-norm residuals. No FlashAttention, |
| no rotary, no mixture-of-experts — anything more would muddy the |
| comparison. The point is to ask: at the same param count and the same |
| data, does the heterogeneous-pathway block beat the standard one? |
| """ |
| from __future__ import annotations |
|
|
| import math |
|
|
| import torch |
| from torch import Tensor, nn |
| from torch.nn import functional as F |
|
|
|
|
| class VanillaBlock(nn.Module): |
| """One pre-norm Transformer decoder block. |
| |
| Standard layout: |
| |
| x → LayerNorm → causal MHA → +x |
| x → LayerNorm → GELU FFN(4×) → +x |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_heads: int, |
| expand: int = 4, |
| ) -> None: |
| super().__init__() |
| if d_model % n_heads != 0: |
| raise ValueError( |
| f"d_model {d_model} not divisible by n_heads {n_heads}" |
| ) |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.d_head = d_model // n_heads |
|
|
| self.norm1 = nn.LayerNorm(d_model) |
| self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) |
| self.proj = nn.Linear(d_model, d_model, bias=False) |
|
|
| self.norm2 = nn.LayerNorm(d_model) |
| self.ff_up = nn.Linear(d_model, expand * d_model, bias=False) |
| self.ff_down = nn.Linear(expand * d_model, d_model, bias=False) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| B, L, D = x.shape |
| h = self.norm1(x) |
| qkv = self.qkv(h).view(B, L, 3, self.n_heads, self.d_head) |
| q, k, v = qkv.unbind(dim=2) |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head) |
| mask = torch.triu( |
| torch.ones(L, L, device=x.device, dtype=torch.bool), |
| diagonal=1, |
| ) |
| scores = scores.masked_fill(mask, float("-inf")) |
| attn = F.softmax(scores, dim=-1) |
| out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, D) |
| x = x + self.proj(out) |
|
|
| h = self.norm2(x) |
| return x + self.ff_down(F.gelu(self.ff_up(h))) |
|
|
|
|
| class VanillaLM(nn.Module): |
| """Byte-level vanilla Transformer LM. |
| |
| Mirrors ``TilelliLM`` interface (``forward``, ``loss``, ``generate``, |
| ``parameter_count``) so the trainer can swap one for the other. |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size: int = 256, |
| d_model: int = 384, |
| n_layers: int = 6, |
| n_heads: int = 6, |
| expand: int = 4, |
| max_seq_len: int = 512, |
| ) -> None: |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.n_layers = n_layers |
| self.max_seq_len = max_seq_len |
|
|
| self.token_emb = nn.Embedding(vocab_size, d_model) |
| self.pos_emb = nn.Embedding(max_seq_len, d_model) |
| self.blocks = nn.ModuleList( |
| [VanillaBlock(d_model, n_heads, expand) for _ in range(n_layers)] |
| ) |
| self.norm_out = nn.LayerNorm(d_model) |
| self.unembed = nn.Linear(d_model, vocab_size, bias=False) |
|
|
| def forward(self, ids: Tensor) -> Tensor: |
| if ids.dim() != 2: |
| raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}") |
| B, L = ids.shape |
| if L > self.max_seq_len: |
| raise ValueError( |
| f"sequence length {L} exceeds max_seq_len {self.max_seq_len}" |
| ) |
| positions = torch.arange(L, device=ids.device) |
| x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :] |
| for block in self.blocks: |
| x = block(x) |
| return self.unembed(self.norm_out(x)) |
|
|
| def loss(self, ids: Tensor, targets: Tensor) -> Tensor: |
| logits = self.forward(ids) |
| return F.cross_entropy( |
| logits.reshape(-1, self.vocab_size), targets.reshape(-1) |
| ) |
|
|
| @torch.no_grad() |
| def generate(self, ids: Tensor, n_new_tokens: int) -> Tensor: |
| was_training = self.training |
| self.eval() |
| try: |
| for _ in range(n_new_tokens): |
| ids_in = ids[:, -self.max_seq_len:] |
| logits = self.forward(ids_in)[:, -1, :] |
| next_id = logits.argmax(dim=-1, keepdim=True) |
| ids = torch.cat([ids, next_id], dim=1) |
| return ids |
| finally: |
| if was_training: |
| self.train() |
|
|
| def parameter_count(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|