├── .gitignore ├── examples ├── gan_example_0.mid ├── gan_example_1.mid ├── skip_connected_boosting_lstm.mid ├── after 15 epochs, batch 2 example 0.mid ├── after 20 epochs, batch 2 example 0.mid └── after 30 epochs, batch 2 example 0.mid ├── LICENSE ├── README.md └── rnn-cnn-gan-enhancer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /examples/gan_example_0.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marekzidek/cnn-lstm-gan-music-generation/HEAD/examples/gan_example_0.mid -------------------------------------------------------------------------------- /examples/gan_example_1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marekzidek/cnn-lstm-gan-music-generation/HEAD/examples/gan_example_1.mid -------------------------------------------------------------------------------- /examples/skip_connected_boosting_lstm.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marekzidek/cnn-lstm-gan-music-generation/HEAD/examples/skip_connected_boosting_lstm.mid -------------------------------------------------------------------------------- /examples/after 15 epochs, batch 2 example 0.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marekzidek/cnn-lstm-gan-music-generation/HEAD/examples/after 15 epochs, batch 2 example 0.mid -------------------------------------------------------------------------------- /examples/after 20 epochs, batch 2 example 0.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marekzidek/cnn-lstm-gan-music-generation/HEAD/examples/after 20 epochs, batch 2 example 0.mid -------------------------------------------------------------------------------- /examples/after 30 epochs, batch 2 example 0.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marekzidek/cnn-lstm-gan-music-generation/HEAD/examples/after 30 epochs, batch 2 example 0.mid -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Marek Židek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cnn-lstm-gan-music-generation 2 | 3 | EDIT: An outline for this GAN scheme is included in my [Bachelor thesis](150999794678469.pdf), which experiments with LSTM skip-connections for music generation, and features very robust survey evaluation. 4 | 5 | This program is used to enhance the outputs of recurrent neural music generator, it could also be used to tune human composed music. 6 | 7 | 8 | Work in progress, repository doesn't contain midis, or needed files, directories or weights. Just up to date code as my backup. 9 | 10 | ![how it works](https://cloud.githubusercontent.com/assets/13591225/25017098/c981fd9e-2082-11e7-8574-aaea5a4174bc.gif) 11 | 12 | One example of enhancements can be found in discriminator_memory folder. 13 | 14 | It takes music from this network(in my MusicNetwork repo): 15 | 16 | ![first model](https://cloud.githubusercontent.com/assets/13591225/25025151/6071e26e-20a1-11e7-870d-25f623b627b8.png) 17 | 18 | 19 | The proposed model diagram, (more detailed in code, there might be mistakes in the picture): 20 | 21 | ![EGAN](https://cloud.githubusercontent.com/assets/13591225/25025240/b8df0620-20a1-11e7-9e9c-f45dd9c91e19.png) 22 | 23 | -------------------------------------------------------------------------------- /rnn-cnn-gan-enhancer.py: -------------------------------------------------------------------------------- 1 | from keras.models import Sequential, Model 2 | from keras.layers import Dense, Input 3 | from keras.layers import Reshape 4 | from keras.layers.core import Activation, Flatten, Dropout, Lambda 5 | from keras.layers.convolutional import Convolution1D, Convolution2D, ZeroPadding1D 6 | from keras.layers.pooling import GlobalMaxPooling1D, AveragePooling1D, MaxPooling1D 7 | from keras.layers.wrappers import TimeDistributed, Bidirectional 8 | from keras.layers.recurrent import GRU, LSTM, SimpleRNN 9 | from keras.regularizers import l2, l1, l1_l2 10 | from keras import backend 11 | from keras.layers.advanced_activations import ELU, PReLU, LeakyReLU 12 | from keras.layers import Layer 13 | import keras.layers 14 | import keras 15 | import itertools 16 | import sys 17 | import random 18 | import theano.tensor as T 19 | import random 20 | import os 21 | import tensorflow as tf 22 | 23 | from keras.optimizers import SGD 24 | import numpy as np 25 | import argparse 26 | import math 27 | from midi_to_matrix import * 28 | import train as trainLoadMusic 29 | 30 | np.random.seed(26041994) 31 | 32 | generating_size = 50*8 # not recommended to change here or via --song_length as a change is also needed in the model 33 | # that generates latent music 34 | note_span_with_ligatures = 156 # will be used as a "magic number constant" in convolution filters 35 | 36 | def generator_model(): 37 | 38 | input_song = Input(shape=(generating_size,156)) 39 | amplified = keras.layers.core.Lambda(lambda x:x * 12 - 6)(input_song) 40 | 41 | forget = LSTM(250, return_sequences=True)(input_song) 42 | forget = keras.layers.local.LocallyConnected1D(156,1,activation='sigmoid', kernel_initializer='zeros', bias_initializer=keras.initializers.Constant(-6.0))(forget) 43 | 44 | conservativity_sum_0 = Lambda(lambda x: backend.sum(x, axis=1), output_shape=lambda s: (s[0], s[2]))(forget) 45 | conservativity_sum_0 = Lambda(lambda x: backend.sum(x[:,::2], axis=1), output_shape=lambda s: (s[0],1))(conservativity_sum_0) 46 | tresh = keras.layers.advanced_activations.ThresholdedReLU(theta=400.0)(conservativity_sum_0) 47 | 48 | # this will be zero if treshold was not found, 49 | tresh = Lambda(lambda x:backend.log(x+1))(tresh) 50 | 51 | # penalty the forget gate for forgetting too much, loged to avoid vanishing gradient for next activation 52 | forget = Lambda(lambda x:x-tresh)(forget) 53 | 54 | forget = keras.layers.core.Lambda(lambda x:(5.5*x))(forget) 55 | forget = Activation('sigmoid')(forget) 56 | 57 | # multiply to be able to outvote residual 6* connection 58 | forget = keras.layers.core.Lambda(lambda x:-(12*x))(forget) 59 | 60 | 61 | add = LSTM(250,kernel_initializer='zeros', return_sequences=True)(input_song) 62 | add = keras.layers.local.LocallyConnected1D(156,1,activation='sigmoid', kernel_initializer='zeros', bias_initializer=keras.initializers.Constant(-6.0))(add) 63 | conservativity_sum_1 = Lambda(lambda x: backend.sum(x, axis=1), output_shape=lambda s: (s[0], s[2]))(add) 64 | conservativity_sum_1 = Lambda(lambda x: backend.sum(x[:,::2], axis=1), output_shape=lambda s: (s[0],1))(conservativity_sum_1) 65 | 66 | treshold = keras.layers.advanced_activations.ThresholdedReLU(theta=60.0)(conservativity_sum_1) 67 | 68 | # this will be zero if treshold was not found, logarithmed to avoid vanishing problem with the succeding add gate 69 | treshold = Lambda(lambda x:backend.log(x+1))(treshold) 70 | 71 | # penalty the forget gate for forgetting too much 72 | add = Lambda(lambda x:x-treshold)(add) 73 | add = keras.layers.core.Lambda(lambda x:(5.5*x))(add) 74 | add = Activation('sigmoid')(add) 75 | 76 | # multiply to be able to outvote residual 6* connection 77 | add = keras.layers.core.Lambda(lambda x:(x*12))(add) 78 | 79 | residual = keras.layers.merge([amplified, forget, add], mode='sum') 80 | residual = Activation('sigmoid')(residual) 81 | 82 | # mask the ligatures/articulations to not be learned, where the corresponding note is not played 83 | mask_for_articulation = keras.layers.advanced_activations.ThresholdedReLU(theta=0.5)(residual) 84 | mask_for_articulation = Lambda(lambda x:x[:,:,1::2], output_shape=lambda s: (s[0], s[1], 78))(mask_for_articulation) 85 | play = Lambda(lambda x:x[:,:,::2], output_shape=lambda s: (s[0], s[1], 78))(residual) 86 | reshaped_play = Reshape((generating_size,78,1))(play) 87 | arti = Lambda(lambda x:x[:,:,1::2], output_shape=lambda s: (s[0], s[1], 78))(residual) 88 | articulate = Lambda(lambda x: x * mask_for_articulation)(arti) 89 | reshaped_articulate = Reshape((generating_size,78,1))(articulate) 90 | final = keras.layers.concatenate([reshaped_play, reshaped_articulate]) 91 | final = Reshape((generating_size,156))(final) 92 | 93 | model = Model(inputs=input_song, outputs=final) 94 | return model 95 | 96 | def discriminator_model(): 97 | 98 | ## check joint 99 | 100 | input_song = Input(shape=(generating_size,156)) 101 | 102 | joint = Reshape((generating_size,156,1))(input_song) 103 | joint = TimeDistributed(Convolution1D(filters=20,kernel_size=8, padding='valid', strides=2))(joint) #tercie (3 pultony = sleduji 4 noty) * ligatura(proto 2 strides) 104 | #39 105 | joint = Activation(LeakyReLU(0.3))(joint) 106 | joint = TimeDistributed(Convolution1D(filters=40,kernel_size=3, padding='valid', strides=1))(joint) # velka tercie 4 pultony a cista kvarta 5 pultonu 107 | #38 108 | joint = Activation(LeakyReLU(0.3))(joint) 109 | joint = TimeDistributed(Convolution1D(filters=200,kernel_size=3, padding='valid', strides=1))(joint) # kvinty 7 pultonu = sleduji 8 not 110 | #17 111 | joint = Activation(LeakyReLU(0.3))(joint) 112 | joint = TimeDistributed(MaxPooling1D(2))(joint) # chci dominantni akord 113 | 114 | joint = TimeDistributed(Convolution1D(filters=300,kernel_size=3, padding='valid', strides=1))(joint) 115 | #5 116 | print joint.shape 117 | joint = Activation(LeakyReLU(0.3))(joint) 118 | joint = TimeDistributed(MaxPooling1D(2))(joint) 119 | print joint.shape 120 | joint = TimeDistributed(Convolution1D(filters=400,kernel_size=3, padding='valid', strides=2))(joint) 121 | #5 122 | print joint.shape 123 | joint = Activation(LeakyReLU(0.3))(joint) 124 | # (gen_size, 66, 20) 125 | cross_joint = Reshape((generating_size,7*400))(joint) 126 | joint = TimeDistributed(Dense(50))(cross_joint) 127 | joint = Flatten()(joint) 128 | joint = Dropout(0.5)(joint) 129 | joint = Dense(1, kernel_regularizer=keras.regularizers.l2(0.1))(joint) 130 | joint = Activation(LeakyReLU(0.3))(joint) 131 | 132 | ## check rhythm 133 | 134 | rhythm = ZeroPadding1D(4)(input_song) # 4 on both sides, so locally connecteds kernel will be 9 (bc. they don't supp 'same' yet) 135 | rhythm = Convolution1D(filters=20*20, kernel_size=24, strides=16, padding='valid')(rhythm) 136 | rhythm = Activation(LeakyReLU(0.3))(rhythm) 137 | rhythm = Reshape((generating_size/16, 20, 20))(rhythm) 138 | rhythm = TimeDistributed(keras.layers.local.LocallyConnected1D(filters=100, kernel_size=9, padding='valid'))(rhythm) 139 | rhythm = Activation(LeakyReLU(0.3))(rhythm) 140 | rhythm = TimeDistributed(Dense(50))(rhythm) 141 | rhythm = Flatten()(rhythm) 142 | rhythm = Dropout(0.5)(rhythm) 143 | rhythm = Dense(1, kernel_regularizer=keras.regularizers.l2(0.1))(rhythm) 144 | rhythm = Activation(LeakyReLU(0.3))(rhythm) 145 | 146 | ## check structure 147 | 148 | structure = Reshape((generating_size,156,1))(input_song) 149 | structure = TimeDistributed(Convolution1D(filters=16,kernel_size=8, padding='same', strides=4))(structure) #tercie*ligatura 150 | # 78 151 | structure = Activation(LeakyReLU(0.3))(structure) 152 | structure = TimeDistributed(Convolution1D(filters=32,kernel_size=2, padding='valid', strides=2))(structure) #kvinty 153 | structure = TimeDistributed(MaxPooling1D(2))(structure) 154 | structure = Reshape((generating_size,9*32))(structure) 155 | structure = Convolution1D(80,2)(structure) 156 | structure = Activation(LeakyReLU(0.3))(structure) 157 | structure = Convolution1D(120,2, dilation_rate=2)(structure) 158 | structure = Activation(LeakyReLU(0.3))(structure) 159 | structure = Convolution1D(160,2, dilation_rate=4)(structure) 160 | structure = Activation(LeakyReLU(0.3))(structure) 161 | structure = Convolution1D(200,2, dilation_rate=8)(structure) 162 | structure = Activation(LeakyReLU(0.3))(structure) 163 | structure = TimeDistributed(Dense(50))(structure) 164 | structure = Dropout(0.5)(structure) 165 | structure = Flatten()(structure) 166 | structure = Dense(1, kernel_regularizer=keras.regularizers.l2(0.1))(structure) 167 | structure = Activation(LeakyReLU(0.3))(structure) 168 | 169 | ## check consistency 170 | 171 | differences = Reshape((generating_size,156,1))(input_song) 172 | differences = TimeDistributed(Convolution1D(filters=1,kernel_size=2, padding='same', strides=2))(differences) #tercie*ligatura 173 | # 78 174 | differences = Activation(LeakyReLU(0.3))(differences) 175 | differences = Reshape((generating_size,78))(differences) 176 | differences = Convolution1D(150,2)(differences) 177 | differences = SimpleRNN(200,return_sequences=True)(differences) 178 | differences = TimeDistributed(Dense(1,kernel_regularizer=keras.regularizers.l2(0.1)))(differences) 179 | differences = Activation(LeakyReLU(0.3))(differences) 180 | differences = Flatten()(differences) 181 | differences = Dropout(0.5)(differences) 182 | differences = Dense(1, kernel_regularizer=keras.regularizers.l2(0.1))(differences) 183 | differences = Activation(LeakyReLU(0.3))(differences) 184 | 185 | continuity = GRU(150,return_sequences=True)(cross_joint) 186 | continuity = Activation(LeakyReLU(0.3))(continuity) 187 | continuity = TimeDistributed(Dense(1,kernel_regularizer=keras.regularizers.l2(0.1)))(continuity) 188 | continuity = Flatten()(continuity) 189 | continuity = Dropout(0.5)(continuity) 190 | continuity = Dense(1, kernel_regularizer=keras.regularizers.l2(0.1))(continuity) 191 | continuity = Activation(LeakyReLU(0.3))(continuity) 192 | 193 | final = keras.layers.concatenate([joint, rhythm, structure, continuity, differences]) 194 | final = Dropout(0.35)(final) 195 | final = Dense(1)(final) 196 | #final = Activation('sigmoid')(final) # Do not use in Wasserstein GAN (also use mean_squared_error) 197 | 198 | model = Model(inputs=input_song, outputs=final) 199 | return model 200 | 201 | def generator_with_discriminator_model(generator, discriminator): 202 | 203 | model = Sequential() 204 | model.add(generator) 205 | 206 | discriminator.trainable = False 207 | #model.add(Reshape((generating_size*156,1))) 208 | model.add(discriminator) 209 | 210 | return model 211 | 212 | 213 | def createBatches(music_list, SONG_LENGTH, BATCH_SIZE): 214 | 215 | if len(music_list) == 0: 216 | print "None music in music or lstm_outputs folder" 217 | sys.exit() 218 | 219 | batch_random_indices = range(len(music_list)) 220 | random.shuffle(batch_random_indices) 221 | 222 | train_X = [] 223 | for i in range(len(music_list)/BATCH_SIZE): 224 | batch = [] 225 | for j in range(BATCH_SIZE): 226 | prepart = music_list[batch_random_indices.pop(0)] 227 | part = [] 228 | 229 | for timestep in prepart: 230 | merged = list(itertools.chain.from_iterable(timestep)) 231 | part.append(merged) 232 | 233 | batch.append(part) 234 | batch = np.array(batch) 235 | train_X.append(batch) 236 | 237 | return np.array(train_X) 238 | 239 | 240 | ## tahle metoda je odporna. potom prepsat at tam neni 2x skoro uplne to stejny 241 | def generate_from_midis(path_memory, path_train): 242 | 243 | batch_size = 3 244 | if len(os.listdir(path_memory)) < len(os.listdir(path_train)): 245 | memory_music_names = os.listdir(path_memory) 246 | memory_music_names = [ midi for midi in memory_music_names if midi[-4:] in ('.mid', '.MID')] 247 | while 1: 248 | random_indices = range(len(os.listdir(path_train))) 249 | random.shuffle(random_indices) 250 | random_pos = 0 251 | batch_pos = 0 252 | for name in [midi for midi in os.listdir(path_train)if midi[-4:] in ('.mid', '.MID')]: 253 | x_memory = midiToMatrix(os.path.join(path_memory, memory_music_names[random_indices[random_pos] % len(memory_music_names)]))[:generating_size] 254 | x_train = midiToMatrix(os.path.join(path_train, name))[:generating_size] 255 | random_pos += 1 256 | if len(x_memory) < generating_size or len(x_train) < generating_size: 257 | continue 258 | x_memory = np.array(x_memory).reshape((1,generating_size, note_span_with_ligatures)) 259 | x_train = np.array(x_train).reshape((1,generating_size, note_span_with_ligatures)) 260 | 261 | x = np.concatenate([x_memory,x_train]) 262 | 263 | if batch_pos % batch_size == 0: 264 | batch_x = x 265 | batch_y = np.array([-0.5,0.5]).reshape((2,1)) 266 | else: 267 | batch_x = np.concatenate([batch_x, x]) 268 | batch_y = np.concatenate([batch_y, np.array([-0.5,0.5]).reshape((2,1))]) 269 | 270 | batch_pos += 1 271 | if batch_pos % batch_size == 0: 272 | yield (batch_x, batch_y) 273 | else: 274 | train_music_names = os.listdir(path_train) 275 | train_music_names = [ midi for midi in train_music_names if midi[-4:] in ('.mid', '.MID')] 276 | 277 | while 1: 278 | random_indices = range(len(os.listdir(path_memory))) 279 | random.shuffle(random_indices) 280 | random_pos = 0 281 | batch_pos = 0 282 | for name in [midi for midi in os.listdir(path_memory)if midi[-4:] in ('.mid', '.MID')]: 283 | x_memory = midiToMatrix(os.path.join(path_memory, name))[:generating_size] 284 | x_train = midiToMatrix(os.path.join(path_train, train_music_names[random_indices[random_pos] % len(train_music_names)]))[:generating_size] 285 | random_pos += 1 286 | if len(x_memory) < generating_size or len(x_train) < generating_size: 287 | continue 288 | x_memory = np.array(x_memory).reshape((1,generating_size, note_span_with_ligatures)) 289 | x_train = np.array(x_train).reshape((1,generating_size, note_span_with_ligatures)) 290 | 291 | x = np.concatenate([x_memory,x_train]) 292 | 293 | if batch_pos % batch_size == 0: 294 | batch_x = x 295 | batch_y = np.array([-0.5,0.5]).reshape((2,1)) 296 | else: 297 | batch_x = np.concatenate([batch_x, x]) 298 | batch_y = np.concatenate([batch_y, np.array([-0.5,0.5]).reshape((2,1))]) 299 | 300 | batch_pos += 1 301 | if batch_pos % batch_size == 0: 302 | yield (batch_x, batch_y) 303 | 304 | 305 | def train(BATCH_SIZE, SONG_LENGTH, EPOCH): 306 | 307 | sys.setrecursionlimit(100000) 308 | 309 | discriminator = discriminator_model() 310 | print "loading latent music" 311 | latent_music = trainLoadMusic.loadMusic("lstm_outputs", SONG_LENGTH) 312 | latent_music = latent_music.values() 313 | 314 | print "creating discriminator" 315 | discriminator = discriminator_model() 316 | print "created discriminator" 317 | generator = generator_model() 318 | generator_with_discriminator = generator_with_discriminator_model(generator, discriminator) 319 | 320 | #d_optim = keras.optimizers.RMSprop(lr=0.0001, rho=0.9, epsilon=1e-08, decay=0.0) 321 | d_optim = keras.optimizers.Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) 322 | g_optim = keras.optimizers.RMSprop(lr=0.0001, rho=0.9, epsilon=1e-08, decay=0.0, clipnorm=0.5, clipvalue=0.5) 323 | #g_optim = keras.optimizers.Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) 324 | #g_optim = keras.optimizers.Adadelta(lr=1.0,rho=0.95,epsilon=1e-8,decay=0.0) 325 | 326 | generator.compile(loss='binary_crossentropy', optimizer="adam") 327 | generator_with_discriminator.compile( 328 | loss='mean_squared_error', optimizer=g_optim) 329 | 330 | 331 | discriminator.trainable = True 332 | discriminator.compile(loss='mean_squared_error', optimizer=d_optim) 333 | 334 | 335 | print "loading weighsts" 336 | #generator.load_weights('generator_pretrain') 337 | 338 | print "loaded weights" 339 | 340 | latent_batches = createBatches(latent_music, SONG_LENGTH, BATCH_SIZE) 341 | 342 | # save memory for discriminator to not forget about this being FAKE 343 | for j in range(len(latent_batches)): 344 | generated_music = generator.predict_on_batch(latent_batches[j]) 345 | for k in range(BATCH_SIZE): 346 | song = generated_music[k].reshape((SONG_LENGTH,note_span_with_ligatures/2,2)) 347 | song = generate_from_probabilities(song) 348 | matrixToMidi(song,'discriminator_memory/after pre epochs {} example {}'.format(j,k)) 349 | 350 | print "loading disc weights" 351 | #discriminator.load_weights('discriminator 2') 352 | print "loaded disc weights" 353 | discriminator.load_weights('discriminator 2') 354 | 355 | discriminator.fit_generator(generate_from_midis("discriminator_memory", "music"), steps_per_epoch=40, epochs=7) 356 | print discriminator.layers[-1].get_weights() 357 | 358 | for epoch in range(1, 100): 359 | 360 | #if epoch % 8 == 0: 361 | # generator.load_weights('generator_identity') 362 | 363 | #latent_batches = createBatches(latent_music, SONG_LENGTH, BATCH_SIZE) 364 | 365 | for indexer in xrange(3*len(latent_batches)): 366 | # latent_batches (batch_size, song_length, notes) 367 | # generated_music = generator.predict_on_batch(latent_batches[indexer % len(latent_batches)]) 368 | 369 | discriminator.trainable = False 370 | latent_batches = createBatches(latent_music, SONG_LENGTH, BATCH_SIZE) 371 | 372 | 373 | 374 | for i in range(3): 375 | 376 | what_to_train = [0.5 for j in range(BATCH_SIZE)] 377 | g_loss = generator_with_discriminator.train_on_batch( 378 | latent_batches[indexer % len(latent_batches)], np.array([what_to_train]).reshape((BATCH_SIZE,1))) 379 | evalu = generator_with_discriminator.predict_on_batch(latent_batches[indexer % len(latent_batches)]) 380 | print evalu 381 | print("epoch %d, batch %d gen_loss : %f" % (epoch, indexer % len(latent_batches), g_loss)) 382 | generated_music = generator.predict_on_batch(latent_batches[indexer % len(latent_batches)]) 383 | song_0 = generated_music[0].reshape((SONG_LENGTH,note_span_with_ligatures/2,2)) 384 | song_0 = generate_from_probabilities(song_0) 385 | matrixToMidi(song_0,'outputs/test {} {}'.format(i, indexer)) 386 | 387 | 388 | if indexer % 10 == 0: 389 | song_0 = generated_music[0].reshape((SONG_LENGTH,note_span_with_ligatures/2,2)) 390 | song_0 = generate_from_probabilities(song_0, conservativity=1.0) 391 | matrixToMidi(song_0,'outputs/after {} epochs {} example 0'.format(epoch, indexer)) 392 | 393 | if indexer % len(latent_batches) == 0: 394 | generator.save_weights('generator {} indexer {} '.format(epoch, indexer), True) 395 | 396 | generator.save_weights('generator {}'.format(epoch), True) 397 | 398 | 399 | folder = 'discriminator_memory' 400 | for the_file in os.listdir(folder): 401 | file_path = os.path.join(folder, the_file) 402 | try: 403 | if os.path.isfile(file_path): 404 | os.unlink(file_path) 405 | #elif os.path.isdir(file_path): shutil.rmtree(file_path) 406 | except Exception as e: 407 | print(e) 408 | 409 | print "saving generated songs for discriminator" 410 | # save memory for discriminator to not forget about this being FAKE 411 | for i in range(len(latent_batches)): 412 | generated_music = generator.predict_on_batch(latent_batches[i]) 413 | for j in range(BATCH_SIZE): 414 | song = generated_music[j].reshape((SONG_LENGTH,note_span_with_ligatures/2,2)) 415 | song = generate_from_probabilities(song) 416 | matrixToMidi(song,'discriminator_memory/after {} epochs, batch {} example {}'.format(epoch, i, j)) 417 | 418 | print "saved all songs" 419 | 420 | discriminator.trainable = True 421 | 422 | discriminator.fit_generator(generate_from_midis("discriminator_memory", "music"), steps_per_epoch=5, epochs=2) 423 | 424 | discriminator.save_weights('discriminator {}'.format(epoch)) 425 | 426 | 427 | def generate(SONG_LENGTH, nb): 428 | 429 | generator = generator_model() 430 | generator.compile(loss='binary_crossentropy', optimizer="SGD") 431 | generator.load_weights('generator') 432 | 433 | print "loading_latent_music" 434 | latent_music = trainLoadMusic.loadMusic("lstm_outputs", SONG_LENGTH) 435 | 436 | for i in range(nb): 437 | 438 | latent = random.choice(latent_music) 439 | 440 | song = generator.predict(latent, verbose=1) 441 | 442 | song = song.reshape((SONG_LENGTH,note_span_with_ligatures/2,2)) 443 | song_0 = generate_from_probabilities(song_0) 444 | matrixToMidi(song_0,'outputs/example {}'.format(i)) 445 | 446 | def generate_from_probabilities(song, conservativity=1): 447 | 448 | for i in range(len(song)): 449 | for j in range(len(song[i])): 450 | song[i][j][0] = np.random.sample(1) < song[i][j][0] * conservativity 451 | song[i][j][1] = np.random.sample(1) < song[i][j][1] * conservativity 452 | return song 453 | 454 | def get_args(): 455 | parser = argparse.ArgumentParser() 456 | parser.add_argument("--mode", type=str, default="generate") 457 | parser.add_argument("--batch_size", type=int, default=1) 458 | parser.add_argument("--song_length", type=int, default=50*8) 459 | parser.add_argument("--nb", type=int, default=1) 460 | parser.add_argument("--epoch", type=int, default=1) 461 | args = parser.parse_args() 462 | return args 463 | 464 | if __name__ == "__main__": 465 | args = get_args() 466 | generating_size = args.song_length 467 | if args.mode == "train": 468 | train(BATCH_SIZE=args.batch_size, SONG_LENGTH=args.song_length, EPOCH=args.epoch) 469 | elif args.mode == "generate": 470 | generate(SONG_LENGTH=args.song_length,nb=args.n) 471 | --------------------------------------------------------------------------------