| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import gradio as gr |
| | import os |
| | from langchain.prompts import ChatPromptTemplate |
| | from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint |
| | from langchain.schema.runnable import RunnablePassthrough |
| | from langchain_community.document_loaders import TextLoader |
| | from langchain_text_splitters import CharacterTextSplitter |
| | from langchain_community.vectorstores import FAISS |
| | from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings |
| |
|
| | API_TOKEN = os.environ.get('HUGGINGFACE_API_KEY') |
| | HF_API_KEY = API_TOKEN |
| |
|
| | llm_urls = { |
| | "Mistral 7B": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2", |
| | "Llama 8B": "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct" |
| | } |
| |
|
| | def initialize_vector_store_retriever(file): |
| | |
| | |
| | raw_documents = TextLoader(file).load() |
| | text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
| | documents = text_splitter.split_documents(raw_documents) |
| |
|
| | API_URL = "https://api-inference.huggingface.co/models/BAAI/bge-base-en-v1.5" |
| | embeddings = HuggingFaceInferenceAPIEmbeddings( |
| | endpoint_url=API_URL, |
| | api_key=HF_API_KEY, |
| | ) |
| | db = FAISS.from_documents(documents, embeddings) |
| | retriever = db.as_retriever() |
| | return retriever |
| |
|
| | def generate_llm_rag_prompt() -> ChatPromptTemplate: |
| | |
| | template = "<s>[INST] <<SYS>>{system}<</SYS>>{context} {prompt} [/INST]" |
| |
|
| | prompt_template = ChatPromptTemplate.from_template(template) |
| | return prompt_template |
| |
|
| |
|
| | |
| | def create_chain(retriever, llm): |
| |
|
| | url = llm_urls[llm] |
| | model_endpoint = HuggingFaceEndpoint( |
| | endpoint_url=url, |
| | huggingfacehub_api_token=HF_API_KEY, |
| | task="text2text-generation", |
| | max_new_tokens=200 |
| | ) |
| |
|
| | if retriever != None: |
| | def get_system(input): |
| | return "You are a helpful and honest assistant. Please, respond concisely and truthfully." |
| |
|
| | retrieval = {"context": retriever, "prompt": RunnablePassthrough(), "system": get_system} |
| | chain = retrieval | generate_llm_rag_prompt() | model_endpoint |
| | return chain, model_endpoint |
| | else: |
| | return None, model_endpoint |
| |
|
| |
|
| | def query(question_text, llm, session_data): |
| | if question_text == "": |
| | without_rag_text = "Query result without RAG is not available. Enter a question first." |
| | rag_text = "Query result with RAG is not available. Enter a question first." |
| | return without_rag_text, rag_text |
| | |
| | if len(session_data)>0: |
| | retriever = session_data[0] |
| | else: |
| | retriever = None |
| | chain, model_endpoint = create_chain(retriever, llm) |
| | without_rag_text = "Query result without RAG:\n\n" + model_endpoint(question_text).strip() |
| | if (retriever == None): |
| | rag_text = "Query result With RAG is not available. Load Vector Store first." |
| | else: |
| | ans = chain.invoke(question_text).strip() |
| | s = ans |
| | s = [s.split("[INST] <<SYS>>")[1] for s in s.split("[/SYS]>[/INST]") if s.find("[INST] <<SYS>>") >=0] |
| | if len(s) >= 2: |
| | s = s[1:-1] |
| | else: |
| | s = ans |
| | rag_text = "Query result With RAG:\n\n" + "".join(s).split("[/INST]")[0] |
| | return without_rag_text, rag_text |
| |
|
| | def upload_file(file, session_data): |
| | |
| | |
| | session_data = [initialize_vector_store_retriever(file)] |
| | return gr.File(value=file, visible=True), session_data |
| |
|
| | def initialize_vector_store(session_data): |
| | session_data = [initialize_vector_store_retriever()] |
| | return session_data |
| | |
| | with gr.Blocks() as demo: |
| | gr.HTML("""<h1 align="center">Retrieval Augmented Generation</h1>""") |
| | session_data = gr.State([]) |
| |
|
| | file_output = gr.File(visible=False) |
| | upload_button = gr.UploadButton("Click to Upload a text File to Vector Store", file_types=["text"], file_count="single") |
| | upload_button.upload(upload_file, [upload_button, session_data], [file_output, session_data]) |
| |
|
| | |
| | with gr.Row(): |
| | with gr.Column(scale=4): |
| | question_text = gr.Textbox(show_label=False, placeholder="Ask a question", lines=2) |
| | with gr.Column(scale=1): |
| | llm_Choice = gr.Radio(["Llama 8B", "Mistral 7B"], value="Mistral 7B", label="Select lanaguage model:", info="") |
| | query_Button = gr.Button("Query") |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | without_rag_text = gr.Textbox(show_label=False, placeholder="Query result without using RAG", lines=15) |
| | with gr.Column(scale=1): |
| | rag_text = gr.Textbox(show_label=False, placeholder="Query result with RAG", lines=15) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | query_Button.click( |
| | query, |
| | [question_text, llm_Choice, session_data], |
| | [without_rag_text, rag_text], |
| | |
| | ) |
| |
|
| | demo.queue().launch(share=False, inbrowser=True) |