File size: 2,910 Bytes
4db9215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
DAB (3,3'-Diaminobenzidine) stain extraction via color deconvolution.

Reference: Ruifrok & Johnston, "Quantification of histochemical
staining by color deconvolution", Anal Quant Cytol Histol 2001
"""

import torch
import torch.nn.functional as F


class DABExtractor:
    """Extract DAB stain intensity from IHC images using color deconvolution.

    Uses the Ruifrok & Johnston H-DAB stain matrix with softplus smoothing
    for differentiable training loss computation.
    """

    def __init__(self, device='cuda'):
        self.device = device

        # Standard H-DAB stain matrix (Ruifrok & Johnston)
        # Each row is a stain vector in RGB optical density space
        self.stain_matrix = torch.tensor([
            [0.268, 0.570, 0.776],  # DAB (brown)
            [0.650, 0.704, 0.286],  # Hematoxylin (blue)
        ], device=device, dtype=torch.float32)

        # Pseudo-inverse for deconvolution: [3, 2]
        self.deconv_matrix = torch.linalg.pinv(self.stain_matrix.T)

    def rgb_to_od(self, rgb_images: torch.Tensor) -> torch.Tensor:
        """Convert RGB [0,1] to optical density: OD = -log10(I/I0)."""
        rgb_images = rgb_images.clamp(1e-6, 1.0)
        return -torch.log10(rgb_images + 1e-6)

    def extract_dab_intensity(
        self,
        images: torch.Tensor,
        normalize: str = "max"
    ) -> torch.Tensor:
        """Extract DAB stain intensity from IHC images.

        Args:
            images: [B, 3, H, W] RGB images in [-1, 1] or [0, 1]
            normalize: "none", "max", or "meanstd"

        Returns:
            dab_intensity: [B, H, W] DAB intensity map
        """
        B, C, H, W = images.shape
        assert C == 3, "Input must be RGB images"

        # Auto-convert [-1, 1] -> [0, 1] if needed
        if images.min() < 0:
            images = (images + 1.0) / 2.0

        od = self.rgb_to_od(images)
        od_flat = od.permute(0, 2, 3, 1).reshape(-1, 3)

        # Ensure deconv_matrix is on same device as input
        deconv_matrix = self.deconv_matrix.to(od_flat.device)

        # Deconvolve: concentrations = OD @ M_inv^T
        concentrations = od_flat @ deconv_matrix.T
        dab_flat = concentrations[:, 0]  # DAB channel

        dab_intensity = dab_flat.reshape(B, H, W)

        # Softplus for smooth gradients (beta=5.0 for sharper transition)
        dab = F.softplus(dab_intensity, beta=5.0)

        if normalize == "max" or normalize is True:
            mx = dab.amax(dim=(1, 2), keepdim=True).clamp(min=1e-6)
            dab = dab / mx
        elif normalize == "meanstd":
            mean = dab.mean(dim=(1, 2), keepdim=True)
            std = dab.std(dim=(1, 2), keepdim=True).clamp(min=1e-6)
            dab = (dab - mean) / std
        elif normalize == "none" or normalize is False:
            pass
        else:
            raise ValueError(f"Unknown normalization: {normalize}")

        return dab