| from torch import Tensor |
| import torch.nn as nn |
| from jaxtyping import Bool, Float |
| import math |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| """ |
| Terminology (jaxtyping): |
| B: batch_size |
| T_q: target sequence length (query) |
| T_k: source sequence length (key/value) |
| D: d_model (model dimension) |
| H: n_heads (number of heads) |
| d_k: dimension of each head (d_model / n_heads) |
| """ |
|
|
| def __init__(self, d_model: int, n_heads: int) -> None: |
| super().__init__() |
| assert d_model % n_heads == 0, "d_model must be divisible by n_heads" |
|
|
| self.d_model: int = d_model |
| self.n_heads: int = n_heads |
| self.d_k: int = d_model // n_heads |
|
|
| self.w_q: nn.Linear = nn.Linear(d_model, d_model, bias=False) |
| self.w_k: nn.Linear = nn.Linear(d_model, d_model, bias=False) |
| self.w_v: nn.Linear = nn.Linear(d_model, d_model, bias=False) |
| self.w_o: nn.Linear = nn.Linear(d_model, d_model, bias=False) |
|
|
| self.attention_weights: Tensor | None = None |
|
|
| @staticmethod |
| def attention( |
| query: Float[Tensor, "B H T_q d_k"], |
| key: Float[Tensor, "B H T_k d_k"], |
| value: Float[Tensor, "B H T_k d_k"], |
| mask: Bool[Tensor, "... 1 T_q T_k"] | None, |
| ) -> tuple[Float[Tensor, "B H T_q d_k"], Float[Tensor, "B H T_q T_k"]]: |
| """ |
| Static method for Scaled Dot-Product Attention calculation. |
| This is pure, stateless logic, making it easy to test. |
| (Ref: "Attention Is All You Need", Equation 1) |
| |
| Args: |
| query (Tensor): Query tensor |
| key (Tensor): Key tensor |
| value (Tensor): Value tensor |
| mask (Tensor | None): Optional mask (for padding or look-ahead). |
| |
| Returns: |
| tuple[Tensor, Tensor]: |
| - context_vector: The output of the attention mechanism. |
| - attention_weights: The softmax-normalized attention weights. |
| """ |
|
|
| d_k: int = query.shape[-1] |
|
|
| |
| attention_scores: Tensor = (query @ key.transpose(-2, -1)) / math.sqrt(d_k) |
|
|
| if mask is not None: |
| attention_scores = attention_scores.masked_fill( |
| mask == 0, value=float("-inf") |
| ) |
|
|
| attention_weights: Tensor = attention_scores.softmax(dim=-1) |
|
|
| |
| context_vector: Tensor = attention_weights @ value |
|
|
| return context_vector, attention_weights |
|
|
| def forward( |
| self, |
| q: Float[Tensor, "B T_q D"], |
| k: Float[Tensor, "B T_k D"], |
| v: Float[Tensor, "B T_k D"], |
| mask: Bool[Tensor, "... 1 T_q T_k"] | None = None, |
| ) -> Float[Tensor, "B T_q D"]: |
| """ |
| Forward pass for Multi-Head Attention. |
| |
| In Self-Attention (Encoder), q, k, and v are all the same tensor. |
| In Cross-Attention (Decoder), q comes from the Decoder, while k and v |
| come from the Encoder's output. |
| |
| Args: |
| q: Query tensor |
| k: Key tensor |
| v: Value tensor |
| mask: Optional mask to apply (padding or look-ahead) |
| |
| Returns: |
| The context vector after multi-head attention and output projection. |
| """ |
|
|
| B, T_q, _ = q.shape |
| _, T_k, _ = k.shape |
|
|
| |
| Q: Tensor = self.w_q(q) |
| K: Tensor = self.w_k(k) |
| V: Tensor = self.w_v(v) |
|
|
| |
| Q = Q.view(B, T_q, self.n_heads, self.d_k).transpose(1, 2) |
| K = K.view(B, T_k, self.n_heads, self.d_k).transpose(1, 2) |
| V = V.view(B, T_k, self.n_heads, self.d_k).transpose(1, 2) |
|
|
| context_vector, self.attention_weights = self.attention(Q, K, V, mask) |
|
|
| |
| context_vector = context_vector.transpose(1, 2).contiguous() |
|
|
| |
| context_vector = context_vector.view(B, T_q, self.d_model) |
|
|
| |
| output: Tensor = self.w_o(context_vector) |
|
|
| return output |
|
|
|
|
| class PositionwiseFeedForward(nn.Module): |
| """ |
| Implements the Position-wise Feed-Forward Network (FFN) sublayer. |
| (Ref: "Attention Is All You Need", Section 3.3) |
| |
| This is a two-layer MLP (Multi-Layer Perceptron) applied independently |
| to each position in the sequence. |
| |
| FFN(x) = max(0, x * W_1 + b_1) * W_2 + b_2 |
| (Or using ReLU activation) |
| |
| Terminology (jaxtyping): |
| B: batch_size |
| T: seq_len (context_length) |
| D: d_model (model dimension) |
| D_FF: d_ff (inner feed-forward dimension) |
| """ |
|
|
| def __init__(self, d_model: int, d_ff: int) -> None: |
| """ |
| Initializes the FFN. |
| |
| Args: |
| d_model (int): Dimension of the model (e.g., 512). |
| d_ff (int): Inner dimension of the FFN (e.g., 2048). |
| Paper suggests d_ff = 4 * d_model. |
| dropout (float): Dropout probability (applied *before* the |
| second linear layer in some implementations, |
| or as part of ResidualConnection). |
| """ |
| super().__init__() |
|
|
| |
| self.linear_1: nn.Linear = nn.Linear(d_model, d_ff) |
|
|
| self.activation: nn.ReLU = nn.ReLU() |
|
|
| |
| self.linear_2: nn.Linear = nn.Linear(d_ff, d_model) |
|
|
| def forward(self, x: Float[Tensor, "B T D"]) -> Float[Tensor, "B T D"]: |
| """ |
| Forward pass for the FFN. |
| Applies two linear transformations with a ReLU activation in between. |
| |
| Args: |
| x: Input tensor from the previous sublayer |
| (e.g., MultiHeadAttention output). |
| |
| Returns: |
| Output tensor of the same shape. |
| """ |
| |
| x = self.linear_1(x) |
|
|
| |
| x = self.activation(x) |
|
|
| |
| x = self.linear_2(x) |
|
|
| return x |
|
|