Rifqi Hafizuddin commited on
Commit
e4f62b8
·
1 Parent(s): de32ab0

[NOTICKET] minor refactoring

Browse files
src/api/v1/chat.py CHANGED
@@ -48,43 +48,43 @@ class ChatRequest(BaseModel):
48
  message: str
49
 
50
 
51
- def _format_context(results: List[Dict[str, Any]]) -> str:
52
  """Format retrieval results as context string for the LLM."""
53
  lines = []
54
  for result in results:
55
- filename = result["metadata"].get("filename", "Unknown")
56
- page = result["metadata"].get("page_label")
57
  source_label = f"{filename}, p.{page}" if page else filename
58
- lines.append(f"[Source: {source_label}]\n{result['content']}\n")
59
  return "\n".join(lines)
60
 
61
 
62
- def _extract_sources(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
63
  """Extract deduplicated source references from retrieval results."""
64
  seen = set()
65
  sources = []
66
  for result in results:
67
- if "document_id" in result["metadata"].get("data", {}):
68
- meta = result["metadata"]
69
- key = (meta.get("data", {}).get("document_id"), meta.get("data", {}).get("page_label"))
 
70
  if key not in seen:
71
  seen.add(key)
72
  sources.append({
73
- "document_id": meta.get("data", {}).get("document_id"),
74
- "filename": meta.get("data", {}).get("filename", "Unknown"),
75
- "page_label": meta.get("data", {}).get("page_label", "Unknown"),
76
  })
77
  else:
78
- meta = result["metadata"]
79
- key = (meta.get("data", {}).get("table_name"), meta.get("data", {}).get("column_name"))
80
  if key not in seen:
81
  seen.add(key)
82
- table_name = meta.get("data", {}).get("table_name")
83
  user_id = meta.get("user_id")
84
  sources.append({
85
  "document_id": f"{user_id}_{table_name}",
86
- "filename": meta.get("data", {}).get("table_name", "Unknown"),
87
- "page_label": meta.get("data", {}).get("column_name", "Unknown"),
88
  })
89
 
90
  logger.debug(f"Extracted sources: {sources}")
@@ -229,17 +229,8 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
229
 
230
  source_hint = intent_result.get("source_hint", "both")
231
  if source_hint in ("schema", "both"):
232
- retrieval_objects = [
233
- RetrievalResult(
234
- content=r["content"],
235
- metadata=r["metadata"],
236
- score=0.0,
237
- source_type=r["metadata"].get("source_type", ""),
238
- )
239
- for r in raw_results
240
- ]
241
  query_results = await query_executor.execute(
242
- results=retrieval_objects,
243
  user_id=request.user_id,
244
  db=db,
245
  question=intent_result.get("search_query") or request.message,
 
48
  message: str
49
 
50
 
51
+ def _format_context(results: List[RetrievalResult]) -> str:
52
  """Format retrieval results as context string for the LLM."""
53
  lines = []
54
  for result in results:
55
+ filename = result.metadata.get("filename", "Unknown")
56
+ page = result.metadata.get("page_label")
57
  source_label = f"{filename}, p.{page}" if page else filename
58
+ lines.append(f"[Source: {source_label}]\n{result.content}\n")
59
  return "\n".join(lines)
60
 
61
 
62
+ def _extract_sources(results: List[RetrievalResult]) -> List[Dict[str, Any]]:
63
  """Extract deduplicated source references from retrieval results."""
64
  seen = set()
65
  sources = []
66
  for result in results:
67
+ meta = result.metadata
68
+ data = meta.get("data", {})
69
+ if "document_id" in data:
70
+ key = (data.get("document_id"), data.get("page_label"))
71
  if key not in seen:
72
  seen.add(key)
73
  sources.append({
74
+ "document_id": data.get("document_id"),
75
+ "filename": data.get("filename", "Unknown"),
76
+ "page_label": data.get("page_label", "Unknown"),
77
  })
78
  else:
79
+ key = (data.get("table_name"), data.get("column_name"))
 
80
  if key not in seen:
81
  seen.add(key)
82
+ table_name = data.get("table_name")
83
  user_id = meta.get("user_id")
84
  sources.append({
85
  "document_id": f"{user_id}_{table_name}",
86
+ "filename": data.get("table_name", "Unknown"),
87
+ "page_label": data.get("column_name", "Unknown"),
88
  })
89
 
90
  logger.debug(f"Extracted sources: {sources}")
 
229
 
230
  source_hint = intent_result.get("source_hint", "both")
231
  if source_hint in ("schema", "both"):
 
 
 
 
 
 
 
 
 
232
  query_results = await query_executor.execute(
233
+ results=raw_results,
234
  user_id=request.user_id,
235
  db=db,
236
  question=intent_result.get("search_query") or request.message,
src/rag/retriever.py CHANGED
@@ -1,10 +1,9 @@
1
  """Public retrieval API — thin wrapper around RetrievalRouter."""
2
 
3
- from typing import Any
4
-
5
  from sqlalchemy.ext.asyncio import AsyncSession
6
 
7
  from src.middlewares.logging import get_logger
 
8
  from src.rag.retrievers.document import document_retriever
9
  from src.rag.retrievers.schema import schema_retriever
10
  from src.rag.router import RetrievalRouter, SourceHint
@@ -16,7 +15,11 @@ class RetrieverService:
16
  """Public retrieval service used by chat.py and search tools.
17
 
18
  Delegates to RetrievalRouter which dispatches based on source_hint.
19
- Returns List[Dict] to preserve backward compatibility with chat.py.
 
 
 
 
20
  """
21
 
22
  def __init__(self):
@@ -32,10 +35,9 @@ class RetrieverService:
32
  db: AsyncSession,
33
  k: int = 5,
34
  source_hint: SourceHint = "both",
35
- ) -> list[dict[str, Any]]:
36
  try:
37
- results = await self._router.retrieve(query, user_id, source_hint, k)
38
- return [{"content": r.content, "metadata": r.metadata} for r in results]
39
  except Exception as e:
40
  logger.error("retrieval failed", error=str(e))
41
  return []
 
1
  """Public retrieval API — thin wrapper around RetrievalRouter."""
2
 
 
 
3
  from sqlalchemy.ext.asyncio import AsyncSession
4
 
5
  from src.middlewares.logging import get_logger
6
+ from src.rag.base import RetrievalResult
7
  from src.rag.retrievers.document import document_retriever
8
  from src.rag.retrievers.schema import schema_retriever
9
  from src.rag.router import RetrievalRouter, SourceHint
 
15
  """Public retrieval service used by chat.py and search tools.
16
 
17
  Delegates to RetrievalRouter which dispatches based on source_hint.
18
+ Returns RetrievalResult objects directly so downstream consumers
19
+ (db_executor, tabular_executor) can be fed without lossy dict
20
+ conversion. The `db` parameter is accepted for call-site compatibility
21
+ but currently unused — retrieval reads PGVector via _pgvector_engine
22
+ inside each retriever.
23
  """
24
 
25
  def __init__(self):
 
35
  db: AsyncSession,
36
  k: int = 5,
37
  source_hint: SourceHint = "both",
38
+ ) -> list[RetrievalResult]:
39
  try:
40
+ return await self._router.retrieve(query, user_id, source_hint, k)
 
41
  except Exception as e:
42
  logger.error("retrieval failed", error=str(e))
43
  return []
src/rag/retrievers/baseline.py CHANGED
@@ -13,8 +13,14 @@ logger = get_logger("retriever")
13
  _RETRIEVAL_CACHE_TTL = 3600 # 1 hour
14
 
15
 
16
- class RetrieverService:
17
- """Service for retrieving relevant documents."""
 
 
 
 
 
 
18
 
19
  def __init__(self):
20
  self.vector_store = get_vector_store()
@@ -67,4 +73,4 @@ class RetrieverService:
67
  return []
68
 
69
 
70
- retriever = RetrieverService()
 
13
  _RETRIEVAL_CACHE_TTL = 3600 # 1 hour
14
 
15
 
16
+ class BaselineRetrieverService:
17
+ """Baseline (pre-Phase-1) retriever preserved for benchmark comparison.
18
+
19
+ Renamed from RetrieverService so it doesn't shadow the production wrapper
20
+ at src/rag/retriever.py. Production code imports from src.rag.retriever;
21
+ benchmark scripts that want this baseline must import explicitly from
22
+ src.rag.retrievers.baseline.
23
+ """
24
 
25
  def __init__(self):
26
  self.vector_store = get_vector_store()
 
73
  return []
74
 
75
 
76
+ baseline_retriever = BaselineRetrieverService()
src/tools/search.py CHANGED
@@ -34,10 +34,10 @@ async def search_documents(
34
 
35
  formatted_results = []
36
  for result in results:
37
- filename = result["metadata"].get("filename", "Unknown")
38
- page = result["metadata"].get("page_label")
39
  source_label = f"{filename}, p.{page}" if page else filename
40
- formatted_results.append(f"[Source: {source_label}]\n{result['content']}\n")
41
 
42
  return "\n".join(formatted_results)
43
 
 
34
 
35
  formatted_results = []
36
  for result in results:
37
+ filename = result.metadata.get("filename", "Unknown")
38
+ page = result.metadata.get("page_label")
39
  source_label = f"{filename}, p.{page}" if page else filename
40
+ formatted_results.append(f"[Source: {source_label}]\n{result.content}\n")
41
 
42
  return "\n".join(formatted_results)
43