""" ================================================================================ SimCLR ResNet-50 Visual Search Engine GUI -------------------------------------------------------------------------------- A premium, interactive web-based graphical user interface (GUI) built with Streamlit, powered by ONNX Runtime and FAISS for real-time visual similarity retrieval on CIFAR-10. Key Features: 1. Premium Dark Mode UI with tailored CSS glassmorphism styling 2. Upload any local image or select a random test image from CIFAR-10 3. Lightning-fast feature extraction (2048-d) via optimized ONNX Runtime 4. Sub-millisecond exact Cosine Similarity search using FAISS vector database 5. Clean visual results grid with similarity scores and class matches 6. Dedicated interactive "Ablation Study Dashboard" showing the Exp 38-42 findings Usage: streamlit run src/streamlit_app.py ================================================================================ """ import os import sys import time import json import random import base64 from io import BytesIO import numpy as np from PIL import Image import streamlit as st # Windows console encoding fix sys.stdout.reconfigure(encoding='utf-8') sys.stderr.reconfigure(encoding='utf-8') # ========================================================================== # 1. Robust Path Resolution (Works for Local & Hugging Face Spaces) # ========================================================================== # Hugging Face sets the Current Working Directory (CWD) to the repository root ROOT_DIR = os.getcwd() SRC_DIR = os.path.join(ROOT_DIR, "src") # Ensure we can import from src/ if SRC_DIR not in sys.path: sys.path.append(SRC_DIR) # ========================================================================== # Caching Assets for Instant Performance # ========================================================================== DEPLOY_DIR = os.path.join(ROOT_DIR, "deployment") ONNX_PATH = os.path.join(DEPLOY_DIR, "simclr_encoder_exp41.onnx") INDEX_PATH = os.path.join(DEPLOY_DIR, "cifar10_index.faiss") META_PATH = os.path.join(DEPLOY_DIR, "metadata.json") CIFAR_MEAN = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32) CIFAR_STD = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32) @st.cache_resource def load_onnx_model(): """Load ONNX inference session and cache it.""" import onnxruntime as ort # CPU is fast enough for single image inference, ensures compatibility providers = ['CPUExecutionProvider'] return ort.InferenceSession(ONNX_PATH, providers=providers) @st.cache_resource def load_faiss_index(): """Load FAISS index and cache it.""" import faiss return faiss.read_index(INDEX_PATH) @st.cache_data def load_metadata(): """Load metadata dictionary and cache it.""" with open(META_PATH, "r", encoding="utf-8") as f: return json.load(f) # -- Check that deployment files exist --------------------------------------- missing_files = [] for p, name in [(ONNX_PATH, "ONNX Model"), (INDEX_PATH, "FAISS Index"), (META_PATH, "Metadata JSON")]: if not os.path.exists(p): missing_files.append(name) if missing_files: # Debugging Info: Prints exact paths for easy troubleshooting on HF st.error(f" Missing deployment assets: {', '.join(missing_files)}") st.warning(" Please ensure the 'deployment' folder contains the required ONNX, FAISS, and JSON files.") st.code(f"Current Root Directory: {ROOT_DIR}\nLooking in: {DEPLOY_DIR}", language="bash") st.stop() # Load cached assets ort_sess = load_onnx_model() faiss_index = load_faiss_index() metadata = load_metadata() # ========================================================================== # Preprocessing & Helper Logic # ========================================================================== def preprocess_image(pil_img): """ Resizes image to 32x32 and applies CIFAR-10 mean/std normalization. Returns: numpy array of shape (1, 3, 32, 32) """ # 1. Resize to 32x32 img_resized = pil_img.resize((32, 32), Image.Resampling.BILINEAR) # 2. Convert to numpy array and scale to [0, 1] img_np = np.array(img_resized, dtype=np.float32) / 255.0 # Handle grayscale images if uploaded if len(img_np.shape) == 2: img_np = np.stack([img_np, img_np, img_np], axis=-1) elif img_np.shape[2] == 4: img_np = img_np[:, :, :3] # Drop alpha channel # 3. Normalize: (x - mean) / std img_normalized = (img_np - CIFAR_MEAN) / CIFAR_STD # 4. Transpose to channel-first (HWC -> CHW) and add batch dimension (1, C, H, W) img_transposed = np.transpose(img_normalized, (2, 0, 1)) img_batch = np.expand_dims(img_transposed, axis=0) return img_batch def perform_search(features_batch, top_k=5): """ Normalizes features, queries FAISS, and returns matched metadata. """ # 1. L2-Normalize vector norm = np.linalg.norm(features_batch, axis=1, keepdims=True) features_normalized = features_batch / np.maximum(norm, 1e-12) # 2. Search FAISS Index similarities, indices = faiss_index.search(features_normalized, top_k) results = [] for rank in range(top_k): vector_id = int(indices[0][rank]) score = float(similarities[0][rank]) # Load from metadata catalog match_info = metadata["images"][vector_id] results.append({ "rank": rank + 1, "id": vector_id, "similarity": score, "image_path": os.path.join(DEPLOY_DIR, match_info["image_path"]), "class_name": match_info["class_name"], "label_id": match_info["label_id"] }) return results def get_image_base64(image_path): """Converts a local image to base64 string for direct HTML rendering.""" if os.path.exists(image_path): img = Image.open(image_path) buffered = BytesIO() img.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode() return "" # ========================================================================== # Streamlit UI Configuration # ========================================================================== st.set_page_config( page_title="SimCLR ResNet-50 Visual Search", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS for Premium Styling st.markdown(""" """, unsafe_allow_html=True) # ========================================================================== # Sidebar Configuration # ========================================================================== st.sidebar.markdown("### Engine Settings") top_k_slider = st.sidebar.slider("Number of results (K)", min_value=3, max_value=20, value=6, step=1) st.sidebar.markdown("---") st.sidebar.markdown("### Trained Backbone Model") st.sidebar.markdown(""" * **Architecture**: ResNet-50 (CIFAR STEM adjusted) * **Training Type**: Self-Supervised (SimCLR) * **Best Experiment**: **Exp 41** * **Augmentations**: Crop + Flip + Blur + Color Jitter * **Midterm Baseline**: 64.49% * **Final Linear Probe Accuracy**: **84.30%** (+19.81% Gain!) * **Embedding Dimension**: 2048 """) st.sidebar.markdown("---") st.sidebar.markdown("### Backend Pipeline") st.sidebar.markdown(""" * **Inference Engine**: ONNX Runtime CPU * **Vector Database**: FAISS (IndexFlatIP) * **Similarity Metric**: Exact Cosine Similarity * **Database Size**: 10,000 Images """) st.sidebar.caption("Mahmoud Alyosify - Natalie Monged & Mirna Embaby, CISC 867, Queen's University, Spring 2026") # ========================================================================== # Main Title # ========================================================================== st.title("Real-time SimCLR Image Retrieval Engine") st.markdown("##### Self-Supervised Representation Learning with ResNet-50 & FAISS Indexing") # Define Tabs tab1, tab2 = st.tabs(["Real-Time Search", "Ablation Study Dashboard"]) # ========================================================================== # TAB 1: Visual Search Engine # ========================================================================== with tab1: col_left, col_right = st.columns([1, 2], gap="large") # Initialize session state for random query image selection if "random_img_id" not in st.session_state: st.session_state.random_img_id = None query_image = None query_info = None with col_left: st.markdown("### Select Query Image") # Method 1: Upload a file uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"]) # Method 2: Pick a random image from the database st.markdown("
— OR —
", unsafe_allow_html=True) if st.button("Pick a Random Test Image from Database", use_container_width=True): st.session_state.random_img_id = random.randint(0, len(metadata["images"]) - 1) # Display selected/uploaded image if uploaded_file is not None: st.session_state.random_img_id = None # Clear random image query_image = Image.open(uploaded_file).convert("RGB") st.image(query_image, caption="Uploaded Query Image", use_container_width=True) elif st.session_state.random_img_id is not None: idx = st.session_state.random_img_id img_info = metadata["images"][idx] local_path = os.path.join(DEPLOY_DIR, img_info["image_path"]) if os.path.exists(local_path): query_image = Image.open(local_path).convert("RGB") query_info = img_info st.image( query_image, caption=f"Selected Reference Image #{idx} (Class: {img_info['class_name']})", use_container_width=True ) else: st.error("Reference image file not found.") with col_right: st.markdown("### Visual Similarity Search Results") if query_image is not None: t_start = time.time() # 1. Preprocess batch_img = preprocess_image(query_image) # 2. ONNX Inference onnx_outputs = ort_sess.run(None, {"images": batch_img}) feature_vector = onnx_outputs[0] # (1, 2048) # 3. FAISS Query t_search_start = time.time() results = perform_search(feature_vector, top_k=top_k_slider) t_total_ms = (time.time() - t_start) * 1000 t_search_ms = (time.time() - t_search_start) * 1000 # Show Metrics m_col1, m_col2, m_col3 = st.columns(3) with m_col1: st.markdown(f'