| |
| |
| ''' |
| @Project :Waveformer-main |
| @File :dataset_online.py |
| @IDE :PyCharm |
| @Author :Aisaka/Hao Ma @SDU |
| @Date :2023/11/1 下午6:47 |
| ''' |
| import os |
| import random |
|
|
| import torch |
| import torchaudio |
| import torchaudio.transforms as AT |
| import csv |
| import json |
| import numpy as np |
| import librosa |
|
|
|
|
| def labels2caption(labels): |
| prefix = "The sound of " if len(labels) == 1 else "The sounds of " |
| caption = prefix + ', '.join(labels) |
| return caption |
|
|
|
|
| class CLAPSepDataSet(torch.utils.data.Dataset): |
|
|
| def __init__(self, data_list, dset='', silence_rate=0.05, chunk_dur=10, sr=None, resample_rate=None): |
| assert dset in ['train', 'val'], \ |
| "`dset` must be one of ['train', 'val']" |
| self.dset = dset |
| self.silence_rate = silence_rate |
| self.chunk_dur = chunk_dur |
| self.data_meta = dict() |
| self.text_dict = dict() |
| with open(data_list, 'r', encoding='utf-8') as d: |
| reader = csv.reader(d, skipinitialspace=True) |
| for row in reader: |
| assert os.path.exists(row[0]) |
| self.data_meta[row[0]] = row[1:] |
| label = ', '.join(row[1:]) |
| if label not in self.text_dict: |
| self.text_dict[label] = [] |
| self.text_dict[label].append(row[0]) |
| |
| self.augmentation = torchaudio.transforms.SpeedPerturbation(48000, [0.9, 1.1]) |
|
|
| self.data_names = list(self.data_meta.keys()) |
| if dset == 'val': |
| self.noise_names = [] |
| for name in self.data_names: |
| noise_name = self.choose_other_samples(', '.join(self.data_meta[name]), 1)[0] |
| self.noise_names.append(noise_name) |
|
|
| if resample_rate is not None: |
| self.resampler = AT.Resample(sr, resample_rate) |
| self.sr = sr |
| self.resample_rate = resample_rate |
| else: |
| self.sr = sr |
|
|
| def __len__(self): |
| return len(self.data_names) |
|
|
| def choose_other_samples(self, target_text, num): |
| candidates = list(self.text_dict.keys()) |
| candidates.remove(target_text) |
| chosen_text = random.sample(candidates, num) |
| chosen_samples = [random.choice(self.text_dict[text]) for text in chosen_text] |
| return chosen_samples |
|
|
| def load_wav(self, path): |
| max_length = self.sr * self.chunk_dur |
| wav = librosa.core.load(path, sr=self.sr)[0] |
| if len(wav) > max_length: |
| wav = wav[0:max_length] |
|
|
| |
| if len(wav) < max_length: |
| wav = np.pad(wav, (0, max_length - len(wav)), 'constant') |
| return wav |
|
|
| def __getitem__(self, idx): |
| tgt_name = self.data_names[idx] |
| if self.dset =='train': |
| noise_name = tgt_name |
| while set(self.data_meta[noise_name]) & set(self.data_meta[tgt_name]): |
| noise_name = random.choice(self.data_names) |
| else: |
| noise_name = self.noise_names[idx] |
|
|
| snr = torch.zeros((1,)) |
| |
| tgt = torch.tensor(self.load_wav(tgt_name)).unsqueeze(0) |
| noise = torch.tensor(self.load_wav(noise_name)).unsqueeze(0) |
| |
| |
| mixed = torchaudio.functional.add_noise(tgt, noise, snr=snr) |
| assert not torch.isnan(mixed).any(), f"tgt: {tgt_name}, noise: {noise_name}" |
| pos_sample, _ = self.augmentation(self.resampler(tgt.squeeze())) |
| neg_sample, _ = self.augmentation(self.resampler(noise.squeeze())) |
|
|
| max_value = torch.max(torch.abs(mixed)) |
| if max_value > 1: |
| tgt *= 0.9 / max_value |
| mixed *= 0.9 / max_value |
|
|
| tgt = tgt.squeeze() |
| mixed = mixed.squeeze() |
| tgt_cap = labels2caption(self.data_meta[tgt_name]) |
| neg_cap = labels2caption(self.data_meta[noise_name]) |
| mixed_resample = self.resampler(mixed) |
| |
| |
| if self.dset =='train' and random.random() < self.silence_rate: |
| other_name = tgt_name |
| while set(self.data_meta[other_name]) & (set(self.data_meta[tgt_name]) | set(self.data_meta[noise_name])): |
| other_name = random.choice(self.data_names) |
| tgt = torch.zeros_like(mixed) |
| neg_cap = labels2caption(self.data_meta[tgt_name] + self.data_meta[noise_name]) |
| tgt_cap = labels2caption(self.data_meta[other_name]) |
| pos_sample, _ = self.augmentation(self.resampler(torch.tensor(self.load_wav(other_name)))) |
| neg_sample, _ = self.augmentation(mixed_resample) |
|
|
| return mixed, mixed_resample, tgt_cap, neg_cap, tgt, self.pad_or_trim(pos_sample), self.pad_or_trim(neg_sample) |
|
|
| def pad_or_trim(self, wav_in): |
| target_len = 48000 * self.chunk_dur |
| if wav_in.size(0) < target_len: |
| wav_in = torch.nn.functional.pad(wav_in, (0, target_len - wav_in.size(0))) |
| elif wav_in.size(0) > target_len: |
| wav_in = wav_in[:target_len] |
| max_value = torch.max(torch.abs(wav_in)) |
| if max_value > 1: |
| wav_in *= 0.9 / max_value |
| return wav_in |
|
|
|
|
| class CLAPSepDataEngineDataSet(torch.utils.data.Dataset): |
|
|
| def __init__(self, data_list, dset='', data_engine_json='', silence_rate=0.05, chunk_dur=10, sr=None, resample_rate=None): |
| assert dset in ['train', 'val'], \ |
| "`dset` must be one of ['train', 'val']" |
| self.dset = dset |
| self.silence_rate = silence_rate |
| self.chunk_dur = chunk_dur |
| self.data_meta = dict() |
| with open(data_list, 'r', encoding='utf-8') as d: |
| reader = csv.reader(d, skipinitialspace=True) |
| for row in reader: |
| assert os.path.exists(row[0]), row[0] |
| self.data_meta[row[0]] = row[1:] |
| |
| self.augmentation = torchaudio.transforms.SpeedPerturbation(48000, [0.9, 1.1]) |
|
|
| self.data_names = list(self.data_meta.keys()) |
| if dset == 'val': |
| self.noise_names = [] |
| for name in self.data_names: |
| noise_name = name |
| while set(self.data_meta[noise_name]) & set(self.data_meta[name]): |
| noise_name = random.choice(self.data_names) |
| self.noise_names.append(noise_name) |
| |
| self.data_engine_dict = {} |
| if os.path.exists(data_engine_json): |
| self.data_engine_dict = json.load(open(data_engine_json, 'r')) |
|
|
| if resample_rate is not None: |
| self.resampler = AT.Resample(sr, resample_rate) |
| self.sr = sr |
| self.resample_rate = resample_rate |
| else: |
| self.sr = sr |
|
|
| def __len__(self): |
| return len(self.data_names) |
|
|
| def load_wav(self, path): |
| max_length = self.sr * self.chunk_dur |
| wav = librosa.core.load(path, sr=self.sr)[0] |
| if len(wav) > max_length: |
| wav = wav[0:max_length] |
|
|
| |
| if len(wav) < max_length: |
| wav = np.pad(wav, (0, max_length - len(wav)), 'constant') |
| return wav |
|
|
| def __getitem__(self, idx): |
| tgt_name = self.data_names[idx] |
| if self.dset =='train': |
| noise_name = tgt_name |
| while set(self.data_meta[noise_name]) & set(self.data_meta[tgt_name]): |
| noise_name = random.choice(self.data_names) |
| else: |
| noise_name = self.noise_names[idx] |
| |
| snr = torch.zeros((1,)) |
| |
| tgt = torch.tensor(self.load_wav(tgt_name)).unsqueeze(0) |
| noise = torch.tensor(self.load_wav(noise_name)).unsqueeze(0) |
| |
| |
| mixed = torchaudio.functional.add_noise(tgt, noise, snr=snr) |
| |
| |
| pos_sample, _ = self.augmentation(self.resampler(tgt.squeeze())) |
| noise = noise.squeeze() |
| |
| max_value = torch.max(torch.abs(mixed)) |
| if max_value > 1: |
| tgt *= 0.9 / max_value |
| mixed *= 0.9 / max_value |
| |
| tgt = tgt.squeeze() |
| mixed = mixed.squeeze() |
| tgt_cap = labels2caption(self.data_meta[tgt_name]) |
| neg_cap = labels2caption(self.data_meta[noise_name]) |
| mixed_resample = self.resampler(mixed) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| video = tgt_name.split('/')[-1][:-4] |
| if self.dset =='train' and video in self.data_engine_dict and random.random() > 0.5: |
| mixed = tgt |
| mixed_resample = self.resampler(mixed) |
| items = self.data_engine_dict[video] |
| tgt_idx = random.choice(range(0, len(items))) |
| tgt_item = items[tgt_idx] |
| items.pop(tgt_idx) |
| tgt = torch.tensor(self.load_wav(tgt_item[0])) |
| max_value = torch.max(torch.abs(tgt)) |
| if max_value > 1: |
| tgt *= 0.9 / max_value |
| tgt_cap = tgt_item[1] |
| if len(items) > 0: |
| noises = [torch.tensor(self.load_wav(x[0])) for x in items] |
| noise_caps = [x[1] for x in items] |
| noise = torch.mean(torch.stack(noises, dim=0), dim=0) |
| neg_cap = labels2caption(noise_caps) |
| |
| |
| elif self.dset =='train' and random.random() < self.silence_rate: |
| other_name = tgt_name |
| while set(self.data_meta[other_name]) & (set(self.data_meta[tgt_name]) | set(self.data_meta[noise_name])): |
| other_name = random.choice(self.data_names) |
| tgt = torch.zeros_like(mixed) |
| neg_cap = labels2caption(self.data_meta[tgt_name] + self.data_meta[noise_name]) |
| tgt_cap = labels2caption(self.data_meta[other_name]) |
| pos_sample, _ = self.augmentation(self.resampler(torch.tensor(self.load_wav(other_name)))) |
| noise = mixed |
| |
| neg_sample, _ = self.augmentation(self.resampler(noise)) |
|
|
| return mixed, mixed_resample, tgt_cap, neg_cap, tgt, self.pad_or_trim(pos_sample), self.pad_or_trim(neg_sample) |
|
|
| def pad_or_trim(self, wav_in): |
| target_len = 48000 * self.chunk_dur |
| if wav_in.size(0) < target_len: |
| wav_in = torch.nn.functional.pad(wav_in, (0, target_len - wav_in.size(0))) |
| elif wav_in.size(0) > target_len: |
| wav_in = wav_in[:target_len] |
| max_value = torch.max(torch.abs(wav_in)) |
| if max_value > 1: |
| wav_in *= 0.9 / max_value |
| return wav_in |
|
|
|
|