|
|
|
|
| import os
|
| import sys
|
| import tempfile
|
| from contextlib import ExitStack, contextmanager
|
| from copy import deepcopy
|
| from unittest import mock
|
| import torch
|
| from torch import nn
|
|
|
|
|
| import detectron2
|
| from detectron2.structures import Boxes, Instances
|
| from detectron2.utils.env import _import_file
|
|
|
| _counter = 0
|
|
|
|
|
| def _clear_jit_cache():
|
| from torch.jit._recursive import concrete_type_store
|
| from torch.jit._state import _jit_caching_layer
|
|
|
| concrete_type_store.type_store.clear()
|
| _jit_caching_layer.clear()
|
|
|
|
|
| def _add_instances_conversion_methods(newInstances):
|
| """
|
| Add from_instances methods to the scripted Instances class.
|
| """
|
| cls_name = newInstances.__name__
|
|
|
| @torch.jit.unused
|
| def from_instances(instances: Instances):
|
| """
|
| Create scripted Instances from original Instances
|
| """
|
| fields = instances.get_fields()
|
| image_size = instances.image_size
|
| ret = newInstances(image_size)
|
| for name, val in fields.items():
|
| assert hasattr(ret, f"_{name}"), f"No attribute named {name} in {cls_name}"
|
| setattr(ret, name, deepcopy(val))
|
| return ret
|
|
|
| newInstances.from_instances = from_instances
|
|
|
|
|
| @contextmanager
|
| def patch_instances(fields):
|
| """
|
| A contextmanager, under which the Instances class in detectron2 is replaced
|
| by a statically-typed scriptable class, defined by `fields`.
|
| See more in `scripting_with_instances`.
|
| """
|
|
|
| with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile(
|
| mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False
|
| ) as f:
|
| try:
|
|
|
|
|
| _clear_jit_cache()
|
|
|
| cls_name, s = _gen_instance_module(fields)
|
| f.write(s)
|
| f.flush()
|
| f.close()
|
|
|
| module = _import(f.name)
|
| new_instances = getattr(module, cls_name)
|
| _ = torch.jit.script(new_instances)
|
|
|
| Instances.__torch_script_class__ = True
|
|
|
| Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances)
|
|
|
| _add_instances_conversion_methods(new_instances)
|
| yield new_instances
|
| finally:
|
| try:
|
| del Instances.__torch_script_class__
|
| del Instances._jit_override_qualname
|
| except AttributeError:
|
| pass
|
| sys.modules.pop(module.__name__)
|
|
|
|
|
| def _gen_instance_class(fields):
|
| """
|
| Args:
|
| fields (dict[name: type])
|
| """
|
|
|
| class _FieldType:
|
| def __init__(self, name, type_):
|
| assert isinstance(name, str), f"Field name must be str, got {name}"
|
| self.name = name
|
| self.type_ = type_
|
| self.annotation = f"{type_.__module__}.{type_.__name__}"
|
|
|
| fields = [_FieldType(k, v) for k, v in fields.items()]
|
|
|
| def indent(level, s):
|
| return " " * 4 * level + s
|
|
|
| lines = []
|
|
|
| global _counter
|
| _counter += 1
|
|
|
| cls_name = "ScriptedInstances{}".format(_counter)
|
|
|
| field_names = tuple(x.name for x in fields)
|
| extra_args = ", ".join([f"{f.name}: Optional[{f.annotation}] = None" for f in fields])
|
| lines.append(
|
| f"""
|
| class {cls_name}:
|
| def __init__(self, image_size: Tuple[int, int], {extra_args}):
|
| self.image_size = image_size
|
| self._field_names = {field_names}
|
| """
|
| )
|
|
|
| for f in fields:
|
| lines.append(
|
| indent(2, f"self._{f.name} = torch.jit.annotate(Optional[{f.annotation}], {f.name})")
|
| )
|
|
|
| for f in fields:
|
| lines.append(
|
| f"""
|
| @property
|
| def {f.name}(self) -> {f.annotation}:
|
| # has to use a local for type refinement
|
| # https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement
|
| t = self._{f.name}
|
| assert t is not None, "{f.name} is None and cannot be accessed!"
|
| return t
|
|
|
| @{f.name}.setter
|
| def {f.name}(self, value: {f.annotation}) -> None:
|
| self._{f.name} = value
|
| """
|
| )
|
|
|
|
|
| lines.append(
|
| """
|
| def __len__(self) -> int:
|
| """
|
| )
|
| for f in fields:
|
| lines.append(
|
| f"""
|
| t = self._{f.name}
|
| if t is not None:
|
| return len(t)
|
| """
|
| )
|
| lines.append(
|
| """
|
| raise NotImplementedError("Empty Instances does not support __len__!")
|
| """
|
| )
|
|
|
|
|
| lines.append(
|
| """
|
| def has(self, name: str) -> bool:
|
| """
|
| )
|
| for f in fields:
|
| lines.append(
|
| f"""
|
| if name == "{f.name}":
|
| return self._{f.name} is not None
|
| """
|
| )
|
| lines.append(
|
| """
|
| return False
|
| """
|
| )
|
|
|
|
|
| none_args = ", None" * len(fields)
|
| lines.append(
|
| f"""
|
| def to(self, device: torch.device) -> "{cls_name}":
|
| ret = {cls_name}(self.image_size{none_args})
|
| """
|
| )
|
| for f in fields:
|
| if hasattr(f.type_, "to"):
|
| lines.append(
|
| f"""
|
| t = self._{f.name}
|
| if t is not None:
|
| ret._{f.name} = t.to(device)
|
| """
|
| )
|
| else:
|
|
|
|
|
| pass
|
| lines.append(
|
| """
|
| return ret
|
| """
|
| )
|
|
|
|
|
| none_args = ", None" * len(fields)
|
| lines.append(
|
| f"""
|
| def __getitem__(self, item) -> "{cls_name}":
|
| ret = {cls_name}(self.image_size{none_args})
|
| """
|
| )
|
| for f in fields:
|
| lines.append(
|
| f"""
|
| t = self._{f.name}
|
| if t is not None:
|
| ret._{f.name} = t[item]
|
| """
|
| )
|
| lines.append(
|
| """
|
| return ret
|
| """
|
| )
|
|
|
|
|
|
|
| none_args = ", None" * len(fields)
|
| lines.append(
|
| f"""
|
| def cat(self, instances: List["{cls_name}"]) -> "{cls_name}":
|
| ret = {cls_name}(self.image_size{none_args})
|
| """
|
| )
|
| for f in fields:
|
| lines.append(
|
| f"""
|
| t = self._{f.name}
|
| if t is not None:
|
| values: List[{f.annotation}] = [x.{f.name} for x in instances]
|
| if torch.jit.isinstance(t, torch.Tensor):
|
| ret._{f.name} = torch.cat(values, dim=0)
|
| else:
|
| ret._{f.name} = t.cat(values)
|
| """
|
| )
|
| lines.append(
|
| """
|
| return ret"""
|
| )
|
|
|
|
|
| lines.append(
|
| """
|
| def get_fields(self) -> Dict[str, Tensor]:
|
| ret = {}
|
| """
|
| )
|
| for f in fields:
|
| if f.type_ == Boxes:
|
| stmt = "t.tensor"
|
| elif f.type_ == torch.Tensor:
|
| stmt = "t"
|
| else:
|
| stmt = f'assert False, "unsupported type {str(f.type_)}"'
|
| lines.append(
|
| f"""
|
| t = self._{f.name}
|
| if t is not None:
|
| ret["{f.name}"] = {stmt}
|
| """
|
| )
|
| lines.append(
|
| """
|
| return ret"""
|
| )
|
| return cls_name, os.linesep.join(lines)
|
|
|
|
|
| def _gen_instance_module(fields):
|
|
|
| s = """
|
| from copy import deepcopy
|
| import torch
|
| from torch import Tensor
|
| import typing
|
| from typing import *
|
|
|
| import detectron2
|
| from detectron2.structures import Boxes, Instances
|
|
|
| """
|
|
|
| cls_name, cls_def = _gen_instance_class(fields)
|
| s += cls_def
|
| return cls_name, s
|
|
|
|
|
| def _import(path):
|
| return _import_file(
|
| "{}{}".format(sys.modules[__name__].__name__, _counter), path, make_importable=True
|
| )
|
|
|
|
|
| @contextmanager
|
| def patch_builtin_len(modules=()):
|
| """
|
| Patch the builtin len() function of a few detectron2 modules
|
| to use __len__ instead, because __len__ does not convert values to
|
| integers and therefore is friendly to tracing.
|
|
|
| Args:
|
| modules (list[stsr]): names of extra modules to patch len(), in
|
| addition to those in detectron2.
|
| """
|
|
|
| def _new_len(obj):
|
| return obj.__len__()
|
|
|
| with ExitStack() as stack:
|
| MODULES = [
|
| "detectron2.modeling.roi_heads.fast_rcnn",
|
| "detectron2.modeling.roi_heads.mask_head",
|
| "detectron2.modeling.roi_heads.keypoint_head",
|
| ] + list(modules)
|
| ctxs = [stack.enter_context(mock.patch(mod + ".len")) for mod in MODULES]
|
| for m in ctxs:
|
| m.side_effect = _new_len
|
| yield
|
|
|
|
|
| def patch_nonscriptable_classes():
|
| """
|
| Apply patches on a few nonscriptable detectron2 classes.
|
| Should not have side-effects on eager usage.
|
| """
|
|
|
|
|
|
|
| from detectron2.modeling.backbone import ResNet, FPN
|
|
|
|
|
|
|
|
|
|
|
| def prepare_resnet(self):
|
| ret = deepcopy(self)
|
| ret.stages = nn.ModuleList(ret.stages)
|
| for k in self.stage_names:
|
| delattr(ret, k)
|
| return ret
|
|
|
| ResNet.__prepare_scriptable__ = prepare_resnet
|
|
|
| def prepare_fpn(self):
|
| ret = deepcopy(self)
|
| ret.lateral_convs = nn.ModuleList(ret.lateral_convs)
|
| ret.output_convs = nn.ModuleList(ret.output_convs)
|
| for name, _ in self.named_children():
|
| if name.startswith("fpn_"):
|
| delattr(ret, name)
|
| return ret
|
|
|
| FPN.__prepare_scriptable__ = prepare_fpn
|
|
|
|
|
|
|
| from detectron2.modeling.roi_heads import StandardROIHeads
|
|
|
| if hasattr(StandardROIHeads, "__annotations__"):
|
|
|
| StandardROIHeads.__annotations__ = deepcopy(StandardROIHeads.__annotations__)
|
| StandardROIHeads.__annotations__["mask_on"] = torch.jit.Final[bool]
|
| StandardROIHeads.__annotations__["keypoint_on"] = torch.jit.Final[bool]
|
|
|
|
|
|
|
| patch_nonscriptable_classes()
|
|
|
|
|
| @contextmanager
|
| def freeze_training_mode(model):
|
| """
|
| A context manager that annotates the "training" attribute of every submodule
|
| to constant, so that the training codepath in these modules can be
|
| meta-compiled away. Upon exiting, the annotations are reverted.
|
| """
|
| classes = {type(x) for x in model.modules()}
|
|
|
|
|
| classes = {x for x in classes if not hasattr(x, "__constants__")}
|
| for cls in classes:
|
| cls.__annotations__["training"] = torch.jit.Final[bool]
|
| yield
|
| for cls in classes:
|
| cls.__annotations__["training"] = bool
|
|
|