| |
| |
| |
| |
| """ |
| Implementation of the following modules is borrowed from ml-cvnets repo: |
| https://github.com/apple/ml-cvnets/blob/main/cvnets/layers/multi_head_attention.py |
| https://github.com/apple/ml-cvnets/blob/main/cvnets/text_encoders/transformer.py |
| |
| Please see ACKNOWLEDGEMENTS for license details. |
| """ |
|
|
| from typing import List, Optional, Union |
|
|
| import torch |
| from torch import Size, Tensor, nn |
| from torch.nn import functional as F |
| from torchvision.ops import StochasticDepth |
|
|
|
|
| class LayerNormFP32(nn.LayerNorm): |
| """ |
| Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a input tensor with FP32 precision |
| """ |
|
|
| def __init__( |
| self, |
| normalized_shape: Union[int, List[int], Size], |
| eps: Optional[float] = 1e-5, |
| elementwise_affine: Optional[bool] = True, |
| *args, |
| **kwargs, |
| ): |
| super().__init__( |
| normalized_shape=normalized_shape, |
| eps=eps, |
| elementwise_affine=elementwise_affine, |
| *args, |
| **kwargs, |
| ) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| |
| |
| inp_dtype = x.dtype |
| return super().forward(x.to(torch.float32)).to(inp_dtype) |
|
|
|
|
| def get_normalization_layer(norm_type, num_features): |
| if norm_type == "layer_norm": |
| return nn.LayerNorm(num_features) |
| elif norm_type == "layer_norm_fp32": |
| return LayerNormFP32(num_features) |
| else: |
| raise NotImplementedError(f"Option: {norm_type} not supported.") |
|
|
|
|
| class PositionalEmbedding(nn.Module): |
| def __init__( |
| self, |
| num_embeddings: int, |
| embedding_dim: int, |
| padding_idx: Optional[int] = None, |
| is_learnable: Optional[bool] = False, |
| interpolation_mode: Optional[str] = "bilinear", |
| *args, |
| **kwargs, |
| ): |
| super().__init__() |
| |
| module = LearnablePositionalEmbedding |
|
|
| self.pos_embed = module( |
| num_embeddings=num_embeddings, |
| embedding_dim=embedding_dim, |
| padding_idx=padding_idx, |
| interpolation_mode=interpolation_mode, |
| *args, |
| **kwargs, |
| ) |
|
|
| def forward(self, seq_len: int, *args, **kwargs) -> Tensor: |
| return self.pos_embed(seq_len, *args, **kwargs) |
|
|
| def __repr__(self): |
| return self.pos_embed.__repr__() |
|
|
|
|
| class LearnablePositionalEmbedding(nn.Module): |
| """Learnable Positional embedding""" |
|
|
| def __init__( |
| self, |
| num_embeddings: int, |
| embedding_dim: int, |
| padding_idx: Optional[int] = None, |
| interpolation_mode: Optional[str] = "bilinear", |
| *args, |
| **kwargs, |
| ): |
| super().__init__() |
| self.pos_embed = nn.Parameter(torch.empty(1, 1, num_embeddings, embedding_dim)) |
| self.embedding_dim = embedding_dim |
| self.num_embeddings = num_embeddings |
| self.padding_idx = padding_idx |
| self.interpolation_mode = interpolation_mode |
|
|
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| nn.init.trunc_normal_(self.pos_embed, mean=0, std=self.embedding_dim**-0.5) |
| if self.padding_idx is not None: |
| with torch.no_grad(): |
| self.pos_embed[:, :, self.padding_idx, ...] = 0.0 |
|
|
| def forward(self, seq_len: int, *args, **kwargs) -> Tensor: |
| |
| pos_embed = self.pos_embed |
| if self.padding_idx is not None: |
| with torch.no_grad(): |
| pos_embed[:, :, self.padding_idx, ...] = 0.0 |
|
|
| if seq_len != self.num_embeddings: |
| pos_embed = F.interpolate( |
| pos_embed, |
| size=(seq_len, self.embedding_dim), |
| mode=self.interpolation_mode, |
| ) |
|
|
| |
| return pos_embed.reshape(1, seq_len, self.embedding_dim) |
|
|
| def __repr__(self): |
| return "{}(num_embeddings={}, embedding_dim={}, padding_idx={})".format( |
| self.__class__.__name__, |
| self.num_embeddings, |
| self.embedding_dim, |
| self.padding_idx, |
| ) |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| """ |
| This layer applies a multi-head self- or cross-attention as described in |
| `Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper |
| |
| Args: |
| embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, S, C_{in})` |
| num_heads (int): Number of heads in multi-head attention |
| attn_dropout (Optional[float]): Attention dropout. Default: 0.0 |
| bias (Optional[bool]): Use bias or not. Default: ``True`` |
| |
| Shape: |
| - Input: |
| - Query tensor (x_q) :math:`(N, S, C_{in})` where :math:`N` is batch size, :math:`S` is number of source tokens, |
| and :math:`C_{in}` is input embedding dim |
| - Optional Key-Value tensor (x_kv) :math:`(N, T, C_{in})` where :math:`T` is number of target tokens |
| - Output: same shape as the input |
| |
| """ |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int, |
| attn_dropout: Optional[float] = 0.0, |
| bias: Optional[bool] = True, |
| output_dim: Optional[int] = None, |
| *args, |
| **kwargs, |
| ) -> None: |
| if output_dim is None: |
| output_dim = embed_dim |
| super().__init__() |
| if embed_dim % num_heads != 0: |
| Warning( |
| "Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format( |
| self.__class__.__name__, embed_dim, num_heads |
| ) |
| ) |
|
|
| self.qkv_proj = nn.Linear( |
| in_features=embed_dim, out_features=3 * embed_dim, bias=bias |
| ) |
|
|
| self.attn_dropout = nn.Dropout(p=attn_dropout) |
| self.out_proj = nn.Linear( |
| in_features=embed_dim, out_features=output_dim, bias=bias |
| ) |
|
|
| self.head_dim = embed_dim // num_heads |
| self.scaling = self.head_dim**-0.5 |
| self.softmax = nn.Softmax(dim=-1) |
| self.num_heads = num_heads |
| self.embed_dim = embed_dim |
| self.use_separate_proj_weight = embed_dim != output_dim |
|
|
| def __repr__(self): |
| return "{}(head_dim={}, num_heads={}, attn_dropout={})".format( |
| self.__class__.__name__, self.head_dim, self.num_heads, self.attn_dropout.p |
| ) |
|
|
| def _forward_impl( |
| self, |
| x_q: Tensor, |
| x_kv: Optional[Tensor] = None, |
| key_padding_mask: Optional[Tensor] = None, |
| attn_mask: Optional[Tensor] = None, |
| ) -> Tensor: |
| |
| b_sz, S_len, in_channels = x_q.shape |
|
|
| if x_kv is None: |
| |
| |
| qkv = self.qkv_proj(x_q).reshape(b_sz, S_len, 3, self.num_heads, -1) |
| |
| qkv = qkv.transpose(1, 3).contiguous() |
|
|
| |
| query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] |
| else: |
| T_len = x_kv.shape[1] |
|
|
| |
| |
| query = F.linear( |
| x_q, |
| weight=self.qkv_proj.weight[: self.embed_dim, ...], |
| bias=self.qkv_proj.bias[: self.embed_dim] |
| if self.qkv_proj.bias is not None |
| else None, |
| ) |
| |
| query = ( |
| query.reshape(b_sz, S_len, self.num_heads, self.head_dim) |
| .transpose(1, 2) |
| .contiguous() |
| ) |
|
|
| |
| kv = F.linear( |
| x_kv, |
| weight=self.qkv_proj.weight[self.embed_dim :, ...], |
| bias=self.qkv_proj.bias[self.embed_dim :] |
| if self.qkv_proj.bias is not None |
| else None, |
| ) |
| |
| kv = kv.reshape(b_sz, T_len, 2, self.num_heads, self.head_dim) |
| |
| kv = kv.transpose(1, 3).contiguous() |
| key, value = kv[:, :, 0], kv[:, :, 1] |
|
|
| query = query * self.scaling |
|
|
| |
| key = key.transpose(-1, -2) |
|
|
| |
| |
| attn = torch.matmul(query, key) |
|
|
| batch_size, num_heads, num_src_tokens, num_tgt_tokens = attn.shape |
| if attn_mask is not None: |
| |
| assert list(attn_mask.shape) == [ |
| batch_size, |
| num_src_tokens, |
| num_tgt_tokens, |
| ], "Shape of attention mask should be [{}, {}, {}]. Got: {}".format( |
| batch_size, num_src_tokens, num_tgt_tokens, attn_mask.shape |
| ) |
| |
| attn_mask = attn_mask.unsqueeze(1) |
| attn = attn + attn_mask |
|
|
| if key_padding_mask is not None: |
| |
| |
| assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [ |
| batch_size, |
| num_tgt_tokens, |
| ], "Key_padding_mask should be 2-dimension with shape [{}, {}]. Got: {}".format( |
| batch_size, num_tgt_tokens, key_padding_mask.shape |
| ) |
| attn = attn.masked_fill( |
| key_padding_mask.unsqueeze(1) |
| .unsqueeze(2) |
| .to(torch.bool), |
| float("-inf"), |
| ) |
|
|
| attn_dtype = attn.dtype |
| attn_as_float = self.softmax(attn.float()) |
| attn = attn_as_float.to(attn_dtype) |
| attn = self.attn_dropout(attn) |
|
|
| |
| |
| out = torch.matmul(attn, value) |
|
|
| |
| out = out.transpose(1, 2).reshape(b_sz, S_len, -1) |
| out = self.out_proj(out) |
|
|
| return out |
|
|
| def forward( |
| self, |
| x_q: Tensor, |
| x_kv: Optional[Tensor] = None, |
| key_padding_mask: Optional[Tensor] = None, |
| attn_mask: Optional[Tensor] = None, |
| *args, |
| **kwargs, |
| ) -> Tensor: |
| |
| return self._forward_impl( |
| x_q=x_q, |
| x_kv=x_kv, |
| key_padding_mask=key_padding_mask, |
| attn_mask=attn_mask, |
| ) |
|
|
|
|
| class TransformerEncoder(nn.Module): |
| """ |
| This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_ |
| Args: |
| embed_dim: :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`. |
| ffn_latent_dim: Inner dimension of the FFN. |
| num_heads: Number of heads in multi-head attention. Default: 8. |
| attn_dropout: Dropout rate for attention in multi-head attention. Default: 0.0 |
| dropout: Dropout rate. Default: 0.0. |
| ffn_dropout: Dropout between FFN layers. Default: 0.0. |
| transformer_norm_layer: Normalization layer. Default: layer_norm. |
| stochastic_dropout: Stochastic dropout setting. Default: 0.0. |
| |
| Shape: |
| - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches, |
| and :math:`C_{in}` is input embedding dim |
| - Output: same shape as the input |
| """ |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| ffn_latent_dim: int, |
| num_heads: Optional[int] = 8, |
| attn_dropout: Optional[float] = 0.0, |
| dropout: Optional[float] = 0.0, |
| ffn_dropout: Optional[float] = 0.0, |
| transformer_norm_layer: Optional[str] = "layer_norm", |
| stochastic_dropout: Optional[float] = 0.0, |
| *args, |
| **kwargs, |
| ) -> None: |
|
|
| super().__init__() |
|
|
| |
| attn_unit = MultiHeadAttention( |
| embed_dim, |
| num_heads, |
| attn_dropout=attn_dropout, |
| bias=True, |
| ) |
|
|
| self.pre_norm_mha = nn.Sequential( |
| get_normalization_layer( |
| norm_type=transformer_norm_layer, num_features=embed_dim |
| ), |
| attn_unit, |
| nn.Dropout(p=dropout), |
| ) |
|
|
| act_name = nn.GELU() |
| self.pre_norm_ffn = nn.Sequential( |
| get_normalization_layer( |
| norm_type=transformer_norm_layer, num_features=embed_dim |
| ), |
| nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True), |
| act_name, |
| nn.Dropout(p=ffn_dropout), |
| nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True), |
| nn.Dropout(p=dropout), |
| ) |
|
|
| self.drop_path = nn.Identity() |
| if stochastic_dropout > 0.0: |
| if dropout > 0.0: |
| Warning( |
| "Stochastic dropout and dropout are mutually exclusive. " |
| "Use either of them, but not both." |
| "Got: {} and {}".format(stochastic_dropout, dropout) |
| ) |
| self.drop_path = StochasticDepth(p=stochastic_dropout, mode="row") |
|
|
| self.embed_dim = embed_dim |
| self.ffn_dim = ffn_latent_dim |
| self.ffn_dropout = ffn_dropout |
| self.stochastic_dropout = stochastic_dropout |
| self.std_dropout = dropout |
| self.attn_fn_name = attn_unit.__class__.__name__ |
| self.act_fn_name = act_name.__class__.__name__ |
| self.norm_type = transformer_norm_layer |
|
|
| def __repr__(self) -> str: |
| return "{}(embed_dim={}, ffn_dim={}, dropout={}, ffn_dropout={}, stochastic_dropout={}, attn_fn={}, act_fn={}, norm_fn={})".format( |
| self.__class__.__name__, |
| self.embed_dim, |
| self.ffn_dim, |
| self.std_dropout, |
| self.ffn_dropout, |
| self.stochastic_dropout, |
| self.attn_fn_name, |
| self.act_fn_name, |
| self.norm_type, |
| ) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| x_prev: Optional[Tensor] = None, |
| key_padding_mask: Optional[Tensor] = None, |
| attn_mask: Optional[Tensor] = None, |
| *args, |
| **kwargs, |
| ) -> Tensor: |
|
|
| |
| res = x |
| x = self.pre_norm_mha[0](x) |
| x = self.pre_norm_mha[1]( |
| x_q=x, |
| x_kv=x_prev, |
| key_padding_mask=key_padding_mask, |
| attn_mask=attn_mask, |
| *args, |
| **kwargs, |
| ) |
|
|
| x = self.drop_path(self.pre_norm_mha[2](x)) |
| x = x + res |
|
|
| |
| x = x + self.drop_path(self.pre_norm_ffn(x)) |
| return x |
|
|