Update server.py: v3-fixed model priority, generalization API + UI
Browse files
server.py
CHANGED
|
@@ -22,14 +22,16 @@ STATIC_DIR = os.path.join(BASE_DIR, "static")
|
|
| 22 |
os.makedirs(STATIC_DIR, exist_ok=True)
|
| 23 |
|
| 24 |
HF_REPO = "Sbhat2026/protfunc-models"
|
| 25 |
-
# Priority order: v3 supplemented > improved base > original baseline
|
| 26 |
HF_FILES = [
|
|
|
|
| 27 |
"protfunc_v3.pth", "protfunc_v3_thresholds.json",
|
| 28 |
"improved_res.pth", "improved_per_label_thresholds.json",
|
| 29 |
"baseline_res.pth", "mlb_public_v1.pkl", "go_annotations_fixed.csv", "go_names.json",
|
| 30 |
]
|
| 31 |
OPTIONAL = {
|
| 32 |
"go_names.json",
|
|
|
|
| 33 |
"protfunc_v3.pth", "protfunc_v3_thresholds.json",
|
| 34 |
"improved_res.pth", "improved_per_label_thresholds.json",
|
| 35 |
}
|
|
@@ -215,6 +217,7 @@ def load_go_map():
|
|
| 215 |
|
| 216 |
def load_thresholds():
|
| 217 |
for path in [
|
|
|
|
| 218 |
os.path.join(BASE_DIR, "improved_per_label_thresholds.json"),
|
| 219 |
os.path.join(BASE_DIR, "protfunc_v3_thresholds.json"),
|
| 220 |
os.path.join(BASE_DIR, "per_label_thresholds.json"),
|
|
@@ -518,10 +521,9 @@ async def lifespan(app: FastAPI):
|
|
| 518 |
import numpy as np
|
| 519 |
device = torch.device("cpu")
|
| 520 |
|
| 521 |
-
# Prefer checkpoints in priority order: improved > v3 > baseline
|
| 522 |
-
# (v3 was trained with propagated training labels which inflates predictions;
|
| 523 |
-
# improved_res.pth has verified Fmax=0.8846 on specific GO annotations)
|
| 524 |
ckpt_candidates = [
|
|
|
|
| 525 |
os.path.join(BASE_DIR, "improved_res.pth"),
|
| 526 |
os.path.join(BASE_DIR, "protfunc_v3.pth"),
|
| 527 |
os.path.join(BASE_DIR, "baseline_res.pth"),
|
|
@@ -591,10 +593,12 @@ async def root():
|
|
| 591 |
@app.get("/api/model/info")
|
| 592 |
async def model_info():
|
| 593 |
"""Return model metadata and configuration."""
|
|
|
|
| 594 |
improved = os.path.exists(os.path.join(BASE_DIR, "improved_res.pth"))
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
|
|
|
| 598 |
name, version, active = "ProtFunc v3 (supplemented + mammal)", "3.0.0", "protfunc_v3"
|
| 599 |
elif improved:
|
| 600 |
name, version, active = "ProtFunc Enhanced", "2.1.0", "improved"
|
|
@@ -618,6 +622,28 @@ async def model_info():
|
|
| 618 |
}
|
| 619 |
|
| 620 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
@app.get("/api/structure")
|
| 622 |
async def get_structure(uniprot_id: str):
|
| 623 |
"""Look up AlphaFold structure data for a UniProt accession."""
|
|
|
|
| 22 |
os.makedirs(STATIC_DIR, exist_ok=True)
|
| 23 |
|
| 24 |
HF_REPO = "Sbhat2026/protfunc-models"
|
| 25 |
+
# Priority order: v3_fixed (ablation best) > v3 supplemented > improved base > original baseline
|
| 26 |
HF_FILES = [
|
| 27 |
+
"protfunc_v3_fixed.pth", "protfunc_v3_fixed_thresholds.json",
|
| 28 |
"protfunc_v3.pth", "protfunc_v3_thresholds.json",
|
| 29 |
"improved_res.pth", "improved_per_label_thresholds.json",
|
| 30 |
"baseline_res.pth", "mlb_public_v1.pkl", "go_annotations_fixed.csv", "go_names.json",
|
| 31 |
]
|
| 32 |
OPTIONAL = {
|
| 33 |
"go_names.json",
|
| 34 |
+
"protfunc_v3_fixed.pth", "protfunc_v3_fixed_thresholds.json",
|
| 35 |
"protfunc_v3.pth", "protfunc_v3_thresholds.json",
|
| 36 |
"improved_res.pth", "improved_per_label_thresholds.json",
|
| 37 |
}
|
|
|
|
| 217 |
|
| 218 |
def load_thresholds():
|
| 219 |
for path in [
|
| 220 |
+
os.path.join(BASE_DIR, "protfunc_v3_fixed_thresholds.json"),
|
| 221 |
os.path.join(BASE_DIR, "improved_per_label_thresholds.json"),
|
| 222 |
os.path.join(BASE_DIR, "protfunc_v3_thresholds.json"),
|
| 223 |
os.path.join(BASE_DIR, "per_label_thresholds.json"),
|
|
|
|
| 521 |
import numpy as np
|
| 522 |
device = torch.device("cpu")
|
| 523 |
|
| 524 |
+
# Prefer checkpoints in priority order: v3_fixed (ablation best, CAFA-correct) > improved > v3 > baseline
|
|
|
|
|
|
|
| 525 |
ckpt_candidates = [
|
| 526 |
+
os.path.join(BASE_DIR, "protfunc_v3_fixed.pth"),
|
| 527 |
os.path.join(BASE_DIR, "improved_res.pth"),
|
| 528 |
os.path.join(BASE_DIR, "protfunc_v3.pth"),
|
| 529 |
os.path.join(BASE_DIR, "baseline_res.pth"),
|
|
|
|
| 593 |
@app.get("/api/model/info")
|
| 594 |
async def model_info():
|
| 595 |
"""Return model metadata and configuration."""
|
| 596 |
+
v3_fixed = os.path.exists(os.path.join(BASE_DIR, "protfunc_v3_fixed.pth"))
|
| 597 |
improved = os.path.exists(os.path.join(BASE_DIR, "improved_res.pth"))
|
| 598 |
+
# model name reflects actual loaded model (v3_fixed takes highest priority)
|
| 599 |
+
if v3_fixed and model_uses_supp:
|
| 600 |
+
name, version, active = "ProtFunc v3-fixed (ablation best, CAFA-correct)", "3.1.0", "protfunc_v3_fixed"
|
| 601 |
+
elif model_uses_supp:
|
| 602 |
name, version, active = "ProtFunc v3 (supplemented + mammal)", "3.0.0", "protfunc_v3"
|
| 603 |
elif improved:
|
| 604 |
name, version, active = "ProtFunc Enhanced", "2.1.0", "improved"
|
|
|
|
| 622 |
}
|
| 623 |
|
| 624 |
|
| 625 |
+
@app.get("/api/generalization")
|
| 626 |
+
async def get_generalization():
|
| 627 |
+
"""
|
| 628 |
+
Return cross-taxon generalization results from eval_generalization.py output.
|
| 629 |
+
Serves artifacts/generalization_results.json if present, otherwise returns empty.
|
| 630 |
+
"""
|
| 631 |
+
candidates = [
|
| 632 |
+
os.path.join(BASE_DIR, "artifacts", "generalization", "generalization_results.json"),
|
| 633 |
+
os.path.join(BASE_DIR, "generalization_results.json"),
|
| 634 |
+
]
|
| 635 |
+
for path in candidates:
|
| 636 |
+
if os.path.exists(path):
|
| 637 |
+
with open(path) as f:
|
| 638 |
+
data = json.load(f)
|
| 639 |
+
return {
|
| 640 |
+
"available": True,
|
| 641 |
+
"taxa": list(data.keys()),
|
| 642 |
+
"results": data,
|
| 643 |
+
}
|
| 644 |
+
return {"available": False, "taxa": [], "results": {}}
|
| 645 |
+
|
| 646 |
+
|
| 647 |
@app.get("/api/structure")
|
| 648 |
async def get_structure(uniprot_id: str):
|
| 649 |
"""Look up AlphaFold structure data for a UniProt accession."""
|