Spaces:
Runtime error
Runtime error
| # coding: utf-8 | |
| __author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/" | |
| import glob | |
| import os | |
| import sys | |
| import time | |
| import librosa | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| import torch.nn as nn | |
| from tqdm.auto import tqdm | |
| # Using the embedded version of Python can also correctly import the utils module. | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(current_dir) | |
| import warnings | |
| from utils.audio_utils import denormalize_audio, draw_spectrogram, normalize_audio | |
| from utils.model_utils import ( | |
| apply_tta, | |
| demix, | |
| load_start_checkpoint, | |
| prefer_target_instrument, | |
| ) | |
| from utils.settings import get_model_from_config, parse_args_inference | |
| warnings.filterwarnings("ignore") | |
| def run_folder( | |
| model: "torch.nn.Module", | |
| args: "argparse.Namespace", | |
| config: dict, | |
| device: "torch.device", | |
| verbose: bool = False, | |
| ) -> None: | |
| """ | |
| Process a folder of audio files for source separation. | |
| Parameters: | |
| ---------- | |
| model : torch.nn.Module | |
| Pre-trained model for source separation. | |
| args : argparse.Namespace | |
| Arguments containing input folder, output folder, and processing options. | |
| config : dict | |
| Configuration object with audio and inference settings. | |
| device : torch.device | |
| Device for model inference (CPU or CUDA). | |
| verbose : bool, optional | |
| If True, prints detailed information during processing. Default is False. | |
| """ | |
| start_time = time.time() | |
| model.eval() | |
| # Recursively collect all files from input directory | |
| mixture_paths = sorted( | |
| glob.glob(os.path.join(args.input_folder, "**/*.*"), recursive=True) | |
| ) | |
| mixture_paths = [p for p in mixture_paths if os.path.isfile(p)] | |
| sample_rate: int = getattr(config.audio, "sample_rate", 44100) | |
| print(f"Total files found: {len(mixture_paths)}. Using sample rate: {sample_rate}") | |
| instruments: list[str] = prefer_target_instrument(config)[:] | |
| os.makedirs(args.store_dir, exist_ok=True) | |
| # Wrap paths with progress bar if not in verbose mode | |
| if not verbose: | |
| mixture_paths = tqdm(mixture_paths, desc="Total progress") | |
| # Determine whether to use detailed progress bar | |
| if args.disable_detailed_pbar: | |
| detailed_pbar = False | |
| else: | |
| detailed_pbar = True | |
| for path in mixture_paths: | |
| # Get relative path from input folder | |
| relative_path: str = os.path.relpath(path, args.input_folder) | |
| # Extract directory and file name | |
| dir_name: str = os.path.dirname(relative_path) | |
| file_name: str = os.path.splitext(os.path.basename(path))[0] | |
| try: | |
| mix, sr = librosa.load(path, sr=sample_rate, mono=False) | |
| except Exception as e: | |
| print(f"Cannot read track: {format(path)}") | |
| print(f"Error message: {str(e)}") | |
| continue | |
| # Convert mono audio to expected channel format if needed | |
| if len(mix.shape) == 1: | |
| mix = np.expand_dims(mix, axis=0) | |
| if "num_channels" in config.audio: | |
| if config.audio["num_channels"] == 2: | |
| print("Convert mono track to stereo...") | |
| mix = np.concatenate([mix, mix], axis=0) | |
| mix_orig = mix.copy() | |
| # Normalize input audio if enabled | |
| if "normalize" in config.inference: | |
| if config.inference["normalize"] is True: | |
| mix, norm_params = normalize_audio(mix) | |
| # Perform source separation | |
| waveforms_orig = demix( | |
| config, model, mix, device, model_type=args.model_type, pbar=detailed_pbar | |
| ) | |
| # Apply test-time augmentation if enabled | |
| if args.use_tta: | |
| waveforms_orig = apply_tta( | |
| config, model, mix, waveforms_orig, device, args.model_type | |
| ) | |
| # Extract instrumental track if requested | |
| if args.extract_instrumental: | |
| instr = "vocals" if "vocals" in instruments else instruments[0] | |
| waveforms_orig["instrumental"] = mix_orig - waveforms_orig[instr] | |
| if "instrumental" not in instruments: | |
| instruments.append("instrumental") | |
| for instr in instruments: | |
| estimates = waveforms_orig[instr] | |
| # Denormalize output audio if normalization was applied | |
| if "normalize" in config.inference: | |
| if config.inference["normalize"] is True: | |
| estimates = denormalize_audio(estimates, norm_params) | |
| peak: float = float(np.abs(estimates).max()) | |
| if peak <= 1.0 and args.pcm_type != "FLOAT": | |
| codec = "flac" | |
| else: | |
| codec = "wav" | |
| subtype = args.pcm_type | |
| # Generate output directory structure using relative paths | |
| dirnames, fname = format_filename( | |
| args.filename_template, | |
| instr=instr, | |
| start_time=int(start_time), | |
| file_name=file_name, | |
| dir_name=dir_name, | |
| model_type=args.model_type, | |
| model=os.path.splitext(os.path.basename(args.start_check_point))[0], | |
| ) | |
| # Create output directory | |
| output_dir: str = os.path.join(args.store_dir, *dirnames) | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_path: str = os.path.join(output_dir, f"{fname}.{codec}") | |
| sf.write(output_path, estimates.T, sr, subtype=subtype) | |
| # Draw and save spectrogram if enabled | |
| if args.draw_spectro > 0: | |
| output_img_path = os.path.join(output_dir, f"{fname}.jpg") | |
| draw_spectrogram(estimates.T, sr, args.draw_spectro, output_img_path) | |
| print("Wrote file:", output_img_path) | |
| print(f"Elapsed time: {time.time() - start_time:.2f} seconds.") | |
| def format_filename(template, **kwargs): | |
| """ | |
| Formats a filename from a template. e.g "{file_name}/{instr}" | |
| Using slashes ('/') in template will result in directories being created | |
| Returns [dirnames, fname], i.e. an array of dir names and a single file name | |
| """ | |
| result = template | |
| for k, v in kwargs.items(): | |
| result = result.replace(f"{{{k}}}", str(v)) | |
| *dirnames, fname = result.split("/") | |
| return dirnames, fname | |
| def proc_folder(dict_args): | |
| args = parse_args_inference(dict_args) | |
| device = "cpu" | |
| if args.force_cpu: | |
| device = "cpu" | |
| elif torch.cuda.is_available(): | |
| print("CUDA is available, use --force_cpu to disable it.") | |
| device = ( | |
| f"cuda:{args.device_ids[0]}" | |
| if isinstance(args.device_ids, list) | |
| else f"cuda:{args.device_ids}" | |
| ) | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| print("Using device: ", device) | |
| model_load_start_time = time.time() | |
| torch.backends.cudnn.benchmark = True | |
| model, config = get_model_from_config(args.model_type, args.config_path) | |
| if "model_type" in config.training: | |
| args.model_type = config.training.model_type | |
| if args.start_check_point: | |
| checkpoint = torch.load( | |
| args.start_check_point, weights_only=False, map_location="cpu" | |
| ) | |
| load_start_checkpoint(args, model, checkpoint, type_="inference") | |
| print("Instruments: {}".format(config.training.instruments)) | |
| # in case multiple CUDA GPUs are used and --device_ids arg is passed | |
| if ( | |
| isinstance(args.device_ids, list) | |
| and len(args.device_ids) > 1 | |
| and not args.force_cpu | |
| ): | |
| model = nn.DataParallel(model, device_ids=args.device_ids) | |
| model = model.to(device) | |
| print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time)) | |
| run_folder(model, args, config, device, verbose=True) | |
| if __name__ == "__main__": | |
| proc_folder(None) | |