"""Node: ask codestral (or any LLMProvider) for SQL given the schema context. Builds the prompt from the active context bundle, dispatches to the provider, parses the JSON response into a `GenerateSQLOutput`. The same node powers the *initial* generation pass — the repair pass is a separate node that calls this same provider with a different prompt. """ from __future__ import annotations from collections.abc import Callable from nl_sql.agent.nodes._support import ( parse_generate_sql_output, render_fewshot_block, render_m_schema, render_schema_block, ) from nl_sql.agent.prompts import load_prompt from nl_sql.agent.state import PipelineState from nl_sql.llm.providers.base import GenerateRequest, LLMProvider def make_generate_sql_node( provider: LLMProvider, *, max_tokens: int = 1024, temperature: float = 0.0, sort_schema_block: bool = False, use_m_schema: bool = False, use_dac_prompt: bool = False, ) -> Callable[[PipelineState], PipelineState]: def node(state: PipelineState) -> PipelineState: question = state.get("question", "") dialect = state.get("dialect", "sqlite") context = state.get("context") plan_raw = (state.get("plan") or "").strip() plan_block = plan_raw if plan_raw else "(no plan — generate SQL directly from question)" # Schema rendering: M-Schema (XiYan-SQL compact) vs verbose card layout. # Driven by `PipelineConfig.use_m_schema`; api/main.py bootstraps the # flag from `NLSQL_M_SCHEMA=1` env so existing eval scripts keep working. if use_m_schema: schema_text = render_m_schema(context) else: schema_text = render_schema_block(context, sort_alphabetically=sort_schema_block) # CHASE-SQL divide-and-conquer prompt — decomposes multi-clause questions # into sub-questions before composing SQL. Driven by # `PipelineConfig.use_dac_prompt`; api/main.py bootstraps from `NLSQL_DAC=1`. prompt_name = "generate_sql_dac" if use_dac_prompt else "generate_sql" prompt = load_prompt( prompt_name, dialect=dialect, schema_block=schema_text, fewshot_block=render_fewshot_block(context), plan_block=plan_block, question=question, ) response = provider.generate( GenerateRequest(prompt=prompt, max_tokens=max_tokens, temperature=temperature) ) parsed = parse_generate_sql_output(response.text) trace = list(state.get("trace") or []) trace.append( { "node": "generate_sql", "model": response.model, "confidence": parsed.confidence, "tables_used": list(parsed.tables_used), "input_tokens": response.input_tokens, "output_tokens": response.output_tokens, } ) # Reset any stale outcome / error from a previous repair iteration. return { "generated": parsed, "outcome": None, "last_error": "", "trace": trace, } return node