JoeStrout commited on
Commit
9a1939c
Β·
verified Β·
1 Parent(s): 2a35216

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +234 -0
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Gradio Space: MiniScript Code Helper (LoRA + RAG).
3
+
4
+ Loads the fine-tuned Qwen2.5-Coder-7B-Instruct LoRA adapter and a ChromaDB
5
+ vector index built from MiniScript documentation, then serves a chat interface.
6
+ """
7
+
8
+ import os
9
+ import re
10
+
11
+ os.environ.setdefault("USE_TF", "0")
12
+
13
+ import chromadb
14
+ import gradio as gr
15
+ import torch
16
+ from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
17
+ from peft import PeftModel
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Configuration
22
+ # ---------------------------------------------------------------------------
23
+
24
+ BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct"
25
+ ADAPTER_REPO = "JoeStrout/miniscript-code-helper-lora"
26
+ RAG_DIR = "./RAG_sources"
27
+ DB_DIR = "./chroma_db"
28
+ COLLECTION = "miniscript_docs"
29
+ EMBEDDING_MODEL = "all-MiniLM-L6-v2"
30
+ TOP_K = 5
31
+ MAX_NEW_TOKENS = 1024
32
+ MAX_CHUNK_CHARS = 1500
33
+
34
+ BASE_SYSTEM_PROMPT = "You are a helpful assistant specializing in MiniScript programming."
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # RAG index builder (inline so app is self-contained)
38
+ # ---------------------------------------------------------------------------
39
+
40
+ def strip_leanpub(text: str) -> str:
41
+ lines = text.splitlines()
42
+ cleaned = []
43
+ for line in lines:
44
+ if re.match(r'^\s*\{(chapterHead|width|i:|caption|pagebreak|startingPageNum)', line):
45
+ m = re.search(r'\{caption:\s*"([^"]+)"\}', line)
46
+ if m:
47
+ cleaned.append(f"[{m.group(1)}]")
48
+ continue
49
+ if re.match(r'^\s*!\[.*\]\(.*\)\s*$', line):
50
+ continue
51
+ line = re.sub(r'^([QADX])>\s?', '', line)
52
+ cleaned.append(line)
53
+ return '\n'.join(cleaned)
54
+
55
+
56
+ def split_long_chunk(text: str, max_chars: int = MAX_CHUNK_CHARS) -> list:
57
+ if len(text) <= max_chars:
58
+ return [text]
59
+ paragraphs = re.split(r'\n\n+', text)
60
+ chunks, current = [], ""
61
+ for para in paragraphs:
62
+ if current and len(current) + len(para) + 2 > max_chars:
63
+ chunks.append(current.strip())
64
+ current = para
65
+ else:
66
+ current = current + "\n\n" + para if current else para
67
+ if current.strip():
68
+ chunks.append(current.strip())
69
+ return chunks
70
+
71
+
72
+ def chunk_document(text: str, filename: str) -> list:
73
+ is_txt = filename.endswith('.txt')
74
+ if is_txt:
75
+ text = strip_leanpub(text)
76
+ lines = text.splitlines()
77
+ chunks, current_section, current_lines = [], filename, []
78
+
79
+ def flush():
80
+ body = '\n'.join(current_lines).strip()
81
+ if not body:
82
+ return
83
+ for part in split_long_chunk(body):
84
+ if part.strip():
85
+ chunks.append({"text": part, "source": filename, "section": current_section})
86
+
87
+ for line in lines:
88
+ heading = None
89
+ if is_txt:
90
+ m = re.match(r'^(#{1,4})\s+(.*)', line)
91
+ if m:
92
+ heading = m.group(2).strip()
93
+ elif re.match(r'^#{1,4}\s', line):
94
+ heading = re.sub(r'^#+\s*', '', line).strip()
95
+ if heading:
96
+ flush()
97
+ current_section = heading
98
+ current_lines = []
99
+ else:
100
+ current_lines.append(line)
101
+ flush()
102
+ return chunks
103
+
104
+
105
+ def build_rag_index():
106
+ print(f"Building ChromaDB index from {RAG_DIR}/ ...")
107
+ embedding_fn = SentenceTransformerEmbeddingFunction(model_name=EMBEDDING_MODEL)
108
+ client = chromadb.PersistentClient(path=DB_DIR)
109
+
110
+ existing = [c.name for c in client.list_collections()]
111
+ if COLLECTION in existing:
112
+ col = client.get_collection(name=COLLECTION, embedding_function=embedding_fn)
113
+ print(f" Reusing existing collection ({col.count()} chunks)")
114
+ return col
115
+
116
+ col = client.create_collection(
117
+ name=COLLECTION,
118
+ embedding_function=embedding_fn,
119
+ metadata={"hnsw:space": "cosine"},
120
+ )
121
+ source_files = sorted(f for f in os.listdir(RAG_DIR) if f.endswith(('.md', '.txt')))
122
+ all_chunks = []
123
+ for fname in source_files:
124
+ with open(os.path.join(RAG_DIR, fname), encoding='utf-8') as f:
125
+ text = f.read()
126
+ chunks = chunk_document(text, fname)
127
+ print(f" {fname}: {len(chunks)} chunks")
128
+ all_chunks.extend(chunks)
129
+
130
+ BATCH = 100
131
+ for i in range(0, len(all_chunks), BATCH):
132
+ batch = all_chunks[i:i + BATCH]
133
+ col.add(
134
+ ids=[f"chunk_{i + j}" for j in range(len(batch))],
135
+ documents=[c["text"] for c in batch],
136
+ metadatas=[{"source": c["source"], "section": c["section"]} for c in batch],
137
+ )
138
+ print(f" Indexed {col.count()} chunks total.")
139
+ return col
140
+
141
+
142
+ # ---------------------------------------------------------------------------
143
+ # Model loading
144
+ # ---------------------------------------------------------------------------
145
+
146
+ def load_model():
147
+ print(f"Loading tokenizer from {ADAPTER_REPO} ...")
148
+ tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO)
149
+
150
+ print(f"Loading base model {BASE_MODEL} in 4-bit ...")
151
+ bnb_cfg = BitsAndBytesConfig(
152
+ load_in_4bit=True,
153
+ bnb_4bit_quant_type="nf4",
154
+ bnb_4bit_compute_dtype=torch.bfloat16,
155
+ )
156
+ base = AutoModelForCausalLM.from_pretrained(
157
+ BASE_MODEL,
158
+ quantization_config=bnb_cfg,
159
+ device_map="auto",
160
+ )
161
+
162
+ print(f"Loading LoRA adapter from {ADAPTER_REPO} ...")
163
+ model = PeftModel.from_pretrained(base, ADAPTER_REPO)
164
+ model.eval()
165
+ print("Model ready!")
166
+ return tokenizer, model
167
+
168
+
169
+ # ---------------------------------------------------------------------------
170
+ # Startup
171
+ # ---------------------------------------------------------------------------
172
+
173
+ collection = build_rag_index()
174
+ tokenizer, model = load_model()
175
+
176
+
177
+ # ---------------------------------------------------------------------------
178
+ # Chat logic
179
+ # ---------------------------------------------------------------------------
180
+
181
+ def build_system_prompt(results: dict) -> str:
182
+ if not results or not results["documents"] or not results["documents"][0]:
183
+ return BASE_SYSTEM_PROMPT
184
+ parts = []
185
+ for doc, meta in zip(results["documents"][0], results["metadatas"][0]):
186
+ parts.append(f"[Source: {meta['source']}, Section: {meta['section']}]\n{doc}")
187
+ context = "\n\n".join(parts)
188
+ return (
189
+ f"{BASE_SYSTEM_PROMPT}\n\n"
190
+ f"Use the following reference material to help answer the user's question:\n\n"
191
+ f"{context}"
192
+ )
193
+
194
+
195
+ def chat(message: str, history: list) -> str:
196
+ results = collection.query(query_texts=[message], n_results=TOP_K)
197
+ system_prompt = build_system_prompt(results)
198
+
199
+ messages = [{"role": "system", "content": system_prompt}]
200
+ for user_msg, assistant_msg in history:
201
+ messages.append({"role": "user", "content": user_msg})
202
+ messages.append({"role": "assistant", "content": assistant_msg})
203
+ messages.append({"role": "user", "content": message})
204
+
205
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
206
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
207
+ with torch.no_grad():
208
+ output = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS)
209
+ response = tokenizer.decode(output[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
210
+ return response
211
+
212
+
213
+ # ---------------------------------------------------------------------------
214
+ # Gradio UI
215
+ # ---------------------------------------------------------------------------
216
+
217
+ demo = gr.ChatInterface(
218
+ fn=chat,
219
+ title="MiniScript Code Helper",
220
+ description=(
221
+ "Ask questions about the [MiniScript](https://miniscript.org) programming language. "
222
+ "Powered by a fine-tuned Qwen2.5-Coder-7B-Instruct model with RAG over MiniScript documentation."
223
+ ),
224
+ examples=[
225
+ "How do I define a function in MiniScript?",
226
+ "How do I iterate over a list?",
227
+ "What is the difference between `and` and `&&` in MiniScript?",
228
+ "How do I read a file in MiniScript?",
229
+ ],
230
+ cache_examples=False,
231
+ )
232
+
233
+ if __name__ == "__main__":
234
+ demo.launch()