Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Any | |
| import uvicorn | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from dataclasses import asdict | |
| from pydantic import BaseModel | |
| from .environment import CommitGuardEnvironment | |
| from .parse_action import action_from_json, parse_action | |
| DATA_PATH = Path(__file__).resolve().parent.parent / "data" / "devign_filtered.jsonl" | |
| app = FastAPI(title="CommitGuard Env Server", version="0.1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| env = CommitGuardEnvironment(data_path=DATA_PATH) | |
| class StepRequest(BaseModel): | |
| # Either send `action` as raw XML text, or send structured fields (curl-friendly). | |
| action: str | None = None | |
| action_type: str | None = None | |
| file_path: str | None = None | |
| reasoning: str | None = None | |
| is_vulnerable: bool | None = None | |
| vuln_type: str | None = None | |
| exploit_sketch: str | None = None | |
| def health() -> dict[str, str]: | |
| return {"status": "healthy"} | |
| class ResetRequest(BaseModel): | |
| sample_id: str | None = None | |
| def reset(req: ResetRequest = ResetRequest()) -> dict[str, Any]: | |
| try: | |
| obs = env.reset(sample_id=req.sample_id) | |
| return { | |
| "observation": asdict(obs), | |
| "done": False, | |
| "reward": 0.0, | |
| } | |
| except ValueError as e: | |
| return {"error": str(e)} | |
| def step(req: StepRequest) -> dict[str, Any]: | |
| if req.action is not None: | |
| action = parse_action(req.action) | |
| else: | |
| action = action_from_json(req.model_dump(exclude_none=True)) | |
| obs, reward, done = env.step(action) | |
| return { | |
| "observation": asdict(obs), | |
| "done": done, | |
| "reward": reward, | |
| "info": {"parse_error": action.parse_error}, | |
| } | |
| def state() -> dict[str, Any]: | |
| st = env.state() | |
| return {"state": asdict(st)} | |
| def main() -> None: | |
| import os | |
| port = int(os.environ.get("PORT", 8000)) | |
| uvicorn.run("commitguard_env.server:app", host="0.0.0.0", port=port, reload=False) | |
| if __name__ == "__main__": | |
| main() | |