| | |
| | |
| | |
| | |
| | import math |
| | from typing import Optional, Sequence |
| |
|
| | import torch |
| | from torch import Tensor, nn |
| |
|
| | from typing import Dict |
| | import open_clip |
| |
|
| | from .mobile_clip_transformer import ( |
| | PositionalEmbedding, |
| | TransformerEncoder, |
| | get_normalization_layer, |
| | ) |
| |
|
| |
|
| | class TextTransformer(nn.Module): |
| | def __init__(self, cfg: dict, projection_dim: int, *args, **kwargs) -> None: |
| | super().__init__() |
| |
|
| | model_dim = cfg["dim"] |
| | no_scale_embedding = cfg.get("no_scale_embedding", False) |
| | no_pos_embedding = cfg.get("no_pos_embedding", False) |
| | embed_dropout = cfg.get("embed_dropout", 0.0) |
| | norm_layer = cfg["norm_layer"] |
| | variant = cfg["model_name"] |
| | self.vocab_size = cfg["vocab_size"] |
| | self.projection_dim = projection_dim |
| |
|
| | |
| | self.embedding_layer = nn.Embedding( |
| | embedding_dim=model_dim, num_embeddings=self.vocab_size |
| | ) |
| | self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5 |
| |
|
| | |
| | context_length = cfg["context_length"] |
| | assert ( |
| | context_length is not None |
| | ), "Context length can't be None. Please set value accordingly." |
| |
|
| | self.positional_embedding = ( |
| | None |
| | if no_pos_embedding |
| | else PositionalEmbedding( |
| | num_embeddings=context_length, embedding_dim=model_dim |
| | ) |
| | ) |
| |
|
| | self.embedding_dropout = nn.Dropout(p=embed_dropout) |
| |
|
| | |
| | n_transformer_layers = cfg["n_transformer_layers"] |
| |
|
| | |
| | ffn_multipliers = cfg["ffn_multiplier_per_layer"] |
| | if isinstance(ffn_multipliers, (float, int)): |
| | ffn_multipliers = [ffn_multipliers] * n_transformer_layers |
| |
|
| | if not isinstance(ffn_multipliers, Sequence): |
| | Warning( |
| | "{} expects FFN multipliers as a list, whose length is the same as" |
| | " number of transformer layers. Got: {}".format( |
| | self.__class__.__name__, type(ffn_multipliers) |
| | ) |
| | ) |
| | elif ( |
| | isinstance(ffn_multipliers, Sequence) |
| | and len(ffn_multipliers) != n_transformer_layers |
| | ): |
| | Warning( |
| | "We need FFN multiplier for each transformer layer. Got {} ffn" |
| | " multipliers while number of transformer layers = {}".format( |
| | len(ffn_multipliers), n_transformer_layers |
| | ) |
| | ) |
| | ffn_dims = [ |
| | int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0) |
| | for ffn_mult in ffn_multipliers |
| | ] |
| |
|
| | |
| | mha_heads = cfg["n_heads_per_layer"] |
| | if isinstance(mha_heads, int): |
| | mha_heads = [mha_heads] * n_transformer_layers |
| |
|
| | if not isinstance(mha_heads, Sequence): |
| | Warning( |
| | "{} expects MHA heads as a list, whose length is the same as number of " |
| | "transformer layers. Got: {}".format( |
| | self.__class__.__name__, type(mha_heads) |
| | ) |
| | ) |
| | elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers: |
| | Warning( |
| | "{} needs MHA heads for each transformer layer. Got {} mha heads while" |
| | " number of transformer layers = {}".format( |
| | self.__class__.__name__, len(mha_heads), n_transformer_layers |
| | ) |
| | ) |
| |
|
| | if variant == "base": |
| | self.transformer = nn.ModuleList( |
| | [ |
| | TransformerEncoder( |
| | embed_dim=model_dim, |
| | num_heads=mha_heads[layer_idx], |
| | ffn_latent_dim=ffn_dims[layer_idx], |
| | transformer_norm_layer=norm_layer, |
| | ) |
| | for layer_idx in range(n_transformer_layers) |
| | ] |
| | ) |
| | elif variant == "mct": |
| | raise NotImplementedError |
| | else: |
| | raise ValueError("Unrecognized text encoder variant {}".format(variant)) |
| |
|
| | self.final_layer_norm = get_normalization_layer( |
| | num_features=model_dim, norm_type=norm_layer |
| | ) |
| |
|
| | self.projection_layer = nn.Parameter( |
| | torch.empty(model_dim, self.projection_dim) |
| | ) |
| | self.model_dim = model_dim |
| | self.causal_masking = cfg["causal_masking"] |
| |
|
| | def forward_embedding(self, text_tokens: Tensor) -> Tensor: |
| | """Return text embedding for all tokens. |
| | |
| | Args: |
| | text_tokens: a tensor of token indices. Shape: [batch_size, context_length] |
| | |
| | Returns: |
| | A tensor of [batch_size, context_length, hidden_dim]. |
| | """ |
| | |
| | token_emb = self.embedding_layer(text_tokens) |
| | seq_len = token_emb.shape[1] |
| | if self.positional_embedding is not None: |
| | token_emb = token_emb + self.positional_embedding(seq_len).to( |
| | token_emb.dtype |
| | ) |
| | token_emb = self.embedding_dropout(token_emb) |
| | return token_emb |
| |
|
| | def build_attention_mask(self, context_length: int, batch_size: int) -> Tensor: |
| | """Build causal attention mask [batch_size, context_length, context_length].""" |
| | |
| | |
| | mask = torch.empty(context_length, context_length) |
| | mask.fill_(float("-inf")) |
| | mask.triu_(1) |
| | mask = mask.unsqueeze(0) |
| | mask = mask.expand(batch_size, -1, -1) |
| | return mask |
| |
|
| | def encode_text( |
| | self, |
| | text_tokens: Tensor, |
| | key_padding_mask: Optional[Tensor] = None, |
| | return_all_tokens: bool = False, |
| | *args, |
| | **kwargs |
| | ) -> Tensor: |
| | """Return text token embeddings. |
| | |
| | Args: |
| | text_tokens: a tensor of token indices. Shape: [batch_size, context_length] |
| | key_padding_mask: a tensor of boolean values as the padding mask. |
| | Shape: [batch_size, context_length] |
| | return_all_tokens: a boolean flag to return all tokens, defaults to False |
| | to return only EOT token embedding. |
| | Returns: |
| | A tensor of [batch_size, context_length, hidden_dim] if return_all_tokens is |
| | True, otherwise a tensor of [batch_size, hidden_dim]. |
| | """ |
| | |
| | |
| | token_emb = self.forward_embedding(text_tokens) |
| |
|
| | |
| | attn_mask = None |
| | if self.causal_masking: |
| | attn_mask = self.build_attention_mask( |
| | context_length=text_tokens.shape[1], batch_size=text_tokens.shape[0] |
| | ) |
| | attn_mask = attn_mask.to(device=token_emb.device, dtype=token_emb.dtype) |
| | key_padding_mask = None |
| |
|
| | for layer in self.transformer: |
| | token_emb = layer( |
| | token_emb, |
| | key_padding_mask=key_padding_mask, |
| | attn_mask=attn_mask, |
| | ) |
| |
|
| | |
| | token_emb = self.final_layer_norm(token_emb) |
| |
|
| | if return_all_tokens: |
| | return token_emb |
| |
|
| | |
| | token_emb = token_emb[ |
| | torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1) |
| | ] |
| |
|
| | token_emb = token_emb @ self.projection_layer |
| | return token_emb |
| |
|
| | def forward( |
| | self, |
| | text_tokens: Tensor, |
| | key_padding_mask: Optional[Tensor] = None, |
| | return_all_tokens: bool = False, |
| | *args, |
| | **kwargs |
| | ) -> Tensor: |
| | |
| | |
| | text_tokens = self.encode_text( |
| | text_tokens=text_tokens, |
| | key_padding_mask=key_padding_mask, |
| | return_all_tokens=return_all_tokens, |
| | *args, |
| | **kwargs |
| | ) |
| | return text_tokens |
| |
|
| |
|
| | class ClipTokenizer(nn.Module): |
| | def __init__(self, cfg, *args, **kwargs): |
| | super().__init__() |
| | self.context_length = cfg["text_cfg"]["context_length"] |
| | model_name = getattr(cfg["text_cfg"], "open_clip_tokenizer", "ViT-B-16") |
| | self.tokenizer = open_clip.get_tokenizer(model_name) |
| |
|
| | def get_vocab_size(self) -> int: |
| | return len(self.tokenizer.encoder) |
| |
|
| | def get_encodings(self) -> Dict[str, int]: |
| | return self.tokenizer.encoder |
| |
|
| | def get_eot_token(self) -> int: |
| | |
| | return self.tokenizer("")[1] |
| |
|
| | def get_sot_token(self) -> int: |
| | |
| | return self.tokenizer("")[0] |
| |
|
| | def forward(self, input_sentence: str, *args, **kwargs) -> Tensor: |
| | |
| | tokenized_sentence = self.tokenizer(input_sentence, self.context_length) |
| | assert ( |
| | tokenized_sentence.shape[-1] == self.context_length |
| | ), "Tokenized tensor should be exactly `context_length` long." |
| | return tokenized_sentence |
| |
|