Spaces:
Sleeping
Sleeping
| import sys | |
| import subprocess | |
| from typing import Any | |
| import streamlit as st | |
| from src.vectorstore import get_retriever | |
| from src.qa_chain import make_conversational_chain | |
| import os | |
| import json | |
| from typing import Dict, List, Tuple, cast | |
| # Unconditionally import KG modules; let import errors propagate so failures are visible | |
| from src.kg.store import KGStore | |
| from src.kg.retriever import KGRetriever | |
| def run_ingest_cli(data_dir: str, persist_dir: str) -> str: | |
| """Run the ingestion module to rebuild the vectorstore. | |
| Runs the ingest CLI as a subprocess and returns stdout on success. | |
| On failure raises subprocess.CalledProcessError with captured stdout/stderr so callers | |
| (for example the Streamlit UI) can display a helpful error message. | |
| """ | |
| cmd = [ | |
| sys.executable, | |
| "-m", | |
| "src.ingest", | |
| "--data-dir", | |
| data_dir, | |
| "--persist-dir", | |
| persist_dir, | |
| ] | |
| try: | |
| # Add a timeout to avoid indefinite hanging; 600s (10 minutes) is generous for large ingests | |
| completed = subprocess.run(cmd, capture_output=True, text=True, timeout=600) | |
| except subprocess.TimeoutExpired as te: | |
| # Provide helpful error including partial output | |
| raise subprocess.CalledProcessError( | |
| returncode=124, | |
| cmd=cmd, | |
| output=getattr(te, 'output', '') or '', | |
| stderr=f"Ingest process timed out after {te.timeout} seconds", | |
| ) | |
| # Check return code and raise with captured output on failure | |
| if completed.returncode != 0: | |
| # Raise with captured output to make it easy to present to the user | |
| raise subprocess.CalledProcessError( | |
| returncode=completed.returncode, | |
| cmd=cmd, | |
| output=completed.stdout, | |
| stderr=completed.stderr, | |
| ) | |
| return completed.stdout | |
| def _load_chunks_index(persist_dir: str) -> Dict[str, Dict]: | |
| idx_path = os.path.join(persist_dir, "chunks_index.json") | |
| if not os.path.exists(idx_path): | |
| return {} | |
| try: | |
| with open(idx_path, "r", encoding="utf-8") as fh: | |
| return json.load(fh) | |
| except Exception: | |
| return {} | |
| def answer_with_kg( | |
| chain, | |
| question: str, | |
| chat_history: List[Tuple[str, str]], | |
| persist_dir: str, | |
| kg_hops: int = 1, | |
| kg_context_max_chars: int = 1000, | |
| ) -> Any: | |
| """Augment question with KG context (if available) and run the chain. | |
| This is a low-risk integration: we build a short textual summary from the KG | |
| (node labels and short chunk snippets from chunks_index.json) and prepend it to | |
| the question. The chain's retriever still runs; KG context is additional grounding. | |
| """ | |
| kg_text_parts: List[str] = [] | |
| # Load chunks index mapping | |
| chunks_index = _load_chunks_index(persist_dir) | |
| # Load KG unconditionally; let import or parse errors raise so callers can see them. | |
| kg_path = os.path.join(persist_dir, "kg_store.ttl") | |
| try: | |
| kg = KGStore(path=kg_path) | |
| retr = KGRetriever(kg) | |
| chunk_ids, summaries = retr.get_context_for_question(question, hops=kg_hops) | |
| if summaries: | |
| kg_text_parts.append("KG entities: " + ", ".join(summaries)) | |
| # add chunk snippets | |
| for cid in chunk_ids: | |
| info = chunks_index.get(cid) | |
| if info: | |
| txt = info.get("text", "") | |
| if txt: | |
| snippet = txt.strip().replace("\n", " ")[:min(len(txt), kg_context_max_chars)] | |
| kg_text_parts.append(f"[KG chunk {cid}]: {snippet}") | |
| except Exception: | |
| # If KG load or query fails, skip KG augmentation (allow the exception to surface in logs) | |
| kg_text_parts = [] | |
| kg_context = "\n\n".join(kg_text_parts) if kg_text_parts else "" | |
| if kg_context: | |
| augmented_question = f"KG CONTEXT:\n{kg_context}\n\nUser Question:\n{question}" | |
| else: | |
| augmented_question = question | |
| return chain({"question": augmented_question, "chat_history": chat_history}) | |
| def build_or_load_retriever_cached( | |
| data_dir: str, | |
| persist_dir: str, | |
| top_k: int, | |
| retrieval_mode: str, | |
| ) -> Any: | |
| """Load a retriever from the persisted vectorstore or build a new one. | |
| If loading fails—usually because the vectorstore doesn't exist—this | |
| function triggers ingestion and retries loading. | |
| Args: | |
| data_dir: Directory containing input documents. | |
| persist_dir: Directory where the Chroma vectorstore is stored. | |
| top_k: Number of chunks to retrieve. | |
| retrieval_mode: Retrieval strategy (mmr, similarity, hybrid). | |
| Returns: | |
| An initialized retriever instance. | |
| """ | |
| try: | |
| # Cast retrieval_mode to the expected literal type to satisfy type checkers | |
| from typing import Literal | |
| RetrievalMode = Literal["mmr", "similarity", "hybrid"] | |
| mode = cast(RetrievalMode, retrieval_mode) | |
| return get_retriever( | |
| persist_dir=persist_dir, | |
| top_k=top_k, | |
| retrieval_mode=mode, | |
| ) | |
| except Exception: | |
| run_ingest_cli(data_dir=data_dir, persist_dir=persist_dir) | |
| from typing import Literal | |
| RetrievalMode = Literal["mmr", "similarity", "hybrid"] | |
| mode = cast(RetrievalMode, retrieval_mode) | |
| return get_retriever( | |
| persist_dir=persist_dir, | |
| top_k=top_k, | |
| retrieval_mode=mode, | |
| ) | |
| def get_chain_cached( | |
| model_name: str, | |
| top_k: int, | |
| retrieval_mode: str, | |
| data_dir: str, | |
| persist_dir: str, | |
| ) -> Any: | |
| """Create or load a cached conversational QA chain. | |
| Args: | |
| model_name: The OpenAI model to use (gpt-3.5-turbo, gpt-4). | |
| top_k: Number of chunks to retrieve. | |
| retrieval_mode: Retrieval mode for the retriever. | |
| data_dir: Path to data directory. | |
| persist_dir: Path to vectorstore directory. | |
| Returns: | |
| A fully configured conversational QA chain. | |
| """ | |
| retriever = build_or_load_retriever_cached( | |
| data_dir=data_dir, | |
| persist_dir=persist_dir, | |
| top_k=top_k, | |
| retrieval_mode=retrieval_mode, | |
| ) | |
| return make_conversational_chain(retriever, model_name=model_name) | |