geolip-axis-prototype / prototype_5_geodesic_prelim.py
AbstractPhil's picture
Create prototype_5_geodesic_prelim.py
2c64e3d verified
# ============================================================================
# RAPID PROTOTYPE v2: Differentiation-Centered Alignment Bank
#
# The bank aligns to the DIFFERENTIATION between experts, not to any
# arbitrary target. The consensus CV, spectral profile, and pairwise
# statistics measured during alignment become the exact targets.
#
# The bank embodies the centerpoint of expert disagreement.
# ============================================================================
import gc
import math
import os
import time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EXPERTS = [
("google-bert/bert-base-uncased", "bert", 512),
("answerdotai/ModernBERT-base", "modern", 512),
]
print("=" * 65)
print("RAPID PROTOTYPE v2: Differentiation-Centered Bank")
print("=" * 65)
print(f" Device: {DEVICE}")
# ══════════════════════════════════════════════════════════════════
# STUDENT MODEL
# ══════════════════════════════════════════════════════════════════
class MiniStudent(nn.Module):
def __init__(self, vocab_size=30522, max_len=512, d_model=256,
n_heads=4, n_layers=4, d_ff=1024, output_dim=768,
dropout=0.1, pad_token_id=0):
super().__init__()
self.pad_token_id = pad_token_id
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
self.pos_emb = nn.Embedding(max_len, d_model)
self.emb_norm = nn.LayerNorm(d_model)
self.emb_drop = nn.Dropout(dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=d_ff,
dropout=dropout, activation="gelu", batch_first=True,
norm_first=True)
self.encoder = nn.TransformerEncoder(
encoder_layer, num_layers=n_layers, enable_nested_tensor=False)
self.output_proj = nn.Sequential(
nn.Linear(d_model, d_model), nn.GELU(),
nn.LayerNorm(d_model), nn.Linear(d_model, output_dim))
def forward(self, input_ids, attention_mask=None):
B, L = input_ids.shape
positions = torch.arange(L, device=input_ids.device).unsqueeze(0)
x = self.token_emb(input_ids) + self.pos_emb(positions)
x = self.emb_drop(self.emb_norm(x))
kpm = ~attention_mask.bool() if attention_mask is not None else (input_ids == self.pad_token_id)
x = self.encoder(x, src_key_padding_mask=kpm)
mask = attention_mask.unsqueeze(-1).float() if attention_mask is not None else (~kpm).unsqueeze(-1).float()
pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
return F.normalize(self.output_proj(pooled), dim=-1)
# ══════════════════════════════════════════════════════════════════
# ALIGNMENT BANK
# ══════════════════════════════════════════════════════════════════
class AlignmentBank(nn.Module):
"""
Differentiation-centered geometric interface.
Aligns to the CENTERPOINT between experts β€” the consensus itself.
Stores per-expert rotation matrices (the differentiation structure)
and learned anchor landmarks (the consensus manifold topology).
The bank doesn't invent geometry. It mirrors the measured consensus.
Every loss term pulls toward measured consensus statistics.
"""
def __init__(self, d_embed=768, n_experts=2, n_anchors=512, d_bank=128):
super().__init__()
self.d_embed = d_embed
self.n_experts = n_experts
self.n_anchors = n_anchors
self.d_bank = d_bank
# Per-expert rotation matrices (differentiation structure)
self.expert_rotations = nn.ParameterList([
nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)
])
# Per-expert whiteners (captures variance structure per expert)
self.expert_whiteners = nn.ParameterList([
nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)
])
# Per-expert means (centering offset per expert)
self.expert_means = nn.ParameterList([
nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)
])
# Anchor bank: consensus landmarks on the hypersphere
self.anchors = nn.Parameter(
F.normalize(torch.randn(n_anchors, d_embed), dim=-1))
# Project: expert_cos (n) + expert_mse (n) + cross (n*(n-1)/2) +
# disagreement_ratio (1) + norm_ratio (n) + anchor_cos (n_anchors)
n_cross = n_experts * (n_experts - 1) // 2
geo_dim = n_experts + n_experts + n_cross + 1 + n_experts + n_anchors
self.geo_proj = nn.Sequential(
nn.Linear(geo_dim, d_bank * 2),
nn.GELU(),
nn.LayerNorm(d_bank * 2),
nn.Linear(d_bank * 2, d_bank),
nn.LayerNorm(d_bank),
)
# Consensus statistics (set during init, used as exact targets)
self.register_buffer("target_cv", torch.tensor(0.12))
self.register_buffer("target_mean_cos", torch.tensor(0.0))
self.register_buffer("target_spectral", torch.zeros(50))
# Disagreement structure (measured once, preserved forever)
self.register_buffer("target_cross_cos_mean", torch.tensor(0.0))
self.register_buffer("target_cross_cos_std", torch.tensor(0.0))
self.register_buffer("target_disagreement_ratio", torch.tensor(0.0))
def init_from_procrustes(self, procrustes_results, expert_names,
consensus_embeddings=None,
consensus_stats=None):
"""Initialize from consensus training artifacts."""
device = self.anchors.device
for i, name in enumerate(expert_names[:self.n_experts]):
info = procrustes_results[name]
self.expert_rotations[i].data = info["rotation"].float().to(device)
if "source_whitener" in info:
self.expert_whiteners[i].data = info["source_whitener"].float().to(device)
if "source_mean" in info:
self.expert_means[i].data = info["source_mean"].float().to(device)
print(f" Expert {i} ({name}): rotation + whitener + mean loaded, "
f"cos_after={info['cos_after']:.4f}")
if consensus_embeddings is not None:
n = min(self.n_anchors, consensus_embeddings.shape[0])
indices = torch.linspace(0, consensus_embeddings.shape[0] - 1, n).long()
self.anchors.data[:n] = F.normalize(
consensus_embeddings[indices].float(), dim=-1).to(device)
print(f" Anchors: {n} initialized from consensus embeddings")
if consensus_stats is not None:
self.target_cv.fill_(consensus_stats["cv"])
self.target_mean_cos.fill_(consensus_stats["mean_cos"])
if "spectral" in consensus_stats:
s = torch.tensor(consensus_stats["spectral"][:50], dtype=torch.float32)
self.target_spectral[:len(s)] = s.to(device)
print(f" Targets: CV={consensus_stats['cv']:.4f}, "
f"mean_cos={consensus_stats['mean_cos']:.4f}")
def forward(self, embedding):
B = embedding.shape[0]
emb = embedding.float()
# ── Per-expert projections (full whitened Procrustes) ──
# Chain: center β†’ whiten β†’ normalize β†’ rotate
# This is EXACTLY what was computed during alignment.
# The rotation only makes geometric sense in whitened-normalized space.
expert_consistency = []
expert_recon = []
expert_projected = []
for i in range(self.n_experts):
R = self.expert_rotations[i]
W = self.expert_whiteners[i]
mu = self.expert_means[i]
# Forward: center β†’ whiten β†’ normalize β†’ rotate
centered = emb - mu
whitened = centered @ W
whitened_n = F.normalize(whitened, dim=-1)
in_expert = whitened_n @ R.T # now in expert's whitened-normalized space
# Round-trip: rotate back (orthogonal, so R.T inverse = R)
back = in_expert @ R
# Consistency: round-trip should recover whitened_n exactly
cos = F.cosine_similarity(whitened_n, back, dim=-1)
recon = (whitened_n - back).pow(2).mean(dim=-1)
expert_consistency.append(cos)
expert_recon.append(recon)
expert_projected.append(in_expert)
expert_cos = torch.stack(expert_consistency, dim=-1) # (B, n_experts)
expert_mse = torch.stack(expert_recon, dim=-1) # (B, n_experts)
# ── Cross-expert differentiation ──
# How each expert's projection relates to every other expert's projection
# This IS the disagreement structure. Preserve it exactly.
cross_cos = []
for i in range(self.n_experts):
for j in range(i + 1, self.n_experts):
cc = F.cosine_similarity(
expert_projected[i], expert_projected[j], dim=-1)
cross_cos.append(cc)
cross_features = torch.stack(cross_cos, dim=-1) if cross_cos else torch.zeros(B, 0, device=emb.device)
# Per-sample disagreement: how much do experts disagree on THIS embedding?
# High disagreement = embedding is in contested territory
# Low disagreement = all experts agree (well-anchored)
per_sample_agreement = expert_cos.mean(dim=-1) # (B,) mean round-trip cos
per_sample_disagreement = expert_cos.std(dim=-1) # (B,) std across experts
# Ratio: how much agreement relative to disagreement
disagreement_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8) # (B,)
# Expert projection norms before normalization (captures magnitude structure)
expert_norms = []
for i in range(self.n_experts):
R = self.expert_rotations[i]
W = self.expert_whiteners[i]
mu = self.expert_means[i]
centered = emb - mu
whitened = centered @ W
expert_norms.append(whitened.norm(dim=-1)) # (B,)
expert_norm_features = torch.stack(expert_norms, dim=-1) # (B, n_experts)
norm_ratio = expert_norm_features / (expert_norm_features.mean(dim=-1, keepdim=True) + 1e-8)
# ── Anchor distances ──
anchors_n = F.normalize(self.anchors, dim=-1)
anchor_cos = emb @ anchors_n.T # (B, n_anchors)
# ── Geometric context ──
# Full feature set: expert consistency + reconstruction + cross-expert +
# disagreement ratio + norm ratios + anchor distances
geo_input = torch.cat([
expert_cos, # (B, n_experts)
expert_mse, # (B, n_experts)
cross_features, # (B, n_cross)
disagreement_ratio.unsqueeze(-1), # (B, 1)
norm_ratio, # (B, n_experts)
anchor_cos, # (B, n_anchors)
], dim=-1)
geo_context = self.geo_proj(geo_input)
enriched = torch.cat([embedding, geo_context], dim=-1)
# ── Losses + Diagnostics ──
aux = {}
# 1. Expert agreement: all experts should see embedding equally
expert_mean = expert_cos.mean(dim=-1, keepdim=True)
aux["expert_agreement"] = (expert_cos - expert_mean).pow(2).mean()
# 2. Rotation orthogonality
ortho_loss = 0.0
for i in range(self.n_experts):
R = self.expert_rotations[i]
RRT = R @ R.T
ortho_loss += (RRT - torch.eye(self.d_embed, device=R.device)).pow(2).mean()
aux["rotation_ortho"] = ortho_loss / self.n_experts
# 3. Anchor spread
anchor_sim = anchors_n @ anchors_n.T
anchor_sim.fill_diagonal_(0)
aux["anchor_spread"] = anchor_sim.pow(2).mean()
# 4. Anchor sharpness
anchor_probs = F.softmax(anchor_cos * 10, dim=-1)
entropy = -(anchor_probs * (anchor_probs + 1e-12).log()).sum(-1).mean()
aux["anchor_entropy"] = entropy
# 5. Cross-expert differentiation consistency
if cross_features.shape[1] > 0:
aux["cross_expert_var"] = cross_features.var(dim=0).mean()
else:
aux["cross_expert_var"] = torch.tensor(0.0, device=emb.device)
# 6. Disagreement preservation
# The distribution of disagreement should stay at the measured target
batch_cross_mean = cross_features.mean() if cross_features.shape[1] > 0 else torch.tensor(0.0, device=emb.device)
batch_cross_std = cross_features.std() if cross_features.shape[1] > 0 else torch.tensor(0.0, device=emb.device)
batch_disagree_ratio = disagreement_ratio.mean()
aux["disagree_preserve"] = (
(batch_cross_mean - self.target_cross_cos_mean).pow(2) +
(batch_cross_std - self.target_cross_cos_std).pow(2) +
(batch_disagree_ratio - self.target_disagreement_ratio).pow(2)
)
# 7. Bank CV
if B >= 10:
ctx_n = F.normalize(geo_context, dim=-1)
vols = []
for _ in range(32):
idx = torch.randperm(B, device=embedding.device)[:5]
pts = ctx_n[idx].unsqueeze(0)
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
Bv, V, _ = d2.shape
cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
v2 = s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
bank_cv = stacked.std() / (stacked.mean() + 1e-8)
aux["bank_cv"] = bank_cv
else:
aux["bank_cv"] = torch.tensor(0.0, device=embedding.device)
# 8. Emb CV
if B >= 10:
emb_n = F.normalize(emb, dim=-1)
vols = []
for _ in range(32):
idx = torch.randperm(B, device=embedding.device)[:5]
pts = emb_n[idx].unsqueeze(0)
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
Bv, V, _ = d2.shape
cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
v2 = s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
emb_cv = stacked.std() / (stacked.mean() + 1e-8)
aux["emb_cv"] = emb_cv
else:
aux["emb_cv"] = torch.tensor(0.0, device=embedding.device)
# Diagnostics
aux["expert_cos_mean"] = expert_cos.mean().item()
aux["expert_cos_std"] = expert_cos.std().item()
aux["anchor_max_cos"] = anchor_cos.max(dim=-1).values.mean().item()
aux["anchor_mean_cos"] = anchor_cos.mean().item()
if cross_features.shape[1] > 0:
aux["cross_expert_cos"] = cross_features.mean().item()
aux["cross_expert_cos_std"] = cross_features.std().item()
aux["disagreement_ratio"] = disagreement_ratio.mean().item()
aux["norm_ratio_spread"] = norm_ratio.std(dim=-1).mean().item()
return enriched, aux
def bank_loss(self, aux):
"""All targets from measured consensus. Preserves disagreement structure."""
loss = (
1.0 * aux["expert_agreement"] +
1.0 * aux["rotation_ortho"] +
0.5 * aux["anchor_spread"] +
0.1 * aux["anchor_entropy"] +
0.3 * aux["cross_expert_var"] +
0.3 * (aux["bank_cv"] - self.target_cv).abs() +
0.3 * (aux["emb_cv"] - self.target_cv).abs() +
0.5 * aux["disagree_preserve"] # preserve the disagreement distribution
)
return loss
@torch.no_grad()
def calibrate_disagreement(self, embeddings):
"""
Measure the initial disagreement structure and store as targets.
Call ONCE after init, before training.
"""
_, aux = self.forward(embeddings)
if "cross_expert_cos" in aux:
self.target_cross_cos_mean.fill_(aux["cross_expert_cos"])
if "cross_expert_cos_std" in aux:
self.target_cross_cos_std.fill_(aux["cross_expert_cos_std"])
self.target_disagreement_ratio.fill_(aux["disagreement_ratio"])
print(f" Calibrated disagreement:")
print(f" cross_cos: {self.target_cross_cos_mean.item():.4f} Β± {self.target_cross_cos_std.item():.4f}")
print(f" disagree_ratio: {self.target_disagreement_ratio.item():.6f}")
# ══════════════════════════════════════════════════════════════════
# GEOMETRY
# ══════════════════════════════════════════════════════════════════
def infonce(a, b, temperature=0.07):
a = F.normalize(a, dim=-1)
b = F.normalize(b, dim=-1)
logits = (a @ b.T) / temperature
labels = torch.arange(logits.shape[0], device=logits.device)
loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
with torch.no_grad():
acc = (logits.argmax(-1) == labels).float().mean().item()
return loss, acc
def cayley_menger_vol2(pts):
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B, V, _ = d2.shape
cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
def cv_loss(emb, target=0.12, n_samples=16):
B = emb.shape[0]
if B < 5: return torch.tensor(0.0, device=emb.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
cv = stacked.std() / (stacked.mean() + 1e-8)
return (cv - target).abs()
def cv_metric(emb, n=200):
B = emb.shape[0]
if B < 5: return 0.0
vols = []
for _ in range(n):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
if v > 0: vols.append(v)
if len(vols) < 10: return 0.0
a = np.array(vols)
return float(a.std() / (a.mean() + 1e-8))
def measure_consensus_stats(consensus_embs, n_check=2000):
"""Measure exact geometric statistics of the consensus manifold."""
embs = consensus_embs[:n_check].float()
# CV
cv = cv_metric(embs.to(DEVICE))
# Pairwise cosine
sim = embs @ embs.T
mask = ~torch.eye(embs.shape[0], dtype=torch.bool)
pairwise = sim[mask]
mean_cos = pairwise.mean().item()
# Spectral
centered = embs - embs.mean(0, keepdim=True)
S = torch.linalg.svdvals(centered)
S_norm = (S / (S.sum() + 1e-8)).tolist()[:50]
# Eff dim
eff_dim = float((S.sum() ** 2) / (S.pow(2).sum() + 1e-12))
return {
"cv": cv,
"mean_cos": mean_cos,
"spectral": S_norm,
"eff_dim": eff_dim,
}
# ══════════════════════════════════════════════════════════════════
# EXTRACTION + ALIGNMENT
# ══════════════════════════════════════════════════════════════════
def symmetric_inv_sqrt(cov, eps=1e-6):
evals, evecs = torch.linalg.eigh(cov)
evals = torch.clamp(evals, min=eps)
return evecs @ torch.diag(evals.rsqrt()) @ evecs.T
def procrustes_align(source, target, n_align=5000):
N = min(n_align, source.shape[0], target.shape[0])
S = source[:N].float(); T = target[:N].float()
s_mean = S.mean(0, keepdim=True); t_mean = T.mean(0, keepdim=True)
Sc = S - s_mean; Tc = T - t_mean; N_s = Sc.shape[0]
cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item()
s_cov = (Sc.T @ Sc) / max(N_s - 1, 1)
t_cov = (Tc.T @ Tc) / max(N_s - 1, 1)
s_whiten = symmetric_inv_sqrt(s_cov)
t_whiten = symmetric_inv_sqrt(t_cov)
Sc_w = F.normalize(Sc @ s_whiten, dim=-1)
Tc_w = F.normalize(Tc @ t_whiten, dim=-1)
U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False)
R = U @ Vt
cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item()
return {
"rotation": R, "source_mean": s_mean.squeeze(0),
"source_whitener": s_whiten,
"target_unwhitener": torch.linalg.pinv(t_whiten),
"cos_before": cos_before, "cos_after": cos_after,
}
def apply_align(emb, a):
x = emb.float() - a["source_mean"]
x = x @ a["source_whitener"]; x = x @ a["rotation"].T
x = x @ a["target_unwhitener"]; return x
# ══════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════
def run():
torch.manual_seed(42)
np.random.seed(42)
N_SAMPLES = 20000
MAX_LEN = 128
BATCH = 256
# ── Phase 0: Extract ──
print(f"\n{'='*65}")
print("PHASE 0: EXTRACTION")
print(f"{'='*65}")
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer
ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext",
split="train", streaming=True)
captions = []
for row in ds:
cap = row.get("caption_llava", "")
if isinstance(cap, str) and len(cap) > 50:
captions.append(cap)
if len(captions) >= N_SAMPLES:
break
print(f" Captions: {len(captions):,}")
embeds = {}
for model_name, short, max_len in EXPERTS:
print(f"\n Extracting: {short}...")
model = AutoModel.from_pretrained(model_name).to(DEVICE).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
all_emb = []
with torch.no_grad():
for i in tqdm(range(0, len(captions), 128), desc=f" {short}"):
batch = captions[i:i+128]
inputs = tokenizer(batch, max_length=max_len, padding=True,
truncation=True, return_tensors="pt").to(DEVICE)
out = model(**inputs)
m = inputs.attention_mask.unsqueeze(-1).float()
pooled = (out.last_hidden_state * m).sum(1) / m.sum(1).clamp(min=1)
all_emb.append(pooled.cpu())
embeds[short] = torch.cat(all_emb)
print(f" Shape: {embeds[short].shape}")
del model; gc.collect(); torch.cuda.empty_cache()
# ── Phase 0b: Align + Consensus + Measure ──
print(f"\n{'='*65}")
print("PHASE 0b: GENERALIZED PROCRUSTES ALIGNMENT (no reference bias)")
print(f"{'='*65}")
names = [s for _, s, _ in EXPERTS]
# Generalized Procrustes: iteratively align all to their mean
# No expert is the reference. The centerpoint emerges.
GPA_ITERS = 10
current = {name: embeds[name].float() for name in names}
for gpa_iter in range(GPA_ITERS):
# Compute mean shape
mean_shape = sum(current[n] for n in names) / len(names)
# Align each to mean
new_current = {}
total_delta = 0.0
for name in names:
info = procrustes_align(current[name], mean_shape)
new_current[name] = apply_align(current[name], info)
# Measure how much this iteration changed things
delta = (new_current[name] - current[name]).pow(2).mean().item()
total_delta += delta
current = new_current
if gpa_iter == 0 or (gpa_iter + 1) % 3 == 0 or total_delta < 1e-8:
print(f" GPA iter {gpa_iter+1}: delta={total_delta:.8f}")
if total_delta < 1e-8:
print(f" Converged at iteration {gpa_iter+1}")
break
# Final alignment: align each expert to the converged mean
mean_shape = sum(current[n] for n in names) / len(names)
procrustes_results = {}
aligned = {}
for name in names:
info = procrustes_align(embeds[name], mean_shape)
procrustes_results[name] = info
aligned[name] = apply_align(embeds[name], info)
cos = F.cosine_similarity(
aligned[name][:2000], mean_shape[:2000], dim=-1).mean().item()
print(f" {name:10s}: cos_after={info['cos_after']:.4f} cos_to_mean={cos:.4f}")
# Consensus = normalized centroid (now equidistant from all experts)
consensus = F.normalize(sum(aligned[n] for n in names) / len(names), dim=-1)
for name in names:
cos = F.cosine_similarity(consensus[:2000], aligned[name][:2000], dim=-1).mean().item()
print(f" cos(consensus, {name}): {cos:.4f}")
# Verify equidistance
expert_cos_to_consensus = []
for name in names:
c = F.cosine_similarity(consensus[:2000], aligned[name][:2000], dim=-1).mean().item()
expert_cos_to_consensus.append(c)
equidist_range = max(expert_cos_to_consensus) - min(expert_cos_to_consensus)
print(f" Equidistance range: {equidist_range:.4f} (should be near 0)")
# Measure EXACT consensus statistics
print(f"\n Measuring consensus statistics...")
consensus_stats = measure_consensus_stats(consensus)
print(f" CV: {consensus_stats['cv']:.4f}")
print(f" Mean cos: {consensus_stats['mean_cos']:.4f}")
print(f" Eff dim: {consensus_stats['eff_dim']:.1f}")
print(f" Spectral: [{', '.join(f'{s:.4f}' for s in consensus_stats['spectral'][:5])}...]")
del embeds, aligned
gc.collect(); torch.cuda.empty_cache()
# ── Phase 1: Train Student ──
print(f"\n{'='*65}")
print("PHASE 1: TRAIN STUDENT")
print(f"{'='*65}")
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
tokens = tokenizer(captions, max_length=MAX_LEN, padding="max_length",
truncation=True, return_tensors="pt")
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
n_train = N_SAMPLES - 2000
train_ids = input_ids[:n_train].to(DEVICE)
train_mask = attention_mask[:n_train].to(DEVICE)
train_targets = consensus[:n_train].to(DEVICE)
val_ids = input_ids[n_train:].to(DEVICE)
val_mask = attention_mask[n_train:].to(DEVICE)
val_targets = consensus[n_train:].to(DEVICE)
student = MiniStudent(
vocab_size=tokenizer.vocab_size, max_len=MAX_LEN,
d_model=256, n_heads=4, n_layers=4, d_ff=1024,
output_dim=768, dropout=0.1, pad_token_id=tokenizer.pad_token_id
).to(DEVICE)
n_params = sum(p.numel() for p in student.parameters())
print(f" Student: {n_params:,} params")
print(f" CV target: {consensus_stats['cv']:.4f}")
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay=0.01)
for epoch in range(5):
student.train()
perm = torch.randperm(n_train, device=DEVICE)
t_loss, t_acc, t_cos, n = 0, 0, 0, 0
t0 = time.time()
for i in range(0, n_train, BATCH):
idx = perm[i:i+BATCH]
if len(idx) < 8: continue
emb = student(train_ids[idx], train_mask[idx])
tgt = train_targets[idx]
l_nce, acc = infonce(emb, tgt)
l_mse = F.mse_loss(emb, tgt)
l_cv = cv_loss(emb, target=consensus_stats["cv"])
loss = l_nce + l_mse + 0.1 * l_cv
loss.backward()
torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
optimizer.step(); optimizer.zero_grad(set_to_none=True)
with torch.no_grad():
cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item()
t_loss += loss.item(); t_acc += acc; t_cos += cos; n += 1
elapsed = time.time() - t0; d = max(n, 1)
student.eval()
with torch.no_grad():
v_emb = student(val_ids, val_mask)
_, v_acc = infonce(v_emb[:1000], val_targets[:1000])
v_cos = F.cosine_similarity(v_emb, val_targets, dim=-1).mean().item()
v_cv = cv_metric(v_emb[:1000])
print(f" E{epoch+1}: {elapsed:.0f}s loss={t_loss/d:.4f} "
f"t_acc={t_acc/d:.3f} t_cos={t_cos/d:.3f} "
f"v_acc={v_acc:.3f} v_cos={v_cos:.3f} v_cv={v_cv:.3f}")
torch.save(student.state_dict(), "mini_student.pt")
print(f"\n Student saved. v_cos={v_cos:.3f}, v_cv={v_cv:.3f}")
# ── Phase 2: Train Alignment Bank ──
print(f"\n{'='*65}")
print("PHASE 2: TRAIN ALIGNMENT BANK (student frozen)")
print(f"{'='*65}")
student.eval()
for p in student.parameters():
p.requires_grad = False
print(" Pre-encoding through frozen student...")
with torch.no_grad():
all_embs = []
for i in range(0, n_train, 512):
j = min(i + 512, n_train)
emb = student(train_ids[i:j], train_mask[i:j])
all_embs.append(emb)
student_embs = torch.cat(all_embs)
val_student_embs = student(val_ids, val_mask)
print(f" Student embeddings: {student_embs.shape}")
bank = AlignmentBank(
d_embed=768, n_experts=len(EXPERTS),
n_anchors=512, d_bank=128
).to(DEVICE)
bank.init_from_procrustes(procrustes_results, names,
consensus[:n_train], consensus_stats)
bank_params = sum(p.numel() for p in bank.parameters())
print(f" Bank: {bank_params:,} params")
print(f" Bank targets: CV={bank.target_cv.item():.4f}, "
f"mean_cos={bank.target_mean_cos.item():.4f}")
# Calibrate disagreement from initial state (before any training)
bank.calibrate_disagreement(student_embs[:2000])
bank_opt = torch.optim.AdamW(bank.parameters(), lr=1e-3, weight_decay=0.01)
BANK_EPOCHS = 20
BANK_BATCH = 256
for epoch in range(BANK_EPOCHS):
bank.train()
perm = torch.randperm(n_train, device=DEVICE)
total_loss = 0
stats = {"expert_agreement": 0, "rotation_ortho": 0,
"anchor_spread": 0, "bank_cv": 0, "emb_cv": 0,
"cross_expert_var": 0, "disagree_preserve": 0}
n = 0
t0 = time.time()
for i in range(0, n_train, BANK_BATCH):
idx = perm[i:i+BANK_BATCH]
if len(idx) < 16: continue
emb = student_embs[idx]
enriched, aux = bank(emb)
loss = bank.bank_loss(aux)
loss.backward()
torch.nn.utils.clip_grad_norm_(bank.parameters(), 1.0)
bank_opt.step(); bank_opt.zero_grad(set_to_none=True)
total_loss += loss.item()
for k in stats:
if k in aux:
v = aux[k]
stats[k] += v.item() if torch.is_tensor(v) else v
n += 1
elapsed = time.time() - t0; d = max(n, 1)
bank.eval()
with torch.no_grad():
v_enriched, v_aux = bank(val_student_embs)
v_loss = bank.bank_loss(v_aux).item()
print(f"\n E{epoch+1:2d}: {elapsed:.0f}s loss={total_loss/d:.4f} v_loss={v_loss:.4f}")
print(f" Geometry: b_cv={stats['bank_cv']/d:.4f} e_cv={stats['emb_cv']/d:.4f} "
f"spread={stats['anchor_spread']/d:.5f} a_max={v_aux['anchor_max_cos']:.3f}")
print(f" Experts: cos={v_aux['expert_cos_mean']:.3f}Β±{v_aux['expert_cos_std']:.3f} "
f"agr={stats['expert_agreement']/d:.6f} ortho={stats['rotation_ortho']/d:.6f}")
print(f" Disagree: x_cos={v_aux.get('cross_expert_cos', 0):.4f}Β±{v_aux.get('cross_expert_cos_std', 0):.4f} "
f"ratio={v_aux['disagreement_ratio']:.6f} "
f"preserve={stats['disagree_preserve']/d:.6f} "
f"norms={v_aux['norm_ratio_spread']:.4f}")
torch.save(bank.state_dict(), "alignment_bank.pt")
# ── Phase 3: Geometric Verification ──
print(f"\n{'='*65}")
print("PHASE 3: GEOMETRIC VERIFICATION")
print(f"{'='*65}")
bank.eval()
with torch.no_grad():
enriched_val, v_aux = bank(val_student_embs)
original_768 = enriched_val[:, :768]
geo_context = enriched_val[:, 768:]
passthrough_cos = F.cosine_similarity(
original_768[:100], val_student_embs[:100], dim=-1).mean().item()
geo_cv = cv_metric(F.normalize(geo_context[:1000], dim=-1))
S = torch.linalg.svdvals(
geo_context[:1000].float() - geo_context[:1000].float().mean(0))
geo_eff_dim = float((S.sum() ** 2) / (S.pow(2).sum() + 1e-12))
# Verify consensus stats are preserved
emb_cv = cv_metric(val_student_embs[:1000])
print(f" Passthrough: {passthrough_cos:.6f} (target: 1.000)")
print(f" Emb CV: {emb_cv:.4f} (consensus: {consensus_stats['cv']:.4f})")
print(f" Geo context CV: {geo_cv:.4f}")
print(f" Geo eff_dim: {geo_eff_dim:.1f} / {bank.d_bank}")
print(f" Expert cos: {v_aux['expert_cos_mean']:.3f} Β± {v_aux['expert_cos_std']:.3f}")
print(f" Anchor max cos: {v_aux['anchor_max_cos']:.3f}")
print(f" Disagreement:")
print(f" Cross-expert: {v_aux.get('cross_expert_cos', 0):.4f} Β± {v_aux.get('cross_expert_cos_std', 0):.4f}")
print(f" Ratio: {v_aux['disagreement_ratio']:.6f} (target: {bank.target_disagreement_ratio.item():.6f})")
print(f" Norm spread: {v_aux['norm_ratio_spread']:.4f}")
# ── Phase 4: Classifier Stability Test ──
print(f"\n{'='*65}")
print("PHASE 4: CLASSIFIER STABILITY TEST")
print(f"{'='*65}")
with torch.no_grad():
embs = val_student_embs[:1000]
sim = embs @ embs.T
sim.fill_diagonal_(-1)
n_pairs = 3000
idx_a = torch.randint(0, 1000, (n_pairs,))
idx_b = torch.randint(0, 1000, (n_pairs,))
pair_cos = sim[idx_a, idx_b]
sorted_cos, _ = pair_cos.sort()
t1 = sorted_cos[n_pairs // 3].item()
t2 = sorted_cos[2 * n_pairs // 3].item()
labels = torch.zeros(n_pairs, dtype=torch.long, device=DEVICE)
labels[pair_cos > t2] = 0
labels[(pair_cos <= t2) & (pair_cos > t1)] = 1
labels[pair_cos <= t1] = 2
enriched_a, _ = bank(embs[idx_a])
enriched_b, _ = bank(embs[idx_b])
for mode in ["with_bank", "without_bank"]:
if mode == "with_bank":
feat_dim = (768 + 128) * 2
features = torch.cat([enriched_a, enriched_b], dim=-1)
else:
feat_dim = 768 * 2
features = torch.cat([embs[idx_a], embs[idx_b]], dim=-1)
clf = nn.Sequential(
nn.Linear(feat_dim, 256), nn.GELU(), nn.LayerNorm(256),
nn.Linear(256, 3)
).to(DEVICE)
clf_opt = torch.optim.Adam(clf.parameters(), lr=1e-3)
n_clf_train = 2400
train_f = features[:n_clf_train].detach()
train_l = labels[:n_clf_train]
val_f = features[n_clf_train:].detach()
val_l = labels[n_clf_train:]
for e in range(30):
clf.train()
logits = clf(train_f)
loss = F.cross_entropy(logits, train_l)
loss.backward(); clf_opt.step(); clf_opt.zero_grad()
clf.eval()
with torch.no_grad():
v_acc = (clf(val_f).argmax(-1) == val_l).float().mean().item()
t_acc = (clf(train_f).argmax(-1) == train_l).float().mean().item()
print(f" {mode:15s}: train={t_acc:.3f} val={v_acc:.3f} gap={t_acc-v_acc:.3f}")
print(f"\n{'='*65}")
print("SUMMARY")
print(f"{'='*65}")
print(f" Consensus CV: {consensus_stats['cv']:.4f}")
print(f" Consensus eff_dim:{consensus_stats['eff_dim']:.1f}")
print(f" Student v_cos: {v_cos:.3f}")
print(f" Student v_cv: {v_cv:.3f}")
print(f" Bank params: {bank_params:,}")
print(f" Bank geo_eff_dim: {geo_eff_dim:.1f}")
print(f" Bank geo_cv: {geo_cv:.4f}")
print(f"\n{'='*65}")
print("DONE")
print(f"{'='*65}")
if __name__ == "__main__":
run()