File size: 9,442 Bytes
d48602c 3f1a281 4b4ff9e d48602c 4b4ff9e d48602c 3f1a281 d48602c 3f1a281 d48602c 4b4ff9e d48602c 4b4ff9e d48602c 4b4ff9e d48602c 4b4ff9e d48602c 3f1a281 4b4ff9e 3f1a281 d48602c 4b4ff9e d48602c 4b4ff9e d48602c 4b4ff9e d48602c 4b4ff9e d48602c 4b4ff9e d48602c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | """Re-score a v10-style BIRD eval report against Arcwise-Plat corrected gold.
Jin et al. (CIDR/VLDB 2026, arXiv:2601.08778) audited BIRD Mini-Dev and found
~52.8% questions have annotation errors. Their corrected artifacts
(`arcwise_plat_sql_only` = SQL-only fixes, `arcwise_plat_full` = SQL + question +
evidence + schema fixes) live at
https://github.com/uiuc-kang-lab/text_to_sql_benchmarks/blob/main/data/.
This script keeps our predictions unchanged and only swaps the gold SQL used
for execution-accuracy scoring. It writes a comparison report grouped into
buckets: same / gained (pred now matches corrected gold) / lost (pred matched
original gold but no longer matches corrected) per source variant.
Outputs:
- eval/reports/2026-05-17/arcwise_rescored.json (full per-record audit)
- stdout summary table
Usage:
uv run python scripts/rescore_arcwise.py \
--report eval/reports/2026-05-17/hybrid-vote-critique-selfcon-sonnet-fewshot5-groq4-mschema-v10.json \
--sql-only data/arcwise_plat_sql_only.json \
--full data/arcwise_plat_full.json \
--out eval/reports/2026-05-17/arcwise_rescored.json
"""
from __future__ import annotations
import argparse
import json
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any
from nl_sql.db.connection import execute_readonly
from nl_sql.db.registry import get_default_registry
from nl_sql.eval.metrics.execution_accuracy import safe_compare_pred
from nl_sql.eval.runner import _execute_gold_with_status
def _load_arcwise(path: Path) -> dict[int, dict[str, Any]]:
raw = json.loads(path.read_text(encoding="utf-8"))
out: dict[int, dict[str, Any]] = {}
for entry in raw:
qid = int(entry["question_id"])
out[qid] = entry
return out
def main() -> int:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--report", type=Path, required=True)
p.add_argument("--sql-only", type=Path, required=True)
p.add_argument("--full", type=Path, required=True)
p.add_argument("--out", type=Path, required=True)
args = p.parse_args()
report = json.loads(args.report.read_text(encoding="utf-8"))
arc_sql = _load_arcwise(args.sql_only)
arc_full = _load_arcwise(args.full)
registry = get_default_registry()
records = report["records"]
# Per-variant aggregates.
variants = ("original", "sql_only", "full")
matched: dict[str, int] = {v: 0 for v in variants}
total_scored: dict[str, int] = {v: 0 for v in variants}
per_diff: dict[str, dict[str, list[int]]] = {v: defaultdict(lambda: [0, 0]) for v in variants}
# Per-qid transitions sql_only vs original, full vs original.
transitions: dict[str, list[dict[str, Any]]] = {"gained": [], "lost": [], "changed_gold": []}
per_record: list[dict[str, Any]] = []
for i, rec in enumerate(records, 1):
qid = rec["question_id"]
db_id = rec["db_id"]
difficulty = rec["difficulty"]
pred_sql = rec.get("pred_sql") or ""
orig_match = bool(rec.get("match"))
spec = registry.get(f"bird_{db_id}")
engine = spec.make_engine()
out_entry = {
"question_id": qid,
"db_id": db_id,
"difficulty": difficulty,
"pred_sql": pred_sql,
"original_match": orig_match,
}
try:
# Execute pred once, reuse rows. Route pred through `execute_readonly`
# directly (matches canonical `scripts/audit_rescore.py`): the
# `_execute_gold` SQLAlchemyError fallback is intended only for
# trusted BIRD gold SQL, not for model-generated pred SQL — using
# it on pred can mask validator-style failures and yields
# non-deterministic engine state across sequential records.
pred_rows: list[tuple[Any, ...]] = []
pred_failed = False
if pred_sql.strip():
try:
with execute_readonly(
engine, pred_sql, statement_timeout_ms=30_000, row_cap=10_000
) as result:
pred_rows = list(result.rows)
except Exception as exc:
out_entry["pred_exec_error"] = str(exc)
pred_failed = True
else:
pred_failed = True
# Score against each variant.
for variant, source in (
("original", rec.get("gold_sql") or ""),
("sql_only", arc_sql.get(qid, {}).get("SQL") or ""),
("full", arc_full.get(qid, {}).get("SQL") or ""),
):
if not source:
continue
gold_failed = False
try:
gold_rows, _, gold_failed = _execute_gold_with_status(
engine, source, statement_timeout_ms=30_000, row_cap=10_000
)
if gold_failed:
out_entry[f"{variant}_gold_exec_error"] = (
"gold SQL crashed in both execute_readonly and raw-connection paths"
)
except Exception as exc:
gold_rows = []
gold_failed = True
out_entry[f"{variant}_gold_exec_error"] = str(exc)
cmp = safe_compare_pred(
gold_rows,
pred_rows,
gold_sql=source,
pred_failed=pred_failed,
gold_failed=gold_failed,
)
is_match = bool(cmp.match)
out_entry[f"{variant}_match"] = is_match
out_entry[f"{variant}_reason"] = cmp.reason
out_entry[f"{variant}_gold_rows"] = len(gold_rows)
total_scored[variant] += 1
matched[variant] += int(is_match)
per_diff[variant][difficulty][1] += 1
per_diff[variant][difficulty][0] += int(is_match)
# Transitions vs sql_only and vs full.
for variant in ("sql_only", "full"):
v_match = out_entry.get(f"{variant}_match")
if v_match is None:
continue
src = arc_sql if variant == "sql_only" else arc_full
arc_entry = src.get(qid) or {}
gold_changed = bool(
arc_entry.get("SQL", "").strip() != (rec.get("gold_sql") or "").strip()
)
if gold_changed:
out_entry[f"{variant}_gold_changed"] = True
if orig_match and not v_match:
transitions["lost"].append(
{"qid": qid, "variant": variant, "difficulty": difficulty}
)
elif (not orig_match) and v_match:
transitions["gained"].append(
{"qid": qid, "variant": variant, "difficulty": difficulty}
)
finally:
engine.dispose()
per_record.append(out_entry)
if i % 25 == 0:
print(f"[{i:3d}/{len(records)}] processed", file=sys.stderr)
# Summary.
print("\n=== Arcwise rescoring summary ===", file=sys.stderr)
for variant in variants:
total = total_scored[variant]
count = matched[variant]
pct = (count / total * 100) if total else 0.0
print(f" {variant:10s}: {count}/{total} = {pct:.2f}%", file=sys.stderr)
print("\n=== Per-tier ===", file=sys.stderr)
for variant in variants:
line = f" {variant:10s}: "
for diff in ("simple", "moderate", "challenging"):
mt, tot = per_diff[variant][diff]
pct = (mt / tot * 100) if tot else 0.0
line += f"{diff[:4]}={mt}/{tot}({pct:.1f}%) "
print(line, file=sys.stderr)
print("\n=== Transitions (vs original gold) ===", file=sys.stderr)
print(f" gained (sql_only): {len(transitions['gained'])}", file=sys.stderr)
print(
f" lost (sql_only): {sum(1 for t in transitions['lost'] if t['variant'] == 'sql_only')}",
file=sys.stderr,
)
print(
f" gained (full): {sum(1 for t in transitions['gained'] if t['variant'] == 'full')}",
file=sys.stderr,
)
print(
f" lost (full): {sum(1 for t in transitions['lost'] if t['variant'] == 'full')}",
file=sys.stderr,
)
out_payload = {
"source_report": str(args.report),
"summary": {v: {"matched": matched[v], "total": total_scored[v]} for v in variants},
"per_difficulty": {
v: {
d: {"matched": per_diff[v][d][0], "total": per_diff[v][d][1]}
for d in ("simple", "moderate", "challenging")
}
for v in variants
},
"transitions": transitions,
"records": per_record,
}
args.out.parent.mkdir(parents=True, exist_ok=True)
args.out.write_text(json.dumps(out_payload, indent=2, default=str), encoding="utf-8")
print(f"\n[info] wrote {args.out}", file=sys.stderr)
return 0
if __name__ == "__main__":
raise SystemExit(main())
|