uncleMehrzad commited on
Commit
f9d46c5
·
verified ·
1 Parent(s): 302bc19

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +1099 -0
train.py ADDED
@@ -0,0 +1,1099 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import json
4
+ import numpy as np
5
+ import pandas as pd
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ from datetime import datetime
9
+ import matplotlib.pyplot as plt
10
+ from sklearn.model_selection import train_test_split
11
+ from scipy.ndimage import morphology
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.utils.data import Dataset, DataLoader
17
+ from torch.optim import AdamW
18
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau
19
+
20
+ from transformers import AutoModel
21
+ import albumentations as A
22
+ from albumentations.pytorch import ToTensorV2
23
+
24
+ import cv2
25
+ import warnings
26
+ import math
27
+ warnings.filterwarnings('ignore')
28
+
29
+ # Set seeds for reproducibility
30
+ def set_seed(seed=42):
31
+ np.random.seed(seed)
32
+ torch.manual_seed(seed)
33
+ torch.cuda.manual_seed_all(seed)
34
+ torch.backends.cudnn.deterministic = True
35
+ torch.backends.cudnn.benchmark = False
36
+
37
+ set_seed(42)
38
+
39
+ # ============================================================================
40
+ # CONFIGURATION
41
+ # ============================================================================
42
+
43
+ class Config:
44
+ # Model - USING YOUR LOCAL DOWNLOADED MODEL
45
+ model_name = "facebook/dinov3-vitl16-pretrain-lvd1689m"
46
+ local_model_path = "/data/F/VoiceNegar/models/pe_models/dino7b/checkpoints/initial_dinov3-vitl16-pretrain-lvd1689m_backbone"
47
+
48
+ # Data paths
49
+ dataset_path = "/home/PeBigModelForVilab/dinov3/toy-project/Kvasir-SEG/"
50
+ image_size = 256
51
+ patch_size = 16
52
+
53
+ # Training
54
+ batch_size = 96
55
+ num_epochs = 150
56
+ learning_rate = 1e-4
57
+ min_lr = 1e-6
58
+ weight_decay = 1e-4
59
+
60
+ # Cosine Annealing with Warm Restarts
61
+ T_0 = 10 # Initial restart period (epochs)
62
+ T_mult = 2 # Period multiplier after each restart
63
+
64
+ # Validation
65
+ val_split = 0.1
66
+ test_split = 0.05
67
+
68
+ # Device
69
+ device = "cuda" if torch.cuda.is_available() else "cpu"
70
+
71
+ # Logging
72
+ save_dir = "./checkpoints"
73
+ log_interval = 10
74
+
75
+ # Image normalization (ImageNet stats)
76
+ mean = [0.485, 0.456, 0.406]
77
+ std = [0.229, 0.224, 0.225]
78
+
79
+ resume_from = None
80
+ # Multi‑scale ViT layers
81
+ multi_scale_layers = [5, 10, 16, 18, 20, 22, 23]
82
+ # Loss parameters (Focal+Dice)
83
+ focal_weight = 0.69
84
+ dice_weight = 0.3
85
+ boundary_weight = 0.01
86
+ # HD95 parameter
87
+ hd95_threshold = 0.5
88
+
89
+ config = Config()
90
+ os.makedirs(config.save_dir, exist_ok=True)
91
+ print(f"Using device: {config.device}")
92
+ print(f"Model: {config.model_name}")
93
+ print(f"Local model path: {config.local_model_path}")
94
+ print(f"Exists: {os.path.exists(config.local_model_path)}")
95
+
96
+ # ============================================================================
97
+ # DATASET CLASS
98
+ # ============================================================================
99
+
100
+ class PolypDataset(Dataset):
101
+ """Kvasir-SEG dataset with manual preprocessing"""
102
+
103
+ def __init__(self, image_paths, mask_paths, transform=None, target_size=(256, 256)):
104
+ self.image_paths = image_paths
105
+ self.mask_paths = mask_paths
106
+ self.transform = transform
107
+ self.target_size = target_size
108
+
109
+ # ImageNet normalization values
110
+ self.mean = torch.tensor(config.mean).view(3, 1, 1)
111
+ self.std = torch.tensor(config.std).view(3, 1, 1)
112
+
113
+ def __len__(self):
114
+ return len(self.image_paths)
115
+
116
+ def __getitem__(self, idx):
117
+ # Load image
118
+ image = cv2.imread(self.image_paths[idx])
119
+ if image is None:
120
+ raise ValueError(f"Could not load image: {self.image_paths[idx]}")
121
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
122
+
123
+ # Load mask
124
+ mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
125
+ if mask is None:
126
+ raise ValueError(f"Could not load mask: {self.mask_paths[idx]}")
127
+ mask = (mask > 127).astype(np.float32)
128
+
129
+ # Apply augmentations
130
+ if self.transform:
131
+ augmented = self.transform(image=image, mask=mask)
132
+ image, mask = augmented['image'], augmented['mask']
133
+ else:
134
+ image = cv2.resize(image, self.target_size)
135
+ mask = cv2.resize(mask, self.target_size, interpolation=cv2.INTER_NEAREST)
136
+
137
+ # Manual preprocessing
138
+ if isinstance(image, np.ndarray):
139
+ image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
140
+ elif isinstance(image, torch.Tensor):
141
+ image = image.float() / 255.0
142
+
143
+ # Apply ImageNet normalization
144
+ image = (image - self.mean) / self.std
145
+
146
+ # Ensure mask is tensor
147
+ if isinstance(mask, np.ndarray):
148
+ mask = torch.from_numpy(mask).float()
149
+
150
+ return image, mask.unsqueeze(0)
151
+
152
+ # ============================================================================
153
+ # FIXED DINOv3 ENCODER
154
+ # ============================================================================
155
+
156
+ class DINOv3Encoder(nn.Module):
157
+ """Frozen DINOv3 encoder that can return concatenated multi‑scale features."""
158
+
159
+ def __init__(self, model_name="facebook/dinov3-vitl16-pretrain-lvd1689m",
160
+ local_path=None, freeze=True, layers=None):
161
+ super().__init__()
162
+
163
+ # Load model
164
+ if local_path and os.path.exists(local_path):
165
+ print(f"Loading DINOv3 model from local path: {local_path}")
166
+ self.model = AutoModel.from_pretrained(local_path, local_files_only=True)
167
+ else:
168
+ print(f"Loading DINOv3 from HuggingFace hub: {model_name}")
169
+ self.model = AutoModel.from_pretrained(model_name)
170
+
171
+ self.embed_dim = self.model.config.hidden_size
172
+ self.patch_size = self.model.config.patch_size
173
+ self.layers = layers
174
+
175
+ if self.layers is not None:
176
+ self.out_channels = self.embed_dim * len(self.layers)
177
+ else:
178
+ self.out_channels = self.embed_dim
179
+
180
+ print(f"DINOv3 loaded - embed_dim: {self.embed_dim}, patch_size: {self.patch_size}")
181
+ if self.layers:
182
+ print(f" Multi‑scale layers: {self.layers}, output channels: {self.out_channels}")
183
+
184
+ if freeze:
185
+ for param in self.model.parameters():
186
+ param.requires_grad = False
187
+
188
+ def _reshape_to_2d(self, patch_tokens, B):
189
+ """Robust reshaping of patch tokens to 2D grid."""
190
+ N = patch_tokens.shape[1]
191
+ D = patch_tokens.shape[2]
192
+
193
+ H_grid = int(math.sqrt(N))
194
+ W_grid = H_grid
195
+
196
+ while H_grid * W_grid != N:
197
+ if H_grid * W_grid < N:
198
+ W_grid += 1
199
+ else:
200
+ found = False
201
+ for h in range(int(math.sqrt(N)), 0, -1):
202
+ if N % h == 0:
203
+ H_grid = h
204
+ W_grid = N // h
205
+ found = True
206
+ break
207
+ if not found:
208
+ W_grid += 1
209
+ else:
210
+ break
211
+
212
+ if H_grid * W_grid != N:
213
+ print(f" Warning: Cannot reshape {N} patches into grid. Interpolating to square.")
214
+ target_size = int(math.sqrt(N))
215
+ patch_tokens_flat = patch_tokens.transpose(1, 2)
216
+ patch_tokens_2d = F.interpolate(
217
+ patch_tokens_flat.unsqueeze(-2) if patch_tokens_flat.dim() == 3 else patch_tokens_flat,
218
+ size=target_size * target_size,
219
+ mode='linear',
220
+ align_corners=False
221
+ ).reshape(B, D, target_size, target_size)
222
+ return patch_tokens_2d
223
+
224
+ feat_map = patch_tokens.transpose(1, 2).reshape(B, D, H_grid, W_grid)
225
+ return feat_map
226
+
227
+ def forward(self, pixel_values):
228
+ B, C, H, W = pixel_values.shape
229
+
230
+ if self.layers is not None:
231
+ outputs = self.model(pixel_values, output_hidden_states=True)
232
+ hidden_states = outputs.hidden_states
233
+
234
+ feature_list = []
235
+ for idx in self.layers:
236
+ hidden = hidden_states[idx]
237
+ patch_tokens = hidden[:, 1:, :]
238
+ feat_map = self._reshape_to_2d(patch_tokens, B)
239
+ feature_list.append(feat_map)
240
+
241
+ target_h, target_w = feature_list[0].shape[-2:]
242
+
243
+ resized_features = []
244
+ for feat in feature_list:
245
+ if feat.shape[-2:] != (target_h, target_w):
246
+ feat = F.interpolate(feat, size=(target_h, target_w),
247
+ mode='bilinear', align_corners=False)
248
+ resized_features.append(feat)
249
+
250
+ features = torch.cat(resized_features, dim=1)
251
+ else:
252
+ outputs = self.model(pixel_values, output_hidden_states=False)
253
+ last_hidden = outputs.last_hidden_state[:, 1:, :]
254
+ features = self._reshape_to_2d(last_hidden, B)
255
+
256
+ return features
257
+
258
+ # ============================================================================
259
+ # SHALLOW STEM FOR SKIP CONNECTIONS
260
+ # ============================================================================
261
+ class ShallowStem(nn.Module):
262
+ """Extracts multi‑scale features from the input image."""
263
+ def __init__(self, in_channels=3, base_channels=64):
264
+ super().__init__()
265
+ self.conv1 = nn.Sequential(
266
+ nn.Conv2d(in_channels, base_channels, 3, padding=1, bias=False),
267
+ nn.BatchNorm2d(base_channels),
268
+ nn.ReLU(inplace=True)
269
+ )
270
+ self.conv2 = nn.Sequential(
271
+ nn.Conv2d(base_channels, base_channels*2, 3, stride=2, padding=1, bias=False),
272
+ nn.BatchNorm2d(base_channels*2),
273
+ nn.ReLU(inplace=True)
274
+ )
275
+ self.conv3 = nn.Sequential(
276
+ nn.Conv2d(base_channels*2, base_channels*4, 3, stride=2, padding=1, bias=False),
277
+ nn.BatchNorm2d(base_channels*4),
278
+ nn.ReLU(inplace=True)
279
+ )
280
+ self.conv4 = nn.Sequential(
281
+ nn.Conv2d(base_channels*4, base_channels*8, 3, stride=2, padding=1, bias=False),
282
+ nn.BatchNorm2d(base_channels*8),
283
+ nn.ReLU(inplace=True)
284
+ )
285
+
286
+ def forward(self, x):
287
+ x = self.conv1(x)
288
+ f2 = self.conv2(x)
289
+ f3 = self.conv3(f2)
290
+ f4 = self.conv4(f3)
291
+ return [f4, f3, f2]
292
+
293
+ # ============================================================================
294
+ # U‑Net DECODER WITH SKIP CONNECTIONS
295
+ # ============================================================================
296
+ class UNetDecoder(nn.Module):
297
+ """Decoder that progressively upsamples ViT features."""
298
+ def __init__(self, vit_channels=1024, stem_channels=[512,256,128], num_classes=1):
299
+ super().__init__()
300
+ self.up1 = self._up_block(vit_channels, 256)
301
+ self.conv1 = self._conv_block(256 + stem_channels[0], 256)
302
+
303
+ self.up2 = self._up_block(256, 128)
304
+ self.conv2 = self._conv_block(128 + stem_channels[1], 128)
305
+
306
+ self.up3 = self._up_block(128, 64)
307
+ self.conv3 = self._conv_block(64 + stem_channels[2], 64)
308
+
309
+ self.up4 = nn.UpsamplingBilinear2d(scale_factor=2)
310
+ self.final = nn.Conv2d(64, num_classes, kernel_size=1)
311
+
312
+ def _up_block(self, in_ch, out_ch):
313
+ return nn.Sequential(
314
+ nn.UpsamplingBilinear2d(scale_factor=2),
315
+ nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
316
+ nn.BatchNorm2d(out_ch),
317
+ nn.ReLU(inplace=True)
318
+ )
319
+
320
+ def _conv_block(self, in_ch, out_ch):
321
+ return nn.Sequential(
322
+ nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
323
+ nn.BatchNorm2d(out_ch),
324
+ nn.ReLU(inplace=True),
325
+ nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
326
+ nn.BatchNorm2d(out_ch),
327
+ nn.ReLU(inplace=True)
328
+ )
329
+
330
+ def forward(self, vit_features, skip_features):
331
+ x = self.up1(vit_features)
332
+
333
+ if x.shape[-2:] != skip_features[0].shape[-2:]:
334
+ x = F.interpolate(x, size=skip_features[0].shape[-2:], mode='bilinear', align_corners=False)
335
+
336
+ x = torch.cat([x, skip_features[0]], dim=1)
337
+ x = self.conv1(x)
338
+
339
+ x = self.up2(x)
340
+ if x.shape[-2:] != skip_features[1].shape[-2:]:
341
+ x = F.interpolate(x, size=skip_features[1].shape[-2:], mode='bilinear', align_corners=False)
342
+
343
+ x = torch.cat([x, skip_features[1]], dim=1)
344
+ x = self.conv2(x)
345
+
346
+ x = self.up3(x)
347
+ if x.shape[-2:] != skip_features[2].shape[-2:]:
348
+ x = F.interpolate(x, size=skip_features[2].shape[-2:], mode='bilinear', align_corners=False)
349
+
350
+ x = torch.cat([x, skip_features[2]], dim=1)
351
+ x = self.conv3(x)
352
+
353
+ x = self.up4(x)
354
+ return self.final(x)
355
+
356
+ # ============================================================================
357
+ # LOSS FUNCTIONS
358
+ # ============================================================================
359
+
360
+ class DiceLoss(nn.Module):
361
+ def __init__(self, smooth=1e-6):
362
+ super().__init__()
363
+ self.smooth = smooth
364
+
365
+ def forward(self, pred, target):
366
+ pred = torch.sigmoid(pred)
367
+ pred_flat = pred.view(-1)
368
+ target_flat = target.view(-1)
369
+
370
+ intersection = (pred_flat * target_flat).sum()
371
+ dice = (2. * intersection + self.smooth) / (pred_flat.sum() + target_flat.sum() + self.smooth)
372
+
373
+ return 1 - dice
374
+
375
+
376
+ class FocalLoss(nn.Module):
377
+ def __init__(self, alpha=0.25, gamma=2.0):
378
+ super().__init__()
379
+ self.alpha = alpha
380
+ self.gamma = gamma
381
+
382
+ def forward(self, pred, target):
383
+ bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
384
+ pt = torch.exp(-bce)
385
+ focal = self.alpha * (1 - pt) ** self.gamma * bce
386
+ return focal.mean()
387
+
388
+ class BoundaryLoss(nn.Module):
389
+ """Boundary loss using Sobel edge detection for sharper edges"""
390
+ def __init__(self):
391
+ super().__init__()
392
+ # Sobel kernels for edge detection
393
+ self.sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
394
+ self.sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
395
+
396
+ def forward(self, pred, target):
397
+ device = pred.device
398
+ self.sobel_x = self.sobel_x.to(device)
399
+ self.sobel_y = self.sobel_y.to(device)
400
+
401
+ # Get probabilities
402
+ pred_prob = torch.sigmoid(pred)
403
+
404
+ # Compute edges for prediction and target
405
+ pred_edges_x = F.conv2d(pred_prob, self.sobel_x, padding=1)
406
+ pred_edges_y = F.conv2d(pred_prob, self.sobel_y, padding=1)
407
+ pred_edges = torch.sqrt(pred_edges_x**2 + pred_edges_y**2 + 1e-6)
408
+
409
+ target_edges_x = F.conv2d(target, self.sobel_x, padding=1)
410
+ target_edges_y = F.conv2d(target, self.sobel_y, padding=1)
411
+ target_edges = torch.sqrt(target_edges_x**2 + target_edges_y**2 + 1e-6)
412
+
413
+ # MSE between edge maps
414
+ boundary_loss = F.mse_loss(pred_edges, target_edges)
415
+ return boundary_loss
416
+
417
+ class FocalDiceBoundaryLoss(nn.Module):
418
+ def __init__(self, focal_weight=0.6, dice_weight=0.3, boundary_weight=0.1):
419
+ super().__init__()
420
+ self.focal = FocalLoss()
421
+ self.dice = DiceLoss()
422
+ self.boundary = BoundaryLoss()
423
+ self.w_f = focal_weight
424
+ self.w_d = dice_weight
425
+ self.w_b = boundary_weight
426
+
427
+ def forward(self, pred, target):
428
+ return (self.w_f * self.focal(pred, target) +
429
+ self.w_d * self.dice(pred, target) +
430
+ self.w_b * self.boundary(pred, target))
431
+ # ============================================================================
432
+ # METRICS
433
+ # ============================================================================
434
+
435
+ def compute_dice(pred, target, threshold=0.5):
436
+ """Compute Dice score"""
437
+ pred_binary = (torch.sigmoid(pred) > threshold).float()
438
+ intersection = (pred_binary * target).sum()
439
+ dice = (2. * intersection) / (pred_binary.sum() + target.sum() + 1e-6)
440
+ return dice.item()
441
+
442
+
443
+ def compute_iou(pred, target, threshold=0.5):
444
+ """Compute IoU (Jaccard index)"""
445
+ pred_binary = (torch.sigmoid(pred) > threshold).float()
446
+ intersection = (pred_binary * target).sum()
447
+ union = pred_binary.sum() + target.sum() - intersection
448
+ iou = intersection / (union + 1e-6)
449
+ return iou.item()
450
+
451
+
452
+ def compute_precision_recall(pred, target, threshold=0.5):
453
+ """Compute precision and recall"""
454
+ pred_binary = (torch.sigmoid(pred) > threshold).float()
455
+ tp = (pred_binary * target).sum()
456
+ fp = (pred_binary * (1 - target)).sum()
457
+ fn = ((1 - pred_binary) * target).sum()
458
+
459
+ precision = tp / (tp + fp + 1e-6)
460
+ recall = tp / (tp + fn + 1e-6)
461
+
462
+ return precision.item(), recall.item()
463
+
464
+
465
+ def compute_hd95(pred, target, threshold=0.5, voxel_spacing=None):
466
+ """
467
+ Compute Hausdorff Distance 95th percentile.
468
+
469
+ Args:
470
+ pred: Tensor [B, 1, H, W] logits
471
+ target: Tensor [B, 1, H, W] ground truth
472
+ threshold: threshold for binarization
473
+ voxel_spacing: not used for 2D but kept for compatibility
474
+
475
+ Returns:
476
+ hd95: 95th percentile Hausdorff distance
477
+ """
478
+ # Convert to numpy and binarize
479
+ pred_binary = (torch.sigmoid(pred) > threshold).float().cpu().numpy().squeeze()
480
+ target_binary = target.cpu().numpy().squeeze()
481
+
482
+ # Handle batch dimension
483
+ if pred_binary.ndim == 3:
484
+ hd95_values = []
485
+ for i in range(pred_binary.shape[0]):
486
+ hd95_values.append(_compute_hd95_single(pred_binary[i], target_binary[i]))
487
+ return np.mean(hd95_values)
488
+ else:
489
+ return _compute_hd95_single(pred_binary, target_binary)
490
+
491
+
492
+ def _compute_hd95_single(pred, target):
493
+ """Compute HD95 for a single 2D image"""
494
+ if pred.sum() == 0 or target.sum() == 0:
495
+ return 100.0 # Return a high value if either is empty
496
+
497
+ # Get surface points
498
+ pred_border = pred - morphology.binary_erosion(pred)
499
+ target_border = target - morphology.binary_erosion(target)
500
+
501
+ if pred_border.sum() == 0 or target_border.sum() == 0:
502
+ return 100.0
503
+
504
+ # Get coordinates of border points
505
+ pred_coords = np.argwhere(pred_border > 0)
506
+ target_coords = np.argwhere(target_border > 0)
507
+
508
+ # Compute pairwise distances
509
+ distances_pred_to_target = []
510
+ for p in pred_coords:
511
+ dist = np.min(np.sqrt(np.sum((target_coords - p) ** 2, axis=1)))
512
+ distances_pred_to_target.append(dist)
513
+
514
+ distances_target_to_pred = []
515
+ for t in target_coords:
516
+ dist = np.min(np.sqrt(np.sum((pred_coords - t) ** 2, axis=1)))
517
+ distances_target_to_pred.append(dist)
518
+
519
+ # Get 95th percentile
520
+ all_distances = distances_pred_to_target + distances_target_to_pred
521
+ hd95 = np.percentile(all_distances, 95)
522
+
523
+ return hd95
524
+
525
+
526
+ def compute_all_metrics(pred, target, threshold=0.5):
527
+ """Compute all metrics at once"""
528
+ dice = compute_dice(pred, target, threshold)
529
+ iou = compute_iou(pred, target, threshold)
530
+ precision, recall = compute_precision_recall(pred, target, threshold)
531
+ hd95 = compute_hd95(pred, target, threshold)
532
+
533
+ return {
534
+ 'dice': dice,
535
+ 'iou': iou,
536
+ 'precision': precision,
537
+ 'recall': recall,
538
+ 'hd95': hd95
539
+ }
540
+
541
+
542
+ def evaluate(decoder, stem, encoder, loader, device):
543
+ """Comprehensive evaluation"""
544
+ decoder.eval()
545
+ stem.eval()
546
+ encoder.eval()
547
+
548
+ all_metrics = {
549
+ 'dice': [], 'iou': [], 'precision': [], 'recall': [], 'hd95': []
550
+ }
551
+
552
+ with torch.no_grad():
553
+ for images, masks in tqdm(loader, desc="Evaluating"):
554
+ images, masks = images.to(device), masks.to(device)
555
+ vit_features = encoder(images)
556
+ skip = stem(images)
557
+ logits = decoder(vit_features, skip)
558
+
559
+ metrics = compute_all_metrics(logits, masks)
560
+
561
+ for key in all_metrics:
562
+ all_metrics[key].append(metrics[key])
563
+
564
+ # Compute mean and std for each metric
565
+ results = {}
566
+ for key in all_metrics:
567
+ results[key] = np.mean(all_metrics[key])
568
+ results[f'{key}_std'] = np.std(all_metrics[key])
569
+
570
+ return results
571
+
572
+ # ============================================================================
573
+ # TRAINING FUNCTION
574
+ # ============================================================================
575
+
576
+ def train_model(decoder, stem, encoder, train_loader, val_loader, config):
577
+ """Enhanced training loop with cosine annealing restarts and comprehensive logging"""
578
+ device = config.device
579
+ best_score = -float('inf')
580
+ criterion = FocalDiceBoundaryLoss(focal_weight=config.focal_weight, dice_weight=config.dice_weight, boundary_weight=config.boundary_weight)
581
+
582
+ # Optimizer includes both stem and decoder parameters
583
+ optimizer = AdamW(
584
+ list(decoder.parameters()) + list(stem.parameters()),
585
+ lr=config.learning_rate,
586
+ weight_decay=config.weight_decay
587
+ )
588
+
589
+ # Cosine Annealing with Warm Restarts
590
+ scheduler = CosineAnnealingWarmRestarts(
591
+ optimizer,
592
+ T_0=config.T_0,
593
+ T_mult=config.T_mult,
594
+ eta_min=config.min_lr
595
+ )
596
+
597
+
598
+ history = {
599
+ 'train_loss': [],
600
+ 'val_metrics': [], # Store full metrics dict per epoch
601
+ 'lr': []
602
+ }
603
+
604
+ for epoch in range(config.num_epochs):
605
+ # Training
606
+ decoder.train()
607
+ stem.train()
608
+ encoder.eval()
609
+
610
+ epoch_loss = 0
611
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
612
+
613
+ for batch_idx, (images, masks) in enumerate(progress_bar):
614
+ images, masks = images.to(device), masks.to(device)
615
+
616
+ # Frozen encoder
617
+ with torch.no_grad():
618
+ vit_features = encoder(images)
619
+
620
+ # Trainable stem
621
+ skip_features = stem(images)
622
+
623
+ # Trainable decoder
624
+ logits = decoder(vit_features, skip_features)
625
+ loss = criterion(logits, masks)
626
+
627
+ optimizer.zero_grad()
628
+ loss.backward()
629
+ torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.0)
630
+ torch.nn.utils.clip_grad_norm_(stem.parameters(), max_norm=1.0)
631
+ optimizer.step()
632
+
633
+ # Step scheduler per batch for cosine annealing
634
+ scheduler.step(epoch + batch_idx / len(train_loader))
635
+
636
+ epoch_loss += loss.item()
637
+ current_lr = optimizer.param_groups[0]['lr']
638
+ progress_bar.set_postfix({'loss': loss.item(), 'lr': f'{current_lr:.2e}'})
639
+
640
+ avg_loss = epoch_loss / len(train_loader)
641
+
642
+ # Validation
643
+ val_metrics = evaluate(decoder, stem, encoder, val_loader, device)
644
+
645
+ # Store metrics
646
+ history['train_loss'].append(avg_loss)
647
+ history['val_metrics'].append(val_metrics)
648
+ history['lr'].append(current_lr)
649
+
650
+ # Save best model
651
+
652
+
653
+ current_score = (0.6 * val_metrics['dice'] +
654
+ 0.3 * val_metrics['iou'] -
655
+ 0.1 * min(val_metrics['hd95'] / 100.0, 1.0))
656
+
657
+ if current_score > best_score : # Rename best_dice to best_score for clarity
658
+ best_score = current_score
659
+ print(f"✓ Saved new best model with Dice: {val_metrics['dice']:.4f}, "
660
+ f"IoU: {val_metrics['iou']:.4f}, HD95: {val_metrics['hd95']:.2f}")
661
+ torch.save({
662
+ 'epoch': epoch,
663
+ 'decoder_state_dict': decoder.state_dict(),
664
+ 'stem_state_dict': stem.state_dict(),
665
+ 'encoder_state_dict': encoder.state_dict(),
666
+ 'optimizer_state_dict': optimizer.state_dict(),
667
+ 'best_score': best_score,
668
+ 'config': config,
669
+ }, os.path.join(config.save_dir, "best_unet_model.pth"))
670
+ print(f"✓ Saved new best model with Score: {best_score:.4f}")
671
+
672
+ # Print epoch summary
673
+ print(f"\n{'='*60}")
674
+ print(f"Epoch {epoch+1}/{config.num_epochs} Summary:")
675
+ print(f" Learning Rate: {current_lr:.6f}")
676
+ print(f" Train Loss: {avg_loss:.4f}")
677
+ print(f" Val Dice: {val_metrics['dice']:.4f} ± {val_metrics['dice_std']:.4f}")
678
+ print(f" Val IoU: {val_metrics['iou']:.4f} ± {val_metrics['iou_std']:.4f}")
679
+ print(f" Val Precision: {val_metrics['precision']:.4f} ± {val_metrics['precision_std']:.4f}")
680
+ print(f" Val Recall: {val_metrics['recall']:.4f} ± {val_metrics['recall_std']:.4f}")
681
+ print(f" Val HD95: {val_metrics['hd95']:.4f} ± {val_metrics['hd95_std']:.4f}")
682
+ print(f"{'='*60}\n")
683
+
684
+ return history, best_score
685
+
686
+ # ============================================================================
687
+ # VISUALIZATION
688
+ # ============================================================================
689
+
690
+ def visualize_predictions(decoder, stem, encoder, dataset, device, num_samples=5,
691
+ save_path="predictions.png", subset_name="Test"):
692
+ """Visualize sample predictions with all metrics"""
693
+ decoder.eval()
694
+ stem.eval()
695
+ encoder.eval()
696
+
697
+ # Create a larger figure for 5 columns (image, mask, pred, overlay, metrics)
698
+ fig, axes = plt.subplots(num_samples, 5, figsize=(20, 4*num_samples))
699
+
700
+ if num_samples == 1:
701
+ axes = axes.reshape(1, -1)
702
+
703
+ indices = np.random.choice(len(dataset), num_samples, replace=False)
704
+
705
+ with torch.no_grad():
706
+ for i, idx in enumerate(indices):
707
+ image, mask = dataset[idx]
708
+ image_batch = image.unsqueeze(0).to(device)
709
+ mask_np = mask.cpu().numpy().squeeze()
710
+
711
+ vit_features = encoder(image_batch)
712
+ skip = stem(image_batch)
713
+ logits = decoder(vit_features, skip)
714
+ pred = torch.sigmoid(logits).cpu().numpy().squeeze()
715
+ pred_binary = (pred > 0.5).astype(np.float32)
716
+
717
+ # Compute metrics
718
+ metrics = compute_all_metrics(logits, mask.to(device))
719
+
720
+ # Denormalize image for display
721
+ img_display = image.cpu().squeeze().permute(1, 2, 0).numpy()
722
+ mean = np.array(config.mean).reshape(1, 1, 3)
723
+ std = np.array(config.std).reshape(1, 1, 3)
724
+ img_display = img_display * std + mean
725
+ img_display = np.clip(img_display, 0, 1)
726
+
727
+ # Create overlay
728
+ overlay = img_display.copy()
729
+ overlay[pred_binary > 0.5] = [1, 0, 0] # Red for predictions
730
+ overlay = 0.7 * img_display + 0.3 * overlay
731
+
732
+ # Plot images
733
+ axes[i, 0].imshow(img_display)
734
+ axes[i, 0].set_title("Input Image")
735
+ axes[i, 0].axis('off')
736
+
737
+ axes[i, 1].imshow(mask_np, cmap='gray')
738
+ axes[i, 1].set_title("Ground Truth")
739
+ axes[i, 1].axis('off')
740
+
741
+ axes[i, 2].imshow(pred_binary, cmap='gray')
742
+ axes[i, 2].set_title("Prediction")
743
+ axes[i, 2].axis('off')
744
+
745
+ axes[i, 3].imshow(overlay)
746
+ axes[i, 3].set_title("Overlay")
747
+ axes[i, 3].axis('off')
748
+
749
+ # Display metrics in text
750
+ metrics_text = f"Dice: {metrics['dice']:.3f}\nIoU: {metrics['iou']:.3f}\nHD95: {metrics['hd95']:.1f}"
751
+ axes[i, 4].text(0.1, 0.5, metrics_text, fontsize=12, verticalalignment='center',
752
+ transform=axes[i, 4].transAxes)
753
+ axes[i, 4].axis('off')
754
+
755
+ plt.suptitle(f"{subset_name} Set - Sample Predictions", fontsize=16, y=1.02)
756
+ plt.tight_layout()
757
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
758
+ plt.close()
759
+ print(f"Visualization saved to {save_path}")
760
+
761
+ # ============================================================================
762
+ # MAIN PIPELINE
763
+ # ============================================================================
764
+
765
+ def load_and_prepare_data(config):
766
+ """Load Kvasir-SEG dataset and create train/val/test splits"""
767
+
768
+ images_path = os.path.join(config.dataset_path, "images")
769
+ masks_path = os.path.join(config.dataset_path, "masks")
770
+
771
+ if not os.path.exists(images_path):
772
+ images_path = config.dataset_path
773
+ masks_path = config.dataset_path
774
+
775
+ image_files = sorted(glob.glob(os.path.join(images_path, "*.jpg")))
776
+ mask_files = sorted(glob.glob(os.path.join(masks_path, "*.jpg")))
777
+
778
+ if len(image_files) == 0:
779
+ image_files = sorted(glob.glob(os.path.join(images_path, "*.png")))
780
+ mask_files = sorted(glob.glob(os.path.join(masks_path, "*.png")))
781
+
782
+ print(f"Found {len(image_files)} images and {len(mask_files)} masks")
783
+
784
+ if len(image_files) == 0:
785
+ raise FileNotFoundError(f"No images found in {config.dataset_path}")
786
+
787
+ assert len(image_files) == len(mask_files), f"Mismatch: {len(image_files)} images vs {len(mask_files)} masks"
788
+
789
+ # Split into train/val/test
790
+ train_files, temp_files = train_test_split(
791
+ list(zip(image_files, mask_files)),
792
+ test_size=config.val_split + config.test_split,
793
+ random_state=42
794
+ )
795
+ val_files, test_files = train_test_split(
796
+ temp_files,
797
+ test_size=config.test_split / (config.val_split + config.test_split),
798
+ random_state=42
799
+ )
800
+
801
+ train_images, train_masks = zip(*train_files) if train_files else ([], [])
802
+ val_images, val_masks = zip(*val_files) if val_files else ([], [])
803
+ test_images, test_masks = zip(*test_files) if test_files else ([], [])
804
+
805
+ print(f"Train: {len(train_images)}, Val: {len(val_images)}, Test: {len(test_images)}")
806
+
807
+ return (list(train_images), list(train_masks)), (list(val_images), list(val_masks)), (list(test_images), list(test_masks))
808
+
809
+
810
+ def plot_training_history(history, save_dir):
811
+ """Plot training history"""
812
+ epochs = range(1, len(history['train_loss']) + 1)
813
+
814
+ # Extract validation metrics
815
+ val_dice = [m['dice'] for m in history['val_metrics']]
816
+ val_iou = [m['iou'] for m in history['val_metrics']]
817
+ val_hd95 = [m['hd95'] for m in history['val_metrics']]
818
+ val_precision = [m['precision'] for m in history['val_metrics']]
819
+ val_recall = [m['recall'] for m in history['val_metrics']]
820
+
821
+ fig, axes = plt.subplots(2, 3, figsize=(18, 10))
822
+
823
+ # Loss
824
+ axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss')
825
+ axes[0, 0].set_title('Training Loss')
826
+ axes[0, 0].set_xlabel('Epoch')
827
+ axes[0, 0].set_ylabel('Loss')
828
+ axes[0, 0].grid(True)
829
+ axes[0, 0].legend()
830
+
831
+ # Learning Rate
832
+ axes[0, 1].plot(epochs, history['lr'], 'g-')
833
+ axes[0, 1].set_title('Learning Rate')
834
+ axes[0, 1].set_xlabel('Epoch')
835
+ axes[0, 1].set_ylabel('LR')
836
+ axes[0, 1].set_yscale('log')
837
+ axes[0, 1].grid(True)
838
+
839
+ # Dice
840
+ axes[0, 2].plot(epochs, val_dice, 'r-', label='Val Dice')
841
+ axes[0, 2].set_title('Validation Dice')
842
+ axes[0, 2].set_xlabel('Epoch')
843
+ axes[0, 2].set_ylabel('Dice')
844
+ axes[0, 2].grid(True)
845
+ axes[0, 2].legend()
846
+
847
+ # IoU
848
+ axes[1, 0].plot(epochs, val_iou, 'm-', label='Val IoU')
849
+ axes[1, 0].set_title('Validation IoU')
850
+ axes[1, 0].set_xlabel('Epoch')
851
+ axes[1, 0].set_ylabel('IoU')
852
+ axes[1, 0].grid(True)
853
+ axes[1, 0].legend()
854
+
855
+ # HD95
856
+ axes[1, 1].plot(epochs, val_hd95, 'c-', label='Val HD95')
857
+ axes[1, 1].set_title('Validation HD95')
858
+ axes[1, 1].set_xlabel('Epoch')
859
+ axes[1, 1].set_ylabel('HD95 (pixels)')
860
+ axes[1, 1].grid(True)
861
+ axes[1, 1].legend()
862
+
863
+ # Precision & Recall
864
+ axes[1, 2].plot(epochs, val_precision, 'orange', label='Precision')
865
+ axes[1, 2].plot(epochs, val_recall, 'purple', label='Recall')
866
+ axes[1, 2].set_title('Validation Precision & Recall')
867
+ axes[1, 2].set_xlabel('Epoch')
868
+ axes[1, 2].set_ylabel('Value')
869
+ axes[1, 2].grid(True)
870
+ axes[1, 2].legend()
871
+
872
+ plt.tight_layout()
873
+ plt.savefig(os.path.join(save_dir, 'training_history.png'), dpi=150, bbox_inches='tight')
874
+ plt.close()
875
+
876
+ # Save history to CSV
877
+ history_df = pd.DataFrame({
878
+ 'epoch': epochs,
879
+ 'train_loss': history['train_loss'],
880
+ 'val_dice': val_dice,
881
+ 'val_iou': val_iou,
882
+ 'val_hd95': val_hd95,
883
+ 'val_precision': val_precision,
884
+ 'val_recall': val_recall,
885
+ 'lr': history['lr']
886
+ })
887
+ history_df.to_csv(os.path.join(save_dir, 'training_history.csv'), index=False)
888
+
889
+
890
+ def main():
891
+ print("=" * 60)
892
+ print("DINOv3 Polyp Segmentation Training - With HD95 & Cosine Annealing")
893
+ print("=" * 60)
894
+
895
+ # Load data
896
+ print("\n1. Loading dataset...")
897
+ train_data, val_data, test_data = load_and_prepare_data(config)
898
+
899
+ # Data augmentations
900
+ train_transform = A.Compose([
901
+ A.Resize(config.image_size, config.image_size),
902
+ A.RandomRotate90(p=0.5),
903
+ A.HorizontalFlip(p=0.5),
904
+ A.VerticalFlip(p=0.5),
905
+ A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
906
+ A.OneOf([
907
+ A.MotionBlur(p=0.2),
908
+ A.GaussianBlur(blur_limit=3, p=0.2),
909
+ ], p=0.3),
910
+ A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05, p=0.3),
911
+ ToTensorV2(),
912
+ ])
913
+
914
+ val_transform = A.Compose([
915
+ A.Resize(config.image_size, config.image_size),
916
+ ToTensorV2(),
917
+ ])
918
+
919
+ # Create datasets
920
+ train_dataset = PolypDataset(
921
+ train_data[0], train_data[1],
922
+ transform=train_transform,
923
+ target_size=(config.image_size, config.image_size)
924
+ )
925
+ val_dataset = PolypDataset(
926
+ val_data[0], val_data[1],
927
+ transform=val_transform,
928
+ target_size=(config.image_size, config.image_size)
929
+ )
930
+ test_dataset = PolypDataset(
931
+ test_data[0], test_data[1],
932
+ transform=val_transform,
933
+ target_size=(config.image_size, config.image_size)
934
+ )
935
+
936
+ # Dataloaders
937
+ train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
938
+ val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
939
+ test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
940
+
941
+ print(f"\n2. Initializing DINOv3 encoder...")
942
+
943
+ encoder = DINOv3Encoder(
944
+ model_name=config.model_name,
945
+ local_path=config.local_model_path,
946
+ freeze=True,
947
+ layers=config.multi_scale_layers
948
+ ).to(config.device)
949
+
950
+ # Test encoder
951
+ print(" Testing encoder with sample batch...")
952
+ sample_images, _ = next(iter(train_loader))
953
+ sample_images = sample_images.to(config.device)
954
+ with torch.no_grad():
955
+ sample_features = encoder(sample_images)
956
+ print(f" Encoder output shape: {sample_features.shape}")
957
+
958
+ print("\n3. Building U‑Net decoder with skip connections...")
959
+
960
+ stem = ShallowStem(in_channels=3, base_channels=64).to(config.device)
961
+ decoder = UNetDecoder(
962
+ vit_channels=encoder.out_channels,
963
+ stem_channels=[512, 256, 128],
964
+ num_classes=1
965
+ ).to(config.device)
966
+
967
+ trainable = sum(p.numel() for p in decoder.parameters()) + sum(p.numel() for p in stem.parameters())
968
+ print(f" Trainable parameters (stem + decoder): {trainable:,}")
969
+
970
+ print("\n4. Starting training with Cosine Annealing Warm Restarts...")
971
+ print(f" Initial LR: {config.learning_rate:.6f}")
972
+ print(f" T_0: {config.T_0}, T_mult: {config.T_mult}")
973
+ print(f" Min LR: {config.min_lr:.6f}")
974
+
975
+ history, best_score = train_model(decoder, stem, encoder, train_loader, val_loader, config)
976
+
977
+ print(f"\n✓ Training complete! Best validation Score: {best_score:.4f}")
978
+
979
+ # Final evaluation on all sets
980
+ print("\n5. Final evaluation on all sets...")
981
+
982
+ # Load best model for final evaluation
983
+ checkpoint = torch.load(os.path.join(config.save_dir, "best_unet_model.pth"),weights_only=False)
984
+ decoder.load_state_dict(checkpoint['decoder_state_dict'])
985
+ stem.load_state_dict(checkpoint['stem_state_dict'])
986
+
987
+ # Evaluate on all splits
988
+ print("\nEvaluating on Training Set...")
989
+ train_metrics = evaluate(decoder, stem, encoder, train_loader, config.device)
990
+
991
+ print("Evaluating on Validation Set...")
992
+ val_metrics = evaluate(decoder, stem, encoder, val_loader, config.device)
993
+
994
+ print("Evaluating on Test Set...")
995
+ test_metrics = evaluate(decoder, stem, encoder, test_loader, config.device)
996
+
997
+ # Print comprehensive results
998
+ print("\n" + "=" * 80)
999
+ print("FINAL RESULTS - ALL METRICS")
1000
+ print("=" * 80)
1001
+
1002
+ print(f"\n{'Metric':<15} {'Train':<20} {'Validation':<20} {'Test':<20}")
1003
+ print("-" * 75)
1004
+
1005
+ for metric in ['dice', 'iou', 'precision', 'recall', 'hd95']:
1006
+ print(f"{metric.upper():<15} "
1007
+ f"{train_metrics[metric]:.4f} ± {train_metrics[f'{metric}_std']:.4f} "
1008
+ f"{val_metrics[metric]:.4f} ± {val_metrics[f'{metric}_std']:.4f} "
1009
+ f"{test_metrics[metric]:.4f} ± {test_metrics[f'{metric}_std']:.4f}")
1010
+
1011
+ print("=" * 80)
1012
+
1013
+ # Plot training history
1014
+ print("\n6. Plotting training history...")
1015
+ plot_training_history(history, config.save_dir)
1016
+
1017
+ # Visualize predictions for all subsets
1018
+ print("\n7. Generating visualizations for all subsets...")
1019
+ visualize_predictions(decoder, stem, encoder, train_dataset, config.device,
1020
+ num_samples=5, save_path=os.path.join(config.save_dir, "train_predictions.png"),
1021
+ subset_name="Training")
1022
+ visualize_predictions(decoder, stem, encoder, val_dataset, config.device,
1023
+ num_samples=5, save_path=os.path.join(config.save_dir, "val_predictions.png"),
1024
+ subset_name="Validation")
1025
+ visualize_predictions(decoder, stem, encoder, test_dataset, config.device,
1026
+ num_samples=5, save_path=os.path.join(config.save_dir, "test_predictions.png"),
1027
+ subset_name="Test")
1028
+
1029
+ # Save comprehensive results
1030
+ results = {
1031
+ 'best_val_score': float(best_score),
1032
+ 'final_epoch': len(history['train_loss']),
1033
+ 'train_metrics': {k: float(v) for k, v in train_metrics.items()},
1034
+ 'val_metrics': {k: float(v) for k, v in val_metrics.items()},
1035
+ 'test_metrics': {k: float(v) for k, v in test_metrics.items()},
1036
+ 'training_history': {
1037
+ 'train_loss': [float(x) for x in history['train_loss']],
1038
+ 'lr': [float(x) for x in history['lr']],
1039
+ 'val_metrics': [{k: float(v) for k, v in m.items()} for m in history['val_metrics']]
1040
+ },
1041
+ 'config': {
1042
+ 'model_name': config.model_name,
1043
+ 'image_size': config.image_size,
1044
+ 'batch_size': config.batch_size,
1045
+ 'num_epochs': config.num_epochs,
1046
+ 'learning_rate': config.learning_rate,
1047
+ 'min_lr': config.min_lr,
1048
+ 'T_0': config.T_0,
1049
+ 'T_mult': config.T_mult,
1050
+ 'scheduler': 'CosineAnnealingWarmRestarts',
1051
+ 'focal_weight': config.focal_weight,
1052
+ 'dice_weight': config.dice_weight,
1053
+ 'multi_scale_layers': config.multi_scale_layers
1054
+ }
1055
+ }
1056
+
1057
+ # Save as JSON
1058
+ with open(os.path.join(config.save_dir, "comprehensive_results.json"), 'w') as f:
1059
+ json.dump(results, f, indent=2)
1060
+
1061
+ # Save as formatted text report
1062
+ with open(os.path.join(config.save_dir, "results_report.txt"), 'w') as f:
1063
+ f.write("=" * 80 + "\n")
1064
+ f.write("DINOv3 POLYP SEGMENTATION - FINAL REPORT\n")
1065
+ f.write("=" * 80 + "\n\n")
1066
+ f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
1067
+
1068
+ f.write("CONFIGURATION:\n")
1069
+ f.write("-" * 40 + "\n")
1070
+ for key, value in results['config'].items():
1071
+ f.write(f" {key}: {value}\n")
1072
+
1073
+ f.write("\n\nFINAL METRICS:\n")
1074
+ f.write("-" * 40 + "\n")
1075
+ f.write(f"{'Metric':<15} {'Train':<25} {'Validation':<25} {'Test':<25}\n")
1076
+ f.write("-" * 90 + "\n")
1077
+
1078
+ for metric in ['dice', 'iou', 'precision', 'recall', 'hd95']:
1079
+ f.write(f"{metric.upper():<15} "
1080
+ f"{train_metrics[metric]:.4f} ± {train_metrics[f'{metric}_std']:.4f} "
1081
+ f"{val_metrics[metric]:.4f} ± {val_metrics[f'{metric}_std']:.4f} "
1082
+ f"{test_metrics[metric]:.4f} ± {test_metrics[f'{metric}_std']:.4f}\n")
1083
+
1084
+ f.write("\n\nBest Validation Score (Dice+IoU-HD95/100): {:.4f}\n".format(best_score))
1085
+ f.write("Training completed at epoch: {}\n".format(len(history['train_loss'])))
1086
+
1087
+ print(f"\n✓ Comprehensive results saved to {config.save_dir}/")
1088
+ print(f" - comprehensive_results.json")
1089
+ print(f" - results_report.txt")
1090
+ print(f" - training_history.csv")
1091
+ print(f" - training_history.png")
1092
+ print(f" - train_predictions.png")
1093
+ print(f" - val_predictions.png")
1094
+ print(f" - test_predictions.png")
1095
+ print("\n🎉 Enhanced training pipeline complete!")
1096
+
1097
+
1098
+ if __name__ == "__main__":
1099
+ main()