import torch import torch.nn as nn from torch.optim import AdamW from torch.amp import autocast, GradScaler from transformers import get_linear_schedule_with_warmup from pathlib import Path from tqdm import tqdm import argparse import json import gc import sys sys.path.append(str(Path(__file__).parent.parent)) from src.v2.data_processor import load_tokenizer, create_dataloader from src.v2.model import VulnerabilityCodeT5, count_parameters class Trainer: def __init__( self, model, train_loader, valid_loader, device, learning_rate=2e-5, num_epochs=5, gradient_accumulation_steps=4, ): self.model = model.to(device) self.train_loader = train_loader self.valid_loader = valid_loader self.device = device self.num_epochs = num_epochs self.gradient_accumulation_steps = gradient_accumulation_steps self.use_amp = device.type == "cuda" self.scaler = GradScaler(enabled=self.use_amp) self.optimizer = AdamW( self.model.parameters(), lr=learning_rate, weight_decay=0.01 ) total_steps = ( len(self.train_loader) * num_epochs ) // gradient_accumulation_steps self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=max(1, total_steps // 10), num_training_steps=total_steps, ) self.best_val_acc = 0.0 self.history = { "train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], } def clear_memory(self): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def train_epoch(self): self.model.train() total_loss = 0.0 correct = 0 total = 0 self.optimizer.zero_grad(set_to_none=True) pbar = tqdm(self.train_loader, desc="Training") for step, batch in enumerate(pbar): input_ids = batch["input_ids"].to(self.device, non_blocking=True) attention_mask = batch["attention_mask"].to(self.device, non_blocking=True) labels = batch["labels"].to(self.device, non_blocking=True) with autocast(device_type="cuda", enabled=self.use_amp): outputs = self.model(input_ids, attention_mask, labels) loss = outputs["loss"] / self.gradient_accumulation_steps self.scaler.scale(loss).backward() if (step + 1) % self.gradient_accumulation_steps == 0: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.scaler.step(self.optimizer) self.scaler.update() self.scheduler.step() self.optimizer.zero_grad(set_to_none=True) with torch.no_grad(): preds = torch.argmax(outputs["logits"], dim=1) correct += (preds == labels).sum().item() total += labels.size(0) total_loss += loss.item() * self.gradient_accumulation_steps gpu_mem = ( torch.cuda.memory_allocated() / 1024 ** 3 if torch.cuda.is_available() else 0 ) pbar.set_postfix( { "loss": f"{loss.item() * self.gradient_accumulation_steps:.4f}", "acc": f"{100 * correct / max(1, total):.2f}%", "gpu": f"{gpu_mem:.2f}GB", } ) del input_ids, attention_mask, labels, outputs, loss self.clear_memory() return total_loss / len(self.train_loader), 100 * correct / total def validate(self): self.model.eval() total_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for batch in tqdm(self.valid_loader, desc="Validating"): input_ids = batch["input_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) labels = batch["labels"].to(self.device) with autocast(device_type="cuda", enabled=self.use_amp): outputs = self.model(input_ids, attention_mask, labels) loss = outputs["loss"] preds = torch.argmax(outputs["logits"], dim=1) correct += (preds == labels).sum().item() total += labels.size(0) total_loss += loss.item() self.clear_memory() return total_loss / len(self.valid_loader), 100 * correct / total def train(self, save_dir="models/v2"): print(f"Training samples: {len(self.train_loader.dataset)}") print(f"Validation samples: {len(self.valid_loader.dataset)}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) for epoch in range(self.num_epochs): print(f"\n{'=' * 60}") print(f"Epoch {epoch + 1}/{self.num_epochs}") print(f"{'=' * 60}") train_loss, train_acc = self.train_epoch() val_loss, val_acc = self.validate() print( f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%" ) print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%") self.history["train_loss"].append(train_loss) self.history["train_acc"].append(train_acc) self.history["val_loss"].append(val_loss) self.history["val_acc"].append(val_acc) if val_acc > self.best_val_acc: self.best_val_acc = val_acc torch.save( { "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "val_acc": val_acc, }, save_dir / "best_model.pt", ) print("Saved best model") torch.save( { "model_state_dict": self.model.state_dict(), "history": self.history, }, save_dir / "final_model.pt", ) with open(save_dir / "training_history.json", "w") as f: json.dump(self.history, f, indent=2) print(f"\nTraining complete. Best Val Acc: {self.best_val_acc:.2f}%") def main(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") data_dir = ( Path("data/processed/sample") if args.use_sample else Path("data/processed") ) train_path = data_dir / "train.jsonl" valid_path = data_dir / "valid.jsonl" test_path = data_dir / "test.jsonl" tokenizer = load_tokenizer(args.model_name) train_loader, valid_loader, test_loader = create_dataloader( train_path, valid_path, test_path, tokenizer, batch_size=args.batch_size, max_length=args.max_length, num_workers=2, ) model = VulnerabilityCodeT5(model_name=args.model_name, num_labels=2) print(f"Trainable parameters: {count_parameters(model):,}") trainer = Trainer( model, train_loader, valid_loader, device, learning_rate=args.learning_rate, num_epochs=args.epochs, gradient_accumulation_steps=args.gradient_accumulation, ) trainer.train(args.output_dir) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_name", default="Salesforce/codet5-base") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--max_length", type=int, default=256) parser.add_argument("--learning_rate", type=float, default=2e-5) parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--gradient_accumulation", type=int, default=4) parser.add_argument("--output_dir", default="models/v2") parser.add_argument("--use_sample", action="store_true") main(parser.parse_args())