SQL_chatbot_API / app.py
saadkhi's picture
Update app.py
a0fbd48 verified
import warnings
warnings.filterwarnings("ignore")
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
torch.set_num_threads(1)
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float32
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model.eval()
print("Model ready")
# ─────────────────────────
# SQL FILTER
# ─────────────────────────
SQL_KEYWORDS = [
"sql", "database", "table", "select", "insert",
"update", "delete", "join", "group by",
"postgres", "mysql", "sqlite", "query"
]
def is_sql_related(text):
text = text.lower()
return any(k in text for k in SQL_KEYWORDS)
# ─────────────────────────
# GENERATION
# ─────────────────────────
SYSTEM_PROMPT = """
You are an expert SQL generator.
Rules:
- Only respond to SQL or database related questions.
- If the question is not about SQL or databases, refuse.
- Output ONLY SQL query.
- Do not explain.
"""
def generate_sql(user_input):
if not user_input.strip():
return "Enter SQL question."
# HARD GUARD
if not is_sql_related(user_input):
return "I only respond to SQL and database related questions. If you want, I can craft helpful database queries for you."
prompt = f"""
{SYSTEM_PROMPT}
User request: {user_input}
SQL:
"""
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=120,
temperature=0.1,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
# return only SQL part
result = text.split("SQL:")[-1].strip()
# extra safety: remove explanations
result = result.split("\n\n")[0]
return result
# ─────────────────────────
# UI
# ─────────────────────────
demo = gr.Interface(
fn=generate_sql,
inputs=gr.Textbox(
lines=3,
label="SQL Question",
placeholder="Find duplicate emails in users table"
),
outputs=gr.Textbox(
lines=8,
label="Generated SQL"
),
title="AI SQL Generator (Portfolio Project)",
description="This model ONLY responds to SQL/database queries.",
examples=[
["Find duplicate emails in users table"],
["Top 5 highest paid employees"],
["Count orders per customer last month"],
["Write a joke about cats"] # will be blocked
],
)
demo.launch(server_name="0.0.0.0")