File size: 13,559 Bytes
9439512
f3fd40f
9439512
 
 
f3fd40f
9439512
 
 
 
 
 
 
 
 
 
 
cb6c215
 
 
 
 
962831e
cb6c215
 
 
962831e
cb6c215
 
 
f3fd40f
cb6c215
 
 
 
 
f3fd40f
cb6c215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9439512
cb6c215
 
 
 
 
9439512
 
 
 
 
 
cb6c215
 
 
 
9439512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3fd40f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9439512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb6c215
9439512
 
cb6c215
 
 
9439512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb6c215
9439512
 
cb6c215
 
 
9439512
 
 
 
 
 
 
 
 
 
 
 
 
cb6c215
9439512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb6c215
962831e
 
 
 
cb6c215
962831e
 
cb6c215
9439512
 
 
 
 
 
 
 
 
f3fd40f
 
9439512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
import ast
import difflib
import json
import logging
import os
import re
import time

from dotenv import load_dotenv

from chart_generator import ChartGenerator
from data_processor import DataProcessor

load_dotenv()

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Model IDs (downloaded at Docker build, cached in HF_HOME)
# ---------------------------------------------------------------------------
QWEN_MODEL_ID = os.getenv("QWEN_MODEL_ID", "Qwen/Qwen2.5-Coder-0.5B-Instruct")
BART_MODEL_ID = os.getenv("BART_MODEL_ID", "ArchCoder/fine-tuned-bart-large")

# ---------------------------------------------------------------------------
# Prompt templates with few-shot examples
# ---------------------------------------------------------------------------

_SYSTEM_PROMPT = """\
You are a data visualization expert. Given the user request and dataset schema, \
output ONLY a valid JSON object. No explanation, no markdown fences, no extra text.

Required JSON keys:
  "x"          : string  β€” exact column name for the x-axis
  "y"          : array   β€” one or more exact column names for the y-axis
  "chart_type" : string  β€” one of: line, bar, scatter, pie, histogram, box, area
  "color"      : string or null β€” optional CSS color like "red", "#4f8cff"

Rules:
- Use ONLY column names from the schema. Never invent names.
- For pie charts: y must contain exactly one column.
- For histogram/box: x may equal the first element of y.
- Default to "line" if chart type is ambiguous.

### Examples

Example 1:
Schema: Year (integer), Sales (float), Profit (float)
User: "plot sales over the years with a red line"
Output: {"x": "Year", "y": ["Sales"], "chart_type": "line", "color": "red"}

Example 2:
Schema: Month (string), Revenue (float), Expenses (float)
User: "bar chart comparing revenue and expenses by month"
Output: {"x": "Month", "y": ["Revenue", "Expenses"], "chart_type": "bar", "color": null}

Example 3:
Schema: Category (string), Count (integer)
User: "pie chart of count by category"
Output: {"x": "Category", "y": ["Count"], "chart_type": "pie", "color": null}

Example 4:
Schema: Date (string), Temperature (float), Humidity (float)
User: "scatter plot of temperature vs humidity in blue"
Output: {"x": "Temperature", "y": ["Humidity"], "chart_type": "scatter", "color": "blue"}

Example 5:
Schema: Year (integer), Sales (float), Employee expense (float), Marketing expense (float)
User: "show me an area chart of sales and marketing expense over years"
Output: {"x": "Year", "y": ["Sales", "Marketing expense"], "chart_type": "area", "color": null}
"""


def _user_message(query: str, columns: list, dtypes: dict, sample_rows: list) -> str:
    schema = "\n".join(f"  - {c} ({dtypes.get(c, 'unknown')})" for c in columns)
    samples = "".join(f"  {json.dumps(r)}\n" for r in sample_rows[:3])
    return (
        f"Schema:\n{schema}\n\n"
        f"Sample rows:\n{samples}\n"
        f"User: \"{query}\"\n"
        f"Output:"
    )


# ---------------------------------------------------------------------------
# Output parsing & validation
# ---------------------------------------------------------------------------

def _parse_output(text: str):
    text = text.strip()
    if "```" in text:
        for part in text.split("```"):
            part = part.strip().lstrip("json").strip()
            if part.startswith("{"):
                text = part
                break
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass
    try:
        return ast.literal_eval(text)
    except (SyntaxError, ValueError):
        pass
    return None


