File size: 3,849 Bytes
7b07ae4
 
 
 
00e4d88
7b07ae4
 
 
 
fa3a167
7b07ae4
 
 
fa3a167
 
 
7b07ae4
 
 
 
 
 
 
 
 
 
fa3a167
 
 
7b07ae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00e4d88
 
dee51a1
 
 
 
 
 
 
 
 
7b07ae4
fa3a167
 
7b07ae4
 
 
 
 
 
 
 
 
fa3a167
7b07ae4
fa3a167
 
 
 
 
 
7b07ae4
fa3a167
 
 
7b07ae4
 
 
00e4d88
7b07ae4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Cell segmentation from force map for background exclusion."""
import numpy as np
from scipy.ndimage import gaussian_filter
from skimage.filters import threshold_otsu
from skimage.morphology import closing, opening, dilation, remove_small_objects, disk
from skimage.measure import label, regionprops


def estimate_cell_mask(heatmap, sigma=2, min_size=200, exclude_full_image=True,
                       threshold_relax=0.85, dilate_radius=4, min_area_ratio=0.2):
    """
    Estimate cell region from force map using Otsu thresholding and morphological cleanup.

    Supports multiple disconnected regions (e.g., two cells): components whose area is
    at least min_area_ratio of the largest are merged into the final mask.

    Args:
        heatmap: 2D float array [0, 1] - predicted force map
        sigma: Gaussian smoothing sigma to reduce noise. Default 2.
        min_size: Minimum object size in pixels; smaller objects removed. Default 200.
        exclude_full_image: If True, exclude the largest connected component when it
            covers most of the image (>70%) and use the second largest. Default True.
        threshold_relax: Multiply Otsu threshold by this (<1 = looser, include more pixels).
            Default 0.85.
        dilate_radius: Radius to dilate mask outward to include surrounding pixels.
            Default 4.
        min_area_ratio: Include components with area >= this fraction of the largest
            component (0–1). E.g. 0.2 = include regions at least 20% the size of the
            largest. Handles multiple disconnected force regions. Default 0.2.

    Returns:
        mask: Binary uint8 array, 1 = estimated cell, 0 = background
    """
    heatmap = np.clip(heatmap, 0, 1).astype(np.float64)
    if np.max(heatmap) <= 0:
        return np.zeros_like(heatmap, dtype=np.uint8)

    # Smooth to reduce noise
    smoothed = gaussian_filter(heatmap, sigma=sigma)

    # Otsu automatic threshold, relaxed to include more pixels
    thresh = threshold_otsu(smoothed) * threshold_relax
    mask = (smoothed > thresh).astype(np.uint8)

    # Morphological cleanup
    mask = closing(mask, disk(5)).astype(np.uint8)
    mask = opening(mask, disk(3)).astype(np.uint8)
    mask_bool = mask.astype(bool)
    min_size_int = max(int(min_size), 0)
    # skimage >=0.26 deprecates `min_size` in favor of `max_size` (inclusive threshold).
    # Use `min_size - 1` so behavior stays equivalent to the prior strict `< min_size` rule.
    try:
        mask_bool = remove_small_objects(mask_bool, max_size=max(min_size_int - 1, 0))
    except TypeError:
        mask_bool = remove_small_objects(mask_bool, min_size=min_size_int)
    mask = mask_bool.astype(np.uint8)

    # Select component(s): optionally exclude full-image background, then merge
    # all significant components (handles multiple disconnected force regions)
    labeled = label(mask)
    props = list(regionprops(labeled))

    if len(props) == 0:
        return np.zeros_like(heatmap, dtype=np.uint8)

    props_sorted = sorted(props, key=lambda x: x.area, reverse=True)
    total_px = heatmap.shape[0] * heatmap.shape[1]

    # Skip largest if it covers most of image (likely background)
    if exclude_full_image and len(props_sorted) >= 2 and props_sorted[0].area > 0.7 * total_px:
        props_sorted = props_sorted[1:]

    # Reference area for "significant" components
    ref_area = props_sorted[0].area
    # Include all components with area >= min_area_ratio * ref_area
    labels_to_keep = [p.label for p in props_sorted if p.area >= min_area_ratio * ref_area]

    mask = np.zeros_like(labeled, dtype=np.uint8)
    for lab in labels_to_keep:
        mask[labeled == lab] = 1

    # Dilate to include surrounding pixels
    if dilate_radius > 0:
        mask = dilation(mask, disk(dilate_radius)).astype(np.uint8)

    return mask