AdithyaSK's picture
AdithyaSK HF Staff
Upload folder using huggingface_hub
698641d verified
"""FastAPI application for the OpenCode Environment.
Mirrors the ``e2b_desktop`` pattern: pass a ``gradio_builder`` to
``create_app`` and let OpenEnv handle the Gradio mount (including the
HF-Space-compatible ``/web`` path). No manual ``gr.mount_gradio_app``.
Also mounts a bespoke SSE endpoint at ``GET /rollouts/{rollout_id}/events``
that multiplexes opencode serve's ``/event`` stream with our proxy's
per-turn frames. MCP tools don't support streaming; this gives the UI
and interactive clients a live feed.
Usage::
# Development:
E2B_API_KEY=... uv run uvicorn server.app:app --reload
# Via uv project script:
E2B_API_KEY=... uv run --project . server
# Docker:
docker run -p 8000:8000 -e E2B_API_KEY=... opencode-openenv
"""
from __future__ import annotations
import os
try:
from openenv.core.env_server.http_server import create_app
from openenv.core.env_server.mcp_types import (
CallToolAction,
CallToolObservation,
)
from .opencode_environment import OpenCodeEnvironment
from .gradio_ui import opencode_ui_builder
except ImportError:
from openenv.core.env_server.http_server import create_app
from openenv.core.env_server.mcp_types import (
CallToolAction,
CallToolObservation,
)
from server.opencode_environment import OpenCodeEnvironment
from server.gradio_ui import opencode_ui_builder
def _custom_gradio_builder(
web_manager,
action_fields,
metadata,
is_chat_env,
title,
quick_start_md,
):
"""Callback invoked by ``create_app`` to build our custom Gradio UI.
We ignore ``web_manager`` (its public API is ``reset_environment`` /
``step_environment`` / ``connect_websocket`` — not an env instance) and
hand the UI the env class directly, matching e2b_desktop's pattern.
"""
return opencode_ui_builder(env_factory=OpenCodeEnvironment)
# Enable OpenEnv's built-in Gradio mounting at the standard /web path.
os.environ.setdefault("ENABLE_WEB_INTERFACE", "true")
app = create_app(
OpenCodeEnvironment,
CallToolAction,
CallToolObservation,
env_name="opencode_env",
max_concurrent_envs=int(os.getenv("MAX_CONCURRENT_ENVS", "4")),
gradio_builder=_custom_gradio_builder,
)
def _find_active_environment(request):
"""Locate a currently-active OpenCodeEnvironment instance.
``create_app`` stores per-session envs internally; we don't have a
public accessor, so we poke at ``app.state`` attributes that match
OpenEnv's conventions. As a last resort we create a fresh env —
fine for single-worker Spaces because registries live in-process
and the default env is idle until a tool is called.
"""
# Most recent "env" attribute on app.state that looks like ours.
for attr_name in ("env_cache", "envs", "environments", "_envs"):
cache = getattr(app.state, attr_name, None)
if cache:
try:
if isinstance(cache, dict):
return next(iter(cache.values()))
if isinstance(cache, (list, tuple)):
return cache[-1]
except Exception:
pass
# Fallback — make a new env. Safe because the SSE endpoint only
# needs the _registry dict, which we then look up rollout_id in.
try:
return OpenCodeEnvironment()
except Exception:
return None
@app.get("/rollouts/{rollout_id}/events")
async def rollout_events(rollout_id: str):
"""Server-Sent Events feed for a rollout started via ``start_rollout``.
Merges two streams:
1. opencode serve's ``GET /event`` (session-level events: message
parts, tool calls, idle/abort markers) — forwarded as-is.
2. our proxy's ``proxy_trace.jsonl`` in the sandbox (per-turn
LLM turns + logprobs) — tailed and emitted as
``{type: "proxy.turn", turn, tokens, logprobs, finish_reason, ...}``.
Terminates on a final ``{"type": "rollout.done", ...}`` frame once the
session has idled or erred.
"""
import asyncio
import json as _json
from starlette.responses import StreamingResponse
env = _find_active_environment(None)
if env is None:
return StreamingResponse(
iter([f"data: {_json.dumps({'type': 'error', 'reason': 'env not found'})}\n\n"]),
media_type="text/event-stream",
)
registry = getattr(env, "_registry", None)
handle = registry.get(rollout_id) if registry else None
if handle is None:
async def _single_error():
yield (
"data: "
+ _json.dumps({"type": "error", "rollout_id": rollout_id, "reason": "unknown rollout"})
+ "\n\n"
)
return StreamingResponse(_single_error(), media_type="text/event-stream")
async def _gen():
# Wait briefly for the serve client to be wired by the worker.
for _ in range(60):
if handle.session is not None and getattr(handle.session, "serve_client", None):
break
if handle.is_done():
break
await asyncio.sleep(0.25)
session = handle.session
if session is None:
yield (
"data: "
+ _json.dumps({
"type": "error",
"rollout_id": rollout_id,
"reason": "session never created",
"detail": handle.error,
})
+ "\n\n"
)
return
sandbox = session.sandbox
proxy_trace_path = session._proxy_trace_path
serve_client = getattr(session, "serve_client", None)
# Task A: forward opencode serve events.
serve_q: asyncio.Queue = asyncio.Queue()
async def forward_serve():
if serve_client is None:
return
try:
async for ev in serve_client.astream_events():
await serve_q.put({"source": "serve", **ev})
if handle.is_done():
break
except Exception as exc: # noqa: BLE001
await serve_q.put({"source": "serve", "type": "error", "reason": str(exc)})
finally:
await serve_q.put(None)
# Task B: tail proxy trace file (incremental) from the sandbox.
async def tail_proxy():
last_len = 0
while not handle.is_done():
try:
if proxy_trace_path:
content = sandbox.read_text(proxy_trace_path) or ""
if len(content) > last_len:
new = content[last_len:]
last_len = len(content)
for line in new.splitlines():
line = line.strip()
if not line:
continue
try:
turn = _json.loads(line)
except Exception:
continue
await serve_q.put({
"source": "proxy",
"type": "proxy.turn",
"turn": turn.get("turn"),
"finish_reason": turn.get("finish_reason"),
"n_tokens": len(turn.get("completion_tokens") or []),
"first_tokens": (turn.get("completion_tokens") or [])[:6],
"first_logps": (turn.get("per_token_logps") or [])[:6],
"latency_s": turn.get("latency_s"),
})
except Exception:
pass
await asyncio.sleep(1.0)
t_serve = asyncio.create_task(forward_serve())
t_proxy = asyncio.create_task(tail_proxy())
try:
while True:
try:
ev = await asyncio.wait_for(serve_q.get(), timeout=1.0)
except asyncio.TimeoutError:
ev = None
if ev is None:
if handle.is_done():
break
continue
yield "data: " + _json.dumps(ev) + "\n\n"
finally:
t_serve.cancel()
t_proxy.cancel()
yield "data: " + _json.dumps({
"source": "server",
"type": "rollout.done",
"rollout_id": rollout_id,
"error": handle.error,
}) + "\n\n"
return StreamingResponse(_gen(), media_type="text/event-stream")
def main(host: str = "0.0.0.0", port: int = 8000) -> None:
import uvicorn
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()