protfunc / scripts /graph_hpo_sequence.py
Sbhat2026's picture
perf: ESM embedding cache + 1500aa limit, add research scripts
7f7a890
"""
graph_hpo_sequence.py
=====================
Set up and optionally execute a graph-aware HPO -> joint insect+mammal training
-> threshold comparison -> mammal generalization pipeline.
The "graph-based" part is implemented as dependency-aware Optuna TPE:
- multivariate TPE
- grouped TPE
- median pruning
Artifacts are written under artifacts/graph_hpo/ so methodology and metrics are
preserved separately from the default training outputs.
"""
from __future__ import annotations
import argparse
import json
import shlex
import subprocess
from pathlib import Path
BASE = Path(__file__).parent.parent
ART = BASE / "artifacts"
GRAPH_ART = ART / "graph_hpo"
GENERALIZATION_ART = ART / "generalization"
def python_bin() -> str:
venv = BASE / ".venv" / "bin" / "python3"
return str(venv if venv.exists() else Path("python3"))
def build_paths() -> dict[str, Path]:
return {
"methodology_json": GRAPH_ART / "methodology.json",
"methodology_md": GRAPH_ART / "methodology.md",
"summary_json": GRAPH_ART / "sequence_summary.json",
"hpo_json": GRAPH_ART / "hpo_results.json",
"train_ckpt": GRAPH_ART / "graph_hpo_best.pth",
"train_thr": GRAPH_ART / "graph_hpo_best_thresholds.json",
"train_log": GRAPH_ART / "graph_hpo_best_log.json",
"generalization_json": GRAPH_ART / "generalization_results.json",
"generalization_csv": GRAPH_ART / "generalization_results.csv",
}
def methodology_payload(paths: dict[str, Path], args: argparse.Namespace) -> dict:
py = python_bin()
return {
"pipeline_name": "graph_hpo_joint_insect_mammal",
"objective": {
"type": "weighted_joint_micro_fmax",
"formula": "alpha * insect_val_micro_fmax + (1-alpha) * mammal_val_micro_fmax",
"alpha": args.alpha,
},
"graph_hpo": {
"sampler": "Optuna TPESampler",
"multivariate": True,
"group": True,
"pruner": "MedianPruner",
"startup_trials": args.startup_trials,
"warmup_steps": args.warmup_steps,
},
"feature_levels": {
"esm_only": 320,
"esm_seq": 331,
"esm_all": 360,
},
"datasets": {
"insect_base": "Important Files/merged_full_struct.parquet",
"insect_supp": "Important Files/merged_full_struct_with_features.parquet",
"mammal": "artifacts/generalization/mammal_full_v1.parquet",
"splits": "artifacts/splits/splits_n250000_seed42.npz",
"mlb": "Important Files/mlb_public_v1.pkl",
},
"stages": [
{
"name": "hpo",
"purpose": "Graph-aware joint insect+mammal hyperparameter tuning",
"output": str(paths["hpo_json"]),
},
{
"name": "train",
"purpose": "Train exact best feature slice with mammal data included",
"output": str(paths["train_ckpt"]),
},
{
"name": "thresholds",
"purpose": "Compare current thresholds vs precision+IC vs novelty-gated thresholds",
"output": str(ART / "threshold_comparison_results.json"),
},
{
"name": "generalization",
"purpose": "Evaluate mammal generalization with trained checkpoint and thresholds",
"output": str(paths["generalization_json"]),
},
{
"name": "summary",
"purpose": "Aggregate methodology and metrics into a single saved record",
"output": str(paths["summary_json"]),
},
],
"notes": [
"Threshold comparison currently saves metrics to artifacts/threshold_comparison_results.json.",
"Final training uses train_v3_fixed.py with explicit feature_level to match HPO output.",
"Mammal data are merged into training unless --skip-mammal is explicitly used.",
],
"stage_commands": {
"hpo": [
py, "scripts/hpo.py",
"--mammal", "artifacts/generalization/mammal_full_v1.parquet",
"--n_trials", str(args.n_trials),
"--epochs", str(args.hpo_epochs),
"--patience", str(args.hpo_patience),
"--alpha", str(args.alpha),
"--startup_trials", str(args.startup_trials),
"--warmup_steps", str(args.warmup_steps),
"--multivariate_tpe", "--group_tpe",
"--out", str(paths["hpo_json"]),
],
"train_template": [
py, "scripts/train_v3_fixed.py",
"--hidden", "<best.hidden>",
"--blocks", "<best.n_blocks>",
"--dropout", "<best.dropout>",
"--lr", "<best.lr>",
"--weight_decay", "<best.weight_decay>",
"--batch", "<best.batch>",
"--label_smooth", "<best.label_smooth>",
"--feature_level", "<best.feat_level>",
"--feature_label", "graph_hpo_best",
"--checkpoint_out", str(paths["train_ckpt"]),
"--threshold_out", str(paths["train_thr"]),
"--log_out", str(paths["train_log"]),
],
"threshold_comparison": [
py, "scripts/threshold_comparison.py",
],
"generalization_template": [
py, "scripts/eval_generalization.py",
"--checkpoint", str(paths["train_ckpt"]),
"--thresholds", str(paths["train_thr"]),
"--mlb", "Important Files/mlb_public_v1.pkl",
"--taxon_parquet", "artifacts/generalization/mammal_embeddings_v3.parquet",
"--taxon_name", "mammals_graph_hpo",
"--obo", "go-basic.obo",
"--out", str(paths["generalization_json"]),
],
},
}
def methodology_markdown(payload: dict) -> str:
lines = [
"# Graph HPO Training Sequence",
"",
"## Objective",
f"- `{payload['objective']['formula']}`",
f"- `alpha = {payload['objective']['alpha']}`",
"",
"## Graph-Aware HPO",
f"- Sampler: `{payload['graph_hpo']['sampler']}`",
f"- Multivariate: `{payload['graph_hpo']['multivariate']}`",
f"- Grouped: `{payload['graph_hpo']['group']}`",
f"- Pruner: `{payload['graph_hpo']['pruner']}`",
"",
"## Feature Levels",
]
for name, dim in payload["feature_levels"].items():
lines.append(f"- `{name}` -> `{dim}` dims")
lines += ["", "## Stages"]
for stage in payload["stages"]:
lines.append(f"- `{stage['name']}`: {stage['purpose']}")
lines.append(f" Output: `{stage['output']}`")
lines += ["", "## Notes"]
for note in payload["notes"]:
lines.append(f"- {note}")
lines += ["", "## Stage Commands"]
for name, cmd in payload["stage_commands"].items():
lines.append(f"- `{name}`")
lines.append(f" `{shlex.join(cmd)}`")
lines.append("")
return "\n".join(lines)
def build_commands(paths: dict[str, Path], args: argparse.Namespace) -> dict[str, list[str]]:
py = python_bin()
hpo_cmd = [
py,
str(BASE / "scripts" / "hpo.py"),
"--mammal", "artifacts/generalization/mammal_full_v1.parquet",
"--n_trials", str(args.n_trials),
"--epochs", str(args.hpo_epochs),
"--patience", str(args.hpo_patience),
"--alpha", str(args.alpha),
"--startup_trials", str(args.startup_trials),
"--warmup_steps", str(args.warmup_steps),
"--multivariate_tpe",
"--group_tpe",
"--out", str(paths["hpo_json"]),
]
return {
"hpo": hpo_cmd,
}
def run(cmd: list[str]) -> None:
print("$", shlex.join(cmd))
subprocess.run(cmd, cwd=BASE, check=True)
def load_json(path: Path):
if path.exists():
with open(path) as f:
return json.load(f)
return None
def write_summary(paths: dict[str, Path], payload: dict) -> None:
hpo = load_json(paths["hpo_json"])
threshold_metrics = load_json(ART / "threshold_comparison_results.json")
generalization = load_json(paths["generalization_json"]) or load_json(GENERALIZATION_ART / "generalization_results.json")
summary = {
"methodology": payload,
"artifacts": {k: str(v) for k, v in paths.items()},
"existing_metrics": {
"hpo": hpo,
"threshold_comparison": threshold_metrics,
"generalization": generalization,
},
}
with open(paths["summary_json"], "w") as f:
json.dump(summary, f, indent=2)
def main() -> None:
p = argparse.ArgumentParser(description="Set up / run graph-aware HPO training pipeline")
p.add_argument("--n_trials", type=int, default=40)
p.add_argument("--hpo_epochs", type=int, default=20)
p.add_argument("--hpo_patience", type=int, default=6)
p.add_argument("--alpha", type=float, default=0.6)
p.add_argument("--startup_trials", type=int, default=5)
p.add_argument("--warmup_steps", type=int, default=5)
p.add_argument("--run_hpo", action="store_true")
p.add_argument("--smoke", action="store_true",
help="Run a minimal 1-trial / 1-epoch HPO smoke test")
args = p.parse_args()
GRAPH_ART.mkdir(parents=True, exist_ok=True)
paths = build_paths()
payload = methodology_payload(paths, args)
with open(paths["methodology_json"], "w") as f:
json.dump(payload, f, indent=2)
with open(paths["methodology_md"], "w") as f:
f.write(methodology_markdown(payload))
cmds = build_commands(paths, args)
if args.smoke:
smoke_cmd = cmds["hpo"].copy()
for flag, value in [("--n_trials", "1"), ("--epochs", "1"), ("--patience", "1")]:
idx = smoke_cmd.index(flag)
smoke_cmd[idx + 1] = value
smoke_idx = smoke_cmd.index("--out")
smoke_cmd[smoke_idx + 1] = str(GRAPH_ART / "hpo_smoke.json")
run(smoke_cmd)
elif args.run_hpo:
run(cmds["hpo"])
write_summary(paths, payload)
print(f"Methodology JSON: {paths['methodology_json']}")
print(f"Methodology MD: {paths['methodology_md']}")
print(f"Summary JSON: {paths['summary_json']}")
if __name__ == "__main__":
main()