| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .registry import CONV_LAYERS |
|
|
|
|
| def conv_ws_2d(input, |
| weight, |
| bias=None, |
| stride=1, |
| padding=0, |
| dilation=1, |
| groups=1, |
| eps=1e-5): |
| c_in = weight.size(0) |
| weight_flat = weight.view(c_in, -1) |
| mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1) |
| std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1) |
| weight = (weight - mean) / (std + eps) |
| return F.conv2d(input, weight, bias, stride, padding, dilation, groups) |
|
|
|
|
| @CONV_LAYERS.register_module('ConvWS') |
| class ConvWS2d(nn.Conv2d): |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=1, |
| padding=0, |
| dilation=1, |
| groups=1, |
| bias=True, |
| eps=1e-5): |
| super(ConvWS2d, self).__init__( |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| groups=groups, |
| bias=bias) |
| self.eps = eps |
|
|
| def forward(self, x): |
| return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding, |
| self.dilation, self.groups, self.eps) |
|
|
|
|
| @CONV_LAYERS.register_module(name='ConvAWS') |
| class ConvAWS2d(nn.Conv2d): |
| """AWS (Adaptive Weight Standardization) |
| |
| This is a variant of Weight Standardization |
| (https://arxiv.org/pdf/1903.10520.pdf) |
| It is used in DetectoRS to avoid NaN |
| (https://arxiv.org/pdf/2006.02334.pdf) |
| |
| Args: |
| in_channels (int): Number of channels in the input image |
| out_channels (int): Number of channels produced by the convolution |
| kernel_size (int or tuple): Size of the conv kernel |
| stride (int or tuple, optional): Stride of the convolution. Default: 1 |
| padding (int or tuple, optional): Zero-padding added to both sides of |
| the input. Default: 0 |
| dilation (int or tuple, optional): Spacing between kernel elements. |
| Default: 1 |
| groups (int, optional): Number of blocked connections from input |
| channels to output channels. Default: 1 |
| bias (bool, optional): If set True, adds a learnable bias to the |
| output. Default: True |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=1, |
| padding=0, |
| dilation=1, |
| groups=1, |
| bias=True): |
| super().__init__( |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| groups=groups, |
| bias=bias) |
| self.register_buffer('weight_gamma', |
| torch.ones(self.out_channels, 1, 1, 1)) |
| self.register_buffer('weight_beta', |
| torch.zeros(self.out_channels, 1, 1, 1)) |
|
|
| def _get_weight(self, weight): |
| weight_flat = weight.view(weight.size(0), -1) |
| mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) |
| std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) |
| weight = (weight - mean) / std |
| weight = self.weight_gamma * weight + self.weight_beta |
| return weight |
|
|
| def forward(self, x): |
| weight = self._get_weight(self.weight) |
| return F.conv2d(x, weight, self.bias, self.stride, self.padding, |
| self.dilation, self.groups) |
|
|
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs): |
| """Override default load function. |
| |
| AWS overrides the function _load_from_state_dict to recover |
| weight_gamma and weight_beta if they are missing. If weight_gamma and |
| weight_beta are found in the checkpoint, this function will return |
| after super()._load_from_state_dict. Otherwise, it will compute the |
| mean and std of the pretrained weights and store them in weight_beta |
| and weight_gamma. |
| """ |
|
|
| self.weight_gamma.data.fill_(-1) |
| local_missing_keys = [] |
| super()._load_from_state_dict(state_dict, prefix, local_metadata, |
| strict, local_missing_keys, |
| unexpected_keys, error_msgs) |
| if self.weight_gamma.data.mean() > 0: |
| for k in local_missing_keys: |
| missing_keys.append(k) |
| return |
| weight = self.weight.data |
| weight_flat = weight.view(weight.size(0), -1) |
| mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) |
| std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) |
| self.weight_beta.data.copy_(mean) |
| self.weight_gamma.data.copy_(std) |
| missing_gamma_beta = [ |
| k for k in local_missing_keys |
| if k.endswith('weight_gamma') or k.endswith('weight_beta') |
| ] |
| for k in missing_gamma_beta: |
| local_missing_keys.remove(k) |
| for k in local_missing_keys: |
| missing_keys.append(k) |
|
|