Spaces:
Running
Running
File size: 4,525 Bytes
cd89698 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Data Loading Reference
## get_dataloader
```python
def get_dataloader(args, split: Literal["train", "test"]) -> DataLoader
```
Creates a PyTorch DataLoader with MONAI transforms and persistent caching.
**Parameters:**
| Parameter | Description |
|-----------|-------------|
| `args` | Namespace with `dataset_json`, `data_root`, `tile_size`, `tile_count`, `depth`, `use_heatmap`, `batch_size`, `workers`, `dry_run`, `logdir` |
| `split` | `"train"` or `"test"` |
**Behavior:**
- Loads data lists from a MONAI decathlon-format JSON
- In `dry_run` mode, limits to 8 samples
- Uses `PersistentDataset` with cache stored at `<logdir>/cache/<split>/`
- Training split is shuffled; test split is not
- Uses `list_data_collate` to stack patches into `[B, N, C, D, H, W]`
## Transform Pipeline
Two variants depending on `args.use_heatmap`:
### With Heatmaps (default)
| Step | Transform | Description |
|------|-----------|-------------|
| 1 | `LoadImaged` | Load T2, mask, DWI, ADC, heatmap (ITKReader, channel-first) |
| 2 | `ClipMaskIntensityPercentilesd` | Clip T2 intensity to [0, 99.5] percentiles within mask |
| 3 | `ConcatItemsd` | Stack T2 + DWI + ADC → 3-channel image |
| 4 | `NormalizeIntensity_customd` | Z-score normalize per channel using mask-only statistics |
| 5 | `ElementwiseProductd` | Multiply mask * heatmap → `final_heatmap` |
| 6 | `RandWeightedCropd` | Extract N patches weighted by `final_heatmap` |
| 7 | `EnsureTyped` | Cast labels to float32 |
| 8 | `Transposed` | Reorder image dims for 3D convolution |
| 9 | `DeleteItemsd` | Remove intermediate keys (mask, dwi, adc, heatmap) |
| 10 | `ToTensord` | Convert to PyTorch tensors |
### Without Heatmaps
| Step | Transform | Description |
|------|-----------|-------------|
| 1 | `LoadImaged` | Load T2, mask, DWI, ADC |
| 2 | `ClipMaskIntensityPercentilesd` | Clip T2 intensity to [0, 99.5] percentiles within mask |
| 3 | `ConcatItemsd` | Stack T2 + DWI + ADC → 3-channel image |
| 4 | `NormalizeIntensityd` | Standard channel-wise normalization (MONAI built-in) |
| 5 | `RandCropByPosNegLabeld` | Extract N patches from positive (mask) regions |
| 6 | `EnsureTyped` | Cast labels to float32 |
| 7 | `Transposed` | Reorder image dims |
| 8 | `DeleteItemsd` | Remove intermediate keys |
| 9 | `ToTensord` | Convert to tensors |
## list_data_collate
```python
def list_data_collate(batch: Sequence) -> dict
```
Custom collation function that stacks per-patient patch lists into batch tensors.
Each sample from the dataset is a list of N patch dictionaries. This function:
1. Stacks `image` across patches: `[N, C, D, H, W]` per sample
2. Stacks `final_heatmap` if present
3. Applies PyTorch's `default_collate` to form the batch dimension
Result: `{"image": [B, N, C, D, H, W], "label": [B], ...}`
## Custom Transforms
### ClipMaskIntensityPercentilesd
```python
ClipMaskIntensityPercentilesd(
keys: KeysCollection,
mask_key: str,
lower: float | None,
upper: float | None,
sharpness_factor: float | None = None,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
)
```
Clips image intensity to percentiles computed only from the **masked region**. Supports both hard clipping (default) and soft clipping (via `sharpness_factor`).
### NormalizeIntensity_customd
```python
NormalizeIntensity_customd(
keys: KeysCollection,
mask_key: str,
subtrahend: NdarrayOrTensor | None = None,
divisor: NdarrayOrTensor | None = None,
nonzero: bool = False,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
)
```
Z-score normalization where mean and standard deviation are computed only from **masked voxels**. Supports channel-wise normalization.
### ElementwiseProductd
```python
ElementwiseProductd(
keys: KeysCollection,
output_key: str,
)
```
Computes the element-wise product of two arrays from the data dictionary and stores the result in `output_key`. Used to combine the prostate mask with the attention heatmap.
## Dataset JSON Format
The pipeline expects a MONAI decathlon-format JSON file:
```json
{
"train": [
{
"image": "relative/path/to/t2.nrrd",
"dwi": "relative/path/to/dwi.nrrd",
"adc": "relative/path/to/adc.nrrd",
"mask": "relative/path/to/mask.nrrd",
"heatmap": "relative/path/to/heatmap.nrrd",
"label": 2
}
],
"test": [...]
}
```
Paths are relative to `data_root`. The `heatmap` key is only required when `use_heatmap=True`.
|