saadkhi commited on
Commit
8b67be0
Β·
verified Β·
1 Parent(s): 0fad5f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -58
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py - ZeroGPU safe: no caching + CPU load + GPU only in inference
2
 
3
  import torch
4
  import gradio as gr
@@ -7,6 +7,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
  from peft import PeftModel
8
 
9
  # ────────────────────────────────────────────────────────────────
 
 
 
10
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
11
  LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
12
 
@@ -14,81 +17,125 @@ MAX_NEW_TOKENS = 180
14
  TEMPERATURE = 0.0
15
  DO_SAMPLE = False
16
 
17
- print("Loading quantized base model on CPU (GPU only during inference)...")
18
- bnb_config = BitsAndBytesConfig(
19
- load_in_4bit=True,
20
- bnb_4bit_quant_type="nf4",
21
- bnb_4bit_compute_dtype=torch.bfloat16
22
- )
23
 
24
- model = AutoModelForCausalLM.from_pretrained(
25
- BASE_MODEL,
26
- quantization_config=bnb_config,
27
- device_map="cpu", # ← Force CPU load at startup
28
- trust_remote_code=True
29
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- print("Loading & merging LoRA...")
32
- model = PeftModel.from_pretrained(model, LORA_PATH)
33
- model = model.merge_and_unload() # Merge once for speed
34
 
35
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
36
- model.eval()
 
 
37
 
38
  # ────────────────────────────────────────────────────────────────
39
- @spaces.GPU(duration=60) # Requests GPU slice only here
 
 
 
40
  def generate_sql(prompt: str):
41
- messages = [{"role": "user", "content": prompt}]
42
-
43
- # Tokenize on CPU
44
- inputs = tokenizer.apply_chat_template(
45
- messages,
46
- tokenize=True,
47
- add_generation_prompt=True,
48
- return_tensors="pt"
49
- )
50
-
51
- # Move to GPU only now (GPU is allocated)
52
- inputs = inputs.to("cuda")
53
-
54
- with torch.inference_mode():
55
- outputs = model.generate(
56
- input_ids=inputs,
57
- max_new_tokens=MAX_NEW_TOKENS,
58
- temperature=TEMPERATURE,
59
- do_sample=DO_SAMPLE,
60
- use_cache=True,
61
- pad_token_id=tokenizer.eos_token_id,
62
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
65
-
66
- # Clean up output
67
- if "<|assistant|>" in response:
68
- response = response.split("<|assistant|>", 1)[-1].strip()
69
- if "<|end|>" in response:
70
- response = response.split("<|end|>")[0].strip()
 
 
71
 
72
- return response
73
 
 
 
 
 
 
74
  # ────────────────────────────────────────────────────────────────
 
75
  demo = gr.Interface(
76
  fn=generate_sql,
77
  inputs=gr.Textbox(
78
- label="Ask an SQL question",
79
- placeholder="Delete duplicate rows from users table based on email",
80
- lines=3
 
 
 
 
 
 
 
 
 
 
81
  ),
82
- outputs=gr.Textbox(label="Generated SQL"),
83
- title="SQL Chatbot (ZeroGPU)",
84
- description="Phi-3-mini 4bit + LoRA - GPU only during generation",
85
  examples=[
86
  ["Find duplicate emails in users table"],
87
- ["Top 5 highest paid employees"],
88
- ["Count orders per customer last month"]
 
89
  ],
90
- cache_examples=False # ← This is critical! Prevents startup crash
 
91
  )
92
 
93
  if __name__ == "__main__":
94
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
 
3
  import torch
4
  import gradio as gr
 
7
  from peft import PeftModel
8
 
9
  # ────────────────────────────────────────────────────────────────
10
+ # Configuration
11
+ # ────────────────────────────────────────────────────────────────
12
+
13
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
14
  LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
15
 
 
17
  TEMPERATURE = 0.0
18
  DO_SAMPLE = False
19
 
20
+ # ────────────────────────────────────────────────────────────────
21
+ # Load model safely on CPU first
22
+ # ────────────────────────────────────────────────────────────────
 
 
 
23
 
24
+ print("Loading base model on CPU...")
25
+ try:
26
+ bnb_config = BitsAndBytesConfig(
27
+ load_in_4bit=True,
28
+ bnb_4bit_quant_type="nf4",
29
+ bnb_4bit_compute_dtype=torch.bfloat16
30
+ )
31
+
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ BASE_MODEL,
34
+ quantization_config=bnb_config,
35
+ device_map="cpu", # Critical for ZeroGPU + CPU spaces
36
+ trust_remote_code=True,
37
+ low_cpu_mem_usage=True
38
+ )
39
+
40
+ print("Loading and merging LoRA adapters...")
41
+ model = PeftModel.from_pretrained(model, LORA_PATH)
42
+ model = model.merge_and_unload() # Merge once β†’ faster inference
43
 
44
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
45
+ model.eval()
 
46
 
47
+ print("Model successfully loaded on CPU")
48
+ except Exception as e:
49
+ print(f"Model loading failed: {str(e)}")
50
+ raise
51
 
52
  # ────────────────────────────────────────────────────────────────
53
+ # Inference function – GPU only here
54
+ # ────────────────────────────────────────────────────────────────
55
+
56
+ @spaces.GPU(duration=60) # 60 seconds is usually enough
57
  def generate_sql(prompt: str):
58
+ try:
59
+ messages = [{"role": "user", "content": prompt.strip()}]
60
+
61
+ # Tokenize on CPU
62
+ inputs = tokenizer.apply_chat_template(
63
+ messages,
64
+ tokenize=True,
65
+ add_generation_prompt=True,
66
+ return_tensors="pt"
 
 
 
 
 
 
 
 
 
 
 
 
67
  )
68
+
69
+ # Move to GPU only inside decorated function
70
+ if torch.cuda.is_available():
71
+ inputs = inputs.to("cuda")
72
+
73
+ with torch.inference_mode():
74
+ outputs = model.generate(
75
+ input_ids=inputs,
76
+ max_new_tokens=MAX_NEW_TOKENS,
77
+ temperature=TEMPERATURE,
78
+ do_sample=DO_SAMPLE,
79
+ use_cache=True,
80
+ pad_token_id=tokenizer.eos_token_id,
81
+ )
82
 
83
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+
85
+ # Clean output
86
+ if "<|assistant|>" in response:
87
+ response = response.split("<|assistant|>", 1)[-1].strip()
88
+ if "<|end|>" in response:
89
+ response = response.split("<|end|>")[0].strip()
90
+ if "<|user|>" in response:
91
+ response = response.split("<|user|>")[0].strip()
92
 
93
+ return response.strip() or "No valid response generated."
94
 
95
+ except Exception as e:
96
+ return f"Error during generation: {str(e)}"
97
+
98
+ # ────────────────────────────────────────────────────────────────
99
+ # Gradio Interface
100
  # ────────────────────────────────────────────────────────────────
101
+
102
  demo = gr.Interface(
103
  fn=generate_sql,
104
  inputs=gr.Textbox(
105
+ label="Your SQL-related question",
106
+ placeholder="e.g. Find duplicate emails in users table",
107
+ lines=3,
108
+ max_lines=6
109
+ ),
110
+ outputs=gr.Textbox(
111
+ label="Generated SQL / Answer",
112
+ lines=6
113
+ ),
114
+ title="SQL Chatbot – Phi-3-mini fine-tuned",
115
+ description=(
116
+ "Ask questions about SQL queries.\n\n"
117
+ "ZeroGPU version – first response may take 10–60 seconds (cold start)."
118
  ),
 
 
 
119
  examples=[
120
  ["Find duplicate emails in users table"],
121
+ ["Top 5 highest paid employees from employees table"],
122
+ ["Count total orders per customer in last 30 days"],
123
+ ["Delete duplicate rows based on email column"]
124
  ],
125
+ cache_examples=False, # Very important for ZeroGPU
126
+ allow_flagging="never"
127
  )
128
 
129
  if __name__ == "__main__":
130
+ print("Starting Gradio server...")
131
+ try:
132
+ demo.launch(
133
+ server_name="0.0.0.0",
134
+ server_port=7860,
135
+ debug=False,
136
+ quiet=False,
137
+ show_error=True
138
+ )
139
+ except Exception as e:
140
+ print(f"Launch failed: {str(e)}")
141
+ raise