kaveh commited on
Commit
699f6c2
·
1 Parent(s): 750ec75

cleaned code

Browse files
Files changed (5) hide show
  1. app.py +27 -32
  2. predictor.py +3 -6
  3. ui/components.py +2 -15
  4. utils/paths.py +17 -0
  5. utils/report.py +9 -7
app.py CHANGED
@@ -20,6 +20,7 @@ from config.constants import (
20
  MODEL_TYPE_LABELS,
21
  SAMPLE_EXTENSIONS,
22
  )
 
23
  from utils.segmentation import estimate_cell_mask
24
  from utils.substrate_settings import list_substrates
25
  from utils.display import apply_display_scale
@@ -32,11 +33,6 @@ from ui.components import (
32
  HAS_DRAWABLE_CANVAS,
33
  )
34
 
35
- try:
36
- from streamlit_drawable_canvas import st_canvas
37
- except (ImportError, AttributeError):
38
- pass # HAS_DRAWABLE_CANVAS from ui.components
39
-
40
  CITATION = (
41
  "Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. "
42
  "**\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\"**, 2026."
@@ -73,6 +69,21 @@ def _get_measure_dialog_fn():
73
  return measure_region_dialog if (HAS_DRAWABLE_CANVAS and ST_DIALOG) else None
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  st.set_page_config(page_title="Shape2Force (S2F)", page_icon="🦠", layout="centered")
77
 
78
  # Theme CSS (inject based on sidebar selection)
@@ -123,11 +134,7 @@ st.title("🦠 Shape2Force (S2F)")
123
  st.caption("Predict traction force maps from bright-field microscopy images of cells or spheroids")
124
 
125
  # Folders
126
- ckp_base = os.path.join(S2F_ROOT, "ckp")
127
- if not os.path.isdir(ckp_base):
128
- project_root = os.path.dirname(S2F_ROOT)
129
- if os.path.isdir(os.path.join(project_root, "ckp")):
130
- ckp_base = os.path.join(project_root, "ckp")
131
  ckp_single_cell = os.path.join(ckp_base, "single_cell")
132
  ckp_spheroid = os.path.join(ckp_base, "spheroid")
133
  sample_base = os.path.join(S2F_ROOT, "samples")
@@ -179,7 +186,7 @@ with st.sidebar:
179
 
180
  ckp_files = get_ckp_files_for_model(model_type)
181
  ckp_folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid
182
- ckp_subfolder_name = "single_cell" if model_type == "single_cell" else "spheroid"
183
 
184
  if ckp_files:
185
  checkpoint = st.selectbox(
@@ -265,7 +272,7 @@ if img_source == "Upload":
265
  else:
266
  sample_files = get_sample_files_for_model(model_type)
267
  sample_folder = sample_single_cell if model_type == "single_cell" else sample_spheroid
268
- sample_subfolder_name = "single_cell" if model_type == "single_cell" else "spheroid"
269
  if sample_files:
270
  selected_sample = st.selectbox(
271
  f"Select example image (from `samples/{sample_subfolder_name}/`)",
@@ -345,16 +352,10 @@ if just_ran:
345
  "pixel_sum": pixel_sum,
346
  "cache_key": cache_key,
347
  }
348
- st.session_state["measure_raw_heatmap"] = heatmap.copy()
349
- st.session_state["measure_display_mode"] = display_mode
350
- st.session_state["measure_bf_img"] = img.copy()
351
- st.session_state["measure_input_filename"] = key_img or "image"
352
- st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force)
353
- st.session_state["measure_colormap"] = colormap_name
354
- cell_mask = estimate_cell_mask(heatmap)
355
- st.session_state["measure_auto_cell_on"] = auto_cell_boundary
356
- st.session_state["measure_cell_vals"] = build_cell_vals(heatmap, cell_mask, pixel_sum, force) if auto_cell_boundary else None
357
- st.session_state["measure_cell_mask"] = cell_mask if auto_cell_boundary else None
358
 
359
  render_result_display(
360
  img, heatmap, display_heatmap, pixel_sum, force, key_img,
@@ -373,16 +374,10 @@ elif has_cached:
373
  img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
374
  display_heatmap = apply_display_scale(heatmap, display_mode)
375
 
376
- st.session_state["measure_raw_heatmap"] = heatmap.copy()
377
- st.session_state["measure_display_mode"] = display_mode
378
- st.session_state["measure_bf_img"] = img.copy()
379
- st.session_state["measure_input_filename"] = key_img or "image"
380
- st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force)
381
- st.session_state["measure_colormap"] = colormap_name
382
- cell_mask = estimate_cell_mask(heatmap)
383
- st.session_state["measure_auto_cell_on"] = auto_cell_boundary
384
- st.session_state["measure_cell_vals"] = build_cell_vals(heatmap, cell_mask, pixel_sum, force) if auto_cell_boundary else None
385
- st.session_state["measure_cell_mask"] = cell_mask if auto_cell_boundary else None
386
 
387
  if st.session_state.pop("open_measure_dialog", False):
388
  measure_region_dialog()
 
20
  MODEL_TYPE_LABELS,
21
  SAMPLE_EXTENSIONS,
22
  )
