geolip-axis-prototype / rapid_prototype_trainer.py
AbstractPhil's picture
Create rapid_prototype_trainer.py
654b110 verified
# ============================================================================
# RAPID PROTOTYPE: 2-Expert Consensus + Alignment Bank
#
# Fast iteration cycle:
# Phase 1: Train student on 2-BERT consensus (20K captions, ~2 epochs)
# Phase 2: Freeze student, train alignment bank on its output
# Phase 3: Verify bank preserves geometry
# Phase 4: Snap a tiny classifier on bank output, check stability
# ============================================================================
import gc
import math
import os
import time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EXPERTS = [
("google-bert/bert-base-uncased", "bert", 512),
("answerdotai/ModernBERT-base", "modern", 512),
]
print("=" * 65)
print("RAPID PROTOTYPE: 2-Expert Consensus + Alignment Bank")
print("=" * 65)
print(f" Device: {DEVICE}")
# ══════════════════════════════════════════════════════════════════
# STUDENT MODEL
# ══════════════════════════════════════════════════════════════════
class MiniStudent(nn.Module):
def __init__(self, vocab_size=30522, max_len=512, d_model=256,
n_heads=4, n_layers=4, d_ff=1024, output_dim=768,
dropout=0.1, pad_token_id=0):
super().__init__()
self.pad_token_id = pad_token_id
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
self.pos_emb = nn.Embedding(max_len, d_model)
self.emb_norm = nn.LayerNorm(d_model)
self.emb_drop = nn.Dropout(dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=d_ff,
dropout=dropout, activation="gelu", batch_first=True,
norm_first=True)
self.encoder = nn.TransformerEncoder(
encoder_layer, num_layers=n_layers, enable_nested_tensor=False)
self.output_proj = nn.Sequential(
nn.Linear(d_model, d_model), nn.GELU(),
nn.LayerNorm(d_model), nn.Linear(d_model, output_dim))
def forward(self, input_ids, attention_mask=None):
B, L = input_ids.shape
positions = torch.arange(L, device=input_ids.device).unsqueeze(0)
x = self.token_emb(input_ids) + self.pos_emb(positions)
x = self.emb_drop(self.emb_norm(x))
kpm = ~attention_mask.bool() if attention_mask is not None else (input_ids == self.pad_token_id)
x = self.encoder(x, src_key_padding_mask=kpm)
mask = attention_mask.unsqueeze(-1).float() if attention_mask is not None else (~kpm).unsqueeze(-1).float()
pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
return F.normalize(self.output_proj(pooled), dim=-1)
# ══════════════════════════════════════════════════════════════════
# ALIGNMENT BANK
# ══════════════════════════════════════════════════════════════════
class AlignmentBank(nn.Module):
"""
Geometric interface layer. Learns to annotate student embeddings
with per-expert alignment context and anchor distances.
Trained on frozen student output. Provides geometric memory of
the expert consensus for downstream heads.
"""
def __init__(self, d_embed=768, n_experts=2, n_anchors=128, d_bank=64):
super().__init__()
self.d_embed = d_embed
self.n_experts = n_experts
self.n_anchors = n_anchors
self.d_bank = d_bank
# Per-expert rotation matrices (initialized from Procrustes)
self.expert_rotations = nn.ParameterList([
nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)
])
# Per-expert bias (mean offset in each expert's space)
self.expert_means = nn.ParameterList([
nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)
])
# Anchor bank: learned consensus landmarks
self.anchors = nn.Parameter(
F.normalize(torch.randn(n_anchors, d_embed), dim=-1))
# Project geometric features into compact context
# Input: n_experts (consistency) + n_anchors (distances) + n_experts (reconstruction quality)
geo_dim = n_experts + n_anchors + n_experts
self.geo_proj = nn.Sequential(
nn.Linear(geo_dim, d_bank * 2),
nn.GELU(),
nn.LayerNorm(d_bank * 2),
nn.Linear(d_bank * 2, d_bank),
nn.LayerNorm(d_bank),
)
def init_from_procrustes(self, procrustes_results, expert_names,
consensus_embeddings=None):
"""Initialize from consensus training artifacts."""
device = self.anchors.device
for i, name in enumerate(expert_names[:self.n_experts]):
info = procrustes_results[name]
self.expert_rotations[i].data = info["rotation"].float().to(device)
self.expert_means[i].data = info["source_mean"].float().to(device)
print(f" Expert {i} ({name}): rotation loaded, cos_after={info['cos_after']:.4f}")
if consensus_embeddings is not None:
n = min(self.n_anchors, consensus_embeddings.shape[0])
indices = torch.linspace(0, consensus_embeddings.shape[0] - 1, n).long()
self.anchors.data[:n] = F.normalize(
consensus_embeddings[indices].float(), dim=-1).to(device)
print(f" Anchors: {n} initialized from consensus embeddings")
def forward(self, embedding):
"""
Annotate embedding with geometric context.
Args:
embedding: (B, 768) L2-normalized
Returns:
enriched: (B, 768 + d_bank)
aux: dict with geometric losses and diagnostics
"""
B = embedding.shape[0]
emb = embedding.float()
# Per-expert: rotate into expert space, measure reconstruction quality
expert_consistency = [] # cosine between original and round-tripped
expert_recon = [] # MSE of round-trip
for i in range(self.n_experts):
R = self.expert_rotations[i]
# Forward rotation: consensus β†’ expert space
in_expert = emb @ R
# Backward rotation: expert space β†’ consensus
round_trip = in_expert @ R.T
# How well does round-trip recover original?
cos = F.cosine_similarity(emb, round_trip, dim=-1) # (B,)
recon = (emb - round_trip).pow(2).mean(dim=-1) # (B,)
expert_consistency.append(cos)
expert_recon.append(recon)
expert_cos = torch.stack(expert_consistency, dim=-1) # (B, n_experts)
expert_mse = torch.stack(expert_recon, dim=-1) # (B, n_experts)
# Anchor distances
anchors_n = F.normalize(self.anchors, dim=-1)
anchor_cos = emb @ anchors_n.T # (B, n_anchors)
# Geometric context vector
geo_input = torch.cat([expert_cos, anchor_cos, expert_mse], dim=-1)
geo_context = self.geo_proj(geo_input) # (B, d_bank)
# Enriched output
enriched = torch.cat([embedding, geo_context], dim=-1)
# ── Geometric losses ──
aux = {}
# 1. Expert agreement: all experts should see the embedding similarly
expert_mean = expert_cos.mean(dim=-1, keepdim=True)
aux["expert_agreement"] = (expert_cos - expert_mean).pow(2).mean()
# 2. Rotation orthogonality: rotations should stay orthogonal
ortho_loss = 0.0
for i in range(self.n_experts):
R = self.expert_rotations[i]
RRT = R @ R.T
ortho_loss += (RRT - torch.eye(self.d_embed, device=R.device)).pow(2).mean()
aux["rotation_ortho"] = ortho_loss / self.n_experts
# 3. Anchor spread: anchors should be well-distributed
anchor_sim = anchors_n @ anchors_n.T
anchor_sim.fill_diagonal_(0)
aux["anchor_spread"] = anchor_sim.pow(2).mean()
# 4. Anchor sharpness: each embedding should have clear nearest anchors
anchor_probs = F.softmax(anchor_cos * 10, dim=-1)
entropy = -(anchor_probs * (anchor_probs + 1e-12).log()).sum(-1).mean()
aux["anchor_entropy"] = entropy
# 5. Pentachoron CV of enriched space (sample from geo_context)
if B >= 10:
ctx_n = F.normalize(geo_context, dim=-1)
vols = []
for _ in range(32):
idx = torch.randperm(B, device=embedding.device)[:5]
pts = ctx_n[idx].unsqueeze(0)
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
Bv, V, _ = d2.shape
cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
v2 = s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
bank_cv = stacked.std() / (stacked.mean() + 1e-8)
aux["bank_cv"] = bank_cv
else:
aux["bank_cv"] = torch.tensor(0.0, device=embedding.device)
# Summary diagnostics
aux["expert_cos_mean"] = expert_cos.mean().item()
aux["expert_cos_std"] = expert_cos.std().item()
aux["anchor_max_cos"] = anchor_cos.max(dim=-1).values.mean().item()
aux["anchor_mean_cos"] = anchor_cos.mean().item()
return enriched, aux
def bank_loss(self, aux, cv_target=0.15):
"""Combined bank training loss."""
loss = (1.0 * aux["expert_agreement"] +
1.0 * aux["rotation_ortho"] +
0.5 * aux["anchor_spread"] +
0.1 * aux["anchor_entropy"] +
0.3 * (aux["bank_cv"] - cv_target).abs())
return loss
# ══════════════════════════════════════════════════════════════════
# GEOMETRY
# ══════════════════════════════════════════════════════════════════
def infonce(a, b, temperature=0.07):
a = F.normalize(a, dim=-1)
b = F.normalize(b, dim=-1)
logits = (a @ b.T) / temperature
labels = torch.arange(logits.shape[0], device=logits.device)
loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
with torch.no_grad():
acc = (logits.argmax(-1) == labels).float().mean().item()
return loss, acc
def cayley_menger_vol2(pts):
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B, V, _ = d2.shape
cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
def cv_loss(emb, target=0.12, n_samples=16):
B = emb.shape[0]
if B < 5: return torch.tensor(0.0, device=emb.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
cv = stacked.std() / (stacked.mean() + 1e-8)
return (cv - target).abs()
def cv_metric(emb, n=200):
B = emb.shape[0]
if B < 5: return 0.0
vols = []
for _ in range(n):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
if v > 0: vols.append(v)
if len(vols) < 10: return 0.0
a = np.array(vols)
return float(a.std() / (a.mean() + 1e-8))
# ══════════════════════════════════════════════════════════════════
# EXTRACTION + ALIGNMENT
# ══════════════════════════════════════════════════════════════════
def symmetric_inv_sqrt(cov, eps=1e-6):
evals, evecs = torch.linalg.eigh(cov)
evals = torch.clamp(evals, min=eps)
return evecs @ torch.diag(evals.rsqrt()) @ evecs.T
def procrustes_align(source, target, n_align=5000):
N = min(n_align, source.shape[0], target.shape[0])
S = source[:N].float()
T = target[:N].float()
s_mean = S.mean(0, keepdim=True)
t_mean = T.mean(0, keepdim=True)
Sc = S - s_mean; Tc = T - t_mean
N_s = Sc.shape[0]
cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item()
s_cov = (Sc.T @ Sc) / max(N_s - 1, 1)
t_cov = (Tc.T @ Tc) / max(N_s - 1, 1)
s_whiten = symmetric_inv_sqrt(s_cov)
t_whiten = symmetric_inv_sqrt(t_cov)
Sc_w = F.normalize(Sc @ s_whiten, dim=-1)
Tc_w = F.normalize(Tc @ t_whiten, dim=-1)
U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False)
R = U @ Vt
cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item()
return {
"rotation": R, "source_mean": s_mean.squeeze(0),
"source_whitener": s_whiten,
"target_unwhitener": torch.linalg.pinv(t_whiten),
"cos_before": cos_before, "cos_after": cos_after,
}
def apply_align(emb, a):
x = emb.float() - a["source_mean"]
x = x @ a["source_whitener"]
x = x @ a["rotation"].T
x = x @ a["target_unwhitener"]
return x
# ══════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════
def run():
torch.manual_seed(42)
np.random.seed(42)
N_SAMPLES = 20000
MAX_LEN = 128
BATCH = 256
# ── Phase 0: Extract ──
print(f"\n{'='*65}")
print("PHASE 0: EXTRACTION")
print(f"{'='*65}")
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer
ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext",
split="train", streaming=True)
captions = []
for row in ds:
cap = row.get("caption_llava", "")
if isinstance(cap, str) and len(cap) > 50:
captions.append(cap)
if len(captions) >= N_SAMPLES:
break
print(f" Captions: {len(captions):,}")
embeds = {}
for model_name, short, max_len in EXPERTS:
print(f"\n Extracting: {short}...")
model = AutoModel.from_pretrained(model_name).to(DEVICE).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
all_emb = []
with torch.no_grad():
for i in tqdm(range(0, len(captions), 128), desc=f" {short}"):
batch = captions[i:i+128]
inputs = tokenizer(batch, max_length=max_len, padding=True,
truncation=True, return_tensors="pt").to(DEVICE)
out = model(**inputs)
m = inputs.attention_mask.unsqueeze(-1).float()
pooled = (out.last_hidden_state * m).sum(1) / m.sum(1).clamp(min=1)
all_emb.append(pooled.cpu())
embeds[short] = torch.cat(all_emb)
print(f" Shape: {embeds[short].shape}")
del model; gc.collect(); torch.cuda.empty_cache()
# ── Phase 0b: Align + Consensus ──
print(f"\n{'='*65}")
print("PHASE 0b: PROCRUSTES ALIGNMENT")
print(f"{'='*65}")
ref = "bert"
names = [s for _, s, _ in EXPERTS]
procrustes_results = {}
aligned = {}
for name in names:
info = procrustes_align(embeds[name], embeds[ref])
procrustes_results[name] = info
aligned[name] = apply_align(embeds[name], info)
print(f" {name:10s}: cos {info['cos_before']:.4f} β†’ {info['cos_after']:.4f}")
consensus = F.normalize(sum(aligned[n] for n in names) / len(names), dim=-1)
print(f" Consensus: {consensus.shape}")
for name in names:
cos = F.cosine_similarity(consensus[:2000], aligned[name][:2000], dim=-1).mean().item()
print(f" cos(consensus, {name}): {cos:.4f}")
consensus_cv = cv_metric(consensus[:2000].to(DEVICE))
print(f" Consensus CV: {consensus_cv:.4f}")
del embeds, aligned
gc.collect(); torch.cuda.empty_cache()
# ── Phase 1: Train Student ──
print(f"\n{'='*65}")
print("PHASE 1: TRAIN STUDENT (2 experts, 20K captions)")
print(f"{'='*65}")
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
tokens = tokenizer(captions, max_length=MAX_LEN, padding="max_length",
truncation=True, return_tensors="pt")
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
n_train = N_SAMPLES - 2000
train_ids = input_ids[:n_train].to(DEVICE)
train_mask = attention_mask[:n_train].to(DEVICE)
train_targets = consensus[:n_train].to(DEVICE)
val_ids = input_ids[n_train:].to(DEVICE)
val_mask = attention_mask[n_train:].to(DEVICE)
val_targets = consensus[n_train:].to(DEVICE)
student = MiniStudent(
vocab_size=tokenizer.vocab_size, max_len=MAX_LEN,
d_model=256, n_heads=4, n_layers=4, d_ff=1024,
output_dim=768, dropout=0.1, pad_token_id=tokenizer.pad_token_id
).to(DEVICE)
n_params = sum(p.numel() for p in student.parameters())
print(f" Student: {n_params:,} params")
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay=0.01)
for epoch in range(5):
student.train()
perm = torch.randperm(n_train, device=DEVICE)
t_loss, t_acc, t_cos, n = 0, 0, 0, 0
t0 = time.time()
for i in range(0, n_train, BATCH):
idx = perm[i:i+BATCH]
if len(idx) < 8: continue
emb = student(train_ids[idx], train_mask[idx])
tgt = train_targets[idx]
l_nce, acc = infonce(emb, tgt)
l_mse = F.mse_loss(emb, tgt)
l_cv = cv_loss(emb, target=consensus_cv)
loss = l_nce + l_mse + 0.1 * l_cv
loss.backward()
torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
optimizer.step(); optimizer.zero_grad(set_to_none=True)
with torch.no_grad():
cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item()
t_loss += loss.item(); t_acc += acc; t_cos += cos; n += 1
elapsed = time.time() - t0
d = max(n, 1)
student.eval()
with torch.no_grad():
v_emb = student(val_ids, val_mask)
_, v_acc = infonce(v_emb[:1000], val_targets[:1000])
v_cos = F.cosine_similarity(v_emb, val_targets, dim=-1).mean().item()
v_cv = cv_metric(v_emb[:1000])
print(f" E{epoch+1}: {elapsed:.0f}s loss={t_loss/d:.4f} "
f"t_acc={t_acc/d:.3f} t_cos={t_cos/d:.3f} "
f"v_acc={v_acc:.3f} v_cos={v_cos:.3f} v_cv={v_cv:.3f}")
# Save student
torch.save(student.state_dict(), "mini_student.pt")
print(f"\n Student saved. v_cos={v_cos:.3f}, v_cv={v_cv:.3f}")
# ── Phase 2: Train Alignment Bank ──
print(f"\n{'='*65}")
print("PHASE 2: TRAIN ALIGNMENT BANK (student frozen)")
print(f"{'='*65}")
# Freeze student
student.eval()
for p in student.parameters():
p.requires_grad = False
# Pre-encode everything through frozen student
print(" Pre-encoding through frozen student...")
with torch.no_grad():
all_embs = []
for i in range(0, n_train, 512):
j = min(i + 512, n_train)
emb = student(train_ids[i:j], train_mask[i:j])
all_embs.append(emb)
student_embs = torch.cat(all_embs) # (n_train, 768)
val_student_embs = student(val_ids, val_mask)
print(f" Student embeddings: {student_embs.shape}")
# Build bank
bank = AlignmentBank(
d_embed=768, n_experts=len(EXPERTS),
n_anchors=128, d_bank=64
).to(DEVICE)
bank.init_from_procrustes(procrustes_results, names, consensus[:n_train])
bank_params = sum(p.numel() for p in bank.parameters())
print(f" Bank: {bank_params:,} params")
bank_opt = torch.optim.AdamW(bank.parameters(), lr=1e-3, weight_decay=0.01)
BANK_EPOCHS = 20
BANK_BATCH = 256
for epoch in range(BANK_EPOCHS):
bank.train()
perm = torch.randperm(n_train, device=DEVICE)
total_loss = 0
stats = {"expert_agreement": 0, "rotation_ortho": 0,
"anchor_spread": 0, "bank_cv": 0}
n = 0
t0 = time.time()
for i in range(0, n_train, BANK_BATCH):
idx = perm[i:i+BANK_BATCH]
if len(idx) < 16: continue
emb = student_embs[idx]
enriched, aux = bank(emb)
loss = bank.bank_loss(aux, cv_target=consensus_cv + 0.02)
loss.backward()
torch.nn.utils.clip_grad_norm_(bank.parameters(), 1.0)
bank_opt.step(); bank_opt.zero_grad(set_to_none=True)
total_loss += loss.item()
for k in stats:
if k in aux:
v = aux[k]
stats[k] += v.item() if torch.is_tensor(v) else v
n += 1
elapsed = time.time() - t0
d = max(n, 1)
# Validation
bank.eval()
with torch.no_grad():
v_enriched, v_aux = bank(val_student_embs)
v_loss = bank.bank_loss(v_aux, cv_target=consensus_cv + 0.02).item()
print(f" E{epoch+1:2d}: {elapsed:.0f}s loss={total_loss/d:.4f} "
f"v_loss={v_loss:.4f} "
f"expert_agr={stats['expert_agreement']/d:.5f} "
f"ortho={stats['rotation_ortho']/d:.5f} "
f"spread={stats['anchor_spread']/d:.5f} "
f"cv={stats['bank_cv']/d:.4f} "
f"anchor_max={v_aux['anchor_max_cos']:.3f} "
f"expert_cos={v_aux['expert_cos_mean']:.3f}Β±{v_aux['expert_cos_std']:.3f}")
torch.save(bank.state_dict(), "alignment_bank.pt")
# ── Phase 3: Verify Geometry ──
print(f"\n{'='*65}")
print("PHASE 3: GEOMETRIC VERIFICATION")
print(f"{'='*65}")
bank.eval()
with torch.no_grad():
# Check that enriched embeddings preserve original structure
enriched_val, _ = bank(val_student_embs)
original_768 = enriched_val[:, :768] # first 768 dims = original embedding
geo_context = enriched_val[:, 768:] # last d_bank dims = geometric annotation
# Original embedding should be unchanged (passthrough)
passthrough_cos = F.cosine_similarity(
original_768[:100], val_student_embs[:100], dim=-1).mean().item()
# Geometric context should be informative
geo_cv = cv_metric(F.normalize(geo_context[:1000], dim=-1))
geo_eff_dim = torch.linalg.svdvals(
geo_context[:1000].float() - geo_context[:1000].float().mean(0)).pow(2)
geo_eff_dim = (geo_eff_dim.sum() ** 2) / (geo_eff_dim.pow(2).sum() + 1e-12)
print(f" Passthrough integrity: {passthrough_cos:.6f} (should be ~1.000)")
print(f" Geo context CV: {geo_cv:.4f}")
print(f" Geo context eff_dim: {geo_eff_dim:.1f}")
print(f" Geo context shape: {geo_context.shape}")
# ── Phase 4: Quick Classifier Test ──
print(f"\n{'='*65}")
print("PHASE 4: CLASSIFIER STABILITY TEST")
print(f"{'='*65}")
# Create synthetic 3-class task from similarity structure
# Class 0: high consensus cosine pairs (similar)
# Class 1: medium consensus cosine pairs
# Class 2: low consensus cosine pairs (different)
with torch.no_grad():
# Generate synthetic labels from embedding distances
embs = val_student_embs[:1000]
sim = embs @ embs.T
sim.fill_diagonal_(-1) # exclude self
# Random pairs
n_pairs = 3000
idx_a = torch.randint(0, 1000, (n_pairs,))
idx_b = torch.randint(0, 1000, (n_pairs,))
pair_cos = sim[idx_a, idx_b]
# Assign labels by cosine terciles
sorted_cos, _ = pair_cos.sort()
t1 = sorted_cos[n_pairs // 3].item()
t2 = sorted_cos[2 * n_pairs // 3].item()
labels = torch.zeros(n_pairs, dtype=torch.long, device=DEVICE)
labels[pair_cos > t2] = 0 # similar
labels[(pair_cos <= t2) & (pair_cos > t1)] = 1 # medium
labels[pair_cos <= t1] = 2 # different
# Get enriched representations
enriched_a, _ = bank(embs[idx_a])
enriched_b, _ = bank(embs[idx_b])
# Train tiny classifier: with bank vs without bank
for mode in ["with_bank", "without_bank"]:
if mode == "with_bank":
feat_dim = (768 + 64) * 2 # enriched
features = torch.cat([enriched_a, enriched_b], dim=-1)
else:
feat_dim = 768 * 2 # raw
features = torch.cat([embs[idx_a], embs[idx_b]], dim=-1)
clf = nn.Sequential(
nn.Linear(feat_dim, 128), nn.GELU(),
nn.Linear(128, 3)
).to(DEVICE)
clf_opt = torch.optim.Adam(clf.parameters(), lr=1e-3)
n_clf_train = 2400
train_f = features[:n_clf_train].detach()
train_l = labels[:n_clf_train]
val_f = features[n_clf_train:].detach()
val_l = labels[n_clf_train:]
for e in range(20):
clf.train()
logits = clf(train_f)
loss = F.cross_entropy(logits, train_l)
loss.backward()
clf_opt.step(); clf_opt.zero_grad()
clf.eval()
with torch.no_grad():
val_logits = clf(val_f)
val_acc = (val_logits.argmax(-1) == val_l).float().mean().item()
train_logits = clf(train_f)
train_acc = (train_logits.argmax(-1) == train_l).float().mean().item()
print(f" {mode:15s}: train_acc={train_acc:.3f} val_acc={val_acc:.3f} "
f"gap={train_acc-val_acc:.3f}")
print(f"\n{'='*65}")
print("DONE")
print(f"{'='*65}")
print(f"\n Student: mini_student.pt")
print(f" Bank: alignment_bank.pt")
print(f" Consensus CV: {consensus_cv:.4f}")
print(f" Student v_cos: {v_cos:.3f}")
if __name__ == "__main__":
run()