File size: 3,360 Bytes
942050b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | """Node: combine retrieve_schema + retrieve_examples into one ContextBundle.
Thin wrapper over `nl_sql.schema_index.retrieve_context`. Per arch v2 §3,
this node also owns dialect-adapter hints (Postgres vs SQLite). For v1 we
just pass dialect through state — the prompt assembler picks dialect-specific
phrasing once we observe model failure modes during eval.
"""
from __future__ import annotations
from collections.abc import Callable
from sqlalchemy.engine import Engine
from nl_sql.agent.state import PipelineState
from nl_sql.db.registry import DatabaseRegistry
from nl_sql.schema_index.indexer import SchemaIndex
from nl_sql.schema_index.retriever import retrieve_context
def make_context_builder_node(
index: SchemaIndex,
*,
schema_top_k: int = 5,
fewshot_top_k: int = 3,
fk_hops: int = 1,
table_budget: int = 12,
registry: DatabaseRegistry | None = None,
primary_sample_size: int = 3,
extended_sample_size: int = 0,
cross_db_fewshot: bool = False,
) -> Callable[[PipelineState], PipelineState]:
"""Construct the context-builder node.
Sample mixture wiring: when `registry` is provided AND
`extended_sample_size > primary_sample_size`, the node opens the
db's read-only engine for the current question and asks
`retrieve_context` to attach an "extended samples" appendix to the
bundle. `render_schema_block` then formats it as a supplementary
block. No-op when either flag is missing — the production default.
"""
mixture_enabled = registry is not None and extended_sample_size > primary_sample_size
def node(state: PipelineState) -> PipelineState:
question = state.get("question", "")
db_id = state.get("db_id", "")
if not question or not db_id:
return {
"context": None,
"trace": _append_trace(state, "context_builder", note="missing question or db_id"),
}
engine: Engine | None = None
if mixture_enabled:
assert registry is not None
engine = registry.get(db_id).make_engine()
try:
bundle = retrieve_context(
index,
question,
db_id=db_id,
schema_top_k=schema_top_k,
fewshot_top_k=fewshot_top_k,
fk_hops=fk_hops,
table_budget=table_budget,
engine=engine,
primary_sample_size=primary_sample_size,
extended_sample_size=extended_sample_size,
cross_db_fewshot=cross_db_fewshot,
)
finally:
if engine is not None:
engine.dispose()
return {
"context": bundle,
"trace": _append_trace(
state,
"context_builder",
tables=bundle.all_tables,
fewshots=len(bundle.fewshots),
truncated=bundle.truncated,
extended_sample_tables=(
sorted(bundle.extended_samples) if bundle.extended_samples else []
),
),
}
return node
def _append_trace(state: PipelineState, node: str, **details: object) -> list[dict[str, object]]:
trace = list(state.get("trace") or [])
trace.append({"node": node, **details})
return trace
|