File size: 7,829 Bytes
6bff5d9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | """DbExecutor β runs a compiled IR against a user's external SQL database.
Pipeline:
IR β SqlCompiler.compile() β CompiledSql(sql, params)
β
sqlglot guard (defense-in-depth: SELECT-only, no DML / DDL)
β
resolve creds (catalog.location_ref β dbclient://{client_id} β DatabaseClient
row β Fernet decrypt)
β
asyncio.to_thread(_run_sync)
β db_pipeline_service.engine_scope(db_type, creds)
β session-level: default_transaction_read_only + statement_timeout=30s
(postgres / supabase only)
β engine.execute(text(sql), params)
β
QueryResult (always returned β errors populate `.error`, never raised)
"""
from __future__ import annotations
import asyncio
import time
from typing import Any
import sqlglot
import sqlglot.expressions as exp
from sqlalchemy import text
from ...catalog.models import Catalog, Source
from ...database_client.database_client_service import database_client_service
from ...db.postgres.connection import AsyncSessionLocal
from ...middlewares.logging import get_logger
from ...pipeline.db_pipeline import db_pipeline_service
from ...utils.db_credential_encryption import decrypt_credentials_dict
from ..compiler.sql import CompiledSql, SqlCompiler
from ..ir.models import QueryIR
from .base import BaseExecutor, QueryResult
logger = get_logger("db_executor")
_QUERY_TIMEOUT_SECONDS = 30
_ROW_HARD_CAP = 10_000 # belt-and-suspenders cap regardless of LIMIT
_DBCLIENT_PREFIX = "dbclient://"
_POSTGRES_LIKE = frozenset({"postgres", "supabase"})
class DbExecutor(BaseExecutor):
"""Executes compiled SQL on the user's registered DB.
Constructed once per query with the user's catalog. The catalog is the
source of truth for identifiers; the executor never touches the user's
DB metadata at execution time.
"""
def __init__(self, catalog: Catalog) -> None:
self._catalog = catalog
self._compiler = SqlCompiler(catalog)
async def run(self, ir: QueryIR) -> QueryResult:
started = time.perf_counter()
table_name = ""
source_name = ""
try:
source = self._find_source(ir.source_id)
source_name = source.name
table_name = next(
(t.name for t in source.tables if t.table_id == ir.table_id), ""
)
if source.source_type != "schema":
raise ValueError(
f"DbExecutor cannot run on source_type={source.source_type!r}; "
"expected 'schema'"
)
compiled = self._compiler.compile(ir)
self._sqlglot_guard(compiled.sql)
client_id = self._parse_client_id(source.location_ref)
client = await self._fetch_client(client_id)
if client.user_id != self._catalog.user_id:
raise PermissionError(
f"DatabaseClient {client_id!r} owner mismatch "
f"(client.user_id != catalog.user_id)"
)
creds = decrypt_credentials_dict(client.credentials)
columns, rows = await asyncio.wait_for(
asyncio.to_thread(self._run_sync, client.db_type, creds, compiled),
timeout=_QUERY_TIMEOUT_SECONDS,
)
truncated = len(rows) > _ROW_HARD_CAP
capped = rows[:_ROW_HARD_CAP]
elapsed_ms = int((time.perf_counter() - started) * 1000)
logger.info(
"db query complete",
source_id=ir.source_id,
rows=len(capped),
truncated=truncated,
elapsed_ms=elapsed_ms,
)
return QueryResult(
source_id=ir.source_id,
backend="sql",
columns=columns,
rows=capped,
row_count=len(capped),
truncated=truncated,
elapsed_ms=elapsed_ms,
table_id=ir.table_id,
table_name=table_name,
source_name=source_name,
)
except Exception as e:
elapsed_ms = int((time.perf_counter() - started) * 1000)
logger.error(
"db executor failed",
source_id=ir.source_id,
error=str(e),
elapsed_ms=elapsed_ms,
)
return QueryResult(
source_id=ir.source_id,
backend="sql",
elapsed_ms=elapsed_ms,
error=str(e),
table_id=ir.table_id,
table_name=table_name,
source_name=source_name,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _find_source(self, source_id: str) -> Source:
for s in self._catalog.sources:
if s.source_id == source_id:
return s
raise ValueError(f"source_id {source_id!r} not in catalog")
@staticmethod
def _parse_client_id(location_ref: str) -> str:
if not location_ref.startswith(_DBCLIENT_PREFIX):
raise ValueError(
f"DbExecutor expects 'dbclient://...' location_ref, got {location_ref!r}"
)
client_id = location_ref[len(_DBCLIENT_PREFIX):]
if not client_id:
raise ValueError("location_ref is missing client_id after 'dbclient://'")
return client_id
@staticmethod
async def _fetch_client(client_id: str) -> Any:
async with AsyncSessionLocal() as session:
client = await database_client_service.get(session, client_id)
if client is None:
raise ValueError(f"DatabaseClient {client_id!r} not found")
if client.status != "active":
raise ValueError(
f"DatabaseClient {client_id!r} is not active "
f"(status={client.status!r})"
)
return client
@staticmethod
def _sqlglot_guard(sql: str) -> None:
"""Defense-in-depth: ensure the compiled SQL is a SELECT statement.
The compiler is already deterministic and only constructs SELECTs from
validated IR, but this guard catches any future bug that could leak
DML/DDL through.
"""
try:
parsed = sqlglot.parse_one(sql, read="postgres")
except sqlglot.errors.ParseError as e:
raise ValueError(f"compiled SQL failed to parse: {e}") from e
if not isinstance(parsed, exp.Select):
raise ValueError(
f"compiled SQL is not a SELECT (got {type(parsed).__name__})"
)
forbidden = (exp.Insert, exp.Update, exp.Delete, exp.Drop, exp.Alter)
for node in parsed.find_all(forbidden):
raise ValueError(
f"compiled SQL contains forbidden DML/DDL: {type(node).__name__}"
)
@staticmethod
def _run_sync(db_type: str, creds: dict, compiled: CompiledSql) -> tuple[list[str], list[dict]]:
with db_pipeline_service.engine_scope(db_type, creds) as engine:
with engine.connect() as conn:
if db_type in _POSTGRES_LIKE:
# session-level read-only + per-statement timeout (ms)
conn.execute(text("SET default_transaction_read_only = on"))
conn.execute(
text(f"SET statement_timeout = {_QUERY_TIMEOUT_SECONDS * 1000}")
)
result = conn.execute(text(compiled.sql), compiled.params)
columns = list(result.keys())
rows = [dict(row) for row in result.mappings()]
return columns, rows
|