OpenLab-NLP commited on
Commit
585706a
ยท
verified ยท
1 Parent(s): 0f1138d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -0
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sentencepiece as spm
2
+ import os, numpy as np, tensorflow as tf
3
+ from tensorflow.keras import layers
4
+ import gradio as gr
5
+
6
+ # --- 1. ํ™˜๊ฒฝ ์„ค์ • ๋ฐ ๋ชจ๋ธ ๊ตฌ์กฐ ์ •์˜ (๊ธฐ์กด ์œ ์ง€) ---
7
+ TOKENIZER_PATH = "tokenizer.model"
8
+ sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
9
+ pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
10
+ end_id = sp.piece_to_id("</s>")
11
+ vocab_size = sp.get_piece_size()
12
+
13
+ class TimeMix(layers.Layer):
14
+ def __init__(self, d_model, layer_id, n_layers):
15
+ super().__init__()
16
+ self.d_model = d_model
17
+ ratio = (layer_id / (n_layers - 1)) if n_layers > 1 else 0.5
18
+ decay_speed = np.arange(d_model)
19
+ self.time_decay = tf.Variable(-5 + 8 * (decay_speed / (d_model - 1)) ** (0.7 + 1.3 * ratio), dtype=tf.float32)
20
+ self.time_first = tf.Variable(np.ones(d_model) * np.log(0.3), dtype=tf.float32)
21
+ self.w_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False)
22
+ self.r_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False)
23
+ self.k_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False)
24
+ self.v_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False)
25
+ self.key = layers.Dense(d_model, use_bias=False)
26
+ self.value = layers.Dense(d_model, use_bias=False)
27
+ self.receptance = layers.Dense(d_model, use_bias=False)
28
+ self.output_projection = layers.Dense(d_model, use_bias=False)
29
+ self.tm_w = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32)
30
+ self.tm_k = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32)
31
+ self.tm_v = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32)
32
+ self.tm_r = tf.Variable(1 - (ratio ** 0.2), dtype=tf.float32)
33
+
34
+ def call(self, x, state):
35
+ last_x, aa, bb, pp = state
36
+ t_type = x.dtype
37
+ tm_w, tm_k, tm_v, tm_r = tf.cast(self.tm_w, t_type), tf.cast(self.tm_k, t_type), tf.cast(self.tm_v, t_type), tf.cast(self.tm_r, t_type)
38
+ dx = x * tm_w + last_x * (1 - tm_w)
39
+ w = tf.cast(self.time_decay, t_type) + tf.cast(self.w_proj(dx), t_type)
40
+ w = -tf.exp(tf.cast(w, tf.float32))
41
+ r = self.receptance(x * tm_r + last_x * (1 - tm_r)) + self.r_proj(dx)
42
+ k = self.key(x * tm_k + last_x * (1 - tm_k)) + self.k_proj(dx)
43
+ v = self.value(x * tm_v + last_x * (1 - tm_v)) + self.v_proj(dx)
44
+ u = tf.cast(self.time_first, tf.float32)
45
+ kv, vv = tf.cast(k, tf.float32), tf.cast(v, tf.float32)
46
+ ww = u + kv
47
+ p = tf.maximum(pp, ww)
48
+ e1, e2 = tf.exp(pp - p), tf.exp(ww - p)
49
+ wkv = (e1 * aa + e2 * vv) / (e1 * bb + e2 + 1e-12)
50
+ ww_next = w + pp
51
+ p_next = tf.maximum(ww_next, kv)
52
+ e1_next, e2_next = tf.exp(ww_next - p_next), tf.exp(kv - p_next)
53
+ new_state = [x, e1_next * aa + e2_next * vv, e1_next * bb + e2_next, p_next]
54
+ return self.output_projection(tf.nn.sigmoid(r) * tf.cast(wkv, t_type)), new_state
55
+
56
+ class ChannelMix(layers.Layer):
57
+ def __init__(self, d_model, layer_id, n_layers):
58
+ super().__init__()
59
+ ratio = (layer_id / (n_layers - 1)) if n_layers > 1 else 0.5
60
+ self.time_mix_k = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32)
61
+ self.time_mix_r = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32)
62
+ self.key = layers.Dense(int(d_model * 4.25), use_bias=False)
63
+ self.receptance = layers.Dense(d_model, use_bias=False)
64
+ self.value = layers.Dense(d_model, use_bias=False)
65
+
66
+ def call(self, x, last_x):
67
+ t_type = x.dtype
68
+ tm_k, tm_r = tf.cast(self.time_mix_k, t_type), tf.cast(self.time_mix_r, t_type)
69
+ k = self.key(x * tm_k + last_x * (1 - tm_k))
70
+ r = self.receptance(x * tm_r + last_x * (1 - tm_r))
71
+ kv = self.value(tf.square(tf.nn.relu(k)))
72
+ return tf.nn.sigmoid(r) * kv, x
73
+
74
+ class Block(layers.Layer):
75
+ def __init__(self, d_model, layer_id, n_layers):
76
+ super().__init__()
77
+ self.ln = layers.LayerNormalization(epsilon=1e-5)
78
+ self.time_mix = TimeMix(d_model, layer_id, n_layers)
79
+ self.channel_mix = ChannelMix(d_model, layer_id, n_layers)
80
+ def call(self, x, state):
81
+ ln_x = self.ln(x)
82
+ tm_out, tm_state = self.time_mix(ln_x, state[:4])
83
+ x = x + tm_out
84
+ cm_out, cm_last_x = self.channel_mix(ln_x, state[4])
85
+ x = x + cm_out
86
+ return x, tm_state + [cm_last_x]
87
+
88
+ class Head(tf.keras.Model):
89
+ def __init__(self, vocab_size):
90
+ super().__init__()
91
+ self.lm_head = layers.Dense(vocab_size, use_bias=False, name="output_head")
92
+ def call(self, x):
93
+ return tf.cast(self.lm_head(x), tf.float32)
94
+
95
+ class LM(tf.keras.Model):
96
+ def __init__(self, d_model, n_layers):
97
+ super().__init__()
98
+ self.token_embedding = layers.Embedding(vocab_size, d_model)
99
+ self.blocks = [Block(d_model, i, n_layers) for i in range(n_layers)]
100
+ self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
101
+ def call(self, x, states):
102
+ x = self.token_embedding(x)
103
+ new_states = []
104
+ for i, block in enumerate(self.blocks):
105
+ x, b_state = block(x, states[i*5 : (i+1)*5])
106
+ new_states.extend(b_state)
107
+ return self.ln_f(x), new_states
108
+
109
+ # --- 2. ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ์ดˆ๊ธฐํ™” ---
110
+ d_model, n_layers = 512, 10
111
+ blocklm = LM(d_model, n_layers)
112
+ head = Head(vocab_size)
113
+
114
+ def get_init_state():
115
+ return [tf.zeros((1, 1, d_model)) if i%5!=3 else tf.ones((1, 1, d_model))*-1e30 for i in range(n_layers*5)]
116
+
117
+ # ๊ฐ€์ค‘์น˜ ๊ตฌ์กฐ ์ƒ์„ฑ์„ ์œ„ํ•œ Dummy Call
118
+ _o, _s = blocklm(tf.constant([[0]]), get_init_state())
119
+ _ = head(_o)
120
+
121
+ # ๊ฐ€์ค‘์น˜ ํŒŒ์ผ ๋กœ๋“œ
122
+ blocklm.load_weights("blocklm.weights.h5")
123
+ head.load_weights("head.weights.h5")
124
+
125
+ # --- 3. ์ถ”๋ก  ์—”์ง„ ์ •์˜ (๊ธฐ์กด ์œ ์ง€) ---
126
+ class InferenceEngine:
127
+ def __init__(self, model, head, sp):
128
+ self.model = model
129
+ self.head = head
130
+ self.sp = sp
131
+ self.pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
132
+ self.eos_id = sp.piece_to_id("</s>") if sp.piece_to_id("</s>") != -1 else sp.piece_to_id("[EOS]")
133
+
134
+ def apply_repetition_penalty(self, logits, generated_ids, penalty, window):
135
+ if not generated_ids: return logits
136
+ recent_ids = set(generated_ids[-window:])
137
+ for token_id in recent_ids:
138
+ if logits[token_id] > 0: logits[token_id] /= penalty
139
+ else: logits[token_id] *= penalty
140
+ return logits
141
+
142
+ def sample(self, logits, temp, top_k, top_p):
143
+ if temp <= 0: return np.argmax(logits)
144
+ logits = logits / temp
145
+ if top_k > 0:
146
+ indices_to_remove = logits < np.sort(logits)[-min(top_k, logits.shape[-1])]
147
+ logits[indices_to_remove] = -float('inf')
148
+
149
+ probs = tf.nn.softmax(logits).numpy()
150
+ sorted_indices = np.argsort(probs)[::-1]
151
+ sorted_probs = probs[sorted_indices]
152
+ cumulative_probs = np.cumsum(sorted_probs)
153
+ idx_to_remove = cumulative_probs > top_p
154
+ if np.any(idx_to_remove):
155
+ cutoff_idx = max(1, np.where(idx_to_remove)[0][0] + 1)
156
+ probs[sorted_indices[cutoff_idx:]] = 0
157
+ if np.sum(probs) > 0: probs /= np.sum(probs)
158
+ else: probs[sorted_indices[0]] = 1.0
159
+ return np.random.choice(len(probs), p=probs)
160
+
161
+ @tf.function(reduce_retracing=True)
162
+ def model_step(self, token_id, states):
163
+ out, next_states = self.model(token_id, states)
164
+ logits = self.head(out)
165
+ return logits, next_states
166
+
167
+ def generate_stream(self, prompt, max_new_tokens, temperature, top_k, top_p, penalty, window):
168
+ input_ids = self.sp.encode(prompt)
169
+ states = get_init_state()
170
+ generated = []
171
+
172
+ if len(input_ids) > 1:
173
+ for i in range(len(input_ids) - 1):
174
+ _, states = self.model_step(tf.constant([[input_ids[i]]]), states)
175
+
176
+ curr_token_id = input_ids[-1]
177
+ prev_text = ""
178
+
179
+ for _ in range(max_new_tokens):
180
+ logits_out, states = self.model_step(tf.constant([[curr_token_id]]), states)
181
+ logits = logits_out[0, 0].numpy()
182
+ logits = self.apply_repetition_penalty(logits, input_ids + generated, penalty, window)
183
+ logits[self.pad_id] = -float('inf')
184
+
185
+ next_id = int(self.sample(logits, temperature, top_k, top_p))
186
+ if next_id == self.eos_id: break
187
+
188
+ generated.append(next_id)
189
+ full_text = self.sp.decode(generated)
190
+ new_part = full_text[len(prev_text):]
191
+ if new_part:
192
+ yield new_part
193
+ prev_text = full_text
194
+ curr_token_id = next_id
195
+
196
+ engine = InferenceEngine(blocklm, head, sp)
197
+
198
+ # --- 4. Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ ---
199
+ def chat_response(message, history, max_tokens, temp, top_p, top_k, penalty):
200
+ # ๋Œ€ํ™” ๋งฅ๋ฝ์„ ํฌํ•จํ•œ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
201
+ # ๊ฐ„๋‹จํ•œ ๊ตฌ์กฐ: Question: {msg}\nAnswer:
202
+ full_prompt = f"Question: {message}\nAnswer:"
203
+
204
+ partial_message = ""
205
+ for delta in engine.generate_stream(
206
+ full_prompt,
207
+ max_new_tokens=max_tokens,
208
+ temperature=temp,
209
+ top_k=top_k,
210
+ top_p=top_p,
211
+ penalty=penalty,
212
+ window=64
213
+ ):
214
+ partial_message += delta
215
+ yield partial_message
216
+
217
+ # Gradio ํ…Œ๋งˆ ๋ฐ ๋ ˆ์ด์•„์›ƒ ์„ค์ •
218
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
219
+ gr.Markdown("# ๐Ÿš€ Dynamic Engine Chatbot")
220
+ gr.Markdown("๋™์  ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ชจ๋ธ์„ ์œ„ํ•œ ์‹ค์‹œ๊ฐ„ ์ŠคํŠธ๋ฆฌ๋ฐ ์ฑ„ํŒ… UI์ž…๋‹ˆ๋‹ค.")
221
+
222
+ with gr.Row():
223
+ with gr.Column(scale=4):
224
+ chatbot = gr.ChatInterface(
225
+ fn=chat_response,
226
+ additional_inputs=[
227
+ gr.Slider(1, 2048, value=512, step=1, label="Max New Tokens"),
228
+ gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="Temperature"),
229
+ gr.Slider(0.0, 1.0, value=0.92, step=0.01, label="Top-P"),
230
+ gr.Slider(0, 100, value=40, step=1, label="Top-K"),
231
+ gr.Slider(1.0, 2.0, value=1.2, step=0.05, label="Repetition Penalty"),
232
+ ],
233
+ examples=[["What is AI?"], ["Hello."]],
234
+ )
235
+
236
+ gr.Markdown("---")
237
+ gr.Markdown("### ๐Ÿ›  Model Info")
238
+ gr.Markdown(f"- **D_Model**: {d_model} | **Layers**: {n_layers} | **Vocab**: {vocab_size}")
239
+
240
+ if __name__ == "__main__":
241
+ # share=True๋ฅผ ์„ค์ •ํ•˜๋ฉด ์™ธ๋ถ€ ๊ณต์œ  ๋งํฌ๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.
242
+ demo.queue().launch(share=True)