def _validate(args: dict, columns: list):
    if not isinstance(args, dict):
        return None
    if not all(k in args for k in ("x", "y", "chart_type")):
        return None
    if isinstance(args["y"], str):
        args["y"] = [args["y"]]
    valid = {"line", "bar", "scatter", "pie", "histogram", "box", "area"}
    if args["chart_type"] not in valid:
        args["chart_type"] = "line"
    if args["x"] not in columns:
        return None
    if not all(c in columns for c in args["y"]):
        return None
    return args


def _pick_chart_type(query: str) -> str:
    lowered = query.lower()
    aliases = {
        "scatter": ["scatter", "scatterplot"],
        "bar": ["bar", "column"],
        "pie": ["pie", "donut"],
        "histogram": ["histogram", "distribution"],
        "box": ["box", "boxplot"],
        "area": ["area"],
        "line": ["line", "trend", "over time", "over the years"],
    }
    for chart_type, keywords in aliases.items():
        if any(keyword in lowered for keyword in keywords):
            return chart_type
    return "line"


def _pick_color(query: str):
    lowered = query.lower()
    colors = [
        "red", "blue", "green", "yellow", "orange", "purple", "pink",
        "black", "white", "gray", "grey", "cyan", "teal", "indigo",
    ]
    for color in colors:
        if re.search(rf"\b{re.escape(color)}\b", lowered):
            return color
    return None


def _pick_columns(query: str, columns: list, dtypes: dict):
    lowered = query.lower()
    query_tokens = re.findall(r"[a-zA-Z0-9_]+", lowered)

    def score_column(column: str) -> float:
        col_lower = column.lower()
        score = 0.0
        if col_lower in lowered:
            score += 10.0
        for token in query_tokens:
            if token and token in col_lower:
                score += 2.0
        score += difflib.SequenceMatcher(None, lowered, col_lower).ratio()
        return score

    sorted_columns = sorted(columns, key=score_column, reverse=True)
    numeric_columns = [col for col in columns if dtypes.get(col) in {"integer", "float"}]
    temporal_columns = [col for col in columns if dtypes.get(col) == "datetime"]
    year_like = [col for col in columns if "year" in col.lower() or "date" in col.lower() or "month" in col.lower()]

    x_col = None
    for candidate in year_like + temporal_columns + sorted_columns:
        if candidate in columns:
            x_col = candidate
            break
    if x_col is None and columns:
        x_col = columns[0]

    y_candidates = [col for col in sorted_columns if col != x_col and col in numeric_columns]
    if not y_candidates:
        y_candidates = [col for col in numeric_columns if col != x_col]
    if not y_candidates:
        y_candidates = [col for col in columns if col != x_col]

    return x_col, y_candidates[:1]


def _heuristic_plot_args(query: str, columns: list, dtypes: dict) -> dict:
    x_col, y_cols = _pick_columns(query, columns, dtypes)
    if not x_col:
        x_col = "Year"
    if not y_cols:
        fallback_y = next((col for col in columns if col != x_col), columns[:1])
        y_cols = list(fallback_y) if isinstance(fallback_y, tuple) else fallback_y
        if isinstance(y_cols, str):
            y_cols = [y_cols]
    return {
        "x": x_col,
        "y": y_cols,
        "chart_type": _pick_chart_type(query),
        "color": _pick_color(query),
    }


# ---------------------------------------------------------------------------
# Agent
# ---------------------------------------------------------------------------

