├── NeuralComposer.zip ├── README.md ├── cpu.theanorc ├── gpu.theanorc ├── live_edit.py ├── load_songs.py ├── midi.py ├── train.py └── util.py /NeuralComposer.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HackerPoet/Composer/0fab6d616962eab8f9fd31309441b4599253b71a/NeuralComposer.zip -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Composer 2 | Generates video game music using neural networks. 3 | https://youtu.be/UWxfnNXlVy8 4 | -------------------------------------------------------------------------------- /cpu.theanorc: -------------------------------------------------------------------------------- 1 | [global] 2 | floatX=float32 3 | device=cpu 4 | -------------------------------------------------------------------------------- /gpu.theanorc: -------------------------------------------------------------------------------- 1 | [global] 2 | floatX=float32 3 | device=cuda 4 | 5 | [nvcc] 6 | compiler_bindir=C:\Program Files (x86)\Microsoft Visual Studio 12.0\VC\bin 7 | 8 | [dnn] 9 | enabled=True 10 | include_path=C:\CUDA\v8.0\include 11 | library_path=C:\CUDA\v8.0\lib\x64 12 | -------------------------------------------------------------------------------- /live_edit.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import random, math 3 | import numpy as np 4 | import cv2 5 | import pyaudio 6 | import midi 7 | import wave 8 | 9 | #User constants 10 | device = "cpu" 11 | dir_name = 'History/' 12 | sub_dir_name = 'e2000/' 13 | sample_rate = 48000 14 | note_dt = 2000 #Num Samples 15 | note_duration = 20000 #Num Samples 16 | note_decay = 5.0 / sample_rate 17 | num_params = 120 18 | num_measures = 16 19 | num_sigmas = 5.0 20 | note_thresh = 32 21 | use_pca = True 22 | is_ae = True 23 | 24 | background_color = (210, 210, 210) 25 | edge_color = (60, 60, 60) 26 | slider_colors = [(90, 20, 20), (90, 90, 20), (20, 90, 20), (20, 90, 90), (20, 20, 90), (90, 20, 90)] 27 | 28 | note_w = 96 29 | note_h = 96 30 | note_pad = 2 31 | 32 | notes_rows = num_measures / 8 33 | notes_cols = 8 34 | 35 | slider_num = min(40, num_params) 36 | slider_h = 200 37 | slider_pad = 5 38 | tick_pad = 4 39 | 40 | control_w = 210 41 | control_h = 30 42 | control_pad = 5 43 | control_num = 3 44 | control_colors = [(255,0,0), (0,255,0), (0,0,255)] 45 | control_inits = [0.75, 0.5, 0.5] 46 | 47 | #Derived constants 48 | notes_w = notes_cols * (note_w + note_pad*2) 49 | notes_h = notes_rows * (note_h + note_pad*2) 50 | sliders_w = notes_w 51 | sliders_h = slider_h + slider_pad*2 52 | controls_w = control_w * control_num 53 | controls_h = control_h 54 | window_w = notes_w 55 | window_h = notes_h + sliders_h + controls_h 56 | slider_w = (window_w - slider_pad*2) / slider_num 57 | notes_x = 0 58 | notes_y = sliders_h 59 | sliders_x = slider_pad 60 | sliders_y = slider_pad 61 | controls_x = (window_w - controls_w) / 2 62 | controls_y = notes_h + sliders_h 63 | 64 | #Global variables 65 | prev_mouse_pos = None 66 | mouse_pressed = 0 67 | cur_slider_ix = 0 68 | cur_control_ix = 0 69 | volume = 3000 70 | instrument = 0 71 | needs_update = True 72 | cur_params = np.zeros((num_params,), dtype=np.float32) 73 | cur_notes = np.zeros((num_measures, note_h, note_w), dtype=np.uint8) 74 | cur_controls = np.array(control_inits, dtype=np.float32) 75 | 76 | #Setup audio stream 77 | audio = pyaudio.PyAudio() 78 | audio_notes = [] 79 | audio_time = 0 80 | note_time = 0 81 | note_time_dt = 0 82 | audio_reset = False 83 | audio_pause = False 84 | def audio_callback(in_data, frame_count, time_info, status): 85 | global audio_time 86 | global audio_notes 87 | global audio_reset 88 | global note_time 89 | global note_time_dt 90 | 91 | #Check if needs restart 92 | if audio_reset: 93 | audio_notes = [] 94 | audio_time = 0 95 | note_time = 0 96 | note_time_dt = 0 97 | audio_reset = False 98 | 99 | #Check if paused 100 | if audio_pause and status is not None: 101 | data = np.zeros((frame_count,), dtype=np.float32) 102 | return (data.tobytes(), pyaudio.paContinue) 103 | 104 | #Find and add any notes in this time window 105 | cur_dt = note_dt 106 | while note_time_dt < audio_time + frame_count: 107 | measure_ix = note_time / note_h 108 | if measure_ix >= num_measures: 109 | break 110 | note_ix = note_time % note_h 111 | notes = np.where(cur_notes[measure_ix, note_ix] >= note_thresh)[0] 112 | for note in notes: 113 | freq = 2 * 38.89 * pow(2.0, note / 12.0) / sample_rate 114 | audio_notes.append((note_time_dt, freq)) 115 | note_time += 1 116 | note_time_dt += cur_dt 117 | 118 | #Generate the tones 119 | data = np.zeros((frame_count,), dtype=np.float32) 120 | for t,f in audio_notes: 121 | x = np.arange(audio_time - t, audio_time + frame_count - t) 122 | x = np.maximum(x, 0) 123 | 124 | if instrument == 0: 125 | w = np.sign(1 - np.mod(x * f, 2)) #Square 126 | elif instrument == 1: 127 | w = np.mod(x * f - 1, 2) - 1 #Sawtooth 128 | elif instrument == 2: 129 | w = 2*np.abs(np.mod(x * f - 0.5, 2) - 1) - 1 #Triangle 130 | elif instrument == 3: 131 | w = np.sin(x * f * math.pi) #Sine 132 | 133 | #w = np.floor(w*8)/8 134 | w[x == 0] = 0 135 | w *= volume * np.exp(-x*note_decay) 136 | data += w 137 | data = np.clip(data, -32000, 32000).astype(np.int16) 138 | 139 | #Remove notes that are too old 140 | audio_time += frame_count 141 | audio_notes = [(t,f) for t,f in audio_notes if audio_time < t + note_duration] 142 | 143 | #Reset if loop occurs 144 | if note_time / note_h >= num_measures: 145 | audio_time = 0 146 | note_time = 0 147 | note_time_dt = 0 148 | audio_notes = [] 149 | 150 | #Return the sound clip 151 | return (data.tobytes(), pyaudio.paContinue) 152 | 153 | #Keras 154 | print "Loading Keras..." 155 | import os 156 | os.environ['THEANORC'] = "./" + device + ".theanorc" 157 | os.environ['KERAS_BACKEND'] = "theano" 158 | import theano 159 | print "Theano Version: " + theano.__version__ 160 | import keras 161 | print "Keras Version: " + keras.__version__ 162 | from keras.models import Model, Sequential, load_model 163 | from keras.layers import Dense, Activation, Dropout, Flatten, Reshape 164 | from keras.layers.convolutional import Conv2D, Conv2DTranspose, ZeroPadding2D 165 | from keras.layers.pooling import MaxPooling2D 166 | from keras.layers.noise import GaussianNoise 167 | from keras.layers.local import LocallyConnected2D 168 | from keras.optimizers import Adam, RMSprop, SGD 169 | from keras.regularizers import l2 170 | from keras.losses import binary_crossentropy 171 | from keras.layers.advanced_activations import ELU 172 | from keras.preprocessing.image import ImageDataGenerator 173 | from keras.utils import plot_model 174 | from keras import backend as K 175 | K.set_image_data_format('channels_first') 176 | 177 | print "Loading Encoder..." 178 | model = load_model(dir_name + 'model.h5') 179 | enc = K.function([model.get_layer('encoder').input, K.learning_phase()], 180 | [model.layers[-1].output]) 181 | enc_model = Model(inputs=model.input, outputs=model.get_layer('pre_encoder').output) 182 | 183 | print "Loading Statistics..." 184 | means = np.load(dir_name + sub_dir_name + 'means.npy') 185 | evals = np.load(dir_name + sub_dir_name + 'evals.npy') 186 | evecs = np.load(dir_name + sub_dir_name + 'evecs.npy') 187 | stds = np.load(dir_name + sub_dir_name + 'stds.npy') 188 | 189 | print "Loading Songs..." 190 | y_samples = np.load('samples.npy') 191 | y_lengths = np.load('lengths.npy') 192 | 193 | #Open a window 194 | pygame.init() 195 | pygame.font.init() 196 | screen = pygame.display.set_mode((window_w, window_h)) 197 | notes_surface = screen.subsurface((notes_x, notes_y, notes_w, notes_h)) 198 | pygame.display.set_caption('MusicEdit') 199 | font = pygame.font.SysFont("monospace", 15) 200 | 201 | #Start the audio stream 202 | audio_stream = audio.open( 203 | format=audio.get_format_from_width(2), 204 | channels=1, 205 | rate=sample_rate, 206 | output=True, 207 | stream_callback=audio_callback) 208 | audio_stream.start_stream() 209 | 210 | def update_mouse_click(mouse_pos): 211 | global cur_slider_ix 212 | global cur_control_ix 213 | global mouse_pressed 214 | x = (mouse_pos[0] - sliders_x) 215 | y = (mouse_pos[1] - sliders_y) 216 | 217 | if x >= 0 and y >= 0 and x < sliders_w and y < sliders_h: 218 | cur_slider_ix = x / slider_w 219 | mouse_pressed = 1 220 | 221 | x = (mouse_pos[0] - controls_x) 222 | y = (mouse_pos[1] - controls_y) 223 | if x >= 0 and y >= 0 and x < controls_w and y < controls_h: 224 | cur_control_ix = x / control_w 225 | mouse_pressed = 2 226 | 227 | def apply_controls(): 228 | global note_thresh 229 | global note_dt 230 | global volume 231 | 232 | note_thresh = (1.0 - cur_controls[0]) * 200 + 10 233 | note_dt = (1.0 - cur_controls[1]) * 1800 + 200 234 | volume = cur_controls[2] * 6000 235 | 236 | def update_mouse_move(mouse_pos): 237 | global needs_update 238 | 239 | if mouse_pressed == 1: 240 | y = (mouse_pos[1] - sliders_y) 241 | if y >= 0 and y <= slider_h: 242 | val = (float(y) / slider_h - 0.5) * (num_sigmas * 2) 243 | cur_params[cur_slider_ix] = val 244 | needs_update = True 245 | elif mouse_pressed == 2: 246 | x = (mouse_pos[0] - (controls_x + cur_control_ix*control_w)) 247 | if x >= control_pad and x <= control_w - control_pad: 248 | val = float(x - control_pad) / (control_w - control_pad*2) 249 | cur_controls[cur_control_ix] = val 250 | apply_controls() 251 | 252 | def draw_controls(): 253 | for i in xrange(control_num): 254 | x = controls_x + i * control_w + control_pad 255 | y = controls_y + control_pad 256 | w = control_w - control_pad*2 257 | h = control_h - control_pad*2 258 | col = control_colors[i] 259 | 260 | pygame.draw.rect(screen, col, (x, y, int(w*cur_controls[i]), h)) 261 | pygame.draw.rect(screen, (0,0,0), (x, y, w, h), 1) 262 | 263 | def draw_sliders(): 264 | for i in xrange(slider_num): 265 | slider_color = slider_colors[i % len(slider_colors)] 266 | x = sliders_x + i * slider_w 267 | y = sliders_y 268 | 269 | cx = x + slider_w / 2 270 | cy_1 = y 271 | cy_2 = y + slider_h 272 | pygame.draw.line(screen, slider_color, (cx, cy_1), (cx, cy_2)) 273 | 274 | cx_1 = x + tick_pad 275 | cx_2 = x + slider_w - tick_pad 276 | for j in xrange(int(num_sigmas * 2 + 1)): 277 | ly = y + slider_h/2.0 + (j-num_sigmas)*slider_h/(num_sigmas*2.0) 278 | ly = int(ly) 279 | col = (0,0,0) if j - num_sigmas == 0 else slider_color 280 | pygame.draw.line(screen, col, (cx_1, ly), (cx_2, ly)) 281 | 282 | py = y + int((cur_params[i] / (num_sigmas * 2) + 0.5) * slider_h) 283 | pygame.draw.circle(screen, slider_color, (cx, py), (slider_w - tick_pad)/2) 284 | 285 | def notes_to_img(notes): 286 | output = np.full((3, notes_h, notes_w), 64, dtype=np.uint8) 287 | 288 | for i in xrange(notes_rows): 289 | for j in xrange(notes_cols): 290 | x = note_pad + j*(note_w + note_pad*2) 291 | y = note_pad + i*(note_h + note_pad*2) 292 | ix = i*notes_cols + j 293 | 294 | measure = np.rot90(notes[ix]) 295 | played_only = np.where(measure >= note_thresh, 255, 0) 296 | output[0,y:y+note_h,x:x+note_w] = np.minimum(measure * (255.0 / note_thresh), 255.0) 297 | output[1,y:y+note_h,x:x+note_w] = played_only 298 | output[2,y:y+note_h,x:x+note_w] = played_only 299 | 300 | return np.transpose(output, (2, 1, 0)) 301 | 302 | def draw_notes(): 303 | pygame.surfarray.blit_array(notes_surface, notes_to_img(cur_notes)) 304 | 305 | measure_ix = note_time / note_h 306 | note_ix = note_time % note_h 307 | x = notes_x + note_pad + (measure_ix % notes_cols) * (note_w + note_pad*2) + note_ix 308 | y = notes_y +note_pad + (measure_ix / notes_cols) * (note_h + note_pad*2) 309 | pygame.draw.rect(screen, (255,255,0), (x, y, 4, note_h), 0) 310 | 311 | #Main loop 312 | running = True 313 | rand_ix = 0 314 | cur_len = 0 315 | apply_controls() 316 | while running: 317 | #Process events 318 | for event in pygame.event.get(): 319 | if event.type == pygame.QUIT: 320 | running = False 321 | break 322 | elif event.type == pygame.MOUSEBUTTONDOWN: 323 | if pygame.mouse.get_pressed()[0]: 324 | prev_mouse_pos = pygame.mouse.get_pos() 325 | update_mouse_click(prev_mouse_pos) 326 | update_mouse_move(prev_mouse_pos) 327 | elif pygame.mouse.get_pressed()[2]: 328 | cur_params = np.zeros((num_params,), dtype=np.float32) 329 | needs_update = True 330 | elif event.type == pygame.MOUSEBUTTONUP: 331 | mouse_pressed = 0 332 | prev_mouse_pos = None 333 | elif event.type == pygame.MOUSEMOTION and mouse_pressed > 0: 334 | update_mouse_move(pygame.mouse.get_pos()) 335 | elif event.type == pygame.KEYDOWN: 336 | if event.key == pygame.K_r: 337 | cur_params = np.clip(np.random.normal(0.0, 1.0, (num_params,)), -num_sigmas, num_sigmas) 338 | needs_update = True 339 | audio_reset = True 340 | if event.key == pygame.K_e: 341 | cur_params = np.clip(np.random.normal(0.0, 2.0, (num_params,)), -num_sigmas, num_sigmas) 342 | needs_update = True 343 | audio_reset = True 344 | if event.key == pygame.K_o: 345 | print "RandIx: " + str(rand_ix) 346 | if is_ae: 347 | example_song = y_samples[cur_len:cur_len + num_measures] 348 | cur_notes = example_song * 255 349 | x = enc_model.predict(np.expand_dims(example_song, 0), batch_size=1)[0] 350 | cur_len += y_lengths[rand_ix] 351 | rand_ix += 1 352 | else: 353 | rand_ix = np.array([rand_ix], dtype=np.int64) 354 | x = enc_model.predict(rand_ix, batch_size=1)[0] 355 | rand_ix = (rand_ix + 1) % model.layers[0].input_dim 356 | 357 | if use_pca: 358 | cur_params = np.dot(x - means, evecs.T) / evals 359 | else: 360 | cur_params = (x - means) / stds 361 | 362 | needs_update = True 363 | audio_reset = True 364 | if event.key == pygame.K_g: 365 | audio_pause = True 366 | audio_reset = True 367 | midi.samples_to_midi(cur_notes, 'live.mid', 16, note_thresh) 368 | save_audio = '' 369 | while True: 370 | save_audio += audio_callback(None, 1024, None, None)[0] 371 | if audio_time == 0: 372 | break 373 | wave_output = wave.open('live.wav', 'w') 374 | wave_output.setparams((1, 2, sample_rate, 0, 'NONE', 'not compressed')) 375 | wave_output.writeframes(save_audio) 376 | wave_output.close() 377 | audio_pause = False 378 | if event.key == pygame.K_ESCAPE: 379 | running = False 380 | break 381 | if event.key == pygame.K_SPACE: 382 | audio_pause = not audio_pause 383 | if event.key == pygame.K_TAB: 384 | audio_reset = True 385 | if event.key == pygame.K_1: 386 | instrument = 0 387 | if event.key == pygame.K_2: 388 | instrument = 1 389 | if event.key == pygame.K_3: 390 | instrument = 2 391 | if event.key == pygame.K_4: 392 | instrument = 3 393 | if event.key == pygame.K_c: 394 | y = np.expand_dims(np.where(cur_notes > note_thresh, 1, 0), 0) 395 | x = enc_model.predict(y)[0] 396 | if use_pca: 397 | cur_params = np.dot(x - means, evecs.T) / evals 398 | else: 399 | cur_params = (x - means) / stds 400 | needs_update = True 401 | 402 | #Check if we need an update 403 | if needs_update: 404 | if use_pca: 405 | x = means + np.dot(cur_params * evals, evecs) 406 | else: 407 | x = means + stds * cur_params 408 | x = np.expand_dims(x, axis=0) 409 | y = enc([x, 0])[0][0] 410 | cur_notes = (y * 255.0).astype(np.uint8) 411 | needs_update = False 412 | 413 | #Draw to the screen 414 | screen.fill(background_color) 415 | draw_notes() 416 | draw_sliders() 417 | draw_controls() 418 | 419 | #Flip the screen buffer 420 | pygame.display.flip() 421 | pygame.time.wait(10) 422 | 423 | #Close the audio stream 424 | audio_stream.stop_stream() 425 | audio_stream.close() 426 | audio.terminate() 427 | 428 | -------------------------------------------------------------------------------- /load_songs.py: -------------------------------------------------------------------------------- 1 | import midi 2 | import os 3 | import util 4 | import numpy as np 5 | 6 | patterns = {} 7 | dirs = ["Music", "download", "rag", "pop", "misc"] 8 | all_samples = [] 9 | all_lens = [] 10 | print "Loading Songs..." 11 | for dir in dirs: 12 | for root, subdirs, files in os.walk(dir): 13 | for file in files: 14 | path = root + "\\" + file 15 | if not (path.endswith('.mid') or path.endswith('.midi')): 16 | continue 17 | try: 18 | samples = midi.midi_to_samples(path) 19 | except: 20 | print "ERROR ", path 21 | continue 22 | if len(samples) < 8: 23 | continue 24 | 25 | samples, lens = util.generate_add_centered_transpose(samples) 26 | all_samples += samples 27 | all_lens += lens 28 | 29 | assert(sum(all_lens) == len(all_samples)) 30 | print "Saving " + str(len(all_samples)) + " samples..." 31 | all_samples = np.array(all_samples, dtype=np.uint8) 32 | all_lens = np.array(all_lens, dtype=np.uint32) 33 | np.save('samples.npy', all_samples) 34 | np.save('lengths.npy', all_lens) 35 | print "Done" 36 | -------------------------------------------------------------------------------- /midi.py: -------------------------------------------------------------------------------- 1 | from mido import MidiFile, MidiTrack, Message 2 | import numpy as np 3 | 4 | num_notes = 96 5 | samples_per_measure = 96 6 | 7 | def midi_to_samples(fname): 8 | has_time_sig = False 9 | flag_warning = False 10 | mid = MidiFile(fname) 11 | ticks_per_beat = mid.ticks_per_beat 12 | ticks_per_measure = 4 * ticks_per_beat 13 | 14 | for i, track in enumerate(mid.tracks): 15 | for msg in track: 16 | if msg.type == 'time_signature': 17 | new_tpm = msg.numerator * ticks_per_beat * 4 / msg.denominator 18 | if has_time_sig and new_tpm != ticks_per_measure: 19 | flag_warning = True 20 | ticks_per_measure = new_tpm 21 | has_time_sig = True 22 | if flag_warning: 23 | print " ^^^^^^ WARNING ^^^^^^" 24 | print " " + fname 25 | print " Detected multiple distinct time signatures." 26 | print " ^^^^^^ WARNING ^^^^^^" 27 | return [] 28 | 29 | all_notes = {} 30 | for i, track in enumerate(mid.tracks): 31 | abs_time = 0 32 | for msg in track: 33 | abs_time += msg.time 34 | if msg.type == 'note_on': 35 | if msg.velocity == 0: 36 | continue 37 | note = msg.note - (128 - num_notes)/2 38 | assert(note >= 0 and note < num_notes) 39 | if note not in all_notes: 40 | all_notes[note] = [] 41 | else: 42 | single_note = all_notes[note][-1] 43 | if len(single_note) == 1: 44 | single_note.append(single_note[0] + 1) 45 | all_notes[note].append([abs_time * samples_per_measure / ticks_per_measure]) 46 | elif msg.type == 'note_off': 47 | if len(all_notes[note][-1]) != 1: 48 | continue 49 | all_notes[note][-1].append(abs_time * samples_per_measure / ticks_per_measure) 50 | for note in all_notes: 51 | for start_end in all_notes[note]: 52 | if len(start_end) == 1: 53 | start_end.append(start_end[0] + 1) 54 | samples = [] 55 | for note in all_notes: 56 | for start, end in all_notes[note]: 57 | sample_ix = start / samples_per_measure 58 | while len(samples) <= sample_ix: 59 | samples.append(np.zeros((samples_per_measure, num_notes), dtype=np.uint8)) 60 | sample = samples[sample_ix] 61 | start_ix = start - sample_ix * samples_per_measure 62 | if False: 63 | end_ix = min(end - sample_ix * samples_per_measure, samples_per_measure) 64 | while start_ix < end_ix: 65 | sample[start_ix, note] = 1 66 | start_ix += 1 67 | else: 68 | sample[start_ix, note] = 1 69 | return samples 70 | 71 | def samples_to_midi(samples, fname, ticks_per_sample, thresh=0.5): 72 | mid = MidiFile() 73 | track = MidiTrack() 74 | mid.tracks.append(track) 75 | ticks_per_beat = mid.ticks_per_beat 76 | ticks_per_measure = 4 * ticks_per_beat 77 | ticks_per_sample = ticks_per_measure / samples_per_measure 78 | abs_time = 0 79 | last_time = 0 80 | for sample in samples: 81 | for y in xrange(sample.shape[0]): 82 | abs_time += ticks_per_sample 83 | for x in xrange(sample.shape[1]): 84 | note = x + (128 - num_notes)/2 85 | if sample[y,x] >= thresh and (y == 0 or sample[y-1,x] < thresh): 86 | delta_time = abs_time - last_time 87 | track.append(Message('note_on', note=note, velocity=127, time=delta_time)) 88 | last_time = abs_time 89 | if sample[y,x] >= thresh and (y == sample.shape[0]-1 or sample[y+1,x] < thresh): 90 | delta_time = abs_time - last_time 91 | track.append(Message('note_off', note=note, velocity=127, time=delta_time)) 92 | last_time = abs_time 93 | mid.save(fname) 94 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys, random, os 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | import pydot 5 | import cv2 6 | import util 7 | import midi 8 | 9 | NUM_EPOCHS = 2000 10 | LR = 0.001 11 | CONTINUE_TRAIN = False 12 | PLAY_ONLY = False 13 | USE_EMBEDDING = False 14 | USE_VAE = False 15 | WRITE_HISTORY = True 16 | NUM_RAND_SONGS = 10 17 | DO_RATE = 0.1 18 | BN_M = 0.9 19 | VAE_B1 = 0.02 20 | VAE_B2 = 0.1 21 | 22 | BATCH_SIZE = 350 23 | MAX_LENGTH = 16 24 | PARAM_SIZE = 120 25 | NUM_OFFSETS = 16 if USE_EMBEDDING else 1 26 | 27 | def plotScores(scores, fname, on_top=True): 28 | plt.clf() 29 | ax = plt.gca() 30 | ax.yaxis.tick_right() 31 | ax.yaxis.set_ticks_position('both') 32 | ax.yaxis.grid(True) 33 | plt.plot(scores) 34 | plt.ylim([0.0, 0.009]) 35 | plt.xlabel('Epoch') 36 | loc = ('upper right' if on_top else 'lower right') 37 | plt.draw() 38 | plt.savefig(fname) 39 | 40 | def save_config(): 41 | with open('config.txt', 'w') as fout: 42 | fout.write('LR: ' + str(LR) + '\n') 43 | fout.write('BN_M: ' + str(BN_M) + '\n') 44 | fout.write('BATCH_SIZE: ' + str(BATCH_SIZE) + '\n') 45 | fout.write('NUM_OFFSETS: ' + str(NUM_OFFSETS) + '\n') 46 | fout.write('DO_RATE: ' + str(DO_RATE) + '\n') 47 | fout.write('num_songs: ' + str(num_songs) + '\n') 48 | fout.write('optimizer: ' + type(model.optimizer).__name__ + '\n') 49 | 50 | ################################### 51 | # Load Keras 52 | ################################### 53 | print "Loading Keras..." 54 | import os, math 55 | os.environ['THEANORC'] = "./gpu.theanorc" 56 | os.environ['KERAS_BACKEND'] = "theano" 57 | import theano 58 | print "Theano Version: " + theano.__version__ 59 | 60 | import keras 61 | print "Keras Version: " + keras.__version__ 62 | from keras.layers import Input, Dense, Activation, Dropout, Flatten, Reshape, Permute, RepeatVector, ActivityRegularization, TimeDistributed, Lambda, SpatialDropout1D 63 | from keras.layers.convolutional import Conv1D, Conv2D, Conv2DTranspose, UpSampling2D, ZeroPadding2D 64 | from keras.layers.embeddings import Embedding 65 | from keras.layers.local import LocallyConnected2D 66 | from keras.layers.pooling import MaxPooling2D, AveragePooling2D 67 | from keras.layers.noise import GaussianNoise 68 | from keras.layers.normalization import BatchNormalization 69 | from keras.layers.recurrent import LSTM, SimpleRNN 70 | from keras.initializers import RandomNormal 71 | from keras.losses import binary_crossentropy 72 | from keras.models import Model, Sequential, load_model 73 | from keras.optimizers import Adam, RMSprop, SGD 74 | from keras.preprocessing.image import ImageDataGenerator 75 | from keras.regularizers import l2 76 | from keras.utils import plot_model 77 | from keras import backend as K 78 | from keras import regularizers 79 | from keras.engine.topology import Layer 80 | K.set_image_data_format('channels_first') 81 | 82 | #Fix the random seed so that training comparisons are easier to make 83 | np.random.seed(0) 84 | random.seed(0) 85 | 86 | if WRITE_HISTORY: 87 | #Create folder to save models into 88 | if not os.path.exists('History'): 89 | os.makedirs('History') 90 | 91 | ################################### 92 | # Load Dataset 93 | ################################### 94 | print "Loading Data..." 95 | y_samples = np.load('samples.npy') 96 | y_lengths = np.load('lengths.npy') 97 | num_samples = y_samples.shape[0] 98 | num_songs = y_lengths.shape[0] 99 | print "Loaded " + str(num_samples) + " samples from " + str(num_songs) + " songs." 100 | print np.sum(y_lengths) 101 | assert(np.sum(y_lengths) == num_samples) 102 | 103 | print "Padding Songs..." 104 | x_shape = (num_songs * NUM_OFFSETS, 1) 105 | y_shape = (num_songs * NUM_OFFSETS, MAX_LENGTH) + y_samples.shape[1:] 106 | x_orig = np.expand_dims(np.arange(x_shape[0]), axis=-1) 107 | y_orig = np.zeros(y_shape, dtype=y_samples.dtype) 108 | cur_ix = 0 109 | for i in xrange(num_songs): 110 | for ofs in xrange(NUM_OFFSETS): 111 | ix = i*NUM_OFFSETS + ofs 112 | end_ix = cur_ix + y_lengths[i] 113 | for j in xrange(MAX_LENGTH): 114 | k = (j + ofs) % (end_ix - cur_ix) 115 | y_orig[ix,j] = y_samples[cur_ix + k] 116 | cur_ix = end_ix 117 | assert(end_ix == num_samples) 118 | x_train = np.copy(x_orig) 119 | y_train = np.copy(y_orig) 120 | 121 | def to_song(encoded_output): 122 | return np.squeeze(decoder([np.round(encoded_output), 0])[0]) 123 | 124 | def reg_mean_std(x): 125 | s = K.log(K.sum(x * x)) 126 | return s*s 127 | 128 | def vae_sampling(args): 129 | z_mean, z_log_sigma_sq = args 130 | epsilon = K.random_normal(shape=K.shape(z_mean), mean=0.0, stddev=VAE_B1) 131 | return z_mean + K.exp(z_log_sigma_sq * 0.5) * epsilon 132 | 133 | def vae_loss(x, x_decoded_mean): 134 | xent_loss = binary_crossentropy(x, x_decoded_mean) 135 | kl_loss = VAE_B2 * K.mean(1 + z_log_sigma_sq - K.square(z_mean) - K.exp(z_log_sigma_sq), axis=None) 136 | return xent_loss - kl_loss 137 | 138 | test_ix = 0 139 | y_test_song = np.copy(y_train[test_ix:test_ix+1]) 140 | x_test_song = np.copy(x_train[test_ix:test_ix+1]) 141 | midi.samples_to_midi(y_test_song[0], 'gt.mid', 16) 142 | 143 | ################################### 144 | # Create Model 145 | ################################### 146 | if CONTINUE_TRAIN or PLAY_ONLY: 147 | print "Loading Model..." 148 | model = load_model('model.h5', custom_objects=custom_objects) 149 | else: 150 | print "Building Model..." 151 | 152 | if USE_EMBEDDING: 153 | x_in = Input(shape=x_shape[1:]) 154 | print (None,) + x_shape[1:] 155 | x = Embedding(x_train.shape[0], PARAM_SIZE, input_length=1)(x_in) 156 | x = Flatten(name='pre_encoder')(x) 157 | else: 158 | x_in = Input(shape=y_shape[1:]) 159 | print (None,) + y_shape[1:] 160 | x = Reshape((y_shape[1], -1))(x_in) 161 | print K.int_shape(x) 162 | 163 | x = TimeDistributed(Dense(2000, activation='relu'))(x) 164 | print K.int_shape(x) 165 | 166 | x = TimeDistributed(Dense(200, activation='relu'))(x) 167 | print K.int_shape(x) 168 | 169 | x = Flatten()(x) 170 | print K.int_shape(x) 171 | 172 | x = Dense(1600, activation='relu')(x) 173 | print K.int_shape(x) 174 | 175 | if USE_VAE: 176 | z_mean = Dense(PARAM_SIZE)(x) 177 | z_log_sigma_sq = Dense(PARAM_SIZE)(x) 178 | x = Lambda(vae_sampling, output_shape=(PARAM_SIZE,), name='pre_encoder')([z_mean, z_log_sigma_sq]) 179 | else: 180 | x = Dense(PARAM_SIZE)(x) 181 | x = BatchNormalization(momentum=BN_M, name='pre_encoder')(x) 182 | print K.int_shape(x) 183 | 184 | x = Dense(1600, name='encoder')(x) 185 | x = BatchNormalization(momentum=BN_M)(x) 186 | x = Activation('relu')(x) 187 | if DO_RATE > 0: 188 | x = Dropout(DO_RATE)(x) 189 | print K.int_shape(x) 190 | 191 | x = Dense(MAX_LENGTH * 200)(x) 192 | print K.int_shape(x) 193 | x = Reshape((MAX_LENGTH, 200))(x) 194 | x = TimeDistributed(BatchNormalization(momentum=BN_M))(x) 195 | x = Activation('relu')(x) 196 | if DO_RATE > 0: 197 | x = Dropout(DO_RATE)(x) 198 | print K.int_shape(x) 199 | 200 | x = TimeDistributed(Dense(2000))(x) 201 | x = TimeDistributed(BatchNormalization(momentum=BN_M))(x) 202 | x = Activation('relu')(x) 203 | if DO_RATE > 0: 204 | x = Dropout(DO_RATE)(x) 205 | print K.int_shape(x) 206 | 207 | x = TimeDistributed(Dense(y_shape[2] * y_shape[3], activation='sigmoid'))(x) 208 | print K.int_shape(x) 209 | x = Reshape((y_shape[1], y_shape[2], y_shape[3]))(x) 210 | print K.int_shape(x) 211 | 212 | if USE_VAE: 213 | model = Model(x_in, x) 214 | model.compile(optimizer=Adam(lr=LR), loss=vae_loss) 215 | else: 216 | model = Model(x_in, x) 217 | model.compile(optimizer=RMSprop(lr=LR), loss='binary_crossentropy') 218 | 219 | plot_model(model, to_file='model.png', show_shapes=True) 220 | 221 | ################################### 222 | # Train 223 | ################################### 224 | print "Compiling SubModels..." 225 | func = K.function([model.get_layer('encoder').input, K.learning_phase()], 226 | [model.layers[-1].output]) 227 | enc = Model(inputs=model.input, outputs=model.get_layer('pre_encoder').output) 228 | 229 | rand_vecs = np.random.normal(0.0, 1.0, (NUM_RAND_SONGS, PARAM_SIZE)) 230 | np.save('rand.npy', rand_vecs) 231 | 232 | def make_rand_songs(write_dir, rand_vecs): 233 | for i in xrange(rand_vecs.shape[0]): 234 | x_rand = rand_vecs[i:i+1] 235 | y_song = func([x_rand, 0])[0] 236 | midi.samples_to_midi(y_song[0], write_dir + 'rand' + str(i) + '.mid', 16, 0.25) 237 | 238 | def make_rand_songs_normalized(write_dir, rand_vecs): 239 | if USE_EMBEDDING: 240 | x_enc = np.squeeze(enc.predict(x_orig)) 241 | else: 242 | x_enc = np.squeeze(enc.predict(y_orig)) 243 | 244 | x_mean = np.mean(x_enc, axis=0) 245 | x_stds = np.std(x_enc, axis=0) 246 | x_cov = np.cov((x_enc - x_mean).T) 247 | u, s, v = np.linalg.svd(x_cov) 248 | e = np.sqrt(s) 249 | 250 | print "Means: ", x_mean[:6] 251 | print "Evals: ", e[:6] 252 | 253 | np.save(write_dir + 'means.npy', x_mean) 254 | np.save(write_dir + 'stds.npy', x_stds) 255 | np.save(write_dir + 'evals.npy', e) 256 | np.save(write_dir + 'evecs.npy', v) 257 | 258 | x_vecs = x_mean + np.dot(rand_vecs * e, v) 259 | make_rand_songs(write_dir, x_vecs) 260 | 261 | title = '' 262 | if '/' in write_dir: 263 | title = 'Epoch: ' + write_dir.split('/')[-2][1:] 264 | 265 | plt.clf() 266 | e[::-1].sort() 267 | plt.title(title) 268 | plt.bar(np.arange(e.shape[0]), e, align='center') 269 | plt.draw() 270 | plt.savefig(write_dir + 'evals.png') 271 | 272 | plt.clf() 273 | plt.title(title) 274 | plt.bar(np.arange(e.shape[0]), x_mean, align='center') 275 | plt.draw() 276 | plt.savefig(write_dir + 'means.png') 277 | 278 | plt.clf() 279 | plt.title(title) 280 | plt.bar(np.arange(e.shape[0]), x_stds, align='center') 281 | plt.draw() 282 | plt.savefig(write_dir + 'stds.png') 283 | 284 | if PLAY_ONLY: 285 | print "Generating Songs..." 286 | make_rand_songs_normalized('', rand_vecs) 287 | for i in xrange(20): 288 | x_test_song = x_train[i:i+1] 289 | y_song = model.predict(x_test_song, batch_size=BATCH_SIZE)[0] 290 | midi.samples_to_midi(y_song, 'gt' + str(i) + '.mid', 16) 291 | exit(0) 292 | 293 | print "Training..." 294 | save_config() 295 | train_loss = [] 296 | ofs = 0 297 | 298 | for iter in xrange(NUM_EPOCHS): 299 | if USE_EMBEDDING: 300 | history = model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=1) 301 | else: 302 | cur_ix = 0 303 | for i in xrange(num_songs): 304 | end_ix = cur_ix + y_lengths[i] 305 | for j in xrange(MAX_LENGTH): 306 | k = (j + ofs) % (end_ix - cur_ix) 307 | y_train[i,j] = y_samples[cur_ix + k] 308 | cur_ix = end_ix 309 | assert(end_ix == num_samples) 310 | ofs += 1 311 | 312 | history = model.fit(y_train, y_train, batch_size=BATCH_SIZE, epochs=1) 313 | 314 | loss = history.history["loss"][-1] 315 | train_loss.append(loss) 316 | print "Train Loss: " + str(train_loss[-1]) 317 | 318 | if WRITE_HISTORY: 319 | plotScores(train_loss, 'History/Scores.png', True) 320 | else: 321 | plotScores(train_loss, 'Scores.png', True) 322 | 323 | i = iter + 1 324 | if i in [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 120, 140, 160, 180, 200, 250, 300, 350, 400, 450] or (i % 100 == 0): 325 | write_dir = '' 326 | if WRITE_HISTORY: 327 | #Create folder to save models into 328 | write_dir = 'History/e' + str(i) 329 | if not os.path.exists(write_dir): 330 | os.makedirs(write_dir) 331 | write_dir += '/' 332 | model.save('History/model.h5') 333 | else: 334 | model.save('model.h5') 335 | print "Saved" 336 | 337 | if USE_EMBEDDING: 338 | y_song = model.predict(x_test_song, batch_size=BATCH_SIZE)[0] 339 | else: 340 | y_song = model.predict(y_test_song, batch_size=BATCH_SIZE)[0] 341 | util.samples_to_pics(write_dir + 'test', y_song) 342 | midi.samples_to_midi(y_song, write_dir + 'test.mid', 16) 343 | 344 | make_rand_songs_normalized(write_dir, rand_vecs) 345 | 346 | print "Done" 347 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | 5 | def transpose_range(samples): 6 | merged_sample = np.zeros_like(samples[0]) 7 | for sample in samples: 8 | merged_sample = np.maximum(merged_sample, sample) 9 | merged_sample = np.amax(merged_sample, axis=0) 10 | min_note = np.argmax(merged_sample) 11 | max_note = merged_sample.shape[0] - np.argmax(merged_sample[::-1]) 12 | return min_note, max_note 13 | 14 | def generate_add_centered_transpose(samples): 15 | num_notes = samples[0].shape[1] 16 | min_note, max_note = transpose_range(samples) 17 | s = num_notes/2 - (max_note + min_note)/2 18 | out_samples = samples 19 | out_lens = [len(samples), len(samples)] 20 | for i in xrange(len(samples)): 21 | out_sample = np.zeros_like(samples[i]) 22 | out_sample[:,min_note+s:max_note+s] = samples[i][:,min_note:max_note] 23 | out_samples.append(out_sample) 24 | return out_samples, out_lens 25 | 26 | def generate_all_transpose(samples, radius=6): 27 | num_notes = samples[0].shape[1] 28 | min_note, max_note = transpose_range(samples) 29 | min_shift = -min(radius, min_note) 30 | max_shift = min(radius, num_notes - max_note) 31 | out_samples = [] 32 | out_lens = [] 33 | for s in xrange(min_shift, max_shift): 34 | for i in xrange(len(samples)): 35 | out_sample = np.zeros_like(samples[i]) 36 | out_sample[:,min_note+s:max_note+s] = samples[i][:,min_note:max_note] 37 | out_samples.append(out_sample) 38 | out_lens.append(len(samples)) 39 | return out_samples, out_lens 40 | 41 | def sample_to_pic(fname, sample, thresh=None): 42 | if thresh is not None: 43 | inverted = np.where(sample > thresh, 0, 1) 44 | else: 45 | inverted = 1.0 - sample 46 | cv2.imwrite(fname, inverted * 255) 47 | 48 | def samples_to_pics(dir, samples, thresh=None): 49 | if not os.path.exists(dir): os.makedirs(dir) 50 | for i in xrange(samples.shape[0]): 51 | sample_to_pic(dir + '/s' + str(i) + '.png', samples[i], thresh) 52 | 53 | def pad_songs(y, y_lens, max_len): 54 | y_shape = (y_lens.shape[0], max_len) + y.shape[1:] 55 | y_train = np.zeros(y_shape, dtype=np.float32) 56 | cur_ix = 0 57 | for i in xrange(y_lens.shape[0]): 58 | end_ix = cur_ix + y_lens[i] 59 | for j in xrange(max_len): 60 | k = j % (end_ix - cur_ix) 61 | y_train[i,j] = y[cur_ix + k] 62 | cur_ix = end_ix 63 | assert(end_ix == y.shape[0]) 64 | return y_train 65 | 66 | def sample_to_pattern(sample, ix, size): 67 | num_pats = 0 68 | pat_types = {} 69 | pat_list = [] 70 | num_samples = len(sample) if type(sample) is list else sample.shape[0] 71 | for i in xrange(size): 72 | j = (ix + i) % num_samples 73 | measure = sample[j].tobytes() 74 | if measure not in pat_types: 75 | pat_types[measure] = num_pats 76 | num_pats += 1 77 | pat_list.append(pat_types[measure]) 78 | return str(pat_list), pat_types 79 | 80 | def embed_samples(samples): 81 | note_dict = {} 82 | n, m, p = samples.shape 83 | samples.flags.writeable = False 84 | e_samples = np.empty(samples.shape[:2], dtype=np.int32) 85 | for i in xrange(n): 86 | for j in xrange(m): 87 | note = samples[i,j].data 88 | if note not in note_dict: 89 | note_dict[note] = len(note_dict) 90 | e_samples[i,j] = note_dict[note] 91 | samples.flags.writeable = True 92 | lookup = np.empty((len(note_dict), p), dtype=np.float32) 93 | for k in note_dict: 94 | lookup[note_dict[k]] = k 95 | return e_samples, note_dict, lookup 96 | 97 | def e_to_samples(e_samples, lookup): 98 | samples = np.empty(e_samples.shape + lookup.shape[-1:], dtype=np.float32) 99 | n, m = e_samples.shape 100 | for i in xrange(n): 101 | for j in xrange(m): 102 | samples[i,j] = lookup[e_samples[i,j]] 103 | return samples 104 | --------------------------------------------------------------------------------