| | from cmath import nan |
| | import csv |
| | import json |
| | import logging |
| | import os |
| | import sys |
| | import pydicom |
| |
|
| | from abc import abstractmethod |
| | from itertools import islice |
| | from typing import List, Tuple, Dict, Any |
| | from torch.utils.data import DataLoader |
| | import PIL |
| | from torch.utils.data import Dataset |
| | import numpy as np |
| | import pandas as pd |
| | from torchvision import transforms |
| | from PIL import Image |
| | from skimage import exposure |
| | import torch |
| |
|
| | |
| | from torchvision.transforms import InterpolationMode |
| |
|
| |
|
| | class RSNA2018_Dataset(Dataset): |
| | def __init__(self, csv_path): |
| | data_info = pd.read_csv(csv_path) |
| | self.img_path_list = np.asarray(data_info.iloc[:, 1]) |
| | self.class_list = np.asarray(data_info.iloc[:, 3]) |
| | self.bbox = np.asarray(data_info.iloc[:, 2]) |
| | normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
| |
|
| | self.transform = transforms.Compose( |
| | [transforms.Resize([224, 224]), transforms.ToTensor(), normalize,] |
| | ) |
| | self.seg_transfrom = transforms.Compose( |
| | [ |
| | transforms.ToTensor(), |
| | transforms.Resize([224, 224], interpolation=InterpolationMode.NEAREST), |
| | ] |
| | ) |
| |
|
| | def __getitem__(self, index): |
| | img_path = self.img_path_list[index] |
| | class_label = np.array([self.class_list[index]]) |
| |
|
| | img = self.read_dcm(img_path) |
| | image = self.transform(img) |
| |
|
| | bbox = self.bbox[index] |
| | seg_map = np.zeros((1024, 1024)) |
| | if class_label == 1: |
| | boxes = bbox.split("|") |
| | for box in boxes: |
| | cc = box.split(";") |
| | seg_map[ |
| | int(float(cc[1])) : (int(float(cc[1])) + int(float(cc[3]))), |
| | int(float(cc[0])) : (int(float(cc[0])) + int(float(cc[2]))), |
| | ] = 1 |
| | seg_map = self.seg_transfrom(seg_map) |
| | return { |
| | "image": image, |
| | "label": class_label, |
| | "image_path": img_path, |
| | "seg_map": seg_map, |
| | } |
| |
|
| | def read_dcm(self, dcm_path): |
| | dcm_data = pydicom.read_file(dcm_path) |
| | img = dcm_data.pixel_array.astype(float) / 255.0 |
| | img = exposure.equalize_hist(img) |
| |
|
| | img = (255 * img).astype(np.uint8) |
| | img = PIL.Image.fromarray(img).convert("RGB") |
| | return img |
| |
|
| | def __len__(self): |
| | return len(self.img_path_list) |
| |
|
| |
|
| | def create_loader_RSNA( |
| | datasets, samplers, batch_size, num_workers, is_trains, collate_fns |
| | ): |
| | loaders = [] |
| | for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( |
| | datasets, samplers, batch_size, num_workers, is_trains, collate_fns |
| | ): |
| | if is_train: |
| | shuffle = sampler is None |
| | drop_last = True |
| | else: |
| | shuffle = False |
| | drop_last = False |
| | loader = DataLoader( |
| | dataset, |
| | batch_size=bs, |
| | num_workers=n_worker, |
| | pin_memory=True, |
| | sampler=sampler, |
| | shuffle=shuffle, |
| | collate_fn=collate_fn, |
| | drop_last=drop_last, |
| | ) |
| | loaders.append(loader) |
| | return loaders |
| |
|