| | import math |
| | import scipy |
| | import numpy as np |
| | from scipy.ndimage import grey_dilation, grey_erosion |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | __all__ = [ |
| | 'supervised_training_iter', |
| | 'soc_adaptation_iter', |
| | ] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GaussianBlurLayer(nn.Module): |
| | """ Add Gaussian Blur to a 4D tensors |
| | This layer takes a 4D tensor of {N, C, H, W} as input. |
| | The Gaussian blur will be performed in given channel number (C) splitly. |
| | """ |
| |
|
| | def __init__(self, channels, kernel_size): |
| | """ |
| | Arguments: |
| | channels (int): Channel for input tensor |
| | kernel_size (int): Size of the kernel used in blurring |
| | """ |
| |
|
| | super(GaussianBlurLayer, self).__init__() |
| | self.channels = channels |
| | self.kernel_size = kernel_size |
| | assert self.kernel_size % 2 != 0 |
| |
|
| | self.op = nn.Sequential( |
| | nn.ReflectionPad2d(math.floor(self.kernel_size / 2)), |
| | nn.Conv2d(channels, channels, self.kernel_size, |
| | stride=1, padding=0, bias=None, groups=channels) |
| | ) |
| |
|
| | self._init_kernel() |
| |
|
| | def forward(self, x): |
| | """ |
| | Arguments: |
| | x (torch.Tensor): input 4D tensor |
| | Returns: |
| | torch.Tensor: Blurred version of the input |
| | """ |
| |
|
| | if not len(list(x.shape)) == 4: |
| | print('\'GaussianBlurLayer\' requires a 4D tensor as input\n') |
| | exit() |
| | elif not x.shape[1] == self.channels: |
| | print('In \'GaussianBlurLayer\', the required channel ({0}) is' |
| | 'not the same as input ({1})\n'.format(self.channels, x.shape[1])) |
| | exit() |
| | |
| | return self.op(x) |
| | |
| | def _init_kernel(self): |
| | sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8 |
| |
|
| | n = np.zeros((self.kernel_size, self.kernel_size)) |
| | i = math.floor(self.kernel_size / 2) |
| | n[i, i] = 1 |
| | kernel = scipy.ndimage.gaussian_filter(n, sigma) |
| |
|
| | for name, param in self.named_parameters(): |
| | param.data.copy_(torch.from_numpy(kernel)) |
| |
|
| | |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | blurer = GaussianBlurLayer(1, 3).cuda() |
| |
|
| |
|
| | def supervised_training_iter( |
| | modnet, optimizer, image, trimap, gt_matte, |
| | semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0): |
| | """ Supervised training iteration of MODNet |
| | This function trains MODNet for one iteration in a labeled dataset. |
| | |
| | Arguments: |
| | modnet (torch.nn.Module): instance of MODNet |
| | optimizer (torch.optim.Optimizer): optimizer for supervised training |
| | image (torch.autograd.Variable): input RGB image |
| | its pixel values should be normalized |
| | trimap (torch.autograd.Variable): trimap used to calculate the losses |
| | its pixel values can be 0, 0.5, or 1 |
| | (foreground=1, background=0, unknown=0.5) |
| | gt_matte (torch.autograd.Variable): ground truth alpha matte |
| | its pixel values are between [0, 1] |
| | semantic_scale (float): scale of the semantic loss |
| | NOTE: please adjust according to your dataset |
| | detail_scale (float): scale of the detail loss |
| | NOTE: please adjust according to your dataset |
| | matte_scale (float): scale of the matte loss |
| | NOTE: please adjust according to your dataset |
| | |
| | Returns: |
| | semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch] |
| | detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch] |
| | matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch] |
| | |
| | Example: |
| | import torch |
| | from src.models.modnet import MODNet |
| | from src.trainer import supervised_training_iter |
| | |
| | bs = 16 # batch size |
| | lr = 0.01 # learn rate |
| | epochs = 40 # total epochs |
| | |
| | modnet = torch.nn.DataParallel(MODNet()).cuda() |
| | optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9) |
| | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1) |
| | |
| | dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function |
| | |
| | for epoch in range(0, epochs): |
| | for idx, (image, trimap, gt_matte) in enumerate(dataloader): |
| | semantic_loss, detail_loss, matte_loss = \ |
| | supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) |
| | lr_scheduler.step() |
| | """ |
| |
|
| | global blurer |
| |
|
| | |
| | modnet.train() |
| | optimizer.zero_grad() |
| |
|
| | |
| | pred_semantic, pred_detail, pred_matte = modnet(image, False) |
| |
|
| | |
| | boundaries = (trimap < 0.5) + (trimap > 0.5) |
| |
|
| | |
| | gt_semantic = F.interpolate(gt_matte, scale_factor=1/16, mode='bilinear') |
| | gt_semantic = blurer(gt_semantic) |
| | semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic)) |
| | semantic_loss = semantic_scale * semantic_loss |
| |
|
| | |
| | pred_boundary_detail = torch.where(boundaries, trimap, pred_detail) |
| | gt_detail = torch.where(boundaries, trimap, gt_matte) |
| | detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail)) |
| | detail_loss = detail_scale * detail_loss |
| |
|
| | |
| | pred_boundary_matte = torch.where(boundaries, trimap, pred_matte) |
| | matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte) |
| | matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \ |
| | + 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte) |
| | matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss) |
| | matte_loss = matte_scale * matte_loss |
| |
|
| | |
| | loss = semantic_loss + detail_loss + matte_loss |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | |
| | return semantic_loss, detail_loss, matte_loss |
| |
|
| |
|
| | def soc_adaptation_iter( |
| | modnet, backup_modnet, optimizer, image, |
| | soc_semantic_scale=100.0, soc_detail_scale=1.0): |
| | """ Self-Supervised sub-objective consistency (SOC) adaptation iteration of MODNet |
| | This function fine-tunes MODNet for one iteration in an unlabeled dataset. |
| | Note that SOC can only fine-tune a converged MODNet, i.e., MODNet that has been |
| | trained in a labeled dataset. |
| | |
| | Arguments: |
| | modnet (torch.nn.Module): instance of MODNet |
| | backup_modnet (torch.nn.Module): backup of the trained MODNet |
| | optimizer (torch.optim.Optimizer): optimizer for self-supervised SOC |
| | image (torch.autograd.Variable): input RGB image |
| | its pixel values should be normalized |
| | soc_semantic_scale (float): scale of the SOC semantic loss |
| | NOTE: please adjust according to your dataset |
| | soc_detail_scale (float): scale of the SOC detail loss |
| | NOTE: please adjust according to your dataset |
| | |
| | Returns: |
| | soc_semantic_loss (torch.Tensor): loss of the semantic SOC |
| | soc_detail_loss (torch.Tensor): loss of the detail SOC |
| | |
| | Example: |
| | import copy |
| | import torch |
| | from src.models.modnet import MODNet |
| | from src.trainer import soc_adaptation_iter |
| | |
| | bs = 1 # batch size |
| | lr = 0.00001 # learn rate |
| | epochs = 10 # total epochs |
| | |
| | modnet = torch.nn.DataParallel(MODNet()).cuda() |
| | modnet = LOAD_TRAINED_CKPT() # NOTE: please finish this function |
| | |
| | optimizer = torch.optim.Adam(modnet.parameters(), lr=lr, betas=(0.9, 0.99)) |
| | dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function |
| | |
| | for epoch in range(0, epochs): |
| | backup_modnet = copy.deepcopy(modnet) |
| | for idx, (image) in enumerate(dataloader): |
| | soc_semantic_loss, soc_detail_loss = \ |
| | soc_adaptation_iter(modnet, backup_modnet, optimizer, image) |
| | """ |
| |
|
| | global blurer |
| |
|
| | |
| | backup_modnet.eval() |
| |
|
| | |
| | modnet.train() |
| | modnet.module.freeze_norm() |
| |
|
| | |
| | optimizer.zero_grad() |
| |
|
| | |
| | pred_semantic, pred_detail, pred_matte = modnet(image, False) |
| |
|
| | |
| | with torch.no_grad(): |
| | _, pred_backup_detail, pred_backup_matte = backup_modnet(image, False) |
| |
|
| | |
| | pred_matte_fg = (pred_matte.detach() > 0.1).float() |
| | pred_semantic_fg = (pred_semantic.detach() > 0.1).float() |
| | pred_semantic_fg = F.interpolate(pred_semantic_fg, scale_factor=16, mode='bilinear') |
| | pred_fg = pred_matte_fg * pred_semantic_fg |
| |
|
| | n, c, h, w = pred_matte.shape |
| | np_pred_fg = pred_fg.data.cpu().numpy() |
| | np_boundaries = np.zeros([n, c, h, w]) |
| | for sdx in range(0, n): |
| | sample_np_boundaries = np_boundaries[sdx, 0, ...] |
| | sample_np_pred_fg = np_pred_fg[sdx, 0, ...] |
| |
|
| | side = int((h + w) / 2 * 0.05) |
| | dilated = grey_dilation(sample_np_pred_fg, size=(side, side)) |
| | eroded = grey_erosion(sample_np_pred_fg, size=(side, side)) |
| |
|
| | sample_np_boundaries[np.where(dilated - eroded != 0)] = 1 |
| | np_boundaries[sdx, 0, ...] = sample_np_boundaries |
| |
|
| | boundaries = torch.tensor(np_boundaries).float().cuda() |
| |
|
| | |
| | |
| | downsampled_pred_matte = blurer(F.interpolate(pred_matte, scale_factor=1/16, mode='bilinear')) |
| | pseudo_gt_semantic = downsampled_pred_matte.detach() |
| | pseudo_gt_semantic = pseudo_gt_semantic * (pseudo_gt_semantic > 0.01).float() |
| | |
| | |
| | pseudo_gt_matte = pred_semantic.detach() |
| | pseudo_gt_matte = pseudo_gt_matte * (pseudo_gt_matte > 0.01).float() |
| |
|
| | |
| | soc_semantic_loss = F.mse_loss(pred_semantic, pseudo_gt_semantic) + F.mse_loss(downsampled_pred_matte, pseudo_gt_matte) |
| | soc_semantic_loss = soc_semantic_scale * torch.mean(soc_semantic_loss) |
| |
|
| | |
| | |
| | backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail, reduction='none') |
| | backup_detail_loss = torch.sum(backup_detail_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3)) |
| | backup_detail_loss = torch.mean(backup_detail_loss) |
| |
|
| | |
| | backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte, reduction='none') |
| | backup_matte_loss = torch.sum(backup_matte_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3)) |
| | backup_matte_loss = torch.mean(backup_matte_loss) |
| |
|
| | soc_detail_loss = soc_detail_scale * (backup_detail_loss + backup_matte_loss) |
| |
|
| | |
| | loss = soc_semantic_loss + soc_detail_loss |
| |
|
| | loss.backward() |
| | optimizer.step() |
| |
|
| | return soc_semantic_loss, soc_detail_loss |
| |
|
| | |
| |
|