modelbuilderhq's picture
Upload folder using huggingface_hub
2ade2c6 verified
"""FastAPI application for the SupportDesk environment."""
from __future__ import annotations
import os
from typing import Any
import uvicorn
from fastapi import Body, HTTPException
from fastapi.routing import APIRoute
try:
from openenv.core.env_server import http_server as openenv_http_server
except ImportError:
try:
from openenv_core.env_server import http_server as openenv_http_server
except Exception:
# Minimal fallback for test runs when openenv is unavailable.
from pydantic import BaseModel, ValidationError as _PydValidationError
from fastapi import FastAPI
class _ResetRequest(BaseModel):
seed: int | None = None
episode_id: str | None = None
task_id: str | None = None
timeout_s: float | None = None
class _StepRequest(BaseModel):
action: dict
timeout_s: float | None = None
episode_id: str | None = None
def _deserialize_action(data, ActionCls):
return ActionCls.model_validate(data)
def _create_app(env_cls, action_cls, obs_cls, env_name: str = "env", max_concurrent_envs: int = 1):
app = FastAPI()
@app.post("/reset")
def _reset(req: _ResetRequest = _ResetRequest()):
env = env_cls()
kwargs = req.model_dump(exclude_none=True)
obs = env.reset(**kwargs)
return {"observation": obs.model_dump(), "reward": obs.reward, "done": obs.done}
@app.post("/step")
def _step(req: _StepRequest):
env = env_cls()
action = _deserialize_action(req.action, action_cls)
obs = env.step(action, timeout_s=req.timeout_s, episode_id=req.episode_id)
return {"observation": obs.model_dump(), "reward": obs.reward, "done": obs.done}
@app.get("/state")
def _state():
env = env_cls()
return env.state.model_dump()
return app
class _Shim:
ResetRequest = _ResetRequest
StepRequest = _StepRequest
ValidationError = _PydValidationError
deserialize_action = staticmethod(_deserialize_action)
create_app = staticmethod(_create_app)
openenv_http_server = _Shim()
from models import SupportDeskAction, SupportDeskObservation, SupportDeskState
from server.supportdesk_environment import SupportDeskEnvironment
from tasks import TASKS
# Bind the default OpenEnv /state route to the full typed state model.
openenv_http_server.State = SupportDeskState
create_app = openenv_http_server.create_app
# Create the app with web interface and README integration.
app = create_app(
SupportDeskEnvironment,
SupportDeskAction,
SupportDeskObservation,
env_name="supportdesk_env",
max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
)
TASK_GRADER_PATHS = {
"billing_refund_easy": "graders:BillingRefundEasyGrader",
"account_takeover_medium": "graders:AccountTakeoverMediumGrader",
"api_incident_hard": "graders:ApiIncidentHardGrader",
"regulated_export_exception_hard": "graders:RegulatedExportExceptionHardGrader",
}
def _replace_route(path: str, methods: set[str]) -> None:
"""Remove a generated route so we can register a score-aware replacement."""
app.router.routes = [
route
for route in app.router.routes
if not (
isinstance(route, APIRoute)
and route.path == path
and methods.issubset(set(route.methods or set()))
)
]
def _score_response(env: SupportDeskEnvironment, observation: SupportDeskObservation) -> dict[str, Any]:
"""Return the standard OpenEnv shape plus an explicit top-level score."""
return {
"observation": observation.model_dump(),
"reward": observation.reward,
"done": observation.done,
"score": env.state.current_score,
}
_replace_route("/reset", {"POST"})
_replace_route("/step", {"POST"})
@app.post("/reset")
async def reset_with_score(
request: openenv_http_server.ResetRequest = Body(default_factory=openenv_http_server.ResetRequest),
) -> dict[str, Any]:
"""Reset the environment and expose the initial deterministic score at top level."""
env = SupportDeskEnvironment()
try:
kwargs = request.model_dump(exclude_unset=True)
observation = env.reset(**kwargs)
return _score_response(env, observation)
finally:
env.close()
@app.post("/step")
async def step_with_score(request: openenv_http_server.StepRequest) -> dict[str, Any]:
"""Execute a step and expose the current deterministic score at top level."""
action_data = request.action
try:
action = openenv_http_server.deserialize_action(action_data, SupportDeskAction)
except openenv_http_server.ValidationError as exc:
raise HTTPException(status_code=422, detail=exc.errors()) from exc
env = SupportDeskEnvironment()
try:
kwargs = request.model_dump(exclude_unset=True, exclude={"action"})
observation = env.step(action, **kwargs)
return _score_response(env, observation)
finally:
env.close()
@app.get("/tasks")
def list_tasks() -> dict[str, Any]:
"""Expose a stable task catalog for UI, debugging, and pre-submit checks."""
return {
"environment": {
"name": "supportdesk_env",
"version": "0.1.0",
"grader_type": "deterministic",
"score_range": [0.0, 1.0],
},
"total_tasks": len(TASKS),
"tasks": [
{
"task_id": task.task_id,
"grader": TASK_GRADER_PATHS[task.task_id],
"title": task.title,
"difficulty": task.difficulty,
"objective": task.objective,
"max_steps": task.max_steps,
"gold_issue_type": task.gold_issue_type,
"gold_queue": task.gold_queue,
"gold_priority": task.gold_priority,
"ticket_context": {
"customer_tier": task.ticket.customer_tier,
"region": task.ticket.region,
"affected_users": task.ticket.affected_users,
"sla_minutes_remaining": task.ticket.sla_minutes_remaining,
},
}
for task in TASKS.values()
],
}
@app.get("/episodes/{episode_id}/state", response_model=SupportDeskState)
def get_episode_state(episode_id: str) -> SupportDeskState:
"""Optional explicit state helper for robust episode-addressable inspection."""
try:
return SupportDeskEnvironment.state_for_episode(episode_id)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
@app.post("/episodes/{episode_id}/step")
def step_episode(
episode_id: str,
payload: dict[str, Any] = Body(...),
) -> dict[str, Any]:
"""Optional explicit step helper that does not require sticky request context."""
action_payload = payload.get("action")
if not isinstance(action_payload, dict):
raise HTTPException(status_code=422, detail="Request body must include an 'action' object.")
timeout_s = payload.get("timeout_s")
try:
action = SupportDeskAction.model_validate(action_payload)
env = SupportDeskEnvironment()
observation = env.step(action, timeout_s=timeout_s, episode_id=episode_id)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
return {
"observation": observation.model_dump(),
"reward": observation.reward,
"done": observation.done,
"score": SupportDeskEnvironment.state_for_episode(episode_id).current_score,
}
def main(host: str = "0.0.0.0", port: int = 8000) -> None:
"""
Entry point for direct execution via uv run or python -m.
This function enables running the server without Docker:
uv run --project . server
uv run --project . server --port 8001
python -m server.app
Args:
host: Host address to bind to (default: "0.0.0.0")
port: Port number to listen on (default: 8000)
For production deployments, consider using uvicorn directly with
multiple workers:
uvicorn server.app:app --workers 4
"""
uvicorn.run("server.app:app", host=host, port=port)
if __name__ == '__main__':
main()