Spaces:
Running
Running
File size: 23,814 Bytes
6cbe52d e6abf01 c1833b6 6cbe52d 750ec75 600a077 dff64e9 750ec75 478c826 750ec75 478c826 750ec75 478c826 750ec75 6cbe52d 750ec75 dff64e9 750ec75 6cbe52d e6abf01 00e4d88 e6abf01 1f836a4 750ec75 e6abf01 750ec75 e6abf01 dff64e9 fd5af45 1e95da3 1f836a4 1e95da3 e6abf01 750ec75 e6abf01 51d2595 750ec75 e6abf01 750ec75 e6abf01 699f6c2 1e95da3 fd5af45 478c826 699f6c2 1e95da3 1f836a4 699f6c2 00e4d88 750ec75 1f836a4 750ec75 00e4d88 750ec75 e6abf01 00e4d88 e6abf01 750ec75 699f6c2 6cbe52d 750ec75 478c826 750ec75 00e4d88 750ec75 6cbe52d 00e4d88 893dcb6 09e9bfe 893dcb6 6cbe52d 893dcb6 6cbe52d 478c826 750ec75 6cbe52d 893dcb6 fd5af45 893dcb6 750ec75 893dcb6 6cbe52d dff64e9 00e4d88 8d06964 00e4d88 893dcb6 1f836a4 893dcb6 00e4d88 fd5af45 00e4d88 750ec75 6cbe52d 00e4d88 6cbe52d dff64e9 6cbe52d dff64e9 6cbe52d dff64e9 00e4d88 6cbe52d dff64e9 6cbe52d dff64e9 6cbe52d 00e4d88 f858aef 00e4d88 f858aef 00e4d88 dff64e9 6cbe52d dff64e9 6cbe52d 8d3f0e2 dff64e9 8d3f0e2 dff64e9 8d3f0e2 1f836a4 dff64e9 8d3f0e2 750ec75 00e4d88 fd5af45 1f836a4 00e4d88 1f836a4 00e4d88 1f836a4 00e4d88 1f836a4 00e4d88 1f836a4 00e4d88 750ec75 dff64e9 b278d0d dff64e9 00e4d88 dff64e9 600a077 00e4d88 dff64e9 00e4d88 dff64e9 b278d0d dff64e9 1f836a4 dff64e9 b278d0d dff64e9 1f836a4 dff64e9 750ec75 6cbe52d 00e4d88 478c826 51d2595 1f836a4 00e4d88 51d2595 00e4d88 fd5af45 00e4d88 1f836a4 699f6c2 6cbe52d 8d3f0e2 00e4d88 fd5af45 00e4d88 1f836a4 750ec75 8d3f0e2 6cbe52d dff64e9 8d3f0e2 dff64e9 6cbe52d 8317b77 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 | """
Shape2Force (S2F) - GUI for force map prediction from bright field microscopy images.
"""
import os
import sys
import traceback
# Suppress OpenCV verbose logging (cv2.utils.logging not reliably available in all builds)
os.environ.setdefault("OPENCV_LOG_LEVEL", "ERROR")
import cv2
import numpy as np
import streamlit as st
S2F_ROOT = os.path.dirname(os.path.abspath(__file__))
if S2F_ROOT not in sys.path:
sys.path.insert(0, S2F_ROOT)
from config.constants import (
BATCH_INFERENCE_SIZE,
BATCH_MAX_IMAGES,
COLORMAPS,
DEFAULT_SUBSTRATE,
MODEL_TYPE_LABELS,
SAMPLE_EXTENSIONS,
SAMPLE_THUMBNAIL_LIMIT,
)
from utils.paths import get_ckp_base, get_ckp_folder, get_sample_folder, list_files_in_folder, model_subfolder
from utils.segmentation import estimate_cell_mask
from utils.substrate_settings import list_substrates
from utils.display import apply_display_scale
from ui.components import (
build_original_vals,
build_cell_vals,
render_batch_results,
render_result_display,
render_region_canvas,
ST_DIALOG,
HAS_DRAWABLE_CANVAS,
)
CITATION = (
"Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. "
"<b>\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\"</b>, 2026."
)
def _inference_cache_condition_key(model_type, use_manual, substrate_val, substrate_config):
"""Hashable key for substrate / manual conditions so cache invalidates when single-cell inputs change."""
if model_type != "single_cell":
return None
if use_manual and substrate_config is not None:
return (
"manual",
round(float(substrate_config["pixelsize"]), 6),
round(float(substrate_config["young"]), 2),
)
return ("preset", str(substrate_val))
# Measure tool dialog: defined early so it exists before render_result_display uses it
if HAS_DRAWABLE_CANVAS and ST_DIALOG:
@ST_DIALOG("Measure tool", width="medium")
def measure_region_dialog():
raw_heatmap = st.session_state.get("measure_raw_heatmap")
if raw_heatmap is None:
st.warning("No prediction available to measure.")
return
display_mode = st.session_state.get("measure_display_mode", "Default")
_m_clamp = st.session_state.get("measure_clamp_only", False)
display_heatmap = apply_display_scale(
raw_heatmap, display_mode,
clip_min=st.session_state.get("measure_clip_min", 0),
clip_max=st.session_state.get("measure_clip_max", 1),
clamp_only=_m_clamp,
)
bf_img = st.session_state.get("measure_bf_img")
original_vals = st.session_state.get("measure_original_vals")
cell_vals = st.session_state.get("measure_cell_vals")
cell_mask = st.session_state.get("measure_cell_mask")
input_filename = st.session_state.get("measure_input_filename", "image")
colormap_name = st.session_state.get("measure_colormap", "Jet")
render_region_canvas(
display_heatmap, raw_heatmap=raw_heatmap, bf_img=bf_img,
original_vals=original_vals, cell_vals=cell_vals, cell_mask=cell_mask,
key_suffix="dialog", input_filename=input_filename, colormap_name=colormap_name,
)
else:
def measure_region_dialog():
pass
def _get_measure_dialog_fn():
"""Return measure dialog callable if available, else None (fixes st_dialog ordering)."""
return measure_region_dialog if (HAS_DRAWABLE_CANVAS and ST_DIALOG) else None
def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, colormap_name,
display_mode, auto_cell_boundary, cell_mask=None,
clip_min=0, clip_max=1, clamp_only=False):
"""Populate session state for the measure tool. If cell_mask is None and auto_cell_boundary, computes it."""
if cell_mask is None and auto_cell_boundary:
cell_mask = estimate_cell_mask(heatmap)
st.session_state["measure_raw_heatmap"] = heatmap.copy()
st.session_state["measure_display_mode"] = display_mode
st.session_state["measure_clip_min"] = clip_min
st.session_state["measure_clip_max"] = clip_max
st.session_state["measure_clamp_only"] = clamp_only
st.session_state["measure_bf_img"] = img.copy()
st.session_state["measure_input_filename"] = key_img or "image"
st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force)
st.session_state["measure_colormap"] = colormap_name
st.session_state["measure_auto_cell_on"] = auto_cell_boundary
st.session_state["measure_cell_vals"] = build_cell_vals(heatmap, cell_mask, pixel_sum, force) if auto_cell_boundary else None
st.session_state["measure_cell_mask"] = cell_mask if auto_cell_boundary else None
st.set_page_config(page_title="Shape2Force (S2F)", page_icon="🦠", layout="wide")
st.markdown(
'<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">',
unsafe_allow_html=True,
)
_css_path = os.path.join(S2F_ROOT, "static", "s2f_styles.css")
if os.path.exists(_css_path):
with open(_css_path, "r") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
st.markdown("""
<div class="s2f-header">
<h1>🦠 Shape2Force (S2F)</h1>
<p>Predict traction force maps from bright-field microscopy images of cells or spheroids</p>
</div>
""", unsafe_allow_html=True)
# Folders
ckp_base = get_ckp_base(S2F_ROOT)
def get_cached_sample_thumbnails(model_type, sample_folder, sample_files):
"""Return cached sample thumbnails. Key by (model_type, tuple(files))."""
cache_key = (model_type, tuple(sample_files))
if "sample_thumbnails" not in st.session_state:
st.session_state["sample_thumbnails"] = {}
cache = st.session_state["sample_thumbnails"]
if cache_key not in cache:
thumbnails = []
for fname in sample_files[:SAMPLE_THUMBNAIL_LIMIT]:
path = os.path.join(sample_folder, fname)
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
thumbnails.append((fname, img))
cache[cache_key] = thumbnails
return cache[cache_key]
def _render_sample_selector(model_type, batch_mode):
"""
Render sample image selector (Example mode). Returns (img, imgs_batch, selected_sample, selected_samples).
For single mode: img is set, imgs_batch=[]. For batch: img=None, imgs_batch=list of (img, key).
"""
sample_folder = get_sample_folder(S2F_ROOT, model_type)
sample_files = list_files_in_folder(sample_folder, SAMPLE_EXTENSIONS)
sample_subfolder_name = model_subfolder(model_type)
img = None
imgs_batch = []
selected_sample = None
selected_samples = []
if not sample_files:
st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.")
return img, imgs_batch, selected_sample, selected_samples
if batch_mode:
selected_samples = st.multiselect(
f"Select example images (max {BATCH_MAX_IMAGES})",
sample_files,
default=None,
max_selections=BATCH_MAX_IMAGES,
key=f"sample_batch_{model_type}",
)
if selected_samples:
for fname in selected_samples[:BATCH_MAX_IMAGES]:
sample_path = os.path.join(sample_folder, fname)
loaded = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
if loaded is not None:
imgs_batch.append((loaded, fname))
else:
selected_sample = st.selectbox(
f"Select example image (from `samples/{sample_subfolder_name}/`)",
sample_files,
format_func=lambda x: x,
key=f"sample_{model_type}",
)
if selected_sample:
sample_path = os.path.join(sample_folder, selected_sample)
img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
thumbnails = get_cached_sample_thumbnails(model_type, sample_folder, sample_files)
n_cols = min(5, len(thumbnails))
cols = st.columns(n_cols)
for i, (fname, sample_img) in enumerate(thumbnails):
if sample_img is not None:
with cols[i % n_cols]:
st.image(sample_img, caption=fname, width=120)
return img, imgs_batch, selected_sample, selected_samples
# Sidebar
with st.sidebar:
st.markdown("""
<div class="sidebar-brand">
<span class="brand-text">Shape2Force</span>
</div>
""", unsafe_allow_html=True)
with st.container(border=False, key="s2f_grp_model"):
model_type = st.radio(
"Model type",
["single_cell", "spheroid"],
format_func=lambda x: MODEL_TYPE_LABELS[x],
horizontal=False,
help="Single cell: substrate-aware force prediction. Spheroid: spheroid force maps.",
)
ckp_folder = get_ckp_folder(ckp_base, model_type)
ckp_files = list_files_in_folder(ckp_folder, ".pth")
ckp_subfolder_name = model_subfolder(model_type)
if ckp_files:
checkpoint = st.selectbox(
"Checkpoint",
ckp_files,
key=f"checkpoint_{model_type}",
help=f"Select a .pth file from ckp/{ckp_subfolder_name}/",
)
else:
st.warning(f"No .pth files in ckp/{ckp_subfolder_name}/. Add checkpoints to load.")
checkpoint = None
substrate_config = None
substrate_val = DEFAULT_SUBSTRATE
use_manual = True
if model_type == "single_cell":
try:
with st.container(border=False, key="s2f_grp_conditions"):
st.markdown('<p class="s2f-form-label s2f-form-label--section">Conditions</p>', unsafe_allow_html=True)
conditions_source = st.radio(
"Conditions",
["From config", "Manually"],
horizontal=True,
label_visibility="collapsed",
)
from_config = conditions_source == "From config"
if from_config:
substrate_config = None
substrates = list_substrates()
substrate_val = st.selectbox(
"Conditions (from config)",
substrates,
help="Select a preset from config/substrate_settings.json",
label_visibility="collapsed",
)
use_manual = False
else:
manual_pixelsize = st.number_input("Pixel size (µm/px)", min_value=0.1, max_value=50.0,
value=3.0769, step=0.1, format="%.4f")
manual_young = st.number_input("Pascals", min_value=100.0, max_value=100000.0,
value=6000.0, step=100.0, format="%.0f")
substrate_config = {"pixelsize": manual_pixelsize, "young": manual_young}
use_manual = True
except FileNotFoundError:
st.error("config/substrate_settings.json not found")
batch_mode = st.toggle(
"Batch mode",
value=False,
help=f"Process up to {BATCH_MAX_IMAGES} images at once. Upload multiple files or select multiple examples.",
)
auto_cell_boundary = st.toggle(
"Auto boundary",
value=False,
help="When on: estimate cell region from force map and use it for metrics (red contour). When off: use entire map.",
)
with st.container(border=False, key="s2f_grp_force"):
force_scale_mode = st.radio(
"Force scale",
["Default", "Range"],
horizontal=True,
key="s2f_force_scale",
help="Default: display forces on the full 0–1 scale. Range: set a sub-range; values outside are zeroed and the rest is stretched to the colormap.",
)
if force_scale_mode == "Default":
clip_min, clip_max = 0.0, 1.0
display_mode = "Default"
clamp_only = True
else:
mn_col, mx_col = st.columns(2)
with mn_col:
clip_min = st.number_input(
"Min",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.01,
format="%.2f",
key="s2f_clip_min",
help="Lower bound of the display range (0–1).",
)
with mx_col:
clip_max = st.number_input(
"Max",
min_value=0.0,
max_value=1.0,
value=1.0,
step=0.01,
format="%.2f",
key="s2f_clip_max",
help="Upper bound of the display range (0–1).",
)
if clip_min >= clip_max:
st.warning("Min must be less than max. Using 0.00–1.00 for display.")
clip_min, clip_max = 0.0, 1.0
display_mode = "Range"
clamp_only = False
cm_col_lbl, cm_col_sb = st.columns([1, 2])
with cm_col_lbl:
st.markdown('<p class="s2f-form-label s2f-form-label--colormap">Colormap</p>', unsafe_allow_html=True)
with cm_col_sb:
colormap_name = st.selectbox(
"Colormap",
list(COLORMAPS.keys()),
key="s2f_colormap",
label_visibility="collapsed",
help="Color scheme for the force map. Viridis is often preferred for accessibility.",
)
# Main area: image input
img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed", key="s2f_img_source")
img = None
imgs_batch = [] # list of (img, key_img) for batch mode
uploaded = None
uploaded_list = []
selected_sample = None
selected_samples = []
if batch_mode:
# Batch mode: multiple images (max BATCH_MAX_IMAGES)
if img_source == "Upload":
uploaded_list = st.file_uploader(
"Upload bright-field images",
type=["tif", "tiff", "png", "jpg", "jpeg"],
accept_multiple_files=True,
help=f"Select up to {BATCH_MAX_IMAGES} images. Bright-field microscopy (grayscale or RGB).",
)
if uploaded_list:
uploaded_list = uploaded_list[:BATCH_MAX_IMAGES]
for u in uploaded_list:
bytes_data = u.read()
nparr = np.frombuffer(bytes_data, np.uint8)
decoded = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
if decoded is not None:
imgs_batch.append((decoded, u.name))
u.seek(0)
else:
img, imgs_batch, selected_sample, selected_samples = _render_sample_selector(model_type, batch_mode=True)
else:
# Single image mode
if img_source == "Upload":
uploaded = st.file_uploader(
"Upload bright-field image",
type=["tif", "tiff", "png", "jpg", "jpeg"],
help="Bright-field microscopy image of a cell or spheroid on a substrate (grayscale or RGB).",
)
if uploaded:
bytes_data = uploaded.read()
nparr = np.frombuffer(bytes_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
uploaded.seek(0)
else:
img, imgs_batch, selected_sample, selected_samples = _render_sample_selector(model_type, batch_mode=False)
st.markdown("")
col_btn, col_info = st.columns([1, 3])
with col_btn:
run = st.button("Run prediction", type="primary", use_container_width=True)
with col_info:
ckp_path = f"ckp/{ckp_subfolder_name}/{checkpoint}" if checkpoint else f"ckp/{ckp_subfolder_name}/"
st.markdown(f"""
<div class="run-info">
<span class="run-info-tag">{MODEL_TYPE_LABELS[model_type]}</span>
<code>{ckp_path}</code>
</div>
""", unsafe_allow_html=True)
has_image = img is not None
has_batch = len(imgs_batch) > 0
if "prediction_result" not in st.session_state:
st.session_state["prediction_result"] = None
if "batch_results" not in st.session_state:
st.session_state["batch_results"] = None
if not batch_mode:
st.session_state["batch_results"] = None # Clear when switching to single mode
# Single-image keys (for non-batch)
key_img = (uploaded.name if uploaded else None) if img_source == "Upload" else selected_sample
_cond_key = _inference_cache_condition_key(model_type, use_manual, substrate_val, substrate_config)
current_key = (model_type, checkpoint, key_img, _cond_key)
cached = st.session_state["prediction_result"]
has_cached = cached is not None and cached.get("cache_key") == current_key and not batch_mode
just_ran = run and checkpoint and has_image and not batch_mode
just_ran_batch = run and checkpoint and has_batch and batch_mode
@st.cache_resource
def _load_predictor(model_type, checkpoint, ckp_folder):
"""Load and cache predictor. Invalidated when model_type or checkpoint changes."""
from predictor import S2FPredictor
return S2FPredictor(
model_type=model_type,
checkpoint_path=checkpoint,
ckp_folder=ckp_folder,
)
def _prepare_and_render_cached_result(r, key_img, colormap_name, display_mode, auto_cell_boundary,
clip_min, clip_max, clamp_only,
download_key_suffix="", check_measure_dialog=False,
show_success=False):
"""Prepare display from cached result and render. Used by both just_ran and has_cached paths."""
img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
display_heatmap = apply_display_scale(
heatmap, display_mode,
clip_min=clip_min,
clip_max=clip_max,
clamp_only=clamp_only,
)
cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None
_populate_measure_session_state(
heatmap, img, pixel_sum, force, key_img, colormap_name,
display_mode, auto_cell_boundary, cell_mask=cell_mask,
clip_min=clip_min, clip_max=clip_max, clamp_only=clamp_only,
)
if check_measure_dialog and st.session_state.pop("open_measure_dialog", False):
measure_region_dialog()
if show_success:
st.success("Prediction complete!")
render_result_display(
img, heatmap, display_heatmap, pixel_sum, force, key_img,
download_key_suffix=download_key_suffix,
colormap_name=colormap_name,
display_mode=display_mode,
measure_region_dialog=_get_measure_dialog_fn(),
auto_cell_boundary=auto_cell_boundary,
cell_mask=cell_mask,
clip_min=clip_min, clip_max=clip_max, clamp_only=clamp_only,
)
if just_ran_batch:
st.session_state["prediction_result"] = None
st.session_state["batch_results"] = None
with st.spinner("Loading model and predicting..."):
progress_bar = None
try:
predictor = _load_predictor(model_type, checkpoint, ckp_folder)
sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE
n_images = len(imgs_batch)
progress_bar = st.progress(0, text=f"Predicting 0 / {n_images} images")
pred_results = []
for start in range(0, n_images, BATCH_INFERENCE_SIZE):
chunk = imgs_batch[start : start + BATCH_INFERENCE_SIZE]
chunk_results = predictor.predict_batch(
chunk,
substrate=sub_val,
substrate_config=substrate_config if model_type == "single_cell" else None,
)
pred_results.extend(chunk_results)
progress_bar.progress(min(start + len(chunk), n_images) / n_images,
text=f"Predicting {len(pred_results)} / {n_images} images")
batch_results = [
{
"img": img_b.copy(),
"heatmap": heatmap.copy(),
"force": force,
"pixel_sum": pixel_sum,
"key_img": key_b,
"cell_mask": estimate_cell_mask(heatmap) if auto_cell_boundary else None,
}
for (img_b, key_b), (heatmap, force, pixel_sum) in zip(imgs_batch, pred_results)
]
st.session_state["batch_results"] = batch_results
progress_bar.empty()
st.success(f"Prediction complete for {len(batch_results)} image(s)!")
render_batch_results(
batch_results,
colormap_name=colormap_name,
display_mode=display_mode,
clip_min=clip_min,
clip_max=clip_max,
auto_cell_boundary=auto_cell_boundary,
clamp_only=clamp_only,
)
except Exception as e:
if progress_bar is not None:
progress_bar.empty()
st.error(f"Prediction failed: {e}")
st.code(traceback.format_exc())
elif batch_mode and st.session_state.get("batch_results"):
render_batch_results(
st.session_state["batch_results"],
colormap_name=colormap_name,
display_mode=display_mode,
clip_min=clip_min,
clip_max=clip_max,
auto_cell_boundary=auto_cell_boundary,
clamp_only=clamp_only,
)
elif just_ran:
st.session_state["prediction_result"] = None
with st.spinner("Loading model and predicting..."):
try:
predictor = _load_predictor(model_type, checkpoint, ckp_folder)
sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE
heatmap, force, pixel_sum = predictor.predict(
image_array=img,
substrate=sub_val,
substrate_config=substrate_config if model_type == "single_cell" else None,
)
cache_key = (model_type, checkpoint, key_img, _cond_key)
r = {
"img": img.copy(),
"heatmap": heatmap.copy(),
"force": force,
"pixel_sum": pixel_sum,
"cache_key": cache_key,
}
st.session_state["prediction_result"] = r
_prepare_and_render_cached_result(
r, key_img, colormap_name, display_mode, auto_cell_boundary,
clip_min, clip_max, clamp_only,
download_key_suffix="", check_measure_dialog=False,
show_success=True,
)
except Exception as e:
st.error(f"Prediction failed: {e}")
st.code(traceback.format_exc())
elif has_cached:
r = st.session_state["prediction_result"]
_prepare_and_render_cached_result(
r, key_img, colormap_name, display_mode, auto_cell_boundary,
clip_min, clip_max, clamp_only,
download_key_suffix="_cached", check_measure_dialog=True,
show_success=False,
)
elif run and not checkpoint:
st.warning("Please add checkpoint files to the ckp/ folder and select one.")
elif run and not has_image and not has_batch:
st.warning("Please upload an image or select an example.")
elif run and batch_mode and not has_batch:
st.warning(f"Please upload or select 1–{BATCH_MAX_IMAGES} images for batch processing.")
st.markdown(f"""
<div class="footer-citation">
<span>If you find this software useful, please cite: {CITATION}</span>
</div>
""", unsafe_allow_html=True)
|