|
|
| from collections import OrderedDict
|
|
|
| from detectron2.checkpoint import DetectionCheckpointer
|
|
|
|
|
| def _rename_HRNet_weights(weights):
|
|
|
|
|
| if (
|
| len(weights["model"].keys()) == 1956
|
| and len([k for k in weights["model"].keys() if k.startswith("stage")]) == 1716
|
| ):
|
| hrnet_weights = OrderedDict()
|
| for k in weights["model"].keys():
|
| hrnet_weights["backbone.bottom_up." + str(k)] = weights["model"][k]
|
| return {"model": hrnet_weights}
|
| else:
|
| return weights
|
|
|
|
|
| class DensePoseCheckpointer(DetectionCheckpointer):
|
| """
|
| Same as :class:`DetectionCheckpointer`, but is able to handle HRNet weights
|
| """
|
|
|
| def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables):
|
| super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables)
|
|
|
| def _load_file(self, filename: str) -> object:
|
| """
|
| Adding hrnet support
|
| """
|
| weights = super()._load_file(filename)
|
| return _rename_HRNet_weights(weights)
|
|
|