"""Execution-based self-consistency voting for SQL candidates. For a single question we run the LangGraph pipeline N times at distinct sampling temperatures, collect the candidates, and pick the one whose execution result has the largest agreement cluster. This is the standard NL→SQL technique from Wang et al. (2023) — clustering on the *execution result* (not the SQL string) tolerates equivalent SQL spelt differently and is robust to small surface-level diversity. """ from __future__ import annotations import hashlib from collections import defaultdict from dataclasses import dataclass from typing import Any from nl_sql.agent.graph import PipelineRunResult from nl_sql.execution.errors import ExecutionErrorKind @dataclass(frozen=True, slots=True) class Candidate: """One pipeline pass + its sampling temperature.""" result: PipelineRunResult temperature: float def fingerprint_rows(rows: list[tuple[Any, ...]]) -> str: """Order-agnostic, type-stable fingerprint of a row set. BIRD-style execution accuracy is set-based unless the gold SQL has ORDER BY, so the canonical voting key sorts rows. Floats are rounded to 6 decimals to merge candidates that differ only in CAST precision. Heterogeneous types (None mixed with str/int) are made comparable by sorting on the repr — never on the raw value. """ canon_rows = [tuple(_normalise_value(v) for v in row) for row in rows] canon = sorted(canon_rows, key=lambda r: tuple((type(v).__name__, repr(v)) for v in r)) return hashlib.sha256(repr(canon).encode("utf-8")).hexdigest() def _normalise_value(v: Any) -> Any: if isinstance(v, float): return round(v, 6) if isinstance(v, str): return v.strip() return v def vote(candidates: list[Candidate]) -> Candidate: """Pick the winning candidate by execution-result clustering. Algorithm: 1. Drop candidates whose execution failed (INVALID_SQL or EXECUTION_FAILED). EMPTY_RESULT counts as a real cluster — an empty answer can be the right answer. 2. If no candidate executed, fall back to the highest-confidence candidate (the LLM's own self-rating, breaking ties by temperature ascending so greedy wins). 3. Otherwise cluster on the row fingerprint. Pick the largest cluster; ties broken by max confidence within cluster, then by lowest temperature (greedy preferred). """ if not candidates: raise ValueError("vote() requires at least one candidate") runnable = [c for c in candidates if _executed(c)] if not runnable: return max( candidates, key=lambda c: (_confidence(c), -c.temperature), ) clusters: dict[str, list[Candidate]] = defaultdict(list) for c in runnable: rows = c.result.outcome.result.rows if c.result.outcome and c.result.outcome.result else [] clusters[fingerprint_rows(rows)].append(c) def cluster_score(key: str) -> tuple[int, float, float]: members = clusters[key] return ( len(members), max(_confidence(m) for m in members), -min(m.temperature for m in members), ) best_key = max(clusters, key=cluster_score) return max( clusters[best_key], key=lambda c: (_confidence(c), -c.temperature), ) def _executed(c: Candidate) -> bool: """True iff the candidate produced rows we can vote on. Treat EMPTY_RESULT as runnable: zero rows is a legitimate answer (e.g. "list customers with no purchases"). INVALID_SQL and EXECUTION_FAILED are not eligible. """ if c.result.outcome is None or c.result.outcome.result is None: return False kind = c.result.error_kind return kind not in (ExecutionErrorKind.INVALID_SQL, ExecutionErrorKind.EXECUTION_FAILED) def _confidence(c: Candidate) -> float: """LLM self-rating from generate_sql trace, default 0.0 if missing.""" for step in reversed(c.result.trace): if step.get("node") in ("generate_sql", "repair_once"): value = step.get("confidence") if isinstance(value, int | float): return float(value) return 0.0