| import contextlib
|
| import warnings
|
|
|
| import torch
|
| from torch import autograd
|
| from torch.nn import functional as F
|
|
|
| enabled = True
|
| weight_gradients_disabled = False
|
|
|
|
|
| @contextlib.contextmanager
|
| def no_weight_gradients():
|
| global weight_gradients_disabled
|
|
|
| old = weight_gradients_disabled
|
| weight_gradients_disabled = True
|
| yield
|
| weight_gradients_disabled = old
|
|
|
|
|
| def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
| if could_use_op(input):
|
| return conv2d_gradfix(
|
| transpose=False,
|
| weight_shape=weight.shape,
|
| stride=stride,
|
| padding=padding,
|
| output_padding=0,
|
| dilation=dilation,
|
| groups=groups,
|
| ).apply(input, weight, bias)
|
|
|
| return F.conv2d(
|
| input=input,
|
| weight=weight,
|
| bias=bias,
|
| stride=stride,
|
| padding=padding,
|
| dilation=dilation,
|
| groups=groups,
|
| )
|
|
|
|
|
| def conv_transpose2d(
|
| input,
|
| weight,
|
| bias=None,
|
| stride=1,
|
| padding=0,
|
| output_padding=0,
|
| groups=1,
|
| dilation=1,
|
| ):
|
| if could_use_op(input):
|
| return conv2d_gradfix(
|
| transpose=True,
|
| weight_shape=weight.shape,
|
| stride=stride,
|
| padding=padding,
|
| output_padding=output_padding,
|
| groups=groups,
|
| dilation=dilation,
|
| ).apply(input, weight, bias)
|
|
|
| return F.conv_transpose2d(
|
| input=input,
|
| weight=weight,
|
| bias=bias,
|
| stride=stride,
|
| padding=padding,
|
| output_padding=output_padding,
|
| dilation=dilation,
|
| groups=groups,
|
| )
|
|
|
|
|
| def could_use_op(input):
|
| if (not enabled) or (not torch.backends.cudnn.enabled):
|
| return False
|
|
|
| if input.device.type != "cuda":
|
| return False
|
|
|
| if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
|
| return True
|
|
|
| warnings.warn(
|
| f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
|
| )
|
|
|
| return False
|
|
|
|
|
| def ensure_tuple(xs, ndim):
|
| xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
|
|
| return xs
|
|
|
|
|
| conv2d_gradfix_cache = dict()
|
|
|
|
|
| def conv2d_gradfix(
|
| transpose, weight_shape, stride, padding, output_padding, dilation, groups
|
| ):
|
| ndim = 2
|
| weight_shape = tuple(weight_shape)
|
| stride = ensure_tuple(stride, ndim)
|
| padding = ensure_tuple(padding, ndim)
|
| output_padding = ensure_tuple(output_padding, ndim)
|
| dilation = ensure_tuple(dilation, ndim)
|
|
|
| key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
| if key in conv2d_gradfix_cache:
|
| return conv2d_gradfix_cache[key]
|
|
|
| common_kwargs = dict(
|
| stride=stride, padding=padding, dilation=dilation, groups=groups
|
| )
|
|
|
| def calc_output_padding(input_shape, output_shape):
|
| if transpose:
|
| return [0, 0]
|
|
|
| return [
|
| input_shape[i + 2]
|
| - (output_shape[i + 2] - 1) * stride[i]
|
| - (1 - 2 * padding[i])
|
| - dilation[i] * (weight_shape[i + 2] - 1)
|
| for i in range(ndim)
|
| ]
|
|
|
| class Conv2d(autograd.Function):
|
| @staticmethod
|
| def forward(ctx, input, weight, bias):
|
| if not transpose:
|
| out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
|
|
| else:
|
| out = F.conv_transpose2d(
|
| input=input,
|
| weight=weight,
|
| bias=bias,
|
| output_padding=output_padding,
|
| **common_kwargs,
|
| )
|
|
|
| ctx.save_for_backward(input, weight)
|
|
|
| return out
|
|
|
| @staticmethod
|
| def backward(ctx, grad_output):
|
| input, weight = ctx.saved_tensors
|
| grad_input, grad_weight, grad_bias = None, None, None
|
|
|
| if ctx.needs_input_grad[0]:
|
| p = calc_output_padding(
|
| input_shape=input.shape, output_shape=grad_output.shape
|
| )
|
| grad_input = conv2d_gradfix(
|
| transpose=(not transpose),
|
| weight_shape=weight_shape,
|
| output_padding=p,
|
| **common_kwargs,
|
| ).apply(grad_output, weight, None)
|
|
|
| if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
| grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
|
|
| if ctx.needs_input_grad[2]:
|
| grad_bias = grad_output.sum((0, 2, 3))
|
|
|
| return grad_input, grad_weight, grad_bias
|
|
|
| class Conv2dGradWeight(autograd.Function):
|
| @staticmethod
|
| def forward(ctx, grad_output, input):
|
| op = torch._C._jit_get_operation(
|
| "aten::cudnn_convolution_backward_weight"
|
| if not transpose
|
| else "aten::cudnn_convolution_transpose_backward_weight"
|
| )
|
| flags = [
|
| torch.backends.cudnn.benchmark,
|
| torch.backends.cudnn.deterministic,
|
| torch.backends.cudnn.allow_tf32,
|
| ]
|
| grad_weight = op(
|
| weight_shape,
|
| grad_output,
|
| input,
|
| padding,
|
| stride,
|
| dilation,
|
| groups,
|
| *flags,
|
| )
|
| ctx.save_for_backward(grad_output, input)
|
|
|
| return grad_weight
|
|
|
| @staticmethod
|
| def backward(ctx, grad_grad_weight):
|
| grad_output, input = ctx.saved_tensors
|
| grad_grad_output, grad_grad_input = None, None
|
|
|
| if ctx.needs_input_grad[0]:
|
| grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
|
|
|
| if ctx.needs_input_grad[1]:
|
| p = calc_output_padding(
|
| input_shape=input.shape, output_shape=grad_output.shape
|
| )
|
| grad_grad_input = conv2d_gradfix(
|
| transpose=(not transpose),
|
| weight_shape=weight_shape,
|
| output_padding=p,
|
| **common_kwargs,
|
| ).apply(grad_output, grad_grad_weight, None)
|
|
|
| return grad_grad_output, grad_grad_input
|
|
|
| conv2d_gradfix_cache[key] = Conv2d
|
|
|
| return Conv2d
|
|
|