mahmoudalyosify commited on
Commit
b81c9f2
Β·
verified Β·
1 Parent(s): 2e6a9a9

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +18 -25
src/streamlit_app.py CHANGED
@@ -1,7 +1,7 @@
1
  """
2
  ================================================================================
3
  SimCLR ResNet-50 Visual Search Engine GUI
4
- -----------------------------------------
5
  A premium, interactive web-based graphical user interface (GUI) built with
6
  Streamlit, powered by ONNX Runtime and FAISS for real-time visual similarity
7
  retrieval on CIFAR-10.
@@ -15,10 +15,9 @@
15
  6. Dedicated interactive "Ablation Study Dashboard" showing the Exp 38-42 findings
16
 
17
  Usage:
18
- streamlit run app.py
19
  ================================================================================
20
  """
21
-
22
  import os
23
  import sys
24
  import time
@@ -26,7 +25,6 @@ import json
26
  import random
27
  import numpy as np
28
  from PIL import Image
29
-
30
  import streamlit as st
31
 
32
  # Windows console encoding fix
@@ -34,14 +32,16 @@ sys.stdout.reconfigure(encoding='utf-8')
34
  sys.stderr.reconfigure(encoding='utf-8')
35
 
36
  # Ensure we can import from src/
37
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
38
- sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
39
 
40
  # ==========================================================================
41
  # Caching Assets for Instant Performance
42
  # ==========================================================================
43
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
44
- DEPLOY_DIR = os.path.join(BASE_DIR, "deployment")
 
 
45
  ONNX_PATH = os.path.join(DEPLOY_DIR, "simclr_encoder_exp41.onnx")
46
  INDEX_PATH = os.path.join(DEPLOY_DIR, "cifar10_index.faiss")
47
  META_PATH = os.path.join(DEPLOY_DIR, "metadata.json")
@@ -49,7 +49,6 @@ META_PATH = os.path.join(DEPLOY_DIR, "metadata.json")
49
  CIFAR_MEAN = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32)
50
  CIFAR_STD = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32)
51
 
52
-
53
  @st.cache_resource
54
  def load_onnx_model():
55
  """Load ONNX inference session and cache it."""
@@ -58,21 +57,18 @@ def load_onnx_model():
58
  providers = ['CPUExecutionProvider']
59
  return ort.InferenceSession(ONNX_PATH, providers=providers)
60
 
61
-
62
  @st.cache_resource
63
  def load_faiss_index():
64
  """Load FAISS index and cache it."""
65
  import faiss
66
  return faiss.read_index(INDEX_PATH)
67
 
68
-
69
  @st.cache_data
70
  def load_metadata():
71
  """Load metadata dictionary and cache it."""
72
  with open(META_PATH, "r", encoding="utf-8") as f:
73
  return json.load(f)
74
 
75
-
76
  # -- Check that deployment files exist ---------------------------------------
77
  missing_files = []
78
  for p, name in [(ONNX_PATH, "ONNX Model"), (INDEX_PATH, "FAISS Index"), (META_PATH, "Metadata JSON")]:
