| import argparse |
| import subprocess |
|
|
| import torch |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser( |
| description='Process a checkpoint to be published') |
| parser.add_argument('in_file', help='input checkpoint filename') |
| parser.add_argument('out_file', help='output checkpoint filename') |
| args = parser.parse_args() |
| return args |
|
|
|
|
| def process_checkpoint(in_file, out_file): |
| checkpoint = torch.load(in_file, map_location='cpu') |
| |
| if 'optimizer' in checkpoint: |
| del checkpoint['optimizer'] |
| |
| |
| torch.save(checkpoint, out_file) |
| sha = subprocess.check_output(['sha256sum', out_file]).decode() |
| final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) |
| subprocess.Popen(['mv', out_file, final_file]) |
|
|
|
|
| def main(): |
| args = parse_args() |
| process_checkpoint(args.in_file, args.out_file) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|