| | import streamlit as st |
| | import os |
| | import pandas as pd |
| | from typing import Literal, TypedDict |
| | from sqlalchemy import create_engine, inspect, text |
| | from transformers import AutoTokenizer |
| | from utils import pprint |
| | import time |
| | import re |
| |
|
| | from openai import OpenAI |
| | import anthropic |
| | from clients.openRouter import OpenRouter |
| |
|
| | |
| | from dotenv import load_dotenv |
| | load_dotenv() |
| |
|
| | |
| | st.set_page_config( |
| | page_title="SQL Query Assistant", |
| | page_icon="💾", |
| | layout="centered", |
| | initial_sidebar_state="collapsed" |
| | ) |
| |
|
| | ModelType = Literal["GPT_4o", "GPT_o1", "CLAUDE", "LLAMA", "DEEPSEEK", "DEEPSEEK_R1", "DEEPSEEK_R1_DISTILL"] |
| | ModelConfig = TypedDict("ModelConfig", { |
| | "client": OpenAI | anthropic.Anthropic, |
| | "model": str, |
| | "max_context": int, |
| | "tokenizer": AutoTokenizer |
| | }) |
| |
|
| | MODEL_CONFIG: dict[ModelType, ModelConfig] = { |
| | "CLAUDE_HAIKU": { |
| | "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), |
| | "model": "claude-3-5-haiku-20241022", |
| | |
| | |
| | "max_context": 40000, |
| | "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") |
| | }, |
| | "CLAUDE_SONNET": { |
| | "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), |
| | |
| | |
| | "model": "claude-3-5-sonnet-20240620", |
| | "max_context": 40000, |
| | "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") |
| | }, |
| | "GPT_4o": { |
| | "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), |
| | "model": "gpt-4o", |
| | "max_context": 15000, |
| | "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") |
| | }, |
| | |
| | |
| | |
| | |
| | |
| | |
| | "DEEPSEEK": { |
| | "client": OpenRouter( |
| | api_key=os.environ.get("OPENROUTER_API_KEY"), |
| | ), |
| | "model": "deepseek/deepseek-chat", |
| | "max_context": 30000, |
| | "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") |
| | }, |
| | "DEEPSEEK_R1": { |
| | "client": OpenRouter( |
| | api_key=os.environ.get("OPENROUTER_API_KEY"), |
| | ), |
| | "model": "deepseek/deepseek-r1", |
| | "max_context": 30000, |
| | "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") |
| | }, |
| | } |
| |
|
| |
|
| | def get_model_type(): |
| | """ |
| | Get the model type from Streamlit sidebar with model names |
| | """ |
| | |
| | available_models = list(MODEL_CONFIG.keys()) |
| | |
| | |
| | model_display_labels = [ |
| | MODEL_CONFIG[model_type]['model'] |
| | for model_type in available_models |
| | ] |
| | |
| | |
| | selected_model_name = st.sidebar.selectbox( |
| | "Select AI Model", |
| | model_display_labels, |
| | index=0 |
| | ) |
| | |
| | |
| | selected_model_type = next( |
| | model_type for model_type in available_models |
| | if MODEL_CONFIG[model_type]['model'] == selected_model_name |
| | ) |
| | |
| | return selected_model_type |
| |
|
| |
|
| | |
| | modelType = get_model_type() |
| |
|
| | client = MODEL_CONFIG[modelType]["client"] |
| | MODEL = MODEL_CONFIG[modelType]["model"] |
| | TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL |
| | MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"] |
| | tokenizer = MODEL_CONFIG[modelType]["tokenizer"] |
| |
|
| | isClaudeModel = modelType.startswith("CLAUDE") |
| | isDeepSeekModel = modelType.startswith("DEEPSEEK") |
| |
|
| |
|
| | def __countTokens(text): |
| | text = str(text) |
| | tokens = tokenizer.encode(text, add_special_tokens=False) |
| | return len(tokens) |
| |
|
| |
|
| | |
| | if "ipAddress" not in st.session_state: |
| | st.session_state.ipAddress = st.context.headers.get("x-forwarded-for") |
| | if "connection_string" not in st.session_state: |
| | st.session_state.connection_string = None |
| | if "selected_table" not in st.session_state: |
| | st.session_state.selected_table = None |
| | if "table_schema" not in st.session_state: |
| | st.session_state.table_schema = None |
| | if "sample_data" not in st.session_state: |
| | st.session_state.sample_data = None |
| | if "engine" not in st.session_state: |
| | st.session_state.engine = None |
| |
|
| |
|
| | def connect_to_db(connection_string): |
| | try: |
| | engine = create_engine(connection_string) |
| | |
| | with engine.connect(): |
| | pass |
| | st.session_state.engine = engine |
| | return True |
| | except Exception as e: |
| | st.error(f"Failed to connect to database: {str(e)}") |
| | return False |
| |
|
| |
|
| | def get_table_schema(table_name): |
| | if not st.session_state.engine: |
| | return None |
| | |
| | inspector = inspect(st.session_state.engine) |
| | columns = inspector.get_columns(table_name) |
| | schema = {col['name']: str(col['type']) for col in columns} |
| | |
| | |
| | table_comment_query = """ |
| | SELECT obj_description(c.oid) as table_comment |
| | FROM pg_class c |
| | JOIN pg_namespace n ON n.oid = c.relnamespace |
| | WHERE c.relname = :table_name |
| | AND n.nspname = 'public' |
| | """ |
| | |
| | |
| | column_comments_query = """ |
| | SELECT |
| | cols.column_name, |
| | ( |
| | SELECT pg_catalog.col_description(c.oid, cols.ordinal_position::int) |
| | FROM pg_catalog.pg_class c |
| | WHERE c.oid = (SELECT ('"' || cols.table_name || '"')::regclass::oid) |
| | AND c.relname = cols.table_name |
| | ) as column_comment |
| | FROM information_schema.columns cols |
| | WHERE cols.table_name = :table_name |
| | AND cols.table_schema = 'public' |
| | """ |
| | |
| | try: |
| | with st.session_state.engine.connect() as conn: |
| | |
| | table_comment_result = conn.execute(text(table_comment_query), {"table_name": table_name}).fetchone() |
| | table_comment = table_comment_result[0] if table_comment_result else None |
| | |
| | |
| | column_comments_result = conn.execute(text(column_comments_query), {"table_name": table_name}).fetchall() |
| | column_comments = {row[0]: row[1] for row in column_comments_result} |
| | |
| | |
| | enhanced_schema = { |
| | "table_comment": table_comment, |
| | "columns": { |
| | col_name: { |
| | "type": schema[col_name], |
| | "comment": column_comments.get(col_name) |
| | } |
| | for col_name in schema |
| | } |
| | } |
| | |
| | return enhanced_schema |
| | except Exception as e: |
| | st.error(f"Error fetching schema comments: {str(e)}") |
| | return schema |
| |
|
| |
|
| | def get_sample_data(table_name): |
| | if not st.session_state.engine: |
| | return pd.DataFrame() |
| | |
| | query = f"SELECT * FROM {table_name} ORDER BY 1 DESC LIMIT 3" |
| | try: |
| | with st.session_state.engine.connect() as conn: |
| | df = pd.read_sql(query, conn) |
| | return df |
| | except Exception as e: |
| | st.error(f"Error fetching sample data for {table_name}: {str(e)}") |
| | return pd.DataFrame() |
| |
|
| |
|
| | def clean_sql_response(response: str) -> str: |
| | """Extract clean SQL query from a potentially formatted response.""" |
| | |
| | sql_block_match = re.search(r'```sql\n(.*?)\n```', response, re.DOTALL) |
| | if sql_block_match: |
| | return sql_block_match.group(1).strip() |
| | return response.strip() |
| |
|
| |
|
| | def is_read_only_query(query: str) -> bool: |
| | """Check if the query is read-only (SELECT only).""" |
| | |
| | query_upper = query.upper() |
| | |
| | |
| | modification_statements = [ |
| | 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'TRUNCATE', |
| | 'REPLACE', 'MERGE', 'UPSERT', 'GRANT', 'REVOKE' |
| | ] |
| | |
| | |
| | return not any(query_upper.strip().startswith(stmt) for stmt in modification_statements) |
| |
|
| |
|
| | def execute_query(query): |
| | if not st.session_state.engine: |
| | return None |
| | |
| | |
| | if not is_read_only_query(query): |
| | st.error("Error: Only SELECT queries are allowed for security reasons.") |
| | return None |
| | |
| | try: |
| | start_time = time.time() |
| | with st.spinner("Executing SQL query..."): |
| | |
| | with st.session_state.engine.begin() as conn: |
| | |
| | result = conn.execute(text(query)) |
| | |
| | df = pd.DataFrame(result.fetchall(), columns=result.keys()) |
| | execution_time = time.time() - start_time |
| | pprint(f"[Query Execution] Latency: {execution_time:.2f}s") |
| | return df |
| | except Exception as e: |
| | st.error(f"Error executing query: {str(e)}") |
| | return None |
| |
|
| |
|
| | def generate_sql_query(user_query): |
| | |
| | tables_context = [] |
| | for table_name, table_type in st.session_state.selected_tables.items(): |
| | |
| | schema_info = st.session_state.table_schemas[table_name] |
| | |
| | |
| | schema_md = [f"\n\n### {table_type}: {table_name}"] |
| | |
| | |
| | if schema_info.get("table_comment"): |
| | schema_md.append(f"> {schema_info['table_comment']}") |
| | |
| | |
| | schema_md.append("\n**Columns:**") |
| | for col_name, col_info in schema_info["columns"].items(): |
| | col_type = col_info["type"] |
| | col_comment = col_info.get("comment") |
| | |
| | |
| | if col_comment: |
| | schema_md.append(f"- `{col_name}` ({col_type}) - {col_comment}") |
| | else: |
| | schema_md.append(f"- `{col_name}` ({col_type})") |
| | |
| | |
| | schema_md.append("\n**Sample Data:**") |
| | schema_md.append(st.session_state.sample_data[table_name].to_markdown(index=False)) |
| | |
| | |
| | tables_context.append("\n".join(schema_md)) |
| |
|
| | prompt = f"""You are a SQL expert. Generate a valid PostgreSQL query based on the following context and user query. |
| | |
| | <AVAILABLE_OBJECTS> |
| | {chr(10).join(tables_context)} |
| | |
| | Important: |
| | 1. Only generate SELECT queries - no INSERT, UPDATE, DELETE, or other data modification statements |
| | 2. Only return the SQL query, nothing else |
| | 3. The query should be valid PostgreSQL syntax |
| | 4. Do not include any explanations or comments |
| | 5. Make sure to handle NULL values appropriately |
| | 6. If joining tables, use appropriate join conditions based on the schema |
| | 7. Use table names with appropriate qualifiers to avoid ambiguity |
| | |
| | User Query: {user_query} |
| | """ |
| |
|
| | prompt_tokens = __countTokens(prompt) |
| | print("\n") |
| | pprint(f"[{MODEL}] Prompt tokens for SQL generation: {prompt_tokens}") |
| |
|
| | |
| | |
| | if 'localhost' in st.context.headers.get("Origin", ""): |
| | with st.expander("Debug: Prompt Generation"): |
| | st.write(f"\nUser Query: {user_query}") |
| | st.write("\nFull Prompt:") |
| | st.code(prompt, language="text") |
| |
|
| | start_time = time.time() |
| | with st.spinner(f"Generating SQL query using {MODEL}..."): |
| | if isClaudeModel: |
| | response = client.messages.create( |
| | model=MODEL, |
| | max_tokens=1000, |
| | messages=[ |
| | {"role": "user", "content": prompt}, |
| | ] |
| | ) |
| | raw_response = response.content[0].text |
| | else: |
| | response = client.chat.completions.create( |
| | model=MODEL, |
| | messages=[ |
| | {"role": "user", "content": prompt}, |
| | ] |
| | ) |
| | raw_response = response.choices[0].message.content |
| | |
| | generation_time = time.time() - start_time |
| | pprint(f"[{MODEL}] Query Generation Latency: {generation_time:.2f}s") |
| |
|
| | return clean_sql_response(raw_response) |
| |
|
| |
|
| | |
| | st.title("SQL Query Assistant") |
| |
|
| | |
| | st.header("1. Database Connection") |
| | connection_string = st.text_input( |
| | "Enter PostgreSQL Connection String", |
| | value=st.session_state.connection_string if st.session_state.connection_string else "", |
| | type="password" |
| | ) |
| |
|
| | if connection_string and connection_string != st.session_state.connection_string: |
| | if connect_to_db(connection_string): |
| | st.session_state.connection_string = connection_string |
| | st.success("Successfully connected to database!") |
| |
|
| | |
| | if st.session_state.connection_string: |
| | st.header("2. Database Object Selection") |
| | inspector = inspect(st.session_state.engine) |
| | |
| | |
| | tables = inspector.get_table_names() |
| | views = inspector.get_view_names() |
| | |
| | |
| | db_objects = [(table, 'Table') for table in tables] + [(view, 'View') for view in views] |
| | db_objects.sort(key=lambda x: x[0]) |
| | |
| | |
| | object_names = [obj[0] for obj in db_objects] |
| | |
| | |
| | default_selections = ['lsq_leads'] if 'lsq_leads' in object_names else [] |
| | |
| | |
| | selected_objects = st.multiselect( |
| | "Select tables/views", |
| | options=object_names, |
| | default=default_selections, |
| | help="You can select multiple tables/views to query across them" |
| | ) |
| | |
| | |
| | if selected_objects: |
| | st.write("Selected objects:") |
| | for obj in selected_objects: |
| | obj_type = next(obj_type for obj_name, obj_type in db_objects if obj_name == obj) |
| | st.write(f"- {obj}: {obj_type}") |
| | |
| | |
| | schema_container = st.container() |
| | data_container = st.container() |
| | |
| | |
| | if selected_objects: |
| | |
| | if not isinstance(st.session_state.get("selected_tables"), dict): |
| | st.session_state.selected_tables = {} |
| | if not isinstance(st.session_state.get("table_schemas"), dict): |
| | st.session_state.table_schemas = {} |
| | if not isinstance(st.session_state.get("sample_data"), dict): |
| | st.session_state.sample_data = {} |
| | |
| | |
| | current_tables = set(selected_objects) |
| | previous_tables = set(st.session_state.selected_tables.keys()) |
| | removed_tables = previous_tables - current_tables |
| | |
| | for table in removed_tables: |
| | if table in st.session_state.selected_tables: |
| | del st.session_state.selected_tables[table] |
| | if table in st.session_state.table_schemas: |
| | del st.session_state.table_schemas[table] |
| | if table in st.session_state.sample_data: |
| | del st.session_state.sample_data[table] |
| | |
| | |
| | for obj in selected_objects: |
| | |
| | st.session_state.selected_tables[obj] = next( |
| | obj_type for obj_name, obj_type in db_objects if obj_name == obj |
| | ) |
| | |
| | |
| | schema = get_table_schema(obj) |
| | if schema: |
| | st.session_state.table_schemas[obj] = schema |
| | |
| | |
| | sample_data = get_sample_data(obj) |
| | if not sample_data.empty: |
| | st.session_state.sample_data[obj] = sample_data |
| | |
| | |
| | with schema_container: |
| | st.subheader("Table/View Schemas") |
| | for obj in selected_objects: |
| | if obj in st.session_state.table_schemas: |
| | st.write(f"**{obj} Schema:**") |
| | st.json(st.session_state.table_schemas[obj]) |
| | st.write("---") |
| | else: |
| | st.warning(f"Could not fetch schema for {obj}") |
| | |
| | with data_container: |
| | st.subheader("Sample Data") |
| | for obj in selected_objects: |
| | if obj in st.session_state.sample_data and not st.session_state.sample_data[obj].empty: |
| | st.write(f"**{obj} (Last 3 rows):**") |
| | st.dataframe( |
| | st.session_state.sample_data[obj], |
| | use_container_width=True, |
| | hide_index=True |
| | ) |
| | st.write("---") |
| | else: |
| | st.warning(f"No sample data available for {obj}") |
| |
|
| | |
| | if st.session_state.get("selected_tables"): |
| | st.header("3. Query Input") |
| | user_query = st.text_area("Enter your query in plain English") |
| | |
| | if st.button("Generate and Execute Query"): |
| | if user_query: |
| | |
| | sql_query = generate_sql_query(user_query) |
| | |
| | |
| | st.subheader("Generated SQL Query") |
| | st.code(sql_query, language="sql") |
| | |
| | |
| | results = execute_query(sql_query) |
| | if results is not None: |
| | st.subheader("Query Results") |
| | st.dataframe(results) |
| |
|
| |
|
| |
|