Openlm / app.py
OpenLab-NLP's picture
Update app.py
4bbac72 verified
Raw
History Blame Contribute Delete
10.5 kB
import sentencepiece as spm
import os, numpy as np, tensorflow as tf
from tensorflow.keras import layers
import gradio as gr
# --- 1. ํ™˜๊ฒฝ ์„ค์ • ๋ฐ ๋ชจ๋ธ ๊ตฌ์กฐ ์ •์˜ ---
# ํŒŒ์ผ ์ด๋ฆ„๋งŒ ์‚ฌ์šฉ (ํ˜„์žฌ ์ž‘์—… ๋””๋ ‰ํ† ๋ฆฌ์— ํŒŒ์ผ์ด ์žˆ์–ด์•ผ ํ•จ)
TOKENIZER_PATH = "tokenizer.model"
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
end_id = sp.piece_to_id("</s>")
vocab_size = sp.get_piece_size()
class TimeMix(layers.Layer):
def __init__(self, d_model, layer_id, n_layers):
super().__init__()
self.d_model = d_model
ratio = (layer_id / (n_layers - 1)) if n_layers > 1 else 0.5
decay_speed = np.arange(d_model)
self.time_decay = tf.Variable(-5 + 8 * (decay_speed / (d_model - 1)) ** (0.7 + 1.3 * ratio), dtype=tf.float32)
self.time_first = tf.Variable(np.ones(d_model) * np.log(0.3), dtype=tf.float32)
self.w_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False)
self.r_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False)
self.k_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False)
self.v_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False)
self.key = layers.Dense(d_model, use_bias=False)
self.value = layers.Dense(d_model, use_bias=False)
self.receptance = layers.Dense(d_model, use_bias=False)
self.output_projection = layers.Dense(d_model, use_bias=False)
self.tm_w = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32)
self.tm_k = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32)
self.tm_v = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32)
self.tm_r = tf.Variable(1 - (ratio ** 0.2), dtype=tf.float32)
def call(self, x, state):
last_x, aa, bb, pp = state
t_type = x.dtype
tm_w, tm_k, tm_v, tm_r = tf.cast(self.tm_w, t_type), tf.cast(self.tm_k, t_type), tf.cast(self.tm_v, t_type), tf.cast(self.tm_r, t_type)
dx = x * tm_w + last_x * (1 - tm_w)
w = tf.cast(self.time_decay, t_type) + tf.cast(self.w_proj(dx), t_type)
w = -tf.exp(tf.cast(w, tf.float32))
r = self.receptance(x * tm_r + last_x * (1 - tm_r)) + self.r_proj(dx)
k = self.key(x * tm_k + last_x * (1 - tm_k)) + self.k_proj(dx)
v = self.value(x * tm_v + last_x * (1 - tm_v)) + self.v_proj(dx)
u = tf.cast(self.time_first, tf.float32)
kv, vv = tf.cast(k, tf.float32), tf.cast(v, tf.float32)
ww = u + kv
p = tf.maximum(pp, ww)
e1, e2 = tf.exp(pp - p), tf.exp(ww - p)
wkv = (e1 * aa + e2 * vv) / (e1 * bb + e2 + 1e-12)
ww_next = w + pp
p_next = tf.maximum(ww_next, kv)
e1_next, e2_next = tf.exp(ww_next - p_next), tf.exp(kv - p_next)
new_state = [x, e1_next * aa + e2_next * vv, e1_next * bb + e2_next, p_next]
return self.output_projection(tf.nn.sigmoid(r) * tf.cast(wkv, t_type)), new_state
class ChannelMix(layers.Layer):
def __init__(self, d_model, layer_id, n_layers):
super().__init__()
ratio = (layer_id / (n_layers - 1)) if n_layers > 1 else 0.5
self.time_mix_k = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32)
self.time_mix_r = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32)
self.key = layers.Dense(int(d_model * 4.25), use_bias=False)
self.receptance = layers.Dense(d_model, use_bias=False)
self.value = layers.Dense(d_model, use_bias=False)
def call(self, x, last_x):
t_type = x.dtype
tm_k, tm_r = tf.cast(self.time_mix_k, t_type), tf.cast(self.time_mix_r, t_type)
k = self.key(x * tm_k + last_x * (1 - tm_k))
r = self.receptance(x * tm_r + last_x * (1 - tm_r))
kv = self.value(tf.square(tf.nn.relu(k)))
return tf.nn.sigmoid(r) * kv, x
class Block(layers.Layer):
def __init__(self, d_model, layer_id, n_layers):
super().__init__()
self.ln = layers.LayerNormalization(epsilon=1e-5)
self.time_mix = TimeMix(d_model, layer_id, n_layers)
self.channel_mix = ChannelMix(d_model, layer_id, n_layers)
def call(self, x, state):
ln_x = self.ln(x)
tm_out, tm_state = self.time_mix(ln_x, state[:4])
x = x + tm_out
cm_out, cm_last_x = self.channel_mix(ln_x, state[4])
x = x + cm_out
return x, tm_state + [cm_last_x]
class Head(tf.keras.Model):
def __init__(self, vocab_size):
super().__init__()
self.lm_head = layers.Dense(vocab_size, use_bias=False, name="output_head")
def call(self, x):
return tf.cast(self.lm_head(x), tf.float32)
class LM(tf.keras.Model):
def __init__(self, d_model, n_layers):
super().__init__()
self.token_embedding = layers.Embedding(vocab_size, d_model)
self.blocks = [Block(d_model, i, n_layers) for i in range(n_layers)]
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
def call(self, x, states):
x = self.token_embedding(x)
new_states = []
for i, block in enumerate(self.blocks):
x, b_state = block(x, states[i*5 : (i+1)*5])
new_states.extend(b_state)
return self.ln_f(x), new_states
# --- 2. ์ดˆ๊ธฐํ™” ๋ฐ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ---
d_model, n_layers = 512, 10
blocklm = LM(d_model, n_layers)
head = Head(vocab_size)
def get_init_state():
return [tf.zeros((1, 1, d_model)) if i%5!=3 else tf.ones((1, 1, d_model))*-1e30 for i in range(n_layers*5)]
# Dummy call
_o, _s = blocklm(tf.constant([[0]]), get_init_state())
_ = head(_o)
blocklm.load_weights("blocklm.weights.h5")
head.load_weights("head.weights.h5")
# --- 3. ์ถ”๋ก  ์—”์ง„ ---
class InferenceEngine:
def __init__(self, model, head, sp):
self.model = model
self.head = head
self.sp = sp
self.pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
self.eos_id = sp.piece_to_id("</s>") if sp.piece_to_id("</s>") != -1 else sp.piece_to_id("[EOS]")
def apply_repetition_penalty(self, logits, generated_ids, penalty, window=64):
if not generated_ids: return logits
recent_ids = set(generated_ids[-window:])
for token_id in recent_ids:
if logits[token_id] > 0: logits[token_id] /= penalty
else: logits[token_id] *= penalty
return logits
def sample(self, logits, temp, top_k, top_p):
if temp <= 0: return np.argmax(logits)
logits = logits / temp
if top_k > 0:
indices_to_remove = logits < np.sort(logits)[-min(top_k, logits.shape[-1])]
logits[indices_to_remove] = -float('inf')
probs = tf.nn.softmax(logits).numpy()
sorted_indices = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_indices]
cumulative_probs = np.cumsum(sorted_probs)
idx_to_remove = cumulative_probs > top_p
if np.any(idx_to_remove):
cutoff_idx = max(1, np.where(idx_to_remove)[0][0] + 1)
probs[sorted_indices[cutoff_idx:]] = 0
if np.sum(probs) > 0: probs /= np.sum(probs)
else: probs[sorted_indices[0]] = 1.0
return np.random.choice(len(probs), p=probs)
@tf.function(reduce_retracing=True)
def model_step(self, token_id, states):
out, next_states = self.model(token_id, states)
logits = self.head(out)
return logits, next_states
def generate(self, prompt, max_new_tokens, temp, top_k, top_p, penalty):
input_ids = self.sp.encode(prompt)
states = get_init_state()
generated = []
if len(input_ids) > 1:
for i in range(len(input_ids) - 1):
_, states = self.model_step(tf.constant([[input_ids[i]]]), states)
curr_token_id = input_ids[-1]
prev_text = ""
for _ in range(max_new_tokens):
logits_out, states = self.model_step(tf.constant([[curr_token_id]]), states)
logits = logits_out[0, 0].numpy()
logits = self.apply_repetition_penalty(logits, input_ids + generated, penalty)
logits[self.pad_id] = -float('inf')
next_id = int(self.sample(logits, temp, top_k, top_p))
if next_id == self.eos_id: break
generated.append(next_id)
full_text = self.sp.decode(generated)
new_part = full_text[len(prev_text):]
if new_part:
yield new_part
prev_text = full_text
curr_token_id = next_id
engine = InferenceEngine(blocklm, head, sp)
# --- 4. Gradio UI (๋‹จ์ˆœ ํ…์ŠคํŠธ ์ž…์ถœ๋ ฅ ๋ฐฉ์‹) ---
with gr.Blocks(title="RWKV Text Generator") as demo:
gr.Markdown("## ๐Ÿ–‹๏ธ Dynamic RWKV Text Generation")
gr.Markdown("์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜๊ณ  Generate๋ฅผ ๋ˆ„๋ฅด๋ฉด ๋‹ต๋ณ€์ด ์•„๋ž˜ ํ…์ŠคํŠธ ๋ฐ•์Šค์— ์‹ค์‹œ๊ฐ„์œผ๋กœ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(lines=5, label="Input Prompt", placeholder="์—ฌ๊ธฐ์— ์งˆ๋ฌธ์ด๋‚˜ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”...")
with gr.Row():
temp_slider = gr.Slider(0, 2, value=0.7, label="Temperature")
top_p_slider = gr.Slider(0, 1, value=0.92, label="Top-P")
with gr.Row():
penalty_slider = gr.Slider(1, 2, value=1.2, label="Penalty")
max_tokens = gr.Slider(1, 2048, value=512, step=1, label="Max Tokens")
submit_btn = gr.Button("Generate", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Column():
output_text = gr.Textbox(lines=15, label="Generated Output", interactive=False)
def run_generation(prompt, tokens, temp, top_p, penalty):
if not prompt.strip():
return "ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."
full_prompt = f"Question: {prompt}\nAnswer:"
current_output = ""
for chunk in engine.generate(full_prompt, int(tokens), temp, 40, top_p, penalty):
current_output += chunk
yield current_output
# ๋ฒ„ํŠผ ํด๋ฆญ ๋ฐ ์—”ํ„ฐ ํ‚ค ์ž…๋ ฅ ์ด๋ฒคํŠธ
submit_btn.click(
fn=run_generation,
inputs=[input_text, max_tokens, temp_slider, top_p_slider, penalty_slider],
outputs=output_text
)
clear_btn.click(lambda: ("", ""), outputs=[input_text, output_text])
if __name__ == "__main__":
demo.queue().launch()