Pranav Pc commited on
Commit
4b82ab5
·
1 Parent(s): 2075aa2

Final Deploy

Browse files
Dockerfile CHANGED
@@ -1,20 +1,13 @@
1
- FROM python:3.13.5-slim
2
 
3
  WORKDIR /app
4
 
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- git \
9
- && rm -rf /var/lib/apt/lists/*
10
 
11
- COPY requirements.txt ./
12
- COPY src/ ./src/
13
 
14
- RUN pip3 install -r requirements.txt
15
 
16
- EXPOSE 8501
17
 
18
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
-
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
1
+ FROM python:3.10
2
 
3
  WORKDIR /app
4
 
5
+ COPY requirements.txt .
 
 
 
 
6
 
7
+ RUN pip install --no-cache-dir -r requirements.txt
 
8
 
9
+ COPY . .
10
 
11
+ EXPOSE 7860
12
 
13
+ CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
 
 
README.md CHANGED
@@ -1,19 +0,0 @@
1
- ---
2
- title: Code Vulnerability Detection
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: AI-powered code vulnerability detection.
12
- ---
13
-
14
- # Welcome to Streamlit!
15
-
16
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
17
-
18
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit UI for Vulnerability Detection
3
+ Interactive web interface
4
+ """
5
+
6
+ import streamlit as st
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ sys.path.append(str(Path(__file__).parent))
11
+
12
+ from src.inference import VulnerabilityDetector
13
+
14
+ # Page config
15
+ st.set_page_config(
16
+ page_title="Code Vulnerability Detector",
17
+ page_icon="🔒",
18
+ layout="wide"
19
+ )
20
+
21
+ # Initialize detector (cache it so it loads only once)
22
+ @st.cache_resource
23
+ def load_detector():
24
+ return VulnerabilityDetector()
25
+
26
+ # Main app
27
+ def main():
28
+ st.title("🔒 AI-Powered Code Vulnerability Detection")
29
+ st.markdown("### Detect security vulnerabilities in your code using fine-tuned CodeT5")
30
+
31
+ # Sidebar
32
+ with st.sidebar:
33
+ st.header("ℹ️ About")
34
+ st.markdown("""
35
+ This tool uses a fine-tuned CodeT5 model to detect security vulnerabilities in source code.
36
+
37
+ **Supported Languages:**
38
+ - C/C++
39
+ - Python
40
+ - JavaScript
41
+
42
+ **Detection Types:**
43
+ - Buffer Overflow
44
+ - SQL Injection
45
+ - Command Injection
46
+ - Format String Bugs
47
+ - And more...
48
+ """)
49
+
50
+ st.header("📊 Model Info")
51
+ try:
52
+ detector = load_detector()
53
+ st.success("Model loaded successfully!")
54
+ except Exception as e:
55
+ st.error(f"Error loading model: {e}")
56
+ st.stop()
57
+
58
+ # Main area
59
+ col1, col2 = st.columns([1, 1])
60
+
61
+ with col1:
62
+ st.header("📝 Enter Code")
63
+
64
+ # Example selector
65
+ example = st.selectbox(
66
+ "Or try an example:",
67
+ ["Custom", "Buffer Overflow", "SQL Injection", "Safe Code"]
68
+ )
69
+
70
+ if example == "Buffer Overflow":
71
+ default_code = '''void copy(char *input) {
72
+ char buffer[8];
73
+ strcpy(buffer, input);
74
+ }'''
75
+ elif example == "SQL Injection":
76
+ default_code = '''def get_user(user_id):
77
+ query = "SELECT * FROM users WHERE id=" + user_id
78
+ cursor.execute(query)
79
+ return cursor.fetchone()'''
80
+ elif example == "Safe Code":
81
+ default_code = '''def add_numbers(a, b):
82
+ return a + b'''
83
+ else:
84
+ default_code = ""
85
+
86
+ code_input = st.text_area(
87
+ "Paste your code here:",
88
+ value=default_code,
89
+ height=300,
90
+ placeholder="Enter source code to analyze..."
91
+ )
92
+
93
+ analyze_button = st.button("🔍 Analyze Code", type="primary", use_container_width=True)
94
+
95
+ with col2:
96
+ st.header("📊 Analysis Results")
97
+
98
+ if analyze_button and code_input.strip():
99
+ with st.spinner("Analyzing code..."):
100
+ try:
101
+ result = detector.predict(code_input)
102
+
103
+ # Display result
104
+ if result['prediction'] == 1:
105
+ st.error(f"⚠️ {result['label']}")
106
+ st.progress(result['probabilities']['vulnerable'])
107
+ else:
108
+ st.success(f"✅ {result['label']}")
109
+ st.progress(result['probabilities']['safe'])
110
+
111
+ # Confidence metrics
112
+ st.subheader("Confidence Breakdown")
113
+ col_a, col_b = st.columns(2)
114
+
115
+ with col_a:
116
+ st.metric(
117
+ "Safe Probability",
118
+ f"{result['probabilities']['safe']:.1%}",
119
+ delta=None
120
+ )
121
+
122
+ with col_b:
123
+ st.metric(
124
+ "Vulnerable Probability",
125
+ f"{result['probabilities']['vulnerable']:.1%}",
126
+ delta=None
127
+ )
128
+
129
+ # Recommendations
130
+ if result['prediction'] == 1:
131
+ st.subheader("🛡️ Recommendations")
132
+ st.warning("""
133
+ **This code appears to have security vulnerabilities.**
134
+
135
+ Common fixes:
136
+ - Use bounds-checked functions (strncpy instead of strcpy)
137
+ - Use parameterized queries for SQL
138
+ - Validate and sanitize all user inputs
139
+ - Avoid eval() and system() with user input
140
+ """)
141
+ else:
142
+ st.subheader("Good Practices")
143
+ st.info("""
144
+ This code appears to follow security best practices!
145
+
146
+ Remember to:
147
+ - Keep dependencies updated
148
+ - Perform regular security audits
149
+ - Use static analysis tools
150
+ - Follow OWASP guidelines
151
+ """)
152
+
153
+ except Exception as e:
154
+ st.error(f"Error during analysis: {e}")
155
+
156
+ elif analyze_button:
157
+ st.warning("Please enter some code to analyze.")
158
+
159
+ if __name__ == "__main__":
160
+ main()
models/best_model_clean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a1b98dd49c9eddf98d8e95f612f6467c10a9f98a2a4b76b0770c84ea88a674c
3
+ size 894029464
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
1
+ streamlit==1.28.2
2
+ torch==2.10.0
3
+ transformers==4.57.1
4
+ sentencepiece
5
+ numpy==1.26.2
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.10
save_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from src.model import VulnerabilityCodeT5
3
+
4
+ # Load original big checkpoint
5
+ checkpoint = torch.load("models/best_model.pt", map_location="cpu")
6
+
7
+ # Initialize model
8
+ model = VulnerabilityCodeT5(num_labels=2)
9
+
10
+ # Load only model weights
11
+ model.load_state_dict(checkpoint['model_state_dict'])
12
+
13
+ # Save clean weights only
14
+ torch.save(model.state_dict(), "models/best_model_clean.pt")
15
+
16
+ print("Saved clean model.")
src/__pycache__/inference.cpython-312.pyc ADDED
Binary file (6 kB). View file
 
src/__pycache__/model.cpython-312.pyc ADDED
Binary file (3.66 kB). View file
 
src/data.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RobertaTokenizer
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import torch
4
+ import json
5
+ from pathlib import Path
6
+
7
+
8
+ class VulnerabilityDataset(Dataset):
9
+ """PyTorch dataset for vulnerability detection"""
10
+
11
+ def __init__(self, data_path, tokenizer, max_length=512):
12
+ self.tokenizer = tokenizer
13
+ self.max_length = max_length
14
+
15
+ self.data = []
16
+ data_path = Path(data_path)
17
+
18
+ if not data_path.exists():
19
+ raise FileNotFoundError(f"Dataset file not found: {data_path}")
20
+
21
+ with open(data_path, "r", encoding="utf-8") as f:
22
+ for line in f:
23
+ line = line.strip()
24
+ if line:
25
+ self.data.append(json.loads(line))
26
+
27
+ print(f"{data_path.name}: {len(self.data)} samples")
28
+
29
+ def __len__(self):
30
+ return len(self.data)
31
+
32
+ def __getitem__(self, idx):
33
+ sample = self.data[idx]
34
+
35
+ code = sample["func"] # confirmed correct
36
+ label = sample["target"] # confirmed correct (0/1)
37
+
38
+ encoding = self.tokenizer(
39
+ code,
40
+ truncation=True,
41
+ padding="max_length",
42
+ max_length=self.max_length,
43
+ return_tensors="pt"
44
+ )
45
+
46
+ return {
47
+ "input_ids": encoding["input_ids"].squeeze(0),
48
+ "attention_mask": encoding["attention_mask"].squeeze(0),
49
+ "labels": torch.tensor(label, dtype=torch.long)
50
+ }
51
+
52
+
53
+ def load_tokenizer(model_name="Salesforce/codet5-base"):
54
+ print(f"Tokenizer: {model_name}")
55
+ return RobertaTokenizer.from_pretrained(model_name)
56
+
57
+
58
+ def create_dataloader(
59
+ train_path,
60
+ valid_path,
61
+ test_path,
62
+ tokenizer,
63
+ batch_size=8,
64
+ max_length=512,
65
+ num_workers=2,
66
+ ):
67
+ train_dataset = VulnerabilityDataset(train_path, tokenizer, max_length)
68
+ valid_dataset = VulnerabilityDataset(valid_path, tokenizer, max_length)
69
+ test_dataset = VulnerabilityDataset(test_path, tokenizer, max_length)
70
+
71
+ if len(train_dataset) == 0:
72
+ raise RuntimeError(f"No samples found in {train_path}")
73
+
74
+ train_loader = DataLoader(
75
+ train_dataset,
76
+ batch_size=batch_size,
77
+ shuffle=True,
78
+ num_workers=num_workers,
79
+ pin_memory=True,
80
+ persistent_workers=True
81
+ )
82
+
83
+ valid_loader = DataLoader(
84
+ valid_dataset,
85
+ batch_size=batch_size,
86
+ shuffle=False,
87
+ num_workers=num_workers,
88
+ pin_memory=True,
89
+ persistent_workers=True
90
+ )
91
+
92
+ test_loader = DataLoader(
93
+ test_dataset,
94
+ batch_size=batch_size,
95
+ shuffle=False,
96
+ num_workers=num_workers,
97
+ pin_memory=True,
98
+ persistent_workers=True
99
+ )
100
+
101
+ return train_loader, valid_loader, test_loader
src/inference.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference module for vulnerability detection
2
+ Load trained models and make predictions"""
3
+
4
+ import torch
5
+ from transformers import RobertaTokenizer
6
+ from pathlib import Path
7
+ import sys
8
+ sys.path.append(str(Path(__file__).parent.parent.parent))
9
+
10
+ from src.model import VulnerabilityCodeT5
11
+
12
+ class VulnerabilityDetector:
13
+ def __init__(self, model_path="models/best_model.pt",
14
+ model_name="Salesforce/codet5-base", max_length=256):
15
+
16
+ ### CHANGED FOR DEPLOYMENT
17
+ self.device = torch.device('cpu')
18
+ self.max_length = max_length
19
+
20
+ self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
21
+
22
+ self.model = VulnerabilityCodeT5(model_name=model_name, num_labels=2)
23
+
24
+ state_dict = torch.load(model_path, map_location=self.device)
25
+ self.model.load_state_dict(state_dict)
26
+ self.model.to(self.device)
27
+ self.model.eval()
28
+
29
+
30
+ print("Model Loaded Successfully")
31
+
32
+ self.labels = {
33
+ 0: "Safe Code",
34
+ 1: "Vulnerable Code"
35
+ }
36
+
37
+ def predict(self, code_snippet):
38
+ """Predict Vulnerability of Code Snippet
39
+
40
+ Args :
41
+ code_snippet: String Containing source code
42
+
43
+ Returns:
44
+ dict with predictions, confidence and label
45
+
46
+ """
47
+ inputs = self.tokenizer(
48
+ code_snippet,
49
+ max_length=256,
50
+ padding='max_length',
51
+ truncation=True,
52
+ return_tensors='pt'
53
+ )
54
+
55
+ input_ids = inputs['input_ids'].to(self.device)
56
+ attention_mask = inputs['attention_mask'].to(self.device)
57
+
58
+ with torch.no_grad():
59
+
60
+ predictions, probs = self.model.predict(input_ids, attention_mask)
61
+
62
+ pred_label = predictions[0].item()
63
+ confidence = probs[0][pred_label].item()
64
+
65
+ return {
66
+ 'prediction': pred_label,
67
+ 'label': self.labels[pred_label],
68
+ 'confidence': confidence,
69
+ 'probabilities':{
70
+ 'safe': probs[0][0].item(),
71
+ 'vulnerable': probs[0][1].item()
72
+ }
73
+ }
74
+
75
+ def analyze_batch(self, code_snippets):
76
+ """Analyze multiple code snippets at once"""
77
+ return [self.predict(code) for code in code_snippets]
78
+
79
+ def test_inference():
80
+ detector = VulnerabilityDetector()
81
+
82
+
83
+
84
+
85
+ test_cases = [
86
+ {
87
+ "name": "Safe Bounded Copy",
88
+ "code": """void copy_input(const char *input) {
89
+ char buffer[32];
90
+ strncpy(buffer, input, sizeof(buffer) - 1);
91
+ buffer[sizeof(buffer) - 1] = '\\0';
92
+ }"""
93
+ },
94
+ {
95
+ "name": "Safe fgets Input",
96
+ "code": """void read_input() {
97
+ char buffer[64];
98
+ if (fgets(buffer, sizeof(buffer), stdin) != NULL) {
99
+ printf("%s", buffer);
100
+ }
101
+ }"""
102
+ },
103
+ {
104
+ "name": "Safe malloc usage",
105
+ "code": """void allocate() {
106
+ char *buf = (char *)malloc(128);
107
+ if (buf == NULL) {
108
+ return;
109
+ }
110
+ strcpy(buf, "safe");
111
+ free(buf);
112
+ }"""
113
+ },
114
+ {
115
+ "name": "Stack Buffer Overflow",
116
+ "code": """void copy_input(char *input) {
117
+ char buffer[8];
118
+ strcpy(buffer, input);
119
+ }"""
120
+ },
121
+ {
122
+ "name": "Integer Overflow",
123
+ "code": """void allocate(int size) {
124
+ char *buf = (char *)malloc(size * sizeof(char));
125
+ if (buf == NULL) return;
126
+ memset(buf, 'A', size + 10);
127
+ }"""
128
+ },
129
+ {
130
+ "name": "Use After Free",
131
+ "code": """void uaf() {
132
+ char *buf = (char *)malloc(16);
133
+ free(buf);
134
+ strcpy(buf, "UAF");
135
+ }"""
136
+ }
137
+ ]
138
+
139
+
140
+ print("\n" + "="*60)
141
+ print("Testing Vulnerability Detection")
142
+ print("="*60)
143
+
144
+ for test in test_cases:
145
+ print(f"\nTest: {test['name']}")
146
+ print(f"Code: {test['code'][:60]}...")
147
+
148
+ result = detector.predict(test['code'])
149
+
150
+ print(f"Prediction: {result['label']}")
151
+ print(f"Confidence: {result['confidence']:.2%}")
152
+ print(f" - Safe: {result['probabilities']['safe']:.2%}")
153
+ print(f" - Vulnerable: {result['probabilities']['vulnerable']:.2%}")
154
+
155
+ if __name__ == "__main__":
156
+ test_inference()
src/model.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CodeT5 Vulnerability Detection model
2
+ Binary Classication Safe(0) vs Vulnerable(1)"""
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import T5ForConditionalGeneration, RobertaTokenizer
7
+
8
+ class VulnerabilityCodeT5(nn.Module):
9
+ """CodeT5 model for vulnerability detection"""
10
+
11
+ def __init__(self, model_name="Salesforce/codet5-base", num_labels=2):
12
+ super().__init__()
13
+
14
+ self.encoder_decoder = T5ForConditionalGeneration.from_pretrained(model_name)
15
+
16
+ #Get hidden size from config
17
+ hidden_size = self.encoder_decoder.config.d_model #768 for base
18
+
19
+ #Classification Head
20
+ self.classifier = nn.Sequential(
21
+ nn.Dropout(0.1),
22
+ nn.Linear(hidden_size, hidden_size),
23
+ nn.ReLU(),
24
+ nn.Dropout(0.1),
25
+ nn.Linear(hidden_size, num_labels)
26
+ )
27
+
28
+ self.num_labels = num_labels
29
+
30
+ def forward(self, input_ids, attention_mask, labels=None):
31
+ """
32
+ Forward pass
33
+ Args:
34
+ input_ids : tokenized code [batch_size, seq_len]
35
+ attention_mask : attention mask [batch_size, seq_len]
36
+ labels: ground truth labels [batch_size]
37
+ """
38
+
39
+ #Get encoder outputs
40
+ encoder_outputs = self.encoder_decoder.encoder(
41
+ input_ids=input_ids,
42
+ attention_mask=attention_mask,
43
+ return_dict=True
44
+ )
45
+
46
+ #Pool encoder outputs (use first token [CLS])
47
+ hidden_state = encoder_outputs.last_hidden_state # [batch, seq_len, hidden]
48
+ pooled_output = hidden_state[:, 0, :] # [batch, hidden]
49
+
50
+ #Classification
51
+ logits = self.classifier(pooled_output) # [batch, num_labels]
52
+
53
+ #Calculate loss
54
+ loss = None
55
+ if labels is not None:
56
+ loss_fn = nn.CrossEntropyLoss()
57
+ loss = loss_fn(logits, labels)
58
+
59
+ return {
60
+ 'loss': loss,
61
+ 'logits': logits,
62
+ 'hidden_states': hidden_state
63
+ }
64
+
65
+ def predict(self, input_ids, attention_mask):
66
+ """Make Predictions"""
67
+ self.eval()
68
+ with torch.no_grad():
69
+ outputs = self.forward(input_ids, attention_mask)
70
+ probs = torch.softmax(outputs["logits"], dim=1)
71
+ predictions = torch.argmax(probs, dim=1)
72
+
73
+ return predictions, probs
74
+
75
+ def count_parameters(model):
76
+ """Count trainable parameters"""
77
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
src/train.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.optim import AdamW
4
+ from torch.amp import autocast, GradScaler
5
+ from transformers import get_linear_schedule_with_warmup
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ import argparse
9
+ import json
10
+ import gc
11
+ import sys
12
+
13
+ sys.path.append(str(Path(__file__).parent.parent))
14
+
15
+ from src.v2.data_processor import load_tokenizer, create_dataloader
16
+ from src.v2.model import VulnerabilityCodeT5, count_parameters
17
+
18
+
19
+ class Trainer:
20
+ def __init__(
21
+ self,
22
+ model,
23
+ train_loader,
24
+ valid_loader,
25
+ device,
26
+ learning_rate=2e-5,
27
+ num_epochs=5,
28
+ gradient_accumulation_steps=4,
29
+ ):
30
+ self.model = model.to(device)
31
+ self.train_loader = train_loader
32
+ self.valid_loader = valid_loader
33
+ self.device = device
34
+ self.num_epochs = num_epochs
35
+ self.gradient_accumulation_steps = gradient_accumulation_steps
36
+
37
+ self.use_amp = device.type == "cuda"
38
+ self.scaler = GradScaler(enabled=self.use_amp)
39
+
40
+ self.optimizer = AdamW(
41
+ self.model.parameters(), lr=learning_rate, weight_decay=0.01
42
+ )
43
+
44
+ total_steps = (
45
+ len(self.train_loader) * num_epochs
46
+ ) // gradient_accumulation_steps
47
+
48
+ self.scheduler = get_linear_schedule_with_warmup(
49
+ self.optimizer,
50
+ num_warmup_steps=max(1, total_steps // 10),
51
+ num_training_steps=total_steps,
52
+ )
53
+
54
+ self.best_val_acc = 0.0
55
+ self.history = {
56
+ "train_loss": [],
57
+ "train_acc": [],
58
+ "val_loss": [],
59
+ "val_acc": [],
60
+ }
61
+
62
+ def clear_memory(self):
63
+ if torch.cuda.is_available():
64
+ torch.cuda.empty_cache()
65
+ gc.collect()
66
+
67
+ def train_epoch(self):
68
+ self.model.train()
69
+ total_loss = 0.0
70
+ correct = 0
71
+ total = 0
72
+
73
+ self.optimizer.zero_grad(set_to_none=True)
74
+
75
+ pbar = tqdm(self.train_loader, desc="Training")
76
+
77
+ for step, batch in enumerate(pbar):
78
+ input_ids = batch["input_ids"].to(self.device, non_blocking=True)
79
+ attention_mask = batch["attention_mask"].to(self.device, non_blocking=True)
80
+ labels = batch["labels"].to(self.device, non_blocking=True)
81
+
82
+ with autocast(device_type="cuda", enabled=self.use_amp):
83
+ outputs = self.model(input_ids, attention_mask, labels)
84
+ loss = outputs["loss"] / self.gradient_accumulation_steps
85
+
86
+ self.scaler.scale(loss).backward()
87
+
88
+ if (step + 1) % self.gradient_accumulation_steps == 0:
89
+ self.scaler.unscale_(self.optimizer)
90
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
91
+
92
+ self.scaler.step(self.optimizer)
93
+ self.scaler.update()
94
+ self.scheduler.step()
95
+ self.optimizer.zero_grad(set_to_none=True)
96
+
97
+ with torch.no_grad():
98
+ preds = torch.argmax(outputs["logits"], dim=1)
99
+ correct += (preds == labels).sum().item()
100
+ total += labels.size(0)
101
+
102
+ total_loss += loss.item() * self.gradient_accumulation_steps
103
+
104
+ gpu_mem = (
105
+ torch.cuda.memory_allocated() / 1024 ** 3
106
+ if torch.cuda.is_available()
107
+ else 0
108
+ )
109
+
110
+ pbar.set_postfix(
111
+ {
112
+ "loss": f"{loss.item() * self.gradient_accumulation_steps:.4f}",
113
+ "acc": f"{100 * correct / max(1, total):.2f}%",
114
+ "gpu": f"{gpu_mem:.2f}GB",
115
+ }
116
+ )
117
+
118
+ del input_ids, attention_mask, labels, outputs, loss
119
+
120
+ self.clear_memory()
121
+
122
+ return total_loss / len(self.train_loader), 100 * correct / total
123
+
124
+ def validate(self):
125
+ self.model.eval()
126
+ total_loss = 0.0
127
+ correct = 0
128
+ total = 0
129
+
130
+ with torch.no_grad():
131
+ for batch in tqdm(self.valid_loader, desc="Validating"):
132
+ input_ids = batch["input_ids"].to(self.device)
133
+ attention_mask = batch["attention_mask"].to(self.device)
134
+ labels = batch["labels"].to(self.device)
135
+
136
+ with autocast(device_type="cuda", enabled=self.use_amp):
137
+ outputs = self.model(input_ids, attention_mask, labels)
138
+ loss = outputs["loss"]
139
+
140
+ preds = torch.argmax(outputs["logits"], dim=1)
141
+ correct += (preds == labels).sum().item()
142
+ total += labels.size(0)
143
+ total_loss += loss.item()
144
+
145
+ self.clear_memory()
146
+ return total_loss / len(self.valid_loader), 100 * correct / total
147
+
148
+ def train(self, save_dir="models/v2"):
149
+ print(f"Training samples: {len(self.train_loader.dataset)}")
150
+ print(f"Validation samples: {len(self.valid_loader.dataset)}")
151
+ if torch.cuda.is_available():
152
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
153
+
154
+ save_dir = Path(save_dir)
155
+ save_dir.mkdir(parents=True, exist_ok=True)
156
+
157
+ for epoch in range(self.num_epochs):
158
+ print(f"\n{'=' * 60}")
159
+ print(f"Epoch {epoch + 1}/{self.num_epochs}")
160
+ print(f"{'=' * 60}")
161
+
162
+ train_loss, train_acc = self.train_epoch()
163
+ val_loss, val_acc = self.validate()
164
+
165
+ print(
166
+ f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%"
167
+ )
168
+ print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
169
+
170
+ self.history["train_loss"].append(train_loss)
171
+ self.history["train_acc"].append(train_acc)
172
+ self.history["val_loss"].append(val_loss)
173
+ self.history["val_acc"].append(val_acc)
174
+
175
+ if val_acc > self.best_val_acc:
176
+ self.best_val_acc = val_acc
177
+ torch.save(
178
+ {
179
+ "model_state_dict": self.model.state_dict(),
180
+ "optimizer_state_dict": self.optimizer.state_dict(),
181
+ "val_acc": val_acc,
182
+ },
183
+ save_dir / "best_model.pt",
184
+ )
185
+ print("Saved best model")
186
+
187
+ torch.save(
188
+ {
189
+ "model_state_dict": self.model.state_dict(),
190
+ "history": self.history,
191
+ },
192
+ save_dir / "final_model.pt",
193
+ )
194
+
195
+ with open(save_dir / "training_history.json", "w") as f:
196
+ json.dump(self.history, f, indent=2)
197
+
198
+ print(f"\nTraining complete. Best Val Acc: {self.best_val_acc:.2f}%")
199
+
200
+
201
+ def main(args):
202
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
203
+
204
+ data_dir = (
205
+ Path("data/processed/sample") if args.use_sample else Path("data/processed")
206
+ )
207
+
208
+ train_path = data_dir / "train.jsonl"
209
+ valid_path = data_dir / "valid.jsonl"
210
+ test_path = data_dir / "test.jsonl"
211
+
212
+ tokenizer = load_tokenizer(args.model_name)
213
+
214
+ train_loader, valid_loader, test_loader = create_dataloader(
215
+ train_path,
216
+ valid_path,
217
+ test_path,
218
+ tokenizer,
219
+ batch_size=args.batch_size,
220
+ max_length=args.max_length,
221
+ num_workers=2,
222
+ )
223
+
224
+ model = VulnerabilityCodeT5(model_name=args.model_name, num_labels=2)
225
+ print(f"Trainable parameters: {count_parameters(model):,}")
226
+
227
+ trainer = Trainer(
228
+ model,
229
+ train_loader,
230
+ valid_loader,
231
+ device,
232
+ learning_rate=args.learning_rate,
233
+ num_epochs=args.epochs,
234
+ gradient_accumulation_steps=args.gradient_accumulation,
235
+ )
236
+
237
+ trainer.train(args.output_dir)
238
+
239
+
240
+ if __name__ == "__main__":
241
+ parser = argparse.ArgumentParser()
242
+ parser.add_argument("--model_name", default="Salesforce/codet5-base")
243
+ parser.add_argument("--batch_size", type=int, default=4)
244
+ parser.add_argument("--max_length", type=int, default=256)
245
+ parser.add_argument("--learning_rate", type=float, default=2e-5)
246
+ parser.add_argument("--epochs", type=int, default=3)
247
+ parser.add_argument("--gradient_accumulation", type=int, default=4)
248
+ parser.add_argument("--output_dir", default="models/v2")
249
+ parser.add_argument("--use_sample", action="store_true")
250
+
251
+ main(parser.parse_args())