| | """ |
| | SPARKNET RAG API Routes |
| | Endpoints for RAG queries, search, and indexing management. |
| | """ |
| |
|
| | from fastapi import APIRouter, HTTPException, Query, Depends |
| | from fastapi.responses import StreamingResponse |
| | from typing import List, Optional |
| | from pathlib import Path |
| | from datetime import datetime |
| | import time |
| | import json |
| | import sys |
| | import asyncio |
| |
|
| | |
| | PROJECT_ROOT = Path(__file__).parent.parent.parent |
| | sys.path.insert(0, str(PROJECT_ROOT)) |
| |
|
| | from api.schemas import ( |
| | QueryRequest, RAGResponse, Citation, QueryPlan, QueryIntentType, |
| | SearchRequest, SearchResponse, SearchResult, |
| | StoreStatus, CollectionInfo |
| | ) |
| | from loguru import logger |
| |
|
| | router = APIRouter() |
| |
|
| | |
| | _query_cache = {} |
| | CACHE_TTL_SECONDS = 3600 |
| |
|
| |
|
| | def get_cache_key(query: str, doc_ids: Optional[List[str]]) -> str: |
| | """Generate cache key for query.""" |
| | import hashlib |
| | doc_str = ",".join(sorted(doc_ids)) if doc_ids else "all" |
| | content = f"{query}:{doc_str}" |
| | return hashlib.md5(content.encode()).hexdigest() |
| |
|
| |
|
| | def get_cached_response(cache_key: str) -> Optional[RAGResponse]: |
| | """Get cached response if valid.""" |
| | if cache_key in _query_cache: |
| | cached = _query_cache[cache_key] |
| | if time.time() - cached["timestamp"] < CACHE_TTL_SECONDS: |
| | response = cached["response"] |
| | response.from_cache = True |
| | return response |
| | else: |
| | del _query_cache[cache_key] |
| | return None |
| |
|
| |
|
| | def cache_response(cache_key: str, response: RAGResponse): |
| | """Cache a query response.""" |
| | _query_cache[cache_key] = { |
| | "response": response, |
| | "timestamp": time.time() |
| | } |
| | |
| | if len(_query_cache) > 1000: |
| | oldest_key = min(_query_cache, key=lambda k: _query_cache[k]["timestamp"]) |
| | del _query_cache[oldest_key] |
| |
|
| |
|
| | def _get_rag_system(): |
| | """Get or initialize the RAG system.""" |
| | try: |
| | from src.rag.agentic.orchestrator import AgenticRAG, RAGConfig |
| |
|
| | config = RAGConfig( |
| | model_name="llama3.2:latest", |
| | max_revision_attempts=2, |
| | retrieval_top_k=10, |
| | final_top_k=5, |
| | min_confidence=0.5, |
| | ) |
| | return AgenticRAG(config) |
| | except Exception as e: |
| | logger.error(f"Failed to initialize RAG system: {e}") |
| | return None |
| |
|
| |
|
| | @router.post("/query", response_model=RAGResponse) |
| | async def query_documents(request: QueryRequest): |
| | """ |
| | Execute a RAG query across indexed documents. |
| | |
| | The query goes through the 5-agent pipeline: |
| | 1. QueryPlanner - Intent classification and query decomposition |
| | 2. Retriever - Hybrid dense+sparse search |
| | 3. Reranker - Cross-encoder reranking with MMR |
| | 4. Synthesizer - Answer generation with citations |
| | 5. Critic - Hallucination detection and validation |
| | """ |
| | start_time = time.time() |
| |
|
| | |
| | if request.use_cache: |
| | cache_key = get_cache_key(request.query, request.doc_ids) |
| | cached = get_cached_response(cache_key) |
| | if cached: |
| | cached.latency_ms = (time.time() - start_time) * 1000 |
| | return cached |
| |
|
| | try: |
| | |
| | rag = _get_rag_system() |
| | if not rag: |
| | raise HTTPException(status_code=503, detail="RAG system not available") |
| |
|
| | |
| | filters = {} |
| | if request.doc_ids: |
| | filters["document_id"] = {"$in": request.doc_ids} |
| |
|
| | |
| | logger.info(f"Executing RAG query: {request.query[:50]}...") |
| |
|
| | result = rag.query( |
| | query=request.query, |
| | filters=filters if filters else None, |
| | top_k=request.top_k, |
| | ) |
| |
|
| | |
| | citations = [] |
| | for i, source in enumerate(result.get("sources", [])): |
| | citations.append(Citation( |
| | citation_id=i + 1, |
| | doc_id=source.get("document_id", "unknown"), |
| | document_name=source.get("filename", source.get("document_id", "unknown")), |
| | chunk_id=source.get("chunk_id", f"chunk_{i}"), |
| | chunk_text=source.get("text", "")[:300], |
| | page_num=source.get("page_num"), |
| | relevance_score=source.get("relevance_score", source.get("score", 0.0)), |
| | bbox=source.get("bbox"), |
| | )) |
| |
|
| | |
| | query_plan = None |
| | if "plan" in result: |
| | plan = result["plan"] |
| | query_plan = QueryPlan( |
| | intent=QueryIntentType(plan.get("intent", "factoid").lower()), |
| | sub_queries=plan.get("sub_queries", []), |
| | keywords=plan.get("keywords", []), |
| | strategy=plan.get("strategy", "hybrid"), |
| | ) |
| |
|
| | response = RAGResponse( |
| | query=request.query, |
| | answer=result.get("answer", "I could not find an answer to your question."), |
| | confidence=result.get("confidence", 0.0), |
| | citations=citations, |
| | source_count=len(citations), |
| | query_plan=query_plan, |
| | from_cache=False, |
| | validation=result.get("validation"), |
| | latency_ms=(time.time() - start_time) * 1000, |
| | revision_count=result.get("revision_count", 0), |
| | ) |
| |
|
| | |
| | if request.use_cache and response.confidence >= request.min_confidence: |
| | cache_key = get_cache_key(request.query, request.doc_ids) |
| | cache_response(cache_key, response) |
| |
|
| | return response |
| |
|
| | except HTTPException: |
| | raise |
| | except Exception as e: |
| | logger.error(f"RAG query failed: {e}") |
| | raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}") |
| |
|
| |
|
| | @router.post("/query/stream") |
| | async def query_documents_stream(request: QueryRequest): |
| | """ |
| | Stream RAG response for real-time updates. |
| | |
| | Returns Server-Sent Events (SSE) with partial responses. |
| | """ |
| | async def generate(): |
| | try: |
| | |
| | rag = _get_rag_system() |
| | if not rag: |
| | yield f"data: {json.dumps({'error': 'RAG system not available'})}\n\n" |
| | return |
| |
|
| | |
| | yield f"data: {json.dumps({'stage': 'planning', 'message': 'Analyzing query...'})}\n\n" |
| | await asyncio.sleep(0.1) |
| |
|
| | |
| | filters = {} |
| | if request.doc_ids: |
| | filters["document_id"] = {"$in": request.doc_ids} |
| |
|
| | |
| | yield f"data: {json.dumps({'stage': 'retrieving', 'message': 'Searching documents...'})}\n\n" |
| |
|
| | |
| | result = rag.query( |
| | query=request.query, |
| | filters=filters if filters else None, |
| | top_k=request.top_k, |
| | ) |
| |
|
| | |
| | yield f"data: {json.dumps({'stage': 'sources', 'count': len(result.get('sources', []))})}\n\n" |
| |
|
| | |
| | yield f"data: {json.dumps({'stage': 'synthesizing', 'message': 'Generating answer...'})}\n\n" |
| |
|
| | |
| | answer = result.get("answer", "") |
| | chunk_size = 50 |
| | for i in range(0, len(answer), chunk_size): |
| | chunk = answer[i:i+chunk_size] |
| | yield f"data: {json.dumps({'stage': 'answer', 'chunk': chunk})}\n\n" |
| | await asyncio.sleep(0.02) |
| |
|
| | |
| | citations = [] |
| | for i, source in enumerate(result.get("sources", [])): |
| | citations.append({ |
| | "citation_id": i + 1, |
| | "doc_id": source.get("document_id", "unknown"), |
| | "chunk_text": source.get("text", "")[:200], |
| | "relevance_score": source.get("score", 0.0), |
| | }) |
| |
|
| | final = { |
| | "stage": "complete", |
| | "confidence": result.get("confidence", 0.0), |
| | "citations": citations, |
| | "validation": result.get("validation"), |
| | } |
| | yield f"data: {json.dumps(final)}\n\n" |
| |
|
| | except Exception as e: |
| | logger.error(f"Streaming query failed: {e}") |
| | yield f"data: {json.dumps({'error': str(e)})}\n\n" |
| |
|
| | return StreamingResponse( |
| | generate(), |
| | media_type="text/event-stream", |
| | headers={ |
| | "Cache-Control": "no-cache", |
| | "Connection": "keep-alive", |
| | } |
| | ) |
| |
|
| |
|
| | @router.post("/search", response_model=SearchResponse) |
| | async def search_documents(request: SearchRequest): |
| | """ |
| | Semantic search across indexed documents. |
| | |
| | Returns matching chunks without answer synthesis. |
| | """ |
| | start_time = time.time() |
| |
|
| | try: |
| | from src.rag.store import get_vector_store |
| | from src.rag.embeddings import get_embedding_model |
| |
|
| | store = get_vector_store() |
| | embeddings = get_embedding_model() |
| |
|
| | |
| | query_embedding = embeddings.embed_query(request.query) |
| |
|
| | |
| | where_filter = None |
| | if request.doc_ids: |
| | where_filter = {"document_id": {"$in": request.doc_ids}} |
| |
|
| | |
| | results = store.similarity_search_with_score( |
| | query_embedding=query_embedding, |
| | k=request.top_k, |
| | where=where_filter, |
| | ) |
| |
|
| | |
| | search_results = [] |
| | for doc, score in results: |
| | if score >= request.min_score: |
| | search_results.append(SearchResult( |
| | chunk_id=doc.metadata.get("chunk_id", "unknown"), |
| | doc_id=doc.metadata.get("document_id", "unknown"), |
| | document_name=doc.metadata.get("filename", "unknown"), |
| | text=doc.page_content, |
| | score=score, |
| | page_num=doc.metadata.get("page_num"), |
| | chunk_type=doc.metadata.get("chunk_type", "text"), |
| | )) |
| |
|
| | return SearchResponse( |
| | query=request.query, |
| | total_results=len(search_results), |
| | results=search_results, |
| | latency_ms=(time.time() - start_time) * 1000, |
| | ) |
| |
|
| | except Exception as e: |
| | logger.error(f"Search failed: {e}") |
| | |
| | return SearchResponse( |
| | query=request.query, |
| | total_results=0, |
| | results=[], |
| | latency_ms=(time.time() - start_time) * 1000, |
| | ) |
| |
|
| |
|
| | @router.get("/store/status", response_model=StoreStatus) |
| | async def get_store_status(): |
| | """Get vector store status and statistics.""" |
| | try: |
| | from src.rag.store import get_vector_store |
| |
|
| | store = get_vector_store() |
| |
|
| | |
| | collection = store._collection |
| | count = collection.count() |
| |
|
| | |
| | all_metadata = collection.get(include=["metadatas"]) |
| | doc_ids = set() |
| | for meta in all_metadata.get("metadatas", []): |
| | if meta and "document_id" in meta: |
| | doc_ids.add(meta["document_id"]) |
| |
|
| | collections = [CollectionInfo( |
| | name=store.collection_name, |
| | document_count=len(doc_ids), |
| | chunk_count=count, |
| | embedding_dimension=store.embedding_dimension if hasattr(store, 'embedding_dimension') else 1024, |
| | )] |
| |
|
| | return StoreStatus( |
| | status="healthy", |
| | collections=collections, |
| | total_documents=len(doc_ids), |
| | total_chunks=count, |
| | ) |
| |
|
| | except Exception as e: |
| | logger.error(f"Store status check failed: {e}") |
| | return StoreStatus( |
| | status="error", |
| | collections=[], |
| | total_documents=0, |
| | total_chunks=0, |
| | ) |
| |
|
| |
|
| | @router.delete("/store/collection/{collection_name}") |
| | async def clear_collection(collection_name: str, confirm: bool = Query(False)): |
| | """Clear a vector store collection (dangerous operation).""" |
| | if not confirm: |
| | raise HTTPException( |
| | status_code=400, |
| | detail="This operation will delete all data. Set confirm=true to proceed." |
| | ) |
| |
|
| | try: |
| | from src.rag.store import get_vector_store |
| |
|
| | store = get_vector_store() |
| | if store.collection_name != collection_name: |
| | raise HTTPException(status_code=404, detail=f"Collection not found: {collection_name}") |
| |
|
| | |
| | store._collection.delete(where={}) |
| |
|
| | return {"status": "cleared", "collection": collection_name, "message": "Collection cleared successfully"} |
| |
|
| | except HTTPException: |
| | raise |
| | except Exception as e: |
| | logger.error(f"Collection clear failed: {e}") |
| | raise HTTPException(status_code=500, detail=f"Clear failed: {str(e)}") |
| |
|
| |
|
| | @router.get("/cache/stats") |
| | async def get_cache_stats(): |
| | """Get query cache statistics.""" |
| | current_time = time.time() |
| | valid_entries = sum( |
| | 1 for v in _query_cache.values() |
| | if current_time - v["timestamp"] < CACHE_TTL_SECONDS |
| | ) |
| |
|
| | return { |
| | "total_entries": len(_query_cache), |
| | "valid_entries": valid_entries, |
| | "expired_entries": len(_query_cache) - valid_entries, |
| | "ttl_seconds": CACHE_TTL_SECONDS, |
| | } |
| |
|
| |
|
| | @router.delete("/cache") |
| | async def clear_cache(): |
| | """Clear the query cache.""" |
| | count = len(_query_cache) |
| | _query_cache.clear() |
| | return {"status": "cleared", "entries_removed": count} |
| |
|