| |
|
|
| from qdrant_client import QdrantClient |
| from qdrant_client.http import models |
| import os |
| import logging |
| import uuid |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
| qdrant_client = QdrantClient( |
| url=os.getenv('QDRANT_URL'), |
| api_key=os.getenv('QDRANT_API_KEY') |
| ) |
|
|
| def create_collection_if_not_exists(collection_name, vector_size): |
| try: |
| |
| collections = qdrant_client.get_collections().collections |
| if not any(collection.name == collection_name for collection in collections): |
| |
| qdrant_client.create_collection( |
| collection_name=collection_name, |
| vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE) |
| ) |
| logging.info(f"Created new collection: {collection_name}") |
| else: |
| logging.info(f"Collection {collection_name} already exists") |
| except Exception as e: |
| logging.error(f"Error creating collection: {str(e)}") |
| raise |
|
|
| def store_embeddings(chunks, embeddings, user_id, data_source_id, file_id, organization_id, s3_bucket_key, total_tokens): |
| try: |
| collection_name = "embed" |
| vector_size = len(embeddings[0]) |
|
|
| |
| create_collection_if_not_exists(collection_name, vector_size) |
|
|
| |
| points = [] |
| for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): |
| chunk_id = str(uuid.uuid4()) |
| points.append( |
| models.PointStruct( |
| id=chunk_id, |
| vector=embedding.tolist(), |
| payload={ |
| "user_id": user_id, |
| "data_source_id": data_source_id, |
| "file_id": file_id, |
| "organization_id": organization_id, |
| "chunk_index": i, |
| "chunk_text": chunk, |
| "s3_bucket_key": s3_bucket_key, |
| "total_tokens": total_tokens |
| |
| } |
| ) |
| ) |
|
|
| |
| qdrant_client.upsert( |
| collection_name=collection_name, |
| points=points |
| ) |
| logging.info(f"Successfully stored {len(points)} embeddings") |
| except Exception as e: |
| logging.error(f"Error storing embeddings in Qdrant: {str(e)}") |
| raise |