File size: 21,162 Bytes
bb5f7e2 b81c9f2 bb5f7e2 b81c9f2 bb5f7e2 cd30a0d bb5f7e2 91771b8 bb5f7e2 91771b8 bb5f7e2 b81c9f2 bb5f7e2 91771b8 2e6a9a9 91771b8 bb5f7e2 cd30a0d bb5f7e2 cd30a0d bb5f7e2 db2a1fe bb5f7e2 2740843 bb5f7e2 db2a1fe bb5f7e2 2740843 bb5f7e2 2e6a9a9 bb5f7e2 2740843 bb5f7e2 2740843 bb5f7e2 db2a1fe bb5f7e2 2740843 bb5f7e2 cd30a0d b81c9f2 bb5f7e2 cd30a0d db2a1fe cd30a0d db2a1fe cd30a0d db2a1fe bb5f7e2 2740843 bb5f7e2 db2a1fe bb5f7e2 db2a1fe bb5f7e2 db2a1fe bb5f7e2 2740843 bb5f7e2 2740843 91771b8 bb5f7e2 2740843 bb5f7e2 db2a1fe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 | """
================================================================================
SimCLR ResNet-50 Visual Search Engine GUI
--------------------------------------------------------------------------------
A premium, interactive web-based graphical user interface (GUI) built with
Streamlit, powered by ONNX Runtime and FAISS for real-time visual similarity
retrieval on CIFAR-10.
Key Features:
1. Premium Dark Mode UI with tailored CSS glassmorphism styling
2. Upload any local image or select a random test image from CIFAR-10
3. Lightning-fast feature extraction (2048-d) via optimized ONNX Runtime
4. Sub-millisecond exact Cosine Similarity search using FAISS vector database
5. Clean visual results grid with similarity scores and class matches
6. Dedicated interactive "Ablation Study Dashboard" showing the Exp 38-42 findings
Usage:
streamlit run src/streamlit_app.py
================================================================================
"""
import os
import sys
import time
import json
import random
import base64
from io import BytesIO
import numpy as np
from PIL import Image
import streamlit as st
# Windows console encoding fix
sys.stdout.reconfigure(encoding='utf-8')
sys.stderr.reconfigure(encoding='utf-8')
# ==========================================================================
# 1. Robust Path Resolution (Works for Local & Hugging Face Spaces)
# ==========================================================================
# Hugging Face sets the Current Working Directory (CWD) to the repository root
ROOT_DIR = os.getcwd()
SRC_DIR = os.path.join(ROOT_DIR, "src")
# Ensure we can import from src/
if SRC_DIR not in sys.path:
sys.path.append(SRC_DIR)
# ==========================================================================
# Caching Assets for Instant Performance
# ==========================================================================
DEPLOY_DIR = os.path.join(ROOT_DIR, "deployment")
ONNX_PATH = os.path.join(DEPLOY_DIR, "simclr_encoder_exp41.onnx")
INDEX_PATH = os.path.join(DEPLOY_DIR, "cifar10_index.faiss")
META_PATH = os.path.join(DEPLOY_DIR, "metadata.json")
CIFAR_MEAN = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32)
CIFAR_STD = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32)
@st.cache_resource
def load_onnx_model():
"""Load ONNX inference session and cache it."""
import onnxruntime as ort
# CPU is fast enough for single image inference, ensures compatibility
providers = ['CPUExecutionProvider']
return ort.InferenceSession(ONNX_PATH, providers=providers)
@st.cache_resource
def load_faiss_index():
"""Load FAISS index and cache it."""
import faiss
return faiss.read_index(INDEX_PATH)
@st.cache_data
def load_metadata():
"""Load metadata dictionary and cache it."""
with open(META_PATH, "r", encoding="utf-8") as f:
return json.load(f)
# -- Check that deployment files exist ---------------------------------------
missing_files = []
for p, name in [(ONNX_PATH, "ONNX Model"), (INDEX_PATH, "FAISS Index"), (META_PATH, "Metadata JSON")]:
if not os.path.exists(p):
missing_files.append(name)
if missing_files:
# Debugging Info: Prints exact paths for easy troubleshooting on HF
st.error(f" Missing deployment assets: {', '.join(missing_files)}")
st.warning(" Please ensure the 'deployment' folder contains the required ONNX, FAISS, and JSON files.")
st.code(f"Current Root Directory: {ROOT_DIR}\nLooking in: {DEPLOY_DIR}", language="bash")
st.stop()
# Load cached assets
ort_sess = load_onnx_model()
faiss_index = load_faiss_index()
metadata = load_metadata()
# ==========================================================================
# Preprocessing & Helper Logic
# ==========================================================================
def preprocess_image(pil_img):
"""
Resizes image to 32x32 and applies CIFAR-10 mean/std normalization.
Returns: numpy array of shape (1, 3, 32, 32)
"""
# 1. Resize to 32x32
img_resized = pil_img.resize((32, 32), Image.Resampling.BILINEAR)
# 2. Convert to numpy array and scale to [0, 1]
img_np = np.array(img_resized, dtype=np.float32) / 255.0
# Handle grayscale images if uploaded
if len(img_np.shape) == 2:
img_np = np.stack([img_np, img_np, img_np], axis=-1)
elif img_np.shape[2] == 4:
img_np = img_np[:, :, :3] # Drop alpha channel
# 3. Normalize: (x - mean) / std
img_normalized = (img_np - CIFAR_MEAN) / CIFAR_STD
# 4. Transpose to channel-first (HWC -> CHW) and add batch dimension (1, C, H, W)
img_transposed = np.transpose(img_normalized, (2, 0, 1))
img_batch = np.expand_dims(img_transposed, axis=0)
return img_batch
def perform_search(features_batch, top_k=5):
"""
Normalizes features, queries FAISS, and returns matched metadata.
"""
# 1. L2-Normalize vector
norm = np.linalg.norm(features_batch, axis=1, keepdims=True)
features_normalized = features_batch / np.maximum(norm, 1e-12)
# 2. Search FAISS Index
similarities, indices = faiss_index.search(features_normalized, top_k)
results = []
for rank in range(top_k):
vector_id = int(indices[0][rank])
score = float(similarities[0][rank])
# Load from metadata catalog
match_info = metadata["images"][vector_id]
results.append({
"rank": rank + 1,
"id": vector_id,
"similarity": score,
"image_path": os.path.join(DEPLOY_DIR, match_info["image_path"]),
"class_name": match_info["class_name"],
"label_id": match_info["label_id"]
})
return results
def get_image_base64(image_path):
"""Converts a local image to base64 string for direct HTML rendering."""
if os.path.exists(image_path):
img = Image.open(image_path)
buffered = BytesIO()
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
return ""
# ==========================================================================
# Streamlit UI Configuration
# ==========================================================================
st.set_page_config(
page_title="SimCLR ResNet-50 Visual Search",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for Premium Styling
st.markdown("""
<style>
/* Dark theme customizations */
.stApp {
background-color: #0F172A;
color: #E2E8F0;
}
/* Title and Header designs */
h1 {
background: linear-gradient(90deg, #38BDF8, #818CF8);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-family: 'Outfit', sans-serif;
font-weight: 800;
}
/* Custom Card/Glassmorphism block */
.glass-card {
background: rgba(30, 41, 59, 0.7);
border: 1px solid rgba(255, 255, 255, 0.05);
border-radius: 12px;
padding: 20px;
margin-bottom: 20px;
box-shadow: 0 4px 30px rgba(0, 0, 0, 0.2);
backdrop-filter: blur(5px);
}
/* Metric styling */
.metric-value {
font-size: 2.2rem;
font-weight: 700;
color: #38BDF8;
}
.metric-label {
font-size: 0.9rem;
color: #94A3B8;
text-transform: uppercase;
letter-spacing: 0.05em;
}
</style>
""", unsafe_allow_html=True)
# ==========================================================================
# Sidebar Configuration
# ==========================================================================
st.sidebar.markdown("### Engine Settings")
top_k_slider = st.sidebar.slider("Number of results (K)", min_value=3, max_value=20, value=6, step=1)
st.sidebar.markdown("---")
st.sidebar.markdown("### Trained Backbone Model")
st.sidebar.markdown("""
* **Architecture**: ResNet-50 (CIFAR STEM adjusted)
* **Training Type**: Self-Supervised (SimCLR)
* **Best Experiment**: **Exp 41**
* **Augmentations**: Crop + Flip + Blur + Color Jitter
* **Midterm Baseline**: 64.49%
* **Final Linear Probe Accuracy**: **84.30%** (+19.81% Gain!)
* **Embedding Dimension**: 2048
""")
st.sidebar.markdown("---")
st.sidebar.markdown("### Backend Pipeline")
st.sidebar.markdown("""
* **Inference Engine**: ONNX Runtime CPU
* **Vector Database**: FAISS (IndexFlatIP)
* **Similarity Metric**: Exact Cosine Similarity
* **Database Size**: 10,000 Images
""")
st.sidebar.caption("Mahmoud Alyosify - Natalie Monged & Mirna Embaby, CISC 867, Queen's University, Spring 2026")
# ==========================================================================
# Main Title
# ==========================================================================
st.title("Real-time SimCLR Image Retrieval Engine")
st.markdown("##### Self-Supervised Representation Learning with ResNet-50 & FAISS Indexing")
# Define Tabs
tab1, tab2 = st.tabs(["Real-Time Search", "Ablation Study Dashboard"])
# ==========================================================================
# TAB 1: Visual Search Engine
# ==========================================================================
with tab1:
col_left, col_right = st.columns([1, 2], gap="large")
# Initialize session state for random query image selection
if "random_img_id" not in st.session_state:
st.session_state.random_img_id = None
query_image = None
query_info = None
with col_left:
st.markdown("### Select Query Image")
# Method 1: Upload a file
uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"])
# Method 2: Pick a random image from the database
st.markdown("<p style='text-align: center; color: #94A3B8; margin: 10px 0;'>— OR —</p>", unsafe_allow_html=True)
if st.button("Pick a Random Test Image from Database", use_container_width=True):
st.session_state.random_img_id = random.randint(0, len(metadata["images"]) - 1)
# Display selected/uploaded image
if uploaded_file is not None:
st.session_state.random_img_id = None # Clear random image
query_image = Image.open(uploaded_file).convert("RGB")
st.image(query_image, caption="Uploaded Query Image", use_container_width=True)
elif st.session_state.random_img_id is not None:
idx = st.session_state.random_img_id
img_info = metadata["images"][idx]
local_path = os.path.join(DEPLOY_DIR, img_info["image_path"])
if os.path.exists(local_path):
query_image = Image.open(local_path).convert("RGB")
query_info = img_info
st.image(
query_image,
caption=f"Selected Reference Image #{idx} (Class: {img_info['class_name']})",
use_container_width=True
)
else:
st.error("Reference image file not found.")
with col_right:
st.markdown("### Visual Similarity Search Results")
if query_image is not None:
t_start = time.time()
# 1. Preprocess
batch_img = preprocess_image(query_image)
# 2. ONNX Inference
onnx_outputs = ort_sess.run(None, {"images": batch_img})
feature_vector = onnx_outputs[0] # (1, 2048)
# 3. FAISS Query
t_search_start = time.time()
results = perform_search(feature_vector, top_k=top_k_slider)
t_total_ms = (time.time() - t_start) * 1000
t_search_ms = (time.time() - t_search_start) * 1000
# Show Metrics
m_col1, m_col2, m_col3 = st.columns(3)
with m_col1:
st.markdown(f'<div class="glass-card"><div class="metric-value">{t_search_ms:.2f} ms</div><div class="metric-label">FAISS Query Time</div></div>', unsafe_allow_html=True)
with m_col2:
st.markdown(f'<div class="glass-card"><div class="metric-value">{t_total_ms:.1f} ms</div><div class="metric-label">Total Pipeline Latency</div></div>', unsafe_allow_html=True)
with m_col3:
# Calculate class precision
if query_info is not None:
matches = sum(1 for r in results if r["class_name"] == query_info["class_name"])
precision = (matches / len(results)) * 100
st.markdown(f'<div class="glass-card"><div class="metric-value">{precision:.1f}%</div><div class="metric-label">Top-{len(results)} Class Precision</div></div>', unsafe_allow_html=True)
else:
st.markdown(f'<div class="glass-card"><div class="metric-value">N/A</div><div class="metric-label">Upload class unknown</div></div>', unsafe_allow_html=True)
# Display Results Grid (Fixed completely with Base64 HTML rendering)
st.markdown("#### Retrieved Nearest Neighbors")
# Create dynamic rows
for i in range(0, len(results), 3):
cols = st.columns(3)
for j in range(3):
if i + j < len(results):
res = results[i + j]
with cols[j]:
# Determine styling color based on class match if query info exists
is_match = False
border_color = "rgba(255, 255, 255, 0.05)"
bg_color = "rgba(30, 41, 59, 0.7)"
badge_html = f'<span style="background-color: #475569; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; font-weight: 600; color: white;">{res["class_name"]}</span>'
if query_info is not None:
is_match = (res["class_name"] == query_info["class_name"])
if is_match:
border_color = "rgba(16, 185, 129, 0.3)"
bg_color = "rgba(6, 78, 59, 0.3)"
badge_html = f'<span style="background-color: #10B981; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; font-weight: 600; color: white;">{res["class_name"]} (MATCH)</span>'
else:
border_color = "rgba(239, 68, 68, 0.2)"
bg_color = "rgba(127, 29, 29, 0.1)"
badge_html = f'<span style="background-color: #EF4444; padding: 2px 8px; border-radius: 4px; font-size: 0.8rem; font-weight: 600; color: white;">{res["class_name"]}</span>'
# Convert local image to base64
img_b64 = get_image_base64(res["image_path"])
img_html = f'<img src="data:image/png;base64,{img_b64}" style="width: 100%; border-radius: 8px; margin-bottom: 12px; display: block;">' if img_b64 else '<div style="color: #EF4444; padding: 20px; border: 1px solid #EF4444; border-radius: 8px; margin-bottom: 12px;">Image missing</div>'
# Custom card container integrating image and text as a single HTML block
card_content = f"""
<div style="background: {bg_color}; border: 1px solid {border_color}; border-radius: 8px; padding: 12px; margin-bottom: 12px; text-align: center; height: 100%;">
{img_html}
<div style="color: #38BDF8; font-size: 1.1rem; font-weight: 700; margin-bottom: 8px;">Rank #{res["rank"]}</div>
<div style="margin-bottom: 10px;">{badge_html}</div>
<div style="color: #94A3B8; font-size: 0.9rem; margin-bottom: 4px;">Cosine Similarity:</div>
<div style="color: #E2E8F0; font-size: 1.4rem; font-weight: 700; margin-bottom: 10px;">{res["similarity"] * 100:.2f}%</div>
<div style="color: #64748B; font-size: 0.8rem;">Vector ID: {res["id"]}</div>
</div>
"""
st.markdown(card_content, unsafe_allow_html=True)
else:
st.info("Please upload a custom image or click the button to select a random test image to query the visual search engine!")
# ==========================================================================
# TAB 2: Ablation Study Dashboard
# ==========================================================================
with tab2:
st.markdown("### Experiment Ablation Study & Color Jitter Findings")
st.markdown("##### Evaluating Midterm Pipelines against Jittered Re-runs")
st.markdown("""
This section presents the results of the complete **Color Jitter Shortcut-Learning Ablation Study** (Experiments 38-42).
By comparing the self-supervised representations learned with and without photometric distortion (Color Jitter), we demonstrate how SimCLR ResNet-50 encoders learn rich semantic representations instead of exploiting simple pixel-level color shortcuts.
""")
# 3x2 Grid for Ablation Statistics
ablation_data = [
{"exp": "Exp 38 vs 36", "desc": "Pure Discrete Rotation", "base": "34.40%", "jitter": "51.21%", "gain": "+16.81 pp"},
{"exp": "Exp 39 vs 35", "desc": "Weak Spatial Baseline", "base": "59.22%", "jitter": "80.53%", "gain": "+21.31 pp"},
{"exp": "Exp 40 vs 9", "desc": "Crop + Gaussian Blur", "base": "63.01%", "jitter": "80.65%", "gain": "+17.64 pp"},
{"exp": "Exp 41 vs 13", "desc": "Crop + Flip + Blur", "base": "64.49%", "jitter": "84.30%", "gain": "+19.81 pp"},
{"exp": "Exp 42 vs 10", "desc": "Crop + Random Cutout", "base": "66.27%", "jitter": "81.21%", "gain": "+14.94 pp"}
]
cols_ablation = st.columns(5)
for i, data in enumerate(ablation_data):
with cols_ablation[i]:
card_html = f"""
<div style="background: rgba(30, 41, 59, 0.5); border: 1px solid rgba(56, 189, 248, 0.2); border-radius: 8px; padding: 15px; text-align: center;">
<div style="color: #38BDF8; font-size: 1.1rem; font-weight: 700;">{data["exp"]}</div>
<div style="color: #94A3B8; font-size: 0.85rem; height: 35px; overflow: hidden; margin-top: 4px;">{data["desc"]}</div>
<hr style="border: 0; border-top: 1px solid rgba(255,255,255,0.05); margin: 8px 0;"/>
<div style="color: #94A3B8; font-size: 0.8rem;">Without Jitter: <b>{data["base"]}</b></div>
<div style="color: #10B981; font-size: 0.85rem; font-weight: 600; margin-top: 2px;">With Jitter: {data["jitter"]}</div>
<div style="background-color: rgba(16, 185, 129, 0.15); color: #10B981; border-radius: 4px; padding: 2px 6px; font-size: 0.9rem; font-weight: 700; margin-top: 8px; display: inline-block;">
{data["gain"]} Gain
</div>
</div>
"""
st.markdown(card_html, unsafe_allow_html=True)
st.markdown("---")
st.markdown("#### Key Project Takeaways")
col1, col2 = st.columns(2)
with col1:
st.markdown("""
##### 1. The Shortcut-Learning Problem
Without Color Jitter, contrastive self-supervised encoders exploit low-level **color histograms** as a shortcut to maximize mutual information, rather than learning general shapes and semantic features. This results in weaker downstream representations, showing a severe performance limit (e.g. baseline Crop+Flip+Blur achieves only **64.49%**).
##### 2. The Color Jitter Shield
Adding photometric distortion (color jittering) forces the model to ignore color profiles and focus on invariant structures, spatial boundaries, and contours. This single ablation yields a massive average boost of **+18.1 pp** across all settings, pushing our best encoder (Exp 41) to a stellar **84.30% Top-1 accuracy**!
""")
with col2:
st.markdown("""
##### 3. Model Architecture & Stem Tuning
Modifying the standard ResNet-50 conv1 stem from 3x3 (stride 1) and removing the initial MaxPool was crucial to preserve the resolution of 32x32 CIFAR-10 images.
##### 4. Near Foundation-Model Upper Bound
Our zero-shot CLIP ViT-B/32 foundation model evaluation sets the academic upper bound at **88.80%**. Our custom-trained SimCLR ResNet-50 achieves **95% of this performance** (**84.30%**) while using **8,000x less training data**!
""")
st.success("Weights & Biases Dashboard has archived all 5 experiments complete with training curves, checkpoints, and t-SNE files.") |