"""Drop-in eval driver for plan-then-SQL ablation. Why a dedicated script instead of `scripts/eval_baseline.py --config G`: we need the `enable_planner=True` knob on PipelineConfig which the existing driver doesn't surface yet, and we want robust progress logging + resumable JSON output without the background-shell-pipe issues we hit when running long evals via the harness. Usage: uv run python scripts/run_planner_eval.py \\ --difficulty moderate --n 200 --seed 0 \\ --out eval/reports/2026-05-11/G_planner-moderate-n99.json """ from __future__ import annotations import argparse import json import sys import time import traceback from pathlib import Path from nl_sql.agent.graph import PipelineConfig, build_pipeline, run_pipeline from nl_sql.config import get_settings from nl_sql.db.registry import get_default_registry from nl_sql.eval.dataset import dev_split, load_bird_mini_dev from nl_sql.eval.metrics.execution_accuracy import compare_results from nl_sql.eval.runner import _compose_question, _execute_gold from nl_sql.llm.cache import CachingEmbeddingProvider, CachingLLMProvider from nl_sql.llm.providers import build_provider from nl_sql.llm.providers.mistral import MistralProvider from nl_sql.schema_index.indexer import SchemaIndex def main() -> int: p = argparse.ArgumentParser(description=__doc__) p.add_argument("--difficulty", choices=["simple", "moderate", "challenging"], default=None) p.add_argument("--n", type=int, default=200, help="prefix size BEFORE difficulty filter") p.add_argument("--seed", type=int, default=0) p.add_argument("--out", type=Path, required=True) p.add_argument( "--log", type=Path, default=None, help="per-example progress log; default .progress.log", ) p.add_argument("--enable-planner", action="store_true", default=False) p.add_argument("--no-planner", dest="enable_planner", action="store_false") p.add_argument("--enable-grounded-critique", action="store_true", default=False) p.add_argument("--bird-root", default="data/bird_mini_dev/MINIDEV") p.add_argument("--provider", default="mistral") p.add_argument("--limit", type=int, default=0, help="cap examples after filtering (0=all)") args = p.parse_args() log_path = args.log or args.out.with_suffix(".progress.log") log_path.parent.mkdir(parents=True, exist_ok=True) args.out.parent.mkdir(parents=True, exist_ok=True) s = get_settings() sql_prov = CachingLLMProvider( build_provider(args.provider, settings=s), cache_dir=s.llm_cache_dir ) emb = CachingEmbeddingProvider( MistralProvider(api_key=s.mistral_api_key), cache_dir=s.llm_cache_dir ) idx = SchemaIndex(persist_dir="chroma_data", embedder=emb) registry = get_default_registry() examples = load_bird_mini_dev(Path(args.bird_root)) sample = dev_split(examples, n=args.n, seed=args.seed) if args.difficulty: sample = [e for e in sample if e.difficulty == args.difficulty] if args.limit: sample = sample[: args.limit] cfg = PipelineConfig( sql_provider=sql_prov, explain_provider=sql_prov, schema_index=idx, registry=registry, fewshot_top_k=3, sort_schema_block=True, cross_db_fewshot=True, verify_retry_on_empty=True, enable_planner=args.enable_planner, enable_grounded_critique=args.enable_grounded_critique, statement_timeout_ms=30_000, row_cap=10_000, ) pipe = build_pipeline(cfg) def log(msg: str) -> None: ts = time.strftime("%H:%M:%S") line = f"[{ts}] {msg}\n" log_path.open("a", encoding="utf-8").write(line) sys.stderr.write(line) sys.stderr.flush() log( f"start: n={len(sample)} difficulty={args.difficulty} enable_planner={args.enable_planner} out={args.out}" ) records: list[dict] = [] matched = 0 for i, ex in enumerate(sample, 1): started = time.perf_counter() spec = registry.get(ex.registry_db_id) gold_engine = spec.make_engine() try: try: res = run_pipeline( pipe, question=_compose_question(ex), db_id=ex.registry_db_id, dialect="sqlite", verify_retry_on_empty=True, ) except Exception as exc: log(f"[{i:3d}/{len(sample)}] EXC qid={ex.question_id}: {type(exc).__name__}: {exc}") traceback.print_exc(file=sys.stderr) continue try: gold_rows, _ = _execute_gold( gold_engine, ex.sql, statement_timeout_ms=30_000, row_cap=10_000 ) except Exception: gold_rows = [] if res.outcome is not None and res.outcome.result is not None: cmp = compare_results(gold_rows, res.outcome.result.rows, gold_sql=ex.sql) ok = cmp.match reason = cmp.reason gc, pc = cmp.gold_rows, cmp.pred_rows else: ok = False reason = res.error_kind.value if res.error_kind else "no result" gc, pc = len(gold_rows), 0 if ok: matched += 1 records.append( { "question_id": ex.question_id, "db_id": ex.db_id, "difficulty": ex.difficulty, "dialect": ex.dialect, "question": ex.question, "gold_sql": ex.sql, "pred_sql": res.sql, "match": bool(ok), "comparison_reason": reason, "gold_row_count": gc, "pred_row_count": pc, "error_kind": res.error_kind.value if res.error_kind else None, "confidence": res.confidence, "repair_attempted": res.repair_attempted, } ) elapsed = (time.perf_counter() - started) * 1000.0 log( f"[{i:3d}/{len(sample)}] {'OK ' if ok else ' '} ({elapsed:6.0f}ms) " f"qid={ex.question_id} {ex.registry_db_id}/{ex.difficulty} — " f"{ex.question[:60]}" ) # incremental dump every 10 to survive crashes if i % 10 == 0: args.out.write_text( json.dumps( { "configuration": "G_planner", "sql_model": "codestral-latest", "overall": {"ea": matched / len(records), "n": len(records)}, "records": records, }, indent=2, ), encoding="utf-8", ) finally: gold_engine.dispose() ea = matched / len(records) if records else 0.0 args.out.write_text( json.dumps( { "configuration": "G_planner", "sql_model": "codestral-latest", "overall": {"ea": ea, "n": len(records), "matched": matched}, "records": records, }, indent=2, ), encoding="utf-8", ) log(f"done: EA={matched}/{len(records)} = {100 * ea:.1f}% → {args.out}") return 0 if __name__ == "__main__": raise SystemExit(main())