|
|
""" |
|
|
Gradio App for Chart Generation using LLM Agents |
|
|
Deployable on HuggingFace Spaces |
|
|
""" |
|
|
|
|
|
import re |
|
|
import json |
|
|
import os |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
import gradio as gr |
|
|
import utils |
|
|
|
|
|
|
|
|
def generate_chart_code(instruction: str, model: str, out_path_v1: str) -> str: |
|
|
"""Generate Python code to make a plot with matplotlib using tag-based wrapping.""" |
|
|
prompt = f""" |
|
|
You are a data visualization expert. |
|
|
|
|
|
Return your answer *strictly* in this format: |
|
|
|
|
|
<execute_python> |
|
|
# valid python code here |
|
|
</execute_python> |
|
|
|
|
|
Do not add explanations, only the tags and the code. |
|
|
|
|
|
The code should create a visualization from a DataFrame 'df' with these columns: |
|
|
- date (M/D/YY) |
|
|
- time (HH:MM) |
|
|
- cash_type (card or cash) |
|
|
- card (string) |
|
|
- price (number) |
|
|
- coffee_name (string) |
|
|
- quarter (1-4) |
|
|
- month (1-12) |
|
|
- year (YYYY) |
|
|
|
|
|
User instruction: {instruction} |
|
|
|
|
|
Requirements for the code: |
|
|
1. Assume the DataFrame is already loaded as 'df'. |
|
|
2. Use matplotlib for plotting. |
|
|
3. Add clear title, axis labels, and legend if needed. |
|
|
4. Save the figure as '{out_path_v1}' with dpi=300. |
|
|
5. Do not call plt.show(). |
|
|
6. Close all plots with plt.close(). |
|
|
7. Add all necessary import python statements |
|
|
|
|
|
Return ONLY the code wrapped in <execute_python> tags. |
|
|
""" |
|
|
response = utils.get_response(model, prompt) |
|
|
return response |
|
|
|
|
|
|
|
|
def reflect_on_image_and_regenerate( |
|
|
chart_path: str, |
|
|
instruction: str, |
|
|
model_name: str, |
|
|
out_path_v2: str, |
|
|
code_v1: str, |
|
|
) -> tuple[str, str]: |
|
|
""" |
|
|
Critique the chart IMAGE and the original code against the instruction, |
|
|
then return refined matplotlib code. |
|
|
Returns (feedback, refined_code_with_tags). |
|
|
""" |
|
|
media_type, b64 = utils.encode_image_b64(chart_path) |
|
|
|
|
|
prompt = f""" |
|
|
You are a data visualization expert. |
|
|
Your task: critique the attached chart and the original code against the given instruction, |
|
|
then return improved matplotlib code. |
|
|
|
|
|
Original code (for context): |
|
|
{code_v1} |
|
|
|
|
|
OUTPUT FORMAT (STRICT): |
|
|
1) First line: a valid JSON object with ONLY the "feedback" field. |
|
|
Example: {{"feedback": "The legend is unclear and the axis labels overlap."}} |
|
|
|
|
|
2) After a newline, output ONLY the refined Python code wrapped in: |
|
|
<execute_python> |
|
|
... |
|
|
</execute_python> |
|
|
|
|
|
3) Import all necessary libraries in the code. Don't assume any imports from the original code. |
|
|
|
|
|
HARD CONSTRAINTS: |
|
|
- Do NOT include Markdown, backticks, or any extra prose outside the two parts above. |
|
|
- Use pandas/matplotlib only (no seaborn). |
|
|
- Assume df already exists; do not read from files. |
|
|
- Save to '{out_path_v2}' with dpi=300. |
|
|
- Always call plt.close() at the end (no plt.show()). |
|
|
- Include all necessary import statements. |
|
|
|
|
|
Schema (columns available in df): |
|
|
- date (M/D/YY) |
|
|
- time (HH:MM) |
|
|
- cash_type (card or cash) |
|
|
- card (string) |
|
|
- price (number) |
|
|
- coffee_name (string) |
|
|
- quarter (1-4) |
|
|
- month (1-12) |
|
|
- year (YYYY) |
|
|
|
|
|
Instruction: |
|
|
{instruction} |
|
|
""" |
|
|
|
|
|
|
|
|
lower = model_name.lower() |
|
|
if "claude" in lower or "anthropic" in lower: |
|
|
content = utils.image_anthropic_call(model_name, prompt, media_type, b64) |
|
|
else: |
|
|
content = utils.image_openai_call(model_name, prompt, media_type, b64) |
|
|
|
|
|
|
|
|
lines = content.strip().splitlines() |
|
|
json_line = lines[0].strip() if lines else "" |
|
|
|
|
|
try: |
|
|
obj = json.loads(json_line) |
|
|
except Exception as e: |
|
|
|
|
|
m_json = re.search(r"\{.*?\}", content, flags=re.DOTALL) |
|
|
if m_json: |
|
|
try: |
|
|
obj = json.loads(m_json.group(0)) |
|
|
except Exception: |
|
|
obj = {"feedback": f"Failed to parse JSON: {e}"} |
|
|
else: |
|
|
obj = {"feedback": f"Failed to find JSON: {e}"} |
|
|
|
|
|
|
|
|
m_code = re.search(r"<execute_python>([\s\S]*?)</execute_python>", content) |
|
|
refined_code_body = m_code.group(1).strip() if m_code else "" |
|
|
refined_code = utils.ensure_execute_python_tags(refined_code_body) |
|
|
|
|
|
feedback = str(obj.get("feedback", "")).strip() |
|
|
return feedback, refined_code |
|
|
|
|
|
|
|
|
def run_workflow( |
|
|
user_instructions: str, |
|
|
generation_model: str, |
|
|
reflection_model: str, |
|
|
progress=gr.Progress(), |
|
|
): |
|
|
""" |
|
|
End-to-end pipeline for chart generation with reflection. |
|
|
Returns results for Gradio display. |
|
|
""" |
|
|
try: |
|
|
|
|
|
csv_path = "coffee_sales_local.csv" |
|
|
if not os.path.exists(csv_path): |
|
|
return ( |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
f"Error: CSV file '{csv_path}' not found. Please ensure the file exists.", |
|
|
) |
|
|
|
|
|
progress(0.1, desc="Loading dataset...") |
|
|
df = utils.load_and_prepare_data(csv_path) |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
out_v1 = os.path.join(temp_dir, "chart_v1.png") |
|
|
out_v2 = os.path.join(temp_dir, "chart_v2.png") |
|
|
|
|
|
|
|
|
progress(0.2, desc="Generating initial chart code (V1)...") |
|
|
code_v1 = generate_chart_code( |
|
|
instruction=user_instructions, |
|
|
model=generation_model, |
|
|
out_path_v1=out_v1, |
|
|
) |
|
|
|
|
|
|
|
|
progress(0.4, desc="Executing V1 code...") |
|
|
match = re.search(r"<execute_python>([\s\S]*?)</execute_python>", code_v1) |
|
|
if match: |
|
|
initial_code = match.group(1).strip() |
|
|
exec_globals = {"df": df} |
|
|
try: |
|
|
exec(initial_code, exec_globals) |
|
|
except Exception as e: |
|
|
return ( |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
f"Error executing V1 code: {str(e)}\n\nCode:\n{initial_code}", |
|
|
) |
|
|
else: |
|
|
return ( |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
"Error: Could not extract code from V1 response. No <execute_python> tags found.", |
|
|
) |
|
|
|
|
|
if not os.path.exists(out_v1): |
|
|
return ( |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
f"Error: Chart V1 was not generated. Check if the code saves to '{out_v1}'.", |
|
|
) |
|
|
|
|
|
|
|
|
progress(0.6, desc="Reflecting on V1 and generating improvements...") |
|
|
feedback, code_v2 = reflect_on_image_and_regenerate( |
|
|
chart_path=out_v1, |
|
|
instruction=user_instructions, |
|
|
model_name=reflection_model, |
|
|
out_path_v2=out_v2, |
|
|
code_v1=code_v1, |
|
|
) |
|
|
|
|
|
|
|
|
progress(0.8, desc="Executing improved chart code (V2)...") |
|
|
match = re.search(r"<execute_python>([\s\S]*?)</execute_python>", code_v2) |
|
|
if match: |
|
|
reflected_code = match.group(1).strip() |
|
|
exec_globals = {"df": df} |
|
|
try: |
|
|
exec(reflected_code, exec_globals) |
|
|
except Exception as e: |
|
|
return ( |
|
|
out_v1, |
|
|
code_v1, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
f"Error executing V2 code: {str(e)}\n\nCode:\n{reflected_code}", |
|
|
) |
|
|
else: |
|
|
return ( |
|
|
out_v1, |
|
|
code_v1, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
"Error: Could not extract code from V2 response. No <execute_python> tags found.", |
|
|
) |
|
|
|
|
|
if not os.path.exists(out_v2): |
|
|
return ( |
|
|
out_v1, |
|
|
code_v1, |
|
|
feedback, |
|
|
code_v2, |
|
|
None, |
|
|
f"Error: Chart V2 was not generated. Check if the code saves to '{out_v2}'.", |
|
|
) |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
|
|
|
|
|
import shutil |
|
|
final_v1 = "chart_v1.png" |
|
|
final_v2 = "chart_v2.png" |
|
|
shutil.copy(out_v1, final_v1) |
|
|
shutil.copy(out_v2, final_v2) |
|
|
|
|
|
return ( |
|
|
final_v1, |
|
|
code_v1, |
|
|
feedback, |
|
|
code_v2, |
|
|
final_v2, |
|
|
"โ
Chart generation complete!", |
|
|
) |
|
|
|
|
|
except ValueError as e: |
|
|
|
|
|
error_msg = f"โ Configuration Error: {str(e)}\n\nPlease check your API keys in HuggingFace Spaces settings." |
|
|
return (None, None, None, None, None, error_msg) |
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_type = type(e).__name__ |
|
|
error_msg = f"โ Error ({error_type}): {str(e)}\n\n" |
|
|
|
|
|
|
|
|
if "API" in error_type or "Connection" in error_type or "Illegal header" in str(e): |
|
|
error_msg += "๐ก Tip: Check your API key in HuggingFace Spaces settings. Make sure there are no extra spaces or newlines." |
|
|
elif "model" in str(e).lower(): |
|
|
error_msg += "๐ก Tip: The selected model might not be available. Try a different model." |
|
|
|
|
|
|
|
|
if os.getenv("DEBUG", "false").lower() == "true": |
|
|
error_msg += f"\n\nFull traceback:\n{traceback.format_exc()}" |
|
|
|
|
|
return (None, None, None, None, None, error_msg) |
|
|
|
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
"""Create and configure the Gradio interface.""" |
|
|
|
|
|
with gr.Blocks(title="Chart Generation with LLM Agents", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# ๐ Chart Generation with LLM Agents |
|
|
|
|
|
This app uses **LLM Agents with Reflection Pattern** to generate and improve data visualizations. |
|
|
|
|
|
**How it works:** |
|
|
1. Enter your chart instruction (e.g., "Create a plot comparing Q1 coffee sales in 2024 and 2025") |
|
|
2. The LLM generates initial chart code (V1) |
|
|
3. The system reflects on V1 and generates improved code (V2) |
|
|
4. Both charts are displayed for comparison |
|
|
|
|
|
**Dataset:** Coffee sales data with columns: date, time, cash_type, card, price, coffee_name, quarter, month, year |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
instruction_input = gr.Textbox( |
|
|
label="Chart Instruction", |
|
|
placeholder="Create a plot comparing Q1 coffee sales in 2024 and 2025 using the data in coffee_sales.csv.", |
|
|
lines=3, |
|
|
value="Create a plot comparing Q1 coffee sales in 2024 and 2025 using the data in coffee_sales.csv.", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
generation_model = gr.Dropdown( |
|
|
label="Generation Model (for V1)", |
|
|
choices=[ |
|
|
"gpt-4o-mini", |
|
|
"gpt-4o", |
|
|
"o1-mini", |
|
|
"o1-preview", |
|
|
"claude-3-5-sonnet-20241022", |
|
|
"claude-3-opus-20240229", |
|
|
], |
|
|
value="gpt-4o-mini", |
|
|
) |
|
|
|
|
|
reflection_model = gr.Dropdown( |
|
|
label="Reflection Model (for V2)", |
|
|
choices=[ |
|
|
"o1-mini", |
|
|
"o1-preview", |
|
|
"gpt-4o", |
|
|
"gpt-4o-mini", |
|
|
"claude-3-5-sonnet-20241022", |
|
|
"claude-3-opus-20240229", |
|
|
], |
|
|
value="gpt-4o-mini", |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("Generate Charts", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
status_output = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value="Ready to generate charts...", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### ๐ Chart V1 (Initial)") |
|
|
chart_v1_output = gr.Image(label="Generated Chart V1", type="filepath") |
|
|
code_v1_output = gr.Code( |
|
|
label="Code V1", |
|
|
language="python", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### โจ Chart V2 (Improved)") |
|
|
chart_v2_output = gr.Image(label="Generated Chart V2", type="filepath") |
|
|
code_v2_output = gr.Code( |
|
|
label="Code V2", |
|
|
language="python", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
feedback_output = gr.Textbox( |
|
|
label="๐ Reflection Feedback", |
|
|
lines=5, |
|
|
interactive=False, |
|
|
value="", |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=run_workflow, |
|
|
inputs=[instruction_input, generation_model, reflection_model], |
|
|
outputs=[ |
|
|
chart_v1_output, |
|
|
code_v1_output, |
|
|
feedback_output, |
|
|
code_v2_output, |
|
|
chart_v2_output, |
|
|
status_output, |
|
|
], |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
### ๐ก Tips: |
|
|
- Be specific in your instructions (mention time periods, chart types, etc.) |
|
|
- Use a faster model for generation (V1) and a stronger model for reflection (V2) |
|
|
- The reflection model analyzes the V1 chart image and suggests improvements |
|
|
""" |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
if not os.getenv("OPENAI_API_KEY") and not os.getenv("ANTHROPIC_API_KEY"): |
|
|
print("โ ๏ธ Warning: No API keys found. Please set OPENAI_API_KEY or ANTHROPIC_API_KEY") |
|
|
print(" For HuggingFace Spaces, add them as secrets in the Space settings") |
|
|
|
|
|
demo = create_interface() |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
) |
|
|
|
|
|
|