├── DPCRN_base.py ├── DPCRN_skip.py ├── README.md ├── configuration ├── DPCRN-base.yaml └── DPCRN-skip.yaml ├── data_loader.py ├── enhance_s.wav ├── evaluations └── mir_eval.py ├── loss.py ├── main.py ├── networks ├── modules.py ├── pruning_gru.py ├── pruning_methods.py └── skip_gru.py ├── pretrained_weights └── DPCRN_base │ └── models_experiment_new_base_nomap_phasenloss_retrain_WSJmodel_84_0.022068.h5 ├── signal_processing.py └── test_audio ├── enhanced └── 440C020A_mix.wav └── noisy └── 440C020A_mix.wav /DPCRN_base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Nov 20 22:16:58 2020 4 | 5 | @author: Xiaohuai Le 6 | """ 7 | import os 8 | import tensorflow as tf 9 | import tensorflow.keras as keras 10 | from tensorflow.keras import backend as K 11 | from tensorflow.keras.models import Model 12 | from tensorflow.keras.layers import Activation, Lambda, Input, LayerNormalization, Conv2D, BatchNormalization, Conv2DTranspose, Concatenate, PReLU 13 | 14 | import soundfile as sf 15 | from random import seed 16 | import numpy as np 17 | import librosa 18 | 19 | from loss import Loss 20 | from signal_processing import Signal_Pro 21 | 22 | from networks.modules import DprnnBlock 23 | 24 | seed(42) 25 | np.random.seed(42) 26 | 27 | class DPCRN_model(Loss, Signal_Pro): 28 | ''' 29 | Class to create the DPCRN-base model 30 | ''' 31 | 32 | def __init__(self, batch_size, config, length_in_s = 8, lr = 1e-3): 33 | ''' 34 | Constructor 35 | ''' 36 | Signal_Pro.__init__(self, config) 37 | 38 | self.network_config = config['network'] 39 | self.filter_size = self.network_config['filter_size'] 40 | self.kernel_size = self.network_config['kernel_size'] 41 | self.strides = self.network_config['strides'] 42 | self.encoder_padding = self.network_config['encoder_padding'] 43 | self.decoder_padding = self.network_config['decoder_padding'] 44 | self.output_cut_off = self.network_config['output_cut'] 45 | self.N_DPRNN = self.network_config['N_DPRNN'] 46 | self.use_CuDNNGRU = self.network_config['use_CuDNNGRU'] 47 | self.activation = self.network_config['activation'] 48 | self.input_norm = self.network_config['input_norm'] 49 | self.intra_hidden_size = self.network_config['DPRNN']['intra_hidden_size'] 50 | self.inter_hidden_size = self.network_config['DPRNN']['inter_hidden_size'] 51 | # empty property for the model 52 | self.model = None 53 | # defining default parameters 54 | self.length_in_s = length_in_s 55 | self.batch_size = batch_size 56 | self.lr = lr 57 | self.eps = 1e-9 58 | 59 | self.L = (16000 * length_in_s - self.block_len) // self.block_shift + 1 60 | 61 | def metricsWrapper(self): 62 | ''' 63 | A wrapper function which returns the metrics used during training 64 | ''' 65 | return [self.sisnr_cost] 66 | 67 | def lossWrapper(self): 68 | ''' 69 | A wrapper function which returns the loss function. This is done to 70 | to enable additional arguments to the loss function if necessary. 71 | ''' 72 | def spectrum_loss_SD(s_hat, s, c = 0.3, Lam = 0.1): 73 | # The complex compressed spectrum MSE loss 74 | s = tf.truediv(s,self.batch_gain + 1e-9) 75 | s_hat= tf.truediv(s_hat,self.batch_gain + 1e-9) 76 | 77 | true_real,true_imag = self.stftLayer(s, mode='real_imag') 78 | hat_real,hat_imag = self.stftLayer(s_hat, mode='real_imag') 79 | 80 | true_mag = tf.sqrt(true_real**2 + true_imag**2 + 1e-9) 81 | hat_mag = tf.sqrt(hat_real**2 + hat_imag**2 + 1e-9) 82 | 83 | true_real_cprs = (true_real / true_mag )*true_mag**c 84 | true_imag_cprs = (true_imag / true_mag )*true_mag**c 85 | hat_real_cprs = (hat_real / hat_mag )* hat_mag**c 86 | hat_imag_cprs = (hat_imag / hat_mag )* hat_mag**c 87 | 88 | loss_mag = tf.reduce_mean((hat_mag**c - true_mag**c)**2,) 89 | loss_real = tf.reduce_mean((hat_real_cprs - true_real_cprs)**2,) 90 | loss_imag = tf.reduce_mean((hat_imag_cprs - true_imag_cprs)**2,) 91 | 92 | return (1 - Lam) * loss_mag + Lam * ( loss_imag + loss_real ) 93 | 94 | return spectrum_loss_SD 95 | 96 | def build_DPCRN_model(self, name = 'model0'): 97 | 98 | # input layer for time signal 99 | time_data = Input(batch_shape=(self.batch_size, None)) 100 | self.batch_gain = Input(batch_shape=(self.batch_size, 1)) 101 | 102 | # calculate STFT 103 | real,imag = Lambda(self.stftLayer,arguments = {'mode':'real_imag'})(time_data) 104 | 105 | real = tf.reshape(real,[self.batch_size,-1,self.block_len // 2 + 1,1]) 106 | imag = tf.reshape(imag,[self.batch_size,-1,self.block_len // 2 + 1,1]) 107 | 108 | input_mag = tf.math.sqrt(real**2 + imag**2 +1e-9) 109 | input_log_spec = 2 * tf.math.log(input_mag) 110 | # input feature 111 | input_complex_spec = Concatenate(axis = -1)([real,imag,input_log_spec]) 112 | 113 | '''encoder''' 114 | 115 | if self.input_norm == 'batchnorm': 116 | input_complex_spec = BatchNormalization(axis = [-1,-2], epsilon = self.eps)(input_complex_spec) 117 | elif self.input_norm == 'instantlayernorm': 118 | input_complex_spec = LayerNormalization(axis = [-1,-2], epsilon = self.eps)(input_complex_spec) 119 | 120 | conv_1 = Conv2D(self.filter_size[0], self.kernel_size[0], self.strides[0], name = name+'_conv_1', padding = [[0,0],[0,0],self.encoder_padding[0],[0,0]])(input_complex_spec) 121 | bn_1 = BatchNormalization(name = name+'_bn_1')(conv_1) 122 | out_1 = PReLU(shared_axes=[1,2])(bn_1) 123 | 124 | conv_2 = Conv2D(self.filter_size[1], self.kernel_size[1], self.strides[1], name = name+'_conv_2', padding = [[0,0],[0,0],self.encoder_padding[1],[0,0]])(out_1) 125 | bn_2 = BatchNormalization(name = name+'_bn_2')(conv_2) 126 | out_2 = PReLU(shared_axes=[1,2])(bn_2) 127 | 128 | conv_3 = Conv2D(self.filter_size[2], self.kernel_size[2], self.strides[2], name = name+'_conv_3', padding = [[0,0],[0,0],self.encoder_padding[2],[0,0]])(out_2) 129 | bn_3 = BatchNormalization(name = name+'_bn_3')(conv_3) 130 | out_3 = PReLU(shared_axes=[1,2])(bn_3) 131 | 132 | conv_4 = Conv2D(self.filter_size[3], self.kernel_size[3], self.strides[3], name = name+'_conv_4', padding = [[0,0],[0,0],self.encoder_padding[3],[0,0]])(out_3) 133 | bn_4 = BatchNormalization(name = name+'_bn_4')(conv_4) 134 | out_4 = PReLU(shared_axes=[1,2])(bn_4) 135 | 136 | conv_5 = Conv2D(self.filter_size[4], self.kernel_size[4], self.strides[4], name = name+'_conv_5', padding = [[0,0],[0,0],self.encoder_padding[4],[0,0]])(out_4) 137 | bn_5 = BatchNormalization(name = name+'_bn_5')(conv_5) 138 | out_5 = PReLU(shared_axes=[1,2])(bn_5) 139 | 140 | dp_in = out_5 141 | 142 | for i in range(self.N_DPRNN): 143 | dp_in = DprnnBlock(intra_hidden = self.intra_hidden_size, 144 | inter_hidden=self.inter_hidden_size, 145 | batch_size = self.batch_size, 146 | L = -1, 147 | width = self.block_len //2 //8, 148 | channel = self.filter_size[4], 149 | causal= True, 150 | CUDNN = self.use_CuDNNGRU)(dp_in) 151 | 152 | dp_out = dp_in 153 | '''decoder''' 154 | skipcon_1 = Concatenate(axis = -1)([out_5, dp_out]) 155 | 156 | deconv_1 = Conv2DTranspose(self.filter_size[3], self.kernel_size[4], self.strides[4], name = name+'_dconv_1', padding = self.decoder_padding[0])(skipcon_1) 157 | dbn_1 = BatchNormalization(name = name+'_dbn_1')(deconv_1) 158 | dout_1 = PReLU(shared_axes=[1,2])(dbn_1) 159 | 160 | skipcon_2 = Concatenate(axis = -1)([out_4, dout_1]) 161 | 162 | deconv_2 = Conv2DTranspose(self.filter_size[2], self.kernel_size[3], self.strides[3], name = name+'_dconv_2', padding = self.decoder_padding[1])(skipcon_2) 163 | dbn_2 = BatchNormalization(name = name+'_dbn_2')(deconv_2) 164 | dout_2 = PReLU(shared_axes=[1,2])(dbn_2) 165 | 166 | skipcon_3 = Concatenate(axis = -1)([out_3, dout_2]) 167 | 168 | deconv_3 = Conv2DTranspose(self.filter_size[1], self.kernel_size[2], self.strides[2], name = name+'_dconv_3', padding = self.decoder_padding[2])(skipcon_3) 169 | dbn_3 = BatchNormalization(name = name+'_dbn_3')(deconv_3) 170 | dout_3 = PReLU(shared_axes=[1,2])(dbn_3) 171 | 172 | skipcon_4 = Concatenate(axis = -1)([out_2, dout_3]) 173 | 174 | deconv_4 = Conv2DTranspose(self.filter_size[0], self.kernel_size[1], self.strides[1], name = name+'_dconv_4', padding = self.decoder_padding[3])(skipcon_4) 175 | dbn_4 = BatchNormalization(name = name+'_dbn_4')(deconv_4) 176 | dout_4 = PReLU(shared_axes=[1,2])(dbn_4) 177 | 178 | skipcon_5 = Concatenate(axis = -1)([out_1, dout_4]) 179 | 180 | deconv_5 = Conv2DTranspose(2, self.kernel_size[0], self.strides[0], name = name+'_dconv_5', padding = self.decoder_padding[4])(skipcon_5) 181 | 182 | deconv_5 = deconv_5[:,:,:-self.output_cut_off] 183 | 184 | dbn_5 = BatchNormalization(name = name+'_dbn_5')(deconv_5) 185 | 186 | mag_mask = Conv2DTranspose(1, self.kernel_size[0], self.strides[0], name = name+'mag_mask', padding = self.decoder_padding[4])(skipcon_5)[:,:,:-self.output_cut_off,0] 187 | 188 | # get magnitude mask 189 | if self.activation == 'sigmoid': 190 | self.mag_mask = Activation('sigmoid')(BatchNormalization()(mag_mask))*1.2 191 | elif self.activation == 'softplus': 192 | self.mag_mask = Activation('softplus')(BatchNormalization()(mag_mask)) 193 | # get phase mask 194 | phase_square = tf.math.sqrt(dbn_5[:,:,:,0]**2 + dbn_5[:,:,:,1]**2 + self.eps) 195 | 196 | self.phase_sin = dbn_5[:,:,:,1] / phase_square 197 | self.phase_cos = dbn_5[:,:,:,0] / phase_square 198 | 199 | self.enh_mag_real,self.enh_mag_imag = Lambda(self.mk_mask_mag)([real,imag,self.mag_mask]) 200 | 201 | enh_spec = Lambda(self.mk_mask_pha)([self.enh_mag_real,self.enh_mag_imag,self.phase_cos,self.phase_sin]) 202 | 203 | enh_frame = Lambda(self.ifftLayer,arguments = {'mode':'real_imag'})(enh_spec) 204 | enh_frame = enh_frame * self.win 205 | enh_time = Lambda(self.overlapAddLayer, name = 'enhanced_time')(enh_frame) 206 | 207 | self.model = Model([time_data, self.batch_gain], enh_time) 208 | self.model.summary() 209 | 210 | self.model_inference = Model(time_data, enh_time) 211 | 212 | return self.model 213 | 214 | def compile_model(self): 215 | ''' 216 | Method to compile the model for training 217 | ''' 218 | # use the Adam optimizer with a clipnorm of 3 219 | optimizerAdam = keras.optimizers.Adam(lr=self.lr, clipnorm=3.0) 220 | # compile model with loss function 221 | self.model.compile(loss=self.lossWrapper(), optimizer = optimizerAdam, metrics = self.metricsWrapper()) 222 | 223 | def enhancement(self, noisy_f, output_f = './enhance_s.wav', plot = True, gain =1): 224 | 225 | noisy_s = sf.read(noisy_f,dtype = 'float32')[0]#[:400] 226 | 227 | enh_s = self.model_inference.predict(np.array([noisy_s])*gain) 228 | 229 | enh_s = enh_s[0] 230 | 231 | if plot: 232 | spec_n = librosa.stft(noisy_s,512,256,center = False) 233 | spec_e = librosa.stft(enh_s, 512,256,center = False) 234 | plt.figure(0) 235 | plt.plot(noisy_s) 236 | plt.plot(enh_s) 237 | plt.figure(1) 238 | plt.subplot(211) 239 | plt.imshow(np.log(abs(spec_n)+1e-8),cmap= 'jet',origin ='lower') 240 | plt.subplot(212) 241 | plt.imshow(np.log(abs(spec_e)+1e-8),cmap= 'jet',origin ='lower') 242 | sf.write(output_f,enh_s,16000) 243 | 244 | return noisy_s,enh_s 245 | 246 | def test_on_dataset(self, noisy_path, target_path): 247 | import tqdm 248 | f_list = os.listdir(noisy_path) 249 | for f in tqdm.tqdm(f_list): 250 | self.enhancement(noisy_f = os.path.join(noisy_path,f),output_f = os.path.join(target_path,f),plot = False) 251 | 252 | if __name__ == '__main__': 253 | 254 | import matplotlib.pyplot as plt 255 | import yaml 256 | 257 | f = open('./configuration/DPCRN-base.yaml','r',encoding='utf-8') 258 | result = f.read() 259 | print(result) 260 | 261 | config_dict = yaml.load(result) 262 | model = DPCRN_model(batch_size = 1, length_in_s =5, lr = 1e-3, config = config_dict) 263 | 264 | model.build_DPCRN_model() 265 | 266 | model.model.load_weights('./pretrained_weights/DPCRN_base/models_experiment_new_base_nomap_phasenloss_retrain_WSJmodel_84_0.022068.h5') 267 | model.enhancement('D:/codes/test_audio/mix/444C020b_mix.wav') -------------------------------------------------------------------------------- /DPCRN_skip.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Nov 20 22:16:58 2020 4 | 5 | @author: Xiaohuai Le 6 | """ 7 | import os 8 | import tensorflow as tf 9 | import tensorflow.keras as keras 10 | from tensorflow.keras import backend as K 11 | from tensorflow.keras.models import Model 12 | from tensorflow.keras.layers import Activation, Lambda, Input, LayerNormalization, Conv2D, BatchNormalization, Conv2DTranspose, Concatenate, PReLU 13 | 14 | import soundfile as sf 15 | from random import seed 16 | import numpy as np 17 | import librosa 18 | 19 | from loss import Loss 20 | from signal_processing import Signal_Pro 21 | 22 | from networks.modules import DprnnBlock_skip 23 | 24 | seed(42) 25 | np.random.seed(42) 26 | 27 | class DPCRN_skip_model(Loss, Signal_Pro): 28 | ''' 29 | Class to create the DPCRN-skip model 30 | ''' 31 | 32 | def __init__(self, batch_size, config, length_in_s = 8, lr = 1e-3): 33 | ''' 34 | Constructor 35 | ''' 36 | Signal_Pro.__init__(self, config) 37 | 38 | self.network_config = config['network'] 39 | self.filter_size = self.network_config['filter_size'] 40 | self.kernel_size = self.network_config['kernel_size'] 41 | self.strides = self.network_config['strides'] 42 | self.encoder_padding = self.network_config['encoder_padding'] 43 | self.decoder_padding = self.network_config['decoder_padding'] 44 | self.output_cut_off = self.network_config['output_cut'] 45 | self.N_DPRNN = self.network_config['N_DPRNN'] 46 | self.activation = self.network_config['activation'] 47 | self.input_norm = self.network_config['input_norm'] 48 | self.intra_hidden_size = self.network_config['DPRNN']['intra_hidden_size'] 49 | self.inter_hidden_size = self.network_config['DPRNN']['inter_hidden_size'] 50 | self.skip = self.network_config['DPRNN']['skip'] 51 | # optimizer and loss 52 | self.loss_type = config['trainer']['loss'] 53 | self.target_rate = config['trainer']['target'] 54 | self.alpha = config['trainer']['alpha'] 55 | # empty property for the model 56 | self.model = None 57 | # defining default parameters 58 | self.length_in_s = length_in_s 59 | self.batch_size = batch_size 60 | 61 | self.lr = lr 62 | self.eps = 1e-9 63 | 64 | self.L = (16000 * length_in_s - self.block_len) // self.block_shift + 1 65 | 66 | def metricsWrapper(self): 67 | ''' 68 | A wrapper function which returns the metrics used during training 69 | ''' 70 | # the average update rates of intra-RNN 71 | 72 | def intra_update_rate(x, y): 73 | return tf.reduce_mean(self.update_gates_intra) 74 | 75 | def inter_update_rate(x, y): 76 | return tf.reduce_mean(self.update_gates_inter) 77 | 78 | return [self.sisnr_cost,intra_update_rate,inter_update_rate] 79 | 80 | def lossWrapper(self): 81 | ''' 82 | A wrapper function which returns the loss function. This is done to 83 | to enable additional arguments to the loss function if necessary. 84 | ''' 85 | def spectrum_loss_SD(s_hat, s, c = 0.3, Lam = 0.1): 86 | 87 | # The complex compressed spectrum MSE loss 88 | s = tf.truediv(s,self.batch_gain + 1e-9) 89 | s_hat= tf.truediv(s_hat,self.batch_gain + 1e-9) 90 | 91 | true_real,true_imag = self.stftLayer(s, mode='real_imag') 92 | hat_real,hat_imag = self.stftLayer(s_hat, mode='real_imag') 93 | 94 | true_mag = tf.sqrt(true_real**2 + true_imag**2 + 1e-9) 95 | hat_mag = tf.sqrt(hat_real**2 + hat_imag**2 + 1e-9) 96 | 97 | true_real_cprs = (true_real / true_mag )*true_mag**c 98 | true_imag_cprs = (true_imag / true_mag )*true_mag**c 99 | hat_real_cprs = (hat_real / hat_mag )* hat_mag**c 100 | hat_imag_cprs = (hat_imag / hat_mag )* hat_mag**c 101 | 102 | loss_mag = tf.reduce_mean((hat_mag**c - true_mag**c)**2,) 103 | loss_real = tf.reduce_mean((hat_real_cprs - true_real_cprs)**2,) 104 | loss_imag = tf.reduce_mean((hat_imag_cprs - true_imag_cprs)**2,) 105 | 106 | if self.loss_type == 'MIN': 107 | intra_update_rates = [tf.reduce_mean(gate) for gate in self.update_gates_intra] 108 | inter_update_rates = [tf.reduce_mean(gate) for gate in self.update_gates_inter] 109 | 110 | elif self.loss_type == 'MAE': 111 | intra_update_rates = [self.skip_regular_MAE(gate, miu = self.target_rate) for gate in self.update_gates_intra] 112 | inter_update_rates = [self.skip_regular_MAE(gate, miu = self.target_rate) for gate in self.update_gates_inter] 113 | 114 | elif self.loss_type == 'MSE': 115 | intra_update_rates = [self.skip_regular_MSE(gate, miu = self.target_rate) for gate in self.update_gates_intra] 116 | inter_update_rates = [self.skip_regular_MSE(gate, miu = self.target_rate) for gate in self.update_gates_inter] 117 | 118 | Loss_skip = tf.reduce_sum(intra_update_rates) + tf.reduce_sum(inter_update_rates) 119 | return (1 - Lam) * loss_mag + Lam * ( loss_imag + loss_real ) + Loss_skip * self.alpha 120 | 121 | return spectrum_loss_SD 122 | 123 | def build_DPCRN_model(self, name = 'model0'): 124 | 125 | # input layer for time signal 126 | time_data = Input(batch_shape=(self.batch_size, None)) 127 | self.batch_gain = Input(batch_shape=(self.batch_size, 1)) 128 | # the update rate rescale factor gamma 129 | self.batch_scale = Input(batch_shape=(self.batch_size,None,1,1)) 130 | scale = tf.repeat(self.batch_scale, repeats=self.block_len //2 //8,axis=2) 131 | 132 | # calculate STFT 133 | real,imag = Lambda(self.stftLayer,arguments = {'mode':'real_imag'})(time_data) 134 | 135 | real = tf.reshape(real,[self.batch_size,-1,self.block_len // 2 + 1,1]) 136 | imag = tf.reshape(imag,[self.batch_size,-1,self.block_len // 2 + 1,1]) 137 | 138 | input_mag = tf.math.sqrt(real**2 + imag**2 +1e-9) 139 | input_log_spec = 2 * tf.math.log(input_mag) 140 | # input feature 141 | input_complex_spec = Concatenate(axis = -1)([real,imag,input_log_spec]) 142 | 143 | '''encoder''' 144 | 145 | if self.input_norm == 'batchnorm': 146 | input_complex_spec = BatchNormalization(axis = [-1,-2], epsilon = self.eps)(input_complex_spec) 147 | elif self.input_norm == 'instantlayernorm': 148 | input_complex_spec = LayerNormalization(axis = [-1,-2], epsilon = self.eps)(input_complex_spec) 149 | 150 | conv_1 = Conv2D(self.filter_size[0], self.kernel_size[0], self.strides[0], name = name+'_conv_1', padding = [[0,0],[0,0],self.encoder_padding[0],[0,0]])(input_complex_spec) 151 | bn_1 = BatchNormalization(name = name+'_bn_1')(conv_1) 152 | out_1 = PReLU(shared_axes=[1,2])(bn_1) 153 | 154 | conv_2 = Conv2D(self.filter_size[1], self.kernel_size[1], self.strides[1], name = name+'_conv_2', padding = [[0,0],[0,0],self.encoder_padding[1],[0,0]])(out_1) 155 | bn_2 = BatchNormalization(name = name+'_bn_2')(conv_2) 156 | out_2 = PReLU(shared_axes=[1,2])(bn_2) 157 | 158 | conv_3 = Conv2D(self.filter_size[2], self.kernel_size[2], self.strides[2], name = name+'_conv_3', padding = [[0,0],[0,0],self.encoder_padding[2],[0,0]])(out_2) 159 | bn_3 = BatchNormalization(name = name+'_bn_3')(conv_3) 160 | out_3 = PReLU(shared_axes=[1,2])(bn_3) 161 | 162 | conv_4 = Conv2D(self.filter_size[3], self.kernel_size[3], self.strides[3], name = name+'_conv_4', padding = [[0,0],[0,0],self.encoder_padding[3],[0,0]])(out_3) 163 | bn_4 = BatchNormalization(name = name+'_bn_4')(conv_4) 164 | out_4 = PReLU(shared_axes=[1,2])(bn_4) 165 | 166 | conv_5 = Conv2D(self.filter_size[4], self.kernel_size[4], self.strides[4], name = name+'_conv_5', padding = [[0,0],[0,0],self.encoder_padding[4],[0,0]])(out_4) 167 | bn_5 = BatchNormalization(name = name+'_bn_5')(conv_5) 168 | out_5 = PReLU(shared_axes=[1,2])(bn_5) 169 | 170 | dp_in = out_5 171 | self.update_gates_intra = [] 172 | self.update_gates_inter = [] 173 | 174 | for i in range(self.N_DPRNN): 175 | dp_in, update_gate_intra, update_gate_inter = DprnnBlock_skip(intra_hidden = self.intra_hidden_size, 176 | inter_hidden=self.inter_hidden_size, 177 | batch_size = self.batch_size, 178 | L = -1, 179 | width = self.block_len //2 //8, 180 | channel = self.filter_size[4], 181 | skip = self.skip)(dp_in, scale) 182 | self.update_gates_intra.append(update_gate_intra) 183 | self.update_gates_inter.append(update_gate_inter) 184 | 185 | dp_out = dp_in 186 | '''decoder''' 187 | skipcon_1 = Concatenate(axis = -1)([out_5, dp_out]) 188 | 189 | deconv_1 = Conv2DTranspose(self.filter_size[3], self.kernel_size[4], self.strides[4], name = name+'_dconv_1', padding = self.decoder_padding[0])(skipcon_1) 190 | dbn_1 = BatchNormalization(name = name+'_dbn_1')(deconv_1) 191 | dout_1 = PReLU(shared_axes=[1,2])(dbn_1) 192 | 193 | skipcon_2 = Concatenate(axis = -1)([out_4, dout_1]) 194 | 195 | deconv_2 = Conv2DTranspose(self.filter_size[2], self.kernel_size[3], self.strides[3], name = name+'_dconv_2', padding = self.decoder_padding[1])(skipcon_2) 196 | dbn_2 = BatchNormalization(name = name+'_dbn_2')(deconv_2) 197 | dout_2 = PReLU(shared_axes=[1,2])(dbn_2) 198 | 199 | skipcon_3 = Concatenate(axis = -1)([out_3, dout_2]) 200 | 201 | deconv_3 = Conv2DTranspose(self.filter_size[1], self.kernel_size[2], self.strides[2], name = name+'_dconv_3', padding = self.decoder_padding[2])(skipcon_3) 202 | dbn_3 = BatchNormalization(name = name+'_dbn_3')(deconv_3) 203 | dout_3 = PReLU(shared_axes=[1,2])(dbn_3) 204 | 205 | skipcon_4 = Concatenate(axis = -1)([out_2, dout_3]) 206 | 207 | deconv_4 = Conv2DTranspose(self.filter_size[0], self.kernel_size[1], self.strides[1], name = name+'_dconv_4', padding = self.decoder_padding[3])(skipcon_4) 208 | dbn_4 = BatchNormalization(name = name+'_dbn_4')(deconv_4) 209 | dout_4 = PReLU(shared_axes=[1,2])(dbn_4) 210 | 211 | skipcon_5 = Concatenate(axis = -1)([out_1, dout_4]) 212 | 213 | deconv_5 = Conv2DTranspose(2, self.kernel_size[0], self.strides[0], name = name+'_dconv_5', padding = self.decoder_padding[4])(skipcon_5) 214 | 215 | deconv_5 = deconv_5[:,:,:-self.output_cut_off] 216 | 217 | dbn_5 = BatchNormalization(name = name+'_dbn_5')(deconv_5) 218 | 219 | mag_mask = Conv2DTranspose(1, self.kernel_size[0], self.strides[0], name = name+'mag_mask', padding = self.decoder_padding[4])(skipcon_5)[:,:,:-self.output_cut_off,0] 220 | 221 | # get magnitude mask 222 | if self.activation == 'sigmoid': 223 | self.mag_mask = Activation('sigmoid')(BatchNormalization()(mag_mask))*1.2 224 | elif self.activation == 'softplus': 225 | self.mag_mask = Activation('softplus')(BatchNormalization()(mag_mask)) 226 | 227 | # get phase mask 228 | phase_square = tf.math.sqrt(dbn_5[:,:,:,0]**2 + dbn_5[:,:,:,1]**2 + self.eps) 229 | 230 | self.phase_sin = dbn_5[:,:,:,1] / phase_square 231 | self.phase_cos = dbn_5[:,:,:,0] / phase_square 232 | 233 | self.enh_mag_real,self.enh_mag_imag = Lambda(self.mk_mask_mag)([real,imag,self.mag_mask]) 234 | 235 | enh_spec = Lambda(self.mk_mask_pha)([self.enh_mag_real,self.enh_mag_imag,self.phase_cos,self.phase_sin]) 236 | 237 | enh_frame = Lambda(self.ifftLayer,arguments = {'mode':'real_imag'})(enh_spec) 238 | enh_frame = enh_frame * self.win 239 | enh_time = Lambda(self.overlapAddLayer, name = 'enhanced_time')(enh_frame) 240 | 241 | self.model = Model([time_data, self.batch_gain, self.batch_scale], enh_time) 242 | self.model.summary() 243 | 244 | outputs = [enh_time] 245 | for update_gates in self.update_gates_intra: 246 | outputs.append(update_gates[None]) 247 | for update_gates in self.update_gates_inter: 248 | outputs.append(update_gates[None]) 249 | 250 | self.model_inference = Model([time_data, self.batch_scale], outputs) 251 | 252 | return self.model 253 | 254 | def compile_model(self): 255 | ''' 256 | Method to compile the model for training 257 | ''' 258 | # use the Adam optimizer with a clipnorm of 3 259 | optimizerAdam = keras.optimizers.Adam(lr=self.lr, clipnorm=3.0) 260 | # compile model with loss function 261 | self.model.compile(loss=self.lossWrapper(), optimizer = optimizerAdam, metrics = self.metricsWrapper()) 262 | 263 | def enhancement(self, noisy_f, output_f = './enhance_s.wav', plot = True, gain =1, gamma = 1, vad = None): 264 | ''' 265 | processing on a single wav 266 | noisy_f: noisy path 267 | output_f: output path 268 | plot: visualization 269 | gain: the level rescaling gain 270 | gamma: update rate scaling factor 271 | vad: the VAD label 272 | ''' 273 | noisy_s = sf.read(noisy_f,dtype = 'float32')[0]#[:400] 274 | 275 | N = librosa.util.frame(noisy_s,512,256).shape[-1] 276 | 277 | if vad is not None: 278 | # VAD guided skipping 279 | scale = np.ones([1,N,1,1]) 280 | scale[0,:,0,0] = scale[0,:,0,0] * (vad + (1-vad) * gamma) 281 | else: 282 | scale = np.ones([1,N,1,1]) * gamma 283 | 284 | enh_s, update_gate1_intra, update_gate2_intra, update_gate1_inter, update_gate2_inter = self.model_inference.predict([np.array([noisy_s])*gain,scale]) 285 | 286 | enh_s = enh_s[0] 287 | # visualization 288 | if plot: 289 | spec_n = librosa.stft(noisy_s,512,256,center = False) 290 | spec_e = librosa.stft(enh_s, 512,256,center = False) 291 | plt.figure(0) 292 | plt.plot(noisy_s) 293 | plt.plot(enh_s) 294 | plt.figure(1) 295 | plt.subplot(211) 296 | plt.imshow(np.log(abs(spec_n)+1e-8),cmap= 'jet',origin ='lower') 297 | plt.subplot(212) 298 | plt.imshow(np.log(abs(spec_e)+1e-8),cmap= 'jet',origin ='lower') 299 | plt.figure(2) 300 | plt.subplot(211) 301 | plt.title('dprnn1-intra-chunk') 302 | plt.imshow(update_gate1_intra[0],origin ='lower',aspect='auto') 303 | plt.subplot(212) 304 | plt.title('dprnn1-inter-chunk') 305 | plt.imshow(update_gate1_inter[0],origin ='lower',aspect='auto') 306 | plt.figure(3) 307 | plt.subplot(211) 308 | plt.title('dprnn2-intra-chunk') 309 | plt.imshow(update_gate2_intra[0],origin ='lower',aspect='auto') 310 | plt.subplot(212) 311 | plt.title('dprnn2-inter-chunk') 312 | plt.imshow(update_gate2_inter[0],origin ='lower',aspect='auto') 313 | 314 | sf.write(output_f,enh_s,16000) 315 | 316 | return noisy_s,enh_s 317 | 318 | def test_on_dataset(self, noisy_path, target_path, gamma = 1): 319 | import tqdm 320 | f_list = os.listdir(noisy_path) 321 | for f in tqdm.tqdm(f_list): 322 | self.enhancement(noisy_f = os.path.join(noisy_path,f),output_f = os.path.join(target_path,f), plot = False, gamma = 1) 323 | 324 | 325 | if __name__ == '__main__': 326 | 327 | import matplotlib.pyplot as plt 328 | import yaml 329 | 330 | f = open('./configuration/DPCRN-skip.yaml','r',encoding='utf-8') 331 | result = f.read() 332 | print(result) 333 | 334 | config_dict = yaml.load(result) 335 | model = DPCRN_skip_model(batch_size = 1, length_in_s =5, lr = 1e-3, config = config_dict) 336 | model.build_DPCRN_model() 337 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SKIP-DPCRN 2 | Implementation of the DPCRN model with skipping strategy in the accepted manuscript for IEEE/ACM TASLP. 3 | 4 | # Requirements 5 | tensorflow>=1.14, 6 | numpy, 7 | matplotlib, 8 | librosa, 9 | sondfile. 10 | # Run 11 | ```shell 12 | python main.py --mode train --cuda 0 --experimentName experiment_1 13 | ``` 14 | ```shell 15 | python main.py --mode test --cuda 0 --ckpt PATH_OF_PRETRAINED_MODEL --test_dir PATH_OF_NOISY_AUDIO --output_dir PATH_OF_ENHANCED_AUDIO 16 | ``` 17 | # Reference 18 | 1. X. Le, H. Chen, K. Chen, and J. Lu, “DPCRN: Dual-Path Convolution Recurrent Network for Single Channel Speech Enhancement,” Proc. Interspeech 2021, pp. 2811–2815, 2021. (https://github.com/Le-Xiaohuai-speech/DPCRN_DNS3) 19 | -------------------------------------------------------------------------------- /configuration/DPCRN-base.yaml: -------------------------------------------------------------------------------- 1 | name: DPCRN-base 2 | trainer: 3 | init_lr: 1e-3 4 | reduce_patience: 10 5 | early_stop: 20 6 | optimizer: adam 7 | max_epochs: 200 8 | seed: 42 9 | network: 10 | filter_size: [32,32,32,64,128] 11 | kernel_size: [[1,5],[1,3],[1,3],[1,3],[1,3]] 12 | strides: [[1,2],[1,2],[1,2],[1,1],[1,1]] 13 | encoder_padding: [[0,2],[0,1],[0,1],[1,1],[1,1]] 14 | decoder_padding: ['same', 'same', 'same', 'same', 'valid'] 15 | output_cut: 2 16 | N_DPRNN: 2 17 | DPRNN: 18 | intra_hidden_size: 128 19 | inter_hidden_size: 128 20 | use_CuDNNGRU: False 21 | activation: softplus # or sigmoid 22 | input_norm: batchnorm # or instantlayernorm 23 | test: 24 | test_data_dir: '' 25 | stft: 26 | fs: 16000 27 | block_len: 512 28 | block_shift: 256 29 | window: sine 30 | N_FFT: 512 31 | database: 32 | DNS_path: '' # the path of the DNS data 33 | WSJ_path: '' # the path of the WSJ data 34 | RIRs_path: ''# the path of the RIR data 35 | SNR: [-5,5] 36 | reverb_rate: 0.5 37 | spec_aug_rate: 0.3 38 | data_path: './temp' -------------------------------------------------------------------------------- /configuration/DPCRN-skip.yaml: -------------------------------------------------------------------------------- 1 | name: DPCRN-skip 2 | trainer: 3 | init_lr: 1e-3 4 | reduce_patience: 10 5 | early_stop: 20 6 | optimizer: adam 7 | max_epochs: 200 8 | seed: 42 9 | loss: MSE # or MAE or MIN 10 | target: 0.5 # target update rate 11 | alpha: 0.01 # the weighting factor of the regularization 12 | network: 13 | filter_size: [32,32,32,64,128] 14 | kernel_size: [[1,5],[1,3],[1,3],[1,3],[1,3]] 15 | strides: [[1,2],[1,2],[1,2],[1,1],[1,1]] 16 | encoder_padding: [[0,2],[0,1],[0,1],[1,1],[1,1]] 17 | decoder_padding: ['same', 'same', 'same', 'same', 'valid'] 18 | output_cut: 2 19 | N_DPRNN: 2 20 | DPRNN: 21 | intra_hidden_size: 128 22 | inter_hidden_size: 128 23 | skip: 2 # 0 inter-skip, 1 intra-skip, 2 all-skip 24 | use_CuDNNGRU: False 25 | activation: softplus # or sigmoid 26 | input_norm: batchnorm # or instantlayernorm 27 | test: 28 | test_data_dir: '' 29 | stft: 30 | fs: 16000 31 | block_len: 512 32 | block_shift: 256 33 | window: sine 34 | N_FFT: 512 35 | database: 36 | noise_path: '' 37 | clean_path: '' 38 | RIRs_path: '' 39 | SNR: [-5,5] 40 | reverb_rate: 0.5 41 | spec_aug_rate: 0.3 42 | data_path: './temp' -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Jan 12 14:57:00 2021 4 | 5 | @author: xiaohuaile 6 | """ 7 | import soundfile as sf 8 | #from wavinfo import WavInfoReader 9 | from random import shuffle, seed 10 | import numpy as np 11 | import librosa 12 | import os 13 | from scipy import signal 14 | import scipy 15 | import tqdm 16 | 17 | #FIR, frequencies below 60Hz will be filtered 18 | fir = signal.firls(1025,[0,40,50,60,70,8000],[0,0,0.1,0.5,1,1],fs = 16000) 19 | 20 | # add the reverberation 21 | def add_pyreverb(clean_speech, rir): 22 | l = len(rir)//2 23 | reverb_speech = signal.fftconvolve(clean_speech, rir, mode="full") 24 | # make reverb_speech same length as clean_speech 25 | reverb_speech = reverb_speech[l : clean_speech.shape[0]+l] 26 | 27 | return reverb_speech 28 | 29 | # mix the signal with SNR 30 | def mk_mixture(s1,s2,snr,eps = 1e-8): 31 | 32 | norm_sig1 = s1 / np.sqrt(np.sum(s1 ** 2) + eps) 33 | norm_sig2 = s2 / np.sqrt(np.sum(s2 ** 2) + eps) 34 | alpha = 10**(snr/20) 35 | mix = norm_sig2 + alpha*norm_sig1 36 | M = max(np.max(abs(mix)),np.max(abs(norm_sig2)),np.max(abs(alpha*norm_sig1))) + eps 37 | mix = mix / M 38 | norm_sig1 = norm_sig1 * alpha/ M 39 | norm_sig2 = norm_sig2 / M 40 | #print('alp',alpha/ M) 41 | return norm_sig1,norm_sig2,mix,snr 42 | 43 | def get_energy(s,frame_length = 512, hop_length = 256): 44 | frames = librosa.util.frame(s,frame_length,hop_length) 45 | energy = np.sum(frames**2,axis = 0) 46 | return energy 47 | 48 | def get_VAD(s,frame_length = 512, hop_length = 256): 49 | s = s/np.max(abs(s)) 50 | energy = get_energy(s,frame_length,hop_length) 51 | thd = -4 52 | vad = np.zeros_like(energy) 53 | vad[np.log(energy)>thd]=1 54 | 55 | energy_1 = vad * energy 56 | thd1 = np.log((np.sum(energy_1)/sum(vad))/100+1e-8) 57 | vad = np.zeros_like(energy) 58 | vad[np.log(energy)>thd1]=1 59 | 60 | return energy,vad 61 | 62 | # random 2-order IIR for spectrum augmentation 63 | def spec_augment(s): 64 | r = np.random.uniform(-0.375,-0.375,4) 65 | sf = signal.lfilter(b = [1,r[0],r[1]],a = [1,r[2],r[3]],x = s) 66 | return sf 67 | 68 | class data_generator(): 69 | 70 | def __init__(self, 71 | DNS_dir, 72 | WSJ_dir, 73 | RIR_dir, 74 | temp_data_dir, 75 | length_per_sample = 8, 76 | SNR_range = [-5,5], 77 | fs = 16000, 78 | n_fft = 512, 79 | n_hop = 256, 80 | batch_size = 16, 81 | sd = 42, 82 | add_reverb = True, 83 | reverb_rate = 0.5, 84 | spec_aug_rate = 0.3, 85 | ): 86 | ''' 87 | keras data generator 88 | Para.: 89 | DNS_dir: the folder of the DNS data, including DNS_dir/clean, DNS_dir/noise 90 | WSJ_dir: the folder of the WSJ data, including train_dir/clean, train_dir/noise 91 | RIR_dir: the folder of RIRs, from OpenSLR26 and OpenSLR28 92 | temp_data_dir: the folder for temporary data storing 93 | length_per_sample: speech sample length in second 94 | SNR_range: the upper and lower bound of the SNR 95 | fs: sample rate of the speech 96 | n_fft: FFT length and window length in STFT 97 | n_hop: hop length in STFT 98 | batch_size: batch size 99 | sample_num: how many samples are used for training and validation 100 | add_reverb: adding reverbrantion or not 101 | reverb_rate: how much data is reverbrant 102 | ''' 103 | seed(sd) 104 | np.random.seed(sd) 105 | 106 | self.fs = fs 107 | self.batch_size = batch_size 108 | self.length_per_sample = length_per_sample 109 | self.L = length_per_sample * self.fs 110 | # calculate the length of each sample after iSTFT 111 | self.points_per_sample = ((self.L - n_fft) // n_hop) * n_hop + n_fft 112 | 113 | self.SNR_range = SNR_range 114 | self.add_reverb = add_reverb 115 | self.reverb_rate = reverb_rate 116 | self.spec_aug_rate = spec_aug_rate 117 | 118 | self.DNS_dir = DNS_dir 119 | self.WSJ_dir = WSJ_dir 120 | self.RIR_dir = RIR_dir 121 | 122 | self.noise_dir = os.path.join(self.DNS_dir,'noise') 123 | self.noise_file_list = os.listdir(self.noise_dir) 124 | if not os.path.exists(temp_data_dir): 125 | self.train_wsj_dir, self.valid_wsj_dir = self.preproccess(self.WSJ_dir, temp_data_dir) 126 | else: 127 | self.train_wsj_dir = os.path.join(temp_data_dir,'si_tr_s') 128 | self.valid_wsj_dir = os.path.join(temp_data_dir,'si_dt_05') 129 | 130 | self.train_wsj_data = librosa.util.find_files(self.train_wsj_dir,ext='npy') 131 | self.valid_wsj_data = librosa.util.find_files(self.valid_wsj_dir,ext='npy') 132 | np.random.shuffle(self.train_wsj_data) 133 | np.random.shuffle(self.valid_wsj_data) 134 | 135 | if RIR_dir is not None: 136 | self.rir_dir = RIR_dir 137 | self.rir_list = librosa.util.find_files(self.rir_dir,ext = 'wav') 138 | np.random.shuffle(self.rir_list) 139 | print('there are {} rir clips\n'.format(len(self.rir_list))) 140 | 141 | self.train_length = len(self.train_wsj_data) 142 | self.valid_length = len(self.valid_wsj_data) 143 | print('there are {} clips for training, and {} clips for validation.'.format(self.train_length, self.valid_length)) 144 | 145 | def preproccess(self, WSJ_dir, data_dir): 146 | ''' 147 | concatenate the clean speech and split them into 8s clips 148 | ''' 149 | if not os.path.exists(data_dir): 150 | os.mkdir(data_dir) 151 | 152 | train_dir = os.path.join(self.WSJ_dir,'si_tr_s') 153 | valid_dir = os.path.join(self.WSJ_dir,'si_dt_05') 154 | 155 | os.mkdir(os.path.join(data_dir,'si_tr_s')) 156 | os.mkdir(os.path.join(data_dir,'si_dt_05')) 157 | 158 | train_wavs = librosa.util.find_files(train_dir,ext='wav') 159 | valid_wavs = librosa.util.find_files(valid_dir,ext='wav') 160 | 161 | train_N_samples = 0 162 | valid_N_samples = 0 163 | 164 | for wav in train_wavs: 165 | train_N_samples += round(sf.info(wav).duration * self.fs) 166 | for wav in valid_wavs: 167 | valid_N_samples += round(sf.info(wav).duration * self.fs) 168 | 169 | temp_train = np.zeros(train_N_samples, dtype = 'int16') 170 | N_samples = train_N_samples // self.L 171 | begin = 0 172 | print('prepare clean data...\n') 173 | for wav in tqdm.tqdm(train_wavs): 174 | s = sf.read(wav)[0] 175 | s = s / np.max(abs(s)) 176 | temp_train[begin:begin+len(s)] = s 177 | begin += len(s) 178 | 179 | for i in tqdm.tqdm(range(N_samples)): 180 | np.save(os.path.join(data_dir,'si_tr_s','{}.npy'.format(i)),temp_train[self.L*i:self.L*(i+1)]) 181 | 182 | del temp_train 183 | 184 | temp_valid = np.zeros(valid_N_samples, dtype = 'int16') 185 | N_samples = valid_N_samples // self.L 186 | 187 | begin = 0 188 | for wav in tqdm.tqdm(valid_wavs): 189 | s = sf.read(wav)[0] 190 | s = s / np.max(abs(s)) 191 | temp_valid[begin:begin+len(s)] = s 192 | begin += len(s) 193 | 194 | for i in tqdm.tqdm(range(N_samples)): 195 | np.save(os.path.join(data_dir,'si_dt_05','{}.npy'.format(i)),temp_valid[self.L*i:self.L*(i+1)]) 196 | 197 | del temp_valid 198 | return os.path.join(data_dir,'si_tr_s'),os.path.join(data_dir,'si_dt_05') 199 | 200 | def generator(self, batch_size, validation = False): 201 | 202 | if validation: 203 | train_data = self.valid_wsj_data 204 | else: 205 | train_data = self.train_wsj_data 206 | 207 | N_batch = len(train_data) // batch_size 208 | batch_num = 0 209 | 210 | while (True): 211 | 212 | batch_clean = np.zeros([batch_size,self.L],dtype = np.float32) 213 | batch_noisy = np.zeros([batch_size,self.L],dtype = np.float32) 214 | batch_gain = np.zeros([batch_size,1],dtype = np.float32) 215 | 216 | rir_f_list = np.random.choice(self.rir_list, batch_size) 217 | noise_f_list = np.random.choice(self.noise_file_list,batch_size) 218 | 219 | for i in range(batch_size): 220 | 221 | SNR = np.random.uniform(self.SNR_range[0],self.SNR_range[1]) 222 | # level rescaling gain 223 | gain = np.random.normal(loc=-5,scale=10) 224 | gain = 10**(gain/10) 225 | gain = min(gain,5) 226 | gain = max(gain,0.01) 227 | 228 | sample_num = batch_num * batch_size + i 229 | clean_f = train_data[sample_num] 230 | 231 | noise_f = noise_f_list[i] 232 | Begin_N = int(np.random.uniform(0, 30 - self.length_per_sample)) * self.fs 233 | # read clean speech and noises 234 | clean_s = np.load(clean_f) / 32768.0 235 | noise_s = sf.read(os.path.join(self.noise_dir,noise_f), dtype = 'float32',start= Begin_N,stop = Begin_N + self.L)[0] 236 | 237 | # high pass filtering 238 | clean_s = add_pyreverb(clean_s, fir) 239 | # spectrum augmentation 240 | if np.random.rand() < self.spec_aug_rate: 241 | clean_s = spec_augment(clean_s) 242 | # add reverberation 243 | if self.add_reverb: 244 | if np.random.rand() < self.reverb_rate: 245 | rir_s = sf.read(rir_f_list[i],dtype = 'float32')[0] 246 | if len(rir_s.shape)>1: 247 | rir_s = rir_s[:,0] 248 | if clean_f.split('_')[0] == 'clean': 249 | clean_s = add_pyreverb(clean_s, rir_s) 250 | # mix the clean speech and the noise 251 | clean_s,noise_s,noisy_s,_ = mk_mixture(clean_s, noise_s, SNR, eps = 1e-8) 252 | # rescaling 253 | batch_clean[i,:] = clean_s *gain 254 | batch_noisy[i,:] = noisy_s *gain 255 | batch_gain[i] = gain 256 | 257 | batch_num += 1 258 | if batch_num == N_batch: 259 | batch_num = 0 260 | 261 | if self.use_cross_valid: 262 | self.train_list, self.validation_list = self.generating_train_validation(self.train_length) 263 | if validation: 264 | train_data = self.valid_wsj_data 265 | else: 266 | train_data = self.train_wsj_data 267 | 268 | np.random.shuffle(train_data) 269 | np.random.shuffle(self.noise_file_list) 270 | 271 | N_batch = len(train_data) // batch_size 272 | 273 | yield [batch_noisy,batch_gain], batch_clean 274 | 275 | if __name__ == '__main__': 276 | 277 | dg = data_generator(DNS_dir = '/data/ssd1/xiaohuai.le/DNS_data1/DNS_data', 278 | WSJ_dir = '/data/hdd0/xiaohuaile/g6_data/Speech_database/WSJ/wsj0/', 279 | RIR_dir = '/data/ssd1/xiaohuai.le/RIR_database/impulse_responses/', 280 | temp_data_dir = './temp_data',) 281 | 282 | 283 | -------------------------------------------------------------------------------- /enhance_s.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Le-Xiaohuai-speech/SKIP-DPCRN/40e9bb6468bbfecff6ddd698eb63dac7b1089472/enhance_s.wav -------------------------------------------------------------------------------- /evaluations/mir_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Source separation algorithms attempt to extract recordings of individual 4 | sources from a recording of a mixture of sources. Evaluation methods for 5 | source separation compare the extracted sources from reference sources and 6 | attempt to measure the perceptual quality of the separation. 7 | 8 | See also the bss_eval MATLAB toolbox: 9 | http://bass-db.gforge.inria.fr/bss_eval/ 10 | 11 | Conventions 12 | ----------- 13 | 14 | An audio signal is expected to be in the format of a 1-dimensional array where 15 | the entries are the samples of the audio signal. When providing a group of 16 | estimated or reference sources, they should be provided in a 2-dimensional 17 | array, where the first dimension corresponds to the source number and the 18 | second corresponds to the samples. 19 | 20 | Metrics 21 | ------- 22 | 23 | * :func:`mir_eval.separation.bss_eval_sources`: Computes the bss_eval_sources 24 | metrics from bss_eval, which optionally optimally match the estimated sources 25 | to the reference sources and measure the distortion and artifacts present in 26 | the estimated sources as well as the interference between them. 27 | 28 | * :func:`mir_eval.separation.bss_eval_sources_framewise`: Computes the 29 | bss_eval_sources metrics on a frame-by-frame basis. 30 | 31 | * :func:`mir_eval.separation.bss_eval_images`: Computes the bss_eval_images 32 | metrics from bss_eval, which includes the metrics in 33 | :func:`mir_eval.separation.bss_eval_sources` plus the image to spatial 34 | distortion ratio. 35 | 36 | * :func:`mir_eval.separation.bss_eval_images_framewise`: Computes the 37 | bss_eval_images metrics on a frame-by-frame basis. 38 | 39 | References 40 | ---------- 41 | .. [#vincent2006performance] Emmanuel Vincent, Rémi Gribonval, and Cédric 42 | Févotte, "Performance measurement in blind audio source separation," IEEE 43 | Trans. on Audio, Speech and Language Processing, 14(4):1462-1469, 2006. 44 | 这个是盲源分离算法的评价指标 45 | 46 | ''' 47 | 48 | import numpy as np 49 | import scipy.fftpack 50 | from scipy.linalg import toeplitz 51 | from scipy.signal import fftconvolve 52 | import collections 53 | import itertools 54 | import warnings 55 | #from utils import mir_util as util 56 | 57 | 58 | # The maximum allowable number of sources (prevents insane computational load) 59 | MAX_SOURCES = 100 60 | 61 | 62 | def validate(reference_sources, estimated_sources): 63 | """Checks that the input data to a metric are valid, and throws helpful 64 | errors if not. 65 | 66 | Parameters 67 | ---------- 68 | reference_sources : np.ndarray, shape=(nsrc, nsampl) 69 | matrix containing true sources 70 | estimated_sources : np.ndarray, shape=(nsrc, nsampl) 71 | matrix containing estimated sources 72 | 73 | """ 74 | 75 | if reference_sources.shape != estimated_sources.shape: 76 | raise ValueError('The shape of estimated sources and the true ' 77 | 'sources should match. reference_sources.shape ' 78 | '= {}, estimated_sources.shape ' 79 | '= {}'.format(reference_sources.shape, 80 | estimated_sources.shape)) 81 | 82 | if reference_sources.ndim > 3 or estimated_sources.ndim > 3: 83 | raise ValueError('The number of dimensions is too high (must be less ' 84 | 'than 3). reference_sources.ndim = {}, ' 85 | 'estimated_sources.ndim ' 86 | '= {}'.format(reference_sources.ndim, 87 | estimated_sources.ndim)) 88 | 89 | if reference_sources.size == 0: 90 | warnings.warn("reference_sources is empty, should be of size " 91 | "(nsrc, nsample). sdr, sir, sar, and perm will all " 92 | "be empty np.ndarrays") 93 | elif _any_source_silent(reference_sources): 94 | raise ValueError('All the reference sources should be non-silent (not ' 95 | 'all-zeros), but at least one of the reference ' 96 | 'sources is all 0s, which introduces ambiguity to the' 97 | ' evaluation. (Otherwise we can add infinitely many ' 98 | 'all-zero sources.)') 99 | 100 | if estimated_sources.size == 0: 101 | warnings.warn("estimated_sources is empty, should be of size " 102 | "(nsrc, nsample). sdr, sir, sar, and perm will all " 103 | "be empty np.ndarrays") 104 | elif _any_source_silent(estimated_sources): 105 | raise ValueError('All the estimated sources should be non-silent (not ' 106 | 'all-zeros), but at least one of the estimated ' 107 | 'sources is all 0s. Since we require each reference ' 108 | 'source to be non-silent, having a silent estimated ' 109 | 'source will result in an underdetermined system.') 110 | 111 | if (estimated_sources.shape[0] > MAX_SOURCES or 112 | reference_sources.shape[0] > MAX_SOURCES): 113 | raise ValueError('The supplied matrices should be of shape (nsrc,' 114 | ' nsampl) but reference_sources.shape[0] = {} and ' 115 | 'estimated_sources.shape[0] = {} which is greater ' 116 | 'than mir_eval.separation.MAX_SOURCES = {}. To ' 117 | 'override this check, set ' 118 | 'mir_eval.separation.MAX_SOURCES to a ' 119 | 'larger value.'.format(reference_sources.shape[0], 120 | estimated_sources.shape[0], 121 | MAX_SOURCES)) 122 | 123 | 124 | def _any_source_silent(sources): 125 | """Returns true if the parameter sources has any silent first dimensions""" 126 | return np.any(np.all(np.sum( 127 | sources, axis=tuple(range(2, sources.ndim))) == 0, axis=1)) 128 | 129 | 130 | def bss_eval_sources(reference_sources, estimated_sources, 131 | compute_permutation=True): 132 | """ 133 | 对源分离指标进行测量与排序 134 | Ordering and measurement of the separation quality for estimated source 135 | signals in terms of filtered true source, interference and artifacts. 136 | 允许5126点的线性时不变滤波器失真 137 | The decomposition allows a time-invariant filter distortion of length 138 | 512, as described in Section III.B of [#vincent2006performance]_. 139 | 这个函数会估计出参考源(1维或是二维的数组)与估计源之间的sdr sir sar 与排序属性值 140 | Passing ``False`` for ``compute_permutation`` will improve the computation 141 | performance of the evaluation; however, it is not always appropriate and 142 | is not the way that the BSS_EVAL Matlab toolbox computes bss_eval_sources. 143 | 144 | Examples 145 | -------- 146 | >>> # reference_sources[n] should be an ndarray of samples of the 147 | >>> # n'th reference source 148 | >>> # estimated_sources[n] should be the same for the n'th estimated 149 | >>> # source 150 | >>> (sdr, sir, sar, 151 | ... perm) = mir_eval.separation.bss_eval_sources(reference_sources, 152 | ... estimated_sources) 153 | 154 | Parameters 155 | ---------- 156 | reference_sources : np.ndarray, shape=(nsrc, nsampl) 157 | matrix containing true sources (must have same shape as 158 | estimated_sources) 159 | estimated_sources : np.ndarray, shape=(nsrc, nsampl) 160 | matrix containing estimated sources (must have same shape as 161 | reference_sources) 162 | compute_permutation : bool, optional 163 | compute permutation of estimate/source combinations (True by default) 164 | 165 | Returns 166 | ------- 167 | sdr : np.ndarray, shape=(nsrc,) 168 | vector of Signal to Distortion Ratios (SDR) 169 | sir : np.ndarray, shape=(nsrc,) 170 | vector of Source to Interference Ratios (SIR) 171 | sar : np.ndarray, shape=(nsrc,) 172 | vector of Sources to Artifacts Ratios (SAR) 173 | perm : np.ndarray, shape=(nsrc,) 174 | vector containing the best ordering of estimated sources in 175 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to 176 | true source number ``j``). Note: ``perm`` will be ``[0, 1, ..., 177 | nsrc-1]`` if ``compute_permutation`` is ``False``. 178 | 179 | References 180 | ---------- 181 | .. [#] Emmanuel Vincent, Shoko Araki, Fabian J. Theis, Guido Nolte, Pau 182 | Bofill, Hiroshi Sawada, Alexey Ozerov, B. Vikrham Gowreesunker, Dominik 183 | Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign 184 | (2007-2010): Achievements and remaining challenges", Signal Processing, 185 | 92, pp. 1928-1936, 2012. 186 | 187 | """ 188 | 189 | # make sure the input is of shape (nsrc, nsampl) 190 | if estimated_sources.ndim == 1: 191 | estimated_sources = estimated_sources[np.newaxis, :] 192 | if reference_sources.ndim == 1: 193 | reference_sources = reference_sources[np.newaxis, :] 194 | 195 | validate(reference_sources, estimated_sources) 196 | # If empty matrices were supplied, return empty lists (special case) 197 | if reference_sources.size == 0 or estimated_sources.size == 0: 198 | return np.array([]), np.array([]), np.array([]), np.array([]) 199 | 200 | nsrc = estimated_sources.shape[0] 201 | 202 | # does user desire permutations? 203 | if compute_permutation: 204 | # compute criteria for all possible pair matches 205 | sdr = np.empty((nsrc, nsrc)) 206 | sir = np.empty((nsrc, nsrc)) 207 | sar = np.empty((nsrc, nsrc)) 208 | for jest in range(nsrc): 209 | for jtrue in range(nsrc): 210 | s_true, e_spat, e_interf, e_artif = \ 211 | _bss_decomp_mtifilt(reference_sources, 212 | estimated_sources[jest], 213 | jtrue, 512) 214 | sdr[jest, jtrue], sir[jest, jtrue], sar[jest, jtrue] = \ 215 | _bss_source_crit(s_true, e_spat, e_interf, e_artif) 216 | 217 | # select the best ordering 218 | perms = list(itertools.permutations(list(range(nsrc)))) 219 | mean_sir = np.empty(len(perms)) 220 | dum = np.arange(nsrc) 221 | for (i, perm) in enumerate(perms): 222 | mean_sir[i] = np.mean(sir[perm, dum]) 223 | popt = perms[np.argmax(mean_sir)] 224 | idx = (popt, dum) 225 | return (sdr[idx], sir[idx], sar[idx], np.asarray(popt)) 226 | else: 227 | # compute criteria for only the simple correspondence 228 | # (estimate 1 is estimate corresponding to reference source 1, etc.) 229 | sdr = np.empty(nsrc) 230 | sir = np.empty(nsrc) 231 | sar = np.empty(nsrc) 232 | for j in range(nsrc): 233 | s_true, e_spat, e_interf, e_artif = \ 234 | _bss_decomp_mtifilt(reference_sources, 235 | estimated_sources[j], 236 | j, 512) 237 | sdr[j], sir[j], sar[j] = \ 238 | _bss_source_crit(s_true, e_spat, e_interf, e_artif) 239 | 240 | # return the default permutation for compatibility 241 | popt = np.arange(nsrc) 242 | return (sdr, sir, sar, popt) 243 | 244 | 245 | def bss_eval_sources_framewise(reference_sources, estimated_sources, 246 | window=30*44100, hop=15*44100, 247 | compute_permutation=False): 248 | """Framewise computation of bss_eval_sources 249 | 250 | Please be aware that this function does not compute permutations (by 251 | default) on the possible relations between reference_sources and 252 | estimated_sources due to the dangers of a changing permutation. Therefore 253 | (by default), it assumes that ``reference_sources[i]`` corresponds to 254 | ``estimated_sources[i]``. To enable computing permutations please set 255 | ``compute_permutation`` to be ``True`` and check that the returned ``perm`` 256 | is identical for all windows. 257 | 258 | NOTE: if ``reference_sources`` and ``estimated_sources`` would be evaluated 259 | using only a single window or are shorter than the window length, the 260 | result of :func:`mir_eval.separation.bss_eval_sources` called on 261 | ``reference_sources`` and ``estimated_sources`` (with the 262 | ``compute_permutation`` parameter passed to 263 | :func:`mir_eval.separation.bss_eval_sources`) is returned. 264 | 265 | Examples 266 | -------- 267 | >>> # reference_sources[n] should be an ndarray of samples of the 268 | >>> # n'th reference source 269 | >>> # estimated_sources[n] should be the same for the n'th estimated 270 | >>> # source 271 | >>> (sdr, sir, sar, 272 | ... perm) = mir_eval.separation.bss_eval_sources_framewise( 273 | reference_sources, 274 | ... estimated_sources) 275 | 276 | Parameters 277 | ---------- 278 | reference_sources : np.ndarray, shape=(nsrc, nsampl) 279 | matrix containing true sources (must have the same shape as 280 | ``estimated_sources``) 281 | estimated_sources : np.ndarray, shape=(nsrc, nsampl) 282 | matrix containing estimated sources (must have the same shape as 283 | ``reference_sources``) 284 | window : int, optional 285 | Window length for framewise evaluation (default value is 30s at a 286 | sample rate of 44.1kHz) 287 | hop : int, optional 288 | Hop size for framewise evaluation (default value is 15s at a 289 | sample rate of 44.1kHz) 290 | compute_permutation : bool, optional 291 | compute permutation of estimate/source combinations for all windows 292 | (False by default) 293 | 294 | Returns 295 | ------- 296 | sdr : np.ndarray, shape=(nsrc, nframes) 297 | vector of Signal to Distortion Ratios (SDR) 298 | sir : np.ndarray, shape=(nsrc, nframes) 299 | vector of Source to Interference Ratios (SIR) 300 | sar : np.ndarray, shape=(nsrc, nframes) 301 | vector of Sources to Artifacts Ratios (SAR) 302 | perm : np.ndarray, shape=(nsrc, nframes) 303 | vector containing the best ordering of estimated sources in 304 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to 305 | true source number ``j``). Note: ``perm`` will be ``range(nsrc)`` for 306 | all windows if ``compute_permutation`` is ``False`` 307 | 308 | """ 309 | 310 | # make sure the input is of shape (nsrc, nsampl) 311 | if estimated_sources.ndim == 1: 312 | estimated_sources = estimated_sources[np.newaxis, :] 313 | if reference_sources.ndim == 1: 314 | reference_sources = reference_sources[np.newaxis, :] 315 | 316 | validate(reference_sources, estimated_sources) 317 | # If empty matrices were supplied, return empty lists (special case) 318 | if reference_sources.size == 0 or estimated_sources.size == 0: 319 | return np.array([]), np.array([]), np.array([]), np.array([]) 320 | 321 | nsrc = reference_sources.shape[0] 322 | 323 | nwin = int( 324 | np.floor((reference_sources.shape[1] - window + hop) / hop) 325 | ) 326 | # if fewer than 2 windows would be evaluated, return the sources result 327 | if nwin < 2: 328 | result = bss_eval_sources(reference_sources, 329 | estimated_sources, 330 | compute_permutation) 331 | return [np.expand_dims(score, -1) for score in result] 332 | 333 | # compute the criteria across all windows 334 | sdr = np.empty((nsrc, nwin)) 335 | sir = np.empty((nsrc, nwin)) 336 | sar = np.empty((nsrc, nwin)) 337 | perm = np.empty((nsrc, nwin)) 338 | 339 | # k iterates across all the windows 340 | for k in range(nwin): 341 | win_slice = slice(k * hop, k * hop + window) 342 | ref_slice = reference_sources[:, win_slice] 343 | est_slice = estimated_sources[:, win_slice] 344 | # check for a silent frame 345 | if (not _any_source_silent(ref_slice) and 346 | not _any_source_silent(est_slice)): 347 | sdr[:, k], sir[:, k], sar[:, k], perm[:, k] = bss_eval_sources( 348 | ref_slice, est_slice, compute_permutation 349 | ) 350 | else: 351 | # if we have a silent frame set results as np.nan 352 | sdr[:, k] = sir[:, k] = sar[:, k] = perm[:, k] = np.nan 353 | 354 | return sdr, sir, sar, perm 355 | 356 | 357 | def bss_eval_images(reference_sources, estimated_sources, 358 | compute_permutation=True): 359 | """Implementation of the bss_eval_images function from the 360 | BSS_EVAL Matlab toolbox. 361 | 362 | Ordering and measurement of the separation quality for estimated source 363 | signals in terms of filtered true source, interference and artifacts. 364 | This method also provides the ISR measure. 365 | 366 | The decomposition allows a time-invariant filter distortion of length 367 | 512, as described in Section III.B of [#vincent2006performance]_. 368 | 369 | Passing ``False`` for ``compute_permutation`` will improve the computation 370 | performance of the evaluation; however, it is not always appropriate and 371 | is not the way that the BSS_EVAL Matlab toolbox computes bss_eval_images. 372 | 373 | Examples 374 | -------- 375 | >>> # reference_sources[n] should be an ndarray of samples of the 376 | >>> # n'th reference source 377 | >>> # estimated_sources[n] should be the same for the n'th estimated 378 | >>> # source 379 | >>> (sdr, isr, sir, sar, 380 | ... perm) = mir_eval.separation.bss_eval_images(reference_sources, 381 | ... estimated_sources) 382 | 383 | Parameters 384 | ---------- 385 | reference_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 386 | matrix containing true sources 387 | estimated_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 388 | matrix containing estimated sources 389 | compute_permutation : bool, optional 390 | compute permutation of estimate/source combinations (True by default) 391 | 392 | Returns 393 | ------- 394 | sdr : np.ndarray, shape=(nsrc,) 395 | vector of Signal to Distortion Ratios (SDR) 396 | isr : np.ndarray, shape=(nsrc,) 397 | vector of source Image to Spatial distortion Ratios (ISR) 398 | sir : np.ndarray, shape=(nsrc,) 399 | vector of Source to Interference Ratios (SIR) 400 | sar : np.ndarray, shape=(nsrc,) 401 | vector of Sources to Artifacts Ratios (SAR) 402 | perm : np.ndarray, shape=(nsrc,) 403 | vector containing the best ordering of estimated sources in 404 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to 405 | true source number ``j``). Note: ``perm`` will be ``(1,2,...,nsrc)`` 406 | if ``compute_permutation`` is ``False``. 407 | 408 | References 409 | ---------- 410 | .. [#] Emmanuel Vincent, Shoko Araki, Fabian J. Theis, Guido Nolte, Pau 411 | Bofill, Hiroshi Sawada, Alexey Ozerov, B. Vikrham Gowreesunker, Dominik 412 | Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign 413 | (2007-2010): Achievements and remaining challenges", Signal Processing, 414 | 92, pp. 1928-1936, 2012. 415 | 416 | """ 417 | 418 | # make sure the input has 3 dimensions 419 | # assuming input is in shape (nsampl) or (nsrc, nsampl) 420 | estimated_sources = np.atleast_3d(estimated_sources) 421 | reference_sources = np.atleast_3d(reference_sources) 422 | # we will ensure input doesn't have more than 3 dimensions in validate 423 | 424 | validate(reference_sources, estimated_sources) 425 | # If empty matrices were supplied, return empty lists (special case) 426 | if reference_sources.size == 0 or estimated_sources.size == 0: 427 | return np.array([]), np.array([]), np.array([]), \ 428 | np.array([]), np.array([]) 429 | 430 | # determine size parameters 431 | nsrc = estimated_sources.shape[0] 432 | nsampl = estimated_sources.shape[1] 433 | nchan = estimated_sources.shape[2] 434 | 435 | # does the user desire permutation? 436 | if compute_permutation: 437 | # compute criteria for all possible pair matches 438 | sdr = np.empty((nsrc, nsrc)) 439 | isr = np.empty((nsrc, nsrc)) 440 | sir = np.empty((nsrc, nsrc)) 441 | sar = np.empty((nsrc, nsrc)) 442 | for jest in range(nsrc): 443 | for jtrue in range(nsrc): 444 | s_true, e_spat, e_interf, e_artif = \ 445 | _bss_decomp_mtifilt_images( 446 | reference_sources, 447 | np.reshape( 448 | estimated_sources[jest], 449 | (nsampl, nchan), 450 | order='F' 451 | ), 452 | jtrue, 453 | 512 454 | ) 455 | sdr[jest, jtrue], isr[jest, jtrue], \ 456 | sir[jest, jtrue], sar[jest, jtrue] = \ 457 | _bss_image_crit(s_true, e_spat, e_interf, e_artif) 458 | 459 | # select the best ordering 460 | perms = list(itertools.permutations(range(nsrc))) 461 | mean_sir = np.empty(len(perms)) 462 | dum = np.arange(nsrc) 463 | for (i, perm) in enumerate(perms): 464 | mean_sir[i] = np.mean(sir[perm, dum]) 465 | popt = perms[np.argmax(mean_sir)] 466 | idx = (popt, dum) 467 | return (sdr[idx], isr[idx], sir[idx], sar[idx], np.asarray(popt)) 468 | else: 469 | # compute criteria for only the simple correspondence 470 | # (estimate 1 is estimate corresponding to reference source 1, etc.) 471 | sdr = np.empty(nsrc) 472 | isr = np.empty(nsrc) 473 | sir = np.empty(nsrc) 474 | sar = np.empty(nsrc) 475 | Gj = [0] * nsrc # prepare G matrics with zeroes 476 | G = np.zeros(1) 477 | for j in range(nsrc): 478 | # save G matrix to avoid recomputing it every call 479 | s_true, e_spat, e_interf, e_artif, Gj_temp, G = \ 480 | _bss_decomp_mtifilt_images(reference_sources, 481 | np.reshape(estimated_sources[j], 482 | (nsampl, nchan), 483 | order='F'), 484 | j, 512, Gj[j], G) 485 | Gj[j] = Gj_temp 486 | sdr[j], isr[j], sir[j], sar[j] = \ 487 | _bss_image_crit(s_true, e_spat, e_interf, e_artif) 488 | 489 | # return the default permutation for compatibility 490 | popt = np.arange(nsrc) 491 | return (sdr, isr, sir, sar, popt) 492 | 493 | 494 | def bss_eval_images_framewise(reference_sources, estimated_sources, 495 | window=30*44100, hop=15*44100, 496 | compute_permutation=False): 497 | """Framewise computation of bss_eval_images 498 | 499 | Please be aware that this function does not compute permutations (by 500 | default) on the possible relations between ``reference_sources`` and 501 | ``estimated_sources`` due to the dangers of a changing permutation. 502 | Therefore (by default), it assumes that ``reference_sources[i]`` 503 | corresponds to ``estimated_sources[i]``. To enable computing permutations 504 | please set ``compute_permutation`` to be ``True`` and check that the 505 | returned ``perm`` is identical for all windows. 506 | 507 | NOTE: if ``reference_sources`` and ``estimated_sources`` would be evaluated 508 | using only a single window or are shorter than the window length, the 509 | result of ``bss_eval_images`` called on ``reference_sources`` and 510 | ``estimated_sources`` (with the ``compute_permutation`` parameter passed to 511 | ``bss_eval_images``) is returned 512 | 513 | Examples 514 | -------- 515 | >>> # reference_sources[n] should be an ndarray of samples of the 516 | >>> # n'th reference source 517 | >>> # estimated_sources[n] should be the same for the n'th estimated 518 | >>> # source 519 | >>> (sdr, isr, sir, sar, 520 | ... perm) = mir_eval.separation.bss_eval_images_framewise( 521 | reference_sources, 522 | ... estimated_sources, 523 | window, 524 | .... hop) 525 | 526 | Parameters 527 | ---------- 528 | reference_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 529 | matrix containing true sources (must have the same shape as 530 | ``estimated_sources``) 531 | estimated_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 532 | matrix containing estimated sources (must have the same shape as 533 | ``reference_sources``) 534 | window : int 535 | Window length for framewise evaluation 536 | hop : int 537 | Hop size for framewise evaluation 538 | compute_permutation : bool, optional 539 | compute permutation of estimate/source combinations for all windows 540 | (False by default) 541 | 542 | Returns 543 | ------- 544 | sdr : np.ndarray, shape=(nsrc, nframes) 545 | vector of Signal to Distortion Ratios (SDR) 546 | isr : np.ndarray, shape=(nsrc, nframes) 547 | vector of source Image to Spatial distortion Ratios (ISR) 548 | sir : np.ndarray, shape=(nsrc, nframes) 549 | vector of Source to Interference Ratios (SIR) 550 | sar : np.ndarray, shape=(nsrc, nframes) 551 | vector of Sources to Artifacts Ratios (SAR) 552 | perm : np.ndarray, shape=(nsrc, nframes) 553 | vector containing the best ordering of estimated sources in 554 | the mean SIR sense (estimated source number perm[j] corresponds to 555 | true source number j) 556 | Note: perm will be range(nsrc) for all windows if compute_permutation 557 | is False 558 | 559 | """ 560 | 561 | # make sure the input has 3 dimensions 562 | # assuming input is in shape (nsampl) or (nsrc, nsampl) 563 | estimated_sources = np.atleast_3d(estimated_sources) 564 | reference_sources = np.atleast_3d(reference_sources) 565 | # we will ensure input doesn't have more than 3 dimensions in validate 566 | 567 | validate(reference_sources, estimated_sources) 568 | # If empty matrices were supplied, return empty lists (special case) 569 | if reference_sources.size == 0 or estimated_sources.size == 0: 570 | return np.array([]), np.array([]), np.array([]), np.array([]) 571 | 572 | nsrc = reference_sources.shape[0] 573 | 574 | nwin = int( 575 | np.floor((reference_sources.shape[1] - window + hop) / hop) 576 | ) 577 | # if fewer than 2 windows would be evaluated, return the images result 578 | if nwin < 2: 579 | result = bss_eval_images(reference_sources, 580 | estimated_sources, 581 | compute_permutation) 582 | return [np.expand_dims(score, -1) for score in result] 583 | 584 | # compute the criteria across all windows 585 | sdr = np.empty((nsrc, nwin)) 586 | isr = np.empty((nsrc, nwin)) 587 | sir = np.empty((nsrc, nwin)) 588 | sar = np.empty((nsrc, nwin)) 589 | perm = np.empty((nsrc, nwin)) 590 | 591 | # k iterates across all the windows 592 | for k in range(nwin): 593 | win_slice = slice(k * hop, k * hop + window) 594 | ref_slice = reference_sources[:, win_slice, :] 595 | est_slice = estimated_sources[:, win_slice, :] 596 | # check for a silent frame 597 | if (not _any_source_silent(ref_slice) and 598 | not _any_source_silent(est_slice)): 599 | sdr[:, k], isr[:, k], sir[:, k], sar[:, k], perm[:, k] = \ 600 | bss_eval_images( 601 | ref_slice, est_slice, compute_permutation 602 | ) 603 | else: 604 | # if we have a silent frame set results as np.nan 605 | sdr[:, k] = sir[:, k] = sar[:, k] = perm[:, k] = np.nan 606 | 607 | return sdr, isr, sir, sar, perm 608 | 609 | 610 | def _bss_decomp_mtifilt(reference_sources, estimated_source, j, flen): 611 | """Decomposition of an estimated source image into four components 612 | representing respectively the true source image, spatial (or filtering) 613 | distortion, interference and artifacts, derived from the true source 614 | images using multichannel time-invariant filters. 615 | """ 616 | nsampl = estimated_source.size 617 | # decomposition 618 | # true source image 619 | s_true = np.hstack((reference_sources[j], np.zeros(flen - 1))) 620 | # spatial (or filtering) distortion 621 | e_spat = _project(reference_sources[j, np.newaxis, :], estimated_source, 622 | flen) - s_true 623 | # interference 624 | e_interf = _project(reference_sources, 625 | estimated_source, flen) - s_true - e_spat 626 | # artifacts 627 | e_artif = -s_true - e_spat - e_interf 628 | e_artif[:nsampl] += estimated_source 629 | return (s_true, e_spat, e_interf, e_artif) 630 | 631 | 632 | def _bss_decomp_mtifilt_images(reference_sources, estimated_source, j, flen, 633 | Gj=None, G=None): 634 | """Decomposition of an estimated source image into four components 635 | representing respectively the true source image, spatial (or filtering) 636 | distortion, interference and artifacts, derived from the true source 637 | images using multichannel time-invariant filters. 638 | Adapted version to work with multichannel sources. 639 | Improved performance can be gained by passing Gj and G parameters initially 640 | as all zeros. These parameters store the results from the computation of 641 | the G matrix in _project_images and then return them for subsequent calls 642 | to this function. This only works when not computing permuations. 643 | """ 644 | nsampl = np.shape(estimated_source)[0] 645 | nchan = np.shape(estimated_source)[1] 646 | # are we saving the Gj and G parameters? 647 | saveg = Gj is not None and G is not None 648 | # decomposition 649 | # true source image 650 | s_true = np.hstack((np.reshape(reference_sources[j], 651 | (nsampl, nchan), 652 | order="F").transpose(), 653 | np.zeros((nchan, flen - 1)))) 654 | # spatial (or filtering) distortion 655 | if saveg: 656 | e_spat, Gj = _project_images(reference_sources[j, np.newaxis, :], 657 | estimated_source, flen, Gj) 658 | else: 659 | e_spat = _project_images(reference_sources[j, np.newaxis, :], 660 | estimated_source, flen) 661 | e_spat = e_spat - s_true 662 | # interference 663 | if saveg: 664 | e_interf, G = _project_images(reference_sources, 665 | estimated_source, flen, G) 666 | else: 667 | e_interf = _project_images(reference_sources, 668 | estimated_source, flen) 669 | e_interf = e_interf - s_true - e_spat 670 | # artifacts 671 | e_artif = -s_true - e_spat - e_interf 672 | e_artif[:, :nsampl] += estimated_source.transpose() 673 | # return Gj and G only if they were passed in 674 | if saveg: 675 | return (s_true, e_spat, e_interf, e_artif, Gj, G) 676 | else: 677 | return (s_true, e_spat, e_interf, e_artif) 678 | 679 | 680 | def _project(reference_sources, estimated_source, flen): 681 | """Least-squares projection of estimated source on the subspace spanned by 682 | delayed versions of reference sources, with delays between 0 and flen-1 683 | """ 684 | nsrc = reference_sources.shape[0] 685 | nsampl = reference_sources.shape[1] 686 | 687 | # computing coefficients of least squares problem via FFT ## 688 | # zero padding and FFT of input data 689 | reference_sources = np.hstack((reference_sources, 690 | np.zeros((nsrc, flen - 1)))) 691 | estimated_source = np.hstack((estimated_source, np.zeros(flen - 1))) 692 | n_fft = int(2**np.ceil(np.log2(nsampl + flen - 1.))) 693 | sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1) 694 | sef = scipy.fftpack.fft(estimated_source, n=n_fft) 695 | # inner products between delayed versions of reference_sources 696 | G = np.zeros((nsrc * flen, nsrc * flen)) 697 | for i in range(nsrc): 698 | for j in range(nsrc): 699 | ssf = sf[i] * np.conj(sf[j]) 700 | ssf = np.real(scipy.fftpack.ifft(ssf)) 701 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), 702 | r=ssf[:flen]) 703 | G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss 704 | G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T 705 | # inner products between estimated_source and delayed versions of 706 | # reference_sources 707 | D = np.zeros(nsrc * flen) 708 | for i in range(nsrc): 709 | ssef = sf[i] * np.conj(sef) 710 | ssef = np.real(scipy.fftpack.ifft(ssef)) 711 | D[i * flen: (i+1) * flen] = np.hstack((ssef[0], ssef[-1:-flen:-1])) 712 | 713 | # Computing projection 714 | # Distortion filters 715 | try: 716 | C = np.linalg.solve(G, D).reshape(flen, nsrc, order='F') 717 | except np.linalg.linalg.LinAlgError: 718 | C = np.linalg.lstsq(G, D)[0].reshape(flen, nsrc, order='F') 719 | # Filtering 720 | sproj = np.zeros(nsampl + flen - 1) 721 | for i in range(nsrc): 722 | sproj += fftconvolve(C[:, i], reference_sources[i])[:nsampl + flen - 1] 723 | return sproj 724 | 725 | 726 | def _project_images(reference_sources, estimated_source, flen, G=None): 727 | """Least-squares projection of estimated source on the subspace spanned by 728 | delayed versions of reference sources, with delays between 0 and flen-1. 729 | Passing G as all zeros will populate the G matrix and return it so it can 730 | be passed into the next call to avoid recomputing G (this will only works 731 | if not computing permutations). 732 | """ 733 | nsrc = reference_sources.shape[0] 734 | nsampl = reference_sources.shape[1] 735 | nchan = reference_sources.shape[2] 736 | reference_sources = np.reshape(np.transpose(reference_sources, (2, 0, 1)), 737 | (nchan*nsrc, nsampl), order='F') 738 | 739 | # computing coefficients of least squares problem via FFT ## 740 | # zero padding and FFT of input data 741 | reference_sources = np.hstack((reference_sources, 742 | np.zeros((nchan*nsrc, flen - 1)))) 743 | estimated_source = \ 744 | np.hstack((estimated_source.transpose(), np.zeros((nchan, flen - 1)))) 745 | n_fft = int(2**np.ceil(np.log2(nsampl + flen - 1.))) 746 | sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1) 747 | sef = scipy.fftpack.fft(estimated_source, n=n_fft) 748 | 749 | # inner products between delayed versions of reference_sources 750 | if G is None: 751 | saveg = False 752 | G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen)) 753 | for i in range(nchan * nsrc): 754 | for j in range(i+1): 755 | ssf = sf[i] * np.conj(sf[j]) 756 | ssf = np.real(scipy.fftpack.ifft(ssf)) 757 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), 758 | r=ssf[:flen]) 759 | G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss 760 | G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T 761 | else: # avoid recomputing G (only works if no permutation is desired) 762 | saveg = True # return G 763 | if np.all(G == 0): # only compute G if passed as 0 764 | G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen)) 765 | for i in range(nchan * nsrc): 766 | for j in range(i+1): 767 | ssf = sf[i] * np.conj(sf[j]) 768 | ssf = np.real(scipy.fftpack.ifft(ssf)) 769 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), 770 | r=ssf[:flen]) 771 | G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss 772 | G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T 773 | 774 | # inner products between estimated_source and delayed versions of 775 | # reference_sources 776 | D = np.zeros((nchan * nsrc * flen, nchan)) 777 | for k in range(nchan * nsrc): 778 | for i in range(nchan): 779 | ssef = sf[k] * np.conj(sef[i]) 780 | ssef = np.real(scipy.fftpack.ifft(ssef)) 781 | D[k * flen: (k+1) * flen, i] = \ 782 | np.hstack((ssef[0], ssef[-1:-flen:-1])).transpose() 783 | 784 | # Computing projection 785 | # Distortion filters 786 | try: 787 | C = np.linalg.solve(G, D).reshape(flen, nchan*nsrc, nchan, order='F') 788 | except np.linalg.linalg.LinAlgError: 789 | C = np.linalg.lstsq(G, D)[0].reshape(flen, nchan*nsrc, nchan, 790 | order='F') 791 | # Filtering 792 | sproj = np.zeros((nchan, nsampl + flen - 1)) 793 | for k in range(nchan * nsrc): 794 | for i in range(nchan): 795 | sproj[i] += fftconvolve(C[:, k, i].transpose(), 796 | reference_sources[k])[:nsampl + flen - 1] 797 | # return G only if it was passed in 798 | if saveg: 799 | return sproj, G 800 | else: 801 | return sproj 802 | 803 | 804 | def _bss_source_crit(s_true, e_spat, e_interf, e_artif): 805 | """Measurement of the separation quality for a given source in terms of 806 | filtered true source, interference and artifacts. 807 | """ 808 | # energy ratios 809 | s_filt = s_true + e_spat 810 | sdr = _safe_db(np.sum(s_filt**2), np.sum((e_interf + e_artif)**2)) 811 | sir = _safe_db(np.sum(s_filt**2), np.sum(e_interf**2)) 812 | sar = _safe_db(np.sum((s_filt + e_interf)**2), np.sum(e_artif**2)) 813 | return (sdr, sir, sar) 814 | 815 | 816 | def _bss_image_crit(s_true, e_spat, e_interf, e_artif): 817 | """Measurement of the separation quality for a given image in terms of 818 | filtered true source, spatial error, interference and artifacts. 819 | """ 820 | # energy ratios 821 | sdr = _safe_db(np.sum(s_true**2), np.sum((e_spat+e_interf+e_artif)**2)) 822 | isr = _safe_db(np.sum(s_true**2), np.sum(e_spat**2)) 823 | sir = _safe_db(np.sum((s_true+e_spat)**2), np.sum(e_interf**2)) 824 | sar = _safe_db(np.sum((s_true+e_spat+e_interf)**2), np.sum(e_artif**2)) 825 | return (sdr, isr, sir, sar) 826 | 827 | 828 | def _safe_db(num, den): 829 | """Properly handle the potential +Inf db SIR, instead of raising a 830 | RuntimeWarning. Only denominator is checked because the numerator can never 831 | be 0. 832 | """ 833 | if den == 0: 834 | return np.Inf 835 | return 10 * np.log10(num / den) 836 | 837 | 838 | def evaluate(reference_sources, estimated_sources, **kwargs): 839 | """Compute all metrics for the given reference and estimated signals. 840 | 841 | NOTE: This will always compute :func:`mir_eval.separation.bss_eval_images` 842 | for any valid input and will additionally compute 843 | :func:`mir_eval.separation.bss_eval_sources` for valid input with fewer 844 | than 3 dimensions. 845 | 846 | Examples 847 | -------- 848 | >>> # reference_sources[n] should be an ndarray of samples of the 849 | >>> # n'th reference source 850 | >>> # estimated_sources[n] should be the same for the n'th estimated source 851 | >>> scores = mir_eval.separation.evaluate(reference_sources, 852 | ... estimated_sources) 853 | 854 | Parameters 855 | ---------- 856 | reference_sources : np.ndarray, shape=(nsrc, nsampl[, nchan]) 857 | matrix containing true sources 858 | estimated_sources : np.ndarray, shape=(nsrc, nsampl[, nchan]) 859 | matrix containing estimated sources 860 | kwargs 861 | Additional keyword arguments which will be passed to the 862 | appropriate metric or preprocessing functions. 863 | 864 | Returns 865 | ------- 866 | scores : dict 867 | Dictionary of scores, where the key is the metric name (str) and 868 | the value is the (float) score achieved. 869 | 870 | """ 871 | # Compute all the metrics 872 | scores = collections.OrderedDict() 873 | 874 | sdr, isr, sir, sar, perm = util.filter_kwargs( 875 | bss_eval_images, 876 | reference_sources, 877 | estimated_sources, 878 | **kwargs 879 | ) 880 | scores['Images - Source to Distortion'] = sdr.tolist() 881 | scores['Images - Image to Spatial'] = isr.tolist() 882 | scores['Images - Source to Interference'] = sir.tolist() 883 | scores['Images - Source to Artifact'] = sar.tolist() 884 | scores['Images - Source permutation'] = perm.tolist() 885 | 886 | sdr, isr, sir, sar, perm = util.filter_kwargs( 887 | bss_eval_images_framewise, 888 | reference_sources, 889 | estimated_sources, 890 | **kwargs 891 | ) 892 | scores['Images Frames - Source to Distortion'] = sdr.tolist() 893 | scores['Images Frames - Image to Spatial'] = isr.tolist() 894 | scores['Images Frames - Source to Interference'] = sir.tolist() 895 | scores['Images Frames - Source to Artifact'] = sar.tolist() 896 | scores['Images Frames - Source permutation'] = perm.tolist() 897 | 898 | # Verify we can compute sources on this input 899 | if reference_sources.ndim < 3 and estimated_sources.ndim < 3: 900 | sdr, sir, sar, perm = util.filter_kwargs( 901 | bss_eval_sources_framewise, 902 | reference_sources, 903 | estimated_sources, 904 | **kwargs 905 | ) 906 | scores['Sources Frames - Source to Distortion'] = sdr.tolist() 907 | scores['Sources Frames - Source to Interference'] = sir.tolist() 908 | scores['Sources Frames - Source to Artifact'] = sar.tolist() 909 | scores['Sources Frames - Source permutation'] = perm.tolist() 910 | 911 | sdr, sir, sar, perm = util.filter_kwargs( 912 | bss_eval_sources, 913 | reference_sources, 914 | estimated_sources, 915 | **kwargs 916 | ) 917 | scores['Sources - Source to Distortion'] = sdr.tolist() 918 | scores['Sources - Source to Interference'] = sir.tolist() 919 | scores['Sources - Source to Artifact'] = sar.tolist() 920 | scores['Sources - Source permutation'] = perm.tolist() 921 | 922 | return scores 923 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Mar 2 16:24:50 2022 4 | 5 | @author: Xiaohuai Le 6 | """ 7 | 8 | import tensorflow as tf 9 | 10 | class Loss(): 11 | 12 | def __init__(self,): 13 | pass 14 | 15 | @staticmethod 16 | def snr_cost(s_estimate, s_true): 17 | ''' 18 | SNR cost 19 | ''' 20 | # calculating the SNR 21 | snr = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True) / \ 22 | (tf.reduce_mean(tf.math.square(s_true-s_estimate), axis=-1, keepdims=True)+1e-7) 23 | num = tf.math.log(snr + 1e-7) 24 | denom = tf.math.log(tf.constant(10, dtype=num.dtype)) 25 | loss = -10*(num / (denom)) 26 | return loss 27 | 28 | @staticmethod 29 | def sisnr_cost(s_hat, s): 30 | ''' 31 | SISNR cost 32 | ''' 33 | def norm(x): 34 | return tf.reduce_sum(x**2, axis=-1, keepdims=True) 35 | s_target = tf.reduce_sum( 36 | s_hat * s, axis=-1, keepdims=True) * s / norm(s) 37 | upp = norm(s_target) 38 | low = norm(s_hat - s_target) 39 | return -10 * tf.log(upp /low) / tf.log(10.0) 40 | 41 | @staticmethod 42 | def skip_regular_MAE(update_gate, miu = 0.5): 43 | ''' 44 | MAE-based regularization 45 | ''' 46 | return tf.abs(tf.reduce_mean(update_gate) - miu) 47 | 48 | @staticmethod 49 | def skip_regular_MSE(update_gate, miu = 0.5): 50 | ''' 51 | MSE-based regularization 52 | ''' 53 | return (tf.reduce_mean(update_gate) - miu) ** 2 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Mar 2 15:52:28 2022 4 | 5 | @author: Xiaohuai Le 6 | """ 7 | import os 8 | import tensorflow as tf 9 | from tensorflow.keras.callbacks import ReduceLROnPlateau, CSVLogger, EarlyStopping, ModelCheckpoint 10 | import yaml 11 | 12 | from DPCRN_base import DPCRN_model 13 | from DPCRN_skip import DPCRN_skip_model 14 | from data_loader import data_generator 15 | 16 | class Trainer(): 17 | 18 | def __init__(self, args): 19 | 20 | print(args) 21 | self.mode = args.mode 22 | if self.mode == 'train': 23 | self.batch_size = args.bs 24 | elif self.mode == 'test': 25 | self.batch_size = 1 26 | self.lr = args.lr 27 | self.second = args.second 28 | self.ckpt = args.ckpt 29 | self.test_dir = args.test_dir 30 | self.output_dir = args.output_dir 31 | self.experiment_name = args.experiment_name 32 | self.config_dict = self.read_yaml(args.config) 33 | self.max_epochs = self.config_dict['trainer']['max_epochs'] 34 | 35 | if self.config_dict['name'] == 'DPCRN-base': 36 | self.dpcrn_model = DPCRN_model(batch_size = self.batch_size, length_in_s = self.second, lr = self.lr, config = self.config_dict) 37 | self.dpcrn_model.build_DPCRN_model() 38 | elif self.config_dict['name'] == 'DPCRN-skip': 39 | self.dpcrn_model = DPCRN_skip_model(batch_size = self.batch_size, length_in_s = self.second, lr = self.lr, config = self.config_dict) 40 | self.dpcrn_model.build_DPCRN_model() 41 | else: 42 | pass 43 | 44 | if self.mode == 'train': 45 | self.data_generator = data_generator(DNS_dir = self.config_dict['database']['DNS_path'], 46 | WSJ_dir = self.config_dict['database']['WSJ_path'], 47 | RIR_dir = self.config_dict['database']['RIRs_path'], 48 | temp_data_dir = self.config_dict['database']['data_path'], 49 | length_per_sample = self.second, 50 | SNR_range = self.config_dict['database']['SNR'], 51 | fs = self.config_dict['stft']['fs'], 52 | n_fft = self.config_dict['stft']['N_FFT'], 53 | n_hop = self.config_dict['stft']['block_shift'], 54 | batch_size = self.batch_size, 55 | sd = self.config_dict['trainer']['seed'], 56 | add_reverb = True, 57 | reverb_rate = self.config_dict['database']['reverb_rate'], 58 | spec_aug_rate = self.config_dict['database']['spec_aug_rate']) 59 | 60 | self.train_model(runName = args.experiment_name, data_generator = self.data_generator) 61 | 62 | elif self.mode == 'test': 63 | if self.ckpt: 64 | self.dpcrn_model.model_inference.load_weights(self.ckpt) 65 | if self.config_dict['name'] == 'DPCRN-base': 66 | self.dpcrn_model.test_on_dataset(args.test_dir, args.output_dir) 67 | elif self.config_dict['name'] == 'DPCRN-skip': 68 | self.dpcrn_model.test_on_dataset(args.test_dir, args.output_dir, args.gamma) 69 | 70 | def read_yaml(self, file): 71 | # read the configuration file 72 | f = open(file,'r',encoding='utf-8') 73 | result = f.read() 74 | print(result) 75 | config_dict = yaml.load(result, Loader = yaml.Loader) 76 | return config_dict 77 | 78 | def train_model(self, runName, data_generator): 79 | 80 | self.dpcrn_model.compile_model() 81 | 82 | # create save path if not existent 83 | savePath = './models_'+ runName+'/' 84 | if not os.path.exists(savePath): 85 | os.makedirs(savePath) 86 | # create log file writer 87 | csv_logger = CSVLogger(savePath+ 'training_' +runName+ '.log') 88 | # create callback for the adaptive learning rate 89 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, 90 | patience=10, min_lr=10**(-10), cooldown=1) 91 | # create callback for early stopping 92 | early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, 93 | patience=20, mode='auto', baseline=None) 94 | # create model check pointer to save the best model 95 | 96 | checkpointer = ModelCheckpoint(savePath+runName+'model_{epoch:02d}_{val_loss:02f}_{val_sisnr_metrics:02f}.h5', 97 | monitor='val_loss', 98 | save_best_only=False, 99 | save_weights_only=True, 100 | mode='auto', 101 | save_freq='epoch' 102 | ) 103 | 104 | # create data generator for training data 105 | self.dpcrn_model.model.fit_generator(data_generator.generator(batch_size = self.batch_size,validation = False), 106 | validation_data = data_generator.generator(batch_size =self.batch_size,validation = True), 107 | epochs = self.max_epochs, 108 | steps_per_epoch = data_generator.train_length//self.batch_size, 109 | validation_steps = self.batch_size, 110 | #use_multiprocessing=True, 111 | callbacks=[checkpointer, reduce_lr, csv_logger, early_stopping]) 112 | # clear out garbage 113 | tf.keras.backend.clear_session() 114 | 115 | if __name__ == '__main__': 116 | import argparse 117 | parser = argparse.ArgumentParser(description='manual to this script') 118 | parser.add_argument("--config", type = str, default = './configuration/DPCRN-base.yaml', help = 'the configuration files') 119 | parser.add_argument("--cuda", type = int, default = 0, help = 'which GPU to use') 120 | parser.add_argument("--mode", type = str, default = 'test', help = 'train or test') 121 | parser.add_argument("--bs", type = int, default = 16, help = 'batch size') 122 | parser.add_argument("--lr", type = float, default = 1e-3, help = 'learning rate') 123 | parser.add_argument("--experiment_name", type = str, default = 'experiment_1', help = 'the experiment name') 124 | parser.add_argument("--second", type = int, default = 8, help = 'length in second of every sample') 125 | parser.add_argument("--ckpt", type=str, default = './pretrained_weights/DPCRN_base/models_experiment_new_base_nomap_phasenloss_retrain_WSJmodel_84_0.022068.h5', help = 'the location of the weights') 126 | parser.add_argument("--test_dir", type=str, default = './test_audio/noisy', help = 'the floder of noisy speech') 127 | parser.add_argument("--output_dir", type=str, default = './test_audio/enhanced', help = 'the floder of enhanced speech') 128 | parser.add_argument("--gamma", type=float, default = 1, help = 'the scaling factor of the state update rate') 129 | 130 | args = parser.parse_args() 131 | 132 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda) 133 | 134 | trainer = Trainer(args) 135 | -------------------------------------------------------------------------------- /networks/modules.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Jan 12 14:53:48 2021 4 | 5 | @author: Xiaohuai Le 6 | """ 7 | import tensorflow.keras as keras 8 | from tensorflow.keras.layers import Layer, Conv2D, Conv2DTranspose, Add, RNN 9 | import tensorflow as tf 10 | from tensorflow.keras import backend as K 11 | 12 | from tensorflow.python.keras import activations 13 | from tensorflow.python.keras import constraints 14 | from tensorflow.python.keras import initializers 15 | from tensorflow.python.keras import regularizers 16 | import logging 17 | 18 | from networks.skip_gru import SkipGRU 19 | 20 | #%% 21 | ''' 22 | dual path rnn block 23 | ''' 24 | class DprnnBlock(keras.layers.Layer): 25 | 26 | def __init__(self, intra_hidden, inter_hidden, batch_size, L, width, channel, causal = False, CUDNN = False, **kwargs): 27 | super(DprnnBlock, self).__init__(**kwargs) 28 | ''' 29 | intra_hidden hidden size of the intra-chunk RNN 30 | inter_hidden hidden size of the inter-chunk RNN 31 | batch_size 32 | L number of frames, -1 for undefined length 33 | width width size output from encoder 34 | channel channel size output from encoder 35 | causal instant Layer Norm or global Layer Norm 36 | ''' 37 | self.batch_size = batch_size 38 | self.causal = causal 39 | self.L = L 40 | self.width = width 41 | self.channel = channel 42 | 43 | if CUDNN: 44 | self.intra_rnn = keras.layers.Bidirectional(keras.layers.CuDNNGRU(units=intra_hidden//2, return_sequences=True)) 45 | else: 46 | self.intra_rnn = keras.layers.Bidirectional(keras.layers.GRU(units=intra_hidden//2, return_sequences=True,implementation = 1,recurrent_activation = 'sigmoid', unroll = True,reset_after = False)) 47 | 48 | self.intra_fc = keras.layers.Dense(units = self.channel,) 49 | 50 | self.intra_ln = keras.layers.LayerNormalization(center=True, scale=True,epsilon = 1e-8) 51 | 52 | if CUDNN: 53 | self.inter_rnn = keras.layers.CuDNNGRU(units=inter_hidden, return_sequences=True) 54 | else: 55 | self.inter_rnn = keras.layers.GRU(units=inter_hidden, return_sequences=True,implementation = 1,recurrent_activation = 'sigmoid',reset_after = False) 56 | 57 | self.inter_fc = keras.layers.Dense(units = self.channel,) 58 | 59 | self.inter_ln = keras.layers.LayerNormalization(center=True, scale=True,epsilon = 1e-8) 60 | 61 | def call(self, x): 62 | # Intra-Chunk Processing 63 | batch_size = self.batch_size 64 | L = self.L 65 | width = self.width 66 | 67 | intra_rnn = self.intra_rnn 68 | intra_fc = self.intra_fc 69 | intra_ln = self.intra_ln 70 | inter_rnn = self.inter_rnn 71 | inter_fc = self.inter_fc 72 | inter_ln = self.inter_ln 73 | channel = self.channel 74 | causal = self.causal 75 | # input shape (bs,T,F,C) --> (bs*T,F,C) 76 | intra_GRU_input = tf.reshape(x,[-1,width,channel]) 77 | # (bs*T,F,C) 78 | intra_GRU_out = intra_rnn(intra_GRU_input) 79 | 80 | # (bs*T,F,C) channel axis dense 81 | intra_dense_out = intra_fc(intra_GRU_out) 82 | 83 | if causal: 84 | # (bs*T,F,C) --> (bs,T,F,C) Freq and channel norm 85 | intra_ln_input = tf.reshape(intra_dense_out,[batch_size,-1,width,channel]) 86 | intra_out = intra_ln(intra_ln_input) 87 | else: 88 | # (bs*T,F,C) --> (bs,T*F*C) global norm 89 | intra_ln_input = tf.reshape(intra_dense_out,[batch_size,-1]) 90 | intra_ln_out = intra_ln(intra_ln_input) 91 | intra_out = tf.reshape(intra_ln_out,[batch_size,L,width,channel]) 92 | 93 | # (bs,T,F,C) 94 | intra_out = Add()([x,intra_out]) 95 | #%% 96 | # (bs,T,F,C) --> (bs,F,T,C) 97 | inter_GRU_input = tf.transpose(intra_out,[0,2,1,3]) 98 | # (bs,F,T,C) --> (bs*F,T,C) 99 | inter_GRU_input = tf.reshape(inter_GRU_input,[batch_size*width,L,channel]) 100 | 101 | inter_GRU_out = inter_rnn(inter_GRU_input) 102 | 103 | # (bs,F,T,C) Channel axis dense 104 | inter_dense_out = inter_fc(inter_GRU_out) 105 | 106 | inter_dense_out = tf.reshape(inter_dense_out,[batch_size,width,L,channel]) 107 | 108 | if causal: 109 | # (bs,F,T,C) --> (bs,T,F,C) 110 | inter_ln_input = tf.transpose(inter_dense_out,[0,2,1,3]) 111 | inter_out = inter_ln(inter_ln_input) 112 | else: 113 | # (bs,F,T,C) --> (bs,F*T*C) 114 | inter_ln_input = tf.reshape(inter_dense_out,[batch_size,-1]) 115 | inter_ln_out = inter_ln(inter_ln_input) 116 | inter_out = tf.reshape(inter_ln_out,[batch_size,width,L,channel]) 117 | # (bs,F,T,C) --> (bs,T,F,C) 118 | inter_out = tf.transpose(inter_out,[0,2,1,3]) 119 | 120 | inter_out = Add()([intra_out,inter_out]) 121 | 122 | return inter_out 123 | 124 | class DprnnBlock_skip(keras.layers.Layer): 125 | 126 | def __init__(self, intra_hidden, inter_hidden, batch_size, L, width, channel, skip = 0, **kwargs): 127 | super(DprnnBlock_skip, self).__init__(**kwargs) 128 | ''' 129 | skip: 0 inter-skip, 1 intra-skip, 2 all-skip 130 | ''' 131 | self.batch_size = batch_size 132 | self.L = L 133 | self.width = width 134 | self.channel = channel 135 | self.skip = skip 136 | 137 | if skip == 0: 138 | self.intra_rnn = keras.layers.Bidirectional(keras.layers.GRU(units=intra_hidden//2, return_sequences=True,implementation = 1,recurrent_activation = 'sigmoid', unroll = True,reset_after = False)) 139 | self.intra_skip = 0 140 | elif skip == 1 or skip == 2: 141 | self.intra_rnn = keras.layers.Bidirectional(SkipGRU(units=intra_hidden//2, return_sequences=True,return_state = True, implementation = 1,recurrent_activation = 'sigmoid',reset_after = False)) 142 | self.intra_skip = 1 143 | else: 144 | raise ValueError('the value of skip mode only support 0, 1, 2!') 145 | 146 | self.intra_fc = keras.layers.Dense(units = self.channel) 147 | 148 | self.intra_ln = keras.layers.LayerNormalization(center=True, scale=True, epsilon = 1e-8) 149 | 150 | if skip == 1: 151 | self.inter_rnn = keras.layers.GRU(units=inter_hidden, return_sequences=True,implementation = 1,recurrent_activation = 'sigmoid',reset_after = False) 152 | self.inter_skip = 0 153 | elif skip == 0 or skip == 2: 154 | self.inter_rnn = SkipGRU(units=inter_hidden, return_sequences=True,implementation = 1,recurrent_activation = 'sigmoid',reset_after = False) 155 | self.inter_skip = 1 156 | else: 157 | raise ValueError('the value of skip mode only support 0, 1, 2!') 158 | 159 | self.inter_fc = keras.layers.Dense(units = self.channel) 160 | 161 | self.inter_ln = keras.layers.LayerNormalization(center=True, scale=True, epsilon = 1e-8) 162 | 163 | def call(self, x, scale): 164 | # Intra-Chunk Processing 165 | batch_size = self.batch_size 166 | L = self.L 167 | width = self.width 168 | 169 | intra_rnn = self.intra_rnn 170 | intra_fc = self.intra_fc 171 | intra_ln = self.intra_ln 172 | inter_rnn = self.inter_rnn 173 | inter_fc = self.inter_fc 174 | inter_ln = self.inter_ln 175 | channel = self.channel 176 | # input shape (bs,T,F,C) --> (bs*T,F,C) 177 | intra_LSTM_input = tf.reshape(x,[-1,width,channel]) 178 | # (bs*T,F,C) 179 | if self.intra_skip: 180 | # get the output of intra-chunk Skip-RNN 181 | scale1 = tf.reshape(scale,[-1,width,1]) 182 | intra_LSTM_input = tf.concat([intra_LSTM_input,scale1],axis = -1) 183 | [intra_LSTM_out, gate_forward,_,_,_,gate_backward,_,_,_] = intra_rnn(intra_LSTM_input) 184 | # we concatenate the output of two sub-RNNs in each direction 185 | update_gate_intra = tf.transpose(tf.concat([gate_forward[:,:,0],gate_backward[:,:,0]],axis = -1),[1,0]) 186 | else: 187 | intra_LSTM_out = intra_rnn(intra_LSTM_input) 188 | update_gate_intra = tf.ones([64,tf.shape(x)[1]]) 189 | # (bs*T,F,C) channel axis dense 190 | intra_dense_out = intra_fc(intra_LSTM_out) 191 | 192 | # (bs*T,F,C) --> (bs,T,F,C) Freq and channel norm 193 | intra_ln_input = tf.reshape(intra_dense_out,[batch_size,-1,width,channel]) 194 | intra_out = intra_ln(intra_ln_input) 195 | 196 | # (bs,T,F,C) 197 | intra_out = Add()([x,intra_out]) 198 | #%% 199 | # (bs,T,F,C) --> (bs,F,T,C) 200 | inter_LSTM_input = tf.transpose(intra_out,[0,2,1,3]) 201 | # (bs,F,T,C) --> (bs*F,T,C) 202 | inter_LSTM_input = tf.reshape(inter_LSTM_input,[batch_size*width,L,channel]) 203 | 204 | if self.inter_skip: 205 | #get the output of inter-chunk Skip-RNN 206 | scale2 = tf.reshape(tf.transpose(scale,[0,2,1,3]),[batch_size*width,L,1]) 207 | inter_LSTM_input = tf.concat([inter_LSTM_input,scale2],axis = -1) 208 | inter_LSTM_out, update_gate_inter = inter_rnn(inter_LSTM_input) 209 | update_gate_inter = update_gate_inter[:,:,0] 210 | else: 211 | inter_LSTM_out = inter_rnn(inter_LSTM_input) 212 | update_gate_inter = tf.ones([32,tf.shape(x)[1]]) 213 | # (bs,F,T,C) Channel axis dense 214 | inter_dense_out = inter_fc(inter_LSTM_out) 215 | 216 | inter_dense_out = tf.reshape(inter_dense_out,[batch_size,width,L,channel]) 217 | 218 | # (bs,F,T,C) --> (bs,T,F,C) 219 | inter_ln_input = tf.transpose(inter_dense_out,[0,2,1,3]) 220 | inter_out = inter_ln(inter_ln_input) 221 | # (bs,T,F,C) 222 | inter_out = Add()([intra_out,inter_out]) 223 | 224 | return inter_out,update_gate_intra,update_gate_inter 225 | -------------------------------------------------------------------------------- /networks/pruning_gru.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Dec 26 16:34:03 2021 4 | 5 | @author: xiaohuai le 6 | """ 7 | 8 | from tensorflow.keras.layers import Layer,RNN,GRUCell,Dense 9 | import tensorflow as tf 10 | from tensorflow.python.framework import dtypes 11 | from tensorflow.python.framework import tensor_shape 12 | from tensorflow.python.ops import math_ops 13 | from tensorflow.python.ops import gen_math_ops 14 | from tensorflow.python.ops import nn 15 | from tensorflow.python.ops import standard_ops 16 | from tensorflow.python.eager import context 17 | from tensorflow.python.keras import activations 18 | from tensorflow.python.keras import backend as K 19 | from tensorflow.python.keras import constraints 20 | from tensorflow.python.keras import initializers 21 | from tensorflow.python.keras import regularizers 22 | from tensorflow.python.keras.engine.input_spec import InputSpec 23 | 24 | ''' 25 | A weight mask is added to the FC layer and the RNN layer for pruning. 26 | We override Dense and GRU of TensorFlow 27 | ''' 28 | class Dense_mask(Dense): 29 | 30 | def __init__(self, 31 | units, 32 | activation=None, 33 | use_bias=True, 34 | kernel_initializer='glorot_uniform', 35 | bias_initializer='zeros', 36 | kernel_regularizer=None, 37 | bias_regularizer=None, 38 | activity_regularizer=None, 39 | kernel_constraint=None, 40 | bias_constraint=None, 41 | **kwargs): 42 | 43 | super(Dense_mask, self).__init__(units, 44 | activation, 45 | use_bias, 46 | kernel_initializer, 47 | bias_initializer, 48 | kernel_regularizer, 49 | bias_regularizer, 50 | activity_regularizer, 51 | kernel_constraint, 52 | bias_constraint, 53 | **kwargs) 54 | 55 | def build(self, input_shape): 56 | 57 | dtype = dtypes.as_dtype(self.dtype or K.floatx()) 58 | if not (dtype.is_floating or dtype.is_complex): 59 | raise TypeError('Unable to build `Dense` layer with non-floating point ' 60 | 'dtype %s' % (dtype,)) 61 | input_shape = tensor_shape.TensorShape(input_shape) 62 | if tensor_shape.dimension_value(input_shape[-1]) is None: 63 | raise ValueError('The last dimension of the inputs to `Dense` ' 64 | 'should be defined. Found `None`.') 65 | last_dim = tensor_shape.dimension_value(input_shape[-1]) 66 | self.input_spec = InputSpec(min_ndim=2, 67 | axes={-1: last_dim}) 68 | self.kernel = self.add_weight( 69 | 'kernel', 70 | shape=[last_dim, self.units], 71 | initializer=self.kernel_initializer, 72 | regularizer=self.kernel_regularizer, 73 | constraint=self.kernel_constraint, 74 | dtype=self.dtype, 75 | trainable=True) 76 | 77 | self.mask = self.add_weight( 78 | shape=[last_dim, 1], 79 | name='kernel_mask', 80 | initializer='ones', 81 | trainable = False) 82 | 83 | self.masked_kernel = self.kernel * self.mask 84 | 85 | if self.use_bias: 86 | self.bias = self.add_weight( 87 | 'bias', 88 | shape=[self.units,], 89 | initializer=self.bias_initializer, 90 | regularizer=self.bias_regularizer, 91 | constraint=self.bias_constraint, 92 | dtype=self.dtype, 93 | trainable=True) 94 | else: 95 | self.bias = None 96 | self.built = True 97 | 98 | def call(self, inputs): 99 | rank = len(inputs.shape) 100 | if rank > 2: 101 | # Broadcasting is required for the inputs. 102 | outputs = standard_ops.tensordot(inputs, self.masked_kernel, [[rank - 1], [0]]) 103 | # Reshape the output back to the original ndim of the input. 104 | if not context.executing_eagerly(): 105 | shape = inputs.shape.as_list() 106 | output_shape = shape[:-1] + [self.units] 107 | outputs.set_shape(output_shape) 108 | else: 109 | inputs = math_ops.cast(inputs, self._compute_dtype) 110 | if K.is_sparse(inputs): 111 | outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, self.masked_kernel) 112 | else: 113 | outputs = gen_math_ops.mat_mul(inputs, self.masked_kernel) 114 | if self.use_bias: 115 | outputs = nn.bias_add(outputs, self.bias) 116 | if self.activation is not None: 117 | return self.activation(outputs) # pylint: disable=not-callable 118 | return outputs 119 | 120 | class GRUCell_mask(GRUCell): 121 | 122 | def __init__(self, 123 | units, 124 | activation='tanh', 125 | recurrent_activation='hard_sigmoid', 126 | use_bias=True, 127 | kernel_initializer='glorot_uniform', 128 | recurrent_initializer='orthogonal', 129 | bias_initializer='zeros', 130 | kernel_regularizer=None, 131 | recurrent_regularizer=None, 132 | bias_regularizer=None, 133 | kernel_constraint=None, 134 | recurrent_constraint=None, 135 | bias_constraint=None, 136 | dropout=0., 137 | recurrent_dropout=0., 138 | implementation=1, 139 | reset_after=False, 140 | **kwargs): 141 | super(GRUCell_mask, self).__init__(units, 142 | activation, 143 | recurrent_activation, 144 | use_bias, 145 | kernel_initializer, 146 | recurrent_initializer, 147 | bias_initializer, 148 | kernel_regularizer, 149 | recurrent_regularizer, 150 | bias_regularizer, 151 | kernel_constraint, 152 | recurrent_constraint, 153 | bias_constraint, 154 | dropout, 155 | recurrent_dropout, 156 | implementation, 157 | reset_after, 158 | **kwargs) 159 | 160 | def build(self, input_shape): 161 | input_dim = input_shape[-1] 162 | 163 | self.kernel = self.add_weight( 164 | shape=(input_dim, self.units * 3), 165 | name='kernel', 166 | initializer=self.kernel_initializer, 167 | regularizer=self.kernel_regularizer, 168 | constraint=self.kernel_constraint) 169 | 170 | self.recurrent_kernel = self.add_weight( 171 | shape=(self.units, self.units * 3), 172 | name='recurrent_kernel', 173 | initializer=self.recurrent_initializer, 174 | regularizer=self.recurrent_regularizer, 175 | constraint=self.recurrent_constraint) 176 | 177 | self.kernel_mask = self.add_weight( 178 | shape=(1, self.units * 3), 179 | name='kernel_mask', 180 | initializer='ones', 181 | trainable = False) 182 | 183 | self.recurrent_mask_column = self.add_weight( 184 | shape=(1, self.units * 3), 185 | name='recurrent_mask_column', 186 | initializer='ones', 187 | trainable = False) 188 | 189 | self.recurrent_mask_row = self.add_weight( 190 | shape=(self.units, 1), 191 | name='recurrent_mask_row', 192 | initializer='ones', 193 | trainable = False) 194 | 195 | self.masked_kernel = self.kernel * self.kernel_mask 196 | self.masked_recurrent_kernel = self.recurrent_kernel * self.recurrent_mask_column * self.recurrent_mask_row 197 | 198 | if self.use_bias: 199 | if not self.reset_after: 200 | bias_shape = (3 * self.units,) 201 | else: 202 | # separate biases for input and recurrent kernels 203 | # Note: the shape is intentionally different from CuDNNGRU biases 204 | # `(2 * 3 * self.units,)`, so that we can distinguish the classes 205 | # when loading and converting saved weights. 206 | bias_shape = (2, 3 * self.units) 207 | self.bias = self.add_weight(shape=bias_shape, 208 | name='bias', 209 | initializer=self.bias_initializer, 210 | regularizer=self.bias_regularizer, 211 | constraint=self.bias_constraint) 212 | 213 | else: 214 | self.bias = None 215 | self.built = True 216 | 217 | def call(self, inputs, states, training=None): 218 | h_tm1 = states[0] # previous memory 219 | 220 | dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3) 221 | rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 222 | h_tm1, training, count=3) 223 | 224 | if self.use_bias: 225 | if not self.reset_after: 226 | input_bias, recurrent_bias = self.bias, None 227 | else: 228 | input_bias, recurrent_bias = array_ops.unstack(self.bias) 229 | 230 | if self.implementation == 1: 231 | if 0. < self.dropout < 1.: 232 | inputs_z = inputs * dp_mask[0] 233 | inputs_r = inputs * dp_mask[1] 234 | inputs_h = inputs * dp_mask[2] 235 | else: 236 | inputs_z = inputs 237 | inputs_r = inputs 238 | inputs_h = inputs 239 | 240 | x_z = K.dot(inputs_z, self.masked_kernel[:, :self.units]) 241 | x_r = K.dot(inputs_r, self.masked_kernel[:, self.units:self.units * 2]) 242 | x_h = K.dot(inputs_h, self.masked_kernel[:, self.units * 2:]) 243 | 244 | if self.use_bias: 245 | x_z = K.bias_add(x_z, input_bias[:self.units]) 246 | x_r = K.bias_add(x_r, input_bias[self.units: self.units * 2]) 247 | x_h = K.bias_add(x_h, input_bias[self.units * 2:]) 248 | 249 | if 0. < self.recurrent_dropout < 1.: 250 | h_tm1_z = h_tm1 * rec_dp_mask[0] 251 | h_tm1_r = h_tm1 * rec_dp_mask[1] 252 | h_tm1_h = h_tm1 * rec_dp_mask[2] 253 | else: 254 | h_tm1_z = h_tm1 255 | h_tm1_r = h_tm1 256 | h_tm1_h = h_tm1 257 | 258 | recurrent_z = K.dot(h_tm1_z, self.masked_recurrent_kernel[:, :self.units]) 259 | recurrent_r = K.dot(h_tm1_r, 260 | self.masked_recurrent_kernel[:, self.units:self.units * 2]) 261 | if self.reset_after and self.use_bias: 262 | recurrent_z = K.bias_add(recurrent_z, recurrent_bias[:self.units]) 263 | recurrent_r = K.bias_add(recurrent_r, 264 | recurrent_bias[self.units:self.units * 2]) 265 | 266 | z = self.recurrent_activation(x_z + recurrent_z) 267 | r = self.recurrent_activation(x_r + recurrent_r) 268 | 269 | # reset gate applied after/before matrix multiplication 270 | if self.reset_after: 271 | recurrent_h = K.dot(h_tm1_h, self.masked_recurrent_kernel[:, self.units * 2:]) 272 | if self.use_bias: 273 | recurrent_h = K.bias_add(recurrent_h, recurrent_bias[self.units * 2:]) 274 | recurrent_h = r * recurrent_h 275 | else: 276 | recurrent_h = K.dot(r * h_tm1_h, 277 | self.masked_recurrent_kernel[:, self.units * 2:]) 278 | 279 | hh = self.activation(x_h + recurrent_h) 280 | else: 281 | if 0. < self.dropout < 1.: 282 | inputs = inputs * dp_mask[0] 283 | 284 | # inputs projected by all gate matrices at once 285 | matrix_x = K.dot(inputs, self.masked_kernel) 286 | if self.use_bias: 287 | # biases: bias_z_i, bias_r_i, bias_h_i 288 | matrix_x = K.bias_add(matrix_x, input_bias) 289 | 290 | x_z = matrix_x[:, :self.units] 291 | x_r = matrix_x[:, self.units: 2 * self.units] 292 | x_h = matrix_x[:, 2 * self.units:] 293 | 294 | if 0. < self.recurrent_dropout < 1.: 295 | h_tm1 = h_tm1 * rec_dp_mask[0] 296 | 297 | if self.reset_after: 298 | # hidden state projected by all gate matrices at once 299 | matrix_inner = K.dot(h_tm1, self.masked_recurrent_kernel) 300 | if self.use_bias: 301 | matrix_inner = K.bias_add(matrix_inner, recurrent_bias) 302 | else: 303 | # hidden state projected separately for update/reset and new 304 | matrix_inner = K.dot(h_tm1, self.masked_recurrent_kernel[:, :2 * self.units]) 305 | 306 | recurrent_z = matrix_inner[:, :self.units] 307 | recurrent_r = matrix_inner[:, self.units:2 * self.units] 308 | 309 | z = self.recurrent_activation(x_z + recurrent_z) 310 | r = self.recurrent_activation(x_r + recurrent_r) 311 | 312 | if self.reset_after: 313 | recurrent_h = r * matrix_inner[:, 2 * self.units:] 314 | else: 315 | recurrent_h = K.dot(r * h_tm1, 316 | self.masked_recurrent_kernel[:, 2 * self.units:]) 317 | 318 | hh = self.activation(x_h + recurrent_h) 319 | # previous and candidate state mixed by update gate 320 | h = z * h_tm1 + (1 - z) * hh 321 | return h, [h] 322 | 323 | 324 | class GRU_mask(RNN): 325 | """Gated Recurrent Unit - Cho et al. 2014. 326 | 327 | There are two variants. The default one is based on 1406.1078v3 and 328 | has reset gate applied to hidden state before matrix multiplication. The 329 | other one is based on original 1406.1078v1 and has the order reversed. 330 | 331 | The second variant is compatible with CuDNNGRU (GPU-only) and allows 332 | inference on CPU. Thus it has separate biases for `kernel` and 333 | `recurrent_kernel`. Use `'reset_after'=True` and 334 | `recurrent_activation='sigmoid'`. 335 | 336 | Arguments: 337 | units: Positive integer, dimensionality of the output space. 338 | activation: Activation function to use. 339 | Default: hyperbolic tangent (`tanh`). 340 | If you pass `None`, no activation is applied 341 | (ie. "linear" activation: `a(x) = x`). 342 | recurrent_activation: Activation function to use 343 | for the recurrent step. 344 | Default: hard sigmoid (`hard_sigmoid`). 345 | If you pass `None`, no activation is applied 346 | (ie. "linear" activation: `a(x) = x`). 347 | use_bias: Boolean, whether the layer uses a bias vector. 348 | kernel_initializer: Initializer for the `kernel` weights matrix, 349 | used for the linear transformation of the inputs. 350 | recurrent_initializer: Initializer for the `recurrent_kernel` 351 | weights matrix, used for the linear transformation of the recurrent state. 352 | bias_initializer: Initializer for the bias vector. 353 | kernel_regularizer: Regularizer function applied to 354 | the `kernel` weights matrix. 355 | recurrent_regularizer: Regularizer function applied to 356 | the `recurrent_kernel` weights matrix. 357 | bias_regularizer: Regularizer function applied to the bias vector. 358 | activity_regularizer: Regularizer function applied to 359 | the output of the layer (its "activation").. 360 | kernel_constraint: Constraint function applied to 361 | the `kernel` weights matrix. 362 | recurrent_constraint: Constraint function applied to 363 | the `recurrent_kernel` weights matrix. 364 | bias_constraint: Constraint function applied to the bias vector. 365 | dropout: Float between 0 and 1. 366 | Fraction of the units to drop for 367 | the linear transformation of the inputs. 368 | recurrent_dropout: Float between 0 and 1. 369 | Fraction of the units to drop for 370 | the linear transformation of the recurrent state. 371 | implementation: Implementation mode, either 1 or 2. 372 | Mode 1 will structure its operations as a larger number of 373 | smaller dot products and additions, whereas mode 2 will 374 | batch them into fewer, larger operations. These modes will 375 | have different performance profiles on different hardware and 376 | for different applications. 377 | return_sequences: Boolean. Whether to return the last output 378 | in the output sequence, or the full sequence. 379 | return_state: Boolean. Whether to return the last state 380 | in addition to the output. 381 | go_backwards: Boolean (default False). 382 | If True, process the input sequence backwards and return the 383 | reversed sequence. 384 | stateful: Boolean (default False). If True, the last state 385 | for each sample at index i in a batch will be used as initial 386 | state for the sample of index i in the following batch. 387 | unroll: Boolean (default False). 388 | If True, the network will be unrolled, 389 | else a symbolic loop will be used. 390 | Unrolling can speed-up a RNN, 391 | although it tends to be more memory-intensive. 392 | Unrolling is only suitable for short sequences. 393 | time_major: The shape format of the `inputs` and `outputs` tensors. 394 | If True, the inputs and outputs will be in shape 395 | `(timesteps, batch, ...)`, whereas in the False case, it will be 396 | `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 397 | efficient because it avoids transposes at the beginning and end of the 398 | RNN calculation. However, most TensorFlow data is batch-major, so by 399 | default this function accepts input and emits output in batch-major 400 | form. 401 | reset_after: GRU convention (whether to apply reset gate after or 402 | before matrix multiplication). False = "before" (default), 403 | True = "after" (CuDNN compatible). 404 | 405 | Call arguments: 406 | inputs: A 3D tensor. 407 | mask: Binary tensor of shape `(samples, timesteps)` indicating whether 408 | a given timestep should be masked. 409 | training: Python boolean indicating whether the layer should behave in 410 | training mode or in inference mode. This argument is passed to the cell 411 | when calling it. This is only relevant if `dropout` or 412 | `recurrent_dropout` is used. 413 | initial_state: List of initial state tensors to be passed to the first 414 | call of the cell. 415 | """ 416 | 417 | def __init__(self, 418 | units, 419 | activation='tanh', 420 | recurrent_activation='hard_sigmoid', 421 | use_bias=True, 422 | kernel_initializer='glorot_uniform', 423 | recurrent_initializer='orthogonal', 424 | bias_initializer='zeros', 425 | kernel_regularizer=None, 426 | recurrent_regularizer=None, 427 | bias_regularizer=None, 428 | activity_regularizer=None, 429 | kernel_constraint=None, 430 | recurrent_constraint=None, 431 | bias_constraint=None, 432 | dropout=0., 433 | recurrent_dropout=0., 434 | implementation=1, 435 | return_sequences=False, 436 | return_state=False, 437 | go_backwards=False, 438 | stateful=False, 439 | unroll=False, 440 | reset_after=False, 441 | **kwargs): 442 | if implementation == 0: 443 | logging.warning('`implementation=0` has been deprecated, ' 444 | 'and now defaults to `implementation=1`.' 445 | 'Please update your layer call.') 446 | cell = GRUCell_mask( 447 | units, 448 | activation=activation, 449 | recurrent_activation=recurrent_activation, 450 | use_bias=use_bias, 451 | kernel_initializer=kernel_initializer, 452 | recurrent_initializer=recurrent_initializer, 453 | bias_initializer=bias_initializer, 454 | kernel_regularizer=kernel_regularizer, 455 | recurrent_regularizer=recurrent_regularizer, 456 | bias_regularizer=bias_regularizer, 457 | kernel_constraint=kernel_constraint, 458 | recurrent_constraint=recurrent_constraint, 459 | bias_constraint=bias_constraint, 460 | dropout=dropout, 461 | recurrent_dropout=recurrent_dropout, 462 | implementation=implementation, 463 | reset_after=reset_after, 464 | dtype=kwargs.get('dtype')) 465 | super(GRU_mask, self).__init__( 466 | cell, 467 | return_sequences=return_sequences, 468 | return_state=return_state, 469 | go_backwards=go_backwards, 470 | stateful=stateful, 471 | unroll=unroll, 472 | **kwargs) 473 | self.activity_regularizer = regularizers.get(activity_regularizer) 474 | self.input_spec = [InputSpec(ndim=3)] 475 | 476 | def call(self, inputs, mask=None, training=None, initial_state=None): 477 | self.cell.reset_dropout_mask() 478 | self.cell.reset_recurrent_dropout_mask() 479 | return super(GRU_mask, self).call( 480 | inputs, mask=mask, training=training, initial_state=initial_state) 481 | 482 | @property 483 | def units(self): 484 | return self.cell.units 485 | 486 | @property 487 | def activation(self): 488 | return self.cell.activation 489 | 490 | @property 491 | def recurrent_activation(self): 492 | return self.cell.recurrent_activation 493 | 494 | @property 495 | def use_bias(self): 496 | return self.cell.use_bias 497 | 498 | @property 499 | def kernel_initializer(self): 500 | return self.cell.kernel_initializer 501 | 502 | @property 503 | def recurrent_initializer(self): 504 | return self.cell.recurrent_initializer 505 | 506 | @property 507 | def bias_initializer(self): 508 | return self.cell.bias_initializer 509 | 510 | @property 511 | def kernel_regularizer(self): 512 | return self.cell.kernel_regularizer 513 | 514 | @property 515 | def recurrent_regularizer(self): 516 | return self.cell.recurrent_regularizer 517 | 518 | @property 519 | def bias_regularizer(self): 520 | return self.cell.bias_regularizer 521 | 522 | @property 523 | def kernel_constraint(self): 524 | return self.cell.kernel_constraint 525 | 526 | @property 527 | def recurrent_constraint(self): 528 | return self.cell.recurrent_constraint 529 | 530 | @property 531 | def bias_constraint(self): 532 | return self.cell.bias_constraint 533 | 534 | @property 535 | def dropout(self): 536 | return self.cell.dropout 537 | 538 | @property 539 | def recurrent_dropout(self): 540 | return self.cell.recurrent_dropout 541 | 542 | @property 543 | def implementation(self): 544 | return self.cell.implementation 545 | 546 | @property 547 | def reset_after(self): 548 | return self.cell.reset_after 549 | 550 | def get_config(self): 551 | config = { 552 | 'units': 553 | self.units, 554 | 'activation': 555 | activations.serialize(self.activation), 556 | 'recurrent_activation': 557 | activations.serialize(self.recurrent_activation), 558 | 'use_bias': 559 | self.use_bias, 560 | 'kernel_initializer': 561 | initializers.serialize(self.kernel_initializer), 562 | 'recurrent_initializer': 563 | initializers.serialize(self.recurrent_initializer), 564 | 'bias_initializer': 565 | initializers.serialize(self.bias_initializer), 566 | 'kernel_regularizer': 567 | regularizers.serialize(self.kernel_regularizer), 568 | 'recurrent_regularizer': 569 | regularizers.serialize(self.recurrent_regularizer), 570 | 'bias_regularizer': 571 | regularizers.serialize(self.bias_regularizer), 572 | 'activity_regularizer': 573 | regularizers.serialize(self.activity_regularizer), 574 | 'kernel_constraint': 575 | constraints.serialize(self.kernel_constraint), 576 | 'recurrent_constraint': 577 | constraints.serialize(self.recurrent_constraint), 578 | 'bias_constraint': 579 | constraints.serialize(self.bias_constraint), 580 | 'dropout': 581 | self.dropout, 582 | 'recurrent_dropout': 583 | self.recurrent_dropout, 584 | 'implementation': 585 | self.implementation, 586 | 'reset_after': 587 | self.reset_after 588 | } 589 | base_config = super(GRU_mask, self).get_config() 590 | del base_config['cell'] 591 | return dict(list(base_config.items()) + list(config.items())) 592 | 593 | @classmethod 594 | def from_config(cls, config): 595 | if 'implementation' in config and config['implementation'] == 0: 596 | config['implementation'] = 1 597 | return cls(**config) 598 | -------------------------------------------------------------------------------- /networks/pruning_methods.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jan 5 23:24:26 2022 4 | 5 | @author: xiaohuai le 6 | """ 7 | import tensorflow as tf 8 | import numpy as np 9 | ''' 10 | A structured pruning method with the intrinsic sparse structures-based regularization 11 | ''' 12 | def get_regular_ISS(rnn, fc, bidirectional = False, *args): 13 | ''' 14 | get the ISS of the RNN and the following FC 15 | ''' 16 | if not bidirectional: 17 | input_weights = rnn.weights[0] 18 | hidden_weights = rnn.weights[1] 19 | linear_weights = fc.weights[0] 20 | # rnn 21 | t1_inp = tf.square(input_weights) 22 | t1_hid = tf.square(hidden_weights) 23 | 24 | t1_col_sum = tf.reduce_sum(t1_inp, axis = 0) + tf.reduce_sum(t1_hid, axis = 0) 25 | t1_col_sum1, t1_col_sum2, t1_col_sum3 = tf.split(t1_col_sum, 3) 26 | t1_row_sum = tf.reduce_sum(t1_hid, axis = 1) 27 | # linear 28 | t2 = tf.square(linear_weights) 29 | t2_row_sum = tf.reduce_sum(t2, axis = 1) 30 | 31 | reg_sum = t1_row_sum + \ 32 | t1_col_sum1 + t1_col_sum2 + t1_col_sum3 + \ 33 | t2_row_sum+ \ 34 | tf.constant(1.0e-8) 35 | reg_sqrt = tf.sqrt(reg_sum) 36 | reg = tf.reduce_sum(reg_sqrt) 37 | else: 38 | forward_input_weights = rnn.weights[0] 39 | forward_hidden_weights = rnn.weights[1] 40 | backward_input_weights = rnn.weights[3] 41 | backward_hidden_weights = rnn.weights[4] 42 | linear_weights = fc.weights[0] 43 | # forward 44 | forward_t1_inp = tf.square(forward_input_weights) 45 | forward_t1_hid = tf.square(forward_hidden_weights) 46 | 47 | forward_t1_col_sum = tf.reduce_sum(forward_t1_inp, axis = 0) + tf.reduce_sum(forward_t1_hid, axis = 0) 48 | forward_t1_col_sum1, forward_t1_col_sum2, forward_t1_col_sum3 = tf.split(forward_t1_col_sum, 3) 49 | forward_t1_row_sum = tf.reduce_sum(forward_t1_hid, axis = 1) 50 | # backward 51 | backward_t1_inp = tf.square(backward_input_weights) 52 | backward_t1_hid = tf.square(backward_hidden_weights) 53 | 54 | backward_t1_col_sum = tf.reduce_sum(backward_t1_inp, axis = 0) + tf.reduce_sum(backward_t1_hid, axis = 0) 55 | backward_t1_col_sum1, backward_t1_col_sum2, backward_t1_col_sum3 = tf.split(backward_t1_col_sum, 3) 56 | backward_t1_row_sum = tf.reduce_sum(backward_t1_hid, axis = 1) 57 | # linear 58 | t2 = tf.square(linear_weights) 59 | t2_row_sum = tf.reduce_sum(t2, axis = 1) 60 | t2_row_sum_forward,t2_row_sum_backward = tf.split(t2_row_sum,2) 61 | 62 | reg_sum = forward_t1_row_sum + \ 63 | forward_t1_col_sum1 + forward_t1_col_sum2 + forward_t1_col_sum3 + \ 64 | backward_t1_row_sum + \ 65 | backward_t1_col_sum1 + backward_t1_col_sum2 + backward_t1_col_sum3 + \ 66 | t2_row_sum_forward + t2_row_sum_backward + \ 67 | tf.constant(1.0e-8) 68 | reg_sqrt = tf.sqrt(reg_sum) 69 | 70 | reg = tf.reduce_sum(reg_sqrt) 71 | 72 | return reg 73 | 74 | def make_mask(rnn,fc,n=1,bidirectional=False): 75 | 76 | rnn_weights = rnn.get_weights() 77 | fc_weights = fc.get_weights() 78 | h_dim = rnn_weights[1].shape[0] 79 | if not bidirectional: 80 | # unidirectional RNN is pruned n dimension per step 81 | input_weights = rnn_weights[0]**2 82 | hidden_weights = rnn_weights[1]**2 83 | dense_weights = fc_weights[0]**2 84 | 85 | mag_list = [] 86 | for i in range(h_dim): 87 | if rnn_weights[5][i,0] == 0: 88 | mag_list.append(np.inf) 89 | else: 90 | t1_row_sum = np.sum(hidden_weights[i,:]) 91 | t1_col_sum1 = np.sum(input_weights[:,i]) + np.sum(hidden_weights[:,i]) 92 | t1_col_sum2 = np.sum(input_weights[:,i+h_dim]) + np.sum(hidden_weights[:,i+h_dim]) 93 | t1_col_sum3 = np.sum(input_weights[:,i+h_dim*2]) + np.sum(hidden_weights[:,i+h_dim*2]) 94 | t2_row_sum = np.sum(dense_weights[i,:]) 95 | reg_sum = t1_row_sum + \ 96 | t1_col_sum1 + t1_col_sum2 + t1_col_sum3 + \ 97 | t2_row_sum+ \ 98 | 1.0e-8 99 | mag_list.append(np.sqrt(reg_sum)) 100 | 101 | top_n_idx = np.array(mag_list).argsort()[0:n] 102 | print(top_n_idx) 103 | for i in top_n_idx: 104 | rnn_weights[3][0,i] = 0 105 | rnn_weights[3][0,i+h_dim] = 0 106 | rnn_weights[3][0,i+h_dim*2] = 0 107 | rnn_weights[4][0,i] = 0 108 | rnn_weights[4][0,i+h_dim] = 0 109 | rnn_weights[4][0,i+h_dim*2] = 0 110 | 111 | rnn_weights[5][i,0] = 0 112 | fc_weights[2][i,0] = 0 113 | 114 | rnn.set_weights(rnn_weights) 115 | fc.set_weights(fc_weights) 116 | else: 117 | forward_input_weights = rnn_weights[0]**2 118 | forward_hidden_weights = rnn_weights[1]**2 119 | backward_input_weights = rnn_weights[3]**2 120 | backward_hidden_weights = rnn_weights[4]**2 121 | 122 | dense_weights = fc_weights[0]**2 123 | #get forward mask 124 | mag_list = [] 125 | for i in range(h_dim): 126 | if rnn_weights[8][i,0] == 0: 127 | mag_list.append(np.inf) 128 | else: 129 | t1_row_sum = np.sum(forward_hidden_weights[i,:]) 130 | t1_col_sum1 = np.sum(forward_input_weights[:,i]) + np.sum(forward_hidden_weights[:,i]) 131 | t1_col_sum2 = np.sum(forward_input_weights[:,i+h_dim]) + np.sum(forward_hidden_weights[:,i+h_dim]) 132 | t1_col_sum3 = np.sum(forward_input_weights[:,i+h_dim*2]) + np.sum(forward_hidden_weights[:,i+h_dim*2]) 133 | t2_row_sum = np.sum(dense_weights[i,:]) 134 | reg_sum = t1_row_sum + \ 135 | t1_col_sum1 + t1_col_sum2 + t1_col_sum3 + \ 136 | t2_row_sum+ \ 137 | 1.0e-8 138 | mag_list.append(np.sqrt(reg_sum)) 139 | 140 | top_n_idx = np.array(mag_list).argsort()[0:n] 141 | print('foward:',top_n_idx) 142 | for i in top_n_idx: 143 | rnn_weights[6][0,i] = 0 144 | rnn_weights[6][0,i+h_dim] = 0 145 | rnn_weights[6][0,i+h_dim*2] = 0 146 | rnn_weights[7][0,i] = 0 147 | rnn_weights[7][0,i+h_dim] = 0 148 | rnn_weights[7][0,i+h_dim*2] = 0 149 | 150 | rnn_weights[8][i,0] = 0 151 | fc_weights[2][i,0] = 0 152 | 153 | #get backward mask 154 | mag_list = [] 155 | for i in range(h_dim): 156 | if rnn_weights[11][i,0] == 0: 157 | mag_list.append(np.inf) 158 | else: 159 | t1_row_sum = np.sum(backward_hidden_weights[i,:]) 160 | t1_col_sum1 = np.sum(backward_input_weights[:,i]) + np.sum(backward_hidden_weights[:,i]) 161 | t1_col_sum2 = np.sum(backward_input_weights[:,i+h_dim]) + np.sum(backward_hidden_weights[:,i+h_dim]) 162 | t1_col_sum3 = np.sum(backward_input_weights[:,i+h_dim*2]) + np.sum(backward_hidden_weights[:,i+h_dim*2]) 163 | t2_row_sum = np.sum(dense_weights[i+h_dim,:]) 164 | reg_sum = t1_row_sum + \ 165 | t1_col_sum1 + t1_col_sum2 + t1_col_sum3 + \ 166 | t2_row_sum+ \ 167 | 1.0e-8 168 | mag_list.append(np.sqrt(reg_sum)) 169 | 170 | top_n_idx = np.array(mag_list).argsort()[0:n] 171 | print('backward:',top_n_idx) 172 | for i in top_n_idx: 173 | rnn_weights[9][0,i] = 0 174 | rnn_weights[9][0,i+h_dim] = 0 175 | rnn_weights[9][0,i+h_dim*2] = 0 176 | rnn_weights[10][0,i] = 0 177 | rnn_weights[10][0,i+h_dim] = 0 178 | rnn_weights[10][0,i+h_dim*2] = 0 179 | 180 | rnn_weights[11][i,0] = 0 181 | fc_weights[2][i+h_dim,0] = 0 182 | rnn.set_weights(rnn_weights) 183 | fc.set_weights(fc_weights) -------------------------------------------------------------------------------- /networks/skip_gru.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Nov 15 21:46:43 2021 4 | 5 | @author: Xiaohuai Le 6 | """ 7 | import tensorflow as tf 8 | from tensorflow.keras.layers import Layer 9 | import tensorflow.keras.backend as K 10 | from tensorflow.python.keras import activations 11 | from tensorflow.python.keras import constraints 12 | from tensorflow.python.keras import initializers 13 | from tensorflow.python.keras import regularizers 14 | from tensorflow.keras.layers import RNN 15 | from tensorflow.python.framework import tensor_shape 16 | from tensorflow.python.util import nest 17 | import logging 18 | from tensorflow.python.training.tracking import data_structures 19 | 20 | def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype): 21 | if inputs is not None: 22 | batch_size = tf.shape(inputs)[0] 23 | dtype = inputs.dtype 24 | return _generate_zero_filled_state(batch_size, cell.state_size, dtype) 25 | 26 | 27 | def _generate_zero_filled_state(batch_size_tensor, state_size, dtype): 28 | """Generate a zero filled tensor with shape [batch_size, state_size].""" 29 | if batch_size_tensor is None or dtype is None: 30 | raise ValueError( 31 | 'batch_size and dtype cannot be None while constructing initial state: ' 32 | 'batch_size={}, dtype={}'.format(batch_size_tensor, dtype)) 33 | 34 | def create_zeros(unnested_state_size): 35 | flat_dims = tensor_shape.as_shape(unnested_state_size).as_list() 36 | init_state_size = [batch_size_tensor] + flat_dims 37 | return tf.zeros(init_state_size, dtype=dtype) 38 | 39 | if nest.is_sequence(state_size): 40 | return nest.map_structure(create_zeros, state_size) 41 | else: 42 | return create_zeros(state_size) 43 | 44 | def _binary_round(x): 45 | """ 46 | Rounds a tensor whose values are in [0,1] to a tensor with values in {0, 1}, 47 | using the straight through estimator for the gradient. 48 | 49 | Based on http://r2rt.com/binary-stochastic-neurons-in-tensorflow.html 50 | 51 | :param x: input tensor 52 | :return: y=round(x) with gradients defined by the identity mapping (y=x) 53 | """ 54 | 55 | g = tf.get_default_graph() 56 | 57 | #with tf.name_scope("BinaryRound") as name: 58 | with g.gradient_override_map({"Round": "Identity"}): 59 | return tf.round(x) 60 | 61 | 62 | class SkipGRUCell(Layer): 63 | 64 | def __init__(self, 65 | units, 66 | activation='tanh', 67 | recurrent_activation='hard_sigmoid', 68 | use_bias=True, 69 | kernel_initializer='glorot_uniform', 70 | recurrent_initializer='orthogonal', 71 | bias_initializer='zeros', 72 | kernel_regularizer=None, 73 | recurrent_regularizer=None, 74 | bias_regularizer=None, 75 | kernel_constraint=None, 76 | recurrent_constraint=None, 77 | bias_constraint=None, 78 | implementation=1, 79 | reset_after=False, 80 | moving_ave=False, 81 | **kwargs): 82 | super(SkipGRUCell, self).__init__(**kwargs) 83 | self.units = units 84 | self.activation = activations.get(activation) 85 | self.recurrent_activation = activations.get(recurrent_activation) 86 | self.use_bias = use_bias 87 | self.moving_ave = moving_ave 88 | self.kernel_initializer = initializers.get(kernel_initializer) 89 | self.recurrent_initializer = initializers.get(recurrent_initializer) 90 | self.bias_initializer = initializers.get(bias_initializer) 91 | 92 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 93 | self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 94 | self.bias_regularizer = regularizers.get(bias_regularizer) 95 | 96 | self.kernel_constraint = constraints.get(kernel_constraint) 97 | self.recurrent_constraint = constraints.get(recurrent_constraint) 98 | self.bias_constraint = constraints.get(bias_constraint) 99 | 100 | self.implementation = implementation 101 | self.reset_after = reset_after 102 | self.state_size = self.units 103 | self.output_size = self.units 104 | self.linear = tf.keras.layers.Dense(1,bias_initializer='ones',activation='sigmoid') 105 | self.state_size = data_structures.NoDependency([self.units, 1, 1]) 106 | 107 | def build(self, input_shape): 108 | input_dim = input_shape[-1]-1 109 | self.kernel = self.add_weight( 110 | shape=(input_dim, self.units * 3), 111 | name='kernel', 112 | initializer=self.kernel_initializer, 113 | regularizer=self.kernel_regularizer, 114 | constraint=self.kernel_constraint) 115 | self.recurrent_kernel = self.add_weight( 116 | shape=(self.units, self.units * 3), 117 | name='recurrent_kernel', 118 | initializer=self.recurrent_initializer, 119 | regularizer=self.recurrent_regularizer, 120 | constraint=self.recurrent_constraint) 121 | if self.use_bias: 122 | if not self.reset_after: 123 | bias_shape = (3 * self.units,) 124 | else: 125 | # separate biases for input and recurrent kernels 126 | # Note: the shape is intentionally different from CuDNNGRU biases 127 | # `(2 * 3 * self.units,)`, so that we can distinguish the classes 128 | # when loading and converting saved weights. 129 | bias_shape = (2, 3 * self.units) 130 | self.bias = self.add_weight(shape=bias_shape, 131 | name='bias', 132 | initializer=self.bias_initializer, 133 | regularizer=self.bias_regularizer, 134 | constraint=self.bias_constraint) 135 | else: 136 | self.bias = None 137 | self.built = True 138 | 139 | def call(self, inputs, states, training=None): 140 | ''' 141 | Skip-GRU Cell based on the GRU implement of tensorflow 142 | 143 | inputs: the input of this time step 144 | the scale (gamma) of the update rate 145 | 146 | states: the hidden state, 147 | the update probability and the cumulative update probability of the last time step 148 | ''' 149 | # GRU Cell 150 | h_tm1, update_prob_prev, cum_update_prob_prev = states[0],states[1],states[2] 151 | if self.use_bias: 152 | if not self.reset_after: 153 | input_bias, recurrent_bias = self.bias, None 154 | else: 155 | input_bias, recurrent_bias = tf.unstack(self.bias) 156 | # scale is the gamma which used to control the update rate 157 | scale = inputs[:,-1:] 158 | if self.implementation == 1: 159 | 160 | inputs_z = inputs[:,:-1] 161 | inputs_r = inputs[:,:-1] 162 | inputs_h = inputs[:,:-1] 163 | 164 | x_z = K.dot(inputs_z, self.kernel[:, :self.units]) 165 | x_r = K.dot(inputs_r, self.kernel[:, self.units:self.units * 2]) 166 | x_h = K.dot(inputs_h, self.kernel[:, self.units * 2:]) 167 | 168 | if self.use_bias: 169 | x_z = K.bias_add(x_z, input_bias[:self.units]) 170 | x_r = K.bias_add(x_r, input_bias[self.units: self.units * 2]) 171 | x_h = K.bias_add(x_h, input_bias[self.units * 2:]) 172 | 173 | h_tm1_z = h_tm1 174 | h_tm1_r = h_tm1 175 | h_tm1_h = h_tm1 176 | 177 | recurrent_z = K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units]) 178 | recurrent_r = K.dot(h_tm1_r, 179 | self.recurrent_kernel[:, self.units:self.units * 2]) 180 | if self.reset_after and self.use_bias: 181 | recurrent_z = K.bias_add(recurrent_z, recurrent_bias[:self.units]) 182 | recurrent_r = K.bias_add(recurrent_r, 183 | recurrent_bias[self.units:self.units * 2]) 184 | 185 | z = self.recurrent_activation(x_z + recurrent_z) 186 | r = self.recurrent_activation(x_r + recurrent_r) 187 | 188 | # reset gate applied after/before matrix multiplication 189 | if self.reset_after: 190 | recurrent_h = K.dot(h_tm1_h, self.recurrent_kernel[:, self.units * 2:]) 191 | if self.use_bias: 192 | recurrent_h = K.bias_add(recurrent_h, recurrent_bias[self.units * 2:]) 193 | recurrent_h = r * recurrent_h 194 | else: 195 | recurrent_h = K.dot(r * h_tm1_h, 196 | self.recurrent_kernel[:, self.units * 2:]) 197 | 198 | hh = self.activation(x_h + recurrent_h) 199 | else: 200 | 201 | # inputs projected by all gate matrices at once 202 | matrix_x = K.dot(inputs[:,:-1], self.kernel) 203 | if self.use_bias: 204 | # biases: bias_z_i, bias_r_i, bias_h_i 205 | matrix_x = K.bias_add(matrix_x, input_bias) 206 | 207 | x_z = matrix_x[:, :self.units] 208 | x_r = matrix_x[:, self.units: 2 * self.units] 209 | x_h = matrix_x[:, 2 * self.units:] 210 | 211 | if self.reset_after: 212 | # hidden state projected by all gate matrices at once 213 | matrix_inner = K.dot(h_tm1, self.recurrent_kernel) 214 | if self.use_bias: 215 | matrix_inner = K.bias_add(matrix_inner, recurrent_bias) 216 | else: 217 | # hidden state projected separately for update/reset and new 218 | matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units]) 219 | 220 | recurrent_z = matrix_inner[:, :self.units] 221 | recurrent_r = matrix_inner[:, self.units:2 * self.units] 222 | 223 | z = self.recurrent_activation(x_z + recurrent_z) 224 | r = self.recurrent_activation(x_r + recurrent_r) 225 | 226 | if self.reset_after: 227 | recurrent_h = r * matrix_inner[:, 2 * self.units:] 228 | else: 229 | recurrent_h = K.dot(r * h_tm1, 230 | self.recurrent_kernel[:, 2 * self.units:]) 231 | 232 | hh = self.activation(x_h + recurrent_h) 233 | 234 | # previous and candidate state mixed by update gate 235 | h = z * h_tm1 + (1 - z) * hh 236 | 237 | # SKIP RNN 238 | new_update_prob_tilde = self.linear(h) * scale 239 | cum_update_prob = cum_update_prob_prev + tf.minimum(update_prob_prev, 1. - cum_update_prob_prev) 240 | update_gate = _binary_round(cum_update_prob) 241 | 242 | # Apply update gate 243 | if self.moving_ave: 244 | new_h = update_gate * h + (1. - update_gate) * (h_tm1 * 0.9 + self.activation(x_h) * 0.1) 245 | else: 246 | new_h = update_gate * h + (1. - update_gate) * h_tm1 247 | new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev 248 | new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob 249 | 250 | return [new_h,update_gate], [new_h, new_update_prob, new_cum_update_prob] #tf.concat([h,new_update_prob,new_cum_update_prob],axis=-1) 251 | 252 | def get_config(self): 253 | config = { 254 | 'units': self.units, 255 | 'activation': activations.serialize(self.activation), 256 | 'recurrent_activation': 257 | activations.serialize(self.recurrent_activation), 258 | 'use_bias': self.use_bias, 259 | 'kernel_initializer': initializers.serialize(self.kernel_initializer), 260 | 'recurrent_initializer': 261 | initializers.serialize(self.recurrent_initializer), 262 | 'bias_initializer': initializers.serialize(self.bias_initializer), 263 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 264 | 'recurrent_regularizer': 265 | regularizers.serialize(self.recurrent_regularizer), 266 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 267 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 268 | 'recurrent_constraint': 269 | constraints.serialize(self.recurrent_constraint), 270 | 'bias_constraint': constraints.serialize(self.bias_constraint), 271 | 'implementation': self.implementation, 272 | 'reset_after': self.reset_after 273 | } 274 | base_config = super(SkipGRUCell, self).get_config() 275 | return dict(list(base_config.items()) + list(config.items())) 276 | ''' 277 | def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 278 | return list(_generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)) 279 | ''' 280 | class SkipGRU(RNN): 281 | 282 | def __init__(self, 283 | units, 284 | activation='tanh', 285 | recurrent_activation='hard_sigmoid', 286 | use_bias=True, 287 | kernel_initializer='glorot_uniform', 288 | recurrent_initializer='orthogonal', 289 | bias_initializer='zeros', 290 | kernel_regularizer=None, 291 | recurrent_regularizer=None, 292 | bias_regularizer=None, 293 | activity_regularizer=None, 294 | kernel_constraint=None, 295 | recurrent_constraint=None, 296 | bias_constraint=None, 297 | implementation=1, 298 | return_sequences=False, 299 | return_state=False, 300 | go_backwards=False, 301 | stateful=False, 302 | unroll=False, 303 | reset_after=False, 304 | moving_ave=False, 305 | **kwargs): 306 | if implementation == 0: 307 | logging.warning('`implementation=0` has been deprecated, ' 308 | 'and now defaults to `implementation=1`.' 309 | 'Please update your layer call.') 310 | cell = SkipGRUCell( 311 | units, 312 | activation=activation, 313 | recurrent_activation=recurrent_activation, 314 | use_bias=use_bias, 315 | kernel_initializer=kernel_initializer, 316 | recurrent_initializer=recurrent_initializer, 317 | bias_initializer=bias_initializer, 318 | kernel_regularizer=kernel_regularizer, 319 | recurrent_regularizer=recurrent_regularizer, 320 | bias_regularizer=bias_regularizer, 321 | kernel_constraint=kernel_constraint, 322 | recurrent_constraint=recurrent_constraint, 323 | bias_constraint=bias_constraint, 324 | implementation=implementation, 325 | reset_after=reset_after, 326 | moving_ave=moving_ave, 327 | dtype=kwargs.get('dtype')) 328 | super(SkipGRU, self).__init__( 329 | cell, 330 | return_sequences=return_sequences, 331 | return_state=return_state, 332 | go_backwards=go_backwards, 333 | stateful=stateful, 334 | unroll=unroll, 335 | **kwargs) 336 | self.activity_regularizer = regularizers.get(activity_regularizer) 337 | 338 | def call(self, inputs, mask=None, training=None, initial_state=None): 339 | if initial_state is None: 340 | initial_state=[tf.zeros([tf.shape(inputs)[0],self.units]),tf.ones([tf.shape(inputs)[0],1]),tf.zeros([tf.shape(inputs)[0],1])] 341 | return super(SkipGRU, self).call( 342 | inputs, mask=mask, training=training, initial_state=initial_state) 343 | 344 | @property 345 | def units(self): 346 | return self.cell.units 347 | 348 | @property 349 | def activation(self): 350 | return self.cell.activation 351 | 352 | @property 353 | def recurrent_activation(self): 354 | return self.cell.recurrent_activation 355 | 356 | @property 357 | def use_bias(self): 358 | return self.cell.use_bias 359 | 360 | @property 361 | def kernel_initializer(self): 362 | return self.cell.kernel_initializer 363 | 364 | @property 365 | def recurrent_initializer(self): 366 | return self.cell.recurrent_initializer 367 | 368 | @property 369 | def bias_initializer(self): 370 | return self.cell.bias_initializer 371 | 372 | @property 373 | def kernel_regularizer(self): 374 | return self.cell.kernel_regularizer 375 | 376 | @property 377 | def recurrent_regularizer(self): 378 | return self.cell.recurrent_regularizer 379 | 380 | @property 381 | def bias_regularizer(self): 382 | return self.cell.bias_regularizer 383 | 384 | @property 385 | def kernel_constraint(self): 386 | return self.cell.kernel_constraint 387 | 388 | @property 389 | def recurrent_constraint(self): 390 | return self.cell.recurrent_constraint 391 | 392 | @property 393 | def bias_constraint(self): 394 | return self.cell.bias_constraint 395 | 396 | @property 397 | def implementation(self): 398 | return self.cell.implementation 399 | 400 | @property 401 | def reset_after(self): 402 | return self.cell.reset_after 403 | 404 | def get_config(self): 405 | config = { 406 | 'units': 407 | self.units, 408 | 'activation': 409 | activations.serialize(self.activation), 410 | 'recurrent_activation': 411 | activations.serialize(self.recurrent_activation), 412 | 'use_bias': 413 | self.use_bias, 414 | 'kernel_initializer': 415 | initializers.serialize(self.kernel_initializer), 416 | 'recurrent_initializer': 417 | initializers.serialize(self.recurrent_initializer), 418 | 'bias_initializer': 419 | initializers.serialize(self.bias_initializer), 420 | 'kernel_regularizer': 421 | regularizers.serialize(self.kernel_regularizer), 422 | 'recurrent_regularizer': 423 | regularizers.serialize(self.recurrent_regularizer), 424 | 'bias_regularizer': 425 | regularizers.serialize(self.bias_regularizer), 426 | 'activity_regularizer': 427 | regularizers.serialize(self.activity_regularizer), 428 | 'kernel_constraint': 429 | constraints.serialize(self.kernel_constraint), 430 | 'recurrent_constraint': 431 | constraints.serialize(self.recurrent_constraint), 432 | 'bias_constraint': 433 | constraints.serialize(self.bias_constraint), 434 | 'implementation': 435 | self.implementation, 436 | 'reset_after': 437 | self.reset_after 438 | } 439 | base_config = super(SkipGRU, self).get_config() 440 | del base_config['cell'] 441 | return dict(list(base_config.items()) + list(config.items())) 442 | 443 | @classmethod 444 | def from_config(cls, config): 445 | if 'implementation' in config and config['implementation'] == 0: 446 | config['implementation'] = 1 447 | return cls(**config) 448 | 449 | 450 | if __name__ == '__main__': 451 | import numpy as np 452 | 453 | ''' 454 | Test the Skip-GRU module on the MNIST data set 455 | ''' 456 | inp = tf.keras.layers.Input(batch_shape = [100,None,29]) 457 | rnn1,gate,h,p,cp = SkipGRU(units = 64, 458 | return_sequences = True,return_state = True)(inp) 459 | results = tf.keras.layers.Dense(10,activation ='softmax')(h) 460 | 461 | gate_regular = tf.reduce_mean(gate) * 0.5 462 | model = tf.keras.models.Model(inp,results) 463 | 464 | 465 | def update_rate(y_true, y_pred): 466 | return tf.reduce_mean(gate) 467 | 468 | from tensorflow.keras.datasets import mnist 469 | 470 | (x_train, y_train_), (x_test, y_test_) = mnist.load_data() 471 | x_train = x_train.astype('float32') / 255. 472 | x_test = x_test.astype('float32') / 255. 473 | x_train = x_train.reshape((-1, 28, 28)) 474 | x_test = x_test.reshape((-1, 28,28)) 475 | # inputs are padded with gamma = 1 476 | x_train = np.concatenate([x_train,np.ones([60000,28,1])],-1) 477 | x_test = np.concatenate([x_test,np.ones([10000,28,1])],-1) 478 | 479 | y_train = tf.keras.utils.to_categorical(y_train_) 480 | y_test = tf.keras.utils.to_categorical(y_test_) 481 | model.add_loss(gate_regular) 482 | model.compile(optimizer='adam', loss='categorical_crossentropy',metrics=['acc', update_rate]) 483 | 484 | model.fit(x_train,y_train,batch_size = 100,epochs =20) 485 | 486 | -------------------------------------------------------------------------------- /pretrained_weights/DPCRN_base/models_experiment_new_base_nomap_phasenloss_retrain_WSJmodel_84_0.022068.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Le-Xiaohuai-speech/SKIP-DPCRN/40e9bb6468bbfecff6ddd698eb63dac7b1089472/pretrained_weights/DPCRN_base/models_experiment_new_base_nomap_phasenloss_retrain_WSJmodel_84_0.022068.h5 -------------------------------------------------------------------------------- /signal_processing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Mar 2 16:27:34 2022 4 | 5 | @author: Xiaohuai Le 6 | """ 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | class Signal_Pro(): 12 | def __init__(self, config): 13 | 14 | self.fs = config['stft']['fs'] 15 | self.block_len = config['stft']['block_len'] 16 | self.block_shift = config['stft']['block_shift'] 17 | self.window = config['stft']['window'] 18 | self.N_FFT = config['stft']['N_FFT'] 19 | self.win = None 20 | if self.window == 'sine': 21 | win = np.sin(np.arange(.5, self.block_len-.5+1) / self.block_len * np.pi) 22 | self.win = tf.constant(win, dtype = 'float32') 23 | else: 24 | pass 25 | 26 | def sep2frame(self, x): 27 | ''' 28 | generate frames from time-domain signal 29 | ''' 30 | frames = tf.signal.frame(x, self.block_len, self.block_shift) 31 | frames = self.win*frames 32 | return frames 33 | 34 | def stftLayer(self, x, mode ='mag_pha'): 35 | ''' 36 | Method for an STFT helper layer used with a Lambda layer. The layer 37 | calculates the STFT on the last dimension and returns the magnitude and 38 | phase of the STFT. 39 | ''' 40 | # creating frames from the continuous waveform 41 | frames = tf.signal.frame(x, self.block_len, self.block_shift) 42 | frames = self.win * frames 43 | # calculating the fft over the time frames. rfft returns NFFT/2+1 bins. 44 | stft_dat = tf.signal.rfft(frames) 45 | # calculating magnitude and phase from the complex signal 46 | output_list = [] 47 | if mode == 'mag_pha': 48 | mag = tf.math.abs(stft_dat) 49 | phase = tf.math.angle(stft_dat) 50 | output_list = [mag, phase] 51 | elif mode == 'real_imag': 52 | real = tf.math.real(stft_dat) 53 | imag = tf.math.imag(stft_dat) 54 | output_list = [real, imag] 55 | # returning magnitude and phase as list 56 | return output_list 57 | 58 | def fftLayer(self, x): 59 | ''' 60 | Method for an fft helper layer used with a Lambda layer. The layer 61 | calculates the rFFT on the last dimension and returns the magnitude and 62 | phase of the STFT. 63 | ''' 64 | 65 | # calculating the fft over the time frames. rfft returns NFFT/2+1 bins. 66 | stft_dat = tf.signal.rfft(x) 67 | # calculating magnitude and phase from the complex signal 68 | mag = tf.abs(stft_dat) 69 | phase = tf.math.angle(stft_dat) 70 | # returning magnitude and phase as list 71 | return [mag, phase] 72 | 73 | def ifftLayer(self, x, mode = 'mag_pha'): 74 | ''' 75 | Method for an inverse FFT layer used with an Lambda layer. This layer 76 | calculates time domain frames from magnitude and phase information. 77 | As input x a list with [mag,phase] is required. 78 | ''' 79 | if mode == 'mag_pha': 80 | # calculating the complex representation 81 | s1_stft = (tf.cast(x[0], tf.complex64) * 82 | tf.exp( (1j * tf.cast(x[1], tf.complex64)))) 83 | elif mode == 'real_imag': 84 | s1_stft = tf.cast(x[0], tf.complex64) + 1j * tf.cast(x[1], tf.complex64) 85 | # returning the time domain frames 86 | return tf.signal.irfft(s1_stft) 87 | 88 | def overlapAddLayer(self, x): 89 | ''' 90 | Method for an overlap and add helper layer used with a Lambda layer. 91 | This layer reconstructs the waveform from a framed signal. 92 | ''' 93 | 94 | # calculating and returning the reconstructed waveform 95 | ''' 96 | if self.move_dc: 97 | x = x - tf.expand_dims(tf.reduce_mean(x,axis = -1),2) 98 | ''' 99 | return tf.signal.overlap_and_add(x, self.block_shift) 100 | 101 | def mk_mask_complex(self, x): 102 | ''' 103 | complex ratio mask 104 | ''' 105 | [noisy_real,noisy_imag,mask] = x 106 | noisy_real = noisy_real[:,:,:,0] 107 | noisy_imag = noisy_imag[:,:,:,0] 108 | 109 | mask_real = mask[:,:,:,0] 110 | mask_imag = mask[:,:,:,1] 111 | 112 | enh_real = noisy_real * mask_real - noisy_imag * mask_imag 113 | enh_imag = noisy_real * mask_imag + noisy_imag * mask_real 114 | 115 | return [enh_real,enh_imag] 116 | 117 | def mk_mask_mag(self, x): 118 | ''' 119 | magnitude mask 120 | ''' 121 | [noisy_real,noisy_imag,mag_mask] = x 122 | noisy_real = noisy_real[:,:,:,0] 123 | noisy_imag = noisy_imag[:,:,:,0] 124 | 125 | enh_mag_real = noisy_real * mag_mask 126 | enh_mag_imag = noisy_imag * mag_mask 127 | return [enh_mag_real,enh_mag_imag] 128 | 129 | def mk_mask_pha(self, x): 130 | ''' 131 | phase mask 132 | ''' 133 | [enh_mag_real,enh_mag_imag,pha_cos,pha_sin] = x 134 | 135 | enh_real = enh_mag_real * pha_cos - enh_mag_imag * pha_sin 136 | enh_imag = enh_mag_real * pha_sin + enh_mag_imag * pha_cos 137 | 138 | return [enh_real,enh_imag] 139 | 140 | def mk_mask_mag_pha(self, x): 141 | 142 | [noisy_real,noisy_imag,mag_mask,pha_cos,pha_sin] = x 143 | noisy_real = noisy_real[:,:,:,0] 144 | noisy_imag = noisy_imag[:,:,:,0] 145 | 146 | enh_mag_real = noisy_real * mag_mask 147 | enh_mag_imag = noisy_imag * mag_mask 148 | 149 | enh_real = enh_mag_real * pha_cos - enh_mag_imag * pha_sin 150 | enh_imag = enh_mag_real * pha_sin + enh_mag_imag * pha_cos 151 | 152 | return [enh_real,enh_imag] 153 | 154 | 155 | -------------------------------------------------------------------------------- /test_audio/enhanced/440C020A_mix.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Le-Xiaohuai-speech/SKIP-DPCRN/40e9bb6468bbfecff6ddd698eb63dac7b1089472/test_audio/enhanced/440C020A_mix.wav -------------------------------------------------------------------------------- /test_audio/noisy/440C020A_mix.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Le-Xiaohuai-speech/SKIP-DPCRN/40e9bb6468bbfecff6ddd698eb63dac7b1089472/test_audio/noisy/440C020A_mix.wav --------------------------------------------------------------------------------