Spaces:
Sleeping
Sleeping
| """FastAPI application for the SupportDesk environment.""" | |
| from __future__ import annotations | |
| import os | |
| from typing import Any | |
| import uvicorn | |
| from fastapi import Body, HTTPException | |
| from fastapi.routing import APIRoute | |
| try: | |
| from openenv.core.env_server import http_server as openenv_http_server | |
| except ImportError: | |
| try: | |
| from openenv_core.env_server import http_server as openenv_http_server | |
| except Exception: | |
| # Minimal fallback for test runs when openenv is unavailable. | |
| from pydantic import BaseModel, ValidationError as _PydValidationError | |
| from fastapi import FastAPI | |
| class _ResetRequest(BaseModel): | |
| seed: int | None = None | |
| episode_id: str | None = None | |
| task_id: str | None = None | |
| timeout_s: float | None = None | |
| class _StepRequest(BaseModel): | |
| action: dict | |
| timeout_s: float | None = None | |
| episode_id: str | None = None | |
| def _deserialize_action(data, ActionCls): | |
| return ActionCls.model_validate(data) | |
| def _create_app(env_cls, action_cls, obs_cls, env_name: str = "env", max_concurrent_envs: int = 1): | |
| app = FastAPI() | |
| def _reset(req: _ResetRequest = _ResetRequest()): | |
| env = env_cls() | |
| kwargs = req.model_dump(exclude_none=True) | |
| obs = env.reset(**kwargs) | |
| return {"observation": obs.model_dump(), "reward": obs.reward, "done": obs.done} | |
| def _step(req: _StepRequest): | |
| env = env_cls() | |
| action = _deserialize_action(req.action, action_cls) | |
| obs = env.step(action, timeout_s=req.timeout_s, episode_id=req.episode_id) | |
| return {"observation": obs.model_dump(), "reward": obs.reward, "done": obs.done} | |
| def _state(): | |
| env = env_cls() | |
| return env.state.model_dump() | |
| return app | |
| class _Shim: | |
| ResetRequest = _ResetRequest | |
| StepRequest = _StepRequest | |
| ValidationError = _PydValidationError | |
| deserialize_action = staticmethod(_deserialize_action) | |
| create_app = staticmethod(_create_app) | |
| openenv_http_server = _Shim() | |
| from models import SupportDeskAction, SupportDeskObservation, SupportDeskState | |
| from server.supportdesk_environment import SupportDeskEnvironment | |
| from tasks import TASKS | |
| # Bind the default OpenEnv /state route to the full typed state model. | |
| openenv_http_server.State = SupportDeskState | |
| create_app = openenv_http_server.create_app | |
| # Create the app with web interface and README integration. | |
| app = create_app( | |
| SupportDeskEnvironment, | |
| SupportDeskAction, | |
| SupportDeskObservation, | |
| env_name="supportdesk_env", | |
| max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions | |
| ) | |
| TASK_GRADER_PATHS = { | |
| "billing_refund_easy": "graders:BillingRefundEasyGrader", | |
| "account_takeover_medium": "graders:AccountTakeoverMediumGrader", | |
| "api_incident_hard": "graders:ApiIncidentHardGrader", | |
| "regulated_export_exception_hard": "graders:RegulatedExportExceptionHardGrader", | |
| } | |
| def _replace_route(path: str, methods: set[str]) -> None: | |
| """Remove a generated route so we can register a score-aware replacement.""" | |
| app.router.routes = [ | |
| route | |
| for route in app.router.routes | |
| if not ( | |
| isinstance(route, APIRoute) | |
| and route.path == path | |
| and methods.issubset(set(route.methods or set())) | |
| ) | |
| ] | |
| def _score_response(env: SupportDeskEnvironment, observation: SupportDeskObservation) -> dict[str, Any]: | |
| """Return the standard OpenEnv shape plus an explicit top-level score.""" | |
| return { | |
| "observation": observation.model_dump(), | |
| "reward": observation.reward, | |
| "done": observation.done, | |
| "score": env.state.current_score, | |
| } | |
| _replace_route("/reset", {"POST"}) | |
| _replace_route("/step", {"POST"}) | |
| async def reset_with_score( | |
| request: openenv_http_server.ResetRequest = Body(default_factory=openenv_http_server.ResetRequest), | |
| ) -> dict[str, Any]: | |
| """Reset the environment and expose the initial deterministic score at top level.""" | |
| env = SupportDeskEnvironment() | |
| try: | |
| kwargs = request.model_dump(exclude_unset=True) | |
| observation = env.reset(**kwargs) | |
| return _score_response(env, observation) | |
| finally: | |
| env.close() | |
| async def step_with_score(request: openenv_http_server.StepRequest) -> dict[str, Any]: | |
| """Execute a step and expose the current deterministic score at top level.""" | |
| action_data = request.action | |
| try: | |
| action = openenv_http_server.deserialize_action(action_data, SupportDeskAction) | |
| except openenv_http_server.ValidationError as exc: | |
| raise HTTPException(status_code=422, detail=exc.errors()) from exc | |
| env = SupportDeskEnvironment() | |
| try: | |
| kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) | |
| observation = env.step(action, **kwargs) | |
| return _score_response(env, observation) | |
| finally: | |
| env.close() | |
| def list_tasks() -> dict[str, Any]: | |
| """Expose a stable task catalog for UI, debugging, and pre-submit checks.""" | |
| return { | |
| "environment": { | |
| "name": "supportdesk_env", | |
| "version": "0.1.0", | |
| "grader_type": "deterministic", | |
| "score_range": [0.0, 1.0], | |
| }, | |
| "total_tasks": len(TASKS), | |
| "tasks": [ | |
| { | |
| "task_id": task.task_id, | |
| "grader": TASK_GRADER_PATHS[task.task_id], | |
| "title": task.title, | |
| "difficulty": task.difficulty, | |
| "objective": task.objective, | |
| "max_steps": task.max_steps, | |
| "gold_issue_type": task.gold_issue_type, | |
| "gold_queue": task.gold_queue, | |
| "gold_priority": task.gold_priority, | |
| "ticket_context": { | |
| "customer_tier": task.ticket.customer_tier, | |
| "region": task.ticket.region, | |
| "affected_users": task.ticket.affected_users, | |
| "sla_minutes_remaining": task.ticket.sla_minutes_remaining, | |
| }, | |
| } | |
| for task in TASKS.values() | |
| ], | |
| } | |
| def get_episode_state(episode_id: str) -> SupportDeskState: | |
| """Optional explicit state helper for robust episode-addressable inspection.""" | |
| try: | |
| return SupportDeskEnvironment.state_for_episode(episode_id) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| def step_episode( | |
| episode_id: str, | |
| payload: dict[str, Any] = Body(...), | |
| ) -> dict[str, Any]: | |
| """Optional explicit step helper that does not require sticky request context.""" | |
| action_payload = payload.get("action") | |
| if not isinstance(action_payload, dict): | |
| raise HTTPException(status_code=422, detail="Request body must include an 'action' object.") | |
| timeout_s = payload.get("timeout_s") | |
| try: | |
| action = SupportDeskAction.model_validate(action_payload) | |
| env = SupportDeskEnvironment() | |
| observation = env.step(action, timeout_s=timeout_s, episode_id=episode_id) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| return { | |
| "observation": observation.model_dump(), | |
| "reward": observation.reward, | |
| "done": observation.done, | |
| "score": SupportDeskEnvironment.state_for_episode(episode_id).current_score, | |
| } | |
| def main(host: str = "0.0.0.0", port: int = 8000) -> None: | |
| """ | |
| Entry point for direct execution via uv run or python -m. | |
| This function enables running the server without Docker: | |
| uv run --project . server | |
| uv run --project . server --port 8001 | |
| python -m server.app | |
| Args: | |
| host: Host address to bind to (default: "0.0.0.0") | |
| port: Port number to listen on (default: 8000) | |
| For production deployments, consider using uvicorn directly with | |
| multiple workers: | |
| uvicorn server.app:app --workers 4 | |
| """ | |
| uvicorn.run("server.app:app", host=host, port=port) | |
| if __name__ == '__main__': | |
| main() | |