| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | import collections
|
| |
|
| | import torch
|
| | import torch.nn.functional as F
|
| |
|
| | from torch.nn.modules.batchnorm import _BatchNorm
|
| | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
| |
|
| | from .comm import SyncMaster
|
| |
|
| | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
| |
|
| |
|
| | def _sum_ft(tensor):
|
| | """sum over the first and last dimention"""
|
| | return tensor.sum(dim=0).sum(dim=-1)
|
| |
|
| |
|
| | def _unsqueeze_ft(tensor):
|
| | """add new dementions at the front and the tail"""
|
| | return tensor.unsqueeze(0).unsqueeze(-1)
|
| |
|
| |
|
| | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
| | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
| |
|
| |
|
| | class _SynchronizedBatchNorm(_BatchNorm):
|
| | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
| | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
| |
|
| | self._sync_master = SyncMaster(self._data_parallel_master)
|
| |
|
| | self._is_parallel = False
|
| | self._parallel_id = None
|
| | self._slave_pipe = None
|
| |
|
| | def forward(self, input, gain=None, bias=None):
|
| |
|
| | if not (self._is_parallel and self.training):
|
| | out = F.batch_norm(
|
| | input, self.running_mean, self.running_var, self.weight, self.bias,
|
| | self.training, self.momentum, self.eps)
|
| | if gain is not None:
|
| | out = out + gain
|
| | if bias is not None:
|
| | out = out + bias
|
| | return out
|
| |
|
| |
|
| | input_shape = input.size()
|
| |
|
| | input = input.view(input.size(0), input.size(1), -1)
|
| |
|
| |
|
| | sum_size = input.size(0) * input.size(2)
|
| | input_sum = _sum_ft(input)
|
| | input_ssum = _sum_ft(input ** 2)
|
| |
|
| |
|
| | if self._parallel_id == 0:
|
| | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
| | else:
|
| | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if gain is not None:
|
| |
|
| |
|
| |
|
| |
|
| | output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1)
|
| | elif self.affine:
|
| |
|
| | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
| | else:
|
| | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
| |
|
| |
|
| | return output.view(input_shape)
|
| |
|
| | def __data_parallel_replicate__(self, ctx, copy_id):
|
| | self._is_parallel = True
|
| | self._parallel_id = copy_id
|
| |
|
| |
|
| | if self._parallel_id == 0:
|
| | ctx.sync_master = self._sync_master
|
| | else:
|
| | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
| |
|
| | def _data_parallel_master(self, intermediates):
|
| | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
| |
|
| |
|
| |
|
| | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
| |
|
| | to_reduce = [i[1][:2] for i in intermediates]
|
| | to_reduce = [j for i in to_reduce for j in i]
|
| | target_gpus = [i[1].sum.get_device() for i in intermediates]
|
| |
|
| | sum_size = sum([i[1].sum_size for i in intermediates])
|
| | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
| | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
| |
|
| | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
| |
|
| |
|
| |
|
| |
|
| | outputs = []
|
| | for i, rec in enumerate(intermediates):
|
| | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
| |
|
| |
|
| | return outputs
|
| |
|
| | def _compute_mean_std(self, sum_, ssum, size):
|
| | """Compute the mean and standard-deviation with sum and square-sum. This method
|
| | also maintains the moving average on the master device."""
|
| | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
| | mean = sum_ / size
|
| | sumvar = ssum - sum_ * mean
|
| | unbias_var = sumvar / (size - 1)
|
| | bias_var = sumvar / size
|
| |
|
| | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
| | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
| | return mean, torch.rsqrt(bias_var + self.eps)
|
| |
|
| |
|
| |
|
| | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
| | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
| | mini-batch.
|
| |
|
| | .. math::
|
| |
|
| | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| |
|
| | This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
| | standard-deviation are reduced across all devices during training.
|
| |
|
| | For example, when one uses `nn.DataParallel` to wrap the network during
|
| | training, PyTorch's implementation normalize the tensor on each device using
|
| | the statistics only on that device, which accelerated the computation and
|
| | is also easy to implement, but the statistics might be inaccurate.
|
| | Instead, in this synchronized version, the statistics will be computed
|
| | over all training samples distributed on multiple devices.
|
| |
|
| | Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| | as the built-in PyTorch implementation.
|
| |
|
| | The mean and standard-deviation are calculated per-dimension over
|
| | the mini-batches and gamma and beta are learnable parameter vectors
|
| | of size C (where C is the input size).
|
| |
|
| | During training, this layer keeps a running estimate of its computed mean
|
| | and variance. The running sum is kept with a default momentum of 0.1.
|
| |
|
| | During evaluation, this running mean/variance is used for normalization.
|
| |
|
| | Because the BatchNorm is done over the `C` dimension, computing statistics
|
| | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
| |
|
| | Args:
|
| | num_features: num_features from an expected input of size
|
| | `batch_size x num_features [x width]`
|
| | eps: a value added to the denominator for numerical stability.
|
| | Default: 1e-5
|
| | momentum: the value used for the running_mean and running_var
|
| | computation. Default: 0.1
|
| | affine: a boolean value that when set to ``True``, gives the layer learnable
|
| | affine parameters. Default: ``True``
|
| |
|
| | Shape:
|
| | - Input: :math:`(N, C)` or :math:`(N, C, L)`
|
| | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
| |
|
| | Examples:
|
| | >>> # With Learnable Parameters
|
| | >>> m = SynchronizedBatchNorm1d(100)
|
| | >>> # Without Learnable Parameters
|
| | >>> m = SynchronizedBatchNorm1d(100, affine=False)
|
| | >>> input = torch.autograd.Variable(torch.randn(20, 100))
|
| | >>> output = m(input)
|
| | """
|
| |
|
| | def _check_input_dim(self, input):
|
| | if input.dim() != 2 and input.dim() != 3:
|
| | raise ValueError('expected 2D or 3D input (got {}D input)'
|
| | .format(input.dim()))
|
| | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
| |
|
| |
|
| | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
| | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
| | of 3d inputs
|
| |
|
| | .. math::
|
| |
|
| | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| |
|
| | This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
| | standard-deviation are reduced across all devices during training.
|
| |
|
| | For example, when one uses `nn.DataParallel` to wrap the network during
|
| | training, PyTorch's implementation normalize the tensor on each device using
|
| | the statistics only on that device, which accelerated the computation and
|
| | is also easy to implement, but the statistics might be inaccurate.
|
| | Instead, in this synchronized version, the statistics will be computed
|
| | over all training samples distributed on multiple devices.
|
| |
|
| | Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| | as the built-in PyTorch implementation.
|
| |
|
| | The mean and standard-deviation are calculated per-dimension over
|
| | the mini-batches and gamma and beta are learnable parameter vectors
|
| | of size C (where C is the input size).
|
| |
|
| | During training, this layer keeps a running estimate of its computed mean
|
| | and variance. The running sum is kept with a default momentum of 0.1.
|
| |
|
| | During evaluation, this running mean/variance is used for normalization.
|
| |
|
| | Because the BatchNorm is done over the `C` dimension, computing statistics
|
| | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
| |
|
| | Args:
|
| | num_features: num_features from an expected input of
|
| | size batch_size x num_features x height x width
|
| | eps: a value added to the denominator for numerical stability.
|
| | Default: 1e-5
|
| | momentum: the value used for the running_mean and running_var
|
| | computation. Default: 0.1
|
| | affine: a boolean value that when set to ``True``, gives the layer learnable
|
| | affine parameters. Default: ``True``
|
| |
|
| | Shape:
|
| | - Input: :math:`(N, C, H, W)`
|
| | - Output: :math:`(N, C, H, W)` (same shape as input)
|
| |
|
| | Examples:
|
| | >>> # With Learnable Parameters
|
| | >>> m = SynchronizedBatchNorm2d(100)
|
| | >>> # Without Learnable Parameters
|
| | >>> m = SynchronizedBatchNorm2d(100, affine=False)
|
| | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
| | >>> output = m(input)
|
| | """
|
| |
|
| | def _check_input_dim(self, input):
|
| | if input.dim() != 4:
|
| | raise ValueError('expected 4D input (got {}D input)'
|
| | .format(input.dim()))
|
| | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
| |
|
| |
|
| | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
| | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
| | of 4d inputs
|
| |
|
| | .. math::
|
| |
|
| | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| |
|
| | This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
| | standard-deviation are reduced across all devices during training.
|
| |
|
| | For example, when one uses `nn.DataParallel` to wrap the network during
|
| | training, PyTorch's implementation normalize the tensor on each device using
|
| | the statistics only on that device, which accelerated the computation and
|
| | is also easy to implement, but the statistics might be inaccurate.
|
| | Instead, in this synchronized version, the statistics will be computed
|
| | over all training samples distributed on multiple devices.
|
| |
|
| | Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| | as the built-in PyTorch implementation.
|
| |
|
| | The mean and standard-deviation are calculated per-dimension over
|
| | the mini-batches and gamma and beta are learnable parameter vectors
|
| | of size C (where C is the input size).
|
| |
|
| | During training, this layer keeps a running estimate of its computed mean
|
| | and variance. The running sum is kept with a default momentum of 0.1.
|
| |
|
| | During evaluation, this running mean/variance is used for normalization.
|
| |
|
| | Because the BatchNorm is done over the `C` dimension, computing statistics
|
| | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
| | or Spatio-temporal BatchNorm
|
| |
|
| | Args:
|
| | num_features: num_features from an expected input of
|
| | size batch_size x num_features x depth x height x width
|
| | eps: a value added to the denominator for numerical stability.
|
| | Default: 1e-5
|
| | momentum: the value used for the running_mean and running_var
|
| | computation. Default: 0.1
|
| | affine: a boolean value that when set to ``True``, gives the layer learnable
|
| | affine parameters. Default: ``True``
|
| |
|
| | Shape:
|
| | - Input: :math:`(N, C, D, H, W)`
|
| | - Output: :math:`(N, C, D, H, W)` (same shape as input)
|
| |
|
| | Examples:
|
| | >>> # With Learnable Parameters
|
| | >>> m = SynchronizedBatchNorm3d(100)
|
| | >>> # Without Learnable Parameters
|
| | >>> m = SynchronizedBatchNorm3d(100, affine=False)
|
| | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
| | >>> output = m(input)
|
| | """
|
| |
|
| | def _check_input_dim(self, input):
|
| | if input.dim() != 5:
|
| | raise ValueError('expected 5D input (got {}D input)'
|
| | .format(input.dim()))
|
| | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) |