├── .gitignore ├── README.md ├── convtest.py ├── data ├── SOURCE.txt ├── test │ └── processed_4410_ulaw.npy └── train │ └── processed_4410_ulaw.npy ├── dataset.py ├── requirements.txt ├── vctk └── download_vctk.sh ├── wavenet.py └── wavenet_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.wav 2 | .idea 3 | *.pyc 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WaveNet implementation in Keras 2 | Based on https://deepmind.com/blog/wavenet-generative-model-raw-audio/ and https://arxiv.org/pdf/1609.03499.pdf. 3 | 4 | 5 | [Listen to a sample 🎶!](https://soundcloud.com/basveeling/wavenet-sample) 6 | 7 | ~~Generate your own samples: 8 | 9 | ```$ KERAS_BACKEND=theano python2 wavenet.py predict with models/run_20160920_120916/config.json predict_seconds=1```~~ 10 | EDIT: The pretrained model had to be removed from the repository as it wasn't compatible with recent changes. 11 | 12 | ## Installation: 13 | Activate a new python2 virtualenv (recommended): 14 | ```bash 15 | pip install virtualenv 16 | mkdir ~/virtualenvs && cd ~/virtualenvs 17 | virtualenv wavenet 18 | source wavenet/bin/activate 19 | ``` 20 | Clone and install requirements. 21 | ```bash 22 | cd ~ 23 | git clone https://github.com/basveeling/wavenet.git 24 | cd wavenet 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | Using the tensorflow backend is not recommended at this time, see [this issue](https://github.com/basveeling/wavenet/issues/7) 29 | 30 | ## Dependencies: 31 | - [Sacred](https://github.com/IDSIA/sacred) is used for managing training and sampling. Take a look at the [documentation](http://sacred.readthedocs.io/en/latest/) for more information. 32 | 33 | - This implementation does not support python3 as of now. 34 | 35 | ## Sampling: 36 | Once the first model checkpoint is created, you can start sampling. 37 | 38 | Run: 39 | ```$ KERAS_BACKEND=theano python2 wavenet.py predict with models//config.json predict_seconds=1``` 40 | 41 | The latest model checkpoint will be retrieved and used to sample. The sample will be streamed to `[run_folder]/samples`, you can start listening when the first sample is generated. 42 | 43 | ### Sampling options: 44 | - `predict_seconds`: float. Number of seconds to sample. 45 | - `sample_argmax`: `True` or `False`. Always take the argmax 46 | - `sample_temperature`: `None` or float. Controls the sampling temperature. 1.0 for the original distribution, < 1.0 for less exploitation, > 1.0 for more exploration. 47 | - `seed`: int: Controls the seed for the sampling procedure. 48 | - `predict_initial_input`: string: Path to a wav file, for which the first `fragment_length` samples are used as initial input. 49 | 50 | e.g.: 51 | ```$ KERAS_BACKEND=theano python2 wavenet.py predict with models/[run_folder]/config.json predict_seconds=1``` 52 | 53 | ## Training: 54 | ```$ KERAS_BACKEND=theano python2 wavenet.py``` 55 | 56 | Or for a smaller network (less channels per layer). 57 | ```$ KERAS_BACKEND=theano python2 wavenet.py with small``` 58 | 59 | ### VCTK: 60 | In order to use the VCTK dataset, first download the dataset by running `vctk/download_vctk.sh`. 61 | 62 | Training is done with: 63 | ```$ KERAS_BACKEND=theano python2 wavenet.py with vctkdata``` 64 | 65 | For smaller network: 66 | ```$ KERAS_BACKEND=theano python2 wavenet.py with vctkdata small``` 67 | 68 | ### Options: 69 | Train with different configurations: 70 | ```$ KERAS_BACKEND=theano python2 wavenet.py with 'option=value' 'option2=value'``` 71 | Available options: 72 | ``` 73 | batch_size = 16 74 | data_dir = 'data' 75 | data_dir_structure = 'flat' 76 | debug = False 77 | desired_sample_rate = 4410 78 | dilation_depth = 9 79 | early_stopping_patience = 20 80 | fragment_length = 1152 81 | fragment_stride = 128 82 | keras_verbose = 1 83 | learn_all_outputs = True 84 | nb_epoch = 1000 85 | nb_filters = 256 86 | nb_output_bins = 256 87 | nb_stacks = 1 88 | predict_initial_input = '' 89 | predict_seconds = 1 90 | predict_use_softmax_as_input = False 91 | random_train_batches = False 92 | randomize_batch_order = True 93 | run_dir = None 94 | sample_argmax = False 95 | sample_temperature = 1 96 | seed = 173213366 97 | test_factor = 0.1 98 | train_only_in_receptive_field = True 99 | use_bias = False 100 | use_skip_connections = True 101 | use_ulaw = True 102 | optimizer: 103 | decay = 0.0 104 | epsilon = None 105 | lr = 0.001 106 | momentum = 0.9 107 | nesterov = True 108 | optimizer = 'sgd' 109 | ``` 110 | 111 | ## Using your own training data: 112 | - Create a new data directory with a train and test folder in it. All wave files in these folders will be used as data. 113 | - Caveat: Make sure your wav files are supported by scipy.io.wavefile.read(): e.g. don't use 24bit wav and remove meta info. 114 | - Run with: `$ python2 wavenet.py with 'data_dir=your_data_dir_name'` 115 | - Test preprocessing results with: `$ python2 wavenet.py test_preprocess with 'data_dir=your_data_dir_name'` 116 | 117 | ## Todo: 118 | - [ ] Local conditioning 119 | - [ ] Global conditioning 120 | - [x] Training on CSTR VCTK Corpus 121 | - [x] CLI option to pick a wave file for the sample generation initial input. Done: see `predict_initial_input`. 122 | - [x] Fully randomized training batches 123 | - [x] Soft targets: by convolving a gaussian kernel over the one-hot targets, the network trains faster. 124 | - [ ] Decaying soft targets: the stdev of the gaussian kernel should slowly decay. 125 | 126 | 127 | ## Uncertainties from paper: 128 | - It's unclear if the model is trained to predict t+1 samples for every input sample, or only for the outputs for which which $t-receptive_field$ was in the input. Right now the code does the latter. 129 | - There is no mention of weight decay, batch normalization in the paper. Perhaps this is not needed given enough data? 130 | 131 | ## Note on computational cost: 132 | The Wavenet model is quite expensive to train and sample from. We can however trade computation cost with accuracy and fidility by lowering the sampling rate, amount of stacks and the amount of channels per layer. 133 | 134 | For a downsized model (4000hz vs 16000 sampling rate, 16 filters v/s 256, 2 stacks vs ??): 135 | - A Tesla K80 needs around ~4 minutes to generate one second of audio. 136 | - A recent macbook pro needs around ~15 minutes. 137 | Deepmind has reported that generating one second of audio with their model takes about 90 minutes. 138 | 139 | ## Disclaimer 140 | This is a re-implementation of the model described in the WaveNet paper by Google Deepmind. This repository is not associated with Google Deepmind. 141 | -------------------------------------------------------------------------------- /convtest.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | import keras.backend as K 5 | 6 | # A test script to validate causal dilated convolutions 7 | dilation = 2 8 | input = T.fvector() 9 | filters = T.fvector() # (output channels, input channels, filter rows, filter columns). 10 | input_reshaped = T.reshape(input,(1,-1,1)) 11 | input_reshaped = K.asymmetric_temporal_padding(input_reshaped,left_pad=dilation, right_pad=0) 12 | input_reshaped = T.reshape(input_reshaped,(1,1,-1,1)) 13 | filters_reshaped = T.reshape(filters,(1,1,-1,1)) 14 | out = T.nnet.conv2d(input_reshaped,filters_reshaped, border_mode='valid',filter_dilation=(dilation,1)) 15 | out = T.reshape(out,(1,-1,1)) 16 | out = K.asymmetric_temporal_padding(out,left_pad=dilation, right_pad=0) 17 | out = T.reshape(out,(1,1,-1,1)) 18 | out = T.nnet.conv2d(out,filters_reshaped, border_mode='valid',filter_dilation=(dilation,1)) 19 | out = T.flatten(out) 20 | 21 | in_input = np.arange(8,dtype='float32') 22 | in_filters = np.array([1,1],dtype='float32') 23 | f = theano.function([input,filters],out) 24 | print "".join(["%3.0f" % i for i in in_input]) 25 | print "".join(["%3.0f" % i for i in f(in_input,in_filters)]) 26 | -------------------------------------------------------------------------------- /data/SOURCE.txt: -------------------------------------------------------------------------------- 1 | Dataset of 2 chopin pieces played by 22 pianist in fixed recording setting. 2 | Download from: http://iwk.mdw.ac.at/goebl/mp3.html 3 | -------------------------------------------------------------------------------- /data/test/processed_4410_ulaw.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basveeling/wavenet/bf8ef958372692ecb32e8540f7c81f69a186eb8d/data/test/processed_4410_ulaw.npy -------------------------------------------------------------------------------- /data/train/processed_4410_ulaw.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basveeling/wavenet/bf8ef958372692ecb32e8540f7c81f69a186eb8d/data/train/processed_4410_ulaw.npy -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | """ 3 | from __future__ import division 4 | 5 | import math 6 | import os 7 | import warnings 8 | 9 | import numpy as np 10 | import scipy.io.wavfile 11 | import scipy.signal 12 | from picklable_itertools import cycle 13 | from picklable_itertools.extras import partition_all 14 | from tqdm import tqdm 15 | 16 | 17 | # TODO: make SACRED ingredient. 18 | def one_hot(x): 19 | return np.eye(256, dtype='uint8')[x.astype('uint8')] 20 | 21 | 22 | def fragment_indices(full_sequences, fragment_length, batch_size, fragment_stride, nb_output_bins): 23 | for seq_i, sequence in enumerate(full_sequences): 24 | # range_values = np.linspace(np.iinfo(sequence.dtype).min, np.iinfo(sequence.dtype).max, nb_output_bins) 25 | # digitized = np.digitize(sequence, range_values).astype('uint8') 26 | for i in range(0, sequence.shape[0] - fragment_length, fragment_stride): 27 | yield seq_i, i 28 | 29 | 30 | def select_generator(set_name, random_train_batches, full_sequences, fragment_length, batch_size, fragment_stride, 31 | nb_output_bins, randomize_batch_order, _rnd): 32 | if random_train_batches and set_name == 'train': 33 | bg = random_batch_generator 34 | else: 35 | bg = batch_generator 36 | return bg(full_sequences, fragment_length, batch_size, fragment_stride, nb_output_bins, randomize_batch_order, _rnd) 37 | 38 | 39 | def batch_generator(full_sequences, fragment_length, batch_size, fragment_stride, nb_output_bins, randomize_batch_order, _rnd): 40 | indices = list(fragment_indices(full_sequences, fragment_length, batch_size, fragment_stride, nb_output_bins)) 41 | if randomize_batch_order: 42 | _rnd.shuffle(indices) 43 | 44 | batches = cycle(partition_all(batch_size, indices)) 45 | for batch in batches: 46 | if len(batch) < batch_size: 47 | continue 48 | yield np.array( 49 | [one_hot(full_sequences[e[0]][e[1]:e[1] + fragment_length]) for e in batch], dtype='uint8'), np.array( 50 | [one_hot(full_sequences[e[0]][e[1] + 1:e[1] + fragment_length + 1]) for e in batch], dtype='uint8') 51 | 52 | 53 | def random_batch_generator(full_sequences, fragment_length, batch_size, fragment_stride, nb_output_bins, 54 | randomize_batch_order, _rnd): 55 | lengths = [x.shape[0] for x in full_sequences] 56 | nb_sequences = len(full_sequences) 57 | while True: 58 | sequence_indices = _rnd.randint(0, nb_sequences, batch_size) 59 | batch_inputs = [] 60 | batch_outputs = [] 61 | for i, seq_i in enumerate(sequence_indices): 62 | l = lengths[seq_i] 63 | offset = np.squeeze(_rnd.randint(0, l - fragment_length, 1)) 64 | batch_inputs.append(full_sequences[seq_i][offset:offset + fragment_length]) 65 | batch_outputs.append(full_sequences[seq_i][offset + 1:offset + fragment_length + 1]) 66 | yield one_hot(np.array(batch_inputs, dtype='uint8')), one_hot(np.array(batch_outputs, dtype='uint8')) 67 | 68 | 69 | def generators(dirname, desired_sample_rate, fragment_length, batch_size, fragment_stride, nb_output_bins, 70 | learn_all_outputs, use_ulaw, randomize_batch_order, _rnd, random_train_batches): 71 | fragment_generators = {} 72 | nb_examples = {} 73 | for set_name in ['train', 'test']: 74 | set_dirname = os.path.join(dirname, set_name) 75 | full_sequences = load_set(desired_sample_rate, set_dirname, use_ulaw) 76 | fragment_generators[set_name] = select_generator(set_name, random_train_batches, full_sequences, 77 | fragment_length, 78 | batch_size, fragment_stride, nb_output_bins, 79 | randomize_batch_order, _rnd) 80 | nb_examples[set_name] = int(sum( 81 | [len(range(0, x.shape[0] - fragment_length, fragment_stride)) for x in 82 | full_sequences]) / batch_size) * batch_size 83 | 84 | return fragment_generators, nb_examples 85 | 86 | 87 | def generators_vctk(dirname, desired_sample_rate, fragment_length, batch_size, fragment_stride, nb_output_bins, 88 | learn_all_outputs, use_ulaw, test_factor, randomize_batch_order, _rnd, random_train_batches): 89 | fragment_generators = {} 90 | nb_examples = {} 91 | speaker_dirs = os.listdir(dirname) 92 | train_full_sequences = [] 93 | test_full_sequences = [] 94 | for speaker_dir in speaker_dirs: 95 | full_sequences = load_set(desired_sample_rate, os.path.join(dirname, speaker_dir), use_ulaw) 96 | nb_examples_train = int(math.ceil(len(full_sequences) * (1 - test_factor))) 97 | train_full_sequences.extend(full_sequences[0:nb_examples_train]) 98 | test_full_sequences.extend(full_sequences[nb_examples_train:]) 99 | 100 | for set_name, set_sequences in zip(['train', 'test'], [train_full_sequences, test_full_sequences]): 101 | fragment_generators[set_name] = select_generator(set_name, random_train_batches, full_sequences, 102 | fragment_length, 103 | batch_size, fragment_stride, nb_output_bins, 104 | randomize_batch_order, _rnd) 105 | nb_examples[set_name] = int(sum( 106 | [len(range(0, x.shape[0] - fragment_length, fragment_stride)) for x in 107 | full_sequences]) / batch_size) * batch_size 108 | 109 | return fragment_generators, nb_examples 110 | 111 | 112 | def load_set(desired_sample_rate, set_dirname, use_ulaw): 113 | ulaw_str = '_ulaw' if use_ulaw else '' 114 | cache_fn = os.path.join(set_dirname, 'processed_%d%s.npy' % (desired_sample_rate, ulaw_str)) 115 | if os.path.isfile(cache_fn): 116 | full_sequences = np.load(cache_fn) 117 | else: 118 | file_names = [fn for fn in os.listdir(set_dirname) if fn.endswith('.wav')] 119 | full_sequences = [] 120 | for fn in tqdm(file_names): 121 | sequence = process_wav(desired_sample_rate, os.path.join(set_dirname, fn), use_ulaw) 122 | full_sequences.append(sequence) 123 | np.save(cache_fn, full_sequences) 124 | 125 | return full_sequences 126 | 127 | 128 | def process_wav(desired_sample_rate, filename, use_ulaw): 129 | with warnings.catch_warnings(): 130 | warnings.simplefilter("error") 131 | channels = scipy.io.wavfile.read(filename) 132 | file_sample_rate, audio = channels 133 | audio = ensure_mono(audio) 134 | audio = wav_to_float(audio) 135 | if use_ulaw: 136 | audio = ulaw(audio) 137 | audio = ensure_sample_rate(desired_sample_rate, file_sample_rate, audio) 138 | audio = float_to_uint8(audio) 139 | return audio 140 | 141 | 142 | def ulaw(x, u=255): 143 | x = np.sign(x) * (np.log(1 + u * np.abs(x)) / np.log(1 + u)) 144 | return x 145 | 146 | 147 | def float_to_uint8(x): 148 | x += 1. 149 | x /= 2. 150 | uint8_max_value = np.iinfo('uint8').max 151 | x *= uint8_max_value 152 | x = x.astype('uint8') 153 | return x 154 | 155 | 156 | def wav_to_float(x): 157 | try: 158 | max_value = np.iinfo(x.dtype).max 159 | min_value = np.iinfo(x.dtype).min 160 | except: 161 | max_value = np.finfo(x.dtype).max 162 | min_value = np.iinfo(x.dtype).min 163 | x = x.astype('float64', casting='safe') 164 | x -= min_value 165 | x /= ((max_value - min_value) / 2.) 166 | x -= 1. 167 | return x 168 | 169 | 170 | def ulaw2lin(x, u=255.): 171 | max_value = np.iinfo('uint8').max 172 | min_value = np.iinfo('uint8').min 173 | x = x.astype('float64', casting='safe') 174 | x -= min_value 175 | x /= ((max_value - min_value) / 2.) 176 | x -= 1. 177 | x = np.sign(x) * (1 / u) * (((1 + u) ** np.abs(x)) - 1) 178 | x = float_to_uint8(x) 179 | return x 180 | 181 | def ensure_sample_rate(desired_sample_rate, file_sample_rate, mono_audio): 182 | if file_sample_rate != desired_sample_rate: 183 | mono_audio = scipy.signal.resample_poly(mono_audio, desired_sample_rate, file_sample_rate) 184 | return mono_audio 185 | 186 | 187 | def ensure_mono(raw_audio): 188 | """ 189 | Just use first channel. 190 | """ 191 | if raw_audio.ndim == 2: 192 | raw_audio = raw_audio[:, 0] 193 | return raw_audio 194 | 195 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | picklable_itertools~=0.1.1 2 | sacred~=0.6.10 3 | tqdm~=4.8.4 4 | q~=2.6 5 | keras==2.1.2 6 | tensorflow-gpu==1.8.0 7 | h5py==2.7.1 8 | scipy==1.0.0 9 | matplotlib==2.1.1 -------------------------------------------------------------------------------- /vctk/download_vctk.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Available diskspace in $(pwd):" 3 | df -h . 4 | read -p "This will download the VCTK corpus (11Gb) and extract it (14.9Gb), are you sure (y/n): " -n 1 -r 5 | echo 6 | if [[ $REPLY =~ ^[Yy]$ ]] 7 | then 8 | wget http://homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz 9 | if [[ "$OSTYPE" == "darwin"* ]]; then 10 | open VCTK-Corpus.tar.gz 11 | else 12 | tar -xvf VCTK-Corpus.tar.gz 13 | fi 14 | fi 15 | -------------------------------------------------------------------------------- /wavenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import datetime 4 | import json 5 | import os 6 | import re 7 | import wave 8 | 9 | import keras.backend as K 10 | import numpy as np 11 | import scipy.io.wavfile 12 | import scipy.signal 13 | from keras import layers 14 | from keras import metrics 15 | from keras import objectives 16 | from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, CSVLogger 17 | from keras.engine import Input 18 | from keras.engine import Model 19 | from keras.optimizers import Adam, SGD 20 | from keras.regularizers import l2 21 | from sacred import Experiment 22 | from sacred.commands import print_config 23 | from tqdm import tqdm 24 | from time import gmtime, strftime 25 | from keras.callbacks import TensorBoard 26 | 27 | import dataset 28 | from wavenet_utils import CausalAtrousConvolution1D, categorical_mean_squared_error 29 | 30 | ex = Experiment('wavenet') 31 | 32 | 33 | @ex.config 34 | def config(): 35 | data_dir = 'data' 36 | data_dir_structure = 'flat' # Or 'vctk' for a speakerdir structure 37 | test_factor = 0.1 # For 'vctk' structure, take test_factor amount of sequences for test set. 38 | nb_epoch = 1000 39 | run_dir = None 40 | early_stopping_patience = 20 41 | desired_sample_rate = 4410 42 | batch_size = 16 43 | nb_output_bins = 256 44 | nb_filters = 256 45 | dilation_depth = 9 # 46 | nb_stacks = 1 47 | use_bias = False 48 | use_ulaw = True 49 | res_l2 = 0 50 | final_l2 = 0 51 | fragment_length = 128 + compute_receptive_field_(desired_sample_rate, dilation_depth, nb_stacks)[0] 52 | fragment_stride = 128 53 | use_skip_connections = True 54 | optimizer = { 55 | 'optimizer': 'sgd', 56 | 'lr': 0.001, 57 | 'momentum': 0.9, 58 | 'decay': 0., 59 | 'nesterov': True, 60 | 'epsilon': None 61 | } 62 | learn_all_outputs = True 63 | random_train_batches = False 64 | randomize_batch_order = True # Only effective if not using random train batches 65 | train_with_soft_target_stdev = None # float to make targets a gaussian with stdev. 66 | 67 | # The temporal-first outputs are computed from zero-padding. Setting below to True ignores these inputs: 68 | train_only_in_receptive_field = True 69 | 70 | keras_verbose = 1 71 | debug = False 72 | 73 | 74 | @ex.named_config 75 | def book(): 76 | desired_sample_rate = 4000 77 | data_dir = 'data_book' 78 | dilation_depth = 8 79 | nb_stacks = 1 80 | fragment_length = 2 ** 10 81 | nb_filters = 256 82 | batch_size = 16 83 | fragment_stride = compute_receptive_field_(desired_sample_rate, dilation_depth, nb_stacks)[0] 84 | 85 | 86 | @ex.named_config 87 | def small(): 88 | desired_sample_rate = 4410 89 | nb_filters = 16 90 | dilation_depth = 8 91 | nb_stacks = 1 92 | fragment_length = 128 + (compute_receptive_field_(desired_sample_rate, dilation_depth, nb_stacks)[0]) 93 | fragment_stride = int(desired_sample_rate / 10) 94 | 95 | 96 | @ex.named_config 97 | def soft_targets(): 98 | train_with_soft_target_stdev = 0.5 99 | # TODO: smooth decay of stdev per epoch. 100 | 101 | 102 | @ex.named_config 103 | def vctkdata(): 104 | assert os.path.isdir(os.path.join('vctk', 'VCTK-Corpus')), "Please download vctk by running vctk/download_vctk.sh." 105 | desired_sample_rate = 4000 106 | data_dir = 'vctk/VCTK-Corpus/wav48' 107 | data_dir_structure = 'vctk' 108 | test_factor = 0.01 109 | 110 | 111 | @ex.named_config 112 | def vctkmod(desired_sample_rate): 113 | nb_filters = 32 114 | dilation_depth = 7 115 | nb_stacks = 4 116 | fragment_length = 1 + (compute_receptive_field_(desired_sample_rate, dilation_depth, nb_stacks)[0]) 117 | fragment_stride = int(desired_sample_rate / 10) 118 | random_train_batches = True 119 | 120 | 121 | @ex.named_config 122 | def length32(desired_sample_rate, dilation_depth, nb_stacks): 123 | fragment_length = 32 + (compute_receptive_field_(desired_sample_rate, dilation_depth, nb_stacks)[0]) 124 | 125 | 126 | @ex.named_config 127 | def adam(): 128 | optimizer = { 129 | 'optimizer': 'adam', 130 | 'lr': 0.001, 131 | 'decay': 0., 132 | 'epsilon': 1e-8 133 | } 134 | 135 | 136 | @ex.named_config 137 | def adam2(): 138 | optimizer = { 139 | 'optimizer': 'adam', 140 | 'lr': 0.01, 141 | 'decay': 0., 142 | 'epsilon': 1e-10 143 | } 144 | 145 | 146 | @ex.config 147 | def predict_config(): 148 | predict_seconds = 1 149 | sample_argmax = False 150 | sample_temperature = 1.0 # Temperature for sampling. > 1.0 for more exploring, < 1.0 for conservative samples. 151 | predict_use_softmax_as_input = False # Uses the softmax rather than the argmax as in input for the next step. 152 | predict_initial_input = None 153 | 154 | 155 | @ex.named_config 156 | def batch_run(): 157 | keras_verbose = 2 158 | 159 | 160 | def skip_out_of_receptive_field(func): 161 | # TODO: consider using keras masking for this? 162 | receptive_field, _ = compute_receptive_field() 163 | 164 | def wrapper(y_true, y_pred): 165 | y_true = y_true[:, receptive_field - 1:, :] 166 | y_pred = y_pred[:, receptive_field - 1:, :] 167 | return func(y_true, y_pred) 168 | 169 | wrapper.__name__ = func.__name__ 170 | 171 | return wrapper 172 | 173 | 174 | def print_t(tensor, label): 175 | tensor.name = label 176 | # tensor = theano.printing.Print(tensor.name, attrs=('__str__', 'shape'))(tensor) 177 | return tensor 178 | 179 | 180 | @ex.capture 181 | def make_soft(y_true, fragment_length, nb_output_bins, train_with_soft_target_stdev, with_prints=False): 182 | receptive_field, _ = compute_receptive_field() 183 | n_outputs = fragment_length - receptive_field + 1 184 | 185 | # Make a gaussian kernel. 186 | kernel_v = scipy.signal.gaussian(9, std=train_with_soft_target_stdev) 187 | print(kernel_v) 188 | kernel_v = np.reshape(kernel_v, [1, 1, -1, 1]) 189 | kernel = K.variable(kernel_v) 190 | 191 | if with_prints: 192 | y_true = print_t(y_true, 'y_true initial') 193 | 194 | # y_true: [batch, timesteps, input_dim] 195 | y_true = K.reshape(y_true, (-1, 1, nb_output_bins, 1)) # Same filter for all output; combine with batch. 196 | # y_true: [batch*timesteps, n_channels=1, input_dim, dummy] 197 | y_true = K.conv2d(y_true, kernel, padding='same') 198 | y_true = K.reshape(y_true, (-1, n_outputs, nb_output_bins)) # Same filter for all output; combine with batch. 199 | # y_true: [batch, timesteps, input_dim] 200 | y_true /= K.sum(y_true, axis=-1, keepdims=True) 201 | 202 | if with_prints: 203 | y_true = print_t(y_true, 'y_true after') 204 | return y_true 205 | 206 | 207 | def make_targets_soft(func): 208 | """Turns one-hot into gaussian distributed.""" 209 | 210 | def wrapper(y_true, y_pred): 211 | y_true = make_soft(y_true) 212 | y_pred = y_pred 213 | return func(y_true, y_pred) 214 | 215 | wrapper.__name__ = func.__name__ 216 | 217 | return wrapper 218 | 219 | 220 | @ex.capture() 221 | def build_model(fragment_length, nb_filters, nb_output_bins, dilation_depth, nb_stacks, use_skip_connections, 222 | learn_all_outputs, _log, desired_sample_rate, use_bias, res_l2, final_l2): 223 | def residual_block(x): 224 | original_x = x 225 | # TODO: initalization, regularization? 226 | # Note: The AtrousConvolution1D with the 'causal' flag is implemented in github.com/basveeling/keras#@wavenet. 227 | tanh_out = CausalAtrousConvolution1D(nb_filters, 2, dilation_rate=2 ** i, padding='valid', causal=True, 228 | use_bias=use_bias, 229 | name='dilated_conv_%d_tanh_s%d' % (2 ** i, s), activation='tanh', 230 | kernel_regularizer=l2(res_l2))(x) 231 | sigm_out = CausalAtrousConvolution1D(nb_filters, 2, dilation_rate=2 ** i, padding='valid', causal=True, 232 | use_bias=use_bias, 233 | name='dilated_conv_%d_sigm_s%d' % (2 ** i, s), activation='sigmoid', 234 | kernel_regularizer=l2(res_l2))(x) 235 | x = layers.Multiply(name='gated_activation_%d_s%d' % (i, s))([tanh_out, sigm_out]) 236 | 237 | res_x = layers.Convolution1D(nb_filters, 1, padding='same', use_bias=use_bias, 238 | kernel_regularizer=l2(res_l2))(x) 239 | skip_x = layers.Convolution1D(nb_filters, 1, padding='same', use_bias=use_bias, 240 | kernel_regularizer=l2(res_l2))(x) 241 | res_x = layers.Add()([original_x, res_x]) 242 | return res_x, skip_x 243 | 244 | input = Input(shape=(fragment_length, nb_output_bins), name='input_part') 245 | out = input 246 | skip_connections = [] 247 | out = CausalAtrousConvolution1D(nb_filters, 2, 248 | dilation_rate=1, 249 | padding='valid', 250 | causal=True, 251 | name='initial_causal_conv' 252 | )(out) 253 | for s in range(nb_stacks): 254 | for i in range(0, dilation_depth + 1): 255 | out, skip_out = residual_block(out) 256 | skip_connections.append(skip_out) 257 | 258 | if use_skip_connections: 259 | out = layers.Add()(skip_connections) 260 | out = layers.Activation('relu')(out) 261 | out = layers.Convolution1D(nb_output_bins, 1, padding='same', 262 | kernel_regularizer=l2(final_l2))(out) 263 | out = layers.Activation('relu')(out) 264 | out = layers.Convolution1D(nb_output_bins, 1, padding='same')(out) 265 | 266 | if not learn_all_outputs: 267 | raise DeprecationWarning('Learning on just all outputs is wasteful, now learning only inside receptive field.') 268 | out = layers.Lambda(lambda x: x[:, -1, :], output_shape=(out._keras_shape[-1],))( 269 | out) # Based on gif in deepmind blog: take last output? 270 | 271 | out = layers.Activation('softmax', name="output_softmax")(out) 272 | model = Model(input, out) 273 | 274 | receptive_field, receptive_field_ms = compute_receptive_field() 275 | 276 | _log.info('Receptive Field: %d (%dms)' % (receptive_field, int(receptive_field_ms))) 277 | return model 278 | 279 | 280 | @ex.capture 281 | def compute_receptive_field(desired_sample_rate, dilation_depth, nb_stacks): 282 | return compute_receptive_field_(desired_sample_rate, dilation_depth, nb_stacks) 283 | 284 | 285 | def compute_receptive_field_(desired_sample_rate, dilation_depth, nb_stacks): 286 | receptive_field = nb_stacks * (2 ** dilation_depth * 2) - (nb_stacks - 1) 287 | receptive_field_ms = (receptive_field * 1000) / desired_sample_rate 288 | return receptive_field, receptive_field_ms 289 | 290 | 291 | @ex.capture(prefix='optimizer') 292 | def make_optimizer(optimizer, lr, momentum, decay, nesterov, epsilon): 293 | if optimizer == 'sgd': 294 | optim = SGD(lr, momentum, decay, nesterov) 295 | elif optimizer == 'adam': 296 | optim = Adam(lr=lr, decay=decay, epsilon=epsilon) 297 | else: 298 | raise ValueError('Invalid config for optimizer.optimizer: ' + optimizer) 299 | return optim 300 | 301 | 302 | @ex.command 303 | def predict(desired_sample_rate, fragment_length, _log, seed, _seed, _config, predict_seconds, data_dir, batch_size, 304 | fragment_stride, nb_output_bins, learn_all_outputs, run_dir, predict_use_softmax_as_input, use_ulaw, 305 | predict_initial_input, 306 | **kwargs): 307 | fragment_length = compute_receptive_field()[0] 308 | _config['fragment_length'] = fragment_length 309 | 310 | checkpoint_dir = os.path.join(run_dir, 'checkpoints') 311 | last_checkpoint = sorted(os.listdir(checkpoint_dir))[-1] 312 | epoch = int(re.match(r'checkpoint\.(\d+?)-.*', last_checkpoint).group(1)) 313 | _log.info('Using checkpoint from epoch: %s' % epoch) 314 | 315 | sample_dir = os.path.join(run_dir, 'samples') 316 | if not os.path.exists(sample_dir): 317 | os.mkdir(sample_dir) 318 | 319 | sample_name = make_sample_name(epoch) 320 | sample_filename = os.path.join(sample_dir, sample_name) 321 | 322 | _log.info('Saving to "%s"' % sample_filename) 323 | 324 | sample_stream = make_sample_stream(desired_sample_rate, sample_filename) 325 | 326 | model = build_model() 327 | model.load_weights(os.path.join(checkpoint_dir, last_checkpoint)) 328 | model.summary() 329 | 330 | if predict_initial_input is None: 331 | outputs = list(dataset.one_hot(np.zeros(fragment_length) + nb_output_bins / 2)) 332 | elif predict_initial_input != '': 333 | _log.info('Taking first %d (%.2fs) from \'%s\' as initial input.' % ( 334 | fragment_length, fragment_length / desired_sample_rate, predict_initial_input)) 335 | wav = dataset.process_wav(desired_sample_rate, predict_initial_input, use_ulaw) 336 | outputs = list(dataset.one_hot(wav[0:fragment_length])) 337 | else: 338 | _log.info('Taking sample from test dataset as initial input.') 339 | data_generators, _ = get_generators() 340 | outputs = list(data_generators['test'].next()[0][-1]) 341 | 342 | # write_samples(sample_stream, outputs) 343 | warned_repetition = False 344 | for i in tqdm(range(int(desired_sample_rate * predict_seconds))): 345 | if not warned_repetition: 346 | if np.argmax(outputs[-1]) == np.argmax(outputs[-2]) and np.argmax(outputs[-2]) == np.argmax(outputs[-3]): 347 | warned_repetition = True 348 | _log.warning('Last three predicted outputs where %d' % np.argmax(outputs[-1])) 349 | else: 350 | warned_repetition = False 351 | prediction_seed = np.expand_dims(np.array(outputs[i:i + fragment_length]), 0) 352 | output = model.predict(prediction_seed) 353 | output_dist = output[0][-1] 354 | output_val = draw_sample(output_dist) 355 | if predict_use_softmax_as_input: 356 | outputs.append(output_dist) 357 | else: 358 | outputs.append(output_val) 359 | write_samples(sample_stream, [output_val]) 360 | 361 | sample_stream.close() 362 | 363 | _log.info("Done!") 364 | 365 | 366 | @ex.capture 367 | def make_sample_name(epoch, predict_seconds, predict_use_softmax_as_input, sample_argmax, sample_temperature, seed): 368 | sample_str = '' 369 | if predict_use_softmax_as_input: 370 | sample_str += '_soft-in' 371 | if sample_argmax: 372 | sample_str += '_argmax' 373 | else: 374 | sample_str += '_sample' 375 | if sample_temperature: 376 | sample_str += '-temp-%s' % sample_temperature 377 | sample_name = 'sample_epoch-%05d_%02ds_%s_seed-%d.wav' % (epoch, int(predict_seconds), sample_str, seed) 378 | return sample_name 379 | 380 | 381 | @ex.capture 382 | def write_samples(sample_file, out_val, use_ulaw): 383 | s = np.argmax(out_val, axis=-1).astype('uint8') 384 | # print out_val, 385 | if use_ulaw: 386 | s = dataset.ulaw2lin(s) 387 | # print s, 388 | s = bytearray(list(s)) 389 | # print s[0] 390 | sample_file.writeframes(s) 391 | sample_file._file.flush() 392 | 393 | 394 | @ex.capture 395 | def get_generators(batch_size, data_dir, desired_sample_rate, fragment_length, fragment_stride, learn_all_outputs, 396 | nb_output_bins, use_ulaw, test_factor, data_dir_structure, randomize_batch_order, _rnd, 397 | random_train_batches): 398 | if data_dir_structure == 'flat': 399 | return dataset.generators(data_dir, desired_sample_rate, fragment_length, batch_size, 400 | fragment_stride, nb_output_bins, learn_all_outputs, use_ulaw, randomize_batch_order, 401 | _rnd, random_train_batches) 402 | 403 | elif data_dir_structure == 'vctk': 404 | return dataset.generators_vctk(data_dir, desired_sample_rate, fragment_length, batch_size, 405 | fragment_stride, nb_output_bins, learn_all_outputs, use_ulaw, test_factor, 406 | randomize_batch_order, _rnd, random_train_batches) 407 | else: 408 | raise ValueError('data_dir_structure must be "flat" or "vctk", is %s' % data_dir_structure) 409 | 410 | 411 | @ex.command 412 | def test_make_soft(_log, train_with_soft_target_stdev, _config): 413 | if train_with_soft_target_stdev is None: 414 | _config['train_with_soft_target_stdev'] = 1 415 | y_true = K.reshape(K.eye(512)[:129, :256], (2, 129, 256)) 416 | y_soft = make_soft(y_true) 417 | f = K.function([], y_soft) 418 | _log.info('Output of soft:') 419 | f1 = f([]) 420 | 421 | _log.info(f1[0, 0]) 422 | _log.info(f1[-1, -1]) 423 | 424 | 425 | @ex.command 426 | def test_preprocess(desired_sample_rate, batch_size, use_ulaw): 427 | sample_dir = os.path.join('preprocess_test') 428 | if not os.path.exists(sample_dir): 429 | os.mkdir(sample_dir) 430 | 431 | ulaw_str = '_ulaw' if use_ulaw else '' 432 | sample_filename = os.path.join(sample_dir, 'test1%s.wav' % ulaw_str) 433 | sample_stream = make_sample_stream(desired_sample_rate, sample_filename) 434 | 435 | data_generators, _ = get_generators() 436 | outputs = data_generators['test'].next()[0][1].astype('uint8') 437 | 438 | write_samples(sample_stream, outputs) 439 | scipy.io.wavfile.write(os.path.join(sample_dir, 'test2%s.wav' % ulaw_str), desired_sample_rate, 440 | np.argmax(outputs, axis=-1).astype('uint8')) 441 | 442 | 443 | def make_sample_stream(desired_sample_rate, sample_filename): 444 | sample_file = wave.open(sample_filename, mode='w') 445 | sample_file.setnchannels(1) 446 | sample_file.setframerate(desired_sample_rate) 447 | sample_file.setsampwidth(1) 448 | return sample_file 449 | 450 | 451 | def softmax(x, temp, mod=np): 452 | x = mod.log(x) / temp 453 | e_x = mod.exp(x - mod.max(x, axis=-1)) 454 | return e_x / mod.sum(e_x, axis=-1) 455 | 456 | 457 | @ex.capture 458 | def draw_sample(output_dist, sample_temperature, sample_argmax, _rnd): 459 | if sample_argmax: 460 | output_dist = np.eye(256)[np.argmax(output_dist, axis=-1)] 461 | else: 462 | if sample_temperature is not None: 463 | output_dist = softmax(output_dist, sample_temperature) 464 | output_dist = output_dist / np.sum(output_dist + 1e-7) 465 | output_dist = _rnd.multinomial(1, output_dist) 466 | return output_dist 467 | 468 | 469 | @ex.automain 470 | def main(run_dir, data_dir, nb_epoch, early_stopping_patience, desired_sample_rate, fragment_length, batch_size, 471 | fragment_stride, nb_output_bins, keras_verbose, _log, seed, _config, debug, learn_all_outputs, 472 | train_only_in_receptive_field, _run, use_ulaw, train_with_soft_target_stdev): 473 | if run_dir is None: 474 | if not os.path.exists("models"): 475 | os.mkdir("models") 476 | run_dir = os.path.join('models', datetime.datetime.now().strftime('run_%Y%m%d_%H%M%S')) 477 | _config['run_dir'] = run_dir 478 | 479 | print_config(_run) 480 | 481 | _log.info('Running with seed %d' % seed) 482 | 483 | if not debug: 484 | if os.path.exists(run_dir): 485 | raise EnvironmentError('Run with seed %d already exists' % seed) 486 | os.mkdir(run_dir) 487 | checkpoint_dir = os.path.join(run_dir, 'checkpoints') 488 | json.dump(_config, open(os.path.join(run_dir, 'config.json'), 'w')) 489 | 490 | _log.info('Loading data...') 491 | data_generators, nb_examples = get_generators() 492 | 493 | _log.info('Building model...') 494 | model = build_model(fragment_length) 495 | _log.info(model.summary()) 496 | 497 | optim = make_optimizer() 498 | _log.info('Compiling Model...') 499 | 500 | loss = objectives.categorical_crossentropy 501 | all_metrics = [ 502 | metrics.categorical_accuracy, 503 | categorical_mean_squared_error 504 | ] 505 | if train_with_soft_target_stdev: 506 | loss = make_targets_soft(loss) 507 | if train_only_in_receptive_field: 508 | loss = skip_out_of_receptive_field(loss) 509 | all_metrics = [skip_out_of_receptive_field(m) for m in all_metrics] 510 | 511 | model.compile(optimizer=optim, loss=loss, metrics=all_metrics) 512 | # TODO: Consider gradient weighting making last outputs more important. 513 | 514 | tictoc = strftime("%a_%d_%b_%Y_%H_%M_%S", gmtime()) 515 | directory_name = tictoc 516 | log_dir = 'wavenet_' + directory_name 517 | os.mkdir(log_dir) 518 | tensorboard = TensorBoard(log_dir=log_dir) 519 | 520 | callbacks = [ 521 | tensorboard, 522 | ReduceLROnPlateau(patience=early_stopping_patience / 2, cooldown=early_stopping_patience / 4, verbose=1), 523 | EarlyStopping(patience=early_stopping_patience, verbose=1), 524 | ] 525 | if not debug: 526 | callbacks.extend([ 527 | ModelCheckpoint(os.path.join(checkpoint_dir, 'checkpoint.{epoch:05d}-{val_loss:.3f}.hdf5'), 528 | save_best_only=True), 529 | CSVLogger(os.path.join(run_dir, 'history.csv')), 530 | ]) 531 | 532 | if not debug: 533 | os.mkdir(checkpoint_dir) 534 | _log.info('Starting Training...') 535 | 536 | print("nb_examples['train'] {0}".format(nb_examples['train'])) 537 | print("nb_examples['test'] {0}".format(nb_examples['test'])) 538 | 539 | model.fit_generator(data_generators['train'], 540 | steps_per_epoch=nb_examples['train'] // batch_size, 541 | epochs=nb_epoch, 542 | validation_data=data_generators['test'], 543 | validation_steps=nb_examples['test'] // batch_size, 544 | callbacks=callbacks, 545 | verbose=keras_verbose) 546 | -------------------------------------------------------------------------------- /wavenet_utils.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | from keras.layers.convolutional import Conv1D 3 | from keras.utils.conv_utils import conv_output_length 4 | import tensorflow as tf 5 | 6 | 7 | def asymmetric_temporal_padding(x, left_pad=1, right_pad=1): 8 | '''Pad the middle dimension of a 3D tensor 9 | with "left_pad" zeros left and "right_pad" right. 10 | ''' 11 | pattern = [[0, 0], [left_pad, right_pad], [0, 0]] 12 | return tf.pad(x, pattern) 13 | 14 | 15 | def categorical_mean_squared_error(y_true, y_pred): 16 | """MSE for categorical variables.""" 17 | return K.mean(K.square(K.argmax(y_true, axis=-1) - 18 | K.argmax(y_pred, axis=-1))) 19 | 20 | 21 | class CausalAtrousConvolution1D(Conv1D): 22 | def __init__(self, filters, kernel_size, init='glorot_uniform', activation=None, 23 | padding='valid', strides=1, dilation_rate=1, bias_regularizer=None, 24 | activity_regularizer=None, kernel_constraint=None, bias_constraint=None, use_bias=True, causal=False, **kwargs): 25 | super(CausalAtrousConvolution1D, self).__init__(filters, 26 | kernel_size=kernel_size, 27 | strides=strides, 28 | padding=padding, 29 | dilation_rate=dilation_rate, 30 | activation=activation, 31 | use_bias=use_bias, 32 | kernel_initializer=init, 33 | activity_regularizer=activity_regularizer, 34 | bias_regularizer=bias_regularizer, 35 | kernel_constraint=kernel_constraint, 36 | bias_constraint=bias_constraint, 37 | **kwargs) 38 | 39 | self.causal = causal 40 | if self.causal and padding != 'valid': 41 | raise ValueError("Causal mode dictates border_mode=valid.") 42 | 43 | def compute_output_shape(self, input_shape): 44 | input_length = input_shape[1] 45 | 46 | if self.causal: 47 | input_length += self.dilation_rate[0] * (self.kernel_size[0] - 1) 48 | 49 | length = conv_output_length(input_length, 50 | self.kernel_size[0], 51 | self.padding, 52 | self.strides[0], 53 | dilation=self.dilation_rate[0]) 54 | 55 | return (input_shape[0], length, self.filters) 56 | 57 | def call(self, x): 58 | if self.causal: 59 | x = asymmetric_temporal_padding(x, self.dilation_rate[0] * (self.kernel_size[0] - 1), 0) 60 | return super(CausalAtrousConvolution1D, self).call(x) 61 | --------------------------------------------------------------------------------