File size: 4,420 Bytes
761f203
 
4a1d6d9
761f203
 
4a1d6d9
 
 
761f203
 
 
4a1d6d9
 
761f203
 
 
4a1d6d9
761f203
 
 
 
 
 
 
 
 
 
 
4a1d6d9
 
f53d90b
4a1d6d9
 
 
 
 
 
 
 
761f203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c6255d
4a1d6d9
 
4c6255d
 
 
4a1d6d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c6255d
 
761f203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from __future__ import annotations

from pathlib import Path
from typing import Any

from fastapi import Body, FastAPI, HTTPException, Query
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field, ValidationError

from env.environment import FlakySleuthEnv
from env.models import FlakySleuthAction, FlakySleuthObservation
from server.inference_runner import InferenceRunner
from server.ui import render_home_page

app = FastAPI(title="FlakySleuth Environment")
env = FlakySleuthEnv()
inference_runner = InferenceRunner(Path(__file__).resolve().parent.parent)


class FlakySleuthState(BaseModel):
    repo_url: str | None = None
    test_name: str | None = None
    task_type: str | None = None
    step_count: int
    files_read: list[str]
    cumulative_progress: float


class InferenceRunRequest(BaseModel):
    dataset_path: str = Field(default="dataset/py_tasks.csv")
    episodes_per_task: int = Field(default=1, ge=1, le=100)
    task_types: str = Field(default="classify,root_cause,fix_proposal")
    max_steps: int = Field(default=20, ge=1, le=100)
    benchmark_name: str = Field(default="flakysleuth")
    api_base_url: str | None = None
    model_name: str | None = None
    api_key: str | None = None


@app.post("/reset")
def reset() -> dict[str, Any]:
    observation = env.reset()
    return {
        "observation": observation.model_dump(),
        "reward": None,
        "done": False,
    }


@app.post("/step")
def step(payload: dict[str, Any] = Body(...)) -> dict[str, Any]:
    """Accept either {'action': {...}} or direct action payload."""
    try:
        action_payload = payload.get("action", payload)
        action = FlakySleuthAction.model_validate(action_payload)
    except ValidationError as exc:
        raise HTTPException(status_code=422, detail=exc.errors()) from exc

    try:
        observation, reward, done, info = env.step(action)
    except RuntimeError as exc:
        raise HTTPException(status_code=400, detail=str(exc)) from exc

    return {
        "observation": observation.model_dump(),
        "reward": reward,
        "done": done,
        "info": info,
    }


@app.get("/state")
def state() -> dict[str, Any]:
    return env.state()


@app.get("/schema")
def schema() -> dict[str, Any]:
    return {
        "action": FlakySleuthAction.model_json_schema(),
        "observation": FlakySleuthObservation.model_json_schema(),
        "state": FlakySleuthState.model_json_schema(),
    }


@app.get("/health")
def health() -> dict[str, str]:
    return {"status": "healthy"}


@app.get("/", include_in_schema=False)
def root() -> HTMLResponse:
    return HTMLResponse(render_home_page())


@app.get("/web", include_in_schema=False)
def web() -> HTMLResponse:
    return HTMLResponse(render_home_page())


@app.post("/web/inference/start", include_in_schema=False)
def start_inference(payload: InferenceRunRequest) -> dict[str, Any]:
    request_payload = payload.model_dump()
    try:
        return inference_runner.start(request_payload)
    except FileNotFoundError as exc:
        raise HTTPException(status_code=404, detail=str(exc)) from exc
    except ValueError as exc:
        raise HTTPException(status_code=422, detail=str(exc)) from exc
    except RuntimeError as exc:
        raise HTTPException(status_code=409, detail=str(exc)) from exc


@app.get("/web/inference/status", include_in_schema=False)
def inference_status(tail: int = Query(default=450, ge=20, le=2000)) -> dict[str, Any]:
    return inference_runner.snapshot(tail=tail)


@app.post("/web/inference/stop", include_in_schema=False)
def stop_inference() -> dict[str, Any]:
    stopped = inference_runner.stop()
    snapshot = inference_runner.snapshot(tail=450)
    snapshot["stopped"] = stopped
    return snapshot


@app.get("/metadata")
def metadata() -> dict[str, str]:
    return {
        "name": "FlakySleuth Environment",
        "description": (
            "RL environment for flaky-test investigation in Python repositories."
        ),
    }


@app.post("/mcp")
def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
    request_id = payload.get("id")
    return {
        "jsonrpc": "2.0",
        "id": request_id,
        "result": {"status": "ok"},
    }


def main(host: str = "0.0.0.0", port: int = 8000) -> None:
    import uvicorn

    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    main()