| | import os |
| | from typing import Dict, List, Any |
| |
|
| | import uuid |
| | from copy import deepcopy |
| | from langchain.embeddings import OpenAIEmbeddings |
| |
|
| | from chromadb import Client as ChromaClient |
| |
|
| | from flows.base_flows import AtomicFlow |
| |
|
| | import hydra |
| |
|
| | class ChromaDBFlow(AtomicFlow): |
| |
|
| | def __init__(self, backend,**kwargs): |
| | super().__init__(**kwargs) |
| | self.client = ChromaClient() |
| | self.collection = self.client.get_or_create_collection(name=self.flow_config["name"]) |
| | self.backend = backend |
| |
|
| | @classmethod |
| | def _set_up_backend(cls, config): |
| | kwargs = {} |
| |
|
| | kwargs["backend"] = \ |
| | hydra.utils.instantiate(config['backend'], _convert_="partial") |
| | |
| | return kwargs |
| | |
| | @classmethod |
| | def instantiate_from_config(cls, config): |
| | flow_config = deepcopy(config) |
| |
|
| | kwargs = {"flow_config": flow_config} |
| |
|
| | |
| | kwargs.update(cls._set_up_backend(flow_config)) |
| |
|
| | |
| | return cls(**kwargs) |
| | |
| | def get_input_keys(self) -> List[str]: |
| | return self.flow_config["input_keys"] |
| |
|
| | def get_output_keys(self) -> List[str]: |
| | return self.flow_config["output_keys"] |
| |
|
| | def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
| |
|
| | api_information = self.backend.get_key() |
| |
|
| | if api_information.backend_used == "openai": |
| | embeddings = OpenAIEmbeddings(openai_api_key=api_information.api_key) |
| | else: |
| | |
| | embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY")) |
| | response = {} |
| |
|
| | operation = input_data["operation"] |
| | if operation not in ["write", "read"]: |
| | raise ValueError(f"Operation '{operation}' not supported") |
| |
|
| | content = input_data["content"] |
| | if operation == "read": |
| | if not isinstance(content, str): |
| | raise ValueError(f"content(query) must be a string during read, got {type(content)}: {content}") |
| | if content == "": |
| | response["retrieved"] = [[""]] |
| | return response |
| | query = content |
| | query_result = self.collection.query( |
| | query_embeddings=embeddings.embed_query(query), |
| | n_results=self.flow_config["n_results"] |
| | ) |
| |
|
| | response["retrieved"] = [doc for doc in query_result["documents"]] |
| |
|
| | elif operation == "write": |
| | if content != "": |
| | if not isinstance(content, list): |
| | content = [content] |
| | documents = content |
| | self.collection.add( |
| | ids=[str(uuid.uuid4()) for _ in range(len(documents))], |
| | embeddings=embeddings.embed_documents(documents), |
| | documents=documents |
| | ) |
| | response["retrieved"] = "" |
| |
|
| | return response |
| |
|