from fastapi import FastAPI from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from contextlib import asynccontextmanager from pydantic import BaseModel import torch import torch.nn as nn import pandas as pd import joblib import json import math import os import re import time import hashlib import warnings from functools import lru_cache warnings.filterwarnings("ignore") # Use all available CPU threads for faster ESM inference torch.set_num_threads(min(os.cpu_count() or 4, 8)) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) STATIC_DIR = os.path.join(BASE_DIR, "static") os.makedirs(STATIC_DIR, exist_ok=True) HF_REPO = "Sbhat2026/protfunc-models" # Priority order: 35M > unified_v1 > mammal_enriched > v3_fixed > v3 > improved > supp_res2 > baseline HF_FILES = [ "unified_35M_v1.pth", "unified_35M_v1_thresholds.json", "unified_v1.pth", "unified_v1_recalibrated.json", "mammal_enriched.pth", "mammal_enriched_thresholds.json", "protfunc_v3_fixed.pth", "protfunc_v3_fixed_thresholds.json", "protfunc_v3.pth", "protfunc_v3_thresholds.json", "improved_res.pth", "improved_per_label_thresholds.json", "supp_res2.pth", "baseline_res.pth", "mlb_public_v1.pkl", "go_names.json", ] OPTIONAL = { "go_names.json", "unified_35M_v1.pth", "unified_35M_v1_thresholds.json", "mammal_enriched.pth", "mammal_enriched_thresholds.json", "protfunc_v3_fixed.pth", "protfunc_v3_fixed_thresholds.json", "protfunc_v3.pth", "protfunc_v3_thresholds.json", "improved_res.pth", "improved_per_label_thresholds.json", "supp_res2.pth", } # Globals populated during lifespan startup device = torch.device("cpu") # always CPU on HF Space model = None esm_model = None batch_converter = None mlb = None go_map = {} go_defs = {} # GO ID -> definition string (from OBO def: field) mf_terms = set() go_parents = {} # GO ID -> set of direct parent GO IDs (MF DAG) go_ancestors = {} # GO ID -> full set of ancestor GO IDs (transitive) go_depth = {} # GO ID -> min depth from MF root (root = 0) go_replaced = {} # obsolete GO ID -> replacement GO ID mf_indices = None thresholds = {} # label_idx (str) -> float threshold (mammal-calibrated) temperature = 1.0 # temperature scaling T for mammal inference (logit/T before sigmoid) # Insect-specific inference params (flat threshold, no temperature scaling needed) insect_temperature = 1.0 insect_threshold_default = 0.68 NUM_LABELS = 0 _ESM_DIM = 320 # updated to 480 when unified_35M_v1 is loaded # Supplemented model stats (loaded from checkpoint if present) supp_mu = None # np.ndarray shape (SUPP_DIM,) supp_sd = None # np.ndarray shape (SUPP_DIM,) supp_cols = None # list[str] model_uses_supp = False # True when model expects supp features # Taxon probe (logistic regression on ESM embeddings, loaded from taxon_probe.json) taxon_probe = None # dict with scaler_mean, scaler_std, coef, intercept # Platt scaling (per-label logistic regression on logits, loaded from platt_mammal.json) platt_params = {} # label_idx str -> [A, B] # UniProt taxonomy + annotation cache _uniprot_cache: dict = {} # Biological complexity filter constants MIN_SEQ_LENGTH = 30 MIN_ENTROPY_BITS = 2.5 MAX_DOMINANT_FRAC = 0.60 MIN_DISTINCT_AA = 5 INVALID_AA = set("BJOUXZ") MF_ROOT = "GO:0003674" # Kyte-Doolittle hydrophobicity scale _KD = {'A':1.8,'R':-4.5,'N':-3.5,'D':-3.5,'C':2.5,'Q':-3.5,'E':-3.5, 'G':-0.4,'H':-3.2,'I':4.5,'L':3.8,'K':-3.9,'M':1.9,'F':2.8, 'P':-1.6,'S':-0.8,'T':-0.7,'W':-0.9,'Y':-1.3,'V':4.2} # Chou-Fasman helix/sheet propensities _CF_HELIX = {'A':1.42,'R':0.98,'N':0.67,'D':1.01,'C':0.70,'Q':1.11,'E':1.51, 'G':0.57,'H':1.00,'I':1.08,'L':1.21,'K':1.16,'M':1.45,'F':1.13, 'P':0.57,'S':0.77,'T':0.83,'W':1.08,'Y':0.69,'V':1.06} _CF_SHEET = {'A':0.83,'R':0.93,'N':0.89,'D':0.54,'C':1.19,'Q':1.10,'E':0.37, 'G':0.75,'H':0.87,'I':1.60,'L':1.30,'K':0.74,'M':1.05,'F':1.38, 'P':0.55,'S':0.75,'T':1.19,'W':1.37,'Y':1.47,'V':1.70} # Disorder-promoting residues (Uversky) _DISORDER_PROMOTING = set("AERSQGKPTD") _TM_HYDROPHOBIC = set("AVILMFYW") _CHARGE = {"K": 1.0, "R": 1.0, "D": -1.0, "E": -1.0} INSECT_LINEAGE = {"Insecta", "Hexapoda", "Arthropoda", "Chelicerata", "Myriapoda"} MAMMAL_LINEAGE = {"Mammalia", "Theria", "Eutheria", "Metatheria", "Monotremata"} def _detect_taxon_composition(seq: str) -> tuple: """ Heuristic taxon detection from amino acid composition. Uses a linear discriminant calibrated from insect/mammal proteome statistics. Returns ('insect'|'mammal', confidence_float). Confidence < 0.60 → ambiguous, caller should treat as 'auto'. """ n = len(seq) if n < 60: return "mammal", 0.50 seq_u = seq.upper() freq = {aa: seq_u.count(aa) / n for aa in "ACDEFGHIKLMNPQRSTVWY"} # Linear discriminant (positive = mammal, negative = insect) # Derived from empirical insect vs mammal proteome frequency differences score = 0.0 score += (freq.get("K", 0) - 0.058) * 18.0 # Lys enriched in mammals score += (freq.get("G", 0) - 0.073) * -12.0 # Gly enriched in insects score += (freq.get("A", 0) - 0.072) * -9.0 # Ala enriched in insects score += (freq.get("S", 0) - 0.077) * 7.0 # Ser enriched in mammals score += (freq.get("T", 0) - 0.056) * 6.0 # Thr enriched in mammals score += (freq.get("P", 0) - 0.055) * -5.0 # Pro enriched in insects score += (freq.get("R", 0) - 0.053) * 4.0 # Arg enriched in mammals p_mammal = 1.0 / (1.0 + math.exp(-score * 2.5)) if p_mammal >= 0.68: return "mammal", round(p_mammal, 3) elif p_mammal <= 0.32: return "insect", round(1.0 - p_mammal, 3) return "mammal", 0.50 # uncertain → default mammal def _detect_taxon_probe(emb_np) -> tuple: """Use the trained logistic probe on ESM embedding if loaded.""" if taxon_probe is None: return None, 0.0 import numpy as np w = taxon_probe["coef"] b = taxon_probe["intercept"] mu = taxon_probe["scaler_mean"] sd = taxon_probe["scaler_std"] x = (emb_np - mu) / (sd + 1e-12) logit = float(np.dot(x, w) + b) p_mammal = 1.0 / (1.0 + math.exp(-logit)) if p_mammal >= 0.65: return "mammal", round(p_mammal, 3) elif p_mammal <= 0.35: return "insect", round(1.0 - p_mammal, 3) return "mammal", 0.50 FEATURE_META = [ {"key": "f_seq_len", "label": "Sequence Length", "desc": "Global protein length (uniform per-residue)", "color": "#888888"}, {"key": "f_mean_hydro", "label": "Hydrophobicity", "desc": "Kyte-Doolittle hydrophobicity (window=5)", "color": "#f4a261"}, {"key": "f_net_charge", "label": "Net Charge", "desc": "Local charge balance K+R−D−E (window=9)", "color": "#457b9d"}, {"key": "f_uversky_disorder", "label": "Disorder Score", "desc": "Uversky charge-hydrophobicity disorder criterion (window=11)", "color": "#9b5de5"}, {"key": "f_idr_frac_proxy", "label": "IDR Residues", "desc": "Disorder-promoting residues: A,E,R,S,Q,G,K,P,T,D", "color": "#00b4d8"}, {"key": "f_lowcomp_proxy", "label": "Low Complexity", "desc": "Repetitive amino acid runs (length ≥5)", "color": "#adb5bd"}, {"key": "f_tm_frac_proxy", "label": "TM Helix Windows", "desc": "Transmembrane helix windows (≥17/20 hydrophobic residues)", "color": "#e63946"}, {"key": "f_tm_any_proxy", "label": "TM Present", "desc": "Presence of any transmembrane window", "color": "#c1121f"}, {"key": "f_signal_peptide_proxy", "label": "Signal Peptide", "desc": "N-terminal hydrophobic signal (linear decay, first 30 aa)", "color": "#2d6a4f"}, {"key": "f_cf_helix_mean", "label": "α-Helix Propensity", "desc": "Chou-Fasman α-helix propensity per residue", "color": "#4361ee"}, {"key": "f_cf_sheet_mean", "label": "β-Sheet Propensity", "desc": "Chou-Fasman β-sheet propensity per residue", "color": "#e76f51"}, ] def compute_seq_features(seq: str) -> dict: """ Compute the 11 sequence-based supplementary features that are always available at inference time. Returns a dict keyed by SUPP_COL names. AF-derived features (f_afdb_has_model, f_plddt_*, f_distbin_*, f_pae_*, f_seqfeat_present, f_af_present) are set to 0 — they will be z-scored to near-zero against training means where >99% of proteins also had no AF data. """ seq_u = seq.upper() n = len(seq_u) kd = [_KD.get(aa, 0.0) for aa in seq_u] mean_hydro = sum(kd) / n net_charge = (seq_u.count('R') + seq_u.count('K') - seq_u.count('D') - seq_u.count('E')) / n # Uversky charge-hydrophobicity disorder criterion uversky_disorder = float(abs(mean_hydro) - abs(net_charge) < 0.06) idr_frac = sum(1 for aa in seq_u if aa in _DISORDER_PROMOTING) / n # Low-complexity: runs of the same amino acid lowcomp = 0 i, prev, run = 0, '', 0 for aa in seq_u: run = run + 1 if aa == prev else 1 if run >= 5: lowcomp += 1 prev = aa lowcomp_proxy = lowcomp / n # TM helix proxy: windows of ≥17 hydrophobic residues in 20-aa window tm_count = 0 for i in range(n - 19): window = seq_u[i:i+20] if sum(1 for aa in window if aa in _TM_HYDROPHOBIC) >= 17: tm_count += 1 tm_frac = tm_count / max(n - 19, 1) tm_any = float(tm_count > 0) # Signal peptide proxy: first 30 aa have a hydrophobic core sp_window = seq_u[:30] sp_proxy = float(sum(1 for aa in sp_window if aa in _TM_HYDROPHOBIC) >= 8) # Chou-Fasman secondary structure propensity cf_helix = sum(_CF_HELIX.get(aa, 1.0) for aa in seq_u) / n cf_sheet = sum(_CF_SHEET.get(aa, 1.0) for aa in seq_u) / n return { "f_seq_len": float(n), "f_mean_hydro": float(mean_hydro), "f_net_charge": float(net_charge), "f_uversky_disorder": float(uversky_disorder), "f_idr_frac_proxy": float(idr_frac), "f_lowcomp_proxy": float(lowcomp_proxy), "f_tm_frac_proxy": float(tm_frac), "f_tm_any_proxy": float(tm_any), "f_signal_peptide_proxy":float(sp_proxy), "f_cf_helix_mean": float(cf_helix), "f_cf_sheet_mean": float(cf_sheet), # AF-derived features: absent at inference → use 0 (imputed to mean) "f_afdb_has_model":0.0,"f_plddt_mean":0.0,"f_plddt_std":0.0, "f_plddt_q10":0.0,"f_plddt_q50":0.0,"f_plddt_q90":0.0, "f_plddt_frac_gt90":0.0,"f_plddt_frac_gt70":0.0,"f_plddt_frac_lt50":0.0, "f_distbin_0":0.0,"f_distbin_1":0.0,"f_distbin_2":0.0,"f_distbin_3":0.0, "f_distbin_4":0.0,"f_distbin_5":0.0,"f_distbin_6":0.0,"f_distbin_7":0.0, "f_distbin_8":0.0,"f_distbin_9":0.0, "f_pae_mean":0.0,"f_pae_median":0.0,"f_pae_p90":0.0,"f_pae_p95":0.0, "f_pae_frac_lt5":0.0,"f_pae_frac_lt10":0.0,"f_pae_frac_gt20":0.0, "f_seqfeat_present":1.0,"f_af_present":0.0, } def compute_per_residue_features(seq: str) -> dict: """Return per-residue vectors (normalized [0,1]) for the 11 supp features.""" import numpy as np seq_u = seq.upper() n = len(seq_u) def smooth(arr, w): hw = w // 2 out = [] for i in range(n): lo, hi = max(0, i - hw), min(n, i + hw + 1) out.append(sum(arr[lo:hi]) / (hi - lo)) return out def normalize(arr): mn, mx = min(arr), max(arr) if mx == mn: return [0.5] * n return [(v - mn) / (mx - mn) for v in arr] kd_raw = [_KD.get(aa, 0.0) for aa in seq_u] charge_raw = [_CHARGE.get(aa, 0.0) for aa in seq_u] # f_seq_len: uniform, no per-residue signal r_seq_len = [1.0] * n # f_mean_hydro: KD per residue smoothed window=5 r_mean_hydro = normalize(smooth(kd_raw, 5)) # f_net_charge: sliding charge window=9 r_net_charge = normalize(smooth(charge_raw, 9)) # f_uversky_disorder: window=11, high = more disordered tendency uv_raw = [] for i in range(n): lo, hi = max(0, i - 5), min(n, i + 6) wkd = sum(kd_raw[lo:hi]) / (hi - lo) wch = sum(charge_raw[lo:hi]) / (hi - lo) uv_raw.append(max(0.0, 0.20 - (abs(wkd) - abs(wch)))) r_uversky = normalize(uv_raw) # f_idr_frac_proxy: binary disorder residue indicator, smoothed idr_raw = [1.0 if aa in _DISORDER_PROMOTING else 0.0 for aa in seq_u] r_idr = normalize(smooth(idr_raw, 7)) # f_lowcomp_proxy: residues in amino-acid runs ≥5 lowcomp_raw = [0.0] * n prev, run = '', 0 for j, aa in enumerate(seq_u): run = run + 1 if aa == prev else 1 prev = aa if run >= 5: for k in range(max(0, j - run + 1), j + 1): lowcomp_raw[k] = 1.0 r_lowcomp = lowcomp_raw # f_tm_frac_proxy: residues covered by TM windows (≥17/20 hydrophobic) tm_raw = [0.0] * n if n >= 20: for i in range(n - 19): if sum(1 for aa in seq_u[i:i+20] if aa in _TM_HYDROPHOBIC) >= 17: for k in range(i, i + 20): tm_raw[k] = 1.0 r_tm_frac = tm_raw # f_tm_any_proxy: same map (at residue level = tm_frac) r_tm_any = tm_raw # f_signal_peptide_proxy: linear decay × hydrophobicity, first 30 aa sp_mod = [] for i in range(n): weight = max(0.0, 1.0 - i / 30.0) sp_mod.append(weight * max(0.0, kd_raw[i] / 4.5)) r_sp = normalize(sp_mod) if any(v > 0 for v in sp_mod) else [max(0.0, 1.0 - i / 30.0) for i in range(n)] # f_cf_helix_mean / f_cf_sheet_mean: propensity per residue smoothed r_cf_helix = normalize(smooth([_CF_HELIX.get(aa, 1.0) for aa in seq_u], 3)) r_cf_sheet = normalize(smooth([_CF_SHEET.get(aa, 1.0) for aa in seq_u], 3)) return { "f_seq_len": [round(v, 3) for v in r_seq_len], "f_mean_hydro": [round(v, 3) for v in r_mean_hydro], "f_net_charge": [round(v, 3) for v in r_net_charge], "f_uversky_disorder": [round(v, 3) for v in r_uversky], "f_idr_frac_proxy": [round(v, 3) for v in r_idr], "f_lowcomp_proxy": [round(v, 3) for v in r_lowcomp], "f_tm_frac_proxy": [round(v, 3) for v in r_tm_frac], "f_tm_any_proxy": [round(v, 3) for v in r_tm_any], "f_signal_peptide_proxy": [round(v, 3) for v in r_sp], "f_cf_helix_mean": [round(v, 3) for v in r_cf_helix], "f_cf_sheet_mean": [round(v, 3) for v in r_cf_sheet], } def build_ancestor_cache(go_parents_map: dict) -> dict: """Compute full transitive ancestor sets for all GO terms (memoized DFS).""" cache = {} def _anc(gid): if gid in cache: return cache[gid] parents = go_parents_map.get(gid, set()) all_anc = set(parents) for p in parents: all_anc |= _anc(p) cache[gid] = all_anc return all_anc for gid in go_parents_map: _anc(gid) return cache def _download_with_retry(fname): from huggingface_hub import hf_hub_download max_attempts = 6 for attempt in range(1, max_attempts + 1): try: print(f" [{attempt}/{max_attempts}] Downloading {fname}...") path = hf_hub_download( repo_id=HF_REPO, filename=fname, local_dir=BASE_DIR, repo_type="model", token=os.environ.get("HF_TOKEN"), ) print(f" saved -> {path}") return except Exception as e: if fname in OPTIONAL: print(f" {fname} is optional, skipping ({e})") return if attempt == max_attempts: raise RuntimeError(f"Could not download '{fname}' after {max_attempts} attempts: {e}") wait = 2 ** attempt print(f" Network error, retrying in {wait}s... ({e})") time.sleep(wait) def ensure_model_files(): missing = [f for f in HF_FILES if not os.path.exists(os.path.join(BASE_DIR, f))] if not missing: print("All model files already present.") return print(f"Downloading {len(missing)} file(s) from HuggingFace Hub...") for fname in missing: _download_with_retry(fname) def load_go_map(): try: df = pd.read_csv(os.path.join(BASE_DIR, "go_annotations_fixed.csv")) mapping = {} for _, row in df.iterrows(): go_id = str(row["GO Annotation"]).strip() raw_name = str(row.get("Gene Ontology (molecular function)", "Unknown")) mapping[go_id] = raw_name.split(" [")[0].strip() print(f"GO map: {len(mapping)} labels loaded") return mapping except Exception as e: print(f"GO map load error: {e}") return {} def load_thresholds(): """ Load per-label thresholds and return (mammal_thresholds, mammal_T, insect_T, insect_t_default). Threshold JSON formats accepted: {"0": 0.68, ...} — plain float {"0": {"threshold": 0.68, "tier": 0, "temperature": 3.69}} — rich dict {"_meta": {...}, "0": 0.68, ...} — new format with metadata Returns: flat : dict str_idx -> float (mammal per-label thresholds) mammal_T : float temperature for mammal inference ins_T : float temperature for insect inference (usually 1.0) ins_t_default : float flat threshold for insect labels with no per-label data """ for path in [ os.path.join(BASE_DIR, "unified_35M_v1_thresholds.json"), # 35M model thresholds (latest) os.path.join(BASE_DIR, "unified_v1_recalibrated.json"), # 8M recalibrated: T=3.85, precision-tuned os.path.join(BASE_DIR, "unified_v1_thresholds.json"), os.path.join(BASE_DIR, "mammal_enriched_thresholds.json"), os.path.join(BASE_DIR, "protfunc_v3_fixed_thresholds.json"), os.path.join(BASE_DIR, "improved_per_label_thresholds.json"), os.path.join(BASE_DIR, "protfunc_v3_thresholds.json"), os.path.join(BASE_DIR, "per_label_thresholds.json"), os.path.join(BASE_DIR, "artifacts", "per_label_thresholds.json"), ]: if not os.path.exists(path): continue print(f"Thresholds loaded from {path}") with open(path) as f: raw = json.load(f) flat = {} mammal_T = 1.0 ins_T = 1.0 ins_t_def = 0.68 # Extract metadata block if present meta = raw.pop("_meta", None) if meta: mammal_T = float(meta.get("temperature", mammal_T)) ins_T = float(meta.get("insect_temperature", ins_T)) ins_t_def = float(meta.get("insect_global_t", ins_t_def)) for k, v in raw.items(): if isinstance(v, dict): flat[k] = float(v.get("threshold", 0.5)) mammal_T = float(v.get("temperature", mammal_T)) else: try: flat[k] = float(v) except (TypeError, ValueError): pass print(f" {len(flat)} mammal thresholds | mammal_T={mammal_T:.4f} | " f"insect_T={ins_T:.4f} insect_t={ins_t_def:.2f}") return flat, mammal_T, ins_T, ins_t_def print("Thresholds not found, using defaults") return {}, 1.0, 1.0, 0.68 def parse_obo(path): """ Parse go-basic.obo and return: mf_terms : set of active (non-obsolete) GO IDs with namespace == molecular_function go_parents : dict GO ID -> set of direct parent GO IDs (is_a + part_of, MF only) go_names_ob : dict GO ID -> canonical name from OBO (authoritative) go_replaced : dict obsolete GO ID -> replacement GO ID go_depth : dict GO ID -> minimum depth from MF root (root = 0) All relationships are restricted to the MF namespace. """ ns_map = {} # id -> namespace par_map = {} # id -> {parent ids} name_map = {} # id -> canonical name def_map = {} # id -> definition string rep_map = {} # obsolete id -> replaced_by id alt_map = {} # alt_id -> canonical id obs_set = set() cur_id = None cur_ns = None cur_nm = None cur_df = None cur_par = set() cur_rep = None cur_obs = False cur_alt = [] in_term = False def flush(): nonlocal cur_id, cur_ns, cur_nm, cur_df, cur_par, cur_rep, cur_obs, cur_alt if cur_id: if cur_obs: obs_set.add(cur_id) if cur_rep: rep_map[cur_id] = cur_rep else: ns_map[cur_id] = cur_ns or "" par_map[cur_id] = cur_par name_map[cur_id] = cur_nm or cur_id if cur_df: def_map[cur_id] = cur_df for a in cur_alt: alt_map[a] = cur_id cur_id = None; cur_ns = None; cur_nm = None; cur_df = None cur_par = set(); cur_rep = None; cur_obs = False; cur_alt = [] with open(path, "r", encoding="utf-8") as fh: for raw in fh: line = raw.strip() if line == "[Term]": flush(); in_term = True; continue if line.startswith("[") and line != "[Term]": flush(); in_term = False; continue if not in_term: continue if line.startswith("id:"): cur_id = line.split("id:", 1)[1].strip().split()[0] elif line.startswith("name:"): cur_nm = line.split("name:", 1)[1].strip() elif line.startswith("namespace:"): cur_ns = line.split("namespace:", 1)[1].strip() elif line.startswith("alt_id:"): cur_alt.append(line.split("alt_id:", 1)[1].strip().split()[0]) elif line.startswith("def:"): # def: "description text" [source] — strip quotes and source raw_def = line.split("def:", 1)[1].strip() if raw_def.startswith('"'): end_q = raw_def.find('"', 1) cur_df = raw_def[1:end_q] if end_q > 0 else raw_def else: cur_df = raw_def.split("[")[0].strip() elif line.startswith("is_obsolete:") and "true" in line: cur_obs = True elif line.startswith("replaced_by:"): cur_rep = line.split("replaced_by:", 1)[1].strip().split()[0] elif line.startswith("is_a:"): parent = line.split("is_a:", 1)[1].strip().split()[0] cur_par.add(parent) elif line.startswith("relationship:"): parts = line.split("relationship:", 1)[1].strip().split() if len(parts) >= 2 and parts[0] in ("part_of", "regulates", "positively_regulates", "negatively_regulates"): cur_par.add(parts[1]) flush() mf = {gid for gid, n in ns_map.items() if n == "molecular_function"} go_parents_mf = {gid: (par_map[gid] & mf) for gid in mf} go_names_ob = {gid: name_map[gid] for gid in mf} go_defs_mf = {gid: def_map[gid] for gid in mf if gid in def_map} n_edges = sum(len(v) for v in go_parents_mf.values()) print(f"OBO parsed: {len(mf)} MF terms, {n_edges} parent edges, " f"{len(rep_map)} replacements, {len(alt_map)} alt-ids, " f"{len(go_defs_mf)} definitions") # BFS from root to compute minimum depth for each MF term go_depth: dict = {} go_depth[MF_ROOT] = 0 # Build children map for BFS (reverse of parents) children: dict = {gid: set() for gid in mf} for gid, parents in go_parents_mf.items(): for p in parents: if p in children: children[p].add(gid) queue = [MF_ROOT] while queue: nxt = [] for gid in queue: d = go_depth[gid] for child in children.get(gid, ()): if child not in go_depth: go_depth[child] = d + 1 nxt.append(child) queue = nxt print(f"Depth computed: {len(go_depth)} MF terms, " f"max depth={max(go_depth.values(), default=0)}") return mf, go_parents_mf, go_names_ob, rep_map, go_depth, go_defs_mf def compute_dynamic_cap(sorted_probs: list, seq_len: int) -> int: """ Return a protein-proportional cap on the number of direct predictions. Combines three signals: 1. Sequence-length prior (log2 scaling). 2. Probability-gap detection — largest relative drop within budget. 3. Diffuse-activation penalty — when predictions are bunched at low confidence with no clear outlier, the cap is tightened. This prevents the model outputting many near-threshold terms for proteins it is uncertain about (e.g. sparse mammal annotations). """ n = len(sorted_probs) if n == 0: return 0 if n <= 3: return n # Length prior (log2 scaling) length_prior = max(2, int(2.5 * math.log2(max(seq_len, 50) / 50) + 2)) abs_cap = min(15, length_prior * 2) # ── Diffuse-activation penalty ────────────────────────────────────────── # If the top prediction is below 0.75 AND the spread across all predictions # is narrow (< 0.12), we're seeing uniform noise rather than clear signal. top_prob = sorted_probs[0] spread = sorted_probs[0] - sorted_probs[-1] if top_prob < 0.75 and spread < 0.12: # Tight cluster at low confidence: cut cap to length_prior (conservative) abs_cap = max(3, length_prior) elif top_prob < 0.72: # Moderately uncertain: mild tightening abs_cap = max(3, min(abs_cap, length_prior + 2)) search_end = min(n, abs_cap) if search_end < 2: return min(n, abs_cap) best_score = -1.0 best_idx = search_end # default: all up to abs_cap for i in range(1, search_end): prev = sorted_probs[i - 1] curr = sorted_probs[i] rel_gap = (prev - curr) / (prev + 1e-6) dist = abs(i - length_prior) / max(length_prior, 1) score = rel_gap * (1.0 - 0.25 * min(dist, 1.0)) if score > best_score and rel_gap >= 0.08: best_score = score best_idx = i return max(3, best_idx) def propagate_and_filter(preds, go_parents_map, go_ancestors_map, prob_map): """ 1. Propagate predictions upward: for every predicted term, all its MF ancestors are implicitly predicted. Ancestors not already above threshold are added as 'implied' predictions with the child's probability. 2. Filter: a term is 'suppressed' only if it has MF parents and NONE of them appear in the final visible set (direct or implied). 3. Specificity: implied ancestors with depth ≤ 1 (root-adjacent, trivially general terms) are hidden from the visible list. Each prediction carries its depth so the UI can convey specificity. Returns (visible, suppressed) where visible includes informative implied parents. """ if not go_ancestors_map: # Still annotate with depth even without ancestors for p in preds: p["depth"] = go_depth.get(p["go_id"], -1) return preds, [] predicted_ids = {p["go_id"] for p in preds} implied = {} # go_id -> max child prob for pred in preds: gid = pred["go_id"] prob = pred["prob"] for anc in go_ancestors_map.get(gid, set()): if anc not in predicted_ids and anc != MF_ROOT: implied[anc] = max(implied.get(anc, 0.0), prob) all_visible_ids = predicted_ids | set(implied.keys()) # Classify direct predictions: visible if any MF parent is visible (or root / no parents) suppressed = [] direct_ok = [] for pred in preds: gid = pred["go_id"] parents = go_parents_map.get(gid, set()) pred["depth"] = go_depth.get(gid, -1) if gid == MF_ROOT or not parents or (parents & all_visible_ids): direct_ok.append(pred) else: pred["reason"] = "no_visible_parent" suppressed.append(pred) # Add implied ancestor terms — skip root-adjacent (depth ≤ 1) and root itself # Depth ≤ 1 = trivially general terms like "binding", "catalytic activity" # that carry no predictive specificity on their own. MIN_IMPLIED_DEPTH = 2 implied_preds = [] for gid, prob in implied.items(): d = go_depth.get(gid, -1) if d < MIN_IMPLIED_DEPTH: continue # too general — still implicitly true, just not displayed implied_preds.append({ "go_id": gid, "name": go_map.get(gid, gid), "prob": round(prob, 3), "implied": True, "depth": d, }) implied_preds.sort(key=lambda x: (-x["prob"], -x["depth"])) visible = direct_ok + implied_preds visible.sort(key=lambda x: (-x["prob"], -x.get("depth", 0))) return visible, suppressed def sequence_entropy(seq): seq_upper = seq.upper() counts = {} for aa in seq_upper: counts[aa] = counts.get(aa, 0) + 1 n = len(seq_upper) return -sum((c / n) * math.log2(c / n) for c in counts.values()) def validate_sequence(name, seq): """Returns an error string if the sequence should be rejected, else None.""" if len(seq) < MIN_SEQ_LENGTH: return (f"'{name}' is too short ({len(seq)} aa — minimum {MIN_SEQ_LENGTH} aa). " f"Sequences this short are unlikely to fold into a stable domain.") # Reject non-letter characters (digits, spaces, symbols) non_letter = sorted({c for c in seq if not c.isalpha()}) if non_letter: display = ", ".join(f"'{c}'" for c in non_letter[:5]) return (f"'{name}' contains non-amino-acid characters: {display}. " f"Only single-letter amino acid codes are accepted.") # Detect DNA/RNA sequences (>85% ATCGU with ≤5 distinct chars) seq_upper_set = {c.upper() for c in seq} nucleotide_chars = seq_upper_set & set("ATCGU") nucleotide_frac = sum(seq.upper().count(c) for c in "ATCGU") / len(seq) if nucleotide_frac > 0.85 and len(seq_upper_set) <= 6: return (f"'{name}' appears to be a nucleotide sequence (DNA/RNA), not a protein. " f"Please enter an amino acid sequence in single-letter code.") bad = sorted({c.upper() for c in seq if c.upper() in INVALID_AA}) if bad: return (f"'{name}' contains invalid amino acid character(s): " f"{', '.join(bad)}. These ambiguity codes are not accepted.") counts = {} for aa in seq.upper(): counts[aa] = counts.get(aa, 0) + 1 if len(counts) < MIN_DISTINCT_AA: return (f"'{name}' uses only {len(counts)} distinct residue type(s). " f"Real proteins require at least {MIN_DISTINCT_AA}.") dominant_frac = max(counts.values()) / len(seq) if dominant_frac > MAX_DOMINANT_FRAC: dominant_aa = max(counts, key=counts.get) return (f"'{name}' is dominated by a single residue " f"({dominant_aa} = {dominant_frac:.0%}). " f"Low-complexity sequences produce unreliable embeddings.") H = sequence_entropy(seq) if H < MIN_ENTROPY_BITS: return (f"'{name}' has very low sequence complexity " f"(Shannon entropy {H:.2f} bits, minimum {MIN_ENTROPY_BITS:.1f} bits). " f"Repetitive or artificially constructed sequences are not accepted.") return None @asynccontextmanager async def lifespan(app: FastAPI): global device, model, esm_model, batch_converter global mlb, go_map, go_defs, mf_terms, go_parents, go_ancestors, go_depth, go_replaced global mf_indices, thresholds, temperature, insect_temperature, insect_threshold_default, NUM_LABELS, _ESM_DIM global supp_mu, supp_sd, supp_cols, model_uses_supp global taxon_probe, platt_params # Step 1: download missing files ensure_model_files() # Step 2: GO name map go_map = load_go_map() go_names_path = os.path.join(BASE_DIR, "go_names.json") if os.path.exists(go_names_path): with open(go_names_path) as f: go_map.update(json.load(f)) print(f"Canonical GO names loaded: {len(go_map)} total entries") # Step 3: MLB — load BEFORE anything references mlb.classes_ mlb = joblib.load(os.path.join(BASE_DIR, "mlb_public_v1.pkl")) NUM_LABELS = len(mlb.classes_) print(f"MLB loaded: {NUM_LABELS} labels") # Step 4: OBO — parse MF namespace, parent DAG, names, depth, replacements obo_path = os.path.join(BASE_DIR, "go-basic.obo") if os.path.exists(obo_path): mf_terms, go_parents, go_names_obo, go_replaced, go_depth, go_defs_obo = parse_obo(obo_path) go_defs.update(go_defs_obo) # OBO canonical names are the most authoritative — merge over CSV names go_map.update(go_names_obo) print(f"OBO names merged: {len(go_names_obo)} MF term names") mf_in_mlb = sum(1 for gid in mlb.classes_ if gid in mf_terms) rep_in_mlb = sum(1 for gid in mlb.classes_ if gid in go_replaced) print(f"OBO cross-check: {mf_in_mlb}/{NUM_LABELS} active MF, " f"{rep_in_mlb} replaced/obsolete labels remapped") # Build full transitive ancestor cache for parental propagation go_ancestors = build_ancestor_cache(go_parents) print(f"Ancestor cache built: {len(go_ancestors)} terms") else: print("WARNING: go-basic.obo not found — hierarchy filtering disabled. " "Download from https://current.geneontology.org/ontology/go-basic.obo " "and place it alongside server.py.") # Step 5: MF-only whitelist — OBO namespace is authoritative, CSV is fallback if mf_terms: mf_indices = [i for i, gid in enumerate(mlb.classes_) if gid in mf_terms] print(f"MF whitelist (OBO): {len(mf_indices)} active indices") else: mf_go_ids = { go_id for go_id, name in go_map.items() if name and name != go_id and not name.startswith("GO:") } mf_indices = [i for i, gid in enumerate(mlb.classes_) if gid in mf_go_ids] or list(range(NUM_LABELS)) print(f"MF whitelist (CSV fallback): {len(mf_indices)} active indices") app.state.mf_indices = mf_indices # Step 6: per-label thresholds (mammal-calibrated) + insect fallback params thresholds, temperature, insect_temperature, insect_threshold_default = load_thresholds() # Step 7: classifier — auto-detect architecture from checkpoint keys class ResBlock(nn.Module): """Pre-activation residual block with BatchNorm (improved model).""" def __init__(self, dim, dropout=0.2): super().__init__() self.net = nn.Sequential( nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), ) def forward(self, x): return x + self.net(x) class ImprovedResidualMLP(nn.Module): """4-block, hidden=2048, BatchNorm — trained by train_improved.py.""" def __init__(self, in_dim=320, out_dim=NUM_LABELS, hidden=2048, n_blocks=4, dropout=0.2): super().__init__() self.fc_in = nn.Linear(in_dim, hidden) self.blocks = nn.ModuleList([ResBlock(hidden, dropout) for _ in range(n_blocks)]) self.fc_out = nn.Sequential( nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, out_dim), ) def forward(self, x): h = self.fc_in(x) for blk in self.blocks: h = blk(h) return self.fc_out(h) class ResidualMLP(nn.Module): """Original 2-block notebook model (fallback).""" def __init__(self, in_dim=320, out_dim=NUM_LABELS, hidden=1024, dropout=0.2): super().__init__() self.fc_in = nn.Linear(in_dim, hidden) self.block1 = nn.Sequential(nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, hidden)) self.block2 = nn.Sequential(nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, hidden)) self.fc_out = nn.Sequential(nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, out_dim)) def forward(self, x): h = self.fc_in(x) h = torch.relu(h) h = h + self.block1(h) h = h + self.block2(h) return self.fc_out(h) class RecoveredBaselineModel(nn.Module): """Earlier server-side architecture — retained for backward compatibility.""" def __init__(self, in_dim=320, out_dim=NUM_LABELS, hidden=1024, dropout=0.2): super().__init__() self.fc1 = nn.Linear(in_dim, hidden) self.proj = nn.Linear(in_dim, hidden) self.fc2 = nn.Linear(hidden, hidden) self.out = nn.Linear(hidden, out_dim) self.relu = nn.ReLU() self.drop = nn.Dropout(dropout) def forward(self, x): h = self.relu(self.fc1(x)) h = h + self.proj(x) h = self.relu(self.fc2(h)) h = self.drop(h) return self.out(h) import numpy as np device = torch.device("cpu") # Prefer checkpoints in priority order: 35M > unified_v1 > mammal_enriched > v3_fixed > improved > v3 > supp_res2 > baseline ckpt_candidates = [ os.path.join(BASE_DIR, "unified_35M_v1.pth"), os.path.join(BASE_DIR, "unified_v1.pth"), os.path.join(BASE_DIR, "mammal_enriched.pth"), os.path.join(BASE_DIR, "protfunc_v3_fixed.pth"), os.path.join(BASE_DIR, "improved_res.pth"), os.path.join(BASE_DIR, "protfunc_v3.pth"), os.path.join(BASE_DIR, "supp_res2.pth"), os.path.join(BASE_DIR, "baseline_res.pth"), ] _ESM_DIM = 320 # updated after checkpoint load if esm_dim present _model = None for ckpt_path in ckpt_candidates: if not os.path.exists(ckpt_path): continue print(f"Trying classifier: {os.path.basename(ckpt_path)}") try: ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) except Exception as e: print(f" Failed to load {os.path.basename(ckpt_path)}: {e} — skipping") continue sd = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt keys = set(sd.keys()) # Reset supp globals for each candidate _c_supp_mu = None _c_supp_sd = None _c_supp_cols = None _c_uses_supp = False if isinstance(ckpt, dict) and "supp_mu" in ckpt: _c_supp_mu = np.array(ckpt["supp_mu"], dtype=np.float32) _c_supp_sd = np.array(ckpt["supp_sd"], dtype=np.float32) _c_supp_cols = ckpt["supp_cols"] _c_uses_supp = True # Detect architecture and validate in_dim against supp metadata. # Checkpoints that store explicit "in_dim" may use a truncated supp feature # set (e.g. mammal_enriched uses SUPP_COLS[:11] → in_dim=331, but stores # all 39 supp_cols for reference). Trust fc_in.weight shape in that case. _ckpt_has_explicit_in_dim = isinstance(ckpt, dict) and "in_dim" in ckpt # Detect out_dim from checkpoint to avoid mismatch with MLB size _out_dim_ckpt = None for _out_key in ("fc_out.3.weight", "fc_out.2.weight", "out.weight"): if _out_key in sd: _out_dim_ckpt = sd[_out_key].shape[0] break if "blocks.0.net.0.weight" in keys: hidden_dim = sd["fc_in.weight"].shape[0] n_blocks = sum(1 for k in keys if k.startswith("blocks.") and k.endswith(".net.0.weight")) in_dim_ckpt = sd["fc_in.weight"].shape[1] if _c_uses_supp and _c_supp_cols is not None and not _ckpt_has_explicit_in_dim: expected = _ESM_DIM + len(_c_supp_cols) + 1 if in_dim_ckpt != expected: print(f" SKIP: supp metadata says in_dim={expected} but fc_in has {in_dim_ckpt} — corrupted checkpoint") continue out_dim_use = _out_dim_ckpt or NUM_LABELS _model = ImprovedResidualMLP(in_dim=in_dim_ckpt, hidden=hidden_dim, n_blocks=n_blocks, out_dim=out_dim_use).to(device) print(f" ImprovedResidualMLP (in={in_dim_ckpt} hidden={hidden_dim} blocks={n_blocks} out={out_dim_use})") elif any(k.startswith("fc_in") for k in keys): in_dim_ckpt = sd["fc_in.weight"].shape[1] if _c_uses_supp and _c_supp_cols is not None and not _ckpt_has_explicit_in_dim: expected = _ESM_DIM + len(_c_supp_cols) + 1 if in_dim_ckpt != expected: print(f" SKIP: supp metadata says in_dim={expected} but fc_in has {in_dim_ckpt} — corrupted checkpoint") continue out_dim_use = _out_dim_ckpt or NUM_LABELS _model = ResidualMLP(in_dim=in_dim_ckpt, out_dim=out_dim_use).to(device) print(f" ResidualMLP (in={in_dim_ckpt} out={out_dim_use})") elif any(k.startswith("fc1") for k in keys): _c_uses_supp = False out_dim_use = _out_dim_ckpt or NUM_LABELS _model = RecoveredBaselineModel(out_dim=out_dim_use).to(device) print(f" RecoveredBaselineModel (legacy out={out_dim_use})") else: print(f" SKIP: unrecognised architecture — keys: {sorted(keys)[:8]}") continue try: _model.load_state_dict(sd, strict=True) except Exception as e: print(f" SKIP: load_state_dict failed: {e}") _model = None continue # Commit globals and break supp_mu = _c_supp_mu supp_sd = _c_supp_sd supp_cols = _c_supp_cols model_uses_supp = _c_uses_supp if isinstance(ckpt, dict) and "val_fmax" in ckpt: print(f" val_fmax={ckpt['val_fmax']:.4f} epoch={ckpt.get('epoch','?')}") if _c_uses_supp: print(f" Supplemented model: {len(_c_supp_cols)} extra features") # Detect ESM dim from checkpoint metadata if isinstance(ckpt, dict) and "esm_dim" in ckpt: _ESM_DIM = int(ckpt["esm_dim"]) print(f" ESM dim from checkpoint: {_ESM_DIM}") print(f"Classifier loaded: {os.path.basename(ckpt_path)}") break if _model is None: raise RuntimeError("No valid classifier checkpoint found.") _model.eval() model = _model # Step 8: ESM-2 — choose model based on detected esm_dim import esm as esm_lib if _ESM_DIM == 480: _esm_model, alphabet = esm_lib.pretrained.esm2_t12_35M_UR50D() print("ESM-2 (35M, 480-dim) loaded OK") else: _esm_model, alphabet = esm_lib.pretrained.esm2_t6_8M_UR50D() print("ESM-2 (8M, 320-dim) loaded OK") esm_model = _esm_model.to(device).eval() batch_converter = alphabet.get_batch_converter() # Step 9: Taxon probe (optional, generated by calibrate_server.py / calibrate_probe_35M.py) probe_path = os.path.join(BASE_DIR, "taxon_probe.json") if os.path.exists(probe_path): with open(probe_path) as f: taxon_probe = json.load(f) acc = taxon_probe.get("train_accuracy", 0) probe_esm_dim = taxon_probe.get("esm_dim", 320) if probe_esm_dim != _ESM_DIM: print(f"Taxon probe ESM dim mismatch ({probe_esm_dim} vs {_ESM_DIM}) — disabling probe") taxon_probe = None else: for k in ("coef", "intercept", "scaler_mean", "scaler_std"): if k in taxon_probe: taxon_probe[k] = np.asarray(taxon_probe[k], dtype=np.float32) print(f"Taxon probe loaded (train_acc={acc:.4f}, esm_dim={probe_esm_dim})") else: print("Taxon probe not found — using composition heuristic for auto-detection") # Step 10: Platt scaling (optional, generated by calibrate_server.py) platt_path = os.path.join(BASE_DIR, "platt_mammal.json") if os.path.exists(platt_path): with open(platt_path) as f: platt_params = json.load(f) print(f"Platt scaling loaded: {len(platt_params)} labels calibrated") else: print("Platt params not found — using temperature scaling only") # Step 11: Override temperature from calibration sweep if available temp_path = os.path.join(BASE_DIR, "temperature_best.json") if os.path.exists(temp_path): with open(temp_path) as f: temp_data = json.load(f) new_T = float(temp_data.get("optimal_T", temperature)) if abs(new_T - temperature) > 0.1: print(f"Temperature updated by calibration sweep: {temperature:.4f} → {new_T:.4f}") temperature = new_T else: print(f"Temperature unchanged by sweep: {temperature:.4f}") yield print("Shutting down.") app = FastAPI(lifespan=lifespan) app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") @app.get("/") async def root(): return FileResponse(os.path.join(STATIC_DIR, "interface.html"), headers={"Cache-Control": "no-store"}) @app.get("/api/model/info") async def model_info(): """Return model metadata and configuration.""" unified_35M = os.path.exists(os.path.join(BASE_DIR, "unified_35M_v1.pth")) unified_v1 = os.path.exists(os.path.join(BASE_DIR, "unified_v1.pth")) mammal_enriched = os.path.exists(os.path.join(BASE_DIR, "mammal_enriched.pth")) v3_fixed = os.path.exists(os.path.join(BASE_DIR, "protfunc_v3_fixed.pth")) improved = os.path.exists(os.path.join(BASE_DIR, "improved_res.pth")) # model name reflects actual loaded model (unified_35M_v1 takes highest priority) if unified_35M and model_uses_supp and _ESM_DIM == 480: name, version, active = "ProtFunc v5.0 (35M ESM, GOA-enriched, insect+mammal)", "5.0.0", "unified_35M_v1" elif unified_v1 and model_uses_supp: name, version, active = "ProtFunc v4.0 (unified insect+mammal, CAFA5-evaluated)", "4.0.0", "unified_v1" elif mammal_enriched and model_uses_supp: name, version, active = "ProtFunc v3.2 (mammal-enriched, CAFA5-evaluated)", "3.2.0", "mammal_enriched" elif v3_fixed and model_uses_supp: name, version, active = "ProtFunc v3-fixed (ablation best, CAFA-correct)", "3.1.0", "protfunc_v3_fixed" elif model_uses_supp: name, version, active = "ProtFunc v3 (supplemented + mammal)", "3.0.0", "protfunc_v3" elif improved: name, version, active = "ProtFunc Enhanced", "2.1.0", "improved" else: name, version, active = "ProtFunc", "2.0.0", "baseline" return { "model_name": name, "model": active, "version": version, "esm_model": "esm2_t12_35M_UR50D" if _ESM_DIM == 480 else "esm2_t6_8M_UR50D", "embed_dim": _ESM_DIM, "num_labels": NUM_LABELS, "supported_namespaces": ["molecular_function"], "max_sequence_length": 1500, "thresholds_loaded": len(thresholds) > 0, "temperature_scaling": temperature != 1.0, "temperature": round(temperature, 4), "insect_temperature": round(insect_temperature, 4), "insect_threshold_default": round(insect_threshold_default, 4), "supported_taxa": ["insect", "mammal"], "taxon_routing": True, "hierarchy_filtering": len(go_parents) > 0, "parental_propagation": len(go_ancestors) > 0, "depth_annotation": len(go_depth) > 0, "mf_terms_loaded": len(mf_terms) if mf_terms else 0, "mf_max_depth": max(go_depth.values(), default=0) if go_depth else 0, "supplemented_features": model_uses_supp, "supp_feature_count": len(supp_cols) if supp_cols else 0, } @app.get("/api/generalization") async def get_generalization(): """ Return cross-taxon generalization results from eval_generalization.py output. Serves artifacts/generalization_results.json if present, otherwise returns empty. """ candidates = [ os.path.join(BASE_DIR, "artifacts", "generalization", "generalization_results.json"), os.path.join(BASE_DIR, "generalization_results.json"), ] for path in candidates: if os.path.exists(path): with open(path) as f: data = json.load(f) return { "available": True, "taxa": list(data.keys()), "results": data, } return {"available": False, "taxa": [], "results": {}} @app.get("/api/structure") async def get_structure(uniprot_id: str): """Look up AlphaFold structure data for a UniProt accession.""" import urllib.request, urllib.error uid = uniprot_id.upper().strip() if not re.match(r'^[A-Z0-9]{4,10}$', uid): return {"found": False, "error": "Invalid accession format"} try: req = urllib.request.Request( f"https://alphafold.ebi.ac.uk/api/prediction/{uid}", headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"} ) with urllib.request.urlopen(req, timeout=8) as resp: entries = json.loads(resp.read()) d = entries[0] return { "found": True, "accession": uid, "organism": d.get("organismScientificName", ""), "gene": d.get("gene", ""), "cif_url": d.get("cifUrl", ""), "pae_image_url": d.get("paeImageUrl", ""), "entry_url": f"https://alphafold.ebi.ac.uk/entry/{uid}", "uniprot_url": f"https://www.uniprot.org/uniprot/{uid}", "model_version": d.get("latestVersion", 4), } except urllib.error.HTTPError as e: if e.code == 404: return {"found": False, "accession": uid, "uniprot_url": f"https://www.uniprot.org/uniprot/{uid}"} return {"found": False, "error": f"HTTP {e.code}"} except Exception as e: return {"found": False, "error": str(e)[:100]} @app.get("/api/uniprot/annotations") async def get_uniprot_annotations(uniprot_id: str): """ Fetch GO-MF annotations and organism info from UniProt REST API. Returns known annotations (evidence-coded) for comparison with predictions. """ import urllib.request, urllib.error uid = uniprot_id.upper().strip() if not re.match(r'^[A-Z0-9]{4,10}$', uid): return {"found": False, "error": "Invalid accession format"} if uid in _uniprot_cache: return _uniprot_cache[uid] try: url = f"https://rest.uniprot.org/uniprotkb/{uid}.json?fields=go,organism,protein_name,gene_names,organism_lineage" req = urllib.request.Request(url, headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"}) with urllib.request.urlopen(req, timeout=10) as resp: data = json.loads(resp.read()) except urllib.error.HTTPError as e: if e.code == 404: return {"found": False, "accession": uid, "error": "Not found in UniProt"} return {"found": False, "error": f"HTTP {e.code}"} except Exception as e: return {"found": False, "error": str(e)[:100]} # Parse organism lineage for taxon detection organism = data.get("organism", {}) org_name = organism.get("scientificName", "") lineage = [x.get("scientificName", "") for x in organism.get("lineage", [])] if any(t in lineage for t in INSECT_LINEAGE): detected_taxon = "insect" elif any(t in lineage for t in MAMMAL_LINEAGE): detected_taxon = "mammal" else: detected_taxon = "auto" # Parse protein name pn_block = data.get("proteinDescription", {}) prot_name = "" if "recommendedName" in pn_block: prot_name = pn_block["recommendedName"].get("fullName", {}).get("value", "") elif "submissionNames" in pn_block: prot_name = pn_block["submissionNames"][0].get("fullName", {}).get("value", "") gene_names = [] for gn in data.get("genes", []): if "geneName" in gn: gene_names.append(gn["geneName"]["value"]) # Parse GO-MF annotations go_mf = [] for xref in data.get("uniProtKBCrossReferences", []): if xref.get("database") != "GO": continue go_id = xref.get("id", "") props = {p["key"]: p["value"] for p in xref.get("properties", [])} term_str = props.get("GoTerm", "") if not term_str.startswith("F:"): continue # only molecular function evidence = props.get("GoEvidenceType", "") ev_code = evidence.split(":")[0] if ":" in evidence else evidence go_name = term_str[2:].strip() # Resolve canonical name from our go_map if available display_name = go_map.get(go_id, go_name) # Classify evidence tier exp_codes = {"EXP","IDA","IPI","IMP","IGI","IEP","HTP","HDA","HMP","HGI","HEP"} comp_codes = {"ISS","ISO","ISA","ISM","IGC","IBA","IBD","IKR","IRD","RCA"} if ev_code in exp_codes: ev_tier = "experimental" elif ev_code in comp_codes: ev_tier = "computational" elif ev_code == "IEA": ev_tier = "electronic" else: ev_tier = "other" go_mf.append({ "go_id": go_id, "name": display_name, "evidence": ev_code, "ev_tier": ev_tier, }) # Deduplicate by go_id, keeping best evidence tier tier_rank = {"experimental": 0, "computational": 1, "other": 2, "electronic": 3} seen = {} for entry in go_mf: gid = entry["go_id"] if gid not in seen or tier_rank[entry["ev_tier"]] < tier_rank[seen[gid]["ev_tier"]]: seen[gid] = entry go_mf = sorted(seen.values(), key=lambda x: (tier_rank[x["ev_tier"]], x["name"])) result = { "found": True, "accession": uid, "protein_name": prot_name, "gene_names": gene_names, "organism": org_name, "lineage": lineage[-5:] if lineage else [], "detected_taxon": detected_taxon, "go_mf": go_mf, "n_experimental": sum(1 for e in go_mf if e["ev_tier"] == "experimental"), "n_total": len(go_mf), } _uniprot_cache[uid] = result return result @app.get("/api/health") async def health_check(): """Health check endpoint for monitoring.""" return { "status": "healthy", "model_loaded": model is not None, "esm_loaded": esm_model is not None, "labels": NUM_LABELS, } class ProteinRequest(BaseModel): sequence: str taxon: str = "auto" # "auto" | "insect" | "mammal" uniprot_id: str = "" # optional accession for structure lookup class SaliencyRequest(BaseModel): sequence: str uniprot_id: str = "" taxon: str = "auto" top_k: int = 20 # top-k predicted labels to use as saliency objective class ExplainRequest(BaseModel): sequence: str uniprot_id: str = "" taxon: str = "auto" top_k: int = 10 class BatchPredictRequest(BaseModel): sequences: list # List of {name: str, sequence: str} objects threshold: float = 0.5 include_suppressed: bool = True def parse_sequences(text): text = text.strip() if text.startswith(">"): blocks = re.split(r"(>.*)", text) names, seqs = [], [] i = 1 while i < len(blocks): name = blocks[i][1:].strip() seq = re.sub(r"\s+", "", blocks[i + 1]) if i + 1 < len(blocks) else "" if seq: names.append(name) seqs.append(seq) i += 2 return list(zip(names, seqs)) seqs = [line.strip() for line in text.splitlines() if line.strip()] return [(f"Sequence {i + 1}", s) for i, s in enumerate(seqs)] def _build_model_input(emb: torch.Tensor, sequence: str) -> torch.Tensor: """ Build the model input tensor from an ESM-2 embedding. For supplemented models: appends z-scored seq/structural features. For base models: returns the embedding as-is. Handles two supplemented feature sets: - v3 models (in_dim=360): 39 sequence-derived features + 1 missingness flag - supp_res2 (in_dim=709): 388 features including f_Dim_* (ESM dims), AF structural features (zeros at inference), and sequence features """ import numpy as np if not model_uses_supp or supp_mu is None: return emb.unsqueeze(0) feats = compute_seq_features(sequence) # For models that include f_Dim_* features (e.g. supp_res2), populate them # from the ESM embedding rather than defaulting to 0. emb_np = emb.detach().cpu().numpy() for c in supp_cols: if c.startswith('f_Dim_'): try: dim_idx = int(c.split('_')[-1]) if dim_idx < len(emb_np): feats[c] = float(emb_np[dim_idx]) except (ValueError, IndexError): pass s_vec = np.array([feats.get(c, 0.0) for c in supp_cols], dtype=np.float32) s_z = (s_vec - supp_mu) / (supp_sd + 1e-12) # missingness flag: 1 = feature data available (seq features always computable) flag = np.array([1.0], dtype=np.float32) extra = torch.from_numpy(np.concatenate([s_z, flag])) full_input = torch.cat([emb, extra], dim=0).unsqueeze(0) # Verify the built input matches what the model actually expects. # Some checkpoints (e.g. mammal_enriched) store all supp_cols for reference but # were trained with only the first N features (via [:in_dim] truncation). # Truncate to match rather than falling back to bare ESM embedding. if model is not None and hasattr(model, 'fc_in'): expected_dim = model.fc_in.weight.shape[1] if full_input.shape[1] != expected_dim: if full_input.shape[1] > expected_dim: full_input = full_input[:, :expected_dim] else: return emb.unsqueeze(0) # can't expand — fall back to bare ESM return full_input # LRU cache for ESM embeddings — avoids recomputing for repeated/identical sequences _ESM_CACHE: dict = {} _ESM_CACHE_MAX = 256 def _get_esm_embedding(sequence: str) -> torch.Tensor: """Return mean-pooled ESM-2 embedding, using in-memory cache.""" key = hashlib.md5(sequence.encode()).hexdigest() if key in _ESM_CACHE: return _ESM_CACHE[key] _, _, tokens = batch_converter([("p", sequence)]) with torch.no_grad(): rep = esm_model(tokens.to(device), repr_layers=[6])["representations"][6] emb = rep[0, 1:len(sequence) + 1].mean(0).cpu() if len(_ESM_CACHE) >= _ESM_CACHE_MAX: _ESM_CACHE.pop(next(iter(_ESM_CACHE))) # evict oldest (FIFO) _ESM_CACHE[key] = emb return emb @app.post("/predict") async def predict(request: ProteinRequest): import numpy as np import urllib.request, urllib.error entries = parse_sequences(request.sequence) results = [] device_cpu = torch.device("cpu") mf_idx = app.state.mf_indices uid = (request.uniprot_id or "").upper().strip() req_taxon = (request.taxon or "auto").lower() # Fetch UniProt annotations async (if accession given) — do once for the batch uniprot_data = None detected_taxon_from_uniprot = None if uid and re.match(r'^[A-Z0-9]{4,10}$', uid): try: url = f"https://rest.uniprot.org/uniprotkb/{uid}.json?fields=go,organism,protein_name,gene_names,organism_lineage" req_obj = urllib.request.Request(url, headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"}) with urllib.request.urlopen(req_obj, timeout=8) as resp: raw = json.loads(resp.read()) organism = raw.get("organism", {}) lineage = [x.get("scientificName", "") for x in organism.get("lineage", [])] if any(t in lineage for t in INSECT_LINEAGE): detected_taxon_from_uniprot = "insect" elif any(t in lineage for t in MAMMAL_LINEAGE): detected_taxon_from_uniprot = "mammal" # Parse GO-MF annotations go_mf_known = {} tier_rank = {"experimental": 0, "computational": 1, "other": 2, "electronic": 3} exp_ev = {"EXP","IDA","IPI","IMP","IGI","IEP","HTP","HDA","HMP","HGI","HEP"} comp_ev = {"ISS","ISO","ISA","ISM","IGC","IBA","IBD","IKR","IRD","RCA"} for xref in raw.get("uniProtKBCrossReferences", []): if xref.get("database") != "GO": continue props = {p["key"]: p["value"] for p in xref.get("properties", [])} if not props.get("GoTerm", "").startswith("F:"): continue go_id = xref.get("id", "") ev_code = props.get("GoEvidenceType", "").split(":")[0] ev_tier = "experimental" if ev_code in exp_ev else ( "computational" if ev_code in comp_ev else ( "electronic" if ev_code == "IEA" else "other")) entry = {"go_id": go_id, "name": go_map.get(go_id, props["GoTerm"][2:]), "evidence": ev_code, "ev_tier": ev_tier} if go_id not in go_mf_known or tier_rank[ev_tier] < tier_rank[go_mf_known[go_id]["ev_tier"]]: go_mf_known[go_id] = entry uniprot_data = { "go_mf_known": sorted(go_mf_known.values(), key=lambda x: (tier_rank[x["ev_tier"]], x["name"])), "organism": organism.get("scientificName", ""), "detected_taxon": detected_taxon_from_uniprot, } except Exception: pass for name, sequence in entries: err = validate_sequence(name, sequence) if err: results.append({"name": name, "error": err}) continue if len(sequence) > 1500: results.append({"name": name, "error": "Sequence too long (max 1500 aa)"}) continue try: emb = _get_esm_embedding(sequence).to(device_cpu) emb_np = emb.detach().cpu().numpy() # ── Taxon auto-detection ──────────────────────────────────────── taxon_source = req_taxon taxon_conf = 1.0 if req_taxon == "auto": if detected_taxon_from_uniprot: taxon_source = detected_taxon_from_uniprot taxon_conf = 1.0 elif taxon_probe is not None: taxon_source, taxon_conf = _detect_taxon_probe(emb_np) else: taxon_source, taxon_conf = _detect_taxon_composition(sequence) if taxon_source == "insect": t_apply = insect_temperature thresh_lookup = {} t_default = insect_threshold_default mammal_floor = 0.0 else: t_apply = temperature thresh_lookup = thresholds t_default = 0.68 mammal_floor = 0.56 # ── Forward pass ──────────────────────────────────────────────── with torch.no_grad(): inp = _build_model_input(emb, sequence) logits = model(inp).squeeze() # Apply Platt scaling per-label (if available, mammal only) if platt_params and taxon_source != "insect": probs_list = [] for i in range(len(logits)): l = float(logits[i]) if str(i) in platt_params: A, B = platt_params[str(i)] p = 1.0 / (1.0 + math.exp(-(A * l + B))) else: p = 1.0 / (1.0 + math.exp(-l / t_apply)) probs_list.append(p) prob = torch.tensor(probs_list) else: prob = torch.sigmoid(logits / t_apply) if prob.dim() == 0: prob = prob.unsqueeze(0) # ── Threshold + collect ───────────────────────────────────────── raw_preds = [] prob_map = {} for i in mf_idx: pv = float(prob[i]) label_thresh = max(float(thresh_lookup.get(str(i), t_default)), mammal_floor) if pv >= label_thresh: go_id = mlb.classes_[i] display_id = go_replaced.get(go_id, go_id) display_nm = go_map.get(display_id, go_map.get(go_id, go_id)) entry = {"go_id": display_id, "name": display_nm, "prob": round(pv, 4), "depth": go_depth.get(display_id, -1)} if display_id != go_id: entry["original_id"] = go_id raw_preds.append(entry) prob_map[display_id] = pv raw_preds.sort(key=lambda x: x["prob"], reverse=True) cap = compute_dynamic_cap([p["prob"] for p in raw_preds], len(sequence)) raw_preds = raw_preds[:cap] for rp in raw_preds: prob_map[rp["go_id"]] = rp["prob"] visible, suppressed = propagate_and_filter(raw_preds, go_parents, go_ancestors, prob_map) result = { "name": name, "sequence_length": len(sequence), "predictions": visible, "suppressed": suppressed, "n_above_threshold": len(raw_preds), "n_implied_parents": sum(1 for p in visible if p.get("implied")), "taxon_applied": taxon_source, "taxon_source": "uniprot" if detected_taxon_from_uniprot and req_taxon == "auto" else ("probe" if taxon_probe and req_taxon == "auto" else ("composition" if req_taxon == "auto" else "manual")), "taxon_confidence": round(taxon_conf, 3), "temperature_applied": round(t_apply, 4), "platt_applied": bool(platt_params) and taxon_source != "insect", } if uniprot_data: result["uniprot"] = uniprot_data results.append(result) except Exception as e: results.append({"name": name, "error": str(e)}) return {"results": results} @app.post("/api/predict/batch") async def predict_batch(request: BatchPredictRequest): """ Batch prediction endpoint for multiple sequences. Accepts a list of sequence objects and returns predictions for all. More efficient than multiple single predictions due to batching. """ results = [] mf_idx = app.state.mf_indices custom_threshold = request.threshold for item in request.sequences: name = item.get("name", "Unknown") sequence = item.get("sequence", "") # Validate sequence err = validate_sequence(name, sequence) if err: results.append({"name": name, "error": err}) continue if len(sequence) > 1500: results.append({"name": name, "error": "Sequence too long (max 1500 aa)"}) continue try: emb = _get_esm_embedding(sequence).to(device) with torch.no_grad(): inp = _build_model_input(emb, sequence) prob = torch.sigmoid(model(inp) / temperature).squeeze() if prob.dim() == 0: prob = prob.unsqueeze(0) raw_preds = [] prob_map = {} for i in mf_idx: pv = float(prob[i]) thresh = float(thresholds.get(str(i), custom_threshold)) if pv >= thresh: go_id = mlb.classes_[i] display_id = go_replaced.get(go_id, go_id) display_nm = go_map.get(display_id, go_map.get(go_id, go_id)) entry = { "go_id": display_id, "name": display_nm, "prob": round(pv, 4), "depth": go_depth.get(display_id, -1), } if display_id != go_id: entry["original_id"] = go_id raw_preds.append(entry) prob_map[display_id] = pv raw_preds.sort(key=lambda x: x["prob"], reverse=True) cap = compute_dynamic_cap([p["prob"] for p in raw_preds], len(sequence)) raw_preds = raw_preds[:cap] for rp in raw_preds: prob_map[rp["go_id"]] = rp["prob"] visible, suppressed = propagate_and_filter( raw_preds, go_parents, go_ancestors, prob_map ) if not request.include_suppressed: suppressed = [] results.append({ "name": name, "sequence_length": len(sequence), "predictions": visible, "suppressed": suppressed, "n_above_threshold": len(raw_preds), "n_implied_parents": sum(1 for p in visible if p.get("implied")), }) except Exception as e: results.append({"name": name, "error": str(e)}) return { "results": results, "total": len(results), "successful": sum(1 for r in results if "error" not in r), } class GoTermsRequest(BaseModel): sequence: str uniprot_id: str = "" taxon: str = "auto" top_k: int = 20 include_implied: bool = False min_prob: float = 0.0 _go_terms_cache: dict = {} @app.post("/api/go_terms") async def get_go_terms(request: GoTermsRequest): """ Lightweight GO-MF prediction endpoint for programmatic/pipeline use. Returns only predicted term list — no suppressed, no taxon explanation, no Platt metadata. ~2× faster than /predict for downstream tools. Results are cached by (sequence_hash, uniprot_id, taxon). """ import hashlib seq = request.sequence.strip() cache_key = (hashlib.md5(seq.encode()).hexdigest(), request.uniprot_id.upper(), request.taxon) if cache_key in _go_terms_cache: return _go_terms_cache[cache_key] err = validate_sequence("seq", seq) if err: return {"error": err, "predictions": []} if len(seq) > 1500: return {"error": "Sequence too long (max 1500 aa)", "predictions": []} try: mf_idx = app.state.mf_indices emb = _get_esm_embedding(seq).to(device) emb_np = emb.detach().cpu().numpy() taxon_source = request.taxon taxon_conf = 1.0 if taxon_source == "auto": if taxon_probe is not None: taxon_source, taxon_conf = _detect_taxon_probe(emb_np) else: taxon_source, taxon_conf = _detect_taxon_composition(seq) if taxon_source == "insect": t_apply = insect_temperature thresh_lookup = {} t_default = insect_threshold_default mammal_floor = 0.0 else: t_apply = temperature thresh_lookup = thresholds t_default = 0.68 mammal_floor = 0.56 with torch.no_grad(): inp = _build_model_input(emb, seq) logits = model(inp).squeeze() if platt_params and taxon_source != "insect": probs_list = [] for i in range(len(logits)): l = float(logits[i]) if str(i) in platt_params: A, B = platt_params[str(i)] p = 1.0 / (1.0 + math.exp(-(A * l + B))) else: p = 1.0 / (1.0 + math.exp(-l / t_apply)) probs_list.append(p) prob = torch.tensor(probs_list) else: prob = torch.sigmoid(logits / t_apply) if prob.dim() == 0: prob = prob.unsqueeze(0) raw_preds = [] for i in mf_idx: pv = float(prob[i]) label_thresh = max(float(thresh_lookup.get(str(i), t_default)), mammal_floor) if pv >= label_thresh: go_id = mlb.classes_[i] display_id = go_replaced.get(go_id, go_id) display_nm = go_map.get(display_id, go_map.get(go_id, go_id)) raw_preds.append({ "go_id": display_id, "name": display_nm, "prob": round(pv, 4), "depth": go_depth.get(display_id, -1), }) raw_preds.sort(key=lambda x: x["prob"], reverse=True) cap = compute_dynamic_cap([p["prob"] for p in raw_preds], len(seq)) raw_preds = raw_preds[:cap] if not request.include_implied: predictions = raw_preds[:request.top_k] else: prob_map = {p["go_id"]: p["prob"] for p in raw_preds} visible, _ = propagate_and_filter(raw_preds, go_parents, go_ancestors, prob_map) predictions = [p for p in visible if not p.get("implied")][:request.top_k] if request.min_prob > 0: predictions = [p for p in predictions if p["prob"] >= request.min_prob] result = { "predictions": predictions, "n_predicted": len(predictions), "taxon_applied": taxon_source, "taxon_conf": round(taxon_conf, 3), "taxon_source_method": ( "probe" if taxon_probe is not None and request.taxon == "auto" else ("composition" if request.taxon == "auto" else "manual") ), "platt_applied": bool(platt_params) and taxon_source != "insect", "threshold_default": t_default, } _go_terms_cache[cache_key] = result return result except Exception as e: return {"error": str(e)[:300], "predictions": []} @app.get("/api/explain_terms") async def explain_terms(ids: str = ""): """ Return name + definition for a comma-separated list of GO IDs. Used by the frontend "Why?" panel to show term descriptions inline. """ if not ids: return {"terms": []} result = [] for gid in ids.split(",")[:50]: gid = gid.strip() if not gid: continue name = go_map.get(gid, gid) defn = go_defs.get(gid, "") result.append({"id": gid, "name": name, "definition": defn}) return {"terms": result} @app.post("/api/saliency") async def compute_saliency(request: SaliencyRequest): """ Compute per-residue gradient saliency for a protein sequence. Uses d(sum_of_top_k_probs)/d(ESM_residue_representations) via backprop. Optionally fetches AlphaFold structure if uniprot_id is provided. Returns normalized per-residue importance scores in [0, 1]. """ import numpy as np import urllib.request, urllib.error sequence = re.sub(r"\s+", "", request.sequence.upper()) if not sequence: return {"error": "Empty sequence"} if len(sequence) > 1200: return {"error": "Sequence too long for saliency (max 1200 aa)"} err = validate_sequence("query", sequence) if err: return {"error": err} taxon = (request.taxon or "auto").lower() t_apply = insect_temperature if taxon == "insect" else temperature try: _, _, tokens = batch_converter([("p", sequence)]) tokens = tokens.to(device) L = len(sequence) with torch.enable_grad(): # Run ESM keeping computation graph; retain grad on residue reps out = esm_model(tokens, repr_layers=[6]) residue_reps = out["representations"][6] # (1, L+2, 320) residue_reps.retain_grad() # Mean-pool (stays in graph) emb = residue_reps[0, 1:L + 1].mean(0) # (320,) # Build MLP input in-graph (gradient-safe version of _build_model_input) if model_uses_supp and supp_mu is not None: feats = compute_seq_features(sequence) s_vec = torch.tensor( [(feats.get(c, 0.0) - float(supp_mu[j])) / (float(supp_sd[j]) + 1e-12) for j, c in enumerate(supp_cols)], dtype=torch.float32, device=device ) # flag tensor (no grad needed) flag = torch.ones(1, dtype=torch.float32, device=device) inp_full = torch.cat([emb, s_vec, flag]).unsqueeze(0) expected = model.fc_in.weight.shape[1] inp = inp_full[:, :expected] else: inp = emb.unsqueeze(0) logits = model(inp) / t_apply probs = torch.sigmoid(logits[0]) # (8124,) # Objective: sum of top-k predicted probabilities k = min(request.top_k, int((probs > 0.3).sum().item()), 8124) k = max(k, 5) top_vals = probs.topk(k).values objective = top_vals.sum() objective.backward() scores = [0.0] * L if residue_reps.grad is not None: grad = residue_reps.grad[0, 1:L + 1].norm(dim=-1).detach().cpu().numpy() mn, mx = grad.min(), grad.max() scores = ((grad - mn) / (mx - mn + 1e-8)).tolist() # Fetch AlphaFold structure if accession given structure = None uid = request.uniprot_id.upper().strip() if uid and re.match(r'^[A-Z0-9]{4,10}$', uid): try: req = urllib.request.Request( f"https://alphafold.ebi.ac.uk/api/prediction/{uid}", headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"} ) with urllib.request.urlopen(req, timeout=8) as resp: d = json.loads(resp.read())[0] structure = { "found": True, "accession": uid, "cif_url": d.get("cifUrl", ""), "pdb_url": d.get("pdbUrl", ""), "organism": d.get("organismScientificName", ""), "gene": d.get("gene", ""), "entry_url": f"https://alphafold.ebi.ac.uk/entry/{uid}", "uniprot_url": f"https://www.uniprot.org/uniprot/{uid}", } except Exception: structure = {"found": False, "accession": uid} return { "sequence_length": L, "per_residue_scores": scores, "taxon": taxon, "structure": structure, } except Exception as e: return {"error": str(e)[:200]} @app.post("/api/explainability") async def compute_explainability(request: ExplainRequest): """ Compute feature-level importance for the 11 sequence features via gradient × input. Also returns per-residue feature maps for 3D structure coloring. Gradient flows only through the first 11 supp features; ESM embedding is detached. """ import numpy as np import urllib.request, urllib.error sequence = re.sub(r"\s+", "", request.sequence.upper()) if not sequence: return {"error": "Empty sequence"} if len(sequence) > 1200: return {"error": "Sequence too long for explainability (max 1200 aa)"} err = validate_sequence("query", sequence) if err: return {"error": err} if not model_uses_supp or supp_mu is None: return {"error": "Feature importance requires a supplemented model (unified_v1 or later)"} taxon = (request.taxon or "auto").lower() t_apply = insect_temperature if taxon == "insect" else temperature try: emb = _get_esm_embedding(sequence).to(device) # (320,) detached feats = compute_seq_features(sequence) s_vec = np.array([feats.get(c, 0.0) for c in supp_cols], dtype=np.float32) s_z = (s_vec - supp_mu) / (supp_sd + 1e-12) n_tracked = min(11, len(supp_cols)) s_z_11 = torch.tensor(s_z[:n_tracked], requires_grad=True, dtype=torch.float32, device=device) s_z_rest = torch.tensor(s_z[n_tracked:], dtype=torch.float32, device=device) flag = torch.ones(1, dtype=torch.float32, device=device) inp_full = torch.cat([emb.detach(), s_z_11, s_z_rest, flag]).unsqueeze(0) expected = model.fc_in.weight.shape[1] inp = inp_full[:, :expected] with torch.enable_grad(): logits = model(inp) / t_apply probs = torch.sigmoid(logits[0]) k = min(request.top_k, int((probs > 0.3).sum().item()), 8124) k = max(k, 5) probs.topk(k).values.sum().backward() if s_z_11.grad is None: return {"error": "Gradient computation failed — no gradient on supp features"} grad_np = s_z_11.grad.detach().cpu().numpy() s_z_11np = s_z_11.detach().cpu().numpy() attribution = grad_np * s_z_11np # signed grad × input per_residue_maps = compute_per_residue_features(sequence) feat_meta_by_key = {fm["key"]: fm for fm in FEATURE_META} features = [] for i in range(n_tracked): col = supp_cols[i] meta = feat_meta_by_key.get(col, {"key": col, "label": col, "desc": col, "color": "#888888"}) attr = float(attribution[i]) features.append({ "key": col, "label": meta["label"], "desc": meta["desc"], "color": meta["color"], "importance": round(attr, 4), "abs_importance": round(abs(attr), 4), "per_residue": per_residue_maps.get(col, [0.5] * len(sequence)), }) features.sort(key=lambda x: x["abs_importance"], reverse=True) # Fetch AlphaFold structure structure = None uid = request.uniprot_id.upper().strip() if uid and re.match(r'^[A-Z0-9]{4,10}$', uid): try: req = urllib.request.Request( f"https://alphafold.ebi.ac.uk/api/prediction/{uid}", headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"} ) with urllib.request.urlopen(req, timeout=8) as resp: d = json.loads(resp.read())[0] structure = { "found": True, "accession": uid, "cif_url": d.get("cifUrl", ""), "organism": d.get("organismScientificName", ""), "gene": d.get("gene", ""), "entry_url": f"https://alphafold.ebi.ac.uk/entry/{uid}", "uniprot_url": f"https://www.uniprot.org/uniprot/{uid}", } except Exception: structure = {"found": False, "accession": uid} return { "sequence_length": len(sequence), "features": features, "top_feature": features[0]["key"] if features else None, "structure": structure, "taxon": taxon, } except Exception as e: return {"error": str(e)[:300]} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)