File size: 8,609 Bytes
942050b 9f75098 942050b 9f75098 942050b 9f75098 942050b 9f75098 942050b 4b4ff9e 9f75098 942050b 9f75098 942050b 4b4ff9e 942050b 3f1a281 4b4ff9e 3f1a281 4b4ff9e 3f1a281 4b4ff9e 3f1a281 4b4ff9e 3f1a281 4b4ff9e 3f1a281 942050b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | """Execution Accuracy (EA) β primary BIRD Mini-Dev metric.
Reference implementation: bird-bench/mini_dev `evaluation_ex.py`. The official
script does set-equality on row tuples after running gold + pred against the
same sqlite DB. We match that behaviour and add three guards:
1. Floats compared with absolute tolerance (1e-6) so trivial CAST/precision
differences don't flip a correct query to a fail.
2. Rows are normalised to tuples; columns names are NOT compared (BIRD
accepts any aliasing as long as values match).
3. ORDER BY in gold β order-sensitive comparison. Otherwise set equality.
This is stricter than the stock BIRD script (which is always set-eq), but
more honest: a "top 5 by sales" question with gold ORDER BY is wrong
when the predicted result is in arbitrary order.
`compare_results` is the single source of truth used by `runner.py` and by
unit tests; `execution_accuracy(records)` aggregates a list of comparisons
into a percentage.
"""
from __future__ import annotations
import re
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
_FLOAT_TOLERANCE = 1e-6
_ORDER_BY_RE = re.compile(r"\border\s+by\b", re.IGNORECASE)
@dataclass(frozen=True, slots=True)
class ResultComparison:
"""Outcome of comparing one (gold, pred) row pair.
`match` is the EA bit. `reason` describes why a comparison failed, used
for slicing the report (e.g. "row count mismatch" vs "value mismatch").
"""
match: bool
reason: str = ""
gold_rows: int = 0
pred_rows: int = 0
def compare_results(
gold_rows: Sequence[Sequence[Any]],
pred_rows: Sequence[Sequence[Any]],
*,
gold_sql: str | None = None,
) -> ResultComparison:
"""Compare two result sets BIRD-style with two extensions.
Default β BIRD-official set-equality on row tuples (so reported EA stays
apples-to-apples with the BIRD Mini-Dev leaderboard, AskData+GPT-4o,
CHESS, XiYan, etc.). A pred with ``DISTINCT`` over a gold without one is
still a match if the underlying unique rows agree β which is exactly how
the official `bird-bench/mini_dev/evaluation_ex.py` script scores it.
Extensions on top of vanilla BIRD:
- Float tolerance: ``abs(a - b) <= 1e-6``.
- Order-sensitive iff ``gold_sql`` contains ``ORDER BY`` (case-insensitive).
Pass ``gold_sql=None`` to force set-equality (matches stock BIRD).
Earlier revisions of this function used multiset (``collections.Counter``)
equality. That was strictly more conservative than BIRD's own scoring and
silently penalised pred SQLs that legitimately deduplicated, making
reported numbers incomparable to the leaderboard. Use the dedicated
multiset helper below if you ever need strict-duplicate semantics.
"""
gold_norm = [_normalise_row(r) for r in gold_rows]
pred_norm = [_normalise_row(r) for r in pred_rows]
order_sensitive = gold_sql is not None and bool(_ORDER_BY_RE.search(gold_sql))
if order_sensitive:
if len(gold_norm) != len(pred_norm):
return ResultComparison(
match=False,
reason=f"ordered row count mismatch: gold={len(gold_norm)}, pred={len(pred_norm)}",
gold_rows=len(gold_norm),
pred_rows=len(pred_norm),
)
for i, (g, p) in enumerate(zip(gold_norm, pred_norm, strict=True)):
if not _row_equal(g, p):
return ResultComparison(
match=False,
reason=f"ordered row {i} mismatch: gold={g!r}, pred={p!r}",
gold_rows=len(gold_norm),
pred_rows=len(pred_norm),
)
return ResultComparison(match=True, gold_rows=len(gold_norm), pred_rows=len(pred_norm))
gold_set = {_hashable(g) for g in gold_norm}
pred_set = {_hashable(p) for p in pred_norm}
if gold_set != pred_set:
return ResultComparison(
match=False,
reason=f"set mismatch (unique rows differ): |gold|={len(gold_set)}, |pred|={len(pred_set)}",
gold_rows=len(gold_norm),
pred_rows=len(pred_norm),
)
return ResultComparison(match=True, gold_rows=len(gold_norm), pred_rows=len(pred_norm))
def safe_compare_pred(
gold_rows: Sequence[Sequence[Any]],
pred_rows: Sequence[Sequence[Any]],
*,
gold_sql: str | None = None,
pred_failed: bool,
gold_failed: bool = False,
) -> ResultComparison:
"""Comparison wrapper that short-circuits pred OR gold execution failures.
Plain `compare_results` is row-level: it treats `pred_rows=[]` identically
whether pred returned zero rows or pred raised before producing any. When
gold also returns zero rows (BIRD quirks: empty filter results, missing
Banned legalities, etc.), `compare_results([], [])` returns match=True β
a silent false positive for malformed pred SQL.
Symmetric defect on the gold side: `_execute_gold` historically returned
`([], [])` when BIRD's gold SQL crashed (~1% of cases), and any pred that
happened to also return zero rows would then be blessed as match=True.
The runner's `_run_one` and `_run_one_via_pipeline` paths already route
pred-failure and gold-failure through `_compare_outcome` / direct
`ResultComparison(match=False)`. Voting and rescoring scripts that bypass
the runner must use this helper instead of calling `compare_results`
directly. Pass `pred_failed=True` when pred SQL raised, `gold_failed=True`
when gold SQL raised.
Discovered via Codex review of c74b46c β qid 518 (card_games moderate
"format with most banned cards"): pred CTE missing the WITH prefix,
SyntaxError on every execution, gold returns 0 rows for that DB, scoring
blessed it as match=True since v13 (helallao grok-4.1-reasoning rescue).
Re-merge v22-v29 + 2026-05-25 EOD fix lands the correction.
Gold-side mirror is Codex audit 2026-05-25 #1 (`runner.py:960`).
"""
if gold_failed:
return ResultComparison(
match=False,
reason="gold execution failed",
gold_rows=0,
pred_rows=len(pred_rows),
)
if pred_failed:
return ResultComparison(
match=False,
reason="pred execution failed",
gold_rows=len(gold_rows),
pred_rows=0,
)
return compare_results(gold_rows, pred_rows, gold_sql=gold_sql)
def execution_accuracy(matches: Sequence[bool]) -> float:
"""Return EA as a fraction in [0, 1]. Empty β 0.0."""
if not matches:
return 0.0
return sum(1 for m in matches if m) / len(matches)
def _normalise_row(row: Sequence[Any]) -> tuple[Any, ...]:
"""Strip type quirks before comparison.
- Decimal β float (BIRD gold has CAST AS REAL; some drivers return Decimal).
- bytes β str (sqlite returns BLOB sometimes; strings compare by content).
- Tuples preserved; everything else stays as-is.
"""
return tuple(_normalise_cell(v) for v in row)
def _normalise_cell(value: Any) -> Any:
if isinstance(value, bool): # bool is a subclass of int β don't promote
return value
if isinstance(value, int):
return value
if isinstance(value, float):
# NaN compares unequal to itself; map all NaN to a sentinel so two
# NaN rows from the same query compare equal.
if value != value:
return "__NaN__"
return float(value)
if isinstance(value, bytes):
try:
return value.decode("utf-8")
except UnicodeDecodeError:
return value.hex()
return value
def _row_equal(a: tuple[Any, ...], b: tuple[Any, ...]) -> bool:
if len(a) != len(b):
return False
return all(_cell_equal(x, y) for x, y in zip(a, b, strict=True))
def _cell_equal(a: Any, b: Any) -> bool:
if isinstance(a, float) or isinstance(b, float):
try:
return abs(float(a) - float(b)) <= _FLOAT_TOLERANCE
except (TypeError, ValueError):
return False
return bool(a == b)
def _hashable(row: tuple[Any, ...]) -> tuple[Any, ...]:
"""Project a row into a hashable representation for multiset comparison.
Floats are quantised to the tolerance grid so that 1.0000001 and 1.0
bucket together. Strings/ints/None pass through.
"""
out: list[Any] = []
for v in row:
if isinstance(v, float):
out.append(round(v / _FLOAT_TOLERANCE) if v == v else "__NaN__")
else:
out.append(v)
return tuple(out)
|