| from fastapi import FastAPI |
| from fastapi.responses import FileResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.staticfiles import StaticFiles |
| from contextlib import asynccontextmanager |
| from pydantic import BaseModel |
| import torch |
| import torch.nn as nn |
| import pandas as pd |
| import joblib |
| import json |
| import math |
| import os |
| import re |
| import time |
| import hashlib |
| import warnings |
| from functools import lru_cache |
|
|
| warnings.filterwarnings("ignore") |
|
|
| |
| torch.set_num_threads(min(os.cpu_count() or 4, 8)) |
|
|
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
| STATIC_DIR = os.path.join(BASE_DIR, "static") |
| os.makedirs(STATIC_DIR, exist_ok=True) |
|
|
| HF_REPO = "Sbhat2026/protfunc-models" |
| |
| HF_FILES = [ |
| "unified_35M_v1.pth", "unified_35M_v1_thresholds.json", |
| "unified_v1.pth", "unified_v1_recalibrated.json", |
| "mammal_enriched.pth", "mammal_enriched_thresholds.json", |
| "protfunc_v3_fixed.pth", "protfunc_v3_fixed_thresholds.json", |
| "protfunc_v3.pth", "protfunc_v3_thresholds.json", |
| "improved_res.pth", "improved_per_label_thresholds.json", |
| "supp_res2.pth", |
| "baseline_res.pth", "mlb_public_v1.pkl", "go_names.json", |
| ] |
| OPTIONAL = { |
| "go_names.json", |
| "unified_35M_v1.pth", "unified_35M_v1_thresholds.json", |
| "mammal_enriched.pth", "mammal_enriched_thresholds.json", |
| "protfunc_v3_fixed.pth", "protfunc_v3_fixed_thresholds.json", |
| "protfunc_v3.pth", "protfunc_v3_thresholds.json", |
| "improved_res.pth", "improved_per_label_thresholds.json", |
| "supp_res2.pth", |
| } |
|
|
| |
| device = torch.device("cpu") |
| model = None |
| esm_model = None |
| batch_converter = None |
| mlb = None |
| go_map = {} |
| go_defs = {} |
| mf_terms = set() |
| go_parents = {} |
| go_ancestors = {} |
| go_depth = {} |
| go_replaced = {} |
| mf_indices = None |
| thresholds = {} |
| temperature = 1.0 |
| |
| insect_temperature = 1.0 |
| insect_threshold_default = 0.68 |
| NUM_LABELS = 0 |
| _ESM_DIM = 320 |
| |
| supp_mu = None |
| supp_sd = None |
| supp_cols = None |
| model_uses_supp = False |
| |
| taxon_probe = None |
| |
| platt_params = {} |
| |
| _uniprot_cache: dict = {} |
|
|
| |
| MIN_SEQ_LENGTH = 30 |
| MIN_ENTROPY_BITS = 2.5 |
| MAX_DOMINANT_FRAC = 0.60 |
| MIN_DISTINCT_AA = 5 |
| INVALID_AA = set("BJOUXZ") |
| MF_ROOT = "GO:0003674" |
|
|
| |
| _KD = {'A':1.8,'R':-4.5,'N':-3.5,'D':-3.5,'C':2.5,'Q':-3.5,'E':-3.5, |
| 'G':-0.4,'H':-3.2,'I':4.5,'L':3.8,'K':-3.9,'M':1.9,'F':2.8, |
| 'P':-1.6,'S':-0.8,'T':-0.7,'W':-0.9,'Y':-1.3,'V':4.2} |
| |
| _CF_HELIX = {'A':1.42,'R':0.98,'N':0.67,'D':1.01,'C':0.70,'Q':1.11,'E':1.51, |
| 'G':0.57,'H':1.00,'I':1.08,'L':1.21,'K':1.16,'M':1.45,'F':1.13, |
| 'P':0.57,'S':0.77,'T':0.83,'W':1.08,'Y':0.69,'V':1.06} |
| _CF_SHEET = {'A':0.83,'R':0.93,'N':0.89,'D':0.54,'C':1.19,'Q':1.10,'E':0.37, |
| 'G':0.75,'H':0.87,'I':1.60,'L':1.30,'K':0.74,'M':1.05,'F':1.38, |
| 'P':0.55,'S':0.75,'T':1.19,'W':1.37,'Y':1.47,'V':1.70} |
| |
| _DISORDER_PROMOTING = set("AERSQGKPTD") |
| _TM_HYDROPHOBIC = set("AVILMFYW") |
| _CHARGE = {"K": 1.0, "R": 1.0, "D": -1.0, "E": -1.0} |
|
|
| INSECT_LINEAGE = {"Insecta", "Hexapoda", "Arthropoda", "Chelicerata", "Myriapoda"} |
| MAMMAL_LINEAGE = {"Mammalia", "Theria", "Eutheria", "Metatheria", "Monotremata"} |
|
|
|
|
| def _detect_taxon_composition(seq: str) -> tuple: |
| """ |
| Heuristic taxon detection from amino acid composition. |
| Uses a linear discriminant calibrated from insect/mammal proteome statistics. |
| Returns ('insect'|'mammal', confidence_float). |
| Confidence < 0.60 → ambiguous, caller should treat as 'auto'. |
| """ |
| n = len(seq) |
| if n < 60: |
| return "mammal", 0.50 |
| seq_u = seq.upper() |
| freq = {aa: seq_u.count(aa) / n for aa in "ACDEFGHIKLMNPQRSTVWY"} |
| |
| |
| score = 0.0 |
| score += (freq.get("K", 0) - 0.058) * 18.0 |
| score += (freq.get("G", 0) - 0.073) * -12.0 |
| score += (freq.get("A", 0) - 0.072) * -9.0 |
| score += (freq.get("S", 0) - 0.077) * 7.0 |
| score += (freq.get("T", 0) - 0.056) * 6.0 |
| score += (freq.get("P", 0) - 0.055) * -5.0 |
| score += (freq.get("R", 0) - 0.053) * 4.0 |
| p_mammal = 1.0 / (1.0 + math.exp(-score * 2.5)) |
| if p_mammal >= 0.68: |
| return "mammal", round(p_mammal, 3) |
| elif p_mammal <= 0.32: |
| return "insect", round(1.0 - p_mammal, 3) |
| return "mammal", 0.50 |
|
|
|
|
| def _detect_taxon_probe(emb_np) -> tuple: |
| """Use the trained logistic probe on ESM embedding if loaded.""" |
| if taxon_probe is None: |
| return None, 0.0 |
| import numpy as np |
| w = taxon_probe["coef"] |
| b = taxon_probe["intercept"] |
| mu = taxon_probe["scaler_mean"] |
| sd = taxon_probe["scaler_std"] |
| x = (emb_np - mu) / (sd + 1e-12) |
| logit = float(np.dot(x, w) + b) |
| p_mammal = 1.0 / (1.0 + math.exp(-logit)) |
| if p_mammal >= 0.65: |
| return "mammal", round(p_mammal, 3) |
| elif p_mammal <= 0.35: |
| return "insect", round(1.0 - p_mammal, 3) |
| return "mammal", 0.50 |
|
|
|
|
| FEATURE_META = [ |
| {"key": "f_seq_len", "label": "Sequence Length", "desc": "Global protein length (uniform per-residue)", "color": "#888888"}, |
| {"key": "f_mean_hydro", "label": "Hydrophobicity", "desc": "Kyte-Doolittle hydrophobicity (window=5)", "color": "#f4a261"}, |
| {"key": "f_net_charge", "label": "Net Charge", "desc": "Local charge balance K+R−D−E (window=9)", "color": "#457b9d"}, |
| {"key": "f_uversky_disorder", "label": "Disorder Score", "desc": "Uversky charge-hydrophobicity disorder criterion (window=11)", "color": "#9b5de5"}, |
| {"key": "f_idr_frac_proxy", "label": "IDR Residues", "desc": "Disorder-promoting residues: A,E,R,S,Q,G,K,P,T,D", "color": "#00b4d8"}, |
| {"key": "f_lowcomp_proxy", "label": "Low Complexity", "desc": "Repetitive amino acid runs (length ≥5)", "color": "#adb5bd"}, |
| {"key": "f_tm_frac_proxy", "label": "TM Helix Windows", "desc": "Transmembrane helix windows (≥17/20 hydrophobic residues)", "color": "#e63946"}, |
| {"key": "f_tm_any_proxy", "label": "TM Present", "desc": "Presence of any transmembrane window", "color": "#c1121f"}, |
| {"key": "f_signal_peptide_proxy", "label": "Signal Peptide", "desc": "N-terminal hydrophobic signal (linear decay, first 30 aa)", "color": "#2d6a4f"}, |
| {"key": "f_cf_helix_mean", "label": "α-Helix Propensity", "desc": "Chou-Fasman α-helix propensity per residue", "color": "#4361ee"}, |
| {"key": "f_cf_sheet_mean", "label": "β-Sheet Propensity", "desc": "Chou-Fasman β-sheet propensity per residue", "color": "#e76f51"}, |
| ] |
|
|
|
|
| def compute_seq_features(seq: str) -> dict: |
| """ |
| Compute the 11 sequence-based supplementary features that are always |
| available at inference time. Returns a dict keyed by SUPP_COL names. |
| AF-derived features (f_afdb_has_model, f_plddt_*, f_distbin_*, f_pae_*, |
| f_seqfeat_present, f_af_present) are set to 0 — they will be z-scored to |
| near-zero against training means where >99% of proteins also had no AF data. |
| """ |
| seq_u = seq.upper() |
| n = len(seq_u) |
| kd = [_KD.get(aa, 0.0) for aa in seq_u] |
|
|
| mean_hydro = sum(kd) / n |
| net_charge = (seq_u.count('R') + seq_u.count('K') - |
| seq_u.count('D') - seq_u.count('E')) / n |
| |
| uversky_disorder = float(abs(mean_hydro) - abs(net_charge) < 0.06) |
| idr_frac = sum(1 for aa in seq_u if aa in _DISORDER_PROMOTING) / n |
|
|
| |
| lowcomp = 0 |
| i, prev, run = 0, '', 0 |
| for aa in seq_u: |
| run = run + 1 if aa == prev else 1 |
| if run >= 5: |
| lowcomp += 1 |
| prev = aa |
| lowcomp_proxy = lowcomp / n |
|
|
| |
| tm_count = 0 |
| for i in range(n - 19): |
| window = seq_u[i:i+20] |
| if sum(1 for aa in window if aa in _TM_HYDROPHOBIC) >= 17: |
| tm_count += 1 |
| tm_frac = tm_count / max(n - 19, 1) |
| tm_any = float(tm_count > 0) |
|
|
| |
| sp_window = seq_u[:30] |
| sp_proxy = float(sum(1 for aa in sp_window if aa in _TM_HYDROPHOBIC) >= 8) |
|
|
| |
| cf_helix = sum(_CF_HELIX.get(aa, 1.0) for aa in seq_u) / n |
| cf_sheet = sum(_CF_SHEET.get(aa, 1.0) for aa in seq_u) / n |
|
|
| return { |
| "f_seq_len": float(n), |
| "f_mean_hydro": float(mean_hydro), |
| "f_net_charge": float(net_charge), |
| "f_uversky_disorder": float(uversky_disorder), |
| "f_idr_frac_proxy": float(idr_frac), |
| "f_lowcomp_proxy": float(lowcomp_proxy), |
| "f_tm_frac_proxy": float(tm_frac), |
| "f_tm_any_proxy": float(tm_any), |
| "f_signal_peptide_proxy":float(sp_proxy), |
| "f_cf_helix_mean": float(cf_helix), |
| "f_cf_sheet_mean": float(cf_sheet), |
| |
| "f_afdb_has_model":0.0,"f_plddt_mean":0.0,"f_plddt_std":0.0, |
| "f_plddt_q10":0.0,"f_plddt_q50":0.0,"f_plddt_q90":0.0, |
| "f_plddt_frac_gt90":0.0,"f_plddt_frac_gt70":0.0,"f_plddt_frac_lt50":0.0, |
| "f_distbin_0":0.0,"f_distbin_1":0.0,"f_distbin_2":0.0,"f_distbin_3":0.0, |
| "f_distbin_4":0.0,"f_distbin_5":0.0,"f_distbin_6":0.0,"f_distbin_7":0.0, |
| "f_distbin_8":0.0,"f_distbin_9":0.0, |
| "f_pae_mean":0.0,"f_pae_median":0.0,"f_pae_p90":0.0,"f_pae_p95":0.0, |
| "f_pae_frac_lt5":0.0,"f_pae_frac_lt10":0.0,"f_pae_frac_gt20":0.0, |
| "f_seqfeat_present":1.0,"f_af_present":0.0, |
| } |
|
|
|
|
| def compute_per_residue_features(seq: str) -> dict: |
| """Return per-residue vectors (normalized [0,1]) for the 11 supp features.""" |
| import numpy as np |
| seq_u = seq.upper() |
| n = len(seq_u) |
|
|
| def smooth(arr, w): |
| hw = w // 2 |
| out = [] |
| for i in range(n): |
| lo, hi = max(0, i - hw), min(n, i + hw + 1) |
| out.append(sum(arr[lo:hi]) / (hi - lo)) |
| return out |
|
|
| def normalize(arr): |
| mn, mx = min(arr), max(arr) |
| if mx == mn: |
| return [0.5] * n |
| return [(v - mn) / (mx - mn) for v in arr] |
|
|
| kd_raw = [_KD.get(aa, 0.0) for aa in seq_u] |
| charge_raw = [_CHARGE.get(aa, 0.0) for aa in seq_u] |
|
|
| |
| r_seq_len = [1.0] * n |
|
|
| |
| r_mean_hydro = normalize(smooth(kd_raw, 5)) |
|
|
| |
| r_net_charge = normalize(smooth(charge_raw, 9)) |
|
|
| |
| uv_raw = [] |
| for i in range(n): |
| lo, hi = max(0, i - 5), min(n, i + 6) |
| wkd = sum(kd_raw[lo:hi]) / (hi - lo) |
| wch = sum(charge_raw[lo:hi]) / (hi - lo) |
| uv_raw.append(max(0.0, 0.20 - (abs(wkd) - abs(wch)))) |
| r_uversky = normalize(uv_raw) |
|
|
| |
| idr_raw = [1.0 if aa in _DISORDER_PROMOTING else 0.0 for aa in seq_u] |
| r_idr = normalize(smooth(idr_raw, 7)) |
|
|
| |
| lowcomp_raw = [0.0] * n |
| prev, run = '', 0 |
| for j, aa in enumerate(seq_u): |
| run = run + 1 if aa == prev else 1 |
| prev = aa |
| if run >= 5: |
| for k in range(max(0, j - run + 1), j + 1): |
| lowcomp_raw[k] = 1.0 |
| r_lowcomp = lowcomp_raw |
|
|
| |
| tm_raw = [0.0] * n |
| if n >= 20: |
| for i in range(n - 19): |
| if sum(1 for aa in seq_u[i:i+20] if aa in _TM_HYDROPHOBIC) >= 17: |
| for k in range(i, i + 20): |
| tm_raw[k] = 1.0 |
| r_tm_frac = tm_raw |
|
|
| |
| r_tm_any = tm_raw |
|
|
| |
| sp_mod = [] |
| for i in range(n): |
| weight = max(0.0, 1.0 - i / 30.0) |
| sp_mod.append(weight * max(0.0, kd_raw[i] / 4.5)) |
| r_sp = normalize(sp_mod) if any(v > 0 for v in sp_mod) else [max(0.0, 1.0 - i / 30.0) for i in range(n)] |
|
|
| |
| r_cf_helix = normalize(smooth([_CF_HELIX.get(aa, 1.0) for aa in seq_u], 3)) |
| r_cf_sheet = normalize(smooth([_CF_SHEET.get(aa, 1.0) for aa in seq_u], 3)) |
|
|
| return { |
| "f_seq_len": [round(v, 3) for v in r_seq_len], |
| "f_mean_hydro": [round(v, 3) for v in r_mean_hydro], |
| "f_net_charge": [round(v, 3) for v in r_net_charge], |
| "f_uversky_disorder": [round(v, 3) for v in r_uversky], |
| "f_idr_frac_proxy": [round(v, 3) for v in r_idr], |
| "f_lowcomp_proxy": [round(v, 3) for v in r_lowcomp], |
| "f_tm_frac_proxy": [round(v, 3) for v in r_tm_frac], |
| "f_tm_any_proxy": [round(v, 3) for v in r_tm_any], |
| "f_signal_peptide_proxy": [round(v, 3) for v in r_sp], |
| "f_cf_helix_mean": [round(v, 3) for v in r_cf_helix], |
| "f_cf_sheet_mean": [round(v, 3) for v in r_cf_sheet], |
| } |
|
|
|
|
| def build_ancestor_cache(go_parents_map: dict) -> dict: |
| """Compute full transitive ancestor sets for all GO terms (memoized DFS).""" |
| cache = {} |
| def _anc(gid): |
| if gid in cache: |
| return cache[gid] |
| parents = go_parents_map.get(gid, set()) |
| all_anc = set(parents) |
| for p in parents: |
| all_anc |= _anc(p) |
| cache[gid] = all_anc |
| return all_anc |
| for gid in go_parents_map: |
| _anc(gid) |
| return cache |
|
|
|
|
| def _download_with_retry(fname): |
| from huggingface_hub import hf_hub_download |
| max_attempts = 6 |
| for attempt in range(1, max_attempts + 1): |
| try: |
| print(f" [{attempt}/{max_attempts}] Downloading {fname}...") |
| path = hf_hub_download( |
| repo_id=HF_REPO, filename=fname, |
| local_dir=BASE_DIR, repo_type="model", |
| token=os.environ.get("HF_TOKEN"), |
| ) |
| print(f" saved -> {path}") |
| return |
| except Exception as e: |
| if fname in OPTIONAL: |
| print(f" {fname} is optional, skipping ({e})") |
| return |
| if attempt == max_attempts: |
| raise RuntimeError(f"Could not download '{fname}' after {max_attempts} attempts: {e}") |
| wait = 2 ** attempt |
| print(f" Network error, retrying in {wait}s... ({e})") |
| time.sleep(wait) |
|
|
|
|
| def ensure_model_files(): |
| missing = [f for f in HF_FILES if not os.path.exists(os.path.join(BASE_DIR, f))] |
| if not missing: |
| print("All model files already present.") |
| return |
| print(f"Downloading {len(missing)} file(s) from HuggingFace Hub...") |
| for fname in missing: |
| _download_with_retry(fname) |
|
|
|
|
| def load_go_map(): |
| try: |
| df = pd.read_csv(os.path.join(BASE_DIR, "go_annotations_fixed.csv")) |
| mapping = {} |
| for _, row in df.iterrows(): |
| go_id = str(row["GO Annotation"]).strip() |
| raw_name = str(row.get("Gene Ontology (molecular function)", "Unknown")) |
| mapping[go_id] = raw_name.split(" [")[0].strip() |
| print(f"GO map: {len(mapping)} labels loaded") |
| return mapping |
| except Exception as e: |
| print(f"GO map load error: {e}") |
| return {} |
|
|
|
|
| def load_thresholds(): |
| """ |
| Load per-label thresholds and return (mammal_thresholds, mammal_T, insect_T, insect_t_default). |
| |
| Threshold JSON formats accepted: |
| {"0": 0.68, ...} — plain float |
| {"0": {"threshold": 0.68, "tier": 0, "temperature": 3.69}} — rich dict |
| {"_meta": {...}, "0": 0.68, ...} — new format with metadata |
| |
| Returns: |
| flat : dict str_idx -> float (mammal per-label thresholds) |
| mammal_T : float temperature for mammal inference |
| ins_T : float temperature for insect inference (usually 1.0) |
| ins_t_default : float flat threshold for insect labels with no per-label data |
| """ |
| for path in [ |
| os.path.join(BASE_DIR, "unified_35M_v1_thresholds.json"), |
| os.path.join(BASE_DIR, "unified_v1_recalibrated.json"), |
| os.path.join(BASE_DIR, "unified_v1_thresholds.json"), |
| os.path.join(BASE_DIR, "mammal_enriched_thresholds.json"), |
| os.path.join(BASE_DIR, "protfunc_v3_fixed_thresholds.json"), |
| os.path.join(BASE_DIR, "improved_per_label_thresholds.json"), |
| os.path.join(BASE_DIR, "protfunc_v3_thresholds.json"), |
| os.path.join(BASE_DIR, "per_label_thresholds.json"), |
| os.path.join(BASE_DIR, "artifacts", "per_label_thresholds.json"), |
| ]: |
| if not os.path.exists(path): |
| continue |
| print(f"Thresholds loaded from {path}") |
| with open(path) as f: |
| raw = json.load(f) |
| flat = {} |
| mammal_T = 1.0 |
| ins_T = 1.0 |
| ins_t_def = 0.68 |
| |
| meta = raw.pop("_meta", None) |
| if meta: |
| mammal_T = float(meta.get("temperature", mammal_T)) |
| ins_T = float(meta.get("insect_temperature", ins_T)) |
| ins_t_def = float(meta.get("insect_global_t", ins_t_def)) |
| for k, v in raw.items(): |
| if isinstance(v, dict): |
| flat[k] = float(v.get("threshold", 0.5)) |
| mammal_T = float(v.get("temperature", mammal_T)) |
| else: |
| try: |
| flat[k] = float(v) |
| except (TypeError, ValueError): |
| pass |
| print(f" {len(flat)} mammal thresholds | mammal_T={mammal_T:.4f} | " |
| f"insect_T={ins_T:.4f} insect_t={ins_t_def:.2f}") |
| return flat, mammal_T, ins_T, ins_t_def |
| print("Thresholds not found, using defaults") |
| return {}, 1.0, 1.0, 0.68 |
|
|
|
|
| def parse_obo(path): |
| """ |
| Parse go-basic.obo and return: |
| mf_terms : set of active (non-obsolete) GO IDs with namespace == molecular_function |
| go_parents : dict GO ID -> set of direct parent GO IDs (is_a + part_of, MF only) |
| go_names_ob : dict GO ID -> canonical name from OBO (authoritative) |
| go_replaced : dict obsolete GO ID -> replacement GO ID |
| go_depth : dict GO ID -> minimum depth from MF root (root = 0) |
| |
| All relationships are restricted to the MF namespace. |
| """ |
| ns_map = {} |
| par_map = {} |
| name_map = {} |
| def_map = {} |
| rep_map = {} |
| alt_map = {} |
| obs_set = set() |
|
|
| cur_id = None |
| cur_ns = None |
| cur_nm = None |
| cur_df = None |
| cur_par = set() |
| cur_rep = None |
| cur_obs = False |
| cur_alt = [] |
| in_term = False |
|
|
| def flush(): |
| nonlocal cur_id, cur_ns, cur_nm, cur_df, cur_par, cur_rep, cur_obs, cur_alt |
| if cur_id: |
| if cur_obs: |
| obs_set.add(cur_id) |
| if cur_rep: |
| rep_map[cur_id] = cur_rep |
| else: |
| ns_map[cur_id] = cur_ns or "" |
| par_map[cur_id] = cur_par |
| name_map[cur_id] = cur_nm or cur_id |
| if cur_df: |
| def_map[cur_id] = cur_df |
| for a in cur_alt: |
| alt_map[a] = cur_id |
| cur_id = None; cur_ns = None; cur_nm = None; cur_df = None |
| cur_par = set(); cur_rep = None; cur_obs = False; cur_alt = [] |
|
|
| with open(path, "r", encoding="utf-8") as fh: |
| for raw in fh: |
| line = raw.strip() |
| if line == "[Term]": |
| flush(); in_term = True; continue |
| if line.startswith("[") and line != "[Term]": |
| flush(); in_term = False; continue |
| if not in_term: |
| continue |
| if line.startswith("id:"): |
| cur_id = line.split("id:", 1)[1].strip().split()[0] |
| elif line.startswith("name:"): |
| cur_nm = line.split("name:", 1)[1].strip() |
| elif line.startswith("namespace:"): |
| cur_ns = line.split("namespace:", 1)[1].strip() |
| elif line.startswith("alt_id:"): |
| cur_alt.append(line.split("alt_id:", 1)[1].strip().split()[0]) |
| elif line.startswith("def:"): |
| |
| raw_def = line.split("def:", 1)[1].strip() |
| if raw_def.startswith('"'): |
| end_q = raw_def.find('"', 1) |
| cur_df = raw_def[1:end_q] if end_q > 0 else raw_def |
| else: |
| cur_df = raw_def.split("[")[0].strip() |
| elif line.startswith("is_obsolete:") and "true" in line: |
| cur_obs = True |
| elif line.startswith("replaced_by:"): |
| cur_rep = line.split("replaced_by:", 1)[1].strip().split()[0] |
| elif line.startswith("is_a:"): |
| parent = line.split("is_a:", 1)[1].strip().split()[0] |
| cur_par.add(parent) |
| elif line.startswith("relationship:"): |
| parts = line.split("relationship:", 1)[1].strip().split() |
| if len(parts) >= 2 and parts[0] in ("part_of", "regulates", |
| "positively_regulates", |
| "negatively_regulates"): |
| cur_par.add(parts[1]) |
| flush() |
|
|
| mf = {gid for gid, n in ns_map.items() if n == "molecular_function"} |
| go_parents_mf = {gid: (par_map[gid] & mf) for gid in mf} |
| go_names_ob = {gid: name_map[gid] for gid in mf} |
| go_defs_mf = {gid: def_map[gid] for gid in mf if gid in def_map} |
| n_edges = sum(len(v) for v in go_parents_mf.values()) |
| print(f"OBO parsed: {len(mf)} MF terms, {n_edges} parent edges, " |
| f"{len(rep_map)} replacements, {len(alt_map)} alt-ids, " |
| f"{len(go_defs_mf)} definitions") |
|
|
| |
| go_depth: dict = {} |
| go_depth[MF_ROOT] = 0 |
| |
| children: dict = {gid: set() for gid in mf} |
| for gid, parents in go_parents_mf.items(): |
| for p in parents: |
| if p in children: |
| children[p].add(gid) |
| queue = [MF_ROOT] |
| while queue: |
| nxt = [] |
| for gid in queue: |
| d = go_depth[gid] |
| for child in children.get(gid, ()): |
| if child not in go_depth: |
| go_depth[child] = d + 1 |
| nxt.append(child) |
| queue = nxt |
| print(f"Depth computed: {len(go_depth)} MF terms, " |
| f"max depth={max(go_depth.values(), default=0)}") |
|
|
| return mf, go_parents_mf, go_names_ob, rep_map, go_depth, go_defs_mf |
|
|
|
|
| def compute_dynamic_cap(sorted_probs: list, seq_len: int) -> int: |
| """ |
| Return a protein-proportional cap on the number of direct predictions. |
| |
| Combines three signals: |
| 1. Sequence-length prior (log2 scaling). |
| 2. Probability-gap detection — largest relative drop within budget. |
| 3. Diffuse-activation penalty — when predictions are bunched at low |
| confidence with no clear outlier, the cap is tightened. This prevents |
| the model outputting many near-threshold terms for proteins it is |
| uncertain about (e.g. sparse mammal annotations). |
| """ |
| n = len(sorted_probs) |
| if n == 0: |
| return 0 |
| if n <= 3: |
| return n |
|
|
| |
| length_prior = max(2, int(2.5 * math.log2(max(seq_len, 50) / 50) + 2)) |
| abs_cap = min(15, length_prior * 2) |
|
|
| |
| |
| |
| top_prob = sorted_probs[0] |
| spread = sorted_probs[0] - sorted_probs[-1] |
| if top_prob < 0.75 and spread < 0.12: |
| |
| abs_cap = max(3, length_prior) |
| elif top_prob < 0.72: |
| |
| abs_cap = max(3, min(abs_cap, length_prior + 2)) |
|
|
| search_end = min(n, abs_cap) |
| if search_end < 2: |
| return min(n, abs_cap) |
|
|
| best_score = -1.0 |
| best_idx = search_end |
|
|
| for i in range(1, search_end): |
| prev = sorted_probs[i - 1] |
| curr = sorted_probs[i] |
| rel_gap = (prev - curr) / (prev + 1e-6) |
| dist = abs(i - length_prior) / max(length_prior, 1) |
| score = rel_gap * (1.0 - 0.25 * min(dist, 1.0)) |
| if score > best_score and rel_gap >= 0.08: |
| best_score = score |
| best_idx = i |
|
|
| return max(3, best_idx) |
|
|
|
|
| def propagate_and_filter(preds, go_parents_map, go_ancestors_map, prob_map): |
| """ |
| 1. Propagate predictions upward: for every predicted term, all its MF |
| ancestors are implicitly predicted. Ancestors not already above threshold |
| are added as 'implied' predictions with the child's probability. |
| 2. Filter: a term is 'suppressed' only if it has MF parents and NONE of |
| them appear in the final visible set (direct or implied). |
| 3. Specificity: implied ancestors with depth ≤ 1 (root-adjacent, trivially |
| general terms) are hidden from the visible list. Each prediction carries |
| its depth so the UI can convey specificity. |
| |
| Returns (visible, suppressed) where visible includes informative implied parents. |
| """ |
| if not go_ancestors_map: |
| |
| for p in preds: |
| p["depth"] = go_depth.get(p["go_id"], -1) |
| return preds, [] |
|
|
| predicted_ids = {p["go_id"] for p in preds} |
| implied = {} |
|
|
| for pred in preds: |
| gid = pred["go_id"] |
| prob = pred["prob"] |
| for anc in go_ancestors_map.get(gid, set()): |
| if anc not in predicted_ids and anc != MF_ROOT: |
| implied[anc] = max(implied.get(anc, 0.0), prob) |
|
|
| all_visible_ids = predicted_ids | set(implied.keys()) |
|
|
| |
| suppressed = [] |
| direct_ok = [] |
| for pred in preds: |
| gid = pred["go_id"] |
| parents = go_parents_map.get(gid, set()) |
| pred["depth"] = go_depth.get(gid, -1) |
| if gid == MF_ROOT or not parents or (parents & all_visible_ids): |
| direct_ok.append(pred) |
| else: |
| pred["reason"] = "no_visible_parent" |
| suppressed.append(pred) |
|
|
| |
| |
| |
| MIN_IMPLIED_DEPTH = 2 |
| implied_preds = [] |
| for gid, prob in implied.items(): |
| d = go_depth.get(gid, -1) |
| if d < MIN_IMPLIED_DEPTH: |
| continue |
| implied_preds.append({ |
| "go_id": gid, |
| "name": go_map.get(gid, gid), |
| "prob": round(prob, 3), |
| "implied": True, |
| "depth": d, |
| }) |
| implied_preds.sort(key=lambda x: (-x["prob"], -x["depth"])) |
|
|
| visible = direct_ok + implied_preds |
| visible.sort(key=lambda x: (-x["prob"], -x.get("depth", 0))) |
| return visible, suppressed |
|
|
|
|
| def sequence_entropy(seq): |
| seq_upper = seq.upper() |
| counts = {} |
| for aa in seq_upper: |
| counts[aa] = counts.get(aa, 0) + 1 |
| n = len(seq_upper) |
| return -sum((c / n) * math.log2(c / n) for c in counts.values()) |
|
|
|
|
| def validate_sequence(name, seq): |
| """Returns an error string if the sequence should be rejected, else None.""" |
| if len(seq) < MIN_SEQ_LENGTH: |
| return (f"'{name}' is too short ({len(seq)} aa — minimum {MIN_SEQ_LENGTH} aa). " |
| f"Sequences this short are unlikely to fold into a stable domain.") |
|
|
| |
| non_letter = sorted({c for c in seq if not c.isalpha()}) |
| if non_letter: |
| display = ", ".join(f"'{c}'" for c in non_letter[:5]) |
| return (f"'{name}' contains non-amino-acid characters: {display}. " |
| f"Only single-letter amino acid codes are accepted.") |
|
|
| |
| seq_upper_set = {c.upper() for c in seq} |
| nucleotide_chars = seq_upper_set & set("ATCGU") |
| nucleotide_frac = sum(seq.upper().count(c) for c in "ATCGU") / len(seq) |
| if nucleotide_frac > 0.85 and len(seq_upper_set) <= 6: |
| return (f"'{name}' appears to be a nucleotide sequence (DNA/RNA), not a protein. " |
| f"Please enter an amino acid sequence in single-letter code.") |
|
|
| bad = sorted({c.upper() for c in seq if c.upper() in INVALID_AA}) |
| if bad: |
| return (f"'{name}' contains invalid amino acid character(s): " |
| f"{', '.join(bad)}. These ambiguity codes are not accepted.") |
|
|
| counts = {} |
| for aa in seq.upper(): |
| counts[aa] = counts.get(aa, 0) + 1 |
|
|
| if len(counts) < MIN_DISTINCT_AA: |
| return (f"'{name}' uses only {len(counts)} distinct residue type(s). " |
| f"Real proteins require at least {MIN_DISTINCT_AA}.") |
|
|
| dominant_frac = max(counts.values()) / len(seq) |
| if dominant_frac > MAX_DOMINANT_FRAC: |
| dominant_aa = max(counts, key=counts.get) |
| return (f"'{name}' is dominated by a single residue " |
| f"({dominant_aa} = {dominant_frac:.0%}). " |
| f"Low-complexity sequences produce unreliable embeddings.") |
|
|
| H = sequence_entropy(seq) |
| if H < MIN_ENTROPY_BITS: |
| return (f"'{name}' has very low sequence complexity " |
| f"(Shannon entropy {H:.2f} bits, minimum {MIN_ENTROPY_BITS:.1f} bits). " |
| f"Repetitive or artificially constructed sequences are not accepted.") |
|
|
| return None |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global device, model, esm_model, batch_converter |
| global mlb, go_map, go_defs, mf_terms, go_parents, go_ancestors, go_depth, go_replaced |
| global mf_indices, thresholds, temperature, insect_temperature, insect_threshold_default, NUM_LABELS, _ESM_DIM |
| global supp_mu, supp_sd, supp_cols, model_uses_supp |
| global taxon_probe, platt_params |
|
|
| |
| ensure_model_files() |
|
|
| |
| go_map = load_go_map() |
| go_names_path = os.path.join(BASE_DIR, "go_names.json") |
| if os.path.exists(go_names_path): |
| with open(go_names_path) as f: |
| go_map.update(json.load(f)) |
| print(f"Canonical GO names loaded: {len(go_map)} total entries") |
|
|
| |
| mlb = joblib.load(os.path.join(BASE_DIR, "mlb_public_v1.pkl")) |
| NUM_LABELS = len(mlb.classes_) |
| print(f"MLB loaded: {NUM_LABELS} labels") |
|
|
| |
| obo_path = os.path.join(BASE_DIR, "go-basic.obo") |
| if os.path.exists(obo_path): |
| mf_terms, go_parents, go_names_obo, go_replaced, go_depth, go_defs_obo = parse_obo(obo_path) |
| go_defs.update(go_defs_obo) |
| |
| go_map.update(go_names_obo) |
| print(f"OBO names merged: {len(go_names_obo)} MF term names") |
| mf_in_mlb = sum(1 for gid in mlb.classes_ if gid in mf_terms) |
| rep_in_mlb = sum(1 for gid in mlb.classes_ if gid in go_replaced) |
| print(f"OBO cross-check: {mf_in_mlb}/{NUM_LABELS} active MF, " |
| f"{rep_in_mlb} replaced/obsolete labels remapped") |
| |
| go_ancestors = build_ancestor_cache(go_parents) |
| print(f"Ancestor cache built: {len(go_ancestors)} terms") |
| else: |
| print("WARNING: go-basic.obo not found — hierarchy filtering disabled. " |
| "Download from https://current.geneontology.org/ontology/go-basic.obo " |
| "and place it alongside server.py.") |
|
|
| |
| if mf_terms: |
| mf_indices = [i for i, gid in enumerate(mlb.classes_) if gid in mf_terms] |
| print(f"MF whitelist (OBO): {len(mf_indices)} active indices") |
| else: |
| mf_go_ids = { |
| go_id for go_id, name in go_map.items() |
| if name and name != go_id and not name.startswith("GO:") |
| } |
| mf_indices = [i for i, gid in enumerate(mlb.classes_) if gid in mf_go_ids] or list(range(NUM_LABELS)) |
| print(f"MF whitelist (CSV fallback): {len(mf_indices)} active indices") |
|
|
| app.state.mf_indices = mf_indices |
|
|
| |
| thresholds, temperature, insect_temperature, insect_threshold_default = load_thresholds() |
|
|
| |
|
|
| class ResBlock(nn.Module): |
| """Pre-activation residual block with BatchNorm (improved model).""" |
| def __init__(self, dim, dropout=0.2): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), |
| nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), |
| ) |
| def forward(self, x): |
| return x + self.net(x) |
|
|
| class ImprovedResidualMLP(nn.Module): |
| """4-block, hidden=2048, BatchNorm — trained by train_improved.py.""" |
| def __init__(self, in_dim=320, out_dim=NUM_LABELS, hidden=2048, n_blocks=4, dropout=0.2): |
| super().__init__() |
| self.fc_in = nn.Linear(in_dim, hidden) |
| self.blocks = nn.ModuleList([ResBlock(hidden, dropout) for _ in range(n_blocks)]) |
| self.fc_out = nn.Sequential( |
| nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(dropout), |
| nn.Linear(hidden, out_dim), |
| ) |
| def forward(self, x): |
| h = self.fc_in(x) |
| for blk in self.blocks: |
| h = blk(h) |
| return self.fc_out(h) |
|
|
| class ResidualMLP(nn.Module): |
| """Original 2-block notebook model (fallback).""" |
| def __init__(self, in_dim=320, out_dim=NUM_LABELS, hidden=1024, dropout=0.2): |
| super().__init__() |
| self.fc_in = nn.Linear(in_dim, hidden) |
| self.block1 = nn.Sequential(nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, hidden)) |
| self.block2 = nn.Sequential(nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, hidden)) |
| self.fc_out = nn.Sequential(nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, out_dim)) |
| def forward(self, x): |
| h = self.fc_in(x) |
| h = torch.relu(h) |
| h = h + self.block1(h) |
| h = h + self.block2(h) |
| return self.fc_out(h) |
|
|
| class RecoveredBaselineModel(nn.Module): |
| """Earlier server-side architecture — retained for backward compatibility.""" |
| def __init__(self, in_dim=320, out_dim=NUM_LABELS, hidden=1024, dropout=0.2): |
| super().__init__() |
| self.fc1 = nn.Linear(in_dim, hidden) |
| self.proj = nn.Linear(in_dim, hidden) |
| self.fc2 = nn.Linear(hidden, hidden) |
| self.out = nn.Linear(hidden, out_dim) |
| self.relu = nn.ReLU() |
| self.drop = nn.Dropout(dropout) |
| def forward(self, x): |
| h = self.relu(self.fc1(x)) |
| h = h + self.proj(x) |
| h = self.relu(self.fc2(h)) |
| h = self.drop(h) |
| return self.out(h) |
|
|
| import numpy as np |
| device = torch.device("cpu") |
|
|
| |
| ckpt_candidates = [ |
| os.path.join(BASE_DIR, "unified_35M_v1.pth"), |
| os.path.join(BASE_DIR, "unified_v1.pth"), |
| os.path.join(BASE_DIR, "mammal_enriched.pth"), |
| os.path.join(BASE_DIR, "protfunc_v3_fixed.pth"), |
| os.path.join(BASE_DIR, "improved_res.pth"), |
| os.path.join(BASE_DIR, "protfunc_v3.pth"), |
| os.path.join(BASE_DIR, "supp_res2.pth"), |
| os.path.join(BASE_DIR, "baseline_res.pth"), |
| ] |
|
|
| _ESM_DIM = 320 |
| _model = None |
|
|
| for ckpt_path in ckpt_candidates: |
| if not os.path.exists(ckpt_path): |
| continue |
|
|
| print(f"Trying classifier: {os.path.basename(ckpt_path)}") |
| try: |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
| except Exception as e: |
| print(f" Failed to load {os.path.basename(ckpt_path)}: {e} — skipping") |
| continue |
|
|
| sd = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt |
| keys = set(sd.keys()) |
|
|
| |
| _c_supp_mu = None |
| _c_supp_sd = None |
| _c_supp_cols = None |
| _c_uses_supp = False |
|
|
| if isinstance(ckpt, dict) and "supp_mu" in ckpt: |
| _c_supp_mu = np.array(ckpt["supp_mu"], dtype=np.float32) |
| _c_supp_sd = np.array(ckpt["supp_sd"], dtype=np.float32) |
| _c_supp_cols = ckpt["supp_cols"] |
| _c_uses_supp = True |
|
|
| |
| |
| |
| |
| _ckpt_has_explicit_in_dim = isinstance(ckpt, dict) and "in_dim" in ckpt |
| |
| _out_dim_ckpt = None |
| for _out_key in ("fc_out.3.weight", "fc_out.2.weight", "out.weight"): |
| if _out_key in sd: |
| _out_dim_ckpt = sd[_out_key].shape[0] |
| break |
|
|
| if "blocks.0.net.0.weight" in keys: |
| hidden_dim = sd["fc_in.weight"].shape[0] |
| n_blocks = sum(1 for k in keys if k.startswith("blocks.") and k.endswith(".net.0.weight")) |
| in_dim_ckpt = sd["fc_in.weight"].shape[1] |
| if _c_uses_supp and _c_supp_cols is not None and not _ckpt_has_explicit_in_dim: |
| expected = _ESM_DIM + len(_c_supp_cols) + 1 |
| if in_dim_ckpt != expected: |
| print(f" SKIP: supp metadata says in_dim={expected} but fc_in has {in_dim_ckpt} — corrupted checkpoint") |
| continue |
| out_dim_use = _out_dim_ckpt or NUM_LABELS |
| _model = ImprovedResidualMLP(in_dim=in_dim_ckpt, hidden=hidden_dim, n_blocks=n_blocks, out_dim=out_dim_use).to(device) |
| print(f" ImprovedResidualMLP (in={in_dim_ckpt} hidden={hidden_dim} blocks={n_blocks} out={out_dim_use})") |
| elif any(k.startswith("fc_in") for k in keys): |
| in_dim_ckpt = sd["fc_in.weight"].shape[1] |
| if _c_uses_supp and _c_supp_cols is not None and not _ckpt_has_explicit_in_dim: |
| expected = _ESM_DIM + len(_c_supp_cols) + 1 |
| if in_dim_ckpt != expected: |
| print(f" SKIP: supp metadata says in_dim={expected} but fc_in has {in_dim_ckpt} — corrupted checkpoint") |
| continue |
| out_dim_use = _out_dim_ckpt or NUM_LABELS |
| _model = ResidualMLP(in_dim=in_dim_ckpt, out_dim=out_dim_use).to(device) |
| print(f" ResidualMLP (in={in_dim_ckpt} out={out_dim_use})") |
| elif any(k.startswith("fc1") for k in keys): |
| _c_uses_supp = False |
| out_dim_use = _out_dim_ckpt or NUM_LABELS |
| _model = RecoveredBaselineModel(out_dim=out_dim_use).to(device) |
| print(f" RecoveredBaselineModel (legacy out={out_dim_use})") |
| else: |
| print(f" SKIP: unrecognised architecture — keys: {sorted(keys)[:8]}") |
| continue |
|
|
| try: |
| _model.load_state_dict(sd, strict=True) |
| except Exception as e: |
| print(f" SKIP: load_state_dict failed: {e}") |
| _model = None |
| continue |
|
|
| |
| supp_mu = _c_supp_mu |
| supp_sd = _c_supp_sd |
| supp_cols = _c_supp_cols |
| model_uses_supp = _c_uses_supp |
| if isinstance(ckpt, dict) and "val_fmax" in ckpt: |
| print(f" val_fmax={ckpt['val_fmax']:.4f} epoch={ckpt.get('epoch','?')}") |
| if _c_uses_supp: |
| print(f" Supplemented model: {len(_c_supp_cols)} extra features") |
| |
| if isinstance(ckpt, dict) and "esm_dim" in ckpt: |
| _ESM_DIM = int(ckpt["esm_dim"]) |
| print(f" ESM dim from checkpoint: {_ESM_DIM}") |
| print(f"Classifier loaded: {os.path.basename(ckpt_path)}") |
| break |
|
|
| if _model is None: |
| raise RuntimeError("No valid classifier checkpoint found.") |
| _model.eval() |
| model = _model |
|
|
| |
| import esm as esm_lib |
| if _ESM_DIM == 480: |
| _esm_model, alphabet = esm_lib.pretrained.esm2_t12_35M_UR50D() |
| print("ESM-2 (35M, 480-dim) loaded OK") |
| else: |
| _esm_model, alphabet = esm_lib.pretrained.esm2_t6_8M_UR50D() |
| print("ESM-2 (8M, 320-dim) loaded OK") |
| esm_model = _esm_model.to(device).eval() |
| batch_converter = alphabet.get_batch_converter() |
|
|
| |
| probe_path = os.path.join(BASE_DIR, "taxon_probe.json") |
| if os.path.exists(probe_path): |
| with open(probe_path) as f: |
| taxon_probe = json.load(f) |
| acc = taxon_probe.get("train_accuracy", 0) |
| probe_esm_dim = taxon_probe.get("esm_dim", 320) |
| if probe_esm_dim != _ESM_DIM: |
| print(f"Taxon probe ESM dim mismatch ({probe_esm_dim} vs {_ESM_DIM}) — disabling probe") |
| taxon_probe = None |
| else: |
| for k in ("coef", "intercept", "scaler_mean", "scaler_std"): |
| if k in taxon_probe: |
| taxon_probe[k] = np.asarray(taxon_probe[k], dtype=np.float32) |
| print(f"Taxon probe loaded (train_acc={acc:.4f}, esm_dim={probe_esm_dim})") |
| else: |
| print("Taxon probe not found — using composition heuristic for auto-detection") |
|
|
| |
| platt_path = os.path.join(BASE_DIR, "platt_mammal.json") |
| if os.path.exists(platt_path): |
| with open(platt_path) as f: |
| platt_params = json.load(f) |
| print(f"Platt scaling loaded: {len(platt_params)} labels calibrated") |
| else: |
| print("Platt params not found — using temperature scaling only") |
|
|
| |
| temp_path = os.path.join(BASE_DIR, "temperature_best.json") |
| if os.path.exists(temp_path): |
| with open(temp_path) as f: |
| temp_data = json.load(f) |
| new_T = float(temp_data.get("optimal_T", temperature)) |
| if abs(new_T - temperature) > 0.1: |
| print(f"Temperature updated by calibration sweep: {temperature:.4f} → {new_T:.4f}") |
| temperature = new_T |
| else: |
| print(f"Temperature unchanged by sweep: {temperature:.4f}") |
|
|
| yield |
|
|
| print("Shutting down.") |
|
|
|
|
| app = FastAPI(lifespan=lifespan) |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") |
|
|
|
|
| @app.get("/") |
| async def root(): |
| return FileResponse(os.path.join(STATIC_DIR, "interface.html"), headers={"Cache-Control": "no-store"}) |
|
|
|
|
| @app.get("/api/model/info") |
| async def model_info(): |
| """Return model metadata and configuration.""" |
| unified_35M = os.path.exists(os.path.join(BASE_DIR, "unified_35M_v1.pth")) |
| unified_v1 = os.path.exists(os.path.join(BASE_DIR, "unified_v1.pth")) |
| mammal_enriched = os.path.exists(os.path.join(BASE_DIR, "mammal_enriched.pth")) |
| v3_fixed = os.path.exists(os.path.join(BASE_DIR, "protfunc_v3_fixed.pth")) |
| improved = os.path.exists(os.path.join(BASE_DIR, "improved_res.pth")) |
| |
| if unified_35M and model_uses_supp and _ESM_DIM == 480: |
| name, version, active = "ProtFunc v5.0 (35M ESM, GOA-enriched, insect+mammal)", "5.0.0", "unified_35M_v1" |
| elif unified_v1 and model_uses_supp: |
| name, version, active = "ProtFunc v4.0 (unified insect+mammal, CAFA5-evaluated)", "4.0.0", "unified_v1" |
| elif mammal_enriched and model_uses_supp: |
| name, version, active = "ProtFunc v3.2 (mammal-enriched, CAFA5-evaluated)", "3.2.0", "mammal_enriched" |
| elif v3_fixed and model_uses_supp: |
| name, version, active = "ProtFunc v3-fixed (ablation best, CAFA-correct)", "3.1.0", "protfunc_v3_fixed" |
| elif model_uses_supp: |
| name, version, active = "ProtFunc v3 (supplemented + mammal)", "3.0.0", "protfunc_v3" |
| elif improved: |
| name, version, active = "ProtFunc Enhanced", "2.1.0", "improved" |
| else: |
| name, version, active = "ProtFunc", "2.0.0", "baseline" |
| return { |
| "model_name": name, |
| "model": active, |
| "version": version, |
| "esm_model": "esm2_t12_35M_UR50D" if _ESM_DIM == 480 else "esm2_t6_8M_UR50D", |
| "embed_dim": _ESM_DIM, |
| "num_labels": NUM_LABELS, |
| "supported_namespaces": ["molecular_function"], |
| "max_sequence_length": 1500, |
| "thresholds_loaded": len(thresholds) > 0, |
| "temperature_scaling": temperature != 1.0, |
| "temperature": round(temperature, 4), |
| "insect_temperature": round(insect_temperature, 4), |
| "insect_threshold_default": round(insect_threshold_default, 4), |
| "supported_taxa": ["insect", "mammal"], |
| "taxon_routing": True, |
| "hierarchy_filtering": len(go_parents) > 0, |
| "parental_propagation": len(go_ancestors) > 0, |
| "depth_annotation": len(go_depth) > 0, |
| "mf_terms_loaded": len(mf_terms) if mf_terms else 0, |
| "mf_max_depth": max(go_depth.values(), default=0) if go_depth else 0, |
| "supplemented_features": model_uses_supp, |
| "supp_feature_count": len(supp_cols) if supp_cols else 0, |
| } |
|
|
|
|
| @app.get("/api/generalization") |
| async def get_generalization(): |
| """ |
| Return cross-taxon generalization results from eval_generalization.py output. |
| Serves artifacts/generalization_results.json if present, otherwise returns empty. |
| """ |
| candidates = [ |
| os.path.join(BASE_DIR, "artifacts", "generalization", "generalization_results.json"), |
| os.path.join(BASE_DIR, "generalization_results.json"), |
| ] |
| for path in candidates: |
| if os.path.exists(path): |
| with open(path) as f: |
| data = json.load(f) |
| return { |
| "available": True, |
| "taxa": list(data.keys()), |
| "results": data, |
| } |
| return {"available": False, "taxa": [], "results": {}} |
|
|
|
|
| @app.get("/api/structure") |
| async def get_structure(uniprot_id: str): |
| """Look up AlphaFold structure data for a UniProt accession.""" |
| import urllib.request, urllib.error |
| uid = uniprot_id.upper().strip() |
| if not re.match(r'^[A-Z0-9]{4,10}$', uid): |
| return {"found": False, "error": "Invalid accession format"} |
| try: |
| req = urllib.request.Request( |
| f"https://alphafold.ebi.ac.uk/api/prediction/{uid}", |
| headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"} |
| ) |
| with urllib.request.urlopen(req, timeout=8) as resp: |
| entries = json.loads(resp.read()) |
| d = entries[0] |
| return { |
| "found": True, |
| "accession": uid, |
| "organism": d.get("organismScientificName", ""), |
| "gene": d.get("gene", ""), |
| "cif_url": d.get("cifUrl", ""), |
| "pae_image_url": d.get("paeImageUrl", ""), |
| "entry_url": f"https://alphafold.ebi.ac.uk/entry/{uid}", |
| "uniprot_url": f"https://www.uniprot.org/uniprot/{uid}", |
| "model_version": d.get("latestVersion", 4), |
| } |
| except urllib.error.HTTPError as e: |
| if e.code == 404: |
| return {"found": False, "accession": uid, |
| "uniprot_url": f"https://www.uniprot.org/uniprot/{uid}"} |
| return {"found": False, "error": f"HTTP {e.code}"} |
| except Exception as e: |
| return {"found": False, "error": str(e)[:100]} |
|
|
|
|
| @app.get("/api/uniprot/annotations") |
| async def get_uniprot_annotations(uniprot_id: str): |
| """ |
| Fetch GO-MF annotations and organism info from UniProt REST API. |
| Returns known annotations (evidence-coded) for comparison with predictions. |
| """ |
| import urllib.request, urllib.error |
| uid = uniprot_id.upper().strip() |
| if not re.match(r'^[A-Z0-9]{4,10}$', uid): |
| return {"found": False, "error": "Invalid accession format"} |
|
|
| if uid in _uniprot_cache: |
| return _uniprot_cache[uid] |
|
|
| try: |
| url = f"https://rest.uniprot.org/uniprotkb/{uid}.json?fields=go,organism,protein_name,gene_names,organism_lineage" |
| req = urllib.request.Request(url, headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"}) |
| with urllib.request.urlopen(req, timeout=10) as resp: |
| data = json.loads(resp.read()) |
| except urllib.error.HTTPError as e: |
| if e.code == 404: |
| return {"found": False, "accession": uid, "error": "Not found in UniProt"} |
| return {"found": False, "error": f"HTTP {e.code}"} |
| except Exception as e: |
| return {"found": False, "error": str(e)[:100]} |
|
|
| |
| organism = data.get("organism", {}) |
| org_name = organism.get("scientificName", "") |
| lineage = [x.get("scientificName", "") for x in organism.get("lineage", [])] |
| if any(t in lineage for t in INSECT_LINEAGE): |
| detected_taxon = "insect" |
| elif any(t in lineage for t in MAMMAL_LINEAGE): |
| detected_taxon = "mammal" |
| else: |
| detected_taxon = "auto" |
|
|
| |
| pn_block = data.get("proteinDescription", {}) |
| prot_name = "" |
| if "recommendedName" in pn_block: |
| prot_name = pn_block["recommendedName"].get("fullName", {}).get("value", "") |
| elif "submissionNames" in pn_block: |
| prot_name = pn_block["submissionNames"][0].get("fullName", {}).get("value", "") |
|
|
| gene_names = [] |
| for gn in data.get("genes", []): |
| if "geneName" in gn: |
| gene_names.append(gn["geneName"]["value"]) |
|
|
| |
| go_mf = [] |
| for xref in data.get("uniProtKBCrossReferences", []): |
| if xref.get("database") != "GO": |
| continue |
| go_id = xref.get("id", "") |
| props = {p["key"]: p["value"] for p in xref.get("properties", [])} |
| term_str = props.get("GoTerm", "") |
| if not term_str.startswith("F:"): |
| continue |
| evidence = props.get("GoEvidenceType", "") |
| ev_code = evidence.split(":")[0] if ":" in evidence else evidence |
| go_name = term_str[2:].strip() |
| |
| display_name = go_map.get(go_id, go_name) |
| |
| exp_codes = {"EXP","IDA","IPI","IMP","IGI","IEP","HTP","HDA","HMP","HGI","HEP"} |
| comp_codes = {"ISS","ISO","ISA","ISM","IGC","IBA","IBD","IKR","IRD","RCA"} |
| if ev_code in exp_codes: |
| ev_tier = "experimental" |
| elif ev_code in comp_codes: |
| ev_tier = "computational" |
| elif ev_code == "IEA": |
| ev_tier = "electronic" |
| else: |
| ev_tier = "other" |
| go_mf.append({ |
| "go_id": go_id, |
| "name": display_name, |
| "evidence": ev_code, |
| "ev_tier": ev_tier, |
| }) |
|
|
| |
| tier_rank = {"experimental": 0, "computational": 1, "other": 2, "electronic": 3} |
| seen = {} |
| for entry in go_mf: |
| gid = entry["go_id"] |
| if gid not in seen or tier_rank[entry["ev_tier"]] < tier_rank[seen[gid]["ev_tier"]]: |
| seen[gid] = entry |
| go_mf = sorted(seen.values(), key=lambda x: (tier_rank[x["ev_tier"]], x["name"])) |
|
|
| result = { |
| "found": True, |
| "accession": uid, |
| "protein_name": prot_name, |
| "gene_names": gene_names, |
| "organism": org_name, |
| "lineage": lineage[-5:] if lineage else [], |
| "detected_taxon": detected_taxon, |
| "go_mf": go_mf, |
| "n_experimental": sum(1 for e in go_mf if e["ev_tier"] == "experimental"), |
| "n_total": len(go_mf), |
| } |
| _uniprot_cache[uid] = result |
| return result |
|
|
|
|
| @app.get("/api/health") |
| async def health_check(): |
| """Health check endpoint for monitoring.""" |
| return { |
| "status": "healthy", |
| "model_loaded": model is not None, |
| "esm_loaded": esm_model is not None, |
| "labels": NUM_LABELS, |
| } |
|
|
|
|
| class ProteinRequest(BaseModel): |
| sequence: str |
| taxon: str = "auto" |
| uniprot_id: str = "" |
|
|
|
|
| class SaliencyRequest(BaseModel): |
| sequence: str |
| uniprot_id: str = "" |
| taxon: str = "auto" |
| top_k: int = 20 |
|
|
|
|
| class ExplainRequest(BaseModel): |
| sequence: str |
| uniprot_id: str = "" |
| taxon: str = "auto" |
| top_k: int = 10 |
|
|
|
|
| class BatchPredictRequest(BaseModel): |
| sequences: list |
| threshold: float = 0.5 |
| include_suppressed: bool = True |
|
|
|
|
| def parse_sequences(text): |
| text = text.strip() |
| if text.startswith(">"): |
| blocks = re.split(r"(>.*)", text) |
| names, seqs = [], [] |
| i = 1 |
| while i < len(blocks): |
| name = blocks[i][1:].strip() |
| seq = re.sub(r"\s+", "", blocks[i + 1]) if i + 1 < len(blocks) else "" |
| if seq: |
| names.append(name) |
| seqs.append(seq) |
| i += 2 |
| return list(zip(names, seqs)) |
| seqs = [line.strip() for line in text.splitlines() if line.strip()] |
| return [(f"Sequence {i + 1}", s) for i, s in enumerate(seqs)] |
|
|
|
|
| def _build_model_input(emb: torch.Tensor, sequence: str) -> torch.Tensor: |
| """ |
| Build the model input tensor from an ESM-2 embedding. |
| For supplemented models: appends z-scored seq/structural features. |
| For base models: returns the embedding as-is. |
| |
| Handles two supplemented feature sets: |
| - v3 models (in_dim=360): 39 sequence-derived features + 1 missingness flag |
| - supp_res2 (in_dim=709): 388 features including f_Dim_* (ESM dims), AF |
| structural features (zeros at inference), and sequence features |
| """ |
| import numpy as np |
| if not model_uses_supp or supp_mu is None: |
| return emb.unsqueeze(0) |
|
|
| feats = compute_seq_features(sequence) |
|
|
| |
| |
| emb_np = emb.detach().cpu().numpy() |
| for c in supp_cols: |
| if c.startswith('f_Dim_'): |
| try: |
| dim_idx = int(c.split('_')[-1]) |
| if dim_idx < len(emb_np): |
| feats[c] = float(emb_np[dim_idx]) |
| except (ValueError, IndexError): |
| pass |
|
|
| s_vec = np.array([feats.get(c, 0.0) for c in supp_cols], dtype=np.float32) |
| s_z = (s_vec - supp_mu) / (supp_sd + 1e-12) |
| |
| flag = np.array([1.0], dtype=np.float32) |
| extra = torch.from_numpy(np.concatenate([s_z, flag])) |
| full_input = torch.cat([emb, extra], dim=0).unsqueeze(0) |
|
|
| |
| |
| |
| |
| if model is not None and hasattr(model, 'fc_in'): |
| expected_dim = model.fc_in.weight.shape[1] |
| if full_input.shape[1] != expected_dim: |
| if full_input.shape[1] > expected_dim: |
| full_input = full_input[:, :expected_dim] |
| else: |
| return emb.unsqueeze(0) |
|
|
| return full_input |
|
|
|
|
| |
| _ESM_CACHE: dict = {} |
| _ESM_CACHE_MAX = 256 |
|
|
| def _get_esm_embedding(sequence: str) -> torch.Tensor: |
| """Return mean-pooled ESM-2 embedding, using in-memory cache.""" |
| key = hashlib.md5(sequence.encode()).hexdigest() |
| if key in _ESM_CACHE: |
| return _ESM_CACHE[key] |
| _, _, tokens = batch_converter([("p", sequence)]) |
| with torch.no_grad(): |
| rep = esm_model(tokens.to(device), repr_layers=[6])["representations"][6] |
| emb = rep[0, 1:len(sequence) + 1].mean(0).cpu() |
| if len(_ESM_CACHE) >= _ESM_CACHE_MAX: |
| _ESM_CACHE.pop(next(iter(_ESM_CACHE))) |
| _ESM_CACHE[key] = emb |
| return emb |
|
|
|
|
| @app.post("/predict") |
| async def predict(request: ProteinRequest): |
| import numpy as np |
| import urllib.request, urllib.error |
|
|
| entries = parse_sequences(request.sequence) |
| results = [] |
| device_cpu = torch.device("cpu") |
| mf_idx = app.state.mf_indices |
| uid = (request.uniprot_id or "").upper().strip() |
| req_taxon = (request.taxon or "auto").lower() |
|
|
| |
| uniprot_data = None |
| detected_taxon_from_uniprot = None |
| if uid and re.match(r'^[A-Z0-9]{4,10}$', uid): |
| try: |
| url = f"https://rest.uniprot.org/uniprotkb/{uid}.json?fields=go,organism,protein_name,gene_names,organism_lineage" |
| req_obj = urllib.request.Request(url, headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"}) |
| with urllib.request.urlopen(req_obj, timeout=8) as resp: |
| raw = json.loads(resp.read()) |
| organism = raw.get("organism", {}) |
| lineage = [x.get("scientificName", "") for x in organism.get("lineage", [])] |
| if any(t in lineage for t in INSECT_LINEAGE): |
| detected_taxon_from_uniprot = "insect" |
| elif any(t in lineage for t in MAMMAL_LINEAGE): |
| detected_taxon_from_uniprot = "mammal" |
| |
| go_mf_known = {} |
| tier_rank = {"experimental": 0, "computational": 1, "other": 2, "electronic": 3} |
| exp_ev = {"EXP","IDA","IPI","IMP","IGI","IEP","HTP","HDA","HMP","HGI","HEP"} |
| comp_ev = {"ISS","ISO","ISA","ISM","IGC","IBA","IBD","IKR","IRD","RCA"} |
| for xref in raw.get("uniProtKBCrossReferences", []): |
| if xref.get("database") != "GO": |
| continue |
| props = {p["key"]: p["value"] for p in xref.get("properties", [])} |
| if not props.get("GoTerm", "").startswith("F:"): |
| continue |
| go_id = xref.get("id", "") |
| ev_code = props.get("GoEvidenceType", "").split(":")[0] |
| ev_tier = "experimental" if ev_code in exp_ev else ( |
| "computational" if ev_code in comp_ev else ( |
| "electronic" if ev_code == "IEA" else "other")) |
| entry = {"go_id": go_id, "name": go_map.get(go_id, props["GoTerm"][2:]), |
| "evidence": ev_code, "ev_tier": ev_tier} |
| if go_id not in go_mf_known or tier_rank[ev_tier] < tier_rank[go_mf_known[go_id]["ev_tier"]]: |
| go_mf_known[go_id] = entry |
| uniprot_data = { |
| "go_mf_known": sorted(go_mf_known.values(), key=lambda x: (tier_rank[x["ev_tier"]], x["name"])), |
| "organism": organism.get("scientificName", ""), |
| "detected_taxon": detected_taxon_from_uniprot, |
| } |
| except Exception: |
| pass |
|
|
| for name, sequence in entries: |
| err = validate_sequence(name, sequence) |
| if err: |
| results.append({"name": name, "error": err}) |
| continue |
| if len(sequence) > 1500: |
| results.append({"name": name, "error": "Sequence too long (max 1500 aa)"}) |
| continue |
|
|
| try: |
| emb = _get_esm_embedding(sequence).to(device_cpu) |
| emb_np = emb.detach().cpu().numpy() |
|
|
| |
| taxon_source = req_taxon |
| taxon_conf = 1.0 |
| if req_taxon == "auto": |
| if detected_taxon_from_uniprot: |
| taxon_source = detected_taxon_from_uniprot |
| taxon_conf = 1.0 |
| elif taxon_probe is not None: |
| taxon_source, taxon_conf = _detect_taxon_probe(emb_np) |
| else: |
| taxon_source, taxon_conf = _detect_taxon_composition(sequence) |
|
|
| if taxon_source == "insect": |
| t_apply = insect_temperature |
| thresh_lookup = {} |
| t_default = insect_threshold_default |
| mammal_floor = 0.0 |
| else: |
| t_apply = temperature |
| thresh_lookup = thresholds |
| t_default = 0.68 |
| mammal_floor = 0.56 |
|
|
| |
| with torch.no_grad(): |
| inp = _build_model_input(emb, sequence) |
| logits = model(inp).squeeze() |
|
|
| |
| if platt_params and taxon_source != "insect": |
| probs_list = [] |
| for i in range(len(logits)): |
| l = float(logits[i]) |
| if str(i) in platt_params: |
| A, B = platt_params[str(i)] |
| p = 1.0 / (1.0 + math.exp(-(A * l + B))) |
| else: |
| p = 1.0 / (1.0 + math.exp(-l / t_apply)) |
| probs_list.append(p) |
| prob = torch.tensor(probs_list) |
| else: |
| prob = torch.sigmoid(logits / t_apply) |
|
|
| if prob.dim() == 0: |
| prob = prob.unsqueeze(0) |
|
|
| |
| raw_preds = [] |
| prob_map = {} |
| for i in mf_idx: |
| pv = float(prob[i]) |
| label_thresh = max(float(thresh_lookup.get(str(i), t_default)), mammal_floor) |
| if pv >= label_thresh: |
| go_id = mlb.classes_[i] |
| display_id = go_replaced.get(go_id, go_id) |
| display_nm = go_map.get(display_id, go_map.get(go_id, go_id)) |
| entry = {"go_id": display_id, "name": display_nm, |
| "prob": round(pv, 4), "depth": go_depth.get(display_id, -1)} |
| if display_id != go_id: |
| entry["original_id"] = go_id |
| raw_preds.append(entry) |
| prob_map[display_id] = pv |
| raw_preds.sort(key=lambda x: x["prob"], reverse=True) |
| cap = compute_dynamic_cap([p["prob"] for p in raw_preds], len(sequence)) |
| raw_preds = raw_preds[:cap] |
| for rp in raw_preds: |
| prob_map[rp["go_id"]] = rp["prob"] |
|
|
| visible, suppressed = propagate_and_filter(raw_preds, go_parents, go_ancestors, prob_map) |
|
|
| result = { |
| "name": name, |
| "sequence_length": len(sequence), |
| "predictions": visible, |
| "suppressed": suppressed, |
| "n_above_threshold": len(raw_preds), |
| "n_implied_parents": sum(1 for p in visible if p.get("implied")), |
| "taxon_applied": taxon_source, |
| "taxon_source": "uniprot" if detected_taxon_from_uniprot and req_taxon == "auto" |
| else ("probe" if taxon_probe and req_taxon == "auto" |
| else ("composition" if req_taxon == "auto" else "manual")), |
| "taxon_confidence": round(taxon_conf, 3), |
| "temperature_applied": round(t_apply, 4), |
| "platt_applied": bool(platt_params) and taxon_source != "insect", |
| } |
| if uniprot_data: |
| result["uniprot"] = uniprot_data |
| results.append(result) |
|
|
| except Exception as e: |
| results.append({"name": name, "error": str(e)}) |
|
|
| return {"results": results} |
|
|
|
|
| @app.post("/api/predict/batch") |
| async def predict_batch(request: BatchPredictRequest): |
| """ |
| Batch prediction endpoint for multiple sequences. |
| |
| Accepts a list of sequence objects and returns predictions for all. |
| More efficient than multiple single predictions due to batching. |
| """ |
| results = [] |
| mf_idx = app.state.mf_indices |
| custom_threshold = request.threshold |
| |
| for item in request.sequences: |
| name = item.get("name", "Unknown") |
| sequence = item.get("sequence", "") |
| |
| |
| err = validate_sequence(name, sequence) |
| if err: |
| results.append({"name": name, "error": err}) |
| continue |
| |
| if len(sequence) > 1500: |
| results.append({"name": name, "error": "Sequence too long (max 1500 aa)"}) |
| continue |
| |
| try: |
| emb = _get_esm_embedding(sequence).to(device) |
| with torch.no_grad(): |
| inp = _build_model_input(emb, sequence) |
| prob = torch.sigmoid(model(inp) / temperature).squeeze() |
|
|
| if prob.dim() == 0: |
| prob = prob.unsqueeze(0) |
|
|
| raw_preds = [] |
| prob_map = {} |
| for i in mf_idx: |
| pv = float(prob[i]) |
| thresh = float(thresholds.get(str(i), custom_threshold)) |
| if pv >= thresh: |
| go_id = mlb.classes_[i] |
| display_id = go_replaced.get(go_id, go_id) |
| display_nm = go_map.get(display_id, go_map.get(go_id, go_id)) |
| entry = { |
| "go_id": display_id, |
| "name": display_nm, |
| "prob": round(pv, 4), |
| "depth": go_depth.get(display_id, -1), |
| } |
| if display_id != go_id: |
| entry["original_id"] = go_id |
| raw_preds.append(entry) |
| prob_map[display_id] = pv |
| raw_preds.sort(key=lambda x: x["prob"], reverse=True) |
| cap = compute_dynamic_cap([p["prob"] for p in raw_preds], len(sequence)) |
| raw_preds = raw_preds[:cap] |
| for rp in raw_preds: |
| prob_map[rp["go_id"]] = rp["prob"] |
|
|
| visible, suppressed = propagate_and_filter( |
| raw_preds, go_parents, go_ancestors, prob_map |
| ) |
| if not request.include_suppressed: |
| suppressed = [] |
|
|
| results.append({ |
| "name": name, |
| "sequence_length": len(sequence), |
| "predictions": visible, |
| "suppressed": suppressed, |
| "n_above_threshold": len(raw_preds), |
| "n_implied_parents": sum(1 for p in visible if p.get("implied")), |
| }) |
| except Exception as e: |
| results.append({"name": name, "error": str(e)}) |
| |
| return { |
| "results": results, |
| "total": len(results), |
| "successful": sum(1 for r in results if "error" not in r), |
| } |
|
|
|
|
| class GoTermsRequest(BaseModel): |
| sequence: str |
| uniprot_id: str = "" |
| taxon: str = "auto" |
| top_k: int = 20 |
| include_implied: bool = False |
| min_prob: float = 0.0 |
|
|
|
|
| _go_terms_cache: dict = {} |
|
|
|
|
| @app.post("/api/go_terms") |
| async def get_go_terms(request: GoTermsRequest): |
| """ |
| Lightweight GO-MF prediction endpoint for programmatic/pipeline use. |
| Returns only predicted term list — no suppressed, no taxon explanation, |
| no Platt metadata. ~2× faster than /predict for downstream tools. |
| Results are cached by (sequence_hash, uniprot_id, taxon). |
| """ |
| import hashlib |
| seq = request.sequence.strip() |
| cache_key = (hashlib.md5(seq.encode()).hexdigest(), request.uniprot_id.upper(), request.taxon) |
| if cache_key in _go_terms_cache: |
| return _go_terms_cache[cache_key] |
|
|
| err = validate_sequence("seq", seq) |
| if err: |
| return {"error": err, "predictions": []} |
| if len(seq) > 1500: |
| return {"error": "Sequence too long (max 1500 aa)", "predictions": []} |
|
|
| try: |
| mf_idx = app.state.mf_indices |
| emb = _get_esm_embedding(seq).to(device) |
| emb_np = emb.detach().cpu().numpy() |
|
|
| taxon_source = request.taxon |
| taxon_conf = 1.0 |
| if taxon_source == "auto": |
| if taxon_probe is not None: |
| taxon_source, taxon_conf = _detect_taxon_probe(emb_np) |
| else: |
| taxon_source, taxon_conf = _detect_taxon_composition(seq) |
|
|
| if taxon_source == "insect": |
| t_apply = insect_temperature |
| thresh_lookup = {} |
| t_default = insect_threshold_default |
| mammal_floor = 0.0 |
| else: |
| t_apply = temperature |
| thresh_lookup = thresholds |
| t_default = 0.68 |
| mammal_floor = 0.56 |
|
|
| with torch.no_grad(): |
| inp = _build_model_input(emb, seq) |
| logits = model(inp).squeeze() |
|
|
| if platt_params and taxon_source != "insect": |
| probs_list = [] |
| for i in range(len(logits)): |
| l = float(logits[i]) |
| if str(i) in platt_params: |
| A, B = platt_params[str(i)] |
| p = 1.0 / (1.0 + math.exp(-(A * l + B))) |
| else: |
| p = 1.0 / (1.0 + math.exp(-l / t_apply)) |
| probs_list.append(p) |
| prob = torch.tensor(probs_list) |
| else: |
| prob = torch.sigmoid(logits / t_apply) |
|
|
| if prob.dim() == 0: |
| prob = prob.unsqueeze(0) |
|
|
| raw_preds = [] |
| for i in mf_idx: |
| pv = float(prob[i]) |
| label_thresh = max(float(thresh_lookup.get(str(i), t_default)), mammal_floor) |
| if pv >= label_thresh: |
| go_id = mlb.classes_[i] |
| display_id = go_replaced.get(go_id, go_id) |
| display_nm = go_map.get(display_id, go_map.get(go_id, go_id)) |
| raw_preds.append({ |
| "go_id": display_id, |
| "name": display_nm, |
| "prob": round(pv, 4), |
| "depth": go_depth.get(display_id, -1), |
| }) |
|
|
| raw_preds.sort(key=lambda x: x["prob"], reverse=True) |
| cap = compute_dynamic_cap([p["prob"] for p in raw_preds], len(seq)) |
| raw_preds = raw_preds[:cap] |
|
|
| if not request.include_implied: |
| predictions = raw_preds[:request.top_k] |
| else: |
| prob_map = {p["go_id"]: p["prob"] for p in raw_preds} |
| visible, _ = propagate_and_filter(raw_preds, go_parents, go_ancestors, prob_map) |
| predictions = [p for p in visible if not p.get("implied")][:request.top_k] |
|
|
| if request.min_prob > 0: |
| predictions = [p for p in predictions if p["prob"] >= request.min_prob] |
|
|
| result = { |
| "predictions": predictions, |
| "n_predicted": len(predictions), |
| "taxon_applied": taxon_source, |
| "taxon_conf": round(taxon_conf, 3), |
| "taxon_source_method": ( |
| "probe" if taxon_probe is not None and request.taxon == "auto" |
| else ("composition" if request.taxon == "auto" else "manual") |
| ), |
| "platt_applied": bool(platt_params) and taxon_source != "insect", |
| "threshold_default": t_default, |
| } |
| _go_terms_cache[cache_key] = result |
| return result |
|
|
| except Exception as e: |
| return {"error": str(e)[:300], "predictions": []} |
|
|
|
|
| @app.get("/api/explain_terms") |
| async def explain_terms(ids: str = ""): |
| """ |
| Return name + definition for a comma-separated list of GO IDs. |
| Used by the frontend "Why?" panel to show term descriptions inline. |
| """ |
| if not ids: |
| return {"terms": []} |
| result = [] |
| for gid in ids.split(",")[:50]: |
| gid = gid.strip() |
| if not gid: |
| continue |
| name = go_map.get(gid, gid) |
| defn = go_defs.get(gid, "") |
| result.append({"id": gid, "name": name, "definition": defn}) |
| return {"terms": result} |
|
|
|
|
| @app.post("/api/saliency") |
| async def compute_saliency(request: SaliencyRequest): |
| """ |
| Compute per-residue gradient saliency for a protein sequence. |
| Uses d(sum_of_top_k_probs)/d(ESM_residue_representations) via backprop. |
| Optionally fetches AlphaFold structure if uniprot_id is provided. |
| Returns normalized per-residue importance scores in [0, 1]. |
| """ |
| import numpy as np |
| import urllib.request, urllib.error |
|
|
| sequence = re.sub(r"\s+", "", request.sequence.upper()) |
| if not sequence: |
| return {"error": "Empty sequence"} |
| if len(sequence) > 1200: |
| return {"error": "Sequence too long for saliency (max 1200 aa)"} |
|
|
| err = validate_sequence("query", sequence) |
| if err: |
| return {"error": err} |
|
|
| taxon = (request.taxon or "auto").lower() |
| t_apply = insect_temperature if taxon == "insect" else temperature |
|
|
| try: |
| _, _, tokens = batch_converter([("p", sequence)]) |
| tokens = tokens.to(device) |
| L = len(sequence) |
|
|
| with torch.enable_grad(): |
| |
| out = esm_model(tokens, repr_layers=[6]) |
| residue_reps = out["representations"][6] |
| residue_reps.retain_grad() |
|
|
| |
| emb = residue_reps[0, 1:L + 1].mean(0) |
|
|
| |
| if model_uses_supp and supp_mu is not None: |
| feats = compute_seq_features(sequence) |
| s_vec = torch.tensor( |
| [(feats.get(c, 0.0) - float(supp_mu[j])) / (float(supp_sd[j]) + 1e-12) |
| for j, c in enumerate(supp_cols)], |
| dtype=torch.float32, device=device |
| ) |
| |
| flag = torch.ones(1, dtype=torch.float32, device=device) |
| inp_full = torch.cat([emb, s_vec, flag]).unsqueeze(0) |
| expected = model.fc_in.weight.shape[1] |
| inp = inp_full[:, :expected] |
| else: |
| inp = emb.unsqueeze(0) |
|
|
| logits = model(inp) / t_apply |
| probs = torch.sigmoid(logits[0]) |
|
|
| |
| k = min(request.top_k, int((probs > 0.3).sum().item()), 8124) |
| k = max(k, 5) |
| top_vals = probs.topk(k).values |
| objective = top_vals.sum() |
| objective.backward() |
|
|
| scores = [0.0] * L |
| if residue_reps.grad is not None: |
| grad = residue_reps.grad[0, 1:L + 1].norm(dim=-1).detach().cpu().numpy() |
| mn, mx = grad.min(), grad.max() |
| scores = ((grad - mn) / (mx - mn + 1e-8)).tolist() |
|
|
| |
| structure = None |
| uid = request.uniprot_id.upper().strip() |
| if uid and re.match(r'^[A-Z0-9]{4,10}$', uid): |
| try: |
| req = urllib.request.Request( |
| f"https://alphafold.ebi.ac.uk/api/prediction/{uid}", |
| headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"} |
| ) |
| with urllib.request.urlopen(req, timeout=8) as resp: |
| d = json.loads(resp.read())[0] |
| structure = { |
| "found": True, |
| "accession": uid, |
| "cif_url": d.get("cifUrl", ""), |
| "pdb_url": d.get("pdbUrl", ""), |
| "organism": d.get("organismScientificName", ""), |
| "gene": d.get("gene", ""), |
| "entry_url": f"https://alphafold.ebi.ac.uk/entry/{uid}", |
| "uniprot_url": f"https://www.uniprot.org/uniprot/{uid}", |
| } |
| except Exception: |
| structure = {"found": False, "accession": uid} |
|
|
| return { |
| "sequence_length": L, |
| "per_residue_scores": scores, |
| "taxon": taxon, |
| "structure": structure, |
| } |
| except Exception as e: |
| return {"error": str(e)[:200]} |
|
|
|
|
| @app.post("/api/explainability") |
| async def compute_explainability(request: ExplainRequest): |
| """ |
| Compute feature-level importance for the 11 sequence features via gradient × input. |
| Also returns per-residue feature maps for 3D structure coloring. |
| Gradient flows only through the first 11 supp features; ESM embedding is detached. |
| """ |
| import numpy as np |
| import urllib.request, urllib.error |
|
|
| sequence = re.sub(r"\s+", "", request.sequence.upper()) |
| if not sequence: |
| return {"error": "Empty sequence"} |
| if len(sequence) > 1200: |
| return {"error": "Sequence too long for explainability (max 1200 aa)"} |
| err = validate_sequence("query", sequence) |
| if err: |
| return {"error": err} |
| if not model_uses_supp or supp_mu is None: |
| return {"error": "Feature importance requires a supplemented model (unified_v1 or later)"} |
|
|
| taxon = (request.taxon or "auto").lower() |
| t_apply = insect_temperature if taxon == "insect" else temperature |
|
|
| try: |
| emb = _get_esm_embedding(sequence).to(device) |
|
|
| feats = compute_seq_features(sequence) |
| s_vec = np.array([feats.get(c, 0.0) for c in supp_cols], dtype=np.float32) |
| s_z = (s_vec - supp_mu) / (supp_sd + 1e-12) |
|
|
| n_tracked = min(11, len(supp_cols)) |
| s_z_11 = torch.tensor(s_z[:n_tracked], requires_grad=True, |
| dtype=torch.float32, device=device) |
| s_z_rest = torch.tensor(s_z[n_tracked:], dtype=torch.float32, device=device) |
| flag = torch.ones(1, dtype=torch.float32, device=device) |
|
|
| inp_full = torch.cat([emb.detach(), s_z_11, s_z_rest, flag]).unsqueeze(0) |
| expected = model.fc_in.weight.shape[1] |
| inp = inp_full[:, :expected] |
|
|
| with torch.enable_grad(): |
| logits = model(inp) / t_apply |
| probs = torch.sigmoid(logits[0]) |
| k = min(request.top_k, int((probs > 0.3).sum().item()), 8124) |
| k = max(k, 5) |
| probs.topk(k).values.sum().backward() |
|
|
| if s_z_11.grad is None: |
| return {"error": "Gradient computation failed — no gradient on supp features"} |
|
|
| grad_np = s_z_11.grad.detach().cpu().numpy() |
| s_z_11np = s_z_11.detach().cpu().numpy() |
| attribution = grad_np * s_z_11np |
|
|
| per_residue_maps = compute_per_residue_features(sequence) |
|
|
| feat_meta_by_key = {fm["key"]: fm for fm in FEATURE_META} |
| features = [] |
| for i in range(n_tracked): |
| col = supp_cols[i] |
| meta = feat_meta_by_key.get(col, {"key": col, "label": col, "desc": col, "color": "#888888"}) |
| attr = float(attribution[i]) |
| features.append({ |
| "key": col, |
| "label": meta["label"], |
| "desc": meta["desc"], |
| "color": meta["color"], |
| "importance": round(attr, 4), |
| "abs_importance": round(abs(attr), 4), |
| "per_residue": per_residue_maps.get(col, [0.5] * len(sequence)), |
| }) |
| features.sort(key=lambda x: x["abs_importance"], reverse=True) |
|
|
| |
| structure = None |
| uid = request.uniprot_id.upper().strip() |
| if uid and re.match(r'^[A-Z0-9]{4,10}$', uid): |
| try: |
| req = urllib.request.Request( |
| f"https://alphafold.ebi.ac.uk/api/prediction/{uid}", |
| headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"} |
| ) |
| with urllib.request.urlopen(req, timeout=8) as resp: |
| d = json.loads(resp.read())[0] |
| structure = { |
| "found": True, |
| "accession": uid, |
| "cif_url": d.get("cifUrl", ""), |
| "organism": d.get("organismScientificName", ""), |
| "gene": d.get("gene", ""), |
| "entry_url": f"https://alphafold.ebi.ac.uk/entry/{uid}", |
| "uniprot_url": f"https://www.uniprot.org/uniprot/{uid}", |
| } |
| except Exception: |
| structure = {"found": False, "accession": uid} |
|
|
| return { |
| "sequence_length": len(sequence), |
| "features": features, |
| "top_feature": features[0]["key"] if features else None, |
| "structure": structure, |
| "taxon": taxon, |
| } |
| except Exception as e: |
| return {"error": str(e)[:300]} |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|