MesserMMP commited on
Commit
3ca1d12
·
1 Parent(s): 5547c55

add backbone model

Browse files
Files changed (3) hide show
  1. backbone/dataset.py +222 -0
  2. backbone/pl_model.py +244 -0
  3. backbone/pl_train.py +278 -0
backbone/dataset.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pydicom
4
+ import numpy as np
5
+ import torch
6
+
7
+ from typing import Callable, Optional, Tuple
8
+ from torch import Tensor
9
+ from torch.utils.data import Dataset
10
+
11
+ # Полуточность достаточно для хранения весов и таргетов,
12
+ # а сами вычисления в модели идут в float32 / bf16.
13
+ DTYPE = torch.float16
14
+
15
+
16
+ class SyntaxDataset(Dataset):
17
+ """
18
+ PyTorch Dataset для обучения видеобэкбона на задаче SYNTAX.
19
+
20
+ Функциональность:
21
+ - читает метаданные из JSON (относительный путь относительно root);
22
+ - фильтрует по артерии (левая / правая);
23
+ - опционально отфильтровывает только примеры с положительным SYNTAX
24
+ (validation=True);
25
+ - рассчитывает sample weights по бинам SYNTAX (для WeightedRandomSampler);
26
+ - конвертирует DICOM-видео в тензор (T, H, W, 3) c uint8 [0–255];
27
+ - возвращает:
28
+ video, label_bin, target_log, weight, rel_path, original_label.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ root: str, # корневая директория датасета
34
+ meta: str, # относительный путь к JSON с метаданными
35
+ train: bool, # режим: train / eval
36
+ length: int, # длина клипа (кол-во кадров)
37
+ label: str, # имя поля с SYNTAX score в JSON
38
+ artery_bin: int, # 0 — левая, 1 — правая артерия
39
+ validation: bool = False, # отбрасывать ли нулевые SYNTAX
40
+ transform: Optional[Callable] = None,
41
+ ) -> None:
42
+ super().__init__()
43
+ self.root = root
44
+ self.train = train
45
+ self.length = length
46
+ self.label = label
47
+ self.transform = transform
48
+ self.validation = validation
49
+
50
+ # meta теперь трактуется как ОТНОСИТЕЛЬНЫЙ путь от root
51
+ meta_path = os.path.join(root, meta)
52
+ with open(meta_path, "r") as f:
53
+ dataset = json.load(f)
54
+
55
+ # Фильтр по артерии (0 — левая, 1 — правая)
56
+ if artery_bin is not None:
57
+ assert artery_bin in (0, 1), "artery_bin должен быть 0 (левая) или 1 (правая)"
58
+ dataset = [rec for rec in dataset if rec["artery"] == artery_bin]
59
+ self.artery_bin = artery_bin
60
+ else:
61
+ # Для корректной работы get_sample_weights ожидаем известный artery_bin
62
+ raise ValueError("artery_bin должен быть явно задан (0 или 1).")
63
+
64
+ # Валидационный набор: берём только записи с положительным SYNTAX
65
+ if validation:
66
+ dataset = [rec for rec in dataset if rec[self.label] > 0]
67
+
68
+ # Инициализируем веса с единиц
69
+ for rec in dataset:
70
+ rec["weight"] = 1.0
71
+
72
+ self.dataset = dataset
73
+
74
+ # ------------------------------------------------------------------
75
+ # Веса для WeightedRandomSampler
76
+ # ------------------------------------------------------------------
77
+ def get_sample_weights(self) -> Tensor:
78
+ """
79
+ Считает веса для примеров по бинам SYNTAX.
80
+
81
+ Для каждой артерии определён свой набор порогов,
82
+ после чего каждый пример получает вес, обратный частоте своего бина.
83
+ """
84
+ # Пороговые значения по артериям (подбирались эмпирически)
85
+ bin_thresholds = {
86
+ 0: [0, 5, 10, 15], # левая
87
+ 1: [0, 2, 5, 8], # правая
88
+ }
89
+
90
+ thresholds = bin_thresholds[self.artery_bin]
91
+ thr0, thr1, thr2, thr3 = thresholds
92
+
93
+ # Бины по значениям SYNTAX
94
+ self.dataset_0 = [rec for rec in self.dataset if rec[self.label] == thr0]
95
+ self.dataset_1 = [rec for rec in self.dataset if thr0 < rec[self.label] <= thr1]
96
+ self.dataset_2 = [rec for rec in self.dataset if thr1 < rec[self.label] <= thr2]
97
+ self.dataset_3 = [rec for rec in self.dataset if thr2 < rec[self.label] <= thr3]
98
+ self.dataset_4 = [rec for rec in self.dataset if rec[self.label] > thr3]
99
+
100
+ total = (
101
+ len(self.dataset_0)
102
+ + len(self.dataset_1)
103
+ + len(self.dataset_2)
104
+ + len(self.dataset_3)
105
+ + len(self.dataset_4)
106
+ )
107
+
108
+ def safe_weight(count: int) -> float:
109
+ # Если в би��е нет примеров, вес ставим 0.0
110
+ return total / count if count > 0 else 0.0
111
+
112
+ self.weights_0 = safe_weight(len(self.dataset_0))
113
+ self.weights_1 = safe_weight(len(self.dataset_1))
114
+ self.weights_2 = safe_weight(len(self.dataset_2))
115
+ self.weights_3 = safe_weight(len(self.dataset_3))
116
+ self.weights_4 = safe_weight(len(self.dataset_4))
117
+
118
+ print(
119
+ "Weights: ",
120
+ self.weights_0,
121
+ self.weights_1,
122
+ self.weights_2,
123
+ self.weights_3,
124
+ self.weights_4,
125
+ )
126
+ print(
127
+ "Counts: ",
128
+ len(self.dataset_0),
129
+ len(self.dataset_1),
130
+ len(self.dataset_2),
131
+ len(self.dataset_3),
132
+ len(self.dataset_4),
133
+ )
134
+
135
+ # Назначаем вес каждому примеру
136
+ weights = []
137
+ for rec in self.dataset:
138
+ syntax_score = rec[self.label]
139
+ if syntax_score == thr0:
140
+ weights.append(self.weights_0)
141
+ elif thr0 < syntax_score <= thr1:
142
+ weights.append(self.weights_1)
143
+ elif thr1 < syntax_score <= thr2:
144
+ weights.append(self.weights_2)
145
+ elif thr2 < syntax_score <= thr3:
146
+ weights.append(self.weights_3)
147
+ else:
148
+ weights.append(self.weights_4)
149
+
150
+ self.weights = torch.tensor(weights, dtype=DTYPE)
151
+ return self.weights
152
+
153
+ # ------------------------------------------------------------------
154
+ def __len__(self) -> int:
155
+ return len(self.dataset)
156
+
157
+ # ------------------------------------------------------------------
158
+ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, str, Tensor]:
159
+ """
160
+ Возвращает один пример:
161
+ - video: Tensor (T, H, W, 3) → после transform обычно (C, T, H, W)
162
+ - label: бинарный таргет по порогу для конкретной артерии
163
+ - target: логарифмированный SYNTAX score (регрессия)
164
+ - weight: вес примера (для самплера / лосса)
165
+ - path: относительный путь к DICOM файлу
166
+ - original_label: исходный SYNTAX score
167
+ """
168
+ rec = self.dataset[idx]
169
+
170
+ # Относительный путь к DICOM из JSON (мы не храним абсолютные пути)
171
+ path = rec["path"]
172
+ weight = rec["weight"]
173
+
174
+ full_path = os.path.join(self.root, path)
175
+ video = pydicom.dcmread(full_path).pixel_array # (T, H, W)
176
+
177
+ # Приводим 16-битный сигнал к диапазону [0, 255] uint8
178
+ if video.dtype == np.uint16:
179
+ vmax = np.max(video)
180
+ assert vmax > 0
181
+ video = video.astype(np.float32)
182
+ video = video * (255.0 / vmax)
183
+ video = video.astype(np.uint8)
184
+ assert video.dtype == np.uint8
185
+
186
+ # Порог для бинарной классификации зависит от артерии
187
+ bin_thresholds = {
188
+ 0: 15, # левая
189
+ 1: 5, # правая
190
+ }
191
+
192
+ syntax_value = rec[self.label]
193
+ label = torch.tensor(
194
+ [int(syntax_value > bin_thresholds[self.artery_bin])],
195
+ dtype=DTYPE,
196
+ )
197
+ target = torch.tensor([np.log(1.0 + syntax_value)], dtype=DTYPE)
198
+ original_label = torch.tensor([syntax_value], dtype=DTYPE)
199
+
200
+ # Дублируем видео по времени, пока не наберём нужную длину клипа
201
+ while len(video) < self.length:
202
+ video = np.concatenate([video, video])
203
+ t = len(video)
204
+
205
+ if self.train:
206
+ # Случайный подотрезок длины self.length
207
+ begin = torch.randint(low=0, high=t - self.length + 1, size=(1,))
208
+ end = begin + self.length
209
+ video = video[begin:end, :, :]
210
+ else:
211
+ # В валидации используем весь видеоряд (обрежется трансформами / моделью)
212
+ video = video
213
+
214
+ # Превращаем (T, H, W) → (T, H, W, 3) путём копирования каналов (grayscale→RGB)
215
+ video = torch.tensor(np.stack([video, video, video], axis=-1))
216
+
217
+ if self.transform is not None:
218
+ video = self.transform(video)
219
+
220
+ sample_weight = torch.tensor([weight], dtype=DTYPE)
221
+
222
+ return video, label, target, sample_weight, path, original_label
backbone/pl_model.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import torch
3
+ from torch import nn, optim
4
+ import lightning.pytorch as pl
5
+ import torchvision.models.video as tvmv
6
+ import sklearn.metrics as skm
7
+ import numpy as np
8
+
9
+
10
+ class SyntaxLightningModule(pl.LightningModule):
11
+ """
12
+ LightningModule для обучения 3D-ResNet (r3d_18) как backbone
13
+ в задаче предсказания SYNTAX score по видеоангиографии.
14
+
15
+ Модель предсказывает:
16
+ - yp_clf: вероятность поражения (syntax > порог) — бинарная классификация
17
+ - yp_reg: логарифмированное значение SYNTAX — регрессия
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ num_classes: int,
23
+ lr: float,
24
+ weight_decay: float = 0.0,
25
+ max_epochs: int = None,
26
+ weight_path: str = None,
27
+ sigma_a: float = 0.0,
28
+ sigma_b: float = 1.0,
29
+ **kwargs,
30
+ ):
31
+ super().__init__()
32
+ self.save_hyperparameters()
33
+
34
+ self.num_classes = num_classes
35
+ self.lr = lr
36
+ self.weight_decay = weight_decay
37
+ self.max_epochs = max_epochs
38
+ self.weight_path = weight_path
39
+ self.sigma_a = sigma_a
40
+ self.sigma_b = sigma_b
41
+
42
+ # Базовый 3D-ResNet с ImageNet Kinetics-предобученными весами
43
+ self.model = tvmv.r3d_18(weights=tvmv.R3D_18_Weights.DEFAULT)
44
+
45
+ # Последний слой заменяем на Linear с num_classes выходами:
46
+ # 1 канал для классификации, 1 для регрессии
47
+ in_features = self.model.fc.in_features
48
+ self.model.fc = nn.Linear(in_features=in_features, out_features=num_classes, bias=True)
49
+
50
+ # Если передан путь к чекпоинту Lightning — загружаем backbone
51
+ if self.weight_path is not None:
52
+ ckpt = torch.load(self.weight_path, map_location="cpu", weights_only=False)
53
+ state_dict = ckpt["state_dict"]
54
+ # Чистим префикс "model." у ключей
55
+ new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
56
+ self.model.load_state_dict(new_state_dict, strict=False)
57
+
58
+ # Лоссы
59
+ self.loss_clf = nn.BCEWithLogitsLoss(reduction="none")
60
+ self.loss_reg = nn.MSELoss(reduction="none")
61
+
62
+ # Буферы для валидационных метрик
63
+ self.y_val = []
64
+ self.p_val = []
65
+ self.r_val = []
66
+ self.ty_val = []
67
+ self.tp_val = []
68
+
69
+ # ------------------------------------------------------------------
70
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ return self.model(x)
72
+
73
+ # ------------------------------------------------------------------
74
+ def training_step(self, batch, batch_idx):
75
+ """
76
+ Один шаг обучения:
77
+ - бинарная классификация поражения (BCE с down-weight для нулей);
78
+ - регрессия логарифмированного SYNTAX с учётом get_sigma(target).
79
+ """
80
+ x, y, target, sample_weight, path, original_label = batch
81
+
82
+ y_hat = self(x)
83
+ yp_clf = y_hat[:, 0:1] # logits для классификации
84
+ yp_reg = y_hat[:, 1:] # регрессия (лог SYNTAX)
85
+
86
+ # BCE с меньшим весом для класса 0 (нет поражения)
87
+ weights_clf = torch.where(y > 0, 1.0, 0.45)
88
+ clf_loss = self.loss_clf(yp_clf, y)
89
+ clf_loss = (clf_loss * weights_clf).mean()
90
+
91
+ # Регрессионный лосс с «вариабельностью по красной линии»
92
+ reg_loss_raw = self.loss_reg(yp_reg, target)
93
+ sigma = self.sigma_a * target + self.sigma_b
94
+ reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
95
+
96
+ loss = clf_loss + 0.5 * reg_loss
97
+
98
+ # Метрики на бинарную задачу
99
+ y_pred = torch.sigmoid(yp_clf)
100
+ y_bin = torch.round(y.detach().cpu()).int()
101
+ y_pred_bin = torch.round(y_pred.detach().cpu()).int()
102
+
103
+ self.log("train_clf_loss", clf_loss, prog_bar=True, sync_dist=True)
104
+ self.log("train_val_loss", reg_loss, prog_bar=True, sync_dist=True)
105
+ self.log("train_full_loss", loss, prog_bar=True, sync_dist=True)
106
+ self.log(
107
+ "train_f1",
108
+ skm.f1_score(y_bin, y_pred_bin, zero_division=0),
109
+ prog_bar=True,
110
+ sync_dist=True,
111
+ )
112
+ self.log(
113
+ "train_acc",
114
+ skm.accuracy_score(y_bin, y_pred_bin),
115
+ prog_bar=True,
116
+ sync_dist=True,
117
+ )
118
+
119
+ return loss
120
+
121
+ # ------------------------------------------------------------------
122
+ def validation_step(self, batch, batch_idx):
123
+ """
124
+ Валидационный шаг: считаем тот же комбини��ованный лосс и
125
+ аккумулируем предсказания для расчёта метрик на эпоху.
126
+ """
127
+ x, y, target, sample_weight, path, original_label = batch
128
+
129
+ y_hat = self(x)
130
+ yp_clf = y_hat[:, 0:1]
131
+ yp_reg = y_hat[:, 1:]
132
+
133
+ # Комбинированный лосс
134
+ clf_loss = self.loss_clf(yp_clf, y)
135
+ reg_loss_raw = self.loss_reg(yp_reg, target)
136
+ sigma = self.sigma_a * target + self.sigma_b
137
+ reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
138
+ loss = clf_loss.mean() + 0.5 * reg_loss
139
+
140
+ # Для метрик
141
+ y_pred = torch.sigmoid(yp_clf)
142
+
143
+ self.y_val.append(int(y[..., 0].cpu()))
144
+ self.p_val.append(float(y_pred[..., 0].cpu()))
145
+ self.r_val.append(round(float(y_pred[..., 0].cpu())))
146
+
147
+ self.ty_val.append(float(target[..., 0].cpu()))
148
+ self.tp_val.append(float(yp_reg[..., 0].cpu()))
149
+
150
+ return loss
151
+
152
+ # ------------------------------------------------------------------
153
+ def on_validation_epoch_end(self) -> None:
154
+ """
155
+ Подсчёт валидационных метрик по всей эпохе и логирование в Logger.
156
+ """
157
+ try:
158
+ auc = skm.roc_auc_score(self.y_val, self.p_val)
159
+ f1 = skm.f1_score(self.y_val, self.r_val, zero_division=0)
160
+ acc = skm.accuracy_score(self.y_val, self.r_val)
161
+ mae = skm.mean_absolute_error(self.y_val, self.r_val)
162
+ rmse = skm.root_mean_squared_error(self.ty_val, self.tp_val)
163
+
164
+ self.log("val_auc", auc, prog_bar=True, sync_dist=True)
165
+ self.log("val_f1", f1, prog_bar=True, sync_dist=True)
166
+ self.log("val_acc", acc, prog_bar=True, sync_dist=True)
167
+ self.log("val_mae", mae, prog_bar=True, sync_dist=True)
168
+ self.log("val_rmse", rmse, prog_bar=True, sync_dist=True)
169
+
170
+ except ValueError as err:
171
+ # Случаи, когда метрики нельзя посчитать (например, только один класс)
172
+ print(err)
173
+ print("Y_VAL", self.y_val)
174
+ print("P_VAL", self.p_val)
175
+
176
+ # Чистим буферы к следующей эпохе
177
+ self.y_val.clear()
178
+ self.p_val.clear()
179
+ self.r_val.clear()
180
+ self.ty_val.clear()
181
+ self.tp_val.clear()
182
+
183
+ # ------------------------------------------------------------------
184
+ def on_train_epoch_end(self) -> None:
185
+ """Логирование текущего learning rate."""
186
+ opt = self.optimizers()
187
+ if hasattr(opt, "optimizer"):
188
+ lr = opt.optimizer.param_groups[0]["lr"]
189
+ else:
190
+ lr = opt.param_groups[0]["lr"]
191
+ self.log("lr", lr, on_step=False, on_epoch=True, sync_dist=True)
192
+
193
+ # ------------------------------------------------------------------
194
+ def configure_optimizers(self):
195
+ """
196
+ - Если weight_path не задан → pretrain: обучаем только финальный fc-слой.
197
+ - Если weight_path задан → full fine-tuning: обучаем весь backbone.
198
+ """
199
+ if not self.weight_path:
200
+ # Pretrain: замораживаем всё, кроме финального слоя
201
+ for param in self.parameters():
202
+ param.requires_grad = False
203
+ for p in self.model.fc.parameters():
204
+ p.requires_grad = True
205
+ params = list(self.model.fc.parameters())
206
+ else:
207
+ # Full fine-tune: обучаем все параметры модели
208
+ for param in self.parameters():
209
+ param.requires_grad = True
210
+ params = self.parameters()
211
+
212
+ optimizer = optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay)
213
+
214
+ if self.max_epochs is not None:
215
+ scheduler = optim.lr_scheduler.OneCycleLR(
216
+ optimizer=optimizer,
217
+ max_lr=self.lr,
218
+ total_steps=self.max_epochs,
219
+ )
220
+ return [optimizer], [scheduler]
221
+ else:
222
+ return optimizer
223
+
224
+ # ------------------------------------------------------------------
225
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
226
+ """
227
+ Инференс: возвращает словарь с бинарным предсказанием, вероятностями
228
+ и регрессионным выходом.
229
+ """
230
+ x, y, target, sample_weight, path, original_label = batch
231
+ y_hat = self(x)
232
+ yp_clf = y_hat[:, 0:1]
233
+ yp_reg = y_hat[:, 1:]
234
+ y_prob = torch.sigmoid(yp_clf)
235
+ y_pred = torch.round(y_prob)
236
+
237
+ return {
238
+ "y": y,
239
+ "y_pred": y_pred,
240
+ "y_prob": y_prob,
241
+ "y_reg": yp_reg,
242
+ "target": target,
243
+ "original_label": original_label,
244
+ }
backbone/pl_train.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+ import click
6
+ import lightning.pytorch as pl
7
+ from lightning.pytorch.loggers import TensorBoardLogger
8
+ from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
9
+ from lightning.pytorch.profilers import AdvancedProfiler, PyTorchProfiler
10
+
11
+ from pytorchvideo.transforms import Normalize, Permute, RandAugment
12
+ from torch.utils.data import DataLoader, WeightedRandomSampler
13
+ from torchvision.transforms import transforms as T
14
+ from torchvision.transforms._transforms_video import ToTensorVideo
15
+ from torchvision.transforms import InterpolationMode
16
+
17
+ from dataset import SyntaxDataset
18
+ from pl_model import SyntaxLightningModule
19
+
20
+ import warnings
21
+ warnings.filterwarnings("ignore", message="No device id is provided via `init_process_group`")
22
+
23
+ torch.set_float32_matmul_precision("medium")
24
+
25
+
26
+ """
27
+ Скрипт обучения backbone (3D-ResNet) для предсказания SYNTAX score.
28
+
29
+ Шаги:
30
+ 1) предварительное обучение (pretrain) — обучается только последний слой;
31
+ 2) полное дообучение (full) — fine-tuning всего backbone.
32
+ """
33
+
34
+
35
+ # ------------------- Трансформации -------------------
36
+ def get_transforms(video_size, imagenet_mean, imagenet_std, train=True):
37
+ interpolation_choices = [
38
+ InterpolationMode.BILINEAR,
39
+ InterpolationMode.BICUBIC,
40
+ ]
41
+ if train:
42
+ return T.Compose([
43
+ ToTensorVideo(), # (T, H, W, 3) -> (C, T, H, W)
44
+ Permute(dims=[1, 0, 2, 3]), # (C, T, H, W) -> (T, C, H, W)
45
+ RandAugment(magnitude=10, num_layers=2),
46
+ T.RandomHorizontalFlip(),
47
+ Permute(dims=[1, 0, 2, 3]), # обратно: (T, C, H, W) -> (C, T, H, W)
48
+ T.RandomChoice([
49
+ T.Resize(size=video_size, interpolation=interp, antialias=True)
50
+ for interp in interpolation_choices
51
+ ]),
52
+ Normalize(mean=imagenet_mean, std=imagenet_std),
53
+ ])
54
+ else:
55
+ return T.Compose([
56
+ ToTensorVideo(),
57
+ T.Resize(size=video_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
58
+ Normalize(mean=imagenet_mean, std=imagenet_std),
59
+ ])
60
+
61
+
62
+ # ------------------- DataLoader -------------------
63
+ def make_dataloader(dataset, batch_size, num_workers):
64
+ """
65
+ Создаёт DataLoader; по умолчанию используем shuffle,
66
+ но можно легко переключиться на WeightedRandomSampler.
67
+ """
68
+ sample_weights = dataset.get_sample_weights()
69
+ # sampler = WeightedRandomSampler(sample_weights, len(dataset), replacement=True)
70
+ return DataLoader(
71
+ dataset,
72
+ batch_size=batch_size,
73
+ num_workers=num_workers,
74
+ # sampler=sampler,
75
+ shuffle=True,
76
+ drop_last=True,
77
+ pin_memory=True,
78
+ )
79
+
80
+
81
+ # ------------------- Модель -------------------
82
+ def make_model(num_classes, video_shape, lr, weight_decay, max_epochs, weight_path=None):
83
+ """
84
+ Обёртка над SyntaxLightningModule для единообразного создания модели
85
+ на этапах pretrain и full fine-tuning.
86
+ """
87
+ model = SyntaxLightningModule(
88
+ num_classes=num_classes,
89
+ lr=lr,
90
+ weight_decay=weight_decay,
91
+ max_epochs=max_epochs,
92
+ weight_path=weight_path,
93
+ )
94
+ return model
95
+
96
+
97
+ # ------------------- Callbacks -------------------
98
+ def make_callbacks(artery: str, fold: int, phase: str):
99
+ """
100
+ Возвращает набор callback'ов:
101
+ - LearningRateMonitor
102
+ - ModelCheckpoint с сохранением по наилучшему val_mae.
103
+ """
104
+ lr_monitor = LearningRateMonitor(logging_interval="epoch")
105
+
106
+ if phase == "pre":
107
+ checkpoint = ModelCheckpoint(
108
+ monitor="val_mae",
109
+ save_top_k=1,
110
+ mode="min",
111
+ filename="model" + "-{epoch:02d}-{val_rmse:.3f}",
112
+ save_last=True,
113
+ )
114
+ elif phase == "full":
115
+ checkpoint = ModelCheckpoint(
116
+ monitor="val_mae",
117
+ save_top_k=3,
118
+ mode="min",
119
+ filename="model" + "-{epoch:02d}-{val_rmse:.3f}",
120
+ save_last=True,
121
+ )
122
+ else:
123
+ raise ValueError(f"Unknown phase '{phase}', expected 'pre' or 'full'")
124
+
125
+ return [lr_monitor, checkpoint]
126
+
127
+
128
+ # ------------------- Trainer -------------------
129
+ def make_trainer(max_epochs, logger_name, callbacks):
130
+ """
131
+ Создаёт Lightning Trainer c TensorBoardLogger.
132
+
133
+ Важно: пути к логам и устройствам можно адаптировать под свой кластер.
134
+ """
135
+ logger = TensorBoardLogger(
136
+ save_dir="backbone_logs",
137
+ name=logger_name,
138
+ )
139
+ trainer = pl.Trainer(
140
+ max_epochs=max_epochs,
141
+ accelerator="gpu",
142
+ devices=1,
143
+ strategy="ddp_find_unused_parameters_true",
144
+ precision="bf16-mixed",
145
+ callbacks=callbacks,
146
+ log_every_n_steps=10,
147
+ logger=logger,
148
+ )
149
+ return trainer
150
+
151
+
152
+ @click.command()
153
+ @click.option(
154
+ "-r",
155
+ "--dataset-root",
156
+ type=click.Path(exists=True),
157
+ default=".",
158
+ required=True,
159
+ help="Путь к корню датасета (директория, внутри которой лежат JSON и DICOM).",
160
+ )
161
+ @click.option("--fold", type=int, default=0, required=True, help="Номер фолда (0–4).")
162
+ @click.option(
163
+ "-a",
164
+ "--artery",
165
+ type=str,
166
+ default="right",
167
+ required=True,
168
+ help="Название артерии: 'left' или 'right'.",
169
+ )
170
+ @click.option("-nc", "--num-classes", type=int, default=2, help="Число выходных каналов модели.")
171
+ @click.option("-b", "--batch-size", type=int, default=50, help="Размер batch.")
172
+ @click.option("-f", "--frames-per-clip", type=int, default=32, help="Количество кадров в клипе.")
173
+ @click.option(
174
+ "-v",
175
+ "--video-size",
176
+ type=click.Tuple([int, int]),
177
+ default=(256, 256),
178
+ help="Размер кадра (H, W).",
179
+ )
180
+ @click.option("--max-epochs", type=int, default=10, help="Число эпох на этапе full fine-tuning.")
181
+ @click.option("--num-workers", type=int, default=8, help="Число воркеров для DataLoader.")
182
+ @click.option(
183
+ "--fast-dev-run",
184
+ is_flag=True,
185
+ default=False,
186
+ show_default=True,
187
+ help="Режим быстрой проверки пайплайна (1–2 батча).",
188
+ )
189
+ @click.option("--seed", type=int, default=42, help="Сид для воспроизводимости.")
190
+ def main(
191
+ dataset_root,
192
+ fold,
193
+ artery,
194
+ num_classes,
195
+ batch_size,
196
+ frames_per_clip,
197
+ video_size,
198
+ max_epochs,
199
+ num_workers,
200
+ fast_dev_run,
201
+ seed,
202
+ ):
203
+ pl.seed_everything(seed)
204
+
205
+ artery = artery.lower()
206
+ artery_bin = {"left": 0, "right": 1}.get(artery)
207
+ if artery_bin is None:
208
+ raise ValueError(f"Unknown artery '{artery}', expected 'left' or 'right'.")
209
+
210
+ imagenet_mean = [0.485, 0.456, 0.406]
211
+ imagenet_std = [0.229, 0.224, 0.225]
212
+
213
+ # ------------------- Datasets -------------------
214
+ # Путь к JSON теперь относительный относительно dataset_root
215
+ train_meta = os.path.join("folds", f"step2_fold{fold:02d}_train.json")
216
+ val_meta = os.path.join("folds", f"step2_fold{fold:02d}_eval.json")
217
+
218
+ train_set = SyntaxDataset(
219
+ root=dataset_root,
220
+ meta=train_meta,
221
+ train=True,
222
+ length=frames_per_clip,
223
+ label=f"syntax_{artery}",
224
+ artery_bin=artery_bin,
225
+ validation=False,
226
+ transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=True),
227
+ )
228
+
229
+ val_set = SyntaxDataset(
230
+ root=dataset_root,
231
+ meta=val_meta,
232
+ train=False,
233
+ length=frames_per_clip,
234
+ label=f"syntax_{artery}",
235
+ artery_bin=artery_bin,
236
+ validation=True,
237
+ transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=False),
238
+ )
239
+
240
+ train_loader_pre = make_dataloader(train_set, batch_size * 2, num_workers)
241
+ train_loader_post = make_dataloader(train_set, batch_size, num_workers)
242
+ val_loader = make_dataloader(val_set, 1, num_workers)
243
+
244
+ # Получаем форму входного видео (C, T, H, W) из одного батча
245
+ x, *_ = next(iter(train_loader_pre))
246
+ video_shape = x.shape[1:]
247
+
248
+ # ------------------- Callbacks -------------------
249
+ callbacks_pre = make_callbacks(artery=artery, fold=fold, phase="pre")
250
+ callbacks_full = make_callbacks(artery=artery, fold=fold, phase="full")
251
+
252
+ # ------------------- Pretrain -------------------
253
+ num_pre_epochs = 10
254
+ model_pre = make_model(
255
+ num_classes=num_classes,
256
+ video_shape=video_shape,
257
+ lr=3e-4,
258
+ weight_decay=0.01,
259
+ max_epochs=num_pre_epochs,
260
+ )
261
+ trainer_pre = make_trainer(num_pre_epochs, f"{artery}BinSyntax_R3D_pre_fold{fold:02d}", callbacks_pre)
262
+ trainer_pre.fit(model_pre, train_loader_pre, val_loader, ckpt_path=None)
263
+
264
+ # ------------------- Full train -------------------
265
+ model_full = make_model(
266
+ num_classes=num_classes,
267
+ video_shape=video_shape,
268
+ lr=1e-4,
269
+ weight_decay=0.01,
270
+ max_epochs=max_epochs,
271
+ weight_path=trainer_pre.checkpoint_callback.last_model_path,
272
+ )
273
+ trainer_full = make_trainer(max_epochs, f"{artery}BinSyntax_R3D_full_fold{fold:02d}", callbacks_full)
274
+ trainer_full.fit(model_full, train_loader_post, val_loader, ckpt_path=None)
275
+
276
+
277
+ if __name__ == "__main__":
278
+ main()