"""FastAPI application for the SupportDesk environment.""" from __future__ import annotations import os from typing import Any import uvicorn from fastapi import Body, HTTPException from fastapi.routing import APIRoute try: from openenv.core.env_server import http_server as openenv_http_server except ImportError: try: from openenv_core.env_server import http_server as openenv_http_server except Exception: # Minimal fallback for test runs when openenv is unavailable. from pydantic import BaseModel, ValidationError as _PydValidationError from fastapi import FastAPI class _ResetRequest(BaseModel): seed: int | None = None episode_id: str | None = None task_id: str | None = None timeout_s: float | None = None class _StepRequest(BaseModel): action: dict timeout_s: float | None = None episode_id: str | None = None def _deserialize_action(data, ActionCls): return ActionCls.model_validate(data) def _create_app(env_cls, action_cls, obs_cls, env_name: str = "env", max_concurrent_envs: int = 1): app = FastAPI() @app.post("/reset") def _reset(req: _ResetRequest = _ResetRequest()): env = env_cls() kwargs = req.model_dump(exclude_none=True) obs = env.reset(**kwargs) return {"observation": obs.model_dump(), "reward": obs.reward, "done": obs.done} @app.post("/step") def _step(req: _StepRequest): env = env_cls() action = _deserialize_action(req.action, action_cls) obs = env.step(action, timeout_s=req.timeout_s, episode_id=req.episode_id) return {"observation": obs.model_dump(), "reward": obs.reward, "done": obs.done} @app.get("/state") def _state(): env = env_cls() return env.state.model_dump() return app class _Shim: ResetRequest = _ResetRequest StepRequest = _StepRequest ValidationError = _PydValidationError deserialize_action = staticmethod(_deserialize_action) create_app = staticmethod(_create_app) openenv_http_server = _Shim() from models import SupportDeskAction, SupportDeskObservation, SupportDeskState from server.supportdesk_environment import SupportDeskEnvironment from tasks import TASKS # Bind the default OpenEnv /state route to the full typed state model. openenv_http_server.State = SupportDeskState create_app = openenv_http_server.create_app # Create the app with web interface and README integration. app = create_app( SupportDeskEnvironment, SupportDeskAction, SupportDeskObservation, env_name="supportdesk_env", max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions ) TASK_GRADER_PATHS = { "billing_refund_easy": "graders:BillingRefundEasyGrader", "account_takeover_medium": "graders:AccountTakeoverMediumGrader", "api_incident_hard": "graders:ApiIncidentHardGrader", "regulated_export_exception_hard": "graders:RegulatedExportExceptionHardGrader", } def _replace_route(path: str, methods: set[str]) -> None: """Remove a generated route so we can register a score-aware replacement.""" app.router.routes = [ route for route in app.router.routes if not ( isinstance(route, APIRoute) and route.path == path and methods.issubset(set(route.methods or set())) ) ] def _score_response(env: SupportDeskEnvironment, observation: SupportDeskObservation) -> dict[str, Any]: """Return the standard OpenEnv shape plus an explicit top-level score.""" return { "observation": observation.model_dump(), "reward": observation.reward, "done": observation.done, "score": env.state.current_score, } _replace_route("/reset", {"POST"}) _replace_route("/step", {"POST"}) @app.post("/reset") async def reset_with_score( request: openenv_http_server.ResetRequest = Body(default_factory=openenv_http_server.ResetRequest), ) -> dict[str, Any]: """Reset the environment and expose the initial deterministic score at top level.""" env = SupportDeskEnvironment() try: kwargs = request.model_dump(exclude_unset=True) observation = env.reset(**kwargs) return _score_response(env, observation) finally: env.close() @app.post("/step") async def step_with_score(request: openenv_http_server.StepRequest) -> dict[str, Any]: """Execute a step and expose the current deterministic score at top level.""" action_data = request.action try: action = openenv_http_server.deserialize_action(action_data, SupportDeskAction) except openenv_http_server.ValidationError as exc: raise HTTPException(status_code=422, detail=exc.errors()) from exc env = SupportDeskEnvironment() try: kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) observation = env.step(action, **kwargs) return _score_response(env, observation) finally: env.close() @app.get("/tasks") def list_tasks() -> dict[str, Any]: """Expose a stable task catalog for UI, debugging, and pre-submit checks.""" return { "environment": { "name": "supportdesk_env", "version": "0.1.0", "grader_type": "deterministic", "score_range": [0.0, 1.0], }, "total_tasks": len(TASKS), "tasks": [ { "task_id": task.task_id, "grader": TASK_GRADER_PATHS[task.task_id], "title": task.title, "difficulty": task.difficulty, "objective": task.objective, "max_steps": task.max_steps, "gold_issue_type": task.gold_issue_type, "gold_queue": task.gold_queue, "gold_priority": task.gold_priority, "ticket_context": { "customer_tier": task.ticket.customer_tier, "region": task.ticket.region, "affected_users": task.ticket.affected_users, "sla_minutes_remaining": task.ticket.sla_minutes_remaining, }, } for task in TASKS.values() ], } @app.get("/episodes/{episode_id}/state", response_model=SupportDeskState) def get_episode_state(episode_id: str) -> SupportDeskState: """Optional explicit state helper for robust episode-addressable inspection.""" try: return SupportDeskEnvironment.state_for_episode(episode_id) except ValueError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc @app.post("/episodes/{episode_id}/step") def step_episode( episode_id: str, payload: dict[str, Any] = Body(...), ) -> dict[str, Any]: """Optional explicit step helper that does not require sticky request context.""" action_payload = payload.get("action") if not isinstance(action_payload, dict): raise HTTPException(status_code=422, detail="Request body must include an 'action' object.") timeout_s = payload.get("timeout_s") try: action = SupportDeskAction.model_validate(action_payload) env = SupportDeskEnvironment() observation = env.step(action, timeout_s=timeout_s, episode_id=episode_id) except ValueError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc return { "observation": observation.model_dump(), "reward": observation.reward, "done": observation.done, "score": SupportDeskEnvironment.state_for_episode(episode_id).current_score, } def main(host: str = "0.0.0.0", port: int = 8000) -> None: """ Entry point for direct execution via uv run or python -m. This function enables running the server without Docker: uv run --project . server uv run --project . server --port 8001 python -m server.app Args: host: Host address to bind to (default: "0.0.0.0") port: Port number to listen on (default: 8000) For production deployments, consider using uvicorn directly with multiple workers: uvicorn server.app:app --workers 4 """ uvicorn.run("server.app:app", host=host, port=port) if __name__ == '__main__': main()