| | import os |
| | import torchaudio |
| | import torch |
| | import numpy as np |
| | import soundfile |
| | class AudioLoader: |
| | def __init__(self, sample_rate=16000): |
| | self.sample_rate = sample_rate |
| |
|
| | def load_audio(self, file_path): |
| | audio, sample_rate = torchaudio.load(file_path,backend='soundfile') |
| | if sample_rate != self.sample_rate: |
| | audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)(audio) |
| | return audio.squeeze(0) |
| |
|
| | class STFT: |
| | def __init__(self, n_fft=1024, hop_length=512, win_length=1024): |
| | self.n_fft = n_fft |
| | self.hop_length = hop_length |
| | self.win_length = win_length |
| |
|
| | def compute_stft(self, signal): |
| | return torch.stft(signal, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=torch.hamming_window(self.win_length), return_complex=True) |
| |
|
| | class SpectrogramSaver: |
| | @staticmethod |
| | def save_spectrogram(spectrogram, save_path): |
| | torch.save(spectrogram, save_path) |
| |
|
| | class Preprocessing: |
| | def __init__(self, sample_rate=16000, n_fft=1024, hop_length=512, win_length=1024): |
| | self.loader = AudioLoader(sample_rate) |
| | self.stft = STFT(n_fft, hop_length, win_length) |
| | self.saver = SpectrogramSaver() |
| | self.fixed_length = None |
| |
|
| | def preprocess(self, signal): |
| | spectrogram = self.stft.compute_stft(signal) |
| | real = spectrogram.real |
| | imag = spectrogram.imag |
| | combined = torch.stack((real, imag), dim=-1) |
| | return combined |
| |
|
| | def determine_fixed_length(self, noisy_dir): |
| | lengths = [] |
| | noisy_files = [os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')] |
| |
|
| | for noisy_file in noisy_files: |
| | noisy_audio = self.loader.load_audio(noisy_file) |
| | noisy_spectrogram = self.preprocess(noisy_audio) |
| | lengths.append(noisy_spectrogram.shape[1]) |
| |
|
| | self.fixed_length = int(np.median(lengths)) |
| | print(f"Determined fixed length: {self.fixed_length}") |
| |
|
| | def create_dataset(self, noisy_dir, save_dir): |
| | if self.fixed_length is None: |
| | self.determine_fixed_length(noisy_dir) |
| |
|
| | noisy_save_dir = os.path.join(save_dir, 'noisy') |
| | |
| | if not os.path.exists(noisy_save_dir): |
| | os.makedirs(noisy_save_dir) |
| |
|
| | noisy_files = [os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')] |
| |
|
| | for noisy_file in noisy_files: |
| | noisy_audio = self.loader.load_audio(noisy_file) |
| | noisy_spectrogram = self.preprocess(noisy_audio) |
| | noisy_spectrogram = self.pad_spectrogram(noisy_spectrogram) |
| | noisy_save_path = os.path.join(noisy_save_dir, f"noisy_{os.path.basename(noisy_file).split('.')[0]}.pt") |
| | self.saver.save_spectrogram(noisy_spectrogram, noisy_save_path) |
| |
|
| | def pad_spectrogram(self, spectrogram): |
| | pad_length = self.fixed_length - spectrogram.shape[1] |
| | if pad_length > 0: |
| | pad = torch.zeros((spectrogram.shape[0], pad_length, spectrogram.shape[2])) |
| | spectrogram = torch.cat((spectrogram, pad), dim=1) |
| | elif pad_length < 0: |
| | spectrogram = spectrogram[:, :self.fixed_length, :] |
| | return spectrogram |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|