File size: 2,985 Bytes
906fcb9
caf6ee7
906fcb9
 
 
 
1baebae
906fcb9
 
 
1baebae
 
 
 
 
 
 
 
 
 
caf6ee7
 
906fcb9
 
 
1baebae
906fcb9
1baebae
906fcb9
1baebae
906fcb9
1baebae
906fcb9
1baebae
 
 
 
 
 
 
 
 
906fcb9
 
1baebae
caf6ee7
1baebae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caf6ee7
906fcb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from __future__ import annotations

import torch
import torch.nn as nn
from monai.utils.module import optional_import

models, _ = optional_import("torchvision.models")


class SimpleNN(nn.Module):
    """
    A simple Multi-Layer Perceptron (MLP) for binary classification.

    This network consists of two hidden layers with ReLU activation and a dropout layer,
    followed by a final sigmoid activation for probability output.

    Args:
        input_dim (int): The number of input features.
    """

    def __init__(self, input_dim: int) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(128, 1),
            nn.Sigmoid(),  # since binary classification
        )

    def forward(self, x):
        """
        Forward pass of the classifier.

        Args:
            x (torch.Tensor): Input tensor of shape (Batch, input_dim).

        Returns:
            torch.Tensor: Output probabilities of shape (Batch, 1).
        """
        return self.net(x)


class CSPCAModel(nn.Module):
    """
    Clinically Significant Prostate Cancer (csPCa) risk prediction model using a MIL backbone.

    This model repurposes a pre-trained Multiple Instance Learning (MIL) backbone (originally
    designed for PI-RADS prediction) for binary csPCa risk assessment. It utilizes the
    backbone's feature extractor, transformer, and attention mechanism to aggregate instance-level
    features into a bag-level embedding.

    The original fully connected classification head of the backbone is replaced by a
    custom :class:`SimpleNN` head for the new task.

    Args:
        backbone (nn.Module): A pre-trained MIL model. The backbone must possess the
            following attributes/sub-modules:
            - ``net``: The CNN feature extractor.
            - ``transformer``: A sequence modeling module.
            - ``attention``: An attention mechanism for pooling.
            - ``myfc``: The original fully connected layer (used to determine feature dimensions).

    Attributes:
        fc_cspca (SimpleNN): The new classification head for csPCa prediction.
        backbone: The MIL based PI-RADS classifier.
    """

    def __init__(self, backbone: nn.Module) -> None:
        super().__init__()
        self.backbone = backbone
        self.fc_dim = backbone.myfc.in_features
        self.fc_cspca = SimpleNN(input_dim=self.fc_dim)

    def forward(self, x):
        sh = x.shape
        x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5])
        x = self.backbone.net(x)
        x = x.reshape(sh[0], sh[1], -1)
        x = x.permute(1, 0, 2)
        x = self.backbone.transformer(x)
        x = x.permute(1, 0, 2)
        a = self.backbone.attention(x)
        a = torch.softmax(a, dim=1)
        x = torch.sum(x * a, dim=1)

        x = self.fc_cspca(x)
        return x