| """ |
| Residual Convolutional Autoencoder for Image Reconstruction |
| Architecture: 6-layer encoder/decoder with residual blocks |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class AEResidualBlock(nn.Module): |
| """Residual block with batch normalization and dropout""" |
| def __init__(self, channels, dropout=0.1): |
| super().__init__() |
| self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) |
| self.bn1 = nn.BatchNorm2d(channels) |
| self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) |
| self.bn2 = nn.BatchNorm2d(channels) |
| self.relu = nn.ReLU(inplace=True) |
| self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity() |
| |
| def forward(self, x): |
| residual = x |
| out = self.relu(self.bn1(self.conv1(x))) |
| out = self.dropout(out) |
| out = self.bn2(self.conv2(out)) |
| out += residual |
| return self.relu(out) |
|
|
|
|
| class ResidualConvAutoencoder(nn.Module): |
| """ |
| Deep Convolutional Autoencoder with Residual Connections |
| |
| Args: |
| latent_dim (int): Dimension of latent space (512 or 768) |
| dropout (float): Dropout rate for regularization (0.15 or 0.20) |
| |
| Input: (B, 3, 256, 256) RGB images |
| Output: (B, 3, 256, 256) Reconstructed images + (B, latent_dim) latent codes |
| """ |
| def __init__(self, latent_dim=512, dropout=0.15): |
| super().__init__() |
| |
| self.latent_dim = latent_dim |
| self.dropout = dropout |
| |
| |
| self.encoder = nn.Sequential( |
| |
| nn.Conv2d(3, 64, 4, stride=2, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(inplace=True), |
| AEResidualBlock(64, dropout), |
| |
| |
| nn.Conv2d(64, 128, 4, stride=2, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(inplace=True), |
| AEResidualBlock(128, dropout), |
| |
| |
| nn.Conv2d(128, 256, 4, stride=2, padding=1), |
| nn.BatchNorm2d(256), |
| nn.ReLU(inplace=True), |
| AEResidualBlock(256, dropout), |
| |
| |
| nn.Conv2d(256, 512, 4, stride=2, padding=1), |
| nn.BatchNorm2d(512), |
| nn.ReLU(inplace=True), |
| AEResidualBlock(512, dropout), |
| |
| |
| nn.Conv2d(512, 512, 4, stride=2, padding=1), |
| nn.BatchNorm2d(512), |
| nn.ReLU(inplace=True), |
| AEResidualBlock(512, dropout), |
| |
| |
| nn.Conv2d(512, 512, 4, stride=2, padding=1), |
| nn.BatchNorm2d(512), |
| nn.ReLU(inplace=True), |
| ) |
| |
| |
| self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim) |
| self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4) |
| |
| |
| self.decoder = nn.Sequential( |
| |
| nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1), |
| nn.BatchNorm2d(512), |
| nn.ReLU(inplace=True), |
| AEResidualBlock(512, dropout), |
| |
| |
| nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1), |
| nn.BatchNorm2d(512), |
| nn.ReLU(inplace=True), |
| AEResidualBlock(512, dropout), |
| |
| |
| nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), |
| nn.BatchNorm2d(256), |
| nn.ReLU(inplace=True), |
| AEResidualBlock(256, dropout), |
| |
| |
| nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(inplace=True), |
| AEResidualBlock(128, dropout), |
| |
| |
| nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(inplace=True), |
| AEResidualBlock(64, dropout), |
| |
| |
| nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), |
| nn.Tanh() |
| ) |
| |
| def forward(self, x): |
| """ |
| Forward pass |
| |
| Args: |
| x: Input tensor (B, 3, 256, 256) in range [-1, 1] |
| |
| Returns: |
| reconstructed: Reconstructed tensor (B, 3, 256, 256) |
| latent: Latent representation (B, latent_dim) |
| """ |
| |
| x = self.encoder(x) |
| x = x.view(x.size(0), -1) |
| latent = self.fc_encoder(x) |
| |
| |
| x = self.fc_decoder(latent) |
| x = x.view(x.size(0), 512, 4, 4) |
| reconstructed = self.decoder(x) |
| |
| return reconstructed, latent |
| |
| def encode(self, x): |
| """Get latent representation only""" |
| x = self.encoder(x) |
| x = x.view(x.size(0), -1) |
| return self.fc_encoder(x) |
| |
| def decode(self, latent): |
| """Reconstruct from latent code""" |
| x = self.fc_decoder(latent) |
| x = x.view(x.size(0), 512, 4, 4) |
| return self.decoder(x) |
|
|
|
|
| def load_model(checkpoint_path, latent_dim=512, dropout=0.15, device='cuda'): |
| """ |
| Load a trained model from checkpoint |
| |
| Args: |
| checkpoint_path: Path to .pth checkpoint file |
| latent_dim: Latent dimension (512 for Model A, 768 for Model B) |
| dropout: Dropout rate (0.15 for Model A, 0.20 for Model B) |
| device: Device to load model on |
| |
| Returns: |
| model: Loaded model in eval mode |
| checkpoint: Full checkpoint dict with metadata |
| """ |
| model = ResidualConvAutoencoder(latent_dim=latent_dim, dropout=dropout) |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model.eval() |
| model.to(device) |
| return model, checkpoint |
|
|