Pranav Pc
Align model path with actual file
2e08b23
"""Inference module for vulnerability detection
Load trained models and make predictions"""
import torch
from transformers import RobertaTokenizer
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent.parent))
from src.model import VulnerabilityCodeT5
class VulnerabilityDetector:
def __init__(self, model_path="models/best_model_clean.pt",
model_name="Salesforce/codet5-base", max_length=256):
### CHANGED FOR DEPLOYMENT
self.device = torch.device('cpu')
self.max_length = max_length
self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
self.model = VulnerabilityCodeT5(model_name=model_name, num_labels=2)
state_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
print("Model Loaded Successfully")
self.labels = {
0: "Safe Code",
1: "Vulnerable Code"
}
def predict(self, code_snippet):
"""Predict Vulnerability of Code Snippet
Args :
code_snippet: String Containing source code
Returns:
dict with predictions, confidence and label
"""
inputs = self.tokenizer(
code_snippet,
max_length=256,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
with torch.no_grad():
predictions, probs = self.model.predict(input_ids, attention_mask)
pred_label = predictions[0].item()
confidence = probs[0][pred_label].item()
return {
'prediction': pred_label,
'label': self.labels[pred_label],
'confidence': confidence,
'probabilities':{
'safe': probs[0][0].item(),
'vulnerable': probs[0][1].item()
}
}
def analyze_batch(self, code_snippets):
"""Analyze multiple code snippets at once"""
return [self.predict(code) for code in code_snippets]
def test_inference():
detector = VulnerabilityDetector()
test_cases = [
{
"name": "Safe Bounded Copy",
"code": """void copy_input(const char *input) {
char buffer[32];
strncpy(buffer, input, sizeof(buffer) - 1);
buffer[sizeof(buffer) - 1] = '\\0';
}"""
},
{
"name": "Safe fgets Input",
"code": """void read_input() {
char buffer[64];
if (fgets(buffer, sizeof(buffer), stdin) != NULL) {
printf("%s", buffer);
}
}"""
},
{
"name": "Safe malloc usage",
"code": """void allocate() {
char *buf = (char *)malloc(128);
if (buf == NULL) {
return;
}
strcpy(buf, "safe");
free(buf);
}"""
},
{
"name": "Stack Buffer Overflow",
"code": """void copy_input(char *input) {
char buffer[8];
strcpy(buffer, input);
}"""
},
{
"name": "Integer Overflow",
"code": """void allocate(int size) {
char *buf = (char *)malloc(size * sizeof(char));
if (buf == NULL) return;
memset(buf, 'A', size + 10);
}"""
},
{
"name": "Use After Free",
"code": """void uaf() {
char *buf = (char *)malloc(16);
free(buf);
strcpy(buf, "UAF");
}"""
}
]
print("\n" + "="*60)
print("Testing Vulnerability Detection")
print("="*60)
for test in test_cases:
print(f"\nTest: {test['name']}")
print(f"Code: {test['code'][:60]}...")
result = detector.predict(test['code'])
print(f"Prediction: {result['label']}")
print(f"Confidence: {result['confidence']:.2%}")
print(f" - Safe: {result['probabilities']['safe']:.2%}")
print(f" - Vulnerable: {result['probabilities']['vulnerable']:.2%}")
if __name__ == "__main__":
test_inference()