Spaces:
Sleeping
Sleeping
File size: 4,420 Bytes
761f203 4a1d6d9 761f203 4a1d6d9 761f203 4a1d6d9 761f203 4a1d6d9 761f203 4a1d6d9 f53d90b 4a1d6d9 761f203 4c6255d 4a1d6d9 4c6255d 4a1d6d9 4c6255d 761f203 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | from __future__ import annotations
from pathlib import Path
from typing import Any
from fastapi import Body, FastAPI, HTTPException, Query
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field, ValidationError
from env.environment import FlakySleuthEnv
from env.models import FlakySleuthAction, FlakySleuthObservation
from server.inference_runner import InferenceRunner
from server.ui import render_home_page
app = FastAPI(title="FlakySleuth Environment")
env = FlakySleuthEnv()
inference_runner = InferenceRunner(Path(__file__).resolve().parent.parent)
class FlakySleuthState(BaseModel):
repo_url: str | None = None
test_name: str | None = None
task_type: str | None = None
step_count: int
files_read: list[str]
cumulative_progress: float
class InferenceRunRequest(BaseModel):
dataset_path: str = Field(default="dataset/py_tasks.csv")
episodes_per_task: int = Field(default=1, ge=1, le=100)
task_types: str = Field(default="classify,root_cause,fix_proposal")
max_steps: int = Field(default=20, ge=1, le=100)
benchmark_name: str = Field(default="flakysleuth")
api_base_url: str | None = None
model_name: str | None = None
api_key: str | None = None
@app.post("/reset")
def reset() -> dict[str, Any]:
observation = env.reset()
return {
"observation": observation.model_dump(),
"reward": None,
"done": False,
}
@app.post("/step")
def step(payload: dict[str, Any] = Body(...)) -> dict[str, Any]:
"""Accept either {'action': {...}} or direct action payload."""
try:
action_payload = payload.get("action", payload)
action = FlakySleuthAction.model_validate(action_payload)
except ValidationError as exc:
raise HTTPException(status_code=422, detail=exc.errors()) from exc
try:
observation, reward, done, info = env.step(action)
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return {
"observation": observation.model_dump(),
"reward": reward,
"done": done,
"info": info,
}
@app.get("/state")
def state() -> dict[str, Any]:
return env.state()
@app.get("/schema")
def schema() -> dict[str, Any]:
return {
"action": FlakySleuthAction.model_json_schema(),
"observation": FlakySleuthObservation.model_json_schema(),
"state": FlakySleuthState.model_json_schema(),
}
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "healthy"}
@app.get("/", include_in_schema=False)
def root() -> HTMLResponse:
return HTMLResponse(render_home_page())
@app.get("/web", include_in_schema=False)
def web() -> HTMLResponse:
return HTMLResponse(render_home_page())
@app.post("/web/inference/start", include_in_schema=False)
def start_inference(payload: InferenceRunRequest) -> dict[str, Any]:
request_payload = payload.model_dump()
try:
return inference_runner.start(request_payload)
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
except RuntimeError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
@app.get("/web/inference/status", include_in_schema=False)
def inference_status(tail: int = Query(default=450, ge=20, le=2000)) -> dict[str, Any]:
return inference_runner.snapshot(tail=tail)
@app.post("/web/inference/stop", include_in_schema=False)
def stop_inference() -> dict[str, Any]:
stopped = inference_runner.stop()
snapshot = inference_runner.snapshot(tail=450)
snapshot["stopped"] = stopped
return snapshot
@app.get("/metadata")
def metadata() -> dict[str, str]:
return {
"name": "FlakySleuth Environment",
"description": (
"RL environment for flaky-test investigation in Python repositories."
),
}
@app.post("/mcp")
def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
request_id = payload.get("id")
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {"status": "ok"},
}
def main(host: str = "0.0.0.0", port: int = 8000) -> None:
import uvicorn
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()
|