|
|
|
|
| import contextlib
|
| from unittest import mock
|
| import torch
|
|
|
| from detectron2.modeling import poolers
|
| from detectron2.modeling.proposal_generator import rpn
|
| from detectron2.modeling.roi_heads import keypoint_head, mask_head
|
| from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
|
|
|
| from .c10 import (
|
| Caffe2Compatible,
|
| Caffe2FastRCNNOutputsInference,
|
| Caffe2KeypointRCNNInference,
|
| Caffe2MaskRCNNInference,
|
| Caffe2ROIPooler,
|
| Caffe2RPN,
|
| caffe2_fast_rcnn_outputs_inference,
|
| caffe2_keypoint_rcnn_inference,
|
| caffe2_mask_rcnn_inference,
|
| )
|
|
|
|
|
| class GenericMixin:
|
| pass
|
|
|
|
|
| class Caffe2CompatibleConverter:
|
| """
|
| A GenericUpdater which implements the `create_from` interface, by modifying
|
| module object and assign it with another class replaceCls.
|
| """
|
|
|
| def __init__(self, replaceCls):
|
| self.replaceCls = replaceCls
|
|
|
| def create_from(self, module):
|
|
|
| assert isinstance(module, torch.nn.Module)
|
| if issubclass(self.replaceCls, GenericMixin):
|
|
|
| new_class = type(
|
| "{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__),
|
| (self.replaceCls, module.__class__),
|
| {},
|
| )
|
| module.__class__ = new_class
|
| else:
|
|
|
| module.__class__ = self.replaceCls
|
|
|
|
|
| if isinstance(module, Caffe2Compatible):
|
| module.tensor_mode = False
|
|
|
| return module
|
|
|
|
|
| def patch(model, target, updater, *args, **kwargs):
|
| """
|
| recursively (post-order) update all modules with the target type and its
|
| subclasses, make a initialization/composition/inheritance/... via the
|
| updater.create_from.
|
| """
|
| for name, module in model.named_children():
|
| model._modules[name] = patch(module, target, updater, *args, **kwargs)
|
| if isinstance(model, target):
|
| return updater.create_from(model, *args, **kwargs)
|
| return model
|
|
|
|
|
| def patch_generalized_rcnn(model):
|
| ccc = Caffe2CompatibleConverter
|
| model = patch(model, rpn.RPN, ccc(Caffe2RPN))
|
| model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler))
|
|
|
| return model
|
|
|
|
|
| @contextlib.contextmanager
|
| def mock_fastrcnn_outputs_inference(
|
| tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers
|
| ):
|
| with mock.patch.object(
|
| box_predictor_type,
|
| "inference",
|
| autospec=True,
|
| side_effect=Caffe2FastRCNNOutputsInference(tensor_mode),
|
| ) as mocked_func:
|
| yield
|
| if check:
|
| assert mocked_func.call_count > 0
|
|
|
|
|
| @contextlib.contextmanager
|
| def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True):
|
| with mock.patch(
|
| "{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference()
|
| ) as mocked_func:
|
| yield
|
| if check:
|
| assert mocked_func.call_count > 0
|
|
|
|
|
| @contextlib.contextmanager
|
| def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True):
|
| with mock.patch(
|
| "{}.keypoint_rcnn_inference".format(patched_module),
|
| side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint),
|
| ) as mocked_func:
|
| yield
|
| if check:
|
| assert mocked_func.call_count > 0
|
|
|
|
|
| class ROIHeadsPatcher:
|
| def __init__(self, heads, use_heatmap_max_keypoint):
|
| self.heads = heads
|
| self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
|
| self.previous_patched = {}
|
|
|
| @contextlib.contextmanager
|
| def mock_roi_heads(self, tensor_mode=True):
|
| """
|
| Patching several inference functions inside ROIHeads and its subclasses
|
|
|
| Args:
|
| tensor_mode (bool): whether the inputs/outputs are caffe2's tensor
|
| format or not. Default to True.
|
| """
|
|
|
|
|
| kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__
|
| mask_head_mod = mask_head.BaseMaskRCNNHead.__module__
|
|
|
| mock_ctx_managers = [
|
| mock_fastrcnn_outputs_inference(
|
| tensor_mode=tensor_mode,
|
| check=True,
|
| box_predictor_type=type(self.heads.box_predictor),
|
| )
|
| ]
|
| if getattr(self.heads, "keypoint_on", False):
|
| mock_ctx_managers += [
|
| mock_keypoint_rcnn_inference(
|
| tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint
|
| )
|
| ]
|
| if getattr(self.heads, "mask_on", False):
|
| mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)]
|
|
|
| with contextlib.ExitStack() as stack:
|
| for mgr in mock_ctx_managers:
|
| stack.enter_context(mgr)
|
| yield
|
|
|
| def patch_roi_heads(self, tensor_mode=True):
|
| self.previous_patched["box_predictor"] = self.heads.box_predictor.inference
|
| self.previous_patched["keypoint_rcnn"] = keypoint_head.keypoint_rcnn_inference
|
| self.previous_patched["mask_rcnn"] = mask_head.mask_rcnn_inference
|
|
|
| def patched_fastrcnn_outputs_inference(predictions, proposal):
|
| return caffe2_fast_rcnn_outputs_inference(
|
| True, self.heads.box_predictor, predictions, proposal
|
| )
|
|
|
| self.heads.box_predictor.inference = patched_fastrcnn_outputs_inference
|
|
|
| if getattr(self.heads, "keypoint_on", False):
|
|
|
| def patched_keypoint_rcnn_inference(pred_keypoint_logits, pred_instances):
|
| return caffe2_keypoint_rcnn_inference(
|
| self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances
|
| )
|
|
|
| keypoint_head.keypoint_rcnn_inference = patched_keypoint_rcnn_inference
|
|
|
| if getattr(self.heads, "mask_on", False):
|
|
|
| def patched_mask_rcnn_inference(pred_mask_logits, pred_instances):
|
| return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances)
|
|
|
| mask_head.mask_rcnn_inference = patched_mask_rcnn_inference
|
|
|
| def unpatch_roi_heads(self):
|
| self.heads.box_predictor.inference = self.previous_patched["box_predictor"]
|
| keypoint_head.keypoint_rcnn_inference = self.previous_patched["keypoint_rcnn"]
|
| mask_head.mask_rcnn_inference = self.previous_patched["mask_rcnn"]
|
|
|