Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI server exposing the OpenEnv SQL Optimizer environment. | |
| Endpoints: | |
| POST /reset β Observation | |
| POST /step β {observation, reward, done, info} | |
| GET /state β state dict | |
| GET /tasks β list of tasks + action schema | |
| GET /grader β grader score for last completed episode | |
| POST /baseline β trigger baseline inference on all 3 tasks | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import subprocess | |
| import sys | |
| from typing import Any, Dict, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from env.environment import SQLOptimizerEnv | |
| from env.models import Action, Observation, Reward | |
| from env.tasks import TASKS | |
| app = FastAPI( | |
| title="SQL Query Optimizer β OpenEnv", | |
| description=( | |
| "An OpenEnv-compliant environment where AI agents learn to rewrite " | |
| "and optimise SQL queries across three difficulty levels." | |
| ), | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Single shared environment instance (stateful, per-process) | |
| _env = SQLOptimizerEnv() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Request / Response schemas | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ResetRequest(BaseModel): | |
| task_id: int = 1 | |
| class StepResponse(BaseModel): | |
| observation: Observation | |
| reward: Reward | |
| done: bool | |
| info: Dict[str, Any] | |
| class GraderResponse(BaseModel): | |
| task_id: Optional[int] | |
| grader_score: float | |
| cumulative_score: float | |
| done: bool | |
| class TaskInfo(BaseModel): | |
| id: int | |
| name: str | |
| difficulty: str | |
| description: str | |
| action_schema: Dict[str, Any] | |
| class BaselineResponse(BaseModel): | |
| task_results: Dict[str, float] | |
| message: str | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Endpoints | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _health_payload() -> Dict[str, str]: | |
| return {"status": "ok", "environment": "sql-query-optimizer", "version": "1.0.0"} | |
| def health() -> Dict[str, str]: | |
| return _health_payload() | |
| def web_health() -> Dict[str, str]: | |
| return _health_payload() | |
| def reset(req: ResetRequest) -> Observation: | |
| """Reset the environment for a given task_id (1=easy, 2=medium, 3=hard).""" | |
| try: | |
| obs = _env.reset(task_id=req.task_id) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| return obs | |
| def step(action: Action) -> StepResponse: | |
| """Advance the environment by submitting an Action.""" | |
| try: | |
| obs, reward, done, info = _env.step(action) | |
| except RuntimeError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| return StepResponse(observation=obs, reward=reward, done=done, info=info) | |
| def state() -> Dict[str, Any]: | |
| """Return the current internal state of the environment.""" | |
| return _env.state() | |
| def list_tasks() -> list[TaskInfo]: | |
| """Return all tasks with descriptions and the action schema.""" | |
| action_schema = Action.model_json_schema() | |
| return [ | |
| TaskInfo( | |
| id=t.id, | |
| name=t.name, | |
| difficulty=t.difficulty, | |
| description=t.description, | |
| action_schema=action_schema, | |
| ) | |
| for t in TASKS.values() | |
| ] | |
| def grader() -> GraderResponse: | |
| """Return the grader score after the current/last episode.""" | |
| s = _env.state() | |
| if s.get("status") == "not_started": | |
| raise HTTPException(status_code=400, detail="No episode started. Call /reset first.") | |
| return GraderResponse( | |
| task_id=s.get("task_id"), | |
| grader_score=s.get("last_grader_score", 0.0), | |
| cumulative_score=s.get("cumulative_score", 0.0), | |
| done=s.get("done", False), | |
| ) | |
| def baseline() -> BaselineResponse: | |
| """ | |
| Trigger the baseline inference script (baseline.py) and return scores. | |
| Requires OPENAI_API_KEY to be set in the environment. | |
| """ | |
| if not os.getenv("OPENAI_API_KEY"): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="OPENAI_API_KEY environment variable not set. Cannot run baseline.", | |
| ) | |
| try: | |
| result = subprocess.run( | |
| [sys.executable, "baseline.py", "--json"], | |
| capture_output=True, | |
| text=True, | |
| timeout=300, | |
| ) | |
| if result.returncode != 0: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Baseline script failed:\n{result.stderr}", | |
| ) | |
| import json | |
| scores = json.loads(result.stdout) | |
| return BaselineResponse(task_results=scores, message="Baseline completed successfully.") | |
| except subprocess.TimeoutExpired: | |
| raise HTTPException(status_code=500, detail="Baseline script timed out after 300s.") | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) | |