Spaces:
Sleeping
Sleeping
File size: 7,582 Bytes
35bda59 76894b4 35bda59 76894b4 35bda59 76894b4 35bda59 76894b4 79a5c64 35bda59 76894b4 35bda59 76894b4 35bda59 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | import faiss
import numpy as np
import streamlit as st
from typing import List, Dict, Optional, Union
from doc_preprocessing import process_files
class VectorDatabase:
"""
A class to manage a vector database using FAISS for efficient similarity search.
"""
def __init__(self, dimension: int = 0):
"""
Initializes the VectorDatabase.
Args:
dimension (int, optional): The dimension of the embeddings. If None, the
index is not initialized until data is added. Defaults to None.
"""
self.dimension = dimension
self.index: Optional[faiss.Index] = None
self.chunks: List[str] = []
self.chunks_metadata: List[Dict] = []
def add_data(self, embeddings: List[np.ndarray], chunks: List[str], chunks_metadata: List[Dict]):
"""
Adds embeddings, text chunks, and metadata to the vector database.
Args:
embeddings (List[List[float]]): A list of embeddings (each a list or numpy array).
chunks (List[str]): A list of corresponding text chunks.
chunks_metadata (List[Dict]): A list of metadata dictionaries, one for each chunk.
"""
if not embeddings:
st.error("No embeddings to add to the database.")
return
# Ensure embeddings are numpy arrays
embeddings = [np.array(emb) for emb in embeddings]
if self.dimension == 0:
self.dimension = embeddings[0].shape[0]
self.index = faiss.IndexFlatL2(self.dimension) # Use L2 distance
elif self.dimension != embeddings[0].shape[0]:
st.error(f"Embedding dimension ({embeddings[0].shape[0]}) does not match database dimension ({self.dimension}).")
return
# Convert embeddings to a float32 numpy array for FAISS
embeddings_np = np.array(embeddings, dtype=np.float32)
if self.index is None:
self.index = faiss.IndexFlatL2(self.dimension)
self.index.add(embeddings_np)
self.chunks = chunks
self.chunks_metadata = chunks_metadata
def query(self, query_embedding: Union[List[float], np.ndarray], k: int = 3) -> List[Dict]:
"""
Queries the vector database for the most similar chunks to a query embedding.
Args:
query_embedding (List[float] or np.ndarray): The embedding of the query.
k (int, optional): The number of nearest neighbors to retrieve. Defaults to 3.
Returns:
List[Dict]: A list of dictionaries, where each dictionary contains:
- "chunk_text" (str): The text of the retrieved chunk.
- "file_name" (str): The name of the file the chunk came from.
- "chunk_index" (int): The index of the chunk in the file.
"""
if self.index is None:
st.error("Vector database is empty. Please upload files and process them first.")
return []
# Ensure query_embedding is a numpy array
query_embedding = np.array(query_embedding, dtype=np.float32).reshape(1, -1) # Reshape for FAISS
dist, indices = self.index.search(query_embedding, k=k)
results = []
for (i, j) in zip(indices[0], dist[0]):
chunk_text = self.chunks[i]
metadata = self.chunks_metadata[i]
results.append({
"chunk_text": chunk_text,
"file_name": metadata["file_name"],
"chunk_index": metadata["chunk_index"],
"score": j
})
return results
def is_empty(self) -> bool:
"""
Checks if the vector database is empty.
Returns:
bool: True if the database is empty, False otherwise.
"""
return self.index is None
if __name__ == "__main__":
# This part is for testing the VectorDatabase class.
# It will only run if you execute this file directly: python vector_database.py
# # Create some dummy data
# embeddings = [
# np.array([1.0, 2.0, 3.0]),
# np.array([4.0, 5.0, 6.0]),
# np.array([7.0, 8.0, 9.0]),
# np.array([10.0, 11.0, 12.0]),
# ]
# chunks = [
# "This is chunk 1 from file A.",
# "This is chunk 2 from file A.",
# "This is chunk 1 from file B.",
# "This is chunk 2 from file B.",
# ]
# chunks_metadata = [
# {"file_name": "file_a.pdf", "chunk_index": 0},
# {"file_name": "file_a.pdf", "chunk_index": 1},
# {"file_name": "file_b.docx", "chunk_index": 0},
# {"file_name": "file_b.docx", "chunk_index": 1},
# ]
dummy_files = ['/Users/zac/Downloads/Janna/verbatimprocs/FZ- revenante - sept24.docx']
chunks, embeddings, chunks_metadata = process_files(dummy_files)
# 1. Initialize the VectorDatabase
vector_db = VectorDatabase(dimension=embeddings[0].shape[0]) # Initialize with dimension
# 2. Add data to the VectorDatabase
vector_db.add_data(embeddings, chunks, chunks_metadata)
print("Data added to VectorDatabase.")
# 3. Perform a query
query_embedding = np.random.rand(embeddings[0].shape[0]).astype(np.float32) # Random query embedding
results = vector_db.query(query_embedding, k=2) # Get the top 2 results
print("\nQuery results:")
for result in results:
print(f"Chunk: {result['chunk_text']}")
print(f" File: {result['file_name']}")
print(f" Index: {result['chunk_index']}")
# 4. Check if the database is empty
print(f"\nIs the database empty? {vector_db.is_empty()}") # Check is_empty method
# 5. Initialize without dimension and then add data
vector_db2 = VectorDatabase()
vector_db2.add_data(embeddings, chunks, chunks_metadata)
print("\nData added to VectorDatabase2 (without initial dimension).")
query_embedding_2 = np.random.rand(embeddings[0].shape[0]).astype(np.float32) # Random query embedding
results_2 = vector_db2.query(query_embedding_2, k=1)
print("\nQuery results from VectorDatabase2:")
for result in results_2:
print(f"Chunk: {result['chunk_text']}")
print(f" File: {result['file_name']}")
print(f" Index: {result['chunk_index']}")
"""
Key improvements and explanations:
Class Structure: The VectorDatabase class encapsulates the FAISS index, chunks, and metadata, providing a clean and organized way to manage the vector database.
Initialization: The __init__ method now takes an optional dimension argument. If not provided during initialization, the dimension is inferred when the first data is added. This provides more flexibility.
Data Handling: The add_data method takes lists of embeddings, chunks, and metadata, and stores them in the object. It also converts the embeddings to a float32 numpy array, which is the format FAISS expects, and checks for dimension consistency.
Querying: The query method performs a similarity search using FAISS and returns a list of dictionaries containing the relevant information. It also handles the case where the database is empty.
Error Handling: The add_data and query methods include error handling for invalid input or an empty database.
Clarity: The code is well-commented and easy to understand.
Testing: The if __name__ == "__main__": block provides a comprehensive test of the VectorDatabase class, demonstrating how to add data, perform queries, and check if the database is empty. I've added a test for initializing the database without a dimension.
""" |