"""Inference module for vulnerability detection Load trained models and make predictions""" import torch from transformers import RobertaTokenizer from pathlib import Path import sys sys.path.append(str(Path(__file__).parent.parent.parent)) from src.model import VulnerabilityCodeT5 class VulnerabilityDetector: def __init__(self, model_path="models/best_model_clean.pt", model_name="Salesforce/codet5-base", max_length=256): ### CHANGED FOR DEPLOYMENT self.device = torch.device('cpu') self.max_length = max_length self.tokenizer = RobertaTokenizer.from_pretrained(model_name) self.model = VulnerabilityCodeT5(model_name=model_name, num_labels=2) state_dict = torch.load(model_path, map_location=self.device) self.model.load_state_dict(state_dict) self.model.to(self.device) self.model.eval() print("Model Loaded Successfully") self.labels = { 0: "Safe Code", 1: "Vulnerable Code" } def predict(self, code_snippet): """Predict Vulnerability of Code Snippet Args : code_snippet: String Containing source code Returns: dict with predictions, confidence and label """ inputs = self.tokenizer( code_snippet, max_length=256, padding='max_length', truncation=True, return_tensors='pt' ) input_ids = inputs['input_ids'].to(self.device) attention_mask = inputs['attention_mask'].to(self.device) with torch.no_grad(): predictions, probs = self.model.predict(input_ids, attention_mask) pred_label = predictions[0].item() confidence = probs[0][pred_label].item() return { 'prediction': pred_label, 'label': self.labels[pred_label], 'confidence': confidence, 'probabilities':{ 'safe': probs[0][0].item(), 'vulnerable': probs[0][1].item() } } def analyze_batch(self, code_snippets): """Analyze multiple code snippets at once""" return [self.predict(code) for code in code_snippets] def test_inference(): detector = VulnerabilityDetector() test_cases = [ { "name": "Safe Bounded Copy", "code": """void copy_input(const char *input) { char buffer[32]; strncpy(buffer, input, sizeof(buffer) - 1); buffer[sizeof(buffer) - 1] = '\\0'; }""" }, { "name": "Safe fgets Input", "code": """void read_input() { char buffer[64]; if (fgets(buffer, sizeof(buffer), stdin) != NULL) { printf("%s", buffer); } }""" }, { "name": "Safe malloc usage", "code": """void allocate() { char *buf = (char *)malloc(128); if (buf == NULL) { return; } strcpy(buf, "safe"); free(buf); }""" }, { "name": "Stack Buffer Overflow", "code": """void copy_input(char *input) { char buffer[8]; strcpy(buffer, input); }""" }, { "name": "Integer Overflow", "code": """void allocate(int size) { char *buf = (char *)malloc(size * sizeof(char)); if (buf == NULL) return; memset(buf, 'A', size + 10); }""" }, { "name": "Use After Free", "code": """void uaf() { char *buf = (char *)malloc(16); free(buf); strcpy(buf, "UAF"); }""" } ] print("\n" + "="*60) print("Testing Vulnerability Detection") print("="*60) for test in test_cases: print(f"\nTest: {test['name']}") print(f"Code: {test['code'][:60]}...") result = detector.predict(test['code']) print(f"Prediction: {result['label']}") print(f"Confidence: {result['confidence']:.2%}") print(f" - Safe: {result['probabilities']['safe']:.2%}") print(f" - Vulnerable: {result['probabilities']['vulnerable']:.2%}") if __name__ == "__main__": test_inference()