@@ -81,7 +77,7 @@ for p, name in [(ONNX_PATH, "ONNX Model"), (INDEX_PATH, "FAISS Index"), (META_PA
81
 
82
  if missing_files:
83
  st.error(f" Missing deployment assets: {', '.join(missing_files)}")
84
- st.warning(" Please run `export_onnx.py` then `build_faiss.py` to generate these files before running the GUI!")
85
  st.stop()
86
 
87
  # Load cached assets
@@ -118,7 +114,6 @@ def preprocess_image(pil_img):
118
 
119
  return img_batch
120
 
121
-
122
  def perform_search(features_batch, top_k=5):
123
  """
124
  Normalizes features, queries FAISS, and returns matched metadata.
@@ -149,7 +144,6 @@ def perform_search(features_batch, top_k=5):
149
 
150
  return results
151
 
152
-
153
  # ==========================================================================
154
  # Streamlit UI Configuration
155
  # ==========================================================================
@@ -230,7 +224,6 @@ st.sidebar.markdown("""
230
  * **Similarity Metric**: Exact Cosine Similarity
231
  * **Database Size**: 10,000 Images
232
  """)
233
-
234
  st.sidebar.caption("Group 20, CISC 867, Queen's University, Spring 2026")
235
 
236
  # ==========================================================================
@@ -240,7 +233,7 @@ st.title("Real-time SimCLR Image Retrieval Engine")
240
  st.markdown("##### Self-Supervised Representation Learning with ResNet-50 & FAISS Indexing")
241
 
242
  # Define Tabs
243
- tab1, tab2 = st.tabs([" Real-Time Search", " Ablation Study Dashboard"])
244
 
245
  # ==========================================================================
246
  # TAB 1: Visual Search Engine
@@ -256,7 +249,7 @@ with tab1:
256
  query_info = None
257
 
258
  with col_left:
259
- st.markdown("### Select Query Image")
260
 
261
  # Method 1: Upload a file
262
  uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"])
@@ -289,7 +282,7 @@ with tab1:
289
  st.error("Reference image file not found.")
290
 
291
  with col_right:
292
- st.markdown("###Visual Similarity Search Results")
293
 
294
  if query_image is not None:
295
  t_start = time.time()
@@ -324,7 +317,7 @@ with tab1:
324
  st.markdown(f'<div class="glass-card"><div class="metric-value">N/A</div><div class="metric-label">Upload class unknown</div></div>', unsafe_allow_html=True)
325
 
326
  # Display Results Grid
327
- st.markdown("#### Retreived Nearest Neighbors")
328
 
329
  # Create dynamic grid columns
330
  grid_cols = st.columns(3)
@@ -371,13 +364,13 @@ with tab1:
371
  col.markdown(card_content, unsafe_allow_html=True)
372
 
373
  else:
374
- st.info(" Please upload a custom image or click the button to select a random test image to query the visual search engine!")
375
 
376
  # ==========================================================================
377
  # TAB 2: Ablation Study Dashboard
378
  # ==========================================================================
379
  with tab2:
380
- st.markdown("### Experiment Ablation Study & Color Jitter Findings")
381
  st.markdown("##### Evaluating Natalie's Midterm Pipelines against Mahmoud's Final Jittered Re-runs")
382
 
383
  st.markdown("""
@@ -412,7 +405,7 @@ with tab2:
412
  st.markdown(card_html, unsafe_allow_html=True)
413
 
414
  st.markdown("---")
415
- st.markdown("#### Key Project Takeaways")
416
  col1, col2 = st.columns(2)
417
 
418
  with col1:
@@ -427,10 +420,10 @@ with tab2:
427
  with col2:
428
  st.markdown("""
429
  ##### 3. Model Architecture & Stem Tuning βš™οΈ
430
- Modifying the standard ResNet-50 conv1 stem from $7\\times7$ (stride 2) to a custom $3\\times3$ (stride 1) and removing the initial MaxPool was crucial to preserve the resolution of $32\\times32$ CIFAR-10 images.
431
 
432
  ##### 4. Near Foundation-Model Upper Bound πŸ†
433
  Our zero-shot CLIP ViT-B/32 foundation model evaluation sets the academic upper bound at **88.80%**. Our custom-trained SimCLR ResNet-50 achieves **95% of this performance** (**84.30%**) while using **8,000x less training data**!
434
  """)
435
 
436
- st.success("πŸŽ‰ Weights & Biases Dashboard has archived all 5 experiments complete with training curves, checkpoints, and t-SNE files.")
 
1
  """
2
  ================================================================================
3
  SimCLR ResNet-50 Visual Search Engine GUI
4
+ --------------------------------------------------------------------------------
5
  A premium, interactive web-based graphical user interface (GUI) built with
6
  Streamlit, powered by ONNX Runtime and FAISS for real-time visual similarity
7
  retrieval on CIFAR-10.
 
15
  6. Dedicated interactive "Ablation Study Dashboard" showing the Exp 38-42 findings
16
 
17
  Usage:
18
+ streamlit run src/streamlit_app.py
19
  ================================================================================
20
  """
 
21
  import os
22
  import sys
23
  import time
 
25
  import random
26
  import numpy as np
27
  from PIL import Image
 
28
  import streamlit as st
29
 
30
  # Windows console encoding fix
 
32
  sys.stderr.reconfigure(encoding='utf-8')
33
 
34
  # Ensure we can import from src/
35
+ SRC_DIR = os.path.dirname(os.path.abspath(__file__))
36
+ sys.path.append(SRC_DIR)
37
 
38
  # ==========================================================================
39
  # Caching Assets for Instant Performance
40
  # ==========================================================================
41
+ # Navigate up one level from 'src' to the root directory to access 'deployment'
42
+ ROOT_DIR = os.path.dirname(SRC_DIR)
43
+ DEPLOY_DIR = os.path.join(ROOT_DIR, "deployment")
44
+
45
  ONNX_PATH = os.path.join(DEPLOY_DIR, "simclr_encoder_exp41.onnx")
46
  INDEX_PATH = os.path.join(DEPLOY_DIR, "cifar10_index.faiss")
47
  META_PATH = os.path.join(DEPLOY_DIR, "metadata.json")
 
49
  CIFAR_MEAN = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32)
50
  CIFAR_STD = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32)
51
 
 
52
  @st.cache_resource
53
  def load_onnx_model():
