CodeMode Agent
Deploy CodeMode via Agent
463fc7e
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import logging
from .model import CodeEmbedder
# Setup Logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CodeTrainer:
def __init__(
self,
model: CodeEmbedder,
train_loader: DataLoader,
val_loader: DataLoader = None,
epochs: int = 3,
learning_rate: float = 2e-5,
accumulation_steps: int = 1,
mixed_precision: bool = True,
output_dir: str = "./output",
device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.epochs = epochs
self.lr = learning_rate
self.accumulation_steps = accumulation_steps
self.mixed_precision = mixed_precision
self.output_dir = output_dir
self.device = device
# Optimizer
self.optimizer = AdamW(self.model.parameters(), lr=self.lr)
# Scheduler (Optional: constant for now, can transform to Linear later)
# self.scheduler = ...
# Mixed Precision Scaler
self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision)
# Loss Function: Triplet Margin Loss (Standard for Sentence Embeddings)
# Tries to maximize distance between Anchor-Negative and minimize Anchor-Positive
self.criterion = nn.TripletMarginLoss(margin=1.0, p=2)
def train_step(self, batch):
"""
Runs one training step. Returns loss.
"""
# Unpack the Triplet Batch
# We assume the Dataset returns keys: 'anchor_input_ids', 'anchor_attention_mask', etc.
# Helper to move dict to device
to_device = lambda x: x.to(self.device)
# Autocast for Mixed Precision
with torch.cuda.amp.autocast(enabled=self.mixed_precision):
# 1. Forward Pass for all 3 components
anchor_emb = self.model(to_device(batch['anchor_input_ids']), to_device(batch['anchor_attention_mask']))
positive_emb = self.model(to_device(batch['positive_input_ids']), to_device(batch['positive_attention_mask']))
negative_emb = self.model(to_device(batch['negative_input_ids']), to_device(batch['negative_attention_mask']))
# 2. Compute Triplet Loss
loss = self.criterion(anchor_emb, positive_emb, negative_emb)
return loss
def train(self):
logger.info(f"Starting training on {self.device}...")
logger.info(f"Batch Size: {self.train_loader.batch_size}, Accumulation Steps: {self.accumulation_steps}")
logger.info(f"Effective Batch Size: {self.train_loader.batch_size * self.accumulation_steps}")
self.model.train()
for epoch in range(self.epochs):
total_loss = 0
self.optimizer.zero_grad()
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.epochs}")
for step, batch in enumerate(progress_bar):
# Forward + Loss Calculation
loss = self.train_step(batch)
# Gradient Accumulation: Normalize loss
loss = loss / self.accumulation_steps
# Backward Pass (Scaled)
self.scaler.scale(loss).backward()
if (step + 1) % self.accumulation_steps == 0:
# Update Weights
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
total_loss += loss.item() * self.accumulation_steps
progress_bar.set_postfix({'loss': total_loss / (step + 1)})
# Save Checkpoint
self.save_model(epoch+1)
def save_model(self, epoch):
save_path = os.path.join(self.output_dir, f"checkpoint-{epoch}")
os.makedirs(save_path, exist_ok=True)
logger.info(f"Saving model to {save_path}...")
# Save explicitly as safetensors via transformers API
self.model.encoder.save_pretrained(save_path, safe_serialization=True)
self.model.config.save_pretrained(save_path)
# Note: We save the 'encoder' which is the AutoModel,
# so it can be loaded easily by others.