23
+ from utils.paths import get_ckp_base, model_subfolder
24
  from utils.segmentation import estimate_cell_mask
25
  from utils.substrate_settings import list_substrates
26
  from utils.display import apply_display_scale
 
33
  HAS_DRAWABLE_CANVAS,
34
  )
35
 
 
 
 
 
 
36
  CITATION = (
37
  "Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. "
38
  "**\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\"**, 2026."
 
69
  return measure_region_dialog if (HAS_DRAWABLE_CANVAS and ST_DIALOG) else None
70
 
71
 
72
+ def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, colormap_name,
73
+ display_mode, auto_cell_boundary):
74
+ """Populate session state for the measure tool."""
75
+ cell_mask = estimate_cell_mask(heatmap)
76
+ st.session_state["measure_raw_heatmap"] = heatmap.copy()
77
+ st.session_state["measure_display_mode"] = display_mode
78
+ st.session_state["measure_bf_img"] = img.copy()
79
+ st.session_state["measure_input_filename"] = key_img or "image"
80
+ st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force)
81
+ st.session_state["measure_colormap"] = colormap_name
82
+ st.session_state["measure_auto_cell_on"] = auto_cell_boundary
83
+ st.session_state["measure_cell_vals"] = build_cell_vals(heatmap, cell_mask, pixel_sum, force) if auto_cell_boundary else None
84
+ st.session_state["measure_cell_mask"] = cell_mask if auto_cell_boundary else None
85
+
86
+
87
  st.set_page_config(page_title="Shape2Force (S2F)", page_icon="🦠", layout="centered")
88
 
89
  # Theme CSS (inject based on sidebar selection)
 
134
  st.caption("Predict traction force maps from bright-field microscopy images of cells or spheroids")
135
 
136
  # Folders
137
+ ckp_base = get_ckp_base(S2F_ROOT)
 
 
 
 
138
  ckp_single_cell = os.path.join(ckp_base, "single_cell")
139
  ckp_spheroid = os.path.join(ckp_base, "spheroid")
140
  sample_base = os.path.join(S2F_ROOT, "samples")
 
186
 
187
  ckp_files = get_ckp_files_for_model(model_type)
188
  ckp_folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid
189
+ ckp_subfolder_name = model_subfolder(model_type)
190
 
191
  if ckp_files:
