""" eval_generalization.py — Per-taxon generalization evaluation for ProtFunc ========================================================================= Given a trained checkpoint and a taxon parquet (standard embedding format), computes micro-Fmax, CAFA-Fmax, precision, recall, AUPRC, and per-label coverage metrics. Works with any taxonomic group produced by prep_taxon.py. Usage ----- python scripts/eval_generalization.py \\ --checkpoint artifacts/protfunc_v3_fixed.pth \\ --thresholds artifacts/protfunc_v3_fixed_thresholds.json \\ --mlb "Important Files/mlb_public_v1.pkl" \\ --taxon_parquet artifacts/mammal_embeddings_v3.parquet \\ --taxon_name mammals \\ --obo go-basic.obo \\ --out artifacts/generalization_mammals.json Multiple taxon runs are accumulated: if --out already exists its contents are merged so you can build up a single generalization_results.json across all taxa over time. Output (JSON) ------------- { "mammals": { "n_proteins": 7, "n_labeled": 7, "micro_fmax": 0.82, "t_star": 0.40, "precision": 0.85, "recall": 0.79, "cafa_fmax": 0.91, "macro_f1": 0.74, "micro_auprc": 0.88, "label_coverage": 0.05, # fraction of 8124 GO terms seen in taxon "insect_test_fmax": 0.884, # reference (if insect log available) "generalization_ratio": 0.93, "model_checkpoint": "protfunc_v3_fixed.pth", "feature_label": "C_ESM_all", "evaluated_at": "2026-04-09T20:42:00" } } """ import argparse import ast import json import os import warnings from datetime import datetime, timezone from pathlib import Path import joblib import numpy as np import pandas as pd import torch import torch.nn as nn from sklearn.metrics import average_precision_score warnings.filterwarnings("ignore") # ── Architecture (must match training) ─────────────────────────────────────── class ResBlock(nn.Module): def __init__(self, dim: int, dropout: float): super().__init__() self.net = nn.Sequential( nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), ) def forward(self, x): return x + self.net(x) class ImprovedResidualMLP(nn.Module): def __init__(self, in_dim, out_dim=8124, hidden=2048, n_blocks=4, dropout=0.2): super().__init__() self.fc_in = nn.Linear(in_dim, hidden) self.blocks = nn.ModuleList([ResBlock(hidden, dropout) for _ in range(n_blocks)]) self.fc_out = nn.Sequential( nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, out_dim), ) def forward(self, x): h = self.fc_in(x) for blk in self.blocks: h = blk(h) return self.fc_out(h) # ── GO hierarchy ───────────────────────────────────────────────────────────── def load_go_parents(obo_path: Path) -> dict: if not obo_path.exists(): print(f" WARNING: {obo_path} not found — CAFA eval disabled") return {} ns_map, par_map = {}, {} cur_id, cur_ns, cur_par, in_term = None, None, set(), False def flush(): nonlocal cur_id, cur_ns, cur_par if cur_id and cur_ns: ns_map[cur_id] = cur_ns par_map[cur_id] = cur_par cur_id, cur_ns, cur_par = None, None, set() with open(obo_path, encoding="utf-8") as fh: for raw in fh: line = raw.strip() if line == "[Term]": flush(); in_term = True; continue if line.startswith("[") and line != "[Term]": flush(); in_term = False; continue if not in_term: continue if line.startswith("id:"): cur_id = line.split("id:", 1)[1].strip().split()[0] elif line.startswith("namespace:"): cur_ns = line.split("namespace:", 1)[1].strip() elif line.startswith("is_obsolete:") and "true" in line: cur_id = None elif line.startswith("is_a:"): cur_par.add(line.split("is_a:", 1)[1].strip().split()[0]) elif line.startswith("relationship:"): pts = line.split("relationship:", 1)[1].strip().split() if len(pts) >= 2 and pts[0] == "part_of": cur_par.add(pts[1]) flush() mf = {g for g, n in ns_map.items() if n == "molecular_function"} return {g: (par_map[g] & mf) for g in mf} # ── Label parsing ───────────────────────────────────────────────────────────── def parse_labels(x): if x is None: return [] if isinstance(x, (list, np.ndarray)): return [int(v) for v in x] if isinstance(x, str): s = x.strip() if not s or s.lower() == "nan": return [] try: v = ast.literal_eval(s) return [int(i) for i in (v if isinstance(v, (list, tuple)) else [v])] except Exception: return [] return [] # ── Metrics ─────────────────────────────────────────────────────────────────── @torch.no_grad() def compute_metrics(model, X_tensor, Y_mat, device, step=0.02): """ Returns micro-Fmax, AUPRC, macro-F1 (at t_star), and per-label details. Y_mat: (N, C) float32 numpy array of ground-truth binary labels. """ model.eval() batch = 512 all_probs = [] for i in range(0, len(X_tensor), batch): xb = X_tensor[i:i+batch].to(device) all_probs.append(torch.sigmoid(model(xb)).cpu().numpy()) probs = np.concatenate(all_probs, axis=0).astype(np.float32) # ── micro-Fmax (sweep threshold) ───────────────────────────────────────── edges = np.arange(0.0, 1.0 + step, step) nbins = len(edges) hp = np.zeros(nbins, np.int64) hn = np.zeros(nbins, np.int64) tp_total = int(Y_mat.sum()) p_flat = probs.ravel() y_flat = Y_mat.ravel() > 0.5 bi = np.minimum(np.floor(p_flat / step + 1e-9).astype(np.int64), nbins - 1) if y_flat.any(): hp += np.bincount(bi[y_flat], minlength=nbins) if (~y_flat).any(): hn += np.bincount(bi[~y_flat], minlength=nbins) cum_tp = np.cumsum(hp[::-1])[::-1].astype(float) cum_fp = np.cumsum(hn[::-1])[::-1].astype(float) pred_c = cum_tp + cum_fp prec_c = np.where(pred_c > 0, cum_tp / pred_c, 0.0) rec_c = cum_tp / max(tp_total, 1) denom = prec_c + rec_c f1_c = np.where(denom > 0, 2 * prec_c * rec_c / denom, 0.0) b = int(np.argmax(f1_c)) micro_fmax = float(f1_c[b]) t_star = float(edges[b]) precision = float(prec_c[b]) recall = float(rec_c[b]) # ── Macro-F1 at t_star ─────────────────────────────────────────────────── pred_bin = (probs >= t_star).astype(np.float32) label_f1s = [] for j in range(Y_mat.shape[1]): tp = float((pred_bin[:, j] * Y_mat[:, j]).sum()) fp = float((pred_bin[:, j] * (1 - Y_mat[:, j])).sum()) fn = float(((1 - pred_bin[:, j]) * Y_mat[:, j]).sum()) d = 2 * tp + fp + fn label_f1s.append((2 * tp / d) if d > 0 else 0.0) macro_f1 = float(np.mean(label_f1s)) # ── Micro-AUPRC ────────────────────────────────────────────────────────── try: micro_auprc = float(average_precision_score( Y_mat.ravel(), probs.ravel(), average="micro" )) except Exception: micro_auprc = float("nan") # ── Label coverage ─────────────────────────────────────────────────────── label_coverage = float((Y_mat.sum(axis=0) > 0).mean()) return { "micro_fmax": round(micro_fmax, 4), "t_star": round(t_star, 3), "precision": round(precision, 4), "recall": round(recall, 4), "macro_f1": round(macro_f1, 4), "micro_auprc": round(micro_auprc, 4), "label_coverage": round(label_coverage, 5), "probs": probs, "Y_mat": Y_mat, } def compute_cafa_fmax(probs, Y_mat, mlb_classes, go_parents, step=0.05): if not go_parents: return {"cafa_fmax": float("nan"), "t_star": float("nan")} go2idx = {g: i for i, g in enumerate(mlb_classes)} anc_map = {} for gid in mlb_classes: parents = go_parents.get(gid, set()) visited, stack = set(), list(parents) while stack: p = stack.pop() if p not in visited: visited.add(p) stack.extend(go_parents.get(p, set())) anc_map[gid] = {p for p in visited if p in go2idx} anc_idx = [ np.array([go2idx[a] for a in anc_map.get(g, set())], dtype=np.int64) for g in mlb_classes ] has_label = Y_mat.sum(axis=1) > 0 p2 = probs[has_label] y2 = Y_mat[has_label] if len(p2) == 0: return {"cafa_fmax": float("nan"), "t_star": float("nan")} thresholds = np.arange(0.05, 0.96, step) best_f1, best_t = -1.0, 0.5 for t in thresholds: pred_bin = (p2 >= t).astype(np.float32) prop = pred_bin.copy() for j, aidx in enumerate(anc_idx): if len(aidx) == 0: continue mask = pred_bin[:, j] > 0 if mask.any(): prop[np.ix_(np.where(mask)[0], aidx)] = 1.0 tp_per = (prop * y2).sum(axis=1) pp_per = prop.sum(axis=1) rp_per = y2.sum(axis=1) prec_per = np.where(pp_per > 0, tp_per / pp_per, 0.0) rec_per = np.where(rp_per > 0, tp_per / rp_per, 0.0) has_pred = pp_per > 0 if has_pred.sum() == 0: continue avg_prec = prec_per[has_pred].mean() avg_rec = rec_per.mean() denom = avg_prec + avg_rec f1 = (2 * avg_prec * avg_rec / denom) if denom > 0 else 0.0 if f1 > best_f1: best_f1, best_t = f1, float(t) return {"cafa_fmax": round(float(best_f1), 4), "t_star": round(best_t, 3)} # ── Build feature matrix from parquet ──────────────────────────────────────── def build_feature_matrix(df: pd.DataFrame, in_dim: int, mu: np.ndarray, sd: np.ndarray, supp_cols: list) -> np.ndarray: """ Reconstruct X from a standard embedding parquet. Handles 3 model configs: in_dim=320 → ESM only in_dim=331 → ESM + 11 seq features in_dim=360 → ESM + all 39 supp features """ ESM_DIM = 320 emb_cols = [f"Dim_{i}" for i in range(ESM_DIM)] X_esm = df[emb_cols].to_numpy(np.float32) if in_dim == ESM_DIM: return X_esm # Need supplemental features — subset to what the model was trained on n_supp = in_dim - ESM_DIM # 11 or 39 (or 40 if af_present appended separately) cols_needed = supp_cols[:n_supp] if n_supp <= len(supp_cols) else supp_cols # Normalise — mu/sd from checkpoint were computed on insect train set mu_s = mu[:len(cols_needed)] sd_s = sd[:len(cols_needed)] # Impute missing columns with the training mean (so z-score → 0, not (0-mu)/sd). # This matters for taxon parquets that lack seq/AF features (e.g. mammal_embeddings_v3 # only stores ESM dims + f_af_present, so all other supp cols would be 0 → extreme # z-scores without this correction). S = np.tile(mu_s, (len(df), 1)).astype(np.float32) # start at training mean for ci, col in enumerate(cols_needed): if col in df.columns: v = df[col].to_numpy(np.float32) S[:, ci] = np.where(np.isnan(v), mu_s[ci], v) # NaN → mean S_z = (S - mu_s) / (sd_s + 1e-12) X = np.concatenate([X_esm, S_z], axis=1).astype(np.float32) # Some checkpoints include af_present as a separate flag appended after supp if X.shape[1] < in_dim: af_col = np.tile(mu[:in_dim - X.shape[1]], (len(df), 1)).astype(np.float32) if "f_af_present" in df.columns: af_col[:, 0] = df["f_af_present"].to_numpy(np.float32) X = np.concatenate([X, af_col], axis=1) return X[:, :in_dim] # ── Main ────────────────────────────────────────────────────────────────────── def main(): ap = argparse.ArgumentParser(description="ProtFunc per-taxon generalization eval") ap.add_argument("--checkpoint", required=True, help="Path to .pth checkpoint") ap.add_argument("--thresholds", required=True, help="Path to per-label thresholds JSON") ap.add_argument("--mlb", required=True, help="Path to mlb_public_v1.pkl") ap.add_argument("--taxon_parquet", required=True, help="Parquet with embeddings (prep_taxon.py output)") ap.add_argument("--taxon_name", required=True, help="Short name, e.g. 'mammals', 'fungi'") ap.add_argument("--obo", default="go-basic.obo", help="GO OBO file for CAFA eval") ap.add_argument("--out", default="artifacts/generalization_results.json", help="Output JSON (accumulates across taxa)") ap.add_argument("--insect_log", default="artifacts/protfunc_v3_fixed_log.json", help="Training log for reference insect test Fmax") ap.add_argument("--device", default="auto") args = ap.parse_args() if args.device == "auto": device = torch.device( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) else: device = torch.device(args.device) print(f"Device: {device}") # ── Load checkpoint ────────────────────────────────────────────────────── print(f"\nLoading checkpoint: {args.checkpoint}") ckpt = torch.load(args.checkpoint, map_location="cpu") in_dim = ckpt["in_dim"] hidden = ckpt.get("hidden", 2048) n_blocks = ckpt.get("n_blocks", 4) feature_label = ckpt.get("feature_label", "unknown") mu = np.array(ckpt.get("supp_mu", []), dtype=np.float32) sd = np.array(ckpt.get("supp_sd", []), dtype=np.float32) supp_cols = ckpt.get("supp_cols", []) print(f" in_dim={in_dim} hidden={hidden} n_blocks={n_blocks} feature_label={feature_label}") mlb = joblib.load(args.mlb) out_dim = len(mlb.classes_) model = ImprovedResidualMLP(in_dim, out_dim, hidden, n_blocks).to(device) model.load_state_dict(ckpt["model"]) model.eval() print(f" Model loaded: {sum(p.numel() for p in model.parameters()):,} params") # ── Load taxon data ────────────────────────────────────────────────────── print(f"\nLoading taxon data: {args.taxon_parquet}") df = pd.read_parquet(args.taxon_parquet) print(f" {len(df)} proteins") label_col = "Label_Indices" if "Label_Indices" in df.columns else None if label_col is None: print(" ERROR: taxon parquet has no Label_Indices column — cannot evaluate") return label_lists = [parse_labels(x) for x in df[label_col]] Y_mat = np.zeros((len(df), out_dim), dtype=np.float32) for r, labs in enumerate(label_lists): for j in labs: if 0 <= j < out_dim: Y_mat[r, j] = 1.0 n_labeled = int((Y_mat.sum(axis=1) > 0).sum()) print(f" {n_labeled}/{len(df)} proteins have ≥1 GO-MF label") if n_labeled == 0: print(" No labeled proteins — nothing to evaluate.") return X = build_feature_matrix(df, in_dim, mu, sd, supp_cols) X_tensor = torch.tensor(X, dtype=torch.float32) print(f" Feature matrix: {X.shape}") # ── Evaluate ───────────────────────────────────────────────────────────── print(f"\nEvaluating '{args.taxon_name}'...") m = compute_metrics(model, X_tensor, Y_mat, device) probs = m.pop("probs") m.pop("Y_mat") # CAFA-style go_parents = load_go_parents(Path(args.obo)) cafa = compute_cafa_fmax(probs, Y_mat, mlb.classes_, go_parents) # Reference insect Fmax insect_fmax = None if Path(args.insect_log).exists(): try: with open(args.insect_log) as f: log = json.load(f) for entry in log: if isinstance(entry, dict) and "test_micro" in entry: insect_fmax = round(float(entry["test_micro"]["micro_fmax"]), 4) except Exception: pass gen_ratio = None if insect_fmax and m["micro_fmax"] > 0: gen_ratio = round(m["micro_fmax"] / insect_fmax, 4) result = { "n_proteins": len(df), "n_labeled": n_labeled, **m, "cafa_fmax": cafa["cafa_fmax"], "cafa_t_star": cafa["t_star"], "insect_test_fmax": insect_fmax, "generalization_ratio": gen_ratio, "model_checkpoint": Path(args.checkpoint).name, "feature_label": feature_label, "evaluated_at": datetime.now(timezone.utc).isoformat(timespec="seconds"), } print(f"\n Results for '{args.taxon_name}':") for k, v in result.items(): if not isinstance(v, str) or k == "evaluated_at": print(f" {k:30s}: {v}") # ── Write output (merge with existing) ─────────────────────────────────── out_path = Path(args.out) existing = {} if out_path.exists(): with open(out_path) as f: existing = json.load(f) existing[args.taxon_name] = result out_path.parent.mkdir(exist_ok=True) with open(out_path, "w") as f: json.dump(existing, f, indent=2) print(f"\nSaved to {out_path}") # Also write a flat CSV for easy inspection across taxa csv_path = out_path.with_suffix(".csv") rows = [] for taxon, vals in existing.items(): row = {"taxon": taxon} for k, v in vals.items(): if isinstance(v, (int, float, str, type(None))): row[k] = v rows.append(row) pd.DataFrame(rows).to_csv(csv_path, index=False) print(f"CSV summary: {csv_path}") if __name__ == "__main__": main()