Spaces:
Runtime error
Runtime error
File size: 3,701 Bytes
64ec292 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | import os
import sys
import warnings
import numpy as np
import torch
import torchaudio
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from utils.audio_utils import denormalize_audio, normalize_audio
from utils.model_utils import demix, load_start_checkpoint
from utils.settings import get_model_from_config
warnings.filterwarnings("ignore")
class Separator:
def __init__(
self,
config_path: str,
checkpoint_path: str,
model_type: str = "mel_band_roformer",
device: str = "auto",
):
if device == "auto":
device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.model_type = model_type
torch.backends.cudnn.benchmark = True
self.model, self.config = get_model_from_config(model_type, config_path)
if "model_type" in self.config.training:
self.model_type = self.config.training.model_type
from argparse import Namespace
fake_args = Namespace(
model_type=self.model_type,
config_path=config_path,
start_check_point=checkpoint_path,
device="auto",
output_dir="./output",
use_tta=False,
extract_instrumental=True,
pcm_type="FLOAT",
lora_checkpoint_loralib="",
draw_spectro=False,
)
ckpt = torch.load(checkpoint_path, weights_only=False, map_location="cpu")
load_start_checkpoint(fake_args, self.model, ckpt, type_="inference")
self.model = self.model.to(self.device)
self.model.eval()
self.sample_rate = getattr(self.config.audio, "sample_rate", 44100)
def separate(self, wav: torch.Tensor, sr: int):
"""
Args:
wav: Waveform returned by torchaudio.load, shape (channels, samples)
sr: Sample rate
Returns:
vocal_wav: np.ndarray, shape (channels, samples)
inst_wav: np.ndarray, shape (channels, samples)
sr: int, output sample rate
"""
# Resample if needed
if sr != self.sample_rate:
wav = torchaudio.transforms.Resample(sr, self.sample_rate)(wav)
sr = self.sample_rate
mix = wav.numpy()
# Convert mono to stereo
if mix.shape[0] == 1 and getattr(self.config.audio, "num_channels", 1) == 2:
mix = np.concatenate([mix, mix], axis=0)
mix_orig = mix.copy()
# Normalize
norm_params = None
if getattr(self.config.inference, "normalize", False):
mix, norm_params = normalize_audio(mix)
# Separate
waveforms = demix(
self.config,
self.model,
mix,
self.device,
model_type=self.model_type,
pbar=True,
)
# Extract vocals
vocal_wav = waveforms.get("vocals", list(waveforms.values())[0])
if norm_params is not None:
vocal_wav = denormalize_audio(vocal_wav, norm_params)
# Instrumental = original mix - vocals
inst_wav = mix_orig - vocal_wav
return vocal_wav, inst_wav, sr
# ---- Example Usage ----
if __name__ == "__main__":
sep = Separator(
config_path="ckpts/config_vocals_mel_band_roformer_kj.yaml",
checkpoint_path="ckpts/MelBandRoformer.ckpt",
device="cuda:0",
)
wav, sr = torchaudio.load("path/to/input.mp3")
vocal_wav, inst_wav, sr = sep.separate(wav, sr)
torchaudio.save("output_vocals.wav", torch.from_numpy(vocal_wav), sr)
torchaudio.save("output_instrumental.wav", torch.from_numpy(inst_wav), sr)
|