File size: 4,330 Bytes
4b82ab5
 
 
 
 
 
 
 
 
 
 
 
2e08b23
4b82ab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Inference module for vulnerability detection
Load trained models and make predictions"""

import torch
from transformers import RobertaTokenizer
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent.parent))

from src.model import VulnerabilityCodeT5

class VulnerabilityDetector:
    def __init__(self, model_path="models/best_model_clean.pt",
                 model_name="Salesforce/codet5-base", max_length=256):
        
        ### CHANGED FOR DEPLOYMENT
        self.device = torch.device('cpu')
        self.max_length = max_length

        self.tokenizer = RobertaTokenizer.from_pretrained(model_name)

        self.model = VulnerabilityCodeT5(model_name=model_name, num_labels=2)

        state_dict = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model.to(self.device)
        self.model.eval()


        print("Model Loaded Successfully")

        self.labels = {
            0: "Safe Code",
            1: "Vulnerable Code"
        }

    def predict(self, code_snippet):
        """Predict Vulnerability of Code Snippet
        
        Args : 
        code_snippet: String Containing source code
        
        Returns:
          dict with predictions, confidence and label
        
        """
        inputs = self.tokenizer(
            code_snippet, 
            max_length=256,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        input_ids = inputs['input_ids'].to(self.device)
        attention_mask = inputs['attention_mask'].to(self.device)

        with torch.no_grad():

            predictions, probs = self.model.predict(input_ids, attention_mask)

        pred_label = predictions[0].item()
        confidence = probs[0][pred_label].item()

        return {
            'prediction': pred_label,
            'label': self.labels[pred_label],
            'confidence': confidence,
            'probabilities':{
                'safe': probs[0][0].item(),
                'vulnerable': probs[0][1].item()
            }
        }
    
    def analyze_batch(self, code_snippets):
        """Analyze multiple code snippets at once"""
        return [self.predict(code) for code in code_snippets]
    
def test_inference():
    detector = VulnerabilityDetector()
    
    


    test_cases = [
        {
            "name": "Safe Bounded Copy",
            "code": """void copy_input(const char *input) {
        char buffer[32];
        strncpy(buffer, input, sizeof(buffer) - 1);
        buffer[sizeof(buffer) - 1] = '\\0';
    }"""
        },
        {
            "name": "Safe fgets Input",
            "code": """void read_input() {
        char buffer[64];
        if (fgets(buffer, sizeof(buffer), stdin) != NULL) {
            printf("%s", buffer);
        }
    }"""
        },
        {
            "name": "Safe malloc usage",
            "code": """void allocate() {
        char *buf = (char *)malloc(128);
        if (buf == NULL) {
            return;
        }
        strcpy(buf, "safe");
        free(buf);
    }"""
        },
        {
            "name": "Stack Buffer Overflow",
            "code": """void copy_input(char *input) {
        char buffer[8];
        strcpy(buffer, input);
    }"""
        },
        {
            "name": "Integer Overflow",
            "code": """void allocate(int size) {
        char *buf = (char *)malloc(size * sizeof(char));
        if (buf == NULL) return;
        memset(buf, 'A', size + 10);
    }"""
        },
        {
            "name": "Use After Free",
            "code": """void uaf() {
        char *buf = (char *)malloc(16);
        free(buf);
        strcpy(buf, "UAF");
    }"""
        }
    ]


    print("\n" + "="*60)
    print("Testing Vulnerability Detection")
    print("="*60)
    
    for test in test_cases:
        print(f"\nTest: {test['name']}")
        print(f"Code: {test['code'][:60]}...")
        
        result = detector.predict(test['code'])
        
        print(f"Prediction: {result['label']}")
        print(f"Confidence: {result['confidence']:.2%}")
        print(f"   - Safe: {result['probabilities']['safe']:.2%}")
        print(f"   - Vulnerable: {result['probabilities']['vulnerable']:.2%}")

if __name__ == "__main__":
    test_inference()