| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| from torch.autograd import Function |
|
|
| from ..utils import ext_loader |
|
|
| ext_module = ext_loader.load_ext('_ext', |
| ['tin_shift_forward', 'tin_shift_backward']) |
|
|
|
|
| class TINShiftFunction(Function): |
|
|
| @staticmethod |
| def forward(ctx, input, shift): |
| C = input.size(2) |
| num_segments = shift.size(1) |
| if C // num_segments <= 0 or C % num_segments != 0: |
| raise ValueError('C should be a multiple of num_segments, ' |
| f'but got C={C} and num_segments={num_segments}.') |
|
|
| ctx.save_for_backward(shift) |
|
|
| out = torch.zeros_like(input) |
| ext_module.tin_shift_forward(input, shift, out) |
|
|
| return out |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
|
|
| shift = ctx.saved_tensors[0] |
| data_grad_input = grad_output.new(*grad_output.size()).zero_() |
| shift_grad_input = shift.new(*shift.size()).zero_() |
| ext_module.tin_shift_backward(grad_output, shift, data_grad_input) |
|
|
| return data_grad_input, shift_grad_input |
|
|
|
|
| tin_shift = TINShiftFunction.apply |
|
|
|
|
| class TINShift(nn.Module): |
| """Temporal Interlace Shift. |
| |
| Temporal Interlace shift is a differentiable temporal-wise frame shifting |
| which is proposed in "Temporal Interlacing Network" |
| |
| Please refer to https://arxiv.org/abs/2001.06499 for more details. |
| Code is modified from https://github.com/mit-han-lab/temporal-shift-module |
| """ |
|
|
| def forward(self, input, shift): |
| """Perform temporal interlace shift. |
| |
| Args: |
| input (Tensor): Feature map with shape [N, num_segments, C, H * W]. |
| shift (Tensor): Shift tensor with shape [N, num_segments]. |
| |
| Returns: |
| Feature map after temporal interlace shift. |
| """ |
| return tin_shift(input, shift) |
|
|