protfunc / scripts /post_train_pipeline.py
Sbhat2026's picture
perf: ESM embedding cache + 1500aa limit, add research scripts
7f7a890
"""
post_train_pipeline.py
======================
Automated post-training pipeline. Runs after graph_hpo_best training completes:
1. Wait for training to finish (polls train.log)
2. Re-run threshold_comparison.py on the new checkpoint
3. Copy best thresholds to webapp per_label_thresholds.json
4. Test the local server with 5 known proteins and validate output
5. Push to HuggingFace Space
6. Write summary to artifacts/graph_hpo/pipeline_summary.json
Usage:
python scripts/post_train_pipeline.py
python scripts/post_train_pipeline.py --skip-wait # if training already done
python scripts/post_train_pipeline.py --dry-run # skip HF push
"""
import argparse
import json
import math
import shlex
import subprocess
import sys
import time
from pathlib import Path
BASE = Path(__file__).parent.parent
ART = BASE / "artifacts"
GRAPH_ART = ART / "graph_hpo"
WEBAPP = Path("/Users/siddhantbhat/insecta_webapp")
SCRIPTS = BASE / "scripts"
TRAIN_LOG = GRAPH_ART / "train.log"
CKPT_OUT = GRAPH_ART / "graph_hpo_best.pth"
THRESH_OUT = GRAPH_ART / "graph_hpo_best_thresholds.json"
LOG_OUT = GRAPH_ART / "graph_hpo_best_log.json"
COMPARISON = ART / "threshold_comparison_results.json"
PREC_THRESH = ART / "thresholds" / "precision_ic_thresholds.json"
SUMMARY_OUT = GRAPH_ART / "pipeline_summary.json"
DEPLOY_THRESH = WEBAPP / "per_label_thresholds.json"
POLL_INTERVAL = 60 # seconds between log polls
# ── Test proteins (known functions for sanity checking) ───────────────────────
# Each entry: name, FASTA sequence, expected GO terms that must appear in top-10
TEST_PROTEINS = [
{
"name": "Drosophila_Hippo_kinase",
"expect_contains": ["kinase", "ATP", "phospho"],
"expect_not_contains": ["chorion", "neurotransmitter", "olfactory"],
"sequence": (
"MSRSSGSSGSAATPVGKRSGQNLSTSIGKGDSFSIRSRSVASTSSSGLNNSGSSTTLNRSN"
"TSSSSVNTSSSQNRSSTLSTSSANNVTSSSSTTMQDNQFLSSSQLEKIQRELEQTLKQLNR"
"QQAELQRQLSQSQSQSQSESQSESQSMSSRSTPVAIPPTQAPPSQSSQSSQSSQSSQSAQD"
"AAAPLNSSSSSSSSSSSSSQQQQLQEQQQQQQQQHQQQHQLHQQHQHHQQHQQQQPQQQPQ"
"QQPQQQPQQQPQAPQAQPQPQPQPQQQQPQQQQPQQQQPQQNPQSQNPQSQNPQSQNAQAQ"
"AQGQAQGQAQGQAQGPQAQAPQSSQGGMGGNNMGGSGMGGSGMGGSGMGNSGMGSSIGGIT"
),
},
{
"name": "Opsin_Rh6",
"expect_contains": ["photoreceptor", "retinal", "light", "G protein", "receptor"],
"expect_not_contains": ["kinase", "DNA", "ribosom"],
"sequence": (
"MSSENGDLQFLNGTEGPNFYVPMSNATGVVRSPAEYNQTRESPIFTYTNSNSTRGPFEGPN"
"YHIAPRWVYHLTSVWMIFVVTASVFTNGLVLAATMKFKKLRHPLNWILVNLAVADLAETVIASTISVVNQF"
"FGYFILGHPMCVLEGYTVSACGITALWSLAIISWERWLVVCKPFGNIKTPLAQRLIAAIWLFSVWIGVPFN"
"IPEGPQNPVSTFKAMLACNPWVLTPLFKTDPESAGYSIFNIYIFLCHFFIPMAVIVFSCYGNIVMTLHSH"
"TKKEALMQFPQTPQQEASSTTVSKTETSQVAPA"
),
},
{
"name": "Human_TP53",
"expect_contains": ["DNA binding", "transcription", "p53"],
"expect_not_contains": ["kinase activity", "photoreceptor"],
"sequence": (
"MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGP"
"DEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYPQGLNGTVNLFRNLNKALK"
"SLSERNSTAHSQAPQLEQPQSSAPISGKQPQSRSAPLEHHLQKSQAPLPNASPRPPMSSNI"
"SQGSSQSKASGSTMGPILSSSSDDIEQWFTEDPSTSEELNEALELKDAQAGKESLHSQATLE"
"STAHPLTSESVQISAVQILANAQREALESLPTKMPVQPEAEELDNFYINQQQTNSASSLSSS"
"RAENTLPDPHYREVGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGQM"
"NRRPILTIITLEDSSGKLLGRNSFEVRVCACPGRDRRTEEENLRKKGQVLKEIREGQRLSDP"
"NTCQKHKKLSELLGQSSFEV"
),
},
{
"name": "Ubiquitin_Drosophila",
"expect_contains": ["ubiquitin", "protein binding"],
"expect_not_contains": ["kinase", "photoreceptor"],
"sequence": "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG",
},
{
"name": "ATP_synthase_beta",
"expect_contains": ["ATP", "hydrolase", "synthase"],
"expect_not_contains": ["photoreceptor", "DNA binding", "olfactory"],
"sequence": (
"MATLRILNIGDQMDVGLNAKSSISRLVQTLPQRKVLISAQDMLQRIREKAEMLGDVPIMVD"
"STSRQLDAEGQVVLSGTSIAIDAMIPIGRGQRELIIGDRQTGKTAVAIDTIINQKGSITSVQ"
"AIYVPADDLTDPAPATTFAHLDATTVLSRAIAKGIYPAVDPLDSTSTMLQPRIVGQEHYDTA"
"RGVQKILQDYKSLQDIIAILGMDELSEEDKLTVSRARLAEHSSKARELKQLVAERQEIVAKA"
"HARAQKIPQAVAAEREALGSDPQIALTASDSYLKQLPQLTDQLAQLMKKQSQQELAQLMQQP"
),
},
]
def py() -> str:
venv = BASE / ".venv" / "bin" / "python3"
return str(venv if venv.exists() else Path("python3"))
def run(cmd, check=True, capture=False):
print(f"$ {shlex.join(cmd)}")
if capture:
r = subprocess.run(cmd, capture_output=True, text=True, cwd=BASE)
return r
subprocess.run(cmd, cwd=BASE, check=check)
# ── Step 1: wait for training ─────────────────────────────────────────────────
def wait_for_training():
print("\n[1/5] Waiting for training to complete...")
if not TRAIN_LOG.exists():
print(f" train.log not found at {TRAIN_LOG}; assuming training not started")
return False
last_size = 0
stall_count = 0
while True:
size = TRAIN_LOG.stat().st_size
content = TRAIN_LOG.read_text(errors="replace")
if "Test evaluation" in content or "[6/6] Done" in content:
print(" Training complete.")
print_last_epoch(content)
return True
if "Traceback" in content or "Error" in content.split("Traceback")[0]:
# Check if error is in the last 2000 chars (recent)
recent = content[-2000:]
if "Traceback" in recent or ("Error" in recent and "Warning" not in recent):
print(f" ERROR detected in training log:")
for line in content.splitlines()[-20:]:
print(f" {line}")
return False
if size == last_size:
stall_count += 1
if stall_count > 30: # 30 min stall
print(" Training appears stalled. Checking process...")
r = subprocess.run(["pgrep", "-f", "train_v3_fixed"], capture_output=True, text=True)
if not r.stdout.strip():
print(" Training process not running. Checking log for completion...")
if CKPT_OUT.exists():
print(" Checkpoint exists β€” treating as complete.")
return True
return False
else:
stall_count = 0
last_size = size
print_last_epoch(content)
time.sleep(POLL_INTERVAL)
def print_last_epoch(content):
lines = content.splitlines()
epoch_lines = [l for l in lines if "micro-fmax" in l or "Epoch" in l]
if epoch_lines:
print(f" Latest: {epoch_lines[-1].strip()}")
elif lines:
print(f" Latest: {lines[-1].strip()}")
# ── Step 2: re-run threshold comparison on new model ─────────────────────────
def run_threshold_comparison():
print("\n[2/5] Running threshold comparison on new checkpoint...")
if not CKPT_OUT.exists():
print(f" Checkpoint not found at {CKPT_OUT} β€” skipping comparison")
return False
# Patch comparison script to use the new checkpoint
comp_script = SCRIPTS / "threshold_comparison.py"
original = comp_script.read_text()
# Temporarily redirect to new checkpoint and thresholds
patched = original.replace(
'CKPT_PATH = ART / "checkpoints" / "protfunc_v3.pth"',
f'CKPT_PATH = Path("{CKPT_OUT}")',
).replace(
'CURRENT_THRESH = ART / "thresholds" / "protfunc_v3_thresholds.json"',
f'CURRENT_THRESH = Path("{THRESH_OUT}")',
)
tmp_script = GRAPH_ART / "threshold_comparison_newmodel.py"
tmp_script.write_text(patched)
r = run([py(), str(tmp_script)], check=False, capture=True)
print(r.stdout[-3000:] if r.stdout else "(no stdout)")
if r.returncode != 0:
print(f" STDERR: {r.stderr[-1000:]}")
return False
# Read new results
new_results_path = ART / "threshold_comparison_results.json"
if new_results_path.exists():
with open(new_results_path) as f:
res = json.load(f)
m = res.get("metrics", {})
print(f"\n New model comparison:")
for strat, metrics in m.items():
if isinstance(metrics, dict) and "micro_f1" in metrics:
print(f" {strat}: P={metrics['micro_precision']:.4f} "
f"R={metrics['micro_recall']:.4f} F1={metrics['micro_f1']:.4f} "
f"avg_preds={metrics['mean_preds_per_protein']:.1f}")
return True
# ── Step 3: deploy thresholds to webapp ───────────────────────────────────────
def deploy_thresholds():
print("\n[3/5] Deploying thresholds to webapp...")
# Prefer new model's precision+IC thresholds; fallback to existing
src = PREC_THRESH if PREC_THRESH.exists() else None
if THRESH_OUT.exists():
# Use new model's thresholds but only if they look reasonable
with open(THRESH_OUT) as f:
t = json.load(f)
vals = list(t.values())
mean_t = sum(vals) / len(vals)
n_low = sum(1 for v in vals if v < 0.3)
print(f" New model thresholds: N={len(vals)}, mean={mean_t:.3f}, <0.3: {n_low}")
if mean_t >= 0.5 and n_low < 300:
src = THRESH_OUT
print(f" Using new model thresholds (mean {mean_t:.3f})")
else:
print(f" New thresholds look poor (mean {mean_t:.3f}, {n_low} < 0.3); "
f"using precision+IC fallback")
if src is None:
print(" No suitable threshold file found β€” skipping deploy")
return False
import shutil
shutil.copy(src, DEPLOY_THRESH)
print(f" Deployed {src.name} β†’ {DEPLOY_THRESH}")
return True
# ── Step 4: test with known proteins ─────────────────────────────────────────
def test_proteins():
print("\n[4/5] Testing with known proteins...")
# Import server components directly rather than spinning up HTTP
sys.path.insert(0, str(WEBAPP))
try:
import importlib, types
# Load server module without running startup (we need model + predict logic)
# Instead run a headless prediction using the same logic as server.py
results = run_headless_predictions()
return validate_results(results)
except Exception as e:
print(f" Headless test failed: {e}")
import traceback; traceback.print_exc()
return False
def run_headless_predictions():
"""Run predictions using the same model/threshold stack as the webapp."""
import os, math
import joblib
import torch
import torch.nn as nn
import numpy as np
# Load model (same priority as server.py)
ckpt_candidates = [
CKPT_OUT, # new model if ready
WEBAPP / "models" / "protfunc_v3.pth",
WEBAPP / "Models" / "protfunc_v3.pth",
]
ckpt_path = next((p for p in ckpt_candidates if p.exists()), None)
if ckpt_path is None:
print(" No checkpoint found in webapp/models; skipping test")
return []
# Load ESM
print(f" Loading ESM-2 (this takes ~30s)...")
try:
import esm as esm_lib
esm_model, alphabet = esm_lib.pretrained.esm2_t6_8M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()
except Exception as e:
print(f" ESM load failed: {e} β€” using random embeddings for smoke test")
esm_model = None
# Load model
class ResBlock(nn.Module):
def __init__(self, d, dr=0.2):
super().__init__()
self.net = nn.Sequential(nn.BatchNorm1d(d),nn.ReLU(),nn.Dropout(dr),nn.Linear(d,d),
nn.BatchNorm1d(d),nn.ReLU(),nn.Dropout(dr),nn.Linear(d,d))
def forward(self, x): return x + self.net(x)
class MLP(nn.Module):
def __init__(self, i, o=8124, h=2048, n=4, dr=0.2):
super().__init__()
self.fc_in = nn.Linear(i, h)
self.blocks = nn.ModuleList([ResBlock(h, dr) for _ in range(n)])
self.fc_out = nn.Sequential(nn.BatchNorm1d(h),nn.ReLU(),nn.Dropout(dr),nn.Linear(h,o))
def forward(self, x):
h = self.fc_in(x)
for b in self.blocks: h = b(h)
return self.fc_out(h)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
state = ckpt.get("model", ckpt)
in_dim = state["fc_in.weight"].shape[1]
model = MLP(in_dim)
model.load_state_dict(state)
model.eval()
print(f" Model loaded: {ckpt_path.name}, in_dim={in_dim}")
# Load MLB + thresholds + go_map
mlb_path = next(
(p for p in [WEBAPP / "mlb_public_v1.pkl",
BASE / "Important Files" / "mlb_public_v1.pkl"] if p.exists()), None
)
if mlb_path is None:
print(" MLB not found β€” skipping test"); return []
mlb = joblib.load(mlb_path)
with open(WEBAPP / "per_label_thresholds.json") as f:
thresholds = {int(k): float(v) for k, v in json.load(f).items()}
with open(WEBAPP / "go_map.json") as f:
go_map = json.load(f)
# Get MF indices from GO hierarchy
obo_path = WEBAPP / "go-basic.obo"
if not obo_path.exists():
obo_path = BASE / "go-basic.obo"
mf_ids = set()
if obo_path.exists():
ns_map = {}
with open(obo_path) as fh:
cur_id, cur_ns, in_term = None, None, False
for line in fh:
line = line.strip()
if line == "[Term]": in_term = True; continue
if line.startswith("[") and line != "[Term]": in_term = False; continue
if not in_term: continue
if line.startswith("id:"): cur_id = line.split("id:",1)[1].strip().split()[0]
elif line.startswith("namespace:"): ns_map[cur_id] = line.split("namespace:",1)[1].strip()
mf_ids = {g for g, n in ns_map.items() if n == "molecular_function"}
mf_idx = [j for j, c in enumerate(mlb.classes_) if c in mf_ids] if mf_ids else list(range(len(mlb.classes_)))
# Supp cols handling
supp_cols = ckpt.get("supp_cols", []) if isinstance(ckpt, dict) else []
supp_mu = np.array(ckpt["supp_mu"], dtype=np.float32) if "supp_mu" in (ckpt if isinstance(ckpt, dict) else {}) else None
supp_sd = np.array(ckpt["supp_sd"], dtype=np.float32) if "supp_sd" in (ckpt if isinstance(ckpt, dict) else {}) else None
results = []
for prot in TEST_PROTEINS:
name = prot["name"]
seq = prot["sequence"].replace("\n", "").replace(" ", "")
print(f" Testing {name} ({len(seq)} aa)...")
try:
if esm_model is not None:
_, _, tokens = batch_converter([("p", seq)])
with torch.no_grad():
rep = esm_model(tokens, repr_layers=[6])["representations"][6]
emb = rep[0, 1:len(seq)+1].mean(0)
else:
emb = torch.randn(320) # smoke test
# Build model input
emb_np = emb.detach().cpu().numpy()
if supp_cols and supp_mu is not None:
feats = {}
for c in supp_cols:
if c.startswith("f_Dim_"):
try:
idx = int(c.split("_")[-1])
if idx < len(emb_np): feats[c] = float(emb_np[idx])
except: pass
s_vec = np.array([feats.get(c, 0.0) for c in supp_cols], dtype=np.float32)
s_z = (s_vec - supp_mu) / (supp_sd + 1e-12)
flag = np.array([1.0], dtype=np.float32)
inp_np = np.concatenate([emb_np, s_z, flag])
else:
inp_np = emb_np[:in_dim]
inp = torch.tensor(inp_np).unsqueeze(0)
with torch.no_grad():
prob = torch.sigmoid(model(inp)).squeeze()
# Apply thresholds (MF only)
raw_preds = []
for i in mf_idx:
pv = float(prob[i])
if pv >= thresholds.get(i, 0.5):
raw_preds.append({
"go_id": mlb.classes_[i],
"name": go_map.get(mlb.classes_[i], mlb.classes_[i]),
"prob": round(pv, 4),
})
raw_preds.sort(key=lambda x: -x["prob"])
cap = min(20, max(3, len(raw_preds)))
raw_preds = raw_preds[:cap]
results.append({
"protein": name,
"n_preds": len(raw_preds),
"top5": raw_preds[:5],
"expect": prot["expect_contains"],
"expect_not": prot["expect_not_contains"],
})
except Exception as e:
results.append({"protein": name, "error": str(e)})
return results
def validate_results(results):
print("\n Validation:")
all_ok = True
for r in results:
if "error" in r:
print(f" βœ— {r['protein']}: ERROR β€” {r['error']}")
all_ok = False
continue
top_names = " ".join(p["name"].lower() for p in r["top5"])
top_probs = [p["prob"] for p in r["top5"]]
n = r["n_preds"]
# Check expected terms appear somewhere in top predictions
hits = [e for e in r["expect"] if e.lower() in top_names]
misses = [e for e in r["expect_not"] if e.lower() in top_names]
status = "βœ“" if (hits or n <= 15) and not misses else "~"
if misses: all_ok = False
print(f" {status} {r['protein']}: {n} preds | top probs {[round(p,3) for p in top_probs[:3]]}")
print(f" top: {[p['name'] for p in r['top5'][:3]]}")
if hits: print(f" hits: {hits}")
if misses: print(f" FALSE POSITIVES: {misses}")
return all_ok
# ── Step 5: push to HuggingFace Space ─────────────────────────────────────────
def push_to_hf(dry_run=False):
print("\n[5/5] Pushing to HuggingFace Space...")
if dry_run:
print(" --dry-run: skipping push")
return True
try:
import subprocess as sp
# Get latest commit on webapp main
r = sp.run(["git", "log", "--oneline", "-1"], cwd=WEBAPP, capture_output=True, text=True)
latest_sha = r.stdout.split()[0]
print(f" Latest commit: {r.stdout.strip()}")
# Cherry-pick onto space/main
cmds = [
["git", "stash"],
["git", "checkout", "-b", "space-postrain", "space/main"],
["git", "cherry-pick", latest_sha],
["git", "push", "space", "space-postrain:main"],
["git", "checkout", "main"],
["git", "stash", "pop"],
["git", "branch", "-d", "space-postrain"],
]
for cmd in cmds:
result = sp.run(cmd, cwd=WEBAPP, capture_output=True, text=True)
if result.returncode != 0 and "cherry-pick" in " ".join(cmd):
# May already be on space β€” try skip
sp.run(["git", "cherry-pick", "--skip"], cwd=WEBAPP)
print(f" {' '.join(cmd[-2:])}: {'ok' if result.returncode == 0 else result.stderr[:80]}")
return True
except Exception as e:
print(f" Push failed: {e}")
return False
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
p = argparse.ArgumentParser()
p.add_argument("--skip-wait", action="store_true", help="Don't wait for training")
p.add_argument("--skip-test", action="store_true", help="Skip protein validation")
p.add_argument("--dry-run", action="store_true", help="Skip HF push")
p.add_argument("--skip-threshold-comparison", action="store_true")
args = p.parse_args()
t0 = time.time()
summary = {"steps": {}, "start_time": time.strftime("%Y-%m-%dT%H:%M:%S")}
# Step 1
if not args.skip_wait:
ok = wait_for_training()
summary["steps"]["wait_training"] = "ok" if ok else "failed"
if not ok:
print("Training failed or timed out. Aborting.")
_save_summary(summary); return
else:
print("[1/5] Skipping wait (--skip-wait)")
# Step 2
if not args.skip_threshold_comparison:
ok = run_threshold_comparison()
summary["steps"]["threshold_comparison"] = "ok" if ok else "skipped"
else:
print("[2/5] Skipping threshold comparison")
# Step 3
ok = deploy_thresholds()
summary["steps"]["deploy_thresholds"] = "ok" if ok else "failed"
# Step 4
if not args.skip_test:
ok = test_proteins()
summary["steps"]["protein_test"] = "passed" if ok else "warnings"
else:
print("[4/5] Skipping protein test (--skip-test)")
# Step 5
ok = push_to_hf(dry_run=args.dry_run)
summary["steps"]["hf_push"] = "ok" if ok else "failed"
summary["elapsed_minutes"] = round((time.time() - t0) / 60, 1)
_save_summary(summary)
print(f"\nPipeline complete in {summary['elapsed_minutes']} min.")
print(f"Summary: {SUMMARY_OUT}")
def _save_summary(summary):
GRAPH_ART.mkdir(parents=True, exist_ok=True)
with open(SUMMARY_OUT, "w") as f:
json.dump(summary, f, indent=2)
if __name__ == "__main__":
main()