| #!/usr/bin/env bash |
| |
| |
| |
| |
| |
|
|
| set -euo pipefail |
|
|
| BASE="$(cd "$(dirname "$0")/.." && pwd)" |
| PY=/Users/siddhantbhat/miniforge3/bin/python3 |
| GEN_OUT="$BASE/artifacts/generalization/generalization_results.json" |
|
|
| echo "============================================================" |
| echo "ProtFunc Post-Training: Mammal Generalization Eval" |
| echo "============================================================" |
|
|
| run_eval() { |
| local label="$1" |
| local ckpt="$2" |
| local thresh="$3" |
| local log="$4" |
|
|
| if [ ! -f "$ckpt" ]; then |
| echo " SKIP $label β checkpoint not found: $ckpt" |
| return |
| fi |
| if [ ! -f "$thresh" ]; then |
| echo " SKIP $label β thresholds not found: $thresh" |
| return |
| fi |
|
|
| echo "" |
| echo ">>> Evaluating $label ..." |
| $PY "$BASE/scripts/eval_generalization.py" \ |
| --checkpoint "$ckpt" \ |
| --thresholds "$thresh" \ |
| --mlb "$BASE/Important Files/mlb_public_v1.pkl" \ |
| --taxon_parquet "$BASE/artifacts/generalization/mammal_embeddings_v3.parquet" \ |
| --taxon_name "$label" \ |
| --obo "$BASE/go-basic.obo" \ |
| --insect_log "$log" \ |
| --out "$GEN_OUT" |
| } |
|
|
| run_eval "mammals_A" \ |
| "$BASE/artifacts/checkpoints/ablation_A_ESM_only.pth" \ |
| "$BASE/artifacts/thresholds/ablation_A_ESM_only_thresholds.json" \ |
| "$BASE/artifacts/logs/ablation_A_ESM_only_log.json" |
|
|
| run_eval "mammals_B" \ |
| "$BASE/artifacts/checkpoints/ablation_B_ESM_seq.pth" \ |
| "$BASE/artifacts/thresholds/ablation_B_ESM_seq_thresholds.json" \ |
| "$BASE/artifacts/logs/ablation_B_ESM_seq_log.json" |
|
|
| run_eval "mammals_C" \ |
| "$BASE/artifacts/checkpoints/ablation_C_ESM_seq_AF.pth" \ |
| "$BASE/artifacts/thresholds/ablation_C_ESM_seq_AF_thresholds.json" \ |
| "$BASE/artifacts/logs/ablation_C_ESM_seq_AF_log.json" |
|
|
| echo "" |
| echo "============================================================" |
| echo "Generalization results written to: $GEN_OUT" |
| echo "" |
|
|
| |
| $PY - <<'EOF' |
| import json, sys |
| from pathlib import Path |
|
|
| path = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("artifacts/generalization/generalization_results.json") |
| |
| import os |
| base = os.path.dirname(os.path.abspath(__file__)) if "__file__" in dir() else "." |
| candidates = [ |
| "/Users/siddhantbhat/Desktop/Research Files/artifacts/generalization/generalization_results.json" |
| ] |
| for c in candidates: |
| if os.path.exists(c): |
| path = c |
| break |
|
|
| with open(path) as f: |
| results = json.load(f) |
|
|
| print("Taxon | micro_Fmax | CAFA_Fmax | gen_ratio | model_checkpoint") |
| print("-" * 80) |
| for taxon, r in results.items(): |
| if "mammals" in taxon: |
| print(f"{taxon:<14} | {r.get('micro_fmax','β'):>10} | {r.get('cafa_fmax','β'):>9} | {r.get('generalization_ratio','β'):>9} | {r.get('model_checkpoint','?')}") |
|
|
| |
| best_taxon = max( |
| (t for t in results if "mammals" in t), |
| key=lambda t: results[t].get("micro_fmax", 0), |
| default=None |
| ) |
| if best_taxon: |
| best = results[best_taxon] |
| ckpt = best.get("model_checkpoint", "?") |
| print(f"\nBest mammal model: {best_taxon} β {ckpt} (micro_fmax={best.get('micro_fmax')})") |
| EOF |
|
|
| echo "" |
| echo "============================================================" |
| echo "To upload best checkpoint to HuggingFace:" |
| echo "" |
| echo " huggingface-cli upload Sbhat2026/protfunc-models \\" |
| echo " \"/Users/siddhantbhat/Desktop/Research Files/artifacts/checkpoints/ablation_C_ESM_seq_AF.pth\" \\" |
| echo " protfunc_v3_fixed.pth" |
| echo "" |
| echo " huggingface-cli upload Sbhat2026/protfunc-models \\" |
| echo " \"/Users/siddhantbhat/Desktop/Research Files/artifacts/thresholds/ablation_C_ESM_seq_AF_thresholds.json\" \\" |
| echo " protfunc_v3_fixed_thresholds.json" |
| echo "" |
| echo "Then push server.py to HF Space:" |
| echo " cd /Users/siddhantbhat/insecta_webapp" |
| echo " git add server.py && git commit -m 'Load protfunc_v3_fixed as top-priority checkpoint'" |
| echo " git push" |
| echo "============================================================" |
|
|