Spaces:
Running
Running
File size: 3,249 Bytes
cd89698 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
# Models Reference
## MILModel_3D
```python
class MILModel_3D(nn.Module):
def __init__(
self,
num_classes: int,
mil_mode: str = "att",
pretrained: bool = True,
backbone: str | nn.Module | None = None,
backbone_num_features: int | None = None,
trans_blocks: int = 4,
trans_dropout: float = 0.0,
)
```
**Constructor arguments:**
| Argument | Type | Default | Description |
|----------|------|---------|-------------|
| `num_classes` | `int` | β | Number of output classes |
| `mil_mode` | `str` | `"att"` | MIL aggregation mode |
| `pretrained` | `bool` | `True` | Use pretrained backbone weights |
| `backbone` | `str \| nn.Module \| None` | `None` | Backbone CNN (None = ResNet18-3D) |
| `backbone_num_features` | `int \| None` | `None` | Output features of custom backbone |
| `trans_blocks` | `int` | `4` | Number of transformer encoder layers |
| `trans_dropout` | `float` | `0.0` | Transformer dropout rate |
**MIL modes:**
| Mode | Description |
|------|-------------|
| `mean` | Average logits across all patches β equivalent to pure CNN |
| `max` | Keep only the max-probability instance for loss |
| `att` | Attention-based MIL ([Ilse et al., 2018](https://arxiv.org/abs/1802.04712)) |
| `att_trans` | Transformer + attention MIL ([Shao et al., 2021](https://arxiv.org/abs/2111.01556)) |
| `att_trans_pyramid` | Pyramid transformer using intermediate ResNet layers |
**Key methods:**
- `forward(x, no_head=False)` β Full forward pass. If `no_head=True`, returns patch-level features `[B, N, 512]` before transformer and attention pooling (used during attention loss computation).
- `calc_head(x)` β Applies the MIL aggregation and classification head to patch features.
**Example:**
```python
import torch
from src.model.MIL import MILModel_3D
model = MILModel_3D(num_classes=4, mil_mode="att_trans")
# Input: [batch, patches, channels, depth, height, width]
x = torch.randn(2, 24, 3, 3, 64, 64)
logits = model(x) # [2, 4]
```
## csPCa_Model
```python
class csPCa_Model(nn.Module):
def __init__(self, backbone: nn.Module)
```
Wraps a pre-trained `MILModel_3D` backbone for binary csPCa prediction. The backbone's feature extractor, transformer, and attention mechanism are reused. The original classification head (`myfc`) is replaced by a `SimpleNN`.
**Attributes:**
| Attribute | Type | Description |
|-----------|------|-------------|
| `backbone` | `MILModel_3D` | Frozen PI-RADS backbone |
| `fc_cspca` | `SimpleNN` | Binary classification head |
| `fc_dim` | `int` | Feature dimension (512 for ResNet18) |
**Example:**
```python
import torch
from src.model.MIL import MILModel_3D
from src.model.csPCa_model import csPCa_Model
backbone = MILModel_3D(num_classes=4, mil_mode="att_trans")
model = csPCa_Model(backbone=backbone)
x = torch.randn(2, 24, 3, 3, 64, 64)
prob = model(x) # [2, 1] β sigmoid probabilities
```
## SimpleNN
```python
class SimpleNN(nn.Module):
def __init__(self, input_dim: int)
```
A lightweight MLP for binary classification:
```
Linear(input_dim, 256) β ReLU
Linear(256, 128) β ReLU β Dropout(0.3)
Linear(128, 1) β Sigmoid
```
Input: `[B, input_dim]` β Output: `[B, 1]` (probability).
|