Spaces:
Running
Running
Models Reference
MILModel_3D
class MILModel_3D(nn.Module):
def __init__(
self,
num_classes: int,
mil_mode: str = "att",
pretrained: bool = True,
backbone: str | nn.Module | None = None,
backbone_num_features: int | None = None,
trans_blocks: int = 4,
trans_dropout: float = 0.0,
)
Constructor arguments:
| Argument | Type | Default | Description |
|---|---|---|---|
num_classes |
int |
β | Number of output classes |
mil_mode |
str |
"att" |
MIL aggregation mode |
pretrained |
bool |
True |
Use pretrained backbone weights |
backbone |
str | nn.Module | None |
None |
Backbone CNN (None = ResNet18-3D) |
backbone_num_features |
int | None |
None |
Output features of custom backbone |
trans_blocks |
int |
4 |
Number of transformer encoder layers |
trans_dropout |
float |
0.0 |
Transformer dropout rate |
MIL modes:
| Mode | Description |
|---|---|
mean |
Average logits across all patches β equivalent to pure CNN |
max |
Keep only the max-probability instance for loss |
att |
Attention-based MIL (Ilse et al., 2018) |
att_trans |
Transformer + attention MIL (Shao et al., 2021) |
att_trans_pyramid |
Pyramid transformer using intermediate ResNet layers |
Key methods:
forward(x, no_head=False)β Full forward pass. Ifno_head=True, returns patch-level features[B, N, 512]before transformer and attention pooling (used during attention loss computation).calc_head(x)β Applies the MIL aggregation and classification head to patch features.
Example:
import torch
from src.model.MIL import MILModel_3D
model = MILModel_3D(num_classes=4, mil_mode="att_trans")
# Input: [batch, patches, channels, depth, height, width]
x = torch.randn(2, 24, 3, 3, 64, 64)
logits = model(x) # [2, 4]
csPCa_Model
class csPCa_Model(nn.Module):
def __init__(self, backbone: nn.Module)
Wraps a pre-trained MILModel_3D backbone for binary csPCa prediction. The backbone's feature extractor, transformer, and attention mechanism are reused. The original classification head (myfc) is replaced by a SimpleNN.
Attributes:
| Attribute | Type | Description |
|---|---|---|
backbone |
MILModel_3D |
Frozen PI-RADS backbone |
fc_cspca |
SimpleNN |
Binary classification head |
fc_dim |
int |
Feature dimension (512 for ResNet18) |
Example:
import torch
from src.model.MIL import MILModel_3D
from src.model.csPCa_model import csPCa_Model
backbone = MILModel_3D(num_classes=4, mil_mode="att_trans")
model = csPCa_Model(backbone=backbone)
x = torch.randn(2, 24, 3, 3, 64, 64)
prob = model(x) # [2, 1] β sigmoid probabilities
SimpleNN
class SimpleNN(nn.Module):
def __init__(self, input_dim: int)
A lightweight MLP for binary classification:
Linear(input_dim, 256) β ReLU
Linear(256, 128) β ReLU β Dropout(0.3)
Linear(128, 1) β Sigmoid
Input: [B, input_dim] β Output: [B, 1] (probability).