├── LICENSE
├── README.md
├── config.py
├── dataloader.py
├── model.py
├── samples
└── 0dB
│ ├── [15] SI-SNR+LMS.wav
│ ├── [15] SI-SNR.wav
│ ├── [15]CLEAN.wav
│ ├── [15]NOISY.wav
│ ├── [189] SI-SNR+LMS.wav
│ ├── [189] SI-SNR.wav
│ ├── [189]CLEAN.wav
│ ├── [189]NOISY.wav
│ ├── [1] SI-SNR+LMS.wav
│ ├── [1] SI-SNR.wav
│ ├── [1]CLEAN.wav
│ ├── [1]NOISY.wav
│ ├── [21] SI-SNR+LMS.wav
│ ├── [21] SI-SNR.wav
│ ├── [21]CLEAN.wav
│ ├── [21]NOISY.wav
│ ├── [78] SI-SNR+LMS.wav
│ ├── [78] SI-SNR.wav
│ ├── [78]CLEAN.wav
│ ├── [78]NOISY.wav
│ ├── [88] SI-SNR+LMS.wav
│ ├── [88] SI-SNR.wav
│ ├── [88]CLEAN.wav
│ └── [88]NOISY.wav
├── tester.py
├── tools_for_loss.py
├── tools_for_model.py
├── train.py
├── trainer.py
└── write_on_tensorboard.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 seorim0
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DCCRN with various loss functions
2 |
3 | DCCRN(Deep Complex Convolutional Recurrent Network) is one of the deep neaural networks proposed at [[1]](https://arxiv.org/abs/2008.00264). This repository is an application using DCCRN with various loss functions. Our original paper can be found [here](https://www.jask.or.kr/articles/xml/ABxn/), and you can check test samples [here](https://github.com/seorim0/DCCRN-with-various-loss-functions/tree/main/samples/0dB). Test samples are randomly choosed and we uploaded samples about SI-SNR and SI-SNR+LMS.
4 |
5 |
6 | 
7 | > Source of the figure: [paper](https://www.jask.or.kr/articles/xml/ABxn/)
8 |
9 |
10 |
11 |
12 | # Loss functions
13 | We use two base loss functions and two perceptual loss functions.
14 |
15 | > Base loss
16 | 1. MSE: Mean Squred Error
17 | 
18 |
19 |
20 | 2. SI-SNR: Scale Invariant Source-to-Noise Ratio
21 | 
22 |
23 |
24 | > Perceptual loss
25 | 1. LMS: Log Mel Spectra
26 | 
27 |
28 |
29 | 2. PMSQE: Perceptual Metric for Speech Quality Evaluation
30 | 
31 |
32 |
33 | We combined 2 types of base loss functons and 2 types of perceptual loss functions. The coupling constant ratio was determined experimentally. For example, in the case of MSE, which is the basic loss function, the initial size is about 0.001 ~ 0.002, whereas the LMS has an initial size of 0.1 ~ 0.2 and PMSQE is about 0.8 ~ 1.3. Therefore, to combine the two terms to be of similar size, a smaller coefficient was used in the perceptual based loss function term. The coupling constant ratio is a result of reflecting the dynamic range of the two terms rather than reflecting the sensitivity of the two terms. Meanwhile, in the course of the experiment, we determined that the basic loss function is a more important term, so we changed the coefficients so that the dynamic range ratio including the coupling constant could be adjusted from 1:1 to 10:1, respectively.
34 |
35 |
36 | # Requirements
37 | > This repository is tested on Ubuntu 20.04.
38 | * Python 3.7+
39 | * Cuda 10.1+
40 | * CuDNN 7+
41 | * Pytorch 1.7+
42 |
43 |
44 | > Library
45 | * tqdm
46 | * asteroid
47 | * scipy
48 | * matplotlib
49 | * tensorboardX
50 | * pesq
51 | * pystoi
52 |
53 | # Prepare data
54 | The training and validation data consist of the following three dimensions.
55 | ```[Batch size, 2(input & target), wav length]```
56 |
57 | The test data consists of the following dimensions.
58 | ```[noise type, dB classes, Batch size, 2(input & target), wav length]```
59 | We use 2 type of noise, seen and unseen and 7 dB classes from -10dB to 20dB.
60 |
61 |
62 | We cut the wav files longer than 3 seconds into 3 seconds and zero padded for wav files shorter than 3 seconds.
63 | The sampling frequency is 16k.
64 |
65 |
69 |
70 | # Performance comparative evaluation
71 | **Objective evaluation**
72 |
73 | We evaluate the outputs with PESQ(Perceptual Evaluation of Speech Quality) and STOI(Short Time Objective Intelligibility measure).
74 | 
75 |
76 |
77 | 
78 |
79 |
80 | **Spectrogram**
81 |
82 | 
83 | > Source of the figure: [paper]()
84 |
85 | The spectrograms of (a) clean speech, (b) noisy speech at 0 dB SNR, estimated speeches using (c) MSE and PMSQE, (d) SI-SNR , (e) SI-SNR and PMSQE, (f) SI-SNR and LMS.
86 |
87 | # References
88 | **DCCRN: Deep Complex Convolution Recurrent Network for Phase-Aware Speech Enhancement**
89 | Yanxin Hu, Yun Liu, Shubo Lv, Mengtao Xing, Shimin Zhang, Yihui Fu, Jian Wu, Bihong Zhang, Lei Xie
90 | [[arXiv]](https://arxiv.org/abs/2008.00264) [[code]](https://github.com/huyanxin/DeepComplexCRN)
91 |
92 |
93 | # Paper
94 | **Performance comparison evaluation of speech enhancement using various loss function.**
95 | Seo-Rim Hwang, Joon Byun, Young-Cheul Park
96 | [[paper]](https://www.jask.or.kr/articles/xml/ABxn/)
97 |
98 |
99 | # Note
100 | * ~~I'm trying to the codes more clearly.~~
101 | * ~~It's still in the editing phase. Please refer to the existing code.~~
102 | * [cleanup and upgrade version code](https://github.com/seorim0/Speech_enhancement_with_Pytorch)
103 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Configuration for program
3 | """
4 |
5 | # model
6 | mode = 'DCCRN' # DCUNET / DCCRN
7 | info = 'MODEL INFORMATION : IT IS USED FOR FILE NAME'
8 |
9 | test = True
10 |
11 | # path
12 | job_dir = './job/'
13 | logs_dir = './logs/'
14 | chkpt_path = None
15 | # chkpt_model = 'FILE NAME THAT YOU WANT TO LOAD'
16 | # chkpt_path = job_dir + chkpt_model + 'chkpt_88.pt'
17 |
18 | # model information
19 | fs = 16000
20 | win_len = 400
21 | win_inc = 100
22 | ola_ratio = win_inc / win_len
23 | fft_len = 512
24 | sam_sec = fft_len / fs
25 | frm_samp = fs * (fft_len / fs)
26 | window_type = 'hanning'
27 |
28 | rnn_layers = 2
29 | rnn_units = 256
30 | masking_mode = 'E'
31 | use_clstm = True
32 | kernel_num = [32, 64, 128, 256, 256, 256] # DCCRN
33 | #kernel_num = [72, 72, 144, 144, 144, 160, 160, 180] # DCUNET
34 | loss_mode = 'SDR+PMSQE'
35 |
36 | # hyperparameters for model train
37 | max_epochs = 100
38 | learning_rate = 0.0005
39 | batch = 15
40 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils.data import Dataset, DataLoader
4 | import config as cfg
5 |
6 |
7 | # save np.load
8 | np_load_old = np.load
9 | # modify the default parameters of np.load
10 | np.load = lambda *a, **k: np_load_old(*a, allow_pickle=True, **k)
11 |
12 |
13 | def create_dataloader(mode):
14 | if mode == 'train':
15 | return DataLoader(
16 | dataset=Wave_Dataset(mode),
17 | batch_size=cfg.batch, # max 3696 * snr types
18 | shuffle=True,
19 | num_workers=0,
20 | pin_memory=True,
21 | drop_last=True,
22 | sampler=None
23 | )
24 | elif mode == 'valid':
25 | return DataLoader(
26 | dataset=Wave_Dataset(mode),
27 | batch_size=cfg.batch, shuffle=False, num_workers=0
28 | ) # max 1152
29 |
30 |
31 | def create_dataloader_for_test(mode, type, snr):
32 | if mode == 'test':
33 | return DataLoader(
34 | dataset=Wave_Dataset_for_test(mode, type, snr),
35 | batch_size=cfg.batch, shuffle=False, num_workers=0
36 | ) # max 192
37 |
38 |
39 | class Wave_Dataset(Dataset):
40 | def __init__(self, mode):
41 | # load data
42 | if mode == 'train':
43 | print('')
44 | print('Load the data...')
45 | self.input_path = './input/train_dataset.npy'
46 | elif mode == 'valid':
47 | print('')
48 | print('Load the data...')
49 | self.input_path = './input/validation_dataset.npy'
50 |
51 | self.input = np.load(self.input_path)
52 |
53 | def __len__(self):
54 | return len(self.input)
55 |
56 | def __getitem__(self, idx):
57 | inputs = self.input[idx][0]
58 | labels = self.input[idx][1]
59 |
60 | # transform to torch from numpy
61 | inputs = torch.from_numpy(inputs)
62 | labels = torch.from_numpy(labels)
63 |
64 | return inputs, labels
65 |
66 |
67 | class Wave_Dataset_for_test(Dataset):
68 | def __init__(self, mode, type, snr):
69 | # load data
70 | if mode == 'test':
71 | print('')
72 | print('Load the data...')
73 | self.input_path = './input/recon_test_dataset.npy'
74 |
75 | self.input = np.load(self.input_path)
76 | self.input = self.input[type][snr]
77 |
78 | def __len__(self):
79 | return len(self.input)
80 |
81 | def __getitem__(self, idx):
82 | inputs = self.input[idx][0]
83 | labels = self.input[idx][1]
84 |
85 | # transform to torch from numpy
86 | inputs = torch.from_numpy(inputs)
87 | labels = torch.from_numpy(labels)
88 |
89 | return inputs, labels
90 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | """
2 | DCCRN: Deep complex convolution recurrent network
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import config as cfg
8 | from tools_for_model import ConvSTFT, ConviSTFT, \
9 | ComplexConv2d, ComplexConvTranspose2d, NavieComplexLSTM, complex_cat, ComplexBatchNorm
10 | from tools_for_loss import si_snr, si_sdr, get_array_mel_loss, pmsqe_stft, pmsqe_loss, sdr
11 | from asteroid.filterbanks import transforms
12 |
13 |
14 | class DCCRN(nn.Module):
15 |
16 | def __init__(
17 | self,
18 | rnn_layers=cfg.rnn_layers,
19 | rnn_units=cfg.rnn_units,
20 | win_len=cfg.win_len,
21 | win_inc=cfg.win_inc,
22 | fft_len=cfg.fft_len,
23 | win_type=cfg.window_type,
24 | masking_mode='E',
25 | use_clstm=False,
26 | use_cbn=False,
27 | kernel_size=5,
28 | kernel_num=[16, 32, 64, 128, 256, 256]
29 | ):
30 | '''
31 |
32 | rnn_layers: the number of lstm layers in the crn,
33 | rnn_units: for clstm, rnn_units = real+imag
34 | '''
35 |
36 | super(DCCRN, self).__init__()
37 |
38 | # for fft
39 | self.win_len = win_len
40 | self.win_inc = win_inc
41 | self.fft_len = fft_len
42 | self.win_type = win_type
43 |
44 | input_dim = win_len
45 | output_dim = win_len
46 |
47 | self.rnn_units = rnn_units
48 | self.input_dim = input_dim
49 | self.output_dim = output_dim
50 | self.hidden_layers = rnn_layers
51 | self.kernel_size = kernel_size
52 | # self.kernel_num = [2, 8, 16, 32, 128, 128, 128]
53 | # self.kernel_num = [2, 16, 32, 64, 128, 256, 256]
54 | self.kernel_num = [2] + kernel_num
55 | self.masking_mode = masking_mode
56 | self.use_clstm = use_clstm
57 |
58 | # bidirectional=True
59 | bidirectional = False
60 | fac = 2 if bidirectional else 1
61 |
62 | fix = True
63 | self.fix = fix
64 | self.stft = ConvSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix)
65 | self.istft = ConviSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix)
66 |
67 | self.encoder = nn.ModuleList()
68 | self.decoder = nn.ModuleList()
69 | for idx in range(len(self.kernel_num) - 1):
70 | self.encoder.append(
71 | nn.Sequential(
72 | # nn.ConstantPad2d([0, 0, 0, 0], 0),
73 | ComplexConv2d(
74 | self.kernel_num[idx],
75 | self.kernel_num[idx + 1],
76 | kernel_size=(self.kernel_size, 2),
77 | stride=(2, 1),
78 | padding=(2, 1)
79 | ),
80 | nn.BatchNorm2d(self.kernel_num[idx + 1]) if not use_cbn else ComplexBatchNorm(
81 | self.kernel_num[idx + 1]),
82 | nn.PReLU()
83 | )
84 | )
85 | hidden_dim = self.fft_len // (2 ** (len(self.kernel_num)))
86 |
87 | if self.use_clstm:
88 | rnns = []
89 | for idx in range(rnn_layers):
90 | rnns.append(
91 | NavieComplexLSTM(
92 | input_size=hidden_dim * self.kernel_num[-1] if idx == 0 else self.rnn_units,
93 | hidden_size=self.rnn_units,
94 | bidirectional=bidirectional,
95 | batch_first=False,
96 | projection_dim=hidden_dim * self.kernel_num[-1] if idx == rnn_layers - 1 else None,
97 | )
98 | )
99 | self.enhance = nn.Sequential(*rnns)
100 | else:
101 | self.enhance = nn.LSTM(
102 | input_size=hidden_dim * self.kernel_num[-1],
103 | hidden_size=self.rnn_units,
104 | num_layers=2,
105 | dropout=0.0,
106 | bidirectional=bidirectional,
107 | batch_first=False
108 | )
109 | self.tranform = nn.Linear(self.rnn_units * fac, hidden_dim * self.kernel_num[-1])
110 |
111 | for idx in range(len(self.kernel_num) - 1, 0, -1):
112 | if idx != 1:
113 | self.decoder.append(
114 | nn.Sequential(
115 | ComplexConvTranspose2d(
116 | self.kernel_num[idx] * 2,
117 | self.kernel_num[idx - 1],
118 | kernel_size=(self.kernel_size, 2),
119 | stride=(2, 1),
120 | padding=(2, 0),
121 | output_padding=(1, 0)
122 | ),
123 | nn.BatchNorm2d(self.kernel_num[idx - 1]) if not use_cbn else ComplexBatchNorm(
124 | self.kernel_num[idx - 1]),
125 | # nn.ELU()
126 | nn.PReLU()
127 | )
128 | )
129 | else:
130 | self.decoder.append(
131 | nn.Sequential(
132 | ComplexConvTranspose2d(
133 | self.kernel_num[idx] * 2,
134 | self.kernel_num[idx - 1],
135 | kernel_size=(self.kernel_size, 2),
136 | stride=(2, 1),
137 | padding=(2, 0),
138 | output_padding=(1, 0)
139 | ),
140 | )
141 | )
142 |
143 | self.flatten_parameters()
144 |
145 | def flatten_parameters(self):
146 | if isinstance(self.enhance, nn.LSTM):
147 | self.enhance.flatten_parameters()
148 |
149 | def forward(self, inputs, lens=None):
150 | specs = self.stft(inputs)
151 | real = specs[:, :self.fft_len // 2 + 1]
152 | imag = specs[:, self.fft_len // 2 + 1:]
153 | spec_mags = torch.sqrt(real ** 2 + imag ** 2 + 1e-8)
154 | spec_mags = spec_mags
155 |
156 | ##
157 |
158 | ##
159 | spec_phase = torch.atan2(imag, real)
160 | spec_phase = spec_phase
161 | cspecs = torch.stack([real, imag], 1)
162 | cspecs = cspecs[:, :, 1:]
163 | '''
164 | means = torch.mean(cspecs, [1,2,3], keepdim=True)
165 | std = torch.std(cspecs, [1,2,3], keepdim=True )
166 | normed_cspecs = (cspecs-means)/(std+1e-8)
167 | out = normed_cspecs
168 | '''
169 |
170 | out = cspecs
171 | encoder_out = []
172 |
173 | for idx, layer in enumerate(self.encoder):
174 | out = layer(out)
175 | # print('encoder', out.size())
176 | encoder_out.append(out)
177 |
178 | batch_size, channels, dims, lengths = out.size()
179 | out = out.permute(3, 0, 1, 2)
180 | if self.use_clstm:
181 | r_rnn_in = out[:, :, :channels // 2]
182 | i_rnn_in = out[:, :, channels // 2:]
183 | r_rnn_in = torch.reshape(r_rnn_in, [lengths, batch_size, channels // 2 * dims])
184 | i_rnn_in = torch.reshape(i_rnn_in, [lengths, batch_size, channels // 2 * dims])
185 |
186 | r_rnn_in, i_rnn_in = self.enhance([r_rnn_in, i_rnn_in])
187 |
188 | r_rnn_in = torch.reshape(r_rnn_in, [lengths, batch_size, channels // 2, dims])
189 | i_rnn_in = torch.reshape(i_rnn_in, [lengths, batch_size, channels // 2, dims])
190 | out = torch.cat([r_rnn_in, i_rnn_in], 2)
191 |
192 | else:
193 | # to [L, B, C, D]
194 | out = torch.reshape(out, [lengths, batch_size, channels * dims])
195 | out, _ = self.enhance(out)
196 | out = self.tranform(out)
197 | out = torch.reshape(out, [lengths, batch_size, channels, dims])
198 |
199 | out = out.permute(1, 2, 3, 0)
200 |
201 | for idx in range(len(self.decoder)):
202 | out = complex_cat([out, encoder_out[-1 - idx]], 1)
203 | out = self.decoder[idx](out)
204 | out = out[..., 1:]
205 | # print('decoder', out.size())
206 | mask_real = out[:, 0]
207 | mask_imag = out[:, 1]
208 | mask_real = F.pad(mask_real, [0, 0, 1, 0])
209 | mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
210 |
211 | if self.masking_mode == 'E':
212 | mask_mags = (mask_real ** 2 + mask_imag ** 2) ** 0.5
213 | real_phase = mask_real / (mask_mags + 1e-8)
214 | imag_phase = mask_imag / (mask_mags + 1e-8)
215 | mask_phase = torch.atan2(
216 | imag_phase,
217 | real_phase
218 | )
219 |
220 | # mask_mags = torch.clamp_(mask_mags,0,100)
221 | mask_mags = torch.tanh(mask_mags)
222 | est_mags = mask_mags * spec_mags
223 | est_phase = spec_phase + mask_phase
224 | real = est_mags * torch.cos(est_phase)
225 | imag = est_mags * torch.sin(est_phase)
226 | elif self.masking_mode == 'C':
227 | real, imag = real * mask_real - imag * mask_imag, real * mask_imag + imag * mask_real
228 | elif self.masking_mode == 'R':
229 | real, imag = real * mask_real, imag * mask_imag
230 |
231 | out_spec = torch.cat([real, imag], 1)
232 | out_wav = self.istft(out_spec)
233 |
234 | out_wav = torch.squeeze(out_wav, 1)
235 | # out_wav = torch.tanh(out_wav)
236 | out_wav = torch.clamp_(out_wav, -1, 1)
237 | return mask_real, mask_imag, real, imag, out_wav # out_spec, out_wav
238 |
239 | def get_params(self, weight_decay=0.0):
240 | # add L2 penalty
241 | weights, biases = [], []
242 | for name, param in self.named_parameters():
243 | if 'bias' in name:
244 | biases += [param]
245 | else:
246 | weights += [param]
247 | params = [{
248 | 'params': weights,
249 | 'weight_decay': weight_decay,
250 | }, {
251 | 'params': biases,
252 | 'weight_decay': 0.0,
253 | }]
254 | return params
255 |
256 | def loss(self, inputs, labels, real_spec, img_spec, loss_mode=cfg.loss_mode):
257 | if loss_mode == 'MSE':
258 | return F.mse_loss(inputs, labels, reduction='mean')
259 |
260 | elif loss_mode == 'SDR':
261 | return -sdr(labels, inputs)
262 |
263 | elif loss_mode == 'SI-SNR':
264 | return -(si_snr(inputs, labels))
265 |
266 | elif loss_mode == 'SI-SDR':
267 | return -(si_sdr(labels, inputs))
268 |
269 | elif loss_mode == 'MSE+LMS':
270 |
271 | mse_loss = F.mse_loss(inputs, labels, reduction='mean')
272 |
273 | # for mel loss calculation
274 | clean_specs = self.stft(labels)
275 | clean_real = clean_specs[:, :self.fft_len // 2 + 1]
276 | clean_imag = clean_specs[:, self.fft_len // 2 + 1:]
277 | clean_mags = torch.sqrt(clean_real ** 2 + clean_imag ** 2 + 1e-7)
278 |
279 | est_clean_mags = torch.sqrt(real_spec ** 2 + img_spec ** 2 + 1e-7)
280 | mel_loss = get_array_mel_loss(clean_mags, est_clean_mags)
281 |
282 | r1 = 1e+3
283 | r2 = 1
284 | r = r1 + r2
285 |
286 | loss = (r1 * mse_loss + r2 * mel_loss) / r
287 |
288 | return loss
289 |
290 | elif loss_mode == 'MSE+SI-SNR':
291 | snr_loss = -(si_snr(inputs, labels))
292 | mse_loss = F.mse_loss(inputs, labels, reduction='mean')
293 |
294 | r1 = 1
295 | r2 = 100
296 | r = r1 + r2
297 |
298 | loss = (r1 * snr_loss + r2 * mse_loss) / r
299 |
300 | return loss
301 |
302 | elif loss_mode == 'MSE+PMSQE':
303 | ref_wav = labels.reshape(-1, 3, 16000)
304 | est_wav = inputs.reshape(-1, 3, 16000)
305 | ref_wav = ref_wav.cpu()
306 | est_wav = est_wav.cpu()
307 |
308 | ref_spec = transforms.take_mag(pmsqe_stft(ref_wav))
309 | est_spec = transforms.take_mag(pmsqe_stft(est_wav))
310 |
311 | loss = pmsqe_loss(ref_spec, est_spec)
312 |
313 | loss = loss.cuda()
314 |
315 | return loss
316 |
317 | elif loss_mode == 'SI-SNR+SI-SDR':
318 | snr_loss = -(si_snr(inputs, labels))
319 | sdr_loss = -(si_sdr(inputs, labels))
320 |
321 | r1 = 1
322 | r2 = 1
323 | r = r1 + r2
324 |
325 | loss = (r1 * snr_loss + r2 * sdr_loss) / r
326 |
327 | return loss
328 |
329 | elif loss_mode == 'SDR+LMS':
330 | sdr_loss = -sdr(labels, inputs)
331 |
332 | # for mel loss calculation
333 | clean_specs = self.stft(labels)
334 | clean_real = clean_specs[:, :self.fft_len // 2 + 1]
335 | clean_imag = clean_specs[:, self.fft_len // 2 + 1:]
336 | clean_mags = torch.sqrt(clean_real ** 2 + clean_imag ** 2 + 1e-7)
337 |
338 | est_clean_mags = torch.sqrt(real_spec ** 2 + img_spec ** 2 + 1e-7)
339 | mel_loss = get_array_mel_loss(clean_mags, est_clean_mags)
340 |
341 | r1 = 1
342 | r2 = 2
343 | r = r1 + r2
344 |
345 | loss = (r1 * sdr_loss + r2 * mel_loss) / r
346 | return loss
347 |
348 | elif loss_mode == 'SDR+PMSQE':
349 | sdr_loss = -sdr(labels, inputs)
350 |
351 | ref_wav = labels.reshape(-1, 3, 16000)
352 | est_wav = inputs.reshape(-1, 3, 16000)
353 | ref_wav = ref_wav.cpu()
354 | est_wav = est_wav.cpu()
355 |
356 | ref_spec = transforms.take_mag(pmsqe_stft(ref_wav))
357 | est_spec = transforms.take_mag(pmsqe_stft(est_wav))
358 |
359 | # p_loss = pmsqe_loss(ref_spec, est_spec) wrong
360 | p_loss = pmsqe_loss(est_spec, ref_spec)
361 |
362 | r1 = 1
363 | r2 = 15
364 | r = r1 + r2
365 |
366 | loss = (r1 * sdr_loss + r2 * p_loss) / r
367 | return loss
368 |
369 | elif loss_mode == 'SI-SNR+LMS':
370 | snr_loss = -(si_snr(inputs, labels))
371 |
372 | # for mel loss calculation
373 | clean_specs = self.stft(labels)
374 | clean_real = clean_specs[:, :self.fft_len // 2 + 1]
375 | clean_imag = clean_specs[:, self.fft_len // 2 + 1:]
376 | clean_mags = torch.sqrt(clean_real ** 2 + clean_imag ** 2 + 1e-7)
377 |
378 | est_clean_mags = torch.sqrt(real_spec ** 2 + img_spec ** 2 + 1e-7)
379 | mel_loss = get_array_mel_loss(clean_mags, est_clean_mags)
380 |
381 | r1 = 1
382 | r2 = 2
383 | r = r1 + r2
384 |
385 | loss = (r1 * snr_loss + r2 * mel_loss) / r
386 |
387 | return loss
388 |
389 | elif loss_mode == 'SI-SNR+PMSQE':
390 | ref_wav = labels.reshape(-1, 3, 16000)
391 | est_wav = inputs.reshape(-1, 3, 16000)
392 | ref_wav = ref_wav.cpu()
393 | est_wav = est_wav.cpu()
394 |
395 | ref_spec = transforms.take_mag(pmsqe_stft(ref_wav))
396 | est_spec = transforms.take_mag(pmsqe_stft(est_wav))
397 |
398 | p_loss = pmsqe_loss(est_spec, ref_spec)
399 |
400 | snr_loss = -(si_snr(est_wav, ref_wav))
401 |
402 | r1 = 8
403 | r2 = 1
404 | r = r1 + r2
405 |
406 | loss = (r1 * p_loss + r2 * snr_loss) / r
407 |
408 | return loss
409 |
410 |
--------------------------------------------------------------------------------
/samples/0dB/[15] SI-SNR+LMS.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[15] SI-SNR+LMS.wav
--------------------------------------------------------------------------------
/samples/0dB/[15] SI-SNR.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[15] SI-SNR.wav
--------------------------------------------------------------------------------
/samples/0dB/[15]CLEAN.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[15]CLEAN.wav
--------------------------------------------------------------------------------
/samples/0dB/[15]NOISY.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[15]NOISY.wav
--------------------------------------------------------------------------------
/samples/0dB/[189] SI-SNR+LMS.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[189] SI-SNR+LMS.wav
--------------------------------------------------------------------------------
/samples/0dB/[189] SI-SNR.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[189] SI-SNR.wav
--------------------------------------------------------------------------------
/samples/0dB/[189]CLEAN.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[189]CLEAN.wav
--------------------------------------------------------------------------------
/samples/0dB/[189]NOISY.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[189]NOISY.wav
--------------------------------------------------------------------------------
/samples/0dB/[1] SI-SNR+LMS.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[1] SI-SNR+LMS.wav
--------------------------------------------------------------------------------
/samples/0dB/[1] SI-SNR.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[1] SI-SNR.wav
--------------------------------------------------------------------------------
/samples/0dB/[1]CLEAN.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[1]CLEAN.wav
--------------------------------------------------------------------------------
/samples/0dB/[1]NOISY.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[1]NOISY.wav
--------------------------------------------------------------------------------
/samples/0dB/[21] SI-SNR+LMS.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[21] SI-SNR+LMS.wav
--------------------------------------------------------------------------------
/samples/0dB/[21] SI-SNR.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[21] SI-SNR.wav
--------------------------------------------------------------------------------
/samples/0dB/[21]CLEAN.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[21]CLEAN.wav
--------------------------------------------------------------------------------
/samples/0dB/[21]NOISY.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[21]NOISY.wav
--------------------------------------------------------------------------------
/samples/0dB/[78] SI-SNR+LMS.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[78] SI-SNR+LMS.wav
--------------------------------------------------------------------------------
/samples/0dB/[78] SI-SNR.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[78] SI-SNR.wav
--------------------------------------------------------------------------------
/samples/0dB/[78]CLEAN.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[78]CLEAN.wav
--------------------------------------------------------------------------------
/samples/0dB/[78]NOISY.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[78]NOISY.wav
--------------------------------------------------------------------------------
/samples/0dB/[88] SI-SNR+LMS.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[88] SI-SNR+LMS.wav
--------------------------------------------------------------------------------
/samples/0dB/[88] SI-SNR.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[88] SI-SNR.wav
--------------------------------------------------------------------------------
/samples/0dB/[88]CLEAN.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[88]CLEAN.wav
--------------------------------------------------------------------------------
/samples/0dB/[88]NOISY.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[88]NOISY.wav
--------------------------------------------------------------------------------
/tester.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import numpy as np
5 | import config as cfg
6 | from run import model_test
7 | from dataloader import create_dataloader_for_test
8 | from model import DCCRN
9 |
10 |
11 | ###############################################################################
12 | # Helper function definition #
13 | ###############################################################################
14 | # Write training related parameters into the log file.
15 | def write_status_to_log_file(fp, total_parameters):
16 | fp.write('adsfasdfsdfds')
17 | fp.write('%d-%d-%d %d:%d:%d\n' %
18 | (time.localtime().tm_year, time.localtime().tm_mon,
19 | time.localtime().tm_mday, time.localtime().tm_hour,
20 | time.localtime().tm_min, time.localtime().tm_sec))
21 | fp.write('mode : %s_%s\n' % (cfg.mode, cfg.info))
22 | fp.write('learning rate : %g\n' % cfg.learning_rate)
23 | fp.write('total params : %d (%.2f M, %.2f MBytes)\n' %
24 | (total_parameters,
25 | total_parameters / 1000000.0,
26 | total_parameters * 4.0 / 1000000.0))
27 |
28 |
29 | # Calculate the size of total network.
30 | def calculate_total_params(our_model):
31 | total_parameters = 0
32 | for variable in our_model.parameters():
33 | shape = variable.size()
34 | variable_parameters = 1
35 | for dim in shape:
36 | variable_parameters *= dim
37 | total_parameters += variable_parameters
38 |
39 | return total_parameters
40 |
41 |
42 | ###############################################################################
43 | # Parameter Initialization #
44 | ###############################################################################
45 | print('***********************************************************')
46 | print('* Python library for DNN-based speech enhancement *')
47 | print('* using Pytorch API *')
48 | print('***********************************************************')
49 |
50 | # Set device
51 | DEVICE = torch.device("cuda")
52 |
53 | # Set model
54 | if cfg.mode == 'DCCRN':
55 | model = DCCRN(rnn_units=cfg.rnn_units, masking_mode=cfg.masking_mode, use_clstm=cfg.use_clstm,
56 | kernel_num=cfg.kernel_num).to(DEVICE)
57 |
58 | ###############################################################################
59 | # Set optimizer and learning rate #
60 | ###############################################################################
61 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)
62 | total_params = calculate_total_params(model)
63 |
64 | ###############################################################################
65 | # Confirm model information #
66 | ###############################################################################
67 | print('%d-%d-%d %d:%d:%d\n' %
68 | (time.localtime().tm_year, time.localtime().tm_mon,
69 | time.localtime().tm_mday, time.localtime().tm_hour,
70 | time.localtime().tm_min, time.localtime().tm_sec))
71 | print('mode : %s_%s\n' % (cfg.mode, cfg.info))
72 | print('learning rate : %g\n' % cfg.learning_rate)
73 | print('total params : %d (%.2f M, %.2f MBytes)\n' %
74 | (total_params,
75 | total_params / 1000000.0,
76 | total_params * 4.0 / 1000000.0))
77 |
78 |
79 | ###############################################################################
80 | # Set a log file to store progress. #
81 | # Set a hps file to store hyper-parameters information. #
82 | ###############################################################################
83 | # Load the checkpoint
84 | if cfg.chkpt_path is not None:
85 | print('Resuming from checkpoint: %s' % cfg.chkpt_path)
86 |
87 | # Set a log file to store progress.
88 | dir_to_save = cfg.job_dir + cfg.chkpt_model
89 | dir_to_logs = cfg.logs_dir + cfg.chkpt_model
90 |
91 | checkpoint = torch.load(cfg.chkpt_path)
92 | model.load_state_dict(checkpoint['model'])
93 | optimizer.load_state_dict(checkpoint['optimizer'])
94 | epoch_start_idx = checkpoint['epoch'] + 1
95 | mse_vali_total = np.load(str(dir_to_save + '/mse_vali_total.npy'))
96 | if len(mse_vali_total) < cfg.max_epochs:
97 | plus = cfg.max_epochs - len(mse_vali_total)
98 | mse_vali_total = np.concatenate((mse_vali_total, np.zeros(plus)), 0)
99 |
100 |
101 | if not os.path.exists(dir_to_save):
102 | os.mkdir(dir_to_save)
103 | os.mkdir(dir_to_logs)
104 |
105 | log_fname = str(dir_to_save + '/log.txt')
106 | if not os.path.exists(log_fname):
107 | fp = open(log_fname, 'w')
108 | write_status_to_log_file(fp, total_params)
109 | else:
110 | fp = open(log_fname, 'a')
111 |
112 | # Set a hps file to store hyper-parameters information.
113 | hps_fname = str(dir_to_save + '/hp_str.txt')
114 | fp_h = open(hps_fname, 'w')
115 |
116 | with open('config.py', 'r') as f:
117 | hp_str = ''.join(f.readlines())
118 | fp_h.write(hp_str)
119 | fp_h.close()
120 |
121 | min_index = np.argmin(mse_vali_total)
122 | print('Minimum validation loss is at '+str(min_index+1)+'.')
123 |
124 | ###############################################################################
125 | # Test #
126 | ###############################################################################
127 | if cfg.test is True:
128 | print('Starting test run')
129 |
130 | # check the lowest validation loss epoch
131 | want_to_check = torch.load(dir_to_save + '/chkpt_opt.pt')
132 | model.load_state_dict(want_to_check['model'])
133 | optimizer.load_state_dict(want_to_check['optimizer'])
134 | epoch_start_idx = want_to_check['epoch'] + 1
135 | mse_vali_total = np.load(str(dir_to_save + '/mse_vali_total.npy'))
136 |
137 | # noise = [seen, unseen]
138 | noise_type = ['seen', 'unseen']
139 | # SNR = [-10, -5, 0, 5, 10]
140 | noisy_snr = ['-10', '-5', '0', '5', '10', '15', '20', 'Avg']
141 | for type in range(len(noise_type)):
142 | for snr in range(len(noisy_snr)):
143 | test_loader = create_dataloader_for_test(mode='test', type=type, snr=snr)
144 | test_loss, test_pesq, test_stoi = \
145 | model_test(noise_type[type], noisy_snr[snr], model,
146 | test_loader, dir_to_save, DEVICE)
147 |
148 | print('Noise type {} | snr {}'.format(noise_type[type], noisy_snr[snr]))
149 | fp.write('\n\nNoise type {} | snr {}'.format(noise_type[type], noisy_snr[snr]))
150 | print('Test loss {:.6} | PESQ {:.6} | STOI {:.6}'
151 | .format(test_loss, test_pesq, test_stoi))
152 | fp.write('Test loss {:.6f} | PESQ {:.6f} | STOI {:.6f}'
153 | .format(test_loss, test_pesq, test_stoi))
154 |
155 | fp.close()
156 | else:
157 | fp.close()
158 |
--------------------------------------------------------------------------------
/tools_for_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 | import config as cfg
5 | from asteroid.losses import SingleSrcPMSQE, PITLossWrapper
6 | from asteroid.filterbanks import STFTFB, Encoder
7 |
8 |
9 | # Set training device
10 | DEVICE = torch.device("cuda")
11 |
12 |
13 | ############################################################################
14 | # for model structure & loss function #
15 | ############################################################################
16 | def remove_dc(data):
17 | mean = torch.mean(data, -1, keepdim=True)
18 | data = data - mean
19 | return data
20 |
21 |
22 | def l2_norm(s1, s2):
23 | # norm = torch.sqrt(torch.sum(s1*s2, 1, keepdim=True))
24 | # norm = torch.norm(s1*s2, 1, keepdim=True)
25 |
26 | norm = torch.sum(s1 * s2, -1, keepdim=True)
27 | return norm
28 |
29 |
30 | def sdr(s1, s2, eps=1e-8):
31 | sn = l2_norm(s1, s1)
32 | sn_m_shn = l2_norm(s1 - s2, s1 - s2)
33 | sdr_loss = 10 * torch.log10(sn**2 / (sn_m_shn**2 + eps))
34 | return torch.mean(sdr_loss)
35 |
36 |
37 | def si_snr(s1, s2, eps=1e-8):
38 | # s1 = remove_dc(s1)
39 | # s2 = remove_dc(s2)
40 | s1_s2_norm = l2_norm(s1, s2)
41 | s2_s2_norm = l2_norm(s2, s2)
42 | s_target = s1_s2_norm / (s2_s2_norm + eps) * s2
43 | e_nosie = s1 - s_target
44 | target_norm = l2_norm(s_target, s_target)
45 | noise_norm = l2_norm(e_nosie, e_nosie)
46 | snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps)
47 | return torch.mean(snr)
48 |
49 |
50 | def si_sdr(reference, estimation, eps=1e-8):
51 | """
52 | Scale-Invariant Signal-to-Distortion Ratio (SI-SDR)
53 | Args:
54 | reference: numpy.ndarray, [..., T]
55 | estimation: numpy.ndarray, [..., T]
56 | Returns:
57 | SI-SDR
58 | [1] SDR– Half- Baked or Well Done?
59 | http://www.merl.com/publications/docs/TR2019-013.pdf
60 | >>> np.random.seed(0)
61 | >>> reference = np.random.randn(100)
62 | >>> si_sdr(reference, reference)
63 | inf
64 | >>> si_sdr(reference, reference * 2)
65 | inf
66 | >>> si_sdr(reference, np.flip(reference))
67 | -25.127672346460717
68 | >>> si_sdr(reference, reference + np.flip(reference))
69 | 0.481070445785553
70 | >>> si_sdr(reference, reference + 0.5)
71 | 6.3704606032577304
72 | >>> si_sdr(reference, reference * 2 + 1)
73 | 6.3704606032577304
74 | >>> si_sdr([1., 0], [0., 0]) # never predict only zeros
75 | nan
76 | >>> si_sdr([reference, reference], [reference * 2 + 1, reference * 1 + 0.5])
77 | array([6.3704606, 6.3704606])
78 | :param reference:
79 | :param estimation:
80 | :param eps:
81 | """
82 |
83 | reference_energy = torch.sum(reference ** 2, axis=-1, keepdims=True)
84 |
85 | # This is $\alpha$ after Equation (3) in [1].
86 | optimal_scaling = torch.sum(reference * estimation, axis=-1, keepdims=True) / reference_energy + eps
87 |
88 | # This is $e_{\text{target}}$ in Equation (4) in [1].
89 | projection = optimal_scaling * reference
90 |
91 | # This is $e_{\text{res}}$ in Equation (4) in [1].
92 | noise = estimation - projection
93 |
94 | ratio = torch.sum(projection ** 2, axis=-1) / torch.sum(noise ** 2, axis=-1) + eps
95 |
96 | ratio = torch.mean(ratio)
97 | return 10 * torch.log10(ratio + eps)
98 |
99 |
100 | class rmse(torch.nn.Module):
101 | def __init__(self):
102 | super(rmse, self).__init__()
103 |
104 | def forward(self, y_true, y_pred):
105 | mse = torch.mean((y_pred - y_true) ** 2, axis=-1)
106 | rmse = torch.sqrt(mse + 1e-7)
107 |
108 | return torch.mean(rmse)
109 |
110 |
111 |
112 | ############################################################################
113 | # MFCC (Mel Frequency Cepstral Coefficients) #
114 | ############################################################################
115 |
116 | # based on a combination of this article:
117 | # http://practicalcryptography.com/miscellaneous/machine-learning/...
118 | # guide-mel-frequency-cepstral-coefficients-mfccs/
119 | # and some of this code:
120 | # http://stackoverflow.com/questions/5835568/...
121 | # how-to-get-mfcc-from-an-fft-on-a-signal
122 |
123 | # conversions between Mel scale and regular frequency scale
124 | def freqToMel(freq):
125 | return 1127.01048 * math.log(1 + freq / 700.0)
126 |
127 |
128 | def melToFreq(mel):
129 | return 700 * (math.exp(mel / 1127.01048) - 1)
130 |
131 |
132 | # generate Mel filter bank
133 | def melFilterBank(numCoeffs, fftSize=None):
134 | minHz = 0
135 | maxHz = cfg.fs / 2 # max Hz by Nyquist theorem
136 | if (fftSize is None):
137 | numFFTBins = cfg.win_len
138 | else:
139 | numFFTBins = int(fftSize / 2) + 1
140 |
141 | maxMel = freqToMel(maxHz)
142 | minMel = freqToMel(minHz)
143 |
144 | # we need (numCoeffs + 2) points to create (numCoeffs) filterbanks
145 | melRange = np.array(range(numCoeffs + 2))
146 | melRange = melRange.astype(np.float32)
147 |
148 | # create (numCoeffs + 2) points evenly spaced between minMel and maxMel
149 | melCenterFilters = melRange * (maxMel - minMel) / (numCoeffs + 1) + minMel
150 |
151 | for i in range(numCoeffs + 2):
152 | # mel domain => frequency domain
153 | melCenterFilters[i] = melToFreq(melCenterFilters[i])
154 |
155 | # frequency domain => FFT bins
156 | melCenterFilters[i] = math.floor(numFFTBins * melCenterFilters[i] / maxHz)
157 |
158 | # create matrix of filters (one row is one filter)
159 | filterMat = np.zeros((numCoeffs, numFFTBins))
160 |
161 | # generate triangular filters (in frequency domain)
162 | for i in range(1, numCoeffs + 1):
163 | filter = np.zeros(numFFTBins)
164 |
165 | startRange = int(melCenterFilters[i - 1])
166 | midRange = int(melCenterFilters[i])
167 | endRange = int(melCenterFilters[i + 1])
168 |
169 | for j in range(startRange, midRange):
170 | filter[j] = (float(j) - startRange) / (midRange - startRange)
171 | for j in range(midRange, endRange):
172 | filter[j] = 1 - ((float(j) - midRange) / (endRange - midRange))
173 |
174 | filterMat[i - 1] = filter
175 |
176 | # return filterbank as matrix
177 | return filterMat
178 |
179 |
180 |
181 | ############################################################################
182 | # Finally: a perceptual loss function (based on Mel scale) #
183 | ############################################################################
184 |
185 | FFT_SIZE = cfg.fft_len
186 |
187 | # multi-scale MFCC distance
188 | MEL_SCALES = [16, 32, 64] # for LMS
189 | # PAM : MEL_SCALES = [32, 64]
190 |
191 |
192 | # given a (symbolic Theano) array of size M x WINDOW_SIZE
193 | # this returns an array M x N where each window has been replaced
194 | # by some perceptual transform (in this case, MFCC coeffs)
195 | def perceptual_transform(x):
196 | # precompute Mel filterbank: [FFT_SIZE x NUM_MFCC_COEFFS]
197 | MEL_FILTERBANKS = []
198 | for scale in MEL_SCALES:
199 | filterbank_npy = melFilterBank(scale, FFT_SIZE).transpose()
200 | torch_filterbank_npy = torch.from_numpy(filterbank_npy).type(torch.FloatTensor)
201 | MEL_FILTERBANKS.append(torch_filterbank_npy.to(DEVICE))
202 |
203 | transforms = []
204 | # powerSpectrum = torch_dft_mag(x, DFT_REAL, DFT_IMAG)**2
205 |
206 | powerSpectrum = x.view(-1, FFT_SIZE // 2 + 1)
207 | powerSpectrum = 1.0 / FFT_SIZE * powerSpectrum
208 |
209 | for filterbank in MEL_FILTERBANKS:
210 | filteredSpectrum = torch.mm(powerSpectrum, filterbank)
211 | filteredSpectrum = torch.log(filteredSpectrum + 1e-7)
212 | transforms.append(filteredSpectrum)
213 |
214 | return transforms
215 |
216 |
217 | # perceptual loss function
218 | class perceptual_distance(torch.nn.Module):
219 |
220 | def __init__(self):
221 | super(perceptual_distance, self).__init__()
222 |
223 | def forward(self, y_true, y_pred):
224 | rmse_loss = rmse()
225 | # y_true = torch.reshape(y_true, (-1, WINDOW_SIZE))
226 | # y_pred = torch.reshape(y_pred, (-1, WINDOW_SIZE))
227 |
228 | pvec_true = perceptual_transform(y_true)
229 | pvec_pred = perceptual_transform(y_pred)
230 |
231 | distances = []
232 | for i in range(0, len(pvec_true)):
233 | error = rmse_loss(pvec_pred[i], pvec_true[i])
234 | error = error.unsqueeze(dim=-1)
235 | distances.append(error)
236 | distances = torch.cat(distances, axis=-1)
237 |
238 | loss = torch.mean(distances, axis=-1)
239 | return torch.mean(loss)
240 |
241 |
242 | get_mel_loss = perceptual_distance()
243 |
244 |
245 | def get_array_mel_loss(clean_array, est_array):
246 | array_mel_loss = 0
247 | for i in range(len(clean_array)):
248 | mel_loss = get_mel_loss(clean_array[i], est_array[i])
249 | array_mel_loss += mel_loss
250 |
251 | avg_mel_loss = array_mel_loss / len(clean_array)
252 | return avg_mel_loss
253 |
254 |
255 | ############################################################################
256 | # for pmsqe loss #
257 | ############################################################################
258 | pmsqe_stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
259 | pmsqe_loss = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt')
260 |
--------------------------------------------------------------------------------
/tools_for_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import time
5 | import torch.nn.functional as F
6 | from scipy.signal import get_window
7 | import matplotlib.pylab as plt
8 | from pesq import pesq
9 | from pystoi import stoi
10 |
11 |
12 | ############################################################################
13 | # for convolutional STFT #
14 | ############################################################################
15 | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
16 | if win_type == 'None' or win_type is None:
17 | window = np.ones(win_len)
18 | else:
19 | window = get_window(win_type, win_len, fftbins=True) # **0.5
20 |
21 | N = fft_len
22 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
23 | real_kernel = np.real(fourier_basis)
24 | imag_kernel = np.imag(fourier_basis)
25 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T
26 |
27 | if invers:
28 | kernel = np.linalg.pinv(kernel).T
29 |
30 | kernel = kernel * window
31 | kernel = kernel[:, None, :]
32 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32))
33 |
34 |
35 | class ConvSTFT(nn.Module):
36 |
37 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
38 | super(ConvSTFT, self).__init__()
39 |
40 | if fft_len == None:
41 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
42 | else:
43 | self.fft_len = fft_len
44 |
45 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
46 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix))
47 | self.register_buffer('weight', kernel)
48 | self.feature_type = feature_type
49 | self.stride = win_inc
50 | self.win_len = win_len
51 | self.dim = self.fft_len
52 |
53 | def forward(self, inputs):
54 | if inputs.dim() == 2:
55 | inputs = torch.unsqueeze(inputs, 1)
56 | inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride])
57 | outputs = F.conv1d(inputs, self.weight, stride=self.stride)
58 |
59 | if self.feature_type == 'complex':
60 | return outputs
61 | else:
62 | dim = self.dim // 2 + 1
63 | real = outputs[:, :dim, :]
64 | imag = outputs[:, dim:, :]
65 | mags = torch.sqrt(real ** 2 + imag ** 2)
66 | phase = torch.atan2(imag, real)
67 | return mags, phase
68 |
69 |
70 | class ConviSTFT(nn.Module):
71 |
72 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
73 | super(ConviSTFT, self).__init__()
74 | if fft_len == None:
75 | self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
76 | else:
77 | self.fft_len = fft_len
78 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True)
79 | #self.weight = nn.Parameter(kernel, requires_grad=(not fix))
80 | self.register_buffer('weight', kernel)
81 | self.feature_type = feature_type
82 | self.win_type = win_type
83 | self.win_len = win_len
84 | self.stride = win_inc
85 | self.stride = win_inc
86 | self.dim = self.fft_len
87 | self.register_buffer('window', window)
88 | self.register_buffer('enframe', torch.eye(win_len)[:,None,:])
89 |
90 | def forward(self, inputs, phase=None):
91 | """
92 | inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
93 | phase: [B, N//2+1, T] (if not none)
94 | """
95 |
96 | if phase is not None:
97 | real = inputs * torch.cos(phase)
98 | imag = inputs * torch.sin(phase)
99 | inputs = torch.cat([real, imag], 1)
100 | outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
101 |
102 | # this is from torch-stft: https://github.com/pseeth/torch-stft
103 | t = self.window.repeat(1, 1, inputs.size(-1)) ** 2
104 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
105 | outputs = outputs / (coff + 1e-8)
106 | # outputs = torch.where(coff == 0, outputs, outputs/coff)
107 | outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)]
108 |
109 | return outputs
110 |
111 |
112 | ############################################################################
113 | # for complex rnn #
114 | ############################################################################
115 | def get_casual_padding1d():
116 | pass
117 |
118 |
119 | def get_casual_padding2d():
120 | pass
121 |
122 |
123 | class cPReLU(nn.Module):
124 |
125 | def __init__(self, complex_axis=1):
126 | super(cPReLU, self).__init__()
127 | self.r_prelu = nn.PReLU()
128 | self.i_prelu = nn.PReLU()
129 | self.complex_axis = complex_axis
130 |
131 | def forward(self, inputs):
132 | real, imag = torch.chunk(inputs, 2, self.complex_axis)
133 | real = self.r_prelu(real)
134 | imag = self.i_prelu(imag)
135 | return torch.cat([real, imag], self.complex_axis)
136 |
137 |
138 | class NavieComplexLSTM(nn.Module):
139 | def __init__(self, input_size, hidden_size, projection_dim=None, bidirectional=False, batch_first=False):
140 | super(NavieComplexLSTM, self).__init__()
141 |
142 | self.input_dim = input_size // 2
143 | self.rnn_units = hidden_size // 2
144 | self.real_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional,
145 | batch_first=False)
146 | self.imag_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional,
147 | batch_first=False)
148 | if bidirectional:
149 | bidirectional = 2
150 | else:
151 | bidirectional = 1
152 | if projection_dim is not None:
153 | self.projection_dim = projection_dim // 2
154 | self.r_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim)
155 | self.i_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim)
156 | else:
157 | self.projection_dim = None
158 |
159 | def forward(self, inputs):
160 | if isinstance(inputs, list):
161 | real, imag = inputs
162 | elif isinstance(inputs, torch.Tensor):
163 | real, imag = torch.chunk(inputs, -1)
164 | r2r_out = self.real_lstm(real)[0]
165 | r2i_out = self.imag_lstm(real)[0]
166 | i2r_out = self.real_lstm(imag)[0]
167 | i2i_out = self.imag_lstm(imag)[0]
168 | real_out = r2r_out - i2i_out
169 | imag_out = i2r_out + r2i_out
170 | if self.projection_dim is not None:
171 | real_out = self.r_trans(real_out)
172 | imag_out = self.i_trans(imag_out)
173 | # print(real_out.shape,imag_out.shape)
174 | return [real_out, imag_out]
175 |
176 | def flatten_parameters(self):
177 | self.imag_lstm.flatten_parameters()
178 | self.real_lstm.flatten_parameters()
179 |
180 |
181 | def complex_cat(inputs, axis):
182 | real, imag = [], []
183 | for idx, data in enumerate(inputs):
184 | r, i = torch.chunk(data, 2, axis) # x = torch.chunk(x, n, dim) >> x의 dim 차원을 n개씩 잘라서 뽑아옴
185 | real.append(r)
186 | imag.append(i)
187 | real = torch.cat(real, axis) # torch.cat : 차원 늘리기
188 | imag = torch.cat(imag, axis)
189 | outputs = torch.cat([real, imag], axis)
190 | return outputs
191 |
192 |
193 | class ComplexConv2d(nn.Module):
194 |
195 | def __init__(
196 | self,
197 | in_channels,
198 | out_channels,
199 | kernel_size=(1, 1),
200 | stride=(1, 1),
201 | padding=(0, 0),
202 | dilation=1,
203 | groups=1,
204 | causal=True,
205 | complex_axis=1,
206 | ):
207 | '''
208 | in_channels: real+imag
209 | out_channels: real+imag
210 | kernel_size : input [B,C,D,T] kernel size in [D,T]
211 | padding : input [B,C,D,T] padding in [D,T]
212 | causal: if causal, will padding time dimension's left side,
213 | otherwise both
214 |
215 | '''
216 | super(ComplexConv2d, self).__init__()
217 | self.in_channels = in_channels // 2
218 | self.out_channels = out_channels // 2
219 | self.kernel_size = kernel_size
220 | self.stride = stride
221 | self.padding = padding
222 | self.causal = causal
223 | self.groups = groups
224 | self.dilation = dilation
225 | self.complex_axis = complex_axis
226 | self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
227 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)
228 | self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
229 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)
230 |
231 | nn.init.normal_(self.real_conv.weight.data, std=0.05)
232 | nn.init.normal_(self.imag_conv.weight.data, std=0.05)
233 | nn.init.constant_(self.real_conv.bias, 0.)
234 | nn.init.constant_(self.imag_conv.bias, 0.)
235 |
236 | def forward(self, inputs):
237 | if self.padding[1] != 0 and self.causal:
238 | inputs = F.pad(inputs, [self.padding[1], 0, 0, 0])
239 | else:
240 | inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0])
241 |
242 | if self.complex_axis == 0:
243 | real = self.real_conv(inputs)
244 | imag = self.imag_conv(inputs)
245 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis)
246 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis)
247 |
248 | else:
249 | if isinstance(inputs, torch.Tensor):
250 | real, imag = torch.chunk(inputs, 2, self.complex_axis)
251 |
252 | real2real = self.real_conv(real, )
253 | imag2imag = self.imag_conv(imag, )
254 |
255 | real2imag = self.imag_conv(real)
256 | imag2real = self.real_conv(imag)
257 |
258 | real = real2real - imag2imag
259 | imag = real2imag + imag2real
260 | out = torch.cat([real, imag], self.complex_axis)
261 |
262 | return out
263 |
264 |
265 | class ComplexConvTranspose2d(nn.Module):
266 |
267 | def __init__(
268 | self,
269 | in_channels,
270 | out_channels,
271 | kernel_size=(1, 1),
272 | stride=(1, 1),
273 | padding=(0, 0),
274 | output_padding=(0, 0),
275 | causal=False,
276 | complex_axis=1,
277 | groups=1
278 | ):
279 | '''
280 | in_channels: real+imag
281 | out_channels: real+imag
282 | '''
283 | super(ComplexConvTranspose2d, self).__init__()
284 | self.in_channels = in_channels // 2
285 | self.out_channels = out_channels // 2
286 | self.kernel_size = kernel_size
287 | self.stride = stride
288 | self.padding = padding
289 | self.output_padding = output_padding
290 | self.groups = groups
291 |
292 | self.real_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
293 | padding=self.padding, output_padding=output_padding, groups=self.groups)
294 | self.imag_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
295 | padding=self.padding, output_padding=output_padding, groups=self.groups)
296 | self.complex_axis = complex_axis
297 |
298 | nn.init.normal_(self.real_conv.weight, std=0.05)
299 | nn.init.normal_(self.imag_conv.weight, std=0.05)
300 | nn.init.constant_(self.real_conv.bias, 0.)
301 | nn.init.constant_(self.imag_conv.bias, 0.)
302 |
303 | def forward(self, inputs):
304 |
305 | if isinstance(inputs, torch.Tensor):
306 | real, imag = torch.chunk(inputs, 2, self.complex_axis)
307 | elif isinstance(inputs, tuple) or isinstance(inputs, list):
308 | real = inputs[0]
309 | imag = inputs[1]
310 | if self.complex_axis == 0:
311 | real = self.real_conv(inputs)
312 | imag = self.imag_conv(inputs)
313 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis)
314 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis)
315 |
316 | else:
317 | if isinstance(inputs, torch.Tensor):
318 | real, imag = torch.chunk(inputs, 2, self.complex_axis)
319 |
320 | real2real = self.real_conv(real, )
321 | imag2imag = self.imag_conv(imag, )
322 |
323 | real2imag = self.imag_conv(real)
324 | imag2real = self.real_conv(imag)
325 |
326 | real = real2real - imag2imag
327 | imag = real2imag + imag2real
328 | out = torch.cat([real, imag], self.complex_axis)
329 |
330 | return out
331 |
332 |
333 | # Source: https://github.com/ChihebTrabelsi/deep_complex_networks/tree/pytorch
334 | # from https://github.com/IMLHF/SE_DCUNet/blob/f28bf1661121c8901ad38149ea827693f1830715/models/layers/complexnn.py#L55
335 | class ComplexBatchNorm(torch.nn.Module):
336 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
337 | track_running_stats=True, complex_axis=1):
338 | super(ComplexBatchNorm, self).__init__()
339 | self.num_features = num_features // 2
340 | self.eps = eps
341 | self.momentum = momentum
342 | self.affine = affine
343 | self.track_running_stats = track_running_stats
344 |
345 | self.complex_axis = complex_axis
346 |
347 | if self.affine:
348 | self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features))
349 | self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features))
350 | self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features))
351 | self.Br = torch.nn.Parameter(torch.Tensor(self.num_features))
352 | self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features))
353 | else:
354 | self.register_parameter('Wrr', None)
355 | self.register_parameter('Wri', None)
356 | self.register_parameter('Wii', None)
357 | self.register_parameter('Br', None)
358 | self.register_parameter('Bi', None)
359 |
360 | if self.track_running_stats:
361 | self.register_buffer('RMr', torch.zeros(self.num_features))
362 | self.register_buffer('RMi', torch.zeros(self.num_features))
363 | self.register_buffer('RVrr', torch.ones(self.num_features))
364 | self.register_buffer('RVri', torch.zeros(self.num_features))
365 | self.register_buffer('RVii', torch.ones(self.num_features))
366 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
367 | else:
368 | self.register_parameter('RMr', None)
369 | self.register_parameter('RMi', None)
370 | self.register_parameter('RVrr', None)
371 | self.register_parameter('RVri', None)
372 | self.register_parameter('RVii', None)
373 | self.register_parameter('num_batches_tracked', None)
374 | self.reset_parameters()
375 |
376 | def reset_running_stats(self):
377 | if self.track_running_stats:
378 | self.RMr.zero_()
379 | self.RMi.zero_()
380 | self.RVrr.fill_(1)
381 | self.RVri.zero_()
382 | self.RVii.fill_(1)
383 | self.num_batches_tracked.zero_()
384 |
385 | def reset_parameters(self):
386 | self.reset_running_stats()
387 | if self.affine:
388 | self.Br.data.zero_()
389 | self.Bi.data.zero_()
390 | self.Wrr.data.fill_(1)
391 | self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
392 | self.Wii.data.fill_(1)
393 |
394 | def _check_input_dim(self, xr, xi):
395 | assert (xr.shape == xi.shape)
396 | assert (xr.size(1) == self.num_features)
397 |
398 | def forward(self, inputs):
399 | # self._check_input_dim(xr, xi)
400 |
401 | xr, xi = torch.chunk(inputs, 2, axis=self.complex_axis)
402 | exponential_average_factor = 0.0
403 |
404 | if self.training and self.track_running_stats:
405 | self.num_batches_tracked += 1
406 | if self.momentum is None: # use cumulative moving average
407 | exponential_average_factor = 1.0 / self.num_batches_tracked.item()
408 | else: # use exponential moving average
409 | exponential_average_factor = self.momentum
410 |
411 | #
412 | # NOTE: The precise meaning of the "training flag" is:
413 | # True: Normalize using batch statistics, update running statistics
414 | # if they are being collected.
415 | # False: Normalize using running statistics, ignore batch statistics.
416 | #
417 | training = self.training or not self.track_running_stats
418 | redux = [i for i in reversed(range(xr.dim())) if i != 1]
419 | vdim = [1] * xr.dim()
420 | vdim[1] = xr.size(1)
421 |
422 | #
423 | # Mean M Computation and Centering
424 | #
425 | # Includes running mean update if training and running.
426 | #
427 | if training:
428 | Mr, Mi = xr, xi
429 | for d in redux:
430 | Mr = Mr.mean(d, keepdim=True)
431 | Mi = Mi.mean(d, keepdim=True)
432 | if self.track_running_stats:
433 | self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
434 | self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
435 | else:
436 | Mr = self.RMr.view(vdim)
437 | Mi = self.RMi.view(vdim)
438 | xr, xi = xr - Mr, xi - Mi
439 |
440 | #
441 | # Variance Matrix V Computation
442 | #
443 | # Includes epsilon numerical stabilizer/Tikhonov regularizer.
444 | # Includes running variance update if training and running.
445 | #
446 | if training:
447 | Vrr = xr * xr
448 | Vri = xr * xi
449 | Vii = xi * xi
450 | for d in redux:
451 | Vrr = Vrr.mean(d, keepdim=True)
452 | Vri = Vri.mean(d, keepdim=True)
453 | Vii = Vii.mean(d, keepdim=True)
454 | if self.track_running_stats:
455 | self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
456 | self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
457 | self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
458 | else:
459 | Vrr = self.RVrr.view(vdim)
460 | Vri = self.RVri.view(vdim)
461 | Vii = self.RVii.view(vdim)
462 | Vrr = Vrr + self.eps
463 | Vri = Vri
464 | Vii = Vii + self.eps
465 |
466 | #
467 | # Matrix Inverse Square Root U = V^-0.5
468 | #
469 | # sqrt of a 2x2 matrix,
470 | # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
471 | tau = Vrr + Vii
472 | delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri)
473 | s = delta.sqrt()
474 | t = (tau + 2 * s).sqrt()
475 |
476 | # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
477 | rst = (s * t).reciprocal()
478 | Urr = (s + Vii) * rst
479 | Uii = (s + Vrr) * rst
480 | Uri = (- Vri) * rst
481 |
482 | #
483 | # Optionally left-multiply U by affine weights W to produce combined
484 | # weights Z, left-multiply the inputs by Z, then optionally bias them.
485 | #
486 | # y = Zx + B
487 | # y = WUx + B
488 | # y = [Wrr Wri][Urr Uri] [xr] + [Br]
489 | # [Wir Wii][Uir Uii] [xi] [Bi]
490 | #
491 | if self.affine:
492 | Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
493 | Zrr = (Wrr * Urr) + (Wri * Uri)
494 | Zri = (Wrr * Uri) + (Wri * Uii)
495 | Zir = (Wri * Urr) + (Wii * Uri)
496 | Zii = (Wri * Uri) + (Wii * Uii)
497 | else:
498 | Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
499 |
500 | yr = (Zrr * xr) + (Zri * xi)
501 | yi = (Zir * xr) + (Zii * xi)
502 |
503 | if self.affine:
504 | yr = yr + self.Br.view(vdim)
505 | yi = yi + self.Bi.view(vdim)
506 |
507 | outputs = torch.cat([yr, yi], self.complex_axis)
508 | return outputs
509 |
510 | def extra_repr(self):
511 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
512 | 'track_running_stats={track_running_stats}'.format(**self.__dict__)
513 |
514 |
515 | def complex_cat(inputs, axis):
516 | real, imag = [], []
517 | for idx, data in enumerate(inputs):
518 | r, i = torch.chunk(data, 2, axis)
519 | real.append(r)
520 | imag.append(i)
521 | real = torch.cat(real, axis)
522 | imag = torch.cat(imag, axis)
523 | outputs = torch.cat([real, imag], axis)
524 | return outputs
525 |
526 |
527 | ############################################################################
528 | # for data normalization #
529 | ############################################################################
530 | # get mu and sig
531 | def get_mu_sig(data):
532 | """Compute mean and standard deviation vector of input data
533 |
534 | Returns:
535 | mu: mean vector (#dim by one)
536 | sig: standard deviation vector (#dim by one)
537 | """
538 | # Initialize array.
539 | data_num = len(data)
540 | mu_utt = []
541 | tmp_utt = []
542 | for n in range(data_num):
543 | dim = len(data[n])
544 | mu_utt_tmp = np.zeros(dim)
545 | mu_utt.append(mu_utt_tmp)
546 |
547 | tmp_utt_tmp = np.zeros(dim)
548 | tmp_utt.append(tmp_utt_tmp)
549 |
550 |
551 | # Get mean.
552 | for n in range(data_num):
553 | mu_utt[n] = np.mean(data[n], 0)
554 | mu = mu_utt
555 |
556 | # Get standard deviation.
557 | for n in range(data_num):
558 | tmp_utt[n] = np.mean(np.square(data[n] - mu[n]), 0)
559 | sig = np.sqrt(tmp_utt)
560 |
561 | # Assign unit variance.
562 | for n in range(len(sig)):
563 | if sig[n] < 1e-5:
564 | sig[n] = 1.0
565 | return np.float16(mu), np.float16(sig)
566 |
567 |
568 | def get_statistics_inp(inp):
569 | """Get statistical parameter of input data.
570 |
571 | Args:
572 | inp: input data
573 |
574 | Returns:
575 | mu_inp: mean vector of input data
576 | sig_inp: standard deviation vector of input data
577 | """
578 |
579 | mu_inp, sig_inp = get_mu_sig(inp)
580 |
581 | return mu_inp, sig_inp
582 |
583 |
584 | ############################################################################
585 | # for scores #
586 | ############################################################################
587 | def cal_pesq(dirty_wavs, clean_wavs):
588 | pesq_scores = []
589 | for i in range(len(dirty_wavs)):
590 | pesq_score = pesq(cfg.FS, clean_wavs[i], dirty_wavs[i], "wb")
591 | pesq_scores.append(pesq_score)
592 | return pesq_scores
593 |
594 |
595 | def cal_stoi(dirty_wavs, clean_wavs):
596 | stoi_scores = []
597 | for i in range(len(dirty_wavs)):
598 | stoi_score = stoi(clean_wavs[i], dirty_wavs[i], cfg.FS, extended=False)
599 | stoi_scores.append(stoi_score)
600 | return stoi_scores
601 |
602 |
603 | ############################################################################
604 | # for plotting the samples #
605 | ############################################################################
606 | def hann_window(win_samp):
607 | tmp = np.arange(1, win_samp + 1, 1.0, dtype=np.float64)
608 | window = 0.5 - 0.5 * np.cos((2.0 * np.pi * tmp) / (win_samp + 1))
609 | return np.float32(window)
610 |
611 |
612 | def fig2np(fig):
613 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
614 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
615 | return data
616 |
617 |
618 | def plot_spectrogram_to_numpy(input_wav, fs, n_fft, n_overlap, win, mode, clim, label):
619 | # cuda to cpu
620 | input_wav = input_wav.cpu().detach().numpy()
621 |
622 | fig, ax = plt.subplots(figsize=(12, 3))
623 |
624 | if mode == 'phase':
625 | pxx, freq, t, cax = plt.specgram(input_wav, NFFT=int(n_fft), Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet',
626 | mode=mode)
627 | else:
628 | pxx, freq, t, cax = plt.specgram(input_wav, NFFT=int(n_fft), Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet')
629 |
630 | plt.xlabel('Time (s)')
631 | plt.ylabel('Frequency (Hz)')
632 | plt.tight_layout()
633 | plt.clim(clim)
634 |
635 | if label is None:
636 | fig.colorbar(cax)
637 | else:
638 | fig.colorbar(cax, label=label)
639 |
640 | fig.canvas.draw()
641 | data = fig2np(fig)
642 | plt.close()
643 | return data
644 |
645 |
646 | def plot_mask_to_numpy(mask, fs, n_fft, n_overlap, win, clim1, clim2, cmap):
647 | frame_num = mask.shape[0]
648 | shift_length = n_overlap
649 | frame_length = n_fft
650 | signal_length = frame_num * shift_length + frame_length
651 |
652 | xt = np.arange(0, np.floor(10 * signal_length / fs) / 10, step=0.5) / (signal_length / fs) * frame_num + 1e-8
653 | yt = (n_fft / 2) / (fs / 1000 / 2) * np.arange(0, (fs / 1000 / 2) + 1)
654 |
655 | fig, ax = plt.subplots(figsize=(12, 3))
656 | im = ax.imshow(np.transpose(mask), aspect='auto', origin='lower', interpolation='none', cmap=cmap)
657 |
658 | plt.xlabel('Time (s)')
659 | plt.ylabel('Frequency (kHz)')
660 | plt.xticks(xt, np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5))
661 | plt.yticks(yt, np.int16(np.linspace(0, int((fs / 1000) / 2), len(yt))))
662 | plt.tight_layout()
663 | plt.colorbar(im, ax=ax)
664 | im.set_clim(clim1, clim2)
665 |
666 | fig.canvas.draw()
667 | data = fig2np(fig)
668 | plt.close()
669 | return data
670 |
671 |
672 | def plot_error_to_numpy(estimated, target, fs, n_fft, n_overlap, win, mode, clim1, clim2, label):
673 | fig, ax = plt.subplots(figsize=(12, 3))
674 | if mode == None:
675 | pxx1, freq, t, cax = plt.specgram(estimated, NFFT=n_fft, Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet')
676 | pxx2, freq, t, cax = plt.specgram(target, NFFT=n_fft, Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet')
677 | im = ax.imshow(10 * np.log10(pxx1) - 10 * np.log10(pxx2), aspect='auto', origin='lower', interpolation='none',
678 | cmap='jet')
679 | else:
680 | pxx1, freq, t, cax = plt.specgram(estimated, NFFT=n_fft, Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet',
681 | mode=mode)
682 | pxx2, freq, t, cax = plt.specgram(target, NFFT=n_fft, Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet',
683 | mode=mode)
684 | im = ax.imshow(pxx1 - pxx2, aspect='auto', origin='lower', interpolation='none', cmap='jet')
685 |
686 | frame_num = pxx1.shape[1]
687 | shift_length = n_overlap
688 | frame_length = n_fft
689 | signal_length = frame_num * shift_length + frame_length
690 |
691 | xt = np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5) / (signal_length / fs) * frame_num
692 | yt = (n_fft / 2) / (fs / 1000 / 2) * np.arange(0, (fs / 1000 / 2) + 1)
693 |
694 | plt.xlabel('Time (s)')
695 | plt.ylabel('Frequency (kHz)')
696 | plt.xticks(xt, np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5))
697 | plt.yticks(yt, np.int16(np.linspace(0, int((fs / 1000) / 2), len(yt))))
698 | plt.tight_layout()
699 | plt.colorbar(im, ax=ax, label=label)
700 | im.set_clim(clim1, clim2)
701 |
702 | fig.canvas.draw()
703 | data = fig2np(fig)
704 | plt.close()
705 | return data
706 |
707 |
708 | ############################################################################
709 | # for run.py #
710 | ############################################################################
711 | def near_avg_index(array):
712 | array_mean = np.mean(array)
713 |
714 | distance_arr = []
715 | for i in range(len(array)):
716 | val = array[i]
717 | distance = abs(array_mean - val)
718 | distance_arr.append(distance)
719 |
720 | index = distance_arr.index(min(distance_arr))
721 | return index
722 |
723 |
724 | def max_index(array):
725 | array_max = np.max(array)
726 |
727 | for i in range(len(array)):
728 | val = array[i]
729 | if val == array_max:
730 | index = i
731 | return index
732 |
733 |
734 | def min_index(array):
735 | array_min = np.min(array)
736 |
737 | for i in range(len(array)):
738 | val = array[i]
739 | if val == array_min:
740 | index = i
741 | return index
742 |
743 |
744 | class Bar(object):
745 | def __init__(self, dataloader):
746 | if not hasattr(dataloader, 'dataset'):
747 | raise ValueError('Attribute `dataset` not exists in dataloder.')
748 | if not hasattr(dataloader, 'batch_size'):
749 | raise ValueError('Attribute `batch_size` not exists in dataloder.')
750 |
751 | self.dataloader = dataloader
752 | self.iterator = iter(dataloader)
753 | self.dataset = dataloader.dataset
754 | self.batch_size = dataloader.batch_size
755 | self._idx = 0
756 | self._batch_idx = 0
757 | self._time = []
758 | self._DISPLAY_LENGTH = 50
759 |
760 | def __len__(self):
761 | return len(self.dataloader)
762 |
763 | def __iter__(self):
764 | return self
765 |
766 | def __next__(self):
767 | if len(self._time) < 2:
768 | self._time.append(time.time())
769 |
770 | self._batch_idx += self.batch_size
771 | if self._batch_idx > len(self.dataset):
772 | self._batch_idx = len(self.dataset)
773 |
774 | try:
775 | batch = next(self.iterator)
776 | self._display()
777 | except StopIteration:
778 | raise StopIteration()
779 |
780 | self._idx += 1
781 | if self._idx >= len(self.dataloader):
782 | self._reset()
783 |
784 | return batch
785 |
786 | def _display(self):
787 | if len(self._time) > 1:
788 | t = (self._time[-1] - self._time[-2])
789 | eta = t * (len(self.dataloader) - self._idx)
790 | else:
791 | eta = 0
792 |
793 | rate = self._idx / len(self.dataloader)
794 | len_bar = int(rate * self._DISPLAY_LENGTH)
795 | bar = ('=' * len_bar + '>').ljust(self._DISPLAY_LENGTH, '.')
796 | idx = str(self._batch_idx).rjust(len(str(len(self.dataset))), ' ')
797 |
798 | tmpl = '\r{}/{}: [{}] - ETA {:.1f}s'.format(
799 | idx,
800 | len(self.dataset),
801 | bar,
802 | eta
803 | )
804 | print(tmpl, end='')
805 | if self._batch_idx == len(self.dataset):
806 | print()
807 |
808 | def _reset(self):
809 | self._idx = 0
810 | self._batch_idx = 0
811 | self._time = []
812 |
813 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Run the trainer and tester
3 | """
4 | import torch
5 | import numpy as np
6 | from scipy.io.wavfile import write as wav_write
7 | from tools_for_model import near_avg_index, max_index, min_index, Bar, cal_pesq, cal_stoi
8 | from config import fs, info, mode
9 |
10 |
11 | def model_train(model, optimizer, train_loader, epoch, DEVICE):
12 | # initialization
13 | train_loss = 0
14 | batch_num = 0
15 |
16 | # train
17 | model.train()
18 | for inputs, labels in Bar(train_loader):
19 | batch_num += 1
20 |
21 | # to cuda
22 | inputs = inputs.float().to(DEVICE)
23 | labels = labels.float().to(DEVICE)
24 |
25 | _, _, real_spec, img_spec, outputs = model(inputs)
26 | loss = model.loss(outputs, labels, real_spec, img_spec)
27 | # loss = model.pmsqe_loss(labels, outputs)
28 |
29 | optimizer.zero_grad()
30 | loss.backward()
31 | optimizer.step()
32 |
33 | train_loss += loss
34 | train_loss /= batch_num
35 |
36 | return train_loss
37 |
38 |
39 | def model_validate(model, validation_loader, dir_to_save, writer, epoch, DEVICE):
40 | # initialization
41 | batch_num = 0
42 | validation_loss = 0
43 | avg_pesq = 0
44 | avg_stoi = 0
45 |
46 | all_batch_input = []
47 | all_batch_label = []
48 | all_batch_output = []
49 | all_batch_real_spec = []
50 | all_batch_img_spec = []
51 | all_batch_pesq = []
52 |
53 | f_pesq = open(dir_to_save + '/pesq_epoch_' + '%d' % epoch, 'a')
54 | f_stoi = open(dir_to_save + '/stoi_epoch_' + '%d' % epoch, 'a')
55 |
56 | model.eval()
57 |
58 | with torch.no_grad():
59 | for inputs, labels in Bar(validation_loader):
60 | batch_num += 1
61 |
62 | # to cuda
63 | inputs = inputs.float().to(DEVICE)
64 | labels = labels.float().to(DEVICE)
65 |
66 | mask_real, mask_imag, real_spec, img_spec, outputs = model(inputs)
67 | loss = model.loss(outputs, labels, real_spec, img_spec)
68 |
69 | # loss = model.pmsqe_loss(labels, outputs)
70 |
71 | # estimate the output speech with pesq and stoi
72 | # save pesq & stoi score at each epoch
73 | estimated_wavs = outputs.cpu().detach().numpy()
74 | clean_wavs = labels.cpu().detach().numpy()
75 |
76 | pesq = cal_pesq(estimated_wavs, clean_wavs) ## 98
77 | stoi = cal_stoi(estimated_wavs, clean_wavs)
78 |
79 | # pesq: 0.1 better / stoi: 0.01 better
80 | for i in range(len(pesq)):
81 | f_pesq.write('{:.6f}\n'.format(pesq[i]))
82 | f_stoi.write('{:.4f}\n'.format(stoi[i]))
83 |
84 | # reshape for sum
85 | pesq = np.reshape(pesq, (1, -1))
86 | stoi = np.reshape(stoi, (1, -1))
87 |
88 | avg_pesq += sum(pesq[0]) / len(inputs)
89 | avg_stoi += sum(stoi[0]) / len(inputs)
90 |
91 | if epoch % 10 == 0:
92 | # all batch data array
93 | all_batch_input.extend(inputs)
94 | all_batch_label.extend(labels)
95 | all_batch_output.extend(outputs)
96 | all_batch_real_spec.extend(mask_real)
97 | all_batch_img_spec.extend(mask_imag)
98 | all_batch_pesq.extend(pesq[0])
99 |
100 | validation_loss += loss
101 |
102 | # save the samples to tensorboard
103 | if epoch % 10 == 0:
104 | all_batch_pesq = np.reshape(all_batch_pesq, (-1, 1))
105 |
106 | # find the best & worst pesq model
107 | max_pesq_index = max_index(all_batch_pesq)
108 | min_pesq_index = min_index(all_batch_pesq)
109 |
110 | # find the avg pesq model
111 | avg_pesq_index = near_avg_index(all_batch_pesq)
112 |
113 | # save the samples to tensorboard
114 | # the best pesq
115 | writer.save_samples_we_want('max_pesq', all_batch_input[max_pesq_index], all_batch_label[max_pesq_index],
116 | all_batch_output[max_pesq_index], epoch)
117 | # the worst pesq
118 | writer.save_samples_we_want('min_pesq', all_batch_input[min_pesq_index], all_batch_label[min_pesq_index],
119 | all_batch_output[min_pesq_index], epoch)
120 | # the avg pesq
121 | writer.save_samples_we_want('avg_pesq', all_batch_input[avg_pesq_index], all_batch_label[avg_pesq_index],
122 | all_batch_output[avg_pesq_index], epoch)
123 |
124 | # save the same sample
125 | clip_num = 10
126 | writer.save_samples_we_want('n{}_sample'.format(clip_num), all_batch_input[clip_num], all_batch_label[clip_num],
127 | all_batch_output[clip_num], epoch)
128 |
129 | validation_loss /= batch_num
130 | avg_pesq /= batch_num
131 | avg_stoi /= batch_num
132 |
133 | # save average score
134 | f_pesq.write('Avg: {:.6f}\n'.format(avg_pesq))
135 | f_stoi.write('Avg: {:.4f}\n'.format(avg_stoi))
136 |
137 | f_pesq.close()
138 | f_stoi.close()
139 | return validation_loss, avg_pesq, avg_stoi
140 |
141 |
142 | def model_test(noise_type, snr, model, test_loader, dir_to_save, DEVICE):
143 | model.eval()
144 | with torch.no_grad():
145 | # initialization
146 | batch_num = 0
147 | test_loss = 0
148 | avg_pesq = 0
149 | avg_stoi = 0
150 |
151 | all_batch_input = []
152 | all_batch_label = []
153 | all_batch_output = []
154 | all_batch_real_spec = []
155 | all_batch_img_spec = []
156 | all_batch_pesq = []
157 |
158 | # f_pesq = open(dir_to_save + '/test_pesq_epoch{}_{}_{}dB'
159 | # .format(min_index + 1, noise_type, snr), 'a')
160 | # f_stoi = open(dir_to_save + '/test_stoi_epoch{}_{}_{}dB'
161 | # .format(min_index + 1, noise_type, snr), 'a')
162 | for inputs, labels in Bar(test_loader):
163 | batch_num += 1
164 |
165 | # to cuda
166 | inputs = inputs.float().to(DEVICE)
167 | labels = labels.float().to(DEVICE)
168 |
169 | mask_real, mask_imag, real_spec, img_spec, outputs = model(inputs)
170 | loss = model.loss(outputs, labels, real_spec, img_spec)
171 | # loss = model.pmsqe_loss(labels, outputs)
172 | # estimate the output speech with pesq and stoi
173 | # save pesq & stoi score at each epoch
174 | # [18480, 1]
175 | estimated_wavs = outputs.cpu().detach().numpy()
176 | clean_wavs = labels.cpu().detach().numpy()
177 |
178 | pesq = cal_pesq(estimated_wavs, clean_wavs)
179 | stoi = cal_stoi(estimated_wavs, clean_wavs)
180 |
181 | # # pesq: 0.1 better / stoi: 0.01 better
182 | # for i in range(len(pesq)):
183 | # f_pesq.write('{:.6f}\n'.format(pesq[i]))
184 | # f_stoi.write('{:.4f}\n'.format(stoi[i]))
185 |
186 | test_loss += loss
187 |
188 | # reshape for sum
189 | pesq = np.reshape(pesq, (1, -1))
190 | stoi = np.reshape(stoi, (1, -1))
191 |
192 | avg_pesq += sum(pesq[0]) / len(inputs)
193 | avg_stoi += sum(stoi[0]) / len(inputs)
194 |
195 | # all batch data array
196 | all_batch_input.extend(inputs)
197 | all_batch_label.extend(labels)
198 | all_batch_output.extend(outputs)
199 | all_batch_real_spec.extend(mask_real)
200 | all_batch_img_spec.extend(mask_imag)
201 | all_batch_pesq.extend(pesq[0])
202 |
203 | # find the best & worst pesq model
204 | max_pesq_index = all_batch_pesq.index(max(all_batch_pesq))
205 | min_pesq_index = all_batch_pesq.index(min(all_batch_pesq))
206 |
207 | test_loss /= batch_num
208 | avg_pesq /= batch_num
209 | avg_stoi /= batch_num
210 |
211 | max_pesq = all_batch_pesq[max_pesq_index]
212 | min_pesq = all_batch_pesq[min_pesq_index]
213 |
214 | # save average score
215 | # f_pesq.write('Max: {:.6f} | Min: {:.6f} | Avg: {:.6f}\n'.format(max_pesq, min_pesq, avg_pesq))
216 | # f_stoi.write('Avg: {:.4f}\n'.format(avg_stoi))
217 | # f_pesq.close()
218 | # f_stoi.close()
219 | return test_loss, avg_pesq, avg_stoi
220 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Interface for train
3 | """
4 |
5 | import os
6 | import time
7 | import torch
8 | import shutil
9 | import numpy as np
10 | import config as cfg
11 | from train import model_train, model_validate, model_test
12 | from dataloader import create_dataloader, create_dataloader_for_test
13 | from model import DCCRN, DCUNET, DCCRN_direct, DCCRN_no_skip
14 | from write_on_tensorboard import Writer
15 |
16 |
17 | ###############################################################################
18 | # Helper function definition #
19 | ###############################################################################
20 | # Write training related parameters into the log file.
21 | def write_status_to_log_file(fp, total_parameters):
22 | fp.write('adsfasdfsdfds')
23 | fp.write('%d-%d-%d %d:%d:%d\n' %
24 | (time.localtime().tm_year, time.localtime().tm_mon,
25 | time.localtime().tm_mday, time.localtime().tm_hour,
26 | time.localtime().tm_min, time.localtime().tm_sec))
27 | fp.write('mode : %s_%s\n' % (cfg.mode, cfg.info))
28 | fp.write('learning rate : %g\n' % cfg.learning_rate)
29 | fp.write('total params : %d (%.2f M, %.2f MBytes)\n' %
30 | (total_parameters,
31 | total_parameters / 1000000.0,
32 | total_parameters * 4.0 / 1000000.0))
33 |
34 |
35 | # Calculate the size of total network.
36 | def calculate_total_params(our_model):
37 | total_parameters = 0
38 | for variable in our_model.parameters():
39 | shape = variable.size()
40 | variable_parameters = 1
41 | for dim in shape:
42 | variable_parameters *= dim
43 | total_parameters += variable_parameters
44 |
45 | return total_parameters
46 |
47 |
48 | ###############################################################################
49 | # Parameter Initialization #
50 | ###############################################################################
51 | print('***********************************************************')
52 | print('* Python library for DNN-based speech enhancement *')
53 | print('* using Pytorch API *')
54 | print('***********************************************************')
55 |
56 | # Set device
57 | DEVICE = torch.device("cuda")
58 |
59 | # Set model
60 | if cfg.mode == 'DCCRN':
61 | model = DCCRN(rnn_units=cfg.rnn_units, masking_mode=cfg.masking_mode, use_clstm=cfg.use_clstm,
62 | kernel_num=cfg.kernel_num).to(DEVICE)
63 | elif cfg.mode == 'DCUNET':
64 | model = DCUNET(masking_mode=cfg.masking_mode, kernel_num=cfg.kernel_num).to(DEVICE)
65 | elif cfg.mode == 'DCCRN_direct':
66 | model = DCCRN_direct(rnn_units=cfg.rnn_units, use_clstm=cfg.use_clstm, kernel_num=cfg.kernel_num).to(DEVICE)
67 |
68 | ###############################################################################
69 | # Set optimizer and learning rate #
70 | ###############################################################################
71 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)
72 | total_params = calculate_total_params(model)
73 |
74 | ###############################################################################
75 | # Confirm model information #
76 | ###############################################################################
77 | print('%d-%d-%d %d:%d:%d\n' %
78 | (time.localtime().tm_year, time.localtime().tm_mon,
79 | time.localtime().tm_mday, time.localtime().tm_hour,
80 | time.localtime().tm_min, time.localtime().tm_sec))
81 | print('mode : %s_%s\n' % (cfg.mode, cfg.info))
82 | print('learning rate : %g\n' % cfg.learning_rate)
83 | print('total params : %d (%.2f M, %.2f MBytes)\n' %
84 | (total_params,
85 | total_params / 1000000.0,
86 | total_params * 4.0 / 1000000.0))
87 |
88 | ###############################################################################
89 | # Create Dataloader #
90 | ###############################################################################
91 | # Set device
92 | DEVICE = torch.device("cuda")
93 |
94 | train_loader = create_dataloader(mode='train')
95 | validation_loader = create_dataloader(mode='valid')
96 |
97 | ###############################################################################
98 | # Set a log file to store progress. #
99 | # Set a hps file to store hyper-parameters information. #
100 | ###############################################################################
101 | # Load the checkpoint
102 | if cfg.chkpt_path is not None:
103 | print('Resuming from checkpoint: %s' % cfg.chkpt_path)
104 |
105 | # Set a log file to store progress.
106 | dir_to_save = cfg.job_dir + cfg.chkpt_model
107 | dir_to_logs = cfg.logs_dir + cfg.chkpt_model
108 |
109 | checkpoint = torch.load(cfg.chkpt_path)
110 | model.load_state_dict(checkpoint['model'])
111 | optimizer.load_state_dict(checkpoint['optimizer'])
112 | epoch_start_idx = checkpoint['epoch'] + 1
113 | mse_vali_total = np.load(str(dir_to_save + '/mse_vali_total.npy'))
114 | if len(mse_vali_total) < cfg.max_epochs:
115 | plus = cfg.max_epochs - len(mse_vali_total)
116 | mse_vali_total = np.concatenate((mse_vali_total, np.zeros(plus)), 0)
117 | else:
118 | print('Starting new training run')
119 | epoch_start_idx = 1
120 | mse_vali_total = np.zeros(cfg.max_epochs)
121 |
122 | # Set a log file to store progress.
123 | dir_to_save = str(cfg.job_dir) + '%d.%d' % (time.localtime().tm_mon, time.localtime().tm_mday) \
124 | + '_%s' % cfg.mode + '_%s' % cfg.info
125 | dir_to_logs = str(cfg.logs_dir) + '%d.%d' % (time.localtime().tm_mon, time.localtime().tm_mday) \
126 | + '_%s' % cfg.mode + '_%s' % cfg.info
127 |
128 | if not os.path.exists(dir_to_save):
129 | os.mkdir(dir_to_save)
130 | os.mkdir(dir_to_logs)
131 |
132 | log_fname = str(dir_to_save + '/log.txt')
133 | if not os.path.exists(log_fname):
134 | fp = open(log_fname, 'w')
135 | write_status_to_log_file(fp, total_params)
136 | else:
137 | fp = open(log_fname, 'a')
138 |
139 | # Set a hps file to store hyper-parameters information.
140 | hps_fname = str(dir_to_save + '/hp_str.txt')
141 | fp_h = open(hps_fname, 'w')
142 |
143 | with open('config.py', 'r') as f:
144 | hp_str = ''.join(f.readlines())
145 | fp_h.write(hp_str)
146 | fp_h.close()
147 |
148 | ###############################################################################
149 | ###############################################################################
150 | # Main program start !! #
151 | ###############################################################################
152 | ###############################################################################
153 |
154 | # Writer initialize
155 | writer = Writer(dir_to_logs)
156 |
157 | ###############################################################################
158 | # Train #
159 | ###############################################################################
160 | for epoch in range(epoch_start_idx, cfg.max_epochs + 1):
161 | start_time = time.time()
162 | train_loss = model_train(model, optimizer, train_loader, epoch, DEVICE)
163 | vali_loss, vali_pesq, vali_stoi = model_validate(model, validation_loader,
164 | dir_to_save, writer, epoch, DEVICE)
165 |
166 | mse_vali_total[epoch - 1] = vali_loss
167 | np.save(str(dir_to_save + '/mse_vali_total.npy'), mse_vali_total)
168 |
169 | # write the loss on tensorboard
170 | writer.log_loss(train_loss, vali_loss, epoch)
171 |
172 | # save checkpoint file to resume training
173 | save_path = str(dir_to_save + '/' + ('chkpt_%d.pt' % epoch))
174 | torch.save({
175 | 'model': model.state_dict(),
176 | 'optimizer': optimizer.state_dict(),
177 | 'epoch': epoch
178 | }, save_path)
179 |
180 | print('Epoch [{}] | {:.6f} | {:.6} | {:.6} | {:.6} takes {:.2f} seconds'
181 | .format(epoch, train_loss, vali_loss, vali_pesq, vali_stoi, time.time() - start_time))
182 | fp.write('Epoch [{}] | {:.6f} | {:.6f} | {:.6f} | {:.6f} takes {:.2f} seconds\n'
183 | .format(epoch, train_loss, vali_loss, vali_pesq, vali_stoi, time.time() - start_time))
184 |
185 | print('Training has been finished.')
186 |
187 | # Copy optimum model that has minimum MSE.
188 | print('Save optimum models...')
189 | min_index = np.argmin(mse_vali_total)
190 | print('Minimum validation loss is at '+str(min_index+1)+'.')
191 |
--------------------------------------------------------------------------------
/write_on_tensorboard.py:
--------------------------------------------------------------------------------
1 | """
2 | For observing the results using tensorboard
3 |
4 | 1. wav
5 | 2. loss
6 | """
7 | from tensorboardX import SummaryWriter
8 | import matplotlib
9 | import config as cfg
10 |
11 |
12 | class Writer(SummaryWriter):
13 | def __init__(self, logdir):
14 | super(Writer, self).__init__(logdir)
15 | # mask real/ imag
16 | cmap_custom = {
17 | 'red': ((0.0, 0.0, 0.0),
18 | (1 / 63, 0.0, 0.0),
19 | (2 / 63, 0.0, 0.0),
20 | (3 / 63, 0.0, 0.0),
21 | (4 / 63, 0.0, 0.0),
22 | (5 / 63, 0.0, 0.0),
23 | (6 / 63, 0.0, 0.0),
24 | (7 / 63, 0.0, 0.0),
25 | (8 / 63, 0.0, 0.0),
26 | (9 / 63, 0.0, 0.0),
27 | (10 / 63, 0.0, 0.0),
28 | (11 / 63, 0.0, 0.0),
29 | (12 / 63, 0.0, 0.0),
30 | (13 / 63, 0.0, 0.0),
31 | (14 / 63, 0.0, 0.0),
32 | (15 / 63, 0.0, 0.0),
33 | (16 / 63, 0.0, 0.0),
34 | (17 / 63, 0.0, 0.0),
35 | (18 / 63, 0.0, 0.0),
36 | (19 / 63, 0.0, 0.0),
37 | (20 / 63, 0.0, 0.0),
38 | (21 / 63, 0.0, 0.0),
39 | (22 / 63, 0.0, 0.0),
40 | (23 / 63, 0.0, 0.0),
41 | (24 / 63, 0.5625, 0.5625),
42 | (25 / 63, 0.6250, 0.6250),
43 | (26 / 63, 0.6875, 0.6875),
44 | (27 / 63, 0.7500, 0.7500),
45 | (28 / 63, 0.8125, 0.8125),
46 | (29 / 63, 0.8750, 0.8750),
47 | (30 / 63, 0.9375, 0.9375),
48 | (31 / 63, 1.0, 1.0),
49 | (32 / 63, 1.0, 1.0),
50 | (33 / 63, 1.0, 1.0),
51 | (34 / 63, 1.0, 1.0),
52 | (35 / 63, 1.0, 1.0),
53 | (36 / 63, 1.0, 1.0),
54 | (37 / 63, 1.0, 1.0),
55 | (38 / 63, 1.0, 1.0),
56 | (39 / 63, 1.0, 1.0),
57 | (40 / 63, 1.0, 1.0),
58 | (41 / 63, 1.0, 1.0),
59 | (42 / 63, 1.0, 1.0),
60 | (43 / 63, 1.0, 1.0),
61 | (44 / 63, 1.0, 1.0),
62 | (45 / 63, 1.0, 1.0),
63 | (46 / 63, 1.0, 1.0),
64 | (47 / 63, 1.0, 1.0),
65 | (48 / 63, 1.0, 1.0),
66 | (49 / 63, 1.0, 1.0),
67 | (50 / 63, 1.0, 1.0),
68 | (51 / 63, 1.0, 1.0),
69 | (52 / 63, 1.0, 1.0),
70 | (53 / 63, 1.0, 1.0),
71 | (54 / 63, 1.0, 1.0),
72 | (55 / 63, 1.0, 1.0),
73 | (56 / 63, 0.9375, 0.9375),
74 | (57 / 63, 0.8750, 0.8750),
75 | (58 / 63, 0.8125, 0.8125),
76 | (59 / 63, 0.7500, 0.7500),
77 | (60 / 63, 0.6875, 0.6875),
78 | (61 / 63, 0.6250, 0.6250),
79 | (62 / 63, 0.5625, 0.5625),
80 | (63 / 63, 0.5000, 0.5000)),
81 | 'green': ((0.0, 0.0, 0.0),
82 | (1 / 63, 0.0, 0.0),
83 | (2 / 63, 0.0, 0.0),
84 | (3 / 63, 0.0, 0.0),
85 | (4 / 63, 0.0, 0.0),
86 | (5 / 63, 0.0, 0.0),
87 | (6 / 63, 0.0, 0.0),
88 | (7 / 63, 0.0, 0.0),
89 | (8 / 63, 0.0625, 0.0625),
90 | (9 / 63, 0.1250, 0.1250),
91 | (10 / 63, 0.1875, 0.1875),
92 | (11 / 63, 0.2500, 0.2500),
93 | (12 / 63, 0.3125, 0.3125),
94 | (13 / 63, 0.3750, 0.3750),
95 | (14 / 63, 0.4375, 0.4375),
96 | (15 / 63, 0.5000, 0.5000),
97 | (16 / 63, 0.5625, 0.5625),
98 | (17 / 63, 0.6250, 0.6250),
99 | (18 / 63, 0.6875, 0.6875),
100 | (19 / 63, 0.7500, 0.7500),
101 | (20 / 63, 0.8125, 0.8125),
102 | (21 / 63, 0.8750, 0.8750),
103 | (22 / 63, 0.9375, 0.9375),
104 | (23 / 63, 1.0, 1.0),
105 | (24 / 63, 1.0, 1.0),
106 | (25 / 63, 1.0, 1.0),
107 | (26 / 63, 1.0, 1.0),
108 | (27 / 63, 1.0, 1.0),
109 | (28 / 63, 1.0, 1.0),
110 | (29 / 63, 1.0, 1.0),
111 | (30 / 63, 1.0, 1.0),
112 | (31 / 63, 1.0, 1.0),
113 | (32 / 63, 1.0, 1.0),
114 | (33 / 63, 1.0, 1.0),
115 | (34 / 63, 1.0, 1.0),
116 | (35 / 63, 1.0, 1.0),
117 | (36 / 63, 1.0, 1.0),
118 | (37 / 63, 1.0, 1.0),
119 | (38 / 63, 1.0, 1.0),
120 | (39 / 63, 1.0, 1.0),
121 | (40 / 63, 0.9375, 0.9375),
122 | (41 / 63, 0.8750, 0.8750),
123 | (42 / 63, 0.8125, 0.8125),
124 | (43 / 63, 0.7500, 0.7500),
125 | (44 / 63, 0.6875, 0.6875),
126 | (45 / 63, 0.6250, 0.6250),
127 | (46 / 63, 0.5625, 0.5625),
128 | (47 / 63, 0.5000, 0.5000),
129 | (48 / 63, 0.4375, 0.4375),
130 | (49 / 63, 0.3750, 0.3750),
131 | (50 / 63, 0.3125, 0.3125),
132 | (51 / 63, 0.2500, 0.2500),
133 | (52 / 63, 0.1875, 0.1875),
134 | (53 / 63, 0.1250, 0.1250),
135 | (54 / 63, 0.0625, 0.0625),
136 | (55 / 63, 0.0, 0.0),
137 | (56 / 63, 0.0, 0.0),
138 | (57 / 63, 0.0, 0.0),
139 | (58 / 63, 0.0, 0.0),
140 | (59 / 63, 0.0, 0.0),
141 | (60 / 63, 0.0, 0.0),
142 | (61 / 63, 0.0, 0.0),
143 | (62 / 63, 0.0, 0.0),
144 | (63 / 63, 0.0, 0.0)),
145 | 'blue': ((0.0, 0.5625, 0.5625),
146 | (1 / 63, 0.6250, 0.6250),
147 | (2 / 63, 0.6875, 0.6875),
148 | (3 / 63, 0.7500, 0.7500),
149 | (4 / 63, 0.8125, 0.8125),
150 | (5 / 63, 0.8750, 0.8750),
151 | (6 / 63, 0.9375, 0.9375),
152 | (7 / 63, 1.0, 1.0),
153 | (8 / 63, 1.0, 1.0),
154 | (9 / 63, 1.0, 1.0),
155 | (10 / 63, 1.0, 1.0),
156 | (11 / 63, 1.0, 1.0),
157 | (12 / 63, 1.0, 1.0),
158 | (13 / 63, 1.0, 1.0),
159 | (14 / 63, 1.0, 1.0),
160 | (15 / 63, 1.0, 1.0),
161 | (16 / 63, 1.0, 1.0),
162 | (17 / 63, 1.0, 1.0),
163 | (18 / 63, 1.0, 1.0),
164 | (19 / 63, 1.0, 1.0),
165 | (20 / 63, 1.0, 1.0),
166 | (21 / 63, 1.0, 1.0),
167 | (22 / 63, 1.0, 1.0),
168 | (23 / 63, 1.0, 1.0),
169 | (24 / 63, 1.0, 1.0),
170 | (25 / 63, 1.0, 1.0),
171 | (26 / 63, 1.0, 1.0),
172 | (27 / 63, 1.0, 1.0),
173 | (28 / 63, 1.0, 1.0),
174 | (29 / 63, 1.0, 1.0),
175 | (30 / 63, 1.0, 1.0),
176 | (31 / 63, 1.0, 1.0),
177 | (32 / 63, 0.9375, 0.9375),
178 | (33 / 63, 0.8750, 0.8750),
179 | (34 / 63, 0.8125, 0.8125),
180 | (35 / 63, 0.7500, 0.7500),
181 | (36 / 63, 0.6875, 0.6875),
182 | (37 / 63, 0.6250, 0.6250),
183 | (38 / 63, 0.5625, 0.5625),
184 | (39 / 63, 0.0, 0.0),
185 | (40 / 63, 0.0, 0.0),
186 | (41 / 63, 0.0, 0.0),
187 | (42 / 63, 0.0, 0.0),
188 | (43 / 63, 0.0, 0.0),
189 | (44 / 63, 0.0, 0.0),
190 | (45 / 63, 0.0, 0.0),
191 | (46 / 63, 0.0, 0.0),
192 | (47 / 63, 0.0, 0.0),
193 | (48 / 63, 0.0, 0.0),
194 | (49 / 63, 0.0, 0.0),
195 | (50 / 63, 0.0, 0.0),
196 | (51 / 63, 0.0, 0.0),
197 | (52 / 63, 0.0, 0.0),
198 | (53 / 63, 0.0, 0.0),
199 | (54 / 63, 0.0, 0.0),
200 | (55 / 63, 0.0, 0.0),
201 | (56 / 63, 0.0, 0.0),
202 | (57 / 63, 0.0, 0.0),
203 | (58 / 63, 0.0, 0.0),
204 | (59 / 63, 0.0, 0.0),
205 | (60 / 63, 0.0, 0.0),
206 | (61 / 63, 0.0, 0.0),
207 | (62 / 63, 0.0, 0.0),
208 | (63 / 63, 0.0, 0.0))
209 | }
210 |
211 | # mask magnitude
212 | cmap_custom2 = {
213 | 'red': ((0.0, 1.0, 1.0),
214 | (1 / 32, 1.0, 1.0),
215 | (2 / 32, 1.0, 1.0),
216 | (3 / 32, 1.0, 1.0),
217 | (4 / 32, 1.0, 1.0),
218 | (5 / 32, 1.0, 1.0),
219 | (6 / 32, 1.0, 1.0),
220 | (7 / 32, 1.0, 1.0),
221 | (8 / 32, 1.0, 1.0),
222 | (9 / 32, 1.0, 1.0),
223 | (10 / 32, 1.0, 1.0),
224 | (11 / 32, 1.0, 1.0),
225 | (12 / 32, 1.0, 1.0),
226 | (13 / 32, 1.0, 1.0),
227 | (14 / 32, 1.0, 1.0),
228 | (15 / 32, 1.0, 1.0),
229 | (16 / 32, 1.0, 1.0),
230 | (17 / 32, 1.0, 1.0),
231 | (18 / 32, 1.0, 1.0),
232 | (19 / 32, 1.0, 1.0),
233 | (20 / 32, 1.0, 1.0),
234 | (21 / 32, 1.0, 1.0),
235 | (22 / 32, 1.0, 1.0),
236 | (23 / 32, 1.0, 1.0),
237 | (24 / 32, 1.0, 1.0),
238 | (25 / 32, 0.9375, 0.9375),
239 | (26 / 32, 0.8750, 0.8750),
240 | (27 / 32, 0.8125, 0.8125),
241 | (28 / 32, 0.7500, 0.7500),
242 | (29 / 32, 0.6875, 0.6875),
243 | (30 / 32, 0.6250, 0.6250),
244 | (31 / 32, 0.5625, 0.5625),
245 | (32 / 32, 0.5000, 0.5000)),
246 | 'green': ((0.0, 1.0, 1.0),
247 | (1 / 32, 1.0, 1.0),
248 | (2 / 32, 1.0, 1.0),
249 | (3 / 32, 1.0, 1.0),
250 | (4 / 32, 1.0, 1.0),
251 | (5 / 32, 1.0, 1.0),
252 | (6 / 32, 1.0, 1.0),
253 | (7 / 32, 1.0, 1.0),
254 | (8 / 32, 1.0, 1.0),
255 | (9 / 32, 0.9375, 0.9375),
256 | (10 / 32, 0.8750, 0.8750),
257 | (11 / 32, 0.8125, 0.8125),
258 | (12 / 32, 0.7500, 0.7500),
259 | (13 / 32, 0.6875, 0.6875),
260 | (14 / 32, 0.6250, 0.6250),
261 | (15 / 32, 0.5625, 0.5625),
262 | (16 / 32, 0.5000, 0.5000),
263 | (17 / 32, 0.4375, 0.4375),
264 | (18 / 32, 0.3750, 0.3750),
265 | (19 / 32, 0.3125, 0.3125),
266 | (20 / 32, 0.2500, 0.2500),
267 | (21 / 32, 0.1875, 0.1875),
268 | (22 / 32, 0.1250, 0.1250),
269 | (23 / 32, 0.0625, 0.0625),
270 | (24 / 32, 0.0, 0.0),
271 | (25 / 32, 0.0, 0.0),
272 | (26 / 32, 0.0, 0.0),
273 | (27 / 32, 0.0, 0.0),
274 | (28 / 32, 0.0, 0.0),
275 | (29 / 32, 0.0, 0.0),
276 | (30 / 32, 0.0, 0.0),
277 | (31 / 32, 0.0, 0.0),
278 | (32 / 32, 0.0, 0.0)),
279 | 'blue': ((0.0, 1.0, 1.0),
280 | (1 / 32, 0.9375, 0.9375),
281 | (2 / 32, 0.8750, 0.8750),
282 | (3 / 32, 0.8125, 0.8125),
283 | (4 / 32, 0.7500, 0.7500),
284 | (5 / 32, 0.6875, 0.6875),
285 | (6 / 32, 0.6250, 0.6250),
286 | (7 / 32, 0.5625, 0.5625),
287 | (8 / 32, 0.0, 0.0),
288 | (9 / 32, 0.0, 0.0),
289 | (10 / 32, 0.0, 0.0),
290 | (11 / 32, 0.0, 0.0),
291 | (12 / 32, 0.0, 0.0),
292 | (13 / 32, 0.0, 0.0),
293 | (14 / 32, 0.0, 0.0),
294 | (15 / 32, 0.0, 0.0),
295 | (16 / 32, 0.0, 0.0),
296 | (17 / 32, 0.0, 0.0),
297 | (18 / 32, 0.0, 0.0),
298 | (19 / 32, 0.0, 0.0),
299 | (20 / 32, 0.0, 0.0),
300 | (21 / 32, 0.0, 0.0),
301 | (22 / 32, 0.0, 0.0),
302 | (23 / 32, 0.0, 0.0),
303 | (24 / 32, 0.0, 0.0),
304 | (25 / 32, 0.0, 0.0),
305 | (26 / 32, 0.0, 0.0),
306 | (27 / 32, 0.0, 0.0),
307 | (28 / 32, 0.0, 0.0),
308 | (29 / 32, 0.0, 0.0),
309 | (30 / 32, 0.0, 0.0),
310 | (31 / 32, 0.0, 0.0),
311 | (32 / 32, 0.0, 0.0))
312 | }
313 |
314 | self.cmap_custom = matplotlib.colors.LinearSegmentedColormap('testCmap', segmentdata=cmap_custom, N=256)
315 | self.cmap_custom2 = matplotlib.colors.LinearSegmentedColormap('testCmap2', segmentdata=cmap_custom2, N=256)
316 |
317 | def log_loss(self, train_loss, vali_loss, step):
318 | self.add_scalar('train_loss', train_loss, step)
319 | self.add_scalar('vali_loss', vali_loss, step)
320 |
321 | def log_sub_loss(self, train_main_loss, train_sub_loss, vali_main_loss, vali_sub_loss, step):
322 | self.add_scalar('train_main_loss', train_main_loss, step)
323 | self.add_scalar('train_sub_loss', train_sub_loss, step)
324 | self.add_scalar('vali_main_loss', vali_main_loss, step)
325 | self.add_scalar('vali_sub_loss', vali_sub_loss, step)
326 |
327 | def log_score(self, vali_pesq, vali_stoi, step):
328 | self.add_scalar('vali_pesq', vali_pesq, step)
329 | self.add_scalar('vali_stoi', vali_stoi, step)
330 |
331 | def log_wav(self, mixed_wav, clean_wav, est_wav, step):
332 | #