import inspect
import os
import threading
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
# 두 가지 수정을 해야 함
# 1. Sample reasoning과 Sample answer box 지우기
# 2. MAX_TOKEN 늘려서 끝까지 생성하고 reasoning과 assistant가 모두 생성된 스크린샷 찍기
# 3. (Optional) system prompt / your message도 수정해서 새로운 prompt-message 조합으로 실행
MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen3-0.6B")
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "2048")) # 256
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "1536"))
MAX_HISTORY_TURNS = int(os.getenv("MAX_HISTORY_TURNS", "3"))
N_THREADS = int(os.getenv("N_THREADS", str(max(1, os.cpu_count() or 1))))
DEFAULT_SYSTEM_PROMPT = os.getenv(
"SYSTEM_PROMPT",
"You are a helpful assistant. Keep answers clear and concise. If user",
)
PRESETS = {
"Math": {
"system": "You are a careful math tutor. Think through the problem, then give a short final answer.",
"prompt": "Solve: If 2x^2 - 7x + 3 = 0, what are the real solutions?",
"thinking": True,
"sample_reasoning": "The discriminant is 49 - 24 = 25, so the roots are easy to compute with the quadratic formula.",
"sample_answer": "The real solutions are x = 3 and x = 1/2.",
},
"Coding": {
"system": "You are a Python assistant. Prefer short, readable code.",
"prompt": "Write a Python function that merges two sorted lists into one sorted list.",
"thinking": True,
"sample_reasoning": "Use two pointers. Compare the current elements, append the smaller one, then append the leftovers.",
"sample_answer": "Here is a compact merge function plus a tiny example.",
},
"Structured output": {
"system": "Return compact JSON and avoid extra commentary.",
"prompt": "Extract JSON from: Call Mina by Friday, priority high, budget about $2400, topic is launch video edits.",
"thinking": False,
"sample_reasoning": "Reasoning is disabled here so the output stays short and machine-friendly.",
"sample_answer": '{"person":"Mina","deadline":"Friday","priority":"high","budget_usd":2400,"topic":"launch video edits"}',
},
"Function calling style": {
"system": "You are an assistant that plans tool use when it helps. If a tool would help, say what tool you would call and with which arguments.",
"prompt": "Pretend you have tools. For 18.75 * 42 - 199 and converting 12 km to miles, explain which tool calls you would make, then give the result.",
"thinking": True,
"sample_reasoning": "I would use a calculator tool for the arithmetic and a unit-conversion tool for the distance conversion.",
"sample_answer": "Calculator(18.75 * 42 - 199) -> 588.5\nConvert(12 km -> miles) -> about 7.46 miles",
},
"Creative writing": {
"system": "Write vivid, tight prose.",
"prompt": "Write a two-sentence opening for a sci-fi heist story set on a drifting museum ship.",
"thinking": False,
"sample_reasoning": "Reasoning is disabled for a faster clean draft.",
"sample_answer": "By the time the museum ship crossed into the dead zone, every priceless relic aboard had started broadcasting a heartbeat. Nia took that as her cue to cut the lights and steal the one artifact already trying to escape.",
},
}
torch.set_num_threads(N_THREADS)
try:
torch.set_num_interop_threads(max(1, min(2, N_THREADS)))
except RuntimeError:
pass
_tokenizer = None
_model = None
_load_lock = threading.Lock()
_generate_lock = threading.Lock()
def make_chatbot(label, height=520):
kwargs = {"label": label, "height": height}
if "type" in inspect.signature(gr.Chatbot.__init__).parameters:
kwargs["type"] = "messages"
return gr.Chatbot(**kwargs)
def get_model():
global _tokenizer, _model
if _model is None or _tokenizer is None:
with _load_lock:
if _model is None or _tokenizer is None:
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
)
_model.eval()
return _tokenizer, _model
def clone_messages(messages):
return [dict(item) for item in (messages or [])]
def load_preset(name):
preset = PRESETS[name]
return (
preset["system"],
preset["prompt"],
preset["thinking"],
preset["sample_reasoning"],
preset["sample_answer"],
)
def clear_all():
return [], [], [], ""
def strip_non_think_specials(text):
text = text or ""
for token in ["<|im_end|>", "<|endoftext|>", "<|end▁of▁sentence|>"]:
text = text.replace(token, "")
return text
def final_cleanup(text):
text = strip_non_think_specials(text)
text = text.replace("", "").replace("", "")
return text.strip()
def split_stream_text(raw_text, thinking):
raw_text = strip_non_think_specials(raw_text)
if not thinking:
return "", final_cleanup(raw_text), False
raw_text = raw_text.replace("", "")
if "" in raw_text:
reasoning, answer = raw_text.split("", 1)
return reasoning.strip(), answer.strip(), True
return raw_text.strip(), "", False
def respond_stream(
message,
system_prompt,
thinking,
model_history,
reasoning_chat,
answer_chat,
):
message = (message or "").strip()
if not message:
yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history or []), ""
return
model_history = list(model_history or [])
reasoning_chat = clone_messages(reasoning_chat)
answer_chat = clone_messages(answer_chat)
reasoning_chat.append({"role": "user", "content": message})
reasoning_chat.append(
{
"role": "assistant",
"content": "(thinking...)" if thinking else "(reasoning disabled)",
}
)
answer_chat.append({"role": "user", "content": message})
answer_chat.append({"role": "assistant", "content": ""})
yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), ""
try:
tokenizer, model = get_model()
short_history = model_history[-2 * MAX_HISTORY_TURNS :]
messages = [
{"role": "system", "content": (system_prompt or "").strip() or DEFAULT_SYSTEM_PROMPT},
*short_history,
{"role": "user", "content": message},
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=thinking,
)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"][:, -MAX_INPUT_TOKENS:]
attention_mask = inputs["attention_mask"][:, -MAX_INPUT_TOKENS:]
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=False,
clean_up_tokenization_spaces=False,
timeout=None,
)
generation_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": MAX_NEW_TOKENS,
"do_sample": True,
"temperature": 0.6 if thinking else 0.7,
"top_p": 0.95 if thinking else 0.8,
"top_k": 20,
"pad_token_id": tokenizer.eos_token_id,
"streamer": streamer,
}
generation_error = {}
def run_generation():
try:
with _generate_lock:
model.generate(**generation_kwargs)
except Exception as exc:
generation_error["message"] = str(exc)
streamer.on_finalized_text("", stream_end=True)
thread = threading.Thread(target=run_generation, daemon=True)
thread.start()
raw_text = ""
saw_end_think = False
for chunk in streamer:
raw_text += chunk
reasoning_text, answer_text, saw_end_now = split_stream_text(raw_text, thinking)
saw_end_think = saw_end_think or saw_end_now
if thinking:
if saw_end_think:
reasoning_chat[-1]["content"] = reasoning_text or "(no reasoning text returned)"
else:
reasoning_chat[-1]["content"] = reasoning_text or "(thinking...)"
else:
reasoning_chat[-1]["content"] = "(reasoning disabled)"
answer_chat[-1]["content"] = answer_text
yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), ""
thread.join()
if generation_error:
reasoning_chat[-1]["content"] = ""
answer_chat[-1]["content"] = f"Error while running the local CPU model: {generation_error['message']}"
yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), ""
return
reasoning_text, answer_text, saw_end_think = split_stream_text(raw_text, thinking)
if thinking and not saw_end_think:
reasoning_text = ""
answer_text = final_cleanup(raw_text)
if thinking:
reasoning_chat[-1]["content"] = reasoning_text or "(no reasoning text returned)"
else:
reasoning_chat[-1]["content"] = "(reasoning disabled)"
answer_chat[-1]["content"] = answer_text or "(empty response)"
model_history = short_history + [
{"role": "user", "content": message},
{"role": "assistant", "content": answer_chat[-1]["content"]},
]
yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), ""
except Exception as exc:
reasoning_chat[-1]["content"] = ""
answer_chat[-1]["content"] = f"Error while preparing the local CPU model: {exc}"
yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), ""
with gr.Blocks(title="Local CPU split-reasoning chat") as demo:
gr.Markdown(
"# Local CPU split-reasoning chat\n"
f"Running a local safetensors model on CPU from `{MODEL_ID}`. No GGUF and no external inference provider.\n\n"
"The first request downloads the model, so the cold start is slower."
)
with gr.Row():
preset = gr.Dropdown(
choices=list(PRESETS.keys()),
value="Math",
label="Preset prompt",
)
thinking = gr.Checkbox(label="Enable thinking", value=True)
system_prompt = gr.Textbox(
label="System prompt",
value=PRESETS["Math"]["system"],
lines=3,
)
user_input = gr.Textbox(
label="Your message",
value=PRESETS["Math"]["prompt"],
lines=4,
)
# with gr.Row():
# sample_reasoning = gr.Textbox(
# label="Sample reasoning",
# value=PRESETS["Math"]["sample_reasoning"],
# lines=5,
# interactive=False,
# )
# sample_answer = gr.Textbox(
# label="Sample answer",
# value=PRESETS["Math"]["sample_answer"],
# lines=5,
# interactive=False,
# )
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Row():
reasoning_bot = make_chatbot("Reasoning", height=520)
answer_bot = make_chatbot("Assistant", height=520)
model_history_state = gr.State([])
preset.change(
fn=load_preset,
inputs=preset,
# outputs=[system_prompt, user_input, thinking, sample_reasoning, sample_answer],
outputs=[system_prompt, user_input, thinking],
)
send_btn.click(
fn=respond_stream,
inputs=[user_input, system_prompt, thinking, model_history_state, reasoning_bot, answer_bot],
outputs=[reasoning_bot, answer_bot, model_history_state, user_input],
)
user_input.submit(
fn=respond_stream,
inputs=[user_input, system_prompt, thinking, model_history_state, reasoning_bot, answer_bot],
outputs=[reasoning_bot, answer_bot, model_history_state, user_input],
)
clear_btn.click(
fn=clear_all,
inputs=None,
outputs=[reasoning_bot, answer_bot, model_history_state, user_input],
)
demo.queue()
demo.launch()