Pranav Pc
Final Deploy
4b82ab5
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.amp import autocast, GradScaler
from transformers import get_linear_schedule_with_warmup
from pathlib import Path
from tqdm import tqdm
import argparse
import json
import gc
import sys
sys.path.append(str(Path(__file__).parent.parent))
from src.v2.data_processor import load_tokenizer, create_dataloader
from src.v2.model import VulnerabilityCodeT5, count_parameters
class Trainer:
def __init__(
self,
model,
train_loader,
valid_loader,
device,
learning_rate=2e-5,
num_epochs=5,
gradient_accumulation_steps=4,
):
self.model = model.to(device)
self.train_loader = train_loader
self.valid_loader = valid_loader
self.device = device
self.num_epochs = num_epochs
self.gradient_accumulation_steps = gradient_accumulation_steps
self.use_amp = device.type == "cuda"
self.scaler = GradScaler(enabled=self.use_amp)
self.optimizer = AdamW(
self.model.parameters(), lr=learning_rate, weight_decay=0.01
)
total_steps = (
len(self.train_loader) * num_epochs
) // gradient_accumulation_steps
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=max(1, total_steps // 10),
num_training_steps=total_steps,
)
self.best_val_acc = 0.0
self.history = {
"train_loss": [],
"train_acc": [],
"val_loss": [],
"val_acc": [],
}
def clear_memory(self):
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def train_epoch(self):
self.model.train()
total_loss = 0.0
correct = 0
total = 0
self.optimizer.zero_grad(set_to_none=True)
pbar = tqdm(self.train_loader, desc="Training")
for step, batch in enumerate(pbar):
input_ids = batch["input_ids"].to(self.device, non_blocking=True)
attention_mask = batch["attention_mask"].to(self.device, non_blocking=True)
labels = batch["labels"].to(self.device, non_blocking=True)
with autocast(device_type="cuda", enabled=self.use_amp):
outputs = self.model(input_ids, attention_mask, labels)
loss = outputs["loss"] / self.gradient_accumulation_steps
self.scaler.scale(loss).backward()
if (step + 1) % self.gradient_accumulation_steps == 0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
self.scheduler.step()
self.optimizer.zero_grad(set_to_none=True)
with torch.no_grad():
preds = torch.argmax(outputs["logits"], dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
total_loss += loss.item() * self.gradient_accumulation_steps
gpu_mem = (
torch.cuda.memory_allocated() / 1024 ** 3
if torch.cuda.is_available()
else 0
)
pbar.set_postfix(
{
"loss": f"{loss.item() * self.gradient_accumulation_steps:.4f}",
"acc": f"{100 * correct / max(1, total):.2f}%",
"gpu": f"{gpu_mem:.2f}GB",
}
)
del input_ids, attention_mask, labels, outputs, loss
self.clear_memory()
return total_loss / len(self.train_loader), 100 * correct / total
def validate(self):
self.model.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for batch in tqdm(self.valid_loader, desc="Validating"):
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
with autocast(device_type="cuda", enabled=self.use_amp):
outputs = self.model(input_ids, attention_mask, labels)
loss = outputs["loss"]
preds = torch.argmax(outputs["logits"], dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
total_loss += loss.item()
self.clear_memory()
return total_loss / len(self.valid_loader), 100 * correct / total
def train(self, save_dir="models/v2"):
print(f"Training samples: {len(self.train_loader.dataset)}")
print(f"Validation samples: {len(self.valid_loader.dataset)}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
for epoch in range(self.num_epochs):
print(f"\n{'=' * 60}")
print(f"Epoch {epoch + 1}/{self.num_epochs}")
print(f"{'=' * 60}")
train_loss, train_acc = self.train_epoch()
val_loss, val_acc = self.validate()
print(
f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%"
)
print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
self.history["train_loss"].append(train_loss)
self.history["train_acc"].append(train_acc)
self.history["val_loss"].append(val_loss)
self.history["val_acc"].append(val_acc)
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
torch.save(
{
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"val_acc": val_acc,
},
save_dir / "best_model.pt",
)
print("Saved best model")
torch.save(
{
"model_state_dict": self.model.state_dict(),
"history": self.history,
},
save_dir / "final_model.pt",
)
with open(save_dir / "training_history.json", "w") as f:
json.dump(self.history, f, indent=2)
print(f"\nTraining complete. Best Val Acc: {self.best_val_acc:.2f}%")
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = (
Path("data/processed/sample") if args.use_sample else Path("data/processed")
)
train_path = data_dir / "train.jsonl"
valid_path = data_dir / "valid.jsonl"
test_path = data_dir / "test.jsonl"
tokenizer = load_tokenizer(args.model_name)
train_loader, valid_loader, test_loader = create_dataloader(
train_path,
valid_path,
test_path,
tokenizer,
batch_size=args.batch_size,
max_length=args.max_length,
num_workers=2,
)
model = VulnerabilityCodeT5(model_name=args.model_name, num_labels=2)
print(f"Trainable parameters: {count_parameters(model):,}")
trainer = Trainer(
model,
train_loader,
valid_loader,
device,
learning_rate=args.learning_rate,
num_epochs=args.epochs,
gradient_accumulation_steps=args.gradient_accumulation,
)
trainer.train(args.output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="Salesforce/codet5-base")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_length", type=int, default=256)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--gradient_accumulation", type=int, default=4)
parser.add_argument("--output_dir", default="models/v2")
parser.add_argument("--use_sample", action="store_true")
main(parser.parse_args())