File size: 9,661 Bytes
3ca1d12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import os
import json
import pydicom
import numpy as np
import torch

from typing import Callable, Optional, Tuple
from torch import Tensor
from torch.utils.data import Dataset

# Полуточность достаточно для хранения весов и таргетов,
# а сами вычисления в модели идут в float32 / bf16.
DTYPE = torch.float16


class SyntaxDataset(Dataset):
    """
    PyTorch Dataset для обучения видеобэкбона на задаче SYNTAX.

    Функциональность:
    - читает метаданные из JSON (относительный путь относительно root);
    - фильтрует по артерии (левая / правая);
    - опционально отфильтровывает только примеры с положительным SYNTAX
      (validation=True);
    - рассчитывает sample weights по бинам SYNTAX (для WeightedRandomSampler);
    - конвертирует DICOM-видео в тензор (T, H, W, 3) c uint8 [0–255];
    - возвращает:
        video, label_bin, target_log, weight, rel_path, original_label.
    """

    def __init__(
        self,
        root: str,                      # корневая директория датасета
        meta: str,                      # относительный путь к JSON с метаданными
        train: bool,                    # режим: train / eval
        length: int,                    # длина клипа (кол-во кадров)
        label: str,                     # имя поля с SYNTAX score в JSON
        artery_bin: int,                # 0 — левая, 1 — правая артерия
        validation: bool = False,       # отбрасывать ли нулевые SYNTAX
        transform: Optional[Callable] = None,
    ) -> None:
        super().__init__()
        self.root = root
        self.train = train
        self.length = length
        self.label = label
        self.transform = transform
        self.validation = validation

        # meta теперь трактуется как ОТНОСИТЕЛЬНЫЙ путь от root
        meta_path = os.path.join(root, meta)
        with open(meta_path, "r") as f:
            dataset = json.load(f)

        # Фильтр по артерии (0 — левая, 1 — правая)
        if artery_bin is not None:
            assert artery_bin in (0, 1), "artery_bin должен быть 0 (левая) или 1 (правая)"
            dataset = [rec for rec in dataset if rec["artery"] == artery_bin]
            self.artery_bin = artery_bin
        else:
            # Для корректной работы get_sample_weights ожидаем известный artery_bin
            raise ValueError("artery_bin должен быть явно задан (0 или 1).")

        # Валидационный набор: берём только записи с положительным SYNTAX
        if validation:
            dataset = [rec for rec in dataset if rec[self.label] > 0]

        # Инициализируем веса с единиц
        for rec in dataset:
            rec["weight"] = 1.0

        self.dataset = dataset

    # ------------------------------------------------------------------
    # Веса для WeightedRandomSampler
    # ------------------------------------------------------------------
    def get_sample_weights(self) -> Tensor:
        """
        Считает веса для примеров по бинам SYNTAX.

        Для каждой артерии определён свой набор порогов,
        после чего каждый пример получает вес, обратный частоте своего бина.
        """
        # Пороговые значения по артериям (подбирались эмпирически)
        bin_thresholds = {
            0: [0, 5, 10, 15],  # левая
            1: [0, 2, 5, 8],    # правая
        }

        thresholds = bin_thresholds[self.artery_bin]
        thr0, thr1, thr2, thr3 = thresholds

        # Бины по значениям SYNTAX
        self.dataset_0 = [rec for rec in self.dataset if rec[self.label] == thr0]
        self.dataset_1 = [rec for rec in self.dataset if thr0 < rec[self.label] <= thr1]
        self.dataset_2 = [rec for rec in self.dataset if thr1 < rec[self.label] <= thr2]
        self.dataset_3 = [rec for rec in self.dataset if thr2 < rec[self.label] <= thr3]
        self.dataset_4 = [rec for rec in self.dataset if rec[self.label] > thr3]

        total = (
            len(self.dataset_0)
            + len(self.dataset_1)
            + len(self.dataset_2)
            + len(self.dataset_3)
            + len(self.dataset_4)
        )

        def safe_weight(count: int) -> float:
            # Если в бине нет примеров, вес ставим 0.0
            return total / count if count > 0 else 0.0

        self.weights_0 = safe_weight(len(self.dataset_0))
        self.weights_1 = safe_weight(len(self.dataset_1))
        self.weights_2 = safe_weight(len(self.dataset_2))
        self.weights_3 = safe_weight(len(self.dataset_3))
        self.weights_4 = safe_weight(len(self.dataset_4))

        print(
            "Weights: ",
            self.weights_0,
            self.weights_1,
            self.weights_2,
            self.weights_3,
            self.weights_4,
        )
        print(
            "Counts: ",
            len(self.dataset_0),
            len(self.dataset_1),
            len(self.dataset_2),
            len(self.dataset_3),
            len(self.dataset_4),
        )

        # Назначаем вес каждому примеру
        weights = []
        for rec in self.dataset:
            syntax_score = rec[self.label]
            if syntax_score == thr0:
                weights.append(self.weights_0)
            elif thr0 < syntax_score <= thr1:
                weights.append(self.weights_1)
            elif thr1 < syntax_score <= thr2:
                weights.append(self.weights_2)
            elif thr2 < syntax_score <= thr3:
                weights.append(self.weights_3)
            else:
                weights.append(self.weights_4)

        self.weights = torch.tensor(weights, dtype=DTYPE)
        return self.weights

    # ------------------------------------------------------------------
    def __len__(self) -> int:
        return len(self.dataset)

    # ------------------------------------------------------------------
    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, str, Tensor]:
        """
        Возвращает один пример:
        - video: Tensor (T, H, W, 3) → после transform обычно (C, T, H, W)
        - label: бинарный таргет по порогу для конкретной артерии
        - target: логарифмированный SYNTAX score (регрессия)
        - weight: вес примера (для самплера / лосса)
        - path: относительный путь к DICOM файлу
        - original_label: исходный SYNTAX score
        """
        rec = self.dataset[idx]

        # Относительный путь к DICOM из JSON (мы не храним абсолютные пути)
        path = rec["path"]
        weight = rec["weight"]

        full_path = os.path.join(self.root, path)
        video = pydicom.dcmread(full_path).pixel_array  # (T, H, W)

        # Приводим 16-битный сигнал к диапазону [0, 255] uint8
        if video.dtype == np.uint16:
            vmax = np.max(video)
            assert vmax > 0
            video = video.astype(np.float32)
            video = video * (255.0 / vmax)
            video = video.astype(np.uint8)
        assert video.dtype == np.uint8

        # Порог для бинарной классификации зависит от артерии
        bin_thresholds = {
            0: 15,  # левая
            1: 5,   # правая
        }

        syntax_value = rec[self.label]
        label = torch.tensor(
            [int(syntax_value > bin_thresholds[self.artery_bin])],
            dtype=DTYPE,
        )
        target = torch.tensor([np.log(1.0 + syntax_value)], dtype=DTYPE)
        original_label = torch.tensor([syntax_value], dtype=DTYPE)

        # Дублируем видео по времени, пока не наберём нужную длину клипа
        while len(video) < self.length:
            video = np.concatenate([video, video])
        t = len(video)

        if self.train:
            # Случайный подотрезок длины self.length
            begin = torch.randint(low=0, high=t - self.length + 1, size=(1,))
            end = begin + self.length
            video = video[begin:end, :, :]
        else:
            # В валидации используем весь видеоряд (обрежется трансформами / моделью)
            video = video

        # Превращаем (T, H, W) → (T, H, W, 3) путём копирования каналов (grayscale→RGB)
        video = torch.tensor(np.stack([video, video, video], axis=-1))

        if self.transform is not None:
            video = self.transform(video)

        sample_weight = torch.tensor([weight], dtype=DTYPE)

        return video, label, target, sample_weight, path, original_label