saadkhi commited on
Commit
0fad5f5
Β·
verified Β·
1 Parent(s): 81f0e97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -30
app.py CHANGED
@@ -1,58 +1,56 @@
1
- # app.py - CPU SAFE VERSION (No CUDA, No GPU)
2
 
3
  import torch
4
  import gradio as gr
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
  from peft import PeftModel
7
 
8
- # ─────────────────────────────────────────────
9
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
10
- LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
11
 
12
  MAX_NEW_TOKENS = 180
13
- TEMPERATURE = 0.0
14
- DO_SAMPLE = False
15
 
16
- print("Loading model on CPU...")
17
-
18
- # 4-bit config (works on CPU but slower)
19
  bnb_config = BitsAndBytesConfig(
20
  load_in_4bit=True,
21
  bnb_4bit_quant_type="nf4",
22
  bnb_4bit_compute_dtype=torch.bfloat16
23
  )
24
 
25
- # Load base model on CPU
26
  model = AutoModelForCausalLM.from_pretrained(
27
  BASE_MODEL,
28
  quantization_config=bnb_config,
29
- device_map="cpu",
30
  trust_remote_code=True
31
  )
32
 
33
- print("Loading LoRA...")
34
  model = PeftModel.from_pretrained(model, LORA_PATH)
35
-
36
- # Merge LoRA for simpler inference
37
- model = model.merge_and_unload()
38
 
39
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
40
  model.eval()
41
 
42
- # ─────────────────────────────────────────────
 
43
  def generate_sql(prompt: str):
44
  messages = [{"role": "user", "content": prompt}]
45
-
46
- # Tokenize (CPU)
47
  inputs = tokenizer.apply_chat_template(
48
  messages,
49
  tokenize=True,
50
  add_generation_prompt=True,
51
  return_tensors="pt"
52
  )
53
-
54
- input_length = inputs.shape[-1] # length of prompt tokens
55
-
 
56
  with torch.inference_mode():
57
  outputs = model.generate(
58
  input_ids=inputs,
@@ -63,31 +61,34 @@ def generate_sql(prompt: str):
63
  pad_token_id=tokenizer.eos_token_id,
64
  )
65
 
66
- # πŸ”‘ Remove the prompt tokens from the output
67
- generated_tokens = outputs[0][input_length:]
68
-
69
- response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
 
 
 
70
 
71
  return response
72
 
73
- # ─────────────────────────────────────────────
74
  demo = gr.Interface(
75
  fn=generate_sql,
76
  inputs=gr.Textbox(
77
- label="Ask SQL question",
78
  placeholder="Delete duplicate rows from users table based on email",
79
  lines=3
80
  ),
81
  outputs=gr.Textbox(label="Generated SQL"),
82
- title="SQL Chatbot (CPU Mode)",
83
- description="Phi-3-mini 4bit + LoRA (CPU only, slower inference)",
84
  examples=[
85
  ["Find duplicate emails in users table"],
86
  ["Top 5 highest paid employees"],
87
  ["Count orders per customer last month"]
88
  ],
89
- cache_examples=False
90
  )
91
 
92
  if __name__ == "__main__":
93
- demo.launch()
 
1
+ # app.py - ZeroGPU safe: no caching + CPU load + GPU only in inference
2
 
3
  import torch
4
  import gradio as gr
5
+ import spaces
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
  from peft import PeftModel
8
 
9
+ # ────────────────────────────────────────────────────────────────
10
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
11
+ LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
12
 
13
  MAX_NEW_TOKENS = 180
14
+ TEMPERATURE = 0.0
15
+ DO_SAMPLE = False
16
 
17
+ print("Loading quantized base model on CPU (GPU only during inference)...")
 
 
18
  bnb_config = BitsAndBytesConfig(
19
  load_in_4bit=True,
20
  bnb_4bit_quant_type="nf4",
21
  bnb_4bit_compute_dtype=torch.bfloat16
22
  )
23
 
 
24
  model = AutoModelForCausalLM.from_pretrained(
25
  BASE_MODEL,
26
  quantization_config=bnb_config,
27
+ device_map="cpu", # ← Force CPU load at startup
28
  trust_remote_code=True
29
  )
30
 
31
+ print("Loading & merging LoRA...")
32
  model = PeftModel.from_pretrained(model, LORA_PATH)
33
+ model = model.merge_and_unload() # Merge once for speed
 
 
34
 
35
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
36
  model.eval()
37
 
38
+ # ────────────────────────────────────────────────────────────────
39
+ @spaces.GPU(duration=60) # Requests GPU slice only here
40
  def generate_sql(prompt: str):
41
  messages = [{"role": "user", "content": prompt}]
42
+
43
+ # Tokenize on CPU
44
  inputs = tokenizer.apply_chat_template(
45
  messages,
46
  tokenize=True,
47
  add_generation_prompt=True,
48
  return_tensors="pt"
49
  )
50
+
51
+ # Move to GPU only now (GPU is allocated)
52
+ inputs = inputs.to("cuda")
53
+
54
  with torch.inference_mode():
55
  outputs = model.generate(
56
  input_ids=inputs,
 
61
  pad_token_id=tokenizer.eos_token_id,
62
  )
63
 
64
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
65
+
66
+ # Clean up output
67
+ if "<|assistant|>" in response:
68
+ response = response.split("<|assistant|>", 1)[-1].strip()
69
+ if "<|end|>" in response:
70
+ response = response.split("<|end|>")[0].strip()
71
 
72
  return response
73
 
74
+ # ────────────────────────────────────────────────────────────────
75
  demo = gr.Interface(
76
  fn=generate_sql,
77
  inputs=gr.Textbox(
78
+ label="Ask an SQL question",
79
  placeholder="Delete duplicate rows from users table based on email",
80
  lines=3
81
  ),
82
  outputs=gr.Textbox(label="Generated SQL"),
83
+ title="SQL Chatbot (ZeroGPU)",
84
+ description="Phi-3-mini 4bit + LoRA - GPU only during generation",
85
  examples=[
86
  ["Find duplicate emails in users table"],
87
  ["Top 5 highest paid employees"],
88
  ["Count orders per customer last month"]
89
  ],
90
+ cache_examples=False # ← This is critical! Prevents startup crash
91
  )
92
 
93
  if __name__ == "__main__":
94
+ demo.launch()