├── .DS_Store ├── assets ├── sasegan.png ├── pesq_stoi.png └── attention_layer2.png ├── evaluate ├── .DS_Store ├── evaluate_sasegan.m ├── comp_snr.m ├── stoi.m └── composite.m ├── sasegan ├── .DS_Store ├── clean_wav.sh ├── clean_wav_dir.sh ├── LICENCE ├── data_loader.py ├── run.sh ├── bnorm.py ├── discriminator.py ├── main.py ├── selfattention.py ├── generator.py ├── ops.py └── model.py ├── cfg └── e2e_maker.cfg ├── create_training_tfrecord.sh ├── download_audio.sh ├── README.md └── make_tfrecords.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pquochuy/sasegan/HEAD/.DS_Store -------------------------------------------------------------------------------- /assets/sasegan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pquochuy/sasegan/HEAD/assets/sasegan.png -------------------------------------------------------------------------------- /evaluate/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pquochuy/sasegan/HEAD/evaluate/.DS_Store -------------------------------------------------------------------------------- /sasegan/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pquochuy/sasegan/HEAD/sasegan/.DS_Store -------------------------------------------------------------------------------- /assets/pesq_stoi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pquochuy/sasegan/HEAD/assets/pesq_stoi.png -------------------------------------------------------------------------------- /assets/attention_layer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pquochuy/sasegan/HEAD/assets/attention_layer2.png -------------------------------------------------------------------------------- /cfg/e2e_maker.cfg: -------------------------------------------------------------------------------- 1 | [segan] 2 | noisy="data/noisy_trainset_wav_16k/" 3 | clean="data/clean_trainset_wav_16k/" 4 | -------------------------------------------------------------------------------- /create_training_tfrecord.sh: -------------------------------------------------------------------------------- 1 | python make_tfrecords.py --force-gen --wav_dir "data/clean_trainset_wav_16k/" --noisy_dir "data/noisy_trainset_wav_16k/" 2 | -------------------------------------------------------------------------------- /sasegan/clean_wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # guia file containing pointers to files to clean up 5 | if [ $# -lt 1 ]; then 6 | echo 'ERROR: at least wavname must be provided!' 7 | echo "Usage: $0 [optional:save_path]" 8 | echo "If no save_path is specified, clean file is saved in current dir" 9 | exit 1 10 | fi 11 | 12 | NOISY_WAVNAME="$1" 13 | SAVE_PATH="." 14 | if [ $# -gt 1 ]; then 15 | SAVE_PATH="$2" 16 | fi 17 | 18 | echo "INPUT NOISY WAV: $NOISY_WAVNAME" 19 | echo "SAVE PATH: $SAVE_PATH" 20 | mkdir -p $SAVE_PATH 21 | 22 | python main.py --init_noise_std 0. --save_path segan_allbiased_preemph \ 23 | --batch_size 100 --g_nl prelu --weights SEGAN-41700 \ 24 | --preemph 0.95 --bias_deconv True \ 25 | --bias_downconv True --bias_D_conv True \ 26 | --test_wav $NOISY_WAVNAME --save_clean_path $SAVE_PATH 27 | -------------------------------------------------------------------------------- /sasegan/clean_wav_dir.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # guia file containing pointers to files to clean up 5 | if [ $# -lt 1 ]; then 6 | echo 'ERROR: at least wavname must be provided!' 7 | echo "Usage: $0 [optional:save_path]" 8 | echo "If no save_path is specified, clean file is saved in current dir" 9 | exit 1 10 | fi 11 | 12 | NOISY_WAVDIR="$1" 13 | SAVE_PATH="." 14 | if [ $# -gt 1 ]; then 15 | SAVE_PATH="$2" 16 | fi 17 | 18 | echo "INPUT NOISY WAV DIRECTORY: $NOISY_WAVDIR" 19 | echo "SAVE PATH: $SAVE_PATH" 20 | mkdir -p $SAVE_PATH 21 | 22 | python main.py --init_noise_std 0. --save_path segan_allbiased_preemph \ 23 | --batch_size 100 --g_nl prelu --weights SEGAN-41700 \ 24 | --preemph 0.95 --bias_deconv True \ 25 | --bias_downconv True --bias_D_conv True \ 26 | --test_wav_dir $NOISY_WAVDIR --save_clean_path $SAVE_PATH 27 | -------------------------------------------------------------------------------- /sasegan/LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Santi Dsp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /evaluate/evaluate_sasegan.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | clc 4 | 5 | list_file = dir('../data/clean_testset_wav_16k/*.wav'); 6 | 7 | % different model checkpoints 8 | cp = {'97000','97100','97200','97300','97400'}; 9 | Ncp = numel(cp); 10 | 11 | fs_signal = 16000; 12 | % attention layer indes, you may want to change this to the value that you set when training the model 13 | att_layer_index = 2; 14 | 15 | ret = zeros(Ncp, 6); 16 | for c = 1 : Ncp 17 | ret_c = zeros(numel(list_file),6); 18 | parfor f = 1 : numel(list_file) 19 | disp(list_file(f).name); 20 | clean_wav = ['../data/clean_testset_wav_16k/', list_file(f).name]; 21 | noisy_wav = ['../sasegan/cleaned_testset_wav_16k_att',num2str(att),'_', cp{c}, '/', list_file(f).name]; 22 | spesq = pesq(clean_wav, noisy_wav); 23 | [~,ssnr] = comp_snr(clean_wav, noisy_wav); 24 | [Csig,Cbak,Covl] = composite(clean_wav,noisy_wav); 25 | 26 | [x, ~] = audioread(clean_wav); 27 | [y, ~] = audioread(noisy_wav); 28 | d_stoi = stoi(x, y, fs_signal); 29 | 30 | ret_c(f,:) = [spesq, Csig, Cbak, Covl, ssnr, d_stoi]; 31 | end 32 | ret(c, :) = mean(ret_c); 33 | end 34 | disp('Average: PESQ, CSIG, CBAK, COVL, SSNR, STOI') 35 | mean(ret) -------------------------------------------------------------------------------- /sasegan/data_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | from ops import * 4 | import numpy as np 5 | 6 | 7 | def pre_emph(x, coeff=0.95): 8 | x0 = tf.reshape(x[0], [ 9 | 1, 10 | ]) 11 | diff = x[1:] - coeff * x[:-1] 12 | concat = tf.concat([x0, diff], 0) 13 | return concat 14 | 15 | 16 | def de_emph(y, coeff=0.95): 17 | if coeff <= 0: 18 | return y 19 | x = np.zeros(y.shape[0], dtype=np.float32) 20 | x[0] = y[0] 21 | for n in range(1, y.shape[0], 1): 22 | x[n] = coeff * x[n - 1] + y[n] 23 | return x 24 | 25 | 26 | def read_and_decode(filename_queue, canvas_size, preemph=0.): 27 | reader = tf.TFRecordReader() 28 | _, serialized_example = reader.read(filename_queue) 29 | features = tf.parse_single_example( 30 | serialized_example, 31 | features={ 32 | 'wav_raw': tf.FixedLenFeature([], tf.string), 33 | 'noisy_raw': tf.FixedLenFeature([], tf.string), 34 | }) 35 | wave = tf.decode_raw(features['wav_raw'], tf.int32) 36 | wave.set_shape(canvas_size) 37 | wave = (2. / 65535.) * tf.cast((wave - 32767), tf.float32) + 1. 38 | noisy = tf.decode_raw(features['noisy_raw'], tf.int32) 39 | noisy.set_shape(canvas_size) 40 | noisy = (2. / 65535.) * tf.cast((noisy - 32767), tf.float32) + 1. 41 | 42 | if preemph > 0: 43 | wave = tf.cast(pre_emph(wave, preemph), tf.float32) 44 | noisy = tf.cast(pre_emph(noisy, preemph), tf.float32) 45 | 46 | return wave, noisy 47 | -------------------------------------------------------------------------------- /download_audio.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # DOWNLOAD THE DATASET 4 | #mkdir -p data 5 | pushd data 6 | if [ ! -d clean_testset_wav_16k ]; then 7 | # Clean utterances 8 | if [ ! -f clean_testset_wav.zip ]; then 9 | echo 'DOWNLOADING CLEAN DATASET...' 10 | wget http://datashare.is.ed.ac.uk/bitstream/handle/10283/1942/clean_testset_wav.zip 11 | fi 12 | if [ ! -d clean_testset_wav ]; then 13 | echo 'INFLATING CLEAN TESTSET ZIP...' 14 | unzip -q clean_testset_wav.zip -d clean_testset_wav 15 | fi 16 | if [ ! -d clean_testset_wav_16k ]; then 17 | echo 'CONVERTING CLEAN WAVS TO 16K...' 18 | mkdir -p clean_testset_wav_16k 19 | pushd clean_testset_wav 20 | ls *.wav | while read name; do 21 | sox $name -r 16k ../clean_testset_wav_16k/$name 22 | done 23 | popd 24 | fi 25 | fi 26 | if [ ! -d noisy_testset_wav_16k ]; then 27 | # Noisy utterances 28 | if [ ! -f noisy_testset_wav.zip ]; then 29 | echo 'DOWNLOADING NOISY DATASET...' 30 | wget http://datashare.is.ed.ac.uk/bitstream/handle/10283/1942/noisy_testset_wav.zip 31 | fi 32 | if [ ! -d noisy_testset_wav ]; then 33 | echo 'INFLATING NOISY TRAINSET ZIP...' 34 | unzip -q noisy_testset_wav.zip -d noisy_testset_wav 35 | fi 36 | if [ ! -d noisy_testset_wav_16k ]; then 37 | echo 'CONVERTING NOISY WAVS TO 16K...' 38 | mkdir -p noisy_testset_wav_16k 39 | pushd noisy_testset_wav 40 | ls *.wav | while read name; do 41 | sox $name -r 16k ../noisy_testset_wav_16k/$name 42 | done 43 | popd 44 | fi 45 | fi 46 | popd 47 | -------------------------------------------------------------------------------- /sasegan/run.sh: -------------------------------------------------------------------------------- 1 | # Model training 2 | # --att_layer_ind "X", where X takes a values in {2, 3, 4, 5 ,6, 7, 8, 9, 10} 3 | CUDA_VISIBLE_DEVICES="0,-1" python main.py --init_noise_std 0. --save_path segan_allbiased_preemph_att1 --init_l1_weight 100. --batch_size 50 --g_nl prelu --save_freq 50 --preemph 0.95 --epoch 100 --bias_deconv True --bias_downconv True --bias_D_conv True --e2e_dataset '../data/segan.tfrecords' --att_layer_ind "2" --synthesis_path dwavegan_samples 4 | # test the trained model with different checkpoint 5 | mkdir cleaned_testset_wav_16k_att1_97000 6 | CUDA_VISIBLE_DEVICES="0,-1" python main.py --init_noise_std 0. --save_path segan_allbiased_preemph_att1 --batch_size 50 --g_nl prelu --weights SEGAN-97000 --preemph 0.95 --bias_deconv True --bias_downconv True --bias_D_conv True --test_wav_dir '../data/noisy_testset_wav_16k/' --save_clean_path './cleaned_testset_wav_16k_att1_97000/' --att_layer_ind "1" 7 | mkdir cleaned_testset_wav_16k_att1_97100 8 | CUDA_VISIBLE_DEVICES="0,-1" python main.py --init_noise_std 0. --save_path segan_allbiased_preemph_att1 --batch_size 50 --g_nl prelu --weights SEGAN-97100 --preemph 0.95 --bias_deconv True --bias_downconv True --bias_D_conv True --test_wav_dir '../data/noisy_testset_wav_16k/' --save_clean_path './cleaned_testset_wav_16k_att1_97100/' --att_layer_ind "1" 9 | mkdir cleaned_testset_wav_16k_att1_97200 10 | CUDA_VISIBLE_DEVICES="0,-1" python main.py --init_noise_std 0. --save_path segan_allbiased_preemph_att1 --batch_size 50 --g_nl prelu --weights SEGAN-97200 --preemph 0.95 --bias_deconv True --bias_downconv True --bias_D_conv True --test_wav_dir '../data/noisy_testset_wav_16k/' --save_clean_path './cleaned_testset_wav_16k_att1_97200/' --att_layer_ind "1" 11 | mkdir cleaned_testset_wav_16k_att1_97300 12 | CUDA_VISIBLE_DEVICES="0,-1" python main.py --init_noise_std 0. --save_path segan_allbiased_preemph_att1 --batch_size 50 --g_nl prelu --weights SEGAN-97300 --preemph 0.95 --bias_deconv True --bias_downconv True --bias_D_conv True --test_wav_dir '../data/noisy_testset_wav_16k/' --save_clean_path './cleaned_testset_wav_16k_att1_97300/' --att_layer_ind "1" 13 | mkdir cleaned_testset_wav_16k_att1_97400 14 | CUDA_VISIBLE_DEVICES="0,-1" python main.py --init_noise_std 0. --save_path segan_allbiased_preemph_att1 --batch_size 50 --g_nl prelu --weights SEGAN-97400 --preemph 0.95 --bias_deconv True --bias_downconv True --bias_D_conv True --test_wav_dir '../data/noisy_testset_wav_16k/' --save_clean_path './cleaned_testset_wav_16k_att1_97400/' --att_layer_ind "1" 15 | -------------------------------------------------------------------------------- /sasegan/bnorm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class VBN(object): 5 | """ 6 | Virtual Batch Normalization 7 | (modified from https://github.com/openai/improved-gan/ definition) 8 | """ 9 | 10 | def __init__(self, x, name, epsilon=1e-5): 11 | """ 12 | x is the reference batch 13 | """ 14 | assert isinstance(epsilon, float) 15 | 16 | shape = x.get_shape().as_list() 17 | assert len(shape) == 3, shape 18 | with tf.variable_scope(name) as scope: 19 | assert name.startswith("d_") or name.startswith("g_") 20 | self.epsilon = epsilon 21 | self.name = name 22 | self.mean = tf.reduce_mean(x, [0, 1], keep_dims=True) 23 | self.mean_sq = tf.reduce_mean(tf.square(x), [0, 1], keep_dims=True) 24 | self.batch_size = int(x.get_shape()[0]) 25 | assert x is not None 26 | assert self.mean is not None 27 | assert self.mean_sq is not None 28 | out = self._normalize(x, self.mean, self.mean_sq, "reference") 29 | self.reference_output = out 30 | 31 | def __call__(self, x): 32 | 33 | shape = x.get_shape().as_list() 34 | with tf.variable_scope(self.name) as scope: 35 | new_coeff = 1. / (self.batch_size + 1.) 36 | old_coeff = 1. - new_coeff 37 | new_mean = tf.reduce_mean(x, [0, 1], keep_dims=True) 38 | new_mean_sq = tf.reduce_mean(tf.square(x), [0, 1], keep_dims=True) 39 | mean = new_coeff * new_mean + old_coeff * self.mean 40 | mean_sq = new_coeff * new_mean_sq + old_coeff * self.mean_sq 41 | out = self._normalize(x, mean, mean_sq, "live") 42 | return out 43 | 44 | def _normalize(self, x, mean, mean_sq, message): 45 | # make sure this is called with a variable scope 46 | shape = x.get_shape().as_list() 47 | assert len(shape) == 3 48 | self.gamma = tf.get_variable("gamma", [shape[-1]], 49 | initializer=tf.random_normal_initializer(1., 0.02)) 50 | gamma = tf.reshape(self.gamma, [1, 1, -1]) 51 | self.beta = tf.get_variable("beta", [shape[-1]], 52 | initializer=tf.constant_initializer(0.)) 53 | beta = tf.reshape(self.beta, [1, 1, -1]) 54 | assert self.epsilon is not None 55 | assert mean_sq is not None 56 | assert mean is not None 57 | std = tf.sqrt(self.epsilon + mean_sq - tf.square(mean)) 58 | out = x - mean 59 | out = out / std 60 | out = out * gamma 61 | out = out + beta 62 | return out 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Self-Attention Generative Adversarial Network for Speech Enhancement 2 | 3 | 4 | ### Introduction 5 | 6 | This is the repository of the sel-attention GAN for speech enhancement (SASEGAN) in our original paper: 7 | 8 | H. Phan, H. L. Nguyen, O. Y. Chén, P. Koch, N. Q. K. Duong, I. McLoughlin, and A. Mertins, "[_Self-Attention Generative Adversarial Network for Speech Enhancement_](https://arxiv.org/pdf/2010.09132)," Proc. ICASSP, 2021. 9 | 10 | SASEGAN integrates non-local based self-attention to convolutional layers of SEGAN [Pascual _et al._](https://arxiv.org/abs/1703.09452) to improve sequential modelling. 11 | 12 | [//]: #![SASESEGAN](assets/sasegan.png) 13 | sasegan.png 14 | 15 | 16 | **The project is developed with TensorFlow 1**. ([Go to Tensorflow 2 Version](https://github.com/usimarit/sasegan)) 17 | ### Dependencies 18 | 19 | * tensorflow_gpu 1.9 20 | * numpy==1.1.3 21 | * scipy==1.0.0 22 | 23 | ### Data 24 | 25 | The speech enhancement dataset used in the work can be found in [Edinburgh DataShare](http://datashare.is.ed.ac.uk/handle/10283/1942). **The following script downloads and prepares the data for TensorFlow format**: 26 | 27 | ``` 28 | ./download_audio.sh 29 | ./create_training_tfrecord.sh 30 | ``` 31 | 32 | Or alternatively download the dataset, convert the wav files to 16kHz sampling and set the `noisy` and `clean` training files paths in the config file `e2e_maker.cfg` in `cfg/`. Then run the script: 33 | 34 | ``` 35 | python make_tfrecords.py --force-gen --cfg cfg/e2e_maker.cfg 36 | ``` 37 | 38 | ### Training 39 | 40 | Once you have the TFRecords file created in `data/segan.tfrecords` you can simply run the following script: 41 | 42 | ``` 43 | # SASEGAN: run inside sasegan directory 44 | ./run.sh 45 | ``` 46 | The script consists of commands for training and testing with 5 different checkpoints of the trained model on the test audio files. You may want to set the convolutional layer index (the `--att_layer_ind` parameter)where you want to have self-attention component integrated. 47 | 48 | The trained models can be downloaded [HERE](https://zenodo.org/record/4288589) 49 | 50 | ### Results 51 | 52 | Enhancement results compared to the SEGAN baseline: 53 | 54 | [//]: #![results](assets/pesq_stoi.png) 55 | pesq_stoi.png 56 | 57 | Visualization of attention weights (the convolutional layer index 2) at two different time indices of the input: 58 | 59 | [//]: #![results](assets/attention_layer2.png) 60 | attention_layer2.png 61 | 62 | 63 | ### Reference 64 | 65 | ``` 66 | @article{phan2020sasegan, 67 | title={Self-Attention Generative Adversarial Network for Speech Enhancement}, 68 | author={H. Phan, H. L. Nguyen, O. Y. Chén, P. Koch, N. Q. K. Duong, I. McLoughlin, and A. Mertins}, 69 | journal={ICASSP}, 70 | year={2021} 71 | } 72 | ``` 73 | 74 | 1. [Speech enhancement GAN](https://github.com/santi-pdp/segan) 75 | 2. [Improving GANs for speech enhancement](https://github.com/pquochuy/idsegan) 76 | 2. [Self-attention GAN](https://github.com/brain-research/self-attention-gan) 77 | 78 | ### Contact 79 | 80 | Huy Phan 81 | 82 | School of Electronic Engineering and Computer Science 83 | Queen Mary University of London 84 | Email: h.phan{at}qmul.ac.uk 85 | 86 | ### Notes 87 | 88 | * If using this code, parts of it, or developments from it, please cite the above reference. 89 | * We do not provide any support or assistance for the supplied code nor we offer any other compilation/variant of it. 90 | * We assume no responsibility regarding the provided code. 91 | 92 | 93 | ### License 94 | 95 | MIT © Huy Phan 96 | -------------------------------------------------------------------------------- /sasegan/discriminator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | from tensorflow.contrib.layers import batch_norm, fully_connected, flatten 4 | from tensorflow.contrib.layers import xavier_initializer 5 | from ops import * 6 | from selfattention import * 7 | import numpy as np 8 | 9 | 10 | def discriminator(self, wave_in, reuse=False): 11 | """ 12 | wave_in: waveform input 13 | """ 14 | # take the waveform as input "activation" 15 | in_dims = wave_in.get_shape().as_list() 16 | hi = wave_in 17 | if len(in_dims) == 2: 18 | hi = tf.expand_dims(wave_in, -1) 19 | elif len(in_dims) < 2 or len(in_dims) > 3: 20 | raise ValueError('Discriminator input must be 2-D or 3-D') 21 | 22 | batch_size = int(wave_in.get_shape()[0]) 23 | 24 | # set up the disc_block function 25 | with tf.variable_scope('d_model') as scope: 26 | if reuse: 27 | scope.reuse_variables() 28 | def disc_block(block_idx, input_, kwidth, nfmaps, bnorm, activation, 29 | pooling=2): 30 | with tf.variable_scope('d_block_{}'.format(block_idx)): 31 | if not reuse: 32 | print('D block {} input shape: {}' 33 | ''.format(block_idx, input_.get_shape()), 34 | end=' *** ') 35 | bias_init = None 36 | if self.bias_D_conv: 37 | if not reuse: 38 | print('biasing D conv', end=' *** ') 39 | bias_init = tf.constant_initializer(0.) 40 | downconv_init = tf.truncated_normal_initializer(stddev=0.02) 41 | hi_a = sn_downconv(input_, nfmaps, kwidth=kwidth, pool=pooling, 42 | init=downconv_init, bias_init=bias_init) 43 | if not reuse: 44 | print('downconved shape: {} '.format(hi_a.get_shape()), end=' *** ') 45 | if bnorm: 46 | if not reuse: 47 | print('Applying VBN', end=' *** ') 48 | hi_a = self.vbn(hi_a, 'd_vbn_{}'.format(block_idx)) 49 | if activation == 'leakyrelu': 50 | if not reuse: 51 | print('Applying Lrelu', end=' *** ') 52 | hi = leakyrelu(hi_a) 53 | elif activation == 'relu': 54 | if not reuse: 55 | print('Applying Relu', end=' *** ') 56 | hi = tf.nn.relu(hi_a) 57 | else: 58 | raise ValueError('Unrecognized activation {} ' 59 | 'in D'.format(activation)) 60 | return hi 61 | beg_size = self.canvas_size 62 | # apply input noisy layer to real and fake samples 63 | hi = gaussian_noise_layer(hi, self.disc_noise_std) 64 | if not reuse: 65 | print('*** Discriminator summary ***') 66 | for block_idx, fmaps in enumerate(self.d_num_fmaps): 67 | hi = disc_block(block_idx, hi, 31, self.d_num_fmaps[block_idx], True, 'leakyrelu') 68 | # self-attention 69 | #if block_idx == len(self.d_num_fmaps) // 2: 70 | #if block_idx == self.att_layer_ind: 71 | if block_idx in self.enc_att_layer_ind: 72 | hi_2d = tf.expand_dims(hi, 2) 73 | hi_2d = sn_non_local_block_sim(hi_2d, None, 'discriminator_attention_layer{}'.format(block_idx)) 74 | hi = tf.reshape(hi_2d, hi_2d.get_shape().as_list()[:2] + [hi_2d.get_shape().as_list()[-1]]) 75 | print('Discriminator: self-attention') 76 | if not reuse: 77 | print() 78 | if not reuse: 79 | print('discriminator deconved shape: ', hi.get_shape()) 80 | hi_f = flatten(hi) 81 | #hi_f = tf.nn.dropout(hi_f, self.keep_prob_var) 82 | d_logit_out = conv1d(hi, kwidth=1, num_kernels=1, 83 | init=tf.truncated_normal_initializer(stddev=0.02), 84 | name='logits_conv') 85 | d_logit_out = tf.squeeze(d_logit_out) 86 | d_logit_out = fully_connected(d_logit_out, 1, activation_fn=None) 87 | if not reuse: 88 | print('discriminator output shape: ', d_logit_out.get_shape()) 89 | print('*****************************') 90 | return d_logit_out 91 | -------------------------------------------------------------------------------- /evaluate/comp_snr.m: -------------------------------------------------------------------------------- 1 | % use segmental SNR for evaluation 2 | function [snr_mean, segsnr_mean]= comp_SNR(cleanFile, enhdFile); 3 | % 4 | % Segmental Signal-to-Noise Ratio Objective Speech Quality Measure 5 | % 6 | % This function implements the segmental signal-to-noise ratio 7 | % as defined in [1, p. 45] (see Equation 2.12). 8 | % 9 | % Usage: [SNRovl, SNRseg]=comp_snr(cleanFile.wav, enhancedFile.wav) 10 | % 11 | % cleanFile.wav - clean input file in .wav format 12 | % enhancedFile - enhanced output file in .wav format 13 | % SNRovl - overall SNR (dB) 14 | % SNRseg - segmental SNR (dB) 15 | % 16 | % This function returns 2 parameters. The first item is the 17 | % overall SNR for the two speech signals. The second value 18 | % is the segmental signal-to-noise ratio (1 seg-snr per 19 | % frame of input). The segmental SNR is clamped to range 20 | % between 35dB and -10dB (see suggestions in [2]). 21 | % 22 | % Example call: [SNRovl,SNRseg]=comp_SNR('sp04.wav','enhanced.wav') 23 | % 24 | % References: 25 | % 26 | % [1] S. R. Quackenbush, T. P. Barnwell, and M. A. Clements, 27 | % Objective Measures of Speech Quality. Prentice Hall 28 | % Advanced Reference Series, Englewood Cliffs, NJ, 1988, 29 | % ISBN: 0-13-629056-6. 30 | % 31 | % [2] P. E. Papamichalis, Practical Approaches to Speech 32 | % Coding, Prentice-Hall, Englewood Cliffs, NJ, 1987. 33 | % ISBN: 0-13-689019-9. (see pages 179-181). 34 | % 35 | % Authors: Bryan L. Pellom and John H. L. Hansen (July 1998) 36 | % Modified by: Philipos C. Loizou (Oct 2006) 37 | % 38 | % Copyright (c) 2006 by Philipos C. Loizou 39 | % $Revision: 0.0 $ $Date: 10/09/2006 $ 40 | %------------------------------------------------------------------------- 41 | 42 | if nargin ~=2 43 | fprintf('USAGE: [snr_mean, segsnr_mean]= comp_SNR(cleanFile, enhdFile) \n'); 44 | return; 45 | end 46 | 47 | % [data1, Srate1, Nbits1]= wavread(cleanFile); 48 | % [data2, Srate2, Nbits2]= wavread(enhdFile); 49 | % if (( Srate1~= Srate2) | ( Nbits1~= Nbits2)) 50 | % error( 'The two files do not match!\n'); 51 | % end 52 | 53 | [data1, Srate1]= audioread(cleanFile); 54 | [data2, Srate2]= audioread(enhdFile); 55 | if ( Srate1~= Srate2) 56 | error( 'The two files do not match!\n'); 57 | end 58 | 59 | len= min( length( data1), length( data2)); 60 | data1= data1( 1: len); 61 | data2= data2( 1: len); 62 | 63 | [snr_dist, segsnr_dist]= snr( data1, data2,Srate1); 64 | 65 | snr_mean= snr_dist; 66 | segsnr_mean= mean( segsnr_dist); 67 | 68 | 69 | % ========================================================================= 70 | function [overall_snr, segmental_snr] = snr(clean_speech, processed_speech,sample_rate) 71 | 72 | % ---------------------------------------------------------------------- 73 | % Check the length of the clean and processed speech. Must be the same. 74 | % ---------------------------------------------------------------------- 75 | 76 | clean_length = length(clean_speech); 77 | processed_length = length(processed_speech); 78 | 79 | if (clean_length ~= processed_length) 80 | disp('Error: Both Speech Files must be same length.'); 81 | return 82 | end 83 | 84 | % ---------------------------------------------------------------------- 85 | % Scale both clean speech and processed speech to have same dynamic 86 | % range. Also remove DC component from each signal 87 | % ---------------------------------------------------------------------- 88 | 89 | %clean_speech = clean_speech - mean(clean_speech); 90 | %processed_speech = processed_speech - mean(processed_speech); 91 | 92 | %processed_speech = processed_speech.*(max(abs(clean_speech))/ max(abs(processed_speech))); 93 | 94 | overall_snr = 10* log10( sum(clean_speech.^2)/sum((clean_speech-processed_speech).^2)); 95 | 96 | % ---------------------------------------------------------------------- 97 | % Global Variables 98 | % ---------------------------------------------------------------------- 99 | 100 | 101 | winlength = round(30*sample_rate/1000); %240; % window length in samples for 30-msecs 102 | skiprate = floor(winlength/4); %60; % window skip in samples 103 | MIN_SNR = -10; % minimum SNR in dB 104 | MAX_SNR = 35; % maximum SNR in dB 105 | 106 | % ---------------------------------------------------------------------- 107 | % For each frame of input speech, calculate the Segmental SNR 108 | % ---------------------------------------------------------------------- 109 | 110 | num_frames = clean_length/skiprate-(winlength/skiprate); % number of frames 111 | start = 1; % starting sample 112 | window = 0.5*(1 - cos(2*pi*(1:winlength)'/(winlength+1))); 113 | 114 | for frame_count = 1: num_frames 115 | 116 | % ---------------------------------------------------------- 117 | % (1) Get the Frames for the test and reference speech. 118 | % Multiply by Hanning Window. 119 | % ---------------------------------------------------------- 120 | 121 | clean_frame = clean_speech(start:start+winlength-1); 122 | processed_frame = processed_speech(start:start+winlength-1); 123 | clean_frame = clean_frame.*window; 124 | processed_frame = processed_frame.*window; 125 | 126 | % ---------------------------------------------------------- 127 | % (2) Compute the Segmental SNR 128 | % ---------------------------------------------------------- 129 | 130 | signal_energy = sum(clean_frame.^2); 131 | noise_energy = sum((clean_frame-processed_frame).^2); 132 | segmental_snr(frame_count) = 10*log10(signal_energy/(noise_energy+eps)+eps); 133 | segmental_snr(frame_count) = max(segmental_snr(frame_count),MIN_SNR); 134 | segmental_snr(frame_count) = min(segmental_snr(frame_count),MAX_SNR); 135 | 136 | start = start + skiprate; 137 | 138 | end 139 | 140 | -------------------------------------------------------------------------------- /make_tfrecords.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | import numpy as np 4 | from collections import namedtuple, OrderedDict 5 | from subprocess import call 6 | import scipy.io.wavfile as wavfile 7 | #import argparsetoml 8 | import codecs 9 | import timeit 10 | import struct 11 | #import toml 12 | import re 13 | import sys 14 | import os 15 | 16 | 17 | def _int64_feature(value): 18 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 19 | 20 | 21 | def _bytes_feature(value): 22 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 23 | 24 | 25 | def slice_signal(signal, window_size, stride=0.5): 26 | """ Return windows of the given signal by sweeping in stride fractions 27 | of window 28 | """ 29 | assert signal.ndim == 1, signal.ndim 30 | n_samples = signal.shape[0] 31 | offset = int(window_size * stride) 32 | slices = [] 33 | for beg_i, end_i in zip( 34 | range(0, n_samples, offset), 35 | range(window_size, n_samples + offset, offset)): 36 | if end_i - beg_i < window_size: 37 | break 38 | slice_ = signal[beg_i:end_i] 39 | if slice_.shape[0] == window_size: 40 | slices.append(slice_) 41 | return np.array(slices, dtype=np.int32) 42 | 43 | 44 | def read_and_slice(filename, wav_canvas_size, stride=0.5): 45 | fm, wav_data = wavfile.read(filename) 46 | if fm != 16000: 47 | raise ValueError('Sampling rate is expected to be 16kHz!') 48 | signals = slice_signal(wav_data, wav_canvas_size, stride) 49 | return signals 50 | 51 | 52 | def encoder_proc(wav_filename, noisy_path, out_file, wav_canvas_size): 53 | """ Read and slice the wav and noisy files and write to TFRecords. 54 | out_file: TFRecordWriter. 55 | """ 56 | ppath, wav_fullname = os.path.split(wav_filename) 57 | noisy_filename = os.path.join(noisy_path, wav_fullname) 58 | wav_signals = read_and_slice(wav_filename, wav_canvas_size) 59 | noisy_signals = read_and_slice(noisy_filename, wav_canvas_size) 60 | assert wav_signals.shape == noisy_signals.shape, noisy_signals.shape 61 | 62 | for (wav, noisy) in zip(wav_signals, noisy_signals): 63 | wav_raw = wav.tostring() 64 | noisy_raw = noisy.tostring() 65 | example = tf.train.Example( 66 | features=tf.train.Features( 67 | feature={ 68 | 'wav_raw': _bytes_feature(wav_raw), 69 | 'noisy_raw': _bytes_feature(noisy_raw) 70 | })) 71 | out_file.write(example.SerializeToString()) 72 | 73 | 74 | def main(opts): 75 | if not os.path.exists(opts.save_path): 76 | # make save path if it does not exist 77 | os.makedirs(opts.save_path) 78 | # set up the output filepath 79 | out_filepath = os.path.join(opts.save_path, opts.out_file) 80 | if os.path.splitext(out_filepath)[1] != '.tfrecords': 81 | # if wrong extension or no extension appended, put .tfrecords 82 | out_filepath += '.tfrecords' 83 | else: 84 | out_filename, ext = os.path.splitext(out_filepath) 85 | out_filepath = out_filename + ext 86 | # check if out_file exists and if force flag is set 87 | if os.path.exists(out_filepath) and not opts.force_gen: 88 | raise ValueError( 89 | 'ERROR: {} already exists. Set force flag (--force-gen) to ' 90 | 'overwrite. Skipping this speaker.'.format(out_filepath)) 91 | elif os.path.exists(out_filepath) and opts.force_gen: 92 | print('Will overwrite previously existing tfrecords') 93 | os.unlink(out_filepath) 94 | 95 | beg_enc_t = timeit.default_timer() 96 | out_file = tf.python_io.TFRecordWriter(out_filepath) 97 | # process the acoustic and textual data now 98 | print('-' * 50) 99 | wav_dir = opts.wav_dir # clean wav dir 100 | for wav in os.listdir(wav_dir): 101 | print(wav) 102 | wav_files = [ 103 | os.path.join(wav_dir, wav) for wav in os.listdir(wav_dir) 104 | if wav.endswith('.wav') 105 | ] 106 | noisy_dir = opts.noisy_dir # noisy wav dir 107 | nfiles = len(wav_files) 108 | for m, wav_file in enumerate(wav_files): 109 | print('Processing wav file {}/{} {}{}'.format( 110 | m + 1, nfiles, wav_file, ' ' * 10), end='\r') 111 | sys.stdout.flush() 112 | encoder_proc(wav_file, noisy_dir, out_file, 2**14) 113 | out_file.close() 114 | end_enc_t = timeit.default_timer() - beg_enc_t 115 | print('') 116 | print('*' * 50) 117 | print('Total processing and writing time: {} s'.format(end_enc_t)) 118 | 119 | 120 | if __name__ == '__main__': 121 | flags = tf.app.flags 122 | flags.DEFINE_string("wav_dir", "data/clean_trainset_wav_16k/", "Directory containing the wave files.") 123 | flags.DEFINE_string("noisy_dir", "data/noisy_trainset_wav_16k/", "Directory containing the noisy wave files.") 124 | flags.DEFINE_string("save_path", "data/", "Save path.") 125 | flags.DEFINE_string("out_file", "segan.tfrecords", "Output filename.") 126 | flags.DEFINE_boolean("force-gen", True, "Flag to force overwriting existing dataset") 127 | 128 | ''' 129 | parser = argparse.ArgumentParser(description='Convert the set of txt and ' 130 | 'wavs to TFRecords') 131 | parser.add_argument( 132 | '--wav_dir', 133 | type=str, 134 | default='data/clean_trainset_wav_16k/', 135 | help='Directory containing the wave files ') 136 | parser.add_argument( 137 | '--noisy_dir', 138 | type=str, 139 | default='data/noisy_trainset_wav_16k/', 140 | help='Directory containing the wave files ') 141 | parser.add_argument( 142 | '--save_path', 143 | type=str, 144 | default='data/', 145 | help='Path to save the dataset') 146 | parser.add_argument( 147 | '--out_file', 148 | type=str, 149 | default='segan.tfrecords', 150 | help='Output filename') 151 | parser.add_argument( 152 | '--force-gen', 153 | dest='force_gen', 154 | action='store_true', 155 | help='Flag to force overwriting existing dataset.') 156 | # parser.set_defaults(force_gen=False) 157 | parser.set_defaults(force_gen=True) 158 | opts = parser.parse_args() 159 | main(opts) 160 | ''' 161 | main(flags.FLAGS) 162 | 163 | -------------------------------------------------------------------------------- /sasegan/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | import numpy as np 4 | #from model import SEGAN, SEAE 5 | from model import SEGAN 6 | import os 7 | from tensorflow.python.client import device_lib 8 | from scipy.io import wavfile 9 | from data_loader import pre_emph 10 | 11 | import glob 12 | 13 | devices = device_lib.list_local_devices() 14 | 15 | flags = tf.app.flags 16 | flags.DEFINE_integer("seed", 111, "Random seed (Def: 111).") 17 | flags.DEFINE_integer("epoch", 150, "Epochs to train (Def: 150).") 18 | flags.DEFINE_integer("batch_size", 150, "Batch size (Def: 150).") 19 | flags.DEFINE_integer("save_freq", 50, "Batch save freq (Def: 50).") 20 | flags.DEFINE_integer("canvas_size", 2**14, "Canvas size (Def: 2^14).") 21 | flags.DEFINE_integer("denoise_epoch", 5, "Epoch where noise in disc is " 22 | "removed (Def: 5).") 23 | flags.DEFINE_integer("l1_remove_epoch", 150, "Epoch where L1 in G is " 24 | "removed (Def: 150).") 25 | flags.DEFINE_boolean("bias_deconv", False, 26 | "Flag to specify if we bias deconvs (Def: False)") 27 | flags.DEFINE_boolean("bias_downconv", False, 28 | "flag to specify if we bias downconvs (def: false)") 29 | flags.DEFINE_boolean("bias_D_conv", False, 30 | "flag to specify if we bias D_convs (def: false)") 31 | # TODO: noise decay is under check 32 | flags.DEFINE_float("denoise_lbound", 0.01, 33 | "Min noise std to be still alive (Def: 0.001)") 34 | flags.DEFINE_float("noise_decay", 0.7, "Decay rate of noise std (Def: 0.7)") 35 | flags.DEFINE_float("d_label_smooth", 0.25, "Smooth factor in D (Def: 0.25)") 36 | flags.DEFINE_float("init_noise_std", 0.5, "Init noise std (Def: 0.5)") 37 | flags.DEFINE_float("init_l1_weight", 100., "Init L1 lambda (Def: 100)") 38 | flags.DEFINE_integer("z_dim", 256, "Dimension of input noise to G (Def: 256).") 39 | flags.DEFINE_integer("z_depth", 256, "Depth of input noise to G (Def: 256).") 40 | flags.DEFINE_string("save_path", "segan_results", "Path to save out model " 41 | "files. (Def: dwavegan_model" 42 | ").") 43 | flags.DEFINE_string("g_nl", "leaky", 44 | "Type of nonlinearity in G: leaky or prelu. (Def: leaky).") 45 | flags.DEFINE_string("model", "gan", 46 | "Type of model to train: gan or ae. (Def: gan).") 47 | flags.DEFINE_string("deconv_type", "deconv", 48 | "Type of deconv method: deconv or " 49 | "nn_deconv (Def: deconv).") 50 | flags.DEFINE_string("g_type", "ae", 51 | "Type of G to use: ae or dwave. (Def: ae).") 52 | flags.DEFINE_float("g_learning_rate", 0.0002, "G learning_rate (Def: 0.0002)") 53 | flags.DEFINE_float("d_learning_rate", 0.0002, "D learning_rate (Def: 0.0002)") 54 | flags.DEFINE_float("beta_1", 0.5, "Adam beta 1 (Def: 0.5)") 55 | flags.DEFINE_float("preemph", 0.95, "Pre-emph factor (Def: 0.95)") 56 | flags.DEFINE_string("synthesis_path", "dwavegan_samples", "Path to save output" 57 | " generated samples." 58 | " (Def: dwavegan_sam" 59 | "ples).") 60 | flags.DEFINE_string("e2e_dataset", "data/segan.tfrecords", "TFRecords" 61 | " (Def: data/" 62 | "segan.tfrecords.") 63 | flags.DEFINE_string("save_clean_path", "test_clean_results", 64 | "Path to save clean utts") 65 | flags.DEFINE_string("test_wav", None, "name of test wav (it won't train)") 66 | flags.DEFINE_string("test_wav_dir", None, "name of test wav directory (it won't train)") 67 | flags.DEFINE_string("weights", None, "Weights file") 68 | flags.DEFINE_string("att_layer_ind", "5", "Layer at which attention take places (default: '5').") 69 | FLAGS = flags.FLAGS 70 | 71 | def pre_emph_test(coeff, canvas_size): 72 | x_ = tf.placeholder( 73 | tf.float32, shape=[ 74 | canvas_size, 75 | ]) 76 | x_preemph = pre_emph(x_, coeff) 77 | return x_, x_preemph 78 | 79 | 80 | def main(_): 81 | print('Parsed arguments: ', FLAGS.__flags) 82 | 83 | # make save path if it is required 84 | if not os.path.exists(FLAGS.save_path): 85 | os.makedirs(FLAGS.save_path) 86 | if not os.path.exists(FLAGS.synthesis_path): 87 | os.makedirs(FLAGS.synthesis_path) 88 | np.random.seed(FLAGS.seed) 89 | config = tf.ConfigProto() 90 | config.gpu_options.allow_growth = True 91 | config.allow_soft_placement = True 92 | udevices = [] 93 | for device in devices: 94 | if len(devices) > 1 and 'CPU' in device.name: 95 | # Use cpu only when we dont have gpus 96 | continue 97 | print('Using device: ', device.name) 98 | udevices.append(device.name) 99 | # execute the session 100 | with tf.Session(config=config) as sess: 101 | if FLAGS.model == 'gan': 102 | print('Creating GAN model') 103 | se_model = SEGAN(sess, FLAGS, udevices) 104 | else: 105 | raise ValueError('{} model type not understood!'.format( 106 | FLAGS.model)) 107 | if FLAGS.test_wav is None and FLAGS.test_wav_dir is None: 108 | se_model.train(FLAGS, udevices) 109 | elif FLAGS.test_wav is not None: # test 1 file 110 | if FLAGS.weights is None: 111 | raise ValueError('weights must be specified!') 112 | print('Loading model weights...') 113 | se_model.load(FLAGS.save_path, FLAGS.weights) 114 | fm, wav_data = wavfile.read(FLAGS.test_wav) 115 | wavname = FLAGS.test_wav.split('/')[-1] 116 | if fm != 16000: 117 | raise ValueError('16kHz required! Test file is different') 118 | wave = (2. / 65535.) * (wav_data.astype(np.float32) - 32767) + 1. 119 | if FLAGS.preemph > 0: 120 | print('preemph test wave with {}'.format(FLAGS.preemph)) 121 | x_pholder, preemph_op = pre_emph_test(FLAGS.preemph,wave.shape[0]) 122 | wave = sess.run(preemph_op, feed_dict={x_pholder: wave}) 123 | print('test wave shape: ', wave.shape) 124 | print('test wave min:{} max:{}'.format(np.min(wave), np.max(wave))) 125 | c_wave = se_model.clean(wave) 126 | print('c wave min:{} max:{}'.format(np.min(c_wave), np.max(c_wave))) 127 | wavfile.write(os.path.join(FLAGS.save_clean_path, wavname), int(16e3), c_wave) 128 | print('Done cleaning {} and saved to {}'.format(FLAGS.test_wav, 129 | os.path.join(FLAGS.save_clean_path, wavname))) 130 | else: # test 1 directory 131 | if FLAGS.weights is None: 132 | raise ValueError('weights must be specified!') 133 | print('Loading model weights...') 134 | se_model.load(FLAGS.save_path, FLAGS.weights) 135 | 136 | for test_wav in glob.glob(FLAGS.test_wav_dir + "*.wav"): 137 | print(test_wav) 138 | fm, wav_data = wavfile.read(test_wav) 139 | wavname = test_wav.split('/')[-1] 140 | if fm != 16000: 141 | raise ValueError('16kHz required! Test file is different') 142 | wave = (2. / 65535.) * (wav_data.astype(np.float32) - 32767) + 1. 143 | if FLAGS.preemph > 0: 144 | print('preemph test wave with {}'.format(FLAGS.preemph)) 145 | x_pholder, preemph_op = pre_emph_test(FLAGS.preemph, wave.shape[0]) 146 | wave = sess.run(preemph_op, feed_dict={x_pholder: wave}) 147 | print('test wave shape: ', wave.shape) 148 | print('test wave min:{} max:{}'.format(np.min(wave), np.max(wave))) 149 | c_wave = se_model.clean(wave) 150 | print('c wave min:{} max:{}'.format(np.min(c_wave), np.max(c_wave))) 151 | wavfile.write(os.path.join(FLAGS.save_clean_path, wavname), int(16e3), c_wave) 152 | print('Done cleaning {} and saved to {}'.format(test_wav, 153 | os.path.join(FLAGS.save_clean_path, wavname))) 154 | 155 | 156 | if __name__ == '__main__': 157 | tf.app.run() 158 | -------------------------------------------------------------------------------- /evaluate/stoi.m: -------------------------------------------------------------------------------- 1 | function d = stoi(x, y, fs_signal) 2 | % The Short-Time Objective Intelligibility measure 3 | % d = stoi(x, y, fs_signal) returns the output of the short-time 4 | % objective intelligibility (STOI) measure described in [1, 2], where x 5 | % and y denote the clean and processed speech, respectively, with sample 6 | % rate fs_signal in Hz. The output d is expected to have a monotonic 7 | % relation with the subjective speech-intelligibility, where a higher d 8 | % denotes better intelligible speech. See [1, 2] for more details. 9 | % 10 | % References: 11 | % [1] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time 12 | % Objective Intelligibility Measure for Time-Frequency Weighted Noisy 13 | % Speech', ICASSP 2010, Texas, Dallas. 14 | % 15 | % [2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for 16 | % Intelligibility Prediction of Time-Frequency Weighted Noisy Speech', 17 | % IEEE Transactions on Audio, Speech, and Language Processing, 2011. 18 | % 19 | % 20 | % Copyright 2009: Delft University of Technology, Signal & Information 21 | % Processing Lab. The software is free for non-commercial use. This program 22 | % comes WITHOUT ANY WARRANTY. 23 | % 24 | % 25 | % 26 | % Updates: 27 | % 2011-04-26 Using the more efficient 'taa_corr' instead of 'corr' 28 | 29 | if length(x)~=length(y) 30 | error('x and y should have the same length'); 31 | end 32 | 33 | % initialization 34 | x = x(:); % clean speech column vector 35 | y = y(:); % processed speech column vector 36 | 37 | fs = 10000; % sample rate of proposed intelligibility measure 38 | N_frame = 256; % window support 39 | K = 512; % FFT size 40 | J = 15; % Number of 1/3 octave bands 41 | mn = 150; % Center frequency of first 1/3 octave band in Hz. 42 | H = thirdoct(fs, K, J, mn); % Get 1/3 octave band matrix 43 | N = 30; % Number of frames for intermediate intelligibility measure (Length analysis window) 44 | Beta = -15; % lower SDR-bound 45 | dyn_range = 40; % speech dynamic range 46 | 47 | % resample signals if other samplerate is used than fs 48 | if fs_signal ~= fs 49 | x = resample(x, fs, fs_signal); 50 | y = resample(y, fs, fs_signal); 51 | end 52 | 53 | % remove silent frames 54 | [x y] = removeSilentFrames(x, y, dyn_range, N_frame, N_frame/2); 55 | 56 | % apply 1/3 octave band TF-decomposition 57 | x_hat = stdft(x, N_frame, N_frame/2, K); % apply short-time DFT to clean speech 58 | y_hat = stdft(y, N_frame, N_frame/2, K); % apply short-time DFT to processed speech 59 | 60 | x_hat = x_hat(:, 1:(K/2+1)).'; % take clean single-sided spectrum 61 | y_hat = y_hat(:, 1:(K/2+1)).'; % take processed single-sided spectrum 62 | 63 | X = zeros(J, size(x_hat, 2)); % init memory for clean speech 1/3 octave band TF-representation 64 | Y = zeros(J, size(y_hat, 2)); % init memory for processed speech 1/3 octave band TF-representation 65 | 66 | for i = 1:size(x_hat, 2) 67 | X(:, i) = sqrt(H*abs(x_hat(:, i)).^2); % apply 1/3 octave bands as described in Eq.(1) [1] 68 | Y(:, i) = sqrt(H*abs(y_hat(:, i)).^2); 69 | end 70 | 71 | % loop al segments of length N and obtain intermediate intelligibility measure for all TF-regions 72 | d_interm = zeros(J, length(N:size(X, 2))); % init memory for intermediate intelligibility measure 73 | c = 10^(-Beta/20); % constant for clipping procedure 74 | 75 | for m = N:size(X, 2) 76 | X_seg = X(:, (m-N+1):m); % region with length N of clean TF-units for all j 77 | Y_seg = Y(:, (m-N+1):m); % region with length N of processed TF-units for all j 78 | alpha = sqrt(sum(X_seg.^2, 2)./sum(Y_seg.^2, 2)); % obtain scale factor for normalizing processed TF-region for all j 79 | aY_seg = Y_seg.*repmat(alpha, [1 N]); % obtain \alpha*Y_j(n) from Eq.(2) [1] 80 | for j = 1:J 81 | Y_prime = min(aY_seg(j, :), X_seg(j, :)+X_seg(j, :)*c); % apply clipping from Eq.(3) 82 | d_interm(j, m-N+1) = taa_corr(X_seg(j, :).', Y_prime(:)); % obtain correlation coeffecient from Eq.(4) [1] 83 | end 84 | end 85 | 86 | d = mean(d_interm(:)); % combine all intermediate intelligibility measures as in Eq.(4) [1] 87 | 88 | %% 89 | function [A cf] = thirdoct(fs, N_fft, numBands, mn) 90 | % [A CF] = THIRDOCT(FS, N_FFT, NUMBANDS, MN) returns 1/3 octave band matrix 91 | % inputs: 92 | % FS: samplerate 93 | % N_FFT: FFT size 94 | % NUMBANDS: number of bands 95 | % MN: center frequency of first 1/3 octave band 96 | % outputs: 97 | % A: octave band matrix 98 | % CF: center frequencies 99 | 100 | f = linspace(0, fs, N_fft+1); 101 | f = f(1:(N_fft/2+1)); 102 | k = 0:(numBands-1); 103 | cf = 2.^(k/3)*mn; 104 | fl = sqrt((2.^(k/3)*mn).*2.^((k-1)/3)*mn); 105 | fr = sqrt((2.^(k/3)*mn).*2.^((k+1)/3)*mn); 106 | A = zeros(numBands, length(f)); 107 | 108 | for i = 1:(length(cf)) 109 | [a b] = min((f-fl(i)).^2); 110 | fl(i) = f(b); 111 | fl_ii = b; 112 | 113 | [a b] = min((f-fr(i)).^2); 114 | fr(i) = f(b); 115 | fr_ii = b; 116 | A(i,fl_ii:(fr_ii-1)) = 1; 117 | end 118 | 119 | rnk = sum(A, 2); 120 | numBands = find((rnk(2:end)>=rnk(1:(end-1))) & (rnk(2:end)~=0)~=0, 1, 'last' )+1; 121 | A = A(1:numBands, :); 122 | cf = cf(1:numBands); 123 | 124 | %% 125 | function x_stdft = stdft(x, N, K, N_fft) 126 | % X_STDFT = X_STDFT(X, N, K, N_FFT) returns the short-time 127 | % hanning-windowed dft of X with frame-size N, overlap K and DFT size 128 | % N_FFT. The columns and rows of X_STDFT denote the frame-index and 129 | % dft-bin index, respectively. 130 | 131 | frames = 1:K:(length(x)-N); 132 | x_stdft = zeros(length(frames), N_fft); 133 | 134 | w = hanning(N); 135 | x = x(:); 136 | 137 | for i = 1:length(frames) 138 | ii = frames(i):(frames(i)+N-1); 139 | x_stdft(i, :) = fft(x(ii).*w, N_fft); 140 | end 141 | 142 | %% 143 | function [x_sil y_sil] = removeSilentFrames(x, y, range, N, K) 144 | % [X_SIL Y_SIL] = REMOVESILENTFRAMES(X, Y, RANGE, N, K) X and Y 145 | % are segmented with frame-length N and overlap K, where the maximum energy 146 | % of all frames of X is determined, say X_MAX. X_SIL and Y_SIL are the 147 | % reconstructed signals, excluding the frames, where the energy of a frame 148 | % of X is smaller than X_MAX-RANGE 149 | 150 | x = x(:); 151 | y = y(:); 152 | 153 | frames = 1:K:(length(x)-N); 154 | w = hanning(N); 155 | msk = zeros(size(frames)); 156 | 157 | for j = 1:length(frames) 158 | jj = frames(j):(frames(j)+N-1); 159 | msk(j) = 20*log10(norm(x(jj).*w)./sqrt(N)); 160 | end 161 | 162 | msk = (msk-max(msk)+range)>0; 163 | count = 1; 164 | 165 | x_sil = zeros(size(x)); 166 | y_sil = zeros(size(y)); 167 | 168 | for j = 1:length(frames) 169 | if msk(j) 170 | jj_i = frames(j):(frames(j)+N-1); 171 | jj_o = frames(count):(frames(count)+N-1); 172 | x_sil(jj_o) = x_sil(jj_o) + x(jj_i).*w; 173 | y_sil(jj_o) = y_sil(jj_o) + y(jj_i).*w; 174 | count = count+1; 175 | end 176 | end 177 | 178 | x_sil = x_sil(1:jj_o(end)); 179 | y_sil = y_sil(1:jj_o(end)); 180 | 181 | %% 182 | function rho = taa_corr(x, y) 183 | % RHO = TAA_CORR(X, Y) Returns correlation coeffecient between column 184 | % vectors x and y. Gives same results as 'corr' from statistics toolbox. 185 | xn = x-mean(x); 186 | xn = xn/sqrt(sum(xn.^2)); 187 | yn = y-mean(y); 188 | yn = yn/sqrt(sum(yn.^2)); 189 | rho = sum(xn.*yn); -------------------------------------------------------------------------------- /sasegan/selfattention.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | from tensorflow.contrib.layers import batch_norm, fully_connected, flatten 4 | from tensorflow.contrib.layers import xavier_initializer 5 | from contextlib import contextmanager 6 | import numpy as np 7 | 8 | def _l2normalize(v, eps=1e-12): 9 | """l2 normize the input vector.""" 10 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) 11 | 12 | def spectral_normed_weight(weights, num_iters=1, update_collection=None, with_sigma=False): 13 | """Performs Spectral Normalization on a weight tensor. 14 | Specifically it divides the weight tensor by its largest singular value. This 15 | is intended to stabilize GAN training, by making the discriminator satisfy a 16 | local 1-Lipschitz constraint. 17 | Based on [Spectral Normalization for Generative Adversarial Networks][sn-gan] 18 | [sn-gan] https://openreview.net/pdf?id=B1QRgziT- 19 | Args: 20 | weights: The weight tensor which requires spectral normalization 21 | num_iters: Number of SN iterations. 22 | update_collection: The update collection for assigning persisted variable u. 23 | If None, the function will update u during the forward 24 | pass. Else if the update_collection equals 'NO_OPS', the 25 | function will not update the u during the forward. This 26 | is useful for the discriminator, since it does not update 27 | u in the second pass. 28 | Else, it will put the assignment in a collection 29 | defined by the user. Then the user need to run the 30 | assignment explicitly. 31 | with_sigma: For debugging purpose. If True, the fuction returns 32 | the estimated singular value for the weight tensor. 33 | Returns: 34 | w_bar: The normalized weight tensor 35 | sigma: The estimated singular value for the weight tensor. 36 | """ 37 | w_shape = weights.shape.as_list() 38 | w_mat = tf.reshape(weights, [-1, w_shape[-1]]) # [-1, output_channel] 39 | u = tf.get_variable('u', [1, w_shape[-1]], 40 | initializer=tf.truncated_normal_initializer(), 41 | trainable=False) 42 | u_ = u 43 | for _ in range(num_iters): 44 | v_ = _l2normalize(tf.matmul(u_, w_mat, transpose_b=True)) 45 | u_ = _l2normalize(tf.matmul(v_, w_mat)) 46 | 47 | sigma = tf.squeeze(tf.matmul(tf.matmul(v_, w_mat), u_, transpose_b=True)) 48 | w_mat /= sigma 49 | if update_collection is None: 50 | with tf.control_dependencies([u.assign(u_)]): 51 | w_bar = tf.reshape(w_mat, w_shape) 52 | else: 53 | w_bar = tf.reshape(w_mat, w_shape) 54 | if update_collection != 'NO_OPS': 55 | tf.add_to_collection(update_collection, u.assign(u_)) 56 | if with_sigma: 57 | return w_bar, sigma 58 | else: 59 | return w_bar 60 | 61 | def conv1x1(input_, output_dim, init=tf.contrib.layers.xavier_initializer(), name='conv1x1'): 62 | k_h = 1 63 | k_w = 1 64 | d_h = 1 65 | d_w = 1 66 | with tf.variable_scope(name): 67 | w = tf.get_variable( 68 | 'w', [k_h, k_w, input_.get_shape()[-1], output_dim], 69 | initializer=init) 70 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 71 | return conv 72 | 73 | def sn_conv1x1(input_, output_dim, update_collection, 74 | init=tf.contrib.layers.xavier_initializer(), name='sn_conv1x1'): 75 | with tf.variable_scope(name): 76 | k_h = 1 77 | k_w = 1 78 | d_h = 1 79 | d_w = 1 80 | w = tf.get_variable( 81 | 'w', [k_h, k_w, input_.get_shape()[-1], output_dim], 82 | initializer=init) 83 | w_bar = spectral_normed_weight(w, num_iters=1, update_collection=update_collection) 84 | 85 | conv = tf.nn.conv2d(input_, w_bar, strides=[1, d_h, d_w, 1], padding='SAME') 86 | return conv 87 | 88 | def sn_non_local_block_sim(x, update_collection, name, init=tf.contrib.layers.xavier_initializer()): 89 | with tf.variable_scope(name): 90 | batch_size, h, w, num_channels = x.get_shape().as_list() 91 | location_num = h * w 92 | downsampled_num = location_num // 4 93 | #downsampled_num = location_num 94 | 95 | # theta path 96 | theta = sn_conv1x1(x, num_channels // 8, update_collection, init, 'sn_conv_theta') 97 | theta = tf.reshape(theta, [batch_size, location_num, num_channels // 8]) 98 | 99 | # phi path 100 | phi = sn_conv1x1(x, num_channels // 8, update_collection, init, 'sn_conv_phi') 101 | phi = tf.layers.max_pooling2d(inputs=phi, pool_size=[4, 1], strides=[4,1]) 102 | phi = tf.reshape(phi, [batch_size, downsampled_num, num_channels // 8]) 103 | 104 | 105 | attn = tf.matmul(theta, phi, transpose_b=True) 106 | attn = tf.nn.softmax(attn) 107 | print(tf.reduce_sum(attn, axis=-1)) 108 | 109 | # g path 110 | g = sn_conv1x1(x, num_channels // 2, update_collection, init, 'sn_conv_g') 111 | g = tf.layers.max_pooling2d(inputs=g, pool_size=[4, 1], strides=[4,1]) 112 | g = tf.reshape(g, [batch_size, downsampled_num, num_channels // 2]) 113 | 114 | attn_g = tf.matmul(attn, g) 115 | attn_g = tf.reshape(attn_g, [batch_size, h, w, num_channels // 2]) 116 | sigma = tf.get_variable('sigma_ratio', [], initializer=tf.constant_initializer(0.0)) 117 | attn_g = sn_conv1x1(attn_g, num_channels, update_collection, init, 'sn_conv_attn') 118 | return x + sigma * attn_g 119 | 120 | 121 | def snconv2d(input_, output_dim, 122 | k_h=3, k_w=3, d_h=1, d_w=1, 123 | sn_iters=1, update_collection=None, name='snconv2d'): 124 | """Creates a spectral normalized (SN) convolutional layer. 125 | Args: 126 | input_: 4D input tensor (batch size, height, width, channel). 127 | output_dim: Number of features in the output layer. 128 | k_h: The height of the convolutional kernel. 129 | k_w: The width of the convolutional kernel. 130 | d_h: The height stride of the convolutional kernel. 131 | d_w: The width stride of the convolutional kernel. 132 | sn_iters: The number of SN iterations. 133 | update_collection: The update collection used in spectral_normed_weight. 134 | name: The name of the variable scope. 135 | Returns: 136 | conv: The normalized tensor. 137 | """ 138 | with tf.variable_scope(name): 139 | w = tf.get_variable( 140 | 'w', [k_h, k_w, input_.get_shape()[-1], output_dim], 141 | initializer=tf.contrib.layers.xavier_initializer()) 142 | w_bar = spectral_normed_weight(w, num_iters=sn_iters, 143 | update_collection=update_collection) 144 | 145 | conv = tf.nn.conv2d(input_, w_bar, strides=[1, d_h, d_w, 1], padding='SAME') 146 | biases = tf.get_variable('biases', [output_dim], 147 | initializer=tf.zeros_initializer()) 148 | conv = tf.nn.bias_add(conv, biases) 149 | return conv 150 | 151 | def sn_downconv(x, 152 | output_dim, 153 | kwidth=5, 154 | pool=2, 155 | init=None, 156 | uniform=False, 157 | bias_init=None, 158 | name='downconv'): 159 | """ Downsampled convolution 1d """ 160 | x2d = tf.expand_dims(x, 2) 161 | w_init = init 162 | if w_init is None: 163 | w_init = xavier_initializer(uniform=uniform) 164 | with tf.variable_scope(name): 165 | W = tf.get_variable( 166 | 'W', [kwidth, 1, x.get_shape()[-1], output_dim], 167 | initializer=w_init) 168 | W_bar = spectral_normed_weight(W) 169 | conv = tf.nn.conv2d(x2d, W_bar, strides=[1, pool, 1, 1], padding='SAME') 170 | if bias_init is not None: 171 | b = tf.get_variable('b', [output_dim], initializer=bias_init) 172 | conv = tf.reshape(tf.nn.bias_add(conv, b), conv.get_shape()) 173 | else: 174 | conv = tf.reshape(conv, conv.get_shape()) 175 | # reshape back to 1d 176 | conv = tf.reshape( 177 | conv, 178 | conv.get_shape().as_list()[:2] + [conv.get_shape().as_list()[-1]]) 179 | return conv 180 | 181 | 182 | def sn_deconv(x, 183 | output_shape, 184 | kwidth=5, 185 | dilation=2, 186 | init=None, 187 | uniform=False, 188 | bias_init=None, 189 | name='deconv1d'): 190 | input_shape = x.get_shape() 191 | in_channels = input_shape[-1] 192 | out_channels = output_shape[-1] 193 | assert len(input_shape) >= 3 194 | # reshape the tensor to use 2d operators 195 | x2d = tf.expand_dims(x, 2) 196 | o2d = output_shape[:2] + [1] + [output_shape[-1]] 197 | w_init = init 198 | if w_init is None: 199 | w_init = xavier_initializer(uniform=uniform) 200 | with tf.variable_scope(name): 201 | # filter shape: [kwidth, output_channels, in_channels] 202 | W = tf.get_variable('W', [kwidth, 1, out_channels, in_channels], initializer=w_init) 203 | W_bar = spectral_normed_weight(W) 204 | try: 205 | deconv = tf.nn.conv2d_transpose(x2d, W_bar, output_shape=o2d, strides=[1, dilation, 1, 1]) 206 | except AttributeError: 207 | # support for versions of TF before 0.7.0 208 | # based on https://github.com/carpedm20/DCGAN-tensorflow 209 | deconv = tf.nn.conv2d_transpose(x2d, W_bar, output_shape=o2d, strides=[1, dilation, 1, 1]) 210 | if bias_init is not None: 211 | b = tf.get_variable( 212 | 'b', [out_channels], initializer=tf.constant_initializer(0.)) 213 | deconv = tf.reshape(tf.nn.bias_add(deconv, b), deconv.get_shape()) 214 | else: 215 | deconv = tf.reshape(deconv, deconv.get_shape()) 216 | # reshape back to 1d 217 | deconv = tf.reshape(deconv, output_shape) 218 | return deconv -------------------------------------------------------------------------------- /sasegan/generator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | from tensorflow.contrib.layers import batch_norm, fully_connected, flatten 4 | from tensorflow.contrib.layers import xavier_initializer 5 | from ops import * 6 | import numpy as np 7 | 8 | from selfattention import * 9 | 10 | 11 | class Generator(object): 12 | def __init__(self, segan): 13 | self.segan = segan 14 | 15 | def __call__(self, noisy_w, is_ref, spk=None): 16 | """ Build the graph propagating (noisy_w) --> x 17 | On first pass will make variables. 18 | """ 19 | segan = self.segan 20 | 21 | def make_z(shape, mean=0., std=1., name='z'): 22 | if is_ref: 23 | with tf.variable_scope(name) as scope: 24 | z_init = tf.random_normal_initializer( 25 | mean=mean, stddev=std) 26 | z = tf.get_variable( 27 | "z", shape, initializer=z_init, trainable=False) 28 | if z.device != "/device:GPU:0": 29 | # this has to be created into gpu0 30 | print('z.device is {}'.format(z.device)) 31 | assert False 32 | else: 33 | z = tf.random_normal( 34 | shape, mean=mean, stddev=std, name=name, dtype=tf.float32) 35 | return z 36 | 37 | if hasattr(segan, 'generator_built'): 38 | tf.get_variable_scope().reuse_variables() 39 | make_vars = False 40 | else: 41 | make_vars = True 42 | 43 | print('*** Building Generator ***') 44 | in_dims = noisy_w.get_shape().as_list() 45 | h_i = noisy_w 46 | if len(in_dims) == 2: 47 | h_i = tf.expand_dims(noisy_w, -1) 48 | elif len(in_dims) < 2 or len(in_dims) > 3: 49 | raise ValueError('Generator input must be 2-D or 3-D') 50 | kwidth = 3 51 | z = make_z([ 52 | segan.batch_size, 53 | h_i.get_shape().as_list()[1], segan.g_enc_depths[-1] 54 | ]) 55 | h_i = tf.concat(2, [h_i, z]) 56 | skip_out = True 57 | skips = [] 58 | for block_idx, dilation in enumerate(segan.g_dilated_blocks): 59 | name = 'g_residual_block_{}'.format(block_idx) 60 | if block_idx >= len(segan.g_dilated_blocks) - 1: 61 | skip_out = False 62 | if skip_out: 63 | res_i, skip_i = residual_block( 64 | h_i, 65 | dilation, 66 | kwidth, 67 | num_kernels=32, 68 | bias_init=None, 69 | stddev=0.02, 70 | do_skip=True, 71 | name=name) 72 | else: 73 | res_i = residual_block( 74 | h_i, 75 | dilation, 76 | kwidth, 77 | num_kernels=32, 78 | bias_init=None, 79 | stddev=0.02, 80 | do_skip=False, 81 | name=name) 82 | # feed the residual output to the next block 83 | h_i = res_i 84 | if segan.keep_prob < 1: 85 | print('Adding dropout w/ keep prob {} ' 86 | 'to G'.format(segan.keep_prob)) 87 | h_i = tf.nn.dropout(h_i, segan.keep_prob_var) 88 | if skip_out: 89 | # accumulate the skip connections 90 | skips.append(skip_i) 91 | else: 92 | # for last block, the residual output is appended 93 | skips.append(res_i) 94 | print('Amount of skip connections: ', len(skips)) 95 | # TODO: last pooling for actual wave 96 | with tf.variable_scope('g_wave_pooling'): 97 | skip_T = tf.stack(skips, axis=0) 98 | skips_sum = tf.reduce_sum(skip_T, axis=0) 99 | skips_sum = leakyrelu(skips_sum) 100 | wave_a = conv1d( 101 | skips_sum, 102 | kwidth=1, 103 | num_kernels=1, 104 | init=tf.truncated_normal_initializer(stddev=0.02)) 105 | wave = tf.tanh(wave_a) 106 | ''' 107 | segan.gen_wave_summ = histogram_summary('gen_wave', wave) 108 | ''' 109 | print('Last residual wave shape: ', res_i.get_shape()) 110 | print('*************************') 111 | segan.generator_built = True 112 | return wave, z 113 | 114 | 115 | class AEGenerator(object): 116 | def __init__(self, segan): 117 | self.segan = segan 118 | 119 | def __call__(self, noisy_w, is_ref, spk=None, z_on=True, do_prelu=False): 120 | # TODO: remove c_vec 121 | """ Build the graph propagating (noisy_w) --> x 122 | On first pass will make variables. 123 | """ 124 | segan = self.segan 125 | 126 | def make_z(shape, mean=0., std=1., name='z'): 127 | if is_ref: 128 | with tf.variable_scope(name) as scope: 129 | z_init = tf.random_normal_initializer( 130 | mean=mean, stddev=std) 131 | z = tf.get_variable( 132 | "z", shape, initializer=z_init, trainable=False) 133 | if z.device != "/device:GPU:0": 134 | # this has to be created into gpu0 135 | print('z.device is {}'.format(z.device)) 136 | assert False 137 | else: 138 | z = tf.random_normal( 139 | shape, mean=mean, stddev=std, name=name, dtype=tf.float32) 140 | return z 141 | 142 | if hasattr(segan, 'generator_built'): 143 | tf.get_variable_scope().reuse_variables() 144 | make_vars = False 145 | else: 146 | make_vars = True 147 | if is_ref: 148 | print('*** Building Generator ***') 149 | in_dims = noisy_w.get_shape().as_list() 150 | h_i = noisy_w 151 | if len(in_dims) == 2: 152 | h_i = tf.expand_dims(noisy_w, -1) 153 | elif len(in_dims) < 2 or len(in_dims) > 3: 154 | raise ValueError('Generator input must be 2-D or 3-D') 155 | kwidth = 31 156 | enc_layers = 7 157 | skips = [] 158 | if is_ref and do_prelu: 159 | #keep track of prelu activations 160 | alphas = [] 161 | with tf.variable_scope('g_ae'): 162 | #AE to be built is shaped: 163 | # enc ~ [16384x1, 8192x16, 4096x32, 2048x32, 1024x64, 512x64, 256x128, 128x128, 64x256, 32x256, 16x512, 8x1024] 164 | # dec ~ [8x2048, 16x1024, 32x512, 64x512, 8x256, 256x256, 512x128, 1024x128, 2048x64, 4096x64, 8192x32, 16384x1] 165 | #FIRST ENCODER 166 | for layer_idx, layer_depth in enumerate(segan.g_enc_depths): 167 | bias_init = None 168 | if segan.bias_downconv: 169 | if is_ref: 170 | print('Biasing downconv in G') 171 | bias_init = tf.constant_initializer(0.) 172 | h_i_dwn = sn_downconv( 173 | h_i, 174 | layer_depth, 175 | kwidth=kwidth, 176 | init=tf.truncated_normal_initializer(stddev=0.02), 177 | bias_init=bias_init, 178 | name='enc_{}'.format(layer_idx)) 179 | if is_ref: 180 | print('Downconv {} -> {}'.format(h_i.get_shape(), h_i_dwn.get_shape())) 181 | h_i = h_i_dwn 182 | if layer_idx < len(segan.g_enc_depths) - 1: 183 | if is_ref: 184 | print('Adding skip connection downconv {}'.format(layer_idx)) 185 | # store skip connection 186 | # last one is not stored cause it's the code 187 | skips.append(h_i) 188 | if do_prelu: 189 | if is_ref: 190 | print('-- Enc: prelu activation --') 191 | h_i = prelu(h_i, ref=is_ref, name='enc_prelu_{}'.format(layer_idx)) 192 | if is_ref: 193 | # split h_i into its components 194 | alpha_i = h_i[1] 195 | h_i = h_i[0] 196 | alphas.append(alpha_i) 197 | else: 198 | if is_ref: 199 | print('-- Enc: leakyrelu activation --') 200 | h_i = leakyrelu(h_i) 201 | 202 | #if layer_idx == segan.att_layer_ind: 203 | if layer_idx in segan.enc_att_layer_ind: 204 | # self-attention 205 | # to 2d 206 | hi_2d = tf.expand_dims(h_i, 2) 207 | hi_2d = sn_non_local_block_sim(hi_2d, None, 'encoder_attention_layer{}'.format(layer_idx)) 208 | h_i = tf.reshape(hi_2d, hi_2d.get_shape().as_list()[:2] + [hi_2d.get_shape().as_list()[-1]]) 209 | print('Downconv: self-attention') 210 | 211 | if z_on: 212 | # random code is fused with intermediate representation 213 | z = make_z([ 214 | segan.batch_size, 215 | h_i.get_shape().as_list()[1], segan.g_enc_depths[-1] 216 | ]) 217 | h_i = tf.concat([z, h_i], 2) 218 | 219 | #SECOND DECODER (reverse order) 220 | g_dec_depths = segan.g_enc_depths[:-1][::-1] + [1] 221 | if is_ref: 222 | print('g_dec_depths: ', g_dec_depths) 223 | for layer_idx, layer_depth in enumerate(g_dec_depths): 224 | h_i_dim = h_i.get_shape().as_list() 225 | out_shape = [h_i_dim[0], h_i_dim[1] * 2, layer_depth] 226 | bias_init = None 227 | # deconv 228 | if segan.deconv_type == 'deconv': 229 | if is_ref: 230 | print('-- Transposed deconvolution type --') 231 | if segan.bias_deconv: 232 | print('Biasing deconv in G') 233 | if segan.bias_deconv: 234 | bias_init = tf.constant_initializer(0.) 235 | h_i_dcv = sn_deconv( 236 | h_i, 237 | out_shape, 238 | kwidth=kwidth, 239 | dilation=2, 240 | init=tf.truncated_normal_initializer(stddev=0.02), 241 | bias_init=bias_init, 242 | name='dec_{}'.format(layer_idx)) 243 | else: 244 | raise ValueError('Unknown deconv type {}'.format( 245 | segan.deconv_type)) 246 | if is_ref: 247 | print('Deconv {} -> {}'.format(h_i.get_shape(), h_i_dcv.get_shape())) 248 | h_i = h_i_dcv 249 | if layer_idx < len(g_dec_depths) - 1: 250 | if do_prelu: 251 | if is_ref: 252 | print('-- Dec: prelu activation --') 253 | h_i = prelu(h_i, ref=is_ref, name='dec_prelu_{}'.format(layer_idx)) 254 | if is_ref: 255 | # split h_i into its components 256 | alpha_i = h_i[1] 257 | h_i = h_i[0] 258 | alphas.append(alpha_i) 259 | else: 260 | if is_ref: 261 | print('-- Dec: leakyrelu activation --') 262 | h_i = leakyrelu(h_i) 263 | # fuse skip connection 264 | skip_ = skips[-(layer_idx + 1)] 265 | if is_ref: 266 | print('Fusing skip connection of shape {}'.format(skip_.get_shape())) 267 | h_i = tf.concat([h_i, skip_], 2) 268 | 269 | else: 270 | if is_ref: 271 | print('-- Dec: tanh activation --') 272 | h_i = tf.tanh(h_i) 273 | 274 | # self-attention 275 | #if layer_idx == len(segan.g_enc_depths) - segan.att_layer_ind - 1: # the middle layer (5) 276 | if layer_idx in segan.dec_att_layer_ind: 277 | # self-attention 278 | # to 2d 279 | hi_2d = tf.expand_dims(h_i, 2) 280 | hi_2d = sn_non_local_block_sim(hi_2d, None, 'decoder_attention_layer{}'.format(layer_idx)) 281 | h_i = tf.reshape(hi_2d, hi_2d.get_shape().as_list()[:2] + [hi_2d.get_shape().as_list()[-1]]) 282 | print('Deconv: self-attention') 283 | 284 | wave = h_i 285 | if is_ref and do_prelu: 286 | print('Amount of alpha vectors: ', len(alphas)) 287 | if is_ref: 288 | print('Amount of skip connections: ', len(skips)) 289 | print('Last wave shape: ', wave.get_shape()) 290 | print('*************************') 291 | segan.generator_built = True 292 | # ret feats contains the features refs to be returned 293 | ret_feats = [wave] 294 | if z_on: 295 | ret_feats.append(z) 296 | if is_ref and do_prelu: 297 | ret_feats += alphas 298 | return ret_feats 299 | -------------------------------------------------------------------------------- /sasegan/ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | from tensorflow.contrib.layers import batch_norm, fully_connected, flatten 4 | from tensorflow.contrib.layers import xavier_initializer 5 | from contextlib import contextmanager 6 | import numpy as np 7 | 8 | 9 | def gaussian_noise_layer(input_layer, std): 10 | noise = tf.random_normal( 11 | shape=input_layer.get_shape().as_list(), 12 | mean=0.0, 13 | stddev=std, 14 | dtype=tf.float32) 15 | return input_layer + noise 16 | 17 | 18 | def sample_random_walk(batch_size, dim): 19 | rw = np.zeros((batch_size, dim)) 20 | rw[:, 0] = np.random.randn(batch_size) 21 | for b in range(batch_size): 22 | for di in range(1, dim): 23 | rw[b, di] = rw[b, di - 1] + np.random.randn(1) 24 | # normalize to m=0 std=1 25 | mean = np.mean(rw, axis=1).reshape((-1, 1)) 26 | std = np.std(rw, axis=1).reshape((-1, 1)) 27 | rw = (rw - mean) / std 28 | return rw 29 | 30 | ''' 31 | def scalar_summary(name, x): 32 | try: 33 | summ = tf.summary.scalar(name, x) 34 | except AttributeError: 35 | summ = tf.summary.scalar(name, x) 36 | return summ 37 | 38 | 39 | def histogram_summary(name, x): 40 | try: 41 | summ = tf.summary.histogram(name, x) 42 | except AttributeError: 43 | summ = tf.summary.histogram(name, x) 44 | return summ 45 | 46 | 47 | def tensor_summary(name, x): 48 | try: 49 | summ = tf.summary.tensor_summary(name, x) 50 | except AttributeError: 51 | summ = tf.summary.tensor_summary(name, x) 52 | return summ 53 | 54 | 55 | def audio_summary(name, x, sampling_rate=16e3): 56 | try: 57 | summ = tf.summary.audio(name, x, sampling_rate) 58 | except AttributeError: 59 | summ = tf.summary.audio(name, x, sampling_rate) 60 | return summ 61 | ''' 62 | 63 | 64 | def minmax_normalize(x, x_min, x_max, o_min=-1., o_max=1.): 65 | return (o_max - o_min) / (x_max - x_min) * (x - x_max) + o_max 66 | 67 | 68 | def minmax_denormalize(x, x_min, x_max, o_min=-1., o_max=1.): 69 | return minmax_normalize(x, o_min, o_max, x_min, x_max) 70 | 71 | 72 | def downconv(x, 73 | output_dim, 74 | kwidth=5, 75 | pool=2, 76 | init=None, 77 | uniform=False, 78 | bias_init=None, 79 | name='downconv'): 80 | """ Downsampled convolution 1d """ 81 | x2d = tf.expand_dims(x, 2) 82 | w_init = init 83 | if w_init is None: 84 | w_init = xavier_initializer(uniform=uniform) 85 | with tf.variable_scope(name): 86 | W = tf.get_variable( 87 | 'W', [kwidth, 1, x.get_shape()[-1], output_dim], 88 | initializer=w_init) 89 | conv = tf.nn.conv2d(x2d, W, strides=[1, pool, 1, 1], padding='SAME') 90 | if bias_init is not None: 91 | b = tf.get_variable('b', [output_dim], initializer=bias_init) 92 | conv = tf.reshape(tf.nn.bias_add(conv, b), conv.get_shape()) 93 | else: 94 | conv = tf.reshape(conv, conv.get_shape()) 95 | # reshape back to 1d 96 | conv = tf.reshape( 97 | conv, 98 | conv.get_shape().as_list()[:2] + [conv.get_shape().as_list()[-1]]) 99 | return conv 100 | 101 | 102 | # https://github.com/carpedm20/lstm-char-cnn-tensorflow/blob/master/models/ops.py 103 | def highway(input_, size, layer_size=1, bias=-2, f=tf.nn.relu, name='hw'): 104 | """Highway Network (cf. http://arxiv.org/abs/1505.00387). 105 | t = sigmoid(Wy + b) 106 | z = t * g(Wy + b) + (1 - t) * y 107 | where g is nonlinearity, t is transform gate, and (1 - t) is carry gate. 108 | """ 109 | output = input_ 110 | for idx in range(layer_size): 111 | lin_scope = '{}_output_lin_{}'.format(name, idx) 112 | output = f(tf.contrib.rnn._linear(output, size, 0, scope=lin_scope)) 113 | transform_scope = '{}_transform_lin_{}'.format(name, idx) 114 | transform_gate = tf.sigmoid( 115 | tf.contrib.rnn._linear(input_, size, 0, scope=transform_scope) + 116 | bias) 117 | carry_gate = 1. - transform_gate 118 | 119 | output = transform_gate * output + carry_gate * input_ 120 | 121 | return output 122 | 123 | 124 | def leakyrelu(x, alpha=0.3, name='lrelu'): 125 | return tf.maximum(x, alpha * x, name=name) 126 | 127 | 128 | def prelu(x, name='prelu', ref=False): 129 | in_shape = x.get_shape().as_list() 130 | with tf.variable_scope(name): 131 | # make one alpha per feature 132 | alpha = tf.get_variable( 133 | 'alpha', 134 | in_shape[-1], 135 | initializer=tf.constant_initializer(0.), 136 | dtype=tf.float32) 137 | pos = tf.nn.relu(x) 138 | neg = alpha * (x - tf.abs(x)) * .5 139 | if ref: 140 | # return ref to alpha vector 141 | return pos + neg, alpha 142 | else: 143 | return pos + neg 144 | 145 | 146 | def conv1d(x, 147 | kwidth=5, 148 | num_kernels=1, 149 | init=None, 150 | uniform=False, 151 | bias_init=None, 152 | name='conv1d', 153 | padding='SAME'): 154 | input_shape = x.get_shape() 155 | in_channels = input_shape[-1] 156 | assert len(input_shape) >= 3 157 | w_init = init 158 | if w_init is None: 159 | w_init = xavier_initializer(uniform=uniform) 160 | with tf.variable_scope(name): 161 | # filter shape: [kwidth, in_channels, num_kernels] 162 | W = tf.get_variable( 163 | 'W', [kwidth, in_channels, num_kernels], initializer=w_init) 164 | conv = tf.nn.conv1d(x, W, stride=1, padding=padding) 165 | if bias_init is not None: 166 | b = tf.get_variable( 167 | 'b', [num_kernels], 168 | initializer=tf.constant_initializer(bias_init)) 169 | conv = conv + b 170 | return conv 171 | 172 | 173 | def time_to_batch(value, dilation, name=None): 174 | with tf.name_scope('time_to_batch'): 175 | shape = tf.shape(value) 176 | pad_elements = dilation - 1 - (shape[1] + dilation - 1) % dilation 177 | padded = tf.pad(value, [[0, 0], [0, pad_elements], [0, 0]]) 178 | reshaped = tf.reshape(padded, [-1, dilation, shape[2]]) 179 | transposed = tf.transpose(reshaped, perm=[1, 0, 2]) 180 | return tf.reshape(transposed, [shape[0] * dilation, -1, shape[2]]) 181 | 182 | 183 | # https://github.com/ibab/tensorflow-wavenet/blob/master/wavenet/ops.py 184 | def batch_to_time(value, dilation, name=None): 185 | with tf.name_scope('batch_to_time'): 186 | shape = tf.shape(value) 187 | prepared = tf.reshape(value, [dilation, -1, shape[2]]) 188 | transposed = tf.transpose(prepared, perm=[1, 0, 2]) 189 | return tf.reshape(transposed, 190 | [tf.div(shape[0], dilation), -1, shape[2]]) 191 | 192 | 193 | def atrous_conv1d(value, 194 | dilation, 195 | kwidth=3, 196 | num_kernels=1, 197 | name='atrous_conv1d', 198 | bias_init=None, 199 | stddev=0.02): 200 | input_shape = value.get_shape().as_list() 201 | in_channels = input_shape[-1] 202 | assert len(input_shape) >= 3 203 | with tf.variable_scope(name): 204 | weights_init = tf.truncated_normal_initializer(stddev=0.02) 205 | # filter shape: [kwidth, in_channels, output_channels] 206 | filter_ = tf.get_variable( 207 | 'w', 208 | [kwidth, in_channels, num_kernels], 209 | initializer=weights_init, 210 | ) 211 | padding = [[0, 0], [(kwidth / 2) * dilation, (kwidth / 2) * dilation], 212 | [0, 0]] 213 | padded = tf.pad(value, padding, mode='SYMMETRIC') 214 | if dilation > 1: 215 | transformed = time_to_batch(padded, dilation) 216 | conv = tf.nn.conv1d(transformed, filter_, stride=1, padding='SAME') 217 | restored = batch_to_time(conv, dilation) 218 | else: 219 | restored = tf.nn.conv1d(padded, filter_, stride=1, padding='SAME') 220 | # Remove excess elements at the end. 221 | result = tf.slice(restored, [0, 0, 0], 222 | [-1, input_shape[1], num_kernels]) 223 | if bias_init is not None: 224 | b = tf.get_variable( 225 | 'b', [num_kernels], 226 | initializer=tf.constant_initializer(bias_init)) 227 | result = tf.add(result, b) 228 | return result 229 | 230 | 231 | def residual_block(input_, 232 | dilation, 233 | kwidth, 234 | num_kernels=1, 235 | bias_init=None, 236 | stddev=0.02, 237 | do_skip=True, 238 | name='residual_block'): 239 | print('input shape to residual block: ', input_.get_shape()) 240 | with tf.variable_scope(name): 241 | h_a = atrous_conv1d( 242 | input_, 243 | dilation, 244 | kwidth, 245 | num_kernels, 246 | bias_init=bias_init, 247 | stddev=stddev) 248 | h = tf.tanh(h_a) 249 | # apply gated activation 250 | z_a = atrous_conv1d( 251 | input_, 252 | dilation, 253 | kwidth, 254 | num_kernels, 255 | name='conv_gate', 256 | bias_init=bias_init, 257 | stddev=stddev) 258 | z = tf.nn.sigmoid(z_a) 259 | print('gate shape: ', z.get_shape()) 260 | # element-wise apply the gate 261 | gated_h = tf.multiply(z, h) 262 | print('gated h shape: ', gated_h.get_shape()) 263 | #make res connection 264 | h_ = conv1d( 265 | gated_h, 266 | kwidth=1, 267 | num_kernels=1, 268 | init=tf.truncated_normal_initializer(stddev=stddev), 269 | name='residual_conv1') 270 | res = h_ + input_ 271 | print('residual result: ', res.get_shape()) 272 | if do_skip: 273 | #make skip connection 274 | skip = conv1d( 275 | gated_h, 276 | kwidth=1, 277 | num_kernels=1, 278 | init=tf.truncated_normal_initializer(stddev=stddev), 279 | name='skip_conv1') 280 | return res, skip 281 | else: 282 | return res 283 | 284 | 285 | # Code from keras backend 286 | # https://github.com/fchollet/keras/blob/master/keras/backend/tensorflow_backend.py 287 | def repeat_elements(x, rep, axis): 288 | """Repeats the elements of a tensor along an axis, like `np.repeat`. 289 | If `x` has shape `(s1, s2, s3)` and `axis` is `1`, the output 290 | will have shape `(s1, s2 * rep, s3)`. 291 | # Arguments 292 | x: Tensor or variable. 293 | rep: Python integer, number of times to repeat. 294 | axis: Axis along which to repeat. 295 | # Raises 296 | ValueError: In case `x.shape[axis]` is undefined. 297 | # Returns 298 | A tensor. 299 | """ 300 | x_shape = x.get_shape().as_list() 301 | if x_shape[axis] is None: 302 | raise ValueError('Axis ' + str(axis) + ' of input tensor ' 303 | 'should have a defined dimension, but is None. ' 304 | 'Full tensor shape: ' + str(tuple(x_shape)) + '. ' 305 | 'Typically you need to pass a fully-defined ' 306 | '`input_shape` argument to your first layer.') 307 | # slices along the repeat axis 308 | splits = tf.split(split_dim=axis, num_split=x_shape[axis], value=x) 309 | # repeat each slice the given number of reps 310 | x_rep = [s for s in splits for _ in range(rep)] 311 | return tf.concat(axis, x_rep) 312 | 313 | def nn_deconv(x, 314 | kwidth=5, 315 | dilation=2, 316 | init=None, 317 | uniform=False, 318 | bias_init=None, 319 | name='nn_deconv1d'): 320 | # first compute nearest neighbour interpolated x 321 | interp_x = repeat_elements(x, dilation, 1) 322 | # run a convolution over the interpolated fmap 323 | dec = conv1d( 324 | interp_x, 325 | kwidth=5, 326 | num_kernels=1, 327 | init=init, 328 | uniform=uniform, 329 | bias_init=bias_init, 330 | name=name, 331 | padding='SAME') 332 | return dec 333 | 334 | 335 | def deconv(x, 336 | output_shape, 337 | kwidth=5, 338 | dilation=2, 339 | init=None, 340 | uniform=False, 341 | bias_init=None, 342 | name='deconv1d'): 343 | input_shape = x.get_shape() 344 | in_channels = input_shape[-1] 345 | out_channels = output_shape[-1] 346 | assert len(input_shape) >= 3 347 | # reshape the tensor to use 2d operators 348 | x2d = tf.expand_dims(x, 2) 349 | o2d = output_shape[:2] + [1] + [output_shape[-1]] 350 | w_init = init 351 | if w_init is None: 352 | w_init = xavier_initializer(uniform=uniform) 353 | with tf.variable_scope(name): 354 | # filter shape: [kwidth, output_channels, in_channels] 355 | W = tf.get_variable( 356 | 'W', [kwidth, 1, out_channels, in_channels], initializer=w_init) 357 | try: 358 | deconv = tf.nn.conv2d_transpose( 359 | x2d, W, output_shape=o2d, strides=[1, dilation, 1, 1]) 360 | except AttributeError: 361 | # support for versions of TF before 0.7.0 362 | # based on https://github.com/carpedm20/DCGAN-tensorflow 363 | deconv = tf.nn.conv2d_transpose( 364 | x2d, W, output_shape=o2d, strides=[1, dilation, 1, 1]) 365 | if bias_init is not None: 366 | b = tf.get_variable( 367 | 'b', [out_channels], initializer=tf.constant_initializer(0.)) 368 | deconv = tf.reshape(tf.nn.bias_add(deconv, b), deconv.get_shape()) 369 | else: 370 | deconv = tf.reshape(deconv, deconv.get_shape()) 371 | # reshape back to 1d 372 | deconv = tf.reshape(deconv, output_shape) 373 | return deconv 374 | 375 | def conv2d(input_, 376 | output_dim, 377 | k_h, 378 | k_w, 379 | stddev=0.05, 380 | name="conv2d", 381 | with_w=False): 382 | with tf.variable_scope(name): 383 | w = tf.get_variable( 384 | 'w', [k_h, k_w, input_.get_shape()[-1], output_dim], 385 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 386 | conv = tf.nn.conv2d(input_, w, strides=[1, 1, 1, 1], padding='VALID') 387 | if with_w: 388 | return conv, w 389 | else: 390 | return conv 391 | 392 | 393 | # https://github.com/openai/improved-gan/blob/master/imagenet/ops.py 394 | @contextmanager 395 | def variables_on_gpu0(): 396 | old_fn = tf.get_variable 397 | 398 | def new_fn(*args, **kwargs): 399 | with tf.device("/gpu:0"): 400 | return old_fn(*args, **kwargs) 401 | 402 | tf.get_variable = new_fn 403 | yield 404 | tf.get_variable = old_fn 405 | 406 | 407 | def average_gradients(tower_grads): 408 | """ Calculate the average gradient for each shared variable across towers. 409 | 410 | Note that this function provides a sync point across al towers. 411 | Args: 412 | tower_grads: List of lists of (gradient, variable) tuples. The outer 413 | list is over individual gradients. The inner list is over the gradient 414 | calculation for each tower. 415 | Returns: 416 | List of pairs of (gradient, variable) where the gradient has been 417 | averaged across all towers. 418 | """ 419 | 420 | average_grads = [] 421 | for grad_and_vars in zip(*tower_grads): 422 | # each grad is ((grad0_gpu0, var0_gpu0), ..., (grad0_gpuN, var0_gpuN)) 423 | grads = [] 424 | for g, _ in grad_and_vars: 425 | # Add 0 dim to gradients to represent tower 426 | expanded_g = tf.expand_dims(g, 0) 427 | 428 | # Append on a 'tower' dimension that we will average over below 429 | grads.append(expanded_g) 430 | 431 | # Build the tensor and average along tower dimension 432 | grad = tf.concat(grads, 0) 433 | grad = tf.reduce_mean(grad, 0) 434 | 435 | # The Variables are redundant because they are shared across towers 436 | # just return first tower's pointer to the Variable 437 | v = grad_and_vars[0][1] 438 | grad_and_var = (grad, v) 439 | average_grads.append(grad_and_var) 440 | return average_grads 441 | -------------------------------------------------------------------------------- /evaluate/composite.m: -------------------------------------------------------------------------------- 1 | function [Csig,Cbak,Covl]= composite(cleanFile, enhancedFile); 2 | % ---------------------------------------------------------------------- 3 | % Composite Objective Speech Quality Measure 4 | % 5 | % This function implements the composite objective measure proposed in 6 | % [1]. 7 | % 8 | % Usage: [sig,bak,ovl]=composite(cleanFile.wav, enhancedFile.wav) 9 | % 10 | % cleanFile.wav - clean input file in .wav format 11 | % enhancedFile - enhanced output file in .wav format 12 | % sig - predicted rating [1-5] of speech distortion 13 | % bak - predicted rating [1-5] of noise distortion 14 | % ovl - predicted rating [1-5] of overall quality 15 | % 16 | % In addition to the above ratings (sig, bak, & ovl) it returns 17 | % the individual values of the LLR, SNRseg, WSS and PESQ measures. 18 | % 19 | % Example call: [sig,bak,ovl] =composite('sp04.wav','enhanced.wav') 20 | % 21 | % 22 | % References: 23 | % 24 | % [1] Hu, Y. and Loizou, P. (2006). Evaluation of objective measures 25 | % for speech enhancement. Proc. Interspeech, Pittsburg, PA. 26 | % 27 | % Authors: Yi Hu and Philipos C. Loizou 28 | % (the LLR, SNRseg and WSS measures were based on Bryan Pellom and John 29 | % Hansen's implementations) 30 | % 31 | % Copyright (c) 2006 by Philipos C. Loizou 32 | % $Revision: 0.0 $ $Date: 10/09/2006 $ 33 | 34 | % ---------------------------------------------------------------------- 35 | 36 | if nargin~=2 37 | fprintf('USAGE: [sig,bak,ovl]=composite(cleanFile.wav, enhancedFile.wav)\n'); 38 | fprintf('For more help, type: help composite\n\n'); 39 | return; 40 | end 41 | 42 | alpha= 0.95; 43 | 44 | % [data1, Srate1, Nbits1]= wavread(cleanFile); 45 | % [data2, Srate2, Nbits2]= wavread(enhancedFile); 46 | % if ( Srate1~= Srate2) | ( Nbits1~= Nbits2) 47 | % error( 'The two files do not match!\n'); 48 | % end 49 | [data1, Srate1]= audioread(cleanFile); 50 | [data2, Srate2]= audioread(enhancedFile); 51 | if ( Srate1~= Srate2) 52 | error( 'The two files do not match!\n'); 53 | end 54 | 55 | len= min( length( data1), length( data2)); 56 | data1= data1( 1: len)+eps; 57 | data2= data2( 1: len)+eps; 58 | 59 | 60 | % -- compute the WSS measure --- 61 | % 62 | wss_dist_vec= wss( data1, data2,Srate1); 63 | wss_dist_vec= sort( wss_dist_vec); 64 | wss_dist= mean( wss_dist_vec( 1: round( length( wss_dist_vec)*alpha))); 65 | 66 | % --- compute the LLR measure --------- 67 | % 68 | LLR_dist= llr( data1, data2,Srate1); 69 | LLRs= sort(LLR_dist); 70 | LLR_len= round( length(LLR_dist)* alpha); 71 | llr_mean= mean( LLRs( 1: LLR_len)); 72 | 73 | % --- compute the SNRseg ---------------- 74 | % 75 | [snr_dist, segsnr_dist]= snr( data1, data2,Srate1); 76 | snr_mean= snr_dist; 77 | segSNR= mean( segsnr_dist); 78 | 79 | 80 | % -- compute the pesq ---- 81 | % 82 | % if Srate1==8000, mode='nb'; 83 | % elseif Srate1 == 16000, mode='wb'; 84 | % else, 85 | % error ('Sampling freq in PESQ needs to be 8 kHz or 16 kHz'); 86 | % end 87 | 88 | 89 | [pesq_mos_scores]= pesq(cleanFile, enhancedFile); 90 | 91 | if length(pesq_mos_scores)==2 92 | pesq_mos=pesq_mos_scores(1); % take the raw PESQ value instead of the 93 | % MOS-mapped value (this composite 94 | % measure was only validated with the raw 95 | % PESQ value) 96 | else 97 | pesq_mos=pesq_mos_scores; 98 | end 99 | 100 | % --- now compute the composite measures ------------------ 101 | % 102 | Csig = 3.093 - 1.029*llr_mean + 0.603*pesq_mos-0.009*wss_dist; 103 | Csig = max(1,Csig); Csig=min(5, Csig); % limit values to [1, 5] 104 | Cbak = 1.634 + 0.478 *pesq_mos - 0.007*wss_dist + 0.063*segSNR; 105 | Cbak = max(1, Cbak); Cbak=min(5,Cbak); % limit values to [1, 5] 106 | Covl = 1.594 + 0.805*pesq_mos - 0.512*llr_mean - 0.007*wss_dist; 107 | Covl = max(1, Covl); Covl=min(5, Covl); % limit values to [1, 5] 108 | 109 | fprintf('\n LLR=%f SNRseg=%f WSS=%f PESQ=%f\n',llr_mean,segSNR,wss_dist,pesq_mos); 110 | 111 | return; %================================================================= 112 | 113 | 114 | function distortion = wss(clean_speech, processed_speech,sample_rate) 115 | 116 | 117 | % ---------------------------------------------------------------------- 118 | % Check the length of the clean and processed speech. Must be the same. 119 | % ---------------------------------------------------------------------- 120 | 121 | clean_length = length(clean_speech); 122 | processed_length = length(processed_speech); 123 | 124 | if (clean_length ~= processed_length) 125 | disp('Error: Files musthave same length.'); 126 | return 127 | end 128 | 129 | 130 | 131 | % ---------------------------------------------------------------------- 132 | % Global Variables 133 | % ---------------------------------------------------------------------- 134 | 135 | winlength = round(30*sample_rate/1000); %240; % window length in samples 136 | skiprate = floor(winlength/4); % window skip in samples 137 | max_freq = sample_rate/2; % maximum bandwidth 138 | num_crit = 25; % number of critical bands 139 | 140 | USE_FFT_SPECTRUM = 1; % defaults to 10th order LP spectrum 141 | n_fft = 2^nextpow2(2*winlength); 142 | n_fftby2 = n_fft/2; % FFT size/2 143 | Kmax = 20; % value suggested by Klatt, pg 1280 144 | Klocmax = 1; % value suggested by Klatt, pg 1280 145 | 146 | % ---------------------------------------------------------------------- 147 | % Critical Band Filter Definitions (Center Frequency and Bandwidths in Hz) 148 | % ---------------------------------------------------------------------- 149 | 150 | cent_freq(1) = 50.0000; bandwidth(1) = 70.0000; 151 | cent_freq(2) = 120.000; bandwidth(2) = 70.0000; 152 | cent_freq(3) = 190.000; bandwidth(3) = 70.0000; 153 | cent_freq(4) = 260.000; bandwidth(4) = 70.0000; 154 | cent_freq(5) = 330.000; bandwidth(5) = 70.0000; 155 | cent_freq(6) = 400.000; bandwidth(6) = 70.0000; 156 | cent_freq(7) = 470.000; bandwidth(7) = 70.0000; 157 | cent_freq(8) = 540.000; bandwidth(8) = 77.3724; 158 | cent_freq(9) = 617.372; bandwidth(9) = 86.0056; 159 | cent_freq(10) = 703.378; bandwidth(10) = 95.3398; 160 | cent_freq(11) = 798.717; bandwidth(11) = 105.411; 161 | cent_freq(12) = 904.128; bandwidth(12) = 116.256; 162 | cent_freq(13) = 1020.38; bandwidth(13) = 127.914; 163 | cent_freq(14) = 1148.30; bandwidth(14) = 140.423; 164 | cent_freq(15) = 1288.72; bandwidth(15) = 153.823; 165 | cent_freq(16) = 1442.54; bandwidth(16) = 168.154; 166 | cent_freq(17) = 1610.70; bandwidth(17) = 183.457; 167 | cent_freq(18) = 1794.16; bandwidth(18) = 199.776; 168 | cent_freq(19) = 1993.93; bandwidth(19) = 217.153; 169 | cent_freq(20) = 2211.08; bandwidth(20) = 235.631; 170 | cent_freq(21) = 2446.71; bandwidth(21) = 255.255; 171 | cent_freq(22) = 2701.97; bandwidth(22) = 276.072; 172 | cent_freq(23) = 2978.04; bandwidth(23) = 298.126; 173 | cent_freq(24) = 3276.17; bandwidth(24) = 321.465; 174 | cent_freq(25) = 3597.63; bandwidth(25) = 346.136; 175 | 176 | bw_min = bandwidth (1); % minimum critical bandwidth 177 | 178 | % ---------------------------------------------------------------------- 179 | % Set up the critical band filters. Note here that Gaussianly shaped 180 | % filters are used. Also, the sum of the filter weights are equivalent 181 | % for each critical band filter. Filter less than -30 dB and set to 182 | % zero. 183 | % ---------------------------------------------------------------------- 184 | 185 | min_factor = exp (-30.0 / (2.0 * 2.303)); % -30 dB point of filter 186 | 187 | for i = 1:num_crit 188 | f0 = (cent_freq (i) / max_freq) * (n_fftby2); 189 | all_f0(i) = floor(f0); 190 | bw = (bandwidth (i) / max_freq) * (n_fftby2); 191 | norm_factor = log(bw_min) - log(bandwidth(i)); 192 | j = 0:1:n_fftby2-1; 193 | crit_filter(i,:) = exp (-11 *(((j - floor(f0)) ./bw).^2) + norm_factor); 194 | crit_filter(i,:) = crit_filter(i,:).*(crit_filter(i,:) > min_factor); 195 | end 196 | 197 | % ---------------------------------------------------------------------- 198 | % For each frame of input speech, calculate the Weighted Spectral 199 | % Slope Measure 200 | % ---------------------------------------------------------------------- 201 | 202 | num_frames = clean_length/skiprate-(winlength/skiprate); % number of frames 203 | start = 1; % starting sample 204 | window = 0.5*(1 - cos(2*pi*(1:winlength)'/(winlength+1))); 205 | 206 | for frame_count = 1:num_frames 207 | 208 | % ---------------------------------------------------------- 209 | % (1) Get the Frames for the test and reference speech. 210 | % Multiply by Hanning Window. 211 | % ---------------------------------------------------------- 212 | 213 | clean_frame = clean_speech(start:start+winlength-1); 214 | processed_frame = processed_speech(start:start+winlength-1); 215 | clean_frame = clean_frame.*window; 216 | processed_frame = processed_frame.*window; 217 | 218 | % ---------------------------------------------------------- 219 | % (2) Compute the Power Spectrum of Clean and Processed 220 | % ---------------------------------------------------------- 221 | 222 | if (USE_FFT_SPECTRUM) 223 | clean_spec = (abs(fft(clean_frame,n_fft)).^2); 224 | processed_spec = (abs(fft(processed_frame,n_fft)).^2); 225 | else 226 | a_vec = zeros(1,n_fft); 227 | a_vec(1:11) = lpc(clean_frame,10); 228 | clean_spec = 1.0/(abs(fft(a_vec,n_fft)).^2)'; 229 | 230 | a_vec = zeros(1,n_fft); 231 | a_vec(1:11) = lpc(processed_frame,10); 232 | processed_spec = 1.0/(abs(fft(a_vec,n_fft)).^2)'; 233 | end 234 | 235 | % ---------------------------------------------------------- 236 | % (3) Compute Filterbank Output Energies (in dB scale) 237 | % ---------------------------------------------------------- 238 | 239 | for i = 1:num_crit 240 | clean_energy(i) = sum(clean_spec(1:n_fftby2) ... 241 | .*crit_filter(i,:)'); 242 | processed_energy(i) = sum(processed_spec(1:n_fftby2) ... 243 | .*crit_filter(i,:)'); 244 | end 245 | clean_energy = 10*log10(max(clean_energy,1E-10)); 246 | processed_energy = 10*log10(max(processed_energy,1E-10)); 247 | 248 | % ---------------------------------------------------------- 249 | % (4) Compute Spectral Slope (dB[i+1]-dB[i]) 250 | % ---------------------------------------------------------- 251 | 252 | clean_slope = clean_energy(2:num_crit) - ... 253 | clean_energy(1:num_crit-1); 254 | processed_slope = processed_energy(2:num_crit) - ... 255 | processed_energy(1:num_crit-1); 256 | 257 | % ---------------------------------------------------------- 258 | % (5) Find the nearest peak locations in the spectra to 259 | % each critical band. If the slope is negative, we 260 | % search to the left. If positive, we search to the 261 | % right. 262 | % ---------------------------------------------------------- 263 | 264 | for i = 1:num_crit-1 265 | 266 | % find the peaks in the clean speech signal 267 | 268 | if (clean_slope(i)>0) % search to the right 269 | n = i; 270 | while ((n 0)) 271 | n = n+1; 272 | end 273 | clean_loc_peak(i) = clean_energy(n-1); 274 | else % search to the left 275 | n = i; 276 | while ((n>0) & (clean_slope(n) <= 0)) 277 | n = n-1; 278 | end 279 | clean_loc_peak(i) = clean_energy(n+1); 280 | end 281 | 282 | % find the peaks in the processed speech signal 283 | 284 | if (processed_slope(i)>0) % search to the right 285 | n = i; 286 | while ((n 0)) 287 | n = n+1; 288 | end 289 | processed_loc_peak(i) = processed_energy(n-1); 290 | else % search to the left 291 | n = i; 292 | while ((n>0) & (processed_slope(n) <= 0)) 293 | n = n-1; 294 | end 295 | processed_loc_peak(i) = processed_energy(n+1); 296 | end 297 | 298 | end 299 | 300 | % ---------------------------------------------------------- 301 | % (6) Compute the WSS Measure for this frame. This 302 | % includes determination of the weighting function. 303 | % ---------------------------------------------------------- 304 | 305 | dBMax_clean = max(clean_energy); 306 | dBMax_processed = max(processed_energy); 307 | 308 | % The weights are calculated by averaging individual 309 | % weighting factors from the clean and processed frame. 310 | % These weights W_clean and W_processed should range 311 | % from 0 to 1 and place more emphasis on spectral 312 | % peaks and less emphasis on slope differences in spectral 313 | % valleys. This procedure is described on page 1280 of 314 | % Klatt's 1982 ICASSP paper. 315 | 316 | Wmax_clean = Kmax ./ (Kmax + dBMax_clean - ... 317 | clean_energy(1:num_crit-1)); 318 | Wlocmax_clean = Klocmax ./ ( Klocmax + clean_loc_peak - ... 319 | clean_energy(1:num_crit-1)); 320 | W_clean = Wmax_clean .* Wlocmax_clean; 321 | 322 | Wmax_processed = Kmax ./ (Kmax + dBMax_processed - ... 323 | processed_energy(1:num_crit-1)); 324 | Wlocmax_processed = Klocmax ./ ( Klocmax + processed_loc_peak - ... 325 | processed_energy(1:num_crit-1)); 326 | W_processed = Wmax_processed .* Wlocmax_processed; 327 | 328 | W = (W_clean + W_processed)./2.0; 329 | 330 | distortion(frame_count) = sum(W.*(clean_slope(1:num_crit-1) - ... 331 | processed_slope(1:num_crit-1)).^2); 332 | 333 | % this normalization is not part of Klatt's paper, but helps 334 | % to normalize the measure. Here we scale the measure by the 335 | % sum of the weights. 336 | 337 | distortion(frame_count) = distortion(frame_count)/sum(W); 338 | 339 | start = start + skiprate; 340 | 341 | end 342 | 343 | %----------------------------------------------- 344 | function distortion = llr(clean_speech, processed_speech,sample_rate) 345 | 346 | 347 | % ---------------------------------------------------------------------- 348 | % Check the length of the clean and processed speech. Must be the same. 349 | % ---------------------------------------------------------------------- 350 | 351 | clean_length = length(clean_speech); 352 | processed_length = length(processed_speech); 353 | 354 | if (clean_length ~= processed_length) 355 | disp('Error: Both Speech Files must be same length.'); 356 | return 357 | end 358 | 359 | % ---------------------------------------------------------------------- 360 | % Global Variables 361 | % ---------------------------------------------------------------------- 362 | 363 | winlength = round(30*sample_rate/1000); % window length in samples 364 | skiprate = floor(winlength/4); % window skip in samples 365 | if sample_rate<10000 366 | P = 10; % LPC Analysis Order 367 | else 368 | P=16; % this could vary depending on sampling frequency. 369 | end 370 | 371 | % ---------------------------------------------------------------------- 372 | % For each frame of input speech, calculate the Log Likelihood Ratio 373 | % ---------------------------------------------------------------------- 374 | 375 | num_frames = clean_length/skiprate-(winlength/skiprate); % number of frames 376 | start = 1; % starting sample 377 | window = 0.5*(1 - cos(2*pi*(1:winlength)'/(winlength+1))); 378 | 379 | for frame_count = 1:num_frames 380 | 381 | % ---------------------------------------------------------- 382 | % (1) Get the Frames for the test and reference speech. 383 | % Multiply by Hanning Window. 384 | % ---------------------------------------------------------- 385 | 386 | clean_frame = clean_speech(start:start+winlength-1); 387 | processed_frame = processed_speech(start:start+winlength-1); 388 | clean_frame = clean_frame.*window; 389 | processed_frame = processed_frame.*window; 390 | 391 | % ---------------------------------------------------------- 392 | % (2) Get the autocorrelation lags and LPC parameters used 393 | % to compute the LLR measure. 394 | % ---------------------------------------------------------- 395 | 396 | [R_clean, Ref_clean, A_clean] = ... 397 | lpcoeff(clean_frame, P); 398 | [R_processed, Ref_processed, A_processed] = ... 399 | lpcoeff(processed_frame, P); 400 | 401 | % ---------------------------------------------------------- 402 | % (3) Compute the LLR measure 403 | % ---------------------------------------------------------- 404 | 405 | numerator = A_processed*toeplitz(R_clean)*A_processed'; 406 | denominator = A_clean*toeplitz(R_clean)*A_clean'; 407 | distortion(frame_count) = log(numerator/denominator); 408 | start = start + skiprate; 409 | 410 | end 411 | 412 | %--------------------------------------------- 413 | function [acorr, refcoeff, lpparams] = lpcoeff(speech_frame, model_order) 414 | 415 | % ---------------------------------------------------------- 416 | % (1) Compute Autocorrelation Lags 417 | % ---------------------------------------------------------- 418 | 419 | winlength = max(size(speech_frame)); 420 | for k=1:model_order+1 421 | R(k) = sum(speech_frame(1:winlength-k+1) ... 422 | .*speech_frame(k:winlength)); 423 | end 424 | 425 | % ---------------------------------------------------------- 426 | % (2) Levinson-Durbin 427 | % ---------------------------------------------------------- 428 | 429 | a = ones(1,model_order); 430 | E(1)=R(1); 431 | for i=1:model_order 432 | a_past(1:i-1) = a(1:i-1); 433 | sum_term = sum(a_past(1:i-1).*R(i:-1:2)); 434 | rcoeff(i)=(R(i+1) - sum_term) / E(i); 435 | a(i)=rcoeff(i); 436 | a(1:i-1) = a_past(1:i-1) - rcoeff(i).*a_past(i-1:-1:1); 437 | E(i+1)=(1-rcoeff(i)*rcoeff(i))*E(i); 438 | end 439 | 440 | acorr = R; 441 | refcoeff = rcoeff; 442 | lpparams = [1 -a]; 443 | 444 | 445 | % ---------------------------------------------------------------------- 446 | 447 | function [overall_snr, segmental_snr] = snr(clean_speech, processed_speech,sample_rate) 448 | 449 | % ---------------------------------------------------------------------- 450 | % Check the length of the clean and processed speech. Must be the same. 451 | % ---------------------------------------------------------------------- 452 | 453 | clean_length = length(clean_speech); 454 | processed_length = length(processed_speech); 455 | 456 | if (clean_length ~= processed_length) 457 | disp('Error: Both Speech Files must be same length.'); 458 | return 459 | end 460 | 461 | % ---------------------------------------------------------------------- 462 | % Scale both clean speech and processed speech to have same dynamic 463 | % range. Also remove DC component from each signal 464 | % ---------------------------------------------------------------------- 465 | 466 | %clean_speech = clean_speech - mean(clean_speech); 467 | %processed_speech = processed_speech - mean(processed_speech); 468 | 469 | %processed_speech = processed_speech.*(max(abs(clean_speech))/ max(abs(processed_speech))); 470 | 471 | overall_snr = 10* log10( sum(clean_speech.^2)/sum((clean_speech-processed_speech).^2)); 472 | 473 | % ---------------------------------------------------------------------- 474 | % Global Variables 475 | % ---------------------------------------------------------------------- 476 | 477 | winlength = round(30*sample_rate/1000); %240; % window length in samples 478 | skiprate = floor(winlength/4); % window skip in samples 479 | MIN_SNR = -10; % minimum SNR in dB 480 | MAX_SNR = 35; % maximum SNR in dB 481 | 482 | % ---------------------------------------------------------------------- 483 | % For each frame of input speech, calculate the Segmental SNR 484 | % ---------------------------------------------------------------------- 485 | 486 | num_frames = clean_length/skiprate-(winlength/skiprate); % number of frames 487 | start = 1; % starting sample 488 | window = 0.5*(1 - cos(2*pi*(1:winlength)'/(winlength+1))); 489 | 490 | for frame_count = 1: num_frames 491 | 492 | % ---------------------------------------------------------- 493 | % (1) Get the Frames for the test and reference speech. 494 | % Multiply by Hanning Window. 495 | % ---------------------------------------------------------- 496 | 497 | clean_frame = clean_speech(start:start+winlength-1); 498 | processed_frame = processed_speech(start:start+winlength-1); 499 | clean_frame = clean_frame.*window; 500 | processed_frame = processed_frame.*window; 501 | 502 | % ---------------------------------------------------------- 503 | % (2) Compute the Segmental SNR 504 | % ---------------------------------------------------------- 505 | 506 | signal_energy = sum(clean_frame.^2); 507 | noise_energy = sum((clean_frame-processed_frame).^2); 508 | segmental_snr(frame_count) = 10*log10(signal_energy/(noise_energy+eps)+eps); 509 | segmental_snr(frame_count) = max(segmental_snr(frame_count),MIN_SNR); 510 | segmental_snr(frame_count) = min(segmental_snr(frame_count),MAX_SNR); 511 | 512 | start = start + skiprate; 513 | 514 | end 515 | 516 | 517 | 518 | -------------------------------------------------------------------------------- /sasegan/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | from tensorflow.contrib.layers import batch_norm, fully_connected, flatten 4 | from tensorflow.contrib.layers import xavier_initializer 5 | from scipy.io import wavfile 6 | from generator import * 7 | from discriminator import * 8 | import numpy as np 9 | from data_loader import read_and_decode, de_emph 10 | from bnorm import VBN 11 | from ops import * 12 | import timeit 13 | import os 14 | import shutil 15 | 16 | 17 | class Model(object): 18 | def __init__(self, name='BaseModel'): 19 | self.name = name 20 | 21 | def save(self, save_path, step): 22 | model_name = self.name 23 | if not os.path.exists(save_path): 24 | os.makedirs(save_path) 25 | if not hasattr(self, 'saver'): 26 | self.saver = tf.train.Saver() 27 | self.saver.save( 28 | self.sess, os.path.join(save_path, model_name), global_step=step) 29 | 30 | def load(self, save_path, model_file=None): 31 | if not os.path.exists(save_path): 32 | print('[!] Checkpoints path does not exist...') 33 | return False 34 | print('[*] Reading checkpoints...') 35 | if model_file is None: 36 | ckpt = tf.train.get_checkpoint_state(save_path) 37 | if ckpt and ckpt.model_checkpoint_path: 38 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 39 | else: 40 | return False 41 | else: 42 | ckpt_name = model_file 43 | if not hasattr(self, 'saver'): 44 | self.saver = tf.train.Saver() 45 | self.saver.restore(self.sess, os.path.join(save_path, ckpt_name)) 46 | print('[*] Read {}'.format(ckpt_name)) 47 | return True 48 | 49 | 50 | class SEGAN(Model): 51 | """ Speech Enhancement Generative Adversarial Network """ 52 | 53 | def __init__(self, sess, args, devices, infer=False, name='SEGAN'): 54 | super(SEGAN, self).__init__(name) 55 | self.args = args 56 | self.sess = sess 57 | self.keep_prob = 1. 58 | if infer: 59 | self.keep_prob_var = tf.Variable(self.keep_prob, trainable=False) 60 | else: 61 | self.keep_prob = 0.5 62 | self.keep_prob_var = tf.Variable(self.keep_prob, trainable=False) 63 | self.batch_size = args.batch_size 64 | self.epoch = args.epoch 65 | self.d_label_smooth = args.d_label_smooth 66 | self.devices = devices 67 | self.z_dim = args.z_dim 68 | self.z_depth = args.z_depth 69 | # type of deconv 70 | self.deconv_type = args.deconv_type 71 | # specify if use biases or not 72 | self.bias_downconv = args.bias_downconv 73 | self.bias_deconv = args.bias_deconv 74 | self.bias_D_conv = args.bias_D_conv 75 | # clip D values 76 | self.d_clip_weights = False 77 | # apply VBN or regular BN? 78 | self.disable_vbn = False 79 | self.save_path = args.save_path 80 | # num of updates to be applied to D before G 81 | # this is k in original GAN paper (https://arxiv.org/abs/1406.2661) 82 | self.disc_updates = 1 83 | # set preemph factor 84 | self.preemph = args.preemph 85 | if self.preemph > 0: 86 | print('*** Applying pre-emphasis of {} ***'.format(self.preemph)) 87 | else: 88 | print('--- No pre-emphasis applied ---') 89 | # canvas size 90 | self.canvas_size = args.canvas_size 91 | self.deactivated_noise = False 92 | # dilation factors per layer (only in atrous conv G config) 93 | self.g_dilated_blocks = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] 94 | # num fmaps for AutoEncoder SEGAN (v1) 95 | self.g_enc_depths = [16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024] 96 | # Define D fmaps 97 | self.d_num_fmaps = [16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024] 98 | self.init_noise_std = args.init_noise_std 99 | self.disc_noise_std = tf.Variable(self.init_noise_std, trainable=False) 100 | 101 | # encoder: [16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024] (the last layer is the bottleneck) 102 | # decoder: [512, 256, 256, 128, 128, 64, 64, 32, 32, 16, 1] (the last layer is channel-wise convolution) 103 | # indices of encoder's self-attention conv layers 104 | self.enc_att_layer_ind = list(map(int, args.att_layer_ind.split(","))) 105 | # indices of decoder's self-attention conv layers 106 | self.dec_att_layer_ind = [(len(self.g_enc_depths) - ind - 2) for ind in self.enc_att_layer_ind if ind < (len(self.g_enc_depths) - 1)] 107 | 108 | self.e2e_dataset = args.e2e_dataset 109 | # G's supervised loss weight 110 | self.l1_weight = args.init_l1_weight 111 | self.l1_lambda = tf.Variable(self.l1_weight, trainable=False) 112 | self.deactivated_l1 = False 113 | # define the functions 114 | self.discriminator = discriminator 115 | # register G non linearity 116 | self.g_nl = args.g_nl 117 | if args.g_type == 'ae': 118 | self.generator = AEGenerator(self) 119 | elif args.g_type == 'dwave': 120 | self.generator = Generator(self) 121 | else: 122 | raise ValueError('Unrecognized G type {}'.format(args.g_type)) 123 | self.build_model(args) 124 | 125 | def build_model(self, config): 126 | all_d_grads = [] 127 | all_g_grads = [] 128 | d_opt = tf.train.RMSPropOptimizer(config.d_learning_rate) 129 | g_opt = tf.train.RMSPropOptimizer(config.g_learning_rate) 130 | # d_opt = tf.train.AdamOptimizer( 131 | # config.d_learning_rate, beta1=config.beta_1) 132 | # g_opt = tf.train.AdamOptimizer( 133 | # config.g_learning_rate, beta1=config.beta_1) 134 | 135 | with tf.variable_scope(tf.get_variable_scope()) as scope: 136 | for idx, device in enumerate(self.devices): 137 | with tf.device("/%s" % device): 138 | with tf.name_scope("device_%s" % idx): 139 | with variables_on_gpu0(): 140 | self.build_model_single_gpu(idx) 141 | d_grads = d_opt.compute_gradients( 142 | self.d_losses[-1], var_list=list(self.d_vars)) 143 | g_grads = g_opt.compute_gradients( 144 | self.g_losses[-1], var_list=list(self.g_vars)) 145 | all_d_grads.append(d_grads) 146 | all_g_grads.append(g_grads) 147 | # tf.get_variable_scope().reuse_variables() 148 | avg_d_grads = average_gradients(all_d_grads) 149 | avg_g_grads = average_gradients(all_g_grads) 150 | self.d_opt = d_opt.apply_gradients(avg_d_grads) 151 | self.g_opt = g_opt.apply_gradients(avg_g_grads) 152 | 153 | def build_model_single_gpu(self, gpu_idx): 154 | if gpu_idx == 0: 155 | # create the nodes to load for input pipeline 156 | filename_queue = tf.train.string_input_producer([self.e2e_dataset]) 157 | self.get_wav, self.get_noisy = read_and_decode( 158 | filename_queue, self.canvas_size, self.preemph) 159 | # load the data to input pipeline 160 | wavbatch, \ 161 | noisybatch = tf.train.shuffle_batch([self.get_wav, 162 | self.get_noisy], 163 | batch_size=self.batch_size, 164 | num_threads=2, 165 | capacity=1000 + 3 * self.batch_size, 166 | min_after_dequeue=1000, 167 | name='wav_and_noisy') 168 | if gpu_idx == 0: 169 | self.Gs = [] 170 | self.zs = [] 171 | self.gtruth_wavs = [] 172 | self.gtruth_noisy = [] 173 | 174 | self.gtruth_wavs.append(wavbatch) 175 | self.gtruth_noisy.append(noisybatch) 176 | 177 | # add channels dimension to manipulate in D and G 178 | wavbatch = tf.expand_dims(wavbatch, -1) 179 | noisybatch = tf.expand_dims(noisybatch, -1) 180 | # by default leaky relu is used 181 | do_prelu = False 182 | if self.g_nl == 'prelu': 183 | do_prelu = True 184 | if gpu_idx == 0: 185 | #self.sample_wavs = tf.placeholder(tf.float32, [self.batch_size, 186 | # self.canvas_size], 187 | # name='sample_wavs') 188 | ref_Gs = self.generator( 189 | noisybatch, is_ref=True, spk=None, do_prelu=do_prelu) 190 | print('num of G returned: ', len(ref_Gs)) 191 | self.reference_G = ref_Gs[0] 192 | self.ref_z = ref_Gs[1] 193 | if do_prelu: 194 | self.ref_alpha = ref_Gs[2:] 195 | 196 | # make a dummy copy of discriminator to have variables and then 197 | # be able to set up the variable reuse for all other devices 198 | # merge along channels and this would be a real batch 199 | dummy_joint = tf.concat([wavbatch, noisybatch], 2) 200 | dummy = discriminator(self, dummy_joint, reuse=False) 201 | 202 | G, z = self.generator( 203 | noisybatch, is_ref=False, spk=None, do_prelu=do_prelu) 204 | self.Gs.append(G) 205 | self.zs.append(z) 206 | 207 | # add new dimension to merge with other pairs 208 | D_rl_joint = tf.concat([wavbatch, noisybatch], 2) 209 | D_fk_joint = tf.concat([G, noisybatch], 2) 210 | # build rl discriminator 211 | d_rl_logits = discriminator(self, D_rl_joint, reuse=True) 212 | # build fk G discriminator 213 | d_fk_logits = discriminator(self, D_fk_joint, reuse=True) 214 | 215 | 216 | if gpu_idx == 0: 217 | self.g_losses = [] 218 | self.g_l1_losses = [] 219 | self.g_adv_losses = [] 220 | self.d_rl_losses = [] 221 | self.d_fk_losses = [] 222 | #self.d_nfk_losses = [] 223 | self.d_losses = [] 224 | 225 | d_rl_loss = tf.reduce_mean(tf.squared_difference(d_rl_logits, 1.)) 226 | d_fk_loss = tf.reduce_mean(tf.squared_difference(d_fk_logits, 0.)) 227 | #d_nfk_loss = tf.reduce_mean(tf.squared_difference(d_nfk_logits, 0.)) 228 | g_adv_loss = tf.reduce_mean(tf.squared_difference(d_fk_logits, 1.)) 229 | 230 | d_loss = d_rl_loss + d_fk_loss 231 | 232 | # Add the L1 loss to G 233 | g_l1_loss = self.l1_lambda * tf.reduce_mean( 234 | tf.abs(tf.subtract(G, wavbatch))) 235 | 236 | g_loss = g_adv_loss + g_l1_loss 237 | 238 | self.g_l1_losses.append(g_l1_loss) 239 | self.g_adv_losses.append(g_adv_loss) 240 | self.g_losses.append(g_loss) 241 | self.d_rl_losses.append(d_rl_loss) 242 | self.d_fk_losses.append(d_fk_loss) 243 | #self.d_nfk_losses.append(d_nfk_loss) 244 | self.d_losses.append(d_loss) 245 | 246 | if gpu_idx == 0: 247 | self.get_vars() 248 | 249 | def get_vars(self): 250 | t_vars = tf.trainable_variables() 251 | self.d_vars_dict = {} 252 | self.g_vars_dict = {} 253 | for var in t_vars: 254 | if var.name.startswith('d_'): 255 | self.d_vars_dict[var.name] = var 256 | if var.name.startswith('g_'): 257 | self.g_vars_dict[var.name] = var 258 | self.d_vars = self.d_vars_dict.values() 259 | self.g_vars = self.g_vars_dict.values() 260 | for x in self.d_vars: 261 | assert x not in self.g_vars 262 | for x in self.g_vars: 263 | assert x not in self.d_vars 264 | for x in t_vars: 265 | assert x in self.g_vars or x in self.d_vars, x.name 266 | self.all_vars = t_vars 267 | if self.d_clip_weights: 268 | print('Clipping D weights') 269 | self.d_clip = [ 270 | v.assign(tf.clip_by_value(v, -0.05, 0.05)) for v in self.d_vars 271 | ] 272 | else: 273 | print('Not clipping D weights') 274 | 275 | def vbn(self, tensor, name): 276 | if self.disable_vbn: 277 | 278 | class Dummy(object): 279 | # Do nothing here, no bnorm 280 | def __init__(self, tensor, ignored): 281 | self.reference_output = tensor 282 | 283 | def __call__(self, x): 284 | return x 285 | 286 | VBN_cls = Dummy 287 | else: 288 | VBN_cls = VBN 289 | if not hasattr(self, name): 290 | vbn = VBN_cls(tensor, name) 291 | setattr(self, name, vbn) 292 | return vbn.reference_output 293 | vbn = getattr(self, name) 294 | return vbn(tensor) 295 | 296 | def train(self, config, devices): 297 | """ Train the SEGAN """ 298 | 299 | print('Initializing optimizers...') 300 | # init optimizers 301 | d_opt = self.d_opt 302 | g_opt = self.g_opt 303 | num_devices = len(devices) 304 | 305 | try: 306 | init = tf.global_variables_initializer() 307 | except AttributeError: 308 | # fall back to old implementation 309 | init = tf.initialize_all_variables() 310 | 311 | print('Initializing variables...') 312 | self.sess.run(init) 313 | 314 | coord = tf.train.Coordinator() 315 | threads = tf.train.start_queue_runners(coord=coord) 316 | 317 | print('Sampling some wavs to store sample references...') 318 | # Hang onto a copy of wavs so we can feed the same one every time 319 | # we store samples to disk for hearing 320 | # pick a single batch 321 | sample_noisy, sample_wav, \ 322 | sample_z = self.sess.run([self.gtruth_noisy[0], 323 | self.gtruth_wavs[0], 324 | self.zs[0]]) 325 | print('sample noisy shape: ', sample_noisy.shape) 326 | print('sample wav shape: ', sample_wav.shape) 327 | print('sample z shape: ', sample_z.shape) 328 | 329 | save_path = config.save_path 330 | synthesis_path = config.synthesis_path 331 | counter = 0 332 | # count number of samples 333 | num_examples = 0 334 | for record in tf.python_io.tf_record_iterator(self.e2e_dataset): 335 | num_examples += 1 336 | print('total examples in TFRecords {}: {}'.format( 337 | self.e2e_dataset, num_examples)) 338 | # last samples (those not filling a complete batch) are discarded 339 | num_batches = num_examples / self.batch_size 340 | 341 | print('Batches per epoch: ', num_batches) 342 | 343 | if self.load(self.save_path): 344 | print('[*] Load SUCCESS') 345 | else: 346 | print('[!] Load failed') 347 | batch_idx = 0 348 | curr_epoch = 0 349 | batch_timings = [] 350 | d_fk_losses = [] 351 | #d_nfk_losses = [] 352 | d_rl_losses = [] 353 | g_adv_losses = [] 354 | g_l1_losses = [] 355 | try: 356 | while not coord.should_stop(): 357 | start = timeit.default_timer() 358 | 359 | for d_iter in range(self.disc_updates): 360 | _d_opt, \ 361 | d_fk_loss, \ 362 | d_rl_loss = self.sess.run([d_opt, 363 | self.d_fk_losses[0], 364 | #self.d_nfk_losses[0], 365 | self.d_rl_losses[0]]) 366 | #d_nfk_loss, \ 367 | if self.d_clip_weights: 368 | self.sess.run(self.d_clip) 369 | 370 | _g_opt, \ 371 | g_adv_loss, \ 372 | g_l1_loss = self.sess.run([g_opt, self.g_adv_losses[0], 373 | self.g_l1_losses[0]]) 374 | 375 | end = timeit.default_timer() 376 | batch_timings.append(end - start) 377 | d_fk_losses.append(d_fk_loss) 378 | #d_nfk_losses.append(d_nfk_loss) 379 | d_rl_losses.append(d_rl_loss) 380 | g_adv_losses.append(g_adv_loss) 381 | g_l1_losses.append(g_l1_loss) 382 | print('{}/{} (epoch {}), d_rl_loss = {:.5f}, ' 383 | 'd_fk_loss = {:.5f}, ' #d_nfk_loss = {:.5f}, ' 384 | 'g_adv_loss = {:.5f}, g_l1_loss = {:.5f},' 385 | ' time/batch = {:.5f}, ' 386 | 'mtime/batch = {:.5f}'.format( 387 | counter, 388 | config.epoch * num_batches, 389 | curr_epoch, 390 | d_rl_loss, 391 | d_fk_loss, 392 | #d_nfk_loss, 393 | g_adv_loss, 394 | g_l1_loss, 395 | end - start, 396 | np.mean(batch_timings))) 397 | batch_idx += num_devices 398 | counter += num_devices 399 | if (counter / num_devices) % config.save_freq == 0: 400 | self.save(config.save_path, counter) 401 | 402 | fdict = { 403 | self.gtruth_noisy[0]: sample_noisy, 404 | self.zs[0]: sample_z 405 | } 406 | # i do not want to log too many, start log wave file after 40000 steps 407 | if(counter > 200000): 408 | canvas_w = self.sess.run(self.Gs[0], feed_dict=fdict) 409 | swaves = sample_wav 410 | sample_dif = sample_wav - sample_noisy 411 | for m in range(min(10, canvas_w.shape[0])): 412 | print('w{} max: {} min: {}'.format(m, np.max(canvas_w[m]), np.min(canvas_w[m]))) 413 | wavfile.write(os.path.join(synthesis_path, 'sample_{}-{}.wav'.format(counter, m)), int(16e3), 414 | de_emph(canvas_w[m], self.preemph)) 415 | m_gtruth_path = os.path.join(synthesis_path, 'gtruth_{}.wav'.format(m)) 416 | if not os.path.exists(m_gtruth_path): 417 | wavfile.write(os.path.join(synthesis_path, 'gtruth_{}.wav'.format(m)), int(16e3), 418 | de_emph(swaves[m], self.preemph)) 419 | wavfile.write(os.path.join(synthesis_path, 'noisy_{}.wav'.format(m)), int(16e3), 420 | de_emph(sample_noisy[m], self.preemph)) 421 | wavfile.write(os.path.join(synthesis_path,'dif_{}.wav'.format(m)), int(16e3), 422 | de_emph(sample_dif[m], self.preemph)) 423 | np.savetxt(os.path.join(save_path, 'd_rl_losses.txt'),d_rl_losses) 424 | np.savetxt(os.path.join(save_path, 'd_fk_losses.txt'),d_fk_losses) 425 | np.savetxt(os.path.join(save_path, 'g_adv_losses.txt'),g_adv_losses) 426 | np.savetxt(os.path.join(save_path, 'g_l1_losses.txt'),g_l1_losses) 427 | 428 | if batch_idx >= num_batches: 429 | curr_epoch += 1 430 | # re-set batch idx 431 | batch_idx = 0 432 | # check if we have to deactivate L1 433 | if curr_epoch >= config.l1_remove_epoch and self.deactivated_l1 == False: 434 | print('** Deactivating L1 factor! **') 435 | self.sess.run(tf.assign(self.l1_lambda, 0.)) 436 | self.deactivated_l1 = True 437 | # check if we have to start decaying noise (if any) 438 | if curr_epoch >= config.denoise_epoch and self.deactivated_noise == False: 439 | # apply noise std decay rate 440 | decay = config.noise_decay 441 | if not hasattr(self, 'curr_noise_std'): 442 | self.curr_noise_std = self.init_noise_std 443 | new_noise_std = decay * self.curr_noise_std 444 | if new_noise_std < config.denoise_lbound: 445 | print('New noise std {} < lbound {}, setting 0.'. 446 | format(new_noise_std, config.denoise_lbound)) 447 | print('** De-activating noise layer **') 448 | # it it's lower than a lower bound, cancel out completely 449 | new_noise_std = 0. 450 | self.deactivated_noise = True 451 | else: 452 | print( 453 | 'Applying decay {} to noise std {}: {}'.format( 454 | decay, self.curr_noise_std, new_noise_std)) 455 | self.sess.run( 456 | tf.assign(self.disc_noise_std, new_noise_std)) 457 | self.curr_noise_std = new_noise_std 458 | if curr_epoch >= config.epoch: 459 | # done training 460 | print('Done training; epoch limit {} ' 461 | 'reached.'.format(self.epoch)) 462 | print('Saving last model at iteration {}'.format(counter)) 463 | self.save(config.save_path, counter) 464 | break 465 | except tf.errors.OutOfRangeError: 466 | print('Done training; epoch limit {} reached.'.format(self.epoch)) 467 | finally: 468 | coord.request_stop() 469 | coord.join(threads) 470 | 471 | def clean(self, x): 472 | """ clean a utterance x 473 | x: numpy array containing the normalized noisy waveform 474 | """ 475 | c_res = None 476 | for beg_i in range(0, x.shape[0], self.canvas_size): 477 | if x.shape[0] - beg_i < self.canvas_size: 478 | length = x.shape[0] - beg_i 479 | pad = (self.canvas_size) - length 480 | else: 481 | length = self.canvas_size 482 | pad = 0 483 | x_ = np.zeros((self.batch_size, self.canvas_size)) 484 | if pad > 0: 485 | x_[0] = np.concatenate((x[beg_i:beg_i + length], 486 | np.zeros(pad))) 487 | else: 488 | x_[0] = x[beg_i:beg_i + length] 489 | print('Cleaning chunk {} -> {}'.format(beg_i, beg_i + length)) 490 | fdict = {self.gtruth_noisy[0]: x_} 491 | canvas_w = self.sess.run(self.Gs[0], feed_dict=fdict)[0] 492 | canvas_w = canvas_w.reshape((self.canvas_size)) 493 | print('canvas w shape: ', canvas_w.shape) 494 | if pad > 0: 495 | print('Removing padding of {} samples'.format(pad)) 496 | # get rid of last padded samples 497 | canvas_w = canvas_w[:-pad] 498 | if c_res is None: 499 | c_res = canvas_w 500 | else: 501 | c_res = np.concatenate((c_res, canvas_w)) 502 | # deemphasize 503 | c_res = de_emph(c_res, self.preemph) 504 | return c_res 505 | --------------------------------------------------------------------------------