File size: 4,884 Bytes
f86dc09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | """Vanilla pre-norm Transformer baseline.
A minimal, faithful pre-norm Transformer at the same byte-level tokenizer,
same max sequence length, and same parameter budget as the public
``TilelliLM`` config. Used solely for the param-matched "beat vanilla"
comparison the project's headline claim rests on.
This is the textbook decoder block: multi-head causal attention + GELU FFN
at 4× expansion, both wrapped in pre-norm residuals. No FlashAttention,
no rotary, no mixture-of-experts — anything more would muddy the
comparison. The point is to ask: at the same param count and the same
data, does the heterogeneous-pathway block beat the standard one?
"""
from __future__ import annotations
import math
import torch
from torch import Tensor, nn
from torch.nn import functional as F
class VanillaBlock(nn.Module):
"""One pre-norm Transformer decoder block.
Standard layout:
x → LayerNorm → causal MHA → +x
x → LayerNorm → GELU FFN(4×) → +x
"""
def __init__(
self,
d_model: int,
n_heads: int,
expand: int = 4,
) -> None:
super().__init__()
if d_model % n_heads != 0:
raise ValueError(
f"d_model {d_model} not divisible by n_heads {n_heads}"
)
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.norm1 = nn.LayerNorm(d_model)
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.proj = nn.Linear(d_model, d_model, bias=False)
self.norm2 = nn.LayerNorm(d_model)
self.ff_up = nn.Linear(d_model, expand * d_model, bias=False)
self.ff_down = nn.Linear(expand * d_model, d_model, bias=False)
def forward(self, x: Tensor) -> Tensor:
B, L, D = x.shape
h = self.norm1(x)
qkv = self.qkv(h).view(B, L, 3, self.n_heads, self.d_head)
q, k, v = qkv.unbind(dim=2)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
mask = torch.triu(
torch.ones(L, L, device=x.device, dtype=torch.bool),
diagonal=1,
)
scores = scores.masked_fill(mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, D)
x = x + self.proj(out)
h = self.norm2(x)
return x + self.ff_down(F.gelu(self.ff_up(h)))
class VanillaLM(nn.Module):
"""Byte-level vanilla Transformer LM.
Mirrors ``TilelliLM`` interface (``forward``, ``loss``, ``generate``,
``parameter_count``) so the trainer can swap one for the other.
"""
def __init__(
self,
vocab_size: int = 256,
d_model: int = 384,
n_layers: int = 6,
n_heads: int = 6,
expand: int = 4,
max_seq_len: int = 512,
) -> None:
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_len, d_model)
self.blocks = nn.ModuleList(
[VanillaBlock(d_model, n_heads, expand) for _ in range(n_layers)]
)
self.norm_out = nn.LayerNorm(d_model)
self.unembed = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, ids: Tensor) -> Tensor:
if ids.dim() != 2:
raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}")
B, L = ids.shape
if L > self.max_seq_len:
raise ValueError(
f"sequence length {L} exceeds max_seq_len {self.max_seq_len}"
)
positions = torch.arange(L, device=ids.device)
x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :]
for block in self.blocks:
x = block(x)
return self.unembed(self.norm_out(x))
def loss(self, ids: Tensor, targets: Tensor) -> Tensor:
logits = self.forward(ids)
return F.cross_entropy(
logits.reshape(-1, self.vocab_size), targets.reshape(-1)
)
@torch.no_grad()
def generate(self, ids: Tensor, n_new_tokens: int) -> Tensor:
was_training = self.training
self.eval()
try:
for _ in range(n_new_tokens):
ids_in = ids[:, -self.max_seq_len:]
logits = self.forward(ids_in)[:, -1, :]
next_id = logits.argmax(dim=-1, keepdim=True)
ids = torch.cat([ids, next_id], dim=1)
return ids
finally:
if was_training:
self.train()
def parameter_count(self) -> int:
return sum(p.numel() for p in self.parameters())
|