├── GT_16channel_31tap.mat ├── __pycache__ ├── models.cpython-37.pyc ├── data_ops.cpython-37.pyc ├── file_ops.cpython-37.pyc ├── wgan_ops.cpython-37.pyc ├── normalizations.cpython-37.pyc └── keras_contrib_backend.cpython-37.pyc ├── wgan_ops.py ├── LICENSE ├── file_ops.py ├── download_dataset.sh ├── data_ops.py ├── README.md ├── prepare_data.py ├── models.py ├── keras_contrib_backend.py ├── run_aecnn.py ├── run_lsgan_se.py ├── test_wav.txt ├── run_rsgan-gp_se.py ├── run_wgan-gp_se.py └── normalizations.py /GT_16channel_31tap.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepakbaby/se_relativisticgan/HEAD/GT_16channel_31tap.mat -------------------------------------------------------------------------------- /__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepakbaby/se_relativisticgan/HEAD/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/data_ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepakbaby/se_relativisticgan/HEAD/__pycache__/data_ops.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/file_ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepakbaby/se_relativisticgan/HEAD/__pycache__/file_ops.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/wgan_ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepakbaby/se_relativisticgan/HEAD/__pycache__/wgan_ops.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/normalizations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepakbaby/se_relativisticgan/HEAD/__pycache__/normalizations.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/keras_contrib_backend.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepakbaby/se_relativisticgan/HEAD/__pycache__/keras_contrib_backend.cpython-37.pyc -------------------------------------------------------------------------------- /wgan_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions used by Wasserstein GAN 3 | """ 4 | from keras import backend as K 5 | import numpy as np 6 | 7 | def wasserstein_loss(y_true, y_pred): 8 | """ 9 | Define the Wasserstein loss for compiling and training the model 10 | """ 11 | return K.mean(y_true * y_pred) 12 | 13 | 14 | def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight): 15 | """ 16 | This term is used for stabilizing the WGAN training. 17 | """ 18 | gradients = K.gradients(y_pred, averaged_samples)[0] 19 | gradients_sqr = K.square(gradients) 20 | axes_for_sum = tuple(np.arange(1, len(gradients_sqr.shape))) 21 | gradients_sqr_sum = K.sum(gradients_sqr, axis=axes_for_sum) 22 | gradient_norm = K.sqrt(gradients_sqr_sum) 23 | gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_norm) 24 | return K.mean(gradient_penalty) 25 | 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Deepak Baby 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /file_ops.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Get filenames 3 | ''' 4 | import tensorflow as tf 5 | 6 | 7 | def get_modeldirname(opts): 8 | modeldir = opts['dirhead'] 9 | 10 | # add if noise input is not there 11 | if opts ['z_off']: 12 | modeldir += "_noZ" 13 | 14 | # add normalization name 15 | if opts['applyinstancenorm']: 16 | modeldir += "_IN" 17 | elif opts['applybatchrenorm']: 18 | modeldir += "_BRN" 19 | elif opts['applybatchnorm']: 20 | modeldir += "_BN" 21 | elif opts['applygroupnorm']: 22 | modeldir += "_GN" 23 | elif opts['applyspectralnorm']: 24 | modeldir += "_SN" 25 | 26 | # add optimizer 27 | modeldir += "_Adam_D" 28 | modeldir += str(opts['d_lr']) 29 | modeldir += "_G" 30 | modeldir += str(opts['g_lr']) 31 | 32 | # add L1 norm 33 | modeldir += "_L1_" + str(opts ['g_l1loss']) 34 | return modeldir 35 | 36 | 37 | def write_log(callback, names, logs, batch_no): 38 | for name, value in zip(names, logs): 39 | summary = tf.Summary() 40 | summary_value = summary.value.add() 41 | summary_value.simple_value = value 42 | summary_value.tag = name 43 | callback.writer.add_summary(summary, batch_no) 44 | callback.writer.flush() 45 | 46 | -------------------------------------------------------------------------------- /download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # The dataset can be downloaded manually from 3 | # https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791 4 | 5 | # specify the location to which the database be copied 6 | datadir=data 7 | 8 | 9 | # adapted from https://github.com/santi-pdp/segan 10 | datasets="clean_trainset_56spk_wav noisy_trainset_56spk_wav clean_testset_wav noisy_testset_wav" 11 | 12 | # DOWNLOAD THE DATASET 13 | mkdir -p $datadir 14 | pushd $datadir 15 | 16 | 17 | for dset in $datasets; do 18 | if [ ! -d ${dset}_16kHz ]; then 19 | # Clean utterances 20 | if [ ! -f ${dset}.zip ]; then 21 | echo 'DOWNLOADING $dset' 22 | wget http://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/${dset}.zip 23 | fi 24 | if [ ! -d ${dset} ]; then 25 | echo 'INFLATING ${dset}...' 26 | unzip -q ${det}.zip -d $dset 27 | fi 28 | if [ ! -d ${dset}_16kHz ]; then 29 | echo 'CONVERTING WAVS TO 16K...' 30 | mkdir -p ${dset}_16kHz 31 | pushd ${dset} 32 | ls *.wav > ../${dset}.flist 33 | ls *.wav | while read name; do 34 | sox $name -r 16k ../${dset}_16kHz/$name 35 | done 36 | popd 37 | fi 38 | fi 39 | done 40 | 41 | popd 42 | 43 | # store filenames in datadir 44 | cp $datadir/clean_trainset_56spk_wav.flist $datadir/train_wav.txt 45 | cp $datadir/clean_testset_wav.flist $datadir/test_wav.txt 46 | 47 | -------------------------------------------------------------------------------- /data_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data processing routines 3 | Deepak Baby, UGent, June 2018 4 | deepak.baby@ugent.be 5 | """ 6 | 7 | import numpy as np 8 | 9 | def reconstruct_wav(wavmat, stride_factor=0.5): 10 | """ 11 | Reconstructs the audiofile from sliced matrix wavmat 12 | """ 13 | window_length = wavmat.shape[1] 14 | window_stride = int(stride_factor * window_length) 15 | wav_length = (wavmat.shape[0] -1 ) * window_stride + window_length 16 | wav_recon = np.zeros((1,wav_length)) 17 | #print ("wav recon shape " + str(wav_recon.shape)) 18 | for k in range (wavmat.shape[0]): 19 | wav_beg = k * window_stride 20 | wav_end = wav_beg + window_length 21 | wav_recon[0, wav_beg:wav_end] += wavmat[k, :] 22 | 23 | # now compute the scaling factor for multiple instances 24 | noverlap = int(np.ceil(1/stride_factor)) 25 | scale_ = (1/float(noverlap)) * np.ones((1, wav_length)) 26 | for s in range(noverlap-1): 27 | s_beg = s * window_stride 28 | s_end = s_beg + window_stride 29 | scale_[0, s_beg:s_end] = 1/ (s+1) 30 | scale_[0, -s_beg - 1 : -s_end:-1] = 1/ (s+1) 31 | 32 | return wav_recon * scale_ 33 | 34 | def pre_emph(x, coeff=0.95): 35 | """ 36 | Apply pre_emph on 2d data (batch_size x window_length) 37 | """ 38 | #print ("x shape: " + str(x.shape)) 39 | x0 = x[:, 0] 40 | x0 = np.expand_dims(x0, axis=1) 41 | diff = x[:, 1:] - coeff * x[:, :-1] 42 | x_preemph = np.concatenate((x0, diff), axis=1) 43 | if not x.shape == x_preemph.shape: 44 | print ("ERROR: Pre-emphasis is wrong") 45 | #print ("x_preemph shape: " + str(x_preemph.shape)) 46 | return x_preemph 47 | 48 | def de_emph(y, coeff=0.95): 49 | """ 50 | Apply de_emphasis on test data: works only on 1d data 51 | """ 52 | if coeff <= 0: 53 | return y 54 | x = np.zeros((y.shape[0],), dtype=np.float32) 55 | #print("in_shape" + str(y.shape)) 56 | x[0] = y[0] 57 | for n in range(1, y.shape[0], 1): 58 | x[n] = coeff * x[n - 1] + y[n] 59 | return x 60 | 61 | def data_preprocess(wav, preemph=0.95): 62 | wav = (2./65535.) * (wav.astype('float32') - 32767) + 1. 63 | if preemph > 0: 64 | wav = pre_emph(wav, coeff=preemph) 65 | return wav.astype('float32') 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras framework for speech enhancement using relativistic GANs. 2 | Uses a fully convolutional end-to-end speech enhancement system. 3 | 4 | Implemetation details of the paper accepted to ICASSP-2019 5 | 6 | **Deepak Baby and Sarah Verhulst, _SERGAN: Speech enhancement using relativistic generative adversarial networks with gradient penalty_, IEEE-ICASSP, pp. 106-110, May 2019, Brighton, UK.** 7 | 8 | > This work was funded with support from the EU Horizon 2020 programme under grant agreement No 678120 (RobSpear). 9 | 10 | ---- 11 | ### Pre-requisites 12 | 1. Install [tensorflow](https://www.tensorflow.org/) (tested on Tensorflow v1.13.2) and [keras](https://keras.io/) (tested on Keras v2.3.1) 13 | 1. Install [tqdm](https://pypi.org/project/tqdm/) for profiling the training progress 14 | 1. The experiments are conducted on a dataset from Valentini et. al., and are downloaded from [here](https://datashare.is.ed.ac.uk/handle/10283/1942). The following script can be used to download the dataset. *Requires [sox](http://sox.sourceforge.net/) for converting to 16kHz*. 15 | ```bash 16 | $ ./download_dataset.sh 17 | ``` 18 | 19 | ### Running the model 20 | 1. **Prepare data for training and testing the various models**. The folder path may be edited if you keep the database in a different folder. This script is to be executed only once and the all the models reads from the same location. 21 | ```python 22 | python prepare_data.py 23 | ``` 24 | 2. **Running the models**. The models available in this repository are listed below. Every implementation offers several cGAN configurations. Edit the ```opts``` variable for choosing the cofiguration. The results will be automatically saved to different folders. The folder name is generated from ```files_ops.py ``` and the foldername automatically includes different configuration options. 25 | 1. `run_aecnn.py` : Auto-encoder CNN model with L1 loss term (No discriminator) 26 | 1. `run_lsgan_se.py` : SEGAN with least-squares loss [1] 27 | 2. `run_wgan-gp_se.py` : GAN model with Wassterstein loss and Gradient Penalty 28 | 3. `run_rsgan-gp_se.py` : GAN model with relativistic standard GAN with Gradient Penalty 29 | 4. `run_rasgan-gp_se.py` : GAN model with relativistic average standard GAN with Gradient Penalty 30 | 5. `run_ralsgan-gp_se.py`: GAN model with relativistic average least-squares GAN with Gradient Penalty 31 | 32 | 3. **Evaluation on testset is also done together with training**. Set ```TEST_SEGAN = False``` for disabling testing. 33 | 34 | ---- 35 | ### Misc 36 | * **This code loads all the data into memory for speeding up training**. But if you dont have enough memory, it is possible to read the mini-batches from the disk using HDF5 read. In ```run_.py``` 37 | ```python 38 | clean_train_data = np.array(fclean['feat_data']) 39 | noisy_train_data = np.array(fnoisy['feat_data']) 40 | ``` 41 | change the above lines to 42 | ```python 43 | clean_train_data = fclean['feat_data'] 44 | noisy_train_data = fnoisy['feat_data'] 45 | ``` 46 | **But this can lead to a slow-down of about 20 times (on the test machine)** as the mini-batches are to be read from the disk over several epochs. 47 | 48 | ---- 49 | ### References 50 | [1] S. Pascual, A. Bonafonte, and J. Serra, _SEGAN: speech enhancement generative adversarial network_, in INTERSPEECH., ISCA, Aug 2017, pp. 3642–3646. 51 | 52 | ---- 53 | #### Credits 54 | The keras implementation of cGAN is based on the following repos 55 | * [SEGAN](https://github.com/santi-pdp/segan) 56 | * [DCGAN](https://github.com/carpedm20/DCGAN-tensorflow) 57 | * [pix2pix](https://github.com/phillipi/pix2pix) 58 | 59 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | # for preparing the segan training and test data 2 | 3 | import tensorflow as tf 4 | from data_ops import * 5 | import scipy.io.wavfile as wavfile 6 | import numpy as np 7 | import os 8 | import hdf5storage 9 | 10 | def slice_1dsignal(signal, window_size, minlength, stride=0.5): 11 | """ 12 | Return windows of the given signal by sweeping in stride fractions 13 | of window 14 | Slices that are less than minlength are omitted 15 | """ 16 | n_samples = signal.shape[0] 17 | offset = int(window_size * stride) 18 | num_slices = (n_samples) 19 | slices = np.array([]).reshape(0, window_size) # initialize empty array 20 | for beg_i in range(0, n_samples, offset): 21 | end_i = beg_i + window_size 22 | if n_samples - beg_i < minlength : 23 | break 24 | if end_i <= n_samples : 25 | slice_ = np.array([signal[beg_i:end_i]]) 26 | else : 27 | slice_ = np.concatenate((np.array([signal[beg_i:]]), np.zeros((1, end_i - n_samples))), axis=1) 28 | slices = np.concatenate((slices, slice_), axis=0) 29 | return slices.astype('float32') 30 | 31 | def read_and_slice1d(wavfilename, window_size, minlength, stride=0.5): 32 | """ 33 | Reads and slices the wavfile into windowed chunks 34 | """ 35 | fs, signal = wavfile.read(wavfilename) 36 | if fs != 16000: 37 | raise ValueError('Sampling rate is expected to be 16kHz!') 38 | sliced = slice_1dsignal(signal, window_size, minlength, stride=stride) 39 | return sliced 40 | 41 | def prepare_sliced_data1d(opts): 42 | wavfolder = opts['wavfolder'] 43 | window_size = opts['window_size'] 44 | stride = opts['stride'] 45 | minlength = opts['minlength'] 46 | filenames = opts['filenames'] 47 | 48 | full_sliced = [] # initialize empty list 49 | dfi = [] 50 | dfi_begin = 0 51 | with open(filenames) as f: 52 | wav_files = f.read().splitlines() # to get rid of the \n while using readlines() 53 | print ("**** Reading from " + wavfolder) 54 | print ("**** The folder has " + str(len(wav_files)) + " files.") 55 | for ind, wav_file in enumerate(wav_files): 56 | if ind % 10 == 0 : 57 | print("Processing " + str(ind) + " of " + str(len(wav_files)) + " files.") 58 | wavfilename = os.path.join(wavfolder, wav_file) 59 | sliced = read_and_slice1d(wavfilename, window_size, minlength, stride=stride) 60 | full_sliced.append(sliced) 61 | dfi.append(np.array([[dfi_begin, dfi_begin + sliced.shape[0]]])) 62 | dfi_begin += sliced.shape[0] 63 | 64 | full_slicedstack = np.vstack(full_sliced) 65 | dfistack = np.vstack(dfi) 66 | 67 | return full_slicedstack, dfistack.astype('int') 68 | 69 | if __name__ == '__main__': 70 | 71 | opts = {} 72 | opts ['datafolder'] = "data" 73 | opts ['window_size'] = 2**14 74 | opts['stride']= 0.5 75 | opts['minlength']= 0.5 * (2 ** 14) 76 | testfilenames = os.path.join(opts['datafolder'], "test_wav.txt") 77 | trainfilenames = os.path.join(opts['datafolder'], "train_wav.txt") 78 | 79 | # for test set 80 | opts['filenames'] = testfilenames 81 | # for clean set 82 | opts['wavfolder'] = os.path.join(opts['datafolder'], "clean_testset_wav_16kHz") 83 | clean_test_sliced, dfi = prepare_sliced_data1d(opts) 84 | # for noisy set 85 | opts['wavfolder'] = os.path.join(opts['datafolder'], "noisy_testset_wav_16kHz") 86 | noisy_test_sliced, dfi = prepare_sliced_data1d(opts) 87 | if clean_test_sliced.shape[0] != noisy_test_sliced.shape[0] : 88 | raise ValueError('Clean sliced and noisy sliced are not of the same size!') 89 | if clean_test_sliced.shape[0] != dfi[-1,1] : 90 | raise ValueError('Sliced matrices have a different size than mentioned in dfi !') 91 | 92 | matcontent={} 93 | matcontent[u'feat_data'] = clean_test_sliced 94 | matcontent[u'dfi'] = dfi 95 | destinationfilenameclean = "./data/clean_test_segan1d.mat" 96 | hdf5storage.savemat(destinationfilenameclean, matcontent) 97 | 98 | matcontent={} 99 | matcontent[u'feat_data'] = noisy_test_sliced 100 | matcontent[u'dfi'] = dfi 101 | destinationfilenamenoisy = "./data/noisy_test_segan1d.mat" 102 | hdf5storage.savemat(destinationfilenamenoisy, matcontent) 103 | 104 | 105 | # for train set 106 | opts['filenames'] = trainfilenames 107 | # for clean set 108 | opts['wavfolder'] = os.path.join(opts['datafolder'], "clean_trainset_56spk_wav_16kHz") 109 | clean_train_sliced, dfi = prepare_sliced_data1d(opts) 110 | # for noisy set 111 | opts['wavfolder'] = os.path.join(opts['datafolder'], "noisy_trainset_56spk_wav_16kHz") 112 | noisy_train_sliced, dfi = prepare_sliced_data1d(opts) 113 | if clean_train_sliced.shape[0] != noisy_train_sliced.shape[0] : 114 | raise ValueError('Clean sliced and noisy sliced are not of the same size!') 115 | if clean_train_sliced.shape[0] != dfi[-1,1] : 116 | raise ValueError('Sliced matrices have a different size than mentioned in dfi !') 117 | 118 | matcontent={} 119 | matcontent[u'feat_data'] = clean_train_sliced 120 | matcontent[u'dfi'] = dfi 121 | destinationfilenameclean = "./data/clean_train_segan1d.mat" 122 | hdf5storage.savemat(destinationfilenameclean, matcontent) 123 | 124 | matcontent={} 125 | matcontent[u'feat_data'] = noisy_train_sliced 126 | matcontent[u'dfi'] = dfi 127 | destinationfilenamenoisy = "./data/noisy_train_segan1d.mat" 128 | hdf5storage.savemat(destinationfilenamenoisy, matcontent) 129 | 130 | 131 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reimplementing segan paper as close as possible. 3 | Deepak Baby, UGent, June 2018. 4 | """ 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | from tensorflow.contrib.layers import xavier_initializer, flatten, fully_connected 9 | import numpy as np 10 | import keras 11 | from keras.layers import Input, Dense, Conv1D, Conv2DTranspose, BatchNormalization 12 | from keras.layers import LeakyReLU, PReLU, Reshape, Concatenate, Flatten, Activation 13 | from keras.models import Sequential, Model 14 | from keras.optimizers import Adam 15 | from keras.callbacks import TensorBoard 16 | from normalizations import InstanceNormalization 17 | #from bnorm import VBN 18 | #Conv2DTranspose = tf.keras.layers.Conv2DTranspose 19 | keras_backend = tf.keras.backend 20 | keras_initializers = tf.keras.initializers 21 | from data_ops import * 22 | 23 | import h5py 24 | 25 | def generator(opts): 26 | kwidth=opts['filterlength'] 27 | strides= opts['strides'] 28 | pool = strides 29 | g_enc_numkernels = opts ['g_enc_numkernels'] 30 | g_dec_numkernels = opts ['g_dec_numkernels'] 31 | window_length = opts['window_length'] 32 | featdim = opts ['featdim'] 33 | batch_size = opts['batch_size'] 34 | 35 | use_bias = True 36 | skips = [] 37 | #kernel_init = keras.initializers.TruncatedNormal(stddev=0.02) 38 | kernel_init = 'glorot_uniform' 39 | 40 | wav_in = Input(shape=(window_length, featdim)) 41 | enc_out = wav_in 42 | 43 | # Defining the Encoder 44 | for layernum, numkernels in enumerate(g_enc_numkernels): 45 | enc_out = Conv1D(numkernels, kwidth, strides=pool, 46 | kernel_initializer=kernel_init, padding="same", 47 | use_bias=use_bias)(enc_out) 48 | 49 | # for skip connections 50 | if layernum < len(g_enc_numkernels) - 1: 51 | skips.append(enc_out) 52 | if opts['applyprelu']: 53 | enc_out = PReLU(alpha_initializer='zero', weights=None)(enc_out) 54 | else: 55 | enc_out = LeakyReLU(alpha=opts['leakyrelualpha'])(enc_out) 56 | 57 | num_enc_layers = len(g_enc_numkernels) 58 | z_rows = int(window_length/ (pool ** num_enc_layers)) 59 | z_cols = g_enc_numkernels[-1] 60 | 61 | # Adding the intermediate noise layer 62 | if not opts['z_off']: 63 | z = Input(shape=(z_rows,z_cols), name='noise_input') 64 | dec_out = keras.layers.concatenate([enc_out, z]) 65 | else : 66 | dec_out = enc_out 67 | 68 | # Now to the decoder part 69 | nrows = z_rows 70 | ncols = dec_out.get_shape().as_list()[-1] 71 | for declayernum, decnumkernels in enumerate(g_dec_numkernels): 72 | # reshape for the conv2dtranspose layer as it needs 3D input 73 | indim = dec_out.get_shape().as_list() 74 | newshape = (indim[1], 1 , indim[2]) 75 | dec_out = Reshape(newshape)(dec_out) 76 | # add the conv2dtranspose layer 77 | dec_out = Conv2DTranspose(decnumkernels, [kwidth,1], strides=[strides, 1], 78 | kernel_initializer=kernel_init, padding="same", use_bias=use_bias)(dec_out) 79 | # Reshape back to 2D 80 | nrows *= strides # number of rows get multiplied by strides 81 | ncols = decnumkernels # number of cols is the same as number of kernels 82 | dec_out.set_shape([None, nrows, 1 , ncols]) # for correcting shape issue with conv2dtranspose 83 | newshape = (nrows, ncols) 84 | if declayernum == len(g_dec_numkernels) -1: 85 | dec_out = Reshape(newshape, name="g_output")(dec_out) # name the final output as g_output 86 | else: 87 | dec_out = Reshape(newshape)(dec_out) 88 | 89 | # add skip and prelu until the second-last layer 90 | if declayernum < len(g_dec_numkernels) -1 : 91 | if opts['applyprelu']: 92 | dec_out = PReLU(alpha_initializer='zero', weights=None)(dec_out) 93 | else: 94 | dec_out = LeakyReLU(alpha=opts['leakyrelualpha'])(dec_out) 95 | # Now add the skip connection 96 | skip_ = skips[-(declayernum + 1)] 97 | dec_out = keras.layers.concatenate([dec_out, skip_]) 98 | 99 | # Create the model graph 100 | if opts ['z_off']: 101 | G = Model(inputs=[wav_in], outputs=[dec_out]) 102 | else : 103 | G = Model(inputs=[wav_in, z], outputs=[dec_out]) 104 | 105 | if opts ['show_summary'] : 106 | G.summary() 107 | 108 | return G 109 | 110 | 111 | 112 | def discriminator(opts): 113 | print('*** Building Discriminator ***') 114 | window_length = opts['window_length'] 115 | featdim = opts ['featdim'] 116 | batch_size = opts['batch_size'] 117 | d_fmaps = opts ['d_fmaps'] 118 | strides = opts['strides'] 119 | activation = opts['d_activation'] 120 | kwidth = opts['filterlength'] 121 | 122 | wav_in_clean = Input(shape=(window_length, featdim), name='disc_inputclean') 123 | wav_in_noisy = Input(shape=(window_length, featdim), name='disc_inputnoisy') 124 | 125 | use_bias= True 126 | #kernel_init = keras.initializers.TruncatedNormal(stddev=0.02) 127 | kernel_init = 'glorot_uniform' 128 | 129 | d_out = keras.layers.concatenate([wav_in_clean, wav_in_noisy]) 130 | 131 | for layer_num, numkernels in enumerate(d_fmaps): 132 | d_out = Conv1D(numkernels, kwidth, strides=strides, kernel_initializer=kernel_init, 133 | use_bias=use_bias, padding="same")(d_out) 134 | 135 | if opts['applybn']: 136 | d_out = BatchNormalization()(d_out) 137 | elif opts['applyinstancenorm'] : 138 | d_out = InstanceNormalization(axis=2)(d_out) 139 | 140 | if activation == 'leakyrelu': 141 | d_out = LeakyReLU(alpha=opts['leakyrelualpha'])(d_out) 142 | elif activation == 'relu': 143 | d_out = tf.nn.relu(d_out) 144 | 145 | d_out = Conv1D(1, 1, padding="same", use_bias=use_bias, kernel_initializer=kernel_init, 146 | name='logits_conv')(d_out) 147 | d_out = Flatten()(d_out) 148 | d_out = Dense(1, activation='linear', name='d_output')(d_out) 149 | D = Model([wav_in_clean, wav_in_noisy], d_out) 150 | 151 | if opts ['show_summary']: 152 | D.summary() 153 | return D 154 | 155 | -------------------------------------------------------------------------------- /keras_contrib_backend.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | try: 5 | from tensorflow.python.ops import ctc_ops as ctc 6 | except ImportError: 7 | import tensorflow.contrib.ctc as ctc 8 | from keras.backend import tensorflow_backend as KTF 9 | from keras.backend import dtype 10 | from keras.backend.common import floatx 11 | from keras.backend.common import image_data_format 12 | from keras.backend.tensorflow_backend import _to_tensor 13 | from keras.backend.tensorflow_backend import logsumexp 14 | 15 | py_all = all 16 | 17 | 18 | def _preprocess_conv2d_input(x, data_format): 19 | """Transpose and cast the input before the conv2d. 20 | # Arguments 21 | x: input tensor. 22 | data_format: string, `"channels_last"` or `"channels_first"`. 23 | # Returns 24 | A tensor. 25 | """ 26 | if dtype(x) == 'float64': 27 | x = tf.cast(x, 'float32') 28 | if data_format == 'channels_first': 29 | # TF uses the last dimension as channel dimension, 30 | # instead of the 2nd one. 31 | # TH input shape: (samples, input_depth, rows, cols) 32 | # TF input shape: (samples, rows, cols, input_depth) 33 | x = tf.transpose(x, (0, 2, 3, 1)) 34 | return x 35 | 36 | 37 | def _postprocess_conv2d_output(x, data_format): 38 | """Transpose and cast the output from conv2d if needed. 39 | # Arguments 40 | x: A tensor. 41 | data_format: string, `"channels_last"` or `"channels_first"`. 42 | # Returns 43 | A tensor. 44 | """ 45 | 46 | if data_format == 'channels_first': 47 | x = tf.transpose(x, (0, 3, 1, 2)) 48 | 49 | if floatx() == 'float64': 50 | x = tf.cast(x, 'float64') 51 | return x 52 | 53 | 54 | def _preprocess_padding(padding): 55 | """Convert keras' padding to tensorflow's padding. 56 | # Arguments 57 | padding: string, `"same"` or `"valid"`. 58 | # Returns 59 | a string, `"SAME"` or `"VALID"`. 60 | # Raises 61 | ValueError: if `padding` is invalid. 62 | """ 63 | if padding == 'same': 64 | padding = 'SAME' 65 | elif padding == 'valid': 66 | padding = 'VALID' 67 | else: 68 | raise ValueError('Invalid padding:', padding) 69 | return padding 70 | 71 | 72 | def conv2d(x, kernel, strides=(1, 1), padding='valid', data_format='channels_first', 73 | image_shape=None, filter_shape=None): 74 | '''2D convolution. 75 | # Arguments 76 | kernel: kernel tensor. 77 | strides: strides tuple. 78 | padding: string, "same" or "valid". 79 | data_format: "tf" or "th". Whether to use Theano or TensorFlow dimension ordering 80 | in inputs/kernels/ouputs. 81 | ''' 82 | if padding == 'same': 83 | padding = 'SAME' 84 | elif padding == 'valid': 85 | padding = 'VALID' 86 | else: 87 | raise Exception('Invalid border mode: ' + str(padding)) 88 | 89 | strides = (1,) + strides + (1,) 90 | 91 | if floatx() == 'float64': 92 | # tf conv2d only supports float32 93 | x = tf.cast(x, 'float32') 94 | kernel = tf.cast(kernel, 'float32') 95 | 96 | if data_format == 'channels_first': 97 | # TF uses the last dimension as channel dimension, 98 | # instead of the 2nd one. 99 | # TH input shape: (samples, input_depth, rows, cols) 100 | # TF input shape: (samples, rows, cols, input_depth) 101 | # TH kernel shape: (depth, input_depth, rows, cols) 102 | # TF kernel shape: (rows, cols, input_depth, depth) 103 | x = tf.transpose(x, (0, 2, 3, 1)) 104 | kernel = tf.transpose(kernel, (2, 3, 1, 0)) 105 | x = tf.nn.conv2d(x, kernel, strides, padding=padding) 106 | x = tf.transpose(x, (0, 3, 1, 2)) 107 | elif data_format == 'channels_last': 108 | x = tf.nn.conv2d(x, kernel, strides, padding=padding) 109 | else: 110 | raise Exception('Unknown data_format: ' + str(data_format)) 111 | 112 | if floatx() == 'float64': 113 | x = tf.cast(x, 'float64') 114 | return x 115 | 116 | 117 | def extract_image_patches(x, ksizes, ssizes, padding='same', 118 | data_format='channels_last'): 119 | ''' 120 | Extract the patches from an image 121 | # Parameters 122 | x : The input image 123 | ksizes : 2-d tuple with the kernel size 124 | ssizes : 2-d tuple with the strides size 125 | padding : 'same' or 'valid' 126 | data_format : 'channels_last' or 'channels_first' 127 | # Returns 128 | The (k_w,k_h) patches extracted 129 | TF ==> (batch_size,w,h,k_w,k_h,c) 130 | TH ==> (batch_size,w,h,c,k_w,k_h) 131 | ''' 132 | kernel = [1, ksizes[0], ksizes[1], 1] 133 | strides = [1, ssizes[0], ssizes[1], 1] 134 | padding = _preprocess_padding(padding) 135 | if data_format == 'channels_first': 136 | x = KTF.permute_dimensions(x, (0, 2, 3, 1)) 137 | bs_i, w_i, h_i, ch_i = KTF.int_shape(x) 138 | patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1], 139 | padding) 140 | # Reshaping to fit Theano 141 | bs, w, h, ch = KTF.int_shape(patches) 142 | patches = tf.reshape(tf.transpose(tf.reshape(patches, [-1, w, h, tf.floordiv(ch, ch_i), ch_i]), [0, 1, 2, 4, 3]), 143 | [-1, w, h, ch_i, ksizes[0], ksizes[1]]) 144 | if data_format == 'channels_last': 145 | patches = KTF.permute_dimensions(patches, [0, 1, 2, 4, 5, 3]) 146 | return patches 147 | 148 | 149 | def depth_to_space(input, scale, data_format=None): 150 | ''' Uses phase shift algorithm to convert channels/depth for spatial resolution ''' 151 | if data_format is None: 152 | data_format = image_data_format() 153 | data_format = data_format.lower() 154 | input = _preprocess_conv2d_input(input, data_format) 155 | out = tf.depth_to_space(input, scale) 156 | out = _postprocess_conv2d_output(out, data_format) 157 | return out 158 | 159 | 160 | def moments(x, axes, shift=None, keep_dims=False): 161 | ''' Wrapper over tensorflow backend call ''' 162 | 163 | return tf.nn.moments(x, axes, shift=shift, keep_dims=keep_dims) 164 | 165 | 166 | def clip(x, min_value, max_value): 167 | """Element-wise value clipping. 168 | If min_value > max_value, clipping range is [min_value,min_value]. 169 | # Arguments 170 | x: Tensor or variable. 171 | min_value: Tensor, float, int, or None. 172 | If min_value is None, defaults to -infinity. 173 | max_value: Tensor, float, int, or None. 174 | If max_value is None, defaults to infinity. 175 | # Returns 176 | A tensor. 177 | """ 178 | if max_value is None: 179 | max_value = np.inf 180 | if min_value is None: 181 | min_value = -np.inf 182 | min_value = _to_tensor(min_value, x.dtype.base_dtype) 183 | max_value = _to_tensor(max_value, x.dtype.base_dtype) 184 | max_value = tf.maximum(min_value, max_value) 185 | return tf.clip_by_value(x, min_value, max_value) 186 | -------------------------------------------------------------------------------- /run_aecnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reimplementing AECNN model minimizing L1 loss. 3 | No discriminator. 4 | 5 | 6 | Written by Deepak Baby, UGent, Oct 2018. 7 | """ 8 | from __future__ import print_function 9 | 10 | import tensorflow as tf 11 | from tensorflow.contrib.layers import xavier_initializer, flatten, fully_connected 12 | import numpy as np 13 | import keras 14 | from keras.layers import Input, Dense, Conv1D, Conv2D, Conv2DTranspose, BatchNormalization 15 | from keras.layers import LeakyReLU, PReLU, Reshape, Concatenate, Flatten, Add, Lambda 16 | from keras.models import Sequential, Model 17 | from keras.optimizers import Adam 18 | from keras.callbacks import TensorBoard 19 | keras_backend = tf.keras.backend 20 | keras_initializers = tf.keras.initializers 21 | from data_ops import * 22 | from file_ops import * 23 | from models import * 24 | import keras.backend as K 25 | 26 | import time 27 | from tqdm import * 28 | import h5py 29 | import os,sys 30 | import scipy.io.wavfile as wavfile 31 | 32 | if __name__ == '__main__': 33 | 34 | # Various GAN options 35 | opts = {} 36 | opts ['dirhead'] = "AECNN_L1loss" 37 | opts ['z_off'] = True # set to True to omit the latent noise input 38 | # normalization 39 | ################################# 40 | # Only one of the follwoing should be set to True 41 | opts ['applyinstancenorm'] = False 42 | opts ['applybatchrenorm'] = False 43 | opts ['applybatchnorm'] = False 44 | opts ['applygroupnorm'] = False 45 | opts ['applyspectralnorm'] = False 46 | ################################## 47 | # Show model summary 48 | opts ['show_summary'] = True 49 | 50 | ## Set the matfiles 51 | clean_train_matfile = "./data/clean_train_segan1d.mat" 52 | noisy_train_matfile = "./data/noisy_train_segan1d.mat" 53 | noisy_test_matfile = "./data/noisy_test_segan1d.mat" 54 | 55 | #################################################### 56 | # Other fixed options 57 | opts ['window_length'] = 2**14 58 | opts ['featdim'] = 1 # 1 since it is just 1d time samples 59 | opts ['filterlength'] = 31 60 | opts ['strides'] = 2 61 | opts ['padding'] = 'SAME' 62 | opts ['g_enc_numkernels'] = [16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024] 63 | opts ['d_fmaps'] = opts ['g_enc_numkernels'] # We use the same structure for discriminator 64 | opts['leakyrelualpha'] = 0.3 65 | opts ['batch_size'] = 100 66 | opts ['applyprelu'] = True 67 | opts ['preemph'] = 0.95 68 | 69 | 70 | opts ['d_activation'] = 'leakyrelu' 71 | g_enc_numkernels = opts ['g_enc_numkernels'] 72 | opts ['g_dec_numkernels'] = g_enc_numkernels[:-1][::-1] + [1] 73 | opts ['gt_stride'] = 2 74 | opts ['g_l1loss'] = 100. 75 | opts ['d_lr'] = 0.0002 76 | opts ['g_lr'] = 0.0002 77 | opts ['random_seed'] = 111 78 | 79 | n_epochs = 81 80 | fs = 16000 81 | 82 | # set flags for training or testing 83 | TRAIN_SEGAN = True 84 | SAVE_MODEL = True 85 | LOAD_SAVED_MODEL = False 86 | TEST_SEGAN = True 87 | 88 | modeldir = get_modeldirname(opts) 89 | print ("The model directory is " + modeldir) 90 | print ("_____________________________________") 91 | 92 | if not os.path.exists(modeldir): 93 | os.makedirs(modeldir) 94 | 95 | # Obtain the generator and the discriminator 96 | G = generator(opts) 97 | 98 | # Define optimizers 99 | g_opt = keras.optimizers.Adam(lr=opts['g_lr']) 100 | 101 | # The G model has the wav and the noise inputs 102 | wav_shape = (opts['window_length'], opts['featdim']) 103 | wav_in_noisy = Input(shape=wav_shape, name="main_input_noisy") 104 | 105 | G_wav = G(wav_in_noisy) 106 | G = Model(wav_in_noisy, G_wav) 107 | G.summary() 108 | 109 | # compile individual models 110 | G.compile(loss='mean_absolute_error', optimizer=g_opt) 111 | 112 | 113 | if TEST_SEGAN: 114 | ftestnoisy = h5py.File(noisy_test_matfile) 115 | noisy_test_data = ftestnoisy['feat_data'] 116 | noisy_test_dfi = ftestnoisy['dfi'] 117 | print ("Number of test files: " + str(noisy_test_dfi.shape[1]) ) 118 | 119 | 120 | # Begin the training part 121 | if TRAIN_SEGAN: 122 | fclean = h5py.File(clean_train_matfile) 123 | clean_train_data = np.array(fclean['feat_data']).astype('float32') 124 | fnoisy = h5py.File(noisy_train_matfile) 125 | noisy_train_data = np.array(fnoisy['feat_data']).astype('float32') 126 | numtrainsamples = clean_train_data.shape[1] 127 | idx_all = np.arange(numtrainsamples) 128 | # set random seed 129 | np.random.seed(opts['random_seed']) 130 | batch_size = opts['batch_size'] 131 | 132 | print ("********************************************") 133 | print (" SEGAN TRAINING ") 134 | print ("********************************************") 135 | print ("Shape of clean feats mat " + str(clean_train_data.shape)) 136 | print ("Shape of noisy feats mat " + str(noisy_train_data.shape)) 137 | numtrainsamples = clean_train_data.shape[1] 138 | 139 | # Tensorboard stuff 140 | log_path = './logs/' + modeldir 141 | callback = TensorBoard(log_path) 142 | callback.set_model(G) 143 | train_names = ['G_loss'] 144 | 145 | idx_all = np.arange(numtrainsamples) 146 | # set random seed 147 | np.random.seed(opts['random_seed']) 148 | 149 | batch_size = opts['batch_size'] 150 | num_batches_per_epoch = int(np.floor(clean_train_data.shape[1]/batch_size)) 151 | for epoch in range(n_epochs): 152 | # train D with minibatch 153 | np.random.shuffle(idx_all) # shuffle the indices for the next epoch 154 | for batch_idx in range(num_batches_per_epoch): 155 | start_time = time.time() 156 | idx_beg = batch_idx * batch_size 157 | idx_end = idx_beg + batch_size 158 | idx = np.sort(np.array(idx_all[idx_beg:idx_end])) 159 | #print ("Batch idx " + str(idx[:5]) +" ... " + str(idx[-5:])) 160 | cleanwavs = np.array(clean_train_data[:,idx]).T 161 | cleanwavs = data_preprocess(cleanwavs, preemph=opts['preemph']) 162 | cleanwavs = np.expand_dims(cleanwavs, axis = 2) 163 | noisywavs = np.array(noisy_train_data[:,idx]).T 164 | noisywavs = data_preprocess(noisywavs, preemph=opts['preemph']) 165 | noisywavs = np.expand_dims(noisywavs, axis = 2) 166 | 167 | g_loss = G.train_on_batch(noisywavs, cleanwavs) 168 | 169 | time_taken = time.time() - start_time 170 | 171 | printlog = "E%d/%d:B%d/%d [G loss: %f] [Exec. time: %f]" % (epoch, n_epochs, batch_idx, num_batches_per_epoch, g_loss, time_taken) 172 | 173 | print (printlog) 174 | # Tensorboard stuff 175 | logs = [g_loss] 176 | write_log(callback, train_names, logs, epoch) 177 | 178 | if (TEST_SEGAN and epoch % 10 == 0) or epoch == n_epochs - 1: 179 | print ("********************************************") 180 | print (" SEGAN TESTING ") 181 | print ("********************************************") 182 | 183 | resultsdir = modeldir + "/test_results_epoch" + str(epoch) 184 | if not os.path.exists(resultsdir): 185 | os.makedirs(resultsdir) 186 | 187 | if LOAD_SAVED_MODEL: 188 | print ("Loading model from " + modeldir + "/Gmodel") 189 | json_file = open(modeldir + "/Gmodel.json", "r") 190 | loaded_model_json = json_file.read() 191 | json_file.close() 192 | G_loaded = model_from_json(loaded_model_json) 193 | G_loaded.compile(loss='mean_squared_error', optimizer=g_opt) 194 | G_loaded.load_weights(modeldir + "/Gmodel.h5") 195 | else: 196 | G_loaded = G 197 | 198 | print ("Saving Results to " + resultsdir) 199 | 200 | for test_num in tqdm(range(noisy_test_dfi.shape[1])) : 201 | test_beg = noisy_test_dfi[0, test_num] 202 | test_end = noisy_test_dfi[1, test_num] 203 | #print ("Reading indices " + str(test_beg) + " to " + str(test_end)) 204 | noisywavs = np.array(noisy_test_data[:,test_beg:test_end]).T 205 | noisywavs = data_preprocess(noisywavs, preemph=opts['preemph']) 206 | noisywavs = np.expand_dims(noisywavs, axis = 2) 207 | if not opts['z_off']: 208 | noiseinput = np.random.normal(0, 1, (noisywavs.shape[0], z_dim1, z_dim2)) 209 | cleaned_wavs = G_loaded.predict([noisywavs, noiseinput]) 210 | else : 211 | cleaned_wavs = G_loaded.predict(noisywavs) 212 | 213 | cleaned_wavs = np.reshape(cleaned_wavs, (noisywavs.shape[0], noisywavs.shape[1])) 214 | cleanwav = reconstruct_wav(cleaned_wavs) 215 | cleanwav = np.reshape(cleanwav, (-1,)) # make it to 1d by dropping the extra dimension 216 | 217 | if opts['preemph'] > 0: 218 | cleanwav = de_emph(cleanwav, coeff=opts['preemph']) 219 | 220 | destfilename = resultsdir + "/testwav_%d.wav" % (test_num) 221 | wavfile.write(destfilename, fs, cleanwav) 222 | 223 | 224 | 225 | # Finally, save the model 226 | if SAVE_MODEL: 227 | model_json = G.to_json() 228 | with open(modeldir + "/Gmodel.json", "w") as json_file: 229 | json_file.write(model_json) 230 | G.save_weights(modeldir + "/Gmodel.h5") 231 | print ("Model saved to " + modeldir) 232 | -------------------------------------------------------------------------------- /run_lsgan_se.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reimplementing SEGAN paper as close as possible in Keras. 3 | But use instance normalization instread of virtual batch normalization 4 | Deepak Baby, UGent, June 2018. 5 | """ 6 | from __future__ import print_function 7 | 8 | import tensorflow as tf 9 | from tensorflow.contrib.layers import xavier_initializer, flatten, fully_connected 10 | import numpy as np 11 | import keras 12 | from keras.layers import Input, Dense, Conv1D, Conv2D, Conv2DTranspose, BatchNormalization 13 | from keras.layers import LeakyReLU, PReLU, Reshape, Concatenate, Flatten 14 | from keras.models import Sequential, Model 15 | from keras.optimizers import Adam 16 | from keras.callbacks import TensorBoard 17 | keras_backend = tf.keras.backend 18 | keras_initializers = tf.keras.initializers 19 | from data_ops import * 20 | from file_ops import * 21 | from models import * 22 | 23 | import time 24 | from tqdm import * 25 | import h5py 26 | import os,sys 27 | import scipy.io.wavfile as wavfile 28 | 29 | if __name__ == '__main__': 30 | 31 | # Various GAN options 32 | opts = {} 33 | opts ['dirhead'] = "LSGAN" 34 | opts ['z_off'] = True # set to True to omit the latent noise input 35 | # normalization 36 | ################################# 37 | # Only one of the follwoing should be set to True 38 | opts ['applyinstancenorm'] = True 39 | opts ['applybn'] = False 40 | ################################## 41 | # Show model summary 42 | opts ['show_summary'] = True 43 | 44 | ## Set the matfiles 45 | clean_train_matfile = "./data/clean_train_segan1d.mat" 46 | noisy_train_matfile = "./data/noisy_train_segan1d.mat" 47 | noisy_test_matfile = "./data/noisy_test_segan1d.mat" 48 | 49 | #################################################### 50 | # Other fixed options 51 | opts ['window_length'] = 2**14 52 | opts ['featdim'] = 1 # 1 since it is just 1d time samples 53 | opts ['filterlength'] = 31 54 | opts ['strides'] = 2 55 | opts ['padding'] = 'SAME' 56 | opts ['g_enc_numkernels'] = [16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024] 57 | opts ['d_fmaps'] = opts ['g_enc_numkernels'] # We use the same structure for discriminator 58 | opts['leakyrelualpha'] = 0.3 59 | opts ['batch_size'] = 100 60 | opts ['applyprelu'] = True 61 | opts ['preemph'] = 0.95 62 | opts ['D_real_target'] = 1. # Use 0.9 0r 0.95 if you want to apply label smoothing 63 | 64 | opts ['d_activation'] = 'leakyrelu' 65 | g_enc_numkernels = opts ['g_enc_numkernels'] 66 | opts ['g_dec_numkernels'] = g_enc_numkernels[:-1][::-1] + [1] 67 | opts ['gt_stride'] = 2 68 | opts ['g_l1loss'] = 200. 69 | opts ['d_lr'] = 0.0002 70 | opts ['g_lr'] = 0.0002 71 | opts ['random_seed'] = 111 72 | 73 | n_epochs = 81 74 | fs = 16000 75 | 76 | # set flags for training or testing 77 | TRAIN_SEGAN = True 78 | SAVE_MODEL = True 79 | LOAD_SAVED_MODEL = False 80 | TEST_SEGAN = True 81 | 82 | modeldir = get_modeldirname(opts) 83 | print ("The model directory is " + modeldir) 84 | print ("_____________________________________") 85 | 86 | if not os.path.exists(modeldir): 87 | os.makedirs(modeldir) 88 | 89 | # Obtain the generator and the discriminator 90 | D = discriminator(opts) 91 | G = generator(opts) 92 | 93 | # Define optimizers 94 | g_opt = keras.optimizers.Adam(lr=opts['g_lr']) 95 | d_opt = keras.optimizers.Adam(lr=opts['d_lr']) 96 | 97 | # The G model has the wav and the noise inputs 98 | wav_shape = (opts['window_length'], opts['featdim']) 99 | z_dim1 = int(opts['window_length']/ (opts ['strides'] ** len(opts ['g_enc_numkernels']))) 100 | z_dim2 = opts ['g_enc_numkernels'][-1] 101 | wav_in_clean = Input(shape=wav_shape, name="main_input_clean") 102 | wav_in_noisy = Input(shape=wav_shape, name="main_input_noisy") 103 | if not opts ['z_off']: 104 | z = Input (shape=(z_dim1, z_dim2), name="noise_input") 105 | G_wav = G([wav_in_noisy, z]) 106 | G = Model([wav_in_noisy, z], G_wav) 107 | else : 108 | G_wav = G(wav_in_noisy) 109 | G = Model(wav_in_noisy, G_wav) 110 | 111 | d_out = D([wav_in_clean, wav_in_noisy]) 112 | D = Model([wav_in_clean, wav_in_noisy], d_out) 113 | G.summary() 114 | D.summary() 115 | 116 | # compile individual models 117 | D.compile(loss='mean_squared_error', optimizer=d_opt) 118 | G.compile(loss='mean_absolute_error', optimizer=g_opt) 119 | 120 | # for the combined model, we set the discriminator to be not trainable 121 | D.trainable = False 122 | D_out = D([G_wav, wav_in_noisy]) 123 | if not opts ['z_off']: 124 | G_D = Model(inputs=[wav_in_clean, wav_in_noisy, z], outputs=[D_out, G_wav]) 125 | else : 126 | G_D = Model(inputs=[wav_in_clean, wav_in_noisy], outputs=[D_out, G_wav]) 127 | G_D.summary() 128 | 129 | G_D.compile(optimizer=g_opt, 130 | loss={'model_2': 'mean_absolute_error', 'model_4': 'mean_squared_error'}, 131 | loss_weights = {'model_2' : opts['g_l1loss'], 'model_4': 1} ) 132 | print (G_D.metrics_names) 133 | 134 | #exit () 135 | 136 | if TEST_SEGAN: 137 | ftestnoisy = h5py.File(noisy_test_matfile) 138 | noisy_test_data = ftestnoisy['feat_data'] 139 | noisy_test_dfi = ftestnoisy['dfi'] 140 | print ("Number of test files: " + str(noisy_test_dfi.shape[1]) ) 141 | 142 | 143 | # Begin the training part 144 | if TRAIN_SEGAN: 145 | fclean = h5py.File(clean_train_matfile) 146 | clean_train_data = np.array(fclean['feat_data']).astype('float32') 147 | fnoisy = h5py.File(noisy_train_matfile) 148 | noisy_train_data = np.array(fnoisy['feat_data']).astype('float32') 149 | numtrainsamples = clean_train_data.shape[1] 150 | idx_all = np.arange(numtrainsamples) 151 | # set random seed 152 | np.random.seed(opts['random_seed']) 153 | batch_size = opts['batch_size'] 154 | 155 | print ("********************************************") 156 | print (" SEGAN TRAINING ") 157 | print ("********************************************") 158 | print ("Shape of clean feats mat " + str(clean_train_data.shape)) 159 | print ("Shape of noisy feats mat " + str(noisy_train_data.shape)) 160 | numtrainsamples = clean_train_data.shape[1] 161 | 162 | # Tensorboard stuff 163 | log_path = './logs/' + modeldir 164 | callback = TensorBoard(log_path) 165 | callback.set_model(G_D) 166 | train_names = ['G_loss', 'G_adv_loss', 'G_l1Loss'] 167 | 168 | idx_all = np.arange(numtrainsamples) 169 | # set random seed 170 | np.random.seed(opts['random_seed']) 171 | 172 | batch_size = opts['batch_size'] 173 | num_batches_per_epoch = int(np.floor(clean_train_data.shape[1]/batch_size)) 174 | for epoch in range(n_epochs): 175 | # train D with minibatch 176 | np.random.shuffle(idx_all) # shuffle the indices for the next epoch 177 | for batch_idx in range(num_batches_per_epoch): 178 | start_time = time.time() 179 | idx_beg = batch_idx * batch_size 180 | idx_end = idx_beg + batch_size 181 | idx = np.sort(np.array(idx_all[idx_beg:idx_end])) 182 | #print ("Batch idx " + str(idx[:5]) +" ... " + str(idx[-5:])) 183 | cleanwavs = np.array(clean_train_data[:,idx]).T 184 | cleanwavs = data_preprocess(cleanwavs, preemph=opts['preemph']) 185 | cleanwavs = np.expand_dims(cleanwavs, axis = 2) 186 | noisywavs = np.array(noisy_train_data[:,idx]).T 187 | noisywavs = data_preprocess(noisywavs, preemph=opts['preemph']) 188 | noisywavs = np.expand_dims(noisywavs, axis = 2) 189 | if not opts ['z_off']: 190 | noiseinput = np.random.normal(0, 1, (batch_size, z_dim1, z_dim2)) 191 | g_out = G.predict([noisywavs, noiseinput]) 192 | else : 193 | g_out = G.predict(noisywavs) 194 | 195 | # train D 196 | d_loss_real = D.train_on_batch ({'main_input_clean':cleanwavs, 'main_input_noisy':noisywavs}, 197 | opts ['D_real_target'] * np.ones((batch_size,1))) 198 | d_loss_fake = D.train_on_batch ({'main_input_clean':g_out, 'main_input_noisy':noisywavs}, 199 | np.zeros((batch_size,1))) 200 | d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) 201 | 202 | # Train the combined model next; here, only the generator part is update 203 | valid_g = np.array([1]*batch_size) # generator wants discriminator to give 1 (identify fake as real) 204 | if not opts['z_off']: 205 | [g_loss, g_dLoss, g_l1loss] = G_D.train_on_batch({'main_input_clean': cleanwavs, 206 | 'main_input_noisy': noisywavs, 'noise_input': noiseinput}, 207 | {'model_2': cleanwavs, 'model_4': valid_g} ) 208 | else: 209 | [g_loss, g_dLoss, g_l1loss] = G_D.train_on_batch({'main_input_clean': cleanwavs, 210 | 'main_input_noisy': noisywavs},{'model_2': cleanwavs, 211 | 'model_4': valid_g} ) 212 | time_taken = time.time() - start_time 213 | 214 | printlog = "E%d/%d:B%d/%d [D loss: %f] [D real loss: %f] [D fake loss: %f] [G loss: %f] [G_D loss: %f] [G_L1 loss: %f] [Exec. time: %f]" % (epoch, n_epochs, batch_idx, num_batches_per_epoch, d_loss, d_loss_real, d_loss_fake, g_loss, g_dLoss, g_l1loss, time_taken) 215 | 216 | print (printlog) 217 | # Tensorboard stuff 218 | logs = [g_loss, g_dLoss, g_l1loss] 219 | write_log(callback, train_names, logs, epoch) 220 | 221 | if (TEST_SEGAN and epoch % 10 == 0) or epoch == n_epochs - 1: 222 | print ("********************************************") 223 | print (" SEGAN TESTING ") 224 | print ("********************************************") 225 | 226 | resultsdir = modeldir + "/test_results_epoch" + str(epoch) 227 | if not os.path.exists(resultsdir): 228 | os.makedirs(resultsdir) 229 | 230 | if LOAD_SAVED_MODEL: 231 | print ("Loading model from " + modeldir + "/Gmodel") 232 | json_file = open(modeldir + "/Gmodel.json", "r") 233 | loaded_model_json = json_file.read() 234 | json_file.close() 235 | G_loaded = model_from_json(loaded_model_json) 236 | G_loaded.compile(loss='mean_squared_error', optimizer=g_opt) 237 | G_loaded.load_weights(modeldir + "/Gmodel.h5") 238 | else: 239 | G_loaded = G 240 | 241 | print ("Saving Results to " + resultsdir) 242 | 243 | for test_num in tqdm(range(noisy_test_dfi.shape[1])) : 244 | test_beg = noisy_test_dfi[0, test_num] 245 | test_end = noisy_test_dfi[1, test_num] 246 | #print ("Reading indices " + str(test_beg) + " to " + str(test_end)) 247 | noisywavs = np.array(noisy_test_data[:,test_beg:test_end]).T 248 | noisywavs = data_preprocess(noisywavs, preemph=opts['preemph']) 249 | noisywavs = np.expand_dims(noisywavs, axis = 2) 250 | if not opts['z_off']: 251 | noiseinput = np.random.normal(0, 1, (noisywavs.shape[0], z_dim1, z_dim2)) 252 | cleaned_wavs = G_loaded.predict([noisywavs, noiseinput]) 253 | else : 254 | cleaned_wavs = G_loaded.predict(noisywavs) 255 | 256 | cleaned_wavs = np.reshape(cleaned_wavs, (noisywavs.shape[0], noisywavs.shape[1])) 257 | cleanwav = reconstruct_wav(cleaned_wavs) 258 | cleanwav = np.reshape(cleanwav, (-1,)) # make it to 1d by dropping the extra dimension 259 | 260 | if opts['preemph'] > 0: 261 | cleanwav = de_emph(cleanwav, coeff=opts['preemph']) 262 | 263 | destfilename = resultsdir + "/testwav_%d.wav" % (test_num) 264 | wavfile.write(destfilename, fs, cleanwav) 265 | 266 | 267 | 268 | # Finally, save the model 269 | if SAVE_MODEL: 270 | model_json = G.to_json() 271 | with open(modeldir + "/Gmodel.json", "w") as json_file: 272 | json_file.write(model_json) 273 | G.save_weights(modeldir + "/Gmodel.h5") 274 | print ("Model saved to " + modeldir) 275 | -------------------------------------------------------------------------------- /test_wav.txt: -------------------------------------------------------------------------------- 1 | p232_001.wav 2 | p232_002.wav 3 | p232_003.wav 4 | p232_005.wav 5 | p232_006.wav 6 | p232_007.wav 7 | p232_009.wav 8 | p232_010.wav 9 | p232_011.wav 10 | p232_012.wav 11 | p232_013.wav 12 | p232_014.wav 13 | p232_015.wav 14 | p232_016.wav 15 | p232_017.wav 16 | p232_019.wav 17 | p232_020.wav 18 | p232_021.wav 19 | p232_022.wav 20 | p232_023.wav 21 | p232_024.wav 22 | p232_025.wav 23 | p232_027.wav 24 | p232_028.wav 25 | p232_029.wav 26 | p232_030.wav 27 | p232_031.wav 28 | p232_032.wav 29 | p232_033.wav 30 | p232_034.wav 31 | p232_035.wav 32 | p232_036.wav 33 | p232_037.wav 34 | p232_038.wav 35 | p232_039.wav 36 | p232_040.wav 37 | p232_041.wav 38 | p232_042.wav 39 | p232_043.wav 40 | p232_044.wav 41 | p232_045.wav 42 | p232_046.wav 43 | p232_047.wav 44 | p232_048.wav 45 | p232_049.wav 46 | p232_050.wav 47 | p232_051.wav 48 | p232_052.wav 49 | p232_053.wav 50 | p232_054.wav 51 | p232_055.wav 52 | p232_056.wav 53 | p232_057.wav 54 | p232_058.wav 55 | p232_059.wav 56 | p232_060.wav 57 | p232_061.wav 58 | p232_062.wav 59 | p232_063.wav 60 | p232_064.wav 61 | p232_065.wav 62 | p232_066.wav 63 | p232_067.wav 64 | p232_068.wav 65 | p232_069.wav 66 | p232_070.wav 67 | p232_071.wav 68 | p232_072.wav 69 | p232_073.wav 70 | p232_074.wav 71 | p232_075.wav 72 | p232_076.wav 73 | p232_077.wav 74 | p232_078.wav 75 | p232_079.wav 76 | p232_080.wav 77 | p232_081.wav 78 | p232_082.wav 79 | p232_083.wav 80 | p232_084.wav 81 | p232_085.wav 82 | p232_086.wav 83 | p232_087.wav 84 | p232_088.wav 85 | p232_089.wav 86 | p232_090.wav 87 | p232_091.wav 88 | p232_092.wav 89 | p232_093.wav 90 | p232_094.wav 91 | p232_095.wav 92 | p232_096.wav 93 | p232_097.wav 94 | p232_098.wav 95 | p232_099.wav 96 | p232_100.wav 97 | p232_101.wav 98 | p232_102.wav 99 | p232_103.wav 100 | p232_104.wav 101 | p232_105.wav 102 | p232_106.wav 103 | p232_107.wav 104 | p232_108.wav 105 | p232_109.wav 106 | p232_110.wav 107 | p232_112.wav 108 | p232_113.wav 109 | p232_114.wav 110 | p232_115.wav 111 | p232_116.wav 112 | p232_117.wav 113 | p232_118.wav 114 | p232_119.wav 115 | p232_120.wav 116 | p232_121.wav 117 | p232_123.wav 118 | p232_124.wav 119 | p232_125.wav 120 | p232_126.wav 121 | p232_127.wav 122 | p232_128.wav 123 | p232_129.wav 124 | p232_130.wav 125 | p232_131.wav 126 | p232_132.wav 127 | p232_133.wav 128 | p232_134.wav 129 | p232_135.wav 130 | p232_136.wav 131 | p232_137.wav 132 | p232_138.wav 133 | p232_139.wav 134 | p232_140.wav 135 | p232_141.wav 136 | p232_142.wav 137 | p232_143.wav 138 | p232_144.wav 139 | p232_145.wav 140 | p232_146.wav 141 | p232_147.wav 142 | p232_148.wav 143 | p232_150.wav 144 | p232_151.wav 145 | p232_152.wav 146 | p232_153.wav 147 | p232_154.wav 148 | p232_155.wav 149 | p232_156.wav 150 | p232_158.wav 151 | p232_159.wav 152 | p232_160.wav 153 | p232_161.wav 154 | p232_162.wav 155 | p232_163.wav 156 | p232_164.wav 157 | p232_165.wav 158 | p232_167.wav 159 | p232_169.wav 160 | p232_170.wav 161 | p232_171.wav 162 | p232_172.wav 163 | p232_173.wav 164 | p232_174.wav 165 | p232_175.wav 166 | p232_176.wav 167 | p232_177.wav 168 | p232_178.wav 169 | p232_179.wav 170 | p232_180.wav 171 | p232_181.wav 172 | p232_182.wav 173 | p232_183.wav 174 | p232_184.wav 175 | p232_185.wav 176 | p232_186.wav 177 | p232_187.wav 178 | p232_188.wav 179 | p232_189.wav 180 | p232_190.wav 181 | p232_191.wav 182 | p232_193.wav 183 | p232_194.wav 184 | p232_195.wav 185 | p232_196.wav 186 | p232_197.wav 187 | p232_198.wav 188 | p232_199.wav 189 | p232_200.wav 190 | p232_201.wav 191 | p232_202.wav 192 | p232_203.wav 193 | p232_204.wav 194 | p232_205.wav 195 | p232_206.wav 196 | p232_207.wav 197 | p232_208.wav 198 | p232_209.wav 199 | p232_210.wav 200 | p232_211.wav 201 | p232_213.wav 202 | p232_214.wav 203 | p232_215.wav 204 | p232_216.wav 205 | p232_217.wav 206 | p232_218.wav 207 | p232_219.wav 208 | p232_220.wav 209 | p232_221.wav 210 | p232_223.wav 211 | p232_224.wav 212 | p232_225.wav 213 | p232_226.wav 214 | p232_227.wav 215 | p232_228.wav 216 | p232_229.wav 217 | p232_230.wav 218 | p232_231.wav 219 | p232_232.wav 220 | p232_234.wav 221 | p232_235.wav 222 | p232_236.wav 223 | p232_237.wav 224 | p232_238.wav 225 | p232_239.wav 226 | p232_240.wav 227 | p232_241.wav 228 | p232_242.wav 229 | p232_243.wav 230 | p232_244.wav 231 | p232_245.wav 232 | p232_246.wav 233 | p232_247.wav 234 | p232_248.wav 235 | p232_249.wav 236 | p232_250.wav 237 | p232_251.wav 238 | p232_252.wav 239 | p232_253.wav 240 | p232_254.wav 241 | p232_255.wav 242 | p232_256.wav 243 | p232_257.wav 244 | p232_258.wav 245 | p232_259.wav 246 | p232_260.wav 247 | p232_261.wav 248 | p232_263.wav 249 | p232_264.wav 250 | p232_265.wav 251 | p232_266.wav 252 | p232_267.wav 253 | p232_268.wav 254 | p232_269.wav 255 | p232_270.wav 256 | p232_271.wav 257 | p232_272.wav 258 | p232_273.wav 259 | p232_274.wav 260 | p232_275.wav 261 | p232_276.wav 262 | p232_277.wav 263 | p232_278.wav 264 | p232_279.wav 265 | p232_280.wav 266 | p232_281.wav 267 | p232_282.wav 268 | p232_283.wav 269 | p232_284.wav 270 | p232_285.wav 271 | p232_286.wav 272 | p232_287.wav 273 | p232_288.wav 274 | p232_289.wav 275 | p232_290.wav 276 | p232_291.wav 277 | p232_292.wav 278 | p232_293.wav 279 | p232_294.wav 280 | p232_295.wav 281 | p232_296.wav 282 | p232_297.wav 283 | p232_298.wav 284 | p232_299.wav 285 | p232_300.wav 286 | p232_301.wav 287 | p232_302.wav 288 | p232_303.wav 289 | p232_305.wav 290 | p232_306.wav 291 | p232_307.wav 292 | p232_308.wav 293 | p232_309.wav 294 | p232_310.wav 295 | p232_311.wav 296 | p232_312.wav 297 | p232_313.wav 298 | p232_314.wav 299 | p232_315.wav 300 | p232_316.wav 301 | p232_317.wav 302 | p232_318.wav 303 | p232_319.wav 304 | p232_320.wav 305 | p232_321.wav 306 | p232_322.wav 307 | p232_323.wav 308 | p232_324.wav 309 | p232_325.wav 310 | p232_326.wav 311 | p232_327.wav 312 | p232_328.wav 313 | p232_329.wav 314 | p232_330.wav 315 | p232_331.wav 316 | p232_332.wav 317 | p232_333.wav 318 | p232_334.wav 319 | p232_335.wav 320 | p232_336.wav 321 | p232_337.wav 322 | p232_338.wav 323 | p232_339.wav 324 | p232_340.wav 325 | p232_341.wav 326 | p232_342.wav 327 | p232_343.wav 328 | p232_344.wav 329 | p232_346.wav 330 | p232_347.wav 331 | p232_348.wav 332 | p232_349.wav 333 | p232_350.wav 334 | p232_351.wav 335 | p232_352.wav 336 | p232_353.wav 337 | p232_354.wav 338 | p232_355.wav 339 | p232_356.wav 340 | p232_357.wav 341 | p232_358.wav 342 | p232_359.wav 343 | p232_360.wav 344 | p232_361.wav 345 | p232_362.wav 346 | p232_363.wav 347 | p232_364.wav 348 | p232_365.wav 349 | p232_366.wav 350 | p232_367.wav 351 | p232_368.wav 352 | p232_369.wav 353 | p232_370.wav 354 | p232_371.wav 355 | p232_372.wav 356 | p232_373.wav 357 | p232_374.wav 358 | p232_375.wav 359 | p232_377.wav 360 | p232_378.wav 361 | p232_379.wav 362 | p232_380.wav 363 | p232_381.wav 364 | p232_382.wav 365 | p232_383.wav 366 | p232_384.wav 367 | p232_385.wav 368 | p232_386.wav 369 | p232_387.wav 370 | p232_388.wav 371 | p232_389.wav 372 | p232_390.wav 373 | p232_391.wav 374 | p232_392.wav 375 | p232_393.wav 376 | p232_394.wav 377 | p232_396.wav 378 | p232_397.wav 379 | p232_398.wav 380 | p232_399.wav 381 | p232_400.wav 382 | p232_402.wav 383 | p232_403.wav 384 | p232_404.wav 385 | p232_405.wav 386 | p232_407.wav 387 | p232_409.wav 388 | p232_410.wav 389 | p232_411.wav 390 | p232_412.wav 391 | p232_413.wav 392 | p232_414.wav 393 | p232_415.wav 394 | p257_001.wav 395 | p257_002.wav 396 | p257_003.wav 397 | p257_004.wav 398 | p257_006.wav 399 | p257_007.wav 400 | p257_008.wav 401 | p257_009.wav 402 | p257_010.wav 403 | p257_011.wav 404 | p257_012.wav 405 | p257_013.wav 406 | p257_014.wav 407 | p257_015.wav 408 | p257_016.wav 409 | p257_017.wav 410 | p257_018.wav 411 | p257_019.wav 412 | p257_020.wav 413 | p257_022.wav 414 | p257_023.wav 415 | p257_024.wav 416 | p257_025.wav 417 | p257_026.wav 418 | p257_027.wav 419 | p257_028.wav 420 | p257_029.wav 421 | p257_030.wav 422 | p257_031.wav 423 | p257_032.wav 424 | p257_033.wav 425 | p257_034.wav 426 | p257_035.wav 427 | p257_036.wav 428 | p257_037.wav 429 | p257_038.wav 430 | p257_039.wav 431 | p257_040.wav 432 | p257_041.wav 433 | p257_042.wav 434 | p257_043.wav 435 | p257_044.wav 436 | p257_045.wav 437 | p257_046.wav 438 | p257_047.wav 439 | p257_048.wav 440 | p257_049.wav 441 | p257_050.wav 442 | p257_051.wav 443 | p257_052.wav 444 | p257_053.wav 445 | p257_054.wav 446 | p257_055.wav 447 | p257_056.wav 448 | p257_057.wav 449 | p257_058.wav 450 | p257_059.wav 451 | p257_060.wav 452 | p257_061.wav 453 | p257_062.wav 454 | p257_063.wav 455 | p257_064.wav 456 | p257_065.wav 457 | p257_066.wav 458 | p257_067.wav 459 | p257_068.wav 460 | p257_069.wav 461 | p257_070.wav 462 | p257_071.wav 463 | p257_072.wav 464 | p257_073.wav 465 | p257_074.wav 466 | p257_075.wav 467 | p257_076.wav 468 | p257_077.wav 469 | p257_078.wav 470 | p257_079.wav 471 | p257_080.wav 472 | p257_081.wav 473 | p257_082.wav 474 | p257_083.wav 475 | p257_084.wav 476 | p257_085.wav 477 | p257_086.wav 478 | p257_087.wav 479 | p257_088.wav 480 | p257_089.wav 481 | p257_090.wav 482 | p257_091.wav 483 | p257_092.wav 484 | p257_093.wav 485 | p257_094.wav 486 | p257_095.wav 487 | p257_096.wav 488 | p257_097.wav 489 | p257_098.wav 490 | p257_099.wav 491 | p257_100.wav 492 | p257_101.wav 493 | p257_102.wav 494 | p257_103.wav 495 | p257_104.wav 496 | p257_105.wav 497 | p257_106.wav 498 | p257_107.wav 499 | p257_108.wav 500 | p257_109.wav 501 | p257_110.wav 502 | p257_111.wav 503 | p257_112.wav 504 | p257_113.wav 505 | p257_114.wav 506 | p257_115.wav 507 | p257_116.wav 508 | p257_117.wav 509 | p257_118.wav 510 | p257_119.wav 511 | p257_120.wav 512 | p257_121.wav 513 | p257_122.wav 514 | p257_123.wav 515 | p257_124.wav 516 | p257_125.wav 517 | p257_126.wav 518 | p257_127.wav 519 | p257_128.wav 520 | p257_129.wav 521 | p257_130.wav 522 | p257_131.wav 523 | p257_132.wav 524 | p257_133.wav 525 | p257_135.wav 526 | p257_136.wav 527 | p257_137.wav 528 | p257_138.wav 529 | p257_139.wav 530 | p257_140.wav 531 | p257_141.wav 532 | p257_142.wav 533 | p257_143.wav 534 | p257_144.wav 535 | p257_145.wav 536 | p257_146.wav 537 | p257_147.wav 538 | p257_148.wav 539 | p257_149.wav 540 | p257_150.wav 541 | p257_151.wav 542 | p257_152.wav 543 | p257_153.wav 544 | p257_154.wav 545 | p257_155.wav 546 | p257_156.wav 547 | p257_157.wav 548 | p257_158.wav 549 | p257_159.wav 550 | p257_160.wav 551 | p257_161.wav 552 | p257_162.wav 553 | p257_163.wav 554 | p257_164.wav 555 | p257_165.wav 556 | p257_166.wav 557 | p257_167.wav 558 | p257_168.wav 559 | p257_169.wav 560 | p257_170.wav 561 | p257_171.wav 562 | p257_172.wav 563 | p257_173.wav 564 | p257_174.wav 565 | p257_175.wav 566 | p257_176.wav 567 | p257_177.wav 568 | p257_178.wav 569 | p257_179.wav 570 | p257_180.wav 571 | p257_181.wav 572 | p257_182.wav 573 | p257_183.wav 574 | p257_184.wav 575 | p257_185.wav 576 | p257_186.wav 577 | p257_187.wav 578 | p257_188.wav 579 | p257_189.wav 580 | p257_190.wav 581 | p257_191.wav 582 | p257_192.wav 583 | p257_193.wav 584 | p257_194.wav 585 | p257_195.wav 586 | p257_196.wav 587 | p257_197.wav 588 | p257_198.wav 589 | p257_199.wav 590 | p257_200.wav 591 | p257_201.wav 592 | p257_202.wav 593 | p257_203.wav 594 | p257_204.wav 595 | p257_205.wav 596 | p257_206.wav 597 | p257_207.wav 598 | p257_208.wav 599 | p257_209.wav 600 | p257_210.wav 601 | p257_211.wav 602 | p257_212.wav 603 | p257_213.wav 604 | p257_214.wav 605 | p257_215.wav 606 | p257_216.wav 607 | p257_217.wav 608 | p257_218.wav 609 | p257_219.wav 610 | p257_220.wav 611 | p257_221.wav 612 | p257_222.wav 613 | p257_223.wav 614 | p257_224.wav 615 | p257_225.wav 616 | p257_226.wav 617 | p257_227.wav 618 | p257_228.wav 619 | p257_229.wav 620 | p257_230.wav 621 | p257_231.wav 622 | p257_232.wav 623 | p257_233.wav 624 | p257_234.wav 625 | p257_235.wav 626 | p257_236.wav 627 | p257_237.wav 628 | p257_238.wav 629 | p257_239.wav 630 | p257_240.wav 631 | p257_241.wav 632 | p257_242.wav 633 | p257_243.wav 634 | p257_244.wav 635 | p257_245.wav 636 | p257_246.wav 637 | p257_247.wav 638 | p257_248.wav 639 | p257_249.wav 640 | p257_250.wav 641 | p257_251.wav 642 | p257_252.wav 643 | p257_253.wav 644 | p257_254.wav 645 | p257_255.wav 646 | p257_256.wav 647 | p257_257.wav 648 | p257_258.wav 649 | p257_259.wav 650 | p257_260.wav 651 | p257_261.wav 652 | p257_262.wav 653 | p257_263.wav 654 | p257_264.wav 655 | p257_265.wav 656 | p257_266.wav 657 | p257_267.wav 658 | p257_268.wav 659 | p257_269.wav 660 | p257_270.wav 661 | p257_271.wav 662 | p257_272.wav 663 | p257_273.wav 664 | p257_274.wav 665 | p257_275.wav 666 | p257_276.wav 667 | p257_277.wav 668 | p257_278.wav 669 | p257_279.wav 670 | p257_280.wav 671 | p257_281.wav 672 | p257_282.wav 673 | p257_283.wav 674 | p257_284.wav 675 | p257_285.wav 676 | p257_286.wav 677 | p257_287.wav 678 | p257_288.wav 679 | p257_289.wav 680 | p257_290.wav 681 | p257_291.wav 682 | p257_292.wav 683 | p257_293.wav 684 | p257_294.wav 685 | p257_295.wav 686 | p257_296.wav 687 | p257_297.wav 688 | p257_298.wav 689 | p257_299.wav 690 | p257_300.wav 691 | p257_301.wav 692 | p257_302.wav 693 | p257_303.wav 694 | p257_304.wav 695 | p257_305.wav 696 | p257_306.wav 697 | p257_307.wav 698 | p257_308.wav 699 | p257_309.wav 700 | p257_310.wav 701 | p257_311.wav 702 | p257_312.wav 703 | p257_313.wav 704 | p257_314.wav 705 | p257_315.wav 706 | p257_316.wav 707 | p257_317.wav 708 | p257_318.wav 709 | p257_319.wav 710 | p257_320.wav 711 | p257_321.wav 712 | p257_322.wav 713 | p257_323.wav 714 | p257_324.wav 715 | p257_325.wav 716 | p257_326.wav 717 | p257_327.wav 718 | p257_328.wav 719 | p257_329.wav 720 | p257_330.wav 721 | p257_331.wav 722 | p257_332.wav 723 | p257_333.wav 724 | p257_334.wav 725 | p257_335.wav 726 | p257_336.wav 727 | p257_337.wav 728 | p257_338.wav 729 | p257_339.wav 730 | p257_340.wav 731 | p257_341.wav 732 | p257_342.wav 733 | p257_343.wav 734 | p257_344.wav 735 | p257_345.wav 736 | p257_346.wav 737 | p257_347.wav 738 | p257_348.wav 739 | p257_349.wav 740 | p257_350.wav 741 | p257_351.wav 742 | p257_352.wav 743 | p257_353.wav 744 | p257_354.wav 745 | p257_355.wav 746 | p257_356.wav 747 | p257_357.wav 748 | p257_358.wav 749 | p257_359.wav 750 | p257_360.wav 751 | p257_361.wav 752 | p257_362.wav 753 | p257_363.wav 754 | p257_364.wav 755 | p257_365.wav 756 | p257_366.wav 757 | p257_367.wav 758 | p257_368.wav 759 | p257_369.wav 760 | p257_370.wav 761 | p257_371.wav 762 | p257_372.wav 763 | p257_373.wav 764 | p257_374.wav 765 | p257_375.wav 766 | p257_376.wav 767 | p257_377.wav 768 | p257_378.wav 769 | p257_379.wav 770 | p257_380.wav 771 | p257_381.wav 772 | p257_382.wav 773 | p257_383.wav 774 | p257_384.wav 775 | p257_385.wav 776 | p257_386.wav 777 | p257_387.wav 778 | p257_388.wav 779 | p257_389.wav 780 | p257_390.wav 781 | p257_391.wav 782 | p257_392.wav 783 | p257_393.wav 784 | p257_394.wav 785 | p257_395.wav 786 | p257_396.wav 787 | p257_397.wav 788 | p257_398.wav 789 | p257_399.wav 790 | p257_400.wav 791 | p257_401.wav 792 | p257_402.wav 793 | p257_403.wav 794 | p257_404.wav 795 | p257_405.wav 796 | p257_406.wav 797 | p257_407.wav 798 | p257_408.wav 799 | p257_409.wav 800 | p257_410.wav 801 | p257_411.wav 802 | p257_412.wav 803 | p257_413.wav 804 | p257_414.wav 805 | p257_415.wav 806 | p257_416.wav 807 | p257_417.wav 808 | p257_418.wav 809 | p257_419.wav 810 | p257_420.wav 811 | p257_421.wav 812 | p257_422.wav 813 | p257_423.wav 814 | p257_424.wav 815 | p257_425.wav 816 | p257_426.wav 817 | p257_427.wav 818 | p257_428.wav 819 | p257_429.wav 820 | p257_430.wav 821 | p257_431.wav 822 | p257_432.wav 823 | p257_433.wav 824 | p257_434.wav 825 | -------------------------------------------------------------------------------- /run_rsgan-gp_se.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reimplementing segan paper as close as possible. 3 | Deepak Baby, UGent, June 2018. 4 | """ 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | from tensorflow.contrib.layers import xavier_initializer, flatten, fully_connected 9 | import numpy as np 10 | from keras.layers import Subtract, Activation, Input 11 | from keras.models import Model 12 | from keras.optimizers import Adam 13 | from keras.layers.merge import _Merge 14 | from keras.callbacks import TensorBoard 15 | import keras.backend as K 16 | 17 | from data_ops import * 18 | from file_ops import * 19 | from models import * 20 | from wgan_ops import * 21 | from functools import partial 22 | import time 23 | from tqdm import * 24 | import h5py 25 | import os,sys 26 | import scipy.io.wavfile as wavfile 27 | 28 | BATCH_SIZE = 100 29 | GRADIENT_PENALTY_WEIGHT = 10 # need to tune 30 | 31 | class RandomWeightedAverage (_Merge): 32 | def _merge_function (self, inputs): 33 | weights = K.random_uniform((BATCH_SIZE, 1, 1)) 34 | return (weights * inputs[0]) + ((1 - weights) * inputs[1]) 35 | 36 | if __name__ == '__main__': 37 | 38 | # Various GAN options 39 | opts = {} 40 | opts ['dirhead'] = 'RSGAN_GP' + str(GRADIENT_PENALTY_WEIGHT) 41 | opts ['gp_weight'] = GRADIENT_PENALTY_WEIGHT 42 | ########################## 43 | opts ['z_off'] = not False # set to True to omit the latent noise input 44 | # normalization 45 | ################################# 46 | # Only one of the follwoing should be set to True or all of can be False 47 | opts ['applybn'] = False 48 | opts ['applyinstancenorm'] = True # Works even without any normalization 49 | ################################## 50 | # Show model summary 51 | opts ['show_summary'] = False 52 | 53 | ## Set the matfiles 54 | clean_train_matfile = "./data/clean_train_segan1d.mat" 55 | noisy_train_matfile = "./data/noisy_train_segan1d.mat" 56 | noisy_test_matfile = "./data/noisy_test_segan1d.mat" 57 | 58 | #################################################### 59 | # Other fixed options 60 | opts ['window_length'] = 2**14 61 | opts ['featdim'] = 1 # 1 since it is just 1d time samples 62 | opts ['filterlength'] = 31 63 | opts ['strides'] = 2 64 | opts ['padding'] = 'SAME' 65 | opts ['g_enc_numkernels'] = [16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024] 66 | opts ['g_enc_lstm_cells'] = [1024] 67 | opts ['d_fmaps'] = opts ['g_enc_numkernels'] # We use the same structure for discriminator 68 | opts ['d_lstms'] = opts ['g_enc_lstm_cells'] 69 | opts['leakyrelualpha'] = 0.3 70 | opts ['batch_size'] = BATCH_SIZE 71 | opts ['applyprelu'] = True 72 | 73 | 74 | opts ['d_activation'] = 'leakyrelu' 75 | g_enc_numkernels = opts ['g_enc_numkernels'] 76 | opts ['g_dec_numkernels'] = g_enc_numkernels[:-1][::-1] + [1] 77 | opts ['gt_stride'] = 2 78 | opts ['g_l1loss'] = 200. 79 | opts ['d_lr'] = 2e-4 80 | opts ['g_lr'] = 2e-4 81 | opts ['random_seed'] = 111 82 | 83 | n_epochs = 81 84 | fs = 16000 85 | 86 | # set flags for training or testing 87 | TRAIN_SEGAN = True 88 | SAVE_MODEL = True 89 | LOAD_SAVED_MODEL = False 90 | TEST_SEGAN = True 91 | 92 | modeldir = get_modeldirname(opts) 93 | print ("The model directory is " + modeldir) 94 | print ("_____________________________________") 95 | 96 | if not os.path.exists(modeldir): 97 | os.makedirs(modeldir) 98 | 99 | # Obtain the generator and the discriminator 100 | D = discriminator(opts) 101 | G = generator(opts) 102 | 103 | # Define optimizers 104 | g_opt = keras.optimizers.Adam(lr=opts['g_lr']) 105 | d_opt = keras.optimizers.Adam(lr=opts['d_lr']) 106 | 107 | # The G model has the wav and the noise inputs 108 | wav_shape = (opts['window_length'], opts['featdim']) 109 | z_dim1 = int(opts['window_length']/ (opts ['strides'] ** len(opts ['g_enc_numkernels']))) 110 | z_dim2 = opts ['g_enc_numkernels'][-1] 111 | wav_in_clean = Input(shape=wav_shape, name="main_input_clean") 112 | wav_in_noisy = Input(shape=wav_shape, name="main_input_noisy") 113 | if not opts ['z_off']: 114 | z = Input (shape=(z_dim1, z_dim2), name="noise_input") 115 | G_wav = G([wav_in_noisy, z]) 116 | G_model = Model([wav_in_noisy, z], G_wav) 117 | else : 118 | G_wav = G(wav_in_noisy) 119 | G_model = Model(wav_in_noisy, G_wav) 120 | 121 | d_out = D([wav_in_clean, wav_in_noisy]) 122 | D = Model([wav_in_clean, wav_in_noisy], d_out) 123 | G_model.summary() 124 | D.summary() 125 | 126 | # ADDING RELATIVISTIC LOSS AT OUTPUT 127 | for layer in D.layers : 128 | layer.trainable = False 129 | D.trainable = False 130 | if not opts ['z_off']: 131 | G_wav = G([wav_in_noisy, z]) 132 | else : 133 | G_wav = G(wav_in_noisy) 134 | D_out_for_G = D([G_wav, wav_in_noisy]) 135 | D_out_for_real = D([wav_in_clean, wav_in_noisy]) 136 | 137 | d_outG = Subtract()([D_out_for_G, D_out_for_real]) 138 | d_outG = Activation('sigmoid', name="DoutG")(d_outG) 139 | 140 | if not opts ['z_off']: 141 | G_D = Model(inputs=[wav_in_clean, wav_in_noisy, z], outputs = [d_outG, G_wav]) 142 | else : 143 | G_D = Model(inputs=[wav_in_clean, wav_in_noisy], outputs = [d_outG, G_wav]) 144 | 145 | G_D.summary() 146 | G_D.compile(optimizer=g_opt, 147 | loss={'model_2': 'mean_absolute_error', 'DoutG': 'binary_crossentropy'}, 148 | loss_weights = {'model_2' : opts['g_l1loss'], 'DoutG': 1} ) 149 | print (G_D.metrics_names) 150 | 151 | # Now we need D model so that gradient penalty can be incorporated 152 | for layer in D.layers : 153 | layer.trainable = True 154 | for layer in G.layers : 155 | layer.trainable = False 156 | D.trainable = True 157 | G.trainable = False 158 | if not opts ['z_off']: 159 | G_wav_for_D = G([wav_in_noisy, z]) 160 | else : 161 | G_wav_for_D = G(wav_in_noisy) 162 | 163 | d_out_for_G = D([G_wav_for_D, wav_in_noisy]) 164 | d_out_for_real = D([wav_in_clean, wav_in_noisy]) 165 | # for gradient penalty 166 | averaged_samples = RandomWeightedAverage()([wav_in_clean, G_wav_for_D]) 167 | # We will need to this also through D, for computing the gradients 168 | d_out_for_averaged = D([averaged_samples, wav_in_noisy]) 169 | # compute the GP loss by means of partial function in keras 170 | partial_gp_loss = partial(gradient_penalty_loss, 171 | averaged_samples = averaged_samples, 172 | gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT) 173 | partial_gp_loss.__name__ = 'gradient_penalty' 174 | 175 | d_outD = Subtract()([d_out_for_real, d_out_for_G]) 176 | d_outD = Activation('sigmoid', name="DoutD")(d_outD) 177 | 178 | if not opts ['z_off']: 179 | D_final = Model(inputs = [wav_in_clean, wav_in_noisy, z], 180 | outputs = [d_outD, d_out_for_averaged]) 181 | else : 182 | D_final = Model(inputs = [wav_in_clean, wav_in_noisy], 183 | outputs = [d_outD, d_out_for_averaged]) 184 | D_final.compile(optimizer = d_opt, 185 | loss = ['binary_crossentropy', partial_gp_loss ]) 186 | 187 | D_final.summary() 188 | print (D_final.metrics_names) 189 | 190 | # create label vectors for training 191 | positive_y = np.ones((BATCH_SIZE, 1), dtype=np.float32) 192 | negative_y = -1 * positive_y 193 | dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32) # for GP Loss 194 | 195 | if TEST_SEGAN: 196 | ftestnoisy = h5py.File(noisy_test_matfile) 197 | noisy_test_data = ftestnoisy['feat_data'] 198 | noisy_test_dfi = ftestnoisy['dfi'] 199 | print ("Number of test files: " + str(noisy_test_dfi.shape[1]) ) 200 | 201 | 202 | # Begin the training part 203 | if TRAIN_SEGAN: 204 | fclean = h5py.File(clean_train_matfile) 205 | clean_train_data = np.array(fclean['feat_data']) 206 | fnoisy = h5py.File(noisy_train_matfile) 207 | noisy_train_data = np.array(fnoisy['feat_data']) 208 | print ("********************************************") 209 | print (" SEGAN TRAINING ") 210 | print ("********************************************") 211 | print ("Shape of clean feats mat " + str(clean_train_data.shape)) 212 | print ("Shape of noisy feats mat " + str(noisy_train_data.shape)) 213 | numtrainsamples = clean_train_data.shape[1] 214 | 215 | # Tensorboard stuff 216 | log_path = './logs/' + modeldir 217 | callback = TensorBoard(log_path) 218 | callback.set_model(G_D) 219 | train_names = ['G_loss', 'G_adv_loss', 'G_l1Loss'] 220 | 221 | idx_all = np.arange(numtrainsamples) 222 | # set random seed 223 | np.random.seed(opts['random_seed']) 224 | 225 | batch_size = opts['batch_size'] 226 | num_batches_per_epoch = int(np.floor(clean_train_data.shape[1]/batch_size)) 227 | for epoch in range(n_epochs): 228 | # train D with minibatch 229 | np.random.shuffle(idx_all) # shuffle the indices for the next epoch 230 | for batch_idx in range(num_batches_per_epoch): 231 | start_time = time.time() 232 | idx_beg = batch_idx * batch_size 233 | idx_end = idx_beg + batch_size 234 | idx = np.sort(np.array(idx_all[idx_beg:idx_end])) 235 | #print ("Batch idx " + str(idx[:5]) +" ... " + str(idx[-5:])) 236 | cleanwavs = np.array(clean_train_data[:,idx]).T 237 | cleanwavs = data_preprocess(cleanwavs, preemph=opts['preemph']) 238 | cleanwavs = np.expand_dims(cleanwavs, axis = 2) 239 | noisywavs = np.array(noisy_train_data[:,idx]).T 240 | noisywavs = data_preprocess(noisywavs, preemph=opts['preemph']) 241 | noisywavs = np.expand_dims(noisywavs, axis = 2) 242 | if not opts ['z_off']: 243 | noiseinput = np.random.normal(0, 1, 244 | (batch_size, z_dim1, z_dim2)) 245 | [_, d_loss, d_gploss] = D_final.train_on_batch({'main_input_clean': cleanwavs, 246 | 'main_input_noisy': noisywavs, 'noise_input': noiseinput}, 247 | {'DoutD': positive_y, 'model_4': dummy_y} ) 248 | [g_loss, g_dLoss, g_l1loss] = G_D.train_on_batch({'main_input_clean': cleanwavs, 249 | 'main_input_noisy': noisywavs, 'noise_input': noiseinput}, 250 | {'model_2': cleanwavs, 'DoutG': positive_y} ) 251 | else: 252 | [_, d_loss, d_gploss] = D_final.train_on_batch({'main_input_clean': cleanwavs, 253 | 'main_input_noisy': noisywavs,}, 254 | {'DoutD': positive_y, 'model_4': dummy_y} ) 255 | [g_loss, g_dLoss, g_l1loss] = G_D.train_on_batch({'main_input_clean': cleanwavs, 256 | 'main_input_noisy': noisywavs}, 257 | {'model_2': cleanwavs, 258 | 'DoutG': positive_y} ) 259 | time_taken = time.time() - start_time 260 | 261 | printlog = "E%d/%d:B%d/%d [D loss: %f] [D_GP loss: %f] [G loss: %f] [G_D loss: %f] [G_L1 loss: %f] [Exec. time: %f]" % (epoch, n_epochs, batch_idx, num_batches_per_epoch, d_loss, d_gploss, g_loss, g_dLoss, g_l1loss, time_taken) 262 | 263 | print (printlog) 264 | # Tensorboard stuff 265 | logs = [g_loss, g_dLoss, g_l1loss] 266 | write_log(callback, train_names, logs, epoch) 267 | 268 | if (TEST_SEGAN and epoch % 10 == 0) or epoch == n_epochs - 1: 269 | print ("********************************************") 270 | print (" SEGAN TESTING ") 271 | print ("********************************************") 272 | 273 | resultsdir = modeldir + "/test_results_epoch" + str(epoch) 274 | if not os.path.exists(resultsdir): 275 | os.makedirs(resultsdir) 276 | 277 | if LOAD_SAVED_MODEL: 278 | print ("Loading model from " + modeldir + "/Gmodel") 279 | json_file = open(modeldir + "/Gmodel.json", "r") 280 | loaded_model_json = json_file.read() 281 | json_file.close() 282 | G_loaded = model_from_json(loaded_model_json) 283 | G_loaded.compile(loss='mean_squared_error', optimizer=g_opt) 284 | G_loaded.load_weights(modeldir + "/Gmodel.h5") 285 | else: 286 | G_loaded = G 287 | 288 | print ("Saving Results to " + resultsdir) 289 | 290 | for test_num in tqdm(range(noisy_test_dfi.shape[1])) : 291 | test_beg = noisy_test_dfi[0, test_num] 292 | test_end = noisy_test_dfi[1, test_num] 293 | #print ("Reading indices " + str(test_beg) + " to " + str(test_end)) 294 | noisywavs = np.array(noisy_test_data[:,test_beg:test_end]).T 295 | noisywavs = data_preprocess(noisywavs, preemph=opts['preemph']) 296 | noisywavs = np.expand_dims(noisywavs, axis = 2) 297 | if not opts['z_off']: 298 | noiseinput = np.random.normal(0, 1, (noisywavs.shape[0], z_dim1, z_dim2)) 299 | cleaned_wavs = G_loaded.predict([noisywavs, noiseinput]) 300 | else : 301 | cleaned_wavs = G_loaded.predict(noisywavs) 302 | 303 | cleaned_wavs = np.reshape(cleaned_wavs, (noisywavs.shape[0], noisywavs.shape[1])) 304 | cleanwav = reconstruct_wav(cleaned_wavs) 305 | cleanwav = np.reshape(cleanwav, (-1,)) # make it to 1d by dropping the extra dimension 306 | 307 | if opts['preemph'] > 0: 308 | cleanwav = de_emph(cleanwav, coeff=opts['preemph']) 309 | 310 | destfilename = resultsdir + "/testwav_%d.wav" % (test_num) 311 | wavfile.write(destfilename, fs, cleanwav) 312 | 313 | 314 | 315 | # Finally, save the model 316 | if SAVE_MODEL: 317 | model_json = G.to_json() 318 | with open(modeldir + "/Gmodel.json", "w") as json_file: 319 | json_file.write(model_json) 320 | G.save_weights(modeldir + "/Gmodel.h5") 321 | print ("Model saved to " + modeldir) 322 | -------------------------------------------------------------------------------- /run_wgan-gp_se.py: -------------------------------------------------------------------------------- 1 | """ 2 | Speech Enhancement with Wasserstein GAN 3 | Deepak Baby, UGent, June 2018. 4 | Currently at IDIAP, Martigny, Switzerland 5 | """ 6 | from __future__ import print_function 7 | 8 | import tensorflow as tf 9 | from tensorflow.contrib.layers import xavier_initializer, flatten, fully_connected 10 | import numpy as np 11 | from keras.layers import Activation, Input 12 | from keras.models import Sequential, Model 13 | from keras.optimizers import Adam 14 | from keras.layers.merge import _Merge 15 | from keras.callbacks import TensorBoard 16 | import keras.backend as K 17 | 18 | from wgan_ops import * 19 | from data_ops import * 20 | from file_ops import * 21 | from models import * 22 | from functools import partial 23 | import time 24 | from tqdm import * 25 | import h5py 26 | import os,sys 27 | import scipy.io.wavfile as wavfile 28 | 29 | BATCH_SIZE = 100 30 | GRADIENT_PENALTY_WEIGHT = 10 # need to tune 31 | 32 | def wasserstein_loss(y_true, y_pred): 33 | return K.mean(y_true * y_pred) 34 | 35 | class RandomWeightedAverage (_Merge): 36 | def _merge_function (self, inputs): 37 | weights = K.random_uniform((BATCH_SIZE, 1, 1)) 38 | return (weights * inputs[0]) + ((1 - weights) * inputs[1]) 39 | 40 | if __name__ == '__main__': 41 | 42 | # Various GAN options 43 | opts = {} 44 | opts ['dirhead'] = "WGAN_GP" + str(GRADIENT_PENALTY_WEIGHT) 45 | opts ['gp_weight'] = GRADIENT_PENALTY_WEIGHT 46 | ########################## 47 | opts ['z_off'] = True # set to True to omit the latent noise input 48 | # normalization 49 | ################################# 50 | # Only one of the follwoing should be set to True 51 | opts ['applybn'] = False 52 | opts ['applyinstancenorm'] = True 53 | opts ['applygroupnorm'] = False 54 | ################################## 55 | # Show model summary 56 | opts ['show_summary'] = False 57 | 58 | ## Set the matfiles 59 | clean_train_matfile = "./data/clean_train_segan1d.mat" 60 | noisy_train_matfile = "./data/noisy_train_segan1d.mat" 61 | noisy_test_matfile = "./data/noisy_test_segan1d.mat" 62 | 63 | #################################################### 64 | # Other fixed options 65 | opts ['window_length'] = 2**14 66 | opts ['featdim'] = 1 # 1 since it is just 1d time samples 67 | opts ['filterlength'] = 31 68 | opts ['strides'] = 2 69 | opts ['padding'] = 'SAME' 70 | opts ['g_enc_numkernels'] = [16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024] 71 | opts ['d_fmaps'] = opts ['g_enc_numkernels'] # We use the same structure for discriminator 72 | opts['leakyrelualpha'] = 0.3 73 | opts ['batch_size'] = BATCH_SIZE 74 | opts ['applyprelu'] = True 75 | 76 | 77 | opts ['d_activation'] = 'leakyrelu' 78 | g_enc_numkernels = opts ['g_enc_numkernels'] 79 | opts ['g_dec_numkernels'] = g_enc_numkernels[:-1][::-1] + [1] 80 | opts ['gt_stride'] = 2 81 | opts ['g_l1loss'] = 200. 82 | opts ['d_lr'] = 2e-4 83 | opts ['g_lr'] = 2e-4 84 | opts ['random_seed'] = 111 85 | 86 | n_epochs = 81 87 | fs = 16000 88 | 89 | # set flags for training or testing 90 | TRAIN_SEGAN = True 91 | SAVE_MODEL = True 92 | LOAD_SAVED_MODEL = False 93 | TEST_SEGAN = True 94 | 95 | modeldir = get_modeldirname(opts) 96 | print ("The model directory is " + modeldir) 97 | print ("_____________________________________") 98 | 99 | if not os.path.exists(modeldir): 100 | os.makedirs(modeldir) 101 | 102 | # Obtain the generator and the discriminator 103 | D = discriminator(opts) 104 | G = generator(opts) 105 | 106 | # Define optimizers 107 | g_opt = keras.optimizers.Adam(lr=opts['g_lr']) 108 | d_opt = keras.optimizers.Adam(lr=opts['d_lr']) 109 | 110 | # The G model has the wav and the noise inputs 111 | wav_shape = (opts['window_length'], opts['featdim']) 112 | z_dim1 = int(opts['window_length']/ (opts ['strides'] ** len(opts ['g_enc_numkernels']))) 113 | z_dim2 = opts ['g_enc_numkernels'][-1] 114 | wav_in_clean = Input(shape=wav_shape, name="main_input_clean") 115 | wav_in_noisy = Input(shape=wav_shape, name="main_input_noisy") 116 | if not opts ['z_off']: 117 | z = Input (shape=(z_dim1, z_dim2), name="noise_input") 118 | G_wav = G([wav_in_noisy, z]) 119 | G_model = Model([wav_in_noisy, z], G_wav) 120 | else : 121 | G_wav = G(wav_in_noisy) 122 | G_model = Model(wav_in_noisy, G_wav) 123 | 124 | d_out = D([wav_in_clean, wav_in_noisy]) 125 | d_out = Activation('sigmoid', name='d_out')(d_out) 126 | D_model = Model([wav_in_clean, wav_in_noisy], d_out) 127 | G_model.summary() 128 | D_model.summary() 129 | 130 | # Incorporating Gradient Penalty 131 | for layer in D.layers : 132 | layer.trainable = False 133 | D.trainable = False 134 | if not opts ['z_off']: 135 | G_wav = G([wav_in_noisy, z]) 136 | else : 137 | G_wav = G(wav_in_noisy) 138 | D_out_for_G = D([G_wav, wav_in_noisy]) 139 | D_out_for_G = Activation('linear', name='DoutG')(D_out_for_G) 140 | if not opts ['z_off']: 141 | G_D = Model(inputs=[wav_in_clean, wav_in_noisy, z], outputs = [D_out_for_G, G_wav]) 142 | else : 143 | G_D = Model(inputs=[wav_in_clean, wav_in_noisy], outputs = [D_out_for_G, G_wav]) 144 | 145 | G_D.summary() 146 | G_D.compile(optimizer=g_opt, 147 | loss={'model_2': 'mean_absolute_error', 'DoutG': wasserstein_loss}, 148 | loss_weights = {'model_2' : opts['g_l1loss'], 'DoutG': 1} ) 149 | print (G_D.metrics_names) 150 | 151 | # Now we need D model so that gradient penalty can be incorporated 152 | for layer in D.layers : 153 | layer.trainable = True 154 | for layer in G.layers : 155 | layer.trainable = False 156 | D.trainable = True 157 | G.trainable = False 158 | if not opts ['z_off']: 159 | G_wav_for_D = G([wav_in_noisy, z]) 160 | else : 161 | G_wav_for_D = G(wav_in_noisy) 162 | 163 | d_out_for_G = D([G_wav_for_D, wav_in_noisy]) 164 | d_out_for_real = D([wav_in_clean, wav_in_noisy]) 165 | d_out_for_G = Activation('linear', name='Dout_fake')(d_out_for_G) 166 | d_out_for_real = Activation('linear', name='Dout_real')(d_out_for_real) 167 | 168 | # for gradient penalty 169 | averaged_samples = RandomWeightedAverage()([wav_in_clean, G_wav_for_D]) 170 | # We will need to this also through D, for computing the gradients 171 | d_out_for_averaged = D([averaged_samples, wav_in_noisy]) 172 | d_out_for_averaged = Activation('linear', name='Dout_avg')(d_out_for_averaged) 173 | # compute the GP loss by means of partial function in keras 174 | partial_gp_loss = partial(gradient_penalty_loss, 175 | averaged_samples = averaged_samples, 176 | gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT) 177 | partial_gp_loss.__name__ = 'gradient_penalty' 178 | if not opts ['z_off']: 179 | D_final = Model(inputs = [wav_in_clean, wav_in_noisy, z], 180 | outputs = [d_out_for_real, d_out_for_G, 181 | d_out_for_averaged]) 182 | else : 183 | D_final = Model(inputs = [wav_in_clean, wav_in_noisy], 184 | outputs = [d_out_for_real, d_out_for_G, 185 | d_out_for_averaged]) 186 | D_final.compile(optimizer = d_opt, 187 | loss = {'Dout_real' : wasserstein_loss, 'Dout_fake' : wasserstein_loss, 188 | 'Dout_avg' : partial_gp_loss}) 189 | D_final.summary() 190 | print (D_final.metrics_names) 191 | 192 | # create label vectors for training 193 | positive_y = np.ones((BATCH_SIZE, 1), dtype=np.float32) 194 | negative_y = -1 * positive_y 195 | dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32) # for GP Loss 196 | zeros_y = dummy_y 197 | 198 | if TEST_SEGAN: 199 | ftestnoisy = h5py.File(noisy_test_matfile) 200 | noisy_test_data = ftestnoisy['feat_data'] 201 | noisy_test_dfi = ftestnoisy['dfi'] 202 | print ("Number of test files: " + str(noisy_test_dfi.shape[1]) ) 203 | 204 | 205 | # Begin the training part 206 | if TRAIN_SEGAN: 207 | fclean = h5py.File(clean_train_matfile) 208 | clean_train_data = np.array(fclean['feat_data']) 209 | fnoisy = h5py.File(noisy_train_matfile) 210 | noisy_train_data = np.array(fnoisy['feat_data']) 211 | print ("********************************************") 212 | print (" SEGAN TRAINING ") 213 | print ("********************************************") 214 | print ("Shape of clean feats mat " + str(clean_train_data.shape)) 215 | print ("Shape of noisy feats mat " + str(noisy_train_data.shape)) 216 | numtrainsamples = clean_train_data.shape[1] 217 | 218 | # Tensorboard stuff 219 | log_path = './logs/' + modeldir 220 | callback = TensorBoard(log_path) 221 | callback.set_model(G_D) 222 | train_names = ['G_loss', 'G_adv_loss', 'G_l1Loss'] 223 | 224 | idx_all = np.arange(numtrainsamples) 225 | # set random seed 226 | np.random.seed(opts['random_seed']) 227 | 228 | batch_size = opts['batch_size'] 229 | num_batches_per_epoch = int(np.floor(clean_train_data.shape[1]/batch_size)) 230 | for epoch in range(n_epochs): 231 | # train D with minibatch 232 | np.random.shuffle(idx_all) # shuffle the indices for the next epoch 233 | for batch_idx in range(num_batches_per_epoch): 234 | start_time = time.time() 235 | idx_beg = batch_idx * batch_size 236 | idx_end = idx_beg + batch_size 237 | idx = np.sort(np.array(idx_all[idx_beg:idx_end])) 238 | #print ("Batch idx " + str(idx[:5]) +" ... " + str(idx[-5:])) 239 | cleanwavs = np.array(clean_train_data[:,idx]).T 240 | cleanwavs = data_preprocess(cleanwavs, preemph=opts['preemph']) 241 | cleanwavs = np.expand_dims(cleanwavs, axis = 2) 242 | noisywavs = np.array(noisy_train_data[:,idx]).T 243 | noisywavs = data_preprocess(noisywavs, preemph=opts['preemph']) 244 | noisywavs = np.expand_dims(noisywavs, axis = 2) 245 | if not opts ['z_off']: 246 | noiseinput = np.random.normal(0, 1, 247 | (batch_size, z_dim1, z_dim2)) 248 | [d_loss, d_loss_real, d_loss_fake, _] = D_final.train_on_batch({'main_input_clean': cleanwavs, 249 | 'main_input_noisy': noisywavs, 'noise_input': noiseinput}, 250 | {'Dout_real' : positive_y, 'Dout_fake': negative_y, 251 | 'Dout_avg' : dummy_y} ) 252 | [g_loss, g_dLoss, g_l1loss] = G_D.train_on_batch({'main_input_clean': cleanwavs, 253 | 'main_input_noisy': noisywavs, 'noise_input': noiseinput}, 254 | {'model_2': cleanwavs, 'DoutG': positive_y} ) 255 | else: 256 | [d_loss, d_loss_real, d_loss_fake, _] = D_final.train_on_batch({'main_input_clean': cleanwavs, 257 | 'main_input_noisy': noisywavs,}, 258 | {'Dout_real' : positive_y, 'Dout_fake': negative_y, 259 | 'Dout_avg' : dummy_y}) 260 | [g_loss, g_dLoss, g_l1loss] = G_D.train_on_batch({'main_input_clean': cleanwavs, 261 | 'main_input_noisy': noisywavs}, 262 | {'model_2': cleanwavs, 263 | 'DoutG': positive_y} ) 264 | time_taken = time.time() - start_time 265 | 266 | printlog = "E%d/%d:B%d/%d [D loss: %f] [D real loss: %f] [D fake loss: %f] [G loss: %f] [G_D loss: %f] [G_L1 loss: %f] [Exec. time: %f]" % (epoch, n_epochs, batch_idx, num_batches_per_epoch, d_loss, d_loss_real, d_loss_fake, g_loss, g_dLoss, g_l1loss, time_taken) 267 | 268 | print (printlog) 269 | # Tensorboard stuff 270 | logs = [g_loss, g_dLoss, g_l1loss] 271 | write_log(callback, train_names, logs, epoch) 272 | 273 | if (TEST_SEGAN and epoch % 10 == 0) or epoch == n_epochs - 1: 274 | print ("********************************************") 275 | print (" SEGAN TESTING ") 276 | print ("********************************************") 277 | 278 | resultsdir = modeldir + "/test_results_epoch" + str(epoch) 279 | if not os.path.exists(resultsdir): 280 | os.makedirs(resultsdir) 281 | 282 | if LOAD_SAVED_MODEL: 283 | print ("Loading model from " + modeldir + "/Gmodel") 284 | json_file = open(modeldir + "/Gmodel.json", "r") 285 | loaded_model_json = json_file.read() 286 | json_file.close() 287 | G_loaded = model_from_json(loaded_model_json) 288 | G_loaded.compile(loss='mean_squared_error', optimizer=g_opt) 289 | G_loaded.load_weights(modeldir + "/Gmodel.h5") 290 | else: 291 | G_loaded = G 292 | 293 | print ("Saving Results to " + resultsdir) 294 | 295 | for test_num in tqdm(range(noisy_test_dfi.shape[1])) : 296 | test_beg = noisy_test_dfi[0, test_num] 297 | test_end = noisy_test_dfi[1, test_num] 298 | #print ("Reading indices " + str(test_beg) + " to " + str(test_end)) 299 | noisywavs = np.array(noisy_test_data[:,test_beg:test_end]).T 300 | noisywavs = data_preprocess(noisywavs, preemph=opts['preemph']) 301 | noisywavs = np.expand_dims(noisywavs, axis = 2) 302 | if not opts['z_off']: 303 | noiseinput = np.random.normal(0, 1, (noisywavs.shape[0], z_dim1, z_dim2)) 304 | cleaned_wavs = G_loaded.predict([noisywavs, noiseinput]) 305 | else : 306 | cleaned_wavs = G_loaded.predict(noisywavs) 307 | 308 | cleaned_wavs = np.reshape(cleaned_wavs, (noisywavs.shape[0], noisywavs.shape[1])) 309 | cleanwav = reconstruct_wav(cleaned_wavs) 310 | cleanwav = np.reshape(cleanwav, (-1,)) # make it to 1d by dropping the extra dimension 311 | 312 | if opts['preemph'] > 0: 313 | cleanwav = de_emph(cleanwav, coeff=opts['preemph']) 314 | 315 | destfilename = resultsdir + "/testwav_%d.wav" % (test_num) 316 | wavfile.write(destfilename, fs, cleanwav) 317 | 318 | 319 | 320 | # Finally, save the model 321 | if SAVE_MODEL: 322 | model_json = G.to_json() 323 | with open(modeldir + "/Gmodel.json", "w") as json_file: 324 | json_file.write(model_json) 325 | G.save_weights(modeldir + "/Gmodel.h5") 326 | print ("Model saved to " + modeldir) 327 | -------------------------------------------------------------------------------- /normalizations.py: -------------------------------------------------------------------------------- 1 | from keras.engine import Layer, InputSpec 2 | from keras import initializers, regularizers, constraints 3 | import keras_contrib_backend as K_contrib 4 | from keras import backend as K 5 | from keras.utils.generic_utils import get_custom_objects 6 | 7 | import numpy as np 8 | 9 | 10 | class InstanceNormalization(Layer): 11 | """Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016). 12 | Normalize the activations of the previous layer at each step, 13 | i.e. applies a transformation that maintains the mean activation 14 | close to 0 and the activation standard deviation close to 1. 15 | # Arguments 16 | axis: Integer, the axis that should be normalized 17 | (typically the features axis). 18 | For instance, after a `Conv2D` layer with 19 | `data_format="channels_first"`, 20 | set `axis=1` in `InstanceNormalization`. 21 | Setting `axis=None` will normalize all values in each instance of the batch. 22 | Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors. 23 | epsilon: Small float added to variance to avoid dividing by zero. 24 | center: If True, add offset of `beta` to normalized tensor. 25 | If False, `beta` is ignored. 26 | scale: If True, multiply by `gamma`. 27 | If False, `gamma` is not used. 28 | When the next layer is linear (also e.g. `nn.relu`), 29 | this can be disabled since the scaling 30 | will be done by the next layer. 31 | beta_initializer: Initializer for the beta weight. 32 | gamma_initializer: Initializer for the gamma weight. 33 | beta_regularizer: Optional regularizer for the beta weight. 34 | gamma_regularizer: Optional regularizer for the gamma weight. 35 | beta_constraint: Optional constraint for the beta weight. 36 | gamma_constraint: Optional constraint for the gamma weight. 37 | # Input shape 38 | Arbitrary. Use the keyword argument `input_shape` 39 | (tuple of integers, does not include the samples axis) 40 | when using this layer as the first layer in a model. 41 | # Output shape 42 | Same shape as input. 43 | # References 44 | - [Layer Normalization](https://arxiv.org/abs/1607.06450) 45 | - [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022) 46 | """ 47 | def __init__(self, 48 | axis=None, 49 | epsilon=1e-3, 50 | center=True, 51 | scale=True, 52 | beta_initializer='zeros', 53 | gamma_initializer='ones', 54 | beta_regularizer=None, 55 | gamma_regularizer=None, 56 | beta_constraint=None, 57 | gamma_constraint=None, 58 | **kwargs): 59 | super(InstanceNormalization, self).__init__(**kwargs) 60 | self.supports_masking = True 61 | self.axis = axis 62 | self.epsilon = epsilon 63 | self.center = center 64 | self.scale = scale 65 | self.beta_initializer = initializers.get(beta_initializer) 66 | self.gamma_initializer = initializers.get(gamma_initializer) 67 | self.beta_regularizer = regularizers.get(beta_regularizer) 68 | self.gamma_regularizer = regularizers.get(gamma_regularizer) 69 | self.beta_constraint = constraints.get(beta_constraint) 70 | self.gamma_constraint = constraints.get(gamma_constraint) 71 | 72 | def build(self, input_shape): 73 | ndim = len(input_shape) 74 | if self.axis == 0: 75 | raise ValueError('Axis cannot be zero') 76 | 77 | if (self.axis is not None) and (ndim == 2): 78 | raise ValueError('Cannot specify axis for rank 1 tensor') 79 | 80 | self.input_spec = InputSpec(ndim=ndim) 81 | 82 | if self.axis is None: 83 | shape = (1,) 84 | else: 85 | shape = (input_shape[self.axis],) 86 | 87 | if self.scale: 88 | self.gamma = self.add_weight(shape=shape, 89 | name='gamma', 90 | initializer=self.gamma_initializer, 91 | regularizer=self.gamma_regularizer, 92 | constraint=self.gamma_constraint) 93 | else: 94 | self.gamma = None 95 | if self.center: 96 | self.beta = self.add_weight(shape=shape, 97 | name='beta', 98 | initializer=self.beta_initializer, 99 | regularizer=self.beta_regularizer, 100 | constraint=self.beta_constraint) 101 | else: 102 | self.beta = None 103 | self.built = True 104 | 105 | def call(self, inputs, training=None): 106 | input_shape = K.int_shape(inputs) 107 | reduction_axes = list(range(0, len(input_shape))) 108 | 109 | if (self.axis is not None): 110 | del reduction_axes[self.axis] 111 | 112 | del reduction_axes[0] 113 | 114 | mean = K.mean(inputs, reduction_axes, keepdims=True) 115 | stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon 116 | normed = (inputs - mean) / stddev 117 | 118 | broadcast_shape = [1] * len(input_shape) 119 | if self.axis is not None: 120 | broadcast_shape[self.axis] = input_shape[self.axis] 121 | 122 | if self.scale: 123 | broadcast_gamma = K.reshape(self.gamma, broadcast_shape) 124 | normed = normed * broadcast_gamma 125 | if self.center: 126 | broadcast_beta = K.reshape(self.beta, broadcast_shape) 127 | normed = normed + broadcast_beta 128 | return normed 129 | 130 | def get_config(self): 131 | config = { 132 | 'axis': self.axis, 133 | 'epsilon': self.epsilon, 134 | 'center': self.center, 135 | 'scale': self.scale, 136 | 'beta_initializer': initializers.serialize(self.beta_initializer), 137 | 'gamma_initializer': initializers.serialize(self.gamma_initializer), 138 | 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 139 | 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), 140 | 'beta_constraint': constraints.serialize(self.beta_constraint), 141 | 'gamma_constraint': constraints.serialize(self.gamma_constraint) 142 | } 143 | base_config = super(InstanceNormalization, self).get_config() 144 | return dict(list(base_config.items()) + list(config.items())) 145 | 146 | 147 | get_custom_objects().update({'InstanceNormalization': InstanceNormalization}) 148 | 149 | 150 | class BatchRenormalization(Layer): 151 | """Batch renormalization layer (Sergey Ioffe, 2017). 152 | Normalize the activations of the previous layer at each batch, 153 | i.e. applies a transformation that maintains the mean activation 154 | close to 0 and the activation standard deviation close to 1. 155 | # Arguments 156 | axis: Integer, the axis that should be normalized 157 | (typically the features axis). 158 | For instance, after a `Conv2D` layer with 159 | `data_format="channels_first"`, 160 | set `axis=1` in `BatchRenormalization`. 161 | momentum: momentum in the computation of the 162 | exponential average of the mean and standard deviation 163 | of the data, for feature-wise normalization. 164 | center: If True, add offset of `beta` to normalized tensor. 165 | If False, `beta` is ignored. 166 | scale: If True, multiply by `gamma`. 167 | If False, `gamma` is not used. 168 | epsilon: small float > 0. Fuzz parameter. 169 | Theano expects epsilon >= 1e-5. 170 | r_max_value: Upper limit of the value of r_max. 171 | d_max_value: Upper limit of the value of d_max. 172 | t_delta: At each iteration, increment the value of t by t_delta. 173 | weights: Initialization weights. 174 | List of 2 Numpy arrays, with shapes: 175 | `[(input_shape,), (input_shape,)]` 176 | Note that the order of this list is [gamma, beta, mean, std] 177 | beta_initializer: name of initialization function for shift parameter 178 | (see [initializers](../initializers.md)), or alternatively, 179 | Theano/TensorFlow function to use for weights initialization. 180 | This parameter is only relevant if you don't pass a `weights` argument. 181 | gamma_initializer: name of initialization function for scale parameter (see 182 | [initializers](../initializers.md)), or alternatively, 183 | Theano/TensorFlow function to use for weights initialization. 184 | This parameter is only relevant if you don't pass a `weights` argument. 185 | moving_mean_initializer: Initializer for the moving mean. 186 | moving_variance_initializer: Initializer for the moving variance. 187 | gamma_regularizer: instance of [WeightRegularizer](../regularizers.md) 188 | (eg. L1 or L2 regularization), applied to the gamma vector. 189 | beta_regularizer: instance of [WeightRegularizer](../regularizers.md), 190 | applied to the beta vector. 191 | beta_constraint: Optional constraint for the beta weight. 192 | gamma_constraint: Optional constraint for the gamma weight. 193 | # Input shape 194 | Arbitrary. Use the keyword argument `input_shape` 195 | (tuple of integers, does not include the samples axis) 196 | when using this layer as the first layer in a model. 197 | # Output shape 198 | Same shape as input. 199 | # References 200 | - [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167) 201 | """ 202 | 203 | def __init__(self, axis=-1, momentum=0.99, center=True, scale=True, epsilon=1e-3, 204 | r_max_value=3., d_max_value=5., t_delta=1e-3, weights=None, beta_initializer='zero', 205 | gamma_initializer='one', moving_mean_initializer='zeros', 206 | moving_variance_initializer='ones', gamma_regularizer=None, beta_regularizer=None, 207 | beta_constraint=None, gamma_constraint=None, **kwargs): 208 | self.supports_masking = True 209 | self.axis = axis 210 | self.epsilon = epsilon 211 | self.center = center 212 | self.scale = scale 213 | self.momentum = momentum 214 | self.gamma_regularizer = regularizers.get(gamma_regularizer) 215 | self.beta_regularizer = regularizers.get(beta_regularizer) 216 | self.initial_weights = weights 217 | self.r_max_value = r_max_value 218 | self.d_max_value = d_max_value 219 | self.t_delta = t_delta 220 | self.beta_initializer = initializers.get(beta_initializer) 221 | self.gamma_initializer = initializers.get(gamma_initializer) 222 | self.moving_mean_initializer = initializers.get(moving_mean_initializer) 223 | self.moving_variance_initializer = initializers.get(moving_variance_initializer) 224 | self.beta_constraint = constraints.get(beta_constraint) 225 | self.gamma_constraint = constraints.get(gamma_constraint) 226 | 227 | super(BatchRenormalization, self).__init__(**kwargs) 228 | 229 | def build(self, input_shape): 230 | dim = input_shape[self.axis] 231 | if dim is None: 232 | raise ValueError('Axis ' + str(self.axis) + ' of ' 233 | 'input tensor should have a defined dimension ' 234 | 'but the layer received an input with shape ' + 235 | str(input_shape) + '.') 236 | self.input_spec = InputSpec(ndim=len(input_shape), 237 | axes={self.axis: dim}) 238 | shape = (dim,) 239 | 240 | if self.scale: 241 | self.gamma = self.add_weight(shape, 242 | initializer=self.gamma_initializer, 243 | regularizer=self.gamma_regularizer, 244 | constraint=self.gamma_constraint, 245 | name='{}_gamma'.format(self.name)) 246 | else: 247 | self.gamma = None 248 | 249 | if self.center: 250 | self.beta = self.add_weight(shape, 251 | initializer=self.beta_initializer, 252 | regularizer=self.beta_regularizer, 253 | constraint=self.beta_constraint, 254 | name='{}_beta'.format(self.name)) 255 | else: 256 | self.beta = None 257 | 258 | self.running_mean = self.add_weight(shape, initializer=self.moving_mean_initializer, 259 | name='{}_running_mean'.format(self.name), 260 | trainable=False) 261 | 262 | self.running_variance = self.add_weight(shape, initializer=self.moving_variance_initializer, 263 | name='{}_running_std'.format(self.name), 264 | trainable=False) 265 | 266 | self.r_max = K.variable(1, name='{}_r_max'.format(self.name)) 267 | 268 | self.d_max = K.variable(0, name='{}_d_max'.format(self.name)) 269 | 270 | self.t = K.variable(0, name='{}_t'.format(self.name)) 271 | 272 | self.t_delta_tensor = K.constant(self.t_delta) 273 | 274 | if self.initial_weights is not None: 275 | self.set_weights(self.initial_weights) 276 | del self.initial_weights 277 | 278 | self.built = True 279 | 280 | def call(self, inputs, training=None): 281 | assert self.built, 'Layer must be built before being called' 282 | input_shape = K.int_shape(inputs) 283 | 284 | reduction_axes = list(range(len(input_shape))) 285 | del reduction_axes[self.axis] 286 | broadcast_shape = [1] * len(input_shape) 287 | broadcast_shape[self.axis] = input_shape[self.axis] 288 | 289 | mean_batch, var_batch = K_contrib.moments(inputs, reduction_axes, shift=None, keep_dims=False) 290 | std_batch = (K.sqrt(var_batch + self.epsilon)) 291 | 292 | r = std_batch / (K.sqrt(self.running_variance + self.epsilon)) 293 | r = K.stop_gradient(K_contrib.clip(r, 1 / self.r_max, self.r_max)) 294 | 295 | d = (mean_batch - self.running_mean) / K.sqrt(self.running_variance + self.epsilon) 296 | d = K.stop_gradient(K_contrib.clip(d, -self.d_max, self.d_max)) 297 | 298 | if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]: 299 | x_normed_batch = (inputs - mean_batch) / std_batch 300 | x_normed = (x_normed_batch * r + d) * self.gamma + self.beta 301 | else: 302 | # need broadcasting 303 | broadcast_mean = K.reshape(mean_batch, broadcast_shape) 304 | broadcast_std = K.reshape(std_batch, broadcast_shape) 305 | broadcast_r = K.reshape(r, broadcast_shape) 306 | broadcast_d = K.reshape(d, broadcast_shape) 307 | broadcast_beta = K.reshape(self.beta, broadcast_shape) 308 | broadcast_gamma = K.reshape(self.gamma, broadcast_shape) 309 | 310 | x_normed_batch = (inputs - broadcast_mean) / broadcast_std 311 | x_normed = (x_normed_batch * broadcast_r + broadcast_d) * broadcast_gamma + broadcast_beta 312 | 313 | # explicit update to moving mean and standard deviation 314 | self.add_update([K.moving_average_update(self.running_mean, mean_batch, self.momentum), 315 | K.moving_average_update(self.running_variance, std_batch ** 2, self.momentum)], inputs) 316 | 317 | # update r_max and d_max 318 | r_val = self.r_max_value / (1 + (self.r_max_value - 1) * K.exp(-self.t)) 319 | d_val = self.d_max_value / (1 + ((self.d_max_value / 1e-3) - 1) * K.exp(-(2 * self.t))) 320 | 321 | self.add_update([K.update(self.r_max, r_val), 322 | K.update(self.d_max, d_val), 323 | K.update_add(self.t, self.t_delta_tensor)], inputs) 324 | 325 | if training in {0, False}: 326 | return x_normed 327 | else: 328 | def normalize_inference(): 329 | if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]: 330 | x_normed_running = K.batch_normalization( 331 | inputs, self.running_mean, self.running_variance, 332 | self.beta, self.gamma, 333 | epsilon=self.epsilon) 334 | 335 | return x_normed_running 336 | else: 337 | # need broadcasting 338 | broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape) 339 | broadcast_running_std = K.reshape(self.running_variance, broadcast_shape) 340 | broadcast_beta = K.reshape(self.beta, broadcast_shape) 341 | broadcast_gamma = K.reshape(self.gamma, broadcast_shape) 342 | x_normed_running = K.batch_normalization( 343 | inputs, broadcast_running_mean, broadcast_running_std, 344 | broadcast_beta, broadcast_gamma, 345 | epsilon=self.epsilon) 346 | 347 | return x_normed_running 348 | 349 | # pick the normalized form of inputs corresponding to the training phase 350 | # for batch renormalization, inference time remains same as batchnorm 351 | x_normed = K.in_train_phase(x_normed, normalize_inference, training=training) 352 | 353 | return x_normed 354 | 355 | def get_config(self): 356 | config = {'epsilon': self.epsilon, 357 | 'axis': self.axis, 358 | 'center': self.center, 359 | 'scale': self.scale, 360 | 'momentum': self.momentum, 361 | 'gamma_regularizer': initializers.serialize(self.gamma_regularizer), 362 | 'beta_regularizer': initializers.serialize(self.beta_regularizer), 363 | 'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer), 364 | 'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer), 365 | 'beta_constraint': constraints.serialize(self.beta_constraint), 366 | 'gamma_constraint': constraints.serialize(self.gamma_constraint), 367 | 'r_max_value': self.r_max_value, 368 | 'd_max_value': self.d_max_value, 369 | 't_delta': self.t_delta} 370 | base_config = super(BatchRenormalization, self).get_config() 371 | return dict(list(base_config.items()) + list(config.items())) 372 | 373 | 374 | get_custom_objects().update({'BatchRenormalization': BatchRenormalization}) 375 | 376 | 377 | class GroupNormalization(Layer): 378 | """Group normalization layer 379 | Group Normalization divides the channels into groups and computes within each group 380 | the mean and variance for normalization. Group Normalization's computation is independent 381 | of batch sizes, and its accuracy is stable in a wide range of batch sizes. 382 | Relation to Layer Normalization: 383 | If the number of groups is set to 1, then this operation becomes identical to 384 | Layer Normalization. 385 | Relation to Instance Normalization: 386 | If the number of groups is set to the input dimension (number of groups is equal 387 | to number of channels), then this operation becomes identical to Instance Normalization. 388 | # Arguments 389 | groups: Integer, the number of groups for Group Normalization. 390 | Can be in the range [1, N] where N is the input dimension. 391 | The input dimension must be divisible by the number of groups. 392 | axis: Integer, the axis that should be normalized 393 | (typically the features axis). 394 | For instance, after a `Conv2D` layer with 395 | `data_format="channels_first"`, 396 | set `axis=1` in `BatchNormalization`. 397 | epsilon: Small float added to variance to avoid dividing by zero. 398 | center: If True, add offset of `beta` to normalized tensor. 399 | If False, `beta` is ignored. 400 | scale: If True, multiply by `gamma`. 401 | If False, `gamma` is not used. 402 | When the next layer is linear (also e.g. `nn.relu`), 403 | this can be disabled since the scaling 404 | will be done by the next layer. 405 | beta_initializer: Initializer for the beta weight. 406 | gamma_initializer: Initializer for the gamma weight. 407 | beta_regularizer: Optional regularizer for the beta weight. 408 | gamma_regularizer: Optional regularizer for the gamma weight. 409 | beta_constraint: Optional constraint for the beta weight. 410 | gamma_constraint: Optional constraint for the gamma weight. 411 | # Input shape 412 | Arbitrary. Use the keyword argument `input_shape` 413 | (tuple of integers, does not include the samples axis) 414 | when using this layer as the first layer in a model. 415 | # Output shape 416 | Same shape as input. 417 | # References 418 | - [Group Normalization](https://arxiv.org/abs/1803.08494) 419 | """ 420 | 421 | def __init__(self, 422 | groups=32, 423 | axis=-1, 424 | epsilon=1e-5, 425 | center=True, 426 | scale=True, 427 | beta_initializer='zeros', 428 | gamma_initializer='ones', 429 | beta_regularizer=None, 430 | gamma_regularizer=None, 431 | beta_constraint=None, 432 | gamma_constraint=None, 433 | **kwargs): 434 | super(GroupNormalization, self).__init__(**kwargs) 435 | self.supports_masking = True 436 | self.groups = groups 437 | self.axis = axis 438 | self.epsilon = epsilon 439 | self.center = center 440 | self.scale = scale 441 | self.beta_initializer = initializers.get(beta_initializer) 442 | self.gamma_initializer = initializers.get(gamma_initializer) 443 | self.beta_regularizer = regularizers.get(beta_regularizer) 444 | self.gamma_regularizer = regularizers.get(gamma_regularizer) 445 | self.beta_constraint = constraints.get(beta_constraint) 446 | self.gamma_constraint = constraints.get(gamma_constraint) 447 | 448 | def build(self, input_shape): 449 | dim = input_shape[self.axis] 450 | 451 | if dim is None: 452 | raise ValueError('Axis ' + str(self.axis) + ' of ' 453 | 'input tensor should have a defined dimension ' 454 | 'but the layer received an input with shape ' + 455 | str(input_shape) + '.') 456 | 457 | if dim < self.groups: 458 | raise ValueError('Number of groups (' + str(self.groups) + ') cannot be ' 459 | 'more than the number of channels (' + 460 | str(dim) + ').') 461 | 462 | if dim % self.groups != 0: 463 | raise ValueError('Number of groups (' + str(self.groups) + ') must be a ' 464 | 'multiple of the number of channels (' + 465 | str(dim) + ').') 466 | 467 | self.input_spec = InputSpec(ndim=len(input_shape), 468 | axes={self.axis: dim}) 469 | shape = (dim,) 470 | 471 | if self.scale: 472 | self.gamma = self.add_weight(shape=shape, 473 | name='gamma', 474 | initializer=self.gamma_initializer, 475 | regularizer=self.gamma_regularizer, 476 | constraint=self.gamma_constraint) 477 | else: 478 | self.gamma = None 479 | if self.center: 480 | self.beta = self.add_weight(shape=shape, 481 | name='beta', 482 | initializer=self.beta_initializer, 483 | regularizer=self.beta_regularizer, 484 | constraint=self.beta_constraint) 485 | else: 486 | self.beta = None 487 | self.built = True 488 | 489 | def call(self, inputs, **kwargs): 490 | input_shape = K.int_shape(inputs) 491 | tensor_input_shape = K.shape(inputs) 492 | 493 | # Prepare broadcasting shape. 494 | reduction_axes = list(range(len(input_shape))) 495 | del reduction_axes[self.axis] 496 | broadcast_shape = [1] * len(input_shape) 497 | broadcast_shape[self.axis] = input_shape[self.axis] // self.groups 498 | broadcast_shape.insert(1, self.groups) 499 | 500 | reshape_group_shape = K.shape(inputs) 501 | group_axes = [reshape_group_shape[i] for i in range(len(input_shape))] 502 | group_axes[self.axis] = input_shape[self.axis] // self.groups 503 | group_axes.insert(1, self.groups) 504 | 505 | # reshape inputs to new group shape 506 | group_shape = [group_axes[0], self.groups] + group_axes[2:] 507 | group_shape = K.stack(group_shape) 508 | inputs = K.reshape(inputs, group_shape) 509 | 510 | group_reduction_axes = list(range(len(group_axes))) 511 | mean, variance = K_contrib.moments(inputs, group_reduction_axes[2:], keep_dims=True) 512 | inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon)) 513 | 514 | # prepare broadcast shape 515 | inputs = K.reshape(inputs, group_shape) 516 | 517 | outputs = inputs 518 | 519 | # In this case we must explicitly broadcast all parameters. 520 | if self.scale: 521 | broadcast_gamma = K.reshape(self.gamma, broadcast_shape) 522 | outputs = outputs * broadcast_gamma 523 | 524 | if self.center: 525 | broadcast_beta = K.reshape(self.beta, broadcast_shape) 526 | outputs = outputs + broadcast_beta 527 | 528 | # finally we reshape the output back to the input shape 529 | outputs = K.reshape(outputs, tensor_input_shape) 530 | 531 | return outputs 532 | 533 | def get_config(self): 534 | config = { 535 | 'groups': self.groups, 536 | 'axis': self.axis, 537 | 'epsilon': self.epsilon, 538 | 'center': self.center, 539 | 'scale': self.scale, 540 | 'beta_initializer': initializers.serialize(self.beta_initializer), 541 | 'gamma_initializer': initializers.serialize(self.gamma_initializer), 542 | 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 543 | 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), 544 | 'beta_constraint': constraints.serialize(self.beta_constraint), 545 | 'gamma_constraint': constraints.serialize(self.gamma_constraint) 546 | } 547 | base_config = super(GroupNormalization, self).get_config() 548 | return dict(list(base_config.items()) + list(config.items())) 549 | 550 | def compute_output_shape(self, input_shape): 551 | return input_shape 552 | 553 | 554 | get_custom_objects().update({'GroupNormalization': GroupNormalization}) 555 | --------------------------------------------------------------------------------