Node-JEPA / node_jepa_round2.py
EPSAGR's picture
Fix OOM: shrink arxiv config to hidden=256 2L no-rec (BGRL-matched, fits T4)
e56a8b3 verified
"""
Node-JEPA V4.2 β€” Round 2: ogbn-arxiv + CiteSeer Fix
=====================================================
Round 1 results:
Cora: 82.8% (rank=65) βœ… healthy
PubMed: 80.1% (rank=62) βœ… healthy
CiteSeer: 53.9% (rank=1) ❌ collapsed
ogbn-arxiv: SKIPPED (OGB install issue)
This script:
1. Runs ogbn-arxiv with proper loading (to_undirected, StandardScaler on features)
2. Tests 4 CiteSeer fix hypotheses to break the collapse
Key fixes from research:
- ogbn-arxiv: use to_undirected() (it's directed!), StandardScaler on features (GraphMAE does this)
- ogbn-arxiv: BGRL uses NO feature dropout on arxiv (p_f=0.0), edge dropout only
- CiteSeer collapse likely caused by: GAT encoder, high feat dim (3703), low LR, reconstruction on sparse features
"""
import math, copy, os, sys, json, time, gc, functools
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
import warnings
warnings.filterwarnings("ignore")
# Fix PyTorch 2.6+ torch.load(weights_only=True) breaking OGB/PyG dataset loading
_original_torch_load = torch.load
@functools.wraps(_original_torch_load)
def _patched_torch_load(*args, **kwargs):
if 'weights_only' not in kwargs:
kwargs['weights_only'] = False
return _original_torch_load(*args, **kwargs)
torch.load = _patched_torch_load
print("Patched torch.load for PyTorch 2.6+ compatibility")
from torch_geometric.nn import GATConv, GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import to_undirected
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
try:
from ogb.nodeproppred import PygNodePropPredDataset
HAS_OGB = True
print("OGB loaded successfully")
except Exception as e:
HAS_OGB = False
print(f"OGB import failed: {e}")
try: import trackio; HAS_TRACKIO = True
except: HAS_TRACKIO = False
try: from huggingface_hub import HfApi; HAS_HF_HUB = True
except: HAS_HF_HUB = False
# ═══════════════════════════════════════════════════════════════
# Model (same as V4.2 baseline)
# ═══════════════════════════════════════════════════════════════
def decorrelation_loss(z):
N, D = z.shape
zc = z - z.mean(0); zs = zc / (zc.std(0) + 1e-8)
R = (zs.T @ zs) / N; I = torch.eye(D, device=z.device)
return (R - I).pow(2).sum() / (D * D)
def sce_loss(pred, target, alpha=1):
p = F.normalize(pred, p=2, dim=-1); t = F.normalize(target, p=2, dim=-1)
return (1 - (p * t).sum(dim=-1)).pow(alpha).mean()
class GraphAugmentor:
def __init__(self, feat_mask_p=0.0, edge_drop_p=0.0):
self.feat_mask_p = feat_mask_p; self.edge_drop_p = edge_drop_p
def __call__(self, x, edge_index):
if self.feat_mask_p > 0:
x = x * torch.bernoulli(torch.ones_like(x) * (1 - self.feat_mask_p))
if self.edge_drop_p > 0:
m = torch.bernoulli(torch.ones(edge_index.size(1), device=x.device) * (1 - self.edge_drop_p)).bool()
edge_index = edge_index[:, m]
return x, edge_index
class GATEncoder(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, num_layers=2, num_heads=4, dropout=0.0):
super().__init__()
self.num_layers = num_layers
self.reg_token = nn.Parameter(torch.zeros(1, in_dim)); nn.init.trunc_normal_(self.reg_token, std=0.02)
self.input_proj = nn.Linear(in_dim, hidden_dim)
self.input_bn = nn.BatchNorm1d(hidden_dim, momentum=0.01)
self.gat_layers = nn.ModuleList(); self.bns = nn.ModuleList(); self.acts = nn.ModuleList()
for _ in range(num_layers):
self.gat_layers.append(GATConv(hidden_dim, hidden_dim // num_heads, heads=num_heads, concat=True, dropout=dropout, add_self_loops=True))
self.bns.append(nn.BatchNorm1d(hidden_dim, momentum=0.01)); self.acts.append(nn.PReLU())
self.output_proj = nn.Linear(hidden_dim, out_dim)
def forward(self, x, edge_index):
N = x.size(0); device = x.device
x_aug = torch.cat([x, self.reg_token.expand(1, -1).to(device)], dim=0)
re = torch.cat([torch.stack([torch.full((N,), N, dtype=torch.long, device=device), torch.arange(N, device=device)]),
torch.stack([torch.arange(N, device=device), torch.full((N,), N, dtype=torch.long, device=device)])], dim=1)
ei = torch.cat([edge_index, re], dim=1)
h = F.prelu(self.input_bn(self.input_proj(x_aug)), torch.tensor(0.25, device=device))
for i in range(self.num_layers): h = self.acts[i](self.bns[i](self.gat_layers[i](h, ei))) + h
return self.output_proj(h)[:N]
class GCNEncoder(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, num_layers=2, dropout=0.0):
super().__init__()
self.num_layers = num_layers
self.input_proj = nn.Linear(in_dim, hidden_dim)
self.input_bn = nn.BatchNorm1d(hidden_dim, momentum=0.01)
self.gcn_layers = nn.ModuleList(); self.bns = nn.ModuleList(); self.acts = nn.ModuleList()
for _ in range(num_layers):
self.gcn_layers.append(GCNConv(hidden_dim, hidden_dim))
self.bns.append(nn.BatchNorm1d(hidden_dim, momentum=0.01)); self.acts.append(nn.PReLU())
self.output_proj = nn.Linear(hidden_dim, out_dim)
def forward(self, x, edge_index):
h = F.prelu(self.input_bn(self.input_proj(x)), torch.tensor(0.25, device=x.device))
for i in range(self.num_layers): h = self.acts[i](self.bns[i](self.gcn_layers[i](h, edge_index))) + h
return self.output_proj(h)
class MLPPredictor(nn.Module):
def __init__(self, dim, hidden=512):
super().__init__()
self.net = nn.Sequential(nn.Linear(dim, hidden), nn.BatchNorm1d(hidden, momentum=0.01), nn.PReLU(), nn.Linear(hidden, dim))
def forward(self, x): return self.net(x)
class FeatureDecoder(nn.Module):
def __init__(self, ed, hd, od):
super().__init__()
self.net = nn.Sequential(nn.Linear(ed, hd), nn.BatchNorm1d(hd, momentum=0.01), nn.PReLU(), nn.Linear(hd, od))
def forward(self, x): return self.net(x)
class NodeMasker(nn.Module):
def __init__(self, feat_dim, mask_rate=0.5, replace_rate=0.05):
super().__init__()
self.mask_rate = mask_rate; self.replace_rate = replace_rate
self.mask_token = nn.Parameter(torch.zeros(1, feat_dim)); nn.init.trunc_normal_(self.mask_token, std=0.02)
def forward(self, x):
N = x.size(0); nm = max(1, int(N * self.mask_rate))
perm = torch.randperm(N, device=x.device)
mask = torch.zeros(N, dtype=torch.bool, device=x.device); mask[perm[:nm]] = True
xm = x.clone(); nr = max(0, int(nm * self.replace_rate))
xm[perm[:nm - nr]] = self.mask_token.expand(nm - nr, -1)
if nr > 0: xm[perm[nm - nr:nm]] = x[torch.randint(0, N, (nr,), device=x.device)]
return xm, mask
class NodeJEPA(nn.Module):
def __init__(self, in_dim, hidden_dim=512, out_dim=256, num_layers=2, num_heads=4,
predictor_hidden=512, mask_rate=0.5, replace_rate=0.05,
ema_tau_base=0.99, aug1_feat_p=0.0, aug1_edge_p=0.0,
aug2_feat_p=0.0, aug2_edge_p=0.0, lambda_dec=0.0, mu_rec=0.0,
sce_alpha=1, encoder_type='gat', dropout=0.0):
super().__init__()
self.lambda_dec = lambda_dec; self.mu_rec = mu_rec; self.sce_alpha = sce_alpha
if encoder_type == 'gat':
self.online_encoder = GATEncoder(in_dim, hidden_dim, out_dim, num_layers, num_heads, dropout)
else:
self.online_encoder = GCNEncoder(in_dim, hidden_dim, out_dim, num_layers, dropout)
self.target_encoder = copy.deepcopy(self.online_encoder)
for p in self.target_encoder.parameters(): p.requires_grad = False
self.predictor = MLPPredictor(out_dim, predictor_hidden)
self.masker = NodeMasker(in_dim, mask_rate, replace_rate)
self.aug1 = GraphAugmentor(aug1_feat_p, aug1_edge_p)
self.aug2 = GraphAugmentor(aug2_feat_p, aug2_edge_p)
self.feat_decoder = FeatureDecoder(out_dim, hidden_dim, in_dim) if mu_rec > 0 else None
@torch.no_grad()
def update_target(self, tau):
for po, pt in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
pt.data.mul_(tau).add_(po.data, alpha=1 - tau)
def _jepa_loss(self, x, ei, aug_o, aug_t, do_mask=True):
xo, eio = aug_o(x, ei); xt, eit = aug_t(x, ei)
if do_mask: xo, mask = self.masker(xo)
else: mask = torch.ones(x.size(0), dtype=torch.bool, device=x.device)
ho = self.online_encoder(xo, eio)
with torch.no_grad(): ht = self.target_encoder(xt, eit)
pred = F.normalize(self.predictor(ho[mask]), dim=-1)
tgt = F.normalize(ht[mask], dim=-1)
return 2 - 2 * (pred * tgt).sum(-1).mean(), ho, mask
def forward(self, data):
x, ei = data.x, data.edge_index
l1, ho1, m1 = self._jepa_loss(x, ei, self.aug1, self.aug2, True)
l2, ho2, m2 = self._jepa_loss(x, ei, self.aug2, self.aug1, False)
lj = (l1 + l2) / 2
ld = self.lambda_dec * (decorrelation_loss(ho1) + decorrelation_loss(ho2)) / 2 if self.lambda_dec > 0 else torch.tensor(0.0, device=x.device)
lr = torch.tensor(0.0, device=x.device)
if self.mu_rec > 0 and self.feat_decoder is not None:
lr = self.mu_rec * sce_loss(self.feat_decoder(ho1[m1]), x[m1], self.sce_alpha)
return lj + ld + lr, lj, ld, lr
@torch.no_grad()
def encode(self, data):
self.online_encoder.eval()
return self.online_encoder(data.x, data.edge_index)
def cosine_scheduler(base, final, steps):
return [final - (final - base) * (math.cos(math.pi * s / max(steps - 1, 1)) + 1) / 2 for s in range(steps)]
# ═══════════════════════════════════════════════════════════════
# Evaluation functions
# ═══════════════════════════════════════════════════════════════
def compute_geometry(embs, labels=None, n=2000):
N, D = embs.shape
idx = np.random.choice(N, min(N, n), replace=False) if N > n else np.arange(N)
es = embs[idx]; ls = labels[idx] if labels is not None else None
ec = es - es.mean(0)
try: _, S, _ = np.linalg.svd(ec, full_matrices=False)
except: S = np.ones(min(es.shape))
tv = (S**2).sum()
if tv < 1e-12: rk95, rk99 = 1, 1
else:
ev = np.cumsum(S**2) / tv
rk95 = int(np.searchsorted(ev, 0.95) + 1); rk99 = int(np.searchsorted(ev, 0.99) + 1)
iso = float(S[-1] / S[0]) if len(S) > 1 and S[0] > 1e-12 else 0.0
norms = np.linalg.norm(es, axis=1, keepdims=True) + 1e-8; en = es / norms
cm = en @ en.T; np.fill_diagonal(cm, 0); nn_ = es.shape[0]
avg_cos = cm.sum() / (nn_ * (nn_ - 1))
align = 0.0
if ls is not None:
sims = []
for c in np.unique(ls):
mc = ls == c
if mc.sum() > 1:
ec_ = en[mc]; cc = ec_ @ ec_.T; np.fill_diagonal(cc, 0)
sims.extend(cc[np.triu_indices(ec_.shape[0], k=1)].tolist())
if sims: align = float(np.mean(sims))
diff = en[:, None, :] - en[None, :, :]; sq = (diff**2).sum(-1)
um = ~np.eye(nn_, dtype=bool)
unif = float(np.log(np.exp(-2 * sq[um]).mean() + 1e-12))
return {"rank95": rk95, "rank99": rk99, "isotropy": iso, "alignment": align,
"uniformity": unif, "avg_cos": float(avg_cos), "std_dim": float(embs.std(0).mean())}
def linear_eval(embs, labels, tr_mask, val_mask, te_mask, seeds=20):
results = {}
for nm, fn in [("raw", lambda e: e), ("l2", lambda e: e/(np.linalg.norm(e,1,True)+1e-8)), ("std", lambda e: StandardScaler().fit_transform(e))]:
e = fn(embs.copy()); Xtr, ytr = e[tr_mask], labels[tr_mask]; Xv, yv = e[val_mask], labels[val_mask]; Xte, yte = e[te_mask], labels[te_mask]
accs = []
for s in range(seeds):
bv, bt = -1, -1
for C in [0.01, 0.1, 1.0, 10.0, 100.0]:
clf = LogisticRegression(max_iter=2000, solver='lbfgs', random_state=s, C=C); clf.fit(Xtr, ytr)
v = clf.score(Xv, yv)
if v > bv: bv = v; bt = clf.score(Xte, yte)
accs.append(bt * 100)
results[nm] = {"mean": float(np.mean(accs)), "std": float(np.std(accs))}
return results
def get_lr(ep, warmup, total, base_lr, min_lr=1e-6):
if ep < warmup: return base_lr * (ep + 1) / warmup
p = (ep - warmup) / max(total - warmup, 1)
return min_lr + (base_lr - min_lr) * (1 + math.cos(math.pi * p)) / 2
# ═══════════════════════════════════════════════════════════════
# Train + eval one config
# ═══════════════════════════════════════════════════════════════
def run(name, data, device, cfg, seed=42):
torch.manual_seed(seed); np.random.seed(seed)
if device.type == 'cuda': torch.cuda.manual_seed(seed)
epochs, lr, warmup = cfg['epochs'], cfg['lr'], cfg['warmup']
model = NodeJEPA(
in_dim=data.num_features, hidden_dim=cfg['hidden_dim'], out_dim=cfg['out_dim'],
num_layers=cfg['num_layers'], num_heads=cfg.get('num_heads', 4),
predictor_hidden=cfg['hidden_dim'], mask_rate=cfg['mask_rate'],
ema_tau_base=cfg['ema_tau_base'],
aug1_feat_p=cfg.get('aug1_feat_p', 0), aug1_edge_p=cfg.get('aug1_edge_p', 0),
aug2_feat_p=cfg.get('aug2_feat_p', 0), aug2_edge_p=cfg.get('aug2_edge_p', 0),
lambda_dec=cfg.get('lambda_dec', 0), mu_rec=cfg.get('mu_rec', 0),
encoder_type=cfg.get('encoder_type', 'gat'),
).to(device)
opt = AdamW([p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=cfg.get('wd', 1e-5))
ema = cosine_scheduler(cfg['ema_tau_base'], 1.0, epochs)
t0 = time.time(); best_acc, best_ep = 0, 0; ei = max(1, epochs // 10); li = max(1, epochs // 20)
for ep in range(epochs):
model.train(); clr = get_lr(ep, warmup, epochs, lr)
for pg in opt.param_groups: pg['lr'] = clr
loss, lj, ld, lr_ = model(data); opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0); opt.step(); model.update_target(ema[ep])
if (ep+1) % li == 0 or ep == 0:
print(f" [{name}] Ep {ep+1:4d}/{epochs} | L={loss.item():.4f} (j={lj.item():.4f} d={ld.item():.4f} r={lr_.item():.4f}) | LR={clr:.1e} | {time.time()-t0:.0f}s"); sys.stdout.flush()
if (ep+1) % ei == 0:
model.eval()
with torch.no_grad(): e = model.encode(data).cpu().numpy()
ln = data.y.cpu().numpy();
if len(ln.shape) > 1: ln = ln.squeeze()
Xtr = e[data.train_mask.cpu().numpy()]; ytr = ln[data.train_mask.cpu().numpy()]
Xte = e[data.test_mask.cpu().numpy()]; yte = ln[data.test_mask.cpu().numpy()]
clf = LogisticRegression(max_iter=2000, solver='lbfgs', C=1.0); clf.fit(Xtr, ytr)
acc = clf.score(Xte, yte) * 100
if acc > best_acc: best_acc = acc; best_ep = ep + 1
print(f" [{name}] EVAL @{ep+1}: {acc:.1f}% | BEST={best_acc:.1f}%@{best_ep}"); sys.stdout.flush()
# Final eval
model.eval()
with torch.no_grad(): embs = model.encode(data).cpu().numpy()
labels = data.y.cpu().numpy()
if len(labels.shape) > 1: labels = labels.squeeze()
tr = data.train_mask.cpu().numpy(); va = data.val_mask.cpu().numpy() if hasattr(data, 'val_mask') and data.val_mask is not None else tr; te = data.test_mask.cpu().numpy()
probe = linear_eval(embs, labels, tr, va, te, seeds=20)
geom = compute_geometry(embs, labels)
best = max(probe[k]['mean'] for k in probe); bn = max(probe, key=lambda k: probe[k]['mean'])
print(f" [{name}] FINAL: raw={probe['raw']['mean']:.1f} l2={probe['l2']['mean']:.1f} std={probe['std']['mean']:.1f} | BEST={best:.1f}%({bn})")
print(f" [{name}] GEOM: rank95={geom['rank95']} iso={geom['isotropy']:.4f} align={geom['alignment']:.4f} unif={geom['uniformity']:.4f} cos={geom['avg_cos']:.4f}")
return {"name": name, "accuracy": probe, "best": best, "best_norm": bn, "geometry": geom, "time": time.time()-t0}
# ═══════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════
def main():
torch.manual_seed(42); np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if device.type == "cuda":
print(f" GPU: {torch.cuda.get_device_name(0)}")
try: print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
except: pass
torch.cuda.manual_seed(42)
tracker = None
if HAS_TRACKIO:
try: trackio.init(name="node-jepa-round2", project="node-jepa"); tracker = True; print("Trackio initialized")
except: pass
all_results = {}
# ═══════════════════════════════════════════════════════════
# PART 1: ogbn-arxiv
# ═══════════════════════════════════════════════════════════
if HAS_OGB:
print(f"\n{'='*80}")
print(f"PART 1: ogbn-arxiv (169K nodes, SSL SOTA: 71.87%)")
print(f"{'='*80}")
try:
dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='/tmp/ogb_data')
data = dataset[0]
si = dataset.get_idx_split()
# Make undirected (ogbn-arxiv is directed!)
print(f" Original: {data.num_nodes} nodes, {data.edge_index.size(1)} edges (directed)")
data.edge_index = to_undirected(data.edge_index, num_nodes=data.num_nodes)
print(f" Undirected: {data.edge_index.size(1)} edges")
# StandardScaler on features (GraphMAE does this, improves stability)
feat_np = data.x.numpy()
scaler = StandardScaler()
data.x = torch.tensor(scaler.fit_transform(feat_np), dtype=torch.float32)
print(f" Features: {data.x.size(1)} dim, StandardScaled")
# Masks
N = data.num_nodes
tm = torch.zeros(N, dtype=torch.bool); tm[si['train']] = True
vm = torch.zeros(N, dtype=torch.bool); vm[si['valid']] = True
tsm = torch.zeros(N, dtype=torch.bool); tsm[si['test']] = True
data.train_mask = tm; data.val_mask = vm; data.test_mask = tsm
data.y = data.y.squeeze()
data = data.to(device)
print(f" Train: {tm.sum()}, Val: {vm.sum()}, Test: {tsm.sum()}, Classes: {data.y.max()+1}")
# Config: BGRL-matched (hidden=256, 2L, no rec β€” fits T4 16GB)
# BGRL paper got 71.64% with this exact size on arxiv
# Larger configs OOM on T4 due to dual-encoder + 169K nodes
arxiv_cfg = {
'epochs': 500, 'lr': 1e-3, 'warmup': 100,
'hidden_dim': 256, 'out_dim': 256,
'num_layers': 2, 'num_heads': 4,
'mask_rate': 0.5, 'ema_tau_base': 0.996, 'wd': 1e-5,
'aug1_feat_p': 0.0, 'aug1_edge_p': 0.4, # NO feat dropout (BGRL finding)
'aug2_feat_p': 0.0, 'aug2_edge_p': 0.4,
'encoder_type': 'gcn',
'lambda_dec': 0.0, 'mu_rec': 0.0, # no rec decoder β€” saves ~30% memory
}
torch.cuda.empty_cache()
print(f" VRAM before training: {torch.cuda.memory_allocated()/1e9:.2f} GB used")
r = run("ogbn-arxiv", data, device, arxiv_cfg, seed=42)
all_results["ogbn-arxiv"] = r
if tracker:
trackio.log({"arxiv/best": r['best'], "arxiv/rank95": r['geometry']['rank95']})
del data, dataset; gc.collect()
if device.type == 'cuda': torch.cuda.empty_cache()
except Exception as e:
print(f" ogbn-arxiv FAILED: {e}")
import traceback; traceback.print_exc()
all_results["ogbn-arxiv"] = {"error": str(e)}
else:
print("\nSKIPPING ogbn-arxiv (OGB not installed)")
# ═══════════════════════════════════════════════════════════
# PART 2: CiteSeer Fix Attempts
# ═══════════════════════════════════════════════════════════
print(f"\n{'='*80}")
print(f"PART 2: CiteSeer Collapse Fix (4 hypotheses)")
print(f"{'='*80}")
dataset = Planetoid(root="/tmp/pyg_data", name="CiteSeer", transform=NormalizeFeatures())
data = dataset[0].to(device)
print(f" {data.num_nodes} nodes, {data.num_edges} edges, {data.num_features} feat, {dataset.num_classes} classes")
# Baseline (what collapsed in round 1)
base = {
'epochs': 600, 'lr': 3e-4, 'warmup': 200,
'hidden_dim': 512, 'out_dim': 256, 'num_layers': 2, 'num_heads': 4,
'mask_rate': 0.5, 'ema_tau_base': 0.99, 'wd': 1e-5,
'aug1_feat_p': 0.5, 'aug1_edge_p': 0.4, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.4,
'encoder_type': 'gat', 'lambda_dec': 0.0, 'mu_rec': 2.0,
}
citeseer_results = []
# Fix 1: GCN instead of GAT (Cora/PubMed use GCN and don't collapse)
print(f"\n --- Fix 1: GCN encoder (GAT may amplify collapse) ---")
cfg1 = {**base, 'encoder_type': 'gcn'}
r1 = run("cs_gcn", data, device, cfg1); citeseer_results.append(r1)
# Fix 2: No reconstruction (V3 got 70.7% without it; rec on 3703-dim sparse may hurt)
print(f"\n --- Fix 2: No reconstruction (mu_rec=0, V3 style) ---")
cfg2 = {**base, 'mu_rec': 0.0}
r2 = run("cs_no_rec", data, device, cfg2); citeseer_results.append(r2)
# Fix 3: Higher LR (1e-3 like Cora/PubMed instead of 3e-4)
print(f"\n --- Fix 3: Higher LR (1e-3 instead of 3e-4) ---")
cfg3 = {**base, 'lr': 1e-3}
r3 = run("cs_lr1e3", data, device, cfg3); citeseer_results.append(r3)
# Fix 4: GCN + no rec + higher LR (all fixes combined)
print(f"\n --- Fix 4: GCN + no rec + lr=1e-3 (all fixes) ---")
cfg4 = {**base, 'encoder_type': 'gcn', 'mu_rec': 0.0, 'lr': 1e-3}
r4 = run("cs_all_fix", data, device, cfg4); citeseer_results.append(r4)
all_results["citeseer_fixes"] = citeseer_results
if tracker:
for r in citeseer_results:
trackio.log({f"cs/{r['name']}/best": r['best'], f"cs/{r['name']}/rank95": r['geometry']['rank95']})
# ═══════════════════════════════════════════════════════════
# SUMMARY
# ═══════════════════════════════════════════════════════════
print(f"\n{'='*100}")
print(f"NODE-JEPA V4.2 β€” ROUND 2 RESULTS")
print(f"{'='*100}")
print(f"\n{'Config':<20} {'Best%':>8} {'Norm':>6} {'Rank95':>7} {'Iso':>8} {'Align':>8} {'Unif':>8} {'Cos':>8}")
print("-" * 85)
# ogbn-arxiv
if "ogbn-arxiv" in all_results and "error" not in all_results["ogbn-arxiv"]:
r = all_results["ogbn-arxiv"]; g = r['geometry']
print(f" {'ogbn-arxiv':<18} {r['best']:>7.2f}% {r['best_norm']:>6} {g['rank95']:>7} {g['isotropy']:>7.4f} {g['alignment']:>7.4f} {g['uniformity']:>7.4f} {g['avg_cos']:>7.4f}")
else:
print(f" {'ogbn-arxiv':<18} {'FAILED':>8}")
# CiteSeer fixes
for r in citeseer_results:
g = r['geometry']
collapsed = "❌" if g['rank95'] <= 1 else "βœ…"
print(f" {r['name']:<18} {r['best']:>7.2f}% {r['best_norm']:>6} {collapsed}{g['rank95']:>6} {g['isotropy']:>7.4f} {g['alignment']:>7.4f} {g['uniformity']:>7.4f} {g['avg_cos']:>7.4f}")
# Context
print(f"\n Round 1 reference:")
print(f" {'Cora (R1)':<18} {'82.77':>7}% {'raw':>6} {'65':>7} {'0.0977':>8} {'0.2172':>8} {'-3.4168':>8}")
print(f" {'PubMed (R1)':<18} {'80.10':>7}% {'std':>6} {'62':>7} {'0.0910':>8} {'0.5689':>8} {'-1.7647':>8}")
print(f" {'CiteSeer (R1)':<18} {'53.93':>7}% {'?':>6} {'1':>7} {'0.0019':>8} {'1.0000':>8} {'0.0000':>8} ← collapsed")
print(f"\n SSL SOTA: Cora=84.2%, CiteSeer=73.4%, PubMed=81.1%, ogbn-arxiv=71.87%")
# Save
out = "/tmp/v4.2_round2_results.json"
with open(out, "w") as f: json.dump(all_results, f, indent=2, default=str)
if HAS_HF_HUB:
try:
api = HfApi()
api.upload_file(path_or_fileobj=out, path_in_repo="v4.2_round2_results.json", repo_id="EPSAGR/Node-JEPA", repo_type="model")
print(f"\nPushed to https://huggingface.co/EPSAGR/Node-JEPA")
except: pass
if tracker: trackio.finish()
print("\nDone!")
if __name__ == "__main__":
main()