Neon-AI commited on
Commit
be43da3
·
verified ·
1 Parent(s): 5f665f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -34
app.py CHANGED
@@ -1,13 +1,22 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
-
5
- st.set_page_config(page_title="Niche AI", layout="centered")
6
-
7
- st.title("🧠 Niche AI (CPU Test)")
8
- st.caption("HF Free Space · 2B params · slow but real")
9
 
 
10
  MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
 
 
 
 
 
 
 
 
11
 
12
  @st.cache_resource
13
  def load_model():
@@ -15,50 +24,71 @@ def load_model():
15
  MODEL_ID,
16
  trust_remote_code=True
17
  )
 
18
  model = AutoModelForCausalLM.from_pretrained(
19
  MODEL_ID,
20
  torch_dtype=torch.float32,
21
- device_map=None # 👈 IMPORTANT
22
  )
 
23
  model.to("cpu")
24
  model.eval()
25
  return tokenizer, model
26
 
27
  tokenizer, model = load_model()
28
 
29
- # Session chat history
30
  if "history" not in st.session_state:
31
  st.session_state.history = []
32
 
33
- prompt = st.text_input("You", placeholder="Say something...")
 
34
 
35
- if st.button("Send"):
36
- if prompt.strip():
37
- st.session_state.history.append(("You", prompt))
38
 
39
- chat = [{"role": "user", "content": prompt}]
40
- inputs = tokenizer.apply_chat_template(
41
- chat,
42
- add_generation_prompt=True,
43
- return_tensors="pt",
44
- return_dict=True # ← this is the key
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- with torch.no_grad():
48
- output = model.generate(
49
- **inputs,
50
- max_new_tokens=64,
51
- do_sample=True,
52
- temperature=0.8,
53
- top_p=0.95,
54
- eos_token_id=tokenizer.eos_token_id, # ← add this
55
- pad_token_id=tokenizer.eos_token_id
56
- )
57
-
58
- reply = tokenizer.decode(output[0], skip_special_tokens=True)
59
- st.session_state.history.append(("Niche", reply))
60
-
61
- # Display chat
62
  for speaker, text in st.session_state.history:
63
  if speaker == "You":
64
  st.markdown(f"**You:** {text}")
 
1
  import streamlit as st
2
  import torch
3
+ import threading
4
+ from transformers import (
5
+ AutoModelForCausalLM,
6
+ AutoTokenizer,
7
+ TextIteratorStreamer
8
+ )
9
 
10
+ # ---------------- CONFIG ----------------
11
  MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
12
+ MAX_NEW_TOKENS = 256
13
+ TEMPERATURE = 0.7
14
+ TOP_P = 0.9
15
+ # ----------------------------------------
16
+
17
+ st.set_page_config(page_title="Niche AI", layout="centered")
18
+ st.title("🧠 Niche AI")
19
+ st.caption("HF Free Space · CPU · Streaming")
20
 
21
  @st.cache_resource
22
  def load_model():
 
24
  MODEL_ID,
25
  trust_remote_code=True
26
  )
27
+
28
  model = AutoModelForCausalLM.from_pretrained(
29
  MODEL_ID,
30
  torch_dtype=torch.float32,
31
+ device_map=None
32
  )
33
+
34
  model.to("cpu")
35
  model.eval()
36
  return tokenizer, model
37
 
38
  tokenizer, model = load_model()
39
 
40
+ # -------- SESSION STATE --------
41
  if "history" not in st.session_state:
42
  st.session_state.history = []
43
 
44
+ # -------- INPUT --------
45
+ prompt = st.text_input("You", placeholder="Say something…")
46
 
47
+ if st.button("Send") and prompt.strip():
48
+ st.session_state.history.append(("You", prompt))
 
49
 
50
+ chat = [{"role": "user", "content": prompt}]
51
+
52
+ inputs = tokenizer.apply_chat_template(
53
+ chat,
54
+ add_generation_prompt=True,
55
+ return_tensors="pt",
56
+ return_dict=True
57
+ )
58
+
59
+ streamer = TextIteratorStreamer(
60
+ tokenizer,
61
+ skip_prompt=True,
62
+ skip_special_tokens=True
63
+ )
64
+
65
+ gen_kwargs = dict(
66
+ **inputs,
67
+ max_new_tokens=MAX_NEW_TOKENS,
68
+ do_sample=True,
69
+ temperature=TEMPERATURE,
70
+ top_p=TOP_P,
71
+ eos_token_id=tokenizer.eos_token_id,
72
+ pad_token_id=tokenizer.eos_token_id,
73
+ streamer=streamer
74
+ )
75
+
76
+ thread = threading.Thread(
77
+ target=model.generate,
78
+ kwargs=gen_kwargs
79
+ )
80
+ thread.start()
81
+
82
+ placeholder = st.empty()
83
+ output_text = ""
84
+
85
+ for token in streamer:
86
+ output_text += token
87
+ placeholder.markdown(f"**Niche:** {output_text}")
88
+
89
+ st.session_state.history.append(("Niche", output_text))
90
 
91
+ # -------- DISPLAY HISTORY --------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  for speaker, text in st.session_state.history:
93
  if speaker == "You":
94
  st.markdown(f"**You:** {text}")