File size: 2,909 Bytes
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Node: run the static SQL guard (`validate_sql`) on the current candidate.

Sets `outcome` with INVALID_SQL if the guard rejects, otherwise leaves
`outcome` unset so the execute node can run. Failure here routes to
``repair_once`` if no repair has been tried yet, or to ``deterministic_format``
(with the validation error visible) if we already burned the single retry.
"""

from __future__ import annotations

from collections.abc import Callable

from nl_sql.agent.state import PipelineState
from nl_sql.db.connection import Dialect
from nl_sql.execution.errors import ExecutionErrorKind
from nl_sql.execution.guards import GuardViolation, ValidationReport, validate_sql
from nl_sql.execution.runner import ExecutionOutcome


def make_validate_node() -> Callable[[PipelineState], PipelineState]:
    def node(state: PipelineState) -> PipelineState:
        generated = state.get("generated")
        dialect: Dialect = state.get("dialect", "sqlite")
        trace = list(state.get("trace") or [])

        if generated is None or not generated.sql:
            report = ValidationReport(sql="", dialect=dialect)
            report.add("no_sql", "generate_sql produced no SQL")
            outcome = ExecutionOutcome(
                sql="",
                validation=report,
                error_kind=ExecutionErrorKind.INVALID_SQL,
                error_message="generate_sql produced no SQL",
            )
            trace.append({"node": "validate", "ok": False, "reason": "no_sql"})
            return {
                "outcome": outcome,
                "last_error": outcome.error_message,
                "error_kind": ExecutionErrorKind.INVALID_SQL,
                "error_message": outcome.error_message,
                "trace": trace,
            }

        report = validate_sql(generated.sql, dialect=dialect)
        if report.ok:
            trace.append({"node": "validate", "ok": True})
            return {
                "outcome": ExecutionOutcome(sql=generated.sql, validation=report),
                "trace": trace,
            }

        joined = "; ".join(v.message for v in report.violations)
        outcome = ExecutionOutcome(
            sql=generated.sql,
            validation=report,
            error_kind=ExecutionErrorKind.INVALID_SQL,
            error_message=joined,
        )
        trace.append(
            {
                "node": "validate",
                "ok": False,
                "violations": [v.code for v in report.violations],
            }
        )
        return {
            "outcome": outcome,
            "last_error": joined,
            "error_kind": ExecutionErrorKind.INVALID_SQL,
            "error_message": joined,
            "trace": trace,
        }

    return node


# GuardViolation is re-exported solely so direct unit tests of this node
# don't need a separate import.
__all__ = ["GuardViolation", "make_validate_node"]