Spaces:
Running
Running
| import torch.nn as nn | |
| import torch | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(ResidualBlock, self).__init__() | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) | |
| self.bn1 = nn.BatchNorm2d(out_channels) | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) | |
| self.bn2 = nn.BatchNorm2d(out_channels) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None | |
| def forward(self, x): | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| if self.downsample: | |
| residual = self.downsample(x) | |
| out += residual | |
| out = self.relu(out) | |
| return out | |