import torch.nn as nn import torch class ChannelAttention(nn.Module): def __init__(self, in_channels, ratio = 16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False) self.relu1 = nn.ReLU() self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False) self.sigmod = nn.Sigmoid() def forward(self,x): avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) out = avg_out + max_out return self.sigmod(out) class ECAM(nn.Module): # SNUNet-CD with ECAM def __init__(self, out_ch=2): super(ECAM, self).__init__() torch.nn.Module.dump_patches = True n1 = 32 # the initial number of channels of feature map filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] self.ca = ChannelAttention(filters[0] * 4, ratio=16) self.ca1 = ChannelAttention(filters[0], ratio=16 // 4) self.conv_final = nn.Conv2d(filters[0] * 4, out_ch, kernel_size=1) def forward(self, x): out = torch.cat([x[0], x[1], x[2], x[3]], 1) intra = torch.sum(torch.stack((x[0], x[1], x[2], x[3])), dim=0) ca1 = self.ca1(intra) out = self.ca(out) * (out + ca1.repeat(1, 4, 1, 1)) out = self.conv_final(out) return out