ishaq101's picture
[KM-438][KM-439] Improve Retrieval and Querying feature (#15)
c93ec90
"""Routes retrieval requests to the appropriate retriever based on source_hint.
Cross-retriever merging uses Reciprocal Rank Fusion (RRF) on per-retriever
ranked lists — score scales differ across retrievers (RRF, cosine, distance)
and aren't directly comparable, so we rank-merge instead of score-merge.
"""
import asyncio
import hashlib
import json
from dataclasses import asdict
from typing import Literal
from src.db.redis.connection import get_redis
from src.middlewares.logging import get_logger
from src.rag.base import BaseRetriever, RetrievalResult
logger = get_logger("retrieval_router")
_CACHE_TTL = 3600 # 1 hour
_CACHE_KEY_PREFIX = "retrieval"
_RRF_K = 60 # standard RRF constant
SourceHint = Literal["document", "schema", "both"]
def _result_dedup_key(r: RetrievalResult) -> tuple:
"""Cross-retriever dedup key — distinguishes DB columns vs DB tables vs
tabular columns vs prose chunks vs sheet-level chunks."""
data = r.metadata.get("data", {})
return (
r.source_type,
data.get("table_name"),
data.get("column_name"),
data.get("filename"),
data.get("sheet_name"),
data.get("chunk_index"), # disambiguates multiple prose chunks per doc
r.metadata.get("chunk_level"), # distinguishes sheet vs column chunks
)
def _rrf_merge(
ranked_lists: list[list[RetrievalResult]],
top_k: int,
k_rrf: int = _RRF_K,
) -> list[RetrievalResult]:
"""Reciprocal Rank Fusion across retriever batches.
Each input list is treated as already best-first ordered. Items are
deduped via _result_dedup_key and re-ranked by aggregated reciprocal
rank across all lists. Score on the returned RetrievalResult is the
aggregated RRF score (uniform scale across legs).
"""
scores: dict[tuple, float] = {}
index: dict[tuple, RetrievalResult] = {}
for ranked in ranked_lists:
for rank, result in enumerate(ranked):
key = _result_dedup_key(result)
scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
# Keep the first occurrence; metadata is identical for the same
# key across lists, so any copy is fine.
if key not in index:
index[key] = result
merged = sorted(index.values(), key=lambda r: scores[_result_dedup_key(r)], reverse=True)
# Overwrite score with RRF score so downstream consumers see a uniform scale.
for r in merged:
r.score = scores[_result_dedup_key(r)]
return merged[:top_k]
async def invalidate_retrieval_cache(user_id: str) -> int:
"""Delete every cached retrieval entry for `user_id`.
Called by ingest/upload/delete API handlers after a successful write so
the next retrieval picks up the new data instead of stale cached top-k.
Returns the number of keys removed.
"""
redis = await get_redis()
pattern = f"{_CACHE_KEY_PREFIX}:{user_id}:*"
keys = [key async for key in redis.scan_iter(match=pattern)]
if not keys:
return 0
deleted = await redis.delete(*keys)
logger.info("retrieval cache invalidated", user_id=user_id, deleted=deleted)
return int(deleted)
class RetrievalRouter:
def __init__(
self,
schema_retriever: BaseRetriever,
document_retriever: BaseRetriever,
):
self._retrievers: dict[str, BaseRetriever] = {
"schema": schema_retriever,
"document": document_retriever,
}
def _route(self, source_hint: SourceHint) -> list[tuple[str, BaseRetriever]]:
if source_hint == "schema":
return [("schema", self._retrievers["schema"])]
if source_hint == "document":
return [("document", self._retrievers["document"])]
return list(self._retrievers.items())
async def retrieve(
self,
query: str,
user_id: str,
source_hint: SourceHint = "both",
k: int = 10,
) -> list[RetrievalResult]:
redis = await get_redis()
query_hash = hashlib.md5(query.encode()).hexdigest()
cache_key = f"{_CACHE_KEY_PREFIX}:{user_id}:{source_hint}:{query_hash}:{k}"
cached = await redis.get(cache_key)
if cached:
try:
raw = json.loads(cached)
logger.info("returning cached retrieval results", source_hint=source_hint)
return [RetrievalResult(**r) for r in raw]
except Exception:
logger.warning("corrupted retrieval cache, fetching fresh", cache_key=cache_key)
results = await self._retrieve_uncached(query, user_id, source_hint, k)
# Empty-result fallback: orchestrator may have misclassified intent.
# Retry once with "both" before giving up. No-op when source_hint is
# already "both".
if not results and source_hint != "both":
logger.warning(
"empty retrieval, falling back to source_hint='both'",
original_source_hint=source_hint,
)
results = await self._retrieve_uncached(query, user_id, "both", k)
await redis.setex(
cache_key,
_CACHE_TTL,
json.dumps([asdict(r) for r in results]),
)
return results
async def _retrieve_uncached(
self,
query: str,
user_id: str,
source_hint: SourceHint,
k: int,
) -> list[RetrievalResult]:
routed = self._route(source_hint)
batches = await asyncio.gather(
*[r.retrieve(query, user_id, k) for _, r in routed],
return_exceptions=True,
)
valid_lists: list[list[RetrievalResult]] = []
per_retriever: dict[str, int | str] = {}
for (name, _), batch in zip(routed, batches):
if isinstance(batch, Exception):
logger.error("retriever failed", retriever=name, error=str(batch))
per_retriever[name] = "error"
continue
valid_lists.append(batch)
per_retriever[name] = len(batch)
results = _rrf_merge(valid_lists, top_k=k)
logger.info(
"router result",
source_hint=source_hint,
per_retriever=per_retriever,
final_count=len(results),
top_score=results[0].score if results else None,
bottom_score=results[-1].score if results else None,
)
return results