Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import threading | |
| import time | |
| import os | |
| import json | |
| from model import VedaProgrammingLLM | |
| from tokenizer import VedaTokenizer | |
| from database import db | |
| from train import VedaTrainer | |
| from teacher import teacher | |
| from config import MODEL_DIR | |
| model = None | |
| tokenizer = None | |
| current_id = -1 | |
| def clean_response(text: str) -> str: | |
| if not text: | |
| return "" | |
| text = text.replace("<CODE>", "\n```python\n").replace("<ENDCODE>", "\n```\n") | |
| for token in ["<PAD>", "<UNK>", "<START>", "<END>", "<USER>", "<ASSISTANT>"]: | |
| text = text.replace(token, "") | |
| return text.strip() | |
| def init(): | |
| """Load model if exists else train once then load.""" | |
| global model, tokenizer | |
| conf_path = os.path.join(MODEL_DIR, "config.json") | |
| weights_path = os.path.join(MODEL_DIR, "weights.h5") | |
| tok_path = os.path.join(MODEL_DIR, "tokenizer.json") | |
| if os.path.exists(weights_path) and os.path.exists(conf_path) and os.path.exists(tok_path): | |
| with open(conf_path, "r") as f: | |
| conf = json.load(f) | |
| tokenizer = VedaTokenizer() | |
| tokenizer.load(tok_path) | |
| model = VedaProgrammingLLM(**conf) | |
| # build model graph | |
| max_len = conf.get("max_length", 512) | |
| model(tf.zeros((1, max_len), dtype=tf.int32)) | |
| model.load_weights(weights_path) | |
| print("[Init] Model loaded.") | |
| return | |
| print("[Init] No model found -> Training initial model...") | |
| VedaTrainer().train(epochs=10) | |
| print("[Init] Training done -> Loading model...") | |
| init() | |
| def auto_train_loop(): | |
| """Background auto-train on teacher samples if available.""" | |
| while True: | |
| time.sleep(300) # 5 min | |
| try: | |
| data = db.get_unused_distillation() | |
| if data and len(data) >= 5: | |
| print(f"[AutoTrain] Training on {len(data)} teacher samples...") | |
| extra = "\n".join([f"<USER> {r[1]}\n<ASSISTANT> {r[2]}" for r in data]) | |
| VedaTrainer().train(epochs=3, extra_data=extra) | |
| db.mark_used([r[0] for r in data]) | |
| init() | |
| except Exception as e: | |
| print("[AutoTrain] skipped:", e) | |
| def is_good(text: str) -> bool: | |
| if not text: | |
| return False | |
| t = text.strip() | |
| if len(t) < 20: | |
| return False | |
| # basic gibberish detectors | |
| if "arr[" in t and "def " not in t and "return" not in t: | |
| return False | |
| if t.lower().count("hello how are you") >= 1: | |
| return False | |
| return True | |
| def respond(user_msg, history): | |
| """ | |
| IMPORTANT: history must be LIST OF DICTS: | |
| {"role":"user","content":"..."} | |
| {"role":"assistant","content":"..."} | |
| """ | |
| global current_id | |
| if history is None: | |
| history = [] | |
| user_msg = (user_msg or "").strip() | |
| if not user_msg: | |
| return "", history | |
| # Student response | |
| prompt = f"<USER> {user_msg}\n<ASSISTANT>" | |
| toks = tokenizer.encode(prompt) | |
| out_ids = model.generate(toks, max_new_tokens=200) | |
| resp = tokenizer.decode(out_ids) | |
| # Extract assistant section | |
| if "<ASSISTANT>" in resp: | |
| resp = resp.split("<ASSISTANT>")[-1] | |
| if "<USER>" in resp: | |
| resp = resp.split("<USER>")[0] | |
| resp = clean_response(resp) | |
| # Teacher fallback | |
| if (not is_good(resp)) and teacher.is_available(): | |
| t_resp = teacher.ask(user_msg) | |
| if t_resp: | |
| resp = t_resp | |
| try: | |
| db.save_distillation(user_msg, t_resp) | |
| except Exception as e: | |
| print("[DB] save_distillation failed:", e) | |
| current_id = db.save_conversation(user_msg, resp) | |
| # β Messages format | |
| history.append({"role": "user", "content": user_msg}) | |
| history.append({"role": "assistant", "content": resp}) | |
| return "", history | |
| def feedback_up(): | |
| if current_id > 0: | |
| db.update_feedback(current_id, 1) | |
| return "Saved π" | |
| def feedback_down(): | |
| if current_id > 0: | |
| db.update_feedback(current_id, -1) | |
| return "Saved π" | |
| # --- startup --- | |
| init() | |
| threading.Thread(target=auto_train_loop, daemon=True).start() | |
| with gr.Blocks(title="Veda Assistant") as demo: | |
| gr.Markdown("# ποΈ Veda Assistant") | |
| # DO NOT pass type= here (your Gradio rejects it) | |
| chat = gr.Chatbot(height=400, value=[]) | |
| msg = gr.Textbox(label="Message", placeholder="Write bubble sort in python") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Row(): | |
| up = gr.Button("π") | |
| down = gr.Button("π") | |
| msg.submit(respond, inputs=[msg, chat], outputs=[msg, chat]) | |
| up.click(feedback_up, outputs=status) | |
| down.click(feedback_down, outputs=status) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |