Priyansh Saxena commited on
Commit ·
f3fd40f
1
Parent(s): 962831e
fix: remove runtime model dependency and repair chart generation
Browse files- app.py +2 -2
- chart_generator.py +8 -10
- llm_agent.py +107 -5
app.py
CHANGED
|
@@ -17,7 +17,7 @@ logging.getLogger('plotly').setLevel(logging.WARNING)
|
|
| 17 |
|
| 18 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
|
| 20 |
-
app = Flask(__name__, static_folder=os.path.join(BASE_DIR, '
|
| 21 |
|
| 22 |
CORS(app, origins=[
|
| 23 |
"https://llm-integrated-excel-plotter-app.vercel.app",
|
|
@@ -27,7 +27,7 @@ CORS(app, origins=[
|
|
| 27 |
|
| 28 |
agent = LLM_Agent()
|
| 29 |
|
| 30 |
-
UPLOAD_FOLDER = os.path.join(BASE_DIR, '
|
| 31 |
ALLOWED_EXTENSIONS = {'csv', 'xls', 'xlsx'}
|
| 32 |
MAX_UPLOAD_BYTES = 10 * 1024 * 1024 # 10 MB
|
| 33 |
|
|
|
|
| 17 |
|
| 18 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
|
| 20 |
+
app = Flask(__name__, static_folder=os.path.join(BASE_DIR, 'static'))
|
| 21 |
|
| 22 |
CORS(app, origins=[
|
| 23 |
"https://llm-integrated-excel-plotter-app.vercel.app",
|
|
|
|
| 27 |
|
| 28 |
agent = LLM_Agent()
|
| 29 |
|
| 30 |
+
UPLOAD_FOLDER = os.path.join(BASE_DIR, 'data', 'uploads')
|
| 31 |
ALLOWED_EXTENSIONS = {'csv', 'xls', 'xlsx'}
|
| 32 |
MAX_UPLOAD_BYTES = 10 * 1024 * 1024 # 10 MB
|
| 33 |
|
chart_generator.py
CHANGED
|
@@ -130,7 +130,7 @@ class ChartGenerator:
|
|
| 130 |
if chart_type not in ("pie", "histogram", "box") and len(x) > 5:
|
| 131 |
plt.xticks(rotation=45, ha="right")
|
| 132 |
|
| 133 |
-
output_dir = os.path.join(os.path.dirname(
|
| 134 |
os.makedirs(output_dir, exist_ok=True)
|
| 135 |
filename = f"chart_{uuid.uuid4().hex[:12]}.png"
|
| 136 |
full_path = os.path.join(output_dir, filename)
|
|
@@ -174,14 +174,12 @@ class ChartGenerator:
|
|
| 174 |
line=dict(color=c, width=2),
|
| 175 |
marker=dict(size=6)).to_plotly_json())
|
| 176 |
|
| 177 |
-
layout =
|
| 178 |
-
|
| 179 |
-
title
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
yaxis=dict(**_PLOTLY_LAYOUT["yaxis"], title=" / ".join(y_cols)),
|
| 185 |
-
)
|
| 186 |
|
| 187 |
return {"data": traces, "layout": layout}
|
|
|
|
| 130 |
if chart_type not in ("pie", "histogram", "box") and len(x) > 5:
|
| 131 |
plt.xticks(rotation=45, ha="right")
|
| 132 |
|
| 133 |
+
output_dir = os.path.join(os.path.dirname(__file__), "static", "images")
|
| 134 |
os.makedirs(output_dir, exist_ok=True)
|
| 135 |
filename = f"chart_{uuid.uuid4().hex[:12]}.png"
|
| 136 |
full_path = os.path.join(output_dir, filename)
|
|
|
|
| 174 |
line=dict(color=c, width=2),
|
| 175 |
marker=dict(size=6)).to_plotly_json())
|
| 176 |
|
| 177 |
+
layout = {**_PLOTLY_LAYOUT}
|
| 178 |
+
layout["title"] = {
|
| 179 |
+
"text": f"{chart_type.title()} \u2014 {', '.join(y_cols)} vs {x_col}",
|
| 180 |
+
"font": {"size": 15, "color": "#e2e8f0"},
|
| 181 |
+
}
|
| 182 |
+
layout["xaxis"] = {**_PLOTLY_LAYOUT["xaxis"], "title": x_col}
|
| 183 |
+
layout["yaxis"] = {**_PLOTLY_LAYOUT["yaxis"], "title": " / ".join(y_cols)}
|
|
|
|
|
|
|
| 184 |
|
| 185 |
return {"data": traces, "layout": layout}
|
llm_agent.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
import ast
|
|
|
|
| 2 |
import json
|
| 3 |
import logging
|
| 4 |
import os
|
|
|
|
| 5 |
import time
|
| 6 |
|
| 7 |
from dotenv import load_dotenv
|
|
@@ -17,6 +19,18 @@ logger = logging.getLogger(__name__)
|
|
| 17 |
def _model_dir(dirname: str) -> str:
|
| 18 |
return os.path.join(os.path.dirname(os.path.abspath(__file__)), dirname)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# ---------------------------------------------------------------------------
|
| 21 |
# Prompt templates
|
| 22 |
# ---------------------------------------------------------------------------
|
|
@@ -88,6 +102,89 @@ def _validate(args: dict, columns: list):
|
|
| 88 |
return args
|
| 89 |
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
# ---------------------------------------------------------------------------
|
| 92 |
# Agent
|
| 93 |
# ---------------------------------------------------------------------------
|
|
@@ -101,16 +198,19 @@ class LLM_Agent:
|
|
| 101 |
self._bart_model = None
|
| 102 |
self._qwen_tokenizer = None
|
| 103 |
self._qwen_model = None
|
|
|
|
|
|
|
| 104 |
|
| 105 |
# -- model runners -------------------------------------------------------
|
| 106 |
|
| 107 |
def _run_qwen(self, user_msg: str) -> str:
|
| 108 |
if self._qwen_model is None:
|
| 109 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 110 |
-
|
| 111 |
-
model_id = os.getenv("QWEN_LOCAL_PATH", "")
|
| 112 |
if not model_id:
|
| 113 |
raise ValueError("Qwen local model is not configured in this Space")
|
|
|
|
|
|
|
| 114 |
logger.info("Loading Qwen model (first request)...")
|
| 115 |
self._qwen_tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
|
| 116 |
self._qwen_model = AutoModelForCausalLM.from_pretrained(model_id, local_files_only=True)
|
|
@@ -162,7 +262,9 @@ class LLM_Agent:
|
|
| 162 |
def _run_bart(self, query: str) -> str:
|
| 163 |
if self._bart_model is None:
|
| 164 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 165 |
-
model_id =
|
|
|
|
|
|
|
| 166 |
logger.info("Loading BART model (first request)...")
|
| 167 |
self._bart_tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
|
| 168 |
self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, local_files_only=True)
|
|
@@ -218,8 +320,8 @@ class LLM_Agent:
|
|
| 218 |
raw_text = str(exc)
|
| 219 |
|
| 220 |
if not plot_args:
|
| 221 |
-
logger.warning("Falling back to
|
| 222 |
-
plot_args = default_args
|
| 223 |
|
| 224 |
try:
|
| 225 |
chart_result = self.chart_generator.generate_chart(plot_args)
|
|
|
|
| 1 |
import ast
|
| 2 |
+
import difflib
|
| 3 |
import json
|
| 4 |
import logging
|
| 5 |
import os
|
| 6 |
+
import re
|
| 7 |
import time
|
| 8 |
|
| 9 |
from dotenv import load_dotenv
|
|
|
|
| 19 |
def _model_dir(dirname: str) -> str:
|
| 20 |
return os.path.join(os.path.dirname(os.path.abspath(__file__)), dirname)
|
| 21 |
|
| 22 |
+
|
| 23 |
+
def _has_model_weights(model_dir: str) -> bool:
|
| 24 |
+
weight_files = (
|
| 25 |
+
"pytorch_model.bin",
|
| 26 |
+
"model.safetensors",
|
| 27 |
+
"tf_model.h5",
|
| 28 |
+
"flax_model.msgpack",
|
| 29 |
+
)
|
| 30 |
+
return os.path.isdir(model_dir) and any(
|
| 31 |
+
os.path.exists(os.path.join(model_dir, filename)) for filename in weight_files
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
# ---------------------------------------------------------------------------
|
| 35 |
# Prompt templates
|
| 36 |
# ---------------------------------------------------------------------------
|
|
|
|
| 102 |
return args
|
| 103 |
|
| 104 |
|
| 105 |
+
def _pick_chart_type(query: str) -> str:
|
| 106 |
+
lowered = query.lower()
|
| 107 |
+
aliases = {
|
| 108 |
+
"scatter": ["scatter", "scatterplot"],
|
| 109 |
+
"bar": ["bar", "column"],
|
| 110 |
+
"pie": ["pie", "donut"],
|
| 111 |
+
"histogram": ["histogram", "distribution"],
|
| 112 |
+
"box": ["box", "boxplot"],
|
| 113 |
+
"area": ["area"],
|
| 114 |
+
"line": ["line", "trend", "over time", "over the years"],
|
| 115 |
+
}
|
| 116 |
+
for chart_type, keywords in aliases.items():
|
| 117 |
+
if any(keyword in lowered for keyword in keywords):
|
| 118 |
+
return chart_type
|
| 119 |
+
return "line"
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _pick_color(query: str):
|
| 123 |
+
lowered = query.lower()
|
| 124 |
+
colors = [
|
| 125 |
+
"red", "blue", "green", "yellow", "orange", "purple", "pink",
|
| 126 |
+
"black", "white", "gray", "grey", "cyan", "teal", "indigo",
|
| 127 |
+
]
|
| 128 |
+
for color in colors:
|
| 129 |
+
if re.search(rf"\b{re.escape(color)}\b", lowered):
|
| 130 |
+
return color
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _pick_columns(query: str, columns: list, dtypes: dict):
|
| 135 |
+
lowered = query.lower()
|
| 136 |
+
query_tokens = re.findall(r"[a-zA-Z0-9_]+", lowered)
|
| 137 |
+
|
| 138 |
+
def score_column(column: str) -> float:
|
| 139 |
+
col_lower = column.lower()
|
| 140 |
+
score = 0.0
|
| 141 |
+
if col_lower in lowered:
|
| 142 |
+
score += 10.0
|
| 143 |
+
for token in query_tokens:
|
| 144 |
+
if token and token in col_lower:
|
| 145 |
+
score += 2.0
|
| 146 |
+
score += difflib.SequenceMatcher(None, lowered, col_lower).ratio()
|
| 147 |
+
return score
|
| 148 |
+
|
| 149 |
+
sorted_columns = sorted(columns, key=score_column, reverse=True)
|
| 150 |
+
numeric_columns = [col for col in columns if dtypes.get(col) in {"integer", "float"}]
|
| 151 |
+
temporal_columns = [col for col in columns if dtypes.get(col) == "datetime"]
|
| 152 |
+
year_like = [col for col in columns if "year" in col.lower() or "date" in col.lower() or "month" in col.lower()]
|
| 153 |
+
|
| 154 |
+
x_col = None
|
| 155 |
+
for candidate in year_like + temporal_columns + sorted_columns:
|
| 156 |
+
if candidate in columns:
|
| 157 |
+
x_col = candidate
|
| 158 |
+
break
|
| 159 |
+
if x_col is None and columns:
|
| 160 |
+
x_col = columns[0]
|
| 161 |
+
|
| 162 |
+
y_candidates = [col for col in sorted_columns if col != x_col and col in numeric_columns]
|
| 163 |
+
if not y_candidates:
|
| 164 |
+
y_candidates = [col for col in numeric_columns if col != x_col]
|
| 165 |
+
if not y_candidates:
|
| 166 |
+
y_candidates = [col for col in columns if col != x_col]
|
| 167 |
+
|
| 168 |
+
return x_col, y_candidates[:1]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _heuristic_plot_args(query: str, columns: list, dtypes: dict) -> dict:
|
| 172 |
+
x_col, y_cols = _pick_columns(query, columns, dtypes)
|
| 173 |
+
if not x_col:
|
| 174 |
+
x_col = "Year"
|
| 175 |
+
if not y_cols:
|
| 176 |
+
fallback_y = next((col for col in columns if col != x_col), columns[:1])
|
| 177 |
+
y_cols = list(fallback_y) if isinstance(fallback_y, tuple) else fallback_y
|
| 178 |
+
if isinstance(y_cols, str):
|
| 179 |
+
y_cols = [y_cols]
|
| 180 |
+
return {
|
| 181 |
+
"x": x_col,
|
| 182 |
+
"y": y_cols,
|
| 183 |
+
"chart_type": _pick_chart_type(query),
|
| 184 |
+
"color": _pick_color(query),
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
|
| 188 |
# ---------------------------------------------------------------------------
|
| 189 |
# Agent
|
| 190 |
# ---------------------------------------------------------------------------
|
|
|
|
| 198 |
self._bart_model = None
|
| 199 |
self._qwen_tokenizer = None
|
| 200 |
self._qwen_model = None
|
| 201 |
+
self._bart_model_dir = os.getenv("BART_LOCAL_PATH", _model_dir("fine-tuned-bart-large"))
|
| 202 |
+
self._qwen_model_dir = os.getenv("QWEN_LOCAL_PATH", "")
|
| 203 |
|
| 204 |
# -- model runners -------------------------------------------------------
|
| 205 |
|
| 206 |
def _run_qwen(self, user_msg: str) -> str:
|
| 207 |
if self._qwen_model is None:
|
| 208 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 209 |
+
model_id = self._qwen_model_dir
|
|
|
|
| 210 |
if not model_id:
|
| 211 |
raise ValueError("Qwen local model is not configured in this Space")
|
| 212 |
+
if not _has_model_weights(model_id):
|
| 213 |
+
raise ValueError(f"Qwen model weights not found in {model_id}")
|
| 214 |
logger.info("Loading Qwen model (first request)...")
|
| 215 |
self._qwen_tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
|
| 216 |
self._qwen_model = AutoModelForCausalLM.from_pretrained(model_id, local_files_only=True)
|
|
|
|
| 262 |
def _run_bart(self, query: str) -> str:
|
| 263 |
if self._bart_model is None:
|
| 264 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 265 |
+
model_id = self._bart_model_dir
|
| 266 |
+
if not _has_model_weights(model_id):
|
| 267 |
+
raise ValueError(f"BART model weights not found in {model_id}")
|
| 268 |
logger.info("Loading BART model (first request)...")
|
| 269 |
self._bart_tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
|
| 270 |
self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, local_files_only=True)
|
|
|
|
| 320 |
raw_text = str(exc)
|
| 321 |
|
| 322 |
if not plot_args:
|
| 323 |
+
logger.warning("Falling back to heuristic plot args")
|
| 324 |
+
plot_args = _validate(_heuristic_plot_args(query, columns, dtypes), columns) or default_args
|
| 325 |
|
| 326 |
try:
|
| 327 |
chart_result = self.chart_generator.generate_chart(plot_args)
|