kaveh commited on
Commit
607b168
·
1 Parent(s): b16c4e9

optimised

Browse files
S2FApp/app.py CHANGED
@@ -17,10 +17,12 @@ if S2F_ROOT not in sys.path:
17
 
18
  from config.constants import (
19
  COLORMAPS,
 
20
  MODEL_TYPE_LABELS,
21
  SAMPLE_EXTENSIONS,
 
22
  )
23
- from utils.paths import get_ckp_base, model_subfolder
24
  from utils.segmentation import estimate_cell_mask
25
  from utils.substrate_settings import list_substrates
26
  from utils.display import apply_display_scale
@@ -70,9 +72,10 @@ def _get_measure_dialog_fn():
70
 
71
 
72
  def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, colormap_name,
73
- display_mode, auto_cell_boundary):
74
- """Populate session state for the measure tool."""
75
- cell_mask = estimate_cell_mask(heatmap)
 
76
  st.session_state["measure_raw_heatmap"] = heatmap.copy()
77
  st.session_state["measure_display_mode"] = display_mode
78
  st.session_state["measure_bf_img"] = img.copy()
@@ -135,25 +138,6 @@ st.caption("Predict traction force maps from bright-field microscopy images of c
135
 
136
  # Folders
137
  ckp_base = get_ckp_base(S2F_ROOT)
138
- ckp_single_cell = os.path.join(ckp_base, "single_cell")
139
- ckp_spheroid = os.path.join(ckp_base, "spheroid")
140
- sample_base = os.path.join(S2F_ROOT, "samples")
141
- sample_single_cell = os.path.join(sample_base, "single_cell")
142
- sample_spheroid = os.path.join(sample_base, "spheroid")
143
-
144
-
145
- def get_ckp_files_for_model(model_type):
146
- folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid
147
- if os.path.isdir(folder):
148
- return sorted(f for f in os.listdir(folder) if f.endswith(".pth"))
149
- return []
150
-
151
-
152
- def get_sample_files_for_model(model_type):
153
- folder = sample_single_cell if model_type == "single_cell" else sample_spheroid
154
- if os.path.isdir(folder):
155
- return sorted(f for f in os.listdir(folder) if f.lower().endswith(SAMPLE_EXTENSIONS))
156
- return []
157
 
158
 
159
  def get_cached_sample_thumbnails(model_type, sample_folder, sample_files):
@@ -164,7 +148,7 @@ def get_cached_sample_thumbnails(model_type, sample_folder, sample_files):
164
  cache = st.session_state["sample_thumbnails"]
165
  if cache_key not in cache:
166
  thumbnails = []
167
- for fname in sample_files[:8]:
168
  path = os.path.join(sample_folder, fname)
169
  img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
170
  thumbnails.append((fname, img))
@@ -184,8 +168,8 @@ with st.sidebar:
184
  help="Single cell: substrate-aware force prediction. Spheroid: spheroid force maps.",
185
  )
186
 
187
- ckp_files = get_ckp_files_for_model(model_type)
188
- ckp_folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid
189
  ckp_subfolder_name = model_subfolder(model_type)
190
 
191
  if ckp_files:
@@ -199,7 +183,7 @@ with st.sidebar:
199
  checkpoint = None
200
 
201
  substrate_config = None
202
- substrate_val = "Fibroblasts_Fibronectin_6KPa"
203
  use_manual = True
204
  if model_type == "single_cell":
205
  try:
@@ -270,8 +254,8 @@ if img_source == "Upload":
270
  img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
271
  uploaded.seek(0)
272
  else:
273
- sample_files = get_sample_files_for_model(model_type)
274
- sample_folder = sample_single_cell if model_type == "single_cell" else sample_spheroid
275
  sample_subfolder_name = model_subfolder(model_type)
276
  if sample_files:
277
  selected_sample = st.selectbox(
@@ -333,7 +317,7 @@ if just_ran:
333
  with st.spinner("Loading model and predicting..."):
334
  try:
335
  predictor = get_or_create_predictor(model_type, checkpoint, ckp_folder)
336
- sub_val = substrate_val if model_type == "single_cell" and not use_manual else "Fibroblasts_Fibronectin_6KPa"
337
  heatmap, force, pixel_sum = predictor.predict(
338
  image_array=img,
339
  substrate=sub_val,
@@ -352,17 +336,18 @@ if just_ran:
352
  "pixel_sum": pixel_sum,
353
  "cache_key": cache_key,
354
  }
 
355
  _populate_measure_session_state(
356
  heatmap, img, pixel_sum, force, key_img, colormap_name,
357
- display_mode, auto_cell_boundary,
358
  )
359
-
360
  render_result_display(
361
  img, heatmap, display_heatmap, pixel_sum, force, key_img,
362
  colormap_name=colormap_name,
363
  display_mode=display_mode,
364
  measure_region_dialog=_get_measure_dialog_fn(),
365
  auto_cell_boundary=auto_cell_boundary,
 
366
  )
367
 
368
  except Exception as e:
@@ -373,10 +358,10 @@ elif has_cached:
373
  r = st.session_state["prediction_result"]
374
  img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
375
  display_heatmap = apply_display_scale(heatmap, display_mode)
376
-
377
  _populate_measure_session_state(
378
  heatmap, img, pixel_sum, force, key_img, colormap_name,
379
- display_mode, auto_cell_boundary,
380
  )
381
 
382
  if st.session_state.pop("open_measure_dialog", False):
@@ -390,6 +375,7 @@ elif has_cached:
390
  display_mode=display_mode,
391
  measure_region_dialog=_get_measure_dialog_fn(),
392
  auto_cell_boundary=auto_cell_boundary,
 
393
  )
