protfunc / scripts /eval_checkpoint.py
Sbhat2026's picture
perf: ESM embedding cache + 1500aa limit, add research scripts
7f7a890
"""
eval_checkpoint.py — run test-set evaluation on an existing checkpoint
Usage:
python scripts/eval_checkpoint.py \
--checkpoint artifacts/graph_hpo/graph_hpo_best.pth \
--out artifacts/graph_hpo/test_eval_results.json
"""
import argparse, json, sys
from pathlib import Path
import joblib
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
BASE = Path(__file__).parent.parent
sys.path.insert(0, str(Path(__file__).parent))
from train_v3_fixed import (
ImprovedResidualMLP, EmbeddingDataset,
eval_micro_fmax, eval_cafa_fmax,
load_go_parents, parse_labels,
DATA_BASE, DATA_SUPP, SUPP_COLS, ESM_DIM, OUT_DIM,
SPLITS_NPZ, MLB_PATH, OBO_PATH,
)
from torch.utils.data import DataLoader
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--checkpoint", default="artifacts/graph_hpo/graph_hpo_best.pth")
ap.add_argument("--out", default="artifacts/graph_hpo/test_eval_results.json")
ap.add_argument("--batch", type=int, default=2048)
args = ap.parse_args()
ckpt_path = BASE / args.checkpoint if not Path(args.checkpoint).is_absolute() else Path(args.checkpoint)
out_path = BASE / args.out if not Path(args.out).is_absolute() else Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
device = torch.device(
"mps" if torch.backends.mps.is_available() else
"cuda" if torch.cuda.is_available() else "cpu"
)
if device.type == "mps":
torch.mps.set_per_process_memory_fraction(0.95)
print(f"Device: {device}")
# Load checkpoint
print(f"Loading checkpoint: {ckpt_path}")
raw = torch.load(ckpt_path, map_location="cpu")
sd = raw["model"]
in_dim = raw["in_dim"]
hidden = raw["hidden"]
n_blks = raw["n_blocks"]
supp_mu = np.array(raw["supp_mu"], dtype=np.float32)
supp_sd = np.array(raw["supp_sd"], dtype=np.float32)
feat_label = raw.get("feature_label", "unknown")
val_fmax = raw.get("val_fmax", float("nan"))
print(f" feature_label={feat_label} in_dim={in_dim} hidden={hidden} n_blocks={n_blks} val_fmax={val_fmax:.4f}")
# Load data
print("Loading insect dataset...")
df_base = pd.read_parquet(DATA_BASE)
df_supp = pd.read_parquet(DATA_SUPP)
mlb = joblib.load(MLB_PATH)
emb_cols = [f"Dim_{i}" for i in range(ESM_DIM)]
X_base = df_base[emb_cols].to_numpy(np.float32)
S_raw = df_supp[SUPP_COLS].to_numpy(np.float32)
m_flag = df_supp["f_af_present"].to_numpy(np.float32).reshape(-1, 1)
# Normalise using stored checkpoint stats
S_z = (S_raw - supp_mu) / (supp_sd + 1e-12)
X_full = np.concatenate([X_base, S_z, m_flag], axis=1).astype(np.float32)
X_all = X_full[:, :in_dim]
print(f" X_all shape: {X_all.shape}")
# Labels
label_lists = [parse_labels(x) for x in df_base["Label_Indices"]]
Y_all = np.zeros((len(df_base), OUT_DIM), dtype=np.uint8)
for r, labs in enumerate(label_lists):
for j in labs:
if 0 <= j < OUT_DIM:
Y_all[r, j] = 1
# Splits
splits = np.load(SPLITS_NPZ)
test_idx = splits["test_idx"]
print(f" Test set: {len(test_idx):,} proteins")
# Model
model = ImprovedResidualMLP(in_dim=in_dim, out_dim=OUT_DIM, hidden=hidden, n_blocks=n_blks)
model.load_state_dict(sd)
model = model.to(device).eval()
# Loaders
ds_te = EmbeddingDataset(X_all, Y_all, test_idx)
ld_te = DataLoader(ds_te, batch_size=args.batch, shuffle=False, num_workers=0, pin_memory=False)
# Evaluate
print("Running micro-Fmax eval...")
test_micro = eval_micro_fmax(model, ld_te, device)
print(f" Test micro-Fmax={test_micro['micro_fmax']:.4f} t*={test_micro['t_star']:.2f} "
f"P={test_micro['precision']:.4f} R={test_micro['recall']:.4f}")
print("Running CAFA-Fmax eval...")
go_parents = load_go_parents(OBO_PATH)
test_cafa = eval_cafa_fmax(model, ld_te, device, mlb.classes_, go_parents)
print(f" Test CAFA-Fmax={test_cafa.get('cafa_fmax', 'N/A')}")
result = {
"checkpoint": str(ckpt_path),
"feature_label": feat_label,
"in_dim": in_dim,
"val_fmax": val_fmax,
"test_micro_fmax": test_micro["micro_fmax"],
"test_t_star": test_micro["t_star"],
"test_precision": test_micro["precision"],
"test_recall": test_micro["recall"],
"test_cafa_fmax": test_cafa.get("cafa_fmax"),
"test_cafa_t_star":test_cafa.get("t_star"),
}
with open(out_path, "w") as f:
json.dump(result, f, indent=2)
print(f"\nSaved to {out_path}")
if __name__ == "__main__":
main()