import os import json import pydicom import numpy as np import torch from typing import Callable, Optional, Tuple from torch import Tensor from torch.utils.data import Dataset from sklearn.preprocessing import RobustScaler DTYPE = torch.float16 class SyntaxDataset(Dataset): def __init__( self, root: str, # dataset dir meta: str, # metadata train: bool, # training mode length: int, # video length label: str, # label field name artery: str, # left or right artery inference: bool = False, validation: bool = False, transform: Optional[Callable] = None ) -> None: self.root = root self.train = train self.length = length self.label = label self.artery = artery self.inference = inference self.transform = transform self.validation = validation meta_path = meta if os.path.isabs(meta) else os.path.join(root, meta) with open(meta_path) as f: dataset = json.load(f) if not self.inference: dataset = [rec for rec in dataset if len(rec[f"videos_{artery}"]) > 0] if validation: dataset = [rec for rec in dataset if rec[self.label] > 0] self.dataset = dataset artery_bin = {"left":0, "right":1}.get(artery.lower()) if artery_bin is None: raise ValueError(f"Unknown artery '{artery}'") self.artery_bin = artery_bin def __len__(self): return len(self.dataset) def get_sample_weights(self): # пороги для левой (0) и правой (1) артерии bin_thresholds = { 0: [0, 5, 10, 15], # левая 1: [0, 2, 5, 8], # правая } # выберем пороги для текущей артерии thresholds = bin_thresholds[self.artery_bin] thr0, thr1, thr2, thr3 = thresholds # разбиваем датасет по интервалам self.dataset_0 = [rec for rec in self.dataset if rec[self.label] == thr0] self.dataset_1 = [rec for rec in self.dataset if thr0 < rec[self.label] <= thr1] self.dataset_2 = [rec for rec in self.dataset if thr1 < rec[self.label] <= thr2] self.dataset_3 = [rec for rec in self.dataset if thr2 < rec[self.label] <= thr3] self.dataset_4 = [rec for rec in self.dataset if rec[self.label] > thr3] total = len(self.dataset_0) + len(self.dataset_1) + len(self.dataset_2) + len(self.dataset_3) + len(self.dataset_4) def safe_weight(count): return total / count if count > 0 else 0.0 self.weights_0 = safe_weight(len(self.dataset_0)) self.weights_1 = safe_weight(len(self.dataset_1)) self.weights_2 = safe_weight(len(self.dataset_2)) self.weights_3 = safe_weight(len(self.dataset_3)) self.weights_4 = safe_weight(len(self.dataset_4)) # print("Weights: ", self.weights_0, self.weights_1, self.weights_2, self.weights_3, self.weights_4) print("Counts: ", len(self.dataset_0), len(self.dataset_1), len(self.dataset_2), len(self.dataset_3), len(self.dataset_4)) weights = [] for rec in self.dataset: syntax_score = rec[self.label] if syntax_score == thr0: weights.append(self.weights_0) elif thr0 < syntax_score <= thr1: weights.append(self.weights_1) elif thr1 < syntax_score <= thr2: weights.append(self.weights_2) elif thr2 < syntax_score <= thr3: weights.append(self.weights_3) else: weights.append(self.weights_4) self.weights = torch.tensor(weights, dtype=DTYPE) return self.weights def __getitem__(self, idx: int) -> Tuple[Tensor, int]: rec = self.dataset[idx] suid = rec["study_uid"] if self.label: bin_thresholds = { 0: 15, # левая 1: 5, # правая } label = torch.tensor([int(rec[self.label] > bin_thresholds[self.artery_bin])], dtype=DTYPE) target = torch.tensor([np.log(1.0+rec[self.label])], dtype=DTYPE) else: label = torch.tensor([0], dtype=DTYPE) target = torch.tensor([0], dtype=DTYPE) nv = len(rec[f"videos_{self.artery}"]) if self.inference: if nv == 0: return 0, label, target, suid seq = range(nv) else: seq = torch.randint(low=0, high=nv, size = (4,)) videos = [] for vi in seq: video_rec = rec[f"videos_{self.artery}"][vi] path = video_rec["path"] if os.path.isabs(path): full_path = path else: full_path = os.path.join(self.root, path) video = pydicom.dcmread(full_path).pixel_array # Time, HW or WH if video.dtype == np.uint16: vmax = np.max(video) assert vmax > 0 video = video.astype(np.float32) video = video * (255. / vmax) video = video.astype(np.uint8) assert video.dtype == np.uint8 while len(video) < self.length: video = np.concatenate([video, video]) t = len(video) if self.train: begin = torch.randint(low=0, high=t-self.length+1, size=(1,)) end = begin + self.length video = video[begin:end, :, :] else: begin = (t - self.length) // 2 end = begin + self.length video = video[begin:end, :, :] video = torch.tensor(np.stack([video, video, video], axis=-1)) if self.transform is not None: video = self.transform(video) videos.append(video) videos = torch.stack(videos, dim=0) return videos, label, target, suid