File size: 3,701 Bytes
64ec292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import sys
import warnings

import numpy as np
import torch
import torchaudio

sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from utils.audio_utils import denormalize_audio, normalize_audio
from utils.model_utils import demix, load_start_checkpoint
from utils.settings import get_model_from_config

warnings.filterwarnings("ignore")


class Separator:
    def __init__(
        self,
        config_path: str,
        checkpoint_path: str,
        model_type: str = "mel_band_roformer",
        device: str = "auto",
    ):
        if device == "auto":
            device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        self.model_type = model_type

        torch.backends.cudnn.benchmark = True
        self.model, self.config = get_model_from_config(model_type, config_path)

        if "model_type" in self.config.training:
            self.model_type = self.config.training.model_type

        from argparse import Namespace

        fake_args = Namespace(
            model_type=self.model_type,
            config_path=config_path,
            start_check_point=checkpoint_path,
            device="auto",
            output_dir="./output",
            use_tta=False,
            extract_instrumental=True,
            pcm_type="FLOAT",
            lora_checkpoint_loralib="",
            draw_spectro=False,
        )
        ckpt = torch.load(checkpoint_path, weights_only=False, map_location="cpu")
        load_start_checkpoint(fake_args, self.model, ckpt, type_="inference")

        self.model = self.model.to(self.device)
        self.model.eval()
        self.sample_rate = getattr(self.config.audio, "sample_rate", 44100)

    def separate(self, wav: torch.Tensor, sr: int):
        """
        Args:
            wav: Waveform returned by torchaudio.load, shape (channels, samples)
            sr:  Sample rate
        Returns:
            vocal_wav:  np.ndarray, shape (channels, samples)
            inst_wav:   np.ndarray, shape (channels, samples)
            sr:         int, output sample rate
        """
        # Resample if needed
        if sr != self.sample_rate:
            wav = torchaudio.transforms.Resample(sr, self.sample_rate)(wav)
            sr = self.sample_rate

        mix = wav.numpy()

        # Convert mono to stereo
        if mix.shape[0] == 1 and getattr(self.config.audio, "num_channels", 1) == 2:
            mix = np.concatenate([mix, mix], axis=0)

        mix_orig = mix.copy()

        # Normalize
        norm_params = None
        if getattr(self.config.inference, "normalize", False):
            mix, norm_params = normalize_audio(mix)

        # Separate
        waveforms = demix(
            self.config,
            self.model,
            mix,
            self.device,
            model_type=self.model_type,
            pbar=True,
        )

        # Extract vocals
        vocal_wav = waveforms.get("vocals", list(waveforms.values())[0])
        if norm_params is not None:
            vocal_wav = denormalize_audio(vocal_wav, norm_params)

        # Instrumental = original mix - vocals
        inst_wav = mix_orig - vocal_wav

        return vocal_wav, inst_wav, sr


# ---- Example Usage ----
if __name__ == "__main__":
    sep = Separator(
        config_path="ckpts/config_vocals_mel_band_roformer_kj.yaml",
        checkpoint_path="ckpts/MelBandRoformer.ckpt",
        device="cuda:0",
    )

    wav, sr = torchaudio.load("path/to/input.mp3")
    vocal_wav, inst_wav, sr = sep.separate(wav, sr)

    torchaudio.save("output_vocals.wav", torch.from_numpy(vocal_wav), sr)
    torchaudio.save("output_instrumental.wav", torch.from_numpy(inst_wav), sr)