kaveh commited on
Commit
1e95da3
·
1 Parent(s): 3f7dffe

added different scales

Browse files
Files changed (2) hide show
  1. app.py +65 -14
  2. utils/display.py +42 -9
app.py CHANGED
@@ -49,8 +49,14 @@ if HAS_DRAWABLE_CANVAS and ST_DIALOG:
49
  if raw_heatmap is None:
50
  st.warning("No prediction available to measure.")
51
  return
52
- display_mode = st.session_state.get("measure_display_mode", "Auto")
53
- display_heatmap = apply_display_scale(raw_heatmap, display_mode)
 
 
 
 
 
 
54
  bf_img = st.session_state.get("measure_bf_img")
55
  original_vals = st.session_state.get("measure_original_vals")
56
  cell_vals = st.session_state.get("measure_cell_vals")
@@ -73,12 +79,17 @@ def _get_measure_dialog_fn():
73
 
74
 
75
  def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, colormap_name,
76
- display_mode, auto_cell_boundary, cell_mask=None):
 
77
  """Populate session state for the measure tool. If cell_mask is None and auto_cell_boundary, computes it."""
78
  if cell_mask is None and auto_cell_boundary:
79
  cell_mask = estimate_cell_mask(heatmap)
80
  st.session_state["measure_raw_heatmap"] = heatmap.copy()
81
  st.session_state["measure_display_mode"] = display_mode
 
 
 
 
82
  st.session_state["measure_bf_img"] = img.copy()
83
  st.session_state["measure_input_filename"] = key_img or "image"
84
  st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force)
@@ -216,24 +227,48 @@ with st.sidebar:
216
  except FileNotFoundError:
217
  st.error("config/substrate_settings.json not found")
