"""DbExecutor — runs a compiled IR against a user's external SQL database. Pipeline: IR → SqlCompiler.compile() → CompiledSql(sql, params) ↓ sqlglot guard (defense-in-depth: SELECT-only, no DML / DDL) ↓ resolve creds (catalog.location_ref → dbclient://{client_id} → DatabaseClient row → Fernet decrypt) ↓ asyncio.to_thread(_run_sync) └ db_pipeline_service.engine_scope(db_type, creds) └ session-level: default_transaction_read_only + statement_timeout=30s (postgres / supabase only) └ engine.execute(text(sql), params) ↓ QueryResult (always returned — errors populate `.error`, never raised) """ from __future__ import annotations import asyncio import time from typing import Any import sqlglot import sqlglot.expressions as exp from sqlalchemy import text from ...catalog.models import Catalog, Source from ...database_client.database_client_service import database_client_service from ...db.postgres.connection import AsyncSessionLocal from ...middlewares.logging import get_logger from ...pipeline.db_pipeline import db_pipeline_service from ...utils.db_credential_encryption import decrypt_credentials_dict from ..compiler.sql import CompiledSql, SqlCompiler from ..ir.models import QueryIR from .base import BaseExecutor, QueryResult logger = get_logger("db_executor") _QUERY_TIMEOUT_SECONDS = 30 _ROW_HARD_CAP = 10_000 # belt-and-suspenders cap regardless of LIMIT _DBCLIENT_PREFIX = "dbclient://" _POSTGRES_LIKE = frozenset({"postgres", "supabase"}) class DbExecutor(BaseExecutor): """Executes compiled SQL on the user's registered DB. Constructed once per query with the user's catalog. The catalog is the source of truth for identifiers; the executor never touches the user's DB metadata at execution time. """ def __init__(self, catalog: Catalog) -> None: self._catalog = catalog self._compiler = SqlCompiler(catalog) async def run(self, ir: QueryIR) -> QueryResult: started = time.perf_counter() table_name = "" source_name = "" try: source = self._find_source(ir.source_id) source_name = source.name table_name = next( (t.name for t in source.tables if t.table_id == ir.table_id), "" ) if source.source_type != "schema": raise ValueError( f"DbExecutor cannot run on source_type={source.source_type!r}; " "expected 'schema'" ) compiled = self._compiler.compile(ir) self._sqlglot_guard(compiled.sql) client_id = self._parse_client_id(source.location_ref) client = await self._fetch_client(client_id) if client.user_id != self._catalog.user_id: raise PermissionError( f"DatabaseClient {client_id!r} owner mismatch " f"(client.user_id != catalog.user_id)" ) creds = decrypt_credentials_dict(client.credentials) columns, rows = await asyncio.wait_for( asyncio.to_thread(self._run_sync, client.db_type, creds, compiled), timeout=_QUERY_TIMEOUT_SECONDS, ) truncated = len(rows) > _ROW_HARD_CAP capped = rows[:_ROW_HARD_CAP] elapsed_ms = int((time.perf_counter() - started) * 1000) logger.info( "db query complete", source_id=ir.source_id, rows=len(capped), truncated=truncated, elapsed_ms=elapsed_ms, ) return QueryResult( source_id=ir.source_id, backend="sql", columns=columns, rows=capped, row_count=len(capped), truncated=truncated, elapsed_ms=elapsed_ms, table_id=ir.table_id, table_name=table_name, source_name=source_name, ) except Exception as e: elapsed_ms = int((time.perf_counter() - started) * 1000) logger.error( "db executor failed", source_id=ir.source_id, error=str(e), elapsed_ms=elapsed_ms, ) return QueryResult( source_id=ir.source_id, backend="sql", elapsed_ms=elapsed_ms, error=str(e), table_id=ir.table_id, table_name=table_name, source_name=source_name, ) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _find_source(self, source_id: str) -> Source: for s in self._catalog.sources: if s.source_id == source_id: return s raise ValueError(f"source_id {source_id!r} not in catalog") @staticmethod def _parse_client_id(location_ref: str) -> str: if not location_ref.startswith(_DBCLIENT_PREFIX): raise ValueError( f"DbExecutor expects 'dbclient://...' location_ref, got {location_ref!r}" ) client_id = location_ref[len(_DBCLIENT_PREFIX):] if not client_id: raise ValueError("location_ref is missing client_id after 'dbclient://'") return client_id @staticmethod async def _fetch_client(client_id: str) -> Any: async with AsyncSessionLocal() as session: client = await database_client_service.get(session, client_id) if client is None: raise ValueError(f"DatabaseClient {client_id!r} not found") if client.status != "active": raise ValueError( f"DatabaseClient {client_id!r} is not active " f"(status={client.status!r})" ) return client @staticmethod def _sqlglot_guard(sql: str) -> None: """Defense-in-depth: ensure the compiled SQL is a SELECT statement. The compiler is already deterministic and only constructs SELECTs from validated IR, but this guard catches any future bug that could leak DML/DDL through. """ try: parsed = sqlglot.parse_one(sql, read="postgres") except sqlglot.errors.ParseError as e: raise ValueError(f"compiled SQL failed to parse: {e}") from e if not isinstance(parsed, exp.Select): raise ValueError( f"compiled SQL is not a SELECT (got {type(parsed).__name__})" ) forbidden = (exp.Insert, exp.Update, exp.Delete, exp.Drop, exp.Alter) for node in parsed.find_all(forbidden): raise ValueError( f"compiled SQL contains forbidden DML/DDL: {type(node).__name__}" ) @staticmethod def _run_sync(db_type: str, creds: dict, compiled: CompiledSql) -> tuple[list[str], list[dict]]: with db_pipeline_service.engine_scope(db_type, creds) as engine: with engine.connect() as conn: if db_type in _POSTGRES_LIKE: # session-level read-only + per-statement timeout (ms) conn.execute(text("SET default_transaction_read_only = on")) conn.execute( text(f"SET statement_timeout = {_QUERY_TIMEOUT_SECONDS * 1000}") ) result = conn.execute(text(compiled.sql), compiled.params) columns = list(result.keys()) rows = [dict(row) for row in result.mappings()] return columns, rows