| """Evaluate the pipeline on the curated demo benchmark. |
| |
| Unlike `eval_baseline.py`, which measures BIRD Mini-Dev (research baseline |
| ~50% ceiling), this script measures a *product workload*: 30 realistic |
| business questions on Chinook where we target ≥90% Execution Accuracy. |
| |
| Usage: |
| uv run python scripts/eval_demo.py |
| uv run python scripts/eval_demo.py --benchmark eval/demo_benchmark.json |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| import time |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Any |
|
|
| import chromadb |
|
|
| from nl_sql.agent import PipelineConfig, build_pipeline, run_pipeline |
| from nl_sql.config import get_settings |
| from nl_sql.db.connection import Dialect, execute_readonly |
| from nl_sql.db.registry import get_default_registry |
| from nl_sql.eval.metrics.execution_accuracy import compare_results |
| from nl_sql.llm.cache import CachingEmbeddingProvider, CachingLLMProvider |
| from nl_sql.llm.providers.mistral import MistralProvider |
| from nl_sql.schema_index.indexer import SchemaIndex |
|
|
| DEFAULT_BENCHMARK = Path(__file__).parent.parent / "eval" / "demo_benchmark.json" |
|
|
|
|
| def main(argv: list[str] | None = None) -> int: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument( |
| "--benchmark", |
| type=Path, |
| default=DEFAULT_BENCHMARK, |
| help=f"path to benchmark JSON (default: {DEFAULT_BENCHMARK})", |
| ) |
| parser.add_argument("--persist", default="chroma_data") |
| parser.add_argument("--no-cache", action="store_true") |
| parser.add_argument( |
| "--report", |
| type=Path, |
| help="optional output JSON path for the full per-question report", |
| ) |
| args = parser.parse_args(argv) |
|
|
| settings = get_settings() |
| if not settings.mistral_api_key: |
| print("[error] MISTRAL_API_KEY not set in .env", file=sys.stderr) |
| return 2 |
|
|
| bench = json.loads(args.benchmark.read_text(encoding="utf-8")) |
| db_id = bench["db_id"] |
| dialect: Dialect = bench.get("dialect", "sqlite") |
| questions = bench["questions"] |
| print(f"[info] benchmark: {bench['name']} ({len(questions)} questions on {db_id!r})") |
|
|
| persist = Path(args.persist) |
| if not persist.is_dir(): |
| print(f"[error] index not found at {persist}; run scripts/build_index.py first") |
| return 3 |
|
|
| client = chromadb.PersistentClient(path=str(persist)) |
| raw_embed = MistralProvider( |
| api_key=settings.mistral_api_key, |
| gen_model=settings.mistral_gen_model, |
| embed_model=settings.mistral_embed_model, |
| base_url=settings.mistral_base_url, |
| ) |
| raw_sql = MistralProvider( |
| api_key=settings.mistral_api_key, |
| gen_model=settings.mistral_gen_model, |
| embed_model=settings.mistral_embed_model, |
| base_url=settings.mistral_base_url, |
| ) |
| raw_explain = MistralProvider( |
| api_key=settings.mistral_api_key, |
| gen_model="mistral-large-latest", |
| embed_model=settings.mistral_embed_model, |
| base_url=settings.mistral_base_url, |
| ) |
| embedder = ( |
| raw_embed |
| if args.no_cache |
| else CachingEmbeddingProvider( |
| raw_embed, |
| cache_dir=settings.llm_cache_dir, |
| size_limit_gb=settings.llm_cache_size_limit_gb, |
| ) |
| ) |
| sql_provider = ( |
| raw_sql |
| if args.no_cache |
| else CachingLLMProvider( |
| raw_sql, |
| cache_dir=settings.llm_cache_dir, |
| size_limit_gb=settings.llm_cache_size_limit_gb, |
| ) |
| ) |
| explain_provider = ( |
| raw_explain |
| if args.no_cache |
| else CachingLLMProvider( |
| raw_explain, |
| cache_dir=settings.llm_cache_dir, |
| size_limit_gb=settings.llm_cache_size_limit_gb, |
| ) |
| ) |
|
|
| index = SchemaIndex(persist_dir=persist, embedder=embedder, client=client) |
| registry = get_default_registry() |
| pipeline = build_pipeline( |
| PipelineConfig( |
| sql_provider=sql_provider, |
| explain_provider=explain_provider, |
| schema_index=index, |
| registry=registry, |
| schema_top_k=5, |
| fewshot_top_k=0, |
| fk_hops=1, |
| table_budget=12, |
| sort_schema_block=True, |
| primary_sample_size=3, |
| ) |
| ) |
|
|
| spec = registry.get(db_id) |
| gold_engine = spec.make_engine() |
|
|
| records: list[dict[str, Any]] = [] |
| started_all = time.perf_counter() |
| try: |
| for i, q in enumerate(questions, start=1): |
| t0 = time.perf_counter() |
| try: |
| result = run_pipeline( |
| pipeline, question=q["question"], db_id=db_id, dialect=dialect |
| ) |
| except Exception as exc: |
| elapsed = (time.perf_counter() - t0) * 1000 |
| records.append( |
| { |
| "id": q["id"], |
| "category": q.get("category", ""), |
| "difficulty": q.get("difficulty", ""), |
| "split": q.get("split", "dev"), |
| "question": q["question"], |
| "gold_sql": q["gold_sql"], |
| "pred_sql": "", |
| "match": False, |
| "reason": f"pipeline raised: {exc!r}", |
| "latency_ms": elapsed, |
| } |
| ) |
| print(f" [{i:>2}/{len(questions)}] EXCEPTION {q['id']}: {exc}") |
| continue |
|
|
| with execute_readonly( |
| gold_engine, q["gold_sql"], statement_timeout_ms=30_000, row_cap=10_000 |
| ) as gold: |
| gold_rows = list(gold.rows) |
|
|
| if result.outcome is not None and result.outcome.result is not None: |
| cmp = compare_results(gold_rows, result.outcome.result.rows, gold_sql=q["gold_sql"]) |
| match = cmp.match |
| reason = cmp.reason |
| else: |
| match = False |
| reason = ( |
| f"pred failed: {result.error_kind.value if result.error_kind else 'unknown'}" |
| ) |
|
|
| elapsed = (time.perf_counter() - t0) * 1000 |
| flag = "OK " if match else "MISS" |
| print( |
| f" [{i:>2}/{len(questions)}] {flag} ({elapsed:5.0f}ms) {q['id']} — {q['question'][:70]}" |
| ) |
| if not match: |
| print(f" gold: {q['gold_sql'][:140]}") |
| print(f" pred: {result.sql[:140]}") |
| print(f" why: {reason}") |
| records.append( |
| { |
| "id": q["id"], |
| "category": q.get("category", ""), |
| "difficulty": q.get("difficulty", ""), |
| "split": q.get("split", "dev"), |
| "question": q["question"], |
| "gold_sql": q["gold_sql"], |
| "pred_sql": result.sql, |
| "match": match, |
| "reason": reason, |
| "latency_ms": elapsed, |
| } |
| ) |
| finally: |
| gold_engine.dispose() |
|
|
| elapsed_total = time.perf_counter() - started_all |
| matches = sum(1 for r in records if r["match"]) |
| ea = matches / len(records) if records else 0.0 |
| print() |
| print("=" * 78) |
| print(f"Demo benchmark: {bench['name']}") |
| print(f"DB: {db_id} ({dialect})") |
| print(f"Questions: {len(records)}") |
| print(f"Match: {matches}/{len(records)} = {ea * 100:.1f}%") |
| by_cat: dict[str, list[bool]] = defaultdict(list) |
| by_diff: dict[str, list[bool]] = defaultdict(list) |
| by_split: dict[str, list[bool]] = defaultdict(list) |
| for r in records: |
| by_cat[r["category"]].append(r["match"]) |
| by_diff[r["difficulty"]].append(r["match"]) |
| by_split[r["split"]].append(r["match"]) |
| print("per category:") |
| for cat, ms in sorted(by_cat.items()): |
| print(f" {cat:14s} {sum(ms):>2}/{len(ms):<2} ({sum(ms) / len(ms) * 100:5.1f}%)") |
| print("per difficulty:") |
| for d, ms in sorted(by_diff.items()): |
| print(f" {d:14s} {sum(ms):>2}/{len(ms):<2} ({sum(ms) / len(ms) * 100:5.1f}%)") |
| print("per split:") |
| for s, ms in sorted(by_split.items()): |
| print(f" {s:14s} {sum(ms):>2}/{len(ms):<2} ({sum(ms) / len(ms) * 100:5.1f}%)") |
| print(f"Wall time: {elapsed_total:.1f}s") |
|
|
| if args.report: |
| args.report.parent.mkdir(parents=True, exist_ok=True) |
| args.report.write_text( |
| json.dumps( |
| { |
| "benchmark": bench["name"], |
| "db_id": db_id, |
| "dialect": dialect, |
| "n": len(records), |
| "matches": matches, |
| "ea": ea, |
| "records": records, |
| }, |
| indent=2, |
| ensure_ascii=False, |
| ), |
| encoding="utf-8", |
| ) |
| print(f"[report] {args.report}") |
|
|
| return 0 if ea >= 0.9 else 1 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|