| """Deterministic format / chart-type picker. |
| |
| Pure Python heuristics. No LLM call. Replaces the v1 plan that asked the LLM |
| to emit Vega-Lite specs (high failure rate per CX/KM review). |
| |
| Decision tree (in order, first match wins): |
| |
| 1. Empty result -> Sentence("the query returned no rows") |
| 2. 1 row, 1 column -> Scalar |
| 3. <= 200 rows, >= 2 columns, first col temporal -> LineChart (with numeric Y cols) |
| 4. 2 columns, <= 12 rows, col 0 categorical and col 1 numeric -> BarChart |
| (or PieChart if it looks like a share-of-total breakdown) |
| 5. 2 numeric columns, >= 6 rows -> ScatterChart |
| 6. Otherwise -> Table |
| """ |
|
|
| from __future__ import annotations |
|
|
| import datetime as dt |
| from collections.abc import Sequence |
| from typing import Any |
|
|
| from nl_sql.render.formats import ( |
| BarChart, |
| LineChart, |
| OutputFormat, |
| PieChart, |
| Scalar, |
| ScatterChart, |
| Sentence, |
| Table, |
| ) |
|
|
| _MAX_LINE_ROWS = 200 |
| _MAX_BAR_ROWS = 12 |
| _MAX_PIE_ROWS = 6 |
| _MIN_SCATTER_ROWS = 6 |
|
|
|
|
| def pick_format( |
| columns: Sequence[str], |
| rows: Sequence[Sequence[Any]], |
| ) -> OutputFormat: |
| """Choose the right output format from the executed query result shape.""" |
| cols = list(columns) |
| data = [list(r) for r in rows] |
| n_rows = len(data) |
| n_cols = len(cols) |
|
|
| if n_rows == 0: |
| return Sentence(text="the query returned no rows") |
|
|
| if n_rows == 1 and n_cols == 1: |
| return Scalar(value=data[0][0], column=cols[0]) |
|
|
| if n_cols >= 2 and _is_temporal_column(data, 0) and n_rows <= _MAX_LINE_ROWS: |
| return ( |
| LineChart( |
| columns=cols, |
| rows=data, |
| x_field=cols[0], |
| y_fields=[c for c in cols[1:] if _is_numeric_column(data, cols.index(c))], |
| ) |
| if any(_is_numeric_column(data, i) for i in range(1, n_cols)) |
| else Table(columns=cols, rows=data) |
| ) |
|
|
| if ( |
| n_cols == 2 |
| and n_rows <= _MAX_BAR_ROWS |
| and _is_categorical_column(data, 0) |
| and _is_numeric_column(data, 1) |
| ): |
| if n_rows <= _MAX_PIE_ROWS and _looks_like_share(data): |
| return PieChart(columns=cols, rows=data, x_field=cols[0], y_fields=[cols[1]]) |
| return BarChart(columns=cols, rows=data, x_field=cols[0], y_fields=[cols[1]]) |
|
|
| if ( |
| n_cols == 2 |
| and n_rows >= _MIN_SCATTER_ROWS |
| and _is_numeric_column(data, 0) |
| and _is_numeric_column(data, 1) |
| ): |
| return ScatterChart(columns=cols, rows=data, x_field=cols[0], y_fields=[cols[1]]) |
|
|
| return Table(columns=cols, rows=data) |
|
|
|
|
| def _is_temporal_column(rows: Sequence[Sequence[Any]], idx: int) -> bool: |
| if not rows: |
| return False |
| sample = [row[idx] for row in rows if row[idx] is not None][:10] |
| if not sample: |
| return False |
| if all(isinstance(v, dt.date | dt.datetime) for v in sample): |
| return True |
| return all(isinstance(v, str) and _looks_like_iso_date(v) for v in sample) |
|
|
|
|
| def _looks_like_iso_date(s: str) -> bool: |
| if len(s) < 7: |
| return False |
| try: |
| dt.date.fromisoformat(s[:10]) |
| except ValueError: |
| return False |
| return True |
|
|
|
|
| def _is_numeric_column(rows: Sequence[Sequence[Any]], idx: int) -> bool: |
| if not rows: |
| return False |
| sample = [row[idx] for row in rows if row[idx] is not None][:20] |
| if not sample: |
| return False |
| return all(isinstance(v, int | float) and not isinstance(v, bool) for v in sample) |
|
|
|
|
| def _is_categorical_column(rows: Sequence[Sequence[Any]], idx: int) -> bool: |
| if not rows: |
| return False |
| sample = [row[idx] for row in rows if row[idx] is not None][:20] |
| if not sample: |
| return False |
| if _is_numeric_column(rows, idx): |
| return False |
| return all(isinstance(v, str) for v in sample) |
|
|
|
|
| def _looks_like_share(rows: Sequence[Sequence[Any]]) -> bool: |
| """Heuristic: looks like a share-of-total breakdown if the numeric column |
| is non-negative and the largest category is < 80% of the total.""" |
| values = [ |
| row[1] for row in rows if isinstance(row[1], int | float) and not isinstance(row[1], bool) |
| ] |
| if len(values) < 2 or any(v < 0 for v in values): |
| return False |
| total = sum(values) |
| if total <= 0: |
| return False |
| return max(values) / total < 0.80 |
|
|