forgeenv-source / scripts /preflight_check.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
#!/usr/bin/env python
"""Local preflight: validate every component the H200 training job touches
WITHOUT spending GPU time. Each test prints PASS/FAIL with a short reason.
Run::
python scripts/preflight_check.py
The script exits non-zero if any required test fails. Optional tests
(network/Hub) print SKIP if HF_TOKEN is not set or the env Space is down.
"""
from __future__ import annotations
import json
import os
import sys
import tempfile
import traceback
from pathlib import Path
from typing import Callable
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))
PASS = "[PASS]"
FAIL = "[FAIL]"
SKIP = "[SKIP]"
_results: list[tuple[str, str, str]] = []
def _run(label: str, fn: Callable[[], str | None], required: bool = True) -> None:
try:
detail = fn() or ""
_results.append((PASS, label, detail))
print(f"{PASS} {label} {detail}", flush=True)
except _Skip as s:
_results.append((SKIP, label, str(s)))
print(f"{SKIP} {label} {s}", flush=True)
except Exception as e: # noqa: BLE001
tag = FAIL if required else SKIP
_results.append((tag, label, f"{type(e).__name__}: {e}"))
print(f"{tag} {label} {type(e).__name__}: {e}", flush=True)
if required:
traceback.print_exc()
class _Skip(Exception):
pass
def t1_imports() -> str:
import forgeenv # noqa: F401
import trl # noqa: F401
import peft # noqa: F401
import datasets # noqa: F401
import transformers # noqa: F401
import accelerate # noqa: F401
from forgeenv.training.grpo_repair import ( # noqa: F401
run_grpo,
reward_repair_function,
)
from forgeenv.training.plots import ( # noqa: F401
plot_baseline_vs_trained,
plot_reward_curve,
plot_success_rate_by_category,
)
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction # noqa: F401
from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff # noqa: F401
from forgeenv.env.forge_environment import ForgeEnvironment # noqa: F401
from forgeenv.roles.repair_agent import extract_diff # noqa: F401
from forgeenv.tasks.task_sampler import TaskSampler # noqa: F401
return f"trl={trl.__version__} transformers={transformers.__version__}"
def t1b_openenv_job_extras() -> str:
"""On HF Jobs we ``pip install openenv-core --no-deps`` then add the
packages openenv lists as requirements so ``import openenv.core`` works."""
import fastmcp # noqa: F401
return "fastmcp (required by openenv.core.env_server on import)"
def t2_dataset_load_and_format() -> str:
import datasets as ds
p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl"
if not p.exists():
raise FileNotFoundError(p)
sft_ds = ds.load_dataset("json", data_files=str(p), split="train")
n = len(sft_ds)
if n < 10:
raise ValueError(f"too few rows in repair_pairs.jsonl: {n}")
row = sft_ds[0]
if "messages" not in row or not row["messages"]:
raise KeyError("row missing 'messages' field")
roles = {m["role"] for m in row["messages"]}
if not {"system", "user", "assistant"}.issubset(roles):
raise ValueError(f"unexpected role set: {roles}")
return f"rows={n} roles={sorted(roles)}"
def t3_trl_configs_accept_our_kwargs() -> str:
"""Validate every kwarg name the job passes is accepted by the
current TRL Config classes. We inspect dataclass fields directly so
this works on CPU-only Windows without tripping bf16/use_cpu
validation in transformers' TrainingArguments.__post_init__."""
import dataclasses
from trl import GRPOConfig, SFTConfig
sft_kwargs = {
"output_dir": "/tmp/forge_sft",
"max_steps": 10,
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"learning_rate": 2e-4,
"logging_steps": 25,
"save_steps": 250,
"bf16": True,
"fp16": False,
"max_length": 2048,
"packing": True,
"packing_strategy": "bfd",
"report_to": [],
}
grpo_kwargs = {
"output_dir": "/tmp/forge_grpo",
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 4,
"learning_rate": 5e-6,
"max_steps": 5,
"num_generations": 4,
"max_completion_length": 1024,
"logging_steps": 5,
"save_steps": 50,
"save_total_limit": 2,
"seed": 0,
"report_to": "none",
"beta": 0.04,
}
def _field_names(cls) -> set[str]:
names: set[str] = set()
for c in cls.__mro__:
if dataclasses.is_dataclass(c):
names.update(f.name for f in dataclasses.fields(c))
return names
sft_fields = _field_names(SFTConfig)
missing_sft = [k for k in sft_kwargs if k not in sft_fields]
if missing_sft:
raise TypeError(f"SFTConfig missing fields: {missing_sft}")
grpo_fields = _field_names(GRPOConfig)
missing_grpo = [k for k in grpo_kwargs if k not in grpo_fields]
if missing_grpo:
raise TypeError(f"GRPOConfig missing fields: {missing_grpo}")
# Best-effort: try actually instantiating with use_cpu=True so even
# __post_init__ runs cleanly under our preflight conditions.
try:
SFTConfig(**sft_kwargs, use_cpu=True, bf16=False)
GRPOConfig(**grpo_kwargs, use_cpu=True)
instantiated = "instantiated OK"
except Exception as e: # noqa: BLE001
instantiated = f"field-check OK; instantiation skipped ({type(e).__name__})"
return (
f"SFT/GRPO kwargs all valid; sft_fields={len(sft_fields)} "
f"grpo_fields={len(grpo_fields)}; {instantiated}"
)
def t4_reward_function_returns_float() -> str:
from forgeenv.training.grpo_repair import reward_repair_function
from forgeenv.tasks.task_sampler import TaskSampler
sampler = TaskSampler()
if not sampler.tasks:
raise RuntimeError("TaskSampler has no tasks")
task_id = sampler.tasks[0].task_id
broken = "x = 1\nprint(x)\n"
fake_completion = (
"--- a/train.py\n"
"+++ b/train.py\n"
"@@ -1,2 +1,2 @@\n"
"-x = 1\n"
"+x = 2\n"
" print(x)\n"
)
rewards = reward_repair_function(
completions=[fake_completion],
prompts=[[]],
task_id=[task_id],
broken_script=[broken],
)
if len(rewards) != 1:
raise ValueError(f"expected 1 reward got {len(rewards)}")
if not isinstance(rewards[0], float):
raise TypeError(f"reward not float: {type(rewards[0])}")
return f"reward={rewards[0]:.3f} (single fake completion)"
def t5_diff_utils_roundtrip() -> str:
from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff
from forgeenv.roles.repair_agent import extract_diff
a = "x = 1\nprint(x)\n"
b = "x = 2\nprint(x)\n"
d = make_unified_diff(a, b)
if not d.strip():
raise ValueError("make_unified_diff returned empty")
blob = "Some thinking...\n```diff\n" + d + "\n```\nmore prose"
extracted = extract_diff(blob)
if not extracted.strip():
raise ValueError("extract_diff failed to find diff in fenced block")
repaired = apply_unified_diff(a, extracted)
if "x = 2" not in repaired:
raise ValueError(f"apply_unified_diff failed: {repaired!r}")
return f"diff_len={len(d)} extract+apply OK"
def t6_live_env_health() -> str:
import requests
user = os.environ.get("HF_USERNAME", "akhiilll")
url = f"https://{user}-forgeenv.hf.space/health"
try:
r = requests.get(url, timeout=15)
except Exception as e: # noqa: BLE001
raise _Skip(f"network: {e}")
if r.status_code >= 400:
raise RuntimeError(f"{url} -> {r.status_code} {r.text[:80]}")
return f"{r.status_code} {r.text[:60]!r}"
def t7_source_repo_exists() -> str:
token = os.environ.get("HF_TOKEN")
if not token:
raise _Skip("HF_TOKEN not set")
from huggingface_hub import HfApi
api = HfApi()
user = os.environ.get("HF_USERNAME", "akhiilll")
repo_id = f"{user}/forgeenv-source"
files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token)
needed = "scripts/jobs/train_repair_agent.py"
if needed not in files:
raise FileNotFoundError(f"{needed} missing from {repo_id} (files: {len(files)})")
return f"{repo_id} has {len(files)} files incl. train_repair_agent.py"
def t8_qwen_tokenizer_loads() -> str:
base = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct")
token = os.environ.get("HF_TOKEN")
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(base, token=token, trust_remote_code=False)
msgs = [
{"role": "system", "content": "you are a repair agent"},
{"role": "user", "content": "fix this"},
{"role": "assistant", "content": "--- a/train.py\n+++ b/train.py\n"},
]
text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
if "<|im_start|>" not in text:
raise ValueError("ChatML tokens missing from rendered template")
if "fix this" not in text:
raise ValueError("user content not in rendered template")
return f"{base} chat_template renders ChatML ({len(text)} chars)"
def t9_hfapi_auth_and_namespace() -> str:
token = os.environ.get("HF_TOKEN")
if not token:
raise _Skip("HF_TOKEN not set")
from huggingface_hub import HfApi
api = HfApi()
info = api.whoami(token=token)
user = info.get("name") or info.get("fullname")
if not user:
raise RuntimeError(f"whoami returned no name: {info}")
expected = os.environ.get("HF_USERNAME", "akhiilll")
if user != expected:
return f"WARN: token user={user} but HF_USERNAME={expected}"
return f"authed as {user}"
def t10_find_trainer_state() -> str:
sys.path.insert(0, str(REPO_ROOT / "scripts" / "jobs"))
with tempfile.TemporaryDirectory() as td:
td_p = Path(td)
ckpt = td_p / "checkpoint-80"
ckpt.mkdir()
state = {
"log_history": [
{"step": 5, "rewards/reward_repair_function/mean": 0.12},
{"step": 10, "rewards/reward_repair_function/mean": 0.34},
]
}
(ckpt / "trainer_state.json").write_text(json.dumps(state))
from importlib import util as _util
spec = _util.spec_from_file_location(
"_train_mod", REPO_ROOT / "scripts" / "jobs" / "train_repair_agent.py"
)
if spec is None or spec.loader is None:
raise RuntimeError("can't spec the training script")
# Don't actually load the module (it has top-level CUDA/HF effects).
# Re-implement the same finder here from source.
# The script uses: prefer GRPO_DIR/trainer_state.json, else newest checkpoint-*.
direct = td_p / "trainer_state.json"
if direct.exists():
found = direct
else:
ckpts = sorted(
(p for p in td_p.glob("checkpoint-*") if (p / "trainer_state.json").exists()),
key=lambda p: int(p.name.split("-")[-1]),
)
found = (ckpts[-1] / "trainer_state.json") if ckpts else None
if found is None or not found.exists():
raise RuntimeError("finder did not locate the synthesized state")
loaded = json.loads(found.read_text())
if len(loaded["log_history"]) != 2:
raise RuntimeError("finder loaded wrong file")
return "checkpoint-N/trainer_state.json discoverable"
def t11_warmstart_rows_all_valid() -> str:
"""Walk every warmstart row and check every row has system+user+assistant."""
p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl"
rows = [json.loads(line) for line in p.read_text(encoding="utf-8").splitlines() if line.strip()]
bad = []
for i, row in enumerate(rows):
msgs = row.get("messages") or []
roles = [m.get("role") for m in msgs]
if roles[:3] != ["system", "user", "assistant"]:
bad.append((i, roles))
for m in msgs:
if not isinstance(m.get("content"), str) or not m["content"].strip():
bad.append((i, "empty content"))
break
if bad:
raise ValueError(f"{len(bad)} bad rows; first: {bad[0]}")
return f"all {len(rows)} rows have system/user/assistant with non-empty content"
def t12_tokenizer_renders_real_rows() -> str:
"""Render the chat template on the FIRST 5 real rows. Mirrors the SFT
map step (`_format_chat`) the job runs after dataset.load_dataset."""
from transformers import AutoTokenizer
base = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct")
token = os.environ.get("HF_TOKEN")
tok = AutoTokenizer.from_pretrained(base, token=token)
p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl"
rows = [json.loads(line) for line in p.read_text(encoding="utf-8").splitlines()][:5]
lengths = []
for r in rows:
text = tok.apply_chat_template(
r["messages"], tokenize=False, add_generation_prompt=False
)
toks = tok(text, return_tensors=None)["input_ids"]
lengths.append(len(toks))
if max(lengths) > 4096:
raise ValueError(f"row tokens > 4096 (would need bigger max_length): {lengths}")
return f"5 rows render OK; token lengths={lengths} (max_length=2048 budget)"
def t13_baseline_drift_generator_each_category() -> str:
"""Walk every primitive category the env supports and confirm the
baseline drift generator returns a sane spec."""
from forgeenv.roles.drift_generator import BaselineDriftGenerator
gen = BaselineDriftGenerator()
cats = ["api_drift", "type_signature", "import_path", "config_schema", "deprecated_kwarg"]
script = (
"from transformers import AutoTokenizer, Trainer\n"
"tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n"
"trainer = Trainer(model=None)\n"
"trainer.train()\n"
)
out: list[str] = []
for c in cats:
spec = gen.propose(target_category=c, script=script)
if "primitive_type" not in spec or "params" not in spec:
raise ValueError(f"bad spec for {c}: {spec}")
out.append(f"{c}->{spec['primitive_type']}")
return "; ".join(out)
def t14_forge_environment_reset_step() -> str:
"""End-to-end env smoke: reset() then step() with a real BreakageAction.
Catches signature/serialisation drift between forgeenv and openenv."""
from forgeenv.env.actions import BreakageAction, ForgeAction
from forgeenv.env.forge_environment import ForgeEnvironment
env = ForgeEnvironment(seed=0)
obs = env.reset(difficulty="easy")
if not getattr(obs, "script_content", "").strip():
raise ValueError("reset() returned empty script_content")
from forgeenv.roles.drift_generator import BaselineDriftGenerator
spec = BaselineDriftGenerator().propose(
target_category=getattr(obs, "target_category", "api_drift"),
script=obs.script_content,
)
obs2 = env.step(
ForgeAction(
breakage=BreakageAction(
primitive_type=spec["primitive_type"], params=spec["params"]
)
)
)
if not getattr(obs2, "script_content", "").strip():
raise ValueError("step() returned empty script_content")
return f"reset+breakage step OK (task={obs.task_id}, primitive={spec['primitive_type']})"
def t15_build_repair_prompt() -> str:
"""Run the exact `_build_repair_prompt` the GRPO loop calls per episode."""
from forgeenv.env.forge_environment import ForgeEnvironment
from forgeenv.training.grpo_repair import _build_repair_prompt
env = ForgeEnvironment(seed=0)
ex = _build_repair_prompt(env)
for k in ("prompt", "task_id", "primitive_type", "broken_script"):
if k not in ex:
raise KeyError(f"missing {k} in built example: {list(ex)}")
if not isinstance(ex["prompt"], list) or len(ex["prompt"]) < 2:
raise ValueError("prompt is not a chat-format list")
if not ex["broken_script"].strip():
raise ValueError("empty broken_script")
return f"task={ex['task_id']} primitive={ex['primitive_type']} prompt_msgs={len(ex['prompt'])}"
def t16_rollout_one_episode() -> str:
"""Drive the full baseline rollout — drift -> repair -> reward."""
from forgeenv.env.forge_environment import ForgeEnvironment
from forgeenv.training.rollout import rollout_one_episode
env = ForgeEnvironment(seed=0)
res = rollout_one_episode(env)
if not hasattr(res, "visible_reward"):
raise AttributeError("rollout result missing visible_reward")
return (
f"reward={res.visible_reward:.3f} primitive={getattr(res,'primitive_type','?')}"
)
def t17_plots_render() -> str:
"""Run all 3 plot helpers on synthetic data and check files appear."""
import tempfile
from forgeenv.training.plots import (
plot_baseline_vs_trained,
plot_reward_curve,
plot_success_rate_by_category,
)
with tempfile.TemporaryDirectory() as td:
td_p = Path(td)
plot_reward_curve([0.1, 0.2, 0.3, 0.4], str(td_p / "rc.png"))
plot_baseline_vs_trained(
[0.1, 0.2, 0.15], [0.4, 0.5, 0.6], str(td_p / "bvt.png")
)
plot_success_rate_by_category(
{"api_drift": [True, False, True], "type_signature": [False, True]},
str(td_p / "succ.png"),
)
sizes = {p.name: p.stat().st_size for p in td_p.glob("*.png")}
if any(s < 1000 for s in sizes.values()):
raise ValueError(f"plot file too small (<1KB): {sizes}")
return f"3 plots rendered: {sizes}"
def t18_simulation_executor_and_reward() -> str:
"""Run the SimulationExecutor + visible reward on a real corpus task,
once with the unmodified script (success) and once with junk (fail)."""
from forgeenv.sandbox.simulation_mode import SimulationExecutor
from forgeenv.tasks.task_sampler import TaskSampler
from forgeenv.verifier.visible_verifier import compute_visible_reward
sampler = TaskSampler()
if not sampler.tasks:
raise RuntimeError("no tasks in TaskSampler")
task = sampler.tasks[0]
executor = SimulationExecutor()
# Path 1: original (canonical) script — should have non-negative reward.
canonical = (REPO_ROOT / "forgeenv" / "tasks" / "seed_corpus" / f"{task.task_id}.py")
if not canonical.exists():
# fall back: any seed file
candidates = list((REPO_ROOT / "forgeenv" / "tasks" / "seed_corpus").glob("*.py"))
if not candidates:
raise FileNotFoundError("no seed corpus files")
canonical = candidates[0]
script = canonical.read_text(encoding="utf-8")
res = executor.execute(script, task)
res.script_content = script
r_ok, _ = compute_visible_reward(res, task)
# Path 2: gibberish — should clearly be lower.
res2 = executor.execute("not_a_real_python_file = ", task)
res2.script_content = "not_a_real_python_file = "
r_bad, _ = compute_visible_reward(res2, task)
if r_ok < r_bad:
raise AssertionError(f"canonical reward {r_ok} should be >= gibberish {r_bad}")
return f"r_canonical={r_ok:.3f} r_gibberish={r_bad:.3f} (delta {r_ok-r_bad:.3f})"
def t19_repair_library_artifact() -> str:
"""Confirm the artifact the job copies into the final adapter exists."""
p = REPO_ROOT / "artifacts" / "repair_library.json"
if not p.exists():
raise FileNotFoundError(p)
data = json.loads(p.read_text(encoding="utf-8"))
if not isinstance(data, (list, dict)) or not data:
raise ValueError("repair_library.json is empty")
n = len(data) if isinstance(data, list) else len(data.keys())
return f"repair_library.json has {n} entries"
def t20_hub_upload_roundtrip() -> str:
"""Real round-trip on a tiny scratch repo so we know `upload_folder`
works end-to-end (network + auth + private flag) before the GPU run."""
token = os.environ.get("HF_TOKEN")
if not token:
raise _Skip("HF_TOKEN not set")
import tempfile
from huggingface_hub import HfApi
api = HfApi()
user = os.environ.get("HF_USERNAME", "akhiilll")
repo_id = f"{user}/forgeenv-preflight-scratch"
api.create_repo(repo_id=repo_id, repo_type="model", token=token, exist_ok=True, private=True)
with tempfile.TemporaryDirectory() as td:
td_p = Path(td)
(td_p / "ok.txt").write_text("preflight OK", encoding="utf-8")
api.upload_folder(
folder_path=str(td_p),
repo_id=repo_id,
repo_type="model",
token=token,
commit_message="preflight roundtrip",
)
files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token)
if "ok.txt" not in files:
raise RuntimeError(f"upload roundtrip failed; files={files}")
return f"{repo_id} round-trip OK ({len(files)} files)"
def main() -> int:
print(f"\n=== ForgeEnv preflight (repo: {REPO_ROOT}) ===\n", flush=True)
_run("01 imports", t1_imports, required=True)
_run("01b openenv extras (job: after --no-deps)", t1b_openenv_job_extras, required=True)
_run("02 dataset load + format", t2_dataset_load_and_format, required=True)
_run("03 TRL configs (SFT/GRPO) accept kwargs", t3_trl_configs_accept_our_kwargs, required=True)
_run("04 reward fn returns float", t4_reward_function_returns_float, required=True)
_run("05 diff utils round-trip", t5_diff_utils_roundtrip, required=True)
_run("06 live env /health", t6_live_env_health, required=False)
_run("07 forgeenv-source repo on Hub", t7_source_repo_exists, required=False)
_run("08 Qwen tokenizer + ChatML", t8_qwen_tokenizer_loads, required=True)
_run("09 HfApi auth", t9_hfapi_auth_and_namespace, required=False)
_run("10 _find_trainer_state logic", t10_find_trainer_state, required=True)
_run("11 every warmstart row valid", t11_warmstart_rows_all_valid, required=True)
_run("12 tokenizer renders real rows", t12_tokenizer_renders_real_rows, required=True)
_run("13 BaselineDriftGenerator each category", t13_baseline_drift_generator_each_category, required=True)
_run("14 ForgeEnvironment reset+step", t14_forge_environment_reset_step, required=True)
_run("15 _build_repair_prompt runs", t15_build_repair_prompt, required=True)
_run("16 rollout_one_episode runs", t16_rollout_one_episode, required=True)
_run("17 plots render to PNG", t17_plots_render, required=True)
_run("18 SimulationExecutor + reward", t18_simulation_executor_and_reward, required=True)
_run("19 repair_library.json artifact", t19_repair_library_artifact, required=True)
_run("20 Hub upload round-trip", t20_hub_upload_roundtrip, required=False)
print("\n=== Summary ===")
n_pass = sum(1 for r in _results if r[0] == PASS)
n_fail = sum(1 for r in _results if r[0] == FAIL)
n_skip = sum(1 for r in _results if r[0] == SKIP)
for tag, label, detail in _results:
print(f"{tag} {label}")
print(f"\n{n_pass} passed, {n_fail} failed, {n_skip} skipped")
return 0 if n_fail == 0 else 1
if __name__ == "__main__":
sys.exit(main())