AbstractPhil commited on
Commit
2c64e3d
·
verified ·
1 Parent(s): 3b820cd

Create prototype_5_geodesic_prelim.py

Browse files
Files changed (1) hide show
  1. prototype_5_geodesic_prelim.py +855 -0
prototype_5_geodesic_prelim.py ADDED
@@ -0,0 +1,855 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # RAPID PROTOTYPE v2: Differentiation-Centered Alignment Bank
3
+ #
4
+ # The bank aligns to the DIFFERENTIATION between experts, not to any
5
+ # arbitrary target. The consensus CV, spectral profile, and pairwise
6
+ # statistics measured during alignment become the exact targets.
7
+ #
8
+ # The bank embodies the centerpoint of expert disagreement.
9
+ # ============================================================================
10
+
11
+ import gc
12
+ import math
13
+ import os
14
+ import time
15
+ import json
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from tqdm import tqdm
22
+
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ EXPERTS = [
26
+ ("google-bert/bert-base-uncased", "bert", 512),
27
+ ("answerdotai/ModernBERT-base", "modern", 512),
28
+ ]
29
+
30
+ print("=" * 65)
31
+ print("RAPID PROTOTYPE v2: Differentiation-Centered Bank")
32
+ print("=" * 65)
33
+ print(f" Device: {DEVICE}")
34
+
35
+
36
+ # ══════════════════════════════════════════════════════════════════
37
+ # STUDENT MODEL
38
+ # ══════════════════════════════════════════════════════════════════
39
+
40
+ class MiniStudent(nn.Module):
41
+ def __init__(self, vocab_size=30522, max_len=512, d_model=256,
42
+ n_heads=4, n_layers=4, d_ff=1024, output_dim=768,
43
+ dropout=0.1, pad_token_id=0):
44
+ super().__init__()
45
+ self.pad_token_id = pad_token_id
46
+ self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
47
+ self.pos_emb = nn.Embedding(max_len, d_model)
48
+ self.emb_norm = nn.LayerNorm(d_model)
49
+ self.emb_drop = nn.Dropout(dropout)
50
+ encoder_layer = nn.TransformerEncoderLayer(
51
+ d_model=d_model, nhead=n_heads, dim_feedforward=d_ff,
52
+ dropout=dropout, activation="gelu", batch_first=True,
53
+ norm_first=True)
54
+ self.encoder = nn.TransformerEncoder(
55
+ encoder_layer, num_layers=n_layers, enable_nested_tensor=False)
56
+ self.output_proj = nn.Sequential(
57
+ nn.Linear(d_model, d_model), nn.GELU(),
58
+ nn.LayerNorm(d_model), nn.Linear(d_model, output_dim))
59
+
60
+ def forward(self, input_ids, attention_mask=None):
61
+ B, L = input_ids.shape
62
+ positions = torch.arange(L, device=input_ids.device).unsqueeze(0)
63
+ x = self.token_emb(input_ids) + self.pos_emb(positions)
64
+ x = self.emb_drop(self.emb_norm(x))
65
+ kpm = ~attention_mask.bool() if attention_mask is not None else (input_ids == self.pad_token_id)
66
+ x = self.encoder(x, src_key_padding_mask=kpm)
67
+ mask = attention_mask.unsqueeze(-1).float() if attention_mask is not None else (~kpm).unsqueeze(-1).float()
68
+ pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
69
+ return F.normalize(self.output_proj(pooled), dim=-1)
70
+
71
+
72
+ # ══════════════════════════════════════════════════════════════════
73
+ # ALIGNMENT BANK
74
+ # ══════════════════════════════════════════════════════════════════
75
+
76
+ class AlignmentBank(nn.Module):
77
+ """
78
+ Differentiation-centered geometric interface.
79
+
80
+ Aligns to the CENTERPOINT between experts — the consensus itself.
81
+ Stores per-expert rotation matrices (the differentiation structure)
82
+ and learned anchor landmarks (the consensus manifold topology).
83
+
84
+ The bank doesn't invent geometry. It mirrors the measured consensus.
85
+ Every loss term pulls toward measured consensus statistics.
86
+ """
87
+ def __init__(self, d_embed=768, n_experts=2, n_anchors=512, d_bank=128):
88
+ super().__init__()
89
+ self.d_embed = d_embed
90
+ self.n_experts = n_experts
91
+ self.n_anchors = n_anchors
92
+ self.d_bank = d_bank
93
+
94
+ # Per-expert rotation matrices (differentiation structure)
95
+ self.expert_rotations = nn.ParameterList([
96
+ nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)
97
+ ])
98
+
99
+ # Per-expert whiteners (captures variance structure per expert)
100
+ self.expert_whiteners = nn.ParameterList([
101
+ nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)
102
+ ])
103
+
104
+ # Per-expert means (centering offset per expert)
105
+ self.expert_means = nn.ParameterList([
106
+ nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)
107
+ ])
108
+
109
+ # Anchor bank: consensus landmarks on the hypersphere
110
+ self.anchors = nn.Parameter(
111
+ F.normalize(torch.randn(n_anchors, d_embed), dim=-1))
112
+
113
+ # Project: expert_cos (n) + expert_mse (n) + cross (n*(n-1)/2) +
114
+ # disagreement_ratio (1) + norm_ratio (n) + anchor_cos (n_anchors)
115
+ n_cross = n_experts * (n_experts - 1) // 2
116
+ geo_dim = n_experts + n_experts + n_cross + 1 + n_experts + n_anchors
117
+ self.geo_proj = nn.Sequential(
118
+ nn.Linear(geo_dim, d_bank * 2),
119
+ nn.GELU(),
120
+ nn.LayerNorm(d_bank * 2),
121
+ nn.Linear(d_bank * 2, d_bank),
122
+ nn.LayerNorm(d_bank),
123
+ )
124
+
125
+ # Consensus statistics (set during init, used as exact targets)
126
+ self.register_buffer("target_cv", torch.tensor(0.12))
127
+ self.register_buffer("target_mean_cos", torch.tensor(0.0))
128
+ self.register_buffer("target_spectral", torch.zeros(50))
129
+ # Disagreement structure (measured once, preserved forever)
130
+ self.register_buffer("target_cross_cos_mean", torch.tensor(0.0))
131
+ self.register_buffer("target_cross_cos_std", torch.tensor(0.0))
132
+ self.register_buffer("target_disagreement_ratio", torch.tensor(0.0))
133
+
134
+ def init_from_procrustes(self, procrustes_results, expert_names,
135
+ consensus_embeddings=None,
136
+ consensus_stats=None):
137
+ """Initialize from consensus training artifacts."""
138
+ device = self.anchors.device
139
+ for i, name in enumerate(expert_names[:self.n_experts]):
140
+ info = procrustes_results[name]
141
+ self.expert_rotations[i].data = info["rotation"].float().to(device)
142
+ if "source_whitener" in info:
143
+ self.expert_whiteners[i].data = info["source_whitener"].float().to(device)
144
+ if "source_mean" in info:
145
+ self.expert_means[i].data = info["source_mean"].float().to(device)
146
+ print(f" Expert {i} ({name}): rotation + whitener + mean loaded, "
147
+ f"cos_after={info['cos_after']:.4f}")
148
+
149
+ if consensus_embeddings is not None:
150
+ n = min(self.n_anchors, consensus_embeddings.shape[0])
151
+ indices = torch.linspace(0, consensus_embeddings.shape[0] - 1, n).long()
152
+ self.anchors.data[:n] = F.normalize(
153
+ consensus_embeddings[indices].float(), dim=-1).to(device)
154
+ print(f" Anchors: {n} initialized from consensus embeddings")
155
+
156
+ if consensus_stats is not None:
157
+ self.target_cv.fill_(consensus_stats["cv"])
158
+ self.target_mean_cos.fill_(consensus_stats["mean_cos"])
159
+ if "spectral" in consensus_stats:
160
+ s = torch.tensor(consensus_stats["spectral"][:50], dtype=torch.float32)
161
+ self.target_spectral[:len(s)] = s.to(device)
162
+ print(f" Targets: CV={consensus_stats['cv']:.4f}, "
163
+ f"mean_cos={consensus_stats['mean_cos']:.4f}")
164
+
165
+ def forward(self, embedding):
166
+ B = embedding.shape[0]
167
+ emb = embedding.float()
168
+
169
+ # ── Per-expert projections (full whitened Procrustes) ──
170
+ # Chain: center → whiten → normalize → rotate
171
+ # This is EXACTLY what was computed during alignment.
172
+ # The rotation only makes geometric sense in whitened-normalized space.
173
+ expert_consistency = []
174
+ expert_recon = []
175
+ expert_projected = []
176
+ for i in range(self.n_experts):
177
+ R = self.expert_rotations[i]
178
+ W = self.expert_whiteners[i]
179
+ mu = self.expert_means[i]
180
+
181
+ # Forward: center → whiten → normalize → rotate
182
+ centered = emb - mu
183
+ whitened = centered @ W
184
+ whitened_n = F.normalize(whitened, dim=-1)
185
+ in_expert = whitened_n @ R.T # now in expert's whitened-normalized space
186
+
187
+ # Round-trip: rotate back (orthogonal, so R.T inverse = R)
188
+ back = in_expert @ R
189
+
190
+ # Consistency: round-trip should recover whitened_n exactly
191
+ cos = F.cosine_similarity(whitened_n, back, dim=-1)
192
+ recon = (whitened_n - back).pow(2).mean(dim=-1)
193
+
194
+ expert_consistency.append(cos)
195
+ expert_recon.append(recon)
196
+ expert_projected.append(in_expert)
197
+
198
+ expert_cos = torch.stack(expert_consistency, dim=-1) # (B, n_experts)
199
+ expert_mse = torch.stack(expert_recon, dim=-1) # (B, n_experts)
200
+
201
+ # ── Cross-expert differentiation ──
202
+ # How each expert's projection relates to every other expert's projection
203
+ # This IS the disagreement structure. Preserve it exactly.
204
+ cross_cos = []
205
+ for i in range(self.n_experts):
206
+ for j in range(i + 1, self.n_experts):
207
+ cc = F.cosine_similarity(
208
+ expert_projected[i], expert_projected[j], dim=-1)
209
+ cross_cos.append(cc)
210
+ cross_features = torch.stack(cross_cos, dim=-1) if cross_cos else torch.zeros(B, 0, device=emb.device)
211
+
212
+ # Per-sample disagreement: how much do experts disagree on THIS embedding?
213
+ # High disagreement = embedding is in contested territory
214
+ # Low disagreement = all experts agree (well-anchored)
215
+ per_sample_agreement = expert_cos.mean(dim=-1) # (B,) mean round-trip cos
216
+ per_sample_disagreement = expert_cos.std(dim=-1) # (B,) std across experts
217
+ # Ratio: how much agreement relative to disagreement
218
+ disagreement_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8) # (B,)
219
+
220
+ # Expert projection norms before normalization (captures magnitude structure)
221
+ expert_norms = []
222
+ for i in range(self.n_experts):
223
+ R = self.expert_rotations[i]
224
+ W = self.expert_whiteners[i]
225
+ mu = self.expert_means[i]
226
+ centered = emb - mu
227
+ whitened = centered @ W
228
+ expert_norms.append(whitened.norm(dim=-1)) # (B,)
229
+ expert_norm_features = torch.stack(expert_norms, dim=-1) # (B, n_experts)
230
+ norm_ratio = expert_norm_features / (expert_norm_features.mean(dim=-1, keepdim=True) + 1e-8)
231
+
232
+ # ── Anchor distances ──
233
+ anchors_n = F.normalize(self.anchors, dim=-1)
234
+ anchor_cos = emb @ anchors_n.T # (B, n_anchors)
235
+
236
+ # ── Geometric context ──
237
+ # Full feature set: expert consistency + reconstruction + cross-expert +
238
+ # disagreement ratio + norm ratios + anchor distances
239
+ geo_input = torch.cat([
240
+ expert_cos, # (B, n_experts)
241
+ expert_mse, # (B, n_experts)
242
+ cross_features, # (B, n_cross)
243
+ disagreement_ratio.unsqueeze(-1), # (B, 1)
244
+ norm_ratio, # (B, n_experts)
245
+ anchor_cos, # (B, n_anchors)
246
+ ], dim=-1)
247
+ geo_context = self.geo_proj(geo_input)
248
+
249
+ enriched = torch.cat([embedding, geo_context], dim=-1)
250
+
251
+ # ── Losses + Diagnostics ──
252
+ aux = {}
253
+
254
+ # 1. Expert agreement: all experts should see embedding equally
255
+ expert_mean = expert_cos.mean(dim=-1, keepdim=True)
256
+ aux["expert_agreement"] = (expert_cos - expert_mean).pow(2).mean()
257
+
258
+ # 2. Rotation orthogonality
259
+ ortho_loss = 0.0
260
+ for i in range(self.n_experts):
261
+ R = self.expert_rotations[i]
262
+ RRT = R @ R.T
263
+ ortho_loss += (RRT - torch.eye(self.d_embed, device=R.device)).pow(2).mean()
264
+ aux["rotation_ortho"] = ortho_loss / self.n_experts
265
+
266
+ # 3. Anchor spread
267
+ anchor_sim = anchors_n @ anchors_n.T
268
+ anchor_sim.fill_diagonal_(0)
269
+ aux["anchor_spread"] = anchor_sim.pow(2).mean()
270
+
271
+ # 4. Anchor sharpness
272
+ anchor_probs = F.softmax(anchor_cos * 10, dim=-1)
273
+ entropy = -(anchor_probs * (anchor_probs + 1e-12).log()).sum(-1).mean()
274
+ aux["anchor_entropy"] = entropy
275
+
276
+ # 5. Cross-expert differentiation consistency
277
+ if cross_features.shape[1] > 0:
278
+ aux["cross_expert_var"] = cross_features.var(dim=0).mean()
279
+ else:
280
+ aux["cross_expert_var"] = torch.tensor(0.0, device=emb.device)
281
+
282
+ # 6. Disagreement preservation
283
+ # The distribution of disagreement should stay at the measured target
284
+ batch_cross_mean = cross_features.mean() if cross_features.shape[1] > 0 else torch.tensor(0.0, device=emb.device)
285
+ batch_cross_std = cross_features.std() if cross_features.shape[1] > 0 else torch.tensor(0.0, device=emb.device)
286
+ batch_disagree_ratio = disagreement_ratio.mean()
287
+ aux["disagree_preserve"] = (
288
+ (batch_cross_mean - self.target_cross_cos_mean).pow(2) +
289
+ (batch_cross_std - self.target_cross_cos_std).pow(2) +
290
+ (batch_disagree_ratio - self.target_disagreement_ratio).pow(2)
291
+ )
292
+
293
+ # 7. Bank CV
294
+ if B >= 10:
295
+ ctx_n = F.normalize(geo_context, dim=-1)
296
+ vols = []
297
+ for _ in range(32):
298
+ idx = torch.randperm(B, device=embedding.device)[:5]
299
+ pts = ctx_n[idx].unsqueeze(0)
300
+ diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
301
+ d2 = (diff * diff).sum(-1)
302
+ Bv, V, _ = d2.shape
303
+ cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32)
304
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
305
+ s = (-1.0)**V; f = math.factorial(V-1)
306
+ v2 = s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
307
+ vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
308
+ stacked = torch.stack(vols)
309
+ bank_cv = stacked.std() / (stacked.mean() + 1e-8)
310
+ aux["bank_cv"] = bank_cv
311
+ else:
312
+ aux["bank_cv"] = torch.tensor(0.0, device=embedding.device)
313
+
314
+ # 8. Emb CV
315
+ if B >= 10:
316
+ emb_n = F.normalize(emb, dim=-1)
317
+ vols = []
318
+ for _ in range(32):
319
+ idx = torch.randperm(B, device=embedding.device)[:5]
320
+ pts = emb_n[idx].unsqueeze(0)
321
+ diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
322
+ d2 = (diff * diff).sum(-1)
323
+ Bv, V, _ = d2.shape
324
+ cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32)
325
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
326
+ s = (-1.0)**V; f = math.factorial(V-1)
327
+ v2 = s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
328
+ vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
329
+ stacked = torch.stack(vols)
330
+ emb_cv = stacked.std() / (stacked.mean() + 1e-8)
331
+ aux["emb_cv"] = emb_cv
332
+ else:
333
+ aux["emb_cv"] = torch.tensor(0.0, device=embedding.device)
334
+
335
+ # Diagnostics
336
+ aux["expert_cos_mean"] = expert_cos.mean().item()
337
+ aux["expert_cos_std"] = expert_cos.std().item()
338
+ aux["anchor_max_cos"] = anchor_cos.max(dim=-1).values.mean().item()
339
+ aux["anchor_mean_cos"] = anchor_cos.mean().item()
340
+ if cross_features.shape[1] > 0:
341
+ aux["cross_expert_cos"] = cross_features.mean().item()
342
+ aux["cross_expert_cos_std"] = cross_features.std().item()
343
+ aux["disagreement_ratio"] = disagreement_ratio.mean().item()
344
+ aux["norm_ratio_spread"] = norm_ratio.std(dim=-1).mean().item()
345
+
346
+ return enriched, aux
347
+
348
+ def bank_loss(self, aux):
349
+ """All targets from measured consensus. Preserves disagreement structure."""
350
+ loss = (
351
+ 1.0 * aux["expert_agreement"] +
352
+ 1.0 * aux["rotation_ortho"] +
353
+ 0.5 * aux["anchor_spread"] +
354
+ 0.1 * aux["anchor_entropy"] +
355
+ 0.3 * aux["cross_expert_var"] +
356
+ 0.3 * (aux["bank_cv"] - self.target_cv).abs() +
357
+ 0.3 * (aux["emb_cv"] - self.target_cv).abs() +
358
+ 0.5 * aux["disagree_preserve"] # preserve the disagreement distribution
359
+ )
360
+ return loss
361
+
362
+ @torch.no_grad()
363
+ def calibrate_disagreement(self, embeddings):
364
+ """
365
+ Measure the initial disagreement structure and store as targets.
366
+ Call ONCE after init, before training.
367
+ """
368
+ _, aux = self.forward(embeddings)
369
+ if "cross_expert_cos" in aux:
370
+ self.target_cross_cos_mean.fill_(aux["cross_expert_cos"])
371
+ if "cross_expert_cos_std" in aux:
372
+ self.target_cross_cos_std.fill_(aux["cross_expert_cos_std"])
373
+ self.target_disagreement_ratio.fill_(aux["disagreement_ratio"])
374
+ print(f" Calibrated disagreement:")
375
+ print(f" cross_cos: {self.target_cross_cos_mean.item():.4f} ± {self.target_cross_cos_std.item():.4f}")
376
+ print(f" disagree_ratio: {self.target_disagreement_ratio.item():.6f}")
377
+
378
+
379
+ # ══════════════════════════════════════════════════════════════════
380
+ # GEOMETRY
381
+ # ══════════════════════════════════════════════════════════════════
382
+
383
+ def infonce(a, b, temperature=0.07):
384
+ a = F.normalize(a, dim=-1)
385
+ b = F.normalize(b, dim=-1)
386
+ logits = (a @ b.T) / temperature
387
+ labels = torch.arange(logits.shape[0], device=logits.device)
388
+ loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
389
+ with torch.no_grad():
390
+ acc = (logits.argmax(-1) == labels).float().mean().item()
391
+ return loss, acc
392
+
393
+ def cayley_menger_vol2(pts):
394
+ pts = pts.float()
395
+ diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
396
+ d2 = (diff * diff).sum(-1)
397
+ B, V, _ = d2.shape
398
+ cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
399
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
400
+ s = (-1.0)**V; f = math.factorial(V-1)
401
+ return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
402
+
403
+ def cv_loss(emb, target=0.12, n_samples=16):
404
+ B = emb.shape[0]
405
+ if B < 5: return torch.tensor(0.0, device=emb.device)
406
+ vols = []
407
+ for _ in range(n_samples):
408
+ idx = torch.randperm(B, device=emb.device)[:5]
409
+ v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
410
+ vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
411
+ stacked = torch.stack(vols)
412
+ cv = stacked.std() / (stacked.mean() + 1e-8)
413
+ return (cv - target).abs()
414
+
415
+ def cv_metric(emb, n=200):
416
+ B = emb.shape[0]
417
+ if B < 5: return 0.0
418
+ vols = []
419
+ for _ in range(n):
420
+ idx = torch.randperm(B, device=emb.device)[:5]
421
+ v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
422
+ v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
423
+ if v > 0: vols.append(v)
424
+ if len(vols) < 10: return 0.0
425
+ a = np.array(vols)
426
+ return float(a.std() / (a.mean() + 1e-8))
427
+
428
+ def measure_consensus_stats(consensus_embs, n_check=2000):
429
+ """Measure exact geometric statistics of the consensus manifold."""
430
+ embs = consensus_embs[:n_check].float()
431
+ # CV
432
+ cv = cv_metric(embs.to(DEVICE))
433
+ # Pairwise cosine
434
+ sim = embs @ embs.T
435
+ mask = ~torch.eye(embs.shape[0], dtype=torch.bool)
436
+ pairwise = sim[mask]
437
+ mean_cos = pairwise.mean().item()
438
+ # Spectral
439
+ centered = embs - embs.mean(0, keepdim=True)
440
+ S = torch.linalg.svdvals(centered)
441
+ S_norm = (S / (S.sum() + 1e-8)).tolist()[:50]
442
+ # Eff dim
443
+ eff_dim = float((S.sum() ** 2) / (S.pow(2).sum() + 1e-12))
444
+
445
+ return {
446
+ "cv": cv,
447
+ "mean_cos": mean_cos,
448
+ "spectral": S_norm,
449
+ "eff_dim": eff_dim,
450
+ }
451
+
452
+
453
+ # ══════════════════════════════════════════════════════════════════
454
+ # EXTRACTION + ALIGNMENT
455
+ # ══════════════════════════════════════════════════════════════════
456
+
457
+ def symmetric_inv_sqrt(cov, eps=1e-6):
458
+ evals, evecs = torch.linalg.eigh(cov)
459
+ evals = torch.clamp(evals, min=eps)
460
+ return evecs @ torch.diag(evals.rsqrt()) @ evecs.T
461
+
462
+ def procrustes_align(source, target, n_align=5000):
463
+ N = min(n_align, source.shape[0], target.shape[0])
464
+ S = source[:N].float(); T = target[:N].float()
465
+ s_mean = S.mean(0, keepdim=True); t_mean = T.mean(0, keepdim=True)
466
+ Sc = S - s_mean; Tc = T - t_mean; N_s = Sc.shape[0]
467
+ cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item()
468
+ s_cov = (Sc.T @ Sc) / max(N_s - 1, 1)
469
+ t_cov = (Tc.T @ Tc) / max(N_s - 1, 1)
470
+ s_whiten = symmetric_inv_sqrt(s_cov)
471
+ t_whiten = symmetric_inv_sqrt(t_cov)
472
+ Sc_w = F.normalize(Sc @ s_whiten, dim=-1)
473
+ Tc_w = F.normalize(Tc @ t_whiten, dim=-1)
474
+ U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False)
475
+ R = U @ Vt
476
+ cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item()
477
+ return {
478
+ "rotation": R, "source_mean": s_mean.squeeze(0),
479
+ "source_whitener": s_whiten,
480
+ "target_unwhitener": torch.linalg.pinv(t_whiten),
481
+ "cos_before": cos_before, "cos_after": cos_after,
482
+ }
483
+
484
+ def apply_align(emb, a):
485
+ x = emb.float() - a["source_mean"]
486
+ x = x @ a["source_whitener"]; x = x @ a["rotation"].T
487
+ x = x @ a["target_unwhitener"]; return x
488
+
489
+
490
+ # ══════════════════════════════════════════════════════════════════
491
+ # MAIN
492
+ # ══════════════════════════════════════════════════════════════════
493
+
494
+ def run():
495
+ torch.manual_seed(42)
496
+ np.random.seed(42)
497
+ N_SAMPLES = 20000
498
+ MAX_LEN = 128
499
+ BATCH = 256
500
+
501
+ # ── Phase 0: Extract ──
502
+ print(f"\n{'='*65}")
503
+ print("PHASE 0: EXTRACTION")
504
+ print(f"{'='*65}")
505
+
506
+ from datasets import load_dataset
507
+ from transformers import AutoModel, AutoTokenizer
508
+
509
+ ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext",
510
+ split="train", streaming=True)
511
+ captions = []
512
+ for row in ds:
513
+ cap = row.get("caption_llava", "")
514
+ if isinstance(cap, str) and len(cap) > 50:
515
+ captions.append(cap)
516
+ if len(captions) >= N_SAMPLES:
517
+ break
518
+ print(f" Captions: {len(captions):,}")
519
+
520
+ embeds = {}
521
+ for model_name, short, max_len in EXPERTS:
522
+ print(f"\n Extracting: {short}...")
523
+ model = AutoModel.from_pretrained(model_name).to(DEVICE).eval()
524
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
525
+ all_emb = []
526
+ with torch.no_grad():
527
+ for i in tqdm(range(0, len(captions), 128), desc=f" {short}"):
528
+ batch = captions[i:i+128]
529
+ inputs = tokenizer(batch, max_length=max_len, padding=True,
530
+ truncation=True, return_tensors="pt").to(DEVICE)
531
+ out = model(**inputs)
532
+ m = inputs.attention_mask.unsqueeze(-1).float()
533
+ pooled = (out.last_hidden_state * m).sum(1) / m.sum(1).clamp(min=1)
534
+ all_emb.append(pooled.cpu())
535
+ embeds[short] = torch.cat(all_emb)
536
+ print(f" Shape: {embeds[short].shape}")
537
+ del model; gc.collect(); torch.cuda.empty_cache()
538
+
539
+ # ── Phase 0b: Align + Consensus + Measure ──
540
+ print(f"\n{'='*65}")
541
+ print("PHASE 0b: GENERALIZED PROCRUSTES ALIGNMENT (no reference bias)")
542
+ print(f"{'='*65}")
543
+
544
+ names = [s for _, s, _ in EXPERTS]
545
+
546
+ # Generalized Procrustes: iteratively align all to their mean
547
+ # No expert is the reference. The centerpoint emerges.
548
+ GPA_ITERS = 10
549
+ current = {name: embeds[name].float() for name in names}
550
+
551
+ for gpa_iter in range(GPA_ITERS):
552
+ # Compute mean shape
553
+ mean_shape = sum(current[n] for n in names) / len(names)
554
+
555
+ # Align each to mean
556
+ new_current = {}
557
+ total_delta = 0.0
558
+ for name in names:
559
+ info = procrustes_align(current[name], mean_shape)
560
+ new_current[name] = apply_align(current[name], info)
561
+ # Measure how much this iteration changed things
562
+ delta = (new_current[name] - current[name]).pow(2).mean().item()
563
+ total_delta += delta
564
+
565
+ current = new_current
566
+ if gpa_iter == 0 or (gpa_iter + 1) % 3 == 0 or total_delta < 1e-8:
567
+ print(f" GPA iter {gpa_iter+1}: delta={total_delta:.8f}")
568
+ if total_delta < 1e-8:
569
+ print(f" Converged at iteration {gpa_iter+1}")
570
+ break
571
+
572
+ # Final alignment: align each expert to the converged mean
573
+ mean_shape = sum(current[n] for n in names) / len(names)
574
+ procrustes_results = {}
575
+ aligned = {}
576
+ for name in names:
577
+ info = procrustes_align(embeds[name], mean_shape)
578
+ procrustes_results[name] = info
579
+ aligned[name] = apply_align(embeds[name], info)
580
+ cos = F.cosine_similarity(
581
+ aligned[name][:2000], mean_shape[:2000], dim=-1).mean().item()
582
+ print(f" {name:10s}: cos_after={info['cos_after']:.4f} cos_to_mean={cos:.4f}")
583
+
584
+ # Consensus = normalized centroid (now equidistant from all experts)
585
+ consensus = F.normalize(sum(aligned[n] for n in names) / len(names), dim=-1)
586
+ for name in names:
587
+ cos = F.cosine_similarity(consensus[:2000], aligned[name][:2000], dim=-1).mean().item()
588
+ print(f" cos(consensus, {name}): {cos:.4f}")
589
+
590
+ # Verify equidistance
591
+ expert_cos_to_consensus = []
592
+ for name in names:
593
+ c = F.cosine_similarity(consensus[:2000], aligned[name][:2000], dim=-1).mean().item()
594
+ expert_cos_to_consensus.append(c)
595
+ equidist_range = max(expert_cos_to_consensus) - min(expert_cos_to_consensus)
596
+ print(f" Equidistance range: {equidist_range:.4f} (should be near 0)")
597
+
598
+ # Measure EXACT consensus statistics
599
+ print(f"\n Measuring consensus statistics...")
600
+ consensus_stats = measure_consensus_stats(consensus)
601
+ print(f" CV: {consensus_stats['cv']:.4f}")
602
+ print(f" Mean cos: {consensus_stats['mean_cos']:.4f}")
603
+ print(f" Eff dim: {consensus_stats['eff_dim']:.1f}")
604
+ print(f" Spectral: [{', '.join(f'{s:.4f}' for s in consensus_stats['spectral'][:5])}...]")
605
+
606
+ del embeds, aligned
607
+ gc.collect(); torch.cuda.empty_cache()
608
+
609
+ # ── Phase 1: Train Student ──
610
+ print(f"\n{'='*65}")
611
+ print("PHASE 1: TRAIN STUDENT")
612
+ print(f"{'='*65}")
613
+
614
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
615
+ tokens = tokenizer(captions, max_length=MAX_LEN, padding="max_length",
616
+ truncation=True, return_tensors="pt")
617
+ input_ids = tokens["input_ids"]
618
+ attention_mask = tokens["attention_mask"]
619
+
620
+ n_train = N_SAMPLES - 2000
621
+ train_ids = input_ids[:n_train].to(DEVICE)
622
+ train_mask = attention_mask[:n_train].to(DEVICE)
623
+ train_targets = consensus[:n_train].to(DEVICE)
624
+ val_ids = input_ids[n_train:].to(DEVICE)
625
+ val_mask = attention_mask[n_train:].to(DEVICE)
626
+ val_targets = consensus[n_train:].to(DEVICE)
627
+
628
+ student = MiniStudent(
629
+ vocab_size=tokenizer.vocab_size, max_len=MAX_LEN,
630
+ d_model=256, n_heads=4, n_layers=4, d_ff=1024,
631
+ output_dim=768, dropout=0.1, pad_token_id=tokenizer.pad_token_id
632
+ ).to(DEVICE)
633
+ n_params = sum(p.numel() for p in student.parameters())
634
+ print(f" Student: {n_params:,} params")
635
+ print(f" CV target: {consensus_stats['cv']:.4f}")
636
+
637
+ optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay=0.01)
638
+
639
+ for epoch in range(5):
640
+ student.train()
641
+ perm = torch.randperm(n_train, device=DEVICE)
642
+ t_loss, t_acc, t_cos, n = 0, 0, 0, 0
643
+ t0 = time.time()
644
+ for i in range(0, n_train, BATCH):
645
+ idx = perm[i:i+BATCH]
646
+ if len(idx) < 8: continue
647
+ emb = student(train_ids[idx], train_mask[idx])
648
+ tgt = train_targets[idx]
649
+ l_nce, acc = infonce(emb, tgt)
650
+ l_mse = F.mse_loss(emb, tgt)
651
+ l_cv = cv_loss(emb, target=consensus_stats["cv"])
652
+ loss = l_nce + l_mse + 0.1 * l_cv
653
+ loss.backward()
654
+ torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
655
+ optimizer.step(); optimizer.zero_grad(set_to_none=True)
656
+ with torch.no_grad():
657
+ cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item()
658
+ t_loss += loss.item(); t_acc += acc; t_cos += cos; n += 1
659
+ elapsed = time.time() - t0; d = max(n, 1)
660
+ student.eval()
661
+ with torch.no_grad():
662
+ v_emb = student(val_ids, val_mask)
663
+ _, v_acc = infonce(v_emb[:1000], val_targets[:1000])
664
+ v_cos = F.cosine_similarity(v_emb, val_targets, dim=-1).mean().item()
665
+ v_cv = cv_metric(v_emb[:1000])
666
+ print(f" E{epoch+1}: {elapsed:.0f}s loss={t_loss/d:.4f} "
667
+ f"t_acc={t_acc/d:.3f} t_cos={t_cos/d:.3f} "
668
+ f"v_acc={v_acc:.3f} v_cos={v_cos:.3f} v_cv={v_cv:.3f}")
669
+
670
+ torch.save(student.state_dict(), "mini_student.pt")
671
+ print(f"\n Student saved. v_cos={v_cos:.3f}, v_cv={v_cv:.3f}")
672
+
673
+ # ── Phase 2: Train Alignment Bank ──
674
+ print(f"\n{'='*65}")
675
+ print("PHASE 2: TRAIN ALIGNMENT BANK (student frozen)")
676
+ print(f"{'='*65}")
677
+
678
+ student.eval()
679
+ for p in student.parameters():
680
+ p.requires_grad = False
681
+
682
+ print(" Pre-encoding through frozen student...")
683
+ with torch.no_grad():
684
+ all_embs = []
685
+ for i in range(0, n_train, 512):
686
+ j = min(i + 512, n_train)
687
+ emb = student(train_ids[i:j], train_mask[i:j])
688
+ all_embs.append(emb)
689
+ student_embs = torch.cat(all_embs)
690
+ val_student_embs = student(val_ids, val_mask)
691
+ print(f" Student embeddings: {student_embs.shape}")
692
+
693
+ bank = AlignmentBank(
694
+ d_embed=768, n_experts=len(EXPERTS),
695
+ n_anchors=512, d_bank=128
696
+ ).to(DEVICE)
697
+
698
+ bank.init_from_procrustes(procrustes_results, names,
699
+ consensus[:n_train], consensus_stats)
700
+ bank_params = sum(p.numel() for p in bank.parameters())
701
+ print(f" Bank: {bank_params:,} params")
702
+ print(f" Bank targets: CV={bank.target_cv.item():.4f}, "
703
+ f"mean_cos={bank.target_mean_cos.item():.4f}")
704
+
705
+ # Calibrate disagreement from initial state (before any training)
706
+ bank.calibrate_disagreement(student_embs[:2000])
707
+
708
+ bank_opt = torch.optim.AdamW(bank.parameters(), lr=1e-3, weight_decay=0.01)
709
+ BANK_EPOCHS = 20
710
+ BANK_BATCH = 256
711
+
712
+ for epoch in range(BANK_EPOCHS):
713
+ bank.train()
714
+ perm = torch.randperm(n_train, device=DEVICE)
715
+ total_loss = 0
716
+ stats = {"expert_agreement": 0, "rotation_ortho": 0,
717
+ "anchor_spread": 0, "bank_cv": 0, "emb_cv": 0,
718
+ "cross_expert_var": 0, "disagree_preserve": 0}
719
+ n = 0
720
+ t0 = time.time()
721
+ for i in range(0, n_train, BANK_BATCH):
722
+ idx = perm[i:i+BANK_BATCH]
723
+ if len(idx) < 16: continue
724
+ emb = student_embs[idx]
725
+ enriched, aux = bank(emb)
726
+ loss = bank.bank_loss(aux)
727
+ loss.backward()
728
+ torch.nn.utils.clip_grad_norm_(bank.parameters(), 1.0)
729
+ bank_opt.step(); bank_opt.zero_grad(set_to_none=True)
730
+ total_loss += loss.item()
731
+ for k in stats:
732
+ if k in aux:
733
+ v = aux[k]
734
+ stats[k] += v.item() if torch.is_tensor(v) else v
735
+ n += 1
736
+ elapsed = time.time() - t0; d = max(n, 1)
737
+
738
+ bank.eval()
739
+ with torch.no_grad():
740
+ v_enriched, v_aux = bank(val_student_embs)
741
+ v_loss = bank.bank_loss(v_aux).item()
742
+
743
+ print(f"\n E{epoch+1:2d}: {elapsed:.0f}s loss={total_loss/d:.4f} v_loss={v_loss:.4f}")
744
+ print(f" Geometry: b_cv={stats['bank_cv']/d:.4f} e_cv={stats['emb_cv']/d:.4f} "
745
+ f"spread={stats['anchor_spread']/d:.5f} a_max={v_aux['anchor_max_cos']:.3f}")
746
+ print(f" Experts: cos={v_aux['expert_cos_mean']:.3f}±{v_aux['expert_cos_std']:.3f} "
747
+ f"agr={stats['expert_agreement']/d:.6f} ortho={stats['rotation_ortho']/d:.6f}")
748
+ print(f" Disagree: x_cos={v_aux.get('cross_expert_cos', 0):.4f}±{v_aux.get('cross_expert_cos_std', 0):.4f} "
749
+ f"ratio={v_aux['disagreement_ratio']:.6f} "
750
+ f"preserve={stats['disagree_preserve']/d:.6f} "
751
+ f"norms={v_aux['norm_ratio_spread']:.4f}")
752
+
753
+ torch.save(bank.state_dict(), "alignment_bank.pt")
754
+
755
+ # ── Phase 3: Geometric Verification ──
756
+ print(f"\n{'='*65}")
757
+ print("PHASE 3: GEOMETRIC VERIFICATION")
758
+ print(f"{'='*65}")
759
+
760
+ bank.eval()
761
+ with torch.no_grad():
762
+ enriched_val, v_aux = bank(val_student_embs)
763
+ original_768 = enriched_val[:, :768]
764
+ geo_context = enriched_val[:, 768:]
765
+
766
+ passthrough_cos = F.cosine_similarity(
767
+ original_768[:100], val_student_embs[:100], dim=-1).mean().item()
768
+ geo_cv = cv_metric(F.normalize(geo_context[:1000], dim=-1))
769
+ S = torch.linalg.svdvals(
770
+ geo_context[:1000].float() - geo_context[:1000].float().mean(0))
771
+ geo_eff_dim = float((S.sum() ** 2) / (S.pow(2).sum() + 1e-12))
772
+
773
+ # Verify consensus stats are preserved
774
+ emb_cv = cv_metric(val_student_embs[:1000])
775
+
776
+ print(f" Passthrough: {passthrough_cos:.6f} (target: 1.000)")
777
+ print(f" Emb CV: {emb_cv:.4f} (consensus: {consensus_stats['cv']:.4f})")
778
+ print(f" Geo context CV: {geo_cv:.4f}")
779
+ print(f" Geo eff_dim: {geo_eff_dim:.1f} / {bank.d_bank}")
780
+ print(f" Expert cos: {v_aux['expert_cos_mean']:.3f} ± {v_aux['expert_cos_std']:.3f}")
781
+ print(f" Anchor max cos: {v_aux['anchor_max_cos']:.3f}")
782
+ print(f" Disagreement:")
783
+ print(f" Cross-expert: {v_aux.get('cross_expert_cos', 0):.4f} ± {v_aux.get('cross_expert_cos_std', 0):.4f}")
784
+ print(f" Ratio: {v_aux['disagreement_ratio']:.6f} (target: {bank.target_disagreement_ratio.item():.6f})")
785
+ print(f" Norm spread: {v_aux['norm_ratio_spread']:.4f}")
786
+
787
+ # ── Phase 4: Classifier Stability Test ──
788
+ print(f"\n{'='*65}")
789
+ print("PHASE 4: CLASSIFIER STABILITY TEST")
790
+ print(f"{'='*65}")
791
+
792
+ with torch.no_grad():
793
+ embs = val_student_embs[:1000]
794
+ sim = embs @ embs.T
795
+ sim.fill_diagonal_(-1)
796
+ n_pairs = 3000
797
+ idx_a = torch.randint(0, 1000, (n_pairs,))
798
+ idx_b = torch.randint(0, 1000, (n_pairs,))
799
+ pair_cos = sim[idx_a, idx_b]
800
+ sorted_cos, _ = pair_cos.sort()
801
+ t1 = sorted_cos[n_pairs // 3].item()
802
+ t2 = sorted_cos[2 * n_pairs // 3].item()
803
+ labels = torch.zeros(n_pairs, dtype=torch.long, device=DEVICE)
804
+ labels[pair_cos > t2] = 0
805
+ labels[(pair_cos <= t2) & (pair_cos > t1)] = 1
806
+ labels[pair_cos <= t1] = 2
807
+ enriched_a, _ = bank(embs[idx_a])
808
+ enriched_b, _ = bank(embs[idx_b])
809
+
810
+ for mode in ["with_bank", "without_bank"]:
811
+ if mode == "with_bank":
812
+ feat_dim = (768 + 128) * 2
813
+ features = torch.cat([enriched_a, enriched_b], dim=-1)
814
+ else:
815
+ feat_dim = 768 * 2
816
+ features = torch.cat([embs[idx_a], embs[idx_b]], dim=-1)
817
+
818
+ clf = nn.Sequential(
819
+ nn.Linear(feat_dim, 256), nn.GELU(), nn.LayerNorm(256),
820
+ nn.Linear(256, 3)
821
+ ).to(DEVICE)
822
+ clf_opt = torch.optim.Adam(clf.parameters(), lr=1e-3)
823
+ n_clf_train = 2400
824
+ train_f = features[:n_clf_train].detach()
825
+ train_l = labels[:n_clf_train]
826
+ val_f = features[n_clf_train:].detach()
827
+ val_l = labels[n_clf_train:]
828
+ for e in range(30):
829
+ clf.train()
830
+ logits = clf(train_f)
831
+ loss = F.cross_entropy(logits, train_l)
832
+ loss.backward(); clf_opt.step(); clf_opt.zero_grad()
833
+ clf.eval()
834
+ with torch.no_grad():
835
+ v_acc = (clf(val_f).argmax(-1) == val_l).float().mean().item()
836
+ t_acc = (clf(train_f).argmax(-1) == train_l).float().mean().item()
837
+ print(f" {mode:15s}: train={t_acc:.3f} val={v_acc:.3f} gap={t_acc-v_acc:.3f}")
838
+
839
+ print(f"\n{'='*65}")
840
+ print("SUMMARY")
841
+ print(f"{'='*65}")
842
+ print(f" Consensus CV: {consensus_stats['cv']:.4f}")
843
+ print(f" Consensus eff_dim:{consensus_stats['eff_dim']:.1f}")
844
+ print(f" Student v_cos: {v_cos:.3f}")
845
+ print(f" Student v_cv: {v_cv:.3f}")
846
+ print(f" Bank params: {bank_params:,}")
847
+ print(f" Bank geo_eff_dim: {geo_eff_dim:.1f}")
848
+ print(f" Bank geo_cv: {geo_cv:.4f}")
849
+ print(f"\n{'='*65}")
850
+ print("DONE")
851
+ print(f"{'='*65}")
852
+
853
+
854
+ if __name__ == "__main__":
855
+ run()