| | from enum import Enum |
| | from typing import Any, Dict, List, Optional |
| |
|
| | from langchain_core.callbacks import ( |
| | AsyncCallbackManagerForRetrieverRun, |
| | CallbackManagerForRetrieverRun, |
| | ) |
| | from langchain_core.documents import Document |
| | from langchain_core.retrievers import BaseRetriever |
| | from langchain_core.stores import BaseStore, ByteStore |
| | from langchain_core.vectorstores import VectorStore |
| | from pydantic import Field, model_validator |
| |
|
| | from langchain.storage._lc_store import create_kv_docstore |
| |
|
| |
|
| | class SearchType(str, Enum): |
| | """Enumerator of the types of search to perform.""" |
| |
|
| | similarity = "similarity" |
| | """Similarity search.""" |
| | similarity_score_threshold = "similarity_score_threshold" |
| | """Similarity search with a score threshold.""" |
| | mmr = "mmr" |
| | """Maximal Marginal Relevance reranking of similarity search.""" |
| |
|
| |
|
| | class MultiVectorRetriever(BaseRetriever): |
| | """Retrieve from a set of multiple embeddings for the same document.""" |
| |
|
| | vectorstore: VectorStore |
| | """The underlying vectorstore to use to store small chunks |
| | and their embedding vectors""" |
| | byte_store: Optional[ByteStore] = None |
| | """The lower-level backing storage layer for the parent documents""" |
| | docstore: BaseStore[str, Document] |
| | """The storage interface for the parent documents""" |
| | id_key: str = "doc_id" |
| | search_kwargs: dict = Field(default_factory=dict) |
| | """Keyword arguments to pass to the search function.""" |
| | search_type: SearchType = SearchType.similarity |
| | """Type of search to perform (similarity / mmr)""" |
| |
|
| | @model_validator(mode="before") |
| | @classmethod |
| | def shim_docstore(cls, values: Dict) -> Any: |
| | byte_store = values.get("byte_store") |
| | docstore = values.get("docstore") |
| | if byte_store is not None: |
| | docstore = create_kv_docstore(byte_store) |
| | elif docstore is None: |
| | raise Exception("You must pass a `byte_store` parameter.") |
| | values["docstore"] = docstore |
| | return values |
| |
|
| | def _get_relevant_documents( |
| | self, query: str, *, run_manager: CallbackManagerForRetrieverRun |
| | ) -> List[Document]: |
| | """Get documents relevant to a query. |
| | Args: |
| | query: String to find relevant documents for |
| | run_manager: The callbacks handler to use |
| | Returns: |
| | List of relevant documents |
| | """ |
| | if self.search_type == SearchType.mmr: |
| | sub_docs = self.vectorstore.max_marginal_relevance_search( |
| | query, **self.search_kwargs |
| | ) |
| | elif self.search_type == SearchType.similarity_score_threshold: |
| | sub_docs_and_similarities = ( |
| | self.vectorstore.similarity_search_with_relevance_scores( |
| | query, **self.search_kwargs |
| | ) |
| | ) |
| | sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] |
| | else: |
| | sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs) |
| |
|
| | |
| | ids = [] |
| | for d in sub_docs: |
| | if self.id_key in d.metadata and d.metadata[self.id_key] not in ids: |
| | ids.append(d.metadata[self.id_key]) |
| | docs = self.docstore.mget(ids) |
| | return [d for d in docs if d is not None] |
| |
|
| | async def _aget_relevant_documents( |
| | self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun |
| | ) -> List[Document]: |
| | """Asynchronously get documents relevant to a query. |
| | Args: |
| | query: String to find relevant documents for |
| | run_manager: The callbacks handler to use |
| | Returns: |
| | List of relevant documents |
| | """ |
| | if self.search_type == SearchType.mmr: |
| | sub_docs = await self.vectorstore.amax_marginal_relevance_search( |
| | query, **self.search_kwargs |
| | ) |
| | elif self.search_type == SearchType.similarity_score_threshold: |
| | sub_docs_and_similarities = ( |
| | await self.vectorstore.asimilarity_search_with_relevance_scores( |
| | query, **self.search_kwargs |
| | ) |
| | ) |
| | sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] |
| | else: |
| | sub_docs = await self.vectorstore.asimilarity_search( |
| | query, **self.search_kwargs |
| | ) |
| |
|
| | |
| | ids = [] |
| | for d in sub_docs: |
| | if self.id_key in d.metadata and d.metadata[self.id_key] not in ids: |
| | ids.append(d.metadata[self.id_key]) |
| | docs = await self.docstore.amget(ids) |
| | return [d for d in docs if d is not None] |
| |
|