Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -256,266 +256,286 @@ if selected_app == "3) Upload Datasets":
|
|
| 256 |
st.markdown("Go to this [google colab link](https://colab.research.google.com/drive/1eCpk9HUoCKZb--tiNyQSHFW2ojoaA35m) to get started")
|
| 257 |
|
| 258 |
if selected_app == "4) Create Chatbot":
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
requirements = '''
|
| 264 |
-
openai
|
| 265 |
-
scipy
|
| 266 |
-
streamlit
|
| 267 |
-
chromadb
|
| 268 |
-
datasets
|
| 269 |
-
'''
|
| 270 |
-
|
| 271 |
-
st.write("requirements.txt")
|
| 272 |
-
st.code(requirements, language='python')
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
import streamlit as st
|
| 277 |
-
from datasets import load_dataset
|
| 278 |
-
import chromadb
|
| 279 |
-
import string
|
| 280 |
-
|
| 281 |
-
from openai import OpenAI
|
| 282 |
-
|
| 283 |
-
import numpy as np
|
| 284 |
-
import pandas as pd
|
| 285 |
-
|
| 286 |
-
from scipy.spatial.distance import cosine
|
| 287 |
-
|
| 288 |
-
from typing import Dict, List
|
| 289 |
-
|
| 290 |
-
def merge_dataframes(dataframes):
|
| 291 |
-
# Concatenate the list of dataframes
|
| 292 |
-
combined_dataframe = pd.concat(dataframes, ignore_index=True)
|
| 293 |
-
|
| 294 |
-
# Ensure that the resulting dataframe only contains the columns "context", "questions", "answers"
|
| 295 |
-
combined_dataframe = combined_dataframe[['context', 'questions', 'answers']]
|
| 296 |
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
Args:
|
| 303 |
-
prompt: A string representing the prompt to send to the OpenAI API.
|
| 304 |
-
Returns:
|
| 305 |
-
A string representing the AI's generated response.
|
| 306 |
-
'''
|
| 307 |
-
|
| 308 |
-
# Use the OpenAI API to generate a response based on the input prompt.
|
| 309 |
-
client = OpenAI(api_key = os.environ["OPENAI_API_KEY"])
|
| 310 |
-
|
| 311 |
-
completion = client.chat.completions.create(
|
| 312 |
-
model="gpt-3.5-turbo-0125",
|
| 313 |
-
messages=[
|
| 314 |
-
{"role": "system", "content": directions},
|
| 315 |
-
{"role": "user", "content": prompt}
|
| 316 |
-
]
|
| 317 |
-
)
|
| 318 |
-
|
| 319 |
-
# Extract the text from the first (and only) choice in the response output.
|
| 320 |
-
ans = completion.choices[0].message.content
|
| 321 |
-
|
| 322 |
-
# Return the generated AI response.
|
| 323 |
-
return ans
|
| 324 |
-
|
| 325 |
-
def openai_text_embedding(prompt: str) -> str:
|
| 326 |
-
return openai.Embedding.create(input=prompt, model="text-embedding-ada-002")[
|
| 327 |
-
"data"
|
| 328 |
-
][0]["embedding"]
|
| 329 |
-
|
| 330 |
-
def calculate_sts_openai_score(sentence1: str, sentence2: str) -> float:
|
| 331 |
-
# Compute sentence embeddings
|
| 332 |
-
embedding1 = openai_text_embedding(sentence1) # Flatten the embedding array
|
| 333 |
-
embedding2 = openai_text_embedding(sentence2) # Flatten the embedding array
|
| 334 |
-
|
| 335 |
-
# Convert to array
|
| 336 |
-
embedding1 = np.asarray(embedding1)
|
| 337 |
-
embedding2 = np.asarray(embedding2)
|
| 338 |
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
return similarity_score
|
| 343 |
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
lambda x: calculate_sts_openai_score(str(x), sentence)
|
| 349 |
-
)
|
| 350 |
-
|
| 351 |
-
sorted_dataframe = dataframe.sort_values(by="stsopenai", ascending=False)
|
| 352 |
|
|
|
|
| 353 |
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
df: A pandas DataFrame with columns named 'questions' and 'answers'.
|
| 361 |
-
Returns:
|
| 362 |
-
A list of dictionaries, with each dictionary containing a 'question' and 'answer' key-value pair.
|
| 363 |
'''
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
qa_dict_quest = {"role": "user", "content": row["questions"]}
|
| 372 |
-
qa_dict_ans = {"role": "assistant", "content": row["answers"]}
|
| 373 |
-
|
| 374 |
-
# Add the dictionary to the result list
|
| 375 |
-
result.append(qa_dict_quest)
|
| 376 |
-
result.append(qa_dict_ans)
|
| 377 |
-
|
| 378 |
-
# Return the list of dictionaries
|
| 379 |
-
return result
|
| 380 |
-
|
| 381 |
-
st.sidebar.markdown(f'''This is a chatbot to help you learn more about {organization_name}''')
|
| 382 |
-
|
| 383 |
-
domain = st.sidebar.selectbox(f"Select a topic", {domains})
|
| 384 |
-
|
| 385 |
-
special_threshold = 0.3
|
| 386 |
-
|
| 387 |
-
n_results = 3
|
| 388 |
-
|
| 389 |
-
clear_button = st.sidebar.button("Clear Conversation", key="clear")
|
| 390 |
-
|
| 391 |
-
if clear_button:
|
| 392 |
-
st.session_state.messages = []
|
| 393 |
-
st.session_state.curr_domain = ""
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
###
|
| 398 |
-
###
|
| 399 |
-
### Load the dataset from a provided source.
|
| 400 |
-
###
|
| 401 |
-
###
|
| 402 |
-
|
| 403 |
-
initial_input = f"Tell me about {organization_name}"
|
| 404 |
-
|
| 405 |
-
# Initialize a new client for ChromeDB.
|
| 406 |
-
client = chromadb.Client()
|
| 407 |
-
|
| 408 |
-
# Generate a random number between 1 billion and 10 billion.
|
| 409 |
-
random_number: int = np.random.randint(low=1e9, high=1e10)
|
| 410 |
-
|
| 411 |
-
# Generate a random string consisting of 10 uppercase letters and digits.
|
| 412 |
-
random_string: str = "".join(
|
| 413 |
-
np.random.choice(list(string.ascii_uppercase + string.digits), size=10)
|
| 414 |
-
)
|
| 415 |
-
|
| 416 |
-
# Combine the random number and random string into one identifier.
|
| 417 |
-
combined_string: str = f"{random_number}{random_string}"
|
| 418 |
-
|
| 419 |
-
# Create a new collection in ChromeDB with the combined string as its name.
|
| 420 |
-
collection = client.create_collection(combined_string)
|
| 421 |
-
|
| 422 |
-
st.title(f"{organization_name} Chatbot")
|
| 423 |
-
|
| 424 |
-
# Initialize chat history
|
| 425 |
-
if "messages" not in st.session_state:
|
| 426 |
-
st.session_state.messages = []
|
| 427 |
-
|
| 428 |
-
if "curr_domain" not in st.session_state:
|
| 429 |
-
st.session_state.curr_domain = ""
|
| 430 |
-
|
| 431 |
-
###
|
| 432 |
-
###
|
| 433 |
-
### init_messages dict (one key per domain)
|
| 434 |
-
###
|
| 435 |
-
###
|
| 436 |
-
|
| 437 |
-
###
|
| 438 |
-
###
|
| 439 |
-
### chatbot_instructions dict (one key per domain)
|
| 440 |
-
###
|
| 441 |
-
###
|
| 442 |
-
|
| 443 |
-
# Embed and store the first N supports for this demo
|
| 444 |
-
with st.spinner("Loading, please be patient with us ... 🙏"):
|
| 445 |
-
L = len(dataset["train"]["questions"])
|
| 446 |
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
-
|
| 454 |
-
|
| 455 |
|
| 456 |
-
|
| 457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
-
answer = call_chatgpt(engineered_prompt, directions)
|
| 509 |
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
st.markdown("Go to this [google colab link](https://colab.research.google.com/drive/1eCpk9HUoCKZb--tiNyQSHFW2ojoaA35m) to get started")
|
| 257 |
|
| 258 |
if selected_app == "4) Create Chatbot":
|
| 259 |
+
if st.session_state.error != "":
|
| 260 |
+
st.error(st.session_state.error)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
+
if st.session_state.success != None:
|
| 263 |
+
st.success("Success! Copy/paste the requirements.txt and app.py files into your HuggingFace Space")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
+
st.write('requirements.txt')
|
| 266 |
+
st.code(st.session_state.success[0], language='python')
|
| 267 |
+
|
| 268 |
+
st.write('app.py')
|
| 269 |
+
st.code(st.session_state.success[1], language='python')
|
| 270 |
|
| 271 |
+
if st.button('Reset'):
|
| 272 |
+
st.session_state.clear()
|
| 273 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
+
else:
|
| 276 |
+
organization_name = st.text_input("What is the name of your organization", "")
|
|
|
|
|
|
|
| 277 |
|
| 278 |
+
# num_domains = st.number_input("Number sentences per Q/A pair", value=2, step=1, min_value=1, max_value=3)
|
| 279 |
+
submit = st.button("Submit")
|
| 280 |
+
if submit:
|
| 281 |
+
st.session_state.submit = True
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
+
if st.session_state.submit:
|
| 284 |
|
| 285 |
+
requirements = '''
|
| 286 |
+
openai
|
| 287 |
+
scipy
|
| 288 |
+
streamlit
|
| 289 |
+
chromadb
|
| 290 |
+
datasets
|
|
|
|
|
|
|
|
|
|
| 291 |
'''
|
| 292 |
|
| 293 |
+
app = """
|
| 294 |
+
import os
|
| 295 |
+
import streamlit as st
|
| 296 |
+
from datasets import load_dataset
|
| 297 |
+
import chromadb
|
| 298 |
+
import string
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
+
from openai import OpenAI
|
| 301 |
+
|
| 302 |
+
import numpy as np
|
| 303 |
+
import pandas as pd
|
| 304 |
+
|
| 305 |
+
from scipy.spatial.distance import cosine
|
| 306 |
+
|
| 307 |
+
from typing import Dict, List
|
| 308 |
+
|
| 309 |
+
def merge_dataframes(dataframes):
|
| 310 |
+
# Concatenate the list of dataframes
|
| 311 |
+
combined_dataframe = pd.concat(dataframes, ignore_index=True)
|
| 312 |
|
| 313 |
+
# Ensure that the resulting dataframe only contains the columns "context", "questions", "answers"
|
| 314 |
+
combined_dataframe = combined_dataframe[['context', 'questions', 'answers']]
|
| 315 |
|
| 316 |
+
return combined_dataframe
|
| 317 |
+
|
| 318 |
+
def call_chatgpt(prompt: str, directions: str) -> str:
|
| 319 |
+
'''
|
| 320 |
+
Uses the OpenAI API to generate an AI response to a prompt.
|
| 321 |
+
Args:
|
| 322 |
+
prompt: A string representing the prompt to send to the OpenAI API.
|
| 323 |
+
Returns:
|
| 324 |
+
A string representing the AI's generated response.
|
| 325 |
+
'''
|
| 326 |
|
| 327 |
+
# Use the OpenAI API to generate a response based on the input prompt.
|
| 328 |
+
client = OpenAI(api_key = os.environ["OPENAI_API_KEY"])
|
| 329 |
+
|
| 330 |
+
completion = client.chat.completions.create(
|
| 331 |
+
model="gpt-3.5-turbo-0125",
|
| 332 |
+
messages=[
|
| 333 |
+
{"role": "system", "content": directions},
|
| 334 |
+
{"role": "user", "content": prompt}
|
| 335 |
+
]
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Extract the text from the first (and only) choice in the response output.
|
| 339 |
+
ans = completion.choices[0].message.content
|
| 340 |
+
|
| 341 |
+
# Return the generated AI response.
|
| 342 |
+
return ans
|
| 343 |
+
|
| 344 |
+
def openai_text_embedding(prompt: str) -> str:
|
| 345 |
+
return openai.Embedding.create(input=prompt, model="text-embedding-ada-002")[
|
| 346 |
+
"data"
|
| 347 |
+
][0]["embedding"]
|
| 348 |
+
|
| 349 |
+
def calculate_sts_openai_score(sentence1: str, sentence2: str) -> float:
|
| 350 |
+
# Compute sentence embeddings
|
| 351 |
+
embedding1 = openai_text_embedding(sentence1) # Flatten the embedding array
|
| 352 |
+
embedding2 = openai_text_embedding(sentence2) # Flatten the embedding array
|
| 353 |
+
|
| 354 |
+
# Convert to array
|
| 355 |
+
embedding1 = np.asarray(embedding1)
|
| 356 |
+
embedding2 = np.asarray(embedding2)
|
| 357 |
+
|
| 358 |
+
# Calculate cosine similarity between the embeddings
|
| 359 |
+
similarity_score = 1 - cosine(embedding1, embedding2)
|
| 360 |
+
|
| 361 |
+
return similarity_score
|
| 362 |
+
|
| 363 |
+
def add_dist_score_column(
|
| 364 |
+
dataframe: pd.DataFrame, sentence: str,
|
| 365 |
+
) -> pd.DataFrame:
|
| 366 |
+
dataframe["stsopenai"] = dataframe["questions"].apply(
|
| 367 |
+
lambda x: calculate_sts_openai_score(str(x), sentence)
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
sorted_dataframe = dataframe.sort_values(by="stsopenai", ascending=False)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
return sorted_dataframe.iloc[:5, :]
|
| 374 |
+
|
| 375 |
+
def convert_to_list_of_dict(df: pd.DataFrame) -> List[Dict[str, str]]:
|
| 376 |
+
'''
|
| 377 |
+
Reads in a pandas DataFrame and produces a list of dictionaries with two keys each, 'question' and 'answer.'
|
| 378 |
+
Args:
|
| 379 |
+
df: A pandas DataFrame with columns named 'questions' and 'answers'.
|
| 380 |
+
Returns:
|
| 381 |
+
A list of dictionaries, with each dictionary containing a 'question' and 'answer' key-value pair.
|
| 382 |
+
'''
|
| 383 |
+
|
| 384 |
+
# Initialize an empty list to store the dictionaries
|
| 385 |
+
result = []
|
| 386 |
+
|
| 387 |
+
# Loop through each row of the DataFrame
|
| 388 |
+
for index, row in df.iterrows():
|
| 389 |
+
# Create a dictionary with the current question and answer
|
| 390 |
+
qa_dict_quest = {"role": "user", "content": row["questions"]}
|
| 391 |
+
qa_dict_ans = {"role": "assistant", "content": row["answers"]}
|
| 392 |
+
|
| 393 |
+
# Add the dictionary to the result list
|
| 394 |
+
result.append(qa_dict_quest)
|
| 395 |
+
result.append(qa_dict_ans)
|
| 396 |
+
|
| 397 |
+
# Return the list of dictionaries
|
| 398 |
+
return result
|
| 399 |
+
|
| 400 |
+
st.sidebar.markdown(f'''This is a chatbot to help you learn more about {organization_name}''')
|
| 401 |
+
|
| 402 |
+
domain = st.sidebar.selectbox(f"Select a topic", {domains})
|
| 403 |
+
|
| 404 |
+
special_threshold = 0.3
|
| 405 |
+
|
| 406 |
+
n_results = 3
|
| 407 |
+
|
| 408 |
+
clear_button = st.sidebar.button("Clear Conversation", key="clear")
|
| 409 |
+
|
| 410 |
+
if clear_button:
|
| 411 |
+
st.session_state.messages = []
|
| 412 |
+
st.session_state.curr_domain = ""
|
| 413 |
|
|
|
|
| 414 |
|
| 415 |
+
|
| 416 |
+
###
|
| 417 |
+
###
|
| 418 |
+
### Load the dataset from a provided source.
|
| 419 |
+
###
|
| 420 |
+
###
|
| 421 |
+
|
| 422 |
+
initial_input = f"Tell me about {organization_name}"
|
| 423 |
+
|
| 424 |
+
# Initialize a new client for ChromeDB.
|
| 425 |
+
client = chromadb.Client()
|
| 426 |
+
|
| 427 |
+
# Generate a random number between 1 billion and 10 billion.
|
| 428 |
+
random_number: int = np.random.randint(low=1e9, high=1e10)
|
| 429 |
+
|
| 430 |
+
# Generate a random string consisting of 10 uppercase letters and digits.
|
| 431 |
+
random_string: str = "".join(
|
| 432 |
+
np.random.choice(list(string.ascii_uppercase + string.digits), size=10)
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# Combine the random number and random string into one identifier.
|
| 436 |
+
combined_string: str = f"{random_number}{random_string}"
|
| 437 |
+
|
| 438 |
+
# Create a new collection in ChromeDB with the combined string as its name.
|
| 439 |
+
collection = client.create_collection(combined_string)
|
| 440 |
+
|
| 441 |
+
st.title(f"{organization_name} Chatbot")
|
| 442 |
+
|
| 443 |
+
# Initialize chat history
|
| 444 |
+
if "messages" not in st.session_state:
|
| 445 |
+
st.session_state.messages = []
|
| 446 |
+
|
| 447 |
+
if "curr_domain" not in st.session_state:
|
| 448 |
+
st.session_state.curr_domain = ""
|
| 449 |
+
|
| 450 |
+
###
|
| 451 |
+
###
|
| 452 |
+
### init_messages dict (one key per domain)
|
| 453 |
+
###
|
| 454 |
+
###
|
| 455 |
+
|
| 456 |
+
###
|
| 457 |
+
###
|
| 458 |
+
### chatbot_instructions dict (one key per domain)
|
| 459 |
+
###
|
| 460 |
+
###
|
| 461 |
+
|
| 462 |
+
# Embed and store the first N supports for this demo
|
| 463 |
+
with st.spinner("Loading, please be patient with us ... 🙏"):
|
| 464 |
+
L = len(dataset["train"]["questions"])
|
| 465 |
+
|
| 466 |
+
collection.add(
|
| 467 |
+
ids=[str(i) for i in range(0, L)], # IDs are just strings
|
| 468 |
+
documents=dataset["train"]["questions"], # Enter questions here
|
| 469 |
+
metadatas=[{"type": "support"} for _ in range(0, L)],
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
if st.session_state.curr_domain != domain:
|
| 473 |
+
st.session_state.messages = []
|
| 474 |
+
|
| 475 |
+
init_message = init_messages[domain]
|
| 476 |
+
st.session_state.messages.append({"role": "assistant", "content": init_message})
|
| 477 |
+
|
| 478 |
+
st.session_state.curr_domain = domain
|
| 479 |
+
|
| 480 |
+
# Display chat messages from history on app rerun
|
| 481 |
+
for message in st.session_state.messages:
|
| 482 |
+
with st.chat_message(message["role"]):
|
| 483 |
+
st.markdown(message["content"])
|
| 484 |
+
|
| 485 |
+
# React to user input
|
| 486 |
+
if prompt := st.chat_input(f"Tell me about {organization_name"):
|
| 487 |
+
# Display user message in chat message container
|
| 488 |
+
st.chat_message("user").markdown(prompt)
|
| 489 |
+
# Add user message to chat history
|
| 490 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 491 |
+
|
| 492 |
+
question = prompt
|
| 493 |
+
|
| 494 |
+
results = collection.query(query_texts=question, n_results=n_results)
|
| 495 |
+
|
| 496 |
+
idx = results["ids"][0]
|
| 497 |
+
idx = [int(i) for i in idx]
|
| 498 |
+
ref = pd.DataFrame(
|
| 499 |
+
{
|
| 500 |
+
"idx": idx,
|
| 501 |
+
"questions": [dataset["train"]["questions"][i] for i in idx],
|
| 502 |
+
"answers": [dataset["train"]["answers"][i] for i in idx],
|
| 503 |
+
"distances": results["distances"][0],
|
| 504 |
+
}
|
| 505 |
+
)
|
| 506 |
+
# special_threshold = st.sidebar.slider('How old are you?', 0, 0.6, 0.1) # 0.3
|
| 507 |
+
# special_threshold = 0.3
|
| 508 |
+
filtered_ref = ref[ref["distances"] < special_threshold]
|
| 509 |
+
if filtered_ref.shape[0] > 0:
|
| 510 |
+
# st.success("There are highly relevant information in our database.")
|
| 511 |
+
ref_from_db_search = filtered_ref["answers"].str.cat(sep=" ")
|
| 512 |
+
final_ref = filtered_ref
|
| 513 |
+
else:
|
| 514 |
+
# st.warning(
|
| 515 |
+
# "The database may not have relevant information to help your question so please be aware of hallucinations."
|
| 516 |
+
# )
|
| 517 |
+
ref_from_db_search = ref["answers"].str.cat(sep=" ")
|
| 518 |
+
final_ref = ref
|
| 519 |
+
|
| 520 |
+
engineered_prompt = f'''
|
| 521 |
+
Based on the context: {ref_from_db_search},
|
| 522 |
+
answer the user question: {question}.
|
| 523 |
+
'''
|
| 524 |
+
|
| 525 |
+
directions = chatbot_instructions[domain]
|
| 526 |
+
|
| 527 |
+
answer = call_chatgpt(engineered_prompt, directions)
|
| 528 |
+
|
| 529 |
+
response = answer
|
| 530 |
+
# Display assistant response in chat message container
|
| 531 |
+
with st.chat_message("assistant"):
|
| 532 |
+
st.markdown(response)
|
| 533 |
+
with st.expander("See reference:"):
|
| 534 |
+
st.table(final_ref)
|
| 535 |
+
# Add assistant response to chat history
|
| 536 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 537 |
+
"""
|
| 538 |
+
|
| 539 |
+
st.session_state.clear()
|
| 540 |
+
st.session_state.success = (requirements, app)
|
| 541 |
+
st.rerun()
|