""" graph_hpo_sequence.py ===================== Set up and optionally execute a graph-aware HPO -> joint insect+mammal training -> threshold comparison -> mammal generalization pipeline. The "graph-based" part is implemented as dependency-aware Optuna TPE: - multivariate TPE - grouped TPE - median pruning Artifacts are written under artifacts/graph_hpo/ so methodology and metrics are preserved separately from the default training outputs. """ from __future__ import annotations import argparse import json import shlex import subprocess from pathlib import Path BASE = Path(__file__).parent.parent ART = BASE / "artifacts" GRAPH_ART = ART / "graph_hpo" GENERALIZATION_ART = ART / "generalization" def python_bin() -> str: venv = BASE / ".venv" / "bin" / "python3" return str(venv if venv.exists() else Path("python3")) def build_paths() -> dict[str, Path]: return { "methodology_json": GRAPH_ART / "methodology.json", "methodology_md": GRAPH_ART / "methodology.md", "summary_json": GRAPH_ART / "sequence_summary.json", "hpo_json": GRAPH_ART / "hpo_results.json", "train_ckpt": GRAPH_ART / "graph_hpo_best.pth", "train_thr": GRAPH_ART / "graph_hpo_best_thresholds.json", "train_log": GRAPH_ART / "graph_hpo_best_log.json", "generalization_json": GRAPH_ART / "generalization_results.json", "generalization_csv": GRAPH_ART / "generalization_results.csv", } def methodology_payload(paths: dict[str, Path], args: argparse.Namespace) -> dict: py = python_bin() return { "pipeline_name": "graph_hpo_joint_insect_mammal", "objective": { "type": "weighted_joint_micro_fmax", "formula": "alpha * insect_val_micro_fmax + (1-alpha) * mammal_val_micro_fmax", "alpha": args.alpha, }, "graph_hpo": { "sampler": "Optuna TPESampler", "multivariate": True, "group": True, "pruner": "MedianPruner", "startup_trials": args.startup_trials, "warmup_steps": args.warmup_steps, }, "feature_levels": { "esm_only": 320, "esm_seq": 331, "esm_all": 360, }, "datasets": { "insect_base": "Important Files/merged_full_struct.parquet", "insect_supp": "Important Files/merged_full_struct_with_features.parquet", "mammal": "artifacts/generalization/mammal_full_v1.parquet", "splits": "artifacts/splits/splits_n250000_seed42.npz", "mlb": "Important Files/mlb_public_v1.pkl", }, "stages": [ { "name": "hpo", "purpose": "Graph-aware joint insect+mammal hyperparameter tuning", "output": str(paths["hpo_json"]), }, { "name": "train", "purpose": "Train exact best feature slice with mammal data included", "output": str(paths["train_ckpt"]), }, { "name": "thresholds", "purpose": "Compare current thresholds vs precision+IC vs novelty-gated thresholds", "output": str(ART / "threshold_comparison_results.json"), }, { "name": "generalization", "purpose": "Evaluate mammal generalization with trained checkpoint and thresholds", "output": str(paths["generalization_json"]), }, { "name": "summary", "purpose": "Aggregate methodology and metrics into a single saved record", "output": str(paths["summary_json"]), }, ], "notes": [ "Threshold comparison currently saves metrics to artifacts/threshold_comparison_results.json.", "Final training uses train_v3_fixed.py with explicit feature_level to match HPO output.", "Mammal data are merged into training unless --skip-mammal is explicitly used.", ], "stage_commands": { "hpo": [ py, "scripts/hpo.py", "--mammal", "artifacts/generalization/mammal_full_v1.parquet", "--n_trials", str(args.n_trials), "--epochs", str(args.hpo_epochs), "--patience", str(args.hpo_patience), "--alpha", str(args.alpha), "--startup_trials", str(args.startup_trials), "--warmup_steps", str(args.warmup_steps), "--multivariate_tpe", "--group_tpe", "--out", str(paths["hpo_json"]), ], "train_template": [ py, "scripts/train_v3_fixed.py", "--hidden", "", "--blocks", "", "--dropout", "", "--lr", "", "--weight_decay", "", "--batch", "", "--label_smooth", "", "--feature_level", "", "--feature_label", "graph_hpo_best", "--checkpoint_out", str(paths["train_ckpt"]), "--threshold_out", str(paths["train_thr"]), "--log_out", str(paths["train_log"]), ], "threshold_comparison": [ py, "scripts/threshold_comparison.py", ], "generalization_template": [ py, "scripts/eval_generalization.py", "--checkpoint", str(paths["train_ckpt"]), "--thresholds", str(paths["train_thr"]), "--mlb", "Important Files/mlb_public_v1.pkl", "--taxon_parquet", "artifacts/generalization/mammal_embeddings_v3.parquet", "--taxon_name", "mammals_graph_hpo", "--obo", "go-basic.obo", "--out", str(paths["generalization_json"]), ], }, } def methodology_markdown(payload: dict) -> str: lines = [ "# Graph HPO Training Sequence", "", "## Objective", f"- `{payload['objective']['formula']}`", f"- `alpha = {payload['objective']['alpha']}`", "", "## Graph-Aware HPO", f"- Sampler: `{payload['graph_hpo']['sampler']}`", f"- Multivariate: `{payload['graph_hpo']['multivariate']}`", f"- Grouped: `{payload['graph_hpo']['group']}`", f"- Pruner: `{payload['graph_hpo']['pruner']}`", "", "## Feature Levels", ] for name, dim in payload["feature_levels"].items(): lines.append(f"- `{name}` -> `{dim}` dims") lines += ["", "## Stages"] for stage in payload["stages"]: lines.append(f"- `{stage['name']}`: {stage['purpose']}") lines.append(f" Output: `{stage['output']}`") lines += ["", "## Notes"] for note in payload["notes"]: lines.append(f"- {note}") lines += ["", "## Stage Commands"] for name, cmd in payload["stage_commands"].items(): lines.append(f"- `{name}`") lines.append(f" `{shlex.join(cmd)}`") lines.append("") return "\n".join(lines) def build_commands(paths: dict[str, Path], args: argparse.Namespace) -> dict[str, list[str]]: py = python_bin() hpo_cmd = [ py, str(BASE / "scripts" / "hpo.py"), "--mammal", "artifacts/generalization/mammal_full_v1.parquet", "--n_trials", str(args.n_trials), "--epochs", str(args.hpo_epochs), "--patience", str(args.hpo_patience), "--alpha", str(args.alpha), "--startup_trials", str(args.startup_trials), "--warmup_steps", str(args.warmup_steps), "--multivariate_tpe", "--group_tpe", "--out", str(paths["hpo_json"]), ] return { "hpo": hpo_cmd, } def run(cmd: list[str]) -> None: print("$", shlex.join(cmd)) subprocess.run(cmd, cwd=BASE, check=True) def load_json(path: Path): if path.exists(): with open(path) as f: return json.load(f) return None def write_summary(paths: dict[str, Path], payload: dict) -> None: hpo = load_json(paths["hpo_json"]) threshold_metrics = load_json(ART / "threshold_comparison_results.json") generalization = load_json(paths["generalization_json"]) or load_json(GENERALIZATION_ART / "generalization_results.json") summary = { "methodology": payload, "artifacts": {k: str(v) for k, v in paths.items()}, "existing_metrics": { "hpo": hpo, "threshold_comparison": threshold_metrics, "generalization": generalization, }, } with open(paths["summary_json"], "w") as f: json.dump(summary, f, indent=2) def main() -> None: p = argparse.ArgumentParser(description="Set up / run graph-aware HPO training pipeline") p.add_argument("--n_trials", type=int, default=40) p.add_argument("--hpo_epochs", type=int, default=20) p.add_argument("--hpo_patience", type=int, default=6) p.add_argument("--alpha", type=float, default=0.6) p.add_argument("--startup_trials", type=int, default=5) p.add_argument("--warmup_steps", type=int, default=5) p.add_argument("--run_hpo", action="store_true") p.add_argument("--smoke", action="store_true", help="Run a minimal 1-trial / 1-epoch HPO smoke test") args = p.parse_args() GRAPH_ART.mkdir(parents=True, exist_ok=True) paths = build_paths() payload = methodology_payload(paths, args) with open(paths["methodology_json"], "w") as f: json.dump(payload, f, indent=2) with open(paths["methodology_md"], "w") as f: f.write(methodology_markdown(payload)) cmds = build_commands(paths, args) if args.smoke: smoke_cmd = cmds["hpo"].copy() for flag, value in [("--n_trials", "1"), ("--epochs", "1"), ("--patience", "1")]: idx = smoke_cmd.index(flag) smoke_cmd[idx + 1] = value smoke_idx = smoke_cmd.index("--out") smoke_cmd[smoke_idx + 1] = str(GRAPH_ART / "hpo_smoke.json") run(smoke_cmd) elif args.run_hpo: run(cmds["hpo"]) write_summary(paths, payload) print(f"Methodology JSON: {paths['methodology_json']}") print(f"Methodology MD: {paths['methodology_md']}") print(f"Summary JSON: {paths['summary_json']}") if __name__ == "__main__": main()