SUMMIT / vit_model.py
DYunt's picture
Upload 26 files
2659b26 verified
import torch
import torch.nn as nn
import timm
from timm.models.vision_transformer import PatchEmbed
from functools import partial
class vit(timm.models.vision_transformer.VisionTransformer):
def __init__(self, global_pool=False, **kwargs):
super(vit, self).__init__()
self.global_pool = global_pool
embed_dim = kwargs['embed_dim']
num_classes = kwargs['num_classes']
self.head = nn.Linear(embed_dim, num_classes, bias=True)
if self.global_pool:
norm_layer = kwargs['norm_layer']
embed_dim = kwargs['embed_dim']
self.fc_norm = norm_layer(embed_dim)
del self.norm
for param in self.parameters():
param.requires_grad = False
for param in self.head.parameters():
param.requires_grad = True
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
if self.global_pool:
x = x[:, 1:, :].mean(dim=1)
outcome = self.fc_norm(x)
else:
x = self.norm(x)
outcome = x[:, 0]
return outcome
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def vit_base_patch16(**kwargs):
model = vit(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def vit_large_patch16(**kwargs):
model = vit(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def vit_huge_patch14(**kwargs):
model = vit(patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model