| | |
| |
|
| | from typing import Sequence, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer |
| |
|
| |
|
| | class UNesTBlock(nn.Module): |
| | """ """ |
| |
|
| | def __init__( |
| | self, |
| | spatial_dims: int, |
| | in_channels: int, |
| | out_channels: int, |
| | kernel_size: Union[Sequence[int], int], |
| | stride: Union[Sequence[int], int], |
| | upsample_kernel_size: Union[Sequence[int], int], |
| | norm_name: Union[Tuple, str], |
| | res_block: bool = False, |
| | ) -> None: |
| | """ |
| | Args: |
| | spatial_dims: number of spatial dimensions. |
| | in_channels: number of input channels. |
| | out_channels: number of output channels. |
| | kernel_size: convolution kernel size. |
| | stride: convolution stride. |
| | upsample_kernel_size: convolution kernel size for transposed convolution layers. |
| | norm_name: feature normalization type and arguments. |
| | res_block: bool argument to determine if residual block is used. |
| | |
| | """ |
| |
|
| | super(UNesTBlock, self).__init__() |
| | upsample_stride = upsample_kernel_size |
| | self.transp_conv = get_conv_layer( |
| | spatial_dims, |
| | in_channels, |
| | out_channels, |
| | kernel_size=upsample_kernel_size, |
| | stride=upsample_stride, |
| | conv_only=True, |
| | is_transposed=True, |
| | ) |
| |
|
| | if res_block: |
| | self.conv_block = UnetResBlock( |
| | spatial_dims, |
| | out_channels + out_channels, |
| | out_channels, |
| | kernel_size=kernel_size, |
| | stride=1, |
| | norm_name=norm_name, |
| | ) |
| | else: |
| | self.conv_block = UnetBasicBlock( |
| | spatial_dims, |
| | out_channels + out_channels, |
| | out_channels, |
| | kernel_size=kernel_size, |
| | stride=1, |
| | norm_name=norm_name, |
| | ) |
| |
|
| | def forward(self, inp, skip): |
| | |
| | out = self.transp_conv(inp) |
| | |
| | |
| | out = torch.cat((out, skip), dim=1) |
| | out = self.conv_block(out) |
| | return out |
| |
|
| |
|
| | class UNestUpBlock(nn.Module): |
| | """ """ |
| |
|
| | def __init__( |
| | self, |
| | spatial_dims: int, |
| | in_channels: int, |
| | out_channels: int, |
| | num_layer: int, |
| | kernel_size: Union[Sequence[int], int], |
| | stride: Union[Sequence[int], int], |
| | upsample_kernel_size: Union[Sequence[int], int], |
| | norm_name: Union[Tuple, str], |
| | conv_block: bool = False, |
| | res_block: bool = False, |
| | ) -> None: |
| | """ |
| | Args: |
| | spatial_dims: number of spatial dimensions. |
| | in_channels: number of input channels. |
| | out_channels: number of output channels. |
| | num_layer: number of upsampling blocks. |
| | kernel_size: convolution kernel size. |
| | stride: convolution stride. |
| | upsample_kernel_size: convolution kernel size for transposed convolution layers. |
| | norm_name: feature normalization type and arguments. |
| | conv_block: bool argument to determine if convolutional block is used. |
| | res_block: bool argument to determine if residual block is used. |
| | |
| | """ |
| |
|
| | super().__init__() |
| |
|
| | upsample_stride = upsample_kernel_size |
| | self.transp_conv_init = get_conv_layer( |
| | spatial_dims, |
| | in_channels, |
| | out_channels, |
| | kernel_size=upsample_kernel_size, |
| | stride=upsample_stride, |
| | conv_only=True, |
| | is_transposed=True, |
| | ) |
| | if conv_block: |
| | if res_block: |
| | self.blocks = nn.ModuleList( |
| | [ |
| | nn.Sequential( |
| | get_conv_layer( |
| | spatial_dims, |
| | out_channels, |
| | out_channels, |
| | kernel_size=upsample_kernel_size, |
| | stride=upsample_stride, |
| | conv_only=True, |
| | is_transposed=True, |
| | ), |
| | UnetResBlock( |
| | spatial_dims=3, |
| | in_channels=out_channels, |
| | out_channels=out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | norm_name=norm_name, |
| | ), |
| | ) |
| | for i in range(num_layer) |
| | ] |
| | ) |
| | else: |
| | self.blocks = nn.ModuleList( |
| | [ |
| | nn.Sequential( |
| | get_conv_layer( |
| | spatial_dims, |
| | out_channels, |
| | out_channels, |
| | kernel_size=upsample_kernel_size, |
| | stride=upsample_stride, |
| | conv_only=True, |
| | is_transposed=True, |
| | ), |
| | UnetBasicBlock( |
| | spatial_dims=3, |
| | in_channels=out_channels, |
| | out_channels=out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | norm_name=norm_name, |
| | ), |
| | ) |
| | for i in range(num_layer) |
| | ] |
| | ) |
| | else: |
| | self.blocks = nn.ModuleList( |
| | [ |
| | get_conv_layer( |
| | spatial_dims, |
| | out_channels, |
| | out_channels, |
| | kernel_size=1, |
| | stride=1, |
| | conv_only=True, |
| | is_transposed=True, |
| | ) |
| | for i in range(num_layer) |
| | ] |
| | ) |
| |
|
| | def forward(self, x): |
| | x = self.transp_conv_init(x) |
| | for blk in self.blocks: |
| | x = blk(x) |
| | return x |
| |
|
| |
|
| | class UNesTConvBlock(nn.Module): |
| | """ |
| | UNesT block with skip connections |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | spatial_dims: int, |
| | in_channels: int, |
| | out_channels: int, |
| | kernel_size: Union[Sequence[int], int], |
| | stride: Union[Sequence[int], int], |
| | norm_name: Union[Tuple, str], |
| | res_block: bool = False, |
| | ) -> None: |
| | """ |
| | Args: |
| | spatial_dims: number of spatial dimensions. |
| | in_channels: number of input channels. |
| | out_channels: number of output channels. |
| | kernel_size: convolution kernel size. |
| | stride: convolution stride. |
| | norm_name: feature normalization type and arguments. |
| | res_block: bool argument to determine if residual block is used. |
| | |
| | """ |
| |
|
| | super().__init__() |
| |
|
| | if res_block: |
| | self.layer = UnetResBlock( |
| | spatial_dims=spatial_dims, |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | norm_name=norm_name, |
| | ) |
| | else: |
| | self.layer = UnetBasicBlock( |
| | spatial_dims=spatial_dims, |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | norm_name=norm_name, |
| | ) |
| |
|
| | def forward(self, inp): |
| | out = self.layer(inp) |
| | return out |
| |
|