394
 
395
  elif run and not checkpoint:
 
17
 
18
  from config.constants import (
19
  COLORMAPS,
20
+ DEFAULT_SUBSTRATE,
21
  MODEL_TYPE_LABELS,
22
  SAMPLE_EXTENSIONS,
23
+ SAMPLE_THUMBNAIL_LIMIT,
24
  )
25
+ from utils.paths import get_ckp_base, get_ckp_folder, get_sample_folder, list_files_in_folder, model_subfolder
26
  from utils.segmentation import estimate_cell_mask
27
  from utils.substrate_settings import list_substrates
28
  from utils.display import apply_display_scale
 
72
 
73
 
74
  def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, colormap_name,
75
+ display_mode, auto_cell_boundary, cell_mask=None):
76
+ """Populate session state for the measure tool. If cell_mask is None and auto_cell_boundary, computes it."""
77
+ if cell_mask is None and auto_cell_boundary:
78
+ cell_mask = estimate_cell_mask(heatmap)
79
  st.session_state["measure_raw_heatmap"] = heatmap.copy()
80
  st.session_state["measure_display_mode"] = display_mode
81
  st.session_state["measure_bf_img"] = img.copy()
 
138
 
139
  # Folders
140
  ckp_base = get_ckp_base(S2F_ROOT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
 
143
  def get_cached_sample_thumbnails(model_type, sample_folder, sample_files):
 
148
  cache = st.session_state["sample_thumbnails"]
149
  if cache_key not in cache:
150
  thumbnails = []
151
+ for fname in sample_files[:SAMPLE_THUMBNAIL_LIMIT]:
152
  path = os.path.join(sample_folder, fname)
153
  img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
154
  thumbnails.append((fname, img))
 
168
  help="Single cell: substrate-aware force prediction. Spheroid: spheroid force maps.",
169
  )
170
 
171
+ ckp_folder = get_ckp_folder(ckp_base, model_type)
172
+ ckp_files = list_files_in_folder(ckp_folder, ".pth")
173
  ckp_subfolder_name = model_subfolder(model_type)
174
 
175
  if ckp_files:
 
183
  checkpoint = None
184
 
185
  substrate_config = None
186
+ substrate_val = DEFAULT_SUBSTRATE
187
  use_manual = True
188
  if model_type == "single_cell":
189
  try:
 
254
  img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
255
  uploaded.seek(0)
256
  else:
257
+ sample_folder = get_sample_folder(S2F_ROOT, model_type)
258
+ sample_files = list_files_in_folder(sample_folder, SAMPLE_EXTENSIONS)
259
  sample_subfolder_name = model_subfolder(model_type)
260
  if sample_files:
261
  selected_sample = st.selectbox(
 
317
  with st.spinner("Loading model and predicting..."):
318
  try:
319
  predictor = get_or_create_predictor(model_type, checkpoint, ckp_folder)
320
+ sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE
321
  heatmap, force, pixel_sum = predictor.predict(
322
  image_array=img,
323
  substrate=sub_val,
 
336
  "pixel_sum": pixel_sum,
337
  "cache_key": cache_key,
338
  }
339
+ cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None
340
  _populate_measure_session_state(
341
  heatmap, img, pixel_sum, force, key_img, colormap_name,
342
+ display_mode, auto_cell_boundary, cell_mask=cell_mask,
343
  )
 
344
  render_result_display(
345
  img, heatmap, display_heatmap, pixel_sum, force, key_img,
346
  colormap_name=colormap_name,
347
  display_mode=display_mode,
348
  measure_region_dialog=_get_measure_dialog_fn(),
349
  auto_cell_boundary=auto_cell_boundary,
350
+ cell_mask=cell_mask,
351
  )
352
 
353
  except Exception as e:
 
358
  r = st.session_state["prediction_result"]
359
  img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
360
  display_heatmap = apply_display_scale(heatmap, display_mode)
361
+ cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None
362
  _populate_measure_session_state(
363
  heatmap, img, pixel_sum, force, key_img, colormap_name,
364
+ display_mode, auto_cell_boundary, cell_mask=cell_mask,
365
  )
366
 
367
  if st.session_state.pop("open_measure_dialog", False):
 
375
  display_mode=display_mode,
376
  measure_region_dialog=_get_measure_dialog_fn(),
377
  auto_cell_boundary=auto_cell_boundary,
378
+ cell_mask=cell_mask,
379
  )
380
 
381
  elif run and not checkpoint:
