File size: 3,931 Bytes
4d0e37d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b02a875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d0e37d
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from flask import Flask, jsonify, request, render_template
from flask_cors import CORS
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch

app = Flask(__name__)
CORS(app)

# Global variables for model and tokenizer
MODEL_PATH = "./models/fine-tuned-gpt2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = None
model = None

def load_chatbot_model():
    """Load the chatbot model and tokenizer"""
    global tokenizer, model
    if model is None:
        print(f"Loading chatbot model from {MODEL_PATH}...")
        print(f"Using device: {device}")
        
        tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH)
        model = GPT2LMHeadModel.from_pretrained(MODEL_PATH)
        model.to(device)
        
        print("Model loaded successfully!")

# Load model on startup
load_chatbot_model()

@app.route("/")
def index():
    """Serve the chat interface"""
    return render_template('index.html')

@app.route("/api")
def root():
    return jsonify({
        "message": "Chatbot API",
        "status": "running",
        "model": "fine-tuned-gpt2",
        "device": str(device)
    })

@app.route("/health")
def health():
    return jsonify({
        "status": "healthy",
        "model_loaded": model is not None,
        "device": str(device)
    })

@app.route("/chat", methods=["POST"])
def chat():
    """
    Generate a chatbot response based on conversation history
    """
    if model is None or tokenizer is None:
        return jsonify({"error": "Model not loaded"}), 500
    
    try:
        data = request.get_json()
        user_messages = data.get("user", [])
        ai_messages = data.get("ai", [])
        
        # Build conversation history
        combined_prompt = ""
        
        # Limit history to last 7 exchanges
        user_msgs = user_messages[-7:] if len(user_messages) > 7 else user_messages
        ai_msgs = ai_messages[-6:] if len(ai_messages) > 6 else ai_messages
        
        # Add conversation history
        for user_message, ai_message in zip(user_msgs[:-1], ai_msgs):
            combined_prompt += f"<user> {user_message}{tokenizer.eos_token}<AI> {ai_message}{tokenizer.eos_token}"
        
        # Add current message
        if user_msgs:
            combined_prompt += f"<user> {user_msgs[-1]}{tokenizer.eos_token}<AI>"
        
        # Tokenize and generate
        inputs = tokenizer.encode(combined_prompt, return_tensors="pt").to(device)
        attention_mask = torch.ones(inputs.shape, device=device)
        
        outputs = model.generate(
            inputs,
            max_new_tokens=50,
            num_beams=5,
            early_stopping=True,
            no_repeat_ngram_size=2,
            temperature=0.7,
            top_k=50,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id,
            attention_mask=attention_mask,
            repetition_penalty=1.2
        )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract only the new AI response
        # Split by <AI> and get the last response
        if "<AI>" in response:
            response = response.split("<AI>")[-1].strip()
        
        # Remove any <user> tags if they appear (model might generate them)
        if "<user>" in response:
            response = response.split("<user>")[0].strip()
        
        # Clean up any remaining special tokens
        response = response.replace("<AI>", "").replace("<user>", "").strip()
        
        # If empty response, provide a default
        if not response:
            response = "I'm not sure how to respond to that."
        
        return jsonify({
            "response": response,
            "device": str(device)
        })
    
    except Exception as e:
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860, debug=False)