| """ |
| threshold_comparison.py |
| ======================= |
| Compare current MF thresholding vs a stricter precision-first alternative on |
| the ProtFunc v3 pipeline. |
| |
| Strategies: |
| A. Current ProtFunc v3 thresholds |
| B. Precision-first MF thresholds + IC scaling for top-25 most common MF terms |
| C. B + novelty gating on the most novel proteins (bottom similarity quantile) |
| |
| The script intentionally evaluates only molecular-function labels on direct |
| annotations. It keeps non-MF thresholds unchanged in the saved JSON and reports: |
| - overall metrics on a random test subset |
| - novelty-subset metrics on the bottom-20% KNN-similarity proteins |
| |
| Outputs: |
| artifacts/thresholds/precision_ic_thresholds.json |
| artifacts/threshold_comparison_results.json |
| """ |
|
|
| import ast |
| import json |
| import math |
| import time |
| import warnings |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| import joblib |
| import torch |
| import torch.nn as nn |
|
|
| warnings.filterwarnings("ignore") |
|
|
| BASE = Path(__file__).parent.parent |
| ART = BASE / "artifacts" |
| IMPORTANT = BASE / "Important Files" |
|
|
| DATA_BASE = IMPORTANT / "merged_full_struct.parquet" |
| DATA_SUPP = IMPORTANT / "merged_full_struct_with_features.parquet" |
| MLB_PATH = IMPORTANT / "mlb_public_v1.pkl" |
| SPLITS_NPZ = ART / "splits" / "splits_n250000_seed42.npz" |
| OBO_PATH = BASE / "go-basic.obo" |
| CKPT_PATH = ART / "graph_hpo" / "graph_hpo_best.pth" |
| CURRENT_THRESH = ART / "graph_hpo" / "graph_hpo_best_thresholds.json" |
| OUT_PATH = ART / "graph_hpo" / "threshold_comparison_graph_hpo.json" |
| PREC_THRESH_OUT = ART / "graph_hpo" / "precision_ic_thresholds_graph_hpo.json" |
|
|
| SUBSET_SIZE = 2000 |
| TRAIN_KNN = 5000 |
| TOP_COMMON = 25 |
| KNN_K = 10 |
| NOVELTY_Q = 0.20 |
| NOVELTY_HI_T = 0.996 |
| SEED = 42 |
|
|
| rng = np.random.default_rng(SEED) |
|
|
|
|
| |
|
|
| class ResBlock(nn.Module): |
| def __init__(self, dim, dropout=0.2): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), |
| nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), |
| ) |
| def forward(self, x): return x + self.net(x) |
|
|
| class ImprovedResidualMLP(nn.Module): |
| def __init__(self, in_dim, out_dim=8124, hidden=2048, n_blocks=4, dropout=0.2): |
| super().__init__() |
| self.fc_in = nn.Linear(in_dim, hidden) |
| self.blocks = nn.ModuleList([ResBlock(hidden, dropout) for _ in range(n_blocks)]) |
| self.fc_out = nn.Sequential( |
| nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(dropout), |
| nn.Linear(hidden, out_dim), |
| ) |
| def forward(self, x): |
| h = self.fc_in(x) |
| for b in self.blocks: h = b(h) |
| return self.fc_out(h) |
|
|
|
|
| |
|
|
| def load_go_hierarchy(obo_path): |
| ns_map, par_map = {}, {} |
| cur_id, cur_ns, cur_par, in_term = None, None, set(), False |
| def flush(): |
| nonlocal cur_id, cur_ns, cur_par |
| if cur_id and cur_ns: |
| ns_map[cur_id] = cur_ns |
| par_map[cur_id] = set(cur_par) |
| cur_id, cur_ns, cur_par = None, None, set() |
| with open(obo_path) as fh: |
| for raw in fh: |
| line = raw.strip() |
| if line == "[Term]": |
| flush(); in_term = True; continue |
| if line.startswith("[") and line != "[Term]": |
| flush(); in_term = False; continue |
| if not in_term: continue |
| if line.startswith("id:"): |
| cur_id = line.split("id:",1)[1].strip().split()[0] |
| elif line.startswith("namespace:"): |
| cur_ns = line.split("namespace:",1)[1].strip() |
| elif "is_obsolete:" in line and "true" in line: |
| cur_id = None |
| elif line.startswith("is_a:"): |
| cur_par.add(line.split("is_a:",1)[1].strip().split()[0]) |
| elif line.startswith("relationship:"): |
| pts = line.split("relationship:",1)[1].strip().split() |
| if len(pts) >= 2 and pts[0] == "part_of": |
| cur_par.add(pts[1]) |
| flush() |
| mf = {g for g, n in ns_map.items() if n == "molecular_function"} |
| return {g: (par_map.get(g, set()) & mf) for g in mf} |
|
|
|
|
| |
|
|
| def parse_labels(x): |
| if x is None: |
| return [] |
| if isinstance(x, (list, np.ndarray)): |
| return [int(v) for v in x] |
| if isinstance(x, str): |
| s = x.strip() |
| if not s or s.lower() == "nan": |
| return [] |
| try: |
| v = ast.literal_eval(s) |
| return [int(i) for i in (v if isinstance(v, (list, tuple)) else [v])] |
| except Exception: |
| return [] |
| return [] |
|
|
|
|
| def build_inputs(df_base, df_supp, indices, ckpt): |
| emb_cols = [c for c in df_base.columns if c.startswith("Dim_")] |
| x_base = df_base.iloc[indices][emb_cols].to_numpy(np.float32) |
|
|
| supp_cols = ckpt.get("supp_cols", []) |
| if not supp_cols: |
| return x_base |
|
|
| mu = np.asarray(ckpt["supp_mu"], dtype=np.float32) |
| sd = np.asarray(ckpt["supp_sd"], dtype=np.float32) |
| s = df_supp.iloc[indices][supp_cols].to_numpy(np.float32) |
| s_z = (s - mu) / (sd + 1e-12) |
|
|
| in_dim = ckpt.get("in_dim") |
| n_supp_used = in_dim - x_base.shape[1] if in_dim else len(supp_cols) |
|
|
| |
| if n_supp_used <= len(supp_cols): |
| s_z = s_z[:, :n_supp_used] |
| return np.concatenate([x_base, s_z], axis=1).astype(np.float32) |
|
|
| |
| if in_dim == x_base.shape[1] + len(supp_cols) + 1: |
| af_present = df_supp.iloc[indices]["f_af_present"].to_numpy(np.float32).reshape(-1, 1) |
| return np.concatenate([x_base, s_z, af_present], axis=1).astype(np.float32) |
|
|
| raise ValueError( |
| f"Unsupported input shape for checkpoint: in_dim={in_dim} " |
| f"vs base={x_base.shape[1]} supp={len(supp_cols)}" |
| ) |
|
|
|
|
| |
|
|
| def compute_fbeta_thresholds(probs, true, beta=0.5, steps=None, min_support=10, floor=0.90): |
| """ |
| Per label: find threshold maximising F-beta on the high-threshold regime. |
| For this v3 MF model, useful separation happens around 0.80+, so sweeping |
| low thresholds only reproduces the overprediction failure mode. |
| """ |
| if steps is None: |
| coarse = np.arange(0.80, 0.981, 0.01, dtype=np.float32) |
| fine = np.arange(0.982, 0.996, 0.002, dtype=np.float32) |
| steps = np.concatenate([coarse, fine]).astype(np.float32) |
| n_labels = probs.shape[1] |
| thr = np.full(n_labels, floor, dtype=np.float32) |
| b2 = beta ** 2 |
| for j in range(n_labels): |
| pj = probs[:, j] |
| tj = true[:, j] |
| if tj.sum() < min_support: |
| continue |
| best_fb, best_t, best_prec = -1.0, floor, -1.0 |
| for t in steps: |
| pred = (pj >= t).astype(np.float32) |
| tp = (pred * tj).sum() |
| fp = (pred * (1 - tj)).sum() |
| fn = ((1 - pred) * tj).sum() |
| prec = tp / (tp + fp + 1e-9) |
| denom = (1 + b2) * tp + b2 * fn + fp |
| fb = ((1 + b2) * tp / denom) if denom > 0 else 0.0 |
| if fb > best_fb or (abs(fb - best_fb) < 1e-12 and prec > best_prec): |
| best_fb, best_t, best_prec = float(fb), float(t), float(prec) |
| thr[j] = best_t |
| return thr |
|
|
|
|
| |
|
|
| def ic_scaled_thresholds(base_thr, label_freq, mlb_classes, mf_idx, top_n=25): |
| """ |
| For the top_n most annotated MF GO terms, raise threshold proportionally to how |
| broad the term is (low IC = high annotation frequency = raise threshold more). |
| IC of term i = -log2(freq_i / total_annotations). |
| New threshold = base * (1 + alpha*(1 - IC_i/max_IC)), capped near 1.0. |
| alpha=0.40. |
| """ |
| thr = base_thr.copy() |
| total = label_freq.sum() + 1e-9 |
| ic = np.zeros(len(mlb_classes)) |
| for j in range(len(mlb_classes)): |
| if label_freq[j] > 0: |
| ic[j] = -math.log2(label_freq[j] / total) |
|
|
| |
| mf_set = set(mf_idx.tolist()) |
| mf_freq = np.zeros(len(mlb_classes)) |
| for j in range(len(mlb_classes)): |
| if j in mf_set: |
| mf_freq[j] = label_freq[j] |
|
|
| top_idx = np.argsort(mf_freq)[-top_n:] |
| max_ic = ic[top_idx].max() if len(top_idx) else 1.0 |
| alpha = 0.40 |
| adjustments = [] |
| for j in top_idx: |
| ic_norm = ic[j] / (max_ic + 1e-9) |
| scale = 1.0 + alpha * (1.0 - ic_norm) |
| new_t = min(NOVELTY_HI_T, float(base_thr[j]) * scale) |
| adjustments.append((j, mlb_classes[j], int(label_freq[j]), |
| float(base_thr[j]), new_t)) |
| thr[j] = new_t |
|
|
| print(f" IC-scaled top-{top_n} most frequent MF terms:") |
| for j, gid, freq, old_t, new_t in sorted(adjustments, key=lambda x: -x[2])[:8]: |
| print(f" {gid} freq={freq:,} base={old_t:.3f} β {new_t:.3f}") |
| return thr, adjustments |
|
|
|
|
| |
|
|
| def build_knn_ref(embs): |
| """L2-normalise for cosine similarity.""" |
| norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-9 |
| return embs / norms |
|
|
| def compute_novelty_sim(query_emb, knn_ref, k=KNN_K): |
| """Mean cosine similarity to top-k neighbours (higher = more familiar).""" |
| q = query_emb / (np.linalg.norm(query_emb) + 1e-9) |
| sims = knn_ref @ q |
| top_k = np.partition(sims, -k)[-k:] |
| return float(top_k.mean()) |
|
|
| def apply_novelty_gate(base_thr, sim, lo, sim_min, hi_t=NOVELTY_HI_T): |
| """ |
| Quantile-gated thresholding: |
| proteins above the novelty cutoff keep the precision+IC thresholds; |
| proteins below it are pushed toward hi_t based on relative novelty. |
| """ |
| if sim >= lo: |
| return base_thr |
| alpha = min(1.0, max(0.0, (lo - sim) / (lo - sim_min + 1e-9))) |
| return base_thr + alpha * (hi_t - base_thr) |
|
|
|
|
| |
|
|
| def evaluate(probs, true_y, thr_arr, mf_idx, embs=None, knn_ref=None, novelty_cut=None, novelty_min=None): |
| """ |
| Compute per-protein predictions then aggregate metrics. |
| If embs + knn_ref supplied β apply novelty gating per protein. |
| true_y: (N, n_labels) from Label_Indices (direct labels, not propagated). |
| """ |
| n = probs.shape[0] |
| tp_tot = fp_tot = fn_tot = 0 |
| n_preds_list = [] |
|
|
| for i in range(n): |
| pv = probs[i, mf_idx] |
| tv = true_y[i, mf_idx] |
| thr = thr_arr[mf_idx] |
|
|
| if embs is not None and knn_ref is not None and novelty_cut is not None and novelty_min is not None: |
| sim = compute_novelty_sim(embs[i], knn_ref) |
| thr = apply_novelty_gate(thr, sim, novelty_cut, novelty_min) |
|
|
| pred = (pv >= thr).astype(np.float32) |
| tp = (pred * tv).sum(); fp = (pred * (1 - tv)).sum(); fn = ((1 - pred) * tv).sum() |
| tp_tot += tp; fp_tot += fp; fn_tot += fn |
| n_preds_list.append(int(pred.sum())) |
|
|
| prec = tp_tot / (tp_tot + fp_tot + 1e-9) |
| rec = tp_tot / (tp_tot + fn_tot + 1e-9) |
| f1 = 2 * prec * rec / (prec + rec + 1e-9) |
| ndl = np.array(n_preds_list) |
|
|
| return { |
| "micro_precision": round(float(prec), 4), |
| "micro_recall": round(float(rec), 4), |
| "micro_f1": round(float(f1), 4), |
| "mean_preds_per_protein": round(float(ndl.mean()), 2), |
| "median_preds": float(np.median(ndl)), |
| "pct_tight_le5": round(float((ndl <= 5).mean() * 100), 1), |
| "pct_noisy_gt15": round(float((ndl > 15).mean() * 100), 1), |
| "pct_zero_preds": round(float((ndl == 0).mean() * 100), 1), |
| "coverage_pct": round(float((ndl > 0).mean() * 100), 1), |
| } |
|
|
|
|
| def subset_metrics(probs, true_y, thr_arr, mf_idx, subset_mask, embs=None, knn_ref=None, novelty_cut=None, novelty_min=None): |
| idx = np.flatnonzero(subset_mask) |
| return evaluate( |
| probs[idx], |
| true_y[idx], |
| thr_arr, |
| mf_idx, |
| embs=None if embs is None else embs[idx], |
| knn_ref=knn_ref, |
| novelty_cut=novelty_cut, |
| novelty_min=novelty_min, |
| ) |
|
|
|
|
| def thr_stats(thr, idx): |
| v = thr[idx] |
| return { |
| "mean": round(float(v.mean()), 4), |
| "min": round(float(v.min()), 4), |
| "max": round(float(v.max()), 4), |
| "pct_lt03": round(float((v < 0.3).mean() * 100), 1), |
| "pct_lt05": round(float((v < 0.5).mean() * 100), 1), |
| "pct_ge07": round(float((v >= 0.7).mean() * 100), 1), |
| } |
|
|
|
|
| |
|
|
| def main(): |
| t0 = time.time() |
| device = torch.device("cpu") |
|
|
| |
| print("Loading data...") |
| df_base = pd.read_parquet(DATA_BASE) |
| df_supp = pd.read_parquet(DATA_SUPP) |
| mlb = joblib.load(MLB_PATH) |
| n_labels = len(mlb.classes_) |
|
|
| splits = np.load(SPLITS_NPZ, allow_pickle=True) |
| train_idx = splits["train_idx"] |
| val_idx = splits["val_idx"] |
| test_idx = splits["test_idx"] |
|
|
| test_sub = rng.choice(test_idx, size=SUBSET_SIZE, replace=False) |
| train_sub = rng.choice(train_idx, size=TRAIN_KNN, replace=False) |
| print(f" Splits: train={len(train_idx)}, val={len(val_idx)}, " |
| f"test subset={len(test_sub)}, knn_ref={len(train_sub)}") |
|
|
| |
| def make_Y(indices): |
| Y = np.zeros((len(indices), n_labels), dtype=np.float32) |
| for r, row in enumerate(df_supp.iloc[indices]["Label_Indices"]): |
| for v in parse_labels(row): |
| if 0 <= int(v) < n_labels: |
| Y[r, int(v)] = 1.0 |
| return Y |
|
|
| |
| print("Loading GO hierarchy...") |
| go_parents = load_go_hierarchy(OBO_PATH) |
| mf_go_ids = set(go_parents.keys()) |
| mf_idx = np.array([j for j, c in enumerate(mlb.classes_) if c in mf_go_ids]) |
| print(f" MF labels in MLB: {len(mf_idx)}") |
|
|
| |
| print(f"Loading model {CKPT_PATH.name}...") |
| ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False) |
| state = ckpt.get("model", ckpt) |
| in_dim = state["fc_in.weight"].shape[1] |
| model = ImprovedResidualMLP( |
| in_dim=in_dim, |
| out_dim=n_labels, |
| hidden=ckpt.get("hidden", 2048), |
| n_blocks=ckpt.get("n_blocks", 4), |
| ).to(device) |
| model.load_state_dict(state) |
| model.eval() |
| print(f" in_dim={in_dim}") |
|
|
| |
| def run_inference(indices, desc): |
| full_x = build_inputs(df_base, df_supp, indices, ckpt) |
| esm_x = df_base.iloc[indices][[c for c in df_base.columns if c.startswith("Dim_")]].values.astype(np.float32) |
| result = [] |
| with torch.no_grad(): |
| for s in range(0, len(full_x), 512): |
| xb = torch.tensor(full_x[s:s+512]).to(device) |
| result.append(torch.sigmoid(model(xb)).cpu().numpy()) |
| print(f" {desc}: {len(indices)} proteins done") |
| return np.concatenate(result, axis=0), esm_x |
|
|
| print("Running inference...") |
| val_probs, val_embs = run_inference(val_idx, "val") |
| test_probs, test_embs = run_inference(test_sub, "test subset") |
| Y_val = make_Y(val_idx) |
| Y_test = make_Y(test_sub) |
|
|
| print(f" Mean direct labels/protein in test subset: " |
| f"{Y_test[:, mf_idx].sum(1).mean():.2f}") |
|
|
| |
| print("Building KNN reference...") |
| train_embs = df_base.iloc[train_sub][[c for c in df_base.columns if c.startswith('Dim_')]].values.astype(np.float32) |
| knn_ref = build_knn_ref(train_embs) |
|
|
| |
| print("Computing label frequencies...") |
| label_freq = np.zeros(n_labels, dtype=np.float32) |
| for row in df_supp.iloc[train_idx]["Label_Indices"]: |
| for v in parse_labels(row): |
| if int(v) < n_labels: |
| label_freq[int(v)] += 1 |
|
|
| |
| |
| |
| print("\n=== Strategy A: current ProtFunc v3 thresholds ===") |
| with open(CURRENT_THRESH) as f: |
| curr_dict = json.load(f) |
| thr_A = np.full(n_labels, 0.5, dtype=np.float32) |
| for k, v in curr_dict.items(): |
| thr_A[int(k)] = float(v) |
| metrics_A = evaluate(test_probs, Y_test, thr_A, mf_idx) |
| print(json.dumps(metrics_A, indent=2)) |
|
|
| |
| |
| |
| print("\n=== Strategy B: MF precision thresholds + IC-scaled top-25 ===") |
| print(" Sweeping MF thresholds on val set in the high-confidence regime...") |
| thr_B = thr_A.copy() |
| thr_B_mf = compute_fbeta_thresholds(val_probs[:, mf_idx], Y_val[:, mf_idx], beta=0.5, floor=0.90) |
| thr_B[mf_idx] = thr_B_mf |
| thr_B, ic_adj = ic_scaled_thresholds(thr_B, label_freq, mlb.classes_, mf_idx, top_n=TOP_COMMON) |
| metrics_B = evaluate(test_probs, Y_test, thr_B, mf_idx) |
| print(json.dumps(metrics_B, indent=2)) |
| |
| thr_B_dict = {str(j): round(float(thr_B[j]), 4) for j in range(n_labels)} |
| with open(PREC_THRESH_OUT, "w") as f: |
| json.dump(thr_B_dict, f) |
| print(f" Saved to {PREC_THRESH_OUT.name}") |
|
|
| |
| |
| |
| print("\n=== Strategy C: novelty-gated (B + KNN on ESM embeddings) ===") |
| test_sims = np.array([compute_novelty_sim(test_embs[i], knn_ref) for i in range(len(test_embs))], dtype=np.float32) |
| novelty_cut = float(np.quantile(test_sims, NOVELTY_Q)) |
| novelty_min = float(test_sims.min()) |
| novelty_mask = test_sims <= novelty_cut |
| print(f" Novelty gate: bottom {int(NOVELTY_Q * 100)}% proteins by KNN similarity") |
| print(f" Similarity stats: min={test_sims.min():.3f} cut={novelty_cut:.3f} " |
| f"mean={test_sims.mean():.3f} max={test_sims.max():.3f}") |
| metrics_C = evaluate(test_probs, Y_test, thr_B, mf_idx, |
| embs=test_embs, knn_ref=knn_ref, |
| novelty_cut=novelty_cut, novelty_min=novelty_min) |
| print(json.dumps(metrics_C, indent=2)) |
|
|
| novelty_subset_A = subset_metrics(test_probs, Y_test, thr_A, mf_idx, novelty_mask) |
| novelty_subset_C = subset_metrics( |
| test_probs, Y_test, thr_B, mf_idx, novelty_mask, |
| embs=test_embs, knn_ref=knn_ref, novelty_cut=novelty_cut, novelty_min=novelty_min |
| ) |
|
|
| |
| |
| |
| winner = max( |
| [("A_current", metrics_A), |
| ("B_precision", metrics_B), |
| ("C_novelty", metrics_C)], |
| key=lambda x: x[1]["micro_f1"] |
| )[0] |
|
|
| results = { |
| "metadata": { |
| "model": CKPT_PATH.name, |
| "test_subset_size": SUBSET_SIZE, |
| "train_knn_ref_size": TRAIN_KNN, |
| "top_common_ic_scaled": TOP_COMMON, |
| "knn_k": KNN_K, |
| "current_thresholds": CURRENT_THRESH.name, |
| "novelty_quantile": NOVELTY_Q, |
| "novelty_subset_size": int(novelty_mask.sum()), |
| "novelty_similarity_cut": round(novelty_cut, 6), |
| "novelty_hi_thr": NOVELTY_HI_T, |
| "n_mf_labels": int(len(mf_idx)), |
| "mean_direct_labels_per_protein": round( |
| float(Y_test[:, mf_idx].sum(1).mean()), 2), |
| }, |
| "threshold_distributions": { |
| "A_current_v3": thr_stats(thr_A, mf_idx), |
| "B_precision_ic": thr_stats(thr_B, mf_idx), |
| }, |
| "ic_scaled_top25": [ |
| {"label_idx": int(j), "go_id": gid, "train_freq": freq, |
| "old_thr": round(old, 4), "new_thr": round(new, 4)} |
| for j, gid, freq, old, new in |
| sorted(ic_adj, key=lambda x: -x[2])[:TOP_COMMON] |
| ], |
| "metrics": { |
| "A_current_thresholds": metrics_A, |
| "B_precision_ic": metrics_B, |
| "C_novelty_gated": metrics_C, |
| "novelty_subset": { |
| "A_current_thresholds": novelty_subset_A, |
| "C_novelty_gated": novelty_subset_C, |
| }, |
| }, |
| "deltas": { |
| "A_vs_B": { |
| "precision_delta": round(metrics_B["micro_precision"] - metrics_A["micro_precision"], 4), |
| "recall_delta": round(metrics_B["micro_recall"] - metrics_A["micro_recall"], 4), |
| "f1_delta": round(metrics_B["micro_f1"] - metrics_A["micro_f1"], 4), |
| "mean_preds_delta": round(metrics_B["mean_preds_per_protein"] - |
| metrics_A["mean_preds_per_protein"], 2), |
| }, |
| "B_vs_C": { |
| "precision_delta": round(metrics_C["micro_precision"] - metrics_B["micro_precision"], 4), |
| "recall_delta": round(metrics_C["micro_recall"] - metrics_B["micro_recall"], 4), |
| "f1_delta": round(metrics_C["micro_f1"] - metrics_B["micro_f1"], 4), |
| "mean_preds_delta": round(metrics_C["mean_preds_per_protein"] - |
| metrics_B["mean_preds_per_protein"], 2), |
| }, |
| "novelty_subset_A_vs_C": { |
| "precision_delta": round(novelty_subset_C["micro_precision"] - novelty_subset_A["micro_precision"], 4), |
| "recall_delta": round(novelty_subset_C["micro_recall"] - novelty_subset_A["micro_recall"], 4), |
| "f1_delta": round(novelty_subset_C["micro_f1"] - novelty_subset_A["micro_f1"], 4), |
| "mean_preds_delta": round(novelty_subset_C["mean_preds_per_protein"] - |
| novelty_subset_A["mean_preds_per_protein"], 2), |
| }, |
| }, |
| "summary": { |
| "winner_by_f1": winner, |
| "mean_preds": {"A": metrics_A["mean_preds_per_protein"], |
| "B": metrics_B["mean_preds_per_protein"], |
| "C": metrics_C["mean_preds_per_protein"]}, |
| "pct_noisy_gt15": {"A": metrics_A["pct_noisy_gt15"], |
| "B": metrics_B["pct_noisy_gt15"], |
| "C": metrics_C["pct_noisy_gt15"]}, |
| "precision": {"A": metrics_A["micro_precision"], |
| "B": metrics_B["micro_precision"], |
| "C": metrics_C["micro_precision"]}, |
| "novelty_subset_f1": { |
| "A": novelty_subset_A["micro_f1"], |
| "C": novelty_subset_C["micro_f1"], |
| }, |
| }, |
| "elapsed_seconds": round(time.time() - t0, 1), |
| } |
|
|
| with open(OUT_PATH, "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| print(f"\n{'='*60}") |
| print(f"Results saved to {OUT_PATH}") |
| print(f"Elapsed: {results['elapsed_seconds']}s") |
| print(f"\nSummary:") |
| print(f" {'Strategy':<35} {'Prec':>6} {'Rec':>6} {'F1':>6} {'AvgN':>6} {'>15%':>6}") |
| for name, m in [("A current webapp", metrics_A), |
| ("B precision+IC", metrics_B), |
| ("C novelty-gated", metrics_C)]: |
| print(f" {name:<35} {m['micro_precision']:>6.4f} {m['micro_recall']:>6.4f} " |
| f"{m['micro_f1']:>6.4f} {m['mean_preds_per_protein']:>6.1f} " |
| f"{m['pct_noisy_gt15']:>5.1f}%") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|