File size: 21,162 Bytes
bb5f7e2
 
 
b81c9f2
bb5f7e2
 
 
 
 
 
 
 
 
 
 
 
 
b81c9f2
bb5f7e2
 
 
 
 
 
 
cd30a0d
 
bb5f7e2
 
 
 
 
 
 
 
91771b8
 
 
 
 
 
 
bb5f7e2
91771b8
 
bb5f7e2
 
 
 
b81c9f2
 
bb5f7e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91771b8
2e6a9a9
91771b8
 
bb5f7e2
 
 
 
 
 
 
 
cd30a0d
bb5f7e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd30a0d
 
 
 
 
 
 
 
 
bb5f7e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db2a1fe
bb5f7e2
 
 
2740843
bb5f7e2
 
 
 
 
 
 
 
 
 
 
db2a1fe
bb5f7e2
 
 
 
 
 
2740843
bb5f7e2
 
 
 
2e6a9a9
bb5f7e2
 
 
2740843
bb5f7e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2740843
bb5f7e2
 
 
 
 
 
db2a1fe
bb5f7e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2740843
bb5f7e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd30a0d
b81c9f2
bb5f7e2
cd30a0d
db2a1fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd30a0d
 
 
 
 
db2a1fe
cd30a0d
 
db2a1fe
 
 
 
 
 
 
 
 
 
bb5f7e2
2740843
bb5f7e2
 
 
 
 
db2a1fe
 
bb5f7e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db2a1fe
bb5f7e2
 
 
 
db2a1fe
bb5f7e2
 
2740843
bb5f7e2
 
 
 
 
2740843
91771b8
bb5f7e2
2740843
bb5f7e2
 
 
db2a1fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
"""
================================================================================
  SimCLR ResNet-50 Visual Search Engine GUI
--------------------------------------------------------------------------------
  A premium, interactive web-based graphical user interface (GUI) built with
  Streamlit, powered by ONNX Runtime and FAISS for real-time visual similarity
  retrieval on CIFAR-10.

  Key Features:
    1. Premium Dark Mode UI with tailored CSS glassmorphism styling
    2. Upload any local image or select a random test image from CIFAR-10
    3. Lightning-fast feature extraction (2048-d) via optimized ONNX Runtime
    4. Sub-millisecond exact Cosine Similarity search using FAISS vector database
    5. Clean visual results grid with similarity scores and class matches
    6. Dedicated interactive "Ablation Study Dashboard" showing the Exp 38-42 findings

  Usage:
    streamlit run src/streamlit_app.py
================================================================================
"""
import os
import sys
import time
import json
import random
import base64
from io import BytesIO
import numpy as np
from PIL import Image
import streamlit as st

# Windows console encoding fix
sys.stdout.reconfigure(encoding='utf-8')
sys.stderr.reconfigure(encoding='utf-8')

# ==========================================================================
# 1. Robust Path Resolution (Works for Local & Hugging Face Spaces)
# ==========================================================================
# Hugging Face sets the Current Working Directory (CWD) to the repository root
ROOT_DIR = os.getcwd() 
SRC_DIR = os.path.join(ROOT_DIR, "src")

# Ensure we can import from src/
if SRC_DIR not in sys.path:
    sys.path.append(SRC_DIR)

# ==========================================================================
# Caching Assets for Instant Performance
# ==========================================================================
DEPLOY_DIR = os.path.join(ROOT_DIR, "deployment")

ONNX_PATH = os.path.join(DEPLOY_DIR, "simclr_encoder_exp41.onnx")
INDEX_PATH = os.path.join(DEPLOY_DIR, "cifar10_index.faiss")
META_PATH = os.path.join(DEPLOY_DIR, "metadata.json")

CIFAR_MEAN = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32)
CIFAR_STD = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32)

@st.cache_resource
def load_onnx_model():
    """Load ONNX inference session and cache it."""
    import onnxruntime as ort
    # CPU is fast enough for single image inference, ensures compatibility
    providers = ['CPUExecutionProvider']
    return ort.InferenceSession(ONNX_PATH, providers=providers)

@st.cache_resource
def load_faiss_index():
    """Load FAISS index and cache it."""
    import faiss
    return faiss.read_index(INDEX_PATH)

