File size: 12,725 Bytes
85306c2 73e0bbe 85306c2 73e0bbe 85306c2 73e0bbe 85306c2 73e0bbe 85306c2 c100eda 85306c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 | import inspect
import os
import threading
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
# 두 가지 수정을 해야 함
# 1. Sample reasoning과 Sample answer box 지우기
# 2. MAX_TOKEN 늘려서 끝까지 생성하고 reasoning과 assistant가 모두 생성된 스크린샷 찍기
# 3. (Optional) system prompt / your message도 수정해서 새로운 prompt-message 조합으로 실행
MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen3-0.6B")
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "2048")) # 256
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "1536"))
MAX_HISTORY_TURNS = int(os.getenv("MAX_HISTORY_TURNS", "3"))
N_THREADS = int(os.getenv("N_THREADS", str(max(1, os.cpu_count() or 1))))
DEFAULT_SYSTEM_PROMPT = os.getenv(
"SYSTEM_PROMPT",
"You are a helpful assistant. Keep answers clear and concise. If user",
)
PRESETS = {
"Math": {
"system": "You are a careful math tutor. Think through the problem, then give a short final answer.",
"prompt": "Solve: If 2x^2 - 7x + 3 = 0, what are the real solutions?",
"thinking": True,
"sample_reasoning": "The discriminant is 49 - 24 = 25, so the roots are easy to compute with the quadratic formula.",
"sample_answer": "The real solutions are x = 3 and x = 1/2.",
},
"Coding": {
"system": "You are a Python assistant. Prefer short, readable code.",
"prompt": "Write a Python function that merges two sorted lists into one sorted list.",
"thinking": True,
"sample_reasoning": "Use two pointers. Compare the current elements, append the smaller one, then append the leftovers.",
"sample_answer": "Here is a compact merge function plus a tiny example.",
},
"Structured output": {
"system": "Return compact JSON and avoid extra commentary.",
"prompt": "Extract JSON from: Call Mina by Friday, priority high, budget about $2400, topic is launch video edits.",
"thinking": False,
"sample_reasoning": "Reasoning is disabled here so the output stays short and machine-friendly.",
"sample_answer": '{"person":"Mina","deadline":"Friday","priority":"high","budget_usd":2400,"topic":"launch video edits"}',
},
"Function calling style": {
"system": "You are an assistant that plans tool use when it helps. If a tool would help, say what tool you would call and with which arguments.",
"prompt": "Pretend you have tools. For 18.75 * 42 - 199 and converting 12 km to miles, explain which tool calls you would make, then give the result.",
"thinking": True,
"sample_reasoning": "I would use a calculator tool for the arithmetic and a unit-conversion tool for the distance conversion.",
"sample_answer": "Calculator(18.75 * 42 - 199) -> 588.5\nConvert(12 km -> miles) -> about 7.46 miles",
},
"Creative writing": {
"system": "Write vivid, tight prose.",
"prompt": "Write a two-sentence opening for a sci-fi heist story set on a drifting museum ship.",
"thinking": False,
"sample_reasoning": "Reasoning is disabled for a faster clean draft.",
"sample_answer": "By the time the museum ship crossed into the dead zone, every priceless relic aboard had started broadcasting a heartbeat. Nia took that as her cue to cut the lights and steal the one artifact already trying to escape.",
},
}
torch.set_num_threads(N_THREADS)
try:
torch.set_num_interop_threads(max(1, min(2, N_THREADS)))
except RuntimeError:
pass
_tokenizer = None
_model = None
_load_lock = threading.Lock()
_generate_lock = threading.Lock()
def make_chatbot(label, height=520):
kwargs = {"label": label, "height": height}
if "type" in inspect.signature(gr.Chatbot.__init__).parameters:
kwargs["type"] = "messages"
return gr.Chatbot(**kwargs)
def get_model():
global _tokenizer, _model
if _model is None or _tokenizer is None:
with _load_lock:
if _model is None or _tokenizer is None:
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
)
_model.eval()
return _tokenizer, _model
def clone_messages(messages):
return [dict(item) for item in (messages or [])]
def load_preset(name):
preset = PRESETS[name]
return (
preset["system"],
preset["prompt"],
preset["thinking"],
preset["sample_reasoning"],
preset["sample_answer"],
)
def clear_all():
return [], [], [], ""
def strip_non_think_specials(text):
text = text or ""
for token in ["<|im_end|>", "<|endoftext|>", "<|end▁of▁sentence|>"]:
text = text.replace(token, "")
return text
def final_cleanup(text):
text = strip_non_think_specials(text)
text = text.replace("<think>", "").replace("</think>", "")
return text.strip()
def split_stream_text(raw_text, thinking):
raw_text = strip_non_think_specials(raw_text)
if not thinking:
return "", final_cleanup(raw_text), False
raw_text = raw_text.replace("<think>", "")
if "</think>" in raw_text:
reasoning, answer = raw_text.split("</think>", 1)
return reasoning.strip(), answer.strip(), True
return raw_text.strip(), "", False
def respond_stream(
message,
system_prompt,
thinking,
model_history,
reasoning_chat,
answer_chat,
):
message = (message or "").strip()
if not message:
yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history or []), ""
return
model_history = list(model_history or [])
reasoning_chat = clone_messages(reasoning_chat)
answer_chat = clone_messages(answer_chat)
reasoning_chat.append({"role": "user", "content": message})
reasoning_chat.append(
{
"role": "assistant",
"content": "(thinking...)" if thinking else "(reasoning disabled)",
}
)
answer_chat.append({"role": "user", "content": message})
answer_chat.append({"role": "assistant", "content": ""})
yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), ""
try:
tokenizer, model = get_model()
short_history = model_history[-2 * MAX_HISTORY_TURNS :]
messages = [
{"role": "system", "content": (system_prompt or "").strip() or DEFAULT_SYSTEM_PROMPT},
*short_history,
{"role": "user", "content": message},
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=thinking,
)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"][:, -MAX_INPUT_TOKENS:]
attention_mask = inputs["attention_mask"][:, -MAX_INPUT_TOKENS:]
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=False,
clean_up_tokenization_spaces=False,
timeout=None,
)
generation_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": MAX_NEW_TOKENS,
"do_sample": True,
"temperature": 0.6 if thinking else 0.7,
"top_p": 0.95 if thinking else 0.8,
"top_k": 20,
"pad_token_id": tokenizer.eos_token_id,
"streamer": streamer,
}
generation_error = {}
def run_generation():
try:
with _generate_lock:
model.generate(**generation_kwargs)
except Exception as exc:
generation_error["message"] = str(exc)
streamer.on_finalized_text("", stream_end=True)
thread = threading.Thread(target=run_generation, daemon=True)
thread.start()
raw_text = ""
saw_end_think = False
for chunk in streamer:
raw_text += chunk
reasoning_text, answer_text, saw_end_now = split_stream_text(raw_text, thinking)
saw_end_think = saw_end_think or saw_end_now
if thinking:
if saw_end_think:
reasoning_chat[-1]["content"] = reasoning_text or "(no reasoning text returned)"
else:
reasoning_chat[-1]["content"] = reasoning_text or "(thinking...)"
else:
reasoning_chat[-1]["content"] = "(reasoning disabled)"
answer_chat[-1]["content"] = answer_text
yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), ""
thread.join()
if generation_error:
reasoning_chat[-1]["content"] = ""
answer_chat[-1]["content"] = f"Error while running the local CPU model: {generation_error['message']}"
yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), ""
return
reasoning_text, answer_text, saw_end_think = split_stream_text(raw_text, thinking)
if thinking and not saw_end_think:
reasoning_text = ""
answer_text = final_cleanup(raw_text)
if thinking:
reasoning_chat[-1]["content"] = reasoning_text or "(no reasoning text returned)"
else:
reasoning_chat[-1]["content"] = "(reasoning disabled)"
answer_chat[-1]["content"] = answer_text or "(empty response)"
model_history = short_history + [
{"role": "user", "content": message},
{"role": "assistant", "content": answer_chat[-1]["content"]},
]
yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), ""
except Exception as exc:
reasoning_chat[-1]["content"] = ""
answer_chat[-1]["content"] = f"Error while preparing the local CPU model: {exc}"
yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), ""
with gr.Blocks(title="Local CPU split-reasoning chat") as demo:
gr.Markdown(
"# Local CPU split-reasoning chat\n"
f"Running a local safetensors model on CPU from `{MODEL_ID}`. No GGUF and no external inference provider.\n\n"
"The first request downloads the model, so the cold start is slower."
)
with gr.Row():
preset = gr.Dropdown(
choices=list(PRESETS.keys()),
value="Math",
label="Preset prompt",
)
thinking = gr.Checkbox(label="Enable thinking", value=True)
system_prompt = gr.Textbox(
label="System prompt",
value=PRESETS["Math"]["system"],
lines=3,
)
user_input = gr.Textbox(
label="Your message",
value=PRESETS["Math"]["prompt"],
lines=4,
)
# with gr.Row():
# sample_reasoning = gr.Textbox(
# label="Sample reasoning",
# value=PRESETS["Math"]["sample_reasoning"],
# lines=5,
# interactive=False,
# )
# sample_answer = gr.Textbox(
# label="Sample answer",
# value=PRESETS["Math"]["sample_answer"],
# lines=5,
# interactive=False,
# )
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Row():
reasoning_bot = make_chatbot("Reasoning", height=520)
answer_bot = make_chatbot("Assistant", height=520)
model_history_state = gr.State([])
preset.change(
fn=load_preset,
inputs=preset,
# outputs=[system_prompt, user_input, thinking, sample_reasoning, sample_answer],
outputs=[system_prompt, user_input, thinking],
)
send_btn.click(
fn=respond_stream,
inputs=[user_input, system_prompt, thinking, model_history_state, reasoning_bot, answer_bot],
outputs=[reasoning_bot, answer_bot, model_history_state, user_input],
)
user_input.submit(
fn=respond_stream,
inputs=[user_input, system_prompt, thinking, model_history_state, reasoning_bot, answer_bot],
outputs=[reasoning_bot, answer_bot, model_history_state, user_input],
)
clear_btn.click(
fn=clear_all,
inputs=None,
outputs=[reasoning_bot, answer_bot, model_history_state, user_input],
)
demo.queue()
demo.launch()
|