protfunc / server.py
Sbhat2026's picture
fix: _ESM_DIM global scope fix for model info endpoint
fee2326
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
from pydantic import BaseModel
import torch
import torch.nn as nn
import pandas as pd
import joblib
import json
import math
import os
import re
import time
import hashlib
import warnings
from functools import lru_cache
warnings.filterwarnings("ignore")
# Use all available CPU threads for faster ESM inference
torch.set_num_threads(min(os.cpu_count() or 4, 8))
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
STATIC_DIR = os.path.join(BASE_DIR, "static")
os.makedirs(STATIC_DIR, exist_ok=True)
HF_REPO = "Sbhat2026/protfunc-models"
# Priority order: 35M > unified_v1 > mammal_enriched > v3_fixed > v3 > improved > supp_res2 > baseline
HF_FILES = [
"unified_35M_v1.pth", "unified_35M_v1_thresholds.json",
"unified_v1.pth", "unified_v1_recalibrated.json",
"mammal_enriched.pth", "mammal_enriched_thresholds.json",
"protfunc_v3_fixed.pth", "protfunc_v3_fixed_thresholds.json",
"protfunc_v3.pth", "protfunc_v3_thresholds.json",
"improved_res.pth", "improved_per_label_thresholds.json",
"supp_res2.pth",
"baseline_res.pth", "mlb_public_v1.pkl", "go_names.json",
]
OPTIONAL = {
"go_names.json",
"unified_35M_v1.pth", "unified_35M_v1_thresholds.json",
"mammal_enriched.pth", "mammal_enriched_thresholds.json",
"protfunc_v3_fixed.pth", "protfunc_v3_fixed_thresholds.json",
"protfunc_v3.pth", "protfunc_v3_thresholds.json",
"improved_res.pth", "improved_per_label_thresholds.json",
"supp_res2.pth",
}
# Globals populated during lifespan startup
device = torch.device("cpu") # always CPU on HF Space
model = None
esm_model = None
batch_converter = None
mlb = None
go_map = {}
go_defs = {} # GO ID -> definition string (from OBO def: field)
mf_terms = set()
go_parents = {} # GO ID -> set of direct parent GO IDs (MF DAG)
go_ancestors = {} # GO ID -> full set of ancestor GO IDs (transitive)
go_depth = {} # GO ID -> min depth from MF root (root = 0)
go_replaced = {} # obsolete GO ID -> replacement GO ID
mf_indices = None
thresholds = {} # label_idx (str) -> float threshold (mammal-calibrated)
temperature = 1.0 # temperature scaling T for mammal inference (logit/T before sigmoid)
# Insect-specific inference params (flat threshold, no temperature scaling needed)
insect_temperature = 1.0
insect_threshold_default = 0.68
NUM_LABELS = 0
_ESM_DIM = 320 # updated to 480 when unified_35M_v1 is loaded
# Supplemented model stats (loaded from checkpoint if present)
supp_mu = None # np.ndarray shape (SUPP_DIM,)
supp_sd = None # np.ndarray shape (SUPP_DIM,)
supp_cols = None # list[str]
model_uses_supp = False # True when model expects supp features
# Taxon probe (logistic regression on ESM embeddings, loaded from taxon_probe.json)
taxon_probe = None # dict with scaler_mean, scaler_std, coef, intercept
# Platt scaling (per-label logistic regression on logits, loaded from platt_mammal.json)
platt_params = {} # label_idx str -> [A, B]
# UniProt taxonomy + annotation cache
_uniprot_cache: dict = {}
# Biological complexity filter constants
MIN_SEQ_LENGTH = 30
MIN_ENTROPY_BITS = 2.5
MAX_DOMINANT_FRAC = 0.60
MIN_DISTINCT_AA = 5
INVALID_AA = set("BJOUXZ")
MF_ROOT = "GO:0003674"
# Kyte-Doolittle hydrophobicity scale
_KD = {'A':1.8,'R':-4.5,'N':-3.5,'D':-3.5,'C':2.5,'Q':-3.5,'E':-3.5,
'G':-0.4,'H':-3.2,'I':4.5,'L':3.8,'K':-3.9,'M':1.9,'F':2.8,
'P':-1.6,'S':-0.8,'T':-0.7,'W':-0.9,'Y':-1.3,'V':4.2}
# Chou-Fasman helix/sheet propensities
_CF_HELIX = {'A':1.42,'R':0.98,'N':0.67,'D':1.01,'C':0.70,'Q':1.11,'E':1.51,
'G':0.57,'H':1.00,'I':1.08,'L':1.21,'K':1.16,'M':1.45,'F':1.13,
'P':0.57,'S':0.77,'T':0.83,'W':1.08,'Y':0.69,'V':1.06}
_CF_SHEET = {'A':0.83,'R':0.93,'N':0.89,'D':0.54,'C':1.19,'Q':1.10,'E':0.37,
'G':0.75,'H':0.87,'I':1.60,'L':1.30,'K':0.74,'M':1.05,'F':1.38,
'P':0.55,'S':0.75,'T':1.19,'W':1.37,'Y':1.47,'V':1.70}
# Disorder-promoting residues (Uversky)
_DISORDER_PROMOTING = set("AERSQGKPTD")
_TM_HYDROPHOBIC = set("AVILMFYW")
_CHARGE = {"K": 1.0, "R": 1.0, "D": -1.0, "E": -1.0}
INSECT_LINEAGE = {"Insecta", "Hexapoda", "Arthropoda", "Chelicerata", "Myriapoda"}
MAMMAL_LINEAGE = {"Mammalia", "Theria", "Eutheria", "Metatheria", "Monotremata"}
def _detect_taxon_composition(seq: str) -> tuple:
"""
Heuristic taxon detection from amino acid composition.
Uses a linear discriminant calibrated from insect/mammal proteome statistics.
Returns ('insect'|'mammal', confidence_float).
Confidence < 0.60 → ambiguous, caller should treat as 'auto'.
"""
n = len(seq)
if n < 60:
return "mammal", 0.50
seq_u = seq.upper()
freq = {aa: seq_u.count(aa) / n for aa in "ACDEFGHIKLMNPQRSTVWY"}
# Linear discriminant (positive = mammal, negative = insect)
# Derived from empirical insect vs mammal proteome frequency differences
score = 0.0
score += (freq.get("K", 0) - 0.058) * 18.0 # Lys enriched in mammals
score += (freq.get("G", 0) - 0.073) * -12.0 # Gly enriched in insects
score += (freq.get("A", 0) - 0.072) * -9.0 # Ala enriched in insects
score += (freq.get("S", 0) - 0.077) * 7.0 # Ser enriched in mammals
score += (freq.get("T", 0) - 0.056) * 6.0 # Thr enriched in mammals
score += (freq.get("P", 0) - 0.055) * -5.0 # Pro enriched in insects
score += (freq.get("R", 0) - 0.053) * 4.0 # Arg enriched in mammals
p_mammal = 1.0 / (1.0 + math.exp(-score * 2.5))
if p_mammal >= 0.68:
return "mammal", round(p_mammal, 3)
elif p_mammal <= 0.32:
return "insect", round(1.0 - p_mammal, 3)
return "mammal", 0.50 # uncertain → default mammal
def _detect_taxon_probe(emb_np) -> tuple:
"""Use the trained logistic probe on ESM embedding if loaded."""
if taxon_probe is None:
return None, 0.0
import numpy as np
w = taxon_probe["coef"]
b = taxon_probe["intercept"]
mu = taxon_probe["scaler_mean"]
sd = taxon_probe["scaler_std"]
x = (emb_np - mu) / (sd + 1e-12)
logit = float(np.dot(x, w) + b)
p_mammal = 1.0 / (1.0 + math.exp(-logit))
if p_mammal >= 0.65:
return "mammal", round(p_mammal, 3)
elif p_mammal <= 0.35:
return "insect", round(1.0 - p_mammal, 3)
return "mammal", 0.50
FEATURE_META = [
{"key": "f_seq_len", "label": "Sequence Length", "desc": "Global protein length (uniform per-residue)", "color": "#888888"},
{"key": "f_mean_hydro", "label": "Hydrophobicity", "desc": "Kyte-Doolittle hydrophobicity (window=5)", "color": "#f4a261"},
{"key": "f_net_charge", "label": "Net Charge", "desc": "Local charge balance K+R−D−E (window=9)", "color": "#457b9d"},
{"key": "f_uversky_disorder", "label": "Disorder Score", "desc": "Uversky charge-hydrophobicity disorder criterion (window=11)", "color": "#9b5de5"},
{"key": "f_idr_frac_proxy", "label": "IDR Residues", "desc": "Disorder-promoting residues: A,E,R,S,Q,G,K,P,T,D", "color": "#00b4d8"},
{"key": "f_lowcomp_proxy", "label": "Low Complexity", "desc": "Repetitive amino acid runs (length ≥5)", "color": "#adb5bd"},
{"key": "f_tm_frac_proxy", "label": "TM Helix Windows", "desc": "Transmembrane helix windows (≥17/20 hydrophobic residues)", "color": "#e63946"},
{"key": "f_tm_any_proxy", "label": "TM Present", "desc": "Presence of any transmembrane window", "color": "#c1121f"},
{"key": "f_signal_peptide_proxy", "label": "Signal Peptide", "desc": "N-terminal hydrophobic signal (linear decay, first 30 aa)", "color": "#2d6a4f"},
{"key": "f_cf_helix_mean", "label": "α-Helix Propensity", "desc": "Chou-Fasman α-helix propensity per residue", "color": "#4361ee"},
{"key": "f_cf_sheet_mean", "label": "β-Sheet Propensity", "desc": "Chou-Fasman β-sheet propensity per residue", "color": "#e76f51"},
]
def compute_seq_features(seq: str) -> dict:
"""
Compute the 11 sequence-based supplementary features that are always
available at inference time. Returns a dict keyed by SUPP_COL names.
AF-derived features (f_afdb_has_model, f_plddt_*, f_distbin_*, f_pae_*,
f_seqfeat_present, f_af_present) are set to 0 — they will be z-scored to
near-zero against training means where >99% of proteins also had no AF data.
"""
seq_u = seq.upper()
n = len(seq_u)
kd = [_KD.get(aa, 0.0) for aa in seq_u]
mean_hydro = sum(kd) / n
net_charge = (seq_u.count('R') + seq_u.count('K') -
seq_u.count('D') - seq_u.count('E')) / n
# Uversky charge-hydrophobicity disorder criterion
uversky_disorder = float(abs(mean_hydro) - abs(net_charge) < 0.06)
idr_frac = sum(1 for aa in seq_u if aa in _DISORDER_PROMOTING) / n
# Low-complexity: runs of the same amino acid
lowcomp = 0
i, prev, run = 0, '', 0
for aa in seq_u:
run = run + 1 if aa == prev else 1
if run >= 5:
lowcomp += 1
prev = aa
lowcomp_proxy = lowcomp / n
# TM helix proxy: windows of ≥17 hydrophobic residues in 20-aa window
tm_count = 0
for i in range(n - 19):
window = seq_u[i:i+20]
if sum(1 for aa in window if aa in _TM_HYDROPHOBIC) >= 17:
tm_count += 1
tm_frac = tm_count / max(n - 19, 1)
tm_any = float(tm_count > 0)
# Signal peptide proxy: first 30 aa have a hydrophobic core
sp_window = seq_u[:30]
sp_proxy = float(sum(1 for aa in sp_window if aa in _TM_HYDROPHOBIC) >= 8)
# Chou-Fasman secondary structure propensity
cf_helix = sum(_CF_HELIX.get(aa, 1.0) for aa in seq_u) / n
cf_sheet = sum(_CF_SHEET.get(aa, 1.0) for aa in seq_u) / n
return {
"f_seq_len": float(n),
"f_mean_hydro": float(mean_hydro),
"f_net_charge": float(net_charge),
"f_uversky_disorder": float(uversky_disorder),
"f_idr_frac_proxy": float(idr_frac),
"f_lowcomp_proxy": float(lowcomp_proxy),
"f_tm_frac_proxy": float(tm_frac),
"f_tm_any_proxy": float(tm_any),
"f_signal_peptide_proxy":float(sp_proxy),
"f_cf_helix_mean": float(cf_helix),
"f_cf_sheet_mean": float(cf_sheet),
# AF-derived features: absent at inference → use 0 (imputed to mean)
"f_afdb_has_model":0.0,"f_plddt_mean":0.0,"f_plddt_std":0.0,
"f_plddt_q10":0.0,"f_plddt_q50":0.0,"f_plddt_q90":0.0,
"f_plddt_frac_gt90":0.0,"f_plddt_frac_gt70":0.0,"f_plddt_frac_lt50":0.0,
"f_distbin_0":0.0,"f_distbin_1":0.0,"f_distbin_2":0.0,"f_distbin_3":0.0,
"f_distbin_4":0.0,"f_distbin_5":0.0,"f_distbin_6":0.0,"f_distbin_7":0.0,
"f_distbin_8":0.0,"f_distbin_9":0.0,
"f_pae_mean":0.0,"f_pae_median":0.0,"f_pae_p90":0.0,"f_pae_p95":0.0,
"f_pae_frac_lt5":0.0,"f_pae_frac_lt10":0.0,"f_pae_frac_gt20":0.0,
"f_seqfeat_present":1.0,"f_af_present":0.0,
}
def compute_per_residue_features(seq: str) -> dict:
"""Return per-residue vectors (normalized [0,1]) for the 11 supp features."""
import numpy as np
seq_u = seq.upper()
n = len(seq_u)
def smooth(arr, w):
hw = w // 2
out = []
for i in range(n):
lo, hi = max(0, i - hw), min(n, i + hw + 1)
out.append(sum(arr[lo:hi]) / (hi - lo))
return out
def normalize(arr):
mn, mx = min(arr), max(arr)
if mx == mn:
return [0.5] * n
return [(v - mn) / (mx - mn) for v in arr]
kd_raw = [_KD.get(aa, 0.0) for aa in seq_u]
charge_raw = [_CHARGE.get(aa, 0.0) for aa in seq_u]
# f_seq_len: uniform, no per-residue signal
r_seq_len = [1.0] * n
# f_mean_hydro: KD per residue smoothed window=5
r_mean_hydro = normalize(smooth(kd_raw, 5))
# f_net_charge: sliding charge window=9
r_net_charge = normalize(smooth(charge_raw, 9))
# f_uversky_disorder: window=11, high = more disordered tendency
uv_raw = []
for i in range(n):
lo, hi = max(0, i - 5), min(n, i + 6)
wkd = sum(kd_raw[lo:hi]) / (hi - lo)
wch = sum(charge_raw[lo:hi]) / (hi - lo)
uv_raw.append(max(0.0, 0.20 - (abs(wkd) - abs(wch))))
r_uversky = normalize(uv_raw)
# f_idr_frac_proxy: binary disorder residue indicator, smoothed
idr_raw = [1.0 if aa in _DISORDER_PROMOTING else 0.0 for aa in seq_u]
r_idr = normalize(smooth(idr_raw, 7))
# f_lowcomp_proxy: residues in amino-acid runs ≥5
lowcomp_raw = [0.0] * n
prev, run = '', 0
for j, aa in enumerate(seq_u):
run = run + 1 if aa == prev else 1
prev = aa
if run >= 5:
for k in range(max(0, j - run + 1), j + 1):
lowcomp_raw[k] = 1.0
r_lowcomp = lowcomp_raw
# f_tm_frac_proxy: residues covered by TM windows (≥17/20 hydrophobic)
tm_raw = [0.0] * n
if n >= 20:
for i in range(n - 19):
if sum(1 for aa in seq_u[i:i+20] if aa in _TM_HYDROPHOBIC) >= 17:
for k in range(i, i + 20):
tm_raw[k] = 1.0
r_tm_frac = tm_raw
# f_tm_any_proxy: same map (at residue level = tm_frac)
r_tm_any = tm_raw
# f_signal_peptide_proxy: linear decay × hydrophobicity, first 30 aa
sp_mod = []
for i in range(n):
weight = max(0.0, 1.0 - i / 30.0)
sp_mod.append(weight * max(0.0, kd_raw[i] / 4.5))
r_sp = normalize(sp_mod) if any(v > 0 for v in sp_mod) else [max(0.0, 1.0 - i / 30.0) for i in range(n)]
# f_cf_helix_mean / f_cf_sheet_mean: propensity per residue smoothed
r_cf_helix = normalize(smooth([_CF_HELIX.get(aa, 1.0) for aa in seq_u], 3))
r_cf_sheet = normalize(smooth([_CF_SHEET.get(aa, 1.0) for aa in seq_u], 3))
return {
"f_seq_len": [round(v, 3) for v in r_seq_len],
"f_mean_hydro": [round(v, 3) for v in r_mean_hydro],
"f_net_charge": [round(v, 3) for v in r_net_charge],
"f_uversky_disorder": [round(v, 3) for v in r_uversky],
"f_idr_frac_proxy": [round(v, 3) for v in r_idr],
"f_lowcomp_proxy": [round(v, 3) for v in r_lowcomp],
"f_tm_frac_proxy": [round(v, 3) for v in r_tm_frac],
"f_tm_any_proxy": [round(v, 3) for v in r_tm_any],
"f_signal_peptide_proxy": [round(v, 3) for v in r_sp],
"f_cf_helix_mean": [round(v, 3) for v in r_cf_helix],
"f_cf_sheet_mean": [round(v, 3) for v in r_cf_sheet],
}
def build_ancestor_cache(go_parents_map: dict) -> dict:
"""Compute full transitive ancestor sets for all GO terms (memoized DFS)."""
cache = {}
def _anc(gid):
if gid in cache:
return cache[gid]
parents = go_parents_map.get(gid, set())
all_anc = set(parents)
for p in parents:
all_anc |= _anc(p)
cache[gid] = all_anc
return all_anc
for gid in go_parents_map:
_anc(gid)
return cache
def _download_with_retry(fname):
from huggingface_hub import hf_hub_download
max_attempts = 6
for attempt in range(1, max_attempts + 1):
try:
print(f" [{attempt}/{max_attempts}] Downloading {fname}...")
path = hf_hub_download(
repo_id=HF_REPO, filename=fname,
local_dir=BASE_DIR, repo_type="model",
token=os.environ.get("HF_TOKEN"),
)
print(f" saved -> {path}")
return
except Exception as e:
if fname in OPTIONAL:
print(f" {fname} is optional, skipping ({e})")
return
if attempt == max_attempts:
raise RuntimeError(f"Could not download '{fname}' after {max_attempts} attempts: {e}")
wait = 2 ** attempt
print(f" Network error, retrying in {wait}s... ({e})")
time.sleep(wait)
def ensure_model_files():
missing = [f for f in HF_FILES if not os.path.exists(os.path.join(BASE_DIR, f))]
if not missing:
print("All model files already present.")
return
print(f"Downloading {len(missing)} file(s) from HuggingFace Hub...")
for fname in missing:
_download_with_retry(fname)
def load_go_map():
try:
df = pd.read_csv(os.path.join(BASE_DIR, "go_annotations_fixed.csv"))
mapping = {}
for _, row in df.iterrows():
go_id = str(row["GO Annotation"]).strip()
raw_name = str(row.get("Gene Ontology (molecular function)", "Unknown"))
mapping[go_id] = raw_name.split(" [")[0].strip()
print(f"GO map: {len(mapping)} labels loaded")
return mapping
except Exception as e:
print(f"GO map load error: {e}")
return {}
def load_thresholds():
"""
Load per-label thresholds and return (mammal_thresholds, mammal_T, insect_T, insect_t_default).
Threshold JSON formats accepted:
{"0": 0.68, ...} — plain float
{"0": {"threshold": 0.68, "tier": 0, "temperature": 3.69}} — rich dict
{"_meta": {...}, "0": 0.68, ...} — new format with metadata
Returns:
flat : dict str_idx -> float (mammal per-label thresholds)
mammal_T : float temperature for mammal inference
ins_T : float temperature for insect inference (usually 1.0)
ins_t_default : float flat threshold for insect labels with no per-label data
"""
for path in [
os.path.join(BASE_DIR, "unified_35M_v1_thresholds.json"), # 35M model thresholds (latest)
os.path.join(BASE_DIR, "unified_v1_recalibrated.json"), # 8M recalibrated: T=3.85, precision-tuned
os.path.join(BASE_DIR, "unified_v1_thresholds.json"),
os.path.join(BASE_DIR, "mammal_enriched_thresholds.json"),
os.path.join(BASE_DIR, "protfunc_v3_fixed_thresholds.json"),
os.path.join(BASE_DIR, "improved_per_label_thresholds.json"),
os.path.join(BASE_DIR, "protfunc_v3_thresholds.json"),
os.path.join(BASE_DIR, "per_label_thresholds.json"),
os.path.join(BASE_DIR, "artifacts", "per_label_thresholds.json"),
]:
if not os.path.exists(path):
continue
print(f"Thresholds loaded from {path}")
with open(path) as f:
raw = json.load(f)
flat = {}
mammal_T = 1.0
ins_T = 1.0
ins_t_def = 0.68
# Extract metadata block if present
meta = raw.pop("_meta", None)
if meta:
mammal_T = float(meta.get("temperature", mammal_T))
ins_T = float(meta.get("insect_temperature", ins_T))
ins_t_def = float(meta.get("insect_global_t", ins_t_def))
for k, v in raw.items():
if isinstance(v, dict):
flat[k] = float(v.get("threshold", 0.5))
mammal_T = float(v.get("temperature", mammal_T))
else:
try:
flat[k] = float(v)
except (TypeError, ValueError):
pass
print(f" {len(flat)} mammal thresholds | mammal_T={mammal_T:.4f} | "
f"insect_T={ins_T:.4f} insect_t={ins_t_def:.2f}")
return flat, mammal_T, ins_T, ins_t_def
print("Thresholds not found, using defaults")
return {}, 1.0, 1.0, 0.68
def parse_obo(path):
"""
Parse go-basic.obo and return:
mf_terms : set of active (non-obsolete) GO IDs with namespace == molecular_function
go_parents : dict GO ID -> set of direct parent GO IDs (is_a + part_of, MF only)
go_names_ob : dict GO ID -> canonical name from OBO (authoritative)
go_replaced : dict obsolete GO ID -> replacement GO ID
go_depth : dict GO ID -> minimum depth from MF root (root = 0)
All relationships are restricted to the MF namespace.
"""
ns_map = {} # id -> namespace
par_map = {} # id -> {parent ids}
name_map = {} # id -> canonical name
def_map = {} # id -> definition string
rep_map = {} # obsolete id -> replaced_by id
alt_map = {} # alt_id -> canonical id
obs_set = set()
cur_id = None
cur_ns = None
cur_nm = None
cur_df = None
cur_par = set()
cur_rep = None
cur_obs = False
cur_alt = []
in_term = False
def flush():
nonlocal cur_id, cur_ns, cur_nm, cur_df, cur_par, cur_rep, cur_obs, cur_alt
if cur_id:
if cur_obs:
obs_set.add(cur_id)
if cur_rep:
rep_map[cur_id] = cur_rep
else:
ns_map[cur_id] = cur_ns or ""
par_map[cur_id] = cur_par
name_map[cur_id] = cur_nm or cur_id
if cur_df:
def_map[cur_id] = cur_df
for a in cur_alt:
alt_map[a] = cur_id
cur_id = None; cur_ns = None; cur_nm = None; cur_df = None
cur_par = set(); cur_rep = None; cur_obs = False; cur_alt = []
with open(path, "r", encoding="utf-8") as fh:
for raw in fh:
line = raw.strip()
if line == "[Term]":
flush(); in_term = True; continue
if line.startswith("[") and line != "[Term]":
flush(); in_term = False; continue
if not in_term:
continue
if line.startswith("id:"):
cur_id = line.split("id:", 1)[1].strip().split()[0]
elif line.startswith("name:"):
cur_nm = line.split("name:", 1)[1].strip()
elif line.startswith("namespace:"):
cur_ns = line.split("namespace:", 1)[1].strip()
elif line.startswith("alt_id:"):
cur_alt.append(line.split("alt_id:", 1)[1].strip().split()[0])
elif line.startswith("def:"):
# def: "description text" [source] — strip quotes and source
raw_def = line.split("def:", 1)[1].strip()
if raw_def.startswith('"'):
end_q = raw_def.find('"', 1)
cur_df = raw_def[1:end_q] if end_q > 0 else raw_def
else:
cur_df = raw_def.split("[")[0].strip()
elif line.startswith("is_obsolete:") and "true" in line:
cur_obs = True
elif line.startswith("replaced_by:"):
cur_rep = line.split("replaced_by:", 1)[1].strip().split()[0]
elif line.startswith("is_a:"):
parent = line.split("is_a:", 1)[1].strip().split()[0]
cur_par.add(parent)
elif line.startswith("relationship:"):
parts = line.split("relationship:", 1)[1].strip().split()
if len(parts) >= 2 and parts[0] in ("part_of", "regulates",
"positively_regulates",
"negatively_regulates"):
cur_par.add(parts[1])
flush()
mf = {gid for gid, n in ns_map.items() if n == "molecular_function"}
go_parents_mf = {gid: (par_map[gid] & mf) for gid in mf}
go_names_ob = {gid: name_map[gid] for gid in mf}
go_defs_mf = {gid: def_map[gid] for gid in mf if gid in def_map}
n_edges = sum(len(v) for v in go_parents_mf.values())
print(f"OBO parsed: {len(mf)} MF terms, {n_edges} parent edges, "
f"{len(rep_map)} replacements, {len(alt_map)} alt-ids, "
f"{len(go_defs_mf)} definitions")
# BFS from root to compute minimum depth for each MF term
go_depth: dict = {}
go_depth[MF_ROOT] = 0
# Build children map for BFS (reverse of parents)
children: dict = {gid: set() for gid in mf}
for gid, parents in go_parents_mf.items():
for p in parents:
if p in children:
children[p].add(gid)
queue = [MF_ROOT]
while queue:
nxt = []
for gid in queue:
d = go_depth[gid]
for child in children.get(gid, ()):
if child not in go_depth:
go_depth[child] = d + 1
nxt.append(child)
queue = nxt
print(f"Depth computed: {len(go_depth)} MF terms, "
f"max depth={max(go_depth.values(), default=0)}")
return mf, go_parents_mf, go_names_ob, rep_map, go_depth, go_defs_mf
def compute_dynamic_cap(sorted_probs: list, seq_len: int) -> int:
"""
Return a protein-proportional cap on the number of direct predictions.
Combines three signals:
1. Sequence-length prior (log2 scaling).
2. Probability-gap detection — largest relative drop within budget.
3. Diffuse-activation penalty — when predictions are bunched at low
confidence with no clear outlier, the cap is tightened. This prevents
the model outputting many near-threshold terms for proteins it is
uncertain about (e.g. sparse mammal annotations).
"""
n = len(sorted_probs)
if n == 0:
return 0
if n <= 3:
return n
# Length prior (log2 scaling)
length_prior = max(2, int(2.5 * math.log2(max(seq_len, 50) / 50) + 2))
abs_cap = min(15, length_prior * 2)
# ── Diffuse-activation penalty ──────────────────────────────────────────
# If the top prediction is below 0.75 AND the spread across all predictions
# is narrow (< 0.12), we're seeing uniform noise rather than clear signal.
top_prob = sorted_probs[0]
spread = sorted_probs[0] - sorted_probs[-1]
if top_prob < 0.75 and spread < 0.12:
# Tight cluster at low confidence: cut cap to length_prior (conservative)
abs_cap = max(3, length_prior)
elif top_prob < 0.72:
# Moderately uncertain: mild tightening
abs_cap = max(3, min(abs_cap, length_prior + 2))
search_end = min(n, abs_cap)
if search_end < 2:
return min(n, abs_cap)
best_score = -1.0
best_idx = search_end # default: all up to abs_cap
for i in range(1, search_end):
prev = sorted_probs[i - 1]
curr = sorted_probs[i]
rel_gap = (prev - curr) / (prev + 1e-6)
dist = abs(i - length_prior) / max(length_prior, 1)
score = rel_gap * (1.0 - 0.25 * min(dist, 1.0))
if score > best_score and rel_gap >= 0.08:
best_score = score
best_idx = i
return max(3, best_idx)
def propagate_and_filter(preds, go_parents_map, go_ancestors_map, prob_map):
"""
1. Propagate predictions upward: for every predicted term, all its MF
ancestors are implicitly predicted. Ancestors not already above threshold
are added as 'implied' predictions with the child's probability.
2. Filter: a term is 'suppressed' only if it has MF parents and NONE of
them appear in the final visible set (direct or implied).
3. Specificity: implied ancestors with depth ≤ 1 (root-adjacent, trivially
general terms) are hidden from the visible list. Each prediction carries
its depth so the UI can convey specificity.
Returns (visible, suppressed) where visible includes informative implied parents.
"""
if not go_ancestors_map:
# Still annotate with depth even without ancestors
for p in preds:
p["depth"] = go_depth.get(p["go_id"], -1)
return preds, []
predicted_ids = {p["go_id"] for p in preds}
implied = {} # go_id -> max child prob
for pred in preds:
gid = pred["go_id"]
prob = pred["prob"]
for anc in go_ancestors_map.get(gid, set()):
if anc not in predicted_ids and anc != MF_ROOT:
implied[anc] = max(implied.get(anc, 0.0), prob)
all_visible_ids = predicted_ids | set(implied.keys())
# Classify direct predictions: visible if any MF parent is visible (or root / no parents)
suppressed = []
direct_ok = []
for pred in preds:
gid = pred["go_id"]
parents = go_parents_map.get(gid, set())
pred["depth"] = go_depth.get(gid, -1)
if gid == MF_ROOT or not parents or (parents & all_visible_ids):
direct_ok.append(pred)
else:
pred["reason"] = "no_visible_parent"
suppressed.append(pred)
# Add implied ancestor terms — skip root-adjacent (depth ≤ 1) and root itself
# Depth ≤ 1 = trivially general terms like "binding", "catalytic activity"
# that carry no predictive specificity on their own.
MIN_IMPLIED_DEPTH = 2
implied_preds = []
for gid, prob in implied.items():
d = go_depth.get(gid, -1)
if d < MIN_IMPLIED_DEPTH:
continue # too general — still implicitly true, just not displayed
implied_preds.append({
"go_id": gid,
"name": go_map.get(gid, gid),
"prob": round(prob, 3),
"implied": True,
"depth": d,
})
implied_preds.sort(key=lambda x: (-x["prob"], -x["depth"]))
visible = direct_ok + implied_preds
visible.sort(key=lambda x: (-x["prob"], -x.get("depth", 0)))
return visible, suppressed
def sequence_entropy(seq):
seq_upper = seq.upper()
counts = {}
for aa in seq_upper:
counts[aa] = counts.get(aa, 0) + 1
n = len(seq_upper)
return -sum((c / n) * math.log2(c / n) for c in counts.values())
def validate_sequence(name, seq):
"""Returns an error string if the sequence should be rejected, else None."""
if len(seq) < MIN_SEQ_LENGTH:
return (f"'{name}' is too short ({len(seq)} aa — minimum {MIN_SEQ_LENGTH} aa). "
f"Sequences this short are unlikely to fold into a stable domain.")
# Reject non-letter characters (digits, spaces, symbols)
non_letter = sorted({c for c in seq if not c.isalpha()})
if non_letter:
display = ", ".join(f"'{c}'" for c in non_letter[:5])
return (f"'{name}' contains non-amino-acid characters: {display}. "
f"Only single-letter amino acid codes are accepted.")
# Detect DNA/RNA sequences (>85% ATCGU with ≤5 distinct chars)
seq_upper_set = {c.upper() for c in seq}
nucleotide_chars = seq_upper_set & set("ATCGU")
nucleotide_frac = sum(seq.upper().count(c) for c in "ATCGU") / len(seq)
if nucleotide_frac > 0.85 and len(seq_upper_set) <= 6:
return (f"'{name}' appears to be a nucleotide sequence (DNA/RNA), not a protein. "
f"Please enter an amino acid sequence in single-letter code.")
bad = sorted({c.upper() for c in seq if c.upper() in INVALID_AA})
if bad:
return (f"'{name}' contains invalid amino acid character(s): "
f"{', '.join(bad)}. These ambiguity codes are not accepted.")
counts = {}
for aa in seq.upper():
counts[aa] = counts.get(aa, 0) + 1
if len(counts) < MIN_DISTINCT_AA:
return (f"'{name}' uses only {len(counts)} distinct residue type(s). "
f"Real proteins require at least {MIN_DISTINCT_AA}.")
dominant_frac = max(counts.values()) / len(seq)
if dominant_frac > MAX_DOMINANT_FRAC:
dominant_aa = max(counts, key=counts.get)
return (f"'{name}' is dominated by a single residue "
f"({dominant_aa} = {dominant_frac:.0%}). "
f"Low-complexity sequences produce unreliable embeddings.")
H = sequence_entropy(seq)
if H < MIN_ENTROPY_BITS:
return (f"'{name}' has very low sequence complexity "
f"(Shannon entropy {H:.2f} bits, minimum {MIN_ENTROPY_BITS:.1f} bits). "
f"Repetitive or artificially constructed sequences are not accepted.")
return None
@asynccontextmanager
async def lifespan(app: FastAPI):
global device, model, esm_model, batch_converter
global mlb, go_map, go_defs, mf_terms, go_parents, go_ancestors, go_depth, go_replaced
global mf_indices, thresholds, temperature, insect_temperature, insect_threshold_default, NUM_LABELS, _ESM_DIM
global supp_mu, supp_sd, supp_cols, model_uses_supp
global taxon_probe, platt_params
# Step 1: download missing files
ensure_model_files()
# Step 2: GO name map
go_map = load_go_map()
go_names_path = os.path.join(BASE_DIR, "go_names.json")
if os.path.exists(go_names_path):
with open(go_names_path) as f:
go_map.update(json.load(f))
print(f"Canonical GO names loaded: {len(go_map)} total entries")
# Step 3: MLB — load BEFORE anything references mlb.classes_
mlb = joblib.load(os.path.join(BASE_DIR, "mlb_public_v1.pkl"))
NUM_LABELS = len(mlb.classes_)
print(f"MLB loaded: {NUM_LABELS} labels")
# Step 4: OBO — parse MF namespace, parent DAG, names, depth, replacements
obo_path = os.path.join(BASE_DIR, "go-basic.obo")
if os.path.exists(obo_path):
mf_terms, go_parents, go_names_obo, go_replaced, go_depth, go_defs_obo = parse_obo(obo_path)
go_defs.update(go_defs_obo)
# OBO canonical names are the most authoritative — merge over CSV names
go_map.update(go_names_obo)
print(f"OBO names merged: {len(go_names_obo)} MF term names")
mf_in_mlb = sum(1 for gid in mlb.classes_ if gid in mf_terms)
rep_in_mlb = sum(1 for gid in mlb.classes_ if gid in go_replaced)
print(f"OBO cross-check: {mf_in_mlb}/{NUM_LABELS} active MF, "
f"{rep_in_mlb} replaced/obsolete labels remapped")
# Build full transitive ancestor cache for parental propagation
go_ancestors = build_ancestor_cache(go_parents)
print(f"Ancestor cache built: {len(go_ancestors)} terms")
else:
print("WARNING: go-basic.obo not found — hierarchy filtering disabled. "
"Download from https://current.geneontology.org/ontology/go-basic.obo "
"and place it alongside server.py.")
# Step 5: MF-only whitelist — OBO namespace is authoritative, CSV is fallback
if mf_terms:
mf_indices = [i for i, gid in enumerate(mlb.classes_) if gid in mf_terms]
print(f"MF whitelist (OBO): {len(mf_indices)} active indices")
else:
mf_go_ids = {
go_id for go_id, name in go_map.items()
if name and name != go_id and not name.startswith("GO:")
}
mf_indices = [i for i, gid in enumerate(mlb.classes_) if gid in mf_go_ids] or list(range(NUM_LABELS))
print(f"MF whitelist (CSV fallback): {len(mf_indices)} active indices")
app.state.mf_indices = mf_indices
# Step 6: per-label thresholds (mammal-calibrated) + insect fallback params
thresholds, temperature, insect_temperature, insect_threshold_default = load_thresholds()
# Step 7: classifier — auto-detect architecture from checkpoint keys
class ResBlock(nn.Module):
"""Pre-activation residual block with BatchNorm (improved model)."""
def __init__(self, dim, dropout=0.2):
super().__init__()
self.net = nn.Sequential(
nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim),
nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim),
)
def forward(self, x):
return x + self.net(x)
class ImprovedResidualMLP(nn.Module):
"""4-block, hidden=2048, BatchNorm — trained by train_improved.py."""
def __init__(self, in_dim=320, out_dim=NUM_LABELS, hidden=2048, n_blocks=4, dropout=0.2):
super().__init__()
self.fc_in = nn.Linear(in_dim, hidden)
self.blocks = nn.ModuleList([ResBlock(hidden, dropout) for _ in range(n_blocks)])
self.fc_out = nn.Sequential(
nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(hidden, out_dim),
)
def forward(self, x):
h = self.fc_in(x)
for blk in self.blocks:
h = blk(h)
return self.fc_out(h)
class ResidualMLP(nn.Module):
"""Original 2-block notebook model (fallback)."""
def __init__(self, in_dim=320, out_dim=NUM_LABELS, hidden=1024, dropout=0.2):
super().__init__()
self.fc_in = nn.Linear(in_dim, hidden)
self.block1 = nn.Sequential(nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, hidden))
self.block2 = nn.Sequential(nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, hidden))
self.fc_out = nn.Sequential(nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, out_dim))
def forward(self, x):
h = self.fc_in(x)
h = torch.relu(h)
h = h + self.block1(h)
h = h + self.block2(h)
return self.fc_out(h)
class RecoveredBaselineModel(nn.Module):
"""Earlier server-side architecture — retained for backward compatibility."""
def __init__(self, in_dim=320, out_dim=NUM_LABELS, hidden=1024, dropout=0.2):
super().__init__()
self.fc1 = nn.Linear(in_dim, hidden)
self.proj = nn.Linear(in_dim, hidden)
self.fc2 = nn.Linear(hidden, hidden)
self.out = nn.Linear(hidden, out_dim)
self.relu = nn.ReLU()
self.drop = nn.Dropout(dropout)
def forward(self, x):
h = self.relu(self.fc1(x))
h = h + self.proj(x)
h = self.relu(self.fc2(h))
h = self.drop(h)
return self.out(h)
import numpy as np
device = torch.device("cpu")
# Prefer checkpoints in priority order: 35M > unified_v1 > mammal_enriched > v3_fixed > improved > v3 > supp_res2 > baseline
ckpt_candidates = [
os.path.join(BASE_DIR, "unified_35M_v1.pth"),
os.path.join(BASE_DIR, "unified_v1.pth"),
os.path.join(BASE_DIR, "mammal_enriched.pth"),
os.path.join(BASE_DIR, "protfunc_v3_fixed.pth"),
os.path.join(BASE_DIR, "improved_res.pth"),
os.path.join(BASE_DIR, "protfunc_v3.pth"),
os.path.join(BASE_DIR, "supp_res2.pth"),
os.path.join(BASE_DIR, "baseline_res.pth"),
]
_ESM_DIM = 320 # updated after checkpoint load if esm_dim present
_model = None
for ckpt_path in ckpt_candidates:
if not os.path.exists(ckpt_path):
continue
print(f"Trying classifier: {os.path.basename(ckpt_path)}")
try:
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
except Exception as e:
print(f" Failed to load {os.path.basename(ckpt_path)}: {e} — skipping")
continue
sd = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
keys = set(sd.keys())
# Reset supp globals for each candidate
_c_supp_mu = None
_c_supp_sd = None
_c_supp_cols = None
_c_uses_supp = False
if isinstance(ckpt, dict) and "supp_mu" in ckpt:
_c_supp_mu = np.array(ckpt["supp_mu"], dtype=np.float32)
_c_supp_sd = np.array(ckpt["supp_sd"], dtype=np.float32)
_c_supp_cols = ckpt["supp_cols"]
_c_uses_supp = True
# Detect architecture and validate in_dim against supp metadata.
# Checkpoints that store explicit "in_dim" may use a truncated supp feature
# set (e.g. mammal_enriched uses SUPP_COLS[:11] → in_dim=331, but stores
# all 39 supp_cols for reference). Trust fc_in.weight shape in that case.
_ckpt_has_explicit_in_dim = isinstance(ckpt, dict) and "in_dim" in ckpt
# Detect out_dim from checkpoint to avoid mismatch with MLB size
_out_dim_ckpt = None
for _out_key in ("fc_out.3.weight", "fc_out.2.weight", "out.weight"):
if _out_key in sd:
_out_dim_ckpt = sd[_out_key].shape[0]
break
if "blocks.0.net.0.weight" in keys:
hidden_dim = sd["fc_in.weight"].shape[0]
n_blocks = sum(1 for k in keys if k.startswith("blocks.") and k.endswith(".net.0.weight"))
in_dim_ckpt = sd["fc_in.weight"].shape[1]
if _c_uses_supp and _c_supp_cols is not None and not _ckpt_has_explicit_in_dim:
expected = _ESM_DIM + len(_c_supp_cols) + 1
if in_dim_ckpt != expected:
print(f" SKIP: supp metadata says in_dim={expected} but fc_in has {in_dim_ckpt} — corrupted checkpoint")
continue
out_dim_use = _out_dim_ckpt or NUM_LABELS
_model = ImprovedResidualMLP(in_dim=in_dim_ckpt, hidden=hidden_dim, n_blocks=n_blocks, out_dim=out_dim_use).to(device)
print(f" ImprovedResidualMLP (in={in_dim_ckpt} hidden={hidden_dim} blocks={n_blocks} out={out_dim_use})")
elif any(k.startswith("fc_in") for k in keys):
in_dim_ckpt = sd["fc_in.weight"].shape[1]
if _c_uses_supp and _c_supp_cols is not None and not _ckpt_has_explicit_in_dim:
expected = _ESM_DIM + len(_c_supp_cols) + 1
if in_dim_ckpt != expected:
print(f" SKIP: supp metadata says in_dim={expected} but fc_in has {in_dim_ckpt} — corrupted checkpoint")
continue
out_dim_use = _out_dim_ckpt or NUM_LABELS
_model = ResidualMLP(in_dim=in_dim_ckpt, out_dim=out_dim_use).to(device)
print(f" ResidualMLP (in={in_dim_ckpt} out={out_dim_use})")
elif any(k.startswith("fc1") for k in keys):
_c_uses_supp = False
out_dim_use = _out_dim_ckpt or NUM_LABELS
_model = RecoveredBaselineModel(out_dim=out_dim_use).to(device)
print(f" RecoveredBaselineModel (legacy out={out_dim_use})")
else:
print(f" SKIP: unrecognised architecture — keys: {sorted(keys)[:8]}")
continue
try:
_model.load_state_dict(sd, strict=True)
except Exception as e:
print(f" SKIP: load_state_dict failed: {e}")
_model = None
continue
# Commit globals and break
supp_mu = _c_supp_mu
supp_sd = _c_supp_sd
supp_cols = _c_supp_cols
model_uses_supp = _c_uses_supp
if isinstance(ckpt, dict) and "val_fmax" in ckpt:
print(f" val_fmax={ckpt['val_fmax']:.4f} epoch={ckpt.get('epoch','?')}")
if _c_uses_supp:
print(f" Supplemented model: {len(_c_supp_cols)} extra features")
# Detect ESM dim from checkpoint metadata
if isinstance(ckpt, dict) and "esm_dim" in ckpt:
_ESM_DIM = int(ckpt["esm_dim"])
print(f" ESM dim from checkpoint: {_ESM_DIM}")
print(f"Classifier loaded: {os.path.basename(ckpt_path)}")
break
if _model is None:
raise RuntimeError("No valid classifier checkpoint found.")
_model.eval()
model = _model
# Step 8: ESM-2 — choose model based on detected esm_dim
import esm as esm_lib
if _ESM_DIM == 480:
_esm_model, alphabet = esm_lib.pretrained.esm2_t12_35M_UR50D()
print("ESM-2 (35M, 480-dim) loaded OK")
else:
_esm_model, alphabet = esm_lib.pretrained.esm2_t6_8M_UR50D()
print("ESM-2 (8M, 320-dim) loaded OK")
esm_model = _esm_model.to(device).eval()
batch_converter = alphabet.get_batch_converter()
# Step 9: Taxon probe (optional, generated by calibrate_server.py / calibrate_probe_35M.py)
probe_path = os.path.join(BASE_DIR, "taxon_probe.json")
if os.path.exists(probe_path):
with open(probe_path) as f:
taxon_probe = json.load(f)
acc = taxon_probe.get("train_accuracy", 0)
probe_esm_dim = taxon_probe.get("esm_dim", 320)
if probe_esm_dim != _ESM_DIM:
print(f"Taxon probe ESM dim mismatch ({probe_esm_dim} vs {_ESM_DIM}) — disabling probe")
taxon_probe = None
else:
for k in ("coef", "intercept", "scaler_mean", "scaler_std"):
if k in taxon_probe:
taxon_probe[k] = np.asarray(taxon_probe[k], dtype=np.float32)
print(f"Taxon probe loaded (train_acc={acc:.4f}, esm_dim={probe_esm_dim})")
else:
print("Taxon probe not found — using composition heuristic for auto-detection")
# Step 10: Platt scaling (optional, generated by calibrate_server.py)
platt_path = os.path.join(BASE_DIR, "platt_mammal.json")
if os.path.exists(platt_path):
with open(platt_path) as f:
platt_params = json.load(f)
print(f"Platt scaling loaded: {len(platt_params)} labels calibrated")
else:
print("Platt params not found — using temperature scaling only")
# Step 11: Override temperature from calibration sweep if available
temp_path = os.path.join(BASE_DIR, "temperature_best.json")
if os.path.exists(temp_path):
with open(temp_path) as f:
temp_data = json.load(f)
new_T = float(temp_data.get("optimal_T", temperature))
if abs(new_T - temperature) > 0.1:
print(f"Temperature updated by calibration sweep: {temperature:.4f}{new_T:.4f}")
temperature = new_T
else:
print(f"Temperature unchanged by sweep: {temperature:.4f}")
yield
print("Shutting down.")
app = FastAPI(lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
@app.get("/")
async def root():
return FileResponse(os.path.join(STATIC_DIR, "interface.html"), headers={"Cache-Control": "no-store"})
@app.get("/api/model/info")
async def model_info():
"""Return model metadata and configuration."""
unified_35M = os.path.exists(os.path.join(BASE_DIR, "unified_35M_v1.pth"))
unified_v1 = os.path.exists(os.path.join(BASE_DIR, "unified_v1.pth"))
mammal_enriched = os.path.exists(os.path.join(BASE_DIR, "mammal_enriched.pth"))
v3_fixed = os.path.exists(os.path.join(BASE_DIR, "protfunc_v3_fixed.pth"))
improved = os.path.exists(os.path.join(BASE_DIR, "improved_res.pth"))
# model name reflects actual loaded model (unified_35M_v1 takes highest priority)
if unified_35M and model_uses_supp and _ESM_DIM == 480:
name, version, active = "ProtFunc v5.0 (35M ESM, GOA-enriched, insect+mammal)", "5.0.0", "unified_35M_v1"
elif unified_v1 and model_uses_supp:
name, version, active = "ProtFunc v4.0 (unified insect+mammal, CAFA5-evaluated)", "4.0.0", "unified_v1"
elif mammal_enriched and model_uses_supp:
name, version, active = "ProtFunc v3.2 (mammal-enriched, CAFA5-evaluated)", "3.2.0", "mammal_enriched"
elif v3_fixed and model_uses_supp:
name, version, active = "ProtFunc v3-fixed (ablation best, CAFA-correct)", "3.1.0", "protfunc_v3_fixed"
elif model_uses_supp:
name, version, active = "ProtFunc v3 (supplemented + mammal)", "3.0.0", "protfunc_v3"
elif improved:
name, version, active = "ProtFunc Enhanced", "2.1.0", "improved"
else:
name, version, active = "ProtFunc", "2.0.0", "baseline"
return {
"model_name": name,
"model": active,
"version": version,
"esm_model": "esm2_t12_35M_UR50D" if _ESM_DIM == 480 else "esm2_t6_8M_UR50D",
"embed_dim": _ESM_DIM,
"num_labels": NUM_LABELS,
"supported_namespaces": ["molecular_function"],
"max_sequence_length": 1500,
"thresholds_loaded": len(thresholds) > 0,
"temperature_scaling": temperature != 1.0,
"temperature": round(temperature, 4),
"insect_temperature": round(insect_temperature, 4),
"insect_threshold_default": round(insect_threshold_default, 4),
"supported_taxa": ["insect", "mammal"],
"taxon_routing": True,
"hierarchy_filtering": len(go_parents) > 0,
"parental_propagation": len(go_ancestors) > 0,
"depth_annotation": len(go_depth) > 0,
"mf_terms_loaded": len(mf_terms) if mf_terms else 0,
"mf_max_depth": max(go_depth.values(), default=0) if go_depth else 0,
"supplemented_features": model_uses_supp,
"supp_feature_count": len(supp_cols) if supp_cols else 0,
}
@app.get("/api/generalization")
async def get_generalization():
"""
Return cross-taxon generalization results from eval_generalization.py output.
Serves artifacts/generalization_results.json if present, otherwise returns empty.
"""
candidates = [
os.path.join(BASE_DIR, "artifacts", "generalization", "generalization_results.json"),
os.path.join(BASE_DIR, "generalization_results.json"),
]
for path in candidates:
if os.path.exists(path):
with open(path) as f:
data = json.load(f)
return {
"available": True,
"taxa": list(data.keys()),
"results": data,
}
return {"available": False, "taxa": [], "results": {}}
@app.get("/api/structure")
async def get_structure(uniprot_id: str):
"""Look up AlphaFold structure data for a UniProt accession."""
import urllib.request, urllib.error
uid = uniprot_id.upper().strip()
if not re.match(r'^[A-Z0-9]{4,10}$', uid):
return {"found": False, "error": "Invalid accession format"}
try:
req = urllib.request.Request(
f"https://alphafold.ebi.ac.uk/api/prediction/{uid}",
headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"}
)
with urllib.request.urlopen(req, timeout=8) as resp:
entries = json.loads(resp.read())
d = entries[0]
return {
"found": True,
"accession": uid,
"organism": d.get("organismScientificName", ""),
"gene": d.get("gene", ""),
"cif_url": d.get("cifUrl", ""),
"pae_image_url": d.get("paeImageUrl", ""),
"entry_url": f"https://alphafold.ebi.ac.uk/entry/{uid}",
"uniprot_url": f"https://www.uniprot.org/uniprot/{uid}",
"model_version": d.get("latestVersion", 4),
}
except urllib.error.HTTPError as e:
if e.code == 404:
return {"found": False, "accession": uid,
"uniprot_url": f"https://www.uniprot.org/uniprot/{uid}"}
return {"found": False, "error": f"HTTP {e.code}"}
except Exception as e:
return {"found": False, "error": str(e)[:100]}
@app.get("/api/uniprot/annotations")
async def get_uniprot_annotations(uniprot_id: str):
"""
Fetch GO-MF annotations and organism info from UniProt REST API.
Returns known annotations (evidence-coded) for comparison with predictions.
"""
import urllib.request, urllib.error
uid = uniprot_id.upper().strip()
if not re.match(r'^[A-Z0-9]{4,10}$', uid):
return {"found": False, "error": "Invalid accession format"}
if uid in _uniprot_cache:
return _uniprot_cache[uid]
try:
url = f"https://rest.uniprot.org/uniprotkb/{uid}.json?fields=go,organism,protein_name,gene_names,organism_lineage"
req = urllib.request.Request(url, headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"})
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read())
except urllib.error.HTTPError as e:
if e.code == 404:
return {"found": False, "accession": uid, "error": "Not found in UniProt"}
return {"found": False, "error": f"HTTP {e.code}"}
except Exception as e:
return {"found": False, "error": str(e)[:100]}
# Parse organism lineage for taxon detection
organism = data.get("organism", {})
org_name = organism.get("scientificName", "")
lineage = [x.get("scientificName", "") for x in organism.get("lineage", [])]
if any(t in lineage for t in INSECT_LINEAGE):
detected_taxon = "insect"
elif any(t in lineage for t in MAMMAL_LINEAGE):
detected_taxon = "mammal"
else:
detected_taxon = "auto"
# Parse protein name
pn_block = data.get("proteinDescription", {})
prot_name = ""
if "recommendedName" in pn_block:
prot_name = pn_block["recommendedName"].get("fullName", {}).get("value", "")
elif "submissionNames" in pn_block:
prot_name = pn_block["submissionNames"][0].get("fullName", {}).get("value", "")
gene_names = []
for gn in data.get("genes", []):
if "geneName" in gn:
gene_names.append(gn["geneName"]["value"])
# Parse GO-MF annotations
go_mf = []
for xref in data.get("uniProtKBCrossReferences", []):
if xref.get("database") != "GO":
continue
go_id = xref.get("id", "")
props = {p["key"]: p["value"] for p in xref.get("properties", [])}
term_str = props.get("GoTerm", "")
if not term_str.startswith("F:"):
continue # only molecular function
evidence = props.get("GoEvidenceType", "")
ev_code = evidence.split(":")[0] if ":" in evidence else evidence
go_name = term_str[2:].strip()
# Resolve canonical name from our go_map if available
display_name = go_map.get(go_id, go_name)
# Classify evidence tier
exp_codes = {"EXP","IDA","IPI","IMP","IGI","IEP","HTP","HDA","HMP","HGI","HEP"}
comp_codes = {"ISS","ISO","ISA","ISM","IGC","IBA","IBD","IKR","IRD","RCA"}
if ev_code in exp_codes:
ev_tier = "experimental"
elif ev_code in comp_codes:
ev_tier = "computational"
elif ev_code == "IEA":
ev_tier = "electronic"
else:
ev_tier = "other"
go_mf.append({
"go_id": go_id,
"name": display_name,
"evidence": ev_code,
"ev_tier": ev_tier,
})
# Deduplicate by go_id, keeping best evidence tier
tier_rank = {"experimental": 0, "computational": 1, "other": 2, "electronic": 3}
seen = {}
for entry in go_mf:
gid = entry["go_id"]
if gid not in seen or tier_rank[entry["ev_tier"]] < tier_rank[seen[gid]["ev_tier"]]:
seen[gid] = entry
go_mf = sorted(seen.values(), key=lambda x: (tier_rank[x["ev_tier"]], x["name"]))
result = {
"found": True,
"accession": uid,
"protein_name": prot_name,
"gene_names": gene_names,
"organism": org_name,
"lineage": lineage[-5:] if lineage else [],
"detected_taxon": detected_taxon,
"go_mf": go_mf,
"n_experimental": sum(1 for e in go_mf if e["ev_tier"] == "experimental"),
"n_total": len(go_mf),
}
_uniprot_cache[uid] = result
return result
@app.get("/api/health")
async def health_check():
"""Health check endpoint for monitoring."""
return {
"status": "healthy",
"model_loaded": model is not None,
"esm_loaded": esm_model is not None,
"labels": NUM_LABELS,
}
class ProteinRequest(BaseModel):
sequence: str
taxon: str = "auto" # "auto" | "insect" | "mammal"
uniprot_id: str = "" # optional accession for structure lookup
class SaliencyRequest(BaseModel):
sequence: str
uniprot_id: str = ""
taxon: str = "auto"
top_k: int = 20 # top-k predicted labels to use as saliency objective
class ExplainRequest(BaseModel):
sequence: str
uniprot_id: str = ""
taxon: str = "auto"
top_k: int = 10
class BatchPredictRequest(BaseModel):
sequences: list # List of {name: str, sequence: str} objects
threshold: float = 0.5
include_suppressed: bool = True
def parse_sequences(text):
text = text.strip()
if text.startswith(">"):
blocks = re.split(r"(>.*)", text)
names, seqs = [], []
i = 1
while i < len(blocks):
name = blocks[i][1:].strip()
seq = re.sub(r"\s+", "", blocks[i + 1]) if i + 1 < len(blocks) else ""
if seq:
names.append(name)
seqs.append(seq)
i += 2
return list(zip(names, seqs))
seqs = [line.strip() for line in text.splitlines() if line.strip()]
return [(f"Sequence {i + 1}", s) for i, s in enumerate(seqs)]
def _build_model_input(emb: torch.Tensor, sequence: str) -> torch.Tensor:
"""
Build the model input tensor from an ESM-2 embedding.
For supplemented models: appends z-scored seq/structural features.
For base models: returns the embedding as-is.
Handles two supplemented feature sets:
- v3 models (in_dim=360): 39 sequence-derived features + 1 missingness flag
- supp_res2 (in_dim=709): 388 features including f_Dim_* (ESM dims), AF
structural features (zeros at inference), and sequence features
"""
import numpy as np
if not model_uses_supp or supp_mu is None:
return emb.unsqueeze(0)
feats = compute_seq_features(sequence)
# For models that include f_Dim_* features (e.g. supp_res2), populate them
# from the ESM embedding rather than defaulting to 0.
emb_np = emb.detach().cpu().numpy()
for c in supp_cols:
if c.startswith('f_Dim_'):
try:
dim_idx = int(c.split('_')[-1])
if dim_idx < len(emb_np):
feats[c] = float(emb_np[dim_idx])
except (ValueError, IndexError):
pass
s_vec = np.array([feats.get(c, 0.0) for c in supp_cols], dtype=np.float32)
s_z = (s_vec - supp_mu) / (supp_sd + 1e-12)
# missingness flag: 1 = feature data available (seq features always computable)
flag = np.array([1.0], dtype=np.float32)
extra = torch.from_numpy(np.concatenate([s_z, flag]))
full_input = torch.cat([emb, extra], dim=0).unsqueeze(0)
# Verify the built input matches what the model actually expects.
# Some checkpoints (e.g. mammal_enriched) store all supp_cols for reference but
# were trained with only the first N features (via [:in_dim] truncation).
# Truncate to match rather than falling back to bare ESM embedding.
if model is not None and hasattr(model, 'fc_in'):
expected_dim = model.fc_in.weight.shape[1]
if full_input.shape[1] != expected_dim:
if full_input.shape[1] > expected_dim:
full_input = full_input[:, :expected_dim]
else:
return emb.unsqueeze(0) # can't expand — fall back to bare ESM
return full_input
# LRU cache for ESM embeddings — avoids recomputing for repeated/identical sequences
_ESM_CACHE: dict = {}
_ESM_CACHE_MAX = 256
def _get_esm_embedding(sequence: str) -> torch.Tensor:
"""Return mean-pooled ESM-2 embedding, using in-memory cache."""
key = hashlib.md5(sequence.encode()).hexdigest()
if key in _ESM_CACHE:
return _ESM_CACHE[key]
_, _, tokens = batch_converter([("p", sequence)])
with torch.no_grad():
rep = esm_model(tokens.to(device), repr_layers=[6])["representations"][6]
emb = rep[0, 1:len(sequence) + 1].mean(0).cpu()
if len(_ESM_CACHE) >= _ESM_CACHE_MAX:
_ESM_CACHE.pop(next(iter(_ESM_CACHE))) # evict oldest (FIFO)
_ESM_CACHE[key] = emb
return emb
@app.post("/predict")
async def predict(request: ProteinRequest):
import numpy as np
import urllib.request, urllib.error
entries = parse_sequences(request.sequence)
results = []
device_cpu = torch.device("cpu")
mf_idx = app.state.mf_indices
uid = (request.uniprot_id or "").upper().strip()
req_taxon = (request.taxon or "auto").lower()
# Fetch UniProt annotations async (if accession given) — do once for the batch
uniprot_data = None
detected_taxon_from_uniprot = None
if uid and re.match(r'^[A-Z0-9]{4,10}$', uid):
try:
url = f"https://rest.uniprot.org/uniprotkb/{uid}.json?fields=go,organism,protein_name,gene_names,organism_lineage"
req_obj = urllib.request.Request(url, headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"})
with urllib.request.urlopen(req_obj, timeout=8) as resp:
raw = json.loads(resp.read())
organism = raw.get("organism", {})
lineage = [x.get("scientificName", "") for x in organism.get("lineage", [])]
if any(t in lineage for t in INSECT_LINEAGE):
detected_taxon_from_uniprot = "insect"
elif any(t in lineage for t in MAMMAL_LINEAGE):
detected_taxon_from_uniprot = "mammal"
# Parse GO-MF annotations
go_mf_known = {}
tier_rank = {"experimental": 0, "computational": 1, "other": 2, "electronic": 3}
exp_ev = {"EXP","IDA","IPI","IMP","IGI","IEP","HTP","HDA","HMP","HGI","HEP"}
comp_ev = {"ISS","ISO","ISA","ISM","IGC","IBA","IBD","IKR","IRD","RCA"}
for xref in raw.get("uniProtKBCrossReferences", []):
if xref.get("database") != "GO":
continue
props = {p["key"]: p["value"] for p in xref.get("properties", [])}
if not props.get("GoTerm", "").startswith("F:"):
continue
go_id = xref.get("id", "")
ev_code = props.get("GoEvidenceType", "").split(":")[0]
ev_tier = "experimental" if ev_code in exp_ev else (
"computational" if ev_code in comp_ev else (
"electronic" if ev_code == "IEA" else "other"))
entry = {"go_id": go_id, "name": go_map.get(go_id, props["GoTerm"][2:]),
"evidence": ev_code, "ev_tier": ev_tier}
if go_id not in go_mf_known or tier_rank[ev_tier] < tier_rank[go_mf_known[go_id]["ev_tier"]]:
go_mf_known[go_id] = entry
uniprot_data = {
"go_mf_known": sorted(go_mf_known.values(), key=lambda x: (tier_rank[x["ev_tier"]], x["name"])),
"organism": organism.get("scientificName", ""),
"detected_taxon": detected_taxon_from_uniprot,
}
except Exception:
pass
for name, sequence in entries:
err = validate_sequence(name, sequence)
if err:
results.append({"name": name, "error": err})
continue
if len(sequence) > 1500:
results.append({"name": name, "error": "Sequence too long (max 1500 aa)"})
continue
try:
emb = _get_esm_embedding(sequence).to(device_cpu)
emb_np = emb.detach().cpu().numpy()
# ── Taxon auto-detection ────────────────────────────────────────
taxon_source = req_taxon
taxon_conf = 1.0
if req_taxon == "auto":
if detected_taxon_from_uniprot:
taxon_source = detected_taxon_from_uniprot
taxon_conf = 1.0
elif taxon_probe is not None:
taxon_source, taxon_conf = _detect_taxon_probe(emb_np)
else:
taxon_source, taxon_conf = _detect_taxon_composition(sequence)
if taxon_source == "insect":
t_apply = insect_temperature
thresh_lookup = {}
t_default = insect_threshold_default
mammal_floor = 0.0
else:
t_apply = temperature
thresh_lookup = thresholds
t_default = 0.68
mammal_floor = 0.56
# ── Forward pass ────────────────────────────────────────────────
with torch.no_grad():
inp = _build_model_input(emb, sequence)
logits = model(inp).squeeze()
# Apply Platt scaling per-label (if available, mammal only)
if platt_params and taxon_source != "insect":
probs_list = []
for i in range(len(logits)):
l = float(logits[i])
if str(i) in platt_params:
A, B = platt_params[str(i)]
p = 1.0 / (1.0 + math.exp(-(A * l + B)))
else:
p = 1.0 / (1.0 + math.exp(-l / t_apply))
probs_list.append(p)
prob = torch.tensor(probs_list)
else:
prob = torch.sigmoid(logits / t_apply)
if prob.dim() == 0:
prob = prob.unsqueeze(0)
# ── Threshold + collect ─────────────────────────────────────────
raw_preds = []
prob_map = {}
for i in mf_idx:
pv = float(prob[i])
label_thresh = max(float(thresh_lookup.get(str(i), t_default)), mammal_floor)
if pv >= label_thresh:
go_id = mlb.classes_[i]
display_id = go_replaced.get(go_id, go_id)
display_nm = go_map.get(display_id, go_map.get(go_id, go_id))
entry = {"go_id": display_id, "name": display_nm,
"prob": round(pv, 4), "depth": go_depth.get(display_id, -1)}
if display_id != go_id:
entry["original_id"] = go_id
raw_preds.append(entry)
prob_map[display_id] = pv
raw_preds.sort(key=lambda x: x["prob"], reverse=True)
cap = compute_dynamic_cap([p["prob"] for p in raw_preds], len(sequence))
raw_preds = raw_preds[:cap]
for rp in raw_preds:
prob_map[rp["go_id"]] = rp["prob"]
visible, suppressed = propagate_and_filter(raw_preds, go_parents, go_ancestors, prob_map)
result = {
"name": name,
"sequence_length": len(sequence),
"predictions": visible,
"suppressed": suppressed,
"n_above_threshold": len(raw_preds),
"n_implied_parents": sum(1 for p in visible if p.get("implied")),
"taxon_applied": taxon_source,
"taxon_source": "uniprot" if detected_taxon_from_uniprot and req_taxon == "auto"
else ("probe" if taxon_probe and req_taxon == "auto"
else ("composition" if req_taxon == "auto" else "manual")),
"taxon_confidence": round(taxon_conf, 3),
"temperature_applied": round(t_apply, 4),
"platt_applied": bool(platt_params) and taxon_source != "insect",
}
if uniprot_data:
result["uniprot"] = uniprot_data
results.append(result)
except Exception as e:
results.append({"name": name, "error": str(e)})
return {"results": results}
@app.post("/api/predict/batch")
async def predict_batch(request: BatchPredictRequest):
"""
Batch prediction endpoint for multiple sequences.
Accepts a list of sequence objects and returns predictions for all.
More efficient than multiple single predictions due to batching.
"""
results = []
mf_idx = app.state.mf_indices
custom_threshold = request.threshold
for item in request.sequences:
name = item.get("name", "Unknown")
sequence = item.get("sequence", "")
# Validate sequence
err = validate_sequence(name, sequence)
if err:
results.append({"name": name, "error": err})
continue
if len(sequence) > 1500:
results.append({"name": name, "error": "Sequence too long (max 1500 aa)"})
continue
try:
emb = _get_esm_embedding(sequence).to(device)
with torch.no_grad():
inp = _build_model_input(emb, sequence)
prob = torch.sigmoid(model(inp) / temperature).squeeze()
if prob.dim() == 0:
prob = prob.unsqueeze(0)
raw_preds = []
prob_map = {}
for i in mf_idx:
pv = float(prob[i])
thresh = float(thresholds.get(str(i), custom_threshold))
if pv >= thresh:
go_id = mlb.classes_[i]
display_id = go_replaced.get(go_id, go_id)
display_nm = go_map.get(display_id, go_map.get(go_id, go_id))
entry = {
"go_id": display_id,
"name": display_nm,
"prob": round(pv, 4),
"depth": go_depth.get(display_id, -1),
}
if display_id != go_id:
entry["original_id"] = go_id
raw_preds.append(entry)
prob_map[display_id] = pv
raw_preds.sort(key=lambda x: x["prob"], reverse=True)
cap = compute_dynamic_cap([p["prob"] for p in raw_preds], len(sequence))
raw_preds = raw_preds[:cap]
for rp in raw_preds:
prob_map[rp["go_id"]] = rp["prob"]
visible, suppressed = propagate_and_filter(
raw_preds, go_parents, go_ancestors, prob_map
)
if not request.include_suppressed:
suppressed = []
results.append({
"name": name,
"sequence_length": len(sequence),
"predictions": visible,
"suppressed": suppressed,
"n_above_threshold": len(raw_preds),
"n_implied_parents": sum(1 for p in visible if p.get("implied")),
})
except Exception as e:
results.append({"name": name, "error": str(e)})
return {
"results": results,
"total": len(results),
"successful": sum(1 for r in results if "error" not in r),
}
class GoTermsRequest(BaseModel):
sequence: str
uniprot_id: str = ""
taxon: str = "auto"
top_k: int = 20
include_implied: bool = False
min_prob: float = 0.0
_go_terms_cache: dict = {}
@app.post("/api/go_terms")
async def get_go_terms(request: GoTermsRequest):
"""
Lightweight GO-MF prediction endpoint for programmatic/pipeline use.
Returns only predicted term list — no suppressed, no taxon explanation,
no Platt metadata. ~2× faster than /predict for downstream tools.
Results are cached by (sequence_hash, uniprot_id, taxon).
"""
import hashlib
seq = request.sequence.strip()
cache_key = (hashlib.md5(seq.encode()).hexdigest(), request.uniprot_id.upper(), request.taxon)
if cache_key in _go_terms_cache:
return _go_terms_cache[cache_key]
err = validate_sequence("seq", seq)
if err:
return {"error": err, "predictions": []}
if len(seq) > 1500:
return {"error": "Sequence too long (max 1500 aa)", "predictions": []}
try:
mf_idx = app.state.mf_indices
emb = _get_esm_embedding(seq).to(device)
emb_np = emb.detach().cpu().numpy()
taxon_source = request.taxon
taxon_conf = 1.0
if taxon_source == "auto":
if taxon_probe is not None:
taxon_source, taxon_conf = _detect_taxon_probe(emb_np)
else:
taxon_source, taxon_conf = _detect_taxon_composition(seq)
if taxon_source == "insect":
t_apply = insect_temperature
thresh_lookup = {}
t_default = insect_threshold_default
mammal_floor = 0.0
else:
t_apply = temperature
thresh_lookup = thresholds
t_default = 0.68
mammal_floor = 0.56
with torch.no_grad():
inp = _build_model_input(emb, seq)
logits = model(inp).squeeze()
if platt_params and taxon_source != "insect":
probs_list = []
for i in range(len(logits)):
l = float(logits[i])
if str(i) in platt_params:
A, B = platt_params[str(i)]
p = 1.0 / (1.0 + math.exp(-(A * l + B)))
else:
p = 1.0 / (1.0 + math.exp(-l / t_apply))
probs_list.append(p)
prob = torch.tensor(probs_list)
else:
prob = torch.sigmoid(logits / t_apply)
if prob.dim() == 0:
prob = prob.unsqueeze(0)
raw_preds = []
for i in mf_idx:
pv = float(prob[i])
label_thresh = max(float(thresh_lookup.get(str(i), t_default)), mammal_floor)
if pv >= label_thresh:
go_id = mlb.classes_[i]
display_id = go_replaced.get(go_id, go_id)
display_nm = go_map.get(display_id, go_map.get(go_id, go_id))
raw_preds.append({
"go_id": display_id,
"name": display_nm,
"prob": round(pv, 4),
"depth": go_depth.get(display_id, -1),
})
raw_preds.sort(key=lambda x: x["prob"], reverse=True)
cap = compute_dynamic_cap([p["prob"] for p in raw_preds], len(seq))
raw_preds = raw_preds[:cap]
if not request.include_implied:
predictions = raw_preds[:request.top_k]
else:
prob_map = {p["go_id"]: p["prob"] for p in raw_preds}
visible, _ = propagate_and_filter(raw_preds, go_parents, go_ancestors, prob_map)
predictions = [p for p in visible if not p.get("implied")][:request.top_k]
if request.min_prob > 0:
predictions = [p for p in predictions if p["prob"] >= request.min_prob]
result = {
"predictions": predictions,
"n_predicted": len(predictions),
"taxon_applied": taxon_source,
"taxon_conf": round(taxon_conf, 3),
"taxon_source_method": (
"probe" if taxon_probe is not None and request.taxon == "auto"
else ("composition" if request.taxon == "auto" else "manual")
),
"platt_applied": bool(platt_params) and taxon_source != "insect",
"threshold_default": t_default,
}
_go_terms_cache[cache_key] = result
return result
except Exception as e:
return {"error": str(e)[:300], "predictions": []}
@app.get("/api/explain_terms")
async def explain_terms(ids: str = ""):
"""
Return name + definition for a comma-separated list of GO IDs.
Used by the frontend "Why?" panel to show term descriptions inline.
"""
if not ids:
return {"terms": []}
result = []
for gid in ids.split(",")[:50]:
gid = gid.strip()
if not gid:
continue
name = go_map.get(gid, gid)
defn = go_defs.get(gid, "")
result.append({"id": gid, "name": name, "definition": defn})
return {"terms": result}
@app.post("/api/saliency")
async def compute_saliency(request: SaliencyRequest):
"""
Compute per-residue gradient saliency for a protein sequence.
Uses d(sum_of_top_k_probs)/d(ESM_residue_representations) via backprop.
Optionally fetches AlphaFold structure if uniprot_id is provided.
Returns normalized per-residue importance scores in [0, 1].
"""
import numpy as np
import urllib.request, urllib.error
sequence = re.sub(r"\s+", "", request.sequence.upper())
if not sequence:
return {"error": "Empty sequence"}
if len(sequence) > 1200:
return {"error": "Sequence too long for saliency (max 1200 aa)"}
err = validate_sequence("query", sequence)
if err:
return {"error": err}
taxon = (request.taxon or "auto").lower()
t_apply = insect_temperature if taxon == "insect" else temperature
try:
_, _, tokens = batch_converter([("p", sequence)])
tokens = tokens.to(device)
L = len(sequence)
with torch.enable_grad():
# Run ESM keeping computation graph; retain grad on residue reps
out = esm_model(tokens, repr_layers=[6])
residue_reps = out["representations"][6] # (1, L+2, 320)
residue_reps.retain_grad()
# Mean-pool (stays in graph)
emb = residue_reps[0, 1:L + 1].mean(0) # (320,)
# Build MLP input in-graph (gradient-safe version of _build_model_input)
if model_uses_supp and supp_mu is not None:
feats = compute_seq_features(sequence)
s_vec = torch.tensor(
[(feats.get(c, 0.0) - float(supp_mu[j])) / (float(supp_sd[j]) + 1e-12)
for j, c in enumerate(supp_cols)],
dtype=torch.float32, device=device
)
# flag tensor (no grad needed)
flag = torch.ones(1, dtype=torch.float32, device=device)
inp_full = torch.cat([emb, s_vec, flag]).unsqueeze(0)
expected = model.fc_in.weight.shape[1]
inp = inp_full[:, :expected]
else:
inp = emb.unsqueeze(0)
logits = model(inp) / t_apply
probs = torch.sigmoid(logits[0]) # (8124,)
# Objective: sum of top-k predicted probabilities
k = min(request.top_k, int((probs > 0.3).sum().item()), 8124)
k = max(k, 5)
top_vals = probs.topk(k).values
objective = top_vals.sum()
objective.backward()
scores = [0.0] * L
if residue_reps.grad is not None:
grad = residue_reps.grad[0, 1:L + 1].norm(dim=-1).detach().cpu().numpy()
mn, mx = grad.min(), grad.max()
scores = ((grad - mn) / (mx - mn + 1e-8)).tolist()
# Fetch AlphaFold structure if accession given
structure = None
uid = request.uniprot_id.upper().strip()
if uid and re.match(r'^[A-Z0-9]{4,10}$', uid):
try:
req = urllib.request.Request(
f"https://alphafold.ebi.ac.uk/api/prediction/{uid}",
headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"}
)
with urllib.request.urlopen(req, timeout=8) as resp:
d = json.loads(resp.read())[0]
structure = {
"found": True,
"accession": uid,
"cif_url": d.get("cifUrl", ""),
"pdb_url": d.get("pdbUrl", ""),
"organism": d.get("organismScientificName", ""),
"gene": d.get("gene", ""),
"entry_url": f"https://alphafold.ebi.ac.uk/entry/{uid}",
"uniprot_url": f"https://www.uniprot.org/uniprot/{uid}",
}
except Exception:
structure = {"found": False, "accession": uid}
return {
"sequence_length": L,
"per_residue_scores": scores,
"taxon": taxon,
"structure": structure,
}
except Exception as e:
return {"error": str(e)[:200]}
@app.post("/api/explainability")
async def compute_explainability(request: ExplainRequest):
"""
Compute feature-level importance for the 11 sequence features via gradient × input.
Also returns per-residue feature maps for 3D structure coloring.
Gradient flows only through the first 11 supp features; ESM embedding is detached.
"""
import numpy as np
import urllib.request, urllib.error
sequence = re.sub(r"\s+", "", request.sequence.upper())
if not sequence:
return {"error": "Empty sequence"}
if len(sequence) > 1200:
return {"error": "Sequence too long for explainability (max 1200 aa)"}
err = validate_sequence("query", sequence)
if err:
return {"error": err}
if not model_uses_supp or supp_mu is None:
return {"error": "Feature importance requires a supplemented model (unified_v1 or later)"}
taxon = (request.taxon or "auto").lower()
t_apply = insect_temperature if taxon == "insect" else temperature
try:
emb = _get_esm_embedding(sequence).to(device) # (320,) detached
feats = compute_seq_features(sequence)
s_vec = np.array([feats.get(c, 0.0) for c in supp_cols], dtype=np.float32)
s_z = (s_vec - supp_mu) / (supp_sd + 1e-12)
n_tracked = min(11, len(supp_cols))
s_z_11 = torch.tensor(s_z[:n_tracked], requires_grad=True,
dtype=torch.float32, device=device)
s_z_rest = torch.tensor(s_z[n_tracked:], dtype=torch.float32, device=device)
flag = torch.ones(1, dtype=torch.float32, device=device)
inp_full = torch.cat([emb.detach(), s_z_11, s_z_rest, flag]).unsqueeze(0)
expected = model.fc_in.weight.shape[1]
inp = inp_full[:, :expected]
with torch.enable_grad():
logits = model(inp) / t_apply
probs = torch.sigmoid(logits[0])
k = min(request.top_k, int((probs > 0.3).sum().item()), 8124)
k = max(k, 5)
probs.topk(k).values.sum().backward()
if s_z_11.grad is None:
return {"error": "Gradient computation failed — no gradient on supp features"}
grad_np = s_z_11.grad.detach().cpu().numpy()
s_z_11np = s_z_11.detach().cpu().numpy()
attribution = grad_np * s_z_11np # signed grad × input
per_residue_maps = compute_per_residue_features(sequence)
feat_meta_by_key = {fm["key"]: fm for fm in FEATURE_META}
features = []
for i in range(n_tracked):
col = supp_cols[i]
meta = feat_meta_by_key.get(col, {"key": col, "label": col, "desc": col, "color": "#888888"})
attr = float(attribution[i])
features.append({
"key": col,
"label": meta["label"],
"desc": meta["desc"],
"color": meta["color"],
"importance": round(attr, 4),
"abs_importance": round(abs(attr), 4),
"per_residue": per_residue_maps.get(col, [0.5] * len(sequence)),
})
features.sort(key=lambda x: x["abs_importance"], reverse=True)
# Fetch AlphaFold structure
structure = None
uid = request.uniprot_id.upper().strip()
if uid and re.match(r'^[A-Z0-9]{4,10}$', uid):
try:
req = urllib.request.Request(
f"https://alphafold.ebi.ac.uk/api/prediction/{uid}",
headers={"Accept": "application/json", "User-Agent": "ProtFunc/1.0"}
)
with urllib.request.urlopen(req, timeout=8) as resp:
d = json.loads(resp.read())[0]
structure = {
"found": True,
"accession": uid,
"cif_url": d.get("cifUrl", ""),
"organism": d.get("organismScientificName", ""),
"gene": d.get("gene", ""),
"entry_url": f"https://alphafold.ebi.ac.uk/entry/{uid}",
"uniprot_url": f"https://www.uniprot.org/uniprot/{uid}",
}
except Exception:
structure = {"found": False, "accession": uid}
return {
"sequence_length": len(sequence),
"features": features,
"top_feature": features[0]["key"] if features else None,
"structure": structure,
"taxon": taxon,
}
except Exception as e:
return {"error": str(e)[:300]}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)