File size: 4,030 Bytes
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Node: execute the validated SQL via the read-only runner.

The validate node has already produced a `ValidationReport`; we re-validate
inside `execute_validated` (cheap) so this node still works if it's ever
called directly without the validate predecessor (e.g. in unit tests).
"""

from __future__ import annotations

from collections.abc import Callable

from nl_sql.agent.state import PipelineState
from nl_sql.db.connection import Dialect
from nl_sql.db.registry import DatabaseRegistry, get_default_registry
from nl_sql.execution.runner import execute_validated


def make_execute_node(
    *,
    registry: DatabaseRegistry | None = None,
    statement_timeout_ms: int = 30_000,
    row_cap: int = 10_000,
) -> Callable[[PipelineState], PipelineState]:
    """`registry` is injected for tests; production code uses the default scan.

    Engine is created+disposed per call. SQLite engine setup is essentially
    free; pooling is unnecessary for one query at a time and risks leaking
    SQLite connections under pytest's strict `ResourceWarning` regime on
    Windows.
    """
    reg = registry or get_default_registry()

    def node(state: PipelineState) -> PipelineState:
        generated = state.get("generated")
        db_id = state.get("db_id", "")
        dialect: Dialect = state.get("dialect", "sqlite")
        trace = list(state.get("trace") or [])

        if generated is None or not generated.sql:
            trace.append({"node": "execute", "ok": False, "reason": "no_sql"})
            # validate already populated outcome+error fields; pass them through.
            return {"trace": trace}

        engine = reg.get(db_id).make_engine()
        try:
            outcome = execute_validated(
                engine,
                generated.sql,
                dialect=dialect,
                statement_timeout_ms=statement_timeout_ms,
                row_cap=row_cap,
            )
        finally:
            engine.dispose()
        if outcome.ok:
            trace.append(
                {
                    "node": "execute",
                    "ok": True,
                    "row_count": outcome.result.row_count if outcome.result else 0,
                    "elapsed_ms": outcome.result.elapsed_ms if outcome.result else 0.0,
                }
            )
            return {
                "outcome": outcome,
                "error_kind": None,
                "error_message": "",
                "trace": trace,
            }

        trace.append(
            {
                "node": "execute",
                "ok": False,
                "kind": outcome.error_kind.value if outcome.error_kind else None,
                "message": outcome.error_message,
            }
        )
        last_error = outcome.error_message
        if (
            outcome.error_kind is not None
            and outcome.error_kind.value == "empty_result"
            and state.get("verify_retry_on_empty")
        ):
            # Concrete hints for the repair_once prompt — empty rows almost
            # always come from a filter-value miss, not a structural error.
            last_error = (
                "Your SQL ran successfully but returned 0 rows. The schema "
                "and joins look fine; the most likely cause is a wrong "
                "filter value. Re-examine the WHERE clause: check exact "
                "case and spelling of string literals (compare against the "
                "sample values in the schema cards), consider whether the "
                "match needs LIKE '%value%' instead of `= 'value'`, and "
                "verify NULL handling (`IS NULL` vs `= NULL`). If the "
                "question is open-ended, try widening the filter; do not "
                "narrow it further."
            )
        return {
            "outcome": outcome,
            "last_error": last_error,
            "error_kind": outcome.error_kind,
            "error_message": outcome.error_message,
            "trace": trace,
        }

    return node