xjsc0's picture
1
64ec292
# coding: utf-8
__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
import itertools
import multiprocessing
import os
import pickle
import random
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from glob import glob
from typing import Union
import audiomentations as AU
import numpy as np
import pedalboard as PB
import soundfile as sf
import torch
import torch.distributed as dist
from ml_collections import ConfigDict
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from tqdm.auto import tqdm
warnings.filterwarnings("ignore")
import argparse
def prepare_data(
config: Union[ConfigDict, OmegaConf], args: argparse.Namespace, batch_size: int
) -> DataLoader:
"""
Build the training DataLoader. If torch.distributed.is_initialized() is True,
construct a DDP DataLoader with DistributedSampler; otherwise, construct a regular DataLoader.
Args:
config: Dataset configuration passed to MSSDataset.
args: Must provide data_path, results_path, dataset_type, and DataLoader settings.
batch_size: Per-process mini-batch size.
Returns:
Configured DataLoader for the training split.
"""
# DDP
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
if args.dataset_type != 5:
ddp_batch = (
batch_size * world_size
) # maintain "num_steps" semantics across the whole world
else:
ddp_batch = batch_size
trainset = MSSDataset(
config,
args.data_path,
batch_size=ddp_batch,
metadata_path=os.path.join(
args.results_path, f"metadata_{args.dataset_type}.pkl"
),
dataset_type=args.dataset_type,
)
sampler = DistributedSampler(
trainset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True
)
train_loader = DataLoader(
trainset,
batch_size=batch_size, # per-process batch size
sampler=sampler, # sampler handles shuffling in DDP
num_workers=args.num_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_workers,
prefetch_factor=args.prefetch_factor,
)
else:
trainset = MSSDataset(
config,
args.data_path,
batch_size=batch_size,
metadata_path=os.path.join(
args.results_path, f"metadata_{args.dataset_type}.pkl"
),
dataset_type=args.dataset_type,
)
train_loader = DataLoader(
trainset,
batch_size=batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_workers,
prefetch_factor=args.prefetch_factor,
)
return train_loader
def load_chunk(path, length, chunk_size, offset=None, target_channels=2):
"""
Returns array with shape (target_channels, chunk_size)
"""
if chunk_size <= length:
if offset is None:
start = np.random.randint(length - chunk_size + 1)
else:
start = offset
x = sf.read(path, dtype="float32", start=start, frames=chunk_size)[0]
else:
if offset is None:
start = 0
else:
start = offset
frames_to_read = length
x = sf.read(path, dtype="float32", start=start, frames=frames_to_read)[0]
if x.ndim == 1:
x = x[:, None]
if x.shape[0] < chunk_size:
pad = np.zeros((chunk_size - x.shape[0], x.shape[1]), dtype=np.float32)
x = np.concatenate([x, pad], axis=0)
elif x.shape[0] > chunk_size:
x = x[:chunk_size]
ch = x.shape[1]
if ch == target_channels:
pass
elif ch > target_channels:
x = x[:, :target_channels]
elif ch == 1:
x = np.repeat(x, 2, axis=1)
else:
raise ValueError(f"Path: {path}, num_channels: {ch}")
return x.T
def get_track_set_length(params):
path, instruments, file_types, dataset_type = params
should_print = (
not dist.is_initialized() or dist.get_rank() == 0
) and dataset_type != 7
# Check lengths of all instruments (it can be different in some cases)
lengths_arr = []
for instr in instruments:
length = -1
for extension in file_types:
path_to_audio_file = path + "/{}.{}".format(instr, extension)
if os.path.isfile(path_to_audio_file):
length = sf.info(path_to_audio_file).frames
break
if length == -1:
if should_print:
print('Cant find file "{}" in folder {}'.format(instr, path))
continue
lengths_arr.append(length)
lengths_arr = np.array(lengths_arr)
if lengths_arr.min() != lengths_arr.max() and should_print:
print(
f"Warning: lengths of stems are different for path: {path}. ({lengths_arr.min()} != {lengths_arr.max()})"
)
# We use minimum to allow overflow for soundfile read in non-equal length cases
return path, lengths_arr.min()
# For multiprocessing
def get_track_length(params):
path = params
length = sf.info(path).frames
return (path, length)
def process_chunk_worker(args):
task, instruments, file_types, min_mean_abs, default_chunk_size = args
track_path, track_length, offset, chunk_size = task
try:
for instrument in instruments:
instrument_loud_enough = False
for extension in file_types:
path_to_audio_file = track_path + "/{}.{}".format(instrument, extension)
if os.path.isfile(path_to_audio_file):
try:
source = load_chunk(
path_to_audio_file,
length=track_length,
offset=offset,
chunk_size=chunk_size,
)
if np.abs(source).mean() >= min_mean_abs:
instrument_loud_enough = True
break
except Exception:
return (track_path, offset, False)
if not instrument_loud_enough:
return (track_path, offset, False)
return (track_path, offset, True)
except Exception:
return (track_path, offset, False)
class MSSDataset(torch.utils.data.Dataset):
def __init__(
self,
config,
data_path,
metadata_path="metadata.pkl",
dataset_type=1,
batch_size=None,
verbose=True,
):
self.verbose = verbose
self.config = config
self.dataset_type = dataset_type # 1, 2, 3, 4 or 5
self.data_path = data_path
self.instruments = instruments = config.training.instruments
if batch_size is None:
batch_size = config.training.batch_size
self.batch_size = batch_size
self.file_types = ["wav", "flac"]
self.metadata_path = metadata_path
should_print = not dist.is_initialized() or dist.get_rank() == 0
# Augmentation block
self.aug = False
if "augmentations" in config:
if config["augmentations"].enable is True:
if self.verbose and should_print:
print("Use augmentation for training")
self.aug = True
else:
if self.verbose and should_print:
print(
"There is no augmentations block in config. Augmentations disabled for training..."
)
metadata = self.get_metadata()
if self.dataset_type in [1, 4, 5, 6, 7]:
if len(metadata) > 0:
if self.verbose and should_print:
print("Found tracks in dataset: {}".format(len(metadata)))
else:
if should_print:
print("No tracks found for training. Check paths you provided!")
exit()
else:
for instr in self.instruments:
if self.verbose and should_print:
print(
"Found tracks for {} in dataset: {}".format(
instr, len(metadata[instr])
)
)
self.metadata = metadata
self.chunk_size = config.audio.chunk_size
self.min_mean_abs = config.audio.min_mean_abs
self.do_chunks = (
config.training.get("precompute_chunks", False)
and float(self.min_mean_abs) > 0
)
# For dataset_type 5 - precompute all chunks
if (
self.dataset_type == 5
or (self.dataset_type == 4 or self.dataset_type == 6)
and self.do_chunks
):
self._initialize_chunks_metadata()
if self.dataset_type == 7:
self._build_class_to_tracks()
def __len__(self):
if self.dataset_type == 5:
return len(self.chunks_metadata)
return self.config.training.num_steps * self.batch_size
def __getitem__(self, index):
if self.dataset_type == 7:
res, mix, active_stem_ids = self.load_class_balanced_aligned()
elif self.dataset_type == 5:
track_path, offset = self.chunks_metadata[index]
res = self._load_chunk_by_offset(track_path, offset)
elif self.dataset_type in [1, 2, 3]:
res = self.load_random_mix()
else: # type 4 or 6
if self.do_chunks:
track_path, offset = self.chunks_metadata[
np.random.randint(len(self.chunks_metadata))
]
res = self._load_chunk_by_offset(track_path, offset)
else:
if self.dataset_type == 6:
res, mix = self.load_aligned_data()
else:
res, _ = self.load_aligned_data()
# Randomly change loudness of each stem
if self.aug:
if "loudness" in self.config["augmentations"]:
if self.config["augmentations"]["loudness"]:
loud_values = np.random.uniform(
low=self.config["augmentations"]["loudness_min"],
high=self.config["augmentations"]["loudness_max"],
size=(len(res),),
)
loud_values = torch.tensor(loud_values, dtype=torch.float32)
res *= loud_values[:, None, None]
if self.dataset_type != 6 and self.dataset_type != 7:
mix = res.sum(0)
if self.aug:
if "mp3_compression_on_mixture" in self.config["augmentations"]:
apply_aug = AU.Mp3Compression(
min_bitrate=self.config["augmentations"][
"mp3_compression_on_mixture_bitrate_min"
],
max_bitrate=self.config["augmentations"][
"mp3_compression_on_mixture_bitrate_max"
],
backend=self.config["augmentations"][
"mp3_compression_on_mixture_backend"
],
p=self.config["augmentations"]["mp3_compression_on_mixture"],
)
mix_conv = mix.cpu().numpy().astype(np.float32)
required_shape = mix_conv.shape
mix = apply_aug(samples=mix_conv, sample_rate=44100)
# Sometimes it gives longer audio (so we cut)
if mix.shape != required_shape:
mix = mix[..., : required_shape[-1]]
mix = torch.tensor(mix, dtype=torch.float32)
# If we need to optimize only given stem
if self.config.training.target_instrument is not None:
index = self.config.training.instruments.index(
self.config.training.target_instrument
)
return res[index : index + 1], mix
if self.dataset_type == 7:
return res, mix, active_stem_ids
return res, mix
def _build_class_to_tracks(self):
import json
should_print = not dist.is_initialized() or dist.get_rank() == 0
cache_path = "class_to_tracks_cache.json"
total_tracks = len(self.metadata)
max_ratio = self.config.training.get("max_class_presence_ratio", 0.4)
if os.path.isfile(cache_path):
if should_print:
print("[dataset_type=7] Loading class_to_tracks from cache")
with open(cache_path, "r", encoding="utf8") as f:
cache = json.load(f)
if (
cache.get("total_tracks") == total_tracks
and cache.get("max_ratio") == max_ratio
):
self.class_to_tracks = cache["class_to_tracks"]
self.available_classes = list(self.class_to_tracks.keys())
if should_print:
print(
f"[dataset_type=7] Loaded {len(self.available_classes)} classes from cache"
)
return
else:
if should_print:
print("[dataset_type=7] Cache invalid, rebuilding")
class_to_tracks = {instr: [] for instr in self.instruments}
track_iter = self.metadata
if should_print:
track_iter = tqdm(
self.metadata,
desc="[dataset_type=7] Building class_to_tracks",
total=total_tracks,
)
for track_path, _ in track_iter:
for instr in self.instruments:
for ext in self.file_types:
path = f"{track_path}/{instr}.{ext}"
if os.path.isfile(path):
class_to_tracks[instr].append(track_path)
break
filtered_class_to_tracks = {}
for instr, tracks in class_to_tracks.items():
count = len(tracks)
ratio = count / total_tracks
if count == 0:
continue
if ratio > max_ratio:
if should_print:
print(
f"[dataset_type=7] Skip frequent stem '{instr}': "
f"{count}/{total_tracks} ({ratio:.1%})"
)
continue
filtered_class_to_tracks[instr] = tracks
if len(filtered_class_to_tracks) == 0:
raise RuntimeError(
"dataset_type 7: all classes were filtered out by frequency threshold"
)
self.class_to_tracks = filtered_class_to_tracks
self.available_classes = list(filtered_class_to_tracks.keys())
if should_print:
print("[dataset_type=7] Saving class_to_tracks cache")
with open(cache_path, "w", encoding="utf8") as f:
json.dump(
{
"total_tracks": total_tracks,
"max_ratio": max_ratio,
"class_to_tracks": filtered_class_to_tracks,
},
f,
indent=2,
)
if should_print:
print(
f"[dataset_type=7] Using {len(self.available_classes)} balanced classes "
f"out of {len(self.instruments)} instruments"
)
def load_class_balanced_aligned(self):
"""
1) Randomly choose instrument (class)
2) Randomly choose track containing this instrument
3) Load aligned chunk from this track
"""
should_print = not dist.is_initialized() or dist.get_rank() == 0
instr = random.choice(self.available_classes)
track_path = random.choice(self.class_to_tracks[instr])
# Find track length
track_length = None
for path, length in self.metadata:
if path == track_path:
track_length = length
break
if track_length is None:
raise RuntimeError(f"Track length not found: {track_path}")
if track_length >= self.chunk_size:
offset = np.random.randint(track_length - self.chunk_size + 1)
else:
offset = None
mix = None
for extension in self.file_types:
path_to_mix_file = f"{track_path}/mixture.{extension}"
if os.path.isfile(path_to_mix_file):
try:
mix = load_chunk(
path_to_mix_file, track_length, self.chunk_size, offset=offset
)
break
except Exception as e:
print(e)
res = []
active_stem_ids = []
for idx, instr in enumerate(self.instruments):
found = False
for extension in self.file_types:
path_to_audio_file = f"{track_path}/{instr}.{extension}"
if os.path.isfile(path_to_audio_file):
try:
source = load_chunk(
path_to_audio_file,
track_length,
self.chunk_size,
offset=offset,
)
active_stem_ids.append(idx)
found = True
break
except Exception as e:
print(e)
if not found:
source = np.zeros((2, self.chunk_size), dtype=np.float32)
res.append(source)
res = np.stack(res, axis=0)
if mix is None:
mix = np.sum(res, axis=0)
if self.aug:
for i, instr in enumerate(self.instruments):
res[i] = self.augm_data(res[i], instr)
return (
torch.tensor(res, dtype=torch.float32),
torch.tensor(mix, dtype=torch.float32),
active_stem_ids,
)
def _initialize_chunks_metadata(self):
should_print = not dist.is_initialized() or dist.get_rank() == 0
chunks_cache_path = self.metadata_path.replace(".pkl", "_chunks.pkl")
current_config = {
"chunk_size": self.chunk_size,
"min_mean_abs": self.min_mean_abs,
"instruments": sorted(self.instruments),
}
if os.path.exists(chunks_cache_path):
try:
cached_chunks = pickle.load(open(chunks_cache_path, "rb"))
cached_config = cached_chunks.get("config", {})
config_matches = (
cached_config.get("chunk_size") == current_config["chunk_size"]
and cached_config.get("min_mean_abs")
== current_config["min_mean_abs"]
and cached_config.get("instruments")
== current_config["instruments"]
)
if config_matches:
self.chunks_metadata = cached_chunks["chunks_metadata"]
if self.verbose and should_print:
print(
f"Loaded {len(self.chunks_metadata)} cached chunks from {chunks_cache_path}"
)
else:
if self.verbose and should_print:
print("Config changed, recomputing chunks...")
print(f"Cached config: {cached_config}")
print(f"Current config: {current_config}")
self.chunks_metadata = self._precompute_and_cache_chunks(
chunks_cache_path, current_config
)
except Exception as e:
if self.verbose and should_print:
print(f"Chunks cache corrupted ({e}), recomputing...")
self.chunks_metadata = self._precompute_and_cache_chunks(
chunks_cache_path, current_config
)
else:
self.chunks_metadata = self._precompute_and_cache_chunks(
chunks_cache_path, current_config
)
if self.verbose and should_print:
print(f"Precomputed {len(self.chunks_metadata)} chunks")
def _precompute_and_cache_chunks(self, cache_path, config):
"""Precompute all chunks and save to cache with config"""
if self.dataset_type == 4 or self.dataset_type == 6:
chunks_metadata = self._precompute_random_chunks()
elif self.dataset_type == 5:
chunks_metadata = self._precompute_chunks()
else:
raise "Only dataset type 4, 5 can be precomputed"
cache_data = {"chunks_metadata": chunks_metadata, "config": config}
pickle.dump(cache_data, open(cache_path, "wb"))
return chunks_metadata
def _precompute_chunks(self):
"""Precompute all chunks for dataset_type 5 with overlap 2 using multiprocessing"""
should_print = not dist.is_initialized() or dist.get_rank() == 0
tasks = []
for track_path, track_length in self.metadata:
if track_length < self.chunk_size:
tasks.append((track_path, track_length, 0, track_length))
else:
step = self.chunk_size // 2
num_chunks = (track_length - self.chunk_size) // step + 1
for i in range(num_chunks):
offset = i * step
tasks.append((track_path, track_length, offset, self.chunk_size))
if should_print:
print(f"Total tasks to process: {len(tasks)}")
if multiprocessing.cpu_count() > 1:
chunks_metadata = self._process_tasks_parallel(tasks, should_print)
else:
chunks_metadata = self._process_tasks_sequential(tasks, should_print)
if self.verbose and should_print:
print(
f"Created {len(chunks_metadata)} good chunks from {len(self.metadata)} tracks"
)
return chunks_metadata
def _precompute_random_chunks(self):
"""Precompute exact number of good chunks"""
should_print = not dist.is_initialized() or dist.get_rank() == 0
target_count = self.config.training.get(
"num_precompute_chunks",
self.config.training.num_steps
* self.batch_size
* self.config.training.num_epochs,
)
chunks_metadata = []
if should_print:
print(f"Generating exactly {target_count} good chunks...")
with tqdm(total=target_count, desc="Progress good chunks") as pbar:
while len(chunks_metadata) < target_count:
batch_size = self.config.training.get(
"precompute_batch_for_chunks", 500
)
tasks = []
need = target_count - len(chunks_metadata)
for i in range(batch_size):
track_path, track_length = random.choice(self.metadata)
if track_length < self.chunk_size:
tasks.append((track_path, track_length, 0, track_length))
else:
offset = np.random.randint(track_length - self.chunk_size + 1)
tasks.append(
(track_path, track_length, offset, self.chunk_size)
)
if multiprocessing.cpu_count() > 1:
good_chunks = self._process_tasks_parallel(tasks, False)
else:
good_chunks = self._process_tasks_sequential(tasks, False)
chunks_metadata.extend(good_chunks)
pbar.update(min(len(good_chunks), need))
chunks_metadata = chunks_metadata[:target_count]
return chunks_metadata
def _process_tasks_sequential(self, tasks, should_print):
chunks_metadata = []
pbar = tqdm(tasks, desc="Processing chunks") if should_print else tasks
for task in pbar:
track_path, track_length, offset, chunk_size = task
if self._is_chunk_loud_enough(track_path, offset, chunk_size, track_length):
chunks_metadata.append((track_path, offset))
return chunks_metadata
def _process_tasks_parallel(self, tasks, should_print):
chunks_metadata = []
with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
worker_args = [
(
task,
self.instruments,
self.file_types,
self.min_mean_abs,
self.chunk_size,
)
for task in tasks
]
results = []
if should_print:
with tqdm(total=len(tasks), desc="Processing chunks") as pbar:
for i, result in enumerate(
pool.imap_unordered(process_chunk_worker, worker_args)
):
results.append(result)
pbar.update(1)
else:
for result in pool.imap_unordered(process_chunk_worker, worker_args):
results.append(result)
for result in results:
track_path, offset, is_loud_enough = result
if is_loud_enough:
chunks_metadata.append((track_path, offset))
return chunks_metadata
def _is_chunk_loud_enough(self, track_path, offset, chunk_size, track_length):
try:
for instrument in self.instruments:
instrument_loud_enough = False
for extension in self.file_types:
path_to_audio_file = track_path + "/{}.{}".format(
instrument, extension
)
if os.path.isfile(path_to_audio_file):
try:
source = load_chunk(
path_to_audio_file,
length=track_length,
offset=offset,
chunk_size=chunk_size,
)
if np.abs(source).mean() >= self.min_mean_abs:
instrument_loud_enough = True
break
except Exception as e:
if not dist.is_initialized() or dist.get_rank() == 0:
print(
"Error loading: {} Path: {}".format(
e, path_to_audio_file
)
)
return False
if not instrument_loud_enough:
return False
return True
except Exception as e:
if not dist.is_initialized() or dist.get_rank() == 0:
print(
"Error checking chunk loudness: {} Path: {}".format(e, track_path)
)
return False
def read_from_metadata_cache(self, track_paths, instr=None):
should_print = not dist.is_initialized() or dist.get_rank() == 0
metadata = []
if os.path.isfile(self.metadata_path):
if self.verbose and should_print:
print("Found metadata cache file: {}".format(self.metadata_path))
old_metadata = pickle.load(open(self.metadata_path, "rb"))
else:
return track_paths, metadata
if instr:
old_metadata = old_metadata[instr]
# We will not re-read tracks existed in old metadata file
track_paths_set = set(track_paths)
for old_path, file_size in old_metadata:
if old_path in track_paths_set:
metadata.append([old_path, file_size])
track_paths_set.remove(old_path)
track_paths = list(track_paths_set)
if len(metadata) > 0 and should_print:
print("Old metadata was used for {} tracks.".format(len(metadata)))
return track_paths, metadata
def get_metadata(self):
read_metadata_procs = multiprocessing.cpu_count() - 2
should_print = not dist.is_initialized() or dist.get_rank() == 0
if "read_metadata_procs" in self.config["training"]:
read_metadata_procs = int(self.config["training"]["read_metadata_procs"])
if self.verbose and should_print:
print(
"Dataset type:",
self.dataset_type,
"Processes to use:",
read_metadata_procs,
"\nCollecting metadata for",
str(self.data_path),
)
if self.dataset_type in [1, 4, 5, 6, 7]: # Added type 7
track_paths = []
if type(self.data_path) == list:
for tp in self.data_path:
tracks_for_folder = sorted(glob(tp + "/*"))
if len(tracks_for_folder) == 0 and should_print:
print(
"Warning: no tracks found in folder '{}'. Please check it!".format(
tp
)
)
track_paths += tracks_for_folder
else:
track_paths += sorted(glob(self.data_path + "/*"))
track_paths = [
path
for path in track_paths
if os.path.basename(path)[0] != "." and os.path.isdir(path)
]
track_paths, metadata = self.read_from_metadata_cache(track_paths, None)
if read_metadata_procs <= 1:
pbar = tqdm(track_paths) if should_print else track_paths
for path in pbar:
track_path, track_length = get_track_set_length(
(path, self.instruments, self.file_types, self.dataset_type)
)
metadata.append((track_path, track_length))
else:
with ThreadPoolExecutor(max_workers=read_metadata_procs) as executor:
futures = [
executor.submit(get_track_set_length, args)
for args in zip(
track_paths,
itertools.repeat(self.instruments),
itertools.repeat(self.file_types),
itertools.repeat(self.dataset_type),
)
]
if should_print:
for f in tqdm(as_completed(futures), total=len(futures)):
track_path, track_length = f.result()
metadata.append((track_path, track_length))
else:
for f in as_completed(futures):
metadata.append(f.result())
elif self.dataset_type == 2:
metadata = dict()
for instr in self.instruments:
metadata[instr] = []
track_paths = []
if type(self.data_path) == list:
for tp in self.data_path:
track_paths += sorted(glob(tp + "/{}/*.wav".format(instr)))
track_paths += sorted(glob(tp + "/{}/*.flac".format(instr)))
else:
track_paths += sorted(
glob(self.data_path + "/{}/*.wav".format(instr))
)
track_paths += sorted(
glob(self.data_path + "/{}/*.flac".format(instr))
)
track_paths, metadata[instr] = self.read_from_metadata_cache(
track_paths, instr
)
if read_metadata_procs <= 1:
pbar = tqdm(track_paths) if should_print else track_paths
for path in pbar:
length = sf.info(path).frames
metadata[instr].append((path, length))
else:
p = multiprocessing.Pool(processes=read_metadata_procs)
track_iter = p.imap(get_track_length, track_paths)
if should_print:
track_iter = tqdm(track_iter, total=len(track_paths))
for out in track_iter:
metadata[instr].append(out)
p.close()
elif self.dataset_type == 3:
import pandas as pd
if type(self.data_path) != list:
data_path = [self.data_path]
metadata = dict()
for i in range(len(self.data_path)):
if self.verbose and should_print:
print("Reading tracks from: {}".format(self.data_path[i]))
df = pd.read_csv(self.data_path[i])
skipped = 0
for instr in self.instruments:
part = df[df["instrum"] == instr].copy()
if should_print:
print("Tracks found for {}: {}".format(instr, len(part)))
for instr in self.instruments:
part = df[df["instrum"] == instr].copy()
metadata[instr] = []
track_paths = list(part["path"].values)
track_paths, metadata[instr] = self.read_from_metadata_cache(
track_paths, instr
)
pbar = tqdm(track_paths) if should_print else track_paths
for path in pbar:
if not os.path.isfile(path):
if should_print:
print("Cant find track: {}".format(path))
skipped += 1
continue
# print(path)
try:
length = sf.info(path).frames
except:
if should_print:
print("Problem with path: {}".format(path))
skipped += 1
continue
metadata[instr].append((path, length))
if skipped > 0 and should_print:
print("Missing tracks: {} from {}".format(skipped, len(df)))
else:
if should_print:
print(
"Unknown dataset type: {}. Must be 1, 2, 3, 4, 5 or 6".format(
self.dataset_type
)
)
exit()
# Save metadata
pickle.dump(metadata, open(self.metadata_path, "wb"))
return metadata
def load_source(self, metadata, instr):
should_print = not dist.is_initialized() or dist.get_rank() == 0
while True:
if self.dataset_type in [1, 4, 5, 6, 7]:
track_path, track_length = random.choice(metadata)
for extension in self.file_types:
path_to_audio_file = track_path + "/{}.{}".format(instr, extension)
if os.path.isfile(path_to_audio_file):
try:
source = load_chunk(
path_to_audio_file, track_length, self.chunk_size
)
except Exception as e:
# Sometimes error during FLAC reading, catch it and use zero stem
if should_print:
print(
"Error: {} Path: {}".format(e, path_to_audio_file)
)
source = np.zeros((2, self.chunk_size), dtype=np.float32)
break
else:
track_path, track_length = random.choice(metadata[instr])
try:
source = load_chunk(track_path, track_length, self.chunk_size)
except Exception as e:
# Sometimes error during FLAC reading, catch it and use zero stem
if should_print:
print("Error: {} Path: {}".format(e, track_path))
source = np.zeros((2, self.chunk_size), dtype=np.float32)
if np.abs(source).mean() >= self.min_mean_abs: # remove quiet chunks
break
if self.aug:
source = self.augm_data(source, instr)
return torch.tensor(source, dtype=torch.float32)
def load_random_mix(self):
res = []
for instr in self.instruments:
s1 = self.load_source(self.metadata, instr)
# Mixup augmentation. Multiple mix of same type of stems
if self.aug:
if "mixup" in self.config["augmentations"]:
if self.config["augmentations"].mixup:
mixup = [s1]
for prob in self.config.augmentations.mixup_probs:
if random.uniform(0, 1) < prob:
s2 = self.load_source(self.metadata, instr)
mixup.append(s2)
mixup = torch.stack(mixup, dim=0)
loud_values = np.random.uniform(
low=self.config.augmentations.loudness_min,
high=self.config.augmentations.loudness_max,
size=(len(mixup),),
)
loud_values = torch.tensor(loud_values, dtype=torch.float32)
mixup *= loud_values[:, None, None]
s1 = mixup.mean(dim=0, dtype=torch.float32)
res.append(s1)
res = torch.stack(res)
return res
def _load_chunk_by_offset(self, track_path, offset):
"""Load specific chunk by track path and offset"""
should_print = not dist.is_initialized() or dist.get_rank() == 0
res = []
for instr in self.instruments:
for extension in self.file_types:
path_to_audio_file = track_path + "/{}.{}".format(instr, extension)
if os.path.isfile(path_to_audio_file):
try:
# Get track length from metadata
track_length = None
for path, length in self.metadata:
if path == track_path:
track_length = length
break
if track_length is None:
source = np.zeros((2, self.chunk_size), dtype=np.float32)
else:
source = load_chunk(
path_to_audio_file,
track_length,
self.chunk_size,
offset=offset,
)
except Exception as e:
if should_print:
print("Error: {} Path: {}".format(e, path_to_audio_file))
source = np.zeros((2, self.chunk_size), dtype=np.float32)
break
else:
source = np.zeros((2, self.chunk_size), dtype=np.float32)
res.append(source)
res = np.stack(res, axis=0)
if self.aug:
for i, instr in enumerate(self.instruments):
res[i] = self.augm_data(res[i], instr)
return torch.tensor(res, dtype=torch.float32)
def load_aligned_data(self):
track_path, track_length = random.choice(self.metadata)
should_print = not dist.is_initialized() or dist.get_rank() == 0
attempts = 10
while attempts:
if track_length >= self.chunk_size:
common_offset = np.random.randint(track_length - self.chunk_size + 1)
else:
common_offset = None
res = []
silent_chunks = 0
for i in self.instruments:
found = False
for extension in self.file_types:
path_to_audio_file = f"{track_path}/{i}.{extension}"
if os.path.isfile(path_to_audio_file):
found = True
try:
source = load_chunk(
path_to_audio_file,
track_length,
self.chunk_size,
offset=common_offset,
)
except Exception as e:
if should_print:
print(f"Error: {e} Path: {path_to_audio_file}")
source = np.zeros((2, self.chunk_size), dtype=np.float32)
break
if not found:
source = np.zeros((2, self.chunk_size), dtype=np.float32)
res.append(source)
if np.abs(source).mean() < self.min_mean_abs: # remove quiet chunks
silent_chunks += 1
mix = None
for extension in self.file_types:
path_to_mix_file = track_path + "/mixture.{}".format(extension)
if os.path.isfile(path_to_mix_file):
try:
mix = load_chunk(
path_to_mix_file,
track_length,
self.chunk_size,
offset=common_offset,
)
except Exception as e:
if should_print:
print(
"Error loading mix: {} Path: {}".format(
e, path_to_mix_file
)
)
break
if silent_chunks == 0:
break
attempts -= 1
if attempts <= 0 and should_print:
print("Attempts max!", track_path)
if common_offset is None:
break
try:
res = np.stack(res, axis=0)
except Exception as e:
print(
"Error during stacking stems: {} Track Length: {} Track path: {}".format(
str(e), track_length, track_path
)
)
res = np.zeros(
(len(self.instruments), 2, self.chunk_size), dtype=np.float32
)
if mix is None:
mix = res.sum(0)
if self.aug:
for i, instr in enumerate(self.instruments):
res[i] = self.augm_data(res[i], instr)
return torch.tensor(res, dtype=torch.float32), torch.tensor(
mix, dtype=torch.float32
)
def augm_data(self, source, instr):
# source.shape = (2, 261120) - first channels, second length
source_shape = source.shape
applied_augs = []
if "all" in self.config["augmentations"]:
augs = self.config["augmentations"]["all"]
else:
augs = dict()
# We need to add to all augmentations specific augs for stem. And rewrite values if needed
if instr in self.config["augmentations"]:
for el in self.config["augmentations"][instr]:
augs[el] = self.config["augmentations"][instr][el]
# Channel shuffle
if "channel_shuffle" in augs:
if augs["channel_shuffle"] > 0:
if random.uniform(0, 1) < augs["channel_shuffle"]:
source = source[::-1].copy()
applied_augs.append("channel_shuffle")
# Random inverse
if "random_inverse" in augs:
if augs["random_inverse"] > 0:
if random.uniform(0, 1) < augs["random_inverse"]:
source = source[:, ::-1].copy()
applied_augs.append("random_inverse")
# Random polarity (multiply -1)
if "random_polarity" in augs:
if augs["random_polarity"] > 0:
if random.uniform(0, 1) < augs["random_polarity"]:
source = -source.copy()
applied_augs.append("random_polarity")
# Random pitch shift
if "pitch_shift" in augs:
if augs["pitch_shift"] > 0:
if random.uniform(0, 1) < augs["pitch_shift"]:
apply_aug = AU.PitchShift(
min_semitones=augs["pitch_shift_min_semitones"],
max_semitones=augs["pitch_shift_max_semitones"],
p=1.0,
)
source = apply_aug(samples=source, sample_rate=44100)
applied_augs.append("pitch_shift")
# Random seven band parametric eq
if "seven_band_parametric_eq" in augs:
if augs["seven_band_parametric_eq"] > 0:
if random.uniform(0, 1) < augs["seven_band_parametric_eq"]:
apply_aug = AU.SevenBandParametricEQ(
min_gain_db=augs["seven_band_parametric_eq_min_gain_db"],
max_gain_db=augs["seven_band_parametric_eq_max_gain_db"],
p=1.0,
)
source = apply_aug(samples=source, sample_rate=44100)
applied_augs.append("seven_band_parametric_eq")
# Random tanh distortion
if "tanh_distortion" in augs:
if augs["tanh_distortion"] > 0:
if random.uniform(0, 1) < augs["tanh_distortion"]:
apply_aug = AU.TanhDistortion(
min_distortion=augs["tanh_distortion_min"],
max_distortion=augs["tanh_distortion_max"],
p=1.0,
)
source = apply_aug(samples=source, sample_rate=44100)
applied_augs.append("tanh_distortion")
# Random MP3 Compression
if "mp3_compression" in augs:
if augs["mp3_compression"] > 0:
if random.uniform(0, 1) < augs["mp3_compression"]:
apply_aug = AU.Mp3Compression(
min_bitrate=augs["mp3_compression_min_bitrate"],
max_bitrate=augs["mp3_compression_max_bitrate"],
backend=augs["mp3_compression_backend"],
p=1.0,
)
source = apply_aug(samples=source, sample_rate=44100)
applied_augs.append("mp3_compression")
# Random AddGaussianNoise
if "gaussian_noise" in augs:
if augs["gaussian_noise"] > 0:
if random.uniform(0, 1) < augs["gaussian_noise"]:
apply_aug = AU.AddGaussianNoise(
min_amplitude=augs["gaussian_noise_min_amplitude"],
max_amplitude=augs["gaussian_noise_max_amplitude"],
p=1.0,
)
source = apply_aug(samples=source, sample_rate=44100)
applied_augs.append("gaussian_noise")
# Random TimeStretch
if "time_stretch" in augs:
if augs["time_stretch"] > 0:
if random.uniform(0, 1) < augs["time_stretch"]:
apply_aug = AU.TimeStretch(
min_rate=augs["time_stretch_min_rate"],
max_rate=augs["time_stretch_max_rate"],
leave_length_unchanged=True,
p=1.0,
)
source = apply_aug(samples=source, sample_rate=44100)
applied_augs.append("time_stretch")
# Possible fix of shape
if source_shape != source.shape:
source = source[..., : source_shape[-1]]
# Random Reverb
if "pedalboard_reverb" in augs:
if augs["pedalboard_reverb"] > 0:
if random.uniform(0, 1) < augs["pedalboard_reverb"]:
room_size = random.uniform(
augs["pedalboard_reverb_room_size_min"],
augs["pedalboard_reverb_room_size_max"],
)
damping = random.uniform(
augs["pedalboard_reverb_damping_min"],
augs["pedalboard_reverb_damping_max"],
)
wet_level = random.uniform(
augs["pedalboard_reverb_wet_level_min"],
augs["pedalboard_reverb_wet_level_max"],
)
dry_level = random.uniform(
augs["pedalboard_reverb_dry_level_min"],
augs["pedalboard_reverb_dry_level_max"],
)
width = random.uniform(
augs["pedalboard_reverb_width_min"],
augs["pedalboard_reverb_width_max"],
)
board = PB.Pedalboard(
[
PB.Reverb(
room_size=room_size, # 0.1 - 0.9
damping=damping, # 0.1 - 0.9
wet_level=wet_level, # 0.1 - 0.9
dry_level=dry_level, # 0.1 - 0.9
width=width, # 0.9 - 1.0
freeze_mode=0.0,
)
]
)
source = board(source, 44100)
applied_augs.append("pedalboard_reverb")
# Random Chorus
if "pedalboard_chorus" in augs:
if augs["pedalboard_chorus"] > 0:
if random.uniform(0, 1) < augs["pedalboard_chorus"]:
rate_hz = random.uniform(
augs["pedalboard_chorus_rate_hz_min"],
augs["pedalboard_chorus_rate_hz_max"],
)
depth = random.uniform(
augs["pedalboard_chorus_depth_min"],
augs["pedalboard_chorus_depth_max"],
)
centre_delay_ms = random.uniform(
augs["pedalboard_chorus_centre_delay_ms_min"],
augs["pedalboard_chorus_centre_delay_ms_max"],
)
feedback = random.uniform(
augs["pedalboard_chorus_feedback_min"],
augs["pedalboard_chorus_feedback_max"],
)
mix = random.uniform(
augs["pedalboard_chorus_mix_min"],
augs["pedalboard_chorus_mix_max"],
)
board = PB.Pedalboard(
[
PB.Chorus(
rate_hz=rate_hz,
depth=depth,
centre_delay_ms=centre_delay_ms,
feedback=feedback,
mix=mix,
)
]
)
source = board(source, 44100)
applied_augs.append("pedalboard_chorus")
# Random Phazer
if "pedalboard_phazer" in augs:
if augs["pedalboard_phazer"] > 0:
if random.uniform(0, 1) < augs["pedalboard_phazer"]:
rate_hz = random.uniform(
augs["pedalboard_phazer_rate_hz_min"],
augs["pedalboard_phazer_rate_hz_max"],
)
depth = random.uniform(
augs["pedalboard_phazer_depth_min"],
augs["pedalboard_phazer_depth_max"],
)
centre_frequency_hz = random.uniform(
augs["pedalboard_phazer_centre_frequency_hz_min"],
augs["pedalboard_phazer_centre_frequency_hz_max"],
)
feedback = random.uniform(
augs["pedalboard_phazer_feedback_min"],
augs["pedalboard_phazer_feedback_max"],
)
mix = random.uniform(
augs["pedalboard_phazer_mix_min"],
augs["pedalboard_phazer_mix_max"],
)
board = PB.Pedalboard(
[
PB.Phaser(
rate_hz=rate_hz,
depth=depth,
centre_frequency_hz=centre_frequency_hz,
feedback=feedback,
mix=mix,
)
]
)
source = board(source, 44100)
applied_augs.append("pedalboard_phazer")
# Random Distortion
if "pedalboard_distortion" in augs:
if augs["pedalboard_distortion"] > 0:
if random.uniform(0, 1) < augs["pedalboard_distortion"]:
drive_db = random.uniform(
augs["pedalboard_distortion_drive_db_min"],
augs["pedalboard_distortion_drive_db_max"],
)
board = PB.Pedalboard(
[
PB.Distortion(
drive_db=drive_db,
)
]
)
source = board(source, 44100)
applied_augs.append("pedalboard_distortion")
# Random PitchShift
if "pedalboard_pitch_shift" in augs:
if augs["pedalboard_pitch_shift"] > 0:
if random.uniform(0, 1) < augs["pedalboard_pitch_shift"]:
semitones = random.uniform(
augs["pedalboard_pitch_shift_semitones_min"],
augs["pedalboard_pitch_shift_semitones_max"],
)
board = PB.Pedalboard([PB.PitchShift(semitones=semitones)])
source = board(source, 44100)
applied_augs.append("pedalboard_pitch_shift")
# Random Resample
if "pedalboard_resample" in augs:
if augs["pedalboard_resample"] > 0:
if random.uniform(0, 1) < augs["pedalboard_resample"]:
target_sample_rate = random.uniform(
augs["pedalboard_resample_target_sample_rate_min"],
augs["pedalboard_resample_target_sample_rate_max"],
)
board = PB.Pedalboard(
[PB.Resample(target_sample_rate=target_sample_rate)]
)
source = board(source, 44100)
applied_augs.append("pedalboard_resample")
# Random Bitcrash
if "pedalboard_bitcrash" in augs:
if augs["pedalboard_bitcrash"] > 0:
if random.uniform(0, 1) < augs["pedalboard_bitcrash"]:
bit_depth = random.uniform(
augs["pedalboard_bitcrash_bit_depth_min"],
augs["pedalboard_bitcrash_bit_depth_max"],
)
board = PB.Pedalboard([PB.Bitcrush(bit_depth=bit_depth)])
source = board(source, 44100)
applied_augs.append("pedalboard_bitcrash")
# Random MP3Compressor
if "pedalboard_mp3_compressor" in augs:
if augs["pedalboard_mp3_compressor"] > 0:
if random.uniform(0, 1) < augs["pedalboard_mp3_compressor"]:
vbr_quality = random.uniform(
augs["pedalboard_mp3_compressor_pedalboard_mp3_compressor_min"],
augs["pedalboard_mp3_compressor_pedalboard_mp3_compressor_max"],
)
board = PB.Pedalboard([PB.MP3Compressor(vbr_quality=vbr_quality)])
source = board(source, 44100)
applied_augs.append("pedalboard_mp3_compressor")
# print(applied_augs)
return source