| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Unility functions for Transformer.""" |
|
|
| from typing import List |
|
|
| import torch |
|
|
| IGNORE_ID = -1 |
|
|
|
|
| def pad_list(xs: List[torch.Tensor], pad_value: int): |
| """Perform padding for the list of tensors. |
| |
| Args: |
| xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. |
| pad_value (float): Value for padding. |
| |
| Returns: |
| Tensor: Padded tensor (B, Tmax, `*`). |
| |
| Examples: |
| >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] |
| >>> x |
| [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] |
| >>> pad_list(x, 0) |
| tensor([[1., 1., 1., 1.], |
| [1., 1., 0., 0.], |
| [1., 0., 0., 0.]]) |
| |
| """ |
| max_len = max([len(item) for item in xs]) |
| batchs = len(xs) |
| ndim = xs[0].ndim |
| if ndim == 1: |
| pad_res = torch.zeros(batchs, |
| max_len, |
| dtype=xs[0].dtype, |
| device=xs[0].device) |
| elif ndim == 2: |
| pad_res = torch.zeros(batchs, |
| max_len, |
| xs[0].shape[1], |
| dtype=xs[0].dtype, |
| device=xs[0].device) |
| elif ndim == 3: |
| pad_res = torch.zeros(batchs, |
| max_len, |
| xs[0].shape[1], |
| xs[0].shape[2], |
| dtype=xs[0].dtype, |
| device=xs[0].device) |
| else: |
| raise ValueError(f"Unsupported ndim: {ndim}") |
| pad_res.fill_(pad_value) |
| for i in range(batchs): |
| pad_res[i, :len(xs[i])] = xs[i] |
| return pad_res |
|
|
|
|
| def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, |
| ignore_label: int) -> torch.Tensor: |
| """Calculate accuracy. |
| |
| Args: |
| pad_outputs (Tensor): Prediction tensors (B * Lmax, D). |
| pad_targets (LongTensor): Target label tensors (B, Lmax). |
| ignore_label (int): Ignore label id. |
| |
| Returns: |
| torch.Tensor: Accuracy value (0.0 - 1.0). |
| |
| """ |
| pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), |
| pad_outputs.size(1)).argmax(2) |
| mask = pad_targets != ignore_label |
| numerator = torch.sum( |
| pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) |
| denominator = torch.sum(mask) |
| return (numerator / denominator).detach() |
|
|