File size: 4,296 Bytes
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Deterministic format / chart-type picker.

Pure Python heuristics. No LLM call. Replaces the v1 plan that asked the LLM
to emit Vega-Lite specs (high failure rate per CX/KM review).

Decision tree (in order, first match wins):

1. Empty result -> Sentence("the query returned no rows")
2. 1 row, 1 column -> Scalar
3. <= 200 rows, >= 2 columns, first col temporal -> LineChart (with numeric Y cols)
4. 2 columns, <= 12 rows, col 0 categorical and col 1 numeric -> BarChart
   (or PieChart if it looks like a share-of-total breakdown)
5. 2 numeric columns, >= 6 rows -> ScatterChart
6. Otherwise -> Table
"""

from __future__ import annotations

import datetime as dt
from collections.abc import Sequence
from typing import Any

from nl_sql.render.formats import (
    BarChart,
    LineChart,
    OutputFormat,
    PieChart,
    Scalar,
    ScatterChart,
    Sentence,
    Table,
)

_MAX_LINE_ROWS = 200
_MAX_BAR_ROWS = 12
_MAX_PIE_ROWS = 6
_MIN_SCATTER_ROWS = 6


def pick_format(
    columns: Sequence[str],
    rows: Sequence[Sequence[Any]],
) -> OutputFormat:
    """Choose the right output format from the executed query result shape."""
    cols = list(columns)
    data = [list(r) for r in rows]
    n_rows = len(data)
    n_cols = len(cols)

    if n_rows == 0:
        return Sentence(text="the query returned no rows")

    if n_rows == 1 and n_cols == 1:
        return Scalar(value=data[0][0], column=cols[0])

    if n_cols >= 2 and _is_temporal_column(data, 0) and n_rows <= _MAX_LINE_ROWS:
        return (
            LineChart(
                columns=cols,
                rows=data,
                x_field=cols[0],
                y_fields=[c for c in cols[1:] if _is_numeric_column(data, cols.index(c))],
            )
            if any(_is_numeric_column(data, i) for i in range(1, n_cols))
            else Table(columns=cols, rows=data)
        )

    if (
        n_cols == 2
        and n_rows <= _MAX_BAR_ROWS
        and _is_categorical_column(data, 0)
        and _is_numeric_column(data, 1)
    ):
        if n_rows <= _MAX_PIE_ROWS and _looks_like_share(data):
            return PieChart(columns=cols, rows=data, x_field=cols[0], y_fields=[cols[1]])
        return BarChart(columns=cols, rows=data, x_field=cols[0], y_fields=[cols[1]])

    if (
        n_cols == 2
        and n_rows >= _MIN_SCATTER_ROWS
        and _is_numeric_column(data, 0)
        and _is_numeric_column(data, 1)
    ):
        return ScatterChart(columns=cols, rows=data, x_field=cols[0], y_fields=[cols[1]])

    return Table(columns=cols, rows=data)


def _is_temporal_column(rows: Sequence[Sequence[Any]], idx: int) -> bool:
    if not rows:
        return False
    sample = [row[idx] for row in rows if row[idx] is not None][:10]
    if not sample:
        return False
    if all(isinstance(v, dt.date | dt.datetime) for v in sample):
        return True
    return all(isinstance(v, str) and _looks_like_iso_date(v) for v in sample)


def _looks_like_iso_date(s: str) -> bool:
    if len(s) < 7:
        return False
    try:
        dt.date.fromisoformat(s[:10])
    except ValueError:
        return False
    return True


def _is_numeric_column(rows: Sequence[Sequence[Any]], idx: int) -> bool:
    if not rows:
        return False
    sample = [row[idx] for row in rows if row[idx] is not None][:20]
    if not sample:
        return False
    return all(isinstance(v, int | float) and not isinstance(v, bool) for v in sample)


def _is_categorical_column(rows: Sequence[Sequence[Any]], idx: int) -> bool:
    if not rows:
        return False
    sample = [row[idx] for row in rows if row[idx] is not None][:20]
    if not sample:
        return False
    if _is_numeric_column(rows, idx):
        return False
    return all(isinstance(v, str) for v in sample)


def _looks_like_share(rows: Sequence[Sequence[Any]]) -> bool:
    """Heuristic: looks like a share-of-total breakdown if the numeric column
    is non-negative and the largest category is < 80% of the total."""
    values = [
        row[1] for row in rows if isinstance(row[1], int | float) and not isinstance(row[1], bool)
    ]
    if len(values) < 2 or any(v < 0 for v in values):
        return False
    total = sum(values)
    if total <= 0:
        return False
    return max(values) / total < 0.80