Omartificial-Intelligence-Space commited on
Commit
be50698
Β·
verified Β·
1 Parent(s): 877d8a2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +422 -0
app.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio App for Chart Generation using LLM Agents
3
+ Deployable on HuggingFace Spaces
4
+ """
5
+
6
+ import re
7
+ import json
8
+ import os
9
+ import tempfile
10
+ from pathlib import Path
11
+ import gradio as gr
12
+ import utils
13
+
14
+ # Chart generation functions
15
+ def generate_chart_code(instruction: str, model: str, out_path_v1: str) -> str:
16
+ """Generate Python code to make a plot with matplotlib using tag-based wrapping."""
17
+ prompt = f"""
18
+ You are a data visualization expert.
19
+
20
+ Return your answer *strictly* in this format:
21
+
22
+ <execute_python>
23
+ # valid python code here
24
+ </execute_python>
25
+
26
+ Do not add explanations, only the tags and the code.
27
+
28
+ The code should create a visualization from a DataFrame 'df' with these columns:
29
+ - date (M/D/YY)
30
+ - time (HH:MM)
31
+ - cash_type (card or cash)
32
+ - card (string)
33
+ - price (number)
34
+ - coffee_name (string)
35
+ - quarter (1-4)
36
+ - month (1-12)
37
+ - year (YYYY)
38
+
39
+ User instruction: {instruction}
40
+
41
+ Requirements for the code:
42
+ 1. Assume the DataFrame is already loaded as 'df'.
43
+ 2. Use matplotlib for plotting.
44
+ 3. Add clear title, axis labels, and legend if needed.
45
+ 4. Save the figure as '{out_path_v1}' with dpi=300.
46
+ 5. Do not call plt.show().
47
+ 6. Close all plots with plt.close().
48
+ 7. Add all necessary import python statements
49
+
50
+ Return ONLY the code wrapped in <execute_python> tags.
51
+ """
52
+ response = utils.get_response(model, prompt)
53
+ return response
54
+
55
+
56
+ def reflect_on_image_and_regenerate(
57
+ chart_path: str,
58
+ instruction: str,
59
+ model_name: str,
60
+ out_path_v2: str,
61
+ code_v1: str,
62
+ ) -> tuple[str, str]:
63
+ """
64
+ Critique the chart IMAGE and the original code against the instruction,
65
+ then return refined matplotlib code.
66
+ Returns (feedback, refined_code_with_tags).
67
+ """
68
+ media_type, b64 = utils.encode_image_b64(chart_path)
69
+
70
+ prompt = f"""
71
+ You are a data visualization expert.
72
+ Your task: critique the attached chart and the original code against the given instruction,
73
+ then return improved matplotlib code.
74
+
75
+ Original code (for context):
76
+ {code_v1}
77
+
78
+ OUTPUT FORMAT (STRICT):
79
+ 1) First line: a valid JSON object with ONLY the "feedback" field.
80
+ Example: {{"feedback": "The legend is unclear and the axis labels overlap."}}
81
+
82
+ 2) After a newline, output ONLY the refined Python code wrapped in:
83
+ <execute_python>
84
+ ...
85
+ </execute_python>
86
+
87
+ 3) Import all necessary libraries in the code. Don't assume any imports from the original code.
88
+
89
+ HARD CONSTRAINTS:
90
+ - Do NOT include Markdown, backticks, or any extra prose outside the two parts above.
91
+ - Use pandas/matplotlib only (no seaborn).
92
+ - Assume df already exists; do not read from files.
93
+ - Save to '{out_path_v2}' with dpi=300.
94
+ - Always call plt.close() at the end (no plt.show()).
95
+ - Include all necessary import statements.
96
+
97
+ Schema (columns available in df):
98
+ - date (M/D/YY)
99
+ - time (HH:MM)
100
+ - cash_type (card or cash)
101
+ - card (string)
102
+ - price (number)
103
+ - coffee_name (string)
104
+ - quarter (1-4)
105
+ - month (1-12)
106
+ - year (YYYY)
107
+
108
+ Instruction:
109
+ {instruction}
110
+ """
111
+
112
+ # Handle different model providers
113
+ lower = model_name.lower()
114
+ if "claude" in lower or "anthropic" in lower:
115
+ content = utils.image_anthropic_call(model_name, prompt, media_type, b64)
116
+ else:
117
+ content = utils.image_openai_call(model_name, prompt, media_type, b64)
118
+
119
+ # Parse feedback (first JSON line)
120
+ lines = content.strip().splitlines()
121
+ json_line = lines[0].strip() if lines else ""
122
+
123
+ try:
124
+ obj = json.loads(json_line)
125
+ except Exception as e:
126
+ # Fallback: try to capture the first {...} in all the content
127
+ m_json = re.search(r"\{.*?\}", content, flags=re.DOTALL)
128
+ if m_json:
129
+ try:
130
+ obj = json.loads(m_json.group(0))
131
+ except Exception:
132
+ obj = {"feedback": f"Failed to parse JSON: {e}"}
133
+ else:
134
+ obj = {"feedback": f"Failed to find JSON: {e}"}
135
+
136
+ # Extract refined code from <execute_python>...</execute_python>
137
+ m_code = re.search(r"<execute_python>([\s\S]*?)</execute_python>", content)
138
+ refined_code_body = m_code.group(1).strip() if m_code else ""
139
+ refined_code = utils.ensure_execute_python_tags(refined_code_body)
140
+
141
+ feedback = str(obj.get("feedback", "")).strip()
142
+ return feedback, refined_code
143
+
144
+
145
+ def run_workflow(
146
+ user_instructions: str,
147
+ generation_model: str,
148
+ reflection_model: str,
149
+ progress=gr.Progress(),
150
+ ):
151
+ """
152
+ End-to-end pipeline for chart generation with reflection.
153
+ Returns results for Gradio display.
154
+ """
155
+ try:
156
+ # Use the CSV file in the same directory
157
+ csv_path = "coffee_sales_local.csv"
158
+ if not os.path.exists(csv_path):
159
+ return (
160
+ None,
161
+ None,
162
+ None,
163
+ None,
164
+ None,
165
+ f"Error: CSV file '{csv_path}' not found. Please ensure the file exists.",
166
+ )
167
+
168
+ progress(0.1, desc="Loading dataset...")
169
+ df = utils.load_and_prepare_data(csv_path)
170
+
171
+ # Create temporary directory for charts
172
+ with tempfile.TemporaryDirectory() as temp_dir:
173
+ out_v1 = os.path.join(temp_dir, "chart_v1.png")
174
+ out_v2 = os.path.join(temp_dir, "chart_v2.png")
175
+
176
+ # Step 1: Generate V1 code
177
+ progress(0.2, desc="Generating initial chart code (V1)...")
178
+ code_v1 = generate_chart_code(
179
+ instruction=user_instructions,
180
+ model=generation_model,
181
+ out_path_v1=out_v1,
182
+ )
183
+
184
+ # Step 2: Execute V1
185
+ progress(0.4, desc="Executing V1 code...")
186
+ match = re.search(r"<execute_python>([\s\S]*?)</execute_python>", code_v1)
187
+ if match:
188
+ initial_code = match.group(1).strip()
189
+ exec_globals = {"df": df}
190
+ try:
191
+ exec(initial_code, exec_globals)
192
+ except Exception as e:
193
+ return (
194
+ None,
195
+ None,
196
+ None,
197
+ None,
198
+ None,
199
+ f"Error executing V1 code: {str(e)}\n\nCode:\n{initial_code}",
200
+ )
201
+ else:
202
+ return (
203
+ None,
204
+ None,
205
+ None,
206
+ None,
207
+ None,
208
+ "Error: Could not extract code from V1 response. No <execute_python> tags found.",
209
+ )
210
+
211
+ if not os.path.exists(out_v1):
212
+ return (
213
+ None,
214
+ None,
215
+ None,
216
+ None,
217
+ None,
218
+ f"Error: Chart V1 was not generated. Check if the code saves to '{out_v1}'.",
219
+ )
220
+
221
+ # Step 3: Reflect and generate V2
222
+ progress(0.6, desc="Reflecting on V1 and generating improvements...")
223
+ feedback, code_v2 = reflect_on_image_and_regenerate(
224
+ chart_path=out_v1,
225
+ instruction=user_instructions,
226
+ model_name=reflection_model,
227
+ out_path_v2=out_v2,
228
+ code_v1=code_v1,
229
+ )
230
+
231
+ # Step 4: Execute V2
232
+ progress(0.8, desc="Executing improved chart code (V2)...")
233
+ match = re.search(r"<execute_python>([\s\S]*?)</execute_python>", code_v2)
234
+ if match:
235
+ reflected_code = match.group(1).strip()
236
+ exec_globals = {"df": df}
237
+ try:
238
+ exec(reflected_code, exec_globals)
239
+ except Exception as e:
240
+ return (
241
+ out_v1,
242
+ code_v1,
243
+ None,
244
+ None,
245
+ None,
246
+ f"Error executing V2 code: {str(e)}\n\nCode:\n{reflected_code}",
247
+ )
248
+ else:
249
+ return (
250
+ out_v1,
251
+ code_v1,
252
+ None,
253
+ None,
254
+ None,
255
+ "Error: Could not extract code from V2 response. No <execute_python> tags found.",
256
+ )
257
+
258
+ if not os.path.exists(out_v2):
259
+ return (
260
+ out_v1,
261
+ code_v1,
262
+ feedback,
263
+ code_v2,
264
+ None,
265
+ f"Error: Chart V2 was not generated. Check if the code saves to '{out_v2}'.",
266
+ )
267
+
268
+ progress(1.0, desc="Complete!")
269
+
270
+ # Copy files to permanent location (Gradio needs accessible paths)
271
+ import shutil
272
+ final_v1 = "chart_v1.png"
273
+ final_v2 = "chart_v2.png"
274
+ shutil.copy(out_v1, final_v1)
275
+ shutil.copy(out_v2, final_v2)
276
+
277
+ return (
278
+ final_v1,
279
+ code_v1,
280
+ feedback,
281
+ code_v2,
282
+ final_v2,
283
+ "βœ… Chart generation complete!",
284
+ )
285
+
286
+ except Exception as e:
287
+ import traceback
288
+ error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
289
+ return (None, None, None, None, None, error_msg)
290
+
291
+
292
+ # Gradio Interface
293
+ def create_interface():
294
+ """Create and configure the Gradio interface."""
295
+
296
+ with gr.Blocks(title="Chart Generation with LLM Agents", theme=gr.themes.Soft()) as demo:
297
+ gr.Markdown(
298
+ """
299
+ # πŸ“Š Chart Generation with LLM Agents
300
+
301
+ This app uses **LLM Agents with Reflection Pattern** to generate and improve data visualizations.
302
+
303
+ **How it works:**
304
+ 1. Enter your chart instruction (e.g., "Create a plot comparing Q1 coffee sales in 2024 and 2025")
305
+ 2. The LLM generates initial chart code (V1)
306
+ 3. The system reflects on V1 and generates improved code (V2)
307
+ 4. Both charts are displayed for comparison
308
+
309
+ **Dataset:** Coffee sales data with columns: date, time, cash_type, card, price, coffee_name, quarter, month, year
310
+ """
311
+ )
312
+
313
+ with gr.Row():
314
+ with gr.Column(scale=2):
315
+ instruction_input = gr.Textbox(
316
+ label="Chart Instruction",
317
+ placeholder="Create a plot comparing Q1 coffee sales in 2024 and 2025 using the data in coffee_sales.csv.",
318
+ lines=3,
319
+ value="Create a plot comparing Q1 coffee sales in 2024 and 2025 using the data in coffee_sales.csv.",
320
+ )
321
+
322
+ with gr.Row():
323
+ generation_model = gr.Dropdown(
324
+ label="Generation Model (for V1)",
325
+ choices=[
326
+ "gpt-4o-mini",
327
+ "gpt-4o",
328
+ "o1-mini",
329
+ "o1-preview",
330
+ "claude-3-5-sonnet-20241022",
331
+ "claude-3-opus-20240229",
332
+ ],
333
+ value="gpt-4o-mini",
334
+ )
335
+
336
+ reflection_model = gr.Dropdown(
337
+ label="Reflection Model (for V2)",
338
+ choices=[
339
+ "o1-mini",
340
+ "o1-preview",
341
+ "gpt-4o",
342
+ "gpt-4o-mini",
343
+ "claude-3-5-sonnet-20241022",
344
+ "claude-3-opus-20240229",
345
+ ],
346
+ value="o1-mini",
347
+ )
348
+
349
+ generate_btn = gr.Button("Generate Charts", variant="primary", size="lg")
350
+
351
+ with gr.Column(scale=1):
352
+ status_output = gr.Textbox(
353
+ label="Status",
354
+ interactive=False,
355
+ value="Ready to generate charts...",
356
+ )
357
+
358
+ with gr.Row():
359
+ with gr.Column():
360
+ gr.Markdown("### πŸ“ˆ Chart V1 (Initial)")
361
+ chart_v1_output = gr.Image(label="Generated Chart V1", type="filepath")
362
+ code_v1_output = gr.Code(
363
+ label="Code V1",
364
+ language="python",
365
+ interactive=False,
366
+ )
367
+
368
+ with gr.Column():
369
+ gr.Markdown("### ✨ Chart V2 (Improved)")
370
+ chart_v2_output = gr.Image(label="Generated Chart V2", type="filepath")
371
+ code_v2_output = gr.Code(
372
+ label="Code V2",
373
+ language="python",
374
+ interactive=False,
375
+ )
376
+
377
+ feedback_output = gr.Textbox(
378
+ label="πŸ“ Reflection Feedback",
379
+ lines=5,
380
+ interactive=False,
381
+ value="",
382
+ )
383
+
384
+ # Connect the workflow
385
+ generate_btn.click(
386
+ fn=run_workflow,
387
+ inputs=[instruction_input, generation_model, reflection_model],
388
+ outputs=[
389
+ chart_v1_output,
390
+ code_v1_output,
391
+ feedback_output,
392
+ code_v2_output,
393
+ chart_v2_output,
394
+ status_output,
395
+ ],
396
+ )
397
+
398
+ gr.Markdown(
399
+ """
400
+ ---
401
+ ### πŸ’‘ Tips:
402
+ - Be specific in your instructions (mention time periods, chart types, etc.)
403
+ - Use a faster model for generation (V1) and a stronger model for reflection (V2)
404
+ - The reflection model analyzes the V1 chart image and suggests improvements
405
+ """
406
+ )
407
+
408
+ return demo
409
+
410
+
411
+ if __name__ == "__main__":
412
+ # Check for required environment variables
413
+ if not os.getenv("OPENAI_API_KEY") and not os.getenv("ANTHROPIC_API_KEY"):
414
+ print("⚠️ Warning: No API keys found. Please set OPENAI_API_KEY or ANTHROPIC_API_KEY")
415
+ print(" For HuggingFace Spaces, add them as secrets in the Space settings")
416
+
417
+ demo = create_interface()
418
+ demo.launch(
419
+ server_name="0.0.0.0",
420
+ server_port=7860,
421
+ share=True,
422
+ )