""" rag.py — Plexi RAG Engine ========================= Handles everything related to the LlamaIndex vector index: - Downloading the pre-built index from GitHub - Loading HuggingFace sentence-transformer embeddings - Embedding queries and retrieving top-k chunks scoped by semester + subject - Extracting text from PDFs for full-context fallback - Formatting retrieved chunks for the LLM system prompt """ import io import os import tempfile from pathlib import Path import requests # --------------------------------------------------------------------------- # Optional LlamaIndex — graceful degradation if not installed # --------------------------------------------------------------------------- try: from llama_index.core import Settings, StorageContext, load_index_from_storage from llama_index.embeddings.huggingface import HuggingFaceEmbedding LLAMA_INDEX_AVAILABLE = True except ImportError: LLAMA_INDEX_AVAILABLE = False try: import PyPDF2 PYPDF2_AVAILABLE = True except ImportError: PYPDF2_AVAILABLE = False # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- MATERIALS_REPO = os.getenv("MATERIALS_REPO", "KunalGupta25/plexi-materials") MANIFEST_BRANCH = os.getenv("MANIFEST_BRANCH", "main") EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" INDEX_FILES = [ "default__vector_store.json", "docstore.json", "graph_store.json", "image__vector_store.json", "index_store.json", ] DEFAULT_TOP_K = 5 # --------------------------------------------------------------------------- # Index loading (called once at FastAPI startup) # --------------------------------------------------------------------------- def load_index(): """ Download the pre-built LlamaIndex from the materials repo and return a VectorStoreIndex ready for querying. Returns (index, error_msg). index is None if loading failed. """ if not LLAMA_INDEX_AVAILABLE: return None, "llama-index-core is not installed." index_base_url = ( f"https://raw.githubusercontent.com/{MATERIALS_REPO}/{MANIFEST_BRANCH}/index" ) index_dir = tempfile.mkdtemp(prefix="plexi_index_") for filename in INDEX_FILES: url = f"{index_base_url}/{filename}" try: resp = requests.get(url, timeout=30) resp.raise_for_status() with open(os.path.join(index_dir, filename), "wb") as fh: fh.write(resp.content) except Exception as err: return None, f"Failed to download index file '{filename}': {err}" try: embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_ID) Settings.embed_model = embed_model Settings.llm = None storage_ctx = StorageContext.from_defaults(persist_dir=index_dir) index = load_index_from_storage(storage_ctx) return index, None except Exception as err: return None, f"Failed to load index from storage: {err}" def load_embed_model(): """Load and return the HuggingFace embedding model (for health checks).""" if not LLAMA_INDEX_AVAILABLE: return None return HuggingFaceEmbedding(model_name=EMBED_MODEL_ID) # --------------------------------------------------------------------------- # Retrieval # --------------------------------------------------------------------------- def _matches_scope(node, semester: str, subject: str) -> bool: """Return True when a retrieved node belongs to the active semester + subject.""" metadata = getattr(node.node, "metadata", {}) or {} return ( metadata.get("semester") == semester and metadata.get("subject") == subject ) def retrieve_chunks( index, query: str, semester: str, subject: str, top_k: int = DEFAULT_TOP_K, ) -> list[dict]: """ Embed the query, retrieve top-k chunks from the index scoped to the given semester + subject. Returns a list of dicts: { text, score, filename, subject } """ if index is None: return [] try: # Fetch more than needed so we have room to filter by scope retriever = index.as_retriever(similarity_top_k=max(top_k * 5, 10)) nodes = retriever.retrieve(query) scoped = [n for n in nodes if _matches_scope(n, semester, subject)] return [ { "text": node.node.get_content(), "score": round(float(node.score), 4) if node.score is not None else None, "filename": (getattr(node.node, "metadata", {}) or {}).get("filename"), "subject": (getattr(node.node, "metadata", {}) or {}).get("subject"), } for node in scoped[:top_k] ] except Exception as err: print(f"Retrieval error: {err}") return [] # --------------------------------------------------------------------------- # Context formatting (for system prompt injection) # --------------------------------------------------------------------------- def format_context(chunks: list[dict]) -> str: """Format retrieved chunks as a numbered block for the LLM system prompt.""" if not chunks: return "(No relevant context retrieved for this query.)" parts = [] for i, chunk in enumerate(chunks, start=1): score_info = f" [relevance: {chunk['score']}]" if chunk.get("score") else "" source = chunk.get("filename") or chunk.get("subject") or "Unknown source" parts.append( f"--- Chunk {i} | {source}{score_info} ---\n{chunk['text']}\n" ) return "\n".join(parts) # --------------------------------------------------------------------------- # PDF text extraction (used for full-context fallback loading) # --------------------------------------------------------------------------- def read_pdf_text(pdf_bytes: bytes) -> str: """Extract plain text from PDF bytes. Returns empty string on failure.""" if not PYPDF2_AVAILABLE: return "" text_parts = [] try: reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes)) for page in reader.pages: try: page_text = page.extract_text() if page_text: # Sanitise surrogate pairs that can appear in some PDFs filtered = page_text.encode("utf-16", "surrogatepass").decode( "utf-16", "ignore" ) text_parts.append(filtered) except Exception: pass except Exception: return pdf_bytes.decode("utf-8", errors="ignore") if pdf_bytes else "" return "\n".join(text_parts)