class LLM_Agent:
    def __init__(self, data_path=None):
        logger.info("Initializing LLM_Agent")
        self.data_processor = DataProcessor(data_path)
        self.chart_generator = ChartGenerator(self.data_processor.data)
        self._bart_tokenizer = None
        self._bart_model = None
        self._qwen_tokenizer = None
        self._qwen_model = None

    # -- model runners -------------------------------------------------------

    def _run_qwen(self, user_msg: str) -> str:
        """Qwen2.5-Coder-0.5B-Instruct β€” fast structured-JSON generation."""
        if self._qwen_model is None:
            from transformers import AutoModelForCausalLM, AutoTokenizer
            logger.info(f"Loading Qwen model: {QWEN_MODEL_ID}")
            self._qwen_tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID)
            self._qwen_model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID)
            logger.info("Qwen model loaded.")
        messages = [
            {"role": "system", "content": _SYSTEM_PROMPT},
            {"role": "user",   "content": user_msg},
        ]
        text = self._qwen_tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self._qwen_tokenizer(text, return_tensors="pt")
        outputs = self._qwen_model.generate(
            **inputs, max_new_tokens=256, temperature=0.1, do_sample=True
        )
        return self._qwen_tokenizer.decode(
            outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True
        )

    def _run_gemini(self, user_msg: str) -> str:
        import google.generativeai as genai
        api_key = os.getenv("GEMINI_API_KEY")
        if not api_key:
            raise ValueError("GEMINI_API_KEY is not set")
        genai.configure(api_key=api_key)
        model = genai.GenerativeModel(
            "gemini-2.0-flash",
            system_instruction=_SYSTEM_PROMPT,
        )
        return model.generate_content(user_msg).text

    def _run_grok(self, user_msg: str) -> str:
        from openai import OpenAI
        api_key = os.getenv("GROK_API_KEY")
        if not api_key:
            raise ValueError("GROK_API_KEY is not set")
        client = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1")
        resp = client.chat.completions.create(
            model="grok-3-mini",
            messages=[
                {"role": "system", "content": _SYSTEM_PROMPT},
                {"role": "user",   "content": user_msg},
            ],
            max_tokens=256,
            temperature=0.1,
        )
        return resp.choices[0].message.content

    def _run_bart(self, query: str) -> str:
        """ArchCoder/fine-tuned-bart-large β€” lightweight Seq2Seq fallback."""
        if self._bart_model is None:
            from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
            logger.info(f"Loading BART model: {BART_MODEL_ID}")
            self._bart_tokenizer = AutoTokenizer.from_pretrained(BART_MODEL_ID)
            self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(BART_MODEL_ID)
            logger.info("BART model loaded.")
        inputs = self._bart_tokenizer(
            query, return_tensors="pt", max_length=512, truncation=True
        )
        outputs = self._bart_model.generate(**inputs, max_length=100)
        return self._bart_tokenizer.decode(outputs[0], skip_special_tokens=True)

    # -- main entry point ----------------------------------------------------

    def process_request(self, data: dict) -> dict:
        t0        = time.time()
        query     = data.get("query", "")
        data_path = data.get("file_path")
        model     = data.get("model", "qwen")

        if data_path and os.path.exists(data_path):
            self.data_processor  = DataProcessor(data_path)
            self.chart_generator = ChartGenerator(self.data_processor.data)

        columns     = self.data_processor.get_columns()
        dtypes      = self.data_processor.get_dtypes()
        sample_rows = self.data_processor.preview(3)

        default_args = {
            "x":          columns[0] if columns else "Year",
            "y":          [columns[1]] if len(columns) > 1 else ["Sales"],
            "chart_type": "line",
        }

        raw_text  = ""
        plot_args = None
        try:
            user_msg = _user_message(query, columns, dtypes, sample_rows)
            if   model == "gemini": raw_text = self._run_gemini(user_msg)
            elif model == "grok":   raw_text = self._run_grok(user_msg)
            elif model == "bart":   raw_text = self._run_bart(query)
            elif model == "qwen":
                try:
                    raw_text = self._run_qwen(user_msg)
                except Exception as qwen_exc:
                    logger.warning(f"Qwen failed, falling back to BART: {qwen_exc}")
                    raw_text = self._run_bart(query)
            else:
                raw_text = self._run_qwen(user_msg)

            logger.info(f"LLM [{model}] output: {raw_text}")
            parsed    = _parse_output(raw_text)
            plot_args = _validate(parsed, columns) if parsed else None
        except Exception as exc:
            logger.error(f"LLM error [{model}]: {exc}")
            raw_text = str(exc)

        if not plot_args:
            logger.warning("Falling back to heuristic plot args")
            plot_args = _validate(_heuristic_plot_args(query, columns, dtypes), columns) or default_args

        try:
            chart_result = self.chart_generator.generate_chart(plot_args)
            chart_path   = chart_result["chart_path"]
            chart_spec   = chart_result["chart_spec"]
        except Exception as exc:
            logger.error(f"Chart generation error: {exc}")
            return {
                "response":   f"Chart generation failed: {exc}",
                "chart_path": "",
                "chart_spec": None,
                "verified":   False,
                "plot_args":  plot_args,
            }

        logger.info(f"Request processed in {time.time() - t0:.2f}s")
        return {
            "response":   json.dumps(plot_args),
            "chart_path": chart_path,
            "chart_spec": chart_spec,
            "verified":   True,
            "plot_args":  plot_args,
        }