File size: 8,609 Bytes
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f75098
 
 
 
 
 
 
 
942050b
 
 
 
9f75098
 
 
 
 
 
942050b
 
 
 
 
 
9f75098
 
 
 
 
 
 
942050b
 
 
 
 
9f75098
 
942050b
4b4ff9e
9f75098
 
 
 
942050b
 
9f75098
 
 
942050b
4b4ff9e
942050b
 
3f1a281
 
 
 
 
 
4b4ff9e
3f1a281
4b4ff9e
3f1a281
 
 
 
 
 
 
4b4ff9e
 
 
 
 
 
 
 
 
 
3f1a281
 
 
 
 
 
4b4ff9e
3f1a281
4b4ff9e
 
 
 
 
 
 
3f1a281
 
 
 
 
 
 
 
 
 
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""Execution Accuracy (EA) β€” primary BIRD Mini-Dev metric.

Reference implementation: bird-bench/mini_dev `evaluation_ex.py`. The official
script does set-equality on row tuples after running gold + pred against the
same sqlite DB. We match that behaviour and add three guards:

1. Floats compared with absolute tolerance (1e-6) so trivial CAST/precision
   differences don't flip a correct query to a fail.
2. Rows are normalised to tuples; columns names are NOT compared (BIRD
   accepts any aliasing as long as values match).
3. ORDER BY in gold β†’ order-sensitive comparison. Otherwise set equality.
   This is stricter than the stock BIRD script (which is always set-eq), but
   more honest: a "top 5 by sales" question with gold ORDER BY is wrong
   when the predicted result is in arbitrary order.

