| |
|
|
| import contextlib |
| from unittest import mock |
| import torch |
|
|
| from detectron2.modeling import poolers |
| from detectron2.modeling.proposal_generator import rpn |
| from detectron2.modeling.roi_heads import keypoint_head, mask_head |
| from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers |
|
|
| from .c10 import ( |
| Caffe2Compatible, |
| Caffe2FastRCNNOutputsInference, |
| Caffe2KeypointRCNNInference, |
| Caffe2MaskRCNNInference, |
| Caffe2ROIPooler, |
| Caffe2RPN, |
| caffe2_fast_rcnn_outputs_inference, |
| caffe2_keypoint_rcnn_inference, |
| caffe2_mask_rcnn_inference, |
| ) |
|
|
|
|
| class GenericMixin: |
| pass |
|
|
|
|
| class Caffe2CompatibleConverter: |
| """ |
| A GenericUpdater which implements the `create_from` interface, by modifying |
| module object and assign it with another class replaceCls. |
| """ |
|
|
| def __init__(self, replaceCls): |
| self.replaceCls = replaceCls |
|
|
| def create_from(self, module): |
| |
| assert isinstance(module, torch.nn.Module) |
| if issubclass(self.replaceCls, GenericMixin): |
| |
| new_class = type( |
| "{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__), |
| (self.replaceCls, module.__class__), |
| {}, |
| ) |
| module.__class__ = new_class |
| else: |
| |
| module.__class__ = self.replaceCls |
|
|
| |
| if isinstance(module, Caffe2Compatible): |
| module.tensor_mode = False |
|
|
| return module |
|
|
|
|
| def patch(model, target, updater, *args, **kwargs): |
| """ |
| recursively (post-order) update all modules with the target type and its |
| subclasses, make a initialization/composition/inheritance/... via the |
| updater.create_from. |
| """ |
| for name, module in model.named_children(): |
| model._modules[name] = patch(module, target, updater, *args, **kwargs) |
| if isinstance(model, target): |
| return updater.create_from(model, *args, **kwargs) |
| return model |
|
|
|
|
| def patch_generalized_rcnn(model): |
| ccc = Caffe2CompatibleConverter |
| model = patch(model, rpn.RPN, ccc(Caffe2RPN)) |
| model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler)) |
|
|
| return model |
|
|
|
|
| @contextlib.contextmanager |
| def mock_fastrcnn_outputs_inference( |
| tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers |
| ): |
| with mock.patch.object( |
| box_predictor_type, |
| "inference", |
| autospec=True, |
| side_effect=Caffe2FastRCNNOutputsInference(tensor_mode), |
| ) as mocked_func: |
| yield |
| if check: |
| assert mocked_func.call_count > 0 |
|
|
|
|
| @contextlib.contextmanager |
| def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True): |
| with mock.patch( |
| "{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference() |
| ) as mocked_func: |
| yield |
| if check: |
| assert mocked_func.call_count > 0 |
|
|
|
|
| @contextlib.contextmanager |
| def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True): |
| with mock.patch( |
| "{}.keypoint_rcnn_inference".format(patched_module), |
| side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint), |
| ) as mocked_func: |
| yield |
| if check: |
| assert mocked_func.call_count > 0 |
|
|
|
|
| class ROIHeadsPatcher: |
| def __init__(self, heads, use_heatmap_max_keypoint): |
| self.heads = heads |
| self.use_heatmap_max_keypoint = use_heatmap_max_keypoint |
| self.previous_patched = {} |
|
|
| @contextlib.contextmanager |
| def mock_roi_heads(self, tensor_mode=True): |
| """ |
| Patching several inference functions inside ROIHeads and its subclasses |
| |
| Args: |
| tensor_mode (bool): whether the inputs/outputs are caffe2's tensor |
| format or not. Default to True. |
| """ |
| |
| |
| kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__ |
| mask_head_mod = mask_head.BaseMaskRCNNHead.__module__ |
|
|
| mock_ctx_managers = [ |
| mock_fastrcnn_outputs_inference( |
| tensor_mode=tensor_mode, |
| check=True, |
| box_predictor_type=type(self.heads.box_predictor), |
| ) |
| ] |
| if getattr(self.heads, "keypoint_on", False): |
| mock_ctx_managers += [ |
| mock_keypoint_rcnn_inference( |
| tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint |
| ) |
| ] |
| if getattr(self.heads, "mask_on", False): |
| mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)] |
|
|
| with contextlib.ExitStack() as stack: |
| for mgr in mock_ctx_managers: |
| stack.enter_context(mgr) |
| yield |
|
|
| def patch_roi_heads(self, tensor_mode=True): |
| self.previous_patched["box_predictor"] = self.heads.box_predictor.inference |
| self.previous_patched["keypoint_rcnn"] = keypoint_head.keypoint_rcnn_inference |
| self.previous_patched["mask_rcnn"] = mask_head.mask_rcnn_inference |
|
|
| def patched_fastrcnn_outputs_inference(predictions, proposal): |
| return caffe2_fast_rcnn_outputs_inference( |
| True, self.heads.box_predictor, predictions, proposal |
| ) |
|
|
| self.heads.box_predictor.inference = patched_fastrcnn_outputs_inference |
|
|
| if getattr(self.heads, "keypoint_on", False): |
|
|
| def patched_keypoint_rcnn_inference(pred_keypoint_logits, pred_instances): |
| return caffe2_keypoint_rcnn_inference( |
| self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances |
| ) |
|
|
| keypoint_head.keypoint_rcnn_inference = patched_keypoint_rcnn_inference |
|
|
| if getattr(self.heads, "mask_on", False): |
|
|
| def patched_mask_rcnn_inference(pred_mask_logits, pred_instances): |
| return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances) |
|
|
| mask_head.mask_rcnn_inference = patched_mask_rcnn_inference |
|
|
| def unpatch_roi_heads(self): |
| self.heads.box_predictor.inference = self.previous_patched["box_predictor"] |
| keypoint_head.keypoint_rcnn_inference = self.previous_patched["keypoint_rcnn"] |
| mask_head.mask_rcnn_inference = self.previous_patched["mask_rcnn"] |
|
|