| | |
| |
|
| | import torch |
| |
|
| | from .base_data_element import BaseDataElement |
| |
|
| |
|
| | class LabelData(BaseDataElement): |
| | """Data structure for label-level annotations or predictions.""" |
| |
|
| | @staticmethod |
| | def onehot_to_label(onehot: torch.Tensor) -> torch.Tensor: |
| | """Convert the one-hot input to label. |
| | |
| | Args: |
| | onehot (torch.Tensor, optional): The one-hot input. The format |
| | of input must be one-hot. |
| | |
| | Returns: |
| | torch.Tensor: The converted results. |
| | """ |
| | assert isinstance(onehot, torch.Tensor) |
| | if (onehot.ndim == 1 and onehot.max().item() <= 1 |
| | and onehot.min().item() >= 0): |
| | return onehot.nonzero().squeeze(-1) |
| | else: |
| | raise ValueError( |
| | 'input is not one-hot and can not convert to label') |
| |
|
| | @staticmethod |
| | def label_to_onehot(label: torch.Tensor, num_classes: int) -> torch.Tensor: |
| | """Convert the label-format input to one-hot. |
| | |
| | Args: |
| | label (torch.Tensor): The label-format input. The format |
| | of item must be label-format. |
| | num_classes (int): The number of classes. |
| | |
| | Returns: |
| | torch.Tensor: The converted results. |
| | """ |
| | assert isinstance(label, torch.Tensor) |
| | onehot = label.new_zeros((num_classes, )) |
| | assert max(label, default=torch.tensor(0)).item() < num_classes |
| | onehot[label] = 1 |
| | return onehot |
| |
|