updated for clarity
Browse files- S2FApp/app.py +127 -34
S2FApp/app.py
CHANGED
|
@@ -27,7 +27,7 @@ st.markdown("""
|
|
| 27 |
</style>
|
| 28 |
""", unsafe_allow_html=True)
|
| 29 |
st.title("🔬 Shape2Force (S2F)")
|
| 30 |
-
st.caption("Predict force maps from bright
|
| 31 |
|
| 32 |
# Folders: checkpoints in subfolders by model type (single_cell / spheroid)
|
| 33 |
ckp_base = os.path.join(S2F_ROOT, "ckp")
|
|
@@ -111,24 +111,17 @@ with st.sidebar:
|
|
| 111 |
except FileNotFoundError:
|
| 112 |
st.error("config/substrate_settings.json not found")
|
| 113 |
|
| 114 |
-
st.divider()
|
| 115 |
-
st.subheader("Display")
|
| 116 |
-
display_size = st.slider("Image size (px)", min_value=200, max_value=800, value=350, step=50,
|
| 117 |
-
help="Adjust display size. Drag to pan, scroll to zoom.")
|
| 118 |
-
|
| 119 |
-
st.divider()
|
| 120 |
-
|
| 121 |
# Main area: image input
|
| 122 |
-
img_source = st.radio("Image source", ["Upload", "
|
| 123 |
img = None
|
| 124 |
uploaded = None
|
| 125 |
selected_sample = None
|
| 126 |
|
| 127 |
if img_source == "Upload":
|
| 128 |
uploaded = st.file_uploader(
|
| 129 |
-
"Upload bright
|
| 130 |
type=["tif", "tiff", "png", "jpg", "jpeg"],
|
| 131 |
-
help="Bright
|
| 132 |
)
|
| 133 |
if uploaded:
|
| 134 |
bytes_data = uploaded.read()
|
|
@@ -141,7 +134,7 @@ else:
|
|
| 141 |
sample_subfolder_name = "single_cell" if model_type == "single_cell" else "spheroid"
|
| 142 |
if sample_files:
|
| 143 |
selected_sample = st.selectbox(
|
| 144 |
-
"Select
|
| 145 |
sample_files,
|
| 146 |
format_func=lambda x: x,
|
| 147 |
key=f"sample_{model_type}",
|
|
@@ -149,8 +142,8 @@ else:
|
|
| 149 |
if selected_sample:
|
| 150 |
sample_path = os.path.join(sample_folder, selected_sample)
|
| 151 |
img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
|
| 152 |
-
# Show
|
| 153 |
-
st.caption(f"
|
| 154 |
n_cols = min(5, len(sample_files))
|
| 155 |
cols = st.columns(n_cols)
|
| 156 |
for i, fname in enumerate(sample_files[:8]): # show up to 8
|
|
@@ -160,12 +153,24 @@ else:
|
|
| 160 |
if sample_img is not None:
|
| 161 |
st.image(sample_img, caption=fname, width='content')
|
| 162 |
else:
|
| 163 |
-
st.info(f"No
|
| 164 |
|
| 165 |
run = st.button("Run prediction", type="primary")
|
| 166 |
has_image = img is not None
|
| 167 |
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
st.markdown(f"**Using checkpoint:** `ckp/{ckp_subfolder_name}/{checkpoint}`")
|
| 170 |
with st.spinner("Loading model and predicting..."):
|
| 171 |
try:
|
|
@@ -185,23 +190,18 @@ if run and checkpoint and has_image:
|
|
| 185 |
|
| 186 |
st.success("Prediction complete!")
|
| 187 |
|
| 188 |
-
#
|
| 189 |
-
|
| 190 |
-
with
|
| 191 |
-
st.
|
| 192 |
-
with
|
| 193 |
-
st.
|
| 194 |
-
|
| 195 |
-
st.metric("Heatmap max", f"{np.max(heatmap):.4f}")
|
| 196 |
-
with col4:
|
| 197 |
-
st.metric("Heatmap mean", f"{np.mean(heatmap):.4f}")
|
| 198 |
-
|
| 199 |
-
# Visualization - Plotly with zoom/pan
|
| 200 |
-
fig_pl = make_subplots(rows=1, cols=2, subplot_titles=["", ""])
|
| 201 |
fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
|
| 202 |
-
fig_pl.add_trace(go.Heatmap(z=heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True
|
|
|
|
| 203 |
fig_pl.update_layout(
|
| 204 |
-
height=
|
| 205 |
margin=dict(l=10, r=10, t=10, b=10),
|
| 206 |
xaxis=dict(scaleanchor="y", scaleratio=1),
|
| 207 |
xaxis2=dict(scaleanchor="y2", scaleratio=1),
|
|
@@ -210,6 +210,34 @@ if run and checkpoint and has_image:
|
|
| 210 |
fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
|
| 211 |
st.plotly_chart(fig_pl, use_container_width=True)
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
# Download
|
| 214 |
heatmap_uint8 = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8)
|
| 215 |
heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
|
@@ -219,18 +247,83 @@ if run and checkpoint and has_image:
|
|
| 219 |
pil_heatmap.save(buf_hm, format="PNG")
|
| 220 |
buf_hm.seek(0)
|
| 221 |
st.download_button("Download Heatmap", data=buf_hm.getvalue(),
|
| 222 |
-
file_name="s2f_heatmap.png", mime="image/png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
except Exception as e:
|
| 225 |
st.error(f"Prediction failed: {e}")
|
| 226 |
import traceback
|
| 227 |
st.code(traceback.format_exc())
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
elif run and not checkpoint:
|
| 230 |
st.warning("Please add checkpoint files to the ckp/ folder and select one.")
|
| 231 |
elif run and not has_image:
|
| 232 |
-
st.warning("Please upload an image or select
|
| 233 |
|
| 234 |
# Footer
|
| 235 |
st.sidebar.divider()
|
| 236 |
-
st.sidebar.caption("
|
|
|
|
|
|
| 27 |
</style>
|
| 28 |
""", unsafe_allow_html=True)
|
| 29 |
st.title("🔬 Shape2Force (S2F)")
|
| 30 |
+
st.caption("Predict traction force maps from bright-field microscopy images of cells or spheroids")
|
| 31 |
|
| 32 |
# Folders: checkpoints in subfolders by model type (single_cell / spheroid)
|
| 33 |
ckp_base = os.path.join(S2F_ROOT, "ckp")
|
|
|
|
| 111 |
except FileNotFoundError:
|
| 112 |
st.error("config/substrate_settings.json not found")
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
# Main area: image input
|
| 115 |
+
img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed")
|
| 116 |
img = None
|
| 117 |
uploaded = None
|
| 118 |
selected_sample = None
|
| 119 |
|
| 120 |
if img_source == "Upload":
|
| 121 |
uploaded = st.file_uploader(
|
| 122 |
+
"Upload bright-field image",
|
| 123 |
type=["tif", "tiff", "png", "jpg", "jpeg"],
|
| 124 |
+
help="Bright-field microscopy image of a cell or spheroid on a substrate (grayscale or RGB). The model will predict traction forces from the cell shape.",
|
| 125 |
)
|
| 126 |
if uploaded:
|
| 127 |
bytes_data = uploaded.read()
|
|
|
|
| 134 |
sample_subfolder_name = "single_cell" if model_type == "single_cell" else "spheroid"
|
| 135 |
if sample_files:
|
| 136 |
selected_sample = st.selectbox(
|
| 137 |
+
"Select example image",
|
| 138 |
sample_files,
|
| 139 |
format_func=lambda x: x,
|
| 140 |
key=f"sample_{model_type}",
|
|
|
|
| 142 |
if selected_sample:
|
| 143 |
sample_path = os.path.join(sample_folder, selected_sample)
|
| 144 |
img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
|
| 145 |
+
# Show example thumbnails (filtered by model type)
|
| 146 |
+
st.caption(f"Example images from `samples/{sample_subfolder_name}/`")
|
| 147 |
n_cols = min(5, len(sample_files))
|
| 148 |
cols = st.columns(n_cols)
|
| 149 |
for i, fname in enumerate(sample_files[:8]): # show up to 8
|
|
|
|
| 153 |
if sample_img is not None:
|
| 154 |
st.image(sample_img, caption=fname, width='content')
|
| 155 |
else:
|
| 156 |
+
st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.")
|
| 157 |
|
| 158 |
run = st.button("Run prediction", type="primary")
|
| 159 |
has_image = img is not None
|
| 160 |
|
| 161 |
+
# Persist results in session state so they survive re-runs (e.g. when clicking Download)
|
| 162 |
+
if "prediction_result" not in st.session_state:
|
| 163 |
+
st.session_state["prediction_result"] = None
|
| 164 |
+
|
| 165 |
+
# Show results if we just ran prediction OR we have cached results from a previous run
|
| 166 |
+
just_ran = run and checkpoint and has_image
|
| 167 |
+
cached = st.session_state["prediction_result"]
|
| 168 |
+
key_img = (uploaded.name if uploaded else None) if img_source == "Upload" else selected_sample
|
| 169 |
+
current_key = (model_type, checkpoint, key_img)
|
| 170 |
+
has_cached = cached is not None and cached.get("cache_key") == current_key
|
| 171 |
+
|
| 172 |
+
if just_ran:
|
| 173 |
+
st.session_state["prediction_result"] = None # Clear before new run
|
| 174 |
st.markdown(f"**Using checkpoint:** `ckp/{ckp_subfolder_name}/{checkpoint}`")
|
| 175 |
with st.spinner("Loading model and predicting..."):
|
| 176 |
try:
|
|
|
|
| 190 |
|
| 191 |
st.success("Prediction complete!")
|
| 192 |
|
| 193 |
+
# Visualization - Plotly with zoom/pan, annotated (titles in Streamlit to avoid clipping)
|
| 194 |
+
tit1, tit2 = st.columns(2)
|
| 195 |
+
with tit1:
|
| 196 |
+
st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Input: Bright-field image</p>', unsafe_allow_html=True)
|
| 197 |
+
with tit2:
|
| 198 |
+
st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
|
| 199 |
+
fig_pl = make_subplots(rows=1, cols=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
|
| 201 |
+
fig_pl.add_trace(go.Heatmap(z=heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
|
| 202 |
+
colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
|
| 203 |
fig_pl.update_layout(
|
| 204 |
+
height=400,
|
| 205 |
margin=dict(l=10, r=10, t=10, b=10),
|
| 206 |
xaxis=dict(scaleanchor="y", scaleratio=1),
|
| 207 |
xaxis2=dict(scaleanchor="y2", scaleratio=1),
|
|
|
|
| 210 |
fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
|
| 211 |
st.plotly_chart(fig_pl, use_container_width=True)
|
| 212 |
|
| 213 |
+
# Metrics with help (below plot)
|
| 214 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 215 |
+
with col1:
|
| 216 |
+
st.metric("Sum of all pixels", f"{pixel_sum:.2f}", help="Raw sum of all pixel values in the force map")
|
| 217 |
+
with col2:
|
| 218 |
+
st.metric("Cell force (scaled)", f"{force:.2f}", help="Total traction force in physical units")
|
| 219 |
+
with col3:
|
| 220 |
+
st.metric("Heatmap max", f"{np.max(heatmap):.4f}", help="Peak force intensity in the map")
|
| 221 |
+
with col4:
|
| 222 |
+
st.metric("Heatmap mean", f"{np.mean(heatmap):.4f}", help="Average force intensity")
|
| 223 |
+
|
| 224 |
+
# How to read (below numbers)
|
| 225 |
+
with st.expander("ℹ️ How to read the results"):
|
| 226 |
+
st.markdown("""
|
| 227 |
+
**Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
|
| 228 |
+
This is the raw image you provided—it shows cell shape but not forces.
|
| 229 |
+
|
| 230 |
+
**Output (right):** Predicted traction force map.
|
| 231 |
+
- **Color** indicates force magnitude: blue = low, red = high
|
| 232 |
+
- **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
|
| 233 |
+
- Values are normalized to [0, 1] for visualization
|
| 234 |
+
|
| 235 |
+
**Metrics:**
|
| 236 |
+
- **Sum of all pixels:** Total force is the sum of all pixels in the force map. Each pixel represents the magnitude of force at that location; summing them gives the overall traction.
|
| 237 |
+
- **Cell force (scaled):** Total traction force in physical units (scaled by substrate stiffness)
|
| 238 |
+
- **Heatmap max/mean:** Peak and average force intensity in the map
|
| 239 |
+
""")
|
| 240 |
+
|
| 241 |
# Download
|
| 242 |
heatmap_uint8 = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8)
|
| 243 |
heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
|
|
|
| 247 |
pil_heatmap.save(buf_hm, format="PNG")
|
| 248 |
buf_hm.seek(0)
|
| 249 |
st.download_button("Download Heatmap", data=buf_hm.getvalue(),
|
| 250 |
+
file_name="s2f_heatmap.png", mime="image/png", key="download_heatmap")
|
| 251 |
+
|
| 252 |
+
# Store in session state so results persist when user clicks Download
|
| 253 |
+
cache_key = (model_type, checkpoint, key_img)
|
| 254 |
+
st.session_state["prediction_result"] = {
|
| 255 |
+
"img": img.copy(),
|
| 256 |
+
"heatmap": heatmap.copy(),
|
| 257 |
+
"force": force,
|
| 258 |
+
"pixel_sum": pixel_sum,
|
| 259 |
+
"cache_key": cache_key,
|
| 260 |
+
}
|
| 261 |
|
| 262 |
except Exception as e:
|
| 263 |
st.error(f"Prediction failed: {e}")
|
| 264 |
import traceback
|
| 265 |
st.code(traceback.format_exc())
|
| 266 |
|
| 267 |
+
elif has_cached:
|
| 268 |
+
# Show cached results (e.g. after clicking Download)
|
| 269 |
+
r = st.session_state["prediction_result"]
|
| 270 |
+
img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
|
| 271 |
+
st.success("Prediction complete!")
|
| 272 |
+
tit1, tit2 = st.columns(2)
|
| 273 |
+
with tit1:
|
| 274 |
+
st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Input: Bright-field image</p>', unsafe_allow_html=True)
|
| 275 |
+
with tit2:
|
| 276 |
+
st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
|
| 277 |
+
fig_pl = make_subplots(rows=1, cols=2)
|
| 278 |
+
fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
|
| 279 |
+
fig_pl.add_trace(go.Heatmap(z=heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
|
| 280 |
+
colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
|
| 281 |
+
fig_pl.update_layout(height=400, margin=dict(l=10, r=10, t=10, b=10),
|
| 282 |
+
xaxis=dict(scaleanchor="y", scaleratio=1),
|
| 283 |
+
xaxis2=dict(scaleanchor="y2", scaleratio=1))
|
| 284 |
+
fig_pl.update_xaxes(showticklabels=False)
|
| 285 |
+
fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
|
| 286 |
+
st.plotly_chart(fig_pl, use_container_width=True)
|
| 287 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 288 |
+
with col1:
|
| 289 |
+
st.metric("Sum of all pixels", f"{pixel_sum:.2f}", help="Raw sum of all pixel values in the force map")
|
| 290 |
+
with col2:
|
| 291 |
+
st.metric("Cell force (scaled)", f"{force:.2f}", help="Total traction force in physical units")
|
| 292 |
+
with col3:
|
| 293 |
+
st.metric("Heatmap max", f"{np.max(heatmap):.4f}", help="Peak force intensity in the map")
|
| 294 |
+
with col4:
|
| 295 |
+
st.metric("Heatmap mean", f"{np.mean(heatmap):.4f}", help="Average force intensity")
|
| 296 |
+
with st.expander("ℹ️ How to read the results"):
|
| 297 |
+
st.markdown("""
|
| 298 |
+
**Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
|
| 299 |
+
This is the raw image you provided—it shows cell shape but not forces.
|
| 300 |
+
|
| 301 |
+
**Output (right):** Predicted traction force map.
|
| 302 |
+
- **Color** indicates force magnitude: blue = low, red = high
|
| 303 |
+
- **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
|
| 304 |
+
- Values are normalized to [0, 1] for visualization
|
| 305 |
+
|
| 306 |
+
**Metrics:**
|
| 307 |
+
- **Sum of all pixels:** Total force is the sum of all pixels in the force map. Each pixel represents the magnitude of force at that location; summing them gives the overall traction.
|
| 308 |
+
- **Cell force (scaled):** Total traction force in physical units (scaled by substrate stiffness)
|
| 309 |
+
- **Heatmap max/mean:** Peak and average force intensity in the map
|
| 310 |
+
""")
|
| 311 |
+
heatmap_uint8 = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8)
|
| 312 |
+
heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
| 313 |
+
heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
|
| 314 |
+
pil_heatmap = Image.fromarray(heatmap_rgb)
|
| 315 |
+
buf_hm = io.BytesIO()
|
| 316 |
+
pil_heatmap.save(buf_hm, format="PNG")
|
| 317 |
+
buf_hm.seek(0)
|
| 318 |
+
st.download_button("Download Heatmap", data=buf_hm.getvalue(),
|
| 319 |
+
file_name="s2f_heatmap.png", mime="image/png", key="download_cached")
|
| 320 |
+
|
| 321 |
elif run and not checkpoint:
|
| 322 |
st.warning("Please add checkpoint files to the ckp/ folder and select one.")
|
| 323 |
elif run and not has_image:
|
| 324 |
+
st.warning("Please upload an image or select an example.")
|
| 325 |
|
| 326 |
# Footer
|
| 327 |
st.sidebar.divider()
|
| 328 |
+
st.sidebar.caption(f"Checkpoint: `ckp/{ckp_subfolder_name}/`")
|
| 329 |
+
st.sidebar.caption(f"Examples: `samples/{ckp_subfolder_name}/`")
|