Spaces:
Running
Running
| """ | |
| DAB (3,3'-Diaminobenzidine) stain extraction via color deconvolution. | |
| Reference: Ruifrok & Johnston, "Quantification of histochemical | |
| staining by color deconvolution", Anal Quant Cytol Histol 2001 | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| class DABExtractor: | |
| """Extract DAB stain intensity from IHC images using color deconvolution. | |
| Uses the Ruifrok & Johnston H-DAB stain matrix with softplus smoothing | |
| for differentiable training loss computation. | |
| """ | |
| def __init__(self, device='cuda'): | |
| self.device = device | |
| # Standard H-DAB stain matrix (Ruifrok & Johnston) | |
| # Each row is a stain vector in RGB optical density space | |
| self.stain_matrix = torch.tensor([ | |
| [0.268, 0.570, 0.776], # DAB (brown) | |
| [0.650, 0.704, 0.286], # Hematoxylin (blue) | |
| ], device=device, dtype=torch.float32) | |
| # Pseudo-inverse for deconvolution: [3, 2] | |
| self.deconv_matrix = torch.linalg.pinv(self.stain_matrix.T) | |
| def rgb_to_od(self, rgb_images: torch.Tensor) -> torch.Tensor: | |
| """Convert RGB [0,1] to optical density: OD = -log10(I/I0).""" | |
| rgb_images = rgb_images.clamp(1e-6, 1.0) | |
| return -torch.log10(rgb_images + 1e-6) | |
| def extract_dab_intensity( | |
| self, | |
| images: torch.Tensor, | |
| normalize: str = "max" | |
| ) -> torch.Tensor: | |
| """Extract DAB stain intensity from IHC images. | |
| Args: | |
| images: [B, 3, H, W] RGB images in [-1, 1] or [0, 1] | |
| normalize: "none", "max", or "meanstd" | |
| Returns: | |
| dab_intensity: [B, H, W] DAB intensity map | |
| """ | |
| B, C, H, W = images.shape | |
| assert C == 3, "Input must be RGB images" | |
| # Auto-convert [-1, 1] -> [0, 1] if needed | |
| if images.min() < 0: | |
| images = (images + 1.0) / 2.0 | |
| od = self.rgb_to_od(images) | |
| od_flat = od.permute(0, 2, 3, 1).reshape(-1, 3) | |
| # Ensure deconv_matrix is on same device as input | |
| deconv_matrix = self.deconv_matrix.to(od_flat.device) | |
| # Deconvolve: concentrations = OD @ M_inv^T | |
| concentrations = od_flat @ deconv_matrix.T | |
| dab_flat = concentrations[:, 0] # DAB channel | |
| dab_intensity = dab_flat.reshape(B, H, W) | |
| # Softplus for smooth gradients (beta=5.0 for sharper transition) | |
| dab = F.softplus(dab_intensity, beta=5.0) | |
| if normalize == "max" or normalize is True: | |
| mx = dab.amax(dim=(1, 2), keepdim=True).clamp(min=1e-6) | |
| dab = dab / mx | |
| elif normalize == "meanstd": | |
| mean = dab.mean(dim=(1, 2), keepdim=True) | |
| std = dab.std(dim=(1, 2), keepdim=True).clamp(min=1e-6) | |
| dab = (dab - mean) / std | |
| elif normalize == "none" or normalize is False: | |
| pass | |
| else: | |
| raise ValueError(f"Unknown normalization: {normalize}") | |
| return dab | |