ishaq101's picture
feat/Catalog Retrieval System (#1)
6bff5d9
"""DbExecutor — runs a compiled IR against a user's external SQL database.
Pipeline:
IR → SqlCompiler.compile() → CompiledSql(sql, params)
sqlglot guard (defense-in-depth: SELECT-only, no DML / DDL)
resolve creds (catalog.location_ref → dbclient://{client_id} → DatabaseClient
row → Fernet decrypt)
asyncio.to_thread(_run_sync)
└ db_pipeline_service.engine_scope(db_type, creds)
└ session-level: default_transaction_read_only + statement_timeout=30s
(postgres / supabase only)
└ engine.execute(text(sql), params)
QueryResult (always returned — errors populate `.error`, never raised)
"""
from __future__ import annotations
import asyncio
import time
from typing import Any
import sqlglot
import sqlglot.expressions as exp
from sqlalchemy import text
from ...catalog.models import Catalog, Source
from ...database_client.database_client_service import database_client_service
from ...db.postgres.connection import AsyncSessionLocal
from ...middlewares.logging import get_logger
from ...pipeline.db_pipeline import db_pipeline_service
from ...utils.db_credential_encryption import decrypt_credentials_dict
from ..compiler.sql import CompiledSql, SqlCompiler
from ..ir.models import QueryIR
from .base import BaseExecutor, QueryResult
logger = get_logger("db_executor")
_QUERY_TIMEOUT_SECONDS = 30
_ROW_HARD_CAP = 10_000 # belt-and-suspenders cap regardless of LIMIT
_DBCLIENT_PREFIX = "dbclient://"
_POSTGRES_LIKE = frozenset({"postgres", "supabase"})
class DbExecutor(BaseExecutor):
"""Executes compiled SQL on the user's registered DB.
Constructed once per query with the user's catalog. The catalog is the
source of truth for identifiers; the executor never touches the user's
DB metadata at execution time.
"""
def __init__(self, catalog: Catalog) -> None:
self._catalog = catalog
self._compiler = SqlCompiler(catalog)
async def run(self, ir: QueryIR) -> QueryResult:
started = time.perf_counter()
table_name = ""
source_name = ""
try:
source = self._find_source(ir.source_id)
source_name = source.name
table_name = next(
(t.name for t in source.tables if t.table_id == ir.table_id), ""
)
if source.source_type != "schema":
raise ValueError(
f"DbExecutor cannot run on source_type={source.source_type!r}; "
"expected 'schema'"
)
compiled = self._compiler.compile(ir)
self._sqlglot_guard(compiled.sql)
client_id = self._parse_client_id(source.location_ref)
client = await self._fetch_client(client_id)
if client.user_id != self._catalog.user_id:
raise PermissionError(
f"DatabaseClient {client_id!r} owner mismatch "
f"(client.user_id != catalog.user_id)"
)
creds = decrypt_credentials_dict(client.credentials)
columns, rows = await asyncio.wait_for(
asyncio.to_thread(self._run_sync, client.db_type, creds, compiled),
timeout=_QUERY_TIMEOUT_SECONDS,
)
truncated = len(rows) > _ROW_HARD_CAP
capped = rows[:_ROW_HARD_CAP]
elapsed_ms = int((time.perf_counter() - started) * 1000)
logger.info(
"db query complete",
source_id=ir.source_id,
rows=len(capped),
truncated=truncated,
elapsed_ms=elapsed_ms,
)
return QueryResult(
source_id=ir.source_id,
backend="sql",
columns=columns,
rows=capped,
row_count=len(capped),
truncated=truncated,
elapsed_ms=elapsed_ms,
table_id=ir.table_id,
table_name=table_name,
source_name=source_name,
)
except Exception as e:
elapsed_ms = int((time.perf_counter() - started) * 1000)
logger.error(
"db executor failed",
source_id=ir.source_id,
error=str(e),
elapsed_ms=elapsed_ms,
)
return QueryResult(
source_id=ir.source_id,
backend="sql",
elapsed_ms=elapsed_ms,
error=str(e),
table_id=ir.table_id,
table_name=table_name,
source_name=source_name,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _find_source(self, source_id: str) -> Source:
for s in self._catalog.sources:
if s.source_id == source_id:
return s
raise ValueError(f"source_id {source_id!r} not in catalog")
@staticmethod
def _parse_client_id(location_ref: str) -> str:
if not location_ref.startswith(_DBCLIENT_PREFIX):
raise ValueError(
f"DbExecutor expects 'dbclient://...' location_ref, got {location_ref!r}"
)
client_id = location_ref[len(_DBCLIENT_PREFIX):]
if not client_id:
raise ValueError("location_ref is missing client_id after 'dbclient://'")
return client_id
@staticmethod
async def _fetch_client(client_id: str) -> Any:
async with AsyncSessionLocal() as session:
client = await database_client_service.get(session, client_id)
if client is None:
raise ValueError(f"DatabaseClient {client_id!r} not found")
if client.status != "active":
raise ValueError(
f"DatabaseClient {client_id!r} is not active "
f"(status={client.status!r})"
)
return client
@staticmethod
def _sqlglot_guard(sql: str) -> None:
"""Defense-in-depth: ensure the compiled SQL is a SELECT statement.
The compiler is already deterministic and only constructs SELECTs from
validated IR, but this guard catches any future bug that could leak
DML/DDL through.
"""
try:
parsed = sqlglot.parse_one(sql, read="postgres")
except sqlglot.errors.ParseError as e:
raise ValueError(f"compiled SQL failed to parse: {e}") from e
if not isinstance(parsed, exp.Select):
raise ValueError(
f"compiled SQL is not a SELECT (got {type(parsed).__name__})"
)
forbidden = (exp.Insert, exp.Update, exp.Delete, exp.Drop, exp.Alter)
for node in parsed.find_all(forbidden):
raise ValueError(
f"compiled SQL contains forbidden DML/DDL: {type(node).__name__}"
)
@staticmethod
def _run_sync(db_type: str, creds: dict, compiled: CompiledSql) -> tuple[list[str], list[dict]]:
with db_pipeline_service.engine_scope(db_type, creds) as engine:
with engine.connect() as conn:
if db_type in _POSTGRES_LIKE:
# session-level read-only + per-statement timeout (ms)
conn.execute(text("SET default_transaction_read_only = on"))
conn.execute(
text(f"SET statement_timeout = {_QUERY_TIMEOUT_SECONDS * 1000}")
)
result = conn.execute(text(compiled.sql), compiled.params)
columns = list(result.keys())
rows = [dict(row) for row in result.mappings()]
return columns, rows