| """ |
| post_train_pipeline.py |
| ====================== |
| Automated post-training pipeline. Runs after graph_hpo_best training completes: |
| |
| 1. Wait for training to finish (polls train.log) |
| 2. Re-run threshold_comparison.py on the new checkpoint |
| 3. Copy best thresholds to webapp per_label_thresholds.json |
| 4. Test the local server with 5 known proteins and validate output |
| 5. Push to HuggingFace Space |
| 6. Write summary to artifacts/graph_hpo/pipeline_summary.json |
| |
| Usage: |
| python scripts/post_train_pipeline.py |
| python scripts/post_train_pipeline.py --skip-wait # if training already done |
| python scripts/post_train_pipeline.py --dry-run # skip HF push |
| """ |
|
|
| import argparse |
| import json |
| import math |
| import shlex |
| import subprocess |
| import sys |
| import time |
| from pathlib import Path |
|
|
| BASE = Path(__file__).parent.parent |
| ART = BASE / "artifacts" |
| GRAPH_ART = ART / "graph_hpo" |
| WEBAPP = Path("/Users/siddhantbhat/insecta_webapp") |
| SCRIPTS = BASE / "scripts" |
|
|
| TRAIN_LOG = GRAPH_ART / "train.log" |
| CKPT_OUT = GRAPH_ART / "graph_hpo_best.pth" |
| THRESH_OUT = GRAPH_ART / "graph_hpo_best_thresholds.json" |
| LOG_OUT = GRAPH_ART / "graph_hpo_best_log.json" |
| COMPARISON = ART / "threshold_comparison_results.json" |
| PREC_THRESH = ART / "thresholds" / "precision_ic_thresholds.json" |
| SUMMARY_OUT = GRAPH_ART / "pipeline_summary.json" |
| DEPLOY_THRESH = WEBAPP / "per_label_thresholds.json" |
|
|
| POLL_INTERVAL = 60 |
|
|
|
|
| |
| |
| TEST_PROTEINS = [ |
| { |
| "name": "Drosophila_Hippo_kinase", |
| "expect_contains": ["kinase", "ATP", "phospho"], |
| "expect_not_contains": ["chorion", "neurotransmitter", "olfactory"], |
| "sequence": ( |
| "MSRSSGSSGSAATPVGKRSGQNLSTSIGKGDSFSIRSRSVASTSSSGLNNSGSSTTLNRSN" |
| "TSSSSVNTSSSQNRSSTLSTSSANNVTSSSSTTMQDNQFLSSSQLEKIQRELEQTLKQLNR" |
| "QQAELQRQLSQSQSQSQSESQSESQSMSSRSTPVAIPPTQAPPSQSSQSSQSSQSSQSAQD" |
| "AAAPLNSSSSSSSSSSSSSQQQQLQEQQQQQQQQHQQQHQLHQQHQHHQQHQQQQPQQQPQ" |
| "QQPQQQPQQQPQAPQAQPQPQPQPQQQQPQQQQPQQQQPQQNPQSQNPQSQNPQSQNAQAQ" |
| "AQGQAQGQAQGQAQGPQAQAPQSSQGGMGGNNMGGSGMGGSGMGGSGMGNSGMGSSIGGIT" |
| ), |
| }, |
| { |
| "name": "Opsin_Rh6", |
| "expect_contains": ["photoreceptor", "retinal", "light", "G protein", "receptor"], |
| "expect_not_contains": ["kinase", "DNA", "ribosom"], |
| "sequence": ( |
| "MSSENGDLQFLNGTEGPNFYVPMSNATGVVRSPAEYNQTRESPIFTYTNSNSTRGPFEGPN" |
| "YHIAPRWVYHLTSVWMIFVVTASVFTNGLVLAATMKFKKLRHPLNWILVNLAVADLAETVIASTISVVNQF" |
| "FGYFILGHPMCVLEGYTVSACGITALWSLAIISWERWLVVCKPFGNIKTPLAQRLIAAIWLFSVWIGVPFN" |
| "IPEGPQNPVSTFKAMLACNPWVLTPLFKTDPESAGYSIFNIYIFLCHFFIPMAVIVFSCYGNIVMTLHSH" |
| "TKKEALMQFPQTPQQEASSTTVSKTETSQVAPA" |
| ), |
| }, |
| { |
| "name": "Human_TP53", |
| "expect_contains": ["DNA binding", "transcription", "p53"], |
| "expect_not_contains": ["kinase activity", "photoreceptor"], |
| "sequence": ( |
| "MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGP" |
| "DEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYPQGLNGTVNLFRNLNKALK" |
| "SLSERNSTAHSQAPQLEQPQSSAPISGKQPQSRSAPLEHHLQKSQAPLPNASPRPPMSSNI" |
| "SQGSSQSKASGSTMGPILSSSSDDIEQWFTEDPSTSEELNEALELKDAQAGKESLHSQATLE" |
| "STAHPLTSESVQISAVQILANAQREALESLPTKMPVQPEAEELDNFYINQQQTNSASSLSSS" |
| "RAENTLPDPHYREVGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGQM" |
| "NRRPILTIITLEDSSGKLLGRNSFEVRVCACPGRDRRTEEENLRKKGQVLKEIREGQRLSDP" |
| "NTCQKHKKLSELLGQSSFEV" |
| ), |
| }, |
| { |
| "name": "Ubiquitin_Drosophila", |
| "expect_contains": ["ubiquitin", "protein binding"], |
| "expect_not_contains": ["kinase", "photoreceptor"], |
| "sequence": "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG", |
| }, |
| { |
| "name": "ATP_synthase_beta", |
| "expect_contains": ["ATP", "hydrolase", "synthase"], |
| "expect_not_contains": ["photoreceptor", "DNA binding", "olfactory"], |
| "sequence": ( |
| "MATLRILNIGDQMDVGLNAKSSISRLVQTLPQRKVLISAQDMLQRIREKAEMLGDVPIMVD" |
| "STSRQLDAEGQVVLSGTSIAIDAMIPIGRGQRELIIGDRQTGKTAVAIDTIINQKGSITSVQ" |
| "AIYVPADDLTDPAPATTFAHLDATTVLSRAIAKGIYPAVDPLDSTSTMLQPRIVGQEHYDTA" |
| "RGVQKILQDYKSLQDIIAILGMDELSEEDKLTVSRARLAEHSSKARELKQLVAERQEIVAKA" |
| "HARAQKIPQAVAAEREALGSDPQIALTASDSYLKQLPQLTDQLAQLMKKQSQQELAQLMQQP" |
| ), |
| }, |
| ] |
|
|
|
|
| def py() -> str: |
| venv = BASE / ".venv" / "bin" / "python3" |
| return str(venv if venv.exists() else Path("python3")) |
|
|
|
|
| def run(cmd, check=True, capture=False): |
| print(f"$ {shlex.join(cmd)}") |
| if capture: |
| r = subprocess.run(cmd, capture_output=True, text=True, cwd=BASE) |
| return r |
| subprocess.run(cmd, cwd=BASE, check=check) |
|
|
|
|
| |
|
|
| def wait_for_training(): |
| print("\n[1/5] Waiting for training to complete...") |
| if not TRAIN_LOG.exists(): |
| print(f" train.log not found at {TRAIN_LOG}; assuming training not started") |
| return False |
|
|
| last_size = 0 |
| stall_count = 0 |
| while True: |
| size = TRAIN_LOG.stat().st_size |
| content = TRAIN_LOG.read_text(errors="replace") |
|
|
| if "Test evaluation" in content or "[6/6] Done" in content: |
| print(" Training complete.") |
| print_last_epoch(content) |
| return True |
|
|
| if "Traceback" in content or "Error" in content.split("Traceback")[0]: |
| |
| recent = content[-2000:] |
| if "Traceback" in recent or ("Error" in recent and "Warning" not in recent): |
| print(f" ERROR detected in training log:") |
| for line in content.splitlines()[-20:]: |
| print(f" {line}") |
| return False |
|
|
| if size == last_size: |
| stall_count += 1 |
| if stall_count > 30: |
| print(" Training appears stalled. Checking process...") |
| r = subprocess.run(["pgrep", "-f", "train_v3_fixed"], capture_output=True, text=True) |
| if not r.stdout.strip(): |
| print(" Training process not running. Checking log for completion...") |
| if CKPT_OUT.exists(): |
| print(" Checkpoint exists β treating as complete.") |
| return True |
| return False |
| else: |
| stall_count = 0 |
| last_size = size |
| print_last_epoch(content) |
|
|
| time.sleep(POLL_INTERVAL) |
|
|
|
|
| def print_last_epoch(content): |
| lines = content.splitlines() |
| epoch_lines = [l for l in lines if "micro-fmax" in l or "Epoch" in l] |
| if epoch_lines: |
| print(f" Latest: {epoch_lines[-1].strip()}") |
| elif lines: |
| print(f" Latest: {lines[-1].strip()}") |
|
|
|
|
| |
|
|
| def run_threshold_comparison(): |
| print("\n[2/5] Running threshold comparison on new checkpoint...") |
| if not CKPT_OUT.exists(): |
| print(f" Checkpoint not found at {CKPT_OUT} β skipping comparison") |
| return False |
|
|
| |
| comp_script = SCRIPTS / "threshold_comparison.py" |
| original = comp_script.read_text() |
|
|
| |
| patched = original.replace( |
| 'CKPT_PATH = ART / "checkpoints" / "protfunc_v3.pth"', |
| f'CKPT_PATH = Path("{CKPT_OUT}")', |
| ).replace( |
| 'CURRENT_THRESH = ART / "thresholds" / "protfunc_v3_thresholds.json"', |
| f'CURRENT_THRESH = Path("{THRESH_OUT}")', |
| ) |
|
|
| tmp_script = GRAPH_ART / "threshold_comparison_newmodel.py" |
| tmp_script.write_text(patched) |
|
|
| r = run([py(), str(tmp_script)], check=False, capture=True) |
| print(r.stdout[-3000:] if r.stdout else "(no stdout)") |
| if r.returncode != 0: |
| print(f" STDERR: {r.stderr[-1000:]}") |
| return False |
|
|
| |
| new_results_path = ART / "threshold_comparison_results.json" |
| if new_results_path.exists(): |
| with open(new_results_path) as f: |
| res = json.load(f) |
| m = res.get("metrics", {}) |
| print(f"\n New model comparison:") |
| for strat, metrics in m.items(): |
| if isinstance(metrics, dict) and "micro_f1" in metrics: |
| print(f" {strat}: P={metrics['micro_precision']:.4f} " |
| f"R={metrics['micro_recall']:.4f} F1={metrics['micro_f1']:.4f} " |
| f"avg_preds={metrics['mean_preds_per_protein']:.1f}") |
| return True |
|
|
|
|
| |
|
|
| def deploy_thresholds(): |
| print("\n[3/5] Deploying thresholds to webapp...") |
|
|
| |
| src = PREC_THRESH if PREC_THRESH.exists() else None |
| if THRESH_OUT.exists(): |
| |
| with open(THRESH_OUT) as f: |
| t = json.load(f) |
| vals = list(t.values()) |
| mean_t = sum(vals) / len(vals) |
| n_low = sum(1 for v in vals if v < 0.3) |
| print(f" New model thresholds: N={len(vals)}, mean={mean_t:.3f}, <0.3: {n_low}") |
| if mean_t >= 0.5 and n_low < 300: |
| src = THRESH_OUT |
| print(f" Using new model thresholds (mean {mean_t:.3f})") |
| else: |
| print(f" New thresholds look poor (mean {mean_t:.3f}, {n_low} < 0.3); " |
| f"using precision+IC fallback") |
|
|
| if src is None: |
| print(" No suitable threshold file found β skipping deploy") |
| return False |
|
|
| import shutil |
| shutil.copy(src, DEPLOY_THRESH) |
| print(f" Deployed {src.name} β {DEPLOY_THRESH}") |
| return True |
|
|
|
|
| |
|
|
| def test_proteins(): |
| print("\n[4/5] Testing with known proteins...") |
|
|
| |
| sys.path.insert(0, str(WEBAPP)) |
|
|
| try: |
| import importlib, types |
|
|
| |
| |
| results = run_headless_predictions() |
| return validate_results(results) |
|
|
| except Exception as e: |
| print(f" Headless test failed: {e}") |
| import traceback; traceback.print_exc() |
| return False |
|
|
|
|
| def run_headless_predictions(): |
| """Run predictions using the same model/threshold stack as the webapp.""" |
| import os, math |
| import joblib |
| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
| |
| ckpt_candidates = [ |
| CKPT_OUT, |
| WEBAPP / "models" / "protfunc_v3.pth", |
| WEBAPP / "Models" / "protfunc_v3.pth", |
| ] |
| ckpt_path = next((p for p in ckpt_candidates if p.exists()), None) |
| if ckpt_path is None: |
| print(" No checkpoint found in webapp/models; skipping test") |
| return [] |
|
|
| |
| print(f" Loading ESM-2 (this takes ~30s)...") |
| try: |
| import esm as esm_lib |
| esm_model, alphabet = esm_lib.pretrained.esm2_t6_8M_UR50D() |
| batch_converter = alphabet.get_batch_converter() |
| esm_model.eval() |
| except Exception as e: |
| print(f" ESM load failed: {e} β using random embeddings for smoke test") |
| esm_model = None |
|
|
| |
| class ResBlock(nn.Module): |
| def __init__(self, d, dr=0.2): |
| super().__init__() |
| self.net = nn.Sequential(nn.BatchNorm1d(d),nn.ReLU(),nn.Dropout(dr),nn.Linear(d,d), |
| nn.BatchNorm1d(d),nn.ReLU(),nn.Dropout(dr),nn.Linear(d,d)) |
| def forward(self, x): return x + self.net(x) |
|
|
| class MLP(nn.Module): |
| def __init__(self, i, o=8124, h=2048, n=4, dr=0.2): |
| super().__init__() |
| self.fc_in = nn.Linear(i, h) |
| self.blocks = nn.ModuleList([ResBlock(h, dr) for _ in range(n)]) |
| self.fc_out = nn.Sequential(nn.BatchNorm1d(h),nn.ReLU(),nn.Dropout(dr),nn.Linear(h,o)) |
| def forward(self, x): |
| h = self.fc_in(x) |
| for b in self.blocks: h = b(h) |
| return self.fc_out(h) |
|
|
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| state = ckpt.get("model", ckpt) |
| in_dim = state["fc_in.weight"].shape[1] |
| model = MLP(in_dim) |
| model.load_state_dict(state) |
| model.eval() |
| print(f" Model loaded: {ckpt_path.name}, in_dim={in_dim}") |
|
|
| |
| mlb_path = next( |
| (p for p in [WEBAPP / "mlb_public_v1.pkl", |
| BASE / "Important Files" / "mlb_public_v1.pkl"] if p.exists()), None |
| ) |
| if mlb_path is None: |
| print(" MLB not found β skipping test"); return [] |
| mlb = joblib.load(mlb_path) |
|
|
| with open(WEBAPP / "per_label_thresholds.json") as f: |
| thresholds = {int(k): float(v) for k, v in json.load(f).items()} |
| with open(WEBAPP / "go_map.json") as f: |
| go_map = json.load(f) |
|
|
| |
| obo_path = WEBAPP / "go-basic.obo" |
| if not obo_path.exists(): |
| obo_path = BASE / "go-basic.obo" |
| mf_ids = set() |
| if obo_path.exists(): |
| ns_map = {} |
| with open(obo_path) as fh: |
| cur_id, cur_ns, in_term = None, None, False |
| for line in fh: |
| line = line.strip() |
| if line == "[Term]": in_term = True; continue |
| if line.startswith("[") and line != "[Term]": in_term = False; continue |
| if not in_term: continue |
| if line.startswith("id:"): cur_id = line.split("id:",1)[1].strip().split()[0] |
| elif line.startswith("namespace:"): ns_map[cur_id] = line.split("namespace:",1)[1].strip() |
| mf_ids = {g for g, n in ns_map.items() if n == "molecular_function"} |
| mf_idx = [j for j, c in enumerate(mlb.classes_) if c in mf_ids] if mf_ids else list(range(len(mlb.classes_))) |
|
|
| |
| supp_cols = ckpt.get("supp_cols", []) if isinstance(ckpt, dict) else [] |
| supp_mu = np.array(ckpt["supp_mu"], dtype=np.float32) if "supp_mu" in (ckpt if isinstance(ckpt, dict) else {}) else None |
| supp_sd = np.array(ckpt["supp_sd"], dtype=np.float32) if "supp_sd" in (ckpt if isinstance(ckpt, dict) else {}) else None |
|
|
| results = [] |
| for prot in TEST_PROTEINS: |
| name = prot["name"] |
| seq = prot["sequence"].replace("\n", "").replace(" ", "") |
| print(f" Testing {name} ({len(seq)} aa)...") |
|
|
| try: |
| if esm_model is not None: |
| _, _, tokens = batch_converter([("p", seq)]) |
| with torch.no_grad(): |
| rep = esm_model(tokens, repr_layers=[6])["representations"][6] |
| emb = rep[0, 1:len(seq)+1].mean(0) |
| else: |
| emb = torch.randn(320) |
|
|
| |
| emb_np = emb.detach().cpu().numpy() |
| if supp_cols and supp_mu is not None: |
| feats = {} |
| for c in supp_cols: |
| if c.startswith("f_Dim_"): |
| try: |
| idx = int(c.split("_")[-1]) |
| if idx < len(emb_np): feats[c] = float(emb_np[idx]) |
| except: pass |
| s_vec = np.array([feats.get(c, 0.0) for c in supp_cols], dtype=np.float32) |
| s_z = (s_vec - supp_mu) / (supp_sd + 1e-12) |
| flag = np.array([1.0], dtype=np.float32) |
| inp_np = np.concatenate([emb_np, s_z, flag]) |
| else: |
| inp_np = emb_np[:in_dim] |
|
|
| inp = torch.tensor(inp_np).unsqueeze(0) |
| with torch.no_grad(): |
| prob = torch.sigmoid(model(inp)).squeeze() |
|
|
| |
| raw_preds = [] |
| for i in mf_idx: |
| pv = float(prob[i]) |
| if pv >= thresholds.get(i, 0.5): |
| raw_preds.append({ |
| "go_id": mlb.classes_[i], |
| "name": go_map.get(mlb.classes_[i], mlb.classes_[i]), |
| "prob": round(pv, 4), |
| }) |
| raw_preds.sort(key=lambda x: -x["prob"]) |
| cap = min(20, max(3, len(raw_preds))) |
| raw_preds = raw_preds[:cap] |
|
|
| results.append({ |
| "protein": name, |
| "n_preds": len(raw_preds), |
| "top5": raw_preds[:5], |
| "expect": prot["expect_contains"], |
| "expect_not": prot["expect_not_contains"], |
| }) |
| except Exception as e: |
| results.append({"protein": name, "error": str(e)}) |
|
|
| return results |
|
|
|
|
| def validate_results(results): |
| print("\n Validation:") |
| all_ok = True |
| for r in results: |
| if "error" in r: |
| print(f" β {r['protein']}: ERROR β {r['error']}") |
| all_ok = False |
| continue |
|
|
| top_names = " ".join(p["name"].lower() for p in r["top5"]) |
| top_probs = [p["prob"] for p in r["top5"]] |
| n = r["n_preds"] |
|
|
| |
| hits = [e for e in r["expect"] if e.lower() in top_names] |
| misses = [e for e in r["expect_not"] if e.lower() in top_names] |
|
|
| status = "β" if (hits or n <= 15) and not misses else "~" |
| if misses: all_ok = False |
|
|
| print(f" {status} {r['protein']}: {n} preds | top probs {[round(p,3) for p in top_probs[:3]]}") |
| print(f" top: {[p['name'] for p in r['top5'][:3]]}") |
| if hits: print(f" hits: {hits}") |
| if misses: print(f" FALSE POSITIVES: {misses}") |
|
|
| return all_ok |
|
|
|
|
| |
|
|
| def push_to_hf(dry_run=False): |
| print("\n[5/5] Pushing to HuggingFace Space...") |
| if dry_run: |
| print(" --dry-run: skipping push") |
| return True |
|
|
| try: |
| import subprocess as sp |
| |
| r = sp.run(["git", "log", "--oneline", "-1"], cwd=WEBAPP, capture_output=True, text=True) |
| latest_sha = r.stdout.split()[0] |
| print(f" Latest commit: {r.stdout.strip()}") |
|
|
| |
| cmds = [ |
| ["git", "stash"], |
| ["git", "checkout", "-b", "space-postrain", "space/main"], |
| ["git", "cherry-pick", latest_sha], |
| ["git", "push", "space", "space-postrain:main"], |
| ["git", "checkout", "main"], |
| ["git", "stash", "pop"], |
| ["git", "branch", "-d", "space-postrain"], |
| ] |
| for cmd in cmds: |
| result = sp.run(cmd, cwd=WEBAPP, capture_output=True, text=True) |
| if result.returncode != 0 and "cherry-pick" in " ".join(cmd): |
| |
| sp.run(["git", "cherry-pick", "--skip"], cwd=WEBAPP) |
| print(f" {' '.join(cmd[-2:])}: {'ok' if result.returncode == 0 else result.stderr[:80]}") |
| return True |
| except Exception as e: |
| print(f" Push failed: {e}") |
| return False |
|
|
|
|
| |
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--skip-wait", action="store_true", help="Don't wait for training") |
| p.add_argument("--skip-test", action="store_true", help="Skip protein validation") |
| p.add_argument("--dry-run", action="store_true", help="Skip HF push") |
| p.add_argument("--skip-threshold-comparison", action="store_true") |
| args = p.parse_args() |
|
|
| t0 = time.time() |
| summary = {"steps": {}, "start_time": time.strftime("%Y-%m-%dT%H:%M:%S")} |
|
|
| |
| if not args.skip_wait: |
| ok = wait_for_training() |
| summary["steps"]["wait_training"] = "ok" if ok else "failed" |
| if not ok: |
| print("Training failed or timed out. Aborting.") |
| _save_summary(summary); return |
| else: |
| print("[1/5] Skipping wait (--skip-wait)") |
|
|
| |
| if not args.skip_threshold_comparison: |
| ok = run_threshold_comparison() |
| summary["steps"]["threshold_comparison"] = "ok" if ok else "skipped" |
| else: |
| print("[2/5] Skipping threshold comparison") |
|
|
| |
| ok = deploy_thresholds() |
| summary["steps"]["deploy_thresholds"] = "ok" if ok else "failed" |
|
|
| |
| if not args.skip_test: |
| ok = test_proteins() |
| summary["steps"]["protein_test"] = "passed" if ok else "warnings" |
| else: |
| print("[4/5] Skipping protein test (--skip-test)") |
|
|
| |
| ok = push_to_hf(dry_run=args.dry_run) |
| summary["steps"]["hf_push"] = "ok" if ok else "failed" |
|
|
| summary["elapsed_minutes"] = round((time.time() - t0) / 60, 1) |
| _save_summary(summary) |
| print(f"\nPipeline complete in {summary['elapsed_minutes']} min.") |
| print(f"Summary: {SUMMARY_OUT}") |
|
|
|
|
| def _save_summary(summary): |
| GRAPH_ART.mkdir(parents=True, exist_ok=True) |
| with open(SUMMARY_OUT, "w") as f: |
| json.dump(summary, f, indent=2) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|