| import math |
| import time |
| from typing import Any, Optional, Dict, List |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from tqdm import tqdm |
| from logger.logger import TrainerLogger |
| from torch.utils.data import DataLoader |
| from transformers import PreTrainedModel |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| class BaseTrainer: |
| def __init__( |
| self, |
| model: PreTrainedModel, |
| optimizer: torch.optim.Optimizer, |
| scheduler: torch.optim.lr_scheduler._LRScheduler, |
| tokenizer: Any, |
| train_loader: DataLoader, |
| test_loader: Optional[DataLoader] = None, |
| logger_config: Dict[str, Any] = None, |
| use_amp: bool = True, |
| ): |
| self.model = model.to(device) |
| self.optimizer = optimizer |
| self.scheduler = scheduler |
| self.tokenizer = tokenizer |
| self.train_loader = train_loader |
| self.test_loader = test_loader |
| self.use_amp = use_amp |
| self.scaler = torch.amp.GradScaler('cuda') if use_amp else None |
| self.train_step = 0 |
| self._best_perplexity = float('inf') |
| self._epochs_no_improve = 0 |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| self.logger = TrainerLogger( |
| tracking_uri=logger_config["tracking_uri"], |
| experiment=logger_config["experiment"], |
| run_name=logger_config["model_name"], |
| model_name=logger_config["model_name"], |
| total_params=total_params, |
| tags={"version": "1.0", "environment": "development"}, |
| ) |
|
|
| def _generate_sample(self, sample_prompts: List[str] = []): |
| self.model.eval() |
| samples_html = "" |
| for prompt in sample_prompts: |
| try: |
| |
| inputs = self.tokenizer(prompt, return_tensors="pt") |
| input_ids = inputs.input_ids.to(self.model.device) |
|
|
| |
| with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16): |
| generated_ids = self.model.generate( |
| input_ids=input_ids, |
| max_length=100, |
| num_beams=5, |
| do_sample=True, |
| top_k=50, |
| top_p=0.95, |
| temperature=0.7, |
| repetition_penalty=1.2, |
| use_cache=True, |
| eos_token_id=self.tokenizer.eos_token_id, |
| pad_token_id=self.tokenizer.pad_token_id, |
| ) |
|
|
| |
| generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| except Exception as e: |
| generated_text = f"Erro: {e}" |
| samples_html += f"<h4><b>prompt:</b> {prompt}</h4><p><b>Resposta:</b> {generated_text}</p>" |
| self.model.train() |
| return samples_html |
|
|
| def _calc_loss_batch(self, inputs: torch.Tensor) -> torch.Tensor: |
| """ |
| Calcula apenas a entropia cruzada para um batch de input_ids, |
| desativando o cache de chaves/valores durante o treinamento. |
| """ |
| ignore_idx = -100 |
| |
| valid = ((inputs >= 0) | (inputs == ignore_idx)) & (inputs < self.tokenizer.vocab_size) |
| assert valid.all(), f"Há labels inválidos: min={inputs.min().item()}, max={inputs.max().item()}" |
|
|
| inputs = inputs.to(device) |
| with torch.autocast(device_type="cuda", dtype=torch.float16): |
| outputs = self.model( |
| input_ids=inputs, |
| labels=inputs, |
| use_cache=False, |
| return_dict=True |
| ) |
| loss = outputs.loss |
| logits = outputs.logits |
| if torch.isnan(logits).any() or torch.isinf(logits).any(): |
| raise RuntimeError("Logits inválidos detectados") |
| return loss |
|
|
| def _train_epoch(self, epoch: int, sample_prompts: Optional[List[str]] = None) -> List[float]: |
| if sample_prompts is None: |
| sample_prompts = [] |
|
|
| self.model.train() |
| losses = [] |
| size_dataset = len(self.train_loader) |
| pbar = tqdm( |
| self.train_loader, |
| total=size_dataset, |
| desc=f"Epoch {epoch + 1}", |
| unit="batch", |
| leave=False, |
| ) |
|
|
| for i, batch in enumerate(pbar): |
| start_time = time.time() |
| self.optimizer.zero_grad() |
| loss = self._calc_loss_batch(batch['input_ids']) |
| losses.append(loss.item()) |
|
|
| if self.use_amp: |
| self.scaler.scale(loss).backward() |
| self.scaler.unscale_(self.optimizer) |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| else: |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| self.optimizer.step() |
|
|
| self.scheduler.step() |
| perplexity = math.exp(loss.item()) |
| current_lr = self.optimizer.param_groups[0].get('lr', 0.0) |
| elapsed_time = time.time() - start_time |
|
|
| pbar.set_postfix({ |
| "loss": f"{loss.item():.4f}", |
| "perplexity": f"{perplexity:.4f}", |
| "lr": f"{current_lr:.2e}", |
| "elapsed_time": f"{elapsed_time:.2f}s", |
| }) |
|
|
| |
| if (i + 1) % 100 == 0: |
| self.train_step += 1 |
| avg_loss = sum(losses[-100:]) / 100 |
| avg_perplexity = math.exp(sum(losses[-100:]) / 100) |
| self.logger.log_metrics( |
| { |
| "train_loss": avg_loss, |
| "train_perplexity": avg_perplexity, |
| "lr": current_lr, |
| }, |
| step=self.train_step, |
| ) |
|
|
| |
| if (i + 1) % 500 == 0: |
| samples_html = self._generate_sample(sample_prompts) |
| self.logger.log_html(f"<html><head><meta charset='utf-8'></head><body>{samples_html}</body></html>", |
| step=self.train_step) |
|
|
| |
| if (i + 1) % 1000 == 0: |
| avg_loss = sum(losses[-1000:]) / 1000 |
| avg_perplexity = math.exp(sum(losses[-1000:]) / 1000) |
| self.logger.log_checkpoint_table(current_lr, avg_loss, avg_perplexity, i + 1) |
| self.logger.checkpoint_model(self.model) |
| self.model.save_pretrained(f"../") |
|
|
|
|
| return losses |
|
|
| def train(self, num_epochs: int = 500, sample_prompts: Optional[List[str]] = None): |
| for epoch in range(num_epochs): |
| train_losses = self._train_epoch(epoch, sample_prompts) |
| mean_train_loss = sum(train_losses) / len(train_losses) |
| self.logger.log_metrics( |
| {"mean_train_loss": mean_train_loss}, |
| step=epoch, |
| ) |
| print(f"Epoch {epoch + 1} | Train Loss: {mean_train_loss:.4f}") |
|
|
| self.logger.finish() |
| print("Treinamento concluído!") |
|
|
|
|
| |
| class TuningTrainer(BaseTrainer): |
| pass |
|
|
| |
| class PreTrainer(BaseTrainer): |
| pass |
|
|