File size: 415 Bytes
4b82ab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from src.model import VulnerabilityCodeT5

# Load original big checkpoint
checkpoint = torch.load("models/best_model.pt", map_location="cpu")

# Initialize model
model = VulnerabilityCodeT5(num_labels=2)

# Load only model weights
model.load_state_dict(checkpoint['model_state_dict'])

# Save clean weights only
torch.save(model.state_dict(), "models/best_model_clean.pt")

print("Saved clean model.")