| | from qdrant_client import QdrantClient |
| | from qdrant_client.http import models |
| | from tqdm import tqdm |
| | import os |
| | import time |
| | import numpy as np |
| | from loguru import logger |
| | import stamina |
| | from typing import Any, List, Tuple, Type, Literal, Optional, Union, Dict |
| |
|
| | class MyQdrantClient: |
| | def __init__(self, path: str): |
| | self.qdrant_client = QdrantClient(path=path) |
| | logger.debug(f"Qdrant client created at {path}") |
| |
|
| | def create_collection(self, collection_name: str, vector_dim: int = 128, vector_type: str = "colbert"): |
| | if vector_type == "colbert": |
| | self.qdrant_client.create_collection( |
| | collection_name=collection_name, |
| | on_disk_payload=True, |
| | vectors_config=models.VectorParams( |
| | size=vector_dim, |
| | distance=models.Distance.COSINE, |
| | on_disk=True, |
| | multivector_config=models.MultiVectorConfig( |
| | comparator=models.MultiVectorComparator.MAX_SIM |
| | ), |
| | |
| | |
| | |
| | |
| | |
| | ), |
| | ) |
| | elif vector_type == "dense": |
| | self.qdrant_client.create_collection( |
| | collection_name=collection_name, |
| | on_disk_payload=True, |
| | vectors_config=models.VectorParams( |
| | size=vector_dim, |
| | distance=models.Distance.COSINE, |
| | on_disk=True, |
| | ), |
| | ) |
| | else: |
| | raise ValueError(f"Vector type {vector_type} not supported") |
| |
|
| | logger.debug(f"Qdrant collection of type {vector_type} : {collection_name} created") |
| | |
| | def delete_collection(self, collection_name: str): |
| | self.qdrant_client.delete_collection(collection_name=collection_name) |
| |
|
| | @stamina.retry(on=Exception, attempts=3) |
| | def upsert_to_qdrant(self, batch, collection_name: str): |
| | try: |
| | self.qdrant_client.upsert( |
| | collection_name=collection_name, |
| | points=batch, |
| | wait=False, |
| | ) |
| | except Exception as e: |
| | logger.error(f"Error during upsert: {e}") |
| | return False |
| | return True |
| |
|
| | def upsert_multivector(self, index: int, multivector_input_list: list[Any], collection_name: str): |
| | try: |
| | points = [] |
| | for j, multivector in enumerate(multivector_input_list): |
| | points.append( |
| | models.PointStruct( |
| | id=index + j, |
| | vector=multivector, |
| | payload={ |
| | "source": "user uploaded data" |
| | }, |
| | ) |
| | ) |
| | |
| | |
| | self.upsert_to_qdrant(points, collection_name) |
| | except Exception as e: |
| | logger.error(f"Vector DB client - error during upsert: {e}") |
| | |
| | def query_multivector(self, multivector_input, collection_name: str, top_k:int=10) -> list[int]: |
| | try: |
| | |
| | |
| |
|
| | start_time = time.time() |
| | search_result = self.qdrant_client.query_points( |
| | collection_name=collection_name, |
| | query=multivector_input, |
| | limit=top_k, |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ) |
| | end_time = time.time() |
| | elapsed_time = end_time - start_time |
| | logger.debug(f"Search completed in {elapsed_time:.4f} seconds") |
| |
|
| | result = [x.id for x in search_result.points] |
| | return result |
| |
|
| | except Exception as e: |
| | logger.error(f"Error during query: {e}") |
| | return None |
| |
|
| | def __del__(self): |
| | self.qdrant_client.close() |
| | |