| | import kornia.filters |
| | import kornia.filters |
| | import scipy.ndimage |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import random |
| |
|
| |
|
| |
|
| | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): |
| | """3x3 convolution with padding""" |
| | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, |
| | padding=dilation, groups=groups, bias=False, dilation=dilation) |
| |
|
| |
|
| | def conv1x1(in_planes, out_planes, stride=1): |
| | """1x1 convolution""" |
| | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) |
| |
|
| |
|
| | class DoubleConv(nn.Module): |
| | """(convolution => [BN] => ReLU) * 2""" |
| |
|
| | def __init__(self, in_channels, out_channels, mid_channels=None): |
| | super().__init__() |
| | if not mid_channels: |
| | mid_channels = out_channels |
| | norm_layer = nn.BatchNorm2d |
| |
|
| | self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False) |
| | self.bn1 = nn.BatchNorm2d(mid_channels) |
| | self.inst1 = nn.InstanceNorm2d(mid_channels) |
| | |
| | self.relu = nn.ReLU(inplace=True) |
| | self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False) |
| | self.bn2 = nn.BatchNorm2d(out_channels) |
| | self.inst2 = nn.InstanceNorm2d(out_channels) |
| | |
| | self.downsample = None |
| | if in_channels != out_channels: |
| | self.downsample = nn.Sequential( |
| | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), |
| | nn.BatchNorm2d(out_channels), |
| | ) |
| |
|
| | def forward(self, x): |
| | identity = x |
| |
|
| | out = self.conv1(x) |
| | |
| | out = self.inst1(out) |
| | |
| | out = self.relu(out) |
| |
|
| | out = self.conv2(out) |
| | |
| | out = self.inst2(out) |
| | |
| | if self.downsample is not None: |
| | identity = self.downsample(x) |
| |
|
| | out += identity |
| | out = self.relu(out) |
| | return out |
| |
|
| |
|
| | class Down(nn.Module): |
| | """Downscaling with maxpool then double conv""" |
| |
|
| | def __init__(self, in_channels, out_channels): |
| | super().__init__() |
| | self.maxpool_conv = nn.Sequential( |
| | nn.MaxPool2d(2), |
| | DoubleConv(in_channels, out_channels) |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.maxpool_conv(x) |
| |
|
| |
|
| | class Up(nn.Module): |
| | """Upscaling then double conv""" |
| |
|
| | def __init__(self, in_channels, out_channels, bilinear=True): |
| | super().__init__() |
| |
|
| | |
| | if bilinear: |
| | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
| | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) |
| | else: |
| | if in_channels == out_channels: |
| | self.up = nn.Identity() |
| | else: |
| | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) |
| | self.conv = DoubleConv(in_channels, out_channels) |
| |
|
| | def forward(self, x1, x2): |
| | x1 = self.up(x1) |
| | |
| | diffY = x2.size()[2] - x1.size()[2] |
| | diffX = x2.size()[3] - x1.size()[3] |
| |
|
| | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, |
| | diffY // 2, diffY - diffY // 2]) |
| | |
| | |
| | |
| | x = torch.cat([x2, x1], dim=1) |
| | return self.conv(x) |
| |
|
| |
|
| | class OutConv(nn.Module): |
| | def __init__(self, in_channels, out_channels): |
| | super(OutConv, self).__init__() |
| | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) |
| |
|
| | def forward(self, x): |
| | return self.conv(x) |
| |
|
| | class GaussianLayer(nn.Module): |
| | def __init__(self): |
| | super(GaussianLayer, self).__init__() |
| | self.seq = nn.Sequential( |
| | |
| | nn.Conv2d(1, 1, 5, stride=1, padding=2, bias=False) |
| | ) |
| |
|
| | self.weights_init() |
| | def forward(self, x): |
| | return self.seq(x) |
| |
|
| | def weights_init(self): |
| | n= np.zeros((5,5)) |
| | n[3,3] = 1 |
| | k = scipy.ndimage.gaussian_filter(n,sigma=1) |
| | for name, f in self.named_parameters(): |
| | f.data.copy_(torch.from_numpy(k)) |
| |
|
| | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): |
| | """3x3 convolution with padding""" |
| | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, |
| | padding=dilation, groups=groups, bias=False, dilation=dilation) |
| |
|
| | class Decoder(nn.Module): |
| | def __init__(self): |
| | super(Decoder, self).__init__() |
| | self.up1 = Up(2048, 1024 // 1, False) |
| | self.up2 = Up(1024, 512 // 1, False) |
| | self.up3 = Up(512, 256 // 1, False) |
| | self.conv2d_2_1 = conv3x3(256, 128) |
| | self.gn1 = nn.GroupNorm(4, 128) |
| | self.instance1 = nn.InstanceNorm2d(128) |
| | self.up4 = Up(128, 64 // 1, False) |
| | self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
| | |
| | self.upsample4_conv = DoubleConv(64, 64, 64 // 2) |
| | self.up_ = Up(128, 128 // 1, False) |
| | self.conv2d_2_2 = conv3x3(128, 6) |
| | self.instance2 = nn.InstanceNorm2d(6) |
| | self.gn2 = nn.GroupNorm(3, 6) |
| | self.gaussian_blur = GaussianLayer() |
| | self.up5 = Up(6, 3, False) |
| | self.conv2d_2_3 = conv3x3(3, 1) |
| | self.instance3 = nn.InstanceNorm2d(1) |
| | self.gaussian_blur = GaussianLayer() |
| | self.kernel = nn.Parameter(torch.tensor( |
| | [[[0.0, 0.0, 0.0], [0.0, 1.0, random.uniform(-1.0, 0.0)], [0.0, 0.0, 0.0]], |
| | [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, random.uniform(-1.0, 0.0)]], |
| | [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, random.uniform(random.uniform(-1.0, 0.0), -0.0), 0.0]], |
| | [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [random.uniform(-1.0, 0.0), 0.0, 0.0]], |
| | [[0.0, 0.0, 0.0], [random.uniform(-1.0, 0.0), 1.0, 0.0], [0.0, 0.0, 0.0]], |
| | [[random.uniform(-1.0, 0.0), 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], |
| | [[0.0, random.uniform(-1.0, 0.0), 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], |
| | [[0.0, 0.0, random.uniform(-1.0, 0.0)], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], ], |
| | ).unsqueeze(1)) |
| |
|
| | self.nms_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False, groups=1) |
| | with torch.no_grad(): |
| | self.nms_conv.weight = self.kernel.float() |
| |
|
| |
|
| | class Resnet_with_skip(nn.Module): |
| | def __init__(self, model): |
| | super(Resnet_with_skip, self).__init__() |
| | self.model = model |
| | self.decoder = Decoder() |
| |
|
| | def forward_pred(self, image): |
| | pred_net = self.model(image) |
| | return pred_net |
| |
|
| | def forward_decode(self, image): |
| | identity = image |
| |
|
| | image = self.model.conv1(image) |
| | image = self.model.bn1(image) |
| | image = self.model.relu(image) |
| | image1 = self.model.maxpool(image) |
| |
|
| | image2 = self.model.layer1(image1) |
| | image3 = self.model.layer2(image2) |
| | image4 = self.model.layer3(image3) |
| | image5 = self.model.layer4(image4) |
| |
|
| | reconst1 = self.decoder.up1(image5, image4) |
| | reconst2 = self.decoder.up2(reconst1, image3) |
| | reconst3 = self.decoder.up3(reconst2, image2) |
| | reconst = self.decoder.conv2d_2_1(reconst3) |
| | |
| | reconst = self.decoder.gn1(reconst) |
| | reconst = F.relu(reconst) |
| | reconst4 = self.decoder.up4(reconst, image1) |
| | |
| | reconst5 = self.decoder.upsample4(reconst4) |
| | |
| | reconst5 = self.decoder.up_(reconst5, image) |
| | |
| | reconst5 = self.decoder.conv2d_2_2(reconst5) |
| | reconst5 = self.decoder.instance2(reconst5) |
| | |
| | reconst5 = F.relu(reconst5) |
| | reconst = self.decoder.up5(reconst5, identity) |
| | reconst = self.decoder.conv2d_2_3(reconst) |
| | |
| | reconst = F.relu(reconst) |
| |
|
| | |
| |
|
| | blurred = self.decoder.gaussian_blur(reconst) |
| |
|
| | gradients = kornia.filters.spatial_gradient(blurred, normalized=False) |
| | |
| | gx = gradients[:, :, 0] |
| | gy = gradients[:, :, 1] |
| |
|
| | angle = torch.atan2(gy, gx) |
| |
|
| | |
| | import math |
| | angle = 180.0 * angle / math.pi |
| |
|
| | |
| | angle = torch.round(angle / 45) * 45 |
| | nms_magnitude = self.decoder.nms_conv(blurred) |
| | |
| |
|
| | |
| | |
| | positive_idx = (angle / 45) % 8 |
| | positive_idx = positive_idx.long() |
| |
|
| | negative_idx = ((angle / 45) + 4) % 8 |
| | negative_idx = negative_idx.long() |
| |
|
| | |
| | channel_select_filtered_positive = torch.gather(nms_magnitude, 1, positive_idx) |
| | channel_select_filtered_negative = torch.gather(nms_magnitude, 1, negative_idx) |
| |
|
| | channel_select_filtered = torch.stack( |
| | [channel_select_filtered_positive, channel_select_filtered_negative], 1 |
| | ) |
| |
|
| | |
| |
|
| | |
| |
|
| | thresh = nn.Threshold(0.01, 0.01) |
| | max_matrix = channel_select_filtered.min(dim=1)[0] |
| | max_matrix = thresh(max_matrix) |
| | magnitude = torch.mul(reconst, max_matrix) |
| | |
| | |
| | |
| | magnitude = kornia.enhance.adjust_gamma(magnitude, 2.0) |
| | |
| | return magnitude |
| |
|
| | def forward(self, image): |
| | reconst = self.forward_decode(image) |
| | pred = self.forward_pred(image) |
| | return pred, reconst |
| |
|