File size: 4,740 Bytes
7f7a890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
eval_checkpoint.py — run test-set evaluation on an existing checkpoint
Usage:
  python scripts/eval_checkpoint.py \
      --checkpoint artifacts/graph_hpo/graph_hpo_best.pth \
      --out        artifacts/graph_hpo/test_eval_results.json
"""
import argparse, json, sys
from pathlib import Path

import joblib
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

BASE    = Path(__file__).parent.parent
sys.path.insert(0, str(Path(__file__).parent))

from train_v3_fixed import (
    ImprovedResidualMLP, EmbeddingDataset,
    eval_micro_fmax, eval_cafa_fmax,
    load_go_parents, parse_labels,
    DATA_BASE, DATA_SUPP, SUPP_COLS, ESM_DIM, OUT_DIM,
    SPLITS_NPZ, MLB_PATH, OBO_PATH,
)
from torch.utils.data import DataLoader


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--checkpoint", default="artifacts/graph_hpo/graph_hpo_best.pth")
    ap.add_argument("--out", default="artifacts/graph_hpo/test_eval_results.json")
    ap.add_argument("--batch", type=int, default=2048)
    args = ap.parse_args()

    ckpt_path = BASE / args.checkpoint if not Path(args.checkpoint).is_absolute() else Path(args.checkpoint)
    out_path  = BASE / args.out        if not Path(args.out).is_absolute()        else Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    device = torch.device(
        "mps"  if torch.backends.mps.is_available() else
        "cuda" if torch.cuda.is_available()          else "cpu"
    )
    if device.type == "mps":
        torch.mps.set_per_process_memory_fraction(0.95)
    print(f"Device: {device}")

    # Load checkpoint
    print(f"Loading checkpoint: {ckpt_path}")
    raw = torch.load(ckpt_path, map_location="cpu")
    sd      = raw["model"]
    in_dim  = raw["in_dim"]
    hidden  = raw["hidden"]
    n_blks  = raw["n_blocks"]
    supp_mu = np.array(raw["supp_mu"], dtype=np.float32)
    supp_sd = np.array(raw["supp_sd"], dtype=np.float32)
    feat_label = raw.get("feature_label", "unknown")
    val_fmax   = raw.get("val_fmax", float("nan"))
    print(f"  feature_label={feat_label}  in_dim={in_dim}  hidden={hidden}  n_blocks={n_blks}  val_fmax={val_fmax:.4f}")

    # Load data
    print("Loading insect dataset...")
    df_base = pd.read_parquet(DATA_BASE)
    df_supp = pd.read_parquet(DATA_SUPP)
    mlb     = joblib.load(MLB_PATH)

    emb_cols = [f"Dim_{i}" for i in range(ESM_DIM)]
    X_base   = df_base[emb_cols].to_numpy(np.float32)
    S_raw    = df_supp[SUPP_COLS].to_numpy(np.float32)
    m_flag   = df_supp["f_af_present"].to_numpy(np.float32).reshape(-1, 1)

    # Normalise using stored checkpoint stats
    S_z      = (S_raw - supp_mu) / (supp_sd + 1e-12)
    X_full   = np.concatenate([X_base, S_z, m_flag], axis=1).astype(np.float32)
    X_all    = X_full[:, :in_dim]
    print(f"  X_all shape: {X_all.shape}")

    # Labels
    label_lists = [parse_labels(x) for x in df_base["Label_Indices"]]
    Y_all = np.zeros((len(df_base), OUT_DIM), dtype=np.uint8)
    for r, labs in enumerate(label_lists):
        for j in labs:
            if 0 <= j < OUT_DIM:
                Y_all[r, j] = 1

    # Splits
    splits   = np.load(SPLITS_NPZ)
    test_idx = splits["test_idx"]
    print(f"  Test set: {len(test_idx):,} proteins")

    # Model
    model = ImprovedResidualMLP(in_dim=in_dim, out_dim=OUT_DIM, hidden=hidden, n_blocks=n_blks)
    model.load_state_dict(sd)
    model = model.to(device).eval()

    # Loaders
    ds_te  = EmbeddingDataset(X_all, Y_all, test_idx)
    ld_te  = DataLoader(ds_te, batch_size=args.batch, shuffle=False, num_workers=0, pin_memory=False)

    # Evaluate
    print("Running micro-Fmax eval...")
    test_micro = eval_micro_fmax(model, ld_te, device)
    print(f"  Test micro-Fmax={test_micro['micro_fmax']:.4f}  t*={test_micro['t_star']:.2f}  "
          f"P={test_micro['precision']:.4f}  R={test_micro['recall']:.4f}")

    print("Running CAFA-Fmax eval...")
    go_parents = load_go_parents(OBO_PATH)
    test_cafa  = eval_cafa_fmax(model, ld_te, device, mlb.classes_, go_parents)
    print(f"  Test CAFA-Fmax={test_cafa.get('cafa_fmax', 'N/A')}")

    result = {
        "checkpoint":      str(ckpt_path),
        "feature_label":   feat_label,
        "in_dim":          in_dim,
        "val_fmax":        val_fmax,
        "test_micro_fmax": test_micro["micro_fmax"],
        "test_t_star":     test_micro["t_star"],
        "test_precision":  test_micro["precision"],
        "test_recall":     test_micro["recall"],
        "test_cafa_fmax":  test_cafa.get("cafa_fmax"),
        "test_cafa_t_star":test_cafa.get("t_star"),
    }
    with open(out_path, "w") as f:
        json.dump(result, f, indent=2)
    print(f"\nSaved to {out_path}")


if __name__ == "__main__":
    main()