Rifqi Hafizuddin commited on
Commit ·
e4f62b8
1
Parent(s): de32ab0
[NOTICKET] minor refactoring
Browse files- src/api/v1/chat.py +17 -26
- src/rag/retriever.py +8 -6
- src/rag/retrievers/baseline.py +9 -3
- src/tools/search.py +3 -3
src/api/v1/chat.py
CHANGED
|
@@ -48,43 +48,43 @@ class ChatRequest(BaseModel):
|
|
| 48 |
message: str
|
| 49 |
|
| 50 |
|
| 51 |
-
def _format_context(results: List[
|
| 52 |
"""Format retrieval results as context string for the LLM."""
|
| 53 |
lines = []
|
| 54 |
for result in results:
|
| 55 |
-
filename = result
|
| 56 |
-
page = result
|
| 57 |
source_label = f"{filename}, p.{page}" if page else filename
|
| 58 |
-
lines.append(f"[Source: {source_label}]\n{result
|
| 59 |
return "\n".join(lines)
|
| 60 |
|
| 61 |
|
| 62 |
-
def _extract_sources(results: List[
|
| 63 |
"""Extract deduplicated source references from retrieval results."""
|
| 64 |
seen = set()
|
| 65 |
sources = []
|
| 66 |
for result in results:
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
| 70 |
if key not in seen:
|
| 71 |
seen.add(key)
|
| 72 |
sources.append({
|
| 73 |
-
"document_id":
|
| 74 |
-
"filename":
|
| 75 |
-
"page_label":
|
| 76 |
})
|
| 77 |
else:
|
| 78 |
-
|
| 79 |
-
key = (meta.get("data", {}).get("table_name"), meta.get("data", {}).get("column_name"))
|
| 80 |
if key not in seen:
|
| 81 |
seen.add(key)
|
| 82 |
-
table_name =
|
| 83 |
user_id = meta.get("user_id")
|
| 84 |
sources.append({
|
| 85 |
"document_id": f"{user_id}_{table_name}",
|
| 86 |
-
"filename":
|
| 87 |
-
"page_label":
|
| 88 |
})
|
| 89 |
|
| 90 |
logger.debug(f"Extracted sources: {sources}")
|
|
@@ -229,17 +229,8 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
|
| 229 |
|
| 230 |
source_hint = intent_result.get("source_hint", "both")
|
| 231 |
if source_hint in ("schema", "both"):
|
| 232 |
-
retrieval_objects = [
|
| 233 |
-
RetrievalResult(
|
| 234 |
-
content=r["content"],
|
| 235 |
-
metadata=r["metadata"],
|
| 236 |
-
score=0.0,
|
| 237 |
-
source_type=r["metadata"].get("source_type", ""),
|
| 238 |
-
)
|
| 239 |
-
for r in raw_results
|
| 240 |
-
]
|
| 241 |
query_results = await query_executor.execute(
|
| 242 |
-
results=
|
| 243 |
user_id=request.user_id,
|
| 244 |
db=db,
|
| 245 |
question=intent_result.get("search_query") or request.message,
|
|
|
|
| 48 |
message: str
|
| 49 |
|
| 50 |
|
| 51 |
+
def _format_context(results: List[RetrievalResult]) -> str:
|
| 52 |
"""Format retrieval results as context string for the LLM."""
|
| 53 |
lines = []
|
| 54 |
for result in results:
|
| 55 |
+
filename = result.metadata.get("filename", "Unknown")
|
| 56 |
+
page = result.metadata.get("page_label")
|
| 57 |
source_label = f"{filename}, p.{page}" if page else filename
|
| 58 |
+
lines.append(f"[Source: {source_label}]\n{result.content}\n")
|
| 59 |
return "\n".join(lines)
|
| 60 |
|
| 61 |
|
| 62 |
+
def _extract_sources(results: List[RetrievalResult]) -> List[Dict[str, Any]]:
|
| 63 |
"""Extract deduplicated source references from retrieval results."""
|
| 64 |
seen = set()
|
| 65 |
sources = []
|
| 66 |
for result in results:
|
| 67 |
+
meta = result.metadata
|
| 68 |
+
data = meta.get("data", {})
|
| 69 |
+
if "document_id" in data:
|
| 70 |
+
key = (data.get("document_id"), data.get("page_label"))
|
| 71 |
if key not in seen:
|
| 72 |
seen.add(key)
|
| 73 |
sources.append({
|
| 74 |
+
"document_id": data.get("document_id"),
|
| 75 |
+
"filename": data.get("filename", "Unknown"),
|
| 76 |
+
"page_label": data.get("page_label", "Unknown"),
|
| 77 |
})
|
| 78 |
else:
|
| 79 |
+
key = (data.get("table_name"), data.get("column_name"))
|
|
|
|
| 80 |
if key not in seen:
|
| 81 |
seen.add(key)
|
| 82 |
+
table_name = data.get("table_name")
|
| 83 |
user_id = meta.get("user_id")
|
| 84 |
sources.append({
|
| 85 |
"document_id": f"{user_id}_{table_name}",
|
| 86 |
+
"filename": data.get("table_name", "Unknown"),
|
| 87 |
+
"page_label": data.get("column_name", "Unknown"),
|
| 88 |
})
|
| 89 |
|
| 90 |
logger.debug(f"Extracted sources: {sources}")
|
|
|
|
| 229 |
|
| 230 |
source_hint = intent_result.get("source_hint", "both")
|
| 231 |
if source_hint in ("schema", "both"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
query_results = await query_executor.execute(
|
| 233 |
+
results=raw_results,
|
| 234 |
user_id=request.user_id,
|
| 235 |
db=db,
|
| 236 |
question=intent_result.get("search_query") or request.message,
|
src/rag/retriever.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
"""Public retrieval API — thin wrapper around RetrievalRouter."""
|
| 2 |
|
| 3 |
-
from typing import Any
|
| 4 |
-
|
| 5 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 6 |
|
| 7 |
from src.middlewares.logging import get_logger
|
|
|
|
| 8 |
from src.rag.retrievers.document import document_retriever
|
| 9 |
from src.rag.retrievers.schema import schema_retriever
|
| 10 |
from src.rag.router import RetrievalRouter, SourceHint
|
|
@@ -16,7 +15,11 @@ class RetrieverService:
|
|
| 16 |
"""Public retrieval service used by chat.py and search tools.
|
| 17 |
|
| 18 |
Delegates to RetrievalRouter which dispatches based on source_hint.
|
| 19 |
-
Returns
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
"""
|
| 21 |
|
| 22 |
def __init__(self):
|
|
@@ -32,10 +35,9 @@ class RetrieverService:
|
|
| 32 |
db: AsyncSession,
|
| 33 |
k: int = 5,
|
| 34 |
source_hint: SourceHint = "both",
|
| 35 |
-
) -> list[
|
| 36 |
try:
|
| 37 |
-
|
| 38 |
-
return [{"content": r.content, "metadata": r.metadata} for r in results]
|
| 39 |
except Exception as e:
|
| 40 |
logger.error("retrieval failed", error=str(e))
|
| 41 |
return []
|
|
|
|
| 1 |
"""Public retrieval API — thin wrapper around RetrievalRouter."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 4 |
|
| 5 |
from src.middlewares.logging import get_logger
|
| 6 |
+
from src.rag.base import RetrievalResult
|
| 7 |
from src.rag.retrievers.document import document_retriever
|
| 8 |
from src.rag.retrievers.schema import schema_retriever
|
| 9 |
from src.rag.router import RetrievalRouter, SourceHint
|
|
|
|
| 15 |
"""Public retrieval service used by chat.py and search tools.
|
| 16 |
|
| 17 |
Delegates to RetrievalRouter which dispatches based on source_hint.
|
| 18 |
+
Returns RetrievalResult objects directly so downstream consumers
|
| 19 |
+
(db_executor, tabular_executor) can be fed without lossy dict
|
| 20 |
+
conversion. The `db` parameter is accepted for call-site compatibility
|
| 21 |
+
but currently unused — retrieval reads PGVector via _pgvector_engine
|
| 22 |
+
inside each retriever.
|
| 23 |
"""
|
| 24 |
|
| 25 |
def __init__(self):
|
|
|
|
| 35 |
db: AsyncSession,
|
| 36 |
k: int = 5,
|
| 37 |
source_hint: SourceHint = "both",
|
| 38 |
+
) -> list[RetrievalResult]:
|
| 39 |
try:
|
| 40 |
+
return await self._router.retrieve(query, user_id, source_hint, k)
|
|
|
|
| 41 |
except Exception as e:
|
| 42 |
logger.error("retrieval failed", error=str(e))
|
| 43 |
return []
|
src/rag/retrievers/baseline.py
CHANGED
|
@@ -13,8 +13,14 @@ logger = get_logger("retriever")
|
|
| 13 |
_RETRIEVAL_CACHE_TTL = 3600 # 1 hour
|
| 14 |
|
| 15 |
|
| 16 |
-
class
|
| 17 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def __init__(self):
|
| 20 |
self.vector_store = get_vector_store()
|
|
@@ -67,4 +73,4 @@ class RetrieverService:
|
|
| 67 |
return []
|
| 68 |
|
| 69 |
|
| 70 |
-
|
|
|
|
| 13 |
_RETRIEVAL_CACHE_TTL = 3600 # 1 hour
|
| 14 |
|
| 15 |
|
| 16 |
+
class BaselineRetrieverService:
|
| 17 |
+
"""Baseline (pre-Phase-1) retriever — preserved for benchmark comparison.
|
| 18 |
+
|
| 19 |
+
Renamed from RetrieverService so it doesn't shadow the production wrapper
|
| 20 |
+
at src/rag/retriever.py. Production code imports from src.rag.retriever;
|
| 21 |
+
benchmark scripts that want this baseline must import explicitly from
|
| 22 |
+
src.rag.retrievers.baseline.
|
| 23 |
+
"""
|
| 24 |
|
| 25 |
def __init__(self):
|
| 26 |
self.vector_store = get_vector_store()
|
|
|
|
| 73 |
return []
|
| 74 |
|
| 75 |
|
| 76 |
+
baseline_retriever = BaselineRetrieverService()
|
src/tools/search.py
CHANGED
|
@@ -34,10 +34,10 @@ async def search_documents(
|
|
| 34 |
|
| 35 |
formatted_results = []
|
| 36 |
for result in results:
|
| 37 |
-
filename = result
|
| 38 |
-
page = result
|
| 39 |
source_label = f"{filename}, p.{page}" if page else filename
|
| 40 |
-
formatted_results.append(f"[Source: {source_label}]\n{result
|
| 41 |
|
| 42 |
return "\n".join(formatted_results)
|
| 43 |
|
|
|
|
| 34 |
|
| 35 |
formatted_results = []
|
| 36 |
for result in results:
|
| 37 |
+
filename = result.metadata.get("filename", "Unknown")
|
| 38 |
+
page = result.metadata.get("page_label")
|
| 39 |
source_label = f"{filename}, p.{page}" if page else filename
|
| 40 |
+
formatted_results.append(f"[Source: {source_label}]\n{result.content}\n")
|
| 41 |
|
| 42 |
return "\n".join(formatted_results)
|
| 43 |
|