|
|
|
|
| import torch
|
|
|
|
|
| class ImageResizeTransform:
|
| """
|
| Transform that resizes images loaded from a dataset
|
| (BGR data in NCHW channel order, typically uint8) to a format ready to be
|
| consumed by DensePose training (BGR float32 data in NCHW channel order)
|
| """
|
|
|
| def __init__(self, min_size: int = 800, max_size: int = 1333):
|
| self.min_size = min_size
|
| self.max_size = max_size
|
|
|
| def __call__(self, images: torch.Tensor) -> torch.Tensor:
|
| """
|
| Args:
|
| images (torch.Tensor): tensor of size [N, 3, H, W] that contains
|
| BGR data (typically in uint8)
|
| Returns:
|
| images (torch.Tensor): tensor of size [N, 3, H1, W1] where
|
| H1 and W1 are chosen to respect the specified min and max sizes
|
| and preserve the original aspect ratio, the data channels
|
| follow BGR order and the data type is `torch.float32`
|
| """
|
|
|
| images = images.float()
|
| min_size = min(images.shape[-2:])
|
| max_size = max(images.shape[-2:])
|
| scale = min(self.min_size / min_size, self.max_size / max_size)
|
| images = torch.nn.functional.interpolate(
|
| images,
|
| scale_factor=scale,
|
| mode="bilinear",
|
| align_corners=False,
|
| )
|
| return images
|
|
|