File size: 7,582 Bytes
35bda59
 
 
76894b4
35bda59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76894b4
35bda59
 
 
 
76894b4
35bda59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76894b4
79a5c64
35bda59
76894b4
35bda59
 
 
 
 
 
76894b4
35bda59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import faiss
import numpy as np
import streamlit as st
from typing import List, Dict, Optional, Union

from doc_preprocessing import process_files 

class VectorDatabase:
    """
    A class to manage a vector database using FAISS for efficient similarity search.
    """
    def __init__(self, dimension: int = 0):
        """
        Initializes the VectorDatabase.

        Args:
            dimension (int, optional): The dimension of the embeddings. If None, the
                index is not initialized until data is added. Defaults to None.
        """
        self.dimension = dimension
        self.index: Optional[faiss.Index] = None
        self.chunks: List[str] = []
        self.chunks_metadata: List[Dict] = []

    def add_data(self, embeddings: List[np.ndarray], chunks: List[str], chunks_metadata: List[Dict]):
        """
        Adds embeddings, text chunks, and metadata to the vector database.

        Args:
            embeddings (List[List[float]]): A list of embeddings (each a list or numpy array).
            chunks (List[str]): A list of corresponding text chunks.
            chunks_metadata (List[Dict]): A list of metadata dictionaries, one for each chunk.
        """
        if not embeddings:
            st.error("No embeddings to add to the database.")
            return

        # Ensure embeddings are numpy arrays
        embeddings = [np.array(emb) for emb in embeddings]

        if self.dimension == 0:
            self.dimension = embeddings[0].shape[0]
            self.index = faiss.IndexFlatL2(self.dimension)  # Use L2 distance
        elif self.dimension != embeddings[0].shape[0]:
            st.error(f"Embedding dimension ({embeddings[0].shape[0]}) does not match database dimension ({self.dimension}).")
            return

        # Convert embeddings to a float32 numpy array for FAISS
        embeddings_np = np.array(embeddings, dtype=np.float32)
        if self.index is None:
            self.index = faiss.IndexFlatL2(self.dimension)
        self.index.add(embeddings_np)
        self.chunks = chunks
        self.chunks_metadata = chunks_metadata

    def query(self, query_embedding: Union[List[float], np.ndarray], k: int = 3) -> List[Dict]:
        """
        Queries the vector database for the most similar chunks to a query embedding.

        Args:
            query_embedding (List[float] or np.ndarray): The embedding of the query.
            k (int, optional): The number of nearest neighbors to retrieve. Defaults to 3.

        Returns:
            List[Dict]: A list of dictionaries, where each dictionary contains:
                - "chunk_text" (str): The text of the retrieved chunk.
                - "file_name" (str): The name of the file the chunk came from.
                - "chunk_index" (int): The index of the chunk in the file.
        """
        if self.index is None:
            st.error("Vector database is empty. Please upload files and process them first.")
            return []

        # Ensure query_embedding is a numpy array
        query_embedding = np.array(query_embedding, dtype=np.float32).reshape(1, -1)  # Reshape for FAISS

        dist, indices = self.index.search(query_embedding, k=k)
        
        results = []
        for (i, j) in zip(indices[0], dist[0]):
            chunk_text = self.chunks[i]
            metadata = self.chunks_metadata[i]
            results.append({
                "chunk_text": chunk_text,
                "file_name": metadata["file_name"],
                "chunk_index": metadata["chunk_index"],
                "score": j
            })
        return results

    def is_empty(self) -> bool:
        """
        Checks if the vector database is empty.

        Returns:
            bool: True if the database is empty, False otherwise.
        """
        return self.index is None

if __name__ == "__main__":
    # This part is for testing the VectorDatabase class.
    #  It will only run if you execute this file directly: python vector_database.py

    # # Create some dummy data
    # embeddings = [
    #     np.array([1.0, 2.0, 3.0]),
    #     np.array([4.0, 5.0, 6.0]),
    #     np.array([7.0, 8.0, 9.0]),
    #     np.array([10.0, 11.0, 12.0]),
    # ]
    # chunks = [
    #     "This is chunk 1 from file A.",
    #     "This is chunk 2 from file A.",
    #     "This is chunk 1 from file B.",
    #     "This is chunk 2 from file B.",
    # ]
    # chunks_metadata = [
    #     {"file_name": "file_a.pdf", "chunk_index": 0},
    #     {"file_name": "file_a.pdf", "chunk_index": 1},
    #     {"file_name": "file_b.docx", "chunk_index": 0},
    #     {"file_name": "file_b.docx", "chunk_index": 1},
    # ]
    dummy_files = ['/Users/zac/Downloads/Janna/verbatimprocs/FZ- revenante - sept24.docx']
    chunks, embeddings, chunks_metadata = process_files(dummy_files)

    # 1. Initialize the VectorDatabase
    vector_db = VectorDatabase(dimension=embeddings[0].shape[0]) # Initialize with dimension

    # 2. Add data to the VectorDatabase
    vector_db.add_data(embeddings, chunks, chunks_metadata)
    print("Data added to VectorDatabase.")

    # 3. Perform a query
    query_embedding = np.random.rand(embeddings[0].shape[0]).astype(np.float32)  # Random query embedding
    results = vector_db.query(query_embedding, k=2)  # Get the top 2 results

    print("\nQuery results:")
    for result in results:
        print(f"Chunk: {result['chunk_text']}")
        print(f"  File: {result['file_name']}")
        print(f"  Index: {result['chunk_index']}")

    # 4. Check if the database is empty
    print(f"\nIs the database empty? {vector_db.is_empty()}") # Check is_empty method

    # 5.  Initialize without dimension and then add data
    vector_db2 = VectorDatabase()
    vector_db2.add_data(embeddings, chunks, chunks_metadata)
    print("\nData added to VectorDatabase2 (without initial dimension).")

    query_embedding_2 = np.random.rand(embeddings[0].shape[0]).astype(np.float32)  # Random query embedding
    results_2 = vector_db2.query(query_embedding_2, k=1)
    print("\nQuery results from VectorDatabase2:")
    for result in results_2:
        print(f"Chunk: {result['chunk_text']}")
        print(f"  File: {result['file_name']}")
        print(f"  Index: {result['chunk_index']}")




"""

Key improvements and explanations:

    Class Structure: The VectorDatabase class encapsulates the FAISS index, chunks, and metadata, providing a clean and organized way to manage the vector database.
    Initialization: The __init__ method now takes an optional dimension argument. If not provided during initialization, the dimension is inferred when the first data is added. This provides more flexibility.
    Data Handling: The add_data method takes lists of embeddings, chunks, and metadata, and stores them in the object. It also converts the embeddings to a float32 numpy array, which is the format FAISS expects, and checks for dimension consistency.
    Querying: The query method performs a similarity search using FAISS and returns a list of dictionaries containing the relevant information. It also handles the case where the database is empty.
    Error Handling: The add_data and query methods include error handling for invalid input or an empty database.
    Clarity: The code is well-commented and easy to understand.
    Testing: The if __name__ == "__main__": block provides a comprehensive test of the VectorDatabase class, demonstrating how to add data, perform queries, and check if the database is empty. I've added a test for initializing the database without a dimension.



"""