├── README.md ├── config.py ├── data_generator.py ├── evaluate.py ├── gv.py ├── main_crn.py ├── main_dnn.py ├── main_dsn.py ├── make_tfrecord.py ├── mini_data ├── test_noise │ ├── babble.wav │ └── white.wav ├── test_speech │ ├── TEST_DR3_MJMP0_SA2.WAV │ └── TEST_DR3_MJMP0_SI905.WAV ├── train_noise │ ├── n1.wav │ └── n43.wav └── train_speech │ ├── TRAIN_DR1_FCJF0_SA2.WAV │ └── TRAIN_DR1_FCJF0_SI648.WAV ├── prepare_data.py ├── spectrogram_to_wave.py └── timit_handler.py /README.md: -------------------------------------------------------------------------------- 1 | # speech_enhancement 2 | speech enhancement using 3 | DNN: 4 | [1] Xu, Y., Du, J., Dai, L.R. and Lee, C.H., 2015. 5 | A regression approach to speech enhancement based on deep neural networks. 6 | IEEE/ACM Transactions on Audio, Speech and Language Processing (TASLP), 23(1), pp.7-19. 7 | https://github.com/yongxuUSTC/sednn/tree/master/mixture2clean_dnn 8 | 9 | 10 | 11 | CRN: 12 | Park S R , Lee J . A Fully Convolutional Neural Network for Speech Enhancement[J]. 2016. 13 | 14 | DSN(still coding): 15 | Nie S , Zhang H , Zhang X L , et al. DEEP STACKING NETWORKS WITH TIME SERIES FOR SPEECH SEPARATION[J]. 2014. 16 | 17 | 18 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Summary: Config file. 3 | Author: Qiuqiang Kong 4 | Created: 2017.12.21 5 | Modified: - 6 | """ 7 | 8 | sample_rate = 16000 9 | #n_window = 512 # windows size for FFT 10 | #n_overlap = 256 # overlap of window 11 | n_overlap = 128 12 | n_window = 256 -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import prepare_data as pp_data 4 | class DataGenerator(object): 5 | def __init__(self, batch_size, type, te_max_iter=None): 6 | assert type in ['train', 'test'] 7 | self._batch_size_ = batch_size 8 | self._type_ = type 9 | self._te_max_iter_ = te_max_iter 10 | 11 | def generate(self, xs, ys): 12 | x = xs[0] 13 | y = ys[0] 14 | batch_size = self._batch_size_ 15 | n_samples = len(x) 16 | 17 | index = np.arange(n_samples) 18 | np.random.shuffle(index) 19 | 20 | iter = 0 21 | epoch = 0 22 | pointer = 0 23 | while True: 24 | if (self._type_ == 'test') and (self._te_max_iter_ is not None): 25 | if iter == self._te_max_iter_: 26 | break 27 | iter += 1 28 | if pointer >= n_samples: 29 | epoch += 1 30 | if (self._type_) == 'test' and (epoch == 1): 31 | break 32 | pointer = 0 33 | np.random.shuffle(index) 34 | 35 | batch_idx = index[pointer : min(pointer + batch_size, n_samples)] 36 | pointer += batch_size 37 | yield x[batch_idx], y[batch_idx] 38 | 39 | 40 | 41 | 42 | 43 | class DataGenerator_h5py(object): 44 | def __init__(self, batch_size, type, scaler, te_max_iter=None , ): 45 | assert type in ['train', 'test'] 46 | self._batch_size_ = batch_size 47 | self._type_ = type 48 | self._te_max_iter_ = te_max_iter 49 | self._scaler_ = scaler 50 | 51 | def generate(self, path_list): 52 | iter = 0 53 | epoch = 0 54 | pointer = 0 55 | path = path_list[epoch] 56 | n_file = len(path_list) 57 | data = h5py.File(path) 58 | x = data['x'] 59 | y = data['y'] 60 | batch_size = self._batch_size_ 61 | n_samples = len(x) 62 | index = np.arange(n_samples) 63 | np.random.shuffle(index) 64 | while True: 65 | if (self._type_ == 'test') and (self._te_max_iter_ is not None): 66 | if iter == self._te_max_iter_: 67 | break 68 | iter += 1 69 | if pointer >= n_samples: 70 | epoch += 1 71 | if epoch == n_file: 72 | epoch = 0 73 | path = path_list[epoch] 74 | print("start %s"%path) 75 | n_file = len(path_list) 76 | data = h5py.File(path) 77 | x = data['x'] 78 | y = data['y'] 79 | if (self._type_) == 'test' and (epoch == n_file - 1): 80 | break 81 | pointer = 0 82 | np.random.shuffle(index) 83 | 84 | batch_idx = index[pointer : min(pointer + batch_size, n_samples)] 85 | pointer += batch_size 86 | yield pp_data.scale_on_3d(x[sorted(batch_idx)], self._scaler_), pp_data.scale_on_2d(y[sorted(batch_idx)], self._scaler_) 87 | 88 | 89 | 90 | ''' 91 | count = 0 92 | tr_gen = DataGenerator_h5py(batch_size = 10, type = "train") 93 | for (batch_x, batch_y) in tr_gen.generate(path_list = ["data1.h5", "data2.h5"]): 94 | count+=1 95 | print(count) 96 | ''' 97 | 98 | 99 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Summary: Calculate PESQ and overal stats of enhanced speech. 3 | Author: Qiuqiang Kong 4 | Created: 2017.12.22 5 | Modified: - 6 | """ 7 | import argparse 8 | import os 9 | import csv 10 | import numpy as np 11 | import cPickle 12 | import soundfile 13 | from pypesq import pypesq 14 | from pystoi.stoi import stoi 15 | from prepare_data import create_folder 16 | #import matplotlib.pyplot as plt 17 | 18 | 19 | def plot_training_stat(args): 20 | """Plot training and testing loss. 21 | 22 | Args: 23 | workspace: str, path of workspace. 24 | tr_snr: float, training SNR. 25 | bgn_iter: int, plot from bgn_iter 26 | fin_iter: int, plot finish at fin_iter 27 | interval_iter: int, interval of files. 28 | """ 29 | workspace = args.workspace 30 | tr_snr = args.tr_snr 31 | bgn_iter = args.bgn_iter 32 | fin_iter = args.fin_iter 33 | interval_iter = args.interval_iter 34 | 35 | tr_losses, te_losses, iters = [], [], [] 36 | 37 | # Load stats. 38 | stats_dir = os.path.join(workspace, "training_stats", "%ddb" % int(tr_snr)) 39 | for iter in xrange(bgn_iter, fin_iter, interval_iter): 40 | stats_path = os.path.join(stats_dir, "%diters.p" % iter) 41 | dict = cPickle.load(open(stats_path, 'rb')) 42 | tr_losses.append(dict['tr_loss']) 43 | te_losses.append(dict['te_loss']) 44 | iters.append(dict['iter']) 45 | 46 | # Plot 47 | # line_tr, = plt.plot(tr_losses, c='b', label="Train") 48 | # line_te, = plt.plot(te_losses, c='r', label="Test") 49 | # plt.axis([0, len(iters), 0, max(tr_losses)]) 50 | # plt.xlabel("Iterations") 51 | # plt.ylabel("Loss") 52 | # plt.legend(handles=[line_tr, line_te]) 53 | # plt.xticks(np.arange(len(iters)), iters) 54 | # plt.show() 55 | 56 | 57 | def calculate_pesq(args): 58 | """Calculate PESQ of all enhaced speech. 59 | 60 | Args: 61 | workspace: str, path of workspace. 62 | speech_dir: str, path of clean speech. 63 | te_snr: float, testing SNR. 64 | """ 65 | # Remove already existed file. 66 | data_type = args.data_type 67 | speech_dir = "mini_data/test_speech" 68 | f = "{0:<16} {1:<16} {2:<16}" 69 | print(f.format("0", "Noise", "PESQ")) 70 | f1 = open(data_type + '_pesq_results.csv', 'w') 71 | f1.write("%s\t%s\n"%("audio_id", "PESQ")) 72 | # Calculate PESQ of all enhaced speech. 73 | if data_type=="DM": 74 | enh_speech_dir = os.path.join("workspace", "enh_wavs", "test", "mixdb") 75 | elif data_type=="IRM": 76 | enh_speech_dir = os.path.join("workspace", "enh_wavs", "test", "mask_mixdb") 77 | elif data_type=="CRN": 78 | enh_speech_dir = os.path.join("workspace", "enh_wavs", "test", "crn_mixdb") 79 | elif data_type=="PHASE": 80 | enh_speech_dir = os.path.join("workspace", "enh_wavs", "test", "phase_spec_clean_mixdb") 81 | elif data_type=="VOLUME": 82 | enh_speech_dir = os.path.join("workspace", "enh_wavs", "test", "volume_mixdb") 83 | elif data_type=="NOISE": 84 | enh_speech_dir = os.path.join("workspace" ,'mixed_audios','spectrogram','test','mixdb') 85 | names = os.listdir(enh_speech_dir) 86 | for (cnt, na) in enumerate(names): 87 | enh_path = os.path.join(enh_speech_dir, na) 88 | enh_audio, fs = soundfile.read(enh_path) 89 | speech_na = na.split('.')[0] 90 | speech_path = os.path.join(speech_dir, "%s.WAV" % speech_na) 91 | speech_audio, fs = soundfile.read(speech_path) 92 | #alpha = 1. / np.max(np.abs(speech_audio)) 93 | #speech_audio *=alpha 94 | pesq_ = pypesq(16000, speech_audio, enh_audio, 'wb') 95 | print(f.format(cnt, na, pesq_)) 96 | f1.write("%s\t%f\n"%(na, pesq_)) 97 | # Call executable PESQ tool. 98 | #cmd = ' '.join(["./pesq", speech_path, enh_path, "+16000"]) 99 | #os.system(cmd) 100 | os.system("mv %s_pesq_results.csv ./pesq_result/%s_pesq_results.csv"%(data_type, data_type)) 101 | 102 | 103 | def get_stats(args): 104 | """Calculate stats of PESQ. 105 | """ 106 | data_type = args.data_type 107 | pesq_path = "./pesq_result/"+ data_type+ "_pesq_results.csv" 108 | with open(pesq_path, 'rb') as f: 109 | reader = csv.reader(f, delimiter='\t') 110 | lis = list(reader) 111 | 112 | pesq_dict = {} 113 | for i1 in xrange(1, len(lis) - 1): 114 | li = lis[i1] 115 | na = li[0] 116 | pesq = float(li[1]) 117 | noise_type = na.split('.')[1] 118 | if noise_type not in pesq_dict.keys(): 119 | pesq_dict[noise_type] = [pesq] 120 | else: 121 | pesq_dict[noise_type].append(pesq) 122 | out_csv_path ='./pesq_result/'+ data_type +'_pesq_differentnoise.csv' 123 | csv_file = open(out_csv_path, 'w') 124 | avg_list, std_list = [], [] 125 | f = "{0:<16} {1:<16}" 126 | print(f.format("Noise", "PESQ")) 127 | csv_file.write("%s\t%s\n"%("Noise", "PESQ")) 128 | print("---------------------------------") 129 | csv_file.write("----------------\t-----------------\n") 130 | for noise_type in pesq_dict.keys(): 131 | pesqs = pesq_dict[noise_type] 132 | avg_pesq = np.mean(pesqs) 133 | std_pesq = np.std(pesqs) 134 | avg_list.append(avg_pesq) 135 | std_list.append(std_pesq) 136 | print(f.format(noise_type, "%.2f +- %.2f" % (avg_pesq, std_pesq))) 137 | csv_file.write("%s\t%s\n"%(noise_type, "%.2f +- %.2f" % (avg_pesq, std_pesq))) 138 | print("---------------------------------") 139 | csv_file.write("----------------\t-----------------\n") 140 | print(f.format("Avg.", "%.2f +- %.2f" % (np.mean(avg_list), np.mean(std_list)))) 141 | csv_file.write("%s\t%s\n"%("Avg.", "%.2f +- %.2f" % (np.mean(avg_list), np.mean(std_list)))) 142 | csv_file.close() 143 | 144 | 145 | 146 | 147 | 148 | def get_snr_stats(args): 149 | 150 | data_type = args.data_type 151 | pesq_path = os.path.join("pesq_result", data_type + "_pesq_results.csv") 152 | with open(pesq_path, 'rb') as f: 153 | reader = csv.reader(f, delimiter='\t') 154 | pesq_lis = list(reader) 155 | 156 | pesq_lis[0].append("SNR") 157 | pesq_title = pesq_lis[0] 158 | pesq_lis = pesq_lis[:-1] 159 | csv_path = os.path.join("workspace", "mixture_csvs", "test_1hour_even.csv") 160 | with open(csv_path, 'rb') as f: 161 | reader = csv.reader(f, delimiter='\t') 162 | csv_lis = list(reader) 163 | 164 | count = 0 165 | for csv_name in csv_lis[1:]: 166 | if data_type=="NOISE": 167 | csv_na = csv_name[0].split(".")[0] + "." + csv_name[1].split(".")[0]+ "."+csv_name[-1] + "db.wav" 168 | else: 169 | csv_na = csv_name[0].split(".")[0] + "." + csv_name[1].split(".")[0]+ "."+csv_name[-1] + "db.enh.wav" 170 | for pesq_name in pesq_lis[1:]: 171 | if csv_na == pesq_name[0]: 172 | count+=1 173 | pesq_name.append(csv_name[-1]) 174 | break 175 | 176 | pesq_dict = {} 177 | for i1 in xrange(1, len(pesq_lis)): 178 | li = pesq_lis[i1] 179 | na = li[0] 180 | pesq = float(li[1][0:4]) 181 | snr = float(li[-1]) 182 | snr_key = snr 183 | if snr_key not in pesq_dict.keys(): 184 | pesq_dict[snr_key] = [pesq] 185 | else: 186 | pesq_dict[snr_key].append(pesq) 187 | 188 | out_csv_path = os.path.join( "pesq_result", data_type + "_snr_results.csv") 189 | create_folder(os.path.dirname(out_csv_path)) 190 | csv_file = open(out_csv_path, 'w') 191 | avg_list, std_list = [], [] 192 | sample_sum = 0 193 | f = "{0:<16} {1:<16} {2:<16}" 194 | print(f.format("SNR", "PESQ", "SAMPLE_NUM")) 195 | csv_file.write("%s\t%s\t%s\n"%("SNR", "PESQ", "SAMPLE_NUM")) 196 | csv_file.flush() 197 | print("---------------------------------") 198 | for snr_type in sorted(pesq_dict.keys()): 199 | pesqs = pesq_dict[snr_type] 200 | sample_num = len(pesqs) 201 | sample_sum+=sample_num 202 | avg_pesq = np.mean(pesqs) 203 | std_pesq = np.std(pesqs) 204 | avg_list.append(avg_pesq) 205 | std_list.append(std_pesq) 206 | print(f.format(snr_type, "%.2f +- %.2f" % (avg_pesq, std_pesq), sample_num)) 207 | csv_file.write("%s\t%s\t%s\n"%(snr_type, "%.2f +- %.2f" % (avg_pesq, std_pesq), sample_num)) 208 | csv_file.flush() 209 | 210 | print("---------------------------------") 211 | print(f.format("Avg.", "%.2f +- %.2f" % (np.mean(avg_list), np.mean(std_list)), sample_sum)) 212 | csv_file.write("%s\t%s\t%s\n"%("Avg.", "%.2f +- %.2f" % (np.mean(avg_list), np.mean(std_list)), "%d"%sample_sum)) 213 | csv_file.close() 214 | 215 | 216 | 217 | 218 | 219 | def calculate_stoi(args): 220 | workspace = "workspace" 221 | speech_dir = "mini_data/test_speech" 222 | # Calculate PESQ of all enhaced speech. 223 | enh_speech_dir = os.path.join(workspace, "enh_wavs", "test", "mixdb") 224 | #enh_speech_dir = "/data00/wangjinchao/sednn-master/mixture2clean_dnn/workspace/mixed_audios/spectrogram/test/mixdb" 225 | # enh_speech_dir = os.path.join(workspace ,'mixed_audios','spectrogram','test','mixdb') 226 | names = os.listdir(enh_speech_dir) 227 | f = open("IRM_stoi.txt", "w") 228 | f.write("%s\t%s\n"%("speech_id", "stoi")) 229 | f.flush() 230 | for (cnt, na) in enumerate(names): 231 | print(cnt, na) 232 | enh_path = os.path.join(enh_speech_dir, na) 233 | speech_na = na.split('.')[0] 234 | speech_path = os.path.join(speech_dir, "%s.WAV" % speech_na) 235 | speech_audio, fs = read_audio(speech_path, 16000) 236 | enhance_audio, fs = read_audio(enh_path, 16000) 237 | if len(speech_audio)>len(enhance_audio): 238 | speech_audio = speech_audio[:len(enhance_audio)] 239 | else: 240 | enhance_audio = enhance_audio[:len(speech_audio)] 241 | stoi_value = stoi(speech_audio, enhance_audio, fs, extended = False) 242 | f.write("%s\t%f\n"%(na, stoi_value)) 243 | f.flush() 244 | f.close() 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | def get_stoi_stats(args): 253 | stoi_path = "./stoi_result/IRM_stoi.txt" 254 | with open(stoi_path, 'rb') as f: 255 | reader = csv.reader(f, delimiter='\t') 256 | lis = list(reader) 257 | 258 | stoi_dict = {} 259 | for i1 in xrange(1, len(lis) - 1): 260 | li = lis[i1] 261 | na = li[0] 262 | stoi = float(li[1]) 263 | noise_type = na.split('.')[1] 264 | if noise_type not in stoi_dict.keys(): 265 | stoi_dict[noise_type] = [stoi] 266 | else: 267 | stoi_dict[noise_type].append(stoi) 268 | #out_csv_path ='./stoi_result/gvdm_enhance.csv' 269 | #csv_file = open(out_csv_path, 'w') 270 | avg_list, std_list = [], [] 271 | f = "{0:<16} {1:<16}" 272 | print(f.format("Noise", "STOI")) 273 | #csv_file.write("%s\t%s\n"%("Noise", "stoi")) 274 | print("---------------------------------") 275 | #csv_file.write("----------------\t-----------------\n") 276 | for noise_type in stoi_dict.keys(): 277 | stois = stoi_dict[noise_type] 278 | avg_stoi = np.mean(stois) 279 | std_stoi = np.std(stois) 280 | avg_list.append(avg_stoi) 281 | std_list.append(std_stoi) 282 | print(f.format(noise_type, "%.5f +- %.5f" % (avg_stoi, std_stoi))) 283 | #csv_file.write("%s\t%s\n"%(noise_type, "%.2f +- %.2f" % (avg_stoi, std_stoi))) 284 | print("---------------------------------") 285 | #csv_file.write("----------------\t-----------------\n") 286 | print(f.format("Avg.", "%.2f +- %.2f" % (np.mean(avg_list), np.mean(std_list)))) 287 | 288 | 289 | 290 | if __name__ == '__main__': 291 | parser = argparse.ArgumentParser() 292 | subparsers = parser.add_subparsers(dest='mode') 293 | 294 | parser_plot_training_stat = subparsers.add_parser('plot_training_stat') 295 | parser_plot_training_stat.add_argument('--workspace', type=str, required=True) 296 | parser_plot_training_stat.add_argument('--tr_snr', type=float, required=True) 297 | parser_plot_training_stat.add_argument('--bgn_iter', type=int, required=True) 298 | parser_plot_training_stat.add_argument('--fin_iter', type=int, required=True) 299 | parser_plot_training_stat.add_argument('--interval_iter', type=int, required=True) 300 | 301 | parser_calculate_pesq = subparsers.add_parser('calculate_pesq') 302 | parser_calculate_pesq.add_argument('--data_type', type=str, required=True) 303 | 304 | parser_get_stats = subparsers.add_parser('get_stats') 305 | parser_get_stats.add_argument('--data_type', type=str, required=True) 306 | 307 | parser_get_snr_stats = subparsers.add_parser('get_snr_stats') 308 | parser_get_snr_stats.add_argument('--data_type', type=str, required=True) 309 | 310 | 311 | 312 | args = parser.parse_args() 313 | 314 | if args.mode == 'plot_training_stat': 315 | plot_training_stat(args) 316 | elif args.mode == 'calculate_pesq': 317 | calculate_pesq(args) 318 | elif args.mode == 'get_stats': 319 | get_stats(args) 320 | elif args.mode == 'get_snr_stats': 321 | get_snr_stats(args) 322 | else: 323 | raise Exception("Error!") 324 | -------------------------------------------------------------------------------- /gv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | gv_ref_dependent = np.array([0.8188001 , 0.7655154 , 1.0443362 , 1.291341 , 1.3091451 , 4 | 1.3863533 , 1.3579204 , 1.3562269 , 1.3332679 , 1.3779601 , 5 | 1.3638331 , 1.4141914 , 1.4132004 , 1.414608 , 1.4191983 , 6 | 1.3779674 , 1.3744586 , 1.4065691 , 1.405165 , 1.3589902 , 7 | 1.3695769 , 1.3828198 , 1.399761 , 1.4110584 , 1.4237365 , 8 | 1.4146262 , 1.4150872 , 1.4020557 , 1.4402065 , 1.4247525 , 9 | 1.4152107 , 1.3777906 , 1.4041775 , 1.411573 , 1.4258041 , 10 | 1.4248255 , 1.4343295 , 1.4284252 , 1.3958426 , 1.3800949 , 11 | 1.394799 , 1.402656 , 1.3995781 , 1.3867273 , 1.4019246 , 12 | 1.394503 , 1.3876013 , 1.3906093 , 1.3923353 , 1.3908792 , 13 | 1.3651602 , 1.3789821 , 1.3817782 , 1.3878808 , 1.3868887 , 14 | 1.389586 , 1.3882133 , 1.3948598 , 1.3833323 , 1.3911697 , 15 | 1.3947376 , 1.3785598 , 1.3657677 , 1.3754646 , 1.3744026 , 16 | 1.36841 , 1.3738396 , 1.375986 , 1.3782787 , 1.3705876 , 17 | 1.3561313 , 1.363172 , 1.3721641 , 1.3663605 , 1.3701444 , 18 | 1.3718685 , 1.3587731 , 1.3583094 , 1.3632051 , 1.3683681 , 19 | 1.3819396 , 1.3825235 , 1.378892 , 1.3761448 , 1.3808253 , 20 | 1.3743024 , 1.367832 , 1.3641973 , 1.3663458 , 1.369809 , 21 | 1.371535 , 1.3641069 , 1.363354 , 1.3653663 , 1.3578664 , 22 | 1.3501805 , 1.3377979 , 1.3453208 , 1.3447514 , 1.3466262 , 23 | 1.3516669 , 1.3419527 , 1.3322309 , 1.3304617 , 1.3314892 , 24 | 1.3222749 , 1.3076648 , 1.3175845 , 1.3237734 , 1.3146265 , 25 | 1.3085129 , 1.3097675 , 1.3060361 , 1.299763 , 1.2958938 , 26 | 1.2963424 , 1.2883214 , 1.2881285 , 1.2870046 , 1.2888812 , 27 | 1.2778481 , 1.2760473 , 1.2680486 , 1.2644651 , 1.2633371 , 28 | 1.2600574 , 1.2674776 , 1.2619113 , 1.25404 , 1.2484775 , 29 | 1.2528795 , 1.2445921 , 1.2449573 , 1.2370106 , 1.240662 , 30 | 1.2343256 , 1.2296497 , 1.2207483 , 1.2245104 , 1.212012 , 31 | 1.2099534 , 1.2040404 , 1.2014705 , 1.2012196 , 1.1975276 , 32 | 1.1931353 , 1.1944716 , 1.1941463 , 1.1930957 , 1.1830707 , 33 | 1.1817104 , 1.1773063 , 1.1705128 , 1.1806594 , 1.1794373 , 34 | 1.175316 , 1.1757798 , 1.1782918 , 1.1770912 , 1.1753559 , 35 | 1.1691241 , 1.1691626 , 1.1616837 , 1.1592903 , 1.1525471 , 36 | 1.148833 , 1.1445248 , 1.1463698 , 1.1432943 , 1.1372362 , 37 | 1.1345378 , 1.1331203 , 1.1327978 , 1.1356372 , 1.1281763 , 38 | 1.117315 , 1.1229038 , 1.1331227 , 1.129955 , 1.1205344 , 39 | 1.1168914 , 1.1162447 , 1.1205385 , 1.1221027 , 1.1183283 , 40 | 1.1176765 , 1.1073152 , 1.1065495 , 1.1066844 , 1.1020577 , 41 | 1.0956546 , 1.0937659 , 1.0824373 , 1.0914868 , 1.0957388 , 42 | 1.0990036 , 1.0980628 , 1.1037108 , 1.0973698 , 1.0961391 , 43 | 1.0953025 , 1.09513 , 1.093008 , 1.0896668 , 1.0927784 , 44 | 1.0900792 , 1.0936061 , 1.0935822 , 1.0972129 , 1.0939381 , 45 | 1.0888202 , 1.0845745 , 1.0836582 , 1.0842501 , 1.0809636 , 46 | 1.0757244 , 1.076439 , 1.0760363 , 1.0668286 , 1.0531492 , 47 | 1.0502294 , 1.0589144 , 1.0721456 , 1.0731709 , 1.0684367 , 48 | 1.0632014 , 1.0599935 , 1.0586678 , 1.0569472 , 1.0625534 , 49 | 1.0626838 , 1.0648353 , 1.0663067 , 1.06597 , 1.0638473 , 50 | 1.0639621 , 1.0637795 , 1.0606909 , 1.0582322 , 1.0517532 , 51 | 1.0480362 , 1.0479565 , 1.0435289 , 1.0371186 , 1.0334498 , 52 | 1.0291612 , 1.0249708 , 1.0198319 , 1.0156595 , 1.0108268 , 53 | 1.0073904 , 1.0044657 , 1.0021349 , 1.0046784 , 1.0016526 , 54 | 0.99999756, 0.9964546 ]) 55 | 56 | 57 | 58 | gv_est_dependent = np.array([0.18696505, 0.34280822, 0.72116977, 1.0536233 , 1.0966913 , 59 | 1.1637543 , 1.1548996 , 1.1676413 , 1.1474363 , 1.171958 , 60 | 1.1660068 , 1.2174896 , 1.2271285 , 1.2325336 , 1.2390236 , 61 | 1.2042515 , 1.2004554 , 1.2231534 , 1.2235614 , 1.1827441 , 62 | 1.1899538 , 1.1991223 , 1.2111914 , 1.2175452 , 1.2235477 , 63 | 1.213116 , 1.2148174 , 1.1969019 , 1.221351 , 1.2065829 , 64 | 1.2010144 , 1.1644497 , 1.1922417 , 1.1955268 , 1.2043128 , 65 | 1.1993589 , 1.2122489 , 1.2041621 , 1.1712927 , 1.1558433 , 66 | 1.1670535 , 1.1739376 , 1.171306 , 1.1575273 , 1.1754738 , 67 | 1.1698086 , 1.1619811 , 1.1717362 , 1.1703048 , 1.1663839 , 68 | 1.146894 , 1.1596279 , 1.1625977 , 1.1698763 , 1.1666222 , 69 | 1.1692064 , 1.1689862 , 1.1767648 , 1.1620858 , 1.1627572 , 70 | 1.167898 , 1.160154 , 1.1409459 , 1.1539472 , 1.1525996 , 71 | 1.1464965 , 1.144994 , 1.147802 , 1.1467297 , 1.1443332 , 72 | 1.1305172 , 1.1382952 , 1.1508311 , 1.1470501 , 1.1399091 , 73 | 1.1442221 , 1.1377715 , 1.1442518 , 1.1429965 , 1.1489736 , 74 | 1.1564287 , 1.157903 , 1.1531717 , 1.1507055 , 1.1558954 , 75 | 1.148154 , 1.1464262 , 1.140695 , 1.1445204 , 1.1501266 , 76 | 1.1467679 , 1.1315886 , 1.1395706 , 1.138759 , 1.1322058 , 77 | 1.1181237 , 1.11835 , 1.11936 , 1.1205423 , 1.1255524 , 78 | 1.1306192 , 1.1268992 , 1.113448 , 1.1130816 , 1.1138029 , 79 | 1.1083379 , 1.0993805 , 1.1055454 , 1.1122226 , 1.1078554 , 80 | 1.1074036 , 1.100761 , 1.1006709 , 1.1015122 , 1.09618 , 81 | 1.0973667 , 1.0900112 , 1.085913 , 1.081041 , 1.0914049 , 82 | 1.0785384 , 1.0716164 , 1.0673122 , 1.0675014 , 1.0644028 , 83 | 1.0625845 , 1.0691746 , 1.0601567 , 1.0490003 , 1.0468317 , 84 | 1.0543352 , 1.0443738 , 1.0391475 , 1.0349274 , 1.043713 , 85 | 1.0387459 , 1.0371169 , 1.0293257 , 1.0260344 , 1.0175334 , 86 | 1.0159734 , 1.0076747 , 1.0029306 , 1.0055224 , 0.9981165 , 87 | 0.9971783 , 0.99703634, 0.9997423 , 0.9974254 , 0.98594177, 88 | 0.9810835 , 0.9804084 , 0.9802749 , 0.9853521 , 0.98604727, 89 | 0.9816301 , 0.97491074, 0.9801819 , 0.9790348 , 0.98185617, 90 | 0.97771597, 0.9796749 , 0.9697705 , 0.96665776, 0.9637393 , 91 | 0.9561936 , 0.94966286, 0.9454328 , 0.94490314, 0.94937634, 92 | 0.9478694 , 0.9447655 , 0.94176424, 0.9403081 , 0.94139105, 93 | 0.9293294 , 0.93296844, 0.9450051 , 0.94217265, 0.9318225 , 94 | 0.9312917 , 0.93434596, 0.9384418 , 0.93287736, 0.9366608 , 95 | 0.9325247 , 0.9249665 , 0.9251077 , 0.92365205, 0.920737 , 96 | 0.9119053 , 0.90590674, 0.9011289 , 0.915063 , 0.9160822 , 97 | 0.92288536, 0.91581845, 0.92356586, 0.9192137 , 0.918826 , 98 | 0.9247057 , 0.9222291 , 0.91558635, 0.90998685, 0.9195295 , 99 | 0.9112585 , 0.9219293 , 0.91975945, 0.918718 , 0.91701937, 100 | 0.917412 , 0.90922654, 0.90436333, 0.9087872 , 0.9071029 , 101 | 0.90249467, 0.9022592 , 0.90647477, 0.89404655, 0.8869796 , 102 | 0.8787871 , 0.8882476 , 0.90208715, 0.9047103 , 0.9043998 , 103 | 0.90141785, 0.8983854 , 0.89462155, 0.90119785, 0.9028099 , 104 | 0.90848804, 0.91339225, 0.9102626 , 0.9124032 , 0.9146076 , 105 | 0.9150387 , 0.92100835, 0.9140667 , 0.91219926, 0.91785675, 106 | 0.92215157, 0.9197757 , 0.9274899 , 0.9233789 , 0.9288496 , 107 | 0.92639154, 0.9255196 , 0.9201056 , 0.9202775 , 0.9202095 , 108 | 0.92024976, 0.9177347 , 0.9155735 , 0.921838 , 0.91567504, 109 | 0.9140105 , 0.85075736] ) -------------------------------------------------------------------------------- /main_crn.py: -------------------------------------------------------------------------------- 1 | from main_dnn import * 2 | from keras.layers import Reshape, Conv2D, BatchNormalization, ZeroPadding2D, Lambda 3 | from keras.layers import Input, Concatenate, LSTM, Conv2DTranspose, Cropping2D, ELU 4 | import keras 5 | import tensorflow as tf 6 | import time 7 | import os 8 | import config as cfg 9 | import prepare_data as pp_data 10 | from spectrogram_to_wave import * 11 | 12 | 13 | def parser_function(serialized_example): 14 | features = tf.parse_single_example(serialized_example, 15 | features={ 16 | 'x': tf.FixedLenFeature([], tf.string), 17 | 'y': tf.FixedLenFeature([], tf.string) 18 | }) 19 | x = tf.reshape(tf.decode_raw(features['x'], tf.float32), [11, 161]) 20 | y = tf.reshape(tf.decode_raw(features['y'], tf.float32), [11, 161]) 21 | return x, y 22 | 23 | 24 | def load_tfrecord(batch, repeat, data_path): 25 | dataset = tf.data.TFRecordDataset(data_path) 26 | dataset = dataset.map(parser_function) 27 | dataset = dataset.shuffle(10240) 28 | dataset = dataset.batch(batch) 29 | dataset = dataset.repeat(repeat) 30 | iterator = dataset.make_one_shot_iterator() 31 | tr_x, tr_y = iterator.get_next() 32 | return tr_x, tr_y 33 | 34 | 35 | 36 | def pad_with_border(x, n_pad): 37 | """Pad the begin and finish of spectrogram with border frame value. 38 | """ 39 | x_pad_list = [x[0:1]] * n_pad + [x] 40 | return np.concatenate(x_pad_list, axis=0) 41 | 42 | 43 | def mat_2d_to_3d(x, agg_num, hop): 44 | """Segment 2D array to 3D segments. 45 | """ 46 | # Pad to at least one block. 47 | len_x, n_in = x.shape 48 | if (len_x < agg_num): 49 | x = np.concatenate((x, np.zeros((agg_num - len_x, n_in)))) 50 | # Segment 2d to 3d. 51 | len_x = len(x) 52 | i1 = 0 53 | x3d = [] 54 | while (i1 + agg_num <= len_x): 55 | x3d.append(x[i1 : i1 + agg_num]) 56 | i1 += hop 57 | return np.array(x3d) 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | def inference(args): 66 | workspace = "workspace" 67 | n_concat = 11 68 | iter = 50000 69 | n_window = 320 70 | n_overlap = 160 71 | fs = 16000 72 | # Load model. 73 | model_path = os.path.join(workspace, "models", "crn_mixdb", "md_%diters.h5" % iter) 74 | model = load_model(model_path, custom_objects={'keras': keras}) 75 | # Load test data. 76 | feat_dir = os.path.join(workspace, "features", "spectrogram", "test", "crn_mixdb") 77 | #feat_dir = os.path.join(workspace, "features", "spectrogram", "train", "office_mixdb") 78 | names = os.listdir(feat_dir) 79 | for (cnt, na) in enumerate(names): 80 | # Load feature. 81 | feat_path = os.path.join(feat_dir, na) 82 | data = cPickle.load(open(feat_path, 'rb')) 83 | [mixed_cmplx_x, speech_x, noise_x, alpha, na] = data 84 | mixed_x = np.abs(mixed_cmplx_x) 85 | # Process data. 86 | n_pad = (n_concat - 1) 87 | #mixed_x = pad_with_border(mixed_x, n_pad) 88 | # Cut input spectrogram to 3D segments with n_concat. 89 | mixed_x_3d = pp_data.mat_2d_to_3d(mixed_x, agg_num=n_concat, hop=11)#[100, 7, 257] 90 | #mixed_x = pad_with_border(mixed_x, n_pad) 91 | #mixed_x_3d = mat_2d_to_3d(mixed_x, agg_num=n_concat, hop=1) 92 | # Predict. 93 | w, h, l = mixed_x_3d.shape 94 | pred = model.predict(mixed_x_3d) 95 | pred_sp = np.reshape(pred, [w*h, l]) 96 | mixed_cmplx_x = mixed_cmplx_x[:w*h, :] 97 | #pred_sp = pred[:, -1, :] 98 | print(cnt, na) 99 | if False: 100 | fig, axs = plt.subplots(3,1, sharex=False) 101 | axs[0].matshow(mixed_x.T, origin='lower', aspect='auto', cmap='jet') 102 | axs[1].matshow(speech_x.T, origin='lower', aspect='auto', cmap='jet') 103 | axs[2].matshow(pred_sp.T, origin='lower', aspect='auto', cmap='jet') 104 | axs[0].set_title("%ddb mixture log spectrogram" % int(1)) 105 | axs[1].set_title("Clean speech log spectrogram") 106 | axs[2].set_title("Enhanced speech log spectrogram") 107 | for j1 in range(3): 108 | axs[j1].xaxis.tick_bottom() 109 | plt.tight_layout() 110 | plt.show() 111 | # Recover enhanced wav. 112 | #pred_sp = np.exp(pred) 113 | #pred_sp = pred 114 | s = recover_wav(pred_sp, mixed_cmplx_x, n_overlap, np.hamming) 115 | s *= np.sqrt((np.hamming(n_window)**2).sum()) # Scaler for compensate the amplitude 116 | # Write out enhanced wav. 117 | out_path = os.path.join(workspace, "enh_wavs", "test", "crn_mixdb", "%s.enh.wav" % na) 118 | pp_data.create_folder(os.path.dirname(out_path)) 119 | pp_data.write_audio(out_path, s, fs) 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | def train_tfrecords(args): 129 | lr = args.lr 130 | # Load data. 131 | t1 = time.time() 132 | tr_hdf5_dir = os.path.join("workspace", "tfrecords", "train", "crn_mixdb") 133 | tr_hdf5_names = os.listdir(tr_hdf5_dir) 134 | tr_path_list = [os.path.join(tr_hdf5_dir, i) for i in tr_hdf5_names] 135 | te_hdf5_path = os.path.join("workspace", "packed_features", "spectrogram", "test", "crn_mixdb" , "data.h5") 136 | (te_x, te_y) = pp_data.load_hdf5(te_hdf5_path) 137 | print("test.h5 loaded ! ! !") 138 | train_path = os.path.join("workspace", "packed_features", "spectrogram", "train", "crn_mixdb" , "data.h5") 139 | (tr_x, tr_y) = pp_data.load_hdf5(train_path) 140 | print("train.h5 loaded ! ! !") 141 | batch_size = 1024 142 | 143 | # Scale data. 144 | t1 = time.time() 145 | 146 | input_x = Input(shape = (11, 161)) 147 | reshape_x = Reshape((1, 11, 161), input_shape = (11, 161))(input_x) 148 | l1_input = ZeroPadding2D(padding = ((1, 0), (0, 0)), data_format = "channels_first")(reshape_x) 149 | l1 = Conv2D(filters=16,kernel_size=(2,3),strides=(1,2), activation=None 150 | , data_format="channels_first", padding = "valid")(l1_input) 151 | l1 = BatchNormalization()(l1) 152 | l1 = ELU()(l1) 153 | 154 | l2_input = ZeroPadding2D(padding = ((1, 0), (0, 0)), data_format = "channels_first")(l1) 155 | l2 = Conv2D(filters=32,kernel_size=(2,3),strides=(1,2), activation=None 156 | , data_format="channels_first" , padding = "valid")(l2_input) 157 | l2 = BatchNormalization()(l2) 158 | l2 = ELU()(l2) 159 | 160 | l3_input = ZeroPadding2D(padding = ((1, 0), (0, 0)), data_format = "channels_first")(l2) 161 | l3 = Conv2D(filters=64,kernel_size=(2,3),strides=(1,2), activation=None 162 | , data_format="channels_first", padding = "valid")(l3_input) 163 | l3 = BatchNormalization()(l3) 164 | l3 = ELU()(l3) 165 | 166 | l4_input = ZeroPadding2D(padding = ((1, 0), (0, 0)), data_format = "channels_first")(l3) 167 | l4 = Conv2D(filters=128,kernel_size=(2,3),strides=(1,2), activation=None 168 | , data_format="channels_first", padding = "valid")(l4_input) 169 | l4 = BatchNormalization()(l4) 170 | l4 = ELU()(l4) 171 | 172 | l5_input = ZeroPadding2D(padding = ((1, 0), (0, 0)), data_format = "channels_first")(l4) 173 | l5 = Conv2D(filters=256,kernel_size=(2,3),strides=(1,2), activation=None 174 | , data_format="channels_first", padding = "valid")(l5_input) 175 | l5 = BatchNormalization()(l5) 176 | l5 = ELU()(l5) 177 | 178 | reshape_x2 = Reshape((11, 4*256), input_shape = (11, 4, 256))(l5) 179 | lstm1 = LSTM(units = 4*256, activation = 'tanh', return_sequences = True)(reshape_x2) 180 | lstm2 = LSTM(units = 4*256, activation = 'tanh', return_sequences = True)(lstm1) 181 | reshape_x3 = Reshape((256, 11, 4), input_shape = (11, 4*256))(lstm2) 182 | 183 | 184 | l8_input = Concatenate(axis = 1)([reshape_x3, l5]) 185 | l8 = Conv2DTranspose(filters=128,kernel_size=(2,3),strides=(1,2), activation=None 186 | , data_format="channels_first", padding = "valid")(l8_input) 187 | l8 = Cropping2D(cropping = ((1, 0), (0, 0)), data_format = "channels_first")(l8) 188 | l8 = BatchNormalization()(l8) 189 | l8 = ELU()(l8) 190 | 191 | 192 | l9_input = Concatenate(axis = 1)([l8, l4]) 193 | l9 = Conv2DTranspose(filters=64,kernel_size=(2,3),strides=(1,2), activation=None 194 | , data_format="channels_first", padding = "valid")(l9_input) 195 | l9 = Cropping2D(cropping = ((1, 0), (0, 0)), data_format = "channels_first")(l9) 196 | l9 = BatchNormalization()(l9) 197 | l9 = ELU()(l9) 198 | 199 | 200 | l10_input = Concatenate(axis = 1)([l9, l3]) 201 | l10 = Conv2DTranspose(filters=32,kernel_size=(2,3),strides=(1,2), activation=None 202 | , data_format="channels_first", padding = "valid")(l10_input) 203 | l10 = Cropping2D(cropping = ((1, 0), (0, 0)), data_format = "channels_first")(l10) 204 | l10 = BatchNormalization()(l10) 205 | l10 = ELU()(l10) 206 | 207 | l11_input = Concatenate(axis = 1)([l10, l2]) 208 | l11_input = ZeroPadding2D(padding = ((0, 0), (1, 0)), data_format = "channels_first")(l11_input) 209 | l11 = Conv2DTranspose(filters=16,kernel_size=(2,3),strides=(1,2), activation=None 210 | , data_format="channels_first", padding = "valid")(l11_input) 211 | l11 = Cropping2D(cropping = ((1, 0), (1, 0)), data_format = "channels_first")(l11) 212 | l11 = BatchNormalization()(l11) 213 | l11 = ELU()(l11) 214 | 215 | l12_input = Concatenate(axis = 1)([l11, l1]) 216 | l12 = Conv2DTranspose(filters=1,kernel_size=(2,3),strides=(1,2), activation=None 217 | , data_format="channels_first", padding = "valid")(l12_input) 218 | l12 = Cropping2D(cropping = ((1, 0), (0, 0)), data_format = "channels_first")(l12) 219 | l12 = Reshape((11, 161), input_shape = (11, 161, 1))(l12) 220 | l12 = Lambda(lambda x: keras.activations.softplus(x))(l12) 221 | #l12 = keras.layers.Lambda(lambda x:keras.activations.softplus(x))(l12) 222 | model = keras.models.Model(inputs = [input_x], outputs = l8) 223 | model.summary() 224 | #lr = 5e-5 225 | #model_path = os.path.join(workspace, "models", "crn_mixdb", "md_%diters.h5" % 3935) 226 | #model = load_model(model_path, custom_objects={'tf': tf}) 227 | #model = multi_gpu_model(model, 4) 228 | 229 | model.compile(loss='mean_absolute_error', 230 | optimizer=Adam(lr=lr, beta_1 = 0.9)) 231 | print("model is built ! ! !") 232 | # Data generator. 233 | eval_tr_gen = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100) 234 | eval_te_gen = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100) 235 | 236 | # Directories for saving models and training stats 237 | model_dir = os.path.join("workspace", "models", "crn_mixdb") 238 | pp_data.create_folder(model_dir) 239 | 240 | stats_dir = os.path.join("workspace", "training_stats", "crn_mixdb") 241 | pp_data.create_folder(stats_dir) 242 | 243 | # Print loss before training. 244 | iter = 0 245 | print("start calculating initial loss.......") 246 | tr_loss = eval(model, eval_tr_gen, tr_x, tr_y) 247 | te_loss = eval(model, eval_te_gen, te_x, te_y) 248 | print("Iteration: %d, tr_loss: %f, te_loss: %f" % (iter, tr_loss, te_loss)) 249 | # Save out training stats. 250 | stat_dict = {'iter': iter, 251 | 'tr_loss': tr_loss, 252 | 'te_loss': te_loss, } 253 | stat_path = os.path.join(stats_dir, "%diters.p" % iter) 254 | cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 255 | # Train. 256 | sess = tf.Session() 257 | x, y = load_tfrecord(batch = batch_size, repeat = 100000, data_path = tr_path_list) 258 | t1 = time.time() 259 | for count in range(1000000000): 260 | [tr_x, tr_y] = sess.run([x, y]) 261 | loss = model.train_on_batch(tr_x, tr_y) 262 | iter += 1 263 | # Validate and save training stats. 264 | if iter % 1000 == 0: 265 | tr_loss = eval(model, eval_tr_gen, tr_x, tr_y) 266 | te_loss = eval(model, eval_te_gen, te_x, te_y) 267 | #te_loss = tr_loss 268 | print("Iteration: %d, tr_loss: %f, te_loss: %f" % (iter, tr_loss, te_loss)) 269 | # Save out training stats. 270 | stat_dict = {'iter': iter, 271 | 'tr_loss': tr_loss, 272 | 'te_loss': te_loss, } 273 | stat_path = os.path.join(stats_dir, "%diters.p" % iter) 274 | cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 275 | # Save model. 276 | if iter % 5000 == 0: 277 | model_path = os.path.join(model_dir, "md_%diters.h5" % iter) 278 | model.save(model_path) 279 | print("Saved model to %s" % model_path) 280 | print("Training time: %s s" % (time.time() - t1,)) 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | if __name__ == '__main__': 289 | parser = argparse.ArgumentParser() 290 | subparsers = parser.add_subparsers(dest='mode') 291 | 292 | parser_train = subparsers.add_parser('train') 293 | parser_train.add_argument('--lr', default = 1e-4, type=float, required=False) 294 | 295 | parser_inference = subparsers.add_parser('inference') 296 | parser_inference.add_argument('--lr', default = 1e-4, type=float, required=False) 297 | parser_inference.add_argument('--iteration', type=int, default=50000) 298 | args = parser.parse_args() 299 | 300 | if args.mode=="inference": 301 | inference(args) 302 | else: 303 | train(args) 304 | 305 | 306 | 307 | 308 | 309 | -------------------------------------------------------------------------------- /main_dnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Summary: Train, inference and evaluate speech enhancement. 3 | Author: Qiuqiang Kong 4 | Created: 2017.12.22 5 | Modified: - 6 | """ 7 | import numpy as np 8 | import os 9 | import pickle 10 | import cPickle 11 | import h5py 12 | import argparse 13 | import time 14 | import glob 15 | #import matplotlib.pyplot as #plt 16 | import tensorflow as tf 17 | import prepare_data as pp_data 18 | import config as cfg 19 | from data_generator import DataGenerator 20 | from data_generator import DataGenerator_h5py 21 | from spectrogram_to_wave import recover_wav 22 | from keras.utils import multi_gpu_model 23 | from keras.models import Sequential 24 | from keras.layers import Dense, Dropout, Flatten 25 | from keras.optimizers import Adam 26 | from keras.models import load_model 27 | 28 | 29 | def eval(model, gen, x, y): 30 | """Validation function. 31 | 32 | Args: 33 | model: keras model. 34 | gen: object, data generator. 35 | x: 3darray, input, (n_segs, n_concat, n_freq) 36 | y: 2darray, target, (n_segs, n_freq) 37 | """ 38 | pred_all, y_all = [], [] 39 | 40 | # Inference in mini batch. 41 | for (batch_x, batch_y) in gen.generate(xs=[x], ys=[y]): 42 | pred = model.predict(batch_x) 43 | pred_all.append(pred) 44 | y_all.append(batch_y) 45 | 46 | # Concatenate mini batch prediction. 47 | pred_all = np.concatenate(pred_all, axis=0) 48 | y_all = np.concatenate(y_all, axis=0) 49 | 50 | # Compute loss. 51 | loss = pp_data.np_mean_absolute_error(y_all, pred_all) 52 | return loss 53 | 54 | 55 | 56 | 57 | 58 | def eval_h5py(model, gen, path_list): 59 | """Validation function. 60 | 61 | Args: 62 | model: keras model. 63 | gen: object, data generator. 64 | x: 3darray, input, (n_segs, n_concat, n_freq) 65 | y: 2darray, target, (n_segs, n_freq) 66 | """ 67 | pred_all, y_all = [], [] 68 | # Inference in mini batch. 69 | for (batch_x, batch_y) in gen.generate(path_list): 70 | pred = model.predict(batch_x) 71 | pred_all.append(pred) 72 | y_all.append(batch_y) 73 | # Concatenate mini batch prediction. 74 | pred_all = np.concatenate(pred_all, axis=0) 75 | y_all = np.concatenate(y_all, axis=0) 76 | # Compute loss. 77 | loss = pp_data.np_mean_absolute_error(y_all, pred_all) 78 | return loss 79 | 80 | 81 | 82 | 83 | 84 | 85 | def train(args): 86 | 87 | """Train the neural network. Write out model every several iterations. 88 | 89 | Args: 90 | workspace: str, path of workspace. 91 | tr_snr: float, training SNR. 92 | te_snr: float, testing SNR. 93 | lr: float, learning rate. 94 | """ 95 | print(args) 96 | workspace = args.workspace 97 | tr_snr = args.tr_snr 98 | te_snr = args.te_snr 99 | lr = args.lr 100 | data_type = "IRM" 101 | # Load data. 102 | t1 = time.time() 103 | # tr_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mixdb", "data.h5") 104 | if data_type=="DM": 105 | tr_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mixdb", "data.h5") 106 | te_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "test", "mixdb" , "data.h5") 107 | else: 108 | tr_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mask_mixdb", "data.h5") 109 | te_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "test", "mask_mixdb" , "data.h5") 110 | (tr_x, tr_y) = pp_data.load_hdf5(tr_hdf5_path) 111 | (te_x, te_y) = pp_data.load_hdf5(te_hdf5_path) 112 | print(tr_x.shape, tr_y.shape) 113 | print(te_x.shape, te_y.shape) 114 | print("Load data time: %s s" % (time.time() - t1,)) 115 | 116 | batch_size = 2048 117 | print("%d iterations / epoch" % int(tr_x.shape[0] / batch_size)) 118 | 119 | # Scale data. 120 | if True: 121 | t1 = time.time() 122 | scaler_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mixdb", "scaler.p") 123 | scaler = pickle.load(open(scaler_path, 'rb')) 124 | tr_x = pp_data.scale_on_3d(tr_x, scaler) 125 | te_x = pp_data.scale_on_3d(te_x, scaler) 126 | if data_type=="DM": 127 | tr_y = pp_data.scale_on_2d(tr_y, scaler) 128 | te_y = pp_data.scale_on_2d(te_y, scaler) 129 | print("Scale data time: %s s" % (time.time() - t1,)) 130 | 131 | # Debug plot. 132 | if False: 133 | #plt.matshow(tr_x[0 : 1000, 0, :].T, origin='lower', aspect='auto', cmap='jet') 134 | #plt.show() 135 | pause 136 | 137 | # Build model 138 | (_, n_concat, n_freq) = tr_x.shape 139 | n_hid = 2048 140 | 141 | model = Sequential() 142 | model.add(Flatten(input_shape=(n_concat, n_freq))) 143 | model.add(Dense(n_hid, activation='elu')) 144 | model.add(Dropout(0.2)) 145 | model.add(Dense(n_hid, activation='elu')) 146 | model.add(Dropout(0.2)) 147 | model.add(Dense(n_hid, activation='elu')) 148 | model.add(Dropout(0.2)) 149 | model.add(Dense(n_freq, activation='linear')) 150 | model.summary() 151 | 152 | model.compile(loss='mean_absolute_error', 153 | optimizer=Adam(lr=lr, beta_1 = 0.9)) 154 | 155 | 156 | 157 | 158 | 159 | # Data generator. 160 | tr_gen = DataGenerator(batch_size=batch_size, type='train') 161 | eval_te_gen = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100) 162 | eval_tr_gen = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100) 163 | 164 | # Directories for saving models and training stats 165 | if data_type =="DM": 166 | model_dir = os.path.join(workspace, "models", "mixdb") 167 | stats_dir = os.path.join(workspace, "training_stats", "mixdb") 168 | else: 169 | model_dir = os.path.join(workspace, "models", "mask_mixdb") 170 | stats_dir = os.path.join(workspace, "training_stats", "mask_mixdb") 171 | pp_data.create_folder(model_dir) 172 | pp_data.create_folder(stats_dir) 173 | 174 | # Print loss before training. 175 | iter = 0 176 | tr_loss = eval(model, eval_tr_gen, tr_x, tr_y) 177 | te_loss = eval(model, eval_te_gen, te_x, te_y) 178 | print("Iteration: %d, tr_loss: %f, te_loss: %f" % (iter, tr_loss, te_loss)) 179 | 180 | # Save out training stats. 181 | stat_dict = {'iter': iter, 182 | 'tr_loss': tr_loss, 183 | 'te_loss': te_loss, } 184 | stat_path = os.path.join(stats_dir, "%diters.p" % iter) 185 | cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 186 | 187 | # Train. 188 | t1 = time.time() 189 | 190 | for (batch_x, batch_y) in tr_gen.generate(xs=[tr_x], ys=[tr_y]): 191 | loss = model.train_on_batch(batch_x, batch_y) 192 | iter += 1 193 | # Validate and save training stats. 194 | if iter % 1000 == 0: 195 | tr_loss = eval(model, eval_tr_gen, tr_x, tr_y) 196 | te_loss = eval(model, eval_te_gen, te_x, te_y) 197 | print("Iteration: %d, tr_loss: %f, te_loss: %f" % (iter, tr_loss, te_loss)) 198 | # Save out training stats. 199 | stat_dict = {'iter': iter, 200 | 'tr_loss': tr_loss, 201 | 'te_loss': te_loss, } 202 | stat_path = os.path.join(stats_dir, "%diters.p" % iter) 203 | cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 204 | # Save model. 205 | if iter % 5000 == 0: 206 | model_path = os.path.join(model_dir, "md_%diters.h5" % iter) 207 | model.save(model_path) 208 | print("Saved model to %s" % model_path) 209 | if iter == 70001: 210 | break 211 | print("Training time: %s s" % (time.time() - t1,)) 212 | 213 | def inference(args): 214 | """Inference all test data, write out recovered wavs to disk. 215 | 216 | Args: 217 | workspace: str, path of workspace. 218 | tr_snr: float, training SNR. 219 | te_snr: float, testing SNR. 220 | n_concat: int, number of frames to concatenta, should equal to n_concat 221 | in the training stage. 222 | iter: int, iteration of model to load. 223 | visualize: bool, plot enhanced spectrogram for debug. 224 | """ 225 | print(args) 226 | workspace = args.workspace 227 | tr_snr = args.tr_snr 228 | te_snr = args.te_snr 229 | n_concat = args.n_concat 230 | iter = args.iteration 231 | data_type = 'IRM' 232 | 233 | n_window = cfg.n_window 234 | n_overlap = cfg.n_overlap 235 | fs = cfg.sample_rate 236 | scale = True 237 | 238 | # Load model. 239 | if data_type=="DM": 240 | model_path = os.path.join(workspace, "models", "mixdb", "md_%diters.h5" % 120000) 241 | else: 242 | model_path = os.path.join(workspace, "models", "mask_mixdb", "md_%diters.h5" % 265000) 243 | model = load_model(model_path) 244 | 245 | # Load scaler. 246 | scaler_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mixdb", "scaler.p") 247 | scaler = pickle.load(open(scaler_path, 'rb')) 248 | 249 | # Load test data. 250 | feat_dir = os.path.join(workspace, "features", "spectrogram", "test", "mixdb") 251 | names = os.listdir(feat_dir) 252 | 253 | for (cnt, na) in enumerate(names): 254 | # Load feature. 255 | feat_path = os.path.join(feat_dir, na) 256 | data = cPickle.load(open(feat_path, 'rb')) 257 | [mixed_cmplx_x, speech_x, noise_x, alpha, na] = data 258 | mixed_x = np.abs(mixed_cmplx_x) 259 | if data_type == "IRM": 260 | mixed_x = speech_x + noise_x 261 | mixed_x1 = speech_x + noise_x 262 | # Process data. 263 | n_pad = (n_concat - 1) / 2 264 | mixed_x = pp_data.pad_with_border(mixed_x, n_pad) 265 | mixed_x = pp_data.log_sp(mixed_x) 266 | 267 | # Scale data. 268 | if scale: 269 | mixed_x = pp_data.scale_on_2d(mixed_x, scaler) 270 | 271 | # Cut input spectrogram to 3D segments with n_concat. 272 | mixed_x_3d = pp_data.mat_2d_to_3d(mixed_x, agg_num=n_concat, hop=1) 273 | 274 | # Predict. 275 | pred = model.predict(mixed_x_3d) 276 | if data_type =="IRM": 277 | pred_sp = pred * mixed_x1 278 | print(cnt, na) 279 | 280 | # Inverse scale. 281 | if data_type =="DM": 282 | pred = pp_data.inverse_scale_on_2d(pred, scaler) 283 | pred_sp = np.exp(pred) 284 | # Debug plot. 285 | # Recover enhanced wav. 286 | s = recover_wav(pred_sp, mixed_cmplx_x, n_overlap, np.hamming) 287 | s *= np.sqrt((np.hamming(n_window)**2).sum()) # Scaler for compensate the amplitude 288 | # change after spectrogram and IFFT. 289 | # Write out enhanced wav. 290 | if data_type=="DM": 291 | out_path = os.path.join(workspace, "enh_wavs", "test", "mixdb", "%s.enh.wav" % na) 292 | else: 293 | out_path = os.path.join(workspace, "enh_wavs", "test", "mask_mixdb", "%s.enh.wav" % na) 294 | pp_data.create_folder(os.path.dirname(out_path)) 295 | pp_data.write_audio(out_path, s, fs) 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | def continue_train(args): 304 | workspace = args.workspace 305 | lr = args.lr 306 | iter = args.iteration 307 | data_type = "IRM" 308 | # Load model. 309 | if data_type =="DM": 310 | model_path = os.path.join(workspace, "models", "mixdb", "md_%diters.h5" % iter) 311 | else: 312 | model_path = os.path.join(workspace, "models", "mask_mixdb", "md_%diters.h5" % iter) 313 | model = load_model(model_path) 314 | #model = multi_gpu_model(model, 4) 315 | model.compile(loss='mean_absolute_error', 316 | optimizer=Adam(lr=lr, beta_1 = 0.2)) 317 | # Load data. 318 | t1 = time.time() 319 | if data_type=="DM": 320 | tr_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mixdb", "data.h5") 321 | te_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "test", "mixdb" , "data.h5") 322 | else: 323 | tr_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mask_mixdb", "data.h5") 324 | te_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "test", "mask_mixdb" , "data.h5") 325 | tr_hdf5_dir = os.path.join(workspace, "packed_features", "spectrogram", "train", "mask_mixdb") 326 | tr_hdf5_names = os.listdir(tr_hdf5_dir) 327 | tr_hdf5_names = [i for i in tr_hdf5_names if i.endswith(".h5")] 328 | tr_path_list = [os.path.join(tr_hdf5_dir, i) for i in tr_hdf5_names] 329 | (tr_x, tr_y) = pp_data.load_hdf5(tr_hdf5_path) 330 | (te_x, te_y) = pp_data.load_hdf5(te_hdf5_path) 331 | print(tr_x.shape, tr_y.shape) 332 | print(te_x.shape, te_y.shape) 333 | print("Load data time: %s s" % (time.time() - t1,)) 334 | batch_size = 2048 335 | print("%d iterations / epoch" % int(tr_x.shape[0] / batch_size)) 336 | # Scale data. 337 | if True: 338 | t1 = time.time() 339 | scaler_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mixdb", "scaler.p") 340 | scaler = pickle.load(open(scaler_path, 'rb')) 341 | tr_x = pp_data.scale_on_3d(tr_x, scaler) 342 | te_x = pp_data.scale_on_3d(te_x, scaler) 343 | if data_type=="DM": 344 | tr_y = pp_data.scale_on_2d(tr_y, scaler) 345 | te_y = pp_data.scale_on_2d(te_y, scaler) 346 | print("Scale data time: %s s" % (time.time() - t1,)) 347 | #scaler_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mixdb", "scaler.p") 348 | #scaler = pickle.load(open(scaler_path, 'rb')) 349 | tr_gen = DataGenerator(batch_size=batch_size, type='train') 350 | eval_te_gen = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100) 351 | eval_tr_gen = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100) 352 | #tr_gen = DataGenerator_h5py(batch_size=batch_size, type='train', scaler = scaler) 353 | #eval_te_gen = DataGenerator_h5py(batch_size=batch_size, type='test', te_max_iter=100, scaler =scaler) 354 | #eval_tr_gen = DataGenerator_h5py(batch_size=batch_size, type='test', te_max_iter=100, scaler =scaler) 355 | # Directories for saving models and training stats 356 | if data_type=="DM": 357 | model_dir = os.path.join(workspace, "models", "chinese_mixdb", "continue") 358 | stats_dir = os.path.join(workspace, "training_stats", "chinese_mixdb", "continue") 359 | else: 360 | model_dir = os.path.join(workspace, "models", "mask_mixdb", "continue") 361 | stats_dir = os.path.join(workspace, "training_stats", "mask_mixdb", "continue") 362 | pp_data.create_folder(model_dir) 363 | pp_data.create_folder(stats_dir) 364 | # Print loss before training. 365 | iter = 0 366 | tr_loss = eval(model, eval_tr_gen, tr_x, tr_y) 367 | te_loss = eval(model, eval_te_gen, te_x, te_y) 368 | #tr_loss = eval_h5py(model, eval_tr_gen, tr_path_list) 369 | #te_loss = eval_h5py(model, eval_te_gen, [te_hdf5_path]) 370 | print("Iteration: %d, tr_loss: %f, te_loss: %f" % (iter, tr_loss, te_loss)) 371 | # Save out training stats. 372 | stat_dict = {'iter': iter, 373 | 'tr_loss': tr_loss, 374 | 'te_loss': te_loss, } 375 | stat_path = os.path.join(stats_dir, "%diters.p" % iter) 376 | cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 377 | # Train. 378 | t1 = time.time() 379 | for (batch_x, batch_y) in tr_gen.generate(xs=[tr_x], ys=[tr_y]): 380 | #for (batch_x, batch_y) in tr_gen.generate(tr_path_list): 381 | loss = model.train_on_batch(batch_x, batch_y) 382 | iter += 1 383 | # Validate and save training stats. 384 | if iter % 500 == 0: 385 | tr_loss = eval(model, eval_tr_gen, tr_x, tr_y) 386 | te_loss = eval(model, eval_te_gen, te_x, te_y) 387 | #tr_loss = eval_h5py(model, eval_tr_gen, tr_path_list) 388 | #te_loss = eval_h5py(model, eval_te_gen, [te_hdf5_path]) 389 | print("Iteration: %d, tr_loss: %f, te_loss: %f" % (iter, tr_loss, te_loss)) 390 | # Save out training stats. 391 | stat_dict = {'iter': iter, 392 | 'tr_loss': tr_loss, 393 | 'te_loss': te_loss, } 394 | stat_path = os.path.join(stats_dir, "%diters.p" % iter) 395 | cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 396 | # Save model. 397 | if iter % 5000 == 0: 398 | model_path = os.path.join(model_dir, "md_%diters.h5" % iter) 399 | model.save(model_path) 400 | print("Saved model to %s" % model_path) 401 | if iter == 100001: 402 | break 403 | print("Training time: %s s" % (time.time() - t1,)) 404 | 405 | 406 | 407 | 408 | def parser_function(serialized_example): 409 | features = tf.parse_single_example(serialized_example, 410 | features={ 411 | 'x': tf.FixedLenFeature([], tf.string), 412 | 'y': tf.FixedLenFeature([], tf.string) 413 | }) 414 | x = tf.reshape(tf.decode_raw(features['x'], tf.float32), [7, 257]) 415 | y = tf.reshape(tf.decode_raw(features['y'], tf.float32), [257,]) 416 | return x, y 417 | 418 | 419 | def load_tfrecord(batch, repeat, data_path): 420 | dataset = tf.data.TFRecordDataset(data_path) 421 | dataset = dataset.map(parser_function) 422 | dataset = dataset.shuffle(buffer_size = 1024*10, seed = 10) 423 | dataset = dataset.batch(batch) 424 | dataset = dataset.repeat(repeat) 425 | iterator = dataset.make_one_shot_iterator() 426 | tr_x, tr_y = iterator.get_next() 427 | return tr_x, tr_y 428 | 429 | 430 | 431 | 432 | 433 | def continue_train_tfrecord(): 434 | workspace = "workspace" 435 | lr = 1e-5 436 | iter = 220000 437 | data_type = "IRM" 438 | # Load model. 439 | if data_type =="DM": 440 | model_path = os.path.join(workspace, "models", "elu_mixdb", "md_%diters.h5" % iter) 441 | else: 442 | model_path = os.path.join(workspace, "models", "mask_mixdb", "md_%diters.h5" % iter) 443 | 444 | model = load_model(model_path) 445 | #model = multi_gpu_model(model, 4) 446 | model.compile(loss='mean_absolute_error', 447 | optimizer=Adam(lr=lr, beta_1 = 0.2)) 448 | # Load data. 449 | if data_type=="DM": 450 | tr_hdf5_dir = os.path.join(workspace, "tfrecords", "train", "mixdb") 451 | tr_hdf5_names = os.listdir(tr_hdf5_dir) 452 | tr_path_list = [os.path.join(tr_hdf5_dir, i) for i in tr_hdf5_names] 453 | te_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "test", "mixdb", "data.h5") 454 | else: 455 | tr_hdf5_dir = os.path.join(workspace, "tfrecords", "train", "mask_mixdb") 456 | tr_hdf5_names = os.listdir(tr_hdf5_dir) 457 | tr_path_list = [os.path.join(tr_hdf5_dir, i) for i in tr_hdf5_names] 458 | te_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "test", "mask_mixdb", "data.h5") 459 | 460 | #(tr_x1, tr_y1) = pp_data.load_hdf5("workspace/packed_features/spectrogram/train/mixdb/data100000.h5") 461 | (te_x, te_y) = pp_data.load_hdf5(te_hdf5_path) 462 | t1 = time.time() 463 | scaler_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mixdb", "scaler.p") 464 | scaler = pickle.load(open(scaler_path, 'rb')) 465 | te_x = pp_data.scale_on_3d(te_x, scaler) 466 | #tr_x1 = pp_data.scale_on_3d(tr_x1, scaler) 467 | if data_type=="DM": 468 | te_y = pp_data.scale_on_2d(te_y, scaler) 469 | tr_y1 = pp_data.scale_on_2d(tr_y1, scaler) 470 | print("Scale data time: %s s" % (time.time() - t1,)) 471 | # Directories for saving models and training stats 472 | if data_type=="DM": 473 | model_dir = os.path.join(workspace, "models", "elu_mixdb", "continue") 474 | stats_dir = os.path.join(workspace, "training_stats", "elu_mixdb", "continue") 475 | else: 476 | model_dir = os.path.join(workspace, "models", "mask_mixdb", "continue") 477 | stats_dir = os.path.join(workspace, "training_stats", "mask_mixdb", "continue") 478 | 479 | pp_data.create_folder(model_dir) 480 | pp_data.create_folder(stats_dir) 481 | # Print loss before training. 482 | 483 | batch_size = 1024*4 484 | #eval_tr_gen = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100) 485 | eval_te_gen = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100) 486 | #tr_loss = eval(model, eval_tr_gen, tr_x1, tr_y1) 487 | tr_loss = 0 488 | te_loss = eval(model, eval_te_gen, te_x, te_y) 489 | print("Iteration: %d, tr_loss: %f, te_loss: %f" % (iter, tr_loss, te_loss)) 490 | # Save out training stats. 491 | stat_dict = {'iter': iter, 492 | 'tr_loss': tr_loss, 493 | 'te_loss': te_loss, } 494 | stat_path = os.path.join(stats_dir, "%diters.p" % iter) 495 | cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 496 | # Train. 497 | sess = tf.Session() 498 | x, y = load_tfrecord(batch = batch_size, repeat = 100000, data_path = tr_path_list) 499 | t1 = time.time() 500 | for count in range(1000000000): 501 | [tr_x, tr_y] = sess.run([x, y]) 502 | loss = model.train_on_batch(tr_x, tr_y) 503 | iter += 1 504 | # Validate and save training stats. 505 | if iter % 1000 == 0: 506 | #tr_loss = eval(model, eval_tr_gen, tr_x1, tr_y1) 507 | te_loss = eval(model, eval_te_gen, te_x, te_y) 508 | print("Iteration: %d, tr_loss: %f, te_loss: %f" % (iter, tr_loss, te_loss)) 509 | # Save out training stats. 510 | stat_dict = {'iter': iter, 511 | 'tr_loss': tr_loss, 512 | 'te_loss': te_loss, } 513 | stat_path = os.path.join(stats_dir, "%diters.p" % iter) 514 | cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 515 | # Save model. 516 | if iter % 5000 == 0: 517 | model_path = os.path.join(model_dir, "md_%diters.h5" % iter) 518 | model.save(model_path) 519 | print("Saved model to %s" % model_path) 520 | if iter == 100001: 521 | break 522 | print("Training time: %s s" % (time.time() - t1,)) 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | if __name__ == '__main__': 532 | parser = argparse.ArgumentParser() 533 | subparsers = parser.add_subparsers(dest='mode') 534 | 535 | parser_train = subparsers.add_parser('train') 536 | parser_train.add_argument('--workspace', type=str, required=True) 537 | parser_train.add_argument('--tr_snr', type=float, required=True) 538 | parser_train.add_argument('--te_snr', type=float, required=True) 539 | parser_train.add_argument('--lr', type=float, required=True) 540 | 541 | parser_inference = subparsers.add_parser('inference') 542 | parser_inference.add_argument('--workspace', type=str, required=True) 543 | parser_inference.add_argument('--tr_snr', type=float, required=True) 544 | parser_inference.add_argument('--te_snr', type=float, required=True) 545 | parser_inference.add_argument('--n_concat', type=int, required=True) 546 | parser_inference.add_argument('--iteration', type=int, required=True) 547 | parser_inference.add_argument('--visualize', action='store_true', default=False) 548 | 549 | parser_calculate_pesq = subparsers.add_parser('calculate_pesq') 550 | parser_calculate_pesq.add_argument('--workspace', type=str, required=True) 551 | parser_calculate_pesq.add_argument('--speech_dir', type=str, required=True) 552 | parser_calculate_pesq.add_argument('--te_snr', type=float, required=True) 553 | 554 | args = parser.parse_args() 555 | 556 | if args.mode == 'train': 557 | train(args) 558 | elif args.mode == 'inference': 559 | inference(args) 560 | elif args.mode == 'calculate_pesq': 561 | calculate_pesq(args) 562 | else: 563 | raise Exception("Error!") 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 574 | 575 | 576 | -------------------------------------------------------------------------------- /main_dsn.py: -------------------------------------------------------------------------------- 1 | from main_dnn import * 2 | from keras.layers import Reshape, Input, Concatenate, Dense, Lambda 3 | import keras 4 | 5 | 6 | def lambda_slice(x, n): 7 | return x[:, n, :] 8 | 9 | def continue_train_tfrecord(): 10 | 11 | data_type = "IRM" 12 | workspace = "workspace" 13 | lr = 1e-4 14 | input_x = Input(shape = (7, 257)) 15 | l1_0 = Lambda(lambda_slice, arguments = {"n": 0})(input_x) 16 | l1 = Dense(1024, activation=None)(l1_0) 17 | l1_1 = Dense(10, activation='elu')(l1) 18 | l2_0 = Lambda(lambda_slice, arguments = {"n": 1})(input_x) 19 | l2_input = Concatenate(axis = -1)([l1_1, l2_0]) 20 | l2 = Dense(1024, activation=None)(l2_input) 21 | l2_1 = Dense(10, activation='elu')(l2) 22 | l3_0 = Lambda(lambda_slice, arguments = {"n": 2})(input_x) 23 | l3_input = Concatenate(axis = -1)([l1_1, l2_1, l3_0]) 24 | l3 = Dense(1024, activation=None)(l3_input) 25 | l3_1 = Dense(10, activation='elu')(l3) 26 | l4_0 = Lambda(lambda_slice, arguments = {"n": 3})(input_x) 27 | l4_input = Concatenate(axis = -1)([l1_1, l2_1, l3_1, l4_0]) 28 | l4 = Dense(1024, activation=None)(l4_input) 29 | l4_1 = Dense(10, activation='elu')(l4) 30 | l5_0 = Lambda(lambda_slice, arguments = {"n": 4})(input_x) 31 | l5_input = Concatenate(axis = -1)([l1_1, l2_1, l3_1, l4_1, l5_0]) 32 | l5 = Dense(1024, activation=None)(l5_input) 33 | l5_1 = Dense(10, activation='elu')(l5) 34 | l6_0 = Lambda(lambda_slice, arguments = {"n": 5})(input_x) 35 | l6_input = Concatenate(axis = -1)([l1_1, l2_1, l3_1, l4_1, l5_1, l6_0]) 36 | l6 = Dense(1024, activation=None)(l6_input) 37 | l6_1 = Dense(10, activation='elu')(l6) 38 | l7_0 = Lambda(lambda_slice, arguments = {"n": 6})(input_x) 39 | l7_input = Concatenate(axis = -1)([l1_1, l2_1, l3_1, l4_1, l5_1, l7_0]) 40 | outputs = Dense(257, activation=None)(l7_input) 41 | 42 | 43 | model = keras.models.Model(inputs = [input_x], outputs = outputs) 44 | model.compile(loss='mean_absolute_error', 45 | optimizer=Adam(lr=lr, beta_1 = 0.9)) 46 | # Load data. 47 | tr_hdf5_dir = os.path.join(workspace, "tfrecords", "train", "mixdb") 48 | tr_hdf5_names = os.listdir(tr_hdf5_dir) 49 | tr_path_list = [os.path.join(tr_hdf5_dir, i) for i in tr_hdf5_names] 50 | te_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "test", "mixdb", "data.h5") 51 | 52 | (tr_x1, tr_y1) = pp_data.load_hdf5("workspace/packed_features/spectrogram/train/mixdb/data100000.h5") 53 | (te_x, te_y) = pp_data.load_hdf5(te_hdf5_path) 54 | t1 = time.time() 55 | scaler_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mixdb", "scaler.p") 56 | scaler = pickle.load(open(scaler_path, 'rb')) 57 | te_x = pp_data.scale_on_3d(te_x, scaler) 58 | tr_x1 = pp_data.scale_on_3d(tr_x1, scaler) 59 | te_y = pp_data.scale_on_2d(te_y, scaler) 60 | tr_y1 = pp_data.scale_on_2d(tr_y1, scaler) 61 | print("Scale data time: %s s" % (time.time() - t1,)) 62 | # Directories for saving models and training stats 63 | model_dir = os.path.join(workspace, "models", "dsn_mixdb", "continue") 64 | stats_dir = os.path.join(workspace, "training_stats", "elu_mixdb", "continue") 65 | 66 | pp_data.create_folder(model_dir) 67 | pp_data.create_folder(stats_dir) 68 | # Print loss before training. 69 | iter = 0 70 | batch_size = 1024 71 | eval_tr_gen = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100) 72 | eval_te_gen = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100) 73 | tr_loss = eval(model, eval_tr_gen, tr_x1, tr_y1) 74 | tr_loss = 0 75 | te_loss = eval(model, eval_te_gen, te_x, te_y) 76 | print("Iteration: %d, tr_loss: %f, te_loss: %f" % (iter, tr_loss, te_loss)) 77 | # Save out training stats. 78 | stat_dict = {'iter': iter, 79 | 'tr_loss': tr_loss, 80 | 'te_loss': te_loss, } 81 | stat_path = os.path.join(stats_dir, "%diters.p" % iter) 82 | cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 83 | # Train. 84 | sess = tf.Session() 85 | x, y = load_tfrecord(batch = batch_size, repeat = 100000, data_path = tr_path_list) 86 | t1 = time.time() 87 | for count in range(1000000000): 88 | [tr_x, tr_y] = sess.run([x, y]) 89 | loss = model.train_on_batch(tr_x, tr_y) 90 | iter += 1 91 | # Validate and save training stats. 92 | if iter % 1000 == 0: 93 | tr_loss = eval(model, eval_tr_gen, tr_x1, tr_y1) 94 | te_loss = eval(model, eval_te_gen, te_x, te_y) 95 | print("Iteration: %d, tr_loss: %f, te_loss: %f" % (iter, tr_loss, te_loss)) 96 | # Save out training stats. 97 | stat_dict = {'iter': iter, 98 | 'tr_loss': tr_loss, 99 | 'te_loss': te_loss, } 100 | stat_path = os.path.join(stats_dir, "%diters.p" % iter) 101 | cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 102 | # Save model. 103 | if iter % 5000 == 0: 104 | model_path = os.path.join(model_dir, "md_%diters.h5" % iter) 105 | model.save(model_path) 106 | print("Saved model to %s" % model_path) 107 | if iter == 100001: 108 | break 109 | print("Training time: %s s" % (time.time() - t1,)) 110 | 111 | 112 | def inference(args): 113 | workspace = "workspace" 114 | n_concat = 7 115 | iter = args.iteration 116 | n_window = 512 117 | n_overlap = 256 118 | fs = 16000 119 | # Load model. 120 | model_path = os.path.join(workspace, "models", "dsn_mixdb", "md_%diters.h5" % iter) 121 | model = load_model(model_path) 122 | # Load test data. 123 | feat_dir = os.path.join(workspace, "features", "spectrogram", "test", "dsn_mixdb") 124 | #feat_dir = os.path.join(workspace, "features", "spectrogram", "train", "office_mixdb") 125 | names = os.listdir(feat_dir) 126 | for (cnt, na) in enumerate(names): 127 | # Load feature. 128 | feat_path = os.path.join(feat_dir, na) 129 | data = cPickle.load(open(feat_path, 'rb')) 130 | [mixed_cmplx_x, speech_x, noise_x, alpha, na] = data 131 | mixed_x = np.abs(mixed_cmplx_x) 132 | # Process data. 133 | n_pad = (n_concat - 1) 134 | #mixed_x = pad_with_border(mixed_x, n_pad) 135 | # Cut input spectrogram to 3D segments with n_concat. 136 | mixed_x_3d = pp_data.mat_2d_to_3d(mixed_x, agg_num=n_concat, hop=1)#[100, 7, 257] 137 | #mixed_x = pad_with_border(mixed_x, n_pad) 138 | #mixed_x_3d = mat_2d_to_3d(mixed_x, agg_num=n_concat, hop=1) 139 | # Predict. 140 | w, h, l = mixed_x_3d.shape 141 | pred = model.predict(mixed_x_3d) 142 | #pred_sp = pred[:, -1, :] 143 | print(cnt, na) 144 | if False: 145 | pred_sp = np.load("pred_sp.npy") 146 | speech_x = np.load("speech_x.npy") 147 | mixed_x = np.load("mixed_x.npy") 148 | fig, axs = plt.subplots(3,1, sharex=False) 149 | axs[0].matshow(mixed_x.T, origin='lower', aspect='auto', cmap='jet') 150 | axs[1].matshow(speech_x.T, origin='lower', aspect='auto', cmap='jet') 151 | axs[2].matshow(pred_sp.T, origin='lower', aspect='auto', cmap='jet') 152 | axs[0].set_title("%ddb mixture log spectrogram" % int(1)) 153 | axs[1].set_title("Clean speech log spectrogram") 154 | axs[2].set_title("Enhanced speech log spectrogram") 155 | for j1 in range(3): 156 | axs[j1].xaxis.tick_bottom() 157 | plt.tight_layout() 158 | plt.show() 159 | # Recover enhanced wav. 160 | #pred_sp = np.exp(pred) 161 | #pred_sp = pred 162 | s = recover_wav(pred_sp, mixed_cmplx_x, n_overlap, np.hamming) 163 | s *= np.sqrt((np.hamming(n_window)**2).sum()) # Scaler for compensate the amplitude 164 | # Write out enhanced wav. 165 | out_path = os.path.join(workspace, "enh_wavs", "test", "dsn_mixdb", "%s.enh.wav" % na) 166 | pp_data.create_folder(os.path.dirname(out_path)) 167 | pp_data.write_audio(out_path, s, fs) 168 | -------------------------------------------------------------------------------- /make_tfrecord.py: -------------------------------------------------------------------------------- 1 | """ 2 | Summary: Prepare data. 3 | Author: Qiuqiang Kong 4 | Created: 2017.12.22 5 | Modified: - 6 | """ 7 | import tensorflow as tf 8 | import librosa 9 | import os 10 | import soundfile 11 | import numpy as np 12 | import argparse 13 | import csv 14 | import time 15 | #import matplotlib.pyplot as plt 16 | from scipy import signal 17 | import pickle 18 | #import pickle as cPickle 19 | import cPickle 20 | import h5py 21 | from sklearn import preprocessing 22 | import random 23 | import prepare_data as pp_data 24 | import config as cfg 25 | 26 | 27 | def create_folder(fd): 28 | if not os.path.exists(fd): 29 | os.makedirs(fd) 30 | 31 | 32 | 33 | 34 | 35 | def parser_function(serialized_example): 36 | features = tf.parse_single_example(serialized_example, 37 | features={ 38 | 'x': tf.FixedLenFeature([], tf.string), 39 | 'y': tf.FixedLenFeature([], tf.string) 40 | }) 41 | x = tf.reshape(tf.decode_raw(features['x'], tf.float32), [7, 257]) 42 | y = tf.reshape(tf.decode_raw(features['y'], tf.float32), [257,]) 43 | return x, y 44 | 45 | 46 | def load_tfrecord(batch, repeat, data_path): 47 | dataset = tf.data.TFRecordDataset(data_path) 48 | dataset = dataset.map(parser_function) 49 | dataset = dataset.shuffle(random.randint(1, 100)) 50 | dataset = dataset.batch(batch) 51 | dataset = dataset.repeat(repeat) 52 | iterator = dataset.make_one_shot_iterator() 53 | tr_x, tr_y = iterator.get_next() 54 | return tr_x, tr_y 55 | 56 | 57 | 58 | 59 | def scale_on_2d(x2d, scaler): 60 | """Scale 2D array data. 61 | """ 62 | return scaler.transform(x2d) 63 | 64 | def scale_on_3d(x3d, scaler): 65 | """Scale 3D array data. 66 | """ 67 | (n_segs, n_concat, n_freq) = x3d.shape 68 | x2d = x3d.reshape((n_segs * n_concat, n_freq)) 69 | x2d = scaler.transform(x2d) 70 | x3d = x2d.reshape((n_segs, n_concat, n_freq)) 71 | return x3d 72 | 73 | def inverse_scale_on_2d(x2d, scaler): 74 | """Inverse scale 2D array data. 75 | """ 76 | return x2d * scaler.scale_[None, :] + scaler.mean_[None, :] 77 | 78 | 79 | 80 | def load_hdf5(hdf5_path): 81 | """Load hdf5 data. 82 | """ 83 | with h5py.File(hdf5_path, 'r') as hf: 84 | x = hf.get('x') 85 | y = hf.get('y') 86 | x = np.array(x) # (n_segs, n_concat, n_freq) 87 | y = np.array(y) # (n_segs, n_freq) 88 | return x, y 89 | 90 | 91 | 92 | 93 | 94 | 95 | def tfrecord_handler(): 96 | workspace = "workspace" 97 | data_type = "IRM" 98 | if data_type=="DM": 99 | tr_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mixdb", "data.h5") 100 | else: 101 | tr_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mask_mixdb", "data100000.h5") 102 | (tr_x, tr_y) = load_hdf5(tr_hdf5_path) 103 | scaler_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "mixdb", "scaler.p") 104 | scaler = pickle.load(open(scaler_path, 'rb')) 105 | tr_x = scale_on_3d(tr_x, scaler) 106 | if data_type=="DM": 107 | tr_y = scale_on_2d(tr_y, scaler) 108 | tfrecords_train_filename = 'workspace/tfrecords/train/mask_mixdb/data_chinese.tfrecords' 109 | create_folder(os.path.dirname(tfrecords_train_filename)) 110 | writer_train = tf.python_io.TFRecordWriter(tfrecords_train_filename) 111 | for i in range(tr_x.shape[0]): 112 | mixed_input = tr_x[i, :, :].astype(np.float32).tostring() 113 | label = tr_y[i, :].astype(np.float32).tostring() 114 | example = tf.train.Example(features=tf.train.Features( 115 | feature={ 116 | 'x': tf.train.Feature(bytes_list = tf.train.BytesList(value=[mixed_input])), 117 | 'y': tf.train.Feature(bytes_list = tf.train.BytesList(value=[label])) 118 | })) 119 | writer_train.write(example.SerializeToString()) 120 | if i % 100000 == 0: 121 | print(i) 122 | 123 | writer_train.close() 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | def mix_tfrecord(): 137 | tr_hdf5_dir = os.path.join("workspace", "tfrecords", "train", "crn_mixdb") 138 | tr_hdf5_names = os.listdir(tr_hdf5_dir) 139 | tr_path_list = [os.path.join(tr_hdf5_dir, i) for i in tr_hdf5_names] 140 | sess = tf.Session() 141 | x, y = load_tfrecord(batch = 1, repeat = 1, data_path = tr_path_list) 142 | tfrecords_train_filename = '/data00/wangjinchao/sednn-master/mixture2clean_dnn/workspace/tfrecords/train/mixdb/data_office.tfrecords' 143 | create_folder(os.path.dirname(tfrecords_train_filename)) 144 | writer_train = tf.python_io.TFRecordWriter(tfrecords_train_filename) 145 | try: 146 | while True: 147 | [tr_x, tr_y] = sess.run([x, y]) 148 | mixed_input = tr_x.astype(np.float32).tostring() 149 | label = tr_y.astype(np.float32).tostring() 150 | example = tf.train.Example(features=tf.train.Features( 151 | feature={ 152 | 'x': tf.train.Feature(bytes_list = tf.train.BytesList(value=[mixed_input])), 153 | 'y': tf.train.Feature(bytes_list = tf.train.BytesList(value=[label])) 154 | })) 155 | writer_train.write(example.SerializeToString()) 156 | except tf.errors.OutOfRangeError: 157 | writer_train.close() 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | def compute_gv(): 172 | mean_y = np.mean(tr_y) 173 | tmp_y= np.power((tr_y - mean_y), 2) 174 | gv_ref_independent = np.mean(tmp_y) 175 | 176 | mean_y = np.mean(tr_y, axis = 0) 177 | tmp_y = np.power((tr_y - mean_y), 2) 178 | gv_ref_dependent = np.mean(tmp_y, axis = 0) 179 | 180 | mean_pred_y = np.mean(pred_y) 181 | tmp_pred_y= np.power((pred_y - mean_pred_y), 2) 182 | gv_est_independent = np.mean(tmp_pred_y) 183 | 184 | mean_pred_y = np.mean(pred_y, axis = 0) 185 | tmp_pred_y = np.power((pred_y - mean_pred_y), 2) 186 | gv_est_dependent = np.mean(tmp_pred_y, axis = 0) 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | gv_ref_independent = 1.2445884 196 | gv_est_independent = 1.047566 197 | 198 | gv_ref_dependent = np.array([0.8188001 , 0.7655154 , 1.0443362 , 1.291341 , 1.3091451 , 199 | 1.3863533 , 1.3579204 , 1.3562269 , 1.3332679 , 1.3779601 , 200 | 1.3638331 , 1.4141914 , 1.4132004 , 1.414608 , 1.4191983 , 201 | 1.3779674 , 1.3744586 , 1.4065691 , 1.405165 , 1.3589902 , 202 | 1.3695769 , 1.3828198 , 1.399761 , 1.4110584 , 1.4237365 , 203 | 1.4146262 , 1.4150872 , 1.4020557 , 1.4402065 , 1.4247525 , 204 | 1.4152107 , 1.3777906 , 1.4041775 , 1.411573 , 1.4258041 , 205 | 1.4248255 , 1.4343295 , 1.4284252 , 1.3958426 , 1.3800949 , 206 | 1.394799 , 1.402656 , 1.3995781 , 1.3867273 , 1.4019246 , 207 | 1.394503 , 1.3876013 , 1.3906093 , 1.3923353 , 1.3908792 , 208 | 1.3651602 , 1.3789821 , 1.3817782 , 1.3878808 , 1.3868887 , 209 | 1.389586 , 1.3882133 , 1.3948598 , 1.3833323 , 1.3911697 , 210 | 1.3947376 , 1.3785598 , 1.3657677 , 1.3754646 , 1.3744026 , 211 | 1.36841 , 1.3738396 , 1.375986 , 1.3782787 , 1.3705876 , 212 | 1.3561313 , 1.363172 , 1.3721641 , 1.3663605 , 1.3701444 , 213 | 1.3718685 , 1.3587731 , 1.3583094 , 1.3632051 , 1.3683681 , 214 | 1.3819396 , 1.3825235 , 1.378892 , 1.3761448 , 1.3808253 , 215 | 1.3743024 , 1.367832 , 1.3641973 , 1.3663458 , 1.369809 , 216 | 1.371535 , 1.3641069 , 1.363354 , 1.3653663 , 1.3578664 , 217 | 1.3501805 , 1.3377979 , 1.3453208 , 1.3447514 , 1.3466262 , 218 | 1.3516669 , 1.3419527 , 1.3322309 , 1.3304617 , 1.3314892 , 219 | 1.3222749 , 1.3076648 , 1.3175845 , 1.3237734 , 1.3146265 , 220 | 1.3085129 , 1.3097675 , 1.3060361 , 1.299763 , 1.2958938 , 221 | 1.2963424 , 1.2883214 , 1.2881285 , 1.2870046 , 1.2888812 , 222 | 1.2778481 , 1.2760473 , 1.2680486 , 1.2644651 , 1.2633371 , 223 | 1.2600574 , 1.2674776 , 1.2619113 , 1.25404 , 1.2484775 , 224 | 1.2528795 , 1.2445921 , 1.2449573 , 1.2370106 , 1.240662 , 225 | 1.2343256 , 1.2296497 , 1.2207483 , 1.2245104 , 1.212012 , 226 | 1.2099534 , 1.2040404 , 1.2014705 , 1.2012196 , 1.1975276 , 227 | 1.1931353 , 1.1944716 , 1.1941463 , 1.1930957 , 1.1830707 , 228 | 1.1817104 , 1.1773063 , 1.1705128 , 1.1806594 , 1.1794373 , 229 | 1.175316 , 1.1757798 , 1.1782918 , 1.1770912 , 1.1753559 , 230 | 1.1691241 , 1.1691626 , 1.1616837 , 1.1592903 , 1.1525471 , 231 | 1.148833 , 1.1445248 , 1.1463698 , 1.1432943 , 1.1372362 , 232 | 1.1345378 , 1.1331203 , 1.1327978 , 1.1356372 , 1.1281763 , 233 | 1.117315 , 1.1229038 , 1.1331227 , 1.129955 , 1.1205344 , 234 | 1.1168914 , 1.1162447 , 1.1205385 , 1.1221027 , 1.1183283 , 235 | 1.1176765 , 1.1073152 , 1.1065495 , 1.1066844 , 1.1020577 , 236 | 1.0956546 , 1.0937659 , 1.0824373 , 1.0914868 , 1.0957388 , 237 | 1.0990036 , 1.0980628 , 1.1037108 , 1.0973698 , 1.0961391 , 238 | 1.0953025 , 1.09513 , 1.093008 , 1.0896668 , 1.0927784 , 239 | 1.0900792 , 1.0936061 , 1.0935822 , 1.0972129 , 1.0939381 , 240 | 1.0888202 , 1.0845745 , 1.0836582 , 1.0842501 , 1.0809636 , 241 | 1.0757244 , 1.076439 , 1.0760363 , 1.0668286 , 1.0531492 , 242 | 1.0502294 , 1.0589144 , 1.0721456 , 1.0731709 , 1.0684367 , 243 | 1.0632014 , 1.0599935 , 1.0586678 , 1.0569472 , 1.0625534 , 244 | 1.0626838 , 1.0648353 , 1.0663067 , 1.06597 , 1.0638473 , 245 | 1.0639621 , 1.0637795 , 1.0606909 , 1.0582322 , 1.0517532 , 246 | 1.0480362 , 1.0479565 , 1.0435289 , 1.0371186 , 1.0334498 , 247 | 1.0291612 , 1.0249708 , 1.0198319 , 1.0156595 , 1.0108268 , 248 | 1.0073904 , 1.0044657 , 1.0021349 , 1.0046784 , 1.0016526 , 249 | 0.99999756, 0.9964546 ]) 250 | 251 | 252 | 253 | gv_est_dependent = np.array([0.18696505, 0.34280822, 0.72116977, 1.0536233 , 1.0966913 , 254 | 1.1637543 , 1.1548996 , 1.1676413 , 1.1474363 , 1.171958 , 255 | 1.1660068 , 1.2174896 , 1.2271285 , 1.2325336 , 1.2390236 , 256 | 1.2042515 , 1.2004554 , 1.2231534 , 1.2235614 , 1.1827441 , 257 | 1.1899538 , 1.1991223 , 1.2111914 , 1.2175452 , 1.2235477 , 258 | 1.213116 , 1.2148174 , 1.1969019 , 1.221351 , 1.2065829 , 259 | 1.2010144 , 1.1644497 , 1.1922417 , 1.1955268 , 1.2043128 , 260 | 1.1993589 , 1.2122489 , 1.2041621 , 1.1712927 , 1.1558433 , 261 | 1.1670535 , 1.1739376 , 1.171306 , 1.1575273 , 1.1754738 , 262 | 1.1698086 , 1.1619811 , 1.1717362 , 1.1703048 , 1.1663839 , 263 | 1.146894 , 1.1596279 , 1.1625977 , 1.1698763 , 1.1666222 , 264 | 1.1692064 , 1.1689862 , 1.1767648 , 1.1620858 , 1.1627572 , 265 | 1.167898 , 1.160154 , 1.1409459 , 1.1539472 , 1.1525996 , 266 | 1.1464965 , 1.144994 , 1.147802 , 1.1467297 , 1.1443332 , 267 | 1.1305172 , 1.1382952 , 1.1508311 , 1.1470501 , 1.1399091 , 268 | 1.1442221 , 1.1377715 , 1.1442518 , 1.1429965 , 1.1489736 , 269 | 1.1564287 , 1.157903 , 1.1531717 , 1.1507055 , 1.1558954 , 270 | 1.148154 , 1.1464262 , 1.140695 , 1.1445204 , 1.1501266 , 271 | 1.1467679 , 1.1315886 , 1.1395706 , 1.138759 , 1.1322058 , 272 | 1.1181237 , 1.11835 , 1.11936 , 1.1205423 , 1.1255524 , 273 | 1.1306192 , 1.1268992 , 1.113448 , 1.1130816 , 1.1138029 , 274 | 1.1083379 , 1.0993805 , 1.1055454 , 1.1122226 , 1.1078554 , 275 | 1.1074036 , 1.100761 , 1.1006709 , 1.1015122 , 1.09618 , 276 | 1.0973667 , 1.0900112 , 1.085913 , 1.081041 , 1.0914049 , 277 | 1.0785384 , 1.0716164 , 1.0673122 , 1.0675014 , 1.0644028 , 278 | 1.0625845 , 1.0691746 , 1.0601567 , 1.0490003 , 1.0468317 , 279 | 1.0543352 , 1.0443738 , 1.0391475 , 1.0349274 , 1.043713 , 280 | 1.0387459 , 1.0371169 , 1.0293257 , 1.0260344 , 1.0175334 , 281 | 1.0159734 , 1.0076747 , 1.0029306 , 1.0055224 , 0.9981165 , 282 | 0.9971783 , 0.99703634, 0.9997423 , 0.9974254 , 0.98594177, 283 | 0.9810835 , 0.9804084 , 0.9802749 , 0.9853521 , 0.98604727, 284 | 0.9816301 , 0.97491074, 0.9801819 , 0.9790348 , 0.98185617, 285 | 0.97771597, 0.9796749 , 0.9697705 , 0.96665776, 0.9637393 , 286 | 0.9561936 , 0.94966286, 0.9454328 , 0.94490314, 0.94937634, 287 | 0.9478694 , 0.9447655 , 0.94176424, 0.9403081 , 0.94139105, 288 | 0.9293294 , 0.93296844, 0.9450051 , 0.94217265, 0.9318225 , 289 | 0.9312917 , 0.93434596, 0.9384418 , 0.93287736, 0.9366608 , 290 | 0.9325247 , 0.9249665 , 0.9251077 , 0.92365205, 0.920737 , 291 | 0.9119053 , 0.90590674, 0.9011289 , 0.915063 , 0.9160822 , 292 | 0.92288536, 0.91581845, 0.92356586, 0.9192137 , 0.918826 , 293 | 0.9247057 , 0.9222291 , 0.91558635, 0.90998685, 0.9195295 , 294 | 0.9112585 , 0.9219293 , 0.91975945, 0.918718 , 0.91701937, 295 | 0.917412 , 0.90922654, 0.90436333, 0.9087872 , 0.9071029 , 296 | 0.90249467, 0.9022592 , 0.90647477, 0.89404655, 0.8869796 , 297 | 0.8787871 , 0.8882476 , 0.90208715, 0.9047103 , 0.9043998 , 298 | 0.90141785, 0.8983854 , 0.89462155, 0.90119785, 0.9028099 , 299 | 0.90848804, 0.91339225, 0.9102626 , 0.9124032 , 0.9146076 , 300 | 0.9150387 , 0.92100835, 0.9140667 , 0.91219926, 0.91785675, 301 | 0.92215157, 0.9197757 , 0.9274899 , 0.9233789 , 0.9288496 , 302 | 0.92639154, 0.9255196 , 0.9201056 , 0.9202775 , 0.9202095 , 303 | 0.92024976, 0.9177347 , 0.9155735 , 0.921838 , 0.91567504, 304 | 0.9140105 , 0.85075736] ) 305 | 306 | 307 | 308 | 309 | -------------------------------------------------------------------------------- /mini_data/test_noise/babble.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vanka0051/speech_enhancement/815c7f3a8f78344d206a58d48ce49d7e4ba657d9/mini_data/test_noise/babble.wav -------------------------------------------------------------------------------- /mini_data/test_noise/white.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vanka0051/speech_enhancement/815c7f3a8f78344d206a58d48ce49d7e4ba657d9/mini_data/test_noise/white.wav -------------------------------------------------------------------------------- /mini_data/test_speech/TEST_DR3_MJMP0_SA2.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vanka0051/speech_enhancement/815c7f3a8f78344d206a58d48ce49d7e4ba657d9/mini_data/test_speech/TEST_DR3_MJMP0_SA2.WAV -------------------------------------------------------------------------------- /mini_data/test_speech/TEST_DR3_MJMP0_SI905.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vanka0051/speech_enhancement/815c7f3a8f78344d206a58d48ce49d7e4ba657d9/mini_data/test_speech/TEST_DR3_MJMP0_SI905.WAV -------------------------------------------------------------------------------- /mini_data/train_noise/n1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vanka0051/speech_enhancement/815c7f3a8f78344d206a58d48ce49d7e4ba657d9/mini_data/train_noise/n1.wav -------------------------------------------------------------------------------- /mini_data/train_noise/n43.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vanka0051/speech_enhancement/815c7f3a8f78344d206a58d48ce49d7e4ba657d9/mini_data/train_noise/n43.wav -------------------------------------------------------------------------------- /mini_data/train_speech/TRAIN_DR1_FCJF0_SA2.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vanka0051/speech_enhancement/815c7f3a8f78344d206a58d48ce49d7e4ba657d9/mini_data/train_speech/TRAIN_DR1_FCJF0_SA2.WAV -------------------------------------------------------------------------------- /mini_data/train_speech/TRAIN_DR1_FCJF0_SI648.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vanka0051/speech_enhancement/815c7f3a8f78344d206a58d48ce49d7e4ba657d9/mini_data/train_speech/TRAIN_DR1_FCJF0_SI648.WAV -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Summary: Prepare data. 3 | Author: Qiuqiang Kong 4 | Created: 2017.12.22 5 | Modified: - 6 | """ 7 | import librosa 8 | import os 9 | import soundfile 10 | import numpy as np 11 | import argparse 12 | import csv 13 | import time 14 | #import matplotlib.pyplot as plt 15 | from scipy import signal 16 | import pickle 17 | #import pickle as cPickle 18 | import cPickle 19 | import h5py 20 | from sklearn import preprocessing 21 | import random 22 | import prepare_data as pp_data 23 | import config as cfg 24 | 25 | 26 | def create_folder(fd): 27 | if not os.path.exists(fd): 28 | os.makedirs(fd) 29 | 30 | def read_audio(path, target_fs=None): 31 | (audio, fs) = soundfile.read(path) 32 | if audio.ndim > 1: 33 | audio = np.mean(audio, axis=1) 34 | if target_fs is not None and fs != target_fs: 35 | audio = librosa.resample(audio, orig_sr=fs, target_sr=target_fs) 36 | fs = target_fs 37 | return audio, fs 38 | 39 | def write_audio(path, audio, sample_rate): 40 | soundfile.write(file=path, data=audio, samplerate=sample_rate) 41 | 42 | ### 43 | def create_mixture_csv(args): 44 | """Create csv containing mixture information. 45 | Each line in the .csv file contains [speech_name, noise_name, noise_onset, noise_offset] 46 | 47 | Args: 48 | workspace: str, path of workspace. 49 | speech_dir: str, path of speech data. 50 | noise_dir: str, path of noise data. 51 | data_type: str, 'train' | 'test'. 52 | magnification: int, only used when data_type='train', number of noise 53 | selected to mix with a speech. E.g., when magnication=3, then 4620 54 | speech with create 4620*3 mixtures. magnification should not larger 55 | than the species of noises. 56 | """ 57 | workspace = args.workspace 58 | speech_dir = args.speech_dir 59 | noise_dir = args.noise_dir 60 | data_type = args.data_type 61 | magnification = args.magnification 62 | fs = cfg.sample_rate 63 | speech_names = [na for na in os.listdir(speech_dir) if na.lower().endswith(".wav")] 64 | noise_names = [na for na in os.listdir(noise_dir) if na.lower().endswith(".wav")] 65 | #snr = np.random.normal(25, 10, len(speech_names)*magnification) 66 | snr = [-5, 0, 5, 10, 15, 20, 25, 30] * int(len(speech_names)*magnification/8.0) 67 | random.shuffle(snr) 68 | rs = np.random.RandomState(0) 69 | out_csv_path = os.path.join(workspace, "mixture_csvs", "%s_98hour_even.csv" % data_type) 70 | pp_data.create_folder(os.path.dirname(out_csv_path)) 71 | cnt = 0 72 | f = open(out_csv_path, 'w') 73 | f.write("%s\t%s\t%s\t%s\t%s\n" % ("speech_name", "noise_name", "noise_onset", "noise_offset", "snr")) 74 | for speech_na in speech_names: 75 | # Read speech. 76 | speech_path = os.path.join(speech_dir, speech_na) 77 | (speech_audio, _) = read_audio(speech_path, fs) 78 | len_speech = len(speech_audio) 79 | # For training data, mix each speech with randomly picked #magnification noises. 80 | if data_type == 'train': 81 | selected_noise_names = rs.choice(noise_names, size=magnification, replace=False) 82 | # For test data, mix each speech with all noises. 83 | elif data_type == 'test': 84 | selected_noise_names = noise_names 85 | else: 86 | raise Exception("data_type must be train | test!") 87 | # Mix one speech with different noises many times. 88 | for noise_na in selected_noise_names: 89 | noise_path = os.path.join(noise_dir, noise_na) 90 | (noise_audio, _) = read_audio(noise_path, fs) 91 | len_noise = len(noise_audio) 92 | if len_noise <= len_speech: 93 | noise_onset = 0 94 | nosie_offset = len_speech 95 | # If noise longer than speech then randomly select a segment of noise. 96 | else: 97 | noise_onset = rs.randint(0, len_noise - len_speech, size=1)[0] 98 | nosie_offset = noise_onset + len_speech 99 | f.write("%s\t%s\t%d\t%d\t%f\n" % (speech_na, noise_na, noise_onset, nosie_offset, snr[cnt])) 100 | if cnt % 100 == 0: 101 | print(cnt) 102 | cnt += 1 103 | f.close() 104 | print(out_csv_path) 105 | print("Create %s mixture csv finished!" % data_type) 106 | 107 | ### 108 | def calculate_mixture_features(args): 109 | """Calculate spectrogram for mixed, speech and noise audio. Then write the 110 | features to disk. 111 | 112 | Args: 113 | workspace: str, path of workspace. 114 | speech_dir: str, path of speech data. 115 | noise_dir: str, path of noise data. 116 | data_type: str, 'train' | 'test'. 117 | snr: float, signal to noise ratio to be mixed. 118 | """ 119 | workspace = args.workspace 120 | speech_dir = args.speech_dir 121 | noise_dir = args.noise_dir 122 | data_type = args.data_type 123 | snr = args.snr 124 | fs = cfg.sample_rate 125 | # Open mixture csv. 126 | 127 | mixture_csv_path = os.path.join(workspace, "mixture_csvs", "test_1hour_even.csv") 128 | with open(mixture_csv_path, 'rb') as f: 129 | reader = csv.reader(f, delimiter='\t') 130 | lis = list(reader) 131 | 132 | t1 = time.time() 133 | cnt = 0 134 | for i1 in xrange(1, len(lis)): 135 | [speech_na, noise_na, noise_onset, noise_offset, snr] = lis[i1] 136 | noise_onset = int(noise_onset) 137 | noise_offset = int(noise_offset) 138 | snr = float(snr) 139 | # Read speech audio. 140 | speech_path = os.path.join(speech_dir, speech_na) 141 | (speech_audio, _) = read_audio(speech_path, target_fs=fs) 142 | # Read noise audio. 143 | noise_path = os.path.join(noise_dir, noise_na) 144 | (noise_audio, _) = read_audio(noise_path, target_fs=fs) 145 | # Repeat noise to the same length as speech. 146 | if len(noise_audio) < len(speech_audio): 147 | n_repeat = int(np.ceil(float(len(speech_audio)) / float(len(noise_audio)))) 148 | noise_audio_ex = np.tile(noise_audio, n_repeat) 149 | noise_audio = noise_audio_ex[0 : len(speech_audio)] 150 | # Truncate noise to the same length as speech. 151 | else: 152 | noise_audio = noise_audio[noise_onset : noise_offset] 153 | # Scale speech to given snr. 154 | scaler = get_amplitude_scaling_factor(speech_audio, noise_audio, snr=snr) 155 | speech_audio *= scaler 156 | # Get normalized mixture, speech, noise. 157 | (mixed_audio, speech_audio, noise_audio, alpha) = additive_mixing(speech_audio, noise_audio) 158 | # Write out mixed audio. 159 | out_bare_na = os.path.join("%s.%s.%s" % 160 | (os.path.splitext(speech_na)[0], os.path.splitext(noise_na)[0], (str(int(snr)) + "db"))) 161 | out_audio_path = os.path.join(workspace, "mixed_audios", "spectrogram", 162 | data_type, "crn_mixdb", "%s.wav" % out_bare_na) 163 | create_folder(os.path.dirname(out_audio_path)) 164 | write_audio(out_audio_path, mixed_audio, fs) 165 | # Extract spectrogram. 166 | mixed_complx_x = calc_sp(mixed_audio, mode='complex', n_window = 320, n_overlap = 160) 167 | speech_x = calc_sp(speech_audio, mode='magnitude', n_window = 320, n_overlap = 160) 168 | noise_x = calc_sp(noise_audio, mode='magnitude', n_window = 320, n_overlap = 160) 169 | # Write out features. 170 | out_feat_path = os.path.join(workspace, "features", "spectrogram", 171 | data_type, "crn_mixdb", "%s.p" % out_bare_na) 172 | create_folder(os.path.dirname(out_feat_path)) 173 | data = [mixed_complx_x, speech_x, noise_x, alpha, out_bare_na] 174 | cPickle.dump(data, open(out_feat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 175 | # Print. 176 | if cnt % 100 == 0: 177 | print(cnt) 178 | cnt += 1 179 | print("Extracting feature time: %s" % (time.time() - t1)) 180 | 181 | 182 | def rms(y): 183 | """Root mean square. 184 | """ 185 | return np.sqrt(np.mean(np.abs(y) ** 2, axis=0, keepdims=False)) 186 | 187 | def get_amplitude_scaling_factor(s, n, snr, method='rms'): 188 | """Given s and n, return the scaler s according to the snr. 189 | 190 | Args: 191 | s: ndarray, source1. 192 | n: ndarray, source2. 193 | snr: float, SNR. 194 | method: 'rms'. 195 | 196 | Outputs: 197 | float, scaler. 198 | """ 199 | original_sn_rms_ratio = rms(s) / rms(n) 200 | target_sn_rms_ratio = 10. ** (float(snr) / 20.) # snr = 20 * lg(rms(s) / rms(n)) 201 | signal_scaling_factor = target_sn_rms_ratio / original_sn_rms_ratio 202 | return signal_scaling_factor 203 | 204 | def additive_mixing(s, n): 205 | """Mix normalized source1 and source2. 206 | 207 | Args: 208 | s: ndarray, source1. 209 | n: ndarray, source2. 210 | 211 | Returns: 212 | mix_audio: ndarray, mixed audio. 213 | s: ndarray, pad or truncated and scalered source1. 214 | n: ndarray, scaled source2. 215 | alpha: float, normalize coefficient. 216 | """ 217 | mixed_audio = s + n 218 | 219 | alpha = 1. / np.max(np.abs(mixed_audio)) 220 | mixed_audio *= alpha 221 | s *= alpha 222 | n *= alpha 223 | return mixed_audio, s, n, alpha 224 | 225 | def calc_sp(audio, mode, n_window, n_overlap): 226 | """Calculate spectrogram. 227 | 228 | Args: 229 | audio: 1darray. 230 | mode: string, 'magnitude' | 'complex' 231 | 232 | Returns: 233 | spectrogram: 2darray, (n_time, n_freq). 234 | """ 235 | ham_win = np.hamming(n_window) 236 | [f, t, x] = signal.spectral.spectrogram( 237 | audio, 238 | window=ham_win, 239 | nperseg=n_window, 240 | noverlap=n_overlap, 241 | detrend=False, 242 | return_onesided=True, 243 | mode=mode) 244 | x = x.T 245 | if mode == 'magnitude': 246 | x = x.astype(np.float32) 247 | elif mode == 'complex': 248 | x = x.astype(np.complex64) 249 | else: 250 | raise Exception("Incorrect mode!") 251 | return x 252 | 253 | ### 254 | def pack_features(args): 255 | """Load all features, apply log and conver to 3D tensor, write out to .h5 file. 256 | 257 | Args: 258 | workspace: str, path of workspace. 259 | data_type: str, 'train' | 'test'. 260 | snr: float, signal to noise ratio to be mixed. 261 | n_concat: int, number of frames to be concatenated. 262 | n_hop: int, hop frames. 263 | """ 264 | workspace = args.workspace 265 | data_type = args.data_type 266 | snr = args.snr 267 | n_concat = args.n_concat 268 | n_hop = args.n_hop 269 | 270 | x_all = [] # (n_segs, n_concat, n_freq) 271 | y_all = [] # (n_segs, n_freq) 272 | 273 | cnt = 0 274 | t1 = time.time() 275 | train_type = "IRM" 276 | # Load all features. 277 | feat_dir = os.path.join(workspace, "features", "spectrogram", data_type, "chinese_mixdb") 278 | names = os.listdir(feat_dir) 279 | for na in names: 280 | # Load feature. 281 | feat_path = os.path.join(feat_dir, na) 282 | data = cPickle.load(open(feat_path, 'rb')) 283 | [mixed_complx_x, speech_x, noise_x, alpha, na] = data 284 | ############## 285 | if train_type == "IRM": 286 | mixed_x = speech_x + noise_x 287 | else: 288 | mixed_x = np.abs(mixed_complx_x) 289 | # Pad start and finish of the spectrogram with boarder values. 290 | n_pad = (n_concat - 1) / 2 291 | mixed_x = pad_with_border(mixed_x, n_pad) 292 | speech_x = pad_with_border(speech_x, n_pad) 293 | # Cut input spectrogram to 3D segments with n_concat. 294 | mixed_x_3d = mat_2d_to_3d(mixed_x, agg_num=n_concat, hop=n_hop) 295 | x_all.append(mixed_x_3d) 296 | # Cut target spectrogram and take the center frame of each 3D segment. 297 | speech_x_3d = mat_2d_to_3d(speech_x, agg_num=n_concat, hop=n_hop) 298 | y = speech_x_3d[:, (n_concat - 1) / 2, :] 299 | if train_type == "IRM": 300 | y = y/mixed_x_3d[:, (n_concat - 1) / 2, :] 301 | y_all.append(y) 302 | # Print. 303 | if cnt % 100 == 0: 304 | print(cnt) 305 | if cnt%10000==0 and cnt!=0: 306 | x_all = np.concatenate(x_all, axis=0) # (n_segs, n_concat, n_freq) 307 | y_all = np.concatenate(y_all, axis=0) # (n_segs, n_freq) 308 | x_all = log_sp(x_all).astype(np.float32) 309 | if train_type != "IRM": 310 | y_all = log_sp(y_all).astype(np.float32) 311 | out_path = os.path.join(workspace, "packed_features", "spectrogram", data_type, "mask_mixdb" , "chinese_data%d.h5"%cnt) 312 | create_folder(os.path.dirname(out_path)) 313 | with h5py.File(out_path, 'w') as hf: 314 | hf.create_dataset('x', data=x_all) 315 | hf.create_dataset('y', data=y_all) 316 | x_all = [] 317 | y_all = [] 318 | # if cnt == 3: break 319 | cnt += 1 320 | # Write out data to .h5 file. 321 | x_all = np.concatenate(x_all, axis=0) 322 | y_all = np.concatenate(y_all, axis=0) 323 | x_all = log_sp(x_all).astype(np.float32) 324 | if train_type != "IRM": 325 | y_all = log_sp(y_all).astype(np.float32) 326 | out_path = os.path.join(workspace, "packed_features", "spectrogram", data_type, "mask_mixdb" , "chinese_data%d.h5"%cnt) 327 | create_folder(os.path.dirname(out_path)) 328 | with h5py.File(out_path, 'w') as hf: 329 | hf.create_dataset('x', data=x_all) 330 | hf.create_dataset('y', data=y_all) 331 | 332 | print("Write out to %s" % out_path) 333 | print("Pack features finished! %s s" % (time.time() - t1,)) 334 | 335 | 336 | 337 | def pack_crn_features(args): 338 | """Load all features, apply log and conver to 3D tensor, write out to .h5 file. 339 | 340 | Args: 341 | workspace: str, path of workspace. 342 | data_type: str, 'train' | 'test'. 343 | snr: float, signal to noise ratio to be mixed. 344 | n_concat: int, number of frames to be concatenated. 345 | n_hop: int, hop frames. 346 | """ 347 | workspace = "workspace" 348 | data_type = "train" 349 | n_concat = 11 350 | n_hop = 4 351 | cnt = 0 352 | t1 = time.time() 353 | # Load all features. 354 | feat_dir = os.path.join(workspace, "features", "spectrogram", data_type, "crn_mixdb") 355 | names = os.listdir(feat_dir) 356 | tfrecords_train_filename = '/data00/wangjinchao/sednn-master/mixture2clean_dnn/workspace/tfrecords/train/crn_mixdb/data_office.tfrecords' 357 | create_folder(os.path.dirname(tfrecords_train_filename)) 358 | writer_train = tf.python_io.TFRecordWriter(tfrecords_train_filename) 359 | for na in names: 360 | # Load feature. 361 | feat_path = os.path.join(feat_dir, na) 362 | data = cPickle.load(open(feat_path, 'rb')) 363 | [mixed_complx_x, speech_x, noise_x, alpha, na] = data 364 | ############## 365 | mixed_x = np.abs(mixed_complx_x) 366 | # Pad start and finish of the spectrogram with boarder values. 367 | # Cut input spectrogram to 3D segments with n_concat. 368 | mixed_x_3d = mat_2d_to_3d(mixed_x, agg_num=n_concat, hop=n_hop) 369 | # Cut target spectrogram and take the center frame of each 3D segment. 370 | speech_x_3d = mat_2d_to_3d(speech_x, agg_num=n_concat, hop=n_hop) 371 | for i in range(mixed_x_3d.shape[0]): 372 | mixed_input = mixed_x_3d[i].astype(np.float32).tostring() 373 | label = speech_x_3d[i].astype(np.float32).tostring() 374 | example = tf.train.Example(features=tf.train.Features( 375 | feature={ 376 | 'x': tf.train.Feature(bytes_list = tf.train.BytesList(value=[mixed_input])), 377 | 'y': tf.train.Feature(bytes_list = tf.train.BytesList(value=[label])) 378 | })) 379 | writer_train.write(example.SerializeToString()) 380 | # Print. 381 | if cnt % 100 == 0: 382 | print(cnt) 383 | # if cnt == 3: break 384 | cnt += 1 385 | 386 | 387 | 388 | 389 | def log_sp(x): 390 | return np.log(x + 1e-08) 391 | 392 | def mat_2d_to_3d(x, agg_num, hop): 393 | """Segment 2D array to 3D segments. 394 | """ 395 | # Pad to at least one block. 396 | len_x, n_in = x.shape 397 | if (len_x < agg_num): 398 | x = np.concatenate((x, np.zeros((agg_num - len_x, n_in)))) 399 | # Segment 2d to 3d. 400 | len_x = len(x) 401 | i1 = 0 402 | x3d = [] 403 | while (i1 + agg_num <= len_x): 404 | x3d.append(x[i1 : i1 + agg_num]) 405 | i1 += hop 406 | return np.array(x3d) 407 | 408 | def pad_with_border(x, n_pad): 409 | """Pad the begin and finish of spectrogram with border frame value. 410 | """ 411 | x_pad_list = [x[0:1]] * n_pad + [x] + [x[-1:]] * n_pad 412 | return np.concatenate(x_pad_list, axis=0) 413 | 414 | ### 415 | def compute_scaler(args): 416 | """Compute and write out scaler of data. 417 | """ 418 | workspace = args.workspace 419 | data_type = args.data_type 420 | snr = args.snr 421 | 422 | # Load data. 423 | t1 = time.time() 424 | hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", data_type, "mixdb", "data.h5") 425 | 426 | tr_hdf5_dir = os.path.join(workspace, "packed_features", "spectrogram", "train", "mixdb") 427 | tr_hdf5_names = os.listdir(tr_hdf5_dir) 428 | print('start loading h5') 429 | with h5py.File(os.path.join(tr_hdf5_dir, tr_hdf5_names[0]), 'r') as hf: 430 | x = hf.get('x') 431 | x = np.array(x) 432 | for na in tr_hdf5_names[1:2]: 433 | with h5py.File(os.path.join(tr_hdf5_dir, na), 'r') as hf: 434 | x1 = hf.get('x') 435 | x1 = np.array(x1) 436 | x = np.concatenate([x, x1], axis = 0) 437 | print(na + 'is loaded') 438 | 439 | 440 | # Compute scaler. 441 | (n_segs, n_concat, n_freq) = x.shape 442 | x2d = x.reshape((n_segs * n_concat, n_freq)) 443 | scaler = preprocessing.StandardScaler(with_mean=True, with_std=True).fit(x2d) 444 | print(scaler.mean_) 445 | print(scaler.scale_) 446 | 447 | # Write out scaler. 448 | out_path = os.path.join(workspace, "packed_features", "spectrogram", data_type, "mixdb", "scaler.p") 449 | create_folder(os.path.dirname(out_path)) 450 | pickle.dump(scaler, open(out_path, 'wb')) 451 | 452 | print("Save scaler to %s" % out_path) 453 | print("Compute scaler finished! %s s" % (time.time() - t1,)) 454 | 455 | def scale_on_2d(x2d, scaler): 456 | """Scale 2D array data. 457 | """ 458 | return scaler.transform(x2d) 459 | 460 | def scale_on_3d(x3d, scaler): 461 | """Scale 3D array data. 462 | """ 463 | (n_segs, n_concat, n_freq) = x3d.shape 464 | x2d = x3d.reshape((n_segs * n_concat, n_freq)) 465 | x2d = scaler.transform(x2d) 466 | x3d = x2d.reshape((n_segs, n_concat, n_freq)) 467 | return x3d 468 | 469 | def inverse_scale_on_2d(x2d, scaler): 470 | """Inverse scale 2D array data. 471 | """ 472 | return x2d * scaler.scale_[None, :] + scaler.mean_[None, :] 473 | 474 | ### 475 | def load_hdf5(hdf5_path): 476 | """Load hdf5 data. 477 | """ 478 | with h5py.File(hdf5_path, 'r') as hf: 479 | x = hf.get('x') 480 | y = hf.get('y') 481 | x = np.array(x) # (n_segs, n_concat, n_freq) 482 | y = np.array(y) # (n_segs, n_freq) 483 | return x, y 484 | 485 | def np_mean_absolute_error(y_true, y_pred): 486 | return np.mean(np.abs(y_pred - y_true)) 487 | 488 | ### 489 | 490 | 491 | 492 | def istft(): 493 | audio_output = scipy.signal.istft( 494 | spec, 495 | fs = 16000, 496 | window = ham_win, 497 | nfft = 320, 498 | nperseg = 256, 499 | noverlap = 128 500 | ) 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | if __name__ == '__main__': 509 | parser = argparse.ArgumentParser() 510 | subparsers = parser.add_subparsers(dest='mode') 511 | 512 | parser_create_mixture_csv = subparsers.add_parser('create_mixture_csv') 513 | parser_create_mixture_csv.add_argument('--workspace', type=str, required=True) 514 | parser_create_mixture_csv.add_argument('--speech_dir', type=str, required=True) 515 | parser_create_mixture_csv.add_argument('--noise_dir', type=str, required=True) 516 | parser_create_mixture_csv.add_argument('--data_type', type=str, required=True) 517 | parser_create_mixture_csv.add_argument('--magnification', type=int, default=1) 518 | 519 | parser_calculate_mixture_features = subparsers.add_parser('calculate_mixture_features') 520 | parser_calculate_mixture_features.add_argument('--workspace', type=str, required=True) 521 | parser_calculate_mixture_features.add_argument('--speech_dir', type=str, required=True) 522 | parser_calculate_mixture_features.add_argument('--noise_dir', type=str, required=True) 523 | parser_calculate_mixture_features.add_argument('--data_type', type=str, required=True) 524 | parser_calculate_mixture_features.add_argument('--snr', type=float, required=True) 525 | 526 | parser_pack_features = subparsers.add_parser('pack_features') 527 | parser_pack_features.add_argument('--workspace', type=str, required=True) 528 | parser_pack_features.add_argument('--data_type', type=str, required=True) 529 | parser_pack_features.add_argument('--snr', type=float, required=True) 530 | parser_pack_features.add_argument('--n_concat', type=int, required=True) 531 | parser_pack_features.add_argument('--n_hop', type=int, required=True) 532 | 533 | parser_compute_scaler = subparsers.add_parser('compute_scaler') 534 | parser_compute_scaler.add_argument('--workspace', type=str, required=True) 535 | parser_compute_scaler.add_argument('--data_type', type=str, required=True) 536 | parser_compute_scaler.add_argument('--snr', type=float, required=True) 537 | 538 | args = parser.parse_args() 539 | if args.mode == 'create_mixture_csv': 540 | create_mixture_csv(args) 541 | elif args.mode == 'calculate_mixture_features': 542 | calculate_mixture_features(args) 543 | elif args.mode == 'pack_features': 544 | pack_features(args) 545 | elif args.mode == 'compute_scaler': 546 | compute_scaler(args) 547 | else: 548 | raise Exception("Error!") 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | -------------------------------------------------------------------------------- /spectrogram_to_wave.py: -------------------------------------------------------------------------------- 1 | """ 2 | Summary: Recover spectrogram to wave. 3 | Author: Qiuqiang Kong 4 | Created: 2017.09 5 | Modified: - 6 | """ 7 | import numpy as np 8 | import numpy 9 | import decimal 10 | 11 | def recover_wav(pd_abs_x, gt_x, n_overlap, winfunc, wav_len=None): 12 | """Recover wave from spectrogram. 13 | If you are using scipy.signal.spectrogram, you may need to multipy a scaler 14 | to the recovered audio after using this function. For example, 15 | recover_scaler = np.sqrt((ham_win**2).sum()) 16 | 17 | Args: 18 | pd_abs_x: 2d array, (n_time, n_freq) 19 | gt_x: 2d complex array, (n_time, n_freq) 20 | n_overlap: integar. 21 | winfunc: func, the analysis window to apply to each frame. 22 | wav_len: integer. Pad or trunc to wav_len with zero. 23 | 24 | Returns: 25 | 1d array. 26 | """ 27 | x = real_to_complex(pd_abs_x, gt_x) 28 | x = half_to_whole(x) 29 | frames = ifft_to_wav(x) 30 | (n_frames, n_window) = frames.shape 31 | s = deframesig(frames=frames, siglen=0, frame_len=n_window, 32 | frame_step=n_window-n_overlap, winfunc=winfunc) 33 | if wav_len: 34 | s = pad_or_trunc(s, wav_len) 35 | return s 36 | 37 | def real_to_complex(pd_abs_x, gt_x): 38 | """Recover pred spectrogram's phase from ground truth's phase. 39 | 40 | Args: 41 | pd_abs_x: 2d array, (n_time, n_freq) 42 | gt_x: 2d complex array, (n_time, n_freq) 43 | 44 | Returns: 45 | 2d complex array, (n_time, n_freq) 46 | """ 47 | theta = np.angle(gt_x) 48 | cmplx = pd_abs_x * np.exp(1j * theta) 49 | return cmplx 50 | 51 | def half_to_whole(x): 52 | """Recover whole spectrogram from half spectrogram. 53 | """ 54 | return np.concatenate((x, np.fliplr(np.conj(x[:, 1:-1]))), axis=1) 55 | 56 | def ifft_to_wav(x): 57 | """Recover wav from whole spectrogram""" 58 | return np.real(np.fft.ifft(x)) 59 | 60 | def pad_or_trunc(s, wav_len): 61 | if len(s) >= wav_len: 62 | s = s[0 : wav_len] 63 | else: 64 | s = np.concatenate((s, np.zeros(wav_len - len(s)))) 65 | return s 66 | 67 | def recover_gt_wav(x, n_overlap, winfunc, wav_len=None): 68 | """Recover ground truth wav. 69 | """ 70 | x = half_to_whole(x) 71 | frames = ifft_to_wav(x) 72 | (n_frames, n_window) = frames.shape 73 | s = deframesig(frames=frames, siglen=0, frame_len=n_window, 74 | frame_step=n_window-n_overlap, winfunc=winfunc) 75 | if wav_len: 76 | s = pad_or_trunc(s, wav_len) 77 | return s 78 | 79 | def deframesig(frames,siglen,frame_len,frame_step,winfunc=lambda x:numpy.ones((x,))): 80 | """Does overlap-add procedure to undo the action of framesig. 81 | Ref: From https://github.com/jameslyons/python_speech_features 82 | 83 | :param frames: the array of frames. 84 | :param siglen: the length of the desired signal, use 0 if unknown. Output will be truncated to siglen samples. 85 | :param frame_len: length of each frame measured in samples. 86 | :param frame_step: number of samples after the start of the previous frame that the next frame should begin. 87 | :param winfunc: the analysis window to apply to each frame. By default no window is applied. 88 | :returns: a 1-D signal. 89 | """ 90 | frame_len = round_half_up(frame_len) 91 | frame_step = round_half_up(frame_step) 92 | numframes = numpy.shape(frames)[0] 93 | assert numpy.shape(frames)[1] == frame_len, '"frames" matrix is wrong size, 2nd dim is not equal to frame_len' 94 | 95 | indices = numpy.tile(numpy.arange(0,frame_len),(numframes,1)) + numpy.tile(numpy.arange(0,numframes*frame_step,frame_step),(frame_len,1)).T 96 | indices = numpy.array(indices,dtype=numpy.int32) 97 | padlen = (numframes-1)*frame_step + frame_len 98 | 99 | if siglen <= 0: siglen = padlen 100 | 101 | rec_signal = numpy.zeros((padlen,)) 102 | window_correction = numpy.zeros((padlen,)) 103 | win = winfunc(frame_len) 104 | 105 | for i in range(0,numframes): 106 | window_correction[indices[i,:]] = window_correction[indices[i,:]] + win + 1e-15 #add a little bit so it is never zero 107 | rec_signal[indices[i,:]] = rec_signal[indices[i,:]] + frames[i,:] 108 | 109 | rec_signal = rec_signal/window_correction 110 | return rec_signal[0:siglen] 111 | 112 | def round_half_up(number): 113 | return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP)) -------------------------------------------------------------------------------- /timit_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | timit_path = '../../database/TIMIT/TEST/' 5 | def rename_and_move_wavfile(timit_path): 6 | tmp = timit_path.split("/")[-2] 7 | if tmp=="TEST": 8 | target_path_base = './mini_data/test_speech/' 9 | elif tmp=="TRAIN": 10 | target_path_base = './mini_data/train_speech/' 11 | else: 12 | print("input path error") 13 | return 0 14 | for root, dirs, files in os.walk(timit_path): 15 | for file in files: 16 | curr_path = os.path.join(root, file) 17 | suffix = os.path.splitext(curr_path)[-1] 18 | if suffix == ".WAV": 19 | train_type = curr_path.split("/")[-4] 20 | district_type = curr_path.split("/")[-3] 21 | speeker_id = curr_path.split("/")[-2] 22 | sentence_id = curr_path.split("/")[-1] 23 | target_path = target_path_base + train_type + "_" + \ 24 | district_type+ "_" +speeker_id+ "_" +sentence_id 25 | shutil.copy(curr_path, target_path) 26 | --------------------------------------------------------------------------------