| | import os |
| |
|
| | import streamlit as st |
| |
|
| | import pandas as pd |
| |
|
| | import pickle |
| |
|
| | import base64 |
| |
|
| | from io import BytesIO, StringIO |
| |
|
| | import sys |
| |
|
| | import operator |
| |
|
| | from typing import Literal, Sequence, TypedDict, Annotated, List, Dict, Tuple |
| |
|
| | import tempfile |
| |
|
| | import shutil |
| |
|
| | import plotly.io as pio |
| |
|
| | import io |
| |
|
| | import re |
| |
|
| | import json |
| |
|
| | import openai |
| |
|
| | |
| |
|
| | import base64 |
| |
|
| | from datetime import datetime |
| |
|
| | from reportlab.lib.pagesizes import letter |
| |
|
| | from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image |
| |
|
| | from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle |
| |
|
| | from reportlab.lib.units import inch |
| |
|
| | from PIL import Image as PILImage |
| |
|
| | |
| |
|
| | |
| |
|
| | from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage |
| |
|
| | from langchain_core.prompts import ChatPromptTemplate |
| |
|
| | from langchain_openai import ChatOpenAI |
| |
|
| | from langchain_experimental.utilities import PythonREPL |
| |
|
| | from langgraph.prebuilt import ToolInvocation, ToolExecutor |
| |
|
| | from langchain_core.tools import tool |
| |
|
| | from langgraph.prebuilt import InjectedState |
| |
|
| | from langgraph.graph import StateGraph, END |
| |
|
| | from reportlab.platypus import PageBreak |
| |
|
| | from PIL import Image as PILImage |
| |
|
| | |
| |
|
| | |
| |
|
| | if 'ai_provider' not in st.session_state: |
| |
|
| | st.session_state.ai_provider = "openai" |
| |
|
| | |
| |
|
| | if 'api_key' not in st.session_state: |
| |
|
| | st.session_state.api_key = "" |
| |
|
| | |
| |
|
| | if 'selected_model' not in st.session_state: |
| |
|
| | st.session_state.selected_model = "gpt-4" |
| |
|
| | |
| |
|
| | |
| |
|
| | OPENAI_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-4-mini", "gpt-3.5-turbo"] |
| |
|
| | GROQ_MODELS = ["llama3.3-70b-versatile", "gemma2-9b-it", "llama-3-8b-8192"] |
| |
|
| | |
| |
|
| | |
| |
|
| | if 'temp_dir' not in st.session_state: |
| |
|
| | st.session_state.temp_dir = tempfile.mkdtemp() |
| |
|
| | st.session_state.images_dir = os.path.join(st.session_state.temp_dir, "images/plotly_figures/pickle") |
| |
|
| | os.makedirs(st.session_state.images_dir, exist_ok=True) |
| |
|
| | print(f"Created temporary directory: {st.session_state.temp_dir}") |
| |
|
| | print(f"Created images directory: {st.session_state.images_dir}") |
| |
|
| | |
| |
|
| | |
| |
|
| | SYSTEM_PROMPT = """## Role |
| | |
| | You are a professional data scientist helping a non-technical user understand, analyze, and visualize their data. |
| | |
| | |
| | |
| | ## Capabilities |
| | |
| | 1. **Execute python code** using the `complete_python_task` tool. |
| | |
| | |
| | |
| | ## Goals |
| | |
| | 1. Understand the user's objectives clearly. |
| | |
| | 2. Take the user on a data analysis journey, iterating to find the best way to visualize or analyse their data to solve their problems. |
| | |
| | 3. Investigate if the goal is achievable by running Python code via the `python_code` field. |
| | |
| | 4. Gain input from the user at every step to ensure the analysis is on the right track and to understand business nuances. |
| | |
| | |
| | |
| | ## Code Guidelines |
| | |
| | - **ALL INPUT DATA IS LOADED ALREADY**, so use the provided variable names to access the data. |
| | |
| | - **VARIABLES PERSIST BETWEEN RUNS**, so reuse previously defined variables if needed. |
| | |
| | - **TO SEE CODE OUTPUT**, use `print()` statements. You won't be able to see outputs of `pd.head()`, `pd.describe()` etc. otherwise. |
| | |
| | - **ONLY USE THE FOLLOWING LIBRARIES**: |
| | |
| | - `pandas` |
| | |
| | - `sklearn` (including all major ML models) |
| | |
| | - `plotly` |
| | |
| | - `numpy` |
| | |
| | |
| | |
| | All these libraries are already imported for you. |
| | |
| | |
| | |
| | ## Machine Learning Guidelines |
| | |
| | - For regression tasks: |
| | |
| | - Linear Regression: `LinearRegression` |
| | |
| | - Logistic Regression: `LogisticRegression` |
| | |
| | - Ridge Regression: `Ridge` |
| | |
| | - Lasso Regression: `Lasso` |
| | |
| | - Random Forest Regression: `RandomForestRegressor` |
| | |
| | |
| | |
| | - For classification tasks: |
| | |
| | - Logistic Regression: `LogisticRegression` |
| | |
| | - Decision Trees: `DecisionTreeClassifier` |
| | |
| | - Random Forests: `RandomForestClassifier` |
| | |
| | - Support Vector Machines: `SVC` |
| | |
| | - K-Nearest Neighbors: `KNeighborsClassifier` |
| | |
| | - Naive Bayes: `GaussianNB` |
| | |
| | |
| | |
| | - For clustering: |
| | |
| | - K-Means: `KMeans` |
| | |
| | - DBSCAN: `DBSCAN` |
| | |
| | |
| | |
| | - For dimensionality reduction: |
| | |
| | - PCA: `PCA` |
| | |
| | |
| | |
| | - Always preprocess data appropriately: |
| | |
| | - Scale numerical features with `StandardScaler` or `MinMaxScaler` |
| | |
| | - Encode categorical variables with `OneHotEncoder` when needed |
| | |
| | - Handle missing values with `SimpleImputer` |
| | |
| | |
| | |
| | - Always split data into training and testing sets using `train_test_split` |
| | |
| | - Evaluate models using appropriate metrics: |
| | |
| | - For regression: `mean_squared_error`, `mean_absolute_error`, `r2_score` |
| | |
| | - For classification: `accuracy_score`, `confusion_matrix`, `classification_report` |
| | |
| | - For clustering: `silhouette_score` |
| | |
| | |
| | |
| | - Consider using `cross_val_score` for more robust evaluation |
| | |
| | - Visualize ML results with plotly when possible |
| | |
| | |
| | |
| | ## Plotting Guidelines |
| | |
| | - Always use the `plotly` library for plotting. |
| | |
| | - Store all plotly figures inside a `plotly_figures` list, they will be saved automatically. |
| | |
| | - Do not try and show the plots inline with `fig.show()`. |
| | |
| | """ |
| |
|
| | |
| |
|
| | |
| |
|
| | class AgentState(TypedDict): |
| |
|
| | messages: Annotated[Sequence[BaseMessage], operator.add] |
| |
|
| | input_data: Annotated[List[Dict], operator.add] |
| |
|
| | intermediate_outputs: Annotated[List[dict], operator.add] |
| |
|
| | current_variables: dict |
| |
|
| | output_image_paths: Annotated[List[str], operator.add] |
| |
|
| | |
| |
|
| | |
| |
|
| | if 'in_memory_datasets' not in st.session_state: |
| |
|
| | st.session_state.in_memory_datasets = {} |
| |
|
| | |
| |
|
| | if 'persistent_vars' not in st.session_state: |
| |
|
| | st.session_state.persistent_vars = {} |
| |
|
| | |
| |
|
| | if 'dataset_metadata_list' not in st.session_state: |
| |
|
| | st.session_state.dataset_metadata_list = [] |
| |
|
| | |
| |
|
| | if 'chat_history' not in st.session_state: |
| |
|
| | st.session_state.chat_history = [] |
| |
|
| | |
| |
|
| | if 'dashboard_plots' not in st.session_state: |
| |
|
| | st.session_state.dashboard_plots = [None, None, None, None] |
| |
|
| | |
| |
|
| | if 'columns' not in st.session_state: |
| |
|
| | st.session_state.columns = ["No columns available"] |
| |
|
| | |
| |
|
| | if 'custom_plots_to_save' not in st.session_state: |
| |
|
| | st.session_state.custom_plots_to_save = {} |
| |
|
| | |
| |
|
| | |
| |
|
| | repl = PythonREPL() |
| |
|
| | plotly_saving_code = """import pickle |
| | |
| | |
| | |
| | import uuid |
| | |
| | import os |
| | |
| | for figure in plotly_figures: |
| | |
| | pickle_filename = f"{images_dir}/{uuid.uuid4()}.pickle" |
| | |
| | with open(pickle_filename, 'wb') as f: |
| | |
| | pickle.dump(figure, f) |
| | |
| | """ |
| |
|
| | |
| |
|
| | @tool |
| |
|
| | def complete_python_task( |
| | |
| | graph_state: Annotated[dict, InjectedState], |
| | |
| | thought: str, |
| | |
| | python_code: str |
| | |
| | ) -> Tuple[str, dict]: |
| |
|
| | """Execute Python code for data analysis and visualization.""" |
| |
|
| | |
| |
|
| | current_variables = graph_state.get("current_variables", {}) |
| |
|
| | |
| |
|
| | |
| |
|
| | for input_dataset in graph_state.get("input_data", []): |
| |
|
| | var_name = input_dataset.get("variable_name") |
| |
|
| | if var_name and var_name not in current_variables and var_name in st.session_state.in_memory_datasets: |
| |
|
| | print(f"Loading {var_name} from in-memory storage") |
| |
|
| | current_variables[var_name] = st.session_state.in_memory_datasets[var_name] |
| |
|
| | current_image_pickle_files = os.listdir(st.session_state.images_dir) |
| |
|
| | |
| |
|
| | try: |
| |
|
| | |
| |
|
| | old_stdout = sys.stdout |
| |
|
| | sys.stdout = StringIO() |
| |
|
| | |
| |
|
| | |
| |
|
| | exec_globals = globals().copy() |
| |
|
| | exec_globals.update(st.session_state.persistent_vars) |
| |
|
| | exec_globals.update(current_variables) |
| |
|
| | |
| |
|
| | |
| |
|
| | import sklearn |
| |
|
| | import numpy as np |
| |
|
| | |
| |
|
| | |
| |
|
| | from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge, Lasso |
| |
|
| | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier |
| |
|
| | from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor |
| |
|
| | from sklearn.svm import SVC, SVR |
| |
|
| | from sklearn.naive_bayes import GaussianNB |
| |
|
| | from sklearn.decomposition import PCA |
| |
|
| | from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor |
| |
|
| | from sklearn.cluster import KMeans, DBSCAN |
| |
|
| | from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder |
| |
|
| | from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV |
| |
|
| | from sklearn.metrics import ( |
| |
|
| | accuracy_score, confusion_matrix, classification_report, |
| |
|
| | mean_squared_error, r2_score, mean_absolute_error, silhouette_score |
| |
|
| | ) |
| |
|
| | from sklearn.pipeline import Pipeline |
| |
|
| | from sklearn.impute import SimpleImputer |
| |
|
| | |
| |
|
| | |
| |
|
| | exec_globals.update({ |
| |
|
| | "plotly_figures": [], |
| |
|
| | "images_dir": st.session_state.images_dir, |
| |
|
| | "np": np, |
| |
|
| | |
| |
|
| | "LinearRegression": LinearRegression, |
| |
|
| | "LogisticRegression": LogisticRegression, |
| |
|
| | "Ridge": Ridge, |
| |
|
| | "Lasso": Lasso, |
| |
|
| | |
| |
|
| | "DecisionTreeClassifier": DecisionTreeClassifier, |
| |
|
| | "DecisionTreeRegressor": DecisionTreeRegressor, |
| |
|
| | "RandomForestClassifier": RandomForestClassifier, |
| |
|
| | "RandomForestRegressor": RandomForestRegressor, |
| |
|
| | "GradientBoostingClassifier": GradientBoostingClassifier, |
| |
|
| | |
| |
|
| | "SVC": SVC, |
| |
|
| | "SVR": SVR, |
| |
|
| | |
| |
|
| | "GaussianNB": GaussianNB, |
| |
|
| | "PCA": PCA, |
| |
|
| | "KNeighborsClassifier": KNeighborsClassifier, |
| |
|
| | "KNeighborsRegressor": KNeighborsRegressor, |
| |
|
| | "KMeans": KMeans, |
| |
|
| | "DBSCAN": DBSCAN, |
| |
|
| | |
| |
|
| | "StandardScaler": StandardScaler, |
| |
|
| | "MinMaxScaler": MinMaxScaler, |
| |
|
| | "OneHotEncoder": OneHotEncoder, |
| |
|
| | "SimpleImputer": SimpleImputer, |
| |
|
| | |
| |
|
| | "train_test_split": train_test_split, |
| |
|
| | "cross_val_score": cross_val_score, |
| |
|
| | "GridSearchCV": GridSearchCV, |
| |
|
| | "accuracy_score": accuracy_score, |
| |
|
| | "confusion_matrix": confusion_matrix, |
| |
|
| | "classification_report": classification_report, |
| |
|
| | "mean_squared_error": mean_squared_error, |
| |
|
| | "r2_score": r2_score, |
| |
|
| | "mean_absolute_error": mean_absolute_error, |
| |
|
| | "silhouette_score": silhouette_score, |
| |
|
| | |
| |
|
| | "Pipeline": Pipeline |
| |
|
| | }) |
| |
|
| | |
| |
|
| | exec(python_code, exec_globals) |
| |
|
| | |
| |
|
| | st.session_state.persistent_vars.update({k: v for k, v in exec_globals.items() if k not in globals()}) |
| |
|
| | |
| |
|
| | |
| |
|
| | output = sys.stdout.getvalue() |
| |
|
| | |
| |
|
| | |
| |
|
| | sys.stdout = old_stdout |
| |
|
| | |
| |
|
| | updated_state = { |
| |
|
| | "intermediate_outputs": [{"thought": thought, "code": python_code, "output": output}], |
| |
|
| | "current_variables": st.session_state.persistent_vars |
| |
|
| | } |
| |
|
| | |
| |
|
| | if 'plotly_figures' in exec_globals and exec_globals['plotly_figures']: |
| |
|
| | exec(plotly_saving_code, exec_globals) |
| |
|
| | |
| |
|
| | |
| |
|
| | new_image_folder_contents = os.listdir(st.session_state.images_dir) |
| |
|
| | new_image_files = [file for file in new_image_folder_contents if file not in current_image_pickle_files] |
| |
|
| | |
| |
|
| | if new_image_files: |
| |
|
| | updated_state["output_image_paths"] = new_image_files |
| |
|
| | st.session_state.persistent_vars["plotly_figures"] = [] |
| |
|
| | return output, updated_state |
| |
|
| | |
| |
|
| | except Exception as e: |
| |
|
| | sys.stdout = old_stdout |
| |
|
| | print(f"Error in complete_python_task: {str(e)}") |
| |
|
| | return str(e), {"intermediate_outputs": [{"thought": thought, "code": python_code, "output": str(e)}]} |
| |
|
| | |
| |
|
| | |
| |
|
| | def initialize_llm(): |
| |
|
| | api_key = st.session_state.api_key |
| |
|
| | model = st.session_state.selected_model |
| |
|
| | |
| |
|
| | if not api_key: |
| |
|
| | return None |
| |
|
| | |
| |
|
| | try: |
| |
|
| | if st.session_state.ai_provider == "openai": |
| |
|
| | os.environ["OPENAI_API_KEY"] = api_key |
| |
|
| | return ChatOpenAI(model=model, temperature=0) |
| |
|
| | elif st.session_state.ai_provider == "groq": |
| |
|
| | os.environ["GROQ_API_KEY"] = api_key |
| |
|
| | |
| |
|
| | from langchain_groq import ChatGroq |
| |
|
| | return ChatGroq(model=model, temperature=0) |
| |
|
| | except Exception as e: |
| |
|
| | print(f"Error initializing LLM: {str(e)}") |
| |
|
| | return None |
| |
|
| | |
| |
|
| | |
| |
|
| | tools = [complete_python_task] |
| |
|
| | tool_executor = ToolExecutor(tools) |
| |
|
| | |
| |
|
| | |
| |
|
| | chat_template = ChatPromptTemplate.from_messages([ |
| |
|
| | ("system", SYSTEM_PROMPT), |
| |
|
| | ("placeholder", "{messages}"), |
| |
|
| | ]) |
| |
|
| | |
| |
|
| | def create_data_summary(state: AgentState) -> str: |
| |
|
| | summary = "" |
| |
|
| | variables = [] |
| |
|
| | |
| |
|
| | |
| |
|
| | for d in state.get("input_data", []): |
| |
|
| | var_name = d.get("variable_name") |
| |
|
| | if var_name: |
| |
|
| | |
| |
|
| | variables.append(var_name) |
| |
|
| | summary += f"\n\nVariable: {var_name}\n" |
| |
|
| | summary += f"Description: {d.get('data_description', 'No description')}\n" |
| |
|
| | |
| |
|
| | |
| |
|
| | if var_name in st.session_state.in_memory_datasets: |
| |
|
| | df = st.session_state.in_memory_datasets[var_name] |
| |
|
| | summary += "\nSample Data (first 5 rows):\n" |
| |
|
| | summary += df.head(5).to_string() |
| |
|
| | |
| |
|
| | if "current_variables" in state: |
| |
|
| | remaining_variables = [v for v in state["current_variables"] if v not in variables and not v.startswith("_")] |
| |
|
| | |
| |
|
| | for v in remaining_variables: |
| |
|
| | |
| |
|
| | var_value = state["current_variables"].get(v) |
| |
|
| | |
| |
|
| | if isinstance(var_value, pd.DataFrame): |
| |
|
| | summary += f"\n\nVariable: {v} (DataFrame with shape {var_value.shape})" |
| |
|
| | else: |
| |
|
| | summary += f"\n\nVariable: {v}" |
| |
|
| | return summary |
| |
|
| | |
| |
|
| | def route_to_tools(state: AgentState) -> Literal["tools", "__end__"]: |
| |
|
| | """Determine if we should route to tools or end the chain""" |
| |
|
| | if messages := state.get("messages", []): |
| |
|
| | ai_message = messages[-1] |
| |
|
| | else: |
| |
|
| | raise ValueError(f"No messages found in input state to tool_edge: {state}") |
| |
|
| | |
| |
|
| | if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: |
| |
|
| | return "tools" |
| |
|
| | |
| |
|
| | return "__end__" |
| |
|
| | |
| |
|
| | def call_model(state: AgentState): |
| |
|
| | """Call the LLM to get a response""" |
| |
|
| | current_data_template = """The following data is available:\n{data_summary}""" |
| |
|
| | current_data_message = HumanMessage( |
| |
|
| | content=current_data_template.format(data_summary=create_data_summary(state)) |
| |
|
| | ) |
| |
|
| | messages = [current_data_message] + state["messages"] |
| |
|
| | |
| |
|
| | |
| |
|
| | llm = initialize_llm() |
| |
|
| | if llm is None: |
| |
|
| | return {"messages": [AIMessage(content="Please configure a valid API key and model in the settings tab.")]} |
| |
|
| | |
| |
|
| | |
| |
|
| | model = llm.bind_tools(tools) |
| |
|
| | model = chat_template | model |
| |
|
| | |
| |
|
| | llm_outputs = model.invoke({"messages": messages}) |
| |
|
| | return {"messages": [llm_outputs], "intermediate_outputs": [current_data_message.content]} |
| |
|
| | |
| |
|
| | def call_tools(state: AgentState): |
| |
|
| | """Execute tools called by the LLM""" |
| |
|
| | last_message = state["messages"][-1] |
| |
|
| | tool_invocations = [] |
| |
|
| | |
| |
|
| | if isinstance(last_message, AIMessage) and hasattr(last_message, 'tool_calls'): |
| |
|
| | tool_invocations = [ |
| |
|
| | ToolInvocation( |
| |
|
| | tool=tool_call["name"], |
| |
|
| | tool_input={**tool_call["args"], "graph_state": state} |
| |
|
| | ) for tool_call in last_message.tool_calls |
| |
|
| | ] |
| |
|
| | responses = tool_executor.batch(tool_invocations, return_exceptions=True) |
| |
|
| | |
| |
|
| | tool_messages = [] |
| |
|
| | state_updates = {} |
| |
|
| | |
| |
|
| | for tc, response in zip(last_message.tool_calls, responses): |
| |
|
| | if isinstance(response, Exception): |
| |
|
| | print(f"Exception in tool execution: {str(response)}") |
| |
|
| | tool_messages.append(ToolMessage( |
| |
|
| | content=f"Error: {str(response)}", |
| |
|
| | name=tc["name"], |
| |
|
| | tool_call_id=tc["id"] |
| |
|
| | )) |
| |
|
| | continue |
| |
|
| | |
| |
|
| | message, updates = response |
| |
|
| | tool_messages.append(ToolMessage( |
| |
|
| | content=str(message), |
| |
|
| | name=tc["name"], |
| |
|
| | tool_call_id=tc["id"] |
| |
|
| | )) |
| |
|
| | |
| |
|
| | |
| |
|
| | for key, value in updates.items(): |
| |
|
| | if key in state_updates: |
| |
|
| | if isinstance(value, list) and isinstance(state_updates[key], list): |
| |
|
| | state_updates[key].extend(value) |
| |
|
| | elif isinstance(value, dict) and isinstance(state_updates[key], dict): |
| |
|
| | state_updates[key].update(value) |
| |
|
| | else: |
| |
|
| | state_updates[key] = value |
| |
|
| | else: |
| |
|
| | state_updates[key] = value |
| |
|
| | |
| |
|
| | if 'messages' not in state_updates: |
| |
|
| | state_updates["messages"] = [] |
| |
|
| | |
| |
|
| | state_updates["messages"] = tool_messages |
| |
|
| | return state_updates |
| |
|
| | |
| |
|
| | |
| |
|
| | workflow = StateGraph(AgentState) |
| |
|
| | workflow.add_node("agent", call_model) |
| |
|
| | workflow.add_node("tools", call_tools) |
| |
|
| | workflow.add_conditional_edges( |
| |
|
| | "agent", |
| |
|
| | route_to_tools, |
| |
|
| | { |
| |
|
| | "tools": "tools", |
| |
|
| | "__end__": END |
| |
|
| | } |
| |
|
| | ) |
| |
|
| | workflow.add_edge("tools", "agent") |
| |
|
| | workflow.set_entry_point("agent") |
| |
|
| | |
| |
|
| | chain = workflow.compile() |
| |
|
| | |
| |
|
| | def process_file_upload(files): |
| |
|
| | """Process uploaded files and return dataframe previews and column names""" |
| |
|
| | st.session_state.in_memory_datasets = {} |
| |
|
| | st.session_state.dataset_metadata_list = [] |
| |
|
| | st.session_state.persistent_vars.clear() |
| |
|
| | |
| |
|
| | if not files: |
| |
|
| | return "No files uploaded.", [], ["No columns available"] |
| |
|
| | |
| |
|
| | results = [] |
| |
|
| | all_columns = [] |
| |
|
| | |
| |
|
| | for file in files: |
| |
|
| | try: |
| |
|
| | |
| |
|
| | if file.name.endswith('.csv'): |
| |
|
| | df = pd.read_csv(file) |
| |
|
| | elif file.name.endswith(('.xls', '.xlsx')): |
| |
|
| | df = pd.read_excel(file) |
| |
|
| | else: |
| |
|
| | results.append(f"Unsupported file format: {file.name}. Please upload CSV or Excel files.") |
| |
|
| | continue |
| |
|
| | |
| |
|
| | var_name = file.name.split('.')[0].replace('-', '_').replace(' ', '_').lower() |
| |
|
| | st.session_state.in_memory_datasets[var_name] = df |
| |
|
| | |
| |
|
| | |
| |
|
| | all_columns.extend(df.columns.tolist()) |
| |
|
| | |
| |
|
| | |
| |
|
| | dataset_metadata = { |
| |
|
| | "variable_name": var_name, |
| |
|
| | "data_path": "in_memory", |
| |
|
| | "data_description": f"Dataset containing {df.shape[0]} rows and {df.shape[1]} columns. Columns: {', '.join(df.columns.tolist())}", |
| |
|
| | "original_filename": file.name |
| |
|
| | } |
| |
|
| | |
| |
|
| | st.session_state.dataset_metadata_list.append(dataset_metadata) |
| |
|
| | |
| |
|
| | |
| |
|
| | preview = f"### Dataset: {file.name}\nVariable name: `{var_name}`\n\n" |
| |
|
| | preview += df.head(10).to_markdown() |
| |
|
| | results.append(preview) |
| |
|
| | print(f"Successfully processed {file.name}") |
| |
|
| | |
| |
|
| | except Exception as e: |
| |
|
| | print(f"Error processing {file.name}: {str(e)}") |
| |
|
| | results.append(f"Error processing {file.name}: {str(e)}") |
| |
|
| | |
| |
|
| | |
| |
|
| | unique_columns = [] |
| |
|
| | seen = set() |
| |
|
| | |
| |
|
| | for col in all_columns: |
| |
|
| | if col not in seen: |
| |
|
| | seen.add(col) |
| |
|
| | unique_columns.append(col) |
| |
|
| | |
| |
|
| | if not unique_columns: |
| |
|
| | unique_columns = ["No columns available"] |
| |
|
| | |
| |
|
| | print(f"Found {len(unique_columns)} unique columns across datasets") |
| |
|
| | return "\n\n".join(results), st.session_state.dataset_metadata_list, unique_columns |
| |
|
| | |
| |
|
| | def get_columns(): |
| |
|
| | """Directly gets columns from in-memory datasets""" |
| |
|
| | all_columns = [] |
| |
|
| | |
| |
|
| | for var_name, df in st.session_state.in_memory_datasets.items(): |
| |
|
| | if isinstance(df, pd.DataFrame): |
| |
|
| | all_columns.extend(df.columns.tolist()) |
| |
|
| | |
| |
|
| | |
| |
|
| | unique_columns = [] |
| |
|
| | seen = set() |
| |
|
| | |
| |
|
| | for col in all_columns: |
| |
|
| | if col not in seen: |
| |
|
| | seen.add(col) |
| |
|
| | unique_columns.append(col) |
| |
|
| | |
| |
|
| | if not unique_columns: |
| |
|
| | unique_columns = ["No columns available"] |
| |
|
| | |
| |
|
| | print(f"Populating dropdowns with {len(unique_columns)} columns") |
| |
|
| | return unique_columns |
| |
|
| | |
| |
|
| | |
| |
|
| | import openai |
| |
|
| | import pandas as pd |
| |
|
| | import json |
| |
|
| | import re |
| |
|
| | |
| |
|
| | def standard_clean(df): |
| |
|
| | df.columns = [re.sub(r'\W+', '_', col).strip().lower() for col in df.columns] |
| |
|
| | df.drop_duplicates(inplace=True) |
| |
|
| | df.dropna(axis=1, how='all', inplace=True) |
| |
|
| | df.dropna(axis=0, how='all', inplace=True) |
| |
|
| | for col in df.select_dtypes(include='object').columns: |
| |
|
| | df[col] = df[col].astype(str).str.strip() |
| |
|
| | return df |
| |
|
| | |
| |
|
| | def query_openai(prompt): |
| |
|
| | try: |
| |
|
| | |
| |
|
| | api_key = st.session_state.api_key |
| |
|
| | model = st.session_state.selected_model |
| |
|
| | |
| |
|
| | if st.session_state.ai_provider == "openai": |
| |
|
| | client = openai.OpenAI(api_key=api_key) |
| |
|
| | response = client.chat.completions.create( |
| |
|
| | model=model, |
| |
|
| | messages=[{"role": "user", "content": prompt}], |
| |
|
| | temperature=0.7 |
| |
|
| | ) |
| |
|
| | return response.choices[0].message.content |
| |
|
| | elif st.session_state.ai_provider == "groq": |
| |
|
| | from groq import Groq |
| |
|
| | client = Groq(api_key=api_key) |
| |
|
| | response = client.chat.completions.create( |
| |
|
| | model=model, |
| |
|
| | messages=[{"role": "user", "content": prompt}], |
| |
|
| | temperature=0.7 |
| |
|
| | ) |
| |
|
| | return response.choices[0].message.content |
| |
|
| | except Exception as e: |
| |
|
| | print(f"API Error: {e}") |
| |
|
| | return "{}" |
| |
|
| | |
| |
|
| | def llm_suggest_cleaning(df): |
| |
|
| | sample = df.head(10).to_csv(index=False) |
| |
|
| | prompt = f""" |
| | |
| | You are a professional data wrangler. Below is a sample of a messy dataset. |
| | |
| | |
| | |
| | Return a Python dictionary with the following keys: |
| | |
| | |
| | |
| | 1. rename_columns – fix unclear or inconsistent column names |
| | |
| | 2. convert_types – correct datatypes: int, float, str, or date |
| | |
| | 3. fill_missing – use 'mean', 'median', 'mode', or a constant like 'Unknown' or 0 |
| | |
| | 4. value_map – map inconsistent values (e.g., yes/Yes/Y → Yes) |
| | |
| | |
| | |
| | Do not drop any rows or columns. Your output must be a valid Python dict. |
| | |
| | |
| | |
| | Example: |
| | |
| | {{ |
| | |
| | "rename_columns": {{"dob": "date_of_birth"}}, |
| | |
| | "convert_types": {{"age": "int", "salary": "float", "signup_date": "date"}}, |
| | |
| | "fill_missing": {{"gender": "mode", "salary": -1}}, |
| | |
| | "value_map": {{ |
| | |
| | "gender": {{"M": "Male", "F": "Female"}}, |
| | |
| | "subscribed": {{"Y": "Yes", "N": "No"}} |
| | |
| | }} |
| | |
| | }} |
| | |
| | Apart from these mentioned steps, study the data and also do whatever things are good and needed for that particular dataset and do the cleaning. |
| | |
| | Sample data: |
| | |
| | {sample} |
| | |
| | """ |
| |
|
| | raw_response = query_openai(prompt) |
| |
|
| | try: |
| |
|
| | suggestions = eval(raw_response) |
| |
|
| | return suggestions |
| |
|
| | except: |
| |
|
| | print("Could not parse suggestions.") |
| |
|
| | return { |
| |
|
| | "rename_columns": {}, |
| |
|
| | "convert_types": {}, |
| |
|
| | "fill_missing": {}, |
| |
|
| | "value_map": {} |
| |
|
| | } |
| |
|
| | |
| |
|
| | def apply_suggestions(df, suggestions): |
| |
|
| | df.rename(columns=suggestions.get("rename_columns", {}), inplace=True) |
| |
|
| | |
| |
|
| | for col, dtype in suggestions.get("convert_types", {}).items(): |
| |
|
| | if col not in df.columns: |
| |
|
| | continue |
| |
|
| | try: |
| |
|
| | if dtype == "int": |
| |
|
| | df[col] = pd.to_numeric(df[col], errors='coerce').astype("Int64") |
| |
|
| | elif dtype == "float": |
| |
|
| | df[col] = pd.to_numeric(df[col], errors='coerce') |
| |
|
| | elif dtype == "str": |
| |
|
| | df[col] = df[col].astype(str) |
| |
|
| | elif dtype == "date": |
| |
|
| | df[col] = pd.to_datetime(df[col], errors='coerce') |
| |
|
| | except: |
| |
|
| | print(f"Failed to convert {col} to {dtype}") |
| |
|
| | |
| |
|
| | for col, method in suggestions.get("fill_missing", {}).items(): |
| |
|
| | if col not in df.columns: |
| |
|
| | continue |
| |
|
| | try: |
| |
|
| | if method == "mean": |
| |
|
| | df[col].fillna(df[col].mean(), inplace=True) |
| |
|
| | elif method == "median": |
| |
|
| | df[col].fillna(df[col].median(), inplace=True) |
| |
|
| | elif method == "mode": |
| |
|
| | df[col].fillna(df[col].mode().iloc[0], inplace=True) |
| |
|
| | elif isinstance(method, str): |
| |
|
| | df[col].fillna(method, inplace=True) |
| |
|
| | except: |
| |
|
| | print(f"Could not fill missing values for {col}") |
| |
|
| | |
| |
|
| | for col, mapping in suggestions.get("value_map", {}).items(): |
| |
|
| | if col in df.columns: |
| |
|
| | try: |
| |
|
| | df[col] = df[col].replace(mapping) |
| |
|
| | except: |
| |
|
| | print(f"Could not map values in {col}") |
| |
|
| | |
| |
|
| | return df |
| |
|
| | |
| |
|
| | def capture_dashboard_screenshot(): |
| |
|
| | """Capture the entire dashboard as a single image""" |
| |
|
| | try: |
| |
|
| | |
| |
|
| | import plotly.graph_objects as go |
| |
|
| | from plotly.subplots import make_subplots |
| |
|
| | |
| |
|
| | |
| |
|
| | fig = make_subplots(rows=2, cols=2, |
| |
|
| | subplot_titles=["Visualization 1", "Visualization 2", |
| |
|
| | "Visualization 3", "Visualization 4"]) |
| |
|
| | |
| |
|
| | |
| |
|
| | for i, plot in enumerate(st.session_state.dashboard_plots): |
| |
|
| | if plot is not None: |
| |
|
| | row = (i // 2) + 1 |
| |
|
| | col = (i % 2) + 1 |
| |
|
| | |
| |
|
| | |
| |
|
| | for trace in plot.data: |
| |
|
| | fig.add_trace(trace, row=row, col=col) |
| |
|
| | |
| |
|
| | |
| |
|
| | for axis_type in ['xaxis', 'yaxis']: |
| |
|
| | axis_name = f"{axis_type}{i+1 if i > 0 else ''}" |
| |
|
| | subplot_name = f"{axis_type}{row}{col}" |
| |
|
| | |
| |
|
| | |
| |
|
| | if hasattr(plot.layout, axis_name): |
| |
|
| | axis_props = getattr(plot.layout, axis_name) |
| |
|
| | fig.update_layout({subplot_name: axis_props}) |
| |
|
| | |
| |
|
| | |
| |
|
| | fig.update_layout( |
| |
|
| | height=800, |
| |
|
| | width=1000, |
| |
|
| | title_text="Dashboard Overview", |
| |
|
| | showlegend=False, |
| |
|
| | ) |
| |
|
| | |
| |
|
| | |
| |
|
| | dashboard_path = f"{st.session_state.temp_dir}/dashboard_combined.png" |
| |
|
| | fig.write_image(dashboard_path, scale=2) |
| |
|
| | return dashboard_path |
| |
|
| | |
| |
|
| | except Exception as e: |
| |
|
| | import traceback |
| |
|
| | print(f"Error capturing dashboard: {str(e)}") |
| |
|
| | print(traceback.format_exc()) |
| |
|
| | return None |
| |
|
| | |
| |
|
| | def generate_enhanced_pdf_report(): |
| |
|
| | """Generate an enhanced PDF report with proper handling of base64 image data""" |
| |
|
| | try: |
| |
|
| | |
| |
|
| | buffer = io.BytesIO() |
| |
|
| | |
| |
|
| | |
| |
|
| | doc = SimpleDocTemplate(buffer, pagesize=letter, |
| |
|
| | leftMargin=36, rightMargin=36, |
| |
|
| | topMargin=36, bottomMargin=36) |
| |
|
| | |
| |
|
| | |
| |
|
| | styles = getSampleStyleSheet() |
| |
|
| | |
| |
|
| | |
| |
|
| | styles.add(ParagraphStyle( |
| |
|
| | name='ReportTitle', |
| |
|
| | parent=styles['Heading1'], |
| |
|
| | fontSize=24, |
| |
|
| | alignment=1, |
| |
|
| | spaceAfter=20, |
| |
|
| | textColor='#2C3E50' |
| |
|
| | )) |
| |
|
| | |
| |
|
| | styles.add(ParagraphStyle( |
| |
|
| | name='SectionHeader', |
| |
|
| | parent=styles['Heading2'], |
| |
|
| | fontSize=16, |
| |
|
| | spaceBefore=15, |
| |
|
| | spaceAfter=10, |
| |
|
| | textColor='#2C3E50', |
| |
|
| | borderWidth=1, |
| |
|
| | borderColor='#95A5A6', |
| |
|
| | borderPadding=5, |
| |
|
| | borderRadius=5 |
| |
|
| | )) |
| |
|
| | |
| |
|
| | styles.add(ParagraphStyle( |
| |
|
| | name='SubHeader', |
| |
|
| | parent=styles['Heading3'], |
| |
|
| | fontSize=14, |
| |
|
| | spaceBefore=10, |
| |
|
| | spaceAfter=8, |
| |
|
| | textColor='#34495E', |
| |
|
| | fontWeight='bold' |
| |
|
| | )) |
| |
|
| | styles.add(ParagraphStyle( |
| |
|
| | name='UserMessage', |
| |
|
| | parent=styles['Normal'], |
| |
|
| | fontSize=11, |
| |
|
| | leftIndent=10, |
| |
|
| | spaceBefore=8, |
| |
|
| | spaceAfter=4 |
| |
|
| | )) |
| |
|
| | |
| |
|
| | styles.add(ParagraphStyle( |
| |
|
| | name='AssistantMessage', |
| |
|
| | parent=styles['Normal'], |
| |
|
| | fontSize=11, |
| |
|
| | leftIndent=10, |
| |
|
| | spaceBefore=4, |
| |
|
| | spaceAfter=12, |
| |
|
| | textColor='#2980B9' |
| |
|
| | )) |
| |
|
| | |
| |
|
| | styles.add(ParagraphStyle( |
| |
|
| | name='Timestamp', |
| |
|
| | parent=styles['Italic'], |
| |
|
| | fontSize=10, |
| |
|
| | textColor='#7F8C8D', |
| |
|
| | alignment=2 |
| |
|
| | )) |
| |
|
| | |
| |
|
| | |
| |
|
| | elements = [] |
| |
|
| | |
| |
|
| | |
| |
|
| | elements.append(Paragraph('Data Analysis Report', styles['ReportTitle'])) |
| |
|
| | |
| |
|
| | |
| |
|
| | elements.append(Paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}', |
| |
|
| | styles['Timestamp'])) |
| |
|
| | elements.append(Spacer(1, 0.5*inch)) |
| |
|
| | |
| |
|
| | |
| |
|
| | elements.append(Paragraph('Analysis Conversation History', styles['SectionHeader'])) |
| |
|
| | |
| |
|
| | if st.session_state.chat_history: |
| |
|
| | for i, (user_msg, assistant_msg) in enumerate(st.session_state.chat_history): |
| |
|
| | |
| |
|
| | elements.append(Paragraph(f'<b>You:</b>', styles['SubHeader'])) |
| |
|
| | user_msg_formatted = user_msg.replace('\n', '<br/>') |
| |
|
| | elements.append(Paragraph(user_msg_formatted, styles['UserMessage'])) |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | base64_pattern = r'!\[Visualization\]\(data:image\/png;base64,([^\)]+)\)' |
| |
|
| | |
| |
|
| | |
| |
|
| | if '### Visualizations' in assistant_msg or re.search(base64_pattern, assistant_msg): |
| |
|
| | |
| |
|
| | if '### Visualizations' in assistant_msg: |
| |
|
| | parts = assistant_msg.split('### Visualizations', 1) |
| |
|
| | text_part = parts[0] |
| |
|
| | viz_part = "### Visualizations" + parts[1] if len(parts) > 1 else "" |
| |
|
| | else: |
| |
|
| | |
| |
|
| | match = re.search(base64_pattern, assistant_msg) |
| |
|
| | text_part = assistant_msg[:match.start()] |
| |
|
| | viz_part = assistant_msg[match.start():] |
| |
|
| | |
| |
|
| | |
| |
|
| | elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader'])) |
| |
|
| | text_part = text_part.replace('\n', '<br/>') |
| |
|
| | elements.append(Paragraph(text_part, styles['AssistantMessage'])) |
| |
|
| | |
| |
|
| | |
| |
|
| | matches = re.findall(base64_pattern, viz_part) |
| |
|
| | for j, base64_data in enumerate(matches): |
| |
|
| | try: |
| |
|
| | |
| |
|
| | image_data = base64.b64decode(base64_data) |
| |
|
| | |
| |
|
| | |
| |
|
| | temp_img_path = f"{st.session_state.temp_dir}/chat_viz_{i}_{j}.png" |
| |
|
| | |
| |
|
| | with open(temp_img_path, 'wb') as f: |
| |
|
| | f.write(image_data) |
| |
|
| | |
| |
|
| | |
| |
|
| | elements.append(Paragraph(f'<b>Visualization:</b>', styles['SubHeader'])) |
| |
|
| | elements.append(Spacer(1, 0.1*inch)) |
| |
|
| | img = Image(temp_img_path, width=6*inch, height=4*inch) |
| |
|
| | elements.append(img) |
| |
|
| | elements.append(Spacer(1, 0.2*inch)) |
| |
|
| | except Exception as e: |
| |
|
| | print(f"Error processing base64 image: {str(e)}") |
| |
|
| | elements.append(Paragraph(f"[Error displaying visualization: {str(e)}]", |
| |
|
| | styles['Normal'])) |
| |
|
| | else: |
| |
|
| | |
| |
|
| | elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader'])) |
| |
|
| | assistant_msg_formatted = assistant_msg.replace('\n', '<br/>') |
| |
|
| | if len(assistant_msg_formatted) > 1500: |
| |
|
| | assistant_msg_formatted = assistant_msg_formatted[:1500] + '...' |
| |
|
| | elements.append(Paragraph(assistant_msg_formatted, styles['AssistantMessage'])) |
| |
|
| | |
| |
|
| | elements.append(Spacer(1, 0.2*inch)) |
| |
|
| | else: |
| |
|
| | elements.append(Paragraph('No conversation history available.', styles['Normal'])) |
| |
|
| | |
| |
|
| | |
| |
|
| | elements.append(PageBreak()) |
| |
|
| | |
| |
|
| | |
| |
|
| | elements.append(Paragraph('Dashboard Overview', styles['SectionHeader'])) |
| |
|
| | elements.append(Spacer(1, 0.2*inch)) |
| |
|
| | |
| |
|
| | |
| |
|
| | dashboard_img_path = capture_dashboard_screenshot() |
| |
|
| | |
| |
|
| | if dashboard_img_path: |
| |
|
| | |
| |
|
| | available_width = doc.width |
| |
|
| | |
| |
|
| | |
| |
|
| | pil_img = PILImage.open(dashboard_img_path) |
| |
|
| | img_width, img_height = pil_img.size |
| |
|
| | |
| |
|
| | |
| |
|
| | scale_factor = available_width / img_width |
| |
|
| | |
| |
|
| | |
| |
|
| | new_height = img_height * scale_factor |
| |
|
| | |
| |
|
| | |
| |
|
| | img = Image(dashboard_img_path, width=available_width, height=new_height) |
| |
|
| | elements.append(img) |
| |
|
| | else: |
| |
|
| | |
| |
|
| | plot_count = 0 |
| |
|
| | for i, plot in enumerate(st.session_state.dashboard_plots): |
| |
|
| | if plot is not None: |
| |
|
| | plot_count += 1 |
| |
|
| | |
| |
|
| | |
| |
|
| | img_bytes = io.BytesIO() |
| |
|
| | plot.write_image(img_bytes, format='png', width=500, height=300) |
| |
|
| | img_bytes.seek(0) |
| |
|
| | |
| |
|
| | |
| |
|
| | temp_img_path = f"{st.session_state.temp_dir}/plot_{i}.png" |
| |
|
| | |
| |
|
| | with open(temp_img_path, 'wb') as f: |
| |
|
| | f.write(img_bytes.getvalue()) |
| |
|
| | |
| |
|
| | |
| |
|
| | elements.append(Paragraph(f'Dashboard Visualization {i+1}', styles['SubHeader'])) |
| |
|
| | elements.append(Spacer(1, 0.1*inch)) |
| |
|
| | |
| |
|
| | |
| |
|
| | img = Image(temp_img_path, width=6.5*inch, height=4*inch) |
| |
|
| | elements.append(img) |
| |
|
| | elements.append(Spacer(1, 0.3*inch)) |
| |
|
| | |
| |
|
| | if plot_count == 0: |
| |
|
| | elements.append(Paragraph('No visualizations have been added to the dashboard.', |
| |
|
| | styles['Normal'])) |
| |
|
| | |
| |
|
| | |
| |
|
| | doc.build(elements) |
| |
|
| | |
| |
|
| | |
| |
|
| | pdf_value = buffer.getvalue() |
| |
|
| | buffer.close() |
| |
|
| | |
| |
|
| | return pdf_value |
| |
|
| | |
| |
|
| | except Exception as e: |
| |
|
| | import traceback |
| |
|
| | print(f"Error generating enhanced PDF report: {str(e)}") |
| |
|
| | print(traceback.format_exc()) |
| |
|
| | return None |
| |
|
| | |
| |
|
| | def chat_with_workflow(message, history, dataset_info): |
| |
|
| | """Send user query to the workflow and get response""" |
| |
|
| | |
| |
|
| | if not dataset_info: |
| |
|
| | return "Please upload at least one dataset before asking questions." |
| |
|
| | |
| |
|
| | |
| |
|
| | if not st.session_state.api_key: |
| |
|
| | return "Please set up your API key and model in the Settings tab before chatting." |
| |
|
| | |
| |
|
| | print(f"Chat with workflow called with {len(dataset_info)} datasets") |
| |
|
| | |
| |
|
| | try: |
| |
|
| | |
| |
|
| | max_history = 3 |
| |
|
| | previous_messages = [] |
| |
|
| | |
| |
|
| | if history: |
| |
|
| | start_idx = max(0, len(history) - max_history) |
| |
|
| | recent_history = history[start_idx:] |
| |
|
| | |
| |
|
| | for exchange in recent_history: |
| |
|
| | if exchange[0]: |
| |
|
| | previous_messages.append(HumanMessage(content=exchange[0])) |
| |
|
| | if exchange[1]: |
| |
|
| | previous_messages.append(AIMessage(content=exchange[1])) |
| |
|
| | |
| |
|
| | |
| |
|
| | state = AgentState( |
| |
|
| | messages=previous_messages + [HumanMessage(content=message)], |
| |
|
| | input_data=dataset_info, |
| |
|
| | intermediate_outputs=[], |
| |
|
| | current_variables=st.session_state.persistent_vars, |
| |
|
| | output_image_paths=[] |
| |
|
| | ) |
| |
|
| | |
| |
|
| | |
| |
|
| | print("Executing workflow...") |
| |
|
| | result = chain.invoke(state) |
| |
|
| | print("Workflow execution completed") |
| |
|
| | |
| |
|
| | |
| |
|
| | messages = result["messages"] |
| |
|
| | |
| |
|
| | |
| |
|
| | response = "" |
| |
|
| | if messages: |
| |
|
| | latest_message = messages[-1] |
| |
|
| | if hasattr(latest_message, "content"): |
| |
|
| | content = latest_message.content |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | if message in content: |
| |
|
| | content = content.split(message)[-1].strip() |
| |
|
| | |
| |
|
| | |
| |
|
| | content_lines = content.split('\n') |
| |
|
| | filtered_lines = [line for line in content_lines |
| |
|
| | if not line.strip().startswith(("You:", "User:", "Human:", "Assistant:"))] |
| |
|
| | content = '\n'.join(filtered_lines) |
| |
|
| | |
| |
|
| | response = content.strip() + "\n\n" |
| |
|
| | |
| |
|
| | |
| |
|
| | if "output_image_paths" in result and result["output_image_paths"]: |
| |
|
| | response += "### Visualizations\n\n" |
| |
|
| | for img_path in result["output_image_paths"]: |
| |
|
| | try: |
| |
|
| | full_path = os.path.join(st.session_state.images_dir, img_path) |
| |
|
| | with open(full_path, 'rb') as f: |
| |
|
| | fig = pickle.load(f) |
| |
|
| | |
| |
|
| | |
| |
|
| | img_bytes = BytesIO() |
| |
|
| | fig.update_layout(width=800, height=500) |
| |
|
| | pio.write_image(fig, img_bytes, format='png') |
| |
|
| | img_bytes.seek(0) |
| |
|
| | |
| |
|
| | |
| |
|
| | b64_img = base64.b64encode(img_bytes.read()).decode() |
| |
|
| | response += f"\n\n" |
| |
|
| | except Exception as e: |
| |
|
| | response += f"Error loading visualization: {str(e)}\n\n" |
| |
|
| | |
| |
|
| | return response |
| |
|
| | |
| |
|
| | except Exception as e: |
| |
|
| | import traceback |
| |
|
| | print(f"Error in chat_with_workflow: {str(e)}") |
| |
|
| | print(traceback.format_exc()) |
| |
|
| | return f"Error executing workflow: {str(e)}" |
| |
|
| | |
| |
|
| | def auto_generate_dashboard(dataset_info): |
| |
|
| | """Generate an automatic dashboard with four plots""" |
| |
|
| | |
| |
|
| | if not dataset_info: |
| |
|
| | return "Please upload a dataset first.", [None, None, None, None] |
| |
|
| | |
| |
|
| | prompt = """ |
| | |
| | You are a data visualization expert. Given a dataset, identify the top 4 most insightful plots using statistical reasoning or patterns (correlation, distribution, trends). |
| | |
| | |
| | |
| | Use plotly and store the plots in a list named plotly_figures. |
| | |
| | |
| | |
| | Include multivariate plots using color/size/facets when helpful. |
| | |
| | """ |
| |
|
| | |
| |
|
| | state = AgentState( |
| |
|
| | messages=[HumanMessage(content=prompt)], |
| |
|
| | input_data=dataset_info, |
| |
|
| | intermediate_outputs=[], |
| |
|
| | current_variables=st.session_state.persistent_vars, |
| |
|
| | output_image_paths=[] |
| |
|
| | ) |
| |
|
| | |
| |
|
| | result = chain.invoke(state) |
| |
|
| | figures = [] |
| |
|
| | |
| |
|
| | if "output_image_paths" in result: |
| |
|
| | for img_path in result["output_image_paths"][:4]: |
| |
|
| | try: |
| |
|
| | full_path = os.path.join(st.session_state.images_dir, img_path) |
| |
|
| | with open(full_path, 'rb') as f: |
| |
|
| | fig = pickle.load(f) |
| |
|
| | figures.append(fig) |
| |
|
| | except Exception as e: |
| |
|
| | print(f"Error loading figure: {e}") |
| |
|
| | |
| |
|
| | while len(figures) < 4: |
| |
|
| | figures.append(None) |
| |
|
| | |
| |
|
| | st.session_state.dashboard_plots = figures |
| |
|
| | return "Dashboard generated!", figures |
| |
|
| | |
| |
|
| | def generate_custom_plots_with_llm(dataset_info, x_col, y_col, facet_col): |
| |
|
| | """Generate custom plots based on user-selected columns""" |
| |
|
| | |
| |
|
| | if not dataset_info or not x_col or not y_col: |
| |
|
| | return [None, None, None] |
| |
|
| | |
| |
|
| | prompt = f""" |
| | |
| | You are a data visualization expert. |
| | |
| | |
| | |
| | Create 3 insightful visualizations using Plotly based on: |
| | |
| | |
| | |
| | - X-axis: {x_col} |
| | |
| | - Y-axis: {y_col} |
| | |
| | - Facet (optional): {facet_col if facet_col != 'None' else 'None'} |
| | |
| | |
| | |
| | Try to find interesting relationships, trends, or clusters using appropriate chart types. |
| | |
| | |
| | |
| | Use `plotly_figures` list and avoid using fig.show(). |
| | |
| | """ |
| |
|
| | |
| |
|
| | state = AgentState( |
| |
|
| | messages=[HumanMessage(content=prompt)], |
| |
|
| | input_data=dataset_info, |
| |
|
| | intermediate_outputs=[], |
| |
|
| | current_variables=st.session_state.persistent_vars, |
| |
|
| | output_image_paths=[] |
| |
|
| | ) |
| |
|
| | |
| |
|
| | result = chain.invoke(state) |
| |
|
| | figures = [] |
| |
|
| | |
| |
|
| | if "output_image_paths" in result: |
| |
|
| | for img_path in result["output_image_paths"][:3]: |
| |
|
| | try: |
| |
|
| | full_path = os.path.join(st.session_state.images_dir, img_path) |
| |
|
| | with open(full_path, 'rb') as f: |
| |
|
| | fig = pickle.load(f) |
| |
|
| | figures.append(fig) |
| |
|
| | except Exception as e: |
| |
|
| | print(f"Error loading figure: {e}") |
| |
|
| | |
| |
|
| | while len(figures) < 3: |
| |
|
| | figures.append(None) |
| |
|
| | return figures |
| |
|
| | |
| |
|
| | def remove_plot(index): |
| |
|
| | """Remove a plot from the dashboard""" |
| |
|
| | if 0 <= index < len(st.session_state.dashboard_plots): |
| |
|
| | st.session_state.dashboard_plots[index] = None |
| |
|
| | |
| |
|
| | def respond(message): |
| |
|
| | """Handle chat message response""" |
| |
|
| | if not st.session_state.dataset_metadata_list: |
| |
|
| | bot_message = "Please upload at least one dataset before asking questions." |
| |
|
| | else: |
| |
|
| | bot_message = chat_with_workflow(message, st.session_state.chat_history, st.session_state.dataset_metadata_list) |
| |
|
| | |
| |
|
| | st.session_state.chat_history.append((message, bot_message)) |
| |
|
| | st.rerun() |
| |
|
| | |
| |
|
| | def save_plot_to_dashboard(plot_index): |
| |
|
| | """Callback for the Add Plot button""" |
| |
|
| | for i in range(len(st.session_state.dashboard_plots)): |
| |
|
| | if st.session_state.dashboard_plots[i] is None: |
| |
|
| | |
| |
|
| | st.session_state.dashboard_plots[i] = st.session_state.custom_plots_to_save[plot_index] |
| |
|
| | return |
| |
|
| | |
| |
|
| | |
| |
|
| | st.set_page_config(page_title="QueryMind 🧠", layout="wide") |
| |
|
| | st.title("QueryMind 🧠 - Data Assistant") |
| |
|
| | st.markdown("Upload your datasets, ask questions, and generate visualizations to gain insights.") |
| |
|
| | |
| |
|
| | |
| |
|
| | tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs(["Upload Datasets", "Data Cleaning", "Chat with AI Assistant", "Auto Dashboard Generator", "Generate Report", "Settings"]) |
| |
|
| | |
| |
|
| | with tab1: |
| |
|
| | st.header("Upload Datasets") |
| |
|
| | uploaded_files = st.file_uploader("Upload CSV or Excel Files", |
| |
|
| | accept_multiple_files=True, |
| |
|
| | type=['csv', 'xlsx', 'xls']) |
| |
|
| | |
| |
|
| | if uploaded_files and st.button("Process Uploaded Files"): |
| |
|
| | with st.spinner("Processing files..."): |
| |
|
| | preview, metadata_list, columns = process_file_upload(uploaded_files) |
| |
|
| | st.session_state.columns = columns |
| |
|
| | |
| |
|
| | |
| |
|
| | st.success(f"✅ Successfully processed {len(uploaded_files)} file(s)") |
| |
|
| | |
| |
|
| | |
| |
|
| | st.subheader("Dataset Previews") |
| |
|
| | |
| |
|
| | for dataset_name, df in st.session_state.in_memory_datasets.items(): |
| |
|
| | with st.expander(f"Preview: {dataset_name}"): |
| |
|
| | |
| |
|
| | st.write(f"**Rows:** {df.shape[0]} | **Columns:** {df.shape[1]}") |
| |
|
| | |
| |
|
| | |
| |
|
| | col_info = pd.DataFrame({ |
| |
|
| | 'Column Name': df.columns, |
| |
|
| | 'Data Type': df.dtypes.astype(str), |
| |
|
| | 'Non-Null Count': df.count().values, |
| |
|
| | 'Sample Values': [', '.join(df[col].dropna().astype(str).head(3).tolist()) for col in df.columns] |
| |
|
| | }) |
| |
|
| | |
| |
|
| | |
| |
|
| | st.write("**Column Information:**") |
| |
|
| | st.dataframe(col_info, use_container_width=True) |
| |
|
| | |
| |
|
| | |
| |
|
| | st.write("**Data Preview (First 10 rows):**") |
| |
|
| | st.dataframe(df.head(10), use_container_width=True) |
| |
|
| | |
| |
|
| | |
| |
|
| | st.info("👆 Click on the dataset names above to see detailed previews. Then proceed to the Data Cleaning tab to clean your data or Chat with AI Assistant to analyze it.") |
| |
|
| | |
| |
|
| | with tab2: |
| |
|
| | st.header("Data Cleaning") |
| |
|
| | |
| |
|
| | if 'cleaning_done' not in st.session_state: |
| |
|
| | st.session_state.cleaning_done = False |
| |
|
| | |
| |
|
| | if 'cleaned_datasets' not in st.session_state: |
| |
|
| | st.session_state.cleaned_datasets = {} |
| |
|
| | |
| |
|
| | if 'cleaning_summaries' not in st.session_state: |
| |
|
| | st.session_state.cleaning_summaries = {} |
| |
|
| | |
| |
|
| | if st.session_state.get("in_memory_datasets"): |
| |
|
| | if not st.session_state.cleaning_done: |
| |
|
| | if st.button("Run Data Cleaning"): |
| |
|
| | with st.spinner("Running LLM-assisted cleaning..."): |
| |
|
| | for name, df in st.session_state.in_memory_datasets.items(): |
| |
|
| | raw_df = df.copy() |
| |
|
| | df_std = standard_clean(raw_df.copy()) |
| |
|
| | suggestions = llm_suggest_cleaning(df_std.copy()) |
| |
|
| | df_clean = apply_suggestions(df_std.copy(), suggestions) |
| |
|
| | st.session_state.cleaned_datasets[name] = df_clean |
| |
|
| | st.session_state.cleaning_summaries[name] = suggestions |
| |
|
| | st.session_state.cleaning_done = True |
| |
|
| | st.rerun() |
| |
|
| | else: |
| |
|
| | st.info("Click Run Data Cleaning to clean your datasets using the LLM.") |
| |
|
| | else: |
| |
|
| | for name, df_clean in st.session_state.cleaned_datasets.items(): |
| |
|
| | raw_df = st.session_state.in_memory_datasets[name] |
| |
|
| | |
| |
|
| | st.subheader(f"Dataset: {name}") |
| |
|
| | col1, col2 = st.columns(2) |
| |
|
| | |
| |
|
| | with col1: |
| |
|
| | st.markdown("Original Data (First 5 Rows)") |
| |
|
| | st.dataframe(raw_df.head()) |
| |
|
| | |
| |
|
| | with col2: |
| |
|
| | st.markdown("Cleaned Data (First 5 Rows)") |
| |
|
| | st.dataframe(df_clean.head()) |
| |
|
| | |
| |
|
| | st.markdown("Summary of Cleaning Actions") |
| |
|
| | suggestions = st.session_state.cleaning_summaries[name] |
| |
|
| | summary_text = "" |
| |
|
| | |
| |
|
| | if suggestions: |
| |
|
| | for key, value in suggestions.items(): |
| |
|
| | summary_text += f"**{key}**: {json.dumps(value, indent=2)}\n\n" |
| |
|
| | st.markdown(summary_text) |
| |
|
| | |
| |
|
| | st.markdown("Refine the Cleaning (Natural Language Instructions)") |
| |
|
| | user_input = st.text_input("Example: Convert 'dob' to datetime and fill missing with '2000-01-01'", |
| |
|
| | key=f"user_input_{name}") |
| |
|
| | |
| |
|
| | if f'corrections_{name}' not in st.session_state: |
| |
|
| | st.session_state[f'corrections_{name}'] = [] |
| |
|
| | |
| |
|
| | if st.button("Apply Correction", key=f'apply_correction_{name}'): |
| |
|
| | if user_input.strip(): |
| |
|
| | correction_prompt = f""" |
| | |
| | You are a data cleaning expert. Below is a previously cleaned dataset with these actions: |
| | |
| | |
| | |
| | {summary_text} |
| | |
| | |
| | |
| | The user now wants the following additional instruction: |
| | |
| | \"{user_input.strip()}\" |
| | |
| | |
| | |
| | Write only the Python code that modifies the pandas DataFrame `df` accordingly. |
| | |
| | Do not include explanations or markdown. |
| | |
| | """ |
| |
|
| | correction_code = query_openai(correction_prompt) |
| |
|
| | |
| |
|
| | try: |
| |
|
| | df = st.session_state.cleaned_datasets[name].copy() |
| |
|
| | local_vars = {"df": df} |
| |
|
| | exec(correction_code, {}, local_vars) |
| |
|
| | df_updated = local_vars["df"] |
| |
|
| | |
| |
|
| | st.session_state.cleaned_datasets[name] = df_updated |
| |
|
| | st.session_state[f'corrections_{name}'].append((user_input, correction_code)) |
| |
|
| | st.success("Correction applied.") |
| |
|
| | st.rerun() |
| |
|
| | |
| |
|
| | except Exception as e: |
| |
|
| | st.error(f"Failed to apply correction: {str(e)}") |
| |
|
| | |
| |
|
| | if st.session_state[f'corrections_{name}']: |
| |
|
| | st.markdown("Applied Corrections") |
| |
|
| | for i, (msg, code) in enumerate(st.session_state[f'corrections_{name}']): |
| |
|
| | st.markdown(f"**Instruction:** {msg}") |
| |
|
| | st.code(code, language='python') |
| |
|
| | |
| |
|
| | col1, col2 = st.columns([1, 2]) |
| |
|
| | with col1: |
| |
|
| | if st.button("Reset Cleaning and Re-run"): |
| |
|
| | st.session_state.cleaning_done = False |
| |
|
| | st.rerun() |
| |
|
| | |
| |
|
| | with col2: |
| |
|
| | if st.button("Finalize and Proceed to Visualizations"): |
| |
|
| | st.session_state.cleaning_finalized = True |
| |
|
| | st.rerun() |
| |
|
| | else: |
| |
|
| | st.info("Please upload and process datasets first.") |
| |
|
| | |
| |
|
| | with tab3: |
| |
|
| | st.header("Chat with AI Assistant") |
| |
|
| | |
| |
|
| | |
| |
|
| | if not st.session_state.api_key: |
| |
|
| | st.warning("⚠️ Please set up your API key and model in the Settings tab before using the chat.") |
| |
|
| | |
| |
|
| | st.markdown(""" |
| | |
| | ## Example Questions |
| | |
| | - "What analysis can you perform on this dataset?" |
| | |
| | - "Show me basic statistics for all columns" |
| | |
| | - "Create a correlation heatmap" |
| | |
| | - "Plot the distribution of a specific column" |
| | |
| | - "What is the relationship between two columns?" |
| | |
| | """) |
| |
|
| | |
| |
|
| | |
| |
|
| | for exchange in st.session_state.chat_history: |
| |
|
| | with st.chat_message("user"): |
| |
|
| | st.write(exchange[0]) |
| |
|
| | with st.chat_message("assistant"): |
| |
|
| | st.write(exchange[1]) |
| |
|
| | |
| |
|
| | |
| |
|
| | if prompt := st.chat_input("Your question"): |
| |
|
| | with st.spinner("Thinking..."): |
| |
|
| | respond(prompt) |
| |
|
| | |
| |
|
| | with tab4: |
| |
|
| | st.header("Auto Dashboard Generator") |
| |
|
| | |
| |
|
| | |
| |
|
| | dashboard_title = st.text_input("Dashboard Title", placeholder="Enter your dashboard title") |
| |
|
| |
|
| | col1, col2 = st.columns(2) |
| |
|
| | |
| |
|
| | with col1: |
| |
|
| | if st.button("Generate Suggested Dashboard (Auto)"): |
| |
|
| | if not st.session_state.api_key: |
| |
|
| | st.warning("⚠️ Please set up your API key and model in the Settings tab first.") |
| |
|
| | else: |
| |
|
| | with st.spinner("Generating dashboard..."): |
| |
|
| | message, figures = auto_generate_dashboard(st.session_state.dataset_metadata_list) |
| |
|
| | st.success(message) |
| |
|
| | |
| |
|
| | with col2: |
| |
|
| | if st.button("Refresh Column Options"): |
| |
|
| | st.session_state.columns = get_columns() |
| |
|
| | st.rerun() |
| |
|
| | |
| |
|
| | |
| |
|
| | st.subheader("Dashboard") |
| |
|
| | |
| |
|
| | |
| |
|
| | col1, col2 = st.columns(2) |
| |
|
| | |
| |
|
| | with col1: |
| |
|
| | if st.session_state.dashboard_plots[0]: |
| |
|
| | st.plotly_chart(st.session_state.dashboard_plots[0], use_container_width=True) |
| |
|
| | if st.button("Remove Plot 1"): |
| |
|
| | remove_plot(0) |
| |
|
| | st.rerun() |
| |
|
| | |
| |
|
| | with col2: |
| |
|
| | if st.session_state.dashboard_plots[1]: |
| |
|
| | st.plotly_chart(st.session_state.dashboard_plots[1], use_container_width=True) |
| |
|
| | if st.button("Remove Plot 2"): |
| |
|
| | remove_plot(1) |
| |
|
| | st.rerun() |
| |
|
| | |
| |
|
| | |
| |
|
| | col3, col4 = st.columns(2) |
| |
|
| | |
| |
|
| | with col3: |
| |
|
| | if st.session_state.dashboard_plots[2]: |
| |
|
| | st.plotly_chart(st.session_state.dashboard_plots[2], use_container_width=True) |
| |
|
| | if st.button("Remove Plot 3"): |
| |
|
| | remove_plot(2) |
| |
|
| | st.rerun() |
| |
|
| | |
| |
|
| | with col4: |
| |
|
| | if st.session_state.dashboard_plots[3]: |
| |
|
| | st.plotly_chart(st.session_state.dashboard_plots[3], use_container_width=True) |
| |
|
| | if st.button("Remove Plot 4"): |
| |
|
| | remove_plot(3) |
| |
|
| | st.rerun() |
| |
|
| | |
| |
|
| | |
| |
|
| | st.subheader("Custom Plot Generator") |
| |
|
| | |
| |
|
| | |
| |
|
| | col1, col2, col3 = st.columns(3) |
| |
|
| | |
| |
|
| | with col1: |
| |
|
| | x_axis = st.selectbox("X-axis Column", options=st.session_state.columns) |
| |
|
| | |
| |
|
| | with col2: |
| |
|
| | y_axis = st.selectbox("Y-axis Column", options=st.session_state.columns) |
| |
|
| | |
| |
|
| | with col3: |
| |
|
| | facet = st.selectbox("Facet (optional)", options=["None"] + st.session_state.columns) |
| |
|
| |
|
| | if st.button("Generate Custom Visualizations"): |
| |
|
| | if not st.session_state.api_key: |
| |
|
| | st.warning("⚠️ Please set up your API key and model in the Settings tab first.") |
| |
|
| | else: |
| |
|
| | with st.spinner("Generating custom visualizations..."): |
| |
|
| | custom_plots = generate_custom_plots_with_llm(st.session_state.dataset_metadata_list, x_axis, y_axis, facet) |
| |
|
| | |
| |
|
| | for i, plot in enumerate(custom_plots): |
| |
|
| | if plot: |
| |
|
| | st.session_state.custom_plots_to_save[i] = plot |
| |
|
| | |
| |
|
| | |
| |
|
| | for i, plot in enumerate(custom_plots): |
| |
|
| | if plot: |
| |
|
| | st.plotly_chart(plot, use_container_width=True) |
| |
|
| | st.button( |
| |
|
| | f"Add Plot {i+1} to Dashboard", |
| |
|
| | key=f"add_plot_{i}", |
| |
|
| | on_click=save_plot_to_dashboard, |
| |
|
| | args=(i,) |
| |
|
| | ) |
| |
|
| | |
| |
|
| | with tab5: |
| |
|
| | st.header("Generate Analysis Report") |
| |
|
| | |
| |
|
| | st.markdown(""" |
| | |
| | Generate a PDF report containing: |
| | |
| | - Dashboard visualizations |
| | |
| | - Chat conversation history |
| | |
| | """) |
| |
|
| | |
| |
|
| | report_title = st.text_input("Report Title (Optional)", "Data Analysis Report") |
| |
|
| | |
| |
|
| | if st.button("Generate PDF Report"): |
| |
|
| | if not st.session_state.api_key: |
| |
|
| | st.warning("⚠️ Please set up your API key and model in the Settings tab first.") |
| |
|
| | else: |
| |
|
| | with st.spinner("Generating report..."): |
| |
|
| | pdf_data = generate_enhanced_pdf_report() |
| |
|
| | if pdf_data: |
| |
|
| | |
| |
|
| | b64_pdf = base64.b64encode(pdf_data).decode('utf-8') |
| |
|
| | |
| |
|
| | pdf_download_link = f'<a href="data:application/pdf;base64,{b64_pdf}" download="data_analysis_report.pdf">Download PDF Report</a>' |
| |
|
| | st.markdown("### Your report is ready!") |
| |
|
| | st.markdown(pdf_download_link, unsafe_allow_html=True) |
| |
|
| | |
| |
|
| | with st.expander("Preview Report"): |
| |
|
| | st.warning("PDF preview is not available in Streamlit, please download the report to view it.") |
| |
|
| | else: |
| |
|
| | st.error("Failed to generate the report. Please try again.") |
| |
|
| | |
| |
|
| | with tab6: |
| |
|
| | st.header("AI Provider Settings") |
| |
|
| | |
| |
|
| | |
| |
|
| | provider = st.radio("Select AI Provider", |
| |
|
| | options=["OpenAI", "Groq"], |
| |
|
| | index=0 if st.session_state.ai_provider == "openai" else 1, |
| |
|
| | horizontal=True) |
| |
|
| | |
| |
|
| | |
| |
|
| | st.session_state.ai_provider = provider.lower() |
| |
|
| | |
| |
|
| | |
| |
|
| | api_key = st.text_input("Enter API Key", |
| |
|
| | value=st.session_state.api_key, |
| |
|
| | type="password", |
| |
|
| | help="Your API key for the selected provider") |
| |
|
| | |
| |
|
| | |
| |
|
| | if st.session_state.ai_provider == "openai": |
| |
|
| | model_options = OPENAI_MODELS |
| |
|
| | model_help = "GPT-4 provides the best results but is slower. GPT-3.5-Turbo is faster but less capable." |
| |
|
| | else: |
| |
|
| | model_options = GROQ_MODELS |
| |
|
| | model_help = "Llama 3.3 70B is most capable. Gemma 2 9B offers good balance. Llama 3 8B is fastest." |
| |
|
| | |
| |
|
| | |
| |
|
| | selected_model = st.selectbox("Select Model", |
| |
|
| | options=model_options, |
| |
|
| | index=model_options.index(st.session_state.selected_model) if st.session_state.selected_model in model_options else 0, |
| |
|
| | help=model_help) |
| |
|
| | |
| |
|
| | |
| |
|
| | if st.button("Save Settings"): |
| |
|
| | st.session_state.api_key = api_key |
| |
|
| | st.session_state.selected_model = selected_model |
| |
|
| | |
| |
|
| | |
| |
|
| | try: |
| |
|
| | |
| |
|
| | test_llm = initialize_llm() |
| |
|
| | if test_llm: |
| |
|
| | st.success(f"✅ Successfully configured {provider} with model: {selected_model}") |
| |
|
| | else: |
| |
|
| | st.error("Failed to initialize the AI provider. Please check your API key and model selection.") |
| |
|
| | except Exception as e: |
| |
|
| | st.error(f"Error testing settings: {str(e)}") |
| |
|
| | |
| |
|
| | |
| |
|
| | st.subheader("Current Settings") |
| |
|
| | settings_info = f""" |
| | |
| | - **Provider**: {st.session_state.ai_provider.upper()} |
| | |
| | - **Model**: {st.session_state.selected_model} |
| | |
| | - **API Key**: {'✅ Set' if st.session_state.api_key else '❌ Not Set'} |
| | |
| | """ |
| |
|
| | st.markdown(settings_info) |
| |
|
| | |
| |
|
| | |
| |
|
| | if st.session_state.ai_provider == "openai": |
| |
|
| | st.info(""" |
| | |
| | **OpenAI Models Information:** |
| | |
| | - **GPT-4**: Most powerful model, best for complex analysis and detailed explanations |
| | |
| | - **GPT-4-Turbo**: Faster than GPT-4 with similar capabilities |
| | |
| | - **GPT-4-Mini**: Economical option with good performance for standard tasks |
| | |
| | - **GPT-3.5-Turbo**: Fastest option, suitable for basic analysis and visualization |
| | |
| | """) |
| |
|
| | else: |
| |
|
| | st.info(""" |
| | |
| | **Groq Models Information:** |
| | |
| | - **llama3.3-70b-versatile**: Most powerful model for comprehensive analysis |
| | |
| | - **gemma2-9b-it**: Good balance of speed and capabilities |
| | |
| | - **llama-3-8b-8192**: Fastest option for basic analysis tasks |
| | |
| | """) |
| |
|
| | |
| |
|
| | |
| |
|
| | with st.expander("How to get API Keys"): |
| |
|
| | if st.session_state.ai_provider == "openai": |
| |
|
| | st.markdown(""" |
| | |
| | ### Getting an OpenAI API Key |
| | |
| | |
| | |
| | 1. Go to [OpenAI's platform](https://platform.openai.com) |
| | |
| | 2. Sign up or log in to your account |
| | |
| | 3. Navigate to the API section |
| | |
| | 4. Create a new API key |
| | |
| | 5. Copy the key and paste it above |
| | |
| | |
| | |
| | Note: OpenAI API usage incurs charges based on tokens used. |
| | |
| | """) |
| |
|
| | else: |
| |
|
| | st.markdown(""" |
| | |
| | ### Getting a Groq API Key |
| | |
| | |
| | |
| | 1. Go to [Groq's website](https://console.groq.com/keys) |
| | |
| | 2. Sign up or log in to your account |
| | |
| | 3. Navigate to API Keys section |
| | |
| | 4. Create a new API key |
| | |
| | 5. Copy the key and paste it above |
| | |
| | |
| | |
| | Note: Check Groq's pricing page for current rates. |
| | |
| | """) |
| |
|
| | |
| |
|
| | |
| |
|
| | def cleanup(): |
| |
|
| | try: |
| |
|
| | shutil.rmtree(st.session_state.temp_dir) |
| |
|
| | print(f"Cleaned up temporary directory: {st.session_state.temp_dir}") |
| |
|
| | except Exception as e: |
| |
|
| | print(f"Error cleaning up: {e}") |
| |
|
| | |
| |
|
| | import atexit |
| |
|
| | atexit.register(cleanup) |