| import os |
| from typing import List |
| from chainlit.types import AskFileResponse |
| from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader |
| from aimakerspace.openai_utils.prompts import ( |
| UserRolePrompt, |
| SystemRolePrompt, |
| AssistantRolePrompt, |
| ) |
| from aimakerspace.openai_utils.embedding import EmbeddingModel |
| from aimakerspace.vectordatabase import VectorDatabase |
| from aimakerspace.openai_utils.chatmodel import ChatOpenAI |
| import chainlit as cl |
|
|
| system_template = """\ |
| Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer.""" |
| system_role_prompt = SystemRolePrompt(system_template) |
|
|
| user_prompt_template = """\ |
| Context: |
| {context} |
| |
| Question: |
| {question} |
| """ |
| user_role_prompt = UserRolePrompt(user_prompt_template) |
|
|
| class RetrievalAugmentedQAPipeline: |
| def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None: |
| self.llm = llm |
| self.vector_db_retriever = vector_db_retriever |
|
|
| async def arun_pipeline(self, user_query: str): |
| context_list = self.vector_db_retriever.search_by_text(user_query, k=4) |
|
|
| context_prompt = "" |
| for context in context_list: |
| context_prompt += context[0] + "\n" |
|
|
| formatted_system_prompt = system_role_prompt.create_message() |
|
|
| formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt) |
|
|
| async def generate_response(): |
| async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]): |
| yield chunk |
|
|
| return {"response": generate_response(), "context": context_list} |
|
|
| text_splitter = CharacterTextSplitter() |
|
|
|
|
| def process_file(file: AskFileResponse): |
| import tempfile |
| import shutil |
| |
| print(f"Processing file: {file.name}") |
| |
| |
| suffix = f".{file.name.split('.')[-1]}" |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: |
| |
| shutil.copyfile(file.path, temp_file.name) |
| print(f"Created temporary file at: {temp_file.name}") |
| |
| |
| if file.name.lower().endswith('.pdf'): |
| loader = PDFLoader(temp_file.name) |
| else: |
| loader = TextFileLoader(temp_file.name) |
| |
| try: |
| |
| documents = loader.load_documents() |
| texts = text_splitter.split_texts(documents) |
| return texts |
| finally: |
| |
| try: |
| os.unlink(temp_file.name) |
| except Exception as e: |
| print(f"Error cleaning up temporary file: {e}") |
|
|
|
|
| @cl.on_chat_start |
| async def on_chat_start(): |
| files = None |
|
|
| |
| while files == None: |
| files = await cl.AskFileMessage( |
| content="Please upload a Text or PDF file to begin!", |
| accept=["text/plain", "application/pdf"], |
| max_size_mb=2, |
| timeout=180, |
| ).send() |
|
|
| file = files[0] |
|
|
| msg = cl.Message( |
| content=f"Processing `{file.name}`..." |
| ) |
| await msg.send() |
|
|
| |
| texts = process_file(file) |
|
|
| print(f"Processing {len(texts)} text chunks") |
|
|
| |
| vector_db = VectorDatabase() |
| vector_db = await vector_db.abuild_from_list(texts) |
| |
| chat_openai = ChatOpenAI() |
|
|
| |
| retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline( |
| vector_db_retriever=vector_db, |
| llm=chat_openai |
| ) |
| |
| |
| msg.content = f"Processing `{file.name}` done. You can now ask questions!" |
| await msg.update() |
|
|
| cl.user_session.set("chain", retrieval_augmented_qa_pipeline) |
|
|
|
|
| @cl.on_message |
| async def main(message): |
| chain = cl.user_session.get("chain") |
|
|
| msg = cl.Message(content="") |
| result = await chain.arun_pipeline(message.content) |
|
|
| async for stream_resp in result["response"]: |
| await msg.stream_token(stream_resp) |
|
|
| await msg.send() |