nl-sql / src /nl_sql /render /picker.py
liovina's picture
Deploy NL_SQL HEAD to HF Space
942050b verified
"""Deterministic format / chart-type picker.
Pure Python heuristics. No LLM call. Replaces the v1 plan that asked the LLM
to emit Vega-Lite specs (high failure rate per CX/KM review).
Decision tree (in order, first match wins):
1. Empty result -> Sentence("the query returned no rows")
2. 1 row, 1 column -> Scalar
3. <= 200 rows, >= 2 columns, first col temporal -> LineChart (with numeric Y cols)
4. 2 columns, <= 12 rows, col 0 categorical and col 1 numeric -> BarChart
(or PieChart if it looks like a share-of-total breakdown)
5. 2 numeric columns, >= 6 rows -> ScatterChart
6. Otherwise -> Table
"""
from __future__ import annotations
import datetime as dt
from collections.abc import Sequence
from typing import Any
from nl_sql.render.formats import (
BarChart,
LineChart,
OutputFormat,
PieChart,
Scalar,
ScatterChart,
Sentence,
Table,
)
_MAX_LINE_ROWS = 200
_MAX_BAR_ROWS = 12
_MAX_PIE_ROWS = 6
_MIN_SCATTER_ROWS = 6
def pick_format(
columns: Sequence[str],
rows: Sequence[Sequence[Any]],
) -> OutputFormat:
"""Choose the right output format from the executed query result shape."""
cols = list(columns)
data = [list(r) for r in rows]
n_rows = len(data)
n_cols = len(cols)
if n_rows == 0:
return Sentence(text="the query returned no rows")
if n_rows == 1 and n_cols == 1:
return Scalar(value=data[0][0], column=cols[0])
if n_cols >= 2 and _is_temporal_column(data, 0) and n_rows <= _MAX_LINE_ROWS:
return (
LineChart(
columns=cols,
rows=data,
x_field=cols[0],
y_fields=[c for c in cols[1:] if _is_numeric_column(data, cols.index(c))],
)
if any(_is_numeric_column(data, i) for i in range(1, n_cols))
else Table(columns=cols, rows=data)
)
if (
n_cols == 2
and n_rows <= _MAX_BAR_ROWS
and _is_categorical_column(data, 0)
and _is_numeric_column(data, 1)
):
if n_rows <= _MAX_PIE_ROWS and _looks_like_share(data):
return PieChart(columns=cols, rows=data, x_field=cols[0], y_fields=[cols[1]])
return BarChart(columns=cols, rows=data, x_field=cols[0], y_fields=[cols[1]])
if (
n_cols == 2
and n_rows >= _MIN_SCATTER_ROWS
and _is_numeric_column(data, 0)
and _is_numeric_column(data, 1)
):
return ScatterChart(columns=cols, rows=data, x_field=cols[0], y_fields=[cols[1]])
return Table(columns=cols, rows=data)
def _is_temporal_column(rows: Sequence[Sequence[Any]], idx: int) -> bool:
if not rows:
return False
sample = [row[idx] for row in rows if row[idx] is not None][:10]
if not sample:
return False
if all(isinstance(v, dt.date | dt.datetime) for v in sample):
return True
return all(isinstance(v, str) and _looks_like_iso_date(v) for v in sample)
def _looks_like_iso_date(s: str) -> bool:
if len(s) < 7:
return False
try:
dt.date.fromisoformat(s[:10])
except ValueError:
return False
return True
def _is_numeric_column(rows: Sequence[Sequence[Any]], idx: int) -> bool:
if not rows:
return False
sample = [row[idx] for row in rows if row[idx] is not None][:20]
if not sample:
return False
return all(isinstance(v, int | float) and not isinstance(v, bool) for v in sample)
def _is_categorical_column(rows: Sequence[Sequence[Any]], idx: int) -> bool:
if not rows:
return False
sample = [row[idx] for row in rows if row[idx] is not None][:20]
if not sample:
return False
if _is_numeric_column(rows, idx):
return False
return all(isinstance(v, str) for v in sample)
def _looks_like_share(rows: Sequence[Sequence[Any]]) -> bool:
"""Heuristic: looks like a share-of-total breakdown if the numeric column
is non-negative and the largest category is < 80% of the total."""
values = [
row[1] for row in rows if isinstance(row[1], int | float) and not isinstance(row[1], bool)
]
if len(values) < 2 or any(v < 0 for v in values):
return False
total = sum(values)
if total <= 0:
return False
return max(values) / total < 0.80