| | import torch |
| | import torch.nn as nn |
| | from torch.optim import AdamW |
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| | import os |
| | import logging |
| | from .model import CodeEmbedder |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | class CodeTrainer: |
| | def __init__( |
| | self, |
| | model: CodeEmbedder, |
| | train_loader: DataLoader, |
| | val_loader: DataLoader = None, |
| | epochs: int = 3, |
| | learning_rate: float = 2e-5, |
| | accumulation_steps: int = 1, |
| | mixed_precision: bool = True, |
| | output_dir: str = "./output", |
| | device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| | ): |
| | self.model = model.to(device) |
| | self.train_loader = train_loader |
| | self.val_loader = val_loader |
| | self.epochs = epochs |
| | self.lr = learning_rate |
| | self.accumulation_steps = accumulation_steps |
| | self.mixed_precision = mixed_precision |
| | self.output_dir = output_dir |
| | self.device = device |
| | |
| | |
| | self.optimizer = AdamW(self.model.parameters(), lr=self.lr) |
| | |
| | |
| | |
| | |
| | |
| | self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision) |
| | |
| | |
| | |
| | self.criterion = nn.TripletMarginLoss(margin=1.0, p=2) |
| |
|
| | def train_step(self, batch): |
| | """ |
| | Runs one training step. Returns loss. |
| | """ |
| | |
| | |
| | |
| | |
| | to_device = lambda x: x.to(self.device) |
| | |
| | |
| | with torch.cuda.amp.autocast(enabled=self.mixed_precision): |
| | |
| | anchor_emb = self.model(to_device(batch['anchor_input_ids']), to_device(batch['anchor_attention_mask'])) |
| | positive_emb = self.model(to_device(batch['positive_input_ids']), to_device(batch['positive_attention_mask'])) |
| | negative_emb = self.model(to_device(batch['negative_input_ids']), to_device(batch['negative_attention_mask'])) |
| | |
| | |
| | loss = self.criterion(anchor_emb, positive_emb, negative_emb) |
| | |
| | return loss |
| |
|
| | def train(self): |
| | logger.info(f"Starting training on {self.device}...") |
| | logger.info(f"Batch Size: {self.train_loader.batch_size}, Accumulation Steps: {self.accumulation_steps}") |
| | logger.info(f"Effective Batch Size: {self.train_loader.batch_size * self.accumulation_steps}") |
| | |
| | self.model.train() |
| | |
| | for epoch in range(self.epochs): |
| | total_loss = 0 |
| | self.optimizer.zero_grad() |
| | |
| | progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.epochs}") |
| | |
| | for step, batch in enumerate(progress_bar): |
| | |
| | |
| | loss = self.train_step(batch) |
| | |
| | |
| | loss = loss / self.accumulation_steps |
| | |
| | |
| | self.scaler.scale(loss).backward() |
| | |
| | if (step + 1) % self.accumulation_steps == 0: |
| | |
| | self.scaler.step(self.optimizer) |
| | self.scaler.update() |
| | self.optimizer.zero_grad() |
| | |
| | total_loss += loss.item() * self.accumulation_steps |
| | progress_bar.set_postfix({'loss': total_loss / (step + 1)}) |
| | |
| | |
| | self.save_model(epoch+1) |
| | |
| | def save_model(self, epoch): |
| | save_path = os.path.join(self.output_dir, f"checkpoint-{epoch}") |
| | os.makedirs(save_path, exist_ok=True) |
| | |
| | logger.info(f"Saving model to {save_path}...") |
| | |
| | |
| | self.model.encoder.save_pretrained(save_path, safe_serialization=True) |
| | self.model.config.save_pretrained(save_path) |
| | |
| | |
| |
|