| """ |
| eval_generalization.py β Per-taxon generalization evaluation for ProtFunc |
| ========================================================================= |
| Given a trained checkpoint and a taxon parquet (standard embedding format), |
| computes micro-Fmax, CAFA-Fmax, precision, recall, AUPRC, and per-label |
| coverage metrics. Works with any taxonomic group produced by prep_taxon.py. |
| |
| Usage |
| ----- |
| python scripts/eval_generalization.py \\ |
| --checkpoint artifacts/protfunc_v3_fixed.pth \\ |
| --thresholds artifacts/protfunc_v3_fixed_thresholds.json \\ |
| --mlb "Important Files/mlb_public_v1.pkl" \\ |
| --taxon_parquet artifacts/mammal_embeddings_v3.parquet \\ |
| --taxon_name mammals \\ |
| --obo go-basic.obo \\ |
| --out artifacts/generalization_mammals.json |
| |
| Multiple taxon runs are accumulated: if --out already exists its contents |
| are merged so you can build up a single generalization_results.json across |
| all taxa over time. |
| |
| Output (JSON) |
| ------------- |
| { |
| "mammals": { |
| "n_proteins": 7, |
| "n_labeled": 7, |
| "micro_fmax": 0.82, |
| "t_star": 0.40, |
| "precision": 0.85, |
| "recall": 0.79, |
| "cafa_fmax": 0.91, |
| "macro_f1": 0.74, |
| "micro_auprc": 0.88, |
| "label_coverage": 0.05, # fraction of 8124 GO terms seen in taxon |
| "insect_test_fmax": 0.884, # reference (if insect log available) |
| "generalization_ratio": 0.93, |
| "model_checkpoint": "protfunc_v3_fixed.pth", |
| "feature_label": "C_ESM_all", |
| "evaluated_at": "2026-04-09T20:42:00" |
| } |
| } |
| """ |
|
|
| import argparse |
| import ast |
| import json |
| import os |
| import warnings |
| from datetime import datetime, timezone |
| from pathlib import Path |
|
|
| import joblib |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from sklearn.metrics import average_precision_score |
|
|
| warnings.filterwarnings("ignore") |
|
|
| |
|
|
| class ResBlock(nn.Module): |
| def __init__(self, dim: int, dropout: float): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), |
| nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), |
| ) |
| def forward(self, x): |
| return x + self.net(x) |
|
|
|
|
| class ImprovedResidualMLP(nn.Module): |
| def __init__(self, in_dim, out_dim=8124, hidden=2048, n_blocks=4, dropout=0.2): |
| super().__init__() |
| self.fc_in = nn.Linear(in_dim, hidden) |
| self.blocks = nn.ModuleList([ResBlock(hidden, dropout) for _ in range(n_blocks)]) |
| self.fc_out = nn.Sequential( |
| nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(dropout), |
| nn.Linear(hidden, out_dim), |
| ) |
| def forward(self, x): |
| h = self.fc_in(x) |
| for blk in self.blocks: |
| h = blk(h) |
| return self.fc_out(h) |
|
|
|
|
| |
|
|
| def load_go_parents(obo_path: Path) -> dict: |
| if not obo_path.exists(): |
| print(f" WARNING: {obo_path} not found β CAFA eval disabled") |
| return {} |
| ns_map, par_map = {}, {} |
| cur_id, cur_ns, cur_par, in_term = None, None, set(), False |
|
|
| def flush(): |
| nonlocal cur_id, cur_ns, cur_par |
| if cur_id and cur_ns: |
| ns_map[cur_id] = cur_ns |
| par_map[cur_id] = cur_par |
| cur_id, cur_ns, cur_par = None, None, set() |
|
|
| with open(obo_path, encoding="utf-8") as fh: |
| for raw in fh: |
| line = raw.strip() |
| if line == "[Term]": |
| flush(); in_term = True; continue |
| if line.startswith("[") and line != "[Term]": |
| flush(); in_term = False; continue |
| if not in_term: |
| continue |
| if line.startswith("id:"): |
| cur_id = line.split("id:", 1)[1].strip().split()[0] |
| elif line.startswith("namespace:"): |
| cur_ns = line.split("namespace:", 1)[1].strip() |
| elif line.startswith("is_obsolete:") and "true" in line: |
| cur_id = None |
| elif line.startswith("is_a:"): |
| cur_par.add(line.split("is_a:", 1)[1].strip().split()[0]) |
| elif line.startswith("relationship:"): |
| pts = line.split("relationship:", 1)[1].strip().split() |
| if len(pts) >= 2 and pts[0] == "part_of": |
| cur_par.add(pts[1]) |
| flush() |
| mf = {g for g, n in ns_map.items() if n == "molecular_function"} |
| return {g: (par_map[g] & mf) for g in mf} |
|
|
|
|
| |
|
|
| def parse_labels(x): |
| if x is None: |
| return [] |
| if isinstance(x, (list, np.ndarray)): |
| return [int(v) for v in x] |
| if isinstance(x, str): |
| s = x.strip() |
| if not s or s.lower() == "nan": |
| return [] |
| try: |
| v = ast.literal_eval(s) |
| return [int(i) for i in (v if isinstance(v, (list, tuple)) else [v])] |
| except Exception: |
| return [] |
| return [] |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def compute_metrics(model, X_tensor, Y_mat, device, step=0.02): |
| """ |
| Returns micro-Fmax, AUPRC, macro-F1 (at t_star), and per-label details. |
| Y_mat: (N, C) float32 numpy array of ground-truth binary labels. |
| """ |
| model.eval() |
| batch = 512 |
| all_probs = [] |
| for i in range(0, len(X_tensor), batch): |
| xb = X_tensor[i:i+batch].to(device) |
| all_probs.append(torch.sigmoid(model(xb)).cpu().numpy()) |
| probs = np.concatenate(all_probs, axis=0).astype(np.float32) |
|
|
| |
| edges = np.arange(0.0, 1.0 + step, step) |
| nbins = len(edges) |
| hp = np.zeros(nbins, np.int64) |
| hn = np.zeros(nbins, np.int64) |
| tp_total = int(Y_mat.sum()) |
|
|
| p_flat = probs.ravel() |
| y_flat = Y_mat.ravel() > 0.5 |
| bi = np.minimum(np.floor(p_flat / step + 1e-9).astype(np.int64), nbins - 1) |
| if y_flat.any(): |
| hp += np.bincount(bi[y_flat], minlength=nbins) |
| if (~y_flat).any(): |
| hn += np.bincount(bi[~y_flat], minlength=nbins) |
|
|
| cum_tp = np.cumsum(hp[::-1])[::-1].astype(float) |
| cum_fp = np.cumsum(hn[::-1])[::-1].astype(float) |
| pred_c = cum_tp + cum_fp |
| prec_c = np.where(pred_c > 0, cum_tp / pred_c, 0.0) |
| rec_c = cum_tp / max(tp_total, 1) |
| denom = prec_c + rec_c |
| f1_c = np.where(denom > 0, 2 * prec_c * rec_c / denom, 0.0) |
| b = int(np.argmax(f1_c)) |
| micro_fmax = float(f1_c[b]) |
| t_star = float(edges[b]) |
| precision = float(prec_c[b]) |
| recall = float(rec_c[b]) |
|
|
| |
| pred_bin = (probs >= t_star).astype(np.float32) |
| label_f1s = [] |
| for j in range(Y_mat.shape[1]): |
| tp = float((pred_bin[:, j] * Y_mat[:, j]).sum()) |
| fp = float((pred_bin[:, j] * (1 - Y_mat[:, j])).sum()) |
| fn = float(((1 - pred_bin[:, j]) * Y_mat[:, j]).sum()) |
| d = 2 * tp + fp + fn |
| label_f1s.append((2 * tp / d) if d > 0 else 0.0) |
| macro_f1 = float(np.mean(label_f1s)) |
|
|
| |
| try: |
| micro_auprc = float(average_precision_score( |
| Y_mat.ravel(), probs.ravel(), average="micro" |
| )) |
| except Exception: |
| micro_auprc = float("nan") |
|
|
| |
| label_coverage = float((Y_mat.sum(axis=0) > 0).mean()) |
|
|
| return { |
| "micro_fmax": round(micro_fmax, 4), |
| "t_star": round(t_star, 3), |
| "precision": round(precision, 4), |
| "recall": round(recall, 4), |
| "macro_f1": round(macro_f1, 4), |
| "micro_auprc": round(micro_auprc, 4), |
| "label_coverage": round(label_coverage, 5), |
| "probs": probs, |
| "Y_mat": Y_mat, |
| } |
|
|
|
|
| def compute_cafa_fmax(probs, Y_mat, mlb_classes, go_parents, step=0.05): |
| if not go_parents: |
| return {"cafa_fmax": float("nan"), "t_star": float("nan")} |
|
|
| go2idx = {g: i for i, g in enumerate(mlb_classes)} |
| anc_map = {} |
| for gid in mlb_classes: |
| parents = go_parents.get(gid, set()) |
| visited, stack = set(), list(parents) |
| while stack: |
| p = stack.pop() |
| if p not in visited: |
| visited.add(p) |
| stack.extend(go_parents.get(p, set())) |
| anc_map[gid] = {p for p in visited if p in go2idx} |
|
|
| anc_idx = [ |
| np.array([go2idx[a] for a in anc_map.get(g, set())], dtype=np.int64) |
| for g in mlb_classes |
| ] |
|
|
| has_label = Y_mat.sum(axis=1) > 0 |
| p2 = probs[has_label] |
| y2 = Y_mat[has_label] |
| if len(p2) == 0: |
| return {"cafa_fmax": float("nan"), "t_star": float("nan")} |
|
|
| thresholds = np.arange(0.05, 0.96, step) |
| best_f1, best_t = -1.0, 0.5 |
| for t in thresholds: |
| pred_bin = (p2 >= t).astype(np.float32) |
| prop = pred_bin.copy() |
| for j, aidx in enumerate(anc_idx): |
| if len(aidx) == 0: |
| continue |
| mask = pred_bin[:, j] > 0 |
| if mask.any(): |
| prop[np.ix_(np.where(mask)[0], aidx)] = 1.0 |
|
|
| tp_per = (prop * y2).sum(axis=1) |
| pp_per = prop.sum(axis=1) |
| rp_per = y2.sum(axis=1) |
| prec_per = np.where(pp_per > 0, tp_per / pp_per, 0.0) |
| rec_per = np.where(rp_per > 0, tp_per / rp_per, 0.0) |
| has_pred = pp_per > 0 |
| if has_pred.sum() == 0: |
| continue |
| avg_prec = prec_per[has_pred].mean() |
| avg_rec = rec_per.mean() |
| denom = avg_prec + avg_rec |
| f1 = (2 * avg_prec * avg_rec / denom) if denom > 0 else 0.0 |
| if f1 > best_f1: |
| best_f1, best_t = f1, float(t) |
|
|
| return {"cafa_fmax": round(float(best_f1), 4), "t_star": round(best_t, 3)} |
|
|
|
|
| |
|
|
| def build_feature_matrix(df: pd.DataFrame, in_dim: int, mu: np.ndarray, |
| sd: np.ndarray, supp_cols: list) -> np.ndarray: |
| """ |
| Reconstruct X from a standard embedding parquet. |
| Handles 3 model configs: |
| in_dim=320 β ESM only |
| in_dim=331 β ESM + 11 seq features |
| in_dim=360 β ESM + all 39 supp features |
| """ |
| ESM_DIM = 320 |
| emb_cols = [f"Dim_{i}" for i in range(ESM_DIM)] |
| X_esm = df[emb_cols].to_numpy(np.float32) |
|
|
| if in_dim == ESM_DIM: |
| return X_esm |
|
|
| |
| n_supp = in_dim - ESM_DIM |
| cols_needed = supp_cols[:n_supp] if n_supp <= len(supp_cols) else supp_cols |
|
|
| |
| mu_s = mu[:len(cols_needed)] |
| sd_s = sd[:len(cols_needed)] |
|
|
| |
| |
| |
| |
| S = np.tile(mu_s, (len(df), 1)).astype(np.float32) |
| for ci, col in enumerate(cols_needed): |
| if col in df.columns: |
| v = df[col].to_numpy(np.float32) |
| S[:, ci] = np.where(np.isnan(v), mu_s[ci], v) |
|
|
| S_z = (S - mu_s) / (sd_s + 1e-12) |
|
|
| X = np.concatenate([X_esm, S_z], axis=1).astype(np.float32) |
| |
| if X.shape[1] < in_dim: |
| af_col = np.tile(mu[:in_dim - X.shape[1]], (len(df), 1)).astype(np.float32) |
| if "f_af_present" in df.columns: |
| af_col[:, 0] = df["f_af_present"].to_numpy(np.float32) |
| X = np.concatenate([X, af_col], axis=1) |
| return X[:, :in_dim] |
|
|
|
|
| |
|
|
| def main(): |
| ap = argparse.ArgumentParser(description="ProtFunc per-taxon generalization eval") |
| ap.add_argument("--checkpoint", required=True, help="Path to .pth checkpoint") |
| ap.add_argument("--thresholds", required=True, help="Path to per-label thresholds JSON") |
| ap.add_argument("--mlb", required=True, help="Path to mlb_public_v1.pkl") |
| ap.add_argument("--taxon_parquet", required=True, help="Parquet with embeddings (prep_taxon.py output)") |
| ap.add_argument("--taxon_name", required=True, help="Short name, e.g. 'mammals', 'fungi'") |
| ap.add_argument("--obo", default="go-basic.obo", help="GO OBO file for CAFA eval") |
| ap.add_argument("--out", default="artifacts/generalization_results.json", |
| help="Output JSON (accumulates across taxa)") |
| ap.add_argument("--insect_log", default="artifacts/protfunc_v3_fixed_log.json", |
| help="Training log for reference insect test Fmax") |
| ap.add_argument("--device", default="auto") |
| args = ap.parse_args() |
|
|
| if args.device == "auto": |
| device = torch.device( |
| "mps" if torch.backends.mps.is_available() else |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
| else: |
| device = torch.device(args.device) |
| print(f"Device: {device}") |
|
|
| |
| print(f"\nLoading checkpoint: {args.checkpoint}") |
| ckpt = torch.load(args.checkpoint, map_location="cpu") |
| in_dim = ckpt["in_dim"] |
| hidden = ckpt.get("hidden", 2048) |
| n_blocks = ckpt.get("n_blocks", 4) |
| feature_label = ckpt.get("feature_label", "unknown") |
| mu = np.array(ckpt.get("supp_mu", []), dtype=np.float32) |
| sd = np.array(ckpt.get("supp_sd", []), dtype=np.float32) |
| supp_cols = ckpt.get("supp_cols", []) |
| print(f" in_dim={in_dim} hidden={hidden} n_blocks={n_blocks} feature_label={feature_label}") |
|
|
| mlb = joblib.load(args.mlb) |
| out_dim = len(mlb.classes_) |
|
|
| model = ImprovedResidualMLP(in_dim, out_dim, hidden, n_blocks).to(device) |
| model.load_state_dict(ckpt["model"]) |
| model.eval() |
| print(f" Model loaded: {sum(p.numel() for p in model.parameters()):,} params") |
|
|
| |
| print(f"\nLoading taxon data: {args.taxon_parquet}") |
| df = pd.read_parquet(args.taxon_parquet) |
| print(f" {len(df)} proteins") |
|
|
| label_col = "Label_Indices" if "Label_Indices" in df.columns else None |
| if label_col is None: |
| print(" ERROR: taxon parquet has no Label_Indices column β cannot evaluate") |
| return |
|
|
| label_lists = [parse_labels(x) for x in df[label_col]] |
| Y_mat = np.zeros((len(df), out_dim), dtype=np.float32) |
| for r, labs in enumerate(label_lists): |
| for j in labs: |
| if 0 <= j < out_dim: |
| Y_mat[r, j] = 1.0 |
|
|
| n_labeled = int((Y_mat.sum(axis=1) > 0).sum()) |
| print(f" {n_labeled}/{len(df)} proteins have β₯1 GO-MF label") |
| if n_labeled == 0: |
| print(" No labeled proteins β nothing to evaluate.") |
| return |
|
|
| X = build_feature_matrix(df, in_dim, mu, sd, supp_cols) |
| X_tensor = torch.tensor(X, dtype=torch.float32) |
| print(f" Feature matrix: {X.shape}") |
|
|
| |
| print(f"\nEvaluating '{args.taxon_name}'...") |
| m = compute_metrics(model, X_tensor, Y_mat, device) |
| probs = m.pop("probs") |
| m.pop("Y_mat") |
|
|
| |
| go_parents = load_go_parents(Path(args.obo)) |
| cafa = compute_cafa_fmax(probs, Y_mat, mlb.classes_, go_parents) |
|
|
| |
| insect_fmax = None |
| if Path(args.insect_log).exists(): |
| try: |
| with open(args.insect_log) as f: |
| log = json.load(f) |
| for entry in log: |
| if isinstance(entry, dict) and "test_micro" in entry: |
| insect_fmax = round(float(entry["test_micro"]["micro_fmax"]), 4) |
| except Exception: |
| pass |
|
|
| gen_ratio = None |
| if insect_fmax and m["micro_fmax"] > 0: |
| gen_ratio = round(m["micro_fmax"] / insect_fmax, 4) |
|
|
| result = { |
| "n_proteins": len(df), |
| "n_labeled": n_labeled, |
| **m, |
| "cafa_fmax": cafa["cafa_fmax"], |
| "cafa_t_star": cafa["t_star"], |
| "insect_test_fmax": insect_fmax, |
| "generalization_ratio": gen_ratio, |
| "model_checkpoint": Path(args.checkpoint).name, |
| "feature_label": feature_label, |
| "evaluated_at": datetime.now(timezone.utc).isoformat(timespec="seconds"), |
| } |
|
|
| print(f"\n Results for '{args.taxon_name}':") |
| for k, v in result.items(): |
| if not isinstance(v, str) or k == "evaluated_at": |
| print(f" {k:30s}: {v}") |
|
|
| |
| out_path = Path(args.out) |
| existing = {} |
| if out_path.exists(): |
| with open(out_path) as f: |
| existing = json.load(f) |
| existing[args.taxon_name] = result |
| out_path.parent.mkdir(exist_ok=True) |
| with open(out_path, "w") as f: |
| json.dump(existing, f, indent=2) |
| print(f"\nSaved to {out_path}") |
|
|
| |
| csv_path = out_path.with_suffix(".csv") |
| rows = [] |
| for taxon, vals in existing.items(): |
| row = {"taxon": taxon} |
| for k, v in vals.items(): |
| if isinstance(v, (int, float, str, type(None))): |
| row[k] = v |
| rows.append(row) |
| pd.DataFrame(rows).to_csv(csv_path, index=False) |
| print(f"CSV summary: {csv_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|