nl-sql / src /nl_sql /execution /guards.py
liovina's picture
Deploy NL_SQL HEAD to HF Space
942050b verified
"""Static SQL guard backed by sqlglot AST.
Layer 2 of the 3-layer execution defence (per docs/02_architecture_v2.md §5):
- DB role enforces read-only at session level.
- This module enforces shape: SELECT-only, single statement, no DML in CTEs,
function allowlist (deny pg_sleep, generate_series above N, ATTACH, file
reads), no schema escalation.
- Runner enforces statement_timeout + result-payload cap.
The guard returns a ValidationReport describing every detected issue rather
than raising on the first one — the caller (repair_once node) needs the full
list to build a useful error context for the LLM retry.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Final, cast
import sqlglot
from sqlglot import exp
from sqlglot.errors import ParseError
from nl_sql.db.connection import Dialect
# Map our SQLAlchemy-style dialect names to sqlglot's canonical names.
_SQLGLOT_DIALECT: Final[dict[Dialect, str]] = {
"sqlite": "sqlite",
"postgresql": "postgres",
}
# Functions banned outright. Mostly resource-abuse vectors.
BANNED_FUNCTIONS: Final[frozenset[str]] = frozenset(
{
"pg_sleep",
"pg_read_file",
"pg_read_binary_file",
"pg_ls_dir",
"lo_import",
"lo_export",
"dblink",
"load_extension",
"readfile",
}
)
# generate_series is legitimate but easy to weaponize for DoS — capped at N.
GENERATE_SERIES_MAX_RANGE: Final[int] = 1_000_000
# Reading from these schemas/tables is denied unless on an explicit allowlist.
DENIED_TABLES: Final[frozenset[str]] = frozenset(
{
"pg_user",
"pg_authid",
"pg_shadow",
"pg_roles",
}
)
@dataclass(frozen=True, slots=True)
class GuardViolation:
code: str
message: str
@dataclass(slots=True)
class ValidationReport:
sql: str
dialect: Dialect
violations: list[GuardViolation] = field(default_factory=list)
parsed: exp.Expression | None = None
@property
def ok(self) -> bool:
return not self.violations
def add(self, code: str, message: str) -> None:
self.violations.append(GuardViolation(code=code, message=message))
def validate_sql(sql: str, dialect: Dialect = "sqlite") -> ValidationReport:
"""Run all static checks against `sql`. Always returns a report."""
report = ValidationReport(sql=sql, dialect=dialect)
parsed_list = _safe_parse(sql, dialect, report)
if parsed_list is None:
return report
if len(parsed_list) != 1:
report.add("multi_statement", f"expected 1 statement, got {len(parsed_list)}")
return report
parsed = parsed_list[0]
if parsed is None:
report.add("empty_statement", "parsed statement is empty")
return report
report.parsed = parsed
_check_select_only(parsed, report)
_check_no_dml_anywhere(parsed, report)
_check_function_allowlist(parsed, report)
_check_generate_series_bounds(parsed, report)
_check_table_denylist(parsed, report)
_check_no_attach_or_pragma(parsed, report)
return report
def _safe_parse(
sql: str, dialect: Dialect, report: ValidationReport
) -> list[exp.Expression | None] | None:
sqlglot_name = _SQLGLOT_DIALECT[dialect]
try:
# sqlglot.parse returns list[Optional[Expression]] but its internal
# Expr alias trips strict mypy; cast keeps the public API typed.
return cast("list[exp.Expression | None]", sqlglot.parse(sql, read=sqlglot_name))
except ParseError as exc:
report.add("parse_error", str(exc))
return None
def _check_select_only(parsed: exp.Expression, report: ValidationReport) -> None:
if not isinstance(parsed, exp.Select | exp.Union | exp.Subquery):
report.add(
"not_select",
f"top-level statement must be SELECT/UNION; got {type(parsed).__name__}",
)
def _check_no_dml_anywhere(parsed: exp.Expression, report: ValidationReport) -> None:
forbidden = (exp.Insert, exp.Update, exp.Delete, exp.Merge, exp.Alter)
for node in parsed.walk():
if isinstance(node, forbidden):
report.add("dml_in_tree", f"{type(node).__name__} not permitted anywhere in query")
# DDL via Create / Drop
for node in parsed.walk():
if isinstance(node, exp.Create | exp.Drop):
report.add("ddl_in_tree", f"{type(node).__name__} not permitted")
# SQLite-specific safety: ATTACH / PRAGMA arrive as their own AST nodes
for node in parsed.walk():
if isinstance(node, exp.Attach):
report.add("attach_database", "ATTACH DATABASE is denied")
if isinstance(node, exp.Pragma):
report.add("pragma_statement", "PRAGMA is denied at the query layer")
def _check_function_allowlist(parsed: exp.Expression, report: ValidationReport) -> None:
seen: set[str] = set()
# Anonymous covers most user-named functions (pg_sleep, generate_series, etc.).
for anon in parsed.find_all(exp.Anonymous):
anon_name = (anon.name or "").lower()
if anon_name in BANNED_FUNCTIONS and anon_name not in seen:
seen.add(anon_name)
report.add("banned_function", f"function {anon_name!r} is denied")
# Some functions render as named exp.Func subclasses with key()-derived names.
for node in parsed.walk():
if isinstance(node, exp.Func):
key = (node.key or "").lower()
if key in BANNED_FUNCTIONS and key not in seen:
seen.add(key)
report.add("banned_function", f"function {key!r} is denied")
def _check_generate_series_bounds(parsed: exp.Expression, report: ValidationReport) -> None:
# Postgres dialect produces a typed GenerateSeries node; SQLite-style or
# unrecognized parses fall back to Anonymous("generate_series", ...).
for node in parsed.walk():
bounds = _generate_series_bounds(node)
if bounds is None:
continue
start, stop = bounds
span = abs(stop - start)
if span > GENERATE_SERIES_MAX_RANGE:
report.add(
"generate_series_too_large",
f"generate_series span {span} exceeds cap {GENERATE_SERIES_MAX_RANGE}",
)
def _generate_series_bounds(node: Any) -> tuple[int, int] | None:
start_node: Any
end_node: Any
if isinstance(node, exp.GenerateSeries):
start_node = node.args.get("start")
end_node = node.args.get("end")
elif isinstance(node, exp.Anonymous) and (node.name or "").lower() == "generate_series":
args = list(node.expressions)
if len(args) < 2:
return None
start_node, end_node = args[0], args[1]
else:
return None
if not (isinstance(start_node, exp.Literal) and start_node.is_int):
return None
if not (isinstance(end_node, exp.Literal) and end_node.is_int):
return None
return int(start_node.this), int(end_node.this)
def _check_table_denylist(parsed: exp.Expression, report: ValidationReport) -> None:
for table in parsed.find_all(exp.Table):
name = (table.name or "").lower()
if name in DENIED_TABLES:
report.add("denied_table", f"table {name!r} is on the denylist")
def _check_no_attach_or_pragma(parsed: exp.Expression, report: ValidationReport) -> None:
# Already covered by _check_no_dml_anywhere via exp.Attach / exp.Pragma nodes;
# kept as a no-op seam so future dialects (e.g. DuckDB ATTACH variants) can plug
# additional textual heuristics without adding a new top-level check.
return