| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import logging |
| | import os |
| | import torch |
| | import torch.nn as nn |
| | import torch.utils.checkpoint as checkpoint |
| | import numpy as np |
| | from einops import rearrange, repeat |
| | from einops.layers.torch import Rearrange |
| | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ |
| |
|
| | from .registry import register_image_encoder |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| |
|
| | class Mlp(nn.Module): |
| | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | self.fc1 = nn.Linear(in_features, hidden_features) |
| | self.act = act_layer() |
| | self.fc2 = nn.Linear(hidden_features, out_features) |
| | self.drop = nn.Dropout(drop) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x = self.act(x) |
| | x = self.drop(x) |
| | x = self.fc2(x) |
| | x = self.drop(x) |
| | return x |
| |
|
| |
|
| | def window_partition(x, window_size): |
| | """ |
| | Args: |
| | x: (B, H, W, C) |
| | window_size (int): window size |
| | |
| | Returns: |
| | windows: (num_windows*B, window_size, window_size, C) |
| | """ |
| | B, H, W, C = x.shape |
| | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
| | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
| | return windows |
| |
|
| |
|
| | def window_reverse(windows, window_size, H, W): |
| | """ |
| | Args: |
| | windows: (num_windows*B, window_size, window_size, C) |
| | window_size (int): Window size |
| | H (int): Height of image |
| | W (int): Width of image |
| | |
| | Returns: |
| | x: (B, H, W, C) |
| | """ |
| | B = int(windows.shape[0] / (H * W / window_size / window_size)) |
| | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) |
| | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
| | return x |
| |
|
| |
|
| | class WindowAttention(nn.Module): |
| | r""" Window based multi-head self attention (W-MSA) module with relative position bias. |
| | It supports both of shifted and non-shifted window. |
| | |
| | Args: |
| | dim (int): Number of input channels. |
| | window_size (tuple[int]): The height and width of the window. |
| | num_heads (int): Number of attention heads. |
| | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
| | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set |
| | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 |
| | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 |
| | """ |
| |
|
| | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): |
| |
|
| | super().__init__() |
| | self.dim = dim |
| | self.window_size = window_size |
| | self.num_heads = num_heads |
| | head_dim = dim // num_heads |
| | self.scale = qk_scale or head_dim ** -0.5 |
| |
|
| | |
| | self.relative_position_bias_table = nn.Parameter( |
| | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) |
| |
|
| | |
| | coords_h = torch.arange(self.window_size[0]) |
| | coords_w = torch.arange(self.window_size[1]) |
| | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) |
| | coords_flatten = torch.flatten(coords, 1) |
| | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] |
| | relative_coords = relative_coords.permute(1, 2, 0).contiguous() |
| | relative_coords[:, :, 0] += self.window_size[0] - 1 |
| | relative_coords[:, :, 1] += self.window_size[1] - 1 |
| | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 |
| | relative_position_index = relative_coords.sum(-1) |
| | self.register_buffer("relative_position_index", relative_position_index) |
| |
|
| | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| | self.attn_drop = nn.Dropout(attn_drop) |
| | self.proj = nn.Linear(dim, dim) |
| | self.proj_drop = nn.Dropout(proj_drop) |
| |
|
| | trunc_normal_(self.relative_position_bias_table, std=.02) |
| | self.softmax = nn.Softmax(dim=-1) |
| |
|
| | def forward(self, x, mask=None): |
| | """ |
| | Args: |
| | x: input features with shape of (num_windows*B, N, C) |
| | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None |
| | """ |
| | B_, N, C = x.shape |
| | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | q = q * self.scale |
| | attn = (q @ k.transpose(-2, -1)) |
| |
|
| | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( |
| | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) |
| | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() |
| | attn = attn + relative_position_bias.unsqueeze(0) |
| |
|
| | if mask is not None: |
| | nW = mask.shape[0] |
| | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) |
| | attn = attn.view(-1, self.num_heads, N, N) |
| | attn = self.softmax(attn) |
| | else: |
| | attn = self.softmax(attn) |
| |
|
| | attn = self.attn_drop(attn) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| | return x |
| |
|
| | def extra_repr(self) -> str: |
| | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' |
| |
|
| | def flops(self, N): |
| | |
| | flops = 0 |
| | |
| | flops += N * self.dim * 3 * self.dim |
| | |
| | flops += self.num_heads * N * (self.dim // self.num_heads) * N |
| | |
| | flops += self.num_heads * N * N * (self.dim // self.num_heads) |
| | |
| | flops += N * self.dim * self.dim |
| | return flops |
| |
|
| |
|
| | class SwinTransformerBlock(nn.Module): |
| | r""" Swin Transformer Block. |
| | |
| | Args: |
| | dim (int): Number of input channels. |
| | input_resolution (tuple[int]): Input resulotion. |
| | num_heads (int): Number of attention heads. |
| | window_size (int): Window size. |
| | shift_size (int): Shift size for SW-MSA. |
| | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
| | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
| | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
| | drop (float, optional): Dropout rate. Default: 0.0 |
| | attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
| | drop_path (float, optional): Stochastic depth rate. Default: 0.0 |
| | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU |
| | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
| | """ |
| |
|
| | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, |
| | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., |
| | act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_scale=False): |
| | super().__init__() |
| | self.dim = dim |
| | self.input_resolution = input_resolution |
| | self.num_heads = num_heads |
| | self.window_size = window_size |
| | self.shift_size = shift_size |
| | self.mlp_ratio = mlp_ratio |
| | if min(self.input_resolution) <= self.window_size: |
| | |
| | self.shift_size = 0 |
| | self.window_size = min(self.input_resolution) |
| | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" |
| |
|
| | self.norm1 = norm_layer(dim) |
| | self.attn = WindowAttention( |
| | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, |
| | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) |
| |
|
| | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| | self.norm2 = norm_layer(dim) |
| | mlp_hidden_dim = int(dim * mlp_ratio) |
| | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
| |
|
| | if self.shift_size > 0: |
| | |
| | H, W = self.input_resolution |
| | img_mask = torch.zeros((1, H, W, 1)) |
| | h_slices = (slice(0, -self.window_size), |
| | slice(-self.window_size, -self.shift_size), |
| | slice(-self.shift_size, None)) |
| | w_slices = (slice(0, -self.window_size), |
| | slice(-self.window_size, -self.shift_size), |
| | slice(-self.shift_size, None)) |
| | cnt = 0 |
| | for h in h_slices: |
| | for w in w_slices: |
| | img_mask[:, h, w, :] = cnt |
| | cnt += 1 |
| |
|
| | mask_windows = window_partition(img_mask, self.window_size) |
| | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
| | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
| | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
| | else: |
| | attn_mask = None |
| |
|
| | self.gamma = 1.0 |
| | if layer_scale: |
| | logger.info('=> enable layer scale') |
| | self.gamma = nn.Parameter( |
| | 1e-4*torch.ones(dim), requires_grad=True |
| | ) |
| |
|
| | self.register_buffer("attn_mask", attn_mask) |
| |
|
| | def forward(self, x): |
| | H, W = self.input_resolution |
| | B, L, C = x.shape |
| | assert L == H * W, "input feature has wrong size" |
| |
|
| | shortcut = x |
| | x = self.norm1(x) |
| | x = x.view(B, H, W, C) |
| |
|
| | |
| | if self.shift_size > 0: |
| | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) |
| | else: |
| | shifted_x = x |
| |
|
| | |
| | x_windows = window_partition(shifted_x, self.window_size) |
| | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) |
| |
|
| | |
| | attn_windows = self.attn(x_windows, mask=self.attn_mask) |
| |
|
| | |
| | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) |
| | shifted_x = window_reverse(attn_windows, self.window_size, H, W) |
| |
|
| | |
| | if self.shift_size > 0: |
| | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) |
| | else: |
| | x = shifted_x |
| | x = x.view(B, H * W, C) |
| |
|
| | |
| | x = shortcut + self.drop_path(self.gamma*x) |
| | x = x + self.drop_path(self.gamma*self.mlp(self.norm2(x))) |
| |
|
| | return x |
| |
|
| | def extra_repr(self) -> str: |
| | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ |
| | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" |
| |
|
| | def flops(self): |
| | flops = 0 |
| | H, W = self.input_resolution |
| | |
| | flops += self.dim * H * W |
| | |
| | nW = H * W / self.window_size / self.window_size |
| | flops += nW * self.attn.flops(self.window_size * self.window_size) |
| | |
| | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio |
| | |
| | flops += self.dim * H * W |
| | return flops |
| |
|
| |
|
| | class PatchMerging(nn.Module): |
| | r""" Patch Merging Layer. |
| | |
| | Args: |
| | input_resolution (tuple[int]): Resolution of input feature. |
| | dim (int): Number of input channels. |
| | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
| | """ |
| |
|
| | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): |
| | super().__init__() |
| | self.input_resolution = input_resolution |
| | self.dim = dim |
| | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) |
| | self.norm = norm_layer(4 * dim) |
| |
|
| | def forward(self, x): |
| | """ |
| | x: B, H*W, C |
| | """ |
| | H, W = self.input_resolution |
| | B, L, C = x.shape |
| | assert L == H * W, "input feature has wrong size" |
| | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." |
| |
|
| | x = x.view(B, H, W, C) |
| |
|
| | x0 = x[:, 0::2, 0::2, :] |
| | x1 = x[:, 1::2, 0::2, :] |
| | x2 = x[:, 0::2, 1::2, :] |
| | x3 = x[:, 1::2, 1::2, :] |
| | x = torch.cat([x0, x1, x2, x3], -1) |
| | x = x.view(B, -1, 4 * C) |
| |
|
| | x = self.norm(x) |
| | x = self.reduction(x) |
| |
|
| | return x |
| |
|
| | def extra_repr(self) -> str: |
| | return f"input_resolution={self.input_resolution}, dim={self.dim}" |
| |
|
| | def flops(self): |
| | H, W = self.input_resolution |
| | flops = H * W * self.dim |
| | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim |
| | return flops |
| |
|
| |
|
| | class BasicLayer(nn.Module): |
| | """ A basic Swin Transformer layer for one stage. |
| | |
| | Args: |
| | dim (int): Number of input channels. |
| | input_resolution (tuple[int]): Input resolution. |
| | depth (int): Number of blocks. |
| | num_heads (int): Number of attention heads. |
| | window_size (int): Local window size. |
| | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
| | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
| | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
| | drop (float, optional): Dropout rate. Default: 0.0 |
| | attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
| | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 |
| | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
| | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None |
| | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. |
| | """ |
| |
|
| | def __init__(self, dim, input_resolution, depth, num_heads, window_size, |
| | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., |
| | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, |
| | use_checkpoint=False, layer_scale=False): |
| |
|
| | super().__init__() |
| | self.dim = dim |
| | self.input_resolution = input_resolution |
| | self.depth = depth |
| | self.use_checkpoint = use_checkpoint |
| |
|
| | |
| | self.blocks = nn.ModuleList([ |
| | SwinTransformerBlock( |
| | dim=dim, input_resolution=input_resolution, |
| | num_heads=num_heads, window_size=window_size, |
| | shift_size=0 if (i % 2 == 0) else window_size // 2, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, qk_scale=qk_scale, |
| | drop=drop, attn_drop=attn_drop, |
| | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, |
| | norm_layer=norm_layer, |
| | layer_scale=layer_scale |
| | ) |
| | for i in range(depth)]) |
| |
|
| | |
| | if downsample is not None: |
| | |
| | self.downsample = downsample( |
| | input_resolution=input_resolution, patch_size=3, in_chans=dim, embed_dim=dim*2, |
| | stride=2, padding=1, norm_layer=norm_layer |
| | ) |
| | else: |
| | self.downsample = None |
| |
|
| | def forward(self, x): |
| | for blk in self.blocks: |
| | if self.use_checkpoint: |
| | x = checkpoint.checkpoint(blk, x) |
| | else: |
| | x = blk(x) |
| | if self.downsample is not None: |
| | x = self.downsample(x) |
| | return x |
| |
|
| | def extra_repr(self) -> str: |
| | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" |
| |
|
| | def flops(self): |
| | flops = 0 |
| | for blk in self.blocks: |
| | flops += blk.flops() |
| | if self.downsample is not None: |
| | flops += self.downsample.flops() |
| | return flops |
| |
|
| |
|
| | class PatchEmbed(nn.Module): |
| | r""" Image to Patch Embedding |
| | |
| | Args: |
| | img_size (int): Image size. Default: 224. |
| | patch_size (int): Patch token size. Default: 4. |
| | in_chans (int): Number of input image channels. Default: 3. |
| | embed_dim (int): Number of linear projection output channels. Default: 96. |
| | norm_layer (nn.Module, optional): Normalization layer. Default: None |
| | """ |
| |
|
| | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): |
| | super().__init__() |
| | img_size = to_2tuple(img_size) |
| | patch_size = to_2tuple(patch_size) |
| | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] |
| | self.img_size = img_size |
| | self.patch_size = patch_size |
| | self.patches_resolution = patches_resolution |
| | self.num_patches = patches_resolution[0] * patches_resolution[1] |
| |
|
| | self.in_chans = in_chans |
| | self.embed_dim = embed_dim |
| |
|
| | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
| | if norm_layer is not None: |
| | self.norm = norm_layer(embed_dim) |
| | else: |
| | self.norm = None |
| |
|
| | def forward(self, x): |
| | B, C, H, W = x.shape |
| | |
| | assert H == self.img_size[0] and W == self.img_size[1], \ |
| | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." |
| | x = self.proj(x).flatten(2).transpose(1, 2) |
| | if self.norm is not None: |
| | x = self.norm(x) |
| | return x |
| |
|
| | def flops(self): |
| | Ho, Wo = self.patches_resolution |
| | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) |
| | if self.norm is not None: |
| | flops += Ho * Wo * self.embed_dim |
| | return flops |
| |
|
| |
|
| | class ConvEmbed(nn.Module): |
| | """ Image to Patch Embedding |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | input_resolution=(224,224), |
| | patch_size=7, |
| | in_chans=3, |
| | embed_dim=64, |
| | stride=4, |
| | padding=2, |
| | norm_layer=None |
| | ): |
| | super().__init__() |
| | self.patch_size = patch_size |
| | self.input_resolution = input_resolution |
| |
|
| | self.proj = nn.Conv2d( |
| | in_chans, embed_dim, |
| | kernel_size=patch_size, |
| | stride=stride, |
| | padding=padding |
| | ) |
| | self.norm = norm_layer(embed_dim) if norm_layer else None |
| |
|
| | def forward(self, x): |
| | if len(x.size()) == 3: |
| | x = rearrange( |
| | x, 'b (h w) c -> b c h w', |
| | h=self.input_resolution[0], |
| | w=self.input_resolution[1] |
| | ) |
| |
|
| | x = self.proj(x) |
| |
|
| | B, C, H, W = x.shape |
| | x = rearrange(x, 'b c h w -> b (h w) c') |
| | if self.norm: |
| | x = self.norm(x) |
| |
|
| | return x |
| |
|
| |
|
| | class SwinTransformer(nn.Module): |
| | r""" Swin Transformer |
| | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - |
| | https://arxiv.org/pdf/2103.14030 |
| | |
| | Args: |
| | img_size (int | tuple(int)): Input image size. Default 224 |
| | patch_size (int | tuple(int)): Patch size. Default: 4 |
| | in_chans (int): Number of input image channels. Default: 3 |
| | num_classes (int): Number of classes for classification head. Default: 1000 |
| | embed_dim (int): Patch embedding dimension. Default: 96 |
| | depths (tuple(int)): Depth of each Swin Transformer layer. |
| | num_heads (tuple(int)): Number of attention heads in different layers. |
| | window_size (int): Window size. Default: 7 |
| | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 |
| | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True |
| | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None |
| | drop_rate (float): Dropout rate. Default: 0 |
| | attn_drop_rate (float): Attention dropout rate. Default: 0 |
| | drop_path_rate (float): Stochastic depth rate. Default: 0.1 |
| | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. |
| | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False |
| | patch_norm (bool): If True, add normalization after patch embedding. Default: True |
| | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False |
| | """ |
| |
|
| | def __init__(self, img_size=224, patch_size=7, patch_padding=2, patch_stride=4, in_chans=3, |
| | num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], |
| | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, |
| | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, |
| | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, |
| | use_checkpoint=False, layer_scale=False, **kwargs): |
| | super().__init__() |
| |
|
| | self.num_classes = num_classes |
| | self.num_layers = len(depths) |
| | self.embed_dim = embed_dim |
| | self.ape = ape |
| | self.patch_norm = patch_norm |
| | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) |
| | self.mlp_ratio = mlp_ratio |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | self.patch_embed = ConvEmbed( |
| | patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, padding=patch_padding, |
| | norm_layer=norm_layer if self.patch_norm else None |
| | ) |
| |
|
| | img_size = to_2tuple(img_size) |
| | patches_resolution = ( |
| | int(np.floor(float(img_size[0]+2*patch_padding-patch_size)/patch_stride+1)), |
| | int(np.floor(float(img_size[0]+2*patch_padding-patch_size)/patch_stride+1)) |
| | ) |
| | num_patches = patches_resolution[0] * patches_resolution[1] |
| | |
| | |
| | self.patches_resolution = patches_resolution |
| |
|
| | |
| | if self.ape: |
| | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) |
| | trunc_normal_(self.absolute_pos_embed, std=.02) |
| |
|
| | self.pos_drop = nn.Dropout(p=drop_rate) |
| |
|
| | |
| | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] |
| |
|
| | |
| | self.layers = nn.ModuleList() |
| | for i_layer in range(self.num_layers): |
| | layer = BasicLayer( |
| | dim=int(embed_dim * 2 ** i_layer), |
| | input_resolution=( |
| | patches_resolution[0] // (2 ** i_layer), |
| | patches_resolution[1] // (2 ** i_layer) |
| | ), |
| | depth=depths[i_layer], |
| | num_heads=num_heads[i_layer], |
| | window_size=window_size, |
| | mlp_ratio=self.mlp_ratio, |
| | qkv_bias=qkv_bias, qk_scale=qk_scale, |
| | drop=drop_rate, attn_drop=attn_drop_rate, |
| | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], |
| | norm_layer=norm_layer, |
| | |
| | downsample=ConvEmbed if (i_layer < self.num_layers - 1) else None, |
| | use_checkpoint=use_checkpoint, |
| | layer_scale=layer_scale |
| | ) |
| | self.layers.append(layer) |
| |
|
| | self.norm = norm_layer(self.num_features) |
| | self.avgpool = nn.AdaptiveAvgPool1d(1) |
| | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() |
| |
|
| | self.apply(self._init_weights) |
| |
|
| | @property |
| | def dim_out(self): |
| | return self.num_features |
| |
|
| | def _init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | trunc_normal_(m.weight, std=.02) |
| | if isinstance(m, nn.Linear) and m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.LayerNorm): |
| | nn.init.constant_(m.bias, 0) |
| | nn.init.constant_(m.weight, 1.0) |
| |
|
| | def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): |
| | if os.path.isfile(pretrained): |
| | logging.info(f'=> loading pretrained model {pretrained}') |
| | pretrained_dict = torch.load(pretrained, map_location='cpu') |
| |
|
| | self.from_state_dict(pretrained_dict, pretrained_layers, verbose) |
| |
|
| | def from_state_dict(self, pretrained_dict, pretrained_layers=[], verbose=True): |
| | model_dict = self.state_dict() |
| | stripped_key = lambda x: x[14:] if x.startswith('image_encoder.') else x |
| |
|
| | pretrained_dict = { |
| | stripped_key(k): v for k, v in pretrained_dict.items() |
| | if stripped_key(k) in model_dict.keys() |
| | } |
| | need_init_state_dict = {} |
| | for k, v in pretrained_dict.items(): |
| | need_init = ( |
| | ( |
| | k.split('.')[0] in pretrained_layers |
| | or pretrained_layers[0] == '*' |
| | ) |
| | and 'relative_position_index' not in k |
| | and 'attn_mask' not in k |
| | ) |
| |
|
| | if need_init: |
| | if verbose: |
| | logger.info(f'=> init {k} from pretrained state dict') |
| |
|
| | if 'relative_position_bias_table' in k and v.size() != model_dict[k].size(): |
| | relative_position_bias_table_pretrained = v |
| | relative_position_bias_table_current = model_dict[k] |
| | L1, nH1 = relative_position_bias_table_pretrained.size() |
| | L2, nH2 = relative_position_bias_table_current.size() |
| | if nH1 != nH2: |
| | logger.info(f"Error in loading {k}, passing") |
| | else: |
| | if L1 != L2: |
| | logger.info( |
| | '=> load_pretrained: resized variant: {} to {}' |
| | .format((L1, nH1), (L2, nH2)) |
| | ) |
| | S1 = int(L1 ** 0.5) |
| | S2 = int(L2 ** 0.5) |
| | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( |
| | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), |
| | size=(S2, S2), |
| | mode='bicubic') |
| | v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) |
| |
|
| | if 'absolute_pos_embed' in k and v.size() != model_dict[k].size(): |
| | absolute_pos_embed_pretrained = v |
| | absolute_pos_embed_current = model_dict[k] |
| | _, L1, C1 = absolute_pos_embed_pretrained.size() |
| | _, L2, C2 = absolute_pos_embed_current.size() |
| | if C1 != C1: |
| | logger.info(f"Error in loading {k}, passing") |
| | else: |
| | if L1 != L2: |
| | logger.info( |
| | '=> load_pretrained: resized variant: {} to {}' |
| | .format((1, L1, C1), (1, L2, C2)) |
| | ) |
| | S1 = int(L1 ** 0.5) |
| | S2 = int(L2 ** 0.5) |
| | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) |
| | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) |
| | absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( |
| | absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') |
| | v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2) |
| |
|
| | need_init_state_dict[k] = v |
| | self.load_state_dict(need_init_state_dict, strict=False) |
| |
|
| | @torch.jit.ignore |
| | def no_weight_decay(self): |
| | return {'absolute_pos_embed'} |
| |
|
| | @torch.jit.ignore |
| | def no_weight_decay_keywords(self): |
| | return {'relative_position_bias_table'} |
| |
|
| | def forward_features(self, x): |
| | x = self.patch_embed(x) |
| | if self.ape: |
| | x = x + self.absolute_pos_embed |
| | x = self.pos_drop(x) |
| |
|
| | for layer in self.layers: |
| | x = layer(x) |
| |
|
| | x = self.norm(x) |
| | x = self.avgpool(x.transpose(1, 2)) |
| | x = torch.flatten(x, 1) |
| | return x |
| |
|
| | def forward(self, x): |
| | x = self.forward_features(x) |
| | x = self.head(x) |
| | return x |
| |
|
| |
|
| | @register_image_encoder |
| | def image_encoder(config_encoder, verbose, **kwargs): |
| | spec = config_encoder['SPEC'] |
| |
|
| | coswin = SwinTransformer( |
| | img_size=config_encoder['IMAGE_SIZE'], |
| | patch_size=spec['PATCH_SIZE'], |
| | patch_padding=spec['PATCH_PADDING'], |
| | patch_stride=spec['PATCH_STRIDE'], |
| | in_chans=spec['IN_CHANS'], |
| | num_classes=0, |
| | embed_dim=spec['EMBED_DIM'], |
| | depths=spec['DEPTHS'], |
| | num_heads=spec['NUM_HEADS'], |
| | window_size=spec['WINDOW_SIZE'], |
| | mlp_ratio=spec['MLP_RATIO'], |
| | qkv_bias=spec['QKV_BIAS'], |
| | qk_scale=spec.get('QK_SCALE', None), |
| | drop_rate=spec['DROP_RATE'], |
| | drop_path_rate=spec['DROP_PATH_RATE'], |
| | ape=spec['APE'], |
| | patch_norm=spec['PATCH_NORM'], |
| | layer_scale=spec.get('LAYER_SCALE', False), |
| | use_checkpoint=spec.get('ENABLE_CHECKPOINT', False) |
| | ) |
| |
|
| | if config_encoder['LOAD_PRETRAINED']: |
| | coswin.from_pretrained( |
| | config_encoder['PRETRAINED'], |
| | config_encoder['PRETRAINED_LAYERS'], |
| | verbose |
| | ) |
| |
|
| | return coswin |
| |
|