| | import os |
| | import random |
| |
|
| | import torch |
| | from PIL import Image |
| | from torchstain.base.normalizers.he_normalizer import HENormalizer |
| | from torchstain.torch.utils import cov, percentile |
| | from torchvision import transforms |
| | from torchvision.transforms.functional import to_pil_image |
| |
|
| |
|
| | def preprocessor(pretrained=False, normalizer=None): |
| | if pretrained: |
| | mean = (0.485, 0.456, 0.406) |
| | std = (0.229, 0.224, 0.225) |
| | else: |
| | mean = (0.5, 0.5, 0.5) |
| | std = (0.5, 0.5, 0.5) |
| |
|
| | preprocess = transforms.Compose( |
| | [ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.Lambda(lambda x: x) if normalizer == None else normalizer, |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=mean, std=std), |
| | ] |
| | ) |
| |
|
| | return preprocess |
| |
|
| |
|
| | """ |
| | Source code ported from: https://github.com/schaugf/HEnorm_python |
| | Original implementation: https://github.com/mitkovetta/staining-normalization |
| | """ |
| |
|
| |
|
| | class TorchMacenkoNormalizer(HENormalizer): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | self.HERef = torch.tensor( |
| | [[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]] |
| | ) |
| | self.maxCRef = torch.tensor([1.9705, 1.0308]) |
| |
|
| | |
| | self.updated_lstsq = hasattr(torch.linalg, "lstsq") |
| |
|
| | def __convert_rgb2od(self, I, Io, beta): |
| | I = I.permute(1, 2, 0) |
| |
|
| | |
| | OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io) |
| |
|
| | |
| | ODhat = OD[~torch.any(OD < beta, dim=1)] |
| |
|
| | return OD, ODhat |
| |
|
| | def __find_HE(self, ODhat, eigvecs, alpha): |
| | |
| | |
| | That = torch.matmul(ODhat, eigvecs) |
| | phi = torch.atan2(That[:, 1], That[:, 0]) |
| | |
| |
|
| | minPhi = percentile(phi, alpha) |
| | maxPhi = percentile(phi, 100 - alpha) |
| |
|
| | vMin = torch.matmul( |
| | eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi))) |
| | ).unsqueeze(1) |
| | vMax = torch.matmul( |
| | eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi))) |
| | ).unsqueeze(1) |
| |
|
| | |
| | |
| | HE = torch.where( |
| | vMin[0] > vMax[0], |
| | torch.cat((vMin, vMax), dim=1), |
| | torch.cat((vMax, vMin), dim=1), |
| | ) |
| |
|
| | return HE |
| |
|
| | def __find_concentration(self, OD, HE): |
| | |
| | Y = OD.T |
| |
|
| | |
| | if not self.updated_lstsq: |
| | return torch.lstsq(Y, HE)[0][:2] |
| |
|
| | return torch.linalg.lstsq(HE, Y)[0] |
| |
|
| | def __compute_matrices(self, I, Io, alpha, beta): |
| | OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) |
| |
|
| | |
| | _, eigvecs = torch.linalg.eigh(cov(ODhat.T)) |
| | eigvecs = eigvecs[:, [1, 2]] |
| |
|
| | HE = self.__find_HE(ODhat, eigvecs, alpha) |
| |
|
| | C = self.__find_concentration(OD, HE) |
| | maxC = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)]) |
| |
|
| | return HE, C, maxC |
| |
|
| | def fit(self, I, Io=240, alpha=1, beta=0.15): |
| | HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta) |
| |
|
| | self.HERef = HE |
| | self.maxCRef = maxC |
| |
|
| | def normalize( |
| | self, I, Io=240, alpha=1, beta=0.15, stains=True, form="chw", dtype="int" |
| | ): |
| | """Normalize staining appearence of H&E stained images |
| | |
| | Example use: |
| | see test.py |
| | |
| | Input: |
| | I: RGB input image: tensor of shape [C, H, W] and type uint8 |
| | Io: (optional) transmitted light intensity |
| | alpha: percentile |
| | beta: transparency threshold |
| | stains: if true, return also H & E components |
| | |
| | Output: |
| | Inorm: normalized image |
| | H: hematoxylin image |
| | E: eosin image |
| | |
| | Reference: |
| | A method for normalizing histology slides for quantitative analysis. M. |
| | Macenko et al., ISBI 2009 |
| | """ |
| |
|
| | c, h, w = I.shape |
| |
|
| | HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) |
| |
|
| | |
| | C *= (self.maxCRef / maxC).unsqueeze(-1) |
| |
|
| | |
| | Inorm = Io * torch.exp(-torch.matmul(self.HERef, C)) |
| | Inorm = torch.clip(Inorm, 0, 255) |
| |
|
| | Inorm = Inorm.reshape(c, h, w).float() / 255.0 |
| | Inorm = torch.clip(Inorm, 0.0, 1.0) |
| |
|
| | H, E = None, None |
| |
|
| | if stains: |
| | H = torch.mul( |
| | Io, |
| | torch.exp( |
| | torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0)) |
| | ), |
| | ) |
| | H[H > 255] = 255 |
| | H = H.T.reshape(h, w, c).int() |
| |
|
| | E = torch.mul( |
| | Io, |
| | torch.exp( |
| | torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0)) |
| | ), |
| | ) |
| | E[E > 255] = 255 |
| | E = E.T.reshape(h, w, c).int() |
| |
|
| | return Inorm, H, E |
| |
|
| |
|
| | class MacenkoNormalizer: |
| | def __init__(self, target_path=None, prob=1): |
| | self.transform_before_macenko = transforms.Compose( |
| | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 255)] |
| | ) |
| | self.normalizer = TorchMacenkoNormalizer() |
| |
|
| | ext = os.path.splitext(target_path)[1].lower() |
| | if ext in [".jpg", ".jpeg", ".png"]: |
| | target = Image.open(target_path) |
| | self.normalizer.fit(self.transform_before_macenko(target)) |
| | elif ext in [".pt"]: |
| | target = torch.load(target_path) |
| | self.normalizer.HERef = target["HERef"] |
| | self.normalizer.maxCRef = target["maxCRef"] |
| |
|
| | else: |
| | raise ValueError(f"Invalid extension: {ext}") |
| | self.prob = prob |
| |
|
| | def __call__(self, image): |
| | t_to_transform = self.transform_before_macenko(image) |
| | try: |
| | image_macenko, _, _ = self.normalizer.normalize( |
| | I=t_to_transform, stains=False, form="chw", dtype="float" |
| | ) |
| | if torch.any(torch.isnan(image_macenko)): |
| | return image |
| | else: |
| | image_macenko = to_pil_image(image_macenko) |
| | return image_macenko |
| | except Exception as e: |
| | if "kthvalue()" in str(e) or "linalg.eigh" in str(e): |
| | pass |
| | else: |
| | print(str(e)) |
| | return image |
| |
|