Spaces:
Running
Running
| """OpenCode MCP Environment. | |
| Exposes a single tool ``run_rollout`` that runs one OpenCode agent rollout | |
| end-to-end against a caller-supplied LLM endpoint: | |
| 1. Spawn a fresh E2B sandbox (via the primitive's ``E2BSandboxBackend``). | |
| 2. Install opencode in the sandbox, write its config pointing at an | |
| in-sandbox proxy (Mode B) or the caller's LLM URL directly (Mode A). | |
| 3. Stage the caller-supplied task: instruction + test.sh + any extra files. | |
| 4. Run ``opencode run`` to completion. | |
| 5. Execute the verifier script; read the scalar reward from | |
| ``/home/user/logs/verifier/reward.txt``. | |
| 6. Collect proxy trace + workdir contents. | |
| 7. Return a JSON-serialized :class:`RolloutResult`. | |
| The env is deliberately task-agnostic β the training script passes the | |
| full task (instruction + verifier) through the tool arguments. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import threading | |
| import time | |
| from typing import Any, Optional | |
| from uuid import uuid4 | |
| from dotenv import load_dotenv | |
| from fastmcp import FastMCP | |
| from openenv.core.env_server.mcp_environment import MCPEnvironment | |
| from openenv.core.env_server.types import Action, Observation | |
| try: | |
| from .catalog import resolve_endpoint | |
| except ImportError: # pragma: no cover | |
| from catalog import resolve_endpoint # type: ignore | |
| load_dotenv() | |
| # One rollout (sandbox create + opencode install + opencode run + verifier) | |
| # typically takes 30-180s and can spike to 600s under load. Override | |
| # OpenEnv's 30s MCP tool-call default so the server doesn't cut our tool | |
| # off mid-flight. This default still defers to an explicit ``timeout_s`` | |
| # on the StepRequest when the client specifies one. | |
| _RUN_ROLLOUT_TIMEOUT_S = 900.0 | |
| # Default test-script and reward paths inside the sandbox. The server writes | |
| # the caller-supplied ``test_script`` text to this path; the verifier reads | |
| # the reward file back out after it finishes. | |
| REMOTE_TEST_PATH = "/home/user/tests/test.sh" | |
| REMOTE_REWARD_PATH = "/home/user/logs/verifier/reward.txt" | |
| WORKDIR_PATH = "/home/user/workdir" | |
| VERIFIER_TIMEOUT_S = 120 | |
| class _RolloutRegistry: | |
| """Per-environment bookkeeping for in-flight non-blocking rollouts. | |
| Every ``start_rollout`` call spawns a background thread that owns a | |
| session through its full lifecycle (create sandbox β install opencode | |
| β start proxy + opencode serve β fire instruction β wait for agent β | |
| finalize on demand). The registry keeps one :class:`_RolloutHandle` | |
| per ``rollout_id`` so subsequent ``subscribe_events`` / ``abort_rollout`` | |
| / ``finalize`` tool calls can find it. | |
| """ | |
| def __init__(self) -> None: | |
| self._lock = threading.Lock() | |
| self._handles: dict[str, "_RolloutHandle"] = {} | |
| def add(self, handle: "_RolloutHandle") -> None: | |
| with self._lock: | |
| self._handles[handle.rollout_id] = handle | |
| def get(self, rollout_id: str) -> "_RolloutHandle | None": | |
| with self._lock: | |
| return self._handles.get(rollout_id) | |
| def drop(self, rollout_id: str) -> None: | |
| with self._lock: | |
| self._handles.pop(rollout_id, None) | |
| class _RolloutHandle: | |
| """Everything the server needs to finalize / observe / abort a rollout.""" | |
| def __init__( | |
| self, | |
| rollout_id: str, | |
| task_id: str, | |
| session_factory_kwargs: dict[str, Any], | |
| task: Any, | |
| ) -> None: | |
| self.rollout_id = rollout_id | |
| self.task_id = task_id | |
| self._kwargs = session_factory_kwargs | |
| self._task = task | |
| # Set by the worker thread as the session progresses. | |
| self.session: Any | None = None | |
| self.error: str | None = None | |
| self.started_at = time.time() | |
| self.finished_at: float | None = None | |
| self._done = threading.Event() | |
| self._worker: threading.Thread | None = None | |
| self._verifier_stdout = "" | |
| self._verifier_stderr = "" | |
| self._test_exit_code: int | None = None | |
| self._test_script: str = "" | |
| def is_done(self) -> bool: | |
| return self._done.is_set() | |
| def wait(self, timeout: float | None = None) -> bool: | |
| return self._done.wait(timeout) | |
| class OpenCodeEnvironment(MCPEnvironment): | |
| """Two-tier MCP environment for OpenCode rollouts. | |
| The **macro tool** ``run_rollout`` runs a blocking end-to-end rollout | |
| and returns the final :class:`RolloutResult` (kept for training, which | |
| doesn't benefit from per-turn streaming). | |
| The **fine-grained tools** ``start_rollout`` / ``subscribe_events`` | |
| (server-side SSE via a helper endpoint) / ``abort_rollout`` / ``finalize`` | |
| are meant for the Gradio UI and interactive inspection β they expose | |
| the PR #471-style session primitives directly. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self) -> None: | |
| # Per-environment rollout registry (one per HTTP session/worker). | |
| self._registry = _RolloutRegistry() | |
| # Import inside __init__ to keep module import cheap and to allow | |
| # patching for tests. Dual-import pattern: package when installed, | |
| # flat when run directly out of the repo via ``server.app:app``. | |
| try: | |
| from ..models import OpenCodeState, RolloutResult, RolloutTurn | |
| except ImportError: | |
| from models import OpenCodeState, RolloutResult, RolloutTurn # type: ignore | |
| from opencode_env import ( | |
| E2BSandboxBackend, | |
| OpenCodeConfig, | |
| OpenCodeSessionFactory, | |
| collect_rollout_summary, | |
| ) | |
| from openenv.core.harness import VerifyResult | |
| self._state_cls = OpenCodeState | |
| self._result_cls = RolloutResult | |
| self._turn_cls = RolloutTurn | |
| self._OpenCodeConfig = OpenCodeConfig | |
| self._OpenCodeSessionFactory = OpenCodeSessionFactory | |
| self._E2BSandboxBackend = E2BSandboxBackend | |
| self._collect_rollout_summary = collect_rollout_summary | |
| self._VerifyResult = VerifyResult | |
| # Require E2B credentials up front β fail loudly if unset. | |
| if not os.environ.get("E2B_API_KEY"): | |
| raise RuntimeError( | |
| "E2B_API_KEY environment variable is required for OpenCodeEnvironment" | |
| ) | |
| self._state = self._state_cls(episode_id=str(uuid4())) | |
| mcp = FastMCP("opencode_env") | |
| def run_rollout( | |
| model_key: str, | |
| instruction: str, | |
| test_script: str, | |
| vllm_url: str = "", | |
| hf_token: str = "", | |
| thinking: bool = False, | |
| task_id: str = "", | |
| setup_shell: str = "", | |
| upload_files: Optional[dict[str, str]] = None, | |
| mode: str = "transparent_proxy", | |
| max_tokens_cap: int = 4096, | |
| agent_timeout_s: float = 600.0, | |
| ) -> str: | |
| """Run one OpenCode rollout end-to-end. | |
| Args: | |
| model_key: Catalog key β one of the entries in | |
| :data:`server.catalog.CATALOG`. Shape is | |
| ``"vllm://<repo>"`` or ``"hf-router://<repo>:<provider>"``. | |
| instruction: Prompt passed to ``opencode run``. | |
| test_script: Bash verifier. Must write a float reward to | |
| ``/home/user/logs/verifier/reward.txt``. | |
| vllm_url: Required when ``model_key`` is a ``vllm://...`` | |
| entry. The tunneled or in-cluster ``/v1`` endpoint. | |
| hf_token: Required when ``model_key`` is a | |
| ``hf-router://...`` entry. User's HF token. | |
| thinking: Enable Qwen-style thinking mode. Ignored for | |
| models where ``supports_thinking`` is False. Passed to | |
| the proxy as ``chat_template_kwargs.enable_thinking``. | |
| task_id: Optional identifier echoed back for traceability. | |
| setup_shell: Optional shell run before opencode starts. | |
| upload_files: Optional ``{remote_path: content}`` staged | |
| into the sandbox. | |
| mode: ``"transparent_proxy"`` (captures per-turn logprobs) | |
| or ``"black_box"`` (direct connection, no logprobs). | |
| max_tokens_cap: Clamp forwarded ``max_tokens``. | |
| agent_timeout_s: Max opencode runtime in seconds. | |
| Returns: | |
| JSON-serialized :class:`RolloutResult`. | |
| """ | |
| base_url, api_key, model, _entry = resolve_endpoint( | |
| model_key, vllm_url=vllm_url, hf_token=hf_token | |
| ) | |
| return self._run_rollout_impl( | |
| vllm_url=base_url, | |
| model=model, | |
| instruction=instruction, | |
| test_script=test_script, | |
| task_id=task_id, | |
| setup_shell=setup_shell, | |
| upload_files=upload_files or {}, | |
| provider="openai_compatible", | |
| api_key=api_key, | |
| mode=mode, | |
| disable_thinking=not bool(thinking), | |
| max_tokens_cap=max_tokens_cap, | |
| agent_timeout_s=agent_timeout_s, | |
| ) | |
| # ββ Fine-grained tools (non-blocking) β Phase 2b ββββββββββββββββββββ | |
| # | |
| # These wrap the primitive's ``driver="serve"`` path so the UI and | |
| # interactive clients can observe a rollout in flight. Training keeps | |
| # using the blocking ``run_rollout`` tool above. | |
| def start_rollout( | |
| model_key: str, | |
| instruction: str, | |
| test_script: str = "", | |
| vllm_url: str = "", | |
| hf_token: str = "", | |
| thinking: bool = False, | |
| task_id: str = "", | |
| setup_shell: str = "", | |
| upload_files: Optional[dict[str, str]] = None, | |
| mode: str = "transparent_proxy", | |
| max_tokens_cap: int = 4096, | |
| agent_timeout_s: float = 600.0, | |
| ) -> str: | |
| """Start a rollout asynchronously; return a ``rollout_id`` immediately. | |
| Same uniform args as :func:`run_rollout`: ``model_key``, plus | |
| ``vllm_url`` OR ``hf_token`` (depending on backend), plus | |
| ``thinking``. Spawns a background worker that creates the | |
| sandbox, installs opencode, boots ``opencode serve``, and | |
| fires the instruction. The caller then uses | |
| ``subscribe_events`` / ``get_state`` / ``abort_rollout`` / | |
| ``finalize`` with the returned id. | |
| """ | |
| base_url, api_key, model, _entry = resolve_endpoint( | |
| model_key, vllm_url=vllm_url, hf_token=hf_token | |
| ) | |
| rid = uuid4().hex[:12] | |
| handle = self._spawn_async_rollout( | |
| rollout_id=rid, | |
| vllm_url=base_url, | |
| model=model, | |
| instruction=instruction, | |
| test_script=test_script, | |
| task_id=task_id, | |
| setup_shell=setup_shell, | |
| upload_files=upload_files or {}, | |
| provider="openai_compatible", | |
| api_key=api_key, | |
| mode=mode, | |
| disable_thinking=not bool(thinking), | |
| max_tokens_cap=max_tokens_cap, | |
| agent_timeout_s=agent_timeout_s, | |
| ) | |
| return json.dumps({ | |
| "rollout_id": handle.rollout_id, | |
| "task_id": handle.task_id, | |
| "status": "starting", | |
| "started_at": handle.started_at, | |
| }) | |
| def get_state(rollout_id: str) -> str: | |
| """Return a small JSON snapshot of the rollout's current state. | |
| Useful for UI polling between SSE reconnects. The payload mirrors | |
| what an ``agent_done`` event would carry plus a ``status`` field. | |
| """ | |
| handle = self._registry.get(rollout_id) | |
| if handle is None: | |
| return json.dumps({"rollout_id": rollout_id, "status": "unknown"}) | |
| session = handle.session | |
| serve_session_id = getattr(session, "serve_session_id", None) | |
| turns_n = 0 | |
| if session is not None and session._proxy_trace_path: | |
| try: | |
| content = session.sandbox.read_text(session._proxy_trace_path) | |
| turns_n = sum(1 for line in content.splitlines() if line.strip()) | |
| except Exception: | |
| pass | |
| return json.dumps({ | |
| "rollout_id": rollout_id, | |
| "task_id": handle.task_id, | |
| "status": "done" if handle.is_done() else "running", | |
| "serve_session_id": serve_session_id, | |
| "proxy_turns_so_far": turns_n, | |
| "error": handle.error, | |
| "started_at": handle.started_at, | |
| "finished_at": handle.finished_at, | |
| }) | |
| def get_messages(rollout_id: str) -> str: | |
| """Return the sandbox-side opencode serve transcript for a rollout. | |
| Shape matches opencode's ``GET /session/:id/message`` β | |
| ``{"messages": [{info, parts}, ...]}``. Empty ``messages`` list | |
| if the rollout hasn't created its serve session yet, isn't | |
| running under the ``serve`` driver, or fetching the transcript | |
| failed. Designed for UI polling to render a live chat view. | |
| """ | |
| handle = self._registry.get(rollout_id) | |
| if handle is None: | |
| return json.dumps({"rollout_id": rollout_id, "messages": [], "status": "unknown"}) | |
| session = handle.session | |
| status = "done" if handle.is_done() else "running" | |
| if session is None: | |
| return json.dumps({ | |
| "rollout_id": rollout_id, | |
| "messages": [], | |
| "status": status, | |
| "error": handle.error, | |
| }) | |
| serve_client = getattr(session, "serve_client", None) | |
| serve_sid = getattr(session, "serve_session_id", None) | |
| if serve_client is None or not serve_sid: | |
| return json.dumps({ | |
| "rollout_id": rollout_id, | |
| "messages": [], | |
| "status": status, | |
| "note": "no serve driver (transcript unavailable)", | |
| }) | |
| try: | |
| msgs = serve_client.list_messages(serve_sid) or [] | |
| except Exception as exc: # noqa: BLE001 | |
| return json.dumps({ | |
| "rollout_id": rollout_id, | |
| "messages": [], | |
| "status": status, | |
| "error": f"list_messages failed: {type(exc).__name__}: {exc}", | |
| }) | |
| return json.dumps({ | |
| "rollout_id": rollout_id, | |
| "messages": msgs, | |
| "status": status, | |
| "serve_session_id": serve_sid, | |
| }) | |
| def abort_rollout(rollout_id: str) -> str: | |
| """Cancel an in-flight rollout. | |
| Calls ``POST /session/:id/abort`` on the sandbox's opencode serve | |
| (which stops token generation immediately), then lets the worker | |
| unwind and close the sandbox. Returns a small JSON payload. | |
| """ | |
| handle = self._registry.get(rollout_id) | |
| if handle is None: | |
| return json.dumps({"rollout_id": rollout_id, "aborted": False, "reason": "unknown"}) | |
| aborted = False | |
| try: | |
| if handle.session is not None and handle.session.driver == "serve": | |
| aborted = bool(handle.session.abort()) | |
| elif handle.session is not None: | |
| # CLI driver: kill the sandbox to stop opencode mid-run. | |
| handle.session.sandbox.kill() | |
| aborted = True | |
| except Exception as exc: # noqa: BLE001 | |
| return json.dumps({ | |
| "rollout_id": rollout_id, "aborted": False, "reason": str(exc), | |
| }) | |
| return json.dumps({"rollout_id": rollout_id, "aborted": aborted}) | |
| def finalize_rollout(rollout_id: str, wait_s: float = 300.0) -> str: | |
| """Block until the rollout's worker finishes, then return its | |
| full :class:`RolloutResult` as JSON (same shape as ``run_rollout``). | |
| Must be called exactly once per rollout β this also de-registers | |
| the handle and closes the sandbox. | |
| """ | |
| handle = self._registry.get(rollout_id) | |
| if handle is None: | |
| return json.dumps({ | |
| "task_id": "", "sandbox_id": "", "reward": None, | |
| "exit_code": 0, "wall_s": 0.0, "mode": "", | |
| "proxy_turns": [], "workdir_files": {}, | |
| "agent_log_tail": "", "verifier_stdout": "", | |
| "verifier_stderr": "", "test_exit_code": None, | |
| "error": f"unknown rollout_id: {rollout_id}", | |
| "proxy_log_tail": "", "install_log_tail": "", | |
| }) | |
| if not handle.wait(timeout=wait_s): | |
| return json.dumps({ | |
| "task_id": handle.task_id, "sandbox_id": "", | |
| "reward": None, "exit_code": -1, "wall_s": 0.0, | |
| "mode": "", "proxy_turns": [], "workdir_files": {}, | |
| "agent_log_tail": "", "verifier_stdout": "", | |
| "verifier_stderr": "", "test_exit_code": None, | |
| "error": f"timeout waiting for rollout_id={rollout_id}", | |
| "proxy_log_tail": "", "install_log_tail": "", | |
| }) | |
| payload = self._finalize_handle(handle) | |
| self._registry.drop(rollout_id) | |
| return payload | |
| super().__init__(mcp) | |
| # ββ OpenEnv lifecycle βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **_: Any, | |
| ) -> Observation: | |
| self._state = self._state_cls(episode_id=episode_id or str(uuid4())) | |
| return Observation( | |
| done=False, | |
| reward=None, | |
| metadata={ | |
| "status": "ready", | |
| "message": "OpenCode env ready. Call run_rollout(...) with a task.", | |
| }, | |
| ) | |
| def _step_impl( | |
| self, | |
| action: Action, | |
| timeout_s: Optional[float] = None, | |
| **_: Any, | |
| ) -> Observation: | |
| return Observation( | |
| done=False, | |
| reward=None, | |
| metadata={ | |
| "error": ( | |
| f"Unknown action type: {type(action).__name__}. " | |
| "Use CallToolAction(name='run_rollout', ...)." | |
| ), | |
| }, | |
| ) | |
| def step( | |
| self, | |
| action: Action, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| """Override the MCP 30s default with a rollout-friendly budget.""" | |
| if timeout_s is None: | |
| timeout_s = _RUN_ROLLOUT_TIMEOUT_S | |
| return super().step(action, timeout_s=timeout_s, **kwargs) | |
| async def step_async( | |
| self, | |
| action: Action, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| if timeout_s is None: | |
| timeout_s = _RUN_ROLLOUT_TIMEOUT_S | |
| return await super().step_async(action, timeout_s=timeout_s, **kwargs) | |
| def state(self) -> Any: | |
| return self._state | |
| # ββ Rollout implementation ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _run_rollout_impl( | |
| self, | |
| *, | |
| vllm_url: str, | |
| model: str, | |
| instruction: str, | |
| test_script: str, | |
| task_id: str, | |
| setup_shell: str, | |
| upload_files: dict[str, str], | |
| provider: str, | |
| api_key: str, | |
| mode: str, | |
| disable_thinking: bool, | |
| max_tokens_cap: int, | |
| agent_timeout_s: float, | |
| ) -> str: | |
| from opencode_env import OpenCodeTask | |
| result = self._result_cls(task_id=task_id, mode=mode) | |
| t0 = time.time() | |
| # Pass the resolved model id straight through β the primitive now | |
| # preserves ``config.model`` verbatim as the upstream model override, | |
| # so any ``_qualify_model`` wrapping here would double-prefix and | |
| # cause a 404 (``openai_compatible/Qwen/Qwen3.5-4B does not exist``). | |
| config = self._OpenCodeConfig( | |
| provider=provider, | |
| base_url=vllm_url.rstrip("/"), | |
| api_key=api_key, | |
| model=model, | |
| agent_timeout_s=agent_timeout_s, | |
| proxy_disable_thinking=disable_thinking, | |
| proxy_max_tokens_cap=max_tokens_cap if max_tokens_cap > 0 else None, | |
| ) | |
| factory = self._OpenCodeSessionFactory( | |
| config=config, | |
| sandbox_backend=self._E2BSandboxBackend(), | |
| mode=mode, # "transparent_proxy" or "black_box" | |
| verifier=None, # we run the caller's test_script ourselves below | |
| ) | |
| merged_uploads = dict(upload_files) | |
| merged_uploads[REMOTE_TEST_PATH] = test_script | |
| task = OpenCodeTask( | |
| instruction=instruction, | |
| setup_shell=setup_shell or None, | |
| upload_files=merged_uploads, | |
| metadata={"task_id": task_id}, | |
| ) | |
| session = None | |
| try: | |
| session = factory.create(task=task) | |
| result.sandbox_id = session.sandbox.sandbox_id | |
| exit_code = session.wait_for_completion(timeout_s=agent_timeout_s) | |
| result.exit_code = int(exit_code) | |
| # Run the verifier. Exit code is ignored; the reward file is the | |
| # source of truth. | |
| session.sandbox.exec( | |
| f"mkdir -p /home/user/logs/verifier /home/user/tests && " | |
| f"chmod +x {REMOTE_TEST_PATH}", | |
| timeout=15, | |
| ) | |
| verifier_run = session.sandbox.exec( | |
| f"bash {REMOTE_TEST_PATH}", | |
| cwd=WORKDIR_PATH, | |
| timeout=VERIFIER_TIMEOUT_S, | |
| ) | |
| result.test_exit_code = int(verifier_run.exit_code) | |
| result.verifier_stdout = (verifier_run.stdout or "")[:4000] | |
| result.verifier_stderr = (verifier_run.stderr or "")[:2000] | |
| result.reward = _read_reward(session.sandbox, REMOTE_REWARD_PATH) | |
| # Collect artifacts via the primitive's summary helper. | |
| summary = self._collect_rollout_summary(session) | |
| result.agent_log_tail = _tail(summary.opencode_events, 20) | |
| result.workdir_files = { | |
| path: (contents or "")[:8000] | |
| for path, contents in (summary.workdir_contents or {}).items() | |
| } | |
| for raw in summary.proxy_turns: | |
| result.proxy_turns.append(self._turn_cls(**_clamp_turn(raw))) | |
| # Diagnostic capture β always try to grab the in-sandbox proxy log. | |
| # Useful when productive_turns=0 or opencode exits abnormally. | |
| result.proxy_log_tail = _read_safe( | |
| session.sandbox, "/home/user/logs/agent/proxy.log", tail_chars=4000 | |
| ) | |
| except Exception as exc: | |
| result.error = f"{type(exc).__name__}: {exc}" | |
| # Even on failure, try to pull the proxy log if the sandbox still | |
| # exists β it often contains the actual upstream HTTP error. | |
| if session is not None: | |
| result.proxy_log_tail = _read_safe( | |
| session.sandbox, "/home/user/logs/agent/proxy.log", tail_chars=4000 | |
| ) | |
| finally: | |
| if session is not None: | |
| try: | |
| session.close() | |
| except Exception: | |
| pass | |
| result.wall_s = round(time.time() - t0, 3) | |
| # Persist lightweight state for bookkeeping. | |
| self._state.rollouts_completed += 1 | |
| self._state.last_reward = result.reward | |
| self._state.last_task_id = task_id or None | |
| self._state.last_sandbox_id = result.sandbox_id or None | |
| return result.model_dump_json() | |
| # ββ Async rollout plumbing (Phase 2b) ββββββββββββββββββββββββββββββββ | |
| def _spawn_async_rollout( | |
| self, | |
| *, | |
| rollout_id: str, | |
| vllm_url: str, | |
| model: str, | |
| instruction: str, | |
| test_script: str, | |
| task_id: str, | |
| setup_shell: str, | |
| upload_files: dict[str, str], | |
| provider: str, | |
| api_key: str, | |
| mode: str, | |
| disable_thinking: bool, | |
| max_tokens_cap: int, | |
| agent_timeout_s: float, | |
| ) -> _RolloutHandle: | |
| from opencode_env import OpenCodeTask | |
| merged_uploads = dict(upload_files) | |
| if test_script: | |
| merged_uploads[REMOTE_TEST_PATH] = test_script | |
| task = OpenCodeTask( | |
| instruction=instruction, | |
| setup_shell=setup_shell or None, | |
| upload_files=merged_uploads, | |
| metadata={"task_id": task_id}, | |
| ) | |
| # Pass model verbatim (no _qualify_model) β primitive now uses | |
| # ``config.model`` as the upstream override directly. | |
| config = self._OpenCodeConfig( | |
| provider=provider, | |
| base_url=vllm_url.rstrip("/"), | |
| api_key=api_key, | |
| model=model, | |
| agent_timeout_s=agent_timeout_s, | |
| proxy_disable_thinking=disable_thinking, | |
| proxy_max_tokens_cap=max_tokens_cap if max_tokens_cap > 0 else None, | |
| ) | |
| handle = _RolloutHandle( | |
| rollout_id=rollout_id, | |
| task_id=task_id, | |
| session_factory_kwargs={"config": config, "mode": mode, | |
| "agent_timeout_s": agent_timeout_s}, | |
| task=task, | |
| ) | |
| handle._test_script = test_script | |
| self._registry.add(handle) | |
| def worker() -> None: | |
| try: | |
| # serve driver: opencode serve runs inside the sandbox, the | |
| # primitive fires the prompt via POST /session/:id/prompt_async, | |
| # and ``list_messages(serve_session_id)`` is what powers the | |
| # live chat transcript exposed via the ``get_messages`` tool. | |
| factory = self._OpenCodeSessionFactory( | |
| config=config, | |
| sandbox_backend=self._E2BSandboxBackend(), | |
| mode=mode, | |
| verifier=None, | |
| driver="serve", | |
| ) | |
| handle.session = factory.create(task=task) | |
| try: | |
| handle.session.wait_for_completion(timeout_s=agent_timeout_s) | |
| except Exception as exc: # noqa: BLE001 | |
| handle.error = f"wait_for_completion: {type(exc).__name__}: {exc}" | |
| except Exception as exc: # noqa: BLE001 | |
| handle.error = f"{type(exc).__name__}: {exc}" | |
| finally: | |
| handle.finished_at = time.time() | |
| handle._done.set() | |
| t = threading.Thread(target=worker, daemon=True, name=f"rollout-{rollout_id}") | |
| t.start() | |
| handle._worker = t | |
| return handle | |
| def _finalize_handle(self, handle: _RolloutHandle) -> str: | |
| """Run the verifier (if present), collect the trace + workdir, and | |
| return a JSON-serialized :class:`RolloutResult`. Closes the session.""" | |
| result = self._result_cls(task_id=handle.task_id, | |
| mode=handle._kwargs.get("mode", "")) | |
| session = handle.session | |
| if session is None: | |
| result.error = handle.error or "session never created" | |
| return result.model_dump_json() | |
| result.sandbox_id = session.sandbox.sandbox_id | |
| result.exit_code = 0 | |
| wall_s = (handle.finished_at or time.time()) - handle.started_at | |
| result.wall_s = round(wall_s, 3) | |
| try: | |
| if handle._test_script: | |
| session.sandbox.exec( | |
| "mkdir -p /home/user/logs/verifier /home/user/tests && " | |
| f"chmod +x {REMOTE_TEST_PATH}", | |
| timeout=15, | |
| ) | |
| v = session.sandbox.exec( | |
| f"bash {REMOTE_TEST_PATH}", | |
| cwd=WORKDIR_PATH, | |
| timeout=VERIFIER_TIMEOUT_S, | |
| ) | |
| result.test_exit_code = int(v.exit_code) | |
| result.verifier_stdout = (v.stdout or "")[:4000] | |
| result.verifier_stderr = (v.stderr or "")[:2000] | |
| result.reward = _read_reward(session.sandbox, REMOTE_REWARD_PATH) | |
| summary = self._collect_rollout_summary(session) | |
| result.agent_log_tail = _tail(summary.opencode_events, 20) | |
| result.workdir_files = { | |
| path: (contents or "")[:8000] | |
| for path, contents in (summary.workdir_contents or {}).items() | |
| } | |
| for raw in summary.proxy_turns: | |
| result.proxy_turns.append(self._turn_cls(**_clamp_turn(raw))) | |
| result.proxy_log_tail = _read_safe( | |
| session.sandbox, "/home/user/logs/agent/proxy.log", tail_chars=4000 | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| if not handle.error: | |
| result.error = f"finalize: {type(exc).__name__}: {exc}" | |
| else: | |
| result.error = handle.error | |
| finally: | |
| try: | |
| session.close() | |
| except Exception: | |
| pass | |
| if handle.error and not result.error: | |
| result.error = handle.error | |
| return result.model_dump_json() | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _qualify_model(provider: str, model: str) -> str: | |
| """Return a ``<provider>/<model>`` string the primitive can split cleanly. | |
| The primitive splits ``config.model`` on the first ``/`` to recover the | |
| upstream model id. If the caller passes a model that already contains a | |
| slash (e.g. ``Qwen/Qwen3.5-4B``), we still prepend the provider so the | |
| split separates provider from model and the model part round-trips | |
| intact (``openai_compatible/Qwen/Qwen3.5-4B`` β upstream ``Qwen/Qwen3.5-4B``). | |
| """ | |
| # Strip an existing <provider>/ prefix only if it matches the configured | |
| # provider verbatim β otherwise treat the whole string as the model id. | |
| if model.startswith(provider + "/"): | |
| return model | |
| return f"{provider}/{model}" | |
| def _read_reward(sandbox: Any, reward_path: str) -> Optional[float]: | |
| try: | |
| raw = sandbox.read_text(reward_path).strip() | |
| except Exception: | |
| return None | |
| if not raw: | |
| return None | |
| try: | |
| return float(raw) | |
| except ValueError: | |
| return None | |
| def _clamp_turn(turn: dict[str, Any]) -> dict[str, Any]: | |
| """Clamp per-turn payload sizes to keep responses under a reasonable cap.""" | |
| out = dict(turn) | |
| raw_response = out.get("response") or {} | |
| choices = raw_response.get("choices") or [] | |
| first_choice = choices[0] if choices else {} | |
| compact: dict[str, Any] = { | |
| "finish_reason": first_choice.get("finish_reason"), | |
| "usage": raw_response.get("usage"), | |
| } | |
| # Surface upstream errors captured by the proxy so they reach the client. | |
| if raw_response.get("upstream_error") is not None: | |
| compact["upstream_error"] = raw_response["upstream_error"] | |
| if raw_response.get("upstream_status") is not None: | |
| compact["upstream_status"] = raw_response["upstream_status"] | |
| out["response"] = compact | |
| req = out.get("request") or {} | |
| messages = req.get("messages") or [] | |
| # Keep request messages (trainer needs them) but drop very long tool schemas. | |
| req = { | |
| "model": req.get("model"), | |
| "messages": messages, | |
| "temperature": req.get("temperature"), | |
| "top_p": req.get("top_p"), | |
| "max_tokens": req.get("max_tokens"), | |
| "max_completion_tokens": req.get("max_completion_tokens"), | |
| "logprobs": req.get("logprobs"), | |
| "top_logprobs": req.get("top_logprobs"), | |
| "stream": req.get("stream"), | |
| } | |
| out["request"] = req | |
| return out | |
| def _tail(events: list[dict[str, Any]], n: int) -> str: | |
| """Return the last ``n`` opencode event lines as a newline-joined string.""" | |
| if not events: | |
| return "" | |
| return "\n".join(json.dumps(e) for e in events[-n:]) | |
| def _read_safe(sandbox: Any, path: str, *, tail_chars: int = 4000) -> str: | |
| """Read a file from the sandbox, returning its last ``tail_chars`` chars. | |
| Returns empty string if the file is missing or unreadable. Used for | |
| diagnostic tails so we never mask a real failure with a read error. | |
| """ | |
| try: | |
| content = sandbox.read_text(path) or "" | |
| except Exception: | |
| return "" | |
| return content[-tail_chars:] | |