S2FApp/config/constants.py CHANGED
@@ -6,8 +6,12 @@ import cv2
6
  # Model & paths
7
  MODEL_INPUT_SIZE = 1024
8
 
 
 
 
9
  # UI
10
  CANVAS_SIZE = 320
 
11
  COLORMAP_N_SAMPLES = 64
12
 
13
  # Model type labels
 
6
  # Model & paths
7
  MODEL_INPUT_SIZE = 1024
8
 
9
+ # Default substrate (used when config lookup fails or manual mode fallback)
10
+ DEFAULT_SUBSTRATE = "Fibroblasts_Fibronectin_6KPa"
11
+
12
  # UI
13
  CANVAS_SIZE = 320
14
+ SAMPLE_THUMBNAIL_LIMIT = 8
15
  COLORMAP_N_SAMPLES = 64
16
 
17
  # Model type labels
S2FApp/predictor.py CHANGED
@@ -13,20 +13,22 @@ S2F_ROOT = os.path.dirname(os.path.abspath(__file__))
13
  if S2F_ROOT not in sys.path:
14
  sys.path.insert(0, S2F_ROOT)
15
 
 
16
  from models.s2f_model import create_s2f_model
17
  from utils.paths import get_ckp_base, model_subfolder
18
  from utils.substrate_settings import get_settings_of_category, compute_settings_normalization
19
  from utils import config
20
 
21
 
22
- def load_image(filepath, target_size=1024):
23
  """Load and preprocess a bright field image."""
24
  img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
25
  if img is None:
26
  raise ValueError(f"Could not load image: {filepath}")
27
- if isinstance(target_size, int):
28
- target_size = (target_size, target_size)
29
- img = cv2.resize(img, target_size)
 
30
  img = img.astype(np.float32) / 255.0
31
  return img
32
 
@@ -126,7 +128,7 @@ class S2FPredictor:
126
  self._use_tanh_output = model_type == "single_cell" # single_cell uses tanh, spheroid uses sigmoid
127
  self.config_path = os.path.join(S2F_ROOT, "config", "substrate_settings.json")
128
 
129
- def predict(self, image_path=None, image_array=None, substrate="Fibroblasts_Fibronectin_6KPa",
130
  substrate_config=None):
131
  """
132
  Run prediction on an image.
@@ -138,7 +140,7 @@ class S2FPredictor:
138
  substrate_config: Optional dict with 'pixelsize' and 'young'. Overrides substrate lookup.
139
 
140
  Returns:
141
- heatmap: numpy array (1024, 1024) in [0, 1]
142
  force: scalar cell force (sum of heatmap * SCALE_FACTOR_FORCE)
143
  pixel_sum: raw sum of all pixel values in heatmap
144
  """
@@ -150,15 +152,16 @@ class S2FPredictor:
150
  img = img[:, :, 0] if img.shape[-1] >= 1 else img
151
  if img.max() > 1.0:
152
  img = img / 255.0
153
- img = cv2.resize(img, (1024, 1024))
154
  else:
155
  raise ValueError("Provide image_path or image_array")
156
 
157
  x = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0).to(self.device) # [1,1,H,W]
158
 
159
  if self.model_type == "single_cell" and self.norm_params is not None:
 
160
  settings_ch = create_settings_channels_single(
161
- substrate, self.device, x.shape[2], x.shape[3],
162
  config_path=self.config_path, substrate_config=substrate_config
163
  )
164
  x = torch.cat([x, settings_ch], dim=1) # [1,3,H,W]
 
13
  if S2F_ROOT not in sys.path:
14
  sys.path.insert(0, S2F_ROOT)
15
 
16
+ from config.constants import DEFAULT_SUBSTRATE, MODEL_INPUT_SIZE
17
  from models.s2f_model import create_s2f_model
18
  from utils.paths import get_ckp_base, model_subfolder
19
  from utils.substrate_settings import get_settings_of_category, compute_settings_normalization
20
  from utils import config
21
 
22
 
23
+ def load_image(filepath, target_size=None):
24
  """Load and preprocess a bright field image."""
25
  img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
26
  if img is None:
27
  raise ValueError(f"Could not load image: {filepath}")
28
+ size = target_size if target_size is not None else MODEL_INPUT_SIZE
29
+ if isinstance(size, int):
30
+ size = (size, size)
31
+ img = cv2.resize(img, size)
32
  img = img.astype(np.float32) / 255.0
33
  return img
34
 
 
128
  self._use_tanh_output = model_type == "single_cell" # single_cell uses tanh, spheroid uses sigmoid
129
  self.config_path = os.path.join(S2F_ROOT, "config", "substrate_settings.json")
130
 
131
+ def predict(self, image_path=None, image_array=None, substrate=None,
132
  substrate_config=None):
133
  """
134
  Run prediction on an image.
 
140
  substrate_config: Optional dict with 'pixelsize' and 'young'. Overrides substrate lookup.
141
 
142
  Returns:
143
+ heatmap: numpy array (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE) in [0, 1]
144
  force: scalar cell force (sum of heatmap * SCALE_FACTOR_FORCE)
145
  pixel_sum: raw sum of all pixel values in heatmap
146
  """
 
