File size: 2,454 Bytes
4b82ab5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | """CodeT5 Vulnerability Detection model
Binary Classication Safe(0) vs Vulnerable(1)"""
import torch
import torch.nn as nn
from transformers import T5ForConditionalGeneration, RobertaTokenizer
class VulnerabilityCodeT5(nn.Module):
"""CodeT5 model for vulnerability detection"""
def __init__(self, model_name="Salesforce/codet5-base", num_labels=2):
super().__init__()
self.encoder_decoder = T5ForConditionalGeneration.from_pretrained(model_name)
#Get hidden size from config
hidden_size = self.encoder_decoder.config.d_model #768 for base
#Classification Head
self.classifier = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, num_labels)
)
self.num_labels = num_labels
def forward(self, input_ids, attention_mask, labels=None):
"""
Forward pass
Args:
input_ids : tokenized code [batch_size, seq_len]
attention_mask : attention mask [batch_size, seq_len]
labels: ground truth labels [batch_size]
"""
#Get encoder outputs
encoder_outputs = self.encoder_decoder.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
#Pool encoder outputs (use first token [CLS])
hidden_state = encoder_outputs.last_hidden_state # [batch, seq_len, hidden]
pooled_output = hidden_state[:, 0, :] # [batch, hidden]
#Classification
logits = self.classifier(pooled_output) # [batch, num_labels]
#Calculate loss
loss = None
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
return {
'loss': loss,
'logits': logits,
'hidden_states': hidden_state
}
def predict(self, input_ids, attention_mask):
"""Make Predictions"""
self.eval()
with torch.no_grad():
outputs = self.forward(input_ids, attention_mask)
probs = torch.softmax(outputs["logits"], dim=1)
predictions = torch.argmax(probs, dim=1)
return predictions, probs
def count_parameters(model):
"""Count trainable parameters"""
return sum(p.numel() for p in model.parameters() if p.requires_grad) |