File size: 3,885 Bytes
b378103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import faiss
import numpy as np
import pickle
from pathlib import Path
from typing import List, Tuple, Optional


class FAISSVectorStore:

    def __init__(self, embedding_dim: int, index_path: str = None):

        self.embedding_dim = embedding_dim
        self.index_path = index_path or "./data/faiss_index"
        self.index = None
        self.id_to_metadata = {}  # Map FAISS ID to metadata
        self.current_id = 0

        Path(self.index_path).parent.mkdir(parents=True, exist_ok=True)

        if Path(f"{self.index_path}.faiss").exists():
            self.load()
        else:
            self._create_new_index()

    def _create_new_index(self):
        self.index = faiss.IndexFlatL2(self.embedding_dim)
        self.id_to_metadata = {}
        self.current_id = 0
        print(f"Created new FAISS index with dimension {self.embedding_dim}")

    def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[int]:

        if embeddings.shape[1] != self.embedding_dim:
            raise ValueError(
                f"Embedding dimension mismatch: expected {self.embedding_dim}, "
                f"got {embeddings.shape[1]}"
            )

        embeddings = embeddings.astype("float32")

        num_vectors = embeddings.shape[0]
        ids = list(range(self.current_id, self.current_id + num_vectors))

        self.index.add(embeddings)

        for i, meta in zip(ids, metadata):
            self.id_to_metadata[i] = meta

        self.current_id += num_vectors

        print(f"Added {num_vectors} vectors. Total: {self.index.ntotal}")
        return ids

    def search(
        self,
        query_embedding: np.ndarray,
        k: int = 5,
        version_filter: Optional[int] = None,
    ) -> List[Tuple[float, dict]]:

        if self.index.ntotal == 0:
            return []

        if query_embedding.ndim == 1:
            query_embedding = query_embedding.reshape(1, -1)
        query_embedding = query_embedding.astype("float32")

        search_k = k * 10 if version_filter else k
        distances, indices = self.index.search(
            query_embedding, min(search_k, self.index.ntotal)
        )

        results = []
        for dist, idx in zip(distances[0], indices[0]):
            if idx == -1:
                continue

            metadata = self.id_to_metadata.get(int(idx), {})

            if version_filter is not None:
                if metadata.get("version_id") != version_filter:
                    continue

            results.append((float(dist), metadata))

            if len(results) >= k:
                break

        return results

    def save(self):
        faiss.write_index(self.index, f"{self.index_path}.faiss")

        with open(f"{self.index_path}.meta", "wb") as f:
            pickle.dump(
                {
                    "id_to_metadata": self.id_to_metadata,
                    "current_id": self.current_id,
                    "embedding_dim": self.embedding_dim,
                },
                f,
            )

        print(f"Saved index to {self.index_path}")

    def load(self):
        try:
            self.index = faiss.read_index(f"{self.index_path}.faiss")

            with open(f"{self.index_path}.meta", "rb") as f:
                data = pickle.load(f)
                self.id_to_metadata = data["id_to_metadata"]
                self.current_id = data["current_id"]
                self.embedding_dim = data["embedding_dim"]

            print(f"Loaded index from {self.index_path} ({self.index.ntotal} vectors)")
        except Exception as e:
            print(f"Error loading index: {e}")
            self._create_new_index()

    def get_stats(self) -> dict:
        return {
            "total_vectors": self.index.ntotal if self.index else 0,
            "embedding_dim": self.embedding_dim,
            "index_path": self.index_path,
        }