| from collections import OrderedDict |
| from typing import Tuple, Union |
| import logging |
| import os |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
| from timm.models.layers import DropPath, trunc_normal_ |
|
|
| from .registry import register_lang_encoder |
| from ..Utils import is_main_process |
| from ..Utils import register_norm_module |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
| @register_norm_module |
| class LayerNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-12): |
| """Construct a layernorm module in the TF style (epsilon inside the square root). |
| """ |
| super(LayerNorm, self).__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.bias = nn.Parameter(torch.zeros(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, x): |
| pdtype = x.dtype |
| x = x.float() |
| u = x.mean(-1, keepdim=True) |
| s = (x - u).pow(2).mean(-1, keepdim=True) |
| x = (x - u) / torch.sqrt(s + self.variance_epsilon) |
| return self.weight * x.to(pdtype) + self.bias |
|
|
|
|
| class QuickGELU(nn.Module): |
| def forward(self, x: torch.Tensor): |
| return x * torch.sigmoid(1.702 * x) |
|
|
|
|
| class ResidualAttentionBlock(nn.Module): |
| def __init__(self, |
| d_model: int, |
| n_head: int, |
| attn_mask: torch.Tensor = None, |
| drop_path: float = 0.0): |
| super().__init__() |
|
|
| self.attn = nn.MultiheadAttention(d_model, n_head) |
| self.ln_1 = LayerNorm(d_model) |
| self.mlp = nn.Sequential(OrderedDict([ |
| ("c_fc", nn.Linear(d_model, d_model * 4)), |
| ("gelu", QuickGELU()), |
| ("c_proj", nn.Linear(d_model * 4, d_model)) |
| ])) |
| self.ln_2 = LayerNorm(d_model) |
| self.attn_mask = attn_mask |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
| def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): |
| self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ |
| if self.attn_mask is not None else None |
|
|
|
|
| return self.attn( |
| x, x, x, |
| key_padding_mask=key_padding_mask, |
| need_weights=False, |
| attn_mask=self.attn_mask |
| )[0] |
|
|
| def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): |
| x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) |
| x = x + self.drop_path(self.mlp(self.ln_2(x))) |
| return x |
|
|
|
|
| class Transformer(nn.Module): |
| def __init__(self, |
| context_length: int, |
| vocab_size: int, |
| width: int, |
| layers: int, |
| heads: int, |
| drop_path: float = 0.0, |
| autogressive: bool =True, |
| key_padding_token: int = 0, |
| ): |
| super().__init__() |
|
|
| self.token_embedding = nn.Embedding(vocab_size, width) |
| self.key_padding_token = key_padding_token |
|
|
| self.context_length = context_length |
| self.positional_embedding = nn.Parameter( |
| torch.empty(self.context_length, width) |
| ) |
|
|
| self.width = width |
| self.layers = layers |
| self.autogressive = autogressive |
| attn_mask = self.build_attention_mask() if autogressive else None |
| dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] |
| self.resblocks = nn.ModuleList( |
| [ |
| ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) |
| for i in range(layers) |
| ] |
| ) |
|
|
| self.ln_final = LayerNorm(width) |
|
|
| trunc_normal_(self.positional_embedding, std=.02) |
| |
| trunc_normal_(self.token_embedding.weight, std=.02) |
| self.apply(self._init_weights) |
|
|
| @property |
| def dim_out(self): |
| return self.width |
|
|
| def build_attention_mask(self): |
| |
| |
| mask = torch.empty(self.context_length, self.context_length) |
| mask.fill_(float("-inf")) |
| mask.triu_(1) |
| return mask |
|
|
| def _init_weights(self, m): |
| if isinstance(m, (nn.Linear, nn.Conv2d)): |
| if is_main_process(): |
| logger.info('=> init weight of Linear/Conv2d from trunc norm') |
| trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| if is_main_process(): |
| logger.info('=> init bias of Linear/Conv2d to zeros') |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): |
| nn.init.constant_(m.bias, 0) |
|
|
| def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): |
| if os.path.isfile(pretrained): |
| pretrained_dict = torch.load(pretrained, map_location='cpu') |
| logging.info(f'=> loading pretrained model {pretrained}') |
| model_dict = self.state_dict() |
| pretrained_dict = { |
| k: v for k, v in pretrained_dict.items() |
| if k in model_dict.keys() |
| } |
| need_init_state_dict = {} |
| for k, v in pretrained_dict.items(): |
| need_init = ( |
| k.split('.')[0] in pretrained_layers |
| or pretrained_layers[0] == '*' |
| ) |
| if need_init: |
| if verbose: |
| logging.info(f'=> init {k} from {pretrained}') |
|
|
| need_init_state_dict[k] = v |
| self.load_state_dict(need_init_state_dict, strict=False) |
|
|
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| return { |
| 'positional_embedding', |
| 'token_embedding', |
| } |
|
|
| def forward(self, input_ids, attention_mask=None): |
| input_ids = input_ids.to(self.positional_embedding.device, non_blocking=True) |
| |
| |
| |
| |
| key_padding_mask = (attention_mask == 0) if not self.autogressive else None |
| |
| x = self.token_embedding(input_ids) |
| x = x + self.positional_embedding |
| x = x.permute(1, 0, 2) |
| for block in self.resblocks: |
| x = block(x, key_padding_mask) |
| x = x.permute(1, 0, 2) |
|
|
| x = self.ln_final(x) |
|
|
| return {'last_hidden_state': x} |
|
|
|
|
| @register_lang_encoder |
| def lang_encoder(config_encoder, tokenizer, verbose, **kwargs): |
| transformer = Transformer( |
| context_length=config_encoder['CONTEXT_LENGTH'], |
| vocab_size=tokenizer.vocab_size, |
| width=config_encoder['WIDTH'], |
| layers=config_encoder['LAYERS'], |
| heads=config_encoder['HEADS'], |
| autogressive=config_encoder.get('AUTOGRESSIVE', True), |
| key_padding_token=config_encoder.get('KEY_PADDING_TOKEN', 0), |
| ) |
|
|
| if config_encoder['LOAD_PRETRAINED']: |
| transformer.load_pretrained() |
|
|
| return transformer |
|
|