import faiss import numpy as np import pickle from pathlib import Path from typing import List, Tuple, Optional class FAISSVectorStore: def __init__(self, embedding_dim: int, index_path: str = None): self.embedding_dim = embedding_dim self.index_path = index_path or "./data/faiss_index" self.index = None self.id_to_metadata = {} # Map FAISS ID to metadata self.current_id = 0 Path(self.index_path).parent.mkdir(parents=True, exist_ok=True) if Path(f"{self.index_path}.faiss").exists(): self.load() else: self._create_new_index() def _create_new_index(self): self.index = faiss.IndexFlatL2(self.embedding_dim) self.id_to_metadata = {} self.current_id = 0 print(f"Created new FAISS index with dimension {self.embedding_dim}") def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[int]: if embeddings.shape[1] != self.embedding_dim: raise ValueError( f"Embedding dimension mismatch: expected {self.embedding_dim}, " f"got {embeddings.shape[1]}" ) embeddings = embeddings.astype("float32") num_vectors = embeddings.shape[0] ids = list(range(self.current_id, self.current_id + num_vectors)) self.index.add(embeddings) for i, meta in zip(ids, metadata): self.id_to_metadata[i] = meta self.current_id += num_vectors print(f"Added {num_vectors} vectors. Total: {self.index.ntotal}") return ids def search( self, query_embedding: np.ndarray, k: int = 5, version_filter: Optional[int] = None, ) -> List[Tuple[float, dict]]: if self.index.ntotal == 0: return [] if query_embedding.ndim == 1: query_embedding = query_embedding.reshape(1, -1) query_embedding = query_embedding.astype("float32") search_k = k * 10 if version_filter else k distances, indices = self.index.search( query_embedding, min(search_k, self.index.ntotal) ) results = [] for dist, idx in zip(distances[0], indices[0]): if idx == -1: continue metadata = self.id_to_metadata.get(int(idx), {}) if version_filter is not None: if metadata.get("version_id") != version_filter: continue results.append((float(dist), metadata)) if len(results) >= k: break return results def save(self): faiss.write_index(self.index, f"{self.index_path}.faiss") with open(f"{self.index_path}.meta", "wb") as f: pickle.dump( { "id_to_metadata": self.id_to_metadata, "current_id": self.current_id, "embedding_dim": self.embedding_dim, }, f, ) print(f"Saved index to {self.index_path}") def load(self): try: self.index = faiss.read_index(f"{self.index_path}.faiss") with open(f"{self.index_path}.meta", "rb") as f: data = pickle.load(f) self.id_to_metadata = data["id_to_metadata"] self.current_id = data["current_id"] self.embedding_dim = data["embedding_dim"] print(f"Loaded index from {self.index_path} ({self.index.ntotal} vectors)") except Exception as e: print(f"Error loading index: {e}") self._create_new_index() def get_stats(self) -> dict: return { "total_vectors": self.index.ntotal if self.index else 0, "embedding_dim": self.embedding_dim, "index_path": self.index_path, }