| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| def pad_x_to_y(x, y, axis: int = -1): |
| if axis != -1: |
| raise NotImplementedError |
| inp_len = y.shape[axis] |
| output_len = x.shape[axis] |
| return nn.functional.pad(x, [0, inp_len - output_len]) |
|
|
|
|
| def shape_reconstructed(reconstructed, size): |
| if len(size) == 1: |
| return reconstructed.squeeze(0) |
| return reconstructed |
|
|
|
|
| def tensors_to_device(tensors, device): |
| """Transfer tensor, dict or list of tensors to device. |
| |
| Args: |
| tensors (:class:`torch.Tensor`): May be a single, a list or a |
| dictionary of tensors. |
| device (:class: `torch.device`): the device where to place the tensors. |
| |
| Returns: |
| Union [:class:`torch.Tensor`, list, tuple, dict]: |
| Same as input but transferred to device. |
| Goes through lists and dicts and transfers the torch.Tensor to |
| device. Leaves the rest untouched. |
| """ |
| if isinstance(tensors, torch.Tensor): |
| return tensors.to(device) |
| elif isinstance(tensors, (list, tuple)): |
| return [tensors_to_device(tens, device) for tens in tensors] |
| elif isinstance(tensors, dict): |
| for key in tensors.keys(): |
| tensors[key] = tensors_to_device(tensors[key], device) |
| return tensors |
| else: |
| return tensors |
|
|