Spaces:
Running
Running
added batch mode and corrected force scale
Browse files- app.py +172 -49
- config/constants.py +1 -0
- ui/components.py +159 -1
- utils/display.py +6 -26
app.py
CHANGED
|
@@ -16,6 +16,7 @@ if S2F_ROOT not in sys.path:
|
|
| 16 |
sys.path.insert(0, S2F_ROOT)
|
| 17 |
|
| 18 |
from config.constants import (
|
|
|
|
| 19 |
COLORMAPS,
|
| 20 |
DEFAULT_SUBSTRATE,
|
| 21 |
MODEL_TYPE_LABELS,
|
|
@@ -29,6 +30,7 @@ from utils.display import apply_display_scale
|
|
| 29 |
from ui.components import (
|
| 30 |
build_original_vals,
|
| 31 |
build_cell_vals,
|
|
|
|
| 32 |
render_result_display,
|
| 33 |
render_region_canvas,
|
| 34 |
render_system_status,
|
|
@@ -49,7 +51,7 @@ if HAS_DRAWABLE_CANVAS and ST_DIALOG:
|
|
| 49 |
if raw_heatmap is None:
|
| 50 |
st.warning("No prediction available to measure.")
|
| 51 |
return
|
| 52 |
-
display_mode = st.session_state.get("measure_display_mode", "
|
| 53 |
display_heatmap = apply_display_scale(
|
| 54 |
raw_heatmap, display_mode,
|
| 55 |
min_percentile=st.session_state.get("measure_min_percentile", 0),
|
|
@@ -121,6 +123,9 @@ def _inject_theme_css(theme):
|
|
| 121 |
st.markdown("""
|
| 122 |
<style>
|
| 123 |
section[data-testid="stSidebar"] { width: 380px !important; }
|
|
|
|
|
|
|
|
|
|
| 124 |
section[data-testid="stSidebar"] h2 {
|
| 125 |
font-size: 1.25rem !important;
|
| 126 |
font-weight: 600 !important;
|
|
@@ -233,13 +238,17 @@ with st.sidebar:
|
|
| 233 |
help="When on: estimate cell region from force map and use it for metrics (red contour). When off: use entire map.",
|
| 234 |
)
|
| 235 |
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
display_mode = st.radio(
|
| 238 |
-
"
|
| 239 |
-
["
|
| 240 |
-
help="Full: 0–1 as-is. Percentile: min/max percentiles. Rescale: stretch range to colors. Clip: clip, keep scale. Filter: show only in range.",
|
| 241 |
horizontal=True,
|
| 242 |
-
|
| 243 |
)
|
| 244 |
min_percentile, max_percentile = 0, 100
|
| 245 |
clip_min, clip_max = 0.0, 1.0
|
|
@@ -252,14 +261,14 @@ with st.sidebar:
|
|
| 252 |
if min_percentile >= max_percentile:
|
| 253 |
st.warning("Min percentile must be less than max. Using min=0, max=100.")
|
| 254 |
min_percentile, max_percentile = 0, 100
|
| 255 |
-
elif display_mode
|
| 256 |
col_cmin, col_cmax = st.columns(2)
|
| 257 |
with col_cmin:
|
| 258 |
-
clip_min = st.number_input("Min", value=0.0, min_value=
|
| 259 |
-
help="
|
| 260 |
with col_cmax:
|
| 261 |
-
clip_max = st.number_input("Max", value=1.0, min_value=
|
| 262 |
-
help="
|
| 263 |
if clip_min >= clip_max:
|
| 264 |
st.warning("Min must be less than max. Using min=0, max=1.")
|
| 265 |
clip_min, clip_max = 0.0, 1.0
|
|
@@ -275,44 +284,93 @@ with st.sidebar:
|
|
| 275 |
# Main area: image input
|
| 276 |
img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed")
|
| 277 |
img = None
|
|
|
|
| 278 |
uploaded = None
|
|
|
|
| 279 |
selected_sample = None
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
else:
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
sample_files,
|
| 300 |
-
format_func=lambda x: x,
|
| 301 |
-
key=f"sample_{model_type}",
|
| 302 |
)
|
| 303 |
-
if
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
n_cols = min(5, len(thumbnails))
|
| 309 |
-
cols = st.columns(n_cols)
|
| 310 |
-
for i, (fname, sample_img) in enumerate(thumbnails):
|
| 311 |
-
if sample_img is not None:
|
| 312 |
-
with cols[i % n_cols]:
|
| 313 |
-
st.image(sample_img, caption=fname, width=120)
|
| 314 |
else:
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
col_btn, col_model, col_path = st.columns([1, 1, 3])
|
| 318 |
with col_btn:
|
|
@@ -322,16 +380,24 @@ with col_model:
|
|
| 322 |
with col_path:
|
| 323 |
ckp_path = f"ckp/{ckp_subfolder_name}/{checkpoint}" if checkpoint else f"ckp/{ckp_subfolder_name}/"
|
| 324 |
st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>Checkpoint: <code>{ckp_path}</code></span>", unsafe_allow_html=True)
|
|
|
|
| 325 |
has_image = img is not None
|
|
|
|
| 326 |
|
| 327 |
if "prediction_result" not in st.session_state:
|
| 328 |
st.session_state["prediction_result"] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
-
|
| 331 |
-
cached = st.session_state["prediction_result"]
|
| 332 |
key_img = (uploaded.name if uploaded else None) if img_source == "Upload" else selected_sample
|
| 333 |
current_key = (model_type, checkpoint, key_img)
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
|
| 337 |
def get_or_create_predictor(model_type, checkpoint, ckp_folder):
|
|
@@ -348,7 +414,62 @@ def get_or_create_predictor(model_type, checkpoint, ckp_folder):
|
|
| 348 |
return st.session_state["predictor"]
|
| 349 |
|
| 350 |
|
| 351 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
st.session_state["prediction_result"] = None
|
| 353 |
with st.spinner("Loading model and predicting..."):
|
| 354 |
try:
|
|
@@ -432,8 +553,10 @@ elif has_cached:
|
|
| 432 |
|
| 433 |
elif run and not checkpoint:
|
| 434 |
st.warning("Please add checkpoint files to the ckp/ folder and select one.")
|
| 435 |
-
elif run and not has_image:
|
| 436 |
st.warning("Please upload an image or select an example.")
|
|
|
|
|
|
|
| 437 |
|
| 438 |
st.sidebar.divider()
|
| 439 |
render_system_status()
|
|
|
|
| 16 |
sys.path.insert(0, S2F_ROOT)
|
| 17 |
|
| 18 |
from config.constants import (
|
| 19 |
+
BATCH_MAX_IMAGES,
|
| 20 |
COLORMAPS,
|
| 21 |
DEFAULT_SUBSTRATE,
|
| 22 |
MODEL_TYPE_LABELS,
|
|
|
|
| 30 |
from ui.components import (
|
| 31 |
build_original_vals,
|
| 32 |
build_cell_vals,
|
| 33 |
+
render_batch_results,
|
| 34 |
render_result_display,
|
| 35 |
render_region_canvas,
|
| 36 |
render_system_status,
|
|
|
|
| 51 |
if raw_heatmap is None:
|
| 52 |
st.warning("No prediction available to measure.")
|
| 53 |
return
|
| 54 |
+
display_mode = st.session_state.get("measure_display_mode", "Default")
|
| 55 |
display_heatmap = apply_display_scale(
|
| 56 |
raw_heatmap, display_mode,
|
| 57 |
min_percentile=st.session_state.get("measure_min_percentile", 0),
|
|
|
|
| 123 |
st.markdown("""
|
| 124 |
<style>
|
| 125 |
section[data-testid="stSidebar"] { width: 380px !important; }
|
| 126 |
+
@media (max-width: 768px) {
|
| 127 |
+
section[data-testid="stSidebar"] { width: 100% !important; max-width: 100% !important; }
|
| 128 |
+
}
|
| 129 |
section[data-testid="stSidebar"] h2 {
|
| 130 |
font-size: 1.25rem !important;
|
| 131 |
font-weight: 600 !important;
|
|
|
|
| 238 |
help="When on: estimate cell region from force map and use it for metrics (red contour). When off: use entire map.",
|
| 239 |
)
|
| 240 |
|
| 241 |
+
batch_mode = st.toggle(
|
| 242 |
+
"Batch mode",
|
| 243 |
+
value=False,
|
| 244 |
+
help=f"Process up to {BATCH_MAX_IMAGES} images at once. Upload multiple files or select multiple examples.",
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
display_mode = st.radio(
|
| 248 |
+
"Heatmap display",
|
| 249 |
+
["Default", "Percentile", "Range"],
|
|
|
|
| 250 |
horizontal=True,
|
| 251 |
+
help="Default: full 0–1 range. Percentile: map a percentile range to improve contrast when few bright pixels dominate. Range: show only values in [min, max]; others hidden (black).",
|
| 252 |
)
|
| 253 |
min_percentile, max_percentile = 0, 100
|
| 254 |
clip_min, clip_max = 0.0, 1.0
|
|
|
|
| 261 |
if min_percentile >= max_percentile:
|
| 262 |
st.warning("Min percentile must be less than max. Using min=0, max=100.")
|
| 263 |
min_percentile, max_percentile = 0, 100
|
| 264 |
+
elif display_mode == "Range":
|
| 265 |
col_cmin, col_cmax = st.columns(2)
|
| 266 |
with col_cmin:
|
| 267 |
+
clip_min = st.number_input("Min", value=0.0, min_value=0.0, max_value=1.0, step=0.01, format="%.3f",
|
| 268 |
+
help="Values below this range → hidden (black)")
|
| 269 |
with col_cmax:
|
| 270 |
+
clip_max = st.number_input("Max", value=1.0, min_value=0.0, max_value=1.0, step=0.01, format="%.3f",
|
| 271 |
+
help="Values above this range → hidden (black)")
|
| 272 |
if clip_min >= clip_max:
|
| 273 |
st.warning("Min must be less than max. Using min=0, max=1.")
|
| 274 |
clip_min, clip_max = 0.0, 1.0
|
|
|
|
| 284 |
# Main area: image input
|
| 285 |
img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed")
|
| 286 |
img = None
|
| 287 |
+
imgs_batch = [] # list of (img, key_img) for batch mode
|
| 288 |
uploaded = None
|
| 289 |
+
uploaded_list = []
|
| 290 |
selected_sample = None
|
| 291 |
+
selected_samples = []
|
| 292 |
+
|
| 293 |
+
if batch_mode:
|
| 294 |
+
# Batch mode: multiple images (max BATCH_MAX_IMAGES)
|
| 295 |
+
if img_source == "Upload":
|
| 296 |
+
uploaded_list = st.file_uploader(
|
| 297 |
+
"Upload bright-field images",
|
| 298 |
+
type=["tif", "tiff", "png", "jpg", "jpeg"],
|
| 299 |
+
accept_multiple_files=True,
|
| 300 |
+
help=f"Select up to {BATCH_MAX_IMAGES} images. Bright-field microscopy (grayscale or RGB).",
|
| 301 |
+
)
|
| 302 |
+
if uploaded_list:
|
| 303 |
+
uploaded_list = uploaded_list[:BATCH_MAX_IMAGES]
|
| 304 |
+
for u in uploaded_list:
|
| 305 |
+
bytes_data = u.read()
|
| 306 |
+
nparr = np.frombuffer(bytes_data, np.uint8)
|
| 307 |
+
decoded = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
|
| 308 |
+
if decoded is not None:
|
| 309 |
+
imgs_batch.append((decoded, u.name))
|
| 310 |
+
u.seek(0)
|
| 311 |
+
else:
|
| 312 |
+
sample_folder = get_sample_folder(S2F_ROOT, model_type)
|
| 313 |
+
sample_files = list_files_in_folder(sample_folder, SAMPLE_EXTENSIONS)
|
| 314 |
+
sample_subfolder_name = model_subfolder(model_type)
|
| 315 |
+
if sample_files:
|
| 316 |
+
selected_samples = st.multiselect(
|
| 317 |
+
f"Select example images (max {BATCH_MAX_IMAGES})",
|
| 318 |
+
sample_files,
|
| 319 |
+
default=None,
|
| 320 |
+
max_selections=BATCH_MAX_IMAGES,
|
| 321 |
+
key=f"sample_batch_{model_type}",
|
| 322 |
+
)
|
| 323 |
+
if selected_samples:
|
| 324 |
+
for fname in selected_samples[:BATCH_MAX_IMAGES]:
|
| 325 |
+
sample_path = os.path.join(sample_folder, fname)
|
| 326 |
+
loaded = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
|
| 327 |
+
if loaded is not None:
|
| 328 |
+
imgs_batch.append((loaded, fname))
|
| 329 |
+
thumbnails = get_cached_sample_thumbnails(model_type, sample_folder, sample_files)
|
| 330 |
+
n_cols = min(5, len(thumbnails))
|
| 331 |
+
cols = st.columns(n_cols)
|
| 332 |
+
for i, (fname, sample_img) in enumerate(thumbnails):
|
| 333 |
+
if sample_img is not None:
|
| 334 |
+
with cols[i % n_cols]:
|
| 335 |
+
st.image(sample_img, caption=fname, width=120)
|
| 336 |
+
else:
|
| 337 |
+
st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.")
|
| 338 |
else:
|
| 339 |
+
# Single image mode
|
| 340 |
+
if img_source == "Upload":
|
| 341 |
+
uploaded = st.file_uploader(
|
| 342 |
+
"Upload bright-field image",
|
| 343 |
+
type=["tif", "tiff", "png", "jpg", "jpeg"],
|
| 344 |
+
help="Bright-field microscopy image of a cell or spheroid on a substrate (grayscale or RGB).",
|
|
|
|
|
|
|
|
|
|
| 345 |
)
|
| 346 |
+
if uploaded:
|
| 347 |
+
bytes_data = uploaded.read()
|
| 348 |
+
nparr = np.frombuffer(bytes_data, np.uint8)
|
| 349 |
+
img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
|
| 350 |
+
uploaded.seek(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
else:
|
| 352 |
+
sample_folder = get_sample_folder(S2F_ROOT, model_type)
|
| 353 |
+
sample_files = list_files_in_folder(sample_folder, SAMPLE_EXTENSIONS)
|
| 354 |
+
sample_subfolder_name = model_subfolder(model_type)
|
| 355 |
+
if sample_files:
|
| 356 |
+
selected_sample = st.selectbox(
|
| 357 |
+
f"Select example image (from `samples/{sample_subfolder_name}/`)",
|
| 358 |
+
sample_files,
|
| 359 |
+
format_func=lambda x: x,
|
| 360 |
+
key=f"sample_{model_type}",
|
| 361 |
+
)
|
| 362 |
+
if selected_sample:
|
| 363 |
+
sample_path = os.path.join(sample_folder, selected_sample)
|
| 364 |
+
img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
|
| 365 |
+
thumbnails = get_cached_sample_thumbnails(model_type, sample_folder, sample_files)
|
| 366 |
+
n_cols = min(5, len(thumbnails))
|
| 367 |
+
cols = st.columns(n_cols)
|
| 368 |
+
for i, (fname, sample_img) in enumerate(thumbnails):
|
| 369 |
+
if sample_img is not None:
|
| 370 |
+
with cols[i % n_cols]:
|
| 371 |
+
st.image(sample_img, caption=fname, width=120)
|
| 372 |
+
else:
|
| 373 |
+
st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.")
|
| 374 |
|
| 375 |
col_btn, col_model, col_path = st.columns([1, 1, 3])
|
| 376 |
with col_btn:
|
|
|
|
| 380 |
with col_path:
|
| 381 |
ckp_path = f"ckp/{ckp_subfolder_name}/{checkpoint}" if checkpoint else f"ckp/{ckp_subfolder_name}/"
|
| 382 |
st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>Checkpoint: <code>{ckp_path}</code></span>", unsafe_allow_html=True)
|
| 383 |
+
|
| 384 |
has_image = img is not None
|
| 385 |
+
has_batch = len(imgs_batch) > 0
|
| 386 |
|
| 387 |
if "prediction_result" not in st.session_state:
|
| 388 |
st.session_state["prediction_result"] = None
|
| 389 |
+
if "batch_results" not in st.session_state:
|
| 390 |
+
st.session_state["batch_results"] = None
|
| 391 |
+
if not batch_mode:
|
| 392 |
+
st.session_state["batch_results"] = None # Clear when switching to single mode
|
| 393 |
|
| 394 |
+
# Single-image keys (for non-batch)
|
|
|
|
| 395 |
key_img = (uploaded.name if uploaded else None) if img_source == "Upload" else selected_sample
|
| 396 |
current_key = (model_type, checkpoint, key_img)
|
| 397 |
+
cached = st.session_state["prediction_result"]
|
| 398 |
+
has_cached = cached is not None and cached.get("cache_key") == current_key and not batch_mode
|
| 399 |
+
just_ran = run and checkpoint and has_image and not batch_mode
|
| 400 |
+
just_ran_batch = run and checkpoint and has_batch and batch_mode
|
| 401 |
|
| 402 |
|
| 403 |
def get_or_create_predictor(model_type, checkpoint, ckp_folder):
|
|
|
|
| 414 |
return st.session_state["predictor"]
|
| 415 |
|
| 416 |
|
| 417 |
+
if just_ran_batch:
|
| 418 |
+
st.session_state["prediction_result"] = None
|
| 419 |
+
st.session_state["batch_results"] = None
|
| 420 |
+
with st.spinner("Loading model and predicting..."):
|
| 421 |
+
try:
|
| 422 |
+
predictor = get_or_create_predictor(model_type, checkpoint, ckp_folder)
|
| 423 |
+
sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE
|
| 424 |
+
batch_results = []
|
| 425 |
+
progress_bar = st.progress(0, text="Processing images...")
|
| 426 |
+
for idx, (img_b, key_b) in enumerate(imgs_batch):
|
| 427 |
+
progress_bar.progress((idx + 1) / len(imgs_batch), text=f"Processing {key_b}...")
|
| 428 |
+
heatmap, force, pixel_sum = predictor.predict(
|
| 429 |
+
image_array=img_b,
|
| 430 |
+
substrate=sub_val,
|
| 431 |
+
substrate_config=substrate_config if model_type == "single_cell" else None,
|
| 432 |
+
)
|
| 433 |
+
cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None
|
| 434 |
+
batch_results.append({
|
| 435 |
+
"img": img_b.copy(),
|
| 436 |
+
"heatmap": heatmap.copy(),
|
| 437 |
+
"force": force,
|
| 438 |
+
"pixel_sum": pixel_sum,
|
| 439 |
+
"key_img": key_b,
|
| 440 |
+
"cell_mask": cell_mask,
|
| 441 |
+
})
|
| 442 |
+
progress_bar.empty()
|
| 443 |
+
st.session_state["batch_results"] = batch_results
|
| 444 |
+
st.success(f"Prediction complete for {len(batch_results)} image(s)!")
|
| 445 |
+
render_batch_results(
|
| 446 |
+
batch_results,
|
| 447 |
+
colormap_name=colormap_name,
|
| 448 |
+
display_mode=display_mode,
|
| 449 |
+
min_percentile=min_percentile,
|
| 450 |
+
max_percentile=max_percentile,
|
| 451 |
+
clip_min=clip_min,
|
| 452 |
+
clip_max=clip_max,
|
| 453 |
+
auto_cell_boundary=auto_cell_boundary,
|
| 454 |
+
)
|
| 455 |
+
except Exception as e:
|
| 456 |
+
st.error(f"Prediction failed: {e}")
|
| 457 |
+
st.code(traceback.format_exc())
|
| 458 |
+
|
| 459 |
+
elif batch_mode and st.session_state.get("batch_results"):
|
| 460 |
+
st.success("Prediction complete!")
|
| 461 |
+
render_batch_results(
|
| 462 |
+
st.session_state["batch_results"],
|
| 463 |
+
colormap_name=colormap_name,
|
| 464 |
+
display_mode=display_mode,
|
| 465 |
+
min_percentile=min_percentile,
|
| 466 |
+
max_percentile=max_percentile,
|
| 467 |
+
clip_min=clip_min,
|
| 468 |
+
clip_max=clip_max,
|
| 469 |
+
auto_cell_boundary=auto_cell_boundary,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
elif just_ran:
|
| 473 |
st.session_state["prediction_result"] = None
|
| 474 |
with st.spinner("Loading model and predicting..."):
|
| 475 |
try:
|
|
|
|
| 553 |
|
| 554 |
elif run and not checkpoint:
|
| 555 |
st.warning("Please add checkpoint files to the ckp/ folder and select one.")
|
| 556 |
+
elif run and not has_image and not has_batch:
|
| 557 |
st.warning("Please upload an image or select an example.")
|
| 558 |
+
elif run and batch_mode and not has_batch:
|
| 559 |
+
st.warning(f"Please upload or select 1–{BATCH_MAX_IMAGES} images for batch processing.")
|
| 560 |
|
| 561 |
st.sidebar.divider()
|
| 562 |
render_system_status()
|
config/constants.py
CHANGED
|
@@ -12,6 +12,7 @@ DEFAULT_SUBSTRATE = "Fibroblasts_Fibronectin_6KPa"
|
|
| 12 |
# UI
|
| 13 |
CANVAS_SIZE = 320
|
| 14 |
SAMPLE_THUMBNAIL_LIMIT = 8
|
|
|
|
| 15 |
COLORMAP_N_SAMPLES = 64
|
| 16 |
|
| 17 |
# Model type labels
|
|
|
|
| 12 |
# UI
|
| 13 |
CANVAS_SIZE = 320
|
| 14 |
SAMPLE_THUMBNAIL_LIMIT = 8
|
| 15 |
+
BATCH_MAX_IMAGES = 5
|
| 16 |
COLORMAP_N_SAMPLES = 64
|
| 17 |
|
| 18 |
# Model type labels
|
ui/components.py
CHANGED
|
@@ -3,6 +3,7 @@ import csv
|
|
| 3 |
import html
|
| 4 |
import io
|
| 5 |
import os
|
|
|
|
| 6 |
|
| 7 |
import cv2
|
| 8 |
import numpy as np
|
|
@@ -115,6 +116,135 @@ def render_system_status():
|
|
| 115 |
pass
|
| 116 |
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
# Distinct colors for each region (RGB - heatmap_rgb is RGB)
|
| 119 |
_REGION_COLORS = [
|
| 120 |
(255, 102, 0), # orange
|
|
@@ -508,7 +638,7 @@ def _add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2):
|
|
| 508 |
|
| 509 |
|
| 510 |
def render_result_display(img, raw_heatmap, display_heatmap, pixel_sum, force, key_img, download_key_suffix="",
|
| 511 |
-
colormap_name="Jet", display_mode="
|
| 512 |
cell_mask=None):
|
| 513 |
"""
|
| 514 |
Render prediction result: plot, metrics, expander, and download/measure buttons.
|
|
@@ -583,6 +713,34 @@ def render_result_display(img, raw_heatmap, display_heatmap, pixel_sum, force, k
|
|
| 583 |
with col4:
|
| 584 |
st.metric("Heatmap mean", f"{np.mean(raw_heatmap):.4f}", help="Average force intensity (full FOV)")
|
| 585 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
with st.expander("How to read the results"):
|
| 587 |
if use_cell_metrics:
|
| 588 |
st.markdown("""
|
|
|
|
| 3 |
import html
|
| 4 |
import io
|
| 5 |
import os
|
| 6 |
+
import zipfile
|
| 7 |
|
| 8 |
import cv2
|
| 9 |
import numpy as np
|
|
|
|
| 116 |
pass
|
| 117 |
|
| 118 |
|
| 119 |
+
def render_batch_results(batch_results, colormap_name="Jet", display_mode="Default",
|
| 120 |
+
min_percentile=0, max_percentile=100, clip_min=0, clip_max=1,
|
| 121 |
+
auto_cell_boundary=False):
|
| 122 |
+
"""
|
| 123 |
+
Render batch prediction results: summary table, bright-field row, heatmap row, and bulk download.
|
| 124 |
+
batch_results: list of dicts with img, heatmap, force, pixel_sum, key_img, cell_mask.
|
| 125 |
+
cell_mask is computed on-the-fly when auto_cell_boundary is True and not stored.
|
| 126 |
+
"""
|
| 127 |
+
if not batch_results:
|
| 128 |
+
return
|
| 129 |
+
st.markdown("### Batch results")
|
| 130 |
+
# Resolve cell_mask for each result (compute if needed when auto_cell_boundary toggled on)
|
| 131 |
+
for r in batch_results:
|
| 132 |
+
if auto_cell_boundary and (r.get("cell_mask") is None or not np.any(r.get("cell_mask", 0) > 0)):
|
| 133 |
+
r["_cell_mask"] = estimate_cell_mask(r["heatmap"])
|
| 134 |
+
else:
|
| 135 |
+
r["_cell_mask"] = r.get("cell_mask") if auto_cell_boundary else None
|
| 136 |
+
# Build table rows - consistent column names for both modes
|
| 137 |
+
headers = ["Image", "Force", "Sum", "Max", "Mean"]
|
| 138 |
+
rows = []
|
| 139 |
+
csv_rows = [["image"] + headers[1:]]
|
| 140 |
+
for r in batch_results:
|
| 141 |
+
heatmap = r["heatmap"]
|
| 142 |
+
cell_mask = r.get("_cell_mask")
|
| 143 |
+
key = r["key_img"] or "image"
|
| 144 |
+
if auto_cell_boundary and cell_mask is not None and np.any(cell_mask > 0):
|
| 145 |
+
vals = heatmap[cell_mask > 0]
|
| 146 |
+
cell_pixel_sum = float(np.sum(vals))
|
| 147 |
+
cell_force = cell_pixel_sum * (r["force"] / r["pixel_sum"]) if r["pixel_sum"] > 0 else cell_pixel_sum
|
| 148 |
+
cell_mean = cell_pixel_sum / np.sum(cell_mask) if np.sum(cell_mask) > 0 else 0
|
| 149 |
+
row = [key, f"{cell_force:.2f}", f"{cell_pixel_sum:.2f}",
|
| 150 |
+
f"{np.max(heatmap):.4f}", f"{cell_mean:.4f}"]
|
| 151 |
+
else:
|
| 152 |
+
row = [key, f"{r['force']:.2f}", f"{r['pixel_sum']:.2f}",
|
| 153 |
+
f"{np.max(heatmap):.4f}", f"{np.mean(heatmap):.4f}"]
|
| 154 |
+
rows.append(row)
|
| 155 |
+
csv_rows.append([os.path.splitext(key)[0]] + row[1:])
|
| 156 |
+
# Bright-field row
|
| 157 |
+
st.markdown("**Input: Bright-field images**")
|
| 158 |
+
n_cols = min(5, len(batch_results))
|
| 159 |
+
bf_cols = st.columns(n_cols)
|
| 160 |
+
for i, r in enumerate(batch_results):
|
| 161 |
+
img = r["img"]
|
| 162 |
+
if img.ndim == 2:
|
| 163 |
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
| 164 |
+
else:
|
| 165 |
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 166 |
+
with bf_cols[i % n_cols]:
|
| 167 |
+
st.image(img_rgb, caption=r["key_img"], use_container_width=True)
|
| 168 |
+
# Heatmap row
|
| 169 |
+
st.markdown("**Output: Predicted force maps**")
|
| 170 |
+
hm_cols = st.columns(n_cols)
|
| 171 |
+
for i, r in enumerate(batch_results):
|
| 172 |
+
display_heatmap = apply_display_scale(
|
| 173 |
+
r["heatmap"], display_mode,
|
| 174 |
+
min_percentile=min_percentile, max_percentile=max_percentile,
|
| 175 |
+
clip_min=clip_min, clip_max=clip_max,
|
| 176 |
+
)
|
| 177 |
+
hm_rgb = heatmap_to_rgb_with_contour(
|
| 178 |
+
display_heatmap, colormap_name,
|
| 179 |
+
r.get("_cell_mask") if auto_cell_boundary else None,
|
| 180 |
+
)
|
| 181 |
+
with hm_cols[i % n_cols]:
|
| 182 |
+
st.image(hm_rgb, caption=r["key_img"], use_container_width=True)
|
| 183 |
+
# Table
|
| 184 |
+
st.dataframe(
|
| 185 |
+
{h: [r[i] for r in rows] for i, h in enumerate(headers)},
|
| 186 |
+
use_container_width=True,
|
| 187 |
+
hide_index=True,
|
| 188 |
+
)
|
| 189 |
+
# Histograms in accordion (one per row for visibility)
|
| 190 |
+
with st.expander("Force distribution (histograms)", expanded=False):
|
| 191 |
+
for i, r in enumerate(batch_results):
|
| 192 |
+
heatmap = r["heatmap"]
|
| 193 |
+
cell_mask = r.get("_cell_mask")
|
| 194 |
+
vals = heatmap[cell_mask > 0] if (cell_mask is not None and np.any(cell_mask > 0) and auto_cell_boundary) else heatmap.flatten()
|
| 195 |
+
vals = vals[vals > 0] if np.any(vals > 0) else vals
|
| 196 |
+
st.markdown(f"**{r['key_img']}**")
|
| 197 |
+
if len(vals) > 0:
|
| 198 |
+
fig = go.Figure(data=[go.Histogram(x=vals, nbinsx=50, marker_color="#0d9488")])
|
| 199 |
+
fig.update_layout(
|
| 200 |
+
height=220, margin=dict(l=40, r=20, t=10, b=40),
|
| 201 |
+
xaxis_title="Force value", yaxis_title="Count",
|
| 202 |
+
showlegend=False,
|
| 203 |
+
)
|
| 204 |
+
st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
|
| 205 |
+
else:
|
| 206 |
+
st.caption("No data")
|
| 207 |
+
if i < len(batch_results) - 1:
|
| 208 |
+
st.divider()
|
| 209 |
+
# Bulk downloads: CSV and heatmaps (zip)
|
| 210 |
+
buf_csv = io.StringIO()
|
| 211 |
+
csv.writer(buf_csv).writerows(csv_rows)
|
| 212 |
+
zip_buf = io.BytesIO()
|
| 213 |
+
with zipfile.ZipFile(zip_buf, "w", zipfile.ZIP_DEFLATED) as zf:
|
| 214 |
+
for r in batch_results:
|
| 215 |
+
display_heatmap = apply_display_scale(
|
| 216 |
+
r["heatmap"], display_mode,
|
| 217 |
+
min_percentile=min_percentile, max_percentile=max_percentile,
|
| 218 |
+
clip_min=clip_min, clip_max=clip_max,
|
| 219 |
+
)
|
| 220 |
+
hm_bytes = heatmap_to_png_bytes(
|
| 221 |
+
display_heatmap, colormap_name,
|
| 222 |
+
r.get("_cell_mask") if auto_cell_boundary else None,
|
| 223 |
+
)
|
| 224 |
+
base = os.path.splitext(r["key_img"] or "image")[0]
|
| 225 |
+
zf.writestr(f"{base}_heatmap.png", hm_bytes.getvalue())
|
| 226 |
+
zip_buf.seek(0)
|
| 227 |
+
dl_col1, dl_col2 = st.columns(2)
|
| 228 |
+
with dl_col1:
|
| 229 |
+
st.download_button(
|
| 230 |
+
"Download all as CSV",
|
| 231 |
+
data=buf_csv.getvalue(),
|
| 232 |
+
file_name="s2f_batch_results.csv",
|
| 233 |
+
mime="text/csv",
|
| 234 |
+
key="download_batch_csv",
|
| 235 |
+
icon=":material/download:",
|
| 236 |
+
)
|
| 237 |
+
with dl_col2:
|
| 238 |
+
st.download_button(
|
| 239 |
+
"Download all heatmaps",
|
| 240 |
+
data=zip_buf.getvalue(),
|
| 241 |
+
file_name="s2f_batch_heatmaps.zip",
|
| 242 |
+
mime="application/zip",
|
| 243 |
+
key="download_batch_heatmaps",
|
| 244 |
+
icon=":material/image:",
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
# Distinct colors for each region (RGB - heatmap_rgb is RGB)
|
| 249 |
_REGION_COLORS = [
|
| 250 |
(255, 102, 0), # orange
|
|
|
|
| 638 |
|
| 639 |
|
| 640 |
def render_result_display(img, raw_heatmap, display_heatmap, pixel_sum, force, key_img, download_key_suffix="",
|
| 641 |
+
colormap_name="Jet", display_mode="Default", measure_region_dialog=None, auto_cell_boundary=True,
|
| 642 |
cell_mask=None):
|
| 643 |
"""
|
| 644 |
Render prediction result: plot, metrics, expander, and download/measure buttons.
|
|
|
|
| 713 |
with col4:
|
| 714 |
st.metric("Heatmap mean", f"{np.mean(raw_heatmap):.4f}", help="Average force intensity (full FOV)")
|
| 715 |
|
| 716 |
+
# Statistics panel (mean, std, percentiles, histogram)
|
| 717 |
+
with st.expander("Statistics"):
|
| 718 |
+
vals = raw_heatmap[cell_mask > 0] if (cell_mask is not None and np.any(cell_mask > 0) and use_cell_metrics) else raw_heatmap.flatten()
|
| 719 |
+
if len(vals) > 0:
|
| 720 |
+
st.markdown("**Summary**")
|
| 721 |
+
stat_col1, stat_col2, stat_col3 = st.columns(3)
|
| 722 |
+
with stat_col1:
|
| 723 |
+
st.metric("Mean", f"{float(np.mean(vals)):.4f}")
|
| 724 |
+
st.metric("Std", f"{float(np.std(vals)):.4f}")
|
| 725 |
+
with stat_col2:
|
| 726 |
+
p25, p50, p75 = float(np.percentile(vals, 25)), float(np.percentile(vals, 50)), float(np.percentile(vals, 75))
|
| 727 |
+
st.metric("P25", f"{p25:.4f}")
|
| 728 |
+
st.metric("P50 (median)", f"{p50:.4f}")
|
| 729 |
+
st.metric("P75", f"{p75:.4f}")
|
| 730 |
+
with stat_col3:
|
| 731 |
+
p90 = float(np.percentile(vals, 90))
|
| 732 |
+
st.metric("P90", f"{p90:.4f}")
|
| 733 |
+
st.markdown("**Histogram**")
|
| 734 |
+
hist_fig = go.Figure(data=[go.Histogram(x=vals, nbinsx=50, marker_color="#0d9488")])
|
| 735 |
+
hist_fig.update_layout(
|
| 736 |
+
height=220, margin=dict(l=40, r=20, t=20, b=40),
|
| 737 |
+
xaxis_title="Force value", yaxis_title="Count",
|
| 738 |
+
showlegend=False,
|
| 739 |
+
)
|
| 740 |
+
st.plotly_chart(hist_fig, use_container_width=True, config={"displayModeBar": False})
|
| 741 |
+
else:
|
| 742 |
+
st.caption("No nonzero values to compute statistics.")
|
| 743 |
+
|
| 744 |
with st.expander("How to read the results"):
|
| 745 |
if use_cell_metrics:
|
| 746 |
st.markdown("""
|
utils/display.py
CHANGED
|
@@ -21,14 +21,12 @@ def cv_colormap_to_plotly_colorscale(colormap_name, n_samples=None):
|
|
| 21 |
|
| 22 |
def apply_display_scale(heatmap, mode, min_percentile=0, max_percentile=100, clip_min=0, clip_max=1):
|
| 23 |
"""
|
| 24 |
-
Apply display scaling (Fiji
|
| 25 |
-
-
|
| 26 |
-
- Percentile: map
|
| 27 |
-
-
|
| 28 |
-
- Clip: clip values to [clip_min, clip_max], display with original data scale (no stretching).
|
| 29 |
-
- Filter: discard values outside [clip_min, clip_max] (black); only values in range show color.
|
| 30 |
"""
|
| 31 |
-
if mode == "Full":
|
| 32 |
return np.clip(heatmap, 0, 1).astype(np.float32)
|
| 33 |
if mode == "Percentile":
|
| 34 |
pmin = float(np.percentile(heatmap, min_percentile))
|
|
@@ -37,25 +35,7 @@ def apply_display_scale(heatmap, mode, min_percentile=0, max_percentile=100, cli
|
|
| 37 |
out = (heatmap.astype(np.float32) - pmin) / (pmax - pmin)
|
| 38 |
return np.clip(out, 0, 1).astype(np.float32)
|
| 39 |
return np.clip(heatmap, 0, 1).astype(np.float32)
|
| 40 |
-
if mode == "
|
| 41 |
-
vmin, vmax = float(clip_min), float(clip_max)
|
| 42 |
-
if vmax > vmin:
|
| 43 |
-
out = (heatmap.astype(np.float32) - vmin) / (vmax - vmin)
|
| 44 |
-
return np.clip(out, 0, 1).astype(np.float32)
|
| 45 |
-
return np.clip(heatmap, 0, 1).astype(np.float32)
|
| 46 |
-
if mode == "Clip":
|
| 47 |
-
vmin, vmax = float(clip_min), float(clip_max)
|
| 48 |
-
if vmax > vmin:
|
| 49 |
-
h = np.clip(heatmap.astype(np.float32), vmin, vmax)
|
| 50 |
-
dmin, dmax = float(np.min(heatmap)), float(np.max(heatmap))
|
| 51 |
-
if dmax > dmin:
|
| 52 |
-
out = (h - dmin) / (dmax - dmin)
|
| 53 |
-
return np.clip(out, 0, 1).astype(np.float32)
|
| 54 |
-
return np.clip(h, 0, 1).astype(np.float32)
|
| 55 |
-
return np.clip(heatmap, 0, 1).astype(np.float32)
|
| 56 |
-
if mode == "Threshold":
|
| 57 |
-
mode = "Filter" # backward compat
|
| 58 |
-
if mode == "Filter":
|
| 59 |
vmin, vmax = float(clip_min), float(clip_max)
|
| 60 |
if vmax > vmin:
|
| 61 |
h = heatmap.astype(np.float32)
|
|
|
|
| 21 |
|
| 22 |
def apply_display_scale(heatmap, mode, min_percentile=0, max_percentile=100, clip_min=0, clip_max=1):
|
| 23 |
"""
|
| 24 |
+
Apply display scaling (Fiji/ImageJ style). Display only—does not change underlying values.
|
| 25 |
+
- Default: full 0–1 range as-is.
|
| 26 |
+
- Percentile: map min..max percentiles to 0..1.
|
| 27 |
+
- Range: show only values in [clip_min, clip_max]; others hidden (black).
|
|
|
|
|
|
|
| 28 |
"""
|
| 29 |
+
if mode == "Default" or mode == "Auto" or mode == "Full":
|
| 30 |
return np.clip(heatmap, 0, 1).astype(np.float32)
|
| 31 |
if mode == "Percentile":
|
| 32 |
pmin = float(np.percentile(heatmap, min_percentile))
|
|
|
|
| 35 |
out = (heatmap.astype(np.float32) - pmin) / (pmax - pmin)
|
| 36 |
return np.clip(out, 0, 1).astype(np.float32)
|
| 37 |
return np.clip(heatmap, 0, 1).astype(np.float32)
|
| 38 |
+
if mode == "Range" or mode == "Filter" or mode == "Threshold":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
vmin, vmax = float(clip_min), float(clip_max)
|
| 40 |
if vmax > vmin:
|
| 41 |
h = heatmap.astype(np.float32)
|