Note_Retriever / vector_DB.py
ZacBl's picture
Optimisation for faster exec
79a5c64 verified
import faiss
import numpy as np
import streamlit as st
from typing import List, Dict, Optional, Union
from doc_preprocessing import process_files
class VectorDatabase:
"""
A class to manage a vector database using FAISS for efficient similarity search.
"""
def __init__(self, dimension: int = 0):
"""
Initializes the VectorDatabase.
Args:
dimension (int, optional): The dimension of the embeddings. If None, the
index is not initialized until data is added. Defaults to None.
"""
self.dimension = dimension
self.index: Optional[faiss.Index] = None
self.chunks: List[str] = []
self.chunks_metadata: List[Dict] = []
def add_data(self, embeddings: List[np.ndarray], chunks: List[str], chunks_metadata: List[Dict]):
"""
Adds embeddings, text chunks, and metadata to the vector database.
Args:
embeddings (List[List[float]]): A list of embeddings (each a list or numpy array).
chunks (List[str]): A list of corresponding text chunks.
chunks_metadata (List[Dict]): A list of metadata dictionaries, one for each chunk.
"""
if not embeddings:
st.error("No embeddings to add to the database.")
return
# Ensure embeddings are numpy arrays
embeddings = [np.array(emb) for emb in embeddings]
if self.dimension == 0:
self.dimension = embeddings[0].shape[0]
self.index = faiss.IndexFlatL2(self.dimension) # Use L2 distance
elif self.dimension != embeddings[0].shape[0]:
st.error(f"Embedding dimension ({embeddings[0].shape[0]}) does not match database dimension ({self.dimension}).")
return
# Convert embeddings to a float32 numpy array for FAISS
embeddings_np = np.array(embeddings, dtype=np.float32)
if self.index is None:
self.index = faiss.IndexFlatL2(self.dimension)
self.index.add(embeddings_np)
self.chunks = chunks
self.chunks_metadata = chunks_metadata
def query(self, query_embedding: Union[List[float], np.ndarray], k: int = 3) -> List[Dict]:
"""
Queries the vector database for the most similar chunks to a query embedding.
Args:
query_embedding (List[float] or np.ndarray): The embedding of the query.
k (int, optional): The number of nearest neighbors to retrieve. Defaults to 3.
Returns:
List[Dict]: A list of dictionaries, where each dictionary contains:
- "chunk_text" (str): The text of the retrieved chunk.
- "file_name" (str): The name of the file the chunk came from.
- "chunk_index" (int): The index of the chunk in the file.
"""
if self.index is None:
st.error("Vector database is empty. Please upload files and process them first.")
return []
# Ensure query_embedding is a numpy array
query_embedding = np.array(query_embedding, dtype=np.float32).reshape(1, -1) # Reshape for FAISS
dist, indices = self.index.search(query_embedding, k=k)
results = []
for (i, j) in zip(indices[0], dist[0]):
chunk_text = self.chunks[i]
metadata = self.chunks_metadata[i]
results.append({
"chunk_text": chunk_text,
"file_name": metadata["file_name"],
"chunk_index": metadata["chunk_index"],
"score": j
})
return results
def is_empty(self) -> bool:
"""
Checks if the vector database is empty.
Returns:
bool: True if the database is empty, False otherwise.
"""
return self.index is None
if __name__ == "__main__":
# This part is for testing the VectorDatabase class.
# It will only run if you execute this file directly: python vector_database.py
# # Create some dummy data
# embeddings = [
# np.array([1.0, 2.0, 3.0]),
# np.array([4.0, 5.0, 6.0]),
# np.array([7.0, 8.0, 9.0]),
# np.array([10.0, 11.0, 12.0]),
# ]
# chunks = [
# "This is chunk 1 from file A.",
# "This is chunk 2 from file A.",
# "This is chunk 1 from file B.",
# "This is chunk 2 from file B.",
# ]
# chunks_metadata = [
# {"file_name": "file_a.pdf", "chunk_index": 0},
# {"file_name": "file_a.pdf", "chunk_index": 1},
# {"file_name": "file_b.docx", "chunk_index": 0},
# {"file_name": "file_b.docx", "chunk_index": 1},
# ]
dummy_files = ['/Users/zac/Downloads/Janna/verbatimprocs/FZ- revenante - sept24.docx']
chunks, embeddings, chunks_metadata = process_files(dummy_files)
# 1. Initialize the VectorDatabase
vector_db = VectorDatabase(dimension=embeddings[0].shape[0]) # Initialize with dimension
# 2. Add data to the VectorDatabase
vector_db.add_data(embeddings, chunks, chunks_metadata)
print("Data added to VectorDatabase.")
# 3. Perform a query
query_embedding = np.random.rand(embeddings[0].shape[0]).astype(np.float32) # Random query embedding
results = vector_db.query(query_embedding, k=2) # Get the top 2 results
print("\nQuery results:")
for result in results:
print(f"Chunk: {result['chunk_text']}")
print(f" File: {result['file_name']}")
print(f" Index: {result['chunk_index']}")
# 4. Check if the database is empty
print(f"\nIs the database empty? {vector_db.is_empty()}") # Check is_empty method
# 5. Initialize without dimension and then add data
vector_db2 = VectorDatabase()
vector_db2.add_data(embeddings, chunks, chunks_metadata)
print("\nData added to VectorDatabase2 (without initial dimension).")
query_embedding_2 = np.random.rand(embeddings[0].shape[0]).astype(np.float32) # Random query embedding
results_2 = vector_db2.query(query_embedding_2, k=1)
print("\nQuery results from VectorDatabase2:")
for result in results_2:
print(f"Chunk: {result['chunk_text']}")
print(f" File: {result['file_name']}")
print(f" Index: {result['chunk_index']}")
"""
Key improvements and explanations:
Class Structure: The VectorDatabase class encapsulates the FAISS index, chunks, and metadata, providing a clean and organized way to manage the vector database.
Initialization: The __init__ method now takes an optional dimension argument. If not provided during initialization, the dimension is inferred when the first data is added. This provides more flexibility.
Data Handling: The add_data method takes lists of embeddings, chunks, and metadata, and stores them in the object. It also converts the embeddings to a float32 numpy array, which is the format FAISS expects, and checks for dimension consistency.
Querying: The query method performs a similarity search using FAISS and returns a list of dictionaries containing the relevant information. It also handles the case where the database is empty.
Error Handling: The add_data and query methods include error handling for invalid input or an empty database.
Clarity: The code is well-commented and easy to understand.
Testing: The if __name__ == "__main__": block provides a comprehensive test of the VectorDatabase class, demonstrating how to add data, perform queries, and check if the database is empty. I've added a test for initializing the database without a dimension.
"""