| """Routes retrieval requests to the appropriate retriever based on source_hint. |
| |
| Cross-retriever merging uses Reciprocal Rank Fusion (RRF) on per-retriever |
| ranked lists — score scales differ across retrievers (RRF, cosine, distance) |
| and aren't directly comparable, so we rank-merge instead of score-merge. |
| """ |
|
|
| import asyncio |
| import hashlib |
| import json |
| from dataclasses import asdict |
| from typing import Literal |
|
|
| from src.db.redis.connection import get_redis |
| from src.middlewares.logging import get_logger |
| from src.rag.base import BaseRetriever, RetrievalResult |
|
|
| logger = get_logger("retrieval_router") |
|
|
| _CACHE_TTL = 3600 |
| _CACHE_KEY_PREFIX = "retrieval" |
| _RRF_K = 60 |
| SourceHint = Literal["document", "schema", "both"] |
|
|
|
|
| def _result_dedup_key(r: RetrievalResult) -> tuple: |
| """Cross-retriever dedup key — distinguishes DB columns vs DB tables vs |
| tabular columns vs prose chunks vs sheet-level chunks.""" |
| data = r.metadata.get("data", {}) |
| return ( |
| r.source_type, |
| data.get("table_name"), |
| data.get("column_name"), |
| data.get("filename"), |
| data.get("sheet_name"), |
| data.get("chunk_index"), |
| r.metadata.get("chunk_level"), |
| ) |
|
|
|
|
| def _rrf_merge( |
| ranked_lists: list[list[RetrievalResult]], |
| top_k: int, |
| k_rrf: int = _RRF_K, |
| ) -> list[RetrievalResult]: |
| """Reciprocal Rank Fusion across retriever batches. |
| |
| Each input list is treated as already best-first ordered. Items are |
| deduped via _result_dedup_key and re-ranked by aggregated reciprocal |
| rank across all lists. Score on the returned RetrievalResult is the |
| aggregated RRF score (uniform scale across legs). |
| """ |
| scores: dict[tuple, float] = {} |
| index: dict[tuple, RetrievalResult] = {} |
|
|
| for ranked in ranked_lists: |
| for rank, result in enumerate(ranked): |
| key = _result_dedup_key(result) |
| scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1) |
| |
| |
| if key not in index: |
| index[key] = result |
|
|
| merged = sorted(index.values(), key=lambda r: scores[_result_dedup_key(r)], reverse=True) |
| |
| for r in merged: |
| r.score = scores[_result_dedup_key(r)] |
| return merged[:top_k] |
|
|
|
|
| async def invalidate_retrieval_cache(user_id: str) -> int: |
| """Delete every cached retrieval entry for `user_id`. |
| |
| Called by ingest/upload/delete API handlers after a successful write so |
| the next retrieval picks up the new data instead of stale cached top-k. |
| Returns the number of keys removed. |
| """ |
| redis = await get_redis() |
| pattern = f"{_CACHE_KEY_PREFIX}:{user_id}:*" |
| keys = [key async for key in redis.scan_iter(match=pattern)] |
| if not keys: |
| return 0 |
| deleted = await redis.delete(*keys) |
| logger.info("retrieval cache invalidated", user_id=user_id, deleted=deleted) |
| return int(deleted) |
|
|
|
|
| class RetrievalRouter: |
| def __init__( |
| self, |
| schema_retriever: BaseRetriever, |
| document_retriever: BaseRetriever, |
| ): |
| self._retrievers: dict[str, BaseRetriever] = { |
| "schema": schema_retriever, |
| "document": document_retriever, |
| } |
|
|
| def _route(self, source_hint: SourceHint) -> list[tuple[str, BaseRetriever]]: |
| if source_hint == "schema": |
| return [("schema", self._retrievers["schema"])] |
| if source_hint == "document": |
| return [("document", self._retrievers["document"])] |
| return list(self._retrievers.items()) |
|
|
| async def retrieve( |
| self, |
| query: str, |
| user_id: str, |
| source_hint: SourceHint = "both", |
| k: int = 10, |
| ) -> list[RetrievalResult]: |
| redis = await get_redis() |
| query_hash = hashlib.md5(query.encode()).hexdigest() |
| cache_key = f"{_CACHE_KEY_PREFIX}:{user_id}:{source_hint}:{query_hash}:{k}" |
|
|
| cached = await redis.get(cache_key) |
| if cached: |
| try: |
| raw = json.loads(cached) |
| logger.info("returning cached retrieval results", source_hint=source_hint) |
| return [RetrievalResult(**r) for r in raw] |
| except Exception: |
| logger.warning("corrupted retrieval cache, fetching fresh", cache_key=cache_key) |
|
|
| results = await self._retrieve_uncached(query, user_id, source_hint, k) |
|
|
| |
| |
| |
| if not results and source_hint != "both": |
| logger.warning( |
| "empty retrieval, falling back to source_hint='both'", |
| original_source_hint=source_hint, |
| ) |
| results = await self._retrieve_uncached(query, user_id, "both", k) |
|
|
| await redis.setex( |
| cache_key, |
| _CACHE_TTL, |
| json.dumps([asdict(r) for r in results]), |
| ) |
| return results |
|
|
| async def _retrieve_uncached( |
| self, |
| query: str, |
| user_id: str, |
| source_hint: SourceHint, |
| k: int, |
| ) -> list[RetrievalResult]: |
| routed = self._route(source_hint) |
| batches = await asyncio.gather( |
| *[r.retrieve(query, user_id, k) for _, r in routed], |
| return_exceptions=True, |
| ) |
|
|
| valid_lists: list[list[RetrievalResult]] = [] |
| per_retriever: dict[str, int | str] = {} |
| for (name, _), batch in zip(routed, batches): |
| if isinstance(batch, Exception): |
| logger.error("retriever failed", retriever=name, error=str(batch)) |
| per_retriever[name] = "error" |
| continue |
| valid_lists.append(batch) |
| per_retriever[name] = len(batch) |
|
|
| results = _rrf_merge(valid_lists, top_k=k) |
|
|
| logger.info( |
| "router result", |
| source_hint=source_hint, |
| per_retriever=per_retriever, |
| final_count=len(results), |
| top_score=results[0].score if results else None, |
| bottom_score=results[-1].score if results else None, |
| ) |
| return results |
|
|