from __future__ import annotations import asyncio import json import os import subprocess from contextlib import asynccontextmanager from dataclasses import dataclass from pathlib import Path from openai import OpenAI from workflow_arena import WorkflowArenaAction, WorkflowArenaEnv from workflow_arena.models import ( DifficultyPreset, WorkflowActionType, WorkflowArenaObservation, WorkflowTaskView, ) BENCHMARK = "WorkflowArena" PRESETS = [ DifficultyPreset.EASY, DifficultyPreset.MEDIUM, DifficultyPreset.HARD, ] PROJECT_DIR = Path(__file__).resolve().parent IMAGE_NAME = "workflow-arena-inference:latest" DOCKERFILE_PATH = PROJECT_DIR / "server" / "Dockerfile" API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "qwen/qwen3.5-9b") HF_TOKEN = os.getenv("HF_TOKEN") LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") DEFAULT_BASE_URL = os.getenv("WORKFLOW_ARENA_BASE_URL", "http://localhost:8000") TEMPERATURE = 0.0 MAX_STEPS = 256 SYSTEM_PROMPT = ( "You are scheduling a dependency-constrained workflow on limited workers. " "Respond with compact JSON only. " 'Valid formats: {"action_type":"wait","task_ids":[]} or ' '{"action_type":"dispatch","task_ids":["task_01","task_02"]}. ' "Only dispatch task ids that appear in ready_tasks for the current observation. " "Never exceed free_workers. " 'If free_workers is 0 and running_tasks is non-empty, respond with {"action_type":"wait","task_ids":[]}. ' "If your previous action was invalid, use validation_error to correct it while still reasoning from the current observation. " "Never repeat a previously dispatched task unless it still appears in ready_tasks." ) def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step( step: int, action: str, reward: float, done: bool, error: str | None ) -> None: error_val = error if error else "null" done_val = str(done).lower() print( f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: rewards_str = ",".join(f"{reward:.2f}" for reward in rewards) print( f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True, ) def log_warning(message: str) -> None: print(f"[WARN] {message}", flush=True) def compact_task(task: WorkflowTaskView) -> dict[str, object]: return { "task_id": task.task_id, "duration": task.duration, "priority": task.priority, "deadline": task.deadline, "criticality": task.criticality, "slack": task.slack, "downstream_count": task.downstream_count, "dependencies": task.dependencies, "attempt_count": task.attempt_count, } def make_user_prompt(observation: WorkflowArenaObservation) -> str: must_wait = observation.free_workers == 0 and bool(observation.running_tasks) return json.dumps( { "instruction": observation.instruction, "current_time": observation.current_time, "effective_workers": observation.effective_workers, "degraded_workers": observation.degraded_workers, "free_workers": observation.free_workers, "time_budget": observation.time_budget, "time_remaining": observation.time_remaining, "must_wait": must_wait, "ready_tasks": [compact_task(task) for task in observation.ready_tasks], "running_tasks": [compact_task(task) for task in observation.running_tasks], "progress": observation.progress.model_dump(mode="json"), "reward_breakdown": observation.last_reward_breakdown.model_dump( mode="json" ), "note": observation.note, "validation_error": observation.validation_error, "recent_failure_events": [ event.model_dump(mode="json") for event in observation.recent_failure_events ], "last_action": observation.received_action, }, separators=(",", ":"), ) def heuristic_action(observation: WorkflowArenaObservation) -> WorkflowArenaAction: if observation.free_workers <= 0 and observation.running_tasks: return WorkflowArenaAction(action_type=WorkflowActionType.WAIT, task_ids=[]) if not observation.ready_tasks or observation.free_workers <= 0: return WorkflowArenaAction(action_type=WorkflowActionType.WAIT, task_ids=[]) time_remaining = observation.time_remaining ranked = sorted( observation.ready_tasks, key=lambda task: ( time_remaining is not None and task.duration > time_remaining, max(0, task.duration - time_remaining) if time_remaining is not None else 0, task.deadline if task.deadline is not None else 10**9, -(task.criticality or 0.0), -task.priority, task.duration, task.task_id, ), ) selected = [task.task_id for task in ranked[: observation.free_workers]] return WorkflowArenaAction( action_type=WorkflowActionType.DISPATCH, task_ids=selected, ) def parse_action( text: str, observation: WorkflowArenaObservation ) -> WorkflowArenaAction: text = text.strip() if not text: raise ValueError("Model response did not include JSON action") payload = json.loads(text) return WorkflowArenaAction.model_validate(payload) def get_model_action( client: OpenAI, model_name: str, observation: WorkflowArenaObservation, ) -> WorkflowArenaAction: prompt = make_user_prompt(observation) completion = client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}, ], temperature=TEMPERATURE, max_tokens=120, ) text = (completion.choices[0].message.content or "").strip() return parse_action(text, observation) def action_to_log_string(action: WorkflowArenaAction) -> str: payload = action.model_dump(mode="json") if payload.get("metadata") == {}: payload.pop("metadata", None) return json.dumps(payload, separators=(",", ":")) def resolve_model_client() -> tuple[OpenAI | None, str]: api_key = ( os.getenv("API_KEY") or HF_TOKEN or os.getenv("OPENAI_API_KEY") ) missing = [] if not api_key: missing.append("API_KEY or HF_TOKEN") if missing: log_warning( "Missing model configuration (" + ", ".join(missing) + "). Falling back to heuristic policy." ) return None, "heuristic" try: return OpenAI(base_url=API_BASE_URL, api_key=api_key), MODEL_NAME except Exception as exc: # pragma: no cover - defensive initialization fallback log_warning( f"Failed to initialize model client: {exc}. Falling back to heuristic policy." ) return None, "heuristic" def compute_score(observation: WorkflowArenaObservation) -> float: score = observation.benchmark_score if score is None: score = observation.success_metrics.benchmark_score return max(0.0, min(1.0, float(score or 0.0))) def is_success(observation: WorkflowArenaObservation) -> bool: return bool( observation.done and observation.success_metrics.makespan is not None and observation.termination_reason is None ) @dataclass class EpisodeResult: success: bool steps: int score: float rewards: list[float] def ensure_local_image() -> None: local_image_name = LOCAL_IMAGE_NAME or IMAGE_NAME try: inspect_result = subprocess.run( ["docker", "image", "inspect", local_image_name], cwd=PROJECT_DIR, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False, ) except OSError as exc: raise RuntimeError(f"Failed to execute docker: {exc}") from exc if inspect_result.returncode == 0: return try: build_result = subprocess.run( ["docker", "build", "-t", local_image_name, "-f", str(DOCKERFILE_PATH), "."], cwd=PROJECT_DIR, capture_output=True, text=True, check=False, ) except OSError as exc: raise RuntimeError(f"Failed to execute docker build: {exc}") from exc if build_result.returncode != 0: raise RuntimeError( "Failed to build Docker image for inference.\n" f"Command: docker build -t {local_image_name} -f {DOCKERFILE_PATH} .\n" f"Exit code: {build_result.returncode}\n" f"Stdout: {build_result.stdout}\n" f"Stderr: {build_result.stderr}" ) @asynccontextmanager async def managed_env(): try: async with WorkflowArenaEnv(base_url=DEFAULT_BASE_URL) as env: yield env return except Exception as exc: log_warning( f"Failed to connect to environment at {DEFAULT_BASE_URL}: {exc}. " "Trying local Docker fallback." ) ensure_local_image() env = await WorkflowArenaEnv.from_docker_image(LOCAL_IMAGE_NAME or IMAGE_NAME) try: yield env finally: try: await env.close() except Exception as exc: # pragma: no cover - teardown failures should not fail inference log_warning(f"Failed to close Docker environment cleanly: {exc}") async def run_episode( env, client: OpenAI | None, model_name: str, preset: DifficultyPreset, seed: int, ) -> EpisodeResult: rewards: list[float] = [] steps_taken = 0 success = False score = 0.0 log_start(task=preset.value, env=BENCHMARK, model=model_name) try: result = await env.reset( seed=seed, preset=preset.value, ) except Exception as exc: # pragma: no cover - env availability failures are external log_warning(f"Failed to reset preset={preset.value}: {exc}") log_end(success=False, steps=steps_taken, score=score, rewards=rewards) return EpisodeResult( success=success, steps=steps_taken, score=score, rewards=rewards ) observation = result.observation while not observation.done and steps_taken < MAX_STEPS: try: if client is None: action = heuristic_action(observation) else: action = get_model_action(client, model_name, observation) except ( Exception ): # pragma: no cover - network/model failures are expected sometimes action = heuristic_action(observation) try: result = await env.step(action) except Exception as exc: # pragma: no cover - preserve log format and continue safely fallback_action = heuristic_action(observation) if fallback_action != action: log_warning( f"Step failed for preset={preset.value} with model action: {exc}. " "Retrying with heuristic action." ) action = fallback_action try: result = await env.step(action) except Exception as retry_exc: log_warning( f"Step failed for preset={preset.value} even with heuristic action: {retry_exc}" ) break observation = result.observation reward = float(result.reward or 0.0) rewards.append(reward) steps_taken += 1 log_step( step=steps_taken, action=action_to_log_string(action), reward=reward, done=bool(result.done), error=observation.validation_error, ) success = is_success(observation) score = compute_score(observation) if observation.done else 0.0 log_end(success=success, steps=steps_taken, score=score, rewards=rewards) return EpisodeResult( success=success, steps=steps_taken, score=score, rewards=rewards ) async def main() -> None: client, model_name = resolve_model_client() async with managed_env() as env: for index, preset in enumerate(PRESETS): await run_episode( env=env, client=client, model_name=model_name, preset=preset, seed=100 + index, ) if __name__ == "__main__": try: asyncio.run(main()) except Exception as exc: # pragma: no cover - final safeguard for validator stability log_warning(f"Fatal inference error: {exc}")