| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from omegaconf import DictConfig |
| | from typing import Any, Dict, Tuple |
| | from utils import instantiate |
| | import cv2 |
| | from PIL import Image |
| | import numpy as np |
| | |
| | class ResidualBlock(nn.Module): |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.relu = nn.ReLU() |
| | self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1) |
| | self.conv2 = nn.Conv2d(dim, dim, 1) |
| |
|
| | def forward(self, x): |
| | tmp = self.relu(x) |
| | tmp = self.conv1(tmp) |
| | tmp = self.relu(tmp) |
| | tmp = self.conv2(tmp) |
| | return x + tmp |
| |
|
| |
|
| | class Encoder2D(nn.Module): |
| | def __init__(self, output_channels=512): |
| | super(Encoder2D, self).__init__() |
| | |
| | self.block = nn.Sequential( |
| | nn.Conv2d(3, output_channels, 4, 2, 1), |
| | nn.ReLU(), |
| | nn.Conv2d(output_channels, output_channels, 4, 2, 1), |
| | nn.ReLU(), |
| | nn.Conv2d(output_channels, output_channels, 4, 2, 1), |
| | nn.ReLU(), |
| | nn.Conv2d(output_channels, output_channels, 4, 2, 1), |
| | nn.ReLU(), |
| | nn.Conv2d(output_channels, output_channels, 4, 2, 1), |
| | nn.ReLU(), |
| | nn.Conv2d(output_channels, output_channels, 3, 1, 1), |
| | ResidualBlock(output_channels), |
| | ResidualBlock(output_channels), |
| | ) |
| | |
| | def forward(self, x): |
| | x = self.block(x) |
| | return x |
| | |
| |
|
| | class Decoder2D(nn.Module): |
| | def __init__(self, input_dim=512): |
| | super(Decoder2D, self).__init__() |
| | |
| | self.fea_map_size=16 |
| | |
| | self.block = nn.Sequential( |
| | nn.Conv2d(input_dim, input_dim, 3, 1, 1), |
| | ResidualBlock(input_dim), |
| | ResidualBlock(input_dim), |
| | nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), |
| | nn.ReLU(), |
| | nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), |
| | nn.ReLU(), |
| | nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), |
| | nn.ReLU(), |
| | nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1), |
| | nn.ReLU(), |
| | nn.ConvTranspose2d(input_dim, 3, 4, 2, 1) |
| | ) |
| |
|
| | def forward(self, x): |
| | x_hat = self.block(x) |
| |
|
| | return x_hat |
| | |
| |
|
| | class Encoder(Encoder2D): |
| | def __init__(self, output_channels=512): |
| | super().__init__(output_channels) |
| | self.pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) |
| |
|
| | def forward(self, x): |
| | x = self.block(x) |
| | x = self.pool(x) |
| | return x |
| | |
| |
|
| | class Decoder(Decoder2D): |
| | def __init__(self, input_dim=512): |
| | super().__init__(input_dim) |
| | |
| | self.fc = nn.Linear(input_dim, input_dim*self.fea_map_size*self.fea_map_size) |
| |
|
| | def forward(self, x): |
| | x = self.fc(x.view(x.size(0), -1)) |
| | x = x.view(x.size(0), 512, self.fea_map_size, self.fea_map_size) |
| | x_hat = self.block(x) |
| |
|
| | return x_hat |
| |
|
| |
|