192
  checkpoint = st.selectbox(
 
272
  else:
273
  sample_files = get_sample_files_for_model(model_type)
274
  sample_folder = sample_single_cell if model_type == "single_cell" else sample_spheroid
275
+ sample_subfolder_name = model_subfolder(model_type)
276
  if sample_files:
277
  selected_sample = st.selectbox(
278
  f"Select example image (from `samples/{sample_subfolder_name}/`)",
 
352
  "pixel_sum": pixel_sum,
353
  "cache_key": cache_key,
354
  }
355
+ _populate_measure_session_state(
356
+ heatmap, img, pixel_sum, force, key_img, colormap_name,
357
+ display_mode, auto_cell_boundary,
358
+ )
 
 
 
 
 
 
359
 
360
  render_result_display(
361
  img, heatmap, display_heatmap, pixel_sum, force, key_img,
 
374
  img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
375
  display_heatmap = apply_display_scale(heatmap, display_mode)
376
 
377
+ _populate_measure_session_state(
378
+ heatmap, img, pixel_sum, force, key_img, colormap_name,
379
+ display_mode, auto_cell_boundary,
380
+ )
 
 
 
 
 
 
381
 
382
  if st.session_state.pop("open_measure_dialog", False):
383
  measure_region_dialog()
predictor.py CHANGED
@@ -14,6 +14,7 @@ if S2F_ROOT not in sys.path:
14
  sys.path.insert(0, S2F_ROOT)
15
 
16
  from models.s2f_model import create_s2f_model
 
17
  from utils.substrate_settings import get_settings_of_category, compute_settings_normalization
18
  from utils import config
19
 
@@ -89,12 +90,8 @@ class S2FPredictor:
89
  """
90
  self.model_type = model_type
91
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
92
- ckp_base = os.path.join(S2F_ROOT, "ckp")
93
- if not os.path.isdir(ckp_base):
94
- project_root = os.path.dirname(S2F_ROOT)
95
- if os.path.isdir(os.path.join(project_root, "ckp")):
96
- ckp_base = os.path.join(project_root, "ckp")
97
- subfolder = "single_cell" if model_type == "single_cell" else "spheroid"
98
  ckp_dir = ckp_folder if ckp_folder else os.path.join(ckp_base, subfolder)
99
  if not os.path.isdir(ckp_dir):
100
  ckp_dir = ckp_base # fallback if subfolders not used
 
14
  sys.path.insert(0, S2F_ROOT)
15
 
16
  from models.s2f_model import create_s2f_model
17
+ from utils.paths import get_ckp_base, model_subfolder
18
  from utils.substrate_settings import get_settings_of_category, compute_settings_normalization
19
  from utils import config
20
 
 
90
  """
91
  self.model_type = model_type
92
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
93
+ ckp_base = get_ckp_base(S2F_ROOT)
94
+ subfolder = model_subfolder(model_type)
 
 
 
 
95
  ckp_dir = ckp_folder if ckp_folder else os.path.join(ckp_base, subfolder)
96
  if not os.path.isdir(ckp_dir):
97
  ckp_dir = ckp_base # fallback if subfolders not used
ui/components.py CHANGED
@@ -17,7 +17,7 @@ from config.constants import (
17
  TOOL_LABELS,
18
  )
19
  from utils.display import apply_display_scale, cv_colormap_to_plotly_colorscale
20
- from utils.report import heatmap_to_rgb, heatmap_to_png_bytes, create_pdf_report
21
  from utils.segmentation import estimate_cell_mask
22
 
23
  try:
@@ -102,17 +102,6 @@ def _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h):
102
  return pts
103
 
104
 
105
- def parse_canvas_shapes_to_mask(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
106
- """Parse drawn shapes from streamlit-drawable-canvas json_data and create binary mask (combined)."""
107
- masks, _ = parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w)
108
- if not masks:
109
- return None, 0
110
- combined = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8)
111
- for m in masks:
112
- combined = np.maximum(combined, m)
113
- return combined, len(masks)
114
-
115
-
116
  def parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
117
  """Parse drawn shapes and return a list of individual masks (one per shape)."""
118
  if not json_data or "objects" not in json_data or not json_data["objects"]:
@@ -237,9 +226,7 @@ def render_region_canvas(display_heatmap, raw_heatmap=None, bf_img=None, origina
237
  """Render drawable canvas and region metrics. When cell_vals: show cell area (replaces Full map). Else: show Full map."""
238
  raw_heatmap = raw_heatmap if raw_heatmap is not None else display_heatmap
239
  h, w = display_heatmap.shape
240
- heatmap_rgb = heatmap_to_rgb(display_heatmap, colormap_name)
241
- if cell_mask is not None and np.any(cell_mask > 0):
242
- heatmap_rgb = _draw_contour_on_image(heatmap_rgb.copy(), cell_mask, stroke_color=(255, 0, 0), stroke_width=2)
243
  pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS)
244
 
