Spaces:
Running
Running
File size: 9,177 Bytes
7b4f5dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 | """
FastAPI route definitions for CodeSentry Backend.
"""
from __future__ import annotations
import json
import logging
import os
from typing import Any, AsyncGenerator
import httpx
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from agents.orchestrator import Orchestrator
from api.models import AnalyzeRequest, HealthResponse, PrivacyCertificate, AMDMetricsSnapshot
from amd_metrics import AMDMetricsCollector
from memory.session_store import get_store
logger = logging.getLogger(__name__)
router = APIRouter()
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://localhost:8080")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-Coder-32B-Instruct")
# Shared orchestrator instance (lazily initialised)
_orchestrator: Orchestrator | None = None
def get_orchestrator() -> Orchestrator:
global _orchestrator
if _orchestrator is None:
_orchestrator = Orchestrator()
return _orchestrator
# Shared AMD metrics collector for the health endpoint
_amd_collector: AMDMetricsCollector | None = None
def get_amd_collector() -> AMDMetricsCollector:
global _amd_collector
if _amd_collector is None:
_amd_collector = AMDMetricsCollector()
return _amd_collector
# ββββββββββββββββββββββββββββββββββββββββββ
# Health
# ββββββββββββββββββββββββββββββββββββββββββ
@router.get("/health", response_model=HealthResponse, tags=["Health"])
async def health_check() -> HealthResponse:
"""
Returns vLLM readiness and available GPU memory.
Works even if vLLM is not running (vllm_ready=false).
"""
vllm_ready = False
gpu_memory_free_gb: float | None = None
try:
async with httpx.AsyncClient(timeout=3.0) as client:
resp = await client.get(f"{VLLM_BASE_URL}/health")
vllm_ready = resp.status_code == 200
except Exception:
vllm_ready = False
# Try to get GPU memory stats via vLLM models endpoint
try:
async with httpx.AsyncClient(timeout=3.0) as client:
resp = await client.get(f"{VLLM_BASE_URL}/v1/models")
if resp.status_code == 200:
vllm_ready = True
except Exception:
pass
# Attempt to read GPU memory from system (Linux / ROCm)
try:
import subprocess
result = subprocess.run(
["rocm-smi", "--showmeminfo", "vram", "--json"],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
data = json.loads(result.stdout)
# Parse first GPU's free VRAM
for card_data in data.values():
if isinstance(card_data, dict):
free_bytes = card_data.get("VRAM Total Memory (B)", 0)
used_bytes = card_data.get("VRAM Total Used Memory (B)", 0)
gpu_memory_free_gb = round((free_bytes - used_bytes) / (1024 ** 3), 1)
break
except Exception:
# On non-AMD or non-Linux systems, skip GPU stats
try:
import torch
if torch.cuda.is_available():
free, total = torch.cuda.mem_get_info()
gpu_memory_free_gb = round(free / (1024 ** 3), 1)
except Exception:
pass
# Try to get AMD GPU metrics
amd_hw = None
try:
collector = get_amd_collector()
metrics = await collector.collect()
amd_hw = AMDMetricsSnapshot(**metrics)
except Exception:
pass
return HealthResponse(
status="ok",
model=MODEL_NAME,
vllm_ready=vllm_ready,
gpu_memory_free_gb=gpu_memory_free_gb,
vllm_endpoint=VLLM_BASE_URL,
amd_hardware=amd_hw,
)
# ββββββββββββββββββββββββββββββββββββββββββ
# Main analysis endpoint (SSE streaming)
# ββββββββββββββββββββββββββββββββββββββββββ
@router.post("/scan", tags=["Analysis"])
async def create_scan(request: AnalyzeRequest) -> JSONResponse:
"""Create a new scan session."""
store = get_store()
await store.create(request.session_id, {
"source": request.source,
"source_type": request.source_type.value
})
return JSONResponse(content={"scanId": request.session_id})
@router.get("/scan/stream/{scan_id}", tags=["Analysis"])
async def scan_stream(scan_id: str) -> EventSourceResponse:
"""Stream the analysis results using SSE."""
store = get_store()
session = await store.get(scan_id)
if not session:
raise HTTPException(status_code=404, detail="Scan session not found")
orchestrator = get_orchestrator()
source = session.get("source")
source_type = session.get("source_type")
async def event_generator() -> AsyncGenerator[dict, None]:
try:
async for event in orchestrator.run_stream(
source=source,
source_type=source_type,
session_id=scan_id,
):
yield {
"event": event["event"],
"data": json.dumps(event["data"], default=str),
}
except Exception as exc:
logger.error("[Routes] Unhandled error in analysis stream: %s", exc, exc_info=True)
yield {
"event": "error",
"data": json.dumps({"message": str(exc)}),
}
return EventSourceResponse(event_generator())
# ββββββββββββββββββββββββββββββββββββββββββ
# Demo endpoint (no GPU required)
# ββββββββββββββββββββββββββββββββββββββββββ
@router.post("/analyze/demo", tags=["Analysis"])
async def analyze_demo() -> JSONResponse:
"""
Returns a pre-computed analysis result using the vulnerable_ml_code fixture.
No vLLM / GPU required β safe for CI and frontend development.
"""
orchestrator = get_orchestrator()
try:
result = await orchestrator.run_demo(session_id="demo-session")
return JSONResponse(content=result.model_dump(mode="json"))
except Exception as exc:
logger.error("[Routes] Demo endpoint error: %s", exc, exc_info=True)
raise HTTPException(status_code=500, detail=str(exc))
# ββββββββββββββββββββββββββββββββββββββββββ
# Session retrieval
# ββββββββββββββββββββββββββββββββββββββββββ
@router.get("/session/{session_id}", tags=["Session"])
async def get_session(session_id: str) -> JSONResponse:
"""
Retrieve the full analysis result for a completed session.
Returns 404 if session not found or expired.
"""
store = get_store()
session = await store.get(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found or expired.")
result = session.get("result")
if result is None:
return JSONResponse(content={"session_id": session_id, "status": session.get("_status", "pending")})
return JSONResponse(content=result)
# ββββββββββββββββββββββββββββββββββββββββββ
# Privacy certificate
# ββββββββββββββββββββββββββββββββββββββββββ
@router.get("/privacy-certificate/{session_id}", tags=["Privacy"])
async def get_privacy_certificate(session_id: str) -> JSONResponse:
"""
Return the Zero Data Retention audit certificate for a completed session.
"""
store = get_store()
session = await store.get(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.")
result = session.get("result", {})
cert = result.get("privacy_certificate")
if cert is None:
raise HTTPException(status_code=404, detail="Privacy certificate not yet generated for this session.")
return JSONResponse(content=cert)
# ββββββββββββββββββββββββββββββββββββββββββ
# Session list (debug / admin)
# ββββββββββββββββββββββββββββββββββββββββββ
@router.get("/sessions", tags=["Session"], include_in_schema=False)
async def list_sessions() -> JSONResponse:
"""List all active session IDs (debug endpoint)."""
store = get_store()
sessions = await store.list_sessions()
count = await store.count()
return JSONResponse(content={"active_sessions": sessions, "count": count})
|