| import torch
|
| from torch import nn
|
| from collections import OrderedDict
|
|
|
|
|
| def to_cuda(tensors):
|
| """Transfer tensor, dict or list of tensors to GPU.
|
|
|
| Args:
|
| tensors (:class:`torch.Tensor`, list or dict): May be a single, a
|
| list or a dictionary of tensors.
|
|
|
| Returns:
|
| :class:`torch.Tensor`:
|
| Same as input but transferred to cuda. Goes through lists and dicts
|
| and transfers the torch.Tensor to cuda. Leaves the rest untouched.
|
| """
|
| if isinstance(tensors, torch.Tensor):
|
| return tensors.cuda()
|
| if isinstance(tensors, list):
|
| return [to_cuda(tens) for tens in tensors]
|
| if isinstance(tensors, dict):
|
| for key in tensors.keys():
|
| tensors[key] = to_cuda(tensors[key])
|
| return tensors
|
| raise TypeError(
|
| "tensors must be a tensor or a list or dict of tensors. "
|
| " Got tensors of type {}".format(type(tensors))
|
| )
|
|
|
|
|
| def tensors_to_device(tensors, device):
|
| """Transfer tensor, dict or list of tensors to device.
|
|
|
| Args:
|
| tensors (:class:`torch.Tensor`): May be a single, a list or a
|
| dictionary of tensors.
|
| device (:class: `torch.device`): the device where to place the tensors.
|
|
|
| Returns:
|
| Union [:class:`torch.Tensor`, list, tuple, dict]:
|
| Same as input but transferred to device.
|
| Goes through lists and dicts and transfers the torch.Tensor to
|
| device. Leaves the rest untouched.
|
| """
|
| if isinstance(tensors, torch.Tensor):
|
| return tensors.to(device)
|
| elif isinstance(tensors, (list, tuple)):
|
| return [tensors_to_device(tens, device) for tens in tensors]
|
| elif isinstance(tensors, dict):
|
| for key in tensors.keys():
|
| tensors[key] = tensors_to_device(tensors[key], device)
|
| return tensors
|
| else:
|
| return tensors
|
|
|
|
|
| def pad_x_to_y(x, y, axis=-1):
|
| """Pad first argument to have same size as second argument
|
|
|
| Args:
|
| x (torch.Tensor): Tensor to be padded.
|
| y (torch.Tensor): Tensor to pad x to.
|
| axis (int): Axis to pad on.
|
|
|
| Returns:
|
| torch.Tensor, x padded to match y's shape.
|
| """
|
| if axis != -1:
|
| raise NotImplementedError
|
| inp_len = y.size(axis)
|
| output_len = x.size(axis)
|
| return nn.functional.pad(x, [0, inp_len - output_len])
|
|
|
|
|
| def load_state_dict_in(state_dict, model):
|
| """Strictly loads state_dict in model, or the next submodel.
|
| Useful to load standalone model after training it with System.
|
|
|
| Args:
|
| state_dict (OrderedDict): the state_dict to load.
|
| model (torch.nn.Module): the model to load it into
|
|
|
| Returns:
|
| torch.nn.Module: model with loaded weights.
|
|
|
| # .. note:: Keys in a state_dict look like object1.object2.layer_name.weight.etc
|
| We first try to load the model in the classic way.
|
| If this fail we removes the first left part of the key to obtain
|
| object2.layer_name.weight.etc.
|
| Blindly loading with strictly=False should be done with some logging
|
| of the missing keys in the state_dict and the model.
|
|
|
| """
|
| try:
|
|
|
|
|
| model.load_state_dict(state_dict, strict=True)
|
| except RuntimeError:
|
|
|
|
|
|
|
|
|
|
|
| new_state_dict = OrderedDict()
|
| for k, v in state_dict.items():
|
| new_k = k[k.find(".") + 1 :]
|
| new_state_dict[new_k] = v
|
| model.load_state_dict(new_state_dict, strict=True)
|
| return model
|
|
|
|
|
| def are_models_equal(model1, model2):
|
| """Check for weights equality between models.
|
|
|
| Args:
|
| model1 (nn.Module): model instance to be compared.
|
| model2 (nn.Module): second model instance to be compared.
|
|
|
| Returns:
|
| bool: Whether all model weights are equal.
|
| """
|
| for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
| if p1.data.ne(p2.data).sum() > 0:
|
| return False
|
| return True
|
|
|