| | import chromadb |
| | from chromadb import Settings |
| | from chromadb.utils.batch_utils import create_batches |
| |
|
| | from typing import Optional |
| |
|
| | from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult |
| | from open_webui.config import ( |
| | CHROMA_DATA_PATH, |
| | CHROMA_HTTP_HOST, |
| | CHROMA_HTTP_PORT, |
| | CHROMA_HTTP_HEADERS, |
| | CHROMA_HTTP_SSL, |
| | CHROMA_TENANT, |
| | CHROMA_DATABASE, |
| | ) |
| |
|
| |
|
| | class ChromaClient: |
| | def __init__(self): |
| | if CHROMA_HTTP_HOST != "": |
| | self.client = chromadb.HttpClient( |
| | host=CHROMA_HTTP_HOST, |
| | port=CHROMA_HTTP_PORT, |
| | headers=CHROMA_HTTP_HEADERS, |
| | ssl=CHROMA_HTTP_SSL, |
| | tenant=CHROMA_TENANT, |
| | database=CHROMA_DATABASE, |
| | settings=Settings(allow_reset=True, anonymized_telemetry=False), |
| | ) |
| | else: |
| | self.client = chromadb.PersistentClient( |
| | path=CHROMA_DATA_PATH, |
| | settings=Settings(allow_reset=True, anonymized_telemetry=False), |
| | tenant=CHROMA_TENANT, |
| | database=CHROMA_DATABASE, |
| | ) |
| |
|
| | def has_collection(self, collection_name: str) -> bool: |
| | |
| | collections = self.client.list_collections() |
| | return collection_name in [collection.name for collection in collections] |
| |
|
| | def delete_collection(self, collection_name: str): |
| | |
| | return self.client.delete_collection(name=collection_name) |
| |
|
| | def search( |
| | self, collection_name: str, vectors: list[list[float | int]], limit: int |
| | ) -> Optional[SearchResult]: |
| | |
| | try: |
| | collection = self.client.get_collection(name=collection_name) |
| | if collection: |
| | result = collection.query( |
| | query_embeddings=vectors, |
| | n_results=limit, |
| | ) |
| |
|
| | return SearchResult( |
| | **{ |
| | "ids": result["ids"], |
| | "distances": result["distances"], |
| | "documents": result["documents"], |
| | "metadatas": result["metadatas"], |
| | } |
| | ) |
| | return None |
| | except Exception as e: |
| | return None |
| |
|
| | def query( |
| | self, collection_name: str, filter: dict, limit: Optional[int] = None |
| | ) -> Optional[GetResult]: |
| | |
| | try: |
| | collection = self.client.get_collection(name=collection_name) |
| | if collection: |
| | result = collection.get( |
| | where=filter, |
| | limit=limit, |
| | ) |
| |
|
| | return GetResult( |
| | **{ |
| | "ids": [result["ids"]], |
| | "documents": [result["documents"]], |
| | "metadatas": [result["metadatas"]], |
| | } |
| | ) |
| | return None |
| | except Exception as e: |
| | print(e) |
| | return None |
| |
|
| | def get(self, collection_name: str) -> Optional[GetResult]: |
| | |
| | collection = self.client.get_collection(name=collection_name) |
| | if collection: |
| | result = collection.get() |
| | return GetResult( |
| | **{ |
| | "ids": [result["ids"]], |
| | "documents": [result["documents"]], |
| | "metadatas": [result["metadatas"]], |
| | } |
| | ) |
| | return None |
| |
|
| | def insert(self, collection_name: str, items: list[VectorItem]): |
| | |
| | collection = self.client.get_or_create_collection(name=collection_name) |
| |
|
| | ids = [item["id"] for item in items] |
| | documents = [item["text"] for item in items] |
| | embeddings = [item["vector"] for item in items] |
| | metadatas = [item["metadata"] for item in items] |
| |
|
| | for batch in create_batches( |
| | api=self.client, |
| | documents=documents, |
| | embeddings=embeddings, |
| | ids=ids, |
| | metadatas=metadatas, |
| | ): |
| | collection.add(*batch) |
| |
|
| | def upsert(self, collection_name: str, items: list[VectorItem]): |
| | |
| | collection = self.client.get_or_create_collection(name=collection_name) |
| |
|
| | ids = [item["id"] for item in items] |
| | documents = [item["text"] for item in items] |
| | embeddings = [item["vector"] for item in items] |
| | metadatas = [item["metadata"] for item in items] |
| |
|
| | collection.upsert( |
| | ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas |
| | ) |
| |
|
| | def delete( |
| | self, |
| | collection_name: str, |
| | ids: Optional[list[str]] = None, |
| | filter: Optional[dict] = None, |
| | ): |
| | |
| | collection = self.client.get_collection(name=collection_name) |
| | if collection: |
| | if ids: |
| | collection.delete(ids=ids) |
| | elif filter: |
| | collection.delete(where=filter) |
| |
|
| | def reset(self): |
| | |
| | return self.client.reset() |
| |
|