protfunc / scripts /threshold_comparison.py
Sbhat2026's picture
perf: ESM embedding cache + 1500aa limit, add research scripts
7f7a890
"""
threshold_comparison.py
=======================
Compare current MF thresholding vs a stricter precision-first alternative on
the ProtFunc v3 pipeline.
Strategies:
A. Current ProtFunc v3 thresholds
B. Precision-first MF thresholds + IC scaling for top-25 most common MF terms
C. B + novelty gating on the most novel proteins (bottom similarity quantile)
The script intentionally evaluates only molecular-function labels on direct
annotations. It keeps non-MF thresholds unchanged in the saved JSON and reports:
- overall metrics on a random test subset
- novelty-subset metrics on the bottom-20% KNN-similarity proteins
Outputs:
artifacts/thresholds/precision_ic_thresholds.json
artifacts/threshold_comparison_results.json
"""
import ast
import json
import math
import time
import warnings
from pathlib import Path
import numpy as np
import pandas as pd
import joblib
import torch
import torch.nn as nn
warnings.filterwarnings("ignore")
BASE = Path(__file__).parent.parent
ART = BASE / "artifacts"
IMPORTANT = BASE / "Important Files"
DATA_BASE = IMPORTANT / "merged_full_struct.parquet"
DATA_SUPP = IMPORTANT / "merged_full_struct_with_features.parquet"
MLB_PATH = IMPORTANT / "mlb_public_v1.pkl"
SPLITS_NPZ = ART / "splits" / "splits_n250000_seed42.npz"
OBO_PATH = BASE / "go-basic.obo"
CKPT_PATH = ART / "graph_hpo" / "graph_hpo_best.pth"
CURRENT_THRESH = ART / "graph_hpo" / "graph_hpo_best_thresholds.json"
OUT_PATH = ART / "graph_hpo" / "threshold_comparison_graph_hpo.json"
PREC_THRESH_OUT = ART / "graph_hpo" / "precision_ic_thresholds_graph_hpo.json"
SUBSET_SIZE = 2000 # test proteins
TRAIN_KNN = 5000 # training proteins for KNN reference
TOP_COMMON = 25 # top-N by frequency for IC scaling
KNN_K = 10
NOVELTY_Q = 0.20 # bottom quantile of proteins treated as "novel"
NOVELTY_HI_T = 0.996 # ceiling for the most novel proteins
SEED = 42
rng = np.random.default_rng(SEED)
# ─── Architecture ─────────────────────────────────────────────────────────────
class ResBlock(nn.Module):
def __init__(self, dim, dropout=0.2):
super().__init__()
self.net = nn.Sequential(
nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim),
nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim),
)
def forward(self, x): return x + self.net(x)
class ImprovedResidualMLP(nn.Module):
def __init__(self, in_dim, out_dim=8124, hidden=2048, n_blocks=4, dropout=0.2):
super().__init__()
self.fc_in = nn.Linear(in_dim, hidden)
self.blocks = nn.ModuleList([ResBlock(hidden, dropout) for _ in range(n_blocks)])
self.fc_out = nn.Sequential(
nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(hidden, out_dim),
)
def forward(self, x):
h = self.fc_in(x)
for b in self.blocks: h = b(h)
return self.fc_out(h)
# ─── GO hierarchy ─────────────────────────────────────────────────────────────
def load_go_hierarchy(obo_path):
ns_map, par_map = {}, {}
cur_id, cur_ns, cur_par, in_term = None, None, set(), False
def flush():
nonlocal cur_id, cur_ns, cur_par
if cur_id and cur_ns:
ns_map[cur_id] = cur_ns
par_map[cur_id] = set(cur_par)
cur_id, cur_ns, cur_par = None, None, set()
with open(obo_path) as fh:
for raw in fh:
line = raw.strip()
if line == "[Term]":
flush(); in_term = True; continue
if line.startswith("[") and line != "[Term]":
flush(); in_term = False; continue
if not in_term: continue
if line.startswith("id:"):
cur_id = line.split("id:",1)[1].strip().split()[0]
elif line.startswith("namespace:"):
cur_ns = line.split("namespace:",1)[1].strip()
elif "is_obsolete:" in line and "true" in line:
cur_id = None
elif line.startswith("is_a:"):
cur_par.add(line.split("is_a:",1)[1].strip().split()[0])
elif line.startswith("relationship:"):
pts = line.split("relationship:",1)[1].strip().split()
if len(pts) >= 2 and pts[0] == "part_of":
cur_par.add(pts[1])
flush()
mf = {g for g, n in ns_map.items() if n == "molecular_function"}
return {g: (par_map.get(g, set()) & mf) for g in mf}
# ─── Label parsing + feature assembly ─────────────────────────────────────────
def parse_labels(x):
if x is None:
return []
if isinstance(x, (list, np.ndarray)):
return [int(v) for v in x]
if isinstance(x, str):
s = x.strip()
if not s or s.lower() == "nan":
return []
try:
v = ast.literal_eval(s)
return [int(i) for i in (v if isinstance(v, (list, tuple)) else [v])]
except Exception:
return []
return []
def build_inputs(df_base, df_supp, indices, ckpt):
emb_cols = [c for c in df_base.columns if c.startswith("Dim_")]
x_base = df_base.iloc[indices][emb_cols].to_numpy(np.float32)
supp_cols = ckpt.get("supp_cols", [])
if not supp_cols:
return x_base
mu = np.asarray(ckpt["supp_mu"], dtype=np.float32)
sd = np.asarray(ckpt["supp_sd"], dtype=np.float32)
s = df_supp.iloc[indices][supp_cols].to_numpy(np.float32)
s_z = (s - mu) / (sd + 1e-12)
in_dim = ckpt.get("in_dim")
n_supp_used = in_dim - x_base.shape[1] if in_dim else len(supp_cols)
# esm_seq / partial supp: use only first n_supp_used cols
if n_supp_used <= len(supp_cols):
s_z = s_z[:, :n_supp_used]
return np.concatenate([x_base, s_z], axis=1).astype(np.float32)
# esm_all with m_flag appended
if in_dim == x_base.shape[1] + len(supp_cols) + 1:
af_present = df_supp.iloc[indices]["f_af_present"].to_numpy(np.float32).reshape(-1, 1)
return np.concatenate([x_base, s_z, af_present], axis=1).astype(np.float32)
raise ValueError(
f"Unsupported input shape for checkpoint: in_dim={in_dim} "
f"vs base={x_base.shape[1]} supp={len(supp_cols)}"
)
# ─── Precision-biased threshold sweep (MF only) ──────────────────────────────
def compute_fbeta_thresholds(probs, true, beta=0.5, steps=None, min_support=10, floor=0.90):
"""
Per label: find threshold maximising F-beta on the high-threshold regime.
For this v3 MF model, useful separation happens around 0.80+, so sweeping
low thresholds only reproduces the overprediction failure mode.
"""
if steps is None:
coarse = np.arange(0.80, 0.981, 0.01, dtype=np.float32)
fine = np.arange(0.982, 0.996, 0.002, dtype=np.float32)
steps = np.concatenate([coarse, fine]).astype(np.float32)
n_labels = probs.shape[1]
thr = np.full(n_labels, floor, dtype=np.float32)
b2 = beta ** 2
for j in range(n_labels):
pj = probs[:, j]
tj = true[:, j]
if tj.sum() < min_support:
continue
best_fb, best_t, best_prec = -1.0, floor, -1.0
for t in steps:
pred = (pj >= t).astype(np.float32)
tp = (pred * tj).sum()
fp = (pred * (1 - tj)).sum()
fn = ((1 - pred) * tj).sum()
prec = tp / (tp + fp + 1e-9)
denom = (1 + b2) * tp + b2 * fn + fp
fb = ((1 + b2) * tp / denom) if denom > 0 else 0.0
if fb > best_fb or (abs(fb - best_fb) < 1e-12 and prec > best_prec):
best_fb, best_t, best_prec = float(fb), float(t), float(prec)
thr[j] = best_t
return thr
# ─── IC-scaled thresholds for top-N most common terms ─────────────────────────
def ic_scaled_thresholds(base_thr, label_freq, mlb_classes, mf_idx, top_n=25):
"""
For the top_n most annotated MF GO terms, raise threshold proportionally to how
broad the term is (low IC = high annotation frequency = raise threshold more).
IC of term i = -log2(freq_i / total_annotations).
New threshold = base * (1 + alpha*(1 - IC_i/max_IC)), capped near 1.0.
alpha=0.40.
"""
thr = base_thr.copy()
total = label_freq.sum() + 1e-9
ic = np.zeros(len(mlb_classes))
for j in range(len(mlb_classes)):
if label_freq[j] > 0:
ic[j] = -math.log2(label_freq[j] / total)
# restrict to MF labels only
mf_set = set(mf_idx.tolist())
mf_freq = np.zeros(len(mlb_classes))
for j in range(len(mlb_classes)):
if j in mf_set:
mf_freq[j] = label_freq[j]
top_idx = np.argsort(mf_freq)[-top_n:]
max_ic = ic[top_idx].max() if len(top_idx) else 1.0
alpha = 0.40
adjustments = []
for j in top_idx:
ic_norm = ic[j] / (max_ic + 1e-9)
scale = 1.0 + alpha * (1.0 - ic_norm)
new_t = min(NOVELTY_HI_T, float(base_thr[j]) * scale)
adjustments.append((j, mlb_classes[j], int(label_freq[j]),
float(base_thr[j]), new_t))
thr[j] = new_t
print(f" IC-scaled top-{top_n} most frequent MF terms:")
for j, gid, freq, old_t, new_t in sorted(adjustments, key=lambda x: -x[2])[:8]:
print(f" {gid} freq={freq:,} base={old_t:.3f} β†’ {new_t:.3f}")
return thr, adjustments
# ─── KNN novelty ──────────────────────────────────────────────────────────────
def build_knn_ref(embs):
"""L2-normalise for cosine similarity."""
norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-9
return embs / norms # (N, dim)
def compute_novelty_sim(query_emb, knn_ref, k=KNN_K):
"""Mean cosine similarity to top-k neighbours (higher = more familiar)."""
q = query_emb / (np.linalg.norm(query_emb) + 1e-9) # (dim,)
sims = knn_ref @ q # (N,)
top_k = np.partition(sims, -k)[-k:]
return float(top_k.mean())
def apply_novelty_gate(base_thr, sim, lo, sim_min, hi_t=NOVELTY_HI_T):
"""
Quantile-gated thresholding:
proteins above the novelty cutoff keep the precision+IC thresholds;
proteins below it are pushed toward hi_t based on relative novelty.
"""
if sim >= lo:
return base_thr
alpha = min(1.0, max(0.0, (lo - sim) / (lo - sim_min + 1e-9)))
return base_thr + alpha * (hi_t - base_thr)
# ─── Evaluation ───────────────────────────────────────────────────────────────
def evaluate(probs, true_y, thr_arr, mf_idx, embs=None, knn_ref=None, novelty_cut=None, novelty_min=None):
"""
Compute per-protein predictions then aggregate metrics.
If embs + knn_ref supplied β†’ apply novelty gating per protein.
true_y: (N, n_labels) from Label_Indices (direct labels, not propagated).
"""
n = probs.shape[0]
tp_tot = fp_tot = fn_tot = 0
n_preds_list = []
for i in range(n):
pv = probs[i, mf_idx]
tv = true_y[i, mf_idx]
thr = thr_arr[mf_idx]
if embs is not None and knn_ref is not None and novelty_cut is not None and novelty_min is not None:
sim = compute_novelty_sim(embs[i], knn_ref)
thr = apply_novelty_gate(thr, sim, novelty_cut, novelty_min)
pred = (pv >= thr).astype(np.float32)
tp = (pred * tv).sum(); fp = (pred * (1 - tv)).sum(); fn = ((1 - pred) * tv).sum()
tp_tot += tp; fp_tot += fp; fn_tot += fn
n_preds_list.append(int(pred.sum()))
prec = tp_tot / (tp_tot + fp_tot + 1e-9)
rec = tp_tot / (tp_tot + fn_tot + 1e-9)
f1 = 2 * prec * rec / (prec + rec + 1e-9)
ndl = np.array(n_preds_list)
return {
"micro_precision": round(float(prec), 4),
"micro_recall": round(float(rec), 4),
"micro_f1": round(float(f1), 4),
"mean_preds_per_protein": round(float(ndl.mean()), 2),
"median_preds": float(np.median(ndl)),
"pct_tight_le5": round(float((ndl <= 5).mean() * 100), 1),
"pct_noisy_gt15": round(float((ndl > 15).mean() * 100), 1),
"pct_zero_preds": round(float((ndl == 0).mean() * 100), 1),
"coverage_pct": round(float((ndl > 0).mean() * 100), 1),
}
def subset_metrics(probs, true_y, thr_arr, mf_idx, subset_mask, embs=None, knn_ref=None, novelty_cut=None, novelty_min=None):
idx = np.flatnonzero(subset_mask)
return evaluate(
probs[idx],
true_y[idx],
thr_arr,
mf_idx,
embs=None if embs is None else embs[idx],
knn_ref=knn_ref,
novelty_cut=novelty_cut,
novelty_min=novelty_min,
)
def thr_stats(thr, idx):
v = thr[idx]
return {
"mean": round(float(v.mean()), 4),
"min": round(float(v.min()), 4),
"max": round(float(v.max()), 4),
"pct_lt03": round(float((v < 0.3).mean() * 100), 1),
"pct_lt05": round(float((v < 0.5).mean() * 100), 1),
"pct_ge07": round(float((v >= 0.7).mean() * 100), 1),
}
# ─── Main ─────────────────────────────────────────────────────────────────────
def main():
t0 = time.time()
device = torch.device("cpu")
# ── Data ───────────────────────────────────────────────────────────────────
print("Loading data...")
df_base = pd.read_parquet(DATA_BASE)
df_supp = pd.read_parquet(DATA_SUPP)
mlb = joblib.load(MLB_PATH)
n_labels = len(mlb.classes_)
splits = np.load(SPLITS_NPZ, allow_pickle=True)
train_idx = splits["train_idx"]
val_idx = splits["val_idx"]
test_idx = splits["test_idx"]
test_sub = rng.choice(test_idx, size=SUBSET_SIZE, replace=False)
train_sub = rng.choice(train_idx, size=TRAIN_KNN, replace=False)
print(f" Splits: train={len(train_idx)}, val={len(val_idx)}, "
f"test subset={len(test_sub)}, knn_ref={len(train_sub)}")
# ── Label matrix helper ────────────────────────────────────────────────────
def make_Y(indices):
Y = np.zeros((len(indices), n_labels), dtype=np.float32)
for r, row in enumerate(df_supp.iloc[indices]["Label_Indices"]):
for v in parse_labels(row):
if 0 <= int(v) < n_labels:
Y[r, int(v)] = 1.0
return Y
# ── GO hierarchy β†’ MF indices ──────────────────────────────────────────────
print("Loading GO hierarchy...")
go_parents = load_go_hierarchy(OBO_PATH)
mf_go_ids = set(go_parents.keys())
mf_idx = np.array([j for j, c in enumerate(mlb.classes_) if c in mf_go_ids])
print(f" MF labels in MLB: {len(mf_idx)}")
# ── Model ─────────────────────────────────────────────────────────────────
print(f"Loading model {CKPT_PATH.name}...")
ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False)
state = ckpt.get("model", ckpt)
in_dim = state["fc_in.weight"].shape[1]
model = ImprovedResidualMLP(
in_dim=in_dim,
out_dim=n_labels,
hidden=ckpt.get("hidden", 2048),
n_blocks=ckpt.get("n_blocks", 4),
).to(device)
model.load_state_dict(state)
model.eval()
print(f" in_dim={in_dim}")
# ── Inference ─────────────────────────────────────────────────────────────
def run_inference(indices, desc):
full_x = build_inputs(df_base, df_supp, indices, ckpt)
esm_x = df_base.iloc[indices][[c for c in df_base.columns if c.startswith("Dim_")]].values.astype(np.float32)
result = []
with torch.no_grad():
for s in range(0, len(full_x), 512):
xb = torch.tensor(full_x[s:s+512]).to(device)
result.append(torch.sigmoid(model(xb)).cpu().numpy())
print(f" {desc}: {len(indices)} proteins done")
return np.concatenate(result, axis=0), esm_x
print("Running inference...")
val_probs, val_embs = run_inference(val_idx, "val")
test_probs, test_embs = run_inference(test_sub, "test subset")
Y_val = make_Y(val_idx)
Y_test = make_Y(test_sub)
print(f" Mean direct labels/protein in test subset: "
f"{Y_test[:, mf_idx].sum(1).mean():.2f}")
# ── KNN reference ─────────────────────────────────────────────────────────
print("Building KNN reference...")
train_embs = df_base.iloc[train_sub][[c for c in df_base.columns if c.startswith('Dim_')]].values.astype(np.float32)
knn_ref = build_knn_ref(train_embs) # (TRAIN_KNN, 320)
# ── Label frequency from training set ─────────────────────────────────────
print("Computing label frequencies...")
label_freq = np.zeros(n_labels, dtype=np.float32)
for row in df_supp.iloc[train_idx]["Label_Indices"]:
for v in parse_labels(row):
if int(v) < n_labels:
label_freq[int(v)] += 1
# ─────────────────────────────────────────────────────────────────────────
# STRATEGY A: current v3 thresholds
# ─────────────────────────────────────────────────────────────────────────
print("\n=== Strategy A: current ProtFunc v3 thresholds ===")
with open(CURRENT_THRESH) as f:
curr_dict = json.load(f)
thr_A = np.full(n_labels, 0.5, dtype=np.float32)
for k, v in curr_dict.items():
thr_A[int(k)] = float(v)
metrics_A = evaluate(test_probs, Y_test, thr_A, mf_idx)
print(json.dumps(metrics_A, indent=2))
# ─────────────────────────────────────────────────────────────────────────
# STRATEGY B: precision-biased (F-Ξ²=0.5) + IC-scaled top-25
# ─────────────────────────────────────────────────────────────────────────
print("\n=== Strategy B: MF precision thresholds + IC-scaled top-25 ===")
print(" Sweeping MF thresholds on val set in the high-confidence regime...")
thr_B = thr_A.copy()
thr_B_mf = compute_fbeta_thresholds(val_probs[:, mf_idx], Y_val[:, mf_idx], beta=0.5, floor=0.90)
thr_B[mf_idx] = thr_B_mf
thr_B, ic_adj = ic_scaled_thresholds(thr_B, label_freq, mlb.classes_, mf_idx, top_n=TOP_COMMON)
metrics_B = evaluate(test_probs, Y_test, thr_B, mf_idx)
print(json.dumps(metrics_B, indent=2))
# Save
thr_B_dict = {str(j): round(float(thr_B[j]), 4) for j in range(n_labels)}
with open(PREC_THRESH_OUT, "w") as f:
json.dump(thr_B_dict, f)
print(f" Saved to {PREC_THRESH_OUT.name}")
# ─────────────────────────────────────────────────────────────────────────
# STRATEGY C: novelty-gated (B thresholds + per-protein KNN gate)
# ─────────────────────────────────────────────────────────────────────────
print("\n=== Strategy C: novelty-gated (B + KNN on ESM embeddings) ===")
test_sims = np.array([compute_novelty_sim(test_embs[i], knn_ref) for i in range(len(test_embs))], dtype=np.float32)
novelty_cut = float(np.quantile(test_sims, NOVELTY_Q))
novelty_min = float(test_sims.min())
novelty_mask = test_sims <= novelty_cut
print(f" Novelty gate: bottom {int(NOVELTY_Q * 100)}% proteins by KNN similarity")
print(f" Similarity stats: min={test_sims.min():.3f} cut={novelty_cut:.3f} "
f"mean={test_sims.mean():.3f} max={test_sims.max():.3f}")
metrics_C = evaluate(test_probs, Y_test, thr_B, mf_idx,
embs=test_embs, knn_ref=knn_ref,
novelty_cut=novelty_cut, novelty_min=novelty_min)
print(json.dumps(metrics_C, indent=2))
novelty_subset_A = subset_metrics(test_probs, Y_test, thr_A, mf_idx, novelty_mask)
novelty_subset_C = subset_metrics(
test_probs, Y_test, thr_B, mf_idx, novelty_mask,
embs=test_embs, knn_ref=knn_ref, novelty_cut=novelty_cut, novelty_min=novelty_min
)
# ─────────────────────────────────────────────────────────────────────────
# Compile and save
# ─────────────────────────────────────────────────────────────────────────
winner = max(
[("A_current", metrics_A),
("B_precision", metrics_B),
("C_novelty", metrics_C)],
key=lambda x: x[1]["micro_f1"]
)[0]
results = {
"metadata": {
"model": CKPT_PATH.name,
"test_subset_size": SUBSET_SIZE,
"train_knn_ref_size": TRAIN_KNN,
"top_common_ic_scaled": TOP_COMMON,
"knn_k": KNN_K,
"current_thresholds": CURRENT_THRESH.name,
"novelty_quantile": NOVELTY_Q,
"novelty_subset_size": int(novelty_mask.sum()),
"novelty_similarity_cut": round(novelty_cut, 6),
"novelty_hi_thr": NOVELTY_HI_T,
"n_mf_labels": int(len(mf_idx)),
"mean_direct_labels_per_protein": round(
float(Y_test[:, mf_idx].sum(1).mean()), 2),
},
"threshold_distributions": {
"A_current_v3": thr_stats(thr_A, mf_idx),
"B_precision_ic": thr_stats(thr_B, mf_idx),
},
"ic_scaled_top25": [
{"label_idx": int(j), "go_id": gid, "train_freq": freq,
"old_thr": round(old, 4), "new_thr": round(new, 4)}
for j, gid, freq, old, new in
sorted(ic_adj, key=lambda x: -x[2])[:TOP_COMMON]
],
"metrics": {
"A_current_thresholds": metrics_A,
"B_precision_ic": metrics_B,
"C_novelty_gated": metrics_C,
"novelty_subset": {
"A_current_thresholds": novelty_subset_A,
"C_novelty_gated": novelty_subset_C,
},
},
"deltas": {
"A_vs_B": {
"precision_delta": round(metrics_B["micro_precision"] - metrics_A["micro_precision"], 4),
"recall_delta": round(metrics_B["micro_recall"] - metrics_A["micro_recall"], 4),
"f1_delta": round(metrics_B["micro_f1"] - metrics_A["micro_f1"], 4),
"mean_preds_delta": round(metrics_B["mean_preds_per_protein"] -
metrics_A["mean_preds_per_protein"], 2),
},
"B_vs_C": {
"precision_delta": round(metrics_C["micro_precision"] - metrics_B["micro_precision"], 4),
"recall_delta": round(metrics_C["micro_recall"] - metrics_B["micro_recall"], 4),
"f1_delta": round(metrics_C["micro_f1"] - metrics_B["micro_f1"], 4),
"mean_preds_delta": round(metrics_C["mean_preds_per_protein"] -
metrics_B["mean_preds_per_protein"], 2),
},
"novelty_subset_A_vs_C": {
"precision_delta": round(novelty_subset_C["micro_precision"] - novelty_subset_A["micro_precision"], 4),
"recall_delta": round(novelty_subset_C["micro_recall"] - novelty_subset_A["micro_recall"], 4),
"f1_delta": round(novelty_subset_C["micro_f1"] - novelty_subset_A["micro_f1"], 4),
"mean_preds_delta": round(novelty_subset_C["mean_preds_per_protein"] -
novelty_subset_A["mean_preds_per_protein"], 2),
},
},
"summary": {
"winner_by_f1": winner,
"mean_preds": {"A": metrics_A["mean_preds_per_protein"],
"B": metrics_B["mean_preds_per_protein"],
"C": metrics_C["mean_preds_per_protein"]},
"pct_noisy_gt15": {"A": metrics_A["pct_noisy_gt15"],
"B": metrics_B["pct_noisy_gt15"],
"C": metrics_C["pct_noisy_gt15"]},
"precision": {"A": metrics_A["micro_precision"],
"B": metrics_B["micro_precision"],
"C": metrics_C["micro_precision"]},
"novelty_subset_f1": {
"A": novelty_subset_A["micro_f1"],
"C": novelty_subset_C["micro_f1"],
},
},
"elapsed_seconds": round(time.time() - t0, 1),
}
with open(OUT_PATH, "w") as f:
json.dump(results, f, indent=2)
print(f"\n{'='*60}")
print(f"Results saved to {OUT_PATH}")
print(f"Elapsed: {results['elapsed_seconds']}s")
print(f"\nSummary:")
print(f" {'Strategy':<35} {'Prec':>6} {'Rec':>6} {'F1':>6} {'AvgN':>6} {'>15%':>6}")
for name, m in [("A current webapp", metrics_A),
("B precision+IC", metrics_B),
("C novelty-gated", metrics_C)]:
print(f" {name:<35} {m['micro_precision']:>6.4f} {m['micro_recall']:>6.4f} "
f"{m['micro_f1']:>6.4f} {m['mean_preds_per_protein']:>6.1f} "
f"{m['pct_noisy_gt15']:>5.1f}%")
if __name__ == "__main__":
main()