File size: 14,702 Bytes
0e1a816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6727da5
0e1a816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7068f5c
 
 
 
 
 
 
 
 
 
0e1a816
 
 
6727da5
0e1a816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
"""Measure tool: drawable canvas, region metrics, and downloads."""
import csv
import html
import io
import os

import cv2
import numpy as np
import streamlit as st
from PIL import Image

from config.constants import CANVAS_SIZE, DRAW_TOOLS, TOOL_LABELS
from utils.report import heatmap_to_rgb_with_contour, create_measure_pdf_report
from ui.heatmaps import make_annotated_heatmap_multi_regions

try:
    from streamlit_drawable_canvas import st_canvas
    HAS_DRAWABLE_CANVAS = True
except (ImportError, AttributeError):
    HAS_DRAWABLE_CANVAS = False


def _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h):
    """Convert a single canvas object to polygon points in heatmap coords. Returns None if invalid."""
    obj_type = obj.get("type", "")
    pts = []
    if obj_type == "rect":
        left = obj.get("left", 0)
        top = obj.get("top", 0)
        w = obj.get("width", 0)
        h = obj.get("height", 0)
        pts = np.array([
            [left, top], [left + w, top], [left + w, top + h], [left, top + h]
        ], dtype=np.float32)
    elif obj_type == "circle" or obj_type == "ellipse":
        left = obj.get("left", 0)
        top = obj.get("top", 0)
        width = obj.get("width", 0)
        height = obj.get("height", 0)
        radius = obj.get("radius", 0)
        angle_deg = obj.get("angle", 0)
        if radius > 0:
            rx = ry = radius
            angle_rad = np.deg2rad(angle_deg)
            cx = left + radius * np.cos(angle_rad)
            cy = top + radius * np.sin(angle_rad)
        else:
            rx = width / 2 if width > 0 else 0
            ry = height / 2 if height > 0 else 0
            if rx <= 0 or ry <= 0:
                return None
            cx = left + rx
            cy = top + ry
        if rx <= 0 or ry <= 0:
            return None
        n = 32
        angles = np.linspace(0, 2 * np.pi, n, endpoint=False)
        pts = np.column_stack([cx + rx * np.cos(angles), cy + ry * np.sin(angles)]).astype(np.float32)
    elif obj_type == "path":
        path = obj.get("path", [])
        for cmd in path:
            if isinstance(cmd, (list, tuple)) and len(cmd) >= 3:
                if cmd[0] in ("M", "L"):
                    pts.append([float(cmd[1]), float(cmd[2])])
                elif cmd[0] == "Q" and len(cmd) >= 5:
                    pts.append([float(cmd[3]), float(cmd[4])])
                elif cmd[0] == "C" and len(cmd) >= 7:
                    pts.append([float(cmd[5]), float(cmd[6])])
        if len(pts) < 3:
            return None
        pts = np.array(pts, dtype=np.float32)
    else:
        return None
    pts[:, 0] *= scale_x
    pts[:, 1] *= scale_y
    pts = np.clip(pts, 0, [heatmap_w - 1, heatmap_h - 1]).astype(np.int32)
    return pts


def parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
    """Parse drawn shapes and return a list of individual masks (one per shape)."""
    if not json_data or "objects" not in json_data or not json_data["objects"]:
        return []
    scale_x = heatmap_w / canvas_w
    scale_y = heatmap_h / canvas_h
    masks = []
    for obj in json_data["objects"]:
        pts = _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h)
        if pts is None:
            continue
        mask = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8)
        cv2.fillPoly(mask, [pts], 1)
        masks.append(mask)
    return masks


def build_original_vals(raw_heatmap, pixel_sum, force):
    """Build original_vals dict for measure tool (full map)."""
    return {
        "pixel_sum": pixel_sum,
        "force": force,
        "max": float(np.max(raw_heatmap)),
        "mean": float(np.mean(raw_heatmap)),
    }


def _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force):
    """Compute metrics over estimated cell area only."""
    area_px = int(np.sum(cell_mask))
    if area_px == 0:
        return None, None, None
    region_values = raw_heatmap * cell_mask
    cell_pixel_sum = float(np.sum(region_values))
    cell_force = cell_pixel_sum * (force / pixel_sum) if pixel_sum > 0 else cell_pixel_sum
    cell_mean = cell_pixel_sum / area_px
    return cell_pixel_sum, cell_force, cell_mean


def build_cell_vals(raw_heatmap, cell_mask, pixel_sum, force):
    """Build cell_vals dict for measure tool (estimated cell area). Returns None if invalid."""
    cell_pixel_sum, cell_force, cell_mean = _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force)
    if cell_pixel_sum is None:
        return None
    region_values = raw_heatmap * cell_mask
    region_nonzero = region_values[cell_mask > 0]
    cell_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0
    return {
        "pixel_sum": cell_pixel_sum,
        "force": cell_force,
        "max": cell_max,
        "mean": cell_mean,
    }


