xjsc0's picture
1
64ec292
# coding: utf-8
__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
import glob
import os
import sys
import time
import librosa
import numpy as np
import soundfile as sf
import torch
import torch.nn as nn
from tqdm.auto import tqdm
# Using the embedded version of Python can also correctly import the utils module.
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
import warnings
from utils.audio_utils import denormalize_audio, draw_spectrogram, normalize_audio
from utils.model_utils import (
apply_tta,
demix,
load_start_checkpoint,
prefer_target_instrument,
)
from utils.settings import get_model_from_config, parse_args_inference
warnings.filterwarnings("ignore")
def run_folder(
model: "torch.nn.Module",
args: "argparse.Namespace",
config: dict,
device: "torch.device",
verbose: bool = False,
) -> None:
"""
Process a folder of audio files for source separation.
Parameters:
----------
model : torch.nn.Module
Pre-trained model for source separation.
args : argparse.Namespace
Arguments containing input folder, output folder, and processing options.
config : dict
Configuration object with audio and inference settings.
device : torch.device
Device for model inference (CPU or CUDA).
verbose : bool, optional
If True, prints detailed information during processing. Default is False.
"""
start_time = time.time()
model.eval()
# Recursively collect all files from input directory
mixture_paths = sorted(
glob.glob(os.path.join(args.input_folder, "**/*.*"), recursive=True)
)
mixture_paths = [p for p in mixture_paths if os.path.isfile(p)]
sample_rate: int = getattr(config.audio, "sample_rate", 44100)
print(f"Total files found: {len(mixture_paths)}. Using sample rate: {sample_rate}")
instruments: list[str] = prefer_target_instrument(config)[:]
os.makedirs(args.store_dir, exist_ok=True)
# Wrap paths with progress bar if not in verbose mode
if not verbose:
mixture_paths = tqdm(mixture_paths, desc="Total progress")
# Determine whether to use detailed progress bar
if args.disable_detailed_pbar:
detailed_pbar = False
else:
detailed_pbar = True
for path in mixture_paths:
# Get relative path from input folder
relative_path: str = os.path.relpath(path, args.input_folder)
# Extract directory and file name
dir_name: str = os.path.dirname(relative_path)
file_name: str = os.path.splitext(os.path.basename(path))[0]
try:
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
except Exception as e:
print(f"Cannot read track: {format(path)}")
print(f"Error message: {str(e)}")
continue
# Convert mono audio to expected channel format if needed
if len(mix.shape) == 1:
mix = np.expand_dims(mix, axis=0)
if "num_channels" in config.audio:
if config.audio["num_channels"] == 2:
print("Convert mono track to stereo...")
mix = np.concatenate([mix, mix], axis=0)
mix_orig = mix.copy()
# Normalize input audio if enabled
if "normalize" in config.inference:
if config.inference["normalize"] is True:
mix, norm_params = normalize_audio(mix)
# Perform source separation
waveforms_orig = demix(
config, model, mix, device, model_type=args.model_type, pbar=detailed_pbar
)
# Apply test-time augmentation if enabled
if args.use_tta:
waveforms_orig = apply_tta(
config, model, mix, waveforms_orig, device, args.model_type
)
# Extract instrumental track if requested
if args.extract_instrumental:
instr = "vocals" if "vocals" in instruments else instruments[0]
waveforms_orig["instrumental"] = mix_orig - waveforms_orig[instr]
if "instrumental" not in instruments:
instruments.append("instrumental")
for instr in instruments:
estimates = waveforms_orig[instr]
# Denormalize output audio if normalization was applied
if "normalize" in config.inference:
if config.inference["normalize"] is True:
estimates = denormalize_audio(estimates, norm_params)
peak: float = float(np.abs(estimates).max())
if peak <= 1.0 and args.pcm_type != "FLOAT":
codec = "flac"
else:
codec = "wav"
subtype = args.pcm_type
# Generate output directory structure using relative paths
dirnames, fname = format_filename(
args.filename_template,
instr=instr,
start_time=int(start_time),
file_name=file_name,
dir_name=dir_name,
model_type=args.model_type,
model=os.path.splitext(os.path.basename(args.start_check_point))[0],
)
# Create output directory
output_dir: str = os.path.join(args.store_dir, *dirnames)
os.makedirs(output_dir, exist_ok=True)
output_path: str = os.path.join(output_dir, f"{fname}.{codec}")
sf.write(output_path, estimates.T, sr, subtype=subtype)
# Draw and save spectrogram if enabled
if args.draw_spectro > 0:
output_img_path = os.path.join(output_dir, f"{fname}.jpg")
draw_spectrogram(estimates.T, sr, args.draw_spectro, output_img_path)
print("Wrote file:", output_img_path)
print(f"Elapsed time: {time.time() - start_time:.2f} seconds.")
def format_filename(template, **kwargs):
"""
Formats a filename from a template. e.g "{file_name}/{instr}"
Using slashes ('/') in template will result in directories being created
Returns [dirnames, fname], i.e. an array of dir names and a single file name
"""
result = template
for k, v in kwargs.items():
result = result.replace(f"{{{k}}}", str(v))
*dirnames, fname = result.split("/")
return dirnames, fname
def proc_folder(dict_args):
args = parse_args_inference(dict_args)
device = "cpu"
if args.force_cpu:
device = "cpu"
elif torch.cuda.is_available():
print("CUDA is available, use --force_cpu to disable it.")
device = (
f"cuda:{args.device_ids[0]}"
if isinstance(args.device_ids, list)
else f"cuda:{args.device_ids}"
)
elif torch.backends.mps.is_available():
device = "mps"
print("Using device: ", device)
model_load_start_time = time.time()
torch.backends.cudnn.benchmark = True
model, config = get_model_from_config(args.model_type, args.config_path)
if "model_type" in config.training:
args.model_type = config.training.model_type
if args.start_check_point:
checkpoint = torch.load(
args.start_check_point, weights_only=False, map_location="cpu"
)
load_start_checkpoint(args, model, checkpoint, type_="inference")
print("Instruments: {}".format(config.training.instruments))
# in case multiple CUDA GPUs are used and --device_ids arg is passed
if (
isinstance(args.device_ids, list)
and len(args.device_ids) > 1
and not args.force_cpu
):
model = nn.DataParallel(model, device_ids=args.device_ids)
model = model.to(device)
print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time))
run_folder(model, args, config, device, verbose=True)
if __name__ == "__main__":
proc_folder(None)