245
  st.markdown("""
 
17
  TOOL_LABELS,
18
  )
19
  from utils.display import apply_display_scale, cv_colormap_to_plotly_colorscale
20
+ from utils.report import heatmap_to_rgb, heatmap_to_rgb_with_contour, heatmap_to_png_bytes, create_pdf_report
21
  from utils.segmentation import estimate_cell_mask
22
 
23
  try:
 
102
  return pts
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
105
  def parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
106
  """Parse drawn shapes and return a list of individual masks (one per shape)."""
107
  if not json_data or "objects" not in json_data or not json_data["objects"]:
 
226
  """Render drawable canvas and region metrics. When cell_vals: show cell area (replaces Full map). Else: show Full map."""
227
  raw_heatmap = raw_heatmap if raw_heatmap is not None else display_heatmap
228
  h, w = display_heatmap.shape
229
+ heatmap_rgb = heatmap_to_rgb_with_contour(display_heatmap, colormap_name, cell_mask)
 
 
230
  pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS)
231
 
232
  st.markdown("""
utils/paths.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Path resolution utilities for S2F App."""
2
+ import os
3
+
4
+
5
+ def get_ckp_base(root):
6
+ """Resolve checkpoint base directory (S2FApp/ckp or project/ckp)."""
7
+ ckp_base = os.path.join(root, "ckp")
8
+ if not os.path.isdir(ckp_base):
9
+ project_root = os.path.dirname(root)
10
+ if os.path.isdir(os.path.join(project_root, "ckp")):
11
+ ckp_base = os.path.join(project_root, "ckp")
12
+ return ckp_base
13
+
14
+
15
+ def model_subfolder(model_type):
16
+ """Return subfolder name for model type: 'single_cell' or 'spheroid'."""
17
+ return "single_cell" if model_type == "single_cell" else "spheroid"
utils/report.py CHANGED
@@ -21,13 +21,19 @@ def heatmap_to_rgb(scaled_heatmap, colormap_name="Jet"):
21
  return heatmap_rgb
22
 
23
 
24
- def heatmap_to_png_bytes(scaled_heatmap, colormap_name="Jet", cell_mask=None):
25
- """Convert scaled heatmap (float 0-1) to PNG bytes buffer. Optionally draw red cell contour."""
26
  heatmap_rgb = heatmap_to_rgb(scaled_heatmap, colormap_name)
27
  if cell_mask is not None and np.any(cell_mask > 0):
28
  contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
29
  if contours:
30
  cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0), 2)
 
 
 
 
 
 
31
  buf = io.BytesIO()
32
  Image.fromarray(heatmap_rgb).save(buf, format="PNG")
33
  buf.seek(0)
@@ -62,11 +68,7 @@ def create_pdf_report(img, display_heatmap, raw_heatmap, pixel_sum, force, base_
62
  c.setFont("Helvetica", 9)
63
  c.drawString(72, img_top - img_h - 12, "Input: Bright-field")
64
 
65
- heatmap_rgb = heatmap_to_rgb(display_heatmap, colormap_name)
66
- if cell_mask is not None and np.any(cell_mask > 0):
67
- contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
68
- if contours:
69
- cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0), 2)
70
  hm_buf = io.BytesIO()
71
  Image.fromarray(heatmap_rgb).save(hm_buf, format="PNG")
72
  hm_buf.seek(0)
 
21
  return heatmap_rgb
22
 
23
 
24
+ def heatmap_to_rgb_with_contour(scaled_heatmap, colormap_name="Jet", cell_mask=None):
25
+ """Convert heatmap to RGB, optionally drawing red cell contour. Mask must match heatmap shape."""
26
  heatmap_rgb = heatmap_to_rgb(scaled_heatmap, colormap_name)
27
  if cell_mask is not None and np.any(cell_mask > 0):
28
  contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
29
  if contours:
30
  cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0), 2)
31
+ return heatmap_rgb
32
+
33
+
34
+ def heatmap_to_png_bytes(scaled_heatmap, colormap_name="Jet", cell_mask=None):
35
+ """Convert scaled heatmap (float 0-1) to PNG bytes buffer. Optionally draw red cell contour."""
36
+ heatmap_rgb = heatmap_to_rgb_with_contour(scaled_heatmap, colormap_name, cell_mask)
37
  buf = io.BytesIO()
38
  Image.fromarray(heatmap_rgb).save(buf, format="PNG")
39
  buf.seek(0)
 
68
  c.setFont("Helvetica", 9)
69
  c.drawString(72, img_top - img_h - 12, "Input: Bright-field")
70
 
71
+ heatmap_rgb = heatmap_to_rgb_with_contour(display_heatmap, colormap_name, cell_mask)
 
 
 
 
72
  hm_buf = io.BytesIO()
73
  Image.fromarray(heatmap_rgb).save(hm_buf, format="PNG")
74
  hm_buf.seek(0)