| """Node: run the static SQL guard (`validate_sql`) on the current candidate. |
| |
| Sets `outcome` with INVALID_SQL if the guard rejects, otherwise leaves |
| `outcome` unset so the execute node can run. Failure here routes to |
| ``repair_once`` if no repair has been tried yet, or to ``deterministic_format`` |
| (with the validation error visible) if we already burned the single retry. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Callable |
|
|
| from nl_sql.agent.state import PipelineState |
| from nl_sql.db.connection import Dialect |
| from nl_sql.execution.errors import ExecutionErrorKind |
| from nl_sql.execution.guards import GuardViolation, ValidationReport, validate_sql |
| from nl_sql.execution.runner import ExecutionOutcome |
|
|
|
|
| def make_validate_node() -> Callable[[PipelineState], PipelineState]: |
| def node(state: PipelineState) -> PipelineState: |
| generated = state.get("generated") |
| dialect: Dialect = state.get("dialect", "sqlite") |
| trace = list(state.get("trace") or []) |
|
|
| if generated is None or not generated.sql: |
| report = ValidationReport(sql="", dialect=dialect) |
| report.add("no_sql", "generate_sql produced no SQL") |
| outcome = ExecutionOutcome( |
| sql="", |
| validation=report, |
| error_kind=ExecutionErrorKind.INVALID_SQL, |
| error_message="generate_sql produced no SQL", |
| ) |
| trace.append({"node": "validate", "ok": False, "reason": "no_sql"}) |
| return { |
| "outcome": outcome, |
| "last_error": outcome.error_message, |
| "error_kind": ExecutionErrorKind.INVALID_SQL, |
| "error_message": outcome.error_message, |
| "trace": trace, |
| } |
|
|
| report = validate_sql(generated.sql, dialect=dialect) |
| if report.ok: |
| trace.append({"node": "validate", "ok": True}) |
| return { |
| "outcome": ExecutionOutcome(sql=generated.sql, validation=report), |
| "trace": trace, |
| } |
|
|
| joined = "; ".join(v.message for v in report.violations) |
| outcome = ExecutionOutcome( |
| sql=generated.sql, |
| validation=report, |
| error_kind=ExecutionErrorKind.INVALID_SQL, |
| error_message=joined, |
| ) |
| trace.append( |
| { |
| "node": "validate", |
| "ok": False, |
| "violations": [v.code for v in report.violations], |
| } |
| ) |
| return { |
| "outcome": outcome, |
| "last_error": joined, |
| "error_kind": ExecutionErrorKind.INVALID_SQL, |
| "error_message": joined, |
| "trace": trace, |
| } |
|
|
| return node |
|
|
|
|
| |
| |
| __all__ = ["GuardViolation", "make_validate_node"] |
|
|