152
  img = img[:, :, 0] if img.shape[-1] >= 1 else img
153
  if img.max() > 1.0:
154
  img = img / 255.0
155
+ img = cv2.resize(img, (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE))
156
  else:
157
  raise ValueError("Provide image_path or image_array")
158
 
159
  x = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0).to(self.device) # [1,1,H,W]
160
 
161
  if self.model_type == "single_cell" and self.norm_params is not None:
162
+ sub = substrate if substrate is not None else DEFAULT_SUBSTRATE
163
  settings_ch = create_settings_channels_single(
164
+ sub, self.device, x.shape[2], x.shape[3],
165
  config_path=self.config_path, substrate_config=substrate_config
166
  )
167
  x = torch.cat([x, settings_ch], dim=1) # [1,3,H,W]
S2FApp/ui/components.py CHANGED
@@ -50,18 +50,23 @@ _REGION_COLORS = [
50
  ]
51
 
52
 
53
- def make_annotated_heatmap(heatmap_rgb, mask, fill_alpha=0.3, stroke_color=(255, 102, 0), stroke_width=2):
54
- """Composite heatmap with drawn region overlay."""
55
- annotated = heatmap_rgb.copy()
56
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
57
  overlay = annotated.copy()
58
- cv2.fillPoly(overlay, contours, stroke_color)
59
  mask_3d = np.stack([mask] * 3, axis=-1).astype(bool)
60
  annotated[mask_3d] = (
61
  (1 - fill_alpha) * annotated[mask_3d].astype(np.float32)
62
  + fill_alpha * overlay[mask_3d].astype(np.float32)
63
  ).astype(np.uint8)
64
- cv2.drawContours(annotated, contours, -1, stroke_color, stroke_width)
 
 
 
 
 
 
65
  return annotated
66
 
67
 
@@ -73,15 +78,7 @@ def make_annotated_heatmap_multi_regions(heatmap_rgb, masks, labels, cell_mask=N
73
  cv2.drawContours(annotated, contours, -1, (255, 0, 0), 2)
74
  for i, mask in enumerate(masks):
75
  color = _REGION_COLORS[i % len(_REGION_COLORS)]
76
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
77
- overlay = annotated.copy()
78
- cv2.fillPoly(overlay, contours, color)
79
- mask_3d = np.stack([mask] * 3, axis=-1).astype(bool)
80
- annotated[mask_3d] = (
81
- (1 - fill_alpha) * annotated[mask_3d].astype(np.float32)
82
- + fill_alpha * overlay[mask_3d].astype(np.float32)
83
- ).astype(np.uint8)
84
- cv2.drawContours(annotated, contours, -1, color, 2)
85
  # Label at centroid
86
  M = cv2.moments(mask)
87
  if M["m00"] > 0:
@@ -159,7 +156,7 @@ def _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h):
159
  def parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
160
  """Parse drawn shapes and return a list of individual masks (one per shape)."""
161
  if not json_data or "objects" not in json_data or not json_data["objects"]:
162
- return [], 0
163
  scale_x = heatmap_w / canvas_w
164
  scale_y = heatmap_h / canvas_h
165
  masks = []
@@ -170,7 +167,7 @@ def parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatm
170
  mask = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8)
171
  cv2.fillPoly(mask, [pts], 1)
172
  masks.append(mask)
173
- return masks, len(masks)
174
 
175
 
176
  def build_original_vals(raw_heatmap, pixel_sum, force):
@@ -385,8 +382,8 @@ def render_region_canvas(display_heatmap, raw_heatmap=None, bf_img=None, origina
385
  )
386
 
387
  if canvas_result.json_data:
388
- masks, n = parse_canvas_shapes_to_masks(canvas_result.json_data, CANVAS_SIZE, CANVAS_SIZE, h, w)
389
- if masks and n > 0:
390
  metrics_list = [compute_region_metrics(raw_heatmap, m, original_vals) for m in masks]
391
  if cell_mask is not None and np.any(cell_mask > 0):
392
  cell_metrics = compute_region_metrics(raw_heatmap, cell_mask, original_vals)
@@ -433,13 +430,18 @@ def _add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2):
433
 
434
 
435
  def render_result_display(img, raw_heatmap, display_heatmap, pixel_sum, force, key_img, download_key_suffix="",
436
- colormap_name="Jet", display_mode="Auto", measure_region_dialog=None, auto_cell_boundary=True):
 
437
  """
438
  Render prediction result: plot, metrics, expander, and download/measure buttons.
439
  measure_region_dialog: callable to open measure dialog (when ST_DIALOG available).
440
  auto_cell_boundary: when True, use estimated cell area for metrics; when False, use entire map.
 
441
  """
442
- cell_mask = estimate_cell_mask(raw_heatmap) if auto_cell_boundary else None
 
 
 
443
  cell_pixel_sum, cell_force, cell_mean = _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force) if cell_mask is not None else (None, None, None)
444
  use_cell_metrics = auto_cell_boundary and cell_pixel_sum is not None and cell_force is not None and cell_mean is not None
