| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.init as init |
|
|
| __all__ = ["BatchNormReimpl"] |
|
|
|
|
| class BatchNorm2dReimpl(nn.Module): |
| """ |
| A re-implementation of batch normalization, used for testing the numerical |
| stability. |
| |
| Author: acgtyrant |
| See also: |
| https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 |
| """ |
|
|
| def __init__(self, num_features, eps=1e-5, momentum=0.1): |
| super().__init__() |
|
|
| self.num_features = num_features |
| self.eps = eps |
| self.momentum = momentum |
| self.weight = nn.Parameter(torch.empty(num_features)) |
| self.bias = nn.Parameter(torch.empty(num_features)) |
| self.register_buffer("running_mean", torch.zeros(num_features)) |
| self.register_buffer("running_var", torch.ones(num_features)) |
| self.reset_parameters() |
|
|
| def reset_running_stats(self): |
| self.running_mean.zero_() |
| self.running_var.fill_(1) |
|
|
| def reset_parameters(self): |
| self.reset_running_stats() |
| init.uniform_(self.weight) |
| init.zeros_(self.bias) |
|
|
| def forward(self, input_): |
| batchsize, channels, height, width = input_.size() |
| numel = batchsize * height * width |
| input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) |
| sum_ = input_.sum(1) |
| sum_of_square = input_.pow(2).sum(1) |
| mean = sum_ / numel |
| sumvar = sum_of_square - sum_ * mean |
|
|
| self.running_mean = ( |
| 1 - self.momentum |
| ) * self.running_mean + self.momentum * mean.detach() |
| unbias_var = sumvar / (numel - 1) |
| self.running_var = ( |
| 1 - self.momentum |
| ) * self.running_var + self.momentum * unbias_var.detach() |
|
|
| bias_var = sumvar / numel |
| inv_std = 1 / (bias_var + self.eps).pow(0.5) |
| output = (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze( |
| 1 |
| ) * self.weight.unsqueeze(1) + self.bias.unsqueeze(1) |
|
|
| return ( |
| output.view(channels, batchsize, height, width) |
| .permute(1, 0, 2, 3) |
| .contiguous() |
| ) |
|
|