File size: 13,916 Bytes
d48602c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b4ff9e
 
 
 
 
 
 
 
 
 
 
d48602c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b4ff9e
 
d48602c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
"""LangGraph StateGraph wiring + a thin run-result wrapper.



Topology (per docs/02_architecture_v2.md Β§3):



    START

      β”‚

      β–Ό

    context_builder

      β”‚

      β–Ό

    generate_sql ◄────────────┐

      β”‚                       β”‚

      β–Ό                       β”‚

    validate ──fail──► repair_once  (fired exactly once,

      β”‚                              guarded by repair_attempted)

      β–Ό ok

    execute ──fail──► repair_once

      β”‚

      β–Ό ok

    deterministic_format

      β”‚

      β–Ό

    explain_trace

      β”‚

      β–Ό

    END



Failure fall-through: when a fail happens AND repair was already attempted,

we route directly to deterministic_format with the error attached, so the

user always sees a structured caption + trace instead of a 500.

"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Literal, cast

from langgraph.graph import END, START, StateGraph
from langgraph.graph.state import CompiledStateGraph

from nl_sql.agent.nodes import (
    make_context_builder_node,
    make_execute_node,
    make_explain_trace_node,
    make_format_node,
    make_generate_sql_node,
    make_grounded_critique_node,
    make_plan_node,
    make_repair_once_node,
    make_validate_node,
)
from nl_sql.agent.state import GenerateSQLOutput, PipelineState
from nl_sql.db.connection import Dialect
from nl_sql.db.registry import DatabaseRegistry
from nl_sql.execution.errors import ExecutionErrorKind
from nl_sql.execution.runner import ExecutionOutcome
from nl_sql.llm.providers.base import LLMProvider
from nl_sql.render.formats import OutputFormat
from nl_sql.schema_index.indexer import SchemaIndex


@dataclass(slots=True)
class PipelineConfig:
    """All runtime dependencies. Tests inject fakes via this object."""

    sql_provider: LLMProvider
    explain_provider: LLMProvider
    schema_index: SchemaIndex
    registry: DatabaseRegistry
    schema_top_k: int = 5
    fewshot_top_k: int = 3
    fk_hops: int = 1
    table_budget: int = 12
    statement_timeout_ms: int = 30_000
    row_cap: int = 10_000
    sort_schema_block: bool = True
    """Render schema_block in alphabetical-by-table-name order instead of

    retrieval-distance + FK BFS order. Empirically the single biggest

    retrieval-side EA lever on BIRD Mini-Dev under codestral

    (+3pp moderate, +5.5pp challenging at n=100; +5pp moderate at n=200).

    Default ON since 2026-05-11 per docs/SESSION_HANDOFF.md item #2.

    Set to False explicitly to recover the unsorted retrieval-distance

    baseline for ablation."""
    primary_sample_size: int = 3
    """Sample density already baked into the chunks stored in Chroma.

    Must match the `--sample-size` used by `scripts/build_index.py` when

    the current `chroma_data/` was built. Used together with

    `extended_sample_size` to compute the tail for the mixture appendix.

    """
    extended_sample_size: int = 0
    """Per-difficulty sample mixture (off by default). When > 0 and

    > `primary_sample_size`, the context_builder fetches sample values

    rows `primary..extended` per column for retrieved tables and

    `render_schema_block` appends them as an "additional sample values"

    section. Empirically: s=3 cards favour moderate-tier accuracy, s=5

    cards favour challenging-tier; the mixture exposes both densities

    to the model in a single prompt. Requires registry access β€” see

    docs/SESSION_HANDOFF.md item #1."""
    sql_temperature: float = 0.0
    """Sampling temperature for the generate_sql / repair_once LLM calls.

    Default 0.0 = greedy / deterministic. Higher values inject diversity

    needed by config F (self-consistency execution-based voting), where

    each candidate runs at a different temperature so the cache stores

    them as distinct entries."""
    cross_db_fewshot: bool = False
    """When True, few-shot retrieval skips the `db_id` filter and pulls

    Q→SQL hits from any database in the `fewshot_qsql` collection. Needed

    for BIRD, whose train and dev splits are partitioned by db_id (zero

    overlap) β€” same-db retrieval would return zero hits. Set ON by

    `run_config_d`; off everywhere else."""
    verify_retry_on_empty: bool = False
    """When True, route an EMPTY_RESULT outcome to `repair_once` instead

    of short-circuiting to deterministic_format. Empty rows often mean

    the model got the filter value wrong (case mismatch, LIKE pattern

    missing, NULL handling); a second pass with the empty-result hint

    can recover them. Subject to the standard `repair_attempted` guard β€”

    one extra LLM call per question, capped. Set ON by `run_config_g`."""
    enable_planner: bool = False
    """When True, insert a `plan_query` node before `generate_sql`. The

    planner emits a structured JSON skeleton (intent / expected_row_count

    / tables / joins / filters / group_by / aggregations / projection /

    sort / limit) which `generate_sql` and `repair_once` then condition

    on via the {{plan_block}} prompt slot. Doubles per-question LLM cost

    on cache miss; intended for moderate/challenging-tier difficulty

    where the row-shape commitment delta justifies the extra call.

    Empirically targets the row_count_off + projection_diff failure

    buckets identified by `scripts/error_taxonomy.py`."""
    enable_grounded_critique: bool = False
    """When True, run a cheap post-execution row-shape critique before

    deterministic formatting and route one failed critique to `repair_once`.

    """
    use_m_schema: bool = False
    """When True, render the schema block as M-Schema (XiYan-SQL compact

    one-line-per-column with inline samples + trailing FK pairs block) instead

    of the default verbose card layout. Replaces the legacy `NLSQL_M_SCHEMA=1`

    env toggle; `api/main.py` reads the env once at boot and threads it here so

    individual nodes no longer touch `os.environ` at runtime."""
    use_dac_prompt: bool = False
    """When True, use the CHASE-SQL divide-and-conquer prompt

    (`generate_sql_dac.txt`) which decomposes multi-clause questions into

    sub-questions before composing SQL. Replaces the legacy `NLSQL_DAC=1`

    env toggle; `api/main.py` reads the env once at boot and threads it here."""


