| import torch |
| import torch.nn as nn |
| from torch import Tensor |
| from torch.utils.data import Dataset, DataLoader |
| from datasets import Dataset as ArrowDataset |
| from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
|
|
| import config |
| from src import utils |
|
|
|
|
| class TranslationDataset(Dataset): |
| """ |
| A "lazy" Dataset. |
| Uses the high-level PreTrainedTokenizerFast wrapper. |
| """ |
|
|
| def __init__( |
| self, |
| dataset: ArrowDataset, |
| tokenizer: PreTrainedTokenizerFast, |
| max_len_src: int, |
| max_len_tgt: int, |
| src_lang: str = "en", |
| tgt_lang: str = "vi", |
| ): |
| super().__init__() |
| self.dataset = dataset |
| self.tokenizer = tokenizer |
| self.max_len_src = max_len_src |
| self.max_len_tgt = max_len_tgt |
| self.src_lang = src_lang |
| self.tgt_lang = tgt_lang |
|
|
| def __len__(self) -> int: |
| return len(self.dataset) |
|
|
| def __getitem__(self, index: int) -> dict[str, list[int]]: |
|
|
| item = self.dataset[index]["translation"] |
| src_text = item[self.src_lang] |
| tgt_text = item[self.tgt_lang] |
|
|
| |
| src_encoding = self.tokenizer( |
| src_text, |
| truncation=True, |
| max_length=self.max_len_src, |
| add_special_tokens=False, |
| ) |
|
|
| tgt_encoding = self.tokenizer( |
| tgt_text, |
| truncation=True, |
| max_length=self.max_len_tgt - 2, |
| add_special_tokens=False, |
| ) |
|
|
| |
| src_ids = src_encoding["input_ids"] |
|
|
| tgt_ids = ( |
| [config.SOS_TOKEN_ID] + tgt_encoding["input_ids"] + [config.EOS_TOKEN_ID] |
| ) |
|
|
| return {"src_ids": src_ids, "tgt_ids": tgt_ids} |
|
|
|
|
| class DataCollator: |
| """ |
| Implements a custom collate_fn. |
| |
| 1. Takes a list of dicts (from __getitem__) |
| 2. Adds SOS/EOS (Wait, we did this in Dataset) |
| 3. Creates decoder inputs and labels (shifted) |
| 4. Dynamically pads all sequences *in the batch* |
| 5. Creates all 3 required masks |
| 6. Returns a single dict of tensors |
| """ |
|
|
| def __init__(self, pad_token_id: int): |
| self.pad_token_id = pad_token_id |
|
|
| def __call__(self, batch: list[dict[str, list[int]]]) -> dict[str, Tensor]: |
|
|
| |
| src_ids_list = [item["src_ids"] for item in batch] |
| tgt_ids_list = [item["tgt_ids"] for item in batch] |
|
|
| |
| |
| dec_input_ids_list = [ids[:-1] for ids in tgt_ids_list] |
| |
| labels_list = [ids[1:] for ids in tgt_ids_list] |
|
|
| |
| |
| |
| src_ids_padded = nn.utils.rnn.pad_sequence( |
| [torch.tensor(ids) for ids in src_ids_list], |
| batch_first=True, |
| padding_value=self.pad_token_id, |
| ) |
|
|
| dec_input_ids_padded = nn.utils.rnn.pad_sequence( |
| [torch.tensor(ids) for ids in dec_input_ids_list], |
| batch_first=True, |
| padding_value=self.pad_token_id, |
| ) |
|
|
| labels_padded = nn.utils.rnn.pad_sequence( |
| [torch.tensor(ids) for ids in labels_list], |
| batch_first=True, |
| padding_value=self.pad_token_id, |
| ) |
|
|
| |
| _, T_tgt = dec_input_ids_padded.shape |
|
|
| |
|
|
| |
| |
| src_mask = utils.create_padding_mask(src_ids_padded, self.pad_token_id) |
|
|
| |
| |
| tgt_padding_mask = utils.create_padding_mask( |
| dec_input_ids_padded, self.pad_token_id |
| ) |
|
|
| |
| |
| look_ahead_mask = utils.create_look_ahead_mask(T_tgt) |
|
|
| |
| |
| tgt_mask = tgt_padding_mask & look_ahead_mask |
|
|
| return { |
| "src_ids": src_ids_padded, |
| "tgt_input_ids": dec_input_ids_padded, |
| "labels": labels_padded, |
| "src_mask": src_mask, |
| "tgt_mask": tgt_mask, |
| } |
|
|
|
|
| def get_translation_datasets( |
| tokenizer: PreTrainedTokenizerFast, |
| ) -> tuple[TranslationDataset, TranslationDataset, TranslationDataset]: |
| """ |
| A Factory function to automate the data pipeline setup. |
| |
| It performs 3 steps: |
| 1. Loads and cleans raw data (using src.utils). |
| 2. Instantiates the TranslationDataset for Train, Val, and Test splits. |
| 3. Returns the 3 PyTorch datasets ready for the DataLoader. |
| |
| Args: |
| tokenizer: The trained tokenizer. |
| |
| Returns: |
| Tuple containing (train_ds, val_ds, test_ds) |
| """ |
|
|
| |
| |
| train_data, val_data, test_data = utils.get_raw_data( |
| config.DATA_PATH, num_workers=config.NUM_WORKERS |
| ) |
| train_data = train_data.select(range(config.NUM_SAMPLES_TO_USE)) |
|
|
| print(f"Building PyTorch Datasets...") |
|
|
| |
| |
| train_ds = TranslationDataset( |
| dataset=train_data, |
| tokenizer=tokenizer, |
| max_len_src=config.MAX_SEQ_LEN, |
| max_len_tgt=config.MAX_SEQ_LEN, |
| ) |
|
|
| |
| val_ds = TranslationDataset( |
| dataset=val_data, |
| tokenizer=tokenizer, |
| max_len_src=config.MAX_SEQ_LEN, |
| max_len_tgt=config.MAX_SEQ_LEN, |
| ) |
|
|
| |
| test_ds = TranslationDataset( |
| dataset=test_data, |
| tokenizer=tokenizer, |
| max_len_src=config.MAX_SEQ_LEN, |
| max_len_tgt=config.MAX_SEQ_LEN, |
| ) |
|
|
| print( |
| f"Datasets created: Train={len(train_ds)}, Val={len(val_ds)}, Test={len(test_ds)}" |
| ) |
|
|
| return train_ds, val_ds, test_ds |
|
|
|
|
| def get_dataloaders( |
| tokenizer: PreTrainedTokenizerFast, |
| ) -> tuple[DataLoader, DataLoader, DataLoader]: |
| """ |
| A high-level Factory function to create DataLoaders. |
| |
| This function abstracts away all the data pipeline complexity: |
| - Loading/Cleaning raw data |
| - Creating PyTorch Datasets |
| - Instantiating the DataCollator (dynamic padding) |
| - Creating DataLoaders with the correct batch size and workers |
| |
| Args: |
| tokenizer: The trained tokenizer. |
| |
| Returns: |
| Tuple containing (train_loader, val_loader, test_loader) |
| """ |
|
|
| |
| train_ds, val_ds, test_ds = get_translation_datasets(tokenizer) |
|
|
| |
| |
| collator = DataCollator(pad_token_id=config.PAD_TOKEN_ID) |
|
|
| print( |
| f"Building DataLoaders (Batch Size: {config.BATCH_SIZE}, Workers: {config.NUM_WORKERS})..." |
| ) |
|
|
| |
| |
| train_loader = DataLoader( |
| train_ds, |
| batch_size=config.BATCH_SIZE, |
| shuffle=True, |
| num_workers=config.NUM_WORKERS, |
| collate_fn=collator, |
| pin_memory=True if config.DEVICE == "cuda" else False, |
| prefetch_factor=2, |
| persistent_workers=True, |
| ) |
|
|
| |
| |
| val_loader = DataLoader( |
| val_ds, |
| batch_size=2 * config.BATCH_SIZE, |
| shuffle=False, |
| num_workers=config.NUM_WORKERS, |
| collate_fn=collator, |
| pin_memory=True if config.DEVICE == "cuda" else False, |
| prefetch_factor=2, |
| persistent_workers=True, |
| ) |
|
|
| |
| test_loader = DataLoader( |
| test_ds, |
| batch_size=2 * config.BATCH_SIZE, |
| shuffle=False, |
| num_workers=2, |
| |
| collate_fn=collator, |
| pin_memory=True if config.DEVICE == "cuda" else False, |
| prefetch_factor=2, |
| ) |
|
|
| print(f"DataLoader (train) created with {len(train_loader)} batches.") |
| print(f"DataLoader (val) created with {len(val_loader)} batches.") |
| print(f"DataLoader (test) created with {len(test_loader)} batches.") |
|
|
| return train_loader, val_loader, test_loader |
|
|