fix: revert to improved model, add top-30 prediction cap
Browse files
server.py
CHANGED
|
@@ -215,8 +215,8 @@ def load_go_map():
|
|
| 215 |
|
| 216 |
def load_thresholds():
|
| 217 |
for path in [
|
| 218 |
-
os.path.join(BASE_DIR, "protfunc_v3_thresholds.json"),
|
| 219 |
os.path.join(BASE_DIR, "improved_per_label_thresholds.json"),
|
|
|
|
| 220 |
os.path.join(BASE_DIR, "per_label_thresholds.json"),
|
| 221 |
os.path.join(BASE_DIR, "artifacts", "per_label_thresholds.json"),
|
| 222 |
]:
|
|
@@ -518,10 +518,12 @@ async def lifespan(app: FastAPI):
|
|
| 518 |
import numpy as np
|
| 519 |
device = torch.device("cpu")
|
| 520 |
|
| 521 |
-
# Prefer checkpoints in priority order:
|
|
|
|
|
|
|
| 522 |
ckpt_candidates = [
|
| 523 |
-
os.path.join(BASE_DIR, "protfunc_v3.pth"),
|
| 524 |
os.path.join(BASE_DIR, "improved_res.pth"),
|
|
|
|
| 525 |
os.path.join(BASE_DIR, "baseline_res.pth"),
|
| 526 |
]
|
| 527 |
ckpt_path = next((p for p in ckpt_candidates if os.path.exists(p)), None)
|
|
@@ -589,16 +591,18 @@ async def root():
|
|
| 589 |
@app.get("/api/model/info")
|
| 590 |
async def model_info():
|
| 591 |
"""Return model metadata and configuration."""
|
| 592 |
-
v3 = os.path.exists(os.path.join(BASE_DIR, "protfunc_v3.pth"))
|
| 593 |
improved = os.path.exists(os.path.join(BASE_DIR, "improved_res.pth"))
|
| 594 |
-
|
| 595 |
-
|
|
|
|
|
|
|
| 596 |
elif improved:
|
| 597 |
-
name, version = "ProtFunc Enhanced", "2.1.0"
|
| 598 |
else:
|
| 599 |
-
name, version = "ProtFunc", "2.0.0"
|
| 600 |
return {
|
| 601 |
"model_name": name,
|
|
|
|
| 602 |
"version": version,
|
| 603 |
"esm_model": "esm2_t6_8M_UR50D",
|
| 604 |
"embed_dim": 320,
|
|
@@ -749,6 +753,12 @@ async def predict(request: ProteinRequest):
|
|
| 749 |
})
|
| 750 |
prob_map[mlb.classes_[i]] = pv
|
| 751 |
raw_preds.sort(key=lambda x: x["prob"], reverse=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 752 |
|
| 753 |
# Propagate up GO DAG and filter
|
| 754 |
visible, suppressed = propagate_and_filter(
|
|
@@ -821,6 +831,9 @@ async def predict_batch(request: BatchPredictRequest):
|
|
| 821 |
})
|
| 822 |
prob_map[mlb.classes_[i]] = pv
|
| 823 |
raw_preds.sort(key=lambda x: x["prob"], reverse=True)
|
|
|
|
|
|
|
|
|
|
| 824 |
|
| 825 |
visible, suppressed = propagate_and_filter(
|
| 826 |
raw_preds, go_parents, go_ancestors, prob_map
|
|
|
|
| 215 |
|
| 216 |
def load_thresholds():
|
| 217 |
for path in [
|
|
|
|
| 218 |
os.path.join(BASE_DIR, "improved_per_label_thresholds.json"),
|
| 219 |
+
os.path.join(BASE_DIR, "protfunc_v3_thresholds.json"),
|
| 220 |
os.path.join(BASE_DIR, "per_label_thresholds.json"),
|
| 221 |
os.path.join(BASE_DIR, "artifacts", "per_label_thresholds.json"),
|
| 222 |
]:
|
|
|
|
| 518 |
import numpy as np
|
| 519 |
device = torch.device("cpu")
|
| 520 |
|
| 521 |
+
# Prefer checkpoints in priority order: improved > v3 > baseline
|
| 522 |
+
# (v3 was trained with propagated training labels which inflates predictions;
|
| 523 |
+
# improved_res.pth has verified Fmax=0.8846 on specific GO annotations)
|
| 524 |
ckpt_candidates = [
|
|
|
|
| 525 |
os.path.join(BASE_DIR, "improved_res.pth"),
|
| 526 |
+
os.path.join(BASE_DIR, "protfunc_v3.pth"),
|
| 527 |
os.path.join(BASE_DIR, "baseline_res.pth"),
|
| 528 |
]
|
| 529 |
ckpt_path = next((p for p in ckpt_candidates if os.path.exists(p)), None)
|
|
|
|
| 591 |
@app.get("/api/model/info")
|
| 592 |
async def model_info():
|
| 593 |
"""Return model metadata and configuration."""
|
|
|
|
| 594 |
improved = os.path.exists(os.path.join(BASE_DIR, "improved_res.pth"))
|
| 595 |
+
v3 = os.path.exists(os.path.join(BASE_DIR, "protfunc_v3.pth"))
|
| 596 |
+
# model name reflects actual loaded model (improved takes priority)
|
| 597 |
+
if model_uses_supp:
|
| 598 |
+
name, version, active = "ProtFunc v3 (supplemented + mammal)", "3.0.0", "protfunc_v3"
|
| 599 |
elif improved:
|
| 600 |
+
name, version, active = "ProtFunc Enhanced", "2.1.0", "improved"
|
| 601 |
else:
|
| 602 |
+
name, version, active = "ProtFunc", "2.0.0", "baseline"
|
| 603 |
return {
|
| 604 |
"model_name": name,
|
| 605 |
+
"model": active,
|
| 606 |
"version": version,
|
| 607 |
"esm_model": "esm2_t6_8M_UR50D",
|
| 608 |
"embed_dim": 320,
|
|
|
|
| 753 |
})
|
| 754 |
prob_map[mlb.classes_[i]] = pv
|
| 755 |
raw_preds.sort(key=lambda x: x["prob"], reverse=True)
|
| 756 |
+
# Cap to top 30 most confident direct predictions before propagation.
|
| 757 |
+
# Without this, models trained on propagated labels return hundreds of
|
| 758 |
+
# broad ancestor terms that overwhelm the signal.
|
| 759 |
+
raw_preds = raw_preds[:30]
|
| 760 |
+
for rp in raw_preds:
|
| 761 |
+
prob_map[rp["go_id"]] = rp["prob"]
|
| 762 |
|
| 763 |
# Propagate up GO DAG and filter
|
| 764 |
visible, suppressed = propagate_and_filter(
|
|
|
|
| 831 |
})
|
| 832 |
prob_map[mlb.classes_[i]] = pv
|
| 833 |
raw_preds.sort(key=lambda x: x["prob"], reverse=True)
|
| 834 |
+
raw_preds = raw_preds[:30]
|
| 835 |
+
for rp in raw_preds:
|
| 836 |
+
prob_map[rp["go_id"]] = rp["prob"]
|
| 837 |
|
| 838 |
visible, suppressed = propagate_and_filter(
|
| 839 |
raw_preds, go_parents, go_ancestors, prob_map
|