| """Node: plan_query — emit a structured plan before SQL generation. |
| |
| Two-stage decomposition (DIN-SQL / MAC-SQL style). The plan node produces a |
| JSON skeleton (tables/joins/filters/group/agg/projection/expected_row_count) |
| that the downstream `generate_sql` node sees as additional grounding context. |
| |
| Empirically, forcing the model to commit to row-shape and projection BEFORE |
| writing SQL fixes a large fraction of "row_count_off" and "projection_diff" |
| failures observed in the BIRD baseline taxonomy (see scripts/error_taxonomy.py). |
| """ |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Callable |
|
|
| from nl_sql.agent.nodes._support import ( |
| render_fewshot_block, |
| render_schema_block, |
| ) |
| from nl_sql.agent.prompts import load_prompt |
| from nl_sql.agent.state import PipelineState |
| from nl_sql.llm.providers.base import GenerateRequest, LLMProvider |
|
|
|
|
| def make_plan_node( |
| provider: LLMProvider, |
| *, |
| max_tokens: int = 600, |
| temperature: float = 0.0, |
| sort_schema_block: bool = False, |
| ) -> Callable[[PipelineState], PipelineState]: |
| def node(state: PipelineState) -> PipelineState: |
| question = state.get("question", "") |
| dialect = state.get("dialect", "sqlite") |
| context = state.get("context") |
| prompt = load_prompt( |
| "plan", |
| dialect=dialect, |
| schema_block=render_schema_block(context, sort_alphabetically=sort_schema_block), |
| fewshot_block=render_fewshot_block(context), |
| question=question, |
| ) |
| response = provider.generate( |
| GenerateRequest(prompt=prompt, max_tokens=max_tokens, temperature=temperature) |
| ) |
| plan_text = (response.text or "").strip() |
| trace = list(state.get("trace") or []) |
| trace.append( |
| { |
| "node": "plan_query", |
| "model": response.model, |
| "input_tokens": response.input_tokens, |
| "output_tokens": response.output_tokens, |
| } |
| ) |
| return { |
| "plan": plan_text, |
| "trace": trace, |
| } |
|
|
| return node |
|
|