plexi-api / rag.py
LazyHuman10
Initial commit for HF Space
3b6130d
"""
rag.py — Plexi RAG Engine
=========================
Handles everything related to the LlamaIndex vector index:
- Downloading the pre-built index from GitHub
- Loading HuggingFace sentence-transformer embeddings
- Embedding queries and retrieving top-k chunks scoped by semester + subject
- Extracting text from PDFs for full-context fallback
- Formatting retrieved chunks for the LLM system prompt
"""
import io
import os
import tempfile
from pathlib import Path
import requests
# ---------------------------------------------------------------------------
# Optional LlamaIndex — graceful degradation if not installed
# ---------------------------------------------------------------------------
try:
from llama_index.core import Settings, StorageContext, load_index_from_storage
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
LLAMA_INDEX_AVAILABLE = True
except ImportError:
LLAMA_INDEX_AVAILABLE = False
try:
import PyPDF2
PYPDF2_AVAILABLE = True
except ImportError:
PYPDF2_AVAILABLE = False
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
MATERIALS_REPO = os.getenv("MATERIALS_REPO", "KunalGupta25/plexi-materials")
MANIFEST_BRANCH = os.getenv("MANIFEST_BRANCH", "main")
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
INDEX_FILES = [
"default__vector_store.json",
"docstore.json",
"graph_store.json",
"image__vector_store.json",
"index_store.json",
]
DEFAULT_TOP_K = 5
# ---------------------------------------------------------------------------
# Index loading (called once at FastAPI startup)
# ---------------------------------------------------------------------------
def load_index():
"""
Download the pre-built LlamaIndex from the materials repo and return a
VectorStoreIndex ready for querying.
Returns (index, error_msg). index is None if loading failed.
"""
if not LLAMA_INDEX_AVAILABLE:
return None, "llama-index-core is not installed."
index_base_url = (
f"https://raw.githubusercontent.com/{MATERIALS_REPO}/{MANIFEST_BRANCH}/index"
)
index_dir = tempfile.mkdtemp(prefix="plexi_index_")
for filename in INDEX_FILES:
url = f"{index_base_url}/{filename}"
try:
resp = requests.get(url, timeout=30)
resp.raise_for_status()
with open(os.path.join(index_dir, filename), "wb") as fh:
fh.write(resp.content)
except Exception as err:
return None, f"Failed to download index file '{filename}': {err}"
try:
embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_ID)
Settings.embed_model = embed_model
Settings.llm = None
storage_ctx = StorageContext.from_defaults(persist_dir=index_dir)
index = load_index_from_storage(storage_ctx)
return index, None
except Exception as err:
return None, f"Failed to load index from storage: {err}"
def load_embed_model():
"""Load and return the HuggingFace embedding model (for health checks)."""
if not LLAMA_INDEX_AVAILABLE:
return None
return HuggingFaceEmbedding(model_name=EMBED_MODEL_ID)
# ---------------------------------------------------------------------------
# Retrieval
# ---------------------------------------------------------------------------
def _matches_scope(node, semester: str, subject: str) -> bool:
"""Return True when a retrieved node belongs to the active semester + subject."""
metadata = getattr(node.node, "metadata", {}) or {}
return (
metadata.get("semester") == semester
and metadata.get("subject") == subject
)
def retrieve_chunks(
index,
query: str,
semester: str,
subject: str,
top_k: int = DEFAULT_TOP_K,
) -> list[dict]:
"""
Embed the query, retrieve top-k chunks from the index scoped to the
given semester + subject.
Returns a list of dicts:
{ text, score, filename, subject }
"""
if index is None:
return []
try:
# Fetch more than needed so we have room to filter by scope
retriever = index.as_retriever(similarity_top_k=max(top_k * 5, 10))
nodes = retriever.retrieve(query)
scoped = [n for n in nodes if _matches_scope(n, semester, subject)]
return [
{
"text": node.node.get_content(),
"score": round(float(node.score), 4) if node.score is not None else None,
"filename": (getattr(node.node, "metadata", {}) or {}).get("filename"),
"subject": (getattr(node.node, "metadata", {}) or {}).get("subject"),
}
for node in scoped[:top_k]
]
except Exception as err:
print(f"Retrieval error: {err}")
return []
# ---------------------------------------------------------------------------
# Context formatting (for system prompt injection)
# ---------------------------------------------------------------------------
def format_context(chunks: list[dict]) -> str:
"""Format retrieved chunks as a numbered block for the LLM system prompt."""
if not chunks:
return "(No relevant context retrieved for this query.)"
parts = []
for i, chunk in enumerate(chunks, start=1):
score_info = f" [relevance: {chunk['score']}]" if chunk.get("score") else ""
source = chunk.get("filename") or chunk.get("subject") or "Unknown source"
parts.append(
f"--- Chunk {i} | {source}{score_info} ---\n{chunk['text']}\n"
)
return "\n".join(parts)
# ---------------------------------------------------------------------------
# PDF text extraction (used for full-context fallback loading)
# ---------------------------------------------------------------------------
def read_pdf_text(pdf_bytes: bytes) -> str:
"""Extract plain text from PDF bytes. Returns empty string on failure."""
if not PYPDF2_AVAILABLE:
return ""
text_parts = []
try:
reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
for page in reader.pages:
try:
page_text = page.extract_text()
if page_text:
# Sanitise surrogate pairs that can appear in some PDFs
filtered = page_text.encode("utf-16", "surrogatepass").decode(
"utf-16", "ignore"
)
text_parts.append(filtered)
except Exception:
pass
except Exception:
return pdf_bytes.decode("utf-8", errors="ignore") if pdf_bytes else ""
return "\n".join(text_parts)