|
|
|
|
| from dataclasses import dataclass
|
| from typing import Union
|
| import torch
|
|
|
|
|
| @dataclass
|
| class DensePoseChartPredictorOutput:
|
| """
|
| Predictor output that contains segmentation and inner coordinates predictions for predefined
|
| body parts:
|
| * coarse segmentation, a tensor of shape [N, K, Hout, Wout]
|
| * fine segmentation, a tensor of shape [N, C, Hout, Wout]
|
| * U coordinates, a tensor of shape [N, C, Hout, Wout]
|
| * V coordinates, a tensor of shape [N, C, Hout, Wout]
|
| where
|
| - N is the number of instances
|
| - K is the number of coarse segmentation channels (
|
| 2 = foreground / background,
|
| 15 = one of 14 body parts / background)
|
| - C is the number of fine segmentation channels (
|
| 24 fine body parts / background)
|
| - Hout and Wout are height and width of predictions
|
| """
|
|
|
| coarse_segm: torch.Tensor
|
| fine_segm: torch.Tensor
|
| u: torch.Tensor
|
| v: torch.Tensor
|
|
|
| def __len__(self):
|
| """
|
| Number of instances (N) in the output
|
| """
|
| return self.coarse_segm.size(0)
|
|
|
| def __getitem__(
|
| self, item: Union[int, slice, torch.BoolTensor]
|
| ) -> "DensePoseChartPredictorOutput":
|
| """
|
| Get outputs for the selected instance(s)
|
|
|
| Args:
|
| item (int or slice or tensor): selected items
|
| """
|
| if isinstance(item, int):
|
| return DensePoseChartPredictorOutput(
|
| coarse_segm=self.coarse_segm[item].unsqueeze(0),
|
| fine_segm=self.fine_segm[item].unsqueeze(0),
|
| u=self.u[item].unsqueeze(0),
|
| v=self.v[item].unsqueeze(0),
|
| )
|
| else:
|
| return DensePoseChartPredictorOutput(
|
| coarse_segm=self.coarse_segm[item],
|
| fine_segm=self.fine_segm[item],
|
| u=self.u[item],
|
| v=self.v[item],
|
| )
|
|
|
| def to(self, device: torch.device):
|
| """
|
| Transfers all tensors to the given device
|
| """
|
| coarse_segm = self.coarse_segm.to(device)
|
| fine_segm = self.fine_segm.to(device)
|
| u = self.u.to(device)
|
| v = self.v.to(device)
|
| return DensePoseChartPredictorOutput(coarse_segm=coarse_segm, fine_segm=fine_segm, u=u, v=v)
|
|
|