def compute_region_metrics(raw_heatmap, mask, original_vals=None):
    """Compute region metrics from mask."""
    area_px = int(np.sum(mask))
    region_values = raw_heatmap * mask
    region_nonzero = region_values[mask > 0]
    force_sum = float(np.sum(region_values))
    density = force_sum / area_px if area_px > 0 else 0
    region_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0
    region_mean = float(np.mean(region_nonzero)) if len(region_nonzero) > 0 else 0
    region_force_scaled = (
        force_sum * (original_vals["force"] / original_vals["pixel_sum"])
        if original_vals and original_vals.get("pixel_sum", 0) > 0
        else force_sum
    )
    return {
        "area_px": area_px,
        "force_sum": force_sum,
        "density": density,
        "max": region_max,
        "mean": region_mean,
        "force_scaled": region_force_scaled,
    }


def _draw_contour_on_image(img_rgb, mask, stroke_color=(255, 0, 0), stroke_width=3):
    """Draw contour from mask on RGB image. Resizes mask to match img if needed."""
    h, w = img_rgb.shape[:2]
    if mask.shape[:2] != (h, w):
        mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        cv2.drawContours(img_rgb, contours, -1, stroke_color, stroke_width)
    return img_rgb


def render_region_metrics_and_downloads(metrics_list, masks, heatmap_rgb, input_filename, key_suffix, has_original_vals,
                                        first_region_label=None, bf_img=None, cell_mask=None, colormap_name="Jet"):
    """
    Render per-shape metrics table and download buttons.
    first_region_label: custom label for first row (e.g. 'Auto boundary').
    masks: list of region masks (user-drawn only; used for labeled heatmap with R1, R2...).
    """
    base_name = os.path.splitext(input_filename or "image")[0]
    st.markdown("**Regions (each selection = one row)**")
    if has_original_vals:
        headers = ["Region", "Area", "F.sum", "Force", "Max", "Mean"]
        csv_rows = [["image", "region"] + headers[1:]]
    else:
        headers = ["Region", "Area (px²)", "Force sum", "Mean"]
        csv_rows = [["image", "region", "Area", "Force sum", "Mean"]]
    table_rows = [headers]
    for i, metrics in enumerate(metrics_list, 1):
        region_label = first_region_label if (i == 1 and first_region_label) else f"Region {i - (1 if first_region_label else 0)}"
        if has_original_vals:
            row = [region_label, str(metrics["area_px"]), f"{metrics['force_sum']:.3f}", f"{metrics['force_scaled']:.1f}",
                   f"{metrics['max']:.3f}", f"{metrics['mean']:.4f}"]
            csv_rows.append([base_name, region_label, metrics["area_px"], f"{metrics['force_sum']:.3f}",
                            f"{metrics['force_scaled']:.1f}", f"{metrics['max']:.3f}", f"{metrics['mean']:.4f}"])
        else:
            row = [region_label, str(metrics["area_px"]), f"{metrics['force_sum']:.4f}", f"{metrics['mean']:.6f}"]
            csv_rows.append([base_name, region_label, metrics["area_px"], f"{metrics['force_sum']:.4f}",
                            f"{metrics['mean']:.6f}"])
        table_rows.append(row)
    # Render as HTML table to avoid Streamlit's default row/column indices
    header = table_rows[0]
    body = table_rows[1:]
    th_cells = "".join(
        f'<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">{html.escape(str(h))}</th>'
        for h in header
    )
    rows_html = [
        "<tr>"
        + "".join(
            f'<td style="border: 1px solid #ddd; padding: 8px;">{html.escape(str(c))}</td>'
            for c in row
        )
        + "</tr>"
        for row in body
    ]
    table_html = (
        f'<table style="border-collapse: collapse; width: 100%;">'
        f"<thead><tr>{th_cells}</tr></thead>"
        f"<tbody>{''.join(rows_html)}</tbody></table>"
    )
    st.markdown(table_html, unsafe_allow_html=True)
    buf_csv = io.StringIO()
    csv.writer(buf_csv).writerows(csv_rows)
    # Annotated heatmap: each region separate with R1, R2 labels (no merging)
    region_labels = [f"R{i + 1}" for i in range(len(masks))]
    heatmap_labeled = make_annotated_heatmap_multi_regions(heatmap_rgb.copy(), masks, region_labels, cell_mask=None)
    buf_img = io.BytesIO()
    Image.fromarray(heatmap_labeled).save(buf_img, format="PNG")
    buf_img.seek(0)
    # PDF report (requires bf_img)
    pdf_bytes = None
    if bf_img is not None:
        pdf_bytes = create_measure_pdf_report(bf_img, heatmap_labeled, table_rows, base_name)
    n_cols = 3 if pdf_bytes is not None else 2
    dl_cols = st.columns(n_cols)
    with dl_cols[0]:
        st.download_button("Download all regions", data=buf_csv.getvalue(),
            file_name=f"{base_name}_all_regions.csv", mime="text/csv",
            key=f"download_all_regions_{key_suffix}", icon=":material/download:")
    with dl_cols[1]:
        st.download_button("Download heatmap", data=buf_img.getvalue(),
            file_name=f"{base_name}_annotated_heatmap.png", mime="image/png",
            key=f"download_annotated_{key_suffix}", icon=":material/image:")
    if pdf_bytes is not None:
        with dl_cols[2]:
            st.download_button("Download report", data=pdf_bytes,
                file_name=f"{base_name}_measure_report.pdf", mime="application/pdf",
                key=f"download_measure_pdf_{key_suffix}", icon=":material/picture_as_pdf:")


