| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import collections |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from torch.nn.modules.batchnorm import _BatchNorm |
| from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast |
|
|
| from .comm import SyncMaster |
|
|
| __all__ = [ |
| "SynchronizedBatchNorm1d", |
| "SynchronizedBatchNorm2d", |
| "SynchronizedBatchNorm3d", |
| ] |
|
|
|
|
| def _sum_ft(tensor): |
| """sum over the first and last dimention""" |
| return tensor.sum(dim=0).sum(dim=-1) |
|
|
|
|
| def _unsqueeze_ft(tensor): |
| """add new dementions at the front and the tail""" |
| return tensor.unsqueeze(0).unsqueeze(-1) |
|
|
|
|
| _ChildMessage = collections.namedtuple("_ChildMessage", ["sum", "ssum", "sum_size"]) |
| _MasterMessage = collections.namedtuple("_MasterMessage", ["sum", "inv_std"]) |
| |
|
|
|
|
| class _SynchronizedBatchNorm(_BatchNorm): |
| def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): |
| super(_SynchronizedBatchNorm, self).__init__( |
| num_features, eps=eps, momentum=momentum, affine=affine |
| ) |
|
|
| self._sync_master = SyncMaster(self._data_parallel_master) |
|
|
| self._is_parallel = False |
| self._parallel_id = None |
| self._slave_pipe = None |
|
|
| def forward(self, input, gain=None, bias=None): |
| |
| if not (self._is_parallel and self.training): |
| out = F.batch_norm( |
| input, |
| self.running_mean, |
| self.running_var, |
| self.weight, |
| self.bias, |
| self.training, |
| self.momentum, |
| self.eps, |
| ) |
| if gain is not None: |
| out = out + gain |
| if bias is not None: |
| out = out + bias |
| return out |
|
|
| |
| input_shape = input.size() |
| |
| input = input.view(input.size(0), input.size(1), -1) |
|
|
| |
| sum_size = input.size(0) * input.size(2) |
| input_sum = _sum_ft(input) |
| input_ssum = _sum_ft(input ** 2) |
| |
| |
| if self._parallel_id == 0: |
| mean, inv_std = self._sync_master.run_master( |
| _ChildMessage(input_sum, input_ssum, sum_size) |
| ) |
| else: |
| mean, inv_std = self._slave_pipe.run_slave( |
| _ChildMessage(input_sum, input_ssum, sum_size) |
| ) |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| if gain is not None: |
| |
| |
| |
| |
| output = (input - _unsqueeze_ft(mean)) * ( |
| _unsqueeze_ft(inv_std) * gain.squeeze(-1) |
| ) + bias.squeeze(-1) |
| elif self.affine: |
| |
| output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft( |
| inv_std * self.weight |
| ) + _unsqueeze_ft(self.bias) |
| else: |
| output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) |
|
|
| |
| return output.view(input_shape) |
|
|
| def __data_parallel_replicate__(self, ctx, copy_id): |
| self._is_parallel = True |
| self._parallel_id = copy_id |
|
|
| |
| if self._parallel_id == 0: |
| ctx.sync_master = self._sync_master |
| else: |
| self._slave_pipe = ctx.sync_master.register_slave(copy_id) |
|
|
| def _data_parallel_master(self, intermediates): |
| """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" |
|
|
| |
| |
| intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) |
|
|
| to_reduce = [i[1][:2] for i in intermediates] |
| to_reduce = [j for i in to_reduce for j in i] |
| target_gpus = [i[1].sum.get_device() for i in intermediates] |
|
|
| sum_size = sum([i[1].sum_size for i in intermediates]) |
| sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) |
| mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) |
|
|
| broadcasted = Broadcast.apply(target_gpus, mean, inv_std) |
| |
| |
| |
| |
| outputs = [] |
| for i, rec in enumerate(intermediates): |
| outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2 : i * 2 + 2]))) |
| |
|
|
| return outputs |
|
|
| def _compute_mean_std(self, sum_, ssum, size): |
| """Compute the mean and standard-deviation with sum and square-sum. This method |
| also maintains the moving average on the master device.""" |
| assert ( |
| size > 1 |
| ), "BatchNorm computes unbiased standard-deviation, which requires size > 1." |
| mean = sum_ / size |
| sumvar = ssum - sum_ * mean |
| unbias_var = sumvar / (size - 1) |
| bias_var = sumvar / size |
|
|
| self.running_mean = ( |
| 1 - self.momentum |
| ) * self.running_mean + self.momentum * mean.data |
| self.running_var = ( |
| 1 - self.momentum |
| ) * self.running_var + self.momentum * unbias_var.data |
| return mean, torch.rsqrt(bias_var + self.eps) |
| |
|
|
|
|
| class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): |
| r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a |
| mini-batch. |
| |
| .. math:: |
| |
| y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta |
| |
| This module differs from the built-in PyTorch BatchNorm1d as the mean and |
| standard-deviation are reduced across all devices during training. |
| |
| For example, when one uses `nn.DataParallel` to wrap the network during |
| training, PyTorch's implementation normalize the tensor on each device using |
| the statistics only on that device, which accelerated the computation and |
| is also easy to implement, but the statistics might be inaccurate. |
| Instead, in this synchronized version, the statistics will be computed |
| over all training samples distributed on multiple devices. |
| |
| Note that, for one-GPU or CPU-only case, this module behaves exactly same |
| as the built-in PyTorch implementation. |
| |
| The mean and standard-deviation are calculated per-dimension over |
| the mini-batches and gamma and beta are learnable parameter vectors |
| of size C (where C is the input size). |
| |
| During training, this layer keeps a running estimate of its computed mean |
| and variance. The running sum is kept with a default momentum of 0.1. |
| |
| During evaluation, this running mean/variance is used for normalization. |
| |
| Because the BatchNorm is done over the `C` dimension, computing statistics |
| on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm |
| |
| Args: |
| num_features: num_features from an expected input of size |
| `batch_size x num_features [x width]` |
| eps: a value added to the denominator for numerical stability. |
| Default: 1e-5 |
| momentum: the value used for the running_mean and running_var |
| computation. Default: 0.1 |
| affine: a boolean value that when set to ``True``, gives the layer learnable |
| affine parameters. Default: ``True`` |
| |
| Shape: |
| - Input: :math:`(N, C)` or :math:`(N, C, L)` |
| - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) |
| |
| Examples: |
| >>> # With Learnable Parameters |
| >>> m = SynchronizedBatchNorm1d(100) |
| >>> # Without Learnable Parameters |
| >>> m = SynchronizedBatchNorm1d(100, affine=False) |
| >>> input = torch.autograd.Variable(torch.randn(20, 100)) |
| >>> output = m(input) |
| """ |
|
|
| def _check_input_dim(self, input): |
| if input.dim() != 2 and input.dim() != 3: |
| raise ValueError( |
| "expected 2D or 3D input (got {}D input)".format(input.dim()) |
| ) |
| super(SynchronizedBatchNorm1d, self)._check_input_dim(input) |
|
|
|
|
| class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): |
| r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch |
| of 3d inputs |
| |
| .. math:: |
| |
| y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta |
| |
| This module differs from the built-in PyTorch BatchNorm2d as the mean and |
| standard-deviation are reduced across all devices during training. |
| |
| For example, when one uses `nn.DataParallel` to wrap the network during |
| training, PyTorch's implementation normalize the tensor on each device using |
| the statistics only on that device, which accelerated the computation and |
| is also easy to implement, but the statistics might be inaccurate. |
| Instead, in this synchronized version, the statistics will be computed |
| over all training samples distributed on multiple devices. |
| |
| Note that, for one-GPU or CPU-only case, this module behaves exactly same |
| as the built-in PyTorch implementation. |
| |
| The mean and standard-deviation are calculated per-dimension over |
| the mini-batches and gamma and beta are learnable parameter vectors |
| of size C (where C is the input size). |
| |
| During training, this layer keeps a running estimate of its computed mean |
| and variance. The running sum is kept with a default momentum of 0.1. |
| |
| During evaluation, this running mean/variance is used for normalization. |
| |
| Because the BatchNorm is done over the `C` dimension, computing statistics |
| on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm |
| |
| Args: |
| num_features: num_features from an expected input of |
| size batch_size x num_features x height x width |
| eps: a value added to the denominator for numerical stability. |
| Default: 1e-5 |
| momentum: the value used for the running_mean and running_var |
| computation. Default: 0.1 |
| affine: a boolean value that when set to ``True``, gives the layer learnable |
| affine parameters. Default: ``True`` |
| |
| Shape: |
| - Input: :math:`(N, C, H, W)` |
| - Output: :math:`(N, C, H, W)` (same shape as input) |
| |
| Examples: |
| >>> # With Learnable Parameters |
| >>> m = SynchronizedBatchNorm2d(100) |
| >>> # Without Learnable Parameters |
| >>> m = SynchronizedBatchNorm2d(100, affine=False) |
| >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) |
| >>> output = m(input) |
| """ |
|
|
| def _check_input_dim(self, input): |
| if input.dim() != 4: |
| raise ValueError("expected 4D input (got {}D input)".format(input.dim())) |
| super(SynchronizedBatchNorm2d, self)._check_input_dim(input) |
|
|
|
|
| class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): |
| r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch |
| of 4d inputs |
| |
| .. math:: |
| |
| y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta |
| |
| This module differs from the built-in PyTorch BatchNorm3d as the mean and |
| standard-deviation are reduced across all devices during training. |
| |
| For example, when one uses `nn.DataParallel` to wrap the network during |
| training, PyTorch's implementation normalize the tensor on each device using |
| the statistics only on that device, which accelerated the computation and |
| is also easy to implement, but the statistics might be inaccurate. |
| Instead, in this synchronized version, the statistics will be computed |
| over all training samples distributed on multiple devices. |
| |
| Note that, for one-GPU or CPU-only case, this module behaves exactly same |
| as the built-in PyTorch implementation. |
| |
| The mean and standard-deviation are calculated per-dimension over |
| the mini-batches and gamma and beta are learnable parameter vectors |
| of size C (where C is the input size). |
| |
| During training, this layer keeps a running estimate of its computed mean |
| and variance. The running sum is kept with a default momentum of 0.1. |
| |
| During evaluation, this running mean/variance is used for normalization. |
| |
| Because the BatchNorm is done over the `C` dimension, computing statistics |
| on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm |
| or Spatio-temporal BatchNorm |
| |
| Args: |
| num_features: num_features from an expected input of |
| size batch_size x num_features x depth x height x width |
| eps: a value added to the denominator for numerical stability. |
| Default: 1e-5 |
| momentum: the value used for the running_mean and running_var |
| computation. Default: 0.1 |
| affine: a boolean value that when set to ``True``, gives the layer learnable |
| affine parameters. Default: ``True`` |
| |
| Shape: |
| - Input: :math:`(N, C, D, H, W)` |
| - Output: :math:`(N, C, D, H, W)` (same shape as input) |
| |
| Examples: |
| >>> # With Learnable Parameters |
| >>> m = SynchronizedBatchNorm3d(100) |
| >>> # Without Learnable Parameters |
| >>> m = SynchronizedBatchNorm3d(100, affine=False) |
| >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) |
| >>> output = m(input) |
| """ |
|
|
| def _check_input_dim(self, input): |
| if input.dim() != 5: |
| raise ValueError("expected 5D input (got {}D input)".format(input.dim())) |
| super(SynchronizedBatchNorm3d, self)._check_input_dim(input) |
|
|