Spaces:
Running
Running
File size: 6,137 Bytes
caf6ee7 906fcb9 caf6ee7 906fcb9 caf6ee7 906fcb9 1baebae 906fcb9 caf6ee7 906fcb9 caf6ee7 906fcb9 caf6ee7 906fcb9 caf6ee7 906fcb9 1baebae 906fcb9 caf6ee7 906fcb9 95dc457 6f43d62 95dc457 6f43d62 95dc457 6f43d62 1baebae 95dc457 caf6ee7 906fcb9 1baebae 906fcb9 caf6ee7 906fcb9 1baebae caf6ee7 1baebae 906fcb9 1baebae 906fcb9 1baebae 906fcb9 1baebae 906fcb9 1baebae caf6ee7 1baebae 906fcb9 1baebae 906fcb9 1baebae 906fcb9 1baebae 906fcb9 caf6ee7 6f43d62 95dc457 6f43d62 906fcb9 1baebae 906fcb9 1baebae 906fcb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import argparse
import os
from typing import Literal
import numpy as np
import torch
from monai.data import PersistentDataset, load_decathlon_datalist
from monai.transforms import (
Compose,
ConcatItemsd,
DeleteItemsd,
EnsureTyped,
LoadImaged,
NormalizeIntensityd,
RandCropByPosNegLabeld,
RandWeightedCropd,
ToTensord,
Transform,
Transposed,
)
from torch.utils.data.dataloader import default_collate
from .custom_transforms import (
ClipMaskIntensityPercentilesd,
ElementwiseProductd,
NormalizeIntensity_customd,
)
class DummyMILDataset(torch.utils.data.Dataset):
def __init__(self, args, num_samples=8):
self.num_samples = num_samples
self.args = args
def __len__(self):
return self.num_samples
def __getitem__(self, index):
# Simulate the output of your 'data_transform'
# A list of dictionaries, one for each 'tile_count' (patch)
bag = []
label_value = float(index % 2)
for _ in range(self.args.tile_count):
item = {
# Shape: (Channels=3, Depth, H, W) based on your Transposed(indices=(0, 3, 1, 2))
"image": torch.randn(3, self.args.depth, self.args.tile_size, self.args.tile_size),
"label": torch.tensor(label_value, dtype=torch.float32),
}
if self.args.use_heatmap:
item["final_heatmap"] = torch.randn(
1, self.args.depth, self.args.tile_size, self.args.tile_size
)
bag.append(item)
return bag
def list_data_collate(batch: list):
"""
Combine instances from a list of dicts into a single dict, by stacking them along first dim
[{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
followed by the default collate which will form a batch BxNx3xHxW
"""
for i, item in enumerate(batch):
data = item[0]
data["image"] = torch.stack([ix["image"] for ix in item], dim=0)
if all("final_heatmap" in ix for ix in item):
data["final_heatmap"] = torch.stack([ix["final_heatmap"] for ix in item], dim=0)
batch[i] = data
return default_collate(batch)
def data_transform(args: argparse.Namespace) -> Transform:
if args.use_heatmap:
transform = Compose(
[
LoadImaged(
keys=["image", "mask", "dwi", "adc", "heatmap"],
reader="ITKReader",
ensure_channel_first=True,
dtype=np.float32,
),
ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
ConcatItemsd(
keys=["image", "dwi", "adc"], name="image", dim=0
), # stacks to (3, H, W)
NormalizeIntensity_customd(keys=["image"], channel_wise=True, mask_key="mask"),
ElementwiseProductd(keys=["mask", "heatmap"], output_key="final_heatmap"),
RandWeightedCropd(
keys=["image", "final_heatmap"],
w_key="final_heatmap",
spatial_size=(args.tile_size, args.tile_size, args.depth),
num_samples=args.tile_count,
),
EnsureTyped(keys=["label"], dtype=torch.float32),
Transposed(keys=["image"], indices=(0, 3, 1, 2)),
DeleteItemsd(keys=["mask", "dwi", "adc", "heatmap"]),
ToTensord(keys=["image", "label", "final_heatmap"]),
]
)
else:
transform = Compose(
[
LoadImaged(
keys=["image", "mask", "dwi", "adc"],
reader="ITKReader",
ensure_channel_first=True,
dtype=np.float32,
),
ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
ConcatItemsd(
keys=["image", "dwi", "adc"], name="image", dim=0
), # stacks to (3, H, W)
NormalizeIntensityd(keys=["image"], channel_wise=True),
RandCropByPosNegLabeld(
keys=["image"],
label_key="mask",
spatial_size=(args.tile_size, args.tile_size, args.depth),
pos=1,
neg=0,
num_samples=args.tile_count,
),
EnsureTyped(keys=["label"], dtype=torch.float32),
Transposed(keys=["image"], indices=(0, 3, 1, 2)),
DeleteItemsd(keys=["mask", "dwi", "adc"]),
ToTensord(keys=["image", "label"]),
]
)
return transform
def get_dataloader(
args: argparse.Namespace, split: Literal["train", "test"]
) -> torch.utils.data.DataLoader:
if args.dry_run:
print(f"🛠️ DRY RUN: Creating synthetic {split} dataloader...")
dummy_ds = DummyMILDataset(args, num_samples=args.batch_size * 2)
return torch.utils.data.DataLoader(
dummy_ds,
batch_size=args.batch_size,
collate_fn=list_data_collate, # Uses your custom stacking logic
num_workers=0, # Keep it simple for dry run
)
data_list = load_decathlon_datalist(
data_list_file_path=args.dataset_json,
data_list_key=split,
base_dir=args.data_root,
)
cache_dir_ = os.path.join(args.logdir, "cache")
os.makedirs(os.path.join(cache_dir_, split), exist_ok=True)
transform = data_transform(args)
dataset = PersistentDataset(
data=data_list, transform=transform, cache_dir=os.path.join(cache_dir_, split)
)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=(split == "train"),
num_workers=args.workers,
pin_memory=True,
multiprocessing_context="fork" if args.workers > 0 else None,
sampler=None,
collate_fn=list_data_collate,
)
return loader
|