Node-JEPA / node_jepa_baseline_eval.py
EPSAGR's picture
Fix total_mem -> total_memory for PyTorch compatibility
861dbe5 verified
"""
Node-JEPA V4.2 — Comprehensive Baseline Evaluation
====================================================
Evaluates Node-JEPA on standard benchmarks with:
1. Linear probe accuracy (raw, L2-norm, StandardScaler)
2. Embedding geometry metrics (isotropy, alignment, uniformity, effective rank)
3. Inference timing
4. Multi-seed robustness
Benchmarks: Cora, CiteSeer, PubMed, ogbn-arxiv (if scalable)
Published SSL SOTA for reference:
Cora: 84.2% (GraphMAE)
CiteSeer: 73.4% (GraphMAE)
PubMed: 81.1% (GraphMAE)
ogbn-arxiv: 71.87% (GraphMAE)
Run in Colab:
!pip install torch torch_geometric trackio huggingface_hub ogb scikit-learn
!wget https://huggingface.co/EPSAGR/node-jepa/resolve/main/node_jepa_baseline_eval.py
!python node_jepa_baseline_eval.py
"""
import math, copy, os, sys, json, time, gc
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
try:
from torch_geometric.nn import GATConv, GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
HAS_PYG = True
except ImportError:
HAS_PYG = False
print("ERROR: torch_geometric not installed. Run: pip install torch_geometric")
sys.exit(1)
try:
from ogb.nodeproppred import PygNodePropPredDataset
HAS_OGB = True
except ImportError:
HAS_OGB = False
print("WARNING: ogb not installed, skipping ogbn-arxiv. Run: pip install ogb")
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
try:
import trackio
HAS_TRACKIO = True
except:
HAS_TRACKIO = False
try:
from huggingface_hub import HfApi
HAS_HF_HUB = True
except:
HAS_HF_HUB = False
def decorrelation_loss(z):
N, D = z.shape
z_centered = z - z.mean(0)
z_std = z_centered / (z_centered.std(0) + 1e-8)
R = (z_std.T @ z_std) / N
I = torch.eye(D, device=z.device)
return (R - I).pow(2).sum() / (D * D)
def sce_loss(pred, target, alpha=1):
pred = F.normalize(pred, p=2, dim=-1)
target = F.normalize(target, p=2, dim=-1)
return (1 - (pred * target).sum(dim=-1)).pow(alpha).mean()
class GraphAugmentor:
def __init__(self, feat_mask_p=0.0, edge_drop_p=0.0):
self.feat_mask_p = feat_mask_p
self.edge_drop_p = edge_drop_p
def __call__(self, x, edge_index):
device = x.device
if self.feat_mask_p > 0:
feat_mask = torch.bernoulli(torch.ones_like(x) * (1 - self.feat_mask_p))
x = x * feat_mask
if self.edge_drop_p > 0:
edge_mask = torch.bernoulli(torch.ones(edge_index.size(1), device=device) * (1 - self.edge_drop_p)).bool()
edge_index = edge_index[:, edge_mask]
return x, edge_index
class GATEncoder(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, num_layers=2, num_heads=4, dropout=0.0):
super().__init__()
self.num_layers = num_layers
self.reg_token = nn.Parameter(torch.zeros(1, in_dim))
nn.init.trunc_normal_(self.reg_token, std=0.02)
self.input_proj = nn.Linear(in_dim, hidden_dim)
self.input_bn = nn.BatchNorm1d(hidden_dim, momentum=0.01)
self.gat_layers = nn.ModuleList()
self.bns = nn.ModuleList()
self.activations = nn.ModuleList()
for _ in range(num_layers):
self.gat_layers.append(GATConv(hidden_dim, hidden_dim // num_heads, heads=num_heads,
concat=True, dropout=dropout, add_self_loops=True))
self.bns.append(nn.BatchNorm1d(hidden_dim, momentum=0.01))
self.activations.append(nn.PReLU())
self.output_proj = nn.Linear(hidden_dim, out_dim)
def forward(self, x, edge_index):
device = x.device; N = x.size(0)
x_aug = torch.cat([x, self.reg_token.expand(1, -1).to(device)], dim=0)
reg_edges = torch.cat([
torch.stack([torch.full((N,), N, dtype=torch.long, device=device), torch.arange(N, device=device)]),
torch.stack([torch.arange(N, device=device), torch.full((N,), N, dtype=torch.long, device=device)])], dim=1)
ei_aug = torch.cat([edge_index, reg_edges], dim=1)
h = F.prelu(self.input_bn(self.input_proj(x_aug)), torch.tensor(0.25, device=device))
for i in range(self.num_layers):
h = self.activations[i](self.bns[i](self.gat_layers[i](h, ei_aug))) + h
return self.output_proj(h)[:N]
class GCNEncoder(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, num_layers=2, dropout=0.0):
super().__init__()
self.num_layers = num_layers
self.input_proj = nn.Linear(in_dim, hidden_dim)
self.input_bn = nn.BatchNorm1d(hidden_dim, momentum=0.01)
self.gcn_layers = nn.ModuleList()
self.bns = nn.ModuleList()
self.activations = nn.ModuleList()
for _ in range(num_layers):
self.gcn_layers.append(GCNConv(hidden_dim, hidden_dim))
self.bns.append(nn.BatchNorm1d(hidden_dim, momentum=0.01))
self.activations.append(nn.PReLU())
self.output_proj = nn.Linear(hidden_dim, out_dim)
def forward(self, x, edge_index):
h = F.prelu(self.input_bn(self.input_proj(x)), torch.tensor(0.25, device=x.device))
for i in range(self.num_layers):
h = self.activations[i](self.bns[i](self.gcn_layers[i](h, edge_index))) + h
return self.output_proj(h)
class MLPPredictor(nn.Module):
def __init__(self, dim, hidden=512):
super().__init__()
self.net = nn.Sequential(nn.Linear(dim, hidden), nn.BatchNorm1d(hidden, momentum=0.01), nn.PReLU(), nn.Linear(hidden, dim))
def forward(self, x): return self.net(x)
class FeatureDecoder(nn.Module):
def __init__(self, embed_dim, hidden_dim, out_dim):
super().__init__()
self.net = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.BatchNorm1d(hidden_dim, momentum=0.01), nn.PReLU(), nn.Linear(hidden_dim, out_dim))
def forward(self, x): return self.net(x)
class NodeMasker(nn.Module):
def __init__(self, feat_dim, mask_rate=0.5, replace_rate=0.05):
super().__init__()
self.mask_rate = mask_rate; self.replace_rate = replace_rate
self.mask_token = nn.Parameter(torch.zeros(1, feat_dim))
nn.init.trunc_normal_(self.mask_token, std=0.02)
def forward(self, x):
N = x.size(0); num_mask = max(1, int(N * self.mask_rate))
perm = torch.randperm(N, device=x.device)
mask = torch.zeros(N, dtype=torch.bool, device=x.device); mask[perm[:num_mask]] = True
x_m = x.clone(); nr = max(0, int(num_mask * self.replace_rate))
x_m[perm[:num_mask - nr]] = self.mask_token.expand(num_mask - nr, -1)
if nr > 0: x_m[perm[num_mask - nr:num_mask]] = x[torch.randint(0, N, (nr,), device=x.device)]
return x_m, mask
class NodeJEPA(nn.Module):
def __init__(self, in_dim, hidden_dim=512, out_dim=256, num_layers=2, num_heads=4,
predictor_hidden=512, mask_rate=0.5, replace_rate=0.05,
ema_tau_base=0.99, ema_tau_final=1.0,
aug1_feat_p=0.0, aug1_edge_p=0.0, aug2_feat_p=0.0, aug2_edge_p=0.0,
lambda_dec=0.0, mu_rec=0.0, sce_alpha=1, encoder_type='gat', dropout=0.0):
super().__init__()
self.ema_tau_base = ema_tau_base; self.ema_tau_final = ema_tau_final
self.lambda_dec = lambda_dec; self.mu_rec = mu_rec; self.sce_alpha = sce_alpha
if encoder_type == 'gat':
self.online_encoder = GATEncoder(in_dim, hidden_dim, out_dim, num_layers, num_heads, dropout)
else:
self.online_encoder = GCNEncoder(in_dim, hidden_dim, out_dim, num_layers, dropout)
self.target_encoder = copy.deepcopy(self.online_encoder)
for p in self.target_encoder.parameters(): p.requires_grad = False
self.predictor = MLPPredictor(out_dim, predictor_hidden)
self.masker = NodeMasker(in_dim, mask_rate, replace_rate)
self.aug1 = GraphAugmentor(aug1_feat_p, aug1_edge_p)
self.aug2 = GraphAugmentor(aug2_feat_p, aug2_edge_p)
self.feat_decoder = FeatureDecoder(out_dim, hidden_dim, in_dim) if mu_rec > 0 else None
@torch.no_grad()
def update_target(self, tau):
for po, pt in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
pt.data.mul_(tau).add_(po.data, alpha=1 - tau)
def _jepa_loss(self, x, edge_index, aug_o, aug_t, do_mask=True):
xo, eio = aug_o(x, edge_index); xt, eit = aug_t(x, edge_index)
if do_mask: xo, mask = self.masker(xo)
else: mask = torch.ones(x.size(0), dtype=torch.bool, device=x.device)
ho = self.online_encoder(xo, eio)
with torch.no_grad(): ht = self.target_encoder(xt, eit)
pred = F.normalize(self.predictor(ho[mask]), dim=-1)
tgt = F.normalize(ht[mask], dim=-1)
return 2 - 2 * (pred * tgt).sum(-1).mean(), ho, mask
def forward(self, data):
x, edge_index = data.x, data.edge_index
l1, ho1, m1 = self._jepa_loss(x, edge_index, self.aug1, self.aug2, True)
l2, ho2, m2 = self._jepa_loss(x, edge_index, self.aug2, self.aug1, False)
l_jepa = (l1 + l2) / 2
l_dec = self.lambda_dec * (decorrelation_loss(ho1) + decorrelation_loss(ho2)) / 2 if self.lambda_dec > 0 else torch.tensor(0.0, device=x.device)
l_rec = torch.tensor(0.0, device=x.device)
if self.mu_rec > 0 and self.feat_decoder is not None:
l_rec = self.mu_rec * sce_loss(self.feat_decoder(ho1[m1]), x[m1], self.sce_alpha)
return l_jepa + l_dec + l_rec, l_jepa, l_dec, l_rec
@torch.no_grad()
def encode(self, data):
self.online_encoder.eval()
return self.online_encoder(data.x, data.edge_index)
def cosine_scheduler(base, final, steps):
return [final - (final - base) * (math.cos(math.pi * s / max(steps - 1, 1)) + 1) / 2 for s in range(steps)]
def compute_geometry_metrics(embeddings, labels=None, sample_size=2000):
N, D = embeddings.shape
idx = np.random.choice(N, min(N, sample_size), replace=False) if N > sample_size else np.arange(N)
embs_sub = embeddings[idx]; labels_sub = labels[idx] if labels is not None else None
embs_centered = embs_sub - embs_sub.mean(0)
try: U, S, Vt = np.linalg.svd(embs_centered, full_matrices=False)
except: S = np.ones(min(embs_sub.shape))
total_var = (S**2).sum()
if total_var < 1e-12: eff_rank_95, eff_rank_99 = 1, 1
else:
ev = np.cumsum(S**2) / total_var
eff_rank_95 = int(np.searchsorted(ev, 0.95) + 1); eff_rank_99 = int(np.searchsorted(ev, 0.99) + 1)
if len(S) > 1 and S[0] > 1e-12:
isotropy = float(S[-1] / S[0])
log_s2 = np.log(S**2 + 1e-12)
isotropy_ratio = np.exp(log_s2.mean()) / ((S**2).mean() + 1e-12)
else: isotropy = 0.0; isotropy_ratio = 0.0
norms = np.linalg.norm(embs_sub, axis=1, keepdims=True) + 1e-8
embs_norm = embs_sub / norms; n = embs_sub.shape[0]
cos_matrix = embs_norm @ embs_norm.T; np.fill_diagonal(cos_matrix, 0)
avg_cos = cos_matrix.sum() / (n * (n - 1))
alignment = 0.0
if labels_sub is not None:
sims = []
for c in np.unique(labels_sub):
mc = labels_sub == c
if mc.sum() > 1:
ec = embs_norm[mc]; cc = ec @ ec.T; np.fill_diagonal(cc, 0)
sims.extend(cc[np.triu_indices(ec.shape[0], k=1)].tolist())
if sims: alignment = float(np.mean(sims))
diff = embs_norm[:, None, :] - embs_norm[None, :, :]
sq_dist = (diff ** 2).sum(-1); umask = ~np.eye(n, dtype=bool)
uniformity = float(np.log(np.exp(-2 * sq_dist[umask]).mean() + 1e-12))
return {
"isotropy_minmax": float(isotropy), "isotropy_ratio": float(isotropy_ratio),
"alignment": float(alignment), "uniformity": float(uniformity),
"effective_rank_95": eff_rank_95, "effective_rank_99": eff_rank_99,
"avg_cos_sim": float(avg_cos), "std_per_dim": float(embeddings.std(axis=0).mean()),
"embed_dim": D, "num_nodes": N,
}
def linear_eval_multi(embeddings, labels, train_mask, val_mask, test_mask, num_seeds=20, max_iter=2000):
results = {}
for norm_name, norm_fn in [
("raw", lambda e: e),
("l2", lambda e: e / (np.linalg.norm(e, axis=1, keepdims=True) + 1e-8)),
("standard", lambda e: StandardScaler().fit_transform(e)),
]:
embs = norm_fn(embeddings.copy())
X_train, y_train = embs[train_mask], labels[train_mask]
X_val, y_val = embs[val_mask], labels[val_mask]
X_test, y_test = embs[test_mask], labels[test_mask]
test_accs = []
for seed in range(num_seeds):
best_val, best_test = -1, -1
for C in [0.001, 0.01, 0.1, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0]:
clf = LogisticRegression(max_iter=max_iter, solver='lbfgs', random_state=seed, C=C)
clf.fit(X_train, y_train)
v = clf.score(X_val, y_val)
if v > best_val:
best_val = v
best_test = clf.score(X_test, y_test)
test_accs.append(best_test * 100)
results[norm_name] = {"mean": float(np.mean(test_accs)), "std": float(np.std(test_accs))}
return results
def measure_inference_time(model, data, num_runs=100, warmup=10):
model.eval(); device = next(model.parameters()).device; data = data.to(device)
with torch.no_grad():
for _ in range(warmup): _ = model.encode(data)
if device.type == 'cuda': torch.cuda.synchronize()
times = []
with torch.no_grad():
for _ in range(num_runs):
if device.type == 'cuda': torch.cuda.synchronize()
t0 = time.perf_counter(); _ = model.encode(data)
if device.type == 'cuda': torch.cuda.synchronize()
times.append((time.perf_counter() - t0) * 1000)
return {"mean_ms": float(np.mean(times)), "std_ms": float(np.std(times)),
"median_ms": float(np.median(times)), "num_nodes": data.num_nodes,
"ms_per_1k_nodes": float(np.mean(times) / (data.num_nodes / 1000))}
def get_lr(epoch, warmup_epochs, total_epochs, base_lr, min_lr=1e-6):
if epoch < warmup_epochs: return base_lr * (epoch + 1) / warmup_epochs
progress = (epoch - warmup_epochs) / max(total_epochs - warmup_epochs, 1)
return min_lr + (base_lr - min_lr) * (1 + math.cos(math.pi * progress)) / 2
def train_and_evaluate(dataset_name, data, device, config, seed=42):
torch.manual_seed(seed); np.random.seed(seed)
if device.type == 'cuda': torch.cuda.manual_seed(seed)
epochs, lr, warmup = config['epochs'], config['lr'], config['warmup']
model = NodeJEPA(
in_dim=data.num_features, hidden_dim=config['hidden_dim'], out_dim=config['out_dim'],
num_layers=config['num_layers'], num_heads=config.get('num_heads', 4),
predictor_hidden=config['hidden_dim'], mask_rate=config['mask_rate'],
ema_tau_base=config['ema_tau_base'],
aug1_feat_p=config['aug1_feat_p'], aug1_edge_p=config['aug1_edge_p'],
aug2_feat_p=config['aug2_feat_p'], aug2_edge_p=config['aug2_edge_p'],
lambda_dec=config.get('lambda_dec', 0.0), mu_rec=config.get('mu_rec', 0.0),
encoder_type=config.get('encoder_type', 'gat'),
).to(device)
optimizer = AdamW([p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=config.get('weight_decay', 1e-5))
ema_sched = cosine_scheduler(config['ema_tau_base'], 1.0, epochs)
t0 = time.time(); eval_interval = max(1, epochs // 10); log_interval = max(1, epochs // 30)
best_eval, best_ep, train_losses = 0, 0, []
for ep in range(epochs):
model.train(); cur_lr = get_lr(ep, warmup, epochs, lr)
for pg in optimizer.param_groups: pg['lr'] = cur_lr
total_loss, l_jepa, l_dec, l_rec = model(data)
optimizer.zero_grad(); total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
optimizer.step(); model.update_target(ema_sched[ep])
train_losses.append(total_loss.item())
if (ep + 1) % log_interval == 0 or ep == 0:
print(f" [{dataset_name}] Ep {ep+1:4d}/{epochs} | L={total_loss.item():.4f} (j={l_jepa.item():.4f} d={l_dec.item():.4f} r={l_rec.item():.4f}) | LR={cur_lr:.1e} | {time.time()-t0:.0f}s"); sys.stdout.flush()
if (ep + 1) % eval_interval == 0:
model.eval()
with torch.no_grad(): embs = model.encode(data).cpu().numpy()
ln = data.y.cpu().numpy()
if len(ln.shape) > 1: ln = ln.squeeze()
X_tr, y_tr = embs[data.train_mask.cpu().numpy()], ln[data.train_mask.cpu().numpy()]
X_te, y_te = embs[data.test_mask.cpu().numpy()], ln[data.test_mask.cpu().numpy()]
clf = LogisticRegression(max_iter=2000, solver='lbfgs', C=1.0); clf.fit(X_tr, y_tr)
acc = clf.score(X_te, y_te) * 100
if acc > best_eval: best_eval = acc; best_ep = ep + 1
print(f" [{dataset_name}] EVAL @{ep+1}: raw={acc:.1f}% | BEST={best_eval:.1f}%@{best_ep}"); sys.stdout.flush()
train_time = time.time() - t0
print(f"\n [{dataset_name}] === FINAL COMPREHENSIVE EVALUATION ===")
model.eval()
with torch.no_grad(): embeddings = model.encode(data).cpu().numpy()
labels_np = data.y.cpu().numpy()
if len(labels_np.shape) > 1: labels_np = labels_np.squeeze()
train_mask = data.train_mask.cpu().numpy()
val_mask = data.val_mask.cpu().numpy() if hasattr(data, 'val_mask') and data.val_mask is not None else train_mask
test_mask = data.test_mask.cpu().numpy()
print(f" [{dataset_name}] Running 20-seed linear probe...")
probe = linear_eval_multi(embeddings, labels_np, train_mask, val_mask, test_mask, num_seeds=20)
best_acc = max(probe[k]['mean'] for k in probe)
best_norm = max(probe, key=lambda k: probe[k]['mean'])
print(f" [{dataset_name}] ACCURACY: raw={probe['raw']['mean']:.2f}+/-{probe['raw']['std']:.2f} | l2={probe['l2']['mean']:.2f}+/-{probe['l2']['std']:.2f} | std={probe['standard']['mean']:.2f}+/-{probe['standard']['std']:.2f}")
print(f" [{dataset_name}] BEST: {best_acc:.2f}% ({best_norm})")
print(f" [{dataset_name}] Computing embedding geometry...")
geometry = compute_geometry_metrics(embeddings, labels_np)
print(f" [{dataset_name}] GEOMETRY: isotropy={geometry['isotropy_ratio']:.6f} align={geometry['alignment']:.4f} uniform={geometry['uniformity']:.4f} rank95={geometry['effective_rank_95']} cos={geometry['avg_cos_sim']:.4f}")
print(f" [{dataset_name}] Measuring inference time...")
timing = measure_inference_time(model, data)
print(f" [{dataset_name}] TIMING: {timing['mean_ms']:.2f}+/-{timing['std_ms']:.2f}ms | {timing['ms_per_1k_nodes']:.2f}ms/1K nodes")
checks = {"isotropy_ratio > 0.3": geometry['isotropy_ratio'] > 0.3, "alignment 0.4-0.8": 0.4 <= geometry['alignment'] <= 0.8,
"uniformity < -2.0": geometry['uniformity'] < -2.0, "inference <= 5ms/1K": timing['ms_per_1k_nodes'] <= 5.0}
print(f" [{dataset_name}] CHECKS: " + " | ".join(f"{'PASS' if v else 'FAIL'}:{k}" for k, v in checks.items()))
return {"dataset": dataset_name, "config": {k: v for k, v in config.items()}, "seed": seed,
"accuracy": probe, "best_accuracy": best_acc, "best_norm": best_norm,
"geometry": geometry, "timing": timing, "train_time": train_time,
"train_losses": train_losses[-10:], "checks": {k: bool(v) for k, v in checks.items()}}
CONFIGS = {
"Cora": {'epochs': 600, 'lr': 1e-3, 'warmup': 100, 'hidden_dim': 512, 'out_dim': 256, 'num_layers': 2, 'num_heads': 4, 'mask_rate': 0.5, 'ema_tau_base': 0.99, 'weight_decay': 1e-5, 'aug1_feat_p': 0.5, 'aug1_edge_p': 0.5, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.5, 'encoder_type': 'gcn', 'lambda_dec': 0.0, 'mu_rec': 2.0},
"CiteSeer": {'epochs': 600, 'lr': 3e-4, 'warmup': 200, 'hidden_dim': 512, 'out_dim': 256, 'num_layers': 2, 'num_heads': 4, 'mask_rate': 0.5, 'ema_tau_base': 0.99, 'weight_decay': 1e-5, 'aug1_feat_p': 0.5, 'aug1_edge_p': 0.4, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.4, 'encoder_type': 'gat', 'lambda_dec': 0.0, 'mu_rec': 2.0},
"PubMed": {'epochs': 600, 'lr': 1e-3, 'warmup': 100, 'hidden_dim': 512, 'out_dim': 256, 'num_layers': 2, 'num_heads': 4, 'mask_rate': 0.6, 'ema_tau_base': 0.99, 'weight_decay': 1e-5, 'aug1_feat_p': 0.5, 'aug1_edge_p': 0.4, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.5, 'encoder_type': 'gcn', 'lambda_dec': 0.0, 'mu_rec': 2.0},
}
ARXIV_CONFIG = {'epochs': 300, 'lr': 1e-3, 'warmup': 100, 'hidden_dim': 1024, 'out_dim': 512, 'num_layers': 3, 'num_heads': 4, 'mask_rate': 0.5, 'ema_tau_base': 0.996, 'weight_decay': 1e-5, 'aug1_feat_p': 0.5, 'aug1_edge_p': 0.4, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.4, 'encoder_type': 'gcn', 'lambda_dec': 0.0, 'mu_rec': 2.0}
def main():
torch.manual_seed(42); np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if device.type == "cuda":
print(f" GPU: {torch.cuda.get_device_name(0)}")
try:
print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
except AttributeError:
print(f" VRAM: unknown")
torch.cuda.manual_seed(42)
tracker = None
if HAS_TRACKIO:
try: trackio.init(name="node-jepa-v4.2-baseline", project="node-jepa"); tracker = True; print("Trackio initialized")
except Exception as e: print(f"Trackio init failed: {e}")
all_results = {}; ssl_sota = {"Cora": 84.2, "CiteSeer": 73.4, "PubMed": 81.1}
for ds in ["Cora", "CiteSeer", "PubMed"]:
print(f"\n{'='*80}\nBENCHMARK: {ds} (SSL SOTA: {ssl_sota[ds]}%)\n{'='*80}")
dataset = Planetoid(root="/tmp/pyg_data", name=ds, transform=NormalizeFeatures())
data = dataset[0].to(device)
print(f" {data.num_nodes} nodes, {data.num_edges} edges, {data.num_features} feat, {dataset.num_classes} classes")
seeds_res = []
for seed in [42, 123, 456]:
print(f"\n --- Seed {seed} ---")
r = train_and_evaluate(ds, data, device, CONFIGS[ds], seed=seed); seeds_res.append(r)
if tracker: trackio.log({f"{ds}/s{seed}/acc": r['best_accuracy'], f"{ds}/s{seed}/rank95": r['geometry']['effective_rank_95']})
accs = [r['best_accuracy'] for r in seeds_res]
all_results[ds] = {"seeds": seeds_res, "mean_accuracy": float(np.mean(accs)), "std_accuracy": float(np.std(accs)),
"max_accuracy": float(np.max(accs)), "ssl_sota": ssl_sota[ds], "gap_to_sota": float(ssl_sota[ds] - np.mean(accs))}
print(f"\n [{ds}] SUMMARY: {np.mean(accs):.2f}+/-{np.std(accs):.2f}% (max {np.max(accs):.2f}%) | SOTA: {ssl_sota[ds]}% | Gap: {ssl_sota[ds]-np.mean(accs):.1f}pts")
del data, dataset; gc.collect()
if device.type == 'cuda': torch.cuda.empty_cache()
if HAS_OGB and device.type == 'cuda':
print(f"\n{'='*80}\nBENCHMARK: ogbn-arxiv (SSL SOTA: 71.87%)\n{'='*80}")
try:
dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='/tmp/ogb_data'); data = dataset[0]
si = dataset.get_idx_split(); N = data.x.size(0)
tm, vm, tsm = torch.zeros(N, dtype=torch.bool), torch.zeros(N, dtype=torch.bool), torch.zeros(N, dtype=torch.bool)
tm[si['train']] = True; vm[si['valid']] = True; tsm[si['test']] = True
data.train_mask, data.val_mask, data.test_mask = tm, vm, tsm; data.y = data.y.squeeze()
ei = data.edge_index; data.edge_index = torch.cat([ei, torch.stack([ei[1], ei[0]])], dim=1)
data = data.to(device); print(f" {data.num_nodes} nodes, {data.edge_index.size(1)} edges")
r = train_and_evaluate("ogbn-arxiv", data, device, ARXIV_CONFIG, seed=42)
all_results["ogbn-arxiv"] = {"seeds": [r], "mean_accuracy": r['best_accuracy'], "ssl_sota": 71.87, "gap_to_ssl_sota": 71.87 - r['best_accuracy']}
except Exception as e: print(f" ogbn-arxiv FAILED: {e}"); import traceback; traceback.print_exc()
print(f"\n{'='*100}\nNODE-JEPA V4.2 — COMPREHENSIVE BASELINE RESULTS\n{'='*100}")
print(f"\n{'Dataset':<15} {'Ours':>10} {'SOTA':>8} {'Gap':>7} {'Isotropy':>10} {'Align':>8} {'Uniform':>10} {'Rank95':>8} {'ms/1K':>8}")
print("-" * 95)
for ds in ["Cora", "CiteSeer", "PubMed", "ogbn-arxiv"]:
if ds not in all_results or "error" in all_results.get(ds, {}):
print(f" {ds:<13} {'SKIP':>10}"); continue
r = all_results[ds]; g = r['seeds'][0]['geometry']; t = r['seeds'][0]['timing']
print(f" {ds:<13} {r['mean_accuracy']:>9.2f}% {r['ssl_sota']:>7.1f}% {r.get('gap_to_sota', r['ssl_sota']-r['mean_accuracy']):>+6.1f} {g['isotropy_ratio']:>9.4f} {g['alignment']:>7.4f} {g['uniformity']:>9.4f} {g['effective_rank_95']:>8} {t['ms_per_1k_nodes']:>7.2f}")
print(f"\nRef: GraphMAE (KDD'22) | Note: ogbn-arxiv SSL ceiling ~72%, ogbn-products has NO published SSL baseline")
out = "/tmp/v4.2_baseline_results.json"
with open(out, "w") as f: json.dump(all_results, f, indent=2, default=str)
if HAS_HF_HUB:
try:
api = HfApi()
api.upload_file(path_or_fileobj=out, path_in_repo="v4.2_baseline_results.json", repo_id="EPSAGR/Node-JEPA", repo_type="model")
print(f"Pushed to https://huggingface.co/EPSAGR/Node-JEPA")
except: pass
if tracker: trackio.finish()
print("\nDone!")
if __name__ == "__main__":
main()