UNIStainNet / src /utils /dab.py
faceless-void's picture
Upload folder using huggingface_hub
4db9215 verified
"""
DAB (3,3'-Diaminobenzidine) stain extraction via color deconvolution.
Reference: Ruifrok & Johnston, "Quantification of histochemical
staining by color deconvolution", Anal Quant Cytol Histol 2001
"""
import torch
import torch.nn.functional as F
class DABExtractor:
"""Extract DAB stain intensity from IHC images using color deconvolution.
Uses the Ruifrok & Johnston H-DAB stain matrix with softplus smoothing
for differentiable training loss computation.
"""
def __init__(self, device='cuda'):
self.device = device
# Standard H-DAB stain matrix (Ruifrok & Johnston)
# Each row is a stain vector in RGB optical density space
self.stain_matrix = torch.tensor([
[0.268, 0.570, 0.776], # DAB (brown)
[0.650, 0.704, 0.286], # Hematoxylin (blue)
], device=device, dtype=torch.float32)
# Pseudo-inverse for deconvolution: [3, 2]
self.deconv_matrix = torch.linalg.pinv(self.stain_matrix.T)
def rgb_to_od(self, rgb_images: torch.Tensor) -> torch.Tensor:
"""Convert RGB [0,1] to optical density: OD = -log10(I/I0)."""
rgb_images = rgb_images.clamp(1e-6, 1.0)
return -torch.log10(rgb_images + 1e-6)
def extract_dab_intensity(
self,
images: torch.Tensor,
normalize: str = "max"
) -> torch.Tensor:
"""Extract DAB stain intensity from IHC images.
Args:
images: [B, 3, H, W] RGB images in [-1, 1] or [0, 1]
normalize: "none", "max", or "meanstd"
Returns:
dab_intensity: [B, H, W] DAB intensity map
"""
B, C, H, W = images.shape
assert C == 3, "Input must be RGB images"
# Auto-convert [-1, 1] -> [0, 1] if needed
if images.min() < 0:
images = (images + 1.0) / 2.0
od = self.rgb_to_od(images)
od_flat = od.permute(0, 2, 3, 1).reshape(-1, 3)
# Ensure deconv_matrix is on same device as input
deconv_matrix = self.deconv_matrix.to(od_flat.device)
# Deconvolve: concentrations = OD @ M_inv^T
concentrations = od_flat @ deconv_matrix.T
dab_flat = concentrations[:, 0] # DAB channel
dab_intensity = dab_flat.reshape(B, H, W)
# Softplus for smooth gradients (beta=5.0 for sharper transition)
dab = F.softplus(dab_intensity, beta=5.0)
if normalize == "max" or normalize is True:
mx = dab.amax(dim=(1, 2), keepdim=True).clamp(min=1e-6)
dab = dab / mx
elif normalize == "meanstd":
mean = dab.mean(dim=(1, 2), keepdim=True)
std = dab.std(dim=(1, 2), keepdim=True).clamp(min=1e-6)
dab = (dab - mean) / std
elif normalize == "none" or normalize is False:
pass
else:
raise ValueError(f"Unknown normalization: {normalize}")
return dab