syntax-model / full_model /rnn_dataset.py
MesserMMP's picture
add full model files
c2d9714
import os
import json
import pydicom
import numpy as np
import torch
from typing import Callable, Optional, Tuple
from torch import Tensor
from torch.utils.data import Dataset
from sklearn.preprocessing import RobustScaler
DTYPE = torch.float16
class SyntaxDataset(Dataset):
def __init__(
self,
root: str, # dataset dir
meta: str, # metadata
train: bool, # training mode
length: int, # video length
label: str, # label field name
artery: str, # left or right artery
inference: bool = False,
validation: bool = False,
transform: Optional[Callable] = None
) -> None:
self.root = root
self.train = train
self.length = length
self.label = label
self.artery = artery
self.inference = inference
self.transform = transform
self.validation = validation
meta_path = meta if os.path.isabs(meta) else os.path.join(root, meta)
with open(meta_path) as f:
dataset = json.load(f)
if not self.inference:
dataset = [rec for rec in dataset if len(rec[f"videos_{artery}"]) > 0]
if validation:
dataset = [rec for rec in dataset if rec[self.label] > 0]
self.dataset = dataset
artery_bin = {"left":0, "right":1}.get(artery.lower())
if artery_bin is None:
raise ValueError(f"Unknown artery '{artery}'")
self.artery_bin = artery_bin
def __len__(self):
return len(self.dataset)
def get_sample_weights(self):
# пороги для левой (0) и правой (1) артерии
bin_thresholds = {
0: [0, 5, 10, 15], # левая
1: [0, 2, 5, 8], # правая
}
# выберем пороги для текущей артерии
thresholds = bin_thresholds[self.artery_bin]
thr0, thr1, thr2, thr3 = thresholds
# разбиваем датасет по интервалам
self.dataset_0 = [rec for rec in self.dataset if rec[self.label] == thr0]
self.dataset_1 = [rec for rec in self.dataset if thr0 < rec[self.label] <= thr1]
self.dataset_2 = [rec for rec in self.dataset if thr1 < rec[self.label] <= thr2]
self.dataset_3 = [rec for rec in self.dataset if thr2 < rec[self.label] <= thr3]
self.dataset_4 = [rec for rec in self.dataset if rec[self.label] > thr3]
total = len(self.dataset_0) + len(self.dataset_1) + len(self.dataset_2) + len(self.dataset_3) + len(self.dataset_4)
def safe_weight(count):
return total / count if count > 0 else 0.0
self.weights_0 = safe_weight(len(self.dataset_0))
self.weights_1 = safe_weight(len(self.dataset_1))
self.weights_2 = safe_weight(len(self.dataset_2))
self.weights_3 = safe_weight(len(self.dataset_3))
self.weights_4 = safe_weight(len(self.dataset_4))
# print("Weights: ", self.weights_0, self.weights_1, self.weights_2, self.weights_3, self.weights_4)
print("Counts: ", len(self.dataset_0), len(self.dataset_1), len(self.dataset_2), len(self.dataset_3), len(self.dataset_4))
weights = []
for rec in self.dataset:
syntax_score = rec[self.label]
if syntax_score == thr0:
weights.append(self.weights_0)
elif thr0 < syntax_score <= thr1:
weights.append(self.weights_1)
elif thr1 < syntax_score <= thr2:
weights.append(self.weights_2)
elif thr2 < syntax_score <= thr3:
weights.append(self.weights_3)
else:
weights.append(self.weights_4)
self.weights = torch.tensor(weights, dtype=DTYPE)
return self.weights
def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
rec = self.dataset[idx]
suid = rec["study_uid"]
if self.label:
bin_thresholds = {
0: 15, # левая
1: 5, # правая
}
label = torch.tensor([int(rec[self.label] > bin_thresholds[self.artery_bin])], dtype=DTYPE)
target = torch.tensor([np.log(1.0+rec[self.label])], dtype=DTYPE)
else:
label = torch.tensor([0], dtype=DTYPE)
target = torch.tensor([0], dtype=DTYPE)
nv = len(rec[f"videos_{self.artery}"])
if self.inference:
if nv == 0:
return 0, label, target, suid
seq = range(nv)
else:
seq = torch.randint(low=0, high=nv, size = (4,))
videos = []
for vi in seq:
video_rec = rec[f"videos_{self.artery}"][vi]
path = video_rec["path"]
if os.path.isabs(path):
full_path = path
else:
full_path = os.path.join(self.root, path)
video = pydicom.dcmread(full_path).pixel_array # Time, HW or WH
if video.dtype == np.uint16:
vmax = np.max(video)
assert vmax > 0
video = video.astype(np.float32)
video = video * (255. / vmax)
video = video.astype(np.uint8)
assert video.dtype == np.uint8
while len(video) < self.length:
video = np.concatenate([video, video])
t = len(video)
if self.train:
begin = torch.randint(low=0, high=t-self.length+1, size=(1,))
end = begin + self.length
video = video[begin:end, :, :]
else:
begin = (t - self.length) // 2
end = begin + self.length
video = video[begin:end, :, :]
video = torch.tensor(np.stack([video, video, video], axis=-1))
if self.transform is not None:
video = self.transform(video)
videos.append(video)
videos = torch.stack(videos, dim=0)
return videos, label, target, suid