Anirudh Balaraman
add ci
caf6ee7
from __future__ import annotations
from typing import cast
import torch
import torch.nn as nn
from monai.networks.nets import resnet
from monai.utils.module import optional_import
models, _ = optional_import("torchvision.models")
class MILModel3D(nn.Module):
"""
Multiple Instance Learning (MIL) model, with a backbone classification model.
Adapted from MONAI, modified for 3D images. The expected shape of input data is `[B, N, C, D, H, W]`,
where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances
extracted from every original image in the batch.
Args:
num_classes: number of output classes.
mil_mode: MIL algorithm, available values (Defaults to ``"att"``):
- ``"mean"`` - average features from all instances, equivalent to pure CNN (non MIL).
- ``"max"`` - retain only the instance with the max probability for loss calculation.
- ``"att"`` - attention based MIL https://arxiv.org/abs/1802.04712.
- ``"att_trans"`` - transformer MIL https://arxiv.org/abs/2111.01556.
- ``"att_trans_pyramid"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556.
backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features,
or a string name of a torchvision model).
Defaults to ``None``, in which case ResNet18 is used.
backbone_num_features: Number of output features of the backbone CNN
Defaults to ``None`` (necessary only when using a custom backbone)
trans_blocks: number of the blocks in `TransformEncoder` layer.
trans_dropout: dropout rate in `TransformEncoder` layer.
"""
def __init__(
self,
num_classes: int,
mil_mode: str = "att",
pretrained: bool = True,
backbone: str | nn.Module | None = None,
backbone_num_features: int | None = None,
trans_blocks: int = 4,
trans_dropout: float = 0.0,
) -> None:
super().__init__()
if num_classes <= 0:
raise ValueError("Number of classes must be positive: " + str(num_classes))
if mil_mode.lower() not in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]:
raise ValueError("Unsupported mil_mode: " + str(mil_mode))
self.mil_mode = mil_mode.lower()
self.attention = nn.Sequential()
self.transformer: nn.Module | None = None
net: nn.Module
if backbone is None:
net = resnet.resnet18(
spatial_dims=3,
n_input_channels=3,
num_classes=5,
)
assert net.fc is not None
nfc = net.fc.in_features # save the number of final features
net.fc = torch.nn.Identity() # type: ignore[assignment]
self.extra_outputs: dict[str, torch.Tensor] = {}
if mil_mode == "att_trans_pyramid":
# register hooks to capture outputs of intermediate layers
def forward_hook(layer_name):
def hook(module, input, output):
self.extra_outputs[layer_name] = output
return hook
net.layer1.register_forward_hook(forward_hook("layer1"))
net.layer2.register_forward_hook(forward_hook("layer2"))
net.layer3.register_forward_hook(forward_hook("layer3"))
net.layer4.register_forward_hook(forward_hook("layer4"))
elif isinstance(backbone, str):
# assume torchvision model string is provided
torch_model = getattr(models, backbone, None)
if torch_model is None:
raise ValueError("Unknown torch vision model" + str(backbone))
net = torch_model(pretrained=pretrained)
if getattr(net, "fc", None) is not None:
nfc = net.fc.in_features # save the number of final features
net.fc = torch.nn.Identity() # type: ignore[assignment]
else:
raise ValueError(
"Unable to detect FC layer for the torchvision model " + str(backbone),
". Please initialize the backbone model manually.",
)
elif isinstance(backbone, nn.Module):
# use a custom backbone
net = backbone
if backbone_num_features is None:
raise ValueError(
"Number of endencoder features must be provided for a custom backbone model"
)
nfc = backbone_num_features
net.fc = torch.nn.Identity() # type: ignore[assignment]
if mil_mode == "att_trans_pyramid":
# register hooks to capture outputs of intermediate layers
raise ValueError(
"Cannot use att_trans_pyramid with custom backbone. Have to use the default ResNet 18 backbone."
)
else:
raise ValueError("Unsupported backbone")
if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]:
raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode))
if self.mil_mode in ["mean", "max"]:
pass
elif self.mil_mode == "att":
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))
elif self.mil_mode == "att_trans":
transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout)
self.transformer = nn.TransformerEncoder(transformer, num_layers=trans_blocks)
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))
elif self.mil_mode == "att_trans_pyramid":
transformer_list = nn.ModuleList(
[
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout),
num_layers=trans_blocks,
),
nn.Sequential(
nn.Linear(192, 64),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout),
num_layers=trans_blocks,
),
),
nn.Sequential(
nn.Linear(320, 64),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout),
num_layers=trans_blocks,
),
),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=576, nhead=8, dropout=trans_dropout),
num_layers=trans_blocks,
),
]
)
self.transformer = transformer_list
nfc = nfc + 64
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))
else:
raise ValueError("Unsupported mil_mode: " + str(mil_mode))
self.myfc = nn.Linear(nfc, num_classes)
self.net = net
def calc_head(self, x: torch.Tensor) -> torch.Tensor:
sh = x.shape
if self.mil_mode == "mean":
x = self.myfc(x)
x = torch.mean(x, dim=1)
elif self.mil_mode == "max":
x = self.myfc(x)
x, _ = torch.max(x, dim=1)
elif self.mil_mode == "att":
a = self.attention(x)
a = torch.softmax(a, dim=1)
x = torch.sum(x * a, dim=1)
x = self.myfc(x)
elif self.mil_mode == "att_trans" and self.transformer is not None:
x = x.permute(1, 0, 2)
x = self.transformer(x)
x = x.permute(1, 0, 2)
a = self.attention(x)
a = torch.softmax(a, dim=1)
x = torch.sum(x * a, dim=1)
x = self.myfc(x)
elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None:
l1 = (
torch.mean(self.extra_outputs["layer1"], dim=(2, 3, 4))
.reshape(sh[0], sh[1], -1)
.permute(1, 0, 2)
)
l2 = (
torch.mean(self.extra_outputs["layer2"], dim=(2, 3, 4))
.reshape(sh[0], sh[1], -1)
.permute(1, 0, 2)
)
l3 = (
torch.mean(self.extra_outputs["layer3"], dim=(2, 3, 4))
.reshape(sh[0], sh[1], -1)
.permute(1, 0, 2)
)
l4 = (
torch.mean(self.extra_outputs["layer4"], dim=(2, 3, 4))
.reshape(sh[0], sh[1], -1)
.permute(1, 0, 2)
)
transformer_list = cast(nn.ModuleList, self.transformer)
x = transformer_list[0](l1)
x = transformer_list[1](torch.cat((x, l2), dim=2))
x = transformer_list[2](torch.cat((x, l3), dim=2))
x = transformer_list[3](torch.cat((x, l4), dim=2))
x = x.permute(1, 0, 2)
a = self.attention(x)
a = torch.softmax(a, dim=1)
x = torch.sum(x * a, dim=1)
x = self.myfc(x)
else:
raise ValueError("Wrong model mode" + str(self.mil_mode))
return x
def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor:
sh = x.shape
x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5])
x = self.net(x)
x = x.reshape(sh[0], sh[1], -1)
if not no_head:
x = self.calc_head(x)
return x