| import torch | |
| import torch.nn as nn | |
| from torch.optim import AdamW | |
| from torch.amp import autocast, GradScaler | |
| from transformers import get_linear_schedule_with_warmup | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import argparse | |
| import json | |
| import gc | |
| import sys | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| from src.v2.data_processor import load_tokenizer, create_dataloader | |
| from src.v2.model import VulnerabilityCodeT5, count_parameters | |
| class Trainer: | |
| def __init__( | |
| self, | |
| model, | |
| train_loader, | |
| valid_loader, | |
| device, | |
| learning_rate=2e-5, | |
| num_epochs=5, | |
| gradient_accumulation_steps=4, | |
| ): | |
| self.model = model.to(device) | |
| self.train_loader = train_loader | |
| self.valid_loader = valid_loader | |
| self.device = device | |
| self.num_epochs = num_epochs | |
| self.gradient_accumulation_steps = gradient_accumulation_steps | |
| self.use_amp = device.type == "cuda" | |
| self.scaler = GradScaler(enabled=self.use_amp) | |
| self.optimizer = AdamW( | |
| self.model.parameters(), lr=learning_rate, weight_decay=0.01 | |
| ) | |
| total_steps = ( | |
| len(self.train_loader) * num_epochs | |
| ) // gradient_accumulation_steps | |
| self.scheduler = get_linear_schedule_with_warmup( | |
| self.optimizer, | |
| num_warmup_steps=max(1, total_steps // 10), | |
| num_training_steps=total_steps, | |
| ) | |
| self.best_val_acc = 0.0 | |
| self.history = { | |
| "train_loss": [], | |
| "train_acc": [], | |
| "val_loss": [], | |
| "val_acc": [], | |
| } | |
| def clear_memory(self): | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def train_epoch(self): | |
| self.model.train() | |
| total_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| self.optimizer.zero_grad(set_to_none=True) | |
| pbar = tqdm(self.train_loader, desc="Training") | |
| for step, batch in enumerate(pbar): | |
| input_ids = batch["input_ids"].to(self.device, non_blocking=True) | |
| attention_mask = batch["attention_mask"].to(self.device, non_blocking=True) | |
| labels = batch["labels"].to(self.device, non_blocking=True) | |
| with autocast(device_type="cuda", enabled=self.use_amp): | |
| outputs = self.model(input_ids, attention_mask, labels) | |
| loss = outputs["loss"] / self.gradient_accumulation_steps | |
| self.scaler.scale(loss).backward() | |
| if (step + 1) % self.gradient_accumulation_steps == 0: | |
| self.scaler.unscale_(self.optimizer) | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| self.scheduler.step() | |
| self.optimizer.zero_grad(set_to_none=True) | |
| with torch.no_grad(): | |
| preds = torch.argmax(outputs["logits"], dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| total_loss += loss.item() * self.gradient_accumulation_steps | |
| gpu_mem = ( | |
| torch.cuda.memory_allocated() / 1024 ** 3 | |
| if torch.cuda.is_available() | |
| else 0 | |
| ) | |
| pbar.set_postfix( | |
| { | |
| "loss": f"{loss.item() * self.gradient_accumulation_steps:.4f}", | |
| "acc": f"{100 * correct / max(1, total):.2f}%", | |
| "gpu": f"{gpu_mem:.2f}GB", | |
| } | |
| ) | |
| del input_ids, attention_mask, labels, outputs, loss | |
| self.clear_memory() | |
| return total_loss / len(self.train_loader), 100 * correct / total | |
| def validate(self): | |
| self.model.eval() | |
| total_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for batch in tqdm(self.valid_loader, desc="Validating"): | |
| input_ids = batch["input_ids"].to(self.device) | |
| attention_mask = batch["attention_mask"].to(self.device) | |
| labels = batch["labels"].to(self.device) | |
| with autocast(device_type="cuda", enabled=self.use_amp): | |
| outputs = self.model(input_ids, attention_mask, labels) | |
| loss = outputs["loss"] | |
| preds = torch.argmax(outputs["logits"], dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| total_loss += loss.item() | |
| self.clear_memory() | |
| return total_loss / len(self.valid_loader), 100 * correct / total | |
| def train(self, save_dir="models/v2"): | |
| print(f"Training samples: {len(self.train_loader.dataset)}") | |
| print(f"Validation samples: {len(self.valid_loader.dataset)}") | |
| if torch.cuda.is_available(): | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| save_dir = Path(save_dir) | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| for epoch in range(self.num_epochs): | |
| print(f"\n{'=' * 60}") | |
| print(f"Epoch {epoch + 1}/{self.num_epochs}") | |
| print(f"{'=' * 60}") | |
| train_loss, train_acc = self.train_epoch() | |
| val_loss, val_acc = self.validate() | |
| print( | |
| f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%" | |
| ) | |
| print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%") | |
| self.history["train_loss"].append(train_loss) | |
| self.history["train_acc"].append(train_acc) | |
| self.history["val_loss"].append(val_loss) | |
| self.history["val_acc"].append(val_acc) | |
| if val_acc > self.best_val_acc: | |
| self.best_val_acc = val_acc | |
| torch.save( | |
| { | |
| "model_state_dict": self.model.state_dict(), | |
| "optimizer_state_dict": self.optimizer.state_dict(), | |
| "val_acc": val_acc, | |
| }, | |
| save_dir / "best_model.pt", | |
| ) | |
| print("Saved best model") | |
| torch.save( | |
| { | |
| "model_state_dict": self.model.state_dict(), | |
| "history": self.history, | |
| }, | |
| save_dir / "final_model.pt", | |
| ) | |
| with open(save_dir / "training_history.json", "w") as f: | |
| json.dump(self.history, f, indent=2) | |
| print(f"\nTraining complete. Best Val Acc: {self.best_val_acc:.2f}%") | |
| def main(args): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| data_dir = ( | |
| Path("data/processed/sample") if args.use_sample else Path("data/processed") | |
| ) | |
| train_path = data_dir / "train.jsonl" | |
| valid_path = data_dir / "valid.jsonl" | |
| test_path = data_dir / "test.jsonl" | |
| tokenizer = load_tokenizer(args.model_name) | |
| train_loader, valid_loader, test_loader = create_dataloader( | |
| train_path, | |
| valid_path, | |
| test_path, | |
| tokenizer, | |
| batch_size=args.batch_size, | |
| max_length=args.max_length, | |
| num_workers=2, | |
| ) | |
| model = VulnerabilityCodeT5(model_name=args.model_name, num_labels=2) | |
| print(f"Trainable parameters: {count_parameters(model):,}") | |
| trainer = Trainer( | |
| model, | |
| train_loader, | |
| valid_loader, | |
| device, | |
| learning_rate=args.learning_rate, | |
| num_epochs=args.epochs, | |
| gradient_accumulation_steps=args.gradient_accumulation, | |
| ) | |
| trainer.train(args.output_dir) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_name", default="Salesforce/codet5-base") | |
| parser.add_argument("--batch_size", type=int, default=4) | |
| parser.add_argument("--max_length", type=int, default=256) | |
| parser.add_argument("--learning_rate", type=float, default=2e-5) | |
| parser.add_argument("--epochs", type=int, default=3) | |
| parser.add_argument("--gradient_accumulation", type=int, default=4) | |
| parser.add_argument("--output_dir", default="models/v2") | |
| parser.add_argument("--use_sample", action="store_true") | |
| main(parser.parse_args()) | |