| from pathlib import Path |
| import random |
| import re |
| from datetime import datetime |
| import numpy as np |
| import torch |
| from torch import Tensor |
| from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
| from jaxtyping import Bool, Int |
| import model |
|
|
|
|
| |
| def seed_everything(seed: int = 42) -> None: |
| """ |
| Set random seed for Python, NumPy, and PyTorch to ensure reproducibility. |
| Args: |
| seed (int): The seed value to use. |
| """ |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| def make_run_name(model_name: str, d_model: int) -> str: |
| time_tag: str = datetime.now().strftime("%Y%m%d_%H%M%S") |
| return f"{model_name}-{d_model}d-{time_tag}" |
|
|
|
|
| |
| def load_tokenizer(tokenizer_path: str | Path) -> PreTrainedTokenizerFast: |
| """ |
| Load a trained tokenizer from file and return tokenizer object and special token ids. |
| Args: |
| tokenizer_path (str | Path): Path to the tokenizer JSON file. |
| special_tokens (list[str], optional): List of special tokens to get ids for (e.g. ["[PAD]", "[SOS]", "[EOS]", "[UNK]"]). |
| Returns: |
| tokenizer (Tokenizer): Loaded tokenizer object. |
| token_ids (dict): Dictionary of special token ids. |
| """ |
| print(f"Loading tokenizer from {tokenizer_path}...") |
| |
| tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path)) |
| tokenizer.pad_token = "[PAD]" |
| tokenizer.unk_token = "[UNK]" |
| tokenizer.bos_token = "[SOS]" |
| tokenizer.eos_token = "[EOS]" |
| return tokenizer |
|
|
|
|
| def create_padding_mask( |
| input_ids: Int[Tensor, "B T_k"], pad_token_id: int |
| ) -> Bool[Tensor, "B 1 1 T_k"]: |
| """ |
| Creates a padding mask for the attention mechanism. |
| |
| This mask identifies positions holding the <PAD> token |
| and prepares a mask tensor that, when broadcasted, will mask |
| these positions in the attention scores matrix (B, H, T_q, T_k). |
| |
| Args: |
| input_ids (Tensor): The input token IDs. Shape (B, T_k). |
| pad_token_id (int): The ID of the padding token. |
| |
| Returns: |
| Tensor: A boolean mask of shape (B, 1, 1, T_k). |
| 'True' means "keep" (not a pad token). |
| 'False' means "mask out" (is a pad token). |
| """ |
|
|
| |
| |
| |
| mask: Tensor = input_ids != pad_token_id |
|
|
| |
| |
| |
| return mask.unsqueeze(1).unsqueeze(2) |
|
|
|
|
| def create_look_ahead_mask(seq_len: int) -> Bool[Tensor, "1 1 T_q T_q"]: |
| """ |
| Creates a causal (look-ahead) mask for the Decoder's self-attention. |
| |
| This mask prevents positions from attending to subsequent positions. |
| It's a square matrix where the upper triangle (future) is False |
| and the lower triangle (past/present) is True. |
| |
| Args: |
| seq_len (int): The sequence length (T_q). |
| device (torch.device): The device to create the tensor on (e.g., 'cuda'). |
| |
| Returns: |
| Tensor: A boolean mask of shape (1, 1, T_q, T_q). |
| 'True' means "keep" (allowed to see). |
| 'False' means "mask out" (future token). |
| """ |
|
|
| |
| |
| ones = torch.ones(seq_len, seq_len) |
|
|
| |
| |
| |
| |
| |
| |
| |
| lower_triangular: Tensor = torch.tril(ones) |
|
|
| |
| |
| |
| return (lower_triangular == 1).unsqueeze(0).unsqueeze(0) |
|
|
|
|
| def greedy_decode_sentence( |
| model: model.Transformer, |
| src: Int[Tensor, "1 T_src"], |
| src_mask: Bool[Tensor, "1 1 1 T_src"], |
| max_len: int, |
| sos_token_id: int, |
| eos_token_id: int, |
| device: torch.device, |
| ) -> Int[Tensor, "1 T_out"]: |
| """ |
| Performs greedy decoding for a single sentence. |
| This is an autoregressive process (token by token). |
| |
| Args: |
| model: The trained Transformer model (already on device). |
| src: The source token IDs (e.g., English). |
| src_mask: The padding mask for the source. |
| max_len: The maximum length to generate. |
| sos_token_id: The ID for [SOS] token. |
| eos_token_id: The ID for [EOS] token. |
| device: The device to run on. |
| |
| Returns: |
| Tensor: The generated target token IDs (e.g., Vietnamese). |
| """ |
|
|
| |
| model.eval() |
|
|
| |
| with torch.no_grad(): |
|
|
| |
| |
| src_embedded = model.src_embed(src) |
| src_with_pos = model.pos_enc(src_embedded) |
| enc_output: Tensor = model.encoder(src_with_pos, src_mask) |
|
|
| |
| |
| decoder_input: Tensor = torch.tensor( |
| [[sos_token_id]], dtype=torch.long, device=device |
| ) |
|
|
| |
| for _ in range(max_len - 1): |
|
|
| |
| |
| tgt_embedded = model.tgt_embed(decoder_input) |
| tgt_with_pos = model.pos_enc(tgt_embedded) |
|
|
| |
| |
| |
| |
| T_tgt = decoder_input.size(1) |
| tgt_mask = create_look_ahead_mask(T_tgt).to(device) |
|
|
| |
| |
| dec_output: Tensor = model.decoder( |
| tgt_with_pos, enc_output, src_mask, tgt_mask |
| ) |
| |
| logits: Tensor = model.generator(dec_output) |
|
|
| |
| |
| last_token_logits = logits[:, -1, :] |
|
|
| |
| |
| next_token: Tensor = torch.argmax(last_token_logits, dim=-1).unsqueeze(-1) |
|
|
| |
| |
| decoder_input = torch.cat([decoder_input, next_token], dim=1) |
|
|
| |
| |
| if next_token.item() == eos_token_id: |
| break |
|
|
| return decoder_input.squeeze(0) |
|
|
|
|
| def filter_and_detokenize(token_list: list[str], skip_special: bool = True) -> str: |
| """ |
| Manually joins tokens with a space and cleans up common |
| punctuation issues caused by whitespace tokenization. |
| """ |
| if skip_special: |
| |
| special_tokens = {"[PAD]", "[UNK]", "[SOS]", "[EOS]"} |
| token_list = [tok for tok in token_list if tok not in special_tokens] |
|
|
| |
| detokenized_string = " ".join(token_list) |
|
|
| |
| |
| |
| detokenized_string = re.sub(r'\s([.,!?\'":;])', r"\1", detokenized_string) |
| |
| detokenized_string = re.sub(r"(\w)\s(\'\w)", r"\1\2", detokenized_string) |
|
|
| return detokenized_string |
|
|
|
|
| |
| |
| def translate( |
| model: model.Transformer, |
| tokenizer: PreTrainedTokenizerFast, |
| sentence_en: str, |
| device: torch.device, |
| max_len: int, |
| sos_token_id: int, |
| eos_token_id: int, |
| pad_token_id: int, |
| ) -> str: |
| """ |
| Translates a single English sentence to Vietnamese. |
| |
| Args: |
| model: The trained Transformer model. |
| tokenizer: The (PreTrainedTokenizerFast) tokenizer. |
| sentence_en: The raw English input string. |
| device: The device to run on. |
| max_len: The max sequence length (from config). |
| sos_token_id: The ID for [SOS]. |
| eos_token_id: The ID for [EOS]. |
| pad_token_id: The ID for [PAD]. |
| |
| Returns: |
| str: The translated Vietnamese string. |
| """ |
|
|
| |
| model.eval() |
|
|
| |
| with torch.no_grad(): |
|
|
| |
| src_encoding = tokenizer( |
| sentence_en, |
| truncation=True, |
| max_length=max_len, |
| add_special_tokens=False, |
| ) |
|
|
| |
| |
| src_ids: Tensor = torch.tensor( |
| [src_encoding["input_ids"]], dtype=torch.long |
| ).to(device) |
|
|
| |
| |
| src_mask: Tensor = create_padding_mask(src_ids, pad_token_id).to(device) |
|
|
| |
| |
| |
| predicted_ids: Tensor = greedy_decode_sentence( |
| model, |
| src_ids, |
| src_mask, |
| max_len=max_len, |
| sos_token_id=sos_token_id, |
| eos_token_id=eos_token_id, |
| device=device, |
| ) |
|
|
| |
|
|
| |
| predicted_id_list = predicted_ids.cpu().tolist() |
|
|
| |
| predicted_token_list = tokenizer.convert_ids_to_tokens(predicted_id_list) |
|
|
| |
| |
| result_string = filter_and_detokenize(predicted_token_list, skip_special=True) |
|
|
| return result_string |
|
|
| print("Inference function `translate()` defined.") |
|
|