Spaces:
Running
Running
added different scales
Browse files- app.py +65 -14
- utils/display.py +42 -9
app.py
CHANGED
|
@@ -49,8 +49,14 @@ if HAS_DRAWABLE_CANVAS and ST_DIALOG:
|
|
| 49 |
if raw_heatmap is None:
|
| 50 |
st.warning("No prediction available to measure.")
|
| 51 |
return
|
| 52 |
-
display_mode = st.session_state.get("measure_display_mode", "
|
| 53 |
-
display_heatmap = apply_display_scale(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
bf_img = st.session_state.get("measure_bf_img")
|
| 55 |
original_vals = st.session_state.get("measure_original_vals")
|
| 56 |
cell_vals = st.session_state.get("measure_cell_vals")
|
|
@@ -73,12 +79,17 @@ def _get_measure_dialog_fn():
|
|
| 73 |
|
| 74 |
|
| 75 |
def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, colormap_name,
|
| 76 |
-
display_mode, auto_cell_boundary, cell_mask=None
|
|
|
|
| 77 |
"""Populate session state for the measure tool. If cell_mask is None and auto_cell_boundary, computes it."""
|
| 78 |
if cell_mask is None and auto_cell_boundary:
|
| 79 |
cell_mask = estimate_cell_mask(heatmap)
|
| 80 |
st.session_state["measure_raw_heatmap"] = heatmap.copy()
|
| 81 |
st.session_state["measure_display_mode"] = display_mode
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
st.session_state["measure_bf_img"] = img.copy()
|
| 83 |
st.session_state["measure_input_filename"] = key_img or "image"
|
| 84 |
st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force)
|
|
@@ -216,24 +227,48 @@ with st.sidebar:
|
|
| 216 |
except FileNotFoundError:
|
| 217 |
st.error("config/substrate_settings.json not found")
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
display_mode = st.radio(
|
| 220 |
-
"
|
| 221 |
-
["
|
| 222 |
-
help="
|
| 223 |
horizontal=True,
|
|
|
|
| 224 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
colormap_name = st.selectbox(
|
| 226 |
"Heatmap colormap",
|
| 227 |
list(COLORMAPS.keys()),
|
| 228 |
help="Color scheme for the force map. Viridis is often preferred for accessibility.",
|
| 229 |
)
|
| 230 |
|
| 231 |
-
auto_cell_boundary = st.checkbox(
|
| 232 |
-
"Auto boundary",
|
| 233 |
-
value=False,
|
| 234 |
-
help="When on: estimate cell region from force map and use it for metrics (red contour). When off: use entire map.",
|
| 235 |
-
)
|
| 236 |
-
|
| 237 |
theme = st.radio("Theme", ["Light", "Dark"], horizontal=True, key="theme_selector")
|
| 238 |
_inject_theme_css(theme)
|
| 239 |
|
|
@@ -327,7 +362,13 @@ if just_ran:
|
|
| 327 |
|
| 328 |
st.success("Prediction complete!")
|
| 329 |
|
| 330 |
-
display_heatmap = apply_display_scale(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
cache_key = (model_type, checkpoint, key_img)
|
| 333 |
st.session_state["prediction_result"] = {
|
|
@@ -341,6 +382,8 @@ if just_ran:
|
|
| 341 |
_populate_measure_session_state(
|
| 342 |
heatmap, img, pixel_sum, force, key_img, colormap_name,
|
| 343 |
display_mode, auto_cell_boundary, cell_mask=cell_mask,
|
|
|
|
|
|
|
| 344 |
)
|
| 345 |
render_result_display(
|
| 346 |
img, heatmap, display_heatmap, pixel_sum, force, key_img,
|
|
@@ -358,11 +401,19 @@ if just_ran:
|
|
| 358 |
elif has_cached:
|
| 359 |
r = st.session_state["prediction_result"]
|
| 360 |
img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
|
| 361 |
-
display_heatmap = apply_display_scale(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None
|
| 363 |
_populate_measure_session_state(
|
| 364 |
heatmap, img, pixel_sum, force, key_img, colormap_name,
|
| 365 |
display_mode, auto_cell_boundary, cell_mask=cell_mask,
|
|
|
|
|
|
|
| 366 |
)
|
| 367 |
|
| 368 |
if st.session_state.pop("open_measure_dialog", False):
|
|
|
|
| 49 |
if raw_heatmap is None:
|
| 50 |
st.warning("No prediction available to measure.")
|
| 51 |
return
|
| 52 |
+
display_mode = st.session_state.get("measure_display_mode", "Full")
|
| 53 |
+
display_heatmap = apply_display_scale(
|
| 54 |
+
raw_heatmap, display_mode,
|
| 55 |
+
min_percentile=st.session_state.get("measure_min_percentile", 0),
|
| 56 |
+
max_percentile=st.session_state.get("measure_max_percentile", 100),
|
| 57 |
+
clip_min=st.session_state.get("measure_clip_min", 0),
|
| 58 |
+
clip_max=st.session_state.get("measure_clip_max", 1),
|
| 59 |
+
)
|
| 60 |
bf_img = st.session_state.get("measure_bf_img")
|
| 61 |
original_vals = st.session_state.get("measure_original_vals")
|
| 62 |
cell_vals = st.session_state.get("measure_cell_vals")
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, colormap_name,
|
| 82 |
+
display_mode, auto_cell_boundary, cell_mask=None,
|
| 83 |
+
min_percentile=0, max_percentile=100, clip_min=0, clip_max=1):
|
| 84 |
"""Populate session state for the measure tool. If cell_mask is None and auto_cell_boundary, computes it."""
|
| 85 |
if cell_mask is None and auto_cell_boundary:
|
| 86 |
cell_mask = estimate_cell_mask(heatmap)
|
| 87 |
st.session_state["measure_raw_heatmap"] = heatmap.copy()
|
| 88 |
st.session_state["measure_display_mode"] = display_mode
|
| 89 |
+
st.session_state["measure_min_percentile"] = min_percentile
|
| 90 |
+
st.session_state["measure_max_percentile"] = max_percentile
|
| 91 |
+
st.session_state["measure_clip_min"] = clip_min
|
| 92 |
+
st.session_state["measure_clip_max"] = clip_max
|
| 93 |
st.session_state["measure_bf_img"] = img.copy()
|
| 94 |
st.session_state["measure_input_filename"] = key_img or "image"
|
| 95 |
st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force)
|
|
|
|
| 227 |
except FileNotFoundError:
|
| 228 |
st.error("config/substrate_settings.json not found")
|
| 229 |
|
| 230 |
+
auto_cell_boundary = st.toggle(
|
| 231 |
+
"Auto boundary",
|
| 232 |
+
value=False,
|
| 233 |
+
help="When on: estimate cell region from force map and use it for metrics (red contour). When off: use entire map.",
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
st.markdown('<p style="font-size: 0.95rem; font-weight: 500; margin-bottom: 0.5rem;">Heatmap display</p>', unsafe_allow_html=True)
|
| 237 |
display_mode = st.radio(
|
| 238 |
+
"Mode",
|
| 239 |
+
["Full", "Percentile", "Rescale", "Clip", "Filter"],
|
| 240 |
+
help="Full: 0–1 as-is. Percentile: min/max percentiles. Rescale: stretch range to colors. Clip: clip, keep scale. Filter: show only in range.",
|
| 241 |
horizontal=True,
|
| 242 |
+
label_visibility="collapsed",
|
| 243 |
)
|
| 244 |
+
min_percentile, max_percentile = 0, 100
|
| 245 |
+
clip_min, clip_max = 0.0, 1.0
|
| 246 |
+
if display_mode == "Percentile":
|
| 247 |
+
col_pmin, col_pmax = st.columns(2)
|
| 248 |
+
with col_pmin:
|
| 249 |
+
min_percentile = st.slider("Min percentile", 0, 100, 2, 1, help="Values below this percentile → black")
|
| 250 |
+
with col_pmax:
|
| 251 |
+
max_percentile = st.slider("Max percentile", 0, 100, 99, 1, help="Values above this percentile → white")
|
| 252 |
+
if min_percentile >= max_percentile:
|
| 253 |
+
st.warning("Min percentile must be less than max. Using min=0, max=100.")
|
| 254 |
+
min_percentile, max_percentile = 0, 100
|
| 255 |
+
elif display_mode in ("Rescale", "Clip", "Filter"):
|
| 256 |
+
col_cmin, col_cmax = st.columns(2)
|
| 257 |
+
with col_cmin:
|
| 258 |
+
clip_min = st.number_input("Min", value=0.0, min_value=None, max_value=None, step=0.01, format="%.3f",
|
| 259 |
+
help="Rescale: below → black. Clip: clamp to min. Filter: below → discarded.")
|
| 260 |
+
with col_cmax:
|
| 261 |
+
clip_max = st.number_input("Max", value=1.0, min_value=None, max_value=None, step=0.01, format="%.3f",
|
| 262 |
+
help="Rescale: above → white. Clip: clamp to max. Filter: above → discarded.")
|
| 263 |
+
if clip_min >= clip_max:
|
| 264 |
+
st.warning("Min must be less than max. Using min=0, max=1.")
|
| 265 |
+
clip_min, clip_max = 0.0, 1.0
|
| 266 |
colormap_name = st.selectbox(
|
| 267 |
"Heatmap colormap",
|
| 268 |
list(COLORMAPS.keys()),
|
| 269 |
help="Color scheme for the force map. Viridis is often preferred for accessibility.",
|
| 270 |
)
|
| 271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
theme = st.radio("Theme", ["Light", "Dark"], horizontal=True, key="theme_selector")
|
| 273 |
_inject_theme_css(theme)
|
| 274 |
|
|
|
|
| 362 |
|
| 363 |
st.success("Prediction complete!")
|
| 364 |
|
| 365 |
+
display_heatmap = apply_display_scale(
|
| 366 |
+
heatmap, display_mode,
|
| 367 |
+
min_percentile=min_percentile,
|
| 368 |
+
max_percentile=max_percentile,
|
| 369 |
+
clip_min=clip_min,
|
| 370 |
+
clip_max=clip_max,
|
| 371 |
+
)
|
| 372 |
|
| 373 |
cache_key = (model_type, checkpoint, key_img)
|
| 374 |
st.session_state["prediction_result"] = {
|
|
|
|
| 382 |
_populate_measure_session_state(
|
| 383 |
heatmap, img, pixel_sum, force, key_img, colormap_name,
|
| 384 |
display_mode, auto_cell_boundary, cell_mask=cell_mask,
|
| 385 |
+
min_percentile=min_percentile, max_percentile=max_percentile,
|
| 386 |
+
clip_min=clip_min, clip_max=clip_max,
|
| 387 |
)
|
| 388 |
render_result_display(
|
| 389 |
img, heatmap, display_heatmap, pixel_sum, force, key_img,
|
|
|
|
| 401 |
elif has_cached:
|
| 402 |
r = st.session_state["prediction_result"]
|
| 403 |
img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
|
| 404 |
+
display_heatmap = apply_display_scale(
|
| 405 |
+
heatmap, display_mode,
|
| 406 |
+
min_percentile=min_percentile,
|
| 407 |
+
max_percentile=max_percentile,
|
| 408 |
+
clip_min=clip_min,
|
| 409 |
+
clip_max=clip_max,
|
| 410 |
+
)
|
| 411 |
cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None
|
| 412 |
_populate_measure_session_state(
|
| 413 |
heatmap, img, pixel_sum, force, key_img, colormap_name,
|
| 414 |
display_mode, auto_cell_boundary, cell_mask=cell_mask,
|
| 415 |
+
min_percentile=min_percentile, max_percentile=max_percentile,
|
| 416 |
+
clip_min=clip_min, clip_max=clip_max,
|
| 417 |
)
|
| 418 |
|
| 419 |
if st.session_state.pop("open_measure_dialog", False):
|
utils/display.py
CHANGED
|
@@ -19,16 +19,49 @@ def cv_colormap_to_plotly_colorscale(colormap_name, n_samples=None):
|
|
| 19 |
return scale
|
| 20 |
|
| 21 |
|
| 22 |
-
def apply_display_scale(heatmap, mode):
|
| 23 |
"""
|
| 24 |
-
Apply display scaling (Fiji
|
| 25 |
-
-
|
| 26 |
-
-
|
|
|
|
|
|
|
|
|
|
| 27 |
"""
|
| 28 |
-
if mode == "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
return np.clip(heatmap, 0, 1).astype(np.float32)
|
| 30 |
-
hmin, hmax = float(np.min(heatmap)), float(np.max(heatmap))
|
| 31 |
-
if hmax > hmin:
|
| 32 |
-
out = (heatmap.astype(np.float32) - hmin) / (hmax - hmin)
|
| 33 |
-
return np.clip(out, 0, 1).astype(np.float32)
|
| 34 |
return np.clip(heatmap, 0, 1).astype(np.float32)
|
|
|
|
| 19 |
return scale
|
| 20 |
|
| 21 |
|
| 22 |
+
def apply_display_scale(heatmap, mode, min_percentile=0, max_percentile=100, clip_min=0, clip_max=1):
|
| 23 |
"""
|
| 24 |
+
Apply display scaling (Fiji ImageJ B&C style). Display only—does not change underlying values.
|
| 25 |
+
- Full: use 0–1 range as-is (clip values outside)
|
| 26 |
+
- Percentile: map data at min_percentile..max_percentile to 0..1 (values outside clipped).
|
| 27 |
+
- Rescale: map [clip_min, clip_max] to [0, 1]; values outside → black/white. Stretches range for contrast.
|
| 28 |
+
- Clip: clip values to [clip_min, clip_max], display with original data scale (no stretching).
|
| 29 |
+
- Filter: discard values outside [clip_min, clip_max] (black); only values in range show color.
|
| 30 |
"""
|
| 31 |
+
if mode == "Full":
|
| 32 |
+
return np.clip(heatmap, 0, 1).astype(np.float32)
|
| 33 |
+
if mode == "Percentile":
|
| 34 |
+
pmin = float(np.percentile(heatmap, min_percentile))
|
| 35 |
+
pmax = float(np.percentile(heatmap, max_percentile))
|
| 36 |
+
if pmax > pmin:
|
| 37 |
+
out = (heatmap.astype(np.float32) - pmin) / (pmax - pmin)
|
| 38 |
+
return np.clip(out, 0, 1).astype(np.float32)
|
| 39 |
+
return np.clip(heatmap, 0, 1).astype(np.float32)
|
| 40 |
+
if mode == "Rescale":
|
| 41 |
+
vmin, vmax = float(clip_min), float(clip_max)
|
| 42 |
+
if vmax > vmin:
|
| 43 |
+
out = (heatmap.astype(np.float32) - vmin) / (vmax - vmin)
|
| 44 |
+
return np.clip(out, 0, 1).astype(np.float32)
|
| 45 |
+
return np.clip(heatmap, 0, 1).astype(np.float32)
|
| 46 |
+
if mode == "Clip":
|
| 47 |
+
vmin, vmax = float(clip_min), float(clip_max)
|
| 48 |
+
if vmax > vmin:
|
| 49 |
+
h = np.clip(heatmap.astype(np.float32), vmin, vmax)
|
| 50 |
+
dmin, dmax = float(np.min(heatmap)), float(np.max(heatmap))
|
| 51 |
+
if dmax > dmin:
|
| 52 |
+
out = (h - dmin) / (dmax - dmin)
|
| 53 |
+
return np.clip(out, 0, 1).astype(np.float32)
|
| 54 |
+
return np.clip(h, 0, 1).astype(np.float32)
|
| 55 |
+
return np.clip(heatmap, 0, 1).astype(np.float32)
|
| 56 |
+
if mode == "Threshold":
|
| 57 |
+
mode = "Filter" # backward compat
|
| 58 |
+
if mode == "Filter":
|
| 59 |
+
vmin, vmax = float(clip_min), float(clip_max)
|
| 60 |
+
if vmax > vmin:
|
| 61 |
+
h = heatmap.astype(np.float32)
|
| 62 |
+
mask = (h >= vmin) & (h <= vmax)
|
| 63 |
+
out = np.zeros_like(h)
|
| 64 |
+
out[mask] = (h[mask] - vmin) / (vmax - vmin)
|
| 65 |
+
return np.clip(out, 0, 1).astype(np.float32)
|
| 66 |
return np.clip(heatmap, 0, 1).astype(np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
return np.clip(heatmap, 0, 1).astype(np.float32)
|