from __future__ import annotations import torch import torch.nn as nn from monai.utils.module import optional_import models, _ = optional_import("torchvision.models") class SimpleNN(nn.Module): """ A simple Multi-Layer Perceptron (MLP) for binary classification. This network consists of two hidden layers with ReLU activation and a dropout layer, followed by a final sigmoid activation for probability output. Args: input_dim (int): The number of input features. """ def __init__(self, input_dim: int) -> None: super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(128, 1), nn.Sigmoid(), # since binary classification ) def forward(self, x): """ Forward pass of the classifier. Args: x (torch.Tensor): Input tensor of shape (Batch, input_dim). Returns: torch.Tensor: Output probabilities of shape (Batch, 1). """ return self.net(x) class CSPCAModel(nn.Module): """ Clinically Significant Prostate Cancer (csPCa) risk prediction model using a MIL backbone. This model repurposes a pre-trained Multiple Instance Learning (MIL) backbone (originally designed for PI-RADS prediction) for binary csPCa risk assessment. It utilizes the backbone's feature extractor, transformer, and attention mechanism to aggregate instance-level features into a bag-level embedding. The original fully connected classification head of the backbone is replaced by a custom :class:`SimpleNN` head for the new task. Args: backbone (nn.Module): A pre-trained MIL model. The backbone must possess the following attributes/sub-modules: - ``net``: The CNN feature extractor. - ``transformer``: A sequence modeling module. - ``attention``: An attention mechanism for pooling. - ``myfc``: The original fully connected layer (used to determine feature dimensions). Attributes: fc_cspca (SimpleNN): The new classification head for csPCa prediction. backbone: The MIL based PI-RADS classifier. """ def __init__(self, backbone: nn.Module) -> None: super().__init__() self.backbone = backbone self.fc_dim = backbone.myfc.in_features self.fc_cspca = SimpleNN(input_dim=self.fc_dim) def forward(self, x): sh = x.shape x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5]) x = self.backbone.net(x) x = x.reshape(sh[0], sh[1], -1) x = x.permute(1, 0, 2) x = self.backbone.transformer(x) x = x.permute(1, 0, 2) a = self.backbone.attention(x) a = torch.softmax(a, dim=1) x = torch.sum(x * a, dim=1) x = self.fc_cspca(x) return x