S-Dreamer commited on
Commit
f40a9d3
·
verified ·
1 Parent(s): cc3d721

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -11
app.py CHANGED
@@ -1,25 +1,45 @@
1
  import gradio as gr
 
 
2
 
3
- # Function to process the input
4
- def chatbot_response(user_input):
5
- # Simple logic for chatbot (e.g., greeting or keywords detection)
6
- if "hello" in user_input.lower():
7
- return "Hello! How can I help you today?"
8
- elif "how are you" in user_input.lower():
9
- return "I'm doing great, thank you for asking!"
10
- else:
11
- return "I'm sorry, I didn't understand that. Can you rephrase?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Define the Gradio interface
14
  def create_interface():
15
  # Gradio interface setup
16
  with gr.Blocks() as demo:
17
- gr.Markdown("## Chatbot Application")
18
  with gr.Row():
19
  chat_box = gr.Chatbot()
20
  input_box = gr.Textbox(placeholder="Type a message...")
21
 
22
- input_box.submit(chatbot_response, input_box, chat_box)
 
23
 
24
  return demo
25
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
+ # Load pre-trained Hugging Face model (DialoGPT-medium)
6
+ model_name = "microsoft/DialoGPT-medium"
7
+ model = AutoModelForCausalLM.from_pretrained(model_name)
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+
10
+ # Chat history to maintain context
11
+ chat_history_ids = None
12
+
13
+ # Function to generate chatbot response using Hugging Face model
14
+ def chatbot_response(user_input, chat_history):
15
+ # Tokenize user input and add chat history
16
+ new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
17
+
18
+ # Append new input to the chat history
19
+ bot_input_ids = torch.cat([chat_history, new_user_input_ids], dim=-1) if chat_history is not None else new_user_input_ids
20
+
21
+ # Generate response from the model
22
+ chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id,
23
+ temperature=0.7, top_k=50, top_p=0.95, num_return_sequences=1,
24
+ no_repeat_ngram_size=3, do_sample=True)
25
+
26
+ # Decode the response and return
27
+ chat_history_ids = chat_history_ids[:, new_user_input_ids.shape[-1]:] # Keep the new response only
28
+ bot_output = tokenizer.decode(chat_history_ids[0], skip_special_tokens=True)
29
+
30
+ return bot_output, chat_history_ids
31
 
32
  # Define the Gradio interface
33
  def create_interface():
34
  # Gradio interface setup
35
  with gr.Blocks() as demo:
36
+ gr.Markdown("## Chatbot powered by DialoGPT")
37
  with gr.Row():
38
  chat_box = gr.Chatbot()
39
  input_box = gr.Textbox(placeholder="Type a message...")
40
 
41
+ # Submit the input and get the response
42
+ input_box.submit(chatbot_response, [input_box, chat_box], [chat_box, gr.State()])
43
 
44
  return demo
45