"""LangGraph StateGraph wiring + a thin run-result wrapper. Topology (per docs/02_architecture_v2.md §3): START │ ▼ context_builder │ ▼ generate_sql ◄────────────┐ │ │ ▼ │ validate ──fail──► repair_once (fired exactly once, │ guarded by repair_attempted) ▼ ok execute ──fail──► repair_once │ ▼ ok deterministic_format │ ▼ explain_trace │ ▼ END Failure fall-through: when a fail happens AND repair was already attempted, we route directly to deterministic_format with the error attached, so the user always sees a structured caption + trace instead of a 500. """ from __future__ import annotations from dataclasses import dataclass from typing import Any, Literal, cast from langgraph.graph import END, START, StateGraph from langgraph.graph.state import CompiledStateGraph from nl_sql.agent.nodes import ( make_context_builder_node, make_execute_node, make_explain_trace_node, make_format_node, make_generate_sql_node, make_grounded_critique_node, make_plan_node, make_repair_once_node, make_validate_node, ) from nl_sql.agent.state import GenerateSQLOutput, PipelineState from nl_sql.db.connection import Dialect from nl_sql.db.registry import DatabaseRegistry from nl_sql.execution.errors import ExecutionErrorKind from nl_sql.execution.runner import ExecutionOutcome from nl_sql.llm.providers.base import LLMProvider from nl_sql.render.formats import OutputFormat from nl_sql.schema_index.indexer import SchemaIndex @dataclass(slots=True) class PipelineConfig: """All runtime dependencies. Tests inject fakes via this object.""" sql_provider: LLMProvider explain_provider: LLMProvider schema_index: SchemaIndex registry: DatabaseRegistry schema_top_k: int = 5 fewshot_top_k: int = 3 fk_hops: int = 1 table_budget: int = 12 statement_timeout_ms: int = 30_000 row_cap: int = 10_000 sort_schema_block: bool = True """Render schema_block in alphabetical-by-table-name order instead of retrieval-distance + FK BFS order. Empirically the single biggest retrieval-side EA lever on BIRD Mini-Dev under codestral (+3pp moderate, +5.5pp challenging at n=100; +5pp moderate at n=200). Default ON since 2026-05-11 per docs/SESSION_HANDOFF.md item #2. Set to False explicitly to recover the unsorted retrieval-distance baseline for ablation.""" primary_sample_size: int = 3 """Sample density already baked into the chunks stored in Chroma. Must match the `--sample-size` used by `scripts/build_index.py` when the current `chroma_data/` was built. Used together with `extended_sample_size` to compute the tail for the mixture appendix. """ extended_sample_size: int = 0 """Per-difficulty sample mixture (off by default). When > 0 and > `primary_sample_size`, the context_builder fetches sample values rows `primary..extended` per column for retrieved tables and `render_schema_block` appends them as an "additional sample values" section. Empirically: s=3 cards favour moderate-tier accuracy, s=5 cards favour challenging-tier; the mixture exposes both densities to the model in a single prompt. Requires registry access — see docs/SESSION_HANDOFF.md item #1.""" sql_temperature: float = 0.0 """Sampling temperature for the generate_sql / repair_once LLM calls. Default 0.0 = greedy / deterministic. Higher values inject diversity needed by config F (self-consistency execution-based voting), where each candidate runs at a different temperature so the cache stores them as distinct entries.""" cross_db_fewshot: bool = False """When True, few-shot retrieval skips the `db_id` filter and pulls Q→SQL hits from any database in the `fewshot_qsql` collection. Needed for BIRD, whose train and dev splits are partitioned by db_id (zero overlap) — same-db retrieval would return zero hits. Set ON by `run_config_d`; off everywhere else.""" verify_retry_on_empty: bool = False """When True, route an EMPTY_RESULT outcome to `repair_once` instead of short-circuiting to deterministic_format. Empty rows often mean the model got the filter value wrong (case mismatch, LIKE pattern missing, NULL handling); a second pass with the empty-result hint can recover them. Subject to the standard `repair_attempted` guard — one extra LLM call per question, capped. Set ON by `run_config_g`.""" enable_planner: bool = False """When True, insert a `plan_query` node before `generate_sql`. The planner emits a structured JSON skeleton (intent / expected_row_count / tables / joins / filters / group_by / aggregations / projection / sort / limit) which `generate_sql` and `repair_once` then condition on via the {{plan_block}} prompt slot. Doubles per-question LLM cost on cache miss; intended for moderate/challenging-tier difficulty where the row-shape commitment delta justifies the extra call. Empirically targets the row_count_off + projection_diff failure buckets identified by `scripts/error_taxonomy.py`.""" enable_grounded_critique: bool = False """When True, run a cheap post-execution row-shape critique before deterministic formatting and route one failed critique to `repair_once`. """ use_m_schema: bool = False """When True, render the schema block as M-Schema (XiYan-SQL compact one-line-per-column with inline samples + trailing FK pairs block) instead of the default verbose card layout. Replaces the legacy `NLSQL_M_SCHEMA=1` env toggle; `api/main.py` reads the env once at boot and threads it here so individual nodes no longer touch `os.environ` at runtime.""" use_dac_prompt: bool = False """When True, use the CHASE-SQL divide-and-conquer prompt (`generate_sql_dac.txt`) which decomposes multi-clause questions into sub-questions before composing SQL. Replaces the legacy `NLSQL_DAC=1` env toggle; `api/main.py` reads the env once at boot and threads it here.""" @dataclass(slots=True) class PipelineRunResult: """Flat snapshot of the terminal state — what the caller needs.""" question: str db_id: str sql: str rationale: str confidence: float outcome: ExecutionOutcome | None output_format: OutputFormat | None caption: str error_kind: ExecutionErrorKind | None error_message: str repair_attempted: bool trace: list[dict[str, object]] @property def ok(self) -> bool: return self.outcome is not None and self.outcome.ok and self.error_kind is None def build_pipeline(config: PipelineConfig) -> CompiledStateGraph[Any, Any, Any, Any]: graph: StateGraph[PipelineState, None, PipelineState, PipelineState] = StateGraph(PipelineState) nodes: dict[str, Any] = { "context_builder": make_context_builder_node( config.schema_index, schema_top_k=config.schema_top_k, fewshot_top_k=config.fewshot_top_k, fk_hops=config.fk_hops, table_budget=config.table_budget, registry=config.registry, primary_sample_size=config.primary_sample_size, extended_sample_size=config.extended_sample_size, cross_db_fewshot=config.cross_db_fewshot, ), "generate_sql": make_generate_sql_node( config.sql_provider, sort_schema_block=config.sort_schema_block, temperature=config.sql_temperature, use_m_schema=config.use_m_schema, use_dac_prompt=config.use_dac_prompt, ), "validate": make_validate_node(), "repair_once": make_repair_once_node( config.sql_provider, sort_schema_block=config.sort_schema_block, ), "execute": make_execute_node( registry=config.registry, statement_timeout_ms=config.statement_timeout_ms, row_cap=config.row_cap, ), "deterministic_format": make_format_node(), "explain_trace": make_explain_trace_node(config.explain_provider), } if config.enable_planner: nodes["plan_query"] = make_plan_node( config.sql_provider, sort_schema_block=config.sort_schema_block, temperature=config.sql_temperature, ) if config.enable_grounded_critique: nodes["grounded_critique"] = make_grounded_critique_node() for name, action in nodes.items(): graph.add_node(name, action) graph.add_edge(START, "context_builder") if config.enable_planner: graph.add_edge("context_builder", "plan_query") graph.add_edge("plan_query", "generate_sql") else: graph.add_edge("context_builder", "generate_sql") graph.add_edge("generate_sql", "validate") graph.add_conditional_edges("validate", _route_after_validate) graph.add_edge("repair_once", "validate") if config.enable_grounded_critique: graph.add_conditional_edges("execute", _route_after_execute_with_critique) graph.add_conditional_edges("grounded_critique", _route_after_grounded_critique) else: graph.add_conditional_edges("execute", _route_after_execute) graph.add_edge("deterministic_format", "explain_trace") graph.add_edge("explain_trace", END) return graph.compile() _AfterValidate = Literal["repair_once", "execute", "deterministic_format"] _AfterExecute = Literal["repair_once", "deterministic_format"] _AfterExecuteWithCritique = Literal["repair_once", "deterministic_format", "grounded_critique"] _AfterGroundedCritique = Literal["repair_once", "deterministic_format"] def _route_after_validate(state: PipelineState) -> _AfterValidate: outcome = state.get("outcome") if outcome is not None and outcome.error_kind is None: return "execute" if not state.get("repair_attempted"): return "repair_once" return "deterministic_format" def _route_after_execute(state: PipelineState) -> _AfterExecute: outcome = state.get("outcome") if outcome is None: return "deterministic_format" if outcome.ok: return "deterministic_format" # EMPTY_RESULT is normally a valid outcome (zero rows is a legitimate # answer) → render handles the empty-set messaging. Config G flips this # to retry the empty case once, on the assumption that the model # confused a filter value (case mismatch, LIKE pattern, NULL handling). if outcome.error_kind == ExecutionErrorKind.EMPTY_RESULT: if state.get("verify_retry_on_empty") and not state.get("repair_attempted"): return "repair_once" return "deterministic_format" if not state.get("repair_attempted"): return "repair_once" return "deterministic_format" def _route_after_execute_with_critique(state: PipelineState) -> _AfterExecuteWithCritique: outcome = state.get("outcome") if outcome is not None and outcome.ok: return "grounded_critique" return _route_after_execute(state) def _route_after_grounded_critique(state: PipelineState) -> _AfterGroundedCritique: if state.get("critique_failed") and not state.get("repair_attempted"): return "repair_once" return "deterministic_format" def run_pipeline( pipeline: CompiledStateGraph[Any, Any, Any, Any], *, question: str, db_id: str, dialect: Dialect = "sqlite", disable_repair: bool = False, verify_retry_on_empty: bool = False, ) -> PipelineRunResult: """One-shot helper: invoke the compiled graph and flatten the result. `disable_repair` (default False): when True, sets repair_attempted in initial state, which causes both `_route_after_validate` and `_route_after_execute` to skip the repair branch on the first failure and fall through to deterministic_format. Used by eval configurations A-D where the methodology specifies "no repair" as a measured baseline. `verify_retry_on_empty` (default False): when True, an EMPTY_RESULT outcome routes to repair_once (subject to the repair_attempted guard) so the model can take a second swing at the filter values. Used by config G; the corresponding `last_error` payload comes from the execute node and includes the empty-result hint. """ initial: PipelineState = { "question": question, "db_id": db_id, "dialect": dialect, "repair_attempted": disable_repair, "verify_retry_on_empty": verify_retry_on_empty, "trace": [], } final = cast(PipelineState, pipeline.invoke(initial)) generated = final.get("generated") or GenerateSQLOutput(sql="") return PipelineRunResult( question=final.get("question", question), db_id=final.get("db_id", db_id), sql=generated.sql, rationale=generated.rationale, confidence=generated.confidence, outcome=final.get("outcome"), output_format=final.get("output_format"), caption=final.get("caption", ""), error_kind=final.get("error_kind"), error_message=final.get("error_message", ""), repair_attempted=bool(final.get("repair_attempted")), trace=list(final.get("trace") or []), ) __all__ = [ "PipelineConfig", "PipelineRunResult", "build_pipeline", "run_pipeline", ]