|
|
import torch.nn as nn
|
|
|
|
|
|
"""
|
|
|
This code refers to "Pyramid attention network for semantic segmentation", that is
|
|
|
"https://github.com/JaveyWang/Pyramid-Attention-Networks-pytorch/blob/f719365c1780f062058dd0c94550c6c4766cd937/networks.py#L41"
|
|
|
"""
|
|
|
|
|
|
class FPM(nn.Module):
|
|
|
def __init__(self, channels=1024):
|
|
|
"""
|
|
|
Feature Pyramid Attention
|
|
|
:type channels: int
|
|
|
"""
|
|
|
super(FPM, self).__init__()
|
|
|
channels_mid = int(channels/4)
|
|
|
|
|
|
self.channels_cond = channels
|
|
|
|
|
|
self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
|
|
|
self.bn_master = nn.BatchNorm2d(channels)
|
|
|
|
|
|
|
|
|
self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False)
|
|
|
self.bn1_1 = nn.BatchNorm2d(channels_mid)
|
|
|
self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False)
|
|
|
self.bn2_1 = nn.BatchNorm2d(channels_mid)
|
|
|
self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False)
|
|
|
self.bn3_1 = nn.BatchNorm2d(channels_mid)
|
|
|
|
|
|
self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False)
|
|
|
self.bn1_2 = nn.BatchNorm2d(channels_mid)
|
|
|
self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False)
|
|
|
self.bn2_2 = nn.BatchNorm2d(channels_mid)
|
|
|
self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False)
|
|
|
self.bn3_2 = nn.BatchNorm2d(channels_mid)
|
|
|
|
|
|
|
|
|
self.conv_upsample_3 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False)
|
|
|
self.bn_upsample_3 = nn.BatchNorm2d(channels_mid)
|
|
|
|
|
|
self.conv_upsample_2 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False)
|
|
|
self.bn_upsample_2 = nn.BatchNorm2d(channels_mid)
|
|
|
|
|
|
self.conv_upsample_1 = nn.ConvTranspose2d(channels_mid, channels, kernel_size=4, stride=2, padding=1, bias=False)
|
|
|
self.bn_upsample_1 = nn.BatchNorm2d(channels)
|
|
|
|
|
|
self.relu = nn.ReLU(inplace=False)
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""
|
|
|
:param x: Shape: [b, 2048, h, w]
|
|
|
:return: out: Feature maps. Shape: [b, 2048, h, w]
|
|
|
"""
|
|
|
|
|
|
x_master = self.conv_master(x)
|
|
|
x_master = self.bn_master(x_master)
|
|
|
|
|
|
|
|
|
x1_1 = self.conv7x7_1(x)
|
|
|
x1_1 = self.bn1_1(x1_1)
|
|
|
x1_1 = self.relu(x1_1)
|
|
|
x1_2 = self.conv7x7_2(x1_1)
|
|
|
x1_2 = self.bn1_2(x1_2)
|
|
|
|
|
|
|
|
|
x2_1 = self.conv5x5_1(x1_1)
|
|
|
x2_1 = self.bn2_1(x2_1)
|
|
|
x2_1 = self.relu(x2_1)
|
|
|
x2_2 = self.conv5x5_2(x2_1)
|
|
|
x2_2 = self.bn2_2(x2_2)
|
|
|
|
|
|
|
|
|
x3_1 = self.conv3x3_1(x2_1)
|
|
|
x3_1 = self.bn3_1(x3_1)
|
|
|
x3_1 = self.relu(x3_1)
|
|
|
x3_2 = self.conv3x3_2(x3_1)
|
|
|
x3_2 = self.bn3_2(x3_2)
|
|
|
|
|
|
|
|
|
x3_upsample = self.relu(self.bn_upsample_3(self.conv_upsample_3(x3_2)))
|
|
|
x2_merge = self.relu(x2_2 + x3_upsample)
|
|
|
x2_upsample = self.relu(self.bn_upsample_2(self.conv_upsample_2(x2_merge)))
|
|
|
x1_merge = self.relu(x1_2 + x2_upsample)
|
|
|
|
|
|
x_master = x_master * self.relu(self.bn_upsample_1(self.conv_upsample_1(x1_merge)))
|
|
|
|
|
|
out = self.relu(x_master)
|
|
|
|
|
|
return out |