| import torch | |
| import torch.nn as nn | |
| import timm | |
| from timm.models.vision_transformer import PatchEmbed | |
| from functools import partial | |
| class vit(timm.models.vision_transformer.VisionTransformer): | |
| def __init__(self, global_pool=False, **kwargs): | |
| super(vit, self).__init__() | |
| self.global_pool = global_pool | |
| embed_dim = kwargs['embed_dim'] | |
| num_classes = kwargs['num_classes'] | |
| self.head = nn.Linear(embed_dim, num_classes, bias=True) | |
| if self.global_pool: | |
| norm_layer = kwargs['norm_layer'] | |
| embed_dim = kwargs['embed_dim'] | |
| self.fc_norm = norm_layer(embed_dim) | |
| del self.norm | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| for param in self.head.parameters(): | |
| param.requires_grad = True | |
| def forward_features(self, x): | |
| B = x.shape[0] | |
| x = self.patch_embed(x) | |
| cls_tokens = self.cls_token.expand(B, -1, -1) | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| x = x + self.pos_embed | |
| x = self.pos_drop(x) | |
| for blk in self.blocks: | |
| x = blk(x) | |
| if self.global_pool: | |
| x = x[:, 1:, :].mean(dim=1) | |
| outcome = self.fc_norm(x) | |
| else: | |
| x = self.norm(x) | |
| outcome = x[:, 0] | |
| return outcome | |
| def forward(self, x): | |
| x = self.forward_features(x) | |
| x = self.head(x) | |
| return x | |
| def vit_base_patch16(**kwargs): | |
| model = vit(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| return model | |
| def vit_large_patch16(**kwargs): | |
| model = vit(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| return model | |
| def vit_huge_patch14(**kwargs): | |
| model = vit(patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| return model | |