AbstractPhil commited on
Commit
f4ce5fa
Β·
verified Β·
1 Parent(s): b37f720

Create 11_test_claim_probe.py

Browse files
Files changed (1) hide show
  1. 11_test_claim_probe.py +642 -0
11_test_claim_probe.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ implicit_solver/A0_projective_reprobe.py
3
+ =========================================
4
+
5
+ Test claim 3: G-Cand is actually a 14-axis ℝPΒ² solver, not a 32-point SΒ² solver.
6
+
7
+ Method
8
+ ------
9
+ 1. Load G-Cand (Q-rank09, V=32, D=3) β€” already trained sphere-solver.
10
+ 2. Collect M tensor as before β€” 512 samples Γ— 32 rows Γ— 3 dims.
11
+ 3. Identify antipodal pairs in the canonical M_avg arrangement:
12
+ row i and row j form a pair if cos(M_avg[i], M_avg[j]) < -0.9
13
+ 4. Collapse: for each pair, pick canonical representative (the one with
14
+ positive first nonzero coordinate). Yields up to 16 axis representatives.
15
+ 5. Re-run v2 probe metrics under projective geometry:
16
+ - Pairwise angles wrapped to [0, Ο€/2] via ΞΈ β†’ min(ΞΈ, Ο€ - ΞΈ)
17
+ - Uniform ℝPΒ² baseline: pairwise angles peak at Ο€/4 (not Ο€/2)
18
+ - Cluster, stability, antipodal-of-antipodal (testing if axes themselves
19
+ have further antipodal structure within ℝPΒ²)
20
+
21
+ Predicted outcomes
22
+ ------------------
23
+ A. CLEAN PROJECTIVE: 14 axes uniformly cover ℝPΒ², pairwise angles peak at
24
+ Ο€/4, no further antipodal collapse.
25
+ β†’ G-Cand is a clean 14-axis ℝPΒ² solver. Sphere-norm was the wrong
26
+ reading. The true geometry is projective.
27
+
28
+ B. STILL DEGENERATE: 14 axes show further structure (clustering, secondary
29
+ antipodal pairs, non-uniform).
30
+ β†’ G-Cand is structured beyond simple ℝPΒ² uniform. Some other geometry
31
+ applies, or the antipodal collapse hypothesis is incomplete.
32
+
33
+ C. ANTI-PROJECTIVE: 14 axes are NOT uniformly distributed on ℝPΒ²; they
34
+ show strong clustering or aligned-direction patterns.
35
+ β†’ The "spindle collapse" was real but isn't ℝPΒ² either. The geometry is
36
+ something more degenerate (line, plane subset, etc.)
37
+
38
+ Cost
39
+ ----
40
+ Same trained checkpoint, different probe math. ~10 seconds.
41
+
42
+ Output
43
+ ------
44
+ /content/implicit_solver_reports/A0_projective_reprobe.json
45
+ /content/implicit_solver_reports/A0_projective_reprobe.png
46
+ """
47
+
48
+ import json
49
+ import math
50
+ from pathlib import Path
51
+
52
+ import numpy as np
53
+ import torch
54
+ import torch.nn.functional as F
55
+ import matplotlib.pyplot as plt
56
+ from mpl_toolkits.mplot3d import Axes3D # noqa
57
+ from sklearn.cluster import KMeans
58
+ from sklearn.metrics import silhouette_score
59
+
60
+
61
+ CKPT_DIR = Path("/content/phaseQ_reports")
62
+ RANK09_CKPT = CKPT_DIR / "Q_rank09_h64_V32_D3_dp0_nx0_adam" / "epoch_1_checkpoint.pt"
63
+
64
+ OUTPUT_DIR = Path("/content/implicit_solver_reports")
65
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
66
+ OUTPUT_PLOT = OUTPUT_DIR / "A0_projective_reprobe.png"
67
+ OUTPUT_JSON = OUTPUT_DIR / "A0_projective_reprobe.json"
68
+
69
+
70
+ # ════════════════════════════════════════════════════════════════════
71
+ # Loading
72
+ # ════════════════════════════════════════════════════════════════════
73
+
74
+ def load_g_cand():
75
+ cfgs = get_phaseQ_configs()
76
+ cfg_dict = next(c for c in cfgs if 'rank09' in c['variant'])
77
+ cfg = build_run_config(cfg_dict)
78
+ overrides = cfg_dict['overrides']
79
+
80
+ model = PatchSVAE_F_Ablation(
81
+ matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size,
82
+ hidden=cfg.hidden, depth=cfg.depth,
83
+ n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads,
84
+ max_alpha=overrides.get('max_alpha', cfg.max_alpha),
85
+ alpha_init=cfg.alpha_init,
86
+ activation=overrides.get('activation', 'gelu'),
87
+ row_norm=overrides.get('row_norm', 'sphere'),
88
+ svd_mode=overrides.get('svd', 'fp64'),
89
+ linear_readout=overrides.get('linear_readout', False),
90
+ match_params=overrides.get('match_params', True),
91
+ init_scheme=overrides.get('init', 'orthogonal'),
92
+ )
93
+
94
+ ckpt = torch.load(RANK09_CKPT, map_location='cpu', weights_only=False)
95
+ state_dict = (
96
+ ckpt.get('model_state')
97
+ or ckpt.get('model_state_dict')
98
+ or ckpt.get('state_dict')
99
+ or ckpt
100
+ )
101
+ model.load_state_dict(state_dict)
102
+ model.eval()
103
+ return model, cfg
104
+
105
+
106
+ def collect_per_sample_M(model, cfg, n_batches=8, batch_size=64):
107
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
108
+ model = model.to(device)
109
+ ds = OmegaNoiseDataset(
110
+ size=n_batches * batch_size, img_size=cfg.img_size,
111
+ allowed_types=[0])
112
+ loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False)
113
+
114
+ all_M = []
115
+ with torch.no_grad():
116
+ for imgs, _ in loader:
117
+ imgs = imgs.to(device)
118
+ out = model(imgs)
119
+ M_patch0 = out['svd']['M'][:, 0]
120
+ all_M.append(M_patch0.cpu())
121
+ return torch.cat(all_M, dim=0).numpy()
122
+
123
+
124
+ # ════════════════════════════════════════════════════════════════════
125
+ # Antipodal pair identification + projective collapse
126
+ # ════════════════════════════════════════════════════════════════════
127
+
128
+ def identify_antipodal_pairs(M_avg, threshold=-0.9):
129
+ """For each row, find its antipodal partner (cos < threshold).
130
+
131
+ Returns (pairs, unpaired):
132
+ pairs: list of (i, j) tuples where i < j and rows i, j are antipodal
133
+ unpaired: list of row indices with no antipodal partner
134
+
135
+ Greedy matching: each row pairs with its strongest antipodal candidate
136
+ that hasn't been claimed yet.
137
+ """
138
+ norms = np.linalg.norm(M_avg, axis=1, keepdims=True)
139
+ unit = M_avg / np.clip(norms, 1e-12, None)
140
+ cosines = unit @ unit.T
141
+ np.fill_diagonal(cosines, 1.0) # exclude self
142
+
143
+ V = M_avg.shape[0]
144
+ claimed = [False] * V
145
+ pairs = []
146
+
147
+ # Sort rows by their strongest antipodal candidate (most negative cos)
148
+ # Greedy claim β€” strongest pairings get priority
149
+ candidates = []
150
+ for i in range(V):
151
+ best_j = int(cosines[i].argmin())
152
+ best_cos = float(cosines[i, best_j])
153
+ if best_cos < threshold:
154
+ candidates.append((best_cos, i, best_j))
155
+ candidates.sort() # most negative first
156
+
157
+ for cos_val, i, j in candidates:
158
+ if claimed[i] or claimed[j]:
159
+ continue
160
+ # Verify symmetry: j's strongest is also i (or close enough)
161
+ if cosines[j].argmin() == i or cosines[j, i] < threshold:
162
+ pairs.append((min(i, j), max(i, j)))
163
+ claimed[i] = True
164
+ claimed[j] = True
165
+
166
+ unpaired = [i for i in range(V) if not claimed[i]]
167
+ return pairs, unpaired
168
+
169
+
170
+ def collapse_to_axes(M_avg, pairs, unpaired):
171
+ """Pick canonical representative for each pair: the row with positive
172
+ first nonzero coordinate. Unpaired rows stay as-is.
173
+
174
+ Returns axes [n_axes, D] where n_axes = len(pairs) + len(unpaired)."""
175
+ norms = np.linalg.norm(M_avg, axis=1, keepdims=True)
176
+ unit = M_avg / np.clip(norms, 1e-12, None)
177
+
178
+ representatives = []
179
+ for i, j in pairs:
180
+ # Pick the row whose first nonzero coordinate is positive
181
+ for r in [unit[i], unit[j]]:
182
+ for k in range(r.shape[0]):
183
+ if abs(r[k]) > 1e-6:
184
+ chosen = r if r[k] > 0 else -r
185
+ representatives.append(chosen)
186
+ break
187
+ else:
188
+ # All zeros (shouldn't happen on sphere) β€” pick row i
189
+ representatives.append(unit[i])
190
+ break
191
+ else:
192
+ continue
193
+ # We added one rep; continue to next pair
194
+
195
+ # Actually the above structure is wrong β€” let me redo cleanly:
196
+ representatives = []
197
+ for i, j in pairs:
198
+ # Average of row_i and -row_j (since they're antipodal, this enhances
199
+ # the shared axis direction)
200
+ merged = unit[i] - unit[j]
201
+ merged = merged / max(np.linalg.norm(merged), 1e-12)
202
+ # Canonicalize sign: first nonzero coord positive
203
+ for k in range(merged.shape[0]):
204
+ if abs(merged[k]) > 1e-6:
205
+ if merged[k] < 0:
206
+ merged = -merged
207
+ break
208
+ representatives.append(merged)
209
+
210
+ for i in unpaired:
211
+ r = unit[i].copy()
212
+ # Same canonical sign convention
213
+ for k in range(r.shape[0]):
214
+ if abs(r[k]) > 1e-6:
215
+ if r[k] < 0:
216
+ r = -r
217
+ break
218
+ representatives.append(r)
219
+
220
+ return np.array(representatives)
221
+
222
+
223
+ # ════════════════════════════════════════════════════════════════════
224
+ # Projective metrics
225
+ # ════════════════════════════════════════════════════════════════════
226
+
227
+ def projective_pairwise_angles(axes):
228
+ """Angles between axes on ℝP^(D-1). Each axis is a line through origin,
229
+ so angle between two axes is min(ΞΈ, Ο€-ΞΈ) ∈ [0, Ο€/2]."""
230
+ n = axes.shape[0]
231
+ cosines = axes @ axes.T
232
+ cosines = np.clip(cosines, -1, 1)
233
+ # On ℝP^(D-1): two axes are equivalent under sign flip
234
+ # so the "true" angle is the smaller of ΞΈ and Ο€-ΞΈ
235
+ raw_angles = np.arccos(cosines)
236
+ proj_angles = np.minimum(raw_angles, np.pi - raw_angles)
237
+
238
+ triu = np.triu_indices(n, k=1)
239
+ return proj_angles[triu]
240
+
241
+
242
+ def uniform_rp_pairwise_angle_baseline(D, n_axes, n_trials=10):
243
+ """Predicted pairwise-angle distribution for n_axes uniformly placed
244
+ on ℝP^(D-1). Sample uniformly on S^(D-1), antipodally identify, compute
245
+ pairwise angles."""
246
+ rng = np.random.RandomState(0)
247
+ means = []
248
+ for _ in range(n_trials):
249
+ # Sample n_axes uniformly on S^(D-1)
250
+ x = rng.randn(n_axes, D)
251
+ x = x / np.linalg.norm(x, axis=1, keepdims=True)
252
+ # Canonicalize to upper hemisphere (positive first coord)
253
+ for k in range(D):
254
+ sign = np.sign(x[:, k])
255
+ sign[sign == 0] = 1
256
+ mask = (x[:, k] != 0) & (np.all(x[:, :k] == 0, axis=1) if k > 0 else np.ones(n_axes, dtype=bool))
257
+ x[mask] = x[mask] * sign[mask, None]
258
+ if np.all(x[:, k] != 0):
259
+ break
260
+ angles = projective_pairwise_angles(x)
261
+ means.append(angles.mean())
262
+ return float(np.mean(means))
263
+
264
+
265
+ def test_axis_distribution(axes, label):
266
+ """Run all probe metrics on the projective axes."""
267
+ D = axes.shape[1]
268
+ n = axes.shape[0]
269
+
270
+ print(f"\n[{label}]")
271
+ print(f" Axes shape: {axes.shape}")
272
+
273
+ # Pairwise angles in projective space
274
+ proj_angles = projective_pairwise_angles(axes)
275
+
276
+ print(f" Projective pairwise angles (radians, max possible Ο€/2={math.pi/2:.3f}):")
277
+ print(f" mean: {proj_angles.mean():.3f}")
278
+ print(f" median: {np.median(proj_angles):.3f}")
279
+ print(f" min: {proj_angles.min():.3f}")
280
+ print(f" max: {proj_angles.max():.3f}")
281
+
282
+ # Predicted uniform baseline for ℝP^(D-1)
283
+ uniform_baseline = uniform_rp_pairwise_angle_baseline(D, n)
284
+ deviation = proj_angles.mean() - uniform_baseline
285
+ print(f" Uniform ℝP^{D-1} baseline (n={n}): {uniform_baseline:.3f}")
286
+ print(f" Deviation: {deviation:+.3f} "
287
+ f"({'CLOSE TO UNIFORM' if abs(deviation) < 0.05 else 'NON-UNIFORM'})")
288
+
289
+ # Fraction at small angles (axis clustering)
290
+ fraction_clustered = (proj_angles < 0.3).mean()
291
+ fraction_perp = ((proj_angles > math.pi/4 - 0.15) &
292
+ (proj_angles < math.pi/4 + 0.15)).mean()
293
+ print(f" Fraction near-zero (axes parallel): {fraction_clustered:.3f}")
294
+ print(f" Fraction near Ο€/4 (uniform peak): {fraction_perp:.3f}")
295
+
296
+ # Cluster analysis on the axes themselves (not on the original M)
297
+ sils = []
298
+ for k in range(2, min(8, n)):
299
+ try:
300
+ km = KMeans(n_clusters=k, n_init=10, random_state=42)
301
+ labels = km.fit_predict(axes)
302
+ if len(set(labels)) >= 2:
303
+ sils.append((k, silhouette_score(axes, labels)))
304
+ except Exception:
305
+ pass
306
+
307
+ if sils:
308
+ best_k, best_sil = max(sils, key=lambda x: x[1])
309
+ print(f" Best cluster k={best_k}, silhouette={best_sil:.3f}")
310
+ cluster_verdict = (
311
+ 'STRONG (real clusters)' if best_sil > 0.5 else
312
+ 'WEAK (some structure)' if best_sil > 0.3 else
313
+ 'NONE (continuous distribution)'
314
+ )
315
+ print(f" Cluster verdict: {cluster_verdict}")
316
+ else:
317
+ best_k, best_sil = None, None
318
+ cluster_verdict = 'N/A'
319
+
320
+ # Effective rank of the axis matrix
321
+ sv = np.linalg.svd(axes, compute_uv=False)
322
+ sv_norm = sv / sv.sum()
323
+ erank = math.exp(-(sv_norm * np.log(sv_norm + 1e-12)).sum())
324
+ print(f" Effective rank: {erank:.2f} of {D} possible "
325
+ f"({erank/D*100:.0f}% utilization)")
326
+
327
+ # Test for SECONDARY antipodal structure within the axes
328
+ # If axes still show antipodal pairs, the geometry is more degenerate
329
+ # than ℝP^(D-1) β€” possibly ℝP^(D-1) / β„€β‚‚ or projection to even lower dim
330
+ cos_axes = axes @ axes.T
331
+ np.fill_diagonal(cos_axes, 1.0)
332
+ most_anti = cos_axes.min(axis=1)
333
+ secondary_anti = (most_anti < -0.9).sum() // 2
334
+ print(f" Secondary antipodal pairs (axes paired again): "
335
+ f"{secondary_anti}/{n//2}")
336
+
337
+ return {
338
+ 'n_axes': int(n),
339
+ 'D': int(D),
340
+ 'proj_angle_mean': float(proj_angles.mean()),
341
+ 'proj_angle_median': float(np.median(proj_angles)),
342
+ 'proj_angle_min': float(proj_angles.min()),
343
+ 'proj_angle_max': float(proj_angles.max()),
344
+ 'uniform_baseline': uniform_baseline,
345
+ 'deviation_from_uniform': float(deviation),
346
+ 'fraction_clustered': float(fraction_clustered),
347
+ 'fraction_near_pi_over_4': float(fraction_perp),
348
+ 'best_cluster_k': best_k,
349
+ 'best_silhouette': best_sil,
350
+ 'cluster_verdict': cluster_verdict,
351
+ 'effective_rank': float(erank),
352
+ 'utilization': float(erank / D),
353
+ 'secondary_antipodal_pairs': int(secondary_anti),
354
+ 'proj_angles_subset': proj_angles[:200].tolist(),
355
+ }
356
+
357
+
358
+ # ════════════════════════════════════════════════════════════════════
359
+ # Plotting
360
+ # ════════════════════════════════════════════════════════════════════
361
+
362
+ def plot_projective(M_avg, axes, pairs, unpaired, results, output_path):
363
+ fig = plt.figure(figsize=(18, 12))
364
+
365
+ # Panel 1: Original M_avg on SΒ² with pairings highlighted
366
+ ax1 = fig.add_subplot(2, 3, 1, projection='3d')
367
+ norms = np.linalg.norm(M_avg, axis=1, keepdims=True)
368
+ unit = M_avg / np.clip(norms, 1e-12, None)
369
+
370
+ # Wireframe sphere
371
+ u = np.linspace(0, 2*np.pi, 20)
372
+ v = np.linspace(0, np.pi, 20)
373
+ x_s = np.outer(np.cos(u), np.sin(v))
374
+ y_s = np.outer(np.sin(u), np.sin(v))
375
+ z_s = np.outer(np.ones_like(u), np.cos(v))
376
+ ax1.plot_wireframe(x_s, y_s, z_s, alpha=0.1, color='gray')
377
+
378
+ # Color paired rows by pair index
379
+ pair_colors = plt.cm.tab20(np.linspace(0, 1, max(len(pairs), 1)))
380
+ for k, (i, j) in enumerate(pairs):
381
+ color = pair_colors[k]
382
+ ax1.scatter(unit[i, 0], unit[i, 1], unit[i, 2],
383
+ c=[color], s=80, edgecolors='black', linewidths=0.5)
384
+ ax1.scatter(unit[j, 0], unit[j, 1], unit[j, 2],
385
+ c=[color], s=80, edgecolors='black', linewidths=0.5)
386
+ # Line connecting the antipodal pair
387
+ ax1.plot([unit[i, 0], unit[j, 0]],
388
+ [unit[i, 1], unit[j, 1]],
389
+ [unit[i, 2], unit[j, 2]],
390
+ color=color, alpha=0.3, linewidth=0.8)
391
+ # Unpaired rows in red
392
+ for i in unpaired:
393
+ ax1.scatter(unit[i, 0], unit[i, 1], unit[i, 2],
394
+ c='red', marker='x', s=100, linewidths=2)
395
+ ax1.set_title(f'Original M_avg on SΒ²\n'
396
+ f'{len(pairs)} antipodal pairs (colored), '
397
+ f'{len(unpaired)} unpaired (red Γ—)')
398
+
399
+ # Panel 2: Collapsed axes on upper hemisphere (canonical reps)
400
+ ax2 = fig.add_subplot(2, 3, 2, projection='3d')
401
+ ax2.plot_wireframe(x_s, y_s, z_s, alpha=0.1, color='gray')
402
+ for k, ax in enumerate(axes):
403
+ ax2.scatter(ax[0], ax[1], ax[2], c=[plt.cm.tab20(k % 20)],
404
+ s=120, edgecolors='black', linewidths=0.5)
405
+ # Draw line through origin to show it's an AXIS not a point
406
+ ax2.plot([-ax[0], ax[0]], [-ax[1], ax[1]], [-ax[2], ax[2]],
407
+ color=plt.cm.tab20(k % 20), alpha=0.4, linewidth=1.0)
408
+ ax2.set_title(f'Collapsed axes (n={axes.shape[0]})\n'
409
+ f'Each line through origin = one axis on ℝPΒ²')
410
+
411
+ # Panel 3: Projective angle distribution vs uniform baseline
412
+ ax3 = fig.add_subplot(2, 3, 3)
413
+ proj_angles = results['proj_angles_subset']
414
+ ax3.hist(proj_angles, bins=30, density=True, alpha=0.7,
415
+ color='steelblue', label='G-Cand projective')
416
+ ax3.axvline(results['uniform_baseline'], color='red', linestyle='--',
417
+ label=f"uniform ℝPΒ² baseline ({results['uniform_baseline']:.3f})")
418
+ ax3.axvline(math.pi/4, color='green', linestyle=':',
419
+ label=f'Ο€/4 = {math.pi/4:.3f}')
420
+ ax3.set_xlabel('Projective pairwise angle (radians, max Ο€/2)')
421
+ ax3.set_ylabel('Density')
422
+ ax3.set_title(f'Projective angle distribution\n'
423
+ f"deviation: {results['deviation_from_uniform']:+.3f}")
424
+ ax3.legend(fontsize=8)
425
+
426
+ # Panel 4: Cluster silhouette across k
427
+ ax4 = fig.add_subplot(2, 3, 4)
428
+ if results['best_cluster_k'] is not None:
429
+ ks_sils = []
430
+ for k in range(2, min(8, axes.shape[0])):
431
+ try:
432
+ km = KMeans(n_clusters=k, n_init=10, random_state=42)
433
+ labels = km.fit_predict(axes)
434
+ if len(set(labels)) >= 2:
435
+ ks_sils.append((k, silhouette_score(axes, labels)))
436
+ except Exception:
437
+ pass
438
+ if ks_sils:
439
+ ks, sils = zip(*ks_sils)
440
+ ax4.plot(ks, sils, 'o-', color='purple', markersize=8)
441
+ ax4.axhline(0.5, color='red', linestyle='--', alpha=0.5,
442
+ label='strong cluster')
443
+ ax4.axhline(0.3, color='orange', linestyle='--', alpha=0.5,
444
+ label='weak cluster')
445
+ ax4.set_xlabel('k (number of clusters)')
446
+ ax4.set_ylabel('silhouette score')
447
+ ax4.set_title(f"Axis clustering\n"
448
+ f"verdict: {results['cluster_verdict']}")
449
+ ax4.legend(fontsize=8)
450
+ ax4.grid(alpha=0.3)
451
+
452
+ # Panel 5: Effective rank bar
453
+ ax5 = fig.add_subplot(2, 3, 5)
454
+ sv = np.linalg.svd(axes, compute_uv=False)
455
+ ax5.bar([f'Οƒ{i+1}' for i in range(len(sv))], sv,
456
+ color=['red', 'orange', 'yellow'][:len(sv)])
457
+ ax5.set_ylabel('Singular value')
458
+ ax5.set_title(f"Singular values of axis matrix\n"
459
+ f"effective rank: {results['effective_rank']:.2f} "
460
+ f"of {results['D']}")
461
+
462
+ # Panel 6: Composite verdict
463
+ ax6 = fig.add_subplot(2, 3, 6)
464
+ ax6.axis('off')
465
+
466
+ # Decide composite verdict
467
+ is_uniform = abs(results['deviation_from_uniform']) < 0.05
468
+ is_clustered = (results['best_silhouette'] or 0) > 0.5
469
+ has_secondary_antipodal = results['secondary_antipodal_pairs'] >= 3
470
+ full_rank = results['utilization'] > 0.95
471
+
472
+ if is_uniform and not is_clustered and not has_secondary_antipodal and full_rank:
473
+ verdict = "βœ“ CLEAN ℝPΒ² SOLVER"
474
+ explanation = (
475
+ "G-Cand was a 14-axis projective-space solver all along.\n"
476
+ "Sphere-norm was the wrong reading β€” the true geometry\n"
477
+ "is uniform on ℝPΒ². Claim 3 SUPPORTED."
478
+ )
479
+ color = 'lightgreen'
480
+ elif is_uniform and not is_clustered and full_rank:
481
+ verdict = "βœ“ MOSTLY ℝPΒ², minor irregularities"
482
+ explanation = (
483
+ "Axes are roughly uniform on ℝPΒ² with some structure.\n"
484
+ "Claim 3 PARTIALLY SUPPORTED β€” projective interpretation\n"
485
+ "is the right space, but distribution isn't perfectly uniform."
486
+ )
487
+ color = 'palegreen'
488
+ elif is_clustered:
489
+ verdict = "βœ— STRUCTURED on ℝPΒ²"
490
+ explanation = (
491
+ "Axes show genuine cluster structure on ℝPΒ².\n"
492
+ "Not uniform, not random β€” something more specific.\n"
493
+ "May be a polytope on ℝPΒ² (less common) or other geometry."
494
+ )
495
+ color = 'lightyellow'
496
+ elif has_secondary_antipodal:
497
+ verdict = "βœ— FURTHER COLLAPSE"
498
+ explanation = (
499
+ "Even after antipodal collapse, axes show NEW antipodal pairs.\n"
500
+ "Geometry is more degenerate than ℝPΒ² β€” possibly lens space,\n"
501
+ "or D-effective lower than D=3."
502
+ )
503
+ color = 'mistyrose'
504
+ elif not full_rank:
505
+ verdict = "βœ— DEGENERATE β€” sub-rank"
506
+ explanation = (
507
+ "Axes don't span the full D=3 space.\n"
508
+ "Effective rank < 3 means rows live on a 2D plane or 1D line\n"
509
+ "in 3D space. Spindle hypothesis dimension-collapsed."
510
+ )
511
+ color = 'lightcoral'
512
+ else:
513
+ verdict = "? UNCLEAR"
514
+ explanation = (
515
+ "Mixed signals β€” re-examine the metrics individually.\n"
516
+ "Antipodal hypothesis neither confirmed nor cleanly refuted."
517
+ )
518
+ color = 'lightgray'
519
+
520
+ ax6.text(0.5, 0.85, verdict, ha='center', va='top',
521
+ fontsize=20, fontweight='bold',
522
+ bbox=dict(boxstyle='round', facecolor=color, alpha=0.8))
523
+ ax6.text(0.05, 0.55, explanation, ha='left', va='top', fontsize=11,
524
+ wrap=True, family='monospace')
525
+
526
+ metrics_summary = (
527
+ f"\n\nKey metrics:\n"
528
+ f" axes: {results['n_axes']}\n"
529
+ f" proj angle mean: {results['proj_angle_mean']:.3f}\n"
530
+ f" uniform baseline: {results['uniform_baseline']:.3f}\n"
531
+ f" deviation: {results['deviation_from_uniform']:+.3f}\n"
532
+ f" best cluster silhouette: {results['best_silhouette'] or 0:.3f}\n"
533
+ f" effective rank: {results['effective_rank']:.2f}/{results['D']}\n"
534
+ f" secondary antipodal: {results['secondary_antipodal_pairs']}"
535
+ )
536
+ ax6.text(0.05, 0.30, metrics_summary, ha='left', va='top',
537
+ fontsize=10, family='monospace')
538
+
539
+ plt.tight_layout()
540
+ plt.savefig(output_path, dpi=120, bbox_inches='tight')
541
+ plt.show()
542
+
543
+
544
+ # ════════════════════════════════════════════════════════════════════
545
+ # Main
546
+ # ════════════════════════════════════════════════════════════════════
547
+
548
+ def main():
549
+ print("=" * 70)
550
+ print("Projective re-probe of G-Cand (Q-rank09, V=32, D=3)")
551
+ print("Testing claim 3: trained sphere-solver is actually a ℝPΒ² solver")
552
+ print("=" * 70)
553
+
554
+ print("\nLoading G-Cand checkpoint...")
555
+ model, cfg = load_g_cand()
556
+ print(f" V={cfg.matrix_v}, D={cfg.D}, "
557
+ f"params={sum(p.numel() for p in model.parameters()):,}")
558
+
559
+ print("\nCollecting M tensor (512 gaussian samples)...")
560
+ all_M = collect_per_sample_M(model, cfg)
561
+ M_avg = all_M.mean(axis=0)
562
+ print(f" M_avg shape: {M_avg.shape}")
563
+
564
+ print("\nIdentifying antipodal pairs (cos < -0.9)...")
565
+ pairs, unpaired = identify_antipodal_pairs(M_avg, threshold=-0.9)
566
+ print(f" Found {len(pairs)} antipodal pairs")
567
+ print(f" Unpaired rows: {len(unpaired)}")
568
+ print(f" Total accounted: {2*len(pairs) + len(unpaired)} of {M_avg.shape[0]}")
569
+
570
+ print("\nCollapsing to projective axes...")
571
+ axes = collapse_to_axes(M_avg, pairs, unpaired)
572
+ print(f" Axes: {axes.shape[0]} representatives in {axes.shape[1]}-D")
573
+
574
+ # ── Run probe metrics under projective interpretation ──
575
+ results = test_axis_distribution(axes, "G-Cand projective axes")
576
+
577
+ # ── Save ──
578
+ output_data = {
579
+ 'config': {
580
+ 'variant': 'Q_rank09_h64_V32_D3_dp0_nx0_adam',
581
+ 'V': cfg.matrix_v,
582
+ 'D': cfg.D,
583
+ },
584
+ 'antipodal_pairs_found': len(pairs),
585
+ 'unpaired_rows': len(unpaired),
586
+ 'total_axes': axes.shape[0],
587
+ 'projective_metrics': results,
588
+ 'pairs': [list(p) for p in pairs],
589
+ 'unpaired': unpaired,
590
+ }
591
+
592
+ with open(OUTPUT_JSON, 'w') as f:
593
+ json.dump(output_data, f, indent=2, default=str)
594
+ print(f"\nSaved: {OUTPUT_JSON}")
595
+
596
+ plot_projective(M_avg, axes, pairs, unpaired, results, OUTPUT_PLOT)
597
+ print(f"Saved: {OUTPUT_PLOT}")
598
+
599
+ # ── Headline conclusion ──
600
+ print("\n" + "=" * 70)
601
+ print("CONCLUSION")
602
+ print("=" * 70)
603
+
604
+ is_uniform = abs(results['deviation_from_uniform']) < 0.05
605
+ is_clustered = (results['best_silhouette'] or 0) > 0.5
606
+ has_secondary_antipodal = results['secondary_antipodal_pairs'] >= 3
607
+ full_rank = results['utilization'] > 0.95
608
+
609
+ if is_uniform and not is_clustered and not has_secondary_antipodal and full_rank:
610
+ print("\nβœ“ CLAIM 3 SUPPORTED:")
611
+ print(" The 14 axes are uniformly distributed on ℝPΒ² with no")
612
+ print(" further collapse. G-Cand is a 14-axis projective solver.")
613
+ print(" The 'sphere-norm V=32 D=3' was a mislabeling of 14 axes.\n")
614
+ print(" IMPLICATION: For inference, project trained sphere outputs")
615
+ print(" to ℝP^(D-1) and read as axes, not points. The polygonal")
616
+ print(" geometry is implicit in the trained sphere-solver.")
617
+ elif is_clustered:
618
+ print("\nβœ— CLAIM 3 PARTIALLY REFUTED:")
619
+ print(" Axes have cluster structure on ℝPΒ² β€” they are not")
620
+ print(" uniformly distributed. Either the projective space isn't")
621
+ print(" the right reading, or the clusters reveal a finer polytope")
622
+ print(" structure (e.g., axes prefer specific directions on ℝPΒ²).")
623
+ elif has_secondary_antipodal:
624
+ print("\nβœ— CLAIM 3 REFUTED β€” geometry collapses further:")
625
+ print(" Axes show NEW antipodal pairs after the first collapse.")
626
+ print(" G-Cand has more degenerate geometry than ℝPΒ² β€” possibly")
627
+ print(" effective dimension < 3.")
628
+ elif not full_rank:
629
+ print("\nβœ— CLAIM 3 REFUTED β€” dimension collapse:")
630
+ print(" Effective rank of the axes is below 3. The trained model")
631
+ print(" used less than the full D=3 space.")
632
+ else:
633
+ print("\n? CLAIM 3 PARTIALLY SUPPORTED:")
634
+ print(" Axes are full-rank and don't show secondary collapse,")
635
+ print(" but distribution deviates from uniform ℝPΒ² baseline.")
636
+ print(" Some structure beyond simple uniform projection.")
637
+
638
+ return output_data
639
+
640
+
641
+ if __name__ == '__main__':
642
+ results = main()