llm-excel-plotter-agent / llm_agent.py
Priyansh Saxena
feat: download Qwen2.5-Coder-0.5B + BART at build, add few-shot prompts
cb6c215
import ast
import difflib
import json
import logging
import os
import re
import time
from dotenv import load_dotenv
from chart_generator import ChartGenerator
from data_processor import DataProcessor
load_dotenv()
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Model IDs (downloaded at Docker build, cached in HF_HOME)
# ---------------------------------------------------------------------------
QWEN_MODEL_ID = os.getenv("QWEN_MODEL_ID", "Qwen/Qwen2.5-Coder-0.5B-Instruct")
BART_MODEL_ID = os.getenv("BART_MODEL_ID", "ArchCoder/fine-tuned-bart-large")
# ---------------------------------------------------------------------------
# Prompt templates with few-shot examples
# ---------------------------------------------------------------------------
_SYSTEM_PROMPT = """\
You are a data visualization expert. Given the user request and dataset schema, \
output ONLY a valid JSON object. No explanation, no markdown fences, no extra text.
Required JSON keys:
"x" : string β€” exact column name for the x-axis
"y" : array β€” one or more exact column names for the y-axis
"chart_type" : string β€” one of: line, bar, scatter, pie, histogram, box, area
"color" : string or null β€” optional CSS color like "red", "#4f8cff"
Rules:
- Use ONLY column names from the schema. Never invent names.
- For pie charts: y must contain exactly one column.
- For histogram/box: x may equal the first element of y.
- Default to "line" if chart type is ambiguous.
### Examples
Example 1:
Schema: Year (integer), Sales (float), Profit (float)
User: "plot sales over the years with a red line"
Output: {"x": "Year", "y": ["Sales"], "chart_type": "line", "color": "red"}
Example 2:
Schema: Month (string), Revenue (float), Expenses (float)
User: "bar chart comparing revenue and expenses by month"
Output: {"x": "Month", "y": ["Revenue", "Expenses"], "chart_type": "bar", "color": null}
Example 3:
Schema: Category (string), Count (integer)
User: "pie chart of count by category"
Output: {"x": "Category", "y": ["Count"], "chart_type": "pie", "color": null}
Example 4:
Schema: Date (string), Temperature (float), Humidity (float)
User: "scatter plot of temperature vs humidity in blue"
Output: {"x": "Temperature", "y": ["Humidity"], "chart_type": "scatter", "color": "blue"}
Example 5:
Schema: Year (integer), Sales (float), Employee expense (float), Marketing expense (float)
User: "show me an area chart of sales and marketing expense over years"
Output: {"x": "Year", "y": ["Sales", "Marketing expense"], "chart_type": "area", "color": null}
"""
def _user_message(query: str, columns: list, dtypes: dict, sample_rows: list) -> str:
schema = "\n".join(f" - {c} ({dtypes.get(c, 'unknown')})" for c in columns)
samples = "".join(f" {json.dumps(r)}\n" for r in sample_rows[:3])
return (
f"Schema:\n{schema}\n\n"
f"Sample rows:\n{samples}\n"
f"User: \"{query}\"\n"
f"Output:"
)
# ---------------------------------------------------------------------------
# Output parsing & validation
# ---------------------------------------------------------------------------
def _parse_output(text: str):
text = text.strip()
if "```" in text:
for part in text.split("```"):
part = part.strip().lstrip("json").strip()
if part.startswith("{"):
text = part
break
try:
return json.loads(text)
except json.JSONDecodeError:
pass
try:
return ast.literal_eval(text)
except (SyntaxError, ValueError):
pass
return None
def _validate(args: dict, columns: list):
if not isinstance(args, dict):
return None
if not all(k in args for k in ("x", "y", "chart_type")):
return None
if isinstance(args["y"], str):
args["y"] = [args["y"]]
valid = {"line", "bar", "scatter", "pie", "histogram", "box", "area"}
if args["chart_type"] not in valid:
args["chart_type"] = "line"
if args["x"] not in columns:
return None
if not all(c in columns for c in args["y"]):
return None
return args
def _pick_chart_type(query: str) -> str:
lowered = query.lower()
aliases = {
"scatter": ["scatter", "scatterplot"],
"bar": ["bar", "column"],
"pie": ["pie", "donut"],
"histogram": ["histogram", "distribution"],
"box": ["box", "boxplot"],
"area": ["area"],
"line": ["line", "trend", "over time", "over the years"],
}
for chart_type, keywords in aliases.items():
if any(keyword in lowered for keyword in keywords):
return chart_type
return "line"
def _pick_color(query: str):
lowered = query.lower()
colors = [
"red", "blue", "green", "yellow", "orange", "purple", "pink",
"black", "white", "gray", "grey", "cyan", "teal", "indigo",
]
for color in colors:
if re.search(rf"\b{re.escape(color)}\b", lowered):
return color
return None
def _pick_columns(query: str, columns: list, dtypes: dict):
lowered = query.lower()
query_tokens = re.findall(r"[a-zA-Z0-9_]+", lowered)
def score_column(column: str) -> float:
col_lower = column.lower()
score = 0.0
if col_lower in lowered:
score += 10.0
for token in query_tokens:
if token and token in col_lower:
score += 2.0
score += difflib.SequenceMatcher(None, lowered, col_lower).ratio()
return score
sorted_columns = sorted(columns, key=score_column, reverse=True)
numeric_columns = [col for col in columns if dtypes.get(col) in {"integer", "float"}]
temporal_columns = [col for col in columns if dtypes.get(col) == "datetime"]
year_like = [col for col in columns if "year" in col.lower() or "date" in col.lower() or "month" in col.lower()]
x_col = None
for candidate in year_like + temporal_columns + sorted_columns:
if candidate in columns:
x_col = candidate
break
if x_col is None and columns:
x_col = columns[0]
y_candidates = [col for col in sorted_columns if col != x_col and col in numeric_columns]
if not y_candidates:
y_candidates = [col for col in numeric_columns if col != x_col]
if not y_candidates:
y_candidates = [col for col in columns if col != x_col]
return x_col, y_candidates[:1]
def _heuristic_plot_args(query: str, columns: list, dtypes: dict) -> dict:
x_col, y_cols = _pick_columns(query, columns, dtypes)
if not x_col:
x_col = "Year"
if not y_cols:
fallback_y = next((col for col in columns if col != x_col), columns[:1])
y_cols = list(fallback_y) if isinstance(fallback_y, tuple) else fallback_y
if isinstance(y_cols, str):
y_cols = [y_cols]
return {
"x": x_col,
"y": y_cols,
"chart_type": _pick_chart_type(query),
"color": _pick_color(query),
}
# ---------------------------------------------------------------------------
# Agent
# ---------------------------------------------------------------------------
class LLM_Agent:
def __init__(self, data_path=None):
logger.info("Initializing LLM_Agent")
self.data_processor = DataProcessor(data_path)
self.chart_generator = ChartGenerator(self.data_processor.data)
self._bart_tokenizer = None
self._bart_model = None
self._qwen_tokenizer = None
self._qwen_model = None
# -- model runners -------------------------------------------------------
def _run_qwen(self, user_msg: str) -> str:
"""Qwen2.5-Coder-0.5B-Instruct β€” fast structured-JSON generation."""
if self._qwen_model is None:
from transformers import AutoModelForCausalLM, AutoTokenizer
logger.info(f"Loading Qwen model: {QWEN_MODEL_ID}")
self._qwen_tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID)
self._qwen_model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID)
logger.info("Qwen model loaded.")
messages = [
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": user_msg},
]
text = self._qwen_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self._qwen_tokenizer(text, return_tensors="pt")
outputs = self._qwen_model.generate(
**inputs, max_new_tokens=256, temperature=0.1, do_sample=True
)
return self._qwen_tokenizer.decode(
outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True
)
def _run_gemini(self, user_msg: str) -> str:
import google.generativeai as genai
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY is not set")
genai.configure(api_key=api_key)
model = genai.GenerativeModel(
"gemini-2.0-flash",
system_instruction=_SYSTEM_PROMPT,
)
return model.generate_content(user_msg).text
def _run_grok(self, user_msg: str) -> str:
from openai import OpenAI
api_key = os.getenv("GROK_API_KEY")
if not api_key:
raise ValueError("GROK_API_KEY is not set")
client = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1")
resp = client.chat.completions.create(
model="grok-3-mini",
messages=[
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": user_msg},
],
max_tokens=256,
temperature=0.1,
)
return resp.choices[0].message.content
def _run_bart(self, query: str) -> str:
"""ArchCoder/fine-tuned-bart-large β€” lightweight Seq2Seq fallback."""
if self._bart_model is None:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger.info(f"Loading BART model: {BART_MODEL_ID}")
self._bart_tokenizer = AutoTokenizer.from_pretrained(BART_MODEL_ID)
self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(BART_MODEL_ID)
logger.info("BART model loaded.")
inputs = self._bart_tokenizer(
query, return_tensors="pt", max_length=512, truncation=True
)
outputs = self._bart_model.generate(**inputs, max_length=100)
return self._bart_tokenizer.decode(outputs[0], skip_special_tokens=True)
# -- main entry point ----------------------------------------------------
def process_request(self, data: dict) -> dict:
t0 = time.time()
query = data.get("query", "")
data_path = data.get("file_path")
model = data.get("model", "qwen")
if data_path and os.path.exists(data_path):
self.data_processor = DataProcessor(data_path)
self.chart_generator = ChartGenerator(self.data_processor.data)
columns = self.data_processor.get_columns()
dtypes = self.data_processor.get_dtypes()
sample_rows = self.data_processor.preview(3)
default_args = {
"x": columns[0] if columns else "Year",
"y": [columns[1]] if len(columns) > 1 else ["Sales"],
"chart_type": "line",
}
raw_text = ""
plot_args = None
try:
user_msg = _user_message(query, columns, dtypes, sample_rows)
if model == "gemini": raw_text = self._run_gemini(user_msg)
elif model == "grok": raw_text = self._run_grok(user_msg)
elif model == "bart": raw_text = self._run_bart(query)
elif model == "qwen":
try:
raw_text = self._run_qwen(user_msg)
except Exception as qwen_exc:
logger.warning(f"Qwen failed, falling back to BART: {qwen_exc}")
raw_text = self._run_bart(query)
else:
raw_text = self._run_qwen(user_msg)
logger.info(f"LLM [{model}] output: {raw_text}")
parsed = _parse_output(raw_text)
plot_args = _validate(parsed, columns) if parsed else None
except Exception as exc:
logger.error(f"LLM error [{model}]: {exc}")
raw_text = str(exc)
if not plot_args:
logger.warning("Falling back to heuristic plot args")
plot_args = _validate(_heuristic_plot_args(query, columns, dtypes), columns) or default_args
try:
chart_result = self.chart_generator.generate_chart(plot_args)
chart_path = chart_result["chart_path"]
chart_spec = chart_result["chart_spec"]
except Exception as exc:
logger.error(f"Chart generation error: {exc}")
return {
"response": f"Chart generation failed: {exc}",
"chart_path": "",
"chart_spec": None,
"verified": False,
"plot_args": plot_args,
}
logger.info(f"Request processed in {time.time() - t0:.2f}s")
return {
"response": json.dumps(plot_args),
"chart_path": chart_path,
"chart_spec": chart_spec,
"verified": True,
"plot_args": plot_args,
}