File size: 415 Bytes
4b82ab5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import torch
from src.model import VulnerabilityCodeT5
# Load original big checkpoint
checkpoint = torch.load("models/best_model.pt", map_location="cpu")
# Initialize model
model = VulnerabilityCodeT5(num_labels=2)
# Load only model weights
model.load_state_dict(checkpoint['model_state_dict'])
# Save clean weights only
torch.save(model.state_dict(), "models/best_model_clean.pt")
print("Saved clean model.")
|