| """Public retrieval API — thin wrapper around RetrievalRouter.""" |
|
|
| from sqlalchemy.ext.asyncio import AsyncSession |
|
|
| from src.middlewares.logging import get_logger |
| from src.rag.base import RetrievalResult |
| from src.rag.retrievers.document import document_retriever |
| from src.rag.retrievers.schema import schema_retriever |
| from src.rag.router import RetrievalRouter, SourceHint |
|
|
| logger = get_logger("retriever") |
|
|
|
|
| class RetrieverService: |
| """Public retrieval service used by chat.py and search tools. |
| |
| Delegates to RetrievalRouter which dispatches based on source_hint. |
| Returns RetrievalResult objects directly so downstream consumers |
| (db_executor, tabular_executor) can be fed without lossy dict |
| conversion. The `db` parameter is accepted for call-site compatibility |
| but currently unused — retrieval reads PGVector via _pgvector_engine |
| inside each retriever. |
| """ |
|
|
| def __init__(self): |
| self._router = RetrievalRouter( |
| schema_retriever=schema_retriever, |
| document_retriever=document_retriever, |
| ) |
|
|
| async def retrieve( |
| self, |
| query: str, |
| user_id: str, |
| db: AsyncSession, |
| k: int = 5, |
| source_hint: SourceHint = "both", |
| ) -> list[RetrievalResult]: |
| try: |
| return await self._router.retrieve(query, user_id, source_hint, k) |
| except Exception as e: |
| logger.error("retrieval failed", error=str(e)) |
| return [] |
|
|
|
|
| retriever = RetrieverService() |
|
|