| import torch |
| import math |
| from torch import nn |
| import torch.nn.functional as F |
| import einops |
| from rotary_embedding_torch import RotaryEmbedding |
|
|
| class TransformerEncoder(torch.nn.Module): |
| """ |
| Single Transformer Encoder. |
| |
| """ |
| def __init__( |
| self, |
| hidden_embed_size, |
| n_attn_heads, |
| attn_dropout: float = 0.0, |
| layer_norm_eps: float = 1e-05, |
| a_fn: str = "gelu", |
| ): |
| super().__init__() |
| |
| assert hidden_embed_size % n_attn_heads == 0, \ |
| "Embedding dimension must be devisible with the number of heads." |
| |
| self.multihead_attention = MultiHeadAttention( |
| embed_dim = hidden_embed_size, |
| num_heads = n_attn_heads, |
| attention_dropout_prob = attn_dropout |
| ) |
| |
| activation_fn, scale = get_activation_fn(a_fn) |
| |
| self.intermediate_layer = torch.nn.Sequential( |
| torch.nn.Linear(hidden_embed_size, hidden_embed_size * 4 * scale), |
| activation_fn(), |
| torch.nn.Linear(hidden_embed_size * 4, hidden_embed_size), |
| ) |
| |
| self.pre_attn_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) |
| self.final_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) |
| |
| def forward(self, hidden_embed, attn_mask=None, return_attn_weights: bool = False): |
| |
| residual = hidden_embed |
| hidden_embed = self.pre_attn_layer_norm(hidden_embed.clone()) |
| hidden_embed, attn_weights = self.multihead_attention( |
| hidden_embed, |
| attn_mask=attn_mask, |
| return_attn_weights=return_attn_weights |
| ) |
| hidden_embed = residual + hidden_embed |
| |
| residual = hidden_embed |
| hidden_embed = self.final_layer_norm(hidden_embed) |
| hidden_embed = self.intermediate_layer(hidden_embed) |
| hidden_embed = residual + hidden_embed |
| return hidden_embed, attn_weights |
| |
| class MultiHeadAttention(torch.nn.Module): |
|
|
| def __init__( |
| self, |
| embed_dim, |
| num_heads, |
| attention_dropout_prob: float = 0.0, |
| bias: bool = True, |
| ): |
| super().__init__() |
| |
| self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) |
|
|
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
| self.head_dim = embed_dim // num_heads |
| assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads" |
| self.scaling = self.head_dim**-0.5 |
| |
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| |
| self.reset_parameters() |
| |
| self.rotary_emb = RotaryEmbedding(dim = self.head_dim) |
| |
| def reset_parameters(self): |
| |
| nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) |
| nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) |
| nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) |
|
|
| nn.init.xavier_uniform_(self.out_proj.weight) |
| if self.out_proj.bias is not None: |
| nn.init.constant_(self.out_proj.bias, 0.0) |
| |
| def attention(self, q, k, v, attn_mask=None): |
| |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) |
| attn_weights = attn_weights / math.sqrt(self.head_dim) |
| |
| if attn_mask is not None: |
| attn_mask = einops.rearrange( |
| attn_mask, |
| 'b_size (h1 h2 seq_len) -> b_size h1 h2 seq_len', |
| h1=1, h2=1 |
| ) |
| attn_weights = attn_weights.masked_fill(attn_mask, float("-inf")) |
|
|
| attn_weights = F.softmax(attn_weights, dim=-1) |
| |
| attn = self.attention_dropout(attn_weights) |
| attn = torch.matmul(attn, v) |
| return attn, attn_weights |
|
|
| def forward(self, x, attn_mask=None, return_attn_weights: bool = False): |
| |
| batch_size, seq_len, embed_dim = x.size() |
| |
| q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x) |
| q *= self.scaling |
| |
| q = q.contiguous().view( |
| batch_size, |
| seq_len, |
| self.num_heads, |
| self.head_dim |
| ).transpose(1, 2) |
| k = k.contiguous().view( |
| batch_size, |
| seq_len, |
| self.num_heads, |
| self.head_dim |
| ).transpose(1, 2) |
| v = v.contiguous().view( |
| batch_size, |
| seq_len, |
| self.num_heads, |
| self.head_dim |
| ).transpose(1, 2) |
| |
| q = self.rotary_emb.rotate_queries_or_keys(q) |
| k = self.rotary_emb.rotate_queries_or_keys(k) |
| |
| |
| attn, attn_weights = self.attention( |
| q, k, v, |
| attn_mask=attn_mask |
| ) |
| |
| attn = attn.transpose(1, 2).reshape(batch_size, seq_len, embed_dim) |
| attn = self.out_proj(attn) |
|
|
| if return_attn_weights: |
| return attn, attn_weights |
| else: |
| return attn, None |
| |
| class SwiGLU(torch.nn.Module): |
| def forward(self, x): |
| x, gate = x.chunk(2, dim=-1) |
| return F.silu(gate) * x |
| |
| def get_activation_fn(a_fn): |
| |
| if a_fn == "gelu": |
| return torch.nn.GELU, 1 |
| |
| elif a_fn == "swiglu": |
| return SwiGLU, 2 |
| |