| |
| |
| |
| |
| |
|
|
|
|
| from pathlib import Path |
|
|
| import torch |
|
|
|
|
| def convert(checkpoint: str, outdir: str, suffix: str = "base"): |
| """Convert the checkpoint to generator and detector""" |
| outdir_path = Path(outdir) |
| ckpt = torch.load(checkpoint) |
|
|
| |
| infer_cfg = { |
| "seanet": ckpt["xp.cfg"]["seanet"], |
| "channels": ckpt["xp.cfg"]["channels"], |
| "dtype": ckpt["xp.cfg"]["dtype"], |
| "sample_rate": ckpt["xp.cfg"]["sample_rate"], |
| } |
|
|
| generator_ckpt = {"xp.cfg": infer_cfg, "model": {}} |
| detector_ckpt = {"xp.cfg": infer_cfg, "model": {}} |
|
|
| for layer in ckpt["model"].keys(): |
| if layer.startswith("detector"): |
| new_layer = layer[9:] |
| detector_ckpt["model"][new_layer] = ckpt["model"][layer] |
| elif layer == "msg_processor.msg_processor.0.weight": |
| generator_ckpt["model"]["msg_processor.msg_processor.weight"] = ckpt[ |
| "model" |
| ][ |
| layer |
| ] |
| else: |
| assert layer.startswith("generator"), f"Invalid layer: {layer}" |
| new_layer = layer[10:] |
| generator_ckpt["model"][new_layer] = ckpt["model"][layer] |
|
|
| torch.save(generator_ckpt, outdir_path / (f"checkpoint_generator_{suffix}.pth")) |
| torch.save(detector_ckpt, outdir_path / (f"checkpoint_detector_{suffix}.pth")) |
|
|
|
|
| if __name__ == "__main__": |
| import fire |
|
|
| fire.Fire(convert) |
|
|