File size: 4,525 Bytes
cd89698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Data Loading Reference

## get_dataloader

```python
def get_dataloader(args, split: Literal["train", "test"]) -> DataLoader
```

Creates a PyTorch DataLoader with MONAI transforms and persistent caching.

**Parameters:**

| Parameter | Description |
|-----------|-------------|
| `args` | Namespace with `dataset_json`, `data_root`, `tile_size`, `tile_count`, `depth`, `use_heatmap`, `batch_size`, `workers`, `dry_run`, `logdir` |
| `split` | `"train"` or `"test"` |

**Behavior:**

- Loads data lists from a MONAI decathlon-format JSON
- In `dry_run` mode, limits to 8 samples
- Uses `PersistentDataset` with cache stored at `<logdir>/cache/<split>/`
- Training split is shuffled; test split is not
- Uses `list_data_collate` to stack patches into `[B, N, C, D, H, W]`

## Transform Pipeline

Two variants depending on `args.use_heatmap`:

### With Heatmaps (default)

| Step | Transform | Description |
|------|-----------|-------------|
| 1 | `LoadImaged` | Load T2, mask, DWI, ADC, heatmap (ITKReader, channel-first) |
| 2 | `ClipMaskIntensityPercentilesd` | Clip T2 intensity to [0, 99.5] percentiles within mask |
| 3 | `ConcatItemsd` | Stack T2 + DWI + ADC → 3-channel image |
| 4 | `NormalizeIntensity_customd` | Z-score normalize per channel using mask-only statistics |
| 5 | `ElementwiseProductd` | Multiply mask * heatmap → `final_heatmap` |
| 6 | `RandWeightedCropd` | Extract N patches weighted by `final_heatmap` |
| 7 | `EnsureTyped` | Cast labels to float32 |
| 8 | `Transposed` | Reorder image dims for 3D convolution |
| 9 | `DeleteItemsd` | Remove intermediate keys (mask, dwi, adc, heatmap) |
| 10 | `ToTensord` | Convert to PyTorch tensors |

### Without Heatmaps

| Step | Transform | Description |
|------|-----------|-------------|
| 1 | `LoadImaged` | Load T2, mask, DWI, ADC |
| 2 | `ClipMaskIntensityPercentilesd` | Clip T2 intensity to [0, 99.5] percentiles within mask |
| 3 | `ConcatItemsd` | Stack T2 + DWI + ADC → 3-channel image |
| 4 | `NormalizeIntensityd` | Standard channel-wise normalization (MONAI built-in) |
| 5 | `RandCropByPosNegLabeld` | Extract N patches from positive (mask) regions |
| 6 | `EnsureTyped` | Cast labels to float32 |
| 7 | `Transposed` | Reorder image dims |
| 8 | `DeleteItemsd` | Remove intermediate keys |
| 9 | `ToTensord` | Convert to tensors |

## list_data_collate

```python
def list_data_collate(batch: Sequence) -> dict
```

Custom collation function that stacks per-patient patch lists into batch tensors.

Each sample from the dataset is a list of N patch dictionaries. This function:

1. Stacks `image` across patches: `[N, C, D, H, W]` per sample
2. Stacks `final_heatmap` if present
3. Applies PyTorch's `default_collate` to form the batch dimension

Result: `{"image": [B, N, C, D, H, W], "label": [B], ...}`

## Custom Transforms

### ClipMaskIntensityPercentilesd

```python
ClipMaskIntensityPercentilesd(
    keys: KeysCollection,
    mask_key: str,
    lower: float | None,
    upper: float | None,
    sharpness_factor: float | None = None,
    channel_wise: bool = False,
    dtype: DtypeLike = np.float32,
)
```

Clips image intensity to percentiles computed only from the **masked region**. Supports both hard clipping (default) and soft clipping (via `sharpness_factor`).

### NormalizeIntensity_customd

```python
NormalizeIntensity_customd(
    keys: KeysCollection,
    mask_key: str,
    subtrahend: NdarrayOrTensor | None = None,
    divisor: NdarrayOrTensor | None = None,
    nonzero: bool = False,
    channel_wise: bool = False,
    dtype: DtypeLike = np.float32,
)
```

Z-score normalization where mean and standard deviation are computed only from **masked voxels**. Supports channel-wise normalization.

### ElementwiseProductd

```python
ElementwiseProductd(
    keys: KeysCollection,
    output_key: str,
)
```

Computes the element-wise product of two arrays from the data dictionary and stores the result in `output_key`. Used to combine the prostate mask with the attention heatmap.

## Dataset JSON Format

The pipeline expects a MONAI decathlon-format JSON file:

```json
{
    "train": [
        {
            "image": "relative/path/to/t2.nrrd",
            "dwi": "relative/path/to/dwi.nrrd",
            "adc": "relative/path/to/adc.nrrd",
            "mask": "relative/path/to/mask.nrrd",
            "heatmap": "relative/path/to/heatmap.nrrd",
            "label": 2
        }
    ],
    "test": [...]
}
```

Paths are relative to `data_root`. The `heatmap` key is only required when `use_heatmap=True`.