nl-sql / src /nl_sql /agent /nodes /validate.py
liovina's picture
Deploy NL_SQL HEAD to HF Space
942050b verified
"""Node: run the static SQL guard (`validate_sql`) on the current candidate.
Sets `outcome` with INVALID_SQL if the guard rejects, otherwise leaves
`outcome` unset so the execute node can run. Failure here routes to
``repair_once`` if no repair has been tried yet, or to ``deterministic_format``
(with the validation error visible) if we already burned the single retry.
"""
from __future__ import annotations
from collections.abc import Callable
from nl_sql.agent.state import PipelineState
from nl_sql.db.connection import Dialect
from nl_sql.execution.errors import ExecutionErrorKind
from nl_sql.execution.guards import GuardViolation, ValidationReport, validate_sql
from nl_sql.execution.runner import ExecutionOutcome
def make_validate_node() -> Callable[[PipelineState], PipelineState]:
def node(state: PipelineState) -> PipelineState:
generated = state.get("generated")
dialect: Dialect = state.get("dialect", "sqlite")
trace = list(state.get("trace") or [])
if generated is None or not generated.sql:
report = ValidationReport(sql="", dialect=dialect)
report.add("no_sql", "generate_sql produced no SQL")
outcome = ExecutionOutcome(
sql="",
validation=report,
error_kind=ExecutionErrorKind.INVALID_SQL,
error_message="generate_sql produced no SQL",
)
trace.append({"node": "validate", "ok": False, "reason": "no_sql"})
return {
"outcome": outcome,
"last_error": outcome.error_message,
"error_kind": ExecutionErrorKind.INVALID_SQL,
"error_message": outcome.error_message,
"trace": trace,
}
report = validate_sql(generated.sql, dialect=dialect)
if report.ok:
trace.append({"node": "validate", "ok": True})
return {
"outcome": ExecutionOutcome(sql=generated.sql, validation=report),
"trace": trace,
}
joined = "; ".join(v.message for v in report.violations)
outcome = ExecutionOutcome(
sql=generated.sql,
validation=report,
error_kind=ExecutionErrorKind.INVALID_SQL,
error_message=joined,
)
trace.append(
{
"node": "validate",
"ok": False,
"violations": [v.code for v in report.violations],
}
)
return {
"outcome": outcome,
"last_error": joined,
"error_kind": ExecutionErrorKind.INVALID_SQL,
"error_message": joined,
"trace": trace,
}
return node
# GuardViolation is re-exported solely so direct unit tests of this node
# don't need a separate import.
__all__ = ["GuardViolation", "make_validate_node"]