File size: 7,403 Bytes
0a937d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# import faiss
import numpy as np
import hnswlib
from src.utils import print_rank

# class FAISSIndex:
#     """
#     Manages FAISS indices for different candidate types and their associated keys.
#     """
#     # BUG: incompatible with numpy >= 2.0.0
#     def __init__(self, ngpus=None):
#         self.indices = {}  # Stores FAISS indices for each candidate type
#         self.keys_dict = {}  # Stores candidate keys for each candidate type
#         self.ngpus = ngpus or faiss.get_num_gpus()
#         print_rank(f"FAISS Index initialized with {self.ngpus} GPUs")
        
#     def create_index(self, cand_type, cand_vectors, cand_keys):
#         """
#         Create a multi-GPU FAISS index for a candidate type.
        
#         Args:
#             cand_type (str): Candidate type (state, trajectory, interval)
#             cand_vectors (np.ndarray): Embeddings for the candidates
#             cand_keys (list): List of candidate IDs
#         """
#         print_rank(f"Building FAISS index for {cand_type}")
#         assert len(cand_keys) == cand_vectors.shape[0]
#         # Store candidate keys for this type
#         self.keys_dict[cand_type] = cand_keys
        
#         # Normalize vectors for cosine similarity
#         vectors = cand_vectors.astype(np.float32).copy()
#         faiss.normalize_L2(vectors)
        
#         # Create CPU index
#         d = vectors.shape[1]  # Embedding dimension
#         cpu_index = faiss.IndexFlatIP(d)  # Inner product similarity
#         cpu_index.add(vectors)
        
#         # Distribute the index across multiple GPUs
#         co = faiss.GpuMultipleClonerOptions()
#         co.shard = True  # Shard the index across GPUs
#         gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=co, ngpu=self.ngpus)
        
#         # Store the GPU index
#         self.indices[cand_type] = gpu_index
        
#     def search(self, cand_type, query_vectors, k):
#         """
#         Search for nearest neighbors in the index for a specific candidate type.
        
#         Args:
#             cand_type (str): Candidate type (state, trajectory, interval)
#             query_vector (np.ndarray): Query embedding(s)
#             k (int): Number of results to retrieve
            
#         Returns:
#             tuple: (scores, predictions) where:
#                 - scores is a list of lists of similarity scores
#                 - predictions is a list of lists of candidate IDs
#         """
#         if cand_type not in self.indices:
#             raise ValueError(f"Index for {cand_type} not found")
        
#         if len(query_vectors.shape) == 1:
#             q = query_vectors.reshape(1, -1).astype(np.float32)
#         else:
#             q = query_vectors.astype(np.float32)

#         # Normalize vectors for cosine similarity
#         faiss.normalize_L2(q)

#         assert q.shape[1] == self.indices[cand_type].d, \
#             f"Query dimension {q.shape[1]} doesn't match index dimension {self.indices[cand_type].d}"
    
#         # Search in the appropriate index
#         scores, indices = self.indices[cand_type].search(q, k)
        
#         # Process results - create a list of predictions for each query
#         all_predictions = []
#         for i in range(indices.shape[0]):
#             predictions = [self.keys_dict[cand_type][int(idx)] for idx in indices[i]]
#             all_predictions.append(predictions)
        
#         return scores.tolist(), all_predictions


class HNSWIndex:
    """
    Manages HNSW indices for different candidate types and their associated keys.
    This implementation provides functionality similar to FAISSIndex.
    """
    def __init__(self, ef_construction=200, M=48):
        self.indices = {}  # Stores HNSW indices for each candidate type
        self.keys_dict = {}  # Stores candidate keys for each candidate type
        self.dimensions = {}  # Stores embedding dimensions for each candidate type
        self.ef_construction = ef_construction  # Controls index quality
        self.M = M  # Controls graph connectivity
        print_rank(f"HNSW Index initialized with ef_construction={ef_construction}, M={M}")
        
    def create_index(self, cand_type, cand_vectors, cand_keys):
        """
        Create an HNSW index for a candidate type.
        
        Args:
            cand_type (str): Candidate type (state, trajectory, interval)
            cand_vectors (np.ndarray): Embeddings for the candidates
            cand_keys (list): List of candidate IDs
        """
        print_rank(f"Building HNSW index for {cand_type}")
        assert len(cand_keys) == cand_vectors.shape[0]
        # Store candidate keys for this type
        self.keys_dict[cand_type] = cand_keys
        
        # Normalize vectors for cosine similarity
        vectors = cand_vectors.astype(np.float32).copy()
        # Equivalent to faiss.normalize_L2
        norms = np.linalg.norm(vectors, axis=1, keepdims=True)
        assert not np.any(norms == 0), "Zero norm found in candidate vectors"
        vectors = vectors / norms
        
        num_elements, dim = vectors.shape
        
        # Initialize the index; using cosine metric (distance = 1 - cosine similarity)
        index = hnswlib.Index(space='cosine', dim=dim)
        index.init_index(max_elements=num_elements, ef_construction=self.ef_construction, M=self.M)
        
        # Add all vectors with their IDs
        index.add_items(vectors, np.arange(num_elements))
        
        # Set search quality parameter
        index.set_ef(100)
        
        # Store the index
        self.indices[cand_type] = index
        
    def search(self, cand_type, query_vectors, k):
        """
        Search for nearest neighbors in the index for a specific candidate type.
        
        Args:
            cand_type (str): Candidate type (state, trajectory, interval)
            query_vector (np.ndarray): Query embedding(s)
            k (int): Number of results to retrieve
            
        Returns:
            tuple: (scores, predictions) where:
                - scores is a list of lists of similarity scores
                - predictions is a list of lists of candidate IDs
        """
        if cand_type not in self.indices:
            raise ValueError(f"Index for {cand_type} not found")
        
        if len(query_vectors.shape) == 1:
            q = query_vectors.reshape(1, -1).astype(np.float32)
        else:
            q = query_vectors.astype(np.float32)

        # Normalize query vectors
        norms = np.linalg.norm(q, axis=1, keepdims=True)
        assert not np.any(norms == 0), "Zero norm found in query vectors"
        q = q / norms

        assert q.shape[1] == self.indices[cand_type].dim, \
            f"Query dimension {q.shape[1]} doesn't match index dimension {self.indices[cand_type].dim}"
    
        # Search in the HNSW index
        indices, distances = self.indices[cand_type].knn_query(q, k=k)
        
        # Convert distances to similarity scores
        scores = 1 - distances
        
        # Process results - create a list of predictions for each query
        all_predictions = []
        for i in range(indices.shape[0]):
            predictions = [self.keys_dict[cand_type][int(idx)] for idx in indices[i]]
            all_predictions.append(predictions)
        
        return scores.tolist(), all_predictions