├── LICENSE ├── README.md ├── config.py ├── dataloader.py ├── model.py ├── samples └── 0dB │ ├── [15] SI-SNR+LMS.wav │ ├── [15] SI-SNR.wav │ ├── [15]CLEAN.wav │ ├── [15]NOISY.wav │ ├── [189] SI-SNR+LMS.wav │ ├── [189] SI-SNR.wav │ ├── [189]CLEAN.wav │ ├── [189]NOISY.wav │ ├── [1] SI-SNR+LMS.wav │ ├── [1] SI-SNR.wav │ ├── [1]CLEAN.wav │ ├── [1]NOISY.wav │ ├── [21] SI-SNR+LMS.wav │ ├── [21] SI-SNR.wav │ ├── [21]CLEAN.wav │ ├── [21]NOISY.wav │ ├── [78] SI-SNR+LMS.wav │ ├── [78] SI-SNR.wav │ ├── [78]CLEAN.wav │ ├── [78]NOISY.wav │ ├── [88] SI-SNR+LMS.wav │ ├── [88] SI-SNR.wav │ ├── [88]CLEAN.wav │ └── [88]NOISY.wav ├── tester.py ├── tools_for_loss.py ├── tools_for_model.py ├── train.py ├── trainer.py └── write_on_tensorboard.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 seorim0 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DCCRN with various loss functions 2 | 3 | DCCRN(Deep Complex Convolutional Recurrent Network) is one of the deep neaural networks proposed at [[1]](https://arxiv.org/abs/2008.00264). This repository is an application using DCCRN with various loss functions. Our original paper can be found [here](https://www.jask.or.kr/articles/xml/ABxn/), and you can check test samples [here](https://github.com/seorim0/DCCRN-with-various-loss-functions/tree/main/samples/0dB). Test samples are randomly choosed and we uploaded samples about SI-SNR and SI-SNR+LMS. 4 |
5 | 6 | ![DCCRN_수정최종](https://user-images.githubusercontent.com/55497506/105969652-d39f6b80-60cb-11eb-805c-0f204405ef37.png) 7 | > Source of the figure: [paper](https://www.jask.or.kr/articles/xml/ABxn/) 8 |
9 | 10 | 11 | 12 | # Loss functions 13 | We use two base loss functions and two perceptual loss functions. 14 | 15 | > Base loss 16 | 1. MSE: Mean Squred Error 17 | ![image](https://user-images.githubusercontent.com/55497506/106714015-97758900-663e-11eb-9593-6ecfd4266a41.png) 18 |
19 | 20 | 2. SI-SNR: Scale Invariant Source-to-Noise Ratio 21 | ![image](https://user-images.githubusercontent.com/55497506/106714206-da376100-663e-11eb-94c6-77f6588616b9.png) 22 |
23 | 24 | > Perceptual loss 25 | 1. LMS: Log Mel Spectra 26 | ![image](https://user-images.githubusercontent.com/55497506/106714238-e58a8c80-663e-11eb-8601-58bb020a2d3b.png) 27 |
28 | 29 | 2. PMSQE: Perceptual Metric for Speech Quality Evaluation 30 | ![image](https://user-images.githubusercontent.com/55497506/106714147-c855be00-663e-11eb-8a8d-a9d5aba1325d.png) 31 |
32 | 33 | We combined 2 types of base loss functons and 2 types of perceptual loss functions. The coupling constant ratio was determined experimentally. For example, in the case of MSE, which is the basic loss function, the initial size is about 0.001 ~ 0.002, whereas the LMS has an initial size of 0.1 ~ 0.2 and PMSQE is about 0.8 ~ 1.3. Therefore, to combine the two terms to be of similar size, a smaller coefficient was used in the perceptual based loss function term. The coupling constant ratio is a result of reflecting the dynamic range of the two terms rather than reflecting the sensitivity of the two terms. Meanwhile, in the course of the experiment, we determined that the basic loss function is a more important term, so we changed the coefficients so that the dynamic range ratio including the coupling constant could be adjusted from 1:1 to 10:1, respectively. 34 |
35 | 36 | # Requirements 37 | > This repository is tested on Ubuntu 20.04. 38 | * Python 3.7+ 39 | * Cuda 10.1+ 40 | * CuDNN 7+ 41 | * Pytorch 1.7+ 42 |
43 | 44 | > Library 45 | * tqdm 46 | * asteroid 47 | * scipy 48 | * matplotlib 49 | * tensorboardX 50 | * pesq 51 | * pystoi 52 | 53 | # Prepare data 54 | The training and validation data consist of the following three dimensions. 55 | ```[Batch size, 2(input & target), wav length]``` 56 |
57 | The test data consists of the following dimensions. 58 | ```[noise type, dB classes, Batch size, 2(input & target), wav length]``` 59 | We use 2 type of noise, seen and unseen and 7 dB classes from -10dB to 20dB. 60 | 61 |
62 | We cut the wav files longer than 3 seconds into 3 seconds and zero padded for wav files shorter than 3 seconds. 63 | The sampling frequency is 16k. 64 | 65 | 69 | 70 | # Performance comparative evaluation 71 | **Objective evaluation** 72 |
73 | We evaluate the outputs with PESQ(Perceptual Evaluation of Speech Quality) and STOI(Short Time Objective Intelligibility measure). 74 | ![t1](https://user-images.githubusercontent.com/55497506/108797149-e1aeb200-75cd-11eb-8ea4-3db00da21991.png) 75 |
76 | 77 | ![t2](https://user-images.githubusercontent.com/55497506/108797168-eb381a00-75cd-11eb-94ba-1d3a1016fb6e.png) 78 |
79 | 80 | **Spectrogram** 81 | 82 | ![image](https://user-images.githubusercontent.com/55497506/108705017-1a0fab00-7550-11eb-962a-9f0b218371a8.png) 83 | > Source of the figure: [paper]() 84 | 85 | The spectrograms of (a) clean speech, (b) noisy speech at 0 dB SNR, estimated speeches using (c) MSE and PMSQE, (d) SI-SNR , (e) SI-SNR and PMSQE, (f) SI-SNR and LMS. 86 | 87 | # References 88 | **DCCRN: Deep Complex Convolution Recurrent Network for Phase-Aware Speech Enhancement** 89 | Yanxin Hu, Yun Liu, Shubo Lv, Mengtao Xing, Shimin Zhang, Yihui Fu, Jian Wu, Bihong Zhang, Lei Xie 90 | [[arXiv]](https://arxiv.org/abs/2008.00264) [[code]](https://github.com/huyanxin/DeepComplexCRN) 91 | 92 | 93 | # Paper 94 | **Performance comparison evaluation of speech enhancement using various loss function.** 95 | Seo-Rim Hwang, Joon Byun, Young-Cheul Park 96 | [[paper]](https://www.jask.or.kr/articles/xml/ABxn/) 97 | 98 | 99 | # Note 100 | * ~~I'm trying to the codes more clearly.~~ 101 | * ~~It's still in the editing phase. Please refer to the existing code.~~ 102 | * [cleanup and upgrade version code](https://github.com/seorim0/Speech_enhancement_with_Pytorch) 103 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration for program 3 | """ 4 | 5 | # model 6 | mode = 'DCCRN' # DCUNET / DCCRN 7 | info = 'MODEL INFORMATION : IT IS USED FOR FILE NAME' 8 | 9 | test = True 10 | 11 | # path 12 | job_dir = './job/' 13 | logs_dir = './logs/' 14 | chkpt_path = None 15 | # chkpt_model = 'FILE NAME THAT YOU WANT TO LOAD' 16 | # chkpt_path = job_dir + chkpt_model + 'chkpt_88.pt' 17 | 18 | # model information 19 | fs = 16000 20 | win_len = 400 21 | win_inc = 100 22 | ola_ratio = win_inc / win_len 23 | fft_len = 512 24 | sam_sec = fft_len / fs 25 | frm_samp = fs * (fft_len / fs) 26 | window_type = 'hanning' 27 | 28 | rnn_layers = 2 29 | rnn_units = 256 30 | masking_mode = 'E' 31 | use_clstm = True 32 | kernel_num = [32, 64, 128, 256, 256, 256] # DCCRN 33 | #kernel_num = [72, 72, 144, 144, 144, 160, 160, 180] # DCUNET 34 | loss_mode = 'SDR+PMSQE' 35 | 36 | # hyperparameters for model train 37 | max_epochs = 100 38 | learning_rate = 0.0005 39 | batch = 15 40 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset, DataLoader 4 | import config as cfg 5 | 6 | 7 | # save np.load 8 | np_load_old = np.load 9 | # modify the default parameters of np.load 10 | np.load = lambda *a, **k: np_load_old(*a, allow_pickle=True, **k) 11 | 12 | 13 | def create_dataloader(mode): 14 | if mode == 'train': 15 | return DataLoader( 16 | dataset=Wave_Dataset(mode), 17 | batch_size=cfg.batch, # max 3696 * snr types 18 | shuffle=True, 19 | num_workers=0, 20 | pin_memory=True, 21 | drop_last=True, 22 | sampler=None 23 | ) 24 | elif mode == 'valid': 25 | return DataLoader( 26 | dataset=Wave_Dataset(mode), 27 | batch_size=cfg.batch, shuffle=False, num_workers=0 28 | ) # max 1152 29 | 30 | 31 | def create_dataloader_for_test(mode, type, snr): 32 | if mode == 'test': 33 | return DataLoader( 34 | dataset=Wave_Dataset_for_test(mode, type, snr), 35 | batch_size=cfg.batch, shuffle=False, num_workers=0 36 | ) # max 192 37 | 38 | 39 | class Wave_Dataset(Dataset): 40 | def __init__(self, mode): 41 | # load data 42 | if mode == 'train': 43 | print('') 44 | print('Load the data...') 45 | self.input_path = './input/train_dataset.npy' 46 | elif mode == 'valid': 47 | print('') 48 | print('Load the data...') 49 | self.input_path = './input/validation_dataset.npy' 50 | 51 | self.input = np.load(self.input_path) 52 | 53 | def __len__(self): 54 | return len(self.input) 55 | 56 | def __getitem__(self, idx): 57 | inputs = self.input[idx][0] 58 | labels = self.input[idx][1] 59 | 60 | # transform to torch from numpy 61 | inputs = torch.from_numpy(inputs) 62 | labels = torch.from_numpy(labels) 63 | 64 | return inputs, labels 65 | 66 | 67 | class Wave_Dataset_for_test(Dataset): 68 | def __init__(self, mode, type, snr): 69 | # load data 70 | if mode == 'test': 71 | print('') 72 | print('Load the data...') 73 | self.input_path = './input/recon_test_dataset.npy' 74 | 75 | self.input = np.load(self.input_path) 76 | self.input = self.input[type][snr] 77 | 78 | def __len__(self): 79 | return len(self.input) 80 | 81 | def __getitem__(self, idx): 82 | inputs = self.input[idx][0] 83 | labels = self.input[idx][1] 84 | 85 | # transform to torch from numpy 86 | inputs = torch.from_numpy(inputs) 87 | labels = torch.from_numpy(labels) 88 | 89 | return inputs, labels 90 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | DCCRN: Deep complex convolution recurrent network 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import config as cfg 8 | from tools_for_model import ConvSTFT, ConviSTFT, \ 9 | ComplexConv2d, ComplexConvTranspose2d, NavieComplexLSTM, complex_cat, ComplexBatchNorm 10 | from tools_for_loss import si_snr, si_sdr, get_array_mel_loss, pmsqe_stft, pmsqe_loss, sdr 11 | from asteroid.filterbanks import transforms 12 | 13 | 14 | class DCCRN(nn.Module): 15 | 16 | def __init__( 17 | self, 18 | rnn_layers=cfg.rnn_layers, 19 | rnn_units=cfg.rnn_units, 20 | win_len=cfg.win_len, 21 | win_inc=cfg.win_inc, 22 | fft_len=cfg.fft_len, 23 | win_type=cfg.window_type, 24 | masking_mode='E', 25 | use_clstm=False, 26 | use_cbn=False, 27 | kernel_size=5, 28 | kernel_num=[16, 32, 64, 128, 256, 256] 29 | ): 30 | ''' 31 | 32 | rnn_layers: the number of lstm layers in the crn, 33 | rnn_units: for clstm, rnn_units = real+imag 34 | ''' 35 | 36 | super(DCCRN, self).__init__() 37 | 38 | # for fft 39 | self.win_len = win_len 40 | self.win_inc = win_inc 41 | self.fft_len = fft_len 42 | self.win_type = win_type 43 | 44 | input_dim = win_len 45 | output_dim = win_len 46 | 47 | self.rnn_units = rnn_units 48 | self.input_dim = input_dim 49 | self.output_dim = output_dim 50 | self.hidden_layers = rnn_layers 51 | self.kernel_size = kernel_size 52 | # self.kernel_num = [2, 8, 16, 32, 128, 128, 128] 53 | # self.kernel_num = [2, 16, 32, 64, 128, 256, 256] 54 | self.kernel_num = [2] + kernel_num 55 | self.masking_mode = masking_mode 56 | self.use_clstm = use_clstm 57 | 58 | # bidirectional=True 59 | bidirectional = False 60 | fac = 2 if bidirectional else 1 61 | 62 | fix = True 63 | self.fix = fix 64 | self.stft = ConvSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix) 65 | self.istft = ConviSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix) 66 | 67 | self.encoder = nn.ModuleList() 68 | self.decoder = nn.ModuleList() 69 | for idx in range(len(self.kernel_num) - 1): 70 | self.encoder.append( 71 | nn.Sequential( 72 | # nn.ConstantPad2d([0, 0, 0, 0], 0), 73 | ComplexConv2d( 74 | self.kernel_num[idx], 75 | self.kernel_num[idx + 1], 76 | kernel_size=(self.kernel_size, 2), 77 | stride=(2, 1), 78 | padding=(2, 1) 79 | ), 80 | nn.BatchNorm2d(self.kernel_num[idx + 1]) if not use_cbn else ComplexBatchNorm( 81 | self.kernel_num[idx + 1]), 82 | nn.PReLU() 83 | ) 84 | ) 85 | hidden_dim = self.fft_len // (2 ** (len(self.kernel_num))) 86 | 87 | if self.use_clstm: 88 | rnns = [] 89 | for idx in range(rnn_layers): 90 | rnns.append( 91 | NavieComplexLSTM( 92 | input_size=hidden_dim * self.kernel_num[-1] if idx == 0 else self.rnn_units, 93 | hidden_size=self.rnn_units, 94 | bidirectional=bidirectional, 95 | batch_first=False, 96 | projection_dim=hidden_dim * self.kernel_num[-1] if idx == rnn_layers - 1 else None, 97 | ) 98 | ) 99 | self.enhance = nn.Sequential(*rnns) 100 | else: 101 | self.enhance = nn.LSTM( 102 | input_size=hidden_dim * self.kernel_num[-1], 103 | hidden_size=self.rnn_units, 104 | num_layers=2, 105 | dropout=0.0, 106 | bidirectional=bidirectional, 107 | batch_first=False 108 | ) 109 | self.tranform = nn.Linear(self.rnn_units * fac, hidden_dim * self.kernel_num[-1]) 110 | 111 | for idx in range(len(self.kernel_num) - 1, 0, -1): 112 | if idx != 1: 113 | self.decoder.append( 114 | nn.Sequential( 115 | ComplexConvTranspose2d( 116 | self.kernel_num[idx] * 2, 117 | self.kernel_num[idx - 1], 118 | kernel_size=(self.kernel_size, 2), 119 | stride=(2, 1), 120 | padding=(2, 0), 121 | output_padding=(1, 0) 122 | ), 123 | nn.BatchNorm2d(self.kernel_num[idx - 1]) if not use_cbn else ComplexBatchNorm( 124 | self.kernel_num[idx - 1]), 125 | # nn.ELU() 126 | nn.PReLU() 127 | ) 128 | ) 129 | else: 130 | self.decoder.append( 131 | nn.Sequential( 132 | ComplexConvTranspose2d( 133 | self.kernel_num[idx] * 2, 134 | self.kernel_num[idx - 1], 135 | kernel_size=(self.kernel_size, 2), 136 | stride=(2, 1), 137 | padding=(2, 0), 138 | output_padding=(1, 0) 139 | ), 140 | ) 141 | ) 142 | 143 | self.flatten_parameters() 144 | 145 | def flatten_parameters(self): 146 | if isinstance(self.enhance, nn.LSTM): 147 | self.enhance.flatten_parameters() 148 | 149 | def forward(self, inputs, lens=None): 150 | specs = self.stft(inputs) 151 | real = specs[:, :self.fft_len // 2 + 1] 152 | imag = specs[:, self.fft_len // 2 + 1:] 153 | spec_mags = torch.sqrt(real ** 2 + imag ** 2 + 1e-8) 154 | spec_mags = spec_mags 155 | 156 | ## 157 | 158 | ## 159 | spec_phase = torch.atan2(imag, real) 160 | spec_phase = spec_phase 161 | cspecs = torch.stack([real, imag], 1) 162 | cspecs = cspecs[:, :, 1:] 163 | ''' 164 | means = torch.mean(cspecs, [1,2,3], keepdim=True) 165 | std = torch.std(cspecs, [1,2,3], keepdim=True ) 166 | normed_cspecs = (cspecs-means)/(std+1e-8) 167 | out = normed_cspecs 168 | ''' 169 | 170 | out = cspecs 171 | encoder_out = [] 172 | 173 | for idx, layer in enumerate(self.encoder): 174 | out = layer(out) 175 | # print('encoder', out.size()) 176 | encoder_out.append(out) 177 | 178 | batch_size, channels, dims, lengths = out.size() 179 | out = out.permute(3, 0, 1, 2) 180 | if self.use_clstm: 181 | r_rnn_in = out[:, :, :channels // 2] 182 | i_rnn_in = out[:, :, channels // 2:] 183 | r_rnn_in = torch.reshape(r_rnn_in, [lengths, batch_size, channels // 2 * dims]) 184 | i_rnn_in = torch.reshape(i_rnn_in, [lengths, batch_size, channels // 2 * dims]) 185 | 186 | r_rnn_in, i_rnn_in = self.enhance([r_rnn_in, i_rnn_in]) 187 | 188 | r_rnn_in = torch.reshape(r_rnn_in, [lengths, batch_size, channels // 2, dims]) 189 | i_rnn_in = torch.reshape(i_rnn_in, [lengths, batch_size, channels // 2, dims]) 190 | out = torch.cat([r_rnn_in, i_rnn_in], 2) 191 | 192 | else: 193 | # to [L, B, C, D] 194 | out = torch.reshape(out, [lengths, batch_size, channels * dims]) 195 | out, _ = self.enhance(out) 196 | out = self.tranform(out) 197 | out = torch.reshape(out, [lengths, batch_size, channels, dims]) 198 | 199 | out = out.permute(1, 2, 3, 0) 200 | 201 | for idx in range(len(self.decoder)): 202 | out = complex_cat([out, encoder_out[-1 - idx]], 1) 203 | out = self.decoder[idx](out) 204 | out = out[..., 1:] 205 | # print('decoder', out.size()) 206 | mask_real = out[:, 0] 207 | mask_imag = out[:, 1] 208 | mask_real = F.pad(mask_real, [0, 0, 1, 0]) 209 | mask_imag = F.pad(mask_imag, [0, 0, 1, 0]) 210 | 211 | if self.masking_mode == 'E': 212 | mask_mags = (mask_real ** 2 + mask_imag ** 2) ** 0.5 213 | real_phase = mask_real / (mask_mags + 1e-8) 214 | imag_phase = mask_imag / (mask_mags + 1e-8) 215 | mask_phase = torch.atan2( 216 | imag_phase, 217 | real_phase 218 | ) 219 | 220 | # mask_mags = torch.clamp_(mask_mags,0,100) 221 | mask_mags = torch.tanh(mask_mags) 222 | est_mags = mask_mags * spec_mags 223 | est_phase = spec_phase + mask_phase 224 | real = est_mags * torch.cos(est_phase) 225 | imag = est_mags * torch.sin(est_phase) 226 | elif self.masking_mode == 'C': 227 | real, imag = real * mask_real - imag * mask_imag, real * mask_imag + imag * mask_real 228 | elif self.masking_mode == 'R': 229 | real, imag = real * mask_real, imag * mask_imag 230 | 231 | out_spec = torch.cat([real, imag], 1) 232 | out_wav = self.istft(out_spec) 233 | 234 | out_wav = torch.squeeze(out_wav, 1) 235 | # out_wav = torch.tanh(out_wav) 236 | out_wav = torch.clamp_(out_wav, -1, 1) 237 | return mask_real, mask_imag, real, imag, out_wav # out_spec, out_wav 238 | 239 | def get_params(self, weight_decay=0.0): 240 | # add L2 penalty 241 | weights, biases = [], [] 242 | for name, param in self.named_parameters(): 243 | if 'bias' in name: 244 | biases += [param] 245 | else: 246 | weights += [param] 247 | params = [{ 248 | 'params': weights, 249 | 'weight_decay': weight_decay, 250 | }, { 251 | 'params': biases, 252 | 'weight_decay': 0.0, 253 | }] 254 | return params 255 | 256 | def loss(self, inputs, labels, real_spec, img_spec, loss_mode=cfg.loss_mode): 257 | if loss_mode == 'MSE': 258 | return F.mse_loss(inputs, labels, reduction='mean') 259 | 260 | elif loss_mode == 'SDR': 261 | return -sdr(labels, inputs) 262 | 263 | elif loss_mode == 'SI-SNR': 264 | return -(si_snr(inputs, labels)) 265 | 266 | elif loss_mode == 'SI-SDR': 267 | return -(si_sdr(labels, inputs)) 268 | 269 | elif loss_mode == 'MSE+LMS': 270 | 271 | mse_loss = F.mse_loss(inputs, labels, reduction='mean') 272 | 273 | # for mel loss calculation 274 | clean_specs = self.stft(labels) 275 | clean_real = clean_specs[:, :self.fft_len // 2 + 1] 276 | clean_imag = clean_specs[:, self.fft_len // 2 + 1:] 277 | clean_mags = torch.sqrt(clean_real ** 2 + clean_imag ** 2 + 1e-7) 278 | 279 | est_clean_mags = torch.sqrt(real_spec ** 2 + img_spec ** 2 + 1e-7) 280 | mel_loss = get_array_mel_loss(clean_mags, est_clean_mags) 281 | 282 | r1 = 1e+3 283 | r2 = 1 284 | r = r1 + r2 285 | 286 | loss = (r1 * mse_loss + r2 * mel_loss) / r 287 | 288 | return loss 289 | 290 | elif loss_mode == 'MSE+SI-SNR': 291 | snr_loss = -(si_snr(inputs, labels)) 292 | mse_loss = F.mse_loss(inputs, labels, reduction='mean') 293 | 294 | r1 = 1 295 | r2 = 100 296 | r = r1 + r2 297 | 298 | loss = (r1 * snr_loss + r2 * mse_loss) / r 299 | 300 | return loss 301 | 302 | elif loss_mode == 'MSE+PMSQE': 303 | ref_wav = labels.reshape(-1, 3, 16000) 304 | est_wav = inputs.reshape(-1, 3, 16000) 305 | ref_wav = ref_wav.cpu() 306 | est_wav = est_wav.cpu() 307 | 308 | ref_spec = transforms.take_mag(pmsqe_stft(ref_wav)) 309 | est_spec = transforms.take_mag(pmsqe_stft(est_wav)) 310 | 311 | loss = pmsqe_loss(ref_spec, est_spec) 312 | 313 | loss = loss.cuda() 314 | 315 | return loss 316 | 317 | elif loss_mode == 'SI-SNR+SI-SDR': 318 | snr_loss = -(si_snr(inputs, labels)) 319 | sdr_loss = -(si_sdr(inputs, labels)) 320 | 321 | r1 = 1 322 | r2 = 1 323 | r = r1 + r2 324 | 325 | loss = (r1 * snr_loss + r2 * sdr_loss) / r 326 | 327 | return loss 328 | 329 | elif loss_mode == 'SDR+LMS': 330 | sdr_loss = -sdr(labels, inputs) 331 | 332 | # for mel loss calculation 333 | clean_specs = self.stft(labels) 334 | clean_real = clean_specs[:, :self.fft_len // 2 + 1] 335 | clean_imag = clean_specs[:, self.fft_len // 2 + 1:] 336 | clean_mags = torch.sqrt(clean_real ** 2 + clean_imag ** 2 + 1e-7) 337 | 338 | est_clean_mags = torch.sqrt(real_spec ** 2 + img_spec ** 2 + 1e-7) 339 | mel_loss = get_array_mel_loss(clean_mags, est_clean_mags) 340 | 341 | r1 = 1 342 | r2 = 2 343 | r = r1 + r2 344 | 345 | loss = (r1 * sdr_loss + r2 * mel_loss) / r 346 | return loss 347 | 348 | elif loss_mode == 'SDR+PMSQE': 349 | sdr_loss = -sdr(labels, inputs) 350 | 351 | ref_wav = labels.reshape(-1, 3, 16000) 352 | est_wav = inputs.reshape(-1, 3, 16000) 353 | ref_wav = ref_wav.cpu() 354 | est_wav = est_wav.cpu() 355 | 356 | ref_spec = transforms.take_mag(pmsqe_stft(ref_wav)) 357 | est_spec = transforms.take_mag(pmsqe_stft(est_wav)) 358 | 359 | # p_loss = pmsqe_loss(ref_spec, est_spec) wrong 360 | p_loss = pmsqe_loss(est_spec, ref_spec) 361 | 362 | r1 = 1 363 | r2 = 15 364 | r = r1 + r2 365 | 366 | loss = (r1 * sdr_loss + r2 * p_loss) / r 367 | return loss 368 | 369 | elif loss_mode == 'SI-SNR+LMS': 370 | snr_loss = -(si_snr(inputs, labels)) 371 | 372 | # for mel loss calculation 373 | clean_specs = self.stft(labels) 374 | clean_real = clean_specs[:, :self.fft_len // 2 + 1] 375 | clean_imag = clean_specs[:, self.fft_len // 2 + 1:] 376 | clean_mags = torch.sqrt(clean_real ** 2 + clean_imag ** 2 + 1e-7) 377 | 378 | est_clean_mags = torch.sqrt(real_spec ** 2 + img_spec ** 2 + 1e-7) 379 | mel_loss = get_array_mel_loss(clean_mags, est_clean_mags) 380 | 381 | r1 = 1 382 | r2 = 2 383 | r = r1 + r2 384 | 385 | loss = (r1 * snr_loss + r2 * mel_loss) / r 386 | 387 | return loss 388 | 389 | elif loss_mode == 'SI-SNR+PMSQE': 390 | ref_wav = labels.reshape(-1, 3, 16000) 391 | est_wav = inputs.reshape(-1, 3, 16000) 392 | ref_wav = ref_wav.cpu() 393 | est_wav = est_wav.cpu() 394 | 395 | ref_spec = transforms.take_mag(pmsqe_stft(ref_wav)) 396 | est_spec = transforms.take_mag(pmsqe_stft(est_wav)) 397 | 398 | p_loss = pmsqe_loss(est_spec, ref_spec) 399 | 400 | snr_loss = -(si_snr(est_wav, ref_wav)) 401 | 402 | r1 = 8 403 | r2 = 1 404 | r = r1 + r2 405 | 406 | loss = (r1 * p_loss + r2 * snr_loss) / r 407 | 408 | return loss 409 | 410 | -------------------------------------------------------------------------------- /samples/0dB/[15] SI-SNR+LMS.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[15] SI-SNR+LMS.wav -------------------------------------------------------------------------------- /samples/0dB/[15] SI-SNR.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[15] SI-SNR.wav -------------------------------------------------------------------------------- /samples/0dB/[15]CLEAN.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[15]CLEAN.wav -------------------------------------------------------------------------------- /samples/0dB/[15]NOISY.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[15]NOISY.wav -------------------------------------------------------------------------------- /samples/0dB/[189] SI-SNR+LMS.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[189] SI-SNR+LMS.wav -------------------------------------------------------------------------------- /samples/0dB/[189] SI-SNR.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[189] SI-SNR.wav -------------------------------------------------------------------------------- /samples/0dB/[189]CLEAN.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[189]CLEAN.wav -------------------------------------------------------------------------------- /samples/0dB/[189]NOISY.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[189]NOISY.wav -------------------------------------------------------------------------------- /samples/0dB/[1] SI-SNR+LMS.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[1] SI-SNR+LMS.wav -------------------------------------------------------------------------------- /samples/0dB/[1] SI-SNR.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[1] SI-SNR.wav -------------------------------------------------------------------------------- /samples/0dB/[1]CLEAN.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[1]CLEAN.wav -------------------------------------------------------------------------------- /samples/0dB/[1]NOISY.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[1]NOISY.wav -------------------------------------------------------------------------------- /samples/0dB/[21] SI-SNR+LMS.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[21] SI-SNR+LMS.wav -------------------------------------------------------------------------------- /samples/0dB/[21] SI-SNR.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[21] SI-SNR.wav -------------------------------------------------------------------------------- /samples/0dB/[21]CLEAN.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[21]CLEAN.wav -------------------------------------------------------------------------------- /samples/0dB/[21]NOISY.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[21]NOISY.wav -------------------------------------------------------------------------------- /samples/0dB/[78] SI-SNR+LMS.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[78] SI-SNR+LMS.wav -------------------------------------------------------------------------------- /samples/0dB/[78] SI-SNR.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[78] SI-SNR.wav -------------------------------------------------------------------------------- /samples/0dB/[78]CLEAN.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[78]CLEAN.wav -------------------------------------------------------------------------------- /samples/0dB/[78]NOISY.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[78]NOISY.wav -------------------------------------------------------------------------------- /samples/0dB/[88] SI-SNR+LMS.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[88] SI-SNR+LMS.wav -------------------------------------------------------------------------------- /samples/0dB/[88] SI-SNR.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[88] SI-SNR.wav -------------------------------------------------------------------------------- /samples/0dB/[88]CLEAN.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[88]CLEAN.wav -------------------------------------------------------------------------------- /samples/0dB/[88]NOISY.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DCCRN-with-various-loss-functions/2e244605464452bd50404dfeb66356d278d7f1e3/samples/0dB/[88]NOISY.wav -------------------------------------------------------------------------------- /tester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import numpy as np 5 | import config as cfg 6 | from run import model_test 7 | from dataloader import create_dataloader_for_test 8 | from model import DCCRN 9 | 10 | 11 | ############################################################################### 12 | # Helper function definition # 13 | ############################################################################### 14 | # Write training related parameters into the log file. 15 | def write_status_to_log_file(fp, total_parameters): 16 | fp.write('adsfasdfsdfds') 17 | fp.write('%d-%d-%d %d:%d:%d\n' % 18 | (time.localtime().tm_year, time.localtime().tm_mon, 19 | time.localtime().tm_mday, time.localtime().tm_hour, 20 | time.localtime().tm_min, time.localtime().tm_sec)) 21 | fp.write('mode : %s_%s\n' % (cfg.mode, cfg.info)) 22 | fp.write('learning rate : %g\n' % cfg.learning_rate) 23 | fp.write('total params : %d (%.2f M, %.2f MBytes)\n' % 24 | (total_parameters, 25 | total_parameters / 1000000.0, 26 | total_parameters * 4.0 / 1000000.0)) 27 | 28 | 29 | # Calculate the size of total network. 30 | def calculate_total_params(our_model): 31 | total_parameters = 0 32 | for variable in our_model.parameters(): 33 | shape = variable.size() 34 | variable_parameters = 1 35 | for dim in shape: 36 | variable_parameters *= dim 37 | total_parameters += variable_parameters 38 | 39 | return total_parameters 40 | 41 | 42 | ############################################################################### 43 | # Parameter Initialization # 44 | ############################################################################### 45 | print('***********************************************************') 46 | print('* Python library for DNN-based speech enhancement *') 47 | print('* using Pytorch API *') 48 | print('***********************************************************') 49 | 50 | # Set device 51 | DEVICE = torch.device("cuda") 52 | 53 | # Set model 54 | if cfg.mode == 'DCCRN': 55 | model = DCCRN(rnn_units=cfg.rnn_units, masking_mode=cfg.masking_mode, use_clstm=cfg.use_clstm, 56 | kernel_num=cfg.kernel_num).to(DEVICE) 57 | 58 | ############################################################################### 59 | # Set optimizer and learning rate # 60 | ############################################################################### 61 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate) 62 | total_params = calculate_total_params(model) 63 | 64 | ############################################################################### 65 | # Confirm model information # 66 | ############################################################################### 67 | print('%d-%d-%d %d:%d:%d\n' % 68 | (time.localtime().tm_year, time.localtime().tm_mon, 69 | time.localtime().tm_mday, time.localtime().tm_hour, 70 | time.localtime().tm_min, time.localtime().tm_sec)) 71 | print('mode : %s_%s\n' % (cfg.mode, cfg.info)) 72 | print('learning rate : %g\n' % cfg.learning_rate) 73 | print('total params : %d (%.2f M, %.2f MBytes)\n' % 74 | (total_params, 75 | total_params / 1000000.0, 76 | total_params * 4.0 / 1000000.0)) 77 | 78 | 79 | ############################################################################### 80 | # Set a log file to store progress. # 81 | # Set a hps file to store hyper-parameters information. # 82 | ############################################################################### 83 | # Load the checkpoint 84 | if cfg.chkpt_path is not None: 85 | print('Resuming from checkpoint: %s' % cfg.chkpt_path) 86 | 87 | # Set a log file to store progress. 88 | dir_to_save = cfg.job_dir + cfg.chkpt_model 89 | dir_to_logs = cfg.logs_dir + cfg.chkpt_model 90 | 91 | checkpoint = torch.load(cfg.chkpt_path) 92 | model.load_state_dict(checkpoint['model']) 93 | optimizer.load_state_dict(checkpoint['optimizer']) 94 | epoch_start_idx = checkpoint['epoch'] + 1 95 | mse_vali_total = np.load(str(dir_to_save + '/mse_vali_total.npy')) 96 | if len(mse_vali_total) < cfg.max_epochs: 97 | plus = cfg.max_epochs - len(mse_vali_total) 98 | mse_vali_total = np.concatenate((mse_vali_total, np.zeros(plus)), 0) 99 | 100 | 101 | if not os.path.exists(dir_to_save): 102 | os.mkdir(dir_to_save) 103 | os.mkdir(dir_to_logs) 104 | 105 | log_fname = str(dir_to_save + '/log.txt') 106 | if not os.path.exists(log_fname): 107 | fp = open(log_fname, 'w') 108 | write_status_to_log_file(fp, total_params) 109 | else: 110 | fp = open(log_fname, 'a') 111 | 112 | # Set a hps file to store hyper-parameters information. 113 | hps_fname = str(dir_to_save + '/hp_str.txt') 114 | fp_h = open(hps_fname, 'w') 115 | 116 | with open('config.py', 'r') as f: 117 | hp_str = ''.join(f.readlines()) 118 | fp_h.write(hp_str) 119 | fp_h.close() 120 | 121 | min_index = np.argmin(mse_vali_total) 122 | print('Minimum validation loss is at '+str(min_index+1)+'.') 123 | 124 | ############################################################################### 125 | # Test # 126 | ############################################################################### 127 | if cfg.test is True: 128 | print('Starting test run') 129 | 130 | # check the lowest validation loss epoch 131 | want_to_check = torch.load(dir_to_save + '/chkpt_opt.pt') 132 | model.load_state_dict(want_to_check['model']) 133 | optimizer.load_state_dict(want_to_check['optimizer']) 134 | epoch_start_idx = want_to_check['epoch'] + 1 135 | mse_vali_total = np.load(str(dir_to_save + '/mse_vali_total.npy')) 136 | 137 | # noise = [seen, unseen] 138 | noise_type = ['seen', 'unseen'] 139 | # SNR = [-10, -5, 0, 5, 10] 140 | noisy_snr = ['-10', '-5', '0', '5', '10', '15', '20', 'Avg'] 141 | for type in range(len(noise_type)): 142 | for snr in range(len(noisy_snr)): 143 | test_loader = create_dataloader_for_test(mode='test', type=type, snr=snr) 144 | test_loss, test_pesq, test_stoi = \ 145 | model_test(noise_type[type], noisy_snr[snr], model, 146 | test_loader, dir_to_save, DEVICE) 147 | 148 | print('Noise type {} | snr {}'.format(noise_type[type], noisy_snr[snr])) 149 | fp.write('\n\nNoise type {} | snr {}'.format(noise_type[type], noisy_snr[snr])) 150 | print('Test loss {:.6} | PESQ {:.6} | STOI {:.6}' 151 | .format(test_loss, test_pesq, test_stoi)) 152 | fp.write('Test loss {:.6f} | PESQ {:.6f} | STOI {:.6f}' 153 | .format(test_loss, test_pesq, test_stoi)) 154 | 155 | fp.close() 156 | else: 157 | fp.close() 158 | -------------------------------------------------------------------------------- /tools_for_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | import config as cfg 5 | from asteroid.losses import SingleSrcPMSQE, PITLossWrapper 6 | from asteroid.filterbanks import STFTFB, Encoder 7 | 8 | 9 | # Set training device 10 | DEVICE = torch.device("cuda") 11 | 12 | 13 | ############################################################################ 14 | # for model structure & loss function # 15 | ############################################################################ 16 | def remove_dc(data): 17 | mean = torch.mean(data, -1, keepdim=True) 18 | data = data - mean 19 | return data 20 | 21 | 22 | def l2_norm(s1, s2): 23 | # norm = torch.sqrt(torch.sum(s1*s2, 1, keepdim=True)) 24 | # norm = torch.norm(s1*s2, 1, keepdim=True) 25 | 26 | norm = torch.sum(s1 * s2, -1, keepdim=True) 27 | return norm 28 | 29 | 30 | def sdr(s1, s2, eps=1e-8): 31 | sn = l2_norm(s1, s1) 32 | sn_m_shn = l2_norm(s1 - s2, s1 - s2) 33 | sdr_loss = 10 * torch.log10(sn**2 / (sn_m_shn**2 + eps)) 34 | return torch.mean(sdr_loss) 35 | 36 | 37 | def si_snr(s1, s2, eps=1e-8): 38 | # s1 = remove_dc(s1) 39 | # s2 = remove_dc(s2) 40 | s1_s2_norm = l2_norm(s1, s2) 41 | s2_s2_norm = l2_norm(s2, s2) 42 | s_target = s1_s2_norm / (s2_s2_norm + eps) * s2 43 | e_nosie = s1 - s_target 44 | target_norm = l2_norm(s_target, s_target) 45 | noise_norm = l2_norm(e_nosie, e_nosie) 46 | snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps) 47 | return torch.mean(snr) 48 | 49 | 50 | def si_sdr(reference, estimation, eps=1e-8): 51 | """ 52 | Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) 53 | Args: 54 | reference: numpy.ndarray, [..., T] 55 | estimation: numpy.ndarray, [..., T] 56 | Returns: 57 | SI-SDR 58 | [1] SDR– Half- Baked or Well Done? 59 | http://www.merl.com/publications/docs/TR2019-013.pdf 60 | >>> np.random.seed(0) 61 | >>> reference = np.random.randn(100) 62 | >>> si_sdr(reference, reference) 63 | inf 64 | >>> si_sdr(reference, reference * 2) 65 | inf 66 | >>> si_sdr(reference, np.flip(reference)) 67 | -25.127672346460717 68 | >>> si_sdr(reference, reference + np.flip(reference)) 69 | 0.481070445785553 70 | >>> si_sdr(reference, reference + 0.5) 71 | 6.3704606032577304 72 | >>> si_sdr(reference, reference * 2 + 1) 73 | 6.3704606032577304 74 | >>> si_sdr([1., 0], [0., 0]) # never predict only zeros 75 | nan 76 | >>> si_sdr([reference, reference], [reference * 2 + 1, reference * 1 + 0.5]) 77 | array([6.3704606, 6.3704606]) 78 | :param reference: 79 | :param estimation: 80 | :param eps: 81 | """ 82 | 83 | reference_energy = torch.sum(reference ** 2, axis=-1, keepdims=True) 84 | 85 | # This is $\alpha$ after Equation (3) in [1]. 86 | optimal_scaling = torch.sum(reference * estimation, axis=-1, keepdims=True) / reference_energy + eps 87 | 88 | # This is $e_{\text{target}}$ in Equation (4) in [1]. 89 | projection = optimal_scaling * reference 90 | 91 | # This is $e_{\text{res}}$ in Equation (4) in [1]. 92 | noise = estimation - projection 93 | 94 | ratio = torch.sum(projection ** 2, axis=-1) / torch.sum(noise ** 2, axis=-1) + eps 95 | 96 | ratio = torch.mean(ratio) 97 | return 10 * torch.log10(ratio + eps) 98 | 99 | 100 | class rmse(torch.nn.Module): 101 | def __init__(self): 102 | super(rmse, self).__init__() 103 | 104 | def forward(self, y_true, y_pred): 105 | mse = torch.mean((y_pred - y_true) ** 2, axis=-1) 106 | rmse = torch.sqrt(mse + 1e-7) 107 | 108 | return torch.mean(rmse) 109 | 110 | 111 | 112 | ############################################################################ 113 | # MFCC (Mel Frequency Cepstral Coefficients) # 114 | ############################################################################ 115 | 116 | # based on a combination of this article: 117 | # http://practicalcryptography.com/miscellaneous/machine-learning/... 118 | # guide-mel-frequency-cepstral-coefficients-mfccs/ 119 | # and some of this code: 120 | # http://stackoverflow.com/questions/5835568/... 121 | # how-to-get-mfcc-from-an-fft-on-a-signal 122 | 123 | # conversions between Mel scale and regular frequency scale 124 | def freqToMel(freq): 125 | return 1127.01048 * math.log(1 + freq / 700.0) 126 | 127 | 128 | def melToFreq(mel): 129 | return 700 * (math.exp(mel / 1127.01048) - 1) 130 | 131 | 132 | # generate Mel filter bank 133 | def melFilterBank(numCoeffs, fftSize=None): 134 | minHz = 0 135 | maxHz = cfg.fs / 2 # max Hz by Nyquist theorem 136 | if (fftSize is None): 137 | numFFTBins = cfg.win_len 138 | else: 139 | numFFTBins = int(fftSize / 2) + 1 140 | 141 | maxMel = freqToMel(maxHz) 142 | minMel = freqToMel(minHz) 143 | 144 | # we need (numCoeffs + 2) points to create (numCoeffs) filterbanks 145 | melRange = np.array(range(numCoeffs + 2)) 146 | melRange = melRange.astype(np.float32) 147 | 148 | # create (numCoeffs + 2) points evenly spaced between minMel and maxMel 149 | melCenterFilters = melRange * (maxMel - minMel) / (numCoeffs + 1) + minMel 150 | 151 | for i in range(numCoeffs + 2): 152 | # mel domain => frequency domain 153 | melCenterFilters[i] = melToFreq(melCenterFilters[i]) 154 | 155 | # frequency domain => FFT bins 156 | melCenterFilters[i] = math.floor(numFFTBins * melCenterFilters[i] / maxHz) 157 | 158 | # create matrix of filters (one row is one filter) 159 | filterMat = np.zeros((numCoeffs, numFFTBins)) 160 | 161 | # generate triangular filters (in frequency domain) 162 | for i in range(1, numCoeffs + 1): 163 | filter = np.zeros(numFFTBins) 164 | 165 | startRange = int(melCenterFilters[i - 1]) 166 | midRange = int(melCenterFilters[i]) 167 | endRange = int(melCenterFilters[i + 1]) 168 | 169 | for j in range(startRange, midRange): 170 | filter[j] = (float(j) - startRange) / (midRange - startRange) 171 | for j in range(midRange, endRange): 172 | filter[j] = 1 - ((float(j) - midRange) / (endRange - midRange)) 173 | 174 | filterMat[i - 1] = filter 175 | 176 | # return filterbank as matrix 177 | return filterMat 178 | 179 | 180 | 181 | ############################################################################ 182 | # Finally: a perceptual loss function (based on Mel scale) # 183 | ############################################################################ 184 | 185 | FFT_SIZE = cfg.fft_len 186 | 187 | # multi-scale MFCC distance 188 | MEL_SCALES = [16, 32, 64] # for LMS 189 | # PAM : MEL_SCALES = [32, 64] 190 | 191 | 192 | # given a (symbolic Theano) array of size M x WINDOW_SIZE 193 | # this returns an array M x N where each window has been replaced 194 | # by some perceptual transform (in this case, MFCC coeffs) 195 | def perceptual_transform(x): 196 | # precompute Mel filterbank: [FFT_SIZE x NUM_MFCC_COEFFS] 197 | MEL_FILTERBANKS = [] 198 | for scale in MEL_SCALES: 199 | filterbank_npy = melFilterBank(scale, FFT_SIZE).transpose() 200 | torch_filterbank_npy = torch.from_numpy(filterbank_npy).type(torch.FloatTensor) 201 | MEL_FILTERBANKS.append(torch_filterbank_npy.to(DEVICE)) 202 | 203 | transforms = [] 204 | # powerSpectrum = torch_dft_mag(x, DFT_REAL, DFT_IMAG)**2 205 | 206 | powerSpectrum = x.view(-1, FFT_SIZE // 2 + 1) 207 | powerSpectrum = 1.0 / FFT_SIZE * powerSpectrum 208 | 209 | for filterbank in MEL_FILTERBANKS: 210 | filteredSpectrum = torch.mm(powerSpectrum, filterbank) 211 | filteredSpectrum = torch.log(filteredSpectrum + 1e-7) 212 | transforms.append(filteredSpectrum) 213 | 214 | return transforms 215 | 216 | 217 | # perceptual loss function 218 | class perceptual_distance(torch.nn.Module): 219 | 220 | def __init__(self): 221 | super(perceptual_distance, self).__init__() 222 | 223 | def forward(self, y_true, y_pred): 224 | rmse_loss = rmse() 225 | # y_true = torch.reshape(y_true, (-1, WINDOW_SIZE)) 226 | # y_pred = torch.reshape(y_pred, (-1, WINDOW_SIZE)) 227 | 228 | pvec_true = perceptual_transform(y_true) 229 | pvec_pred = perceptual_transform(y_pred) 230 | 231 | distances = [] 232 | for i in range(0, len(pvec_true)): 233 | error = rmse_loss(pvec_pred[i], pvec_true[i]) 234 | error = error.unsqueeze(dim=-1) 235 | distances.append(error) 236 | distances = torch.cat(distances, axis=-1) 237 | 238 | loss = torch.mean(distances, axis=-1) 239 | return torch.mean(loss) 240 | 241 | 242 | get_mel_loss = perceptual_distance() 243 | 244 | 245 | def get_array_mel_loss(clean_array, est_array): 246 | array_mel_loss = 0 247 | for i in range(len(clean_array)): 248 | mel_loss = get_mel_loss(clean_array[i], est_array[i]) 249 | array_mel_loss += mel_loss 250 | 251 | avg_mel_loss = array_mel_loss / len(clean_array) 252 | return avg_mel_loss 253 | 254 | 255 | ############################################################################ 256 | # for pmsqe loss # 257 | ############################################################################ 258 | pmsqe_stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256)) 259 | pmsqe_loss = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt') 260 | -------------------------------------------------------------------------------- /tools_for_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import time 5 | import torch.nn.functional as F 6 | from scipy.signal import get_window 7 | import matplotlib.pylab as plt 8 | from pesq import pesq 9 | from pystoi import stoi 10 | 11 | 12 | ############################################################################ 13 | # for convolutional STFT # 14 | ############################################################################ 15 | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): 16 | if win_type == 'None' or win_type is None: 17 | window = np.ones(win_len) 18 | else: 19 | window = get_window(win_type, win_len, fftbins=True) # **0.5 20 | 21 | N = fft_len 22 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len] 23 | real_kernel = np.real(fourier_basis) 24 | imag_kernel = np.imag(fourier_basis) 25 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T 26 | 27 | if invers: 28 | kernel = np.linalg.pinv(kernel).T 29 | 30 | kernel = kernel * window 31 | kernel = kernel[:, None, :] 32 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32)) 33 | 34 | 35 | class ConvSTFT(nn.Module): 36 | 37 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): 38 | super(ConvSTFT, self).__init__() 39 | 40 | if fft_len == None: 41 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len))) 42 | else: 43 | self.fft_len = fft_len 44 | 45 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) 46 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 47 | self.register_buffer('weight', kernel) 48 | self.feature_type = feature_type 49 | self.stride = win_inc 50 | self.win_len = win_len 51 | self.dim = self.fft_len 52 | 53 | def forward(self, inputs): 54 | if inputs.dim() == 2: 55 | inputs = torch.unsqueeze(inputs, 1) 56 | inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride]) 57 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 58 | 59 | if self.feature_type == 'complex': 60 | return outputs 61 | else: 62 | dim = self.dim // 2 + 1 63 | real = outputs[:, :dim, :] 64 | imag = outputs[:, dim:, :] 65 | mags = torch.sqrt(real ** 2 + imag ** 2) 66 | phase = torch.atan2(imag, real) 67 | return mags, phase 68 | 69 | 70 | class ConviSTFT(nn.Module): 71 | 72 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): 73 | super(ConviSTFT, self).__init__() 74 | if fft_len == None: 75 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 76 | else: 77 | self.fft_len = fft_len 78 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) 79 | #self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 80 | self.register_buffer('weight', kernel) 81 | self.feature_type = feature_type 82 | self.win_type = win_type 83 | self.win_len = win_len 84 | self.stride = win_inc 85 | self.stride = win_inc 86 | self.dim = self.fft_len 87 | self.register_buffer('window', window) 88 | self.register_buffer('enframe', torch.eye(win_len)[:,None,:]) 89 | 90 | def forward(self, inputs, phase=None): 91 | """ 92 | inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) 93 | phase: [B, N//2+1, T] (if not none) 94 | """ 95 | 96 | if phase is not None: 97 | real = inputs * torch.cos(phase) 98 | imag = inputs * torch.sin(phase) 99 | inputs = torch.cat([real, imag], 1) 100 | outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) 101 | 102 | # this is from torch-stft: https://github.com/pseeth/torch-stft 103 | t = self.window.repeat(1, 1, inputs.size(-1)) ** 2 104 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) 105 | outputs = outputs / (coff + 1e-8) 106 | # outputs = torch.where(coff == 0, outputs, outputs/coff) 107 | outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)] 108 | 109 | return outputs 110 | 111 | 112 | ############################################################################ 113 | # for complex rnn # 114 | ############################################################################ 115 | def get_casual_padding1d(): 116 | pass 117 | 118 | 119 | def get_casual_padding2d(): 120 | pass 121 | 122 | 123 | class cPReLU(nn.Module): 124 | 125 | def __init__(self, complex_axis=1): 126 | super(cPReLU, self).__init__() 127 | self.r_prelu = nn.PReLU() 128 | self.i_prelu = nn.PReLU() 129 | self.complex_axis = complex_axis 130 | 131 | def forward(self, inputs): 132 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 133 | real = self.r_prelu(real) 134 | imag = self.i_prelu(imag) 135 | return torch.cat([real, imag], self.complex_axis) 136 | 137 | 138 | class NavieComplexLSTM(nn.Module): 139 | def __init__(self, input_size, hidden_size, projection_dim=None, bidirectional=False, batch_first=False): 140 | super(NavieComplexLSTM, self).__init__() 141 | 142 | self.input_dim = input_size // 2 143 | self.rnn_units = hidden_size // 2 144 | self.real_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional, 145 | batch_first=False) 146 | self.imag_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional, 147 | batch_first=False) 148 | if bidirectional: 149 | bidirectional = 2 150 | else: 151 | bidirectional = 1 152 | if projection_dim is not None: 153 | self.projection_dim = projection_dim // 2 154 | self.r_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim) 155 | self.i_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim) 156 | else: 157 | self.projection_dim = None 158 | 159 | def forward(self, inputs): 160 | if isinstance(inputs, list): 161 | real, imag = inputs 162 | elif isinstance(inputs, torch.Tensor): 163 | real, imag = torch.chunk(inputs, -1) 164 | r2r_out = self.real_lstm(real)[0] 165 | r2i_out = self.imag_lstm(real)[0] 166 | i2r_out = self.real_lstm(imag)[0] 167 | i2i_out = self.imag_lstm(imag)[0] 168 | real_out = r2r_out - i2i_out 169 | imag_out = i2r_out + r2i_out 170 | if self.projection_dim is not None: 171 | real_out = self.r_trans(real_out) 172 | imag_out = self.i_trans(imag_out) 173 | # print(real_out.shape,imag_out.shape) 174 | return [real_out, imag_out] 175 | 176 | def flatten_parameters(self): 177 | self.imag_lstm.flatten_parameters() 178 | self.real_lstm.flatten_parameters() 179 | 180 | 181 | def complex_cat(inputs, axis): 182 | real, imag = [], [] 183 | for idx, data in enumerate(inputs): 184 | r, i = torch.chunk(data, 2, axis) # x = torch.chunk(x, n, dim) >> x의 dim 차원을 n개씩 잘라서 뽑아옴 185 | real.append(r) 186 | imag.append(i) 187 | real = torch.cat(real, axis) # torch.cat : 차원 늘리기 188 | imag = torch.cat(imag, axis) 189 | outputs = torch.cat([real, imag], axis) 190 | return outputs 191 | 192 | 193 | class ComplexConv2d(nn.Module): 194 | 195 | def __init__( 196 | self, 197 | in_channels, 198 | out_channels, 199 | kernel_size=(1, 1), 200 | stride=(1, 1), 201 | padding=(0, 0), 202 | dilation=1, 203 | groups=1, 204 | causal=True, 205 | complex_axis=1, 206 | ): 207 | ''' 208 | in_channels: real+imag 209 | out_channels: real+imag 210 | kernel_size : input [B,C,D,T] kernel size in [D,T] 211 | padding : input [B,C,D,T] padding in [D,T] 212 | causal: if causal, will padding time dimension's left side, 213 | otherwise both 214 | 215 | ''' 216 | super(ComplexConv2d, self).__init__() 217 | self.in_channels = in_channels // 2 218 | self.out_channels = out_channels // 2 219 | self.kernel_size = kernel_size 220 | self.stride = stride 221 | self.padding = padding 222 | self.causal = causal 223 | self.groups = groups 224 | self.dilation = dilation 225 | self.complex_axis = complex_axis 226 | self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride, 227 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups) 228 | self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride, 229 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups) 230 | 231 | nn.init.normal_(self.real_conv.weight.data, std=0.05) 232 | nn.init.normal_(self.imag_conv.weight.data, std=0.05) 233 | nn.init.constant_(self.real_conv.bias, 0.) 234 | nn.init.constant_(self.imag_conv.bias, 0.) 235 | 236 | def forward(self, inputs): 237 | if self.padding[1] != 0 and self.causal: 238 | inputs = F.pad(inputs, [self.padding[1], 0, 0, 0]) 239 | else: 240 | inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0]) 241 | 242 | if self.complex_axis == 0: 243 | real = self.real_conv(inputs) 244 | imag = self.imag_conv(inputs) 245 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis) 246 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) 247 | 248 | else: 249 | if isinstance(inputs, torch.Tensor): 250 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 251 | 252 | real2real = self.real_conv(real, ) 253 | imag2imag = self.imag_conv(imag, ) 254 | 255 | real2imag = self.imag_conv(real) 256 | imag2real = self.real_conv(imag) 257 | 258 | real = real2real - imag2imag 259 | imag = real2imag + imag2real 260 | out = torch.cat([real, imag], self.complex_axis) 261 | 262 | return out 263 | 264 | 265 | class ComplexConvTranspose2d(nn.Module): 266 | 267 | def __init__( 268 | self, 269 | in_channels, 270 | out_channels, 271 | kernel_size=(1, 1), 272 | stride=(1, 1), 273 | padding=(0, 0), 274 | output_padding=(0, 0), 275 | causal=False, 276 | complex_axis=1, 277 | groups=1 278 | ): 279 | ''' 280 | in_channels: real+imag 281 | out_channels: real+imag 282 | ''' 283 | super(ComplexConvTranspose2d, self).__init__() 284 | self.in_channels = in_channels // 2 285 | self.out_channels = out_channels // 2 286 | self.kernel_size = kernel_size 287 | self.stride = stride 288 | self.padding = padding 289 | self.output_padding = output_padding 290 | self.groups = groups 291 | 292 | self.real_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride, 293 | padding=self.padding, output_padding=output_padding, groups=self.groups) 294 | self.imag_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride, 295 | padding=self.padding, output_padding=output_padding, groups=self.groups) 296 | self.complex_axis = complex_axis 297 | 298 | nn.init.normal_(self.real_conv.weight, std=0.05) 299 | nn.init.normal_(self.imag_conv.weight, std=0.05) 300 | nn.init.constant_(self.real_conv.bias, 0.) 301 | nn.init.constant_(self.imag_conv.bias, 0.) 302 | 303 | def forward(self, inputs): 304 | 305 | if isinstance(inputs, torch.Tensor): 306 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 307 | elif isinstance(inputs, tuple) or isinstance(inputs, list): 308 | real = inputs[0] 309 | imag = inputs[1] 310 | if self.complex_axis == 0: 311 | real = self.real_conv(inputs) 312 | imag = self.imag_conv(inputs) 313 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis) 314 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) 315 | 316 | else: 317 | if isinstance(inputs, torch.Tensor): 318 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 319 | 320 | real2real = self.real_conv(real, ) 321 | imag2imag = self.imag_conv(imag, ) 322 | 323 | real2imag = self.imag_conv(real) 324 | imag2real = self.real_conv(imag) 325 | 326 | real = real2real - imag2imag 327 | imag = real2imag + imag2real 328 | out = torch.cat([real, imag], self.complex_axis) 329 | 330 | return out 331 | 332 | 333 | # Source: https://github.com/ChihebTrabelsi/deep_complex_networks/tree/pytorch 334 | # from https://github.com/IMLHF/SE_DCUNet/blob/f28bf1661121c8901ad38149ea827693f1830715/models/layers/complexnn.py#L55 335 | class ComplexBatchNorm(torch.nn.Module): 336 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 337 | track_running_stats=True, complex_axis=1): 338 | super(ComplexBatchNorm, self).__init__() 339 | self.num_features = num_features // 2 340 | self.eps = eps 341 | self.momentum = momentum 342 | self.affine = affine 343 | self.track_running_stats = track_running_stats 344 | 345 | self.complex_axis = complex_axis 346 | 347 | if self.affine: 348 | self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features)) 349 | self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features)) 350 | self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features)) 351 | self.Br = torch.nn.Parameter(torch.Tensor(self.num_features)) 352 | self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features)) 353 | else: 354 | self.register_parameter('Wrr', None) 355 | self.register_parameter('Wri', None) 356 | self.register_parameter('Wii', None) 357 | self.register_parameter('Br', None) 358 | self.register_parameter('Bi', None) 359 | 360 | if self.track_running_stats: 361 | self.register_buffer('RMr', torch.zeros(self.num_features)) 362 | self.register_buffer('RMi', torch.zeros(self.num_features)) 363 | self.register_buffer('RVrr', torch.ones(self.num_features)) 364 | self.register_buffer('RVri', torch.zeros(self.num_features)) 365 | self.register_buffer('RVii', torch.ones(self.num_features)) 366 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 367 | else: 368 | self.register_parameter('RMr', None) 369 | self.register_parameter('RMi', None) 370 | self.register_parameter('RVrr', None) 371 | self.register_parameter('RVri', None) 372 | self.register_parameter('RVii', None) 373 | self.register_parameter('num_batches_tracked', None) 374 | self.reset_parameters() 375 | 376 | def reset_running_stats(self): 377 | if self.track_running_stats: 378 | self.RMr.zero_() 379 | self.RMi.zero_() 380 | self.RVrr.fill_(1) 381 | self.RVri.zero_() 382 | self.RVii.fill_(1) 383 | self.num_batches_tracked.zero_() 384 | 385 | def reset_parameters(self): 386 | self.reset_running_stats() 387 | if self.affine: 388 | self.Br.data.zero_() 389 | self.Bi.data.zero_() 390 | self.Wrr.data.fill_(1) 391 | self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite 392 | self.Wii.data.fill_(1) 393 | 394 | def _check_input_dim(self, xr, xi): 395 | assert (xr.shape == xi.shape) 396 | assert (xr.size(1) == self.num_features) 397 | 398 | def forward(self, inputs): 399 | # self._check_input_dim(xr, xi) 400 | 401 | xr, xi = torch.chunk(inputs, 2, axis=self.complex_axis) 402 | exponential_average_factor = 0.0 403 | 404 | if self.training and self.track_running_stats: 405 | self.num_batches_tracked += 1 406 | if self.momentum is None: # use cumulative moving average 407 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 408 | else: # use exponential moving average 409 | exponential_average_factor = self.momentum 410 | 411 | # 412 | # NOTE: The precise meaning of the "training flag" is: 413 | # True: Normalize using batch statistics, update running statistics 414 | # if they are being collected. 415 | # False: Normalize using running statistics, ignore batch statistics. 416 | # 417 | training = self.training or not self.track_running_stats 418 | redux = [i for i in reversed(range(xr.dim())) if i != 1] 419 | vdim = [1] * xr.dim() 420 | vdim[1] = xr.size(1) 421 | 422 | # 423 | # Mean M Computation and Centering 424 | # 425 | # Includes running mean update if training and running. 426 | # 427 | if training: 428 | Mr, Mi = xr, xi 429 | for d in redux: 430 | Mr = Mr.mean(d, keepdim=True) 431 | Mi = Mi.mean(d, keepdim=True) 432 | if self.track_running_stats: 433 | self.RMr.lerp_(Mr.squeeze(), exponential_average_factor) 434 | self.RMi.lerp_(Mi.squeeze(), exponential_average_factor) 435 | else: 436 | Mr = self.RMr.view(vdim) 437 | Mi = self.RMi.view(vdim) 438 | xr, xi = xr - Mr, xi - Mi 439 | 440 | # 441 | # Variance Matrix V Computation 442 | # 443 | # Includes epsilon numerical stabilizer/Tikhonov regularizer. 444 | # Includes running variance update if training and running. 445 | # 446 | if training: 447 | Vrr = xr * xr 448 | Vri = xr * xi 449 | Vii = xi * xi 450 | for d in redux: 451 | Vrr = Vrr.mean(d, keepdim=True) 452 | Vri = Vri.mean(d, keepdim=True) 453 | Vii = Vii.mean(d, keepdim=True) 454 | if self.track_running_stats: 455 | self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor) 456 | self.RVri.lerp_(Vri.squeeze(), exponential_average_factor) 457 | self.RVii.lerp_(Vii.squeeze(), exponential_average_factor) 458 | else: 459 | Vrr = self.RVrr.view(vdim) 460 | Vri = self.RVri.view(vdim) 461 | Vii = self.RVii.view(vdim) 462 | Vrr = Vrr + self.eps 463 | Vri = Vri 464 | Vii = Vii + self.eps 465 | 466 | # 467 | # Matrix Inverse Square Root U = V^-0.5 468 | # 469 | # sqrt of a 2x2 matrix, 470 | # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix 471 | tau = Vrr + Vii 472 | delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri) 473 | s = delta.sqrt() 474 | t = (tau + 2 * s).sqrt() 475 | 476 | # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html 477 | rst = (s * t).reciprocal() 478 | Urr = (s + Vii) * rst 479 | Uii = (s + Vrr) * rst 480 | Uri = (- Vri) * rst 481 | 482 | # 483 | # Optionally left-multiply U by affine weights W to produce combined 484 | # weights Z, left-multiply the inputs by Z, then optionally bias them. 485 | # 486 | # y = Zx + B 487 | # y = WUx + B 488 | # y = [Wrr Wri][Urr Uri] [xr] + [Br] 489 | # [Wir Wii][Uir Uii] [xi] [Bi] 490 | # 491 | if self.affine: 492 | Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim) 493 | Zrr = (Wrr * Urr) + (Wri * Uri) 494 | Zri = (Wrr * Uri) + (Wri * Uii) 495 | Zir = (Wri * Urr) + (Wii * Uri) 496 | Zii = (Wri * Uri) + (Wii * Uii) 497 | else: 498 | Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii 499 | 500 | yr = (Zrr * xr) + (Zri * xi) 501 | yi = (Zir * xr) + (Zii * xi) 502 | 503 | if self.affine: 504 | yr = yr + self.Br.view(vdim) 505 | yi = yi + self.Bi.view(vdim) 506 | 507 | outputs = torch.cat([yr, yi], self.complex_axis) 508 | return outputs 509 | 510 | def extra_repr(self): 511 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 512 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 513 | 514 | 515 | def complex_cat(inputs, axis): 516 | real, imag = [], [] 517 | for idx, data in enumerate(inputs): 518 | r, i = torch.chunk(data, 2, axis) 519 | real.append(r) 520 | imag.append(i) 521 | real = torch.cat(real, axis) 522 | imag = torch.cat(imag, axis) 523 | outputs = torch.cat([real, imag], axis) 524 | return outputs 525 | 526 | 527 | ############################################################################ 528 | # for data normalization # 529 | ############################################################################ 530 | # get mu and sig 531 | def get_mu_sig(data): 532 | """Compute mean and standard deviation vector of input data 533 | 534 | Returns: 535 | mu: mean vector (#dim by one) 536 | sig: standard deviation vector (#dim by one) 537 | """ 538 | # Initialize array. 539 | data_num = len(data) 540 | mu_utt = [] 541 | tmp_utt = [] 542 | for n in range(data_num): 543 | dim = len(data[n]) 544 | mu_utt_tmp = np.zeros(dim) 545 | mu_utt.append(mu_utt_tmp) 546 | 547 | tmp_utt_tmp = np.zeros(dim) 548 | tmp_utt.append(tmp_utt_tmp) 549 | 550 | 551 | # Get mean. 552 | for n in range(data_num): 553 | mu_utt[n] = np.mean(data[n], 0) 554 | mu = mu_utt 555 | 556 | # Get standard deviation. 557 | for n in range(data_num): 558 | tmp_utt[n] = np.mean(np.square(data[n] - mu[n]), 0) 559 | sig = np.sqrt(tmp_utt) 560 | 561 | # Assign unit variance. 562 | for n in range(len(sig)): 563 | if sig[n] < 1e-5: 564 | sig[n] = 1.0 565 | return np.float16(mu), np.float16(sig) 566 | 567 | 568 | def get_statistics_inp(inp): 569 | """Get statistical parameter of input data. 570 | 571 | Args: 572 | inp: input data 573 | 574 | Returns: 575 | mu_inp: mean vector of input data 576 | sig_inp: standard deviation vector of input data 577 | """ 578 | 579 | mu_inp, sig_inp = get_mu_sig(inp) 580 | 581 | return mu_inp, sig_inp 582 | 583 | 584 | ############################################################################ 585 | # for scores # 586 | ############################################################################ 587 | def cal_pesq(dirty_wavs, clean_wavs): 588 | pesq_scores = [] 589 | for i in range(len(dirty_wavs)): 590 | pesq_score = pesq(cfg.FS, clean_wavs[i], dirty_wavs[i], "wb") 591 | pesq_scores.append(pesq_score) 592 | return pesq_scores 593 | 594 | 595 | def cal_stoi(dirty_wavs, clean_wavs): 596 | stoi_scores = [] 597 | for i in range(len(dirty_wavs)): 598 | stoi_score = stoi(clean_wavs[i], dirty_wavs[i], cfg.FS, extended=False) 599 | stoi_scores.append(stoi_score) 600 | return stoi_scores 601 | 602 | 603 | ############################################################################ 604 | # for plotting the samples # 605 | ############################################################################ 606 | def hann_window(win_samp): 607 | tmp = np.arange(1, win_samp + 1, 1.0, dtype=np.float64) 608 | window = 0.5 - 0.5 * np.cos((2.0 * np.pi * tmp) / (win_samp + 1)) 609 | return np.float32(window) 610 | 611 | 612 | def fig2np(fig): 613 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 614 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 615 | return data 616 | 617 | 618 | def plot_spectrogram_to_numpy(input_wav, fs, n_fft, n_overlap, win, mode, clim, label): 619 | # cuda to cpu 620 | input_wav = input_wav.cpu().detach().numpy() 621 | 622 | fig, ax = plt.subplots(figsize=(12, 3)) 623 | 624 | if mode == 'phase': 625 | pxx, freq, t, cax = plt.specgram(input_wav, NFFT=int(n_fft), Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet', 626 | mode=mode) 627 | else: 628 | pxx, freq, t, cax = plt.specgram(input_wav, NFFT=int(n_fft), Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet') 629 | 630 | plt.xlabel('Time (s)') 631 | plt.ylabel('Frequency (Hz)') 632 | plt.tight_layout() 633 | plt.clim(clim) 634 | 635 | if label is None: 636 | fig.colorbar(cax) 637 | else: 638 | fig.colorbar(cax, label=label) 639 | 640 | fig.canvas.draw() 641 | data = fig2np(fig) 642 | plt.close() 643 | return data 644 | 645 | 646 | def plot_mask_to_numpy(mask, fs, n_fft, n_overlap, win, clim1, clim2, cmap): 647 | frame_num = mask.shape[0] 648 | shift_length = n_overlap 649 | frame_length = n_fft 650 | signal_length = frame_num * shift_length + frame_length 651 | 652 | xt = np.arange(0, np.floor(10 * signal_length / fs) / 10, step=0.5) / (signal_length / fs) * frame_num + 1e-8 653 | yt = (n_fft / 2) / (fs / 1000 / 2) * np.arange(0, (fs / 1000 / 2) + 1) 654 | 655 | fig, ax = plt.subplots(figsize=(12, 3)) 656 | im = ax.imshow(np.transpose(mask), aspect='auto', origin='lower', interpolation='none', cmap=cmap) 657 | 658 | plt.xlabel('Time (s)') 659 | plt.ylabel('Frequency (kHz)') 660 | plt.xticks(xt, np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5)) 661 | plt.yticks(yt, np.int16(np.linspace(0, int((fs / 1000) / 2), len(yt)))) 662 | plt.tight_layout() 663 | plt.colorbar(im, ax=ax) 664 | im.set_clim(clim1, clim2) 665 | 666 | fig.canvas.draw() 667 | data = fig2np(fig) 668 | plt.close() 669 | return data 670 | 671 | 672 | def plot_error_to_numpy(estimated, target, fs, n_fft, n_overlap, win, mode, clim1, clim2, label): 673 | fig, ax = plt.subplots(figsize=(12, 3)) 674 | if mode == None: 675 | pxx1, freq, t, cax = plt.specgram(estimated, NFFT=n_fft, Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet') 676 | pxx2, freq, t, cax = plt.specgram(target, NFFT=n_fft, Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet') 677 | im = ax.imshow(10 * np.log10(pxx1) - 10 * np.log10(pxx2), aspect='auto', origin='lower', interpolation='none', 678 | cmap='jet') 679 | else: 680 | pxx1, freq, t, cax = plt.specgram(estimated, NFFT=n_fft, Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet', 681 | mode=mode) 682 | pxx2, freq, t, cax = plt.specgram(target, NFFT=n_fft, Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet', 683 | mode=mode) 684 | im = ax.imshow(pxx1 - pxx2, aspect='auto', origin='lower', interpolation='none', cmap='jet') 685 | 686 | frame_num = pxx1.shape[1] 687 | shift_length = n_overlap 688 | frame_length = n_fft 689 | signal_length = frame_num * shift_length + frame_length 690 | 691 | xt = np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5) / (signal_length / fs) * frame_num 692 | yt = (n_fft / 2) / (fs / 1000 / 2) * np.arange(0, (fs / 1000 / 2) + 1) 693 | 694 | plt.xlabel('Time (s)') 695 | plt.ylabel('Frequency (kHz)') 696 | plt.xticks(xt, np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5)) 697 | plt.yticks(yt, np.int16(np.linspace(0, int((fs / 1000) / 2), len(yt)))) 698 | plt.tight_layout() 699 | plt.colorbar(im, ax=ax, label=label) 700 | im.set_clim(clim1, clim2) 701 | 702 | fig.canvas.draw() 703 | data = fig2np(fig) 704 | plt.close() 705 | return data 706 | 707 | 708 | ############################################################################ 709 | # for run.py # 710 | ############################################################################ 711 | def near_avg_index(array): 712 | array_mean = np.mean(array) 713 | 714 | distance_arr = [] 715 | for i in range(len(array)): 716 | val = array[i] 717 | distance = abs(array_mean - val) 718 | distance_arr.append(distance) 719 | 720 | index = distance_arr.index(min(distance_arr)) 721 | return index 722 | 723 | 724 | def max_index(array): 725 | array_max = np.max(array) 726 | 727 | for i in range(len(array)): 728 | val = array[i] 729 | if val == array_max: 730 | index = i 731 | return index 732 | 733 | 734 | def min_index(array): 735 | array_min = np.min(array) 736 | 737 | for i in range(len(array)): 738 | val = array[i] 739 | if val == array_min: 740 | index = i 741 | return index 742 | 743 | 744 | class Bar(object): 745 | def __init__(self, dataloader): 746 | if not hasattr(dataloader, 'dataset'): 747 | raise ValueError('Attribute `dataset` not exists in dataloder.') 748 | if not hasattr(dataloader, 'batch_size'): 749 | raise ValueError('Attribute `batch_size` not exists in dataloder.') 750 | 751 | self.dataloader = dataloader 752 | self.iterator = iter(dataloader) 753 | self.dataset = dataloader.dataset 754 | self.batch_size = dataloader.batch_size 755 | self._idx = 0 756 | self._batch_idx = 0 757 | self._time = [] 758 | self._DISPLAY_LENGTH = 50 759 | 760 | def __len__(self): 761 | return len(self.dataloader) 762 | 763 | def __iter__(self): 764 | return self 765 | 766 | def __next__(self): 767 | if len(self._time) < 2: 768 | self._time.append(time.time()) 769 | 770 | self._batch_idx += self.batch_size 771 | if self._batch_idx > len(self.dataset): 772 | self._batch_idx = len(self.dataset) 773 | 774 | try: 775 | batch = next(self.iterator) 776 | self._display() 777 | except StopIteration: 778 | raise StopIteration() 779 | 780 | self._idx += 1 781 | if self._idx >= len(self.dataloader): 782 | self._reset() 783 | 784 | return batch 785 | 786 | def _display(self): 787 | if len(self._time) > 1: 788 | t = (self._time[-1] - self._time[-2]) 789 | eta = t * (len(self.dataloader) - self._idx) 790 | else: 791 | eta = 0 792 | 793 | rate = self._idx / len(self.dataloader) 794 | len_bar = int(rate * self._DISPLAY_LENGTH) 795 | bar = ('=' * len_bar + '>').ljust(self._DISPLAY_LENGTH, '.') 796 | idx = str(self._batch_idx).rjust(len(str(len(self.dataset))), ' ') 797 | 798 | tmpl = '\r{}/{}: [{}] - ETA {:.1f}s'.format( 799 | idx, 800 | len(self.dataset), 801 | bar, 802 | eta 803 | ) 804 | print(tmpl, end='') 805 | if self._batch_idx == len(self.dataset): 806 | print() 807 | 808 | def _reset(self): 809 | self._idx = 0 810 | self._batch_idx = 0 811 | self._time = [] 812 | 813 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run the trainer and tester 3 | """ 4 | import torch 5 | import numpy as np 6 | from scipy.io.wavfile import write as wav_write 7 | from tools_for_model import near_avg_index, max_index, min_index, Bar, cal_pesq, cal_stoi 8 | from config import fs, info, mode 9 | 10 | 11 | def model_train(model, optimizer, train_loader, epoch, DEVICE): 12 | # initialization 13 | train_loss = 0 14 | batch_num = 0 15 | 16 | # train 17 | model.train() 18 | for inputs, labels in Bar(train_loader): 19 | batch_num += 1 20 | 21 | # to cuda 22 | inputs = inputs.float().to(DEVICE) 23 | labels = labels.float().to(DEVICE) 24 | 25 | _, _, real_spec, img_spec, outputs = model(inputs) 26 | loss = model.loss(outputs, labels, real_spec, img_spec) 27 | # loss = model.pmsqe_loss(labels, outputs) 28 | 29 | optimizer.zero_grad() 30 | loss.backward() 31 | optimizer.step() 32 | 33 | train_loss += loss 34 | train_loss /= batch_num 35 | 36 | return train_loss 37 | 38 | 39 | def model_validate(model, validation_loader, dir_to_save, writer, epoch, DEVICE): 40 | # initialization 41 | batch_num = 0 42 | validation_loss = 0 43 | avg_pesq = 0 44 | avg_stoi = 0 45 | 46 | all_batch_input = [] 47 | all_batch_label = [] 48 | all_batch_output = [] 49 | all_batch_real_spec = [] 50 | all_batch_img_spec = [] 51 | all_batch_pesq = [] 52 | 53 | f_pesq = open(dir_to_save + '/pesq_epoch_' + '%d' % epoch, 'a') 54 | f_stoi = open(dir_to_save + '/stoi_epoch_' + '%d' % epoch, 'a') 55 | 56 | model.eval() 57 | 58 | with torch.no_grad(): 59 | for inputs, labels in Bar(validation_loader): 60 | batch_num += 1 61 | 62 | # to cuda 63 | inputs = inputs.float().to(DEVICE) 64 | labels = labels.float().to(DEVICE) 65 | 66 | mask_real, mask_imag, real_spec, img_spec, outputs = model(inputs) 67 | loss = model.loss(outputs, labels, real_spec, img_spec) 68 | 69 | # loss = model.pmsqe_loss(labels, outputs) 70 | 71 | # estimate the output speech with pesq and stoi 72 | # save pesq & stoi score at each epoch 73 | estimated_wavs = outputs.cpu().detach().numpy() 74 | clean_wavs = labels.cpu().detach().numpy() 75 | 76 | pesq = cal_pesq(estimated_wavs, clean_wavs) ## 98 77 | stoi = cal_stoi(estimated_wavs, clean_wavs) 78 | 79 | # pesq: 0.1 better / stoi: 0.01 better 80 | for i in range(len(pesq)): 81 | f_pesq.write('{:.6f}\n'.format(pesq[i])) 82 | f_stoi.write('{:.4f}\n'.format(stoi[i])) 83 | 84 | # reshape for sum 85 | pesq = np.reshape(pesq, (1, -1)) 86 | stoi = np.reshape(stoi, (1, -1)) 87 | 88 | avg_pesq += sum(pesq[0]) / len(inputs) 89 | avg_stoi += sum(stoi[0]) / len(inputs) 90 | 91 | if epoch % 10 == 0: 92 | # all batch data array 93 | all_batch_input.extend(inputs) 94 | all_batch_label.extend(labels) 95 | all_batch_output.extend(outputs) 96 | all_batch_real_spec.extend(mask_real) 97 | all_batch_img_spec.extend(mask_imag) 98 | all_batch_pesq.extend(pesq[0]) 99 | 100 | validation_loss += loss 101 | 102 | # save the samples to tensorboard 103 | if epoch % 10 == 0: 104 | all_batch_pesq = np.reshape(all_batch_pesq, (-1, 1)) 105 | 106 | # find the best & worst pesq model 107 | max_pesq_index = max_index(all_batch_pesq) 108 | min_pesq_index = min_index(all_batch_pesq) 109 | 110 | # find the avg pesq model 111 | avg_pesq_index = near_avg_index(all_batch_pesq) 112 | 113 | # save the samples to tensorboard 114 | # the best pesq 115 | writer.save_samples_we_want('max_pesq', all_batch_input[max_pesq_index], all_batch_label[max_pesq_index], 116 | all_batch_output[max_pesq_index], epoch) 117 | # the worst pesq 118 | writer.save_samples_we_want('min_pesq', all_batch_input[min_pesq_index], all_batch_label[min_pesq_index], 119 | all_batch_output[min_pesq_index], epoch) 120 | # the avg pesq 121 | writer.save_samples_we_want('avg_pesq', all_batch_input[avg_pesq_index], all_batch_label[avg_pesq_index], 122 | all_batch_output[avg_pesq_index], epoch) 123 | 124 | # save the same sample 125 | clip_num = 10 126 | writer.save_samples_we_want('n{}_sample'.format(clip_num), all_batch_input[clip_num], all_batch_label[clip_num], 127 | all_batch_output[clip_num], epoch) 128 | 129 | validation_loss /= batch_num 130 | avg_pesq /= batch_num 131 | avg_stoi /= batch_num 132 | 133 | # save average score 134 | f_pesq.write('Avg: {:.6f}\n'.format(avg_pesq)) 135 | f_stoi.write('Avg: {:.4f}\n'.format(avg_stoi)) 136 | 137 | f_pesq.close() 138 | f_stoi.close() 139 | return validation_loss, avg_pesq, avg_stoi 140 | 141 | 142 | def model_test(noise_type, snr, model, test_loader, dir_to_save, DEVICE): 143 | model.eval() 144 | with torch.no_grad(): 145 | # initialization 146 | batch_num = 0 147 | test_loss = 0 148 | avg_pesq = 0 149 | avg_stoi = 0 150 | 151 | all_batch_input = [] 152 | all_batch_label = [] 153 | all_batch_output = [] 154 | all_batch_real_spec = [] 155 | all_batch_img_spec = [] 156 | all_batch_pesq = [] 157 | 158 | # f_pesq = open(dir_to_save + '/test_pesq_epoch{}_{}_{}dB' 159 | # .format(min_index + 1, noise_type, snr), 'a') 160 | # f_stoi = open(dir_to_save + '/test_stoi_epoch{}_{}_{}dB' 161 | # .format(min_index + 1, noise_type, snr), 'a') 162 | for inputs, labels in Bar(test_loader): 163 | batch_num += 1 164 | 165 | # to cuda 166 | inputs = inputs.float().to(DEVICE) 167 | labels = labels.float().to(DEVICE) 168 | 169 | mask_real, mask_imag, real_spec, img_spec, outputs = model(inputs) 170 | loss = model.loss(outputs, labels, real_spec, img_spec) 171 | # loss = model.pmsqe_loss(labels, outputs) 172 | # estimate the output speech with pesq and stoi 173 | # save pesq & stoi score at each epoch 174 | # [18480, 1] 175 | estimated_wavs = outputs.cpu().detach().numpy() 176 | clean_wavs = labels.cpu().detach().numpy() 177 | 178 | pesq = cal_pesq(estimated_wavs, clean_wavs) 179 | stoi = cal_stoi(estimated_wavs, clean_wavs) 180 | 181 | # # pesq: 0.1 better / stoi: 0.01 better 182 | # for i in range(len(pesq)): 183 | # f_pesq.write('{:.6f}\n'.format(pesq[i])) 184 | # f_stoi.write('{:.4f}\n'.format(stoi[i])) 185 | 186 | test_loss += loss 187 | 188 | # reshape for sum 189 | pesq = np.reshape(pesq, (1, -1)) 190 | stoi = np.reshape(stoi, (1, -1)) 191 | 192 | avg_pesq += sum(pesq[0]) / len(inputs) 193 | avg_stoi += sum(stoi[0]) / len(inputs) 194 | 195 | # all batch data array 196 | all_batch_input.extend(inputs) 197 | all_batch_label.extend(labels) 198 | all_batch_output.extend(outputs) 199 | all_batch_real_spec.extend(mask_real) 200 | all_batch_img_spec.extend(mask_imag) 201 | all_batch_pesq.extend(pesq[0]) 202 | 203 | # find the best & worst pesq model 204 | max_pesq_index = all_batch_pesq.index(max(all_batch_pesq)) 205 | min_pesq_index = all_batch_pesq.index(min(all_batch_pesq)) 206 | 207 | test_loss /= batch_num 208 | avg_pesq /= batch_num 209 | avg_stoi /= batch_num 210 | 211 | max_pesq = all_batch_pesq[max_pesq_index] 212 | min_pesq = all_batch_pesq[min_pesq_index] 213 | 214 | # save average score 215 | # f_pesq.write('Max: {:.6f} | Min: {:.6f} | Avg: {:.6f}\n'.format(max_pesq, min_pesq, avg_pesq)) 216 | # f_stoi.write('Avg: {:.4f}\n'.format(avg_stoi)) 217 | # f_pesq.close() 218 | # f_stoi.close() 219 | return test_loss, avg_pesq, avg_stoi 220 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Interface for train 3 | """ 4 | 5 | import os 6 | import time 7 | import torch 8 | import shutil 9 | import numpy as np 10 | import config as cfg 11 | from train import model_train, model_validate, model_test 12 | from dataloader import create_dataloader, create_dataloader_for_test 13 | from model import DCCRN, DCUNET, DCCRN_direct, DCCRN_no_skip 14 | from write_on_tensorboard import Writer 15 | 16 | 17 | ############################################################################### 18 | # Helper function definition # 19 | ############################################################################### 20 | # Write training related parameters into the log file. 21 | def write_status_to_log_file(fp, total_parameters): 22 | fp.write('adsfasdfsdfds') 23 | fp.write('%d-%d-%d %d:%d:%d\n' % 24 | (time.localtime().tm_year, time.localtime().tm_mon, 25 | time.localtime().tm_mday, time.localtime().tm_hour, 26 | time.localtime().tm_min, time.localtime().tm_sec)) 27 | fp.write('mode : %s_%s\n' % (cfg.mode, cfg.info)) 28 | fp.write('learning rate : %g\n' % cfg.learning_rate) 29 | fp.write('total params : %d (%.2f M, %.2f MBytes)\n' % 30 | (total_parameters, 31 | total_parameters / 1000000.0, 32 | total_parameters * 4.0 / 1000000.0)) 33 | 34 | 35 | # Calculate the size of total network. 36 | def calculate_total_params(our_model): 37 | total_parameters = 0 38 | for variable in our_model.parameters(): 39 | shape = variable.size() 40 | variable_parameters = 1 41 | for dim in shape: 42 | variable_parameters *= dim 43 | total_parameters += variable_parameters 44 | 45 | return total_parameters 46 | 47 | 48 | ############################################################################### 49 | # Parameter Initialization # 50 | ############################################################################### 51 | print('***********************************************************') 52 | print('* Python library for DNN-based speech enhancement *') 53 | print('* using Pytorch API *') 54 | print('***********************************************************') 55 | 56 | # Set device 57 | DEVICE = torch.device("cuda") 58 | 59 | # Set model 60 | if cfg.mode == 'DCCRN': 61 | model = DCCRN(rnn_units=cfg.rnn_units, masking_mode=cfg.masking_mode, use_clstm=cfg.use_clstm, 62 | kernel_num=cfg.kernel_num).to(DEVICE) 63 | elif cfg.mode == 'DCUNET': 64 | model = DCUNET(masking_mode=cfg.masking_mode, kernel_num=cfg.kernel_num).to(DEVICE) 65 | elif cfg.mode == 'DCCRN_direct': 66 | model = DCCRN_direct(rnn_units=cfg.rnn_units, use_clstm=cfg.use_clstm, kernel_num=cfg.kernel_num).to(DEVICE) 67 | 68 | ############################################################################### 69 | # Set optimizer and learning rate # 70 | ############################################################################### 71 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate) 72 | total_params = calculate_total_params(model) 73 | 74 | ############################################################################### 75 | # Confirm model information # 76 | ############################################################################### 77 | print('%d-%d-%d %d:%d:%d\n' % 78 | (time.localtime().tm_year, time.localtime().tm_mon, 79 | time.localtime().tm_mday, time.localtime().tm_hour, 80 | time.localtime().tm_min, time.localtime().tm_sec)) 81 | print('mode : %s_%s\n' % (cfg.mode, cfg.info)) 82 | print('learning rate : %g\n' % cfg.learning_rate) 83 | print('total params : %d (%.2f M, %.2f MBytes)\n' % 84 | (total_params, 85 | total_params / 1000000.0, 86 | total_params * 4.0 / 1000000.0)) 87 | 88 | ############################################################################### 89 | # Create Dataloader # 90 | ############################################################################### 91 | # Set device 92 | DEVICE = torch.device("cuda") 93 | 94 | train_loader = create_dataloader(mode='train') 95 | validation_loader = create_dataloader(mode='valid') 96 | 97 | ############################################################################### 98 | # Set a log file to store progress. # 99 | # Set a hps file to store hyper-parameters information. # 100 | ############################################################################### 101 | # Load the checkpoint 102 | if cfg.chkpt_path is not None: 103 | print('Resuming from checkpoint: %s' % cfg.chkpt_path) 104 | 105 | # Set a log file to store progress. 106 | dir_to_save = cfg.job_dir + cfg.chkpt_model 107 | dir_to_logs = cfg.logs_dir + cfg.chkpt_model 108 | 109 | checkpoint = torch.load(cfg.chkpt_path) 110 | model.load_state_dict(checkpoint['model']) 111 | optimizer.load_state_dict(checkpoint['optimizer']) 112 | epoch_start_idx = checkpoint['epoch'] + 1 113 | mse_vali_total = np.load(str(dir_to_save + '/mse_vali_total.npy')) 114 | if len(mse_vali_total) < cfg.max_epochs: 115 | plus = cfg.max_epochs - len(mse_vali_total) 116 | mse_vali_total = np.concatenate((mse_vali_total, np.zeros(plus)), 0) 117 | else: 118 | print('Starting new training run') 119 | epoch_start_idx = 1 120 | mse_vali_total = np.zeros(cfg.max_epochs) 121 | 122 | # Set a log file to store progress. 123 | dir_to_save = str(cfg.job_dir) + '%d.%d' % (time.localtime().tm_mon, time.localtime().tm_mday) \ 124 | + '_%s' % cfg.mode + '_%s' % cfg.info 125 | dir_to_logs = str(cfg.logs_dir) + '%d.%d' % (time.localtime().tm_mon, time.localtime().tm_mday) \ 126 | + '_%s' % cfg.mode + '_%s' % cfg.info 127 | 128 | if not os.path.exists(dir_to_save): 129 | os.mkdir(dir_to_save) 130 | os.mkdir(dir_to_logs) 131 | 132 | log_fname = str(dir_to_save + '/log.txt') 133 | if not os.path.exists(log_fname): 134 | fp = open(log_fname, 'w') 135 | write_status_to_log_file(fp, total_params) 136 | else: 137 | fp = open(log_fname, 'a') 138 | 139 | # Set a hps file to store hyper-parameters information. 140 | hps_fname = str(dir_to_save + '/hp_str.txt') 141 | fp_h = open(hps_fname, 'w') 142 | 143 | with open('config.py', 'r') as f: 144 | hp_str = ''.join(f.readlines()) 145 | fp_h.write(hp_str) 146 | fp_h.close() 147 | 148 | ############################################################################### 149 | ############################################################################### 150 | # Main program start !! # 151 | ############################################################################### 152 | ############################################################################### 153 | 154 | # Writer initialize 155 | writer = Writer(dir_to_logs) 156 | 157 | ############################################################################### 158 | # Train # 159 | ############################################################################### 160 | for epoch in range(epoch_start_idx, cfg.max_epochs + 1): 161 | start_time = time.time() 162 | train_loss = model_train(model, optimizer, train_loader, epoch, DEVICE) 163 | vali_loss, vali_pesq, vali_stoi = model_validate(model, validation_loader, 164 | dir_to_save, writer, epoch, DEVICE) 165 | 166 | mse_vali_total[epoch - 1] = vali_loss 167 | np.save(str(dir_to_save + '/mse_vali_total.npy'), mse_vali_total) 168 | 169 | # write the loss on tensorboard 170 | writer.log_loss(train_loss, vali_loss, epoch) 171 | 172 | # save checkpoint file to resume training 173 | save_path = str(dir_to_save + '/' + ('chkpt_%d.pt' % epoch)) 174 | torch.save({ 175 | 'model': model.state_dict(), 176 | 'optimizer': optimizer.state_dict(), 177 | 'epoch': epoch 178 | }, save_path) 179 | 180 | print('Epoch [{}] | {:.6f} | {:.6} | {:.6} | {:.6} takes {:.2f} seconds' 181 | .format(epoch, train_loss, vali_loss, vali_pesq, vali_stoi, time.time() - start_time)) 182 | fp.write('Epoch [{}] | {:.6f} | {:.6f} | {:.6f} | {:.6f} takes {:.2f} seconds\n' 183 | .format(epoch, train_loss, vali_loss, vali_pesq, vali_stoi, time.time() - start_time)) 184 | 185 | print('Training has been finished.') 186 | 187 | # Copy optimum model that has minimum MSE. 188 | print('Save optimum models...') 189 | min_index = np.argmin(mse_vali_total) 190 | print('Minimum validation loss is at '+str(min_index+1)+'.') 191 | -------------------------------------------------------------------------------- /write_on_tensorboard.py: -------------------------------------------------------------------------------- 1 | """ 2 | For observing the results using tensorboard 3 | 4 | 1. wav 5 | 2. loss 6 | """ 7 | from tensorboardX import SummaryWriter 8 | import matplotlib 9 | import config as cfg 10 | 11 | 12 | class Writer(SummaryWriter): 13 | def __init__(self, logdir): 14 | super(Writer, self).__init__(logdir) 15 | # mask real/ imag 16 | cmap_custom = { 17 | 'red': ((0.0, 0.0, 0.0), 18 | (1 / 63, 0.0, 0.0), 19 | (2 / 63, 0.0, 0.0), 20 | (3 / 63, 0.0, 0.0), 21 | (4 / 63, 0.0, 0.0), 22 | (5 / 63, 0.0, 0.0), 23 | (6 / 63, 0.0, 0.0), 24 | (7 / 63, 0.0, 0.0), 25 | (8 / 63, 0.0, 0.0), 26 | (9 / 63, 0.0, 0.0), 27 | (10 / 63, 0.0, 0.0), 28 | (11 / 63, 0.0, 0.0), 29 | (12 / 63, 0.0, 0.0), 30 | (13 / 63, 0.0, 0.0), 31 | (14 / 63, 0.0, 0.0), 32 | (15 / 63, 0.0, 0.0), 33 | (16 / 63, 0.0, 0.0), 34 | (17 / 63, 0.0, 0.0), 35 | (18 / 63, 0.0, 0.0), 36 | (19 / 63, 0.0, 0.0), 37 | (20 / 63, 0.0, 0.0), 38 | (21 / 63, 0.0, 0.0), 39 | (22 / 63, 0.0, 0.0), 40 | (23 / 63, 0.0, 0.0), 41 | (24 / 63, 0.5625, 0.5625), 42 | (25 / 63, 0.6250, 0.6250), 43 | (26 / 63, 0.6875, 0.6875), 44 | (27 / 63, 0.7500, 0.7500), 45 | (28 / 63, 0.8125, 0.8125), 46 | (29 / 63, 0.8750, 0.8750), 47 | (30 / 63, 0.9375, 0.9375), 48 | (31 / 63, 1.0, 1.0), 49 | (32 / 63, 1.0, 1.0), 50 | (33 / 63, 1.0, 1.0), 51 | (34 / 63, 1.0, 1.0), 52 | (35 / 63, 1.0, 1.0), 53 | (36 / 63, 1.0, 1.0), 54 | (37 / 63, 1.0, 1.0), 55 | (38 / 63, 1.0, 1.0), 56 | (39 / 63, 1.0, 1.0), 57 | (40 / 63, 1.0, 1.0), 58 | (41 / 63, 1.0, 1.0), 59 | (42 / 63, 1.0, 1.0), 60 | (43 / 63, 1.0, 1.0), 61 | (44 / 63, 1.0, 1.0), 62 | (45 / 63, 1.0, 1.0), 63 | (46 / 63, 1.0, 1.0), 64 | (47 / 63, 1.0, 1.0), 65 | (48 / 63, 1.0, 1.0), 66 | (49 / 63, 1.0, 1.0), 67 | (50 / 63, 1.0, 1.0), 68 | (51 / 63, 1.0, 1.0), 69 | (52 / 63, 1.0, 1.0), 70 | (53 / 63, 1.0, 1.0), 71 | (54 / 63, 1.0, 1.0), 72 | (55 / 63, 1.0, 1.0), 73 | (56 / 63, 0.9375, 0.9375), 74 | (57 / 63, 0.8750, 0.8750), 75 | (58 / 63, 0.8125, 0.8125), 76 | (59 / 63, 0.7500, 0.7500), 77 | (60 / 63, 0.6875, 0.6875), 78 | (61 / 63, 0.6250, 0.6250), 79 | (62 / 63, 0.5625, 0.5625), 80 | (63 / 63, 0.5000, 0.5000)), 81 | 'green': ((0.0, 0.0, 0.0), 82 | (1 / 63, 0.0, 0.0), 83 | (2 / 63, 0.0, 0.0), 84 | (3 / 63, 0.0, 0.0), 85 | (4 / 63, 0.0, 0.0), 86 | (5 / 63, 0.0, 0.0), 87 | (6 / 63, 0.0, 0.0), 88 | (7 / 63, 0.0, 0.0), 89 | (8 / 63, 0.0625, 0.0625), 90 | (9 / 63, 0.1250, 0.1250), 91 | (10 / 63, 0.1875, 0.1875), 92 | (11 / 63, 0.2500, 0.2500), 93 | (12 / 63, 0.3125, 0.3125), 94 | (13 / 63, 0.3750, 0.3750), 95 | (14 / 63, 0.4375, 0.4375), 96 | (15 / 63, 0.5000, 0.5000), 97 | (16 / 63, 0.5625, 0.5625), 98 | (17 / 63, 0.6250, 0.6250), 99 | (18 / 63, 0.6875, 0.6875), 100 | (19 / 63, 0.7500, 0.7500), 101 | (20 / 63, 0.8125, 0.8125), 102 | (21 / 63, 0.8750, 0.8750), 103 | (22 / 63, 0.9375, 0.9375), 104 | (23 / 63, 1.0, 1.0), 105 | (24 / 63, 1.0, 1.0), 106 | (25 / 63, 1.0, 1.0), 107 | (26 / 63, 1.0, 1.0), 108 | (27 / 63, 1.0, 1.0), 109 | (28 / 63, 1.0, 1.0), 110 | (29 / 63, 1.0, 1.0), 111 | (30 / 63, 1.0, 1.0), 112 | (31 / 63, 1.0, 1.0), 113 | (32 / 63, 1.0, 1.0), 114 | (33 / 63, 1.0, 1.0), 115 | (34 / 63, 1.0, 1.0), 116 | (35 / 63, 1.0, 1.0), 117 | (36 / 63, 1.0, 1.0), 118 | (37 / 63, 1.0, 1.0), 119 | (38 / 63, 1.0, 1.0), 120 | (39 / 63, 1.0, 1.0), 121 | (40 / 63, 0.9375, 0.9375), 122 | (41 / 63, 0.8750, 0.8750), 123 | (42 / 63, 0.8125, 0.8125), 124 | (43 / 63, 0.7500, 0.7500), 125 | (44 / 63, 0.6875, 0.6875), 126 | (45 / 63, 0.6250, 0.6250), 127 | (46 / 63, 0.5625, 0.5625), 128 | (47 / 63, 0.5000, 0.5000), 129 | (48 / 63, 0.4375, 0.4375), 130 | (49 / 63, 0.3750, 0.3750), 131 | (50 / 63, 0.3125, 0.3125), 132 | (51 / 63, 0.2500, 0.2500), 133 | (52 / 63, 0.1875, 0.1875), 134 | (53 / 63, 0.1250, 0.1250), 135 | (54 / 63, 0.0625, 0.0625), 136 | (55 / 63, 0.0, 0.0), 137 | (56 / 63, 0.0, 0.0), 138 | (57 / 63, 0.0, 0.0), 139 | (58 / 63, 0.0, 0.0), 140 | (59 / 63, 0.0, 0.0), 141 | (60 / 63, 0.0, 0.0), 142 | (61 / 63, 0.0, 0.0), 143 | (62 / 63, 0.0, 0.0), 144 | (63 / 63, 0.0, 0.0)), 145 | 'blue': ((0.0, 0.5625, 0.5625), 146 | (1 / 63, 0.6250, 0.6250), 147 | (2 / 63, 0.6875, 0.6875), 148 | (3 / 63, 0.7500, 0.7500), 149 | (4 / 63, 0.8125, 0.8125), 150 | (5 / 63, 0.8750, 0.8750), 151 | (6 / 63, 0.9375, 0.9375), 152 | (7 / 63, 1.0, 1.0), 153 | (8 / 63, 1.0, 1.0), 154 | (9 / 63, 1.0, 1.0), 155 | (10 / 63, 1.0, 1.0), 156 | (11 / 63, 1.0, 1.0), 157 | (12 / 63, 1.0, 1.0), 158 | (13 / 63, 1.0, 1.0), 159 | (14 / 63, 1.0, 1.0), 160 | (15 / 63, 1.0, 1.0), 161 | (16 / 63, 1.0, 1.0), 162 | (17 / 63, 1.0, 1.0), 163 | (18 / 63, 1.0, 1.0), 164 | (19 / 63, 1.0, 1.0), 165 | (20 / 63, 1.0, 1.0), 166 | (21 / 63, 1.0, 1.0), 167 | (22 / 63, 1.0, 1.0), 168 | (23 / 63, 1.0, 1.0), 169 | (24 / 63, 1.0, 1.0), 170 | (25 / 63, 1.0, 1.0), 171 | (26 / 63, 1.0, 1.0), 172 | (27 / 63, 1.0, 1.0), 173 | (28 / 63, 1.0, 1.0), 174 | (29 / 63, 1.0, 1.0), 175 | (30 / 63, 1.0, 1.0), 176 | (31 / 63, 1.0, 1.0), 177 | (32 / 63, 0.9375, 0.9375), 178 | (33 / 63, 0.8750, 0.8750), 179 | (34 / 63, 0.8125, 0.8125), 180 | (35 / 63, 0.7500, 0.7500), 181 | (36 / 63, 0.6875, 0.6875), 182 | (37 / 63, 0.6250, 0.6250), 183 | (38 / 63, 0.5625, 0.5625), 184 | (39 / 63, 0.0, 0.0), 185 | (40 / 63, 0.0, 0.0), 186 | (41 / 63, 0.0, 0.0), 187 | (42 / 63, 0.0, 0.0), 188 | (43 / 63, 0.0, 0.0), 189 | (44 / 63, 0.0, 0.0), 190 | (45 / 63, 0.0, 0.0), 191 | (46 / 63, 0.0, 0.0), 192 | (47 / 63, 0.0, 0.0), 193 | (48 / 63, 0.0, 0.0), 194 | (49 / 63, 0.0, 0.0), 195 | (50 / 63, 0.0, 0.0), 196 | (51 / 63, 0.0, 0.0), 197 | (52 / 63, 0.0, 0.0), 198 | (53 / 63, 0.0, 0.0), 199 | (54 / 63, 0.0, 0.0), 200 | (55 / 63, 0.0, 0.0), 201 | (56 / 63, 0.0, 0.0), 202 | (57 / 63, 0.0, 0.0), 203 | (58 / 63, 0.0, 0.0), 204 | (59 / 63, 0.0, 0.0), 205 | (60 / 63, 0.0, 0.0), 206 | (61 / 63, 0.0, 0.0), 207 | (62 / 63, 0.0, 0.0), 208 | (63 / 63, 0.0, 0.0)) 209 | } 210 | 211 | # mask magnitude 212 | cmap_custom2 = { 213 | 'red': ((0.0, 1.0, 1.0), 214 | (1 / 32, 1.0, 1.0), 215 | (2 / 32, 1.0, 1.0), 216 | (3 / 32, 1.0, 1.0), 217 | (4 / 32, 1.0, 1.0), 218 | (5 / 32, 1.0, 1.0), 219 | (6 / 32, 1.0, 1.0), 220 | (7 / 32, 1.0, 1.0), 221 | (8 / 32, 1.0, 1.0), 222 | (9 / 32, 1.0, 1.0), 223 | (10 / 32, 1.0, 1.0), 224 | (11 / 32, 1.0, 1.0), 225 | (12 / 32, 1.0, 1.0), 226 | (13 / 32, 1.0, 1.0), 227 | (14 / 32, 1.0, 1.0), 228 | (15 / 32, 1.0, 1.0), 229 | (16 / 32, 1.0, 1.0), 230 | (17 / 32, 1.0, 1.0), 231 | (18 / 32, 1.0, 1.0), 232 | (19 / 32, 1.0, 1.0), 233 | (20 / 32, 1.0, 1.0), 234 | (21 / 32, 1.0, 1.0), 235 | (22 / 32, 1.0, 1.0), 236 | (23 / 32, 1.0, 1.0), 237 | (24 / 32, 1.0, 1.0), 238 | (25 / 32, 0.9375, 0.9375), 239 | (26 / 32, 0.8750, 0.8750), 240 | (27 / 32, 0.8125, 0.8125), 241 | (28 / 32, 0.7500, 0.7500), 242 | (29 / 32, 0.6875, 0.6875), 243 | (30 / 32, 0.6250, 0.6250), 244 | (31 / 32, 0.5625, 0.5625), 245 | (32 / 32, 0.5000, 0.5000)), 246 | 'green': ((0.0, 1.0, 1.0), 247 | (1 / 32, 1.0, 1.0), 248 | (2 / 32, 1.0, 1.0), 249 | (3 / 32, 1.0, 1.0), 250 | (4 / 32, 1.0, 1.0), 251 | (5 / 32, 1.0, 1.0), 252 | (6 / 32, 1.0, 1.0), 253 | (7 / 32, 1.0, 1.0), 254 | (8 / 32, 1.0, 1.0), 255 | (9 / 32, 0.9375, 0.9375), 256 | (10 / 32, 0.8750, 0.8750), 257 | (11 / 32, 0.8125, 0.8125), 258 | (12 / 32, 0.7500, 0.7500), 259 | (13 / 32, 0.6875, 0.6875), 260 | (14 / 32, 0.6250, 0.6250), 261 | (15 / 32, 0.5625, 0.5625), 262 | (16 / 32, 0.5000, 0.5000), 263 | (17 / 32, 0.4375, 0.4375), 264 | (18 / 32, 0.3750, 0.3750), 265 | (19 / 32, 0.3125, 0.3125), 266 | (20 / 32, 0.2500, 0.2500), 267 | (21 / 32, 0.1875, 0.1875), 268 | (22 / 32, 0.1250, 0.1250), 269 | (23 / 32, 0.0625, 0.0625), 270 | (24 / 32, 0.0, 0.0), 271 | (25 / 32, 0.0, 0.0), 272 | (26 / 32, 0.0, 0.0), 273 | (27 / 32, 0.0, 0.0), 274 | (28 / 32, 0.0, 0.0), 275 | (29 / 32, 0.0, 0.0), 276 | (30 / 32, 0.0, 0.0), 277 | (31 / 32, 0.0, 0.0), 278 | (32 / 32, 0.0, 0.0)), 279 | 'blue': ((0.0, 1.0, 1.0), 280 | (1 / 32, 0.9375, 0.9375), 281 | (2 / 32, 0.8750, 0.8750), 282 | (3 / 32, 0.8125, 0.8125), 283 | (4 / 32, 0.7500, 0.7500), 284 | (5 / 32, 0.6875, 0.6875), 285 | (6 / 32, 0.6250, 0.6250), 286 | (7 / 32, 0.5625, 0.5625), 287 | (8 / 32, 0.0, 0.0), 288 | (9 / 32, 0.0, 0.0), 289 | (10 / 32, 0.0, 0.0), 290 | (11 / 32, 0.0, 0.0), 291 | (12 / 32, 0.0, 0.0), 292 | (13 / 32, 0.0, 0.0), 293 | (14 / 32, 0.0, 0.0), 294 | (15 / 32, 0.0, 0.0), 295 | (16 / 32, 0.0, 0.0), 296 | (17 / 32, 0.0, 0.0), 297 | (18 / 32, 0.0, 0.0), 298 | (19 / 32, 0.0, 0.0), 299 | (20 / 32, 0.0, 0.0), 300 | (21 / 32, 0.0, 0.0), 301 | (22 / 32, 0.0, 0.0), 302 | (23 / 32, 0.0, 0.0), 303 | (24 / 32, 0.0, 0.0), 304 | (25 / 32, 0.0, 0.0), 305 | (26 / 32, 0.0, 0.0), 306 | (27 / 32, 0.0, 0.0), 307 | (28 / 32, 0.0, 0.0), 308 | (29 / 32, 0.0, 0.0), 309 | (30 / 32, 0.0, 0.0), 310 | (31 / 32, 0.0, 0.0), 311 | (32 / 32, 0.0, 0.0)) 312 | } 313 | 314 | self.cmap_custom = matplotlib.colors.LinearSegmentedColormap('testCmap', segmentdata=cmap_custom, N=256) 315 | self.cmap_custom2 = matplotlib.colors.LinearSegmentedColormap('testCmap2', segmentdata=cmap_custom2, N=256) 316 | 317 | def log_loss(self, train_loss, vali_loss, step): 318 | self.add_scalar('train_loss', train_loss, step) 319 | self.add_scalar('vali_loss', vali_loss, step) 320 | 321 | def log_sub_loss(self, train_main_loss, train_sub_loss, vali_main_loss, vali_sub_loss, step): 322 | self.add_scalar('train_main_loss', train_main_loss, step) 323 | self.add_scalar('train_sub_loss', train_sub_loss, step) 324 | self.add_scalar('vali_main_loss', vali_main_loss, step) 325 | self.add_scalar('vali_sub_loss', vali_sub_loss, step) 326 | 327 | def log_score(self, vali_pesq, vali_stoi, step): 328 | self.add_scalar('vali_pesq', vali_pesq, step) 329 | self.add_scalar('vali_stoi', vali_stoi, step) 330 | 331 | def log_wav(self, mixed_wav, clean_wav, est_wav, step): 332 | #