| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import gc |
| | import math |
| | import os |
| | import time |
| | import json |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from tqdm import tqdm |
| |
|
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | EXPERTS = [ |
| | ("google-bert/bert-base-uncased", "bert", 512), |
| | ("answerdotai/ModernBERT-base", "modern", 512), |
| | ] |
| |
|
| | print("=" * 65) |
| | print("RAPID PROTOTYPE: 2-Expert Consensus + Alignment Bank") |
| | print("=" * 65) |
| | print(f" Device: {DEVICE}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class MiniStudent(nn.Module): |
| | def __init__(self, vocab_size=30522, max_len=512, d_model=256, |
| | n_heads=4, n_layers=4, d_ff=1024, output_dim=768, |
| | dropout=0.1, pad_token_id=0): |
| | super().__init__() |
| | self.pad_token_id = pad_token_id |
| | self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id) |
| | self.pos_emb = nn.Embedding(max_len, d_model) |
| | self.emb_norm = nn.LayerNorm(d_model) |
| | self.emb_drop = nn.Dropout(dropout) |
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, |
| | dropout=dropout, activation="gelu", batch_first=True, |
| | norm_first=True) |
| | self.encoder = nn.TransformerEncoder( |
| | encoder_layer, num_layers=n_layers, enable_nested_tensor=False) |
| | self.output_proj = nn.Sequential( |
| | nn.Linear(d_model, d_model), nn.GELU(), |
| | nn.LayerNorm(d_model), nn.Linear(d_model, output_dim)) |
| |
|
| | def forward(self, input_ids, attention_mask=None): |
| | B, L = input_ids.shape |
| | positions = torch.arange(L, device=input_ids.device).unsqueeze(0) |
| | x = self.token_emb(input_ids) + self.pos_emb(positions) |
| | x = self.emb_drop(self.emb_norm(x)) |
| | kpm = ~attention_mask.bool() if attention_mask is not None else (input_ids == self.pad_token_id) |
| | x = self.encoder(x, src_key_padding_mask=kpm) |
| | mask = attention_mask.unsqueeze(-1).float() if attention_mask is not None else (~kpm).unsqueeze(-1).float() |
| | pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1) |
| | return F.normalize(self.output_proj(pooled), dim=-1) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class AlignmentBank(nn.Module): |
| | """ |
| | Geometric interface layer. Learns to annotate student embeddings |
| | with per-expert alignment context and anchor distances. |
| | |
| | Trained on frozen student output. Provides geometric memory of |
| | the expert consensus for downstream heads. |
| | """ |
| | def __init__(self, d_embed=768, n_experts=2, n_anchors=128, d_bank=64): |
| | super().__init__() |
| | self.d_embed = d_embed |
| | self.n_experts = n_experts |
| | self.n_anchors = n_anchors |
| | self.d_bank = d_bank |
| |
|
| | |
| | self.expert_rotations = nn.ParameterList([ |
| | nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts) |
| | ]) |
| |
|
| | |
| | self.expert_means = nn.ParameterList([ |
| | nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts) |
| | ]) |
| |
|
| | |
| | self.anchors = nn.Parameter( |
| | F.normalize(torch.randn(n_anchors, d_embed), dim=-1)) |
| |
|
| | |
| | |
| | geo_dim = n_experts + n_anchors + n_experts |
| | self.geo_proj = nn.Sequential( |
| | nn.Linear(geo_dim, d_bank * 2), |
| | nn.GELU(), |
| | nn.LayerNorm(d_bank * 2), |
| | nn.Linear(d_bank * 2, d_bank), |
| | nn.LayerNorm(d_bank), |
| | ) |
| |
|
| | def init_from_procrustes(self, procrustes_results, expert_names, |
| | consensus_embeddings=None): |
| | """Initialize from consensus training artifacts.""" |
| | device = self.anchors.device |
| | for i, name in enumerate(expert_names[:self.n_experts]): |
| | info = procrustes_results[name] |
| | self.expert_rotations[i].data = info["rotation"].float().to(device) |
| | self.expert_means[i].data = info["source_mean"].float().to(device) |
| | print(f" Expert {i} ({name}): rotation loaded, cos_after={info['cos_after']:.4f}") |
| |
|
| | if consensus_embeddings is not None: |
| | n = min(self.n_anchors, consensus_embeddings.shape[0]) |
| | indices = torch.linspace(0, consensus_embeddings.shape[0] - 1, n).long() |
| | self.anchors.data[:n] = F.normalize( |
| | consensus_embeddings[indices].float(), dim=-1).to(device) |
| | print(f" Anchors: {n} initialized from consensus embeddings") |
| |
|
| | def forward(self, embedding): |
| | """ |
| | Annotate embedding with geometric context. |
| | |
| | Args: |
| | embedding: (B, 768) L2-normalized |
| | |
| | Returns: |
| | enriched: (B, 768 + d_bank) |
| | aux: dict with geometric losses and diagnostics |
| | """ |
| | B = embedding.shape[0] |
| | emb = embedding.float() |
| |
|
| | |
| | expert_consistency = [] |
| | expert_recon = [] |
| | for i in range(self.n_experts): |
| | R = self.expert_rotations[i] |
| | |
| | in_expert = emb @ R |
| | |
| | round_trip = in_expert @ R.T |
| | |
| | cos = F.cosine_similarity(emb, round_trip, dim=-1) |
| | recon = (emb - round_trip).pow(2).mean(dim=-1) |
| | expert_consistency.append(cos) |
| | expert_recon.append(recon) |
| |
|
| | expert_cos = torch.stack(expert_consistency, dim=-1) |
| | expert_mse = torch.stack(expert_recon, dim=-1) |
| |
|
| | |
| | anchors_n = F.normalize(self.anchors, dim=-1) |
| | anchor_cos = emb @ anchors_n.T |
| |
|
| | |
| | geo_input = torch.cat([expert_cos, anchor_cos, expert_mse], dim=-1) |
| | geo_context = self.geo_proj(geo_input) |
| |
|
| | |
| | enriched = torch.cat([embedding, geo_context], dim=-1) |
| |
|
| | |
| | aux = {} |
| |
|
| | |
| | expert_mean = expert_cos.mean(dim=-1, keepdim=True) |
| | aux["expert_agreement"] = (expert_cos - expert_mean).pow(2).mean() |
| |
|
| | |
| | ortho_loss = 0.0 |
| | for i in range(self.n_experts): |
| | R = self.expert_rotations[i] |
| | RRT = R @ R.T |
| | ortho_loss += (RRT - torch.eye(self.d_embed, device=R.device)).pow(2).mean() |
| | aux["rotation_ortho"] = ortho_loss / self.n_experts |
| |
|
| | |
| | anchor_sim = anchors_n @ anchors_n.T |
| | anchor_sim.fill_diagonal_(0) |
| | aux["anchor_spread"] = anchor_sim.pow(2).mean() |
| |
|
| | |
| | anchor_probs = F.softmax(anchor_cos * 10, dim=-1) |
| | entropy = -(anchor_probs * (anchor_probs + 1e-12).log()).sum(-1).mean() |
| | aux["anchor_entropy"] = entropy |
| |
|
| | |
| | if B >= 10: |
| | ctx_n = F.normalize(geo_context, dim=-1) |
| | vols = [] |
| | for _ in range(32): |
| | idx = torch.randperm(B, device=embedding.device)[:5] |
| | pts = ctx_n[idx].unsqueeze(0) |
| | diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| | d2 = (diff * diff).sum(-1) |
| | Bv, V, _ = d2.shape |
| | cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32) |
| | cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| | s = (-1.0)**V; f = math.factorial(V-1) |
| | v2 = s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
| | vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| | stacked = torch.stack(vols) |
| | bank_cv = stacked.std() / (stacked.mean() + 1e-8) |
| | aux["bank_cv"] = bank_cv |
| | else: |
| | aux["bank_cv"] = torch.tensor(0.0, device=embedding.device) |
| |
|
| | |
| | aux["expert_cos_mean"] = expert_cos.mean().item() |
| | aux["expert_cos_std"] = expert_cos.std().item() |
| | aux["anchor_max_cos"] = anchor_cos.max(dim=-1).values.mean().item() |
| | aux["anchor_mean_cos"] = anchor_cos.mean().item() |
| |
|
| | return enriched, aux |
| |
|
| | def bank_loss(self, aux, cv_target=0.15): |
| | """Combined bank training loss.""" |
| | loss = (1.0 * aux["expert_agreement"] + |
| | 1.0 * aux["rotation_ortho"] + |
| | 0.5 * aux["anchor_spread"] + |
| | 0.1 * aux["anchor_entropy"] + |
| | 0.3 * (aux["bank_cv"] - cv_target).abs()) |
| | return loss |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def infonce(a, b, temperature=0.07): |
| | a = F.normalize(a, dim=-1) |
| | b = F.normalize(b, dim=-1) |
| | logits = (a @ b.T) / temperature |
| | labels = torch.arange(logits.shape[0], device=logits.device) |
| | loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 |
| | with torch.no_grad(): |
| | acc = (logits.argmax(-1) == labels).float().mean().item() |
| | return loss, acc |
| |
|
| | def cayley_menger_vol2(pts): |
| | pts = pts.float() |
| | diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| | d2 = (diff * diff).sum(-1) |
| | B, V, _ = d2.shape |
| | cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| | cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| | s = (-1.0)**V; f = math.factorial(V-1) |
| | return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
| |
|
| | def cv_loss(emb, target=0.12, n_samples=16): |
| | B = emb.shape[0] |
| | if B < 5: return torch.tensor(0.0, device=emb.device) |
| | vols = [] |
| | for _ in range(n_samples): |
| | idx = torch.randperm(B, device=emb.device)[:5] |
| | v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| | vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| | stacked = torch.stack(vols) |
| | cv = stacked.std() / (stacked.mean() + 1e-8) |
| | return (cv - target).abs() |
| |
|
| | def cv_metric(emb, n=200): |
| | B = emb.shape[0] |
| | if B < 5: return 0.0 |
| | vols = [] |
| | for _ in range(n): |
| | idx = torch.randperm(B, device=emb.device)[:5] |
| | v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| | v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() |
| | if v > 0: vols.append(v) |
| | if len(vols) < 10: return 0.0 |
| | a = np.array(vols) |
| | return float(a.std() / (a.mean() + 1e-8)) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def symmetric_inv_sqrt(cov, eps=1e-6): |
| | evals, evecs = torch.linalg.eigh(cov) |
| | evals = torch.clamp(evals, min=eps) |
| | return evecs @ torch.diag(evals.rsqrt()) @ evecs.T |
| |
|
| | def procrustes_align(source, target, n_align=5000): |
| | N = min(n_align, source.shape[0], target.shape[0]) |
| | S = source[:N].float() |
| | T = target[:N].float() |
| | s_mean = S.mean(0, keepdim=True) |
| | t_mean = T.mean(0, keepdim=True) |
| | Sc = S - s_mean; Tc = T - t_mean |
| | N_s = Sc.shape[0] |
| | cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item() |
| | s_cov = (Sc.T @ Sc) / max(N_s - 1, 1) |
| | t_cov = (Tc.T @ Tc) / max(N_s - 1, 1) |
| | s_whiten = symmetric_inv_sqrt(s_cov) |
| | t_whiten = symmetric_inv_sqrt(t_cov) |
| | Sc_w = F.normalize(Sc @ s_whiten, dim=-1) |
| | Tc_w = F.normalize(Tc @ t_whiten, dim=-1) |
| | U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False) |
| | R = U @ Vt |
| | cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item() |
| | return { |
| | "rotation": R, "source_mean": s_mean.squeeze(0), |
| | "source_whitener": s_whiten, |
| | "target_unwhitener": torch.linalg.pinv(t_whiten), |
| | "cos_before": cos_before, "cos_after": cos_after, |
| | } |
| |
|
| | def apply_align(emb, a): |
| | x = emb.float() - a["source_mean"] |
| | x = x @ a["source_whitener"] |
| | x = x @ a["rotation"].T |
| | x = x @ a["target_unwhitener"] |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def run(): |
| | torch.manual_seed(42) |
| | np.random.seed(42) |
| | N_SAMPLES = 20000 |
| | MAX_LEN = 128 |
| | BATCH = 256 |
| |
|
| | |
| | print(f"\n{'='*65}") |
| | print("PHASE 0: EXTRACTION") |
| | print(f"{'='*65}") |
| |
|
| | from datasets import load_dataset |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| | ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext", |
| | split="train", streaming=True) |
| | captions = [] |
| | for row in ds: |
| | cap = row.get("caption_llava", "") |
| | if isinstance(cap, str) and len(cap) > 50: |
| | captions.append(cap) |
| | if len(captions) >= N_SAMPLES: |
| | break |
| | print(f" Captions: {len(captions):,}") |
| |
|
| | embeds = {} |
| | for model_name, short, max_len in EXPERTS: |
| | print(f"\n Extracting: {short}...") |
| | model = AutoModel.from_pretrained(model_name).to(DEVICE).eval() |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | all_emb = [] |
| | with torch.no_grad(): |
| | for i in tqdm(range(0, len(captions), 128), desc=f" {short}"): |
| | batch = captions[i:i+128] |
| | inputs = tokenizer(batch, max_length=max_len, padding=True, |
| | truncation=True, return_tensors="pt").to(DEVICE) |
| | out = model(**inputs) |
| | m = inputs.attention_mask.unsqueeze(-1).float() |
| | pooled = (out.last_hidden_state * m).sum(1) / m.sum(1).clamp(min=1) |
| | all_emb.append(pooled.cpu()) |
| | embeds[short] = torch.cat(all_emb) |
| | print(f" Shape: {embeds[short].shape}") |
| | del model; gc.collect(); torch.cuda.empty_cache() |
| |
|
| | |
| | print(f"\n{'='*65}") |
| | print("PHASE 0b: PROCRUSTES ALIGNMENT") |
| | print(f"{'='*65}") |
| |
|
| | ref = "bert" |
| | names = [s for _, s, _ in EXPERTS] |
| | procrustes_results = {} |
| | aligned = {} |
| | for name in names: |
| | info = procrustes_align(embeds[name], embeds[ref]) |
| | procrustes_results[name] = info |
| | aligned[name] = apply_align(embeds[name], info) |
| | print(f" {name:10s}: cos {info['cos_before']:.4f} β {info['cos_after']:.4f}") |
| |
|
| | consensus = F.normalize(sum(aligned[n] for n in names) / len(names), dim=-1) |
| | print(f" Consensus: {consensus.shape}") |
| | for name in names: |
| | cos = F.cosine_similarity(consensus[:2000], aligned[name][:2000], dim=-1).mean().item() |
| | print(f" cos(consensus, {name}): {cos:.4f}") |
| |
|
| | consensus_cv = cv_metric(consensus[:2000].to(DEVICE)) |
| | print(f" Consensus CV: {consensus_cv:.4f}") |
| |
|
| | del embeds, aligned |
| | gc.collect(); torch.cuda.empty_cache() |
| |
|
| | |
| | print(f"\n{'='*65}") |
| | print("PHASE 1: TRAIN STUDENT (2 experts, 20K captions)") |
| | print(f"{'='*65}") |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") |
| | tokens = tokenizer(captions, max_length=MAX_LEN, padding="max_length", |
| | truncation=True, return_tensors="pt") |
| | input_ids = tokens["input_ids"] |
| | attention_mask = tokens["attention_mask"] |
| |
|
| | n_train = N_SAMPLES - 2000 |
| | train_ids = input_ids[:n_train].to(DEVICE) |
| | train_mask = attention_mask[:n_train].to(DEVICE) |
| | train_targets = consensus[:n_train].to(DEVICE) |
| | val_ids = input_ids[n_train:].to(DEVICE) |
| | val_mask = attention_mask[n_train:].to(DEVICE) |
| | val_targets = consensus[n_train:].to(DEVICE) |
| |
|
| | student = MiniStudent( |
| | vocab_size=tokenizer.vocab_size, max_len=MAX_LEN, |
| | d_model=256, n_heads=4, n_layers=4, d_ff=1024, |
| | output_dim=768, dropout=0.1, pad_token_id=tokenizer.pad_token_id |
| | ).to(DEVICE) |
| | n_params = sum(p.numel() for p in student.parameters()) |
| | print(f" Student: {n_params:,} params") |
| |
|
| | optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay=0.01) |
| |
|
| | for epoch in range(5): |
| | student.train() |
| | perm = torch.randperm(n_train, device=DEVICE) |
| | t_loss, t_acc, t_cos, n = 0, 0, 0, 0 |
| | t0 = time.time() |
| |
|
| | for i in range(0, n_train, BATCH): |
| | idx = perm[i:i+BATCH] |
| | if len(idx) < 8: continue |
| | emb = student(train_ids[idx], train_mask[idx]) |
| | tgt = train_targets[idx] |
| | l_nce, acc = infonce(emb, tgt) |
| | l_mse = F.mse_loss(emb, tgt) |
| | l_cv = cv_loss(emb, target=consensus_cv) |
| | loss = l_nce + l_mse + 0.1 * l_cv |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0) |
| | optimizer.step(); optimizer.zero_grad(set_to_none=True) |
| | with torch.no_grad(): |
| | cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item() |
| | t_loss += loss.item(); t_acc += acc; t_cos += cos; n += 1 |
| |
|
| | elapsed = time.time() - t0 |
| | d = max(n, 1) |
| | student.eval() |
| | with torch.no_grad(): |
| | v_emb = student(val_ids, val_mask) |
| | _, v_acc = infonce(v_emb[:1000], val_targets[:1000]) |
| | v_cos = F.cosine_similarity(v_emb, val_targets, dim=-1).mean().item() |
| | v_cv = cv_metric(v_emb[:1000]) |
| |
|
| | print(f" E{epoch+1}: {elapsed:.0f}s loss={t_loss/d:.4f} " |
| | f"t_acc={t_acc/d:.3f} t_cos={t_cos/d:.3f} " |
| | f"v_acc={v_acc:.3f} v_cos={v_cos:.3f} v_cv={v_cv:.3f}") |
| |
|
| | |
| | torch.save(student.state_dict(), "mini_student.pt") |
| | print(f"\n Student saved. v_cos={v_cos:.3f}, v_cv={v_cv:.3f}") |
| |
|
| | |
| | print(f"\n{'='*65}") |
| | print("PHASE 2: TRAIN ALIGNMENT BANK (student frozen)") |
| | print(f"{'='*65}") |
| |
|
| | |
| | student.eval() |
| | for p in student.parameters(): |
| | p.requires_grad = False |
| |
|
| | |
| | print(" Pre-encoding through frozen student...") |
| | with torch.no_grad(): |
| | all_embs = [] |
| | for i in range(0, n_train, 512): |
| | j = min(i + 512, n_train) |
| | emb = student(train_ids[i:j], train_mask[i:j]) |
| | all_embs.append(emb) |
| | student_embs = torch.cat(all_embs) |
| | val_student_embs = student(val_ids, val_mask) |
| |
|
| | print(f" Student embeddings: {student_embs.shape}") |
| |
|
| | |
| | bank = AlignmentBank( |
| | d_embed=768, n_experts=len(EXPERTS), |
| | n_anchors=128, d_bank=64 |
| | ).to(DEVICE) |
| |
|
| | bank.init_from_procrustes(procrustes_results, names, consensus[:n_train]) |
| | bank_params = sum(p.numel() for p in bank.parameters()) |
| | print(f" Bank: {bank_params:,} params") |
| |
|
| | bank_opt = torch.optim.AdamW(bank.parameters(), lr=1e-3, weight_decay=0.01) |
| | BANK_EPOCHS = 20 |
| | BANK_BATCH = 256 |
| |
|
| | for epoch in range(BANK_EPOCHS): |
| | bank.train() |
| | perm = torch.randperm(n_train, device=DEVICE) |
| | total_loss = 0 |
| | stats = {"expert_agreement": 0, "rotation_ortho": 0, |
| | "anchor_spread": 0, "bank_cv": 0} |
| | n = 0 |
| | t0 = time.time() |
| |
|
| | for i in range(0, n_train, BANK_BATCH): |
| | idx = perm[i:i+BANK_BATCH] |
| | if len(idx) < 16: continue |
| |
|
| | emb = student_embs[idx] |
| | enriched, aux = bank(emb) |
| | loss = bank.bank_loss(aux, cv_target=consensus_cv + 0.02) |
| |
|
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(bank.parameters(), 1.0) |
| | bank_opt.step(); bank_opt.zero_grad(set_to_none=True) |
| |
|
| | total_loss += loss.item() |
| | for k in stats: |
| | if k in aux: |
| | v = aux[k] |
| | stats[k] += v.item() if torch.is_tensor(v) else v |
| | n += 1 |
| |
|
| | elapsed = time.time() - t0 |
| | d = max(n, 1) |
| |
|
| | |
| | bank.eval() |
| | with torch.no_grad(): |
| | v_enriched, v_aux = bank(val_student_embs) |
| | v_loss = bank.bank_loss(v_aux, cv_target=consensus_cv + 0.02).item() |
| |
|
| | print(f" E{epoch+1:2d}: {elapsed:.0f}s loss={total_loss/d:.4f} " |
| | f"v_loss={v_loss:.4f} " |
| | f"expert_agr={stats['expert_agreement']/d:.5f} " |
| | f"ortho={stats['rotation_ortho']/d:.5f} " |
| | f"spread={stats['anchor_spread']/d:.5f} " |
| | f"cv={stats['bank_cv']/d:.4f} " |
| | f"anchor_max={v_aux['anchor_max_cos']:.3f} " |
| | f"expert_cos={v_aux['expert_cos_mean']:.3f}Β±{v_aux['expert_cos_std']:.3f}") |
| |
|
| | torch.save(bank.state_dict(), "alignment_bank.pt") |
| |
|
| | |
| | print(f"\n{'='*65}") |
| | print("PHASE 3: GEOMETRIC VERIFICATION") |
| | print(f"{'='*65}") |
| |
|
| | bank.eval() |
| | with torch.no_grad(): |
| | |
| | enriched_val, _ = bank(val_student_embs) |
| | original_768 = enriched_val[:, :768] |
| | geo_context = enriched_val[:, 768:] |
| |
|
| | |
| | passthrough_cos = F.cosine_similarity( |
| | original_768[:100], val_student_embs[:100], dim=-1).mean().item() |
| |
|
| | |
| | geo_cv = cv_metric(F.normalize(geo_context[:1000], dim=-1)) |
| | geo_eff_dim = torch.linalg.svdvals( |
| | geo_context[:1000].float() - geo_context[:1000].float().mean(0)).pow(2) |
| | geo_eff_dim = (geo_eff_dim.sum() ** 2) / (geo_eff_dim.pow(2).sum() + 1e-12) |
| |
|
| | print(f" Passthrough integrity: {passthrough_cos:.6f} (should be ~1.000)") |
| | print(f" Geo context CV: {geo_cv:.4f}") |
| | print(f" Geo context eff_dim: {geo_eff_dim:.1f}") |
| | print(f" Geo context shape: {geo_context.shape}") |
| |
|
| | |
| | print(f"\n{'='*65}") |
| | print("PHASE 4: CLASSIFIER STABILITY TEST") |
| | print(f"{'='*65}") |
| |
|
| | |
| | |
| | |
| | |
| | with torch.no_grad(): |
| | |
| | embs = val_student_embs[:1000] |
| | sim = embs @ embs.T |
| | sim.fill_diagonal_(-1) |
| |
|
| | |
| | n_pairs = 3000 |
| | idx_a = torch.randint(0, 1000, (n_pairs,)) |
| | idx_b = torch.randint(0, 1000, (n_pairs,)) |
| | pair_cos = sim[idx_a, idx_b] |
| |
|
| | |
| | sorted_cos, _ = pair_cos.sort() |
| | t1 = sorted_cos[n_pairs // 3].item() |
| | t2 = sorted_cos[2 * n_pairs // 3].item() |
| | labels = torch.zeros(n_pairs, dtype=torch.long, device=DEVICE) |
| | labels[pair_cos > t2] = 0 |
| | labels[(pair_cos <= t2) & (pair_cos > t1)] = 1 |
| | labels[pair_cos <= t1] = 2 |
| |
|
| | |
| | enriched_a, _ = bank(embs[idx_a]) |
| | enriched_b, _ = bank(embs[idx_b]) |
| |
|
| | |
| | for mode in ["with_bank", "without_bank"]: |
| | if mode == "with_bank": |
| | feat_dim = (768 + 64) * 2 |
| | features = torch.cat([enriched_a, enriched_b], dim=-1) |
| | else: |
| | feat_dim = 768 * 2 |
| | features = torch.cat([embs[idx_a], embs[idx_b]], dim=-1) |
| |
|
| | clf = nn.Sequential( |
| | nn.Linear(feat_dim, 128), nn.GELU(), |
| | nn.Linear(128, 3) |
| | ).to(DEVICE) |
| |
|
| | clf_opt = torch.optim.Adam(clf.parameters(), lr=1e-3) |
| | n_clf_train = 2400 |
| | train_f = features[:n_clf_train].detach() |
| | train_l = labels[:n_clf_train] |
| | val_f = features[n_clf_train:].detach() |
| | val_l = labels[n_clf_train:] |
| |
|
| | for e in range(20): |
| | clf.train() |
| | logits = clf(train_f) |
| | loss = F.cross_entropy(logits, train_l) |
| | loss.backward() |
| | clf_opt.step(); clf_opt.zero_grad() |
| |
|
| | clf.eval() |
| | with torch.no_grad(): |
| | val_logits = clf(val_f) |
| | val_acc = (val_logits.argmax(-1) == val_l).float().mean().item() |
| | train_logits = clf(train_f) |
| | train_acc = (train_logits.argmax(-1) == train_l).float().mean().item() |
| |
|
| | print(f" {mode:15s}: train_acc={train_acc:.3f} val_acc={val_acc:.3f} " |
| | f"gap={train_acc-val_acc:.3f}") |
| |
|
| | print(f"\n{'='*65}") |
| | print("DONE") |
| | print(f"{'='*65}") |
| | print(f"\n Student: mini_student.pt") |
| | print(f" Bank: alignment_bank.pt") |
| | print(f" Consensus CV: {consensus_cv:.4f}") |
| | print(f" Student v_cos: {v_cos:.3f}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | run() |