| """ |
| Node-JEPA V4.2 β Round 2: ogbn-arxiv + CiteSeer Fix |
| ===================================================== |
| Round 1 results: |
| Cora: 82.8% (rank=65) β
healthy |
| PubMed: 80.1% (rank=62) β
healthy |
| CiteSeer: 53.9% (rank=1) β collapsed |
| ogbn-arxiv: SKIPPED (OGB install issue) |
| |
| This script: |
| 1. Runs ogbn-arxiv with proper loading (to_undirected, StandardScaler on features) |
| 2. Tests 4 CiteSeer fix hypotheses to break the collapse |
| |
| Key fixes from research: |
| - ogbn-arxiv: use to_undirected() (it's directed!), StandardScaler on features (GraphMAE does this) |
| - ogbn-arxiv: BGRL uses NO feature dropout on arxiv (p_f=0.0), edge dropout only |
| - CiteSeer collapse likely caused by: GAT encoder, high feat dim (3703), low LR, reconstruction on sparse features |
| """ |
|
|
| import math, copy, os, sys, json, time, gc, functools |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| |
| _original_torch_load = torch.load |
| @functools.wraps(_original_torch_load) |
| def _patched_torch_load(*args, **kwargs): |
| if 'weights_only' not in kwargs: |
| kwargs['weights_only'] = False |
| return _original_torch_load(*args, **kwargs) |
| torch.load = _patched_torch_load |
| print("Patched torch.load for PyTorch 2.6+ compatibility") |
|
|
| from torch_geometric.nn import GATConv, GCNConv |
| from torch_geometric.datasets import Planetoid |
| from torch_geometric.transforms import NormalizeFeatures |
| from torch_geometric.utils import to_undirected |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.preprocessing import StandardScaler |
|
|
| try: |
| from ogb.nodeproppred import PygNodePropPredDataset |
| HAS_OGB = True |
| print("OGB loaded successfully") |
| except Exception as e: |
| HAS_OGB = False |
| print(f"OGB import failed: {e}") |
|
|
| try: import trackio; HAS_TRACKIO = True |
| except: HAS_TRACKIO = False |
|
|
| try: from huggingface_hub import HfApi; HAS_HF_HUB = True |
| except: HAS_HF_HUB = False |
|
|
|
|
| |
| |
| |
|
|
| def decorrelation_loss(z): |
| N, D = z.shape |
| zc = z - z.mean(0); zs = zc / (zc.std(0) + 1e-8) |
| R = (zs.T @ zs) / N; I = torch.eye(D, device=z.device) |
| return (R - I).pow(2).sum() / (D * D) |
|
|
| def sce_loss(pred, target, alpha=1): |
| p = F.normalize(pred, p=2, dim=-1); t = F.normalize(target, p=2, dim=-1) |
| return (1 - (p * t).sum(dim=-1)).pow(alpha).mean() |
|
|
| class GraphAugmentor: |
| def __init__(self, feat_mask_p=0.0, edge_drop_p=0.0): |
| self.feat_mask_p = feat_mask_p; self.edge_drop_p = edge_drop_p |
| def __call__(self, x, edge_index): |
| if self.feat_mask_p > 0: |
| x = x * torch.bernoulli(torch.ones_like(x) * (1 - self.feat_mask_p)) |
| if self.edge_drop_p > 0: |
| m = torch.bernoulli(torch.ones(edge_index.size(1), device=x.device) * (1 - self.edge_drop_p)).bool() |
| edge_index = edge_index[:, m] |
| return x, edge_index |
|
|
| class GATEncoder(nn.Module): |
| def __init__(self, in_dim, hidden_dim, out_dim, num_layers=2, num_heads=4, dropout=0.0): |
| super().__init__() |
| self.num_layers = num_layers |
| self.reg_token = nn.Parameter(torch.zeros(1, in_dim)); nn.init.trunc_normal_(self.reg_token, std=0.02) |
| self.input_proj = nn.Linear(in_dim, hidden_dim) |
| self.input_bn = nn.BatchNorm1d(hidden_dim, momentum=0.01) |
| self.gat_layers = nn.ModuleList(); self.bns = nn.ModuleList(); self.acts = nn.ModuleList() |
| for _ in range(num_layers): |
| self.gat_layers.append(GATConv(hidden_dim, hidden_dim // num_heads, heads=num_heads, concat=True, dropout=dropout, add_self_loops=True)) |
| self.bns.append(nn.BatchNorm1d(hidden_dim, momentum=0.01)); self.acts.append(nn.PReLU()) |
| self.output_proj = nn.Linear(hidden_dim, out_dim) |
| def forward(self, x, edge_index): |
| N = x.size(0); device = x.device |
| x_aug = torch.cat([x, self.reg_token.expand(1, -1).to(device)], dim=0) |
| re = torch.cat([torch.stack([torch.full((N,), N, dtype=torch.long, device=device), torch.arange(N, device=device)]), |
| torch.stack([torch.arange(N, device=device), torch.full((N,), N, dtype=torch.long, device=device)])], dim=1) |
| ei = torch.cat([edge_index, re], dim=1) |
| h = F.prelu(self.input_bn(self.input_proj(x_aug)), torch.tensor(0.25, device=device)) |
| for i in range(self.num_layers): h = self.acts[i](self.bns[i](self.gat_layers[i](h, ei))) + h |
| return self.output_proj(h)[:N] |
|
|
| class GCNEncoder(nn.Module): |
| def __init__(self, in_dim, hidden_dim, out_dim, num_layers=2, dropout=0.0): |
| super().__init__() |
| self.num_layers = num_layers |
| self.input_proj = nn.Linear(in_dim, hidden_dim) |
| self.input_bn = nn.BatchNorm1d(hidden_dim, momentum=0.01) |
| self.gcn_layers = nn.ModuleList(); self.bns = nn.ModuleList(); self.acts = nn.ModuleList() |
| for _ in range(num_layers): |
| self.gcn_layers.append(GCNConv(hidden_dim, hidden_dim)) |
| self.bns.append(nn.BatchNorm1d(hidden_dim, momentum=0.01)); self.acts.append(nn.PReLU()) |
| self.output_proj = nn.Linear(hidden_dim, out_dim) |
| def forward(self, x, edge_index): |
| h = F.prelu(self.input_bn(self.input_proj(x)), torch.tensor(0.25, device=x.device)) |
| for i in range(self.num_layers): h = self.acts[i](self.bns[i](self.gcn_layers[i](h, edge_index))) + h |
| return self.output_proj(h) |
|
|
| class MLPPredictor(nn.Module): |
| def __init__(self, dim, hidden=512): |
| super().__init__() |
| self.net = nn.Sequential(nn.Linear(dim, hidden), nn.BatchNorm1d(hidden, momentum=0.01), nn.PReLU(), nn.Linear(hidden, dim)) |
| def forward(self, x): return self.net(x) |
|
|
| class FeatureDecoder(nn.Module): |
| def __init__(self, ed, hd, od): |
| super().__init__() |
| self.net = nn.Sequential(nn.Linear(ed, hd), nn.BatchNorm1d(hd, momentum=0.01), nn.PReLU(), nn.Linear(hd, od)) |
| def forward(self, x): return self.net(x) |
|
|
| class NodeMasker(nn.Module): |
| def __init__(self, feat_dim, mask_rate=0.5, replace_rate=0.05): |
| super().__init__() |
| self.mask_rate = mask_rate; self.replace_rate = replace_rate |
| self.mask_token = nn.Parameter(torch.zeros(1, feat_dim)); nn.init.trunc_normal_(self.mask_token, std=0.02) |
| def forward(self, x): |
| N = x.size(0); nm = max(1, int(N * self.mask_rate)) |
| perm = torch.randperm(N, device=x.device) |
| mask = torch.zeros(N, dtype=torch.bool, device=x.device); mask[perm[:nm]] = True |
| xm = x.clone(); nr = max(0, int(nm * self.replace_rate)) |
| xm[perm[:nm - nr]] = self.mask_token.expand(nm - nr, -1) |
| if nr > 0: xm[perm[nm - nr:nm]] = x[torch.randint(0, N, (nr,), device=x.device)] |
| return xm, mask |
|
|
| class NodeJEPA(nn.Module): |
| def __init__(self, in_dim, hidden_dim=512, out_dim=256, num_layers=2, num_heads=4, |
| predictor_hidden=512, mask_rate=0.5, replace_rate=0.05, |
| ema_tau_base=0.99, aug1_feat_p=0.0, aug1_edge_p=0.0, |
| aug2_feat_p=0.0, aug2_edge_p=0.0, lambda_dec=0.0, mu_rec=0.0, |
| sce_alpha=1, encoder_type='gat', dropout=0.0): |
| super().__init__() |
| self.lambda_dec = lambda_dec; self.mu_rec = mu_rec; self.sce_alpha = sce_alpha |
| if encoder_type == 'gat': |
| self.online_encoder = GATEncoder(in_dim, hidden_dim, out_dim, num_layers, num_heads, dropout) |
| else: |
| self.online_encoder = GCNEncoder(in_dim, hidden_dim, out_dim, num_layers, dropout) |
| self.target_encoder = copy.deepcopy(self.online_encoder) |
| for p in self.target_encoder.parameters(): p.requires_grad = False |
| self.predictor = MLPPredictor(out_dim, predictor_hidden) |
| self.masker = NodeMasker(in_dim, mask_rate, replace_rate) |
| self.aug1 = GraphAugmentor(aug1_feat_p, aug1_edge_p) |
| self.aug2 = GraphAugmentor(aug2_feat_p, aug2_edge_p) |
| self.feat_decoder = FeatureDecoder(out_dim, hidden_dim, in_dim) if mu_rec > 0 else None |
|
|
| @torch.no_grad() |
| def update_target(self, tau): |
| for po, pt in zip(self.online_encoder.parameters(), self.target_encoder.parameters()): |
| pt.data.mul_(tau).add_(po.data, alpha=1 - tau) |
|
|
| def _jepa_loss(self, x, ei, aug_o, aug_t, do_mask=True): |
| xo, eio = aug_o(x, ei); xt, eit = aug_t(x, ei) |
| if do_mask: xo, mask = self.masker(xo) |
| else: mask = torch.ones(x.size(0), dtype=torch.bool, device=x.device) |
| ho = self.online_encoder(xo, eio) |
| with torch.no_grad(): ht = self.target_encoder(xt, eit) |
| pred = F.normalize(self.predictor(ho[mask]), dim=-1) |
| tgt = F.normalize(ht[mask], dim=-1) |
| return 2 - 2 * (pred * tgt).sum(-1).mean(), ho, mask |
|
|
| def forward(self, data): |
| x, ei = data.x, data.edge_index |
| l1, ho1, m1 = self._jepa_loss(x, ei, self.aug1, self.aug2, True) |
| l2, ho2, m2 = self._jepa_loss(x, ei, self.aug2, self.aug1, False) |
| lj = (l1 + l2) / 2 |
| ld = self.lambda_dec * (decorrelation_loss(ho1) + decorrelation_loss(ho2)) / 2 if self.lambda_dec > 0 else torch.tensor(0.0, device=x.device) |
| lr = torch.tensor(0.0, device=x.device) |
| if self.mu_rec > 0 and self.feat_decoder is not None: |
| lr = self.mu_rec * sce_loss(self.feat_decoder(ho1[m1]), x[m1], self.sce_alpha) |
| return lj + ld + lr, lj, ld, lr |
|
|
| @torch.no_grad() |
| def encode(self, data): |
| self.online_encoder.eval() |
| return self.online_encoder(data.x, data.edge_index) |
|
|
| def cosine_scheduler(base, final, steps): |
| return [final - (final - base) * (math.cos(math.pi * s / max(steps - 1, 1)) + 1) / 2 for s in range(steps)] |
|
|
|
|
| |
| |
| |
|
|
| def compute_geometry(embs, labels=None, n=2000): |
| N, D = embs.shape |
| idx = np.random.choice(N, min(N, n), replace=False) if N > n else np.arange(N) |
| es = embs[idx]; ls = labels[idx] if labels is not None else None |
| ec = es - es.mean(0) |
| try: _, S, _ = np.linalg.svd(ec, full_matrices=False) |
| except: S = np.ones(min(es.shape)) |
| tv = (S**2).sum() |
| if tv < 1e-12: rk95, rk99 = 1, 1 |
| else: |
| ev = np.cumsum(S**2) / tv |
| rk95 = int(np.searchsorted(ev, 0.95) + 1); rk99 = int(np.searchsorted(ev, 0.99) + 1) |
| iso = float(S[-1] / S[0]) if len(S) > 1 and S[0] > 1e-12 else 0.0 |
| norms = np.linalg.norm(es, axis=1, keepdims=True) + 1e-8; en = es / norms |
| cm = en @ en.T; np.fill_diagonal(cm, 0); nn_ = es.shape[0] |
| avg_cos = cm.sum() / (nn_ * (nn_ - 1)) |
| align = 0.0 |
| if ls is not None: |
| sims = [] |
| for c in np.unique(ls): |
| mc = ls == c |
| if mc.sum() > 1: |
| ec_ = en[mc]; cc = ec_ @ ec_.T; np.fill_diagonal(cc, 0) |
| sims.extend(cc[np.triu_indices(ec_.shape[0], k=1)].tolist()) |
| if sims: align = float(np.mean(sims)) |
| diff = en[:, None, :] - en[None, :, :]; sq = (diff**2).sum(-1) |
| um = ~np.eye(nn_, dtype=bool) |
| unif = float(np.log(np.exp(-2 * sq[um]).mean() + 1e-12)) |
| return {"rank95": rk95, "rank99": rk99, "isotropy": iso, "alignment": align, |
| "uniformity": unif, "avg_cos": float(avg_cos), "std_dim": float(embs.std(0).mean())} |
|
|
| def linear_eval(embs, labels, tr_mask, val_mask, te_mask, seeds=20): |
| results = {} |
| for nm, fn in [("raw", lambda e: e), ("l2", lambda e: e/(np.linalg.norm(e,1,True)+1e-8)), ("std", lambda e: StandardScaler().fit_transform(e))]: |
| e = fn(embs.copy()); Xtr, ytr = e[tr_mask], labels[tr_mask]; Xv, yv = e[val_mask], labels[val_mask]; Xte, yte = e[te_mask], labels[te_mask] |
| accs = [] |
| for s in range(seeds): |
| bv, bt = -1, -1 |
| for C in [0.01, 0.1, 1.0, 10.0, 100.0]: |
| clf = LogisticRegression(max_iter=2000, solver='lbfgs', random_state=s, C=C); clf.fit(Xtr, ytr) |
| v = clf.score(Xv, yv) |
| if v > bv: bv = v; bt = clf.score(Xte, yte) |
| accs.append(bt * 100) |
| results[nm] = {"mean": float(np.mean(accs)), "std": float(np.std(accs))} |
| return results |
|
|
| def get_lr(ep, warmup, total, base_lr, min_lr=1e-6): |
| if ep < warmup: return base_lr * (ep + 1) / warmup |
| p = (ep - warmup) / max(total - warmup, 1) |
| return min_lr + (base_lr - min_lr) * (1 + math.cos(math.pi * p)) / 2 |
|
|
|
|
| |
| |
| |
|
|
| def run(name, data, device, cfg, seed=42): |
| torch.manual_seed(seed); np.random.seed(seed) |
| if device.type == 'cuda': torch.cuda.manual_seed(seed) |
| epochs, lr, warmup = cfg['epochs'], cfg['lr'], cfg['warmup'] |
| model = NodeJEPA( |
| in_dim=data.num_features, hidden_dim=cfg['hidden_dim'], out_dim=cfg['out_dim'], |
| num_layers=cfg['num_layers'], num_heads=cfg.get('num_heads', 4), |
| predictor_hidden=cfg['hidden_dim'], mask_rate=cfg['mask_rate'], |
| ema_tau_base=cfg['ema_tau_base'], |
| aug1_feat_p=cfg.get('aug1_feat_p', 0), aug1_edge_p=cfg.get('aug1_edge_p', 0), |
| aug2_feat_p=cfg.get('aug2_feat_p', 0), aug2_edge_p=cfg.get('aug2_edge_p', 0), |
| lambda_dec=cfg.get('lambda_dec', 0), mu_rec=cfg.get('mu_rec', 0), |
| encoder_type=cfg.get('encoder_type', 'gat'), |
| ).to(device) |
| opt = AdamW([p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=cfg.get('wd', 1e-5)) |
| ema = cosine_scheduler(cfg['ema_tau_base'], 1.0, epochs) |
| t0 = time.time(); best_acc, best_ep = 0, 0; ei = max(1, epochs // 10); li = max(1, epochs // 20) |
| for ep in range(epochs): |
| model.train(); clr = get_lr(ep, warmup, epochs, lr) |
| for pg in opt.param_groups: pg['lr'] = clr |
| loss, lj, ld, lr_ = model(data); opt.zero_grad(); loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0); opt.step(); model.update_target(ema[ep]) |
| if (ep+1) % li == 0 or ep == 0: |
| print(f" [{name}] Ep {ep+1:4d}/{epochs} | L={loss.item():.4f} (j={lj.item():.4f} d={ld.item():.4f} r={lr_.item():.4f}) | LR={clr:.1e} | {time.time()-t0:.0f}s"); sys.stdout.flush() |
| if (ep+1) % ei == 0: |
| model.eval() |
| with torch.no_grad(): e = model.encode(data).cpu().numpy() |
| ln = data.y.cpu().numpy(); |
| if len(ln.shape) > 1: ln = ln.squeeze() |
| Xtr = e[data.train_mask.cpu().numpy()]; ytr = ln[data.train_mask.cpu().numpy()] |
| Xte = e[data.test_mask.cpu().numpy()]; yte = ln[data.test_mask.cpu().numpy()] |
| clf = LogisticRegression(max_iter=2000, solver='lbfgs', C=1.0); clf.fit(Xtr, ytr) |
| acc = clf.score(Xte, yte) * 100 |
| if acc > best_acc: best_acc = acc; best_ep = ep + 1 |
| print(f" [{name}] EVAL @{ep+1}: {acc:.1f}% | BEST={best_acc:.1f}%@{best_ep}"); sys.stdout.flush() |
| |
| model.eval() |
| with torch.no_grad(): embs = model.encode(data).cpu().numpy() |
| labels = data.y.cpu().numpy() |
| if len(labels.shape) > 1: labels = labels.squeeze() |
| tr = data.train_mask.cpu().numpy(); va = data.val_mask.cpu().numpy() if hasattr(data, 'val_mask') and data.val_mask is not None else tr; te = data.test_mask.cpu().numpy() |
| probe = linear_eval(embs, labels, tr, va, te, seeds=20) |
| geom = compute_geometry(embs, labels) |
| best = max(probe[k]['mean'] for k in probe); bn = max(probe, key=lambda k: probe[k]['mean']) |
| print(f" [{name}] FINAL: raw={probe['raw']['mean']:.1f} l2={probe['l2']['mean']:.1f} std={probe['std']['mean']:.1f} | BEST={best:.1f}%({bn})") |
| print(f" [{name}] GEOM: rank95={geom['rank95']} iso={geom['isotropy']:.4f} align={geom['alignment']:.4f} unif={geom['uniformity']:.4f} cos={geom['avg_cos']:.4f}") |
| return {"name": name, "accuracy": probe, "best": best, "best_norm": bn, "geometry": geom, "time": time.time()-t0} |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| torch.manual_seed(42); np.random.seed(42) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
| if device.type == "cuda": |
| print(f" GPU: {torch.cuda.get_device_name(0)}") |
| try: print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| except: pass |
| torch.cuda.manual_seed(42) |
|
|
| tracker = None |
| if HAS_TRACKIO: |
| try: trackio.init(name="node-jepa-round2", project="node-jepa"); tracker = True; print("Trackio initialized") |
| except: pass |
|
|
| all_results = {} |
|
|
| |
| |
| |
| if HAS_OGB: |
| print(f"\n{'='*80}") |
| print(f"PART 1: ogbn-arxiv (169K nodes, SSL SOTA: 71.87%)") |
| print(f"{'='*80}") |
| try: |
| dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='/tmp/ogb_data') |
| data = dataset[0] |
| si = dataset.get_idx_split() |
|
|
| |
| print(f" Original: {data.num_nodes} nodes, {data.edge_index.size(1)} edges (directed)") |
| data.edge_index = to_undirected(data.edge_index, num_nodes=data.num_nodes) |
| print(f" Undirected: {data.edge_index.size(1)} edges") |
|
|
| |
| feat_np = data.x.numpy() |
| scaler = StandardScaler() |
| data.x = torch.tensor(scaler.fit_transform(feat_np), dtype=torch.float32) |
| print(f" Features: {data.x.size(1)} dim, StandardScaled") |
|
|
| |
| N = data.num_nodes |
| tm = torch.zeros(N, dtype=torch.bool); tm[si['train']] = True |
| vm = torch.zeros(N, dtype=torch.bool); vm[si['valid']] = True |
| tsm = torch.zeros(N, dtype=torch.bool); tsm[si['test']] = True |
| data.train_mask = tm; data.val_mask = vm; data.test_mask = tsm |
| data.y = data.y.squeeze() |
| data = data.to(device) |
|
|
| print(f" Train: {tm.sum()}, Val: {vm.sum()}, Test: {tsm.sum()}, Classes: {data.y.max()+1}") |
|
|
| |
| |
| |
| arxiv_cfg = { |
| 'epochs': 500, 'lr': 1e-3, 'warmup': 100, |
| 'hidden_dim': 256, 'out_dim': 256, |
| 'num_layers': 2, 'num_heads': 4, |
| 'mask_rate': 0.5, 'ema_tau_base': 0.996, 'wd': 1e-5, |
| 'aug1_feat_p': 0.0, 'aug1_edge_p': 0.4, |
| 'aug2_feat_p': 0.0, 'aug2_edge_p': 0.4, |
| 'encoder_type': 'gcn', |
| 'lambda_dec': 0.0, 'mu_rec': 0.0, |
| } |
| torch.cuda.empty_cache() |
| print(f" VRAM before training: {torch.cuda.memory_allocated()/1e9:.2f} GB used") |
| r = run("ogbn-arxiv", data, device, arxiv_cfg, seed=42) |
| all_results["ogbn-arxiv"] = r |
|
|
| if tracker: |
| trackio.log({"arxiv/best": r['best'], "arxiv/rank95": r['geometry']['rank95']}) |
|
|
| del data, dataset; gc.collect() |
| if device.type == 'cuda': torch.cuda.empty_cache() |
| except Exception as e: |
| print(f" ogbn-arxiv FAILED: {e}") |
| import traceback; traceback.print_exc() |
| all_results["ogbn-arxiv"] = {"error": str(e)} |
| else: |
| print("\nSKIPPING ogbn-arxiv (OGB not installed)") |
|
|
| |
| |
| |
| print(f"\n{'='*80}") |
| print(f"PART 2: CiteSeer Collapse Fix (4 hypotheses)") |
| print(f"{'='*80}") |
|
|
| dataset = Planetoid(root="/tmp/pyg_data", name="CiteSeer", transform=NormalizeFeatures()) |
| data = dataset[0].to(device) |
| print(f" {data.num_nodes} nodes, {data.num_edges} edges, {data.num_features} feat, {dataset.num_classes} classes") |
|
|
| |
| base = { |
| 'epochs': 600, 'lr': 3e-4, 'warmup': 200, |
| 'hidden_dim': 512, 'out_dim': 256, 'num_layers': 2, 'num_heads': 4, |
| 'mask_rate': 0.5, 'ema_tau_base': 0.99, 'wd': 1e-5, |
| 'aug1_feat_p': 0.5, 'aug1_edge_p': 0.4, 'aug2_feat_p': 0.3, 'aug2_edge_p': 0.4, |
| 'encoder_type': 'gat', 'lambda_dec': 0.0, 'mu_rec': 2.0, |
| } |
|
|
| citeseer_results = [] |
|
|
| |
| print(f"\n --- Fix 1: GCN encoder (GAT may amplify collapse) ---") |
| cfg1 = {**base, 'encoder_type': 'gcn'} |
| r1 = run("cs_gcn", data, device, cfg1); citeseer_results.append(r1) |
|
|
| |
| print(f"\n --- Fix 2: No reconstruction (mu_rec=0, V3 style) ---") |
| cfg2 = {**base, 'mu_rec': 0.0} |
| r2 = run("cs_no_rec", data, device, cfg2); citeseer_results.append(r2) |
|
|
| |
| print(f"\n --- Fix 3: Higher LR (1e-3 instead of 3e-4) ---") |
| cfg3 = {**base, 'lr': 1e-3} |
| r3 = run("cs_lr1e3", data, device, cfg3); citeseer_results.append(r3) |
|
|
| |
| print(f"\n --- Fix 4: GCN + no rec + lr=1e-3 (all fixes) ---") |
| cfg4 = {**base, 'encoder_type': 'gcn', 'mu_rec': 0.0, 'lr': 1e-3} |
| r4 = run("cs_all_fix", data, device, cfg4); citeseer_results.append(r4) |
|
|
| all_results["citeseer_fixes"] = citeseer_results |
|
|
| if tracker: |
| for r in citeseer_results: |
| trackio.log({f"cs/{r['name']}/best": r['best'], f"cs/{r['name']}/rank95": r['geometry']['rank95']}) |
|
|
| |
| |
| |
| print(f"\n{'='*100}") |
| print(f"NODE-JEPA V4.2 β ROUND 2 RESULTS") |
| print(f"{'='*100}") |
|
|
| print(f"\n{'Config':<20} {'Best%':>8} {'Norm':>6} {'Rank95':>7} {'Iso':>8} {'Align':>8} {'Unif':>8} {'Cos':>8}") |
| print("-" * 85) |
|
|
| |
| if "ogbn-arxiv" in all_results and "error" not in all_results["ogbn-arxiv"]: |
| r = all_results["ogbn-arxiv"]; g = r['geometry'] |
| print(f" {'ogbn-arxiv':<18} {r['best']:>7.2f}% {r['best_norm']:>6} {g['rank95']:>7} {g['isotropy']:>7.4f} {g['alignment']:>7.4f} {g['uniformity']:>7.4f} {g['avg_cos']:>7.4f}") |
| else: |
| print(f" {'ogbn-arxiv':<18} {'FAILED':>8}") |
|
|
| |
| for r in citeseer_results: |
| g = r['geometry'] |
| collapsed = "β" if g['rank95'] <= 1 else "β
" |
| print(f" {r['name']:<18} {r['best']:>7.2f}% {r['best_norm']:>6} {collapsed}{g['rank95']:>6} {g['isotropy']:>7.4f} {g['alignment']:>7.4f} {g['uniformity']:>7.4f} {g['avg_cos']:>7.4f}") |
|
|
| |
| print(f"\n Round 1 reference:") |
| print(f" {'Cora (R1)':<18} {'82.77':>7}% {'raw':>6} {'65':>7} {'0.0977':>8} {'0.2172':>8} {'-3.4168':>8}") |
| print(f" {'PubMed (R1)':<18} {'80.10':>7}% {'std':>6} {'62':>7} {'0.0910':>8} {'0.5689':>8} {'-1.7647':>8}") |
| print(f" {'CiteSeer (R1)':<18} {'53.93':>7}% {'?':>6} {'1':>7} {'0.0019':>8} {'1.0000':>8} {'0.0000':>8} β collapsed") |
| print(f"\n SSL SOTA: Cora=84.2%, CiteSeer=73.4%, PubMed=81.1%, ogbn-arxiv=71.87%") |
|
|
| |
| out = "/tmp/v4.2_round2_results.json" |
| with open(out, "w") as f: json.dump(all_results, f, indent=2, default=str) |
| if HAS_HF_HUB: |
| try: |
| api = HfApi() |
| api.upload_file(path_or_fileobj=out, path_in_repo="v4.2_round2_results.json", repo_id="EPSAGR/Node-JEPA", repo_type="model") |
| print(f"\nPushed to https://huggingface.co/EPSAGR/Node-JEPA") |
| except: pass |
| if tracker: trackio.finish() |
| print("\nDone!") |
|
|
| if __name__ == "__main__": |
| main() |
|
|