File size: 2,867 Bytes
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""End-to-end SQL execution: validate → execute under runtime limits → report.

Single entry point `execute_validated`. Pipeline nodes call this instead of
threading the guard + DB layers themselves.
"""

from __future__ import annotations

from dataclasses import dataclass

from sqlalchemy import Engine
from sqlalchemy.exc import OperationalError, SQLAlchemyError

from nl_sql.db.connection import Dialect, QueryResult, execute_readonly
from nl_sql.execution.errors import ExecutionErrorKind
from nl_sql.execution.guards import ValidationReport, validate_sql


@dataclass(frozen=True, slots=True)
class ExecutionOutcome:
    """Combined result + error taxonomy.

    Exactly one of `result` or `error_kind` is set. `validation` is always
    present so callers can render the AST-level diagnostics regardless of
    whether execution actually ran.
    """

    sql: str
    validation: ValidationReport
    result: QueryResult | None = None
    error_kind: ExecutionErrorKind | None = None
    error_message: str = ""

    @property
    def ok(self) -> bool:
        return self.result is not None and self.error_kind is None


def execute_validated(
    engine: Engine,
    sql: str,
    *,
    dialect: Dialect = "sqlite",
    statement_timeout_ms: int = 30_000,
    row_cap: int = 10_000,
) -> ExecutionOutcome:
    validation = validate_sql(sql, dialect=dialect)
    if not validation.ok:
        return ExecutionOutcome(
            sql=sql,
            validation=validation,
            error_kind=ExecutionErrorKind.INVALID_SQL,
            error_message="; ".join(v.message for v in validation.violations),
        )

    try:
        with execute_readonly(
            engine,
            sql,
            statement_timeout_ms=statement_timeout_ms,
            row_cap=row_cap,
        ) as result:
            if result.row_count == 0:
                return ExecutionOutcome(
                    sql=sql,
                    validation=validation,
                    result=result,
                    error_kind=ExecutionErrorKind.EMPTY_RESULT,
                    error_message="query returned 0 rows",
                )
            return ExecutionOutcome(sql=sql, validation=validation, result=result)
    except OperationalError as exc:
        kind = (
            ExecutionErrorKind.EXECUTION_TIMEOUT
            if "timeout" in str(exc).lower() or "interrupted" in str(exc).lower()
            else ExecutionErrorKind.EXECUTION_FAILED
        )
        return ExecutionOutcome(
            sql=sql,
            validation=validation,
            error_kind=kind,
            error_message=str(exc),
        )
    except SQLAlchemyError as exc:
        return ExecutionOutcome(
            sql=sql,
            validation=validation,
            error_kind=ExecutionErrorKind.EXECUTION_FAILED,
            error_message=str(exc),
        )