| | from functools import partial |
| | from typing import Optional, Tuple, Type |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from segment_anything.modeling import MaskDecoder, PromptEncoder, Sam, TwoWayTransformer |
| | from segment_anything.modeling.common import LayerNorm2d |
| | from segment_anything.modeling.image_encoder import ( |
| | Block, |
| | PatchEmbed, |
| | window_partition, |
| | window_unpartition, |
| | ) |
| |
|
| |
|
| | class CustomBlock(Block): |
| | def __init__(self, **kargs) -> None: |
| | super().__init__(**kargs) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | shortcut = x |
| | x = self.norm1(x) |
| | |
| | if self.window_size > 0: |
| | H, W = x.shape[1], x.shape[2] |
| | x, pad_hw = window_partition(x, self.window_size) |
| | x = self.attn(x) |
| | |
| | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) |
| | else: |
| | x = self.attn(x) |
| |
|
| | x = shortcut + x |
| | x = x + self.mlp(self.norm2(x)) |
| |
|
| | return x |
| |
|
| |
|
| | class CustomImageEncoderViT(nn.Module): |
| | def __init__( |
| | self, |
| | img_size: int = 1024, |
| | patch_size: int = 16, |
| | in_chans: int = 3, |
| | embed_dim: int = 768, |
| | depth: int = 12, |
| | num_heads: int = 12, |
| | mlp_ratio: float = 4.0, |
| | out_chans: int = 256, |
| | qkv_bias: bool = True, |
| | norm_layer: Type[nn.Module] = nn.LayerNorm, |
| | act_layer: Type[nn.Module] = nn.GELU, |
| | use_abs_pos: bool = True, |
| | use_rel_pos: bool = False, |
| | rel_pos_zero_init: bool = True, |
| | window_size: int = 0, |
| | global_attn_indexes: Tuple[int, ...] = (), |
| | ) -> None: |
| | super().__init__() |
| | self.img_size = img_size |
| |
|
| | self.patch_embed = PatchEmbed( |
| | kernel_size=(patch_size, patch_size), |
| | stride=(patch_size, patch_size), |
| | in_chans=in_chans, |
| | embed_dim=embed_dim, |
| | ) |
| |
|
| | self.pos_embed: Optional[nn.Parameter] = None |
| | if use_abs_pos: |
| | |
| | self.pos_embed = nn.Parameter( |
| | torch.zeros( |
| | 1, img_size // patch_size, img_size // patch_size, embed_dim |
| | ) |
| | ) |
| |
|
| | self.blocks = nn.ModuleList() |
| | for i in range(depth): |
| | block = CustomBlock( |
| | dim=embed_dim, |
| | num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, |
| | norm_layer=norm_layer, |
| | act_layer=act_layer, |
| | use_rel_pos=use_rel_pos, |
| | rel_pos_zero_init=rel_pos_zero_init, |
| | window_size=window_size if i not in global_attn_indexes else 0, |
| | input_size=(img_size // patch_size, img_size // patch_size), |
| | ) |
| | self.blocks.append(block) |
| |
|
| | self.neck = nn.Sequential( |
| | nn.Conv2d( |
| | embed_dim, |
| | out_chans, |
| | kernel_size=1, |
| | bias=False, |
| | ), |
| | LayerNorm2d(out_chans), |
| | nn.Conv2d( |
| | out_chans, |
| | out_chans, |
| | kernel_size=3, |
| | padding=1, |
| | bias=False, |
| | ), |
| | LayerNorm2d(out_chans), |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.patch_embed(x) |
| | if self.pos_embed is not None: |
| | x = x + self.pos_embed |
| |
|
| | for blk in self.blocks: |
| | x = blk(x) |
| |
|
| | x = self.neck(x.permute(0, 3, 1, 2)) |
| |
|
| | return x |
| |
|
| |
|
| | def _build_sam_torchscript( |
| | encoder_embed_dim, |
| | encoder_depth, |
| | encoder_num_heads, |
| | encoder_global_attn_indexes, |
| | checkpoint=None, |
| | ): |
| | prompt_embed_dim = 256 |
| | image_size = 1024 |
| | vit_patch_size = 16 |
| | image_embedding_size = image_size // vit_patch_size |
| | sam = Sam( |
| | image_encoder=CustomImageEncoderViT( |
| | depth=encoder_depth, |
| | embed_dim=encoder_embed_dim, |
| | img_size=image_size, |
| | mlp_ratio=4, |
| | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), |
| | num_heads=encoder_num_heads, |
| | patch_size=vit_patch_size, |
| | qkv_bias=True, |
| | use_rel_pos=True, |
| | global_attn_indexes=encoder_global_attn_indexes, |
| | window_size=14, |
| | out_chans=prompt_embed_dim, |
| | ), |
| | prompt_encoder=PromptEncoder( |
| | embed_dim=prompt_embed_dim, |
| | image_embedding_size=(image_embedding_size, image_embedding_size), |
| | input_image_size=(image_size, image_size), |
| | mask_in_chans=16, |
| | ), |
| | mask_decoder=MaskDecoder( |
| | num_multimask_outputs=3, |
| | transformer=TwoWayTransformer( |
| | depth=2, |
| | embedding_dim=prompt_embed_dim, |
| | mlp_dim=2048, |
| | num_heads=8, |
| | ), |
| | transformer_dim=prompt_embed_dim, |
| | iou_head_depth=3, |
| | iou_head_hidden_dim=256, |
| | ), |
| | pixel_mean=[123.675, 116.28, 103.53], |
| | pixel_std=[58.395, 57.12, 57.375], |
| | ) |
| | sam.eval() |
| | if checkpoint is not None: |
| | with open(checkpoint, "rb") as f: |
| | state_dict = torch.load(f) |
| | sam.load_state_dict(state_dict) |
| | return sam |
| |
|
| |
|
| | def build_sam_vit_h_torchscript(checkpoint=None): |
| | return _build_sam_torchscript( |
| | encoder_embed_dim=1280, |
| | encoder_depth=32, |
| | encoder_num_heads=16, |
| | encoder_global_attn_indexes=[7, 15, 23, 31], |
| | checkpoint=checkpoint, |
| | ) |
| |
|