File size: 3,360 Bytes
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Node: combine retrieve_schema + retrieve_examples into one ContextBundle.

Thin wrapper over `nl_sql.schema_index.retrieve_context`. Per arch v2 §3,
this node also owns dialect-adapter hints (Postgres vs SQLite). For v1 we
just pass dialect through state — the prompt assembler picks dialect-specific
phrasing once we observe model failure modes during eval.
"""

from __future__ import annotations

from collections.abc import Callable

from sqlalchemy.engine import Engine

from nl_sql.agent.state import PipelineState
from nl_sql.db.registry import DatabaseRegistry
from nl_sql.schema_index.indexer import SchemaIndex
from nl_sql.schema_index.retriever import retrieve_context


def make_context_builder_node(
    index: SchemaIndex,
    *,
    schema_top_k: int = 5,
    fewshot_top_k: int = 3,
    fk_hops: int = 1,
    table_budget: int = 12,
    registry: DatabaseRegistry | None = None,
    primary_sample_size: int = 3,
    extended_sample_size: int = 0,
    cross_db_fewshot: bool = False,
) -> Callable[[PipelineState], PipelineState]:
    """Construct the context-builder node.

    Sample mixture wiring: when `registry` is provided AND
    `extended_sample_size > primary_sample_size`, the node opens the
    db's read-only engine for the current question and asks
    `retrieve_context` to attach an "extended samples" appendix to the
    bundle. `render_schema_block` then formats it as a supplementary
    block. No-op when either flag is missing — the production default.
    """

    mixture_enabled = registry is not None and extended_sample_size > primary_sample_size

    def node(state: PipelineState) -> PipelineState:
        question = state.get("question", "")
        db_id = state.get("db_id", "")
        if not question or not db_id:
            return {
                "context": None,
                "trace": _append_trace(state, "context_builder", note="missing question or db_id"),
            }
        engine: Engine | None = None
        if mixture_enabled:
            assert registry is not None
            engine = registry.get(db_id).make_engine()
        try:
            bundle = retrieve_context(
                index,
                question,
                db_id=db_id,
                schema_top_k=schema_top_k,
                fewshot_top_k=fewshot_top_k,
                fk_hops=fk_hops,
                table_budget=table_budget,
                engine=engine,
                primary_sample_size=primary_sample_size,
                extended_sample_size=extended_sample_size,
                cross_db_fewshot=cross_db_fewshot,
            )
        finally:
            if engine is not None:
                engine.dispose()
        return {
            "context": bundle,
            "trace": _append_trace(
                state,
                "context_builder",
                tables=bundle.all_tables,
                fewshots=len(bundle.fewshots),
                truncated=bundle.truncated,
                extended_sample_tables=(
                    sorted(bundle.extended_samples) if bundle.extended_samples else []
                ),
            ),
        }

    return node


def _append_trace(state: PipelineState, node: str, **details: object) -> list[dict[str, object]]:
    trace = list(state.get("trace") or [])
    trace.append({"node": node, **details})
    return trace