protfunc / scripts /eval_generalization.py
Sbhat2026's picture
perf: ESM embedding cache + 1500aa limit, add research scripts
7f7a890
"""
eval_generalization.py β€” Per-taxon generalization evaluation for ProtFunc
=========================================================================
Given a trained checkpoint and a taxon parquet (standard embedding format),
computes micro-Fmax, CAFA-Fmax, precision, recall, AUPRC, and per-label
coverage metrics. Works with any taxonomic group produced by prep_taxon.py.
Usage
-----
python scripts/eval_generalization.py \\
--checkpoint artifacts/protfunc_v3_fixed.pth \\
--thresholds artifacts/protfunc_v3_fixed_thresholds.json \\
--mlb "Important Files/mlb_public_v1.pkl" \\
--taxon_parquet artifacts/mammal_embeddings_v3.parquet \\
--taxon_name mammals \\
--obo go-basic.obo \\
--out artifacts/generalization_mammals.json
Multiple taxon runs are accumulated: if --out already exists its contents
are merged so you can build up a single generalization_results.json across
all taxa over time.
Output (JSON)
-------------
{
"mammals": {
"n_proteins": 7,
"n_labeled": 7,
"micro_fmax": 0.82,
"t_star": 0.40,
"precision": 0.85,
"recall": 0.79,
"cafa_fmax": 0.91,
"macro_f1": 0.74,
"micro_auprc": 0.88,
"label_coverage": 0.05, # fraction of 8124 GO terms seen in taxon
"insect_test_fmax": 0.884, # reference (if insect log available)
"generalization_ratio": 0.93,
"model_checkpoint": "protfunc_v3_fixed.pth",
"feature_label": "C_ESM_all",
"evaluated_at": "2026-04-09T20:42:00"
}
}
"""
import argparse
import ast
import json
import os
import warnings
from datetime import datetime, timezone
from pathlib import Path
import joblib
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.metrics import average_precision_score
warnings.filterwarnings("ignore")
# ── Architecture (must match training) ───────────────────────────────────────
class ResBlock(nn.Module):
def __init__(self, dim: int, dropout: float):
super().__init__()
self.net = nn.Sequential(
nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim),
nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim),
)
def forward(self, x):
return x + self.net(x)
class ImprovedResidualMLP(nn.Module):
def __init__(self, in_dim, out_dim=8124, hidden=2048, n_blocks=4, dropout=0.2):
super().__init__()
self.fc_in = nn.Linear(in_dim, hidden)
self.blocks = nn.ModuleList([ResBlock(hidden, dropout) for _ in range(n_blocks)])
self.fc_out = nn.Sequential(
nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(hidden, out_dim),
)
def forward(self, x):
h = self.fc_in(x)
for blk in self.blocks:
h = blk(h)
return self.fc_out(h)
# ── GO hierarchy ─────────────────────────────────────────────────────────────
def load_go_parents(obo_path: Path) -> dict:
if not obo_path.exists():
print(f" WARNING: {obo_path} not found β€” CAFA eval disabled")
return {}
ns_map, par_map = {}, {}
cur_id, cur_ns, cur_par, in_term = None, None, set(), False
def flush():
nonlocal cur_id, cur_ns, cur_par
if cur_id and cur_ns:
ns_map[cur_id] = cur_ns
par_map[cur_id] = cur_par
cur_id, cur_ns, cur_par = None, None, set()
with open(obo_path, encoding="utf-8") as fh:
for raw in fh:
line = raw.strip()
if line == "[Term]":
flush(); in_term = True; continue
if line.startswith("[") and line != "[Term]":
flush(); in_term = False; continue
if not in_term:
continue
if line.startswith("id:"):
cur_id = line.split("id:", 1)[1].strip().split()[0]
elif line.startswith("namespace:"):
cur_ns = line.split("namespace:", 1)[1].strip()
elif line.startswith("is_obsolete:") and "true" in line:
cur_id = None
elif line.startswith("is_a:"):
cur_par.add(line.split("is_a:", 1)[1].strip().split()[0])
elif line.startswith("relationship:"):
pts = line.split("relationship:", 1)[1].strip().split()
if len(pts) >= 2 and pts[0] == "part_of":
cur_par.add(pts[1])
flush()
mf = {g for g, n in ns_map.items() if n == "molecular_function"}
return {g: (par_map[g] & mf) for g in mf}
# ── Label parsing ─────────────────────────────────────────────────────────────
def parse_labels(x):
if x is None:
return []
if isinstance(x, (list, np.ndarray)):
return [int(v) for v in x]
if isinstance(x, str):
s = x.strip()
if not s or s.lower() == "nan":
return []
try:
v = ast.literal_eval(s)
return [int(i) for i in (v if isinstance(v, (list, tuple)) else [v])]
except Exception:
return []
return []
# ── Metrics ───────────────────────────────────────────────────────────────────
@torch.no_grad()
def compute_metrics(model, X_tensor, Y_mat, device, step=0.02):
"""
Returns micro-Fmax, AUPRC, macro-F1 (at t_star), and per-label details.
Y_mat: (N, C) float32 numpy array of ground-truth binary labels.
"""
model.eval()
batch = 512
all_probs = []
for i in range(0, len(X_tensor), batch):
xb = X_tensor[i:i+batch].to(device)
all_probs.append(torch.sigmoid(model(xb)).cpu().numpy())
probs = np.concatenate(all_probs, axis=0).astype(np.float32)
# ── micro-Fmax (sweep threshold) ─────────────────────────────────────────
edges = np.arange(0.0, 1.0 + step, step)
nbins = len(edges)
hp = np.zeros(nbins, np.int64)
hn = np.zeros(nbins, np.int64)
tp_total = int(Y_mat.sum())
p_flat = probs.ravel()
y_flat = Y_mat.ravel() > 0.5
bi = np.minimum(np.floor(p_flat / step + 1e-9).astype(np.int64), nbins - 1)
if y_flat.any():
hp += np.bincount(bi[y_flat], minlength=nbins)
if (~y_flat).any():
hn += np.bincount(bi[~y_flat], minlength=nbins)
cum_tp = np.cumsum(hp[::-1])[::-1].astype(float)
cum_fp = np.cumsum(hn[::-1])[::-1].astype(float)
pred_c = cum_tp + cum_fp
prec_c = np.where(pred_c > 0, cum_tp / pred_c, 0.0)
rec_c = cum_tp / max(tp_total, 1)
denom = prec_c + rec_c
f1_c = np.where(denom > 0, 2 * prec_c * rec_c / denom, 0.0)
b = int(np.argmax(f1_c))
micro_fmax = float(f1_c[b])
t_star = float(edges[b])
precision = float(prec_c[b])
recall = float(rec_c[b])
# ── Macro-F1 at t_star ───────────────────────────────────────────────────
pred_bin = (probs >= t_star).astype(np.float32)
label_f1s = []
for j in range(Y_mat.shape[1]):
tp = float((pred_bin[:, j] * Y_mat[:, j]).sum())
fp = float((pred_bin[:, j] * (1 - Y_mat[:, j])).sum())
fn = float(((1 - pred_bin[:, j]) * Y_mat[:, j]).sum())
d = 2 * tp + fp + fn
label_f1s.append((2 * tp / d) if d > 0 else 0.0)
macro_f1 = float(np.mean(label_f1s))
# ── Micro-AUPRC ──────────────────────────────────────────────────────────
try:
micro_auprc = float(average_precision_score(
Y_mat.ravel(), probs.ravel(), average="micro"
))
except Exception:
micro_auprc = float("nan")
# ── Label coverage ───────────────────────────────────────────────────────
label_coverage = float((Y_mat.sum(axis=0) > 0).mean())
return {
"micro_fmax": round(micro_fmax, 4),
"t_star": round(t_star, 3),
"precision": round(precision, 4),
"recall": round(recall, 4),
"macro_f1": round(macro_f1, 4),
"micro_auprc": round(micro_auprc, 4),
"label_coverage": round(label_coverage, 5),
"probs": probs,
"Y_mat": Y_mat,
}
def compute_cafa_fmax(probs, Y_mat, mlb_classes, go_parents, step=0.05):
if not go_parents:
return {"cafa_fmax": float("nan"), "t_star": float("nan")}
go2idx = {g: i for i, g in enumerate(mlb_classes)}
anc_map = {}
for gid in mlb_classes:
parents = go_parents.get(gid, set())
visited, stack = set(), list(parents)
while stack:
p = stack.pop()
if p not in visited:
visited.add(p)
stack.extend(go_parents.get(p, set()))
anc_map[gid] = {p for p in visited if p in go2idx}
anc_idx = [
np.array([go2idx[a] for a in anc_map.get(g, set())], dtype=np.int64)
for g in mlb_classes
]
has_label = Y_mat.sum(axis=1) > 0
p2 = probs[has_label]
y2 = Y_mat[has_label]
if len(p2) == 0:
return {"cafa_fmax": float("nan"), "t_star": float("nan")}
thresholds = np.arange(0.05, 0.96, step)
best_f1, best_t = -1.0, 0.5
for t in thresholds:
pred_bin = (p2 >= t).astype(np.float32)
prop = pred_bin.copy()
for j, aidx in enumerate(anc_idx):
if len(aidx) == 0:
continue
mask = pred_bin[:, j] > 0
if mask.any():
prop[np.ix_(np.where(mask)[0], aidx)] = 1.0
tp_per = (prop * y2).sum(axis=1)
pp_per = prop.sum(axis=1)
rp_per = y2.sum(axis=1)
prec_per = np.where(pp_per > 0, tp_per / pp_per, 0.0)
rec_per = np.where(rp_per > 0, tp_per / rp_per, 0.0)
has_pred = pp_per > 0
if has_pred.sum() == 0:
continue
avg_prec = prec_per[has_pred].mean()
avg_rec = rec_per.mean()
denom = avg_prec + avg_rec
f1 = (2 * avg_prec * avg_rec / denom) if denom > 0 else 0.0
if f1 > best_f1:
best_f1, best_t = f1, float(t)
return {"cafa_fmax": round(float(best_f1), 4), "t_star": round(best_t, 3)}
# ── Build feature matrix from parquet ────────────────────────────────────────
def build_feature_matrix(df: pd.DataFrame, in_dim: int, mu: np.ndarray,
sd: np.ndarray, supp_cols: list) -> np.ndarray:
"""
Reconstruct X from a standard embedding parquet.
Handles 3 model configs:
in_dim=320 β†’ ESM only
in_dim=331 β†’ ESM + 11 seq features
in_dim=360 β†’ ESM + all 39 supp features
"""
ESM_DIM = 320
emb_cols = [f"Dim_{i}" for i in range(ESM_DIM)]
X_esm = df[emb_cols].to_numpy(np.float32)
if in_dim == ESM_DIM:
return X_esm
# Need supplemental features β€” subset to what the model was trained on
n_supp = in_dim - ESM_DIM # 11 or 39 (or 40 if af_present appended separately)
cols_needed = supp_cols[:n_supp] if n_supp <= len(supp_cols) else supp_cols
# Normalise β€” mu/sd from checkpoint were computed on insect train set
mu_s = mu[:len(cols_needed)]
sd_s = sd[:len(cols_needed)]
# Impute missing columns with the training mean (so z-score β†’ 0, not (0-mu)/sd).
# This matters for taxon parquets that lack seq/AF features (e.g. mammal_embeddings_v3
# only stores ESM dims + f_af_present, so all other supp cols would be 0 β†’ extreme
# z-scores without this correction).
S = np.tile(mu_s, (len(df), 1)).astype(np.float32) # start at training mean
for ci, col in enumerate(cols_needed):
if col in df.columns:
v = df[col].to_numpy(np.float32)
S[:, ci] = np.where(np.isnan(v), mu_s[ci], v) # NaN β†’ mean
S_z = (S - mu_s) / (sd_s + 1e-12)
X = np.concatenate([X_esm, S_z], axis=1).astype(np.float32)
# Some checkpoints include af_present as a separate flag appended after supp
if X.shape[1] < in_dim:
af_col = np.tile(mu[:in_dim - X.shape[1]], (len(df), 1)).astype(np.float32)
if "f_af_present" in df.columns:
af_col[:, 0] = df["f_af_present"].to_numpy(np.float32)
X = np.concatenate([X, af_col], axis=1)
return X[:, :in_dim]
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
ap = argparse.ArgumentParser(description="ProtFunc per-taxon generalization eval")
ap.add_argument("--checkpoint", required=True, help="Path to .pth checkpoint")
ap.add_argument("--thresholds", required=True, help="Path to per-label thresholds JSON")
ap.add_argument("--mlb", required=True, help="Path to mlb_public_v1.pkl")
ap.add_argument("--taxon_parquet", required=True, help="Parquet with embeddings (prep_taxon.py output)")
ap.add_argument("--taxon_name", required=True, help="Short name, e.g. 'mammals', 'fungi'")
ap.add_argument("--obo", default="go-basic.obo", help="GO OBO file for CAFA eval")
ap.add_argument("--out", default="artifacts/generalization_results.json",
help="Output JSON (accumulates across taxa)")
ap.add_argument("--insect_log", default="artifacts/protfunc_v3_fixed_log.json",
help="Training log for reference insect test Fmax")
ap.add_argument("--device", default="auto")
args = ap.parse_args()
if args.device == "auto":
device = torch.device(
"mps" if torch.backends.mps.is_available() else
"cuda" if torch.cuda.is_available() else "cpu"
)
else:
device = torch.device(args.device)
print(f"Device: {device}")
# ── Load checkpoint ──────────────────────────────────────────────────────
print(f"\nLoading checkpoint: {args.checkpoint}")
ckpt = torch.load(args.checkpoint, map_location="cpu")
in_dim = ckpt["in_dim"]
hidden = ckpt.get("hidden", 2048)
n_blocks = ckpt.get("n_blocks", 4)
feature_label = ckpt.get("feature_label", "unknown")
mu = np.array(ckpt.get("supp_mu", []), dtype=np.float32)
sd = np.array(ckpt.get("supp_sd", []), dtype=np.float32)
supp_cols = ckpt.get("supp_cols", [])
print(f" in_dim={in_dim} hidden={hidden} n_blocks={n_blocks} feature_label={feature_label}")
mlb = joblib.load(args.mlb)
out_dim = len(mlb.classes_)
model = ImprovedResidualMLP(in_dim, out_dim, hidden, n_blocks).to(device)
model.load_state_dict(ckpt["model"])
model.eval()
print(f" Model loaded: {sum(p.numel() for p in model.parameters()):,} params")
# ── Load taxon data ──────────────────────────────────────────────────────
print(f"\nLoading taxon data: {args.taxon_parquet}")
df = pd.read_parquet(args.taxon_parquet)
print(f" {len(df)} proteins")
label_col = "Label_Indices" if "Label_Indices" in df.columns else None
if label_col is None:
print(" ERROR: taxon parquet has no Label_Indices column β€” cannot evaluate")
return
label_lists = [parse_labels(x) for x in df[label_col]]
Y_mat = np.zeros((len(df), out_dim), dtype=np.float32)
for r, labs in enumerate(label_lists):
for j in labs:
if 0 <= j < out_dim:
Y_mat[r, j] = 1.0
n_labeled = int((Y_mat.sum(axis=1) > 0).sum())
print(f" {n_labeled}/{len(df)} proteins have β‰₯1 GO-MF label")
if n_labeled == 0:
print(" No labeled proteins β€” nothing to evaluate.")
return
X = build_feature_matrix(df, in_dim, mu, sd, supp_cols)
X_tensor = torch.tensor(X, dtype=torch.float32)
print(f" Feature matrix: {X.shape}")
# ── Evaluate ─────────────────────────────────────────────────────────────
print(f"\nEvaluating '{args.taxon_name}'...")
m = compute_metrics(model, X_tensor, Y_mat, device)
probs = m.pop("probs")
m.pop("Y_mat")
# CAFA-style
go_parents = load_go_parents(Path(args.obo))
cafa = compute_cafa_fmax(probs, Y_mat, mlb.classes_, go_parents)
# Reference insect Fmax
insect_fmax = None
if Path(args.insect_log).exists():
try:
with open(args.insect_log) as f:
log = json.load(f)
for entry in log:
if isinstance(entry, dict) and "test_micro" in entry:
insect_fmax = round(float(entry["test_micro"]["micro_fmax"]), 4)
except Exception:
pass
gen_ratio = None
if insect_fmax and m["micro_fmax"] > 0:
gen_ratio = round(m["micro_fmax"] / insect_fmax, 4)
result = {
"n_proteins": len(df),
"n_labeled": n_labeled,
**m,
"cafa_fmax": cafa["cafa_fmax"],
"cafa_t_star": cafa["t_star"],
"insect_test_fmax": insect_fmax,
"generalization_ratio": gen_ratio,
"model_checkpoint": Path(args.checkpoint).name,
"feature_label": feature_label,
"evaluated_at": datetime.now(timezone.utc).isoformat(timespec="seconds"),
}
print(f"\n Results for '{args.taxon_name}':")
for k, v in result.items():
if not isinstance(v, str) or k == "evaluated_at":
print(f" {k:30s}: {v}")
# ── Write output (merge with existing) ───────────────────────────────────
out_path = Path(args.out)
existing = {}
if out_path.exists():
with open(out_path) as f:
existing = json.load(f)
existing[args.taxon_name] = result
out_path.parent.mkdir(exist_ok=True)
with open(out_path, "w") as f:
json.dump(existing, f, indent=2)
print(f"\nSaved to {out_path}")
# Also write a flat CSV for easy inspection across taxa
csv_path = out_path.with_suffix(".csv")
rows = []
for taxon, vals in existing.items():
row = {"taxon": taxon}
for k, v in vals.items():
if isinstance(v, (int, float, str, type(None))):
row[k] = v
rows.append(row)
pd.DataFrame(rows).to_csv(csv_path, index=False)
print(f"CSV summary: {csv_path}")
if __name__ == "__main__":
main()