Files changed (1) hide show
  1. gclm.py +191 -0
gclm.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ # --- Configuration & Data ---
7
+ data = """To be, or not to be, that is the question:
8
+ Whether 'tis nobler in the mind to suffer
9
+ The slings and arrows of outrageous fortune,
10
+ Or to take arms against a sea of troubles
11
+ And by opposing end them. To die—to sleep,
12
+ No more; and by a sleep to say we end
13
+ The heart-ache and the thousand natural shocks
14
+ That flesh is heir to: 'tis a consummation
15
+ Devoutly to be wish'd. To die, to sleep;
16
+ To sleep, perchance to dream—ay, there's the rub:
17
+ For in that sleep of death what dreams may come,
18
+ When we have shuffled off this mortal coil,
19
+ Must give us pause—there's the respect
20
+ That makes calamity of so long life.
21
+ For who would bear the whips and scorns of time,
22
+ Th'oppressor's wrong, the proud man's contumely,
23
+ The pangs of dispriz'd love, the law's delay,
24
+ The insolence of office, and the spurns
25
+ That patient merit of th'unworthy takes,
26
+ When he himself might his quietus make
27
+ With a bare bodkin? Who would fardels bear,
28
+ To grunt and sweat under a weary life,
29
+ But that the dread of something after death,
30
+ The undiscovere'd country, from whose bourn
31
+ No traveller returns, puzzles the will,
32
+ And makes us rather bear those ills we have
33
+ Than fly to others that we know not of?
34
+ Thus conscience doth make cowards of us all,
35
+ And thus the native hue of resolution
36
+ Is sicklied o'er with the pale cast of thought,
37
+ And enterprises of great pith and moment
38
+ With this regard their currents turn awry
39
+ And lose the name of action."""
40
+
41
+ chars = sorted(list(set(data)))
42
+ vocab_size = len(chars)
43
+ stoi = {ch: i for i, ch in enumerate(chars)}
44
+ itos = {i: ch for i, ch in enumerate(chars)}
45
+ encoded = torch.tensor([stoi[c] for c in data], dtype=torch.long)
46
+
47
+ # Hyperparameters based on your architecture
48
+ D_MODEL = 256
49
+ N_LAYERS = 4
50
+ MAX_SEQ_LEN = 64
51
+ LOCAL_K = 5
52
+ GLOBAL_K = 128
53
+ FFT_SIZE = 256
54
+ TRAIN_TIME = 60
55
+ BATCH_SIZE = 8
56
+
57
+ # --- Architecture Components ---
58
+
59
+ class GlobalConv1D(nn.Module):
60
+ def __init__(self, d_model, kernel_size, fft_size):
61
+ super().__init__()
62
+ self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
63
+ self.kernel_size = kernel_size
64
+ self.fft_size = fft_size
65
+
66
+ def forward(self, x):
67
+ B, C, T = x.shape
68
+ K = min(self.kernel_size, T)
69
+ overlap = K - 1
70
+ block = self.fft_size - overlap
71
+
72
+ x = F.pad(x, (overlap, 0))
73
+ k = self.kernel[:, :K]
74
+ k = F.pad(k, (0, self.fft_size - K))
75
+ k_f = torch.fft.rfft(k, n=self.fft_size)
76
+
77
+ outs = []
78
+ pos = 0
79
+ while pos < T:
80
+ seg = x[..., pos:pos + self.fft_size]
81
+ if seg.shape[-1] < self.fft_size:
82
+ seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
83
+ y = torch.fft.irfft(torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size)
84
+ outs.append(y[..., overlap:overlap + block])
85
+ pos += block
86
+ return torch.cat(outs, dim=-1)[..., :T]
87
+
88
+ class LocalConv1D(nn.Module):
89
+ def __init__(self, d_model, k):
90
+ super().__init__()
91
+ self.k = k
92
+ self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
93
+ self.pw = nn.Conv1d(d_model, d_model, 1)
94
+
95
+ def forward(self, x):
96
+ x = F.pad(x, (self.k - 1, 0))
97
+ return self.pw(F.relu(self.dw(x)))
98
+
99
+ class Block(nn.Module):
100
+ def __init__(self, d_model, use_global):
101
+ super().__init__()
102
+ self.use_global = use_global
103
+ self.ln1 = nn.LayerNorm(d_model)
104
+ self.local = LocalConv1D(d_model, LOCAL_K)
105
+ if use_global:
106
+ self.ln2 = nn.LayerNorm(d_model)
107
+ self.global_conv = GlobalConv1D(d_model, GLOBAL_K, FFT_SIZE)
108
+ self.ln3 = nn.LayerNorm(d_model)
109
+ self.ff = nn.Sequential(
110
+ nn.Linear(d_model, d_model * 4),
111
+ nn.GELU(),
112
+ nn.Linear(d_model * 4, d_model)
113
+ )
114
+
115
+ def forward(self, x):
116
+ x = x + self.local(self.ln1(x).transpose(1, 2)).transpose(1, 2)
117
+ if self.use_global:
118
+ x = x + self.global_conv(self.ln2(x).transpose(1, 2)).transpose(1, 2)
119
+ return x + self.ff(self.ln3(x))
120
+
121
+ class GCLM(nn.Module):
122
+ def __init__(self, vocab):
123
+ super().__init__()
124
+ self.emb = nn.Embedding(vocab, D_MODEL)
125
+ self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
126
+ self.layers = nn.ModuleList([Block(D_MODEL, i % 2 == 0) for i in range(N_LAYERS)])
127
+ self.ln = nn.LayerNorm(D_MODEL)
128
+ self.head = nn.Linear(D_MODEL, vocab)
129
+ self.head.weight = self.emb.weight # Weight Tying
130
+
131
+ def forward(self, x):
132
+ T = x.size(1)
133
+ h = self.emb(x) + self.pos(torch.arange(T, device=x.device))
134
+ for layer in self.layers:
135
+ h = layer(h)
136
+ return self.head(self.ln(h))
137
+
138
+ # --- Training Setup ---
139
+
140
+ device = "cuda" if torch.cuda.is_available() else "cpu"
141
+ model = GCLM(vocab_size).to(device)
142
+ optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
143
+
144
+ print(f"Training on {device} for {TRAIN_TIME} seconds...")
145
+ start_time = time.time()
146
+ step = 0
147
+
148
+ model.train()
149
+ while (time.time() - start_time) < TRAIN_TIME:
150
+ # Random batching
151
+ ix = torch.randint(0, len(encoded) - MAX_SEQ_LEN, (BATCH_SIZE,))
152
+ x = torch.stack([encoded[i : i + MAX_SEQ_LEN] for i in ix]).to(device)
153
+ y = torch.stack([encoded[i + 1 : i + MAX_SEQ_LEN + 1] for i in ix]).to(device)
154
+
155
+ logits = model(x)
156
+ loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
157
+
158
+ optimizer.zero_grad(set_to_none=True)
159
+ loss.backward()
160
+ optimizer.step()
161
+
162
+ if step % 10 == 0:
163
+ elapsed = time.time() - start_time
164
+ print(f"\rStep {step} | Loss: {loss.item():.4f} | Progress: {min(100, (elapsed/TRAIN_TIME)*100):.1f}%", end="")
165
+ step += 1
166
+
167
+ # --- Generation ---
168
+
169
+ print("\n\nTraining Complete. Generating:\n" + "-"*30)
170
+ model.eval()
171
+ prompt = "To be, "
172
+ ctx = torch.tensor([[stoi[c] for c in prompt]], dtype=torch.long, device=device)
173
+ print(prompt, end="", flush=True)
174
+
175
+ with torch.no_grad():
176
+ for _ in range(MAX_SEQ_LEN * 2):
177
+ # Crop context to model's MAX_SEQ_LEN
178
+ inp = ctx[:, -MAX_SEQ_LEN:]
179
+ logits = model(inp)
180
+ logits = logits[:, -1, :] / 0.8 # Temperature
181
+
182
+ # Simple top-k to keep it clean
183
+ v, _ = torch.topk(logits, min(10, vocab_size))
184
+ logits[logits < v[:, [-1]]] = -float('Inf')
185
+
186
+ probs = F.softmax(logits, dim=-1)
187
+ next_char_idx = torch.multinomial(probs, num_samples=1)
188
+
189
+ ctx = torch.cat((ctx, next_char_idx), dim=1)
190
+ print(itos[next_char_idx.item()], end="", flush=True)
191
+ print("\n" + "-"*30)