File size: 5,638 Bytes
8802920 52999bc c93ec90 8802920 52999bc 8802920 52999bc 8802920 52999bc 8802920 52999bc 8802920 c93ec90 8802920 52999bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular)."""
import math
from langchain_postgres import PGVector
from langchain_postgres.vectorstores import DistanceStrategy
from langchain_openai import AzureOpenAIEmbeddings
from sqlalchemy import text
from src.config.settings import settings
from src.db.postgres.connection import _pgvector_engine
from src.db.postgres.vector_store import get_vector_store
from src.middlewares.logging import get_logger
from src.rag.base import BaseRetriever, RetrievalResult
logger = get_logger("document_retriever")
# Change this one line to switch retrieval method
# Options: "mmr" | "cosine" | "euclidean" | "inner_product" | "manhattan"
_RETRIEVAL_METHOD = "mmr"
_TABULAR_TYPES = {"csv", "xlsx"}
_FETCH_K = 20
_LAMBDA_MULT = 0.5
_COLLECTION_NAME = "document_embeddings"
_embeddings = AzureOpenAIEmbeddings(
azure_deployment=settings.azureai_deployment_name_embedding,
openai_api_version=settings.azureai_api_version_embedding,
azure_endpoint=settings.azureai_endpoint_url_embedding,
api_key=settings.azureai_api_key_embedding,
)
_euclidean_store = PGVector(
embeddings=_embeddings,
connection=_pgvector_engine,
collection_name=_COLLECTION_NAME,
distance_strategy=DistanceStrategy.EUCLIDEAN,
use_jsonb=True,
async_mode=True,
create_extension=False,
)
_ip_store = PGVector(
embeddings=_embeddings,
connection=_pgvector_engine,
collection_name=_COLLECTION_NAME,
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
use_jsonb=True,
async_mode=True,
create_extension=False,
)
_MANHATTAN_SQL = text("""
SELECT
lpe.document,
lpe.cmetadata,
lpe.embedding <+> CAST(:embedding AS vector) AS distance
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = :collection
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'document'
ORDER BY distance ASC
LIMIT :k
""")
class DocumentRetriever(BaseRetriever):
def __init__(self) -> None:
self.vector_store = get_vector_store()
async def retrieve(
self, query: str, user_id: str, k: int = 5
) -> list[RetrievalResult]:
filter_ = {"user_id": user_id, "source_type": "document"}
fetch_k = k + len(_TABULAR_TYPES)
if _RETRIEVAL_METHOD == "manhattan":
return await self._retrieve_manhattan(query, user_id, k, fetch_k)
if _RETRIEVAL_METHOD == "mmr":
docs = await self.vector_store.amax_marginal_relevance_search(
query=query,
k=fetch_k,
fetch_k=_FETCH_K,
lambda_mult=_LAMBDA_MULT,
filter=filter_,
)
cosine = await self.vector_store.asimilarity_search_with_score(
query=query, k=fetch_k, filter=filter_,
)
score_map = {doc.page_content: score for doc, score in cosine}
docs_with_scores = [(doc, score_map.get(doc.page_content, 0.0)) for doc in docs]
elif _RETRIEVAL_METHOD == "euclidean":
docs_with_scores = await _euclidean_store.asimilarity_search_with_score(
query=query, k=fetch_k, filter=filter_,
)
elif _RETRIEVAL_METHOD == "inner_product":
docs_with_scores = await _ip_store.asimilarity_search_with_score(
query=query, k=fetch_k, filter=filter_,
)
else: # cosine
docs_with_scores = await self.vector_store.asimilarity_search_with_score(
query=query, k=fetch_k, filter=filter_,
)
results = []
for doc, score in docs_with_scores:
file_type = doc.metadata.get("data", {}).get("file_type", "")
if file_type not in _TABULAR_TYPES:
results.append(RetrievalResult(
content=doc.page_content,
metadata=doc.metadata,
score=score,
source_type="document",
))
if len(results) == k:
break
logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
return results
async def _retrieve_manhattan(
self, query: str, user_id: str, k: int, fetch_k: int
) -> list[RetrievalResult]:
query_vector = await _embeddings.aembed_query(query)
if not all(math.isfinite(v) for v in query_vector):
raise ValueError("Embedding vector contains NaN or Infinity values.")
vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
async with _pgvector_engine.connect() as conn:
result = await conn.execute(_MANHATTAN_SQL, {
"embedding": vector_str,
"collection": _COLLECTION_NAME,
"user_id": user_id,
"k": fetch_k,
})
rows = result.fetchall()
results = []
for row in rows:
file_type = row.cmetadata.get("data", {}).get("file_type", "")
if file_type not in _TABULAR_TYPES:
results.append(RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.distance),
source_type="document",
))
if len(results) == k:
break
logger.info("retrieved chunks", method="manhattan", count=len(results))
return results
document_retriever = DocumentRetriever()
|