Spaces:
Sleeping
Sleeping
| import faiss | |
| import numpy as np | |
| import streamlit as st | |
| from typing import List, Dict, Optional, Union | |
| from doc_preprocessing import process_files | |
| class VectorDatabase: | |
| """ | |
| A class to manage a vector database using FAISS for efficient similarity search. | |
| """ | |
| def __init__(self, dimension: int = 0): | |
| """ | |
| Initializes the VectorDatabase. | |
| Args: | |
| dimension (int, optional): The dimension of the embeddings. If None, the | |
| index is not initialized until data is added. Defaults to None. | |
| """ | |
| self.dimension = dimension | |
| self.index: Optional[faiss.Index] = None | |
| self.chunks: List[str] = [] | |
| self.chunks_metadata: List[Dict] = [] | |
| def add_data(self, embeddings: List[np.ndarray], chunks: List[str], chunks_metadata: List[Dict]): | |
| """ | |
| Adds embeddings, text chunks, and metadata to the vector database. | |
| Args: | |
| embeddings (List[List[float]]): A list of embeddings (each a list or numpy array). | |
| chunks (List[str]): A list of corresponding text chunks. | |
| chunks_metadata (List[Dict]): A list of metadata dictionaries, one for each chunk. | |
| """ | |
| if not embeddings: | |
| st.error("No embeddings to add to the database.") | |
| return | |
| # Ensure embeddings are numpy arrays | |
| embeddings = [np.array(emb) for emb in embeddings] | |
| if self.dimension == 0: | |
| self.dimension = embeddings[0].shape[0] | |
| self.index = faiss.IndexFlatL2(self.dimension) # Use L2 distance | |
| elif self.dimension != embeddings[0].shape[0]: | |
| st.error(f"Embedding dimension ({embeddings[0].shape[0]}) does not match database dimension ({self.dimension}).") | |
| return | |
| # Convert embeddings to a float32 numpy array for FAISS | |
| embeddings_np = np.array(embeddings, dtype=np.float32) | |
| if self.index is None: | |
| self.index = faiss.IndexFlatL2(self.dimension) | |
| self.index.add(embeddings_np) | |
| self.chunks = chunks | |
| self.chunks_metadata = chunks_metadata | |
| def query(self, query_embedding: Union[List[float], np.ndarray], k: int = 3) -> List[Dict]: | |
| """ | |
| Queries the vector database for the most similar chunks to a query embedding. | |
| Args: | |
| query_embedding (List[float] or np.ndarray): The embedding of the query. | |
| k (int, optional): The number of nearest neighbors to retrieve. Defaults to 3. | |
| Returns: | |
| List[Dict]: A list of dictionaries, where each dictionary contains: | |
| - "chunk_text" (str): The text of the retrieved chunk. | |
| - "file_name" (str): The name of the file the chunk came from. | |
| - "chunk_index" (int): The index of the chunk in the file. | |
| """ | |
| if self.index is None: | |
| st.error("Vector database is empty. Please upload files and process them first.") | |
| return [] | |
| # Ensure query_embedding is a numpy array | |
| query_embedding = np.array(query_embedding, dtype=np.float32).reshape(1, -1) # Reshape for FAISS | |
| dist, indices = self.index.search(query_embedding, k=k) | |
| results = [] | |
| for (i, j) in zip(indices[0], dist[0]): | |
| chunk_text = self.chunks[i] | |
| metadata = self.chunks_metadata[i] | |
| results.append({ | |
| "chunk_text": chunk_text, | |
| "file_name": metadata["file_name"], | |
| "chunk_index": metadata["chunk_index"], | |
| "score": j | |
| }) | |
| return results | |
| def is_empty(self) -> bool: | |
| """ | |
| Checks if the vector database is empty. | |
| Returns: | |
| bool: True if the database is empty, False otherwise. | |
| """ | |
| return self.index is None | |
| if __name__ == "__main__": | |
| # This part is for testing the VectorDatabase class. | |
| # It will only run if you execute this file directly: python vector_database.py | |
| # # Create some dummy data | |
| # embeddings = [ | |
| # np.array([1.0, 2.0, 3.0]), | |
| # np.array([4.0, 5.0, 6.0]), | |
| # np.array([7.0, 8.0, 9.0]), | |
| # np.array([10.0, 11.0, 12.0]), | |
| # ] | |
| # chunks = [ | |
| # "This is chunk 1 from file A.", | |
| # "This is chunk 2 from file A.", | |
| # "This is chunk 1 from file B.", | |
| # "This is chunk 2 from file B.", | |
| # ] | |
| # chunks_metadata = [ | |
| # {"file_name": "file_a.pdf", "chunk_index": 0}, | |
| # {"file_name": "file_a.pdf", "chunk_index": 1}, | |
| # {"file_name": "file_b.docx", "chunk_index": 0}, | |
| # {"file_name": "file_b.docx", "chunk_index": 1}, | |
| # ] | |
| dummy_files = ['/Users/zac/Downloads/Janna/verbatimprocs/FZ- revenante - sept24.docx'] | |
| chunks, embeddings, chunks_metadata = process_files(dummy_files) | |
| # 1. Initialize the VectorDatabase | |
| vector_db = VectorDatabase(dimension=embeddings[0].shape[0]) # Initialize with dimension | |
| # 2. Add data to the VectorDatabase | |
| vector_db.add_data(embeddings, chunks, chunks_metadata) | |
| print("Data added to VectorDatabase.") | |
| # 3. Perform a query | |
| query_embedding = np.random.rand(embeddings[0].shape[0]).astype(np.float32) # Random query embedding | |
| results = vector_db.query(query_embedding, k=2) # Get the top 2 results | |
| print("\nQuery results:") | |
| for result in results: | |
| print(f"Chunk: {result['chunk_text']}") | |
| print(f" File: {result['file_name']}") | |
| print(f" Index: {result['chunk_index']}") | |
| # 4. Check if the database is empty | |
| print(f"\nIs the database empty? {vector_db.is_empty()}") # Check is_empty method | |
| # 5. Initialize without dimension and then add data | |
| vector_db2 = VectorDatabase() | |
| vector_db2.add_data(embeddings, chunks, chunks_metadata) | |
| print("\nData added to VectorDatabase2 (without initial dimension).") | |
| query_embedding_2 = np.random.rand(embeddings[0].shape[0]).astype(np.float32) # Random query embedding | |
| results_2 = vector_db2.query(query_embedding_2, k=1) | |
| print("\nQuery results from VectorDatabase2:") | |
| for result in results_2: | |
| print(f"Chunk: {result['chunk_text']}") | |
| print(f" File: {result['file_name']}") | |
| print(f" Index: {result['chunk_index']}") | |
| """ | |
| Key improvements and explanations: | |
| Class Structure: The VectorDatabase class encapsulates the FAISS index, chunks, and metadata, providing a clean and organized way to manage the vector database. | |
| Initialization: The __init__ method now takes an optional dimension argument. If not provided during initialization, the dimension is inferred when the first data is added. This provides more flexibility. | |
| Data Handling: The add_data method takes lists of embeddings, chunks, and metadata, and stores them in the object. It also converts the embeddings to a float32 numpy array, which is the format FAISS expects, and checks for dimension consistency. | |
| Querying: The query method performs a similarity search using FAISS and returns a list of dictionaries containing the relevant information. It also handles the case where the database is empty. | |
| Error Handling: The add_data and query methods include error handling for invalid input or an empty database. | |
| Clarity: The code is well-commented and easy to understand. | |
| Testing: The if __name__ == "__main__": block provides a comprehensive test of the VectorDatabase class, demonstrating how to add data, perform queries, and check if the database is empty. I've added a test for initializing the database without a dimension. | |
| """ |