File size: 7,829 Bytes
6bff5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""DbExecutor β€” runs a compiled IR against a user's external SQL database.

Pipeline:
  IR β†’ SqlCompiler.compile()  β†’  CompiledSql(sql, params)
       ↓
  sqlglot guard  (defense-in-depth: SELECT-only, no DML / DDL)
       ↓
  resolve creds (catalog.location_ref β†’ dbclient://{client_id} β†’ DatabaseClient
                 row β†’ Fernet decrypt)
       ↓
  asyncio.to_thread(_run_sync)
    β”” db_pipeline_service.engine_scope(db_type, creds)
       β”” session-level: default_transaction_read_only + statement_timeout=30s
                        (postgres / supabase only)
       β”” engine.execute(text(sql), params)
       ↓
  QueryResult (always returned β€” errors populate `.error`, never raised)
"""

from __future__ import annotations

import asyncio
import time
from typing import Any

import sqlglot
import sqlglot.expressions as exp
from sqlalchemy import text

from ...catalog.models import Catalog, Source
from ...database_client.database_client_service import database_client_service
from ...db.postgres.connection import AsyncSessionLocal
from ...middlewares.logging import get_logger
from ...pipeline.db_pipeline import db_pipeline_service
from ...utils.db_credential_encryption import decrypt_credentials_dict
from ..compiler.sql import CompiledSql, SqlCompiler
from ..ir.models import QueryIR
from .base import BaseExecutor, QueryResult

logger = get_logger("db_executor")

_QUERY_TIMEOUT_SECONDS = 30
_ROW_HARD_CAP = 10_000  # belt-and-suspenders cap regardless of LIMIT
_DBCLIENT_PREFIX = "dbclient://"
_POSTGRES_LIKE = frozenset({"postgres", "supabase"})


class DbExecutor(BaseExecutor):
    """Executes compiled SQL on the user's registered DB.

    Constructed once per query with the user's catalog. The catalog is the
    source of truth for identifiers; the executor never touches the user's
    DB metadata at execution time.
    """

    def __init__(self, catalog: Catalog) -> None:
        self._catalog = catalog
        self._compiler = SqlCompiler(catalog)

    async def run(self, ir: QueryIR) -> QueryResult:
        started = time.perf_counter()
        table_name = ""
        source_name = ""
        try:
            source = self._find_source(ir.source_id)
            source_name = source.name
            table_name = next(
                (t.name for t in source.tables if t.table_id == ir.table_id), ""
            )
            if source.source_type != "schema":
                raise ValueError(
                    f"DbExecutor cannot run on source_type={source.source_type!r}; "
                    "expected 'schema'"
                )

            compiled = self._compiler.compile(ir)
            self._sqlglot_guard(compiled.sql)

            client_id = self._parse_client_id(source.location_ref)
            client = await self._fetch_client(client_id)
            if client.user_id != self._catalog.user_id:
                raise PermissionError(
                    f"DatabaseClient {client_id!r} owner mismatch "
                    f"(client.user_id != catalog.user_id)"
                )
            creds = decrypt_credentials_dict(client.credentials)

            columns, rows = await asyncio.wait_for(
                asyncio.to_thread(self._run_sync, client.db_type, creds, compiled),
                timeout=_QUERY_TIMEOUT_SECONDS,
            )

            truncated = len(rows) > _ROW_HARD_CAP
            capped = rows[:_ROW_HARD_CAP]
            elapsed_ms = int((time.perf_counter() - started) * 1000)
            logger.info(
                "db query complete",
                source_id=ir.source_id,
                rows=len(capped),
                truncated=truncated,
                elapsed_ms=elapsed_ms,
            )
            return QueryResult(
                source_id=ir.source_id,
                backend="sql",
                columns=columns,
                rows=capped,
                row_count=len(capped),
                truncated=truncated,
                elapsed_ms=elapsed_ms,
                table_id=ir.table_id,
                table_name=table_name,
                source_name=source_name,
            )

        except Exception as e:
            elapsed_ms = int((time.perf_counter() - started) * 1000)
            logger.error(
                "db executor failed",
                source_id=ir.source_id,
                error=str(e),
                elapsed_ms=elapsed_ms,
            )
            return QueryResult(
                source_id=ir.source_id,
                backend="sql",
                elapsed_ms=elapsed_ms,
                error=str(e),
                table_id=ir.table_id,
                table_name=table_name,
                source_name=source_name,
            )

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    def _find_source(self, source_id: str) -> Source:
        for s in self._catalog.sources:
            if s.source_id == source_id:
                return s
        raise ValueError(f"source_id {source_id!r} not in catalog")

    @staticmethod
    def _parse_client_id(location_ref: str) -> str:
        if not location_ref.startswith(_DBCLIENT_PREFIX):
            raise ValueError(
                f"DbExecutor expects 'dbclient://...' location_ref, got {location_ref!r}"
            )
        client_id = location_ref[len(_DBCLIENT_PREFIX):]
        if not client_id:
            raise ValueError("location_ref is missing client_id after 'dbclient://'")
        return client_id

    @staticmethod
    async def _fetch_client(client_id: str) -> Any:
        async with AsyncSessionLocal() as session:
            client = await database_client_service.get(session, client_id)
        if client is None:
            raise ValueError(f"DatabaseClient {client_id!r} not found")
        if client.status != "active":
            raise ValueError(
                f"DatabaseClient {client_id!r} is not active "
                f"(status={client.status!r})"
            )
        return client

    @staticmethod
    def _sqlglot_guard(sql: str) -> None:
        """Defense-in-depth: ensure the compiled SQL is a SELECT statement.

        The compiler is already deterministic and only constructs SELECTs from
        validated IR, but this guard catches any future bug that could leak
        DML/DDL through.
        """
        try:
            parsed = sqlglot.parse_one(sql, read="postgres")
        except sqlglot.errors.ParseError as e:
            raise ValueError(f"compiled SQL failed to parse: {e}") from e
        if not isinstance(parsed, exp.Select):
            raise ValueError(
                f"compiled SQL is not a SELECT (got {type(parsed).__name__})"
            )
        forbidden = (exp.Insert, exp.Update, exp.Delete, exp.Drop, exp.Alter)
        for node in parsed.find_all(forbidden):
            raise ValueError(
                f"compiled SQL contains forbidden DML/DDL: {type(node).__name__}"
            )

    @staticmethod
    def _run_sync(db_type: str, creds: dict, compiled: CompiledSql) -> tuple[list[str], list[dict]]:
        with db_pipeline_service.engine_scope(db_type, creds) as engine:
            with engine.connect() as conn:
                if db_type in _POSTGRES_LIKE:
                    # session-level read-only + per-statement timeout (ms)
                    conn.execute(text("SET default_transaction_read_only = on"))
                    conn.execute(
                        text(f"SET statement_timeout = {_QUERY_TIMEOUT_SECONDS * 1000}")
                    )
                result = conn.execute(text(compiled.sql), compiled.params)
                columns = list(result.keys())
                rows = [dict(row) for row in result.mappings()]
                return columns, rows