Priyansh Saxena commited on
Commit
f3fd40f
·
1 Parent(s): 962831e

fix: remove runtime model dependency and repair chart generation

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. chart_generator.py +8 -10
  3. llm_agent.py +107 -5
app.py CHANGED
@@ -17,7 +17,7 @@ logging.getLogger('plotly').setLevel(logging.WARNING)
17
 
18
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19
 
20
- app = Flask(__name__, static_folder=os.path.join(BASE_DIR, '..', 'static'))
21
 
22
  CORS(app, origins=[
23
  "https://llm-integrated-excel-plotter-app.vercel.app",
@@ -27,7 +27,7 @@ CORS(app, origins=[
27
 
28
  agent = LLM_Agent()
29
 
30
- UPLOAD_FOLDER = os.path.join(BASE_DIR, '..', 'data', 'uploads')
31
  ALLOWED_EXTENSIONS = {'csv', 'xls', 'xlsx'}
32
  MAX_UPLOAD_BYTES = 10 * 1024 * 1024 # 10 MB
33
 
 
17
 
18
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19
 
20
+ app = Flask(__name__, static_folder=os.path.join(BASE_DIR, 'static'))
21
 
22
  CORS(app, origins=[
23
  "https://llm-integrated-excel-plotter-app.vercel.app",
 
27
 
28
  agent = LLM_Agent()
29
 
30
+ UPLOAD_FOLDER = os.path.join(BASE_DIR, 'data', 'uploads')
31
  ALLOWED_EXTENSIONS = {'csv', 'xls', 'xlsx'}
32
  MAX_UPLOAD_BYTES = 10 * 1024 * 1024 # 10 MB
33
 
chart_generator.py CHANGED
@@ -130,7 +130,7 @@ class ChartGenerator:
130
  if chart_type not in ("pie", "histogram", "box") and len(x) > 5:
131
  plt.xticks(rotation=45, ha="right")
132
 
133
- output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static", "images")
134
  os.makedirs(output_dir, exist_ok=True)
135
  filename = f"chart_{uuid.uuid4().hex[:12]}.png"
136
  full_path = os.path.join(output_dir, filename)
@@ -174,14 +174,12 @@ class ChartGenerator:
174
  line=dict(color=c, width=2),
175
  marker=dict(size=6)).to_plotly_json())
176
 
177
- layout = dict(
178
- **_PLOTLY_LAYOUT,
179
- title=dict(
180
- text=f"{chart_type.title()} \u2014 {', '.join(y_cols)} vs {x_col}",
181
- font=dict(size=15, color="#e2e8f0"),
182
- ),
183
- xaxis=dict(**_PLOTLY_LAYOUT["xaxis"], title=x_col),
184
- yaxis=dict(**_PLOTLY_LAYOUT["yaxis"], title=" / ".join(y_cols)),
185
- )
186
 
187
  return {"data": traces, "layout": layout}
 
130
  if chart_type not in ("pie", "histogram", "box") and len(x) > 5:
131
  plt.xticks(rotation=45, ha="right")
132
 
133
+ output_dir = os.path.join(os.path.dirname(__file__), "static", "images")
134
  os.makedirs(output_dir, exist_ok=True)
135
  filename = f"chart_{uuid.uuid4().hex[:12]}.png"
136
  full_path = os.path.join(output_dir, filename)
 
174
  line=dict(color=c, width=2),
175
  marker=dict(size=6)).to_plotly_json())
176
 
177
+ layout = {**_PLOTLY_LAYOUT}
178
+ layout["title"] = {
179
+ "text": f"{chart_type.title()} \u2014 {', '.join(y_cols)} vs {x_col}",
180
+ "font": {"size": 15, "color": "#e2e8f0"},
181
+ }
182
+ layout["xaxis"] = {**_PLOTLY_LAYOUT["xaxis"], "title": x_col}
183
+ layout["yaxis"] = {**_PLOTLY_LAYOUT["yaxis"], "title": " / ".join(y_cols)}
 
 
184
 
185
  return {"data": traces, "layout": layout}
llm_agent.py CHANGED
@@ -1,7 +1,9 @@
1
  import ast
 
2
  import json
3
  import logging
4
  import os
 
5
  import time
6
 
7
  from dotenv import load_dotenv
@@ -17,6 +19,18 @@ logger = logging.getLogger(__name__)
17
  def _model_dir(dirname: str) -> str:
18
  return os.path.join(os.path.dirname(os.path.abspath(__file__)), dirname)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # ---------------------------------------------------------------------------
21
  # Prompt templates
22
  # ---------------------------------------------------------------------------
@@ -88,6 +102,89 @@ def _validate(args: dict, columns: list):
88
  return args
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # ---------------------------------------------------------------------------
92
  # Agent
93
  # ---------------------------------------------------------------------------
@@ -101,16 +198,19 @@ class LLM_Agent:
101
  self._bart_model = None
102
  self._qwen_tokenizer = None
103
  self._qwen_model = None
 
 
104
 
105
  # -- model runners -------------------------------------------------------
106
 
107
  def _run_qwen(self, user_msg: str) -> str:
108
  if self._qwen_model is None:
109
  from transformers import AutoModelForCausalLM, AutoTokenizer
110
- # Prefer a local model path in Spaces to avoid any runtime network dependency.
111
- model_id = os.getenv("QWEN_LOCAL_PATH", "")
112
  if not model_id:
113
  raise ValueError("Qwen local model is not configured in this Space")
 
 
114
  logger.info("Loading Qwen model (first request)...")
115
  self._qwen_tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
116
  self._qwen_model = AutoModelForCausalLM.from_pretrained(model_id, local_files_only=True)
@@ -162,7 +262,9 @@ class LLM_Agent:
162
  def _run_bart(self, query: str) -> str:
163
  if self._bart_model is None:
164
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
165
- model_id = os.getenv("BART_LOCAL_PATH", _model_dir("fine-tuned-bart-large"))
 
 
166
  logger.info("Loading BART model (first request)...")
167
  self._bart_tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
168
  self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, local_files_only=True)
@@ -218,8 +320,8 @@ class LLM_Agent:
218
  raw_text = str(exc)
219
 
220
  if not plot_args:
221
- logger.warning("Falling back to default plot args")
222
- plot_args = default_args
223
 
224
  try:
225
  chart_result = self.chart_generator.generate_chart(plot_args)
 
1
  import ast
2
+ import difflib
3
  import json
4
  import logging
5
  import os
6
+ import re
7
  import time
8
 
9
  from dotenv import load_dotenv
 
19
  def _model_dir(dirname: str) -> str:
20
  return os.path.join(os.path.dirname(os.path.abspath(__file__)), dirname)
21
 
22
+
23
+ def _has_model_weights(model_dir: str) -> bool:
24
+ weight_files = (
25
+ "pytorch_model.bin",
26
+ "model.safetensors",
27
+ "tf_model.h5",
28
+ "flax_model.msgpack",
29
+ )
30
+ return os.path.isdir(model_dir) and any(
31
+ os.path.exists(os.path.join(model_dir, filename)) for filename in weight_files
32
+ )
33
+
34
  # ---------------------------------------------------------------------------
35
  # Prompt templates
36
  # ---------------------------------------------------------------------------
 
102
  return args
103
 
104
 
105
+ def _pick_chart_type(query: str) -> str:
106
+ lowered = query.lower()
107
+ aliases = {
108
+ "scatter": ["scatter", "scatterplot"],
109
+ "bar": ["bar", "column"],
110
+ "pie": ["pie", "donut"],
111
+ "histogram": ["histogram", "distribution"],
112
+ "box": ["box", "boxplot"],
113
+ "area": ["area"],
114
+ "line": ["line", "trend", "over time", "over the years"],
115
+ }
116
+ for chart_type, keywords in aliases.items():
117
+ if any(keyword in lowered for keyword in keywords):
118
+ return chart_type
119
+ return "line"
120
+
121
+
122
+ def _pick_color(query: str):
123
+ lowered = query.lower()
124
+ colors = [
125
+ "red", "blue", "green", "yellow", "orange", "purple", "pink",
126
+ "black", "white", "gray", "grey", "cyan", "teal", "indigo",
127
+ ]
128
+ for color in colors:
129
+ if re.search(rf"\b{re.escape(color)}\b", lowered):
130
+ return color
131
+ return None
132
+
133
+
134
+ def _pick_columns(query: str, columns: list, dtypes: dict):
135
+ lowered = query.lower()
136
+ query_tokens = re.findall(r"[a-zA-Z0-9_]+", lowered)
137
+
138
+ def score_column(column: str) -> float:
139
+ col_lower = column.lower()
140
+ score = 0.0
141
+ if col_lower in lowered:
142
+ score += 10.0
143
+ for token in query_tokens:
144
+ if token and token in col_lower:
145
+ score += 2.0
146
+ score += difflib.SequenceMatcher(None, lowered, col_lower).ratio()
147
+ return score
148
+
149
+ sorted_columns = sorted(columns, key=score_column, reverse=True)
150
+ numeric_columns = [col for col in columns if dtypes.get(col) in {"integer", "float"}]
151
+ temporal_columns = [col for col in columns if dtypes.get(col) == "datetime"]
152
+ year_like = [col for col in columns if "year" in col.lower() or "date" in col.lower() or "month" in col.lower()]
153
+
154
+ x_col = None
155
+ for candidate in year_like + temporal_columns + sorted_columns:
156
+ if candidate in columns:
157
+ x_col = candidate
158
+ break
159
+ if x_col is None and columns:
160
+ x_col = columns[0]
161
+
162
+ y_candidates = [col for col in sorted_columns if col != x_col and col in numeric_columns]
163
+ if not y_candidates:
164
+ y_candidates = [col for col in numeric_columns if col != x_col]
165
+ if not y_candidates:
166
+ y_candidates = [col for col in columns if col != x_col]
167
+
168
+ return x_col, y_candidates[:1]
169
+
170
+
171
+ def _heuristic_plot_args(query: str, columns: list, dtypes: dict) -> dict:
172
+ x_col, y_cols = _pick_columns(query, columns, dtypes)
173
+ if not x_col:
174
+ x_col = "Year"
175
+ if not y_cols:
176
+ fallback_y = next((col for col in columns if col != x_col), columns[:1])
177
+ y_cols = list(fallback_y) if isinstance(fallback_y, tuple) else fallback_y
178
+ if isinstance(y_cols, str):
179
+ y_cols = [y_cols]
180
+ return {
181
+ "x": x_col,
182
+ "y": y_cols,
183
+ "chart_type": _pick_chart_type(query),
184
+ "color": _pick_color(query),
185
+ }
186
+
187
+
188
  # ---------------------------------------------------------------------------
189
  # Agent
190
  # ---------------------------------------------------------------------------
 
198
  self._bart_model = None
199
  self._qwen_tokenizer = None
200
  self._qwen_model = None
201
+ self._bart_model_dir = os.getenv("BART_LOCAL_PATH", _model_dir("fine-tuned-bart-large"))
202
+ self._qwen_model_dir = os.getenv("QWEN_LOCAL_PATH", "")
203
 
204
  # -- model runners -------------------------------------------------------
205
 
206
  def _run_qwen(self, user_msg: str) -> str:
207
  if self._qwen_model is None:
208
  from transformers import AutoModelForCausalLM, AutoTokenizer
209
+ model_id = self._qwen_model_dir
 
210
  if not model_id:
211
  raise ValueError("Qwen local model is not configured in this Space")
212
+ if not _has_model_weights(model_id):
213
+ raise ValueError(f"Qwen model weights not found in {model_id}")
214
  logger.info("Loading Qwen model (first request)...")
215
  self._qwen_tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
216
  self._qwen_model = AutoModelForCausalLM.from_pretrained(model_id, local_files_only=True)
 
262
  def _run_bart(self, query: str) -> str:
263
  if self._bart_model is None:
264
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
265
+ model_id = self._bart_model_dir
266
+ if not _has_model_weights(model_id):
267
+ raise ValueError(f"BART model weights not found in {model_id}")
268
  logger.info("Loading BART model (first request)...")
269
  self._bart_tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
270
  self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, local_files_only=True)
 
320
  raw_text = str(exc)
321
 
322
  if not plot_args:
323
+ logger.warning("Falling back to heuristic plot args")
324
+ plot_args = _validate(_heuristic_plot_args(query, columns, dtypes), columns) or default_args
325
 
326
  try:
327
  chart_result = self.chart_generator.generate_chart(plot_args)