nl-sql / src /nl_sql /api /main.py
liovina's picture
Deploy NL_SQL HEAD to HF Space
424ea19 verified
"""FastAPI surface for the NL→SQL Assistant.
Endpoints:
GET /healthz — liveness probe + provider configuration snapshot.
GET /readyz — readiness probe (Chroma + DB registry reachable).
GET /databases — list registered DBs + table counts.
POST /ask — translate question to SQL, execute, return result.
GET /eval/latest — metadata of the latest committed eval report.
Auth:
Set ``NL_SQL_API_KEY`` (env, .env, or settings). When set, every request
to /ask and /databases must include ``X-API-Key`` matching it. /healthz
and /readyz are always open for orchestrator probes.
Rate limit:
In-process token bucket per API key (60 req/min default). No external
Redis — this is a single-replica portfolio demo, not a fleet service.
"""
from __future__ import annotations
import time
import uuid
from collections import defaultdict, deque
from functools import lru_cache
from pathlib import Path
from typing import Any, NamedTuple
from fastapi import Depends, FastAPI, Header, HTTPException, Request, status
from pydantic import BaseModel, Field
from nl_sql import __version__
from nl_sql.agent.graph import (
PipelineConfig,
PipelineRunResult,
build_pipeline,
run_pipeline,
)
from nl_sql.config import Settings, get_settings
from nl_sql.db.registry import DatabaseRegistry, get_default_registry
from nl_sql.llm.cache import CachingEmbeddingProvider, CachingLLMProvider
from nl_sql.llm.providers import build_provider
from nl_sql.llm.providers.base import EmbeddingProvider, LLMProvider
from nl_sql.llm.providers.mistral import MistralProvider
from nl_sql.schema_index.indexer import SchemaIndex
# ---------------------------------------------------------- response models
class HealthResponse(BaseModel):
status: str
version: str
providers_configured: list[str]
class ReadyResponse(BaseModel):
status: str
chroma_ok: bool
registry_ok: bool
registered_dbs: int
schema_chunks: int
class DatabaseInfo(BaseModel):
db_id: str
dialect: str
description: str = ""
table_count: int
class DatabasesResponse(BaseModel):
databases: list[DatabaseInfo]
class AskRequest(BaseModel):
question: str = Field(min_length=1, max_length=2000)
db_id: str = Field(min_length=1)
class TraceStep(BaseModel):
node: str
model: str | None = None
tokens_in: int | None = None
tokens_out: int | None = None
confidence: float | None = None
class AskResponse(BaseModel):
trace_id: str
db_id: str
sql: str
rationale: str
confidence: float
confidence_label: str
rows: list[list[Any]] | None
columns: list[str] | None
row_count: int
caption: str
output_format: str | None
error_kind: str | None
error_message: str
repair_attempted: bool
latency_ms: float
trace: list[TraceStep]
class EvalLatestResponse(BaseModel):
configuration: str
sql_model: str
overall_ea: float | None
n: int
report_path: str
# ---------------------------------------------------------- helpers
def _confidence_label(value: float) -> str:
if value >= 0.8:
return "High"
if value >= 0.5:
return "Medium"
if value > 0.0:
return "Low"
return "Unknown"
def _result_to_response(result: PipelineRunResult, *, latency_ms: float) -> AskResponse:
rows: list[list[Any]] | None = None
columns: list[str] | None = None
row_count = 0
if result.outcome is not None and result.outcome.result is not None:
rows = [list(r) for r in result.outcome.result.rows]
columns = list(result.outcome.result.columns)
row_count = result.outcome.result.row_count
trace_steps: list[TraceStep] = []
for step in result.trace:
trace_steps.append(
TraceStep(
node=str(step.get("node", "?")),
model=step.get("model"), # type: ignore[arg-type]
tokens_in=step.get("input_tokens"), # type: ignore[arg-type]
tokens_out=step.get("output_tokens"), # type: ignore[arg-type]
confidence=step.get("confidence"), # type: ignore[arg-type]
)
)
fmt_name = None if result.output_format is None else type(result.output_format).__name__
return AskResponse(
trace_id=str(uuid.uuid4()),
db_id=result.db_id,
sql=result.sql,
rationale=result.rationale,
confidence=result.confidence,
confidence_label=_confidence_label(result.confidence),
rows=rows,
columns=columns,
row_count=row_count,
caption=result.caption,
output_format=fmt_name,
error_kind=result.error_kind.value if result.error_kind else None,
error_message=result.error_message,
repair_attempted=result.repair_attempted,
latency_ms=latency_ms,
trace=trace_steps,
)
# ---------------------------------------------------------- rate limit
class _TokenBucket:
"""Sliding-window token bucket per key.
Default: 60 requests per 60 seconds. Single-process state — fine for the
portfolio demo. Move to Redis if/when running multiple replicas.
"""
def __init__(self, *, max_req: int = 60, window_s: int = 60) -> None:
self.max_req = max_req
self.window_s = window_s
self._hits: dict[str, deque[float]] = defaultdict(deque)
def check(self, key: str) -> tuple[bool, int]:
now = time.time()
bucket = self._hits[key]
cutoff = now - self.window_s
while bucket and bucket[0] < cutoff:
bucket.popleft()
if len(bucket) >= self.max_req:
retry_after = int(self.window_s - (now - bucket[0]))
return False, max(retry_after, 1)
bucket.append(now)
return True, 0
# ---------------------------------------------------------- bootstrap
def _build_pipeline_components(
settings: Settings,
) -> tuple[DatabaseRegistry, SchemaIndex, LLMProvider, LLMProvider]:
if not settings.mistral_api_key:
raise RuntimeError("MISTRAL_API_KEY is not set — API can't bootstrap embeddings.")
raw_sql = build_provider(settings.default_provider, settings=settings)
sql_provider: LLMProvider = CachingLLMProvider(raw_sql, cache_dir=settings.llm_cache_dir)
explain_provider: LLMProvider = sql_provider
raw_embed: EmbeddingProvider = MistralProvider(
api_key=settings.mistral_api_key,
gen_model=settings.mistral_gen_model,
embed_model=settings.mistral_embed_model,
base_url=settings.mistral_base_url,
)
embedder: EmbeddingProvider = CachingEmbeddingProvider(
raw_embed, cache_dir=settings.llm_cache_dir
)
schema_index = SchemaIndex(persist_dir="chroma_data", embedder=embedder)
registry = get_default_registry()
return registry, schema_index, sql_provider, explain_provider
class Singletons(NamedTuple):
"""The four runtime objects the API routes share.
Exposed as a public type so tests can construct mock instances and feed
them through ``app.dependency_overrides[get_singletons]``.
"""
pipeline: Any
registry: DatabaseRegistry
schema_index: SchemaIndex
sql_provider: LLMProvider
@lru_cache(maxsize=1)
def _make_singletons() -> Singletons:
"""Lazy: build the pipeline only when the first /ask hits — keeps /healthz
fast and avoids touching Chroma when the API is used for status probes."""
import os
settings = get_settings()
registry, schema_index, sql_provider, explain_provider = _build_pipeline_components(settings)
# Eval-script env toggles bootstrap into PipelineConfig once at boot;
# individual nodes never read os.environ at runtime (see graph.py docstrings).
config = PipelineConfig(
sql_provider=sql_provider,
explain_provider=explain_provider,
schema_index=schema_index,
registry=registry,
fewshot_top_k=3,
sort_schema_block=True,
cross_db_fewshot=True,
verify_retry_on_empty=True,
use_m_schema=os.environ.get("NLSQL_M_SCHEMA") == "1",
use_dac_prompt=os.environ.get("NLSQL_DAC") == "1",
)
pipeline = build_pipeline(config)
return Singletons(pipeline, registry, schema_index, sql_provider)
def get_singletons() -> Singletons:
"""FastAPI Depends-able factory; tests override via ``app.dependency_overrides``."""
return _make_singletons()
def create_app() -> FastAPI:
app = FastAPI(
title="NL→SQL Assistant",
version=__version__,
description=(
"Portfolio API: natural-language questions → SQL → executed rows. "
"BIRD Mini-Dev 57% hybrid, Chinook 100%, $0 budget, AST safety guards."
),
)
settings = get_settings()
rate_limiter = _TokenBucket(max_req=60, window_s=60)
api_key_env = "" # `NL_SQL_API_KEY` via env, optional
import os
api_key_env = os.environ.get("NL_SQL_API_KEY", "")
async def require_api_key(
request: Request,
x_api_key: str | None = Header(default=None, alias="X-API-Key"),
) -> str:
if not api_key_env:
# Auth off entirely when no key is configured — useful for local
# eval drivers and the Streamlit UI bootstrapping side-by-side.
return "anonymous"
if x_api_key != api_key_env:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="missing or invalid X-API-Key header",
)
ok, retry_after = rate_limiter.check(x_api_key)
if not ok:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"rate limit exceeded; retry in {retry_after}s",
headers={"Retry-After": str(retry_after)},
)
return x_api_key
# --------------------------------------------------------- health / ready
@app.get("/healthz", response_model=HealthResponse, tags=["status"])
def healthz() -> HealthResponse:
configured: list[str] = []
if settings.mistral_api_key:
configured.append("mistral")
if settings.github_token:
configured.append("github_models")
if settings.groq_api_key:
configured.append("groq")
configured.append("ollama")
return HealthResponse(
status="ok",
version=__version__,
providers_configured=sorted(configured),
)
@app.get("/readyz", response_model=ReadyResponse, tags=["status"])
def readyz() -> ReadyResponse:
chroma_ok = False
registry_ok = False
registered = 0
schema_chunks = 0
try:
factory: Any = app.dependency_overrides.get(get_singletons, get_singletons)
singletons: Singletons = factory()
registered = len(singletons.registry.ids())
registry_ok = registered > 0
schema_chunks = singletons.schema_index.schema_collection.count()
chroma_ok = schema_chunks > 0
except Exception:
pass
all_ok = chroma_ok and registry_ok
return ReadyResponse(
status="ok" if all_ok else "not_ready",
chroma_ok=chroma_ok,
registry_ok=registry_ok,
registered_dbs=registered,
schema_chunks=schema_chunks,
)
# --------------------------------------------------------- product API
@app.get("/databases", response_model=DatabasesResponse, tags=["catalog"])
def databases(
_auth: str = Depends(require_api_key),
singletons: Singletons = Depends(get_singletons), # noqa: B008
) -> DatabasesResponse:
infos: list[DatabaseInfo] = []
for db_id in singletons.registry.ids():
spec = singletons.registry.get(db_id)
try:
records = singletons.schema_index.schema_collection.get(
where={"db_id": db_id}, include=["metadatas"]
)
table_count = len(records.get("metadatas") or [])
except Exception:
table_count = 0
infos.append(
DatabaseInfo(
db_id=db_id,
dialect=str(spec.dialect),
description=str(getattr(spec, "description", "") or ""),
table_count=table_count,
)
)
return DatabasesResponse(databases=infos)
@app.post("/ask", response_model=AskResponse, tags=["nl-sql"])
def ask(
req: AskRequest,
_auth: str = Depends(require_api_key),
singletons: Singletons = Depends(get_singletons), # noqa: B008
) -> AskResponse:
if req.db_id not in singletons.registry.ids():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"unknown db_id: {req.db_id!r}; see /databases for the list",
)
spec = singletons.registry.get(req.db_id)
t0 = time.perf_counter()
try:
result = run_pipeline(
singletons.pipeline,
question=req.question,
db_id=req.db_id,
dialect=spec.dialect,
verify_retry_on_empty=True,
)
except Exception as exc: # pragma: no cover — defensive
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"pipeline crashed: {type(exc).__name__}: {exc}",
) from exc
latency_ms = (time.perf_counter() - t0) * 1000.0
return _result_to_response(result, latency_ms=latency_ms)
@app.get("/eval/latest", response_model=EvalLatestResponse, tags=["transparency"])
def eval_latest() -> EvalLatestResponse:
"""Returns metadata of the latest hybrid eval report committed to repo."""
import json
baseline = Path("eval/baselines/hybrid_n200_v0.json")
if not baseline.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="no committed baseline yet — run scripts/eval_baseline.py",
)
data = json.loads(baseline.read_text(encoding="utf-8"))
overall = data.get("overall") or {}
return EvalLatestResponse(
configuration=str(data.get("configuration", "unknown")),
sql_model=str(data.get("sql_model", "unknown")),
overall_ea=overall.get("ea"),
n=int(overall.get("n") or 0),
report_path=str(baseline),
)
return app
app = create_app()