#!/usr/bin/env python """Local preflight: validate every component the H200 training job touches WITHOUT spending GPU time. Each test prints PASS/FAIL with a short reason. Run:: python scripts/preflight_check.py The script exits non-zero if any required test fails. Optional tests (network/Hub) print SKIP if HF_TOKEN is not set or the env Space is down. """ from __future__ import annotations import json import os import sys import tempfile import traceback from pathlib import Path from typing import Callable REPO_ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(REPO_ROOT)) PASS = "[PASS]" FAIL = "[FAIL]" SKIP = "[SKIP]" _results: list[tuple[str, str, str]] = [] def _run(label: str, fn: Callable[[], str | None], required: bool = True) -> None: try: detail = fn() or "" _results.append((PASS, label, detail)) print(f"{PASS} {label} {detail}", flush=True) except _Skip as s: _results.append((SKIP, label, str(s))) print(f"{SKIP} {label} {s}", flush=True) except Exception as e: # noqa: BLE001 tag = FAIL if required else SKIP _results.append((tag, label, f"{type(e).__name__}: {e}")) print(f"{tag} {label} {type(e).__name__}: {e}", flush=True) if required: traceback.print_exc() class _Skip(Exception): pass def t1_imports() -> str: import forgeenv # noqa: F401 import trl # noqa: F401 import peft # noqa: F401 import datasets # noqa: F401 import transformers # noqa: F401 import accelerate # noqa: F401 from forgeenv.training.grpo_repair import ( # noqa: F401 run_grpo, reward_repair_function, ) from forgeenv.training.plots import ( # noqa: F401 plot_baseline_vs_trained, plot_reward_curve, plot_success_rate_by_category, ) from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction # noqa: F401 from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff # noqa: F401 from forgeenv.env.forge_environment import ForgeEnvironment # noqa: F401 from forgeenv.roles.repair_agent import extract_diff # noqa: F401 from forgeenv.tasks.task_sampler import TaskSampler # noqa: F401 return f"trl={trl.__version__} transformers={transformers.__version__}" def t1b_openenv_job_extras() -> str: """On HF Jobs we ``pip install openenv-core --no-deps`` then add the packages openenv lists as requirements so ``import openenv.core`` works.""" import fastmcp # noqa: F401 return "fastmcp (required by openenv.core.env_server on import)" def t2_dataset_load_and_format() -> str: import datasets as ds p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl" if not p.exists(): raise FileNotFoundError(p) sft_ds = ds.load_dataset("json", data_files=str(p), split="train") n = len(sft_ds) if n < 10: raise ValueError(f"too few rows in repair_pairs.jsonl: {n}") row = sft_ds[0] if "messages" not in row or not row["messages"]: raise KeyError("row missing 'messages' field") roles = {m["role"] for m in row["messages"]} if not {"system", "user", "assistant"}.issubset(roles): raise ValueError(f"unexpected role set: {roles}") return f"rows={n} roles={sorted(roles)}" def t3_trl_configs_accept_our_kwargs() -> str: """Validate every kwarg name the job passes is accepted by the current TRL Config classes. We inspect dataclass fields directly so this works on CPU-only Windows without tripping bf16/use_cpu validation in transformers' TrainingArguments.__post_init__.""" import dataclasses from trl import GRPOConfig, SFTConfig sft_kwargs = { "output_dir": "/tmp/forge_sft", "max_steps": 10, "per_device_train_batch_size": 4, "gradient_accumulation_steps": 4, "learning_rate": 2e-4, "logging_steps": 25, "save_steps": 250, "bf16": True, "fp16": False, "max_length": 2048, "packing": True, "packing_strategy": "bfd", "report_to": [], } grpo_kwargs = { "output_dir": "/tmp/forge_grpo", "per_device_train_batch_size": 1, "gradient_accumulation_steps": 4, "learning_rate": 5e-6, "max_steps": 5, "num_generations": 4, "max_completion_length": 1024, "logging_steps": 5, "save_steps": 50, "save_total_limit": 2, "seed": 0, "report_to": "none", "beta": 0.04, } def _field_names(cls) -> set[str]: names: set[str] = set() for c in cls.__mro__: if dataclasses.is_dataclass(c): names.update(f.name for f in dataclasses.fields(c)) return names sft_fields = _field_names(SFTConfig) missing_sft = [k for k in sft_kwargs if k not in sft_fields] if missing_sft: raise TypeError(f"SFTConfig missing fields: {missing_sft}") grpo_fields = _field_names(GRPOConfig) missing_grpo = [k for k in grpo_kwargs if k not in grpo_fields] if missing_grpo: raise TypeError(f"GRPOConfig missing fields: {missing_grpo}") # Best-effort: try actually instantiating with use_cpu=True so even # __post_init__ runs cleanly under our preflight conditions. try: SFTConfig(**sft_kwargs, use_cpu=True, bf16=False) GRPOConfig(**grpo_kwargs, use_cpu=True) instantiated = "instantiated OK" except Exception as e: # noqa: BLE001 instantiated = f"field-check OK; instantiation skipped ({type(e).__name__})" return ( f"SFT/GRPO kwargs all valid; sft_fields={len(sft_fields)} " f"grpo_fields={len(grpo_fields)}; {instantiated}" ) def t4_reward_function_returns_float() -> str: from forgeenv.training.grpo_repair import reward_repair_function from forgeenv.tasks.task_sampler import TaskSampler sampler = TaskSampler() if not sampler.tasks: raise RuntimeError("TaskSampler has no tasks") task_id = sampler.tasks[0].task_id broken = "x = 1\nprint(x)\n" fake_completion = ( "--- a/train.py\n" "+++ b/train.py\n" "@@ -1,2 +1,2 @@\n" "-x = 1\n" "+x = 2\n" " print(x)\n" ) rewards = reward_repair_function( completions=[fake_completion], prompts=[[]], task_id=[task_id], broken_script=[broken], ) if len(rewards) != 1: raise ValueError(f"expected 1 reward got {len(rewards)}") if not isinstance(rewards[0], float): raise TypeError(f"reward not float: {type(rewards[0])}") return f"reward={rewards[0]:.3f} (single fake completion)" def t5_diff_utils_roundtrip() -> str: from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff from forgeenv.roles.repair_agent import extract_diff a = "x = 1\nprint(x)\n" b = "x = 2\nprint(x)\n" d = make_unified_diff(a, b) if not d.strip(): raise ValueError("make_unified_diff returned empty") blob = "Some thinking...\n```diff\n" + d + "\n```\nmore prose" extracted = extract_diff(blob) if not extracted.strip(): raise ValueError("extract_diff failed to find diff in fenced block") repaired = apply_unified_diff(a, extracted) if "x = 2" not in repaired: raise ValueError(f"apply_unified_diff failed: {repaired!r}") return f"diff_len={len(d)} extract+apply OK" def t6_live_env_health() -> str: import requests user = os.environ.get("HF_USERNAME", "akhiilll") url = f"https://{user}-forgeenv.hf.space/health" try: r = requests.get(url, timeout=15) except Exception as e: # noqa: BLE001 raise _Skip(f"network: {e}") if r.status_code >= 400: raise RuntimeError(f"{url} -> {r.status_code} {r.text[:80]}") return f"{r.status_code} {r.text[:60]!r}" def t7_source_repo_exists() -> str: token = os.environ.get("HF_TOKEN") if not token: raise _Skip("HF_TOKEN not set") from huggingface_hub import HfApi api = HfApi() user = os.environ.get("HF_USERNAME", "akhiilll") repo_id = f"{user}/forgeenv-source" files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token) needed = "scripts/jobs/train_repair_agent.py" if needed not in files: raise FileNotFoundError(f"{needed} missing from {repo_id} (files: {len(files)})") return f"{repo_id} has {len(files)} files incl. train_repair_agent.py" def t8_qwen_tokenizer_loads() -> str: base = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct") token = os.environ.get("HF_TOKEN") from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained(base, token=token, trust_remote_code=False) msgs = [ {"role": "system", "content": "you are a repair agent"}, {"role": "user", "content": "fix this"}, {"role": "assistant", "content": "--- a/train.py\n+++ b/train.py\n"}, ] text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) if "<|im_start|>" not in text: raise ValueError("ChatML tokens missing from rendered template") if "fix this" not in text: raise ValueError("user content not in rendered template") return f"{base} chat_template renders ChatML ({len(text)} chars)" def t9_hfapi_auth_and_namespace() -> str: token = os.environ.get("HF_TOKEN") if not token: raise _Skip("HF_TOKEN not set") from huggingface_hub import HfApi api = HfApi() info = api.whoami(token=token) user = info.get("name") or info.get("fullname") if not user: raise RuntimeError(f"whoami returned no name: {info}") expected = os.environ.get("HF_USERNAME", "akhiilll") if user != expected: return f"WARN: token user={user} but HF_USERNAME={expected}" return f"authed as {user}" def t10_find_trainer_state() -> str: sys.path.insert(0, str(REPO_ROOT / "scripts" / "jobs")) with tempfile.TemporaryDirectory() as td: td_p = Path(td) ckpt = td_p / "checkpoint-80" ckpt.mkdir() state = { "log_history": [ {"step": 5, "rewards/reward_repair_function/mean": 0.12}, {"step": 10, "rewards/reward_repair_function/mean": 0.34}, ] } (ckpt / "trainer_state.json").write_text(json.dumps(state)) from importlib import util as _util spec = _util.spec_from_file_location( "_train_mod", REPO_ROOT / "scripts" / "jobs" / "train_repair_agent.py" ) if spec is None or spec.loader is None: raise RuntimeError("can't spec the training script") # Don't actually load the module (it has top-level CUDA/HF effects). # Re-implement the same finder here from source. # The script uses: prefer GRPO_DIR/trainer_state.json, else newest checkpoint-*. direct = td_p / "trainer_state.json" if direct.exists(): found = direct else: ckpts = sorted( (p for p in td_p.glob("checkpoint-*") if (p / "trainer_state.json").exists()), key=lambda p: int(p.name.split("-")[-1]), ) found = (ckpts[-1] / "trainer_state.json") if ckpts else None if found is None or not found.exists(): raise RuntimeError("finder did not locate the synthesized state") loaded = json.loads(found.read_text()) if len(loaded["log_history"]) != 2: raise RuntimeError("finder loaded wrong file") return "checkpoint-N/trainer_state.json discoverable" def t11_warmstart_rows_all_valid() -> str: """Walk every warmstart row and check every row has system+user+assistant.""" p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl" rows = [json.loads(line) for line in p.read_text(encoding="utf-8").splitlines() if line.strip()] bad = [] for i, row in enumerate(rows): msgs = row.get("messages") or [] roles = [m.get("role") for m in msgs] if roles[:3] != ["system", "user", "assistant"]: bad.append((i, roles)) for m in msgs: if not isinstance(m.get("content"), str) or not m["content"].strip(): bad.append((i, "empty content")) break if bad: raise ValueError(f"{len(bad)} bad rows; first: {bad[0]}") return f"all {len(rows)} rows have system/user/assistant with non-empty content" def t12_tokenizer_renders_real_rows() -> str: """Render the chat template on the FIRST 5 real rows. Mirrors the SFT map step (`_format_chat`) the job runs after dataset.load_dataset.""" from transformers import AutoTokenizer base = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct") token = os.environ.get("HF_TOKEN") tok = AutoTokenizer.from_pretrained(base, token=token) p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl" rows = [json.loads(line) for line in p.read_text(encoding="utf-8").splitlines()][:5] lengths = [] for r in rows: text = tok.apply_chat_template( r["messages"], tokenize=False, add_generation_prompt=False ) toks = tok(text, return_tensors=None)["input_ids"] lengths.append(len(toks)) if max(lengths) > 4096: raise ValueError(f"row tokens > 4096 (would need bigger max_length): {lengths}") return f"5 rows render OK; token lengths={lengths} (max_length=2048 budget)" def t13_baseline_drift_generator_each_category() -> str: """Walk every primitive category the env supports and confirm the baseline drift generator returns a sane spec.""" from forgeenv.roles.drift_generator import BaselineDriftGenerator gen = BaselineDriftGenerator() cats = ["api_drift", "type_signature", "import_path", "config_schema", "deprecated_kwarg"] script = ( "from transformers import AutoTokenizer, Trainer\n" "tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n" "trainer = Trainer(model=None)\n" "trainer.train()\n" ) out: list[str] = [] for c in cats: spec = gen.propose(target_category=c, script=script) if "primitive_type" not in spec or "params" not in spec: raise ValueError(f"bad spec for {c}: {spec}") out.append(f"{c}->{spec['primitive_type']}") return "; ".join(out) def t14_forge_environment_reset_step() -> str: """End-to-end env smoke: reset() then step() with a real BreakageAction. Catches signature/serialisation drift between forgeenv and openenv.""" from forgeenv.env.actions import BreakageAction, ForgeAction from forgeenv.env.forge_environment import ForgeEnvironment env = ForgeEnvironment(seed=0) obs = env.reset(difficulty="easy") if not getattr(obs, "script_content", "").strip(): raise ValueError("reset() returned empty script_content") from forgeenv.roles.drift_generator import BaselineDriftGenerator spec = BaselineDriftGenerator().propose( target_category=getattr(obs, "target_category", "api_drift"), script=obs.script_content, ) obs2 = env.step( ForgeAction( breakage=BreakageAction( primitive_type=spec["primitive_type"], params=spec["params"] ) ) ) if not getattr(obs2, "script_content", "").strip(): raise ValueError("step() returned empty script_content") return f"reset+breakage step OK (task={obs.task_id}, primitive={spec['primitive_type']})" def t15_build_repair_prompt() -> str: """Run the exact `_build_repair_prompt` the GRPO loop calls per episode.""" from forgeenv.env.forge_environment import ForgeEnvironment from forgeenv.training.grpo_repair import _build_repair_prompt env = ForgeEnvironment(seed=0) ex = _build_repair_prompt(env) for k in ("prompt", "task_id", "primitive_type", "broken_script"): if k not in ex: raise KeyError(f"missing {k} in built example: {list(ex)}") if not isinstance(ex["prompt"], list) or len(ex["prompt"]) < 2: raise ValueError("prompt is not a chat-format list") if not ex["broken_script"].strip(): raise ValueError("empty broken_script") return f"task={ex['task_id']} primitive={ex['primitive_type']} prompt_msgs={len(ex['prompt'])}" def t16_rollout_one_episode() -> str: """Drive the full baseline rollout — drift -> repair -> reward.""" from forgeenv.env.forge_environment import ForgeEnvironment from forgeenv.training.rollout import rollout_one_episode env = ForgeEnvironment(seed=0) res = rollout_one_episode(env) if not hasattr(res, "visible_reward"): raise AttributeError("rollout result missing visible_reward") return ( f"reward={res.visible_reward:.3f} primitive={getattr(res,'primitive_type','?')}" ) def t17_plots_render() -> str: """Run all 3 plot helpers on synthetic data and check files appear.""" import tempfile from forgeenv.training.plots import ( plot_baseline_vs_trained, plot_reward_curve, plot_success_rate_by_category, ) with tempfile.TemporaryDirectory() as td: td_p = Path(td) plot_reward_curve([0.1, 0.2, 0.3, 0.4], str(td_p / "rc.png")) plot_baseline_vs_trained( [0.1, 0.2, 0.15], [0.4, 0.5, 0.6], str(td_p / "bvt.png") ) plot_success_rate_by_category( {"api_drift": [True, False, True], "type_signature": [False, True]}, str(td_p / "succ.png"), ) sizes = {p.name: p.stat().st_size for p in td_p.glob("*.png")} if any(s < 1000 for s in sizes.values()): raise ValueError(f"plot file too small (<1KB): {sizes}") return f"3 plots rendered: {sizes}" def t18_simulation_executor_and_reward() -> str: """Run the SimulationExecutor + visible reward on a real corpus task, once with the unmodified script (success) and once with junk (fail).""" from forgeenv.sandbox.simulation_mode import SimulationExecutor from forgeenv.tasks.task_sampler import TaskSampler from forgeenv.verifier.visible_verifier import compute_visible_reward sampler = TaskSampler() if not sampler.tasks: raise RuntimeError("no tasks in TaskSampler") task = sampler.tasks[0] executor = SimulationExecutor() # Path 1: original (canonical) script — should have non-negative reward. canonical = (REPO_ROOT / "forgeenv" / "tasks" / "seed_corpus" / f"{task.task_id}.py") if not canonical.exists(): # fall back: any seed file candidates = list((REPO_ROOT / "forgeenv" / "tasks" / "seed_corpus").glob("*.py")) if not candidates: raise FileNotFoundError("no seed corpus files") canonical = candidates[0] script = canonical.read_text(encoding="utf-8") res = executor.execute(script, task) res.script_content = script r_ok, _ = compute_visible_reward(res, task) # Path 2: gibberish — should clearly be lower. res2 = executor.execute("not_a_real_python_file = ", task) res2.script_content = "not_a_real_python_file = " r_bad, _ = compute_visible_reward(res2, task) if r_ok < r_bad: raise AssertionError(f"canonical reward {r_ok} should be >= gibberish {r_bad}") return f"r_canonical={r_ok:.3f} r_gibberish={r_bad:.3f} (delta {r_ok-r_bad:.3f})" def t19_repair_library_artifact() -> str: """Confirm the artifact the job copies into the final adapter exists.""" p = REPO_ROOT / "artifacts" / "repair_library.json" if not p.exists(): raise FileNotFoundError(p) data = json.loads(p.read_text(encoding="utf-8")) if not isinstance(data, (list, dict)) or not data: raise ValueError("repair_library.json is empty") n = len(data) if isinstance(data, list) else len(data.keys()) return f"repair_library.json has {n} entries" def t20_hub_upload_roundtrip() -> str: """Real round-trip on a tiny scratch repo so we know `upload_folder` works end-to-end (network + auth + private flag) before the GPU run.""" token = os.environ.get("HF_TOKEN") if not token: raise _Skip("HF_TOKEN not set") import tempfile from huggingface_hub import HfApi api = HfApi() user = os.environ.get("HF_USERNAME", "akhiilll") repo_id = f"{user}/forgeenv-preflight-scratch" api.create_repo(repo_id=repo_id, repo_type="model", token=token, exist_ok=True, private=True) with tempfile.TemporaryDirectory() as td: td_p = Path(td) (td_p / "ok.txt").write_text("preflight OK", encoding="utf-8") api.upload_folder( folder_path=str(td_p), repo_id=repo_id, repo_type="model", token=token, commit_message="preflight roundtrip", ) files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token) if "ok.txt" not in files: raise RuntimeError(f"upload roundtrip failed; files={files}") return f"{repo_id} round-trip OK ({len(files)} files)" def main() -> int: print(f"\n=== ForgeEnv preflight (repo: {REPO_ROOT}) ===\n", flush=True) _run("01 imports", t1_imports, required=True) _run("01b openenv extras (job: after --no-deps)", t1b_openenv_job_extras, required=True) _run("02 dataset load + format", t2_dataset_load_and_format, required=True) _run("03 TRL configs (SFT/GRPO) accept kwargs", t3_trl_configs_accept_our_kwargs, required=True) _run("04 reward fn returns float", t4_reward_function_returns_float, required=True) _run("05 diff utils round-trip", t5_diff_utils_roundtrip, required=True) _run("06 live env /health", t6_live_env_health, required=False) _run("07 forgeenv-source repo on Hub", t7_source_repo_exists, required=False) _run("08 Qwen tokenizer + ChatML", t8_qwen_tokenizer_loads, required=True) _run("09 HfApi auth", t9_hfapi_auth_and_namespace, required=False) _run("10 _find_trainer_state logic", t10_find_trainer_state, required=True) _run("11 every warmstart row valid", t11_warmstart_rows_all_valid, required=True) _run("12 tokenizer renders real rows", t12_tokenizer_renders_real_rows, required=True) _run("13 BaselineDriftGenerator each category", t13_baseline_drift_generator_each_category, required=True) _run("14 ForgeEnvironment reset+step", t14_forge_environment_reset_step, required=True) _run("15 _build_repair_prompt runs", t15_build_repair_prompt, required=True) _run("16 rollout_one_episode runs", t16_rollout_one_episode, required=True) _run("17 plots render to PNG", t17_plots_render, required=True) _run("18 SimulationExecutor + reward", t18_simulation_executor_and_reward, required=True) _run("19 repair_library.json artifact", t19_repair_library_artifact, required=True) _run("20 Hub upload round-trip", t20_hub_upload_roundtrip, required=False) print("\n=== Summary ===") n_pass = sum(1 for r in _results if r[0] == PASS) n_fail = sum(1 for r in _results if r[0] == FAIL) n_skip = sum(1 for r in _results if r[0] == SKIP) for tag, label, detail in _results: print(f"{tag} {label}") print(f"\n{n_pass} passed, {n_fail} failed, {n_skip} skipped") return 0 if n_fail == 0 else 1 if __name__ == "__main__": sys.exit(main())