import gradio as gr import numpy as np import faiss import torch from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline # ===================================== # 1. LOAD DOCUMENTS # ===================================== def load_documents(path="documents.txt"): with open(path, "r", encoding="utf-8") as f: docs = f.readlines() return [doc.strip() for doc in docs if doc.strip()] documents = load_documents() # ===================================== # 2. LOAD EMBEDDING MODEL (HF Open Source) # ===================================== embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") doc_embeddings = embedding_model.encode(documents, convert_to_numpy=True) dimension = doc_embeddings.shape[1] # ===================================== # 3. BUILD FAISS INDEX # ===================================== index = faiss.IndexFlatL2(dimension) index.add(doc_embeddings) # ===================================== # 4. LOAD OPEN-SOURCE LLM (HF) # ===================================== MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # change if needed tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" ) generator = pipeline( "text-generation", model=model, tokenizer=tokenizer ) # ===================================== # 5. RETRIEVAL FUNCTION # ===================================== def retrieve(query, top_k=3): query_embedding = embedding_model.encode([query], convert_to_numpy=True) distances, indices = index.search(query_embedding, top_k) retrieved_docs = [documents[i] for i in indices[0]] return retrieved_docs # ===================================== # 6. GENERIC LLM CALL # ===================================== def call_llm(prompt): response = generator( prompt, max_new_tokens=300, do_sample=True, temperature=0.7, top_p=0.9 ) return response[0]["generated_text"] # ===================================== # 7. RAG PIPELINE # ===================================== def rag_pipeline(query): retrieved_docs = retrieve(query) context = "\n".join(retrieved_docs) prompt = f""" You are a helpful AI assistant. Answer ONLY from the provided context. Context: {context} Question: {query} Answer: """ answer = call_llm(prompt) return answer # ===================================== # 8. GRADIO UI # ===================================== with gr.Blocks() as demo: gr.Markdown("# 🧠 Open Source RAG (HF Only)") query = gr.Textbox(label="Ask your question") output = gr.Textbox(label="Answer") query.submit(rag_pipeline, query, output) demo.launch()