| | import os |
| | import json |
| | import pydicom |
| | import numpy as np |
| | import torch |
| |
|
| | from typing import Callable, Optional, Tuple |
| |
|
| | from torch import Tensor |
| | from torch.utils.data import Dataset |
| | from sklearn.preprocessing import RobustScaler |
| |
|
| | DTYPE = torch.float16 |
| |
|
| |
|
| | class SyntaxDataset(Dataset): |
| | def __init__( |
| | self, |
| | root: str, |
| | meta: str, |
| | train: bool, |
| | length: int, |
| | label: str, |
| | artery: str, |
| | inference: bool = False, |
| | validation: bool = False, |
| | transform: Optional[Callable] = None |
| | |
| | ) -> None: |
| | self.root = root |
| | self.train = train |
| | self.length = length |
| | self.label = label |
| | self.artery = artery |
| | self.inference = inference |
| | self.transform = transform |
| | self.validation = validation |
| | meta_path = meta if os.path.isabs(meta) else os.path.join(root, meta) |
| |
|
| | with open(meta_path) as f: |
| | dataset = json.load(f) |
| |
|
| | if not self.inference: |
| | dataset = [rec for rec in dataset if len(rec[f"videos_{artery}"]) > 0] |
| |
|
| | if validation: |
| | dataset = [rec for rec in dataset if rec[self.label] > 0] |
| |
|
| | self.dataset = dataset |
| |
|
| | artery_bin = {"left":0, "right":1}.get(artery.lower()) |
| | if artery_bin is None: |
| | raise ValueError(f"Unknown artery '{artery}'") |
| | |
| | self.artery_bin = artery_bin |
| |
|
| | def __len__(self): |
| | return len(self.dataset) |
| | |
| |
|
| | def get_sample_weights(self): |
| | |
| | bin_thresholds = { |
| | 0: [0, 5, 10, 15], |
| | 1: [0, 2, 5, 8], |
| | } |
| |
|
| | |
| | thresholds = bin_thresholds[self.artery_bin] |
| |
|
| | thr0, thr1, thr2, thr3 = thresholds |
| |
|
| | |
| | self.dataset_0 = [rec for rec in self.dataset if rec[self.label] == thr0] |
| | self.dataset_1 = [rec for rec in self.dataset if thr0 < rec[self.label] <= thr1] |
| | self.dataset_2 = [rec for rec in self.dataset if thr1 < rec[self.label] <= thr2] |
| | self.dataset_3 = [rec for rec in self.dataset if thr2 < rec[self.label] <= thr3] |
| | self.dataset_4 = [rec for rec in self.dataset if rec[self.label] > thr3] |
| |
|
| |
|
| | total = len(self.dataset_0) + len(self.dataset_1) + len(self.dataset_2) + len(self.dataset_3) + len(self.dataset_4) |
| |
|
| |
|
| | def safe_weight(count): |
| | return total / count if count > 0 else 0.0 |
| |
|
| | self.weights_0 = safe_weight(len(self.dataset_0)) |
| | self.weights_1 = safe_weight(len(self.dataset_1)) |
| | self.weights_2 = safe_weight(len(self.dataset_2)) |
| | self.weights_3 = safe_weight(len(self.dataset_3)) |
| | self.weights_4 = safe_weight(len(self.dataset_4)) |
| |
|
| | |
| | print("Counts: ", len(self.dataset_0), len(self.dataset_1), len(self.dataset_2), len(self.dataset_3), len(self.dataset_4)) |
| |
|
| | weights = [] |
| | for rec in self.dataset: |
| | syntax_score = rec[self.label] |
| | if syntax_score == thr0: |
| | weights.append(self.weights_0) |
| | elif thr0 < syntax_score <= thr1: |
| | weights.append(self.weights_1) |
| | elif thr1 < syntax_score <= thr2: |
| | weights.append(self.weights_2) |
| | elif thr2 < syntax_score <= thr3: |
| | weights.append(self.weights_3) |
| | else: |
| | weights.append(self.weights_4) |
| |
|
| | self.weights = torch.tensor(weights, dtype=DTYPE) |
| | return self.weights |
| |
|
| | def __getitem__(self, idx: int) -> Tuple[Tensor, int]: |
| |
|
| | rec = self.dataset[idx] |
| | suid = rec["study_uid"] |
| | |
| | |
| | if self.label: |
| | bin_thresholds = { |
| | 0: 15, |
| | 1: 5, |
| | } |
| |
|
| | label = torch.tensor([int(rec[self.label] > bin_thresholds[self.artery_bin])], dtype=DTYPE) |
| | target = torch.tensor([np.log(1.0+rec[self.label])], dtype=DTYPE) |
| | else: |
| | label = torch.tensor([0], dtype=DTYPE) |
| | target = torch.tensor([0], dtype=DTYPE) |
| |
|
| | nv = len(rec[f"videos_{self.artery}"]) |
| | if self.inference: |
| | if nv == 0: |
| | return 0, label, target, suid |
| | seq = range(nv) |
| | else: |
| | seq = torch.randint(low=0, high=nv, size = (4,)) |
| |
|
| | videos = [] |
| | for vi in seq: |
| | video_rec = rec[f"videos_{self.artery}"][vi] |
| | path = video_rec["path"] |
| | if os.path.isabs(path): |
| | full_path = path |
| | else: |
| | full_path = os.path.join(self.root, path) |
| |
|
| | video = pydicom.dcmread(full_path).pixel_array |
| |
|
| | if video.dtype == np.uint16: |
| | vmax = np.max(video) |
| | assert vmax > 0 |
| | video = video.astype(np.float32) |
| | video = video * (255. / vmax) |
| | video = video.astype(np.uint8) |
| | assert video.dtype == np.uint8 |
| |
|
| | while len(video) < self.length: |
| | video = np.concatenate([video, video]) |
| | t = len(video) |
| | if self.train: |
| | begin = torch.randint(low=0, high=t-self.length+1, size=(1,)) |
| | end = begin + self.length |
| | video = video[begin:end, :, :] |
| | else: |
| | begin = (t - self.length) // 2 |
| | end = begin + self.length |
| | video = video[begin:end, :, :] |
| | |
| | video = torch.tensor(np.stack([video, video, video], axis=-1)) |
| |
|
| | if self.transform is not None: |
| | video = self.transform(video) |
| | videos.append(video) |
| | videos = torch.stack(videos, dim=0) |
| |
|
| | |
| | return videos, label, target, suid |
| |
|