vedaco's picture
Update app.py
06758b5 verified
import gradio as gr
import tensorflow as tf
import threading
import time
import os
import json
from model import VedaProgrammingLLM
from tokenizer import VedaTokenizer
from database import db
from train import VedaTrainer
from teacher import teacher
from config import MODEL_DIR
model = None
tokenizer = None
current_id = -1
def clean_response(text: str) -> str:
if not text:
return ""
text = text.replace("<CODE>", "\n```python\n").replace("<ENDCODE>", "\n```\n")
for token in ["<PAD>", "<UNK>", "<START>", "<END>", "<USER>", "<ASSISTANT>"]:
text = text.replace(token, "")
return text.strip()
def init():
"""Load model if exists else train once then load."""
global model, tokenizer
conf_path = os.path.join(MODEL_DIR, "config.json")
weights_path = os.path.join(MODEL_DIR, "weights.h5")
tok_path = os.path.join(MODEL_DIR, "tokenizer.json")
if os.path.exists(weights_path) and os.path.exists(conf_path) and os.path.exists(tok_path):
with open(conf_path, "r") as f:
conf = json.load(f)
tokenizer = VedaTokenizer()
tokenizer.load(tok_path)
model = VedaProgrammingLLM(**conf)
# build model graph
max_len = conf.get("max_length", 512)
model(tf.zeros((1, max_len), dtype=tf.int32))
model.load_weights(weights_path)
print("[Init] Model loaded.")
return
print("[Init] No model found -> Training initial model...")
VedaTrainer().train(epochs=10)
print("[Init] Training done -> Loading model...")
init()
def auto_train_loop():
"""Background auto-train on teacher samples if available."""
while True:
time.sleep(300) # 5 min
try:
data = db.get_unused_distillation()
if data and len(data) >= 5:
print(f"[AutoTrain] Training on {len(data)} teacher samples...")
extra = "\n".join([f"<USER> {r[1]}\n<ASSISTANT> {r[2]}" for r in data])
VedaTrainer().train(epochs=3, extra_data=extra)
db.mark_used([r[0] for r in data])
init()
except Exception as e:
print("[AutoTrain] skipped:", e)
def is_good(text: str) -> bool:
if not text:
return False
t = text.strip()
if len(t) < 20:
return False
# basic gibberish detectors
if "arr[" in t and "def " not in t and "return" not in t:
return False
if t.lower().count("hello how are you") >= 1:
return False
return True
def respond(user_msg, history):
"""
IMPORTANT: history must be LIST OF DICTS:
{"role":"user","content":"..."}
{"role":"assistant","content":"..."}
"""
global current_id
if history is None:
history = []
user_msg = (user_msg or "").strip()
if not user_msg:
return "", history
# Student response
prompt = f"<USER> {user_msg}\n<ASSISTANT>"
toks = tokenizer.encode(prompt)
out_ids = model.generate(toks, max_new_tokens=200)
resp = tokenizer.decode(out_ids)
# Extract assistant section
if "<ASSISTANT>" in resp:
resp = resp.split("<ASSISTANT>")[-1]
if "<USER>" in resp:
resp = resp.split("<USER>")[0]
resp = clean_response(resp)
# Teacher fallback
if (not is_good(resp)) and teacher.is_available():
t_resp = teacher.ask(user_msg)
if t_resp:
resp = t_resp
try:
db.save_distillation(user_msg, t_resp)
except Exception as e:
print("[DB] save_distillation failed:", e)
current_id = db.save_conversation(user_msg, resp)
# βœ… Messages format
history.append({"role": "user", "content": user_msg})
history.append({"role": "assistant", "content": resp})
return "", history
def feedback_up():
if current_id > 0:
db.update_feedback(current_id, 1)
return "Saved πŸ‘"
def feedback_down():
if current_id > 0:
db.update_feedback(current_id, -1)
return "Saved πŸ‘Ž"
# --- startup ---
init()
threading.Thread(target=auto_train_loop, daemon=True).start()
with gr.Blocks(title="Veda Assistant") as demo:
gr.Markdown("# πŸ•‰οΈ Veda Assistant")
# DO NOT pass type= here (your Gradio rejects it)
chat = gr.Chatbot(height=400, value=[])
msg = gr.Textbox(label="Message", placeholder="Write bubble sort in python")
status = gr.Textbox(label="Status", interactive=False)
with gr.Row():
up = gr.Button("πŸ‘")
down = gr.Button("πŸ‘Ž")
msg.submit(respond, inputs=[msg, chat], outputs=[msg, chat])
up.click(feedback_up, outputs=status)
down.click(feedback_down, outputs=status)
demo.launch(server_name="0.0.0.0", server_port=7860)