├── .DS_Store
├── .gitignore
├── .idea
├── .gitignore
├── DeepComplexCRN.iml
├── deployment.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── other.xml
├── vcs.xml
└── webServers.xml
├── README.md
├── config.py
├── dataloader
└── THCHS30.py
├── debug.py
├── main.py
├── models
├── DCCRN.py
├── complexnn.py
├── conv_stft.py
└── loss.py
└── utils
├── show.py
└── synthesizer.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stdKonjac/DeepComplexCRN/fb029bf604d02d34702947a1b745e2595d34665a/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoint
2 | samples
3 | models/pretrained-models
4 | .DS_Store
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/DeepComplexCRN.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
13 |
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DCCRN
2 |
3 | Deep Complex Convolution Recurrent Network for Phase-Aware Speech Enhancement
4 |
5 | __Authors__: Yanxin Hu, Yun Liu, Shubo Lv, Mengtao Xing, Shimin Zhang,Yihui Fu, Jian Wu, Bihong Zhang, Lei Xie
6 |
7 | Paper: https://arxiv.org/abs/2008.00264
8 |
9 | Official Sample: https://huyanxin.github.io/DeepComplexCRN/
10 |
11 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 |
4 | import torch
5 | import torchvision.transforms as transforms
6 |
7 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
8 |
9 |
10 | class DefaultConfig(object):
11 | project_root = '/data1/zengziyun/Project/DeepComplexCRN'
12 | data_root = os.path.join('/data1/zengziyun/Project/Dataset')
13 | checkpoint_root = os.path.join(project_root, 'checkpoint')
14 | sample_root = os.path.join(project_root, 'samples')
15 | pretrained_models_root = os.path.join(project_root, 'models/pretrained-models')
16 |
17 | use_gpu = True if torch.cuda.is_available() else False
18 | device = torch.device('cuda' if use_gpu else 'cpu')
19 | num_workers = 4
20 |
21 | # train params
22 | batch_size = 16
23 | max_epoch = 40
24 | lr = 1e-3
25 | lr_decay = 0.1
26 | weight_decay = 1e-5
27 |
28 | verbose_inter = 20
29 | save_inter = 5
30 |
31 | def _parse(self, kwargs):
32 | """
33 | update config params according to kwargs
34 | """
35 | for k, v in kwargs.items():
36 | if not hasattr(self, k):
37 | warnings.warn("Warning: opt does not have attribute %s" % k)
38 | setattr(self, k, v)
39 |
40 | opt.device = torch.device('cuda') if opt.use_gpu else torch.device('cpu')
41 |
42 | print('<===================current config===================>')
43 | for k, v in self.__class__.__dict__.items():
44 | if not k.startswith('_'):
45 | print(k, '=', getattr(self, k))
46 | print('<===================current config===================>')
47 |
48 |
49 | opt = DefaultConfig()
50 |
--------------------------------------------------------------------------------
/dataloader/THCHS30.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import librosa
4 | import fnmatch
5 |
6 | import torch
7 | from torch.utils.data import Dataset
8 |
9 | from config import opt
10 |
11 | DATAPATH = os.path.join(opt.data_root, 'THCHS-30')
12 |
13 |
14 | class THCHS30(Dataset):
15 | def __init__(self, phase='train', sr=16000, dimension=72000):
16 | assert phase in ['train', 'test'], 'non-supported phase!'
17 |
18 | self.data_dir = None
19 | self.label_dir = None
20 |
21 | if phase == 'train':
22 | self.data_dir = os.path.join(DATAPATH, 'data_synthesized/train/noisy')
23 | self.label_dir = os.path.join(DATAPATH, 'data_synthesized/train/clean')
24 | elif phase == 'test':
25 | self.data_dir = os.path.join(DATAPATH, 'data_synthesized/test/noisy')
26 | self.label_dir = os.path.join(DATAPATH, 'data_synthesized/test/clean')
27 |
28 | self.sr = sr
29 | self.dim = dimension
30 |
31 | # use mapper in __getitem__
32 | # ensure each data find corresponding label
33 | self.mapper = {}
34 |
35 | # get label
36 | self.label_path = []
37 | for file in os.listdir(self.label_dir):
38 | if file.endswith('.wav'):
39 | self.mapper[file[:-4]] = len(self.label_path)
40 | self.label_path.append(os.path.join(self.label_dir, file))
41 |
42 | # get data path
43 | self.data_path = []
44 | for file in os.listdir(self.data_dir):
45 | if file.endswith('.wav'):
46 | self.data_path.append(os.path.join(self.data_dir, file))
47 |
48 | assert len(self.data_path) == len(self.label_path), 'data or label is corrupted!'
49 |
50 | def __getitem__(self, item):
51 | data, _ = librosa.load(self.data_path[item], sr=self.sr)
52 | data_name = os.path.basename(self.data_path[item])
53 | data_name = data_name[:data_name.rfind('_')]
54 | label, _ = librosa.load(self.label_path[self.mapper[data_name]], sr=self.sr)
55 | # 取 帧
56 | if len(data) > self.dim:
57 | max_audio_start = len(data) - self.dim
58 | audio_start = np.random.randint(0, max_audio_start)
59 | data = data[audio_start: audio_start + self.dim]
60 | label = label[audio_start:audio_start + self.dim]
61 | else:
62 | data = np.pad(data, (0, self.dim - len(data)), "constant")
63 | label = np.pad(label, (0, self.dim - len(label)), "constant")
64 |
65 | return data, label
66 |
67 | def __len__(self):
68 | return len(self.data_path)
69 |
70 |
71 | if __name__ == '__main__':
72 | ds = THCHS30(phase='train')
73 | min_dim = 1e8
74 | max_dim = 0
75 | for i in range(0, len(ds)):
76 | data, label = ds[i]
77 | min_dim = min(min_dim, len(data))
78 | max_dim = max(max_dim, len(data))
79 | print('min dim=', min_dim)
80 | print('max dim=', max_dim)
81 | print('mid dim=', int((min_dim + max_dim) / 2))
82 | pass
83 |
--------------------------------------------------------------------------------
/debug.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models.DCCRN import dccrn
3 |
4 | if __name__ == '__main__':
5 | torch.manual_seed(10)
6 | torch.autograd.set_detect_anomaly(True)
7 | inputs = torch.randn([10, 16000 * 4]).clamp_(-1, 1)
8 | labels = torch.randn([10, 16000 * 4]).clamp_(-1, 1)
9 |
10 | print(inputs.shape)
11 | exit(0)
12 |
13 | # DCCRN-E
14 | # model = dccrn('E')
15 | # DCCRN-R
16 | # model = dccrn('R')
17 | # DCCRN-C
18 | # model = dccrn('C')
19 | # DCCRN-CL
20 | model = dccrn('CL')
21 |
22 | outputs = model(inputs)[1]
23 | loss = model.loss(outputs, labels, loss_mode='SI-SNR')
24 | print(loss)
25 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import librosa
3 | import soundfile as sf
4 | import time
5 |
6 | import torch
7 | from torch.utils.data import DataLoader
8 | from torch.optim import Adam
9 | from torch.optim.lr_scheduler import MultiStepLR
10 | from torchnet.meter import AverageValueMeter
11 |
12 | from models.DCCRN import dccrn
13 | from models.loss import SISNRLoss
14 |
15 | from dataloader.THCHS30 import THCHS30
16 | from config import opt
17 |
18 |
19 | def train(mode='CL'):
20 | model = dccrn(mode)
21 | model.to(opt.device)
22 |
23 | train_data = THCHS30(phase='train')
24 | train_loader = DataLoader(train_data,
25 | batch_size=opt.batch_size,
26 | num_workers=opt.num_workers,
27 | shuffle=True)
28 |
29 | optimizer = Adam(model.parameters(), lr=opt.lr)
30 | scheduler = MultiStepLR(optimizer,
31 | milestones=[int(opt.max_epoch * 0.5),
32 | int(opt.max_epoch * 0.7),
33 | int(opt.max_epoch * 0.9)],
34 | gamma=opt.lr_decay)
35 | criterion = SISNRLoss()
36 |
37 | loss_meter = AverageValueMeter()
38 |
39 | for epoch in range(0, opt.max_epoch):
40 | loss_meter.reset()
41 | for i, (data, label) in enumerate(train_loader):
42 | data = data.to(opt.device)
43 | label = label.to(opt.device)
44 |
45 | spec, wav = model(data)
46 |
47 | optimizer.zero_grad()
48 | loss = criterion(wav, label)
49 | loss.backward()
50 | optimizer.step()
51 |
52 | loss_meter.add(loss.item())
53 |
54 | if (i + 1) % opt.verbose_inter == 0:
55 | print('epoch', epoch + 1, 'batch', i + 1,
56 | 'SI-SNR', -loss_meter.value()[0])
57 | if (epoch + 1) % opt.save_inter == 0:
58 | print('save model at epoch {0} ...'.format(epoch + 1))
59 | save_path = os.path.join(opt.checkpoint_root,
60 | 'DCCRN_{0}_{1}.pth'.format(mode, epoch + 1))
61 | torch.save(model.state_dict(), save_path)
62 |
63 | scheduler.step()
64 |
65 | save_path = os.path.join(opt.checkpoint_root,
66 | 'DCCRN_{0}.pth'.format(mode))
67 | torch.save(model.state_dict(), save_path)
68 |
69 |
70 | # when denoising, use cpu
71 | def denoise(mode, speech_file, save_dir, pth=None):
72 | assert os.path.exists(speech_file), 'speech file does not exist!'
73 |
74 | assert speech_file.endswith('.wav'), 'non-supported speech format!'
75 |
76 | if not os.path.exists(save_dir):
77 | print('warning: save directory does not exist, it will be created automatically!')
78 | os.makedirs(save_dir)
79 |
80 | model = dccrn(mode)
81 | if pth is not None:
82 | model.load_state_dict(torch.load(pth), strict=True)
83 |
84 | noisy_wav, _ = librosa.load(speech_file, sr=16000)
85 |
86 | noisy_wav = torch.Tensor(noisy_wav).reshape(1, -1)
87 |
88 | torch.cuda.synchronize()
89 | start = time.time()
90 |
91 | _, denoised_wav = model(noisy_wav)
92 |
93 | torch.cuda.synchronize()
94 | end = time.time()
95 |
96 | print('process time {0}s on device {1}'.format(end - start, 'cpu'))
97 |
98 | speech_name = os.path.basename(speech_file)[:-4]
99 |
100 | noisy_path = os.path.join(save_dir, speech_name + '_' + 'noisy' + '.wav')
101 | denoised_path = os.path.join(save_dir, speech_name + '_' + 'denoised' + '.wav')
102 |
103 | noisy_wav = noisy_wav.data.numpy().flatten()
104 | denoised_wav = denoised_wav.data.numpy().flatten()
105 |
106 | sf.write(noisy_path, noisy_wav, 16000)
107 | sf.write(denoised_path, denoised_wav, 16000)
108 |
109 |
110 | if __name__ == '__main__':
111 | # train('E')
112 |
113 | test_speech_base = os.path.join(opt.data_root, 'THCHS-30', 'data_synthesized/test/noisy')
114 | test_speech = os.path.join(test_speech_base, 'D11_752_car.wav')
115 |
116 | save_dir = os.path.join(opt.sample_root, 'THCHS-30')
117 | pth = os.path.join(opt.checkpoint_root, 'DCCRN_E.pth')
118 |
119 | denoise('E', test_speech, save_dir, pth=pth)
120 |
121 | pass
122 |
--------------------------------------------------------------------------------
/models/DCCRN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from utils.show import show_params, show_model
4 | import torch.nn.functional as F
5 | from models.conv_stft import ConvSTFT, ConviSTFT
6 |
7 | from models.complexnn import ComplexConv2d, ComplexConvTranspose2d, NavieComplexLSTM, complex_cat, ComplexBatchNorm
8 |
9 |
10 | class DCCRN(nn.Module):
11 |
12 | def __init__(
13 | self,
14 | rnn_layers=2,
15 | rnn_units=128,
16 | win_len=400,
17 | win_inc=100,
18 | fft_len=512,
19 | win_type='hanning',
20 | masking_mode='E',
21 | use_clstm=False,
22 | use_cbn=False,
23 | kernel_size=5,
24 | kernel_num=[16, 32, 64, 128, 256, 256]
25 | ):
26 | '''
27 |
28 | rnn_layers: the number of lstm layers in the crn,
29 | rnn_units: for clstm, rnn_units = real+imag
30 |
31 | '''
32 |
33 | super(DCCRN, self).__init__()
34 |
35 | # for fft
36 | self.win_len = win_len
37 | self.win_inc = win_inc
38 | self.fft_len = fft_len
39 | self.win_type = win_type
40 |
41 | input_dim = win_len
42 | output_dim = win_len
43 |
44 | self.rnn_units = rnn_units
45 | self.input_dim = input_dim
46 | self.output_dim = output_dim
47 | self.hidden_layers = rnn_layers
48 | self.kernel_size = kernel_size
49 | # self.kernel_num = [2, 8, 16, 32, 128, 128, 128]
50 | # self.kernel_num = [2, 16, 32, 64, 128, 256, 256]
51 | self.kernel_num = [2] + kernel_num
52 | self.masking_mode = masking_mode
53 | self.use_clstm = use_clstm
54 |
55 | # bidirectional=True
56 | bidirectional = False
57 | fac = 2 if bidirectional else 1
58 |
59 | fix = True
60 | self.fix = fix
61 | self.stft = ConvSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix)
62 | self.istft = ConviSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix)
63 |
64 | self.encoder = nn.ModuleList()
65 | self.decoder = nn.ModuleList()
66 | for idx in range(len(self.kernel_num) - 1):
67 | self.encoder.append(
68 | nn.Sequential(
69 | # nn.ConstantPad2d([0, 0, 0, 0], 0),
70 | ComplexConv2d(
71 | self.kernel_num[idx],
72 | self.kernel_num[idx + 1],
73 | kernel_size=(self.kernel_size, 2),
74 | stride=(2, 1),
75 | padding=(2, 1)
76 | ),
77 | nn.BatchNorm2d(self.kernel_num[idx + 1]) if not use_cbn else ComplexBatchNorm(
78 | self.kernel_num[idx + 1]),
79 | nn.PReLU()
80 | )
81 | )
82 | hidden_dim = self.fft_len // (2 ** (len(self.kernel_num)))
83 |
84 | if self.use_clstm:
85 | rnns = []
86 | for idx in range(rnn_layers):
87 | rnns.append(
88 | NavieComplexLSTM(
89 | input_size=hidden_dim * self.kernel_num[-1] if idx == 0 else self.rnn_units,
90 | hidden_size=self.rnn_units,
91 | bidirectional=bidirectional,
92 | batch_first=False,
93 | projection_dim=hidden_dim * self.kernel_num[-1] if idx == rnn_layers - 1 else None,
94 | )
95 | )
96 | self.enhance = nn.Sequential(*rnns)
97 | else:
98 | self.enhance = nn.LSTM(
99 | input_size=hidden_dim * self.kernel_num[-1],
100 | hidden_size=self.rnn_units,
101 | num_layers=2,
102 | dropout=0.0,
103 | bidirectional=bidirectional,
104 | batch_first=False
105 | )
106 | self.tranform = nn.Linear(self.rnn_units * fac, hidden_dim * self.kernel_num[-1])
107 |
108 | for idx in range(len(self.kernel_num) - 1, 0, -1):
109 | if idx != 1:
110 | self.decoder.append(
111 | nn.Sequential(
112 | ComplexConvTranspose2d(
113 | self.kernel_num[idx] * 2,
114 | self.kernel_num[idx - 1],
115 | kernel_size=(self.kernel_size, 2),
116 | stride=(2, 1),
117 | padding=(2, 0),
118 | output_padding=(1, 0)
119 | ),
120 | nn.BatchNorm2d(self.kernel_num[idx - 1]) if not use_cbn else ComplexBatchNorm(
121 | self.kernel_num[idx - 1]),
122 | # nn.ELU()
123 | nn.PReLU()
124 | )
125 | )
126 | else:
127 | self.decoder.append(
128 | nn.Sequential(
129 | ComplexConvTranspose2d(
130 | self.kernel_num[idx] * 2,
131 | self.kernel_num[idx - 1],
132 | kernel_size=(self.kernel_size, 2),
133 | stride=(2, 1),
134 | padding=(2, 0),
135 | output_padding=(1, 0)
136 | ),
137 | )
138 | )
139 |
140 | show_model(self)
141 | show_params(self)
142 | self.flatten_parameters()
143 |
144 | def flatten_parameters(self):
145 | if isinstance(self.enhance, nn.LSTM):
146 | self.enhance.flatten_parameters()
147 |
148 | def forward(self, inputs, lens=None):
149 | specs = self.stft(inputs)
150 | real = specs[:, :self.fft_len // 2 + 1]
151 | imag = specs[:, self.fft_len // 2 + 1:]
152 | spec_mags = torch.sqrt(real ** 2 + imag ** 2 + 1e-8)
153 | spec_mags = spec_mags
154 | spec_phase = torch.atan2(imag, real)
155 | spec_phase = spec_phase
156 | cspecs = torch.stack([real, imag], 1)
157 | cspecs = cspecs[:, :, 1:]
158 | '''
159 | means = torch.mean(cspecs, [1,2,3], keepdim=True)
160 | std = torch.std(cspecs, [1,2,3], keepdim=True )
161 | normed_cspecs = (cspecs-means)/(std+1e-8)
162 | out = normed_cspecs
163 | '''
164 |
165 | out = cspecs
166 | encoder_out = []
167 |
168 | for idx, layer in enumerate(self.encoder):
169 | out = layer(out)
170 | # print('encoder', out.size())
171 | encoder_out.append(out)
172 |
173 | batch_size, channels, dims, lengths = out.size()
174 | out = out.permute(3, 0, 1, 2)
175 | if self.use_clstm:
176 | r_rnn_in = out[:, :, :channels // 2]
177 | i_rnn_in = out[:, :, channels // 2:]
178 | r_rnn_in = torch.reshape(r_rnn_in, [lengths, batch_size, channels // 2 * dims])
179 | i_rnn_in = torch.reshape(i_rnn_in, [lengths, batch_size, channels // 2 * dims])
180 |
181 | r_rnn_in, i_rnn_in = self.enhance([r_rnn_in, i_rnn_in])
182 |
183 | r_rnn_in = torch.reshape(r_rnn_in, [lengths, batch_size, channels // 2, dims])
184 | i_rnn_in = torch.reshape(i_rnn_in, [lengths, batch_size, channels // 2, dims])
185 | out = torch.cat([r_rnn_in, i_rnn_in], 2)
186 |
187 | else:
188 | # to [L, B, C, D]
189 | out = torch.reshape(out, [lengths, batch_size, channels * dims])
190 | out, _ = self.enhance(out)
191 | out = self.tranform(out)
192 | out = torch.reshape(out, [lengths, batch_size, channels, dims])
193 |
194 | out = out.permute(1, 2, 3, 0)
195 |
196 | for idx in range(len(self.decoder)):
197 | out = complex_cat([out, encoder_out[-1 - idx]], 1)
198 | out = self.decoder[idx](out)
199 | out = out[..., 1:]
200 | # print('decoder', out.size())
201 | mask_real = out[:, 0]
202 | mask_imag = out[:, 1]
203 | mask_real = F.pad(mask_real, [0, 0, 1, 0])
204 | mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
205 |
206 | if self.masking_mode == 'E':
207 | mask_mags = (mask_real ** 2 + mask_imag ** 2) ** 0.5
208 | real_phase = mask_real / (mask_mags + 1e-8)
209 | imag_phase = mask_imag / (mask_mags + 1e-8)
210 | mask_phase = torch.atan2(
211 | imag_phase,
212 | real_phase
213 | )
214 |
215 | # mask_mags = torch.clamp_(mask_mags,0,100)
216 | mask_mags = torch.tanh(mask_mags)
217 | est_mags = mask_mags * spec_mags
218 | est_phase = spec_phase + mask_phase
219 | real = est_mags * torch.cos(est_phase)
220 | imag = est_mags * torch.sin(est_phase)
221 | elif self.masking_mode == 'C':
222 | real, imag = real * mask_real - imag * mask_imag, real * mask_imag + imag * mask_real
223 | elif self.masking_mode == 'R':
224 | real, imag = real * mask_real, imag * mask_imag
225 |
226 | out_spec = torch.cat([real, imag], 1)
227 | out_wav = self.istft(out_spec)
228 |
229 | out_wav = torch.squeeze(out_wav, 1)
230 | # out_wav = torch.tanh(out_wav)
231 | # add _ to be a in-place operation
232 | out_wav = torch.clamp_(out_wav, -1, 1)
233 | return out_spec, out_wav
234 |
235 | def get_params(self, weight_decay=0.0):
236 | # add L2 penalty
237 | weights, biases = [], []
238 | for name, param in self.named_parameters():
239 | if 'bias' in name:
240 | biases += [param]
241 | else:
242 | weights += [param]
243 | params = [{
244 | 'params': weights,
245 | 'weight_decay': weight_decay,
246 | }, {
247 | 'params': biases,
248 | 'weight_decay': 0.0,
249 | }]
250 | return params
251 |
252 |
253 | def dccrn(mode='CL'):
254 | if mode == 'E':
255 | model = DCCRN(rnn_units=256, masking_mode='E')
256 | elif mode == 'R':
257 | model = DCCRN(rnn_units=256, masking_mode='R')
258 | elif mode == 'C':
259 | model = DCCRN(rnn_units=256, masking_mode='C')
260 | elif mode == 'CL':
261 | model = DCCRN(rnn_units=256, masking_mode='E',
262 | use_clstm=True, kernel_num=[32, 64, 128, 256, 256, 256])
263 | else:
264 | raise Exception('non-supported mode!')
265 | return model
266 |
--------------------------------------------------------------------------------
/models/complexnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 |
7 | def get_casual_padding1d():
8 | pass
9 |
10 |
11 | def get_casual_padding2d():
12 | pass
13 |
14 |
15 | class cPReLU(nn.Module):
16 |
17 | def __init__(self, complex_axis=1):
18 | super(cPReLU, self).__init__()
19 | self.r_prelu = nn.PReLU()
20 | self.i_prelu = nn.PReLU()
21 | self.complex_axis = complex_axis
22 |
23 | def forward(self, inputs):
24 | real, imag = torch.chunk(inputs, 2, self.complex_axis)
25 | real = self.r_prelu(real)
26 | imag = self.i_prelu(imag)
27 | return torch.cat([real, imag], self.complex_axis)
28 |
29 |
30 | class NavieComplexLSTM(nn.Module):
31 | def __init__(self, input_size, hidden_size, projection_dim=None, bidirectional=False, batch_first=False):
32 | super(NavieComplexLSTM, self).__init__()
33 |
34 | self.input_dim = input_size // 2
35 | self.rnn_units = hidden_size // 2
36 | self.real_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional,
37 | batch_first=False)
38 | self.imag_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional,
39 | batch_first=False)
40 | if bidirectional:
41 | bidirectional = 2
42 | else:
43 | bidirectional = 1
44 | if projection_dim is not None:
45 | self.projection_dim = projection_dim // 2
46 | self.r_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim)
47 | self.i_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim)
48 | else:
49 | self.projection_dim = None
50 |
51 | def forward(self, inputs):
52 | if isinstance(inputs, list):
53 | real, imag = inputs
54 | elif isinstance(inputs, torch.Tensor):
55 | real, imag = torch.chunk(inputs, -1)
56 | r2r_out = self.real_lstm(real)[0]
57 | r2i_out = self.imag_lstm(real)[0]
58 | i2r_out = self.real_lstm(imag)[0]
59 | i2i_out = self.imag_lstm(imag)[0]
60 | real_out = r2r_out - i2i_out
61 | imag_out = i2r_out + r2i_out
62 | if self.projection_dim is not None:
63 | real_out = self.r_trans(real_out)
64 | imag_out = self.i_trans(imag_out)
65 | # print(real_out.shape,imag_out.shape)
66 | return [real_out, imag_out]
67 |
68 | def flatten_parameters(self):
69 | self.imag_lstm.flatten_parameters()
70 | self.real_lstm.flatten_parameters()
71 |
72 |
73 | def complex_cat(inputs, axis):
74 | real, imag = [], []
75 | for idx, data in enumerate(inputs):
76 | r, i = torch.chunk(data, 2, axis)
77 | real.append(r)
78 | imag.append(i)
79 | real = torch.cat(real, axis)
80 | imag = torch.cat(imag, axis)
81 | outputs = torch.cat([real, imag], axis)
82 | return outputs
83 |
84 |
85 | class ComplexConv2d(nn.Module):
86 |
87 | def __init__(
88 | self,
89 | in_channels,
90 | out_channels,
91 | kernel_size=(1, 1),
92 | stride=(1, 1),
93 | padding=(0, 0),
94 | dilation=1,
95 | groups=1,
96 | causal=True,
97 | complex_axis=1,
98 | ):
99 | '''
100 | in_channels: real+imag
101 | out_channels: real+imag
102 | kernel_size : input [B,C,D,T] kernel size in [D,T]
103 | padding : input [B,C,D,T] padding in [D,T]
104 | causal: if causal, will padding time dimension's left side,
105 | otherwise both
106 |
107 | '''
108 | super(ComplexConv2d, self).__init__()
109 | self.in_channels = in_channels // 2
110 | self.out_channels = out_channels // 2
111 | self.kernel_size = kernel_size
112 | self.stride = stride
113 | self.padding = padding
114 | self.causal = causal
115 | self.groups = groups
116 | self.dilation = dilation
117 | self.complex_axis = complex_axis
118 | self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
119 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)
120 | self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
121 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)
122 |
123 | nn.init.normal_(self.real_conv.weight.data, std=0.05)
124 | nn.init.normal_(self.imag_conv.weight.data, std=0.05)
125 | nn.init.constant_(self.real_conv.bias, 0.)
126 | nn.init.constant_(self.imag_conv.bias, 0.)
127 |
128 | def forward(self, inputs):
129 | if self.padding[1] != 0 and self.causal:
130 | inputs = F.pad(inputs, [self.padding[1], 0, 0, 0])
131 | else:
132 | inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0])
133 |
134 | if self.complex_axis == 0:
135 | real = self.real_conv(inputs)
136 | imag = self.imag_conv(inputs)
137 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis)
138 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis)
139 |
140 | else:
141 | if isinstance(inputs, torch.Tensor):
142 | real, imag = torch.chunk(inputs, 2, self.complex_axis)
143 |
144 | real2real = self.real_conv(real, )
145 | imag2imag = self.imag_conv(imag, )
146 |
147 | real2imag = self.imag_conv(real)
148 | imag2real = self.real_conv(imag)
149 |
150 | real = real2real - imag2imag
151 | imag = real2imag + imag2real
152 | out = torch.cat([real, imag], self.complex_axis)
153 |
154 | return out
155 |
156 |
157 | class ComplexConvTranspose2d(nn.Module):
158 |
159 | def __init__(
160 | self,
161 | in_channels,
162 | out_channels,
163 | kernel_size=(1, 1),
164 | stride=(1, 1),
165 | padding=(0, 0),
166 | output_padding=(0, 0),
167 | causal=False,
168 | complex_axis=1,
169 | groups=1
170 | ):
171 | '''
172 | in_channels: real+imag
173 | out_channels: real+imag
174 | '''
175 | super(ComplexConvTranspose2d, self).__init__()
176 | self.in_channels = in_channels // 2
177 | self.out_channels = out_channels // 2
178 | self.kernel_size = kernel_size
179 | self.stride = stride
180 | self.padding = padding
181 | self.output_padding = output_padding
182 | self.groups = groups
183 |
184 | self.real_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
185 | padding=self.padding, output_padding=output_padding, groups=self.groups)
186 | self.imag_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
187 | padding=self.padding, output_padding=output_padding, groups=self.groups)
188 | self.complex_axis = complex_axis
189 |
190 | nn.init.normal_(self.real_conv.weight, std=0.05)
191 | nn.init.normal_(self.imag_conv.weight, std=0.05)
192 | nn.init.constant_(self.real_conv.bias, 0.)
193 | nn.init.constant_(self.imag_conv.bias, 0.)
194 |
195 | def forward(self, inputs):
196 |
197 | if isinstance(inputs, torch.Tensor):
198 | real, imag = torch.chunk(inputs, 2, self.complex_axis)
199 | elif isinstance(inputs, tuple) or isinstance(inputs, list):
200 | real = inputs[0]
201 | imag = inputs[1]
202 | if self.complex_axis == 0:
203 | real = self.real_conv(inputs)
204 | imag = self.imag_conv(inputs)
205 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis)
206 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis)
207 |
208 | else:
209 | if isinstance(inputs, torch.Tensor):
210 | real, imag = torch.chunk(inputs, 2, self.complex_axis)
211 |
212 | real2real = self.real_conv(real, )
213 | imag2imag = self.imag_conv(imag, )
214 |
215 | real2imag = self.imag_conv(real)
216 | imag2real = self.real_conv(imag)
217 |
218 | real = real2real - imag2imag
219 | imag = real2imag + imag2real
220 | out = torch.cat([real, imag], self.complex_axis)
221 |
222 | return out
223 |
224 |
225 | # Source: https://github.com/ChihebTrabelsi/deep_complex_networks/tree/pytorch
226 | # from https://github.com/IMLHF/SE_DCUNet/blob/f28bf1661121c8901ad38149ea827693f1830715/models/layers/complexnn.py#L55
227 |
228 | class ComplexBatchNorm(torch.nn.Module):
229 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
230 | track_running_stats=True, complex_axis=1):
231 | super(ComplexBatchNorm, self).__init__()
232 | self.num_features = num_features // 2
233 | self.eps = eps
234 | self.momentum = momentum
235 | self.affine = affine
236 | self.track_running_stats = track_running_stats
237 |
238 | self.complex_axis = complex_axis
239 |
240 | if self.affine:
241 | self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features))
242 | self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features))
243 | self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features))
244 | self.Br = torch.nn.Parameter(torch.Tensor(self.num_features))
245 | self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features))
246 | else:
247 | self.register_parameter('Wrr', None)
248 | self.register_parameter('Wri', None)
249 | self.register_parameter('Wii', None)
250 | self.register_parameter('Br', None)
251 | self.register_parameter('Bi', None)
252 |
253 | if self.track_running_stats:
254 | self.register_buffer('RMr', torch.zeros(self.num_features))
255 | self.register_buffer('RMi', torch.zeros(self.num_features))
256 | self.register_buffer('RVrr', torch.ones(self.num_features))
257 | self.register_buffer('RVri', torch.zeros(self.num_features))
258 | self.register_buffer('RVii', torch.ones(self.num_features))
259 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
260 | else:
261 | self.register_parameter('RMr', None)
262 | self.register_parameter('RMi', None)
263 | self.register_parameter('RVrr', None)
264 | self.register_parameter('RVri', None)
265 | self.register_parameter('RVii', None)
266 | self.register_parameter('num_batches_tracked', None)
267 | self.reset_parameters()
268 |
269 | def reset_running_stats(self):
270 | if self.track_running_stats:
271 | self.RMr.zero_()
272 | self.RMi.zero_()
273 | self.RVrr.fill_(1)
274 | self.RVri.zero_()
275 | self.RVii.fill_(1)
276 | self.num_batches_tracked.zero_()
277 |
278 | def reset_parameters(self):
279 | self.reset_running_stats()
280 | if self.affine:
281 | self.Br.data.zero_()
282 | self.Bi.data.zero_()
283 | self.Wrr.data.fill_(1)
284 | self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
285 | self.Wii.data.fill_(1)
286 |
287 | def _check_input_dim(self, xr, xi):
288 | assert (xr.shape == xi.shape)
289 | assert (xr.size(1) == self.num_features)
290 |
291 | def forward(self, inputs):
292 | # self._check_input_dim(xr, xi)
293 |
294 | xr, xi = torch.chunk(inputs, 2, axis=self.complex_axis)
295 | exponential_average_factor = 0.0
296 |
297 | if self.training and self.track_running_stats:
298 | self.num_batches_tracked += 1
299 | if self.momentum is None: # use cumulative moving average
300 | exponential_average_factor = 1.0 / self.num_batches_tracked.item()
301 | else: # use exponential moving average
302 | exponential_average_factor = self.momentum
303 |
304 | #
305 | # NOTE: The precise meaning of the "training flag" is:
306 | # True: Normalize using batch statistics, update running statistics
307 | # if they are being collected.
308 | # False: Normalize using running statistics, ignore batch statistics.
309 | #
310 | training = self.training or not self.track_running_stats
311 | redux = [i for i in reversed(range(xr.dim())) if i != 1]
312 | vdim = [1] * xr.dim()
313 | vdim[1] = xr.size(1)
314 |
315 | #
316 | # Mean M Computation and Centering
317 | #
318 | # Includes running mean update if training and running.
319 | #
320 | if training:
321 | Mr, Mi = xr, xi
322 | for d in redux:
323 | Mr = Mr.mean(d, keepdim=True)
324 | Mi = Mi.mean(d, keepdim=True)
325 | if self.track_running_stats:
326 | self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
327 | self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
328 | else:
329 | Mr = self.RMr.view(vdim)
330 | Mi = self.RMi.view(vdim)
331 | xr, xi = xr - Mr, xi - Mi
332 |
333 | #
334 | # Variance Matrix V Computation
335 | #
336 | # Includes epsilon numerical stabilizer/Tikhonov regularizer.
337 | # Includes running variance update if training and running.
338 | #
339 | if training:
340 | Vrr = xr * xr
341 | Vri = xr * xi
342 | Vii = xi * xi
343 | for d in redux:
344 | Vrr = Vrr.mean(d, keepdim=True)
345 | Vri = Vri.mean(d, keepdim=True)
346 | Vii = Vii.mean(d, keepdim=True)
347 | if self.track_running_stats:
348 | self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
349 | self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
350 | self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
351 | else:
352 | Vrr = self.RVrr.view(vdim)
353 | Vri = self.RVri.view(vdim)
354 | Vii = self.RVii.view(vdim)
355 | Vrr = Vrr + self.eps
356 | Vri = Vri
357 | Vii = Vii + self.eps
358 |
359 | #
360 | # Matrix Inverse Square Root U = V^-0.5
361 | #
362 | # sqrt of a 2x2 matrix,
363 | # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
364 | tau = Vrr + Vii
365 | delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri)
366 | s = delta.sqrt()
367 | t = (tau + 2 * s).sqrt()
368 |
369 | # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
370 | rst = (s * t).reciprocal()
371 | Urr = (s + Vii) * rst
372 | Uii = (s + Vrr) * rst
373 | Uri = (- Vri) * rst
374 |
375 | #
376 | # Optionally left-multiply U by affine weights W to produce combined
377 | # weights Z, left-multiply the inputs by Z, then optionally bias them.
378 | #
379 | # y = Zx + B
380 | # y = WUx + B
381 | # y = [Wrr Wri][Urr Uri] [xr] + [Br]
382 | # [Wir Wii][Uir Uii] [xi] [Bi]
383 | #
384 | if self.affine:
385 | Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
386 | Zrr = (Wrr * Urr) + (Wri * Uri)
387 | Zri = (Wrr * Uri) + (Wri * Uii)
388 | Zir = (Wri * Urr) + (Wii * Uri)
389 | Zii = (Wri * Uri) + (Wii * Uii)
390 | else:
391 | Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
392 |
393 | yr = (Zrr * xr) + (Zri * xi)
394 | yi = (Zir * xr) + (Zii * xi)
395 |
396 | if self.affine:
397 | yr = yr + self.Br.view(vdim)
398 | yi = yi + self.Bi.view(vdim)
399 |
400 | outputs = torch.cat([yr, yi], self.complex_axis)
401 | return outputs
402 |
403 | def extra_repr(self):
404 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
405 | 'track_running_stats={track_running_stats}'.format(**self.__dict__)
406 |
407 |
408 | def complex_cat(inputs, axis):
409 | real, imag = [], []
410 | for idx, data in enumerate(inputs):
411 | r, i = torch.chunk(data, 2, axis)
412 | real.append(r)
413 | imag.append(i)
414 | real = torch.cat(real, axis)
415 | imag = torch.cat(imag, axis)
416 | outputs = torch.cat([real, imag], axis)
417 | return outputs
418 |
419 |
420 | if __name__ == '__main__':
421 | import dc_crn7
422 |
423 | torch.manual_seed(20)
424 | onet1 = dc_crn7.ComplexConv2d(12, 12, kernel_size=(3, 2), padding=(2, 1))
425 | onet2 = dc_crn7.ComplexConvTranspose2d(12, 12, kernel_size=(3, 2), padding=(2, 1))
426 | inputs = torch.randn([1, 12, 12, 10])
427 | # print(onet1.real_kernel[0,0,0,0])
428 | nnet1 = ComplexConv2d(12, 12, kernel_size=(3, 2), padding=(2, 1), causal=True)
429 | # print(nnet1.real_conv.weight[0,0,0,0])
430 | nnet2 = ComplexConvTranspose2d(12, 12, kernel_size=(3, 2), padding=(2, 1))
431 | print(torch.mean(nnet1(inputs) - onet1(inputs)))
432 |
--------------------------------------------------------------------------------
/models/conv_stft.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from scipy.signal import get_window
6 |
7 |
8 | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
9 | if win_type == 'None' or win_type is None:
10 | window = np.ones(win_len)
11 | else:
12 | window = get_window(win_type, win_len, fftbins=True) # **0.5
13 |
14 | N = fft_len
15 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
16 | real_kernel = np.real(fourier_basis)
17 | imag_kernel = np.imag(fourier_basis)
18 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T
19 |
20 | if invers:
21 | kernel = np.linalg.pinv(kernel).T
22 |
23 | kernel = kernel * window
24 | kernel = kernel[:, None, :]
25 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32))
26 |
27 |
28 | class ConvSTFT(nn.Module):
29 |
30 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
31 | super(ConvSTFT, self).__init__()
32 |
33 | if fft_len == None:
34 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
35 | else:
36 | self.fft_len = fft_len
37 |
38 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
39 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix))
40 | self.register_buffer('weight', kernel)
41 | self.feature_type = feature_type
42 | self.stride = win_inc
43 | self.win_len = win_len
44 | self.dim = self.fft_len
45 |
46 | def forward(self, inputs):
47 | if inputs.dim() == 2:
48 | inputs = torch.unsqueeze(inputs, 1)
49 | inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride])
50 | outputs = F.conv1d(inputs, self.weight, stride=self.stride)
51 |
52 | if self.feature_type == 'complex':
53 | return outputs
54 | else:
55 | dim = self.dim // 2 + 1
56 | real = outputs[:, :dim, :]
57 | imag = outputs[:, dim:, :]
58 | mags = torch.sqrt(real ** 2 + imag ** 2)
59 | phase = torch.atan2(imag, real)
60 | return mags, phase
61 |
62 |
63 | class ConviSTFT(nn.Module):
64 |
65 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
66 | super(ConviSTFT, self).__init__()
67 | if fft_len == None:
68 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
69 | else:
70 | self.fft_len = fft_len
71 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True)
72 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix))
73 | self.register_buffer('weight', kernel)
74 | self.feature_type = feature_type
75 | self.win_type = win_type
76 | self.win_len = win_len
77 | self.stride = win_inc
78 | self.stride = win_inc
79 | self.dim = self.fft_len
80 | self.register_buffer('window', window)
81 | self.register_buffer('enframe', torch.eye(win_len)[:, None, :])
82 |
83 | def forward(self, inputs, phase=None):
84 | """
85 | inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
86 | phase: [B, N//2+1, T] (if not none)
87 | """
88 |
89 | if phase is not None:
90 | real = inputs * torch.cos(phase)
91 | imag = inputs * torch.sin(phase)
92 | inputs = torch.cat([real, imag], 1)
93 | outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
94 |
95 | # this is from torch-stft: https://github.com/pseeth/torch-stft
96 | t = self.window.repeat(1, 1, inputs.size(-1)) ** 2
97 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
98 | outputs = outputs / (coff + 1e-8)
99 | # outputs = torch.where(coff == 0, outputs, outputs/coff)
100 | outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)]
101 |
102 | return outputs
103 |
104 |
105 | def test_fft():
106 | torch.manual_seed(20)
107 | win_len = 320
108 | win_inc = 160
109 | fft_len = 512
110 | inputs = torch.randn([1, 1, 16000 * 4])
111 | fft = ConvSTFT(win_len, win_inc, fft_len, win_type='hanning', feature_type='real')
112 | import librosa
113 |
114 | outputs1 = fft(inputs)[0]
115 | outputs1 = outputs1.numpy()[0]
116 | np_inputs = inputs.numpy().reshape([-1])
117 | librosa_stft = librosa.stft(np_inputs, win_length=win_len, n_fft=fft_len, hop_length=win_inc, center=False)
118 | print(np.mean((outputs1 - np.abs(librosa_stft)) ** 2))
119 |
120 |
121 | def test_ifft1():
122 | import soundfile as sf
123 | N = 400
124 | inc = 100
125 | fft_len = 512
126 | torch.manual_seed(N)
127 | data = np.random.randn(16000 * 8)[None, None, :]
128 | # data = sf.read('../ori.wav')[0]
129 | inputs = data.reshape([1, 1, -1])
130 | fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')
131 | ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')
132 | inputs = torch.from_numpy(inputs.astype(np.float32))
133 | outputs1 = fft(inputs)
134 | print(outputs1.shape)
135 | outputs2 = ifft(outputs1)
136 | sf.write('conv_stft.wav', outputs2.numpy()[0, 0, :], 16000)
137 | print('wav MSE', torch.mean(torch.abs(inputs[..., :outputs2.size(2)] - outputs2) ** 2))
138 |
139 |
140 | def test_ifft2():
141 | N = 400
142 | inc = 100
143 | fft_len = 512
144 | np.random.seed(20)
145 | torch.manual_seed(20)
146 | t = np.random.randn(16000 * 4) * 0.001
147 | t = np.clip(t, -1, 1)
148 | # input = torch.randn([1,16000*4])
149 | input = torch.from_numpy(t[None, None, :].astype(np.float32))
150 |
151 | fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')
152 | ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')
153 |
154 | out1 = fft(input)
155 | output = ifft(out1)
156 | print('random MSE', torch.mean(torch.abs(input - output) ** 2))
157 | import soundfile as sf
158 | sf.write('zero.wav', output[0, 0].numpy(), 16000)
159 |
160 |
161 | if __name__ == '__main__':
162 | # test_fft()
163 | test_ifft1()
164 | # test_ifft2()
165 |
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def remove_dc(data):
7 | mean = torch.mean(data, -1, keepdim=True)
8 | data = data - mean
9 | return data
10 |
11 |
12 | def l2_norm(s1, s2):
13 | # norm = torch.sqrt(torch.sum(s1*s2, 1, keepdim=True))
14 | # norm = torch.norm(s1*s2, 1, keepdim=True)
15 |
16 | norm = torch.sum(s1 * s2, -1, keepdim=True)
17 | return norm
18 |
19 |
20 | def si_snr(s1, s2, eps=1e-8):
21 | # s1 = remove_dc(s1)
22 | # s2 = remove_dc(s2)
23 | s1_s2_norm = l2_norm(s1, s2)
24 | s2_s2_norm = l2_norm(s2, s2)
25 | s_target = s1_s2_norm / (s2_s2_norm + eps) * s2
26 | e_nosie = s1 - s_target
27 | target_norm = l2_norm(s_target, s_target)
28 | noise_norm = l2_norm(e_nosie, e_nosie)
29 | snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps)
30 | return torch.mean(snr)
31 |
32 |
33 | # The larger the SI-SNR, the better the model
34 | class SISNRLoss(nn.Module):
35 | def __init__(self, eps=1e-8):
36 | super().__init__()
37 | self.eps = eps
38 |
39 | def forward(self, x, y):
40 | # return -torch.mean(si_snr(inputs, labels))
41 | return -(si_snr(x, y, eps=self.eps))
42 |
43 |
44 | class MSELoss(nn.Module):
45 | def __init__(self):
46 | super().__init__()
47 |
48 | def forward(self, x, y):
49 | b, d, t = x.shape
50 | y[:, 0, :] = 0
51 | y[:, d // 2, :] = 0
52 | return F.mse_loss(x, y, reduction='mean') * d
53 |
54 |
55 | class MAELoss(nn.Module):
56 | def __init__(self, stft):
57 | super().__init__()
58 | self.stft = stft
59 |
60 | def forward(self, x, y):
61 | gth_spec, gth_phase = self.stft(y)
62 | b, d, t = x.shape
63 | return torch.mean(torch.abs(x - gth_spec)) * d
64 |
--------------------------------------------------------------------------------
/utils/show.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python -u
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2018 Northwestern Polytechnical University (author: Ke Wang)
5 |
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 |
11 | def show_params(nnet):
12 | print("=" * 40, "Model Parameters", "=" * 40)
13 | num_params = 0
14 | for module_name, m in nnet.named_modules():
15 | if module_name == '':
16 | for name, params in m.named_parameters():
17 | print(name, params.size())
18 | i = 1
19 | for j in params.size():
20 | i = i * j
21 | num_params += i
22 | print('[*] Parameter Size: {}'.format(num_params))
23 | print("=" * 98)
24 |
25 |
26 | def show_model(nnet):
27 | print("=" * 40, "Model Structures", "=" * 40)
28 | for module_name, m in nnet.named_modules():
29 | if module_name == '':
30 | print(m)
31 | print("=" * 98)
32 |
--------------------------------------------------------------------------------
/utils/synthesizer.py:
--------------------------------------------------------------------------------
1 | from scipy.io import wavfile
2 | import numpy as np
3 | import soundfile as sf
4 | import librosa
5 | import random
6 | import os
7 | from config import opt
8 |
9 |
10 | # split origin noise file to gain better generalization performance
11 | def split_noise(noise_file, save_dir, prop=0.5):
12 | assert os.path.exists(noise_file), 'noise file does not exist!'
13 |
14 | assert noise_file.endswith('.wav'), 'non-supported noise format!'
15 |
16 | if not os.path.exists(save_dir):
17 | print('warning: save directory does not exist, it will be created automatically.')
18 | os.makedirs(save_dir)
19 |
20 | sample_rate, sig = wavfile.read(noise_file)
21 |
22 | train_len = sig.shape[0] * prop
23 |
24 | train_noise = sig[:int(train_len)]
25 | test_noise = sig[int(train_len):]
26 |
27 | # remove .wav
28 | noise_name = os.path.basename(noise_file)[:-4]
29 |
30 | train_noise_dir = os.path.join(save_dir, 'train')
31 | test_noise_dir = os.path.join(save_dir, 'test')
32 |
33 | if not os.path.exists(train_noise_dir):
34 | os.makedirs(train_noise_dir)
35 | if not os.path.exists(test_noise_dir):
36 | os.makedirs(test_noise_dir)
37 |
38 | train_noise_path = os.path.join(train_noise_dir, noise_name + '.wav')
39 | test_noise_path = os.path.join(test_noise_dir, noise_name + '.wav')
40 |
41 | sf.write(train_noise_path, train_noise, sample_rate)
42 | sf.write(test_noise_path, test_noise, sample_rate)
43 |
44 |
45 | def synthesize_noisy_speech(speech_file, noise_file, save_dir, snr=0):
46 | assert os.path.exists(speech_file), 'speech file does not exist!'
47 | assert os.path.exists(noise_file), 'noise file does not exist!'
48 |
49 | assert speech_file.endswith('.wav'), 'non-supported speech format!'
50 | assert noise_file.endswith('.wav'), 'non-supported noise format!'
51 |
52 | if not os.path.exists(save_dir):
53 | print('warning: save directory does not exist, it will be created automatically.')
54 | os.makedirs(save_dir)
55 |
56 | speech_name = os.path.basename(speech_file)[:-4]
57 | noise_name = os.path.basename(noise_file)[:-4]
58 |
59 | # 原始语音
60 | a, a_sr = librosa.load(speech_file, sr=16000)
61 | # 噪音
62 | b, b_sr = librosa.load(noise_file, sr=16000)
63 | # 随机取一段噪声,保证长度和纯净语音长度一致,保证不会越界
64 | start = random.randint(0, b.shape[0] - a.shape[0])
65 | # 切片
66 | n_b = b[int(start):int(start) + a.shape[0]]
67 |
68 | # 平方求和
69 | sum_s = np.sum(a ** 2)
70 | sum_n = np.sum(n_b ** 2)
71 | # 信噪比为snr时的权重
72 | x = np.sqrt(sum_s / (sum_n * pow(10, snr)))
73 |
74 | noise = x * n_b
75 | noisy_speech = a + noise
76 |
77 | noisy_dir = os.path.join(save_dir, '{0}dB'.format(snr), 'noisy')
78 | clean_dir = os.path.join(save_dir, '{0}dB'.format(snr), 'clean')
79 |
80 | if not os.path.exists(noisy_dir):
81 | os.makedirs(noisy_dir)
82 | if not os.path.exists(clean_dir):
83 | os.makedirs(clean_dir)
84 |
85 | noisy_speech_path = os.path.join(noisy_dir, speech_name + '_' + noise_name + '.wav')
86 | clean_speech_path = os.path.join(clean_dir, speech_name + '.wav')
87 |
88 | sf.write(noisy_speech_path, noisy_speech, 16000)
89 | sf.write(clean_speech_path, a, 16000)
90 |
91 |
92 | # split noise for train and test
93 | def generate_noise_dataset(noise_base, save_dir):
94 | print('noise base directory: ', noise_base)
95 | print('output directory: ', save_dir)
96 | # find all noise file and split them with custom proportion
97 | for dir in os.listdir(noise_base):
98 | noise_dir = os.path.join(noise_base, dir)
99 | for file in os.listdir(noise_dir):
100 | if file.endswith('.wav'):
101 | noise_file = os.path.join(noise_dir, file)
102 | split_noise(noise_file, save_dir, prop=0.5)
103 | print('succesfully generated noise dataset!')
104 |
105 |
106 | def generate_noisy_dataset(speech_base, noise_base, save_dir):
107 | print('speech base directory: ', speech_base)
108 | print('output directory: ', save_dir)
109 | noise_files = []
110 | for file in os.listdir(noise_base):
111 | if file.endswith('.wav'):
112 | noise_files.append(os.path.join(noise_base, file))
113 | for file in os.listdir(speech_base):
114 | if file.endswith('.wav'):
115 | speech_file = os.path.join(speech_base, file)
116 | noise_file = random.choice(noise_files)
117 | synthesize_noisy_speech(speech_file, noise_file, save_dir=save_dir, snr=0)
118 | print('successfully generate noisy dataset!')
119 |
120 |
121 | if __name__ == '__main__':
122 | # origin speech data path
123 | noise_base = os.path.join(opt.data_root, 'THCHS-30', 'test-noise/noise')
124 | train_speech_base = os.path.join(opt.data_root, 'THCHS-30', 'data_thchs30/train')
125 | test_speech_base = os.path.join(opt.data_root, 'THCHS-30', 'data_thchs30/test')
126 |
127 | # synthesized speech data path
128 | noise_dir = os.path.join(opt.data_root, 'THCHS-30', 'data_synthesized/noise')
129 | train_dir = os.path.join(opt.data_root, 'THCHS-30', 'data_synthesized/train')
130 | test_dir = os.path.join(opt.data_root, 'THCHS-30', 'data_synthesized/test')
131 |
132 | # split origin noise for train and test
133 | # generate_noise_dataset(noise_base=noise_base, save_dir=noise_dir)
134 |
135 | # generate train noisy speech
136 | generate_noisy_dataset(speech_base=train_speech_base,
137 | noise_base=os.path.join(noise_dir, 'train'),
138 | save_dir=train_dir)
139 |
140 | # generate test noisy speech
141 | generate_noisy_dataset(speech_base=test_speech_base,
142 | noise_base=os.path.join(noise_dir, 'test'),
143 | save_dir=test_dir)
144 |
--------------------------------------------------------------------------------