nl-sql / scripts /eval_demo.py
liovina's picture
Deploy NL_SQL HEAD to HF Space
942050b verified
"""Evaluate the pipeline on the curated demo benchmark.
Unlike `eval_baseline.py`, which measures BIRD Mini-Dev (research baseline
~50% ceiling), this script measures a *product workload*: 30 realistic
business questions on Chinook where we target ≥90% Execution Accuracy.
Usage:
uv run python scripts/eval_demo.py
uv run python scripts/eval_demo.py --benchmark eval/demo_benchmark.json
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import Any
import chromadb
from nl_sql.agent import PipelineConfig, build_pipeline, run_pipeline
from nl_sql.config import get_settings
from nl_sql.db.connection import Dialect, execute_readonly
from nl_sql.db.registry import get_default_registry
from nl_sql.eval.metrics.execution_accuracy import compare_results
from nl_sql.llm.cache import CachingEmbeddingProvider, CachingLLMProvider
from nl_sql.llm.providers.mistral import MistralProvider
from nl_sql.schema_index.indexer import SchemaIndex
DEFAULT_BENCHMARK = Path(__file__).parent.parent / "eval" / "demo_benchmark.json"
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--benchmark",
type=Path,
default=DEFAULT_BENCHMARK,
help=f"path to benchmark JSON (default: {DEFAULT_BENCHMARK})",
)
parser.add_argument("--persist", default="chroma_data")
parser.add_argument("--no-cache", action="store_true")
parser.add_argument(
"--report",
type=Path,
help="optional output JSON path for the full per-question report",
)
args = parser.parse_args(argv)
settings = get_settings()
if not settings.mistral_api_key:
print("[error] MISTRAL_API_KEY not set in .env", file=sys.stderr)
return 2
bench = json.loads(args.benchmark.read_text(encoding="utf-8"))
db_id = bench["db_id"]
dialect: Dialect = bench.get("dialect", "sqlite")
questions = bench["questions"]
print(f"[info] benchmark: {bench['name']} ({len(questions)} questions on {db_id!r})")
persist = Path(args.persist)
if not persist.is_dir():
print(f"[error] index not found at {persist}; run scripts/build_index.py first")
return 3
client = chromadb.PersistentClient(path=str(persist))
raw_embed = MistralProvider(
api_key=settings.mistral_api_key,
gen_model=settings.mistral_gen_model,
embed_model=settings.mistral_embed_model,
base_url=settings.mistral_base_url,
)
raw_sql = MistralProvider(
api_key=settings.mistral_api_key,
gen_model=settings.mistral_gen_model,
embed_model=settings.mistral_embed_model,
base_url=settings.mistral_base_url,
)
raw_explain = MistralProvider(
api_key=settings.mistral_api_key,
gen_model="mistral-large-latest",
embed_model=settings.mistral_embed_model,
base_url=settings.mistral_base_url,
)
embedder = (
raw_embed
if args.no_cache
else CachingEmbeddingProvider(
raw_embed,
cache_dir=settings.llm_cache_dir,
size_limit_gb=settings.llm_cache_size_limit_gb,
)
)
sql_provider = (
raw_sql
if args.no_cache
else CachingLLMProvider(
raw_sql,
cache_dir=settings.llm_cache_dir,
size_limit_gb=settings.llm_cache_size_limit_gb,
)
)
explain_provider = (
raw_explain
if args.no_cache
else CachingLLMProvider(
raw_explain,
cache_dir=settings.llm_cache_dir,
size_limit_gb=settings.llm_cache_size_limit_gb,
)
)
index = SchemaIndex(persist_dir=persist, embedder=embedder, client=client)
registry = get_default_registry()
pipeline = build_pipeline(
PipelineConfig(
sql_provider=sql_provider,
explain_provider=explain_provider,
schema_index=index,
registry=registry,
schema_top_k=5,
fewshot_top_k=0,
fk_hops=1,
table_budget=12,
sort_schema_block=True,
primary_sample_size=3,
)
)
spec = registry.get(db_id)
gold_engine = spec.make_engine()
records: list[dict[str, Any]] = []
started_all = time.perf_counter()
try:
for i, q in enumerate(questions, start=1):
t0 = time.perf_counter()
try:
result = run_pipeline(
pipeline, question=q["question"], db_id=db_id, dialect=dialect
)
except Exception as exc:
elapsed = (time.perf_counter() - t0) * 1000
records.append(
{
"id": q["id"],
"category": q.get("category", ""),
"difficulty": q.get("difficulty", ""),
"split": q.get("split", "dev"),
"question": q["question"],
"gold_sql": q["gold_sql"],
"pred_sql": "",
"match": False,
"reason": f"pipeline raised: {exc!r}",
"latency_ms": elapsed,
}
)
print(f" [{i:>2}/{len(questions)}] EXCEPTION {q['id']}: {exc}")
continue
with execute_readonly(
gold_engine, q["gold_sql"], statement_timeout_ms=30_000, row_cap=10_000
) as gold:
gold_rows = list(gold.rows)
if result.outcome is not None and result.outcome.result is not None:
cmp = compare_results(gold_rows, result.outcome.result.rows, gold_sql=q["gold_sql"])
match = cmp.match
reason = cmp.reason
else:
match = False
reason = (
f"pred failed: {result.error_kind.value if result.error_kind else 'unknown'}"
)
elapsed = (time.perf_counter() - t0) * 1000
flag = "OK " if match else "MISS"
print(
f" [{i:>2}/{len(questions)}] {flag} ({elapsed:5.0f}ms) {q['id']}{q['question'][:70]}"
)
if not match:
print(f" gold: {q['gold_sql'][:140]}")
print(f" pred: {result.sql[:140]}")
print(f" why: {reason}")
records.append(
{
"id": q["id"],
"category": q.get("category", ""),
"difficulty": q.get("difficulty", ""),
"split": q.get("split", "dev"),
"question": q["question"],
"gold_sql": q["gold_sql"],
"pred_sql": result.sql,
"match": match,
"reason": reason,
"latency_ms": elapsed,
}
)
finally:
gold_engine.dispose()
elapsed_total = time.perf_counter() - started_all
matches = sum(1 for r in records if r["match"])
ea = matches / len(records) if records else 0.0
print()
print("=" * 78)
print(f"Demo benchmark: {bench['name']}")
print(f"DB: {db_id} ({dialect})")
print(f"Questions: {len(records)}")
print(f"Match: {matches}/{len(records)} = {ea * 100:.1f}%")
by_cat: dict[str, list[bool]] = defaultdict(list)
by_diff: dict[str, list[bool]] = defaultdict(list)
by_split: dict[str, list[bool]] = defaultdict(list)
for r in records:
by_cat[r["category"]].append(r["match"])
by_diff[r["difficulty"]].append(r["match"])
by_split[r["split"]].append(r["match"])
print("per category:")
for cat, ms in sorted(by_cat.items()):
print(f" {cat:14s} {sum(ms):>2}/{len(ms):<2} ({sum(ms) / len(ms) * 100:5.1f}%)")
print("per difficulty:")
for d, ms in sorted(by_diff.items()):
print(f" {d:14s} {sum(ms):>2}/{len(ms):<2} ({sum(ms) / len(ms) * 100:5.1f}%)")
print("per split:")
for s, ms in sorted(by_split.items()):
print(f" {s:14s} {sum(ms):>2}/{len(ms):<2} ({sum(ms) / len(ms) * 100:5.1f}%)")
print(f"Wall time: {elapsed_total:.1f}s")
if args.report:
args.report.parent.mkdir(parents=True, exist_ok=True)
args.report.write_text(
json.dumps(
{
"benchmark": bench["name"],
"db_id": db_id,
"dialect": dialect,
"n": len(records),
"matches": matches,
"ea": ea,
"records": records,
},
indent=2,
ensure_ascii=False,
),
encoding="utf-8",
)
print(f"[report] {args.report}")
return 0 if ea >= 0.9 else 1
if __name__ == "__main__":
sys.exit(main())