File size: 4,884 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""Vanilla pre-norm Transformer baseline.

A minimal, faithful pre-norm Transformer at the same byte-level tokenizer,
same max sequence length, and same parameter budget as the public
``TilelliLM`` config. Used solely for the param-matched "beat vanilla"
comparison the project's headline claim rests on.

This is the textbook decoder block: multi-head causal attention + GELU FFN
at 4× expansion, both wrapped in pre-norm residuals. No FlashAttention,
no rotary, no mixture-of-experts — anything more would muddy the
comparison. The point is to ask: at the same param count and the same
data, does the heterogeneous-pathway block beat the standard one?
"""
from __future__ import annotations

import math

import torch
from torch import Tensor, nn
from torch.nn import functional as F


class VanillaBlock(nn.Module):
    """One pre-norm Transformer decoder block.

    Standard layout:

        x → LayerNorm → causal MHA   → +x
        x → LayerNorm → GELU FFN(4×) → +x
    """

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        expand: int = 4,
    ) -> None:
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError(
                f"d_model {d_model} not divisible by n_heads {n_heads}"
            )
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        self.norm1 = nn.LayerNorm(d_model)
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model, bias=False)

        self.norm2 = nn.LayerNorm(d_model)
        self.ff_up = nn.Linear(d_model, expand * d_model, bias=False)
        self.ff_down = nn.Linear(expand * d_model, d_model, bias=False)

    def forward(self, x: Tensor) -> Tensor:
        B, L, D = x.shape
        h = self.norm1(x)
        qkv = self.qkv(h).view(B, L, 3, self.n_heads, self.d_head)
        q, k, v = qkv.unbind(dim=2)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
        mask = torch.triu(
            torch.ones(L, L, device=x.device, dtype=torch.bool),
            diagonal=1,
        )
        scores = scores.masked_fill(mask, float("-inf"))
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, D)
        x = x + self.proj(out)

        h = self.norm2(x)
        return x + self.ff_down(F.gelu(self.ff_up(h)))


class VanillaLM(nn.Module):
    """Byte-level vanilla Transformer LM.

    Mirrors ``TilelliLM`` interface (``forward``, ``loss``, ``generate``,
    ``parameter_count``) so the trainer can swap one for the other.
    """

    def __init__(
        self,
        vocab_size: int = 256,
        d_model: int = 384,
        n_layers: int = 6,
        n_heads: int = 6,
        expand: int = 4,
        max_seq_len: int = 512,
    ) -> None:
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.max_seq_len = max_seq_len

        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.blocks = nn.ModuleList(
            [VanillaBlock(d_model, n_heads, expand) for _ in range(n_layers)]
        )
        self.norm_out = nn.LayerNorm(d_model)
        self.unembed = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, ids: Tensor) -> Tensor:
        if ids.dim() != 2:
            raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}")
        B, L = ids.shape
        if L > self.max_seq_len:
            raise ValueError(
                f"sequence length {L} exceeds max_seq_len {self.max_seq_len}"
            )
        positions = torch.arange(L, device=ids.device)
        x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :]
        for block in self.blocks:
            x = block(x)
        return self.unembed(self.norm_out(x))

    def loss(self, ids: Tensor, targets: Tensor) -> Tensor:
        logits = self.forward(ids)
        return F.cross_entropy(
            logits.reshape(-1, self.vocab_size), targets.reshape(-1)
        )

    @torch.no_grad()
    def generate(self, ids: Tensor, n_new_tokens: int) -> Tensor:
        was_training = self.training
        self.eval()
        try:
            for _ in range(n_new_tokens):
                ids_in = ids[:, -self.max_seq_len:]
                logits = self.forward(ids_in)[:, -1, :]
                next_id = logits.argmax(dim=-1, keepdim=True)
                ids = torch.cat([ids, next_id], dim=1)
            return ids
        finally:
            if was_training:
                self.train()

    def parameter_count(self) -> int:
        return sum(p.numel() for p in self.parameters())