AxionLab-official commited on
Commit
3da2ee9
·
verified ·
1 Parent(s): 9dad87e

Create components.py

Browse files
Files changed (1) hide show
  1. components.py +237 -0
components.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model components optimized for CPU training.
3
+
4
+ Design rationale:
5
+ - RMSNorm instead of LayerNorm: simpler, faster (no mean computation)
6
+ - Rotary Position Embeddings (RoPE): no learned position embeddings needed,
7
+ saves parameters and generalizes better
8
+ - LoRA-style low-rank linear layers: dramatically reduces parameter count
9
+ while maintaining expressiveness
10
+ - All operations use float32 for CPU stability (no mixed precision)
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import math
17
+ from typing import Optional, Tuple
18
+
19
+
20
+ class RMSNorm(nn.Module):
21
+ """
22
+ Root Mean Square normalization.
23
+
24
+ Why: ~30% faster than LayerNorm on CPU since it skips mean computation.
25
+ Empirically equivalent performance for transformers.
26
+ """
27
+ def __init__(self, dim: int, eps: float = 1e-6):
28
+ super().__init__()
29
+ self.eps = eps
30
+ self.weight = nn.Parameter(torch.ones(dim))
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+ norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
34
+ return x * norm * self.weight
35
+
36
+
37
+ class RotaryEmbedding(nn.Module):
38
+ """
39
+ Rotary Position Embedding (RoPE).
40
+
41
+ Why:
42
+ - No learned parameters (saves memory)
43
+ - Relative position awareness without extra params
44
+ - Extrapolates better to unseen sequence lengths
45
+ - Computationally efficient on CPU (just sin/cos)
46
+ """
47
+ def __init__(self, dim: int, max_seq_len: int = 512, base: float = 10000.0):
48
+ super().__init__()
49
+ self.dim = dim
50
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
51
+ self.register_buffer('inv_freq', inv_freq)
52
+ # Pre-compute for max_seq_len to avoid recomputation
53
+ self._build_cache(max_seq_len)
54
+
55
+ def _build_cache(self, seq_len: int):
56
+ t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
57
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
58
+ emb = torch.cat((freqs, freqs), dim=-1)
59
+ self.register_buffer('cos_cached', emb.cos())
60
+ self.register_buffer('sin_cached', emb.sin())
61
+
62
+ def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
63
+ if seq_len > self.cos_cached.size(0):
64
+ self._build_cache(seq_len)
65
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
66
+
67
+
68
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
69
+ """Rotate half the hidden dims of the input."""
70
+ x1 = x[..., : x.shape[-1] // 2]
71
+ x2 = x[..., x.shape[-1] // 2 :]
72
+ return torch.cat((-x2, x1), dim=-1)
73
+
74
+
75
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor,
76
+ cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ """Apply rotary embeddings to queries and keys."""
78
+ # cos, sin: [seq_len, dim]
79
+ cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim]
80
+ sin = sin.unsqueeze(0).unsqueeze(0)
81
+ q_embed = (q * cos) + (rotate_half(q) * sin)
82
+ k_embed = (k * cos) + (rotate_half(k) * sin)
83
+ return q_embed, k_embed
84
+
85
+
86
+ class LoRALinear(nn.Module):
87
+ """
88
+ Low-Rank Adaptation linear layer.
89
+
90
+ Why: Instead of full d_in x d_out matrix, uses two smaller matrices:
91
+ d_in x rank + rank x d_out. For rank=16, d_in=d_out=256:
92
+ Full: 65,536 params
93
+ LoRA: 256*16 + 16*256 = 8,192 params (8x reduction!)
94
+
95
+ Still maintains good expressiveness for the tasks we need.
96
+ """
97
+ def __init__(self, in_features: int, out_features: int, rank: int = 16, bias: bool = False):
98
+ super().__init__()
99
+ self.rank = rank
100
+ # If rank is large enough, just use full linear
101
+ if rank >= min(in_features, out_features) // 2:
102
+ self.use_lora = False
103
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
104
+ else:
105
+ self.use_lora = True
106
+ self.down = nn.Linear(in_features, rank, bias=False)
107
+ self.up = nn.Linear(rank, out_features, bias=bias)
108
+ # Initialize to approximate identity-like behavior
109
+ nn.init.kaiming_uniform_(self.down.weight, a=math.sqrt(5))
110
+ nn.init.zeros_(self.up.weight)
111
+
112
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
113
+ if self.use_lora:
114
+ return self.up(self.down(x))
115
+ return self.linear(x)
116
+
117
+
118
+ class GatedMLP(nn.Module):
119
+ """
120
+ SwiGLU-style gated MLP.
121
+
122
+ Why: Gated activation functions consistently outperform standard ReLU/GELU
123
+ in transformers, especially at small scale. The gate provides a learned
124
+ "feature selection" mechanism.
125
+
126
+ Uses LoRA projections to save parameters.
127
+ """
128
+ def __init__(self, d_model: int, d_ff: int, rank: int = 16, dropout: float = 0.05):
129
+ super().__init__()
130
+ self.gate_proj = LoRALinear(d_model, d_ff, rank=rank)
131
+ self.up_proj = LoRALinear(d_model, d_ff, rank=rank)
132
+ self.down_proj = LoRALinear(d_ff, d_model, rank=rank)
133
+ self.dropout = nn.Dropout(dropout)
134
+
135
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
136
+ gate = F.silu(self.gate_proj(x))
137
+ up = self.up_proj(x)
138
+ return self.dropout(self.down_proj(gate * up))
139
+
140
+
141
+ class MultiHeadAttention(nn.Module):
142
+ """
143
+ Multi-Head Attention with RoPE and optional Grouped Query Attention.
144
+
145
+ Why these choices:
146
+ - Grouped Query Attention (GQA): shares KV heads, reducing memory and params
147
+ while maintaining quality. For 8 heads with 4 KV groups: 50% KV param reduction.
148
+ - Pre-computed causal mask: avoids recomputing each forward pass on CPU
149
+ - RoPE applied per-head: correct relative position encoding
150
+ """
151
+ def __init__(self, d_model: int, n_heads: int, rank: int = 16,
152
+ dropout: float = 0.05, max_seq_len: int = 512,
153
+ n_kv_heads: Optional[int] = None):
154
+ super().__init__()
155
+ self.d_model = d_model
156
+ self.n_heads = n_heads
157
+ self.n_kv_heads = n_kv_heads or n_heads
158
+ self.head_dim = d_model // n_heads
159
+ self.n_rep = n_heads // self.n_kv_heads # repetition factor for GQA
160
+
161
+ assert d_model % n_heads == 0
162
+
163
+ self.q_proj = LoRALinear(d_model, d_model, rank=rank)
164
+ self.k_proj = LoRALinear(d_model, self.n_kv_heads * self.head_dim, rank=rank)
165
+ self.v_proj = LoRALinear(d_model, self.n_kv_heads * self.head_dim, rank=rank)
166
+ self.o_proj = LoRALinear(d_model, d_model, rank=rank)
167
+
168
+ self.dropout = nn.Dropout(dropout)
169
+ self.rope = RotaryEmbedding(self.head_dim, max_seq_len)
170
+
171
+ # Pre-compute causal mask
172
+ mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
173
+ self.register_buffer('causal_mask', mask)
174
+
175
+ def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
176
+ """Repeat KV heads to match Q heads for GQA."""
177
+ if self.n_rep == 1:
178
+ return x
179
+ bs, n_kv, seq_len, head_dim = x.shape
180
+ x = x[:, :, None, :, :].expand(bs, n_kv, self.n_rep, seq_len, head_dim)
181
+ return x.reshape(bs, self.n_heads, seq_len, head_dim)
182
+
183
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
184
+ B, T, C = x.shape
185
+
186
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
187
+ k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
188
+ v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
189
+
190
+ # Apply RoPE
191
+ cos, sin = self.rope(T)
192
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
193
+
194
+ # Expand KV for GQA
195
+ k = self._repeat_kv(k)
196
+ v = self._repeat_kv(v)
197
+
198
+ # Attention
199
+ scale = math.sqrt(self.head_dim)
200
+ attn = torch.matmul(q, k.transpose(-2, -1)) / scale
201
+
202
+ # Apply causal mask
203
+ causal = self.causal_mask[:T, :T].unsqueeze(0).unsqueeze(0)
204
+ attn = attn.masked_fill(causal, float('-inf'))
205
+
206
+ if mask is not None:
207
+ # mask shape: [B, T] -> [B, 1, 1, T]
208
+ attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
209
+
210
+ attn = F.softmax(attn, dim=-1)
211
+ attn = self.dropout(attn)
212
+
213
+ out = torch.matmul(attn, v)
214
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
215
+ return self.o_proj(out)
216
+
217
+
218
+ class TransformerBlock(nn.Module):
219
+ """
220
+ Single transformer block with pre-norm architecture.
221
+
222
+ Why pre-norm: More stable training, especially at small scale.
223
+ Gradient flow is better since residual path is unimpeded.
224
+ """
225
+ def __init__(self, d_model: int, n_heads: int, d_ff: int,
226
+ rank: int = 16, dropout: float = 0.05,
227
+ max_seq_len: int = 512, n_kv_heads: Optional[int] = None):
228
+ super().__init__()
229
+ self.attn_norm = RMSNorm(d_model)
230
+ self.attn = MultiHeadAttention(d_model, n_heads, rank, dropout, max_seq_len, n_kv_heads)
231
+ self.ff_norm = RMSNorm(d_model)
232
+ self.ff = GatedMLP(d_model, d_ff, rank, dropout)
233
+
234
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
235
+ x = x + self.attn(self.attn_norm(x), mask)
236
+ x = x + self.ff(self.ff_norm(x))
237
+ return x