""" post_train_pipeline.py ====================== Automated post-training pipeline. Runs after graph_hpo_best training completes: 1. Wait for training to finish (polls train.log) 2. Re-run threshold_comparison.py on the new checkpoint 3. Copy best thresholds to webapp per_label_thresholds.json 4. Test the local server with 5 known proteins and validate output 5. Push to HuggingFace Space 6. Write summary to artifacts/graph_hpo/pipeline_summary.json Usage: python scripts/post_train_pipeline.py python scripts/post_train_pipeline.py --skip-wait # if training already done python scripts/post_train_pipeline.py --dry-run # skip HF push """ import argparse import json import math import shlex import subprocess import sys import time from pathlib import Path BASE = Path(__file__).parent.parent ART = BASE / "artifacts" GRAPH_ART = ART / "graph_hpo" WEBAPP = Path("/Users/siddhantbhat/insecta_webapp") SCRIPTS = BASE / "scripts" TRAIN_LOG = GRAPH_ART / "train.log" CKPT_OUT = GRAPH_ART / "graph_hpo_best.pth" THRESH_OUT = GRAPH_ART / "graph_hpo_best_thresholds.json" LOG_OUT = GRAPH_ART / "graph_hpo_best_log.json" COMPARISON = ART / "threshold_comparison_results.json" PREC_THRESH = ART / "thresholds" / "precision_ic_thresholds.json" SUMMARY_OUT = GRAPH_ART / "pipeline_summary.json" DEPLOY_THRESH = WEBAPP / "per_label_thresholds.json" POLL_INTERVAL = 60 # seconds between log polls # ── Test proteins (known functions for sanity checking) ─────────────────────── # Each entry: name, FASTA sequence, expected GO terms that must appear in top-10 TEST_PROTEINS = [ { "name": "Drosophila_Hippo_kinase", "expect_contains": ["kinase", "ATP", "phospho"], "expect_not_contains": ["chorion", "neurotransmitter", "olfactory"], "sequence": ( "MSRSSGSSGSAATPVGKRSGQNLSTSIGKGDSFSIRSRSVASTSSSGLNNSGSSTTLNRSN" "TSSSSVNTSSSQNRSSTLSTSSANNVTSSSSTTMQDNQFLSSSQLEKIQRELEQTLKQLNR" "QQAELQRQLSQSQSQSQSESQSESQSMSSRSTPVAIPPTQAPPSQSSQSSQSSQSSQSAQD" "AAAPLNSSSSSSSSSSSSSQQQQLQEQQQQQQQQHQQQHQLHQQHQHHQQHQQQQPQQQPQ" "QQPQQQPQQQPQAPQAQPQPQPQPQQQQPQQQQPQQQQPQQNPQSQNPQSQNPQSQNAQAQ" "AQGQAQGQAQGQAQGPQAQAPQSSQGGMGGNNMGGSGMGGSGMGGSGMGNSGMGSSIGGIT" ), }, { "name": "Opsin_Rh6", "expect_contains": ["photoreceptor", "retinal", "light", "G protein", "receptor"], "expect_not_contains": ["kinase", "DNA", "ribosom"], "sequence": ( "MSSENGDLQFLNGTEGPNFYVPMSNATGVVRSPAEYNQTRESPIFTYTNSNSTRGPFEGPN" "YHIAPRWVYHLTSVWMIFVVTASVFTNGLVLAATMKFKKLRHPLNWILVNLAVADLAETVIASTISVVNQF" "FGYFILGHPMCVLEGYTVSACGITALWSLAIISWERWLVVCKPFGNIKTPLAQRLIAAIWLFSVWIGVPFN" "IPEGPQNPVSTFKAMLACNPWVLTPLFKTDPESAGYSIFNIYIFLCHFFIPMAVIVFSCYGNIVMTLHSH" "TKKEALMQFPQTPQQEASSTTVSKTETSQVAPA" ), }, { "name": "Human_TP53", "expect_contains": ["DNA binding", "transcription", "p53"], "expect_not_contains": ["kinase activity", "photoreceptor"], "sequence": ( "MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGP" "DEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYPQGLNGTVNLFRNLNKALK" "SLSERNSTAHSQAPQLEQPQSSAPISGKQPQSRSAPLEHHLQKSQAPLPNASPRPPMSSNI" "SQGSSQSKASGSTMGPILSSSSDDIEQWFTEDPSTSEELNEALELKDAQAGKESLHSQATLE" "STAHPLTSESVQISAVQILANAQREALESLPTKMPVQPEAEELDNFYINQQQTNSASSLSSS" "RAENTLPDPHYREVGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGQM" "NRRPILTIITLEDSSGKLLGRNSFEVRVCACPGRDRRTEEENLRKKGQVLKEIREGQRLSDP" "NTCQKHKKLSELLGQSSFEV" ), }, { "name": "Ubiquitin_Drosophila", "expect_contains": ["ubiquitin", "protein binding"], "expect_not_contains": ["kinase", "photoreceptor"], "sequence": "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG", }, { "name": "ATP_synthase_beta", "expect_contains": ["ATP", "hydrolase", "synthase"], "expect_not_contains": ["photoreceptor", "DNA binding", "olfactory"], "sequence": ( "MATLRILNIGDQMDVGLNAKSSISRLVQTLPQRKVLISAQDMLQRIREKAEMLGDVPIMVD" "STSRQLDAEGQVVLSGTSIAIDAMIPIGRGQRELIIGDRQTGKTAVAIDTIINQKGSITSVQ" "AIYVPADDLTDPAPATTFAHLDATTVLSRAIAKGIYPAVDPLDSTSTMLQPRIVGQEHYDTA" "RGVQKILQDYKSLQDIIAILGMDELSEEDKLTVSRARLAEHSSKARELKQLVAERQEIVAKA" "HARAQKIPQAVAAEREALGSDPQIALTASDSYLKQLPQLTDQLAQLMKKQSQQELAQLMQQP" ), }, ] def py() -> str: venv = BASE / ".venv" / "bin" / "python3" return str(venv if venv.exists() else Path("python3")) def run(cmd, check=True, capture=False): print(f"$ {shlex.join(cmd)}") if capture: r = subprocess.run(cmd, capture_output=True, text=True, cwd=BASE) return r subprocess.run(cmd, cwd=BASE, check=check) # ── Step 1: wait for training ───────────────────────────────────────────────── def wait_for_training(): print("\n[1/5] Waiting for training to complete...") if not TRAIN_LOG.exists(): print(f" train.log not found at {TRAIN_LOG}; assuming training not started") return False last_size = 0 stall_count = 0 while True: size = TRAIN_LOG.stat().st_size content = TRAIN_LOG.read_text(errors="replace") if "Test evaluation" in content or "[6/6] Done" in content: print(" Training complete.") print_last_epoch(content) return True if "Traceback" in content or "Error" in content.split("Traceback")[0]: # Check if error is in the last 2000 chars (recent) recent = content[-2000:] if "Traceback" in recent or ("Error" in recent and "Warning" not in recent): print(f" ERROR detected in training log:") for line in content.splitlines()[-20:]: print(f" {line}") return False if size == last_size: stall_count += 1 if stall_count > 30: # 30 min stall print(" Training appears stalled. Checking process...") r = subprocess.run(["pgrep", "-f", "train_v3_fixed"], capture_output=True, text=True) if not r.stdout.strip(): print(" Training process not running. Checking log for completion...") if CKPT_OUT.exists(): print(" Checkpoint exists — treating as complete.") return True return False else: stall_count = 0 last_size = size print_last_epoch(content) time.sleep(POLL_INTERVAL) def print_last_epoch(content): lines = content.splitlines() epoch_lines = [l for l in lines if "micro-fmax" in l or "Epoch" in l] if epoch_lines: print(f" Latest: {epoch_lines[-1].strip()}") elif lines: print(f" Latest: {lines[-1].strip()}") # ── Step 2: re-run threshold comparison on new model ───────────────────────── def run_threshold_comparison(): print("\n[2/5] Running threshold comparison on new checkpoint...") if not CKPT_OUT.exists(): print(f" Checkpoint not found at {CKPT_OUT} — skipping comparison") return False # Patch comparison script to use the new checkpoint comp_script = SCRIPTS / "threshold_comparison.py" original = comp_script.read_text() # Temporarily redirect to new checkpoint and thresholds patched = original.replace( 'CKPT_PATH = ART / "checkpoints" / "protfunc_v3.pth"', f'CKPT_PATH = Path("{CKPT_OUT}")', ).replace( 'CURRENT_THRESH = ART / "thresholds" / "protfunc_v3_thresholds.json"', f'CURRENT_THRESH = Path("{THRESH_OUT}")', ) tmp_script = GRAPH_ART / "threshold_comparison_newmodel.py" tmp_script.write_text(patched) r = run([py(), str(tmp_script)], check=False, capture=True) print(r.stdout[-3000:] if r.stdout else "(no stdout)") if r.returncode != 0: print(f" STDERR: {r.stderr[-1000:]}") return False # Read new results new_results_path = ART / "threshold_comparison_results.json" if new_results_path.exists(): with open(new_results_path) as f: res = json.load(f) m = res.get("metrics", {}) print(f"\n New model comparison:") for strat, metrics in m.items(): if isinstance(metrics, dict) and "micro_f1" in metrics: print(f" {strat}: P={metrics['micro_precision']:.4f} " f"R={metrics['micro_recall']:.4f} F1={metrics['micro_f1']:.4f} " f"avg_preds={metrics['mean_preds_per_protein']:.1f}") return True # ── Step 3: deploy thresholds to webapp ─────────────────────────────────────── def deploy_thresholds(): print("\n[3/5] Deploying thresholds to webapp...") # Prefer new model's precision+IC thresholds; fallback to existing src = PREC_THRESH if PREC_THRESH.exists() else None if THRESH_OUT.exists(): # Use new model's thresholds but only if they look reasonable with open(THRESH_OUT) as f: t = json.load(f) vals = list(t.values()) mean_t = sum(vals) / len(vals) n_low = sum(1 for v in vals if v < 0.3) print(f" New model thresholds: N={len(vals)}, mean={mean_t:.3f}, <0.3: {n_low}") if mean_t >= 0.5 and n_low < 300: src = THRESH_OUT print(f" Using new model thresholds (mean {mean_t:.3f})") else: print(f" New thresholds look poor (mean {mean_t:.3f}, {n_low} < 0.3); " f"using precision+IC fallback") if src is None: print(" No suitable threshold file found — skipping deploy") return False import shutil shutil.copy(src, DEPLOY_THRESH) print(f" Deployed {src.name} → {DEPLOY_THRESH}") return True # ── Step 4: test with known proteins ───────────────────────────────────────── def test_proteins(): print("\n[4/5] Testing with known proteins...") # Import server components directly rather than spinning up HTTP sys.path.insert(0, str(WEBAPP)) try: import importlib, types # Load server module without running startup (we need model + predict logic) # Instead run a headless prediction using the same logic as server.py results = run_headless_predictions() return validate_results(results) except Exception as e: print(f" Headless test failed: {e}") import traceback; traceback.print_exc() return False def run_headless_predictions(): """Run predictions using the same model/threshold stack as the webapp.""" import os, math import joblib import torch import torch.nn as nn import numpy as np # Load model (same priority as server.py) ckpt_candidates = [ CKPT_OUT, # new model if ready WEBAPP / "models" / "protfunc_v3.pth", WEBAPP / "Models" / "protfunc_v3.pth", ] ckpt_path = next((p for p in ckpt_candidates if p.exists()), None) if ckpt_path is None: print(" No checkpoint found in webapp/models; skipping test") return [] # Load ESM print(f" Loading ESM-2 (this takes ~30s)...") try: import esm as esm_lib esm_model, alphabet = esm_lib.pretrained.esm2_t6_8M_UR50D() batch_converter = alphabet.get_batch_converter() esm_model.eval() except Exception as e: print(f" ESM load failed: {e} — using random embeddings for smoke test") esm_model = None # Load model class ResBlock(nn.Module): def __init__(self, d, dr=0.2): super().__init__() self.net = nn.Sequential(nn.BatchNorm1d(d),nn.ReLU(),nn.Dropout(dr),nn.Linear(d,d), nn.BatchNorm1d(d),nn.ReLU(),nn.Dropout(dr),nn.Linear(d,d)) def forward(self, x): return x + self.net(x) class MLP(nn.Module): def __init__(self, i, o=8124, h=2048, n=4, dr=0.2): super().__init__() self.fc_in = nn.Linear(i, h) self.blocks = nn.ModuleList([ResBlock(h, dr) for _ in range(n)]) self.fc_out = nn.Sequential(nn.BatchNorm1d(h),nn.ReLU(),nn.Dropout(dr),nn.Linear(h,o)) def forward(self, x): h = self.fc_in(x) for b in self.blocks: h = b(h) return self.fc_out(h) ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) state = ckpt.get("model", ckpt) in_dim = state["fc_in.weight"].shape[1] model = MLP(in_dim) model.load_state_dict(state) model.eval() print(f" Model loaded: {ckpt_path.name}, in_dim={in_dim}") # Load MLB + thresholds + go_map mlb_path = next( (p for p in [WEBAPP / "mlb_public_v1.pkl", BASE / "Important Files" / "mlb_public_v1.pkl"] if p.exists()), None ) if mlb_path is None: print(" MLB not found — skipping test"); return [] mlb = joblib.load(mlb_path) with open(WEBAPP / "per_label_thresholds.json") as f: thresholds = {int(k): float(v) for k, v in json.load(f).items()} with open(WEBAPP / "go_map.json") as f: go_map = json.load(f) # Get MF indices from GO hierarchy obo_path = WEBAPP / "go-basic.obo" if not obo_path.exists(): obo_path = BASE / "go-basic.obo" mf_ids = set() if obo_path.exists(): ns_map = {} with open(obo_path) as fh: cur_id, cur_ns, in_term = None, None, False for line in fh: line = line.strip() if line == "[Term]": in_term = True; continue if line.startswith("[") and line != "[Term]": in_term = False; continue if not in_term: continue if line.startswith("id:"): cur_id = line.split("id:",1)[1].strip().split()[0] elif line.startswith("namespace:"): ns_map[cur_id] = line.split("namespace:",1)[1].strip() mf_ids = {g for g, n in ns_map.items() if n == "molecular_function"} mf_idx = [j for j, c in enumerate(mlb.classes_) if c in mf_ids] if mf_ids else list(range(len(mlb.classes_))) # Supp cols handling supp_cols = ckpt.get("supp_cols", []) if isinstance(ckpt, dict) else [] supp_mu = np.array(ckpt["supp_mu"], dtype=np.float32) if "supp_mu" in (ckpt if isinstance(ckpt, dict) else {}) else None supp_sd = np.array(ckpt["supp_sd"], dtype=np.float32) if "supp_sd" in (ckpt if isinstance(ckpt, dict) else {}) else None results = [] for prot in TEST_PROTEINS: name = prot["name"] seq = prot["sequence"].replace("\n", "").replace(" ", "") print(f" Testing {name} ({len(seq)} aa)...") try: if esm_model is not None: _, _, tokens = batch_converter([("p", seq)]) with torch.no_grad(): rep = esm_model(tokens, repr_layers=[6])["representations"][6] emb = rep[0, 1:len(seq)+1].mean(0) else: emb = torch.randn(320) # smoke test # Build model input emb_np = emb.detach().cpu().numpy() if supp_cols and supp_mu is not None: feats = {} for c in supp_cols: if c.startswith("f_Dim_"): try: idx = int(c.split("_")[-1]) if idx < len(emb_np): feats[c] = float(emb_np[idx]) except: pass s_vec = np.array([feats.get(c, 0.0) for c in supp_cols], dtype=np.float32) s_z = (s_vec - supp_mu) / (supp_sd + 1e-12) flag = np.array([1.0], dtype=np.float32) inp_np = np.concatenate([emb_np, s_z, flag]) else: inp_np = emb_np[:in_dim] inp = torch.tensor(inp_np).unsqueeze(0) with torch.no_grad(): prob = torch.sigmoid(model(inp)).squeeze() # Apply thresholds (MF only) raw_preds = [] for i in mf_idx: pv = float(prob[i]) if pv >= thresholds.get(i, 0.5): raw_preds.append({ "go_id": mlb.classes_[i], "name": go_map.get(mlb.classes_[i], mlb.classes_[i]), "prob": round(pv, 4), }) raw_preds.sort(key=lambda x: -x["prob"]) cap = min(20, max(3, len(raw_preds))) raw_preds = raw_preds[:cap] results.append({ "protein": name, "n_preds": len(raw_preds), "top5": raw_preds[:5], "expect": prot["expect_contains"], "expect_not": prot["expect_not_contains"], }) except Exception as e: results.append({"protein": name, "error": str(e)}) return results def validate_results(results): print("\n Validation:") all_ok = True for r in results: if "error" in r: print(f" ✗ {r['protein']}: ERROR — {r['error']}") all_ok = False continue top_names = " ".join(p["name"].lower() for p in r["top5"]) top_probs = [p["prob"] for p in r["top5"]] n = r["n_preds"] # Check expected terms appear somewhere in top predictions hits = [e for e in r["expect"] if e.lower() in top_names] misses = [e for e in r["expect_not"] if e.lower() in top_names] status = "✓" if (hits or n <= 15) and not misses else "~" if misses: all_ok = False print(f" {status} {r['protein']}: {n} preds | top probs {[round(p,3) for p in top_probs[:3]]}") print(f" top: {[p['name'] for p in r['top5'][:3]]}") if hits: print(f" hits: {hits}") if misses: print(f" FALSE POSITIVES: {misses}") return all_ok # ── Step 5: push to HuggingFace Space ───────────────────────────────────────── def push_to_hf(dry_run=False): print("\n[5/5] Pushing to HuggingFace Space...") if dry_run: print(" --dry-run: skipping push") return True try: import subprocess as sp # Get latest commit on webapp main r = sp.run(["git", "log", "--oneline", "-1"], cwd=WEBAPP, capture_output=True, text=True) latest_sha = r.stdout.split()[0] print(f" Latest commit: {r.stdout.strip()}") # Cherry-pick onto space/main cmds = [ ["git", "stash"], ["git", "checkout", "-b", "space-postrain", "space/main"], ["git", "cherry-pick", latest_sha], ["git", "push", "space", "space-postrain:main"], ["git", "checkout", "main"], ["git", "stash", "pop"], ["git", "branch", "-d", "space-postrain"], ] for cmd in cmds: result = sp.run(cmd, cwd=WEBAPP, capture_output=True, text=True) if result.returncode != 0 and "cherry-pick" in " ".join(cmd): # May already be on space — try skip sp.run(["git", "cherry-pick", "--skip"], cwd=WEBAPP) print(f" {' '.join(cmd[-2:])}: {'ok' if result.returncode == 0 else result.stderr[:80]}") return True except Exception as e: print(f" Push failed: {e}") return False # ── Main ────────────────────────────────────────────────────────────────────── def main(): p = argparse.ArgumentParser() p.add_argument("--skip-wait", action="store_true", help="Don't wait for training") p.add_argument("--skip-test", action="store_true", help="Skip protein validation") p.add_argument("--dry-run", action="store_true", help="Skip HF push") p.add_argument("--skip-threshold-comparison", action="store_true") args = p.parse_args() t0 = time.time() summary = {"steps": {}, "start_time": time.strftime("%Y-%m-%dT%H:%M:%S")} # Step 1 if not args.skip_wait: ok = wait_for_training() summary["steps"]["wait_training"] = "ok" if ok else "failed" if not ok: print("Training failed or timed out. Aborting.") _save_summary(summary); return else: print("[1/5] Skipping wait (--skip-wait)") # Step 2 if not args.skip_threshold_comparison: ok = run_threshold_comparison() summary["steps"]["threshold_comparison"] = "ok" if ok else "skipped" else: print("[2/5] Skipping threshold comparison") # Step 3 ok = deploy_thresholds() summary["steps"]["deploy_thresholds"] = "ok" if ok else "failed" # Step 4 if not args.skip_test: ok = test_proteins() summary["steps"]["protein_test"] = "passed" if ok else "warnings" else: print("[4/5] Skipping protein test (--skip-test)") # Step 5 ok = push_to_hf(dry_run=args.dry_run) summary["steps"]["hf_push"] = "ok" if ok else "failed" summary["elapsed_minutes"] = round((time.time() - t0) / 60, 1) _save_summary(summary) print(f"\nPipeline complete in {summary['elapsed_minutes']} min.") print(f"Summary: {SUMMARY_OUT}") def _save_summary(summary): GRAPH_ART.mkdir(parents=True, exist_ok=True) with open(SUMMARY_OUT, "w") as f: json.dump(summary, f, indent=2) if __name__ == "__main__": main()