| """DbExecutor — runs a compiled IR against a user's external SQL database. |
| |
| Pipeline: |
| IR → SqlCompiler.compile() → CompiledSql(sql, params) |
| ↓ |
| sqlglot guard (defense-in-depth: SELECT-only, no DML / DDL) |
| ↓ |
| resolve creds (catalog.location_ref → dbclient://{client_id} → DatabaseClient |
| row → Fernet decrypt) |
| ↓ |
| asyncio.to_thread(_run_sync) |
| └ db_pipeline_service.engine_scope(db_type, creds) |
| └ session-level: default_transaction_read_only + statement_timeout=30s |
| (postgres / supabase only) |
| └ engine.execute(text(sql), params) |
| ↓ |
| QueryResult (always returned — errors populate `.error`, never raised) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import time |
| from typing import Any |
|
|
| import sqlglot |
| import sqlglot.expressions as exp |
| from sqlalchemy import text |
|
|
| from ...catalog.models import Catalog, Source |
| from ...database_client.database_client_service import database_client_service |
| from ...db.postgres.connection import AsyncSessionLocal |
| from ...middlewares.logging import get_logger |
| from ...pipeline.db_pipeline import db_pipeline_service |
| from ...utils.db_credential_encryption import decrypt_credentials_dict |
| from ..compiler.sql import CompiledSql, SqlCompiler |
| from ..ir.models import QueryIR |
| from .base import BaseExecutor, QueryResult |
|
|
| logger = get_logger("db_executor") |
|
|
| _QUERY_TIMEOUT_SECONDS = 30 |
| _ROW_HARD_CAP = 10_000 |
| _DBCLIENT_PREFIX = "dbclient://" |
| _POSTGRES_LIKE = frozenset({"postgres", "supabase"}) |
|
|
|
|
| class DbExecutor(BaseExecutor): |
| """Executes compiled SQL on the user's registered DB. |
| |
| Constructed once per query with the user's catalog. The catalog is the |
| source of truth for identifiers; the executor never touches the user's |
| DB metadata at execution time. |
| """ |
|
|
| def __init__(self, catalog: Catalog) -> None: |
| self._catalog = catalog |
| self._compiler = SqlCompiler(catalog) |
|
|
| async def run(self, ir: QueryIR) -> QueryResult: |
| started = time.perf_counter() |
| table_name = "" |
| source_name = "" |
| try: |
| source = self._find_source(ir.source_id) |
| source_name = source.name |
| table_name = next( |
| (t.name for t in source.tables if t.table_id == ir.table_id), "" |
| ) |
| if source.source_type != "schema": |
| raise ValueError( |
| f"DbExecutor cannot run on source_type={source.source_type!r}; " |
| "expected 'schema'" |
| ) |
|
|
| compiled = self._compiler.compile(ir) |
| self._sqlglot_guard(compiled.sql) |
|
|
| client_id = self._parse_client_id(source.location_ref) |
| client = await self._fetch_client(client_id) |
| if client.user_id != self._catalog.user_id: |
| raise PermissionError( |
| f"DatabaseClient {client_id!r} owner mismatch " |
| f"(client.user_id != catalog.user_id)" |
| ) |
| creds = decrypt_credentials_dict(client.credentials) |
|
|
| columns, rows = await asyncio.wait_for( |
| asyncio.to_thread(self._run_sync, client.db_type, creds, compiled), |
| timeout=_QUERY_TIMEOUT_SECONDS, |
| ) |
|
|
| truncated = len(rows) > _ROW_HARD_CAP |
| capped = rows[:_ROW_HARD_CAP] |
| elapsed_ms = int((time.perf_counter() - started) * 1000) |
| logger.info( |
| "db query complete", |
| source_id=ir.source_id, |
| rows=len(capped), |
| truncated=truncated, |
| elapsed_ms=elapsed_ms, |
| ) |
| return QueryResult( |
| source_id=ir.source_id, |
| backend="sql", |
| columns=columns, |
| rows=capped, |
| row_count=len(capped), |
| truncated=truncated, |
| elapsed_ms=elapsed_ms, |
| table_id=ir.table_id, |
| table_name=table_name, |
| source_name=source_name, |
| ) |
|
|
| except Exception as e: |
| elapsed_ms = int((time.perf_counter() - started) * 1000) |
| logger.error( |
| "db executor failed", |
| source_id=ir.source_id, |
| error=str(e), |
| elapsed_ms=elapsed_ms, |
| ) |
| return QueryResult( |
| source_id=ir.source_id, |
| backend="sql", |
| elapsed_ms=elapsed_ms, |
| error=str(e), |
| table_id=ir.table_id, |
| table_name=table_name, |
| source_name=source_name, |
| ) |
|
|
| |
| |
| |
|
|
| def _find_source(self, source_id: str) -> Source: |
| for s in self._catalog.sources: |
| if s.source_id == source_id: |
| return s |
| raise ValueError(f"source_id {source_id!r} not in catalog") |
|
|
| @staticmethod |
| def _parse_client_id(location_ref: str) -> str: |
| if not location_ref.startswith(_DBCLIENT_PREFIX): |
| raise ValueError( |
| f"DbExecutor expects 'dbclient://...' location_ref, got {location_ref!r}" |
| ) |
| client_id = location_ref[len(_DBCLIENT_PREFIX):] |
| if not client_id: |
| raise ValueError("location_ref is missing client_id after 'dbclient://'") |
| return client_id |
|
|
| @staticmethod |
| async def _fetch_client(client_id: str) -> Any: |
| async with AsyncSessionLocal() as session: |
| client = await database_client_service.get(session, client_id) |
| if client is None: |
| raise ValueError(f"DatabaseClient {client_id!r} not found") |
| if client.status != "active": |
| raise ValueError( |
| f"DatabaseClient {client_id!r} is not active " |
| f"(status={client.status!r})" |
| ) |
| return client |
|
|
| @staticmethod |
| def _sqlglot_guard(sql: str) -> None: |
| """Defense-in-depth: ensure the compiled SQL is a SELECT statement. |
| |
| The compiler is already deterministic and only constructs SELECTs from |
| validated IR, but this guard catches any future bug that could leak |
| DML/DDL through. |
| """ |
| try: |
| parsed = sqlglot.parse_one(sql, read="postgres") |
| except sqlglot.errors.ParseError as e: |
| raise ValueError(f"compiled SQL failed to parse: {e}") from e |
| if not isinstance(parsed, exp.Select): |
| raise ValueError( |
| f"compiled SQL is not a SELECT (got {type(parsed).__name__})" |
| ) |
| forbidden = (exp.Insert, exp.Update, exp.Delete, exp.Drop, exp.Alter) |
| for node in parsed.find_all(forbidden): |
| raise ValueError( |
| f"compiled SQL contains forbidden DML/DDL: {type(node).__name__}" |
| ) |
|
|
| @staticmethod |
| def _run_sync(db_type: str, creds: dict, compiled: CompiledSql) -> tuple[list[str], list[dict]]: |
| with db_pipeline_service.engine_scope(db_type, creds) as engine: |
| with engine.connect() as conn: |
| if db_type in _POSTGRES_LIKE: |
| |
| conn.execute(text("SET default_transaction_read_only = on")) |
| conn.execute( |
| text(f"SET statement_timeout = {_QUERY_TIMEOUT_SECONDS * 1000}") |
| ) |
| result = conn.execute(text(compiled.sql), compiled.params) |
| columns = list(result.keys()) |
| rows = [dict(row) for row in result.mappings()] |
| return columns, rows |
|
|