| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| |
| def initialize_weights(m): |
| if isinstance(m, nn.Conv1d): |
| n = m.kernel_size[0] * m.out_channels |
| m.weight.data.normal_(0, math.sqrt(2 / n)) |
| if m.bias is not None: |
| nn.init.constant_(m.bias.data, 0) |
| elif isinstance(m, nn.BatchNorm1d): |
| nn.init.constant_(m.weight.data, 1) |
| nn.init.constant_(m.bias.data, 0) |
| elif isinstance(m, nn.Linear): |
| m.weight.data.normal_(0, 0.001) |
| if m.bias is not None: |
| nn.init.constant_(m.bias.data, 0) |
|
|
| class SELayer(nn.Module): |
| def __init__(self, inp, reduction=4): |
| super(SELayer, self).__init__() |
| self.fc = nn.Sequential( |
| nn.Linear(inp, int(inp // reduction)), |
| nn.SiLU(), |
| nn.Linear(int(inp // reduction), inp), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, x): |
| b, c, _, = x.size() |
| y = x.view(b, c, -1).mean(dim=2) |
| y = self.fc(y).view(b, c, 1) |
| return x * y |
|
|
| class EffBlock(nn.Module): |
| def __init__(self, in_ch, ks, resize_factor, activation, out_ch=None, se_reduction=None): |
| super().__init__() |
| self.in_ch = in_ch |
| self.out_ch = self.in_ch if out_ch is None else out_ch |
| self.resize_factor = resize_factor |
| self.se_reduction = resize_factor if se_reduction is None else se_reduction |
| self.ks = ks |
| self.inner_dim = self.in_ch * self.resize_factor |
| |
| block = nn.Sequential( |
| nn.Conv1d( |
| in_channels=self.in_ch, |
| out_channels=self.inner_dim, |
| kernel_size=1, |
| padding='same', |
| bias=False |
| ), |
| nn.BatchNorm1d(self.inner_dim), |
| activation(), |
| |
| nn.Conv1d( |
| in_channels=self.inner_dim, |
| out_channels=self.inner_dim, |
| kernel_size=ks, |
| groups=self.inner_dim, |
| padding='same', |
| bias=False |
| ), |
| nn.BatchNorm1d(self.inner_dim), |
| activation(), |
| SELayer(self.inner_dim, reduction=self.se_reduction), |
| nn.Conv1d( |
| in_channels=self.inner_dim, |
| out_channels=self.in_ch, |
| kernel_size=1, |
| padding='same', |
| bias=False |
| ), |
| nn.BatchNorm1d(self.in_ch), |
| activation(), |
| ) |
| |
| self.block = block |
| |
| def forward(self, x): |
| return self.block(x) |
| |
| class LocalBlock(nn.Module): |
| def __init__(self, in_ch, ks, activation, out_ch=None): |
| super().__init__() |
| self.in_ch = in_ch |
| self.out_ch = self.in_ch if out_ch is None else out_ch |
| self.ks = ks |
| |
| self.block = nn.Sequential( |
| nn.Conv1d( |
| in_channels=self.in_ch, |
| out_channels=self.out_ch, |
| kernel_size=self.ks, |
| padding='same', |
| bias=False |
| ), |
| nn.BatchNorm1d(self.out_ch), |
| activation() |
| ) |
| |
| def forward(self, x): |
| return self.block(x) |
| |
| class ResidualConcat(nn.Module): |
| def __init__(self, fn): |
| super().__init__() |
| self.fn = fn |
|
|
| def forward(self, x, **kwargs): |
| return torch.concat([self.fn(x, **kwargs), x], dim=1) |
|
|
| class MapperBlock(nn.Module): |
| def __init__(self, in_features, out_features, activation=nn.SiLU): |
| super().__init__() |
| self.block = nn.Sequential( |
| nn.BatchNorm1d(in_features), |
| nn.Conv1d(in_channels=in_features, |
| out_channels=out_features, |
| kernel_size=1), |
| ) |
| |
| def forward(self, x): |
| return self.block(x) |
|
|
| class LegNet(nn.Module): |
| def __init__(self, |
| in_ch, |
| stem_ch, |
| stem_ks, |
| ef_ks, |
| ef_block_sizes, |
| pool_sizes, |
| resize_factor, |
| activation=nn.SiLU, |
| ): |
| super().__init__() |
| assert len(pool_sizes) == len(ef_block_sizes) |
| |
| self.in_ch = in_ch |
| self.stem = LocalBlock(in_ch=in_ch, |
| out_ch=stem_ch, |
| ks=stem_ks, |
| activation=activation) |
| |
| blocks = [] |
| |
| in_ch = stem_ch |
| out_ch = stem_ch |
| for pool_sz, out_ch in zip(pool_sizes, ef_block_sizes): |
| blc = nn.Sequential( |
| ResidualConcat( |
| EffBlock( |
| in_ch=in_ch, |
| out_ch=in_ch, |
| ks=ef_ks, |
| resize_factor=resize_factor, |
| activation=activation) |
| ), |
| LocalBlock(in_ch=in_ch * 2, |
| out_ch=out_ch, |
| ks=ef_ks, |
| activation=activation), |
| nn.MaxPool1d(pool_sz) if pool_sz != 1 else nn.Identity() |
| ) |
| in_ch = out_ch |
| blocks.append(blc) |
| self.main = nn.Sequential(*blocks) |
| |
| self.mapper = MapperBlock(in_features=out_ch, |
| out_features=out_ch * 2) |
| self.head = nn.Sequential(nn.Linear(out_ch * 2, out_ch * 2), |
| nn.BatchNorm1d(out_ch * 2), |
| activation(), |
| nn.Linear(out_ch * 2, 1)) |
| |
| def forward(self, x): |
| x = self.stem(x) |
| x = self.main(x) |
| x = self.mapper(x) |
| x = F.adaptive_avg_pool1d(x, 1) |
| x = x.squeeze(-1) |
| x = self.head(x) |
| x = x.squeeze(-1) |
| return x |
|
|