# ============================================================================ # RAPID PROTOTYPE: 2-Expert Consensus + Alignment Bank # # Fast iteration cycle: # Phase 1: Train student on 2-BERT consensus (20K captions, ~2 epochs) # Phase 2: Freeze student, train alignment bank on its output # Phase 3: Verify bank preserves geometry # Phase 4: Snap a tiny classifier on bank output, check stability # ============================================================================ import gc import math import os import time import json import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm DEVICE = "cuda" if torch.cuda.is_available() else "cpu" EXPERTS = [ ("google-bert/bert-base-uncased", "bert", 512), ("answerdotai/ModernBERT-base", "modern", 512), ] print("=" * 65) print("RAPID PROTOTYPE: 2-Expert Consensus + Alignment Bank") print("=" * 65) print(f" Device: {DEVICE}") # ══════════════════════════════════════════════════════════════════ # STUDENT MODEL # ══════════════════════════════════════════════════════════════════ class MiniStudent(nn.Module): def __init__(self, vocab_size=30522, max_len=512, d_model=256, n_heads=4, n_layers=4, d_ff=1024, output_dim=768, dropout=0.1, pad_token_id=0): super().__init__() self.pad_token_id = pad_token_id self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id) self.pos_emb = nn.Embedding(max_len, d_model) self.emb_norm = nn.LayerNorm(d_model) self.emb_drop = nn.Dropout(dropout) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, dropout=dropout, activation="gelu", batch_first=True, norm_first=True) self.encoder = nn.TransformerEncoder( encoder_layer, num_layers=n_layers, enable_nested_tensor=False) self.output_proj = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.LayerNorm(d_model), nn.Linear(d_model, output_dim)) def forward(self, input_ids, attention_mask=None): B, L = input_ids.shape positions = torch.arange(L, device=input_ids.device).unsqueeze(0) x = self.token_emb(input_ids) + self.pos_emb(positions) x = self.emb_drop(self.emb_norm(x)) kpm = ~attention_mask.bool() if attention_mask is not None else (input_ids == self.pad_token_id) x = self.encoder(x, src_key_padding_mask=kpm) mask = attention_mask.unsqueeze(-1).float() if attention_mask is not None else (~kpm).unsqueeze(-1).float() pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1) return F.normalize(self.output_proj(pooled), dim=-1) # ══════════════════════════════════════════════════════════════════ # ALIGNMENT BANK # ══════════════════════════════════════════════════════════════════ class AlignmentBank(nn.Module): """ Geometric interface layer. Learns to annotate student embeddings with per-expert alignment context and anchor distances. Trained on frozen student output. Provides geometric memory of the expert consensus for downstream heads. """ def __init__(self, d_embed=768, n_experts=2, n_anchors=128, d_bank=64): super().__init__() self.d_embed = d_embed self.n_experts = n_experts self.n_anchors = n_anchors self.d_bank = d_bank # Per-expert rotation matrices (initialized from Procrustes) self.expert_rotations = nn.ParameterList([ nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts) ]) # Per-expert bias (mean offset in each expert's space) self.expert_means = nn.ParameterList([ nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts) ]) # Anchor bank: learned consensus landmarks self.anchors = nn.Parameter( F.normalize(torch.randn(n_anchors, d_embed), dim=-1)) # Project geometric features into compact context # Input: n_experts (consistency) + n_anchors (distances) + n_experts (reconstruction quality) geo_dim = n_experts + n_anchors + n_experts self.geo_proj = nn.Sequential( nn.Linear(geo_dim, d_bank * 2), nn.GELU(), nn.LayerNorm(d_bank * 2), nn.Linear(d_bank * 2, d_bank), nn.LayerNorm(d_bank), ) def init_from_procrustes(self, procrustes_results, expert_names, consensus_embeddings=None): """Initialize from consensus training artifacts.""" device = self.anchors.device for i, name in enumerate(expert_names[:self.n_experts]): info = procrustes_results[name] self.expert_rotations[i].data = info["rotation"].float().to(device) self.expert_means[i].data = info["source_mean"].float().to(device) print(f" Expert {i} ({name}): rotation loaded, cos_after={info['cos_after']:.4f}") if consensus_embeddings is not None: n = min(self.n_anchors, consensus_embeddings.shape[0]) indices = torch.linspace(0, consensus_embeddings.shape[0] - 1, n).long() self.anchors.data[:n] = F.normalize( consensus_embeddings[indices].float(), dim=-1).to(device) print(f" Anchors: {n} initialized from consensus embeddings") def forward(self, embedding): """ Annotate embedding with geometric context. Args: embedding: (B, 768) L2-normalized Returns: enriched: (B, 768 + d_bank) aux: dict with geometric losses and diagnostics """ B = embedding.shape[0] emb = embedding.float() # Per-expert: rotate into expert space, measure reconstruction quality expert_consistency = [] # cosine between original and round-tripped expert_recon = [] # MSE of round-trip for i in range(self.n_experts): R = self.expert_rotations[i] # Forward rotation: consensus → expert space in_expert = emb @ R # Backward rotation: expert space → consensus round_trip = in_expert @ R.T # How well does round-trip recover original? cos = F.cosine_similarity(emb, round_trip, dim=-1) # (B,) recon = (emb - round_trip).pow(2).mean(dim=-1) # (B,) expert_consistency.append(cos) expert_recon.append(recon) expert_cos = torch.stack(expert_consistency, dim=-1) # (B, n_experts) expert_mse = torch.stack(expert_recon, dim=-1) # (B, n_experts) # Anchor distances anchors_n = F.normalize(self.anchors, dim=-1) anchor_cos = emb @ anchors_n.T # (B, n_anchors) # Geometric context vector geo_input = torch.cat([expert_cos, anchor_cos, expert_mse], dim=-1) geo_context = self.geo_proj(geo_input) # (B, d_bank) # Enriched output enriched = torch.cat([embedding, geo_context], dim=-1) # ── Geometric losses ── aux = {} # 1. Expert agreement: all experts should see the embedding similarly expert_mean = expert_cos.mean(dim=-1, keepdim=True) aux["expert_agreement"] = (expert_cos - expert_mean).pow(2).mean() # 2. Rotation orthogonality: rotations should stay orthogonal ortho_loss = 0.0 for i in range(self.n_experts): R = self.expert_rotations[i] RRT = R @ R.T ortho_loss += (RRT - torch.eye(self.d_embed, device=R.device)).pow(2).mean() aux["rotation_ortho"] = ortho_loss / self.n_experts # 3. Anchor spread: anchors should be well-distributed anchor_sim = anchors_n @ anchors_n.T anchor_sim.fill_diagonal_(0) aux["anchor_spread"] = anchor_sim.pow(2).mean() # 4. Anchor sharpness: each embedding should have clear nearest anchors anchor_probs = F.softmax(anchor_cos * 10, dim=-1) entropy = -(anchor_probs * (anchor_probs + 1e-12).log()).sum(-1).mean() aux["anchor_entropy"] = entropy # 5. Pentachoron CV of enriched space (sample from geo_context) if B >= 10: ctx_n = F.normalize(geo_context, dim=-1) vols = [] for _ in range(32): idx = torch.randperm(B, device=embedding.device)[:5] pts = ctx_n[idx].unsqueeze(0) diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) d2 = (diff * diff).sum(-1) Bv, V, _ = d2.shape cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 s = (-1.0)**V; f = math.factorial(V-1) v2 = s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) stacked = torch.stack(vols) bank_cv = stacked.std() / (stacked.mean() + 1e-8) aux["bank_cv"] = bank_cv else: aux["bank_cv"] = torch.tensor(0.0, device=embedding.device) # Summary diagnostics aux["expert_cos_mean"] = expert_cos.mean().item() aux["expert_cos_std"] = expert_cos.std().item() aux["anchor_max_cos"] = anchor_cos.max(dim=-1).values.mean().item() aux["anchor_mean_cos"] = anchor_cos.mean().item() return enriched, aux def bank_loss(self, aux, cv_target=0.15): """Combined bank training loss.""" loss = (1.0 * aux["expert_agreement"] + 1.0 * aux["rotation_ortho"] + 0.5 * aux["anchor_spread"] + 0.1 * aux["anchor_entropy"] + 0.3 * (aux["bank_cv"] - cv_target).abs()) return loss # ══════════════════════════════════════════════════════════════════ # GEOMETRY # ══════════════════════════════════════════════════════════════════ def infonce(a, b, temperature=0.07): a = F.normalize(a, dim=-1) b = F.normalize(b, dim=-1) logits = (a @ b.T) / temperature labels = torch.arange(logits.shape[0], device=logits.device) loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 with torch.no_grad(): acc = (logits.argmax(-1) == labels).float().mean().item() return loss, acc def cayley_menger_vol2(pts): pts = pts.float() diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) d2 = (diff * diff).sum(-1) B, V, _ = d2.shape cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 s = (-1.0)**V; f = math.factorial(V-1) return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) def cv_loss(emb, target=0.12, n_samples=16): B = emb.shape[0] if B < 5: return torch.tensor(0.0, device=emb.device) vols = [] for _ in range(n_samples): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) stacked = torch.stack(vols) cv = stacked.std() / (stacked.mean() + 1e-8) return (cv - target).abs() def cv_metric(emb, n=200): B = emb.shape[0] if B < 5: return 0.0 vols = [] for _ in range(n): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() if v > 0: vols.append(v) if len(vols) < 10: return 0.0 a = np.array(vols) return float(a.std() / (a.mean() + 1e-8)) # ══════════════════════════════════════════════════════════════════ # EXTRACTION + ALIGNMENT # ══════════════════════════════════════════════════════════════════ def symmetric_inv_sqrt(cov, eps=1e-6): evals, evecs = torch.linalg.eigh(cov) evals = torch.clamp(evals, min=eps) return evecs @ torch.diag(evals.rsqrt()) @ evecs.T def procrustes_align(source, target, n_align=5000): N = min(n_align, source.shape[0], target.shape[0]) S = source[:N].float() T = target[:N].float() s_mean = S.mean(0, keepdim=True) t_mean = T.mean(0, keepdim=True) Sc = S - s_mean; Tc = T - t_mean N_s = Sc.shape[0] cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item() s_cov = (Sc.T @ Sc) / max(N_s - 1, 1) t_cov = (Tc.T @ Tc) / max(N_s - 1, 1) s_whiten = symmetric_inv_sqrt(s_cov) t_whiten = symmetric_inv_sqrt(t_cov) Sc_w = F.normalize(Sc @ s_whiten, dim=-1) Tc_w = F.normalize(Tc @ t_whiten, dim=-1) U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False) R = U @ Vt cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item() return { "rotation": R, "source_mean": s_mean.squeeze(0), "source_whitener": s_whiten, "target_unwhitener": torch.linalg.pinv(t_whiten), "cos_before": cos_before, "cos_after": cos_after, } def apply_align(emb, a): x = emb.float() - a["source_mean"] x = x @ a["source_whitener"] x = x @ a["rotation"].T x = x @ a["target_unwhitener"] return x # ══════════════════════════════════════════════════════════════════ # MAIN # ══════════════════════════════════════════════════════════════════ def run(): torch.manual_seed(42) np.random.seed(42) N_SAMPLES = 20000 MAX_LEN = 128 BATCH = 256 # ── Phase 0: Extract ── print(f"\n{'='*65}") print("PHASE 0: EXTRACTION") print(f"{'='*65}") from datasets import load_dataset from transformers import AutoModel, AutoTokenizer ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext", split="train", streaming=True) captions = [] for row in ds: cap = row.get("caption_llava", "") if isinstance(cap, str) and len(cap) > 50: captions.append(cap) if len(captions) >= N_SAMPLES: break print(f" Captions: {len(captions):,}") embeds = {} for model_name, short, max_len in EXPERTS: print(f"\n Extracting: {short}...") model = AutoModel.from_pretrained(model_name).to(DEVICE).eval() tokenizer = AutoTokenizer.from_pretrained(model_name) all_emb = [] with torch.no_grad(): for i in tqdm(range(0, len(captions), 128), desc=f" {short}"): batch = captions[i:i+128] inputs = tokenizer(batch, max_length=max_len, padding=True, truncation=True, return_tensors="pt").to(DEVICE) out = model(**inputs) m = inputs.attention_mask.unsqueeze(-1).float() pooled = (out.last_hidden_state * m).sum(1) / m.sum(1).clamp(min=1) all_emb.append(pooled.cpu()) embeds[short] = torch.cat(all_emb) print(f" Shape: {embeds[short].shape}") del model; gc.collect(); torch.cuda.empty_cache() # ── Phase 0b: Align + Consensus ── print(f"\n{'='*65}") print("PHASE 0b: PROCRUSTES ALIGNMENT") print(f"{'='*65}") ref = "bert" names = [s for _, s, _ in EXPERTS] procrustes_results = {} aligned = {} for name in names: info = procrustes_align(embeds[name], embeds[ref]) procrustes_results[name] = info aligned[name] = apply_align(embeds[name], info) print(f" {name:10s}: cos {info['cos_before']:.4f} → {info['cos_after']:.4f}") consensus = F.normalize(sum(aligned[n] for n in names) / len(names), dim=-1) print(f" Consensus: {consensus.shape}") for name in names: cos = F.cosine_similarity(consensus[:2000], aligned[name][:2000], dim=-1).mean().item() print(f" cos(consensus, {name}): {cos:.4f}") consensus_cv = cv_metric(consensus[:2000].to(DEVICE)) print(f" Consensus CV: {consensus_cv:.4f}") del embeds, aligned gc.collect(); torch.cuda.empty_cache() # ── Phase 1: Train Student ── print(f"\n{'='*65}") print("PHASE 1: TRAIN STUDENT (2 experts, 20K captions)") print(f"{'='*65}") tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") tokens = tokenizer(captions, max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="pt") input_ids = tokens["input_ids"] attention_mask = tokens["attention_mask"] n_train = N_SAMPLES - 2000 train_ids = input_ids[:n_train].to(DEVICE) train_mask = attention_mask[:n_train].to(DEVICE) train_targets = consensus[:n_train].to(DEVICE) val_ids = input_ids[n_train:].to(DEVICE) val_mask = attention_mask[n_train:].to(DEVICE) val_targets = consensus[n_train:].to(DEVICE) student = MiniStudent( vocab_size=tokenizer.vocab_size, max_len=MAX_LEN, d_model=256, n_heads=4, n_layers=4, d_ff=1024, output_dim=768, dropout=0.1, pad_token_id=tokenizer.pad_token_id ).to(DEVICE) n_params = sum(p.numel() for p in student.parameters()) print(f" Student: {n_params:,} params") optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay=0.01) for epoch in range(5): student.train() perm = torch.randperm(n_train, device=DEVICE) t_loss, t_acc, t_cos, n = 0, 0, 0, 0 t0 = time.time() for i in range(0, n_train, BATCH): idx = perm[i:i+BATCH] if len(idx) < 8: continue emb = student(train_ids[idx], train_mask[idx]) tgt = train_targets[idx] l_nce, acc = infonce(emb, tgt) l_mse = F.mse_loss(emb, tgt) l_cv = cv_loss(emb, target=consensus_cv) loss = l_nce + l_mse + 0.1 * l_cv loss.backward() torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0) optimizer.step(); optimizer.zero_grad(set_to_none=True) with torch.no_grad(): cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item() t_loss += loss.item(); t_acc += acc; t_cos += cos; n += 1 elapsed = time.time() - t0 d = max(n, 1) student.eval() with torch.no_grad(): v_emb = student(val_ids, val_mask) _, v_acc = infonce(v_emb[:1000], val_targets[:1000]) v_cos = F.cosine_similarity(v_emb, val_targets, dim=-1).mean().item() v_cv = cv_metric(v_emb[:1000]) print(f" E{epoch+1}: {elapsed:.0f}s loss={t_loss/d:.4f} " f"t_acc={t_acc/d:.3f} t_cos={t_cos/d:.3f} " f"v_acc={v_acc:.3f} v_cos={v_cos:.3f} v_cv={v_cv:.3f}") # Save student torch.save(student.state_dict(), "mini_student.pt") print(f"\n Student saved. v_cos={v_cos:.3f}, v_cv={v_cv:.3f}") # ── Phase 2: Train Alignment Bank ── print(f"\n{'='*65}") print("PHASE 2: TRAIN ALIGNMENT BANK (student frozen)") print(f"{'='*65}") # Freeze student student.eval() for p in student.parameters(): p.requires_grad = False # Pre-encode everything through frozen student print(" Pre-encoding through frozen student...") with torch.no_grad(): all_embs = [] for i in range(0, n_train, 512): j = min(i + 512, n_train) emb = student(train_ids[i:j], train_mask[i:j]) all_embs.append(emb) student_embs = torch.cat(all_embs) # (n_train, 768) val_student_embs = student(val_ids, val_mask) print(f" Student embeddings: {student_embs.shape}") # Build bank bank = AlignmentBank( d_embed=768, n_experts=len(EXPERTS), n_anchors=128, d_bank=64 ).to(DEVICE) bank.init_from_procrustes(procrustes_results, names, consensus[:n_train]) bank_params = sum(p.numel() for p in bank.parameters()) print(f" Bank: {bank_params:,} params") bank_opt = torch.optim.AdamW(bank.parameters(), lr=1e-3, weight_decay=0.01) BANK_EPOCHS = 20 BANK_BATCH = 256 for epoch in range(BANK_EPOCHS): bank.train() perm = torch.randperm(n_train, device=DEVICE) total_loss = 0 stats = {"expert_agreement": 0, "rotation_ortho": 0, "anchor_spread": 0, "bank_cv": 0} n = 0 t0 = time.time() for i in range(0, n_train, BANK_BATCH): idx = perm[i:i+BANK_BATCH] if len(idx) < 16: continue emb = student_embs[idx] enriched, aux = bank(emb) loss = bank.bank_loss(aux, cv_target=consensus_cv + 0.02) loss.backward() torch.nn.utils.clip_grad_norm_(bank.parameters(), 1.0) bank_opt.step(); bank_opt.zero_grad(set_to_none=True) total_loss += loss.item() for k in stats: if k in aux: v = aux[k] stats[k] += v.item() if torch.is_tensor(v) else v n += 1 elapsed = time.time() - t0 d = max(n, 1) # Validation bank.eval() with torch.no_grad(): v_enriched, v_aux = bank(val_student_embs) v_loss = bank.bank_loss(v_aux, cv_target=consensus_cv + 0.02).item() print(f" E{epoch+1:2d}: {elapsed:.0f}s loss={total_loss/d:.4f} " f"v_loss={v_loss:.4f} " f"expert_agr={stats['expert_agreement']/d:.5f} " f"ortho={stats['rotation_ortho']/d:.5f} " f"spread={stats['anchor_spread']/d:.5f} " f"cv={stats['bank_cv']/d:.4f} " f"anchor_max={v_aux['anchor_max_cos']:.3f} " f"expert_cos={v_aux['expert_cos_mean']:.3f}±{v_aux['expert_cos_std']:.3f}") torch.save(bank.state_dict(), "alignment_bank.pt") # ── Phase 3: Verify Geometry ── print(f"\n{'='*65}") print("PHASE 3: GEOMETRIC VERIFICATION") print(f"{'='*65}") bank.eval() with torch.no_grad(): # Check that enriched embeddings preserve original structure enriched_val, _ = bank(val_student_embs) original_768 = enriched_val[:, :768] # first 768 dims = original embedding geo_context = enriched_val[:, 768:] # last d_bank dims = geometric annotation # Original embedding should be unchanged (passthrough) passthrough_cos = F.cosine_similarity( original_768[:100], val_student_embs[:100], dim=-1).mean().item() # Geometric context should be informative geo_cv = cv_metric(F.normalize(geo_context[:1000], dim=-1)) geo_eff_dim = torch.linalg.svdvals( geo_context[:1000].float() - geo_context[:1000].float().mean(0)).pow(2) geo_eff_dim = (geo_eff_dim.sum() ** 2) / (geo_eff_dim.pow(2).sum() + 1e-12) print(f" Passthrough integrity: {passthrough_cos:.6f} (should be ~1.000)") print(f" Geo context CV: {geo_cv:.4f}") print(f" Geo context eff_dim: {geo_eff_dim:.1f}") print(f" Geo context shape: {geo_context.shape}") # ── Phase 4: Quick Classifier Test ── print(f"\n{'='*65}") print("PHASE 4: CLASSIFIER STABILITY TEST") print(f"{'='*65}") # Create synthetic 3-class task from similarity structure # Class 0: high consensus cosine pairs (similar) # Class 1: medium consensus cosine pairs # Class 2: low consensus cosine pairs (different) with torch.no_grad(): # Generate synthetic labels from embedding distances embs = val_student_embs[:1000] sim = embs @ embs.T sim.fill_diagonal_(-1) # exclude self # Random pairs n_pairs = 3000 idx_a = torch.randint(0, 1000, (n_pairs,)) idx_b = torch.randint(0, 1000, (n_pairs,)) pair_cos = sim[idx_a, idx_b] # Assign labels by cosine terciles sorted_cos, _ = pair_cos.sort() t1 = sorted_cos[n_pairs // 3].item() t2 = sorted_cos[2 * n_pairs // 3].item() labels = torch.zeros(n_pairs, dtype=torch.long, device=DEVICE) labels[pair_cos > t2] = 0 # similar labels[(pair_cos <= t2) & (pair_cos > t1)] = 1 # medium labels[pair_cos <= t1] = 2 # different # Get enriched representations enriched_a, _ = bank(embs[idx_a]) enriched_b, _ = bank(embs[idx_b]) # Train tiny classifier: with bank vs without bank for mode in ["with_bank", "without_bank"]: if mode == "with_bank": feat_dim = (768 + 64) * 2 # enriched features = torch.cat([enriched_a, enriched_b], dim=-1) else: feat_dim = 768 * 2 # raw features = torch.cat([embs[idx_a], embs[idx_b]], dim=-1) clf = nn.Sequential( nn.Linear(feat_dim, 128), nn.GELU(), nn.Linear(128, 3) ).to(DEVICE) clf_opt = torch.optim.Adam(clf.parameters(), lr=1e-3) n_clf_train = 2400 train_f = features[:n_clf_train].detach() train_l = labels[:n_clf_train] val_f = features[n_clf_train:].detach() val_l = labels[n_clf_train:] for e in range(20): clf.train() logits = clf(train_f) loss = F.cross_entropy(logits, train_l) loss.backward() clf_opt.step(); clf_opt.zero_grad() clf.eval() with torch.no_grad(): val_logits = clf(val_f) val_acc = (val_logits.argmax(-1) == val_l).float().mean().item() train_logits = clf(train_f) train_acc = (train_logits.argmax(-1) == train_l).float().mean().item() print(f" {mode:15s}: train_acc={train_acc:.3f} val_acc={val_acc:.3f} " f"gap={train_acc-val_acc:.3f}") print(f"\n{'='*65}") print("DONE") print(f"{'='*65}") print(f"\n Student: mini_student.pt") print(f" Bank: alignment_bank.pt") print(f" Consensus CV: {consensus_cv:.4f}") print(f" Student v_cos: {v_cos:.3f}") if __name__ == "__main__": run()