File size: 13,916 Bytes
d48602c 4b4ff9e d48602c 4b4ff9e d48602c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 | """LangGraph StateGraph wiring + a thin run-result wrapper.
Topology (per docs/02_architecture_v2.md Β§3):
START
β
βΌ
context_builder
β
βΌ
generate_sql ββββββββββββββ
β β
βΌ β
validate ββfailβββΊ repair_once (fired exactly once,
β guarded by repair_attempted)
βΌ ok
execute ββfailβββΊ repair_once
β
βΌ ok
deterministic_format
β
βΌ
explain_trace
β
βΌ
END
Failure fall-through: when a fail happens AND repair was already attempted,
we route directly to deterministic_format with the error attached, so the
user always sees a structured caption + trace instead of a 500.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Literal, cast
from langgraph.graph import END, START, StateGraph
from langgraph.graph.state import CompiledStateGraph
from nl_sql.agent.nodes import (
make_context_builder_node,
make_execute_node,
make_explain_trace_node,
make_format_node,
make_generate_sql_node,
make_grounded_critique_node,
make_plan_node,
make_repair_once_node,
make_validate_node,
)
from nl_sql.agent.state import GenerateSQLOutput, PipelineState
from nl_sql.db.connection import Dialect
from nl_sql.db.registry import DatabaseRegistry
from nl_sql.execution.errors import ExecutionErrorKind
from nl_sql.execution.runner import ExecutionOutcome
from nl_sql.llm.providers.base import LLMProvider
from nl_sql.render.formats import OutputFormat
from nl_sql.schema_index.indexer import SchemaIndex
@dataclass(slots=True)
class PipelineConfig:
"""All runtime dependencies. Tests inject fakes via this object."""
sql_provider: LLMProvider
explain_provider: LLMProvider
schema_index: SchemaIndex
registry: DatabaseRegistry
schema_top_k: int = 5
fewshot_top_k: int = 3
fk_hops: int = 1
table_budget: int = 12
statement_timeout_ms: int = 30_000
row_cap: int = 10_000
sort_schema_block: bool = True
"""Render schema_block in alphabetical-by-table-name order instead of
retrieval-distance + FK BFS order. Empirically the single biggest
retrieval-side EA lever on BIRD Mini-Dev under codestral
(+3pp moderate, +5.5pp challenging at n=100; +5pp moderate at n=200).
Default ON since 2026-05-11 per docs/SESSION_HANDOFF.md item #2.
Set to False explicitly to recover the unsorted retrieval-distance
baseline for ablation."""
primary_sample_size: int = 3
"""Sample density already baked into the chunks stored in Chroma.
Must match the `--sample-size` used by `scripts/build_index.py` when
the current `chroma_data/` was built. Used together with
`extended_sample_size` to compute the tail for the mixture appendix.
"""
extended_sample_size: int = 0
"""Per-difficulty sample mixture (off by default). When > 0 and
> `primary_sample_size`, the context_builder fetches sample values
rows `primary..extended` per column for retrieved tables and
`render_schema_block` appends them as an "additional sample values"
section. Empirically: s=3 cards favour moderate-tier accuracy, s=5
cards favour challenging-tier; the mixture exposes both densities
to the model in a single prompt. Requires registry access β see
docs/SESSION_HANDOFF.md item #1."""
sql_temperature: float = 0.0
"""Sampling temperature for the generate_sql / repair_once LLM calls.
Default 0.0 = greedy / deterministic. Higher values inject diversity
needed by config F (self-consistency execution-based voting), where
each candidate runs at a different temperature so the cache stores
them as distinct entries."""
cross_db_fewshot: bool = False
"""When True, few-shot retrieval skips the `db_id` filter and pulls
QβSQL hits from any database in the `fewshot_qsql` collection. Needed
for BIRD, whose train and dev splits are partitioned by db_id (zero
overlap) β same-db retrieval would return zero hits. Set ON by
`run_config_d`; off everywhere else."""
verify_retry_on_empty: bool = False
"""When True, route an EMPTY_RESULT outcome to `repair_once` instead
of short-circuiting to deterministic_format. Empty rows often mean
the model got the filter value wrong (case mismatch, LIKE pattern
missing, NULL handling); a second pass with the empty-result hint
can recover them. Subject to the standard `repair_attempted` guard β
one extra LLM call per question, capped. Set ON by `run_config_g`."""
enable_planner: bool = False
"""When True, insert a `plan_query` node before `generate_sql`. The
planner emits a structured JSON skeleton (intent / expected_row_count
/ tables / joins / filters / group_by / aggregations / projection /
sort / limit) which `generate_sql` and `repair_once` then condition
on via the {{plan_block}} prompt slot. Doubles per-question LLM cost
on cache miss; intended for moderate/challenging-tier difficulty
where the row-shape commitment delta justifies the extra call.
Empirically targets the row_count_off + projection_diff failure
buckets identified by `scripts/error_taxonomy.py`."""
enable_grounded_critique: bool = False
"""When True, run a cheap post-execution row-shape critique before
deterministic formatting and route one failed critique to `repair_once`.
"""
use_m_schema: bool = False
"""When True, render the schema block as M-Schema (XiYan-SQL compact
one-line-per-column with inline samples + trailing FK pairs block) instead
of the default verbose card layout. Replaces the legacy `NLSQL_M_SCHEMA=1`
env toggle; `api/main.py` reads the env once at boot and threads it here so
individual nodes no longer touch `os.environ` at runtime."""
use_dac_prompt: bool = False
"""When True, use the CHASE-SQL divide-and-conquer prompt
(`generate_sql_dac.txt`) which decomposes multi-clause questions into
sub-questions before composing SQL. Replaces the legacy `NLSQL_DAC=1`
env toggle; `api/main.py` reads the env once at boot and threads it here."""
@dataclass(slots=True)
class PipelineRunResult:
"""Flat snapshot of the terminal state β what the caller needs."""
question: str
db_id: str
sql: str
rationale: str
confidence: float
outcome: ExecutionOutcome | None
output_format: OutputFormat | None
caption: str
error_kind: ExecutionErrorKind | None
error_message: str
repair_attempted: bool
trace: list[dict[str, object]]
@property
def ok(self) -> bool:
return self.outcome is not None and self.outcome.ok and self.error_kind is None
def build_pipeline(config: PipelineConfig) -> CompiledStateGraph[Any, Any, Any, Any]:
graph: StateGraph[PipelineState, None, PipelineState, PipelineState] = StateGraph(PipelineState)
nodes: dict[str, Any] = {
"context_builder": make_context_builder_node(
config.schema_index,
schema_top_k=config.schema_top_k,
fewshot_top_k=config.fewshot_top_k,
fk_hops=config.fk_hops,
table_budget=config.table_budget,
registry=config.registry,
primary_sample_size=config.primary_sample_size,
extended_sample_size=config.extended_sample_size,
cross_db_fewshot=config.cross_db_fewshot,
),
"generate_sql": make_generate_sql_node(
config.sql_provider,
sort_schema_block=config.sort_schema_block,
temperature=config.sql_temperature,
use_m_schema=config.use_m_schema,
use_dac_prompt=config.use_dac_prompt,
),
"validate": make_validate_node(),
"repair_once": make_repair_once_node(
config.sql_provider,
sort_schema_block=config.sort_schema_block,
),
"execute": make_execute_node(
registry=config.registry,
statement_timeout_ms=config.statement_timeout_ms,
row_cap=config.row_cap,
),
"deterministic_format": make_format_node(),
"explain_trace": make_explain_trace_node(config.explain_provider),
}
if config.enable_planner:
nodes["plan_query"] = make_plan_node(
config.sql_provider,
sort_schema_block=config.sort_schema_block,
temperature=config.sql_temperature,
)
if config.enable_grounded_critique:
nodes["grounded_critique"] = make_grounded_critique_node()
for name, action in nodes.items():
graph.add_node(name, action)
graph.add_edge(START, "context_builder")
if config.enable_planner:
graph.add_edge("context_builder", "plan_query")
graph.add_edge("plan_query", "generate_sql")
else:
graph.add_edge("context_builder", "generate_sql")
graph.add_edge("generate_sql", "validate")
graph.add_conditional_edges("validate", _route_after_validate)
graph.add_edge("repair_once", "validate")
if config.enable_grounded_critique:
graph.add_conditional_edges("execute", _route_after_execute_with_critique)
graph.add_conditional_edges("grounded_critique", _route_after_grounded_critique)
else:
graph.add_conditional_edges("execute", _route_after_execute)
graph.add_edge("deterministic_format", "explain_trace")
graph.add_edge("explain_trace", END)
return graph.compile()
_AfterValidate = Literal["repair_once", "execute", "deterministic_format"]
_AfterExecute = Literal["repair_once", "deterministic_format"]
_AfterExecuteWithCritique = Literal["repair_once", "deterministic_format", "grounded_critique"]
_AfterGroundedCritique = Literal["repair_once", "deterministic_format"]
def _route_after_validate(state: PipelineState) -> _AfterValidate:
outcome = state.get("outcome")
if outcome is not None and outcome.error_kind is None:
return "execute"
if not state.get("repair_attempted"):
return "repair_once"
return "deterministic_format"
def _route_after_execute(state: PipelineState) -> _AfterExecute:
outcome = state.get("outcome")
if outcome is None:
return "deterministic_format"
if outcome.ok:
return "deterministic_format"
# EMPTY_RESULT is normally a valid outcome (zero rows is a legitimate
# answer) β render handles the empty-set messaging. Config G flips this
# to retry the empty case once, on the assumption that the model
# confused a filter value (case mismatch, LIKE pattern, NULL handling).
if outcome.error_kind == ExecutionErrorKind.EMPTY_RESULT:
if state.get("verify_retry_on_empty") and not state.get("repair_attempted"):
return "repair_once"
return "deterministic_format"
if not state.get("repair_attempted"):
return "repair_once"
return "deterministic_format"
def _route_after_execute_with_critique(state: PipelineState) -> _AfterExecuteWithCritique:
outcome = state.get("outcome")
if outcome is not None and outcome.ok:
return "grounded_critique"
return _route_after_execute(state)
def _route_after_grounded_critique(state: PipelineState) -> _AfterGroundedCritique:
if state.get("critique_failed") and not state.get("repair_attempted"):
return "repair_once"
return "deterministic_format"
def run_pipeline(
pipeline: CompiledStateGraph[Any, Any, Any, Any],
*,
question: str,
db_id: str,
dialect: Dialect = "sqlite",
disable_repair: bool = False,
verify_retry_on_empty: bool = False,
) -> PipelineRunResult:
"""One-shot helper: invoke the compiled graph and flatten the result.
`disable_repair` (default False): when True, sets repair_attempted in
initial state, which causes both `_route_after_validate` and
`_route_after_execute` to skip the repair branch on the first failure
and fall through to deterministic_format. Used by eval configurations
A-D where the methodology specifies "no repair" as a measured baseline.
`verify_retry_on_empty` (default False): when True, an EMPTY_RESULT
outcome routes to repair_once (subject to the repair_attempted guard)
so the model can take a second swing at the filter values. Used by
config G; the corresponding `last_error` payload comes from the
execute node and includes the empty-result hint.
"""
initial: PipelineState = {
"question": question,
"db_id": db_id,
"dialect": dialect,
"repair_attempted": disable_repair,
"verify_retry_on_empty": verify_retry_on_empty,
"trace": [],
}
final = cast(PipelineState, pipeline.invoke(initial))
generated = final.get("generated") or GenerateSQLOutput(sql="")
return PipelineRunResult(
question=final.get("question", question),
db_id=final.get("db_id", db_id),
sql=generated.sql,
rationale=generated.rationale,
confidence=generated.confidence,
outcome=final.get("outcome"),
output_format=final.get("output_format"),
caption=final.get("caption", ""),
error_kind=final.get("error_kind"),
error_message=final.get("error_message", ""),
repair_attempted=bool(final.get("repair_attempted")),
trace=list(final.get("trace") or []),
)
__all__ = [
"PipelineConfig",
"PipelineRunResult",
"build_pipeline",
"run_pipeline",
]
|