| | """Inference module for vulnerability detection |
| | Load trained models and make predictions""" |
| |
|
| | import torch |
| | from transformers import RobertaTokenizer |
| | from pathlib import Path |
| | import sys |
| | sys.path.append(str(Path(__file__).parent.parent.parent)) |
| |
|
| | from src.model import VulnerabilityCodeT5 |
| |
|
| | class VulnerabilityDetector: |
| | def __init__(self, model_path="models/best_model_clean.pt", |
| | model_name="Salesforce/codet5-base", max_length=256): |
| | |
| | |
| | self.device = torch.device('cpu') |
| | self.max_length = max_length |
| |
|
| | self.tokenizer = RobertaTokenizer.from_pretrained(model_name) |
| |
|
| | self.model = VulnerabilityCodeT5(model_name=model_name, num_labels=2) |
| |
|
| | state_dict = torch.load(model_path, map_location=self.device) |
| | self.model.load_state_dict(state_dict) |
| | self.model.to(self.device) |
| | self.model.eval() |
| |
|
| |
|
| | print("Model Loaded Successfully") |
| |
|
| | self.labels = { |
| | 0: "Safe Code", |
| | 1: "Vulnerable Code" |
| | } |
| |
|
| | def predict(self, code_snippet): |
| | """Predict Vulnerability of Code Snippet |
| | |
| | Args : |
| | code_snippet: String Containing source code |
| | |
| | Returns: |
| | dict with predictions, confidence and label |
| | |
| | """ |
| | inputs = self.tokenizer( |
| | code_snippet, |
| | max_length=256, |
| | padding='max_length', |
| | truncation=True, |
| | return_tensors='pt' |
| | ) |
| |
|
| | input_ids = inputs['input_ids'].to(self.device) |
| | attention_mask = inputs['attention_mask'].to(self.device) |
| |
|
| | with torch.no_grad(): |
| |
|
| | predictions, probs = self.model.predict(input_ids, attention_mask) |
| |
|
| | pred_label = predictions[0].item() |
| | confidence = probs[0][pred_label].item() |
| |
|
| | return { |
| | 'prediction': pred_label, |
| | 'label': self.labels[pred_label], |
| | 'confidence': confidence, |
| | 'probabilities':{ |
| | 'safe': probs[0][0].item(), |
| | 'vulnerable': probs[0][1].item() |
| | } |
| | } |
| | |
| | def analyze_batch(self, code_snippets): |
| | """Analyze multiple code snippets at once""" |
| | return [self.predict(code) for code in code_snippets] |
| | |
| | def test_inference(): |
| | detector = VulnerabilityDetector() |
| | |
| | |
| |
|
| |
|
| | test_cases = [ |
| | { |
| | "name": "Safe Bounded Copy", |
| | "code": """void copy_input(const char *input) { |
| | char buffer[32]; |
| | strncpy(buffer, input, sizeof(buffer) - 1); |
| | buffer[sizeof(buffer) - 1] = '\\0'; |
| | }""" |
| | }, |
| | { |
| | "name": "Safe fgets Input", |
| | "code": """void read_input() { |
| | char buffer[64]; |
| | if (fgets(buffer, sizeof(buffer), stdin) != NULL) { |
| | printf("%s", buffer); |
| | } |
| | }""" |
| | }, |
| | { |
| | "name": "Safe malloc usage", |
| | "code": """void allocate() { |
| | char *buf = (char *)malloc(128); |
| | if (buf == NULL) { |
| | return; |
| | } |
| | strcpy(buf, "safe"); |
| | free(buf); |
| | }""" |
| | }, |
| | { |
| | "name": "Stack Buffer Overflow", |
| | "code": """void copy_input(char *input) { |
| | char buffer[8]; |
| | strcpy(buffer, input); |
| | }""" |
| | }, |
| | { |
| | "name": "Integer Overflow", |
| | "code": """void allocate(int size) { |
| | char *buf = (char *)malloc(size * sizeof(char)); |
| | if (buf == NULL) return; |
| | memset(buf, 'A', size + 10); |
| | }""" |
| | }, |
| | { |
| | "name": "Use After Free", |
| | "code": """void uaf() { |
| | char *buf = (char *)malloc(16); |
| | free(buf); |
| | strcpy(buf, "UAF"); |
| | }""" |
| | } |
| | ] |
| |
|
| |
|
| | print("\n" + "="*60) |
| | print("Testing Vulnerability Detection") |
| | print("="*60) |
| | |
| | for test in test_cases: |
| | print(f"\nTest: {test['name']}") |
| | print(f"Code: {test['code'][:60]}...") |
| | |
| | result = detector.predict(test['code']) |
| | |
| | print(f"Prediction: {result['label']}") |
| | print(f"Confidence: {result['confidence']:.2%}") |
| | print(f" - Safe: {result['probabilities']['safe']:.2%}") |
| | print(f" - Vulnerable: {result['probabilities']['vulnerable']:.2%}") |
| |
|
| | if __name__ == "__main__": |
| | test_inference() |