218
 
 
 
 
 
 
 
 
219
  display_mode = st.radio(
220
- "Force scale",
221
- ["Auto", "Fixed"],
222
- help="Auto: map data range to full color scale (Fiji-style). Fixed: use 0-1 range. Metrics always show raw values.",
223
  horizontal=True,
 
224
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  colormap_name = st.selectbox(
226
  "Heatmap colormap",
227
  list(COLORMAPS.keys()),
228
  help="Color scheme for the force map. Viridis is often preferred for accessibility.",
229
  )
230
 
231
- auto_cell_boundary = st.checkbox(
232
- "Auto boundary",
233
- value=False,
234
- help="When on: estimate cell region from force map and use it for metrics (red contour). When off: use entire map.",
235
- )
236
-
237
  theme = st.radio("Theme", ["Light", "Dark"], horizontal=True, key="theme_selector")
238
  _inject_theme_css(theme)
239
 
@@ -327,7 +362,13 @@ if just_ran:
327
 
328
  st.success("Prediction complete!")
329
 
330
- display_heatmap = apply_display_scale(heatmap, display_mode)
 
 
 
 
 
 
331
 
332
  cache_key = (model_type, checkpoint, key_img)
333
  st.session_state["prediction_result"] = {
@@ -341,6 +382,8 @@ if just_ran:
341
  _populate_measure_session_state(
342
  heatmap, img, pixel_sum, force, key_img, colormap_name,
343
  display_mode, auto_cell_boundary, cell_mask=cell_mask,
 
 
344
  )
345
  render_result_display(
346
  img, heatmap, display_heatmap, pixel_sum, force, key_img,
@@ -358,11 +401,19 @@ if just_ran:
358
  elif has_cached:
359
  r = st.session_state["prediction_result"]
360
  img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
361
- display_heatmap = apply_display_scale(heatmap, display_mode)
 
 
 
 
 
 
362
  cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None
363
  _populate_measure_session_state(
364
  heatmap, img, pixel_sum, force, key_img, colormap_name,
365
  display_mode, auto_cell_boundary, cell_mask=cell_mask,
 
 
366
  )
367
 
368
  if st.session_state.pop("open_measure_dialog", False):
 
49
  if raw_heatmap is None:
50
  st.warning("No prediction available to measure.")
51
  return
52
+ display_mode = st.session_state.get("measure_display_mode", "Full")
53
+ display_heatmap = apply_display_scale(
54
+ raw_heatmap, display_mode,
55
+ min_percentile=st.session_state.get("measure_min_percentile", 0),
56
+ max_percentile=st.session_state.get("measure_max_percentile", 100),
57
+ clip_min=st.session_state.get("measure_clip_min", 0),
58
+ clip_max=st.session_state.get("measure_clip_max", 1),
59
+ )
60
  bf_img = st.session_state.get("measure_bf_img")
61
  original_vals = st.session_state.get("measure_original_vals")
62
  cell_vals = st.session_state.get("measure_cell_vals")
 
79
 
80
 
81
  def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, colormap_name,
82
+ display_mode, auto_cell_boundary, cell_mask=None,
83
+ min_percentile=0, max_percentile=100, clip_min=0, clip_max=1):
84
  """Populate session state for the measure tool. If cell_mask is None and auto_cell_boundary, computes it."""
85
  if cell_mask is None and auto_cell_boundary:
86
  cell_mask = estimate_cell_mask(heatmap)
87
  st.session_state["measure_raw_heatmap"] = heatmap.copy()
88
  st.session_state["measure_display_mode"] = display_mode
89
+ st.session_state["measure_min_percentile"] = min_percentile
90
+ st.session_state["measure_max_percentile"] = max_percentile
91
+ st.session_state["measure_clip_min"] = clip_min
92
+ st.session_state["measure_clip_max"] = clip_max
93
  st.session_state["measure_bf_img"] = img.copy()
94
  st.session_state["measure_input_filename"] = key_img or "image"
95
  st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force)
 
227
  except FileNotFoundError:
228
  st.error("config/substrate_settings.json not found")
229
 
230
+ auto_cell_boundary = st.toggle(
231
+ "Auto boundary",
232
+ value=False,
233
+ help="When on: estimate cell region from force map and use it for metrics (red contour). When off: use entire map.",
234
+ )
235
+
236
+ st.markdown('<p style="font-size: 0.95rem; font-weight: 500; margin-bottom: 0.5rem;">Heatmap display</p>', unsafe_allow_html=True)
237
  display_mode = st.radio(
238
+ "Mode",
239
+ ["Full", "Percentile", "Rescale", "Clip", "Filter"],
240
+ help="Full: 0–1 as-is. Percentile: min/max percentiles. Rescale: stretch range to colors. Clip: clip, keep scale. Filter: show only in range.",
241
  horizontal=True,
242
+ label_visibility="collapsed",
243
  )
244
+ min_percentile, max_percentile = 0, 100
245
+ clip_min, clip_max = 0.0, 1.0
246
+ if display_mode == "Percentile":
247
+ col_pmin, col_pmax = st.columns(2)
248
+ with col_pmin:
249
+ min_percentile = st.slider("Min percentile", 0, 100, 2, 1, help="Values below this percentile → black")
250
+ with col_pmax:
251
+ max_percentile = st.slider("Max percentile", 0, 100, 99, 1, help="Values above this percentile → white")
252
+ if min_percentile >= max_percentile:
253
+ st.warning("Min percentile must be less than max. Using min=0, max=100.")
254
+ min_percentile, max_percentile = 0, 100
255
+ elif display_mode in ("Rescale", "Clip", "Filter"):
256
+ col_cmin, col_cmax = st.columns(2)
257
+ with col_cmin:
258
+ clip_min = st.number_input("Min", value=0.0, min_value=None, max_value=None, step=0.01, format="%.3f",
259
+ help="Rescale: below → black. Clip: clamp to min. Filter: below → discarded.")
260
+ with col_cmax:
261
+ clip_max = st.number_input("Max", value=1.0, min_value=None, max_value=None, step=0.01, format="%.3f",
262
+ help="Rescale: above → white. Clip: clamp to max. Filter: above → discarded.")
263
+ if clip_min >= clip_max:
264
+ st.warning("Min must be less than max. Using min=0, max=1.")
265
+ clip_min, clip_max = 0.0, 1.0
266
  colormap_name = st.selectbox(
267
  "Heatmap colormap",
268
  list(COLORMAPS.keys()),
269
  help="Color scheme for the force map. Viridis is often preferred for accessibility.",
270
  )
271
 
 
 
 
 
 
 
272
  theme = st.radio("Theme", ["Light", "Dark"], horizontal=True, key="theme_selector")
273
  _inject_theme_css(theme)
274
 
 
362
 
363
  st.success("Prediction complete!")
364
 
365
+ display_heatmap = apply_display_scale(
366
+ heatmap, display_mode,
367
+ min_percentile=min_percentile,
368
+ max_percentile=max_percentile,
369
+ clip_min=clip_min,
370
+ clip_max=clip_max,
371
+ )
372
 
373
  cache_key = (model_type, checkpoint, key_img)
374
  st.session_state["prediction_result"] = {
 
382
  _populate_measure_session_state(
383
  heatmap, img, pixel_sum, force, key_img, colormap_name,
384
  display_mode, auto_cell_boundary, cell_mask=cell_mask,
385
+ min_percentile=min_percentile, max_percentile=max_percentile,
386
+ clip_min=clip_min, clip_max=clip_max,
387
  )
388
  render_result_display(
389
  img, heatmap, display_heatmap, pixel_sum, force, key_img,
 
401
  elif has_cached:
402
  r = st.session_state["prediction_result"]
403
  img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
404
+ display_heatmap = apply_display_scale(
405
+ heatmap, display_mode,
406
+ min_percentile=min_percentile,
407
+ max_percentile=max_percentile,
408
+ clip_min=clip_min,
409
+ clip_max=clip_max,
410
+ )
411
  cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None
412
  _populate_measure_session_state(
413
  heatmap, img, pixel_sum, force, key_img, colormap_name,
414
  display_mode, auto_cell_boundary, cell_mask=cell_mask,
415
+ min_percentile=min_percentile, max_percentile=max_percentile,
416
+ clip_min=clip_min, clip_max=clip_max,
417
  )
418
 
419
  if st.session_state.pop("open_measure_dialog", False):
utils/display.py CHANGED
@@ -19,16 +19,49 @@ def cv_colormap_to_plotly_colorscale(colormap_name, n_samples=None):
19
  return scale
20
 
21
 
22
- def apply_display_scale(heatmap, mode):
23
  """
24
- Apply display scaling (Fiji-style). Display only—does not change underlying values.
25
- - Auto: map data min..max to 0..1 (full color range)
26
- - Fixed: use 0-1 range as-is
 
 
 
27
  """
28
- if mode == "Fixed":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  return np.clip(heatmap, 0, 1).astype(np.float32)
30
- hmin, hmax = float(np.min(heatmap)), float(np.max(heatmap))
31
- if hmax > hmin:
32
- out = (heatmap.astype(np.float32) - hmin) / (hmax - hmin)
33
- return np.clip(out, 0, 1).astype(np.float32)
34
  return np.clip(heatmap, 0, 1).astype(np.float32)
 
19
  return scale
20
 
21
 
22
+ def apply_display_scale(heatmap, mode, min_percentile=0, max_percentile=100, clip_min=0, clip_max=1):
23
  """
24
+ Apply display scaling (Fiji ImageJ B&C style). Display only—does not change underlying values.
25
+ - Full: use 01 range as-is (clip values outside)
26
+ - Percentile: map data at min_percentile..max_percentile to 0..1 (values outside clipped).
27
+ - Rescale: map [clip_min, clip_max] to [0, 1]; values outside → black/white. Stretches range for contrast.
28
+ - Clip: clip values to [clip_min, clip_max], display with original data scale (no stretching).
29
+ - Filter: discard values outside [clip_min, clip_max] (black); only values in range show color.
30
  """
31
+ if mode == "Full":
32
+ return np.clip(heatmap, 0, 1).astype(np.float32)
33
+ if mode == "Percentile":
34
+ pmin = float(np.percentile(heatmap, min_percentile))
35
+ pmax = float(np.percentile(heatmap, max_percentile))
36
+ if pmax > pmin:
37
+ out = (heatmap.astype(np.float32) - pmin) / (pmax - pmin)
38
+ return np.clip(out, 0, 1).astype(np.float32)
39
+ return np.clip(heatmap, 0, 1).astype(np.float32)
40
+ if mode == "Rescale":
41
+ vmin, vmax = float(clip_min), float(clip_max)
42
+ if vmax > vmin:
43
+ out = (heatmap.astype(np.float32) - vmin) / (vmax - vmin)
44
+ return np.clip(out, 0, 1).astype(np.float32)
45
+ return np.clip(heatmap, 0, 1).astype(np.float32)
46
+ if mode == "Clip":
47
+ vmin, vmax = float(clip_min), float(clip_max)
48
+ if vmax > vmin:
49
+ h = np.clip(heatmap.astype(np.float32), vmin, vmax)
50
+ dmin, dmax = float(np.min(heatmap)), float(np.max(heatmap))
51
+ if dmax > dmin:
52
+ out = (h - dmin) / (dmax - dmin)
53
+ return np.clip(out, 0, 1).astype(np.float32)
54
+ return np.clip(h, 0, 1).astype(np.float32)
55
+ return np.clip(heatmap, 0, 1).astype(np.float32)
56
+ if mode == "Threshold":
57
+ mode = "Filter" # backward compat
58
+ if mode == "Filter":
59
+ vmin, vmax = float(clip_min), float(clip_max)
60
+ if vmax > vmin:
61
+ h = heatmap.astype(np.float32)
62
+ mask = (h >= vmin) & (h <= vmax)
63
+ out = np.zeros_like(h)
64
+ out[mask] = (h[mask] - vmin) / (vmax - vmin)
65
+ return np.clip(out, 0, 1).astype(np.float32)
66
  return np.clip(heatmap, 0, 1).astype(np.float32)
 
 
 
 
67
  return np.clip(heatmap, 0, 1).astype(np.float32)