| | |
| | import io |
| | import numpy as np |
| | import torch |
| |
|
| | from detectron2 import model_zoo |
| | from detectron2.data import DatasetCatalog |
| | from detectron2.data.detection_utils import read_image |
| | from detectron2.modeling import build_model |
| | from detectron2.structures import Boxes, Instances, ROIMasks |
| | from detectron2.utils.file_io import PathManager |
| |
|
| |
|
| | """ |
| | Internal utilities for tests. Don't use except for writing tests. |
| | """ |
| |
|
| |
|
| | def get_model_no_weights(config_path): |
| | """ |
| | Like model_zoo.get, but do not load any weights (even pretrained) |
| | """ |
| | cfg = model_zoo.get_config(config_path) |
| | if not torch.cuda.is_available(): |
| | cfg.MODEL.DEVICE = "cpu" |
| | return build_model(cfg) |
| |
|
| |
|
| | def random_boxes(num_boxes, max_coord=100, device="cpu"): |
| | """ |
| | Create a random Nx4 boxes tensor, with coordinates < max_coord. |
| | """ |
| | boxes = torch.rand(num_boxes, 4, device=device) * (max_coord * 0.5) |
| | boxes.clamp_(min=1.0) |
| | |
| | |
| | |
| | |
| | boxes[:, 2:] += boxes[:, :2] |
| | return boxes |
| |
|
| |
|
| | def get_sample_coco_image(tensor=True): |
| | """ |
| | Args: |
| | tensor (bool): if True, returns 3xHxW tensor. |
| | else, returns a HxWx3 numpy array. |
| | |
| | Returns: |
| | an image, in BGR color. |
| | """ |
| | try: |
| | file_name = DatasetCatalog.get("coco_2017_val_100")[0]["file_name"] |
| | if not PathManager.exists(file_name): |
| | raise FileNotFoundError() |
| | except IOError: |
| | |
| | file_name = PathManager.get_local_path( |
| | "http://images.cocodataset.org/train2017/000000000009.jpg" |
| | ) |
| | ret = read_image(file_name, format="BGR") |
| | if tensor: |
| | ret = torch.from_numpy(np.ascontiguousarray(ret.transpose(2, 0, 1))) |
| | return ret |
| |
|
| |
|
| | def convert_scripted_instances(instances): |
| | """ |
| | Convert a scripted Instances object to a regular :class:`Instances` object |
| | """ |
| | assert hasattr( |
| | instances, "image_size" |
| | ), f"Expect an Instances object, but got {type(instances)}!" |
| | ret = Instances(instances.image_size) |
| | for name in instances._field_names: |
| | val = getattr(instances, "_" + name, None) |
| | if val is not None: |
| | ret.set(name, val) |
| | return ret |
| |
|
| |
|
| | def assert_instances_allclose(input, other, *, rtol=1e-5, msg="", size_as_tensor=False): |
| | """ |
| | Args: |
| | input, other (Instances): |
| | size_as_tensor: compare image_size of the Instances as tensors (instead of tuples). |
| | Useful for comparing outputs of tracing. |
| | """ |
| | if not isinstance(input, Instances): |
| | input = convert_scripted_instances(input) |
| | if not isinstance(other, Instances): |
| | other = convert_scripted_instances(other) |
| |
|
| | if not msg: |
| | msg = "Two Instances are different! " |
| | else: |
| | msg = msg.rstrip() + " " |
| |
|
| | size_error_msg = msg + f"image_size is {input.image_size} vs. {other.image_size}!" |
| | if size_as_tensor: |
| | assert torch.equal( |
| | torch.tensor(input.image_size), torch.tensor(other.image_size) |
| | ), size_error_msg |
| | else: |
| | assert input.image_size == other.image_size, size_error_msg |
| | fields = sorted(input.get_fields().keys()) |
| | fields_other = sorted(other.get_fields().keys()) |
| | assert fields == fields_other, msg + f"Fields are {fields} vs {fields_other}!" |
| |
|
| | for f in fields: |
| | val1, val2 = input.get(f), other.get(f) |
| | if isinstance(val1, (Boxes, ROIMasks)): |
| | |
| | assert torch.allclose(val1.tensor, val2.tensor, atol=100 * rtol), ( |
| | msg + f"Field {f} differs too much!" |
| | ) |
| | elif isinstance(val1, torch.Tensor): |
| | if val1.dtype.is_floating_point: |
| | mag = torch.abs(val1).max().cpu().item() |
| | assert torch.allclose(val1, val2, atol=mag * rtol), ( |
| | msg + f"Field {f} differs too much!" |
| | ) |
| | else: |
| | assert torch.equal(val1, val2), msg + f"Field {f} is different!" |
| | else: |
| | raise ValueError(f"Don't know how to compare type {type(val1)}") |
| |
|
| |
|
| | def reload_script_model(module): |
| | """ |
| | Save a jit module and load it back. |
| | Similar to the `getExportImportCopy` function in torch/testing/ |
| | """ |
| | buffer = io.BytesIO() |
| | torch.jit.save(module, buffer) |
| | buffer.seek(0) |
| | return torch.jit.load(buffer) |
| |
|