| | import os
|
| | import warnings
|
| |
|
| | import numpy as np
|
| | import torch
|
| | from torch import nn
|
| |
|
| | from ..masknn import activations
|
| | from ..utils.torch_utils import pad_x_to_y
|
| |
|
| |
|
| | def _unsqueeze_to_3d(x):
|
| | if x.ndim == 1:
|
| | return x.reshape(1, 1, -1)
|
| | elif x.ndim == 2:
|
| | return x.unsqueeze(1)
|
| | else:
|
| | return x
|
| |
|
| |
|
| | class BaseModel(nn.Module):
|
| | def __init__(self):
|
| | print("initialize BaseModel")
|
| | super().__init__()
|
| |
|
| | def forward(self, *args, **kwargs):
|
| | raise NotImplementedError
|
| |
|
| | @torch.no_grad()
|
| | def separate(self, wav, output_dir=None, force_overwrite=False, **kwargs):
|
| | """Infer separated sources from input waveforms.
|
| | Also supports filenames.
|
| |
|
| | Args:
|
| | wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor.
|
| | Shape: 1D, 2D or 3D tensor, time last.
|
| | output_dir (str): path to save all the wav files. If None,
|
| | estimated sources will be saved next to the original ones.
|
| | force_overwrite (bool): whether to overwrite existing files.
|
| | **kwargs: keyword arguments to be passed to `_separate`.
|
| |
|
| | Returns:
|
| | Union[torch.Tensor, numpy.ndarray, None], the estimated sources.
|
| | (batch, n_src, time) or (n_src, time) w/o batch dim.
|
| |
|
| | .. note::
|
| | By default, `separate` calls `_separate` which calls `forward`.
|
| | For models whose `forward` doesn't return waveform tensors,
|
| | overwrite `_separate` to return waveform tensors.
|
| | """
|
| | if isinstance(wav, str):
|
| | self.file_separate(
|
| | wav, output_dir=output_dir, force_overwrite=force_overwrite, **kwargs
|
| | )
|
| | elif isinstance(wav, np.ndarray):
|
| | print("is ndarray")
|
| |
|
| | return self.numpy_separate(wav, **kwargs)
|
| | elif isinstance(wav, torch.Tensor):
|
| | print("is torch.Tensor")
|
| | return self.torch_separate(wav, **kwargs)
|
| | else:
|
| | raise ValueError(
|
| | f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}"
|
| | )
|
| |
|
| | def torch_separate(self, wav: torch.Tensor, **kwargs) -> torch.Tensor:
|
| | """ Core logic of `separate`."""
|
| |
|
| | input_device = wav.device
|
| | model_device = next(self.parameters()).device
|
| | wav = wav.to(model_device)
|
| |
|
| | out_wavs = self._separate(wav, **kwargs)
|
| |
|
| |
|
| | out_wavs *= wav.abs().sum() / (out_wavs.abs().sum())
|
| |
|
| |
|
| | out_wavs = out_wavs.to(input_device)
|
| | return out_wavs
|
| |
|
| | def numpy_separate(self, wav: np.ndarray, **kwargs) -> np.ndarray:
|
| | """ Numpy interface to `separate`."""
|
| | wav = torch.from_numpy(wav)
|
| | out_wav = self.torch_separate(wav, **kwargs)
|
| | out_wav = out_wav.data.numpy()
|
| | return out_wav
|
| |
|
| | def file_separate(
|
| | self, filename: str, output_dir=None, force_overwrite=False, **kwargs
|
| | ) -> None:
|
| | """ Filename interface to `separate`."""
|
| | import soundfile as sf
|
| |
|
| | wav, fs = sf.read(filename, dtype="float32", always_2d=True)
|
| |
|
| | to_save = self.numpy_separate(wav[:, 0], **kwargs)
|
| |
|
| |
|
| | for src_idx, est_src in enumerate(to_save):
|
| | base = ".".join(filename.split(".")[:-1])
|
| | save_name = base + "_est{}.".format(src_idx + 1) + filename.split(".")[-1]
|
| | if os.path.isfile(save_name) and not force_overwrite:
|
| | warnings.warn(
|
| | f"File {save_name} already exists, pass `force_overwrite=True` to overwrite it",
|
| | UserWarning,
|
| | )
|
| | return
|
| | if output_dir is not None:
|
| | save_name = os.path.join(output_dir, save_name.split("/")[-1])
|
| | sf.write(save_name, est_src, fs)
|
| |
|
| | def _separate(self, wav, *args, **kwargs):
|
| | """Hidden separation method
|
| |
|
| | Args:
|
| | wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor.
|
| | Shape: 1D, 2D or 3D tensor, time last.
|
| |
|
| | Returns:
|
| | The output of self(wav, *args, **kwargs).
|
| | """
|
| | return self(wav, *args, **kwargs)
|
| |
|
| | @classmethod
|
| | def from_pretrained(cls, pretrained_model_conf_or_path, *args, **kwargs):
|
| | """Instantiate separation model from a model config (file or dict).
|
| |
|
| | Args:
|
| | pretrained_model_conf_or_path (Union[dict, str]): model conf as
|
| | returned by `serialize`, or path to it. Need to contain
|
| | `model_args` and `state_dict` keys.
|
| | *args: Positional arguments to be passed to the model.
|
| | **kwargs: Keyword arguments to be passed to the model.
|
| | They overwrite the ones in the model package.
|
| |
|
| | Returns:
|
| | nn.Module corresponding to the pretrained model conf/URL.
|
| |
|
| | Raises:
|
| | ValueError if the input config file doesn't contain the keys
|
| | `model_name`, `model_args` or `state_dict`.
|
| | """
|
| | from . import get
|
| |
|
| | if isinstance(pretrained_model_conf_or_path, str):
|
| |
|
| | if os.path.isfile(pretrained_model_conf_or_path):
|
| | cached_model = pretrained_model_conf_or_path
|
| | else:
|
| | raise ValueError(
|
| | "Model {} is not a file or doesn't exist.".format(pretrained_model_conf_or_path)
|
| | )
|
| |
|
| | conf = torch.load(cached_model, map_location="cpu")
|
| | else:
|
| | conf = pretrained_model_conf_or_path
|
| |
|
| | if "model_name" not in conf.keys():
|
| | raise ValueError(
|
| | "Expected config dictionary to have field "
|
| | "model_name`. Found only: {}".format(conf.keys())
|
| | )
|
| | if "state_dict" not in conf.keys():
|
| | raise ValueError(
|
| | "Expected config dictionary to have field "
|
| | "state_dict`. Found only: {}".format(conf.keys())
|
| | )
|
| | if "model_args" not in conf.keys():
|
| | raise ValueError(
|
| | "Expected config dictionary to have field "
|
| | "model_args`. Found only: {}".format(conf.keys())
|
| | )
|
| | conf["model_args"].update(kwargs)
|
| |
|
| | try:
|
| | model_class = get(conf["model_name"])
|
| | except ValueError:
|
| | model = cls(*args, **conf["model_args"])
|
| | else:
|
| | model = model_class(*args, **conf["model_args"])
|
| | model.load_state_dict(conf["state_dict"])
|
| | return model
|
| |
|
| | def serialize(self):
|
| | """Serialize model and output dictionary.
|
| |
|
| | Returns:
|
| | dict, serialized model with keys `model_args` and `state_dict`.
|
| | """
|
| | import pytorch_lightning as pl
|
| |
|
| | from .. import __version__ as asteroid_version
|
| |
|
| | model_conf = dict(
|
| | model_name=self.__class__.__name__,
|
| | state_dict=self.get_state_dict(),
|
| | model_args=self.get_model_args(),
|
| | )
|
| |
|
| | infos = dict()
|
| | infos["software_versions"] = dict(
|
| | torch_version=torch.__version__,
|
| | pytorch_lightning_version=pl.__version__,
|
| | asteroid_version=asteroid_version,
|
| | )
|
| | model_conf["infos"] = infos
|
| | return model_conf
|
| |
|
| | def get_state_dict(self):
|
| | """ In case the state dict needs to be modified before sharing the model."""
|
| | return self.state_dict()
|
| |
|
| | def get_model_args(self):
|
| | raise NotImplementedError
|
| |
|
| | def cached_download(self, filename_or_url):
|
| | if os.path.isfile(filename_or_url):
|
| | print("is file")
|
| | return filename_or_url
|
| | else:
|
| | print("Model {} is not a file or doesn't exist.".format(filename_or_url))
|
| |
|
| |
|
| | class BaseEncoderMaskerDecoder(BaseModel):
|
| | """Base class for encoder-masker-decoder separation models.
|
| |
|
| | Args:
|
| | encoder (Encoder): Encoder instance.
|
| | masker (nn.Module): masker network.
|
| | decoder (Decoder): Decoder instance.
|
| | encoder_activation (Optional[str], optional): Activation to apply after encoder.
|
| | See ``asteroid.masknn.activations`` for valid values.
|
| | """
|
| |
|
| | def __init__(self, encoder, masker, decoder, encoder_activation=None):
|
| | super().__init__()
|
| | self.encoder = encoder
|
| | self.masker = masker
|
| | self.decoder = decoder
|
| |
|
| | self.encoder_activation = encoder_activation
|
| | self.enc_activation = activations.get(encoder_activation or "linear")()
|
| |
|
| | def forward(self, wav):
|
| | """Enc/Mask/Dec model forward
|
| |
|
| | Args:
|
| | wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
|
| |
|
| | Returns:
|
| | torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
|
| | """
|
| |
|
| | was_one_d = wav.ndim == 1
|
| |
|
| | wav = _unsqueeze_to_3d(wav)
|
| |
|
| |
|
| | tf_rep = self.encoder(wav)
|
| | tf_rep = self.postprocess_encoded(tf_rep)
|
| | tf_rep = self.enc_activation(tf_rep)
|
| |
|
| | est_masks = self.masker(tf_rep)
|
| | est_masks = self.postprocess_masks(est_masks)
|
| |
|
| | masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
|
| | masked_tf_rep = self.postprocess_masked(masked_tf_rep)
|
| |
|
| | decoded = self.decoder(masked_tf_rep)
|
| | decoded = self.postprocess_decoded(decoded)
|
| |
|
| | reconstructed = pad_x_to_y(decoded, wav)
|
| | if was_one_d:
|
| | return reconstructed.squeeze(0)
|
| | else:
|
| | return reconstructed
|
| |
|
| | def postprocess_encoded(self, tf_rep):
|
| | """Hook to perform transformations on the encoded, time-frequency domain
|
| | representation (output of the encoder) before encoder activation is applied.
|
| |
|
| | Args:
|
| | tf_rep (Tensor of shape (batch, freq, time)):
|
| | Output of the encoder, before encoder activation is applied.
|
| |
|
| | Return:
|
| | Transformed `tf_rep`
|
| | """
|
| | return tf_rep
|
| |
|
| | def postprocess_masks(self, masks):
|
| | """Hook to perform transformations on the masks (output of the masker) before
|
| | masks are applied.
|
| |
|
| | Args:
|
| | masks (Tensor of shape (batch, n_src, freq, time)):
|
| | Output of the masker
|
| |
|
| | Return:
|
| | Transformed `masks`
|
| | """
|
| | return masks
|
| |
|
| | def postprocess_masked(self, masked_tf_rep):
|
| | """Hook to perform transformations on the masked time-frequency domain
|
| | representation (result of masking in the time-frequency domain) before decoding.
|
| |
|
| | Args:
|
| | masked_tf_rep (Tensor of shape (batch, n_src, freq, time)):
|
| | Masked time-frequency representation, before decoding.
|
| |
|
| | Return:
|
| | Transformed `masked_tf_rep`
|
| | """
|
| | return masked_tf_rep
|
| |
|
| | def postprocess_decoded(self, decoded):
|
| | """Hook to perform transformations on the decoded, time domain representation
|
| | (output of the decoder) before original shape reconstruction.
|
| |
|
| | Args:
|
| | decoded (Tensor of shape (batch, n_src, time)):
|
| | Output of the decoder, before original shape reconstruction.
|
| |
|
| | Return:
|
| | Transformed `decoded`
|
| | """
|
| | return decoded
|
| |
|
| | def get_model_args(self):
|
| | """ Arguments needed to re-instantiate the model. """
|
| | fb_config = self.encoder.filterbank.get_config()
|
| | masknet_config = self.masker.get_config()
|
| |
|
| | if not all(k not in fb_config for k in masknet_config):
|
| | raise AssertionError(
|
| | "Filterbank and Mask network config share" "common keys. Merging them is not safe."
|
| | )
|
| |
|
| | model_args = {
|
| | **fb_config,
|
| | **masknet_config,
|
| | "encoder_activation": self.encoder_activation,
|
| | }
|
| | return model_args
|
| |
|
| |
|
| |
|
| | BaseTasNet = BaseEncoderMaskerDecoder
|
| |
|