├── Preprocessing and Experiments.ipynb ├── README.md ├── audio_params.json ├── encode_and_reconstruct.py ├── find_similar.py ├── generate.py ├── griffin_lim.py ├── model.py ├── model_iaf.py ├── params.json ├── spec_reader.py ├── train.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # Spectrogram VAE 2 | TensorFlow implementation of a Variational Autencoder with Inverse Autoregressive Flows for encoding spectrograms. 3 | 4 | This is the main model I used for my [NeuralFunk project](https://towardsdatascience.com/neuralfunk-combining-deep-learning-with-sound-design-91935759d628). 5 | 6 | This code was not really intended to be shared and is quite messy. I might improve it at some point in the future, but for now be aware that everything is quite hacky and badly documented. 7 | 8 | ## Acknowledgments 9 | * The preprocessing as well as the encoder architecture were heavily inspired by [this iPython Notebook](https://gist.github.com/naotokui/a2b331dd206b13a70800e862cfe7da3c). 10 | * A lot of the data-feeding code and many other bits and pieces were adapted from [this Wavenet implementation](https://github.com/ibab/tensorflow-wavenet). 11 | * The Griffin-Lim algorithm was taken from the [Magenta NSynth utils](https://github.com/tensorflow/magenta/blob/master/magenta/models/nsynth/utils.py). 12 | 13 | ## Overview 14 | Some random experiments, as well as the creation of the dataset for the VAE can be found in [Preprocessing and Experiments.ipynb](https://github.com/maxfrenzel/SpectrogramVAE/blob/master/Preprocessing%20and%20Experiments.ipynb). 15 | 16 | The dataset pickle file has to be a dictionary of the form 17 | ``` 18 | { 19 | 'filenames' : list_of_filenames, 20 | 'melspecs' : list_of_spectrogram_arrays, 21 | 'actual_lengths' : list_of_audio_len_in_sec 22 | } 23 | ``` 24 | and be stored as `dataset.pkl` in the root directory. 25 | 26 | ### Training the VAE 27 | ```python train.py``` 28 | 29 | ### Generating samples 30 | Based on 31 | * Sampling from latent space: `python generate.py` 32 | * Single input file: `python generate.py --file_in filename` 33 | * Multiple input files: `python generate.py --file_in list_of_filenames` 34 | 35 | ### Encode audio 36 | * Single file: `python encode_and_reconstruct.py --audio_file filename` 37 | * Full dataset: `python encode_and_reconstruct.py --encode_full true` 38 | 39 | ### Finding similar files: 40 | ```python find_similar.py --target target_audio_file --sample_dirs list_of_dirs_to_search``` 41 | 42 | All the above scripts have other options and uses as well, look into the code for more details. 43 | -------------------------------------------------------------------------------- /audio_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "N_FFT": 1024, 3 | "HOP_LENGTH": 256, 4 | "SAMPLING_RATE": 16000, 5 | "MELSPEC_BANDS": 128, 6 | "sample_secs": 2 7 | } 8 | -------------------------------------------------------------------------------- /encode_and_reconstruct.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | mpl.use('TkAgg') 3 | 4 | import tensorflow as tf 5 | import os 6 | import sys 7 | import time 8 | import joblib 9 | from random import shuffle 10 | import numpy as np 11 | import argparse 12 | import json 13 | 14 | import librosa 15 | import librosa.display 16 | import matplotlib.pyplot as plt 17 | 18 | from spec_reader import * 19 | from util import * 20 | from model_iaf import * 21 | from griffin_lim import griffin_lim 22 | 23 | with open('audio_params.json', 'r') as f: 24 | param = json.load(f) 25 | 26 | N_FFT = param['N_FFT'] 27 | HOP_LENGTH = param['HOP_LENGTH'] 28 | SAMPLING_RATE = param['SAMPLING_RATE'] 29 | MELSPEC_BANDS = param['MELSPEC_BANDS'] 30 | sample_secs = param['sample_secs'] 31 | num_samples_dataset = int(sample_secs * SAMPLING_RATE) 32 | 33 | logdir = './test_iaf' 34 | max_checkpoints = 5 35 | num_steps = 10000 36 | checkpoint_every = 500 37 | batch_size = 1 38 | model_params = 'params.json' 39 | num_data = -1 40 | encode_batch_size = 128 41 | dataset_file = 'dataset.pkl' 42 | 43 | def get_arguments(): 44 | def _str_to_bool(s): 45 | """Convert string to bool (in argparse context).""" 46 | if s.lower() not in ['true', 'false']: 47 | raise ValueError('Argument needs to be a ' 48 | 'boolean, got {}'.format(s)) 49 | return {'true': True, 'false': False}[s.lower()] 50 | 51 | # TODO: Some of these paramters clash and if not chosen correctly crash the script 52 | parser = argparse.ArgumentParser(description='Spectrogram VAE') 53 | parser.add_argument('--num_data', type=int, default=num_data, 54 | help='How many data points to process. Default: ' + str(num_data) + '.') 55 | parser.add_argument('--logdir', type=str, default=None, 56 | help='Directory in which to store the logging ' 57 | 'information for TensorBoard. ' 58 | 'If the model already exists, it will restore ' 59 | 'the state and will continue training. ' 60 | 'Cannot use with --logdir_root and --restore_from.') 61 | parser.add_argument('--model_params', type=str, default=model_params, 62 | help='JSON file with the network parameters. Default: ' + model_params + '.') 63 | parser.add_argument('--audio_file', type=str, default=None, 64 | help='Audiofile to encode and reconstruct. If not specified, will use existing dataset instead.') 65 | parser.add_argument('--dataset_file', type=str, default=dataset_file, 66 | help='Dataset pkl file. Default: ' + dataset_file + '.') 67 | parser.add_argument('--encode_full', type=bool, default=False, 68 | help='Encode and save entire dataset? Default: ' + str(False) + '.') 69 | parser.add_argument('--encode_only_new', type=bool, default=True, 70 | help='Encode only new data points? Default: ' + str(True) + '.') 71 | parser.add_argument('--process_original_audio', type=bool, default=False, 72 | help='Process/copy original audio when saving embeddings? Default: ' + str(False) + '.') 73 | parser.add_argument('--plot_spec', type=bool, default=True, 74 | help='Plot reconstructed spectrograms? Default: ' + str(True) + '.') 75 | parser.add_argument('--rec_audio', type=bool, default=True, 76 | help='Reconstruct and save audio? Default: ' + str(True) + '.') 77 | return parser.parse_args() 78 | 79 | def save(saver, sess, logdir, step): 80 | model_name = 'model.ckpt' 81 | checkpoint_path = os.path.join(logdir, model_name) 82 | print('Storing checkpoint to {} ...'.format(logdir), end="") 83 | sys.stdout.flush() 84 | 85 | if not os.path.exists(logdir): 86 | os.makedirs(logdir) 87 | 88 | saver.save(sess, checkpoint_path, global_step=step) 89 | print(' Done.') 90 | 91 | 92 | def load(saver, sess, logdir): 93 | print("Trying to restore saved checkpoints from {} ...".format(logdir), 94 | end="") 95 | 96 | ckpt = tf.train.get_checkpoint_state(logdir) 97 | if ckpt: 98 | print(" Checkpoint found: {}".format(ckpt.model_checkpoint_path)) 99 | global_step = int(ckpt.model_checkpoint_path 100 | .split('/')[-1] 101 | .split('-')[-1]) 102 | print(" Global step was: {}".format(global_step)) 103 | print(" Restoring...", end="") 104 | saver.restore(sess, ckpt.model_checkpoint_path) 105 | print(" Done.") 106 | return global_step 107 | else: 108 | print(" No checkpoint found.") 109 | return None 110 | 111 | def main(): 112 | 113 | args = get_arguments() 114 | 115 | # Load data unless input audiofile is specified 116 | if args.audio_file: 117 | melspecs, _ = get_melspec(args.audio_file, as_tf_input=True) 118 | filename = os.path.basename(args.audio_file) 119 | else: 120 | melspecs, filenames = load_specs(filename='dataset.pkl', return_filenames=True) 121 | melspecs = (np.float32(melspecs) + 80.0) / 80.0 122 | # melspecs = 80.0*(np.random.random((10000,128,126))-1.0) 123 | 124 | # print(melspecs[0].shape) 125 | # print(np.expand_dims(np.expand_dims(melspecs[0], 0), 3).shape) 126 | 127 | if not os.path.exists(args.logdir): 128 | os.makedirs(args.logdir) 129 | 130 | # Look for original parameters 131 | print('Loading existing parameters.') 132 | print(f'{args.logdir}/params.json') 133 | with open(f'{args.logdir}/params.json', 'r') as f: 134 | param = json.load(f) 135 | 136 | if args.encode_full: 137 | batch_size = encode_batch_size 138 | full_batches = len(filenames) // batch_size 139 | filename_counter = 0 140 | else: 141 | batch_size = 1 142 | 143 | # Set correct batch size in deconvolution shapes 144 | deconv_shape = param['deconv_shape'] 145 | for k, s in enumerate(deconv_shape): 146 | actual_shape = s 147 | actual_shape[0] = batch_size 148 | deconv_shape[k] = actual_shape 149 | param['deconv_shape'] = deconv_shape 150 | 151 | # Create coordinator. 152 | coord = tf.train.Coordinator() 153 | 154 | with tf.name_scope('create_inputs'): 155 | reader = SpectrogramReader(melspecs, coord) 156 | spec_batch = reader.dequeue(batch_size) 157 | 158 | # Create network. 159 | net = VAEModel(param, 160 | batch_size) 161 | 162 | # Set up session 163 | sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) 164 | init = tf.global_variables_initializer() 165 | sess.run(init) 166 | 167 | # Saver for loading checkpoints of the model. 168 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=max_checkpoints) 169 | 170 | try: 171 | saved_global_step = load(saver, sess, args.logdir) 172 | 173 | except: 174 | print("Something went wrong while restoring checkpoint. " 175 | "We will terminate training to avoid accidentally overwriting " 176 | "the previous model.") 177 | raise 178 | 179 | # Check if directory for saving exists 180 | out_dir = f'{args.logdir}/reconstructed-{saved_global_step}' 181 | if not os.path.exists(out_dir): 182 | os.makedirs(out_dir) 183 | if args.encode_full: 184 | out_dir_emb = f'{args.logdir}/embeddings-{saved_global_step}' 185 | if not os.path.exists(out_dir_emb): 186 | os.makedirs(out_dir_emb) 187 | 188 | if args.audio_file == None: 189 | 190 | if args.num_data == -1: 191 | if args.encode_full: 192 | num_batches = full_batches 193 | else: 194 | num_batches = len(filenames) 195 | else: 196 | num_batches = args.num_data 197 | 198 | for step in range(num_batches): 199 | 200 | if batch_size == 1: 201 | spec_in = np.expand_dims(np.expand_dims(melspecs[step],0),3) 202 | else: 203 | if step < full_batches: 204 | spec_in = melspecs[step * batch_size:(step + 1) * batch_size] 205 | batch_filenames = filenames[step * batch_size:(step + 1) * batch_size] 206 | else: 207 | spec_in = melspecs[step * batch_size:-1] 208 | batch_size_discrep = batch_size - spec_in.shape[0] 209 | spec_in = np.concatenate([spec_in, np.zeros(batch_size_discrep, spec_in.shape[1])]) 210 | batch_filenames = filenames[step * batch_size:-1] 211 | spec_in = np.expand_dims(spec_in,3) 212 | 213 | if args.encode_full: 214 | 215 | print(f'Batch {step} of {full_batches}.') 216 | 217 | # Check if all files exist 218 | exists = [] 219 | filename_counter_check = filename_counter 220 | for k, name in enumerate(batch_filenames): 221 | name_no_path = os.path.splitext(os.path.split(name)[1])[0] 222 | 223 | dataset_filename_emb = f'{out_dir_emb}/{filename_counter_check} - {name_no_path}.npy' 224 | 225 | # Skip if already exists 226 | if os.path.isfile(dataset_filename_emb): 227 | exists.append(True) 228 | else: 229 | exists.append(False) 230 | filename_counter_check += 1 231 | 232 | if all(exists): 233 | filename_counter = filename_counter_check 234 | continue 235 | 236 | emb, out = net.encode_and_reconstruct(spec_in) 237 | embedding, output = sess.run([emb, out]) 238 | 239 | del spec_in 240 | print(embedding.shape) 241 | 242 | # Save 243 | for k, name in enumerate(batch_filenames): 244 | 245 | if args.process_original_audio: 246 | 247 | try: 248 | 249 | name_no_path = os.path.splitext(os.path.split(name)[1])[0] 250 | 251 | dataset_filename = f'{out_dir_emb}/{filename_counter} - {name_no_path}.wav' 252 | dataset_filename_emb = f'{out_dir_emb}/{filename_counter} - {name_no_path}.npy' 253 | 254 | # Skip if already exists 255 | if args.encode_only_new and os.path.isfile(dataset_filename_emb): 256 | filename_counter += 1 257 | continue 258 | 259 | # Load audio file 260 | y, sr = librosa.core.load(name, sr=SAMPLING_RATE, mono=True, duration=sample_secs) 261 | y_tmp = np.zeros(num_samples_dataset) 262 | 263 | # Truncate or pad 264 | if len(y) >= num_samples_dataset: 265 | y_tmp = y[:num_samples_dataset] 266 | else: 267 | y_tmp[:len(y)] = y 268 | 269 | # Write to file 270 | librosa.output.write_wav(dataset_filename, y_tmp, sr, norm=True) 271 | np.save(dataset_filename_emb, embedding[k]) 272 | 273 | # # Also plot reconstruction 274 | # melspec = (np.squeeze(output[k]) - 1.0) * 80.0 275 | # plt.figure() 276 | # ax1 = plt.subplot(2, 1, 1) 277 | # 278 | # librosa.display.specshow((melspecs[step * batch_size + k] - 1.0) * 80.0, sr=SAMPLING_RATE, y_axis='mel', 279 | # x_axis='time', 280 | # hop_length=HOP_LENGTH) 281 | # plt.title('Original: ' + name_no_path) 282 | # ax2 = plt.subplot(2, 1, 2, sharex=ax1) 283 | # librosa.display.specshow(melspec, sr=SAMPLING_RATE, y_axis='mel', x_axis='time', 284 | # hop_length=HOP_LENGTH) 285 | # plt.title('Reconstruction') 286 | # plt.tight_layout() 287 | # plt.savefig(f'{out_dir_emb}/{filename_counter} - {name_no_path}.png') 288 | # plt.close() 289 | 290 | filename_counter += 1 291 | 292 | except: 293 | pass 294 | 295 | else: 296 | dataset_filename_emb = f'{out_dir_emb}/{filename_counter} - {name_no_path}.npy' 297 | 298 | # Write to file 299 | np.save(dataset_filename_emb, embedding[k]) 300 | 301 | filename_counter += 1 302 | 303 | del emb, out 304 | del embedding, output 305 | 306 | else: 307 | 308 | emb, out = net.encode_and_reconstruct(spec_in) 309 | embedding, output = sess.run([emb, out]) 310 | 311 | melspec = (np.squeeze(output[0])-1.0)*80.0 312 | 313 | if args.plot_spec: 314 | plt.figure() 315 | ax1 = plt.subplot(2, 1, 1) 316 | 317 | librosa.display.specshow((melspecs[step]-1.0)*80.0, sr=SAMPLING_RATE, y_axis='mel', x_axis='time', 318 | hop_length=HOP_LENGTH) 319 | plt.title('Original: ' + os.path.basename(filenames[step])) 320 | ax2 = plt.subplot(2, 1, 2, sharex=ax1) 321 | librosa.display.specshow(melspec, sr=SAMPLING_RATE, y_axis='mel', x_axis='time', 322 | hop_length=HOP_LENGTH) 323 | plt.title('Reconstruction') 324 | plt.tight_layout() 325 | plt.savefig(f'{out_dir}/reconstructed-{step}.png') 326 | plt.close() 327 | 328 | if args.rec_audio: 329 | audio = griffin_lim(melspec) 330 | audio_file = f'{out_dir}/reconstructed-{step}.wav' 331 | librosa.output.write_wav(audio_file, audio/np.max(audio), sr=SAMPLING_RATE) 332 | 333 | else: 334 | 335 | spec_in = melspecs 336 | 337 | emb, out = net.encode_and_reconstruct(spec_in) 338 | embedding, output = sess.run([emb, out]) 339 | 340 | melspec = (np.squeeze(output[0]) - 1.0) * 80.0 341 | melspec_in = (np.squeeze(melspecs) - 1.0) * 80.0 342 | 343 | # Save embeddings 344 | np.save(f'{out_dir}/reconstructed-{filename[:-4]}.npy', embedding[0]) 345 | 346 | # Plot 347 | plt.figure() 348 | ax1 = plt.subplot(2, 1, 1) 349 | 350 | librosa.display.specshow((melspec_in - 1.0) * 80.0, sr=SAMPLING_RATE, y_axis='mel', x_axis='time', 351 | hop_length=HOP_LENGTH) 352 | plt.title('Original: ' + os.path.basename(args.audio_file)) 353 | ax2 = plt.subplot(2, 1, 2, sharex=ax1) 354 | librosa.display.specshow(melspec, sr=SAMPLING_RATE, y_axis='mel', x_axis='time', 355 | hop_length=HOP_LENGTH) 356 | plt.title('Reconstruction') 357 | plt.tight_layout() 358 | plt.savefig(f'{out_dir}/reconstructed-{filename[:-4]}.png') 359 | plt.close() 360 | 361 | # Reconstruct and save audio 362 | audio = griffin_lim(melspec) 363 | audio_file = f'{out_dir}/reconstructed-{filename[:-4]}.wav' 364 | librosa.output.write_wav(audio_file, audio / np.max(audio), sr=SAMPLING_RATE) 365 | 366 | 367 | if __name__ == '__main__': 368 | main() -------------------------------------------------------------------------------- /find_similar.py: -------------------------------------------------------------------------------- 1 | # Give input file and target folder or file, as well as maximum number of most simialr things to return 2 | 3 | import os 4 | import sys 5 | import argparse 6 | from collections import deque 7 | from scipy.spatial import distance 8 | from tqdm import tqdm 9 | 10 | from griffin_lim import * 11 | from model_iaf import * 12 | from util import * 13 | 14 | # Print most similar file whenever a new one was found 15 | # If multiple clips per file enabled, also say which onset it was 16 | 17 | with open('audio_params.json', 'r') as f: 18 | param = json.load(f) 19 | 20 | N_FFT = param['N_FFT'] 21 | HOP_LENGTH = param['HOP_LENGTH'] 22 | SAMPLING_RATE = param['SAMPLING_RATE'] 23 | MELSPEC_BANDS = param['MELSPEC_BANDS'] 24 | sample_secs = param['sample_secs'] 25 | num_samples_dataset = int(sample_secs * SAMPLING_RATE) 26 | 27 | logdir = './test_iaf' 28 | max_checkpoints = 5 29 | num_steps = 10000 30 | checkpoint_every = 500 31 | batch_size = 128 32 | model_params = 'params.json' 33 | 34 | def get_arguments(): 35 | def _str_to_bool(s): 36 | """Convert string to bool (in argparse context).""" 37 | if s.lower() not in ['true', 'false']: 38 | raise ValueError('Argument needs to be a ' 39 | 'boolean, got {}'.format(s)) 40 | return {'true': True, 'false': False}[s.lower()] 41 | 42 | parser = argparse.ArgumentParser(description='Spectrogram VAE') 43 | parser.add_argument('--logdir', type=str, default=None, 44 | help='Directory in which to store the logging ' 45 | 'information for TensorBoard. ' 46 | 'If the model already exists, it will restore ' 47 | 'the state and will continue training. ') 48 | parser.add_argument('--target', type=str, default=None, 49 | help='File for which similar sounds are to be found. ') 50 | parser.add_argument('--sample_dirs', type=str, nargs='+', 51 | help='Root directories in which to look for samples. ') 52 | parser.add_argument('--num_to_keep', type=int, default=5, 53 | help='Keep this many most similar files.') 54 | parser.add_argument('--batch_size', type=int, default=batch_size, 55 | help='Batch Size.') 56 | parser.add_argument('--detect_onset', type=bool, default=False, 57 | help='Remove initial silence.') 58 | parser.add_argument('--search_within_file', type=bool, default=False, 59 | help='If true, not only encode the beginning, but detect transients and treat each separately.') 60 | return parser.parse_args() 61 | 62 | def save(saver, sess, logdir, step): 63 | model_name = 'model.ckpt' 64 | checkpoint_path = os.path.join(logdir, model_name) 65 | print('Storing checkpoint to {} ...'.format(logdir), end="") 66 | sys.stdout.flush() 67 | 68 | if not os.path.exists(logdir): 69 | os.makedirs(logdir) 70 | 71 | saver.save(sess, checkpoint_path, global_step=step) 72 | print(' Done.') 73 | 74 | 75 | def load(saver, sess, logdir): 76 | print("Trying to restore saved checkpoints from {} ...".format(logdir), 77 | end="") 78 | 79 | ckpt = tf.train.get_checkpoint_state(logdir) 80 | if ckpt: 81 | print(" Checkpoint found: {}".format(ckpt.model_checkpoint_path)) 82 | global_step = int(ckpt.model_checkpoint_path 83 | .split('/')[-1] 84 | .split('-')[-1]) 85 | print(" Global step was: {}".format(global_step)) 86 | print(" Restoring...", end="") 87 | saver.restore(sess, ckpt.model_checkpoint_path) 88 | print(" Done.") 89 | return global_step 90 | else: 91 | print(" No checkpoint found.") 92 | return None 93 | 94 | 95 | def main(): 96 | 97 | args = get_arguments() 98 | 99 | # Look for original parameters 100 | print('Loading existing parameters.') 101 | print(f'{args.logdir}/params.json') 102 | with open(f'{args.logdir}/params.json', 'r') as f: 103 | param = json.load(f) 104 | 105 | batch_size = args.batch_size 106 | 107 | # Set correct batch size in deconvolution shapes 108 | deconv_shape = param['deconv_shape'] 109 | for k, s in enumerate(deconv_shape): 110 | actual_shape = s 111 | actual_shape[0] = batch_size 112 | deconv_shape[k] = actual_shape 113 | param['deconv_shape'] = deconv_shape 114 | 115 | # Find all audio files in directories 116 | audio_files = [] 117 | 118 | for root_dir in args.sample_dirs: 119 | for dirName, subdirList, fileList in os.walk(root_dir, topdown=False): 120 | for fname in fileList: 121 | if os.path.splitext(fname)[1] in ['.wav', '.aiff', '.WAV']: 122 | audio_files.append('%s/%s' % (dirName, fname)) 123 | 124 | # Create network. 125 | net = VAEModel(param, 126 | batch_size) 127 | 128 | # Set up session 129 | sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) 130 | init = tf.global_variables_initializer() 131 | sess.run(init) 132 | 133 | # Saver for storing checkpoints of the model. 134 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=max_checkpoints) 135 | 136 | try: 137 | saved_global_step = load(saver, sess, args.logdir) 138 | 139 | except: 140 | print("Something went wrong while restoring checkpoint. " 141 | "We will terminate training to avoid accidentally overwriting " 142 | "the previous model.") 143 | raise 144 | 145 | # Get target embeddings 146 | target_spec_single, _ = get_melspec(args.target, as_tf_input=True) 147 | target_spec = np.float32(np.zeros((batch_size, 148 | target_spec_single.shape[1], 149 | target_spec_single.shape[2], 150 | 1))) 151 | target_spec[0] = target_spec_single[0] 152 | 153 | emb, _ = net.encode_and_reconstruct(target_spec) 154 | embedding_target = sess.run(emb)[0] 155 | 156 | # Go through all found files and compare distance 157 | similar_files = deque([None]) 158 | distances = deque([float('inf')]) 159 | 160 | full_batches = len(audio_files) // batch_size 161 | 162 | print(f'Starting to compare to {len(audio_files)} files.') 163 | 164 | try: 165 | for k in tqdm(range(full_batches+1)): 166 | 167 | # Prepare batch 168 | spec_list = [] 169 | for j in range(batch_size): 170 | # Exception for last batch where index will be out of range 171 | try: 172 | comparison_spec, _ = get_melspec(audio_files[k*batch_size+j], as_tf_input=True) 173 | except: 174 | comparison_spec = np.zeros_like(target_spec_single) 175 | spec_list.append(comparison_spec) 176 | comparison_specs = np.concatenate(spec_list) 177 | 178 | emb_comp, _ = net.encode_and_reconstruct(comparison_specs) 179 | embedding_comp = sess.run(emb_comp) 180 | 181 | # Compare each individually 182 | for j in range(batch_size): 183 | 184 | if k*batch_size+j >= len(audio_files): 185 | break 186 | 187 | # Get distance 188 | dist = distance.euclidean(embedding_comp[j], embedding_target) 189 | 190 | if dist >= max(distances): 191 | continue 192 | 193 | # Find position where it should go 194 | for m in range(len(similar_files)): 195 | if dist < distances[m]: 196 | # print(f'{k*batch_size+j},{k},{j},{m} Inserting Distance: {dist}; File: {audio_files[k*batch_size+j]}') 197 | similar_files.insert(m, audio_files[k * batch_size + j]) 198 | distances.insert(m, dist) 199 | if m == 0: 200 | print('New most similar file found.') 201 | print(f'Distance: {dist}; File: {audio_files[k*batch_size+j]}') 202 | break 203 | 204 | # Check if list grew beyond desired size 205 | if len(similar_files) > args.num_to_keep: 206 | similar_files.pop() 207 | distances.pop() 208 | 209 | except KeyboardInterrupt: 210 | print() 211 | finally: 212 | print('Search complete. Most similar files:') 213 | for k, file in enumerate(similar_files): 214 | # print() 215 | print(f'{k} -- Distance: {distances[k]}; File: {file}') 216 | 217 | 218 | if __name__ == '__main__': 219 | main() -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | mpl.use('TkAgg') 3 | 4 | import os 5 | import sys 6 | import argparse 7 | 8 | from griffin_lim import * 9 | from model_iaf import * 10 | from util import * 11 | 12 | import librosa 13 | import librosa.display 14 | import matplotlib.pyplot as plt 15 | 16 | with open('audio_params.json', 'r') as f: 17 | param = json.load(f) 18 | 19 | N_FFT = param['N_FFT'] 20 | HOP_LENGTH = param['HOP_LENGTH'] 21 | SAMPLING_RATE = param['SAMPLING_RATE'] 22 | MELSPEC_BANDS = param['MELSPEC_BANDS'] 23 | sample_secs = param['sample_secs'] 24 | num_samples_dataset = int(sample_secs * SAMPLING_RATE) 25 | 26 | logdir = './logdir' 27 | max_checkpoints = 5 28 | num_steps = 10000 29 | checkpoint_every = 500 30 | batch_size = 64 31 | learning_rate = 1e-3 32 | beta=1.0 33 | model_params = 'params.json' 34 | 35 | def get_arguments(): 36 | def _str_to_bool(s): 37 | """Convert string to bool (in argparse context).""" 38 | if s.lower() not in ['true', 'false']: 39 | raise ValueError('Argument needs to be a ' 40 | 'boolean, got {}'.format(s)) 41 | return {'true': True, 'false': False}[s.lower()] 42 | 43 | parser = argparse.ArgumentParser(description='Spectrogram VAE') 44 | parser.add_argument('--logdir', type=str, default=None, 45 | help='Directory in which to store the logging ' 46 | 'information for TensorBoard. ' 47 | 'If the model already exists, it will restore ' 48 | 'the state and will continue training. ') 49 | parser.add_argument('--file_in', type=str, nargs='*', 50 | help='Input file(s) from which to generate new audio. If none, sample random point in latent space') 51 | parser.add_argument('--file_out', type=str, default='generated', 52 | help='Output file for storing new audio. ') 53 | return parser.parse_args() 54 | 55 | def save(saver, sess, logdir, step): 56 | model_name = 'model.ckpt' 57 | checkpoint_path = os.path.join(logdir, model_name) 58 | print('Storing checkpoint to {} ...'.format(logdir), end="") 59 | sys.stdout.flush() 60 | 61 | if not os.path.exists(logdir): 62 | os.makedirs(logdir) 63 | 64 | saver.save(sess, checkpoint_path, global_step=step) 65 | print(' Done.') 66 | 67 | 68 | def load(saver, sess, logdir): 69 | print("Trying to restore saved checkpoints from {} ...".format(logdir), 70 | end="") 71 | 72 | ckpt = tf.train.get_checkpoint_state(logdir) 73 | if ckpt: 74 | print(" Checkpoint found: {}".format(ckpt.model_checkpoint_path)) 75 | global_step = int(ckpt.model_checkpoint_path 76 | .split('/')[-1] 77 | .split('-')[-1]) 78 | print(" Global step was: {}".format(global_step)) 79 | print(" Restoring...", end="") 80 | saver.restore(sess, ckpt.model_checkpoint_path) 81 | print(" Done.") 82 | return global_step 83 | else: 84 | print(" No checkpoint found.") 85 | return None 86 | 87 | 88 | def main(): 89 | 90 | args = get_arguments() 91 | 92 | num_files = len(args.file_in) 93 | 94 | if num_files > 0: 95 | # Convert audio files to spectrograms 96 | specs = [] 97 | for filename in args.file_in: 98 | spec, _ = get_melspec(filename) 99 | specs.append(np.expand_dims(spec, axis=0)) 100 | specs_in = np.concatenate(specs) 101 | specs_in = (np.float32(specs_in) + 80.0) / 80.0 102 | specs_in = np.expand_dims(specs_in, axis=3) 103 | 104 | batch_size = num_files 105 | else: 106 | batch_size = 1 107 | 108 | # Look for original parameters 109 | print('Loading existing parameters.') 110 | print(f'{args.logdir}/params.json') 111 | with open(f'{args.logdir}/params.json', 'r') as f: 112 | param = json.load(f) 113 | 114 | # Set correct batch size in deconvolution shapes 115 | deconv_shape = param['deconv_shape'] 116 | for k, s in enumerate(deconv_shape): 117 | actual_shape = s 118 | actual_shape[0] = batch_size 119 | deconv_shape[k] = actual_shape 120 | param['deconv_shape'] = deconv_shape 121 | 122 | # Create network. 123 | net = VAEModel(param, 124 | batch_size) 125 | 126 | # Set up session 127 | sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) 128 | init = tf.global_variables_initializer() 129 | sess.run(init) 130 | 131 | # Saver for storing checkpoints of the model. 132 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=max_checkpoints) 133 | 134 | try: 135 | saved_global_step = load(saver, sess, args.logdir) 136 | 137 | except: 138 | print("Something went wrong while restoring checkpoint. " 139 | "We will terminate training to avoid accidentally overwriting " 140 | "the previous model.") 141 | raise 142 | 143 | # Check if directory for saving exists 144 | out_dir = f'{args.logdir}/generated-{saved_global_step}' 145 | if not os.path.exists(out_dir): 146 | os.makedirs(out_dir) 147 | 148 | if num_files > 0: 149 | 150 | # Get embeddings 151 | emb, out = net.encode_and_reconstruct(specs_in) 152 | embedding, output = sess.run([emb, out]) 153 | 154 | # Average over embeddings 155 | embedding_mean = np.mean(embedding, axis=0) 156 | 157 | # Add zeros to send through same net with same batch size 158 | embedding_mean_batch = np.float32(np.zeros((batch_size,param['dim_latent']))) 159 | embedding_mean_batch[0] = embedding_mean 160 | 161 | else: 162 | embedding_mean_batch = np.float32(np.random.standard_normal((1, param['dim_latent']))) 163 | 164 | 165 | # Decode the mean embedding 166 | out_mean = net.decode(embedding_mean_batch) 167 | output_mean = sess.run(out_mean) 168 | 169 | spec_out = (np.squeeze(output_mean[0])-1.0)*80.0 170 | # spec_out1 = (np.squeeze(output[0])-1.0)*80.0 171 | 172 | # Plot 173 | plt.figure(figsize=(10, (num_files+1)*4)) 174 | 175 | if num_files > 0: 176 | ax1 = plt.subplot(num_files+1, 1, 1) 177 | librosa.display.specshow(np.squeeze(specs[0]), sr=SAMPLING_RATE, y_axis='mel', x_axis='time', 178 | hop_length=HOP_LENGTH) 179 | plt.title(f'Original 1: ' + os.path.basename(args.file_in[0])) 180 | for k in range(1,num_files): 181 | plt.subplot(num_files + 1, 1, k+1, sharex=ax1) 182 | librosa.display.specshow(np.squeeze(specs[k]), sr=SAMPLING_RATE, y_axis='mel', x_axis='time', 183 | hop_length=HOP_LENGTH) 184 | plt.title(f'Original {k+1}: ' + os.path.basename(args.file_in[k])) 185 | plt.subplot(num_files+1, 1, num_files+1, sharex=ax1) 186 | else: 187 | ax1 = plt.subplot(1, 1, 1) 188 | librosa.display.specshow(spec_out, sr=SAMPLING_RATE, y_axis='mel', x_axis='time', 189 | hop_length=HOP_LENGTH) 190 | plt.title('Combined Reconstruction') 191 | plt.tight_layout() 192 | plt.savefig(f'{out_dir}/{args.file_out}.png') 193 | plt.close() 194 | 195 | # Reconstruct audio 196 | audio = griffin_lim(spec_out) 197 | audio_file = f'{out_dir}/{args.file_out}.wav' 198 | librosa.output.write_wav(audio_file, audio / np.max(audio), sr=SAMPLING_RATE) 199 | 200 | 201 | if __name__ == '__main__': 202 | main() -------------------------------------------------------------------------------- /griffin_lim.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import json 4 | 5 | with open('audio_params.json', 'r') as f: 6 | param = json.load(f) 7 | 8 | N_FFT = param['N_FFT'] 9 | HOP_LENGTH = param['HOP_LENGTH'] 10 | SAMPLING_RATE = param['SAMPLING_RATE'] 11 | MELSPEC_BANDS = param['MELSPEC_BANDS'] 12 | 13 | _inv_mel_basis = None 14 | 15 | 16 | def _mel_to_linear(mel_spectrogram): 17 | global _inv_mel_basis 18 | if _inv_mel_basis is None: 19 | _inv_mel_basis = np.linalg.pinv(_build_mel_basis()) 20 | return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram)) 21 | 22 | 23 | def _db_to_amp(x): 24 | return np.power(10.0, x * 0.05) 25 | 26 | 27 | def _denormalize(S): 28 | return (np.clip(S, 0, 1) * -80.0) + 80.0 29 | 30 | 31 | def inv_magphase(mag, phase_angle): 32 | phase = np.cos(phase_angle) + 1.j * np.sin(phase_angle) 33 | return mag * phase 34 | 35 | 36 | def _build_mel_basis(): 37 | n_fft = N_FFT 38 | return librosa.filters.mel(SAMPLING_RATE, n_fft, n_mels=MELSPEC_BANDS) 39 | 40 | 41 | def griffin_lim(melspec, num_iters=10, phase_angle=0.0, n_fft=N_FFT, hop=HOP_LENGTH): 42 | """Iterative algorithm for phase retrival from a melspectrogram. 43 | 44 | Args: 45 | mag: Magnitude spectrogram. 46 | phase_angle: Initial condition for phase. 47 | n_fft: Size of the FFT. 48 | hop: Stride of FFT. Defaults to n_fft/2. 49 | num_iters: Griffin-Lim iterations to perform. 50 | 51 | Returns: 52 | audio: 1-D array of float32 sound samples. 53 | """ 54 | mag = _mel_to_linear(_db_to_amp(melspec)) 55 | 56 | fft_config = dict(n_fft=n_fft, win_length=n_fft, hop_length=hop, center=True) 57 | ifft_config = dict(win_length=n_fft, hop_length=hop, center=True) 58 | complex_specgram = inv_magphase(mag, phase_angle) 59 | for i in range(num_iters): 60 | audio = librosa.istft(complex_specgram, **ifft_config) 61 | if i != num_iters - 1: 62 | complex_specgram = librosa.stft(audio, **fft_config) 63 | _, phase = librosa.magphase(complex_specgram) 64 | phase_angle = np.angle(phase) 65 | complex_specgram = inv_magphase(mag, phase_angle) 66 | return audio -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def create_variable(name, shape): 5 | '''Create a convolution filter variable with the specified name and shape, 6 | and initialize it using Xavier initialition.''' 7 | initializer = tf.contrib.layers.xavier_initializer_conv2d() 8 | variable = tf.Variable(initializer(shape=shape), name=name) 9 | return variable 10 | 11 | 12 | def create_bias_variable(name, shape): 13 | '''Create a bias variable with the specified name and shape and initialize 14 | it to zero.''' 15 | initializer = tf.constant_initializer(value=0.001, dtype=tf.float32) 16 | return tf.Variable(initializer(shape=shape), name) 17 | 18 | 19 | # def upsample(net, name, stride, mode='ZEROS'): 20 | # """ 21 | # Imitate reverse operation of Max-Pooling by either placing original max values 22 | # into a fixed postion of upsampled cell: 23 | # [0.9] =>[[.9, 0], (stride=2) 24 | # [ 0, 0]] 25 | # or copying the value into each cell: 26 | # [0.9] =>[[.9, .9], (stride=2) 27 | # [ .9, .9]] 28 | # :param net: 4D input tensor with [batch_size, width, heights, channels] axis 29 | # :param stride: 30 | # :param mode: string 'ZEROS' or 'COPY' indicating which value to use for undefined cells 31 | # :return: 4D tensor of size [batch_size, width*stride, heights*stride, channels] 32 | # """ 33 | # assert mode in ['COPY', 'ZEROS'] 34 | # with tf.name_scope('Upsampling'): 35 | # net = _upsample_along_axis(net, 2, stride[1], mode=mode) 36 | # net = _upsample_along_axis(net, 1, stride[0], mode=mode) 37 | # return net 38 | 39 | 40 | # def _upsample_along_axis(volume, axis, stride, mode='ZEROS'): 41 | # shape = volume.get_shape().as_list() 42 | 43 | # assert mode in ['COPY', 'ZEROS'] 44 | # assert 0 <= axis < len(shape) 45 | 46 | # target_shape = shape[:] 47 | # target_shape[axis] *= stride 48 | 49 | # print(volume.dtype) 50 | # print(shape) 51 | 52 | # padding = tf.zeros(shape, dtype=volume.dtype) if mode == 'ZEROS' else volume 53 | # parts = [volume] + [padding for _ in range(stride - 1)] 54 | # volume = tf.concat(parts, min(axis+1, len(shape)-1)) 55 | 56 | # volume = tf.reshape(volume, target_shape) 57 | # return volume 58 | 59 | def upsample(value, name, factor=[2, 2]): 60 | size = [int(value.shape[1] * factor[0]), int(value.shape[2] * factor[1])] 61 | with tf.name_scope(name): 62 | out = tf.image.resize_bilinear(value, size=size, align_corners=None, name=None) 63 | return out 64 | 65 | 66 | def upsample2(value, name, output_shape): 67 | size = [int(output_shape[1]), int(output_shape[2])] 68 | with tf.name_scope(name): 69 | out = tf.image.resize_bilinear(value, size=size, align_corners=None, name=None) 70 | return out 71 | 72 | 73 | def two_d_conv(value, filter_, pool_kernel=[2, 2], name='two_d_conv'): 74 | out = tf.nn.conv2d(value, filter_, strides=[1, 1, 1, 1], padding='SAME') 75 | out = tf.contrib.layers.max_pool2d(out, pool_kernel) 76 | 77 | return out 78 | 79 | 80 | def two_d_deconv(value, filter_, deconv_shape, pool_kernel=[2, 2], name='two_d_conv'): 81 | out = upsample2(value, 'unpool', deconv_shape) 82 | # print(out) 83 | out = tf.nn.conv2d_transpose(out, filter_, output_shape=deconv_shape, strides=[1, 1, 1, 1], padding='SAME') 84 | # print(out) 85 | 86 | return out 87 | 88 | 89 | class VAEModel(object): 90 | 91 | def __init__(self, 92 | param, 93 | batch_size, 94 | activation=tf.nn.elu, 95 | activation_conv=tf.nn.elu, 96 | activation_nf=tf.nn.elu, 97 | encode=False): 98 | 99 | self.param = param 100 | self.batch_size = batch_size 101 | self.activation = activation 102 | self.activation_conv = activation_conv 103 | self.activation_nf = activation_nf 104 | self.encode = encode 105 | self.layers_enc = len(param['conv_channels']) 106 | self.layers_dec = self.layers_enc 107 | self.conv_out_shape = [7, 7] 108 | self.conv_out_units = self.conv_out_shape[0] * self.conv_out_shape[1] * param['conv_channels'][-1] 109 | self.cells_hidden = 512 110 | 111 | self.variables = self._create_variables() 112 | 113 | def _create_variables(self): 114 | '''This function creates all variables used by the network. 115 | This allows us to share them between multiple calls to the loss 116 | function and generation function.''' 117 | 118 | var = dict() 119 | 120 | with tf.variable_scope('VAE'): 121 | 122 | with tf.variable_scope("Encoder"): 123 | 124 | var['encoder_conv'] = list() 125 | with tf.variable_scope('conv_stack'): 126 | 127 | for l in range(self.layers_enc): 128 | 129 | with tf.variable_scope('layer{}'.format(l)): 130 | current = dict() 131 | 132 | if l == 0: 133 | channels_in = 1 134 | else: 135 | channels_in = self.param['conv_channels'][l - 1] 136 | channels_out = self.param['conv_channels'][l] 137 | 138 | current['filter'] = create_variable("filter", 139 | [3, 3, channels_in, channels_out]) 140 | # current['bias'] = create_bias_variable("bias", 141 | # [channels_out]) 142 | var['encoder_conv'].append(current) 143 | 144 | with tf.variable_scope('fully_connected'): 145 | 146 | layer = dict() 147 | 148 | layer['W_z0'] = create_variable("W_z0", 149 | shape=[self.conv_out_units, self.cells_hidden]) 150 | layer['b_z0'] = create_bias_variable("b_z0", 151 | shape=[1, self.cells_hidden]) 152 | 153 | layer['W_mu'] = create_variable("W_mu", 154 | shape=[self.cells_hidden, self.param['dim_latent']]) 155 | layer['W_logvar'] = create_variable("W_logvar", 156 | shape=[self.cells_hidden, self.param['dim_latent']]) 157 | layer['b_mu'] = create_bias_variable("b_mu", 158 | shape=[1, self.param['dim_latent']]) 159 | layer['b_logvar'] = create_bias_variable("b_logvar", 160 | shape=[1, self.param['dim_latent']]) 161 | 162 | var['encoder_fc'] = layer 163 | 164 | with tf.variable_scope("Decoder"): 165 | 166 | with tf.variable_scope('fully_connected'): 167 | layer = dict() 168 | 169 | layer['W_z'] = create_variable("W_z", 170 | shape=[self.param['dim_latent'], self.conv_out_units]) 171 | layer['b_z'] = create_bias_variable("b_z", 172 | shape=[1, self.conv_out_units]) 173 | 174 | var['decoder_fc'] = layer 175 | 176 | var['decoder_deconv'] = list() 177 | with tf.variable_scope('deconv_stack'): 178 | 179 | for l in range(self.layers_enc): 180 | with tf.variable_scope('layer{}'.format(l)): 181 | current = dict() 182 | 183 | channels_in = self.param['conv_channels'][-1 - l] 184 | if l == self.layers_enc - 1: 185 | channels_out = 1 186 | else: 187 | channels_out = self.param['conv_channels'][-l - 2] 188 | 189 | current['filter'] = create_variable("filter", 190 | [3, 3, channels_out, channels_in]) 191 | # current['bias'] = create_bias_variable("bias", 192 | # [channels_out]) 193 | var['decoder_deconv'].append(current) 194 | 195 | return var 196 | 197 | def _create_network(self, input_batch, keep_prob=1.0, encode=False): 198 | 199 | # Do encoder calculation 200 | encoder_hidden = input_batch 201 | for l in range(self.layers_enc): 202 | # print(encoder_hidden) 203 | encoder_hidden = two_d_conv(encoder_hidden, self.variables['encoder_conv'][l]['filter'], 204 | self.param['max_pooling'][l]) 205 | encoder_hidden = self.activation_conv(encoder_hidden) 206 | 207 | # print(encoder_hidden) 208 | 209 | encoder_hidden = tf.reshape(encoder_hidden, [-1, self.conv_out_units]) 210 | 211 | # print(encoder_hidden) 212 | 213 | # Additional non-linearity between encoder hidden state and prediction of mu_0,sigma_0 214 | mu_logvar_hidden = tf.nn.dropout(self.activation(tf.matmul(encoder_hidden, 215 | self.variables['encoder_fc']['W_z0']) 216 | + self.variables['encoder_fc']['b_z0']), 217 | keep_prob=keep_prob) 218 | 219 | # print(mu_logvar_hidden) 220 | 221 | encoder_mu = tf.add(tf.matmul(mu_logvar_hidden, self.variables['encoder_fc']['W_mu']), 222 | self.variables['encoder_fc']['b_mu'], name='ZMu') 223 | encoder_logvar = tf.add(tf.matmul(mu_logvar_hidden, self.variables['encoder_fc']['W_logvar']), 224 | self.variables['encoder_fc']['b_logvar'], name='ZLogVar') 225 | 226 | # print(encoder_mu) 227 | 228 | # Convert log variance into standard deviation 229 | encoder_std = tf.exp(0.5 * encoder_logvar) 230 | 231 | # Sample epsilon 232 | epsilon = tf.random_normal(tf.shape(encoder_std), name='epsilon') 233 | 234 | if encode: 235 | z0 = tf.identity(encoder_mu, name='LatentZ0') 236 | else: 237 | z0 = tf.identity(tf.add(encoder_mu, tf.multiply(encoder_std, epsilon), 238 | name='LatentZ0')) 239 | 240 | # print(z0) 241 | 242 | # Fully connected 243 | decoder_hidden = tf.nn.dropout(self.activation(tf.matmul(z0, self.variables['decoder_fc']['W_z']) 244 | + self.variables['decoder_fc']['b_z']), 245 | keep_prob=keep_prob) 246 | 247 | # print(decoder_hidden) 248 | 249 | # Reshape 250 | decoder_hidden = tf.reshape(decoder_hidden, [-1, self.conv_out_shape[0], self.conv_out_shape[1], 251 | self.param['conv_channels'][-1]]) 252 | 253 | for l in range(self.layers_enc): 254 | # print(decoder_hidden) 255 | 256 | pool_kernel = self.param['max_pooling'][-1 - l] 257 | decoder_hidden = two_d_deconv(decoder_hidden, self.variables['decoder_deconv'][l]['filter'], 258 | self.param['deconv_shape'][l], pool_kernel) 259 | if l < self.layers_enc - 1: 260 | decoder_hidden = self.activation_conv(decoder_hidden) 261 | 262 | decoder_output = tf.nn.sigmoid(decoder_hidden) 263 | 264 | # print(decoder_output) 265 | 266 | # return decoder_output, encoder_hidden, encoder_logvar, encoder_std 267 | return decoder_output, encoder_mu, encoder_logvar, encoder_std 268 | 269 | def loss(self, 270 | input_batch, 271 | name='vae', 272 | beta=1.0): 273 | 274 | with tf.name_scope(name): 275 | output, encoder_mu, encoder_logvar, encoder_std = self._create_network(input_batch) 276 | 277 | # loss=tf.reduce_min(encoder_std) 278 | 279 | loss_latent = tf.identity(-0.5 * tf.reduce_sum(1 + encoder_logvar 280 | - tf.square(encoder_mu) 281 | - tf.square(encoder_std), 1), name='LossLatent') 282 | 283 | loss_reconstruction = tf.identity(-tf.reduce_sum(input_batch * tf.log(1e-8 + output) 284 | + (1 - input_batch) * tf.log(1e-8 + 1 - output), 285 | [1, 2]), name='LossReconstruction') 286 | 287 | # loss_reconstruction = tf.reduce_mean(tf.pow(input_batch - output, 2)) 288 | 289 | loss = tf.reduce_mean(loss_reconstruction + beta*loss_latent, name='Loss') 290 | # loss = tf.reduce_mean(loss_reconstruction, name='Loss') 291 | 292 | tf.summary.scalar('loss', loss) 293 | tf.summary.scalar('loss_rec', tf.reduce_mean(loss_reconstruction)) 294 | tf.summary.scalar('loss_kl', tf.reduce_mean(loss_latent)) 295 | tf.summary.scalar('beta', beta) 296 | 297 | return loss 298 | 299 | def encode_and_reconstruct(self, 300 | input_batch): 301 | 302 | output, encoder_mu, _, _ = self._create_network(input_batch) 303 | 304 | return encoder_mu, output -------------------------------------------------------------------------------- /model_iaf.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def create_variable(name, shape): 5 | '''Create a convolution filter variable with the specified name and shape, 6 | and initialize it using Xavier initialition.''' 7 | initializer = tf.contrib.layers.xavier_initializer_conv2d() 8 | variable = tf.Variable(initializer(shape=shape), name=name) 9 | return variable 10 | 11 | 12 | def create_bias_variable(name, shape): 13 | '''Create a bias variable with the specified name and shape and initialize 14 | it to zero.''' 15 | initializer = tf.constant_initializer(value=0.001, dtype=tf.float32) 16 | return tf.Variable(initializer(shape=shape), name) 17 | 18 | 19 | def upsample(value, name, factor=[2, 2]): 20 | size = [int(value.shape[1] * factor[0]), int(value.shape[2] * factor[1])] 21 | with tf.name_scope(name): 22 | out = tf.image.resize_bilinear(value, size=size, align_corners=None, name=None) 23 | return out 24 | 25 | 26 | def upsample2(value, name, output_shape): 27 | size = [int(output_shape[1]), int(output_shape[2])] 28 | with tf.name_scope(name): 29 | out = tf.image.resize_bilinear(value, size=size, align_corners=None, name=None) 30 | return out 31 | 32 | 33 | def two_d_conv(value, filter_, pool_kernel=[2, 2], name='two_d_conv'): 34 | out = tf.nn.conv2d(value, filter_, strides=[1, 1, 1, 1], padding='SAME') 35 | out = tf.contrib.layers.max_pool2d(out, pool_kernel) 36 | 37 | return out 38 | 39 | 40 | def two_d_deconv(value, filter_, deconv_shape, pool_kernel=[2, 2], name='two_d_conv'): 41 | out = upsample2(value, 'unpool', deconv_shape) 42 | # print(out) 43 | out = tf.nn.conv2d_transpose(out, filter_, output_shape=deconv_shape, strides=[1, 1, 1, 1], padding='SAME') 44 | # print(out) 45 | 46 | return out 47 | 48 | # KL divergence between posterior with autoregressive flow and prior 49 | def kl_divergence(sigma, epsilon, z_K, param, batch_mean=True): 50 | # logprob of posterior 51 | log_q_z0 = -0.5 * tf.square(epsilon) 52 | 53 | # logprob of prior 54 | log_p_zK = 0.5 * tf.square(z_K) 55 | 56 | # Terms from each flow layer 57 | flow_loss = 0 58 | for l in range(param['iaf_flow_length'] + 1): 59 | # Make sure it can't take log(0) or log(neg) 60 | flow_loss -= tf.log(sigma[l] + 1e-10) 61 | 62 | kl_divs = tf.identity(log_q_z0 + flow_loss + log_p_zK) 63 | kl_divs_reduced = tf.reduce_sum(kl_divs, axis=1) 64 | 65 | if batch_mean: 66 | return tf.reduce_mean(kl_divs, axis=0), tf.reduce_mean(kl_divs_reduced) 67 | else: 68 | return kl_divs, kl_divs_reduced 69 | 70 | 71 | class VAEModel(object): 72 | 73 | def __init__(self, 74 | param, 75 | batch_size, 76 | activation=tf.nn.elu, 77 | activation_conv=tf.nn.elu, 78 | activation_nf=tf.nn.elu, 79 | encode=False): 80 | 81 | self.param = param 82 | self.batch_size = batch_size 83 | self.activation = activation 84 | self.activation_conv = activation_conv 85 | self.activation_nf = activation_nf 86 | self.encode = encode 87 | self.layers_enc = len(param['conv_channels']) 88 | self.layers_dec = self.layers_enc 89 | self.conv_out_shape = [7, 7] 90 | self.conv_out_units = self.conv_out_shape[0] * self.conv_out_shape[1] * param['conv_channels'][-1] 91 | self.cells_hidden = param['cells_hidden'] 92 | 93 | self.variables = self._create_variables() 94 | 95 | def _create_variables(self): 96 | '''This function creates all variables used by the network. 97 | This allows us to share them between multiple calls to the loss 98 | function and generation function.''' 99 | 100 | var = dict() 101 | 102 | with tf.variable_scope('VAE'): 103 | 104 | with tf.variable_scope("Encoder"): 105 | 106 | var['encoder_conv'] = list() 107 | with tf.variable_scope('conv_stack'): 108 | 109 | for l in range(self.layers_enc): 110 | 111 | with tf.variable_scope('layer{}'.format(l)): 112 | current = dict() 113 | 114 | if l == 0: 115 | channels_in = 1 116 | else: 117 | channels_in = self.param['conv_channels'][l - 1] 118 | channels_out = self.param['conv_channels'][l] 119 | 120 | current['filter'] = create_variable("filter", 121 | [3, 3, channels_in, channels_out]) 122 | # current['bias'] = create_bias_variable("bias", 123 | # [channels_out]) 124 | var['encoder_conv'].append(current) 125 | 126 | with tf.variable_scope('fully_connected'): 127 | 128 | layer = dict() 129 | 130 | layer['W_z0'] = create_variable("W_z0", 131 | shape=[self.conv_out_units, self.cells_hidden]) 132 | layer['b_z0'] = create_bias_variable("b_z0", 133 | shape=[1, self.cells_hidden]) 134 | 135 | layer['W_mu'] = create_variable("W_mu", 136 | shape=[self.cells_hidden, self.param['dim_latent']]) 137 | layer['W_logvar'] = create_variable("W_logvar", 138 | shape=[self.cells_hidden, self.param['dim_latent']]) 139 | layer['b_mu'] = create_bias_variable("b_mu", 140 | shape=[1, self.param['dim_latent']]) 141 | layer['b_logvar'] = create_bias_variable("b_logvar", 142 | shape=[1, self.param['dim_latent']]) 143 | 144 | var['encoder_fc'] = layer 145 | 146 | with tf.variable_scope("IAF"): 147 | 148 | var['iaf_flows'] = list() 149 | for l in range(self.param['iaf_flow_length']): 150 | 151 | with tf.variable_scope('layer{}'.format(l)): 152 | 153 | layer = dict() 154 | 155 | # Hidden state 156 | layer['W_flow'] = create_variable("W_flow", 157 | shape=[self.conv_out_units, self.param['dim_latent']]) 158 | layer['b_flow'] = create_bias_variable("b_flow", 159 | shape=[1, self.param['dim_latent']]) 160 | 161 | flow_variables = list() 162 | # Flow parameters from hidden state (m and s parameters for IAF) 163 | for j in range(self.param['dim_latent']): 164 | with tf.variable_scope('flow_layer{}'.format(j)): 165 | 166 | flow_layer = dict() 167 | 168 | # Set correct dimensionality 169 | units_to_hidden_iaf = self.param['dim_autoregressive_nl'] 170 | 171 | flow_layer['W_flow_params_nl'] = create_variable("W_flow_params_nl", 172 | shape=[self.param['dim_latent'] + j, units_to_hidden_iaf]) 173 | flow_layer['b_flow_params_nl'] = create_bias_variable("b_flow_params_nl", 174 | shape=[1, units_to_hidden_iaf]) 175 | 176 | flow_layer['W_flow_params'] = create_variable("W_flow_params", 177 | shape=[units_to_hidden_iaf, 178 | 2]) 179 | flow_layer['b_flow_params'] = create_bias_variable("b_flow_params", 180 | shape=[1, 2]) 181 | 182 | flow_variables.append(flow_layer) 183 | 184 | layer['flow_vars'] = flow_variables 185 | 186 | var['iaf_flows'].append(layer) 187 | 188 | 189 | with tf.variable_scope("Decoder"): 190 | 191 | with tf.variable_scope('fully_connected'): 192 | layer = dict() 193 | 194 | layer['W_z'] = create_variable("W_z", 195 | shape=[self.param['dim_latent'], self.conv_out_units]) 196 | layer['b_z'] = create_bias_variable("b_z", 197 | shape=[1, self.conv_out_units]) 198 | 199 | var['decoder_fc'] = layer 200 | 201 | var['decoder_deconv'] = list() 202 | with tf.variable_scope('deconv_stack'): 203 | 204 | for l in range(self.layers_enc): 205 | with tf.variable_scope('layer{}'.format(l)): 206 | current = dict() 207 | 208 | channels_in = self.param['conv_channels'][-1 - l] 209 | if l == self.layers_enc - 1: 210 | channels_out = 1 211 | else: 212 | channels_out = self.param['conv_channels'][-l - 2] 213 | 214 | current['filter'] = create_variable("filter", 215 | [3, 3, channels_out, channels_in]) 216 | # current['bias'] = create_bias_variable("bias", 217 | # [channels_out]) 218 | var['decoder_deconv'].append(current) 219 | 220 | return var 221 | 222 | def _create_network(self, input_batch, keep_prob=1.0, encode=False): 223 | 224 | # ----------------------------------- 225 | # Encoder 226 | 227 | # Do encoder calculation 228 | encoder_hidden = input_batch 229 | for l in range(self.layers_enc): 230 | # print(encoder_hidden) 231 | encoder_hidden = two_d_conv(encoder_hidden, self.variables['encoder_conv'][l]['filter'], 232 | self.param['max_pooling'][l]) 233 | encoder_hidden = self.activation_conv(encoder_hidden) 234 | 235 | # print(encoder_hidden) 236 | 237 | encoder_hidden = tf.reshape(encoder_hidden, [-1, self.conv_out_units]) 238 | 239 | # print(encoder_hidden) 240 | 241 | # Additional non-linearity between encoder hidden state and prediction of mu_0,sigma_0 242 | mu_logvar_hidden = tf.nn.dropout(self.activation(tf.matmul(encoder_hidden, 243 | self.variables['encoder_fc']['W_z0']) 244 | + self.variables['encoder_fc']['b_z0']), 245 | keep_prob=keep_prob) 246 | 247 | # print(mu_logvar_hidden) 248 | 249 | encoder_mu = tf.add(tf.matmul(mu_logvar_hidden, self.variables['encoder_fc']['W_mu']), 250 | self.variables['encoder_fc']['b_mu'], name='ZMu') 251 | encoder_logvar = tf.add(tf.matmul(mu_logvar_hidden, self.variables['encoder_fc']['W_logvar']), 252 | self.variables['encoder_fc']['b_logvar'], name='ZLogVar') 253 | 254 | # print(encoder_mu) 255 | 256 | # Convert log variance into standard deviation 257 | encoder_std = tf.exp(0.5 * encoder_logvar) 258 | 259 | # Sample epsilon 260 | epsilon = tf.random_normal(tf.shape(encoder_std), name='epsilon') 261 | 262 | if encode: 263 | z0 = tf.identity(encoder_mu, name='LatentZ0') 264 | else: 265 | z0 = tf.identity(tf.add(encoder_mu, tf.multiply(encoder_std, epsilon), 266 | name='LatentZ0')) 267 | 268 | # ----------------------------------- 269 | # Latent flow 270 | 271 | # Lists to store the latent variables and the flow parameters 272 | nf_z = [z0] 273 | nf_sigma = [encoder_std] 274 | 275 | # Do calculations for each flow layer 276 | for l in range(self.param['iaf_flow_length']): 277 | 278 | W_flow = self.variables['iaf_flows'][l]['W_flow'] 279 | b_flow = self.variables['iaf_flows'][l]['b_flow'] 280 | 281 | nf_hidden = self.activation_nf(tf.matmul(encoder_hidden, W_flow) + b_flow) 282 | 283 | # Autoregressive calculation 284 | m_list = self.param['dim_latent'] * [None] 285 | s_list = self.param['dim_latent'] * [None] 286 | 287 | for j, flow_vars in enumerate(self.variables['iaf_flows'][l]['flow_vars']): 288 | 289 | # Go through computation one variable at a time 290 | if j == 0: 291 | hidden_autoregressive = nf_hidden 292 | else: 293 | z_slice = tf.slice(nf_z[-1], [0, 0], [-1, j]) 294 | hidden_autoregressive = tf.concat(axis=1, values=[nf_hidden, z_slice]) 295 | 296 | W_flow_params_nl = flow_vars['W_flow_params_nl'] 297 | b_flow_params_nl = flow_vars['b_flow_params_nl'] 298 | W_flow_params = flow_vars['W_flow_params'] 299 | b_flow_params = flow_vars['b_flow_params'] 300 | 301 | # Non-linearity at current autoregressive step 302 | nf_hidden_nl = self.activation_nf(tf.matmul(hidden_autoregressive, 303 | W_flow_params_nl) + b_flow_params_nl) 304 | 305 | # Calculate parameters for normalizing flow as linear transform 306 | ms = tf.matmul(nf_hidden_nl, W_flow_params) + b_flow_params 307 | 308 | # Split into individual components 309 | # m_list[j], s_list[j] = tf.split_v(value=ms, 310 | # size_splits=[1,1], 311 | # split_dim=1) 312 | m_list[j], s_list[j] = tf.split(value=ms, 313 | num_or_size_splits=[1, 1], 314 | axis=1) 315 | 316 | # Concatenate autoregressively computed variables 317 | # Add offset to s to make sure it starts out positive 318 | # (could have also initialised the bias term to 1) 319 | # Guarantees that flow initially small 320 | m = tf.concat(axis=1, values=m_list) 321 | s = self.param['initial_s_offset'] + tf.concat(axis=1, values=s_list) 322 | 323 | # Calculate sigma ("update gate value") from s 324 | sigma = tf.nn.sigmoid(s) 325 | nf_sigma.append(sigma) 326 | 327 | # Perform normalizing flow 328 | z_current = tf.multiply(sigma, nf_z[-1]) + tf.multiply((1 - sigma), m) 329 | 330 | # Invert order of variables to alternate dependence of autoregressive structure 331 | z_current = tf.reverse(z_current, axis=[1], name='LatentZ%d' % (l + 1)) 332 | 333 | # Add to list of latent variables 334 | nf_z.append(z_current) 335 | 336 | z = tf.identity(nf_z[-1], name="LatentZ") 337 | 338 | # ----------------------------------- 339 | # Decoder 340 | 341 | # Fully connected 342 | decoder_hidden = tf.nn.dropout(self.activation(tf.matmul(z, self.variables['decoder_fc']['W_z']) 343 | + self.variables['decoder_fc']['b_z']), 344 | keep_prob=keep_prob) 345 | 346 | # print(decoder_hidden) 347 | 348 | # Reshape 349 | decoder_hidden = tf.reshape(decoder_hidden, [-1, self.conv_out_shape[0], self.conv_out_shape[1], 350 | self.param['conv_channels'][-1]]) 351 | 352 | for l in range(self.layers_enc): 353 | # print(decoder_hidden) 354 | 355 | pool_kernel = self.param['max_pooling'][-1 - l] 356 | decoder_hidden = two_d_deconv(decoder_hidden, self.variables['decoder_deconv'][l]['filter'], 357 | self.param['deconv_shape'][l], pool_kernel) 358 | if l < self.layers_enc - 1: 359 | decoder_hidden = self.activation_conv(decoder_hidden) 360 | 361 | decoder_output = tf.nn.sigmoid(decoder_hidden) 362 | 363 | # print(decoder_output) 364 | 365 | # return decoder_output, encoder_hidden, encoder_logvar, encoder_std 366 | return decoder_output, encoder_mu, encoder_logvar, encoder_std, epsilon, z, nf_sigma 367 | 368 | def loss(self, 369 | input_batch, 370 | name='vae', 371 | beta=1.0): 372 | 373 | with tf.name_scope(name): 374 | output, encoder_mu, encoder_logvar, encoder_std, epsilon, z, nf_sigma = self._create_network(input_batch) 375 | 376 | _, div = kl_divergence(nf_sigma, epsilon, z, self.param, batch_mean=False) 377 | loss_latent = tf.identity(div, name='LossLatent') 378 | print(loss_latent) 379 | 380 | # loss_latent = tf.identity(-0.5 * tf.reduce_sum(1 + encoder_logvar 381 | # - tf.square(encoder_mu) 382 | # - tf.square(encoder_std), 1), name='LossLatent') 383 | 384 | print(input_batch) 385 | loss_reconstruction = tf.identity(-tf.reduce_sum(input_batch * tf.log(1e-8 + output) 386 | + (1 - input_batch) * tf.log(1e-8 + 1 - output), 387 | [1,2]), name='LossReconstruction') 388 | 389 | # loss_reconstruction = tf.reduce_mean(tf.pow(input_batch - output, 2)) 390 | 391 | loss = tf.reduce_mean(loss_reconstruction + beta*loss_latent, name='Loss') 392 | # loss = tf.reduce_mean(loss_reconstruction, name='Loss') 393 | 394 | tf.summary.scalar('loss', loss) 395 | tf.summary.scalar('loss_rec', tf.reduce_mean(loss_reconstruction)) 396 | tf.summary.scalar('loss_kl', tf.reduce_mean(loss_latent)) 397 | tf.summary.scalar('beta', beta) 398 | 399 | return loss 400 | 401 | def encode_and_reconstruct(self, input_batch): 402 | 403 | output, _, _, _, _, encoder_mu, _ = self._create_network(input_batch, encode=True) 404 | 405 | return encoder_mu, output 406 | 407 | def decode(self, input_batch): 408 | 409 | z = input_batch 410 | 411 | # Fully connected 412 | decoder_hidden = self.activation(tf.matmul(z, self.variables['decoder_fc']['W_z']) 413 | + self.variables['decoder_fc']['b_z']) 414 | 415 | # Reshape 416 | decoder_hidden = tf.reshape(decoder_hidden, [-1, self.conv_out_shape[0], self.conv_out_shape[1], 417 | self.param['conv_channels'][-1]]) 418 | 419 | for l in range(self.layers_enc): 420 | 421 | pool_kernel = self.param['max_pooling'][-1 - l] 422 | decoder_hidden = two_d_deconv(decoder_hidden, self.variables['decoder_deconv'][l]['filter'], 423 | self.param['deconv_shape'][l], pool_kernel) 424 | if l < self.layers_enc - 1: 425 | decoder_hidden = self.activation_conv(decoder_hidden) 426 | 427 | decoder_output = tf.nn.sigmoid(decoder_hidden) 428 | 429 | return decoder_output -------------------------------------------------------------------------------- /params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dim_latent": 64, 3 | "cells_hidden": 512, 4 | "conv_channels": [32, 64, 32, 32], 5 | "max_pooling": [[4, 4], [2, 2], [2, 4], [2, 2]], 6 | "deconv_shape": [[1, 15, 14, 32], 7 | [1, 31, 31, 64], 8 | [1, 63, 62, 32], 9 | [1, 128, 126, 1]], 10 | "iaf_flow_length": 5, 11 | "dim_autoregressive_nl": 64, 12 | "initial_s_offset": 1.0 13 | } 14 | -------------------------------------------------------------------------------- /spec_reader.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import random 3 | import tensorflow as tf 4 | import numpy as np 5 | import joblib 6 | 7 | def randomize_specs(specs): 8 | for k in range(specs.shape[0]): 9 | file_index = random.randint(0, (specs.shape[0] - 1)) 10 | yield specs[file_index] 11 | 12 | 13 | def return_spec(specs): 14 | randomized_specs = randomize_specs(specs) 15 | for spec in randomized_specs: 16 | # Convert from -80 to 0dB to range [0,1], and add channel dimension 17 | normalized_spec = np.expand_dims((spec + 80.0) / 80.0, 2) 18 | yield normalized_spec 19 | 20 | 21 | class SpectrogramReader(object): 22 | def __init__(self, 23 | specs, 24 | coord, 25 | queue_size=32): 26 | 27 | self.specs = specs 28 | self.coord = coord 29 | self.threads = [] 30 | self.spec_placeholder = tf.placeholder(dtype=tf.float32, shape=None) 31 | self.queue = tf.PaddingFIFOQueue(queue_size, 32 | ['float32'], 33 | shapes=[(128, 126, 1)]) 34 | self.enqueue = self.queue.enqueue([self.spec_placeholder]) 35 | 36 | def dequeue(self, num_elements): 37 | output = self.queue.dequeue_many(num_elements) 38 | return output 39 | 40 | def thread_main(self, sess): 41 | stop = False 42 | # Go through the dataset multiple times 43 | while not stop: 44 | iterator = return_spec(self.specs) 45 | for spec in iterator: 46 | if self.coord.should_stop(): 47 | stop = True 48 | break 49 | 50 | sess.run(self.enqueue, 51 | feed_dict={self.spec_placeholder: spec}) 52 | 53 | def start_threads(self, sess, n_threads=1): 54 | for _ in range(n_threads): 55 | thread = threading.Thread(target=self.thread_main, args=(sess,)) 56 | thread.daemon = True # Thread will close when parent quits. 57 | thread.start() 58 | self.threads.append(thread) 59 | return self.threads 60 | 61 | def load_specs(filename='dataset.pkl', return_filenames=False): 62 | print('Loading dataset.') 63 | # with open('dataset.pkl', 'rb') as handle: 64 | # dataset = pkl.load(handle) 65 | 66 | dataset = joblib.load(filename) 67 | 68 | print('Dataset loaded.') 69 | 70 | filenames = dataset['filenames'] 71 | melspecs = dataset['melspecs'] 72 | actual_lengths = dataset['actual_lengths'] 73 | 74 | # Convert spectra to array 75 | melspecs = np.array(melspecs) 76 | 77 | if return_filenames: 78 | return melspecs, filenames 79 | else: 80 | return melspecs -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | from shutil import copyfile 4 | import sys 5 | import time 6 | import joblib 7 | from random import shuffle 8 | import numpy as np 9 | import argparse 10 | import json 11 | 12 | from spec_reader import * 13 | from model_iaf import * 14 | 15 | logdir = './logdir' 16 | max_checkpoints = 5 17 | num_steps = 10000 18 | checkpoint_every = 500 19 | batch_size = 64 20 | learning_rate = 1e-3 21 | beta=1.0 22 | model_params = 'params.json' 23 | 24 | def get_arguments(): 25 | def _str_to_bool(s): 26 | """Convert string to bool (in argparse context).""" 27 | if s.lower() not in ['true', 'false']: 28 | raise ValueError('Argument needs to be a ' 29 | 'boolean, got {}'.format(s)) 30 | return {'true': True, 'false': False}[s.lower()] 31 | 32 | parser = argparse.ArgumentParser(description='Spectrogram VAE') 33 | parser.add_argument('--batch_size', type=int, default=batch_size, 34 | help='How many wav files to process at once. Default: ' + str(batch_size) + '.') 35 | parser.add_argument('--logdir', type=str, default=None, 36 | help='Directory in which to store the logging ' 37 | 'information for TensorBoard. ' 38 | 'If the model already exists, it will restore ' 39 | 'the state and will continue training. ') 40 | parser.add_argument('--checkpoint_every', type=int, 41 | default=checkpoint_every, 42 | help='How many steps to save each checkpoint after. Default: ' + str(checkpoint_every) + '.') 43 | parser.add_argument('--num_steps', type=int, default=num_steps, 44 | help='Number of training steps. Default: ' + str(num_steps) + '.') 45 | parser.add_argument('--learning_rate', type=float, default=learning_rate, 46 | help='Learning rate for training. Default: ' + str(learning_rate) + '.') 47 | parser.add_argument('--beta', type=float, default=beta, 48 | help='Factor for KL divergence term in loss. Default: ' + str(beta) + '.') 49 | parser.add_argument('--model_params', type=str, default=model_params, 50 | help='JSON file with the network parameters. Default: ' + model_params + '.') 51 | parser.add_argument('--max_checkpoints', type=int, default=max_checkpoints, 52 | help='Maximum amount of checkpoints that will be kept alive. Default: ' 53 | + str(max_checkpoints) + '.') 54 | return parser.parse_args() 55 | 56 | def save(saver, sess, logdir, step): 57 | model_name = 'model.ckpt' 58 | checkpoint_path = os.path.join(logdir, model_name) 59 | print('Storing checkpoint to {} ...'.format(logdir), end="") 60 | sys.stdout.flush() 61 | 62 | if not os.path.exists(logdir): 63 | os.makedirs(logdir) 64 | 65 | saver.save(sess, checkpoint_path, global_step=step) 66 | print(' Done.') 67 | 68 | 69 | def load(saver, sess, logdir): 70 | print("Trying to restore saved checkpoints from {} ...".format(logdir), 71 | end="") 72 | 73 | ckpt = tf.train.get_checkpoint_state(logdir) 74 | if ckpt: 75 | print(" Checkpoint found: {}".format(ckpt.model_checkpoint_path)) 76 | global_step = int(ckpt.model_checkpoint_path 77 | .split('/')[-1] 78 | .split('-')[-1]) 79 | print(" Global step was: {}".format(global_step)) 80 | print(" Restoring...", end="") 81 | saver.restore(sess, ckpt.model_checkpoint_path) 82 | print(" Done.") 83 | return global_step 84 | else: 85 | print(" No checkpoint found.") 86 | return None 87 | 88 | def main(): 89 | 90 | args = get_arguments() 91 | 92 | if not os.path.exists(args.logdir): 93 | os.makedirs(args.logdir) 94 | 95 | # If restarting an existing model, look for original parameters 96 | if os.path.isfile(f'{args.logdir}/params.json'): 97 | print('Loading existing parameters.') 98 | print(f'{args.logdir}/params.json') 99 | with open(f'{args.logdir}/params.json', 'r') as f: 100 | param = json.load(f) 101 | # Otherwise load new one and copy to logdir 102 | else: 103 | print('Starting with new parameters.') 104 | # Load model parameters 105 | with open(args.model_params, 'r') as f: 106 | param = json.load(f) 107 | copyfile(args.model_params, f'{args.logdir}/params.json') 108 | 109 | # Set correct batch size in deconvolution shapes 110 | deconv_shape = param['deconv_shape'] 111 | for k, s in enumerate(deconv_shape): 112 | actual_shape = s 113 | actual_shape[0] = args.batch_size 114 | deconv_shape[k] = actual_shape 115 | param['deconv_shape'] = deconv_shape 116 | 117 | # Load data 118 | melspecs = load_specs() 119 | # melspecs = 80.0*(np.random.random((10000,128,126))-1.0) 120 | 121 | # Create coordinator. 122 | coord = tf.train.Coordinator() 123 | 124 | with tf.name_scope('create_inputs'): 125 | reader = SpectrogramReader(melspecs, coord) 126 | spec_batch = reader.dequeue(args.batch_size) 127 | 128 | # Create network. 129 | net = VAEModel(param, 130 | args.batch_size) 131 | 132 | loss = net.loss(spec_batch, beta=args.beta) 133 | 134 | optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate, 135 | epsilon=1e-4) 136 | trainable = tf.trainable_variables() 137 | for var in trainable: 138 | print(var) 139 | optim = optimizer.minimize(loss, var_list=trainable) 140 | 141 | # Set up logging for TensorBoard. 142 | writer = tf.summary.FileWriter(args.logdir) 143 | writer.add_graph(tf.get_default_graph()) 144 | run_metadata = tf.RunMetadata() 145 | summaries = tf.summary.merge_all() 146 | print(summaries) 147 | 148 | # Set up session 149 | sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) 150 | init = tf.global_variables_initializer() 151 | sess.run(init) 152 | 153 | # Saver for storing checkpoints of the model. 154 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=args.max_checkpoints) 155 | 156 | try: 157 | saved_global_step = load(saver, sess, args.logdir) 158 | if saved_global_step is None: 159 | # The first training step will be saved_global_step + 1, 160 | # therefore we put -1 here for new or overwritten trainings. 161 | saved_global_step = -1 162 | 163 | except: 164 | print("Something went wrong while restoring checkpoint. " 165 | "We will terminate training to avoid accidentally overwriting " 166 | "the previous model.") 167 | raise 168 | 169 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 170 | reader.start_threads(sess) 171 | 172 | step = None 173 | last_saved_step = saved_global_step 174 | try: 175 | for step in range(saved_global_step + 1, num_steps): 176 | start_time = time.time() 177 | 178 | # loss_value = sess.run([loss])[0] 179 | # print(loss_value) 180 | summary, loss_value, _ = sess.run([summaries, loss, optim]) 181 | 182 | writer.add_summary(summary, step) 183 | 184 | duration = time.time() - start_time 185 | print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)' 186 | .format(step, loss_value, duration)) 187 | 188 | if step % args.checkpoint_every == 0: 189 | save(saver, sess, args.logdir, step) 190 | last_saved_step = step 191 | 192 | except KeyboardInterrupt: 193 | # Introduce a line break after ^C is displayed so save message 194 | # is on its own line. 195 | print() 196 | finally: 197 | if step > last_saved_step: 198 | save(saver, sess, args.logdir, step) 199 | coord.request_stop() 200 | coord.join(threads) 201 | 202 | 203 | if __name__ == '__main__': 204 | main() -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import json 4 | 5 | with open('audio_params.json', 'r') as f: 6 | param = json.load(f) 7 | 8 | N_FFT = param['N_FFT'] 9 | HOP_LENGTH = param['HOP_LENGTH'] 10 | SAMPLING_RATE = param['SAMPLING_RATE'] 11 | MELSPEC_BANDS = param['MELSPEC_BANDS'] 12 | sample_secs = param['sample_secs'] 13 | num_samples_dataset = int(sample_secs * SAMPLING_RATE) 14 | 15 | 16 | # Function to read in an audio file and return a mel spectrogram 17 | def get_melspec(filepath_or_audio, hop_length=HOP_LENGTH, n_mels=MELSPEC_BANDS, n_samples=num_samples_dataset, 18 | sample_secs=sample_secs, as_tf_input=False): 19 | 20 | y_tmp = np.zeros(n_samples) 21 | 22 | # Load a little more than necessary as a buffer 23 | load_duration = None if sample_secs == None else 1.1 * sample_secs 24 | 25 | # Load audio file or take given input 26 | if type(filepath_or_audio) == str: 27 | y, sr = librosa.core.load(filepath_or_audio, sr=SAMPLING_RATE, mono=True, duration=load_duration) 28 | else: 29 | y = filepath_or_audio 30 | sr = SAMPLING_RATE 31 | 32 | # Truncate or pad 33 | if n_samples: 34 | if len(y) >= n_samples: 35 | y_tmp = y[:n_samples] 36 | lentgh_ratio = 1.0 37 | else: 38 | y_tmp[:len(y)] = y 39 | lentgh_ratio = len(y) / n_samples 40 | else: 41 | y_tmp = y 42 | lentgh_ratio = 1.0 43 | 44 | # sfft -> mel conversion 45 | melspec = librosa.feature.melspectrogram(y=y_tmp, sr=sr, 46 | n_fft=N_FFT, hop_length=hop_length, n_mels=n_mels) 47 | S = librosa.power_to_db(melspec, np.max) 48 | 49 | if as_tf_input: 50 | S = spec_to_input(S) 51 | 52 | return S, lentgh_ratio 53 | 54 | def spec_to_input(spec): 55 | specs_out = (spec + 80.0) / 80.0 56 | specs_out = np.expand_dims(np.expand_dims(specs_out, axis=0), axis=3) 57 | return np.float32(specs_out) --------------------------------------------------------------------------------