kaveh commited on
Commit
dff64e9
·
1 Parent(s): 1e95da3

added batch mode and corrected force scale

Browse files
Files changed (4) hide show
  1. app.py +172 -49
  2. config/constants.py +1 -0
  3. ui/components.py +159 -1
  4. utils/display.py +6 -26
app.py CHANGED
@@ -16,6 +16,7 @@ if S2F_ROOT not in sys.path:
16
  sys.path.insert(0, S2F_ROOT)
17
 
18
  from config.constants import (
 
19
  COLORMAPS,
20
  DEFAULT_SUBSTRATE,
21
  MODEL_TYPE_LABELS,
@@ -29,6 +30,7 @@ from utils.display import apply_display_scale
29
  from ui.components import (
30
  build_original_vals,
31
  build_cell_vals,
 
32
  render_result_display,
33
  render_region_canvas,
34
  render_system_status,
@@ -49,7 +51,7 @@ if HAS_DRAWABLE_CANVAS and ST_DIALOG:
49
  if raw_heatmap is None:
50
  st.warning("No prediction available to measure.")
51
  return
52
- display_mode = st.session_state.get("measure_display_mode", "Full")
53
  display_heatmap = apply_display_scale(
54
  raw_heatmap, display_mode,
55
  min_percentile=st.session_state.get("measure_min_percentile", 0),
@@ -121,6 +123,9 @@ def _inject_theme_css(theme):
121
  st.markdown("""
122
  <style>
123
  section[data-testid="stSidebar"] { width: 380px !important; }
 
 
 
124
  section[data-testid="stSidebar"] h2 {
125
  font-size: 1.25rem !important;
126
  font-weight: 600 !important;
@@ -233,13 +238,17 @@ with st.sidebar:
233
  help="When on: estimate cell region from force map and use it for metrics (red contour). When off: use entire map.",
234
  )
235
 
236
- st.markdown('<p style="font-size: 0.95rem; font-weight: 500; margin-bottom: 0.5rem;">Heatmap display</p>', unsafe_allow_html=True)
 
 
 
 
 
237
  display_mode = st.radio(
238
- "Mode",
239
- ["Full", "Percentile", "Rescale", "Clip", "Filter"],
240
- help="Full: 0–1 as-is. Percentile: min/max percentiles. Rescale: stretch range to colors. Clip: clip, keep scale. Filter: show only in range.",
241
  horizontal=True,
242
- label_visibility="collapsed",
243
  )
244
  min_percentile, max_percentile = 0, 100
245
  clip_min, clip_max = 0.0, 1.0
@@ -252,14 +261,14 @@ with st.sidebar:
252
  if min_percentile >= max_percentile:
253
  st.warning("Min percentile must be less than max. Using min=0, max=100.")
254
  min_percentile, max_percentile = 0, 100
255
- elif display_mode in ("Rescale", "Clip", "Filter"):
256
  col_cmin, col_cmax = st.columns(2)
257
  with col_cmin:
258
- clip_min = st.number_input("Min", value=0.0, min_value=None, max_value=None, step=0.01, format="%.3f",
259
- help="Rescale: below black. Clip: clamp to min. Filter: below discarded.")
260
  with col_cmax:
261
- clip_max = st.number_input("Max", value=1.0, min_value=None, max_value=None, step=0.01, format="%.3f",
262
- help="Rescale: above white. Clip: clamp to max. Filter: above discarded.")
263
  if clip_min >= clip_max:
264
  st.warning("Min must be less than max. Using min=0, max=1.")
265
  clip_min, clip_max = 0.0, 1.0
@@ -275,44 +284,93 @@ with st.sidebar:
275
  # Main area: image input
276
  img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed")
277
  img = None
 
278
  uploaded = None
 
279
  selected_sample = None
280
-
281
- if img_source == "Upload":
282
- uploaded = st.file_uploader(
283
- "Upload bright-field image",
284
- type=["tif", "tiff", "png", "jpg", "jpeg"],
285
- help="Bright-field microscopy image of a cell or spheroid on a substrate (grayscale or RGB).",
286
- )
287
- if uploaded:
288
- bytes_data = uploaded.read()
289
- nparr = np.frombuffer(bytes_data, np.uint8)
290
- img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
291
- uploaded.seek(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  else:
293
- sample_folder = get_sample_folder(S2F_ROOT, model_type)
294
- sample_files = list_files_in_folder(sample_folder, SAMPLE_EXTENSIONS)
295
- sample_subfolder_name = model_subfolder(model_type)
296
- if sample_files:
297
- selected_sample = st.selectbox(
298
- f"Select example image (from `samples/{sample_subfolder_name}/`)",
299
- sample_files,
300
- format_func=lambda x: x,
301
- key=f"sample_{model_type}",
302
  )
303
- if selected_sample:
304
- sample_path = os.path.join(sample_folder, selected_sample)
305
- img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
306
- # Cached thumbnails
307
- thumbnails = get_cached_sample_thumbnails(model_type, sample_folder, sample_files)
308
- n_cols = min(5, len(thumbnails))
309
- cols = st.columns(n_cols)
310
- for i, (fname, sample_img) in enumerate(thumbnails):
311
- if sample_img is not None:
312
- with cols[i % n_cols]:
313
- st.image(sample_img, caption=fname, width=120)
314
  else:
315
- st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
  col_btn, col_model, col_path = st.columns([1, 1, 3])
318
  with col_btn:
@@ -322,16 +380,24 @@ with col_model:
322
  with col_path:
323
  ckp_path = f"ckp/{ckp_subfolder_name}/{checkpoint}" if checkpoint else f"ckp/{ckp_subfolder_name}/"
324
  st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>Checkpoint: <code>{ckp_path}</code></span>", unsafe_allow_html=True)
 
325
  has_image = img is not None
 
326
 
327
  if "prediction_result" not in st.session_state:
328
  st.session_state["prediction_result"] = None
 
 
 
 
329
 
330
- just_ran = run and checkpoint and has_image
331
- cached = st.session_state["prediction_result"]
332
  key_img = (uploaded.name if uploaded else None) if img_source == "Upload" else selected_sample
333
  current_key = (model_type, checkpoint, key_img)
334
- has_cached = cached is not None and cached.get("cache_key") == current_key
 
 
 
335
 
336
 
337
  def get_or_create_predictor(model_type, checkpoint, ckp_folder):
@@ -348,7 +414,62 @@ def get_or_create_predictor(model_type, checkpoint, ckp_folder):
348
  return st.session_state["predictor"]
349
 
350
 
351
- if just_ran:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  st.session_state["prediction_result"] = None
353
  with st.spinner("Loading model and predicting..."):
354
  try:
@@ -432,8 +553,10 @@ elif has_cached:
432
 
433
  elif run and not checkpoint:
434
  st.warning("Please add checkpoint files to the ckp/ folder and select one.")
435
- elif run and not has_image:
436
  st.warning("Please upload an image or select an example.")
 
 
437
 
438
  st.sidebar.divider()
439
  render_system_status()
 
16
  sys.path.insert(0, S2F_ROOT)
17
 
18
  from config.constants import (
19
+ BATCH_MAX_IMAGES,
20
  COLORMAPS,
21
  DEFAULT_SUBSTRATE,
22
  MODEL_TYPE_LABELS,
 
30
  from ui.components import (
31
  build_original_vals,
32
  build_cell_vals,
33
+ render_batch_results,
34
  render_result_display,
35
  render_region_canvas,
36
  render_system_status,
 
51
  if raw_heatmap is None:
52
  st.warning("No prediction available to measure.")
53
  return
54
+ display_mode = st.session_state.get("measure_display_mode", "Default")
55
  display_heatmap = apply_display_scale(
56
  raw_heatmap, display_mode,
57
  min_percentile=st.session_state.get("measure_min_percentile", 0),
 
123
  st.markdown("""
124
  <style>
125
  section[data-testid="stSidebar"] { width: 380px !important; }
126
+ @media (max-width: 768px) {
127
+ section[data-testid="stSidebar"] { width: 100% !important; max-width: 100% !important; }
128
+ }
129
  section[data-testid="stSidebar"] h2 {
130
  font-size: 1.25rem !important;
131
  font-weight: 600 !important;
 
238
  help="When on: estimate cell region from force map and use it for metrics (red contour). When off: use entire map.",
239
  )
240
 
241
+ batch_mode = st.toggle(
242
+ "Batch mode",
243
+ value=False,
244
+ help=f"Process up to {BATCH_MAX_IMAGES} images at once. Upload multiple files or select multiple examples.",
245
+ )
246
+
247
  display_mode = st.radio(
248
+ "Heatmap display",
249
+ ["Default", "Percentile", "Range"],
 
250
  horizontal=True,
251
+ help="Default: full 0–1 range. Percentile: map a percentile range to improve contrast when few bright pixels dominate. Range: show only values in [min, max]; others hidden (black).",
252
  )
253
  min_percentile, max_percentile = 0, 100
254
  clip_min, clip_max = 0.0, 1.0
 
261
  if min_percentile >= max_percentile:
262
  st.warning("Min percentile must be less than max. Using min=0, max=100.")
263
  min_percentile, max_percentile = 0, 100
264
+ elif display_mode == "Range":
265
  col_cmin, col_cmax = st.columns(2)
266
  with col_cmin:
267
+ clip_min = st.number_input("Min", value=0.0, min_value=0.0, max_value=1.0, step=0.01, format="%.3f",
268
+ help="Values below this rangehidden (black)")
269
  with col_cmax:
270
+ clip_max = st.number_input("Max", value=1.0, min_value=0.0, max_value=1.0, step=0.01, format="%.3f",
271
+ help="Values above this rangehidden (black)")
272
  if clip_min >= clip_max:
273
  st.warning("Min must be less than max. Using min=0, max=1.")
274
  clip_min, clip_max = 0.0, 1.0
 
284
  # Main area: image input
285
  img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed")
286
  img = None
287
+ imgs_batch = [] # list of (img, key_img) for batch mode
288
  uploaded = None
289
+ uploaded_list = []
290
  selected_sample = None
291
+ selected_samples = []
292
+
293
+ if batch_mode:
294
+ # Batch mode: multiple images (max BATCH_MAX_IMAGES)
295
+ if img_source == "Upload":
296
+ uploaded_list = st.file_uploader(
297
+ "Upload bright-field images",
298
+ type=["tif", "tiff", "png", "jpg", "jpeg"],
299
+ accept_multiple_files=True,
300
+ help=f"Select up to {BATCH_MAX_IMAGES} images. Bright-field microscopy (grayscale or RGB).",
301
+ )
302
+ if uploaded_list:
303
+ uploaded_list = uploaded_list[:BATCH_MAX_IMAGES]
304
+ for u in uploaded_list:
305
+ bytes_data = u.read()
306
+ nparr = np.frombuffer(bytes_data, np.uint8)
307
+ decoded = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
308
+ if decoded is not None:
309
+ imgs_batch.append((decoded, u.name))
310
+ u.seek(0)
311
+ else:
312
+ sample_folder = get_sample_folder(S2F_ROOT, model_type)
313
+ sample_files = list_files_in_folder(sample_folder, SAMPLE_EXTENSIONS)
314
+ sample_subfolder_name = model_subfolder(model_type)
315
+ if sample_files:
316
+ selected_samples = st.multiselect(
317
+ f"Select example images (max {BATCH_MAX_IMAGES})",
318
+ sample_files,
319
+ default=None,
320
+ max_selections=BATCH_MAX_IMAGES,
321
+ key=f"sample_batch_{model_type}",
322
+ )
323
+ if selected_samples:
324
+ for fname in selected_samples[:BATCH_MAX_IMAGES]:
325
+ sample_path = os.path.join(sample_folder, fname)
326
+ loaded = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
327
+ if loaded is not None:
328
+ imgs_batch.append((loaded, fname))
329
+ thumbnails = get_cached_sample_thumbnails(model_type, sample_folder, sample_files)
330
+ n_cols = min(5, len(thumbnails))
331
+ cols = st.columns(n_cols)
332
+ for i, (fname, sample_img) in enumerate(thumbnails):
333
+ if sample_img is not None:
334
+ with cols[i % n_cols]:
335
+ st.image(sample_img, caption=fname, width=120)
336
+ else:
337
+ st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.")
338
  else:
339
+ # Single image mode
340
+ if img_source == "Upload":
341
+ uploaded = st.file_uploader(
342
+ "Upload bright-field image",
343
+ type=["tif", "tiff", "png", "jpg", "jpeg"],
344
+ help="Bright-field microscopy image of a cell or spheroid on a substrate (grayscale or RGB).",
 
 
 
345
  )
346
+ if uploaded:
347
+ bytes_data = uploaded.read()
348
+ nparr = np.frombuffer(bytes_data, np.uint8)
349
+ img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
350
+ uploaded.seek(0)
 
 
 
 
 
 
351
  else:
352
+ sample_folder = get_sample_folder(S2F_ROOT, model_type)
353
+ sample_files = list_files_in_folder(sample_folder, SAMPLE_EXTENSIONS)
354
+ sample_subfolder_name = model_subfolder(model_type)
355
+ if sample_files:
356
+ selected_sample = st.selectbox(
357
+ f"Select example image (from `samples/{sample_subfolder_name}/`)",
358
+ sample_files,
359
+ format_func=lambda x: x,
360
+ key=f"sample_{model_type}",
361
+ )
362
+ if selected_sample:
363
+ sample_path = os.path.join(sample_folder, selected_sample)
364
+ img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
365
+ thumbnails = get_cached_sample_thumbnails(model_type, sample_folder, sample_files)
366
+ n_cols = min(5, len(thumbnails))
367
+ cols = st.columns(n_cols)
368
+ for i, (fname, sample_img) in enumerate(thumbnails):
369
+ if sample_img is not None:
370
+ with cols[i % n_cols]:
371
+ st.image(sample_img, caption=fname, width=120)
372
+ else:
373
+ st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.")
374
 
375
  col_btn, col_model, col_path = st.columns([1, 1, 3])
376
  with col_btn:
 
380
  with col_path:
381
  ckp_path = f"ckp/{ckp_subfolder_name}/{checkpoint}" if checkpoint else f"ckp/{ckp_subfolder_name}/"
382
  st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>Checkpoint: <code>{ckp_path}</code></span>", unsafe_allow_html=True)
383
+
384
  has_image = img is not None
385
+ has_batch = len(imgs_batch) > 0
386
 
387
  if "prediction_result" not in st.session_state:
388
  st.session_state["prediction_result"] = None
389
+ if "batch_results" not in st.session_state:
390
+ st.session_state["batch_results"] = None
391
+ if not batch_mode:
392
+ st.session_state["batch_results"] = None # Clear when switching to single mode
393
 
394
+ # Single-image keys (for non-batch)
 
395
  key_img = (uploaded.name if uploaded else None) if img_source == "Upload" else selected_sample
396
  current_key = (model_type, checkpoint, key_img)
397
+ cached = st.session_state["prediction_result"]
398
+ has_cached = cached is not None and cached.get("cache_key") == current_key and not batch_mode
399
+ just_ran = run and checkpoint and has_image and not batch_mode
400
+ just_ran_batch = run and checkpoint and has_batch and batch_mode
401
 
402
 
403
  def get_or_create_predictor(model_type, checkpoint, ckp_folder):
 
414
  return st.session_state["predictor"]
415
 
416
 
417
+ if just_ran_batch:
418
+ st.session_state["prediction_result"] = None
419
+ st.session_state["batch_results"] = None
420
+ with st.spinner("Loading model and predicting..."):
421
+ try:
422
+ predictor = get_or_create_predictor(model_type, checkpoint, ckp_folder)
423
+ sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE
424
+ batch_results = []
425
+ progress_bar = st.progress(0, text="Processing images...")
426
+ for idx, (img_b, key_b) in enumerate(imgs_batch):
427
+ progress_bar.progress((idx + 1) / len(imgs_batch), text=f"Processing {key_b}...")
428
+ heatmap, force, pixel_sum = predictor.predict(
429
+ image_array=img_b,
430
+ substrate=sub_val,
431
+ substrate_config=substrate_config if model_type == "single_cell" else None,
432
+ )
433
+ cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None
434
+ batch_results.append({
435
+ "img": img_b.copy(),
436
+ "heatmap": heatmap.copy(),
437
+ "force": force,
438
+ "pixel_sum": pixel_sum,
439
+ "key_img": key_b,
440
+ "cell_mask": cell_mask,
441
+ })
442
+ progress_bar.empty()
443
+ st.session_state["batch_results"] = batch_results
444
+ st.success(f"Prediction complete for {len(batch_results)} image(s)!")
445
+ render_batch_results(
446
+ batch_results,
447
+ colormap_name=colormap_name,
448
+ display_mode=display_mode,
449
+ min_percentile=min_percentile,
450
+ max_percentile=max_percentile,
451
+ clip_min=clip_min,
452
+ clip_max=clip_max,
453
+ auto_cell_boundary=auto_cell_boundary,
454
+ )
455
+ except Exception as e:
456
+ st.error(f"Prediction failed: {e}")
457
+ st.code(traceback.format_exc())
458
+
459
+ elif batch_mode and st.session_state.get("batch_results"):
460
+ st.success("Prediction complete!")
461
+ render_batch_results(
462
+ st.session_state["batch_results"],
463
+ colormap_name=colormap_name,
464
+ display_mode=display_mode,
465
+ min_percentile=min_percentile,
466
+ max_percentile=max_percentile,
467
+ clip_min=clip_min,
468
+ clip_max=clip_max,
469
+ auto_cell_boundary=auto_cell_boundary,
470
+ )
471
+
472
+ elif just_ran:
473
  st.session_state["prediction_result"] = None
474
  with st.spinner("Loading model and predicting..."):
475
  try:
 
553
 
554
  elif run and not checkpoint:
555
  st.warning("Please add checkpoint files to the ckp/ folder and select one.")
556
+ elif run and not has_image and not has_batch:
557
  st.warning("Please upload an image or select an example.")
558
+ elif run and batch_mode and not has_batch:
559
+ st.warning(f"Please upload or select 1–{BATCH_MAX_IMAGES} images for batch processing.")
560
 
561
  st.sidebar.divider()
562
  render_system_status()
config/constants.py CHANGED
@@ -12,6 +12,7 @@ DEFAULT_SUBSTRATE = "Fibroblasts_Fibronectin_6KPa"
12
  # UI
13
  CANVAS_SIZE = 320
14
  SAMPLE_THUMBNAIL_LIMIT = 8
 
15
  COLORMAP_N_SAMPLES = 64
16
 
17
  # Model type labels
 
12
  # UI
13
  CANVAS_SIZE = 320
14
  SAMPLE_THUMBNAIL_LIMIT = 8
15
+ BATCH_MAX_IMAGES = 5
16
  COLORMAP_N_SAMPLES = 64
17
 
18
  # Model type labels
ui/components.py CHANGED
@@ -3,6 +3,7 @@ import csv
3
  import html
4
  import io
5
  import os
 
6
 
7
  import cv2
8
  import numpy as np
@@ -115,6 +116,135 @@ def render_system_status():
115
  pass
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  # Distinct colors for each region (RGB - heatmap_rgb is RGB)
119
  _REGION_COLORS = [
120
  (255, 102, 0), # orange
@@ -508,7 +638,7 @@ def _add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2):
508
 
509
 
510
  def render_result_display(img, raw_heatmap, display_heatmap, pixel_sum, force, key_img, download_key_suffix="",
511
- colormap_name="Jet", display_mode="Auto", measure_region_dialog=None, auto_cell_boundary=True,
512
  cell_mask=None):
513
  """
514
  Render prediction result: plot, metrics, expander, and download/measure buttons.
@@ -583,6 +713,34 @@ def render_result_display(img, raw_heatmap, display_heatmap, pixel_sum, force, k
583
  with col4:
584
  st.metric("Heatmap mean", f"{np.mean(raw_heatmap):.4f}", help="Average force intensity (full FOV)")
585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  with st.expander("How to read the results"):
587
  if use_cell_metrics:
588
  st.markdown("""
 
3
  import html
4
  import io
5
  import os
6
+ import zipfile
7
 
8
  import cv2
9
  import numpy as np
 
116
  pass
117
 
118
 
119
+ def render_batch_results(batch_results, colormap_name="Jet", display_mode="Default",
120
+ min_percentile=0, max_percentile=100, clip_min=0, clip_max=1,
121
+ auto_cell_boundary=False):
122
+ """
123
+ Render batch prediction results: summary table, bright-field row, heatmap row, and bulk download.
124
+ batch_results: list of dicts with img, heatmap, force, pixel_sum, key_img, cell_mask.
125
+ cell_mask is computed on-the-fly when auto_cell_boundary is True and not stored.
126
+ """
127
+ if not batch_results:
128
+ return
129
+ st.markdown("### Batch results")
130
+ # Resolve cell_mask for each result (compute if needed when auto_cell_boundary toggled on)
131
+ for r in batch_results:
132
+ if auto_cell_boundary and (r.get("cell_mask") is None or not np.any(r.get("cell_mask", 0) > 0)):
133
+ r["_cell_mask"] = estimate_cell_mask(r["heatmap"])
134
+ else:
135
+ r["_cell_mask"] = r.get("cell_mask") if auto_cell_boundary else None
136
+ # Build table rows - consistent column names for both modes
137
+ headers = ["Image", "Force", "Sum", "Max", "Mean"]
138
+ rows = []
139
+ csv_rows = [["image"] + headers[1:]]
140
+ for r in batch_results:
141
+ heatmap = r["heatmap"]
142
+ cell_mask = r.get("_cell_mask")
143
+ key = r["key_img"] or "image"
144
+ if auto_cell_boundary and cell_mask is not None and np.any(cell_mask > 0):
145
+ vals = heatmap[cell_mask > 0]
146
+ cell_pixel_sum = float(np.sum(vals))
147
+ cell_force = cell_pixel_sum * (r["force"] / r["pixel_sum"]) if r["pixel_sum"] > 0 else cell_pixel_sum
148
+ cell_mean = cell_pixel_sum / np.sum(cell_mask) if np.sum(cell_mask) > 0 else 0
149
+ row = [key, f"{cell_force:.2f}", f"{cell_pixel_sum:.2f}",
150
+ f"{np.max(heatmap):.4f}", f"{cell_mean:.4f}"]
151
+ else:
152
+ row = [key, f"{r['force']:.2f}", f"{r['pixel_sum']:.2f}",
153
+ f"{np.max(heatmap):.4f}", f"{np.mean(heatmap):.4f}"]
154
+ rows.append(row)
155
+ csv_rows.append([os.path.splitext(key)[0]] + row[1:])
156
+ # Bright-field row
157
+ st.markdown("**Input: Bright-field images**")
158
+ n_cols = min(5, len(batch_results))
159
+ bf_cols = st.columns(n_cols)
160
+ for i, r in enumerate(batch_results):
161
+ img = r["img"]
162
+ if img.ndim == 2:
163
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
164
+ else:
165
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
166
+ with bf_cols[i % n_cols]:
167
+ st.image(img_rgb, caption=r["key_img"], use_container_width=True)
168
+ # Heatmap row
169
+ st.markdown("**Output: Predicted force maps**")
170
+ hm_cols = st.columns(n_cols)
171
+ for i, r in enumerate(batch_results):
172
+ display_heatmap = apply_display_scale(
173
+ r["heatmap"], display_mode,
174
+ min_percentile=min_percentile, max_percentile=max_percentile,
175
+ clip_min=clip_min, clip_max=clip_max,
176
+ )
177
+ hm_rgb = heatmap_to_rgb_with_contour(
178
+ display_heatmap, colormap_name,
179
+ r.get("_cell_mask") if auto_cell_boundary else None,
180
+ )
181
+ with hm_cols[i % n_cols]:
182
+ st.image(hm_rgb, caption=r["key_img"], use_container_width=True)
183
+ # Table
184
+ st.dataframe(
185
+ {h: [r[i] for r in rows] for i, h in enumerate(headers)},
186
+ use_container_width=True,
187
+ hide_index=True,
188
+ )
189
+ # Histograms in accordion (one per row for visibility)
190
+ with st.expander("Force distribution (histograms)", expanded=False):
191
+ for i, r in enumerate(batch_results):
192
+ heatmap = r["heatmap"]
193
+ cell_mask = r.get("_cell_mask")
194
+ vals = heatmap[cell_mask > 0] if (cell_mask is not None and np.any(cell_mask > 0) and auto_cell_boundary) else heatmap.flatten()
195
+ vals = vals[vals > 0] if np.any(vals > 0) else vals
196
+ st.markdown(f"**{r['key_img']}**")
197
+ if len(vals) > 0:
198
+ fig = go.Figure(data=[go.Histogram(x=vals, nbinsx=50, marker_color="#0d9488")])
199
+ fig.update_layout(
200
+ height=220, margin=dict(l=40, r=20, t=10, b=40),
201
+ xaxis_title="Force value", yaxis_title="Count",
202
+ showlegend=False,
203
+ )
204
+ st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
205
+ else:
206
+ st.caption("No data")
207
+ if i < len(batch_results) - 1:
208
+ st.divider()
209
+ # Bulk downloads: CSV and heatmaps (zip)
210
+ buf_csv = io.StringIO()
211
+ csv.writer(buf_csv).writerows(csv_rows)
212
+ zip_buf = io.BytesIO()
213
+ with zipfile.ZipFile(zip_buf, "w", zipfile.ZIP_DEFLATED) as zf:
214
+ for r in batch_results:
215
+ display_heatmap = apply_display_scale(
216
+ r["heatmap"], display_mode,
217
+ min_percentile=min_percentile, max_percentile=max_percentile,
218
+ clip_min=clip_min, clip_max=clip_max,
219
+ )
220
+ hm_bytes = heatmap_to_png_bytes(
221
+ display_heatmap, colormap_name,
222
+ r.get("_cell_mask") if auto_cell_boundary else None,
223
+ )
224
+ base = os.path.splitext(r["key_img"] or "image")[0]
225
+ zf.writestr(f"{base}_heatmap.png", hm_bytes.getvalue())
226
+ zip_buf.seek(0)
227
+ dl_col1, dl_col2 = st.columns(2)
228
+ with dl_col1:
229
+ st.download_button(
230
+ "Download all as CSV",
231
+ data=buf_csv.getvalue(),
232
+ file_name="s2f_batch_results.csv",
233
+ mime="text/csv",
234
+ key="download_batch_csv",
235
+ icon=":material/download:",
236
+ )
237
+ with dl_col2:
238
+ st.download_button(
239
+ "Download all heatmaps",
240
+ data=zip_buf.getvalue(),
241
+ file_name="s2f_batch_heatmaps.zip",
242
+ mime="application/zip",
243
+ key="download_batch_heatmaps",
244
+ icon=":material/image:",
245
+ )
246
+
247
+
248
  # Distinct colors for each region (RGB - heatmap_rgb is RGB)
249
  _REGION_COLORS = [
250
  (255, 102, 0), # orange
 
638
 
639
 
640
  def render_result_display(img, raw_heatmap, display_heatmap, pixel_sum, force, key_img, download_key_suffix="",
641
+ colormap_name="Jet", display_mode="Default", measure_region_dialog=None, auto_cell_boundary=True,
642
  cell_mask=None):
643
  """
644
  Render prediction result: plot, metrics, expander, and download/measure buttons.
 
713
  with col4:
714
  st.metric("Heatmap mean", f"{np.mean(raw_heatmap):.4f}", help="Average force intensity (full FOV)")
715
 
716
+ # Statistics panel (mean, std, percentiles, histogram)
717
+ with st.expander("Statistics"):
718
+ vals = raw_heatmap[cell_mask > 0] if (cell_mask is not None and np.any(cell_mask > 0) and use_cell_metrics) else raw_heatmap.flatten()
719
+ if len(vals) > 0:
720
+ st.markdown("**Summary**")
721
+ stat_col1, stat_col2, stat_col3 = st.columns(3)
722
+ with stat_col1:
723
+ st.metric("Mean", f"{float(np.mean(vals)):.4f}")
724
+ st.metric("Std", f"{float(np.std(vals)):.4f}")
725
+ with stat_col2:
726
+ p25, p50, p75 = float(np.percentile(vals, 25)), float(np.percentile(vals, 50)), float(np.percentile(vals, 75))
727
+ st.metric("P25", f"{p25:.4f}")
728
+ st.metric("P50 (median)", f"{p50:.4f}")
729
+ st.metric("P75", f"{p75:.4f}")
730
+ with stat_col3:
731
+ p90 = float(np.percentile(vals, 90))
732
+ st.metric("P90", f"{p90:.4f}")
733
+ st.markdown("**Histogram**")
734
+ hist_fig = go.Figure(data=[go.Histogram(x=vals, nbinsx=50, marker_color="#0d9488")])
735
+ hist_fig.update_layout(
736
+ height=220, margin=dict(l=40, r=20, t=20, b=40),
737
+ xaxis_title="Force value", yaxis_title="Count",
738
+ showlegend=False,
739
+ )
740
+ st.plotly_chart(hist_fig, use_container_width=True, config={"displayModeBar": False})
741
+ else:
742
+ st.caption("No nonzero values to compute statistics.")
743
+
744
  with st.expander("How to read the results"):
745
  if use_cell_metrics:
746
  st.markdown("""
utils/display.py CHANGED
@@ -21,14 +21,12 @@ def cv_colormap_to_plotly_colorscale(colormap_name, n_samples=None):
21
 
22
  def apply_display_scale(heatmap, mode, min_percentile=0, max_percentile=100, clip_min=0, clip_max=1):
23
  """
24
- Apply display scaling (Fiji ImageJ B&C style). Display only—does not change underlying values.
25
- - Full: use 0–1 range as-is (clip values outside)
26
- - Percentile: map data at min_percentile..max_percentile to 0..1 (values outside clipped).
27
- - Rescale: map [clip_min, clip_max] to [0, 1]; values outside black/white. Stretches range for contrast.
28
- - Clip: clip values to [clip_min, clip_max], display with original data scale (no stretching).
29
- - Filter: discard values outside [clip_min, clip_max] (black); only values in range show color.
30
  """
31
- if mode == "Full":
32
  return np.clip(heatmap, 0, 1).astype(np.float32)
33
  if mode == "Percentile":
34
  pmin = float(np.percentile(heatmap, min_percentile))
@@ -37,25 +35,7 @@ def apply_display_scale(heatmap, mode, min_percentile=0, max_percentile=100, cli
37
  out = (heatmap.astype(np.float32) - pmin) / (pmax - pmin)
38
  return np.clip(out, 0, 1).astype(np.float32)
39
  return np.clip(heatmap, 0, 1).astype(np.float32)
40
- if mode == "Rescale":
41
- vmin, vmax = float(clip_min), float(clip_max)
42
- if vmax > vmin:
43
- out = (heatmap.astype(np.float32) - vmin) / (vmax - vmin)
44
- return np.clip(out, 0, 1).astype(np.float32)
45
- return np.clip(heatmap, 0, 1).astype(np.float32)
46
- if mode == "Clip":
47
- vmin, vmax = float(clip_min), float(clip_max)
48
- if vmax > vmin:
49
- h = np.clip(heatmap.astype(np.float32), vmin, vmax)
50
- dmin, dmax = float(np.min(heatmap)), float(np.max(heatmap))
51
- if dmax > dmin:
52
- out = (h - dmin) / (dmax - dmin)
53
- return np.clip(out, 0, 1).astype(np.float32)
54
- return np.clip(h, 0, 1).astype(np.float32)
55
- return np.clip(heatmap, 0, 1).astype(np.float32)
56
- if mode == "Threshold":
57
- mode = "Filter" # backward compat
58
- if mode == "Filter":
59
  vmin, vmax = float(clip_min), float(clip_max)
60
  if vmax > vmin:
61
  h = heatmap.astype(np.float32)
 
21
 
22
  def apply_display_scale(heatmap, mode, min_percentile=0, max_percentile=100, clip_min=0, clip_max=1):
23
  """
24
+ Apply display scaling (Fiji/ImageJ style). Display only—does not change underlying values.
25
+ - Default: full 0–1 range as-is.
26
+ - Percentile: map min..max percentiles to 0..1.
27
+ - Range: show only values in [clip_min, clip_max]; others hidden (black).
 
 
28
  """
29
+ if mode == "Default" or mode == "Auto" or mode == "Full":
30
  return np.clip(heatmap, 0, 1).astype(np.float32)
31
  if mode == "Percentile":
32
  pmin = float(np.percentile(heatmap, min_percentile))
 
35
  out = (heatmap.astype(np.float32) - pmin) / (pmax - pmin)
36
  return np.clip(out, 0, 1).astype(np.float32)
37
  return np.clip(heatmap, 0, 1).astype(np.float32)
38
+ if mode == "Range" or mode == "Filter" or mode == "Threshold":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  vmin, vmax = float(clip_min), float(clip_max)
40
  if vmax > vmin:
41
  h = heatmap.astype(np.float32)