InPeerReview commited on
Commit
8b04570
·
verified ·
1 Parent(s): 7ad9dfd

Upload 3 files

Browse files
Files changed (3) hide show
  1. requirement.txt +70 -0
  2. test_cd.py +340 -0
  3. train_cd.py +312 -0
requirement.txt ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict==2.4.0
2
+ aliyun-python-sdk-core==2.15.0
3
+ aliyun-python-sdk-kms==2.16.2
4
+ cffi==1.16.0
5
+ click==8.1.7
6
+ colorama==0.4.6
7
+ contourpy==1.2.0
8
+ crcmod==1.7
9
+ cryptography==42.0.5
10
+ cycler==0.12.1
11
+ einops==0.7.0
12
+ fonttools==4.49.0
13
+ fsspec==2024.2.0
14
+ ftfy==6.1.3
15
+ fvcore==0.1.5.post20221221
16
+ huggingface-hub==0.21.4
17
+ importlib-metadata==7.0.2
18
+ iopath==0.1.10
19
+ jmespath==0.10.0
20
+ kiwisolver==1.4.5
21
+ markdown==3.5.2
22
+ markdown-it-py==3.0.0
23
+ matplotlib==3.8.3
24
+ mdurl==0.1.2
25
+ mmcv==2.1.0
26
+ mmengine==0.10.3
27
+ mmsegmentation==1.2.2
28
+ model-index==0.1.11
29
+ monai==1.3.0
30
+ ninja==1.11.1.1
31
+ opencv-python==4.9.0.80
32
+ opendatalab==0.0.10
33
+ openmim==0.3.9
34
+ openxlab==0.0.35
35
+ ordered-set==4.1.0
36
+ oss2==2.17.0
37
+ packaging==24.0
38
+ pandas==2.2.1
39
+ platformdirs==4.2.0
40
+ portalocker==2.8.2
41
+ prettytable==3.10.0
42
+ pycparser==2.21
43
+ pycryptodome==3.20.0
44
+ pygments==2.17.2
45
+ pyparsing==3.1.2
46
+ python-dateutil==2.9.0.post0
47
+ pytoolconfig==1.3.1
48
+ pytz==2023.4
49
+ regex==2023.12.25
50
+ requests==2.28.2
51
+ rich==13.4.2
52
+ rope==1.12.0
53
+ safetensors==0.4.2
54
+ scipy==1.12.0
55
+ setuptools==60.2.0
56
+ six==1.16.0
57
+ tabulate==0.9.0
58
+ termcolor==2.4.0
59
+ thop==0.1.1-2209072238
60
+ timm==0.9.16
61
+ tokenizers==0.15.2
62
+ tomli==2.0.1
63
+ tqdm==4.65.2
64
+ transformers==4.38.2
65
+ tzdata==2024.1
66
+ urllib3==1.26.18
67
+ wcwidth==0.2.13
68
+ yacs==0.1.8
69
+ yapf==0.40.2
70
+ zipp==3.17.0
test_cd.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import data as Data
4
+ import models as Model
5
+ import torch.nn as nn
6
+ import argparse
7
+ import logging
8
+ import core.logger as Logger
9
+ import os
10
+ import numpy as np
11
+ from misc.metric_tools import ConfuseMatrixMeter
12
+ from models.loss import *
13
+ from collections import OrderedDict
14
+ import core.metrics as Metrics
15
+ from misc.torchutils import get_scheduler, save_network
16
+
17
+ if __name__ == '__main__':
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument('--config', type=str, default='./config/whu/whu_test.json',
20
+ help='JSON file for configuration')
21
+ parser.add_argument('--phase', type=str, default='test',
22
+ choices=['train', 'test'], help='Run either train(training + validation) or testing',)
23
+ parser.add_argument('--gpu_ids', type=str, default=None)
24
+ parser.add_argument('-log_eval', action='store_true')
25
+
26
+ args = parser.parse_args()
27
+ opt = Logger.parse(args)
28
+
29
+ opt = Logger.dict_to_nonedict(opt)
30
+
31
+ torch.backends.cudnn.enabled = True
32
+ torch.backends.cudnn.benchmark = True
33
+
34
+ Logger.setup_logger(logger_name=None, root=opt['path_cd']['log'], phase='train',
35
+ level=logging.INFO, screen=True)
36
+ Logger.setup_logger(logger_name='test', root=opt['path_cd']['log'], phase='test',
37
+ level=logging.INFO)
38
+ logger = logging.getLogger('base')
39
+ logger.info(Logger.dict2str(opt))
40
+
41
+ for phase, dataset_opt in opt['datasets'].items():
42
+ if phase == 'train' and args.phase != 'test':
43
+ print("Create [train] change-detection dataloader")
44
+ train_set = Data.create_cd_dataset(dataset_opt=dataset_opt, phase=phase)
45
+ train_loader = Data.create_cd_dataloader(train_set, dataset_opt, phase)
46
+ opt['len_train_dataloader'] = len(train_loader)
47
+
48
+ elif phase == 'val' and args.phase != 'test':
49
+ print("Create [val] change-detection dataloader")
50
+ val_set = Data.create_cd_dataset(dataset_opt=dataset_opt, phase=phase)
51
+ val_loader = Data.create_cd_dataloader(val_set, dataset_opt, phase)
52
+ opt['len_val_dataloader'] = len(val_loader)
53
+
54
+ elif phase == 'test' and args.phase == 'test':
55
+ print("Create [test] change-detection dataloader")
56
+ test_set = Data.create_cd_dataset(dataset_opt=dataset_opt, phase=phase)
57
+ test_loader = Data.create_cd_dataloader(test_set, dataset_opt, phase)
58
+ opt['len_test_dataloader'] = len(test_loader)
59
+
60
+ logger.info('Initial Dataset Finished')
61
+ cd_model = Model.create_CD_model(opt)
62
+
63
+ if opt['model']['loss'] == 'ce_dice':
64
+ loss_fun = ce_dice
65
+ elif opt['model']['loss'] == 'ce':
66
+ loss_fun = cross_entropy
67
+
68
+ if opt['train']["optimizer"]["type"] == 'adam':
69
+ optimer = optim.Adam(cd_model.parameters(), lr=opt['train']["optimizer"]["lr"])
70
+ elif opt['train']["optimizer"]["type"] == 'adamw':
71
+ optimer = optim.AdamW(cd_model.parameters(), lr=opt['train']["optimizer"]["lr"])
72
+
73
+ device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
74
+ cd_model.to(device)
75
+ if len(opt['gpu_ids']) > 0:
76
+ cd_model = nn.DataParallel(cd_model)
77
+ metric = ConfuseMatrixMeter(n_class=2)
78
+ log_dict = OrderedDict()
79
+
80
+ if opt['phase'] == 'train':
81
+ best_mF1 = 0.0
82
+ for current_epoch in range(0, opt['train']['n_epoch']):
83
+ print("......Begin Training......")
84
+ metric.clear()
85
+ cd_model.train()
86
+ train_result_path = '{}/train/{}'.format(opt['path_cd']['result'], current_epoch)
87
+ os.makedirs(train_result_path, exist_ok=True)
88
+
89
+ message = 'lr: %0.7f\n \n' % optimer.param_groups[0]['lr']
90
+ logger.info(message)
91
+ for current_step, train_data in enumerate(train_loader):
92
+ train_im1 = train_data['A'].to(device)
93
+ train_im2 = train_data['B'].to(device)
94
+ pred_img = cd_model(train_im1, train_im2)
95
+ gt = train_data['L'].to(device).long()
96
+ train_loss = loss_fun(pred_img, gt)
97
+ optimer.zero_grad()
98
+ train_loss.backward()
99
+ optimer.step()
100
+ log_dict['loss'] = train_loss.item()
101
+
102
+ G_pred = pred_img.detach()
103
+ G_pred = torch.argmax(G_pred, dim=1)
104
+ current_score = metric.update_cm(pr=G_pred.cpu().numpy(), gt=gt.detach().cpu().numpy())
105
+ log_dict['running_acc'] = current_score.item()
106
+
107
+ if current_step % opt['train']['train_print_iter'] == 0:
108
+ logs = log_dict
109
+ message = '[Training CD]. epoch: [%d/%d]. Itter: [%d/%d], CD_loss: %.5f, running_mf1: %.5f\n' % \
110
+ (current_epoch, opt['train']['n_epoch'] - 1, current_step, len(train_loader), logs['loss'],
111
+ logs['running_acc'])
112
+ logger.info(message)
113
+
114
+ out_dict = OrderedDict()
115
+ out_dict['pred_cm'] = torch.argmax(pred_img, dim=1, keepdim=False)
116
+ out_dict['gt_cm'] = gt
117
+ visuals = out_dict
118
+
119
+ img_mode = "grid"
120
+ if img_mode == "single":
121
+ img_A = Metrics.tensor2img(train_data['A'], out_type=np.uint8, min_max=(-1, 1)) # uint8
122
+ img_B = Metrics.tensor2img(train_data['B'], out_type=np.uint8, min_max=(-1, 1)) # uint8
123
+ gt_cm = Metrics.tensor2img(visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1), out_type=np.uint8,
124
+ min_max=(0, 1)) # uint8
125
+ pred_cm = Metrics.tensor2img(visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1), out_type=np.uint8,
126
+ min_max=(0, 1)) # uint8
127
+
128
+ Metrics.save_img(
129
+ img_A, '{}/img_A_e{}_b{}.png'.format(train_result_path, current_epoch, current_step))
130
+ Metrics.save_img(
131
+ img_B, '{}/img_B_e{}_b{}.png'.format(train_result_path, current_epoch, current_step))
132
+ Metrics.save_img(
133
+ pred_cm, '{}/img_pred_e{}_b{}.png'.format(train_result_path, current_epoch, current_step))
134
+ Metrics.save_img(
135
+ gt_cm, '{}/img_gt_e{}_b{}.png'.format(train_result_path, current_epoch, current_step))
136
+ else:
137
+ visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0
138
+ visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0
139
+ grid_img = torch.cat((train_data['A'].to(device),
140
+ train_data['B'].to(device),
141
+ visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
142
+ visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1)),
143
+ dim=0)
144
+ grid_img = Metrics.tensor2img(grid_img) # uint8
145
+ Metrics.save_img(
146
+ grid_img,
147
+ '{}/img_A_B_pred_gt_e{}_b{}.png'.format(train_result_path, current_epoch, current_step))
148
+
149
+ scores = metric.get_scores()
150
+ epoch_acc = scores['mf1']
151
+ log_dict['epoch_acc'] = epoch_acc.item()
152
+ for k, v in scores.items():
153
+ log_dict[k] = v
154
+ logs = log_dict
155
+ message = '[Training CD (epoch summary)]: epoch: [%d/%d]. epoch_mF1=%.5f \n' % \
156
+ (current_epoch, opt['train']['n_epoch'] - 1, logs['epoch_acc'])
157
+ for k, v in logs.items():
158
+ message += '{:s}: {:.4e} '.format(k, v)
159
+ message += '\n'
160
+ logger.info(message)
161
+
162
+ metric.clear()
163
+
164
+ cd_model.eval()
165
+ with torch.no_grad():
166
+ if current_epoch % opt['train']['val_freq'] == 0:
167
+ val_result_path = '{}/val/{}'.format(opt['path_cd']['result'], current_epoch)
168
+ os.makedirs(val_result_path, exist_ok=True)
169
+
170
+ for current_step, val_data in enumerate(val_loader):
171
+ val_img1 = val_data['A'].to(device)
172
+ val_img2 = val_data['B'].to(device)
173
+ pred_img = cd_model(val_img1, val_img2)
174
+ gt = val_data['L'].to(device).long()
175
+ val_loss = loss_fun(pred_img, gt)
176
+ log_dict['loss'] = val_loss.item()
177
+
178
+ G_pred = pred_img.detach()
179
+ G_pred = torch.argmax(G_pred, dim=1)
180
+ current_score = metric.update_cm(pr=G_pred.cpu().numpy(), gt=gt.detach().cpu().numpy())
181
+ log_dict['running_acc'] = current_score.item()
182
+
183
+ if current_step % opt['train']['val_print_iter'] == 0:
184
+ logs = log_dict
185
+ message = '[Validation CD]. epoch: [%d/%d]. Itter: [%d/%d], running_mf1: %.5f\n' % \
186
+ (current_epoch, opt['train']['n_epoch'] - 1, current_step, len(val_loader), logs['running_acc'])
187
+ logger.info(message)
188
+
189
+ out_dict = OrderedDict()
190
+ out_dict['pred_cm'] = torch.argmax(pred_img, dim=1, keepdim=False)
191
+ out_dict['gt_cm'] = gt
192
+ visuals = out_dict
193
+
194
+ img_mode = "single"
195
+ if img_mode == "single":
196
+ img_A = Metrics.tensor2img(val_data['A'], out_type=np.uint8, min_max=(-1, 1)) # uint8
197
+ img_B = Metrics.tensor2img(val_data['B'], out_type=np.uint8, min_max=(-1, 1)) # uint8
198
+ gt_cm = Metrics.tensor2img(visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
199
+ out_type=np.uint8, min_max=(0, 1)) # uint8
200
+ pred_cm = Metrics.tensor2img(visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
201
+ out_type=np.uint8, min_max=(0, 1)) # uint8
202
+
203
+ Metrics.save_img(
204
+ img_A, '{}/img_A_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
205
+ Metrics.save_img(
206
+ img_B, '{}/img_B_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
207
+ Metrics.save_img(
208
+ pred_cm, '{}/img_pred_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
209
+ Metrics.save_img(
210
+ gt_cm, '{}/img_gt_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
211
+ else:
212
+ visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0
213
+ visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0
214
+ grid_img = torch.cat((val_data['A'].to(device),
215
+ val_data['B'].to(device),
216
+ visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
217
+ visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1)),
218
+ dim=0)
219
+ grid_img = Metrics.tensor2img(grid_img) # uint8
220
+ Metrics.save_img(
221
+ grid_img,'{}/img_A_B_pred_gt_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
222
+
223
+ scores = metric.get_scores()
224
+ epoch_acc = scores['mf1']
225
+ log_dict['epoch_acc'] = epoch_acc.item()
226
+ for k, v in scores.items():
227
+ log_dict[k] = v
228
+ logs = log_dict
229
+ message = '[Validation CD (epoch summary)]: epoch: [%d/%d]. epoch_mF1=%.5f \n' % \
230
+ (current_epoch, opt['train']['n_epoch'], logs['epoch_acc'])
231
+ for k, v in logs.items():
232
+ message += '{:s}: {:.4e} '.format(k, v)
233
+ message += '\n'
234
+ logger.info(message)
235
+
236
+ if logs['epoch_acc'] > best_mF1:
237
+ is_best_model = True
238
+ best_mF1 = logs['epoch_acc']
239
+ logger.info('[Validation CD] Best model updated. Saving the models (current + best) and training states.')
240
+ else:
241
+ is_best_model = False
242
+ logger.info('[Validation CD] Saving the current cd model and training states.')
243
+ logger.info('--- Proceed To The Next Epoch ----\n \n')
244
+
245
+ save_network(opt, current_epoch, cd_model, optimer, is_best_model)
246
+ metric.clear()
247
+
248
+ get_scheduler(optimizer=optimer, args=opt['train']).step()
249
+ logger.info('End of training.')
250
+
251
+ else:
252
+ logger.info('Begin model evaluation (testing phase)')
253
+ test_result_path = '{}/test/'.format(opt['path_cd']['result'])
254
+ os.makedirs(test_result_path, exist_ok=True)
255
+ logger_test = logging.getLogger('test')
256
+
257
+ load_path = opt["path_cd"]["resume_state"]
258
+ print(load_path)
259
+ if load_path is not None:
260
+ logger.info('Loading pre-trained change detection model [{:s}] ...'.format(load_path))
261
+ gen_path = '{}_gen.pth'.format(load_path)
262
+ opt_path = '{}_opt.pth'.format(load_path)
263
+
264
+ cd_model = Model.create_CD_model(opt)
265
+ cpkt_state = torch.load(gen_path)
266
+ missing_keys, unexpected_keys = cd_model.load_state_dict(cpkt_state, strict=False)
267
+ print(missing_keys)
268
+ cd_model.to(device)
269
+ metric.clear()
270
+ cd_model.eval()
271
+ with torch.no_grad():
272
+ for current_step, test_data in enumerate(test_loader):
273
+ test_img1 = test_data['A'].to(device)
274
+ test_img2 = test_data['B'].to(device)
275
+ pred_img = cd_model(test_img1, test_img2)
276
+
277
+ if isinstance(pred_img, tuple):
278
+ pred_img = pred_img[0]
279
+
280
+ gt = test_data['L'].to(device).long()
281
+
282
+ G_pred = pred_img.detach()
283
+ G_pred = torch.argmax(G_pred, dim=1)
284
+ current_score = metric.update_cm(pr=G_pred.cpu().numpy(), gt=gt.detach().cpu().numpy())
285
+ log_dict['running_acc'] = current_score.item()
286
+
287
+ logs = log_dict
288
+ message = '[Test Change Detection] Iteration: [%d/%d], current mF1: %.5f\n' % \
289
+ (current_step, len(test_loader), logs['running_acc'])
290
+ logger_test.info(message)
291
+
292
+ out_dict = OrderedDict()
293
+ out_dict['pred_cm'] = torch.argmax(pred_img, dim=1, keepdim=False)
294
+ out_dict['gt_cm'] = gt
295
+ visuals = out_dict
296
+
297
+ img_mode = 'single'
298
+ if img_mode == 'single':
299
+ visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0
300
+ visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0
301
+ img_A = Metrics.tensor2img(test_data['A'], out_type=np.uint8, min_max=(-1, 1)) # uint8
302
+ img_B = Metrics.tensor2img(test_data['B'], out_type=np.uint8, min_max=(-1, 1)) # uint8
303
+ gt_cm = Metrics.tensor2img(visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
304
+ out_type=np.uint8, min_max=(0, 1)) # uint8
305
+ pred_cm = Metrics.tensor2img(visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
306
+ out_type=np.uint8, min_max=(0, 1)) # uint8
307
+
308
+ Metrics.save_img(
309
+ img_A, '{}/img_A_{}.png'.format(test_result_path, current_step))
310
+ Metrics.save_img(
311
+ img_B, '{}/img_B_{}.png'.format(test_result_path, current_step))
312
+ Metrics.save_img(
313
+ pred_cm, '{}/img_pred_cm{}.png'.format(test_result_path, current_step))
314
+ Metrics.save_img(
315
+ gt_cm, '{}/img_gt_cm{}.png'.format(test_result_path, current_step))
316
+ else:
317
+ visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0
318
+ visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0
319
+ grid_img = torch.cat((test_data['A'],
320
+ test_data['B'],
321
+ visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
322
+ visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1)),
323
+ dim=0)
324
+ grid_img = Metrics.tensor2img(grid_img) # uint8
325
+ Metrics.save_img(
326
+ grid_img, '{}/img_A_B_pred_gt_{}.png'.format(test_result_path, current_step))
327
+
328
+ scores = metric.get_scores()
329
+ epoch_acc = scores['mf1']
330
+ log_dict['epoch_acc'] = epoch_acc.item()
331
+ for k, v in scores.items():
332
+ log_dict[k] = v
333
+ logs = log_dict
334
+ message = '[Test Change Detection Summary]: Test mF1=%.5f \n' % \
335
+ (logs['epoch_acc'])
336
+ for k, v in logs.items():
337
+ message += '{:s}: {:.4e} '.format(k, v)
338
+ message += '\n'
339
+ logger_test.info(message)
340
+ logger.info('Testing finished...')
train_cd.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import data as Data
4
+ import models as Model
5
+ import torch.nn as nn
6
+ import argparse
7
+ import logging
8
+ import core.logger as Logger
9
+ import os
10
+ import numpy as np
11
+ from misc.metric_tools import ConfuseMatrixMeter
12
+ from models.loss import *
13
+ from collections import OrderedDict
14
+ import core.metrics as Metrics
15
+ from misc.torchutils import get_scheduler, save_network
16
+
17
+
18
+ if __name__ == '__main__':
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('--config', type=str, default='./config/whu/whu.json',
21
+ help='JSON configuration file for training')
22
+ parser.add_argument('--phase', type=str, default='train',
23
+ choices=['train', 'test'], help='Choose between training or testing')
24
+ parser.add_argument('--gpu_ids', type=str, default=None, help='Specify GPU device')
25
+ parser.add_argument('-log_eval', action='store_true', help='Whether to log evaluation')
26
+
27
+ args = parser.parse_args()
28
+ opt = Logger.parse(args)
29
+
30
+ opt = Logger.dict_to_nonedict(opt)
31
+
32
+ torch.backends.cudnn.enabled = True
33
+ torch.backends.cudnn.benchmark = True
34
+
35
+ Logger.setup_logger(logger_name=None, root=opt['path_cd']['log'], phase='train',
36
+ level=logging.INFO, screen=True)
37
+ Logger.setup_logger(logger_name='test', root=opt['path_cd']['log'], phase='test',
38
+ level=logging.INFO)
39
+
40
+ logger = logging.getLogger('base')
41
+ logger.info(Logger.dict2str(opt))
42
+
43
+ for phase, dataset_opt in opt['datasets'].items():
44
+ if phase == 'train' and args.phase != 'test':
45
+ print("Creating [train] change-detection dataloader")
46
+ train_set = Data.create_cd_dataset(dataset_opt=dataset_opt, phase=phase)
47
+ train_loader = Data.create_cd_dataloader(train_set, dataset_opt, phase)
48
+ opt['len_train_dataloader'] = len(train_loader)
49
+
50
+ elif phase == 'val' and args.phase != 'test':
51
+ print("Creating [val] change-detection dataloader")
52
+ val_set = Data.create_cd_dataset(dataset_opt=dataset_opt, phase=phase)
53
+ val_loader = Data.create_cd_dataloader(val_set, dataset_opt, phase)
54
+ opt['len_val_dataloader'] = len(val_loader)
55
+
56
+ elif phase == 'test':
57
+ print("Creating [test] change-detection dataloader")
58
+ test_set = Data.create_cd_dataset(dataset_opt=dataset_opt, phase=phase)
59
+ test_loader = Data.create_cd_dataloader(test_set, dataset_opt, phase)
60
+ opt['len_test_dataloader'] = len(test_loader)
61
+
62
+ logger.info('Dataset initialization completed')
63
+
64
+ cd_model = Model.create_CD_model(opt)
65
+
66
+ if opt['model']['loss'] == 'ce_dice':
67
+ loss_fun = ce_dice
68
+ elif opt['model']['loss'] == 'ce':
69
+ loss_fun = cross_entropy
70
+ elif opt['model']['loss'] == 'dice':
71
+ loss_fun = dice
72
+ elif opt['model']['loss'] == 'ce2_dice1':
73
+ loss_fun = ce2_dice1
74
+ elif opt['model']['loss'] == 'ce1_dice2':
75
+ loss_fun = ce1_dice2
76
+
77
+ if opt['train']["optimizer"]["type"] == 'adam':
78
+ optimer = optim.Adam(cd_model.parameters(), lr=opt['train']["optimizer"]["lr"])
79
+ elif opt['train']["optimizer"]["type"] == 'adamw':
80
+ optimer = optim.AdamW(cd_model.parameters(), lr=opt['train']["optimizer"]["lr"])
81
+ elif opt['train']["optimizer"]["type"] == 'sgd':
82
+ optimer = optim.SGD(cd_model.parameters(), lr=opt['train']["optimizer"]["lr"],
83
+ momentum=0.9, weight_decay=5e-4)
84
+
85
+ device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
86
+ cd_model.to(device)
87
+ if len(opt['gpu_ids']) > 0:
88
+ cd_model = nn.DataParallel(cd_model)
89
+ metric = ConfuseMatrixMeter(n_class=2)
90
+ log_dict = OrderedDict()
91
+
92
+ if opt['phase'] == 'train':
93
+ best_mF1 = 0.0
94
+ for current_epoch in range(0, opt['train']['n_epoch']):
95
+ print("......Training Started......")
96
+ metric.clear()
97
+ cd_model.train()
98
+ train_result_path = '{}/train/{}'.format(opt['path_cd']['result'], current_epoch)
99
+ os.makedirs(train_result_path, exist_ok=True)
100
+
101
+ message = 'Current learning rate: %0.7f\n \n' % optimer.param_groups[0]['lr']
102
+ logger.info(message)
103
+ for current_step, train_data in enumerate(train_loader):
104
+ train_im1 = train_data['A'].to(device)
105
+ train_im2 = train_data['B'].to(device)
106
+ pred_img = cd_model(train_im1, train_im2)
107
+ gt = train_data['L'].to(device).long()
108
+ train_loss = loss_fun(pred_img, gt)
109
+
110
+ optimer.zero_grad()
111
+ train_loss.backward()
112
+ optimer.step()
113
+ log_dict['loss'] = train_loss.item()
114
+
115
+ G_pred = pred_img.detach()
116
+ G_pred = torch.argmax(G_pred, dim=1)
117
+ current_score = metric.update_cm(pr=G_pred.cpu().numpy(), gt=gt.detach().cpu().numpy())
118
+ log_dict['running_acc'] = current_score.item()
119
+
120
+ if current_step % opt['train']['train_print_iter'] == 0:
121
+ logs = log_dict
122
+ message = '[Training Change Detection]. Epoch: [%d/%d]. Iteration: [%d/%d], Loss: %.5f, Current mF1: %.5f\n' % \
123
+ (current_epoch, opt['train']['n_epoch'] - 1, current_step, len(train_loader), logs['loss'],
124
+ logs['running_acc'])
125
+ logger.info(message)
126
+
127
+ scores = metric.get_scores()
128
+ epoch_acc = scores['mf1']
129
+ log_dict['epoch_acc'] = epoch_acc.item()
130
+ for k, v in scores.items():
131
+ log_dict[k] = v
132
+ logs = log_dict
133
+ message = '[Training Change Detection (Epoch Summary)]: Epoch: [%d/%d]. Current mF1=%.5f \n' % \
134
+ (current_epoch, opt['train']['n_epoch'] - 1, logs['epoch_acc'])
135
+ for k, v in logs.items():
136
+ message += '{:s}: {:.4e} '.format(k, v)
137
+ message += '\n'
138
+ logger.info(message)
139
+
140
+ metric.clear()
141
+
142
+ cd_model.eval()
143
+ with torch.no_grad():
144
+ if current_epoch % opt['train']['val_freq'] == 0:
145
+ val_result_path = '{}/val/{}'.format(opt['path_cd']['result'], current_epoch)
146
+ os.makedirs(val_result_path, exist_ok=True)
147
+
148
+ for current_step, val_data in enumerate(val_loader):
149
+ val_img1 = val_data['A'].to(device)
150
+ val_img2 = val_data['B'].to(device)
151
+ pred_img = cd_model(val_img1, val_img2)
152
+ gt = val_data['L'].to(device).long()
153
+ val_loss = loss_fun(pred_img, gt)
154
+ log_dict['loss'] = val_loss.item()
155
+
156
+ G_pred = pred_img.detach()
157
+ G_pred = torch.argmax(G_pred, dim=1)
158
+ current_score = metric.update_cm(pr=G_pred.cpu().numpy(), gt=gt.detach().cpu().numpy())
159
+ log_dict['running_acc'] = current_score.item()
160
+
161
+ if current_step % opt['train']['val_print_iter'] == 0:
162
+ logs = log_dict
163
+ message = '[Validation Change Detection]. Epoch: [%d/%d]. Iteration: [%d/%d], Current mF1: %.5f\n' % \
164
+ (current_epoch, opt['train']['n_epoch'] - 1, current_step, len(val_loader), logs['running_acc'])
165
+ logger.info(message)
166
+
167
+ out_dict = OrderedDict()
168
+ out_dict['pred_cm'] = torch.argmax(pred_img, dim=1, keepdim=False)
169
+ out_dict['gt_cm'] = gt
170
+ visuals = out_dict
171
+
172
+ img_mode = "grid"
173
+ if img_mode == "single":
174
+ img_A = Metrics.tensor2img(val_data['A'], out_type=np.uint8, min_max=(-1, 1)) # uint8
175
+ img_B = Metrics.tensor2img(val_data['B'], out_type=np.uint8, min_max=(-1, 1)) # uint8
176
+ gt_cm = Metrics.tensor2img(visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
177
+ out_type=np.uint8, min_max=(0, 1)) # uint8
178
+ pred_cm = Metrics.tensor2img(visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
179
+ out_type=np.uint8, min_max=(0, 1)) # uint8
180
+
181
+ Metrics.save_img(
182
+ img_A, '{}/img_A_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
183
+ Metrics.save_img(
184
+ img_B, '{}/img_B_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
185
+ Metrics.save_img(
186
+ pred_cm, '{}/img_pred_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
187
+ Metrics.save_img(
188
+ gt_cm, '{}/img_gt_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
189
+ else:
190
+ visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0
191
+ visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0
192
+ grid_img = torch.cat((val_data['A'].to(device),
193
+ val_data['B'].to(device),
194
+ visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
195
+ visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1)),
196
+ dim=0)
197
+ grid_img = Metrics.tensor2img(grid_img) # uint8
198
+ Metrics.save_img(
199
+ grid_img,'{}/img_A_B_pred_gt_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
200
+
201
+ scores = metric.get_scores()
202
+ epoch_acc = scores['mf1']
203
+ log_dict['epoch_acc'] = epoch_acc.item()
204
+ for k, v in scores.items():
205
+ log_dict[k] = v
206
+ logs = log_dict
207
+ message = '[Validation Change Detection (Epoch Summary)]: Epoch: [%d/%d]. Epoch mF1=%.5f \n' % \
208
+ (current_epoch, opt['train']['n_epoch'], logs['epoch_acc'])
209
+ for k, v in logs.items():
210
+ message += '{:s}: {:.4e} '.format(k, v)
211
+ message += '\n'
212
+ logger.info(message)
213
+
214
+ if logs['epoch_acc'] > best_mF1:
215
+ is_best_model = True
216
+ best_mF1 = logs['epoch_acc']
217
+ logger.info('[Validation CD Phase] Best model updated, saving current best model and training state.')
218
+ save_network(opt, current_epoch, cd_model, optimer, is_best_model)
219
+ else:
220
+ is_best_model = False
221
+ logger.info('[Validation CD Phase] Saving current change detection model and training state.')
222
+ logger.info('--- Proceed to next epoch ----\n \n')
223
+
224
+ metric.clear()
225
+
226
+ get_scheduler(optimizer=optimer, args=opt['train']).step()
227
+ logger.info('Training finished.')
228
+
229
+ else:
230
+ logger.info('Begin Model Evaluation (testing).')
231
+ test_result_path = '{}/test/'.format(opt['path_cd']['result'])
232
+ os.makedirs(test_result_path, exist_ok=True)
233
+ logger_test = logging.getLogger('test')
234
+
235
+ load_path = opt["path_cd"]["resume_state"]
236
+ print(load_path)
237
+ if load_path is not None:
238
+ logger.info('Loading pretrained model for CD model [{:s}] ...'.format(load_path))
239
+ gen_path = '{}_gen.pth'.format(load_path)
240
+ opt_path = '{}_opt.pth'.format(load_path)
241
+
242
+ cd_model = Model.create_CD_model(opt)
243
+ cd_model.load_state_dict(torch.load(gen_path), strict=True)
244
+ cd_model.to(device)
245
+ metric.clear()
246
+ cd_model.eval()
247
+ with torch.no_grad():
248
+ for current_step, test_data in enumerate(test_loader):
249
+ test_img1 = test_data['A'].to(device)
250
+ test_img2 = test_data['B'].to(device)
251
+ pred_img = cd_model(test_img1, test_img2)
252
+ gt = test_data['L'].to(device).long()
253
+
254
+ G_pred = pred_img.detach()
255
+ G_pred = torch.argmax(G_pred, dim=1)
256
+ current_score = metric.update_cm(pr=G_pred.cpu().numpy(), gt=gt.detach().cpu().numpy())
257
+ log_dict['running_acc'] = current_score.item()
258
+
259
+ logs = log_dict
260
+ message = '[Testing Change Detection]. Iteration: [%d/%d], running mF1: %.5f\n' % \
261
+ (current_step, len(test_loader), logs['running_acc'])
262
+ logger_test.info(message)
263
+
264
+ out_dict = OrderedDict()
265
+ out_dict['pred_cm'] = torch.argmax(pred_img, dim=1, keepdim=False)
266
+ out_dict['gt_cm'] = gt
267
+ visuals = out_dict
268
+
269
+ img_mode = 'single'
270
+ if img_mode == 'single':
271
+ visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0
272
+ visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0
273
+ img_A = Metrics.tensor2img(test_data['A'], out_type=np.uint8, min_max=(-1, 1)) # uint8
274
+ img_B = Metrics.tensor2img(test_data['B'], out_type=np.uint8, min_max=(-1, 1)) # uint8
275
+ gt_cm = Metrics.tensor2img(visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1), out_type=np.uint8,
276
+ min_max=(0, 1)) # uint8
277
+ pred_cm = Metrics.tensor2img(visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
278
+ out_type=np.uint8, min_max=(0, 1)) # uint8
279
+
280
+ Metrics.save_img(
281
+ img_A, '{}/img_A_{}.png'.format(test_result_path, current_step))
282
+ Metrics.save_img(
283
+ img_B, '{}/img_B_{}.png'.format(test_result_path, current_step))
284
+ Metrics.save_img(
285
+ pred_cm, '{}/img_pred_cm{}.png'.format(test_result_path, current_step))
286
+ Metrics.save_img(
287
+ gt_cm, '{}/img_gt_cm{}.png'.format(test_result_path, current_step))
288
+ else:
289
+ visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0
290
+ visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0
291
+ grid_img = torch.cat((test_data['A'],
292
+ test_data['B'],
293
+ visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
294
+ visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1)),
295
+ dim=0)
296
+ grid_img = Metrics.tensor2img(grid_img) # uint8
297
+ Metrics.save_img(
298
+ grid_img, '{}/img_A_B_pred_gt_{}.png'.format(test_result_path, current_step))
299
+
300
+ scores = metric.get_scores()
301
+ epoch_acc = scores['mf1']
302
+ log_dict['epoch_acc'] = epoch_acc.item()
303
+ for k, v in scores.items():
304
+ log_dict[k] = v
305
+ logs = log_dict
306
+ message = '[Test Change Detection Summary]: Test mF1=%.5f \n' % \
307
+ (logs['epoch_acc'])
308
+ for k, v in logs.items():
309
+ message += '{:s}: {:.4e} '.format(k, v)
310
+ message += '\n'
311
+ logger_test.info(message)
312
+ logger.info('End of testing...')