|
|
|
|
|
|
| import typing
|
| from typing import Any, List
|
| import fvcore
|
| from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table
|
| from torch import nn
|
|
|
| from detectron2.export import TracingAdapter
|
|
|
| __all__ = [
|
| "activation_count_operators",
|
| "flop_count_operators",
|
| "parameter_count_table",
|
| "parameter_count",
|
| "FlopCountAnalysis",
|
| ]
|
|
|
| FLOPS_MODE = "flops"
|
| ACTIVATIONS_MODE = "activations"
|
|
|
|
|
|
|
| _IGNORED_OPS = {
|
| "aten::add",
|
| "aten::add_",
|
| "aten::argmax",
|
| "aten::argsort",
|
| "aten::batch_norm",
|
| "aten::constant_pad_nd",
|
| "aten::div",
|
| "aten::div_",
|
| "aten::exp",
|
| "aten::log2",
|
| "aten::max_pool2d",
|
| "aten::meshgrid",
|
| "aten::mul",
|
| "aten::mul_",
|
| "aten::neg",
|
| "aten::nonzero_numpy",
|
| "aten::reciprocal",
|
| "aten::repeat_interleave",
|
| "aten::rsub",
|
| "aten::sigmoid",
|
| "aten::sigmoid_",
|
| "aten::softmax",
|
| "aten::sort",
|
| "aten::sqrt",
|
| "aten::sub",
|
| "torchvision::nms",
|
| }
|
|
|
|
|
| class FlopCountAnalysis(fvcore.nn.FlopCountAnalysis):
|
| """
|
| Same as :class:`fvcore.nn.FlopCountAnalysis`, but supports detectron2 models.
|
| """
|
|
|
| def __init__(self, model, inputs):
|
| """
|
| Args:
|
| model (nn.Module):
|
| inputs (Any): inputs of the given model. Does not have to be tuple of tensors.
|
| """
|
| wrapper = TracingAdapter(model, inputs, allow_non_tensor=True)
|
| super().__init__(wrapper, wrapper.flattened_inputs)
|
| self.set_op_handle(**{k: None for k in _IGNORED_OPS})
|
|
|
|
|
| def flop_count_operators(model: nn.Module, inputs: list) -> typing.DefaultDict[str, float]:
|
| """
|
| Implement operator-level flops counting using jit.
|
| This is a wrapper of :func:`fvcore.nn.flop_count` and adds supports for standard
|
| detection models in detectron2.
|
| Please use :class:`FlopCountAnalysis` for more advanced functionalities.
|
|
|
| Note:
|
| The function runs the input through the model to compute flops.
|
| The flops of a detection model is often input-dependent, for example,
|
| the flops of box & mask head depends on the number of proposals &
|
| the number of detected objects.
|
| Therefore, the flops counting using a single input may not accurately
|
| reflect the computation cost of a model. It's recommended to average
|
| across a number of inputs.
|
|
|
| Args:
|
| model: a detectron2 model that takes `list[dict]` as input.
|
| inputs (list[dict]): inputs to model, in detectron2's standard format.
|
| Only "image" key will be used.
|
| supported_ops (dict[str, Handle]): see documentation of :func:`fvcore.nn.flop_count`
|
|
|
| Returns:
|
| Counter: Gflop count per operator
|
| """
|
| old_train = model.training
|
| model.eval()
|
| ret = FlopCountAnalysis(model, inputs).by_operator()
|
| model.train(old_train)
|
| return {k: v / 1e9 for k, v in ret.items()}
|
|
|
|
|
| def activation_count_operators(
|
| model: nn.Module, inputs: list, **kwargs
|
| ) -> typing.DefaultDict[str, float]:
|
| """
|
| Implement operator-level activations counting using jit.
|
| This is a wrapper of fvcore.nn.activation_count, that supports standard detection models
|
| in detectron2.
|
|
|
| Note:
|
| The function runs the input through the model to compute activations.
|
| The activations of a detection model is often input-dependent, for example,
|
| the activations of box & mask head depends on the number of proposals &
|
| the number of detected objects.
|
|
|
| Args:
|
| model: a detectron2 model that takes `list[dict]` as input.
|
| inputs (list[dict]): inputs to model, in detectron2's standard format.
|
| Only "image" key will be used.
|
|
|
| Returns:
|
| Counter: activation count per operator
|
| """
|
| return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs)
|
|
|
|
|
| def _wrapper_count_operators(
|
| model: nn.Module, inputs: list, mode: str, **kwargs
|
| ) -> typing.DefaultDict[str, float]:
|
|
|
| supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS}
|
| supported_ops.update(kwargs.pop("supported_ops", {}))
|
| kwargs["supported_ops"] = supported_ops
|
|
|
| assert len(inputs) == 1, "Please use batch size=1"
|
| tensor_input = inputs[0]["image"]
|
| inputs = [{"image": tensor_input}]
|
|
|
| old_train = model.training
|
| if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
|
| model = model.module
|
| wrapper = TracingAdapter(model, inputs)
|
| wrapper.eval()
|
| if mode == FLOPS_MODE:
|
| ret = flop_count(wrapper, (tensor_input,), **kwargs)
|
| elif mode == ACTIVATIONS_MODE:
|
| ret = activation_count(wrapper, (tensor_input,), **kwargs)
|
| else:
|
| raise NotImplementedError("Count for mode {} is not supported yet.".format(mode))
|
|
|
| if isinstance(ret, tuple):
|
| ret = ret[0]
|
| model.train(old_train)
|
| return ret
|
|
|
|
|
| def find_unused_parameters(model: nn.Module, inputs: Any) -> List[str]:
|
| """
|
| Given a model, find parameters that do not contribute
|
| to the loss.
|
|
|
| Args:
|
| model: a model in training mode that returns losses
|
| inputs: argument or a tuple of arguments. Inputs of the model
|
|
|
| Returns:
|
| list[str]: the name of unused parameters
|
| """
|
| assert model.training
|
| for _, prm in model.named_parameters():
|
| prm.grad = None
|
|
|
| if isinstance(inputs, tuple):
|
| losses = model(*inputs)
|
| else:
|
| losses = model(inputs)
|
|
|
| if isinstance(losses, dict):
|
| losses = sum(losses.values())
|
| losses.backward()
|
|
|
| unused: List[str] = []
|
| for name, prm in model.named_parameters():
|
| if prm.grad is None:
|
| unused.append(name)
|
| prm.grad = None
|
| return unused
|
|
|