"""Failure taxonomy for an eval report. Given a `G_*hybrid*.json` (or any report with the standard record schema), classify every failing example into an actionable bucket so accuracy work can target the biggest pile instead of guessing. Buckets: exec_failed — pred SQL crashed at runtime exec_timeout — pred SQL hit statement_timeout empty_result — pred returned 0 rows, gold has rows row_count_off — both ran, row counts differ (missing/extra GROUP BY, missing LIMIT, missing JOIN filter) projection_diff — same row count, column shape differs (wrong SELECT list) filter_or_value — same shape, values differ (wrong WHERE, wrong JOIN, wrong aggregation expression) order_by_off — gold has ORDER BY and gold[0] != pred[0] numeric_precision — single-row scalar, off by <1e-3 rel or CAST flavour Usage: uv run python scripts/error_taxonomy.py eval/reports/2026-05-11/G_*hybrid*.json """ from __future__ import annotations import json import re import sys from collections import Counter, defaultdict from pathlib import Path from typing import Any _RE_ORDER_BY = re.compile(r"\border\s+by\b", re.IGNORECASE) _RE_GROUP_BY = re.compile(r"\bgroup\s+by\b", re.IGNORECASE) _RE_LIMIT = re.compile(r"\blimit\b", re.IGNORECASE) _RE_AGG = re.compile(r"\b(sum|avg|count|min|max|cast)\s*\(", re.IGNORECASE) _RE_CASE = re.compile(r"\bcase\s+when\b|\biif\s*\(", re.IGNORECASE) def _classify(rec: dict[str, Any]) -> str: if rec.get("match"): return "match" ek = (rec.get("error_kind") or "").lower() if "timeout" in ek: return "exec_timeout" if ek in {"execution_failed", "validation_failed", "parse_failed"}: return "exec_failed" if ek == "empty_result": return "empty_result" gc = rec.get("gold_row_count") or 0 pc = rec.get("pred_row_count") or 0 gold = rec.get("gold_sql") or "" reason = (rec.get("comparison_reason") or "").lower() if gc != pc: return "row_count_off" if ( gc == pc == 1 and _RE_AGG.search(gold) and (_RE_CASE.search(gold) or "/ " in gold or "*100" in gold.replace(" ", "")) ): return "filter_or_value" if reason.startswith("ordered row"): if _RE_ORDER_BY.search(gold): return "order_by_off" return "filter_or_value" if "column" in reason or "projection" in reason or "shape" in reason: return "projection_diff" return "filter_or_value" def _exemplars(records: list[dict[str, Any]], bucket: str, n: int = 3) -> list[dict[str, Any]]: items = [r for r in records if _classify(r) == bucket] return items[:n] def summarise(report_path: Path) -> dict[str, Any]: data = json.loads(report_path.read_text(encoding="utf-8")) records = data.get("records", []) n = len(records) by_bucket: Counter[str] = Counter() by_diff_bucket: dict[str, Counter[str]] = defaultdict(Counter) for r in records: b = _classify(r) by_bucket[b] += 1 by_diff_bucket[r.get("difficulty", "?")][b] += 1 print(f"\n=== {report_path.name} (n={n}) ===") overall = data.get("overall", {}) print( f"EA={overall.get('ea', 0):.3f} first_pass={overall.get('first_pass_ea', 0):.3f} " f"valid={overall.get('validity_rate', 0):.3f} recall={overall.get('schema_recall_at_k', 0):.3f}" ) print("\n bucket n % lift_if_solved") print(" ----------------- --- ---- --------------") for bucket, cnt in by_bucket.most_common(): pct = 100.0 * cnt / n lift = 0.0 if bucket == "match" else 100.0 * cnt / n print(f" {bucket:17s} {cnt:3d} {pct:5.1f}% +{lift:5.1f}pp") print("\n by difficulty:") for diff in ("simple", "moderate", "challenging"): bd = by_diff_bucket.get(diff) if not bd: continue total = sum(bd.values()) miss = total - bd.get("match", 0) if total: print( f" {diff:12s} n={total:3d} miss={miss:3d} " + ", ".join(f"{k}={v}" for k, v in bd.most_common() if k != "match") ) print("\n top failure buckets — exemplars:") for bucket, _ in [(b, c) for b, c in by_bucket.most_common() if b != "match"][:4]: ex = _exemplars(records, bucket, n=2) print(f"\n [{bucket}]") for r in ex: q = (r.get("question") or "").replace("\n", " ")[:120] print(f" qid={r['question_id']} ({r['difficulty']}, {r['db_id']}): {q}") print(f" gold: {(r.get('gold_sql') or '')[:140]}") print(f" pred: {(r.get('pred_sql') or '')[:140]}") print(f" reason: {r.get('comparison_reason') or r.get('error_kind') or 'n/a'}") return { "by_bucket": dict(by_bucket), "by_difficulty": {k: dict(v) for k, v in by_diff_bucket.items()}, } def main() -> int: if len(sys.argv) < 2: print(__doc__) return 2 for path in sys.argv[1:]: summarise(Path(path)) return 0 if __name__ == "__main__": raise SystemExit(main())