saadkhi commited on
Commit
f5903a4
Β·
verified Β·
1 Parent(s): 15e7b42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -16
app.py CHANGED
@@ -7,9 +7,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
  torch.set_num_threads(1)
9
 
10
- # ─────────────────────────
11
- # MODEL
12
- # ─────────────────────────
13
  BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
14
 
15
  print("Loading model...")
@@ -20,22 +17,49 @@ model = AutoModelForCausalLM.from_pretrained(
20
  )
21
 
22
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
23
-
24
  model.eval()
 
25
  print("Model ready")
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # ─────────────────────────
28
  # GENERATION
29
  # ─────────────────────────
30
- def generate_sql(question):
31
- if not question.strip():
 
 
 
 
 
 
 
 
 
 
 
32
  return "Enter SQL question."
33
 
 
 
 
 
34
  prompt = f"""
35
- You are a SQL expert.
36
- Return ONLY SQL query.
37
 
38
- User: {question}
39
  SQL:
40
  """
41
 
@@ -45,28 +69,42 @@ SQL:
45
  output = model.generate(
46
  **inputs,
47
  max_new_tokens=120,
48
- temperature=0.2,
49
  do_sample=False,
50
  pad_token_id=tokenizer.eos_token_id,
51
  )
52
 
53
  text = tokenizer.decode(output[0], skip_special_tokens=True)
54
 
55
- return text.split("SQL:")[-1].strip()
 
 
 
 
 
 
56
 
57
  # ─────────────────────────
58
  # UI
59
  # ─────────────────────────
60
  demo = gr.Interface(
61
  fn=generate_sql,
62
- inputs=gr.Textbox(lines=3, label="SQL Question"),
63
- outputs=gr.Textbox(lines=8, label="Generated SQL"),
64
- title="SQL Generator (Portfolio Demo)",
65
- description="Fast CPU AI SQL generator.",
 
 
 
 
 
 
 
66
  examples=[
67
  ["Find duplicate emails in users table"],
68
  ["Top 5 highest paid employees"],
69
- ["Orders per customer last month"],
 
70
  ],
71
  )
72
 
 
7
 
8
  torch.set_num_threads(1)
9
 
 
 
 
10
  BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
11
 
12
  print("Loading model...")
 
17
  )
18
 
19
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
 
20
  model.eval()
21
+
22
  print("Model ready")
23
 
24
+ # ─────────────────────────
25
+ # SQL FILTER
26
+ # ─────────────────────────
27
+ SQL_KEYWORDS = [
28
+ "sql", "database", "table", "select", "insert",
29
+ "update", "delete", "join", "group by",
30
+ "postgres", "mysql", "sqlite", "query"
31
+ ]
32
+
33
+ def is_sql_related(text):
34
+ text = text.lower()
35
+ return any(k in text for k in SQL_KEYWORDS)
36
+
37
  # ─────────────────────────
38
  # GENERATION
39
  # ─────────────────────────
40
+ SYSTEM_PROMPT = """
41
+ You are an expert SQL generator.
42
+
43
+ Rules:
44
+ - Only respond to SQL or database related questions.
45
+ - If the question is not about SQL or databases, refuse.
46
+ - Output ONLY SQL query.
47
+ - Do not explain.
48
+ """
49
+
50
+ def generate_sql(user_input):
51
+
52
+ if not user_input.strip():
53
  return "Enter SQL question."
54
 
55
+ # πŸ”΄ HARD GUARD
56
+ if not is_sql_related(user_input):
57
+ return "❌ This demo only supports SQL and database related questions."
58
+
59
  prompt = f"""
60
+ {SYSTEM_PROMPT}
 
61
 
62
+ User request: {user_input}
63
  SQL:
64
  """
65
 
 
69
  output = model.generate(
70
  **inputs,
71
  max_new_tokens=120,
72
+ temperature=0.1,
73
  do_sample=False,
74
  pad_token_id=tokenizer.eos_token_id,
75
  )
76
 
77
  text = tokenizer.decode(output[0], skip_special_tokens=True)
78
 
79
+ # return only SQL part
80
+ result = text.split("SQL:")[-1].strip()
81
+
82
+ # extra safety: remove explanations
83
+ result = result.split("\n\n")[0]
84
+
85
+ return result
86
 
87
  # ─────────────────────────
88
  # UI
89
  # ─────────────────────────
90
  demo = gr.Interface(
91
  fn=generate_sql,
92
+ inputs=gr.Textbox(
93
+ lines=3,
94
+ label="SQL Question",
95
+ placeholder="Find duplicate emails in users table"
96
+ ),
97
+ outputs=gr.Textbox(
98
+ lines=8,
99
+ label="Generated SQL"
100
+ ),
101
+ title="AI SQL Generator (Portfolio Project)",
102
+ description="This model ONLY responds to SQL/database queries.",
103
  examples=[
104
  ["Find duplicate emails in users table"],
105
  ["Top 5 highest paid employees"],
106
+ ["Count orders per customer last month"],
107
+ ["Write a joke about cats"] # will be blocked
108
  ],
109
  )
110