| | from ConvLSTM import ConvLSTM |
| | import torch |
| | import torch.nn as nn |
| | from collections import defaultdict |
| |
|
| | |
| | class MLP_5D(nn.Module): |
| | def __init__(self, height, width): |
| | super(MLP_5D, self).__init__() |
| | |
| | self.fc1 = nn.Linear(64, 128) |
| | self.dropout1 = nn.Dropout(0.05) |
| | self.fc2 = nn.Linear(128, 64) |
| | self.dropout2 = nn.Dropout(0.05) |
| | self.fc3 = nn.Linear(64, 1) |
| |
|
| | self.height = height |
| | self.width = width |
| |
|
| | def forward(self, x): |
| | batch_size, timesteps, channels, height, width = x.shape |
| | |
| | |
| | assert height == self.height and width == self.width, "Height and width mismatch" |
| | |
| | |
| | x = x.permute(0, 1, 3, 4, 2).reshape(-1, channels) |
| | |
| | |
| | |
| | x = self.fc1(x) |
| | x = torch.nn.functional.softplus(x) |
| | x = self.dropout1(x) |
| | x = self.fc2(x) |
| | x = torch.nn.functional.softplus(x) |
| | x = self.dropout2(x) |
| | x = self.fc3(x) |
| | x = torch.nn.functional.softplus(x) |
| | |
| | |
| | x = x.view(batch_size, timesteps, self.height, self.width, 1).permute(0, 1, 4, 2, 3) |
| |
|
| | return x |
| | |
| | |
| | |
| |
|
| | class ConvLSTMNetwork(nn.Module): |
| | def __init__(self, input_dim, hidden_dims, kernel_size, num_layers, output_channels, batch_first=True, pool_size=(2,2)): |
| | super(ConvLSTMNetwork, self).__init__() |
| | |
| | |
| | self.convlstm = ConvLSTM(input_dim=input_dim, |
| | hidden_dim=hidden_dims, |
| | kernel_size=kernel_size, |
| | num_layers=num_layers, |
| | batch_first=batch_first, |
| | bias=True, |
| | return_all_layers=True) |
| | |
| | |
| | self.batch_norms = nn.ModuleList([ |
| | nn.BatchNorm3d(hidden_dim) for hidden_dim in hidden_dims |
| | ]) |
| |
|
| | |
| | self.conv3d = nn.Conv3d(in_channels=hidden_dims[-1], |
| | out_channels=output_channels, |
| | kernel_size=(1, 3, 3), |
| | padding=(0, 1, 1)) |
| |
|
| | |
| | self.mlp = MLP_5D(height=81, width=97) |
| |
|
| | |
| | |
| | |
| | |
| | self.classification_head = nn.Sequential( |
| | nn.Conv3d(output_channels, 1, kernel_size=(1,1,1)), |
| | nn.Sigmoid() |
| | ) |
| |
|
| | self.activation_variance = defaultdict(list) |
| |
|
| | def forward(self, x): |
| | """ |
| | x: (B, T, input_dim, H, W) |
| | """ |
| | |
| | layer_output_list, last_state_list = self.convlstm(x) |
| | |
| | |
| | for i, output in enumerate(layer_output_list): |
| | |
| | output = output.permute(0, 2, 1, 3, 4) |
| | output = self.batch_norms[i](output) |
| | output = output.permute(0, 2, 1, 3, 4) |
| |
|
| | |
| | activation_variance = output.var(dim=(3, 4)).mean().item() |
| | self.activation_variance[f"ConvLSTM_layer_{i}"].append(activation_variance) |
| |
|
| | layer_output_list[i] = output |
| | |
| | |
| | final_output = layer_output_list[-1] |
| |
|
| | |
| | final_output = final_output.permute(0, 2, 1, 3, 4) |
| | final_output = self.conv3d(final_output) |
| | |
| |
|
| | |
| | final_output_t = final_output.permute(0, 2, 1, 3, 4) |
| |
|
| | |
| | regression_output = self.mlp(final_output_t) |
| |
|
| | |
| | |
| | final_output_c = final_output |
| | classification_output = self.classification_head(final_output_c) |
| | |
| |
|
| | |
| | classification_output = classification_output.permute(0, 2, 1, 3, 4) |
| |
|
| | return regression_output, classification_output |