├── .gitignore ├── README.md ├── dataset.zip ├── generated_music └── mp3 │ ├── 1.mp3 │ ├── 10.mp3 │ ├── 2.mp3 │ ├── 3.mp3 │ ├── 4.mp3 │ ├── 5.mp3 │ ├── 6.mp3 │ ├── 7.mp3 │ ├── 8.mp3 │ ├── 9.mp3 │ └── nottingham_sample.mp3 ├── main.py ├── midi_util.py ├── model.py ├── nottingham_util.py ├── rnn.py ├── rnn_sample.py ├── rnn_separate.py ├── rnn_test.py ├── sampling.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | 2012code 3 | models 4 | python-midi 5 | research 6 | tensorflow 7 | 8 | *.midi 9 | *.pyc 10 | img/.DS_Store 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Overview 2 | ============ 3 | A project that trains a LSTM recurrent neural network over a dataset of MIDI files. More information can be found on the [writeup about this project](http://yoavz.com/music_rnn/). This the code for 'Build an AI Composer' on [Youtube](https://youtu.be/S_f2qV2_U00) 4 | 5 | Dependencies 6 | ============ 7 | 8 | * Numpy (http://www.numpy.org/) 9 | * Tensorflow (https://github.com/tensorflow/tensorflow) 10 | * Python Midi (https://github.com/vishnubob/python-midi.git) 11 | * Mingus (https://github.com/bspaans/python-mingus) 12 | 13 | Use [pip](https://pypi.python.org/pypi/pip) to install any missing dependencies 14 | 15 | Installation (Tested on Ubuntu 16.04) 16 | ============ 17 | 18 | * Step 1: Tensorflow version 0.8.0 must be used. On [Tensorflow's download page here](https://www.tensorflow.org/versions/r0.10/get_started/os_setup.html), scroll down to "Pip Installation". Follow the first step normally. 19 | 20 | You will see "export TF_BINARY_URL" followed by a URL. Modify the part of the url that has "tensorflow-0.10.0", so that it will download version 0.8.0, not version 0.10.0 "tensorflow-0.8.0. 21 | 22 | Example of the modified url, for the Python 2.7 CPU version of Tensorflow: 23 | 24 | ``` 25 | export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl 26 | 27 | sudo pip install --upgrade $TF_BINARY_URL 28 | ``` 29 | Follow the third step normally to install Tensorflow. 30 | 31 | * Step 2: After installing Tensorflow, you will have to install the missing dependencies: 32 | 33 | `pip install matplotlib` 34 | 35 | `sudo apt-get install python-tk ` 36 | 37 | `pip install numpy` 38 | 39 | * Step 3: 40 | 41 | ``` 42 | cd ~ 43 | git clone https://github.com/vishnubob/python-midi 44 | cd python-midi 45 | python setup.py install 46 | ``` 47 | 48 | 49 | 50 | 51 | ``` 52 | cd ~ 53 | git clone https://github.com/bspaans/python-mingus 54 | cd python-mingus 55 | python setup.py install 56 | ``` 57 | 58 | 59 | Basic Usage 60 | =========== 61 | 62 | 1. `mkdir data && mkdir models` 63 | 2. run 'python main.py'. This will collect the data, create the chord mapping file in data/nottingham.pickle, and train the model 64 | 3. Run `python rnn_sample.py --config_file new_config_file.config` to generate a new MIDI song. 65 | 66 | Give it 1-2 hours to train on your local machine, then generate the new song. You don't have to wait for it to finish, just wait until you see the 'saving model' message in terminal. In a future video, I'll talk about how to easily setup cloud GPU training. Likely using www.fomoro.com 67 | 68 | Credits 69 | =========== 70 | Credit for the vast majority of code here goes to [Yoav Zimmerman](https://github.com/yoavz). I've merely created a wrapper around all of the important functions to get people started. 71 | -------------------------------------------------------------------------------- /dataset.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/AI_Composer/434817dcad9bf2e80a2c9ec9d9ef8308e71e855b/dataset.zip -------------------------------------------------------------------------------- /generated_music/mp3/1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/AI_Composer/434817dcad9bf2e80a2c9ec9d9ef8308e71e855b/generated_music/mp3/1.mp3 -------------------------------------------------------------------------------- /generated_music/mp3/10.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/AI_Composer/434817dcad9bf2e80a2c9ec9d9ef8308e71e855b/generated_music/mp3/10.mp3 -------------------------------------------------------------------------------- /generated_music/mp3/2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/AI_Composer/434817dcad9bf2e80a2c9ec9d9ef8308e71e855b/generated_music/mp3/2.mp3 -------------------------------------------------------------------------------- /generated_music/mp3/3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/AI_Composer/434817dcad9bf2e80a2c9ec9d9ef8308e71e855b/generated_music/mp3/3.mp3 -------------------------------------------------------------------------------- /generated_music/mp3/4.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/AI_Composer/434817dcad9bf2e80a2c9ec9d9ef8308e71e855b/generated_music/mp3/4.mp3 -------------------------------------------------------------------------------- /generated_music/mp3/5.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/AI_Composer/434817dcad9bf2e80a2c9ec9d9ef8308e71e855b/generated_music/mp3/5.mp3 -------------------------------------------------------------------------------- /generated_music/mp3/6.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/AI_Composer/434817dcad9bf2e80a2c9ec9d9ef8308e71e855b/generated_music/mp3/6.mp3 -------------------------------------------------------------------------------- /generated_music/mp3/7.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/AI_Composer/434817dcad9bf2e80a2c9ec9d9ef8308e71e855b/generated_music/mp3/7.mp3 -------------------------------------------------------------------------------- /generated_music/mp3/8.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/AI_Composer/434817dcad9bf2e80a2c9ec9d9ef8308e71e855b/generated_music/mp3/8.mp3 -------------------------------------------------------------------------------- /generated_music/mp3/9.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/AI_Composer/434817dcad9bf2e80a2c9ec9d9ef8308e71e855b/generated_music/mp3/9.mp3 -------------------------------------------------------------------------------- /generated_music/mp3/nottingham_sample.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/AI_Composer/434817dcad9bf2e80a2c9ec9d9ef8308e71e855b/generated_music/mp3/nottingham_sample.mp3 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | import zipfile 3 | import nottingham_util 4 | import rnn 5 | 6 | # collect the data 7 | url = "http://www-etud.iro.umontreal.ca/~boulanni/Nottingham.zip" 8 | urllib.urlretrieve(url, "dataset.zip") 9 | 10 | zip = zipfile.ZipFile(r'dataset.zip') 11 | zip.extractall('data') 12 | 13 | # build the model 14 | nottingham_util.create_model() 15 | 16 | # train the model 17 | rnn.train_model() 18 | -------------------------------------------------------------------------------- /midi_util.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | from collections import defaultdict 3 | import numpy as np 4 | import midi 5 | 6 | RANGE = 128 7 | 8 | def round_tick(tick, time_step): 9 | return int(round(tick/float(time_step)) * time_step) 10 | 11 | def ingest_notes(track, verbose=False): 12 | 13 | notes = { n: [] for n in range(RANGE) } 14 | current_tick = 0 15 | 16 | for msg in track: 17 | # ignore all end of track events 18 | if isinstance(msg, midi.EndOfTrackEvent): 19 | continue 20 | 21 | if msg.tick > 0: 22 | current_tick += msg.tick 23 | 24 | # velocity of 0 is equivalent to note off, so treat as such 25 | if isinstance(msg, midi.NoteOnEvent) and msg.get_velocity() != 0: 26 | if len(notes[msg.get_pitch()]) > 0 and \ 27 | len(notes[msg.get_pitch()][-1]) != 2: 28 | if verbose: 29 | print "Warning: double NoteOn encountered, deleting the first" 30 | print msg 31 | else: 32 | notes[msg.get_pitch()] += [[current_tick]] 33 | elif isinstance(msg, midi.NoteOffEvent) or \ 34 | (isinstance(msg, midi.NoteOnEvent) and msg.get_velocity() == 0): 35 | # sanity check: no notes end without being started 36 | if len(notes[msg.get_pitch()][-1]) != 1: 37 | if verbose: 38 | print "Warning: skipping NoteOff Event with no corresponding NoteOn" 39 | print msg 40 | else: 41 | notes[msg.get_pitch()][-1] += [current_tick] 42 | 43 | return notes, current_tick 44 | 45 | def round_notes(notes, track_ticks, time_step, R=None, O=None): 46 | if not R: 47 | R = RANGE 48 | if not O: 49 | O = 0 50 | 51 | sequence = np.zeros((track_ticks/time_step, R)) 52 | disputed = { t: defaultdict(int) for t in range(track_ticks/time_step) } 53 | for note in notes: 54 | for (start, end) in notes[note]: 55 | start_t = round_tick(start, time_step) / time_step 56 | end_t = round_tick(end, time_step) / time_step 57 | # normal case where note is long enough 58 | if end - start > time_step/2 and start_t != end_t: 59 | sequence[start_t:end_t, note - O] = 1 60 | # cases where note is within bounds of time step 61 | elif start > start_t * time_step: 62 | disputed[start_t][note] += (end - start) 63 | elif end <= end_t * time_step: 64 | disputed[end_t-1][note] += (end - start) 65 | # case where a note is on the border 66 | else: 67 | before_border = start_t * time_step - start 68 | if before_border > 0: 69 | disputed[start_t-1][note] += before_border 70 | after_border = end - start_t * time_step 71 | if after_border > 0 and end < track_ticks: 72 | disputed[start_t][note] += after_border 73 | 74 | # solve disputed 75 | for seq_idx in range(sequence.shape[0]): 76 | if np.count_nonzero(sequence[seq_idx, :]) == 0 and len(disputed[seq_idx]) > 0: 77 | # print seq_idx, disputed[seq_idx] 78 | sorted_notes = sorted(disputed[seq_idx].items(), 79 | key=lambda x: x[1]) 80 | max_val = max(x[1] for x in sorted_notes) 81 | top_notes = filter(lambda x: x[1] >= max_val, sorted_notes) 82 | for note, _ in top_notes: 83 | sequence[seq_idx, note - O] = 1 84 | 85 | return sequence 86 | 87 | def parse_midi_to_sequence(input_filename, time_step, verbose=False): 88 | sequence = [] 89 | pattern = midi.read_midifile(input_filename) 90 | 91 | if len(pattern) < 1: 92 | raise Exception("No pattern found in midi file") 93 | 94 | if verbose: 95 | print "Track resolution: {}".format(pattern.resolution) 96 | print "Number of tracks: {}".format(len(pattern)) 97 | print "Time step: {}".format(time_step) 98 | 99 | # Track ingestion stage 100 | notes = { n: [] for n in range(RANGE) } 101 | track_ticks = 0 102 | for track in pattern: 103 | current_tick = 0 104 | for msg in track: 105 | # ignore all end of track events 106 | if isinstance(msg, midi.EndOfTrackEvent): 107 | continue 108 | 109 | if msg.tick > 0: 110 | current_tick += msg.tick 111 | 112 | # velocity of 0 is equivalent to note off, so treat as such 113 | if isinstance(msg, midi.NoteOnEvent) and msg.get_velocity() != 0: 114 | if len(notes[msg.get_pitch()]) > 0 and \ 115 | len(notes[msg.get_pitch()][-1]) != 2: 116 | if verbose: 117 | print "Warning: double NoteOn encountered, deleting the first" 118 | print msg 119 | else: 120 | notes[msg.get_pitch()] += [[current_tick]] 121 | elif isinstance(msg, midi.NoteOffEvent) or \ 122 | (isinstance(msg, midi.NoteOnEvent) and msg.get_velocity() == 0): 123 | # sanity check: no notes end without being started 124 | if len(notes[msg.get_pitch()][-1]) != 1: 125 | if verbose: 126 | print "Warning: skipping NoteOff Event with no corresponding NoteOn" 127 | print msg 128 | else: 129 | notes[msg.get_pitch()][-1] += [current_tick] 130 | 131 | track_ticks = max(current_tick, track_ticks) 132 | 133 | track_ticks = round_tick(track_ticks, time_step) 134 | if verbose: 135 | print "Track ticks (rounded): {} ({} time steps)".format(track_ticks, track_ticks/time_step) 136 | 137 | sequence = round_notes(notes, track_ticks, time_step) 138 | 139 | return sequence 140 | 141 | class MidiWriter(object): 142 | 143 | def __init__(self, verbose=False): 144 | self.verbose = verbose 145 | self.note_range = RANGE 146 | 147 | def note_off(self, val, tick): 148 | self.track.append(midi.NoteOffEvent(tick=tick, pitch=val)) 149 | return 0 150 | 151 | def note_on(self, val, tick): 152 | self.track.append(midi.NoteOnEvent(tick=tick, pitch=val, velocity=70)) 153 | return 0 154 | 155 | def dump_sequence_to_midi(self, sequence, output_filename, time_step, 156 | resolution, metronome=24): 157 | if self.verbose: 158 | print "Dumping sequence to MIDI file: {}".format(output_filename) 159 | print "Resolution: {}".format(resolution) 160 | print "Time Step: {}".format(time_step) 161 | 162 | pattern = midi.Pattern(resolution=resolution) 163 | self.track = midi.Track() 164 | 165 | # metadata track 166 | meta_track = midi.Track() 167 | time_sig = midi.TimeSignatureEvent() 168 | time_sig.set_numerator(4) 169 | time_sig.set_denominator(4) 170 | time_sig.set_metronome(metronome) 171 | time_sig.set_thirtyseconds(8) 172 | meta_track.append(time_sig) 173 | pattern.append(meta_track) 174 | 175 | # reshape to (SEQ_LENGTH X NUM_DIMS) 176 | sequence = np.reshape(sequence, [-1, self.note_range]) 177 | 178 | time_steps = sequence.shape[0] 179 | if self.verbose: 180 | print "Total number of time steps: {}".format(time_steps) 181 | 182 | tick = time_step 183 | self.notes_on = { n: False for n in range(self.note_range) } 184 | # for seq_idx in range(188, 220): 185 | for seq_idx in range(time_steps): 186 | notes = np.nonzero(sequence[seq_idx, :])[0].tolist() 187 | 188 | # this tick will only be assigned to first NoteOn/NoteOff in 189 | # this time_step 190 | 191 | # NoteOffEvents come first so they'll have the tick value 192 | # go through all notes that are currently on and see if any 193 | # turned off 194 | for n in self.notes_on: 195 | if self.notes_on[n] and n not in notes: 196 | tick = self.note_off(n, tick) 197 | self.notes_on[n] = False 198 | 199 | # Turn on any notes that weren't previously on 200 | for note in notes: 201 | if not self.notes_on[note]: 202 | tick = self.note_on(note, tick) 203 | self.notes_on[note] = True 204 | 205 | tick += time_step 206 | 207 | # flush out notes 208 | for n in self.notes_on: 209 | if self.notes_on[n]: 210 | self.note_off(n, tick) 211 | tick = 0 212 | self.notes_on[n] = False 213 | 214 | pattern.append(self.track) 215 | midi.write_midifile(output_filename, pattern) 216 | 217 | if __name__ == '__main__': 218 | pass 219 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.models.rnn import rnn_cell 6 | from tensorflow.models.rnn import rnn, seq2seq 7 | 8 | import nottingham_util 9 | 10 | class Model(object): 11 | """ 12 | Cross-Entropy Naive Formulation 13 | A single time step may have multiple notes active, so a sigmoid cross entropy loss 14 | is used to match targets. 15 | 16 | seq_input: a [ T x B x D ] matrix, where T is the time steps in the batch, B is the 17 | batch size, and D is the amount of dimensions 18 | """ 19 | 20 | def __init__(self, config, training=False): 21 | self.config = config 22 | self.time_batch_len = time_batch_len = config.time_batch_len 23 | self.input_dim = input_dim = config.input_dim 24 | hidden_size = config.hidden_size 25 | num_layers = config.num_layers 26 | dropout_prob = config.dropout_prob 27 | input_dropout_prob = config.input_dropout_prob 28 | cell_type = config.cell_type 29 | 30 | self.seq_input = \ 31 | tf.placeholder(tf.float32, shape=[self.time_batch_len, None, input_dim]) 32 | 33 | if (dropout_prob <= 0.0 or dropout_prob > 1.0): 34 | raise Exception("Invalid dropout probability: {}".format(dropout_prob)) 35 | 36 | if (input_dropout_prob <= 0.0 or input_dropout_prob > 1.0): 37 | raise Exception("Invalid input dropout probability: {}".format(input_dropout_prob)) 38 | 39 | # setup variables 40 | with tf.variable_scope("rnnlstm"): 41 | output_W = tf.get_variable("output_w", [hidden_size, input_dim]) 42 | output_b = tf.get_variable("output_b", [input_dim]) 43 | self.lr = tf.constant(config.learning_rate, name="learning_rate") 44 | self.lr_decay = tf.constant(config.learning_rate_decay, name="learning_rate_decay") 45 | 46 | def create_cell(input_size): 47 | if cell_type == "vanilla": 48 | cell_class = rnn_cell.BasicRNNCell 49 | elif cell_type == "gru": 50 | cell_class = rnn_cell.BasicGRUCell 51 | elif cell_type == "lstm": 52 | cell_class = rnn_cell.BasicLSTMCell 53 | else: 54 | raise Exception("Invalid cell type: {}".format(cell_type)) 55 | 56 | cell = cell_class(hidden_size, input_size = input_size) 57 | if training: 58 | return rnn_cell.DropoutWrapper(cell, output_keep_prob = dropout_prob) 59 | else: 60 | return cell 61 | 62 | if training: 63 | self.seq_input_dropout = tf.nn.dropout(self.seq_input, keep_prob = input_dropout_prob) 64 | else: 65 | self.seq_input_dropout = self.seq_input 66 | 67 | self.cell = rnn_cell.MultiRNNCell( 68 | [create_cell(input_dim)] + [create_cell(hidden_size) for i in range(1, num_layers)]) 69 | 70 | batch_size = tf.shape(self.seq_input_dropout)[0] 71 | self.initial_state = self.cell.zero_state(batch_size, tf.float32) 72 | inputs_list = tf.unpack(self.seq_input_dropout) 73 | 74 | # rnn outputs a list of [batch_size x H] outputs 75 | outputs_list, self.final_state = rnn.rnn(self.cell, inputs_list, 76 | initial_state=self.initial_state) 77 | 78 | outputs = tf.pack(outputs_list) 79 | outputs_concat = tf.reshape(outputs, [-1, hidden_size]) 80 | logits_concat = tf.matmul(outputs_concat, output_W) + output_b 81 | logits = tf.reshape(logits_concat, [self.time_batch_len, -1, input_dim]) 82 | 83 | # probabilities of each note 84 | self.probs = self.calculate_probs(logits) 85 | self.loss = self.init_loss(logits, logits_concat) 86 | self.train_step = tf.train.RMSPropOptimizer(self.lr, decay = self.lr_decay) \ 87 | .minimize(self.loss) 88 | 89 | def init_loss(self, outputs, _): 90 | self.seq_targets = \ 91 | tf.placeholder(tf.float32, [self.time_batch_len, None, self.input_dim]) 92 | 93 | batch_size = tf.shape(self.seq_input_dropout) 94 | cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(outputs, self.seq_targets) 95 | return tf.reduce_sum(cross_ent) / self.time_batch_len / tf.to_float(batch_size) 96 | 97 | def calculate_probs(self, logits): 98 | return tf.sigmoid(logits) 99 | 100 | def get_cell_zero_state(self, session, batch_size): 101 | return self.cell.zero_state(batch_size, tf.float32).eval(session=session) 102 | 103 | class NottinghamModel(Model): 104 | """ 105 | Dual softmax formulation 106 | 107 | A single time step should be a concatenation of two one-hot-encoding binary vectors. 108 | Loss function is a sum of two softmax loss functions over [:r] and [r:] respectively, 109 | where r is the number of melody classes 110 | """ 111 | 112 | def init_loss(self, outputs, outputs_concat): 113 | self.seq_targets = \ 114 | tf.placeholder(tf.int64, [self.time_batch_len, None, 2]) 115 | batch_size = tf.shape(self.seq_targets)[1] 116 | 117 | with tf.variable_scope("rnnlstm"): 118 | self.melody_coeff = tf.constant(self.config.melody_coeff) 119 | 120 | r = nottingham_util.NOTTINGHAM_MELODY_RANGE 121 | targets_concat = tf.reshape(self.seq_targets, [-1, 2]) 122 | 123 | melody_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( \ 124 | outputs_concat[:, :r], \ 125 | targets_concat[:, 0]) 126 | harmony_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( \ 127 | outputs_concat[:, r:], \ 128 | targets_concat[:, 1]) 129 | losses = tf.add(self.melody_coeff * melody_loss, (1 - self.melody_coeff) * harmony_loss) 130 | return tf.reduce_sum(losses) / self.time_batch_len / tf.to_float(batch_size) 131 | 132 | def calculate_probs(self, logits): 133 | steps = [] 134 | for t in range(self.time_batch_len): 135 | melody_softmax = tf.nn.softmax(logits[t, :, :nottingham_util.NOTTINGHAM_MELODY_RANGE]) 136 | harmony_softmax = tf.nn.softmax(logits[t, :, nottingham_util.NOTTINGHAM_MELODY_RANGE:]) 137 | steps.append(tf.concat(1, [melody_softmax, harmony_softmax])) 138 | return tf.pack(steps) 139 | 140 | def assign_melody_coeff(self, session, melody_coeff): 141 | if melody_coeff < 0.0 or melody_coeff > 1.0: 142 | raise Exception("Invalid melody coeffecient") 143 | 144 | session.run(tf.assign(self.melody_coeff, melody_coeff)) 145 | 146 | class NottinghamSeparate(Model): 147 | """ 148 | Single softmax formulation 149 | 150 | Regular single classification formulation, used to train baseline models 151 | where the melody and harmony are trained separately 152 | """ 153 | 154 | def init_loss(self, outputs, outputs_concat): 155 | self.seq_targets = \ 156 | tf.placeholder(tf.int64, [self.time_batch_len, None]) 157 | batch_size = tf.shape(self.seq_targets)[1] 158 | 159 | with tf.variable_scope("rnnlstm"): 160 | self.melody_coeff = tf.constant(self.config.melody_coeff) 161 | 162 | targets_concat = tf.reshape(self.seq_targets, [-1]) 163 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits( \ 164 | outputs_concat, targets_concat) 165 | 166 | return tf.reduce_sum(losses) / self.time_batch_len / tf.to_float(batch_size) 167 | 168 | def calculate_probs(self, logits): 169 | steps = [] 170 | for t in range(self.time_batch_len): 171 | softmax = tf.nn.softmax(logits[t, :, :]) 172 | steps.append(softmax) 173 | return tf.pack(steps) 174 | -------------------------------------------------------------------------------- /nottingham_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import midi 4 | import cPickle 5 | from pprint import pprint 6 | 7 | import midi_util 8 | import mingus 9 | import mingus.core.chords 10 | import sampling 11 | 12 | PICKLE_LOC = 'data/nottingham.pickle' 13 | NOTTINGHAM_MELODY_MAX = 88 14 | NOTTINGHAM_MELODY_MIN = 55 15 | # add one to the range for silence in melody 16 | NOTTINGHAM_MELODY_RANGE = NOTTINGHAM_MELODY_MAX - NOTTINGHAM_MELODY_MIN + 1 + 1 17 | CHORD_BASE = 48 18 | CHORD_BLACKLIST = ['major third', 'minor third', 'perfect fifth'] 19 | NO_CHORD = 'NONE' 20 | SHARPS_TO_FLATS = { 21 | "A#": "Bb", 22 | "B#": "C", 23 | "C#": "Db", 24 | "D#": "Eb", 25 | "E#": "F", 26 | "F#": "Gb", 27 | "G#": "Ab", 28 | } 29 | 30 | def resolve_chord(chord): 31 | """ 32 | Resolves rare chords to their closest common chord, to limit the total 33 | amount of chord classes. 34 | """ 35 | if chord in CHORD_BLACKLIST: 36 | return None 37 | # take the first of dual chords 38 | if "|" in chord: 39 | chord = chord.split("|")[0] 40 | # remove 7ths, 11ths, 9s, 6th, 41 | if chord.endswith("11"): 42 | chord = chord[:-2] 43 | if chord.endswith("7") or chord.endswith("9") or chord.endswith("6"): 44 | chord = chord[:-1] 45 | # replace 'dim' with minor 46 | if chord.endswith("dim"): 47 | chord = chord[:-3] + "m" 48 | return chord 49 | 50 | def prepare_nottingham_pickle(time_step, chord_cutoff=64, filename=PICKLE_LOC, verbose=False): 51 | """ 52 | time_step: the time step to discretize all notes into 53 | chord_cutoff: if chords are seen less than this cutoff, they are ignored and marked as 54 | as rests in the resulting dataset 55 | filename: the location where the pickle will be saved to 56 | """ 57 | 58 | data = {} 59 | store = {} 60 | chords = {} 61 | max_seq = 0 62 | seq_lens = [] 63 | 64 | for d in ["train", "test", "valid"]: 65 | print "Parsing {}...".format(d) 66 | parsed = parse_nottingham_directory("data/Nottingham/{}".format(d), time_step, verbose=False) 67 | metadata = [s[0] for s in parsed] 68 | seqs = [s[1] for s in parsed] 69 | data[d] = seqs 70 | data[d + '_metadata'] = metadata 71 | lens = [len(s[1]) for s in seqs] 72 | seq_lens += lens 73 | max_seq = max(max_seq, max(lens)) 74 | 75 | for _, harmony in seqs: 76 | for h in harmony: 77 | if h not in chords: 78 | chords[h] = 1 79 | else: 80 | chords[h] += 1 81 | 82 | avg_seq = float(sum(seq_lens)) / len(seq_lens) 83 | 84 | chords = { c: i for c, i in chords.iteritems() if chords[c] >= chord_cutoff } 85 | chord_mapping = { c: i for i, c in enumerate(chords.keys()) } 86 | num_chords = len(chord_mapping) 87 | store['chord_to_idx'] = chord_mapping 88 | if verbose: 89 | pprint(chords) 90 | print "Number of chords: {}".format(num_chords) 91 | print "Max Sequence length: {}".format(max_seq) 92 | print "Avg Sequence length: {}".format(avg_seq) 93 | print "Num Sequences: {}".format(len(seq_lens)) 94 | 95 | def combine(melody, harmony): 96 | full = np.zeros((melody.shape[0], NOTTINGHAM_MELODY_RANGE + num_chords)) 97 | 98 | assert melody.shape[0] == len(harmony) 99 | 100 | # for all melody sequences that don't have any notes, add the empty melody marker (last one) 101 | for i in range(melody.shape[0]): 102 | if np.count_nonzero(melody[i, :]) == 0: 103 | melody[i, NOTTINGHAM_MELODY_RANGE-1] = 1 104 | 105 | # all melody encodings should now have exactly one 1 106 | for i in range(melody.shape[0]): 107 | assert np.count_nonzero(melody[i, :]) == 1 108 | 109 | # add all the melodies 110 | full[:, :melody.shape[1]] += melody 111 | 112 | harmony_idxs = [ chord_mapping[h] if h in chord_mapping else chord_mapping[NO_CHORD] \ 113 | for h in harmony ] 114 | harmony_idxs = [ NOTTINGHAM_MELODY_RANGE + h for h in harmony_idxs ] 115 | full[np.arange(len(harmony)), harmony_idxs] = 1 116 | 117 | # all full encodings should have exactly two 1's 118 | for i in range(full.shape[0]): 119 | assert np.count_nonzero(full[i, :]) == 2 120 | 121 | return full 122 | 123 | for d in ["train", "test", "valid"]: 124 | print "Combining {}".format(d) 125 | store[d] = [ combine(m, h) for m, h in data[d] ] 126 | store[d + '_metadata'] = data[d + '_metadata'] 127 | 128 | with open(filename, 'w') as f: 129 | cPickle.dump(store, f, protocol=-1) 130 | 131 | return True 132 | 133 | def parse_nottingham_directory(input_dir, time_step, verbose=False): 134 | """ 135 | input_dir: a directory containing MIDI files 136 | 137 | returns a list of [T x D] matrices, where each matrix represents a 138 | a sequence with T time steps over D dimensions 139 | """ 140 | 141 | files = [ os.path.join(input_dir, f) for f in os.listdir(input_dir) 142 | if os.path.isfile(os.path.join(input_dir, f)) ] 143 | sequences = [ \ 144 | parse_nottingham_to_sequence(f, time_step=time_step, verbose=verbose) \ 145 | for f in files ] 146 | 147 | if verbose: 148 | print "Total sequences: {}".format(len(sequences)) 149 | 150 | # filter out the non 2-track MIDI's 151 | sequences = filter(lambda x: x[1] != None, sequences) 152 | 153 | if verbose: 154 | print "Total sequences left: {}".format(len(sequences)) 155 | 156 | return sequences 157 | 158 | def parse_nottingham_to_sequence(input_filename, time_step, verbose=False): 159 | """ 160 | input_filename: a MIDI filename 161 | 162 | returns a [T x D] matrix representing a sequence with T time steps over 163 | D dimensions 164 | """ 165 | sequence = [] 166 | pattern = midi.read_midifile(input_filename) 167 | 168 | metadata = { 169 | "path": input_filename, 170 | "name": input_filename.split("/")[-1].split(".")[0] 171 | } 172 | 173 | # Most nottingham midi's have 3 tracks. metadata info, melody, harmony 174 | # throw away any tracks that don't fit this 175 | if len(pattern) != 3: 176 | if verbose: 177 | "Skipping track with {} tracks".format(len(pattern)) 178 | return (metadata, None) 179 | 180 | # ticks_per_quarter = -1 181 | for msg in pattern[0]: 182 | if isinstance(msg, midi.TimeSignatureEvent): 183 | metadata["ticks_per_quarter"] = msg.get_metronome() 184 | ticks_per_quarter = msg.get_metronome() 185 | 186 | if verbose: 187 | print "{}".format(input_filename) 188 | print "Track resolution: {}".format(pattern.resolution) 189 | print "Number of tracks: {}".format(len(pattern)) 190 | print "Time step: {}".format(time_step) 191 | print "Ticks per quarter: {}".format(ticks_per_quarter) 192 | 193 | # Track ingestion stage 194 | track_ticks = 0 195 | 196 | melody_notes, melody_ticks = midi_util.ingest_notes(pattern[1]) 197 | harmony_notes, harmony_ticks = midi_util.ingest_notes(pattern[2]) 198 | 199 | track_ticks = midi_util.round_tick(max(melody_ticks, harmony_ticks), time_step) 200 | if verbose: 201 | print "Track ticks (rounded): {} ({} time steps)".format(track_ticks, track_ticks/time_step) 202 | 203 | melody_sequence = midi_util.round_notes(melody_notes, track_ticks, time_step, 204 | R=NOTTINGHAM_MELODY_RANGE, O=NOTTINGHAM_MELODY_MIN) 205 | 206 | for i in range(melody_sequence.shape[0]): 207 | if np.count_nonzero(melody_sequence[i, :]) > 1: 208 | if verbose: 209 | print "Double note found: {}: {} ({})".format(i, np.nonzero(melody_sequence[i, :]), input_filename) 210 | return (metadata, None) 211 | 212 | harmony_sequence = midi_util.round_notes(harmony_notes, track_ticks, time_step) 213 | 214 | harmonies = [] 215 | for i in range(harmony_sequence.shape[0]): 216 | notes = np.where(harmony_sequence[i] == 1)[0] 217 | if len(notes) > 0: 218 | notes_shift = [ mingus.core.notes.int_to_note(h%12) for h in notes] 219 | chord = mingus.core.chords.determine(notes_shift, shorthand=True) 220 | if len(chord) == 0: 221 | # try flat combinations 222 | notes_shift = [ SHARPS_TO_FLATS[n] if n in SHARPS_TO_FLATS else n for n in notes_shift] 223 | chord = mingus.core.chords.determine(notes_shift, shorthand=True) 224 | if len(chord) == 0: 225 | if verbose: 226 | print "Could not determine chord: {} ({}, {}), defaulting to last steps chord" \ 227 | .format(notes_shift, input_filename, i) 228 | if len(harmonies) > 0: 229 | harmonies.append(harmonies[-1]) 230 | else: 231 | harmonies.append(NO_CHORD) 232 | else: 233 | resolved = resolve_chord(chord[0]) 234 | if resolved: 235 | harmonies.append(resolved) 236 | else: 237 | harmonies.append(NO_CHORD) 238 | else: 239 | harmonies.append(NO_CHORD) 240 | 241 | return (metadata, (melody_sequence, harmonies)) 242 | 243 | class NottinghamMidiWriter(midi_util.MidiWriter): 244 | 245 | def __init__(self, chord_to_idx, verbose=False): 246 | super(NottinghamMidiWriter, self).__init__(verbose) 247 | self.idx_to_chord = { i: c for c, i in chord_to_idx.items() } 248 | self.note_range = NOTTINGHAM_MELODY_RANGE + len(self.idx_to_chord) 249 | 250 | def dereference_chord(self, idx): 251 | if idx not in self.idx_to_chord: 252 | raise Exception("No chord index found: {}".format(idx)) 253 | shorthand = self.idx_to_chord[idx] 254 | if shorthand == NO_CHORD: 255 | return [] 256 | chord = mingus.core.chords.from_shorthand(shorthand) 257 | return [ CHORD_BASE + mingus.core.notes.note_to_int(n) for n in chord ] 258 | 259 | def note_on(self, val, tick): 260 | if val >= NOTTINGHAM_MELODY_RANGE: 261 | notes = self.dereference_chord(val - NOTTINGHAM_MELODY_RANGE) 262 | else: 263 | # if note is the top of the range, then it stands for gap in melody 264 | if val == NOTTINGHAM_MELODY_RANGE - 1: 265 | notes = [] 266 | else: 267 | notes = [NOTTINGHAM_MELODY_MIN + val] 268 | 269 | # print 'turning on {}'.format(notes) 270 | for note in notes: 271 | self.track.append(midi.NoteOnEvent(tick=tick, pitch=note, velocity=70)) 272 | tick = 0 # notes that come right after each other should have zero tick 273 | 274 | return tick 275 | 276 | def note_off(self, val, tick): 277 | if val >= NOTTINGHAM_MELODY_RANGE: 278 | notes = self.dereference_chord(val - NOTTINGHAM_MELODY_RANGE) 279 | else: 280 | notes = [NOTTINGHAM_MELODY_MIN + val] 281 | 282 | # print 'turning off {}'.format(notes) 283 | for note in notes: 284 | self.track.append(midi.NoteOffEvent(tick=tick, pitch=note)) 285 | tick = 0 286 | 287 | return tick 288 | 289 | class NottinghamSampler(object): 290 | 291 | def __init__(self, chord_to_idx, method = 'sample', harmony_repeat_max = 16, melody_repeat_max = 16, verbose=False): 292 | self.verbose = verbose 293 | self.idx_to_chord = { i: c for c, i in chord_to_idx.items() } 294 | self.method = method 295 | 296 | self.hlast = 0 297 | self.hcount = 0 298 | self.hrepeat = harmony_repeat_max 299 | 300 | self.mlast = 0 301 | self.mcount = 0 302 | self.mrepeat = melody_repeat_max 303 | 304 | def visualize_probs(self, probs): 305 | if not self.verbose: 306 | return 307 | 308 | melodies = sorted(list(enumerate(probs[:NOTTINGHAM_MELODY_RANGE])), 309 | key=lambda x: x[1], reverse=True)[:4] 310 | harmonies = sorted(list(enumerate(probs[NOTTINGHAM_MELODY_RANGE:])), 311 | key=lambda x: x[1], reverse=True)[:4] 312 | harmonies = [(self.idx_to_chord[i], j) for i, j in harmonies] 313 | print 'Top Melody Notes: ' 314 | pprint(melodies) 315 | print 'Top Harmony Notes: ' 316 | pprint(harmonies) 317 | 318 | def sample_notes_static(self, probs): 319 | top_m = probs[:NOTTINGHAM_MELODY_RANGE].argsort() 320 | if top_m[-1] == self.mlast and self.mcount >= self.mrepeat: 321 | top_m = top_m[:-1] 322 | self.mcount = 0 323 | elif top_m[-1] == self.mlast: 324 | self.mcount += 1 325 | else: 326 | self.mcount = 0 327 | self.mlast = top_m[-1] 328 | top_melody = top_m[-1] 329 | 330 | top_h = probs[NOTTINGHAM_MELODY_RANGE:].argsort() 331 | if top_h[-1] == self.hlast and self.hcount >= self.hrepeat: 332 | top_h = top_h[:-1] 333 | self.hcount = 0 334 | elif top_h[-1] == self.hlast: 335 | self.hcount += 1 336 | else: 337 | self.hcount = 0 338 | self.hlast = top_h[-1] 339 | top_chord = top_h[-1] + NOTTINGHAM_MELODY_RANGE 340 | 341 | chord = np.zeros([len(probs)], dtype=np.int32) 342 | chord[top_melody] = 1.0 343 | chord[top_chord] = 1.0 344 | return chord 345 | 346 | def sample_notes_dist(self, probs): 347 | idxed = [(i, p) for i, p in enumerate(probs)] 348 | 349 | notes = [n[0] for n in idxed] 350 | ps = np.array([n[1] for n in idxed]) 351 | r = NOTTINGHAM_MELODY_RANGE 352 | 353 | assert np.allclose(np.sum(ps[:r]), 1.0) 354 | assert np.allclose(np.sum(ps[r:]), 1.0) 355 | 356 | # renormalize so numpy doesn't complain 357 | ps[:r] = ps[:r] / ps[:r].sum() 358 | ps[r:] = ps[r:] / ps[r:].sum() 359 | 360 | melody = np.random.choice(notes[:r], p=ps[:r]) 361 | harmony = np.random.choice(notes[r:], p=ps[r:]) 362 | 363 | chord = np.zeros([len(probs)], dtype=np.int32) 364 | chord[melody] = 1.0 365 | chord[harmony] = 1.0 366 | return chord 367 | 368 | 369 | def sample_notes(self, probs): 370 | self.visualize_probs(probs) 371 | if self.method == 'static': 372 | return self.sample_notes_static(probs) 373 | elif self.method == 'sample': 374 | return self.sample_notes_dist(probs) 375 | 376 | def accuracy(batch_probs, data, num_samples=1): 377 | """ 378 | Batch Probs: { num_time_steps: [ time_step_1, time_step_2, ... ] } 379 | Data: [ 380 | [ [ data ], [ target ] ], # batch with one time step 381 | [ [ data1, data2 ], [ target1, target2 ] ], # batch with two time steps 382 | ... 383 | ] 384 | """ 385 | 386 | def calc_accuracy(): 387 | total = 0 388 | melody_correct, harmony_correct = 0, 0 389 | melody_incorrect, harmony_incorrect = 0, 0 390 | for _, batch_targets in data: 391 | num_time_steps = len(batch_targets) 392 | for ts_targets, ts_probs in zip(batch_targets, batch_probs[num_time_steps]): 393 | 394 | assert ts_targets.shape == ts_targets.shape 395 | 396 | for seq_idx in range(ts_targets.shape[1]): 397 | for step_idx in range(ts_targets.shape[0]): 398 | idxed = [(n, p) for n, p in \ 399 | enumerate(ts_probs[step_idx, seq_idx, :])] 400 | notes = [n[0] for n in idxed] 401 | ps = np.array([n[1] for n in idxed]) 402 | r = NOTTINGHAM_MELODY_RANGE 403 | 404 | assert np.allclose(np.sum(ps[:r]), 1.0) 405 | assert np.allclose(np.sum(ps[r:]), 1.0) 406 | 407 | # renormalize so numpy doesn't complain 408 | ps[:r] = ps[:r] / ps[:r].sum() 409 | ps[r:] = ps[r:] / ps[r:].sum() 410 | 411 | melody = np.random.choice(notes[:r], p=ps[:r]) 412 | harmony = np.random.choice(notes[r:], p=ps[r:]) 413 | 414 | melody_target = ts_targets[step_idx, seq_idx, 0] 415 | if melody_target == melody: 416 | melody_correct += 1 417 | else: 418 | melody_incorrect += 1 419 | 420 | harmony_target = ts_targets[step_idx, seq_idx, 1] + r 421 | if harmony_target == harmony: 422 | harmony_correct += 1 423 | else: 424 | harmony_incorrect += 1 425 | 426 | return (melody_correct, melody_incorrect, harmony_correct, harmony_incorrect) 427 | 428 | maccs, haccs, taccs = [], [], [] 429 | for i in range(num_samples): 430 | print "Sample {}".format(i) 431 | m, mi, h, hi = calc_accuracy() 432 | maccs.append( float(m) / float(m + mi)) 433 | haccs.append( float(h) / float(h + hi)) 434 | taccs.append( float(m + h) / float(m + h + mi + hi) ) 435 | 436 | print "Melody Precision/Recall: {}".format(sum(maccs)/len(maccs)) 437 | print "Harmony Precision/Recall: {}".format(sum(haccs)/len(haccs)) 438 | print "Total Precision/Recall: {}".format(sum(taccs)/len(taccs)) 439 | 440 | def seperate_accuracy(batch_probs, data, num_samples=1): 441 | 442 | def calc_accuracy(): 443 | total = 0 444 | total_correct, total_incorrect = 0, 0 445 | for _, batch_targets in data: 446 | num_time_steps = len(batch_targets) 447 | for ts_targets, ts_probs in zip(batch_targets, batch_probs[num_time_steps]): 448 | 449 | assert ts_targets.shape == ts_targets.shape 450 | 451 | for seq_idx in range(ts_targets.shape[1]): 452 | for step_idx in range(ts_targets.shape[0]): 453 | 454 | idxed = [(n, p) for n, p in \ 455 | enumerate(ts_probs[step_idx, seq_idx, :])] 456 | notes = [n[0] for n in idxed] 457 | ps = np.array([n[1] for n in idxed]) 458 | r = NOTTINGHAM_MELODY_RANGE 459 | 460 | assert np.allclose(np.sum(ps), 1.0) 461 | ps = ps / ps.sum() 462 | note = np.random.choice(notes, p=ps) 463 | 464 | target = ts_targets[step_idx, seq_idx] 465 | if target == note: 466 | total_correct += 1 467 | else: 468 | total_incorrect += 1 469 | 470 | return (total_correct, total_incorrect) 471 | 472 | taccs = [] 473 | for i in range(num_samples): 474 | print "Sample {}".format(i) 475 | c, ic = calc_accuracy() 476 | taccs.append( float(c) / float(c + ic)) 477 | 478 | print "Precision/Recall: {}".format(sum(taccs)/len(taccs)) 479 | 480 | def i_vi_iv_v(chord_to_idx, repeats, input_dim): 481 | r = NOTTINGHAM_MELODY_RANGE 482 | 483 | i = np.zeros(input_dim) 484 | i[r + chord_to_idx['CM']] = 1 485 | 486 | vi = np.zeros(input_dim) 487 | vi[r + chord_to_idx['Am']] = 1 488 | 489 | iv = np.zeros(input_dim) 490 | iv[r + chord_to_idx['FM']] = 1 491 | 492 | v = np.zeros(input_dim) 493 | v[r + chord_to_idx['GM']] = 1 494 | 495 | full_seq = [i] * 16 + [vi] * 16 + [iv] * 16 + [v] * 16 496 | full_seq = full_seq * repeats 497 | 498 | return full_seq 499 | 500 | def create_model(): 501 | resolution = 480 502 | time_step = 120 503 | 504 | assert resolve_chord("GM7") == "GM" 505 | assert resolve_chord("G#dim|AM7") == "G#m" 506 | assert resolve_chord("Dm9") == "Dm" 507 | assert resolve_chord("AM11") == "AM" 508 | 509 | prepare_nottingham_pickle(time_step, verbose=True) 510 | print('Model created!') 511 | 512 | if __name__ == '__main__': 513 | 514 | resolution = 480 515 | time_step = 120 516 | 517 | assert resolve_chord("GM7") == "GM" 518 | assert resolve_chord("G#dim|AM7") == "G#m" 519 | assert resolve_chord("Dm9") == "Dm" 520 | assert resolve_chord("AM11") == "AM" 521 | 522 | prepare_nottingham_pickle(time_step, verbose=True) 523 | -------------------------------------------------------------------------------- /rnn.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import time 4 | import itertools 5 | import cPickle 6 | import logging 7 | import random 8 | import string 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | import matplotlib.pyplot as plt 13 | 14 | import nottingham_util 15 | import util 16 | from model import Model, NottinghamModel 17 | 18 | def get_config_name(config): 19 | def replace_dot(s): return s.replace(".", "p") 20 | return "nl_" + str(config.num_layers) + "_hs_" + str(config.hidden_size) + \ 21 | replace_dot("_mc_{}".format(config.melody_coeff)) + \ 22 | replace_dot("_dp_{}".format(config.dropout_prob)) + \ 23 | replace_dot("_idp_{}".format(config.input_dropout_prob)) + \ 24 | replace_dot("_tb_{}".format(config.time_batch_len)) 25 | 26 | class DefaultConfig(object): 27 | # model parameters 28 | num_layers = 2 29 | hidden_size = 200 30 | melody_coeff = 0.5 31 | dropout_prob = 0.5 32 | input_dropout_prob = 0.8 33 | cell_type = 'lstm' 34 | 35 | # learning parameters 36 | max_time_batches = 9 37 | time_batch_len = 128 38 | learning_rate = 5e-3 39 | learning_rate_decay = 0.9 40 | num_epochs = 250 41 | 42 | # metadata 43 | dataset = 'softmax' 44 | model_file = '' 45 | 46 | def __repr__(self): 47 | return """Num Layers: {}, Hidden Size: {}, Melody Coeff: {}, Dropout Prob: {}, Input Dropout Prob: {}, Cell Type: {}, Time Batch Len: {}, Learning Rate: {}, Decay: {}""".format(self.num_layers, self.hidden_size, self.melody_coeff, self.dropout_prob, self.input_dropout_prob, self.cell_type, self.time_batch_len, self.learning_rate, self.learning_rate_decay) 48 | 49 | def train_model(): 50 | np.random.seed() 51 | 52 | parser = argparse.ArgumentParser(description='Script to train and save a model.') 53 | parser.add_argument('--dataset', type=str, default='softmax', 54 | # choices = ['bach', 'nottingham', 'softmax'], 55 | choices = ['softmax']) 56 | parser.add_argument('--model_dir', type=str, default='models') 57 | parser.add_argument('--run_name', type=str, default=time.strftime("%m%d_%H%M")) 58 | 59 | args = parser.parse_args() 60 | 61 | if args.dataset == 'softmax': 62 | resolution = 480 63 | time_step = 120 64 | model_class = NottinghamModel 65 | with open(nottingham_util.PICKLE_LOC, 'r') as f: 66 | pickle = cPickle.load(f) 67 | chord_to_idx = pickle['chord_to_idx'] 68 | 69 | input_dim = pickle["train"][0].shape[1] 70 | print 'Finished loading data, input dim: {}'.format(input_dim) 71 | else: 72 | raise Exception("Other datasets not yet implemented") 73 | 74 | initializer = tf.random_uniform_initializer(-0.1, 0.1) 75 | 76 | best_config = None 77 | best_valid_loss = None 78 | 79 | # set up run dir 80 | run_folder = os.path.join(args.model_dir, args.run_name) 81 | if os.path.exists(run_folder): 82 | raise Exception("Run name {} already exists, choose a different one", format(run_folder)) 83 | os.makedirs(run_folder) 84 | 85 | logger = logging.getLogger(__name__) 86 | logger.setLevel(logging.INFO) 87 | logger.addHandler(logging.StreamHandler()) 88 | logger.addHandler(logging.FileHandler(os.path.join(run_folder, "training.log"))) 89 | 90 | grid = { 91 | "dropout_prob": [0.5], 92 | "input_dropout_prob": [0.8], 93 | "melody_coeff": [0.5], 94 | "num_layers": [2], 95 | "hidden_size": [200], 96 | "num_epochs": [250], 97 | "learning_rate": [5e-3], 98 | "learning_rate_decay": [0.9], 99 | "time_batch_len": [128], 100 | } 101 | 102 | # Generate product of hyperparams 103 | runs = list(list(itertools.izip(grid, x)) for x in itertools.product(*grid.itervalues())) 104 | logger.info("{} runs detected".format(len(runs))) 105 | 106 | for combination in runs: 107 | 108 | config = DefaultConfig() 109 | config.dataset = args.dataset 110 | config.model_name = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(12)) + '.model' 111 | for attr, value in combination: 112 | setattr(config, attr, value) 113 | 114 | if config.dataset == 'softmax': 115 | data = util.load_data('', time_step, config.time_batch_len, config.max_time_batches, nottingham=pickle) 116 | config.input_dim = data["input_dim"] 117 | else: 118 | raise Exception("Other datasets not yet implemented") 119 | 120 | logger.info(config) 121 | config_file_path = os.path.join(run_folder, get_config_name(config) + '.config') 122 | with open(config_file_path, 'w') as f: 123 | cPickle.dump(config, f) 124 | 125 | with tf.Graph().as_default(), tf.Session() as session: 126 | with tf.variable_scope("model", reuse=None): 127 | train_model = model_class(config, training=True) 128 | with tf.variable_scope("model", reuse=True): 129 | valid_model = model_class(config, training=False) 130 | 131 | saver = tf.train.Saver(tf.all_variables(), max_to_keep=40) 132 | tf.initialize_all_variables().run() 133 | 134 | # training 135 | early_stop_best_loss = None 136 | start_saving = False 137 | saved_flag = False 138 | train_losses, valid_losses = [], [] 139 | start_time = time.time() 140 | for i in range(config.num_epochs): 141 | loss = util.run_epoch(session, train_model, 142 | data["train"]["data"], training=True, testing=False) 143 | train_losses.append((i, loss)) 144 | if i == 0: 145 | continue 146 | 147 | logger.info('Epoch: {}, Train Loss: {}, Time Per Epoch: {}'.format(\ 148 | i, loss, (time.time() - start_time)/i)) 149 | valid_loss = util.run_epoch(session, valid_model, data["valid"]["data"], training=False, testing=False) 150 | valid_losses.append((i, valid_loss)) 151 | logger.info('Valid Loss: {}'.format(valid_loss)) 152 | 153 | if early_stop_best_loss == None: 154 | early_stop_best_loss = valid_loss 155 | elif valid_loss < early_stop_best_loss: 156 | early_stop_best_loss = valid_loss 157 | if start_saving: 158 | logger.info('Best loss so far encountered, saving model.') 159 | saver.save(session, os.path.join(run_folder, config.model_name)) 160 | saved_flag = True 161 | elif not start_saving: 162 | start_saving = True 163 | logger.info('Valid loss increased for the first time, will start saving models') 164 | saver.save(session, os.path.join(run_folder, config.model_name)) 165 | saved_flag = True 166 | 167 | if not saved_flag: 168 | saver.save(session, os.path.join(run_folder, config.model_name)) 169 | 170 | # set loss axis max to 20 171 | axes = plt.gca() 172 | if config.dataset == 'softmax': 173 | axes.set_ylim([0, 2]) 174 | else: 175 | axes.set_ylim([0, 100]) 176 | plt.plot([t[0] for t in train_losses], [t[1] for t in train_losses]) 177 | plt.plot([t[0] for t in valid_losses], [t[1] for t in valid_losses]) 178 | plt.legend(['Train Loss', 'Validation Loss']) 179 | chart_file_path = os.path.join(run_folder, get_config_name(config) + '.png') 180 | plt.savefig(chart_file_path) 181 | plt.clf() 182 | 183 | logger.info("Config {}, Loss: {}".format(config, early_stop_best_loss)) 184 | if best_valid_loss == None or early_stop_best_loss < best_valid_loss: 185 | logger.info("Found best new model!") 186 | best_valid_loss = early_stop_best_loss 187 | best_config = config 188 | 189 | logger.info("Best Config: {}, Loss: {}".format(best_config, best_valid_loss)) 190 | -------------------------------------------------------------------------------- /rnn_sample.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import time 4 | import itertools 5 | import cPickle 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | import util 11 | import nottingham_util 12 | from model import Model, NottinghamModel 13 | from rnn import DefaultConfig 14 | 15 | if __name__ == '__main__': 16 | np.random.seed() 17 | 18 | parser = argparse.ArgumentParser(description='Script to generated a MIDI file sample from a trained model.') 19 | parser.add_argument('--config_file', type=str, required=True) 20 | parser.add_argument('--sample_melody', action='store_true', default=False) 21 | parser.add_argument('--sample_harmony', action='store_true', default=False) 22 | parser.add_argument('--sample_seq', type=str, default='random', 23 | choices = ['random', 'chords']) 24 | parser.add_argument('--conditioning', type=int, default=-1) 25 | parser.add_argument('--sample_length', type=int, default=512) 26 | 27 | args = parser.parse_args() 28 | 29 | with open(args.config_file, 'r') as f: 30 | config = cPickle.load(f) 31 | 32 | if config.dataset == 'softmax': 33 | config.time_batch_len = 1 34 | config.max_time_batches = -1 35 | model_class = NottinghamModel 36 | with open(nottingham_util.PICKLE_LOC, 'r') as f: 37 | pickle = cPickle.load(f) 38 | chord_to_idx = pickle['chord_to_idx'] 39 | 40 | time_step = 120 41 | resolution = 480 42 | 43 | # use time batch len of 1 so that every target is covered 44 | test_data = util.batch_data(pickle['test'], time_batch_len = 1, 45 | max_time_batches = -1, softmax = True) 46 | else: 47 | raise Exception("Other datasets not yet implemented") 48 | 49 | print config 50 | 51 | with tf.Graph().as_default(), tf.Session() as session: 52 | with tf.variable_scope("model", reuse=None): 53 | sampling_model = model_class(config) 54 | 55 | saver = tf.train.Saver(tf.all_variables()) 56 | model_path = os.path.join(os.path.dirname(args.config_file), 57 | config.model_name) 58 | saver.restore(session, model_path) 59 | 60 | state = sampling_model.get_cell_zero_state(session, 1) 61 | if args.sample_seq == 'chords': 62 | # 16 - one measure, 64 - chord progression 63 | repeats = args.sample_length / 64 64 | sample_seq = nottingham_util.i_vi_iv_v(chord_to_idx, repeats, config.input_dim) 65 | print 'Sampling melody using a I, VI, IV, V progression' 66 | 67 | elif args.sample_seq == 'random': 68 | sample_index = np.random.choice(np.arange(len(pickle['test']))) 69 | sample_seq = [ pickle['test'][sample_index][i, :] 70 | for i in range(pickle['test'][sample_index].shape[0]) ] 71 | 72 | chord = sample_seq[0] 73 | seq = [chord] 74 | 75 | if args.conditioning > 0: 76 | for i in range(1, args.conditioning): 77 | seq_input = np.reshape(chord, [1, 1, config.input_dim]) 78 | feed = { 79 | sampling_model.seq_input: seq_input, 80 | sampling_model.initial_state: state, 81 | } 82 | state = session.run(sampling_model.final_state, feed_dict=feed) 83 | chord = sample_seq[i] 84 | seq.append(chord) 85 | 86 | if config.dataset == 'softmax': 87 | writer = nottingham_util.NottinghamMidiWriter(chord_to_idx, verbose=False) 88 | sampler = nottingham_util.NottinghamSampler(chord_to_idx, verbose=False) 89 | else: 90 | # writer = midi_util.MidiWriter() 91 | # sampler = sampling.Sampler(verbose=False) 92 | raise Exception("Other datasets not yet implemented") 93 | 94 | for i in range(max(args.sample_length - len(seq), 0)): 95 | seq_input = np.reshape(chord, [1, 1, config.input_dim]) 96 | feed = { 97 | sampling_model.seq_input: seq_input, 98 | sampling_model.initial_state: state, 99 | } 100 | [probs, state] = session.run( 101 | [sampling_model.probs, sampling_model.final_state], 102 | feed_dict=feed) 103 | probs = np.reshape(probs, [config.input_dim]) 104 | chord = sampler.sample_notes(probs) 105 | 106 | if config.dataset == 'softmax': 107 | r = nottingham_util.NOTTINGHAM_MELODY_RANGE 108 | if args.sample_melody: 109 | chord[r:] = 0 110 | chord[r:] = sample_seq[i][r:] 111 | elif args.sample_harmony: 112 | chord[:r] = 0 113 | chord[:r] = sample_seq[i][:r] 114 | 115 | seq.append(chord) 116 | 117 | writer.dump_sequence_to_midi(seq, "best.midi", 118 | time_step=time_step, resolution=resolution) 119 | -------------------------------------------------------------------------------- /rnn_separate.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import time 4 | import itertools 5 | import cPickle 6 | import logging 7 | import random 8 | import string 9 | import pprint 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | import matplotlib.pyplot as plt 14 | 15 | import midi_util 16 | import nottingham_util 17 | import sampling 18 | import util 19 | from rnn import get_config_name, DefaultConfig 20 | from model import Model, NottinghamSeparate 21 | 22 | if __name__ == '__main__': 23 | np.random.seed() 24 | 25 | parser = argparse.ArgumentParser(description='Music RNN') 26 | parser.add_argument('--choice', type=str, default='melody', 27 | choices = ['melody', 'harmony']) 28 | parser.add_argument('--dataset', type=str, default='softmax', 29 | choices = ['bach', 'nottingham', 'softmax']) 30 | parser.add_argument('--model_dir', type=str, default='models') 31 | parser.add_argument('--run_name', type=str, default=time.strftime("%m%d_%H%M")) 32 | 33 | args = parser.parse_args() 34 | 35 | if args.dataset == 'softmax': 36 | resolution = 480 37 | time_step = 120 38 | model_class = NottinghamSeparate 39 | with open(nottingham_util.PICKLE_LOC, 'r') as f: 40 | pickle = cPickle.load(f) 41 | chord_to_idx = pickle['chord_to_idx'] 42 | 43 | input_dim = pickle["train"][0].shape[1] 44 | print 'Finished loading data, input dim: {}'.format(input_dim) 45 | else: 46 | raise Exception("Other datasets not yet implemented") 47 | 48 | 49 | initializer = tf.random_uniform_initializer(-0.1, 0.1) 50 | 51 | best_config = None 52 | best_valid_loss = None 53 | 54 | # set up run dir 55 | run_folder = os.path.join(args.model_dir, args.run_name) 56 | if os.path.exists(run_folder): 57 | raise Exception("Run name {} already exists, choose a different one", format(run_folder)) 58 | os.makedirs(run_folder) 59 | 60 | logger = logging.getLogger(__name__) 61 | logger.setLevel(logging.INFO) 62 | logger.addHandler(logging.StreamHandler()) 63 | logger.addHandler(logging.FileHandler(os.path.join(run_folder, "training.log"))) 64 | 65 | # grid 66 | grid = { 67 | "dropout_prob": [0.65], 68 | "input_dropout_prob": [0.9], 69 | "num_layers": [1], 70 | "hidden_size": [100] 71 | } 72 | 73 | # Generate product of hyperparams 74 | runs = list(list(itertools.izip(grid, x)) for x in itertools.product(*grid.itervalues())) 75 | logger.info("{} runs detected".format(len(runs))) 76 | 77 | for combination in runs: 78 | 79 | config = DefaultConfig() 80 | config.dataset = args.dataset 81 | config.model_name = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(12)) + '.model' 82 | for attr, value in combination: 83 | setattr(config, attr, value) 84 | 85 | if config.dataset == 'softmax': 86 | data = util.load_data('', time_step, config.time_batch_len, config.max_time_batches, nottingham=pickle) 87 | config.input_dim = data["input_dim"] 88 | else: 89 | raise Exception("Other datasets not yet implemented") 90 | 91 | # cut away unnecessary parts 92 | r = nottingham_util.NOTTINGHAM_MELODY_RANGE 93 | if args.choice == 'melody': 94 | print "Using only melody" 95 | for d in ['train', 'test', 'valid']: 96 | new_data = [] 97 | for batch_data, batch_targets in data[d]["data"]: 98 | new_data.append(([tb[:, :, :r] for tb in batch_data], 99 | [tb[:, :, 0] for tb in batch_targets])) 100 | data[d]["data"] = new_data 101 | else: 102 | print "Using only harmony" 103 | for d in ['train', 'test', 'valid']: 104 | new_data = [] 105 | for batch_data, batch_targets in data[d]["data"]: 106 | new_data.append(([tb[:, :, r:] for tb in batch_data], 107 | [tb[:, :, 1] for tb in batch_targets])) 108 | data[d]["data"] = new_data 109 | 110 | input_dim = data["input_dim"] = data["train"]["data"][0][0][0].shape[2] 111 | config.input_dim = input_dim 112 | print "New input dim: {}".format(input_dim) 113 | 114 | logger.info(config) 115 | config_file_path = os.path.join(run_folder, get_config_name(config) + '.config') 116 | with open(config_file_path, 'w') as f: 117 | cPickle.dump(config, f) 118 | 119 | with tf.Graph().as_default(), tf.Session() as session: 120 | with tf.variable_scope("model", reuse=None): 121 | train_model = model_class(config, training=True) 122 | with tf.variable_scope("model", reuse=True): 123 | valid_model = model_class(config, training=False) 124 | 125 | saver = tf.train.Saver(tf.all_variables()) 126 | tf.initialize_all_variables().run() 127 | 128 | # training 129 | early_stop_best_loss = None 130 | start_saving = False 131 | saved_flag = False 132 | train_losses, valid_losses = [], [] 133 | start_time = time.time() 134 | for i in range(config.num_epochs): 135 | loss = util.run_epoch(session, train_model, data["train"]["data"], training=True, testing=False) 136 | train_losses.append((i, loss)) 137 | if i == 0: 138 | continue 139 | 140 | valid_loss = util.run_epoch(session, valid_model, data["valid"]["data"], training=False, testing=False) 141 | valid_losses.append((i, valid_loss)) 142 | 143 | logger.info('Epoch: {}, Train Loss: {}, Valid Loss: {}, Time Per Epoch: {}'.format(\ 144 | i, loss, valid_loss, (time.time() - start_time)/i)) 145 | 146 | # if it's best validation loss so far, save it 147 | if early_stop_best_loss == None: 148 | early_stop_best_loss = valid_loss 149 | elif valid_loss < early_stop_best_loss: 150 | early_stop_best_loss = valid_loss 151 | if start_saving: 152 | logger.info('Best loss so far encountered, saving model.') 153 | saver.save(session, os.path.join(run_folder, config.model_name)) 154 | saved_flag = True 155 | elif not start_saving: 156 | start_saving = True 157 | logger.info('Valid loss increased for the first time, will start saving models') 158 | saver.save(session, os.path.join(run_folder, config.model_name)) 159 | saved_flag = True 160 | 161 | if not saved_flag: 162 | saver.save(session, os.path.join(run_folder, config.model_name)) 163 | 164 | # set loss axis max to 20 165 | axes = plt.gca() 166 | if config.dataset == 'softmax': 167 | axes.set_ylim([0, 2]) 168 | else: 169 | axes.set_ylim([0, 100]) 170 | plt.plot([t[0] for t in train_losses], [t[1] for t in train_losses]) 171 | plt.plot([t[0] for t in valid_losses], [t[1] for t in valid_losses]) 172 | plt.legend(['Train Loss', 'Validation Loss']) 173 | chart_file_path = os.path.join(run_folder, get_config_name(config) + '.png') 174 | plt.savefig(chart_file_path) 175 | plt.clf() 176 | 177 | logger.info("Config {}, Loss: {}".format(config, early_stop_best_loss)) 178 | if best_valid_loss == None or early_stop_best_loss < best_valid_loss: 179 | logger.info("Found best new model!") 180 | best_valid_loss = early_stop_best_loss 181 | best_config = config 182 | 183 | logger.info("Best Config: {}, Loss: {}".format(best_config, best_valid_loss)) 184 | -------------------------------------------------------------------------------- /rnn_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import cPickle 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | import util 9 | import nottingham_util 10 | from model import Model, NottinghamModel, NottinghamSeparate 11 | from rnn import DefaultConfig 12 | 13 | if __name__ == '__main__': 14 | np.random.seed() 15 | 16 | parser = argparse.ArgumentParser(description='Script to test a models performance against the test set') 17 | parser.add_argument('--config_file', type=str, required=True) 18 | parser.add_argument('--num_samples', type=int, default=1) 19 | parser.add_argument('--seperate', action='store_true', default=False) 20 | parser.add_argument('--choice', type=str, default='melody', 21 | choices = ['melody', 'harmony']) 22 | args = parser.parse_args() 23 | 24 | with open(args.config_file, 'r') as f: 25 | config = cPickle.load(f) 26 | 27 | if config.dataset == 'softmax': 28 | config.time_batch_len = 1 29 | config.max_time_batches = -1 30 | with open(nottingham_util.PICKLE_LOC, 'r') as f: 31 | pickle = cPickle.load(f) 32 | if args.seperate: 33 | model_class = NottinghamSeparate 34 | test_data = util.batch_data(pickle['test'], time_batch_len = 1, 35 | max_time_batches = -1, softmax = True) 36 | r = nottingham_util.NOTTINGHAM_MELODY_RANGE 37 | if args.choice == 'melody': 38 | print "Using only melody" 39 | new_data = [] 40 | for batch_data, batch_targets in test_data: 41 | new_data.append(([tb[:, :, :r] for tb in batch_data], 42 | [tb[:, :, 0] for tb in batch_targets])) 43 | test_data = new_data 44 | else: 45 | print "Using only harmony" 46 | new_data = [] 47 | for batch_data, batch_targets in test_data: 48 | new_data.append(([tb[:, :, r:] for tb in batch_data], 49 | [tb[:, :, 1] for tb in batch_targets])) 50 | test_data = new_data 51 | else: 52 | model_class = NottinghamModel 53 | # use time batch len of 1 so that every target is covered 54 | test_data = util.batch_data(pickle['test'], time_batch_len = 1, 55 | max_time_batches = -1, softmax = True) 56 | else: 57 | raise Exception("Other datasets not yet implemented") 58 | 59 | print config 60 | 61 | with tf.Graph().as_default(), tf.Session() as session: 62 | with tf.variable_scope("model", reuse=None): 63 | test_model = model_class(config, training=False) 64 | 65 | saver = tf.train.Saver(tf.all_variables()) 66 | model_path = os.path.join(os.path.dirname(args.config_file), 67 | config.model_name) 68 | saver.restore(session, model_path) 69 | 70 | test_loss, test_probs = util.run_epoch(session, test_model, test_data, 71 | training=False, testing=True) 72 | print 'Testing Loss: {}'.format(test_loss) 73 | 74 | if config.dataset == 'softmax': 75 | if args.seperate: 76 | nottingham_util.seperate_accuracy(test_probs, test_data, num_samples=args.num_samples) 77 | else: 78 | nottingham_util.accuracy(test_probs, test_data, num_samples=args.num_samples) 79 | 80 | else: 81 | util.accuracy(test_probs, test_data, num_samples=50) 82 | 83 | sys.exit(1) 84 | -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pprint import pprint 3 | 4 | import midi_util 5 | 6 | 7 | class Sampler(object): 8 | 9 | def __init__(self, min_prob=0.5, num_notes = 4, method = 'sample', verbose=False): 10 | self.min_prob = min_prob 11 | self.num_notes = num_notes 12 | self.method = method 13 | self.verbose = verbose 14 | 15 | def visualize_probs(self, probs): 16 | if not self.verbose: 17 | return 18 | print 'Highest four probs: ' 19 | pprint(sorted(list(enumerate(probs)), key=lambda x: x[1], 20 | reverse=True)[:4]) 21 | 22 | def sample_notes_prob(self, probs, max_notes=-1): 23 | """ Samples all notes that are over a certain probability""" 24 | self.visualize_probs(probs) 25 | top_idxs = list() 26 | for idx in probs.argsort()[::-1]: 27 | if max_notes > 0 and len(top_idxs) >= max_notes: 28 | break 29 | if probs[idx] < self.min_prob: 30 | break 31 | top_idxs.append(idx) 32 | chord = np.zeros([len(probs)], dtype=np.int32) 33 | chord[top_idxs] = 1.0 34 | return chord 35 | 36 | def sample_notes_static(self, probs): 37 | top_idxs = probs.argsort()[-self.num_notes:][::-1] 38 | chord = np.zeros([len(probs)], dtype=np.int32) 39 | chord[top_idxs] = 1.0 40 | return chord 41 | 42 | def sample_notes_bernoulli(self, probs): 43 | chord = np.zeros([len(probs)], dtype=np.int32) 44 | for note, prob in enumerate(probs): 45 | if np.random.binomial(1, prob) > 0: 46 | chord[note] = 1 47 | return chord 48 | 49 | def sample_notes(self, probs): 50 | """ Samples a static amount of notes from probabilities by highest prob """ 51 | self.visualize_probs(probs) 52 | if self.method == 'sample': 53 | return self.sample_notes_bernoulli(probs) 54 | elif self.method == 'static': 55 | return self.sample_notes_static(probs) 56 | elif self.method == 'min_prob': 57 | return self.sample_notes_prob(probs) 58 | else: 59 | raise Exception("Unrecognized method: {}".format(self.method)) 60 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import cPickle 4 | from collections import defaultdict 5 | from random import shuffle 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | import midi_util 11 | import nottingham_util 12 | 13 | def parse_midi_directory(input_dir, time_step): 14 | """ 15 | input_dir: data directory full of midi files 16 | time_step: the number of ticks to use as a time step for discretization 17 | 18 | Returns a list of [T x D] matrices, where T is the amount of time steps 19 | and D is the range of notes. 20 | """ 21 | files = [ os.path.join(input_dir, f) for f in os.listdir(input_dir) 22 | if os.path.isfile(os.path.join(input_dir, f)) ] 23 | sequences = [ \ 24 | (f, midi_util.parse_midi_to_sequence(f, time_step=time_step)) \ 25 | for f in files ] 26 | 27 | return sequences 28 | 29 | def batch_data(sequences, time_batch_len=128, max_time_batches=10, 30 | softmax=False, verbose=False): 31 | """ 32 | sequences: a list of [T x D] matrices, each matrix representing a sequencey 33 | time_batch_len: the unrolling length that will be used by BPTT. 34 | max_time_batches: the max amount of time batches to consider. Any sequences 35 | longert than max_time_batches * time_batch_len will be ignored 36 | Can be set to -1 to all time batches needed. 37 | softmax: Flag should be set to true if using the dual-softmax formualtion 38 | 39 | returns [ 40 | [ [ data ], [ target ] ], # batch with one time step 41 | [ [ data1, data2 ], [ target1, target2 ] ], # batch with two time steps 42 | ... 43 | ] 44 | """ 45 | 46 | assert time_batch_len > 0 47 | 48 | dims = sequences[0].shape[1] 49 | sequence_lens = [s.shape[0] for s in sequences] 50 | 51 | if verbose: 52 | avg_seq_len = sum(sequence_lens) / len(sequences) 53 | print "Average Sequence Length: {}".format(avg_seq_len) 54 | print "Max Sequence Length: {}".format(time_batch_len) 55 | print "Number of sequences: {}".format(len(sequences)) 56 | 57 | batches = defaultdict(list) 58 | for sequence in sequences: 59 | # -1 because we can't predict the first step 60 | num_time_steps = ((sequence.shape[0]-1) // time_batch_len) 61 | if num_time_steps < 1: 62 | continue 63 | if max_time_batches > 0 and num_time_steps > max_time_batches: 64 | continue 65 | batches[num_time_steps].append(sequence) 66 | 67 | if verbose: 68 | print "Batch distribution:" 69 | print [(k, len(v)) for (k, v) in batches.iteritems()] 70 | 71 | def arrange_batch(sequences, num_time_steps): 72 | sequences = [s[:(num_time_steps*time_batch_len)+1, :] for s in sequences] 73 | stacked = np.dstack(sequences) 74 | # swap axes so that shape is (SEQ_LENGTH X BATCH_SIZE X INPUT_DIM) 75 | data = np.swapaxes(stacked, 1, 2) 76 | targets = np.roll(data, -1, axis=0) 77 | # cutoff final time step 78 | data = data[:-1, :, :] 79 | targets = targets[:-1, :, :] 80 | assert data.shape == targets.shape 81 | 82 | if softmax: 83 | r = nottingham_util.NOTTINGHAM_MELODY_RANGE 84 | labels = np.ones((targets.shape[0], targets.shape[1], 2), dtype=np.int32) 85 | assert np.all(np.sum(targets[:, :, :r], axis=2) == 1) 86 | assert np.all(np.sum(targets[:, :, r:], axis=2) == 1) 87 | labels[:, :, 0] = np.argmax(targets[:, :, :r], axis=2) 88 | labels[:, :, 1] = np.argmax(targets[:, :, r:], axis=2) 89 | targets = labels 90 | assert targets.shape[:2] == data.shape[:2] 91 | 92 | assert data.shape[0] == num_time_steps * time_batch_len 93 | 94 | # split them up into time batches 95 | tb_data = np.split(data, num_time_steps, axis=0) 96 | tb_targets = np.split(targets, num_time_steps, axis=0) 97 | 98 | assert len(tb_data) == len(tb_targets) == num_time_steps 99 | for i in range(len(tb_data)): 100 | assert tb_data[i].shape[0] == time_batch_len 101 | assert tb_targets[i].shape[0] == time_batch_len 102 | if softmax: 103 | assert np.all(np.sum(tb_data[i], axis=2) == 2) 104 | 105 | return (tb_data, tb_targets) 106 | 107 | return [ arrange_batch(b, n) for n, b in batches.iteritems() ] 108 | 109 | def load_data(data_dir, time_step, time_batch_len, max_time_batches, nottingham=None): 110 | """ 111 | nottingham: The sequences object as created in prepare_nottingham_pickle 112 | (see nottingham_util for more). If None, parse all the MIDI 113 | files from data_dir 114 | time_step: the time_step used to parse midi files (only used if data_dir 115 | is provided) 116 | time_batch_len and max_time_batches: see batch_data() 117 | 118 | returns { 119 | "train": { 120 | "data": [ batch_data() ], 121 | "metadata: { ... } 122 | }, 123 | "valid": { ... } 124 | "test": { ... } 125 | } 126 | """ 127 | 128 | data = {} 129 | for dataset in ['train', 'test', 'valid']: 130 | 131 | # For testing, use ALL the sequences 132 | if dataset == 'test': 133 | max_time_batches = -1 134 | 135 | # Softmax formualation preparsed into sequences 136 | if nottingham: 137 | sequences = nottingham[dataset] 138 | metadata = nottingham[dataset + '_metadata'] 139 | # Cross-entropy formulation needs to be parsed 140 | else: 141 | sf = parse_midi_directory(os.path.join(data_dir, dataset), time_step) 142 | sequences = [s[1] for s in sf] 143 | files = [s[0] for s in sf] 144 | metadata = [{ 145 | 'path': f, 146 | 'name': f.split("/")[-1].split(".")[0] 147 | } for f in files] 148 | 149 | dataset_data = batch_data(sequences, time_batch_len, max_time_batches, softmax = True if nottingham else False) 150 | 151 | data[dataset] = { 152 | "data": dataset_data, 153 | "metadata": metadata, 154 | } 155 | 156 | data["input_dim"] = dataset_data[0][0][0].shape[2] 157 | 158 | return data 159 | 160 | 161 | def run_epoch(session, model, batches, training=False, testing=False): 162 | """ 163 | session: Tensorflow session object 164 | model: model object (see model.py) 165 | batches: data object loaded from util_data() 166 | 167 | training: A backpropagation iteration will be performed on the dataset 168 | if this flag is active 169 | 170 | returns average loss per time step over all batches. 171 | if testing flag is active: returns [ loss, probs ] where is the probability 172 | values for each note 173 | """ 174 | 175 | # shuffle batches 176 | shuffle(batches) 177 | 178 | target_tensors = [model.loss, model.final_state] 179 | if testing: 180 | target_tensors.append(model.probs) 181 | batch_probs = defaultdict(list) 182 | if training: 183 | target_tensors.append(model.train_step) 184 | 185 | losses = [] 186 | for data, targets in batches: 187 | # save state over unrolling time steps 188 | batch_size = data[0].shape[1] 189 | num_time_steps = len(data) 190 | state = model.get_cell_zero_state(session, batch_size) 191 | probs = list() 192 | 193 | for tb_data, tb_targets in zip(data, targets): 194 | if testing: 195 | tbd = tb_data 196 | tbt = tb_targets 197 | else: 198 | # shuffle all the batches of input, state, and target 199 | batches = tb_data.shape[1] 200 | permutations = np.random.permutation(batches) 201 | tbd = np.zeros_like(tb_data) 202 | tbd[:, np.arange(batches), :] = tb_data[:, permutations, :] 203 | tbt = np.zeros_like(tb_targets) 204 | tbt[:, np.arange(batches), :] = tb_targets[:, permutations, :] 205 | state[np.arange(batches)] = state[permutations] 206 | 207 | feed_dict = { 208 | model.initial_state: state, 209 | model.seq_input: tbd, 210 | model.seq_targets: tbt, 211 | } 212 | results = session.run(target_tensors, feed_dict=feed_dict) 213 | 214 | losses.append(results[0]) 215 | state = results[1] 216 | if testing: 217 | batch_probs[num_time_steps].append(results[2]) 218 | 219 | loss = sum(losses) / len(losses) 220 | 221 | if testing: 222 | return [loss, batch_probs] 223 | else: 224 | return loss 225 | 226 | def accuracy(batch_probs, data, num_samples=20): 227 | """ 228 | batch_probs: probs object returned from run_epoch 229 | data: data object passed into run_epoch 230 | num_samples: the number of times to sample each note (an average over all 231 | these samples will be used) 232 | 233 | returns the accuracy metric according to 234 | http://ismir2009.ismir.net/proceedings/PS2-21.pdf 235 | """ 236 | 237 | false_positives, false_negatives, true_positives = 0, 0, 0 238 | for _, batch_targets in data: 239 | num_time_steps = len(batch_data) 240 | for ts_targets, ts_probs in zip(batch_targets, batch_probs[num_time_steps]): 241 | 242 | assert ts_targets.shape == ts_targets.shape 243 | 244 | for seq_idx in range(ts_targets.shape[1]): 245 | for step_idx in range(ts_targets.shape[0]): 246 | for note_idx, prob in enumerate(ts_probs[step_idx, seq_idx, :]): 247 | num_occurrences = np.random.binomial(num_samples, prob) 248 | if ts_targets[step_idx, seq_idx, note_idx] == 0.0: 249 | false_positives += num_occurrences 250 | else: 251 | false_negatives += (num_samples - num_occurrences) 252 | true_positives += num_occurrences 253 | 254 | accuracy = (float(true_positives) / float(true_positives + false_positives + false_negatives)) 255 | 256 | print "Precision: {}".format(float(true_positives) / (float(true_positives + false_positives))) 257 | print "Recall: {}".format(float(true_positives) / (float(true_positives + false_negatives))) 258 | print "Accuracy: {}".format(accuracy) 259 | --------------------------------------------------------------------------------