@st.cache_data
def load_metadata():
    """Load metadata dictionary and cache it."""
    with open(META_PATH, "r", encoding="utf-8") as f:
        return json.load(f)

# -- Check that deployment files exist ---------------------------------------
missing_files = []
for p, name in [(ONNX_PATH, "ONNX Model"), (INDEX_PATH, "FAISS Index"), (META_PATH, "Metadata JSON")]:
    if not os.path.exists(p):
        missing_files.append(name)

if missing_files:
    # Debugging Info: Prints exact paths for easy troubleshooting on HF
    st.error(f" Missing deployment assets: {', '.join(missing_files)}")
    st.warning(" Please ensure the 'deployment' folder contains the required ONNX, FAISS, and JSON files.")
    st.code(f"Current Root Directory: {ROOT_DIR}\nLooking in: {DEPLOY_DIR}", language="bash")
    st.stop()

# Load cached assets
ort_sess = load_onnx_model()
faiss_index = load_faiss_index()
metadata = load_metadata()

# ==========================================================================
# Preprocessing & Helper Logic
# ==========================================================================
def preprocess_image(pil_img):
    """
    Resizes image to 32x32 and applies CIFAR-10 mean/std normalization.
    Returns: numpy array of shape (1, 3, 32, 32)
    """
    # 1. Resize to 32x32
    img_resized = pil_img.resize((32, 32), Image.Resampling.BILINEAR)
    
    # 2. Convert to numpy array and scale to [0, 1]
    img_np = np.array(img_resized, dtype=np.float32) / 255.0
    
    # Handle grayscale images if uploaded
    if len(img_np.shape) == 2:
        img_np = np.stack([img_np, img_np, img_np], axis=-1)
    elif img_np.shape[2] == 4:
        img_np = img_np[:, :, :3] # Drop alpha channel
        
    # 3. Normalize: (x - mean) / std
    img_normalized = (img_np - CIFAR_MEAN) / CIFAR_STD
    
    # 4. Transpose to channel-first (HWC -> CHW) and add batch dimension (1, C, H, W)
    img_transposed = np.transpose(img_normalized, (2, 0, 1))
    img_batch = np.expand_dims(img_transposed, axis=0)
    
    return img_batch

def perform_search(features_batch, top_k=5):
    """
    Normalizes features, queries FAISS, and returns matched metadata.
    """
    # 1. L2-Normalize vector
    norm = np.linalg.norm(features_batch, axis=1, keepdims=True)
    features_normalized = features_batch / np.maximum(norm, 1e-12)
    
    # 2. Search FAISS Index
    similarities, indices = faiss_index.search(features_normalized, top_k)
    
    results = []
    for rank in range(top_k):
        vector_id = int(indices[0][rank])
        score = float(similarities[0][rank])
        
        # Load from metadata catalog
        match_info = metadata["images"][vector_id]
        
        results.append({
            "rank": rank + 1,
            "id": vector_id,
            "similarity": score,
            "image_path": os.path.join(DEPLOY_DIR, match_info["image_path"]),
            "class_name": match_info["class_name"],
            "label_id": match_info["label_id"]
        })
        
    return results

def get_image_base64(image_path):
    """Converts a local image to base64 string for direct HTML rendering."""
    if os.path.exists(image_path):
        img = Image.open(image_path)
        buffered = BytesIO()
        img.save(buffered, format="PNG")
        return base64.b64encode(buffered.getvalue()).decode()
    return ""

