| |
| import math |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| from mmcv.cnn import ConvModule |
| from mmcv.cnn import build_norm_layer |
| from timm.models.registry import register_model |
|
|
|
|
| def _make_divisible(v, divisor, min_value=None): |
| """ |
| This function is taken from the original tf repo. |
| It ensures that all layers have a channel number that is divisible by 8 |
| It can be seen here: |
| https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py |
| :param v: |
| :param divisor: |
| :param min_value: |
| :return: |
| """ |
| if min_value is None: |
| min_value = divisor |
| new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) |
| |
| if new_v < 0.9 * v: |
| new_v += divisor |
| return new_v |
|
|
|
|
| def drop_path(x, drop_prob: float = 0., training: bool = False): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
| the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
| changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
| 'survival rate' as the argument. |
| """ |
| if drop_prob == 0. or not training: |
| return x |
| keep_prob = 1 - drop_prob |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
| random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
| random_tensor.floor_() |
| output = x.div(keep_prob) * random_tensor |
| return output |
|
|
|
|
| class DropPath(nn.Module): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| """ |
|
|
| def __init__(self, drop_prob=None): |
| super(DropPath, self).__init__() |
| self.drop_prob = drop_prob |
|
|
| def forward(self, x): |
| return drop_path(x, self.drop_prob, self.training) |
|
|
|
|
| def get_shape(tensor): |
| shape = tensor.shape |
| if torch.onnx.is_in_onnx_export(): |
| shape = [i.cpu().numpy() for i in shape] |
| return shape |
|
|
|
|
| class Conv2d_BN(nn.Sequential): |
| def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, |
| groups=1, bn_weight_init=1, bias=False, |
| norm_cfg=dict(type='BN', requires_grad=True)): |
| super().__init__() |
| self.inp_channel = a |
| self.out_channel = b |
| self.ks = ks |
| self.pad = pad |
| self.stride = stride |
| self.dilation = dilation |
| self.groups = groups |
| |
| self.add_module('c', nn.Conv2d( |
| a, b, ks, stride, pad, dilation, groups, bias=bias)) |
| bn = build_norm_layer(norm_cfg, b)[1] |
| nn.init.constant_(bn.weight, bn_weight_init) |
| nn.init.constant_(bn.bias, 0) |
| self.add_module('bn', bn) |
|
|
|
|
| class Mlp(nn.Module): |
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0., |
| norm_cfg=dict(type='BN', requires_grad=True)): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = Conv2d_BN(in_features, hidden_features, norm_cfg=norm_cfg) |
| self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features) |
| self.act = act_layer() |
| self.fc2 = Conv2d_BN(hidden_features, out_features, norm_cfg=norm_cfg) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.dwconv(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
|
|
| class InvertedResidual(nn.Module): |
| def __init__( |
| self, |
| inp: int, |
| oup: int, |
| ks: int, |
| stride: int, |
| expand_ratio: int, |
| activations=None, |
| norm_cfg=dict(type='BN', requires_grad=True) |
| ) -> None: |
| super(InvertedResidual, self).__init__() |
| self.stride = stride |
| self.expand_ratio = expand_ratio |
| assert stride in [1, 2] |
|
|
| if activations is None: |
| activations = nn.ReLU |
|
|
| hidden_dim = int(round(inp * expand_ratio)) |
| self.use_res_connect = self.stride == 1 and inp == oup |
|
|
| layers = [] |
| if expand_ratio != 1: |
| |
| layers.append(Conv2d_BN(inp, hidden_dim, ks=1, norm_cfg=norm_cfg)) |
| layers.append(activations()) |
| layers.extend([ |
| |
| Conv2d_BN(hidden_dim, hidden_dim, ks=ks, stride=stride, pad=ks // 2, groups=hidden_dim, norm_cfg=norm_cfg), |
| activations(), |
| |
| Conv2d_BN(hidden_dim, oup, ks=1, norm_cfg=norm_cfg) |
| ]) |
| self.conv = nn.Sequential(*layers) |
| self.out_channels = oup |
| self._is_cn = stride > 1 |
|
|
| def forward(self, x): |
| if self.use_res_connect: |
| return x + self.conv(x) |
| else: |
| return self.conv(x) |
|
|
|
|
| class StackedMV2Block(nn.Module): |
| def __init__( |
| self, |
| cfgs, |
| stem, |
| inp_channel=16, |
| activation=nn.ReLU, |
| norm_cfg=dict(type='BN', requires_grad=True), |
| width_mult=1.): |
| super().__init__() |
| self.stem = stem |
| if stem: |
| self.stem_block = nn.Sequential( |
| Conv2d_BN(3, inp_channel, 3, 2, 1, norm_cfg=norm_cfg), |
| activation() |
| ) |
| self.cfgs = cfgs |
|
|
| self.layers = [] |
| for i, (k, t, c, s) in enumerate(cfgs): |
| output_channel = _make_divisible(c * width_mult, 8) |
| exp_size = t * inp_channel |
| exp_size = _make_divisible(exp_size * width_mult, 8) |
| layer_name = 'layer{}'.format(i + 1) |
| layer = InvertedResidual(inp_channel, output_channel, ks=k, stride=s, expand_ratio=t, norm_cfg=norm_cfg, |
| activations=activation) |
| self.add_module(layer_name, layer) |
| inp_channel = output_channel |
| self.layers.append(layer_name) |
|
|
| def forward(self, x): |
| if self.stem: |
| x = self.stem_block(x) |
| for i, layer_name in enumerate(self.layers): |
| layer = getattr(self, layer_name) |
| x = layer(x) |
| return x |
| |
|
|
| class SqueezeAxialPositionalEmbedding(nn.Module): |
| def __init__(self, dim, shape): |
| super().__init__() |
| |
| self.pos_embed = nn.Parameter(torch.randn([1, dim, shape])) |
|
|
| def forward(self, x): |
| B, C, N = x.shape |
| x = x + F.interpolate(self.pos_embed, size=(N), mode='linear', align_corners=False) |
| |
| return x |
| |
| |
| class Sea_Attention(torch.nn.Module): |
| def __init__(self, dim, key_dim, num_heads, |
| attn_ratio=2, |
| activation=None, |
| norm_cfg=dict(type='BN', requires_grad=True), ): |
| super().__init__() |
| self.num_heads = num_heads |
| self.scale = key_dim ** -0.5 |
| self.key_dim = key_dim |
| self.nh_kd = nh_kd = key_dim * num_heads |
| self.d = int(attn_ratio * key_dim) |
| self.dh = int(attn_ratio * key_dim) * num_heads |
| self.attn_ratio = attn_ratio |
|
|
| self.to_q = Conv2d_BN(dim, nh_kd, 1, norm_cfg=norm_cfg) |
| self.to_k = Conv2d_BN(dim, nh_kd, 1, norm_cfg=norm_cfg) |
| self.to_v = Conv2d_BN(dim, self.dh, 1, norm_cfg=norm_cfg) |
| |
| self.proj = torch.nn.Sequential(activation(), Conv2d_BN( |
| self.dh, dim, bn_weight_init=0, norm_cfg=norm_cfg)) |
| self.proj_encode_row = torch.nn.Sequential(activation(), Conv2d_BN( |
| self.dh, self.dh, bn_weight_init=0, norm_cfg=norm_cfg)) |
| self.pos_emb_rowq = SqueezeAxialPositionalEmbedding(nh_kd, 16) |
| self.pos_emb_rowk = SqueezeAxialPositionalEmbedding(nh_kd, 16) |
|
|
| self.proj_encode_column = torch.nn.Sequential(activation(), Conv2d_BN( |
| self.dh, self.dh, bn_weight_init=0, norm_cfg=norm_cfg)) |
| self.pos_emb_columnq = SqueezeAxialPositionalEmbedding(nh_kd, 16) |
| self.pos_emb_columnk = SqueezeAxialPositionalEmbedding(nh_kd, 16) |
| |
| self.dwconv = Conv2d_BN(self.dh + 2 * self.nh_kd, 2 * self.nh_kd + self.dh, ks=3, stride=1, pad=1, dilation=1, |
| groups=2 * self.nh_kd + self.dh, norm_cfg=norm_cfg) |
| self.act = activation() |
| self.pwconv = Conv2d_BN(2 * self.nh_kd + self.dh, dim, ks=1, norm_cfg=norm_cfg) |
| self.sigmoid = h_sigmoid() |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
|
|
| q = self.to_q(x) |
| k = self.to_k(x) |
| v = self.to_v(x) |
| |
| |
| qkv = torch.cat([q, k, v], dim=1) |
| qkv = self.act(self.dwconv(qkv)) |
| qkv = self.pwconv(qkv) |
|
|
| |
| |
| qrow = self.pos_emb_rowq(q.mean(-1)).reshape(B, self.num_heads, -1, H).permute(0, 1, 3, 2) |
| krow = self.pos_emb_rowk(k.mean(-1)).reshape(B, self.num_heads, -1, H) |
| vrow = v.mean(-1).reshape(B, self.num_heads, -1, H).permute(0, 1, 3, 2) |
| attn_row = torch.matmul(qrow, krow) * self.scale |
| attn_row = attn_row.softmax(dim=-1) |
| xx_row = torch.matmul(attn_row, vrow) |
| xx_row = self.proj_encode_row(xx_row.permute(0, 1, 3, 2).reshape(B, self.dh, H, 1)) |
|
|
| |
| qcolumn = self.pos_emb_columnq(q.mean(-2)).reshape(B, self.num_heads, -1, W).permute(0, 1, 3, 2) |
| kcolumn = self.pos_emb_columnk(k.mean(-2)).reshape(B, self.num_heads, -1, W) |
| vcolumn = v.mean(-2).reshape(B, self.num_heads, -1, W).permute(0, 1, 3, 2) |
| attn_column = torch.matmul(qcolumn, kcolumn) * self.scale |
| attn_column = attn_column.softmax(dim=-1) |
| xx_column = torch.matmul(attn_column, vcolumn) |
| xx_column = self.proj_encode_column(xx_column.permute(0, 1, 3, 2).reshape(B, self.dh, 1, W)) |
|
|
| xx = xx_row.add(xx_column) |
| xx = v.add(xx) |
| xx = self.proj(xx) |
| |
| xx = self.sigmoid(xx) * qkv |
| return xx |
|
|
|
|
| class Block(nn.Module): |
|
|
| def __init__(self, dim, key_dim, num_heads, mlp_ratio=4., attn_ratio=2., drop=0., |
| drop_path=0., act_layer=nn.ReLU, norm_cfg=dict(type='BN2d', requires_grad=True)): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.mlp_ratio = mlp_ratio |
| |
| self.attn = Sea_Attention(dim, key_dim=key_dim, num_heads=num_heads, attn_ratio=attn_ratio, |
| activation=act_layer, norm_cfg=norm_cfg) |
| |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, norm_cfg=norm_cfg) |
|
|
| def forward(self, x1): |
| x1 = x1 + self.drop_path(self.attn(x1)) |
| x1 = x1 + self.drop_path(self.mlp(x1)) |
| return x1 |
|
|
|
|
| class BasicLayer(nn.Module): |
| def __init__(self, block_num, embedding_dim, key_dim, num_heads, |
| mlp_ratio=4., attn_ratio=2., drop=0., attn_drop=0., drop_path=0., |
| norm_cfg=dict(type='BN2d', requires_grad=True), |
| act_layer=None): |
| super().__init__() |
| self.block_num = block_num |
|
|
| self.transformer_blocks = nn.ModuleList() |
| for i in range(self.block_num): |
| self.transformer_blocks.append(Block( |
| embedding_dim, key_dim=key_dim, num_heads=num_heads, |
| mlp_ratio=mlp_ratio, attn_ratio=attn_ratio, |
| drop=drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, |
| norm_cfg=norm_cfg, |
| act_layer=act_layer)) |
|
|
| def forward(self, x): |
| |
| for i in range(self.block_num): |
| x = self.transformer_blocks[i](x) |
| return x |
|
|
|
|
| class h_sigmoid(nn.Module): |
| def __init__(self, inplace=True): |
| super(h_sigmoid, self).__init__() |
| self.relu = nn.ReLU6(inplace=inplace) |
|
|
| def forward(self, x): |
| return self.relu(x + 3) / 6 |
|
|
|
|
| class SeaFormer(nn.Module): |
| def __init__(self, cfgs, |
| channels, |
| emb_dims, |
| key_dims, |
| depths=[2,2], |
| num_heads=4, |
| attn_ratios=2, |
| mlp_ratios=[2, 4], |
| drop_path_rate=0., |
| norm_cfg=dict(type='BN', requires_grad=True), |
| act_layer=nn.ReLU6, |
| init_cfg=None, |
| num_classes=1000): |
| super().__init__() |
| self.num_classes = num_classes |
| self.channels = channels |
| self.depths = depths |
| self.cfgs = cfgs |
| self.norm_cfg = norm_cfg |
| self.init_cfg = init_cfg |
| if self.init_cfg is not None: |
| self.pretrained = self.init_cfg['checkpoint'] |
|
|
| for i in range(len(cfgs)): |
| smb = StackedMV2Block(cfgs=cfgs[i], stem=True if i == 0 else False, inp_channel=channels[i], norm_cfg=norm_cfg) |
| setattr(self, f"smb{i + 1}", smb) |
|
|
| for i in range(len(depths)): |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths[i])] |
| trans = BasicLayer( |
| block_num=depths[i], |
| embedding_dim=emb_dims[i], |
| key_dim=key_dims[i], |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratios[i], |
| attn_ratio=attn_ratios, |
| drop=0, attn_drop=0, |
| drop_path=dpr, |
| norm_cfg=norm_cfg, |
| act_layer=act_layer) |
| setattr(self, f"trans{i + 1}", trans) |
|
|
| self.linear = nn.Linear(channels[-1], 1000) |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
| def init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| n //= m.groups |
| m.weight.data.normal_(0, math.sqrt(2. / n)) |
| if m.bias is not None: |
| m.bias.data.zero_() |
| elif isinstance(m, nn.BatchNorm2d): |
| m.weight.data.fill_(1) |
| m.bias.data.zero_() |
| elif isinstance(m, nn.Linear): |
| m.weight.data.normal_(0, 0.01) |
| if m.bias is not None: |
| m.bias.data.zero_() |
|
|
| def forward(self, x): |
| num_smb_stage = len(self.cfgs) |
| num_trans_stage = len(self.depths) |
| res = [] |
| for i in range(num_smb_stage): |
| smb = getattr(self, f"smb{i + 1}") |
| x = smb(x) |
| if num_trans_stage + i >= num_smb_stage: |
| trans = getattr(self, f"trans{i + num_trans_stage - num_smb_stage + 1}") |
| x = trans(x) |
| res.append(x) |
|
|
| out = self.avgpool(x).view(-1, x.shape[1]) |
| out = self.linear(out) |
| return res |
|
|
|
|
| @register_model |
| def SeaFormer_T(pretrained=False, **kwargs): |
| model_cfgs = dict( |
| cfg1=[ |
| |
| [3, 1, 16, 1], |
| [3, 4, 16, 2], |
| [3, 3, 16, 1]], |
| cfg2=[ |
| [5, 3, 32, 2], |
| [5, 3, 32, 1]], |
| cfg3=[ |
| [3, 3, 64, 2], |
| [3, 3, 64, 1]], |
| cfg4=[ |
| [5, 3, 128, 2]], |
| cfg5=[ |
| [3, 6, 160, 2]], |
| channels=[16, 16, 32, 64, 128, 160], |
| num_heads=4, |
| depths=[2, 2], |
| emb_dims=[128, 160], |
| key_dims=[16, 24], |
| drop_path_rate=0.1, |
| attn_ratios=2, |
| mlp_ratios=[2, 4]) |
| return SeaFormer( |
| cfgs=[model_cfgs['cfg1'], model_cfgs['cfg2'], model_cfgs['cfg3'], model_cfgs['cfg4'], model_cfgs['cfg5']], |
| channels=model_cfgs['channels'], |
| emb_dims=model_cfgs['emb_dims'], |
| key_dims=model_cfgs['key_dims'], |
| depths=model_cfgs['depths'], |
| attn_ratios=model_cfgs['attn_ratios'], |
| mlp_ratios=model_cfgs['mlp_ratios'], |
| num_heads=model_cfgs['num_heads'], |
| drop_path_rate=model_cfgs['drop_path_rate']) |
|
|
|
|
| @register_model |
| def SeaFormer_S(pretrained=False, **kwargs): |
| model_cfgs = dict( |
| cfg1=[ |
| |
| [3, 1, 16, 1], |
| [3, 4, 24, 2], |
| [3, 3, 24, 1]], |
| cfg2=[ |
| [5, 3, 48, 2], |
| [5, 3, 48, 1]], |
| cfg3=[ |
| [3, 3, 96, 2], |
| [3, 3, 96, 1]], |
| cfg4=[ |
| [5, 4, 160, 2]], |
| cfg5=[ |
| [3, 6, 192, 2]], |
| channels=[16, 24, 48, 96, 160, 192], |
| num_heads=6, |
| depths=[3, 3], |
| key_dims=[16, 24], |
| emb_dims=[160, 192], |
| drop_path_rate=0.1, |
| attn_ratios=2, |
| mlp_ratios=[2, 4]) |
| return SeaFormer( |
| cfgs=[model_cfgs['cfg1'], model_cfgs['cfg2'], model_cfgs['cfg3'], model_cfgs['cfg4'], model_cfgs['cfg5']], |
| channels=model_cfgs['channels'], |
| emb_dims=model_cfgs['emb_dims'], |
| key_dims=model_cfgs['key_dims'], |
| depths=model_cfgs['depths'], |
| attn_ratios=model_cfgs['attn_ratios'], |
| mlp_ratios=model_cfgs['mlp_ratios'], |
| num_heads=model_cfgs['num_heads'], |
| drop_path_rate=model_cfgs['drop_path_rate']) |
|
|
|
|
| @register_model |
| def SeaFormer_B(pretrained=False, **kwargs): |
| model_cfgs = dict( |
| cfg1=[ |
| |
| [3, 1, 16, 1], |
| [3, 4, 32, 2], |
| [3, 3, 32, 1]], |
| cfg2=[ |
| [5, 3, 64, 2], |
| [5, 3, 64, 1]], |
| cfg3=[ |
| [3, 3, 128, 2], |
| [3, 3, 128, 1]], |
| cfg4=[ |
| [5, 4, 192, 2]], |
| cfg5=[ |
| [3, 6, 256, 2]], |
| channels=[16, 32, 64, 128, 192, 256], |
| num_heads=8, |
| depths=[4, 4], |
| key_dims=[16, 24], |
| emb_dims=[192, 256], |
| drop_path_rate=0.1, |
| attn_ratios=2, |
| mlp_ratios=[2, 4]) |
| return SeaFormer( |
| cfgs=[model_cfgs['cfg1'], model_cfgs['cfg2'], model_cfgs['cfg3'], model_cfgs['cfg4'], model_cfgs['cfg5']], |
| channels=model_cfgs['channels'], |
| emb_dims=model_cfgs['emb_dims'], |
| key_dims=model_cfgs['key_dims'], |
| depths=model_cfgs['depths'], |
| attn_ratios=model_cfgs['attn_ratios'], |
| mlp_ratios=model_cfgs['mlp_ratios'], |
| num_heads=model_cfgs['num_heads'], |
| drop_path_rate=model_cfgs['drop_path_rate']) |
|
|
| |
| |
| @register_model |
| def SeaFormer_L(pretrained=False, weights='rscd/models/backbones/review_pretrain/SeaFormer_L_cls_79.9.pth.tar', **kwargs): |
| model_cfgs = dict( |
| cfg1=[ |
| |
| [3, 3, 32, 1], |
| [3, 4, 64, 2], |
| [3, 4, 64, 1]], |
| cfg2=[ |
| [5, 4, 128, 2], |
| [5, 4, 128, 1]], |
| cfg3=[ |
| [3, 4, 192, 2], |
| [3, 4, 192, 1]], |
| cfg4=[ |
| [5, 4, 256, 2]], |
| cfg5=[ |
| [3, 6, 320, 2]], |
| channels=[32, 64, 128, 192, 256, 320], |
| num_heads=8, |
| depths=[3, 3, 3], |
| key_dims=[16, 20, 24], |
| emb_dims=[192, 256, 320], |
| drop_path_rate=0.1, |
| attn_ratios=2, |
| mlp_ratios=[2, 4, 6]) |
| model = SeaFormer( |
| cfgs=[model_cfgs['cfg1'], model_cfgs['cfg2'], model_cfgs['cfg3'], model_cfgs['cfg4'], model_cfgs['cfg5']], |
| channels=model_cfgs['channels'], |
| emb_dims=model_cfgs['emb_dims'], |
| key_dims=model_cfgs['key_dims'], |
| depths=model_cfgs['depths'], |
| attn_ratios=model_cfgs['attn_ratios'], |
| mlp_ratios=model_cfgs['mlp_ratios'], |
| num_heads=model_cfgs['num_heads'], |
| drop_path_rate=model_cfgs['drop_path_rate']) |
|
|
| if pretrained: |
| model_weitht = torch.load(weights) |
| model.load_state_dict(model_weitht['state_dict']) |
| |
| return model |
|
|
| if __name__ == '__main__': |
| model = SeaFormer_L(pretrained=True) |
| |
| |
| input = torch.rand((1, 3, 512, 512)) |
| print(model) |
| from fvcore.nn import FlopCountAnalysis, flop_count_table |
| model.eval() |
| flops = FlopCountAnalysis(model, input) |
| print(flop_count_table(flops)) |
|
|
| res = model(input) |
| for i in res: |
| print(i.shape) |
|
|