|
|
|
|
|
|
| import torch
|
| from torch import nn
|
| from torch.nn import functional as F
|
| from torch.nn.modules.utils import _pair
|
|
|
|
|
| class MultiScaleTridentConv(nn.Module):
|
| def __init__(
|
| self,
|
| in_channels,
|
| out_channels,
|
| kernel_size,
|
| stride=1,
|
| strides=1,
|
| paddings=0,
|
| dilations=1,
|
| dilation=1,
|
| groups=1,
|
| num_branch=1,
|
| test_branch_idx=-1,
|
| bias=False,
|
| norm=None,
|
| activation=None,
|
| ):
|
| super(MultiScaleTridentConv, self).__init__()
|
| self.in_channels = in_channels
|
| self.out_channels = out_channels
|
| self.kernel_size = _pair(kernel_size)
|
| self.num_branch = num_branch
|
| self.stride = _pair(stride)
|
| self.groups = groups
|
| self.with_bias = bias
|
| self.dilation = dilation
|
| if isinstance(paddings, int):
|
| paddings = [paddings] * self.num_branch
|
| if isinstance(dilations, int):
|
| dilations = [dilations] * self.num_branch
|
| if isinstance(strides, int):
|
| strides = [strides] * self.num_branch
|
| self.paddings = [_pair(padding) for padding in paddings]
|
| self.dilations = [_pair(dilation) for dilation in dilations]
|
| self.strides = [_pair(stride) for stride in strides]
|
| self.test_branch_idx = test_branch_idx
|
| self.norm = norm
|
| self.activation = activation
|
|
|
| assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
|
|
|
| self.weight = nn.Parameter(
|
| torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
|
| )
|
| if bias:
|
| self.bias = nn.Parameter(torch.Tensor(out_channels))
|
| else:
|
| self.bias = None
|
|
|
| nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
|
| if self.bias is not None:
|
| nn.init.constant_(self.bias, 0)
|
|
|
| def forward(self, inputs):
|
| num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
|
| assert len(inputs) == num_branch
|
|
|
| if self.training or self.test_branch_idx == -1:
|
| outputs = [
|
| F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
|
| for input, stride, padding in zip(inputs, self.strides, self.paddings)
|
| ]
|
| else:
|
| outputs = [
|
| F.conv2d(
|
| inputs[0],
|
| self.weight,
|
| self.bias,
|
| self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
|
| self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
|
| self.dilation,
|
| self.groups,
|
| )
|
| ]
|
|
|
| if self.norm is not None:
|
| outputs = [self.norm(x) for x in outputs]
|
| if self.activation is not None:
|
| outputs = [self.activation(x) for x in outputs]
|
| return outputs
|
|
|