# ============================================================================ # RAPID PROTOTYPE v2: Differentiation-Centered Alignment Bank # # The bank aligns to the DIFFERENTIATION between experts, not to any # arbitrary target. The consensus CV, spectral profile, and pairwise # statistics measured during alignment become the exact targets. # # The bank embodies the centerpoint of expert disagreement. # ============================================================================ import gc import math import os import time import json import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm DEVICE = "cuda" if torch.cuda.is_available() else "cpu" EXPERTS = [ ("google-bert/bert-base-uncased", "bert", 512), ("answerdotai/ModernBERT-base", "modern", 512), ] print("=" * 65) print("RAPID PROTOTYPE v2: Differentiation-Centered Bank") print("=" * 65) print(f" Device: {DEVICE}") # ══════════════════════════════════════════════════════════════════ # STUDENT MODEL # ══════════════════════════════════════════════════════════════════ class MiniStudent(nn.Module): def __init__(self, vocab_size=30522, max_len=512, d_model=256, n_heads=4, n_layers=4, d_ff=1024, output_dim=768, dropout=0.1, pad_token_id=0): super().__init__() self.pad_token_id = pad_token_id self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id) self.pos_emb = nn.Embedding(max_len, d_model) self.emb_norm = nn.LayerNorm(d_model) self.emb_drop = nn.Dropout(dropout) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, dropout=dropout, activation="gelu", batch_first=True, norm_first=True) self.encoder = nn.TransformerEncoder( encoder_layer, num_layers=n_layers, enable_nested_tensor=False) self.output_proj = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.LayerNorm(d_model), nn.Linear(d_model, output_dim)) def forward(self, input_ids, attention_mask=None): B, L = input_ids.shape positions = torch.arange(L, device=input_ids.device).unsqueeze(0) x = self.token_emb(input_ids) + self.pos_emb(positions) x = self.emb_drop(self.emb_norm(x)) kpm = ~attention_mask.bool() if attention_mask is not None else (input_ids == self.pad_token_id) x = self.encoder(x, src_key_padding_mask=kpm) mask = attention_mask.unsqueeze(-1).float() if attention_mask is not None else (~kpm).unsqueeze(-1).float() pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1) return F.normalize(self.output_proj(pooled), dim=-1) # ══════════════════════════════════════════════════════════════════ # ALIGNMENT BANK # ══════════════════════════════════════════════════════════════════ class AlignmentBank(nn.Module): """ Differentiation-centered geometric interface. Aligns to the CENTERPOINT between experts — the consensus itself. Stores per-expert rotation matrices (the differentiation structure) and learned anchor landmarks (the consensus manifold topology). The bank doesn't invent geometry. It mirrors the measured consensus. Every loss term pulls toward measured consensus statistics. """ def __init__(self, d_embed=768, n_experts=2, n_anchors=512, d_bank=128): super().__init__() self.d_embed = d_embed self.n_experts = n_experts self.n_anchors = n_anchors self.d_bank = d_bank # Per-expert rotation matrices (differentiation structure) self.expert_rotations = nn.ParameterList([ nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts) ]) # Per-expert whiteners (captures variance structure per expert) self.expert_whiteners = nn.ParameterList([ nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts) ]) # Per-expert means (centering offset per expert) self.expert_means = nn.ParameterList([ nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts) ]) # Anchor bank: consensus landmarks on the hypersphere self.anchors = nn.Parameter( F.normalize(torch.randn(n_anchors, d_embed), dim=-1)) # Project: expert_cos (n) + expert_mse (n) + cross (n*(n-1)/2) + # disagreement_ratio (1) + norm_ratio (n) + anchor_cos (n_anchors) n_cross = n_experts * (n_experts - 1) // 2 geo_dim = n_experts + n_experts + n_cross + 1 + n_experts + n_anchors self.geo_proj = nn.Sequential( nn.Linear(geo_dim, d_bank * 2), nn.GELU(), nn.LayerNorm(d_bank * 2), nn.Linear(d_bank * 2, d_bank), nn.LayerNorm(d_bank), ) # Consensus statistics (set during init, used as exact targets) self.register_buffer("target_cv", torch.tensor(0.12)) self.register_buffer("target_mean_cos", torch.tensor(0.0)) self.register_buffer("target_spectral", torch.zeros(50)) # Disagreement structure (measured once, preserved forever) self.register_buffer("target_cross_cos_mean", torch.tensor(0.0)) self.register_buffer("target_cross_cos_std", torch.tensor(0.0)) self.register_buffer("target_disagreement_ratio", torch.tensor(0.0)) def init_from_procrustes(self, procrustes_results, expert_names, consensus_embeddings=None, consensus_stats=None): """Initialize from consensus training artifacts.""" device = self.anchors.device for i, name in enumerate(expert_names[:self.n_experts]): info = procrustes_results[name] self.expert_rotations[i].data = info["rotation"].float().to(device) if "source_whitener" in info: self.expert_whiteners[i].data = info["source_whitener"].float().to(device) if "source_mean" in info: self.expert_means[i].data = info["source_mean"].float().to(device) print(f" Expert {i} ({name}): rotation + whitener + mean loaded, " f"cos_after={info['cos_after']:.4f}") if consensus_embeddings is not None: n = min(self.n_anchors, consensus_embeddings.shape[0]) indices = torch.linspace(0, consensus_embeddings.shape[0] - 1, n).long() self.anchors.data[:n] = F.normalize( consensus_embeddings[indices].float(), dim=-1).to(device) print(f" Anchors: {n} initialized from consensus embeddings") if consensus_stats is not None: self.target_cv.fill_(consensus_stats["cv"]) self.target_mean_cos.fill_(consensus_stats["mean_cos"]) if "spectral" in consensus_stats: s = torch.tensor(consensus_stats["spectral"][:50], dtype=torch.float32) self.target_spectral[:len(s)] = s.to(device) print(f" Targets: CV={consensus_stats['cv']:.4f}, " f"mean_cos={consensus_stats['mean_cos']:.4f}") def forward(self, embedding): B = embedding.shape[0] emb = embedding.float() # ── Per-expert projections (full whitened Procrustes) ── # Chain: center → whiten → normalize → rotate # This is EXACTLY what was computed during alignment. # The rotation only makes geometric sense in whitened-normalized space. expert_consistency = [] expert_recon = [] expert_projected = [] for i in range(self.n_experts): R = self.expert_rotations[i] W = self.expert_whiteners[i] mu = self.expert_means[i] # Forward: center → whiten → normalize → rotate centered = emb - mu whitened = centered @ W whitened_n = F.normalize(whitened, dim=-1) in_expert = whitened_n @ R.T # now in expert's whitened-normalized space # Round-trip: rotate back (orthogonal, so R.T inverse = R) back = in_expert @ R # Consistency: round-trip should recover whitened_n exactly cos = F.cosine_similarity(whitened_n, back, dim=-1) recon = (whitened_n - back).pow(2).mean(dim=-1) expert_consistency.append(cos) expert_recon.append(recon) expert_projected.append(in_expert) expert_cos = torch.stack(expert_consistency, dim=-1) # (B, n_experts) expert_mse = torch.stack(expert_recon, dim=-1) # (B, n_experts) # ── Cross-expert differentiation ── # How each expert's projection relates to every other expert's projection # This IS the disagreement structure. Preserve it exactly. cross_cos = [] for i in range(self.n_experts): for j in range(i + 1, self.n_experts): cc = F.cosine_similarity( expert_projected[i], expert_projected[j], dim=-1) cross_cos.append(cc) cross_features = torch.stack(cross_cos, dim=-1) if cross_cos else torch.zeros(B, 0, device=emb.device) # Per-sample disagreement: how much do experts disagree on THIS embedding? # High disagreement = embedding is in contested territory # Low disagreement = all experts agree (well-anchored) per_sample_agreement = expert_cos.mean(dim=-1) # (B,) mean round-trip cos per_sample_disagreement = expert_cos.std(dim=-1) # (B,) std across experts # Ratio: how much agreement relative to disagreement disagreement_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8) # (B,) # Expert projection norms before normalization (captures magnitude structure) expert_norms = [] for i in range(self.n_experts): R = self.expert_rotations[i] W = self.expert_whiteners[i] mu = self.expert_means[i] centered = emb - mu whitened = centered @ W expert_norms.append(whitened.norm(dim=-1)) # (B,) expert_norm_features = torch.stack(expert_norms, dim=-1) # (B, n_experts) norm_ratio = expert_norm_features / (expert_norm_features.mean(dim=-1, keepdim=True) + 1e-8) # ── Anchor distances ── anchors_n = F.normalize(self.anchors, dim=-1) anchor_cos = emb @ anchors_n.T # (B, n_anchors) # ── Geometric context ── # Full feature set: expert consistency + reconstruction + cross-expert + # disagreement ratio + norm ratios + anchor distances geo_input = torch.cat([ expert_cos, # (B, n_experts) expert_mse, # (B, n_experts) cross_features, # (B, n_cross) disagreement_ratio.unsqueeze(-1), # (B, 1) norm_ratio, # (B, n_experts) anchor_cos, # (B, n_anchors) ], dim=-1) geo_context = self.geo_proj(geo_input) enriched = torch.cat([embedding, geo_context], dim=-1) # ── Losses + Diagnostics ── aux = {} # 1. Expert agreement: all experts should see embedding equally expert_mean = expert_cos.mean(dim=-1, keepdim=True) aux["expert_agreement"] = (expert_cos - expert_mean).pow(2).mean() # 2. Rotation orthogonality ortho_loss = 0.0 for i in range(self.n_experts): R = self.expert_rotations[i] RRT = R @ R.T ortho_loss += (RRT - torch.eye(self.d_embed, device=R.device)).pow(2).mean() aux["rotation_ortho"] = ortho_loss / self.n_experts # 3. Anchor spread anchor_sim = anchors_n @ anchors_n.T anchor_sim.fill_diagonal_(0) aux["anchor_spread"] = anchor_sim.pow(2).mean() # 4. Anchor sharpness anchor_probs = F.softmax(anchor_cos * 10, dim=-1) entropy = -(anchor_probs * (anchor_probs + 1e-12).log()).sum(-1).mean() aux["anchor_entropy"] = entropy # 5. Cross-expert differentiation consistency if cross_features.shape[1] > 0: aux["cross_expert_var"] = cross_features.var(dim=0).mean() else: aux["cross_expert_var"] = torch.tensor(0.0, device=emb.device) # 6. Disagreement preservation # The distribution of disagreement should stay at the measured target batch_cross_mean = cross_features.mean() if cross_features.shape[1] > 0 else torch.tensor(0.0, device=emb.device) batch_cross_std = cross_features.std() if cross_features.shape[1] > 0 else torch.tensor(0.0, device=emb.device) batch_disagree_ratio = disagreement_ratio.mean() aux["disagree_preserve"] = ( (batch_cross_mean - self.target_cross_cos_mean).pow(2) + (batch_cross_std - self.target_cross_cos_std).pow(2) + (batch_disagree_ratio - self.target_disagreement_ratio).pow(2) ) # 7. Bank CV if B >= 10: ctx_n = F.normalize(geo_context, dim=-1) vols = [] for _ in range(32): idx = torch.randperm(B, device=embedding.device)[:5] pts = ctx_n[idx].unsqueeze(0) diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) d2 = (diff * diff).sum(-1) Bv, V, _ = d2.shape cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 s = (-1.0)**V; f = math.factorial(V-1) v2 = s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) stacked = torch.stack(vols) bank_cv = stacked.std() / (stacked.mean() + 1e-8) aux["bank_cv"] = bank_cv else: aux["bank_cv"] = torch.tensor(0.0, device=embedding.device) # 8. Emb CV if B >= 10: emb_n = F.normalize(emb, dim=-1) vols = [] for _ in range(32): idx = torch.randperm(B, device=embedding.device)[:5] pts = emb_n[idx].unsqueeze(0) diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) d2 = (diff * diff).sum(-1) Bv, V, _ = d2.shape cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 s = (-1.0)**V; f = math.factorial(V-1) v2 = s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) stacked = torch.stack(vols) emb_cv = stacked.std() / (stacked.mean() + 1e-8) aux["emb_cv"] = emb_cv else: aux["emb_cv"] = torch.tensor(0.0, device=embedding.device) # Diagnostics aux["expert_cos_mean"] = expert_cos.mean().item() aux["expert_cos_std"] = expert_cos.std().item() aux["anchor_max_cos"] = anchor_cos.max(dim=-1).values.mean().item() aux["anchor_mean_cos"] = anchor_cos.mean().item() if cross_features.shape[1] > 0: aux["cross_expert_cos"] = cross_features.mean().item() aux["cross_expert_cos_std"] = cross_features.std().item() aux["disagreement_ratio"] = disagreement_ratio.mean().item() aux["norm_ratio_spread"] = norm_ratio.std(dim=-1).mean().item() return enriched, aux def bank_loss(self, aux): """All targets from measured consensus. Preserves disagreement structure.""" loss = ( 1.0 * aux["expert_agreement"] + 1.0 * aux["rotation_ortho"] + 0.5 * aux["anchor_spread"] + 0.1 * aux["anchor_entropy"] + 0.3 * aux["cross_expert_var"] + 0.3 * (aux["bank_cv"] - self.target_cv).abs() + 0.3 * (aux["emb_cv"] - self.target_cv).abs() + 0.5 * aux["disagree_preserve"] # preserve the disagreement distribution ) return loss @torch.no_grad() def calibrate_disagreement(self, embeddings): """ Measure the initial disagreement structure and store as targets. Call ONCE after init, before training. """ _, aux = self.forward(embeddings) if "cross_expert_cos" in aux: self.target_cross_cos_mean.fill_(aux["cross_expert_cos"]) if "cross_expert_cos_std" in aux: self.target_cross_cos_std.fill_(aux["cross_expert_cos_std"]) self.target_disagreement_ratio.fill_(aux["disagreement_ratio"]) print(f" Calibrated disagreement:") print(f" cross_cos: {self.target_cross_cos_mean.item():.4f} ± {self.target_cross_cos_std.item():.4f}") print(f" disagree_ratio: {self.target_disagreement_ratio.item():.6f}") # ══════════════════════════════════════════════════════════════════ # GEOMETRY # ══════════════════════════════════════════════════════════════════ def infonce(a, b, temperature=0.07): a = F.normalize(a, dim=-1) b = F.normalize(b, dim=-1) logits = (a @ b.T) / temperature labels = torch.arange(logits.shape[0], device=logits.device) loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 with torch.no_grad(): acc = (logits.argmax(-1) == labels).float().mean().item() return loss, acc def cayley_menger_vol2(pts): pts = pts.float() diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) d2 = (diff * diff).sum(-1) B, V, _ = d2.shape cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 s = (-1.0)**V; f = math.factorial(V-1) return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) def cv_loss(emb, target=0.12, n_samples=16): B = emb.shape[0] if B < 5: return torch.tensor(0.0, device=emb.device) vols = [] for _ in range(n_samples): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) stacked = torch.stack(vols) cv = stacked.std() / (stacked.mean() + 1e-8) return (cv - target).abs() def cv_metric(emb, n=200): B = emb.shape[0] if B < 5: return 0.0 vols = [] for _ in range(n): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() if v > 0: vols.append(v) if len(vols) < 10: return 0.0 a = np.array(vols) return float(a.std() / (a.mean() + 1e-8)) def measure_consensus_stats(consensus_embs, n_check=2000): """Measure exact geometric statistics of the consensus manifold.""" embs = consensus_embs[:n_check].float() # CV cv = cv_metric(embs.to(DEVICE)) # Pairwise cosine sim = embs @ embs.T mask = ~torch.eye(embs.shape[0], dtype=torch.bool) pairwise = sim[mask] mean_cos = pairwise.mean().item() # Spectral centered = embs - embs.mean(0, keepdim=True) S = torch.linalg.svdvals(centered) S_norm = (S / (S.sum() + 1e-8)).tolist()[:50] # Eff dim eff_dim = float((S.sum() ** 2) / (S.pow(2).sum() + 1e-12)) return { "cv": cv, "mean_cos": mean_cos, "spectral": S_norm, "eff_dim": eff_dim, } # ══════════════════════════════════════════════════════════════════ # EXTRACTION + ALIGNMENT # ══════════════════════════════════════════════════════════════════ def symmetric_inv_sqrt(cov, eps=1e-6): evals, evecs = torch.linalg.eigh(cov) evals = torch.clamp(evals, min=eps) return evecs @ torch.diag(evals.rsqrt()) @ evecs.T def procrustes_align(source, target, n_align=5000): N = min(n_align, source.shape[0], target.shape[0]) S = source[:N].float(); T = target[:N].float() s_mean = S.mean(0, keepdim=True); t_mean = T.mean(0, keepdim=True) Sc = S - s_mean; Tc = T - t_mean; N_s = Sc.shape[0] cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item() s_cov = (Sc.T @ Sc) / max(N_s - 1, 1) t_cov = (Tc.T @ Tc) / max(N_s - 1, 1) s_whiten = symmetric_inv_sqrt(s_cov) t_whiten = symmetric_inv_sqrt(t_cov) Sc_w = F.normalize(Sc @ s_whiten, dim=-1) Tc_w = F.normalize(Tc @ t_whiten, dim=-1) U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False) R = U @ Vt cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item() return { "rotation": R, "source_mean": s_mean.squeeze(0), "source_whitener": s_whiten, "target_unwhitener": torch.linalg.pinv(t_whiten), "cos_before": cos_before, "cos_after": cos_after, } def apply_align(emb, a): x = emb.float() - a["source_mean"] x = x @ a["source_whitener"]; x = x @ a["rotation"].T x = x @ a["target_unwhitener"]; return x # ══════════════════════════════════════════════════════════════════ # MAIN # ══════════════════════════════════════════════════════════════════ def run(): torch.manual_seed(42) np.random.seed(42) N_SAMPLES = 20000 MAX_LEN = 128 BATCH = 256 # ── Phase 0: Extract ── print(f"\n{'='*65}") print("PHASE 0: EXTRACTION") print(f"{'='*65}") from datasets import load_dataset from transformers import AutoModel, AutoTokenizer ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext", split="train", streaming=True) captions = [] for row in ds: cap = row.get("caption_llava", "") if isinstance(cap, str) and len(cap) > 50: captions.append(cap) if len(captions) >= N_SAMPLES: break print(f" Captions: {len(captions):,}") embeds = {} for model_name, short, max_len in EXPERTS: print(f"\n Extracting: {short}...") model = AutoModel.from_pretrained(model_name).to(DEVICE).eval() tokenizer = AutoTokenizer.from_pretrained(model_name) all_emb = [] with torch.no_grad(): for i in tqdm(range(0, len(captions), 128), desc=f" {short}"): batch = captions[i:i+128] inputs = tokenizer(batch, max_length=max_len, padding=True, truncation=True, return_tensors="pt").to(DEVICE) out = model(**inputs) m = inputs.attention_mask.unsqueeze(-1).float() pooled = (out.last_hidden_state * m).sum(1) / m.sum(1).clamp(min=1) all_emb.append(pooled.cpu()) embeds[short] = torch.cat(all_emb) print(f" Shape: {embeds[short].shape}") del model; gc.collect(); torch.cuda.empty_cache() # ── Phase 0b: Align + Consensus + Measure ── print(f"\n{'='*65}") print("PHASE 0b: GENERALIZED PROCRUSTES ALIGNMENT (no reference bias)") print(f"{'='*65}") names = [s for _, s, _ in EXPERTS] # Generalized Procrustes: iteratively align all to their mean # No expert is the reference. The centerpoint emerges. GPA_ITERS = 10 current = {name: embeds[name].float() for name in names} for gpa_iter in range(GPA_ITERS): # Compute mean shape mean_shape = sum(current[n] for n in names) / len(names) # Align each to mean new_current = {} total_delta = 0.0 for name in names: info = procrustes_align(current[name], mean_shape) new_current[name] = apply_align(current[name], info) # Measure how much this iteration changed things delta = (new_current[name] - current[name]).pow(2).mean().item() total_delta += delta current = new_current if gpa_iter == 0 or (gpa_iter + 1) % 3 == 0 or total_delta < 1e-8: print(f" GPA iter {gpa_iter+1}: delta={total_delta:.8f}") if total_delta < 1e-8: print(f" Converged at iteration {gpa_iter+1}") break # Final alignment: align each expert to the converged mean mean_shape = sum(current[n] for n in names) / len(names) procrustes_results = {} aligned = {} for name in names: info = procrustes_align(embeds[name], mean_shape) procrustes_results[name] = info aligned[name] = apply_align(embeds[name], info) cos = F.cosine_similarity( aligned[name][:2000], mean_shape[:2000], dim=-1).mean().item() print(f" {name:10s}: cos_after={info['cos_after']:.4f} cos_to_mean={cos:.4f}") # Consensus = normalized centroid (now equidistant from all experts) consensus = F.normalize(sum(aligned[n] for n in names) / len(names), dim=-1) for name in names: cos = F.cosine_similarity(consensus[:2000], aligned[name][:2000], dim=-1).mean().item() print(f" cos(consensus, {name}): {cos:.4f}") # Verify equidistance expert_cos_to_consensus = [] for name in names: c = F.cosine_similarity(consensus[:2000], aligned[name][:2000], dim=-1).mean().item() expert_cos_to_consensus.append(c) equidist_range = max(expert_cos_to_consensus) - min(expert_cos_to_consensus) print(f" Equidistance range: {equidist_range:.4f} (should be near 0)") # Measure EXACT consensus statistics print(f"\n Measuring consensus statistics...") consensus_stats = measure_consensus_stats(consensus) print(f" CV: {consensus_stats['cv']:.4f}") print(f" Mean cos: {consensus_stats['mean_cos']:.4f}") print(f" Eff dim: {consensus_stats['eff_dim']:.1f}") print(f" Spectral: [{', '.join(f'{s:.4f}' for s in consensus_stats['spectral'][:5])}...]") del embeds, aligned gc.collect(); torch.cuda.empty_cache() # ── Phase 1: Train Student ── print(f"\n{'='*65}") print("PHASE 1: TRAIN STUDENT") print(f"{'='*65}") tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") tokens = tokenizer(captions, max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="pt") input_ids = tokens["input_ids"] attention_mask = tokens["attention_mask"] n_train = N_SAMPLES - 2000 train_ids = input_ids[:n_train].to(DEVICE) train_mask = attention_mask[:n_train].to(DEVICE) train_targets = consensus[:n_train].to(DEVICE) val_ids = input_ids[n_train:].to(DEVICE) val_mask = attention_mask[n_train:].to(DEVICE) val_targets = consensus[n_train:].to(DEVICE) student = MiniStudent( vocab_size=tokenizer.vocab_size, max_len=MAX_LEN, d_model=256, n_heads=4, n_layers=4, d_ff=1024, output_dim=768, dropout=0.1, pad_token_id=tokenizer.pad_token_id ).to(DEVICE) n_params = sum(p.numel() for p in student.parameters()) print(f" Student: {n_params:,} params") print(f" CV target: {consensus_stats['cv']:.4f}") optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay=0.01) for epoch in range(5): student.train() perm = torch.randperm(n_train, device=DEVICE) t_loss, t_acc, t_cos, n = 0, 0, 0, 0 t0 = time.time() for i in range(0, n_train, BATCH): idx = perm[i:i+BATCH] if len(idx) < 8: continue emb = student(train_ids[idx], train_mask[idx]) tgt = train_targets[idx] l_nce, acc = infonce(emb, tgt) l_mse = F.mse_loss(emb, tgt) l_cv = cv_loss(emb, target=consensus_stats["cv"]) loss = l_nce + l_mse + 0.1 * l_cv loss.backward() torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0) optimizer.step(); optimizer.zero_grad(set_to_none=True) with torch.no_grad(): cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item() t_loss += loss.item(); t_acc += acc; t_cos += cos; n += 1 elapsed = time.time() - t0; d = max(n, 1) student.eval() with torch.no_grad(): v_emb = student(val_ids, val_mask) _, v_acc = infonce(v_emb[:1000], val_targets[:1000]) v_cos = F.cosine_similarity(v_emb, val_targets, dim=-1).mean().item() v_cv = cv_metric(v_emb[:1000]) print(f" E{epoch+1}: {elapsed:.0f}s loss={t_loss/d:.4f} " f"t_acc={t_acc/d:.3f} t_cos={t_cos/d:.3f} " f"v_acc={v_acc:.3f} v_cos={v_cos:.3f} v_cv={v_cv:.3f}") torch.save(student.state_dict(), "mini_student.pt") print(f"\n Student saved. v_cos={v_cos:.3f}, v_cv={v_cv:.3f}") # ── Phase 2: Train Alignment Bank ── print(f"\n{'='*65}") print("PHASE 2: TRAIN ALIGNMENT BANK (student frozen)") print(f"{'='*65}") student.eval() for p in student.parameters(): p.requires_grad = False print(" Pre-encoding through frozen student...") with torch.no_grad(): all_embs = [] for i in range(0, n_train, 512): j = min(i + 512, n_train) emb = student(train_ids[i:j], train_mask[i:j]) all_embs.append(emb) student_embs = torch.cat(all_embs) val_student_embs = student(val_ids, val_mask) print(f" Student embeddings: {student_embs.shape}") bank = AlignmentBank( d_embed=768, n_experts=len(EXPERTS), n_anchors=512, d_bank=128 ).to(DEVICE) bank.init_from_procrustes(procrustes_results, names, consensus[:n_train], consensus_stats) bank_params = sum(p.numel() for p in bank.parameters()) print(f" Bank: {bank_params:,} params") print(f" Bank targets: CV={bank.target_cv.item():.4f}, " f"mean_cos={bank.target_mean_cos.item():.4f}") # Calibrate disagreement from initial state (before any training) bank.calibrate_disagreement(student_embs[:2000]) bank_opt = torch.optim.AdamW(bank.parameters(), lr=1e-3, weight_decay=0.01) BANK_EPOCHS = 20 BANK_BATCH = 256 for epoch in range(BANK_EPOCHS): bank.train() perm = torch.randperm(n_train, device=DEVICE) total_loss = 0 stats = {"expert_agreement": 0, "rotation_ortho": 0, "anchor_spread": 0, "bank_cv": 0, "emb_cv": 0, "cross_expert_var": 0, "disagree_preserve": 0} n = 0 t0 = time.time() for i in range(0, n_train, BANK_BATCH): idx = perm[i:i+BANK_BATCH] if len(idx) < 16: continue emb = student_embs[idx] enriched, aux = bank(emb) loss = bank.bank_loss(aux) loss.backward() torch.nn.utils.clip_grad_norm_(bank.parameters(), 1.0) bank_opt.step(); bank_opt.zero_grad(set_to_none=True) total_loss += loss.item() for k in stats: if k in aux: v = aux[k] stats[k] += v.item() if torch.is_tensor(v) else v n += 1 elapsed = time.time() - t0; d = max(n, 1) bank.eval() with torch.no_grad(): v_enriched, v_aux = bank(val_student_embs) v_loss = bank.bank_loss(v_aux).item() print(f"\n E{epoch+1:2d}: {elapsed:.0f}s loss={total_loss/d:.4f} v_loss={v_loss:.4f}") print(f" Geometry: b_cv={stats['bank_cv']/d:.4f} e_cv={stats['emb_cv']/d:.4f} " f"spread={stats['anchor_spread']/d:.5f} a_max={v_aux['anchor_max_cos']:.3f}") print(f" Experts: cos={v_aux['expert_cos_mean']:.3f}±{v_aux['expert_cos_std']:.3f} " f"agr={stats['expert_agreement']/d:.6f} ortho={stats['rotation_ortho']/d:.6f}") print(f" Disagree: x_cos={v_aux.get('cross_expert_cos', 0):.4f}±{v_aux.get('cross_expert_cos_std', 0):.4f} " f"ratio={v_aux['disagreement_ratio']:.6f} " f"preserve={stats['disagree_preserve']/d:.6f} " f"norms={v_aux['norm_ratio_spread']:.4f}") torch.save(bank.state_dict(), "alignment_bank.pt") # ── Phase 3: Geometric Verification ── print(f"\n{'='*65}") print("PHASE 3: GEOMETRIC VERIFICATION") print(f"{'='*65}") bank.eval() with torch.no_grad(): enriched_val, v_aux = bank(val_student_embs) original_768 = enriched_val[:, :768] geo_context = enriched_val[:, 768:] passthrough_cos = F.cosine_similarity( original_768[:100], val_student_embs[:100], dim=-1).mean().item() geo_cv = cv_metric(F.normalize(geo_context[:1000], dim=-1)) S = torch.linalg.svdvals( geo_context[:1000].float() - geo_context[:1000].float().mean(0)) geo_eff_dim = float((S.sum() ** 2) / (S.pow(2).sum() + 1e-12)) # Verify consensus stats are preserved emb_cv = cv_metric(val_student_embs[:1000]) print(f" Passthrough: {passthrough_cos:.6f} (target: 1.000)") print(f" Emb CV: {emb_cv:.4f} (consensus: {consensus_stats['cv']:.4f})") print(f" Geo context CV: {geo_cv:.4f}") print(f" Geo eff_dim: {geo_eff_dim:.1f} / {bank.d_bank}") print(f" Expert cos: {v_aux['expert_cos_mean']:.3f} ± {v_aux['expert_cos_std']:.3f}") print(f" Anchor max cos: {v_aux['anchor_max_cos']:.3f}") print(f" Disagreement:") print(f" Cross-expert: {v_aux.get('cross_expert_cos', 0):.4f} ± {v_aux.get('cross_expert_cos_std', 0):.4f}") print(f" Ratio: {v_aux['disagreement_ratio']:.6f} (target: {bank.target_disagreement_ratio.item():.6f})") print(f" Norm spread: {v_aux['norm_ratio_spread']:.4f}") # ── Phase 4: Classifier Stability Test ── print(f"\n{'='*65}") print("PHASE 4: CLASSIFIER STABILITY TEST") print(f"{'='*65}") with torch.no_grad(): embs = val_student_embs[:1000] sim = embs @ embs.T sim.fill_diagonal_(-1) n_pairs = 3000 idx_a = torch.randint(0, 1000, (n_pairs,)) idx_b = torch.randint(0, 1000, (n_pairs,)) pair_cos = sim[idx_a, idx_b] sorted_cos, _ = pair_cos.sort() t1 = sorted_cos[n_pairs // 3].item() t2 = sorted_cos[2 * n_pairs // 3].item() labels = torch.zeros(n_pairs, dtype=torch.long, device=DEVICE) labels[pair_cos > t2] = 0 labels[(pair_cos <= t2) & (pair_cos > t1)] = 1 labels[pair_cos <= t1] = 2 enriched_a, _ = bank(embs[idx_a]) enriched_b, _ = bank(embs[idx_b]) for mode in ["with_bank", "without_bank"]: if mode == "with_bank": feat_dim = (768 + 128) * 2 features = torch.cat([enriched_a, enriched_b], dim=-1) else: feat_dim = 768 * 2 features = torch.cat([embs[idx_a], embs[idx_b]], dim=-1) clf = nn.Sequential( nn.Linear(feat_dim, 256), nn.GELU(), nn.LayerNorm(256), nn.Linear(256, 3) ).to(DEVICE) clf_opt = torch.optim.Adam(clf.parameters(), lr=1e-3) n_clf_train = 2400 train_f = features[:n_clf_train].detach() train_l = labels[:n_clf_train] val_f = features[n_clf_train:].detach() val_l = labels[n_clf_train:] for e in range(30): clf.train() logits = clf(train_f) loss = F.cross_entropy(logits, train_l) loss.backward(); clf_opt.step(); clf_opt.zero_grad() clf.eval() with torch.no_grad(): v_acc = (clf(val_f).argmax(-1) == val_l).float().mean().item() t_acc = (clf(train_f).argmax(-1) == train_l).float().mean().item() print(f" {mode:15s}: train={t_acc:.3f} val={v_acc:.3f} gap={t_acc-v_acc:.3f}") print(f"\n{'='*65}") print("SUMMARY") print(f"{'='*65}") print(f" Consensus CV: {consensus_stats['cv']:.4f}") print(f" Consensus eff_dim:{consensus_stats['eff_dim']:.1f}") print(f" Student v_cos: {v_cos:.3f}") print(f" Student v_cv: {v_cv:.3f}") print(f" Bank params: {bank_params:,}") print(f" Bank geo_eff_dim: {geo_eff_dim:.1f}") print(f" Bank geo_cv: {geo_cv:.4f}") print(f"\n{'='*65}") print("DONE") print(f"{'='*65}") if __name__ == "__main__": run()