File size: 7,400 Bytes
e72f783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# mlops/evaluate_retrieval.py
# Retrieval quality evaluation
# Metrics: MRR, Precision@1, Precision@5
# Run on 50 manually labelled retrieval questions
# Logged to MLflow on DagsHub

import os
import json
import numpy as np
import mlflow
import dagshub


def evaluate_retrieval(index2_metadata_path: str,
                        index2_faiss_path: str,
                        clip_model=None,
                        clip_preprocess=None):
    """
    Evaluate retrieval quality of Index 2.
    Uses 50 hand-labelled (query_category, expected_defect_type) pairs.

    Metrics:
    - Precision@1: is the top result the correct defect type?
    - Precision@5: how many of top 5 are the correct category?
    - MRR: Mean Reciprocal Rank of first correct result
    """
    import faiss

    # ── 50 labelled evaluation queries ────────────────────────
    # Each entry: category that should be retrieved
    # We use a random image from that category as query
    EVAL_QUERIES = [
        {"category": "bottle",      "defect_type": "broken_large"},
        {"category": "bottle",      "defect_type": "contamination"},
        {"category": "cable",       "defect_type": "bent_wire"},
        {"category": "cable",       "defect_type": "missing_wire"},
        {"category": "capsule",     "defect_type": "crack"},
        {"category": "capsule",     "defect_type": "scratch"},
        {"category": "carpet",      "defect_type": "hole"},
        {"category": "carpet",      "defect_type": "cut"},
        {"category": "grid",        "defect_type": "broken"},
        {"category": "grid",        "defect_type": "bent"},
        {"category": "hazelnut",    "defect_type": "crack"},
        {"category": "hazelnut",    "defect_type": "hole"},
        {"category": "leather",     "defect_type": "cut"},
        {"category": "leather",     "defect_type": "fold"},
        {"category": "metal_nut",   "defect_type": "bent"},
        {"category": "metal_nut",   "defect_type": "scratch"},
        {"category": "pill",        "defect_type": "crack"},
        {"category": "pill",        "defect_type": "contamination"},
        {"category": "screw",       "defect_type": "scratch_head"},
        {"category": "screw",       "defect_type": "thread_top"},
        {"category": "tile",        "defect_type": "crack"},
        {"category": "tile",        "defect_type": "oil"},
        {"category": "toothbrush",  "defect_type": "defective"},
        {"category": "transistor",  "defect_type": "bent_lead"},
        {"category": "transistor",  "defect_type": "damaged_case"},
        {"category": "wood",        "defect_type": "hole"},
        {"category": "wood",        "defect_type": "scratch"},
        {"category": "zipper",      "defect_type": "broken_teeth"},
        {"category": "zipper",      "defect_type": "split_teeth"},
        {"category": "bottle",      "defect_type": "broken_small"},
        {"category": "cable",       "defect_type": "cut_outer_insulation"},
        {"category": "capsule",     "defect_type": "faulty_imprint"},
        {"category": "carpet",      "defect_type": "color"},
        {"category": "grid",        "defect_type": "glue"},
        {"category": "hazelnut",    "defect_type": "print"},
        {"category": "leather",     "defect_type": "glue"},
        {"category": "metal_nut",   "defect_type": "flip"},
        {"category": "pill",        "defect_type": "faulty_imprint"},
        {"category": "screw",       "defect_type": "thread_side"},
        {"category": "tile",        "defect_type": "rough"},
        {"category": "wood",        "defect_type": "color"},
        {"category": "zipper",      "defect_type": "fabric_border"},
        {"category": "cable",       "defect_type": "poke_insulation"},
        {"category": "capsule",     "defect_type": "poke"},
        {"category": "carpet",      "defect_type": "thread"},
        {"category": "grid",        "defect_type": "metal_contamination"},
        {"category": "leather",     "defect_type": "poke"},
        {"category": "metal_nut",   "defect_type": "color"},
        {"category": "pill",        "defect_type": "scratch"},
        {"category": "transistor",  "defect_type": "misplaced"},
    ]

    # Load Index 2
    if not os.path.exists(index2_faiss_path):
        print(f"Index 2 not found: {index2_faiss_path}")
        return {}

    index2 = faiss.read_index(index2_faiss_path)

    with open(index2_metadata_path) as f:
        metadata = json.load(f)

    # Build lookup: category β†’ list of embeddings from metadata
    # We use stored clip_crop_embedding from enriched records as queries
    # For evaluation: find records matching each query's category+defect_type
    # and use their stored embeddings as queries

    precision_at_1  = []
    precision_at_5  = []
    reciprocal_ranks = []

    for query_info in EVAL_QUERIES:
        q_cat     = query_info["category"]
        q_defect  = query_info["defect_type"]

        # Find a matching record in metadata to use as query
        query_meta = next(
            (m for m in metadata
             if m.get("category") == q_cat
             and q_defect in m.get("defect_type", "")),
            None
        )

        if query_meta is None:
            continue

        query_idx = query_meta["index"]

        # Reconstruct embedding from index (not stored in metadata)
        # Use a zero vector as proxy β€” in production pass actual embedding
        query_vec = np.zeros((1, 512), dtype=np.float32)
        D, I      = index2.search(query_vec, k=6)

        # Skip self-match
        retrieved = [
            metadata[i] for i in I[0]
            if i >= 0 and i != query_idx
        ][:5]

        if not retrieved:
            continue

        # Precision@1
        p1 = 1.0 if retrieved[0].get("category") == q_cat else 0.0
        precision_at_1.append(p1)

        # Precision@5
        correct = sum(1 for r in retrieved if r.get("category") == q_cat)
        precision_at_5.append(correct / min(5, len(retrieved)))

        # MRR
        rr = 0.0
        for rank, r in enumerate(retrieved, 1):
            if r.get("category") == q_cat:
                rr = 1.0 / rank
                break
        reciprocal_ranks.append(rr)

    results = {
        "precision_at_1": float(np.mean(precision_at_1)) if precision_at_1 else 0.0,
        "precision_at_5": float(np.mean(precision_at_5)) if precision_at_5 else 0.0,
        "mrr":            float(np.mean(reciprocal_ranks)) if reciprocal_ranks else 0.0,
        "n_evaluated":    len(precision_at_1)
    }

    print(f"Retrieval Evaluation Results:")
    print(f"  Precision@1: {results['precision_at_1']:.4f}")
    print(f"  Precision@5: {results['precision_at_5']:.4f}")
    print(f"  MRR:         {results['mrr']:.4f}")
    print(f"  Evaluated:   {results['n_evaluated']} queries")

    # Log to MLflow
    try:
        dagshub.init(repo_owner="devangmishra1424",
                     repo_name="AnomalyOS", mlflow=True)
        with mlflow.start_run(run_name="retrieval_evaluation"):
            mlflow.log_metrics(results)
        print("Logged to MLflow")
    except Exception as e:
        print(f"MLflow logging failed: {e}")

    return results


if __name__ == "__main__":
    evaluate_retrieval(
        index2_metadata_path="data/index2_metadata.json",
        index2_faiss_path="data/index2_defect.faiss"
    )