| import torch | |
| import torch.nn as nn | |
| from torch.nn import Parameter | |
| class Decompose_conv(nn.Module): | |
| def __init__(self, conv2d, time_dim=3, time_padding=0, time_stride=1, time_dilation=1, center=False): | |
| super(Decompose_conv, self).__init__() | |
| self.time_dim = time_dim | |
| kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1]) | |
| padding = (time_padding, conv2d.padding[0], conv2d.padding[1]) | |
| stride = (time_stride, conv2d.stride[0], conv2d.stride[0]) | |
| dilation = (time_dilation, conv2d.dilation[0], conv2d.dilation[1]) | |
| if time_dim == 1: | |
| self.conv3d = torch.nn.Conv3d(conv2d.in_channels, conv2d.out_channels, kernel_dim, padding=padding, | |
| dilation=dilation, stride=stride) | |
| weight_2d = conv2d.weight.data | |
| weight_3d = weight_2d.unsqueeze(2) | |
| self.conv3d.weight = Parameter(weight_3d) | |
| self.conv3d.bias = conv2d.bias | |
| else: | |
| self.conv3d_spatial = torch.nn.Conv3d(conv2d.in_channels, conv2d.out_channels, | |
| kernel_size=(1, kernel_dim[1], kernel_dim[2]), | |
| padding=(0, padding[1], padding[2]), | |
| dilation=(1, dilation[1], dilation[2]), | |
| stride=(1, stride[1], stride[2]) | |
| ) | |
| weight_2d = conv2d.weight.data | |
| self.conv3d_spatial.weight = Parameter(weight_2d.unsqueeze(2)) | |
| self.conv3d_spatial.bias = conv2d.bias | |
| self.conv3d_time_1 = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, [1, 1, 1], bias=False) | |
| self.conv3d_time_2 = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, [1, 1, 1], bias=False) | |
| self.conv3d_time_3 = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, [1, 1, 1], bias=False) | |
| torch.nn.init.constant_(self.conv3d_time_1.weight, 0.0) | |
| torch.nn.init.constant_(self.conv3d_time_3.weight, 0.0) | |
| torch.nn.init.eye_(self.conv3d_time_2.weight[:, :, 0, 0, 0]) | |
| temp = 1 | |
| def forward(self, x): | |
| if self.time_dim == 1: | |
| return self.conv3d(x) | |
| else: | |
| x_spatial = self.conv3d_spatial(x) | |
| T1 = x_spatial[:, :, 0:1, :, :] | |
| T2 = x_spatial[:, :, 1:2, :, :] | |
| T1_F1 = self.conv3d_time_2(T1) | |
| T2_F1 = self.conv3d_time_2(T2) | |
| T1_F2 = self.conv3d_time_1(T1) | |
| T2_F2 = self.conv3d_time_3(T2) | |
| x = torch.cat([T1_F1 + T2_F2, T1_F2 + T2_F1], dim=2) | |
| return x | |
| def Decompose_norm(batch2d): | |
| batch3d = torch.nn.BatchNorm3d(batch2d.num_features) | |
| batch2d._check_input_dim = batch3d._check_input_dim | |
| return batch2d | |
| def Decompose_pool(pool2d, time_dim=1, time_padding=0, time_stride=None, time_dilation=1): | |
| if isinstance(pool2d, torch.nn.AdaptiveAvgPool2d): | |
| pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1)) | |
| else: | |
| kernel_dim = (time_dim, pool2d.kernel_size, pool2d.kernel_size) | |
| padding = (time_padding, pool2d.padding, pool2d.padding) | |
| if time_stride is None: | |
| time_stride = time_dim | |
| stride = (time_stride, pool2d.stride, pool2d.stride) | |
| if isinstance(pool2d, torch.nn.MaxPool2d): | |
| dilation = (time_dilation, pool2d.dilation, pool2d.dilation) | |
| pool3d = torch.nn.MaxPool3d(kernel_dim, padding=padding, dilation=dilation, stride=stride, | |
| ceil_mode=pool2d.ceil_mode) | |
| elif isinstance(pool2d, torch.nn.AvgPool2d): | |
| pool3d = torch.nn.AvgPool3d(kernel_dim, stride=stride) | |
| else: | |
| raise ValueError('{} is not among known pooling classes'.format(type(pool2d))) | |
| return pool3d | |
| def inflate_conv(conv2d, time_dim=3, time_padding=0, time_stride=1, time_dilation=1, center=False): | |
| kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1]) | |
| padding = (time_padding, conv2d.padding[0], conv2d.padding[1]) | |
| stride = (time_stride, conv2d.stride[0], conv2d.stride[0]) | |
| dilation = (time_dilation, conv2d.dilation[0], conv2d.dilation[1]) | |
| conv3d = torch.nn.Conv3d(conv2d.in_channels, conv2d.out_channels, kernel_dim, padding=padding, | |
| dilation=dilation, stride=stride) | |
| weight_2d = conv2d.weight.data | |
| if center: | |
| weight_3d = torch.zeros(*weight_2d.shape) | |
| weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) | |
| middle_idx = time_dim // 2 | |
| weight_3d[:, :, middle_idx, :, :] = weight_2d | |
| else: | |
| weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) | |
| weight_3d = weight_3d / time_dim | |
| conv3d.weight = Parameter(weight_3d) | |
| conv3d.bias = conv2d.bias | |
| return conv3d | |