Sbhat2026 commited on
Commit
b91bbd2
·
verified ·
1 Parent(s): 621deb3

fix: revert to improved model, add top-30 prediction cap

Browse files
Files changed (1) hide show
  1. server.py +21 -8
server.py CHANGED
@@ -215,8 +215,8 @@ def load_go_map():
215
 
216
  def load_thresholds():
217
  for path in [
218
- os.path.join(BASE_DIR, "protfunc_v3_thresholds.json"),
219
  os.path.join(BASE_DIR, "improved_per_label_thresholds.json"),
 
220
  os.path.join(BASE_DIR, "per_label_thresholds.json"),
221
  os.path.join(BASE_DIR, "artifacts", "per_label_thresholds.json"),
222
  ]:
@@ -518,10 +518,12 @@ async def lifespan(app: FastAPI):
518
  import numpy as np
519
  device = torch.device("cpu")
520
 
521
- # Prefer checkpoints in priority order: v3 > improved > baseline
 
 
522
  ckpt_candidates = [
523
- os.path.join(BASE_DIR, "protfunc_v3.pth"),
524
  os.path.join(BASE_DIR, "improved_res.pth"),
 
525
  os.path.join(BASE_DIR, "baseline_res.pth"),
526
  ]
527
  ckpt_path = next((p for p in ckpt_candidates if os.path.exists(p)), None)
@@ -589,16 +591,18 @@ async def root():
589
  @app.get("/api/model/info")
590
  async def model_info():
591
  """Return model metadata and configuration."""
592
- v3 = os.path.exists(os.path.join(BASE_DIR, "protfunc_v3.pth"))
593
  improved = os.path.exists(os.path.join(BASE_DIR, "improved_res.pth"))
594
- if v3:
595
- name, version = "ProtFunc v3 (supplemented + mammal)", "3.0.0"
 
 
596
  elif improved:
597
- name, version = "ProtFunc Enhanced", "2.1.0"
598
  else:
599
- name, version = "ProtFunc", "2.0.0"
600
  return {
601
  "model_name": name,
 
602
  "version": version,
603
  "esm_model": "esm2_t6_8M_UR50D",
604
  "embed_dim": 320,
@@ -749,6 +753,12 @@ async def predict(request: ProteinRequest):
749
  })
750
  prob_map[mlb.classes_[i]] = pv
751
  raw_preds.sort(key=lambda x: x["prob"], reverse=True)
 
 
 
 
 
 
752
 
753
  # Propagate up GO DAG and filter
754
  visible, suppressed = propagate_and_filter(
@@ -821,6 +831,9 @@ async def predict_batch(request: BatchPredictRequest):
821
  })
822
  prob_map[mlb.classes_[i]] = pv
823
  raw_preds.sort(key=lambda x: x["prob"], reverse=True)
 
 
 
824
 
825
  visible, suppressed = propagate_and_filter(
826
  raw_preds, go_parents, go_ancestors, prob_map
 
215
 
216
  def load_thresholds():
217
  for path in [
 
218
  os.path.join(BASE_DIR, "improved_per_label_thresholds.json"),
219
+ os.path.join(BASE_DIR, "protfunc_v3_thresholds.json"),
220
  os.path.join(BASE_DIR, "per_label_thresholds.json"),
221
  os.path.join(BASE_DIR, "artifacts", "per_label_thresholds.json"),
222
  ]:
 
518
  import numpy as np
519
  device = torch.device("cpu")
520
 
521
+ # Prefer checkpoints in priority order: improved > v3 > baseline
522
+ # (v3 was trained with propagated training labels which inflates predictions;
523
+ # improved_res.pth has verified Fmax=0.8846 on specific GO annotations)
524
  ckpt_candidates = [
 
525
  os.path.join(BASE_DIR, "improved_res.pth"),
526
+ os.path.join(BASE_DIR, "protfunc_v3.pth"),
527
  os.path.join(BASE_DIR, "baseline_res.pth"),
528
  ]
529
  ckpt_path = next((p for p in ckpt_candidates if os.path.exists(p)), None)
 
591
  @app.get("/api/model/info")
592
  async def model_info():
593
  """Return model metadata and configuration."""
 
594
  improved = os.path.exists(os.path.join(BASE_DIR, "improved_res.pth"))
595
+ v3 = os.path.exists(os.path.join(BASE_DIR, "protfunc_v3.pth"))
596
+ # model name reflects actual loaded model (improved takes priority)
597
+ if model_uses_supp:
598
+ name, version, active = "ProtFunc v3 (supplemented + mammal)", "3.0.0", "protfunc_v3"
599
  elif improved:
600
+ name, version, active = "ProtFunc Enhanced", "2.1.0", "improved"
601
  else:
602
+ name, version, active = "ProtFunc", "2.0.0", "baseline"
603
  return {
604
  "model_name": name,
605
+ "model": active,
606
  "version": version,
607
  "esm_model": "esm2_t6_8M_UR50D",
608
  "embed_dim": 320,
 
753
  })
754
  prob_map[mlb.classes_[i]] = pv
755
  raw_preds.sort(key=lambda x: x["prob"], reverse=True)
756
+ # Cap to top 30 most confident direct predictions before propagation.
757
+ # Without this, models trained on propagated labels return hundreds of
758
+ # broad ancestor terms that overwhelm the signal.
759
+ raw_preds = raw_preds[:30]
760
+ for rp in raw_preds:
761
+ prob_map[rp["go_id"]] = rp["prob"]
762
 
763
  # Propagate up GO DAG and filter
764
  visible, suppressed = propagate_and_filter(
 
831
  })
832
  prob_map[mlb.classes_[i]] = pv
833
  raw_preds.sort(key=lambda x: x["prob"], reverse=True)
834
+ raw_preds = raw_preds[:30]
835
+ for rp in raw_preds:
836
+ prob_map[rp["go_id"]] = rp["prob"]
837
 
838
  visible, suppressed = propagate_and_filter(
839
  raw_preds, go_parents, go_ancestors, prob_map