| | |
| |
|
| | import os |
| | from typing import Any, Callable, Sequence |
| |
|
| | import monai |
| | import monai.transforms as mt |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from monai.data.meta_obj import get_track_meta |
| | from monai.networks.blocks import ConvDenseBlock, Convolution |
| | from monai.networks.layers import Flatten, Reshape |
| | from monai.networks.nets import Regressor |
| | from monai.networks.utils import meshgrid_ij |
| | from monai.utils import CommonKeys |
| | from monai.utils import ImageMetaKey as Key |
| | from monai.utils import convert_to_numpy, convert_to_tensor |
| |
|
| | |
| | LM_INDICES = { |
| | 10: 0, |
| | 15: 1, |
| | 20: 2, |
| | 25: 3, |
| | 30: 4, |
| | 35: 5, |
| | 100: 6, |
| | 150: 7, |
| | 200: 8, |
| | 250: 9, |
| | } |
| |
|
| | output_trans = monai.handlers.from_engine(["pred", "label"]) |
| |
|
| |
|
| | def _output_lm_trans(data): |
| | pred, label = output_trans(data) |
| | return [p.permute(1, 0) for p in pred], [l.permute(1, 0) for l in label] |
| |
|
| |
|
| | def convert_lm_image_t(lm_image): |
| | """Convert a landmark image into a (2,N) tensor of landmark coordinates.""" |
| | lmarray = torch.zeros((2, len(LM_INDICES)), dtype=torch.float32).to(lm_image.device) |
| |
|
| | for _, y, x in np.argwhere(lm_image.cpu().numpy() != 0): |
| | im_id = int(lm_image[0, y, x]) |
| | lm_index = LM_INDICES[im_id] |
| |
|
| | lmarray[0, lm_index] = y |
| | lmarray[1, lm_index] = x |
| |
|
| | return lmarray |
| |
|
| |
|
| | class ParallelCat(nn.Module): |
| | """ |
| | Apply the same input to each of the given modules and concatenate their results together. |
| | |
| | Args: |
| | catmodules: sequence of nn.Module objects to apply inputs to |
| | cat_dim: dimension to concatenate along when combining outputs |
| | """ |
| |
|
| | def __init__(self, catmodules: Sequence[nn.Module], cat_dim: int = 1): |
| | super().__init__() |
| | self.cat_dim = cat_dim |
| |
|
| | for i, s in enumerate(catmodules): |
| | self.add_module(f"catmodule_{i}", s) |
| |
|
| | def forward(self, x): |
| | tensors = [s(x) for s in self.children()] |
| | return torch.cat(tensors, self.cat_dim) |
| |
|
| |
|
| | class PointRegressor(Regressor): |
| | """Regressor defined as a sequence of dense blocks followed by convolution/linear layers for each landmark.""" |
| |
|
| | def _get_layer(self, in_channels, out_channels, strides, is_last): |
| | dout = out_channels - in_channels |
| | dilations = [1, 2, 4] |
| | dchannels = [dout // 3, dout // 3, dout // 3 + dout % 3] |
| |
|
| | db = ConvDenseBlock( |
| | spatial_dims=self.dimensions, |
| | in_channels=in_channels, |
| | channels=dchannels, |
| | dilations=dilations, |
| | kernel_size=self.kernel_size, |
| | num_res_units=self.num_res_units, |
| | act=self.act, |
| | norm=self.norm, |
| | dropout=self.dropout, |
| | bias=self.bias, |
| | ) |
| |
|
| | conv = Convolution( |
| | spatial_dims=self.dimensions, |
| | in_channels=out_channels, |
| | out_channels=out_channels, |
| | strides=strides, |
| | kernel_size=self.kernel_size, |
| | act=self.act, |
| | norm=self.norm, |
| | dropout=self.dropout, |
| | bias=self.bias, |
| | conv_only=is_last, |
| | ) |
| |
|
| | return nn.Sequential(db, conv) |
| |
|
| | def _get_final_layer(self, in_shape): |
| | point_paths = [] |
| |
|
| | for _ in range(self.out_shape[1]): |
| | conv = Convolution( |
| | spatial_dims=self.dimensions, |
| | in_channels=in_shape[0], |
| | out_channels=in_shape[0] * 2, |
| | strides=2, |
| | kernel_size=self.kernel_size, |
| | act=self.act, |
| | norm=self.norm, |
| | dropout=self.dropout, |
| | conv_only=True, |
| | ) |
| | linear = nn.Linear(int(np.product(in_shape)) // 2, self.out_shape[0]) |
| | point_paths.append(nn.Sequential(conv, Flatten(), linear)) |
| |
|
| | return torch.nn.Sequential(ParallelCat(point_paths), Reshape(*self.out_shape)) |
| |
|
| |
|
| | class LandmarkInferer(monai.inferers.Inferer): |
| | """Applies inference on 2D slices from 3D volumes.""" |
| |
|
| | def __init__(self, spatial_dim=0, stack_dim=-1): |
| | self.spatial_dim = spatial_dim |
| | self.stack_dim = stack_dim |
| |
|
| | def __call__(self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any): |
| | if inputs.ndim != 5: |
| | raise ValueError(f"Input volume to inferer must have shape BCDHW, input shape is {inputs.shape}") |
| |
|
| | results = [] |
| | input_slices = [slice(None) for _ in range(inputs.ndim)] |
| |
|
| | for idx in range(inputs.shape[self.spatial_dim + 2]): |
| | input_slices[self.spatial_dim + 2] = idx |
| | input_2d = inputs[input_slices] |
| |
|
| | result = network(input_2d, *args, **kwargs) |
| | results.append(result) |
| |
|
| | result = torch.stack(results, self.stack_dim) |
| | return result |
| |
|
| |
|
| | class NpySaverd(mt.MapTransform): |
| | """Saves tensors/arrays to Numpy npy files.""" |
| |
|
| | def __init__(self, keys, output_dir, data_root_dir): |
| | super().__init__(keys) |
| | self.output_dir = output_dir |
| | self.data_root_dir = data_root_dir |
| | self.folder_layout = monai.data.FolderLayout( |
| | self.output_dir, extension=".npy", data_root_dir=self.data_root_dir |
| | ) |
| |
|
| | def __call__(self, d): |
| | if not os.path.exists(self.output_dir): |
| | os.makedirs(self.output_dir, exist_ok=True) |
| |
|
| | for key in self.key_iterator(d): |
| | orig_filename = d[key].meta[Key.FILENAME_OR_OBJ] |
| | if isinstance(orig_filename, (list, tuple)): |
| | orig_filename = orig_filename[0] |
| |
|
| | out_filename = self.folder_layout.filename(orig_filename, key=key) |
| |
|
| | np.save(out_filename, convert_to_numpy(d[key])) |
| |
|
| | return d |
| |
|
| |
|
| | class FourierDropout(mt.Transform, mt.Fourier): |
| | """ |
| | Apply dropout in Fourier space to corrupt images. This works by zeroing out pixels with greater probability the |
| | farther from the centre they are. All pixels closer than `min_dist` to the center are preserved, all beyond |
| | `max_dist` become 0. Distances from the centre to an edge in a given dimension are defined as 1.0. |
| | |
| | Args: |
| | min_dist: minimum distance to apply dropout, must be >0, smaller values will cause greater corruption |
| | max_dist: maximal distance to apply dropout, must be greater than `min_dist`, all pixels beyond become 0 |
| | """ |
| |
|
| | def __init__(self, min_dist: float = 0.1, max_dist: float = 0.9): |
| | super().__init__() |
| | self.min_dist = min_dist |
| | self.max_dist = max_dist |
| | self.prob_field = None |
| | self.field_shape = None |
| |
|
| | def _get_prob_field(self, shape): |
| | shape = tuple(shape) |
| | if shape != self.field_shape: |
| | self.field_shape = shape |
| | spaces = [torch.linspace(-1, 1, s) for s in shape[1:]] |
| | grids = meshgrid_ij(*spaces) |
| | |
| | self.prob_field = torch.stack(grids).pow_(2).sum(axis=0).sqrt_() |
| |
|
| | return self.prob_field |
| |
|
| | def __call__(self, im): |
| | probfield = self._get_prob_field(im.shape).to(im.device) |
| |
|
| | |
| | dropout = torch.rand_like(im).mul_(self.max_dist - self.min_dist).add_(self.min_dist) |
| | |
| | dropout = dropout.ge_(probfield) |
| |
|
| | result = self.shift_fourier(im, im.ndim - 1) |
| | result.mul_(dropout) |
| | result = self.inv_shift_fourier(result, im.ndim - 1) |
| |
|
| | return convert_to_tensor(result, track_meta=get_track_meta()) |
| |
|
| |
|
| | class RandFourierDropout(mt.RandomizableTransform): |
| | def __init__(self, min_dist=0.1, max_dist=0.9, prob=0.1): |
| | mt.RandomizableTransform.__init__(self, prob) |
| | self.dropper = FourierDropout(min_dist, max_dist) |
| |
|
| | def __call__(self, im, randomize: bool = True): |
| | if randomize: |
| | self.randomize(None) |
| |
|
| | if self._do_transform: |
| | im = self.dropper(im) |
| | else: |
| | im = convert_to_tensor(im, track_meta=get_track_meta()) |
| |
|
| | return im |
| |
|
| |
|
| | class RandFourierDropoutd(mt.RandomizableTransform, mt.MapTransform): |
| | def __init__(self, keys, min_dist=0.1, max_dist=0.9, prob=0.1): |
| | mt.RandomizableTransform.__init__(self, prob) |
| | mt.MapTransform.__init__(self, keys) |
| | self.dropper = FourierDropout(min_dist, max_dist) |
| |
|
| | def __call__(self, data, randomize: bool = True): |
| | d = dict(data) |
| |
|
| | if randomize: |
| | self.randomize(None) |
| |
|
| | for key in self.key_iterator(d): |
| | if self._do_transform: |
| | d[key] = self.dropper(d[key]) |
| | else: |
| | d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) |
| |
|
| | return d |
| |
|
| |
|
| | class RandImageLMDeformd(mt.RandSmoothDeform): |
| | """Apply smooth random deformation to the image and landmark locations.""" |
| |
|
| | def __call__(self, d): |
| | d = dict(d) |
| | old_label = d[CommonKeys.LABEL] |
| | new_label = torch.zeros_like(old_label) |
| |
|
| | d[CommonKeys.IMAGE] = super().__call__(d[CommonKeys.IMAGE]) |
| |
|
| | if self._do_transform: |
| | field = self.sfield() |
| | labels = np.argwhere(d[CommonKeys.LABEL][0].cpu().numpy() > 0) |
| |
|
| | |
| | |
| | for y, x in labels: |
| | dy = int(field[0, y, x] * new_label.shape[1] / 2) |
| | dx = int(field[1, y, x] * new_label.shape[2] / 2) |
| |
|
| | new_label[:, y - dy, x - dx] = old_label[:, y, x] |
| |
|
| | d[CommonKeys.LABEL] = new_label |
| |
|
| | return d |
| |
|
| |
|
| | class RandLMShiftd(mt.RandomizableTransform, mt.MapTransform): |
| | """Randomly shift the image and landmark image in either direction in integer amounts.""" |
| |
|
| | def __init__(self, keys, spatial_size, max_shift=0, prob=0.1): |
| | mt.RandomizableTransform.__init__(self, prob=prob) |
| | mt.MapTransform.__init__(self, keys=keys) |
| |
|
| | self.spatial_size = tuple(spatial_size) |
| | self.max_shift = max_shift |
| | self.padder = mt.BorderPad(self.max_shift) |
| | self.unpadder = mt.CenterSpatialCrop(self.spatial_size) |
| | self.shift = (0,) * len(self.spatial_size) |
| | self.roll_dims = list(range(1, len(self.spatial_size) + 1)) |
| |
|
| | def randomize(self, data): |
| | super().randomize(None) |
| | if self._do_transform: |
| | rs = torch.randint(-self.max_shift, self.max_shift, (len(self.spatial_size),), dtype=torch.int32) |
| | self.shift = tuple(rs.tolist()) |
| |
|
| | def __call__(self, d, randomize: bool = True): |
| | d = dict(d) |
| |
|
| | if randomize: |
| | self.randomize(None) |
| |
|
| | if self._do_transform: |
| | for key in self.key_iterator(d): |
| | imp = self.padder(d[key]) |
| | ims = torch.roll(imp, self.shift, self.roll_dims) |
| | d[key] = self.unpadder(ims) |
| |
|
| | return d |
| |
|