`compare_results` is the single source of truth used by `runner.py` and by
unit tests; `execution_accuracy(records)` aggregates a list of comparisons
into a percentage.
"""

from __future__ import annotations

import re
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any

_FLOAT_TOLERANCE = 1e-6
_ORDER_BY_RE = re.compile(r"\border\s+by\b", re.IGNORECASE)


@dataclass(frozen=True, slots=True)
class ResultComparison:
    """Outcome of comparing one (gold, pred) row pair.

    `match` is the EA bit. `reason` describes why a comparison failed, used
    for slicing the report (e.g. "row count mismatch" vs "value mismatch").
    """

    match: bool
    reason: str = ""
    gold_rows: int = 0
    pred_rows: int = 0


def compare_results(
    gold_rows: Sequence[Sequence[Any]],
    pred_rows: Sequence[Sequence[Any]],
    *,
    gold_sql: str | None = None,
) -> ResultComparison:
    """Compare two result sets BIRD-style with two extensions.

    Default β€” BIRD-official set-equality on row tuples (so reported EA stays
    apples-to-apples with the BIRD Mini-Dev leaderboard, AskData+GPT-4o,
    CHESS, XiYan, etc.). A pred with ``DISTINCT`` over a gold without one is
    still a match if the underlying unique rows agree β€” which is exactly how
    the official `bird-bench/mini_dev/evaluation_ex.py` script scores it.

    Extensions on top of vanilla BIRD:

    - Float tolerance: ``abs(a - b) <= 1e-6``.
    - Order-sensitive iff ``gold_sql`` contains ``ORDER BY`` (case-insensitive).
      Pass ``gold_sql=None`` to force set-equality (matches stock BIRD).

    Earlier revisions of this function used multiset (``collections.Counter``)
    equality. That was strictly more conservative than BIRD's own scoring and
    silently penalised pred SQLs that legitimately deduplicated, making
    reported numbers incomparable to the leaderboard. Use the dedicated
    multiset helper below if you ever need strict-duplicate semantics.
    """
    gold_norm = [_normalise_row(r) for r in gold_rows]
    pred_norm = [_normalise_row(r) for r in pred_rows]

    order_sensitive = gold_sql is not None and bool(_ORDER_BY_RE.search(gold_sql))

    if order_sensitive:
        if len(gold_norm) != len(pred_norm):
            return ResultComparison(
                match=False,
                reason=f"ordered row count mismatch: gold={len(gold_norm)}, pred={len(pred_norm)}",
                gold_rows=len(gold_norm),
                pred_rows=len(pred_norm),
            )
        for i, (g, p) in enumerate(zip(gold_norm, pred_norm, strict=True)):
            if not _row_equal(g, p):
                return ResultComparison(
                    match=False,
                    reason=f"ordered row {i} mismatch: gold={g!r}, pred={p!r}",
                    gold_rows=len(gold_norm),
                    pred_rows=len(pred_norm),
                )
        return ResultComparison(match=True, gold_rows=len(gold_norm), pred_rows=len(pred_norm))

    gold_set = {_hashable(g) for g in gold_norm}
    pred_set = {_hashable(p) for p in pred_norm}
    if gold_set != pred_set:
        return ResultComparison(
            match=False,
            reason=f"set mismatch (unique rows differ): |gold|={len(gold_set)}, |pred|={len(pred_set)}",
            gold_rows=len(gold_norm),
            pred_rows=len(pred_norm),
        )
    return ResultComparison(match=True, gold_rows=len(gold_norm), pred_rows=len(pred_norm))


def safe_compare_pred(
    gold_rows: Sequence[Sequence[Any]],
    pred_rows: Sequence[Sequence[Any]],
    *,
    gold_sql: str | None = None,
    pred_failed: bool,
    gold_failed: bool = False,
) -> ResultComparison:
    """Comparison wrapper that short-circuits pred OR gold execution failures.

    Plain `compare_results` is row-level: it treats `pred_rows=[]` identically
    whether pred returned zero rows or pred raised before producing any. When
    gold also returns zero rows (BIRD quirks: empty filter results, missing
    Banned legalities, etc.), `compare_results([], [])` returns match=True β€”
    a silent false positive for malformed pred SQL.

    Symmetric defect on the gold side: `_execute_gold` historically returned
    `([], [])` when BIRD's gold SQL crashed (~1% of cases), and any pred that
    happened to also return zero rows would then be blessed as match=True.

    The runner's `_run_one` and `_run_one_via_pipeline` paths already route
    pred-failure and gold-failure through `_compare_outcome` / direct
    `ResultComparison(match=False)`. Voting and rescoring scripts that bypass
    the runner must use this helper instead of calling `compare_results`
    directly. Pass `pred_failed=True` when pred SQL raised, `gold_failed=True`
    when gold SQL raised.

    Discovered via Codex review of c74b46c β†’ qid 518 (card_games moderate
    "format with most banned cards"): pred CTE missing the WITH prefix,
    SyntaxError on every execution, gold returns 0 rows for that DB, scoring
    blessed it as match=True since v13 (helallao grok-4.1-reasoning rescue).
    Re-merge v22-v29 + 2026-05-25 EOD fix lands the correction.
    Gold-side mirror is Codex audit 2026-05-25 #1 (`runner.py:960`).
    """
    if gold_failed:
        return ResultComparison(
            match=False,
            reason="gold execution failed",
            gold_rows=0,
            pred_rows=len(pred_rows),
        )
    if pred_failed:
        return ResultComparison(
            match=False,
            reason="pred execution failed",
            gold_rows=len(gold_rows),
            pred_rows=0,
        )
    return compare_results(gold_rows, pred_rows, gold_sql=gold_sql)


def execution_accuracy(matches: Sequence[bool]) -> float:
    """Return EA as a fraction in [0, 1]. Empty β†’ 0.0."""
    if not matches:
        return 0.0
    return sum(1 for m in matches if m) / len(matches)


def _normalise_row(row: Sequence[Any]) -> tuple[Any, ...]:
    """Strip type quirks before comparison.

    - Decimal β†’ float (BIRD gold has CAST AS REAL; some drivers return Decimal).
    - bytes β†’ str (sqlite returns BLOB sometimes; strings compare by content).
    - Tuples preserved; everything else stays as-is.
    """
    return tuple(_normalise_cell(v) for v in row)


def _normalise_cell(value: Any) -> Any:
    if isinstance(value, bool):  # bool is a subclass of int β€” don't promote
        return value
    if isinstance(value, int):
        return value
    if isinstance(value, float):
        # NaN compares unequal to itself; map all NaN to a sentinel so two
        # NaN rows from the same query compare equal.
        if value != value:
            return "__NaN__"
        return float(value)
    if isinstance(value, bytes):
        try:
            return value.decode("utf-8")
        except UnicodeDecodeError:
            return value.hex()
    return value


def _row_equal(a: tuple[Any, ...], b: tuple[Any, ...]) -> bool:
    if len(a) != len(b):
        return False
    return all(_cell_equal(x, y) for x, y in zip(a, b, strict=True))


def _cell_equal(a: Any, b: Any) -> bool:
    if isinstance(a, float) or isinstance(b, float):
        try:
            return abs(float(a) - float(b)) <= _FLOAT_TOLERANCE
        except (TypeError, ValueError):
            return False
    return bool(a == b)


def _hashable(row: tuple[Any, ...]) -> tuple[Any, ...]:
    """Project a row into a hashable representation for multiset comparison.

    Floats are quantised to the tolerance grid so that 1.0000001 and 1.0
    bucket together. Strings/ints/None pass through.
    """
    out: list[Any] = []
    for v in row:
        if isinstance(v, float):
            out.append(round(v / _FLOAT_TOLERANCE) if v == v else "__NaN__")
        else:
            out.append(v)
    return tuple(out)