""" eval_checkpoint.py — run test-set evaluation on an existing checkpoint Usage: python scripts/eval_checkpoint.py \ --checkpoint artifacts/graph_hpo/graph_hpo_best.pth \ --out artifacts/graph_hpo/test_eval_results.json """ import argparse, json, sys from pathlib import Path import joblib import numpy as np import pandas as pd import torch import torch.nn as nn BASE = Path(__file__).parent.parent sys.path.insert(0, str(Path(__file__).parent)) from train_v3_fixed import ( ImprovedResidualMLP, EmbeddingDataset, eval_micro_fmax, eval_cafa_fmax, load_go_parents, parse_labels, DATA_BASE, DATA_SUPP, SUPP_COLS, ESM_DIM, OUT_DIM, SPLITS_NPZ, MLB_PATH, OBO_PATH, ) from torch.utils.data import DataLoader def main(): ap = argparse.ArgumentParser() ap.add_argument("--checkpoint", default="artifacts/graph_hpo/graph_hpo_best.pth") ap.add_argument("--out", default="artifacts/graph_hpo/test_eval_results.json") ap.add_argument("--batch", type=int, default=2048) args = ap.parse_args() ckpt_path = BASE / args.checkpoint if not Path(args.checkpoint).is_absolute() else Path(args.checkpoint) out_path = BASE / args.out if not Path(args.out).is_absolute() else Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) device = torch.device( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) if device.type == "mps": torch.mps.set_per_process_memory_fraction(0.95) print(f"Device: {device}") # Load checkpoint print(f"Loading checkpoint: {ckpt_path}") raw = torch.load(ckpt_path, map_location="cpu") sd = raw["model"] in_dim = raw["in_dim"] hidden = raw["hidden"] n_blks = raw["n_blocks"] supp_mu = np.array(raw["supp_mu"], dtype=np.float32) supp_sd = np.array(raw["supp_sd"], dtype=np.float32) feat_label = raw.get("feature_label", "unknown") val_fmax = raw.get("val_fmax", float("nan")) print(f" feature_label={feat_label} in_dim={in_dim} hidden={hidden} n_blocks={n_blks} val_fmax={val_fmax:.4f}") # Load data print("Loading insect dataset...") df_base = pd.read_parquet(DATA_BASE) df_supp = pd.read_parquet(DATA_SUPP) mlb = joblib.load(MLB_PATH) emb_cols = [f"Dim_{i}" for i in range(ESM_DIM)] X_base = df_base[emb_cols].to_numpy(np.float32) S_raw = df_supp[SUPP_COLS].to_numpy(np.float32) m_flag = df_supp["f_af_present"].to_numpy(np.float32).reshape(-1, 1) # Normalise using stored checkpoint stats S_z = (S_raw - supp_mu) / (supp_sd + 1e-12) X_full = np.concatenate([X_base, S_z, m_flag], axis=1).astype(np.float32) X_all = X_full[:, :in_dim] print(f" X_all shape: {X_all.shape}") # Labels label_lists = [parse_labels(x) for x in df_base["Label_Indices"]] Y_all = np.zeros((len(df_base), OUT_DIM), dtype=np.uint8) for r, labs in enumerate(label_lists): for j in labs: if 0 <= j < OUT_DIM: Y_all[r, j] = 1 # Splits splits = np.load(SPLITS_NPZ) test_idx = splits["test_idx"] print(f" Test set: {len(test_idx):,} proteins") # Model model = ImprovedResidualMLP(in_dim=in_dim, out_dim=OUT_DIM, hidden=hidden, n_blocks=n_blks) model.load_state_dict(sd) model = model.to(device).eval() # Loaders ds_te = EmbeddingDataset(X_all, Y_all, test_idx) ld_te = DataLoader(ds_te, batch_size=args.batch, shuffle=False, num_workers=0, pin_memory=False) # Evaluate print("Running micro-Fmax eval...") test_micro = eval_micro_fmax(model, ld_te, device) print(f" Test micro-Fmax={test_micro['micro_fmax']:.4f} t*={test_micro['t_star']:.2f} " f"P={test_micro['precision']:.4f} R={test_micro['recall']:.4f}") print("Running CAFA-Fmax eval...") go_parents = load_go_parents(OBO_PATH) test_cafa = eval_cafa_fmax(model, ld_te, device, mlb.classes_, go_parents) print(f" Test CAFA-Fmax={test_cafa.get('cafa_fmax', 'N/A')}") result = { "checkpoint": str(ckpt_path), "feature_label": feat_label, "in_dim": in_dim, "val_fmax": val_fmax, "test_micro_fmax": test_micro["micro_fmax"], "test_t_star": test_micro["t_star"], "test_precision": test_micro["precision"], "test_recall": test_micro["recall"], "test_cafa_fmax": test_cafa.get("cafa_fmax"), "test_cafa_t_star":test_cafa.get("t_star"), } with open(out_path, "w") as f: json.dump(result, f, indent=2) print(f"\nSaved to {out_path}") if __name__ == "__main__": main()