xjsc0's picture
1
64ec292
import os
from typing import Dict, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import torch.distributed as dist
def read_audio_transposed(
path: str, instr: Optional[str] = None, skip_err: bool = False
) -> Tuple[Optional[np.ndarray], Optional[int]]:
"""
Read an audio file and return transposed waveform data with channels first.
Loads the audio file from `path`, converts mono signals to 2D format, and
transposes the array so that its shape is (channels, length). In case of
errors, either raises an exception or skips gracefully depending on
`skip_err`.
Args:
path (str): Path to the audio file to load.
instr (Optional[str], optional): Instrument name, used for informative
messages when `skip_err` is True. Defaults to None.
skip_err (bool, optional): If True, skip files with read errors and
return `(None, None)` instead of raising. Defaults to False.
Returns:
Tuple[Optional[np.ndarray], Optional[int]]: A tuple containing:
- NumPy array of shape (channels, length), or None if skipped.
- Sampling rate as an integer, or None if skipped.
"""
should_print = not dist.is_initialized() or dist.get_rank() == 0
try:
mix, sr = sf.read(path)
except Exception as e:
if skip_err:
if should_print:
print(f"No stem {instr}: skip!")
return None, None
else:
raise RuntimeError(f"Error reading the file at {path}: {e}")
else:
if len(mix.shape) == 1: # For mono audio
mix = np.expand_dims(mix, axis=-1)
return mix.T, sr
def normalize_audio(audio: np.ndarray) -> Tuple[np.ndarray, Dict[str, float]]:
"""
Normalize an audio signal using mean and standard deviation.
Computes the mean and standard deviation from the mono mix of the input
signal, then applies normalization to each channel.
Args:
audio (np.ndarray): Input audio array of shape (channels, time) or (time,).
Returns:
Tuple[np.ndarray, Dict[str, float]]: A tuple containing:
- Normalized audio with the same shape as the input.
- A dictionary with keys "mean" and "std" from the original audio.
"""
mono = audio.mean(0)
mean, std = mono.mean(), mono.std()
return (audio - mean) / std, {"mean": mean, "std": std}
def denormalize_audio(audio: np.ndarray, norm_params: Dict[str, float]) -> np.ndarray:
"""
Reverse normalization on an audio signal.
Applies the stored mean and standard deviation to restore the original
scale of a previously normalized signal.
Args:
audio (np.ndarray): Normalized audio array to be denormalized.
norm_params (Dict[str, float]): Dictionary containing the keys
"mean" and "std" used during normalization.
Returns:
np.ndarray: Denormalized audio with the same shape as the input.
"""
return audio * norm_params["std"] + norm_params["mean"]
def draw_spectrogram(
waveform: np.ndarray, sample_rate: int, length: float, output_file: str
) -> None:
"""
Generate and save a spectrogram image from an audio waveform.
Converts the provided waveform into a mono signal, computes its Short-Time
Fourier Transform (STFT), converts the amplitude spectrogram to dB scale,
and plots it using a plasma colormap.
Args:
waveform (np.ndarray): Input audio waveform array of shape (time, channels)
or (time,).
sample_rate (int): Sampling rate of the waveform in Hz.
length (float): Duration (in seconds) of the waveform to include in the
spectrogram.
output_file (str): Path to save the resulting spectrogram image.
Returns:
None
"""
import librosa.display
# Cut only required part of spectorgram
x = waveform[: int(length * sample_rate), :]
X = librosa.stft(
x.mean(axis=-1)
) # perform short-term fourier transform on mono signal
Xdb = librosa.amplitude_to_db(
np.abs(X), ref=np.max
) # convert an amplitude spectrogram to dB-scaled spectrogram.
fig, ax = plt.subplots()
# plt.figure(figsize=(30, 10)) # initialize the fig size
img = librosa.display.specshow(
Xdb, cmap="plasma", sr=sample_rate, x_axis="time", y_axis="linear", ax=ax
)
ax.set(title="File: " + os.path.basename(output_file))
fig.colorbar(img, ax=ax, format="%+2.f dB")
if output_file is not None:
plt.savefig(output_file)
def draw_2_mel_spectrogram(
estimates_waveform: np.ndarray,
track_waveform: np.ndarray,
sample_rate: int,
length: float,
output_base: str,
) -> None:
"""
Generate and save separate images for spectrograms and waveforms
for both estimated and original audio.
Creates two separate images:
- One with mel-spectrograms (estimated vs original)
- One with waveforms (estimated vs original)
Args:
estimates_waveform (np.ndarray): Estimated audio waveform
track_waveform (np.ndarray): Original audio waveform
sample_rate (int): Sampling rate in Hz
length (float): Duration in seconds to include
output_base (str): Base path for output files (without extension)
Returns:
None
"""
import librosa.display
# Prepare both waveforms
waveforms = [estimates_waveform, track_waveform]
titles = ["Estimated", "Original"]
# Store processed (mono, possibly decimated) waveforms
processed_waveforms: list[tuple[np.ndarray, int]] = []
for waveform in waveforms:
# Convert to mono if multi-channel
mono_signal = waveform.mean(axis=-1) if len(waveform.shape) > 1 else waveform
# Apply decimation for long audio signals
if len(mono_signal) > 60 * sample_rate:
# Decimation: take every second sample
mono_signal = mono_signal[::2]
effective_sr = sample_rate // 2
else:
effective_sr = sample_rate
processed_waveforms.append((mono_signal, effective_sr))
# Create mel-spectrograms figure
fig_spec, axes_spec = plt.subplots(2, 1, figsize=(16, 10))
for i, ((mono_signal, effective_sr), title) in enumerate(
zip(processed_waveforms, titles)
):
# Compute mel-spectrogram with reduced number of mel bins
S = librosa.feature.melspectrogram(y=mono_signal, sr=effective_sr, n_mels=128)
S_db = librosa.power_to_db(S, ref=np.max)
# Plot mel-spectrogram
img = librosa.display.specshow(
S_db,
cmap="plasma",
sr=effective_sr,
x_axis="time",
y_axis="mel",
ax=axes_spec[i],
)
axes_spec[i].set_title(
f"Mel-spectrogram: {title}", fontsize=14, fontweight="bold"
)
axes_spec[i].set_xlabel("Time (seconds)", fontsize=12)
axes_spec[i].set_ylabel("Frequency (Mel)", fontsize=12)
# Colorbar intentionally disabled
# fig_spec.colorbar(img, ax=axes_spec, format="%+2.f dB",
# shrink=0.8, pad=0.02, location="right")
# Set global title for spectrograms
fig_spec.suptitle(
f"Mel-spectrograms: {os.path.basename(output_base)}",
fontsize=16,
fontweight="bold",
y=0.98,
)
plt.tight_layout()
plt.subplots_adjust(top=0.94, hspace=0.4, right=0.88)
# Save spectrograms image with reduced DPI
spec_output = f"{output_base}_spectrograms.jpg"
plt.savefig(spec_output, dpi=150, bbox_inches="tight")
plt.close(fig_spec)
# Create waveforms figure
fig_wave, axes_wave = plt.subplots(2, 1, figsize=(16, 8))
for i, ((mono_signal, effective_sr), title) in enumerate(
zip(processed_waveforms, titles)
):
# Generate time axis
time = np.linspace(0, len(mono_signal) / effective_sr, len(mono_signal))
# Plot simplified waveform for very long signals
if len(mono_signal) > 100000:
# Take every 10th sample for plotting
plot_indices = np.arange(0, len(mono_signal), 10)
axes_wave[i].plot(
time[plot_indices],
mono_signal[plot_indices],
color="#00ff88",
alpha=0.9,
linewidth=0.5,
)
else:
axes_wave[i].plot(
time, mono_signal, color="#00ff88", alpha=0.9, linewidth=0.8
)
axes_wave[i].fill_between(time, mono_signal, alpha=0.3, color="#00ff8833")
axes_wave[i].set_xlabel("Time (seconds)", fontsize=12)
axes_wave[i].set_ylabel("Amplitude", fontsize=12)
axes_wave[i].set_title(f"Waveform: {title}", fontsize=14, fontweight="bold")
axes_wave[i].grid(True, alpha=0.3, color="gray")
axes_wave[i].set_xlim(0, time[-1])
# Set global title for waveforms
fig_wave.suptitle(
f"Waveforms: {os.path.basename(output_base)}",
fontsize=16,
fontweight="bold",
y=0.98,
)
plt.tight_layout()
plt.subplots_adjust(top=0.94, hspace=0.4)
# Save waveforms image
wave_output = f"{output_base}_waveforms.jpg"
plt.savefig(wave_output, dpi=150, bbox_inches="tight")
plt.close(fig_wave)
def draw_mel_spectrogram(
waveform: np.ndarray, sample_rate: int, length: float, output_file: str
) -> None:
"""
Generate and save a spectrogram image from an audio waveform.
Converts the provided waveform into a mono signal, computes its Short-Time
Fourier Transform (STFT), converts the amplitude spectrogram to dB scale,
and plots it using a plasma colormap.
Args:
waveform (np.ndarray): Input audio waveform array of shape (time, channels)
or (time,).
sample_rate (int): Sampling rate of the waveform in Hz.
length (float): Duration (in seconds) of the waveform to include in the
spectrogram.
output_file (str): Path to save the resulting spectrogram image.
Returns:
None
"""
import librosa.display
# Cut only required part of spectrogram
x = waveform
# Compute mel-spectrogram instead of STFT
S = librosa.feature.melspectrogram(
y=x.mean(axis=-1), # mono signal
sr=sample_rate,
)
# Convert to dB scale
S_db = librosa.power_to_db(S, ref=np.max)
fig, ax = plt.subplots()
try:
img = librosa.display.specshow(
S_db, cmap="plasma", sr=sample_rate, x_axis="time", y_axis="mel", ax=ax
)
ax.set(title="Mel-spectrogram: " + os.path.basename(output_file))
fig.colorbar(img, ax=ax, format="%+2.f dB")
if output_file is not None:
plt.savefig(output_file)
finally:
plt.close(fig)
plot_waveform_basic(
waveform, sample_rate, output_file.replace(".jpg", "_waveform.jpg")
)
def plot_waveform_basic(waveform, samplerate, output_path=None, theme="dark"):
data = waveform
if len(data.shape) > 1:
data = np.mean(data, axis=1)
try:
themes = {
"dark": {"bg": "#0f0f0f", "wave": "#00ff88", "fill": "#00ff8833"},
"light": {"bg": "white", "wave": "#2563eb", "fill": "#3b82f633"},
"purple": {"bg": "#1a1a2e", "wave": "#e94560", "fill": "#e9456033"},
}
colors = themes.get(theme, themes["dark"])
fig, ax = plt.subplots(figsize=(12, 3), facecolor=colors["bg"])
time = np.linspace(0, len(data) / samplerate, len(data))
ax.plot(time, data, color=colors["wave"], alpha=0.9, linewidth=0.8)
ax.fill_between(time, data, alpha=0.3, color=colors["fill"])
ax.set_facecolor(colors["bg"])
if theme == "dark" or theme == "purple":
ax.tick_params(colors="white", labelsize=8)
ax.set_xlabel("Time (seconds)", color="white", fontsize=10)
ax.set_ylabel("Amplitude", color="white", fontsize=10)
else:
ax.tick_params(colors="black", labelsize=8)
ax.grid(True, alpha=0.2, color="gray")
ax.set_xlim(0, time[-1])
plt.tight_layout()
if output_path:
plt.savefig(
output_path,
dpi=200,
bbox_inches="tight",
facecolor=colors["bg"],
edgecolor="none",
)
finally:
plt.close()