| from rscd.models.decoderheads.lgpnet.unet_parts import *
|
|
|
| class BCDNET(nn.Module):
|
| """ Local-Global Pyramid Network (LGPNet) """
|
| def __init__(self, n_channels, n_classes):
|
| super(BCDNET, self).__init__()
|
| self.n_channels = n_channels
|
| self.n_classes = n_classes
|
| self.conv = TribleConv(128, 64)
|
| self.final = OutConv(64, n_classes)
|
|
|
| def forward(self, x=[]):
|
|
|
|
|
| feat1 = x[2]
|
| feat2 = x[3]
|
| fusionfeats = torch.cat([feat1, feat2], dim=1)
|
|
|
| x = self.conv(fusionfeats)
|
| logits = self.final(x)
|
| return logits
|
|
|
|
|
| class TribleConv(nn.Module):
|
| """(convolution => [BN] => ReLU) 2次"""
|
|
|
| def __init__(self, in_channels, out_channels):
|
| super().__init__()
|
| self.trible_conv = nn.Sequential(
|
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(out_channels),
|
| nn.ReLU(inplace=True),
|
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(out_channels),
|
| nn.ReLU(inplace=True)
|
| )
|
|
|
| def forward(self, x):
|
| return self.trible_conv(x)
|
|
|
|
|
| if __name__ == '__main__':
|
| net = BCDNET(n_channels=3, n_classes=1)
|
| print(net) |