File size: 4,330 Bytes
4b82ab5 2e08b23 4b82ab5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """Inference module for vulnerability detection
Load trained models and make predictions"""
import torch
from transformers import RobertaTokenizer
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent.parent))
from src.model import VulnerabilityCodeT5
class VulnerabilityDetector:
def __init__(self, model_path="models/best_model_clean.pt",
model_name="Salesforce/codet5-base", max_length=256):
### CHANGED FOR DEPLOYMENT
self.device = torch.device('cpu')
self.max_length = max_length
self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
self.model = VulnerabilityCodeT5(model_name=model_name, num_labels=2)
state_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
print("Model Loaded Successfully")
self.labels = {
0: "Safe Code",
1: "Vulnerable Code"
}
def predict(self, code_snippet):
"""Predict Vulnerability of Code Snippet
Args :
code_snippet: String Containing source code
Returns:
dict with predictions, confidence and label
"""
inputs = self.tokenizer(
code_snippet,
max_length=256,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
with torch.no_grad():
predictions, probs = self.model.predict(input_ids, attention_mask)
pred_label = predictions[0].item()
confidence = probs[0][pred_label].item()
return {
'prediction': pred_label,
'label': self.labels[pred_label],
'confidence': confidence,
'probabilities':{
'safe': probs[0][0].item(),
'vulnerable': probs[0][1].item()
}
}
def analyze_batch(self, code_snippets):
"""Analyze multiple code snippets at once"""
return [self.predict(code) for code in code_snippets]
def test_inference():
detector = VulnerabilityDetector()
test_cases = [
{
"name": "Safe Bounded Copy",
"code": """void copy_input(const char *input) {
char buffer[32];
strncpy(buffer, input, sizeof(buffer) - 1);
buffer[sizeof(buffer) - 1] = '\\0';
}"""
},
{
"name": "Safe fgets Input",
"code": """void read_input() {
char buffer[64];
if (fgets(buffer, sizeof(buffer), stdin) != NULL) {
printf("%s", buffer);
}
}"""
},
{
"name": "Safe malloc usage",
"code": """void allocate() {
char *buf = (char *)malloc(128);
if (buf == NULL) {
return;
}
strcpy(buf, "safe");
free(buf);
}"""
},
{
"name": "Stack Buffer Overflow",
"code": """void copy_input(char *input) {
char buffer[8];
strcpy(buffer, input);
}"""
},
{
"name": "Integer Overflow",
"code": """void allocate(int size) {
char *buf = (char *)malloc(size * sizeof(char));
if (buf == NULL) return;
memset(buf, 'A', size + 10);
}"""
},
{
"name": "Use After Free",
"code": """void uaf() {
char *buf = (char *)malloc(16);
free(buf);
strcpy(buf, "UAF");
}"""
}
]
print("\n" + "="*60)
print("Testing Vulnerability Detection")
print("="*60)
for test in test_cases:
print(f"\nTest: {test['name']}")
print(f"Code: {test['code'][:60]}...")
result = detector.predict(test['code'])
print(f"Prediction: {result['label']}")
print(f"Confidence: {result['confidence']:.2%}")
print(f" - Safe: {result['probabilities']['safe']:.2%}")
print(f" - Vulnerable: {result['probabilities']['vulnerable']:.2%}")
if __name__ == "__main__":
test_inference() |