| | |
| | import streamlit as st |
| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| | import sqlparse |
| |
|
| | |
| | st.set_page_config( |
| | page_title="AI SQL Query Generator", |
| | page_icon="🤖", |
| | layout="centered" |
| | ) |
| |
|
| | |
| | @st.cache_resource |
| | def load_model(): |
| | model_name = "tscholak/cxmefzzi" |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
| | return tokenizer, model |
| |
|
| | |
| | def format_sql(sql): |
| | return sqlparse.format(sql, reindent=True, keyword_case='upper') |
| |
|
| | |
| | def generate_sql(input_text, tokenizer, model): |
| | prefix = "Translate English to SQL: " |
| | inputs = tokenizer(prefix + input_text, return_tensors="pt", max_length=512, truncation=True) |
| | outputs = model.generate(**inputs, max_length=256) |
| | return tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | |
| | def main(): |
| | st.title("🤖 AI-Powered SQL Query Generator") |
| | st.markdown("Convert natural language questions to SQL queries") |
| |
|
| | |
| | tokenizer, model = load_model() |
| |
|
| | |
| | user_input = st.text_area( |
| | "Enter your question in natural language:", |
| | placeholder="e.g., Show all customers from California who made purchases after January 2023", |
| | height=150 |
| | ) |
| |
|
| | |
| | if st.button("Generate SQL"): |
| | if user_input.strip() == "": |
| | st.warning("Please enter a question") |
| | else: |
| | with st.spinner("Generating SQL query..."): |
| | try: |
| | |
| | raw_sql = generate_sql(user_input, tokenizer, model) |
| | formatted_sql = format_sql(raw_sql) |
| | |
| | |
| | st.subheader("Generated SQL Query:") |
| | st.code(formatted_sql, language="sql") |
| | |
| | st.success("Query generated successfully!") |
| | |
| | |
| | with st.expander("Debug Info"): |
| | st.write(f"Model: tscholak/cxmefzzi") |
| | st.write(f"Raw Output: `{raw_sql}`") |
| | except Exception as e: |
| | st.error(f"Error generating SQL: {str(e)}") |
| |
|
| | |
| | st.markdown("---") |
| | st.markdown("### How to use:") |
| | st.markdown("1. Enter a question about data you want to query") |
| | st.markdown("2. Click 'Generate SQL'") |
| | st.markdown("3. Copy the generated SQL and use it in your database") |
| | |
| | st.markdown("### Example queries:") |
| | st.code("Show the total sales per product category in 2022", language="text") |
| | st.code("List employees hired before 2020 with salary above $50,000", language="text") |
| | st.code("Count orders by customer country and sort descending", language="text") |
| |
|
| | if __name__ == "__main__": |
| | main() |