| | import torch |
| |
|
| | class CNN2D(torch.nn.Module): |
| | |
| | def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15): |
| | assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding) |
| | super(CNN2D, self).__init__() |
| | |
| | |
| | self.conv_blocks = torch.nn.ModuleList() |
| | prev_channel = 1 |
| | |
| | for i in range(len(channels)): |
| | |
| | block = [] |
| | for j, conv_channel in enumerate(channels[i]): |
| | block.append(torch.nn.Conv2d(in_channels=prev_channel, out_channels=conv_channel, kernel_size=conv_kernels[i], stride=conv_strides[i], padding=conv_padding[i])) |
| | prev_channel = conv_channel |
| | |
| | block.append(torch.nn.BatchNorm2d(prev_channel)) |
| | |
| | block.append(torch.nn.ReLU()) |
| | self.conv_blocks.append(torch.nn.Sequential(*block)) |
| |
|
| | |
| | self.pool_blocks = torch.nn.ModuleList() |
| | for i in range(len(pool_padding)): |
| | |
| | self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i])) |
| |
|
| | |
| | self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) |
| | self.linear = torch.nn.Linear(prev_channel, num_classes) |
| |
|
| | def forward(self, inwav): |
| | for i in range(len(self.conv_blocks)): |
| | |
| | inwav = self.conv_blocks[i](inwav) |
| | |
| | if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav) |
| | |
| | out = self.global_pool(inwav).squeeze() |
| | out = self.linear(out) |
| | return out |
| | |
| | class ResBlock2D(torch.nn.Module): |
| | |
| | def __init__(self, prev_channel, channel, conv_kernel, conv_stride, conv_pad): |
| | super(ResBlock2D, self).__init__() |
| | self.res = torch.nn.Sequential( |
| | torch.nn.Conv2d(in_channels=prev_channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad), |
| | torch.nn.BatchNorm2d(channel), |
| | torch.nn.ReLU(), |
| | torch.nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad), |
| | torch.nn.BatchNorm2d(channel), |
| | ) |
| | self.bn = torch.nn.BatchNorm2d(channel) |
| | self.relu = torch.nn.ReLU() |
| |
|
| | def forward(self, x): |
| | identity = x |
| | x = self.res(x) |
| | if x.shape[1] == identity.shape[1]: |
| | x += identity |
| | elif x.shape[1] > identity.shape[1]: |
| | if x.shape[1] % identity.shape[1] == 0: |
| | x += identity.repeat(1, x.shape[1]//identity.shape[1], 1, 1) |
| | else: |
| | raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!") |
| | else: |
| | if identity.shape[1] % x.shape[1] == 0: |
| | identity += x.repeat(1, identity.shape[1]//x.shape[1], 1, 1) |
| | else: |
| | raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!") |
| | x = identity |
| | x = self.bn(x) |
| | x = self.relu(x) |
| | return x |
| | |
| | class CNNRes2D(torch.nn.Module): |
| | |
| | def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15): |
| | assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding) |
| | super(CNNRes2D, self).__init__() |
| | |
| | |
| | prev_channel = 1 |
| | self.conv_block = torch.nn.Sequential( |
| | torch.nn.Conv2d(in_channels=prev_channel, out_channels=channels[0][0], kernel_size=conv_kernels[0], stride=conv_strides[0], padding=conv_padding[0]), |
| | torch.nn.BatchNorm2d(channels[0][0]), |
| | torch.nn.ReLU(), |
| | torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[0]), |
| | ) |
| | |
| | |
| | prev_channel = channels[0][0] |
| | self.res_blocks = torch.nn.ModuleList() |
| | for i in range(1, len(channels)): |
| | block = [] |
| | for j, conv_channel in enumerate(channels[i]): |
| | block.append(ResBlock2D(prev_channel, conv_channel, conv_kernels[i], conv_strides[i], conv_padding[i])) |
| | prev_channel = conv_channel |
| | self.res_blocks.append(torch.nn.Sequential(*block)) |
| |
|
| | |
| | self.pool_blocks = torch.nn.ModuleList() |
| | for i in range(1, len(pool_padding)): |
| | self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i])) |
| |
|
| | |
| | self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) |
| | self.linear = torch.nn.Linear(prev_channel, num_classes) |
| |
|
| | def forward(self, inwav): |
| | inwav = self.conv_block(inwav) |
| | for i in range(len(self.res_blocks)): |
| | inwav = self.res_blocks[i](inwav) |
| | if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav) |
| | out = self.global_pool(inwav).squeeze() |
| | out = self.linear(out) |
| | return out |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|