import json import traceback import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, LogitsProcessorList from trl import PPOTrainer, PPOConfig import gradio as gr # ----------------------------------------------------------------------------- # 1. Helpers # ----------------------------------------------------------------------------- def make_json_serializable(obj): """ Recursively convert any torch.Tensor in obj to Python lists. """ if isinstance(obj, torch.Tensor): return obj.cpu().tolist() elif isinstance(obj, dict): return {k: make_json_serializable(v) for k, v in obj.items()} elif isinstance(obj, list): return [make_json_serializable(v) for v in obj] return obj def safe_json_dumps(data): """ Dump JSON with our converter to avoid Tensor serialization errors. """ return json.dumps( make_json_serializable(data), indent=2, ensure_ascii=False ) # ----------------------------------------------------------------------------- # 2. Load Models and Initialize PPO Agent # ----------------------------------------------------------------------------- MODEL_NAME = "google/flan-t5-base" # Core seq2seq model & tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) # PPO configuration ppo_config = PPOConfig( model_name=MODEL_NAME, learning_rate=1e-5, batch_size=1, log_with=None # switch to "wandb" or "tensorboard" if you like ) # Wrap FLAN-T5 in a PPO agent ppo_trainer = PPOTrainer( config=ppo_config, model=model, tokenizer=tokenizer ) # ----------------------------------------------------------------------------- # 3. Session State # ----------------------------------------------------------------------------- current_session = { "dialog": [] # each entry: {"user": str, "bot": str, "reward": float or None} } # ----------------------------------------------------------------------------- # 4. Core Callback Functions # ----------------------------------------------------------------------------- def reset_session(): """ Clear the conversation and return an empty chat history. """ global current_session current_session = {"dialog": []} return [] def chat_with_agent(user_input: str): """ Generate the model's reply, append to session, and return full chat history. """ global current_session try: # Tokenize user prompt and generate inputs = tokenizer(user_input, return_tensors="pt").input_ids outputs = model.generate( inputs, max_new_tokens=128, do_sample=True, top_p=0.9, temperature=0.8 ) bot_reply = tokenizer.decode(outputs[0], skip_special_tokens=True) # Store in session current_session["dialog"].append({ "user": user_input, "bot": bot_reply, "reward": None }) # Prepare for Gradio Chatbot: list of (user, bot) history = [ (turn["user"], turn["bot"]) for turn in current_session["dialog"] ] return history except Exception as e: print("🔥 Error in chat_with_agent:", e) traceback.print_exc() # On failure, leave session untouched return [("Error:", "Failed to generate reply. Check logs.")] def rate_and_train(rating: float): """ Take the last bot reply’s rating, run a PPO step, and return serialized session. """ global current_session try: if not current_session["dialog"]: return "No dialog to rate. Chat first." # Attach reward last = current_session["dialog"][-1] last["reward"] = float(rating) # Prepare for PPO step user_text = last["user"] bot_text = last["bot"] # Token IDs for PPO query_ids = tokenizer(user_text, return_tensors="pt").input_ids.squeeze(0) response_ids = tokenizer(bot_text, return_tensors="pt").input_ids.squeeze(0) # Run PPO optimization with this single example stats = ppo_trainer.step( [query_ids], [response_ids], [last["reward"]] ) print("🚀 PPO step stats:", stats) # Return the entire session as JSON return safe_json_dumps(current_session) except Exception as e: print("🔥 Error in rate_and_train:", e) traceback.print_exc() return "Failed to apply training step. See logs." # ----------------------------------------------------------------------------- # 5. Gradio UI # ----------------------------------------------------------------------------- with gr.Blocks() as demo: gr.Markdown("## FLAN-T5 Chatbot with On-the-Fly Reinforcement Learning") chat_box = gr.Chatbot(label="Chat History") user_input = gr.Textbox(placeholder="Type your message here…", label="You") send_btn = gr.Button("Send") reset_btn = gr.Button("Reset Conversation") with gr.Row(): rating = gr.Slider(0, 5, step=1, value=0, label="Rate Last Reply") rate_btn = gr.Button("Apply Rating & Train") export_json = gr.Textbox(label="Session JSON", lines=10) # Reset chat reset_btn.click( fn=reset_session, inputs=None, outputs=chat_box ) # Send user message send_btn.click( fn=chat_with_agent, inputs=user_input, outputs=chat_box ) # Rate & train rate_btn.click( fn=rate_and_train, inputs=rating, outputs=export_json ) if __name__ == "__main__": demo.launch()