Spaces:
Running
Running
File size: 7,400 Bytes
e72f783 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | # mlops/evaluate_retrieval.py
# Retrieval quality evaluation
# Metrics: MRR, Precision@1, Precision@5
# Run on 50 manually labelled retrieval questions
# Logged to MLflow on DagsHub
import os
import json
import numpy as np
import mlflow
import dagshub
def evaluate_retrieval(index2_metadata_path: str,
index2_faiss_path: str,
clip_model=None,
clip_preprocess=None):
"""
Evaluate retrieval quality of Index 2.
Uses 50 hand-labelled (query_category, expected_defect_type) pairs.
Metrics:
- Precision@1: is the top result the correct defect type?
- Precision@5: how many of top 5 are the correct category?
- MRR: Mean Reciprocal Rank of first correct result
"""
import faiss
# ββ 50 labelled evaluation queries ββββββββββββββββββββββββ
# Each entry: category that should be retrieved
# We use a random image from that category as query
EVAL_QUERIES = [
{"category": "bottle", "defect_type": "broken_large"},
{"category": "bottle", "defect_type": "contamination"},
{"category": "cable", "defect_type": "bent_wire"},
{"category": "cable", "defect_type": "missing_wire"},
{"category": "capsule", "defect_type": "crack"},
{"category": "capsule", "defect_type": "scratch"},
{"category": "carpet", "defect_type": "hole"},
{"category": "carpet", "defect_type": "cut"},
{"category": "grid", "defect_type": "broken"},
{"category": "grid", "defect_type": "bent"},
{"category": "hazelnut", "defect_type": "crack"},
{"category": "hazelnut", "defect_type": "hole"},
{"category": "leather", "defect_type": "cut"},
{"category": "leather", "defect_type": "fold"},
{"category": "metal_nut", "defect_type": "bent"},
{"category": "metal_nut", "defect_type": "scratch"},
{"category": "pill", "defect_type": "crack"},
{"category": "pill", "defect_type": "contamination"},
{"category": "screw", "defect_type": "scratch_head"},
{"category": "screw", "defect_type": "thread_top"},
{"category": "tile", "defect_type": "crack"},
{"category": "tile", "defect_type": "oil"},
{"category": "toothbrush", "defect_type": "defective"},
{"category": "transistor", "defect_type": "bent_lead"},
{"category": "transistor", "defect_type": "damaged_case"},
{"category": "wood", "defect_type": "hole"},
{"category": "wood", "defect_type": "scratch"},
{"category": "zipper", "defect_type": "broken_teeth"},
{"category": "zipper", "defect_type": "split_teeth"},
{"category": "bottle", "defect_type": "broken_small"},
{"category": "cable", "defect_type": "cut_outer_insulation"},
{"category": "capsule", "defect_type": "faulty_imprint"},
{"category": "carpet", "defect_type": "color"},
{"category": "grid", "defect_type": "glue"},
{"category": "hazelnut", "defect_type": "print"},
{"category": "leather", "defect_type": "glue"},
{"category": "metal_nut", "defect_type": "flip"},
{"category": "pill", "defect_type": "faulty_imprint"},
{"category": "screw", "defect_type": "thread_side"},
{"category": "tile", "defect_type": "rough"},
{"category": "wood", "defect_type": "color"},
{"category": "zipper", "defect_type": "fabric_border"},
{"category": "cable", "defect_type": "poke_insulation"},
{"category": "capsule", "defect_type": "poke"},
{"category": "carpet", "defect_type": "thread"},
{"category": "grid", "defect_type": "metal_contamination"},
{"category": "leather", "defect_type": "poke"},
{"category": "metal_nut", "defect_type": "color"},
{"category": "pill", "defect_type": "scratch"},
{"category": "transistor", "defect_type": "misplaced"},
]
# Load Index 2
if not os.path.exists(index2_faiss_path):
print(f"Index 2 not found: {index2_faiss_path}")
return {}
index2 = faiss.read_index(index2_faiss_path)
with open(index2_metadata_path) as f:
metadata = json.load(f)
# Build lookup: category β list of embeddings from metadata
# We use stored clip_crop_embedding from enriched records as queries
# For evaluation: find records matching each query's category+defect_type
# and use their stored embeddings as queries
precision_at_1 = []
precision_at_5 = []
reciprocal_ranks = []
for query_info in EVAL_QUERIES:
q_cat = query_info["category"]
q_defect = query_info["defect_type"]
# Find a matching record in metadata to use as query
query_meta = next(
(m for m in metadata
if m.get("category") == q_cat
and q_defect in m.get("defect_type", "")),
None
)
if query_meta is None:
continue
query_idx = query_meta["index"]
# Reconstruct embedding from index (not stored in metadata)
# Use a zero vector as proxy β in production pass actual embedding
query_vec = np.zeros((1, 512), dtype=np.float32)
D, I = index2.search(query_vec, k=6)
# Skip self-match
retrieved = [
metadata[i] for i in I[0]
if i >= 0 and i != query_idx
][:5]
if not retrieved:
continue
# Precision@1
p1 = 1.0 if retrieved[0].get("category") == q_cat else 0.0
precision_at_1.append(p1)
# Precision@5
correct = sum(1 for r in retrieved if r.get("category") == q_cat)
precision_at_5.append(correct / min(5, len(retrieved)))
# MRR
rr = 0.0
for rank, r in enumerate(retrieved, 1):
if r.get("category") == q_cat:
rr = 1.0 / rank
break
reciprocal_ranks.append(rr)
results = {
"precision_at_1": float(np.mean(precision_at_1)) if precision_at_1 else 0.0,
"precision_at_5": float(np.mean(precision_at_5)) if precision_at_5 else 0.0,
"mrr": float(np.mean(reciprocal_ranks)) if reciprocal_ranks else 0.0,
"n_evaluated": len(precision_at_1)
}
print(f"Retrieval Evaluation Results:")
print(f" Precision@1: {results['precision_at_1']:.4f}")
print(f" Precision@5: {results['precision_at_5']:.4f}")
print(f" MRR: {results['mrr']:.4f}")
print(f" Evaluated: {results['n_evaluated']} queries")
# Log to MLflow
try:
dagshub.init(repo_owner="devangmishra1424",
repo_name="AnomalyOS", mlflow=True)
with mlflow.start_run(run_name="retrieval_evaluation"):
mlflow.log_metrics(results)
print("Logged to MLflow")
except Exception as e:
print(f"MLflow logging failed: {e}")
return results
if __name__ == "__main__":
evaluate_retrieval(
index2_metadata_path="data/index2_metadata.json",
index2_faiss_path="data/index2_defect.faiss"
) |