| """ |
| graph_hpo_sequence.py |
| ===================== |
| Set up and optionally execute a graph-aware HPO -> joint insect+mammal training |
| -> threshold comparison -> mammal generalization pipeline. |
| |
| The "graph-based" part is implemented as dependency-aware Optuna TPE: |
| - multivariate TPE |
| - grouped TPE |
| - median pruning |
| |
| Artifacts are written under artifacts/graph_hpo/ so methodology and metrics are |
| preserved separately from the default training outputs. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import shlex |
| import subprocess |
| from pathlib import Path |
|
|
|
|
| BASE = Path(__file__).parent.parent |
| ART = BASE / "artifacts" |
| GRAPH_ART = ART / "graph_hpo" |
| GENERALIZATION_ART = ART / "generalization" |
|
|
|
|
| def python_bin() -> str: |
| venv = BASE / ".venv" / "bin" / "python3" |
| return str(venv if venv.exists() else Path("python3")) |
|
|
|
|
| def build_paths() -> dict[str, Path]: |
| return { |
| "methodology_json": GRAPH_ART / "methodology.json", |
| "methodology_md": GRAPH_ART / "methodology.md", |
| "summary_json": GRAPH_ART / "sequence_summary.json", |
| "hpo_json": GRAPH_ART / "hpo_results.json", |
| "train_ckpt": GRAPH_ART / "graph_hpo_best.pth", |
| "train_thr": GRAPH_ART / "graph_hpo_best_thresholds.json", |
| "train_log": GRAPH_ART / "graph_hpo_best_log.json", |
| "generalization_json": GRAPH_ART / "generalization_results.json", |
| "generalization_csv": GRAPH_ART / "generalization_results.csv", |
| } |
|
|
|
|
| def methodology_payload(paths: dict[str, Path], args: argparse.Namespace) -> dict: |
| py = python_bin() |
| return { |
| "pipeline_name": "graph_hpo_joint_insect_mammal", |
| "objective": { |
| "type": "weighted_joint_micro_fmax", |
| "formula": "alpha * insect_val_micro_fmax + (1-alpha) * mammal_val_micro_fmax", |
| "alpha": args.alpha, |
| }, |
| "graph_hpo": { |
| "sampler": "Optuna TPESampler", |
| "multivariate": True, |
| "group": True, |
| "pruner": "MedianPruner", |
| "startup_trials": args.startup_trials, |
| "warmup_steps": args.warmup_steps, |
| }, |
| "feature_levels": { |
| "esm_only": 320, |
| "esm_seq": 331, |
| "esm_all": 360, |
| }, |
| "datasets": { |
| "insect_base": "Important Files/merged_full_struct.parquet", |
| "insect_supp": "Important Files/merged_full_struct_with_features.parquet", |
| "mammal": "artifacts/generalization/mammal_full_v1.parquet", |
| "splits": "artifacts/splits/splits_n250000_seed42.npz", |
| "mlb": "Important Files/mlb_public_v1.pkl", |
| }, |
| "stages": [ |
| { |
| "name": "hpo", |
| "purpose": "Graph-aware joint insect+mammal hyperparameter tuning", |
| "output": str(paths["hpo_json"]), |
| }, |
| { |
| "name": "train", |
| "purpose": "Train exact best feature slice with mammal data included", |
| "output": str(paths["train_ckpt"]), |
| }, |
| { |
| "name": "thresholds", |
| "purpose": "Compare current thresholds vs precision+IC vs novelty-gated thresholds", |
| "output": str(ART / "threshold_comparison_results.json"), |
| }, |
| { |
| "name": "generalization", |
| "purpose": "Evaluate mammal generalization with trained checkpoint and thresholds", |
| "output": str(paths["generalization_json"]), |
| }, |
| { |
| "name": "summary", |
| "purpose": "Aggregate methodology and metrics into a single saved record", |
| "output": str(paths["summary_json"]), |
| }, |
| ], |
| "notes": [ |
| "Threshold comparison currently saves metrics to artifacts/threshold_comparison_results.json.", |
| "Final training uses train_v3_fixed.py with explicit feature_level to match HPO output.", |
| "Mammal data are merged into training unless --skip-mammal is explicitly used.", |
| ], |
| "stage_commands": { |
| "hpo": [ |
| py, "scripts/hpo.py", |
| "--mammal", "artifacts/generalization/mammal_full_v1.parquet", |
| "--n_trials", str(args.n_trials), |
| "--epochs", str(args.hpo_epochs), |
| "--patience", str(args.hpo_patience), |
| "--alpha", str(args.alpha), |
| "--startup_trials", str(args.startup_trials), |
| "--warmup_steps", str(args.warmup_steps), |
| "--multivariate_tpe", "--group_tpe", |
| "--out", str(paths["hpo_json"]), |
| ], |
| "train_template": [ |
| py, "scripts/train_v3_fixed.py", |
| "--hidden", "<best.hidden>", |
| "--blocks", "<best.n_blocks>", |
| "--dropout", "<best.dropout>", |
| "--lr", "<best.lr>", |
| "--weight_decay", "<best.weight_decay>", |
| "--batch", "<best.batch>", |
| "--label_smooth", "<best.label_smooth>", |
| "--feature_level", "<best.feat_level>", |
| "--feature_label", "graph_hpo_best", |
| "--checkpoint_out", str(paths["train_ckpt"]), |
| "--threshold_out", str(paths["train_thr"]), |
| "--log_out", str(paths["train_log"]), |
| ], |
| "threshold_comparison": [ |
| py, "scripts/threshold_comparison.py", |
| ], |
| "generalization_template": [ |
| py, "scripts/eval_generalization.py", |
| "--checkpoint", str(paths["train_ckpt"]), |
| "--thresholds", str(paths["train_thr"]), |
| "--mlb", "Important Files/mlb_public_v1.pkl", |
| "--taxon_parquet", "artifacts/generalization/mammal_embeddings_v3.parquet", |
| "--taxon_name", "mammals_graph_hpo", |
| "--obo", "go-basic.obo", |
| "--out", str(paths["generalization_json"]), |
| ], |
| }, |
| } |
|
|
|
|
| def methodology_markdown(payload: dict) -> str: |
| lines = [ |
| "# Graph HPO Training Sequence", |
| "", |
| "## Objective", |
| f"- `{payload['objective']['formula']}`", |
| f"- `alpha = {payload['objective']['alpha']}`", |
| "", |
| "## Graph-Aware HPO", |
| f"- Sampler: `{payload['graph_hpo']['sampler']}`", |
| f"- Multivariate: `{payload['graph_hpo']['multivariate']}`", |
| f"- Grouped: `{payload['graph_hpo']['group']}`", |
| f"- Pruner: `{payload['graph_hpo']['pruner']}`", |
| "", |
| "## Feature Levels", |
| ] |
| for name, dim in payload["feature_levels"].items(): |
| lines.append(f"- `{name}` -> `{dim}` dims") |
| lines += ["", "## Stages"] |
| for stage in payload["stages"]: |
| lines.append(f"- `{stage['name']}`: {stage['purpose']}") |
| lines.append(f" Output: `{stage['output']}`") |
| lines += ["", "## Notes"] |
| for note in payload["notes"]: |
| lines.append(f"- {note}") |
| lines += ["", "## Stage Commands"] |
| for name, cmd in payload["stage_commands"].items(): |
| lines.append(f"- `{name}`") |
| lines.append(f" `{shlex.join(cmd)}`") |
| lines.append("") |
| return "\n".join(lines) |
|
|
|
|
| def build_commands(paths: dict[str, Path], args: argparse.Namespace) -> dict[str, list[str]]: |
| py = python_bin() |
| hpo_cmd = [ |
| py, |
| str(BASE / "scripts" / "hpo.py"), |
| "--mammal", "artifacts/generalization/mammal_full_v1.parquet", |
| "--n_trials", str(args.n_trials), |
| "--epochs", str(args.hpo_epochs), |
| "--patience", str(args.hpo_patience), |
| "--alpha", str(args.alpha), |
| "--startup_trials", str(args.startup_trials), |
| "--warmup_steps", str(args.warmup_steps), |
| "--multivariate_tpe", |
| "--group_tpe", |
| "--out", str(paths["hpo_json"]), |
| ] |
|
|
| return { |
| "hpo": hpo_cmd, |
| } |
|
|
|
|
| def run(cmd: list[str]) -> None: |
| print("$", shlex.join(cmd)) |
| subprocess.run(cmd, cwd=BASE, check=True) |
|
|
|
|
| def load_json(path: Path): |
| if path.exists(): |
| with open(path) as f: |
| return json.load(f) |
| return None |
|
|
|
|
| def write_summary(paths: dict[str, Path], payload: dict) -> None: |
| hpo = load_json(paths["hpo_json"]) |
| threshold_metrics = load_json(ART / "threshold_comparison_results.json") |
| generalization = load_json(paths["generalization_json"]) or load_json(GENERALIZATION_ART / "generalization_results.json") |
|
|
| summary = { |
| "methodology": payload, |
| "artifacts": {k: str(v) for k, v in paths.items()}, |
| "existing_metrics": { |
| "hpo": hpo, |
| "threshold_comparison": threshold_metrics, |
| "generalization": generalization, |
| }, |
| } |
| with open(paths["summary_json"], "w") as f: |
| json.dump(summary, f, indent=2) |
|
|
|
|
| def main() -> None: |
| p = argparse.ArgumentParser(description="Set up / run graph-aware HPO training pipeline") |
| p.add_argument("--n_trials", type=int, default=40) |
| p.add_argument("--hpo_epochs", type=int, default=20) |
| p.add_argument("--hpo_patience", type=int, default=6) |
| p.add_argument("--alpha", type=float, default=0.6) |
| p.add_argument("--startup_trials", type=int, default=5) |
| p.add_argument("--warmup_steps", type=int, default=5) |
| p.add_argument("--run_hpo", action="store_true") |
| p.add_argument("--smoke", action="store_true", |
| help="Run a minimal 1-trial / 1-epoch HPO smoke test") |
| args = p.parse_args() |
|
|
| GRAPH_ART.mkdir(parents=True, exist_ok=True) |
| paths = build_paths() |
| payload = methodology_payload(paths, args) |
|
|
| with open(paths["methodology_json"], "w") as f: |
| json.dump(payload, f, indent=2) |
| with open(paths["methodology_md"], "w") as f: |
| f.write(methodology_markdown(payload)) |
|
|
| cmds = build_commands(paths, args) |
| if args.smoke: |
| smoke_cmd = cmds["hpo"].copy() |
| for flag, value in [("--n_trials", "1"), ("--epochs", "1"), ("--patience", "1")]: |
| idx = smoke_cmd.index(flag) |
| smoke_cmd[idx + 1] = value |
| smoke_idx = smoke_cmd.index("--out") |
| smoke_cmd[smoke_idx + 1] = str(GRAPH_ART / "hpo_smoke.json") |
| run(smoke_cmd) |
| elif args.run_hpo: |
| run(cmds["hpo"]) |
|
|
| write_summary(paths, payload) |
| print(f"Methodology JSON: {paths['methodology_json']}") |
| print(f"Methodology MD: {paths['methodology_md']}") |
| print(f"Summary JSON: {paths['summary_json']}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|