import torch import torch.nn as nn from torch.optim import AdamW from torch.utils.data import DataLoader from tqdm import tqdm import os import logging from .model import CodeEmbedder # Setup Logger logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class CodeTrainer: def __init__( self, model: CodeEmbedder, train_loader: DataLoader, val_loader: DataLoader = None, epochs: int = 3, learning_rate: float = 2e-5, accumulation_steps: int = 1, mixed_precision: bool = True, output_dir: str = "./output", device: str = "cuda" if torch.cuda.is_available() else "cpu" ): self.model = model.to(device) self.train_loader = train_loader self.val_loader = val_loader self.epochs = epochs self.lr = learning_rate self.accumulation_steps = accumulation_steps self.mixed_precision = mixed_precision self.output_dir = output_dir self.device = device # Optimizer self.optimizer = AdamW(self.model.parameters(), lr=self.lr) # Scheduler (Optional: constant for now, can transform to Linear later) # self.scheduler = ... # Mixed Precision Scaler self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision) # Loss Function: Triplet Margin Loss (Standard for Sentence Embeddings) # Tries to maximize distance between Anchor-Negative and minimize Anchor-Positive self.criterion = nn.TripletMarginLoss(margin=1.0, p=2) def train_step(self, batch): """ Runs one training step. Returns loss. """ # Unpack the Triplet Batch # We assume the Dataset returns keys: 'anchor_input_ids', 'anchor_attention_mask', etc. # Helper to move dict to device to_device = lambda x: x.to(self.device) # Autocast for Mixed Precision with torch.cuda.amp.autocast(enabled=self.mixed_precision): # 1. Forward Pass for all 3 components anchor_emb = self.model(to_device(batch['anchor_input_ids']), to_device(batch['anchor_attention_mask'])) positive_emb = self.model(to_device(batch['positive_input_ids']), to_device(batch['positive_attention_mask'])) negative_emb = self.model(to_device(batch['negative_input_ids']), to_device(batch['negative_attention_mask'])) # 2. Compute Triplet Loss loss = self.criterion(anchor_emb, positive_emb, negative_emb) return loss def train(self): logger.info(f"Starting training on {self.device}...") logger.info(f"Batch Size: {self.train_loader.batch_size}, Accumulation Steps: {self.accumulation_steps}") logger.info(f"Effective Batch Size: {self.train_loader.batch_size * self.accumulation_steps}") self.model.train() for epoch in range(self.epochs): total_loss = 0 self.optimizer.zero_grad() progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.epochs}") for step, batch in enumerate(progress_bar): # Forward + Loss Calculation loss = self.train_step(batch) # Gradient Accumulation: Normalize loss loss = loss / self.accumulation_steps # Backward Pass (Scaled) self.scaler.scale(loss).backward() if (step + 1) % self.accumulation_steps == 0: # Update Weights self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() total_loss += loss.item() * self.accumulation_steps progress_bar.set_postfix({'loss': total_loss / (step + 1)}) # Save Checkpoint self.save_model(epoch+1) def save_model(self, epoch): save_path = os.path.join(self.output_dir, f"checkpoint-{epoch}") os.makedirs(save_path, exist_ok=True) logger.info(f"Saving model to {save_path}...") # Save explicitly as safetensors via transformers API self.model.encoder.save_pretrained(save_path, safe_serialization=True) self.model.config.save_pretrained(save_path) # Note: We save the 'encoder' which is the AutoModel, # so it can be loaded easily by others.