445
 
 
50
  ]
51
 
52
 
53
+ def _draw_region_overlay(annotated, mask, color, fill_alpha=0.3, stroke_width=2):
54
+ """Draw single region overlay on annotated heatmap (fill + alpha blend + contour). Modifies annotated in place."""
 
55
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
56
  overlay = annotated.copy()
57
+ cv2.fillPoly(overlay, contours, color)
58
  mask_3d = np.stack([mask] * 3, axis=-1).astype(bool)
59
  annotated[mask_3d] = (
60
  (1 - fill_alpha) * annotated[mask_3d].astype(np.float32)
61
  + fill_alpha * overlay[mask_3d].astype(np.float32)
62
  ).astype(np.uint8)
63
+ cv2.drawContours(annotated, contours, -1, color, stroke_width)
64
+
65
+
66
+ def make_annotated_heatmap(heatmap_rgb, mask, fill_alpha=0.3, stroke_color=(255, 102, 0), stroke_width=2):
67
+ """Composite heatmap with drawn region overlay."""
68
+ annotated = heatmap_rgb.copy()
69
+ _draw_region_overlay(annotated, mask, stroke_color, fill_alpha, stroke_width)
70
  return annotated
71
 
72
 
 
78
  cv2.drawContours(annotated, contours, -1, (255, 0, 0), 2)
79
  for i, mask in enumerate(masks):
80
  color = _REGION_COLORS[i % len(_REGION_COLORS)]
81
+ _draw_region_overlay(annotated, mask, color, fill_alpha, stroke_width=2)
 
 
 
 
 
 
 
 
82
  # Label at centroid
83
  M = cv2.moments(mask)
84
  if M["m00"] > 0:
 
156
  def parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
157
  """Parse drawn shapes and return a list of individual masks (one per shape)."""
158
  if not json_data or "objects" not in json_data or not json_data["objects"]:
159
+ return []
160
  scale_x = heatmap_w / canvas_w
161
  scale_y = heatmap_h / canvas_h
162
  masks = []
 
167
  mask = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8)
168
  cv2.fillPoly(mask, [pts], 1)
169
  masks.append(mask)
170
+ return masks
171
 
172
 
173
  def build_original_vals(raw_heatmap, pixel_sum, force):
 
382
  )
383
 
384
  if canvas_result.json_data:
385
+ masks = parse_canvas_shapes_to_masks(canvas_result.json_data, CANVAS_SIZE, CANVAS_SIZE, h, w)
386
+ if masks:
387
  metrics_list = [compute_region_metrics(raw_heatmap, m, original_vals) for m in masks]
388
  if cell_mask is not None and np.any(cell_mask > 0):
389
  cell_metrics = compute_region_metrics(raw_heatmap, cell_mask, original_vals)
 
430
 
431
 
432
  def render_result_display(img, raw_heatmap, display_heatmap, pixel_sum, force, key_img, download_key_suffix="",
433
+ colormap_name="Jet", display_mode="Auto", measure_region_dialog=None, auto_cell_boundary=True,
434
+ cell_mask=None):
435
  """
436
  Render prediction result: plot, metrics, expander, and download/measure buttons.
437
  measure_region_dialog: callable to open measure dialog (when ST_DIALOG available).
438
  auto_cell_boundary: when True, use estimated cell area for metrics; when False, use entire map.
439
+ cell_mask: optional precomputed cell mask; if None and auto_cell_boundary, will be computed.
440
  """
441
+ if cell_mask is None and auto_cell_boundary:
442
+ cell_mask = estimate_cell_mask(raw_heatmap)
443
+ elif not auto_cell_boundary:
444
+ cell_mask = None
445
  cell_pixel_sum, cell_force, cell_mean = _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force) if cell_mask is not None else (None, None, None)
446
  use_cell_metrics = auto_cell_boundary and cell_pixel_sum is not None and cell_force is not None and cell_mean is not None
447
 
