File size: 3,439 Bytes
95cbc5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import annotations

import logging
import os
import sys
from pathlib import Path
from typing import Any

# Immediate flush logging for HF diagnosis
def print_now(msg: str):
    sys.stdout.write(f"DEBUG: {msg}\n")
    sys.stdout.flush()

print_now("Server process started, beginning imports...")

import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from dataclasses import asdict
from pydantic import BaseModel

print_now("FastAPI imported.")

from .environment import CommitGuardEnvironment
from .parse_action import action_from_json, parse_action

print_now("Local modules imported.")

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configurable data path with fallback
DATA_PATH_STR = os.environ.get("COMMITGUARD_DATA_PATH", "")
if DATA_PATH_STR:
    DATA_PATH = Path(DATA_PATH_STR)
else:
    # Match Docker path: /app/data/...
    DATA_PATH = Path(__file__).resolve().parent.parent / "data" / "devign_filtered.jsonl"

print_now(f"DATA_PATH resolved to: {DATA_PATH}")

app = FastAPI(title="CommitGuard Env Server", version="0.1.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=False,
    allow_methods=["*"],
    allow_headers=["*"],
)

env = CommitGuardEnvironment(data_path=DATA_PATH)

@app.on_event("startup")
def startup_event():
    print_now("FastAPI startup event triggered.")
    logger.info(f"Loading data from {DATA_PATH}...")
    try:
        if not DATA_PATH.exists():
            print_now(f"CRITICAL: Data path {DATA_PATH} DOES NOT EXIST")
        env.load()
        logger.info(f"Successfully loaded {len(env._samples)} samples.")
        print_now(f"Loaded {len(env._samples)} samples.")
    except Exception as e:
        logger.error(f"FAILED to load data: {e}")
        print_now(f"ERROR during load: {e}")

class StepRequest(BaseModel):
    action: str | None = None
    action_type: str | None = None
    file_path: str | None = None
    reasoning: str | None = None
    is_vulnerable: bool | None = None
    vuln_type: str | None = None
    exploit_sketch: str | None = None
    episode_id: str | None = None


@app.get("/health")
def health() -> dict[str, str]:
    return {"status": "healthy"}


class ResetRequest(BaseModel):
    sample_id: str | None = None

@app.post("/reset")
def reset(req: ResetRequest = ResetRequest()) -> dict[str, Any]:
    try:
        obs = env.reset(sample_id=req.sample_id)
        return {
            "observation": asdict(obs),
            "done": False,
            "reward": 0.0,
        }
    except ValueError as e:
        return {"error": str(e)}


@app.post("/step")
def step(req: StepRequest) -> dict[str, Any]:
    if req.action is not None:
        action = parse_action(req.action)
    else:
        action = action_from_json(req.model_dump(exclude_none=True))
    obs, reward, done = env.step(action, episode_id=req.episode_id)
    return {
        "observation": asdict(obs),
        "done": done,
        "reward": reward,
        "info": {"parse_error": action.parse_error},
    }


@app.get("/state")
def state(episode_id: str | None = None) -> dict[str, Any]:
    st = env.state(episode_id=episode_id)
    return {"state": asdict(st)}


def main() -> None:
    port = int(os.environ.get("PORT", 8000))
    uvicorn.run("commitguard_env.server:app", host="0.0.0.0", port=port, reload=False)


if __name__ == "__main__":
    main()