AbstractPhil commited on
Commit
b37f720
Β·
verified Β·
1 Parent(s): 15c3a5a

Create 10_run_finetune.py

Browse files
Files changed (1) hide show
  1. 10_run_finetune.py +339 -0
10_run_finetune.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ cell_r_runner.py β€” Phase R: sphere-packing prediction test
3
+
4
+ Trains 3 configs whose (V, D) match natural sphere polytopes:
5
+ D=4, V=16: 16-cell vertices on SΒ³
6
+ D=4, V=8: 8-cell / 16-cell vertex subset on SΒ³
7
+ D=3, V=20: dodecahedron vertices on SΒ²
8
+
9
+ Hypothesis: each will produce H2-LIKE rows (high stability, low antipodal
10
+ pairs, full rank utilization) because V points uniformly fit S^(D-1) for
11
+ these counts. The G-Class behavior at (V=32, D=3) was geometric frustration
12
+ β€” natural V's should reproduce H2 sphere-solver character.
13
+
14
+ After training, immediately runs the v3 probe metrics on each model:
15
+ - per-sample sphere-norm
16
+ - row stability across 512 gaussian inputs
17
+ - antipodal pair fraction
18
+ - per-sample silhouette
19
+ - effective rank
20
+ - pairwise angle distribution
21
+
22
+ Outputs:
23
+ /content/phaseR_reports/results_phaseR.json β€” training results + probes
24
+ /content/phaseR_reports/phaseR_summary.png β€” H2-LIKE / G-LIKE verdicts
25
+ """
26
+
27
+ import json
28
+ import math
29
+ import time
30
+ import traceback
31
+ from pathlib import Path
32
+
33
+ import numpy as np
34
+ import torch
35
+ import torch.nn.functional as F
36
+ import matplotlib.pyplot as plt
37
+ from sklearn.cluster import KMeans
38
+ from sklearn.metrics import silhouette_score
39
+
40
+
41
+ OUTPUT_ROOT = Path("/content/phaseR_reports")
42
+ OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
43
+ AGGREGATE_PATH = OUTPUT_ROOT / "results_phaseR.json"
44
+ SUMMARY_PLOT = OUTPUT_ROOT / "phaseR_summary.png"
45
+
46
+
47
+ # ════════════════════════════════════════════════════════════════════
48
+ # Geometric probe (compact version of v3)
49
+ # ════════════════════════════════════════════════════════════════════
50
+
51
+ def collect_M(model, cfg, n_batches=8, batch_size=64):
52
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53
+ model = model.to(device)
54
+ ds = OmegaNoiseDataset(
55
+ size=n_batches * batch_size, img_size=cfg.img_size,
56
+ allowed_types=[0])
57
+ loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False)
58
+
59
+ all_M = []
60
+ with torch.no_grad():
61
+ for imgs, _ in loader:
62
+ imgs = imgs.to(device)
63
+ out = model(imgs)
64
+ M_patch0 = out['svd']['M'][:, 0]
65
+ all_M.append(M_patch0.cpu())
66
+ return torch.cat(all_M, dim=0).numpy()
67
+
68
+
69
+ def probe_geometry(all_M):
70
+ """Return all v3 probe metrics in one dict."""
71
+ # sphere-norm
72
+ row_norms = np.linalg.norm(all_M, axis=2)
73
+ sphere_normed = abs(row_norms.mean() - 1.0) < 0.05 and row_norms.std() < 0.05
74
+
75
+ # row stability
76
+ mean_dirs = all_M.mean(axis=0)
77
+ mean_dir_norms = np.linalg.norm(mean_dirs, axis=1)
78
+
79
+ # per-sample silhouette (k=5 if Vβ‰₯10 else k=V//2)
80
+ V = all_M.shape[1]
81
+ k_test = min(5, max(2, V // 2))
82
+ sils = []
83
+ for i in range(min(20, all_M.shape[0])):
84
+ try:
85
+ km = KMeans(n_clusters=k_test, n_init=10, random_state=42)
86
+ labels = km.fit_predict(all_M[i])
87
+ if len(set(labels)) >= 2:
88
+ sils.append(silhouette_score(all_M[i], labels))
89
+ except Exception:
90
+ pass
91
+ sils = np.array(sils)
92
+
93
+ # angular
94
+ all_rows = all_M.reshape(-1, all_M.shape[-1])
95
+ norms = np.linalg.norm(all_rows, axis=1, keepdims=True)
96
+ unit_rows = all_rows / np.clip(norms, 1e-12, None)
97
+ n_subset = min(500, unit_rows.shape[0])
98
+ idx = np.random.RandomState(42).choice(unit_rows.shape[0], n_subset, replace=False)
99
+ cosines = unit_rows[idx] @ unit_rows[idx].T
100
+ pairwise_angles = np.arccos(
101
+ np.clip(cosines[np.triu_indices(n_subset, k=1)], -1, 1))
102
+
103
+ # antipodal
104
+ unit_dirs = mean_dirs / np.clip(
105
+ np.linalg.norm(mean_dirs, axis=1, keepdims=True), 1e-12, None)
106
+ cos_mat = unit_dirs @ unit_dirs.T
107
+ np.fill_diagonal(cos_mat, 1.0)
108
+ most_anti = cos_mat.min(axis=1)
109
+
110
+ # effective rank
111
+ M_avg = all_M.mean(axis=0)
112
+ sv = np.linalg.svd(M_avg, compute_uv=False)
113
+ sv_norm = sv / sv.sum()
114
+ erank = math.exp(-(sv_norm * np.log(sv_norm + 1e-12)).sum())
115
+
116
+ return {
117
+ 'sphere_normed': bool(sphere_normed),
118
+ 'row_norm_mean': float(row_norms.mean()),
119
+ 'stability_mean': float(mean_dir_norms.mean()),
120
+ 'stability_min': float(mean_dir_norms.min()),
121
+ 'stability_max': float(mean_dir_norms.max()),
122
+ 'silhouette_mean': float(sils.mean()) if len(sils) else None,
123
+ 'silhouette_std': float(sils.std()) if len(sils) else None,
124
+ 'angular_mean': float(pairwise_angles.mean()),
125
+ 'angular_near_pi': float((pairwise_angles > math.pi - 0.5).mean()),
126
+ 'angular_near_perp': float(
127
+ ((pairwise_angles > math.pi/2 - 0.3) &
128
+ (pairwise_angles < math.pi/2 + 0.3)).mean()),
129
+ 'antipodal_frac': float((most_anti < -0.9).mean()),
130
+ 'antipodal_pairs': int((most_anti < -0.9).sum() // 2),
131
+ 'antipodal_max_pairs': int(all_M.shape[1] // 2),
132
+ 'effective_rank': float(erank),
133
+ 'D': int(all_M.shape[2]),
134
+ 'utilization': float(erank / all_M.shape[2]),
135
+ }
136
+
137
+
138
+ def classify_character(probe):
139
+ """H2-LIKE / G-LIKE / DIFFUSE / HYBRID β€” same logic as v3."""
140
+ stab = probe['stability_mean']
141
+ anti = probe['antipodal_frac']
142
+ util = probe['utilization']
143
+
144
+ if stab > 0.85 and anti < 0.55 and util > 0.95:
145
+ return 'H2-LIKE'
146
+ if stab < 0.65 and anti > 0.80:
147
+ return 'G-LIKE'
148
+ if stab < 0.65 and anti < 0.55:
149
+ return 'DIFFUSE'
150
+ return 'HYBRID'
151
+
152
+
153
+ # ════════════════════════════════════════════════════════════════════
154
+ # Build trained model from a Q-style report
155
+ # ════════════════════════════════════════════════════════════════════
156
+
157
+ def build_model_from_config(ablation_config):
158
+ """Build the model architecture (without loaded weights). After
159
+ training, load from checkpoint."""
160
+ cfg = build_run_config(ablation_config)
161
+ overrides = ablation_config['overrides']
162
+ model = PatchSVAE_F_Ablation(
163
+ matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size,
164
+ hidden=cfg.hidden, depth=cfg.depth,
165
+ n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads,
166
+ max_alpha=overrides.get('max_alpha', cfg.max_alpha),
167
+ alpha_init=cfg.alpha_init,
168
+ activation=overrides.get('activation', 'gelu'),
169
+ row_norm=overrides.get('row_norm', 'sphere'),
170
+ svd_mode=overrides.get('svd', 'fp64'),
171
+ linear_readout=overrides.get('linear_readout', False),
172
+ match_params=overrides.get('match_params', True),
173
+ init_scheme=overrides.get('init', 'orthogonal'),
174
+ )
175
+ return model, cfg
176
+
177
+
178
+ def load_trained(ablation_config, output_dir):
179
+ """Load the trained model's weights from its epoch checkpoint."""
180
+ model, cfg = build_model_from_config(ablation_config)
181
+ ckpt_path = Path(output_dir) / "epoch_1_checkpoint.pt"
182
+ ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
183
+ state_dict = (
184
+ ckpt.get('model_state')
185
+ or ckpt.get('model_state_dict')
186
+ or ckpt.get('state_dict')
187
+ or ckpt
188
+ )
189
+ model.load_state_dict(state_dict)
190
+ model.eval()
191
+ return model, cfg
192
+
193
+
194
+ # ════════════════════════════════════════════════════════════════════
195
+ # Main
196
+ # ════════════════════════════════════════════════════════════════════
197
+
198
+ def run_sweep_with_probes():
199
+ configs = get_phaseR_configs()
200
+ print(f"Phase R: {len(configs)} packed-polytope test configs")
201
+ print(f"Output: {OUTPUT_ROOT}\n")
202
+
203
+ print("Predicted: each config produces H2-LIKE static rows because")
204
+ print("(V, D) matches a natural sphere polytope vertex count.\n")
205
+
206
+ print("Config lineup:")
207
+ for cfg in configs:
208
+ ov = cfg['overrides']
209
+ print(f" {cfg['variant']:<45} V={ov['V']} D={ov['D']}")
210
+ print()
211
+
212
+ results = []
213
+ sweep_t0 = time.time()
214
+
215
+ for i, cfg in enumerate(configs):
216
+ print(f"[{i+1}/{len(configs)}] {cfg['variant']}")
217
+ config_output_dir = OUTPUT_ROOT / cfg['variant']
218
+ config_output_dir.mkdir(exist_ok=True)
219
+
220
+ # ── Train ──
221
+ t0 = time.time()
222
+ try:
223
+ report = run_ablation_config(
224
+ ablation_config=cfg,
225
+ output_dir=str(config_output_dir),
226
+ batch_limit=phase2_batch_limit(cfg),
227
+ num_epochs=cfg.get('num_epochs', 1),
228
+ )
229
+ report['_sweep_status'] = 'ok'
230
+ train_time = time.time() - t0
231
+
232
+ g_mse = report.get('test_mse_per_noise', {}).get(0,
233
+ report.get('test_mse_per_noise', {}).get('0'))
234
+ cv = report.get('observed_sphere_cv', 0.0)
235
+ print(f" train: {train_time:.0f}s, "
236
+ f"G-MSE={g_mse:.5f}, CV={cv:.3f}")
237
+
238
+ # ── Probe geometry ──
239
+ print(f" probe: collecting M rows + running v3 metrics...", end=' ', flush=True)
240
+ t1 = time.time()
241
+ try:
242
+ model, run_cfg = load_trained(cfg, config_output_dir)
243
+ all_M = collect_M(model, run_cfg)
244
+ probe = probe_geometry(all_M)
245
+ probe['M_shape'] = list(all_M.shape)
246
+ probe['character'] = classify_character(probe)
247
+ report['probe'] = probe
248
+ print(f"{time.time()-t1:.0f}s β†’ {probe['character']}")
249
+ print(f" stability={probe['stability_mean']:.3f}, "
250
+ f"antipodal={probe['antipodal_pairs']}/"
251
+ f"{probe['antipodal_max_pairs']}, "
252
+ f"utilization={probe['utilization']*100:.0f}%")
253
+ except Exception as e:
254
+ report['probe'] = {'error': f'{type(e).__name__}: {str(e)[:300]}'}
255
+ print(f"FAILED: {type(e).__name__}: {str(e)[:80]}")
256
+
257
+ except Exception as e:
258
+ report = {
259
+ '_sweep_status': f'error: {type(e).__name__}: {str(e)[:300]}',
260
+ '_traceback': traceback.format_exc()[:2000],
261
+ 'config': cfg,
262
+ 'variant': cfg['variant'],
263
+ }
264
+ print(f" ERROR: {type(e).__name__}: {str(e)[:80]}")
265
+
266
+ report['variant'] = cfg['variant']
267
+ report['wallclock_outer_s'] = time.time() - t0
268
+ results.append(report)
269
+
270
+ with open(AGGREGATE_PATH, 'w') as f:
271
+ json.dump(results, f, indent=2, default=str)
272
+ print()
273
+
274
+ # ════════════════════════════════════════════════════════════════
275
+ # Verdict summary
276
+ # ════════════════════════════════════════════════════════════════
277
+
278
+ print("=" * 70)
279
+ print("PHASE R RESULTS β€” sphere-packing hypothesis test")
280
+ print("=" * 70)
281
+
282
+ print(f"\n{'Variant':<45} {'G-MSE':>9} {'Char':>10} {'Stab':>6} {'Anti':>10}")
283
+ print("-" * 85)
284
+
285
+ n_h2like = 0
286
+ n_glike = 0
287
+ for r in results:
288
+ v = r.get('variant', '?')
289
+ probe = r.get('probe', {})
290
+ if 'error' in probe:
291
+ print(f"{v[:45]:<45} {'N/A':>9} {'PROBE_ERR':>10}")
292
+ continue
293
+ g_mse = r.get('test_mse_per_noise', {}).get(0,
294
+ r.get('test_mse_per_noise', {}).get('0', float('nan')))
295
+ char = probe.get('character', '?')
296
+ stab = probe.get('stability_mean', 0)
297
+ ap_pairs = probe.get('antipodal_pairs', 0)
298
+ ap_max = probe.get('antipodal_max_pairs', 0)
299
+ print(f"{v[:45]:<45} {g_mse:>9.5f} {char:>10} {stab:>6.3f} "
300
+ f"{f'{ap_pairs}/{ap_max}':>10}")
301
+ if char == 'H2-LIKE':
302
+ n_h2like += 1
303
+ elif char == 'G-LIKE':
304
+ n_glike += 1
305
+
306
+ print(f"\n H2-LIKE: {n_h2like}/{len(results)}")
307
+ print(f" G-LIKE: {n_glike}/{len(results)}")
308
+
309
+ print("\n" + "=" * 70)
310
+ print("INTERPRETATION")
311
+ print("=" * 70)
312
+
313
+ if n_h2like == len(results):
314
+ print(" All 3 packed-polytope configs produced H2-LIKE batteries.")
315
+ print(" β†’ Sphere-packing hypothesis CONFIRMED.")
316
+ print(" β†’ G-Class is a SYMPTOM of (V, D) geometric frustration,")
317
+ print(" not a battery family in its own right.")
318
+ print(" β†’ Useful (V, D) pairs follow polytope vertex counts:")
319
+ print(" D=3: 4, 6, 8, 12, 20 (Platonic)")
320
+ print(" D=4: 5, 8, 16, 24, 120, 600 (4D regular polytopes)")
321
+ print(" Dβ‰₯5: most V's work (high-D sphere-packing flexible)")
322
+ elif n_h2like > 0:
323
+ print(f" Mixed: {n_h2like}/{len(results)} produced H2-LIKE.")
324
+ print(" β†’ Hypothesis partially supported but more nuanced.")
325
+ print(" β†’ Some packed-polytope V's work, others don't.")
326
+ else:
327
+ print(" No H2-LIKE batteries produced.")
328
+ print(" β†’ Sphere-packing hypothesis FALSIFIED.")
329
+ print(" β†’ G-Class behavior has a different cause.")
330
+
331
+ total = time.time() - sweep_t0
332
+ print(f"\nTotal time: {total/60:.1f} min")
333
+ print(f"Aggregate: {AGGREGATE_PATH}")
334
+
335
+ return results
336
+
337
+
338
+ if __name__ == '__main__':
339
+ results = run_sweep_with_probes()