szuweifu commited on
Commit
d78f08c
·
verified ·
1 Parent(s): 6380bc7

Upload 11 files

Browse files
exp/30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_peak_GAN_tel_mic/g_01134000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e27db9e1de904eb59fc627dea72c69da7ca25650a3e704b4096f89812b395fe5
3
+ size 38982886
inference.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import argparse
11
+ import torch
12
+ import torchaudio
13
+ import librosa
14
+ from models.stfts import mag_phase_stft, mag_phase_istft
15
+ from models.generator_SEMamba_time_d4 import SEMamba
16
+ from utils.util import load_config, pad_or_trim_to_match
17
+
18
+ def get_filepaths(directory, file_type=None):
19
+ file_paths = [] # List which will store all of the full filepaths.
20
+ # Walk the tree.
21
+ for root, directories, files in os.walk(directory):
22
+ for filename in files:
23
+ # Join the two strings in order to form the full filepath.
24
+ filepath = os.path.join(root, filename)
25
+ if file_type is not None:
26
+ if filepath.split('.')[-1] == file_type:
27
+ file_paths.append(filepath) # Add it to the list.
28
+ else:
29
+ file_paths.append(filepath) # Add it to the list.
30
+ return file_paths # Self-explanatory.
31
+
32
+ def make_even(value):
33
+ value = int(round(value))
34
+ return value if value % 2 == 0 else value + 1
35
+
36
+ def inference(args, device):
37
+ cfg = load_config(args.config)
38
+ n_fft, hop_size, win_size = cfg['stft_cfg']['n_fft'], cfg['stft_cfg']['hop_size'], cfg['stft_cfg']['win_size']
39
+ compress_factor = cfg['model_cfg']['compress_factor']
40
+ sampling_rate = cfg['stft_cfg']['sampling_rate']
41
+
42
+ SE_model = SEMamba(cfg).to(device)
43
+ state_dict = torch.load(args.checkpoint_file, map_location=device)
44
+ SE_model.load_state_dict(state_dict['generator'])
45
+ SE_model.eval()
46
+
47
+ os.makedirs(args.output_folder, exist_ok=True)
48
+ with torch.no_grad():
49
+ for i, fname in enumerate(get_filepaths(args.input_folder)):
50
+ print(fname)
51
+ try:
52
+ os.makedirs(args.output_folder + fname[0:fname.rfind('/')].replace(args.input_folder,''), exist_ok=True)
53
+ noisy_wav, noisy_sr = torchaudio.load(fname)
54
+ except Exception as e:
55
+ print(f"Warning: cannot read {fname}, skipping. ({e})")
56
+ continue
57
+
58
+ if args.BWE is not None:
59
+ opts = {"res_type": "kaiser_best"}
60
+ noisy_wav = librosa.resample(noisy_wav.cpu().numpy(), orig_sr=noisy_sr, target_sr=int(args.BWE), **opts)
61
+ noisy_sr = int(args.BWE)
62
+
63
+ noisy_wav = torch.FloatTensor(noisy_wav).to(device)
64
+ n_fft_scaled = make_even(n_fft * noisy_sr // sampling_rate)
65
+ hop_size_scaled = make_even(hop_size * noisy_sr // sampling_rate)
66
+ win_size_scaled = make_even(win_size * noisy_sr // sampling_rate)
67
+
68
+ noisy_mag, noisy_pha, noisy_com = mag_phase_stft(
69
+ noisy_wav,
70
+ n_fft=n_fft_scaled,
71
+ hop_size=hop_size_scaled,
72
+ win_size=win_size_scaled,
73
+ compress_factor=compress_factor,
74
+ center=True,
75
+ addeps=False
76
+ )
77
+ amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha)
78
+
79
+ audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
80
+ audio_g = pad_or_trim_to_match(noisy_wav.detach(), audio_g, pad_value=1e-8) # Align lengths using epsilon padding
81
+ assert audio_g.shape == noisy_wav.shape, audio_g.shape
82
+
83
+ output_file = os.path.join(args.output_folder + fname.replace(args.input_folder,'').split('.')[0]+'.flac') # save to .flac format
84
+ torchaudio.save(output_file, audio_g.cpu(), noisy_sr)
85
+
86
+ def main():
87
+ print('Initializing Inference Process...')
88
+ parser = argparse.ArgumentParser()
89
+ parser.add_argument('--input_folder')
90
+ parser.add_argument('--output_folder')
91
+ parser.add_argument('--config')
92
+ parser.add_argument('--checkpoint_file', required=True)
93
+ parser.add_argument('--BWE', default=None)
94
+ args = parser.parse_args()
95
+
96
+ global device
97
+ if torch.cuda.is_available():
98
+ device = torch.device('cuda')
99
+ else:
100
+ raise RuntimeError("Currently, CPU mode is not supported.")
101
+
102
+ inference(args, device)
103
+
104
+
105
+ if __name__ == '__main__':
106
+ main()
inference.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES='0' python ./inference.py \
2
+ --input_folder ./noisy_audio \
3
+ --output_folder ./enhanced_audio \
4
+ --checkpoint_file ./exp/30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_peak_GAN_tel_mic/g_01134000.pth \
5
+ --config ./recipes/USEMamba_30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_001.yaml \
6
+ #--BWE 32000 \
inference_chunk.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import argparse
11
+ import torch
12
+ import torchaudio
13
+ import librosa
14
+ import math
15
+ from models.stfts import mag_phase_stft, mag_phase_istft
16
+ from models.generator_SEMamba_time_d4 import SEMamba
17
+ from utils.util import load_config, pad_or_trim_to_match
18
+
19
+
20
+ def get_filepaths(directory, file_type=None):
21
+ file_paths = [] # List which will store all of the full filepaths.
22
+ # Walk the tree.
23
+ for root, directories, files in os.walk(directory):
24
+ for filename in files:
25
+ # Join the two strings in order to form the full filepath.
26
+ filepath = os.path.join(root, filename)
27
+ if file_type is not None:
28
+ if filepath.split('.')[-1] == file_type:
29
+ file_paths.append(filepath) # Add it to the list.
30
+ else:
31
+ file_paths.append(filepath) # Add it to the list.
32
+ return file_paths # Self-explanatory.
33
+
34
+ def make_even(value):
35
+ value = int(round(value))
36
+ return value if value % 2 == 0 else value + 1
37
+
38
+ def inference(args, device):
39
+ cfg = load_config(args.config)
40
+ n_fft, hop_size, win_size = cfg['stft_cfg']['n_fft'], cfg['stft_cfg']['hop_size'], cfg['stft_cfg']['win_size']
41
+ compress_factor = cfg['model_cfg']['compress_factor']
42
+ sampling_rate = cfg['stft_cfg']['sampling_rate']
43
+
44
+ SE_model = SEMamba(cfg).to(device)
45
+ state_dict = torch.load(args.checkpoint_file, map_location=device)
46
+ SE_model.load_state_dict(state_dict['generator'])
47
+ SE_model.eval()
48
+
49
+ os.makedirs(args.output_folder, exist_ok=True)
50
+ with torch.no_grad():
51
+ for fname in get_filepaths(args.input_folder):
52
+ print(fname)
53
+ try:
54
+ os.makedirs(args.output_folder + fname[0:fname.rfind('/')].replace(args.input_folder,''), exist_ok=True)
55
+ Noisy_wav, noisy_sr = torchaudio.load(fname)
56
+ except Exception as e:
57
+ print(f"Warning: cannot read {fname}, skipping. ({e})")
58
+ continue
59
+
60
+ if args.BWE is not None:
61
+ opts = {"res_type": "kaiser_best"}
62
+ Noisy_wav = librosa.resample(Noisy_wav.cpu().numpy(), orig_sr=noisy_sr, target_sr=int(args.BWE), **opts)
63
+ noisy_sr = int(args.BWE)
64
+
65
+ chunk_size = int(args.chunk_size_in_seconds*noisy_sr) # (in samples)
66
+ hop_length = int(args.hop_length_portion*chunk_size) # (in samples)
67
+ window = torch.hann_window(chunk_size).to(device)
68
+
69
+ n_fft_scaled = make_even(n_fft * noisy_sr // sampling_rate)
70
+ hop_size_scaled = make_even(hop_size * noisy_sr // sampling_rate)
71
+ win_size_scaled = make_even(win_size * noisy_sr // sampling_rate)
72
+
73
+ Noisy_wav = torch.FloatTensor(Noisy_wav).to(device)
74
+ audio_enhanced = torch.zeros_like(Noisy_wav).to(device)
75
+ #norm = torch.zeros_like(Noisy_wav).to(device)
76
+ window_sum = torch.zeros_like(Noisy_wav).to(device)
77
+ for c in range(Noisy_wav.shape[0]): # for multi-channel speech
78
+ noisy_wav = Noisy_wav[c:c+1,:]
79
+ for i in range(max(1, math.ceil((noisy_wav.shape[1]-chunk_size)/hop_length)+1)):
80
+ noisy_wav_chunk = noisy_wav[:, i*hop_length : i*hop_length+chunk_size]
81
+
82
+ noisy_mag, noisy_pha, noisy_com = mag_phase_stft(
83
+ noisy_wav_chunk,
84
+ n_fft=n_fft_scaled,
85
+ hop_size=hop_size_scaled,
86
+ win_size=win_size_scaled,
87
+ compress_factor=compress_factor,
88
+ center=True,
89
+ addeps=False
90
+ )
91
+ amp_g, pha_g, _ = SE_model(noisy_mag, noisy_pha)
92
+
93
+ audio_g = mag_phase_istft(amp_g, pha_g, n_fft_scaled, hop_size_scaled, win_size_scaled, compress_factor)
94
+ audio_g = pad_or_trim_to_match(noisy_wav_chunk.detach(), audio_g, pad_value=1e-8) # Align lengths using epsilon padding
95
+
96
+ audio_enhanced[c:c+1,i*hop_length:i*hop_length+chunk_size] += audio_g*window[0:audio_g.shape[1]]
97
+ window_sum[c:c+1,i*hop_length:i*hop_length+chunk_size] += window[0:audio_g.shape[1]]
98
+ #norm[c:c+1,i*hop_length:i*hop_length+chunk_size] += 1.0
99
+ nonzero_indices = (window_sum > 1e-8)
100
+ audio_enhanced[:,nonzero_indices[0]] = audio_enhanced[:,nonzero_indices[0]]/window_sum[:,nonzero_indices[0]]
101
+ assert audio_enhanced.shape == Noisy_wav.shape, audio_enhanced.shape
102
+ output_file = os.path.join(args.output_folder + fname.replace(args.input_folder,'').split('.')[0]+'.flac') # save to .flac format
103
+ torchaudio.save(output_file, audio_enhanced.cpu(), noisy_sr)
104
+
105
+ def main():
106
+ print('Initializing Inference Process..')
107
+ parser = argparse.ArgumentParser()
108
+ parser.add_argument('--input_folder')
109
+ parser.add_argument('--output_folder')
110
+ parser.add_argument('--config')
111
+ parser.add_argument('--checkpoint_file')
112
+ parser.add_argument('--chunk_size_in_seconds', type=float)
113
+ parser.add_argument('--hop_length_portion', type=float)
114
+ parser.add_argument('--BWE', default=None)
115
+ args = parser.parse_args()
116
+
117
+ global device
118
+ if torch.cuda.is_available():
119
+ device = torch.device('cuda')
120
+ else:
121
+ raise RuntimeError("Currently, CPU mode is not supported.")
122
+
123
+ inference(args, device)
124
+
125
+ if __name__ == '__main__':
126
+ main()
127
+
inference_chunk.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES='0' python ./inference_chunk.py \
2
+ --input_folder ./long_noisy_audio \
3
+ --output_folder ./long_enhanced_audio \
4
+ --checkpoint_file ./exp/30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_peak_GAN_tel_mic/g_01134000.pth \
5
+ --config ./recipes/USEMamba_30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_001.yaml \
6
+ --chunk_size_in_seconds 5\
7
+ --hop_length_portion 0.5\
8
+ #--BWE 32000 \
9
+
models/codec_module_time_d4.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ from einops import rearrange
13
+
14
+ def get_padding_2d(kernel_size, dilation=(1, 1)):
15
+ """
16
+ Calculate the padding size for a 2D convolutional layer.
17
+
18
+ Args:
19
+ - kernel_size (tuple): Size of the convolutional kernel (height, width).
20
+ - dilation (tuple, optional): Dilation rate of the convolution (height, width). Defaults to (1, 1).
21
+
22
+ Returns:
23
+ - tuple: Calculated padding size (height, width).
24
+ """
25
+ return (int((kernel_size[0] * dilation[0] - dilation[0]) / 2),
26
+ int((kernel_size[1] * dilation[1] - dilation[1]) / 2))
27
+
28
+ class SPConvTranspose2d(nn.Module):
29
+ def __init__(self, in_channels, out_channels, kernel_size, r=1):
30
+ super(SPConvTranspose2d, self).__init__()
31
+ self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
32
+ self.out_channels = out_channels
33
+ self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1))
34
+ self.r = r
35
+
36
+ def forward(self, x):
37
+ x = self.pad1(x)
38
+ out = self.conv(x)
39
+ batch_size, nchannels, H, W = out.shape
40
+ out = out.view((batch_size, self.r, nchannels // self.r, H, W))
41
+ out = out.permute(0, 2, 3, 4, 1)
42
+ out = out.contiguous().view((batch_size, nchannels // self.r, H, -1))
43
+ return out
44
+
45
+ class DenseBlock(nn.Module):
46
+ """
47
+ DenseBlock module consisting of multiple convolutional layers with dilation.
48
+ """
49
+ def __init__(self, cfg, kernel_size=(3, 3), depth=4):
50
+ super(DenseBlock, self).__init__()
51
+ self.cfg = cfg
52
+ self.depth = depth
53
+ self.dense_block = nn.ModuleList()
54
+ self.hid_feature = cfg['model_cfg']['hid_feature']
55
+
56
+ for i in range(depth):
57
+ dil = 2 ** i
58
+ dense_conv = nn.Sequential(
59
+ nn.Conv2d(self.hid_feature * (i + 1), self.hid_feature, kernel_size,
60
+ dilation=(dil, 1), padding=get_padding_2d(kernel_size, (dil, 1))),
61
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
62
+ nn.PReLU(self.hid_feature)
63
+ )
64
+ self.dense_block.append(dense_conv)
65
+
66
+ def forward(self, x):
67
+ skip = x
68
+ for i in range(self.depth):
69
+ x = self.dense_block[i](skip)
70
+ skip = torch.cat([x, skip], dim=1)
71
+ return x
72
+
73
+ class DenseEncoder(nn.Module):
74
+ """
75
+ DenseEncoder module consisting of initial convolution, dense block, and a final convolution.
76
+ """
77
+ def __init__(self, cfg):
78
+ super(DenseEncoder, self).__init__()
79
+ self.cfg = cfg
80
+ self.input_channel = cfg['model_cfg']['input_channel']
81
+ self.hid_feature = cfg['model_cfg']['hid_feature']
82
+
83
+ self.dense_conv_1 = nn.Sequential(
84
+ nn.Conv2d(self.input_channel, self.hid_feature, (1, 1)),
85
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
86
+ nn.PReLU(self.hid_feature)
87
+ )
88
+
89
+ self.dense_block = DenseBlock(cfg, depth=4)
90
+
91
+ self.dense_conv_2 = nn.Sequential(
92
+ nn.Conv2d(self.hid_feature, self.hid_feature, (1, 3), stride=(4, 2)),
93
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
94
+ nn.PReLU(self.hid_feature)
95
+ )
96
+
97
+ def forward(self, x):
98
+ x = self.dense_conv_1(x) # [batch, hid_feature, time, freq]
99
+ x = self.dense_block(x) # [batch, hid_feature, time, freq]
100
+ x = self.dense_conv_2(x) # [batch, hid_feature, time, freq//2]
101
+ return x
102
+
103
+ class MagDecoder(nn.Module):
104
+ """
105
+ MagDecoder module for decoding magnitude information.
106
+ """
107
+ def __init__(self, cfg):
108
+ super(MagDecoder, self).__init__()
109
+ self.dense_block = DenseBlock(cfg, depth=4)
110
+ self.hid_feature = cfg['model_cfg']['hid_feature']
111
+ self.output_channel = cfg['model_cfg']['output_channel']
112
+ self.n_fft = cfg['stft_cfg']['n_fft']
113
+ self.beta = cfg['model_cfg']['beta']
114
+
115
+ self.up_conv1 = nn.Sequential(
116
+ SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 2),
117
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
118
+ nn.PReLU(self.hid_feature)
119
+ )
120
+
121
+ self.up_conv2 = nn.Sequential(
122
+ SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 4),
123
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
124
+ nn.PReLU(self.hid_feature)
125
+ )
126
+
127
+ self.final_conv = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
128
+
129
+ def forward(self, x):
130
+ x = self.dense_block(x)
131
+ x = self.up_conv1(x)
132
+ x = self.up_conv2(x.permute(0,1,3,2)).permute(0,1,3,2)
133
+ x = self.final_conv(x)
134
+ return x
135
+
136
+ class PhaseDecoder(nn.Module):
137
+ """
138
+ PhaseDecoder module for decoding phase information.
139
+ """
140
+ def __init__(self, cfg):
141
+ super(PhaseDecoder, self).__init__()
142
+ self.dense_block = DenseBlock(cfg, depth=4)
143
+ self.hid_feature = cfg['model_cfg']['hid_feature']
144
+ self.output_channel = cfg['model_cfg']['output_channel']
145
+
146
+ self.up_conv1 = nn.Sequential(
147
+ SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 2),
148
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
149
+ nn.PReLU(self.hid_feature)
150
+ )
151
+
152
+ self.up_conv2 = nn.Sequential(
153
+ SPConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), 4),
154
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
155
+ nn.PReLU(self.hid_feature)
156
+ )
157
+
158
+ self.phase_conv_r = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
159
+ self.phase_conv_i = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
160
+
161
+ def forward(self, x):
162
+ x = self.dense_block(x)
163
+ x = self.up_conv1(x)
164
+ x = self.up_conv2(x.permute(0,1,3,2)).permute(0,1,3,2)
165
+ x_r = self.phase_conv_r(x)
166
+ x_i = self.phase_conv_i(x)
167
+ x = torch.atan2(x_i, x_r)
168
+ return x
models/generator_SEMamba_time_d4.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from .mamba_block2_SEMamba import TFMambaBlock
13
+ from .codec_module_time_d4 import DenseEncoder, MagDecoder, PhaseDecoder
14
+
15
+ class SEMamba(nn.Module):
16
+ """
17
+ SEMamba model for speech enhancement using Mamba blocks.
18
+
19
+ This model uses a dense encoder, multiple Mamba blocks, and separate magnitude
20
+ and phase decoders to process noisy magnitude and phase inputs.
21
+ """
22
+ def __init__(self, cfg):
23
+ """
24
+ Initialize the SEMamba model.
25
+
26
+ Args:
27
+ - cfg: Configuration object containing model parameters.
28
+ """
29
+ super(SEMamba, self).__init__()
30
+ self.cfg = cfg
31
+ self.num_tscblocks = cfg['model_cfg']['num_tfmamba'] if cfg['model_cfg']['num_tfmamba'] is not None else 4 # default tfmamba: 4
32
+
33
+ # Initialize dense encoder
34
+ self.dense_encoder = DenseEncoder(cfg)
35
+
36
+ # Initialize Mamba blocks
37
+ self.TSMamba = nn.ModuleList([TFMambaBlock(cfg) for _ in range(self.num_tscblocks)])
38
+
39
+ # Initialize decoders
40
+ self.mask_decoder = MagDecoder(cfg)
41
+ self.phase_decoder = PhaseDecoder(cfg)
42
+
43
+ def forward(self, noisy_mag, noisy_pha):
44
+ """
45
+ Forward pass for the SEMamba model.
46
+
47
+ Args:
48
+ - noisy_mag (torch.Tensor): Noisy magnitude input tensor [B, F, T].
49
+ - noisy_pha (torch.Tensor): Noisy phase input tensor [B, F, T].
50
+
51
+ Returns:
52
+ - denoised_mag (torch.Tensor): Denoised magnitude tensor [B, F, T].
53
+ - denoised_pha (torch.Tensor): Denoised phase tensor [B, F, T].
54
+ - denoised_com (torch.Tensor): Denoised complex tensor [B, F, T, 2].
55
+ """
56
+ # Reshape inputs
57
+ noisy_mag = rearrange(noisy_mag, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
58
+ noisy_pha = rearrange(noisy_pha, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
59
+
60
+ # Concatenate magnitude and phase inputs
61
+ x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F]
62
+
63
+ # Prevent unpredictable errors
64
+ B, C, T, F = x.shape
65
+ zeros = torch.zeros(B, C, T, 2, device=x.device)
66
+ x = torch.cat((x, zeros), dim=-1)
67
+ zeros = torch.zeros(B, C, 2, F+2, device=x.device)
68
+ x = torch.cat((x, zeros), dim=-2)
69
+
70
+ # Encode input
71
+ x = self.dense_encoder(x)
72
+
73
+ # Apply Mamba blocks
74
+ for block in self.TSMamba:
75
+ x = block(x)
76
+
77
+ # Decode output
78
+ denoised_mag = rearrange(self.mask_decoder(x), 'b c t f -> b f t c').squeeze(-1)
79
+ denoised_pha = rearrange(self.phase_decoder(x), 'b c t f -> b f t c').squeeze(-1)
80
+
81
+ # Prevent unpredictable errors
82
+ denoised_mag = denoised_mag[:, :F, :T]
83
+ denoised_pha = denoised_pha[:, :F, :T]
84
+
85
+ # Combine denoised magnitude and phase into a complex representation
86
+ denoised_com = torch.stack(
87
+ (denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)),
88
+ dim=-1
89
+ )
90
+
91
+ return denoised_mag, denoised_pha, denoised_com
models/mamba_block2_SEMamba.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.nn import init
13
+ from torch.nn.parameter import Parameter
14
+ from functools import partial
15
+ from einops import rearrange
16
+ from mamba_ssm import Mamba
17
+
18
+ class MambaBlock(nn.Module):
19
+ def __init__(self, d_model, cfg):
20
+ super(MambaBlock, self).__init__()
21
+
22
+ d_state = cfg['model_cfg']['d_state'] # 16
23
+ d_conv = cfg['model_cfg']['d_conv'] # 4
24
+ expand = cfg['model_cfg']['expand'] # 4
25
+
26
+ self.forward_blocks = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
27
+ self.backward_blocks = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
28
+ self.output_proj = nn.Linear(2 * d_model, d_model)
29
+ self.norm = nn.LayerNorm(d_model)
30
+
31
+ def forward(self, x):
32
+ # x: [B, T, D]
33
+ out_fw = self.forward_blocks(x) + x
34
+
35
+ out_bw = self.backward_blocks(torch.flip(x, dims=[1])) + torch.flip(x, dims=[1])
36
+ out_bw = torch.flip(out_bw, dims=[1])
37
+
38
+ out = torch.cat([out_fw, out_bw], dim=-1)
39
+ out = self.output_proj(out)
40
+
41
+ # LayerNorm
42
+ return self.norm(out)
43
+
44
+
45
+ class TFMambaBlock(nn.Module):
46
+ """
47
+ Temporal-Frequency Mamba block for sequence modeling.
48
+
49
+ Attributes:
50
+ cfg (Config): Configuration for the block.
51
+ time_mamba (MambaBlock): Mamba block for temporal dimension.
52
+ freq_mamba (MambaBlock): Mamba block for frequency dimension.
53
+ tlinear (ConvTranspose1d): ConvTranspose1d layer for temporal dimension.
54
+ flinear (ConvTranspose1d): ConvTranspose1d layer for frequency dimension.
55
+ """
56
+ def __init__(self, cfg):
57
+ super(TFMambaBlock, self).__init__()
58
+ self.cfg = cfg
59
+ self.hid_feature = cfg['model_cfg']['hid_feature']
60
+
61
+ # Initialize Mamba blocks
62
+ self.time_mamba = MambaBlock(d_model=self.hid_feature, cfg=cfg)
63
+ self.freq_mamba = MambaBlock(d_model=self.hid_feature, cfg=cfg)
64
+
65
+ def forward(self, x):
66
+ """
67
+ Forward pass of the TFMamba block.
68
+
69
+ Parameters:
70
+ x (Tensor): Input tensor with shape (batch, channels, time, freq).
71
+
72
+ Returns:
73
+ Tensor: Output tensor after applying temporal and frequency Mamba blocks.
74
+ """
75
+ b, c, t, f = x.size()
76
+ x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
77
+ x = self.time_mamba(x) + x
78
+ x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
79
+ x = self.freq_mamba(x) + x
80
+ x = x.view(b, t, f, c).permute(0, 3, 1, 2)
81
+ return x
models/stfts.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ def decompress_signed_log1p(y):
13
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
14
+
15
+ RELU = nn.ReLU()
16
+
17
+ def mag_phase_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True, addeps=False):
18
+ """
19
+ Compute magnitude and phase using STFT.
20
+
21
+ Args:
22
+ y (torch.Tensor): Input audio signal.
23
+ n_fft (int): FFT size.
24
+ hop_size (int): Hop size.
25
+ win_size (int): Window size.
26
+ compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0.
27
+ center (bool, optional): Whether to center the signal before padding. Defaults to True.
28
+ eps (bool, optional): Whether adding epsilon to magnitude and phase or not. Defaults to False.
29
+
30
+ Returns:
31
+ tuple: Magnitude, phase, and complex representation of the STFT.
32
+ """
33
+ eps = 1e-10
34
+ hann_window = torch.hann_window(win_size).to(y.device)
35
+ stft_spec = torch.stft(
36
+ y, n_fft,
37
+ hop_length=hop_size,
38
+ win_length=win_size,
39
+ window=hann_window,
40
+ center=center,
41
+ pad_mode='reflect',
42
+ normalized=False,
43
+ return_complex=True)
44
+
45
+ if addeps==False:
46
+ mag = torch.abs(stft_spec)
47
+ pha = torch.angle(stft_spec)
48
+ else:
49
+ real_part = stft_spec.real
50
+ imag_part = stft_spec.imag
51
+ mag = torch.sqrt(real_part.pow(2) + imag_part.pow(2) + eps)
52
+ pha = torch.atan2(imag_part + eps, real_part + eps)
53
+ # Compress the magnitude
54
+ if compress_factor in ['log1p','relu_log1p', 'signed_log1p']:
55
+ mag = torch.log1p(mag)
56
+ else:
57
+ mag = torch.pow(mag, compress_factor)
58
+ com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1)
59
+ return mag, pha, com
60
+
61
+
62
+ def mag_phase_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
63
+ """
64
+ Inverse STFT to reconstruct the audio signal from magnitude and phase.
65
+
66
+ Args:
67
+ mag (torch.Tensor): Magnitude of the STFT.
68
+ pha (torch.Tensor): Phase of the STFT.
69
+ n_fft (int): FFT size.
70
+ hop_size (int): Hop size.
71
+ win_size (int): Window size.
72
+ compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0.
73
+ center (bool, optional): Whether to center the signal before padding. Defaults to True.
74
+
75
+ Returns:
76
+ torch.Tensor: Reconstructed audio signal.
77
+ """
78
+ if compress_factor == 'log1p':
79
+ mag = torch.expm1(mag)
80
+ elif compress_factor == 'signed_log1p':
81
+ mag = decompress_signed_log1p(mag)
82
+ elif compress_factor == 'relu_log1p':
83
+ mag = torch.expm1(RELU(mag))
84
+ else:
85
+ mag = torch.pow(RELU(mag), 1.0 / compress_factor)
86
+ com = torch.complex(mag * torch.cos(pha), mag * torch.sin(pha))
87
+ hann_window = torch.hann_window(win_size).to(com.device)
88
+ wav = torch.istft(
89
+ com,
90
+ n_fft,
91
+ hop_length=hop_size,
92
+ win_length=win_size,
93
+ window=hann_window,
94
+ center=center)
95
+ return wav
recipes/USEMamba_30x1_lr_00002_norm_05_vq_065_nfft_320_hop_40_NRIR_012_pha_0005_com_04_early_001.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment Settings
2
+ # These settings specify the hardware and distributed setup for the model training.
3
+ # Adjust `num_gpus` and `dist_config` according to your distributed training environment.
4
+ env_setting:
5
+ num_gpus: 8 # Number of GPUs. Now we don't support CPU mode.
6
+ num_workers: 20 # 0 Number of worker threads for data loading.
7
+ persistent_workers: True # False If you have large RAM, turn this to be True
8
+ prefetch_factor: 8 # null
9
+ seed: 1234 # Seed for random number generators to ensure reproducibility.
10
+ stdout_interval: 5000
11
+ checkpoint_interval: 5000 # save model to ckpt every N steps
12
+ validation_interval: 5000
13
+ dist_cfg:
14
+ dist_backend: nccl # Distributed training backend, 'nccl' for NVIDIA GPUs.
15
+ dist_url: tcp://localhost:19478 # URL for initializing distributed training.
16
+ world_size: 1 # Total number of processes in the distributed training.
17
+ pin_memory: True # If you have large RAM, turn this to be True
18
+
19
+
20
+ # STFT Configuration
21
+ # Configuration for Short-Time Fourier Transform (STFT), crucial for audio processing models.
22
+ stft_cfg:
23
+ sampling_rate: 8000 # Audio sampling rate in Hz.
24
+ n_fft: 320 # FFT components for transforming audio signals.
25
+ hop_size: 40 # Samples between successive frames.
26
+ win_size: 320 # Window size used in FFT.
27
+ sfi: True # Sampline Frequency Independent
28
+
29
+ # Model Configuration
30
+ # Defines the architecture specifics of the model, including layer configurations and feature compression.
31
+ model_cfg:
32
+ hid_feature: 64 # Channels in dense layers.
33
+ compress_factor: relu_log1p # Compression factor applied to extracted features.
34
+ num_tfmamba: 30 # Number of Time-Frequency Mamba (TFMamba) blocks in the model.
35
+ d_state: 16 # Dimensionality of the state vector in Mamba blocks.
36
+ d_conv: 4 # Convolutional layer dimensionality within Mamba blocks.
37
+ expand: 4 # Expansion factor for the layers within the Mamba blocks.
38
+ norm_epsilon: 0.00001 # Numerical stability in normalization layers within the Mamba blocks.
39
+ beta: 2.0 # Hyperparameter for the Learnable Sigmoid function.
40
+ input_channel: 2 # Magnitude and Phase
41
+ output_channel: 1 # Single Channel Speech Enhancement
42
+ inner_mamba_nlayer: 1 # Number of layer of Mamba in Bidirectional Mamba
43
+ nonlinear: None # last activation function for the mag encoder. 'softplus' or 'relu'
44
+ mapping: True # Otherwise, this should be masking model
utils/util.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import yaml
10
+ import torch
11
+ import os
12
+ import shutil
13
+ import torch.nn.functional as F
14
+
15
+ def load_config(config_path):
16
+ """Load configuration from a YAML file."""
17
+ with open(config_path, 'r') as file:
18
+ return yaml.safe_load(file)
19
+
20
+ def pad_or_trim_to_match(reference: torch.Tensor, target: torch.Tensor, pad_value: float = 1e-6) -> torch.Tensor:
21
+ """
22
+ Extends the target tensor to match the reference tensor along dim=1
23
+ without breaking autograd, by creating a new tensor and copying data in.
24
+ """
25
+ B, ref_len = reference.shape
26
+ _, tgt_len = target.shape
27
+
28
+ if tgt_len == ref_len:
29
+ return target
30
+ elif tgt_len > ref_len:
31
+ return target[:, :ref_len]
32
+
33
+ # Allocate padded tensor with grad support
34
+ padded = torch.full((B, ref_len), pad_value, dtype=target.dtype, device=target.device)
35
+ padded[:, :tgt_len] = target # This preserves gradient tracking
36
+
37
+ return padded