def render_region_canvas(display_heatmap, raw_heatmap=None, bf_img=None, original_vals=None, cell_vals=None,
                         cell_mask=None, key_suffix="", input_filename=None, colormap_name="Jet"):
    """Render drawable canvas and region metrics. When cell_vals: show cell area (replaces Full map). Else: show Full map."""
    if not HAS_DRAWABLE_CANVAS:
        st.caption("Install `streamlit-drawable-canvas-fix` for region measurement: `pip install streamlit-drawable-canvas-fix`")
        return
    raw_heatmap = raw_heatmap if raw_heatmap is not None else display_heatmap
    h, w = display_heatmap.shape
    heatmap_rgb = heatmap_to_rgb_with_contour(display_heatmap, colormap_name, cell_mask)
    pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS)

    if bf_img is not None:
        bf_resized = cv2.resize(bf_img, (CANVAS_SIZE, CANVAS_SIZE))
        bf_rgb = cv2.cvtColor(bf_resized, cv2.COLOR_GRAY2RGB) if bf_img.ndim == 2 else cv2.cvtColor(bf_resized, cv2.COLOR_BGR2RGB)
        left_col, right_col = st.columns(2, gap=None)
        with left_col:
            draw_mode = st.selectbox("Tool", DRAW_TOOLS, format_func=lambda x: TOOL_LABELS[x], key=f"draw_mode_region_{key_suffix}")
            st.caption("Left-click add, right-click close.  \nForce map (draw region)")
            canvas_result = st_canvas(
                fill_color="rgba(0, 188, 212, 0.25)", stroke_width=2, stroke_color="#00bcd4",
                background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True,
                height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True,
                key=f"region_measure_canvas_{key_suffix}",
            )
        with right_col:
            vals = cell_vals if cell_vals else original_vals
            if vals:
                label = "Cell area" if cell_vals else "Full map"
                st.markdown(
                    f'<p class="s2f-measure-vals-heading">{html.escape(label)}</p>'
                    f'<div class="s2f-measure-vals-panel"><div class="s2f-measure-vals-grid">'
                    f"<span><strong>Sum:</strong> {vals['pixel_sum']:.1f}</span>"
                    f"<span><strong>Force:</strong> {vals['force']:.1f}</span>"
                    f"<span><strong>Max:</strong> {vals['max']:.3f}</span>"
                    f"<span><strong>Mean:</strong> {vals['mean']:.3f}</span>"
                    f"</div></div>",
                    unsafe_allow_html=True,
                )
            st.caption("Bright-field")
            bf_display = bf_rgb.copy()
            if cell_mask is not None and np.any(cell_mask > 0):
                bf_display = _draw_contour_on_image(bf_display, cell_mask, stroke_color=(255, 0, 0), stroke_width=3)
            st.image(bf_display, width=CANVAS_SIZE)
    else:
        st.markdown("**Draw a region** on the heatmap.")
        draw_mode = st.selectbox("Drawing tool", DRAW_TOOLS,
            format_func=lambda x: "Polygon (free shape)" if x == "polygon" else TOOL_LABELS[x],
            key=f"draw_mode_region_{key_suffix}")
        st.caption("Polygon: left-click to add points, right-click to close.")
        canvas_result = st_canvas(
            fill_color="rgba(0, 188, 212, 0.25)", stroke_width=2, stroke_color="#00bcd4",
            background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True,
            height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True,
            key=f"region_measure_canvas_{key_suffix}",
        )

    if canvas_result.json_data:
        masks = parse_canvas_shapes_to_masks(canvas_result.json_data, CANVAS_SIZE, CANVAS_SIZE, h, w)
        if masks:
            metrics_list = [compute_region_metrics(raw_heatmap, m, original_vals) for m in masks]
            if cell_mask is not None and np.any(cell_mask > 0):
                cell_metrics = compute_region_metrics(raw_heatmap, cell_mask, original_vals)
                metrics_list = [cell_metrics] + metrics_list
            render_region_metrics_and_downloads(
                metrics_list, masks, heatmap_rgb, input_filename, key_suffix, original_vals is not None,
                first_region_label="Auto boundary" if (cell_mask is not None and np.any(cell_mask > 0)) else None,
                bf_img=bf_img, cell_mask=cell_mask, colormap_name=colormap_name,
            )