├── README.md ├── Utterance_based_FCN_MSE.py └── images ├── E2E.png ├── Fig1.png ├── Fig1_2.png ├── Fig1_3.png └── t2.gif /README.md: -------------------------------------------------------------------------------- 1 | # End-to-end waveform utterance enhancement for direct evaluation metrics optimization by fully convolutional neural networks (TASLP 2018) 2 | 3 | 4 | ### Introduction 5 | This paper tries to solve the mismatch (as in Fig.1) between training objective function and evaluation metrics which are usually highly correlated to human perception. Due to the inconsistency, there is no guarantee that the trained model can provide optimal performance in applications. In this study, we propose an end-to-end utterance-based speech enhancement framework using fully convolutional neural networks (FCN) to reduce the gap between the model optimization and the evaluation criterion. Because of the utterance-based optimization, temporal correlation information of long speech segments, or even at the entire utterance level, can be considered to directly optimize perception-based objective functions. 6 | 7 | ### Major Contribution 8 | 1) Utterance-based waveform enhancement 9 | 2) Direct short-time objective intelligibility (STOI) score optimization (without any approximation) 10 | 11 | 12 | For more details and evaluation results, please check out our [paper](https://ieeexplore.ieee.org/document/8331910). 13 | 14 | ![teaser](https://github.com/JasonSWFu/End-to-end-waveform-utterance-enhancement/blob/master/images/Fig1_3.png) 15 | 16 | ### Waveform enhancement process: 17 | 18 | ![teaser](https://github.com/JasonSWFu/End-to-end-waveform-utterance-enhancement/blob/master/images/t2.gif) 19 | 20 | ### Dependencies: 21 | * Python 2.7 22 | * keras=1.1.0 (recommended) 23 | 24 | ### Note! 25 | For the STOI loss function optimization, please e-mail me. 26 | 27 | ### Citation 28 | 29 | If you find the code useful in your research, please cite: 30 | 31 | @article{fu2018end, 32 | title={End-to-end waveform utterance enhancement for direct evaluation metrics optimization by fully convolutional neural networks}, 33 | author={Fu, Szu-Wei and Wang, Tao-Wei and Tsao, Yu and Lu, Xugang and Kawai, Hisashi}, 34 | journal={IEEE/ACM Transactions on Audio, Speech and Language Processing (TASLP)}, 35 | volume={26}, 36 | number={9}, 37 | pages={1570--1584}, 38 | year={2018}, 39 | publisher={IEEE Press}} 40 | 41 | @inproceedings{fu2017raw, 42 | title={Raw waveform-based speech enhancement by fully convolutional networks}, 43 | author={Fu, Szu-Wei and Tsao, Yu and Lu, Xugang and Kawai, Hisashi}, 44 | booktitle={2017 Asia-Pacific Signal and Information Processing Association Annual Summit and Conference (APSIPA ASC)}, 45 | pages={006--012}, 46 | year={2017}, 47 | organization={IEEE}} 48 | 49 | ### Contact 50 | 51 | e-mail: jasonfu@iis.sinica.edu.tw or d04922007@ntu.edu.tw 52 | 53 | -------------------------------------------------------------------------------- /Utterance_based_FCN_MSE.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Feb 25 10:24:51 2018 4 | 5 | !!! This code is recomended to be implemented with Keras version 1.1.0 (Since BatchNormalization seems to be modified in the future versions) !!!!!! 6 | 7 | !!! If you find that the high frequency regions of enhanced speech are missing, please train FCN for more epochs (although the loss may not change a lot). 8 | You may ovserve the high frequency regions gradually appear as shown in the .gif here: https://jasonswfu.github.io/JasonFu.github.io/ and https://jasonswfu.github.io/JasonFu.github.io/images/t2.gif 9 | 10 | This code is used for FCN-based raw waveform denoising (utterance-wise, with MSE loss) 11 | 12 | If you find this code useful in your research, please cite: 13 | Citation: 14 | [1] S.-W. Fu, Y. Tsao, X. Lu, and H. Kawai, "Raw waveform-based speech enhancement by fully convolutional networks," in Proc. APSIPA, 2017. 15 | [2] S.-W. Fu, Y. Tsao, X. Lu, and H. Kawai, "End-to-end waveform utterance enhancement for direct evaluation metrics optimization by fully convolutional neural networks," IEEE Transactions on Audio, Speech, and Language Processing, 2018. 16 | Contact: 17 | Szu-Wei Fu 18 | jasonfu@citi.sinica.edu.tw 19 | Academia Sinica, Taipei, Taiwan 20 | 21 | @author: Jason 22 | """ 23 | 24 | import matplotlib 25 | # Force matplotlib to not use any Xwindows backend. 26 | matplotlib.use('Agg') 27 | import matplotlib.pyplot as plt 28 | from keras.models import Sequential, model_from_json, load_model 29 | from keras.layers.core import Dense, Dropout, Flatten, Activation, SpatialDropout2D, Reshape, Lambda 30 | from keras.layers.normalization import BatchNormalization 31 | from keras.layers.advanced_activations import ELU, PReLU, LeakyReLU 32 | from keras.layers.convolutional import Convolution1D 33 | from keras.optimizers import SGD 34 | from keras.callbacks import ModelCheckpoint 35 | from scipy.io import wavfile 36 | 37 | import scipy.io 38 | import librosa 39 | import os 40 | import time 41 | import numpy as np 42 | import numpy.matlib 43 | import random 44 | import theano 45 | import theano.tensor as T 46 | random.seed(999) 47 | 48 | Num_traindata=20000 49 | epoch=40 50 | batch_size=1 51 | max_input_audio_length=7 # In a 12GB RAM TITAN X GPU, with the current FCN structure, the maximun input audio length without OOM is roughly 7s. 52 | 53 | 54 | def shuffle_list(x_old,index): 55 | x_new=[x_old[i] for i in index] 56 | return x_new 57 | 58 | def get_filepaths(directory): 59 | """ 60 | This function will generate the file names in a directory 61 | tree by walking the tree either top-down or bottom-up. For each 62 | directory in the tree rooted at directory top (including top itself), 63 | it yields a 3-tuple (dirpath, dirnames, filenames). 64 | """ 65 | file_paths = [] # List which will store all of the full filepaths. 66 | 67 | # Walk the tree. 68 | for root, directories, files in os.walk(directory): 69 | for filename in files: 70 | # Join the two strings in order to form the full filepath. 71 | filepath = os.path.join(root, filename) 72 | file_paths.append(filepath) # Add it to the list. 73 | 74 | return file_paths # Self-explanatory. 75 | 76 | 77 | 78 | def train_data_generator(noisy_list, clean_path): 79 | index=0 80 | while True: 81 | #noisy, rate = librosa.load(noisy_list[index],sr=16000) 82 | rate, noisy = wavfile.read(noisy_list[index]) 83 | while noisy.shape[0]/16000.>max_input_audio_length: # Audio length <7s or OOM. Read next utterance. 84 | index += 1 85 | if index == len(noisy_list): 86 | index = 0 87 | rate, noisy = wavfile.read(noisy_list[index]) 88 | 89 | noisy=noisy.astype('float') 90 | if len(noisy.shape)==2: 91 | noisy=(noisy[:,0]+noisy[:,1])/2 92 | noisy=noisy/np.max(abs(noisy)) 93 | noisy=np.reshape(noisy,(1,np.shape(noisy)[0],1)) 94 | 95 | #clean, rate =librosa.load(clean_list[clean_wav_list.index(noisy_wav_list[index])],sr=16000) 96 | rate, clean = wavfile.read(clean_path+noisy_list[index].split('/')[-1]) 97 | clean=clean.astype('float') 98 | if len(clean.shape)==2: 99 | clean=(clean[:,0]+clean[:,1])/2 100 | clean=clean/np.max(abs(clean)) 101 | clean=np.reshape(clean,(1,np.shape(clean)[0],1)) 102 | 103 | 104 | index += 1 105 | if index == len(noisy_list): 106 | index = 0 107 | 108 | random.shuffle(noisy_list) 109 | 110 | yield noisy, clean 111 | 112 | def val_data_generator(noisy_list, clean_path): 113 | index=0 114 | while True: 115 | #noisy, rate = librosa.load(noisy_list[index],sr=16000) 116 | rate, noisy = wavfile.read(noisy_list[index]) 117 | noisy=noisy.astype('float') 118 | if len(noisy.shape)==2: 119 | noisy=(noisy[:,0]+noisy[:,1])/2 120 | noisy=noisy/np.max(abs(noisy)) 121 | noisy=np.reshape(noisy,(1,np.shape(noisy)[0],1)) 122 | 123 | #clean, rate =librosa.load(clean_list[clean_wav_list.index(noisy_wav_list[index])],sr=16000) 124 | rate, clean = wavfile.read(clean_path+noisy_list[index][noisy_list[index].index('dB')+2:]) 125 | clean=clean.astype('float') 126 | if len(clean.shape)==2: 127 | clean=(clean[:,0]+clean[:,1])/2 128 | clean=clean/np.max(abs(clean)) 129 | clean=np.reshape(clean,(1,np.shape(clean)[0],1)) 130 | 131 | 132 | index += 1 133 | if index == len(noisy_list): 134 | index = 0 135 | 136 | yield noisy, clean 137 | 138 | # Data Path: change to your path! 139 | ######################### Training data ######################### 140 | Train_Noisy_lists = get_filepaths("/mnt/hd-02/avse/training/noisy") # Please change to your path 141 | Train_Clean_paths = "/mnt/hd-02/avse/training/clean/" # Please change to your path 142 | 143 | 144 | # data_shuffle 145 | random.shuffle(Train_Noisy_lists) 146 | 147 | Train_Noisy_lists=Train_Noisy_lists[0:Num_traindata] # Only use subset of training data 148 | 149 | steps_per_epoch = (Num_traindata)//batch_size 150 | ######################### Test_set ######################### 151 | Test_Noisy_lists = get_filepaths("/mnt/hd-02/avse/testing/noisy") # Please change to your path 152 | Test_Clean_paths = "/mnt/hd-02/avse/testing/clean/" # Please change to your path 153 | 154 | Num_testdata=len(Test_Noisy_lists) 155 | 156 | 157 | start_time = time.time() 158 | 159 | print 'model building...' 160 | 161 | model = Sequential() 162 | 163 | 164 | model.add(Convolution1D(30, 55, border_mode='same', input_shape=(None,1))) 165 | model.add(BatchNormalization(mode=2,axis=-1)) # Instance Normalization. Because of batch size=1. 166 | model.add(LeakyReLU()) 167 | #model.add(Dropout(0.06)) 168 | 169 | model.add(Convolution1D(30, 55, border_mode='same')) 170 | model.add(BatchNormalization(mode=2,axis=-1)) 171 | model.add(LeakyReLU()) 172 | #model.add(Dropout(0.06)) 173 | 174 | model.add(Convolution1D(30, 55, border_mode='same')) 175 | model.add(BatchNormalization(mode=2,axis=-1)) 176 | model.add(LeakyReLU()) 177 | #model.add(Dropout(0.06)) 178 | 179 | model.add(Convolution1D(30, 55, border_mode='same')) 180 | model.add(BatchNormalization(mode=2,axis=-1)) 181 | model.add(LeakyReLU()) 182 | #model.add(Dropout(0.06)) 183 | 184 | model.add(Convolution1D(30, 55, border_mode='same')) 185 | model.add(BatchNormalization(mode=2,axis=-1)) 186 | model.add(LeakyReLU()) 187 | #model.add(Dropout(0.06)) 188 | 189 | 190 | model.add(Convolution1D(30, 55, border_mode='same')) 191 | model.add(BatchNormalization(mode=2,axis=-1)) 192 | model.add(LeakyReLU()) 193 | #model.add(Dropout(0.06)) 194 | 195 | model.add(Convolution1D(30, 55, border_mode='same')) 196 | model.add(BatchNormalization(mode=2,axis=-1)) 197 | model.add(LeakyReLU()) 198 | #model.add(Dropout(0.06)) 199 | 200 | model.add(Convolution1D(1, 55, border_mode='same')) 201 | model.add(Activation('tanh')) 202 | 203 | model.compile(loss='mse', optimizer='adam') 204 | 205 | with open('FCNN_MSE.json','w') as f: # save the model 206 | f.write(model.to_json()) 207 | checkpointer = ModelCheckpoint(filepath='FCNN_MSE.hdf5', verbose=1, save_best_only=True, mode='min') 208 | 209 | print 'training...' 210 | g1 = train_data_generator(Train_Noisy_lists, Train_Clean_paths) 211 | g2 = val_data_generator (Test_Noisy_lists, Test_Clean_paths) 212 | 213 | hist=model.fit_generator(g1, 214 | samples_per_epoch=Num_traindata, 215 | nb_epoch=epoch, 216 | verbose=1, 217 | validation_data=g2, 218 | nb_val_samples=Num_testdata, 219 | max_q_size=1, 220 | nb_worker=1, 221 | pickle_safe=True, 222 | callbacks=[checkpointer] 223 | ) 224 | 225 | print 'De-noising...' 226 | maxv = np.iinfo(np.int16).max 227 | for path in Test_Noisy_lists: # Ex: /mnt/hd-02/avse/testing/noisy/engine/1dB/1.wav 228 | S=path.split('/') 229 | noise=S[-3] 230 | dB=S[-2] 231 | wave_name=S[-1] 232 | 233 | rate, noisy = wavfile.read(path) 234 | noisy=noisy.astype('float') 235 | if len(noisy.shape)==2: 236 | noisy=(noisy[:,0]+noisy[:,1])/2 237 | noisy=noisy/np.max(abs(noisy)) 238 | noisy=np.reshape(noisy,(1,np.shape(noisy)[0],1)) 239 | enhanced=np.squeeze(model.predict(noisy, verbose=0, batch_size=batch_size)) 240 | enhanced=enhanced/np.max(abs(enhanced)) 241 | librosa.output.write_wav(os.path.join("FCN_enhanced_MSE",noise, dB, wave_name), (enhanced* maxv).astype(np.int16), 16000) 242 | 243 | # plotting the learning curve 244 | TrainERR=hist.history['loss'] 245 | ValidERR=hist.history['val_loss'] 246 | print ('@%f, Minimun error:%f, at iteration: %i' % (hist.history['val_loss'][epoch-1], np.min(np.asarray(ValidERR)),np.argmin(np.asarray(ValidERR))+1)) 247 | print 'drawing the training process...' 248 | plt.figure(2) 249 | plt.plot(range(1,epoch+1),TrainERR,'b',label='TrainERR') 250 | plt.plot(range(1,epoch+1),ValidERR,'r',label='ValidERR') 251 | plt.xlim([1,epoch]) 252 | plt.legend() 253 | plt.xlabel('epoch') 254 | plt.ylabel('error') 255 | plt.grid(True) 256 | plt.show() 257 | plt.savefig('Learning_curve_FCN_MSE.png', dpi=150) 258 | 259 | 260 | end_time = time.time() 261 | print ('The code for this file ran for %.2fm' % ((end_time - start_time) / 60.)) 262 | 263 | -------------------------------------------------------------------------------- /images/E2E.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/End-to-end-waveform-utterance-enhancement/9b6ab1a9ed43536b6834b617a1893e3a8d7d162a/images/E2E.png -------------------------------------------------------------------------------- /images/Fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/End-to-end-waveform-utterance-enhancement/9b6ab1a9ed43536b6834b617a1893e3a8d7d162a/images/Fig1.png -------------------------------------------------------------------------------- /images/Fig1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/End-to-end-waveform-utterance-enhancement/9b6ab1a9ed43536b6834b617a1893e3a8d7d162a/images/Fig1_2.png -------------------------------------------------------------------------------- /images/Fig1_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/End-to-end-waveform-utterance-enhancement/9b6ab1a9ed43536b6834b617a1893e3a8d7d162a/images/Fig1_3.png -------------------------------------------------------------------------------- /images/t2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonSWFu/End-to-end-waveform-utterance-enhancement/9b6ab1a9ed43536b6834b617a1893e3a8d7d162a/images/t2.gif --------------------------------------------------------------------------------