| | from typing import Optional, Dict, Any, List |
| | import torch |
| | import torch.nn as nn |
| |
|
| | |
| | |
| | |
| |
|
| | class Conv2d(nn.Module): |
| | """ Perform a 2D convolution |
| | |
| | inputs are [b, c, h, w] where |
| | b is the batch size |
| | c is the number of channels |
| | h is the height |
| | w is the width |
| | """ |
| | def __init__(self, |
| | in_channels: int, |
| | out_channels: int, |
| | kernel_size: int, |
| | padding: int, |
| | do_activation: bool = True, |
| | ): |
| | super(Conv2d, self).__init__() |
| |
|
| | conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) |
| | lst = [conv] |
| |
|
| | if do_activation: |
| | lst.append(nn.PReLU()) |
| |
|
| | self.conv = nn.Sequential(*lst) |
| |
|
| | def forward(self, x): |
| | |
| | return self.conv(x) |
| | |
| | |
| | |
| | |
| |
|
| | class _UNet(nn.Module): |
| | def __init__(self, |
| | in_channels: int = 1, |
| | out_channels: int = 1, |
| | features: List[int] = [64, 64, 64, 64, 64], |
| | conv_kernel_size: int = 3, |
| | conv: Optional[nn.Module] = None, |
| | conv_kwargs: Dict[str,Any] = {} |
| | ): |
| | """ |
| | UNet (but can switch out the Conv) |
| | """ |
| | super(_UNet, self).__init__() |
| |
|
| | self.in_channels = in_channels |
| |
|
| | padding = (conv_kernel_size - 1) // 2 |
| |
|
| | self.ups = nn.ModuleList() |
| | self.downs = nn.ModuleList() |
| | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
| |
|
| | |
| | for feat in features: |
| | self.downs.append( |
| | conv( |
| | in_channels, feat, kernel_size=conv_kernel_size, padding=padding, **conv_kwargs |
| | ) |
| | ) |
| | in_channels = feat |
| |
|
| | |
| | for feat in reversed(features): |
| | self.ups.append(nn.UpsamplingBilinear2d(scale_factor=2)) |
| | self.ups.append( |
| | conv( |
| | |
| | feat * 2, feat, kernel_size=conv_kernel_size, padding=padding, **conv_kwargs |
| | ) |
| | ) |
| |
|
| | self.bottleneck = conv( |
| | features[-1], features[-1], kernel_size=conv_kernel_size, padding=padding, **conv_kwargs |
| | ) |
| | self.final_conv = conv( |
| | features[0], out_channels, kernel_size=1, padding=0, do_activation=False, **conv_kwargs |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | skip_connections = [] |
| | for down in self.downs: |
| | x = down(x) |
| | skip_connections.append(x) |
| | x = self.pool(x) |
| |
|
| | x = self.bottleneck(x) |
| | skip_connections = skip_connections[::-1] |
| |
|
| | for idx in range(0, len(self.ups), 2): |
| | x = self.ups[idx](x) |
| | skip_connection = skip_connections[idx // 2] |
| |
|
| | concat_skip = torch.cat((skip_connection, x), dim=1) |
| | x = self.ups[idx + 1](concat_skip) |
| |
|
| | return self.final_conv(x) |
| | |
| |
|
| | class UNet(_UNet): |
| | """ |
| | Unet with normal conv blocks |
| | |
| | input shape: B x C x H x W |
| | output shape: B x C x H x W |
| | """ |
| | def __init__(self, **kwargs) -> None: |
| | super().__init__(conv=Conv2d, **kwargs) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return super().forward(x) |
| | |