# ==========================================================================
# Streamlit UI Configuration
# ==========================================================================
st.set_page_config(
    page_title="SimCLR ResNet-50 Visual Search",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS for Premium Styling
st.markdown("""
<style>
    /* Dark theme customizations */
    .stApp {
        background-color: #0F172A;
        color: #E2E8F0;
    }
    
    /* Title and Header designs */
    h1 {
        background: linear-gradient(90deg, #38BDF8, #818CF8);
        -webkit-background-clip: text;
        -webkit-text-fill-color: transparent;
        font-family: 'Outfit', sans-serif;
        font-weight: 800;
    }
    
    /* Custom Card/Glassmorphism block */
    .glass-card {
        background: rgba(30, 41, 59, 0.7);
        border: 1px solid rgba(255, 255, 255, 0.05);
        border-radius: 12px;
        padding: 20px;
        margin-bottom: 20px;
        box-shadow: 0 4px 30px rgba(0, 0, 0, 0.2);
        backdrop-filter: blur(5px);
    }
    
    /* Metric styling */
    .metric-value {
        font-size: 2.2rem;
        font-weight: 700;
        color: #38BDF8;
    }
    .metric-label {
        font-size: 0.9rem;
        color: #94A3B8;
        text-transform: uppercase;
        letter-spacing: 0.05em;
    }
</style>
""", unsafe_allow_html=True)

# ==========================================================================
# Sidebar Configuration
# ==========================================================================
st.sidebar.markdown("### Engine Settings")
top_k_slider = st.sidebar.slider("Number of results (K)", min_value=3, max_value=20, value=6, step=1)

st.sidebar.markdown("---")
st.sidebar.markdown("### Trained Backbone Model")
st.sidebar.markdown("""
* **Architecture**: ResNet-50 (CIFAR STEM adjusted)
* **Training Type**: Self-Supervised (SimCLR)
* **Best Experiment**: **Exp 41**
* **Augmentations**: Crop + Flip + Blur + Color Jitter
* **Midterm Baseline**: 64.49%
* **Final Linear Probe Accuracy**: **84.30%** (+19.81% Gain!)
* **Embedding Dimension**: 2048
""")

st.sidebar.markdown("---")
st.sidebar.markdown("### Backend Pipeline")
st.sidebar.markdown("""
* **Inference Engine**: ONNX Runtime CPU
* **Vector Database**: FAISS (IndexFlatIP)
* **Similarity Metric**: Exact Cosine Similarity
* **Database Size**: 10,000 Images
""")
st.sidebar.caption("Mahmoud Alyosify - Natalie Monged & Mirna Embaby, CISC 867, Queen's University, Spring 2026")

# ==========================================================================
# Main Title
# ==========================================================================
st.title("Real-time SimCLR Image Retrieval Engine")
st.markdown("##### Self-Supervised Representation Learning with ResNet-50 & FAISS Indexing")

# Define Tabs
tab1, tab2 = st.tabs(["Real-Time Search", "Ablation Study Dashboard"])

# ==========================================================================
# TAB 1: Visual Search Engine
# ==========================================================================
with tab1:
    col_left, col_right = st.columns([1, 2], gap="large")
    
    # Initialize session state for random query image selection
    if "random_img_id" not in st.session_state:
        st.session_state.random_img_id = None
        
    query_image = None
    query_info = None

    with col_left:
        st.markdown("### Select Query Image")
        
        # Method 1: Upload a file
        uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"])
        
        # Method 2: Pick a random image from the database
        st.markdown("<p style='text-align: center; color: #94A3B8; margin: 10px 0;'>— OR —</p>", unsafe_allow_html=True)
        if st.button("Pick a Random Test Image from Database", use_container_width=True):
            st.session_state.random_img_id = random.randint(0, len(metadata["images"]) - 1)
            
        # Display selected/uploaded image
        if uploaded_file is not None:
            st.session_state.random_img_id = None # Clear random image
            query_image = Image.open(uploaded_file).convert("RGB")
            st.image(query_image, caption="Uploaded Query Image", use_container_width=True)
            
        elif st.session_state.random_img_id is not None:
            idx = st.session_state.random_img_id
            img_info = metadata["images"][idx]
            local_path = os.path.join(DEPLOY_DIR, img_info["image_path"])
            
            if os.path.exists(local_path):
                query_image = Image.open(local_path).convert("RGB")
                query_info = img_info
                st.image(
                    query_image, 
                    caption=f"Selected Reference Image #{idx} (Class: {img_info['class_name']})", 
                    use_container_width=True
                )
            else:
                st.error("Reference image file not found.")

    with col_right:
        st.markdown("### Visual Similarity Search Results")
        
        if query_image is not None:
            t_start = time.time()
            
            # 1. Preprocess
            batch_img = preprocess_image(query_image)
            
            # 2. ONNX Inference
            onnx_outputs = ort_sess.run(None, {"images": batch_img})
            feature_vector = onnx_outputs[0] # (1, 2048)
            
            # 3. FAISS Query
            t_search_start = time.time()
            results = perform_search(feature_vector, top_k=top_k_slider)
            
            t_total_ms = (time.time() - t_start) * 1000
            t_search_ms = (time.time() - t_search_start) * 1000
            
            # Show Metrics
            m_col1, m_col2, m_col3 = st.columns(3)
            with m_col1:
                st.markdown(f'<div class="glass-card"><div class="metric-value">{t_search_ms:.2f} ms</div><div class="metric-label">FAISS Query Time</div></div>', unsafe_allow_html=True)
            with m_col2:
                st.markdown(f'<div class="glass-card"><div class="metric-value">{t_total_ms:.1f} ms</div><div class="metric-label">Total Pipeline Latency</div></div>', unsafe_allow_html=True)
            with m_col3:
                # Calculate class precision
                if query_info is not None:
                    matches = sum(1 for r in results if r["class_name"] == query_info["class_name"])
                    precision = (matches / len(results)) * 100
                    st.markdown(f'<div class="glass-card"><div class="metric-value">{precision:.1f}%</div><div class="metric-label">Top-{len(results)} Class Precision</div></div>', unsafe_allow_html=True)
                else:
                    st.markdown(f'<div class="glass-card"><div class="metric-value">N/A</div><div class="metric-label">Upload class unknown</div></div>', unsafe_allow_html=True)
            
            # Display Results Grid (Fixed completely with Base64 HTML rendering)
            st.markdown("#### Retrieved Nearest Neighbors")
            
            # Create dynamic rows
            for i in range(0, len(results), 3):
                cols = st.columns(3)
                for j in range(3):
                    if i + j < len(results):
                        res = results[i + j]
                        with cols[j]:
                            # Determine styling color based on class match if query info exists
                            is_match = False
                            border_color = "rgba(255, 255, 255, 0.05)"
                            bg_color = "rgba(30, 41, 59, 0.7)"
                            badge_html = f'<span style="background-color: #475569; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; font-weight: 600; color: white;">{res["class_name"]}</span>'
                            
                            if query_info is not None:
                                is_match = (res["class_name"] == query_info["class_name"])
                                if is_match:
                                    border_color = "rgba(16, 185, 129, 0.3)"
                                    bg_color = "rgba(6, 78, 59, 0.3)"
                                    badge_html = f'<span style="background-color: #10B981; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; font-weight: 600; color: white;">{res["class_name"]} (MATCH)</span>'
                                else:
                                    border_color = "rgba(239, 68, 68, 0.2)"
                                    bg_color = "rgba(127, 29, 29, 0.1)"
                                    badge_html = f'<span style="background-color: #EF4444; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; font-weight: 600; color: white;">{res["class_name"]}</span>'
                            
                            # Convert local image to base64
                            img_b64 = get_image_base64(res["image_path"])
                            img_html = f'<img src="data:image/png;base64,{img_b64}" style="width: 100%; border-radius: 8px; margin-bottom: 12px; display: block;">' if img_b64 else '<div style="color: #EF4444; padding: 20px; border: 1px solid #EF4444; border-radius: 8px; margin-bottom: 12px;">Image missing</div>'
                            
                            # Custom card container integrating image and text as a single HTML block
                            card_content = f"""
                            <div style="background: {bg_color}; border: 1px solid {border_color}; border-radius: 8px; padding: 12px; margin-bottom: 12px; text-align: center; height: 100%;">
                                {img_html}
                                <div style="color: #38BDF8; font-size: 1.1rem; font-weight: 700; margin-bottom: 8px;">Rank #{res["rank"]}</div>
                                <div style="margin-bottom: 10px;">{badge_html}</div>
                                <div style="color: #94A3B8; font-size: 0.9rem; margin-bottom: 4px;">Cosine Similarity:</div>
                                <div style="color: #E2E8F0; font-size: 1.4rem; font-weight: 700; margin-bottom: 10px;">{res["similarity"] * 100:.2f}%</div>
                                <div style="color: #64748B; font-size: 0.8rem;">Vector ID: {res["id"]}</div>
                            </div>
                            """
                            
                            st.markdown(card_content, unsafe_allow_html=True)
                            
        else:
            st.info("Please upload a custom image or click the button to select a random test image to query the visual search engine!")

# ==========================================================================
# TAB 2: Ablation Study Dashboard
# ==========================================================================
with tab2:
    st.markdown("### Experiment Ablation Study & Color Jitter Findings")
    st.markdown("##### Evaluating Midterm Pipelines against Jittered Re-runs")
    
    st.markdown("""
    This section presents the results of the complete **Color Jitter Shortcut-Learning Ablation Study** (Experiments 38-42).
    By comparing the self-supervised representations learned with and without photometric distortion (Color Jitter), we demonstrate how SimCLR ResNet-50 encoders learn rich semantic representations instead of exploiting simple pixel-level color shortcuts.
    """)
    
    # 3x2 Grid for Ablation Statistics
    ablation_data = [
        {"exp": "Exp 38 vs 36", "desc": "Pure Discrete Rotation", "base": "34.40%", "jitter": "51.21%", "gain": "+16.81 pp"},
        {"exp": "Exp 39 vs 35", "desc": "Weak Spatial Baseline", "base": "59.22%", "jitter": "80.53%", "gain": "+21.31 pp"},
        {"exp": "Exp 40 vs 9",  "desc": "Crop + Gaussian Blur", "base": "63.01%", "jitter": "80.65%", "gain": "+17.64 pp"},
        {"exp": "Exp 41 vs 13", "desc": "Crop + Flip + Blur", "base": "64.49%", "jitter": "84.30%", "gain": "+19.81 pp"},
        {"exp": "Exp 42 vs 10", "desc": "Crop + Random Cutout", "base": "66.27%", "jitter": "81.21%", "gain": "+14.94 pp"}
    ]
    
    cols_ablation = st.columns(5)
    for i, data in enumerate(ablation_data):
        with cols_ablation[i]:
            card_html = f"""
            <div style="background: rgba(30, 41, 59, 0.5); border: 1px solid rgba(56, 189, 248, 0.2); border-radius: 8px; padding: 15px; text-align: center;">
                <div style="color: #38BDF8; font-size: 1.1rem; font-weight: 700;">{data["exp"]}</div>
                <div style="color: #94A3B8; font-size: 0.85rem; height: 35px; overflow: hidden; margin-top: 4px;">{data["desc"]}</div>
                <hr style="border: 0; border-top: 1px solid rgba(255,255,255,0.05); margin: 8px 0;"/>
                <div style="color: #94A3B8; font-size: 0.8rem;">Without Jitter: <b>{data["base"]}</b></div>
                <div style="color: #10B981; font-size: 0.85rem; font-weight: 600; margin-top: 2px;">With Jitter: {data["jitter"]}</div>
                <div style="background-color: rgba(16, 185, 129, 0.15); color: #10B981; border-radius: 4px; padding: 2px 6px; font-size: 0.9rem; font-weight: 700; margin-top: 8px; display: inline-block;">
                    {data["gain"]} Gain
                </div>
            </div>
            """
            st.markdown(card_html, unsafe_allow_html=True)
            
    st.markdown("---")
    st.markdown("#### Key Project Takeaways")
    col1, col2 = st.columns(2)
    
    with col1:
        st.markdown("""
        ##### 1. The Shortcut-Learning Problem
        Without Color Jitter, contrastive self-supervised encoders exploit low-level **color histograms** as a shortcut to maximize mutual information, rather than learning general shapes and semantic features. This results in weaker downstream representations, showing a severe performance limit (e.g. baseline Crop+Flip+Blur achieves only **64.49%**).
        
        ##### 2. The Color Jitter Shield 
        Adding photometric distortion (color jittering) forces the model to ignore color profiles and focus on invariant structures, spatial boundaries, and contours. This single ablation yields a massive average boost of **+18.1 pp** across all settings, pushing our best encoder (Exp 41) to a stellar **84.30% Top-1 accuracy**!
        """)
        
    with col2:
        st.markdown("""
        ##### 3. Model Architecture & Stem Tuning 
        Modifying the standard ResNet-50 conv1 stem from 3x3 (stride 1) and removing the initial MaxPool was crucial to preserve the resolution of 32x32 CIFAR-10 images. 
        
        ##### 4. Near Foundation-Model Upper Bound 
        Our zero-shot CLIP ViT-B/32 foundation model evaluation sets the academic upper bound at **88.80%**. Our custom-trained SimCLR ResNet-50 achieves **95% of this performance** (**84.30%**) while using **8,000x less training data**!
        """)

    st.success("Weights & Biases Dashboard has archived all 5 experiments complete with training curves, checkpoints, and t-SNE files.")