| """Node: combine retrieve_schema + retrieve_examples into one ContextBundle. |
| |
| Thin wrapper over `nl_sql.schema_index.retrieve_context`. Per arch v2 §3, |
| this node also owns dialect-adapter hints (Postgres vs SQLite). For v1 we |
| just pass dialect through state — the prompt assembler picks dialect-specific |
| phrasing once we observe model failure modes during eval. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Callable |
|
|
| from sqlalchemy.engine import Engine |
|
|
| from nl_sql.agent.state import PipelineState |
| from nl_sql.db.registry import DatabaseRegistry |
| from nl_sql.schema_index.indexer import SchemaIndex |
| from nl_sql.schema_index.retriever import retrieve_context |
|
|
|
|
| def make_context_builder_node( |
| index: SchemaIndex, |
| *, |
| schema_top_k: int = 5, |
| fewshot_top_k: int = 3, |
| fk_hops: int = 1, |
| table_budget: int = 12, |
| registry: DatabaseRegistry | None = None, |
| primary_sample_size: int = 3, |
| extended_sample_size: int = 0, |
| cross_db_fewshot: bool = False, |
| ) -> Callable[[PipelineState], PipelineState]: |
| """Construct the context-builder node. |
| |
| Sample mixture wiring: when `registry` is provided AND |
| `extended_sample_size > primary_sample_size`, the node opens the |
| db's read-only engine for the current question and asks |
| `retrieve_context` to attach an "extended samples" appendix to the |
| bundle. `render_schema_block` then formats it as a supplementary |
| block. No-op when either flag is missing — the production default. |
| """ |
|
|
| mixture_enabled = registry is not None and extended_sample_size > primary_sample_size |
|
|
| def node(state: PipelineState) -> PipelineState: |
| question = state.get("question", "") |
| db_id = state.get("db_id", "") |
| if not question or not db_id: |
| return { |
| "context": None, |
| "trace": _append_trace(state, "context_builder", note="missing question or db_id"), |
| } |
| engine: Engine | None = None |
| if mixture_enabled: |
| assert registry is not None |
| engine = registry.get(db_id).make_engine() |
| try: |
| bundle = retrieve_context( |
| index, |
| question, |
| db_id=db_id, |
| schema_top_k=schema_top_k, |
| fewshot_top_k=fewshot_top_k, |
| fk_hops=fk_hops, |
| table_budget=table_budget, |
| engine=engine, |
| primary_sample_size=primary_sample_size, |
| extended_sample_size=extended_sample_size, |
| cross_db_fewshot=cross_db_fewshot, |
| ) |
| finally: |
| if engine is not None: |
| engine.dispose() |
| return { |
| "context": bundle, |
| "trace": _append_trace( |
| state, |
| "context_builder", |
| tables=bundle.all_tables, |
| fewshots=len(bundle.fewshots), |
| truncated=bundle.truncated, |
| extended_sample_tables=( |
| sorted(bundle.extended_samples) if bundle.extended_samples else [] |
| ), |
| ), |
| } |
|
|
| return node |
|
|
|
|
| def _append_trace(state: PipelineState, node: str, **details: object) -> list[dict[str, object]]: |
| trace = list(state.get("trace") or []) |
| trace.append({"node": node, **details}) |
| return trace |
|
|