| | import torch.nn as nn |
| | import torch |
| |
|
| | class ChannelAttention(nn.Module): |
| | def __init__(self, in_channels, ratio = 16): |
| | super(ChannelAttention, self).__init__() |
| | self.avg_pool = nn.AdaptiveAvgPool2d(1) |
| | self.max_pool = nn.AdaptiveMaxPool2d(1) |
| | self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False) |
| | self.relu1 = nn.ReLU() |
| | self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False) |
| | self.sigmod = nn.Sigmoid() |
| | def forward(self,x): |
| | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) |
| | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) |
| | out = avg_out + max_out |
| | return self.sigmod(out) |
| |
|
| | class ECAM(nn.Module): |
| | |
| | def __init__(self, out_ch=2): |
| | super(ECAM, self).__init__() |
| | torch.nn.Module.dump_patches = True |
| | n1 = 32 |
| | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] |
| |
|
| | self.ca = ChannelAttention(filters[0] * 4, ratio=16) |
| | self.ca1 = ChannelAttention(filters[0], ratio=16 // 4) |
| |
|
| | self.conv_final = nn.Conv2d(filters[0] * 4, out_ch, kernel_size=1) |
| |
|
| | def forward(self, x): |
| | out = torch.cat([x[0], x[1], x[2], x[3]], 1) |
| |
|
| | intra = torch.sum(torch.stack((x[0], x[1], x[2], x[3])), dim=0) |
| | ca1 = self.ca1(intra) |
| | out = self.ca(out) * (out + ca1.repeat(1, 4, 1, 1)) |
| | out = self.conv_final(out) |
| |
|
| | return out |
| |
|