File size: 17,198 Bytes
942050b
 
 
 
 
 
 
 
 
 
 
d48602c
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d48602c
 
 
 
 
 
 
 
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d48602c
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d48602c
 
 
 
 
 
 
 
 
 
 
 
 
 
942050b
 
 
 
 
 
 
 
 
 
 
 
 
d48602c
 
942050b
 
 
 
 
 
 
 
 
d48602c
 
 
 
 
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d48602c
942050b
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
"""Live eval-baseline: configurations A or C on a BIRD Mini-Dev slice.

Runs an ablation configuration through the codestral-latest provider against
N BIRD examples (default 50), prints per-question status, and writes both
JSON and HTML artefacts to `eval/reports/<date>/`. Configurations B/D/E are
not yet implemented; they will join the same CLI shape when they ship.

Usage:
    uv run python scripts/eval_baseline.py --config A --n 50 --seed 0
    uv run python scripts/eval_baseline.py --config C --n 50 --seed 0
    uv run python scripts/eval_baseline.py --n 5 --db bird_california_schools
    uv run python scripts/eval_baseline.py --config C --only-qids 1399,1205
"""

from __future__ import annotations

import argparse
import sys
import time
from pathlib import Path

import chromadb

from nl_sql.config import get_settings
from nl_sql.db.registry import get_default_registry
from nl_sql.eval import (
    EvalRecord,
    EvalRun,
    dev_split,
    load_bird_mini_dev,
    load_run_from_json,
    run_config_a,
    run_config_c,
    run_config_d,
    run_config_e,
    run_config_f,
    run_config_g,
    write_html_report,
    write_json_report,
)
from nl_sql.eval.dataset import DEFAULT_BIRD_ROOT
from nl_sql.llm.cache import CachingEmbeddingProvider, CachingLLMProvider
from nl_sql.llm.providers import build_provider
from nl_sql.llm.providers.base import EmbeddingProvider, LLMProvider
from nl_sql.llm.providers.mistral import MistralProvider
from nl_sql.schema_index.indexer import SchemaIndex


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--n", type=int, default=50, help="number of BIRD examples (default: 50)")
    parser.add_argument("--seed", type=int, default=0, help="dev_split seed")
    parser.add_argument(
        "--db",
        default=None,
        help=(
            "optional registry-id filter (e.g. bird_california_schools); "
            "if set, only examples for that DB are kept"
        ),
    )
    parser.add_argument(
        "--difficulty",
        choices=["simple", "moderate", "challenging"],
        default=None,
        help=(
            "optional difficulty filter; useful for tier-specific runs "
            "(e.g. --difficulty challenging to run config F only on the "
            "hard tier and merge with G for the rest — see "
            "docs/SESSION_HANDOFF.md for the hybrid recipe)."
        ),
    )
    parser.add_argument(
        "--only-qids",
        default="",
        help=(
            "comma-separated BIRD question IDs to run exactly, preserving "
            "argument order and bypassing --n/--seed sampling"
        ),
    )
    parser.add_argument(
        "--bird-root",
        default=str(DEFAULT_BIRD_ROOT),
        help=f"path to MINIDEV/ root (default: {DEFAULT_BIRD_ROOT})",
    )
    parser.add_argument("--reports", default="eval/reports", help="output root")
    parser.add_argument(
        "--config",
        choices=["A", "C", "D", "E", "F", "G"],
        default="A",
        help=(
            "ablation configuration "
            "(A=full_schema, C=dense+FK no repair, "
            "E=dense+FK+repair_once, F=dense+FK+self-consistency)"
        ),
    )
    parser.add_argument(
        "--sql-candidate-temperatures",
        default="0.2,0.4,0.6,0.8",
        help=(
            "comma-separated sampling temperatures for config F "
            "(self-consistency). One pipeline pass per temperature; "
            "default 4 candidates at 0.2/0.4/0.6/0.8."
        ),
    )
    parser.add_argument(
        "--persist",
        default="chroma_data",
        help="chroma persist directory (config C only; default: chroma_data/)",
    )
    parser.add_argument(
        "--no-cache",
        action="store_true",
        help=(
            "disable diskcache wrappers around the LLM/embedding providers. "
            "Default is cached — re-running the same examples is then $0 + "
            "deterministic, so ablations compare apples to apples."
        ),
    )
    parser.add_argument(
        "--schema-top-k",
        type=int,
        default=5,
        help="dense schema retrieval top-k (configs C/E; default: 5)",
    )
    parser.add_argument(
        "--fk-hops",
        type=int,
        default=1,
        help="FK graph expansion hops (configs C/E; default: 1)",
    )
    parser.add_argument(
        "--table-budget",
        type=int,
        default=12,
        help="max tables in the schema block (configs C/E; default: 12)",
    )
    parser.add_argument(
        "--report-suffix",
        default="",
        help=(
            "extra string appended to <config>.json so knob-bump runs don't "
            "overwrite the baseline (e.g. '--report-suffix=topk8' → "
            "C_dense_cards-topk8.json)"
        ),
    )
    parser.add_argument(
        "--sort-schema-block",
        action="store_true",
        help=(
            "render schema_block in alphabetical-by-table-name order "
            "(configs C/E only; default: retrieval-distance + FK BFS order). "
            "Tests the hypothesis that codestral is order-sensitive on "
            "moderate-tier BIRD questions."
        ),
    )
    parser.add_argument(
        "--primary-sample-size",
        type=int,
        default=3,
        help=(
            "sample density baked into the chunks stored in Chroma "
            "(must match the --sample-size used at build_index time; "
            "default: 3)"
        ),
    )
    parser.add_argument(
        "--fewshot-top-k",
        type=int,
        default=3,
        help=(
            "number of fewshot Q→SQL pairs to retrieve from the "
            "fewshot_qsql collection (configs D/G/F-with-fewshot; "
            "default: 3). Higher values give the LLM more templates "
            "but inflate prompt token count and risk distracting the "
            "generator with off-topic examples."
        ),
    )
    parser.add_argument(
        "--with-fewshot",
        action="store_true",
        help=(
            "enable cross-db fewshot retrieval for config F "
            "(self-consistency). D and G have fewshot ON by default; "
            "for F it's opt-in so old F runs stay comparable."
        ),
    )
    parser.add_argument(
        "--extended-sample-size",
        type=int,
        default=0,
        help=(
            "per-difficulty sample mixture (configs C/E only; default: 0 "
            "= disabled). When > primary_sample_size, the schema_block "
            "appendix lists samples primary..extended per column for "
            "retrieved tables, so the model has both densities in one "
            "prompt. Re-introspects the live DB at runtime — no chroma "
            "rebuild needed. Recommended value: 5."
        ),
    )
    parser.add_argument(
        "--provider",
        choices=["mistral", "groq", "github_models", "ollama", "perplexity", "openrouter"],
        default="mistral",
        help=(
            "LLM provider for generation (embedding stays mistral — only "
            "Mistral implements EmbeddingProvider). Used for the "
            "architecture §1 provider bakeoff."
        ),
    )
    args = parser.parse_args(argv)

    examples = load_bird_mini_dev(Path(args.bird_root))
    if args.db:
        examples = [e for e in examples if e.registry_db_id == args.db]
        if not examples:
            print(f"[error] no examples for db {args.db!r}", file=sys.stderr)
            return 3

    try:
        only_qids = [int(x) for x in args.only_qids.split(",") if x.strip()]
    except ValueError:
        print("[error] invalid --only-qids: expected comma-separated integers", file=sys.stderr)
        return 3
    if only_qids:
        examples_by_qid = {e.question_id: e for e in examples}
        sample = [examples_by_qid[qid] for qid in only_qids if qid in examples_by_qid]
        missing_qids = [qid for qid in only_qids if qid not in examples_by_qid]
        if missing_qids:
            print(f"[error] qids not found after filters: {missing_qids}", file=sys.stderr)
            return 3
    else:
        sample = dev_split(examples, n=args.n, seed=args.seed)
    if args.difficulty:
        # Apply AFTER dev_split so the same shuffle-prefix examples appear
        # as in unfiltered runs — needed for hybrid merging (e.g., F on
        # challenging tier blended with G on the rest).
        sample = [e for e in sample if e.difficulty == args.difficulty]
        if not sample:
            print(
                f"[error] no examples for difficulty {args.difficulty!r} "
                f"within the n={args.n} prefix",
                file=sys.stderr,
            )
            return 3
    print(f"[info] loaded {len(examples)} examples → sampled {len(sample)} (seed={args.seed})")

    registry = get_default_registry()
    missing = sorted({e.registry_db_id for e in sample} - set(registry.ids()))
    if missing:
        print(
            f"[error] sampled examples reference unregistered DBs: {missing}\n"
            f"  registered: {registry.ids()}",
            file=sys.stderr,
        )
        return 4

    settings = get_settings()
    if not settings.mistral_api_key:
        print("[error] MISTRAL_API_KEY not set in .env", file=sys.stderr)
        return 2

    raw_sql_provider = build_provider(args.provider, settings=settings)
    print(f"[info] provider: {args.provider} (model={raw_sql_provider.model})")
    sql_provider: LLMProvider
    if args.no_cache:
        sql_provider = raw_sql_provider
        print("[info] cache: DISABLED (--no-cache)")
    else:
        sql_provider = CachingLLMProvider(
            raw_sql_provider,
            cache_dir=settings.llm_cache_dir,
            size_limit_gb=settings.llm_cache_size_limit_gb,
        )
        print(f"[info] cache: ENABLED at {settings.llm_cache_dir}/")

    started = time.perf_counter()

    def _on_progress(idx: int, total: int, rec: EvalRecord) -> None:
        flag = "OK " if rec.match else "MISS"
        err = f" [{rec.error_kind}]" if rec.error_kind else ""
        recall = "rec✓" if rec.schema_recall else "rec✗"
        print(
            f"  [{idx:>3}/{total}] {flag} {recall} ({rec.latency_ms:6.0f}ms) "
            f"{rec.db_id}/{rec.difficulty}{err}{rec.question[:80]}"
        )

    print(f"[info] running configuration {args.config} on {len(sample)} examples …")
    run: EvalRun
    if args.config == "A":
        run = run_config_a(
            sample,
            sql_provider=sql_provider,
            registry=registry,
            progress=_on_progress,
        )
    else:  # "C", "E", or "F" — all need the Chroma index
        persist_dir = Path(args.persist)
        if not persist_dir.is_dir():
            print(
                f"[error] chroma persist dir not found: {persist_dir}. "
                f"Run `python scripts/build_index.py --db all` first.",
                file=sys.stderr,
            )
            return 5
        chroma_client = chromadb.PersistentClient(path=str(persist_dir))
        # Embedding provider also Mistral — same key, same `mistral-embed`.
        raw_embedder = MistralProvider(
            api_key=settings.mistral_api_key,
            gen_model=settings.mistral_gen_model,
            embed_model=settings.mistral_embed_model,
            base_url=settings.mistral_base_url,
        )
        embedder: EmbeddingProvider = (
            raw_embedder
            if args.no_cache
            else CachingEmbeddingProvider(
                raw_embedder,
                cache_dir=settings.llm_cache_dir,
                size_limit_gb=settings.llm_cache_size_limit_gb,
            )
        )
        index = SchemaIndex(persist_dir=persist_dir, embedder=embedder, client=chroma_client)
        explain_provider = sql_provider  # codestral works for caption too in eval
        if args.config == "F":
            temps = tuple(float(x) for x in args.sql_candidate_temperatures.split(",") if x.strip())
            print(f"[info] self-consistency: {len(temps)} candidates @ {temps}")
            run = run_config_f(
                sample,
                sql_provider=sql_provider,
                explain_provider=explain_provider,
                schema_index=index,
                registry=registry,
                schema_top_k=args.schema_top_k,
                fewshot_top_k=args.fewshot_top_k if args.with_fewshot else 0,
                fk_hops=args.fk_hops,
                table_budget=args.table_budget,
                sort_schema_block=args.sort_schema_block,
                primary_sample_size=args.primary_sample_size,
                extended_sample_size=args.extended_sample_size,
                sql_candidate_temperatures=temps,
                cross_db_fewshot=args.with_fewshot,
                progress=_on_progress,
            )
        elif args.config == "D":
            run = run_config_d(
                sample,
                sql_provider=sql_provider,
                explain_provider=explain_provider,
                schema_index=index,
                registry=registry,
                schema_top_k=args.schema_top_k,
                fewshot_top_k=args.fewshot_top_k,
                fk_hops=args.fk_hops,
                table_budget=args.table_budget,
                sort_schema_block=args.sort_schema_block,
                primary_sample_size=args.primary_sample_size,
                extended_sample_size=args.extended_sample_size,
                progress=_on_progress,
            )
        elif args.config == "G":
            run = run_config_g(
                sample,
                sql_provider=sql_provider,
                explain_provider=explain_provider,
                schema_index=index,
                registry=registry,
                schema_top_k=args.schema_top_k,
                fewshot_top_k=args.fewshot_top_k,
                fk_hops=args.fk_hops,
                table_budget=args.table_budget,
                sort_schema_block=args.sort_schema_block,
                primary_sample_size=args.primary_sample_size,
                extended_sample_size=args.extended_sample_size,
                progress=_on_progress,
            )
        else:
            runner = run_config_c if args.config == "C" else run_config_e
            run = runner(
                sample,
                sql_provider=sql_provider,
                explain_provider=explain_provider,
                schema_index=index,
                registry=registry,
                schema_top_k=args.schema_top_k,
                fk_hops=args.fk_hops,
                table_budget=args.table_budget,
                sort_schema_block=args.sort_schema_block,
                primary_sample_size=args.primary_sample_size,
                extended_sample_size=args.extended_sample_size,
                progress=_on_progress,
            )
    elapsed = time.perf_counter() - started

    print()
    print("=" * 78)
    print(f"Configuration: {run.configuration.value}")
    print(f"Model:         {run.sql_model}")
    print(f"Examples:      {run.overall.n}")
    print(f"EA (final):    {run.overall.ea * 100:.1f}%")
    print(f"EA (1st pass): {run.overall.first_pass_ea * 100:.1f}%")
    print(
        f"  simple:      {run.per_difficulty['simple'].ea * 100:.1f}% (n={run.per_difficulty['simple'].n})"
    )
    print(
        f"  moderate:    {run.per_difficulty['moderate'].ea * 100:.1f}% (n={run.per_difficulty['moderate'].n})"
    )
    print(
        f"  challenging: {run.per_difficulty['challenging'].ea * 100:.1f}% (n={run.per_difficulty['challenging'].n})"
    )
    print(f"Validity:      {run.overall.validity_rate * 100:.1f}%")
    print(
        f"Repair fired:  {sum(1 for r in run.records if r.repair_attempted)}/{run.overall.n}; success rate {run.overall.repair_success_rate * 100:.1f}%"
    )
    print(
        f"Schema rec@k:  {run.overall.schema_recall_at_k * 100:.1f}%  (k = full schema, so recall ≈ 100% expected)"
    )
    print(f"Empty result:  {run.overall.empty_result_rate * 100:.1f}%")
    print(f"Latency P50:   {run.overall.latency_p50_ms:.0f} ms")
    print(f"Latency P95:   {run.overall.latency_p95_ms:.0f} ms")
    print(f"Tokens P50:    {run.overall.tokens_p50:.0f}")
    print(f"Tokens P95:    {run.overall.tokens_p95:.0f}")
    print(f"Wall time:     {elapsed:.1f}s")

    json_path = write_json_report(run, root=args.reports, name_suffix=args.report_suffix)

    # Combine today's run with any other configurations that finished earlier
    # so the HTML index keeps a single side-by-side ablation table per day.
    today_dir = json_path.parent
    prior_runs: list[EvalRun] = []
    for other in sorted(today_dir.glob("*.json")):
        if other == json_path:
            continue
        try:
            prior_runs.append(load_run_from_json(other))
        except (KeyError, TypeError, ValueError) as exc:
            print(f"[warn] skipped {other.name}: {exc}", file=sys.stderr)
    html_path = write_html_report([*prior_runs, run], root=args.reports)
    print()
    print(f"[json] {json_path}")
    print(f"[html] {html_path}")
    return 0


if __name__ == "__main__":
    sys.exit(main())