from flask import Flask, jsonify, request, render_template from flask_cors import CORS from transformers import GPT2Tokenizer, GPT2LMHeadModel import torch app = Flask(__name__) CORS(app) # Global variables for model and tokenizer MODEL_PATH = "./models/fine-tuned-gpt2" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = None model = None def load_chatbot_model(): """Load the chatbot model and tokenizer""" global tokenizer, model if model is None: print(f"Loading chatbot model from {MODEL_PATH}...") print(f"Using device: {device}") tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH) model = GPT2LMHeadModel.from_pretrained(MODEL_PATH) model.to(device) print("Model loaded successfully!") # Load model on startup load_chatbot_model() @app.route("/") def index(): """Serve the chat interface""" return render_template('index.html') @app.route("/api") def root(): return jsonify({ "message": "Chatbot API", "status": "running", "model": "fine-tuned-gpt2", "device": str(device) }) @app.route("/health") def health(): return jsonify({ "status": "healthy", "model_loaded": model is not None, "device": str(device) }) @app.route("/chat", methods=["POST"]) def chat(): """ Generate a chatbot response based on conversation history """ if model is None or tokenizer is None: return jsonify({"error": "Model not loaded"}), 500 try: data = request.get_json() user_messages = data.get("user", []) ai_messages = data.get("ai", []) # Build conversation history combined_prompt = "" # Limit history to last 7 exchanges user_msgs = user_messages[-7:] if len(user_messages) > 7 else user_messages ai_msgs = ai_messages[-6:] if len(ai_messages) > 6 else ai_messages # Add conversation history for user_message, ai_message in zip(user_msgs[:-1], ai_msgs): combined_prompt += f" {user_message}{tokenizer.eos_token} {ai_message}{tokenizer.eos_token}" # Add current message if user_msgs: combined_prompt += f" {user_msgs[-1]}{tokenizer.eos_token}" # Tokenize and generate inputs = tokenizer.encode(combined_prompt, return_tensors="pt").to(device) attention_mask = torch.ones(inputs.shape, device=device) outputs = model.generate( inputs, max_new_tokens=50, num_beams=5, early_stopping=True, no_repeat_ngram_size=2, temperature=0.7, top_k=50, top_p=0.95, pad_token_id=tokenizer.eos_token_id, attention_mask=attention_mask, repetition_penalty=1.2 ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the new AI response # Split by and get the last response if "" in response: response = response.split("")[-1].strip() # Remove any tags if they appear (model might generate them) if "" in response: response = response.split("")[0].strip() # Clean up any remaining special tokens response = response.replace("", "").replace("", "").strip() # If empty response, provide a default if not response: response = "I'm not sure how to respond to that." return jsonify({ "response": response, "device": str(device) }) except Exception as e: return jsonify({"error": str(e)}), 500 if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=False)