Omini3D / Dataloader /dataloader0.py
maxmo2009's picture
Initial upload: OmniMorph codebase
75854b3 verified
Raw
History Blame Contribute Delete
15.4 kB
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import nibabel as nib
from skimage.transform import rescale, resize, downscale_local_mean
from scipy.ndimage import zoom
import numpy as np
# import SimpleITK as sitk
# print(os.getcwd())
import sys
sys.path.append('./')
from Dataloader.dataloader_utils import *
EPS = 1e-7
def get_dataloader(data_name='cmr',mode='train'):
if data_name=='cmr':
if mode=='train':
dataloader=CMR_loader
elif mode =='aug':
dataloader=CMR_tgt_loader
else:
print('mode not exist')
elif data_name=='lct':
if mode=='train':
dataloader=LCT_loader
elif mode =='aug':
dataloader=LCT_tgt_loader
else:
print('mode not exist')
else:
print('dataloader not exist')
return dataloader
class LCT_loader(Dataset):
def __init__(self, data_root_path = f'Data/Src_data/CTLung_processed/', target_res = (256, 256),transforms = None, noise_scale=0.0, patient_index = None):
# def __init__(self, data_root_path = '/home/data/jzheng/CTLung_processed/', target_res = (256, 256),transforms = None, noise_scale=0.0, patient_index = None):
self.files = [data_root_path + f for f in os.listdir(data_root_path) if f.endswith('.npy')]
self.transforms = transforms
self.noise_scale=noise_scale
self.d_p = data_root_path
def __getitem__(self, item):
array = np.load(self.files[item])
if 'process' not in self.d_p:
array = (array - array.min()) / (array.max() - array.min() + EPS) # Normalize to 0 to 1
array = array[None,:,:,:] # add a channel to array make it (‘C’,H,W,Z)
if self.transforms != None:
array = self.transforms(array)
# print(array.shape)
return array, array, item # -> (B, C, H, W, Z)
# return array, array # -> (B, C, H, W, Z)
def __len__(self):
return len(self.files)
class LCT_tgt_loader(Dataset):
def __init__(self, data_root_path = "Data/Tgt_data/lct/",noise_scale=0.0, patient_index = None):
self.files_gt = [data_root_path + "Gt/" + f for f in os.listdir(data_root_path + "Gt/")]
self.files_tr = [data_root_path + 'Tr/' + f for f in os.listdir(data_root_path + "Tr/")]
self.files_tr.sort()
self.files_gt.sort()
self.transforms = transforms
self.noise_scale=noise_scale
def __getitem__(self, item):
img_nib = nib.load(self.files_tr[item])
mask_nib = nib.load(self.files_gt[item])
image = img_nib.get_fdata()
mask = mask_nib.get_fdata()
image = image[None,:,:,:]
mask = mask[None,:,:,:]
print(self.files_tr[item],self.files_gt[item])
return image, mask, item
def __len__(self):
assert len(self.files_gt) == len(self.files_tr)
return len(self.files_gt)
class LCT_seg(Dataset):
def __init__(self, data_root_path = "/home/data/jzheng/CTLung_processed/testset/modality_0001/",noise_scale=0.0, patient_index = None):
self.files_gt = [data_root_path + "Gt/" + f for f in os.listdir(data_root_path + "Gt/")]
self.files_tr = [data_root_path + 'Tr/' + f for f in os.listdir(data_root_path + "Tr/")]
self.files_tr.sort()
self.files_gt.sort()
self.transforms = transforms
self.noise_scale=noise_scale
def __getitem__(self, item):
img_nib = nib.load(self.files_tr[item])
mask_nib = nib.load(self.files_gt[item])
image = img_nib.get_fdata()
mask = mask_nib.get_fdata()
image = image[None,:,:,:]
mask = mask[None,:,:,:]
print(self.files_tr[item],self.files_gt[item])
return image, mask, item
def __len__(self):
assert len(self.files_gt) == len(self.files_tr)
return len(self.files_gt)
class CMR_loader_preprocess(Dataset):
# This is for pre_processing for CMR. not use for training model
def __init__(self, data_path = 'Data/CTLung_processed/', target_res = (256, 256), transforms = None, noise_scale=0.0):
# def __init__(self, data_path = '/home/data/jzheng/CMR_processed/', target_res = (256, 256), transforms = None, noise_scale=0.0):
self.d_p = data_path
self.target_res = target_res
self.files = [self.d_p + x for x in os.listdir(self.d_p)]
self.transforms = transforms
self.noise_scale=noise_scale
def __getitem__(self, item):
array = nib.load(self.files[item]).get_fdata()
array = resize(array, self.target_res, anti_aliasing = True, preserve_range = True)
array = array[None, :, :]
array = remove_background(array) # jzheng 20240228
array = (array - array.min()) / (array.max() - array.min() + EPS)
if self.noise_scale > 0:
array = thresh_img(array,[0,self.noise_scale])
array = array * (np.random.normal(1, self.noise_scale*2))
if self.transforms != None:
array = self.transforms(array)
return array, self.files[item]
def __len__(self):
return len(self.files)
class CMR_loader(Dataset):
# niff format size is (H,W) for CMR
# CMR_processed_rmbg_resize means the niif image has been gone throught rmbg and resize offline to make trainig fast
def __init__(self, data_path = f'Data/Src_data/CMR_processed_rmbg_resize/', target_res = (256, 256), transforms = None, noise_scale=0.0):
# def __init__(self, data_path = '/home/data/jzheng/CMR_processed_rmbg_resize/', target_res = (256, 256), transforms = None, noise_scale=0.0):
self.d_p = data_path
self.ndims = 2
self.target_res = target_res
self.files = [self.d_p + x for x in os.listdir(self.d_p)]
self.transforms = transforms
# self.get_transform()
self.noise_scale=noise_scale
self.preprocessed='resize' in data_path
def __getitem__(self, item):
array = nib.load(self.files[item]).get_fdata()
if not self.preprocessed:
array = resize(array, self.target_res, anti_aliasing = True, preserve_range = True)
array = array[None, :, :]
if not self.preprocessed:
array = remove_background(array) # jzheng 20240228
array = (array - array.min()) / (array.max() - array.min() + EPS)
# if self.noise_scale > 0:
# array = thresh_img(array,[0,self.noise_scale])
# array = array * (np.random.normal(1, self.noise_scale*2)) + np.random.normal(0, self.noise_scale*2)
if self.transforms != None:
array = self.transforms(array)
return array, array, item
def __len__(self):
return len(self.files)
def get_transform(self,degrees=np.pi,translate=0.125):
# self.transforms = torchvision.transforms.RandomAffine(degrees=degrees,translate=[translate]*self.ndims,interpolation=torchvision.transforms.InterpolationMode.BILINEAR)
self.transforms = torchvision.transforms.Compose([
# torchvision.transforms.Resize((hyp_parameters['img_size'], hyp_parameters['img_size'])),
torchvision.transforms.ToTensor(),
torchvision.transforms.RandomAffine(degrees=degrees,translate=[translate]*self.ndims,interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
# torchvision.transforms.ToTensor(),
# torchvision.transforms.Normalize(0.5, 0.5)
# Lambda(lambda x: (x - 0.5) * 2)
])
return
class CMR_tgt_loader(Dataset):
def __init__(self,
data_path = 'Data/Tgt_data/cmr/',
# gt_path = '/home/data/jzheng/acdc/train_gt/',
target_res = (256,256),
is_3d = False,
patient_index = [],
):
# parameter initialize
self.d_p = os.path.join(data_path,'Tr','')
self.gt_p = os.path.join(data_path,'Gt','')
self.img_files = os.listdir(self.d_p)
self.gt_files = os.listdir(self.gt_p)
self.p_indice = patient_index
self.target_res_2d = target_res
self.img_files.sort()
self.gt_files.sort()
self.img_samples = []
self.gt_samples = []
self.p_id = []
if len(self.p_indice) == 0:
self.p_indice = [x for x in range(1,101)]
# build patient-to-file correspondence
p2f = {}
assert len(self.gt_files) == len(self.img_files)
print(self.p_indice)
for i in self.p_indice:
for gt_f, img_f in zip(self.gt_files, self.img_files):
pf_id = gt_f.split('_')[0]
pf_id = pf_id[-3:]
if i == int(pf_id):
img_volume = nib.load(self.d_p + img_f).get_fdata()
gt_volume = nib.load(self.gt_p + gt_f).get_fdata()
assert img_volume.shape == gt_volume.shape
depth = img_volume.shape[2]
for si in range(depth):
img = resize(img_volume[:, :, si], self.target_res_2d, anti_aliasing=True, preserve_range=True)
img = (img - img.min()) / (img.max() - img.min() + EPS)
gt = gt_volume[:, :, si]
gt_1_index = gt == 1
gt_2_index = gt == 2
gt_3_index = gt == 3
gt_4_index = gt == 4
gt_1 = gt * gt_1_index
gt_2 = gt * gt_2_index
gt_3 = gt * gt_3_index
gt_4 = gt * gt_4_index
gt_1 = resize(gt_1, self.target_res_2d, anti_aliasing=True, preserve_range=True)
gt_2 = resize(gt_2, self.target_res_2d, anti_aliasing=True, preserve_range=True)
gt_3 = resize(gt_3, self.target_res_2d, anti_aliasing=True, preserve_range=True)
gt_4 = resize(gt_4, self.target_res_2d, anti_aliasing=True, preserve_range=True)
self.img_samples.append(img[np.newaxis, :, :])
self.gt_samples.append(np.array([gt_1, gt_2, gt_3, gt_4]))
self.p_id.append(i)
def __getitem__(self, item):
return self.img_samples[item], self.gt_samples[item], self.p_id[item]
def __len__(self):
assert len(self.img_samples) == len(self.gt_samples)
return len(self.img_samples)
class acdc_seg(Dataset):
def __init__(self,
data_path = '/home/data/jzheng/acdc/train_images/',
gt_path = '/home/data/jzheng/acdc/train_gt/',
target_res = (256,256),
is_3d = False,
patient_index = [],
):
# parameter initialize
self.d_p = data_path
self.gt_p = gt_path
self.img_files = os.listdir(self.d_p)
self.gt_files = os.listdir(self.gt_p)
self.p_indice = patient_index
self.target_res_2d = target_res
self.img_files.sort()
self.gt_files.sort()
self.img_samples = []
self.gt_samples = []
self.p_id = []
if len(self.p_indice) == 0:
self.p_indice = [x for x in range(1,101)]
# build patient-to-file correspondence
p2f = {}
assert len(self.gt_files) == len(self.img_files)
print(self.p_indice)
for i in self.p_indice:
for gt_f, img_f in zip(self.gt_files, self.img_files):
pf_id = gt_f.split('_')[0]
pf_id = pf_id[-3:]
if i == int(pf_id):
img_volume = nib.load(self.d_p + img_f).get_fdata()
gt_volume = nib.load(self.gt_p + gt_f).get_fdata()
assert img_volume.shape == gt_volume.shape
depth = img_volume.shape[2]
for si in range(depth):
img = resize(img_volume[:, :, si], self.target_res_2d, anti_aliasing=True, preserve_range=True)
img = (img - img.min()) / (img.max() - img.min() + EPS)
gt = gt_volume[:, :, si]
gt_1_index = gt == 1
gt_2_index = gt == 2
gt_3_index = gt == 3
gt_4_index = gt == 4
gt_1 = gt * gt_1_index
gt_2 = gt * gt_2_index
gt_3 = gt * gt_3_index
gt_4 = gt * gt_4_index
gt_1 = resize(gt_1, self.target_res_2d, anti_aliasing=True, preserve_range=True)
gt_2 = resize(gt_2, self.target_res_2d, anti_aliasing=True, preserve_range=True)
gt_3 = resize(gt_3, self.target_res_2d, anti_aliasing=True, preserve_range=True)
gt_4 = resize(gt_4, self.target_res_2d, anti_aliasing=True, preserve_range=True)
self.img_samples.append(img[np.newaxis, :, :])
self.gt_samples.append(np.array([gt_1, gt_2, gt_3, gt_4]))
self.p_id.append(i)
def __getitem__(self, item):
return self.img_samples[item], self.gt_samples[item], self.p_id[item]
def __len__(self):
assert len(self.img_samples) == len(self.gt_samples)
return len(self.img_samples)
class acdc_gan(Dataset):
def __init__(self,
train_path = '/home/data/jzheng/acdc/images/',
target_res = (32, 256, 256),
is_3d = False,
transforms = None
):
self.t_p = train_path
self.files = os.listdir(self.t_p)
self.sample_list_2d = []
self.is_3d = is_3d
self.target_res = target_res
self.res_2d = (target_res[1], target_res[2])
self.transforms = transforms
if self.is_3d == False:
for f in self.files:
img = nib.load(self.t_p + f).get_fdata()
depth = img.shape[2]
f_i = int(round(depth*0.1))
b_i = int(round(depth*0.9))
interval_slice = img[:, :, f_i:b_i]
for ii in range(interval_slice.shape[2]):
single_slice = interval_slice[:,:,ii]
single_slice = resize(single_slice, self.res_2d, anti_aliasing=True, preserve_range=True)
single_slice = (single_slice - single_slice.min()) / ( single_slice.max() - single_slice.min() + EPS)
self.sample_list_2d.append(single_slice[None,:,:])
def __len__(self):
if self.is_3d == False:
return len(self.sample_list_2d)
else:
return len(self.files )
def __getitem__(self, index):
if self.is_3d == False:
return self.sample_list_2d[index], self.sample_list_2d[index]
for f in self.files:
img = nib.load(self.t_p + f).get_fdata()
target_d_ratio = self.target_res[0] / img.shape[2]
target_w_ratio = self.target_res[1] / img.shape[0]
target_h_ratio = self.target_res[2] / img.shape[1]
resize_img = zoom(img, (target_w_ratio, target_h_ratio, target_d_ratio))
resize_img = np.swapaxes(resize_img, 0, 2)
resize_img = np.swapaxes(resize_img, 1, 2)
resize_img = (resize_img - resize_img.min()) / (resize_img.max() - resize_img.min() + EPS)
if transforms != None:
resize_img = self.transforms(resize_img)
return resize_img, resize_img
class acdc_gan_single_slice(Dataset):
def __init__(self, train_path = '/well/papiez/shared/ACDC/clean_training/images/'):
self.t_p = train_path
self.files = os.listdir(self.t_p)
def __len__(self):
return len(self.files)
def __getitem__(self, index):
img = self.files[index]
img = nib.load(self.t_p + img).get_fdata()
depth = img.shape[2]
mid_d = int(depth/2)
mid_slice = img[:,:,mid_d]
mid_slice = resize(mid_slice, (128, 128), anti_aliasing=True, preserve_range=True)
mid_slice = (mid_slice-mid_slice.min())/(mid_slice.max()-mid_slice.min()+EPS)
# print(mid_slice.max(),mid_slice.min())
return mid_slice, mid_slice