| """Re-score a v10-style BIRD eval report against Arcwise-Plat corrected gold.
|
|
|
| Jin et al. (CIDR/VLDB 2026, arXiv:2601.08778) audited BIRD Mini-Dev and found
|
| ~52.8% questions have annotation errors. Their corrected artifacts
|
| (`arcwise_plat_sql_only` = SQL-only fixes, `arcwise_plat_full` = SQL + question +
|
| evidence + schema fixes) live at
|
| https://github.com/uiuc-kang-lab/text_to_sql_benchmarks/blob/main/data/.
|
|
|
| This script keeps our predictions unchanged and only swaps the gold SQL used
|
| for execution-accuracy scoring. It writes a comparison report grouped into
|
| buckets: same / gained (pred now matches corrected gold) / lost (pred matched
|
| original gold but no longer matches corrected) per source variant.
|
|
|
| Outputs:
|
| - eval/reports/2026-05-17/arcwise_rescored.json (full per-record audit)
|
| - stdout summary table
|
|
|
| Usage:
|
| uv run python scripts/rescore_arcwise.py \
|
| --report eval/reports/2026-05-17/hybrid-vote-critique-selfcon-sonnet-fewshot5-groq4-mschema-v10.json \
|
| --sql-only data/arcwise_plat_sql_only.json \
|
| --full data/arcwise_plat_full.json \
|
| --out eval/reports/2026-05-17/arcwise_rescored.json
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import json
|
| import sys
|
| from collections import defaultdict
|
| from pathlib import Path
|
| from typing import Any
|
|
|
| from nl_sql.db.connection import execute_readonly
|
| from nl_sql.db.registry import get_default_registry
|
| from nl_sql.eval.metrics.execution_accuracy import safe_compare_pred
|
| from nl_sql.eval.runner import _execute_gold_with_status
|
|
|
|
|
| def _load_arcwise(path: Path) -> dict[int, dict[str, Any]]:
|
| raw = json.loads(path.read_text(encoding="utf-8"))
|
| out: dict[int, dict[str, Any]] = {}
|
| for entry in raw:
|
| qid = int(entry["question_id"])
|
| out[qid] = entry
|
| return out
|
|
|
|
|
| def main() -> int:
|
| p = argparse.ArgumentParser(description=__doc__)
|
| p.add_argument("--report", type=Path, required=True)
|
| p.add_argument("--sql-only", type=Path, required=True)
|
| p.add_argument("--full", type=Path, required=True)
|
| p.add_argument("--out", type=Path, required=True)
|
| args = p.parse_args()
|
|
|
| report = json.loads(args.report.read_text(encoding="utf-8"))
|
| arc_sql = _load_arcwise(args.sql_only)
|
| arc_full = _load_arcwise(args.full)
|
|
|
| registry = get_default_registry()
|
| records = report["records"]
|
|
|
|
|
| variants = ("original", "sql_only", "full")
|
| matched: dict[str, int] = {v: 0 for v in variants}
|
| total_scored: dict[str, int] = {v: 0 for v in variants}
|
| per_diff: dict[str, dict[str, list[int]]] = {v: defaultdict(lambda: [0, 0]) for v in variants}
|
|
|
| transitions: dict[str, list[dict[str, Any]]] = {"gained": [], "lost": [], "changed_gold": []}
|
|
|
| per_record: list[dict[str, Any]] = []
|
|
|
| for i, rec in enumerate(records, 1):
|
| qid = rec["question_id"]
|
| db_id = rec["db_id"]
|
| difficulty = rec["difficulty"]
|
| pred_sql = rec.get("pred_sql") or ""
|
| orig_match = bool(rec.get("match"))
|
|
|
| spec = registry.get(f"bird_{db_id}")
|
| engine = spec.make_engine()
|
| out_entry = {
|
| "question_id": qid,
|
| "db_id": db_id,
|
| "difficulty": difficulty,
|
| "pred_sql": pred_sql,
|
| "original_match": orig_match,
|
| }
|
| try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| pred_rows: list[tuple[Any, ...]] = []
|
| pred_failed = False
|
| if pred_sql.strip():
|
| try:
|
| with execute_readonly(
|
| engine, pred_sql, statement_timeout_ms=30_000, row_cap=10_000
|
| ) as result:
|
| pred_rows = list(result.rows)
|
| except Exception as exc:
|
| out_entry["pred_exec_error"] = str(exc)
|
| pred_failed = True
|
| else:
|
| pred_failed = True
|
|
|
|
|
| for variant, source in (
|
| ("original", rec.get("gold_sql") or ""),
|
| ("sql_only", arc_sql.get(qid, {}).get("SQL") or ""),
|
| ("full", arc_full.get(qid, {}).get("SQL") or ""),
|
| ):
|
| if not source:
|
| continue
|
| gold_failed = False
|
| try:
|
| gold_rows, _, gold_failed = _execute_gold_with_status(
|
| engine, source, statement_timeout_ms=30_000, row_cap=10_000
|
| )
|
| if gold_failed:
|
| out_entry[f"{variant}_gold_exec_error"] = (
|
| "gold SQL crashed in both execute_readonly and raw-connection paths"
|
| )
|
| except Exception as exc:
|
| gold_rows = []
|
| gold_failed = True
|
| out_entry[f"{variant}_gold_exec_error"] = str(exc)
|
| cmp = safe_compare_pred(
|
| gold_rows,
|
| pred_rows,
|
| gold_sql=source,
|
| pred_failed=pred_failed,
|
| gold_failed=gold_failed,
|
| )
|
| is_match = bool(cmp.match)
|
| out_entry[f"{variant}_match"] = is_match
|
| out_entry[f"{variant}_reason"] = cmp.reason
|
| out_entry[f"{variant}_gold_rows"] = len(gold_rows)
|
| total_scored[variant] += 1
|
| matched[variant] += int(is_match)
|
| per_diff[variant][difficulty][1] += 1
|
| per_diff[variant][difficulty][0] += int(is_match)
|
|
|
|
|
| for variant in ("sql_only", "full"):
|
| v_match = out_entry.get(f"{variant}_match")
|
| if v_match is None:
|
| continue
|
| src = arc_sql if variant == "sql_only" else arc_full
|
| arc_entry = src.get(qid) or {}
|
| gold_changed = bool(
|
| arc_entry.get("SQL", "").strip() != (rec.get("gold_sql") or "").strip()
|
| )
|
| if gold_changed:
|
| out_entry[f"{variant}_gold_changed"] = True
|
| if orig_match and not v_match:
|
| transitions["lost"].append(
|
| {"qid": qid, "variant": variant, "difficulty": difficulty}
|
| )
|
| elif (not orig_match) and v_match:
|
| transitions["gained"].append(
|
| {"qid": qid, "variant": variant, "difficulty": difficulty}
|
| )
|
| finally:
|
| engine.dispose()
|
| per_record.append(out_entry)
|
| if i % 25 == 0:
|
| print(f"[{i:3d}/{len(records)}] processed", file=sys.stderr)
|
|
|
|
|
| print("\n=== Arcwise rescoring summary ===", file=sys.stderr)
|
| for variant in variants:
|
| total = total_scored[variant]
|
| count = matched[variant]
|
| pct = (count / total * 100) if total else 0.0
|
| print(f" {variant:10s}: {count}/{total} = {pct:.2f}%", file=sys.stderr)
|
| print("\n=== Per-tier ===", file=sys.stderr)
|
| for variant in variants:
|
| line = f" {variant:10s}: "
|
| for diff in ("simple", "moderate", "challenging"):
|
| mt, tot = per_diff[variant][diff]
|
| pct = (mt / tot * 100) if tot else 0.0
|
| line += f"{diff[:4]}={mt}/{tot}({pct:.1f}%) "
|
| print(line, file=sys.stderr)
|
| print("\n=== Transitions (vs original gold) ===", file=sys.stderr)
|
| print(f" gained (sql_only): {len(transitions['gained'])}", file=sys.stderr)
|
| print(
|
| f" lost (sql_only): {sum(1 for t in transitions['lost'] if t['variant'] == 'sql_only')}",
|
| file=sys.stderr,
|
| )
|
| print(
|
| f" gained (full): {sum(1 for t in transitions['gained'] if t['variant'] == 'full')}",
|
| file=sys.stderr,
|
| )
|
| print(
|
| f" lost (full): {sum(1 for t in transitions['lost'] if t['variant'] == 'full')}",
|
| file=sys.stderr,
|
| )
|
|
|
| out_payload = {
|
| "source_report": str(args.report),
|
| "summary": {v: {"matched": matched[v], "total": total_scored[v]} for v in variants},
|
| "per_difficulty": {
|
| v: {
|
| d: {"matched": per_diff[v][d][0], "total": per_diff[v][d][1]}
|
| for d in ("simple", "moderate", "challenging")
|
| }
|
| for v in variants
|
| },
|
| "transitions": transitions,
|
| "records": per_record,
|
| }
|
| args.out.parent.mkdir(parents=True, exist_ok=True)
|
| args.out.write_text(json.dumps(out_payload, indent=2, default=str), encoding="utf-8")
|
| print(f"\n[info] wrote {args.out}", file=sys.stderr)
|
| return 0
|
|
|
|
|
| if __name__ == "__main__":
|
| raise SystemExit(main())
|
|
|