| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
|
|
| import copy |
| import fnmatch |
| import logging |
| from functools import partial |
| from typing import Callable, List |
|
|
| import torch |
| import torch.nn as nn |
| import torch.utils.checkpoint as checkpoint |
|
|
| from timm.models.layers import DropPath, trunc_normal_ |
|
|
|
|
| class Attention(nn.Module): |
| def __init__( |
| self, |
| dim, |
| num_heads=8, |
| qkv_bias=False, |
| qk_scale=None, |
| attn_drop=0.0, |
| proj_drop=0.0, |
| ): |
| super().__init__() |
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| |
| |
| self.scale = qk_scale or head_dim**-0.5 |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| def forward(self, x): |
| B, N, C = x.shape |
| qkv = ( |
| self.qkv(x) |
| .reshape(B, N, 3, self.num_heads, C // self.num_heads) |
| .permute(2, 0, 3, 1, 4) |
| ) |
| q, k, v = ( |
| qkv[0], |
| qkv[1], |
| qkv[2], |
| ) |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
| return x |
|
|
|
|
| class Mlp(nn.Module): |
| def __init__( |
| self, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| act_layer=nn.GELU, |
| drop=0.0, |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features) |
| self.act = act_layer() |
| self.fc2 = nn.Linear(hidden_features, out_features) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
|
|
| class MultiheadAttention(nn.MultiheadAttention): |
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): |
| return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] |
|
|
|
|
| class ViTAttention(Attention): |
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): |
| assert attn_mask is None |
| return super().forward(x) |
|
|
|
|
| class BlockWithMasking(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| attn_target: Callable, |
| mlp_ratio: int = 4, |
| act_layer: Callable = nn.GELU, |
| norm_layer: Callable = nn.LayerNorm, |
| ffn_dropout_rate: float = 0.0, |
| drop_path: float = 0.0, |
| layer_scale_type: str = None, |
| layer_scale_init_value: float = 1e-4, |
| ): |
| super().__init__() |
|
|
| assert not isinstance( |
| attn_target, nn.Module |
| ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" |
| self.attn = attn_target() |
| if drop_path > 0.0: |
| self.drop_path = DropPath(drop_path) |
| else: |
| self.drop_path = nn.Identity() |
| self.norm_1 = norm_layer(dim) |
| mlp_hidden_dim = int(mlp_ratio * dim) |
| self.mlp = Mlp( |
| in_features=dim, |
| hidden_features=mlp_hidden_dim, |
| act_layer=act_layer, |
| drop=ffn_dropout_rate, |
| ) |
| self.norm_2 = norm_layer(dim) |
| self.layer_scale_type = layer_scale_type |
| if self.layer_scale_type is not None: |
| assert self.layer_scale_type in [ |
| "per_channel", |
| "scalar", |
| ], f"Found Layer scale type {self.layer_scale_type}" |
| if self.layer_scale_type == "per_channel": |
| |
| gamma_shape = [1, 1, dim] |
| elif self.layer_scale_type == "scalar": |
| |
| gamma_shape = [1, 1, 1] |
| |
| self.layer_scale_gamma1 = nn.Parameter( |
| torch.ones(size=gamma_shape) * layer_scale_init_value, |
| requires_grad=True, |
| ) |
| self.layer_scale_gamma2 = nn.Parameter( |
| torch.ones(size=gamma_shape) * layer_scale_init_value, |
| requires_grad=True, |
| ) |
|
|
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): |
| if self.layer_scale_type is None: |
| x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) |
| x = x + self.drop_path(self.mlp(self.norm_2(x))) |
| else: |
| x = ( |
| x |
| + self.drop_path(self.attn(self.norm_1(x), attn_mask)) |
| * self.layer_scale_gamma1 |
| ) |
| x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 |
| return x |
|
|
|
|
| _LAYER_NORM = partial(nn.LayerNorm, eps=1e-6) |
|
|
|
|
| class SimpleTransformer(nn.Module): |
| def __init__( |
| self, |
| attn_target: Callable, |
| embed_dim: int, |
| num_blocks: int, |
| block: Callable = BlockWithMasking, |
| pre_transformer_layer: Callable = None, |
| post_transformer_layer: Callable = None, |
| drop_path_rate: float = 0.0, |
| drop_path_type: str = "progressive", |
| norm_layer: Callable = _LAYER_NORM, |
| mlp_ratio: int = 4, |
| ffn_dropout_rate: float = 0.0, |
| layer_scale_type: str = None, |
| layer_scale_init_value: float = 1e-4, |
| weight_init_style: str = "jax", |
| ): |
| """ |
| Simple Transformer with the following features |
| 1. Supports masked attention |
| 2. Supports DropPath |
| 3. Supports LayerScale |
| 4. Supports Dropout in Attention and FFN |
| 5. Makes few assumptions about the input except that it is a Tensor |
| """ |
| super().__init__() |
| self.pre_transformer_layer = pre_transformer_layer |
| if drop_path_type == "progressive": |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] |
| elif drop_path_type == "uniform": |
| dpr = [drop_path_rate for i in range(num_blocks)] |
| else: |
| raise ValueError(f"Unknown drop_path_type: {drop_path_type}") |
|
|
| self.blocks = nn.Sequential( |
| *[ |
| block( |
| dim=embed_dim, |
| attn_target=attn_target, |
| mlp_ratio=mlp_ratio, |
| ffn_dropout_rate=ffn_dropout_rate, |
| drop_path=dpr[i], |
| norm_layer=norm_layer, |
| layer_scale_type=layer_scale_type, |
| layer_scale_init_value=layer_scale_init_value, |
| ) |
| for i in range(num_blocks) |
| ] |
| ) |
| self.post_transformer_layer = post_transformer_layer |
| self.weight_init_style = weight_init_style |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| if self.weight_init_style == "jax": |
| |
| torch.nn.init.xavier_uniform_(m.weight) |
| elif self.weight_init_style == "pytorch": |
| |
| trunc_normal_(m.weight, std=0.02) |
|
|
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, (nn.LayerNorm)): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
|
|
| def forward( |
| self, |
| tokens: torch.Tensor, |
| attn_mask: torch.Tensor = None, |
| use_checkpoint: bool = False, |
| checkpoint_every_n: int = 1, |
| checkpoint_blk_ids: List[int] = None, |
| ): |
| """ |
| Inputs |
| - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation) |
| - attn: mask of shape L x L |
| |
| Output |
| - x: data of shape N x L x D (or L x N x D depending on the attention implementation) |
| """ |
| if self.pre_transformer_layer: |
| tokens = self.pre_transformer_layer(tokens) |
| if use_checkpoint and checkpoint_blk_ids is None: |
| checkpoint_blk_ids = [ |
| blk_id |
| for blk_id in range(len(self.blocks)) |
| if blk_id % checkpoint_every_n == 0 |
| ] |
| if checkpoint_blk_ids: |
| checkpoint_blk_ids = set(checkpoint_blk_ids) |
| for blk_id, blk in enumerate(self.blocks): |
| if use_checkpoint and blk_id in checkpoint_blk_ids: |
| tokens = checkpoint.checkpoint( |
| blk, tokens, attn_mask, use_reentrant=False |
| ) |
| else: |
| tokens = blk(tokens, attn_mask=attn_mask) |
| if self.post_transformer_layer: |
| tokens = self.post_transformer_layer(tokens) |
| return tokens |
|
|