File size: 5,221 Bytes
942050b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """Failure taxonomy for an eval report.
Given a `G_*hybrid*.json` (or any report with the standard record schema), classify
every failing example into an actionable bucket so accuracy work can target the
biggest pile instead of guessing.
Buckets:
exec_failed β pred SQL crashed at runtime
exec_timeout β pred SQL hit statement_timeout
empty_result β pred returned 0 rows, gold has rows
row_count_off β both ran, row counts differ (missing/extra GROUP BY,
missing LIMIT, missing JOIN filter)
projection_diff β same row count, column shape differs (wrong SELECT list)
filter_or_value β same shape, values differ (wrong WHERE, wrong JOIN,
wrong aggregation expression)
order_by_off β gold has ORDER BY and gold[0] != pred[0]
numeric_precision β single-row scalar, off by <1e-3 rel or CAST flavour
Usage:
uv run python scripts/error_taxonomy.py eval/reports/2026-05-11/G_*hybrid*.json
"""
from __future__ import annotations
import json
import re
import sys
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any
_RE_ORDER_BY = re.compile(r"\border\s+by\b", re.IGNORECASE)
_RE_GROUP_BY = re.compile(r"\bgroup\s+by\b", re.IGNORECASE)
_RE_LIMIT = re.compile(r"\blimit\b", re.IGNORECASE)
_RE_AGG = re.compile(r"\b(sum|avg|count|min|max|cast)\s*\(", re.IGNORECASE)
_RE_CASE = re.compile(r"\bcase\s+when\b|\biif\s*\(", re.IGNORECASE)
def _classify(rec: dict[str, Any]) -> str:
if rec.get("match"):
return "match"
ek = (rec.get("error_kind") or "").lower()
if "timeout" in ek:
return "exec_timeout"
if ek in {"execution_failed", "validation_failed", "parse_failed"}:
return "exec_failed"
if ek == "empty_result":
return "empty_result"
gc = rec.get("gold_row_count") or 0
pc = rec.get("pred_row_count") or 0
gold = rec.get("gold_sql") or ""
reason = (rec.get("comparison_reason") or "").lower()
if gc != pc:
return "row_count_off"
if (
gc == pc == 1
and _RE_AGG.search(gold)
and (_RE_CASE.search(gold) or "/ " in gold or "*100" in gold.replace(" ", ""))
):
return "filter_or_value"
if reason.startswith("ordered row"):
if _RE_ORDER_BY.search(gold):
return "order_by_off"
return "filter_or_value"
if "column" in reason or "projection" in reason or "shape" in reason:
return "projection_diff"
return "filter_or_value"
def _exemplars(records: list[dict[str, Any]], bucket: str, n: int = 3) -> list[dict[str, Any]]:
items = [r for r in records if _classify(r) == bucket]
return items[:n]
def summarise(report_path: Path) -> dict[str, Any]:
data = json.loads(report_path.read_text(encoding="utf-8"))
records = data.get("records", [])
n = len(records)
by_bucket: Counter[str] = Counter()
by_diff_bucket: dict[str, Counter[str]] = defaultdict(Counter)
for r in records:
b = _classify(r)
by_bucket[b] += 1
by_diff_bucket[r.get("difficulty", "?")][b] += 1
print(f"\n=== {report_path.name} (n={n}) ===")
overall = data.get("overall", {})
print(
f"EA={overall.get('ea', 0):.3f} first_pass={overall.get('first_pass_ea', 0):.3f} "
f"valid={overall.get('validity_rate', 0):.3f} recall={overall.get('schema_recall_at_k', 0):.3f}"
)
print("\n bucket n % lift_if_solved")
print(" ----------------- --- ---- --------------")
for bucket, cnt in by_bucket.most_common():
pct = 100.0 * cnt / n
lift = 0.0 if bucket == "match" else 100.0 * cnt / n
print(f" {bucket:17s} {cnt:3d} {pct:5.1f}% +{lift:5.1f}pp")
print("\n by difficulty:")
for diff in ("simple", "moderate", "challenging"):
bd = by_diff_bucket.get(diff)
if not bd:
continue
total = sum(bd.values())
miss = total - bd.get("match", 0)
if total:
print(
f" {diff:12s} n={total:3d} miss={miss:3d} "
+ ", ".join(f"{k}={v}" for k, v in bd.most_common() if k != "match")
)
print("\n top failure buckets β exemplars:")
for bucket, _ in [(b, c) for b, c in by_bucket.most_common() if b != "match"][:4]:
ex = _exemplars(records, bucket, n=2)
print(f"\n [{bucket}]")
for r in ex:
q = (r.get("question") or "").replace("\n", " ")[:120]
print(f" qid={r['question_id']} ({r['difficulty']}, {r['db_id']}): {q}")
print(f" gold: {(r.get('gold_sql') or '')[:140]}")
print(f" pred: {(r.get('pred_sql') or '')[:140]}")
print(f" reason: {r.get('comparison_reason') or r.get('error_kind') or 'n/a'}")
return {
"by_bucket": dict(by_bucket),
"by_difficulty": {k: dict(v) for k, v in by_diff_bucket.items()},
}
def main() -> int:
if len(sys.argv) < 2:
print(__doc__)
return 2
for path in sys.argv[1:]:
summarise(Path(path))
return 0
if __name__ == "__main__":
raise SystemExit(main())
|