File size: 7,621 Bytes
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""Static SQL guard backed by sqlglot AST.

Layer 2 of the 3-layer execution defence (per docs/02_architecture_v2.md §5):
- DB role enforces read-only at session level.
- This module enforces shape: SELECT-only, single statement, no DML in CTEs,
  function allowlist (deny pg_sleep, generate_series above N, ATTACH, file
  reads), no schema escalation.
- Runner enforces statement_timeout + result-payload cap.

The guard returns a ValidationReport describing every detected issue rather
than raising on the first one — the caller (repair_once node) needs the full
list to build a useful error context for the LLM retry.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Final, cast

import sqlglot
from sqlglot import exp
from sqlglot.errors import ParseError

from nl_sql.db.connection import Dialect

# Map our SQLAlchemy-style dialect names to sqlglot's canonical names.
_SQLGLOT_DIALECT: Final[dict[Dialect, str]] = {
    "sqlite": "sqlite",
    "postgresql": "postgres",
}

# Functions banned outright. Mostly resource-abuse vectors.
BANNED_FUNCTIONS: Final[frozenset[str]] = frozenset(
    {
        "pg_sleep",
        "pg_read_file",
        "pg_read_binary_file",
        "pg_ls_dir",
        "lo_import",
        "lo_export",
        "dblink",
        "load_extension",
        "readfile",
    }
)

# generate_series is legitimate but easy to weaponize for DoS — capped at N.
GENERATE_SERIES_MAX_RANGE: Final[int] = 1_000_000

# Reading from these schemas/tables is denied unless on an explicit allowlist.
DENIED_TABLES: Final[frozenset[str]] = frozenset(
    {
        "pg_user",
        "pg_authid",
        "pg_shadow",
        "pg_roles",
    }
)


@dataclass(frozen=True, slots=True)
class GuardViolation:
    code: str
    message: str


@dataclass(slots=True)
class ValidationReport:
    sql: str
    dialect: Dialect
    violations: list[GuardViolation] = field(default_factory=list)
    parsed: exp.Expression | None = None

    @property
    def ok(self) -> bool:
        return not self.violations

    def add(self, code: str, message: str) -> None:
        self.violations.append(GuardViolation(code=code, message=message))


def validate_sql(sql: str, dialect: Dialect = "sqlite") -> ValidationReport:
    """Run all static checks against `sql`. Always returns a report."""
    report = ValidationReport(sql=sql, dialect=dialect)

    parsed_list = _safe_parse(sql, dialect, report)
    if parsed_list is None:
        return report

    if len(parsed_list) != 1:
        report.add("multi_statement", f"expected 1 statement, got {len(parsed_list)}")
        return report

    parsed = parsed_list[0]
    if parsed is None:
        report.add("empty_statement", "parsed statement is empty")
        return report

    report.parsed = parsed

    _check_select_only(parsed, report)
    _check_no_dml_anywhere(parsed, report)
    _check_function_allowlist(parsed, report)
    _check_generate_series_bounds(parsed, report)
    _check_table_denylist(parsed, report)
    _check_no_attach_or_pragma(parsed, report)

    return report


def _safe_parse(
    sql: str, dialect: Dialect, report: ValidationReport
) -> list[exp.Expression | None] | None:
    sqlglot_name = _SQLGLOT_DIALECT[dialect]
    try:
        # sqlglot.parse returns list[Optional[Expression]] but its internal
        # Expr alias trips strict mypy; cast keeps the public API typed.
        return cast("list[exp.Expression | None]", sqlglot.parse(sql, read=sqlglot_name))
    except ParseError as exc:
        report.add("parse_error", str(exc))
        return None


def _check_select_only(parsed: exp.Expression, report: ValidationReport) -> None:
    if not isinstance(parsed, exp.Select | exp.Union | exp.Subquery):
        report.add(
            "not_select",
            f"top-level statement must be SELECT/UNION; got {type(parsed).__name__}",
        )


def _check_no_dml_anywhere(parsed: exp.Expression, report: ValidationReport) -> None:
    forbidden = (exp.Insert, exp.Update, exp.Delete, exp.Merge, exp.Alter)
    for node in parsed.walk():
        if isinstance(node, forbidden):
            report.add("dml_in_tree", f"{type(node).__name__} not permitted anywhere in query")
    # DDL via Create / Drop
    for node in parsed.walk():
        if isinstance(node, exp.Create | exp.Drop):
            report.add("ddl_in_tree", f"{type(node).__name__} not permitted")
    # SQLite-specific safety: ATTACH / PRAGMA arrive as their own AST nodes
    for node in parsed.walk():
        if isinstance(node, exp.Attach):
            report.add("attach_database", "ATTACH DATABASE is denied")
        if isinstance(node, exp.Pragma):
            report.add("pragma_statement", "PRAGMA is denied at the query layer")


def _check_function_allowlist(parsed: exp.Expression, report: ValidationReport) -> None:
    seen: set[str] = set()
    # Anonymous covers most user-named functions (pg_sleep, generate_series, etc.).
    for anon in parsed.find_all(exp.Anonymous):
        anon_name = (anon.name or "").lower()
        if anon_name in BANNED_FUNCTIONS and anon_name not in seen:
            seen.add(anon_name)
            report.add("banned_function", f"function {anon_name!r} is denied")
    # Some functions render as named exp.Func subclasses with key()-derived names.
    for node in parsed.walk():
        if isinstance(node, exp.Func):
            key = (node.key or "").lower()
            if key in BANNED_FUNCTIONS and key not in seen:
                seen.add(key)
                report.add("banned_function", f"function {key!r} is denied")


def _check_generate_series_bounds(parsed: exp.Expression, report: ValidationReport) -> None:
    # Postgres dialect produces a typed GenerateSeries node; SQLite-style or
    # unrecognized parses fall back to Anonymous("generate_series", ...).
    for node in parsed.walk():
        bounds = _generate_series_bounds(node)
        if bounds is None:
            continue
        start, stop = bounds
        span = abs(stop - start)
        if span > GENERATE_SERIES_MAX_RANGE:
            report.add(
                "generate_series_too_large",
                f"generate_series span {span} exceeds cap {GENERATE_SERIES_MAX_RANGE}",
            )


def _generate_series_bounds(node: Any) -> tuple[int, int] | None:
    start_node: Any
    end_node: Any
    if isinstance(node, exp.GenerateSeries):
        start_node = node.args.get("start")
        end_node = node.args.get("end")
    elif isinstance(node, exp.Anonymous) and (node.name or "").lower() == "generate_series":
        args = list(node.expressions)
        if len(args) < 2:
            return None
        start_node, end_node = args[0], args[1]
    else:
        return None

    if not (isinstance(start_node, exp.Literal) and start_node.is_int):
        return None
    if not (isinstance(end_node, exp.Literal) and end_node.is_int):
        return None
    return int(start_node.this), int(end_node.this)


def _check_table_denylist(parsed: exp.Expression, report: ValidationReport) -> None:
    for table in parsed.find_all(exp.Table):
        name = (table.name or "").lower()
        if name in DENIED_TABLES:
            report.add("denied_table", f"table {name!r} is on the denylist")


def _check_no_attach_or_pragma(parsed: exp.Expression, report: ValidationReport) -> None:
    # Already covered by _check_no_dml_anywhere via exp.Attach / exp.Pragma nodes;
    # kept as a no-op seam so future dialects (e.g. DuckDB ATTACH variants) can plug
    # additional textual heuristics without adding a new top-level check.
    return