Spaces:
Running on Zero
Running on Zero
| # coding: utf-8 | |
| __author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/" | |
| __version__ = "1.0.3" | |
| # Read more here: | |
| # https://huggingface.co/docs/accelerate/index | |
| import argparse | |
| import glob | |
| import os | |
| import time | |
| import warnings | |
| import auraloss | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import wandb | |
| from accelerate import Accelerator | |
| from torch.optim import SGD, Adam, AdamW, RAdam, RMSprop | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn | |
| from torch.utils.data import DataLoader | |
| from tqdm.auto import tqdm | |
| from utils.dataset import MSSDataset | |
| from utils.losses import masked_loss | |
| from utils.metrics import sdr | |
| from utils.model_utils import ( | |
| demix, | |
| load_not_compatible_weights, | |
| prefer_target_instrument, | |
| ) | |
| from utils.settings import get_model_from_config, manual_seed | |
| warnings.filterwarnings("ignore") | |
| def valid(model, valid_loader, args, config, device, verbose=False): | |
| instruments = prefer_target_instrument(config) | |
| all_sdr = dict() | |
| for instr in instruments: | |
| all_sdr[instr] = [] | |
| all_mixtures_path = valid_loader | |
| if verbose: | |
| all_mixtures_path = tqdm(valid_loader) | |
| pbar_dict = {} | |
| for path_list in all_mixtures_path: | |
| path = path_list[0] | |
| mix, sr = sf.read(path) | |
| folder = os.path.dirname(path) | |
| res = demix(config, model, mix.T, device, model_type=args.model_type) # mix.T | |
| for instr in instruments: | |
| if instr != "other" or config.training.other_fix is False: | |
| track, sr1 = sf.read(folder + "/{}.wav".format(instr)) | |
| else: | |
| # other is actually instrumental | |
| track, sr1 = sf.read(folder + "/{}.wav".format("vocals")) | |
| track = mix - track | |
| # sf.write("{}.wav".format(instr), res[instr].T, sr, subtype='FLOAT') | |
| references = np.expand_dims(track, axis=0) | |
| estimates = np.expand_dims(res[instr].T, axis=0) | |
| sdr_val = sdr(references, estimates)[0] | |
| single_val = torch.from_numpy(np.array([sdr_val])).to(device) | |
| all_sdr[instr].append(single_val) | |
| pbar_dict["sdr_{}".format(instr)] = sdr_val | |
| if verbose: | |
| all_mixtures_path.set_postfix(pbar_dict) | |
| return all_sdr | |
| class MSSValidationDataset(torch.utils.data.Dataset): | |
| def __init__(self, args): | |
| all_mixtures_path = [] | |
| for valid_path in args.valid_path: | |
| part = sorted(glob.glob(valid_path + "/*/mixture.wav")) | |
| if len(part) == 0: | |
| print("No validation data found in: {}".format(valid_path)) | |
| all_mixtures_path += part | |
| self.list_of_files = all_mixtures_path | |
| def __len__(self): | |
| return len(self.list_of_files) | |
| def __getitem__(self, index): | |
| return self.list_of_files[index] | |
| def train_model(args): | |
| accelerator = Accelerator() | |
| device = accelerator.device | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model_type", | |
| type=str, | |
| default="mdx23c", | |
| help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet, bandit", | |
| ) | |
| parser.add_argument("--config_path", type=str, help="path to config file") | |
| parser.add_argument( | |
| "--start_check_point", | |
| type=str, | |
| default="", | |
| help="Initial checkpoint to start training", | |
| ) | |
| parser.add_argument( | |
| "--results_path", | |
| type=str, | |
| help="path to folder where results will be stored (weights, metadata)", | |
| ) | |
| parser.add_argument( | |
| "--data_path", | |
| nargs="+", | |
| type=str, | |
| help="Dataset data paths. You can provide several folders.", | |
| ) | |
| parser.add_argument( | |
| "--dataset_type", | |
| type=int, | |
| default=1, | |
| help="Dataset type. Must be one of: 1, 2, 3 or 4. Details here: https://github.com/ZFTurbo/Music-Source-Separation-Training/blob/main/docs/dataset_types.md", | |
| ) | |
| parser.add_argument( | |
| "--valid_path", | |
| nargs="+", | |
| type=str, | |
| help="validation data paths. You can provide several folders.", | |
| ) | |
| parser.add_argument( | |
| "--num_workers", type=int, default=0, help="dataloader num_workers" | |
| ) | |
| parser.add_argument( | |
| "--pin_memory", type=bool, default=False, help="dataloader pin_memory" | |
| ) | |
| parser.add_argument("--seed", type=int, default=0, help="random seed") | |
| parser.add_argument( | |
| "--device_ids", nargs="+", type=int, default=[0], help="list of gpu ids" | |
| ) | |
| parser.add_argument( | |
| "--use_multistft_loss", | |
| action="store_true", | |
| help="Use MultiSTFT Loss (from auraloss package)", | |
| ) | |
| parser.add_argument( | |
| "--use_mse_loss", action="store_true", help="Use default MSE loss" | |
| ) | |
| parser.add_argument("--use_l1_loss", action="store_true", help="Use L1 loss") | |
| parser.add_argument("--wandb_key", type=str, default="", help="wandb API Key") | |
| parser.add_argument( | |
| "--pre_valid", action="store_true", help="Run validation before training" | |
| ) | |
| if args is None: | |
| args = parser.parse_args() | |
| else: | |
| args = parser.parse_args(args) | |
| manual_seed(args.seed + int(time.time())) | |
| # torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.deterministic = ( | |
| False # Fix possible slow down with dilation convolutions | |
| ) | |
| torch.multiprocessing.set_start_method("spawn") | |
| model, config = get_model_from_config(args.model_type, args.config_path) | |
| if "model_type" in config.training: | |
| args.model_type = config.training.model_type | |
| accelerator.print("Instruments: {}".format(config.training.instruments)) | |
| os.makedirs(args.results_path, exist_ok=True) | |
| device_ids = args.device_ids | |
| batch_size = config.training.batch_size | |
| # wandb | |
| if ( | |
| accelerator.is_main_process | |
| and args.wandb_key is not None | |
| and args.wandb_key.strip() != "" | |
| ): | |
| wandb.login(key=args.wandb_key) | |
| wandb.init( | |
| project="msst-accelerate", | |
| config={ | |
| "config": config, | |
| "args": args, | |
| "device_ids": device_ids, | |
| "batch_size": batch_size, | |
| }, | |
| ) | |
| else: | |
| wandb.init(mode="disabled") | |
| # Fix for num of steps | |
| config.training.num_steps *= accelerator.num_processes | |
| trainset = MSSDataset( | |
| config, | |
| args.data_path, | |
| batch_size=batch_size, | |
| metadata_path=os.path.join( | |
| args.results_path, "metadata_{}.pkl".format(args.dataset_type) | |
| ), | |
| dataset_type=args.dataset_type, | |
| verbose=accelerator.is_main_process, | |
| ) | |
| train_loader = DataLoader( | |
| trainset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=args.pin_memory, | |
| ) | |
| validset = MSSValidationDataset(args) | |
| valid_dataset_length = len(validset) | |
| valid_loader = DataLoader( | |
| validset, | |
| batch_size=1, | |
| shuffle=False, | |
| ) | |
| valid_loader = accelerator.prepare(valid_loader) | |
| if args.start_check_point != "": | |
| accelerator.print("Start from checkpoint: {}".format(args.start_check_point)) | |
| if 1: | |
| load_not_compatible_weights(model, args.start_check_point, verbose=False) | |
| else: | |
| model.load_state_dict(torch.load(args.start_check_point)) | |
| optim_params = dict() | |
| if "optimizer" in config: | |
| optim_params = dict(config["optimizer"]) | |
| accelerator.print("Optimizer params from config:\n{}".format(optim_params)) | |
| if config.training.optimizer == "adam": | |
| optimizer = Adam(model.parameters(), lr=config.training.lr, **optim_params) | |
| elif config.training.optimizer == "adamw": | |
| optimizer = AdamW(model.parameters(), lr=config.training.lr, **optim_params) | |
| elif config.training.optimizer == "radam": | |
| optimizer = RAdam(model.parameters(), lr=config.training.lr, **optim_params) | |
| elif config.training.optimizer == "rmsprop": | |
| optimizer = RMSprop(model.parameters(), lr=config.training.lr, **optim_params) | |
| elif config.training.optimizer == "prodigy": | |
| from prodigyopt import Prodigy | |
| # you can choose weight decay value based on your problem, 0 by default | |
| # We recommend using lr=1.0 (default) for all networks. | |
| optimizer = Prodigy(model.parameters(), lr=config.training.lr, **optim_params) | |
| elif config.training.optimizer == "adamw8bit": | |
| import bitsandbytes as bnb | |
| optimizer = bnb.optim.AdamW8bit( | |
| model.parameters(), lr=config.training.lr, **optim_params | |
| ) | |
| elif config.training.optimizer == "sgd": | |
| accelerator.print("Use SGD optimizer") | |
| optimizer = SGD(model.parameters(), lr=config.training.lr, **optim_params) | |
| else: | |
| accelerator.print("Unknown optimizer: {}".format(config.training.optimizer)) | |
| exit() | |
| if accelerator.is_main_process: | |
| print("Processes GPU: {}".format(accelerator.num_processes)) | |
| print( | |
| "Patience: {} Reduce factor: {} Batch size: {} Optimizer: {}".format( | |
| config.training.patience, | |
| config.training.reduce_factor, | |
| batch_size, | |
| config.training.optimizer, | |
| ) | |
| ) | |
| # Reduce LR if no SDR improvements for several epochs | |
| scheduler = ReduceLROnPlateau( | |
| optimizer, | |
| "max", | |
| # patience=accelerator.num_processes * config.training.patience, # This is strange place... | |
| patience=config.training.patience, | |
| factor=config.training.reduce_factor, | |
| ) | |
| if args.use_multistft_loss: | |
| try: | |
| loss_options = dict(config.loss_multistft) | |
| except: | |
| loss_options = dict() | |
| accelerator.print("Loss options: {}".format(loss_options)) | |
| loss_multistft = auraloss.freq.MultiResolutionSTFTLoss(**loss_options) | |
| model, optimizer, train_loader, scheduler = accelerator.prepare( | |
| model, optimizer, train_loader, scheduler | |
| ) | |
| ema_model = None | |
| if hasattr(config.training, "ema_momentum") and config.training.ema_momentum > 0: | |
| accelerator.print( | |
| f"Initializing EMA with decay: {config.training.ema_momentum}" | |
| ) | |
| ema_model = AveragedModel( | |
| accelerator.unwrap_model(model), | |
| multi_avg_fn=get_ema_multi_avg_fn(config.training.ema_momentum), | |
| ) | |
| ema_model.to(device) | |
| if args.pre_valid: | |
| model_to_valid = ema_model if ema_model is not None else model | |
| sdr_list = valid( | |
| model_to_valid, | |
| valid_loader, | |
| args, | |
| config, | |
| device, | |
| verbose=accelerator.is_main_process, | |
| ) | |
| sdr_list = accelerator.gather(sdr_list) | |
| accelerator.wait_for_everyone() | |
| # print(sdr_list) | |
| sdr_avg = 0.0 | |
| instruments = prefer_target_instrument(config) | |
| for instr in instruments: | |
| # print(sdr_list[instr]) | |
| sdr_data = torch.cat(sdr_list[instr], dim=0).cpu().numpy() | |
| sdr_val = sdr_data.mean() | |
| accelerator.print("Valid length: {}".format(valid_dataset_length)) | |
| accelerator.print( | |
| "Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data)) | |
| ) | |
| sdr_val = sdr_data[:valid_dataset_length].mean() | |
| accelerator.print( | |
| "Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data)) | |
| ) | |
| sdr_avg += sdr_val | |
| sdr_avg /= len(instruments) | |
| if len(instruments) > 1: | |
| accelerator.print("SDR Avg: {:.4f}".format(sdr_avg)) | |
| sdr_list = None | |
| accelerator.print("Train for: {}".format(config.training.num_epochs)) | |
| best_sdr = -100 | |
| for epoch in range(config.training.num_epochs): | |
| model.train().to(device) | |
| accelerator.print( | |
| "Train epoch: {} Learning rate: {}".format( | |
| epoch, optimizer.param_groups[0]["lr"] | |
| ) | |
| ) | |
| loss_val = 0.0 | |
| total = 0 | |
| pbar = tqdm(train_loader, disable=not accelerator.is_main_process) | |
| for i, (batch, mixes) in enumerate(pbar): | |
| y = batch | |
| x = mixes | |
| if args.model_type in [ | |
| "mel_band_roformer", | |
| "bs_roformer", | |
| "bs_mamba2", | |
| "mel_band_conformer", | |
| "bs_conformer", | |
| ]: | |
| # loss is computed in forward pass | |
| loss = model(x, y) | |
| else: | |
| y_ = model(x) | |
| if args.use_multistft_loss: | |
| y1_ = torch.reshape( | |
| y_, (y_.shape[0], y_.shape[1] * y_.shape[2], y_.shape[3]) | |
| ) | |
| y1 = torch.reshape( | |
| y, (y.shape[0], y.shape[1] * y.shape[2], y.shape[3]) | |
| ) | |
| loss = loss_multistft(y1_, y1) | |
| # We can use many losses at the same time | |
| if args.use_mse_loss: | |
| loss += 1000 * nn.MSELoss()(y1_, y1) | |
| if args.use_l1_loss: | |
| loss += 1000 * F.l1_loss(y1_, y1) | |
| elif args.use_mse_loss: | |
| loss = nn.MSELoss()(y_, y) | |
| elif args.use_l1_loss: | |
| loss = F.l1_loss(y_, y) | |
| else: | |
| loss = masked_loss( | |
| y_, | |
| y, | |
| q=config.training.q, | |
| coarse=config.training.coarse_loss_clip, | |
| ) | |
| accelerator.backward(loss) | |
| if config.training.grad_clip: | |
| accelerator.clip_grad_norm_( | |
| model.parameters(), config.training.grad_clip | |
| ) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| if ema_model is not None: | |
| ema_model.update_parameters(accelerator.unwrap_model(model)) | |
| li = loss.item() | |
| loss_val += li | |
| total += 1 | |
| if accelerator.is_main_process: | |
| wandb.log( | |
| { | |
| "loss": 100 * li, | |
| "avg_loss": 100 * loss_val / (i + 1), | |
| "total": total, | |
| "loss_val": loss_val, | |
| "i": i, | |
| } | |
| ) | |
| pbar.set_postfix( | |
| {"loss": 100 * li, "avg_loss": 100 * loss_val / (i + 1)} | |
| ) | |
| if accelerator.is_main_process: | |
| print("Training loss: {:.6f}".format(loss_val / total)) | |
| wandb.log({"train_loss": loss_val / total, "epoch": epoch}) | |
| # Save last | |
| store_path = args.results_path + "/last_{}.ckpt".format(args.model_type) | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| if ema_model is not None: | |
| accelerator.save(ema_model.module.state_dict(), store_path) | |
| else: | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| accelerator.save(unwrapped_model.state_dict(), store_path) | |
| # Validation | |
| model_to_valid = ema_model if ema_model is not None else model | |
| sdr_list = valid( | |
| model_to_valid, | |
| valid_loader, | |
| args, | |
| config, | |
| device, | |
| verbose=accelerator.is_main_process, | |
| ) | |
| sdr_list = accelerator.gather(sdr_list) | |
| accelerator.wait_for_everyone() | |
| sdr_avg = 0.0 | |
| instruments = prefer_target_instrument(config) | |
| for instr in instruments: | |
| if accelerator.is_main_process and 0: | |
| print(sdr_list[instr]) | |
| sdr_data = torch.cat(sdr_list[instr], dim=0).cpu().numpy() | |
| # sdr_val = sdr_data.mean() | |
| sdr_val = sdr_data[:valid_dataset_length].mean() | |
| if accelerator.is_main_process: | |
| print( | |
| "Instr SDR {}: {:.4f} Debug: {}".format( | |
| instr, sdr_val, len(sdr_data) | |
| ) | |
| ) | |
| wandb.log({f"{instr}_sdr": sdr_val}) | |
| sdr_avg += sdr_val | |
| sdr_avg /= len(instruments) | |
| if len(instruments) > 1: | |
| if accelerator.is_main_process: | |
| print("SDR Avg: {:.4f}".format(sdr_avg)) | |
| wandb.log({"sdr_avg": sdr_avg, "best_sdr": best_sdr}) | |
| if accelerator.is_main_process: | |
| if sdr_avg > best_sdr: | |
| store_path = ( | |
| args.results_path | |
| + "/model_{}_ep_{}_sdr_{:.4f}.ckpt".format( | |
| args.model_type, epoch, sdr_avg | |
| ) | |
| ) | |
| print("Store weights: {}".format(store_path)) | |
| if ema_model is not None: | |
| accelerator.save(ema_model.module.state_dict(), store_path) | |
| else: | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| accelerator.save(unwrapped_model.state_dict(), store_path) | |
| best_sdr = sdr_avg | |
| scheduler.step(sdr_avg) | |
| sdr_list = None | |
| accelerator.wait_for_everyone() | |
| if __name__ == "__main__": | |
| train_model(None) | |