opencode-env-rollout / server /opencode_environment.py
AdithyaSK's picture
AdithyaSK HF Staff
Upload folder using huggingface_hub
d4d3fde verified
"""OpenCode MCP Environment.
Exposes a single tool ``run_rollout`` that runs one OpenCode agent rollout
end-to-end against a caller-supplied LLM endpoint:
1. Spawn a fresh E2B sandbox (via the primitive's ``E2BSandboxBackend``).
2. Install opencode in the sandbox, write its config pointing at an
in-sandbox proxy (Mode B) or the caller's LLM URL directly (Mode A).
3. Stage the caller-supplied task: instruction + test.sh + any extra files.
4. Run ``opencode run`` to completion.
5. Execute the verifier script; read the scalar reward from
``/home/user/logs/verifier/reward.txt``.
6. Collect proxy trace + workdir contents.
7. Return a JSON-serialized :class:`RolloutResult`.
The env is deliberately task-agnostic β€” the training script passes the
full task (instruction + verifier) through the tool arguments.
"""
from __future__ import annotations
import json
import os
import threading
import time
from typing import Any, Optional
from uuid import uuid4
from dotenv import load_dotenv
from fastmcp import FastMCP
from openenv.core.env_server.mcp_environment import MCPEnvironment
from openenv.core.env_server.types import Action, Observation
try:
from .catalog import resolve_endpoint
except ImportError: # pragma: no cover
from catalog import resolve_endpoint # type: ignore
load_dotenv()
# One rollout (sandbox create + opencode install + opencode run + verifier)
# typically takes 30-180s and can spike to 600s under load. Override
# OpenEnv's 30s MCP tool-call default so the server doesn't cut our tool
# off mid-flight. This default still defers to an explicit ``timeout_s``
# on the StepRequest when the client specifies one.
_RUN_ROLLOUT_TIMEOUT_S = 900.0
# Default test-script and reward paths inside the sandbox. The server writes
# the caller-supplied ``test_script`` text to this path; the verifier reads
# the reward file back out after it finishes.
REMOTE_TEST_PATH = "/home/user/tests/test.sh"
REMOTE_REWARD_PATH = "/home/user/logs/verifier/reward.txt"
WORKDIR_PATH = "/home/user/workdir"
VERIFIER_TIMEOUT_S = 120
class _RolloutRegistry:
"""Per-environment bookkeeping for in-flight non-blocking rollouts.
Every ``start_rollout`` call spawns a background thread that owns a
session through its full lifecycle (create sandbox β†’ install opencode
β†’ start proxy + opencode serve β†’ fire instruction β†’ wait for agent β†’
finalize on demand). The registry keeps one :class:`_RolloutHandle`
per ``rollout_id`` so subsequent ``subscribe_events`` / ``abort_rollout``
/ ``finalize`` tool calls can find it.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._handles: dict[str, "_RolloutHandle"] = {}
def add(self, handle: "_RolloutHandle") -> None:
with self._lock:
self._handles[handle.rollout_id] = handle
def get(self, rollout_id: str) -> "_RolloutHandle | None":
with self._lock:
return self._handles.get(rollout_id)
def drop(self, rollout_id: str) -> None:
with self._lock:
self._handles.pop(rollout_id, None)
class _RolloutHandle:
"""Everything the server needs to finalize / observe / abort a rollout."""
def __init__(
self,
rollout_id: str,
task_id: str,
session_factory_kwargs: dict[str, Any],
task: Any,
) -> None:
self.rollout_id = rollout_id
self.task_id = task_id
self._kwargs = session_factory_kwargs
self._task = task
# Set by the worker thread as the session progresses.
self.session: Any | None = None
self.error: str | None = None
self.started_at = time.time()
self.finished_at: float | None = None
self._done = threading.Event()
self._worker: threading.Thread | None = None
self._verifier_stdout = ""
self._verifier_stderr = ""
self._test_exit_code: int | None = None
self._test_script: str = ""
def is_done(self) -> bool:
return self._done.is_set()
def wait(self, timeout: float | None = None) -> bool:
return self._done.wait(timeout)
class OpenCodeEnvironment(MCPEnvironment):
"""Two-tier MCP environment for OpenCode rollouts.
The **macro tool** ``run_rollout`` runs a blocking end-to-end rollout
and returns the final :class:`RolloutResult` (kept for training, which
doesn't benefit from per-turn streaming).
The **fine-grained tools** ``start_rollout`` / ``subscribe_events``
(server-side SSE via a helper endpoint) / ``abort_rollout`` / ``finalize``
are meant for the Gradio UI and interactive inspection β€” they expose
the PR #471-style session primitives directly.
"""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self) -> None:
# Per-environment rollout registry (one per HTTP session/worker).
self._registry = _RolloutRegistry()
# Import inside __init__ to keep module import cheap and to allow
# patching for tests. Dual-import pattern: package when installed,
# flat when run directly out of the repo via ``server.app:app``.
try:
from ..models import OpenCodeState, RolloutResult, RolloutTurn
except ImportError:
from models import OpenCodeState, RolloutResult, RolloutTurn # type: ignore
from opencode_env import (
E2BSandboxBackend,
OpenCodeConfig,
OpenCodeSessionFactory,
collect_rollout_summary,
)
from openenv.core.harness import VerifyResult
self._state_cls = OpenCodeState
self._result_cls = RolloutResult
self._turn_cls = RolloutTurn
self._OpenCodeConfig = OpenCodeConfig
self._OpenCodeSessionFactory = OpenCodeSessionFactory
self._E2BSandboxBackend = E2BSandboxBackend
self._collect_rollout_summary = collect_rollout_summary
self._VerifyResult = VerifyResult
# Require E2B credentials up front β€” fail loudly if unset.
if not os.environ.get("E2B_API_KEY"):
raise RuntimeError(
"E2B_API_KEY environment variable is required for OpenCodeEnvironment"
)
self._state = self._state_cls(episode_id=str(uuid4()))
mcp = FastMCP("opencode_env")
@mcp.tool
def run_rollout(
model_key: str,
instruction: str,
test_script: str,
vllm_url: str = "",
hf_token: str = "",
thinking: bool = False,
task_id: str = "",
setup_shell: str = "",
upload_files: Optional[dict[str, str]] = None,
mode: str = "transparent_proxy",
max_tokens_cap: int = 4096,
agent_timeout_s: float = 600.0,
) -> str:
"""Run one OpenCode rollout end-to-end.
Args:
model_key: Catalog key β€” one of the entries in
:data:`server.catalog.CATALOG`. Shape is
``"vllm://<repo>"`` or ``"hf-router://<repo>:<provider>"``.
instruction: Prompt passed to ``opencode run``.
test_script: Bash verifier. Must write a float reward to
``/home/user/logs/verifier/reward.txt``.
vllm_url: Required when ``model_key`` is a ``vllm://...``
entry. The tunneled or in-cluster ``/v1`` endpoint.
hf_token: Required when ``model_key`` is a
``hf-router://...`` entry. User's HF token.
thinking: Enable Qwen-style thinking mode. Ignored for
models where ``supports_thinking`` is False. Passed to
the proxy as ``chat_template_kwargs.enable_thinking``.
task_id: Optional identifier echoed back for traceability.
setup_shell: Optional shell run before opencode starts.
upload_files: Optional ``{remote_path: content}`` staged
into the sandbox.
mode: ``"transparent_proxy"`` (captures per-turn logprobs)
or ``"black_box"`` (direct connection, no logprobs).
max_tokens_cap: Clamp forwarded ``max_tokens``.
agent_timeout_s: Max opencode runtime in seconds.
Returns:
JSON-serialized :class:`RolloutResult`.
"""
base_url, api_key, model, _entry = resolve_endpoint(
model_key, vllm_url=vllm_url, hf_token=hf_token
)
return self._run_rollout_impl(
vllm_url=base_url,
model=model,
instruction=instruction,
test_script=test_script,
task_id=task_id,
setup_shell=setup_shell,
upload_files=upload_files or {},
provider="openai_compatible",
api_key=api_key,
mode=mode,
disable_thinking=not bool(thinking),
max_tokens_cap=max_tokens_cap,
agent_timeout_s=agent_timeout_s,
)
# ── Fine-grained tools (non-blocking) β€” Phase 2b ────────────────────
#
# These wrap the primitive's ``driver="serve"`` path so the UI and
# interactive clients can observe a rollout in flight. Training keeps
# using the blocking ``run_rollout`` tool above.
@mcp.tool
def start_rollout(
model_key: str,
instruction: str,
test_script: str = "",
vllm_url: str = "",
hf_token: str = "",
thinking: bool = False,
task_id: str = "",
setup_shell: str = "",
upload_files: Optional[dict[str, str]] = None,
mode: str = "transparent_proxy",
max_tokens_cap: int = 4096,
agent_timeout_s: float = 600.0,
) -> str:
"""Start a rollout asynchronously; return a ``rollout_id`` immediately.
Same uniform args as :func:`run_rollout`: ``model_key``, plus
``vllm_url`` OR ``hf_token`` (depending on backend), plus
``thinking``. Spawns a background worker that creates the
sandbox, installs opencode, boots ``opencode serve``, and
fires the instruction. The caller then uses
``subscribe_events`` / ``get_state`` / ``abort_rollout`` /
``finalize`` with the returned id.
"""
base_url, api_key, model, _entry = resolve_endpoint(
model_key, vllm_url=vllm_url, hf_token=hf_token
)
rid = uuid4().hex[:12]
handle = self._spawn_async_rollout(
rollout_id=rid,
vllm_url=base_url,
model=model,
instruction=instruction,
test_script=test_script,
task_id=task_id,
setup_shell=setup_shell,
upload_files=upload_files or {},
provider="openai_compatible",
api_key=api_key,
mode=mode,
disable_thinking=not bool(thinking),
max_tokens_cap=max_tokens_cap,
agent_timeout_s=agent_timeout_s,
)
return json.dumps({
"rollout_id": handle.rollout_id,
"task_id": handle.task_id,
"status": "starting",
"started_at": handle.started_at,
})
@mcp.tool
def get_state(rollout_id: str) -> str:
"""Return a small JSON snapshot of the rollout's current state.
Useful for UI polling between SSE reconnects. The payload mirrors
what an ``agent_done`` event would carry plus a ``status`` field.
"""
handle = self._registry.get(rollout_id)
if handle is None:
return json.dumps({"rollout_id": rollout_id, "status": "unknown"})
session = handle.session
serve_session_id = getattr(session, "serve_session_id", None)
turns_n = 0
if session is not None and session._proxy_trace_path:
try:
content = session.sandbox.read_text(session._proxy_trace_path)
turns_n = sum(1 for line in content.splitlines() if line.strip())
except Exception:
pass
return json.dumps({
"rollout_id": rollout_id,
"task_id": handle.task_id,
"status": "done" if handle.is_done() else "running",
"serve_session_id": serve_session_id,
"proxy_turns_so_far": turns_n,
"error": handle.error,
"started_at": handle.started_at,
"finished_at": handle.finished_at,
})
@mcp.tool
def get_messages(rollout_id: str) -> str:
"""Return the sandbox-side opencode serve transcript for a rollout.
Shape matches opencode's ``GET /session/:id/message`` β€”
``{"messages": [{info, parts}, ...]}``. Empty ``messages`` list
if the rollout hasn't created its serve session yet, isn't
running under the ``serve`` driver, or fetching the transcript
failed. Designed for UI polling to render a live chat view.
"""
handle = self._registry.get(rollout_id)
if handle is None:
return json.dumps({"rollout_id": rollout_id, "messages": [], "status": "unknown"})
session = handle.session
status = "done" if handle.is_done() else "running"
if session is None:
return json.dumps({
"rollout_id": rollout_id,
"messages": [],
"status": status,
"error": handle.error,
})
serve_client = getattr(session, "serve_client", None)
serve_sid = getattr(session, "serve_session_id", None)
if serve_client is None or not serve_sid:
return json.dumps({
"rollout_id": rollout_id,
"messages": [],
"status": status,
"note": "no serve driver (transcript unavailable)",
})
try:
msgs = serve_client.list_messages(serve_sid) or []
except Exception as exc: # noqa: BLE001
return json.dumps({
"rollout_id": rollout_id,
"messages": [],
"status": status,
"error": f"list_messages failed: {type(exc).__name__}: {exc}",
})
return json.dumps({
"rollout_id": rollout_id,
"messages": msgs,
"status": status,
"serve_session_id": serve_sid,
})
@mcp.tool
def abort_rollout(rollout_id: str) -> str:
"""Cancel an in-flight rollout.
Calls ``POST /session/:id/abort`` on the sandbox's opencode serve
(which stops token generation immediately), then lets the worker
unwind and close the sandbox. Returns a small JSON payload.
"""
handle = self._registry.get(rollout_id)
if handle is None:
return json.dumps({"rollout_id": rollout_id, "aborted": False, "reason": "unknown"})
aborted = False
try:
if handle.session is not None and handle.session.driver == "serve":
aborted = bool(handle.session.abort())
elif handle.session is not None:
# CLI driver: kill the sandbox to stop opencode mid-run.
handle.session.sandbox.kill()
aborted = True
except Exception as exc: # noqa: BLE001
return json.dumps({
"rollout_id": rollout_id, "aborted": False, "reason": str(exc),
})
return json.dumps({"rollout_id": rollout_id, "aborted": aborted})
@mcp.tool
def finalize_rollout(rollout_id: str, wait_s: float = 300.0) -> str:
"""Block until the rollout's worker finishes, then return its
full :class:`RolloutResult` as JSON (same shape as ``run_rollout``).
Must be called exactly once per rollout β€” this also de-registers
the handle and closes the sandbox.
"""
handle = self._registry.get(rollout_id)
if handle is None:
return json.dumps({
"task_id": "", "sandbox_id": "", "reward": None,
"exit_code": 0, "wall_s": 0.0, "mode": "",
"proxy_turns": [], "workdir_files": {},
"agent_log_tail": "", "verifier_stdout": "",
"verifier_stderr": "", "test_exit_code": None,
"error": f"unknown rollout_id: {rollout_id}",
"proxy_log_tail": "", "install_log_tail": "",
})
if not handle.wait(timeout=wait_s):
return json.dumps({
"task_id": handle.task_id, "sandbox_id": "",
"reward": None, "exit_code": -1, "wall_s": 0.0,
"mode": "", "proxy_turns": [], "workdir_files": {},
"agent_log_tail": "", "verifier_stdout": "",
"verifier_stderr": "", "test_exit_code": None,
"error": f"timeout waiting for rollout_id={rollout_id}",
"proxy_log_tail": "", "install_log_tail": "",
})
payload = self._finalize_handle(handle)
self._registry.drop(rollout_id)
return payload
super().__init__(mcp)
# ── OpenEnv lifecycle ───────────────────────────────────────────────────
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**_: Any,
) -> Observation:
self._state = self._state_cls(episode_id=episode_id or str(uuid4()))
return Observation(
done=False,
reward=None,
metadata={
"status": "ready",
"message": "OpenCode env ready. Call run_rollout(...) with a task.",
},
)
def _step_impl(
self,
action: Action,
timeout_s: Optional[float] = None,
**_: Any,
) -> Observation:
return Observation(
done=False,
reward=None,
metadata={
"error": (
f"Unknown action type: {type(action).__name__}. "
"Use CallToolAction(name='run_rollout', ...)."
),
},
)
def step(
self,
action: Action,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> Observation:
"""Override the MCP 30s default with a rollout-friendly budget."""
if timeout_s is None:
timeout_s = _RUN_ROLLOUT_TIMEOUT_S
return super().step(action, timeout_s=timeout_s, **kwargs)
async def step_async(
self,
action: Action,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> Observation:
if timeout_s is None:
timeout_s = _RUN_ROLLOUT_TIMEOUT_S
return await super().step_async(action, timeout_s=timeout_s, **kwargs)
@property
def state(self) -> Any:
return self._state
# ── Rollout implementation ──────────────────────────────────────────────
def _run_rollout_impl(
self,
*,
vllm_url: str,
model: str,
instruction: str,
test_script: str,
task_id: str,
setup_shell: str,
upload_files: dict[str, str],
provider: str,
api_key: str,
mode: str,
disable_thinking: bool,
max_tokens_cap: int,
agent_timeout_s: float,
) -> str:
from opencode_env import OpenCodeTask
result = self._result_cls(task_id=task_id, mode=mode)
t0 = time.time()
# Pass the resolved model id straight through β€” the primitive now
# preserves ``config.model`` verbatim as the upstream model override,
# so any ``_qualify_model`` wrapping here would double-prefix and
# cause a 404 (``openai_compatible/Qwen/Qwen3.5-4B does not exist``).
config = self._OpenCodeConfig(
provider=provider,
base_url=vllm_url.rstrip("/"),
api_key=api_key,
model=model,
agent_timeout_s=agent_timeout_s,
proxy_disable_thinking=disable_thinking,
proxy_max_tokens_cap=max_tokens_cap if max_tokens_cap > 0 else None,
)
factory = self._OpenCodeSessionFactory(
config=config,
sandbox_backend=self._E2BSandboxBackend(),
mode=mode, # "transparent_proxy" or "black_box"
verifier=None, # we run the caller's test_script ourselves below
)
merged_uploads = dict(upload_files)
merged_uploads[REMOTE_TEST_PATH] = test_script
task = OpenCodeTask(
instruction=instruction,
setup_shell=setup_shell or None,
upload_files=merged_uploads,
metadata={"task_id": task_id},
)
session = None
try:
session = factory.create(task=task)
result.sandbox_id = session.sandbox.sandbox_id
exit_code = session.wait_for_completion(timeout_s=agent_timeout_s)
result.exit_code = int(exit_code)
# Run the verifier. Exit code is ignored; the reward file is the
# source of truth.
session.sandbox.exec(
f"mkdir -p /home/user/logs/verifier /home/user/tests && "
f"chmod +x {REMOTE_TEST_PATH}",
timeout=15,
)
verifier_run = session.sandbox.exec(
f"bash {REMOTE_TEST_PATH}",
cwd=WORKDIR_PATH,
timeout=VERIFIER_TIMEOUT_S,
)
result.test_exit_code = int(verifier_run.exit_code)
result.verifier_stdout = (verifier_run.stdout or "")[:4000]
result.verifier_stderr = (verifier_run.stderr or "")[:2000]
result.reward = _read_reward(session.sandbox, REMOTE_REWARD_PATH)
# Collect artifacts via the primitive's summary helper.
summary = self._collect_rollout_summary(session)
result.agent_log_tail = _tail(summary.opencode_events, 20)
result.workdir_files = {
path: (contents or "")[:8000]
for path, contents in (summary.workdir_contents or {}).items()
}
for raw in summary.proxy_turns:
result.proxy_turns.append(self._turn_cls(**_clamp_turn(raw)))
# Diagnostic capture β€” always try to grab the in-sandbox proxy log.
# Useful when productive_turns=0 or opencode exits abnormally.
result.proxy_log_tail = _read_safe(
session.sandbox, "/home/user/logs/agent/proxy.log", tail_chars=4000
)
except Exception as exc:
result.error = f"{type(exc).__name__}: {exc}"
# Even on failure, try to pull the proxy log if the sandbox still
# exists β€” it often contains the actual upstream HTTP error.
if session is not None:
result.proxy_log_tail = _read_safe(
session.sandbox, "/home/user/logs/agent/proxy.log", tail_chars=4000
)
finally:
if session is not None:
try:
session.close()
except Exception:
pass
result.wall_s = round(time.time() - t0, 3)
# Persist lightweight state for bookkeeping.
self._state.rollouts_completed += 1
self._state.last_reward = result.reward
self._state.last_task_id = task_id or None
self._state.last_sandbox_id = result.sandbox_id or None
return result.model_dump_json()
# ── Async rollout plumbing (Phase 2b) ────────────────────────────────
def _spawn_async_rollout(
self,
*,
rollout_id: str,
vllm_url: str,
model: str,
instruction: str,
test_script: str,
task_id: str,
setup_shell: str,
upload_files: dict[str, str],
provider: str,
api_key: str,
mode: str,
disable_thinking: bool,
max_tokens_cap: int,
agent_timeout_s: float,
) -> _RolloutHandle:
from opencode_env import OpenCodeTask
merged_uploads = dict(upload_files)
if test_script:
merged_uploads[REMOTE_TEST_PATH] = test_script
task = OpenCodeTask(
instruction=instruction,
setup_shell=setup_shell or None,
upload_files=merged_uploads,
metadata={"task_id": task_id},
)
# Pass model verbatim (no _qualify_model) β€” primitive now uses
# ``config.model`` as the upstream override directly.
config = self._OpenCodeConfig(
provider=provider,
base_url=vllm_url.rstrip("/"),
api_key=api_key,
model=model,
agent_timeout_s=agent_timeout_s,
proxy_disable_thinking=disable_thinking,
proxy_max_tokens_cap=max_tokens_cap if max_tokens_cap > 0 else None,
)
handle = _RolloutHandle(
rollout_id=rollout_id,
task_id=task_id,
session_factory_kwargs={"config": config, "mode": mode,
"agent_timeout_s": agent_timeout_s},
task=task,
)
handle._test_script = test_script
self._registry.add(handle)
def worker() -> None:
try:
# serve driver: opencode serve runs inside the sandbox, the
# primitive fires the prompt via POST /session/:id/prompt_async,
# and ``list_messages(serve_session_id)`` is what powers the
# live chat transcript exposed via the ``get_messages`` tool.
factory = self._OpenCodeSessionFactory(
config=config,
sandbox_backend=self._E2BSandboxBackend(),
mode=mode,
verifier=None,
driver="serve",
)
handle.session = factory.create(task=task)
try:
handle.session.wait_for_completion(timeout_s=agent_timeout_s)
except Exception as exc: # noqa: BLE001
handle.error = f"wait_for_completion: {type(exc).__name__}: {exc}"
except Exception as exc: # noqa: BLE001
handle.error = f"{type(exc).__name__}: {exc}"
finally:
handle.finished_at = time.time()
handle._done.set()
t = threading.Thread(target=worker, daemon=True, name=f"rollout-{rollout_id}")
t.start()
handle._worker = t
return handle
def _finalize_handle(self, handle: _RolloutHandle) -> str:
"""Run the verifier (if present), collect the trace + workdir, and
return a JSON-serialized :class:`RolloutResult`. Closes the session."""
result = self._result_cls(task_id=handle.task_id,
mode=handle._kwargs.get("mode", ""))
session = handle.session
if session is None:
result.error = handle.error or "session never created"
return result.model_dump_json()
result.sandbox_id = session.sandbox.sandbox_id
result.exit_code = 0
wall_s = (handle.finished_at or time.time()) - handle.started_at
result.wall_s = round(wall_s, 3)
try:
if handle._test_script:
session.sandbox.exec(
"mkdir -p /home/user/logs/verifier /home/user/tests && "
f"chmod +x {REMOTE_TEST_PATH}",
timeout=15,
)
v = session.sandbox.exec(
f"bash {REMOTE_TEST_PATH}",
cwd=WORKDIR_PATH,
timeout=VERIFIER_TIMEOUT_S,
)
result.test_exit_code = int(v.exit_code)
result.verifier_stdout = (v.stdout or "")[:4000]
result.verifier_stderr = (v.stderr or "")[:2000]
result.reward = _read_reward(session.sandbox, REMOTE_REWARD_PATH)
summary = self._collect_rollout_summary(session)
result.agent_log_tail = _tail(summary.opencode_events, 20)
result.workdir_files = {
path: (contents or "")[:8000]
for path, contents in (summary.workdir_contents or {}).items()
}
for raw in summary.proxy_turns:
result.proxy_turns.append(self._turn_cls(**_clamp_turn(raw)))
result.proxy_log_tail = _read_safe(
session.sandbox, "/home/user/logs/agent/proxy.log", tail_chars=4000
)
except Exception as exc: # noqa: BLE001
if not handle.error:
result.error = f"finalize: {type(exc).__name__}: {exc}"
else:
result.error = handle.error
finally:
try:
session.close()
except Exception:
pass
if handle.error and not result.error:
result.error = handle.error
return result.model_dump_json()
# ── Helpers ─────────────────────────────────────────────────────────────────
def _qualify_model(provider: str, model: str) -> str:
"""Return a ``<provider>/<model>`` string the primitive can split cleanly.
The primitive splits ``config.model`` on the first ``/`` to recover the
upstream model id. If the caller passes a model that already contains a
slash (e.g. ``Qwen/Qwen3.5-4B``), we still prepend the provider so the
split separates provider from model and the model part round-trips
intact (``openai_compatible/Qwen/Qwen3.5-4B`` β†’ upstream ``Qwen/Qwen3.5-4B``).
"""
# Strip an existing <provider>/ prefix only if it matches the configured
# provider verbatim β€” otherwise treat the whole string as the model id.
if model.startswith(provider + "/"):
return model
return f"{provider}/{model}"
def _read_reward(sandbox: Any, reward_path: str) -> Optional[float]:
try:
raw = sandbox.read_text(reward_path).strip()
except Exception:
return None
if not raw:
return None
try:
return float(raw)
except ValueError:
return None
def _clamp_turn(turn: dict[str, Any]) -> dict[str, Any]:
"""Clamp per-turn payload sizes to keep responses under a reasonable cap."""
out = dict(turn)
raw_response = out.get("response") or {}
choices = raw_response.get("choices") or []
first_choice = choices[0] if choices else {}
compact: dict[str, Any] = {
"finish_reason": first_choice.get("finish_reason"),
"usage": raw_response.get("usage"),
}
# Surface upstream errors captured by the proxy so they reach the client.
if raw_response.get("upstream_error") is not None:
compact["upstream_error"] = raw_response["upstream_error"]
if raw_response.get("upstream_status") is not None:
compact["upstream_status"] = raw_response["upstream_status"]
out["response"] = compact
req = out.get("request") or {}
messages = req.get("messages") or []
# Keep request messages (trainer needs them) but drop very long tool schemas.
req = {
"model": req.get("model"),
"messages": messages,
"temperature": req.get("temperature"),
"top_p": req.get("top_p"),
"max_tokens": req.get("max_tokens"),
"max_completion_tokens": req.get("max_completion_tokens"),
"logprobs": req.get("logprobs"),
"top_logprobs": req.get("top_logprobs"),
"stream": req.get("stream"),
}
out["request"] = req
return out
def _tail(events: list[dict[str, Any]], n: int) -> str:
"""Return the last ``n`` opencode event lines as a newline-joined string."""
if not events:
return ""
return "\n".join(json.dumps(e) for e in events[-n:])
def _read_safe(sandbox: Any, path: str, *, tail_chars: int = 4000) -> str:
"""Read a file from the sandbox, returning its last ``tail_chars`` chars.
Returns empty string if the file is missing or unreadable. Used for
diagnostic tails so we never mask a real failure with a read error.
"""
try:
content = sandbox.read_text(path) or ""
except Exception:
return ""
return content[-tail_chars:]