| |
|
|
| from dataclasses import asdict, dataclass, field |
| from typing import Dict, List, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
| @dataclass |
| class RotaryEmbeddingConfig: |
| """ |
| Rotary Positional Embedding configuration |
| max_seq_len: The number of positions to encode and cache. |
| dim: Dimension of RoPE. |
| theta: Rotation angle. |
| """ |
|
|
| max_seq_len: int |
| dim: int |
| theta: float |
|
|
|
|
| @dataclass |
| class PerceiverResamplerConfig: |
| """ |
| Parameters to initialize an PerceiverResampler model. |
| |
| Args: |
| emb_layer_norm_before: Whether to use layer norm before the first attention |
| layer. |
| attention_heads: Number of attention heads. |
| key_size: The dimension of the query, key, and values within each attention |
| head, if not specified, it is set to attention_heads//embed_dim. |
| It can be useful to set a custom key size if we want to impose the size of |
| the query, key and value tensor ( for example, tensors shaped with |
| power of 2 are more efficiently handled on TPUs ). |
| Note: Parametrizing the model with a custom key size has been done in : |
| Brown, Tom, et al. "Language models are few-shot learners." |
| Advances in neural information processing systems 33 (2020): 1877-1901. |
| embed_dim: Embedding dimension. |
| ffn_embed_dim: Feed forward embedding dimension. |
| num_layers: Number of attention blocks. |
| ffn_activation_name: Activation function to be used in FFN block. Supported |
| names are "gelu", "relu", "swish". |
| use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed |
| Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg |
| to True and use swish as ffn_activation_name. |
| Same principle for a gated-relu. To keep the same number of parameters in |
| the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU. |
| See https://arxiv.org/pdf/2002.05202.pdf for more details. |
| resampled_length: length of the resampled output of the module |
| use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint |
| gradients in the forward pass to reduce the computation in the backward). |
| """ |
|
|
| |
| emb_layer_norm_before: bool = False |
| attention_heads: int = 20 |
| key_size: Optional[int] = None |
| embed_dim: int = 1280 |
| ffn_embed_dim: int = 5120 |
| num_layers: int = 24 |
| add_bias_kv: bool = False |
| add_bias_ffn: bool = True |
| ffn_activation_name: str = "gelu-no-approx" |
| use_glu_in_ffn: bool = False |
| resampled_length: int = 64 |
|
|
| |
| use_gradient_checkpointing: bool = False |
|
|
| def __post_init__(self) -> None: |
| """ |
| Checks that the given values are compatible. |
| """ |
|
|
| if self.key_size is None: |
| if not self.embed_dim % self.attention_heads == 0: |
| raise ValueError( |
| f"When no key size is provided, the embedding dimension should be " |
| f"divisible by the number of heads, however provided embedding " |
| f"dimension is {self.embed_dim} and the number of heads is " |
| f"{self.attention_heads}." |
| ) |
| self.key_size = self.embed_dim // self.attention_heads |
|
|
|
|
| @dataclass |
| class GptConfig: |
| """ |
| Parameters to initialize a Gpt model. |
| |
| NOTE: the pad token is not defined |
| |
| Args: |
| vocab_size: Token vocabulary. |
| eos_token_id: used to stop sentence generation |
| embed_dim: Embedding dimension. |
| ffn_embed_dim: Feed forward embedding dimension. |
| num_heads: Number of attention heads. |
| num_kv_heads: Number of key and value heads to support Grouped-Query and |
| Multi-Query Attention. If None, the number of key and value heads is |
| equal to the number of attention heads. |
| num_layers: Number of Decoder layer_stack |
| rope_config: The configuration for the rotary positional embeddings |
| add_bias_ffn: Add bias in feed forward network block. |
| ffn_activation_name: Activation function to be used in FFN block. Supported |
| names are "gelu", "gelu-no-approx", "relu", "swish". |
| use_glu_in_ffn: whether to use Gated Linear Unit (GLU) in Feed |
| Forward Network (FFN) block. |
| example: To do a swiGLU (gated-swish) put this arg |
| to True and use swish as ffn_activation_name. |
| Same principle for a gated-relu. |
| add_bias_lm_head: whether to use bias in the final LM layer |
| norm_type: The type of norm used ( pre normalization scheme ) used. can be |
| one of ["layer_norm", "RMS_norm"] |
| parallel_attention_ff: Whether to do the attention and the MLP in parallel, |
| and then sum up the results as it is done in Gpt-NeoX : |
| Black, Sid, et al. "Gpt-neox-20b: An open-source autoregressive |
| language model." arXiv preprint arXiv:2204.06745 (2022). |
| It is said to improve the training time of 15% when compiling with JAX |
| use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint |
| gradients in the forward pass to reduce the computation in the backward). |
| add_bias_attn: Add bias to the attention mechanism (key, query, value, and |
| output projections). |
| """ |
|
|
| |
| vocab_size: int |
| eos_token_id: int |
|
|
| |
| embed_dim: int = 16 |
| ffn_embed_dim: int = 64 |
| num_heads: int = 2 |
| num_kv_heads: Optional[int] = None |
| num_layers: int = 2 |
| rope_config: RotaryEmbeddingConfig = field( |
| default_factory=lambda: RotaryEmbeddingConfig( |
| max_seq_len=512, dim=8, theta=10000.0 |
| ) |
| ) |
| add_bias_ffn: bool = False |
| ffn_activation_name: str = "swish" |
| use_glu_in_ffn: bool = True |
| add_bias_lm_head: bool = False |
| norm_type: str = "RMS_norm" |
| rms_norm_eps: float = 1e-6 |
| parallel_attention_ff: bool = True |
|
|
| |
| use_gradient_checkpointing: bool = False |
|
|
| |
| add_bias_attn: bool = False |
|
|
| def __post_init__(self) -> None: |
| """ |
| Checks that the given values are compatible. |
| """ |
| if not self.embed_dim % self.num_heads == 0: |
| raise ValueError( |
| f"The embedding dimension should be " |
| f"divisible by the number of heads, however provided embedding " |
| f"dimension is {self.embed_dim} and the number of heads is " |
| f"{self.num_heads}." |
| ) |
|
|
| if not self.embed_dim // self.num_heads > 1: |
| raise ValueError( |
| "embed_dim / num_heads must be higher than 2 to apply rotary embeddings" |
| ) |
|
|
| if not self.embed_dim // self.num_heads >= self.rope_config.dim: |
| raise ValueError( |
| "embed_dim // num_heads must be higher than rope_config.dim " |
| "to apply rotary embeddings" |
| ) |
|
|
| def to_dict(self): |
| output = asdict(self) |
| output["rope_config"] = asdict(self.rope_config) |
| return output |
|
|
|
|
| @dataclass |
| class NucleotideTransformerConfig: |
| """ |
| Parameters to initialize an NT model. |
| |
| Args: |
| alphabet_size: Token vocabulary. |
| pad_token_id: ID of pad token. |
| mask_token_id: ID of mask token. |
| max_positions: Maximum sequence length. |
| embed_scale: Correction ratio applied to the embeddings to make up for the |
| norm difference between the input during training and inference. |
| emb_layer_norm_before: Whether to use layer norm before the first attention |
| layer. |
| attention_heads: Number of attention heads. |
| key_size: The dimension of the query, key, and values within each attention |
| head, if not specified, it is set to attention_heads//embed_dim. |
| It can be useful to set a custom key size if we want to impose the size of |
| the query, key and value tensor ( for example, tensors shaped with |
| power of 2 are more efficiently handled on TPUs ). |
| Note: Parametrizing the model with a custom key size has been done in : |
| Brown, Tom, et al. "Language models are few-shot learners." |
| Advances in neural information processing systems 33 (2020): 1877-1901. |
| embed_dim: Embedding dimension. |
| ffn_embed_dim: Feed forward embedding dimension. |
| num_layers: Number of attention blocks. |
| positional_embedding: Type of positional embedding to use before the first |
| attention layer. Options: "learned", "learned_standard" "sinusoidal" or |
| None. |
| NOTE: "learned" is the positional embedding of ESM, and "learned_standard" |
| is a more standard one, used for example in DNAbert. |
| lm_head: type of language model head. Options: "simple", "roberta" or None. |
| add_bias_kv: Add bias in attention layer. |
| add_bias_ffn: Add bias in feed forward network block. |
| use_rotary_embedding: Whether to use rotary embeddings. Requires: |
| positional_embeddings = None. |
| rescaling_factor: Scaling factor to use for rotary embeddings. |
| ffn_activation_name: Activation function to be used in FFN block. Supported |
| names are "gelu", "relu", "swish". |
| use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed |
| Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg |
| to True and use swish as ffn_activation_name. |
| Same principle for a gated-relu. To keep the same number of parameters in |
| the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU. |
| See https://arxiv.org/pdf/2002.05202.pdf for more details. |
| mask_before_attention: Use mask before attention layers. |
| layer_norm_eps: the eps factor in the different layer norms of the model (refer |
| to layer norm implementation) |
| token_dropout: Token dropout. |
| masking_ratio: Masking ratio (used if token dropout is enabled). |
| masking_prob: Masking probability (used if token dropout is enabled). |
| use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint |
| gradients in the forward pass to reduce the computation in the backward). |
| """ |
|
|
| alphabet_size: int |
| pad_token_id: int |
| mask_token_id: int |
|
|
| max_positions: int = 1024 |
| embed_scale: float = 1.0 |
|
|
| |
| emb_layer_norm_before: bool = False |
| attention_heads: int = 20 |
| key_size: Optional[int] = None |
| embed_dim: int = 1280 |
| ffn_embed_dim: int = 5120 |
| num_layers: int = 24 |
| positional_embedding: Optional[str] = "learned" |
| lm_head: Optional[str] = "simple" |
| add_bias_kv: bool = False |
| add_bias_ffn: bool = True |
| use_rotary_embedding: bool = False |
| rescaling_factor: Optional[float] = None |
| ffn_activation_name: str = "gelu-no-approx" |
| use_glu_in_ffn: bool = False |
| mask_before_attention: bool = False |
| layer_norm_eps: float = 1e-5 |
| pre_layer_norm: bool = True |
| bias_word_embedding: bool = False |
|
|
| |
| token_dropout: bool = False |
| masking_ratio: float = 0.1 |
| masking_prob: float = 0.8 |
|
|
| |
| use_gradient_checkpointing: bool = False |
|
|
| |
| embeddings_layers_to_save: List[int] = field(default_factory=list) |
| attention_maps_to_save: List[Tuple[int, int]] = field(default_factory=list) |
|
|
| def __post_init__(self) -> None: |
| """ |
| Checks that the given values are compatible. |
| """ |
|
|
| if self.key_size is None: |
| if not self.embed_dim % self.attention_heads == 0: |
| raise ValueError( |
| f"When no key size is provided, the embedding dimension should be " |
| f"divisible by the number of heads, however provided embedding " |
| f"dimension is {self.embed_dim} and the number of heads is " |
| f"{self.attention_heads}." |
| ) |
| self.key_size = self.embed_dim // self.attention_heads |
| if self.positional_embedding is not None: |
| if type(self.positional_embedding) != str: |
| raise TypeError |
|
|
| if self.positional_embedding not in [ |
| "learned", |
| "sinusoidal", |
| "learned_standard", |
| "alibi_dnabert_2", |
| ]: |
| raise ValueError( |
| "The positional_embedding argument should either be None," |
| "`learned`, `sinusoidal`, 'learned_standard' or 'alibi_dnabert_2'." |
| ) |
| if self.lm_head is not None: |
| if type(self.lm_head) != str: |
| raise TypeError |
|
|
| if self.lm_head not in ["simple", "roberta"]: |
| raise ValueError( |
| "The lm_head argument should either be None," |
| "`simple` or `roberta`." |
| ) |
|
|
| if self.use_rotary_embedding and self.positional_embedding is not None: |
| raise ValueError( |
| "When using rotary embedding, positional_embedding must be set to none" |
| ) |
|
|
| if self.add_bias_kv and self.use_rotary_embedding: |
| raise ValueError( |
| "Biases on key and values are not compatible with Rotary embeddings." |
| ) |
|
|
| if self.positional_embedding == "alibi_dnabert_2": |
| assert not self.add_bias_kv |
|
|
|
|
| @dataclass |
| class ChatNTConfig(PretrainedConfig): |
| model_type = "ChatNT" |
|
|
| def __init__(self, **kwargs): |
| self.gpt_config: GptConfig = kwargs.get("gpt_config", GptConfig(32000, 3)) |
| self.nt_config: NucleotideTransformerConfig = kwargs.get( |
| "nt_config", NucleotideTransformerConfig(4000, 1, 4) |
| ) |
| self.perceiver_resampler_config: PerceiverResamplerConfig = kwargs.get( |
| "perceiver_resampler_config", PerceiverResamplerConfig() |
| ) |
| self.seq_token_id: int = kwargs.get("seq_token_id", 32000) |
| self.bio_pad_token_id: int = kwargs.get("bio_pad_token_id", 1) |
| self.english_pad_token_id: int = kwargs.get("english_pad_token_id", 2) |
| super().__init__(**kwargs) |
|
|
| def to_dict(self): |
| output = super().to_dict() |
|
|
| def serialize(obj): |
| return obj.to_dict() if hasattr(obj, "to_dict") else vars(obj) |
|
|
| output["gpt_config"] = serialize(self.gpt_config) |
| output["nt_config"] = serialize(self.nt_config) |
| output["perceiver_resampler_config"] = serialize( |
| self.perceiver_resampler_config |
| ) |
| return output |
|
|
|
|
| class TorchBioBrainDecoder(nn.Module): |
| def __init__( |
| self, |
| gpt_config: GptConfig, |
| seq_token_id: int, |
| ): |
| """ |
| Initializes the BioBrain decoder, using a GPT model for text generation with |
| bio embeddings. |
| |
| Args: |
| gpt_config: Configuration for the GPT model |
| seq_token_id: Index of the SEQ token |
| """ |
| super(TorchBioBrainDecoder, self).__init__() |
| self.gpt_config = gpt_config |
| self.seq_token_id = seq_token_id |
|
|
| |
| self.gpt_model = TorchGptDecoder(self.gpt_config) |
|
|
| def forward( |
| self, english_token_ids: torch.Tensor, projected_bio_embeddings: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Forward pass through the model. |
| |
| Args: |
| english_token_ids: Tensor of English token IDs with shape |
| (batch_size, num_english_tokens). |
| projected_bio_embeddings: Optional tensor of bio embeddings with shape |
| (batch_size, num_bio_sequences, ?, embed_dim). |
| |
| Returns: |
| torch.Tensor: The logits from the GPT model, |
| shaped (batch_size, num_english_tokens, vocab_size). |
| """ |
|
|
| |
| tokens_embeddings = self.gpt_model.token_embed(english_token_ids) |
|
|
| if projected_bio_embeddings is not None: |
| ( |
| batch_size, |
| num_bio_sequences, |
| _, |
| bio_embed_dim, |
| ) = projected_bio_embeddings.shape |
|
|
| |
| processed_tokens_ids = english_token_ids.clone() |
| for bio_seq_num in range(num_bio_sequences): |
| tokens_embeddings, processed_tokens_ids = self.insert_embeddings( |
| processed_tokens_ids, |
| tokens_embeddings, |
| projected_bio_embeddings[:, bio_seq_num, :, :], |
| bio_seq_num=bio_seq_num, |
| ) |
|
|
| |
| embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings) |
| embeddings = self.gpt_model.final_norm(embeddings) |
|
|
| |
| logits = self.gpt_model.lm_head(embeddings) |
|
|
| if projected_bio_embeddings is not None: |
| |
| processed_tokens_ids = english_token_ids.clone() |
| resampled_length = projected_bio_embeddings.shape[-2] |
| for _ in range(num_bio_sequences): |
| logits, processed_tokens_ids = self.cleanup_logits( |
| tokens=processed_tokens_ids, |
| logits=logits, |
| resampled_length=resampled_length, |
| ) |
|
|
| return logits |
|
|
| def insert_embeddings( |
| self, |
| tokens: torch.Tensor, |
| input_embeddings: torch.Tensor, |
| resampled_embeddings: torch.Tensor, |
| bio_seq_num: int, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Inserts resampled embeddings in input_embeddings, starting at the SEQ token |
| |
| Args: |
| tokens (torch.Tensor): Shape (batch_size, num_tokens) |
| input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim) |
| resampled_embeddings (torch.Tensor): |
| Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim) |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: |
| - input_embeddings with resampled_embeddings inserted at the SEQ token |
| - tokens with the SEQ token set to -1 |
| """ |
|
|
| def _insert( |
| tokens_1d: torch.Tensor, |
| input_embeddings_1d: torch.Tensor, |
| resampled_embeddings_1d: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| tokens (torch.Tensor): Shape (num_tokens,) |
| input_embeddings (torch.Tensor): Shape (num_tokens, embed_dim,) |
| resampled_embeddings (torch.Tensor): |
| Shape (bio_sequence_length, embed_dim,) |
| """ |
| indices = torch.where(tokens_1d == self.seq_token_id)[0] |
| if indices.numel() > 0: |
| idx = indices[0].item() |
| insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num |
| x = torch.cat( |
| [ |
| input_embeddings_1d[:insertion_pos, :], |
| resampled_embeddings_1d, |
| input_embeddings_1d[insertion_pos:, :], |
| ], |
| dim=0, |
| )[: tokens_1d.shape[0] + 1, :] |
| x = torch.roll(torch.roll(x, shifts=-idx, dims=0), shifts=idx, dims=0)[ |
| :-1, : |
| ] |
| tokens_1d[idx] = -1 |
| return x, tokens_1d |
| else: |
| return ( |
| input_embeddings, |
| tokens_1d, |
| ) |
|
|
| tokens_acc = [] |
| embeddings_acc = [] |
|
|
| for i in range(tokens.shape[0]): |
| embeddings_out, tokens_out = _insert( |
| tokens[i].clone(), |
| input_embeddings[i].clone(), |
| resampled_embeddings[i].clone(), |
| ) |
| tokens_acc.append(tokens_out) |
| embeddings_acc.append(embeddings_out) |
| tokens_acc = torch.stack(tokens_acc) |
| embeddings_acc = torch.stack(embeddings_acc) |
|
|
| return embeddings_acc, tokens_acc |
|
|
| def cleanup_logits( |
| self, tokens: torch.Tensor, logits: torch.Tensor, resampled_length: int |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Removes the logits corresponding to the unused embeddings. |
| |
| Args: |
| tokens: Input english tokens. |
| logits: Input logits. |
| |
| Returns: |
| Cleaned logits, last values will be equal to 0. |
| """ |
|
|
| def _clean( |
| token: torch.Tensor, logit: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| indices = torch.where(token == self.seq_token_id)[0] |
| if indices.numel() > 0: |
| idx = indices[0].item() |
|
|
| mask_idx = ( |
| torch.arange(logit.shape[0] - resampled_length, device=logit.device) |
| > idx |
| ) |
| mask_idx = mask_idx.unsqueeze(1) |
|
|
| |
| logit = ( |
| logit[:-resampled_length] * (~mask_idx) |
| + logit[resampled_length:] * mask_idx |
| ) |
|
|
| |
| logit = torch.cat( |
| ( |
| logit, |
| torch.zeros( |
| (resampled_length, logit.shape[1]), |
| dtype=logit.dtype, |
| device=logit.device, |
| ), |
| ) |
| ) |
|
|
| |
| token[idx] = -1 |
|
|
| return logit, token |
|
|
| else: |
| return logit, token |
|
|
| tokens_acc = [] |
| logits_acc = [] |
|
|
| for i in range(tokens.shape[0]): |
| logits_out, tokens_out = _clean(tokens[i].clone(), logits[i].clone()) |
| tokens_acc.append(tokens_out) |
| logits_acc.append(logits_out) |
| tokens_acc = torch.stack(tokens_acc) |
| logits_acc = torch.stack(logits_acc) |
|
|
| return logits_acc, tokens_acc |
|
|
|
|
| class TorchMultiOmicsModel(PreTrainedModel): |
| config_class = ChatNTConfig |
|
|
| def __init__(self, config: ChatNTConfig) -> None: |
| if isinstance(config, dict): |
| |
| |
| config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig( |
| **config["gpt_config"]["rope_config"] |
| ) |
| config["gpt_config"] = GptConfig(**config["gpt_config"]) |
| config["nt_config"] = NucleotideTransformerConfig(**config["nt_config"]) |
| config["perceiver_resampler_config"] = PerceiverResamplerConfig( |
| **config["perceiver_resampler_config"] |
| ) |
| config = ChatNTConfig(**config) |
|
|
| else: |
| if isinstance(config.gpt_config, dict): |
| config.gpt_config["rope_config"] = RotaryEmbeddingConfig( |
| **config.gpt_config["rope_config"] |
| ) |
| config.gpt_config = GptConfig(**config.gpt_config) |
|
|
| if isinstance(config.nt_config, dict): |
| config.nt_config = NucleotideTransformerConfig(**config.nt_config) |
|
|
| if isinstance(config.perceiver_resampler_config, dict): |
| config.perceiver_resampler_config = PerceiverResamplerConfig( |
| **config.perceiver_resampler_config |
| ) |
|
|
| super().__init__(config=config) |
| self.gpt_config = config.gpt_config |
| self.nt_config = config.nt_config |
| self.perceiver_resampler_config = config.perceiver_resampler_config |
| self.seq_token_id = config.seq_token_id |
| self.bio_pad_token_id = config.bio_pad_token_id |
| self.english_pad_token_id = config.english_pad_token_id |
|
|
| |
| self.seq_token_id -= 1 |
|
|
| self.biobrain_encoder = TorchBioBrainEncoder(nt_config=self.nt_config) |
| self.biobrain_decoder = TorchBioBrainDecoder( |
| gpt_config=self.gpt_config, seq_token_id=self.seq_token_id |
| ) |
| self.projection_model = TorchMultiModalPerceiverResamplerProjection( |
| perceiver_resampler_config=self.perceiver_resampler_config, |
| input_embed_dim=self.nt_config.embed_dim, |
| embed_dim=self.gpt_config.embed_dim, |
| english_vocab_size=self.gpt_config.vocab_size, |
| bio_pad_token_id=self.bio_pad_token_id, |
| english_pad_token_id=self.english_pad_token_id, |
| ) |
|
|
| def forward( |
| self, |
| multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor], |
| projection_english_tokens_ids: torch.Tensor, |
| projected_bio_embeddings: torch.Tensor = None, |
| ) -> dict[str, torch.Tensor]: |
| """ |
| |
| Args: |
| multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]): |
| english_tokens_ids: Represents the prompt tokens (english tokens) |
| Shape (batch_size, num_english_tokens) |
| |
| bio_tokens_ids: Represents the bio sequences tokens |
| Shape (batch_size, num_bio_sequences, num_bio_tokens) |
| |
| projection_english_tokens_ids (torch.Tensor): |
| Shape (batch_size, num_english_tokens) |
| |
| projected_bio_embeddings (projected_bio_embeddings, optional): |
| Shape (batch_size, num_bio_sequencse, ?, embed_dim). |
| Defaults to None. |
| |
| Returns: |
| dict[str, torch.Tensor] containing: |
| - logits: |
| Shape (batch_size, num_tokens, vocab_size) |
| |
| - projected_bio_embeddings: |
| Shape (batch_size, num_bio_sequences, ?, embed_dim) |
| """ |
| english_token_ids, bio_token_ids = multi_omics_tokens_ids |
| english_token_ids = english_token_ids.clone() |
| bio_token_ids = bio_token_ids.clone() |
| projection_english_tokens_ids = projection_english_tokens_ids.clone() |
| if projected_bio_embeddings is not None: |
| projected_bio_embeddings = projected_bio_embeddings.clone() |
|
|
| |
| |
| |
| |
| |
| |
| vocab_size = self.gpt_config.vocab_size |
| |
| english_token_ids[english_token_ids == vocab_size - 1] = 0 |
| projection_english_tokens_ids[ |
| projection_english_tokens_ids == vocab_size - 1 |
| ] = 0 |
| english_token_ids[english_token_ids == vocab_size] = vocab_size - 1 |
| projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = ( |
| vocab_size - 1 |
| ) |
|
|
| if bio_token_ids is None: |
| projected_bio_embeddings = None |
| else: |
| num_bio_sequences = bio_token_ids.shape[1] |
|
|
| if projected_bio_embeddings is None: |
| |
| bio_embeddings_list = [ |
| self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num]) |
| for bio_seq_num in range(num_bio_sequences) |
| ] |
|
|
| |
| projected_bio_embeddings = [ |
| self.projection_model( |
| bio_token_ids=bio_token_ids[:, bio_seq_num], |
| bio_embeddings=bio_embeddings, |
| english_token_ids=projection_english_tokens_ids, |
| ) |
| for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list) |
| ] |
| projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1) |
|
|
| |
| logits = self.biobrain_decoder( |
| english_token_ids=english_token_ids, |
| projected_bio_embeddings=projected_bio_embeddings, |
| ) |
| logits = logits.to(torch.float32) |
|
|
| outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings} |
|
|
| return outs |
|
|
|
|
| class TorchRotaryEmbedding(torch.nn.Module): |
| def __init__(self, config: RotaryEmbeddingConfig): |
| super().__init__() |
|
|
| self.max_seq_len = config.max_seq_len |
| self.dim = config.dim |
| self.theta = config.theta |
| self.sincos_cache = None |
|
|
| def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor: |
| """ |
| Create the sines and cosines for the RoPE. |
| |
| Returns: |
| Sinusoidal positions of shape (self.max_seq_len, self.dim). |
| """ |
| |
| inv_freq = 1.0 / ( |
| self.theta |
| ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim) |
| ) |
|
|
| |
| sinusoid_inp = torch.einsum( |
| "i,j->ij", torch.arange(self.max_seq_len, device=device).float(), inv_freq |
| ) |
|
|
| |
| sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos() |
|
|
| |
| sincos = torch.zeros( |
| (self.max_seq_len, self.dim), dtype=torch.float32, device=device |
| ) |
|
|
| |
| sentinel = self.dim // 2 + self.dim % 2 |
| sincos[:, :sentinel] = sin |
| sincos[:, sentinel:] = cos |
|
|
| return sincos |
|
|
| def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Prepare a tensor to apply the RoPE mechanism. |
| |
| Args: |
| x: Tensor of shape (batch_size, seq_len, num_heads, head_dim), |
| typically this is the key or query tensor. |
| |
| Returns: |
| The even indices in the last dimension have their sign flipped. |
| Tensor of shape (batch_size, seq_len, num_heads, head_dim). |
| """ |
| |
| rotate_half = torch.stack((-x[..., 1::2], x[..., ::2]), dim=-1) |
|
|
| |
| rotate_half = rotate_half.view(rotate_half.shape[:-2] + (-1,)) |
| return rotate_half |
|
|
| def _apply_rotary_pos_emb( |
| self, x: torch.Tensor, sincos: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Applies rotary embeddings to x. |
| |
| Args: |
| x: Tensor of shape (batch_size, seq_len, num_heads, head_dim), |
| typically this is the key or query tensor. |
| sincos: Tuple of sine and cosine tensors for position encoding. |
| |
| Returns: |
| RoPE embeddings tensor. |
| """ |
| sin_pos, cos_pos = sincos |
|
|
| |
| sin_pos = torch.repeat_interleave(sin_pos.unsqueeze(2), repeats=2, dim=-1) |
| cos_pos = torch.repeat_interleave(cos_pos.unsqueeze(2), repeats=2, dim=-1) |
|
|
| |
| return (x * cos_pos) + (self._rotate_every_two(x) * sin_pos) |
|
|
| def __call__( |
| self, k: torch.Tensor, q: torch.Tensor, positions: Optional[torch.Tensor] = None |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Applies rotary embeddings to k and q. |
| |
| Args: |
| k: key tensor of shape (batch_size, seq_len, num_heads, head_dim), |
| q: value tensor of shape (batch_size, seq_len, num_heads, head_dim), |
| positions: optional positions offset useful when caching, |
| |
| Returns: |
| RoPE embeddings for the keys and values. |
| """ |
| if self.sincos_cache is None: |
| device = k.device |
| self.sincos_cache = self._create_sinusoidal_positions(device=device) |
|
|
| batch_size, seq_len, num_heads, head_dim = k.shape |
|
|
| |
| position_ids = ( |
| torch.arange(seq_len, device=k.device).unsqueeze(0).expand(batch_size, -1) |
| ) |
|
|
| if positions is not None: |
| position_ids += positions |
|
|
| |
| sincos = self.sincos_cache[position_ids] |
|
|
| |
| sincos = torch.chunk(sincos, 2, dim=-1) |
|
|
| |
| k_rot = self._apply_rotary_pos_emb(k[..., : self.dim], sincos) |
| k_pass = k[..., self.dim :] |
|
|
| q_rot = self._apply_rotary_pos_emb(q[..., : self.dim], sincos) |
| q_pass = q[..., self.dim :] |
|
|
| |
| keys = torch.cat([k_rot, k_pass], dim=-1) |
| values = torch.cat([q_rot, q_pass], dim=-1) |
|
|
| return keys, values |
|
|
|
|
| class TorchGptGroupedQueryAttention(nn.Module): |
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int, |
| rope_config: RotaryEmbeddingConfig, |
| num_kv_heads: int = None, |
| head_dim: int = None, |
| add_bias_attn: bool = False, |
| ) -> None: |
| super().__init__() |
| self.num_heads = num_heads |
| self.num_kv_heads = num_kv_heads or num_heads |
| self.embed_dim = embed_dim |
| self.head_dim = head_dim or (embed_dim // num_heads) |
| self.add_bias_attn = add_bias_attn |
| self.rope = TorchRotaryEmbedding(rope_config) |
|
|
| self.query_linear = nn.Linear( |
| embed_dim, self.num_heads * self.head_dim, bias=add_bias_attn |
| ) |
| self.key_linear = nn.Linear( |
| embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn |
| ) |
| self.value_linear = nn.Linear( |
| embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn |
| ) |
| self.out_linear = nn.Linear( |
| self.num_heads * self.head_dim, embed_dim, bias=add_bias_attn |
| ) |
|
|
| def forward( |
| self, |
| query_inputs: torch.Tensor, |
| key_inputs: torch.Tensor, |
| value_inputs: torch.Tensor, |
| attention_mask: torch.Tensor = None, |
| ) -> torch.Tensor: |
| batch_size, seq_len, _ = query_inputs.shape |
|
|
| queries = self.query_linear(query_inputs).view( |
| batch_size, seq_len, self.num_heads, self.head_dim |
| ) |
| keys = self.key_linear(key_inputs).view( |
| batch_size, seq_len, self.num_kv_heads, self.head_dim |
| ) |
| values = self.value_linear(value_inputs).view( |
| batch_size, seq_len, self.num_kv_heads, self.head_dim |
| ) |
|
|
| keys, queries = self.rope(keys, queries) |
|
|
| n_rep = self.num_heads // self.num_kv_heads |
| keys = keys.repeat_interleave(n_rep, dim=2) |
| values = values.repeat_interleave(n_rep, dim=2) |
|
|
| attention_logits = torch.einsum("bthd,bThd->bhtT", queries, keys) / ( |
| self.head_dim**0.5 |
| ) |
|
|
| if attention_mask is not None: |
| attention_logits = attention_logits.masked_fill( |
| attention_mask == 0, float("-inf") |
| ) |
|
|
| attention_weights = nn.functional.softmax(attention_logits, dim=-1) |
| attention_weights = attention_weights.to(values.dtype) |
|
|
| values = torch.einsum("bhtT,bThd->bthd", attention_weights, values) |
| values = values.contiguous().view(batch_size, seq_len, -1) |
|
|
| return self.out_linear(values) |
|
|
|
|
| class TorchGptDecoder(nn.Module): |
| def __init__(self, config: GptConfig, name: Optional[str] = None): |
| super().__init__() |
| self.config = config |
|
|
| self.token_embed = nn.Embedding(config.vocab_size, config.embed_dim) |
|
|
| if config.norm_type == "layer_norm": |
| self.final_norm = nn.LayerNorm(config.embed_dim) |
| elif config.norm_type == "RMS_norm": |
| self.final_norm = TorchRMSNorm(config.embed_dim, eps=config.rms_norm_eps) |
| else: |
| raise ValueError(f"unrecognized norm_type in config {config.norm_type}") |
|
|
| self.layers = nn.ModuleList( |
| [ |
| TorchGptDecoderLayer( |
| embed_dim=config.embed_dim, |
| ffn_embed_dim=config.ffn_embed_dim, |
| num_heads=config.num_heads, |
| rope_config=config.rope_config, |
| norm_type=config.norm_type, |
| parallel_attention_ff=config.parallel_attention_ff, |
| add_bias_ffn=config.add_bias_ffn, |
| ffn_activation_name=config.ffn_activation_name, |
| use_glu_in_ffn=config.use_glu_in_ffn, |
| num_kv_heads=config.num_kv_heads, |
| add_bias_attn=config.add_bias_attn, |
| rms_norm_eps=config.rms_norm_eps, |
| ) |
| for _ in range(config.num_layers) |
| ] |
| ) |
|
|
| self.lm_head = TorchSimpleLMHead( |
| embed_dim=config.embed_dim, |
| alphabet_size=config.vocab_size, |
| add_bias_lm_head=config.add_bias_lm_head, |
| ) |
|
|
| def apply_transformer_layers( |
| self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None |
| ) -> torch.Tensor: |
| if attention_mask is None: |
| attention_mask = build_causal_attention_mask( |
| 1, embeddings.shape[1], device=embeddings.device |
| ) |
| for layer in self.layers: |
| embeddings = layer(embeddings, attention_mask) |
|
|
| return embeddings |
|
|
| def forward( |
| self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None |
| ) -> dict[str, torch.Tensor]: |
| if attention_mask is None: |
| attention_mask = build_causal_attention_mask( |
| 1, token_ids.shape[1], device=token_ids.device |
| ) |
|
|
| tokens_embeddings = self.token_embed(token_ids) |
|
|
| after_transformer_embeddings = self.apply_transformer_layers( |
| tokens_embeddings, attention_mask=attention_mask |
| ) |
|
|
| embeddings = self.final_norm(after_transformer_embeddings) |
| logits = self.lm_head(embeddings) |
| return {"embeddings": embeddings, "logits": logits} |
|
|
|
|
| class TorchSimpleLMHead(nn.Module): |
| def __init__( |
| self, embed_dim: int, alphabet_size: int, add_bias_lm_head: bool = True |
| ) -> None: |
| super().__init__() |
| self.fc = nn.Linear(embed_dim, alphabet_size, bias=add_bias_lm_head) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.fc(x) |
|
|
|
|
| class TorchGptDecoderLayer(nn.Module): |
| def __init__( |
| self, |
| embed_dim: int, |
| ffn_embed_dim: int, |
| num_heads: int, |
| rope_config: RotaryEmbeddingConfig, |
| norm_type: str, |
| parallel_attention_ff: bool, |
| add_bias_ffn: bool, |
| ffn_activation_name: str, |
| use_glu_in_ffn: bool, |
| num_kv_heads: int, |
| add_bias_attn: bool, |
| rms_norm_eps: float = 1e-6, |
| ) -> None: |
| super().__init__() |
| self.num_heads = num_heads |
| self.parallel_attention_ff = parallel_attention_ff |
| self.use_glu_in_ffn = use_glu_in_ffn |
|
|
| |
| self.self_attn = TorchGptGroupedQueryAttention( |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| num_kv_heads=num_kv_heads, |
| rope_config=rope_config, |
| add_bias_attn=add_bias_attn, |
| ) |
|
|
| |
| if norm_type == "layer_norm": |
| self.attn_norm = nn.LayerNorm(embed_dim) |
| if not self.parallel_attention_ff: |
| self.ffn_norm = nn.LayerNorm(embed_dim) |
| elif norm_type == "RMS_norm": |
| self.attn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps) |
| if not self.parallel_attention_ff: |
| self.ffn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps) |
| else: |
| raise ValueError(f"unrecognized norm_type: {norm_type}") |
|
|
| |
| self.activation = get_activation_fn(ffn_activation_name) |
| ffn_hidden_dim = ffn_embed_dim * (2 if use_glu_in_ffn else 1) |
| self.fc1 = nn.Linear(embed_dim, ffn_hidden_dim, bias=add_bias_ffn) |
| self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_ffn) |
|
|
| def forward( |
| self, embeddings: torch.Tensor, attention_mask: torch.Tensor |
| ) -> torch.Tensor: |
| residuals = embeddings |
|
|
| if self.parallel_attention_ff: |
| |
| embeddings_normed = self.attn_norm(embeddings) |
|
|
| attn_output, _ = self.self_attn( |
| embeddings_normed, |
| embeddings_normed, |
| embeddings_normed, |
| attn_mask=attention_mask, |
| ) |
| ffn_output = self.mlp(embeddings_normed) |
|
|
| return residuals + attn_output + ffn_output |
| else: |
| |
| normed_embeddings = self.attn_norm(embeddings) |
|
|
| attn_output = embeddings + self.self_attn( |
| normed_embeddings, |
| normed_embeddings, |
| normed_embeddings, |
| attention_mask=attention_mask, |
| ) |
|
|
| normed_embeddings2 = self.ffn_norm(attn_output) |
| ffn_output = self.mlp(normed_embeddings2) |
| return attn_output + ffn_output |
|
|
| def mlp(self, x: torch.Tensor) -> torch.Tensor: |
| """Applies the feedforward network (MLP) with optional GLU.""" |
| ffn_output = self.fc1(x) |
|
|
| if self.use_glu_in_ffn: |
| ffn_output1, ffn_output2 = ffn_output.chunk(2, dim=-1) |
| ffn_output = self.activation(ffn_output1) * ffn_output2 |
| else: |
| ffn_output = self.activation(ffn_output) |
|
|
| return self.fc2(ffn_output) |
|
|
|
|
| class TorchRMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.eps = eps |
| self.scale = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return ( |
| x |
| * self.scale |
| / torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) |
| ) |
|
|
|
|
| def get_activation_fn(activation_name: str): |
| activations = { |
| "gelu": nn.functional.gelu, |
| "relu": nn.functional.relu, |
| "swish": nn.functional.silu, |
| "silu": nn.functional.silu, |
| } |
| return activations.get(activation_name, nn.functional.relu) |
|
|
|
|
| def build_causal_attention_mask( |
| batch_size: int, seq_len: int, device: torch.device |
| ) -> torch.Tensor: |
| """ |
| Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed |
| to an attention layer. |
| |
| Args: |
| batch_size: Batch size. |
| seq_len: Length of the sequences. |
| |
| Returns: |
| Batch of causal masks. |
| """ |
| mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device) |
| causal_mask = torch.tril(mask) |
| return causal_mask |
|
|
|
|
| @dataclass |
| class RotaryEmbeddingConfigBis: |
| """ |
| Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows |
| to adapt the rotary embeddings to larger lengths than what was used for training. |
| One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa |
| Args: |
| """ |
|
|
| rescaling_factor: Optional[float] |
|
|
|
|
| class RotaryEmbeddingBis(torch.nn.Module): |
| """ |
| Rotary position embeddings based on those in |
| [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). |
| Query and keys are transformed by rotation |
| matrices which depend on their relative positions. |
| """ |
|
|
| def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfigBis): |
| super().__init__() |
|
|
| |
| self.rescaling_factor = rotary_embedding_config.rescaling_factor |
| self.upper_freq = 10000 |
| self.dim = dim |
|
|
| self._seq_len_cached = None |
| self._cos_cached = None |
| self._sin_cached = None |
|
|
| def _apply_rotary_pos_emb( |
| self, |
| heads: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> torch.Tensor: |
| """ """ |
| x_first, x_second = ( |
| heads[..., : heads.shape[-1] // 2], |
| heads[..., heads.shape[-1] // 2 :], |
| ) |
|
|
| first_part = x_first * cos - x_second * sin |
| second_part = x_second * cos + x_first * sin |
|
|
| return torch.cat((first_part, second_part), dim=-1) |
|
|
| def _compute_cos_sin_tables( |
| self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2 |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| seq_len = x.shape[seq_dimension] |
| |
| |
| self._seq_len_cached = seq_len |
| t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq) |
| |
| freqs = torch.einsum("i, j -> ij", t, inv_freq) |
|
|
| self._cos_cached = torch.cos(freqs)[None, :, None, :] |
| self._sin_cached = torch.sin(freqs)[None, :, None, :] |
| |
|
|
| |
| |
|
|
| return self._cos_cached, self._sin_cached |
|
|
| def forward( |
| self, q: torch.Tensor, k: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if self.rescaling_factor is None: |
| inv_freq = 1.0 / ( |
| self.upper_freq |
| ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim) |
| ) |
| else: |
| updated_base = self.upper_freq * ( |
| self.rescaling_factor ** (self.dim / (self.dim - 2)) |
| ) |
| inv_freq = 1.0 / ( |
| updated_base |
| ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim) |
| ) |
|
|
| self._cos_cached, self._sin_cached = self._compute_cos_sin_tables( |
| q, |
| inv_freq, |
| seq_dimension=-3, |
| ) |
|
|
| return ( |
| self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), |
| self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), |
| ) |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| def __init__( |
| self, |
| num_heads: int, |
| key_size: int, |
| rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None, |
| add_bias_kv: bool = False, |
| value_size: Optional[int] = None, |
| model_size: Optional[int] = None, |
| name: Optional[str] = None, |
| ): |
| super().__init__() |
| if not model_size: |
| model_size = key_size * num_heads |
| if not value_size: |
| value_size = key_size |
| self.model_size = model_size |
| self.key_size = key_size |
| self.value_size = value_size |
| self.add_bias_kv = add_bias_kv |
| self.name = name |
| self.num_heads = num_heads |
| self._rotary_embedding_config = rotary_embedding_config |
|
|
| self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size) |
| self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size) |
| self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size) |
| self.output = nn.Linear(self.num_heads * self.value_size, self.model_size) |
| if self._rotary_embedding_config: |
| self._rotary_embedding = RotaryEmbeddingBis( |
| self.key_size, self._rotary_embedding_config |
| ) |
|
|
| def apply_rotary_embeddings( |
| self, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ """ |
| query, key = self._rotary_embedding(query, key) |
| return query, key |
|
|
| def forward( |
| self, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| attention_weight_bias: Optional[torch.Tensor] = None, |
| ) -> dict[str, torch.Tensor]: |
| """ |
| Returns: |
| dictionary containing attention weights |
| and outputs. |
| """ |
| key_heads = self.w_k(key).reshape( |
| (*key.shape[:-1], self.num_heads, self.key_size) |
| ) |
| query_heads = self.w_q(query).reshape( |
| (*query.shape[:-1], self.num_heads, self.key_size) |
| ) |
| value_heads = self.w_v(value).reshape( |
| (*value.shape[:-1], self.num_heads, self.value_size) |
| ) |
| if self._rotary_embedding_config: |
| query_heads, key_heads = self.apply_rotary_embeddings( |
| query_heads, key_heads |
| ) |
| attention_weights = torch.einsum( |
| "...thd, ...Thd -> ...htT", query_heads, key_heads |
| ) |
| sqrt_key_size = np.sqrt(self.key_size) |
| attention_weights = attention_weights / sqrt_key_size |
| if attention_mask is not None: |
| attention_weights = torch.where(attention_mask, attention_weights, -1e30) |
|
|
| attention_weights = attention_weights.to(value_heads.dtype) |
|
|
| if attention_weight_bias is not None: |
| attention_weights = F.softmax( |
| attention_weights + attention_weight_bias, dim=-1 |
| ) |
| else: |
| attention_weights = F.softmax(attention_weights, dim=-1) |
|
|
| value_out = torch.einsum( |
| "...htT, ...Thd->...thd", attention_weights, value_heads |
| ) |
| value_out = value_out.reshape((*value_out.shape[:-2], -1)) |
| embeddings = self.output(value_out) |
|
|
| return {"attention_weights": attention_weights, "embeddings": embeddings} |
|
|
|
|
| class SelfAttentionBlock(nn.Module): |
| def __init__( |
| self, |
| num_heads: int, |
| embed_dim: int, |
| ffn_embed_dim: int, |
| key_size: Optional[int] = None, |
| add_bias_kv: bool = False, |
| add_bias_fnn: bool = True, |
| ffn_activation_name: str = "gelu-no-approx", |
| use_glu_in_ffn: bool = False, |
| layer_norm_eps: float = 1e-5, |
| pre_layer_norm: bool = True, |
| name: Optional[str] = None, |
| rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None, |
| ): |
| super().__init__() |
| if key_size is None: |
| if embed_dim % num_heads != 0: |
| raise ValueError( |
| f"The embedding dimension should be divisible by the number of " |
| f"heads, however provided embedding dimension is {embed_dim} and " |
| f"the number of heads is {num_heads}." |
| ) |
| else: |
| key_size = embed_dim // num_heads |
|
|
| |
| self._pre_layer_norm = pre_layer_norm |
| self._use_glu_in_fnn = use_glu_in_ffn |
| |
| if use_glu_in_ffn: |
| |
| |
| |
| |
| self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn) |
| else: |
| self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn) |
|
|
| self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn) |
|
|
| self.layer_norm_self_attention = nn.LayerNorm( |
| embed_dim, |
| ) |
| self.layer_norm_mlp = nn.LayerNorm(embed_dim) |
| if ffn_activation_name == "swish": |
| self._ffn_activation_fn = nn.SiLU() |
| elif ffn_activation_name == "gelu-no-approx": |
| self._ffn_activation_fn = nn.GELU(approximate="tanh") |
| else: |
| self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name) |
|
|
| self.mha = MultiHeadAttention( |
| num_heads=num_heads, |
| key_size=key_size, |
| add_bias_kv=add_bias_kv, |
| model_size=embed_dim, |
| name="self_attention", |
| rotary_embedding_config=rotary_embedding_config, |
| ) |
|
|
| def mlp(self, embed: torch.Tensor) -> torch.Tensor: |
|
|
| if self._pre_layer_norm: |
| x = self.layer_norm_mlp(embed) |
| else: |
| x = embed |
|
|
| if self._use_glu_in_fnn: |
| x = self.fc1(x) |
| x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1) |
| x = self._ffn_activation_fn(x1) * x2 |
| else: |
| x = self._ffn_activation_fn(self.fc1(x)) |
| x = self.fc2(x) |
|
|
| if not self._pre_layer_norm: |
| x = self.layer_norm_mlp(x + embed) |
| return x |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| attention_weight_bias: Optional[torch.Tensor] = None, |
| ) -> dict[str, torch.Tensor]: |
|
|
| res = x |
| if self._pre_layer_norm: |
| x = self.layer_norm_self_attention(x) |
|
|
| output: dict[str, torch.Tensor] = self.mha( |
| x, |
| x, |
| x, |
| attention_mask=attention_mask, |
| attention_weight_bias=attention_weight_bias, |
| ) |
|
|
| if not self._pre_layer_norm: |
| output["embeddings"] = self.layer_norm_self_attention( |
| output["embeddings"] + res |
| ) |
|
|
| x = output["embeddings"] |
| else: |
| x = output["embeddings"] |
| x = res + x |
|
|
| |
| if not self._pre_layer_norm: |
| x = self.mlp(x) |
| else: |
| x = x + self.mlp(x) |
|
|
| output["embeddings"] = x |
| return output |
|
|
|
|
| class RobertaLMHead(nn.Module): |
| """ |
| Roberta Language Model head. Transforms final attention layer output into a |
| distribution over tokens at each position. |
| """ |
|
|
| def __init__(self, embed_dim: int, alphabet_size: int): |
| """ |
| Args: |
| embed_dim: Embedding dimension. |
| alphabet_size: Number of tokens in the alphabet. |
| """ |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.alphabet_size = alphabet_size |
|
|
| |
| self._first_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True) |
| self._fc1 = nn.Linear(embed_dim, embed_dim) |
| self._second_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True) |
| self._final_fc = nn.Linear(embed_dim, alphabet_size) |
|
|
| def forward(self, x: torch.Tensor) -> dict: |
| x = self._first_layer_norm(x) |
| embeddings = x |
| x = self._fc1(x) |
| x = nn.functional.gelu(x) |
| x = self._second_layer_norm(x) |
| logits = self._final_fc(x) |
| return {"embeddings": embeddings, "logits": logits} |
|
|
|
|
| class TorchNucleotideTransformer(nn.Module): |
| def __init__( |
| self, |
| nt_config: NucleotideTransformerConfig, |
| ): |
| super(TorchNucleotideTransformer, self).__init__() |
| self.nt_config = nt_config |
|
|
| |
| assert nt_config.positional_embedding is None |
| assert nt_config.lm_head == "roberta" |
| assert nt_config.use_rotary_embedding is True |
| assert nt_config.token_dropout is False |
| assert nt_config.emb_layer_norm_before is False |
| assert nt_config.mask_before_attention is False |
| assert nt_config.bias_word_embedding is False |
| assert nt_config.use_gradient_checkpointing is False |
|
|
| self.embed_layer = nn.Embedding(nt_config.alphabet_size, nt_config.embed_dim) |
|
|
| self.lm_head = RobertaLMHead( |
| embed_dim=nt_config.embed_dim, |
| alphabet_size=nt_config.alphabet_size, |
| ) |
|
|
| self.rotary_embedding_config = RotaryEmbeddingConfigBis( |
| rescaling_factor=nt_config.rescaling_factor |
| ) |
|
|
| self.attention_blocks = nn.ModuleList( |
| [ |
| SelfAttentionBlock( |
| num_heads=nt_config.attention_heads, |
| embed_dim=nt_config.embed_dim, |
| key_size=nt_config.key_size, |
| ffn_embed_dim=nt_config.ffn_embed_dim, |
| add_bias_kv=nt_config.add_bias_kv, |
| add_bias_fnn=nt_config.add_bias_ffn, |
| ffn_activation_name=nt_config.ffn_activation_name, |
| use_glu_in_ffn=nt_config.use_glu_in_ffn, |
| rotary_embedding_config=self.rotary_embedding_config, |
| layer_norm_eps=nt_config.layer_norm_eps, |
| pre_layer_norm=nt_config.pre_layer_norm, |
| ) |
| for _ in range(nt_config.num_layers) |
| ] |
| ) |
|
|
| def forward( |
| self, tokens: torch.Tensor, attention_mask: torch.Tensor = None |
| ) -> torch.Tensor: |
| """ |
| Computes the embeddings based on the input tokens. |
| |
| Args: |
| tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len). |
| attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len). |
| If no mask is provided, a mask by default which equals 1 over all non |
| pad tokens and 0 over pad tokens is computed. |
| |
| Returns: |
| Dictionary containing the final embeddings and logits. |
| """ |
| x = self.embed_layer(tokens) |
|
|
| |
| x = self.nt_config.embed_scale * x |
|
|
| if attention_mask is None: |
| attention_mask = build_padding_attention_mask( |
| tokens=tokens, pad_token_id=self.nt_config.pad_token_id |
| ) |
|
|
| for layer in self.attention_blocks: |
| x = layer(x, attention_mask)["embeddings"] |
|
|
| assert self.nt_config.lm_head == "roberta" |
| x = self.lm_head(x)["embeddings"] |
|
|
| return x |
|
|
|
|
| def build_padding_attention_mask( |
| tokens: torch.Tensor, pad_token_id: int |
| ) -> torch.Tensor: |
| """ |
| Builds a padding mask from a sequence of tokens by masking <pad> in the attention. |
| |
| Args: |
| tokens: Batch of sequences of shape (batch_size, seq_len). |
| pad_token_id: Int corresponding to the <pad> token to mask. |
| |
| Returns: |
| Batch of attention masks, masking out <pad> tokens. |
| """ |
| padding_mask = tokens != pad_token_id |
| padding_mask = padding_mask.unsqueeze(1) |
| padding_mask = torch.einsum("bhT, bht -> bhtT", padding_mask, padding_mask) |
| return padding_mask |
|
|
|
|
| class TorchBioBrainEncoder(nn.Module): |
| def __init__( |
| self, |
| nt_config: NucleotideTransformerConfig, |
| ): |
| super(TorchBioBrainEncoder, self).__init__() |
| self.nt_config = nt_config |
| self.nt_model = TorchNucleotideTransformer(self.nt_config) |
|
|
| def forward( |
| self, |
| bio_token_ids: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| bio_token_ids (torch.Tensor): |
| Shape (batch_size, num_bio_tokens) |
| |
| Returns: |
| torch.Tensor: |
| Shape (batch_size, num_bio_tokens, embed_dim) |
| """ |
| bio_embeddings = self.nt_model(tokens=bio_token_ids) |
|
|
| return bio_embeddings |
|
|
|
|
| class TorchMultiModalPerceiverResamplerBlock(nn.Module): |
| def __init__( |
| self, |
| num_heads: int, |
| embed_dim: int, |
| ffn_embed_dim: int, |
| key_size: Optional[int] = None, |
| add_bias_kv: bool = False, |
| add_bias_ffn: bool = True, |
| ffn_activation_name: str = "gelu", |
| use_glu_in_ffn: bool = False, |
| ): |
| super().__init__() |
|
|
| if key_size is None: |
| if embed_dim % num_heads != 0: |
| raise ValueError( |
| f"Embedding dimension {embed_dim} should be divisible by " |
| f"num_heads {num_heads}." |
| ) |
| key_size = embed_dim // num_heads |
|
|
| self.num_heads = num_heads |
| self.embed_dim = embed_dim |
| self.ffn_embed_dim = ffn_embed_dim * 2 if use_glu_in_ffn else ffn_embed_dim |
| self.use_glu_in_ffn = use_glu_in_ffn |
|
|
| self.cross_attention_1 = MultiHeadAttention( |
| num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv |
| ) |
| self.cross_attention_2 = MultiHeadAttention( |
| num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv |
| ) |
|
|
| self.norm_cross_attention_1 = nn.LayerNorm(embed_dim) |
| self.norm_cross_attention_2 = nn.LayerNorm(embed_dim) |
| self.norm_mlp = nn.LayerNorm(embed_dim) |
|
|
| self.fc1 = nn.Linear(embed_dim, self.ffn_embed_dim, bias=add_bias_ffn) |
| self.fc2 = nn.Linear(self.ffn_embed_dim, embed_dim, bias=add_bias_ffn) |
|
|
| self.activation_fn = getattr( |
| nn.functional, ffn_activation_name, nn.functional.gelu |
| ) |
|
|
| def mlp(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.norm_mlp(x) |
| if self.use_glu_in_ffn: |
| x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1) |
| x = self.activation_fn(x1) * x2 |
| else: |
| x = self.activation_fn(self.fc1(x)) |
| return self.fc2(x) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| cross_attention_embeddings_1: torch.Tensor, |
| cross_attention_embeddings_2: torch.Tensor, |
| attention_mask_1: Optional[torch.Tensor] = None, |
| attention_mask_2: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| res = x |
| x = self.norm_cross_attention_1(x) |
|
|
| attn_output = self.cross_attention_1( |
| query=x, |
| key=cross_attention_embeddings_1, |
| value=cross_attention_embeddings_1, |
| attention_mask=attention_mask_1, |
| )["embeddings"] |
| x = res + attn_output |
|
|
| res = x |
| x = self.norm_cross_attention_2(x) |
| attn_output = self.cross_attention_2( |
| query=x, |
| key=cross_attention_embeddings_2, |
| value=cross_attention_embeddings_2, |
| attention_mask=attention_mask_2, |
| )["embeddings"] |
| x = res + attn_output |
|
|
| x = x + self.mlp(x) |
|
|
| return {"embeddings": x} |
|
|
|
|
| class TorchMultiModalPerceiverResampler(nn.Module): |
| """ |
| Perceiver Resampler model, made of successive PerceiverResamplerBlocks. |
| """ |
|
|
| def __init__( |
| self, |
| config: PerceiverResamplerConfig, |
| name: Optional[str] = None, |
| ): |
| """ |
| Initialize a Perceiver Resampler model. |
| |
| Args: |
| config: Dataclass containing model hyperparameters. |
| name: Name for module (custom will break weight loading). |
| """ |
| super().__init__() |
| self.config = config |
| self.name = name |
| self.layers = nn.ModuleList( |
| [ |
| TorchMultiModalPerceiverResamplerBlock( |
| num_heads=self.config.attention_heads, |
| embed_dim=self.config.embed_dim, |
| key_size=self.config.key_size, |
| ffn_embed_dim=self.config.ffn_embed_dim, |
| add_bias_kv=self.config.add_bias_kv, |
| add_bias_ffn=self.config.add_bias_ffn, |
| ffn_activation_name=self.config.ffn_activation_name, |
| use_glu_in_ffn=self.config.use_glu_in_ffn, |
| ) |
| for _ in range(self.config.num_layers) |
| ] |
| ) |
|
|
| self.latent_queries = torch.nn.Parameter( |
| torch.randn(self.config.resampled_length, self.config.embed_dim) |
| * ( |
| 1.0 |
| / torch.sqrt(torch.tensor(self.config.embed_dim, dtype=torch.float32)) |
| ) |
| ) |
|
|
| def apply_attention_blocks( |
| self, |
| x: torch.Tensor, |
| xf_1: torch.Tensor, |
| xf_2: torch.Tensor, |
| outs: Dict[str, torch.Tensor], |
| attention_mask_1: Optional[torch.Tensor] = None, |
| attention_mask_2: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| """ |
| Create the blocks of attention layers and applies them. |
| """ |
| for layer in self.layers: |
| concat_input_1 = torch.cat([xf_1, x], dim=1) |
| concat_input_2 = torch.cat([xf_2, x], dim=1) |
|
|
| output = layer( |
| x=x, |
| cross_attention_embeddings_1=concat_input_1, |
| cross_attention_embeddings_2=concat_input_2, |
| attention_mask_1=attention_mask_1, |
| attention_mask_2=attention_mask_2, |
| ) |
| x = output["embeddings"] |
|
|
| return x, outs |
|
|
| def forward( |
| self, |
| input_embeddings_1: torch.Tensor, |
| input_embeddings_2: torch.Tensor, |
| attention_mask_1: Optional[torch.Tensor] = None, |
| attention_mask_2: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Computes the embeddings based on the input tokens. |
| """ |
| assert ( |
| input_embeddings_1.shape[-1] == self.config.embed_dim |
| ), "The input embedding dim should match the model embed dim" |
| assert ( |
| input_embeddings_2.shape[-1] == self.config.embed_dim |
| ), "The input embedding dim should match the model embed dim" |
|
|
| batch_size = input_embeddings_1.shape[0] |
|
|
| latent_queries = self.latent_queries.unsqueeze(0).repeat(batch_size, 1, 1) |
|
|
| outs: Dict[str, torch.Tensor] = {} |
| x = latent_queries |
|
|
| x, outs = self.apply_attention_blocks( |
| x=x, |
| xf_1=input_embeddings_1, |
| xf_2=input_embeddings_2, |
| outs=outs, |
| attention_mask_1=attention_mask_1, |
| attention_mask_2=attention_mask_2, |
| ) |
|
|
| outs["embeddings"] = x |
|
|
| return outs |
|
|
|
|
| class TorchMultiModalPerceiverResamplerProjection(nn.Module): |
| def __init__( |
| self, |
| perceiver_resampler_config: PerceiverResamplerConfig, |
| input_embed_dim: int, |
| embed_dim: int, |
| bio_pad_token_id: int, |
| english_pad_token_id: int, |
| english_vocab_size: int, |
| ): |
| super().__init__() |
| self.config = perceiver_resampler_config |
| self.input_embed_dim = input_embed_dim |
| self.embed_dim = embed_dim |
| self.bio_pad_token_id = bio_pad_token_id |
| self.english_pad_token_id = english_pad_token_id |
| self.english_vocab_size = english_vocab_size |
|
|
| self.bio_projection = nn.Linear(input_embed_dim, embed_dim) |
| self.token_embedding = nn.Embedding(english_vocab_size, embed_dim) |
| self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config) |
|
|
| def forward( |
| self, |
| bio_token_ids: torch.Tensor, |
| bio_embeddings: torch.Tensor, |
| english_token_ids: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| bio_token_ids (torch.Tensor): |
| Shape (batch_size, num_bio_tokens) |
| |
| bio_embeddings (torch.Tensor): |
| Shape (batch_size, num_bio_tokens, embed_dim) |
| |
| english_token_ids (torch.Tensor): |
| Shape (batch_size, num_english_tokens) |
| """ |
| projected_bio_embeddings = self.bio_projection(bio_embeddings) |
| english_embeddings = self.token_embedding(english_token_ids) |
|
|
| bio_attention_mask = build_perceiver_padding_attention_mask( |
| bio_token_ids, self.config.resampled_length, self.bio_pad_token_id |
| ) |
| english_attention_mask = build_perceiver_padding_attention_mask( |
| english_token_ids, self.config.resampled_length, self.english_pad_token_id |
| ) |
|
|
| projected_embeddings = self.perceiver_resampler( |
| input_embeddings_1=projected_bio_embeddings, |
| attention_mask_1=bio_attention_mask, |
| input_embeddings_2=english_embeddings, |
| attention_mask_2=english_attention_mask, |
| )["embeddings"] |
|
|
| return projected_embeddings |
|
|
|
|
| def build_perceiver_padding_attention_mask( |
| tokens: torch.Tensor, resampled_length: int, pad_token_id: int |
| ) -> torch.Tensor: |
| batch_size, seq_len = tokens.shape |
| padding_mask = tokens != pad_token_id |
|
|
| padding_mask = torch.cat( |
| [ |
| padding_mask, |
| torch.ones( |
| (batch_size, resampled_length), dtype=torch.bool, device=tokens.device |
| ), |
| ], |
| dim=1, |
| ) |
|
|
| padding_mask = padding_mask[:, None, None, :] |
| padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) |
| return padding_mask |
|
|