├── README.md ├── config.py ├── data ├── test_noise │ ├── noise1.wav │ └── noise2.wav ├── test_speech │ ├── TEST_DR2_MRCZ0_si2171.wav │ ├── TEST_DR7_FCAU0_sx317.wav │ └── TEST_DR8_MSLB0_sx203.wav ├── train_noise │ ├── noise1.wav │ └── noise2.wav └── train_speech │ ├── TRAIN_DR1_MPGR0_si2040.wav │ ├── TRAIN_DR5_FJXM0_sx401.wav │ └── TRAIN_DR8_FNKL0_sx262.wav ├── data_generator.py ├── demo.sh ├── demo_data ├── noise │ ├── noise1.wav │ └── noise2.wav ├── noisy │ ├── TEST_DR7_FCAU0_sx317.noise1.wav │ ├── TEST_DR7_FCAU0_sx317.noise2.wav │ ├── THCH_test_D8_770.noise1.wav │ └── THCH_test_D8_770.noise2.wav └── speech │ ├── TEST_DR7_FCAU0_sx317.wav │ └── THCH_test_D8_770.wav ├── evaluate.py ├── ffmpeg ├── main_dnn.py ├── models └── pretrained │ └── README.md ├── notes ├── THCH_test_D8_770-.wav ├── THCH_test_D8_770.noise1.ns_enh.wav ├── THCH_test_D8_770.noise1.wav ├── THCH_test_D8_770.noise2.ns_enh.wav ├── THCH_test_D8_770.noise2.wav ├── clear-d8-770.jpg ├── denoised-noise1-d8-770.jpg ├── denoised-noise2-d8-770.jpg ├── noise1-d8-770.jpg ├── noise2-d8-770.jpg ├── paypal.jpg └── wechat.jpg ├── ns ├── pesq ├── prepare_data.py ├── requirements.txt ├── runme.sh └── spectrogram_to_wave.py /README.md: -------------------------------------------------------------------------------- 1 | # ClearWave 2 | Denoise Speech by Deep Learning (Using Keras and Tensorflow) 3 | 4 | ------------------ 5 | 6 | This project is modified from deep neural network (DNN) by yongxuUSTC(https://github.com/yongxuUSTC/sednn). 7 | 8 | Also, the project uses ffmpeg, webrtc and pesq to deal with speech data. 9 | 10 | Before try the project, please download the base dnn model from https://pan.baidu.com/s/1eVnRkNb5xIn96aYOV8C-Gg 11 | and copy the .h5 file to ./models/pretrained/base_dnn_model.h5. 12 | 13 | ------------------ 14 | 15 | # Speech Samples 16 | You could download and listen the clear, noisy and denoised Speech: 17 | 18 | The Clear Speech ------- https://github.com/boozyguo/ClearWave/blob/master/notes/THCH_test_D8_770-.wav 19 | 20 | 21 | 22 | The noisy Speech -------https://github.com/boozyguo/ClearWave/blob/master/notes/THCH_test_D8_770.noise1.wav 23 | 24 | The Denoised Speech ----https://github.com/boozyguo/ClearWave/blob/master/notes/THCH_test_D8_770.noise1.ns_enh.wav 25 | 26 | 27 | 28 | The noisy Speech -------https://github.com/boozyguo/ClearWave/blob/master/notes/THCH_test_D8_770.noise2.wav 29 | 30 | The Denoised Speech ----https://github.com/boozyguo/ClearWave/blob/master/notes/THCH_test_D8_770.noise2.ns_enh.wav 31 | 32 | 33 | ------------------ 34 | 35 | ## Inference Usage: Denoise on noisy data. 36 | If you have noisy speech, you can edit and run "./demo.sh" to denoise the noisy file. 37 | 38 | 1. Put the noisy file in path "./demo_data/noisy/*.wav" 39 | 40 | 2. Edit the demo.sh file with "INPUT_NOISY=1" 41 | 42 | 3. Run ./demo.sh 43 | 44 | 4. Check the denoised speech in "demo_workspace/ns_enh_wavs/test/1000db/*.wav" 45 | 46 | ------------------ 47 | 48 | ## Inference Usage: Denoise on speech data and noise data. 49 | If you have clear speech and noise: 50 | 51 | 1. Put the noise file in path "./demo_data/noise/*.wav" 52 | 53 | 2. Put the clear speech file in path "./demo_data/clear/*.wav" 54 | 55 | 3. Edit the demo.sh file with "INPUT_NOISY=0". Also, you can modify the SNR in parameter "TE_SNR", for example "TE_SNR=5" is 5db. 56 | 57 | 4. Run ./demo.sh 58 | 59 | 5. Check the denoised speech in "demo_workspace/ns_enh_wavs/test/5db/*.wav" (if "TE_SNR=5") 60 | 61 | ------------------ 62 | 63 | ## Training Usage: Training model on speech data and noise data. 64 | If you want to train yourself model, just prepare your data, then run "./runme.sh": 65 | 66 | 1. Put the train noise file in path "./data/train_noise/*.wav" 67 | 68 | 2. Put the train clear speech file in path "./data/train_speech/*.wav" 69 | 70 | 3. Put the validtaion noise file in path "./data/test_noise/*.wav" 71 | 72 | 4. Put the train validtaion speech file in path "./data/test_speech/*.wav" 73 | 74 | 5. Edit the runme.sh file, set parameters: TR_SNR, TE_SNR, EPOCHS, LEARNING_RATE 75 | 76 | 6. Run ./runme.sh 77 | 78 | 7. Check the new model in "./workspace/models/5db/*.h5" (if "TE_SNR=5") 79 | 80 | ------------------ 81 | 82 | 83 | ## Models: 84 | 85 | ClearWave model based on simple DNN in keras: 86 | 87 | ```python 88 | n_concat = 7 89 | n_freq = 257 90 | n_hid = 2048 91 | model = Sequential() 92 | model.add(Flatten(input_shape=(n_concat, n_freq))) 93 | model.add(Dropout(0.1)) 94 | model.add(Dense(n_hid, activation='relu')) 95 | model.add(Dense(n_hid, activation='relu')) 96 | model.add(Dense(n_hid, activation='relu')) 97 | model.add(BatchNormalization()) 98 | model.add(Dropout(0.2)) 99 | model.add(Dense(n_hid, activation='relu')) 100 | model.add(Dense(n_hid, activation='relu')) 101 | model.add(Dense(n_hid, activation='relu')) 102 | model.add(Dropout(0.2)) 103 | model.add(Dense(n_hid, activation='relu')) 104 | model.add(Dense(n_hid, activation='relu')) 105 | model.add(Dense(n_hid, activation='relu')) 106 | model.add(Dropout(0.2)) 107 | model.add(Dense(n_freq, activation='relu')) 108 | model.summary() 109 | ``` 110 | 111 | ------------------ 112 | 113 | ## Run on THCH and 5 noises 114 | 115 | Training: 116 | 117 | Speech: THCHS30(http://cslt.org) 2178 training sentences. (selected 20% from 10893 testing sentences) 118 | 119 | Noise: 5 kinds of noises 120 | 121 | Testing: 122 | 123 | Speech: THCHS30 499 testing sentences (selected 20% from 2495 testing sentences) 124 | 125 | Noise: same to training 126 | 127 | 128 | The denoised PESQ is(SNR=5db): 129 | 130 |
131 | Calculate overall stats. 
132 | Noise            PESQ            
133 | ---------------------------------
134 | Cafeteria_Noise_16s_26s 2.81 +- 0.09    
135 | Fullsize_Car1_16s_26s 3.13 +- 0.09    
136 | Pub_Noise_16s_26s 2.51 +- 0.09    
137 | Outside_Traffic_2s_12s 2.89 +- 0.11    
138 | RockMusic01m48k_16s_26s 3.04 +- 0.09    
139 | ---------------------------------
140 | Avg.             2.87 +- 0.10
141 | 
142 | 143 | ------------------ 144 | 145 | ## Samples: 146 | There are some speech files in "./notes". 147 | 148 | The clear speech is "./notes/THCH_test_D8_770-.wav", which figures showed below: 149 | 150 | ![Clear Speech](https://github.com/boozyguo/ClearWave/blob/master/notes/clear-d8-770.jpg) 151 | 152 | 153 | The noisy file are "./notes/THCH_test_D8_770.noise1.wav" and "./notes/THCH_test_D8_770.noise2.wav", which figures showed below: 154 | 155 | ![Noisy1](https://github.com/boozyguo/ClearWave/blob/master/notes/noise1-d8-770.jpg) 156 | ![Noisy2](https://github.com/boozyguo/ClearWave/blob/master/notes/noise2-d8-770.jpg) 157 | 158 | 159 | The denoised file are "./notes/THCH_test_D8_770.noise1.ns_enh.wav" and "./notes/THCH_test_D8_770.noise2.ns_enh.wav", which figures showed below: 160 | 161 | ![Denoised1](https://github.com/boozyguo/ClearWave/blob/master/notes/denoised-noise1-d8-770.jpg) 162 | ![Denoised2](https://github.com/boozyguo/ClearWave/blob/master/notes/denoised-noise2-d8-770.jpg) 163 | 164 | 165 | 166 | ------------------ 167 | 168 | 169 | ## Donate: 170 | 171 | If the project could help you, please star it and give us some donations. Donations will be used to fund expenses related to development (e.g. to cover equipment and server maintenance costs), to sponsor bug fixing, feature development. 172 | 173 | 174 | WeChat Payment 175 | ![WeChat Payment](https://github.com/boozyguo/ClearWave/blob/master/notes/wechat.jpg) 176 | 177 | 178 | [Paypal Payment](http://paypal.me/githubClearWave) 179 | [![Paypal Payment](https://github.com/boozyguo/ClearWave/blob/master/notes/paypal.jpg)](http://paypal.me/githubClearWave) 180 | 181 | ------------------ 182 | 183 | ## Ref: 184 | 185 | https://github.com/yongxuUSTC/sednn 186 | 187 | http://arxiv.org/abs/1512.01882 188 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Summary: Config file. 3 | Author: Qiuqiang Kong 4 | Created: 2017.12.21 5 | Modified: - 6 | """ 7 | 8 | sample_rate = 16000 9 | n_window = 512 # windows size for FFT 10 | n_overlap = 256 # overlap of window -------------------------------------------------------------------------------- /data/test_noise/noise1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/data/test_noise/noise1.wav -------------------------------------------------------------------------------- /data/test_noise/noise2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/data/test_noise/noise2.wav -------------------------------------------------------------------------------- /data/test_speech/TEST_DR2_MRCZ0_si2171.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/data/test_speech/TEST_DR2_MRCZ0_si2171.wav -------------------------------------------------------------------------------- /data/test_speech/TEST_DR7_FCAU0_sx317.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/data/test_speech/TEST_DR7_FCAU0_sx317.wav -------------------------------------------------------------------------------- /data/test_speech/TEST_DR8_MSLB0_sx203.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/data/test_speech/TEST_DR8_MSLB0_sx203.wav -------------------------------------------------------------------------------- /data/train_noise/noise1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/data/train_noise/noise1.wav -------------------------------------------------------------------------------- /data/train_noise/noise2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/data/train_noise/noise2.wav -------------------------------------------------------------------------------- /data/train_speech/TRAIN_DR1_MPGR0_si2040.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/data/train_speech/TRAIN_DR1_MPGR0_si2040.wav -------------------------------------------------------------------------------- /data/train_speech/TRAIN_DR5_FJXM0_sx401.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/data/train_speech/TRAIN_DR5_FJXM0_sx401.wav -------------------------------------------------------------------------------- /data/train_speech/TRAIN_DR8_FNKL0_sx262.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/data/train_speech/TRAIN_DR8_FNKL0_sx262.wav -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class DataGenerator(object): 4 | def __init__(self, batch_size, type, te_max_iter=None): 5 | assert type in ['train', 'test'] 6 | self._batch_size_ = batch_size 7 | self._type_ = type 8 | self._te_max_iter_ = te_max_iter 9 | 10 | def generate(self, xs, ys): 11 | x = xs[0] 12 | y = ys[0] 13 | batch_size = self._batch_size_ 14 | n_samples = len(x) 15 | 16 | index = np.arange(n_samples) 17 | np.random.shuffle(index) 18 | 19 | iter = 0 20 | epoch = 0 21 | pointer = 0 22 | while True: 23 | if (self._type_ == 'test') and (self._te_max_iter_ is not None): 24 | if iter == self._te_max_iter_: 25 | break 26 | iter += 1 27 | if pointer >= n_samples: 28 | epoch += 1 29 | if (self._type_) == 'test' and (epoch == 1): 30 | break 31 | pointer = 0 32 | np.random.shuffle(index) 33 | 34 | batch_idx = index[pointer : min(pointer + batch_size, n_samples)] 35 | pointer += batch_size 36 | yield x[batch_idx], y[batch_idx] -------------------------------------------------------------------------------- /demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CMD="main_dnn.py" 4 | 5 | 6 | MODEL_FILE="./models/pretrained/base_dnn_model.h5" 7 | INPUT_NOISY=1 8 | 9 | WORKSPACE="./demo_workspace" 10 | mkdir $WORKSPACE 11 | DEMO_SPEECH_DIR="./demo_data/speech" 12 | DEMO_NOISE_DIR="./demo_data/noise" 13 | DEMO_NOISY_DIR="./demo_data/noisy" 14 | echo "Denoise Demo. " 15 | 16 | 17 | TR_SNR=5 18 | TE_SNR=5 19 | N_CONCAT=7 20 | N_HOP=2 21 | CALC_LOG=0 22 | #EPOCHS=10000 23 | ITERATION=10000 24 | #LEARNING_RATE=1e-3 25 | 26 | CALC_DATA=1 27 | if [ $CALC_DATA -eq 1 ]; then 28 | 29 | if [ $INPUT_NOISY -eq 0 ]; then 30 | # Create mixture csv. 31 | echo "Go:Create mixture csv. " 32 | python prepare_data.py create_mixture_csv --workspace=$WORKSPACE --speech_dir=$DEMO_SPEECH_DIR --noise_dir=$DEMO_NOISE_DIR --data_type=test --speechratio=1 33 | 34 | # Calculate mixture features. 35 | echo "Go:Calculate mixture features. " 36 | python prepare_data.py calculate_mixture_features --workspace=$WORKSPACE --speech_dir=$DEMO_SPEECH_DIR --noise_dir=$DEMO_NOISE_DIR --data_type=test --snr=$TE_SNR 37 | else 38 | # Calculate noisy features. 39 | TE_SNR=1000 40 | echo "Go:Calculate noisy features. " 41 | python prepare_data.py calculate_noisy_features --workspace=$WORKSPACE --noisy_dir=$DEMO_NOISY_DIR --data_type=test --snr=$TE_SNR 42 | fi 43 | 44 | echo "Data finish!" 45 | #exit 46 | 47 | fi 48 | 49 | 50 | # Inference, enhanced wavs will be created. 51 | echo "Inference, enhanced wavs will be created. " 52 | CUDA_VISIBLE_DEVICES=0 python $CMD inference --workspace=$WORKSPACE --tr_snr=$TR_SNR --te_snr=$TE_SNR --n_concat=$N_CONCAT --iteration=$ITERATION --calc_log=$CALC_LOG --model_file=$MODEL_FILE 53 | 54 | 55 | -------------------------------------------------------------------------------- /demo_data/noise/noise1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/demo_data/noise/noise1.wav -------------------------------------------------------------------------------- /demo_data/noise/noise2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/demo_data/noise/noise2.wav -------------------------------------------------------------------------------- /demo_data/noisy/TEST_DR7_FCAU0_sx317.noise1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/demo_data/noisy/TEST_DR7_FCAU0_sx317.noise1.wav -------------------------------------------------------------------------------- /demo_data/noisy/TEST_DR7_FCAU0_sx317.noise2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/demo_data/noisy/TEST_DR7_FCAU0_sx317.noise2.wav -------------------------------------------------------------------------------- /demo_data/noisy/THCH_test_D8_770.noise1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/demo_data/noisy/THCH_test_D8_770.noise1.wav -------------------------------------------------------------------------------- /demo_data/noisy/THCH_test_D8_770.noise2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/demo_data/noisy/THCH_test_D8_770.noise2.wav -------------------------------------------------------------------------------- /demo_data/speech/TEST_DR7_FCAU0_sx317.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/demo_data/speech/TEST_DR7_FCAU0_sx317.wav -------------------------------------------------------------------------------- /demo_data/speech/THCH_test_D8_770.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/demo_data/speech/THCH_test_D8_770.wav -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Summary: Calculate PESQ and overal stats of enhanced speech. 3 | Author: Qiuqiang Kong 4 | Created: 2017.12.22 5 | Modified: - 6 | """ 7 | import argparse 8 | import os 9 | import csv 10 | import numpy as np 11 | import cPickle 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | def plot_training_stat(args): 16 | """Plot training and testing loss. 17 | 18 | Args: 19 | workspace: str, path of workspace. 20 | tr_snr: float, training SNR. 21 | bgn_iter: int, plot from bgn_iter 22 | fin_iter: int, plot finish at fin_iter 23 | interval_iter: int, interval of files. 24 | """ 25 | workspace = args.workspace 26 | tr_snr = args.tr_snr 27 | bgn_iter = args.bgn_iter 28 | fin_iter = args.fin_iter 29 | interval_iter = args.interval_iter 30 | 31 | tr_losses, te_losses, iters = [], [], [] 32 | 33 | # Load stats. 34 | stats_dir = os.path.join(workspace, "training_stats", "%ddb" % int(tr_snr)) 35 | for iter in xrange(bgn_iter, fin_iter, interval_iter): 36 | stats_path = os.path.join(stats_dir, "%diters.p" % iter) 37 | dict = cPickle.load(open(stats_path, 'rb')) 38 | tr_losses.append(dict['tr_loss']) 39 | te_losses.append(dict['te_loss']) 40 | iters.append(dict['iter']) 41 | 42 | # Plot 43 | line_tr, = plt.plot(tr_losses, c='b', label="Train") 44 | line_te, = plt.plot(te_losses, c='r', label="Test") 45 | plt.axis([0, len(iters), 0, max(tr_losses)]) 46 | plt.xlabel("Iterations") 47 | plt.ylabel("Loss") 48 | plt.legend(handles=[line_tr, line_te]) 49 | plt.xticks(np.arange(len(iters)), iters) 50 | plt.show() 51 | 52 | 53 | 54 | 55 | def calculate_noisy_pesq(args): 56 | """Calculate PESQ of all enhaced speech. 57 | 58 | Args: 59 | workspace: str, path of workspace. 60 | speech_dir: str, path of clean speech. 61 | te_snr: float, testing SNR. 62 | """ 63 | workspace = args.workspace 64 | speech_dir = args.speech_dir 65 | te_snr = args.te_snr 66 | 67 | # Remove already existed file. 68 | os.system('rm _pesq_itu_results.txt') 69 | os.system('rm _pesq_results.txt') 70 | 71 | # Calculate PESQ of all enhaced speech. 72 | noisy_speech_dir = os.path.join(workspace, "mixed_audios", "spectrogram", "test", "%ddb" % int(te_snr)) 73 | names = os.listdir(noisy_speech_dir) 74 | print(names) 75 | for (cnt, na) in enumerate(names): 76 | print(cnt, na) 77 | enh_path = os.path.join(noisy_speech_dir, na) 78 | 79 | speech_na = na.split('.')[0] 80 | speech_path = os.path.join(speech_dir, "%s.wav" % speech_na) 81 | 82 | # Call executable PESQ tool. 83 | cmd = ' '.join(["./pesq", speech_path, enh_path, "+16000"]) 84 | os.system(cmd) 85 | 86 | 87 | 88 | def calculate_pesq(args): 89 | """Calculate PESQ of all enhaced speech. 90 | 91 | Args: 92 | workspace: str, path of workspace. 93 | speech_dir: str, path of clean speech. 94 | te_snr: float, testing SNR. 95 | """ 96 | workspace = args.workspace 97 | speech_dir = args.speech_dir 98 | te_snr = args.te_snr 99 | 100 | # Remove already existed file. 101 | os.system('rm _pesq_itu_results.txt') 102 | os.system('rm _pesq_results.txt') 103 | 104 | # Calculate PESQ of all enhaced speech. 105 | enh_speech_dir = os.path.join(workspace, "enh_wavs", "test", "%ddb" % int(te_snr)) 106 | names = os.listdir(enh_speech_dir) 107 | print(names) 108 | for (cnt, na) in enumerate(names): 109 | print(cnt, na) 110 | enh_path = os.path.join(enh_speech_dir, na) 111 | 112 | speech_na = na.split('.')[0] 113 | speech_path = os.path.join(speech_dir, "%s.wav" % speech_na) 114 | 115 | # Call executable PESQ tool. 116 | cmd = ' '.join(["./pesq", speech_path, enh_path, "+16000"]) 117 | os.system(cmd) 118 | 119 | 120 | def get_stats(args): 121 | """Calculate stats of PESQ. 122 | """ 123 | pesq_path = "_pesq_results.txt" 124 | with open(pesq_path, 'rb') as f: 125 | reader = csv.reader(f, delimiter='\t') 126 | lis = list(reader) 127 | 128 | pesq_dict = {} 129 | for i1 in xrange(1, len(lis) - 1): 130 | li = lis[i1] 131 | na = li[0] 132 | pesq = float(li[1]) 133 | noise_type = na.split('.')[1] 134 | if noise_type not in pesq_dict.keys(): 135 | pesq_dict[noise_type] = [pesq] 136 | else: 137 | pesq_dict[noise_type].append(pesq) 138 | 139 | avg_list, std_list = [], [] 140 | f = "{0:<16} {1:<16}" 141 | print(f.format("Noise", "PESQ")) 142 | print("---------------------------------") 143 | for noise_type in pesq_dict.keys(): 144 | pesqs = pesq_dict[noise_type] 145 | avg_pesq = np.mean(pesqs) 146 | std_pesq = np.std(pesqs) 147 | avg_list.append(avg_pesq) 148 | std_list.append(std_pesq) 149 | print(f.format(noise_type, "%.2f +- %.2f" % (avg_pesq, std_pesq))) 150 | print("---------------------------------") 151 | print(f.format("Avg.", "%.2f +- %.2f" % (np.mean(avg_list), np.mean(std_list)))) 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser() 156 | subparsers = parser.add_subparsers(dest='mode') 157 | 158 | parser_plot_training_stat = subparsers.add_parser('plot_training_stat') 159 | parser_plot_training_stat.add_argument('--workspace', type=str, required=True) 160 | parser_plot_training_stat.add_argument('--tr_snr', type=float, required=True) 161 | parser_plot_training_stat.add_argument('--bgn_iter', type=int, required=True) 162 | parser_plot_training_stat.add_argument('--fin_iter', type=int, required=True) 163 | parser_plot_training_stat.add_argument('--interval_iter', type=int, required=True) 164 | 165 | parser_calculate_pesq = subparsers.add_parser('calculate_pesq') 166 | parser_calculate_pesq.add_argument('--workspace', type=str, required=True) 167 | parser_calculate_pesq.add_argument('--speech_dir', type=str, required=True) 168 | parser_calculate_pesq.add_argument('--te_snr', type=float, required=True) 169 | 170 | parser_calculate_pesq = subparsers.add_parser('calculate_noisy_pesq') 171 | parser_calculate_pesq.add_argument('--workspace', type=str, required=True) 172 | parser_calculate_pesq.add_argument('--speech_dir', type=str, required=True) 173 | parser_calculate_pesq.add_argument('--te_snr', type=float, required=True) 174 | 175 | parser_get_stats = subparsers.add_parser('get_stats') 176 | 177 | args = parser.parse_args() 178 | 179 | if args.mode == 'plot_training_stat': 180 | plot_training_stat(args) 181 | elif args.mode == 'calculate_pesq': 182 | calculate_pesq(args) 183 | elif args.mode == 'calculate_noisy_pesq': 184 | calculate_noisy_pesq(args) 185 | elif args.mode == 'get_stats': 186 | get_stats(args) 187 | else: 188 | raise Exception("Error!") 189 | -------------------------------------------------------------------------------- /ffmpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/ffmpeg -------------------------------------------------------------------------------- /main_dnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Summary: Train, inference and evaluate speech enhancement. 3 | Author: Qiuqiang Kong 4 | Created: 2017.12.22 5 | Modified: - 6 | """ 7 | import numpy as np 8 | import os 9 | import pickle 10 | import cPickle 11 | import h5py 12 | import argparse 13 | import time 14 | import glob 15 | import matplotlib.pyplot as plt 16 | 17 | import prepare_data as pp_data 18 | import config as cfg 19 | from data_generator import DataGenerator 20 | from spectrogram_to_wave import recover_wav 21 | 22 | from keras.layers.normalization import BatchNormalization 23 | from keras.models import Sequential 24 | from keras.layers import Dense, Dropout, Flatten 25 | from keras.optimizers import Adam 26 | from keras.models import load_model 27 | 28 | 29 | def eval(model, gen, x, y): 30 | """Validation function. 31 | 32 | Args: 33 | model: keras model. 34 | gen: object, data generator. 35 | x: 3darray, input, (n_segs, n_concat, n_freq) 36 | y: 2darray, target, (n_segs, n_freq) 37 | """ 38 | pred_all, y_all = [], [] 39 | 40 | # Inference in mini batch. 41 | for (batch_x, batch_y) in gen.generate(xs=[x], ys=[y]): 42 | pred = model.predict(batch_x) 43 | pred_all.append(pred) 44 | y_all.append(batch_y) 45 | if False: 46 | print("pred") 47 | print(pred) 48 | 49 | # Concatenate mini batch prediction. 50 | pred_all = np.concatenate(pred_all, axis=0) 51 | y_all = np.concatenate(y_all, axis=0) 52 | 53 | 54 | 55 | # Compute loss. 56 | loss = pp_data.np_mean_absolute_error(y_all, pred_all) 57 | return loss 58 | 59 | 60 | def train(args): 61 | """Train the neural network. Write out model every several iterations. 62 | 63 | Args: 64 | workspace: str, path of workspace. 65 | tr_snr: float, training SNR. 66 | te_snr: float, testing SNR. 67 | lr: float, learning rate. 68 | """ 69 | print(args) 70 | workspace = args.workspace 71 | tr_snr = args.tr_snr 72 | te_snr = args.te_snr 73 | lr = args.lr 74 | calc_log = args.calc_log 75 | epoch = args.epoch 76 | 77 | scale = True 78 | 79 | # Load data. 80 | t1 = time.time() 81 | tr_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "%ddb" % int(tr_snr), "data.h5") 82 | te_hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", "test", "%ddb" % int(te_snr), "data.h5") 83 | (tr_x, tr_y) = pp_data.load_hdf5(tr_hdf5_path) 84 | (te_x, te_y) = pp_data.load_hdf5(te_hdf5_path) 85 | print(tr_x.shape, tr_y.shape) 86 | print(te_x.shape, te_y.shape) 87 | print("Load data time: %s s" % (time.time() - t1,)) 88 | 89 | batch_size = 128 #128 #500 90 | print("%d iterations / epoch" % int(tr_x.shape[0] / batch_size)) 91 | 92 | # Scale data. 93 | 94 | t1 = time.time() 95 | if calc_log: 96 | scaler_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "%ddb" % int(tr_snr), "scaler.p") 97 | scaler = pickle.load(open(scaler_path, 'rb')) 98 | tr_x = pp_data.scale_on_3d(tr_x, scaler) 99 | tr_y = pp_data.scale_on_2d(tr_y, scaler) 100 | te_x = pp_data.scale_on_3d(te_x, scaler) 101 | te_y = pp_data.scale_on_2d(te_y, scaler) 102 | else: 103 | print("max of tr_x:", np.max(tr_x)) 104 | print("max of tr_y:", np.max(tr_y)) 105 | print("max of te_x:", np.max(te_x)) 106 | print("max of te_y:", np.max(te_y)) 107 | tr_x = tr_x / np.max(tr_x) 108 | tr_y = tr_y / np.max(tr_y) 109 | te_x = te_x / np.max(te_x) 110 | te_y = te_y / np.max(te_y) 111 | 112 | print("Scale data time: %s s" % (time.time() - t1,)) 113 | 114 | # Debug plot. 115 | if False: 116 | plt.matshow(tr_x[0 : 1000, 0, :].T, origin='lower', aspect='auto', cmap='jet') 117 | plt.show() 118 | pause 119 | 120 | # Build model 121 | (_, n_concat, n_freq) = tr_x.shape 122 | n_hid = 2048 123 | 124 | model = Sequential() 125 | model.add(Flatten(input_shape=(n_concat, n_freq))) 126 | model.add(Dropout(0.1)) 127 | model.add(Dense(n_hid, activation='relu')) 128 | model.add(Dense(n_hid, activation='relu')) 129 | model.add(Dense(n_hid, activation='relu')) 130 | model.add(BatchNormalization()) 131 | model.add(Dropout(0.2)) 132 | model.add(Dense(n_hid, activation='relu')) 133 | model.add(Dense(n_hid, activation='relu')) 134 | model.add(Dense(n_hid, activation='relu')) 135 | model.add(Dropout(0.2)) 136 | model.add(Dense(n_hid, activation='relu')) 137 | model.add(Dense(n_hid, activation='relu')) 138 | model.add(Dense(n_hid, activation='relu')) 139 | model.add(Dropout(0.2)) 140 | 141 | if calc_log: 142 | model.add(Dense(n_freq, activation='linear')) 143 | else: 144 | model.add(Dense(n_freq, activation='relu')) 145 | model.summary() 146 | 147 | model.compile(loss='mean_absolute_error', 148 | optimizer=Adam(lr=lr)) 149 | 150 | # Data generator. 151 | tr_gen = DataGenerator(batch_size=batch_size, type='train') 152 | eval_te_gen = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100) 153 | eval_tr_gen = DataGenerator(batch_size=batch_size, type='test', te_max_iter=100) 154 | 155 | # Directories for saving models and training stats 156 | model_dir = os.path.join(workspace, "models", "%ddb" % int(tr_snr)) 157 | pp_data.create_folder(model_dir) 158 | 159 | stats_dir = os.path.join(workspace, "training_stats", "%ddb" % int(tr_snr)) 160 | pp_data.create_folder(stats_dir) 161 | 162 | # Print loss before training. 163 | iter = 0 164 | tr_loss = eval(model, eval_tr_gen, tr_x, tr_y) 165 | te_loss = eval(model, eval_te_gen, te_x, te_y) 166 | if False: 167 | print("tr_x") 168 | print(tr_x) 169 | print("tr_y") 170 | print(tr_y) 171 | print("te_x") 172 | print(te_x) 173 | print("te_y") 174 | print(te_y) 175 | print("Iteration: %d, tr_loss: %2.20f, te_loss: %2.20f" % (iter, tr_loss, te_loss)) 176 | 177 | # Save out training stats. 178 | stat_dict = {'iter': iter, 179 | 'tr_loss': tr_loss, 180 | 'te_loss': te_loss, } 181 | stat_path = os.path.join(stats_dir, "%diters.p" % iter) 182 | cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 183 | 184 | # Train. 185 | t1 = time.time() 186 | for (batch_x, batch_y) in tr_gen.generate(xs=[tr_x], ys=[tr_y]): 187 | loss = model.train_on_batch(batch_x, batch_y) 188 | iter += 1 189 | 190 | # Validate and save training stats. 191 | if iter % 200 == 0: 192 | tr_loss = eval(model, eval_tr_gen, tr_x, tr_y) 193 | te_loss = eval(model, eval_te_gen, te_x, te_y) 194 | print("Iteration: %d, tr_loss: %2.20f, te_loss: %2.20f" % (iter, tr_loss, te_loss)) 195 | 196 | # Save out training stats. 197 | stat_dict = {'iter': iter, 198 | 'tr_loss': tr_loss, 199 | 'te_loss': te_loss, } 200 | stat_path = os.path.join(stats_dir, "%diters.p" % iter) 201 | cPickle.dump(stat_dict, open(stat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 202 | 203 | # Save model. 204 | if iter % 200 == 0: 205 | model_path = os.path.join(model_dir, "md_%diters.h5" % iter) 206 | print model_path 207 | #model.save(model_path) 208 | model.save_weights(model_path) 209 | print("Saved model to %s" % model_path) 210 | 211 | 212 | 213 | if iter == (epoch+1): 214 | break 215 | 216 | print("Training time: %s s" % (time.time() - t1,)) 217 | 218 | def inference(args): 219 | """Inference all test data, write out recovered wavs to disk. 220 | 221 | Args: 222 | workspace: str, path of workspace. 223 | tr_snr: float, training SNR. 224 | te_snr: float, testing SNR. 225 | n_concat: int, number of frames to concatenta, should equal to n_concat 226 | in the training stage. 227 | iter: int, iteration of model to load. 228 | visualize: bool, plot enhanced spectrogram for debug. 229 | """ 230 | print(args) 231 | workspace = args.workspace 232 | tr_snr = args.tr_snr 233 | te_snr = args.te_snr 234 | n_concat = args.n_concat 235 | iter = args.iteration 236 | calc_log = args.calc_log 237 | model_file = args.model_file 238 | 239 | n_window = cfg.n_window 240 | n_overlap = cfg.n_overlap 241 | fs = cfg.sample_rate 242 | scale = True 243 | 244 | 245 | # Build model 246 | n_concat = 7 247 | n_freq = 257 248 | n_hid = 2048 249 | lr = 1e-3 250 | 251 | model = Sequential() 252 | model.add(Flatten(input_shape=(n_concat, n_freq))) 253 | model.add(Dropout(0.1)) 254 | model.add(Dense(n_hid, activation='relu')) 255 | model.add(Dense(n_hid, activation='relu')) 256 | model.add(Dense(n_hid, activation='relu')) 257 | model.add(BatchNormalization()) 258 | model.add(Dropout(0.2)) 259 | model.add(Dense(n_hid, activation='relu')) 260 | model.add(Dense(n_hid, activation='relu')) 261 | model.add(Dense(n_hid, activation='relu')) 262 | model.add(Dropout(0.2)) 263 | model.add(Dense(n_hid, activation='relu')) 264 | model.add(Dense(n_hid, activation='relu')) 265 | model.add(Dense(n_hid, activation='relu')) 266 | model.add(Dropout(0.2)) 267 | if calc_log: 268 | model.add(Dense(n_freq, activation='linear')) 269 | else: 270 | model.add(Dense(n_freq, activation='relu')) 271 | model.summary() 272 | 273 | model.compile(loss='mean_absolute_error', 274 | optimizer=Adam(lr=lr)) 275 | 276 | 277 | 278 | # Load model. 279 | if (model_file=="null"): 280 | model_path = os.path.join(workspace, "models", "%ddb" % int(tr_snr), "md_%diters.h5" % iter) 281 | #model = load_model(model_path) 282 | model.load_weights(model_path) 283 | else: 284 | model.load_weights(model_file) 285 | 286 | 287 | # Load scaler. 288 | if calc_log: 289 | scaler_path = os.path.join(workspace, "packed_features", "spectrogram", "train", "%ddb" % int(tr_snr), "scaler.p") 290 | scaler = pickle.load(open(scaler_path, 'rb')) 291 | 292 | # Load test data. 293 | feat_dir = os.path.join(workspace, "features", "spectrogram", "test", "%ddb" % int(te_snr)) 294 | names = os.listdir(feat_dir) 295 | 296 | for (cnt, na) in enumerate(names): 297 | # Load feature. 298 | feat_path = os.path.join(feat_dir, na) 299 | data = cPickle.load(open(feat_path, 'rb')) 300 | [mixed_cmplx_x, speech_x, noise_x, alpha, na] = data 301 | mixed_x = np.abs(mixed_cmplx_x) 302 | 303 | # Process data. 304 | n_pad = (n_concat - 1) / 2 305 | mixed_x = pp_data.pad_with_border(mixed_x, n_pad) 306 | if calc_log: 307 | mixed_x = pp_data.log_sp(mixed_x) 308 | #speech_x = pp_data.log_sp(speech_x) 309 | else: 310 | mixed_x = mixed_x 311 | #speech_x = speech_x 312 | 313 | # Scale data. 314 | if calc_log: 315 | mixed_x = pp_data.scale_on_2d(mixed_x, scaler) 316 | #speech_x = pp_data.scale_on_2d(speech_x, scaler) 317 | else: 318 | mixed_x_max = np.max(mixed_x) 319 | print("max of tr_x:", mixed_x_max) 320 | mixed_x = mixed_x / mixed_x_max 321 | 322 | speech_x_max = np.max(speech_x) 323 | print("max of speech_x:", speech_x_max) 324 | speech_x = speech_x / speech_x_max 325 | 326 | 327 | # Cut input spectrogram to 3D segments with n_concat. 328 | mixed_x_3d = pp_data.mat_2d_to_3d(mixed_x, agg_num=n_concat, hop=1) 329 | 330 | # Predict. 331 | if False: 332 | print(mixed_x_3d) 333 | pred = model.predict(mixed_x_3d) 334 | print(cnt, na) 335 | if False: 336 | print("pred") 337 | print(pred) 338 | print("speech") 339 | print(speech_x) 340 | 341 | # Inverse scale. 342 | if calc_log: 343 | mixed_x = pp_data.inverse_scale_on_2d(mixed_x, scaler) 344 | #speech_x = pp_data.inverse_scale_on_2d(speech_x, scaler) 345 | pred = pp_data.inverse_scale_on_2d(pred, scaler) 346 | else: 347 | mixed_x = mixed_x * mixed_x_max 348 | #speech_x = speech_x * 16384 349 | pred = pred * mixed_x_max 350 | 351 | # Debug plot. 352 | if args.visualize: 353 | fig, axs = plt.subplots(3,1, sharex=False) 354 | axs[0].matshow(mixed_x.T, origin='lower', aspect='auto', cmap='jet') 355 | #axs[1].matshow(speech_x.T, origin='lower', aspect='auto', cmap='jet') 356 | axs[2].matshow(pred.T, origin='lower', aspect='auto', cmap='jet') 357 | axs[0].set_title("%ddb mixture log spectrogram" % int(te_snr)) 358 | axs[1].set_title("Clean speech log spectrogram") 359 | axs[2].set_title("Enhanced speech log spectrogram") 360 | for j1 in xrange(3): 361 | axs[j1].xaxis.tick_bottom() 362 | plt.tight_layout() 363 | plt.show() 364 | 365 | # Recover enhanced wav. 366 | if calc_log: 367 | pred_sp = np.exp(pred) 368 | else: 369 | #gv = 0.025 370 | #pred_sp = np.maximum(0,pred - gv) 371 | pred_sp = pred 372 | 373 | if False: 374 | pred_sp = mixed_x[3:-3] 375 | 376 | s = recover_wav(pred_sp, mixed_cmplx_x, n_overlap, np.hamming) 377 | s *= np.sqrt((np.hamming(n_window)**2).sum()) # Scaler for compensate the amplitude 378 | # change after spectrogram and IFFT. 379 | 380 | # Write out enhanced wav. 381 | out_path = os.path.join(workspace, "enh_wavs", "test", "%ddb" % int(te_snr), "%s.enh.wav" % na) 382 | pp_data.create_folder(os.path.dirname(out_path)) 383 | pp_data.write_audio(out_path, s, fs) 384 | # Write out enhanced pcm 8K pcm_s16le. 385 | out_pcm_path = os.path.join(workspace, "enh_wavs", "test", "%ddb" % int(te_snr), "%s.enh.pcm" % na) 386 | cmd = ' '.join(["./ffmpeg -y -i ", out_path, " -f s16le -ar 8000 -ac 1 -acodec pcm_s16le ", out_pcm_path]) 387 | os.system(cmd) 388 | 389 | # Write out webrtc-denoised enhanced pcm 8K pcm_s16le. 390 | ns_out_pcm_path = os.path.join(workspace, "ns_enh_wavs", "test", "%ddb" % int(te_snr), "%s.ns_enh.pcm" % na) 391 | ns_out_wav_path = os.path.join(workspace, "ns_enh_wavs", "test", "%ddb" % int(te_snr), "%s.ns_enh.wav" % na) 392 | pp_data.create_folder(os.path.dirname(ns_out_pcm_path)) 393 | cmd = ' '.join(["./ns", out_pcm_path, ns_out_pcm_path]) 394 | os.system(cmd) 395 | cmd = ' '.join(["./ffmpeg -y -f s16le -ar 8000 -ac 1 -acodec pcm_s16le -i ", ns_out_pcm_path, " ",ns_out_wav_path]) 396 | os.system(cmd) 397 | 398 | cmd = ' '.join(["rm ", out_pcm_path]) 399 | os.system(cmd) 400 | cmd = ' '.join(["rm ", ns_out_pcm_path]) 401 | os.system(cmd) 402 | 403 | 404 | if __name__ == '__main__': 405 | parser = argparse.ArgumentParser() 406 | subparsers = parser.add_subparsers(dest='mode') 407 | 408 | parser_train = subparsers.add_parser('train') 409 | parser_train.add_argument('--workspace', type=str, required=True) 410 | parser_train.add_argument('--tr_snr', type=float, required=True) 411 | parser_train.add_argument('--te_snr', type=float, required=True) 412 | parser_train.add_argument('--lr', type=float, required=True) 413 | parser_train.add_argument('--calc_log', type=int, required=True) 414 | parser_train.add_argument('--epoch', type=int, required=True) 415 | 416 | parser_inference = subparsers.add_parser('inference') 417 | parser_inference.add_argument('--workspace', type=str, required=True) 418 | parser_inference.add_argument('--tr_snr', type=float, required=True) 419 | parser_inference.add_argument('--te_snr', type=float, required=True) 420 | parser_inference.add_argument('--n_concat', type=int, required=True) 421 | parser_inference.add_argument('--iteration', type=int, required=True) 422 | parser_inference.add_argument('--calc_log', type=int, required=True) 423 | parser_inference.add_argument('--model_file', type=str, required=True) 424 | parser_inference.add_argument('--visualize', action='store_true', default=False) 425 | 426 | 427 | parser_calculate_pesq = subparsers.add_parser('calculate_pesq') 428 | parser_calculate_pesq.add_argument('--workspace', type=str, required=True) 429 | parser_calculate_pesq.add_argument('--speech_dir', type=str, required=True) 430 | parser_calculate_pesq.add_argument('--te_snr', type=float, required=True) 431 | 432 | args = parser.parse_args() 433 | 434 | if args.mode == 'train': 435 | train(args) 436 | elif args.mode == 'inference': 437 | inference(args) 438 | elif args.mode == 'calculate_pesq': 439 | calculate_pesq(args) 440 | else: 441 | raise Exception("Error!") 442 | -------------------------------------------------------------------------------- /models/pretrained/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ClearWave Net: Simple DNN 4 | 5 | the Weights: base_dnn_model.h5 (Please download weights file from https://pan.baidu.com/s/1eVnRkNb5xIn96aYOV8C-Gg to ./models/pretrained/) 6 | -------------------------------------------------------------------------------- /notes/THCH_test_D8_770-.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/notes/THCH_test_D8_770-.wav -------------------------------------------------------------------------------- /notes/THCH_test_D8_770.noise1.ns_enh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/notes/THCH_test_D8_770.noise1.ns_enh.wav -------------------------------------------------------------------------------- /notes/THCH_test_D8_770.noise1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/notes/THCH_test_D8_770.noise1.wav -------------------------------------------------------------------------------- /notes/THCH_test_D8_770.noise2.ns_enh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/notes/THCH_test_D8_770.noise2.ns_enh.wav -------------------------------------------------------------------------------- /notes/THCH_test_D8_770.noise2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/notes/THCH_test_D8_770.noise2.wav -------------------------------------------------------------------------------- /notes/clear-d8-770.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/notes/clear-d8-770.jpg -------------------------------------------------------------------------------- /notes/denoised-noise1-d8-770.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/notes/denoised-noise1-d8-770.jpg -------------------------------------------------------------------------------- /notes/denoised-noise2-d8-770.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/notes/denoised-noise2-d8-770.jpg -------------------------------------------------------------------------------- /notes/noise1-d8-770.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/notes/noise1-d8-770.jpg -------------------------------------------------------------------------------- /notes/noise2-d8-770.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/notes/noise2-d8-770.jpg -------------------------------------------------------------------------------- /notes/paypal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/notes/paypal.jpg -------------------------------------------------------------------------------- /notes/wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/notes/wechat.jpg -------------------------------------------------------------------------------- /ns: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/ns -------------------------------------------------------------------------------- /pesq: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boozyguo/ClearWave/3e652d50114bea66f817a6f2d1d057b3c4071f37/pesq -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Summary: Prepare data. 3 | Author: Qiuqiang Kong 4 | Created: 2017.12.22 5 | Modified: - 6 | """ 7 | import os 8 | import soundfile 9 | import numpy as np 10 | import argparse 11 | import csv 12 | import time 13 | import matplotlib.pyplot as plt 14 | from scipy import signal 15 | import pickle 16 | import cPickle 17 | import h5py 18 | import librosa 19 | from sklearn import preprocessing 20 | 21 | import prepare_data as pp_data 22 | import config as cfg 23 | 24 | 25 | def create_folder(fd): 26 | if not os.path.exists(fd): 27 | os.makedirs(fd) 28 | 29 | def read_audio(path, target_fs=None): 30 | print 'file is:' 31 | print path 32 | (audio, fs) = soundfile.read(path) 33 | if audio.ndim > 1: 34 | audio = np.mean(audio, axis=1) 35 | if target_fs is not None and fs != target_fs: 36 | audio = librosa.resample(audio, orig_sr=fs, target_sr=target_fs) 37 | fs = target_fs 38 | return audio, fs 39 | 40 | def write_audio(path, audio, sample_rate): 41 | soundfile.write(file=path, data=audio, samplerate=sample_rate) 42 | 43 | ### 44 | def create_mixture_csv(args): 45 | """Create csv containing mixture information. 46 | Each line in the .csv file contains [speech_name, noise_name, noise_onset, noise_offset] 47 | 48 | Args: 49 | workspace: str, path of workspace. 50 | speech_dir: str, path of speech data. 51 | noise_dir: str, path of noise data. 52 | data_type: str, 'train' | 'test'. 53 | magnification: int, only used when data_type='train', number of noise 54 | selected to mix with a speech. E.g., when magnication=3, then 4620 55 | speech with create 4620*3 mixtures. magnification should not larger 56 | than the species of noises. 57 | """ 58 | workspace = args.workspace 59 | speech_dir = args.speech_dir 60 | noise_dir = args.noise_dir 61 | data_type = args.data_type 62 | speechratio = args.speechratio 63 | magnification = args.magnification 64 | fs = cfg.sample_rate 65 | 66 | noise_onset = 0 67 | nosie_offset = 16384 68 | 69 | all_speech_names = [na for na in os.listdir(speech_dir) if na.lower().endswith(".wav")] 70 | noise_names = [na for na in os.listdir(noise_dir) if na.lower().endswith(".wav")] 71 | 72 | speech_rs = np.random.RandomState(10) 73 | print int(len(all_speech_names)/speechratio) 74 | speech_names = speech_rs.choice(all_speech_names, size=int(len(all_speech_names)/speechratio), replace=False) 75 | 76 | rs = np.random.RandomState(0) 77 | out_csv_path = os.path.join(workspace, "mixture_csvs", "%s.csv" % data_type) 78 | pp_data.create_folder(os.path.dirname(out_csv_path)) 79 | 80 | cnt = 0 81 | f = open(out_csv_path, 'w') 82 | f.write("%s\t%s\t%s\t%s\n" % ("speech_name", "noise_name", "noise_onset", "noise_offset")) 83 | for speech_na in speech_names: 84 | # Read speech. 85 | speech_path = os.path.join(speech_dir, speech_na) 86 | #by gm(speech_audio, _) = read_audio(speech_path,target_fs=fs) 87 | #by gmlen_speech = len(speech_audio) 88 | 89 | # For training data, mix each speech with randomly picked #magnification noises. 90 | if data_type == 'train': 91 | selected_noise_names = rs.choice(noise_names, size=magnification, replace=False) 92 | # For test data, mix each speech with all noises. 93 | elif data_type == 'test': 94 | selected_noise_names = noise_names 95 | else: 96 | raise Exception("data_type must be train | test!") 97 | 98 | # Mix one speech with different noises many times. 99 | for noise_na in selected_noise_names: 100 | #by gm noise_path = os.path.join(noise_dir, noise_na) 101 | #by gm (noise_audio, _) = read_audio(noise_path,target_fs=fs) 102 | 103 | #by gm len_noise = len(noise_audio) 104 | 105 | #by gm 106 | ''' 107 | if len_noise <= len_speech: 108 | noise_onset = 0 109 | nosie_offset = len_speech 110 | # If noise longer than speech then randomly select a segment of noise. 111 | else: 112 | noise_onset = rs.randint(0, len_noise - len_speech, size=1)[0] 113 | nosie_offset = noise_onset + len_speech 114 | ''' 115 | if cnt % 100 == 0: 116 | print cnt 117 | 118 | cnt += 1 119 | f.write("%s\t%s\t%d\t%d\n" % (speech_na, noise_na, noise_onset, nosie_offset)) 120 | f.close() 121 | print(out_csv_path) 122 | print("Create %s mixture csv finished!" % data_type) 123 | 124 | ### 125 | def calculate_mixture_features(args): 126 | """Calculate spectrogram for mixed, speech and noise audio. Then write the 127 | features to disk. 128 | 129 | Args: 130 | workspace: str, path of workspace. 131 | speech_dir: str, path of speech data. 132 | noise_dir: str, path of noise data. 133 | data_type: str, 'train' | 'test'. 134 | snr: float, signal to noise ratio to be mixed. 135 | """ 136 | workspace = args.workspace 137 | speech_dir = args.speech_dir 138 | noise_dir = args.noise_dir 139 | data_type = args.data_type 140 | snr = args.snr 141 | fs = cfg.sample_rate 142 | 143 | # Open mixture csv. 144 | mixture_csv_path = os.path.join(workspace, "mixture_csvs", "%s.csv" % data_type) 145 | with open(mixture_csv_path, 'rb') as f: 146 | reader = csv.reader(f, delimiter='\t') 147 | lis = list(reader) 148 | 149 | t1 = time.time() 150 | cnt = 0 151 | for i1 in xrange(1, len(lis)): 152 | [speech_na, noise_na, noise_onset, noise_offset] = lis[i1] 153 | noise_onset = int(noise_onset) 154 | noise_offset = int(noise_offset) 155 | 156 | # Read speech audio. 157 | speech_path = os.path.join(speech_dir, speech_na) 158 | (speech_audio, _) = read_audio(speech_path, target_fs=fs) 159 | 160 | # Read noise audio. 161 | noise_path = os.path.join(noise_dir, noise_na) 162 | (noise_audio, _) = read_audio(noise_path, target_fs=fs) 163 | 164 | # Repeat noise to the same length as speech. 165 | if len(noise_audio) < len(speech_audio): 166 | n_repeat = int(np.ceil(float(len(speech_audio)) / float(len(noise_audio)))) 167 | noise_audio_ex = np.tile(noise_audio, n_repeat) 168 | noise_audio = noise_audio_ex[0 : len(speech_audio)] 169 | elif len(speech_audio) < len(noise_audio): 170 | n_repeat = int(np.ceil(float(len(noise_audio)) / float(len(speech_audio)))) 171 | speech_audio_ex = np.tile(speech_audio, n_repeat) 172 | speech_audio = speech_audio_ex[0 : len(noise_audio)] 173 | # Truncate noise to the same length as speech. 174 | else: 175 | #noise_audio = noise_audio[noise_onset : noise_offset] 176 | noise_audio = noise_audio 177 | speech_audio = speech_audio 178 | 179 | # Scale speech to given snr. 180 | scaler = get_amplitude_scaling_factor(speech_audio, noise_audio, snr=snr) 181 | speech_audio *= scaler 182 | 183 | # Get normalized mixture, speech, noise. 184 | (mixed_audio, speech_audio, noise_audio, alpha) = additive_mixing(speech_audio, noise_audio) 185 | 186 | # Write out normalized clean audio. 187 | out_bare_clear = os.path.join("%s-" % 188 | (os.path.splitext(speech_na)[0])) 189 | out_audio_path_clear = os.path.join(workspace, "normalized_clear_audios", "spectrogram", 190 | data_type, "%ddb" % int(snr), "%s.wav" % out_bare_clear) 191 | create_folder(os.path.dirname(out_audio_path_clear)) 192 | write_audio(out_audio_path_clear, speech_audio, fs) 193 | 194 | # Write out mixed audio. 195 | out_bare_na = os.path.join("%s.%s" % 196 | (os.path.splitext(speech_na)[0], os.path.splitext(noise_na)[0])) 197 | out_audio_path = os.path.join(workspace, "mixed_audios", "spectrogram", 198 | data_type, "%ddb" % int(snr), "%s.wav" % out_bare_na) 199 | create_folder(os.path.dirname(out_audio_path)) 200 | write_audio(out_audio_path, mixed_audio, fs) 201 | 202 | # Extract spectrogram. 203 | mixed_complx_x = calc_sp(mixed_audio, mode='complex') 204 | speech_x = calc_sp(speech_audio, mode='magnitude') 205 | noise_x = calc_sp(noise_audio, mode='magnitude') 206 | 207 | # Write out features. 208 | out_feat_path = os.path.join(workspace, "features", "spectrogram", 209 | data_type, "%ddb" % int(snr), "%s.p" % out_bare_na) 210 | create_folder(os.path.dirname(out_feat_path)) 211 | data = [mixed_complx_x, speech_x, noise_x, alpha, out_bare_na] 212 | print (mixed_complx_x.shape) 213 | print (speech_x.shape) 214 | print (noise_x.shape) 215 | print ("alpha:") 216 | print (alpha) 217 | print ("out_bare_na:") 218 | print (out_bare_na) 219 | cPickle.dump(data, open(out_feat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 220 | 221 | # Print. 222 | if cnt % 100 == 0: 223 | print(cnt) 224 | 225 | cnt += 1 226 | 227 | print("Extracting feature time: %s" % (time.time() - t1)) 228 | 229 | 230 | 231 | ### 232 | def calculate_noisy_features(args): 233 | """Calculate spectrogram for mixed, speech and noise audio. Then write the 234 | features to disk. 235 | 236 | Args: 237 | workspace: str, path of workspace. 238 | speech_dir: str, path of speech data. 239 | noise_dir: str, path of noise data. 240 | data_type: str, 'train' | 'test'. 241 | snr: float, signal to noise ratio to be mixed. 242 | """ 243 | workspace = args.workspace 244 | noisy_dir = args.noisy_dir 245 | data_type = args.data_type 246 | snr = args.snr 247 | fs = cfg.sample_rate 248 | 249 | t1 = time.time() 250 | cnt = 0 251 | 252 | # Extract spectrogram of all noisy_speech. 253 | noisy_speech_dir = os.path.join(noisy_dir) 254 | names = os.listdir(noisy_speech_dir) 255 | print(names) 256 | for (cnt, na) in enumerate(names): 257 | print(cnt, na) 258 | noisy_file_path = os.path.join(noisy_speech_dir, na) 259 | 260 | print(noisy_file_path) 261 | 262 | # Read noisy audio. 263 | (noisy_audio, _) = read_audio(noisy_file_path, target_fs=fs) 264 | 265 | # Extract spectrogram. 266 | mixed_complx_x = calc_sp(noisy_audio, mode='complex') 267 | speech_x = 1e-8 268 | noise_x = 1e-8 269 | alpha = 1e-8 270 | out_bare_na = na 271 | 272 | # Write out features. 273 | out_feat_path = os.path.join(workspace, "features", "spectrogram", 274 | data_type, "%ddb" % int(snr), "%s.p" % na) 275 | create_folder(os.path.dirname(out_feat_path)) 276 | data = [mixed_complx_x, speech_x, noise_x, alpha, out_bare_na] 277 | print (mixed_complx_x.shape) 278 | print ("alpha:") 279 | print (alpha) 280 | print ("out_bare_na:") 281 | print (out_bare_na) 282 | cPickle.dump(data, open(out_feat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL) 283 | print(out_feat_path) 284 | # Print. 285 | if cnt % 100 == 0: 286 | print(cnt) 287 | 288 | cnt += 1 289 | 290 | print("Extracting feature time: %s" % (time.time() - t1)) 291 | 292 | 293 | 294 | 295 | def rms(y): 296 | """Root mean square. 297 | """ 298 | return np.sqrt(np.mean(np.abs(y) ** 2, axis=0, keepdims=False)) 299 | 300 | def get_amplitude_scaling_factor(s, n, snr, method='rms'): 301 | """Given s and n, return the scaler s according to the snr. 302 | 303 | Args: 304 | s: ndarray, source1. 305 | n: ndarray, source2. 306 | snr: float, SNR. 307 | method: 'rms'. 308 | 309 | Outputs: 310 | float, scaler. 311 | """ 312 | original_sn_rms_ratio = rms(s) / rms(n) 313 | target_sn_rms_ratio = 10. ** (float(snr) / 20.) # snr = 20 * lg(rms(s) / rms(n)) 314 | signal_scaling_factor = target_sn_rms_ratio / original_sn_rms_ratio 315 | return signal_scaling_factor 316 | 317 | def additive_mixing(s, n): 318 | """Mix normalized source1 and source2. 319 | 320 | Args: 321 | s: ndarray, source1. 322 | n: ndarray, source2. 323 | 324 | Returns: 325 | mix_audio: ndarray, mixed audio. 326 | s: ndarray, pad or truncated and scalered source1. 327 | n: ndarray, scaled source2. 328 | alpha: float, normalize coefficient. 329 | """ 330 | mixed_audio = s + n 331 | 332 | alpha = 1. / np.max(np.abs(mixed_audio)) 333 | mixed_audio *= alpha 334 | s *= alpha 335 | n *= alpha 336 | return mixed_audio, s, n, alpha 337 | 338 | def calc_sp(audio, mode): 339 | """Calculate spectrogram. 340 | 341 | Args: 342 | audio: 1darray. 343 | mode: string, 'magnitude' | 'complex' 344 | 345 | Returns: 346 | spectrogram: 2darray, (n_time, n_freq). 347 | """ 348 | n_window = cfg.n_window 349 | n_overlap = cfg.n_overlap 350 | ham_win = np.hamming(n_window) 351 | [f, t, x] = signal.spectral.spectrogram( 352 | audio, 353 | window=ham_win, 354 | nperseg=n_window, 355 | noverlap=n_overlap, 356 | detrend=False, 357 | return_onesided=True, 358 | mode=mode) 359 | x = x.T 360 | if mode == 'magnitude': 361 | x = x.astype(np.float32) 362 | elif mode == 'complex': 363 | x = x.astype(np.complex64) 364 | else: 365 | raise Exception("Incorrect mode!") 366 | return x 367 | 368 | ### 369 | def pack_features(args): 370 | """Load all features, apply log and conver to 3D tensor, write out to .h5 file. 371 | 372 | Args: 373 | workspace: str, path of workspace. 374 | data_type: str, 'train' | 'test'. 375 | snr: float, signal to noise ratio to be mixed. 376 | n_concat: int, number of frames to be concatenated. 377 | n_hop: int, hop frames. 378 | """ 379 | workspace = args.workspace 380 | data_type = args.data_type 381 | snr = args.snr 382 | n_concat = args.n_concat 383 | n_hop = args.n_hop 384 | calc_log = args.calc_log 385 | 386 | x_all = [] # (n_segs, n_concat, n_freq) 387 | y_all = [] # (n_segs, n_freq) 388 | 389 | cnt = 0 390 | t1 = time.time() 391 | 392 | # Load all features. 393 | feat_dir = os.path.join(workspace, "features", "spectrogram", data_type, "%ddb" % int(snr)) 394 | names = os.listdir(feat_dir) 395 | for na in names: 396 | # Load feature. 397 | feat_path = os.path.join(feat_dir, na) 398 | data = cPickle.load(open(feat_path, 'rb')) 399 | [mixed_complx_x, speech_x, noise_x, alpha, na] = data 400 | mixed_x = np.abs(mixed_complx_x) 401 | 402 | # Pad start and finish of the spectrogram with boarder values. 403 | n_pad = (n_concat - 1) / 2 404 | mixed_x = pad_with_border(mixed_x, n_pad) 405 | speech_x = pad_with_border(speech_x, n_pad) 406 | 407 | # Cut input spectrogram to 3D segments with n_concat. 408 | mixed_x_3d = mat_2d_to_3d(mixed_x, agg_num=n_concat, hop=n_hop) 409 | x_all.append(mixed_x_3d) 410 | print ("shape:") 411 | print mixed_x.shape 412 | print mixed_x_3d.shape 413 | 414 | # Cut target spectrogram and take the center frame of each 3D segment. 415 | speech_x_3d = mat_2d_to_3d(speech_x, agg_num=n_concat, hop=n_hop) 416 | y = speech_x_3d[:, (n_concat - 1) / 2, :] 417 | y_all.append(y) 418 | print ("shape:") 419 | print speech_x.shape 420 | print speech_x_3d.shape 421 | print y.shape 422 | 423 | # Print. 424 | if cnt % 100 == 0: 425 | print(cnt) 426 | 427 | # if cnt == 3: break 428 | cnt += 1 429 | 430 | x_all = np.concatenate(x_all, axis=0) # (n_segs, n_concat, n_freq) 431 | y_all = np.concatenate(y_all, axis=0) # (n_segs, n_freq) 432 | 433 | if calc_log: 434 | x_all = log_sp(x_all).astype(np.float32) 435 | y_all = log_sp(y_all).astype(np.float32) 436 | else: 437 | x_all = (x_all).astype(np.float32) 438 | y_all = (y_all).astype(np.float32) 439 | 440 | 441 | # Write out data to .h5 file. 442 | out_path = os.path.join(workspace, "packed_features", "spectrogram", data_type, "%ddb" % int(snr), "data.h5") 443 | create_folder(os.path.dirname(out_path)) 444 | with h5py.File(out_path, 'w') as hf: 445 | hf.create_dataset('x', data=x_all) 446 | hf.create_dataset('y', data=y_all) 447 | 448 | print("Write out to %s" % out_path) 449 | print("Pack features finished! %s s" % (time.time() - t1,)) 450 | 451 | def log_sp(x): 452 | return np.log(x + 1e-08) 453 | 454 | def mat_2d_to_3d(x, agg_num, hop): 455 | """Segment 2D array to 3D segments. 456 | """ 457 | # Pad to at least one block. 458 | len_x, n_in = x.shape 459 | if (len_x < agg_num): 460 | x = np.concatenate((x, np.zeros((agg_num - len_x, n_in)))) 461 | 462 | # Segment 2d to 3d. 463 | len_x = len(x) 464 | i1 = 0 465 | x3d = [] 466 | while (i1 + agg_num <= len_x): 467 | x3d.append(x[i1 : i1 + agg_num]) 468 | i1 += hop 469 | return np.array(x3d) 470 | 471 | def pad_with_border(x, n_pad): 472 | """Pad the begin and finish of spectrogram with border frame value. 473 | """ 474 | x_pad_list = [x[0:1]] * n_pad + [x] + [x[-1:]] * n_pad 475 | return np.concatenate(x_pad_list, axis=0) 476 | 477 | ### 478 | def compute_scaler(args): 479 | """Compute and write out scaler of data. 480 | """ 481 | workspace = args.workspace 482 | data_type = args.data_type 483 | snr = args.snr 484 | 485 | # Load data. 486 | t1 = time.time() 487 | hdf5_path = os.path.join(workspace, "packed_features", "spectrogram", data_type, "%ddb" % int(snr), "data.h5") 488 | with h5py.File(hdf5_path, 'r') as hf: 489 | x = hf.get('x') 490 | x = np.array(x) # (n_segs, n_concat, n_freq) 491 | 492 | # Compute scaler. 493 | (n_segs, n_concat, n_freq) = x.shape 494 | x2d = x.reshape((n_segs * n_concat, n_freq)) 495 | scaler = preprocessing.StandardScaler(with_mean=True, with_std=True).fit(x2d) 496 | print(scaler.mean_) 497 | print(scaler.scale_) 498 | 499 | # Write out scaler. 500 | out_path = os.path.join(workspace, "packed_features", "spectrogram", data_type, "%ddb" % int(snr), "scaler.p") 501 | create_folder(os.path.dirname(out_path)) 502 | pickle.dump(scaler, open(out_path, 'wb')) 503 | 504 | print("Save scaler to %s" % out_path) 505 | print("Compute scaler finished! %s s" % (time.time() - t1,)) 506 | 507 | def scale_on_2d(x2d, scaler): 508 | """Scale 2D array data. 509 | """ 510 | return scaler.transform(x2d) 511 | 512 | def scale_on_3d(x3d, scaler): 513 | """Scale 3D array data. 514 | """ 515 | (n_segs, n_concat, n_freq) = x3d.shape 516 | x2d = x3d.reshape((n_segs * n_concat, n_freq)) 517 | x2d = scaler.transform(x2d) 518 | x3d = x2d.reshape((n_segs, n_concat, n_freq)) 519 | return x3d 520 | 521 | def inverse_scale_on_2d(x2d, scaler): 522 | """Inverse scale 2D array data. 523 | """ 524 | return x2d * scaler.scale_[None, :] + scaler.mean_[None, :] 525 | 526 | ### 527 | def load_hdf5(hdf5_path): 528 | """Load hdf5 data. 529 | """ 530 | with h5py.File(hdf5_path, 'r') as hf: 531 | x = hf.get('x') 532 | y = hf.get('y') 533 | x = np.array(x) # (n_segs, n_concat, n_freq) 534 | y = np.array(y) # (n_segs, n_freq) 535 | return x, y 536 | 537 | def np_mean_absolute_error(y_true, y_pred): 538 | return np.mean(np.abs(y_pred - y_true)) 539 | 540 | ### 541 | if __name__ == '__main__': 542 | parser = argparse.ArgumentParser() 543 | subparsers = parser.add_subparsers(dest='mode') 544 | 545 | parser_create_mixture_csv = subparsers.add_parser('create_mixture_csv') 546 | parser_create_mixture_csv.add_argument('--workspace', type=str, required=True) 547 | parser_create_mixture_csv.add_argument('--speech_dir', type=str, required=True) 548 | parser_create_mixture_csv.add_argument('--noise_dir', type=str, required=True) 549 | parser_create_mixture_csv.add_argument('--data_type', type=str, required=True) 550 | parser_create_mixture_csv.add_argument('--speechratio', type=int, default=1) 551 | parser_create_mixture_csv.add_argument('--magnification', type=int, default=1) 552 | 553 | parser_calculate_mixture_features = subparsers.add_parser('calculate_mixture_features') 554 | parser_calculate_mixture_features.add_argument('--workspace', type=str, required=True) 555 | parser_calculate_mixture_features.add_argument('--speech_dir', type=str, required=True) 556 | parser_calculate_mixture_features.add_argument('--noise_dir', type=str, required=True) 557 | parser_calculate_mixture_features.add_argument('--data_type', type=str, required=True) 558 | parser_calculate_mixture_features.add_argument('--snr', type=float, required=True) 559 | 560 | parser_pack_features = subparsers.add_parser('pack_features') 561 | parser_pack_features.add_argument('--workspace', type=str, required=True) 562 | parser_pack_features.add_argument('--data_type', type=str, required=True) 563 | parser_pack_features.add_argument('--snr', type=float, required=True) 564 | parser_pack_features.add_argument('--n_concat', type=int, required=True) 565 | parser_pack_features.add_argument('--n_hop', type=int, required=True) 566 | parser_pack_features.add_argument('--calc_log', type=int, required=True) 567 | 568 | parser_compute_scaler = subparsers.add_parser('compute_scaler') 569 | parser_compute_scaler.add_argument('--workspace', type=str, required=True) 570 | parser_compute_scaler.add_argument('--data_type', type=str, required=True) 571 | parser_compute_scaler.add_argument('--snr', type=float, required=True) 572 | 573 | parser_calculate_noisy_features = subparsers.add_parser('calculate_noisy_features') 574 | parser_calculate_noisy_features.add_argument('--workspace', type=str, required=True) 575 | parser_calculate_noisy_features.add_argument('--noisy_dir', type=str, required=True) 576 | parser_calculate_noisy_features.add_argument('--data_type', type=str, required=True) 577 | parser_calculate_noisy_features.add_argument('--snr', type=float, required=True) 578 | 579 | 580 | args = parser.parse_args() 581 | if args.mode == 'create_mixture_csv': 582 | create_mixture_csv(args) 583 | elif args.mode == 'calculate_mixture_features': 584 | calculate_mixture_features(args) 585 | elif args.mode == 'calculate_noisy_features': 586 | calculate_noisy_features(args) 587 | elif args.mode == 'pack_features': 588 | pack_features(args) 589 | elif args.mode == 'compute_scaler': 590 | compute_scaler(args) 591 | else: 592 | raise Exception("Error!") 593 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | soundfile>=0.9.0.post1 2 | numpy>=1.13.3 3 | matplotlib>=2.1.1 4 | scipy>=1.0.0 5 | h5py>=2.7.1 6 | scikit-learn>=0.19.1 7 | keras>=2.1.2 8 | tensorflow-gpu>=1.4.1 9 | -------------------------------------------------------------------------------- /runme.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CMD="main_dnn.py" 4 | 5 | WORKSPACE="./workspace" 6 | mkdir $WORKSPACE 7 | TR_SPEECH_DIR="./data/train_speech" 8 | TR_NOISE_DIR="./data/train_noise" 9 | TE_SPEECH_DIR="./data/test_speech" 10 | TE_NOISE_DIR="./data/test_noise" 11 | 12 | MODEL_FILE="null" 13 | 14 | TR_SNR=5 15 | TE_SNR=5 16 | N_CONCAT=7 17 | N_HOP=2 18 | CALC_LOG=0 19 | EPOCHS=100000 20 | ITERATION=90000 21 | LEARNING_RATE=1e-4 22 | 23 | CALC_DATA=1 24 | if [ $CALC_DATA -eq 1 ]; then 25 | # Create mixture csv. 26 | echo "Go:Create mixture csv. " 27 | python prepare_data.py create_mixture_csv --workspace=$WORKSPACE --speech_dir=$TR_SPEECH_DIR --noise_dir=$TR_NOISE_DIR --data_type=train --speechratio=1 --magnification=2 28 | python prepare_data.py create_mixture_csv --workspace=$WORKSPACE --speech_dir=$TE_SPEECH_DIR --noise_dir=$TE_NOISE_DIR --data_type=test --speechratio=1 29 | 30 | 31 | #echo "Calculate mixture features. " 32 | TR_SNR=5 33 | TE_SNR=5 34 | python prepare_data.py calculate_mixture_features --workspace=$WORKSPACE --speech_dir=$TR_SPEECH_DIR --noise_dir=$TR_NOISE_DIR --data_type=train --snr=$TR_SNR 35 | python prepare_data.py calculate_mixture_features --workspace=$WORKSPACE --speech_dir=$TE_SPEECH_DIR --noise_dir=$TE_NOISE_DIR --data_type=test --snr=$TE_SNR 36 | 37 | 38 | #echo "finish!" 39 | #exit 40 | 41 | # Calculate PESQ of all noisy speech. 42 | echo "Calculate PESQ of all noisy speech. " 43 | python evaluate.py calculate_noisy_pesq --workspace=$WORKSPACE --speech_dir=$TE_SPEECH_DIR --te_snr=$TE_SNR 44 | 45 | # Calculate noisy overall stats. 46 | echo "Calculate noisy overall stats. " 47 | python evaluate.py get_stats 48 | 49 | 50 | # Pack features. 51 | echo "Pack features. " 52 | N_CONCAT=7 53 | N_HOP=2 54 | python prepare_data.py pack_features --workspace=$WORKSPACE --data_type=train --snr=$TR_SNR --n_concat=$N_CONCAT --n_hop=$N_HOP --calc_log=$CALC_LOG 55 | python prepare_data.py pack_features --workspace=$WORKSPACE --data_type=test --snr=$TE_SNR --n_concat=$N_CONCAT --n_hop=$N_HOP --calc_log=$CALC_LOG 56 | 57 | # Compute scaler. 58 | echo "Compute scaler. " 59 | python prepare_data.py compute_scaler --workspace=$WORKSPACE --data_type=train --snr=$TR_SNR 60 | 61 | fi 62 | 63 | 64 | 65 | # Train. 66 | echo "Train. " 67 | CUDA_VISIBLE_DEVICES=0 python $CMD train --workspace=$WORKSPACE --tr_snr=$TR_SNR --te_snr=$TE_SNR --lr=$LEARNING_RATE --epoch=$EPOCHS --calc_log=$CALC_LOG 68 | 69 | 70 | # Inference, enhanced wavs will be created. 71 | echo "Inference, enhanced wavs will be created. " 72 | CUDA_VISIBLE_DEVICES=0 python $CMD inference --workspace=$WORKSPACE --tr_snr=$TR_SNR --te_snr=$TE_SNR --n_concat=$N_CONCAT --iteration=$ITERATION --calc_log=$CALC_LOG --model_file=$MODEL_FILE 73 | 74 | # Calculate PESQ of all enhanced speech. 75 | echo "Calculate PESQ of all enhanced speech. " 76 | python evaluate.py calculate_pesq --workspace=$WORKSPACE --speech_dir=$TE_SPEECH_DIR --te_snr=$TE_SNR 77 | 78 | # Calculate overall stats. 79 | echo "Calculate overall stats. " 80 | python evaluate.py get_stats 81 | 82 | -------------------------------------------------------------------------------- /spectrogram_to_wave.py: -------------------------------------------------------------------------------- 1 | """ 2 | Summary: Recover spectrogram to wave. 3 | Author: Qiuqiang Kong 4 | Created: 2017.09 5 | Modified: - 6 | """ 7 | import numpy as np 8 | import numpy 9 | import decimal 10 | 11 | def recover_wav(pd_abs_x, gt_x, n_overlap, winfunc, wav_len=None): 12 | """Recover wave from spectrogram. 13 | If you are using scipy.signal.spectrogram, you may need to multipy a scaler 14 | to the recovered audio after using this function. For example, 15 | recover_scaler = np.sqrt((ham_win**2).sum()) 16 | 17 | Args: 18 | pd_abs_x: 2d array, (n_time, n_freq) 19 | gt_x: 2d complex array, (n_time, n_freq) 20 | n_overlap: integar. 21 | winfunc: func, the analysis window to apply to each frame. 22 | wav_len: integer. Pad or trunc to wav_len with zero. 23 | 24 | Returns: 25 | 1d array. 26 | """ 27 | x = real_to_complex(pd_abs_x, gt_x) 28 | x = half_to_whole(x) 29 | frames = ifft_to_wav(x) 30 | (n_frames, n_window) = frames.shape 31 | print ("pred shape:") 32 | print (pd_abs_x.shape) 33 | print ("frames shape:") 34 | print (frames.shape) 35 | s = deframesig(frames=frames, siglen=0, frame_len=n_window, 36 | frame_step=n_window-n_overlap, winfunc=winfunc) 37 | print ("s shape:") 38 | print (s.shape) 39 | if wav_len: 40 | s = pad_or_trunc(s, wav_len) 41 | return s 42 | 43 | def real_to_complex(pd_abs_x, gt_x): 44 | """Recover pred spectrogram's phase from ground truth's phase. 45 | 46 | Args: 47 | pd_abs_x: 2d array, (n_time, n_freq) 48 | gt_x: 2d complex array, (n_time, n_freq) 49 | 50 | Returns: 51 | 2d complex array, (n_time, n_freq) 52 | """ 53 | theta = np.angle(gt_x) 54 | cmplx = pd_abs_x * np.exp(1j * theta) 55 | return cmplx 56 | 57 | def half_to_whole(x): 58 | """Recover whole spectrogram from half spectrogram. 59 | """ 60 | return np.concatenate((x, np.fliplr(np.conj(x[:, 1:-1]))), axis=1) 61 | 62 | def ifft_to_wav(x): 63 | """Recover wav from whole spectrogram""" 64 | return np.real(np.fft.ifft(x)) 65 | 66 | def pad_or_trunc(s, wav_len): 67 | if len(s) >= wav_len: 68 | s = s[0 : wav_len] 69 | else: 70 | s = np.concatenate((s, np.zeros(wav_len - len(s)))) 71 | return s 72 | 73 | def recover_gt_wav(x, n_overlap, winfunc, wav_len=None): 74 | """Recover ground truth wav. 75 | """ 76 | x = half_to_whole(x) 77 | frames = ifft_to_wav(x) 78 | (n_frames, n_window) = frames.shape 79 | s = deframesig(frames=frames, siglen=0, frame_len=n_window, 80 | frame_step=n_window-n_overlap, winfunc=winfunc) 81 | if wav_len: 82 | s = pad_or_trunc(s, wav_len) 83 | return s 84 | 85 | def deframesig(frames,siglen,frame_len,frame_step,winfunc=lambda x:numpy.ones((x,))): 86 | """Does overlap-add procedure to undo the action of framesig. 87 | Ref: From https://github.com/jameslyons/python_speech_features 88 | 89 | :param frames: the array of frames. 90 | :param siglen: the length of the desired signal, use 0 if unknown. Output will be truncated to siglen samples. 91 | :param frame_len: length of each frame measured in samples. 92 | :param frame_step: number of samples after the start of the previous frame that the next frame should begin. 93 | :param winfunc: the analysis window to apply to each frame. By default no window is applied. 94 | :returns: a 1-D signal. 95 | """ 96 | frame_len = round_half_up(frame_len) 97 | frame_step = round_half_up(frame_step) 98 | numframes = numpy.shape(frames)[0] 99 | assert numpy.shape(frames)[1] == frame_len, '"frames" matrix is wrong size, 2nd dim is not equal to frame_len' 100 | 101 | indices = numpy.tile(numpy.arange(0,frame_len),(numframes,1)) + numpy.tile(numpy.arange(0,numframes*frame_step,frame_step),(frame_len,1)).T 102 | indices = numpy.array(indices,dtype=numpy.int32) 103 | padlen = (numframes-1)*frame_step + frame_len 104 | 105 | if siglen <= 0: siglen = padlen 106 | 107 | rec_signal = numpy.zeros((padlen,)) 108 | window_correction = numpy.zeros((padlen,)) 109 | win = winfunc(frame_len) 110 | 111 | for i in range(0,numframes): 112 | window_correction[indices[i,:]] = window_correction[indices[i,:]] + win + 1e-15 #add a little bit so it is never zero 113 | rec_signal[indices[i,:]] = rec_signal[indices[i,:]] + frames[i,:] 114 | 115 | rec_signal = rec_signal/window_correction 116 | return rec_signal[0:siglen] 117 | 118 | def round_half_up(number): 119 | return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP)) --------------------------------------------------------------------------------