#!/usr/bin/env bash # post_training.sh — Run after Model C training completes. # Evaluates generalization for all 3 ablation models on mammals, # then prints the upload command for the best checkpoint. # # Usage: bash scripts/post_training.sh set -euo pipefail BASE="$(cd "$(dirname "$0")/.." && pwd)" PY=/Users/siddhantbhat/miniforge3/bin/python3 GEN_OUT="$BASE/artifacts/generalization/generalization_results.json" echo "============================================================" echo "ProtFunc Post-Training: Mammal Generalization Eval" echo "============================================================" run_eval() { local label="$1" # mammals_A, mammals_B, mammals_C local ckpt="$2" local thresh="$3" local log="$4" if [ ! -f "$ckpt" ]; then echo " SKIP $label — checkpoint not found: $ckpt" return fi if [ ! -f "$thresh" ]; then echo " SKIP $label — thresholds not found: $thresh" return fi echo "" echo ">>> Evaluating $label ..." $PY "$BASE/scripts/eval_generalization.py" \ --checkpoint "$ckpt" \ --thresholds "$thresh" \ --mlb "$BASE/Important Files/mlb_public_v1.pkl" \ --taxon_parquet "$BASE/artifacts/generalization/mammal_embeddings_v3.parquet" \ --taxon_name "$label" \ --obo "$BASE/go-basic.obo" \ --insect_log "$log" \ --out "$GEN_OUT" } run_eval "mammals_A" \ "$BASE/artifacts/checkpoints/ablation_A_ESM_only.pth" \ "$BASE/artifacts/thresholds/ablation_A_ESM_only_thresholds.json" \ "$BASE/artifacts/logs/ablation_A_ESM_only_log.json" run_eval "mammals_B" \ "$BASE/artifacts/checkpoints/ablation_B_ESM_seq.pth" \ "$BASE/artifacts/thresholds/ablation_B_ESM_seq_thresholds.json" \ "$BASE/artifacts/logs/ablation_B_ESM_seq_log.json" run_eval "mammals_C" \ "$BASE/artifacts/checkpoints/ablation_C_ESM_seq_AF.pth" \ "$BASE/artifacts/thresholds/ablation_C_ESM_seq_AF_thresholds.json" \ "$BASE/artifacts/logs/ablation_C_ESM_seq_AF_log.json" echo "" echo "============================================================" echo "Generalization results written to: $GEN_OUT" echo "" # Print summary $PY - <<'EOF' import json, sys from pathlib import Path path = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("artifacts/generalization/generalization_results.json") # find it import os base = os.path.dirname(os.path.abspath(__file__)) if "__file__" in dir() else "." candidates = [ "/Users/siddhantbhat/Desktop/Research Files/artifacts/generalization/generalization_results.json" ] for c in candidates: if os.path.exists(c): path = c break with open(path) as f: results = json.load(f) print("Taxon | micro_Fmax | CAFA_Fmax | gen_ratio | model_checkpoint") print("-" * 80) for taxon, r in results.items(): if "mammals" in taxon: print(f"{taxon:<14} | {r.get('micro_fmax','—'):>10} | {r.get('cafa_fmax','—'):>9} | {r.get('generalization_ratio','—'):>9} | {r.get('model_checkpoint','?')}") # Pick best best_taxon = max( (t for t in results if "mammals" in t), key=lambda t: results[t].get("micro_fmax", 0), default=None ) if best_taxon: best = results[best_taxon] ckpt = best.get("model_checkpoint", "?") print(f"\nBest mammal model: {best_taxon} → {ckpt} (micro_fmax={best.get('micro_fmax')})") EOF echo "" echo "============================================================" echo "To upload best checkpoint to HuggingFace:" echo "" echo " huggingface-cli upload Sbhat2026/protfunc-models \\" echo " \"/Users/siddhantbhat/Desktop/Research Files/artifacts/checkpoints/ablation_C_ESM_seq_AF.pth\" \\" echo " protfunc_v3_fixed.pth" echo "" echo " huggingface-cli upload Sbhat2026/protfunc-models \\" echo " \"/Users/siddhantbhat/Desktop/Research Files/artifacts/thresholds/ablation_C_ESM_seq_AF_thresholds.json\" \\" echo " protfunc_v3_fixed_thresholds.json" echo "" echo "Then push server.py to HF Space:" echo " cd /Users/siddhantbhat/insecta_webapp" echo " git add server.py && git commit -m 'Load protfunc_v3_fixed as top-priority checkpoint'" echo " git push" echo "============================================================"