| import inspect |
| import typing |
| from functools import wraps |
|
|
| from . import util |
|
|
|
|
| def format_figure(func): |
| """Decorator for formatting figures produced by the code below. |
| See :py:func:`audiotools.core.util.format_figure` for more. |
| |
| Parameters |
| ---------- |
| func : Callable |
| Plotting function that is decorated by this function. |
| |
| """ |
|
|
| @wraps(func) |
| def wrapper(*args, **kwargs): |
| f_keys = inspect.signature(util.format_figure).parameters.keys() |
| f_kwargs = {} |
| for k, v in list(kwargs.items()): |
| if k in f_keys: |
| kwargs.pop(k) |
| f_kwargs[k] = v |
| func(*args, **kwargs) |
| util.format_figure(**f_kwargs) |
|
|
| return wrapper |
|
|
|
|
| class DisplayMixin: |
| @format_figure |
| def specshow( |
| self, |
| preemphasis: bool = False, |
| x_axis: str = "time", |
| y_axis: str = "linear", |
| n_mels: int = 128, |
| **kwargs, |
| ): |
| """Displays a spectrogram, using ``librosa.display.specshow``. |
| |
| Parameters |
| ---------- |
| preemphasis : bool, optional |
| Whether or not to apply preemphasis, which makes high |
| frequency detail easier to see, by default False |
| x_axis : str, optional |
| How to label the x axis, by default "time" |
| y_axis : str, optional |
| How to label the y axis, by default "linear" |
| n_mels : int, optional |
| If displaying a mel spectrogram with ``y_axis = "mel"``, |
| this controls the number of mels, by default 128. |
| kwargs : dict, optional |
| Keyword arguments to :py:func:`audiotools.core.util.format_figure`. |
| """ |
| import librosa |
| import librosa.display |
|
|
| |
| |
| signal = self.clone() |
| signal.stft_data = None |
|
|
| if preemphasis: |
| signal.preemphasis() |
|
|
| ref = signal.magnitude.max() |
| log_mag = signal.log_magnitude(ref_value=ref) |
|
|
| if y_axis == "mel": |
| log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10() |
| log_mag -= log_mag.max() |
|
|
| librosa.display.specshow( |
| log_mag.numpy()[0].mean(axis=0), |
| x_axis=x_axis, |
| y_axis=y_axis, |
| sr=signal.sample_rate, |
| **kwargs, |
| ) |
|
|
| @format_figure |
| def waveplot(self, x_axis: str = "time", **kwargs): |
| """Displays a waveform plot, using ``librosa.display.waveshow``. |
| |
| Parameters |
| ---------- |
| x_axis : str, optional |
| How to label the x axis, by default "time" |
| kwargs : dict, optional |
| Keyword arguments to :py:func:`audiotools.core.util.format_figure`. |
| """ |
| import librosa |
| import librosa.display |
|
|
| audio_data = self.audio_data[0].mean(dim=0) |
| audio_data = audio_data.cpu().numpy() |
|
|
| plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot" |
| wave_plot_fn = getattr(librosa.display, plot_fn) |
| wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs) |
|
|
| @format_figure |
| def wavespec(self, x_axis: str = "time", **kwargs): |
| """Displays a waveform plot, using ``librosa.display.waveshow``. |
| |
| Parameters |
| ---------- |
| x_axis : str, optional |
| How to label the x axis, by default "time" |
| kwargs : dict, optional |
| Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`. |
| """ |
| import matplotlib.pyplot as plt |
| from matplotlib.gridspec import GridSpec |
|
|
| gs = GridSpec(6, 1) |
| plt.subplot(gs[0, :]) |
| self.waveplot(x_axis=x_axis) |
| plt.subplot(gs[1:, :]) |
| self.specshow(x_axis=x_axis, **kwargs) |
|
|
| def write_audio_to_tb( |
| self, |
| tag: str, |
| writer, |
| step: int = None, |
| plot_fn: typing.Union[typing.Callable, str] = "specshow", |
| **kwargs, |
| ): |
| """Writes a signal and its spectrogram to Tensorboard. Will show up |
| under the Audio and Images tab in Tensorboard. |
| |
| Parameters |
| ---------- |
| tag : str |
| Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be |
| written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``). |
| writer : SummaryWriter |
| A SummaryWriter object from PyTorch library. |
| step : int, optional |
| The step to write the signal to, by default None |
| plot_fn : typing.Union[typing.Callable, str], optional |
| How to create the image. Set to ``None`` to avoid plotting, by default "specshow" |
| kwargs : dict, optional |
| Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or |
| whatever ``plot_fn`` is set to. |
| """ |
| import matplotlib.pyplot as plt |
|
|
| audio_data = self.audio_data[0, 0].detach().cpu() |
| sample_rate = self.sample_rate |
| writer.add_audio(tag, audio_data, step, sample_rate) |
|
|
| if plot_fn is not None: |
| if isinstance(plot_fn, str): |
| plot_fn = getattr(self, plot_fn) |
| fig = plt.figure() |
| plt.clf() |
| plot_fn(**kwargs) |
| writer.add_figure(tag.replace("wav", "png"), fig, step) |
|
|
| def save_image( |
| self, |
| image_path: str, |
| plot_fn: typing.Union[typing.Callable, str] = "specshow", |
| **kwargs, |
| ): |
| """Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to |
| a specified file. |
| |
| Parameters |
| ---------- |
| image_path : str |
| Where to save the file to. |
| plot_fn : typing.Union[typing.Callable, str], optional |
| How to create the image. Set to ``None`` to avoid plotting, by default "specshow" |
| kwargs : dict, optional |
| Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or |
| whatever ``plot_fn`` is set to. |
| """ |
| import matplotlib.pyplot as plt |
|
|
| if isinstance(plot_fn, str): |
| plot_fn = getattr(self, plot_fn) |
|
|
| plt.clf() |
| plot_fn(**kwargs) |
| plt.savefig(image_path, bbox_inches="tight", pad_inches=0) |
| plt.close() |
|
|