@dataclass(slots=True)
class PipelineRunResult:
    """Flat snapshot of the terminal state β€” what the caller needs."""

    question: str
    db_id: str
    sql: str
    rationale: str
    confidence: float
    outcome: ExecutionOutcome | None
    output_format: OutputFormat | None
    caption: str
    error_kind: ExecutionErrorKind | None
    error_message: str
    repair_attempted: bool
    trace: list[dict[str, object]]

    @property
    def ok(self) -> bool:
        return self.outcome is not None and self.outcome.ok and self.error_kind is None


def build_pipeline(config: PipelineConfig) -> CompiledStateGraph[Any, Any, Any, Any]:
    graph: StateGraph[PipelineState, None, PipelineState, PipelineState] = StateGraph(PipelineState)

    nodes: dict[str, Any] = {
        "context_builder": make_context_builder_node(
            config.schema_index,
            schema_top_k=config.schema_top_k,
            fewshot_top_k=config.fewshot_top_k,
            fk_hops=config.fk_hops,
            table_budget=config.table_budget,
            registry=config.registry,
            primary_sample_size=config.primary_sample_size,
            extended_sample_size=config.extended_sample_size,
            cross_db_fewshot=config.cross_db_fewshot,
        ),
        "generate_sql": make_generate_sql_node(
            config.sql_provider,
            sort_schema_block=config.sort_schema_block,
            temperature=config.sql_temperature,
            use_m_schema=config.use_m_schema,
            use_dac_prompt=config.use_dac_prompt,
        ),
        "validate": make_validate_node(),
        "repair_once": make_repair_once_node(
            config.sql_provider,
            sort_schema_block=config.sort_schema_block,
        ),
        "execute": make_execute_node(
            registry=config.registry,
            statement_timeout_ms=config.statement_timeout_ms,
            row_cap=config.row_cap,
        ),
        "deterministic_format": make_format_node(),
        "explain_trace": make_explain_trace_node(config.explain_provider),
    }
    if config.enable_planner:
        nodes["plan_query"] = make_plan_node(
            config.sql_provider,
            sort_schema_block=config.sort_schema_block,
            temperature=config.sql_temperature,
        )
    if config.enable_grounded_critique:
        nodes["grounded_critique"] = make_grounded_critique_node()
    for name, action in nodes.items():
        graph.add_node(name, action)

    graph.add_edge(START, "context_builder")
    if config.enable_planner:
        graph.add_edge("context_builder", "plan_query")
        graph.add_edge("plan_query", "generate_sql")
    else:
        graph.add_edge("context_builder", "generate_sql")
    graph.add_edge("generate_sql", "validate")
    graph.add_conditional_edges("validate", _route_after_validate)
    graph.add_edge("repair_once", "validate")
    if config.enable_grounded_critique:
        graph.add_conditional_edges("execute", _route_after_execute_with_critique)
        graph.add_conditional_edges("grounded_critique", _route_after_grounded_critique)
    else:
        graph.add_conditional_edges("execute", _route_after_execute)
    graph.add_edge("deterministic_format", "explain_trace")
    graph.add_edge("explain_trace", END)

    return graph.compile()


