""" threshold_comparison.py ======================= Compare current MF thresholding vs a stricter precision-first alternative on the ProtFunc v3 pipeline. Strategies: A. Current ProtFunc v3 thresholds B. Precision-first MF thresholds + IC scaling for top-25 most common MF terms C. B + novelty gating on the most novel proteins (bottom similarity quantile) The script intentionally evaluates only molecular-function labels on direct annotations. It keeps non-MF thresholds unchanged in the saved JSON and reports: - overall metrics on a random test subset - novelty-subset metrics on the bottom-20% KNN-similarity proteins Outputs: artifacts/thresholds/precision_ic_thresholds.json artifacts/threshold_comparison_results.json """ import ast import json import math import time import warnings from pathlib import Path import numpy as np import pandas as pd import joblib import torch import torch.nn as nn warnings.filterwarnings("ignore") BASE = Path(__file__).parent.parent ART = BASE / "artifacts" IMPORTANT = BASE / "Important Files" DATA_BASE = IMPORTANT / "merged_full_struct.parquet" DATA_SUPP = IMPORTANT / "merged_full_struct_with_features.parquet" MLB_PATH = IMPORTANT / "mlb_public_v1.pkl" SPLITS_NPZ = ART / "splits" / "splits_n250000_seed42.npz" OBO_PATH = BASE / "go-basic.obo" CKPT_PATH = ART / "graph_hpo" / "graph_hpo_best.pth" CURRENT_THRESH = ART / "graph_hpo" / "graph_hpo_best_thresholds.json" OUT_PATH = ART / "graph_hpo" / "threshold_comparison_graph_hpo.json" PREC_THRESH_OUT = ART / "graph_hpo" / "precision_ic_thresholds_graph_hpo.json" SUBSET_SIZE = 2000 # test proteins TRAIN_KNN = 5000 # training proteins for KNN reference TOP_COMMON = 25 # top-N by frequency for IC scaling KNN_K = 10 NOVELTY_Q = 0.20 # bottom quantile of proteins treated as "novel" NOVELTY_HI_T = 0.996 # ceiling for the most novel proteins SEED = 42 rng = np.random.default_rng(SEED) # ─── Architecture ───────────────────────────────────────────────────────────── class ResBlock(nn.Module): def __init__(self, dim, dropout=0.2): super().__init__() self.net = nn.Sequential( nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), ) def forward(self, x): return x + self.net(x) class ImprovedResidualMLP(nn.Module): def __init__(self, in_dim, out_dim=8124, hidden=2048, n_blocks=4, dropout=0.2): super().__init__() self.fc_in = nn.Linear(in_dim, hidden) self.blocks = nn.ModuleList([ResBlock(hidden, dropout) for _ in range(n_blocks)]) self.fc_out = nn.Sequential( nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, out_dim), ) def forward(self, x): h = self.fc_in(x) for b in self.blocks: h = b(h) return self.fc_out(h) # ─── GO hierarchy ───────────────────────────────────────────────────────────── def load_go_hierarchy(obo_path): ns_map, par_map = {}, {} cur_id, cur_ns, cur_par, in_term = None, None, set(), False def flush(): nonlocal cur_id, cur_ns, cur_par if cur_id and cur_ns: ns_map[cur_id] = cur_ns par_map[cur_id] = set(cur_par) cur_id, cur_ns, cur_par = None, None, set() with open(obo_path) as fh: for raw in fh: line = raw.strip() if line == "[Term]": flush(); in_term = True; continue if line.startswith("[") and line != "[Term]": flush(); in_term = False; continue if not in_term: continue if line.startswith("id:"): cur_id = line.split("id:",1)[1].strip().split()[0] elif line.startswith("namespace:"): cur_ns = line.split("namespace:",1)[1].strip() elif "is_obsolete:" in line and "true" in line: cur_id = None elif line.startswith("is_a:"): cur_par.add(line.split("is_a:",1)[1].strip().split()[0]) elif line.startswith("relationship:"): pts = line.split("relationship:",1)[1].strip().split() if len(pts) >= 2 and pts[0] == "part_of": cur_par.add(pts[1]) flush() mf = {g for g, n in ns_map.items() if n == "molecular_function"} return {g: (par_map.get(g, set()) & mf) for g in mf} # ─── Label parsing + feature assembly ───────────────────────────────────────── def parse_labels(x): if x is None: return [] if isinstance(x, (list, np.ndarray)): return [int(v) for v in x] if isinstance(x, str): s = x.strip() if not s or s.lower() == "nan": return [] try: v = ast.literal_eval(s) return [int(i) for i in (v if isinstance(v, (list, tuple)) else [v])] except Exception: return [] return [] def build_inputs(df_base, df_supp, indices, ckpt): emb_cols = [c for c in df_base.columns if c.startswith("Dim_")] x_base = df_base.iloc[indices][emb_cols].to_numpy(np.float32) supp_cols = ckpt.get("supp_cols", []) if not supp_cols: return x_base mu = np.asarray(ckpt["supp_mu"], dtype=np.float32) sd = np.asarray(ckpt["supp_sd"], dtype=np.float32) s = df_supp.iloc[indices][supp_cols].to_numpy(np.float32) s_z = (s - mu) / (sd + 1e-12) in_dim = ckpt.get("in_dim") n_supp_used = in_dim - x_base.shape[1] if in_dim else len(supp_cols) # esm_seq / partial supp: use only first n_supp_used cols if n_supp_used <= len(supp_cols): s_z = s_z[:, :n_supp_used] return np.concatenate([x_base, s_z], axis=1).astype(np.float32) # esm_all with m_flag appended if in_dim == x_base.shape[1] + len(supp_cols) + 1: af_present = df_supp.iloc[indices]["f_af_present"].to_numpy(np.float32).reshape(-1, 1) return np.concatenate([x_base, s_z, af_present], axis=1).astype(np.float32) raise ValueError( f"Unsupported input shape for checkpoint: in_dim={in_dim} " f"vs base={x_base.shape[1]} supp={len(supp_cols)}" ) # ─── Precision-biased threshold sweep (MF only) ────────────────────────────── def compute_fbeta_thresholds(probs, true, beta=0.5, steps=None, min_support=10, floor=0.90): """ Per label: find threshold maximising F-beta on the high-threshold regime. For this v3 MF model, useful separation happens around 0.80+, so sweeping low thresholds only reproduces the overprediction failure mode. """ if steps is None: coarse = np.arange(0.80, 0.981, 0.01, dtype=np.float32) fine = np.arange(0.982, 0.996, 0.002, dtype=np.float32) steps = np.concatenate([coarse, fine]).astype(np.float32) n_labels = probs.shape[1] thr = np.full(n_labels, floor, dtype=np.float32) b2 = beta ** 2 for j in range(n_labels): pj = probs[:, j] tj = true[:, j] if tj.sum() < min_support: continue best_fb, best_t, best_prec = -1.0, floor, -1.0 for t in steps: pred = (pj >= t).astype(np.float32) tp = (pred * tj).sum() fp = (pred * (1 - tj)).sum() fn = ((1 - pred) * tj).sum() prec = tp / (tp + fp + 1e-9) denom = (1 + b2) * tp + b2 * fn + fp fb = ((1 + b2) * tp / denom) if denom > 0 else 0.0 if fb > best_fb or (abs(fb - best_fb) < 1e-12 and prec > best_prec): best_fb, best_t, best_prec = float(fb), float(t), float(prec) thr[j] = best_t return thr # ─── IC-scaled thresholds for top-N most common terms ───────────────────────── def ic_scaled_thresholds(base_thr, label_freq, mlb_classes, mf_idx, top_n=25): """ For the top_n most annotated MF GO terms, raise threshold proportionally to how broad the term is (low IC = high annotation frequency = raise threshold more). IC of term i = -log2(freq_i / total_annotations). New threshold = base * (1 + alpha*(1 - IC_i/max_IC)), capped near 1.0. alpha=0.40. """ thr = base_thr.copy() total = label_freq.sum() + 1e-9 ic = np.zeros(len(mlb_classes)) for j in range(len(mlb_classes)): if label_freq[j] > 0: ic[j] = -math.log2(label_freq[j] / total) # restrict to MF labels only mf_set = set(mf_idx.tolist()) mf_freq = np.zeros(len(mlb_classes)) for j in range(len(mlb_classes)): if j in mf_set: mf_freq[j] = label_freq[j] top_idx = np.argsort(mf_freq)[-top_n:] max_ic = ic[top_idx].max() if len(top_idx) else 1.0 alpha = 0.40 adjustments = [] for j in top_idx: ic_norm = ic[j] / (max_ic + 1e-9) scale = 1.0 + alpha * (1.0 - ic_norm) new_t = min(NOVELTY_HI_T, float(base_thr[j]) * scale) adjustments.append((j, mlb_classes[j], int(label_freq[j]), float(base_thr[j]), new_t)) thr[j] = new_t print(f" IC-scaled top-{top_n} most frequent MF terms:") for j, gid, freq, old_t, new_t in sorted(adjustments, key=lambda x: -x[2])[:8]: print(f" {gid} freq={freq:,} base={old_t:.3f} → {new_t:.3f}") return thr, adjustments # ─── KNN novelty ────────────────────────────────────────────────────────────── def build_knn_ref(embs): """L2-normalise for cosine similarity.""" norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-9 return embs / norms # (N, dim) def compute_novelty_sim(query_emb, knn_ref, k=KNN_K): """Mean cosine similarity to top-k neighbours (higher = more familiar).""" q = query_emb / (np.linalg.norm(query_emb) + 1e-9) # (dim,) sims = knn_ref @ q # (N,) top_k = np.partition(sims, -k)[-k:] return float(top_k.mean()) def apply_novelty_gate(base_thr, sim, lo, sim_min, hi_t=NOVELTY_HI_T): """ Quantile-gated thresholding: proteins above the novelty cutoff keep the precision+IC thresholds; proteins below it are pushed toward hi_t based on relative novelty. """ if sim >= lo: return base_thr alpha = min(1.0, max(0.0, (lo - sim) / (lo - sim_min + 1e-9))) return base_thr + alpha * (hi_t - base_thr) # ─── Evaluation ─────────────────────────────────────────────────────────────── def evaluate(probs, true_y, thr_arr, mf_idx, embs=None, knn_ref=None, novelty_cut=None, novelty_min=None): """ Compute per-protein predictions then aggregate metrics. If embs + knn_ref supplied → apply novelty gating per protein. true_y: (N, n_labels) from Label_Indices (direct labels, not propagated). """ n = probs.shape[0] tp_tot = fp_tot = fn_tot = 0 n_preds_list = [] for i in range(n): pv = probs[i, mf_idx] tv = true_y[i, mf_idx] thr = thr_arr[mf_idx] if embs is not None and knn_ref is not None and novelty_cut is not None and novelty_min is not None: sim = compute_novelty_sim(embs[i], knn_ref) thr = apply_novelty_gate(thr, sim, novelty_cut, novelty_min) pred = (pv >= thr).astype(np.float32) tp = (pred * tv).sum(); fp = (pred * (1 - tv)).sum(); fn = ((1 - pred) * tv).sum() tp_tot += tp; fp_tot += fp; fn_tot += fn n_preds_list.append(int(pred.sum())) prec = tp_tot / (tp_tot + fp_tot + 1e-9) rec = tp_tot / (tp_tot + fn_tot + 1e-9) f1 = 2 * prec * rec / (prec + rec + 1e-9) ndl = np.array(n_preds_list) return { "micro_precision": round(float(prec), 4), "micro_recall": round(float(rec), 4), "micro_f1": round(float(f1), 4), "mean_preds_per_protein": round(float(ndl.mean()), 2), "median_preds": float(np.median(ndl)), "pct_tight_le5": round(float((ndl <= 5).mean() * 100), 1), "pct_noisy_gt15": round(float((ndl > 15).mean() * 100), 1), "pct_zero_preds": round(float((ndl == 0).mean() * 100), 1), "coverage_pct": round(float((ndl > 0).mean() * 100), 1), } def subset_metrics(probs, true_y, thr_arr, mf_idx, subset_mask, embs=None, knn_ref=None, novelty_cut=None, novelty_min=None): idx = np.flatnonzero(subset_mask) return evaluate( probs[idx], true_y[idx], thr_arr, mf_idx, embs=None if embs is None else embs[idx], knn_ref=knn_ref, novelty_cut=novelty_cut, novelty_min=novelty_min, ) def thr_stats(thr, idx): v = thr[idx] return { "mean": round(float(v.mean()), 4), "min": round(float(v.min()), 4), "max": round(float(v.max()), 4), "pct_lt03": round(float((v < 0.3).mean() * 100), 1), "pct_lt05": round(float((v < 0.5).mean() * 100), 1), "pct_ge07": round(float((v >= 0.7).mean() * 100), 1), } # ─── Main ───────────────────────────────────────────────────────────────────── def main(): t0 = time.time() device = torch.device("cpu") # ── Data ─────────────────────────────────────────────────────────────────── print("Loading data...") df_base = pd.read_parquet(DATA_BASE) df_supp = pd.read_parquet(DATA_SUPP) mlb = joblib.load(MLB_PATH) n_labels = len(mlb.classes_) splits = np.load(SPLITS_NPZ, allow_pickle=True) train_idx = splits["train_idx"] val_idx = splits["val_idx"] test_idx = splits["test_idx"] test_sub = rng.choice(test_idx, size=SUBSET_SIZE, replace=False) train_sub = rng.choice(train_idx, size=TRAIN_KNN, replace=False) print(f" Splits: train={len(train_idx)}, val={len(val_idx)}, " f"test subset={len(test_sub)}, knn_ref={len(train_sub)}") # ── Label matrix helper ──────────────────────────────────────────────────── def make_Y(indices): Y = np.zeros((len(indices), n_labels), dtype=np.float32) for r, row in enumerate(df_supp.iloc[indices]["Label_Indices"]): for v in parse_labels(row): if 0 <= int(v) < n_labels: Y[r, int(v)] = 1.0 return Y # ── GO hierarchy → MF indices ────────────────────────────────────────────── print("Loading GO hierarchy...") go_parents = load_go_hierarchy(OBO_PATH) mf_go_ids = set(go_parents.keys()) mf_idx = np.array([j for j, c in enumerate(mlb.classes_) if c in mf_go_ids]) print(f" MF labels in MLB: {len(mf_idx)}") # ── Model ───────────────────────────────────────────────────────────────── print(f"Loading model {CKPT_PATH.name}...") ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False) state = ckpt.get("model", ckpt) in_dim = state["fc_in.weight"].shape[1] model = ImprovedResidualMLP( in_dim=in_dim, out_dim=n_labels, hidden=ckpt.get("hidden", 2048), n_blocks=ckpt.get("n_blocks", 4), ).to(device) model.load_state_dict(state) model.eval() print(f" in_dim={in_dim}") # ── Inference ───────────────────────────────────────────────────────────── def run_inference(indices, desc): full_x = build_inputs(df_base, df_supp, indices, ckpt) esm_x = df_base.iloc[indices][[c for c in df_base.columns if c.startswith("Dim_")]].values.astype(np.float32) result = [] with torch.no_grad(): for s in range(0, len(full_x), 512): xb = torch.tensor(full_x[s:s+512]).to(device) result.append(torch.sigmoid(model(xb)).cpu().numpy()) print(f" {desc}: {len(indices)} proteins done") return np.concatenate(result, axis=0), esm_x print("Running inference...") val_probs, val_embs = run_inference(val_idx, "val") test_probs, test_embs = run_inference(test_sub, "test subset") Y_val = make_Y(val_idx) Y_test = make_Y(test_sub) print(f" Mean direct labels/protein in test subset: " f"{Y_test[:, mf_idx].sum(1).mean():.2f}") # ── KNN reference ───────────────────────────────────────────────────────── print("Building KNN reference...") train_embs = df_base.iloc[train_sub][[c for c in df_base.columns if c.startswith('Dim_')]].values.astype(np.float32) knn_ref = build_knn_ref(train_embs) # (TRAIN_KNN, 320) # ── Label frequency from training set ───────────────────────────────────── print("Computing label frequencies...") label_freq = np.zeros(n_labels, dtype=np.float32) for row in df_supp.iloc[train_idx]["Label_Indices"]: for v in parse_labels(row): if int(v) < n_labels: label_freq[int(v)] += 1 # ───────────────────────────────────────────────────────────────────────── # STRATEGY A: current v3 thresholds # ───────────────────────────────────────────────────────────────────────── print("\n=== Strategy A: current ProtFunc v3 thresholds ===") with open(CURRENT_THRESH) as f: curr_dict = json.load(f) thr_A = np.full(n_labels, 0.5, dtype=np.float32) for k, v in curr_dict.items(): thr_A[int(k)] = float(v) metrics_A = evaluate(test_probs, Y_test, thr_A, mf_idx) print(json.dumps(metrics_A, indent=2)) # ───────────────────────────────────────────────────────────────────────── # STRATEGY B: precision-biased (F-β=0.5) + IC-scaled top-25 # ───────────────────────────────────────────────────────────────────────── print("\n=== Strategy B: MF precision thresholds + IC-scaled top-25 ===") print(" Sweeping MF thresholds on val set in the high-confidence regime...") thr_B = thr_A.copy() thr_B_mf = compute_fbeta_thresholds(val_probs[:, mf_idx], Y_val[:, mf_idx], beta=0.5, floor=0.90) thr_B[mf_idx] = thr_B_mf thr_B, ic_adj = ic_scaled_thresholds(thr_B, label_freq, mlb.classes_, mf_idx, top_n=TOP_COMMON) metrics_B = evaluate(test_probs, Y_test, thr_B, mf_idx) print(json.dumps(metrics_B, indent=2)) # Save thr_B_dict = {str(j): round(float(thr_B[j]), 4) for j in range(n_labels)} with open(PREC_THRESH_OUT, "w") as f: json.dump(thr_B_dict, f) print(f" Saved to {PREC_THRESH_OUT.name}") # ───────────────────────────────────────────────────────────────────────── # STRATEGY C: novelty-gated (B thresholds + per-protein KNN gate) # ───────────────────────────────────────────────────────────────────────── print("\n=== Strategy C: novelty-gated (B + KNN on ESM embeddings) ===") test_sims = np.array([compute_novelty_sim(test_embs[i], knn_ref) for i in range(len(test_embs))], dtype=np.float32) novelty_cut = float(np.quantile(test_sims, NOVELTY_Q)) novelty_min = float(test_sims.min()) novelty_mask = test_sims <= novelty_cut print(f" Novelty gate: bottom {int(NOVELTY_Q * 100)}% proteins by KNN similarity") print(f" Similarity stats: min={test_sims.min():.3f} cut={novelty_cut:.3f} " f"mean={test_sims.mean():.3f} max={test_sims.max():.3f}") metrics_C = evaluate(test_probs, Y_test, thr_B, mf_idx, embs=test_embs, knn_ref=knn_ref, novelty_cut=novelty_cut, novelty_min=novelty_min) print(json.dumps(metrics_C, indent=2)) novelty_subset_A = subset_metrics(test_probs, Y_test, thr_A, mf_idx, novelty_mask) novelty_subset_C = subset_metrics( test_probs, Y_test, thr_B, mf_idx, novelty_mask, embs=test_embs, knn_ref=knn_ref, novelty_cut=novelty_cut, novelty_min=novelty_min ) # ───────────────────────────────────────────────────────────────────────── # Compile and save # ───────────────────────────────────────────────────────────────────────── winner = max( [("A_current", metrics_A), ("B_precision", metrics_B), ("C_novelty", metrics_C)], key=lambda x: x[1]["micro_f1"] )[0] results = { "metadata": { "model": CKPT_PATH.name, "test_subset_size": SUBSET_SIZE, "train_knn_ref_size": TRAIN_KNN, "top_common_ic_scaled": TOP_COMMON, "knn_k": KNN_K, "current_thresholds": CURRENT_THRESH.name, "novelty_quantile": NOVELTY_Q, "novelty_subset_size": int(novelty_mask.sum()), "novelty_similarity_cut": round(novelty_cut, 6), "novelty_hi_thr": NOVELTY_HI_T, "n_mf_labels": int(len(mf_idx)), "mean_direct_labels_per_protein": round( float(Y_test[:, mf_idx].sum(1).mean()), 2), }, "threshold_distributions": { "A_current_v3": thr_stats(thr_A, mf_idx), "B_precision_ic": thr_stats(thr_B, mf_idx), }, "ic_scaled_top25": [ {"label_idx": int(j), "go_id": gid, "train_freq": freq, "old_thr": round(old, 4), "new_thr": round(new, 4)} for j, gid, freq, old, new in sorted(ic_adj, key=lambda x: -x[2])[:TOP_COMMON] ], "metrics": { "A_current_thresholds": metrics_A, "B_precision_ic": metrics_B, "C_novelty_gated": metrics_C, "novelty_subset": { "A_current_thresholds": novelty_subset_A, "C_novelty_gated": novelty_subset_C, }, }, "deltas": { "A_vs_B": { "precision_delta": round(metrics_B["micro_precision"] - metrics_A["micro_precision"], 4), "recall_delta": round(metrics_B["micro_recall"] - metrics_A["micro_recall"], 4), "f1_delta": round(metrics_B["micro_f1"] - metrics_A["micro_f1"], 4), "mean_preds_delta": round(metrics_B["mean_preds_per_protein"] - metrics_A["mean_preds_per_protein"], 2), }, "B_vs_C": { "precision_delta": round(metrics_C["micro_precision"] - metrics_B["micro_precision"], 4), "recall_delta": round(metrics_C["micro_recall"] - metrics_B["micro_recall"], 4), "f1_delta": round(metrics_C["micro_f1"] - metrics_B["micro_f1"], 4), "mean_preds_delta": round(metrics_C["mean_preds_per_protein"] - metrics_B["mean_preds_per_protein"], 2), }, "novelty_subset_A_vs_C": { "precision_delta": round(novelty_subset_C["micro_precision"] - novelty_subset_A["micro_precision"], 4), "recall_delta": round(novelty_subset_C["micro_recall"] - novelty_subset_A["micro_recall"], 4), "f1_delta": round(novelty_subset_C["micro_f1"] - novelty_subset_A["micro_f1"], 4), "mean_preds_delta": round(novelty_subset_C["mean_preds_per_protein"] - novelty_subset_A["mean_preds_per_protein"], 2), }, }, "summary": { "winner_by_f1": winner, "mean_preds": {"A": metrics_A["mean_preds_per_protein"], "B": metrics_B["mean_preds_per_protein"], "C": metrics_C["mean_preds_per_protein"]}, "pct_noisy_gt15": {"A": metrics_A["pct_noisy_gt15"], "B": metrics_B["pct_noisy_gt15"], "C": metrics_C["pct_noisy_gt15"]}, "precision": {"A": metrics_A["micro_precision"], "B": metrics_B["micro_precision"], "C": metrics_C["micro_precision"]}, "novelty_subset_f1": { "A": novelty_subset_A["micro_f1"], "C": novelty_subset_C["micro_f1"], }, }, "elapsed_seconds": round(time.time() - t0, 1), } with open(OUT_PATH, "w") as f: json.dump(results, f, indent=2) print(f"\n{'='*60}") print(f"Results saved to {OUT_PATH}") print(f"Elapsed: {results['elapsed_seconds']}s") print(f"\nSummary:") print(f" {'Strategy':<35} {'Prec':>6} {'Rec':>6} {'F1':>6} {'AvgN':>6} {'>15%':>6}") for name, m in [("A current webapp", metrics_A), ("B precision+IC", metrics_B), ("C novelty-gated", metrics_C)]: print(f" {name:<35} {m['micro_precision']:>6.4f} {m['micro_recall']:>6.4f} " f"{m['micro_f1']:>6.4f} {m['mean_preds_per_protein']:>6.1f} " f"{m['pct_noisy_gt15']:>5.1f}%") if __name__ == "__main__": main()