| | import ast |
| | import difflib |
| | import json |
| | import logging |
| | import os |
| | import re |
| | import time |
| |
|
| | from dotenv import load_dotenv |
| |
|
| | from chart_generator import ChartGenerator |
| | from data_processor import DataProcessor |
| |
|
| | load_dotenv() |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | |
| | |
| | QWEN_MODEL_ID = os.getenv("QWEN_MODEL_ID", "Qwen/Qwen2.5-Coder-0.5B-Instruct") |
| | BART_MODEL_ID = os.getenv("BART_MODEL_ID", "ArchCoder/fine-tuned-bart-large") |
| |
|
| | |
| | |
| | |
| |
|
| | _SYSTEM_PROMPT = """\ |
| | You are a data visualization expert. Given the user request and dataset schema, \ |
| | output ONLY a valid JSON object. No explanation, no markdown fences, no extra text. |
| | |
| | Required JSON keys: |
| | "x" : string β exact column name for the x-axis |
| | "y" : array β one or more exact column names for the y-axis |
| | "chart_type" : string β one of: line, bar, scatter, pie, histogram, box, area |
| | "color" : string or null β optional CSS color like "red", "#4f8cff" |
| | |
| | Rules: |
| | - Use ONLY column names from the schema. Never invent names. |
| | - For pie charts: y must contain exactly one column. |
| | - For histogram/box: x may equal the first element of y. |
| | - Default to "line" if chart type is ambiguous. |
| | |
| | ### Examples |
| | |
| | Example 1: |
| | Schema: Year (integer), Sales (float), Profit (float) |
| | User: "plot sales over the years with a red line" |
| | Output: {"x": "Year", "y": ["Sales"], "chart_type": "line", "color": "red"} |
| | |
| | Example 2: |
| | Schema: Month (string), Revenue (float), Expenses (float) |
| | User: "bar chart comparing revenue and expenses by month" |
| | Output: {"x": "Month", "y": ["Revenue", "Expenses"], "chart_type": "bar", "color": null} |
| | |
| | Example 3: |
| | Schema: Category (string), Count (integer) |
| | User: "pie chart of count by category" |
| | Output: {"x": "Category", "y": ["Count"], "chart_type": "pie", "color": null} |
| | |
| | Example 4: |
| | Schema: Date (string), Temperature (float), Humidity (float) |
| | User: "scatter plot of temperature vs humidity in blue" |
| | Output: {"x": "Temperature", "y": ["Humidity"], "chart_type": "scatter", "color": "blue"} |
| | |
| | Example 5: |
| | Schema: Year (integer), Sales (float), Employee expense (float), Marketing expense (float) |
| | User: "show me an area chart of sales and marketing expense over years" |
| | Output: {"x": "Year", "y": ["Sales", "Marketing expense"], "chart_type": "area", "color": null} |
| | """ |
| |
|
| |
|
| | def _user_message(query: str, columns: list, dtypes: dict, sample_rows: list) -> str: |
| | schema = "\n".join(f" - {c} ({dtypes.get(c, 'unknown')})" for c in columns) |
| | samples = "".join(f" {json.dumps(r)}\n" for r in sample_rows[:3]) |
| | return ( |
| | f"Schema:\n{schema}\n\n" |
| | f"Sample rows:\n{samples}\n" |
| | f"User: \"{query}\"\n" |
| | f"Output:" |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _parse_output(text: str): |
| | text = text.strip() |
| | if "```" in text: |
| | for part in text.split("```"): |
| | part = part.strip().lstrip("json").strip() |
| | if part.startswith("{"): |
| | text = part |
| | break |
| | try: |
| | return json.loads(text) |
| | except json.JSONDecodeError: |
| | pass |
| | try: |
| | return ast.literal_eval(text) |
| | except (SyntaxError, ValueError): |
| | pass |
| | return None |
| |
|
| |
|
| | def _validate(args: dict, columns: list): |
| | if not isinstance(args, dict): |
| | return None |
| | if not all(k in args for k in ("x", "y", "chart_type")): |
| | return None |
| | if isinstance(args["y"], str): |
| | args["y"] = [args["y"]] |
| | valid = {"line", "bar", "scatter", "pie", "histogram", "box", "area"} |
| | if args["chart_type"] not in valid: |
| | args["chart_type"] = "line" |
| | if args["x"] not in columns: |
| | return None |
| | if not all(c in columns for c in args["y"]): |
| | return None |
| | return args |
| |
|
| |
|
| | def _pick_chart_type(query: str) -> str: |
| | lowered = query.lower() |
| | aliases = { |
| | "scatter": ["scatter", "scatterplot"], |
| | "bar": ["bar", "column"], |
| | "pie": ["pie", "donut"], |
| | "histogram": ["histogram", "distribution"], |
| | "box": ["box", "boxplot"], |
| | "area": ["area"], |
| | "line": ["line", "trend", "over time", "over the years"], |
| | } |
| | for chart_type, keywords in aliases.items(): |
| | if any(keyword in lowered for keyword in keywords): |
| | return chart_type |
| | return "line" |
| |
|
| |
|
| | def _pick_color(query: str): |
| | lowered = query.lower() |
| | colors = [ |
| | "red", "blue", "green", "yellow", "orange", "purple", "pink", |
| | "black", "white", "gray", "grey", "cyan", "teal", "indigo", |
| | ] |
| | for color in colors: |
| | if re.search(rf"\b{re.escape(color)}\b", lowered): |
| | return color |
| | return None |
| |
|
| |
|
| | def _pick_columns(query: str, columns: list, dtypes: dict): |
| | lowered = query.lower() |
| | query_tokens = re.findall(r"[a-zA-Z0-9_]+", lowered) |
| |
|
| | def score_column(column: str) -> float: |
| | col_lower = column.lower() |
| | score = 0.0 |
| | if col_lower in lowered: |
| | score += 10.0 |
| | for token in query_tokens: |
| | if token and token in col_lower: |
| | score += 2.0 |
| | score += difflib.SequenceMatcher(None, lowered, col_lower).ratio() |
| | return score |
| |
|
| | sorted_columns = sorted(columns, key=score_column, reverse=True) |
| | numeric_columns = [col for col in columns if dtypes.get(col) in {"integer", "float"}] |
| | temporal_columns = [col for col in columns if dtypes.get(col) == "datetime"] |
| | year_like = [col for col in columns if "year" in col.lower() or "date" in col.lower() or "month" in col.lower()] |
| |
|
| | x_col = None |
| | for candidate in year_like + temporal_columns + sorted_columns: |
| | if candidate in columns: |
| | x_col = candidate |
| | break |
| | if x_col is None and columns: |
| | x_col = columns[0] |
| |
|
| | y_candidates = [col for col in sorted_columns if col != x_col and col in numeric_columns] |
| | if not y_candidates: |
| | y_candidates = [col for col in numeric_columns if col != x_col] |
| | if not y_candidates: |
| | y_candidates = [col for col in columns if col != x_col] |
| |
|
| | return x_col, y_candidates[:1] |
| |
|
| |
|
| | def _heuristic_plot_args(query: str, columns: list, dtypes: dict) -> dict: |
| | x_col, y_cols = _pick_columns(query, columns, dtypes) |
| | if not x_col: |
| | x_col = "Year" |
| | if not y_cols: |
| | fallback_y = next((col for col in columns if col != x_col), columns[:1]) |
| | y_cols = list(fallback_y) if isinstance(fallback_y, tuple) else fallback_y |
| | if isinstance(y_cols, str): |
| | y_cols = [y_cols] |
| | return { |
| | "x": x_col, |
| | "y": y_cols, |
| | "chart_type": _pick_chart_type(query), |
| | "color": _pick_color(query), |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class LLM_Agent: |
| | def __init__(self, data_path=None): |
| | logger.info("Initializing LLM_Agent") |
| | self.data_processor = DataProcessor(data_path) |
| | self.chart_generator = ChartGenerator(self.data_processor.data) |
| | self._bart_tokenizer = None |
| | self._bart_model = None |
| | self._qwen_tokenizer = None |
| | self._qwen_model = None |
| |
|
| | |
| |
|
| | def _run_qwen(self, user_msg: str) -> str: |
| | """Qwen2.5-Coder-0.5B-Instruct β fast structured-JSON generation.""" |
| | if self._qwen_model is None: |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | logger.info(f"Loading Qwen model: {QWEN_MODEL_ID}") |
| | self._qwen_tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID) |
| | self._qwen_model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID) |
| | logger.info("Qwen model loaded.") |
| | messages = [ |
| | {"role": "system", "content": _SYSTEM_PROMPT}, |
| | {"role": "user", "content": user_msg}, |
| | ] |
| | text = self._qwen_tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=True |
| | ) |
| | inputs = self._qwen_tokenizer(text, return_tensors="pt") |
| | outputs = self._qwen_model.generate( |
| | **inputs, max_new_tokens=256, temperature=0.1, do_sample=True |
| | ) |
| | return self._qwen_tokenizer.decode( |
| | outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True |
| | ) |
| |
|
| | def _run_gemini(self, user_msg: str) -> str: |
| | import google.generativeai as genai |
| | api_key = os.getenv("GEMINI_API_KEY") |
| | if not api_key: |
| | raise ValueError("GEMINI_API_KEY is not set") |
| | genai.configure(api_key=api_key) |
| | model = genai.GenerativeModel( |
| | "gemini-2.0-flash", |
| | system_instruction=_SYSTEM_PROMPT, |
| | ) |
| | return model.generate_content(user_msg).text |
| |
|
| | def _run_grok(self, user_msg: str) -> str: |
| | from openai import OpenAI |
| | api_key = os.getenv("GROK_API_KEY") |
| | if not api_key: |
| | raise ValueError("GROK_API_KEY is not set") |
| | client = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1") |
| | resp = client.chat.completions.create( |
| | model="grok-3-mini", |
| | messages=[ |
| | {"role": "system", "content": _SYSTEM_PROMPT}, |
| | {"role": "user", "content": user_msg}, |
| | ], |
| | max_tokens=256, |
| | temperature=0.1, |
| | ) |
| | return resp.choices[0].message.content |
| |
|
| | def _run_bart(self, query: str) -> str: |
| | """ArchCoder/fine-tuned-bart-large β lightweight Seq2Seq fallback.""" |
| | if self._bart_model is None: |
| | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| | logger.info(f"Loading BART model: {BART_MODEL_ID}") |
| | self._bart_tokenizer = AutoTokenizer.from_pretrained(BART_MODEL_ID) |
| | self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(BART_MODEL_ID) |
| | logger.info("BART model loaded.") |
| | inputs = self._bart_tokenizer( |
| | query, return_tensors="pt", max_length=512, truncation=True |
| | ) |
| | outputs = self._bart_model.generate(**inputs, max_length=100) |
| | return self._bart_tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | |
| |
|
| | def process_request(self, data: dict) -> dict: |
| | t0 = time.time() |
| | query = data.get("query", "") |
| | data_path = data.get("file_path") |
| | model = data.get("model", "qwen") |
| |
|
| | if data_path and os.path.exists(data_path): |
| | self.data_processor = DataProcessor(data_path) |
| | self.chart_generator = ChartGenerator(self.data_processor.data) |
| |
|
| | columns = self.data_processor.get_columns() |
| | dtypes = self.data_processor.get_dtypes() |
| | sample_rows = self.data_processor.preview(3) |
| |
|
| | default_args = { |
| | "x": columns[0] if columns else "Year", |
| | "y": [columns[1]] if len(columns) > 1 else ["Sales"], |
| | "chart_type": "line", |
| | } |
| |
|
| | raw_text = "" |
| | plot_args = None |
| | try: |
| | user_msg = _user_message(query, columns, dtypes, sample_rows) |
| | if model == "gemini": raw_text = self._run_gemini(user_msg) |
| | elif model == "grok": raw_text = self._run_grok(user_msg) |
| | elif model == "bart": raw_text = self._run_bart(query) |
| | elif model == "qwen": |
| | try: |
| | raw_text = self._run_qwen(user_msg) |
| | except Exception as qwen_exc: |
| | logger.warning(f"Qwen failed, falling back to BART: {qwen_exc}") |
| | raw_text = self._run_bart(query) |
| | else: |
| | raw_text = self._run_qwen(user_msg) |
| |
|
| | logger.info(f"LLM [{model}] output: {raw_text}") |
| | parsed = _parse_output(raw_text) |
| | plot_args = _validate(parsed, columns) if parsed else None |
| | except Exception as exc: |
| | logger.error(f"LLM error [{model}]: {exc}") |
| | raw_text = str(exc) |
| |
|
| | if not plot_args: |
| | logger.warning("Falling back to heuristic plot args") |
| | plot_args = _validate(_heuristic_plot_args(query, columns, dtypes), columns) or default_args |
| |
|
| | try: |
| | chart_result = self.chart_generator.generate_chart(plot_args) |
| | chart_path = chart_result["chart_path"] |
| | chart_spec = chart_result["chart_spec"] |
| | except Exception as exc: |
| | logger.error(f"Chart generation error: {exc}") |
| | return { |
| | "response": f"Chart generation failed: {exc}", |
| | "chart_path": "", |
| | "chart_spec": None, |
| | "verified": False, |
| | "plot_args": plot_args, |
| | } |
| |
|
| | logger.info(f"Request processed in {time.time() - t0:.2f}s") |
| | return { |
| | "response": json.dumps(plot_args), |
| | "chart_path": chart_path, |
| | "chart_spec": chart_spec, |
| | "verified": True, |
| | "plot_args": plot_args, |
| | } |
| |
|