File size: 2,295 Bytes
e4f3d12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from __future__ import annotations

from pathlib import Path
from typing import Any

import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from dataclasses import asdict
from pydantic import BaseModel

from .environment import CommitGuardEnvironment
from .parse_action import action_from_json, parse_action


DATA_PATH = Path(__file__).resolve().parent.parent / "data" / "devign_filtered.jsonl"

app = FastAPI(title="CommitGuard Env Server", version="0.1.0")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=False,
    allow_methods=["*"],
    allow_headers=["*"],
)

env = CommitGuardEnvironment(data_path=DATA_PATH)


class StepRequest(BaseModel):
    # Either send `action` as raw XML text, or send structured fields (curl-friendly).
    action: str | None = None
    action_type: str | None = None
    file_path: str | None = None
    reasoning: str | None = None
    is_vulnerable: bool | None = None
    vuln_type: str | None = None
    exploit_sketch: str | None = None


@app.get("/health")
def health() -> dict[str, str]:
    return {"status": "healthy"}


class ResetRequest(BaseModel):
    sample_id: str | None = None

@app.post("/reset")
def reset(req: ResetRequest = ResetRequest()) -> dict[str, Any]:
    try:
        obs = env.reset(sample_id=req.sample_id)
        return {
            "observation": asdict(obs),
            "done": False,
            "reward": 0.0,
        }
    except ValueError as e:
        return {"error": str(e)}


@app.post("/step")
def step(req: StepRequest) -> dict[str, Any]:
    if req.action is not None:
        action = parse_action(req.action)
    else:
        action = action_from_json(req.model_dump(exclude_none=True))
    obs, reward, done = env.step(action)
    return {
        "observation": asdict(obs),
        "done": done,
        "reward": reward,
        "info": {"parse_error": action.parse_error},
    }


@app.get("/state")
def state() -> dict[str, Any]:
    st = env.state()
    return {"state": asdict(st)}


def main() -> None:
    uvicorn.run("commitguard_env.server:app", host="0.0.0.0", port=8000, reload=False)


if __name__ == "__main__":
    main()