Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from typing import cast | |
| import torch | |
| import torch.nn as nn | |
| from monai.networks.nets import resnet | |
| from monai.utils.module import optional_import | |
| models, _ = optional_import("torchvision.models") | |
| class MILModel3D(nn.Module): | |
| """ | |
| Multiple Instance Learning (MIL) model, with a backbone classification model. | |
| Adapted from MONAI, modified for 3D images. The expected shape of input data is `[B, N, C, D, H, W]`, | |
| where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances | |
| extracted from every original image in the batch. | |
| Args: | |
| num_classes: number of output classes. | |
| mil_mode: MIL algorithm, available values (Defaults to ``"att"``): | |
| - ``"mean"`` - average features from all instances, equivalent to pure CNN (non MIL). | |
| - ``"max"`` - retain only the instance with the max probability for loss calculation. | |
| - ``"att"`` - attention based MIL https://arxiv.org/abs/1802.04712. | |
| - ``"att_trans"`` - transformer MIL https://arxiv.org/abs/2111.01556. | |
| - ``"att_trans_pyramid"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556. | |
| backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features, | |
| or a string name of a torchvision model). | |
| Defaults to ``None``, in which case ResNet18 is used. | |
| backbone_num_features: Number of output features of the backbone CNN | |
| Defaults to ``None`` (necessary only when using a custom backbone) | |
| trans_blocks: number of the blocks in `TransformEncoder` layer. | |
| trans_dropout: dropout rate in `TransformEncoder` layer. | |
| """ | |
| def __init__( | |
| self, | |
| num_classes: int, | |
| mil_mode: str = "att", | |
| pretrained: bool = True, | |
| backbone: str | nn.Module | None = None, | |
| backbone_num_features: int | None = None, | |
| trans_blocks: int = 4, | |
| trans_dropout: float = 0.0, | |
| ) -> None: | |
| super().__init__() | |
| if num_classes <= 0: | |
| raise ValueError("Number of classes must be positive: " + str(num_classes)) | |
| if mil_mode.lower() not in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]: | |
| raise ValueError("Unsupported mil_mode: " + str(mil_mode)) | |
| self.mil_mode = mil_mode.lower() | |
| self.attention = nn.Sequential() | |
| self.transformer: nn.Module | None = None | |
| net: nn.Module | |
| if backbone is None: | |
| net = resnet.resnet18( | |
| spatial_dims=3, | |
| n_input_channels=3, | |
| num_classes=5, | |
| ) | |
| assert net.fc is not None | |
| nfc = net.fc.in_features # save the number of final features | |
| net.fc = torch.nn.Identity() # type: ignore[assignment] | |
| self.extra_outputs: dict[str, torch.Tensor] = {} | |
| if mil_mode == "att_trans_pyramid": | |
| # register hooks to capture outputs of intermediate layers | |
| def forward_hook(layer_name): | |
| def hook(module, input, output): | |
| self.extra_outputs[layer_name] = output | |
| return hook | |
| net.layer1.register_forward_hook(forward_hook("layer1")) | |
| net.layer2.register_forward_hook(forward_hook("layer2")) | |
| net.layer3.register_forward_hook(forward_hook("layer3")) | |
| net.layer4.register_forward_hook(forward_hook("layer4")) | |
| elif isinstance(backbone, str): | |
| # assume torchvision model string is provided | |
| torch_model = getattr(models, backbone, None) | |
| if torch_model is None: | |
| raise ValueError("Unknown torch vision model" + str(backbone)) | |
| net = torch_model(pretrained=pretrained) | |
| if getattr(net, "fc", None) is not None: | |
| nfc = net.fc.in_features # save the number of final features | |
| net.fc = torch.nn.Identity() # type: ignore[assignment] | |
| else: | |
| raise ValueError( | |
| "Unable to detect FC layer for the torchvision model " + str(backbone), | |
| ". Please initialize the backbone model manually.", | |
| ) | |
| elif isinstance(backbone, nn.Module): | |
| # use a custom backbone | |
| net = backbone | |
| if backbone_num_features is None: | |
| raise ValueError( | |
| "Number of endencoder features must be provided for a custom backbone model" | |
| ) | |
| nfc = backbone_num_features | |
| net.fc = torch.nn.Identity() # type: ignore[assignment] | |
| if mil_mode == "att_trans_pyramid": | |
| # register hooks to capture outputs of intermediate layers | |
| raise ValueError( | |
| "Cannot use att_trans_pyramid with custom backbone. Have to use the default ResNet 18 backbone." | |
| ) | |
| else: | |
| raise ValueError("Unsupported backbone") | |
| if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]: | |
| raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode)) | |
| if self.mil_mode in ["mean", "max"]: | |
| pass | |
| elif self.mil_mode == "att": | |
| self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) | |
| elif self.mil_mode == "att_trans": | |
| transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout) | |
| self.transformer = nn.TransformerEncoder(transformer, num_layers=trans_blocks) | |
| self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) | |
| elif self.mil_mode == "att_trans_pyramid": | |
| transformer_list = nn.ModuleList( | |
| [ | |
| nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout), | |
| num_layers=trans_blocks, | |
| ), | |
| nn.Sequential( | |
| nn.Linear(192, 64), | |
| nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout), | |
| num_layers=trans_blocks, | |
| ), | |
| ), | |
| nn.Sequential( | |
| nn.Linear(320, 64), | |
| nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout), | |
| num_layers=trans_blocks, | |
| ), | |
| ), | |
| nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer(d_model=576, nhead=8, dropout=trans_dropout), | |
| num_layers=trans_blocks, | |
| ), | |
| ] | |
| ) | |
| self.transformer = transformer_list | |
| nfc = nfc + 64 | |
| self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) | |
| else: | |
| raise ValueError("Unsupported mil_mode: " + str(mil_mode)) | |
| self.myfc = nn.Linear(nfc, num_classes) | |
| self.net = net | |
| def calc_head(self, x: torch.Tensor) -> torch.Tensor: | |
| sh = x.shape | |
| if self.mil_mode == "mean": | |
| x = self.myfc(x) | |
| x = torch.mean(x, dim=1) | |
| elif self.mil_mode == "max": | |
| x = self.myfc(x) | |
| x, _ = torch.max(x, dim=1) | |
| elif self.mil_mode == "att": | |
| a = self.attention(x) | |
| a = torch.softmax(a, dim=1) | |
| x = torch.sum(x * a, dim=1) | |
| x = self.myfc(x) | |
| elif self.mil_mode == "att_trans" and self.transformer is not None: | |
| x = x.permute(1, 0, 2) | |
| x = self.transformer(x) | |
| x = x.permute(1, 0, 2) | |
| a = self.attention(x) | |
| a = torch.softmax(a, dim=1) | |
| x = torch.sum(x * a, dim=1) | |
| x = self.myfc(x) | |
| elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None: | |
| l1 = ( | |
| torch.mean(self.extra_outputs["layer1"], dim=(2, 3, 4)) | |
| .reshape(sh[0], sh[1], -1) | |
| .permute(1, 0, 2) | |
| ) | |
| l2 = ( | |
| torch.mean(self.extra_outputs["layer2"], dim=(2, 3, 4)) | |
| .reshape(sh[0], sh[1], -1) | |
| .permute(1, 0, 2) | |
| ) | |
| l3 = ( | |
| torch.mean(self.extra_outputs["layer3"], dim=(2, 3, 4)) | |
| .reshape(sh[0], sh[1], -1) | |
| .permute(1, 0, 2) | |
| ) | |
| l4 = ( | |
| torch.mean(self.extra_outputs["layer4"], dim=(2, 3, 4)) | |
| .reshape(sh[0], sh[1], -1) | |
| .permute(1, 0, 2) | |
| ) | |
| transformer_list = cast(nn.ModuleList, self.transformer) | |
| x = transformer_list[0](l1) | |
| x = transformer_list[1](torch.cat((x, l2), dim=2)) | |
| x = transformer_list[2](torch.cat((x, l3), dim=2)) | |
| x = transformer_list[3](torch.cat((x, l4), dim=2)) | |
| x = x.permute(1, 0, 2) | |
| a = self.attention(x) | |
| a = torch.softmax(a, dim=1) | |
| x = torch.sum(x * a, dim=1) | |
| x = self.myfc(x) | |
| else: | |
| raise ValueError("Wrong model mode" + str(self.mil_mode)) | |
| return x | |
| def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor: | |
| sh = x.shape | |
| x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5]) | |
| x = self.net(x) | |
| x = x.reshape(sh[0], sh[1], -1) | |
| if not no_head: | |
| x = self.calc_head(x) | |
| return x | |