| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
| from torchmetrics.text import BLEUScore, SacreBLEUScore |
| from tqdm.auto import tqdm |
| import config |
| from src import model, utils |
|
|
|
|
| TGT_VOCAB_SIZE: int = config.VOCAB_SIZE |
|
|
|
|
| def train_one_epoch( |
| model: model.Transformer, |
| dataloader: DataLoader, |
| optimizer: torch.optim.Optimizer, |
| criterion: nn.Module, |
| scheduler: torch.optim.lr_scheduler.LambdaLR, |
| device: torch.device, |
| logger=None, |
| ) -> float: |
| """ |
| Runs a single training epoch. |
| |
| Args: |
| model: The Transformer model. |
| dataloader: The training DataLoader. |
| optimizer: The optimizer. |
| criterion: The loss function (e.g., CrossEntropyLoss). |
| device: The device to run on (e.g., 'cuda'). |
| |
| Returns: |
| The average training loss for the epoch. |
| """ |
|
|
| |
| |
| model.train() |
|
|
| total_loss = 0.0 |
|
|
| |
| progress_bar = tqdm(dataloader, desc="Training", leave=False) |
| batch_idx: int = 0 |
|
|
| for batch in progress_bar: |
| batch_idx += 1 |
|
|
| |
| |
| batch_gpu = { |
| k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor) |
| } |
|
|
| |
| optimizer.zero_grad() |
|
|
| |
| |
| logits = model( |
| src=batch_gpu["src_ids"], |
| tgt=batch_gpu["tgt_input_ids"], |
| src_mask=batch_gpu["src_mask"], |
| tgt_mask=batch_gpu["tgt_mask"], |
| ) |
|
|
| |
| |
| |
| |
| |
| loss = criterion(logits.view(-1, TGT_VOCAB_SIZE), batch_gpu["labels"].view(-1)) |
|
|
| |
| loss.backward() |
|
|
| |
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
| |
| optimizer.step() |
|
|
| |
| scheduler.step() |
|
|
| |
| total_loss += loss.item() |
| progress_bar.set_postfix(loss=loss.item()) |
|
|
| |
| if logger and batch_idx % 100 == 0: |
| logger.log( |
| { |
| "train/batch_loss": loss.item(), |
| "train/learning_rate": optimizer.param_groups[0]["lr"], |
| } |
| ) |
|
|
| |
| return total_loss / len(dataloader) |
|
|
|
|
| def validate_one_epoch( |
| model: model.Transformer, |
| dataloader: DataLoader, |
| criterion: nn.Module, |
| device: torch.device, |
| ) -> float: |
| """ |
| Runs a single validation epoch. |
| |
| Args: |
| model: The Transformer model. |
| dataloader: The validation DataLoader. |
| criterion: The loss function (e.g., CrossEntropyLoss). |
| device: The device to run on (e.g., 'cuda'). |
| |
| Returns: |
| The average validation loss for the epoch. |
| """ |
|
|
| |
| |
| model.eval() |
|
|
| total_loss = 0.0 |
|
|
| |
| progress_bar = tqdm(dataloader, desc="Validating", leave=False) |
|
|
| |
| |
| with torch.no_grad(): |
| for batch in progress_bar: |
| |
| batch_gpu = { |
| k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor) |
| } |
|
|
| |
| logits = model( |
| src=batch_gpu["src_ids"], |
| tgt=batch_gpu["tgt_input_ids"], |
| src_mask=batch_gpu["src_mask"], |
| tgt_mask=batch_gpu["tgt_mask"], |
| ) |
|
|
| |
| |
| loss = criterion( |
| logits.view(-1, TGT_VOCAB_SIZE), batch_gpu["labels"].view(-1) |
| ) |
|
|
| |
| total_loss += loss.item() |
| progress_bar.set_postfix(loss=loss.item()) |
|
|
| |
| return total_loss / len(dataloader) |
|
|
|
|
| def evaluate_model( |
| model: model.Transformer, |
| dataloader: DataLoader, |
| tokenizer: PreTrainedTokenizerFast, |
| device: torch.device, |
| table=None, |
| ) -> tuple[float, float]: |
| """ |
| Runs final evaluation on the test set using Beam Search |
| and calculates the SacreBLEU score. |
| """ |
| print("\n--- Starting Evaluation (BLEU + SacreBLEU) ---") |
|
|
| |
| |
| model.eval() |
|
|
| all_predicted_strings = [] |
| all_expected_strings = [] |
|
|
| |
| with torch.no_grad(): |
| for batch in tqdm(dataloader, desc="Evaluating"): |
|
|
| batch_gpu = { |
| k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor) |
| } |
|
|
| src_ids = batch_gpu["src_ids"] |
| src_mask = batch_gpu["src_mask"] |
| expected_ids = batch_gpu["labels"] |
|
|
| B = src_ids.size(0) |
|
|
| |
| batch_expected_strings = [] |
|
|
| |
| expected_id_lists = expected_ids.cpu().tolist() |
|
|
| |
| for id_list in expected_id_lists: |
| |
| |
| token_list = tokenizer.convert_ids_to_tokens(id_list) |
| batch_expected_strings.append( |
| utils.filter_and_detokenize(token_list, skip_special=True) |
| ) |
|
|
| |
| batch_predicted_strings = [] |
| for i in tqdm(range(B), desc="Decoding Batch", leave=False): |
| src_sentence = src_ids[i].unsqueeze(0) |
| src_sentence_mask = src_mask[i].unsqueeze(0) |
|
|
| |
| predicted_ids = utils.greedy_decode_sentence( |
| model, |
| src_sentence, |
| src_sentence_mask, |
| max_len=config.MAX_SEQ_LEN, |
| sos_token_id=config.SOS_TOKEN_ID, |
| eos_token_id=config.EOS_TOKEN_ID, |
| device=device, |
| ) |
|
|
| |
| predicted_id_list = predicted_ids.cpu().tolist() |
|
|
| |
| predicted_token_list = tokenizer.convert_ids_to_tokens( |
| predicted_id_list |
| ) |
|
|
| decoded_str = utils.filter_and_detokenize( |
| predicted_token_list, skip_special=True |
| ) |
| batch_predicted_strings.append(decoded_str) |
|
|
| |
| all_predicted_strings.extend(batch_predicted_strings) |
| all_expected_strings.extend([[s] for s in batch_expected_strings]) |
|
|
| bleu_metric = BLEUScore(n_gram=4, smooth=True).to(config.DEVICE) |
| sacrebleu_metric = SacreBLEUScore( |
| n_gram=4, smooth=True, tokenize="intl", lowercase=False |
| ).to(config.DEVICE) |
|
|
| |
| print("\nCalculating final BLEU score...") |
| final_bleu = bleu_metric(all_predicted_strings, all_expected_strings) |
|
|
| |
| |
| |
|
|
| print("\nCalculating final SacreBLEU score...") |
| final_sacrebleu = sacrebleu_metric(all_predicted_strings, all_expected_strings) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| print("\n--- Translation Examples (Pred vs Exp) ---") |
| for i in range(min(5, len(all_predicted_strings))): |
| print(f" PRED: {all_predicted_strings[i]}") |
| print(f" EXP: {all_expected_strings[i][0]}") |
| print(" ---") |
|
|
| table.add_data(all_expected_strings[i][0], all_predicted_strings[i]) |
|
|
| return final_bleu.item() * 100, final_sacrebleu.item() * 100 |
|
|