├── .gitignore ├── README.md ├── _config.yml ├── code ├── cl_vae │ ├── __init__.py │ ├── model.py │ ├── sample.py │ └── train.py ├── cl_vrnn │ ├── __init__.py │ ├── model.py │ ├── sample.py │ └── train.py └── utils │ ├── __init__.py │ ├── midi_utils.py │ ├── model_utils.py │ ├── pianoroll.py │ └── weightnorm.py ├── data ├── input │ ├── JSB Chorales_Cs.pickle │ ├── JSB Chorales_all.pickle │ ├── Piano-midi_Cs.pickle │ └── Piano-midi_all.pickle ├── models │ └── .gitkeep └── samples │ ├── JSB10_CL-VAE_infer.wav │ ├── JSB10_CL-VRNN_infer.wav │ ├── JSB10_Data.wav │ ├── JSB10_VAE.wav │ ├── JSB10_VRNN.wav │ ├── PMall_CL-VAE_infer.wav │ ├── PMall_CL-VAE_true.wav │ ├── PMall_Data.wav │ └── PMall_VAE.wav └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pyc 3 | 4 | data/models/* 5 | !data/models/.gitkeep 6 | 7 | *.mid 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### A Classifying Variational Autoencoder with Application to Polyphonic Music Generation 2 | 3 | This is the implementation of the Classifying VAE and Classifying VAE+LSTM models, as described in [_A Classifying Variational Autoencoder with Application to Polyphonic Music Generation_](https://arxiv.org/abs/1711.07050) by Jay A. Hennig, Akash Umakantha, and Ryan C. Williamson. 4 | 5 | These models extend the standard VAE and VAE+LSTM to the case where there is a latent discrete category. In the case of music generation, for example, we may wish to infer the key of a song, so that we can generate notes that are consistent with that key. These discrete latents are modeled as a Logistic Normal distribution, so that random samples from this distribution can use the reparameterization trick during training. 6 | 7 | __Code for these models (in Keras) can be found [here](https://github.com/mobeets/classifying-vae-lstm).__ 8 | 9 | Training data for the JSB Chorales and Piano-midi corpuses can be found in `data/input`. Songs have been transposed into C major or C minor (`*_Cs.pickle`), for comparison to previous work, or kept in their original keys (`*_all.pickle`). 10 | 11 | ### Generated music samples 12 | 13 | Samples from the models trained on the JSB Chorales and Piano-midi corpuses, with songs in their original keys, can be found below, or in `data/samples`. 14 | 15 | __JSB Chorales (all keys)__: 16 | 17 | - VAE
18 | - Classifying VAE (inferred key)
19 | - VAE+LSTM
20 | - Classifying VAE+LSTM (inferred key)
21 | 22 | __Piano-midi (all keys)__: 23 | 24 | - VAE
25 | - Classifying VAE (inferred key)
26 | - Classifying VAE (given key)
27 | 28 | ### Training new models 29 | 30 | Example of training a Classifying VAE with 4 latent dimensions on JSB Chorales in two keys, and then generating a sample from this model: 31 | 32 | ```bash 33 | $ python cl_vae/train.py run1 --use_x_prev --latent_dim 4 --train_file '../data/input/JSB Chorales_Cs.pickle' 34 | $ python cl_vae/sample.py outfile --model_file ../data/models/run1.h5 --train_file '../data/input/JSB Chorales_Cs.pickle' 35 | ``` 36 | 37 | 38 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-minimal -------------------------------------------------------------------------------- /code/cl_vae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/code/cl_vae/__init__.py -------------------------------------------------------------------------------- /code/cl_vae/model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import scipy.stats 4 | from keras.layers import Input, Dense, Lambda, Reshape, concatenate 5 | from keras.models import Model 6 | from keras import backend as K 7 | from keras import losses 8 | 9 | def generate_sample(dec_model, w_enc_model, z_enc_model, x_seed, nsteps, w_val=None, use_z_prior=False, do_reset=True, w_sample=False, use_x_prev=False): 10 | """ 11 | for t = 1:nsteps 12 | 1. encode x_seed -> w_mean, w_log_var 13 | 2. sample w_t ~ logit-N(w_mean, exp(w_log_var/2)) 14 | 3. encode x_seed, w_t -> z_mean, z_log_var 15 | 4. sample z_t ~ N(z_mean, exp(z_log_var/2)) 16 | 3. decode w_t, z_t -> x_mean 17 | 4. sample x_t ~ Bern(x_mean) 18 | 5. update x_seed := x_t 19 | """ 20 | original_dim = x_seed.shape[0] 21 | Xs = np.zeros([nsteps, original_dim]) 22 | x_prev = np.expand_dims(x_seed, axis=0) 23 | x_prev_t = x_prev 24 | if w_val is None: 25 | w_t = sample_w(w_enc_model.predict(x_prev), add_noise=w_sample) 26 | else: 27 | w_t = w_val 28 | for t in xrange(nsteps): 29 | z_mean, z_log_var = z_enc_model.predict([x_prev, w_t]) 30 | if use_z_prior: 31 | z_t = sample_z((0*z_mean, 0*z_log_var)) 32 | else: 33 | z_t = sample_z((z_mean, z_log_var)) 34 | if use_x_prev: 35 | zc = [w_t, z_t, x_prev_t] 36 | else: 37 | zc = [w_t, z_t] 38 | x_t = sample_x(dec_model.predict(zc)) 39 | Xs[t] = x_t 40 | x_prev_t = x_prev 41 | x_prev = x_t 42 | return Xs 43 | 44 | def sample_x(x_mean): 45 | return 1.0*(np.random.rand(len(x_mean.squeeze())) <= x_mean) 46 | 47 | def sample_w(args, nsamps=1, nrm_samp=False, add_noise=True): 48 | w_mean, w_log_var = args 49 | if nsamps == 1: 50 | eps = np.random.randn(*((1, w_mean.flatten().shape[0]))) 51 | else: 52 | eps = np.random.randn(*((nsamps,) + w_mean.shape)) 53 | if eps.T.shape == w_mean.shape: 54 | eps = eps.T 55 | if add_noise: 56 | w_norm = w_mean + np.exp(w_log_var/2)*eps 57 | else: 58 | w_norm = w_mean + 0*np.exp(w_log_var/2)*eps 59 | if nrm_samp: 60 | return w_norm 61 | if nsamps == 1: 62 | w_norm = np.hstack([w_norm, np.zeros((w_norm.shape[0], 1))]) 63 | return np.exp(w_norm)/np.sum(np.exp(w_norm), axis=-1)[:,None] 64 | else: 65 | w_norm = np.dstack([w_norm, np.zeros(w_norm.shape[:-1]+ (1,))]) 66 | return np.exp(w_norm)/np.sum(np.exp(w_norm), axis=-1)[:,:,None] 67 | 68 | def sample_z(args, nsamps=1): 69 | Z_mean, Z_log_var = args 70 | if nsamps == 1: 71 | eps = np.random.randn(*Z_mean.squeeze().shape) 72 | else: 73 | eps = np.random.randn(*((nsamps,) + Z_mean.squeeze().shape)) 74 | return Z_mean + np.exp(Z_log_var/2) * eps 75 | 76 | def make_w_encoder(model, original_dim, batch_size=1): 77 | x = Input(batch_shape=(batch_size, original_dim), name='x') 78 | 79 | # build label encoder 80 | h_w = model.get_layer('h_w')(x) 81 | w_mean = model.get_layer('w_mean')(h_w) 82 | w_log_var = model.get_layer('w_log_var')(h_w) 83 | 84 | mdl = Model(x, [w_mean, w_log_var]) 85 | return mdl 86 | 87 | def make_z_encoder(model, original_dim, class_dim, (latent_dim_0, latent_dim), batch_size=1): 88 | x = Input(batch_shape=(batch_size, original_dim), name='x') 89 | w = Input(batch_shape=(batch_size, class_dim), name='w') 90 | xw = concatenate([x, w], axis=-1) 91 | 92 | # build latent encoder 93 | if latent_dim_0 > 0: 94 | h = model.get_layer('h')(xw) 95 | z_mean = model.get_layer('z_mean')(h) 96 | z_log_var = model.get_layer('z_log_var')(h) 97 | else: 98 | z_mean = model.get_layer('z_mean')(xw) 99 | z_log_var = model.get_layer('z_log_var')(xw) 100 | 101 | mdl = Model([x, w], [z_mean, z_log_var]) 102 | return mdl 103 | 104 | def make_decoder(model, (latent_dim_0, latent_dim), class_dim, original_dim=88, use_x_prev=False, batch_size=1): 105 | w = Input(batch_shape=(batch_size, class_dim), name='w') 106 | z = Input(batch_shape=(batch_size, latent_dim), name='z') 107 | if use_x_prev: 108 | xp = Input(batch_shape=(batch_size, original_dim), name='history') 109 | if use_x_prev: 110 | xpz = concatenate([xp, z], axis=-1) 111 | else: 112 | xpz = z 113 | wz = concatenate([w, xpz], axis=-1) 114 | 115 | # build x decoder 116 | decoder_mean = model.get_layer('x_decoded_mean') 117 | if latent_dim_0 > 0: 118 | decoder_h = model.get_layer('decoder_h') 119 | h_decoded = decoder_h(wz) 120 | x_decoded_mean = decoder_mean(h_decoded) 121 | else: 122 | x_decoded_mean = decoder_mean(wz) 123 | 124 | if use_x_prev: 125 | mdl = Model([w, z, xp], x_decoded_mean) 126 | else: 127 | mdl = Model([w, z], x_decoded_mean) 128 | return mdl 129 | 130 | def get_model(batch_size, original_dim, 131 | (latent_dim_0, latent_dim), 132 | (class_dim_0, class_dim), optimizer, 133 | class_weight=1.0, kl_weight=1.0, use_x_prev=False, 134 | w_kl_weight=1.0, w_log_var_prior=0.0): 135 | 136 | x = Input(batch_shape=(batch_size, original_dim), name='x') 137 | if use_x_prev: 138 | xp = Input(batch_shape=(batch_size, original_dim), name='history') 139 | 140 | # build label encoder 141 | h_w = Dense(class_dim_0, activation='relu', name='h_w')(x) 142 | w_mean = Dense(class_dim-1, name='w_mean')(h_w) 143 | w_log_var = Dense(class_dim-1, name='w_log_var')(h_w) 144 | 145 | # sample label 146 | def w_sampling(args): 147 | """ 148 | sample from a logit-normal with params w_mean and w_log_var 149 | (n.b. this is very similar to a logistic-normal distribution) 150 | """ 151 | w_mean, w_log_var = args 152 | eps = K.random_normal(shape=(batch_size, class_dim-1), mean=0., stddev=1.0) 153 | w_norm = w_mean + K.exp(w_log_var/2) * eps 154 | # need to add '0' so we can sum it all to 1 155 | w_norm = concatenate([w_norm, K.tf.zeros(batch_size, 1)[:,None]]) 156 | return K.exp(w_norm)/K.sum(K.exp(w_norm), axis=-1)[:,None] 157 | w = Lambda(w_sampling, name='w')([w_mean, w_log_var]) 158 | 159 | # build latent encoder 160 | xw = concatenate([x, w], axis=-1) 161 | if latent_dim_0 > 0: 162 | h = Dense(latent_dim_0, activation='relu', name='h')(xw) 163 | z_mean = Dense(latent_dim, name='z_mean')(h) 164 | z_log_var = Dense(latent_dim, name='z_log_var')(h) 165 | else: 166 | z_mean = Dense(latent_dim, name='z_mean')(xw) 167 | z_log_var = Dense(latent_dim, name='z_log_var')(xw) 168 | 169 | # sample latents 170 | def sampling(args): 171 | z_mean, z_log_var = args 172 | eps = K.random_normal(shape=(batch_size, latent_dim), mean=0., stddev=1.0) 173 | return z_mean + K.exp(z_log_var/2) * eps 174 | z = Lambda(sampling, name='z')([z_mean, z_log_var]) 175 | 176 | # build decoder 177 | if use_x_prev: 178 | xpz = concatenate([xp, z], axis=-1) 179 | else: 180 | xpz = z 181 | wz = concatenate([w, xpz], axis=-1) 182 | decoder_mean = Dense(original_dim, activation='sigmoid', name='x_decoded_mean') 183 | if latent_dim_0 > 0: 184 | decoder_h = Dense(latent_dim_0, activation='relu', name='decoder_h') 185 | h_decoded = decoder_h(wz) 186 | x_decoded_mean = decoder_mean(h_decoded) 187 | else: 188 | x_decoded_mean = decoder_mean(wz) 189 | 190 | def vae_loss(x, x_decoded_mean): 191 | return original_dim * losses.binary_crossentropy(x, x_decoded_mean) 192 | 193 | def kl_loss(z_true, z_args): 194 | Z_mean = z_args[:,:latent_dim] 195 | Z_log_var = z_args[:,latent_dim:] 196 | return -0.5*K.sum(1 + Z_log_var - K.square(Z_mean) - K.exp(Z_log_var), axis=-1) 197 | 198 | def w_rec_loss(w_true, w): 199 | return (class_dim-1) * losses.categorical_crossentropy(w_true, w) 200 | 201 | # w_log_var_prior = 1.0 202 | def w_kl_loss(w_true, w): 203 | # w_log_var_prior 204 | # return -0.5 * K.sum(1 + w_log_var - K.square(w_mean) - K.exp(w_log_var), axis=-1) 205 | vs = 1 - w_log_var_prior + w_log_var - K.exp(w_log_var)/K.exp(w_log_var_prior) - K.square(w_mean)/K.exp(w_log_var_prior) 206 | return -0.5*K.sum(vs, axis=-1) 207 | 208 | w2 = Lambda(lambda x: x+1e-10, name='w2')(w) 209 | z_args = concatenate([z_mean, z_log_var], axis=-1, name='z_args') 210 | if use_x_prev: 211 | model = Model([x, xp], [x_decoded_mean, w, w2, z_args]) 212 | enc_model = Model([x, xp], [z_mean, w_mean]) 213 | else: 214 | model = Model(x, [x_decoded_mean, w, w2, z_args]) 215 | enc_model = Model(x, [z_mean, w_mean]) 216 | model.compile(optimizer=optimizer, 217 | loss={'x_decoded_mean': vae_loss, 'w': w_kl_loss, 'w2': w_rec_loss, 'z_args': kl_loss}, 218 | loss_weights={'x_decoded_mean': 1.0, 'w': w_kl_weight, 'w2': class_weight, 'z_args': kl_weight}, 219 | metrics={'w': 'accuracy'}) 220 | if use_x_prev: 221 | enc_model = Model([x, xp], [z_mean, w_mean]) 222 | else: 223 | enc_model = Model(x, [z_mean, w_mean]) 224 | return model, enc_model 225 | 226 | def load_model(model_file, optimizer='adam', batch_size=1, no_x_prev=False): 227 | """ 228 | there's a curently bug in the way keras loads models from .yaml 229 | that has to do with Lambdas 230 | so this is a hack for now... 231 | """ 232 | margs = json.load(open(model_file.replace('.h5', '.json'))) 233 | # model = model_from_yaml(open(args.model_file)) 234 | batch_size = margs['batch_size'] if batch_size == None else batch_size 235 | if no_x_prev or 'use_x_prev' not in margs: 236 | margs['use_x_prev'] = False 237 | model, enc_model = get_model(batch_size, margs['original_dim'], (margs['intermediate_dim'], margs['latent_dim']), (margs['intermediate_class_dim'], margs['n_classes']), optimizer, margs['class_weight'], use_x_prev=margs['use_x_prev']) 238 | model.load_weights(model_file) 239 | return model, enc_model, margs 240 | -------------------------------------------------------------------------------- /code/cl_vae/sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from keras.utils import to_categorical 4 | from utils.pianoroll import PianoData 5 | from utils.midi_utils import write_sample 6 | from model import load_model, generate_sample, make_decoder, make_w_encoder, make_z_encoder, sample_z 7 | 8 | def make_sample(P, dec_model, w_enc_model, z_enc_model, args, margs): 9 | # generate and write sample 10 | seed_ind = np.random.choice(xrange(len(P.x_test))) 11 | x_seed = P.x_test[seed_ind][0] 12 | seed_key_ind = P.test_song_keys[seed_ind] 13 | w_val = None if args.infer_w else to_categorical(seed_key_ind, margs['n_classes']) 14 | sample = generate_sample(dec_model, w_enc_model, z_enc_model, x_seed, args.t, w_val=w_val, use_z_prior=args.use_z_prior, use_x_prev=margs['use_x_prev']) 15 | write_sample(sample, args.sample_dir, args.run_name, True) 16 | 17 | def sample(args): 18 | # load models 19 | train_model, enc_model, margs = load_model(args.model_file, no_x_prev=args.no_x_prev) 20 | w_enc_model = make_w_encoder(train_model, margs['original_dim']) 21 | z_enc_model = make_z_encoder(train_model, margs['original_dim'], margs['n_classes'], (margs['intermediate_dim'], margs['latent_dim'])) 22 | dec_model = make_decoder(train_model, (margs['intermediate_dim'], margs['latent_dim']), margs['n_classes'], use_x_prev=margs['use_x_prev']) 23 | 24 | # load data 25 | P = PianoData(args.train_file, 26 | batch_size=1, 27 | seq_length=args.t, 28 | squeeze_x=True) 29 | 30 | basenm = args.run_name 31 | for i in xrange(args.n): 32 | args.run_name = basenm + '_' + str(i) 33 | make_sample(P, dec_model, w_enc_model, z_enc_model, args, margs) 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('run_name', type=str, 38 | help='tag for current run') 39 | parser.add_argument("-n", type=int, default=1, 40 | help="number of samples") 41 | parser.add_argument("--use_z_prior", action="store_true", 42 | help="sample z from standard normal at each timestep") 43 | parser.add_argument('-t', type=int, default=32, 44 | help='number of timesteps per sample') 45 | parser.add_argument("--infer_w", action="store_true", 46 | help="infer w when generating") 47 | parser.add_argument("--no_x_prev", action="store_true", 48 | help="override use_x_prev") 49 | parser.add_argument('--sample_dir', type=str, 50 | default='../data/samples', 51 | help='basedir for saving output midi files') 52 | parser.add_argument('--model_dir', type=str, 53 | default='../data/models', 54 | help='basedir for saving model weights') 55 | parser.add_argument('-i', '--model_file', type=str, default='', 56 | help='preload model weights (no training)') 57 | parser.add_argument('--train_file', type=str, 58 | default='../data/input/JSB Chorales_Cs.pickle', 59 | help='file of training data (.pickle)') 60 | args = parser.parse_args() 61 | sample(args) 62 | # $ brew install timidity 63 | # $ timidity filename.mid 64 | -------------------------------------------------------------------------------- /code/cl_vae/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classifying variational autoencoders 3 | """ 4 | import argparse 5 | import numpy as np 6 | from keras import backend as K 7 | from keras.utils import to_categorical 8 | from utils.pianoroll import PianoData 9 | from utils.model_utils import get_callbacks, save_model_in_pieces, init_adam_wn, AnnealLossWeight 10 | from utils.weightnorm import data_based_init 11 | from model import get_model 12 | 13 | def train(args): 14 | P = PianoData(args.train_file, 15 | batch_size=args.batch_size, 16 | seq_length=args.seq_length, 17 | step_length=1, 18 | return_y_next=args.predict_next or args.use_x_prev, 19 | squeeze_x=True, 20 | squeeze_y=True) 21 | if args.seq_length > 1: 22 | X = np.vstack([P.x_train, P.x_valid, P.x_test, P.y_train, P.y_valid, P.y_test]) 23 | ix = X.sum(axis=0).sum(axis=0) > 0 24 | P.x_train = P.x_train[:,:,ix].reshape((len(P.x_train), -1)) 25 | P.x_valid = P.x_valid[:,:,ix].reshape((len(P.x_valid), -1)) 26 | P.x_test = P.x_test[:,:,ix].reshape((len(P.x_test), -1)) 27 | P.y_train = P.y_train[:,:,ix].reshape((len(P.y_train), -1)) 28 | P.y_valid = P.y_valid[:,:,ix].reshape((len(P.y_valid), -1)) 29 | P.y_test = P.y_test[:,:,ix].reshape((len(P.y_test), -1)) 30 | args.original_dim = ix.sum()*args.seq_length 31 | 32 | args.n_classes = len(np.unique(P.train_song_keys)) 33 | wtr = to_categorical(P.train_song_keys, args.n_classes) 34 | wva = to_categorical(P.valid_song_keys, args.n_classes) 35 | wte = to_categorical(P.test_song_keys, args.n_classes) 36 | 37 | assert not (args.predict_next and args.use_x_prev), "Can't use --predict_next if using --use_x_prev" 38 | callbacks = get_callbacks(args, patience=args.patience, 39 | min_epoch=max(args.kl_anneal, args.w_kl_anneal)+1, do_log=args.do_log) 40 | if args.kl_anneal > 0: 41 | assert args.kl_anneal <= args.num_epochs, "invalid kl_anneal" 42 | kl_weight = K.variable(value=0.1) 43 | callbacks += [AnnealLossWeight(kl_weight, name="kl_weight", final_value=1.0, n_epochs=args.kl_anneal)] 44 | else: 45 | kl_weight = 1.0 46 | if args.w_kl_anneal > 0: 47 | assert args.w_kl_anneal <= args.num_epochs, "invalid w_kl_anneal" 48 | w_kl_weight = K.variable(value=0.0) 49 | callbacks += [AnnealLossWeight(w_kl_weight, name="w_kl_weight", final_value=1.0, n_epochs=args.w_kl_anneal)] 50 | else: 51 | w_kl_weight = 1.0 52 | 53 | args.optimizer, was_adam_wn = init_adam_wn(args.optimizer) 54 | model, enc_model = get_model(args.batch_size, args.original_dim, (args.intermediate_dim, args.latent_dim), (args.intermediate_class_dim, args.n_classes), args.optimizer, args.class_weight, kl_weight, use_x_prev=args.use_x_prev, w_kl_weight=w_kl_weight, w_log_var_prior=args.w_log_var_prior) 55 | args.optimizer = 'adam-wn' if was_adam_wn else args.optimizer 56 | save_model_in_pieces(model, args) 57 | 58 | if args.use_x_prev: 59 | xtr = [P.y_train, P.x_train] 60 | xva = [P.y_valid, P.x_valid] 61 | else: 62 | xtr = P.x_train 63 | xva = P.x_valid 64 | 65 | data_based_init(model, P.x_train[:100]) 66 | history = model.fit(xtr, [P.y_train, wtr, wtr, P.y_train], 67 | shuffle=True, 68 | epochs=args.num_epochs, 69 | batch_size=args.batch_size, 70 | callbacks=callbacks, 71 | validation_data=(xva, [P.y_valid, wva, wva, P.y_valid])) 72 | best_ind = np.argmin([x if i >= max(args.kl_anneal, args.w_kl_anneal)+1 else np.inf for i,x in enumerate(history.history['val_loss'])]) 73 | best_loss = {k: history.history[k][best_ind] for k in history.history} 74 | return model, best_loss 75 | 76 | if __name__ == '__main__': 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('run_name', type=str, 79 | help='tag for current run') 80 | parser.add_argument('--batch_size', type=int, default=100, 81 | help='batch size') 82 | parser.add_argument('--optimizer', type=str, default='adam-wn', 83 | help='optimizer name') # 'rmsprop' 84 | parser.add_argument('--num_epochs', type=int, default=200, 85 | help='number of epochs') 86 | parser.add_argument('--original_dim', type=int, default=88, 87 | help='input dim') 88 | parser.add_argument('--intermediate_dim', type=int, default=88, 89 | help='intermediate dim') 90 | parser.add_argument('--latent_dim', type=int, default=2, 91 | help='latent dim') 92 | parser.add_argument('--seq_length', type=int, default=1, 93 | help='sequence length (concat)') 94 | parser.add_argument('--class_weight', type=float, default=1.0, 95 | help='relative weight on classifying key') 96 | parser.add_argument('--w_log_var_prior', type=float, default=0.0, 97 | help='w log var prior') 98 | parser.add_argument('--intermediate_class_dim', 99 | type=int, default=88, 100 | help='intermediate dims for classes') 101 | parser.add_argument("--do_log", action="store_true", 102 | help="save log files") 103 | parser.add_argument("--predict_next", action="store_true", 104 | help="use x_t to 'autoencode' x_{t+1}") 105 | parser.add_argument("--use_x_prev", action="store_true", 106 | help="use x_{t-1} to help z_t decode x_t") 107 | parser.add_argument('--patience', type=int, default=5, 108 | help='# of epochs, for early stopping') 109 | parser.add_argument("--kl_anneal", type=int, default=0, 110 | help="number of epochs before kl loss term is 1.0") 111 | parser.add_argument("--w_kl_anneal", type=int, default=0, 112 | help="number of epochs before w's kl loss term is 1.0") 113 | parser.add_argument('--log_dir', type=str, default='../data/logs', 114 | help='basedir for saving log files') 115 | parser.add_argument('--model_dir', type=str, 116 | default='../data/models', 117 | help='basedir for saving model weights') 118 | parser.add_argument('--train_file', type=str, 119 | default='../data/input/JSB Chorales_Cs.pickle', 120 | help='file of training data (.pickle)') 121 | args = parser.parse_args() 122 | train(args) 123 | -------------------------------------------------------------------------------- /code/cl_vrnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/code/cl_vrnn/__init__.py -------------------------------------------------------------------------------- /code/cl_vrnn/model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from keras import losses 4 | from keras import backend as K 5 | from keras import initializers 6 | from keras.layers import Input, Dense, LSTM, TimeDistributed, Lambda, concatenate, RepeatVector, Flatten 7 | from keras.models import Model 8 | 9 | def generate_sample(dec_model, w_enc_model, z_enc_model, x_seed, nsteps, use_x_prev, w_val=None, do_reset=True, seq_length=None, w_sample=False, w_discrete=False): 10 | """ 11 | for t = 1:nsteps 12 | 1. encode x_seed -> z_mean, z_log_var 13 | 2. sample z_t ~ N(z_mean, exp(z_log_var/2)) 14 | 3. decode z_t -> x_mean 15 | - note: may also use x_{t-1}, depending on the model 16 | 4. sample x_t ~ Bern(x_mean) 17 | 5. update x_seed := x_t 18 | 19 | NOTE: Looks like what's assumed by STORN 20 | """ 21 | if do_reset: 22 | dec_model.reset_states() 23 | w_enc_model.reset_states() 24 | z_enc_model.reset_states() 25 | # may need to seed model for multiple iters 26 | original_dim = x_seed.shape[-1] 27 | nseedsteps = x_seed.shape[0] if len(x_seed.shape) > 1 else 0 28 | Xs = np.zeros([nsteps+nseedsteps, original_dim]) 29 | Ws = np.tile(w_val, (len(Xs), 1)) 30 | if nseedsteps == 0: 31 | x_prev = x_seed[None,None,:] 32 | 33 | # decode w in the seed, or use provided value 34 | if w_val is None: 35 | ntms = x_seed.shape[1] 36 | w_ts = [] 37 | for i in np.arange(0, ntms, seq_length): 38 | xcs = x_seed[i:i+seq_length] 39 | if xcs.shape[0] == seq_length: 40 | w_ts.append(sample_w(w_enc_model.predict(xcs[None,:]), add_noise=w_sample)) 41 | w_t = np.vstack(w_ts).mean(axis=0)[None,:] 42 | # w_t = sample_w(w_enc_model.predict(x_seed[None,:seq_length]), add_noise=w_sample) 43 | if w_discrete: 44 | w_t = sample_w_discrete(w_t[0])[None,:] 45 | else: 46 | w_t = w_val 47 | for t in xrange(nsteps+nseedsteps): 48 | if t < nseedsteps: 49 | x_prev = x_seed[t][None,None,:] 50 | z_t = sample_z(z_enc_model.predict([x_prev, w_t])) 51 | 52 | # use previous X for decoding, if model requires this 53 | if use_x_prev: 54 | z_t = [z_t, x_prev, w_t] 55 | else: 56 | z_t = [z_t, w_t] 57 | x_t = sample_x(dec_model.predict(z_t)) 58 | x_prev = x_t 59 | Xs[t] = x_t 60 | return Xs[nseedsteps:] 61 | 62 | def sample_x(x_mean): 63 | return 1.0*(np.random.rand(*x_mean.squeeze().shape) <= x_mean) 64 | 65 | def sample_w_discrete(w): 66 | wn = np.zeros(w.shape) 67 | wn[np.random.choice(len(w), p=w/w.sum())] = 1. 68 | # wn[np.argmax(w)] = 1. 69 | return wn 70 | 71 | def sample_w(args, nsamps=1, nrm_samp=False, add_noise=True): 72 | w_mean, w_log_var = args 73 | if nsamps == 1: 74 | eps = np.random.randn(*((1, w_mean.flatten().shape[0]))) 75 | else: 76 | eps = np.random.randn(*((nsamps,) + w_mean.shape)) 77 | if add_noise: 78 | w_norm = w_mean + np.exp(w_log_var/2)*eps 79 | else: 80 | w_norm = w_mean + 0*eps 81 | if nrm_samp: 82 | return w_norm 83 | if nsamps == 1: 84 | w_norm = np.hstack([w_norm, np.zeros((w_norm.shape[0], 1))]) 85 | return np.exp(w_norm)/np.sum(np.exp(w_norm), axis=-1)[:,None] 86 | else: 87 | w_norm = np.dstack([w_norm, np.zeros(w_norm.shape[:-1]+ (1,))]) 88 | return np.exp(w_norm)/np.sum(np.exp(w_norm), axis=-1)[:,:,None] 89 | 90 | def sample_z(args, nsamps=1): 91 | Z_mean, Z_log_var = args 92 | if nsamps == 1: 93 | eps = np.random.randn(*Z_mean.squeeze().shape) 94 | else: 95 | eps = np.random.randn(*((nsamps,) + Z_mean.squeeze().shape)) 96 | return Z_mean + np.exp(Z_log_var/2) * eps 97 | 98 | def make_w_encoder(model, original_dim, n_classes, seq_length=1, batch_size=1): 99 | x = Input(batch_shape=(batch_size, seq_length, original_dim), name='x') 100 | 101 | # build label encoder 102 | hW = model.get_layer('hW') 103 | encoder_w_layer = model.get_layer('Wargs') 104 | # Wargs = encoder_w_layer(hW(x)) 105 | Wargs = encoder_w_layer(hW(Flatten()(x))) 106 | def get_w_mean(x): 107 | return x[:,:(n_classes-1)] 108 | def get_w_log_var(x): 109 | return x[:,(n_classes-1):] 110 | w_mean = Lambda(get_w_mean)(Wargs) 111 | w_log_var = Lambda(get_w_log_var)(Wargs) 112 | 113 | mdl = Model(x, [w_mean, w_log_var]) 114 | return mdl 115 | 116 | def make_z_encoder(model, original_dim, n_classes, (latent_dim_0, latent_dim), seq_length=1, batch_size=1, stateful=True): 117 | x = Input(batch_shape=(batch_size, seq_length, original_dim), name='x') 118 | w = Input(batch_shape=(batch_size, n_classes), name='w') 119 | xw = concatenate([x, RepeatVector(seq_length)(w)], axis=-1) 120 | 121 | # build latent encoder 122 | h = LSTM(latent_dim_0, 123 | # activation='relu', 124 | stateful=stateful, 125 | return_sequences=True, name='encoder_h')(xw) 126 | Zm = Dense(latent_dim, name='Z_mean_t') 127 | Zv = Dense(latent_dim, name='Z_log_var_t') 128 | z_mean = TimeDistributed(Zm, name='Z_mean')(h) 129 | z_log_var = TimeDistributed(Zv, name='Z_log_var')(h) 130 | zm = model.get_layer('Z_mean') 131 | zv = model.get_layer('Z_log_var') 132 | Zm.set_weights(zm.get_weights()) 133 | Zv.set_weights(zv.get_weights()) 134 | 135 | mdl = Model([x, w], [z_mean, z_log_var]) 136 | return mdl 137 | 138 | def make_decoder(model, original_dim, intermediate_dim, latent_dim, n_classes, use_x_prev, seq_length=1, batch_size=1, stateful=True): 139 | # build decoder 140 | Z = Input(batch_shape=(batch_size, seq_length, latent_dim), name='Z') 141 | if use_x_prev: 142 | Xp = Input(batch_shape=(batch_size, seq_length, original_dim), name='history') 143 | XpZ = concatenate([Xp, Z], axis=-1) 144 | else: 145 | XpZ = Z 146 | W = Input(batch_shape=(batch_size, n_classes), name='W') 147 | XpZ = concatenate([XpZ, RepeatVector(seq_length)(W)], axis=-1) 148 | 149 | decoder_h = LSTM(intermediate_dim, 150 | # activation='relu', 151 | return_sequences=True, 152 | stateful=stateful, name='decoder_h')(XpZ) 153 | X_mean_t = Dense(original_dim, activation='sigmoid', name='X_mean_t') 154 | X_decoded_mean = TimeDistributed(X_mean_t, name='X_decoded_mean')(decoder_h) 155 | 156 | if use_x_prev: 157 | decoder = Model([Z, Xp, W], X_decoded_mean) 158 | else: 159 | decoder = Model([Z, W], X_decoded_mean) 160 | decoder.get_layer('X_decoded_mean').set_weights(model.get_layer('X_decoded_mean').get_weights()) 161 | decoder.get_layer('decoder_h').set_weights(model.get_layer('decoder_h').get_weights()) 162 | return decoder 163 | 164 | def get_model(batch_size, original_dim, intermediate_dim, latent_dim, seq_length, n_classes, use_x_prev, optimizer, class_weight=1.0, kl_weight=1.0, dropout=0.0, w_kl_weight=1.0, w_log_var_prior = 0.0): 165 | """ 166 | if intermediate_dim == 0, uses the output of the lstms directly 167 | otherwise, adds dense layers 168 | """ 169 | X = Input(batch_shape=(batch_size, seq_length, original_dim), name='current') 170 | if use_x_prev: 171 | Xp = Input(batch_shape=(batch_size, seq_length, original_dim), name='history') 172 | 173 | # Sample w ~ logitNormal before continuing... 174 | hW = Dense(original_dim, activation='relu', name='hW')(Flatten()(X)) 175 | Wargs = Dense(2*(n_classes-1), name='Wargs')(hW) 176 | def get_w_mean(x): 177 | return x[:,:(n_classes-1)] 178 | def get_w_log_var(x): 179 | return x[:,(n_classes-1):] 180 | W_mean = Lambda(get_w_mean)(Wargs) 181 | W_log_var = Lambda(get_w_log_var)(Wargs) 182 | # sample latents, w 183 | def sampling_w(args): 184 | W_mean, W_log_var = args 185 | eps = K.random_normal(shape=(batch_size, (n_classes-1)), mean=0., stddev=1.0) 186 | W_samp = W_mean + K.exp(W_log_var/2) * eps 187 | W0 = concatenate([W_samp, K.zeros((batch_size,1))], axis=-1) 188 | num = K.exp(W0) 189 | denom = K.sum(num, axis=-1, keepdims=True) 190 | return num/denom 191 | W = Lambda(sampling_w, output_shape=(n_classes,), name='W')([W_mean, W_log_var]) 192 | 193 | XW = concatenate([X, RepeatVector(seq_length)(W)], axis=-1) 194 | 195 | # build encoder 196 | encoder_h = LSTM(intermediate_dim, 197 | # activation='relu', 198 | dropout=dropout, 199 | return_sequences=True, name='encoder_h')(XW) 200 | Z_mean_t = Dense(latent_dim, 201 | bias_initializer='zeros', 202 | kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=0.1), 203 | name='Z_mean_t') 204 | Z_log_var_t = Dense(latent_dim, 205 | bias_initializer='zeros', 206 | kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=0.1), 207 | name='Z_log_var_t') 208 | Z_mean = TimeDistributed(Z_mean_t, name='Z_mean')(encoder_h) 209 | Z_log_var = TimeDistributed(Z_log_var_t, name='Z_log_var')(encoder_h) 210 | 211 | # sample latents, z 212 | def sampling(args): 213 | Z_mean, Z_log_var = args 214 | eps = K.random_normal(shape=(batch_size, seq_length, latent_dim), mean=0., stddev=1.0) 215 | return Z_mean + K.exp(Z_log_var/2) * eps 216 | Z = Lambda(sampling, output_shape=(seq_length, latent_dim,))([Z_mean, Z_log_var]) 217 | 218 | if use_x_prev: 219 | XpZ = concatenate([Xp, Z], axis=-1) 220 | else: 221 | XpZ = Z 222 | XpZ = concatenate([XpZ, RepeatVector(seq_length)(W)], axis=-1) 223 | 224 | # build decoder 225 | decoder_h = LSTM(intermediate_dim, 226 | # activation='relu', 227 | dropout=dropout, 228 | return_sequences=True, name='decoder_h')(XpZ) 229 | X_mean_t = Dense(original_dim, 230 | bias_initializer='zeros', 231 | kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=0.1), 232 | activation='sigmoid', 233 | name='X_mean_t') 234 | X_decoded_mean = TimeDistributed(X_mean_t, name='X_decoded_mean')(decoder_h) 235 | 236 | def kl_loss(z_true, z_args): 237 | Z_mean = z_args[:,:,:latent_dim] 238 | Z_log_var = z_args[:,:,latent_dim:] 239 | return -0.5*K.sum(1 + Z_log_var - K.square(Z_mean) - K.exp(Z_log_var), axis=-1) 240 | 241 | def vae_loss(X, X_decoded_mean): 242 | return original_dim * losses.binary_crossentropy(X, X_decoded_mean) 243 | 244 | def w_rec_loss(w_true, w): 245 | return (n_classes-1) * losses.categorical_crossentropy(w_true, w) 246 | 247 | def w_kl_loss(w_true, w): 248 | # w_log_var_prior 249 | # return -0.5 * K.sum(1 + W_log_var - K.exp(W_log_var) - K.square(W_mean), axis=-1) 250 | # vs = 1 + W_log_var - K.exp(W_log_var) - K.square(W_mean) 251 | vs = 1 - w_log_var_prior + W_log_var - K.exp(W_log_var)/K.exp(w_log_var_prior) - K.square(W_mean)/K.exp(w_log_var_prior) 252 | return -0.5*K.sum(vs, axis=-1) 253 | 254 | # n.b. have to add very small amount to rename :( 255 | W2 = Lambda(lambda x: x+1e-10, name='W2')(W) 256 | Z_args = concatenate([Z_mean, Z_log_var], axis=-1, name='Z_args') 257 | if use_x_prev: 258 | model = Model([X, Xp], [X_decoded_mean, W, W2, Z_args]) 259 | else: 260 | model = Model(X, [X_decoded_mean, W, W2, Z_args]) 261 | model.compile(optimizer=optimizer, 262 | loss={'X_decoded_mean': vae_loss, 'W': w_kl_loss, 'W2': w_rec_loss, 'Z_args': kl_loss}, 263 | loss_weights={'X_decoded_mean': 1.0, 'W': w_kl_weight, 'W2': class_weight, 'Z_args': kl_weight}, 264 | metrics={'W': 'accuracy'}) 265 | 266 | encoder = Model(X, [Z_mean, Z_log_var, W]) 267 | return model, encoder 268 | 269 | def load_model(model_file, batch_size=None, seq_length=None, optimizer='adam'): 270 | """ 271 | there's a curently bug in the way keras loads models from .yaml 272 | that has to do with Lambdas 273 | so this is a hack for now... 274 | """ 275 | margs = json.load(open(model_file.replace('.h5', '.json'))) 276 | # model = model_from_yaml(open(args.model_file)) 277 | optimizer = margs['optimizer'] if optimizer is None else optimizer 278 | batch_size = margs['batch_size'] if batch_size is None else batch_size 279 | seq_length = margs['seq_length'] if seq_length is None else seq_length 280 | model, enc_model = get_model(batch_size, margs['original_dim'], margs['intermediate_dim'], margs['latent_dim'], seq_length, margs['n_classes'], margs['use_x_prev'], optimizer, margs['class_weight']) 281 | model.load_weights(model_file) 282 | return model, enc_model, margs 283 | -------------------------------------------------------------------------------- /code/cl_vrnn/sample.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import argparse 3 | import numpy as np 4 | from keras.utils import to_categorical 5 | from utils.pianoroll import PianoData 6 | from utils.midi_utils import write_sample 7 | from model import load_model, generate_sample, make_decoder, make_w_encoder, make_z_encoder 8 | 9 | def gen_samples(P, dec_model, w_enc_model, z_enc_model, args, margs): 10 | key_map = {v: k for k, v in P.key_map.iteritems()} 11 | inds = np.arange(len(P.test_song_keys)) 12 | if args.c is not None: # user set key 13 | kys = np.array([key_map[k] for k in P.test_song_keys]) 14 | ix = (kys == args.c) 15 | inds = inds[ix] 16 | np.random.shuffle(inds) 17 | outfile = lambda j,i: args.run_name + '_' + str(j) 18 | outfile_seed = lambda j,i: args.run_name + str(j) + '_seed_' + str(i) 19 | for j, i in enumerate(inds[:args.n]): 20 | cur_key_ind = P.test_song_keys[i] 21 | w_val = None if args.infer_w else to_categorical(cur_key_ind, margs['n_classes']) 22 | x_seed = P.x_test[i] 23 | sample = generate_sample(dec_model, w_enc_model, z_enc_model, x_seed, args.t, margs['use_x_prev'], w_val=w_val, w_discrete=args.discrete_w, seq_length=margs['seq_length']) 24 | 25 | write_sample(sample, args.sample_dir, outfile(j,i), 26 | 'jsb' in args.train_file.lower()) 27 | write_sample(x_seed, args.sample_dir, outfile_seed(j,i), 28 | 'jsb' in args.train_file.lower()) 29 | 30 | def sample(args): 31 | # load models 32 | train_model, _, margs = load_model(args.model_file, optimizer='adam') 33 | w_enc_model = make_w_encoder(train_model, margs['original_dim'], 34 | margs['n_classes'], margs['seq_length']) 35 | z_enc_model = make_z_encoder(train_model, margs['original_dim'], 36 | margs['n_classes'], (margs['intermediate_dim'], margs['latent_dim'])) 37 | dec_model = make_decoder(train_model, margs['original_dim'], 38 | margs['intermediate_dim'], margs['latent_dim'], margs['n_classes'], 39 | margs['use_x_prev']) 40 | 41 | # load data 42 | P = PianoData(args.train_file, 43 | batch_size=1, 44 | seq_length=args.t, 45 | squeeze_x=False) 46 | 47 | gen_samples(P, dec_model, w_enc_model, z_enc_model, args, margs) 48 | 49 | if __name__ == '__main__': 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('run_name', type=str, 52 | help='tag for current run') 53 | parser.add_argument("--infer_w", action="store_true", 54 | help="infer w when generating") 55 | parser.add_argument("--discrete_w", action="store_true", 56 | help="sample discrete w when generating") 57 | parser.add_argument('-t', type=int, default=32, 58 | help='number of timesteps per sample') 59 | parser.add_argument('-n', type=int, default=1, 60 | help='number of samples') 61 | parser.add_argument('-c', type=str, 62 | help='set key of seed sample') 63 | parser.add_argument('--sample_dir', type=str, 64 | default='../data/samples', 65 | help='basedir for saving output midi files') 66 | parser.add_argument('-i', '--model_file', type=str, default='', 67 | help='preload model weights (no training)') 68 | parser.add_argument('--train_file', type=str, 69 | default='../data/input/JSB Chorales_Cs.pickle', 70 | help='file of training data (.pickle)') 71 | args = parser.parse_args() 72 | sample(args) 73 | # $ brew install timidity 74 | # $ timidity filename.mid 75 | -------------------------------------------------------------------------------- /code/cl_vrnn/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classifying VAE+LSTM (STORN) 3 | """ 4 | import argparse 5 | import numpy as np 6 | from keras import backend as K 7 | from keras.utils import to_categorical 8 | from utils.pianoroll import PianoData 9 | from utils.model_utils import get_callbacks, save_model_in_pieces, AnnealLossWeight, init_adam_wn 10 | from utils.weightnorm import data_based_init 11 | from model import get_model 12 | 13 | def train(args): 14 | P = PianoData(args.train_file, 15 | batch_size=args.batch_size, 16 | seq_length=args.seq_length, 17 | step_length=1, 18 | return_y_next=args.predict_next or args.use_x_prev, 19 | return_y_hist=True, 20 | squeeze_x=False, 21 | squeeze_y=False) 22 | 23 | args.n_classes = len(np.unique(P.train_song_keys)) 24 | w = to_categorical(P.train_song_keys, args.n_classes) 25 | wv = to_categorical(P.valid_song_keys, args.n_classes) 26 | 27 | print "Training with {} classes.".format(args.n_classes) 28 | assert not (args.predict_next and args.use_x_prev), "Can't use --predict_next if using --use_x_prev" 29 | 30 | callbacks = get_callbacks(args, patience=args.patience, 31 | min_epoch=max(args.kl_anneal, args.w_kl_anneal)+1, do_log=args.do_log) 32 | if args.kl_anneal > 0: 33 | assert args.kl_anneal <= args.num_epochs, "invalid kl_anneal" 34 | kl_weight = K.variable(value=0.1) 35 | callbacks += [AnnealLossWeight(kl_weight, name="kl_weight", final_value=1.0, n_epochs=args.kl_anneal)] 36 | else: 37 | kl_weight = 1.0 38 | if args.w_kl_anneal > 0: 39 | assert args.w_kl_anneal <= args.num_epochs, "invalid w_kl_anneal" 40 | w_kl_weight = K.variable(value=0.0) 41 | callbacks += [AnnealLossWeight(w_kl_weight, name="w_kl_weight", final_value=1.0, n_epochs=args.w_kl_anneal)] 42 | else: 43 | w_kl_weight = 1.0 44 | 45 | args.optimizer, was_adam_wn = init_adam_wn(args.optimizer) 46 | model, _ = get_model(args.batch_size, args.original_dim, args.intermediate_dim, args.latent_dim, args.seq_length, args.n_classes, args.use_x_prev, args.optimizer, args.class_weight, kl_weight, w_kl_weight=w_kl_weight, w_log_var_prior=args.w_log_var_prior) 47 | args.optimizer = 'adam-wn' if was_adam_wn else args.optimizer 48 | save_model_in_pieces(model, args) 49 | 50 | print (P.x_train.shape, P.y_train.shape) 51 | if args.use_x_prev: 52 | x,y = [P.y_train, P.x_train], P.y_train 53 | xv,yv = [P.y_valid, P.x_valid], P.y_valid 54 | xt,yt = [P.y_test, P.x_test], P.y_test 55 | else: 56 | x,y = P.x_train, P.y_train 57 | xv,yv = P.x_valid, P.y_valid 58 | xt,yt = P.x_test, P.y_test 59 | xtr = x 60 | xva = xv 61 | xte = xt 62 | ytr = [y, w, w, y] 63 | yva = [yv, wv, wv, yv] 64 | 65 | data_based_init(model, x[:100]) 66 | history = model.fit(xtr, ytr, 67 | shuffle=True, 68 | epochs=args.num_epochs, 69 | batch_size=args.batch_size, 70 | callbacks=callbacks, 71 | validation_data=(xva, yva)) 72 | best_ind = np.argmin([x if i >= min(args.kl_anneal, args.w_kl_anneal) else np.inf for i,x in enumerate(history.history['val_loss'])]) 73 | best_loss = {k: history.history[k][best_ind] for k in history.history} 74 | return model, best_loss 75 | 76 | if __name__ == '__main__': 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('run_name', type=str, 79 | help='tag for current run') 80 | parser.add_argument('--batch_size', type=int, default=200, 81 | help='batch size') 82 | parser.add_argument('--optimizer', type=str, default='adam-wn', 83 | help='optimizer name') 84 | parser.add_argument('--num_epochs', type=int, default=200, 85 | help='number of epochs') 86 | parser.add_argument('--original_dim', type=int, default=88, 87 | help='input dim') 88 | parser.add_argument('--latent_dim', type=int, default=2, 89 | help='latent dim') 90 | parser.add_argument('--intermediate_dim', type=int, default=88, 91 | help='intermediate dim') 92 | parser.add_argument('--seq_length', type=int, default=16, 93 | help='sequence length (to use as history)') 94 | parser.add_argument('--class_weight', type=float, default=1.0, 95 | help='relative weight on classifying key') 96 | parser.add_argument("--predict_next", action="store_true", 97 | help="use x_t to 'autoencode' x_{t+1}") 98 | parser.add_argument("--do_log", action="store_true", 99 | help="save log files") 100 | parser.add_argument("--w_log_var_prior", type=float, default=0.0, 101 | help="log variance prior on w") 102 | parser.add_argument("--kl_anneal", type=int, default=0, 103 | help="number of epochs before kl loss term is 1.0") 104 | parser.add_argument("--w_kl_anneal", type=int, default=0, 105 | help="number of epochs before w's kl loss term is 1.0") 106 | parser.add_argument('--patience', type=int, default=5, 107 | help='# of epochs, for early stopping') 108 | parser.add_argument("--use_x_prev", action="store_true", 109 | help="use x_{t-1} to help z_t decode x_t") 110 | parser.add_argument('--log_dir', type=str, default='../data/logs', 111 | help='basedir for saving log files') 112 | parser.add_argument('--model_dir', type=str, 113 | default='../data/models', 114 | help='basedir for saving model weights') 115 | parser.add_argument('--train_file', type=str, 116 | default='../data/input/JSB Chorales_Cs.pickle', 117 | help='file of training data (.pickle)') 118 | args = parser.parse_args() 119 | train(args) 120 | -------------------------------------------------------------------------------- /code/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/code/utils/__init__.py -------------------------------------------------------------------------------- /code/utils/midi_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | source: https://github.com/yoavz/music_rnn 3 | """ 4 | import sys, os 5 | from collections import defaultdict 6 | import numpy as np 7 | import midi 8 | 9 | RANGE = 128 10 | 11 | class MidiWriter(object): 12 | 13 | def __init__(self, verbose=False, default_vel=100): 14 | self.verbose = verbose 15 | self.note_range = RANGE 16 | self.default_velocity = default_vel 17 | 18 | def note_off(self, val, tick): 19 | self.track.append(midi.NoteOffEvent(tick=tick, pitch=val)) 20 | return 0 21 | 22 | def note_on(self, val, tick): 23 | self.track.append(midi.NoteOnEvent(tick=tick, pitch=val, velocity=self.default_velocity)) 24 | return 0 25 | 26 | def dump_sequence_to_midi(self, seq, output_filename, 27 | time_step=120, resolution=480, metronome=24, offset=21, 28 | format='final'): 29 | if self.verbose: 30 | print "Dumping sequence to MIDI file: {}".format(output_filename) 31 | print "Resolution: {}".format(resolution) 32 | print "Time Step: {}".format(time_step) 33 | 34 | pattern = midi.Pattern(resolution=resolution) 35 | self.track = midi.Track() 36 | 37 | # metadata track 38 | meta_track = midi.Track() 39 | time_sig = midi.TimeSignatureEvent() 40 | time_sig.set_numerator(4) 41 | time_sig.set_denominator(4) 42 | time_sig.set_metronome(metronome) 43 | time_sig.set_thirtyseconds(8) 44 | meta_track.append(time_sig) 45 | pattern.append(meta_track) 46 | 47 | # reshape to (SEQ_LENGTH X NUM_DIMS) 48 | if format == 'icml': 49 | # assumes seq is list of lists, where each inner list are all the midi notes that were non-zero at that given timestep 50 | sequence = np.zeros([len(seq), self.note_range]) 51 | sequence = [1 if i in tmstp else 0 for i in xrange(self.note_range) for tmstp in seq] 52 | sequence = np.reshape(sequence, [self.note_range,-1]).T 53 | elif format == 'flat': 54 | sequence = np.reshape(seq, [-1, self.note_range]) 55 | else: 56 | sequence = seq 57 | 58 | time_steps = sequence.shape[0] 59 | if self.verbose: 60 | print "Total number of time steps: {}".format(time_steps) 61 | 62 | tick = time_step 63 | self.notes_on = { n: False for n in range(self.note_range) } 64 | # for seq_idx in range(188, 220): 65 | for seq_idx in range(time_steps): 66 | notes = np.nonzero(sequence[seq_idx, :])[0].tolist() 67 | # n.b. notes += 21 ?? 68 | # need to be in range 21,109 69 | notes = [n+offset for n in notes] 70 | 71 | # this tick will only be assigned to first NoteOn/NoteOff in 72 | # this time_step 73 | 74 | # NoteOffEvents come first so they'll have the tick value 75 | # go through all notes that are currently on and see if any 76 | # turned off 77 | for n in self.notes_on: 78 | if self.notes_on[n] and n not in notes: 79 | tick = self.note_off(n, tick) 80 | self.notes_on[n] = False 81 | 82 | # Turn on any notes that weren't previously on 83 | for note in notes: 84 | if not self.notes_on[note]: 85 | tick = self.note_on(note, tick) 86 | self.notes_on[note] = True 87 | 88 | tick += time_step 89 | 90 | # flush out notes 91 | for n in self.notes_on: 92 | if self.notes_on[n]: 93 | self.note_off(n, tick) 94 | tick = 0 95 | self.notes_on[n] = False 96 | 97 | pattern.append(self.track) 98 | midi.write_midifile(output_filename, pattern) 99 | 100 | def write_sample(sample, outdir, fnm, isHalfAsSlow=False): 101 | if isHalfAsSlow: 102 | sample = np.repeat(sample, 2, axis=0) 103 | fnm = os.path.join(outdir, fnm + '.mid') 104 | MidiWriter().dump_sequence_to_midi(sample, fnm) 105 | -------------------------------------------------------------------------------- /code/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path 3 | import numpy as np 4 | from keras import losses 5 | from keras import backend as K 6 | from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping, Callback 7 | from utils.weightnorm import AdamWithWeightnorm, data_based_init 8 | 9 | bincrossentropy = lambda x, xhat: (x*np.log(np.maximum(1e-15, xhat)) + (1-x)*np.log(np.maximum(1e-15, 1-xhat))) 10 | 11 | def logmeanexp(vs, axis=0): 12 | m = np.amax(vs, axis=axis) 13 | return m + np.log(np.mean(np.exp(vs - m[None,:]), axis=axis)) 14 | 15 | def logsumexp(vs, axis=0): 16 | m = np.amax(vs, axis=axis) 17 | return m + np.log(np.sum(np.exp(vs - m[None,:]), axis=axis)) 18 | 19 | class AnnealLossWeight(Callback): 20 | """ 21 | increase the weight of a loss term by adjusting its value as a function of the epoch number 22 | """ 23 | def __init__(self, beta, name="beta", n_epochs=10, final_value=1.0, slope=0): 24 | super(AnnealLossWeight, self).__init__() 25 | self.beta = beta 26 | self.name = name 27 | self.slope = slope 28 | self.n_epochs = n_epochs 29 | self.start_value = K.eval(beta) 30 | self.final_value = final_value 31 | self.all_done = False 32 | 33 | def next_weight(self, x): 34 | if self.slope > 0: 35 | # sigmoid between 0.0 and 1.0 given x between 0.0 and 1.0 36 | return 1 / (1 + np.exp(-self.slope*(x-0.5))) 37 | else: 38 | # linear 39 | return 1.0*x 40 | 41 | def on_epoch_begin(self, epoch, logs={}): 42 | if self.all_done: 43 | return 44 | if epoch >= self.n_epochs: 45 | next_val = self.final_value 46 | self.all_done = True 47 | else: 48 | next_val = self.start_value + self.next_weight(1.0*epoch/self.n_epochs)*(self.final_value - self.start_value) 49 | K.set_value(self.beta, next_val) 50 | print "+++++ {}: {}".format(self.name, K.eval(self.beta)) 51 | 52 | def init_adam_wn(optimizer): 53 | if optimizer == 'adam-wn': 54 | adam_wn = AdamWithWeightnorm(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) 55 | return adam_wn, True 56 | else: 57 | return optimizer, False 58 | 59 | class EarlyStoppingAfterEpoch(Callback): 60 | def __init__(self, monitor='val_loss', min_epoch=0, min_delta=0, patience=0, verbose=0, mode='auto'): 61 | super(EarlyStoppingAfterEpoch, self).__init__() 62 | 63 | self.monitor = monitor 64 | self.patience = patience 65 | self.verbose = verbose 66 | self.min_epoch = min_epoch 67 | self.min_delta = min_delta 68 | self.wait = 0 69 | self.stopped_epoch = 0 70 | assert mode in ['auto', 'min', 'max'] 71 | 72 | if mode == 'min': 73 | self.monitor_op = np.less 74 | elif mode == 'max': 75 | self.monitor_op = np.greater 76 | else: 77 | if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): 78 | self.monitor_op = np.greater 79 | else: 80 | self.monitor_op = np.less 81 | 82 | if self.monitor_op == np.greater: 83 | self.min_delta *= 1 84 | else: 85 | self.min_delta *= -1 86 | 87 | def on_train_begin(self, logs=None): 88 | # Allow instances to be re-used 89 | self.wait = 0 90 | self.stopped_epoch = 0 91 | self.best = np.Inf if self.monitor_op == np.less else -np.Inf 92 | 93 | def on_epoch_end(self, epoch, logs=None): 94 | if epoch < self.min_epoch: 95 | return 96 | current = logs.get(self.monitor) 97 | if self.monitor_op(current - self.min_delta, self.best): 98 | self.best = current 99 | self.wait = 0 100 | else: 101 | if self.wait >= self.patience: 102 | self.stopped_epoch = epoch 103 | self.model.stop_training = True 104 | self.wait += 1 105 | 106 | class ModelCheckpointAfterEpoch(Callback): 107 | def __init__(self, filepath, monitor, min_epoch=0, save_weights_only=True, save_best_only=True, mode='auto', verbose=False): 108 | super(ModelCheckpointAfterEpoch, self).__init__() 109 | assert save_best_only and not verbose 110 | assert mode in ['auto', 'min', 'max'] 111 | self.filepath = filepath 112 | self.monitor = monitor 113 | self.min_epoch = min_epoch 114 | self.save_weights_only = save_weights_only 115 | if mode == 'min': 116 | self.monitor_op = np.less 117 | self.best = np.Inf 118 | elif mode == 'max': 119 | self.monitor_op = np.greater 120 | self.best = -np.Inf 121 | else: 122 | if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): 123 | self.monitor_op = np.greater 124 | self.best = -np.Inf 125 | else: 126 | self.monitor_op = np.less 127 | self.best = np.Inf 128 | 129 | def on_epoch_end(self, epoch, logs=None): 130 | if epoch < self.min_epoch: 131 | return 132 | logs = logs or {} 133 | filepath = self.filepath.format(epoch=epoch, **logs) 134 | current = logs.get(self.monitor) 135 | if self.monitor_op(current, self.best): 136 | self.best = current 137 | if self.save_weights_only: 138 | self.model.save_weights(filepath, overwrite=True) 139 | else: 140 | self.model.save(filepath, overwrite=True) 141 | 142 | def get_callbacks(args, patience=5, min_epoch=0, do_log=False): 143 | # prepare to save model checkpoints 144 | chkpt_filename = os.path.join(args.model_dir, args.run_name + '.h5') 145 | checkpt = ModelCheckpointAfterEpoch(chkpt_filename, min_epoch=min_epoch, 146 | monitor='val_loss', save_weights_only=True, save_best_only=True) 147 | # checkpt = ModelCheckpoint(chkpt_filename, monitor='val_loss', save_weights_only=True, save_best_only=True) 148 | callbacks = [checkpt] 149 | if do_log: 150 | logging = TensorBoard(log_dir=os.path.join(args.log_dir, args.run_name)) 151 | callbacks.append(logging) 152 | if patience > 0: 153 | early_stop = EarlyStoppingAfterEpoch(monitor='val_loss', 154 | min_epoch=min_epoch, patience=patience, verbose=0) 155 | callbacks.append(early_stop) 156 | # early_stop = EarlyStopping(monitor='val_loss', patience=patience, verbose=0) 157 | callbacks.append(early_stop) 158 | return callbacks 159 | 160 | def save_model_in_pieces(model, args): 161 | # save model structure 162 | outfile = os.path.join(args.model_dir, args.run_name + '.yaml') 163 | with open(outfile, 'w') as f: 164 | f.write(model.to_yaml()) 165 | # save model args 166 | outfile = os.path.join(args.model_dir, args.run_name + '.json') 167 | json.dump(vars(args), open(outfile, 'w')) 168 | 169 | def LL_frame(y, yhat): 170 | return 88*losses.binary_crossentropy(y, yhat) 171 | -------------------------------------------------------------------------------- /code/utils/pianoroll.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code to load pianoroll data (.pickle) 3 | """ 4 | import numpy as np 5 | import cPickle 6 | 7 | rel_keys = {'a': 'C', 8 | 'b-': 'D-', 9 | 'b': 'D', 10 | 'c': 'E-', 11 | 'c#': 'E', 12 | 'd-': 'F-', 13 | 'd': 'F', 14 | 'd#': 'F#', 15 | 'e-': 'G-', 16 | 'e': 'G', 17 | 'f': 'A-', 18 | 'f#': 'A', 19 | 'g': 'B-', 20 | 'g#': 'B', 21 | 'a-': 'C-', 22 | } 23 | 24 | def relative_major(k): 25 | return k if k.isupper() else rel_keys[k] 26 | 27 | def pianoroll_to_song(roll, offset=21): 28 | f = lambda x: (np.where(x)[0]+offset).tolist() 29 | return [f(s) for s in roll] 30 | 31 | def song_to_pianoroll(song, offset=21): 32 | """ 33 | song = [(60, 72, 79, 88), (72, 79, 88), (67, 70, 76, 84), ...] 34 | """ 35 | rolls = [] 36 | all_notes = [y for x in song for y in x] 37 | if min(all_notes)-offset < 0: 38 | offset -= 12 39 | # assert False 40 | if max(all_notes)-offset > 87: 41 | offset += 12 42 | # assert False 43 | for notes in song: 44 | roll = np.zeros(88) 45 | roll[[n-offset for n in notes]] = 1. 46 | rolls.append(roll) 47 | return np.vstack(rolls) 48 | 49 | def sliding_inds(n, seq_length, step_length): 50 | return np.arange(n-seq_length, step=step_length) 51 | 52 | def sliding_window(roll, seq_length, step_length=1): 53 | """ 54 | returns [n x seq_length x 88] 55 | if step_length == 1, then roll[i,1:] == roll[i+1,:-1] 56 | """ 57 | rolls = [] 58 | for i in sliding_inds(roll.shape[0], seq_length, step_length): 59 | rolls.append(roll[i:i+seq_length,:]) 60 | if len(rolls) == 0: 61 | return np.array([]) 62 | return np.dstack(rolls).swapaxes(0,2).swapaxes(1,2) 63 | 64 | def songs_to_pianoroll(songs, seq_length, step_length, inner_fcn=song_to_pianoroll): 65 | """ 66 | songs = [song1, song2, ...] 67 | """ 68 | rolls = [sliding_window(inner_fcn(s), seq_length, step_length) for s in songs] 69 | rolls = [r for r in rolls if len(r) > 0] 70 | inds = [i*np.ones((len(r),)) for i,r in enumerate(rolls)] 71 | return np.vstack(rolls), np.hstack(inds) 72 | 73 | class PianoData: 74 | def __init__(self, train_file, batch_size=None, seq_length=1, step_length=1, return_y_next=True, return_y_hist=False, squeeze_x=True, squeeze_y=True, use_rel_major=True): 75 | """ 76 | returns [n x seq_length x 88] where rows referring to the same song will overlap an amount determined by step_length 77 | 78 | specifying batch_size will ensure that that mod(n, batch_size) == 0 79 | """ 80 | D = cPickle.load(open(train_file)) 81 | self.train_file = train_file # .pickle source file 82 | self.batch_size = batch_size # ensures that nsamples is divisible by this 83 | self.seq_length = seq_length # returns [n x seq_length x 88] 84 | self.step_length = step_length # controls overlap in rows of X 85 | self.return_y_next = return_y_next # if True, y is next val of X; else y == X 86 | self.return_y_hist = return_y_hist # if True, y is next val of X for each column of X; else y == [n x 1 x 88] 87 | self.squeeze_x = squeeze_x # remove singleton dimensions in X? 88 | self.squeeze_y = squeeze_y # remove singleton dimensions in y? 89 | self.use_rel_major = use_rel_major # minor keys get mapped to their relative major, e.g. 'a' -> 'C' 90 | 91 | # sequences with song indices 92 | self.x_train, self.y_train, self.train_song_inds = self.make_xy(D['train']) 93 | self.x_test, self.y_test, self.test_song_inds = self.make_xy(D['test']) 94 | self.x_valid, self.y_valid, self.valid_song_inds = self.make_xy(D['valid']) 95 | 96 | # # song index per sequence 97 | # self.train_song_inds = self.song_inds(D['train']) 98 | # self.test_song_inds = self.song_inds(D['test']) 99 | # self.valid_song_inds = self.song_inds(D['valid']) 100 | 101 | # mode per sequence 102 | if 'train_mode' in D: 103 | self.train_song_modes = self.song_modes(D['train_mode'], self.train_song_inds) 104 | self.test_song_modes = self.song_modes(D['test_mode'], self.test_song_inds) 105 | self.valid_song_modes = self.song_modes(D['valid_mode'], self.valid_song_inds) 106 | if 'train_key' in D: 107 | D = self.update_keys(D) 108 | self.key_map = self.make_keymap(D) 109 | self.train_song_keys = self.song_keys(D['train_key'], self.train_song_inds) 110 | self.test_song_keys = self.song_keys(D['test_key'], self.test_song_inds) 111 | self.valid_song_keys = self.song_keys(D['valid_key'], self.valid_song_inds) 112 | 113 | def make_xy(self, songs): 114 | inner_fcn = song_to_pianoroll 115 | x_rolls, song_inds = songs_to_pianoroll(songs, self.seq_length + int(self.return_y_next), self.step_length, inner_fcn=inner_fcn) 116 | x_rolls = self.adjust_for_batch_size(x_rolls) 117 | song_inds = self.adjust_for_batch_size(song_inds) 118 | if self.return_y_next: # make Y the last col of X 119 | if self.return_y_hist: 120 | y_rolls = x_rolls[:,1:,:] 121 | else: 122 | y_rolls = x_rolls[:,-1,:] 123 | x_rolls = x_rolls[:,:-1,:] 124 | else: 125 | y_rolls = x_rolls 126 | if self.squeeze_x: # e.g., if X is [n x 1 x 88] 127 | x_rolls = x_rolls.squeeze() 128 | if self.squeeze_y: 129 | y_rolls = y_rolls.squeeze() 130 | return x_rolls, y_rolls, song_inds 131 | 132 | def song_modes(self, modes, song_inds): 133 | return np.array(modes)[song_inds.astype(int)] 134 | 135 | def update_keys(self, D): 136 | if not self.use_rel_major: 137 | return 138 | D['train_key'] = [relative_major(k) for k in D['train_key']] 139 | D['test_key'] = [relative_major(k) for k in D['test_key']] 140 | D['valid_key'] = [relative_major(k) for k in D['valid_key']] 141 | return D 142 | 143 | def make_keymap(self, D): 144 | all_keys = np.unique(np.hstack([D['train_key'], D['test_key'], D['valid_key']])) 145 | return dict(zip(all_keys, xrange(len(all_keys)))) 146 | 147 | def song_keys(self, keys, song_inds): 148 | """ 149 | also converts keys to ints, e.g., ['A', 'B', 'C'] -> [0, 1, 2] 150 | """ 151 | key_inds = [self.key_map[k] for k in keys] 152 | return np.array(key_inds)[song_inds.astype(int)] 153 | 154 | def adjust_for_batch_size(self, items): 155 | if self.batch_size is None: 156 | return items 157 | mod = (items.shape[0] % self.batch_size) 158 | return items[:-mod] if mod > 0 else items 159 | 160 | if __name__ == '__main__': 161 | # train_file = '../data/input/Piano-midi_all.pickle' 162 | train_file = '../data/input/JSB Chorales_all.pickle' 163 | P = PianoData(train_file, seq_length=1, step_length=1) 164 | -------------------------------------------------------------------------------- /code/utils/weightnorm.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | from keras.optimizers import SGD,Adam 3 | import tensorflow as tf 4 | 5 | # adapted from keras.optimizers.SGD 6 | class SGDWithWeightnorm(SGD): 7 | def get_updates(self, params, constraints, loss): 8 | grads = self.get_gradients(loss, params) 9 | self.updates = [] 10 | 11 | lr = self.lr 12 | if self.initial_decay > 0: 13 | lr *= (1. / (1. + self.decay * self.iterations)) 14 | self.updates .append(K.update_add(self.iterations, 1)) 15 | 16 | # momentum 17 | shapes = [K.get_variable_shape(p) for p in params] 18 | moments = [K.zeros(shape) for shape in shapes] 19 | self.weights = [self.iterations] + moments 20 | for p, g, m in zip(params, grads, moments): 21 | 22 | # if a weight tensor (len > 1) use weight normalized parameterization 23 | ps = K.get_variable_shape(p) 24 | if len(ps) > 1: 25 | 26 | # get weight normalization parameters 27 | V, V_norm, V_scaler, g_param, grad_g, grad_V = get_weightnorm_params_and_grads(p, g) 28 | 29 | # momentum container for the 'g' parameter 30 | V_scaler_shape = K.get_variable_shape(V_scaler) 31 | m_g = K.zeros(V_scaler_shape) 32 | 33 | # update g parameters 34 | v_g = self.momentum * m_g - lr * grad_g # velocity 35 | self.updates.append(K.update(m_g, v_g)) 36 | if self.nesterov: 37 | new_g_param = g_param + self.momentum * v_g - lr * grad_g 38 | else: 39 | new_g_param = g_param + v_g 40 | 41 | # update V parameters 42 | v_v = self.momentum * m - lr * grad_V # velocity 43 | self.updates.append(K.update(m, v_v)) 44 | if self.nesterov: 45 | new_V_param = V + self.momentum * v_v - lr * grad_V 46 | else: 47 | new_V_param = V + v_v 48 | 49 | # if there are constraints we apply them to V, not W 50 | if p in constraints: 51 | c = constraints[p] 52 | new_V_param = c(new_V_param) 53 | 54 | # wn param updates --> W updates 55 | add_weightnorm_param_updates(self.updates, new_V_param, new_g_param, p, V_scaler) 56 | 57 | else: # normal SGD with momentum 58 | v = self.momentum * m - lr * g # velocity 59 | self.updates.append(K.update(m, v)) 60 | 61 | if self.nesterov: 62 | new_p = p + self.momentum * v - lr * g 63 | else: 64 | new_p = p + v 65 | 66 | # apply constraints 67 | if p in constraints: 68 | c = constraints[p] 69 | new_p = c(new_p) 70 | 71 | self.updates.append(K.update(p, new_p)) 72 | return self.updates 73 | 74 | # adapted from keras.optimizers.Adam 75 | class AdamWithWeightnorm(Adam): 76 | def get_updates(self, params, constraints, loss): 77 | grads = self.get_gradients(loss, params) 78 | self.updates = [K.update_add(self.iterations, 1)] 79 | 80 | lr = self.lr 81 | if self.initial_decay > 0: 82 | lr *= (1. / (1. + self.decay * self.iterations)) 83 | 84 | t = self.iterations + 1 85 | lr_t = lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)) 86 | 87 | shapes = [K.get_variable_shape(p) for p in params] 88 | ms = [K.zeros(shape) for shape in shapes] 89 | vs = [K.zeros(shape) for shape in shapes] 90 | self.weights = [self.iterations] + ms + vs 91 | 92 | for p, g, m, v in zip(params, grads, ms, vs): 93 | 94 | # if a weight tensor (len > 1) use weight normalized parameterization 95 | # this is the only part changed w.r.t. keras.optimizers.Adam 96 | ps = K.get_variable_shape(p) 97 | if len(ps)>1: 98 | 99 | # get weight normalization parameters 100 | V, V_norm, V_scaler, g_param, grad_g, grad_V = get_weightnorm_params_and_grads(p, g) 101 | 102 | # Adam containers for the 'g' parameter 103 | V_scaler_shape = K.get_variable_shape(V_scaler) 104 | m_g = K.zeros(V_scaler_shape) 105 | v_g = K.zeros(V_scaler_shape) 106 | 107 | # update g parameters 108 | m_g_t = (self.beta_1 * m_g) + (1. - self.beta_1) * grad_g 109 | v_g_t = (self.beta_2 * v_g) + (1. - self.beta_2) * K.square(grad_g) 110 | new_g_param = g_param - lr_t * m_g_t / (K.sqrt(v_g_t) + self.epsilon) 111 | self.updates.append(K.update(m_g, m_g_t)) 112 | self.updates.append(K.update(v_g, v_g_t)) 113 | 114 | # update V parameters 115 | m_t = (self.beta_1 * m) + (1. - self.beta_1) * grad_V 116 | v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(grad_V) 117 | new_V_param = V - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) 118 | self.updates.append(K.update(m, m_t)) 119 | self.updates.append(K.update(v, v_t)) 120 | 121 | # if there are constraints we apply them to V, not W 122 | if p in constraints: 123 | c = constraints[p] 124 | new_V_param = c(new_V_param) 125 | 126 | # wn param updates --> W updates 127 | add_weightnorm_param_updates(self.updates, new_V_param, new_g_param, p, V_scaler) 128 | 129 | else: # do optimization normally 130 | m_t = (self.beta_1 * m) + (1. - self.beta_1) * g 131 | v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) 132 | p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) 133 | 134 | self.updates.append(K.update(m, m_t)) 135 | self.updates.append(K.update(v, v_t)) 136 | 137 | new_p = p_t 138 | # apply constraints 139 | if p in constraints: 140 | c = constraints[p] 141 | new_p = c(new_p) 142 | self.updates.append(K.update(p, new_p)) 143 | return self.updates 144 | 145 | 146 | def get_weightnorm_params_and_grads(p, g): 147 | ps = K.get_variable_shape(p) 148 | 149 | # construct weight scaler: V_scaler = g/||V|| 150 | V_scaler_shape = (ps[-1],) # assumes we're using tensorflow! 151 | V_scaler = K.ones(V_scaler_shape) # init to ones, so effective parameters don't change 152 | 153 | # get V parameters = ||V||/g * W 154 | norm_axes = [i for i in range(len(ps) - 1)] 155 | V = p / tf.reshape(V_scaler, [1] * len(norm_axes) + [-1]) 156 | 157 | # split V_scaler into ||V|| and g parameters 158 | V_norm = tf.sqrt(tf.reduce_sum(tf.square(V), norm_axes)) 159 | g_param = V_scaler * V_norm 160 | 161 | # get grad in V,g parameters 162 | grad_g = tf.reduce_sum(g * V, norm_axes) / V_norm 163 | grad_V = tf.reshape(V_scaler, [1] * len(norm_axes) + [-1]) * \ 164 | (g - tf.reshape(grad_g / V_norm, [1] * len(norm_axes) + [-1]) * V) 165 | 166 | return V, V_norm, V_scaler, g_param, grad_g, grad_V 167 | 168 | 169 | def add_weightnorm_param_updates(updates, new_V_param, new_g_param, W, V_scaler): 170 | ps = K.get_variable_shape(new_V_param) 171 | norm_axes = [i for i in range(len(ps) - 1)] 172 | 173 | # update W and V_scaler 174 | new_V_norm = tf.sqrt(tf.reduce_sum(tf.square(new_V_param), norm_axes)) 175 | new_V_scaler = new_g_param / new_V_norm 176 | new_W = tf.reshape(new_V_scaler, [1] * len(norm_axes) + [-1]) * new_V_param 177 | updates.append(K.update(W, new_W)) 178 | updates.append(K.update(V_scaler, new_V_scaler)) 179 | 180 | 181 | # data based initialization for a given Keras model 182 | def data_based_init(model, input): 183 | 184 | # input can be dict, numpy array, or list of numpy arrays 185 | if type(input) is dict: 186 | feed_dict = input 187 | elif type(input) is list: 188 | feed_dict = {tf_inp: np_inp for tf_inp,np_inp in zip(model.inputs,input)} 189 | else: 190 | feed_dict = {model.inputs[0]: input} 191 | 192 | # add learning phase if required 193 | if model.uses_learning_phase and K.learning_phase() not in feed_dict: 194 | feed_dict.update({K.learning_phase(): 1}) 195 | 196 | # get all layer name, output, weight, bias tuples 197 | layer_output_weight_bias = [] 198 | for l in model.layers: 199 | if hasattr(l, 'W') and hasattr(l, 'b'): 200 | assert(l.built) 201 | layer_output_weight_bias.append( (l.name,l.get_output_at(0),l.W,l.b) ) # if more than one node, only use the first 202 | 203 | # iterate over our list and do data dependent init 204 | sess = K.get_session() 205 | for l,o,W,b in layer_output_weight_bias: 206 | print('Performing data dependent initialization for layer ' + l) 207 | m,v = tf.nn.moments(o, [i for i in range(len(o.get_shape())-1)]) 208 | s = tf.sqrt(v + 1e-10) 209 | updates = tf.group(W.assign(W/tf.reshape(s,[1]*(len(W.get_shape())-1)+[-1])), b.assign((b-m)/s)) 210 | sess.run(updates, feed_dict) 211 | -------------------------------------------------------------------------------- /data/input/JSB Chorales_Cs.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/data/input/JSB Chorales_Cs.pickle -------------------------------------------------------------------------------- /data/input/Piano-midi_Cs.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/data/input/Piano-midi_Cs.pickle -------------------------------------------------------------------------------- /data/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/data/models/.gitkeep -------------------------------------------------------------------------------- /data/samples/JSB10_CL-VAE_infer.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/data/samples/JSB10_CL-VAE_infer.wav -------------------------------------------------------------------------------- /data/samples/JSB10_CL-VRNN_infer.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/data/samples/JSB10_CL-VRNN_infer.wav -------------------------------------------------------------------------------- /data/samples/JSB10_Data.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/data/samples/JSB10_Data.wav -------------------------------------------------------------------------------- /data/samples/JSB10_VAE.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/data/samples/JSB10_VAE.wav -------------------------------------------------------------------------------- /data/samples/JSB10_VRNN.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/data/samples/JSB10_VRNN.wav -------------------------------------------------------------------------------- /data/samples/PMall_CL-VAE_infer.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/data/samples/PMall_CL-VAE_infer.wav -------------------------------------------------------------------------------- /data/samples/PMall_CL-VAE_true.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/data/samples/PMall_CL-VAE_true.wav -------------------------------------------------------------------------------- /data/samples/PMall_Data.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/data/samples/PMall_Data.wav -------------------------------------------------------------------------------- /data/samples/PMall_VAE.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobeets/classifying-vae-lstm/042a898558a967a37d1246a51b4c978d08783b7b/data/samples/PMall_VAE.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Keras==2.0.0 2 | tensorflow==1.0.1 --------------------------------------------------------------------------------