File size: 7,330 Bytes
e72f783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbfd492
 
e72f783
 
 
 
cbfd492
e72f783
 
 
 
 
 
 
 
 
 
 
 
 
 
cbfd492
 
 
e72f783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# src/retriever.py
# Loads and searches all 3 FAISS indexes
#
# Index 1 β€” Category (15 vectors, IndexFlatIP, CLIP full-image)
# Index 2 β€” Defect pattern (5354 vectors, IndexFlatIP, CLIP crop)
# Index 3 β€” PatchCore coreset (per-category, IndexFlatL2, WideResNet patches)
#           LAZY LOADED β€” only loaded on first request per category
#           Reduces startup time from ~45s to ~15s

import os
import json
import numpy as np
import faiss

# Paths β€” relative to repo root, mounted in Docker at /app/data/
DATA_DIR = os.environ.get("DATA_DIR", "data")

CATEGORIES = [
    'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut',
    'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush',
    'transistor', 'wood', 'zipper'
]


class FAISSRetriever:
    """
    Manages all 3 FAISS indexes with lazy loading for Index 3.
    Loaded once at FastAPI startup, kept in memory for server lifetime.
    """

    def __init__(self, data_dir=DATA_DIR):
        self.data_dir = data_dir
        self.index1 = None          # Category index
        self.index1_metadata = None
        self.index2 = None          # Defect pattern index
        self.index2_metadata = None
        self.index3_cache = {}      # category β†’ loaded FAISS index (lazy)

    def load_indexes(self):
        """
        Load Index 1 and Index 2 at startup.
        Index 3 is lazy-loaded per category on first request.
        """
        # ── Index 1 ──────────────────────────────────────────
        idx1_path = os.path.join(self.data_dir, "index1_category.faiss")
        meta1_path = os.path.join(self.data_dir, "index1_metadata.json")

        if not os.path.exists(idx1_path):
            raise FileNotFoundError(f"Index 1 not found: {idx1_path}")

        self.index1 = faiss.read_index(idx1_path)
        with open(meta1_path) as f:
            self.index1_metadata = json.load(f)
        print(f"Index 1 loaded: {self.index1.ntotal} category vectors")

        # ── Index 2 ──────────────────────────────────────────
        idx2_path = os.path.join(self.data_dir, "index2_defect.faiss")
        meta2_path = os.path.join(self.data_dir, "index2_metadata.json")

        if not os.path.exists(idx2_path):
            raise FileNotFoundError(f"Index 2 not found: {idx2_path}")

        # Memory-mapped β€” not fully loaded into RAM
        self.index2 = faiss.read_index(idx2_path, faiss.IO_FLAG_MMAP)
        with open(meta2_path) as f:
            self.index2_metadata = json.load(f)
        print(f"Index 2 loaded: {self.index2.ntotal} defect pattern vectors")

    def _load_index3(self, category: str):
        """Lazy load Index 3 for a specific category."""
        if category not in self.index3_cache:
            path = os.path.join(self.data_dir, f"index3_{category}.faiss")
            if not os.path.exists(path):
                raise FileNotFoundError(f"Index 3 not found for {category}: {path}")
            self.index3_cache[category] = faiss.read_index(
                path, faiss.IO_FLAG_MMAP
            )
            print(f"Index 3 lazy-loaded: {category} "
                  f"({self.index3_cache[category].ntotal} coreset vectors)")
        return self.index3_cache[category]

    # ── Index 1: Category routing ─────────────────────────────
    def route_category(self, clip_full_embedding: np.ndarray) -> dict:
        """
        Given a full-image CLIP embedding, return the predicted category.
        Returns: {category, confidence_score}
        """
        query = clip_full_embedding.reshape(1, -1).astype(np.float32)
        # Normalise for cosine similarity
        query = query / (np.linalg.norm(query) + 1e-8)
        D, I = self.index1.search(query, k=1)
        cat_idx = int(I[0][0])
        return {
            "category": CATEGORIES[cat_idx],
            "confidence": float(D[0][0])
        }

    # ── Index 2: Defect pattern retrieval ────────────────────
    def retrieve_similar_defects(self,
                                  clip_crop_embedding: np.ndarray,
                                  k: int = 5,
                                  exclude_hash: str = None,
                                  category_filter: str = None) -> list:
        """
        Given a defect-crop CLIP embedding, return k most similar
        historical defect cases.
        exclude_hash: skip self-match (same image submitted again)
        category_filter: only return cases from specified category
        Returns: list of metadata dicts with similarity scores
        """
        query = clip_crop_embedding.reshape(1, -1).astype(np.float32)
        query = query / (np.linalg.norm(query) + 1e-8)

        # Fetch k+1 to allow filtering self-match
        D, I = self.index2.search(query, k=k + 1)

        results = []
        for dist, idx in zip(D[0], I[0]):
            if idx < 0:
                continue
            meta = self.index2_metadata[idx].copy()
            meta["similarity_score"] = float(dist)
            # Filter by category if provided
            if category_filter and meta.get("category") != category_filter:
                continue
            # Skip self-match
            if exclude_hash and meta.get("image_hash") == exclude_hash:
                continue
            results.append(meta)
            if len(results) == k:
                break

        return results

    # ── Index 3: PatchCore k-NN scoring ──────────────────────
    def score_patches(self,
                       patches: np.ndarray,
                       category: str,
                       k: int = 1) -> tuple:
        """
        Given [784, 256] patch features, return anomaly score and
        per-patch distance grid.
        
        Returns:
            image_score: float β€” max patch distance (anomaly score)
            patch_scores: [28, 28] numpy array of per-patch distances
            nn_distances: [784, k] all k-NN distances (for confidence interval)
        """
        index3 = self._load_index3(category)
        patches_f32 = patches.astype(np.float32)

        # k=5 neighbours: first for scoring, rest for confidence interval
        D, _ = index3.search(patches_f32, k=5)

        # Primary score: nearest neighbour distance per patch
        patch_scores = D[:, 0].reshape(28, 28)
        image_score = float(patch_scores.max())

        # Confidence interval: std of top-5 distances at most anomalous patch
        max_patch_idx = np.argmax(D[:, 0])
        score_std = float(np.std(D[max_patch_idx]))

        return image_score, patch_scores, score_std, D

    def get_status(self) -> dict:
        """Returns index sizes for /health endpoint."""
        return {
            "index1_vectors": self.index1.ntotal if self.index1 else 0,
            "index2_vectors": self.index2.ntotal if self.index2 else 0,
            "index3_loaded_categories": list(self.index3_cache.keys()),
            "index3_total_categories": len(CATEGORIES)
        }


# Global instance β€” initialised in api/startup.py
retriever = FAISSRetriever()