S2FApp/utils/metrics.py CHANGED
@@ -4,6 +4,8 @@ Includes: MSE, MS-SSIM, Pixel Correlation (Pearson), Relative Magnitude Error (W
4
  and evaluation helpers for notebooks and scripts.
5
  """
6
  import os
 
 
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
@@ -237,7 +239,7 @@ def evaluate_metrics_on_dataset(generator, data_loader, device=None, description
237
 
238
  if use_settings and normalization_params is not None:
239
  from models.s2f_model import create_settings_channels
240
- meta = metadata if has_metadata else {'substrate': [substrate_override or 'Fibroblasts_Fibronectin_6KPa'] * images.size(0)}
241
  settings_ch = create_settings_channels(meta, normalization_params, device, images.shape, config_path=config_path)
242
  images = torch.cat([images, settings_ch], dim=1)
243
 
@@ -420,7 +422,7 @@ def plot_predictions(loader, generator, n_samples, device, threshold=0.0,
420
  bf_batch = torch.stack(bf_list[:n]).to(device, dtype=torch.float32)
421
  if use_settings and normalization_params:
422
  from models.s2f_model import create_settings_channels
423
- sub = substrate_override or 'Fibroblasts_Fibronectin_6KPa'
424
  meta_dict = {'substrate': [sub] * n}
425
  settings_ch = create_settings_channels(meta_dict, normalization_params, device, bf_batch.shape, config_path=config_path)
426
  bf_batch = torch.cat([bf_batch, settings_ch], dim=1)
 
4
  and evaluation helpers for notebooks and scripts.
5
  """
6
  import os
7
+
8
+ from config.constants import DEFAULT_SUBSTRATE
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
 
239
 
240
  if use_settings and normalization_params is not None:
241
  from models.s2f_model import create_settings_channels
242
+ meta = metadata if has_metadata else {'substrate': [substrate_override or DEFAULT_SUBSTRATE] * images.size(0)}
243
  settings_ch = create_settings_channels(meta, normalization_params, device, images.shape, config_path=config_path)
244
  images = torch.cat([images, settings_ch], dim=1)
245
 
 
422
  bf_batch = torch.stack(bf_list[:n]).to(device, dtype=torch.float32)
423
  if use_settings and normalization_params:
424
  from models.s2f_model import create_settings_channels
425
+ sub = substrate_override or DEFAULT_SUBSTRATE
426
  meta_dict = {'substrate': [sub] * n}
427
  settings_ch = create_settings_channels(meta_dict, normalization_params, device, bf_batch.shape, config_path=config_path)
428
  bf_batch = torch.cat([bf_batch, settings_ch], dim=1)
S2FApp/utils/paths.py CHANGED
@@ -15,3 +15,29 @@ def get_ckp_base(root):
15
  def model_subfolder(model_type):
16
  """Return subfolder name for model type: 'single_cell' or 'spheroid'."""
17
  return "single_cell" if model_type == "single_cell" else "spheroid"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def model_subfolder(model_type):
16
  """Return subfolder name for model type: 'single_cell' or 'spheroid'."""
17
  return "single_cell" if model_type == "single_cell" else "spheroid"
18
+
19
+
20
+ def get_ckp_folder(ckp_base, model_type):
21
+ """Return checkpoint folder path for model type."""
22
+ return os.path.join(ckp_base, model_subfolder(model_type))
23
+
24
+
25
+ def get_sample_folder(root, model_type):
26
+ """Return sample folder path for model type (samples/<subfolder>)."""
27
+ return os.path.join(root, "samples", model_subfolder(model_type))
28
+
29
+
30
+ def list_files_in_folder(folder, extensions):
31
+ """
32
+ List files in folder matching extensions. Returns sorted list.
33
+ extensions: str or tuple of suffixes, e.g. '.pth' or ('.tif', '.png'). Matching is case-insensitive.
34
+ """
35
+ if not os.path.isdir(folder):
36
+ return []
37
+ ext_tuple = (extensions,) if isinstance(extensions, str) else extensions
38
+
39
+ def matches(fname):
40
+ fname_lower = fname.lower()
41
+ return any(fname_lower.endswith(e.lower()) for e in ext_tuple)
42
+
43
+ return sorted(f for f in os.listdir(folder) if matches(f))
S2FApp/utils/report.py CHANGED
@@ -13,6 +13,41 @@ from reportlab.pdfgen import canvas
13
  from config.constants import COLORMAPS
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def heatmap_to_rgb(scaled_heatmap, colormap_name="Jet"):
17
  """Convert scaled heatmap (float 0-1) to RGB array using the given colormap."""
18
  heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
@@ -47,35 +82,23 @@ def create_pdf_report(img, display_heatmap, raw_heatmap, pixel_sum, force, base_
47
  c = canvas.Canvas(buf, pagesize=A4)
48
  c.setTitle("Shape2Force")
49
  c.setAuthor("Angione-Lab")
50
- page_w_pt = A4[0]
51
- page_h_pt = A4[1]
52
  margin = 72
53
- img_w, img_h = 2.8 * inch, 2.8 * inch
54
- img_gap = 20
55
-
56
- # Center images
57
- total_img_width = 2 * img_w + img_gap
58
- img_left = margin + (page_w_pt - 2 * margin - total_img_width) / 2
59
- bf_x = img_left
60
- hm_x = img_left + img_w + img_gap
61
-
62
- footer_y = 40
63
- c.setFont("Helvetica", 8)
64
- c.setFillColorRGB(0.4, 0.4, 0.4)
65
- gen_date = datetime.now().strftime("%Y-%m-%d %H:%M")
66
- c.drawString(margin, footer_y, f"Generated by Shape2Force (S2F) on {gen_date}")
67
- c.drawString(margin, footer_y - 12, "Model: https://huggingface.co/Angione-Lab/Shape2Force")
68
- c.drawString(margin, footer_y - 24, "Web app: https://huggingface.co/spaces/Angione-Lab/Shape2force")
69
- c.setFillColorRGB(0, 0, 0)
70
-
71
- y_top = page_h_pt - 50
72
  c.setFont("Helvetica-Bold", 16)
73
  c.drawString(margin, y_top, "Shape2Force (S2F) - Prediction Report")
74
  c.setFont("Helvetica", 10)
75
  c.drawString(margin, y_top - 14, f"Image: {base_name}")
76
  y_top -= 35
77
 
78
- img_bottom = y_top - img_h
79
  img_pil = Image.fromarray(img) if img.ndim == 2 else Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
80
  img_buf = io.BytesIO()
81
  img_pil.save(img_buf, format="PNG")
@@ -133,36 +156,22 @@ def create_measure_pdf_report(bf_img, heatmap_labeled_rgb, table_rows, base_name
133
  c = canvas.Canvas(buf, pagesize=A4)
134
  c.setTitle("Shape2Force - Region Measurement")
135
  c.setAuthor("Angione-Lab")
136
- page_w_pt = A4[0]
137
- page_h_pt = A4[1]
138
  margin = 72
139
- img_w = 2.8 * inch
140
- img_h = 2.8 * inch
141
- img_gap = 20
142
-
143
- # Center images: two images side by side
144
- total_img_width = 2 * img_w + img_gap
145
- img_left = margin + (page_w_pt - 2 * margin - total_img_width) / 2
146
- bf_x = img_left
147
- hm_x = img_left + img_w + img_gap
148
-
149
- footer_y = 40
150
- c.setFont("Helvetica", 8)
151
- c.setFillColorRGB(0.4, 0.4, 0.4)
152
- gen_date = datetime.now().strftime("%Y-%m-%d %H:%M")
153
- c.drawString(margin, footer_y, f"Generated by Shape2Force (S2F) on {gen_date}")
154
- c.drawString(margin, footer_y - 12, "Model: https://huggingface.co/Angione-Lab/Shape2Force")
155
- c.setFillColorRGB(0, 0, 0)
156
-
157
- y_top = page_h_pt - 50
158
  c.setFont("Helvetica-Bold", 14)
159
  c.drawString(margin, y_top, "Region Measurement Report")
160
  c.setFont("Helvetica", 10)
161
  c.drawString(margin, y_top - 14, f"Image: {base_name}")
162
- y_top -= 35
163
 
164
- # Images (centered)
165
- img_bottom = y_top - img_h
166
  bf_pil = Image.fromarray(bf_img) if bf_img.ndim == 2 else Image.fromarray(
167
  cv2.cvtColor(bf_img, cv2.COLOR_BGR2RGB)
168
  )
 
13
  from config.constants import COLORMAPS
14
 
15
 
16
+ def _draw_pdf_footer(c, margin=72, footer_y=40, include_web_app=False):
17
+ """Draw common footer for S2F PDF reports."""
18
+ c.setFont("Helvetica", 8)
19
+ c.setFillColorRGB(0.4, 0.4, 0.4)
20
+ gen_date = datetime.now().strftime("%Y-%m-%d %H:%M")
21
+ c.drawString(margin, footer_y, f"Generated by Shape2Force (S2F) on {gen_date}")
22
+ c.drawString(margin, footer_y - 12, "Model: https://huggingface.co/Angione-Lab/Shape2Force")
23
+ if include_web_app:
24
+ c.drawString(margin, footer_y - 24, "Web app: https://huggingface.co/spaces/Angione-Lab/Shape2force")
25
+ c.setFillColorRGB(0, 0, 0)
26
+
27
+
28
+ def _pdf_image_layout(page_w_pt, page_h_pt, margin=72, n_images=2):
29
+ """Return layout dict for centered side-by-side images: img_w, img_h, img_gap, img_left, bf_x, hm_x, img_bottom, y_top."""
30
+ img_w = 2.8 * inch
31
+ img_h = 2.8 * inch
32
+ img_gap = 20
33
+ total_img_width = n_images * img_w + (n_images - 1) * img_gap
34
+ img_left = margin + (page_w_pt - 2 * margin - total_img_width) / 2
35
+ bf_x = img_left
36
+ hm_x = img_left + img_w + img_gap
37
+ y_top = page_h_pt - 50
38
+ img_bottom = y_top - 35 - img_h # header (title + image name) takes 35pt
39
+ return {
40
+ "img_w": img_w,
41
+ "img_h": img_h,
42
+ "img_gap": img_gap,
43
+ "img_left": img_left,
44
+ "bf_x": bf_x,
45
+ "hm_x": hm_x,
46
+ "img_bottom": img_bottom,
47
+ "y_top": y_top,
48
+ }
49
+
50
+
51
  def heatmap_to_rgb(scaled_heatmap, colormap_name="Jet"):
52
  """Convert scaled heatmap (float 0-1) to RGB array using the given colormap."""
53
  heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
 
82
  c = canvas.Canvas(buf, pagesize=A4)
83
  c.setTitle("Shape2Force")
84
  c.setAuthor("Angione-Lab")
85
+ page_w_pt, page_h_pt = A4[0], A4[1]
 
86
  margin = 72
87
+ layout = _pdf_image_layout(page_w_pt, page_h_pt, margin)
88
+ img_w = layout["img_w"]
89
+ img_h = layout["img_h"]
90
+ bf_x = layout["bf_x"]
91
+ hm_x = layout["hm_x"]
92
+ img_bottom = layout["img_bottom"]
93
+ y_top = layout["y_top"]
94
+
95
+ _draw_pdf_footer(c, margin=margin, include_web_app=True)
 
 
 
 
 
 
 
 
 
 
96
  c.setFont("Helvetica-Bold", 16)
97
  c.drawString(margin, y_top, "Shape2Force (S2F) - Prediction Report")
98
  c.setFont("Helvetica", 10)
99
  c.drawString(margin, y_top - 14, f"Image: {base_name}")
100
  y_top -= 35
101
 
 
102
  img_pil = Image.fromarray(img) if img.ndim == 2 else Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
103
  img_buf = io.BytesIO()
104
  img_pil.save(img_buf, format="PNG")
 
156
  c = canvas.Canvas(buf, pagesize=A4)
157
  c.setTitle("Shape2Force - Region Measurement")
158
  c.setAuthor("Angione-Lab")
159
+ page_w_pt, page_h_pt = A4[0], A4[1]
 
160
  margin = 72
161
+ layout = _pdf_image_layout(page_w_pt, page_h_pt, margin)
162
+ img_w = layout["img_w"]
163
+ img_h = layout["img_h"]
164
+ bf_x = layout["bf_x"]
165
+ hm_x = layout["hm_x"]
166
+ img_bottom = layout["img_bottom"]
167
+ y_top = layout["y_top"]
168
+
169
+ _draw_pdf_footer(c, margin=margin)
 
 
 
 
 
 
 
 
 
 
170
  c.setFont("Helvetica-Bold", 14)
171
  c.drawString(margin, y_top, "Region Measurement Report")
172
  c.setFont("Helvetica", 10)
173
  c.drawString(margin, y_top - 14, f"Image: {base_name}")
 
174
 
 
 
175
  bf_pil = Image.fromarray(bf_img) if bf_img.ndim == 2 else Image.fromarray(
176
  cv2.cvtColor(bf_img, cv2.COLOR_BGR2RGB)
177
  )
S2FApp/utils/substrate_settings.py CHANGED
@@ -5,6 +5,8 @@ Loads from config/substrate_settings.json - users can edit this file to add/modi
5
  import os
6
  import json
7
 
 
 
8
 
9
  def _default_config_path():
10
  """Default path to substrate settings config (S2F/config/substrate_settings.json)."""
@@ -50,7 +52,7 @@ def resolve_substrate(name, config=None, config_path=None):
50
 
51
  s = (name or '').strip()
52
  if not s:
53
- return config.get('default_substrate', 'Fibroblasts_Fibronectin_6KPa')
54
 
55
  substrates = config.get('substrates', {})
56
  s_lower = s.lower()
@@ -61,7 +63,7 @@ def resolve_substrate(name, config=None, config_path=None):
61
  if s_lower.startswith(key.lower()) or key.lower().startswith(s_lower):
62
  return key
63
 
64
- return config.get('default_substrate', 'Fibroblasts_Fibronectin_6KPa')
65
 
66
 
67
  def get_settings_of_category(substrate_name, config=None, config_path=None):
@@ -81,12 +83,15 @@ def get_settings_of_category(substrate_name, config=None, config_path=None):
81
 
82
  substrate_key = resolve_substrate(substrate_name, config=config)
83
  substrates = config.get('substrates', {})
84
- default = config.get('default_substrate', 'Fibroblasts_Fibronectin_6KPa')
85
 
86
  if substrate_key in substrates:
87
  return substrates[substrate_key].copy()
88
 
89
- default_settings = substrates.get(default, {'name': 'Fibroblasts on Fibronectin (6 kPa)', 'pixelsize': 3.0769, 'young': 6000})
 
 
 
90
  return default_settings.copy()
91
 
92
 
 
5
  import os
6
  import json
7
 
8
+ from config.constants import DEFAULT_SUBSTRATE
9
+
10
 
11
  def _default_config_path():
12
  """Default path to substrate settings config (S2F/config/substrate_settings.json)."""
 
52
 
53
  s = (name or '').strip()
54
  if not s:
55
+ return config.get('default_substrate', DEFAULT_SUBSTRATE)
56
 
57
  substrates = config.get('substrates', {})
58
  s_lower = s.lower()
 
63
  if s_lower.startswith(key.lower()) or key.lower().startswith(s_lower):
64
  return key
65
 
66
+ return config.get('default_substrate', DEFAULT_SUBSTRATE)
67
 
68
 
69
  def get_settings_of_category(substrate_name, config=None, config_path=None):
 
83
 
84
  substrate_key = resolve_substrate(substrate_name, config=config)
85
  substrates = config.get('substrates', {})
86
+ default = config.get('default_substrate', DEFAULT_SUBSTRATE)
87
 
88
  if substrate_key in substrates:
89
  return substrates[substrate_key].copy()
90
 
91
+ default_settings = substrates.get(
92
+ default,
93
+ {'name': 'Fibroblasts on Fibronectin (6 kPa)', 'pixelsize': 3.0769, 'young': 6000},
94
+ )
95
  return default_settings.copy()
96
 
97