"""Measure tool: drawable canvas, region metrics, and downloads.""" import csv import html import io import os import cv2 import numpy as np import streamlit as st from PIL import Image from config.constants import CANVAS_SIZE, DRAW_TOOLS, TOOL_LABELS from utils.report import heatmap_to_rgb_with_contour, create_measure_pdf_report from ui.heatmaps import make_annotated_heatmap_multi_regions try: from streamlit_drawable_canvas import st_canvas HAS_DRAWABLE_CANVAS = True except (ImportError, AttributeError): HAS_DRAWABLE_CANVAS = False def _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h): """Convert a single canvas object to polygon points in heatmap coords. Returns None if invalid.""" obj_type = obj.get("type", "") pts = [] if obj_type == "rect": left = obj.get("left", 0) top = obj.get("top", 0) w = obj.get("width", 0) h = obj.get("height", 0) pts = np.array([ [left, top], [left + w, top], [left + w, top + h], [left, top + h] ], dtype=np.float32) elif obj_type == "circle" or obj_type == "ellipse": left = obj.get("left", 0) top = obj.get("top", 0) width = obj.get("width", 0) height = obj.get("height", 0) radius = obj.get("radius", 0) angle_deg = obj.get("angle", 0) if radius > 0: rx = ry = radius angle_rad = np.deg2rad(angle_deg) cx = left + radius * np.cos(angle_rad) cy = top + radius * np.sin(angle_rad) else: rx = width / 2 if width > 0 else 0 ry = height / 2 if height > 0 else 0 if rx <= 0 or ry <= 0: return None cx = left + rx cy = top + ry if rx <= 0 or ry <= 0: return None n = 32 angles = np.linspace(0, 2 * np.pi, n, endpoint=False) pts = np.column_stack([cx + rx * np.cos(angles), cy + ry * np.sin(angles)]).astype(np.float32) elif obj_type == "path": path = obj.get("path", []) for cmd in path: if isinstance(cmd, (list, tuple)) and len(cmd) >= 3: if cmd[0] in ("M", "L"): pts.append([float(cmd[1]), float(cmd[2])]) elif cmd[0] == "Q" and len(cmd) >= 5: pts.append([float(cmd[3]), float(cmd[4])]) elif cmd[0] == "C" and len(cmd) >= 7: pts.append([float(cmd[5]), float(cmd[6])]) if len(pts) < 3: return None pts = np.array(pts, dtype=np.float32) else: return None pts[:, 0] *= scale_x pts[:, 1] *= scale_y pts = np.clip(pts, 0, [heatmap_w - 1, heatmap_h - 1]).astype(np.int32) return pts def parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w): """Parse drawn shapes and return a list of individual masks (one per shape).""" if not json_data or "objects" not in json_data or not json_data["objects"]: return [] scale_x = heatmap_w / canvas_w scale_y = heatmap_h / canvas_h masks = [] for obj in json_data["objects"]: pts = _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h) if pts is None: continue mask = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8) cv2.fillPoly(mask, [pts], 1) masks.append(mask) return masks def build_original_vals(raw_heatmap, pixel_sum, force): """Build original_vals dict for measure tool (full map).""" return { "pixel_sum": pixel_sum, "force": force, "max": float(np.max(raw_heatmap)), "mean": float(np.mean(raw_heatmap)), } def _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force): """Compute metrics over estimated cell area only.""" area_px = int(np.sum(cell_mask)) if area_px == 0: return None, None, None region_values = raw_heatmap * cell_mask cell_pixel_sum = float(np.sum(region_values)) cell_force = cell_pixel_sum * (force / pixel_sum) if pixel_sum > 0 else cell_pixel_sum cell_mean = cell_pixel_sum / area_px return cell_pixel_sum, cell_force, cell_mean def build_cell_vals(raw_heatmap, cell_mask, pixel_sum, force): """Build cell_vals dict for measure tool (estimated cell area). Returns None if invalid.""" cell_pixel_sum, cell_force, cell_mean = _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force) if cell_pixel_sum is None: return None region_values = raw_heatmap * cell_mask region_nonzero = region_values[cell_mask > 0] cell_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0 return { "pixel_sum": cell_pixel_sum, "force": cell_force, "max": cell_max, "mean": cell_mean, } def compute_region_metrics(raw_heatmap, mask, original_vals=None): """Compute region metrics from mask.""" area_px = int(np.sum(mask)) region_values = raw_heatmap * mask region_nonzero = region_values[mask > 0] force_sum = float(np.sum(region_values)) density = force_sum / area_px if area_px > 0 else 0 region_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0 region_mean = float(np.mean(region_nonzero)) if len(region_nonzero) > 0 else 0 region_force_scaled = ( force_sum * (original_vals["force"] / original_vals["pixel_sum"]) if original_vals and original_vals.get("pixel_sum", 0) > 0 else force_sum ) return { "area_px": area_px, "force_sum": force_sum, "density": density, "max": region_max, "mean": region_mean, "force_scaled": region_force_scaled, } def _draw_contour_on_image(img_rgb, mask, stroke_color=(255, 0, 0), stroke_width=3): """Draw contour from mask on RGB image. Resizes mask to match img if needed.""" h, w = img_rgb.shape[:2] if mask.shape[:2] != (h, w): mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: cv2.drawContours(img_rgb, contours, -1, stroke_color, stroke_width) return img_rgb def render_region_metrics_and_downloads(metrics_list, masks, heatmap_rgb, input_filename, key_suffix, has_original_vals, first_region_label=None, bf_img=None, cell_mask=None, colormap_name="Jet"): """ Render per-shape metrics table and download buttons. first_region_label: custom label for first row (e.g. 'Auto boundary'). masks: list of region masks (user-drawn only; used for labeled heatmap with R1, R2...). """ base_name = os.path.splitext(input_filename or "image")[0] st.markdown("**Regions (each selection = one row)**") if has_original_vals: headers = ["Region", "Area", "F.sum", "Force", "Max", "Mean"] csv_rows = [["image", "region"] + headers[1:]] else: headers = ["Region", "Area (px²)", "Force sum", "Mean"] csv_rows = [["image", "region", "Area", "Force sum", "Mean"]] table_rows = [headers] for i, metrics in enumerate(metrics_list, 1): region_label = first_region_label if (i == 1 and first_region_label) else f"Region {i - (1 if first_region_label else 0)}" if has_original_vals: row = [region_label, str(metrics["area_px"]), f"{metrics['force_sum']:.3f}", f"{metrics['force_scaled']:.1f}", f"{metrics['max']:.3f}", f"{metrics['mean']:.4f}"] csv_rows.append([base_name, region_label, metrics["area_px"], f"{metrics['force_sum']:.3f}", f"{metrics['force_scaled']:.1f}", f"{metrics['max']:.3f}", f"{metrics['mean']:.4f}"]) else: row = [region_label, str(metrics["area_px"]), f"{metrics['force_sum']:.4f}", f"{metrics['mean']:.6f}"] csv_rows.append([base_name, region_label, metrics["area_px"], f"{metrics['force_sum']:.4f}", f"{metrics['mean']:.6f}"]) table_rows.append(row) # Render as HTML table to avoid Streamlit's default row/column indices header = table_rows[0] body = table_rows[1:] th_cells = "".join( f'{html.escape(str(h))}' for h in header ) rows_html = [ "" + "".join( f'{html.escape(str(c))}' for c in row ) + "" for row in body ] table_html = ( f'' f"{th_cells}" f"{''.join(rows_html)}
" ) st.markdown(table_html, unsafe_allow_html=True) buf_csv = io.StringIO() csv.writer(buf_csv).writerows(csv_rows) # Annotated heatmap: each region separate with R1, R2 labels (no merging) region_labels = [f"R{i + 1}" for i in range(len(masks))] heatmap_labeled = make_annotated_heatmap_multi_regions(heatmap_rgb.copy(), masks, region_labels, cell_mask=None) buf_img = io.BytesIO() Image.fromarray(heatmap_labeled).save(buf_img, format="PNG") buf_img.seek(0) # PDF report (requires bf_img) pdf_bytes = None if bf_img is not None: pdf_bytes = create_measure_pdf_report(bf_img, heatmap_labeled, table_rows, base_name) n_cols = 3 if pdf_bytes is not None else 2 dl_cols = st.columns(n_cols) with dl_cols[0]: st.download_button("Download all regions", data=buf_csv.getvalue(), file_name=f"{base_name}_all_regions.csv", mime="text/csv", key=f"download_all_regions_{key_suffix}", icon=":material/download:") with dl_cols[1]: st.download_button("Download heatmap", data=buf_img.getvalue(), file_name=f"{base_name}_annotated_heatmap.png", mime="image/png", key=f"download_annotated_{key_suffix}", icon=":material/image:") if pdf_bytes is not None: with dl_cols[2]: st.download_button("Download report", data=pdf_bytes, file_name=f"{base_name}_measure_report.pdf", mime="application/pdf", key=f"download_measure_pdf_{key_suffix}", icon=":material/picture_as_pdf:") def render_region_canvas(display_heatmap, raw_heatmap=None, bf_img=None, original_vals=None, cell_vals=None, cell_mask=None, key_suffix="", input_filename=None, colormap_name="Jet"): """Render drawable canvas and region metrics. When cell_vals: show cell area (replaces Full map). Else: show Full map.""" if not HAS_DRAWABLE_CANVAS: st.caption("Install `streamlit-drawable-canvas-fix` for region measurement: `pip install streamlit-drawable-canvas-fix`") return raw_heatmap = raw_heatmap if raw_heatmap is not None else display_heatmap h, w = display_heatmap.shape heatmap_rgb = heatmap_to_rgb_with_contour(display_heatmap, colormap_name, cell_mask) pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS) if bf_img is not None: bf_resized = cv2.resize(bf_img, (CANVAS_SIZE, CANVAS_SIZE)) bf_rgb = cv2.cvtColor(bf_resized, cv2.COLOR_GRAY2RGB) if bf_img.ndim == 2 else cv2.cvtColor(bf_resized, cv2.COLOR_BGR2RGB) left_col, right_col = st.columns(2, gap=None) with left_col: draw_mode = st.selectbox("Tool", DRAW_TOOLS, format_func=lambda x: TOOL_LABELS[x], key=f"draw_mode_region_{key_suffix}") st.caption("Left-click add, right-click close. \nForce map (draw region)") canvas_result = st_canvas( fill_color="rgba(0, 188, 212, 0.25)", stroke_width=2, stroke_color="#00bcd4", background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True, height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True, key=f"region_measure_canvas_{key_suffix}", ) with right_col: vals = cell_vals if cell_vals else original_vals if vals: label = "Cell area" if cell_vals else "Full map" st.markdown( f'

{html.escape(label)}

' f'
' f"Sum: {vals['pixel_sum']:.1f}" f"Force: {vals['force']:.1f}" f"Max: {vals['max']:.3f}" f"Mean: {vals['mean']:.3f}" f"
", unsafe_allow_html=True, ) st.caption("Bright-field") bf_display = bf_rgb.copy() if cell_mask is not None and np.any(cell_mask > 0): bf_display = _draw_contour_on_image(bf_display, cell_mask, stroke_color=(255, 0, 0), stroke_width=3) st.image(bf_display, width=CANVAS_SIZE) else: st.markdown("**Draw a region** on the heatmap.") draw_mode = st.selectbox("Drawing tool", DRAW_TOOLS, format_func=lambda x: "Polygon (free shape)" if x == "polygon" else TOOL_LABELS[x], key=f"draw_mode_region_{key_suffix}") st.caption("Polygon: left-click to add points, right-click to close.") canvas_result = st_canvas( fill_color="rgba(0, 188, 212, 0.25)", stroke_width=2, stroke_color="#00bcd4", background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True, height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True, key=f"region_measure_canvas_{key_suffix}", ) if canvas_result.json_data: masks = parse_canvas_shapes_to_masks(canvas_result.json_data, CANVAS_SIZE, CANVAS_SIZE, h, w) if masks: metrics_list = [compute_region_metrics(raw_heatmap, m, original_vals) for m in masks] if cell_mask is not None and np.any(cell_mask > 0): cell_metrics = compute_region_metrics(raw_heatmap, cell_mask, original_vals) metrics_list = [cell_metrics] + metrics_list render_region_metrics_and_downloads( metrics_list, masks, heatmap_rgb, input_filename, key_suffix, original_vals is not None, first_region_label="Auto boundary" if (cell_mask is not None and np.any(cell_mask > 0)) else None, bf_img=bf_img, cell_mask=cell_mask, colormap_name=colormap_name, )