nl-sql / scripts /rescore_arcwise.py
liovina's picture
Deploy NL_SQL HEAD to HF Space
4b4ff9e verified
"""Re-score a v10-style BIRD eval report against Arcwise-Plat corrected gold.
Jin et al. (CIDR/VLDB 2026, arXiv:2601.08778) audited BIRD Mini-Dev and found
~52.8% questions have annotation errors. Their corrected artifacts
(`arcwise_plat_sql_only` = SQL-only fixes, `arcwise_plat_full` = SQL + question +
evidence + schema fixes) live at
https://github.com/uiuc-kang-lab/text_to_sql_benchmarks/blob/main/data/.
This script keeps our predictions unchanged and only swaps the gold SQL used
for execution-accuracy scoring. It writes a comparison report grouped into
buckets: same / gained (pred now matches corrected gold) / lost (pred matched
original gold but no longer matches corrected) per source variant.
Outputs:
- eval/reports/2026-05-17/arcwise_rescored.json (full per-record audit)
- stdout summary table
Usage:
uv run python scripts/rescore_arcwise.py \
--report eval/reports/2026-05-17/hybrid-vote-critique-selfcon-sonnet-fewshot5-groq4-mschema-v10.json \
--sql-only data/arcwise_plat_sql_only.json \
--full data/arcwise_plat_full.json \
--out eval/reports/2026-05-17/arcwise_rescored.json
"""
from __future__ import annotations
import argparse
import json
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any
from nl_sql.db.connection import execute_readonly
from nl_sql.db.registry import get_default_registry
from nl_sql.eval.metrics.execution_accuracy import safe_compare_pred
from nl_sql.eval.runner import _execute_gold_with_status
def _load_arcwise(path: Path) -> dict[int, dict[str, Any]]:
raw = json.loads(path.read_text(encoding="utf-8"))
out: dict[int, dict[str, Any]] = {}
for entry in raw:
qid = int(entry["question_id"])
out[qid] = entry
return out
def main() -> int:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--report", type=Path, required=True)
p.add_argument("--sql-only", type=Path, required=True)
p.add_argument("--full", type=Path, required=True)
p.add_argument("--out", type=Path, required=True)
args = p.parse_args()
report = json.loads(args.report.read_text(encoding="utf-8"))
arc_sql = _load_arcwise(args.sql_only)
arc_full = _load_arcwise(args.full)
registry = get_default_registry()
records = report["records"]
# Per-variant aggregates.
variants = ("original", "sql_only", "full")
matched: dict[str, int] = {v: 0 for v in variants}
total_scored: dict[str, int] = {v: 0 for v in variants}
per_diff: dict[str, dict[str, list[int]]] = {v: defaultdict(lambda: [0, 0]) for v in variants}
# Per-qid transitions sql_only vs original, full vs original.
transitions: dict[str, list[dict[str, Any]]] = {"gained": [], "lost": [], "changed_gold": []}
per_record: list[dict[str, Any]] = []
for i, rec in enumerate(records, 1):
qid = rec["question_id"]
db_id = rec["db_id"]
difficulty = rec["difficulty"]
pred_sql = rec.get("pred_sql") or ""
orig_match = bool(rec.get("match"))
spec = registry.get(f"bird_{db_id}")
engine = spec.make_engine()
out_entry = {
"question_id": qid,
"db_id": db_id,
"difficulty": difficulty,
"pred_sql": pred_sql,
"original_match": orig_match,
}
try:
# Execute pred once, reuse rows. Route pred through `execute_readonly`
# directly (matches canonical `scripts/audit_rescore.py`): the
# `_execute_gold` SQLAlchemyError fallback is intended only for
# trusted BIRD gold SQL, not for model-generated pred SQL — using
# it on pred can mask validator-style failures and yields
# non-deterministic engine state across sequential records.
pred_rows: list[tuple[Any, ...]] = []
pred_failed = False
if pred_sql.strip():
try:
with execute_readonly(
engine, pred_sql, statement_timeout_ms=30_000, row_cap=10_000
) as result:
pred_rows = list(result.rows)
except Exception as exc:
out_entry["pred_exec_error"] = str(exc)
pred_failed = True
else:
pred_failed = True
# Score against each variant.
for variant, source in (
("original", rec.get("gold_sql") or ""),
("sql_only", arc_sql.get(qid, {}).get("SQL") or ""),
("full", arc_full.get(qid, {}).get("SQL") or ""),
):
if not source:
continue
gold_failed = False
try:
gold_rows, _, gold_failed = _execute_gold_with_status(
engine, source, statement_timeout_ms=30_000, row_cap=10_000
)
if gold_failed:
out_entry[f"{variant}_gold_exec_error"] = (
"gold SQL crashed in both execute_readonly and raw-connection paths"
)
except Exception as exc:
gold_rows = []
gold_failed = True
out_entry[f"{variant}_gold_exec_error"] = str(exc)
cmp = safe_compare_pred(
gold_rows,
pred_rows,
gold_sql=source,
pred_failed=pred_failed,
gold_failed=gold_failed,
)
is_match = bool(cmp.match)
out_entry[f"{variant}_match"] = is_match
out_entry[f"{variant}_reason"] = cmp.reason
out_entry[f"{variant}_gold_rows"] = len(gold_rows)
total_scored[variant] += 1
matched[variant] += int(is_match)
per_diff[variant][difficulty][1] += 1
per_diff[variant][difficulty][0] += int(is_match)
# Transitions vs sql_only and vs full.
for variant in ("sql_only", "full"):
v_match = out_entry.get(f"{variant}_match")
if v_match is None:
continue
src = arc_sql if variant == "sql_only" else arc_full
arc_entry = src.get(qid) or {}
gold_changed = bool(
arc_entry.get("SQL", "").strip() != (rec.get("gold_sql") or "").strip()
)
if gold_changed:
out_entry[f"{variant}_gold_changed"] = True
if orig_match and not v_match:
transitions["lost"].append(
{"qid": qid, "variant": variant, "difficulty": difficulty}
)
elif (not orig_match) and v_match:
transitions["gained"].append(
{"qid": qid, "variant": variant, "difficulty": difficulty}
)
finally:
engine.dispose()
per_record.append(out_entry)
if i % 25 == 0:
print(f"[{i:3d}/{len(records)}] processed", file=sys.stderr)
# Summary.
print("\n=== Arcwise rescoring summary ===", file=sys.stderr)
for variant in variants:
total = total_scored[variant]
count = matched[variant]
pct = (count / total * 100) if total else 0.0
print(f" {variant:10s}: {count}/{total} = {pct:.2f}%", file=sys.stderr)
print("\n=== Per-tier ===", file=sys.stderr)
for variant in variants:
line = f" {variant:10s}: "
for diff in ("simple", "moderate", "challenging"):
mt, tot = per_diff[variant][diff]
pct = (mt / tot * 100) if tot else 0.0
line += f"{diff[:4]}={mt}/{tot}({pct:.1f}%) "
print(line, file=sys.stderr)
print("\n=== Transitions (vs original gold) ===", file=sys.stderr)
print(f" gained (sql_only): {len(transitions['gained'])}", file=sys.stderr)
print(
f" lost (sql_only): {sum(1 for t in transitions['lost'] if t['variant'] == 'sql_only')}",
file=sys.stderr,
)
print(
f" gained (full): {sum(1 for t in transitions['gained'] if t['variant'] == 'full')}",
file=sys.stderr,
)
print(
f" lost (full): {sum(1 for t in transitions['lost'] if t['variant'] == 'full')}",
file=sys.stderr,
)
out_payload = {
"source_report": str(args.report),
"summary": {v: {"matched": matched[v], "total": total_scored[v]} for v in variants},
"per_difficulty": {
v: {
d: {"matched": per_diff[v][d][0], "total": per_diff[v][d][1]}
for d in ("simple", "moderate", "challenging")
}
for v in variants
},
"transitions": transitions,
"records": per_record,
}
args.out.parent.mkdir(parents=True, exist_ok=True)
args.out.write_text(json.dumps(out_payload, indent=2, default=str), encoding="utf-8")
print(f"\n[info] wrote {args.out}", file=sys.stderr)
return 0
if __name__ == "__main__":
raise SystemExit(main())