| import torch |
| from torch import Tensor |
| import torch.nn as nn |
| from safetensors.torch import load_model |
| from jaxtyping import Bool, Int, Float |
| from huggingface_hub import hf_hub_download |
| from embedding import InputEmbeddings, PositionalEncoding |
| from modules import Encoder, Decoder |
| import config |
|
|
|
|
| class Generator(nn.Module): |
| """ |
| Implements the final Linear (Projection) layer and Softmax. |
| |
| This module takes the final output of the Decoder stack (B, T, D) |
| and projects it onto the vocabulary space (B, T, vocab_size) |
| to produce the logits. |
| |
| (This layer's weights can be tied with the |
| target embedding layer, which we will handle in the main |
| 'Transformer' model class). |
| """ |
|
|
| def __init__(self, d_model: int, vocab_size: int) -> None: |
| """ |
| Initializes the Generator (Output Projection) layer. |
| |
| Args: |
| d_model (int): The dimension of the model (D). |
| vocab_size (int): The size of the target vocabulary. |
| """ |
| super().__init__() |
|
|
| self.proj: nn.Linear = nn.Linear(d_model, vocab_size, bias=False) |
|
|
| def forward( |
| self, x: Float[Tensor, "B T_tgt D"] |
| ) -> Float[Tensor, "B T_tgt vocab_size"]: |
| """ |
| Forward pass for the Generator. |
| |
| Args: |
| x (Tensor): The final output tensor from the Decoder stack. |
| |
| Returns: |
| Tensor: The output logits over the vocabulary. |
| """ |
| |
| logits = self.proj(x) |
| return logits |
|
|
|
|
| class Transformer(nn.Module): |
| """ |
| The main Transformer model architecture, combining the Encoder |
| and Decoder stacks, as described in "Attention Is All You Need". |
| |
| This implementation follows modern best practices (Pre-LN) and |
| is designed for a sequence-to-sequence task (e.g., translation). |
| """ |
|
|
| def __init__( |
| self, |
| src_vocab_size: int, |
| tgt_vocab_size: int, |
| d_model: int, |
| n_heads: int, |
| n_layers: int, |
| d_ff: int, |
| dropout: float = 0.1, |
| max_seq_len: int = 512, |
| ) -> None: |
| """ |
| Initializes the full Transformer model. |
| |
| Args: |
| src_vocab_size (int): Vocabulary size for the source language. |
| tgt_vocab_size (int): Vocabulary size for the target language. |
| d_model (int): The dimension of the model (D). |
| n_heads (int): The number of attention heads (H). |
| n_layers (int): The number of Encoder/Decoder layers (N). |
| d_ff (int): The inner dimension of the Feed-Forward Network (D_FF). |
| dropout (float): The dropout rate. |
| max_seq_len (int): The maximum sequence length for positional encoding. |
| """ |
| super().__init__() |
|
|
| self.d_model = d_model |
|
|
| |
| |
| self.src_embed: InputEmbeddings = InputEmbeddings(d_model, src_vocab_size) |
|
|
| |
| self.tgt_embed: InputEmbeddings = InputEmbeddings(d_model, tgt_vocab_size) |
|
|
| |
| |
| |
| self.pos_enc: PositionalEncoding = PositionalEncoding( |
| d_model, max_seq_len, dropout |
| ) |
|
|
| |
| self.encoder: Encoder = Encoder(d_model, n_heads, d_ff, n_layers, dropout) |
|
|
| |
| self.decoder: Decoder = Decoder(d_model, n_heads, d_ff, n_layers, dropout) |
|
|
| |
| self.generator: Generator = Generator(d_model, tgt_vocab_size) |
|
|
| |
| |
| |
| self.generator.proj.weight = self.tgt_embed.token_emb.weight |
|
|
| |
| |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module: nn.Module): |
| """ |
| Applies Xavier/Glorot uniform initialization to linear layers. |
| This is a common and effective initialization strategy. |
| """ |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_uniform_(module.weight) |
|
|
| if module.bias is not None: |
| nn.init.constant_(module.bias, 0) |
|
|
| elif isinstance(module, nn.Embedding): |
| |
| nn.init.normal_(module.weight, mean=0, std=self.d_model**-0.5) |
|
|
| def forward( |
| self, |
| src: Int[Tensor, "B T_src"], |
| tgt: Int[Tensor, "B T_tgt"], |
| src_mask: Bool[Tensor, "B 1 1 T_src"], |
| tgt_mask: Bool[Tensor, "B 1 T_tgt T_tgt"], |
| ) -> Float[Tensor, "B T_tgt vocab_size"]: |
| """ |
| Defines the main forward pass of the Transformer model. |
| |
| Args: |
| src (Tensor): Source sequence token IDs. |
| tgt (Tensor): Target sequence token IDs (shifted right). |
| src_mask (Tensor): Padding mask for the source sequence. |
| tgt_mask (Tensor): Combined padding and look-ahead mask |
| for the target sequence. |
| |
| Returns: |
| Tensor: The output logits from the model (B, T_tgt, vocab_size). |
| """ |
| |
| |
| src_embeded = self.src_embed(src) |
| src_with_pos = self.pos_enc(src_embeded) |
|
|
| |
| |
| enc_output: Tensor = self.encoder(src_with_pos, src_mask) |
|
|
| |
| |
| tgt_embeded = self.tgt_embed(tgt) |
| tgt_with_pos = self.pos_enc(tgt_embeded) |
|
|
| |
| dec_output: Tensor = self.decoder(tgt_with_pos, enc_output, src_mask, tgt_mask) |
|
|
| |
| |
| logits: Tensor = self.generator(dec_output) |
|
|
| return logits |
|
|
|
|
| def load_trained_model( |
| config_obj, checkpoint_path, device: torch.device |
| ) -> Transformer: |
| print("Downloading safetensors from Hub...") |
| model_path = hf_hub_download(repo_id=config.REPO_ID, filename=config.FILENAME) |
|
|
| print("Instantiating the Transformer model...") |
| model = Transformer( |
| src_vocab_size=config_obj.VOCAB_SIZE, |
| tgt_vocab_size=config_obj.VOCAB_SIZE, |
| d_model=config_obj.D_MODEL, |
| n_heads=config_obj.N_HEADS, |
| n_layers=config_obj.N_LAYERS, |
| d_ff=config_obj.D_FF, |
| dropout=config_obj.DROPOUT, |
| max_seq_len=config_obj.MAX_SEQ_LEN, |
| ).to(device) |
|
|
| |
| |
|
|
| print(f"Loading model from: {model_path}") |
| load_model(model, filename=model_path) |
|
|
| print(f"Successfully loaded trained weights from {model_path}") |
| return model |
|
|