"""Routes retrieval requests to the appropriate retriever based on source_hint. Cross-retriever merging uses Reciprocal Rank Fusion (RRF) on per-retriever ranked lists — score scales differ across retrievers (RRF, cosine, distance) and aren't directly comparable, so we rank-merge instead of score-merge. """ import asyncio import hashlib import json from dataclasses import asdict from typing import Literal from src.db.redis.connection import get_redis from src.middlewares.logging import get_logger from src.rag.base import BaseRetriever, RetrievalResult logger = get_logger("retrieval_router") _CACHE_TTL = 3600 # 1 hour _CACHE_KEY_PREFIX = "retrieval" _RRF_K = 60 # standard RRF constant SourceHint = Literal["document", "schema", "both"] def _result_dedup_key(r: RetrievalResult) -> tuple: """Cross-retriever dedup key — distinguishes DB columns vs DB tables vs tabular columns vs prose chunks vs sheet-level chunks.""" data = r.metadata.get("data", {}) return ( r.source_type, data.get("table_name"), data.get("column_name"), data.get("filename"), data.get("sheet_name"), data.get("chunk_index"), # disambiguates multiple prose chunks per doc r.metadata.get("chunk_level"), # distinguishes sheet vs column chunks ) def _rrf_merge( ranked_lists: list[list[RetrievalResult]], top_k: int, k_rrf: int = _RRF_K, ) -> list[RetrievalResult]: """Reciprocal Rank Fusion across retriever batches. Each input list is treated as already best-first ordered. Items are deduped via _result_dedup_key and re-ranked by aggregated reciprocal rank across all lists. Score on the returned RetrievalResult is the aggregated RRF score (uniform scale across legs). """ scores: dict[tuple, float] = {} index: dict[tuple, RetrievalResult] = {} for ranked in ranked_lists: for rank, result in enumerate(ranked): key = _result_dedup_key(result) scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1) # Keep the first occurrence; metadata is identical for the same # key across lists, so any copy is fine. if key not in index: index[key] = result merged = sorted(index.values(), key=lambda r: scores[_result_dedup_key(r)], reverse=True) # Overwrite score with RRF score so downstream consumers see a uniform scale. for r in merged: r.score = scores[_result_dedup_key(r)] return merged[:top_k] async def invalidate_retrieval_cache(user_id: str) -> int: """Delete every cached retrieval entry for `user_id`. Called by ingest/upload/delete API handlers after a successful write so the next retrieval picks up the new data instead of stale cached top-k. Returns the number of keys removed. """ redis = await get_redis() pattern = f"{_CACHE_KEY_PREFIX}:{user_id}:*" keys = [key async for key in redis.scan_iter(match=pattern)] if not keys: return 0 deleted = await redis.delete(*keys) logger.info("retrieval cache invalidated", user_id=user_id, deleted=deleted) return int(deleted) class RetrievalRouter: def __init__( self, schema_retriever: BaseRetriever, document_retriever: BaseRetriever, ): self._retrievers: dict[str, BaseRetriever] = { "schema": schema_retriever, "document": document_retriever, } def _route(self, source_hint: SourceHint) -> list[tuple[str, BaseRetriever]]: if source_hint == "schema": return [("schema", self._retrievers["schema"])] if source_hint == "document": return [("document", self._retrievers["document"])] return list(self._retrievers.items()) async def retrieve( self, query: str, user_id: str, source_hint: SourceHint = "both", k: int = 10, ) -> list[RetrievalResult]: redis = await get_redis() query_hash = hashlib.md5(query.encode()).hexdigest() cache_key = f"{_CACHE_KEY_PREFIX}:{user_id}:{source_hint}:{query_hash}:{k}" cached = await redis.get(cache_key) if cached: try: raw = json.loads(cached) logger.info("returning cached retrieval results", source_hint=source_hint) return [RetrievalResult(**r) for r in raw] except Exception: logger.warning("corrupted retrieval cache, fetching fresh", cache_key=cache_key) results = await self._retrieve_uncached(query, user_id, source_hint, k) # Empty-result fallback: orchestrator may have misclassified intent. # Retry once with "both" before giving up. No-op when source_hint is # already "both". if not results and source_hint != "both": logger.warning( "empty retrieval, falling back to source_hint='both'", original_source_hint=source_hint, ) results = await self._retrieve_uncached(query, user_id, "both", k) await redis.setex( cache_key, _CACHE_TTL, json.dumps([asdict(r) for r in results]), ) return results async def _retrieve_uncached( self, query: str, user_id: str, source_hint: SourceHint, k: int, ) -> list[RetrievalResult]: routed = self._route(source_hint) batches = await asyncio.gather( *[r.retrieve(query, user_id, k) for _, r in routed], return_exceptions=True, ) valid_lists: list[list[RetrievalResult]] = [] per_retriever: dict[str, int | str] = {} for (name, _), batch in zip(routed, batches): if isinstance(batch, Exception): logger.error("retriever failed", retriever=name, error=str(batch)) per_retriever[name] = "error" continue valid_lists.append(batch) per_retriever[name] = len(batch) results = _rrf_merge(valid_lists, top_k=k) logger.info( "router result", source_hint=source_hint, per_retriever=per_retriever, final_count=len(results), top_score=results[0].score if results else None, bottom_score=results[-1].score if results else None, ) return results