""" Node-JEPA V4.2 — Comprehensive Baseline Evaluation ==================================================== Evaluates Node-JEPA on standard benchmarks with: 1. Linear probe accuracy (raw, L2-norm, StandardScaler) 2. Embedding geometry metrics (isotropy, alignment, uniformity, effective rank) 3. Inference timing 4. Multi-seed robustness Benchmarks: Cora, CiteSeer, PubMed, ogbn-arxiv (if scalable) Published SSL SOTA for reference: Cora: 84.2% (GraphMAE) CiteSeer: 73.4% (GraphMAE) PubMed: 81.1% (GraphMAE) ogbn-arxiv: 71.87% (GraphMAE) Run in Colab: !pip install torch torch_geometric trackio huggingface_hub ogb scikit-learn !wget https://huggingface.co/EPSAGR/node-jepa/resolve/main/node_jepa_baseline_eval.py !python node_jepa_baseline_eval.py """ import math, copy, os, sys, json, time, gc import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import AdamW import warnings warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) try: from torch_geometric.nn import GATConv, GCNConv from torch_geometric.datasets import Planetoid from torch_geometric.transforms import NormalizeFeatures HAS_PYG = True except ImportError: HAS_PYG = False print("ERROR: torch_geometric not installed. Run: pip install torch_geometric") sys.exit(1) try: from ogb.nodeproppred import PygNodePropPredDataset HAS_OGB = True except ImportError: HAS_OGB = False print("WARNING: ogb not installed, skipping ogbn-arxiv. Run: pip install ogb") from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import StandardScaler try: import trackio HAS_TRACKIO = True except: HAS_TRACKIO = False try: from huggingface_hub import HfApi HAS_HF_HUB = True except: HAS_HF_HUB = False def decorrelation_loss(z): N, D = z.shape z_centered = z - z.mean(0) z_std = z_centered / (z_centered.std(0) + 1e-8) R = (z_std.T @ z_std) / N I = torch.eye(D, device=z.device) return (R - I).pow(2).sum() / (D * D) def sce_loss(pred, target, alpha=1): pred = F.normalize(pred, p=2, dim=-1) target = F.normalize(target, p=2, dim=-1) return (1 - (pred * target).sum(dim=-1)).pow(alpha).mean() class GraphAugmentor: def __init__(self, feat_mask_p=0.0, edge_drop_p=0.0): self.feat_mask_p = feat_mask_p self.edge_drop_p = edge_drop_p def __call__(self, x, edge_index): device = x.device if self.feat_mask_p > 0: feat_mask = torch.bernoulli(torch.ones_like(x) * (1 - self.feat_mask_p)) x = x * feat_mask if self.edge_drop_p > 0: edge_mask = torch.bernoulli(torch.ones(edge_index.size(1), device=device) * (1 - self.edge_drop_p)).bool() edge_index = edge_index[:, edge_mask] return x, edge_index class GATEncoder(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim, num_layers=2, num_heads=4, dropout=0.0): super().__init__() self.num_layers = num_layers self.reg_token = nn.Parameter(torch.zeros(1, in_dim)) nn.init.trunc_normal_(self.reg_token, std=0.02) self.input_proj = nn.Linear(in_dim, hidden_dim) self.input_bn = nn.BatchNorm1d(hidden_dim, momentum=0.01) self.gat_layers = nn.ModuleList() self.bns = nn.ModuleList() self.activations = nn.ModuleList() for _ in range(num_layers): self.gat_layers.append(GATConv(hidden_dim, hidden_dim // num_heads, heads=num_heads, concat=True, dropout=dropout, add_self_loops=True)) self.bns.append(nn.BatchNorm1d(hidden_dim, momentum=0.01)) self.activations.append(nn.PReLU()) self.output_proj = nn.Linear(hidden_dim, out_dim) def forward(self, x, edge_index): device = x.device; N = x.size(0) x_aug = torch.cat([x, self.reg_token.expand(1, -1).to(device)], dim=0) reg_edges = torch.cat([ torch.stack([torch.full((N,), N, dtype=torch.long, device=device), torch.arange(N, device=device)]), torch.stack([torch.arange(N, device=device), torch.full((N,), N, dtype=torch.long, device=device)])], dim=1) ei_aug = torch.cat([edge_index, reg_edges], dim=1) h = F.prelu(self.input_bn(self.input_proj(x_aug)), torch.tensor(0.25, device=device)) for i in range(self.num_layers): h = self.activations[i](self.bns[i](self.gat_layers[i](h, ei_aug))) + h return self.output_proj(h)[:N] class GCNEncoder(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim, num_layers=2, dropout=0.0): super().__init__() self.num_layers = num_layers self.input_proj = nn.Linear(in_dim, hidden_dim) self.input_bn = nn.BatchNorm1d(hidden_dim, momentum=0.01) self.gcn_layers = nn.ModuleList() self.bns = nn.ModuleList() self.activations = nn.ModuleList() for _ in range(num_layers): self.gcn_layers.append(GCNConv(hidden_dim, hidden_dim)) self.bns.append(nn.BatchNorm1d(hidden_dim, momentum=0.01)) self.activations.append(nn.PReLU()) self.output_proj = nn.Linear(hidden_dim, out_dim) def forward(self, x, edge_index): h = F.prelu(self.input_bn(self.input_proj(x)), torch.tensor(0.25, device=x.device)) for i in range(self.num_layers): h = self.activations[i](self.bns[i](self.gcn_layers[i](h, edge_index))) + h return self.output_proj(h) class MLPPredictor(nn.Module): def __init__(self, dim, hidden=512): super().__init__() self.net = nn.Sequential(nn.Linear(dim, hidden), nn.BatchNorm1d(hidden, momentum=0.01), nn.PReLU(), nn.Linear(hidden, dim)) def forward(self, x): return self.net(x) class FeatureDecoder(nn.Module): def __init__(self, embed_dim, hidden_dim, out_dim): super().__init__() self.net = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.BatchNorm1d(hidden_dim, momentum=0.01), nn.PReLU(), nn.Linear(hidden_dim, out_dim)) def forward(self, x): return self.net(x) class NodeMasker(nn.Module): def __init__(self, feat_dim, mask_rate=0.5, replace_rate=0.05): super().__init__() self.mask_rate = mask_rate; self.replace_rate = replace_rate self.mask_token = nn.Parameter(torch.zeros(1, feat_dim)) nn.init.trunc_normal_(self.mask_token, std=0.02) def forward(self, x): N = x.size(0); num_mask = max(1, int(N * self.mask_rate)) perm = torch.randperm(N, device=x.device) mask = torch.zeros(N, dtype=torch.bool, device=x.device); mask[perm[:num_mask]] = True x_m = x.clone(); nr = max(0, int(num_mask * self.replace_rate)) x_m[perm[:num_mask - nr]] = self.mask_token.expand(num_mask - nr, -1) if nr > 0: x_m[perm[num_mask - nr:num_mask]] = x[torch.randint(0, N, (nr,), device=x.device)] return x_m, mask class NodeJEPA(nn.Module): def __init__(self, in_dim, hidden_dim=512, out_dim=256, num_layers=2, num_heads=4, predictor_hidden=512, mask_rate=0.5, replace_rate=0.05, ema_tau_base=0.99, ema_tau_final=1.0, aug1_feat_p=0.0, aug1_edge_p=0.0, aug2_feat_p=0.0, aug2_edge_p=0.0, lambda_dec=0.0, mu_rec=0.0, sce_alpha=1, encoder_type='gat', dropout=0.0): super().__init__() self.ema_tau_base = ema_tau_base; self.ema_tau_final = ema_tau_final self.lambda_dec = lambda_dec; self.mu_rec = mu_rec; self.sce_alpha = sce_alpha if encoder_type == 'gat': self.online_encoder = GATEncoder(in_dim, hidden_dim, out_dim, num_layers, num_heads, dropout) else: self.online_encoder = GCNEncoder(in_dim, hidden_dim, out_dim, num_layers, dropout) self.target_encoder = copy.deepcopy(self.online_encoder) for p in self.target_encoder.parameters(): p.requires_grad = False self.predictor = MLPPredictor(out_dim, predictor_hidden) self.masker = NodeMasker(in_dim, mask_rate, replace_rate) self.aug1 = GraphAugmentor(aug1_feat_p, aug1_edge_p) self.aug2 = GraphAugmentor(aug2_feat_p, aug2_edge_p) self.feat_decoder = FeatureDecoder(out_dim, hidden_dim, in_dim) if mu_rec > 0 else None @torch.no_grad() def update_target(self, tau): for po, pt in zip(self.online_encoder.parameters(), self.target_encoder.parameters()): pt.data.mul_(tau).add_(po.data, alpha=1 - tau) def _jepa_loss(self, x, edge_index, aug_o, aug_t, do_mask=True): xo, eio = aug_o(x, edge_index); xt, eit = aug_t(x, edge_index) if do_mask: xo, mask = self.masker(xo) else: mask = torch.ones(x.size(0), dtype=torch.bool, device=x.device) ho = self.online_encoder(xo, eio) with torch.no_grad(): ht = self.target_encoder(xt, eit) pred = F.normalize(self.predictor(ho[mask]), dim=-1) tgt = F.normalize(ht[mask], dim=-1) return 2 - 2 * (pred * tgt).sum(-1).mean(), ho, mask def forward(self, data): x, edge_index = data.x, data.edge_index l1, ho1, m1 = self._jepa_loss(x, edge_index, self.aug1, self.aug2, True) l2, ho2, m2 = self._jepa_loss(x, edge_index, self.aug2, self.aug1, False) l_jepa = (l1 + l2) / 2 l_dec = self.lambda_dec * (decorrelation_loss(ho1) + decorrelation_loss(ho2)) / 2 if self.lambda_dec > 0 else torch.tensor(0.0, device=x.device) l_rec = torch.tensor(0.0, device=x.device) if self.mu_rec > 0 and self.feat_decoder is not None: l_rec = self.mu_rec * sce_loss(self.feat_decoder(ho1[m1]), x[m1], self.sce_alpha) return l_jepa + l_dec + l_rec, l_jepa, l_dec, l_rec @torch.no_grad() def encode(self, data): self.online_encoder.eval() return self.online_encoder(data.x, data.edge_index) def cosine_scheduler(base, final, steps): return [final - (final - base) * (math.cos(math.pi * s / max(steps - 1, 1)) + 1) / 2 for s in range(steps)] def compute_geometry_metrics(embeddings, labels=None, sample_size=2000): N, D = embeddings.shape idx = np.random.choice(N, min(N, sample_size), replace=False) if N > sample_size else np.arange(N) embs_sub = embeddings[idx]; labels_sub = labels[idx] if labels is not None else None embs_centered = embs_sub - embs_sub.mean(0) try: U, S, Vt = np.linalg.svd(embs_centered, full_matrices=False) except: S = np.ones(min(embs_sub.shape)) total_var = (S**2).sum() if total_var < 1e-12: eff_rank_95, eff_rank_99 = 1, 1 else: ev = np.cumsum(S**2) / total_var eff_rank_95 = int(np.searchsorted(ev, 0.95) + 1); eff_rank_99 = int(np.searchsorted(ev, 0.99) + 1) if len(S) > 1 and S[0] > 1e-12: isotropy = float(S[-1] / S[0]) log_s2 = np.log(S**2 + 1e-12) isotropy_ratio = np.exp(log_s2.mean()) / ((S**2).mean() + 1e-12) else: isotropy = 0.0; isotropy_ratio = 0.0 norms = np.linalg.norm(embs_sub, axis=1, keepdims=True) + 1e-8 embs_norm = embs_sub / norms; n = embs_sub.shape[0] cos_matrix = embs_norm @ embs_norm.T; np.fill_diagonal(cos_matrix, 0) avg_cos = cos_matrix.sum() / (n * (n - 1)) alignment = 0.0 if labels_sub is not None: sims = [] for c in np.unique(labels_sub): mc = labels_sub == c if mc.sum() > 1: ec = embs_norm[mc]; cc = ec @ ec.T; np.fill_diagonal(cc, 0) sims.extend(cc[np.triu_indices(ec.shape[0], k=1)].tolist()) if sims: alignment = float(np.mean(sims)) diff = embs_norm[:, None, :] - embs_norm[None, :, :] sq_dist = (diff ** 2).sum(-1); umask = ~np.eye(n, dtype=bool) uniformity = float(np.log(np.exp(-2 * sq_dist[umask]).mean() + 1e-12)) return { "isotropy_minmax": float(isotropy), "isotropy_ratio": float(isotropy_ratio), "alignment": float(alignment), "uniformity": float(uniformity), "effective_rank_95": eff_rank_95, "effective_rank_99": eff_rank_99, "avg_cos_sim": float(avg_cos), "std_per_dim": float(embeddings.std(axis=0).mean()), "embed_dim": D, "num_nodes": N, } def linear_eval_multi(embeddings, labels, train_mask, val_mask, test_mask, num_seeds=20, max_iter=2000): results = {} for norm_name, norm_fn in [ ("raw", lambda e: e), ("l2", lambda e: e / (np.linalg.norm(e, axis=1, keepdims=True) + 1e-8)), ("standard", lambda e: StandardScaler().fit_transform(e)), ]: embs = norm_fn(embeddings.copy()) X_train, y_train = embs[train_mask], labels[train_mask] X_val, y_val = embs[val_mask], labels[val_mask] X_test, y_test = embs[test_mask], labels[test_mask] test_accs = [] for seed in range(num_seeds): best_val, best_test = -1, -1 for C in [0.001, 0.01, 0.1, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0]: clf = LogisticRegression(max_iter=max_iter, solver='lbfgs', random_state=seed, C=C) clf.fit(X_train, y_train) v = clf.score(X_val, y_val) if v > best_val: best_val = v best_test = clf.score(X_test, y_test) test_accs.append(best_test * 100) results[norm_name] = {"mean": float(np.mean(test_accs)), "std": float(np.std(test_accs))} return results def measure_inference_time(model, data, num_runs=100, warmup=10): model.eval(); device = next(model.parameters()).device; data = data.to(device) with torch.no_grad(): for _ in range(warmup): _ = model.encode(data) if device.type == 'cuda': torch.cuda.synchronize() times = [] with torch.no_grad(): for _ in range(num_runs): if device.type == 'cuda': torch.cuda.synchronize() t0 = time.perf_counter(); _ = model.encode(data) if device.type == 'cuda': torch.cuda.synchronize() times.append((time.perf_counter() - t0) * 1000) return {"mean_ms": float(np.mean(times)), "std_ms": float(np.std(times)), "median_ms": float(np.median(times)), "num_nodes": data.num_nodes, "ms_per_1k_nodes": float(np.mean(times) / (data.num_nodes / 1000))} def get_lr(epoch, warmup_epochs, total_epochs, base_lr, min_lr=1e-6): if epoch < warmup_epochs: return base_lr * (epoch + 1) / warmup_epochs progress = (epoch - warmup_epochs) / max(total_epochs - warmup_epochs, 1) return min_lr + (base_lr - min_lr) * (1 + math.cos(math.pi * progress)) / 2 def train_and_evaluate(dataset_name, data, device, config, seed=42): torch.manual_seed(seed); np.random.seed(seed) if device.type == 'cuda': torch.cuda.manual_seed(seed) epochs, lr, warmup = config['epochs'], config['lr'], config['warmup'] model = NodeJEPA( in_dim=data.num_features, hidden_dim=config['hidden_dim'], out_dim=config['out_dim'], num_layers=config['num_layers'], num_heads=config.get('num_heads', 4), predictor_hidden=config['hidden_dim'], mask_rate=config['mask_rate'], ema_tau_base=config['ema_tau_base'], aug1_feat_p=config['aug1_feat_p'], aug1_edge_p=config['aug1_edge_p'], aug2_feat_p=config['aug2_feat_p'], aug2_edge_p=config['aug2_edge_p'], lambda_dec=config.get('lambda_dec', 0.0), mu_rec=config.get('mu_rec', 0.0), encoder_type=config.get('encoder_type', 'gat'), ).to(device) optimizer = AdamW([p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=config.get('weight_decay', 1e-5)) ema_sched = cosine_scheduler(config['ema_tau_base'], 1.0, epochs) t0 = time.time(); eval_interval = max(1, epochs // 10); log_interval = max(1, epochs // 30) best_eval, best_ep, train_losses = 0, 0, [] for ep in range(epochs): model.train(); cur_lr = get_lr(ep, warmup, epochs, lr) for pg in optimizer.param_groups: pg['lr'] = cur_lr total_loss, l_jepa, l_dec, l_rec = model(data) optimizer.zero_grad(); total_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0) optimizer.step(); model.update_target(ema_sched[ep]) train_losses.append(total_loss.item()) if (ep + 1) % log_interval == 0 or ep == 0: print(f" [{dataset_name}] Ep {ep+1:4d}/{epochs} | L={total_loss.item():.4f} (j={l_jepa.item():.4f} d={l_dec.item():.4f} r={l_rec.item():.4f}) | LR={cur_lr:.1e} | {time.time()-t0:.0f}s"); sys.stdout.flush() if (ep + 1) % eval_interval == 0: model.eval() with torch.no_grad(): embs = model.encode(data).cpu().numpy() ln = data.y.cpu().numpy() if len(ln.shape) > 1: ln = ln.squeeze() X_tr, y_tr = embs[data.train_mask.cpu().numpy()], ln[data.train_mask.cpu().numpy()] X_te, y_te = embs[data.test_mask.cpu().numpy()], ln[data.test_mask.cpu().numpy()] clf = LogisticRegression(max_iter=2000, solver='lbfgs', C=1.0); clf.fit(X_tr, y_tr) acc = clf.score(X_te, y_te) * 100 if acc > best_eval: best_eval = acc; best_ep = ep + 1 print(f" [{dataset_name}] EVAL @{ep+1}: raw={acc:.1f}% | BEST={best_eval:.1f}%@{best_ep}"); sys.stdout.flush() train_time = time.time() - t0 print(f"\n [{dataset_name}] === FINAL COMPREHENSIVE EVALUATION ===") model.eval() with torch.no_grad(): embeddings = model.encode(data).cpu().numpy() labels_np = data.y.cpu().numpy() if len(labels_np.shape) > 1: labels_np = labels_np.squeeze() train_mask = data.train_mask.cpu().numpy() val_mask = data.val_mask.cpu().numpy() if hasattr(data, 'val_mask') and data.val_mask is not None else train_mask test_mask = data.test_mask.cpu().numpy() print(f" [{dataset_name}] Running 20-seed linear probe...") probe = linear_eval_multi(embeddings, labels_np, train_mask, val_mask, test_mask, num_seeds=20) best_acc = max(probe[k]['mean'] for k in probe) best_norm = max(probe, key=lambda k: probe[k]['mean']) print(f" [{dataset_name}] ACCURACY: raw={probe['raw']['mean']:.2f}+/-{probe['raw']['std']:.2f} | l2={probe['l2']['mean']:.2f}+/-{probe['l2']['std']:.2f} | std={probe['standard']['mean']:.2f}+/-{probe['standard']['std']:.2f}") print(f" [{dataset_name}] BEST: {best_acc:.2f}% ({best_norm})") print(f" [{dataset_name}] Computing embedding geometry...") geometry = compute_geometry_metrics(embeddings, labels_np) print(f" [{dataset_name}] GEOMETRY: isotropy={geometry['isotropy_ratio']:.6f} align={geometry['alignment']:.4f} uniform={geometry['uniformity']:.4f} rank95={geometry['effective_rank_95']} cos={geometry['avg_cos_sim']:.4f}") print(f" [{dataset_name}] Measuring inference time...") timing = measure_inference_time(model, data) print(f" [{dataset_name}] TIMING: {timing['mean_ms']:.2f}+/-{timing['std_ms']:.2f}ms | {timing['ms_per_1k_nodes']:.2f}ms/1K nodes") checks = {"isotropy_ratio > 0.3": geometry['isotropy_ratio'] > 0.3, "alignment 0.4-0.8": 0.4 <= geometry['alignment'] <= 0.8, "uniformity < -2.0": geometry['uniformity'] < -2.0, "inference <= 5ms/1K": timing['ms_per_1k_nodes'] <= 5.0} print(f" [{dataset_name}] CHECKS: " + " | ".join(f"{'PASS' if v else 'FAIL'}:{k}" for k, v in checks.items())) return {"dataset": dataset_name, "config": {k: v for k, v in config.items()}, "seed": seed, "accuracy": probe, "best_accuracy": best_acc, "best_norm": best_norm, "geometry": geometry, "timing": timing, "train_time": train_time, "train_losses": train_losses[-10:], "checks": {k: bool(v) for k, v in checks.items()}} CONFIGS = { "Cora": {'epochs': 600, 'lr': 1e-3, 'warmup': 100, 'hidden_dim': 512, 'out_dim': 256, 'num_layers': 2, 'num_heads': 4, 'mask_rate': 0.5, 'ema_tau_base': 0.99, 'weight_decay': 1e-5, 'aug1_feat_p': 0.5, 'aug1_edge_p': 0.5, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.5, 'encoder_type': 'gcn', 'lambda_dec': 0.0, 'mu_rec': 2.0}, "CiteSeer": {'epochs': 600, 'lr': 3e-4, 'warmup': 200, 'hidden_dim': 512, 'out_dim': 256, 'num_layers': 2, 'num_heads': 4, 'mask_rate': 0.5, 'ema_tau_base': 0.99, 'weight_decay': 1e-5, 'aug1_feat_p': 0.5, 'aug1_edge_p': 0.4, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.4, 'encoder_type': 'gat', 'lambda_dec': 0.0, 'mu_rec': 2.0}, "PubMed": {'epochs': 600, 'lr': 1e-3, 'warmup': 100, 'hidden_dim': 512, 'out_dim': 256, 'num_layers': 2, 'num_heads': 4, 'mask_rate': 0.6, 'ema_tau_base': 0.99, 'weight_decay': 1e-5, 'aug1_feat_p': 0.5, 'aug1_edge_p': 0.4, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.5, 'encoder_type': 'gcn', 'lambda_dec': 0.0, 'mu_rec': 2.0}, } ARXIV_CONFIG = {'epochs': 300, 'lr': 1e-3, 'warmup': 100, 'hidden_dim': 1024, 'out_dim': 512, 'num_layers': 3, 'num_heads': 4, 'mask_rate': 0.5, 'ema_tau_base': 0.996, 'weight_decay': 1e-5, 'aug1_feat_p': 0.5, 'aug1_edge_p': 0.4, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.4, 'encoder_type': 'gcn', 'lambda_dec': 0.0, 'mu_rec': 2.0} def main(): torch.manual_seed(42); np.random.seed(42) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") if device.type == "cuda": print(f" GPU: {torch.cuda.get_device_name(0)}") try: print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") except AttributeError: print(f" VRAM: unknown") torch.cuda.manual_seed(42) tracker = None if HAS_TRACKIO: try: trackio.init(name="node-jepa-v4.2-baseline", project="node-jepa"); tracker = True; print("Trackio initialized") except Exception as e: print(f"Trackio init failed: {e}") all_results = {}; ssl_sota = {"Cora": 84.2, "CiteSeer": 73.4, "PubMed": 81.1} for ds in ["Cora", "CiteSeer", "PubMed"]: print(f"\n{'='*80}\nBENCHMARK: {ds} (SSL SOTA: {ssl_sota[ds]}%)\n{'='*80}") dataset = Planetoid(root="/tmp/pyg_data", name=ds, transform=NormalizeFeatures()) data = dataset[0].to(device) print(f" {data.num_nodes} nodes, {data.num_edges} edges, {data.num_features} feat, {dataset.num_classes} classes") seeds_res = [] for seed in [42, 123, 456]: print(f"\n --- Seed {seed} ---") r = train_and_evaluate(ds, data, device, CONFIGS[ds], seed=seed); seeds_res.append(r) if tracker: trackio.log({f"{ds}/s{seed}/acc": r['best_accuracy'], f"{ds}/s{seed}/rank95": r['geometry']['effective_rank_95']}) accs = [r['best_accuracy'] for r in seeds_res] all_results[ds] = {"seeds": seeds_res, "mean_accuracy": float(np.mean(accs)), "std_accuracy": float(np.std(accs)), "max_accuracy": float(np.max(accs)), "ssl_sota": ssl_sota[ds], "gap_to_sota": float(ssl_sota[ds] - np.mean(accs))} print(f"\n [{ds}] SUMMARY: {np.mean(accs):.2f}+/-{np.std(accs):.2f}% (max {np.max(accs):.2f}%) | SOTA: {ssl_sota[ds]}% | Gap: {ssl_sota[ds]-np.mean(accs):.1f}pts") del data, dataset; gc.collect() if device.type == 'cuda': torch.cuda.empty_cache() if HAS_OGB and device.type == 'cuda': print(f"\n{'='*80}\nBENCHMARK: ogbn-arxiv (SSL SOTA: 71.87%)\n{'='*80}") try: dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='/tmp/ogb_data'); data = dataset[0] si = dataset.get_idx_split(); N = data.x.size(0) tm, vm, tsm = torch.zeros(N, dtype=torch.bool), torch.zeros(N, dtype=torch.bool), torch.zeros(N, dtype=torch.bool) tm[si['train']] = True; vm[si['valid']] = True; tsm[si['test']] = True data.train_mask, data.val_mask, data.test_mask = tm, vm, tsm; data.y = data.y.squeeze() ei = data.edge_index; data.edge_index = torch.cat([ei, torch.stack([ei[1], ei[0]])], dim=1) data = data.to(device); print(f" {data.num_nodes} nodes, {data.edge_index.size(1)} edges") r = train_and_evaluate("ogbn-arxiv", data, device, ARXIV_CONFIG, seed=42) all_results["ogbn-arxiv"] = {"seeds": [r], "mean_accuracy": r['best_accuracy'], "ssl_sota": 71.87, "gap_to_ssl_sota": 71.87 - r['best_accuracy']} except Exception as e: print(f" ogbn-arxiv FAILED: {e}"); import traceback; traceback.print_exc() print(f"\n{'='*100}\nNODE-JEPA V4.2 — COMPREHENSIVE BASELINE RESULTS\n{'='*100}") print(f"\n{'Dataset':<15} {'Ours':>10} {'SOTA':>8} {'Gap':>7} {'Isotropy':>10} {'Align':>8} {'Uniform':>10} {'Rank95':>8} {'ms/1K':>8}") print("-" * 95) for ds in ["Cora", "CiteSeer", "PubMed", "ogbn-arxiv"]: if ds not in all_results or "error" in all_results.get(ds, {}): print(f" {ds:<13} {'SKIP':>10}"); continue r = all_results[ds]; g = r['seeds'][0]['geometry']; t = r['seeds'][0]['timing'] print(f" {ds:<13} {r['mean_accuracy']:>9.2f}% {r['ssl_sota']:>7.1f}% {r.get('gap_to_sota', r['ssl_sota']-r['mean_accuracy']):>+6.1f} {g['isotropy_ratio']:>9.4f} {g['alignment']:>7.4f} {g['uniformity']:>9.4f} {g['effective_rank_95']:>8} {t['ms_per_1k_nodes']:>7.2f}") print(f"\nRef: GraphMAE (KDD'22) | Note: ogbn-arxiv SSL ceiling ~72%, ogbn-products has NO published SSL baseline") out = "/tmp/v4.2_baseline_results.json" with open(out, "w") as f: json.dump(all_results, f, indent=2, default=str) if HAS_HF_HUB: try: api = HfApi() api.upload_file(path_or_fileobj=out, path_in_repo="v4.2_baseline_results.json", repo_id="EPSAGR/Node-JEPA", repo_type="model") print(f"Pushed to https://huggingface.co/EPSAGR/Node-JEPA") except: pass if tracker: trackio.finish() print("\nDone!") if __name__ == "__main__": main()