Spaces:
Running
Running
File size: 7,330 Bytes
e72f783 cbfd492 e72f783 cbfd492 e72f783 cbfd492 e72f783 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | # src/retriever.py
# Loads and searches all 3 FAISS indexes
#
# Index 1 β Category (15 vectors, IndexFlatIP, CLIP full-image)
# Index 2 β Defect pattern (5354 vectors, IndexFlatIP, CLIP crop)
# Index 3 β PatchCore coreset (per-category, IndexFlatL2, WideResNet patches)
# LAZY LOADED β only loaded on first request per category
# Reduces startup time from ~45s to ~15s
import os
import json
import numpy as np
import faiss
# Paths β relative to repo root, mounted in Docker at /app/data/
DATA_DIR = os.environ.get("DATA_DIR", "data")
CATEGORIES = [
'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut',
'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush',
'transistor', 'wood', 'zipper'
]
class FAISSRetriever:
"""
Manages all 3 FAISS indexes with lazy loading for Index 3.
Loaded once at FastAPI startup, kept in memory for server lifetime.
"""
def __init__(self, data_dir=DATA_DIR):
self.data_dir = data_dir
self.index1 = None # Category index
self.index1_metadata = None
self.index2 = None # Defect pattern index
self.index2_metadata = None
self.index3_cache = {} # category β loaded FAISS index (lazy)
def load_indexes(self):
"""
Load Index 1 and Index 2 at startup.
Index 3 is lazy-loaded per category on first request.
"""
# ββ Index 1 ββββββββββββββββββββββββββββββββββββββββββ
idx1_path = os.path.join(self.data_dir, "index1_category.faiss")
meta1_path = os.path.join(self.data_dir, "index1_metadata.json")
if not os.path.exists(idx1_path):
raise FileNotFoundError(f"Index 1 not found: {idx1_path}")
self.index1 = faiss.read_index(idx1_path)
with open(meta1_path) as f:
self.index1_metadata = json.load(f)
print(f"Index 1 loaded: {self.index1.ntotal} category vectors")
# ββ Index 2 ββββββββββββββββββββββββββββββββββββββββββ
idx2_path = os.path.join(self.data_dir, "index2_defect.faiss")
meta2_path = os.path.join(self.data_dir, "index2_metadata.json")
if not os.path.exists(idx2_path):
raise FileNotFoundError(f"Index 2 not found: {idx2_path}")
# Memory-mapped β not fully loaded into RAM
self.index2 = faiss.read_index(idx2_path, faiss.IO_FLAG_MMAP)
with open(meta2_path) as f:
self.index2_metadata = json.load(f)
print(f"Index 2 loaded: {self.index2.ntotal} defect pattern vectors")
def _load_index3(self, category: str):
"""Lazy load Index 3 for a specific category."""
if category not in self.index3_cache:
path = os.path.join(self.data_dir, f"index3_{category}.faiss")
if not os.path.exists(path):
raise FileNotFoundError(f"Index 3 not found for {category}: {path}")
self.index3_cache[category] = faiss.read_index(
path, faiss.IO_FLAG_MMAP
)
print(f"Index 3 lazy-loaded: {category} "
f"({self.index3_cache[category].ntotal} coreset vectors)")
return self.index3_cache[category]
# ββ Index 1: Category routing βββββββββββββββββββββββββββββ
def route_category(self, clip_full_embedding: np.ndarray) -> dict:
"""
Given a full-image CLIP embedding, return the predicted category.
Returns: {category, confidence_score}
"""
query = clip_full_embedding.reshape(1, -1).astype(np.float32)
# Normalise for cosine similarity
query = query / (np.linalg.norm(query) + 1e-8)
D, I = self.index1.search(query, k=1)
cat_idx = int(I[0][0])
return {
"category": CATEGORIES[cat_idx],
"confidence": float(D[0][0])
}
# ββ Index 2: Defect pattern retrieval ββββββββββββββββββββ
def retrieve_similar_defects(self,
clip_crop_embedding: np.ndarray,
k: int = 5,
exclude_hash: str = None,
category_filter: str = None) -> list:
"""
Given a defect-crop CLIP embedding, return k most similar
historical defect cases.
exclude_hash: skip self-match (same image submitted again)
category_filter: only return cases from specified category
Returns: list of metadata dicts with similarity scores
"""
query = clip_crop_embedding.reshape(1, -1).astype(np.float32)
query = query / (np.linalg.norm(query) + 1e-8)
# Fetch k+1 to allow filtering self-match
D, I = self.index2.search(query, k=k + 1)
results = []
for dist, idx in zip(D[0], I[0]):
if idx < 0:
continue
meta = self.index2_metadata[idx].copy()
meta["similarity_score"] = float(dist)
# Filter by category if provided
if category_filter and meta.get("category") != category_filter:
continue
# Skip self-match
if exclude_hash and meta.get("image_hash") == exclude_hash:
continue
results.append(meta)
if len(results) == k:
break
return results
# ββ Index 3: PatchCore k-NN scoring ββββββββββββββββββββββ
def score_patches(self,
patches: np.ndarray,
category: str,
k: int = 1) -> tuple:
"""
Given [784, 256] patch features, return anomaly score and
per-patch distance grid.
Returns:
image_score: float β max patch distance (anomaly score)
patch_scores: [28, 28] numpy array of per-patch distances
nn_distances: [784, k] all k-NN distances (for confidence interval)
"""
index3 = self._load_index3(category)
patches_f32 = patches.astype(np.float32)
# k=5 neighbours: first for scoring, rest for confidence interval
D, _ = index3.search(patches_f32, k=5)
# Primary score: nearest neighbour distance per patch
patch_scores = D[:, 0].reshape(28, 28)
image_score = float(patch_scores.max())
# Confidence interval: std of top-5 distances at most anomalous patch
max_patch_idx = np.argmax(D[:, 0])
score_std = float(np.std(D[max_patch_idx]))
return image_score, patch_scores, score_std, D
def get_status(self) -> dict:
"""Returns index sizes for /health endpoint."""
return {
"index1_vectors": self.index1.ntotal if self.index1 else 0,
"index2_vectors": self.index2.ntotal if self.index2 else 0,
"index3_loaded_categories": list(self.index3_cache.keys()),
"index3_total_categories": len(CATEGORIES)
}
# Global instance β initialised in api/startup.py
retriever = FAISSRetriever() |