| | import asyncio |
| | import os |
| | from flask import Blueprint, request, Response |
| | import json |
| | import datetime |
| | import logging |
| | import traceback |
| |
|
| | from pymongo import MongoClient |
| | from bson.objectid import ObjectId |
| | from transformers import GPT2TokenizerFast |
| |
|
| |
|
| |
|
| | from application.core.settings import settings |
| | from application.vectorstore.vector_creator import VectorCreator |
| | from application.llm.llm_creator import LLMCreator |
| | from application.error import bad_request |
| |
|
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | mongo = MongoClient(settings.MONGO_URI) |
| | db = mongo["docsgpt"] |
| | conversations_collection = db["conversations"] |
| | vectors_collection = db["vectors"] |
| | prompts_collection = db["prompts"] |
| | answer = Blueprint('answer', __name__) |
| |
|
| | if settings.LLM_NAME == "gpt4": |
| | gpt_model = 'gpt-4' |
| | elif settings.LLM_NAME == "anthropic": |
| | gpt_model = 'claude-2' |
| | else: |
| | gpt_model = 'gpt-3.5-turbo' |
| |
|
| | |
| | current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| | with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: |
| | chat_combine_template = f.read() |
| |
|
| | with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f: |
| | chat_reduce_template = f.read() |
| |
|
| | with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f: |
| | chat_combine_creative = f.read() |
| |
|
| | with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f: |
| | chat_combine_strict = f.read() |
| |
|
| | api_key_set = settings.API_KEY is not None |
| | embeddings_key_set = settings.EMBEDDINGS_KEY is not None |
| |
|
| |
|
| | async def async_generate(chain, question, chat_history): |
| | result = await chain.arun({"question": question, "chat_history": chat_history}) |
| | return result |
| |
|
| |
|
| | def count_tokens(string): |
| | tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') |
| | return len(tokenizer(string)['input_ids']) |
| |
|
| |
|
| | def run_async_chain(chain, question, chat_history): |
| | loop = asyncio.new_event_loop() |
| | asyncio.set_event_loop(loop) |
| | result = {} |
| | try: |
| | answer = loop.run_until_complete(async_generate(chain, question, chat_history)) |
| | finally: |
| | loop.close() |
| | result["answer"] = answer |
| | return result |
| |
|
| |
|
| | def get_vectorstore(data): |
| | if "active_docs" in data: |
| | if data["active_docs"].split("/")[0] == "default": |
| | vectorstore = "" |
| | elif data["active_docs"].split("/")[0] == "local": |
| | vectorstore = "indexes/" + data["active_docs"] |
| | else: |
| | vectorstore = "vectors/" + data["active_docs"] |
| | if data["active_docs"] == "default": |
| | vectorstore = "" |
| | else: |
| | vectorstore = "" |
| | vectorstore = os.path.join("application", vectorstore) |
| | return vectorstore |
| |
|
| |
|
| | def is_azure_configured(): |
| | return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME |
| |
|
| |
|
| | def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conversation_id): |
| | llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key) |
| |
|
| | if prompt_id == 'default': |
| | prompt = chat_combine_template |
| | elif prompt_id == 'creative': |
| | prompt = chat_combine_creative |
| | elif prompt_id == 'strict': |
| | prompt = chat_combine_strict |
| | else: |
| | prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] |
| |
|
| | docs = docsearch.search(question, k=2) |
| | if settings.LLM_NAME == "llama.cpp": |
| | docs = [docs[0]] |
| | |
| | docs_together = "\n".join([doc.page_content for doc in docs]) |
| | p_chat_combine = prompt.replace("{summaries}", docs_together) |
| | messages_combine = [{"role": "system", "content": p_chat_combine}] |
| | source_log_docs = [] |
| | for doc in docs: |
| | if doc.metadata: |
| | source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) |
| | else: |
| | source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) |
| |
|
| | if len(chat_history) > 1: |
| | tokens_current_history = 0 |
| | |
| | chat_history.reverse() |
| | for i in chat_history: |
| | if "prompt" in i and "response" in i: |
| | tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) |
| | if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: |
| | tokens_current_history += tokens_batch |
| | messages_combine.append({"role": "user", "content": i["prompt"]}) |
| | messages_combine.append({"role": "system", "content": i["response"]}) |
| | messages_combine.append({"role": "user", "content": question}) |
| |
|
| | response_full = "" |
| | completion = llm.gen_stream(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, |
| | messages=messages_combine) |
| | for line in completion: |
| | data = json.dumps({"answer": str(line)}) |
| | response_full += str(line) |
| | yield f"data: {data}\n\n" |
| |
|
| | |
| | if conversation_id is not None: |
| | conversations_collection.update_one( |
| | {"_id": ObjectId(conversation_id)}, |
| | {"$push": {"queries": {"prompt": question, "response": response_full, "sources": source_log_docs}}}, |
| | ) |
| |
|
| | else: |
| | |
| | |
| | messages_summary = [{"role": "assistant", "content": "Summarise following conversation in no more than 3 " |
| | "words, respond ONLY with the summary, use the same " |
| | "language as the system \n\nUser: " + question + "\n\n" + |
| | "AI: " + |
| | response_full}, |
| | {"role": "user", "content": "Summarise following conversation in no more than 3 words, " |
| | "respond ONLY with the summary, use the same language as the " |
| | "system"}] |
| |
|
| | completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, |
| | messages=messages_summary, max_tokens=30) |
| | conversation_id = conversations_collection.insert_one( |
| | {"user": "local", |
| | "date": datetime.datetime.utcnow(), |
| | "name": completion, |
| | "queries": [{"prompt": question, "response": response_full, "sources": source_log_docs}]} |
| | ).inserted_id |
| |
|
| | |
| | data = json.dumps({"type": "id", "id": str(conversation_id)}) |
| | yield f"data: {data}\n\n" |
| | data = json.dumps({"type": "end"}) |
| | yield f"data: {data}\n\n" |
| |
|
| |
|
| | @answer.route("/stream", methods=["POST"]) |
| | def stream(): |
| | data = request.get_json() |
| | |
| | question = data["question"] |
| | history = data["history"] |
| | |
| | history = json.loads(history) |
| | conversation_id = data["conversation_id"] |
| | if 'prompt_id' in data: |
| | prompt_id = data["prompt_id"] |
| | else: |
| | prompt_id = 'default' |
| |
|
| | |
| |
|
| | if not api_key_set: |
| | api_key = data["api_key"] |
| | else: |
| | api_key = settings.API_KEY |
| | if not embeddings_key_set: |
| | embeddings_key = data["embeddings_key"] |
| | else: |
| | embeddings_key = settings.EMBEDDINGS_KEY |
| | if "active_docs" in data: |
| | vectorstore = get_vectorstore({"active_docs": data["active_docs"]}) |
| | else: |
| | vectorstore = "" |
| | docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key) |
| |
|
| | return Response( |
| | complete_stream(question, docsearch, |
| | chat_history=history, api_key=api_key, |
| | prompt_id=prompt_id, |
| | conversation_id=conversation_id), mimetype="text/event-stream" |
| | ) |
| |
|
| |
|
| | @answer.route("/api/answer", methods=["POST"]) |
| | def api_answer(): |
| | data = request.get_json() |
| | question = data["question"] |
| | history = data["history"] |
| | if "conversation_id" not in data: |
| | conversation_id = None |
| | else: |
| | conversation_id = data["conversation_id"] |
| | print("-" * 5) |
| | if not api_key_set: |
| | api_key = data["api_key"] |
| | else: |
| | api_key = settings.API_KEY |
| | if not embeddings_key_set: |
| | embeddings_key = data["embeddings_key"] |
| | else: |
| | embeddings_key = settings.EMBEDDINGS_KEY |
| | if 'prompt_id' in data: |
| | prompt_id = data["prompt_id"] |
| | else: |
| | prompt_id = 'default' |
| |
|
| | if prompt_id == 'default': |
| | prompt = chat_combine_template |
| | elif prompt_id == 'creative': |
| | prompt = chat_combine_creative |
| | elif prompt_id == 'strict': |
| | prompt = chat_combine_strict |
| | else: |
| | prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] |
| |
|
| | |
| | try: |
| | |
| | vectorstore = get_vectorstore(data) |
| | |
| | |
| | docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key) |
| |
|
| |
|
| | llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key) |
| |
|
| |
|
| |
|
| | docs = docsearch.search(question, k=2) |
| | |
| | docs_together = "\n".join([doc.page_content for doc in docs]) |
| | p_chat_combine = prompt.replace("{summaries}", docs_together) |
| | messages_combine = [{"role": "system", "content": p_chat_combine}] |
| | source_log_docs = [] |
| | for doc in docs: |
| | if doc.metadata: |
| | source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) |
| | else: |
| | source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) |
| | |
| |
|
| |
|
| | if len(history) > 1: |
| | tokens_current_history = 0 |
| | |
| | history.reverse() |
| | for i in history: |
| | if "prompt" in i and "response" in i: |
| | tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) |
| | if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: |
| | tokens_current_history += tokens_batch |
| | messages_combine.append({"role": "user", "content": i["prompt"]}) |
| | messages_combine.append({"role": "system", "content": i["response"]}) |
| | messages_combine.append({"role": "user", "content": question}) |
| |
|
| |
|
| | completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, |
| | messages=messages_combine) |
| |
|
| |
|
| | result = {"answer": completion, "sources": source_log_docs} |
| | logger.debug(result) |
| |
|
| | |
| | if conversation_id is not None: |
| | conversations_collection.update_one( |
| | {"_id": ObjectId(conversation_id)}, |
| | {"$push": {"queries": {"prompt": question, |
| | "response": result["answer"], "sources": result['sources']}}}, |
| | ) |
| |
|
| | else: |
| | |
| | |
| | messages_summary = [ |
| | {"role": "assistant", "content": "Summarise following conversation in no more than 3 words, " |
| | "respond ONLY with the summary, use the same language as the system \n\n" |
| | "User: " + question + "\n\n" + "AI: " + result["answer"]}, |
| | {"role": "user", "content": "Summarise following conversation in no more than 3 words, " |
| | "respond ONLY with the summary, use the same language as the system"} |
| | ] |
| |
|
| | completion = llm.gen( |
| | model=gpt_model, |
| | engine=settings.AZURE_DEPLOYMENT_NAME, |
| | messages=messages_summary, |
| | max_tokens=30 |
| | ) |
| | conversation_id = conversations_collection.insert_one( |
| | {"user": "local", |
| | "date": datetime.datetime.utcnow(), |
| | "name": completion, |
| | "queries": [{"prompt": question, "response": result["answer"], "sources": source_log_docs}]} |
| | ).inserted_id |
| |
|
| | result["conversation_id"] = str(conversation_id) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | return result |
| | except Exception as e: |
| | |
| | traceback.print_exc() |
| | print(str(e)) |
| | return bad_request(500, str(e)) |
| |
|
| |
|
| | @answer.route("/api/search", methods=["POST"]) |
| | def api_search(): |
| | data = request.get_json() |
| | |
| | question = data["question"] |
| |
|
| | if not embeddings_key_set: |
| | embeddings_key = data["embeddings_key"] |
| | else: |
| | embeddings_key = settings.EMBEDDINGS_KEY |
| | if "active_docs" in data: |
| | vectorstore = get_vectorstore({"active_docs": data["active_docs"]}) |
| | else: |
| | vectorstore = "" |
| | docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key) |
| |
|
| | docs = docsearch.search(question, k=2) |
| |
|
| | source_log_docs = [] |
| | for doc in docs: |
| | if doc.metadata: |
| | source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) |
| | else: |
| | source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) |
| | |
| | return source_log_docs |
| |
|
| |
|