File size: 3,262 Bytes
06ed757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b4ff9e
 
06ed757
 
 
 
 
 
 
4b4ff9e
 
 
 
06ed757
 
 
4b4ff9e
 
 
 
06ed757
8b8c11e
06ed757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Node: ask codestral (or any LLMProvider) for SQL given the schema context.



Builds the prompt from the active context bundle, dispatches to the provider,

parses the JSON response into a `GenerateSQLOutput`. The same node powers the

*initial* generation pass — the repair pass is a separate node that calls

this same provider with a different prompt.

"""

from __future__ import annotations

from collections.abc import Callable

from nl_sql.agent.nodes._support import (
    parse_generate_sql_output,
    render_fewshot_block,
    render_m_schema,
    render_schema_block,
)
from nl_sql.agent.prompts import load_prompt
from nl_sql.agent.state import PipelineState
from nl_sql.llm.providers.base import GenerateRequest, LLMProvider


def make_generate_sql_node(

    provider: LLMProvider,

    *,

    max_tokens: int = 1024,

    temperature: float = 0.0,

    sort_schema_block: bool = False,

    use_m_schema: bool = False,

    use_dac_prompt: bool = False,

) -> Callable[[PipelineState], PipelineState]:
    def node(state: PipelineState) -> PipelineState:
        question = state.get("question", "")
        dialect = state.get("dialect", "sqlite")
        context = state.get("context")
        plan_raw = (state.get("plan") or "").strip()
        plan_block = plan_raw if plan_raw else "(no plan — generate SQL directly from question)"
        # Schema rendering: M-Schema (XiYan-SQL compact) vs verbose card layout.
        # Driven by `PipelineConfig.use_m_schema`; api/main.py bootstraps the
        # flag from `NLSQL_M_SCHEMA=1` env so existing eval scripts keep working.
        if use_m_schema:
            schema_text = render_m_schema(context)
        else:
            schema_text = render_schema_block(context, sort_alphabetically=sort_schema_block)
        # CHASE-SQL divide-and-conquer prompt — decomposes multi-clause questions
        # into sub-questions before composing SQL. Driven by
        # `PipelineConfig.use_dac_prompt`; api/main.py bootstraps from `NLSQL_DAC=1`.
        prompt_name = "generate_sql_dac" if use_dac_prompt else "generate_sql"
        prompt = load_prompt(
            prompt_name,
            dialect=dialect,
            schema_block=schema_text,
            fewshot_block=render_fewshot_block(context),
            plan_block=plan_block,
            question=question,
        )
        response = provider.generate(
            GenerateRequest(prompt=prompt, max_tokens=max_tokens, temperature=temperature)
        )
        parsed = parse_generate_sql_output(response.text)
        trace = list(state.get("trace") or [])
        trace.append(
            {
                "node": "generate_sql",
                "model": response.model,
                "confidence": parsed.confidence,
                "tables_used": list(parsed.tables_used),
                "input_tokens": response.input_tokens,
                "output_tokens": response.output_tokens,
            }
        )
        # Reset any stale outcome / error from a previous repair iteration.
        return {
            "generated": parsed,
            "outcome": None,
            "last_error": "",
            "trace": trace,
        }

    return node