File size: 5,221 Bytes
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""Failure taxonomy for an eval report.

Given a `G_*hybrid*.json` (or any report with the standard record schema), classify
every failing example into an actionable bucket so accuracy work can target the
biggest pile instead of guessing.

Buckets:
    exec_failed         β€” pred SQL crashed at runtime
    exec_timeout        β€” pred SQL hit statement_timeout
    empty_result        β€” pred returned 0 rows, gold has rows
    row_count_off       β€” both ran, row counts differ (missing/extra GROUP BY,
                          missing LIMIT, missing JOIN filter)
    projection_diff     β€” same row count, column shape differs (wrong SELECT list)
    filter_or_value     β€” same shape, values differ (wrong WHERE, wrong JOIN,
                          wrong aggregation expression)
    order_by_off        β€” gold has ORDER BY and gold[0] != pred[0]
    numeric_precision   β€” single-row scalar, off by <1e-3 rel or CAST flavour

Usage:
    uv run python scripts/error_taxonomy.py eval/reports/2026-05-11/G_*hybrid*.json
"""

from __future__ import annotations

import json
import re
import sys
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any

_RE_ORDER_BY = re.compile(r"\border\s+by\b", re.IGNORECASE)
_RE_GROUP_BY = re.compile(r"\bgroup\s+by\b", re.IGNORECASE)
_RE_LIMIT = re.compile(r"\blimit\b", re.IGNORECASE)
_RE_AGG = re.compile(r"\b(sum|avg|count|min|max|cast)\s*\(", re.IGNORECASE)
_RE_CASE = re.compile(r"\bcase\s+when\b|\biif\s*\(", re.IGNORECASE)


def _classify(rec: dict[str, Any]) -> str:
    if rec.get("match"):
        return "match"

    ek = (rec.get("error_kind") or "").lower()
    if "timeout" in ek:
        return "exec_timeout"
    if ek in {"execution_failed", "validation_failed", "parse_failed"}:
        return "exec_failed"
    if ek == "empty_result":
        return "empty_result"

    gc = rec.get("gold_row_count") or 0
    pc = rec.get("pred_row_count") or 0
    gold = rec.get("gold_sql") or ""
    reason = (rec.get("comparison_reason") or "").lower()

    if gc != pc:
        return "row_count_off"

    if (
        gc == pc == 1
        and _RE_AGG.search(gold)
        and (_RE_CASE.search(gold) or "/ " in gold or "*100" in gold.replace(" ", ""))
    ):
        return "filter_or_value"

    if reason.startswith("ordered row"):
        if _RE_ORDER_BY.search(gold):
            return "order_by_off"
        return "filter_or_value"

    if "column" in reason or "projection" in reason or "shape" in reason:
        return "projection_diff"

    return "filter_or_value"


def _exemplars(records: list[dict[str, Any]], bucket: str, n: int = 3) -> list[dict[str, Any]]:
    items = [r for r in records if _classify(r) == bucket]
    return items[:n]


def summarise(report_path: Path) -> dict[str, Any]:
    data = json.loads(report_path.read_text(encoding="utf-8"))
    records = data.get("records", [])
    n = len(records)

    by_bucket: Counter[str] = Counter()
    by_diff_bucket: dict[str, Counter[str]] = defaultdict(Counter)
    for r in records:
        b = _classify(r)
        by_bucket[b] += 1
        by_diff_bucket[r.get("difficulty", "?")][b] += 1

    print(f"\n=== {report_path.name} (n={n}) ===")
    overall = data.get("overall", {})
    print(
        f"EA={overall.get('ea', 0):.3f}  first_pass={overall.get('first_pass_ea', 0):.3f}  "
        f"valid={overall.get('validity_rate', 0):.3f}  recall={overall.get('schema_recall_at_k', 0):.3f}"
    )
    print("\n  bucket             n     %   lift_if_solved")
    print("  ----------------- ---  ----  --------------")
    for bucket, cnt in by_bucket.most_common():
        pct = 100.0 * cnt / n
        lift = 0.0 if bucket == "match" else 100.0 * cnt / n
        print(f"  {bucket:17s} {cnt:3d}  {pct:5.1f}%  +{lift:5.1f}pp")

    print("\n  by difficulty:")
    for diff in ("simple", "moderate", "challenging"):
        bd = by_diff_bucket.get(diff)
        if not bd:
            continue
        total = sum(bd.values())
        miss = total - bd.get("match", 0)
        if total:
            print(
                f"    {diff:12s} n={total:3d} miss={miss:3d}  "
                + ", ".join(f"{k}={v}" for k, v in bd.most_common() if k != "match")
            )

    print("\n  top failure buckets β€” exemplars:")
    for bucket, _ in [(b, c) for b, c in by_bucket.most_common() if b != "match"][:4]:
        ex = _exemplars(records, bucket, n=2)
        print(f"\n  [{bucket}]")
        for r in ex:
            q = (r.get("question") or "").replace("\n", " ")[:120]
            print(f"    qid={r['question_id']} ({r['difficulty']}, {r['db_id']}): {q}")
            print(f"      gold: {(r.get('gold_sql') or '')[:140]}")
            print(f"      pred: {(r.get('pred_sql') or '')[:140]}")
            print(f"      reason: {r.get('comparison_reason') or r.get('error_kind') or 'n/a'}")

    return {
        "by_bucket": dict(by_bucket),
        "by_difficulty": {k: dict(v) for k, v in by_diff_bucket.items()},
    }


def main() -> int:
    if len(sys.argv) < 2:
        print(__doc__)
        return 2
    for path in sys.argv[1:]:
        summarise(Path(path))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())