protfunc / scripts /post_training.sh
Sbhat2026's picture
perf: ESM embedding cache + 1500aa limit, add research scripts
7f7a890
#!/usr/bin/env bash
# post_training.sh β€” Run after Model C training completes.
# Evaluates generalization for all 3 ablation models on mammals,
# then prints the upload command for the best checkpoint.
#
# Usage: bash scripts/post_training.sh
set -euo pipefail
BASE="$(cd "$(dirname "$0")/.." && pwd)"
PY=/Users/siddhantbhat/miniforge3/bin/python3
GEN_OUT="$BASE/artifacts/generalization/generalization_results.json"
echo "============================================================"
echo "ProtFunc Post-Training: Mammal Generalization Eval"
echo "============================================================"
run_eval() {
local label="$1" # mammals_A, mammals_B, mammals_C
local ckpt="$2"
local thresh="$3"
local log="$4"
if [ ! -f "$ckpt" ]; then
echo " SKIP $label β€” checkpoint not found: $ckpt"
return
fi
if [ ! -f "$thresh" ]; then
echo " SKIP $label β€” thresholds not found: $thresh"
return
fi
echo ""
echo ">>> Evaluating $label ..."
$PY "$BASE/scripts/eval_generalization.py" \
--checkpoint "$ckpt" \
--thresholds "$thresh" \
--mlb "$BASE/Important Files/mlb_public_v1.pkl" \
--taxon_parquet "$BASE/artifacts/generalization/mammal_embeddings_v3.parquet" \
--taxon_name "$label" \
--obo "$BASE/go-basic.obo" \
--insect_log "$log" \
--out "$GEN_OUT"
}
run_eval "mammals_A" \
"$BASE/artifacts/checkpoints/ablation_A_ESM_only.pth" \
"$BASE/artifacts/thresholds/ablation_A_ESM_only_thresholds.json" \
"$BASE/artifacts/logs/ablation_A_ESM_only_log.json"
run_eval "mammals_B" \
"$BASE/artifacts/checkpoints/ablation_B_ESM_seq.pth" \
"$BASE/artifacts/thresholds/ablation_B_ESM_seq_thresholds.json" \
"$BASE/artifacts/logs/ablation_B_ESM_seq_log.json"
run_eval "mammals_C" \
"$BASE/artifacts/checkpoints/ablation_C_ESM_seq_AF.pth" \
"$BASE/artifacts/thresholds/ablation_C_ESM_seq_AF_thresholds.json" \
"$BASE/artifacts/logs/ablation_C_ESM_seq_AF_log.json"
echo ""
echo "============================================================"
echo "Generalization results written to: $GEN_OUT"
echo ""
# Print summary
$PY - <<'EOF'
import json, sys
from pathlib import Path
path = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("artifacts/generalization/generalization_results.json")
# find it
import os
base = os.path.dirname(os.path.abspath(__file__)) if "__file__" in dir() else "."
candidates = [
"/Users/siddhantbhat/Desktop/Research Files/artifacts/generalization/generalization_results.json"
]
for c in candidates:
if os.path.exists(c):
path = c
break
with open(path) as f:
results = json.load(f)
print("Taxon | micro_Fmax | CAFA_Fmax | gen_ratio | model_checkpoint")
print("-" * 80)
for taxon, r in results.items():
if "mammals" in taxon:
print(f"{taxon:<14} | {r.get('micro_fmax','β€”'):>10} | {r.get('cafa_fmax','β€”'):>9} | {r.get('generalization_ratio','β€”'):>9} | {r.get('model_checkpoint','?')}")
# Pick best
best_taxon = max(
(t for t in results if "mammals" in t),
key=lambda t: results[t].get("micro_fmax", 0),
default=None
)
if best_taxon:
best = results[best_taxon]
ckpt = best.get("model_checkpoint", "?")
print(f"\nBest mammal model: {best_taxon} β†’ {ckpt} (micro_fmax={best.get('micro_fmax')})")
EOF
echo ""
echo "============================================================"
echo "To upload best checkpoint to HuggingFace:"
echo ""
echo " huggingface-cli upload Sbhat2026/protfunc-models \\"
echo " \"/Users/siddhantbhat/Desktop/Research Files/artifacts/checkpoints/ablation_C_ESM_seq_AF.pth\" \\"
echo " protfunc_v3_fixed.pth"
echo ""
echo " huggingface-cli upload Sbhat2026/protfunc-models \\"
echo " \"/Users/siddhantbhat/Desktop/Research Files/artifacts/thresholds/ablation_C_ESM_seq_AF_thresholds.json\" \\"
echo " protfunc_v3_fixed_thresholds.json"
echo ""
echo "Then push server.py to HF Space:"
echo " cd /Users/siddhantbhat/insecta_webapp"
echo " git add server.py && git commit -m 'Load protfunc_v3_fixed as top-priority checkpoint'"
echo " git push"
echo "============================================================"