| | import mxnet as mx |
| | import numpy as np |
| | from nowcasting.operators.common import constant |
| |
|
| |
|
| | def CDNA(data, kernels, mask, batch_size, num_filter, kernel_size): |
| | """We assume that the kernels and masks are the output of an identity activation |
| | |
| | Parameters |
| | ---------- |
| | data : mx.sym.symbol |
| | Shape: (batch_size, C, H, W) |
| | kernels : mx.sym.symbol |
| | Shape: (batch_size, M, K, K) |
| | mask : mx.sym.symbol |
| | Shape: (batch_size, M, H, W) |
| | batch_size : int |
| | num_filter : int |
| | M |
| | kernel_size : int |
| | K |
| | Returns |
| | ------- |
| | ret : mx.sym.symbol |
| | Shape: (batch_size, C, H, W) |
| | """ |
| | assert kernel_size % 2 == 1, "Only support odd kernel size" |
| | |
| | kernels = mx.sym.SoftmaxActivation(mx.sym.Reshape(kernels, |
| | shape=(-1, kernel_size * kernel_size))) |
| | kernels = mx.sym.Reshape(kernels, shape=(-1, num_filter, kernel_size, kernel_size)) |
| | mask = mx.sym.SoftmaxActivation(mask, mode="channel") |
| |
|
| | data_sliced = mx.sym.SliceChannel(mx.sym.expand_dims(data, axis=2), axis=0, |
| | num_outputs=batch_size, squeeze_axis=True) |
| | kernels_sliced = mx.sym.SliceChannel(mx.sym.expand_dims(kernels, axis=2), |
| | axis=0, num_outputs=batch_size, |
| | squeeze_axis=True) |
| | out = [] |
| | for i in range(batch_size): |
| | ele = mx.sym.Convolution(data=data_sliced[i], |
| | num_filter=num_filter, |
| | kernel=(kernel_size, kernel_size), |
| | pad=(kernel_size/2, kernel_size/2), |
| | weight=kernels_sliced[i], no_bias=True) |
| | out.append(mx.sym.expand_dims(ele, axis=0)) |
| | out = mx.sym.Concat(*out, num_args=batch_size, dim=0) |
| | mask = mx.sym.Reshape(mask, reverse=True, shape=(batch_size, 1, num_filter, 0, 0)) |
| | out = mx.sym.broadcast_mul(out, mask) |
| | out = mx.sym.sum(out, axis=2) |
| | return out |
| |
|
| | def STP(data, affine_transform_matrices, mask, num_filter, kernel_size): |
| | """Spatial Transformer Predictor |
| | |
| | Parameters |
| | ---------- |
| | data : mx.sym.symbol |
| | affine_transform_matrices |
| | mask |
| | |
| | Returns |
| | ------- |
| | |
| | """ |
| | raise NotImplementedError() |
| |
|
| |
|
| | def DFN(data, local_kernels, K, batch_size): |
| | """[NIPS2016] Dynamic Filter Network |
| | |
| | Parameters |
| | ---------- |
| | data : mx.sym.symbol |
| | Shape: (batch_size, C, H, W) |
| | local_kernels : mx.sym.symbol |
| | Shape: (batch_size, K*K, H, W) |
| | K : int |
| | size of the local convolutional kernel |
| | batch_size : int |
| | size of the minibatch |
| | Returns |
| | ------- |
| | |
| | """ |
| | local_kernels = mx.sym.SoftmaxActivation(local_kernels, mode="channel") |
| | |
| | |
| | filter_localexpand = mx.sym.one_hot(indices=mx.sym.arange(K * K), depth=K*K) |
| | filter_localexpand = mx.sym.reshape(mx.sym.transpose(filter_localexpand, axes=(1, 0)), |
| | shape=(K * K, 1, K, K)) |
| | data_sliced = mx.sym.SliceChannel(data, num_outputs=batch_size, axis=0, squeeze_axis=True) |
| | vec = [] |
| | for i in range(batch_size): |
| | ele = mx.sym.Convolution(data=mx.sym.expand_dims(data_sliced[i], axis=1), |
| | weight=filter_localexpand, |
| | num_filter=K*K, |
| | kernel=(K, K), |
| | pad=(K // 2, K // 2), no_bias=True) |
| | vec.append(mx.sym.expand_dims(ele, axis=0)) |
| | input_localexpanded = mx.sym.Concat(*vec, num_args=len(vec), dim=0) |
| | output = mx.sym.broadcast_mul(input_localexpanded, mx.sym.expand_dims(local_kernels, axis=1)) |
| | output = mx.sym.sum(output, axis=2) |
| | return output |
| |
|
| |
|
| |
|
| | if __name__ == '__main__': |
| | data = mx.sym.Variable('data') |
| | local_kernels = mx.sym.Variable('local_kernels') |
| | K = 11 |
| | C = 3 |
| | H = 60 |
| | W = 60 |
| | batch_size = 32 |
| | local_kernels_npy = np.random.normal(size=(batch_size, K*K, H, W)) |
| | data_npy = np.random.normal(size=(batch_size, C, H, W)) |
| | out = data |
| | for i in range(10): |
| | out = DFN(data=out, local_kernels=local_kernels, K=K, batch_size=batch_size) |
| | exe = out.simple_bind(ctx=mx.gpu(), data=(batch_size, C, H, W), |
| | local_kernels=(batch_size, K*K, H, W)) |
| | exe.forward(data=data_npy, local_kernels=local_kernels_npy) |
| | print(exe.outputs[0].asnumpy().shape) |
| |
|