File size: 4,634 Bytes
463fc7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import logging
from .model import CodeEmbedder

# Setup Logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class CodeTrainer:
    def __init__(
        self,
        model: CodeEmbedder,
        train_loader: DataLoader,
        val_loader: DataLoader = None,
        epochs: int = 3,
        learning_rate: float = 2e-5,
        accumulation_steps: int = 1,
        mixed_precision: bool = True,
        output_dir: str = "./output",
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.epochs = epochs
        self.lr = learning_rate
        self.accumulation_steps = accumulation_steps
        self.mixed_precision = mixed_precision
        self.output_dir = output_dir
        self.device = device
        
        # Optimizer
        self.optimizer = AdamW(self.model.parameters(), lr=self.lr)
        
        # Scheduler (Optional: constant for now, can transform to Linear later)
        # self.scheduler = ...
        
        # Mixed Precision Scaler
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision)
        
        # Loss Function: Triplet Margin Loss (Standard for Sentence Embeddings)
        # Tries to maximize distance between Anchor-Negative and minimize Anchor-Positive
        self.criterion = nn.TripletMarginLoss(margin=1.0, p=2)

    def train_step(self, batch):
        """
        Runs one training step. Returns loss.
        """
        # Unpack the Triplet Batch
        # We assume the Dataset returns keys: 'anchor_input_ids', 'anchor_attention_mask', etc.
        
        # Helper to move dict to device
        to_device = lambda x: x.to(self.device)
        
        # Autocast for Mixed Precision
        with torch.cuda.amp.autocast(enabled=self.mixed_precision):
            # 1. Forward Pass for all 3 components
            anchor_emb = self.model(to_device(batch['anchor_input_ids']), to_device(batch['anchor_attention_mask']))
            positive_emb = self.model(to_device(batch['positive_input_ids']), to_device(batch['positive_attention_mask']))
            negative_emb = self.model(to_device(batch['negative_input_ids']), to_device(batch['negative_attention_mask']))
            
            # 2. Compute Triplet Loss
            loss = self.criterion(anchor_emb, positive_emb, negative_emb)
            
        return loss

    def train(self):
        logger.info(f"Starting training on {self.device}...")
        logger.info(f"Batch Size: {self.train_loader.batch_size}, Accumulation Steps: {self.accumulation_steps}")
        logger.info(f"Effective Batch Size: {self.train_loader.batch_size * self.accumulation_steps}")
        
        self.model.train()
        
        for epoch in range(self.epochs):
            total_loss = 0
            self.optimizer.zero_grad()
            
            progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.epochs}")
            
            for step, batch in enumerate(progress_bar):
                
                # Forward + Loss Calculation
                loss = self.train_step(batch)
                
                # Gradient Accumulation: Normalize loss
                loss = loss / self.accumulation_steps
                
                # Backward Pass (Scaled)
                self.scaler.scale(loss).backward()
                
                if (step + 1) % self.accumulation_steps == 0:
                    # Update Weights
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()
                
                total_loss += loss.item() * self.accumulation_steps
                progress_bar.set_postfix({'loss': total_loss / (step + 1)})
            
            # Save Checkpoint
            self.save_model(epoch+1)
            
    def save_model(self, epoch):
        save_path = os.path.join(self.output_dir, f"checkpoint-{epoch}")
        os.makedirs(save_path, exist_ok=True)
        
        logger.info(f"Saving model to {save_path}...")
        
        # Save explicitly as safetensors via transformers API
        self.model.encoder.save_pretrained(save_path, safe_serialization=True)
        self.model.config.save_pretrained(save_path)
        # Note: We save the 'encoder' which is the AutoModel, 
        # so it can be loaded easily by others.