File size: 4,740 Bytes
7f7a890 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
eval_checkpoint.py — run test-set evaluation on an existing checkpoint
Usage:
python scripts/eval_checkpoint.py \
--checkpoint artifacts/graph_hpo/graph_hpo_best.pth \
--out artifacts/graph_hpo/test_eval_results.json
"""
import argparse, json, sys
from pathlib import Path
import joblib
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
BASE = Path(__file__).parent.parent
sys.path.insert(0, str(Path(__file__).parent))
from train_v3_fixed import (
ImprovedResidualMLP, EmbeddingDataset,
eval_micro_fmax, eval_cafa_fmax,
load_go_parents, parse_labels,
DATA_BASE, DATA_SUPP, SUPP_COLS, ESM_DIM, OUT_DIM,
SPLITS_NPZ, MLB_PATH, OBO_PATH,
)
from torch.utils.data import DataLoader
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--checkpoint", default="artifacts/graph_hpo/graph_hpo_best.pth")
ap.add_argument("--out", default="artifacts/graph_hpo/test_eval_results.json")
ap.add_argument("--batch", type=int, default=2048)
args = ap.parse_args()
ckpt_path = BASE / args.checkpoint if not Path(args.checkpoint).is_absolute() else Path(args.checkpoint)
out_path = BASE / args.out if not Path(args.out).is_absolute() else Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
device = torch.device(
"mps" if torch.backends.mps.is_available() else
"cuda" if torch.cuda.is_available() else "cpu"
)
if device.type == "mps":
torch.mps.set_per_process_memory_fraction(0.95)
print(f"Device: {device}")
# Load checkpoint
print(f"Loading checkpoint: {ckpt_path}")
raw = torch.load(ckpt_path, map_location="cpu")
sd = raw["model"]
in_dim = raw["in_dim"]
hidden = raw["hidden"]
n_blks = raw["n_blocks"]
supp_mu = np.array(raw["supp_mu"], dtype=np.float32)
supp_sd = np.array(raw["supp_sd"], dtype=np.float32)
feat_label = raw.get("feature_label", "unknown")
val_fmax = raw.get("val_fmax", float("nan"))
print(f" feature_label={feat_label} in_dim={in_dim} hidden={hidden} n_blocks={n_blks} val_fmax={val_fmax:.4f}")
# Load data
print("Loading insect dataset...")
df_base = pd.read_parquet(DATA_BASE)
df_supp = pd.read_parquet(DATA_SUPP)
mlb = joblib.load(MLB_PATH)
emb_cols = [f"Dim_{i}" for i in range(ESM_DIM)]
X_base = df_base[emb_cols].to_numpy(np.float32)
S_raw = df_supp[SUPP_COLS].to_numpy(np.float32)
m_flag = df_supp["f_af_present"].to_numpy(np.float32).reshape(-1, 1)
# Normalise using stored checkpoint stats
S_z = (S_raw - supp_mu) / (supp_sd + 1e-12)
X_full = np.concatenate([X_base, S_z, m_flag], axis=1).astype(np.float32)
X_all = X_full[:, :in_dim]
print(f" X_all shape: {X_all.shape}")
# Labels
label_lists = [parse_labels(x) for x in df_base["Label_Indices"]]
Y_all = np.zeros((len(df_base), OUT_DIM), dtype=np.uint8)
for r, labs in enumerate(label_lists):
for j in labs:
if 0 <= j < OUT_DIM:
Y_all[r, j] = 1
# Splits
splits = np.load(SPLITS_NPZ)
test_idx = splits["test_idx"]
print(f" Test set: {len(test_idx):,} proteins")
# Model
model = ImprovedResidualMLP(in_dim=in_dim, out_dim=OUT_DIM, hidden=hidden, n_blocks=n_blks)
model.load_state_dict(sd)
model = model.to(device).eval()
# Loaders
ds_te = EmbeddingDataset(X_all, Y_all, test_idx)
ld_te = DataLoader(ds_te, batch_size=args.batch, shuffle=False, num_workers=0, pin_memory=False)
# Evaluate
print("Running micro-Fmax eval...")
test_micro = eval_micro_fmax(model, ld_te, device)
print(f" Test micro-Fmax={test_micro['micro_fmax']:.4f} t*={test_micro['t_star']:.2f} "
f"P={test_micro['precision']:.4f} R={test_micro['recall']:.4f}")
print("Running CAFA-Fmax eval...")
go_parents = load_go_parents(OBO_PATH)
test_cafa = eval_cafa_fmax(model, ld_te, device, mlb.classes_, go_parents)
print(f" Test CAFA-Fmax={test_cafa.get('cafa_fmax', 'N/A')}")
result = {
"checkpoint": str(ckpt_path),
"feature_label": feat_label,
"in_dim": in_dim,
"val_fmax": val_fmax,
"test_micro_fmax": test_micro["micro_fmax"],
"test_t_star": test_micro["t_star"],
"test_precision": test_micro["precision"],
"test_recall": test_micro["recall"],
"test_cafa_fmax": test_cafa.get("cafa_fmax"),
"test_cafa_t_star":test_cafa.get("t_star"),
}
with open(out_path, "w") as f:
json.dump(result, f, indent=2)
print(f"\nSaved to {out_path}")
if __name__ == "__main__":
main()
|