| """Tests for the simulation-mode executor.""" |
| from forgeenv.sandbox.simulation_mode import SimulationExecutor |
| from forgeenv.tasks.models import Task |
|
|
| VALID_HF = """ |
| from transformers import Trainer, TrainingArguments |
| from datasets import load_dataset |
| import torch |
| |
| dataset = load_dataset("glue", "sst2") |
| trainer = Trainer(model=None, args=None, train_dataset=dataset) |
| trainer.train() |
| trainer.save_model("/tmp/forge_output/checkpoint") |
| print("TRAINING_COMPLETE") |
| """ |
|
|
| SYNTAX_ERROR = "def foo(\n broken" |
|
|
| OS_IMPORT = "import os\nos.listdir('.')" |
|
|
|
|
| def _task(content: str) -> Task: |
| return Task( |
| task_id="t", |
| description="d", |
| script_content=content, |
| difficulty="easy", |
| ) |
|
|
|
|
| def test_valid_script_can_succeed(): |
| """With seed 0, the valid HF script eventually returns a positive case.""" |
| executor = SimulationExecutor(seed=0) |
| result = executor.execute(VALID_HF, _task(VALID_HF)) |
| |
| |
| assert result.exit_code in (0, 1) |
| if result.exit_code == 0: |
| assert "TRAINING_COMPLETE" in result.stdout |
|
|
|
|
| def test_syntax_error_fails(): |
| executor = SimulationExecutor(seed=0) |
| result = executor.execute(SYNTAX_ERROR, _task(SYNTAX_ERROR)) |
| assert result.exit_code == 1 |
| assert "SyntaxError" in result.stderr |
|
|
|
|
| def test_forbidden_import_fails(): |
| executor = SimulationExecutor(seed=0) |
| result = executor.execute(OS_IMPORT, _task(OS_IMPORT)) |
| assert result.exit_code == 1 |
| assert "Validation failed" in result.stderr |
|
|
|
|
| def test_simulation_is_fast(): |
| """Simulation mode must complete each call in <100ms wall_time. |
| |
| The reported wall_time_ms field includes a synthetic delay so we measure |
| real elapsed time at this layer instead. |
| """ |
| import time |
| executor = SimulationExecutor(seed=0) |
| t0 = time.time() |
| executor.execute(VALID_HF, _task(VALID_HF)) |
| elapsed_ms = (time.time() - t0) * 1000 |
| assert elapsed_ms < 200, f"Simulation took {elapsed_ms:.1f}ms" |
|
|
|
|
| def test_seed_is_deterministic(): |
| e1 = SimulationExecutor(seed=42) |
| e2 = SimulationExecutor(seed=42) |
| r1 = e1.execute(VALID_HF, _task(VALID_HF)) |
| r2 = e2.execute(VALID_HF, _task(VALID_HF)) |
| assert r1.exit_code == r2.exit_code |
| assert r1.stderr == r2.stderr |
|
|