_AfterValidate = Literal["repair_once", "execute", "deterministic_format"]
_AfterExecute = Literal["repair_once", "deterministic_format"]
_AfterExecuteWithCritique = Literal["repair_once", "deterministic_format", "grounded_critique"]
_AfterGroundedCritique = Literal["repair_once", "deterministic_format"]


def _route_after_validate(state: PipelineState) -> _AfterValidate:
    outcome = state.get("outcome")
    if outcome is not None and outcome.error_kind is None:
        return "execute"
    if not state.get("repair_attempted"):
        return "repair_once"
    return "deterministic_format"


def _route_after_execute(state: PipelineState) -> _AfterExecute:
    outcome = state.get("outcome")
    if outcome is None:
        return "deterministic_format"
    if outcome.ok:
        return "deterministic_format"
    # EMPTY_RESULT is normally a valid outcome (zero rows is a legitimate
    # answer) β†’ render handles the empty-set messaging. Config G flips this
    # to retry the empty case once, on the assumption that the model
    # confused a filter value (case mismatch, LIKE pattern, NULL handling).
    if outcome.error_kind == ExecutionErrorKind.EMPTY_RESULT:
        if state.get("verify_retry_on_empty") and not state.get("repair_attempted"):
            return "repair_once"
        return "deterministic_format"
    if not state.get("repair_attempted"):
        return "repair_once"
    return "deterministic_format"


def _route_after_execute_with_critique(state: PipelineState) -> _AfterExecuteWithCritique:
    outcome = state.get("outcome")
    if outcome is not None and outcome.ok:
        return "grounded_critique"
    return _route_after_execute(state)


def _route_after_grounded_critique(state: PipelineState) -> _AfterGroundedCritique:
    if state.get("critique_failed") and not state.get("repair_attempted"):
        return "repair_once"
    return "deterministic_format"


def run_pipeline(

    pipeline: CompiledStateGraph[Any, Any, Any, Any],

    *,

    question: str,

    db_id: str,

    dialect: Dialect = "sqlite",

    disable_repair: bool = False,

    verify_retry_on_empty: bool = False,

) -> PipelineRunResult:
    """One-shot helper: invoke the compiled graph and flatten the result.



    `disable_repair` (default False): when True, sets repair_attempted in

    initial state, which causes both `_route_after_validate` and

    `_route_after_execute` to skip the repair branch on the first failure

    and fall through to deterministic_format. Used by eval configurations

    A-D where the methodology specifies "no repair" as a measured baseline.



    `verify_retry_on_empty` (default False): when True, an EMPTY_RESULT

    outcome routes to repair_once (subject to the repair_attempted guard)

    so the model can take a second swing at the filter values. Used by

    config G; the corresponding `last_error` payload comes from the

    execute node and includes the empty-result hint.

    """
    initial: PipelineState = {
        "question": question,
        "db_id": db_id,
        "dialect": dialect,
        "repair_attempted": disable_repair,
        "verify_retry_on_empty": verify_retry_on_empty,
        "trace": [],
    }
    final = cast(PipelineState, pipeline.invoke(initial))
    generated = final.get("generated") or GenerateSQLOutput(sql="")
    return PipelineRunResult(
        question=final.get("question", question),
        db_id=final.get("db_id", db_id),
        sql=generated.sql,
        rationale=generated.rationale,
        confidence=generated.confidence,
        outcome=final.get("outcome"),
        output_format=final.get("output_format"),
        caption=final.get("caption", ""),
        error_kind=final.get("error_kind"),
        error_message=final.get("error_message", ""),
        repair_attempted=bool(final.get("repair_attempted")),
        trace=list(final.get("trace") or []),
    )


__all__ = [
    "PipelineConfig",
    "PipelineRunResult",
    "build_pipeline",
    "run_pipeline",
]