| | from transformers import RobertaTokenizer |
| | from torch.utils.data import Dataset, DataLoader |
| | import torch |
| | import json |
| | from pathlib import Path |
| |
|
| |
|
| | class VulnerabilityDataset(Dataset): |
| | """PyTorch dataset for vulnerability detection""" |
| |
|
| | def __init__(self, data_path, tokenizer, max_length=512): |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| |
|
| | self.data = [] |
| | data_path = Path(data_path) |
| |
|
| | if not data_path.exists(): |
| | raise FileNotFoundError(f"Dataset file not found: {data_path}") |
| |
|
| | with open(data_path, "r", encoding="utf-8") as f: |
| | for line in f: |
| | line = line.strip() |
| | if line: |
| | self.data.append(json.loads(line)) |
| |
|
| | print(f"{data_path.name}: {len(self.data)} samples") |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | sample = self.data[idx] |
| |
|
| | code = sample["func"] |
| | label = sample["target"] |
| |
|
| | encoding = self.tokenizer( |
| | code, |
| | truncation=True, |
| | padding="max_length", |
| | max_length=self.max_length, |
| | return_tensors="pt" |
| | ) |
| |
|
| | return { |
| | "input_ids": encoding["input_ids"].squeeze(0), |
| | "attention_mask": encoding["attention_mask"].squeeze(0), |
| | "labels": torch.tensor(label, dtype=torch.long) |
| | } |
| |
|
| |
|
| | def load_tokenizer(model_name="Salesforce/codet5-base"): |
| | print(f"Tokenizer: {model_name}") |
| | return RobertaTokenizer.from_pretrained(model_name) |
| |
|
| |
|
| | def create_dataloader( |
| | train_path, |
| | valid_path, |
| | test_path, |
| | tokenizer, |
| | batch_size=8, |
| | max_length=512, |
| | num_workers=2, |
| | ): |
| | train_dataset = VulnerabilityDataset(train_path, tokenizer, max_length) |
| | valid_dataset = VulnerabilityDataset(valid_path, tokenizer, max_length) |
| | test_dataset = VulnerabilityDataset(test_path, tokenizer, max_length) |
| |
|
| | if len(train_dataset) == 0: |
| | raise RuntimeError(f"No samples found in {train_path}") |
| |
|
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=batch_size, |
| | shuffle=True, |
| | num_workers=num_workers, |
| | pin_memory=True, |
| | persistent_workers=True |
| | ) |
| |
|
| | valid_loader = DataLoader( |
| | valid_dataset, |
| | batch_size=batch_size, |
| | shuffle=False, |
| | num_workers=num_workers, |
| | pin_memory=True, |
| | persistent_workers=True |
| | ) |
| |
|
| | test_loader = DataLoader( |
| | test_dataset, |
| | batch_size=batch_size, |
| | shuffle=False, |
| | num_workers=num_workers, |
| | pin_memory=True, |
| | persistent_workers=True |
| | ) |
| |
|
| | return train_loader, valid_loader, test_loader |
| |
|