| """ |
| eval_checkpoint.py — run test-set evaluation on an existing checkpoint |
| Usage: |
| python scripts/eval_checkpoint.py \ |
| --checkpoint artifacts/graph_hpo/graph_hpo_best.pth \ |
| --out artifacts/graph_hpo/test_eval_results.json |
| """ |
| import argparse, json, sys |
| from pathlib import Path |
|
|
| import joblib |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
|
|
| BASE = Path(__file__).parent.parent |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from train_v3_fixed import ( |
| ImprovedResidualMLP, EmbeddingDataset, |
| eval_micro_fmax, eval_cafa_fmax, |
| load_go_parents, parse_labels, |
| DATA_BASE, DATA_SUPP, SUPP_COLS, ESM_DIM, OUT_DIM, |
| SPLITS_NPZ, MLB_PATH, OBO_PATH, |
| ) |
| from torch.utils.data import DataLoader |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--checkpoint", default="artifacts/graph_hpo/graph_hpo_best.pth") |
| ap.add_argument("--out", default="artifacts/graph_hpo/test_eval_results.json") |
| ap.add_argument("--batch", type=int, default=2048) |
| args = ap.parse_args() |
|
|
| ckpt_path = BASE / args.checkpoint if not Path(args.checkpoint).is_absolute() else Path(args.checkpoint) |
| out_path = BASE / args.out if not Path(args.out).is_absolute() else Path(args.out) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| device = torch.device( |
| "mps" if torch.backends.mps.is_available() else |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
| if device.type == "mps": |
| torch.mps.set_per_process_memory_fraction(0.95) |
| print(f"Device: {device}") |
|
|
| |
| print(f"Loading checkpoint: {ckpt_path}") |
| raw = torch.load(ckpt_path, map_location="cpu") |
| sd = raw["model"] |
| in_dim = raw["in_dim"] |
| hidden = raw["hidden"] |
| n_blks = raw["n_blocks"] |
| supp_mu = np.array(raw["supp_mu"], dtype=np.float32) |
| supp_sd = np.array(raw["supp_sd"], dtype=np.float32) |
| feat_label = raw.get("feature_label", "unknown") |
| val_fmax = raw.get("val_fmax", float("nan")) |
| print(f" feature_label={feat_label} in_dim={in_dim} hidden={hidden} n_blocks={n_blks} val_fmax={val_fmax:.4f}") |
|
|
| |
| print("Loading insect dataset...") |
| df_base = pd.read_parquet(DATA_BASE) |
| df_supp = pd.read_parquet(DATA_SUPP) |
| mlb = joblib.load(MLB_PATH) |
|
|
| emb_cols = [f"Dim_{i}" for i in range(ESM_DIM)] |
| X_base = df_base[emb_cols].to_numpy(np.float32) |
| S_raw = df_supp[SUPP_COLS].to_numpy(np.float32) |
| m_flag = df_supp["f_af_present"].to_numpy(np.float32).reshape(-1, 1) |
|
|
| |
| S_z = (S_raw - supp_mu) / (supp_sd + 1e-12) |
| X_full = np.concatenate([X_base, S_z, m_flag], axis=1).astype(np.float32) |
| X_all = X_full[:, :in_dim] |
| print(f" X_all shape: {X_all.shape}") |
|
|
| |
| label_lists = [parse_labels(x) for x in df_base["Label_Indices"]] |
| Y_all = np.zeros((len(df_base), OUT_DIM), dtype=np.uint8) |
| for r, labs in enumerate(label_lists): |
| for j in labs: |
| if 0 <= j < OUT_DIM: |
| Y_all[r, j] = 1 |
|
|
| |
| splits = np.load(SPLITS_NPZ) |
| test_idx = splits["test_idx"] |
| print(f" Test set: {len(test_idx):,} proteins") |
|
|
| |
| model = ImprovedResidualMLP(in_dim=in_dim, out_dim=OUT_DIM, hidden=hidden, n_blocks=n_blks) |
| model.load_state_dict(sd) |
| model = model.to(device).eval() |
|
|
| |
| ds_te = EmbeddingDataset(X_all, Y_all, test_idx) |
| ld_te = DataLoader(ds_te, batch_size=args.batch, shuffle=False, num_workers=0, pin_memory=False) |
|
|
| |
| print("Running micro-Fmax eval...") |
| test_micro = eval_micro_fmax(model, ld_te, device) |
| print(f" Test micro-Fmax={test_micro['micro_fmax']:.4f} t*={test_micro['t_star']:.2f} " |
| f"P={test_micro['precision']:.4f} R={test_micro['recall']:.4f}") |
|
|
| print("Running CAFA-Fmax eval...") |
| go_parents = load_go_parents(OBO_PATH) |
| test_cafa = eval_cafa_fmax(model, ld_te, device, mlb.classes_, go_parents) |
| print(f" Test CAFA-Fmax={test_cafa.get('cafa_fmax', 'N/A')}") |
|
|
| result = { |
| "checkpoint": str(ckpt_path), |
| "feature_label": feat_label, |
| "in_dim": in_dim, |
| "val_fmax": val_fmax, |
| "test_micro_fmax": test_micro["micro_fmax"], |
| "test_t_star": test_micro["t_star"], |
| "test_precision": test_micro["precision"], |
| "test_recall": test_micro["recall"], |
| "test_cafa_fmax": test_cafa.get("cafa_fmax"), |
| "test_cafa_t_star":test_cafa.get("t_star"), |
| } |
| with open(out_path, "w") as f: |
| json.dump(result, f, indent=2) |
| print(f"\nSaved to {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|