| | import torch |
| | import re |
| | import cupy |
| |
|
| | from modules.cupy_module.cupy_utils import cupy_launch |
| |
|
| | |
| |
|
| | kernel_Softsplat_updateOutput = ''' |
| | extern "C" __global__ void kernel_Softsplat_updateOutput( |
| | const int n, |
| | const float* input, |
| | const float* flow, |
| | float* output |
| | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { |
| | const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); |
| | const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); |
| | const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); |
| | const int intX = ( intIndex ) % SIZE_3(output); |
| | |
| | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); |
| | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); |
| | |
| | int intNorthwestX = (int) (floor(fltOutputX)); |
| | int intNorthwestY = (int) (floor(fltOutputY)); |
| | int intNortheastX = intNorthwestX + 1; |
| | int intNortheastY = intNorthwestY; |
| | int intSouthwestX = intNorthwestX; |
| | int intSouthwestY = intNorthwestY + 1; |
| | int intSoutheastX = intNorthwestX + 1; |
| | int intSoutheastY = intNorthwestY + 1; |
| | |
| | float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY); |
| | float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY); |
| | float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY)); |
| | float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); |
| | |
| | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { |
| | atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest); |
| | } |
| | |
| | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { |
| | atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast); |
| | } |
| | |
| | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { |
| | atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest); |
| | } |
| | |
| | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { |
| | atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast); |
| | } |
| | } } |
| | ''' |
| |
|
| | kernel_Softsplat_updateGradInput = ''' |
| | extern "C" __global__ void kernel_Softsplat_updateGradInput( |
| | const int n, |
| | const float* input, |
| | const float* flow, |
| | const float* gradOutput, |
| | float* gradInput, |
| | float* gradFlow |
| | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { |
| | const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); |
| | const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); |
| | const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); |
| | const int intX = ( intIndex ) % SIZE_3(gradInput); |
| | |
| | float fltGradInput = 0.0; |
| | |
| | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); |
| | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); |
| | |
| | int intNorthwestX = (int) (floor(fltOutputX)); |
| | int intNorthwestY = (int) (floor(fltOutputY)); |
| | int intNortheastX = intNorthwestX + 1; |
| | int intNortheastY = intNorthwestY; |
| | int intSouthwestX = intNorthwestX; |
| | int intSouthwestY = intNorthwestY + 1; |
| | int intSoutheastX = intNorthwestX + 1; |
| | int intSoutheastY = intNorthwestY + 1; |
| | |
| | float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY); |
| | float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY); |
| | float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY)); |
| | float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); |
| | |
| | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { |
| | fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; |
| | } |
| | |
| | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { |
| | fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; |
| | } |
| | |
| | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { |
| | fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; |
| | } |
| | |
| | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { |
| | fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; |
| | } |
| | |
| | gradInput[intIndex] = fltGradInput; |
| | } } |
| | ''' |
| |
|
| | kernel_Softsplat_updateGradFlow = ''' |
| | extern "C" __global__ void kernel_Softsplat_updateGradFlow( |
| | const int n, |
| | const float* input, |
| | const float* flow, |
| | const float* gradOutput, |
| | float* gradInput, |
| | float* gradFlow |
| | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { |
| | float fltGradFlow = 0.0; |
| | |
| | const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); |
| | const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); |
| | const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); |
| | const int intX = ( intIndex ) % SIZE_3(gradFlow); |
| | |
| | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); |
| | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); |
| | |
| | int intNorthwestX = (int) (floor(fltOutputX)); |
| | int intNorthwestY = (int) (floor(fltOutputY)); |
| | int intNortheastX = intNorthwestX + 1; |
| | int intNortheastY = intNorthwestY; |
| | int intSouthwestX = intNorthwestX; |
| | int intSouthwestY = intNorthwestY + 1; |
| | int intSoutheastX = intNorthwestX + 1; |
| | int intSoutheastY = intNorthwestY + 1; |
| | |
| | float fltNorthwest = 0.0; |
| | float fltNortheast = 0.0; |
| | float fltSouthwest = 0.0; |
| | float fltSoutheast = 0.0; |
| | |
| | if (intC == 0) { |
| | fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY); |
| | fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY); |
| | fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); |
| | fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); |
| | |
| | } else if (intC == 1) { |
| | fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (-1.0)); |
| | fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); |
| | fltSouthwest = ((float) (intNortheastX) - fltOutputX) * ((float) (+1.0)); |
| | fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); |
| | |
| | } |
| | |
| | for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { |
| | float fltInput = VALUE_4(input, intN, intChannel, intY, intX); |
| | |
| | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { |
| | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; |
| | } |
| | |
| | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { |
| | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; |
| | } |
| | |
| | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { |
| | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; |
| | } |
| | |
| | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { |
| | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; |
| | } |
| | } |
| | |
| | gradFlow[intIndex] = fltGradFlow; |
| | } } |
| | ''' |
| |
|
| | def cupy_kernel(strFunction, objVariables): |
| | strKernel = globals()[strFunction] |
| |
|
| | while True: |
| | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) |
| |
|
| | if objMatch is None: |
| | break |
| | |
| |
|
| | intArg = int(objMatch.group(2)) |
| |
|
| | strTensor = objMatch.group(4) |
| | intSizes = objVariables[strTensor].size() |
| |
|
| | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) |
| | |
| |
|
| | while True: |
| | objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) |
| |
|
| | if objMatch is None: |
| | break |
| | |
| |
|
| | intArgs = int(objMatch.group(2)) |
| | strArgs = objMatch.group(4).split(',') |
| |
|
| | strTensor = strArgs[0] |
| | intStrides = objVariables[strTensor].stride() |
| | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] |
| |
|
| | strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') |
| | |
| |
|
| | while True: |
| | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) |
| |
|
| | if objMatch is None: |
| | break |
| | |
| |
|
| | intArgs = int(objMatch.group(2)) |
| | strArgs = objMatch.group(4).split(',') |
| |
|
| | strTensor = strArgs[0] |
| | intStrides = objVariables[strTensor].stride() |
| | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] |
| |
|
| | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') |
| | |
| |
|
| | return strKernel |
| | |
| |
|
| | class _FunctionSoftsplat(torch.autograd.Function): |
| | @staticmethod |
| | def forward(self, input, flow): |
| | intSamples = input.shape[0] |
| | intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] |
| | intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] |
| |
|
| | assert(intFlowDepth == 2) |
| | assert(intInputHeight == intFlowHeight) |
| | assert(intInputWidth == intFlowWidth) |
| |
|
| | input = input.contiguous(); assert(input.is_cuda == True) |
| | flow = flow.contiguous(); assert(flow.is_cuda == True) |
| |
|
| | output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) |
| |
|
| | if input.is_cuda == True: |
| | n = output.nelement() |
| | cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { |
| | 'input': input, |
| | 'flow': flow, |
| | 'output': output |
| | }))( |
| | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), |
| | block=tuple([ 512, 1, 1 ]), |
| | args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), output.data_ptr() ] |
| | ) |
| |
|
| | elif input.is_cuda == False: |
| | raise NotImplementedError() |
| |
|
| | |
| |
|
| | self.save_for_backward(input, flow) |
| |
|
| | return output |
| | |
| |
|
| | @staticmethod |
| | def backward(self, gradOutput): |
| | input, flow = self.saved_tensors |
| |
|
| | intSamples = input.shape[0] |
| | intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] |
| | intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] |
| |
|
| | assert(intFlowDepth == 2) |
| | assert(intInputHeight == intFlowHeight) |
| | assert(intInputWidth == intFlowWidth) |
| |
|
| | gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True) |
| |
|
| | gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) if self.needs_input_grad[0] == True else None |
| | gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ]) if self.needs_input_grad[1] == True else None |
| |
|
| | if input.is_cuda == True: |
| | if gradInput is not None: |
| | n = gradInput.nelement() |
| | cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { |
| | 'input': input, |
| | 'flow': flow, |
| | 'gradOutput': gradOutput, |
| | 'gradInput': gradInput, |
| | 'gradFlow': gradFlow |
| | }))( |
| | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), |
| | block=tuple([ 512, 1, 1 ]), |
| | args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] |
| | ) |
| | |
| |
|
| | if gradFlow is not None: |
| | n = gradFlow.nelement() |
| | cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { |
| | 'input': input, |
| | 'flow': flow, |
| | 'gradOutput': gradOutput, |
| | 'gradInput': gradInput, |
| | 'gradFlow': gradFlow |
| | }))( |
| | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), |
| | block=tuple([ 512, 1, 1 ]), |
| | args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] |
| | ) |
| | |
| |
|
| | elif input.is_cuda == False: |
| | raise NotImplementedError() |
| |
|
| | |
| |
|
| | return gradInput, gradFlow |
| | |
| | |
| |
|
| | def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): |
| | assert(tenMetric is None or tenMetric.shape[1] == 1) |
| | assert(strType in ['summation', 'average', 'linear', 'softmax']) |
| |
|
| | if strType == 'average': |
| | tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) |
| |
|
| | elif strType == 'linear': |
| | tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) |
| |
|
| | elif strType == 'softmax': |
| | tenInput = torch.cat([ tenInput * tenMetric.exp(), tenMetric.exp() ], 1) |
| |
|
| | |
| |
|
| | tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow) |
| |
|
| | if strType != 'summation': |
| | tenNormalize = tenOutput[:, -1:, :, :] |
| |
|
| | tenNormalize[tenNormalize == 0.0] = 1.0 |
| |
|
| | tenOutput = tenOutput[:, :-1, :, :] / tenNormalize |
| | |
| |
|
| | return tenOutput |
| | |
| |
|
| | class ModuleSoftsplat(torch.nn.Module): |
| | def __init__(self, strType): |
| | super().__init__() |
| |
|
| | self.strType = strType |
| | |
| |
|
| | def forward(self, tenInput, tenFlow, tenMetric): |
| | return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) |
| | |
| | |