| | import os |
| | from typing import Dict, List, Any |
| |
|
| | import uuid |
| | from copy import deepcopy |
| | from langchain.embeddings import OpenAIEmbeddings |
| |
|
| | from chromadb import Client as ChromaClient |
| | from aiflows.messages import FlowMessage |
| | from aiflows.base_flows import AtomicFlow |
| |
|
| | import hydra |
| |
|
| | import os |
| | from typing import Dict, List, Any |
| |
|
| | import uuid |
| | from copy import deepcopy |
| | from langchain.embeddings import OpenAIEmbeddings |
| |
|
| | from aiflows.messages import FlowMessage |
| | from aiflows.base_flows import AtomicFlow |
| | from langchain.text_splitter import CharacterTextSplitter |
| | from langchain.document_loaders import TextLoader |
| | from langchain.vectorstores import Chroma |
| | import hydra |
| |
|
| | class ChromaDBFlow(AtomicFlow): |
| | """ A flow that uses the ChromaDB model to write and read memories stored in a database |
| | |
| | *Configuration Parameters*: |
| | |
| | - `name` (str): The name of the flow. Default: "chroma_db" |
| | - `description` (str): A description of the flow. This description is used to generate the help message of the flow. |
| | Default: "ChromaDB is a document store that uses vector embeddings to store and retrieve documents." |
| | - `backend` (Dict[str, Any]): The configuration of the backend which is used to fetch api keys. Default: LiteLLMBackend with the |
| | default parameters of LiteLLMBackend (see aiflows.backends.LiteLLMBackend). Except for the following parameter whose default value is overwritten: |
| | - `api_infos` (List[Dict[str, Any]]): The list of api infos. Default: No default value, this parameter is required. |
| | - `model_name` (str): The name of the model. Default: "". In the current implementation, this parameter is not used. |
| | - `n_results` (int): The number of results to retrieve when reading from the database. Default: 5 |
| | - Other parameters are inherited from the default configuration of AtomicFlow (see AtomicFlow) |
| | |
| | *Input Interface*: |
| | |
| | - `operation` (str): The operation to perform. It can be "write" or "read". |
| | - `content` (str or List[str]): The content to write or read. If operation is "write", it must be a string or a list of strings. If operation is "read", it must be a string. |
| | |
| | *Output Interface*: |
| | |
| | - `retrieved` (str or List[str]): The retrieved content. If operation is "write", it is an empty string. If operation is "read", it is a string or a list of strings. |
| | |
| | :param backend: The backend of the flow (used to retrieve the API key) |
| | :type backend: LiteLLMBackend |
| | :param \**kwargs: Additional arguments to pass to the flow. |
| | """ |
| | def __init__(self, backend,**kwargs): |
| | super().__init__(**kwargs) |
| | |
| | self.backend = backend |
| |
|
| | def set_up_flow_state(self): |
| | super().set_up_flow_state() |
| | self.flow_state["db_created"] =False |
| | |
| | @classmethod |
| | def _set_up_backend(cls, config): |
| | """ This instantiates the backend of the flow from a configuration file. |
| | |
| | :param config: The configuration of the backend. |
| | :type config: Dict[str, Any] |
| | :return: The backend of the flow. |
| | :rtype: Dict[str, LiteLLMBackend] |
| | """ |
| | kwargs = {} |
| |
|
| | kwargs["backend"] = \ |
| | hydra.utils.instantiate(config['backend'], _convert_="partial") |
| | |
| | return kwargs |
| | |
| | @classmethod |
| | def instantiate_from_config(cls, config): |
| | """ This method instantiates the flow from a configuration file |
| | |
| | :param config: The configuration of the flow. |
| | :type config: Dict[str, Any] |
| | :return: The instantiated flow. |
| | :rtype: ChromaDBFlow |
| | """ |
| | flow_config = deepcopy(config) |
| |
|
| | kwargs = {"flow_config": flow_config} |
| |
|
| | |
| | kwargs.update(cls._set_up_backend(flow_config)) |
| |
|
| | |
| | return cls(**kwargs) |
| | |
| | |
| | def get_embeddings_model(self): |
| | api_information = self.backend.get_key() |
| | if api_information.backend_used == "openai": |
| | embeddings = OpenAIEmbeddings(openai_api_key=api_information.api_key) |
| | else: |
| | |
| | embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY")) |
| | return embeddings |
| | |
| | |
| | def get_db(self): |
| | db_created = self.flow_state["db_created"] |
| | |
| | if hasattr(self, 'db'): |
| | |
| | db = self.db |
| | |
| | elif db_created: |
| | |
| | db = Chroma( |
| | persist_directory=self.flow_config["persist_directory"], |
| | embedding_function=self.get_embeddings_model() |
| | ) |
| | else: |
| | |
| | full_docs = [] |
| | text_splitter = CharacterTextSplitter( |
| | chunk_size=self.flow_config["chunk_size"], |
| | chunk_overlap=self.flow_config["chunk_overlap"] |
| | ) |
| | |
| | for path in self.flow_config["paths_to_data"]: |
| | loader = TextLoader(path) |
| | documents = loader.load() |
| | docs = text_splitter.split_documents(documents) |
| | full_docs.extend(docs) |
| | |
| | db = Chroma.from_documents( |
| | full_docs, |
| | self.get_embeddings_model(), |
| | persist_directory=self.flow_config["persist_directory"] |
| | ) |
| | |
| | self.flow_state["db_created"] = True |
| | return db |
| | |
| | def run(self, input_message: FlowMessage): |
| | """ This method runs the flow. It runs the ChromaDBFlow. It either writes or reads memories from the database. |
| | |
| | :param input_message: The input message of the flow. |
| | :type input_message: FlowMessage |
| | """ |
| | |
| | self.db = self.get_db() |
| | |
| | input_data = input_message.data |
| | |
| | embeddings = self.get_embeddings_model() |
| | |
| | response = {} |
| |
|
| | operation = input_data["operation"] |
| | if operation not in ["write", "read"]: |
| | raise ValueError(f"Operation '{operation}' not supported") |
| |
|
| | content = input_data["content"] |
| | |
| | if operation == "read": |
| | if not isinstance(content, str): |
| | raise ValueError(f"content(query) must be a string during read, got {type(content)}: {content}") |
| | if content == "": |
| | response["retrieved"] = [[""]] |
| | else: |
| | query = content |
| | query_result = self.db.similarity_search(query, **self.flow_config["similarity_search_kwargs"]) |
| | |
| | response["retrieved"] = [doc.page_content for doc in query_result] |
| |
|
| | elif operation == "write": |
| | if content != "": |
| | if not isinstance(content, list): |
| | content = [content] |
| | documents = content |
| | self.db._collection.add( |
| | ids=[str(uuid.uuid4()) for _ in range(len(documents))], |
| | embeddings=embeddings.embed_documents(documents), |
| | documents=documents |
| | ) |
| | |
| | response["retrieved"] = "" |
| |
|
| | reply = self.package_output_message( |
| | input_message = input_message, |
| | response = response |
| | ) |
| | self.send_message(reply) |
| |
|