| import torch | |
| from src.model import VulnerabilityCodeT5 | |
| # Load original big checkpoint | |
| checkpoint = torch.load("models/best_model.pt", map_location="cpu") | |
| # Initialize model | |
| model = VulnerabilityCodeT5(num_labels=2) | |
| # Load only model weights | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| # Save clean weights only | |
| torch.save(model.state_dict(), "models/best_model_clean.pt") | |
| print("Saved clean model.") | |