Sbhat2026 commited on
Commit
83c0418
·
verified ·
1 Parent(s): b91bbd2

Update server.py: v3-fixed model priority, generalization API + UI

Browse files
Files changed (1) hide show
  1. server.py +33 -7
server.py CHANGED
@@ -22,14 +22,16 @@ STATIC_DIR = os.path.join(BASE_DIR, "static")
22
  os.makedirs(STATIC_DIR, exist_ok=True)
23
 
24
  HF_REPO = "Sbhat2026/protfunc-models"
25
- # Priority order: v3 supplemented > improved base > original baseline
26
  HF_FILES = [
 
27
  "protfunc_v3.pth", "protfunc_v3_thresholds.json",
28
  "improved_res.pth", "improved_per_label_thresholds.json",
29
  "baseline_res.pth", "mlb_public_v1.pkl", "go_annotations_fixed.csv", "go_names.json",
30
  ]
31
  OPTIONAL = {
32
  "go_names.json",
 
33
  "protfunc_v3.pth", "protfunc_v3_thresholds.json",
34
  "improved_res.pth", "improved_per_label_thresholds.json",
35
  }
@@ -215,6 +217,7 @@ def load_go_map():
215
 
216
  def load_thresholds():
217
  for path in [
 
218
  os.path.join(BASE_DIR, "improved_per_label_thresholds.json"),
219
  os.path.join(BASE_DIR, "protfunc_v3_thresholds.json"),
220
  os.path.join(BASE_DIR, "per_label_thresholds.json"),
@@ -518,10 +521,9 @@ async def lifespan(app: FastAPI):
518
  import numpy as np
519
  device = torch.device("cpu")
520
 
521
- # Prefer checkpoints in priority order: improved > v3 > baseline
522
- # (v3 was trained with propagated training labels which inflates predictions;
523
- # improved_res.pth has verified Fmax=0.8846 on specific GO annotations)
524
  ckpt_candidates = [
 
525
  os.path.join(BASE_DIR, "improved_res.pth"),
526
  os.path.join(BASE_DIR, "protfunc_v3.pth"),
527
  os.path.join(BASE_DIR, "baseline_res.pth"),
@@ -591,10 +593,12 @@ async def root():
591
  @app.get("/api/model/info")
592
  async def model_info():
593
  """Return model metadata and configuration."""
 
594
  improved = os.path.exists(os.path.join(BASE_DIR, "improved_res.pth"))
595
- v3 = os.path.exists(os.path.join(BASE_DIR, "protfunc_v3.pth"))
596
- # model name reflects actual loaded model (improved takes priority)
597
- if model_uses_supp:
 
598
  name, version, active = "ProtFunc v3 (supplemented + mammal)", "3.0.0", "protfunc_v3"
599
  elif improved:
600
  name, version, active = "ProtFunc Enhanced", "2.1.0", "improved"
@@ -618,6 +622,28 @@ async def model_info():
618
  }
619
 
620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  @app.get("/api/structure")
622
  async def get_structure(uniprot_id: str):
623
  """Look up AlphaFold structure data for a UniProt accession."""
 
22
  os.makedirs(STATIC_DIR, exist_ok=True)
23
 
24
  HF_REPO = "Sbhat2026/protfunc-models"
25
+ # Priority order: v3_fixed (ablation best) > v3 supplemented > improved base > original baseline
26
  HF_FILES = [
27
+ "protfunc_v3_fixed.pth", "protfunc_v3_fixed_thresholds.json",
28
  "protfunc_v3.pth", "protfunc_v3_thresholds.json",
29
  "improved_res.pth", "improved_per_label_thresholds.json",
30
  "baseline_res.pth", "mlb_public_v1.pkl", "go_annotations_fixed.csv", "go_names.json",
31
  ]
32
  OPTIONAL = {
33
  "go_names.json",
34
+ "protfunc_v3_fixed.pth", "protfunc_v3_fixed_thresholds.json",
35
  "protfunc_v3.pth", "protfunc_v3_thresholds.json",
36
  "improved_res.pth", "improved_per_label_thresholds.json",
37
  }
 
217
 
218
  def load_thresholds():
219
  for path in [
220
+ os.path.join(BASE_DIR, "protfunc_v3_fixed_thresholds.json"),
221
  os.path.join(BASE_DIR, "improved_per_label_thresholds.json"),
222
  os.path.join(BASE_DIR, "protfunc_v3_thresholds.json"),
223
  os.path.join(BASE_DIR, "per_label_thresholds.json"),
 
521
  import numpy as np
522
  device = torch.device("cpu")
523
 
524
+ # Prefer checkpoints in priority order: v3_fixed (ablation best, CAFA-correct) > improved > v3 > baseline
 
 
525
  ckpt_candidates = [
526
+ os.path.join(BASE_DIR, "protfunc_v3_fixed.pth"),
527
  os.path.join(BASE_DIR, "improved_res.pth"),
528
  os.path.join(BASE_DIR, "protfunc_v3.pth"),
529
  os.path.join(BASE_DIR, "baseline_res.pth"),
 
593
  @app.get("/api/model/info")
594
  async def model_info():
595
  """Return model metadata and configuration."""
596
+ v3_fixed = os.path.exists(os.path.join(BASE_DIR, "protfunc_v3_fixed.pth"))
597
  improved = os.path.exists(os.path.join(BASE_DIR, "improved_res.pth"))
598
+ # model name reflects actual loaded model (v3_fixed takes highest priority)
599
+ if v3_fixed and model_uses_supp:
600
+ name, version, active = "ProtFunc v3-fixed (ablation best, CAFA-correct)", "3.1.0", "protfunc_v3_fixed"
601
+ elif model_uses_supp:
602
  name, version, active = "ProtFunc v3 (supplemented + mammal)", "3.0.0", "protfunc_v3"
603
  elif improved:
604
  name, version, active = "ProtFunc Enhanced", "2.1.0", "improved"
 
622
  }
623
 
624
 
625
+ @app.get("/api/generalization")
626
+ async def get_generalization():
627
+ """
628
+ Return cross-taxon generalization results from eval_generalization.py output.
629
+ Serves artifacts/generalization_results.json if present, otherwise returns empty.
630
+ """
631
+ candidates = [
632
+ os.path.join(BASE_DIR, "artifacts", "generalization", "generalization_results.json"),
633
+ os.path.join(BASE_DIR, "generalization_results.json"),
634
+ ]
635
+ for path in candidates:
636
+ if os.path.exists(path):
637
+ with open(path) as f:
638
+ data = json.load(f)
639
+ return {
640
+ "available": True,
641
+ "taxa": list(data.keys()),
642
+ "results": data,
643
+ }
644
+ return {"available": False, "taxa": [], "results": {}}
645
+
646
+
647
  @app.get("/api/structure")
648
  async def get_structure(uniprot_id: str):
649
  """Look up AlphaFold structure data for a UniProt accession."""