File size: 4,353 Bytes
d48602c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | """Execution-based self-consistency voting for SQL candidates.
For a single question we run the LangGraph pipeline N times at distinct
sampling temperatures, collect the candidates, and pick the one whose
execution result has the largest agreement cluster.
This is the standard NL→SQL technique from Wang et al. (2023) — clustering
on the *execution result* (not the SQL string) tolerates equivalent SQL
spelt differently and is robust to small surface-level diversity.
"""
from __future__ import annotations
import hashlib
from collections import defaultdict
from dataclasses import dataclass
from typing import Any
from nl_sql.agent.graph import PipelineRunResult
from nl_sql.execution.errors import ExecutionErrorKind
@dataclass(frozen=True, slots=True)
class Candidate:
"""One pipeline pass + its sampling temperature."""
result: PipelineRunResult
temperature: float
def fingerprint_rows(rows: list[tuple[Any, ...]]) -> str:
"""Order-agnostic, type-stable fingerprint of a row set.
BIRD-style execution accuracy is set-based unless the gold SQL has
ORDER BY, so the canonical voting key sorts rows. Floats are rounded
to 6 decimals to merge candidates that differ only in CAST precision.
Heterogeneous types (None mixed with str/int) are made comparable
by sorting on the repr — never on the raw value.
"""
canon_rows = [tuple(_normalise_value(v) for v in row) for row in rows]
canon = sorted(canon_rows, key=lambda r: tuple((type(v).__name__, repr(v)) for v in r))
return hashlib.sha256(repr(canon).encode("utf-8")).hexdigest()
def _normalise_value(v: Any) -> Any:
if isinstance(v, float):
return round(v, 6)
if isinstance(v, str):
return v.strip()
return v
def vote(candidates: list[Candidate]) -> Candidate:
"""Pick the winning candidate by execution-result clustering.
Algorithm:
1. Drop candidates whose execution failed (INVALID_SQL or
EXECUTION_FAILED). EMPTY_RESULT counts as a real cluster — an
empty answer can be the right answer.
2. If no candidate executed, fall back to the highest-confidence
candidate (the LLM's own self-rating, breaking ties by
temperature ascending so greedy wins).
3. Otherwise cluster on the row fingerprint. Pick the largest
cluster; ties broken by max confidence within cluster, then
by lowest temperature (greedy preferred).
"""
if not candidates:
raise ValueError("vote() requires at least one candidate")
runnable = [c for c in candidates if _executed(c)]
if not runnable:
return max(
candidates,
key=lambda c: (_confidence(c), -c.temperature),
)
clusters: dict[str, list[Candidate]] = defaultdict(list)
for c in runnable:
rows = c.result.outcome.result.rows if c.result.outcome and c.result.outcome.result else []
clusters[fingerprint_rows(rows)].append(c)
def cluster_score(key: str) -> tuple[int, float, float]:
members = clusters[key]
return (
len(members),
max(_confidence(m) for m in members),
-min(m.temperature for m in members),
)
best_key = max(clusters, key=cluster_score)
return max(
clusters[best_key],
key=lambda c: (_confidence(c), -c.temperature),
)
def _executed(c: Candidate) -> bool:
"""True iff the candidate produced rows we can vote on.
Treat EMPTY_RESULT as runnable: zero rows is a legitimate answer
(e.g. "list customers with no purchases"). INVALID_SQL and
EXECUTION_FAILED are not eligible.
"""
if c.result.outcome is None or c.result.outcome.result is None:
return False
kind = c.result.error_kind
return kind not in (ExecutionErrorKind.INVALID_SQL, ExecutionErrorKind.EXECUTION_FAILED)
def _confidence(c: Candidate) -> float:
"""LLM self-rating from generate_sql trace, default 0.0 if missing."""
for step in reversed(c.result.trace):
if step.get("node") in ("generate_sql", "repair_once"):
value = step.get("confidence")
if isinstance(value, int | float):
return float(value)
return 0.0
|