xjsc0's picture
1
64ec292
# coding: utf-8
__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
__version__ = "1.0.3"
# Read more here:
# https://huggingface.co/docs/accelerate/index
import argparse
import glob
import os
import time
import warnings
import auraloss
import numpy as np
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from accelerate import Accelerator
from torch.optim import SGD, Adam, AdamW, RAdam, RMSprop
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from utils.dataset import MSSDataset
from utils.losses import masked_loss
from utils.metrics import sdr
from utils.model_utils import (
demix,
load_not_compatible_weights,
prefer_target_instrument,
)
from utils.settings import get_model_from_config, manual_seed
warnings.filterwarnings("ignore")
def valid(model, valid_loader, args, config, device, verbose=False):
instruments = prefer_target_instrument(config)
all_sdr = dict()
for instr in instruments:
all_sdr[instr] = []
all_mixtures_path = valid_loader
if verbose:
all_mixtures_path = tqdm(valid_loader)
pbar_dict = {}
for path_list in all_mixtures_path:
path = path_list[0]
mix, sr = sf.read(path)
folder = os.path.dirname(path)
res = demix(config, model, mix.T, device, model_type=args.model_type) # mix.T
for instr in instruments:
if instr != "other" or config.training.other_fix is False:
track, sr1 = sf.read(folder + "/{}.wav".format(instr))
else:
# other is actually instrumental
track, sr1 = sf.read(folder + "/{}.wav".format("vocals"))
track = mix - track
# sf.write("{}.wav".format(instr), res[instr].T, sr, subtype='FLOAT')
references = np.expand_dims(track, axis=0)
estimates = np.expand_dims(res[instr].T, axis=0)
sdr_val = sdr(references, estimates)[0]
single_val = torch.from_numpy(np.array([sdr_val])).to(device)
all_sdr[instr].append(single_val)
pbar_dict["sdr_{}".format(instr)] = sdr_val
if verbose:
all_mixtures_path.set_postfix(pbar_dict)
return all_sdr
class MSSValidationDataset(torch.utils.data.Dataset):
def __init__(self, args):
all_mixtures_path = []
for valid_path in args.valid_path:
part = sorted(glob.glob(valid_path + "/*/mixture.wav"))
if len(part) == 0:
print("No validation data found in: {}".format(valid_path))
all_mixtures_path += part
self.list_of_files = all_mixtures_path
def __len__(self):
return len(self.list_of_files)
def __getitem__(self, index):
return self.list_of_files[index]
def train_model(args):
accelerator = Accelerator()
device = accelerator.device
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type",
type=str,
default="mdx23c",
help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet, bandit",
)
parser.add_argument("--config_path", type=str, help="path to config file")
parser.add_argument(
"--start_check_point",
type=str,
default="",
help="Initial checkpoint to start training",
)
parser.add_argument(
"--results_path",
type=str,
help="path to folder where results will be stored (weights, metadata)",
)
parser.add_argument(
"--data_path",
nargs="+",
type=str,
help="Dataset data paths. You can provide several folders.",
)
parser.add_argument(
"--dataset_type",
type=int,
default=1,
help="Dataset type. Must be one of: 1, 2, 3 or 4. Details here: https://github.com/ZFTurbo/Music-Source-Separation-Training/blob/main/docs/dataset_types.md",
)
parser.add_argument(
"--valid_path",
nargs="+",
type=str,
help="validation data paths. You can provide several folders.",
)
parser.add_argument(
"--num_workers", type=int, default=0, help="dataloader num_workers"
)
parser.add_argument(
"--pin_memory", type=bool, default=False, help="dataloader pin_memory"
)
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument(
"--device_ids", nargs="+", type=int, default=[0], help="list of gpu ids"
)
parser.add_argument(
"--use_multistft_loss",
action="store_true",
help="Use MultiSTFT Loss (from auraloss package)",
)
parser.add_argument(
"--use_mse_loss", action="store_true", help="Use default MSE loss"
)
parser.add_argument("--use_l1_loss", action="store_true", help="Use L1 loss")
parser.add_argument("--wandb_key", type=str, default="", help="wandb API Key")
parser.add_argument(
"--pre_valid", action="store_true", help="Run validation before training"
)
if args is None:
args = parser.parse_args()
else:
args = parser.parse_args(args)
manual_seed(args.seed + int(time.time()))
# torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = (
False # Fix possible slow down with dilation convolutions
)
torch.multiprocessing.set_start_method("spawn")
model, config = get_model_from_config(args.model_type, args.config_path)
if "model_type" in config.training:
args.model_type = config.training.model_type
accelerator.print("Instruments: {}".format(config.training.instruments))
os.makedirs(args.results_path, exist_ok=True)
device_ids = args.device_ids
batch_size = config.training.batch_size
# wandb
if (
accelerator.is_main_process
and args.wandb_key is not None
and args.wandb_key.strip() != ""
):
wandb.login(key=args.wandb_key)
wandb.init(
project="msst-accelerate",
config={
"config": config,
"args": args,
"device_ids": device_ids,
"batch_size": batch_size,
},
)
else:
wandb.init(mode="disabled")
# Fix for num of steps
config.training.num_steps *= accelerator.num_processes
trainset = MSSDataset(
config,
args.data_path,
batch_size=batch_size,
metadata_path=os.path.join(
args.results_path, "metadata_{}.pkl".format(args.dataset_type)
),
dataset_type=args.dataset_type,
verbose=accelerator.is_main_process,
)
train_loader = DataLoader(
trainset,
batch_size=batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
)
validset = MSSValidationDataset(args)
valid_dataset_length = len(validset)
valid_loader = DataLoader(
validset,
batch_size=1,
shuffle=False,
)
valid_loader = accelerator.prepare(valid_loader)
if args.start_check_point != "":
accelerator.print("Start from checkpoint: {}".format(args.start_check_point))
if 1:
load_not_compatible_weights(model, args.start_check_point, verbose=False)
else:
model.load_state_dict(torch.load(args.start_check_point))
optim_params = dict()
if "optimizer" in config:
optim_params = dict(config["optimizer"])
accelerator.print("Optimizer params from config:\n{}".format(optim_params))
if config.training.optimizer == "adam":
optimizer = Adam(model.parameters(), lr=config.training.lr, **optim_params)
elif config.training.optimizer == "adamw":
optimizer = AdamW(model.parameters(), lr=config.training.lr, **optim_params)
elif config.training.optimizer == "radam":
optimizer = RAdam(model.parameters(), lr=config.training.lr, **optim_params)
elif config.training.optimizer == "rmsprop":
optimizer = RMSprop(model.parameters(), lr=config.training.lr, **optim_params)
elif config.training.optimizer == "prodigy":
from prodigyopt import Prodigy
# you can choose weight decay value based on your problem, 0 by default
# We recommend using lr=1.0 (default) for all networks.
optimizer = Prodigy(model.parameters(), lr=config.training.lr, **optim_params)
elif config.training.optimizer == "adamw8bit":
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(
model.parameters(), lr=config.training.lr, **optim_params
)
elif config.training.optimizer == "sgd":
accelerator.print("Use SGD optimizer")
optimizer = SGD(model.parameters(), lr=config.training.lr, **optim_params)
else:
accelerator.print("Unknown optimizer: {}".format(config.training.optimizer))
exit()
if accelerator.is_main_process:
print("Processes GPU: {}".format(accelerator.num_processes))
print(
"Patience: {} Reduce factor: {} Batch size: {} Optimizer: {}".format(
config.training.patience,
config.training.reduce_factor,
batch_size,
config.training.optimizer,
)
)
# Reduce LR if no SDR improvements for several epochs
scheduler = ReduceLROnPlateau(
optimizer,
"max",
# patience=accelerator.num_processes * config.training.patience, # This is strange place...
patience=config.training.patience,
factor=config.training.reduce_factor,
)
if args.use_multistft_loss:
try:
loss_options = dict(config.loss_multistft)
except:
loss_options = dict()
accelerator.print("Loss options: {}".format(loss_options))
loss_multistft = auraloss.freq.MultiResolutionSTFTLoss(**loss_options)
model, optimizer, train_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, scheduler
)
ema_model = None
if hasattr(config.training, "ema_momentum") and config.training.ema_momentum > 0:
accelerator.print(
f"Initializing EMA with decay: {config.training.ema_momentum}"
)
ema_model = AveragedModel(
accelerator.unwrap_model(model),
multi_avg_fn=get_ema_multi_avg_fn(config.training.ema_momentum),
)
ema_model.to(device)
if args.pre_valid:
model_to_valid = ema_model if ema_model is not None else model
sdr_list = valid(
model_to_valid,
valid_loader,
args,
config,
device,
verbose=accelerator.is_main_process,
)
sdr_list = accelerator.gather(sdr_list)
accelerator.wait_for_everyone()
# print(sdr_list)
sdr_avg = 0.0
instruments = prefer_target_instrument(config)
for instr in instruments:
# print(sdr_list[instr])
sdr_data = torch.cat(sdr_list[instr], dim=0).cpu().numpy()
sdr_val = sdr_data.mean()
accelerator.print("Valid length: {}".format(valid_dataset_length))
accelerator.print(
"Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data))
)
sdr_val = sdr_data[:valid_dataset_length].mean()
accelerator.print(
"Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data))
)
sdr_avg += sdr_val
sdr_avg /= len(instruments)
if len(instruments) > 1:
accelerator.print("SDR Avg: {:.4f}".format(sdr_avg))
sdr_list = None
accelerator.print("Train for: {}".format(config.training.num_epochs))
best_sdr = -100
for epoch in range(config.training.num_epochs):
model.train().to(device)
accelerator.print(
"Train epoch: {} Learning rate: {}".format(
epoch, optimizer.param_groups[0]["lr"]
)
)
loss_val = 0.0
total = 0
pbar = tqdm(train_loader, disable=not accelerator.is_main_process)
for i, (batch, mixes) in enumerate(pbar):
y = batch
x = mixes
if args.model_type in [
"mel_band_roformer",
"bs_roformer",
"bs_mamba2",
"mel_band_conformer",
"bs_conformer",
]:
# loss is computed in forward pass
loss = model(x, y)
else:
y_ = model(x)
if args.use_multistft_loss:
y1_ = torch.reshape(
y_, (y_.shape[0], y_.shape[1] * y_.shape[2], y_.shape[3])
)
y1 = torch.reshape(
y, (y.shape[0], y.shape[1] * y.shape[2], y.shape[3])
)
loss = loss_multistft(y1_, y1)
# We can use many losses at the same time
if args.use_mse_loss:
loss += 1000 * nn.MSELoss()(y1_, y1)
if args.use_l1_loss:
loss += 1000 * F.l1_loss(y1_, y1)
elif args.use_mse_loss:
loss = nn.MSELoss()(y_, y)
elif args.use_l1_loss:
loss = F.l1_loss(y_, y)
else:
loss = masked_loss(
y_,
y,
q=config.training.q,
coarse=config.training.coarse_loss_clip,
)
accelerator.backward(loss)
if config.training.grad_clip:
accelerator.clip_grad_norm_(
model.parameters(), config.training.grad_clip
)
optimizer.step()
optimizer.zero_grad()
if ema_model is not None:
ema_model.update_parameters(accelerator.unwrap_model(model))
li = loss.item()
loss_val += li
total += 1
if accelerator.is_main_process:
wandb.log(
{
"loss": 100 * li,
"avg_loss": 100 * loss_val / (i + 1),
"total": total,
"loss_val": loss_val,
"i": i,
}
)
pbar.set_postfix(
{"loss": 100 * li, "avg_loss": 100 * loss_val / (i + 1)}
)
if accelerator.is_main_process:
print("Training loss: {:.6f}".format(loss_val / total))
wandb.log({"train_loss": loss_val / total, "epoch": epoch})
# Save last
store_path = args.results_path + "/last_{}.ckpt".format(args.model_type)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if ema_model is not None:
accelerator.save(ema_model.module.state_dict(), store_path)
else:
unwrapped_model = accelerator.unwrap_model(model)
accelerator.save(unwrapped_model.state_dict(), store_path)
# Validation
model_to_valid = ema_model if ema_model is not None else model
sdr_list = valid(
model_to_valid,
valid_loader,
args,
config,
device,
verbose=accelerator.is_main_process,
)
sdr_list = accelerator.gather(sdr_list)
accelerator.wait_for_everyone()
sdr_avg = 0.0
instruments = prefer_target_instrument(config)
for instr in instruments:
if accelerator.is_main_process and 0:
print(sdr_list[instr])
sdr_data = torch.cat(sdr_list[instr], dim=0).cpu().numpy()
# sdr_val = sdr_data.mean()
sdr_val = sdr_data[:valid_dataset_length].mean()
if accelerator.is_main_process:
print(
"Instr SDR {}: {:.4f} Debug: {}".format(
instr, sdr_val, len(sdr_data)
)
)
wandb.log({f"{instr}_sdr": sdr_val})
sdr_avg += sdr_val
sdr_avg /= len(instruments)
if len(instruments) > 1:
if accelerator.is_main_process:
print("SDR Avg: {:.4f}".format(sdr_avg))
wandb.log({"sdr_avg": sdr_avg, "best_sdr": best_sdr})
if accelerator.is_main_process:
if sdr_avg > best_sdr:
store_path = (
args.results_path
+ "/model_{}_ep_{}_sdr_{:.4f}.ckpt".format(
args.model_type, epoch, sdr_avg
)
)
print("Store weights: {}".format(store_path))
if ema_model is not None:
accelerator.save(ema_model.module.state_dict(), store_path)
else:
unwrapped_model = accelerator.unwrap_model(model)
accelerator.save(unwrapped_model.state_dict(), store_path)
best_sdr = sdr_avg
scheduler.step(sdr_avg)
sdr_list = None
accelerator.wait_for_everyone()
if __name__ == "__main__":
train_model(None)