| """ |
| Node-JEPA V4.2 — Comprehensive Baseline Evaluation |
| ==================================================== |
| Evaluates Node-JEPA on standard benchmarks with: |
| 1. Linear probe accuracy (raw, L2-norm, StandardScaler) |
| 2. Embedding geometry metrics (isotropy, alignment, uniformity, effective rank) |
| 3. Inference timing |
| 4. Multi-seed robustness |
| |
| Benchmarks: Cora, CiteSeer, PubMed, ogbn-arxiv (if scalable) |
| |
| Published SSL SOTA for reference: |
| Cora: 84.2% (GraphMAE) |
| CiteSeer: 73.4% (GraphMAE) |
| PubMed: 81.1% (GraphMAE) |
| ogbn-arxiv: 71.87% (GraphMAE) |
| |
| Run in Colab: |
| !pip install torch torch_geometric trackio huggingface_hub ogb scikit-learn |
| !wget https://huggingface.co/EPSAGR/node-jepa/resolve/main/node_jepa_baseline_eval.py |
| !python node_jepa_baseline_eval.py |
| """ |
|
|
| import math, copy, os, sys, json, time, gc |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| import warnings |
| warnings.filterwarnings("ignore", category=UserWarning) |
| warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
| try: |
| from torch_geometric.nn import GATConv, GCNConv |
| from torch_geometric.datasets import Planetoid |
| from torch_geometric.transforms import NormalizeFeatures |
| HAS_PYG = True |
| except ImportError: |
| HAS_PYG = False |
| print("ERROR: torch_geometric not installed. Run: pip install torch_geometric") |
| sys.exit(1) |
|
|
| try: |
| from ogb.nodeproppred import PygNodePropPredDataset |
| HAS_OGB = True |
| except ImportError: |
| HAS_OGB = False |
| print("WARNING: ogb not installed, skipping ogbn-arxiv. Run: pip install ogb") |
|
|
| from sklearn.linear_model import LogisticRegression |
| from sklearn.preprocessing import StandardScaler |
|
|
| try: |
| import trackio |
| HAS_TRACKIO = True |
| except: |
| HAS_TRACKIO = False |
|
|
| try: |
| from huggingface_hub import HfApi |
| HAS_HF_HUB = True |
| except: |
| HAS_HF_HUB = False |
|
|
|
|
| def decorrelation_loss(z): |
| N, D = z.shape |
| z_centered = z - z.mean(0) |
| z_std = z_centered / (z_centered.std(0) + 1e-8) |
| R = (z_std.T @ z_std) / N |
| I = torch.eye(D, device=z.device) |
| return (R - I).pow(2).sum() / (D * D) |
|
|
| def sce_loss(pred, target, alpha=1): |
| pred = F.normalize(pred, p=2, dim=-1) |
| target = F.normalize(target, p=2, dim=-1) |
| return (1 - (pred * target).sum(dim=-1)).pow(alpha).mean() |
|
|
|
|
| class GraphAugmentor: |
| def __init__(self, feat_mask_p=0.0, edge_drop_p=0.0): |
| self.feat_mask_p = feat_mask_p |
| self.edge_drop_p = edge_drop_p |
| def __call__(self, x, edge_index): |
| device = x.device |
| if self.feat_mask_p > 0: |
| feat_mask = torch.bernoulli(torch.ones_like(x) * (1 - self.feat_mask_p)) |
| x = x * feat_mask |
| if self.edge_drop_p > 0: |
| edge_mask = torch.bernoulli(torch.ones(edge_index.size(1), device=device) * (1 - self.edge_drop_p)).bool() |
| edge_index = edge_index[:, edge_mask] |
| return x, edge_index |
|
|
| class GATEncoder(nn.Module): |
| def __init__(self, in_dim, hidden_dim, out_dim, num_layers=2, num_heads=4, dropout=0.0): |
| super().__init__() |
| self.num_layers = num_layers |
| self.reg_token = nn.Parameter(torch.zeros(1, in_dim)) |
| nn.init.trunc_normal_(self.reg_token, std=0.02) |
| self.input_proj = nn.Linear(in_dim, hidden_dim) |
| self.input_bn = nn.BatchNorm1d(hidden_dim, momentum=0.01) |
| self.gat_layers = nn.ModuleList() |
| self.bns = nn.ModuleList() |
| self.activations = nn.ModuleList() |
| for _ in range(num_layers): |
| self.gat_layers.append(GATConv(hidden_dim, hidden_dim // num_heads, heads=num_heads, |
| concat=True, dropout=dropout, add_self_loops=True)) |
| self.bns.append(nn.BatchNorm1d(hidden_dim, momentum=0.01)) |
| self.activations.append(nn.PReLU()) |
| self.output_proj = nn.Linear(hidden_dim, out_dim) |
| def forward(self, x, edge_index): |
| device = x.device; N = x.size(0) |
| x_aug = torch.cat([x, self.reg_token.expand(1, -1).to(device)], dim=0) |
| reg_edges = torch.cat([ |
| torch.stack([torch.full((N,), N, dtype=torch.long, device=device), torch.arange(N, device=device)]), |
| torch.stack([torch.arange(N, device=device), torch.full((N,), N, dtype=torch.long, device=device)])], dim=1) |
| ei_aug = torch.cat([edge_index, reg_edges], dim=1) |
| h = F.prelu(self.input_bn(self.input_proj(x_aug)), torch.tensor(0.25, device=device)) |
| for i in range(self.num_layers): |
| h = self.activations[i](self.bns[i](self.gat_layers[i](h, ei_aug))) + h |
| return self.output_proj(h)[:N] |
|
|
| class GCNEncoder(nn.Module): |
| def __init__(self, in_dim, hidden_dim, out_dim, num_layers=2, dropout=0.0): |
| super().__init__() |
| self.num_layers = num_layers |
| self.input_proj = nn.Linear(in_dim, hidden_dim) |
| self.input_bn = nn.BatchNorm1d(hidden_dim, momentum=0.01) |
| self.gcn_layers = nn.ModuleList() |
| self.bns = nn.ModuleList() |
| self.activations = nn.ModuleList() |
| for _ in range(num_layers): |
| self.gcn_layers.append(GCNConv(hidden_dim, hidden_dim)) |
| self.bns.append(nn.BatchNorm1d(hidden_dim, momentum=0.01)) |
| self.activations.append(nn.PReLU()) |
| self.output_proj = nn.Linear(hidden_dim, out_dim) |
| def forward(self, x, edge_index): |
| h = F.prelu(self.input_bn(self.input_proj(x)), torch.tensor(0.25, device=x.device)) |
| for i in range(self.num_layers): |
| h = self.activations[i](self.bns[i](self.gcn_layers[i](h, edge_index))) + h |
| return self.output_proj(h) |
|
|
| class MLPPredictor(nn.Module): |
| def __init__(self, dim, hidden=512): |
| super().__init__() |
| self.net = nn.Sequential(nn.Linear(dim, hidden), nn.BatchNorm1d(hidden, momentum=0.01), nn.PReLU(), nn.Linear(hidden, dim)) |
| def forward(self, x): return self.net(x) |
|
|
| class FeatureDecoder(nn.Module): |
| def __init__(self, embed_dim, hidden_dim, out_dim): |
| super().__init__() |
| self.net = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.BatchNorm1d(hidden_dim, momentum=0.01), nn.PReLU(), nn.Linear(hidden_dim, out_dim)) |
| def forward(self, x): return self.net(x) |
|
|
| class NodeMasker(nn.Module): |
| def __init__(self, feat_dim, mask_rate=0.5, replace_rate=0.05): |
| super().__init__() |
| self.mask_rate = mask_rate; self.replace_rate = replace_rate |
| self.mask_token = nn.Parameter(torch.zeros(1, feat_dim)) |
| nn.init.trunc_normal_(self.mask_token, std=0.02) |
| def forward(self, x): |
| N = x.size(0); num_mask = max(1, int(N * self.mask_rate)) |
| perm = torch.randperm(N, device=x.device) |
| mask = torch.zeros(N, dtype=torch.bool, device=x.device); mask[perm[:num_mask]] = True |
| x_m = x.clone(); nr = max(0, int(num_mask * self.replace_rate)) |
| x_m[perm[:num_mask - nr]] = self.mask_token.expand(num_mask - nr, -1) |
| if nr > 0: x_m[perm[num_mask - nr:num_mask]] = x[torch.randint(0, N, (nr,), device=x.device)] |
| return x_m, mask |
|
|
|
|
| class NodeJEPA(nn.Module): |
| def __init__(self, in_dim, hidden_dim=512, out_dim=256, num_layers=2, num_heads=4, |
| predictor_hidden=512, mask_rate=0.5, replace_rate=0.05, |
| ema_tau_base=0.99, ema_tau_final=1.0, |
| aug1_feat_p=0.0, aug1_edge_p=0.0, aug2_feat_p=0.0, aug2_edge_p=0.0, |
| lambda_dec=0.0, mu_rec=0.0, sce_alpha=1, encoder_type='gat', dropout=0.0): |
| super().__init__() |
| self.ema_tau_base = ema_tau_base; self.ema_tau_final = ema_tau_final |
| self.lambda_dec = lambda_dec; self.mu_rec = mu_rec; self.sce_alpha = sce_alpha |
| if encoder_type == 'gat': |
| self.online_encoder = GATEncoder(in_dim, hidden_dim, out_dim, num_layers, num_heads, dropout) |
| else: |
| self.online_encoder = GCNEncoder(in_dim, hidden_dim, out_dim, num_layers, dropout) |
| self.target_encoder = copy.deepcopy(self.online_encoder) |
| for p in self.target_encoder.parameters(): p.requires_grad = False |
| self.predictor = MLPPredictor(out_dim, predictor_hidden) |
| self.masker = NodeMasker(in_dim, mask_rate, replace_rate) |
| self.aug1 = GraphAugmentor(aug1_feat_p, aug1_edge_p) |
| self.aug2 = GraphAugmentor(aug2_feat_p, aug2_edge_p) |
| self.feat_decoder = FeatureDecoder(out_dim, hidden_dim, in_dim) if mu_rec > 0 else None |
|
|
| @torch.no_grad() |
| def update_target(self, tau): |
| for po, pt in zip(self.online_encoder.parameters(), self.target_encoder.parameters()): |
| pt.data.mul_(tau).add_(po.data, alpha=1 - tau) |
|
|
| def _jepa_loss(self, x, edge_index, aug_o, aug_t, do_mask=True): |
| xo, eio = aug_o(x, edge_index); xt, eit = aug_t(x, edge_index) |
| if do_mask: xo, mask = self.masker(xo) |
| else: mask = torch.ones(x.size(0), dtype=torch.bool, device=x.device) |
| ho = self.online_encoder(xo, eio) |
| with torch.no_grad(): ht = self.target_encoder(xt, eit) |
| pred = F.normalize(self.predictor(ho[mask]), dim=-1) |
| tgt = F.normalize(ht[mask], dim=-1) |
| return 2 - 2 * (pred * tgt).sum(-1).mean(), ho, mask |
|
|
| def forward(self, data): |
| x, edge_index = data.x, data.edge_index |
| l1, ho1, m1 = self._jepa_loss(x, edge_index, self.aug1, self.aug2, True) |
| l2, ho2, m2 = self._jepa_loss(x, edge_index, self.aug2, self.aug1, False) |
| l_jepa = (l1 + l2) / 2 |
| l_dec = self.lambda_dec * (decorrelation_loss(ho1) + decorrelation_loss(ho2)) / 2 if self.lambda_dec > 0 else torch.tensor(0.0, device=x.device) |
| l_rec = torch.tensor(0.0, device=x.device) |
| if self.mu_rec > 0 and self.feat_decoder is not None: |
| l_rec = self.mu_rec * sce_loss(self.feat_decoder(ho1[m1]), x[m1], self.sce_alpha) |
| return l_jepa + l_dec + l_rec, l_jepa, l_dec, l_rec |
|
|
| @torch.no_grad() |
| def encode(self, data): |
| self.online_encoder.eval() |
| return self.online_encoder(data.x, data.edge_index) |
|
|
|
|
| def cosine_scheduler(base, final, steps): |
| return [final - (final - base) * (math.cos(math.pi * s / max(steps - 1, 1)) + 1) / 2 for s in range(steps)] |
|
|
|
|
| def compute_geometry_metrics(embeddings, labels=None, sample_size=2000): |
| N, D = embeddings.shape |
| idx = np.random.choice(N, min(N, sample_size), replace=False) if N > sample_size else np.arange(N) |
| embs_sub = embeddings[idx]; labels_sub = labels[idx] if labels is not None else None |
| embs_centered = embs_sub - embs_sub.mean(0) |
| try: U, S, Vt = np.linalg.svd(embs_centered, full_matrices=False) |
| except: S = np.ones(min(embs_sub.shape)) |
| total_var = (S**2).sum() |
| if total_var < 1e-12: eff_rank_95, eff_rank_99 = 1, 1 |
| else: |
| ev = np.cumsum(S**2) / total_var |
| eff_rank_95 = int(np.searchsorted(ev, 0.95) + 1); eff_rank_99 = int(np.searchsorted(ev, 0.99) + 1) |
| if len(S) > 1 and S[0] > 1e-12: |
| isotropy = float(S[-1] / S[0]) |
| log_s2 = np.log(S**2 + 1e-12) |
| isotropy_ratio = np.exp(log_s2.mean()) / ((S**2).mean() + 1e-12) |
| else: isotropy = 0.0; isotropy_ratio = 0.0 |
| norms = np.linalg.norm(embs_sub, axis=1, keepdims=True) + 1e-8 |
| embs_norm = embs_sub / norms; n = embs_sub.shape[0] |
| cos_matrix = embs_norm @ embs_norm.T; np.fill_diagonal(cos_matrix, 0) |
| avg_cos = cos_matrix.sum() / (n * (n - 1)) |
| alignment = 0.0 |
| if labels_sub is not None: |
| sims = [] |
| for c in np.unique(labels_sub): |
| mc = labels_sub == c |
| if mc.sum() > 1: |
| ec = embs_norm[mc]; cc = ec @ ec.T; np.fill_diagonal(cc, 0) |
| sims.extend(cc[np.triu_indices(ec.shape[0], k=1)].tolist()) |
| if sims: alignment = float(np.mean(sims)) |
| diff = embs_norm[:, None, :] - embs_norm[None, :, :] |
| sq_dist = (diff ** 2).sum(-1); umask = ~np.eye(n, dtype=bool) |
| uniformity = float(np.log(np.exp(-2 * sq_dist[umask]).mean() + 1e-12)) |
| return { |
| "isotropy_minmax": float(isotropy), "isotropy_ratio": float(isotropy_ratio), |
| "alignment": float(alignment), "uniformity": float(uniformity), |
| "effective_rank_95": eff_rank_95, "effective_rank_99": eff_rank_99, |
| "avg_cos_sim": float(avg_cos), "std_per_dim": float(embeddings.std(axis=0).mean()), |
| "embed_dim": D, "num_nodes": N, |
| } |
|
|
|
|
| def linear_eval_multi(embeddings, labels, train_mask, val_mask, test_mask, num_seeds=20, max_iter=2000): |
| results = {} |
| for norm_name, norm_fn in [ |
| ("raw", lambda e: e), |
| ("l2", lambda e: e / (np.linalg.norm(e, axis=1, keepdims=True) + 1e-8)), |
| ("standard", lambda e: StandardScaler().fit_transform(e)), |
| ]: |
| embs = norm_fn(embeddings.copy()) |
| X_train, y_train = embs[train_mask], labels[train_mask] |
| X_val, y_val = embs[val_mask], labels[val_mask] |
| X_test, y_test = embs[test_mask], labels[test_mask] |
| test_accs = [] |
| for seed in range(num_seeds): |
| best_val, best_test = -1, -1 |
| for C in [0.001, 0.01, 0.1, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0]: |
| clf = LogisticRegression(max_iter=max_iter, solver='lbfgs', random_state=seed, C=C) |
| clf.fit(X_train, y_train) |
| v = clf.score(X_val, y_val) |
| if v > best_val: |
| best_val = v |
| best_test = clf.score(X_test, y_test) |
| test_accs.append(best_test * 100) |
| results[norm_name] = {"mean": float(np.mean(test_accs)), "std": float(np.std(test_accs))} |
| return results |
|
|
| def measure_inference_time(model, data, num_runs=100, warmup=10): |
| model.eval(); device = next(model.parameters()).device; data = data.to(device) |
| with torch.no_grad(): |
| for _ in range(warmup): _ = model.encode(data) |
| if device.type == 'cuda': torch.cuda.synchronize() |
| times = [] |
| with torch.no_grad(): |
| for _ in range(num_runs): |
| if device.type == 'cuda': torch.cuda.synchronize() |
| t0 = time.perf_counter(); _ = model.encode(data) |
| if device.type == 'cuda': torch.cuda.synchronize() |
| times.append((time.perf_counter() - t0) * 1000) |
| return {"mean_ms": float(np.mean(times)), "std_ms": float(np.std(times)), |
| "median_ms": float(np.median(times)), "num_nodes": data.num_nodes, |
| "ms_per_1k_nodes": float(np.mean(times) / (data.num_nodes / 1000))} |
|
|
|
|
| def get_lr(epoch, warmup_epochs, total_epochs, base_lr, min_lr=1e-6): |
| if epoch < warmup_epochs: return base_lr * (epoch + 1) / warmup_epochs |
| progress = (epoch - warmup_epochs) / max(total_epochs - warmup_epochs, 1) |
| return min_lr + (base_lr - min_lr) * (1 + math.cos(math.pi * progress)) / 2 |
|
|
| def train_and_evaluate(dataset_name, data, device, config, seed=42): |
| torch.manual_seed(seed); np.random.seed(seed) |
| if device.type == 'cuda': torch.cuda.manual_seed(seed) |
| epochs, lr, warmup = config['epochs'], config['lr'], config['warmup'] |
| model = NodeJEPA( |
| in_dim=data.num_features, hidden_dim=config['hidden_dim'], out_dim=config['out_dim'], |
| num_layers=config['num_layers'], num_heads=config.get('num_heads', 4), |
| predictor_hidden=config['hidden_dim'], mask_rate=config['mask_rate'], |
| ema_tau_base=config['ema_tau_base'], |
| aug1_feat_p=config['aug1_feat_p'], aug1_edge_p=config['aug1_edge_p'], |
| aug2_feat_p=config['aug2_feat_p'], aug2_edge_p=config['aug2_edge_p'], |
| lambda_dec=config.get('lambda_dec', 0.0), mu_rec=config.get('mu_rec', 0.0), |
| encoder_type=config.get('encoder_type', 'gat'), |
| ).to(device) |
| optimizer = AdamW([p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=config.get('weight_decay', 1e-5)) |
| ema_sched = cosine_scheduler(config['ema_tau_base'], 1.0, epochs) |
| t0 = time.time(); eval_interval = max(1, epochs // 10); log_interval = max(1, epochs // 30) |
| best_eval, best_ep, train_losses = 0, 0, [] |
| for ep in range(epochs): |
| model.train(); cur_lr = get_lr(ep, warmup, epochs, lr) |
| for pg in optimizer.param_groups: pg['lr'] = cur_lr |
| total_loss, l_jepa, l_dec, l_rec = model(data) |
| optimizer.zero_grad(); total_loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0) |
| optimizer.step(); model.update_target(ema_sched[ep]) |
| train_losses.append(total_loss.item()) |
| if (ep + 1) % log_interval == 0 or ep == 0: |
| print(f" [{dataset_name}] Ep {ep+1:4d}/{epochs} | L={total_loss.item():.4f} (j={l_jepa.item():.4f} d={l_dec.item():.4f} r={l_rec.item():.4f}) | LR={cur_lr:.1e} | {time.time()-t0:.0f}s"); sys.stdout.flush() |
| if (ep + 1) % eval_interval == 0: |
| model.eval() |
| with torch.no_grad(): embs = model.encode(data).cpu().numpy() |
| ln = data.y.cpu().numpy() |
| if len(ln.shape) > 1: ln = ln.squeeze() |
| X_tr, y_tr = embs[data.train_mask.cpu().numpy()], ln[data.train_mask.cpu().numpy()] |
| X_te, y_te = embs[data.test_mask.cpu().numpy()], ln[data.test_mask.cpu().numpy()] |
| clf = LogisticRegression(max_iter=2000, solver='lbfgs', C=1.0); clf.fit(X_tr, y_tr) |
| acc = clf.score(X_te, y_te) * 100 |
| if acc > best_eval: best_eval = acc; best_ep = ep + 1 |
| print(f" [{dataset_name}] EVAL @{ep+1}: raw={acc:.1f}% | BEST={best_eval:.1f}%@{best_ep}"); sys.stdout.flush() |
| train_time = time.time() - t0 |
| print(f"\n [{dataset_name}] === FINAL COMPREHENSIVE EVALUATION ===") |
| model.eval() |
| with torch.no_grad(): embeddings = model.encode(data).cpu().numpy() |
| labels_np = data.y.cpu().numpy() |
| if len(labels_np.shape) > 1: labels_np = labels_np.squeeze() |
| train_mask = data.train_mask.cpu().numpy() |
| val_mask = data.val_mask.cpu().numpy() if hasattr(data, 'val_mask') and data.val_mask is not None else train_mask |
| test_mask = data.test_mask.cpu().numpy() |
| print(f" [{dataset_name}] Running 20-seed linear probe...") |
| probe = linear_eval_multi(embeddings, labels_np, train_mask, val_mask, test_mask, num_seeds=20) |
| best_acc = max(probe[k]['mean'] for k in probe) |
| best_norm = max(probe, key=lambda k: probe[k]['mean']) |
| print(f" [{dataset_name}] ACCURACY: raw={probe['raw']['mean']:.2f}+/-{probe['raw']['std']:.2f} | l2={probe['l2']['mean']:.2f}+/-{probe['l2']['std']:.2f} | std={probe['standard']['mean']:.2f}+/-{probe['standard']['std']:.2f}") |
| print(f" [{dataset_name}] BEST: {best_acc:.2f}% ({best_norm})") |
| print(f" [{dataset_name}] Computing embedding geometry...") |
| geometry = compute_geometry_metrics(embeddings, labels_np) |
| print(f" [{dataset_name}] GEOMETRY: isotropy={geometry['isotropy_ratio']:.6f} align={geometry['alignment']:.4f} uniform={geometry['uniformity']:.4f} rank95={geometry['effective_rank_95']} cos={geometry['avg_cos_sim']:.4f}") |
| print(f" [{dataset_name}] Measuring inference time...") |
| timing = measure_inference_time(model, data) |
| print(f" [{dataset_name}] TIMING: {timing['mean_ms']:.2f}+/-{timing['std_ms']:.2f}ms | {timing['ms_per_1k_nodes']:.2f}ms/1K nodes") |
| checks = {"isotropy_ratio > 0.3": geometry['isotropy_ratio'] > 0.3, "alignment 0.4-0.8": 0.4 <= geometry['alignment'] <= 0.8, |
| "uniformity < -2.0": geometry['uniformity'] < -2.0, "inference <= 5ms/1K": timing['ms_per_1k_nodes'] <= 5.0} |
| print(f" [{dataset_name}] CHECKS: " + " | ".join(f"{'PASS' if v else 'FAIL'}:{k}" for k, v in checks.items())) |
| return {"dataset": dataset_name, "config": {k: v for k, v in config.items()}, "seed": seed, |
| "accuracy": probe, "best_accuracy": best_acc, "best_norm": best_norm, |
| "geometry": geometry, "timing": timing, "train_time": train_time, |
| "train_losses": train_losses[-10:], "checks": {k: bool(v) for k, v in checks.items()}} |
|
|
|
|
| CONFIGS = { |
| "Cora": {'epochs': 600, 'lr': 1e-3, 'warmup': 100, 'hidden_dim': 512, 'out_dim': 256, 'num_layers': 2, 'num_heads': 4, 'mask_rate': 0.5, 'ema_tau_base': 0.99, 'weight_decay': 1e-5, 'aug1_feat_p': 0.5, 'aug1_edge_p': 0.5, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.5, 'encoder_type': 'gcn', 'lambda_dec': 0.0, 'mu_rec': 2.0}, |
| "CiteSeer": {'epochs': 600, 'lr': 3e-4, 'warmup': 200, 'hidden_dim': 512, 'out_dim': 256, 'num_layers': 2, 'num_heads': 4, 'mask_rate': 0.5, 'ema_tau_base': 0.99, 'weight_decay': 1e-5, 'aug1_feat_p': 0.5, 'aug1_edge_p': 0.4, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.4, 'encoder_type': 'gat', 'lambda_dec': 0.0, 'mu_rec': 2.0}, |
| "PubMed": {'epochs': 600, 'lr': 1e-3, 'warmup': 100, 'hidden_dim': 512, 'out_dim': 256, 'num_layers': 2, 'num_heads': 4, 'mask_rate': 0.6, 'ema_tau_base': 0.99, 'weight_decay': 1e-5, 'aug1_feat_p': 0.5, 'aug1_edge_p': 0.4, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.5, 'encoder_type': 'gcn', 'lambda_dec': 0.0, 'mu_rec': 2.0}, |
| } |
|
|
| ARXIV_CONFIG = {'epochs': 300, 'lr': 1e-3, 'warmup': 100, 'hidden_dim': 1024, 'out_dim': 512, 'num_layers': 3, 'num_heads': 4, 'mask_rate': 0.5, 'ema_tau_base': 0.996, 'weight_decay': 1e-5, 'aug1_feat_p': 0.5, 'aug1_edge_p': 0.4, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.4, 'encoder_type': 'gcn', 'lambda_dec': 0.0, 'mu_rec': 2.0} |
|
|
|
|
| def main(): |
| torch.manual_seed(42); np.random.seed(42) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
| if device.type == "cuda": |
| print(f" GPU: {torch.cuda.get_device_name(0)}") |
| try: |
| print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| except AttributeError: |
| print(f" VRAM: unknown") |
| torch.cuda.manual_seed(42) |
| tracker = None |
| if HAS_TRACKIO: |
| try: trackio.init(name="node-jepa-v4.2-baseline", project="node-jepa"); tracker = True; print("Trackio initialized") |
| except Exception as e: print(f"Trackio init failed: {e}") |
| all_results = {}; ssl_sota = {"Cora": 84.2, "CiteSeer": 73.4, "PubMed": 81.1} |
| for ds in ["Cora", "CiteSeer", "PubMed"]: |
| print(f"\n{'='*80}\nBENCHMARK: {ds} (SSL SOTA: {ssl_sota[ds]}%)\n{'='*80}") |
| dataset = Planetoid(root="/tmp/pyg_data", name=ds, transform=NormalizeFeatures()) |
| data = dataset[0].to(device) |
| print(f" {data.num_nodes} nodes, {data.num_edges} edges, {data.num_features} feat, {dataset.num_classes} classes") |
| seeds_res = [] |
| for seed in [42, 123, 456]: |
| print(f"\n --- Seed {seed} ---") |
| r = train_and_evaluate(ds, data, device, CONFIGS[ds], seed=seed); seeds_res.append(r) |
| if tracker: trackio.log({f"{ds}/s{seed}/acc": r['best_accuracy'], f"{ds}/s{seed}/rank95": r['geometry']['effective_rank_95']}) |
| accs = [r['best_accuracy'] for r in seeds_res] |
| all_results[ds] = {"seeds": seeds_res, "mean_accuracy": float(np.mean(accs)), "std_accuracy": float(np.std(accs)), |
| "max_accuracy": float(np.max(accs)), "ssl_sota": ssl_sota[ds], "gap_to_sota": float(ssl_sota[ds] - np.mean(accs))} |
| print(f"\n [{ds}] SUMMARY: {np.mean(accs):.2f}+/-{np.std(accs):.2f}% (max {np.max(accs):.2f}%) | SOTA: {ssl_sota[ds]}% | Gap: {ssl_sota[ds]-np.mean(accs):.1f}pts") |
| del data, dataset; gc.collect() |
| if device.type == 'cuda': torch.cuda.empty_cache() |
| if HAS_OGB and device.type == 'cuda': |
| print(f"\n{'='*80}\nBENCHMARK: ogbn-arxiv (SSL SOTA: 71.87%)\n{'='*80}") |
| try: |
| dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='/tmp/ogb_data'); data = dataset[0] |
| si = dataset.get_idx_split(); N = data.x.size(0) |
| tm, vm, tsm = torch.zeros(N, dtype=torch.bool), torch.zeros(N, dtype=torch.bool), torch.zeros(N, dtype=torch.bool) |
| tm[si['train']] = True; vm[si['valid']] = True; tsm[si['test']] = True |
| data.train_mask, data.val_mask, data.test_mask = tm, vm, tsm; data.y = data.y.squeeze() |
| ei = data.edge_index; data.edge_index = torch.cat([ei, torch.stack([ei[1], ei[0]])], dim=1) |
| data = data.to(device); print(f" {data.num_nodes} nodes, {data.edge_index.size(1)} edges") |
| r = train_and_evaluate("ogbn-arxiv", data, device, ARXIV_CONFIG, seed=42) |
| all_results["ogbn-arxiv"] = {"seeds": [r], "mean_accuracy": r['best_accuracy'], "ssl_sota": 71.87, "gap_to_ssl_sota": 71.87 - r['best_accuracy']} |
| except Exception as e: print(f" ogbn-arxiv FAILED: {e}"); import traceback; traceback.print_exc() |
| print(f"\n{'='*100}\nNODE-JEPA V4.2 — COMPREHENSIVE BASELINE RESULTS\n{'='*100}") |
| print(f"\n{'Dataset':<15} {'Ours':>10} {'SOTA':>8} {'Gap':>7} {'Isotropy':>10} {'Align':>8} {'Uniform':>10} {'Rank95':>8} {'ms/1K':>8}") |
| print("-" * 95) |
| for ds in ["Cora", "CiteSeer", "PubMed", "ogbn-arxiv"]: |
| if ds not in all_results or "error" in all_results.get(ds, {}): |
| print(f" {ds:<13} {'SKIP':>10}"); continue |
| r = all_results[ds]; g = r['seeds'][0]['geometry']; t = r['seeds'][0]['timing'] |
| print(f" {ds:<13} {r['mean_accuracy']:>9.2f}% {r['ssl_sota']:>7.1f}% {r.get('gap_to_sota', r['ssl_sota']-r['mean_accuracy']):>+6.1f} {g['isotropy_ratio']:>9.4f} {g['alignment']:>7.4f} {g['uniformity']:>9.4f} {g['effective_rank_95']:>8} {t['ms_per_1k_nodes']:>7.2f}") |
| print(f"\nRef: GraphMAE (KDD'22) | Note: ogbn-arxiv SSL ceiling ~72%, ogbn-products has NO published SSL baseline") |
| out = "/tmp/v4.2_baseline_results.json" |
| with open(out, "w") as f: json.dump(all_results, f, indent=2, default=str) |
| if HAS_HF_HUB: |
| try: |
| api = HfApi() |
| api.upload_file(path_or_fileobj=out, path_in_repo="v4.2_baseline_results.json", repo_id="EPSAGR/Node-JEPA", repo_type="model") |
| print(f"Pushed to https://huggingface.co/EPSAGR/Node-JEPA") |
| except: pass |
| if tracker: trackio.finish() |
| print("\nDone!") |
|
|
| if __name__ == "__main__": |
| main() |
|
|