| from typing import List, Union |
|
|
| import torch |
| from torch.nn.functional import cross_entropy |
|
|
| from .constants import IGNORE_INDEX |
|
|
| __all__ = ["soft_cross_entropy"] |
|
|
|
|
| def soft_cross_entropy( |
| outputs: torch.Tensor, |
| targets: torch.Tensor, |
| soft_tokens: Union[torch.Tensor, List[int]], |
| std: float = 1, |
| ignore_index: int = IGNORE_INDEX, |
| ) -> torch.Tensor: |
| |
| outputs = outputs[..., :-1, :].contiguous() |
| targets = targets[..., 1:].contiguous() |
|
|
| |
| targets = targets.view(-1) |
| outputs = outputs.view(targets.size(0), -1) |
|
|
| |
| indices = targets != ignore_index |
| outputs = outputs[indices] |
| targets = targets[indices] |
|
|
| |
| if isinstance(soft_tokens, list): |
| soft_tokens = torch.tensor(soft_tokens).to(targets) |
|
|
| |
| indices = torch.isin(targets, soft_tokens, invert=True) |
| loss = cross_entropy(outputs[indices], targets[indices], reduction="sum") |
|
|
| |
| indices = torch.isin(targets, soft_tokens) |
| targets_indices = torch.zeros_like(outputs[indices]) |
| for k, target in enumerate(targets[indices]): |
| dist = torch.exp(-((target - soft_tokens) ** 2) / (2 * std**2)) |
| targets_indices[k][soft_tokens] = dist / dist.sum() |
| loss += cross_entropy(outputs[indices], targets_indices, reduction="sum") |
|
|
| |
| return loss / targets.size(0) |
|
|