54
  """Load ONNX inference session and cache it."""
 
57
  providers = ['CPUExecutionProvider']
58
  return ort.InferenceSession(ONNX_PATH, providers=providers)
59
 
 
60
  @st.cache_resource
61
  def load_faiss_index():
62
  """Load FAISS index and cache it."""
63
  import faiss
64
  return faiss.read_index(INDEX_PATH)
65
 
 
66
  @st.cache_data
67
  def load_metadata():
68
  """Load metadata dictionary and cache it."""
69
  with open(META_PATH, "r", encoding="utf-8") as f:
70
  return json.load(f)
71
 
 
72
  # -- Check that deployment files exist ---------------------------------------
73
  missing_files = []
74
  for p, name in [(ONNX_PATH, "ONNX Model"), (INDEX_PATH, "FAISS Index"), (META_PATH, "Metadata JSON")]:
 
77
 
78
  if missing_files:
79
  st.error(f" Missing deployment assets: {', '.join(missing_files)}")
80
+ st.warning(" Please ensure the 'deployment' folder contains the required ONNX, FAISS, and JSON files before running the GUI!")
81
  st.stop()
82
 
83
  # Load cached assets
 
114
 
115
  return img_batch
116
 
 
117
  def perform_search(features_batch, top_k=5):
118
  """
119
  Normalizes features, queries FAISS, and returns matched metadata.
 
144
 
145
  return results
146
 
 
147
  # ==========================================================================
148
  # Streamlit UI Configuration
149
  # ==========================================================================
 
224
  * **Similarity Metric**: Exact Cosine Similarity
225
  * **Database Size**: 10,000 Images
226
  """)
 
227
  st.sidebar.caption("Group 20, CISC 867, Queen's University, Spring 2026")
228
 
229
  # ==========================================================================
 
233
  st.markdown("##### Self-Supervised Representation Learning with ResNet-50 & FAISS Indexing")
234
 
235
  # Define Tabs
236
+ tab1, tab2 = st.tabs(["πŸ” Real-Time Search", "πŸ“Š Ablation Study Dashboard"])
237
 
238
  # ==========================================================================
239
  # TAB 1: Visual Search Engine
 
249
  query_info = None
250
 
251
  with col_left:
252
+ st.markdown("### πŸ“€ Select Query Image")
253
 
254
  # Method 1: Upload a file
255
  uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"])
 
282
  st.error("Reference image file not found.")
283
 
284
  with col_right:
285
+ st.markdown("### πŸ”Ž Visual Similarity Search Results")
286
 
287
  if query_image is not None:
288
  t_start = time.time()
 
317
  st.markdown(f'<div class="glass-card"><div class="metric-value">N/A</div><div class="metric-label">Upload class unknown</div></div>', unsafe_allow_html=True)
318
 
319
  # Display Results Grid
320
+ st.markdown("#### Retrieved Nearest Neighbors")
321
 
322
  # Create dynamic grid columns
323
  grid_cols = st.columns(3)
 
364
  col.markdown(card_content, unsafe_allow_html=True)
365
 
366
  else:
367
+ st.info("πŸ’‘ Please upload a custom image or click the button to select a random test image to query the visual search engine!")
368
 
369
  # ==========================================================================
370
  # TAB 2: Ablation Study Dashboard
371
  # ==========================================================================
372
  with tab2:
373
+ st.markdown("### πŸ§ͺ Experiment Ablation Study & Color Jitter Findings")
374
  st.markdown("##### Evaluating Natalie's Midterm Pipelines against Mahmoud's Final Jittered Re-runs")
375
 
376
  st.markdown("""
 
405
  st.markdown(card_html, unsafe_allow_html=True)
406
 
407
  st.markdown("---")
408
+ st.markdown("#### πŸ’‘ Key Project Takeaways")
409
  col1, col2 = st.columns(2)
410
 
411
  with col1:
 
420
  with col2:
421
  st.markdown("""
422
  ##### 3. Model Architecture & Stem Tuning βš™οΈ
423
+ Modifying the standard ResNet-50 conv1 stem from 7x7 (stride 2) to a custom 3x3 (stride 1) and removing the initial MaxPool was crucial to preserve the resolution of 32x32 CIFAR-10 images.
424
 
425
  ##### 4. Near Foundation-Model Upper Bound πŸ†
426
  Our zero-shot CLIP ViT-B/32 foundation model evaluation sets the academic upper bound at **88.80%**. Our custom-trained SimCLR ResNet-50 achieves **95% of this performance** (**84.30%**) while using **8,000x less training data**!
427
  """)
428
 
429
+ st.success("πŸŽ‰ Weights & Biases Dashboard has archived all 5 experiments complete with training curves, checkpoints, and t-SNE files.")