| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Created in September 2022 |
| @author: fabrizio.guillaro |
| """ |
|
|
| import logging |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
|
|
| import torch |
| import torch.nn as nn |
|
|
| import matplotlib |
| matplotlib.use('agg') |
| import matplotlib.pyplot as plt |
| import itertools |
|
|
|
|
|
|
| def adjust_learning_rate(optimizer, base_lr, max_iters, cur_iters, power=0.9): |
| lr = base_lr*((1-float(cur_iters)/max_iters)**(power)) |
| for i, param_group in enumerate(optimizer.param_groups): |
| param_group['lr'] = lr |
| return lr |
|
|
|
|
| class FullModel(nn.Module): |
| """ |
| Distribute the loss on multi-gpu to reduce the memory cost in the main gpu. |
| """ |
| def __init__(self, model, config=None): |
| super(FullModel, self).__init__() |
| self.model = model |
| self.model_name = config.MODEL.NAME |
| self.cfg = config |
| self.losses = config.LOSS.LOSSES |
| self.loss_loc, self.loss_conf, self.loss_det = get_criterion(config) |
|
|
| def forward(self, labels=None, rgbs=None): |
| outputs, conf, det, npp = self.model(rgbs) |
| final_loss = 0 |
| for (l,w,_) in self.losses: |
| if l == 'LOC': |
| loss = self.loss_loc(outputs, labels) |
| elif l == 'CONF': |
| loss = self.loss_conf(outputs, labels, conf) |
| elif l == 'DET': |
| loss = self.loss_det(det, labels) |
|
|
| loss = torch.unsqueeze(loss, 0) |
| final_loss += w * loss |
|
|
| return final_loss, outputs, conf, det |
|
|
|
|
|
|
|
|
|
|
| def get_model(config): |
| if config.MODEL.NAME == 'detconfcmx': |
| from lib.models.cmx.builder_np_conf import EncoderDecoder as detconfcmx |
| return detconfcmx(cfg=config) |
| else: |
| raise NotImplementedError("Model not implemented") |
|
|
|
|
| def get_criterion(config): |
| ignore_label = config.TRAIN.IGNORE_LABEL |
| smooth = config.LOSS.SMOOTH |
| weight = torch.FloatTensor(config.DATASET.CLASS_WEIGHTS) |
|
|
| losses = config.LOSS.LOSSES |
| detection = config.MODEL.EXTRA.DETECTION |
|
|
| criterion_loc, criterion_conf, criterion_det = None, None, None |
|
|
| for (l,_,criterion) in losses: |
| assert l in ['LOC', 'CONF', 'DET'] |
|
|
| |
| if l == 'LOC': |
| if criterion == 'dice': |
| from lib.core.criterion import DiceLoss |
| criterion_loc = DiceLoss(ignore_label=ignore_label, smooth=smooth).cuda() |
| elif criterion == 'binary_dice': |
| from lib.core.criterion import BinaryDiceLoss |
| criterion_loc = BinaryDiceLoss(ignore_label=ignore_label, smooth=smooth).cuda() |
| elif criterion == 'cross_entropy': |
| from lib.core.criterion import CrossEntropy |
| criterion_loc = CrossEntropy(ignore_label=ignore_label, weight=weight).cuda() |
| elif criterion == 'dice_entropy': |
| from lib.core.criterion import DiceEntropyLoss |
| criterion_loc = DiceEntropyLoss(ignore_label=ignore_label, weight=weight, smooth=smooth).cuda() |
| else: |
| raise ValueError('Localization loss not implemented') |
|
|
| |
| elif l == 'CONF': |
| if criterion == 'mse': |
| from lib.core.criterion_conf import MSE |
| criterion_conf = MSE().cuda() |
| else: |
| raise ValueError('Confidence loss not implemented') |
|
|
| |
| elif l == 'DET': |
| if detection is not None and not detection == 'none': |
| if criterion == 'cross_entropy': |
| from lib.core.criterion_det import CrossEntropy |
| criterion_det = CrossEntropy().cuda() |
| else: |
| raise ValueError('Detection loss not implemented') |
|
|
| return criterion_loc, criterion_conf, criterion_det |
|
|
|
|
|
|
| def get_optimizer(model, config): |
| if 'cmx' in config.MODEL.NAME: |
| from lib.models.cmx.init_func import group_weight |
| params_list = [] |
| params_list = group_weight(params_list, model, nn.BatchNorm2d, config.TRAIN.LR) |
| else: |
| params_list = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': config.TRAIN.LR}] |
|
|
| if config.TRAIN.OPTIMIZER == 'sgd': |
| optimizer = torch.optim.SGD(params_list, |
| lr = config.TRAIN.LR, |
| momentum = config.TRAIN.MOMENTUM, |
| weight_decay = config.TRAIN.WD, |
| nesterov = config.TRAIN.NESTEROV) |
| elif config.TRAIN.OPTIMIZER == 'adam': |
| optimizer = torch.optim.Adam(params_list, |
| lr = config.TRAIN.LR, |
| betas = (0.9, 0.999), |
| weight_decay = config.TRAIN.WD) |
| else: |
| raise ValueError('Optimizer not implemented') |
|
|
| return optimizer |
|
|
|
|
|
|
|
|
| class AverageMeter(object): |
| """Computes and stores the average and current value""" |
|
|
| def __init__(self): |
| self.initialized = False |
| self.val = None |
| self.avg = None |
| self.sum = None |
| self.count = None |
|
|
| def initialize(self, val, weight): |
| self.val = val |
| self.avg = val |
| self.sum = val * weight |
| self.count = weight |
| self.initialized = True |
|
|
| def update(self, val, weight=1): |
| if not self.initialized: |
| self.initialize(val, weight) |
| else: |
| self.add(val, weight) |
|
|
| def add(self, val, weight): |
| self.val = val |
| self.sum += val * weight |
| self.count += weight |
| self.avg = self.sum / self.count |
|
|
| def value(self): |
| return self.val |
|
|
| def average(self): |
| return self.avg |
|
|
|
|
|
|
| def create_logger(cfg, cfg_name, phase='train'): |
| root_output_dir = Path(cfg.OUTPUT_DIR) |
| |
| if not root_output_dir.exists(): |
| print('=> creating {}'.format(root_output_dir)) |
| root_output_dir.mkdir() |
|
|
| model = cfg.MODEL.NAME |
| final_output_dir = root_output_dir / cfg_name |
|
|
| print('=> creating {}'.format(final_output_dir)) |
| final_output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| time_str = time.strftime('%Y-%m-%d-%H-%M') |
| log_file = '{}_{}_{}.log'.format(cfg_name.replace('/','_'), time_str, phase) |
| final_log_file = final_output_dir / log_file |
| head = '%(asctime)-15s %(message)s' |
| logging.basicConfig(filename=str(final_log_file), format=head) |
| logger = logging.getLogger() |
| logger.setLevel(logging.INFO) |
| console = logging.StreamHandler() |
| logging.getLogger('').addHandler(console) |
|
|
| tensorboard_log_dir = Path(cfg.LOG_DIR) / model / (cfg_name + '_' + time_str) |
| return logger, str(final_output_dir), str(tensorboard_log_dir) |
|
|
|
|
|
|
| def get_confusion_matrix(label, pred, size, num_class, ignore=-1): |
| """ |
| Calcute the confusion matrix by given label and pred |
| """ |
| output = pred.cpu().numpy().transpose(0, 2, 3, 1) |
| seg_pred = np.asarray(np.argmax(output, axis=3), dtype=np.uint8) |
| seg_gt = np.asarray( |
| label.cpu().numpy()[:, :size[-2], :size[-1]], dtype=np.int) |
|
|
| ignore_index = seg_gt != ignore |
| seg_gt = seg_gt[ignore_index] |
| seg_pred = seg_pred[ignore_index] |
|
|
| index = (seg_gt * num_class + seg_pred).astype('int32') |
| label_count = np.bincount(index) |
| confusion_matrix = np.zeros((num_class, num_class)) |
|
|
| for i_label in range(num_class): |
| for i_pred in range(num_class): |
| cur_index = i_label * num_class + i_pred |
| if cur_index < len(label_count): |
| confusion_matrix[i_label, |
| i_pred] = label_count[cur_index] |
| return confusion_matrix |
|
|
|
|
|
|
| |
| def get_confusion_matrix_1ch(label, confid, size, num_class, ignore=-1): |
| """ |
| Calcute the confusion matrix by given label and pred |
| """ |
| |
|
|
| output = confid.squeeze(dim=1).cpu().numpy() |
| |
| |
| seg_pred = np.asarray(output>0, dtype=np.uint8) |
| seg_gt = np.asarray( |
| label.cpu().numpy()[:, :size[-2], :size[-1]], dtype=np.int) |
|
|
| ignore_index = seg_gt != ignore |
| seg_gt = seg_gt[ignore_index] |
| seg_pred = seg_pred[ignore_index] |
|
|
| index = (seg_gt * num_class + seg_pred).astype('int32') |
| label_count = np.bincount(index) |
| confusion_matrix = np.zeros((num_class, num_class)) |
|
|
| for i_label in range(num_class): |
| for i_pred in range(num_class): |
| cur_index = i_label * num_class + i_pred |
| if cur_index < len(label_count): |
| confusion_matrix[i_label, |
| i_pred] = label_count[cur_index] |
| return confusion_matrix |
|
|
|
|
| def plot_confusion_matrix(confusion_matrix): |
| |
| fig = plt.figure(figsize=(3, 3), dpi=200, facecolor='w', edgecolor='k') |
| ax = fig.add_subplot(1, 1, 1) |
| im = ax.imshow(confusion_matrix, cmap='bwr') |
| |
| ax.set_xlabel('Predicted', fontsize=10) |
| ax.set_xticks([0,1]) |
| ax.xaxis.set_label_position('bottom') |
| ax.xaxis.tick_bottom() |
|
|
| ax.set_ylabel('True Label', fontsize=10) |
| ax.set_yticks([0,1]) |
| ax.yaxis.set_label_position('left') |
| ax.yaxis.tick_left() |
| |
| for i, j in itertools.product(range(2), range(2)): |
| ax.text(j, i, format(confusion_matrix[i, j], '.3e') if confusion_matrix[i,j]!=0 else '.', horizontalalignment="center", fontsize=10, verticalalignment='center', color= "black") |
| |
| fig.set_tight_layout(True) |
| fig.colorbar(im,fraction=0.046, pad=0.04) |
| |
| fig.canvas.draw() |
| canvas = fig.canvas.tostring_rgb() |
| ncols, nrows = fig.canvas.get_width_height() |
| cm = np.frombuffer(canvas, dtype=np.uint8).reshape(nrows, ncols, 3).transpose(2, 0, 1) |
| plt.close(fig) |
| return cm |
|
|