Spaces:
Running
Running
| """ | |
| FastAPI route definitions for CodeSentry Backend. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| from typing import Any, AsyncGenerator | |
| import httpx | |
| from fastapi import APIRouter, HTTPException, Request | |
| from fastapi.responses import JSONResponse | |
| from sse_starlette.sse import EventSourceResponse | |
| from agents.orchestrator import Orchestrator | |
| from api.models import AnalyzeRequest, HealthResponse, PrivacyCertificate, AMDMetricsSnapshot | |
| from amd_metrics import AMDMetricsCollector | |
| from memory.session_store import get_store | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter() | |
| VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://localhost:8080") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-Coder-32B-Instruct") | |
| # Shared orchestrator instance (lazily initialised) | |
| _orchestrator: Orchestrator | None = None | |
| def get_orchestrator() -> Orchestrator: | |
| global _orchestrator | |
| if _orchestrator is None: | |
| _orchestrator = Orchestrator() | |
| return _orchestrator | |
| # Shared AMD metrics collector for the health endpoint | |
| _amd_collector: AMDMetricsCollector | None = None | |
| def get_amd_collector() -> AMDMetricsCollector: | |
| global _amd_collector | |
| if _amd_collector is None: | |
| _amd_collector = AMDMetricsCollector() | |
| return _amd_collector | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| # Health | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| async def health_check() -> HealthResponse: | |
| """ | |
| Returns vLLM readiness and available GPU memory. | |
| Works even if vLLM is not running (vllm_ready=false). | |
| """ | |
| vllm_ready = False | |
| gpu_memory_free_gb: float | None = None | |
| try: | |
| async with httpx.AsyncClient(timeout=3.0) as client: | |
| resp = await client.get(f"{VLLM_BASE_URL}/health") | |
| vllm_ready = resp.status_code == 200 | |
| except Exception: | |
| vllm_ready = False | |
| # Try to get GPU memory stats via vLLM models endpoint | |
| try: | |
| async with httpx.AsyncClient(timeout=3.0) as client: | |
| resp = await client.get(f"{VLLM_BASE_URL}/v1/models") | |
| if resp.status_code == 200: | |
| vllm_ready = True | |
| except Exception: | |
| pass | |
| # Attempt to read GPU memory from system (Linux / ROCm) | |
| try: | |
| import subprocess | |
| result = subprocess.run( | |
| ["rocm-smi", "--showmeminfo", "vram", "--json"], | |
| capture_output=True, text=True, timeout=5 | |
| ) | |
| if result.returncode == 0: | |
| data = json.loads(result.stdout) | |
| # Parse first GPU's free VRAM | |
| for card_data in data.values(): | |
| if isinstance(card_data, dict): | |
| free_bytes = card_data.get("VRAM Total Memory (B)", 0) | |
| used_bytes = card_data.get("VRAM Total Used Memory (B)", 0) | |
| gpu_memory_free_gb = round((free_bytes - used_bytes) / (1024 ** 3), 1) | |
| break | |
| except Exception: | |
| # On non-AMD or non-Linux systems, skip GPU stats | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| free, total = torch.cuda.mem_get_info() | |
| gpu_memory_free_gb = round(free / (1024 ** 3), 1) | |
| except Exception: | |
| pass | |
| # Try to get AMD GPU metrics | |
| amd_hw = None | |
| try: | |
| collector = get_amd_collector() | |
| metrics = await collector.collect() | |
| amd_hw = AMDMetricsSnapshot(**metrics) | |
| except Exception: | |
| pass | |
| return HealthResponse( | |
| status="ok", | |
| model=MODEL_NAME, | |
| vllm_ready=vllm_ready, | |
| gpu_memory_free_gb=gpu_memory_free_gb, | |
| vllm_endpoint=VLLM_BASE_URL, | |
| amd_hardware=amd_hw, | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| # Main analysis endpoint (SSE streaming) | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| async def create_scan(request: AnalyzeRequest) -> JSONResponse: | |
| """Create a new scan session.""" | |
| store = get_store() | |
| await store.create(request.session_id, { | |
| "source": request.source, | |
| "source_type": request.source_type.value | |
| }) | |
| return JSONResponse(content={"scanId": request.session_id}) | |
| async def scan_stream(scan_id: str) -> EventSourceResponse: | |
| """Stream the analysis results using SSE.""" | |
| store = get_store() | |
| session = await store.get(scan_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Scan session not found") | |
| orchestrator = get_orchestrator() | |
| source = session.get("source") | |
| source_type = session.get("source_type") | |
| async def event_generator() -> AsyncGenerator[dict, None]: | |
| try: | |
| async for event in orchestrator.run_stream( | |
| source=source, | |
| source_type=source_type, | |
| session_id=scan_id, | |
| ): | |
| yield { | |
| "event": event["event"], | |
| "data": json.dumps(event["data"], default=str), | |
| } | |
| except Exception as exc: | |
| logger.error("[Routes] Unhandled error in analysis stream: %s", exc, exc_info=True) | |
| yield { | |
| "event": "error", | |
| "data": json.dumps({"message": str(exc)}), | |
| } | |
| return EventSourceResponse(event_generator()) | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| # Demo endpoint (no GPU required) | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| async def analyze_demo() -> JSONResponse: | |
| """ | |
| Returns a pre-computed analysis result using the vulnerable_ml_code fixture. | |
| No vLLM / GPU required β safe for CI and frontend development. | |
| """ | |
| orchestrator = get_orchestrator() | |
| try: | |
| result = await orchestrator.run_demo(session_id="demo-session") | |
| return JSONResponse(content=result.model_dump(mode="json")) | |
| except Exception as exc: | |
| logger.error("[Routes] Demo endpoint error: %s", exc, exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| # Session retrieval | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_session(session_id: str) -> JSONResponse: | |
| """ | |
| Retrieve the full analysis result for a completed session. | |
| Returns 404 if session not found or expired. | |
| """ | |
| store = get_store() | |
| session = await store.get(session_id) | |
| if session is None: | |
| raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found or expired.") | |
| result = session.get("result") | |
| if result is None: | |
| return JSONResponse(content={"session_id": session_id, "status": session.get("_status", "pending")}) | |
| return JSONResponse(content=result) | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| # Privacy certificate | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_privacy_certificate(session_id: str) -> JSONResponse: | |
| """ | |
| Return the Zero Data Retention audit certificate for a completed session. | |
| """ | |
| store = get_store() | |
| session = await store.get(session_id) | |
| if session is None: | |
| raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.") | |
| result = session.get("result", {}) | |
| cert = result.get("privacy_certificate") | |
| if cert is None: | |
| raise HTTPException(status_code=404, detail="Privacy certificate not yet generated for this session.") | |
| return JSONResponse(content=cert) | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| # Session list (debug / admin) | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| async def list_sessions() -> JSONResponse: | |
| """List all active session IDs (debug endpoint).""" | |
| store = get_store() | |
| sessions = await store.list_sessions() | |
| count = await store.count() | |
| return JSONResponse(content={"active_sessions": sessions, "count": count}) | |