├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── data ├── mid │ ├── Avril14th.mid │ └── Bounce.mid └── wav │ ├── Avril14th.wav │ └── Bounce.wav ├── examples ├── deep_mnist_with_summaries.py └── one_hot.py ├── keras_train.py ├── models.csv ├── notebooks ├── NMF.ipynb └── wavmidi_preprocess.ipynb ├── preprocess.py └── runs.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Jon Sleep 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wav2mid: Polyphonic Piano Music Transcription with Deep Neural Networks 2 | 3 | ### Thesis by Jonathan Sleep for MS in CSC @ CalPoly 4 | 5 | ## Abstract / Intro 6 | There has been a multitude of recent research on using deep learning for music & audio generation and classification. In this paper, we plan to build on these works by implementing a novel system to automatically transcribe polyphonic music with an artificial neural network model. We show that by treating the transcription problem as an image classification problem we can use transformed audio data to predict the group of notes currently being played. 7 | 8 | ## Background 9 | Digital Signal Processing: Fourier Transform, STFT, Constant-Q, Onset/Beat Tracking, Auto-correlation 10 | Machine Learning: Artificial Neural Networks, Convolutional Neural Networks, Recurrent Neural Networks 11 | 12 | ## Related Work on AMT 13 | * Pre-Deep Learning Research 14 | * [Non-negative matrix factorization for polyphonic music transcription](http://ieeexplore.ieee.org/abstract/document/1285860/) 15 | * Really cool paper for transcribing music using NMF - very simple. I wish there were more results shown with metrics like accuracy, but the work seemed clear. It would be cool to see if/how I could extend this. 16 | * [YIN, a fundamental frequency estimator for speech and music](asa.scitation.org/doi/abs/10.1121/1.1458024) - building off autocorrelation which produces an f0 estimator with even less error. 17 | 18 | * Research that use Deep Learning 19 | * [Modeling Temporal Dependencies in High-Dimensional Sequences: Application to Polyphonic Music Generation and Transcription](https://arxiv.org/abs/1206.6392) - Using a sequential model to aid in transcription. 20 | * [An End-to-End Neural Network for Polyphonic Piano Music Transcription](https://arxiv.org/abs/1508.01774) - Research on AMT that used an acoustic and language model, ~75% accuracy on MAPS 21 | * [On the Potential of Simple Framewise Approaches to Piano Transcription](https://arxiv.org/abs/1612.05153) - explains the current state-of-the-art and what the most effective architectures and input representations are for framewise transcription. 22 | * [An Experimental Analysis of the Entanglement Problem in Neural-Network-based Music Transcription Systems](https://arxiv.org/abs/1702.00025) - explains entanglement, which is the problem of learning to generalize note combinations that it may have not been trained with. Entanglement is the current glass ceiling problem for framewise neural network music transcription. They present a few (really just one) possible solutions that I could try to implement (a loss function that takes entanglement into account). 23 | 24 | * Products 25 | * [Melodyne](http://www.celemony.com/en/melodyne/what-is-melodyne) - Popular plugin for transcription + pitch correction, costs up to $500 26 | * [AnthemScore](https://www.lunaverus.com/cnn) - A product for Music Transcription that uses deep learning. 27 | 28 | ## Design 29 | The design for the system is as follows: 30 | * Pre-process our data into an ingestible format, fourier-like transform of the audio and piano-roll conversion of midi files. 31 | * Design a neural network model to estimate current notes from audio data 32 | * Use frame-wise (simpler) or onsets (faster) 33 | * Train on a large corpus of audio to midi 34 | * Evaluate it's performance on audio/midi pairs we have not trained on 35 | 36 | ## Implementation 37 | ### Libraries 38 | * Python - due to the abundance of music and machine learning libraries developed for it 39 | * librosa - for digital signal processing methods 40 | * pretty_midi - for midi manipulation methods 41 | * TensorFlow - for neural networks 42 | 43 | ## Data 44 | * [MAPS dataset](http://www.tsi.telecom-paristech.fr/aao/en/2010/07/08/maps-database-a-piano-database-for-multipitch-estimation-and-automatic-transcription-of-music/) 45 | * [Large MIDI collection](https://www.reddit.com/r/WeAreTheMusicMakers/comments/3ajwe4/the_largest_midi_collection_on_the_internet/) 46 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | ''' 6 | This script is for creating and loading a JSON structure that will hold parameters that are to be 7 | held constant for preprocessing data for, training, and testing each model used. 8 | ''' 9 | 10 | 11 | def load_config(json_fn): 12 | with open(json_fn, 'r') as infile: 13 | config = json.load(infile) 14 | return config 15 | 16 | def create_config(args): 17 | path = os.path.join('models',args['model_name']) 18 | if not os.path.exists(path): 19 | os.mkdir(path) 20 | with open(os.path.join(path,'config.json'), 'w') as outfile: 21 | json.dump(args, outfile) 22 | 23 | 24 | if __name__ == '__main__': 25 | # Set up command-line argument parsing 26 | parser = argparse.ArgumentParser( 27 | description='Create a config JSON') 28 | 29 | #possible types/values 30 | #model_name,spec_type,init_lr,lr_decay,bin_multiple,residual,filter_shape 31 | #baseline,cqt,1e-2,linear,36,False,some 32 | #new,logstft,1e-1,geo,96,True,full 33 | 34 | parser.add_argument('model_name', 35 | help='model name. will create a directory for model where config,data,etc will go') 36 | parser.add_argument('spec_type', 37 | help='Spectrogram Type, cqt or logstft') 38 | parser.add_argument('init_lr', type=float, 39 | help='Initial Learning Rate') 40 | parser.add_argument('lr_decay', 41 | help='How the Learning Rate Will Decay') 42 | parser.add_argument('bin_multiple', type=int, 43 | help='Used to calculate bins_per_octave') 44 | parser.add_argument('residual', type=bool, 45 | help='Use Residual Connections or not') 46 | parser.add_argument('full_window', 47 | help='Whether or not the convolution window spans the full axis') 48 | 49 | ''' These are all constant. 50 | parser.add_argument('--sr', type=int, default=22050, 51 | help='Sampling Rate') 52 | parser.add_argument('--hl', type=int, default=512, 53 | help='Hop Length') 54 | parser.add_argument('--ws', type=int, default=7, 55 | help='Window Size') 56 | parser.add_argument('--bm', type=int, default=3, 57 | help='Bin Multiple') 58 | parser.add_argument('--min', type=int, default=21, #A0 59 | help='Min MIDI value') 60 | parser.add_argument('--max', type=int, default=108, #C8 61 | help='Max MIDI value')''' 62 | 63 | args = vars(parser.parse_args()) 64 | 65 | create_config(args) 66 | -------------------------------------------------------------------------------- /data/mid/Avril14th.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsleep/wav2mid/bfc4de0b9e7bccea390d30661e5b80f03f9c5ef2/data/mid/Avril14th.mid -------------------------------------------------------------------------------- /data/mid/Bounce.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsleep/wav2mid/bfc4de0b9e7bccea390d30661e5b80f03f9c5ef2/data/mid/Bounce.mid -------------------------------------------------------------------------------- /data/wav/Avril14th.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsleep/wav2mid/bfc4de0b9e7bccea390d30661e5b80f03f9c5ef2/data/wav/Avril14th.wav -------------------------------------------------------------------------------- /data/wav/Bounce.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsleep/wav2mid/bfc4de0b9e7bccea390d30661e5b80f03f9c5ef2/data/wav/Bounce.wav -------------------------------------------------------------------------------- /examples/deep_mnist_with_summaries.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A simple MNIST classifier which displays summaries in TensorBoard. 16 | 17 | This is an unimpressive MNIST model, but it is a good example of using 18 | tf.name_scope to make a graph legible in the TensorBoard graph explorer, and of 19 | naming summary tags so that they are grouped meaningfully in TensorBoard. 20 | 21 | It demonstrates the functionality of every TensorBoard dashboard. 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import argparse 28 | import sys 29 | 30 | import math 31 | 32 | import tensorflow as tf 33 | 34 | from tensorflow.examples.tutorials.mnist import input_data 35 | 36 | FLAGS = None 37 | 38 | 39 | def train(): 40 | # Import data 41 | mnist = input_data.read_data_sets(FLAGS.data_dir, 42 | one_hot=True, 43 | fake_data=FLAGS.fake_data) 44 | 45 | #config = tf.ConfigProto(log_device_placement=True) 46 | sess = tf.InteractiveSession() 47 | # Create a multilayer model. 48 | 49 | # Input placeholders 50 | with tf.name_scope('input'): 51 | x = tf.placeholder(tf.float32, [None, 784], name='x-input') 52 | y_ = tf.placeholder(tf.float32, [None, 10], name='y-input') 53 | 54 | with tf.name_scope('input_reshape'): 55 | image_shaped_input = tf.reshape(x, [-1, 28, 28, 1]) 56 | tf.summary.image('input', image_shaped_input, 10) 57 | 58 | # We can't initialize these variables to 0 - the network will get stuck. 59 | def weight_variable(shape): 60 | """Create a weight variable with appropriate initialization.""" 61 | initial = tf.truncated_normal(shape, stddev=0.1) 62 | return tf.Variable(initial) 63 | 64 | def bias_variable(shape): 65 | """Create a bias variable with appropriate initialization.""" 66 | initial = tf.constant(0.1, shape=shape) 67 | return tf.Variable(initial) 68 | 69 | def variable_summaries(var): 70 | """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" 71 | with tf.name_scope('summaries'): 72 | mean = tf.reduce_mean(var) 73 | tf.summary.scalar('mean', mean) 74 | with tf.name_scope('stddev'): 75 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 76 | tf.summary.scalar('stddev', stddev) 77 | tf.summary.scalar('max', tf.reduce_max(var)) 78 | tf.summary.scalar('min', tf.reduce_min(var)) 79 | tf.summary.histogram('histogram', var) 80 | 81 | 82 | 83 | def fc_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu): 84 | """Reusable code for making a fully connected layer. 85 | 86 | It does a matrix multiply, bias add, and then uses relu to nonlinearize. 87 | It also sets up name scoping so that the resultant graph is easy to read, 88 | and adds a number of summary ops. 89 | """ 90 | # Adding a name scope ensures logical grouping of the layers in the graph. 91 | with tf.name_scope(layer_name): 92 | # This Variable will hold the state of the weights for the layer 93 | with tf.name_scope('weights'): 94 | weights = weight_variable([input_dim, output_dim]) 95 | variable_summaries(weights) 96 | with tf.name_scope('biases'): 97 | biases = bias_variable([output_dim]) 98 | variable_summaries(biases) 99 | with tf.name_scope('Wx_plus_b'): 100 | preactivate = tf.matmul(input_tensor, weights) + biases 101 | tf.summary.histogram('pre_activations', preactivate) 102 | activations = act(preactivate, name='activation') 103 | tf.summary.histogram('activations', activations) 104 | return activations 105 | 106 | def conv2d(x, W): 107 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 108 | 109 | def max_pool_2x2(x): 110 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], 111 | strides=[1, 2, 2, 1], padding='SAME') 112 | 113 | def convnet_layer(input_tensor, input_features, output_features, layer_name, act=tf.nn.relu): 114 | """Reusable code for making a conv net layer 115 | 116 | It does a 1x1 stride convolution to find features with shared weights. Then it 117 | adds bias, applys the activation function. Finally, it does 2x2 pooling to reduce image size by half. 118 | """ 119 | with tf.name_scope(layer_name): 120 | # conv over 5x5 patches 121 | with tf.name_scope('weights'): 122 | weights = weight_variable([5, 5, input_features, output_features]) 123 | variable_summaries(weights) 124 | with tf.name_scope('biases'): 125 | biases = bias_variable([output_features]) 126 | variable_summaries(biases) 127 | with tf.name_scope('Wx_plus_b'): 128 | preactivate = conv2d(input_tensor, weights) + biases 129 | tf.summary.histogram('pre_activations', preactivate) 130 | activations = act(preactivate, name='activation') 131 | tf.summary.histogram('activations', activations) 132 | with tf.name_scope('pooling'): 133 | pooling = max_pool_2x2(activations) 134 | tf.summary.histogram('pooling', pooling) 135 | return pooling 136 | 137 | #reshape 2-d tensor to 4-d for convolutions 138 | x_image = tf.reshape(x, [-1,28,28,1]) 139 | #first layer to compute 32 features 140 | conv_layer1 = convnet_layer(x_image,1,32,'convnet_layer1') 141 | #second layer to compute 64 features 142 | conv_layer2 = convnet_layer(conv_layer1,32,64,'convnet_layer2') 143 | #turn 4-d tensor into 2-d 144 | conv_layer2_flat = tf.reshape(conv_layer2,[-1,7*7*64]) 145 | #here we have intermediate data that is size 7*7 with 64 features, into a 1024 wide layer 146 | fc_layer1 = fc_layer(conv_layer2_flat, 7*7*64, 1024, 'fc_layer1') 147 | 148 | #to reduce overfitting, lets apply dropout before the readout 149 | with tf.name_scope('dropout'): 150 | keep_prob = tf.placeholder(tf.float32) 151 | tf.summary.scalar('dropout_keep_probability', keep_prob) 152 | fc1_drop = tf.nn.dropout(fc_layer1, keep_prob) 153 | 154 | #readout, 1024-> 10 classes 155 | y = fc_layer(fc1_drop, 1024, 10, 'output', act=tf.identity) 156 | 157 | # Do not apply softmax activation yet, see below. 158 | 159 | with tf.name_scope('cross_entropy'): 160 | # The raw formulation of cross-entropy, 161 | # 162 | # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.softmax(y)), 163 | # reduction_indices=[1])) 164 | # 165 | # can be numerically unstable. 166 | # 167 | # So here we use tf.nn.softmax_cross_entropy_with_logits on the 168 | # raw outputs of the nn_layer above, and then average across 169 | # the batch. 170 | diff = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_) 171 | with tf.name_scope('total'): 172 | cross_entropy = tf.reduce_mean(diff) 173 | tf.summary.scalar('cross_entropy', cross_entropy) 174 | 175 | with tf.name_scope('optimizer'): 176 | train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(cross_entropy) 177 | 178 | with tf.name_scope('accuracy'): 179 | with tf.name_scope('correct_prediction'): 180 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 181 | with tf.name_scope('accuracy'): 182 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 183 | tf.summary.scalar('accuracy', accuracy) 184 | 185 | # Merge all the summaries and write them out to /tmp/mnist_logs (by default) 186 | merged = tf.summary.merge_all() 187 | train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) 188 | test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test') 189 | tf.global_variables_initializer().run() 190 | 191 | # Train the model, and also write summaries. 192 | # Every 10th step, measure test-set accuracy, and write test summaries 193 | # All other steps, run train_step on training data, & add training summaries 194 | 195 | def feed_dict(train): 196 | """Make a TensorFlow feed_dict: maps data onto Tensor placeholders.""" 197 | if train or FLAGS.fake_data: 198 | xs, ys = mnist.train.next_batch(FLAGS.batch_size, fake_data=FLAGS.fake_data) 199 | k = FLAGS.dropout 200 | else: 201 | xs, ys = mnist.test.next_batch(FLAGS.batch_size) 202 | k = 1.0 203 | 204 | return {x: xs, y_: ys, keep_prob: k} 205 | 206 | for i in range(FLAGS.max_steps): 207 | if i > 0 and i % 100 == 0: # Record summaries and test-set accuracy 208 | #my gpu (surfacebook 1 dpgu) cant handle fitting 10000 size batch for testing, so lets put this calculation on the cpu 209 | summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False)) 210 | test_writer.add_summary(summary, i) 211 | print('Mini-batch Test Accuracy at step %s: %s' % (i, acc)) 212 | else: # Record train set summaries, and train 213 | if i % 1000 == 999: # Record execution stats 214 | run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 215 | run_metadata = tf.RunMetadata() 216 | summary, _ = sess.run([merged, train_step], 217 | feed_dict=feed_dict(True), 218 | options=run_options, 219 | run_metadata=run_metadata) 220 | train_writer.add_run_metadata(run_metadata, 'step%03d' % i) 221 | train_writer.add_summary(summary, i) 222 | print('Adding run metadata for', i) 223 | else: # Record a summary 224 | summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True)) 225 | train_writer.add_summary(summary, i) 226 | print('num test images: %s'%(len(mnist.test.images))) 227 | num_test_batches = int(len(mnist.test.images) / FLAGS.batch_size) 228 | print('num test batches: %s'%(num_test_batches)) 229 | acc_sum, acc_mean = 0 , 0 230 | for i in range(num_test_batches): 231 | summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False)) 232 | acc_sum += acc 233 | acc_mean = acc_sum / num_test_batches 234 | print('Final Test Accuracy: %s' % (acc_mean)) 235 | '''for i in range(FLAGS.max_steps): 236 | batch = mnist.train.next_batch(50) 237 | if i%100 == 0: 238 | train_accuracy = accuracy.eval(feed_dict={ 239 | x:batch[0], y_: batch[1], keep_prob: 1.0}) 240 | print("step %d, training accuracy %g"%(i, train_accuracy)) 241 | train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 242 | 243 | print("test accuracy %g"%accuracy.eval(feed_dict={ 244 | x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))''' 245 | 246 | 247 | train_writer.close() 248 | test_writer.close() 249 | 250 | 251 | def main(_): 252 | if tf.gfile.Exists(FLAGS.log_dir): 253 | tf.gfile.DeleteRecursively(FLAGS.log_dir) 254 | tf.gfile.MakeDirs(FLAGS.log_dir) 255 | train() 256 | 257 | 258 | if __name__ == '__main__': 259 | parser = argparse.ArgumentParser() 260 | parser.add_argument('--fake_data', nargs='?', const=True, type=bool, 261 | default=False, 262 | help='If true, uses fake data for unit testing.') 263 | parser.add_argument('--max_steps', type=int, default=20000, 264 | help='Number of steps to run trainer.') 265 | parser.add_argument('--learning_rate', type=float, default=1e-4, 266 | help='Initial learning rate') 267 | parser.add_argument('--dropout', type=float, default=0.9, 268 | help='Keep probability for training dropout.') 269 | parser.add_argument('--data_dir', type=str, default='./data/input_data', 270 | help='Directory for storing input data') 271 | parser.add_argument('--log_dir', type=str, default='./logs/deep_mnist_with_summaries', 272 | help='Summaries log directory') 273 | parser.add_argument('--batch_size', type=int, default=200, 274 | help='How many images to feed into network at once') 275 | FLAGS, unparsed = parser.parse_known_args() 276 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 277 | -------------------------------------------------------------------------------- /examples/one_hot.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | """ 3 | Simple function for converting Pretty MIDI object into one-hot encoding 4 | / piano-roll-like to be used for machine learning. 5 | """ 6 | import pretty_midi 7 | import numpy as np 8 | import sys 9 | import argparse 10 | 11 | def pretty_midi_to_one_hot(pm, fs=100): 12 | """Compute a one hot matrix of a pretty midi object 13 | 14 | Parameters 15 | ---------- 16 | pm : pretty_midi.PrettyMIDI 17 | A pretty_midi.PrettyMIDI class instance describing 18 | the piano roll. 19 | fs : int 20 | Sampling frequency of the columns, i.e. each column is spaced apart 21 | by ``1./fs`` seconds. 22 | 23 | Returns 24 | ------- 25 | one_hot : np.ndarray, shape=(128,times.shape[0]) 26 | Piano roll of this instrument. 1 represents Note Ons, 27 | -1 represents Note offs, 0 represents constant/do-nothing 28 | """ 29 | 30 | # Allocate a matrix of zeros - we will add in as we go 31 | one_hots = [] 32 | 33 | for instrument in pm.instruments: 34 | one_hot = np.zeros((128, int(fs*instrument.get_end_time())+1)) 35 | for note in instrument.notes: 36 | # note on 37 | one_hot[note.pitch, int(note.start*fs)] = 1 38 | print('note on',note.pitch, int(note.start*fs)) 39 | # note off 40 | one_hot[note.pitch, int(note.end*fs)] = 0 41 | print('note off',note.pitch, int(note.end*fs)) 42 | one_hots.append(one_hot) 43 | 44 | one_hot = np.zeros((128, np.max([o.shape[1] for o in one_hots]))) 45 | for o in one_hots: 46 | one_hot[:, :o.shape[1]] += o 47 | 48 | one_hot = np.clip(one_hot,-1,1) 49 | return one_hot 50 | 51 | def one_hot_to_pretty_midi(one_hot, fs=100, program=1,bpm=120): 52 | '''Convert a Piano Roll array into a PrettyMidi object 53 | with a single instrument. 54 | 55 | Parameters 56 | ---------- 57 | piano_roll : np.ndarray, shape=(128,time) 58 | Piano roll of one instrument 59 | fs : int 60 | Sampling frequency of the columns, i.e. each column is spaced apart 61 | by ``1./fs`` seconds. 62 | program : int 63 | The program number of the instrument. 64 | bpm : int 65 | Beats per minute, used to decide when to re-emphasize notes left on. 66 | 67 | Returns 68 | ------- 69 | midi_object : pretty_midi.PrettyMIDI 70 | A pretty_midi.PrettyMIDI class instance describing 71 | the piano roll. 72 | 73 | ''' 74 | notes, frames = one_hot.shape 75 | pm = pretty_midi.PrettyMIDI() 76 | instrument = pretty_midi.Instrument(program=program) 77 | 78 | # prepend, append zeros so we can acknowledge inital and ending events 79 | piano_roll = np.hstack((np.zeros((notes, 1)), 80 | one_hot, 81 | np.zeros((notes, 1)))) 82 | 83 | # use changes to find note on / note off events 84 | changes = np.nonzero(np.diff(piano_roll).T) 85 | 86 | # keep track of note on times and notes currently playing 87 | note_on_time = np.zeros(notes) 88 | current_notes = np.zeros(notes) 89 | 90 | bps = bpm / 60 91 | beat_interval = fs / bps 92 | strong_beats = beat_interval * 2 #(for 4/4 timing) 93 | 94 | last_beat_time = 0 95 | 96 | for time, note in zip(*changes): 97 | change = piano_roll[note, time + 1] 98 | 99 | if time >= last_beat_time + beat_interval: 100 | for note in current_notes: 101 | 102 | time = time / fs 103 | 104 | 105 | if change == 1: 106 | # note on 107 | if current_notes[note] == 0: 108 | # from note off 109 | note_on_time[note] = time 110 | current_notes[note] = 1 111 | else: 112 | #re-articulate (later in code) 113 | '''pm_note = pretty_midi.Note( 114 | velocity=100, #don't care fer now 115 | pitch=note, 116 | start=note_on_time[note], 117 | end=time) 118 | instrument.notes.append(pm_note) 119 | note_on_time[note] = time 120 | current_notes[note] = 1''' 121 | elif change == 0: 122 | #note off 123 | pm_note = pretty_midi.Note( 124 | velocity=100, #don't care fer now 125 | pitch=note, 126 | start=note_on_time[note], 127 | end=time) 128 | current_notes[note] = 0 129 | instrument.notes.append(pm_note) 130 | pm.instruments.append(instrument) 131 | return pm 132 | 133 | if __name__ == '__main__': 134 | # Set up command-line argument parsing 135 | parser = argparse.ArgumentParser( 136 | description='Translate MIDI file to piano roll and back', 137 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 138 | 139 | parser.add_argument('input_midi', action='store', 140 | help='Path to the input MIDI file') 141 | parser.add_argument('output_midi', action='store', 142 | help='Path where the translated MIDI will be written') 143 | parser.add_argument('--fs', default=100, type=int, action='store', 144 | help='Sampling rate to use between conversions') 145 | parser.add_argument('--program', default=1, type=int, action='store', 146 | help='Program of the instrument') 147 | 148 | parameters = vars(parser.parse_args(sys.argv[1:])) 149 | pm = pretty_midi.PrettyMIDI(parameters['input_midi']) 150 | #print(pm.instruments[0].notes) 151 | oh = pretty_midi_to_one_hot(pm, fs=parameters['fs']) 152 | new_pm = one_hot_to_pretty_midi(oh, fs=parameters['fs'], 153 | program=parameters['program']) 154 | #print(new_pm.instruments[0].notes) 155 | new_pm.write(parameters['output_midi']) 156 | -------------------------------------------------------------------------------- /keras_train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | keras: CNN Transcription model 4 | 5 | ''' 6 | #from __future__ import print_function 7 | import argparse 8 | 9 | import matplotlib.pyplot as plt 10 | 11 | #keras utils 12 | from keras.callbacks import Callback 13 | from keras import metrics 14 | from keras.models import Model, load_model 15 | from keras.layers import Dense, Dropout, Flatten, Reshape, Input 16 | from keras.layers import Conv2D, MaxPooling2D, add 17 | from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, CSVLogger 18 | from keras.layers.normalization import BatchNormalization 19 | from keras.layers import Activation 20 | from keras.optimizers import SGD 21 | from keras import backend as K 22 | from keras.utils import plot_model 23 | 24 | 25 | import tensorflow as tf 26 | import sklearn 27 | from sklearn.metrics import precision_recall_fscore_support 28 | 29 | #internal utils 30 | from preprocess import DataGen 31 | from config import load_config 32 | 33 | import numpy as np 34 | 35 | import os 36 | 37 | 38 | def opt_thresholds(y_true,y_scores): 39 | othresholds = np.zeros(y_scores.shape[1]) 40 | print othresholds.shape 41 | for label, (label_scores, true_bin) in enumerate(zip(y_scores.T,y_true.T)): 42 | #print label 43 | precision, recall, thresholds = sklearn.metrics.precision_recall_curve(true_bin, label_scores) 44 | max_f1 = 0 45 | max_f1_threshold = .5 46 | for r, p, t in zip(recall, precision, thresholds): 47 | if p + r == 0: continue 48 | if (2*p*r)/(p + r) > max_f1: 49 | max_f1 = (2*p*r)/(p + r) 50 | max_f1_threshold = t 51 | #print label, ": ", max_f1_threshold, "=>", max_f1 52 | othresholds[label] = max_f1_threshold 53 | print othresholds 54 | return othresholds 55 | 56 | class linear_decay(Callback): 57 | ''' 58 | decay = decay value to subtract each epoch 59 | ''' 60 | def __init__(self, initial_lr,epochs): 61 | super(linear_decay, self).__init__() 62 | self.initial_lr = initial_lr 63 | self.decay = initial_lr/epochs 64 | 65 | def on_epoch_begin(self, epoch, logs={}): 66 | new_lr = self.initial_lr - self.decay*epoch 67 | print("ld: learning rate is now "+str(new_lr)) 68 | K.set_value(self.model.optimizer.lr, new_lr) 69 | 70 | class half_decay(Callback): 71 | ''' 72 | decay = decay value to subtract each epoch 73 | ''' 74 | def __init__(self, initial_lr,period): 75 | super(half_decay, self).__init__() 76 | self.init_lr = initial_lr 77 | self.period = period 78 | 79 | def on_epoch_begin(self, epoch, logs={}): 80 | factor = epoch // self.period 81 | lr = self.init_lr / (2**factor) 82 | print("hd: learning rate is now "+str(lr)) 83 | K.set_value(self.model.optimizer.lr, lr) 84 | 85 | class Threshold(Callback): 86 | ''' 87 | decay = decay value to subtract each epoch 88 | ''' 89 | def __init__(self, val_data): 90 | super(Threshold, self).__init__() 91 | self.val_data = val_data 92 | _,y = val_data 93 | self.othresholds = np.full(y.shape[1],0.5) 94 | 95 | def on_epoch_end(self, epoch, logs={}): 96 | #find optimal thresholds on validation data 97 | x,y_true = self.val_data 98 | y_scores = self.model.predict(x) 99 | self.othresholds = opt_thresholds(y_true,y_scores) 100 | y_pred = y_scores > self.othresholds 101 | p,r,f,s = sklearn.metrics.precision_recall_fscore_support(y_true,y_pred,average='micro') 102 | print "validation p,r,f,s:" 103 | print p,r,f,s 104 | 105 | def baseline_model(): 106 | inputs = Input(shape=input_shape) 107 | reshape = Reshape(input_shape_channels)(inputs) 108 | 109 | #normal convnet layer (have to do one initially to get 64 channels) 110 | conv1 = Conv2D(50,(5,25),activation='tanh')(reshape) 111 | do1 = Dropout(0.5)(conv1) 112 | pool1 = MaxPooling2D(pool_size=(1,3))(do1) 113 | 114 | conv2 = Conv2D(50,(3,5),activation='tanh')(pool1) 115 | do2 = Dropout(0.5)(conv2) 116 | pool2 = MaxPooling2D(pool_size=(1,3))(do2) 117 | 118 | flattened = Flatten()(pool2) 119 | fc1 = Dense(1000, activation='sigmoid')(flattened) 120 | do3 = Dropout(0.5)(fc1) 121 | 122 | fc2 = Dense(200, activation='sigmoid')(do3) 123 | do4 = Dropout(0.5)(fc2) 124 | outputs = Dense(note_range, activation='sigmoid')(do4) 125 | 126 | model = Model(inputs=inputs, outputs=outputs) 127 | return model 128 | 129 | 130 | 131 | def resnet_model(bin_multiple): 132 | 133 | #input and reshape 134 | inputs = Input(shape=input_shape) 135 | reshape = Reshape(input_shape_channels)(inputs) 136 | 137 | #normal convnet layer (have to do one initially to get 64 channels) 138 | conv = Conv2D(64,(1,bin_multiple*note_range),padding="same",activation='relu')(reshape) 139 | pool = MaxPooling2D(pool_size=(1,2))(conv) 140 | 141 | for i in range(int(np.log2(bin_multiple))-1): 142 | print i 143 | #residual block 144 | bn = BatchNormalization()(pool) 145 | re = Activation('relu')(bn) 146 | freq_range = (bin_multiple/(2**(i+1)))*note_range 147 | print freq_range 148 | conv = Conv2D(64,(1,freq_range),padding="same",activation='relu')(re) 149 | 150 | #add and downsample 151 | ad = add([pool,conv]) 152 | pool = MaxPooling2D(pool_size=(1,2))(ad) 153 | 154 | flattened = Flatten()(pool) 155 | fc = Dense(1024, activation='relu')(flattened) 156 | do = Dropout(0.5)(fc) 157 | fc = Dense(512, activation='relu')(do) 158 | do = Dropout(0.5)(fc) 159 | outputs = Dense(note_range, activation='sigmoid')(do) 160 | 161 | model = Model(inputs=inputs, outputs=outputs) 162 | 163 | return model 164 | 165 | window_size = 7 166 | min_midi = 21 167 | max_midi = 108 168 | note_range = max_midi - min_midi + 1 169 | 170 | 171 | def train(args): 172 | path = os.path.join('models',args['model_name']) 173 | config = load_config(os.path.join(path,'config.json')) 174 | 175 | global feature_bins 176 | global input_shape 177 | global input_shape_channels 178 | 179 | bin_multiple = int(args['bin_multiple']) 180 | print('bin multiple',str(np.log2(bin_multiple))) 181 | feature_bins = note_range * bin_multiple 182 | input_shape = (window_size,feature_bins) 183 | input_shape_channels = (window_size,feature_bins,1) 184 | 185 | #filenames 186 | model_ckpt = os.path.join(path,'ckpt.h5') 187 | 188 | #train params 189 | batch_size = 256 190 | epochs = 1000 191 | 192 | trainGen = DataGen(os.path.join(path,'data','train'),batch_size,args) 193 | valGen = DataGen(os.path.join(path,'data','val'),batch_size,args) 194 | #valData = load_data(os.path.join(path,'data','val')) 195 | 196 | 197 | if os.path.isfile(model_ckpt): 198 | print('loading model') 199 | model = load_model(model_ckpt) 200 | else: 201 | print('training new model from scratch') 202 | if bool(args['residual']): 203 | model = resnet_model(bin_multiple) 204 | else: 205 | model = baseline_model() 206 | 207 | init_lr = float(args['init_lr']) 208 | 209 | model.compile(loss='binary_crossentropy', 210 | optimizer=SGD(lr=init_lr,momentum=0.9)) 211 | model.summary() 212 | plot_model(model, to_file=os.path.join(path,'model.png')) 213 | 214 | checkpoint = ModelCheckpoint(model_ckpt, monitor='val_loss', verbose=1, save_best_only=True, mode='min') 215 | early_stop = EarlyStopping(patience=5,monitor='val_loss', verbose=1, mode='min') 216 | #tensorboard = TensorBoard(log_dir='./logs/baseline/', histogram_freq=250, batch_size=batch_size) 217 | if args['lr_decay'] == 'linear': 218 | decay = linear_decay(init_lr,epochs) 219 | else: 220 | decay = half_decay(init_lr,5) 221 | csv_logger = CSVLogger(os.path.join(path,'training.log')) 222 | #t = Threshold(valData) 223 | callbacks = [checkpoint,early_stop,decay,csv_logger] 224 | 225 | history = model.fit_generator(trainGen.next(),trainGen.steps(), epochs=epochs, 226 | verbose=1,validation_data=valGen.next(),validation_steps=valGen.steps(),callbacks=callbacks) 227 | 228 | # list all data in history 229 | print(history.history.keys()) 230 | # summarize history for accuracy 231 | '''plt.plot(history.history['acc']) 232 | plt.plot(history.history['val_acc']) 233 | plt.title('model accuracy') 234 | plt.ylabel('accuracy') 235 | plt.xlabel('epoch') 236 | plt.legend(['train', 'val'], loc='upper left') 237 | plt.savefig('baseline/acc.png')''' 238 | 239 | # summarize history for loss 240 | plt.plot(history.history['loss']) 241 | plt.plot(history.history['val_loss']) 242 | plt.title('model loss') 243 | plt.ylabel('loss') 244 | plt.xlabel('epoch') 245 | plt.legend(['train', 'val'], loc='upper left') 246 | plt.savefig('baseline/loss.png') 247 | 248 | #test 249 | testGen = DataGen(os.path.join(path,'data','test'),batch_size,args) 250 | 251 | res = model.evaluate_generator(testGen.next(),steps=testGen.steps()) 252 | print(model.metrics_names) 253 | print(res) 254 | 255 | def main(): 256 | #train 257 | parser = argparse.ArgumentParser( 258 | description='Preprocess MIDI/Audio file pairs into ingestible data') 259 | parser.add_argument('model_name', 260 | help='Path to the model directory where data should reside') 261 | 262 | args = vars(parser.parse_args()) 263 | train(args) 264 | 265 | 266 | if __name__ == '__main__': 267 | main() 268 | -------------------------------------------------------------------------------- /models.csv: -------------------------------------------------------------------------------- 1 | model_name,spec_type,init_lr,lr_decay,bin_multiple,residual 2 | new,cqt,1e-1,half,4,True 3 | baseline,cqt,1e-2,linear,3,False 4 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | from collections import defaultdict 4 | import sys, os 5 | import argparse 6 | 7 | import madmom 8 | import numpy as np 9 | import pandas as pd 10 | import pretty_midi 11 | import librosa 12 | import h5py 13 | import math 14 | 15 | from config import load_config 16 | 17 | import numpy as np 18 | 19 | def readmm(d,args): 20 | ipath = os.path.join(d,'input.dat') 21 | note_range = 88 22 | n_bins = int(args['bin_multiple']) * note_range 23 | window_size = 7 24 | mmi = np.memmap(ipath, mode='r') 25 | i = np.reshape(mmi,(-1,window_size,n_bins)) 26 | opath = os.path.join(d,'output.dat') 27 | mmo = np.memmap(opath, mode='r') 28 | o = np.reshape(mmo,(-1,note_range)) 29 | return i,o 30 | 31 | class DataGen: 32 | def __init__(self, dirpath, batch_size,args,num_files=1): 33 | print 'initializing gen for '+dirpath 34 | 35 | self.mmdirs = os.listdir(dirpath) 36 | self.spe = 0 #steps per epoch 37 | self.dir = dirpath 38 | 39 | for mmdir in self.mmdirs: 40 | print mmdir 41 | _,outputs = readmm(os.path.join(self.dir,mmdir),args) 42 | self.spe += len(outputs) // batch_size 43 | #print cnt 44 | self.num_files = num_files 45 | 46 | self.batch_size = batch_size 47 | self.current_file_idx = 0 48 | print 'starting with ', self.mmdirs[self.current_file_idx:self.current_file_idx+self.num_files] 49 | for j in range(self.num_files): 50 | mmdir = os.path.join(self.dir,self.mmdirs[self.current_file_idx+j]) 51 | i,o = readmm(mmdir,args) 52 | if j == 0: 53 | self.inputs,self.outputs = i,o 54 | print 'set inputs,outputs' 55 | else: 56 | self.inputs = np.concatenate((self.inputs,i)) 57 | self.outputs = np.concatenate((self.outputs,o)) 58 | print 'concatenated' 59 | self.current_file_idx = (self.current_file_idx + 1) % len(self.mmdirs) 60 | self.i = 0 61 | 62 | def steps(self): 63 | return self.spe 64 | 65 | def next(self): 66 | while True: 67 | if (self.i+1)*self.batch_size > self.inputs.shape[0]: 68 | #return rest and then switch files 69 | x,y = self.inputs[self.i*self.batch_size:],self.outputs[self.i*self.batch_size:] 70 | self.i = 0 71 | if len(self.mmdirs) > 1: # no need to open any new files if we only deal with one, like for validation 72 | print 'switching to ', self.mmdirs[self.current_file_idx:self.current_file_idx+self.num_files] 73 | for j in range(self.num_files): 74 | mmdir = os.path.join(self.dir,self.mmdirs[self.current_file_idx+j]) 75 | i,o = readmm(mmdir,args) 76 | if j == 0: 77 | self.inputs,self.output = i,o 78 | else: 79 | self.inputs = np.concatenate((self.inputs,i)) 80 | self.outputs = np.concatenate((self.outputs,o)) 81 | self.current_file_idx = (self.current_file_idx + 1) % len(self.mmdirs) 82 | 83 | else: 84 | x,y = self.inputs[self.i*self.batch_size:(self.i+1)*self.batch_size],self.outputs[self.i*self.batch_size:(self.i+1)*self.batch_size] 85 | self.i += 1 86 | yield x,y 87 | 88 | '''def load_data(dirpa): 89 | print('loading data from '+dirpath) 90 | hdf5_file = os.listdir(dirpath)[0] 91 | with h5py.File(os.path.join(dirpath,hdf5_file), 'r') as hf: 92 | inputs = hf['-inputs'][:] 93 | outputs = hf['-outputs'][:] 94 | return inputs,outputs''' 95 | 96 | 97 | sr = 22050 98 | hop_length = 512 99 | window_size = 7 100 | min_midi = 21 101 | max_midi = 108 102 | 103 | 104 | def wav2inputnp(audio_fn,spec_type='cqt',bin_multiple=3): 105 | print("wav2inputnp") 106 | bins_per_octave = 12 * bin_multiple #should be a multiple of 12 107 | n_bins = (max_midi - min_midi + 1) * bin_multiple 108 | 109 | #down-sample,mono-channel 110 | y,_ = librosa.load(audio_fn,sr) 111 | S = librosa.cqt(y,fmin=librosa.midi_to_hz(min_midi), sr=sr, hop_length=hop_length, 112 | bins_per_octave=bins_per_octave, n_bins=n_bins) 113 | S = S.T 114 | 115 | #TODO: LogScaleSpectrogram? 116 | ''' 117 | if spec_type == 'cqt': 118 | #down-sample,mono-channel 119 | y,_ = librosa.load(audio_fn,sr) 120 | S = librosa.cqt(y,fmin=librosa.midi_to_hz(min_midi), sr=sr, hop_length=hop_length, 121 | bins_per_octave=bins_per_octave, n_bins=n_bins) 122 | S = S.T 123 | else: 124 | #down-sample,mono-channel 125 | y = madmom.audio.signal.Signal(audio_fn, sample_rate=sr, num_channels=1) 126 | S = madmom.audio.spectrogram.LogarithmicFilteredSpectrogram(y,fmin=librosa.midi_to_hz(min_midi), 127 | hop_size=hop_length, num_bands=bins_per_octave, fft_size=4096)''' 128 | 129 | #S = librosa.amplitude_to_db(S) 130 | S = np.abs(S) 131 | 132 | minDB = np.min(S) 133 | 134 | print np.min(S),np.max(S),np.mean(S) 135 | 136 | S = np.pad(S, ((window_size//2,window_size//2),(0,0)), 'constant', constant_values=minDB) 137 | 138 | 139 | 140 | windows = [] 141 | 142 | # IMPORTANT NOTE: 143 | # Since we pad the the spectrogram frame, 144 | # the onset frames are actually `offset` frames. 145 | # To obtain a window of the center frame at each true index, we take a slice from i to i+window_size 146 | # starting at frame 0 of the padded spectrogram 147 | for i in range(S.shape[0]-window_size+1): 148 | w = S[i:i+window_size,:] 149 | windows.append(w) 150 | 151 | 152 | #print inputs 153 | x = np.array(windows) 154 | return x 155 | 156 | def mid2outputnp(pm_mid,times): 157 | piano_roll = pm_mid.get_piano_roll(fs=sr,times=times)[min_midi:max_midi+1].T 158 | piano_roll[piano_roll > 0] = 1 159 | return piano_roll 160 | 161 | 162 | 163 | def joinAndCreate(basePath,new): 164 | newPath = os.path.join(basePath,new) 165 | if not os.path.exists(newPath): 166 | os.mkdir(newPath) 167 | return newPath 168 | 169 | def isSplitFolder(ddir): 170 | return ddir == 'train' or ddir == 'test' or ddir == 'val' 171 | 172 | def organize(args): 173 | valCnt = 1 174 | testPrefix = 'ENS' 175 | 176 | path = os.path.join('models',args['model_name']) 177 | dpath = os.path.join(path,'data') 178 | 179 | train_path = joinAndCreate(dpath,'train') 180 | test_path = joinAndCreate(dpath,'test') 181 | val_path = joinAndCreate(dpath,'val') 182 | 183 | for ddir in os.listdir(dpath): 184 | if os.path.isdir(os.path.join(dpath,ddir)) and not isSplitFolder(ddir): 185 | #print h5file 186 | if ddir.startswith(testPrefix): 187 | os.rename(os.path.join(dpath,ddir), os.path.join(test_path,ddir)) 188 | elif valCnt > 0: 189 | os.rename(os.path.join(dpath,ddir), os.path.join(val_path,ddir)) 190 | valCnt -= 1 191 | else: 192 | os.rename(os.path.join(dpath,ddir), os.path.join(train_path,ddir)) 193 | 194 | 195 | data_dir = '../maps/' 196 | def preprocess(args): 197 | #params 198 | path = os.path.join('models',args['model_name']) 199 | config = load_config(os.path.join(path,'config.json')) 200 | 201 | 202 | 203 | bin_multiple = int(args['bin_multiple']) 204 | spec_type = args['spec_type'] 205 | 206 | 207 | 208 | 209 | framecnt = 0 210 | 211 | # hack to deal with high PPQ from MAPS 212 | # https://github.com/craffel/pretty-midi/issues/112 213 | pretty_midi.pretty_midi.MAX_TICK = 1e10 214 | 215 | 216 | for s in os.listdir(data_dir): 217 | subdir = os.path.join(data_dir,s) 218 | if not os.path.isdir(subdir): 219 | continue 220 | # recursively search in subdir 221 | print subdir 222 | inputs,outputs = [],[] 223 | addCnt, errCnt = 0,0 224 | for dp, dn, filenames in os.walk(subdir): 225 | # in each level of the directory, look at filenames ending with .mid 226 | for f in filenames: 227 | # if there exists a .wav file and .midi file with the same name 228 | 229 | if f.endswith('.wav'): 230 | audio_fn = f 231 | fprefix = audio_fn.split('.wav')[0] 232 | mid_fn = fprefix + '.mid' 233 | txt_fn = fprefix + '.txt' 234 | if mid_fn in filenames: 235 | # wav2inputnp 236 | audio_fn = os.path.join(dp,audio_fn) 237 | # mid2outputnp 238 | mid_fn = os.path.join(dp,mid_fn) 239 | 240 | pm_mid = pretty_midi.PrettyMIDI(mid_fn) 241 | 242 | inputnp = wav2inputnp(audio_fn,spec_type=spec_type,bin_multiple=bin_multiple) 243 | times = librosa.frames_to_time(np.arange(inputnp.shape[0]),sr=sr,hop_length=hop_length) 244 | outputnp = mid2outputnp(pm_mid,times) 245 | 246 | # check that num onsets is equal 247 | if inputnp.shape[0] == outputnp.shape[0]: 248 | print("adding to dataset fprefix {}".format(fprefix)) 249 | addCnt += 1 250 | framecnt += inputnp.shape[0] 251 | print("framecnt is {}".format(framecnt)) 252 | inputs.append(inputnp) 253 | outputs.append(outputnp) 254 | else: 255 | print("error for fprefix {}".format(fprefix)) 256 | errCnt += 1 257 | print(inputnp.shape) 258 | print(outputnp.shape) 259 | 260 | print("{} examples in dataset".format(addCnt)) 261 | print("{} examples couldnt be processed".format(errCnt)) 262 | 263 | 264 | # concatenate dynamic list to numpy list of example 265 | if addCnt: 266 | inputs = np.concatenate(inputs) 267 | outputs = np.concatenate(outputs) 268 | 269 | fn = subdir.split('/')[-1] 270 | if not fn: 271 | fn = subdir.split('/')[-2] 272 | #fn += '.h5' 273 | # save inputs,outputs to hdf5 file 274 | datapath = joinAndCreate(path,'data') 275 | fnpath = joinAndCreate(datapath,fn) 276 | 277 | mmi = np.memmap(filename=os.path.join(fnpath,'input.dat'), mode='w+',shape=inputs.shape) 278 | mmi[:] = inputs[:] 279 | mmo = np.memmap(filename=os.path.join(fnpath,'output.dat'), mode='w+',shape=outputs.shape) 280 | mmo[:] = outputs[:] 281 | del mmi 282 | del mmo 283 | 284 | '''with h5py.File(os.path.join(datapath,fn), 'w') as hf: 285 | hf.create_dataset("-inputs", data=inputs) 286 | hf.create_dataset("-outputs", data=outputs) 287 | 288 | without dB, i'm just going to not worry about feature scaling 289 | if args.zn: 290 | nppath = os.path.join(path,'xn') 291 | if os.path.isfile(nppath+'.npz'): 292 | npzfile = np.load(nppath+'.npz') 293 | x,x2,n = npzfile['x'],npzfile['x2'],npzfile['n'] 294 | else: 295 | x,x2,n = 0,0,0 296 | 297 | 298 | x += np.sum(inputs,axis=0) 299 | x2 += np.sum(inputs**2,axis=0) 300 | n += inputs.shape[0] 301 | 302 | print x,x2,n 303 | 304 | print 'mean={}'.format(x/n) 305 | 306 | print 'var={}'.format(x2/n-(x/n)**2) 307 | 308 | np.savez(nppath,x=x,x2=x2,n=n)''' 309 | 310 | 311 | if __name__ == '__main__': 312 | # Set up command-line argument parsing 313 | parser = argparse.ArgumentParser( 314 | description='Preprocess MIDI/Audio file pairs into ingestible data') 315 | 316 | parser.add_argument('model_name', 317 | help='model name. will use config from directory and save preprocessed data to it') 318 | 319 | parser.add_argument('data_dir', 320 | help='Path to data dir, searched recursively, used for naming HDF5 file') 321 | 322 | 323 | parser.add_argument('--no-zn', dest='zn', action='store_false') 324 | parser.set_defaults(zn=True) 325 | 326 | args = vars(parser.parse_args()) 327 | 328 | preprocess(args) 329 | -------------------------------------------------------------------------------- /runs.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from config import create_config 3 | from preprocess import preprocess, organize 4 | from keras_train import train 5 | 6 | with open('models.csv', 'rb') as csvfile: 7 | reader = csv.reader(csvfile) 8 | header = reader.next() 9 | #print header 10 | for row in reader: 11 | args = dict(zip(header,row)) 12 | print args['model_name'] 13 | create_config(args) 14 | preprocess(args) 15 | organize(args) 16 | train(args) 17 | --------------------------------------------------------------------------------