├── .gitignore ├── README.md ├── convert_format.rb ├── data_util.py ├── file_util.py ├── generate_audio.ipynb ├── main.py ├── midi_util.py ├── midifile.rb ├── model.py └── pianoify.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | **/ 3 | .DS_Store 4 | .nfs* 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StyleNet 2 | 3 | A cute multi-layer LSTM network that can perform like a human 🎶 It learns the dynamics of music! The architecture was specifically designed to handle music of different genres. 4 | 5 | If you wish to learn more about my findings, then please read my [blog post](http://imanmalik.com/cs/2017/06/05/neural-style.html) and paper: 6 | 7 | > **Iman Malik, Carl Henrik Ek, [*"Neural Translation of Musical Style"*](https://arxiv.org/abs/1708.03535), 2017.** 8 | 9 | ![GitHub Logo](http://imanmalik.com/assets/img/stylenet.png) 10 | 11 | 12 | 13 | ## Prerequisites 14 | You will need a few things in order to get started. 15 | 16 | 1. Tensorflow 17 | 2. mido 18 | 3. pretty_midi 19 | 4. fluidsynth 20 | 21 | ## The Piano Dataset 22 | I created my own dataset for the model. If you wish to use the Piano Dataset 🎹 for academic purposes, you can download it from [here.](http://imanmalik.com/assets/dataset/TPD.zip) The Piano Dataset is distributed with a [CC-BY 4.0 license](https://creativecommons.org/licenses/by/4.0/). If you use this dataset, please reference this [paper](https://arxiv.org/abs/1708.03535): 23 | 24 | 25 | 26 | ## How to Run 27 | ``` python main.py -current_run -bi ``` 28 | 29 | Flags: 30 | `-load_last` : Loads and continues from last epoch. 31 | `-load_model`: Loads specified model. 32 | `-data_dir` : Directory of datasets. 33 | `-data_set` : Dataset name. 34 | `-runs_dir` : Directory of session files. 35 | `-forward_only` : For making predictions (not training). 36 | `-bi` : If you wish to use bi-directional LSTMs. (HIGHLY recommended) 37 | 38 | ## Files 39 | `pianoify.ipynb` : This was used to ensure the files across the dataset were consistent in their musical properties. 40 | `generate_audio.ipynb` : This was used to make predicitions using StyleNet and generate the audio. 41 | `convert-format.rb` : This was used to convert format 1 MIDIs into format 0. 42 | `file_util.py` : This contains folder/file-handling functions. 43 | `midi_util.py` : This contains MIDI-handling functions. 44 | `model.py` : StyleNet's Class. 45 | `data_util.py` : For shuffling and batching data during training. 46 | 47 | -------------------------------------------------------------------------------- /convert_format.rb: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env ruby 2 | 3 | # Script to generate a Format 0 midifile from a Format 1 source 4 | 5 | require_relative 'midifile.rb' 6 | 7 | if ARGV.length != 2 8 | puts "ARGS: '" 9 | exit 10 | end 11 | Dir.glob(File.join(ARGV[0],"*.mid")) do |item| 12 | puts item 13 | out = Midifile.new 14 | open(item) {|f| 15 | mr = Midifile.new f 16 | out.format = 0 17 | mr.each {|ev| 18 | ev.trkno = 0 if ev.trkno 19 | out.add(ev) 20 | } 21 | } 22 | Dir.mkdir(ARGV[1]) unless File.exists?(ARGV[1]) 23 | out_name = File.basename(item, '.mid') + ".mid" 24 | # out_name = item 25 | out_name = File.join(ARGV[1],out_name) 26 | puts out_name 27 | open(out_name,"w") {|fw| 28 | out.to_stream(fw) if out.vet() 29 | } 30 | end 31 | -------------------------------------------------------------------------------- /data_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class BatchGenerator(object): 4 | '''Generator for returning shuffled batches. 5 | 6 | data_x -- list of input matrices 7 | data_y -- list of output matrices 8 | batch_size -- size of batch 9 | input_size -- input width 10 | output_size -- output width 11 | mini -- create subsequences for truncating backprop 12 | mini_len -- truncated backprop window''' 13 | 14 | def __init__(self, data_x, data_y, batch_size, input_size, output_size, mini=True, mini_len=200): 15 | self.input_size = input_size 16 | self.output_size = output_size 17 | self.data_x = data_x 18 | self.data_y = data_y 19 | self.batch_size = batch_size 20 | self.batch_count = len(range(0, len(self.data_x), self.batch_size)) 21 | self.batch_length = None 22 | self.mini = mini 23 | self.mini_len = mini_len 24 | 25 | 26 | def batch(self): 27 | while True: 28 | idxs = np.arange(0, len(self.data_x)) 29 | np.random.shuffle(idxs) 30 | # np.random.shuffle(idxs) 31 | shuff_x = [] 32 | shuff_y = [] 33 | for i in idxs: 34 | shuff_x.append(self.data_x[i]) 35 | shuff_y.append(self.data_y[i]) 36 | 37 | for batch_idx in range(0, len(self.data_x), self.batch_size): 38 | input_batch = [] 39 | output_batch = [] 40 | for j in xrange(batch_idx, min(batch_idx+self.batch_size,len(self.data_x)), 1): 41 | input_batch.append(shuff_x[j]) 42 | output_batch.append(shuff_y[j]) 43 | input_batch, output_batch, seq_len = self.pad(input_batch, output_batch) 44 | yield input_batch, output_batch, seq_len 45 | 46 | 47 | def pad(self, sequence_X, sequence_Y): 48 | current_batch = len(sequence_X) 49 | padding_X = [0]*self.input_size 50 | padding_Y = [0]*self.output_size 51 | 52 | lens = [sequence_X[i].shape[0] for i in range(len(sequence_X))] 53 | # lens2 = [sequence_Y[i].shape[0] for i in range(len(sequence_Y))] 54 | # 55 | max_lens = max(lens) 56 | # max_lens2 = max(lens2) 57 | # 58 | # assert max_lens == max_lens2 59 | # print(max_lens) 60 | for i, x in enumerate(lens): 61 | length = x 62 | a = list(sequence_X[i]) 63 | b = list(sequence_Y[i]) 64 | while length < max_lens: 65 | a.append(padding_X) 66 | b.append(padding_Y) 67 | length+=1 68 | 69 | if self.mini: 70 | while length % self.mini_len != 0: 71 | a.append(padding_X) 72 | b.append(padding_Y) 73 | length+=1 74 | 75 | sequence_X[i] = np.array(a) 76 | sequence_Y[i] = np.array(b) 77 | # for x in minis: 78 | # mini_X.append(np.array(a[x:min(x+self.mini,x)])) 79 | # mini_Y.append(np.array(b[x:min(x+self.mini,x)])) 80 | # print sequence_X[i].shape 81 | # print sequence_Y[i].shape 82 | 83 | # assert all(x.shape == (max_lens, self.input_size) for x in sequence_X) 84 | # assert all(y.shape == (max_lens, self.output_size) for y in sequence_Y) 85 | 86 | sequence_X = np.vstack([np.expand_dims(x, 1) for x in sequence_X]) 87 | sequence_Y = np.vstack([np.expand_dims(y, 1) for y in sequence_Y]) 88 | 89 | if not self.mini: 90 | mini_batches = 1 91 | max_lens = max(lens) 92 | else: 93 | mini_batches = length/self.mini_len 94 | max_lens = self.mini_len 95 | 96 | sequence_X = np.reshape(sequence_X, [current_batch*mini_batches, max_lens, self.input_size]) 97 | sequence_Y = np.reshape(sequence_Y, [current_batch*mini_batches, max_lens, self.output_size]) 98 | 99 | return sequence_X, sequence_Y, max_lens 100 | -------------------------------------------------------------------------------- /file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from mido import MidiFile 4 | from midi_util import * 5 | 6 | def validate_data(path, quant): 7 | '''Creates a folder containing valid MIDI files. 8 | 9 | Arguments: 10 | path -- Original directory containing untouched midis. 11 | quant -- Level of quantisation''' 12 | 13 | path_prefix, path_suffix = os.path.split(path) 14 | 15 | # Handle case where a trailing / requires two splits. 16 | if len(path_suffix) == 0: 17 | path_prefix, path_suffix = os.path.split(path_prefix) 18 | 19 | total_file_count = 0 20 | processed_count = 0 21 | 22 | base_path_out = os.path.join(path_prefix, path_suffix+'_valid') 23 | 24 | for root, dirs, files in os.walk(path): 25 | for file in files: 26 | if file.split('.')[-1] == 'mid' or file.split('.')[-1] == 'MID': 27 | total_file_count += 1 28 | print 'Processing ' + str(file) 29 | midi_path = os.path.join(root,file) 30 | try: 31 | midi_file = MidiFile(midi_path) 32 | except (KeyError, IOError, TypeError, IndexError, EOFError, ValueError): 33 | print "Bad MIDI." 34 | continue 35 | time_sig_msgs = [ msg for msg in midi_file.tracks[0] if msg.type == 'time_signature' ] 36 | 37 | if len(time_sig_msgs) == 1: 38 | time_sig = time_sig_msgs[0] 39 | if not (time_sig.numerator == 4 and time_sig.denominator == 4): 40 | print '\tTime signature not 4/4. Skipping ...' 41 | continue 42 | else: 43 | # print time_sig_msgs 44 | print '\tNo time signature. Skipping ...' 45 | continue 46 | 47 | mid = quantize(MidiFile(os.path.join(root,file)), quant) 48 | if not mid: 49 | print 'Invalid MIDI. Skipping...' 50 | continue 51 | 52 | if not os.path.exists(base_path_out): 53 | os.makedirs(base_path_out) 54 | 55 | out_file = os.path.join(base_path_out, file) 56 | 57 | print '\tSaving', out_file 58 | midi_file.save(out_file) 59 | processed_count += 1 60 | 61 | print '\nProcessed {} files out of {}'.format(processed_count, total_file_count) 62 | 63 | def quantize_data(path, quant): 64 | '''Creates a folder containing the quantised MIDI files. 65 | 66 | Arguments: 67 | path -- Validated directory containing midis. 68 | quant -- Level of quantisation''' 69 | 70 | path_prefix, path_suffix = os.path.split(path) 71 | 72 | if len(path_suffix) == 0: 73 | path_prefix, path_suffix = os.path.split(path_prefix) 74 | 75 | total_file_count = 0 76 | processed_count = 0 77 | 78 | base_path_out = os.path.join(path_prefix, path_suffix+'_quantized') 79 | for root, dirs, files in os.walk(path): 80 | for file in files: 81 | if file.split('.')[-1] == 'mid' or file.split('.')[-1] == 'MID': 82 | total_file_count += 1 83 | mid = quantize(MidiFile(os.path.join(root,file)),quant) 84 | if not mid: 85 | print 'Invalid MIDI. Skipping...' 86 | continue 87 | suffix = root.split(path)[-1] 88 | out_dir = base_path_out + '/' + suffix 89 | if not os.path.exists(out_dir): 90 | os.makedirs(out_dir) 91 | out_file = os.path.join(out_dir, file) 92 | 93 | print 'Saving', out_file 94 | mid.save(out_file) 95 | 96 | processed_count += 1 97 | 98 | print 'Processed {} files out of {}'.format(processed_count, total_file_count) 99 | 100 | def save_data(path, quant, one_hot=True): 101 | '''Creates a folder containing the quantised MIDI files. 102 | 103 | Arguments: 104 | path -- Quantised directory containing midis. 105 | quant -- Level of quantisation''' 106 | 107 | path_prefix, path_suffix = os.path.split(path) 108 | 109 | # Handle case where a trailing / requires two splits. 110 | if len(path_suffix) == 0: 111 | path_prefix, path_suffix = os.path.split(path_prefix) 112 | 113 | array_out = os.path.join(path_prefix, path_suffix+'_inputs') 114 | velocity_out = os.path.join(path_prefix, path_suffix+'_velocities') 115 | 116 | total_file_count = 0 117 | processed_count = 0 118 | 119 | for root, dirs, files in os.walk(path): 120 | for file in files: 121 | # print os.path.join(root, file) 122 | if file.split('.')[-1] == 'mid' or file.split('.')[-1] == 'MID': 123 | total_file_count += 1 124 | 125 | 126 | out_array = '{}.npy'.format(os.path.join(array_out, file)) 127 | out_velocity = '{}.npy'.format(os.path.join(velocity_out, file)) 128 | midi_path = os.path.join(root,file) 129 | midi_file = MidiFile(midi_path) 130 | 131 | print 'Processing ' + str(file) 132 | mid = MidiFile(os.path.join(root,file)) 133 | 134 | # mid = quantize(midi_file, 135 | # quantization=quant) 136 | 137 | if one_hot: 138 | try: 139 | array, velocity_array = midi_to_array_one_hot(mid, quant) 140 | except (KeyError, TypeError, IOError, IndexError, EOFError, ValueError): 141 | print "Out of bounds" 142 | continue 143 | else: 144 | array, velocity_array = midi_to_array(mid, quant) 145 | 146 | if not os.path.exists(array_out): 147 | os.makedirs(array_out) 148 | 149 | if not os.path.exists(velocity_out): 150 | os.makedirs(velocity_out) 151 | 152 | # print out_dir 153 | 154 | print 'Saving', out_array 155 | 156 | # print_array( mid, array) 157 | # raw_input("Press Enter to continue...") 158 | 159 | np.save(out_array, array) 160 | np.save(out_velocity, velocity_array) 161 | 162 | processed_count += 1 163 | print '\nProcessed {} files out of {}'.format(processed_count, total_file_count) 164 | 165 | def load_data(path): 166 | '''Returns lists of input and output numpy matrices. 167 | 168 | Arguments: 169 | path -- Quantised directory path. 170 | quant -- Level of quantisation''' 171 | 172 | names = [] 173 | X_list = [] 174 | Y_list = [] 175 | path_prefix, path_suffix = os.path.split(path) 176 | 177 | # Handle case where a trailing / requires two splits. 178 | if len(path_suffix) == 0: 179 | path_prefix, path_suffix = os.path.split(path_prefix) 180 | 181 | x_path = os.path.join(path_prefix, path_suffix+"_inputs") 182 | y_path = os.path.join(path_prefix, path_suffix+"_labels") 183 | 184 | for filename in os.listdir(x_path): 185 | if filename.split('.')[-1] == 'npy': 186 | abs_path = os.path.join(x_path,filename) 187 | loaded = np.array(np.load(abs_path)) 188 | 189 | 190 | X_list.append(loaded) 191 | 192 | for filename in os.listdir(y_path): 193 | if filename.split('.')[-1] == 'npy': 194 | abs_path = os.path.join(y_path,filename) 195 | loaded = np.array(np.load(abs_path)) 196 | Y_list.append(loaded) 197 | 198 | # X_list = np.array(X_list) 199 | # Y_list = np.array(Y_list) 200 | 201 | 202 | return X_list, Y_list 203 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | import os 4 | import numpy as np 5 | from model import GenreLSTM 6 | 7 | 8 | parser = argparse.ArgumentParser(description='How to run this') 9 | 10 | parser.add_argument( 11 | "-current_run", 12 | type=str, 13 | help="The name of the model which will also be the name of the session's folder." 14 | ) 15 | 16 | parser.add_argument( 17 | "-data_dir", 18 | type=str, 19 | default="./data", 20 | help="Directory of datasets" 21 | ) 22 | 23 | parser.add_argument( 24 | "-data_set", 25 | type=str, 26 | default="test", 27 | help="The name of training dataset" 28 | ) 29 | 30 | parser.add_argument( 31 | "-runs_dir", 32 | type=str, 33 | default="./runs", 34 | help="The name of the model which will also be the name of the session folder" 35 | ) 36 | 37 | parser.add_argument( 38 | "-bi", 39 | help="True for bidirectional", 40 | action='store_true' 41 | ) 42 | 43 | parser.add_argument( 44 | "-forward_only", 45 | action='store_true', 46 | help="True for forward only, False for training [False]" 47 | ) 48 | 49 | parser.add_argument( 50 | "-load_model", 51 | type=str, 52 | default=None, 53 | help="Folder name of model to load" 54 | ) 55 | 56 | parser.add_argument( 57 | "-load_last", 58 | action='store_true', 59 | help="Start from last epoch" 60 | ) 61 | 62 | args = parser.parse_args() 63 | 64 | def setup_dir(): 65 | 66 | print('[*] Setting up directory...') 67 | 68 | main_path = args.runs_dir 69 | current_run = os.path.join(main_path, args.current_run) 70 | 71 | files_path = args.data_dir 72 | files_path = os.path.join(files_path, args.data_set) 73 | 74 | x_path = os.path.join(files_path, 'inputs') 75 | y_path = os.path.join(files_path, 'velocities') 76 | eval_path = os.path.join(files_path, 'eval') 77 | 78 | model_path = os.path.join(current_run, 'model') 79 | logs_path = os.path.join(current_run, 'tmp') 80 | png_path = os.path.join(current_run, 'png') 81 | pred_path = os.path.join(current_run, 'predictions') 82 | 83 | if not os.path.exists(current_run): 84 | os.makedirs(current_run) 85 | if not os.path.exists(logs_path): 86 | os.makedirs(logs_path) 87 | if not os.path.exists(model_path): 88 | os.makedirs(model_path) 89 | if not os.path.exists(png_path): 90 | os.makedirs(png_path) 91 | if not os.path.exists(pred_path): 92 | os.makedirs(pred_path) 93 | 94 | dirs = { 95 | 'main_path': main_path, 96 | 'current_run': current_run, 97 | 'model_path': model_path, 98 | 'logs_path': logs_path, 99 | 'png_path': png_path, 100 | 'eval_path': eval_path, 101 | 'pred_path': pred_path, 102 | 'x_path': x_path, 103 | 'y_path': y_path 104 | } 105 | 106 | # print main_path 107 | # print current_run 108 | # print model_path 109 | # print logs_path 110 | # print png_path 111 | # print eval_path 112 | # print x_path 113 | # print y_path 114 | return dirs 115 | 116 | def load_training_data(x_path, y_path, genre): 117 | X_data = [] 118 | Y_data = [] 119 | names = [] 120 | print('[*] Loading data...') 121 | 122 | x_path = os.path.join(x_path, genre) 123 | y_path = os.path.join(y_path, genre) 124 | 125 | for i, filename in enumerate(os.listdir(x_path)): 126 | if filename.split('.')[-1] == 'npy': 127 | names.append(filename) 128 | 129 | for i, filename in enumerate(names): 130 | abs_x_path = os.path.join(x_path,filename) 131 | abs_y_path = os.path.join(y_path,filename) 132 | loaded_x = np.load(abs_x_path) 133 | 134 | X_data.append(loaded_x) 135 | 136 | loaded_y = np.load(abs_y_path) 137 | loaded_y = loaded_y/127 138 | Y_data.append(loaded_y) 139 | assert X_data[i].shape[0] == Y_data[i].shape[0] 140 | 141 | 142 | return X_data, Y_data 143 | 144 | def prepare_data(): 145 | dirs = setup_dir() 146 | data = {} 147 | data["classical"] = {} 148 | data["jazz"] = {} 149 | 150 | c_train_X , c_train_Y = load_training_data(dirs['x_path'], dirs['y_path'], "classical") 151 | 152 | data["classical"]["X"] = c_train_X 153 | data["classical"]["Y"] = c_train_Y 154 | 155 | j_train_X , j_train_Y = load_training_data(dirs['x_path'], dirs['y_path'], "jazz") 156 | 157 | data["jazz"]["X"] = j_train_X 158 | data["jazz"]["Y"] = j_train_Y 159 | return dirs, data 160 | 161 | def main(): 162 | tf.logging.set_verbosity(tf.logging.ERROR) 163 | 164 | dirs, data = prepare_data() 165 | 166 | network = GenreLSTM(dirs, input_size=176, mini=True, bi=args.bi) 167 | network.prepare_model() 168 | 169 | if not args.forward_only: 170 | if args.load_model: 171 | loaded_epoch = args.load_model.split('.')[0] 172 | loaded_epoch = loaded_epoch.split('-')[-1] 173 | loaded_epoch = loaded_epoch[1:] 174 | print("[*] Loading " + args.load_model + " and continuing from " + loaded_epoch + ".") 175 | loaded_epoch = int(loaded_epoch) 176 | network.train(data, model=args.load_model, starting_epoch=loaded_epoch+1) 177 | elif args.load_last: 178 | tree = os.listdir(dirs["model_path"]) 179 | tree.remove('checkpoint') 180 | files = [(int(file.split('.')[0].split('-')[-1][1:]), file.split('.')[0]) for file in tree] 181 | files.sort(key = lambda t: t[0]) 182 | # print files 183 | last = files[-1][1] 184 | last = last + ".ckpt" 185 | loaded_epoch = files[-1][0] 186 | # loaded_epoch = last.split('-')[-1] 187 | # loaded_epoch = loaded_epoch[1:] 188 | # last = last + ".ckpt" 189 | print("[*] Loading " + last + " and continuing from " + str(loaded_epoch) + ".") 190 | network.train(data, model=last, starting_epoch=loaded_epoch+1) 191 | else: 192 | network.train(data) 193 | else: 194 | network.load(args.load_model) 195 | 196 | if __name__ == '__main__': 197 | main() 198 | -------------------------------------------------------------------------------- /midi_util.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import copy 3 | from math import log, floor, ceil 4 | import pprint 5 | import matplotlib.pyplot as plt 6 | import pretty_midi 7 | import mido 8 | from mido import MidiFile, MidiTrack, Message, MetaMessage 9 | import numpy as np 10 | import random 11 | 12 | DEBUG = False 13 | 14 | # The MIDI pitches we use. 15 | PITCHES = xrange(21,109,1) 16 | OFFSET = 109-21 17 | PITCHES_MAP = { p : i for i, p in enumerate(PITCHES) } 18 | print len(PITCHES) 19 | 20 | def nearest_pow2(x): 21 | '''Normalize input to nearest power of 2, or midpoints between 22 | consecutive powers of two. Round down when halfway between two 23 | possibilities.''' 24 | 25 | low = 2**int(floor(log(x, 2))) 26 | high = 2**int(ceil(log(x, 2))) 27 | mid = (low + high) / 2 28 | 29 | if x < mid: 30 | high = mid 31 | else: 32 | low = mid 33 | if high - x < x - low: 34 | nearest = high 35 | else: 36 | nearest = low 37 | return nearest 38 | 39 | def midi_to_array_one_hot(mid, quantization): 40 | '''Return array representation of a 4/4 time signature, MIDI object. 41 | 42 | Normalize the number of time steps in track to a power of 2. Then 43 | construct a T x N*2 array A (T = number of time steps, N = number of 44 | MIDI note numbers) where [A(t,n), A(t, n+1)] is the state of the note number 45 | at time step t. 46 | 47 | Arguments: 48 | mid -- MIDI object with a 4/4 time signature 49 | quantization -- The note duration, represented as 1/2**quantization.''' 50 | 51 | time_sig_msgs = [ msg for msg in mid.tracks[0] if msg.type == 'time_signature' ] 52 | assert len(time_sig_msgs) == 1, 'No time signature found' 53 | time_sig = time_sig_msgs[0] 54 | assert time_sig.numerator == 4 and time_sig.denominator == 4, 'Not 4/4 time.' 55 | 56 | # Quantize the notes to a grid of time steps. 57 | mid = quantize(mid, quantization=quantization) 58 | 59 | # Convert the note timing and velocity to an array. 60 | _, track = get_note_track(mid) 61 | ticks_per_quarter = mid.ticks_per_beat 62 | time_msgs = [msg for msg in track if hasattr(msg, 'time')] 63 | cum_times = np.cumsum([msg.time for msg in time_msgs]) 64 | 65 | track_len_ticks = cum_times[-1] 66 | if DEBUG: 67 | print 'Track len in ticks:', track_len_ticks 68 | notes = [ 69 | (time * (2**quantization/4) / (ticks_per_quarter), msg.type, msg.note, msg.velocity) 70 | for (time, msg) in zip(cum_times, time_msgs) 71 | if msg.type == 'note_on' or msg.type == 'note_off'] 72 | 73 | num_steps = int(round(track_len_ticks / float(ticks_per_quarter)*2**quantization/4)) 74 | normalized_num_steps = nearest_pow2(num_steps) 75 | notes.sort(key=lambda (position, note_type, note_num, velocity):(position,-velocity)) 76 | 77 | if DEBUG: 78 | # pp = pprint.PrettyPrinter() 79 | print num_steps 80 | print normalized_num_steps 81 | # pp.pprint(notes) 82 | 83 | midi_array = np.zeros((normalized_num_steps, len(PITCHES)*2)) 84 | velocity_array = np.zeros((normalized_num_steps, len(PITCHES))) 85 | open_msgs = defaultdict(list) 86 | 87 | for (position, note_type, note_num, velocity) in notes: 88 | if position == normalized_num_steps: 89 | # print 'Warning: truncating from position {} to {}'.format(position, normalized_num_steps - 1) 90 | position = normalized_num_steps - 1 91 | # continue 92 | 93 | if position > normalized_num_steps: 94 | # print 'Warning: skipping note at position {} (greater than {})'.format(position, normalized_num_steps) 95 | continue 96 | 97 | if note_type == "note_on" and velocity > 0: 98 | open_msgs[note_num].append((position, note_type, note_num, velocity)) 99 | midi_array[position, 2*PITCHES_MAP[note_num]] = 1 100 | midi_array[position, 2*PITCHES_MAP[note_num]+1] = 1 101 | velocity_array[position, PITCHES_MAP[note_num]] = velocity 102 | elif note_type == 'note_off' or (note_type == 'note_on' and velocity == 0): 103 | 104 | note_on_open_msgs = open_msgs[note_num] 105 | 106 | if len(note_on_open_msgs) == 0: 107 | print 'Bad MIDI, Note has no end time.' 108 | return 109 | 110 | stack_pos, _, _, vel = note_on_open_msgs[0] 111 | open_msgs[note_num] = note_on_open_msgs[1:] 112 | current_pos = position 113 | while current_pos > stack_pos: 114 | # if midi_array[position, PITCHES_MAP[note_num]] != 1: 115 | midi_array[current_pos, 2*PITCHES_MAP[note_num]] = 0 116 | midi_array[current_pos, 2*PITCHES_MAP[note_num]+1] = 1 117 | velocity_array[current_pos, PITCHES_MAP[note_num]] = vel 118 | current_pos -= 1 119 | 120 | for (position, note_type, note_num, velocity) in notes: 121 | if position == normalized_num_steps: 122 | print 'Warning: truncating from position {} to {}'.format(position, normalized_num_steps - 1) 123 | position = normalized_num_steps - 1 124 | # continue 125 | 126 | if position > normalized_num_steps: 127 | # print 'Warning: skipping note at position {} (greater than {})'.format(position, normalized_num_steps) 128 | continue 129 | if note_type == "note_on" and velocity > 0: 130 | open_msgs[note_num].append((position, note_type, note_num, velocity)) 131 | midi_array[position, 2*PITCHES_MAP[note_num]] = 1 132 | midi_array[position, 2*PITCHES_MAP[note_num]+1] = 1 133 | velocity_array[position, PITCHES_MAP[note_num]] = velocity 134 | 135 | assert len(midi_array) == len(velocity_array) 136 | return midi_array, velocity_array 137 | 138 | def print_array(mid, array, quantization=4): 139 | '''Print a binary array representing midi notes.''' 140 | bar = 1 141 | ticks_per_beat = mid.ticks_per_beat 142 | ticks_per_slice = ticks_per_beat/2**quantization 143 | 144 | bars = [x*ticks_per_slice % ticks_per_beat for x in xrange(0,len(array))] 145 | 146 | # print ticks_per_beat, ticks_per_slice 147 | res = '' 148 | for i, slice in enumerate(array): 149 | for pitch in slice: 150 | if pitch > 0: 151 | res += str(int(pitch)) 152 | else: 153 | res += '-' 154 | if bars[i]== 0: 155 | res += str(bar) 156 | bar +=1 157 | res += '\n' 158 | # Take out the last newline 159 | print res[:-1] 160 | 161 | def get_note_track(mid): 162 | '''Given a MIDI object, return the first track with note events.''' 163 | 164 | for i, track in enumerate(mid.tracks): 165 | for msg in track: 166 | if msg.type == 'note_on': 167 | return i, track 168 | raise ValueError( 169 | 'MIDI object does not contain any tracks with note messages.') 170 | 171 | def quantize_tick(tick, ticks_per_quarter, quantization): 172 | '''Quantize the timestamp or tick. 173 | 174 | Arguments: 175 | tick -- An integer timestamp 176 | ticks_per_quarter -- The number of ticks per quarter note 177 | quantization -- The note duration, represented as 1/2**quantization 178 | ''' 179 | assert (ticks_per_quarter * 4) % 2 ** quantization == 0, \ 180 | 'Quantization too fine. Ticks per quantum must be an integer.' 181 | ticks_per_quantum = (ticks_per_quarter * 4) / float(2 ** quantization) 182 | quantized_ticks = int( 183 | round(tick / float(ticks_per_quantum)) * ticks_per_quantum) 184 | return quantized_ticks 185 | 186 | def unquantize(mid, style_mid): 187 | unquantized_mid = copy.deepcopy(mid) 188 | # By convention, Track 0 contains metadata and Track 1 contains 189 | # the note on and note off events. 190 | orig_note_track_idx, orig_note_track = get_note_track(mid) 191 | style_note_track_idx, style_note_track = get_note_track(style_mid) 192 | 193 | note_track = unquantize_track(orig_note_track, style_note_track) 194 | unquantized_mid.tracks[orig_note_track_idx] = note_track 195 | 196 | return unquantized_mid 197 | 198 | def unquantize_track(orig_track, style_track): 199 | '''Returns the unquantised orig_track with encoded velocities from the style_track. 200 | 201 | Arguments: 202 | orig_track -- Non-quantised MIDI object 203 | style_track -- Quantised and stylised MIDI object ''' 204 | 205 | first_note_msg_idx = None 206 | 207 | for i, msg in enumerate(orig_track): 208 | if msg.type == 'note_on': 209 | orig_first_note_msg_idx = i 210 | break 211 | 212 | for i, msg in enumerate(style_track): 213 | if msg.type == 'note_on': 214 | style_first_note_msg_idx = i 215 | break 216 | 217 | orig_cum_msgs = zip( 218 | np.cumsum([msg.time for msg in orig_track[orig_first_note_msg_idx:]]), 219 | [msg for msg in orig_track[orig_first_note_msg_idx:]]) 220 | 221 | style_cum_msgs = zip( 222 | np.cumsum([msg.time for msg in style_track[style_first_note_msg_idx:]]), 223 | [msg for msg in style_track[style_first_note_msg_idx:]]) 224 | 225 | orig_cum_msgs.sort(key=lambda (cum_time, msg): cum_time) 226 | style_cum_msgs.sort(key=lambda (cum_time, msg): cum_time) 227 | 228 | open_msgs = defaultdict(list) 229 | 230 | for cum_time, msg in orig_cum_msgs: 231 | if msg.type == 'note_on' and msg.velocity > 0: 232 | open_msgs[msg.note].append((cum_time,msg)) 233 | 234 | for i, (cum_time, msg) in enumerate(style_cum_msgs): 235 | if msg.type == 'note_on' and msg.velocity > 0: 236 | note_on_open_msgs = open_msgs[msg.note] 237 | note_on_cum_time, note_on_msg = note_on_open_msgs[0] 238 | note_on_msg.velocity = msg.velocity 239 | open_msgs[msg.note] = note_on_open_msgs[1:] 240 | 241 | return orig_track 242 | 243 | def quantize(mid, quantization=5): 244 | '''Return a midi object whose notes are quantized to 245 | 1/2**quantization notes. 246 | 247 | Arguments: 248 | mid -- MIDI object 249 | quantization -- The note duration, represented as 250 | 1/2**quantization.''' 251 | 252 | quantized_mid = copy.deepcopy(mid) 253 | # By convention, Track 0 contains metadata and Track 1 contains 254 | # the note on and note off events. 255 | note_track_idx, note_track = get_note_track(mid) 256 | new_track = quantize_track( note_track, mid.ticks_per_beat, quantization) 257 | if new_track == None: 258 | return None 259 | quantized_mid.tracks[note_track_idx] = new_track 260 | return quantized_mid 261 | 262 | def quantize_track(track, ticks_per_quarter, quantization): 263 | '''Return the differential time stamps of the note_on, note_off, and 264 | end_of_track events, in order of appearance, with the note_on events 265 | quantized to the grid given by the quantization. 266 | 267 | Arguments: 268 | track -- MIDI track containing note event and other messages 269 | ticks_per_quarter -- The number of ticks per quarter note 270 | quantization -- The note duration, represented as 271 | 1/2**quantization.''' 272 | 273 | pp = pprint.PrettyPrinter() 274 | 275 | # Message timestamps are represented as differences between 276 | # consecutive events. Annotate messages with cumulative timestamps. 277 | 278 | # Assume the following structure: 279 | # [header meta messages] [note messages] [end_of_track message] 280 | first_note_msg_idx = None 281 | for i, msg in enumerate(track): 282 | if msg.type == 'note_on': 283 | first_note_msg_idx = i 284 | break 285 | 286 | cum_msgs = zip( 287 | np.cumsum([msg.time for msg in track[first_note_msg_idx:]]), 288 | [msg for msg in track[first_note_msg_idx:]]) 289 | end_of_track_cum_time = cum_msgs[-1][0] 290 | 291 | quantized_track = MidiTrack() 292 | quantized_track.extend(track[:first_note_msg_idx]) 293 | # Keep track of note_on events that have not had an off event yet. 294 | # note number -> message 295 | open_msgs = defaultdict(list) 296 | quantized_msgs = [] 297 | for cum_time, msg in cum_msgs: 298 | if DEBUG: 299 | print 'Message:', msg 300 | print 'Open messages:' 301 | pp.pprint(open_msgs) 302 | if msg.type == 'note_on' and msg.velocity > 0: 303 | # Store until note off event. Note that there can be 304 | # several note events for the same note. Subsequent 305 | # note_off events will be associated with these note_on 306 | # events in FIFO fashion. 307 | open_msgs[msg.note].append((cum_time, msg)) 308 | elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0): 309 | # assert msg.note in open_msgs, \ 310 | # 'Bad MIDI. Cannot have note off event before note on event' 311 | 312 | if msg.note not in open_msgs: 313 | print 'Bad MIDI. Cannot have note off event before note on event' 314 | return 315 | 316 | note_on_open_msgs = open_msgs[msg.note] 317 | 318 | if len(note_on_open_msgs) == 0: 319 | print 'Bad MIDI, Note has no end time.' 320 | return 321 | 322 | # assert len(note_on_open_msgs) > 0, 'Bad MIDI, Note has no end time.' 323 | 324 | note_on_cum_time, note_on_msg = note_on_open_msgs[0] 325 | open_msgs[msg.note] = note_on_open_msgs[1:] 326 | 327 | # Quantized note_on time 328 | quantized_note_on_cum_time = quantize_tick( 329 | note_on_cum_time, ticks_per_quarter, quantization) 330 | 331 | # The cumulative time of note_off is the quantized 332 | # cumulative time of note_on plus the orginal difference 333 | # of the unquantized cumulative times. 334 | quantized_note_off_cum_time = quantized_note_on_cum_time + (cum_time - note_on_cum_time) 335 | quantized_msgs.append((min(end_of_track_cum_time, quantized_note_on_cum_time), note_on_msg)) 336 | quantized_msgs.append((min(end_of_track_cum_time, quantized_note_off_cum_time), msg)) 337 | 338 | if DEBUG: 339 | print 'Appended', quantized_msgs[-2:] 340 | elif msg.type == 'end_of_track': 341 | quantized_msgs.append((cum_time, msg)) 342 | 343 | if DEBUG: 344 | print '\n' 345 | 346 | # Now, sort the quantized messages by (cumulative time, 347 | # note_type), making sure that note_on events come before note_off 348 | # events when two event have the same cumulative time. Compute 349 | # differential times and construct the quantized track messages. 350 | quantized_msgs.sort( 351 | key=lambda (cum_time, msg): cum_time 352 | if (msg.type=='note_on' and msg.velocity > 0) else cum_time + 0.5) 353 | 354 | diff_times = [quantized_msgs[0][0]] + list( 355 | np.diff([ msg[0] for msg in quantized_msgs ])) 356 | for diff_time, (cum_time, msg) in zip(diff_times, quantized_msgs): 357 | quantized_track.append(msg.copy(time=diff_time)) 358 | if DEBUG: 359 | print 'Quantized messages:' 360 | pp.pprint(quantized_msgs) 361 | pp.pprint(diff_times) 362 | return quantized_track 363 | 364 | def stylify(mid, velocity_array, quantization): 365 | style_mid = copy.deepcopy(mid) 366 | # By convention, Track 0 contains metadata and Track 1 contains 367 | # the note on and note off events. 368 | note_track_idx, note_track = get_note_track(mid) 369 | new_track = stylify_track(mid, velocity_array, quantization) 370 | style_mid.tracks[note_track_idx] = new_track 371 | return style_mid 372 | 373 | # def midi_to_array(mid, quantization): 374 | # '''Return array representation of a 4/4 time signature, MIDI object. 375 | # 376 | # Normalize the number of time steps in track to a power of 2. Then 377 | # construct a T x N array A (T = number of time steps, N = number of 378 | # MIDI note numbers) where A(t,n) is the velocity of the note number 379 | # n at time step t if the note is active, and 0 if it is not. 380 | # 381 | # Arguments: 382 | # mid -- MIDI object with a 4/4 time signature 383 | # quantization -- The note duration, represented as 1/2**quantization.''' 384 | # 385 | # time_sig_msgs = [ msg for msg in mid.tracks[0] if msg.type == 'time_signature' ] 386 | # assert len(time_sig_msgs) == 1, 'No time signature found' 387 | # time_sig = time_sig_msgs[0] 388 | # assert time_sig.numerator == 4 and time_sig.denominator == 4, 'Not 4/4 time.' 389 | # 390 | # # Quantize the notes to a grid of time steps. 391 | # mid = quantize(mid, quantization=quantization) 392 | # 393 | # # Convert the note timing and velocity to an array. 394 | # _, track = get_note_track(mid) 395 | # ticks_per_quarter = mid.ticks_per_beat 396 | # time_msgs = [msg for msg in track if hasattr(msg, 'time')] 397 | # cum_times = np.cumsum([msg.time for msg in time_msgs]) 398 | # 399 | # track_len_ticks = cum_times[-1] 400 | # if DEBUG: 401 | # print 'Track len in ticks:', track_len_ticks 402 | # notes = [ 403 | # (time * (2**quantization/4) / (ticks_per_quarter), msg.type, msg.note, msg.velocity) 404 | # for (time, msg) in zip(cum_times, time_msgs) 405 | # if msg.type == 'note_on' or msg.type == 'note_off'] 406 | # 407 | # num_steps = int(round(track_len_ticks / float(ticks_per_quarter)*2**quantization/4)) 408 | # normalized_num_steps = nearest_pow2(num_steps) 409 | # notes.sort(key=lambda (position, note_type, note_num, velocity):(position,-velocity)) 410 | # 411 | # if DEBUG: 412 | # # pp = pprint.PrettyPrinter() 413 | # print num_steps 414 | # print normalized_num_steps 415 | # # pp.pprint(notes) 416 | # 417 | # midi_array = np.zeros((normalized_num_steps, len(PITCHES))) 418 | # velocity_array = np.zeros((normalized_num_steps, len(PITCHES))) 419 | # open_msgs = defaultdict(list) 420 | # 421 | # for (position, note_type, note_num, velocity) in notes: 422 | # if position == normalized_num_steps: 423 | # # print 'Warning: truncating from position {} to {}'.format(position, normalized_num_steps - 1) 424 | # position = normalized_num_steps - 1 425 | # # continue 426 | # 427 | # if position > normalized_num_steps: 428 | # # print 'Warning: skipping note at position {} (greater than {})'.format(position, normalized_num_steps) 429 | # continue 430 | # 431 | # if note_type == "note_on" and velocity > 0: 432 | # open_msgs[note_num].append((position, note_type, note_num, velocity)) 433 | # midi_array[position, PITCHES_MAP[note_num]] = 1 434 | # velocity_array[position, PITCHES_MAP[note_num]] = velocity 435 | # elif note_type == 'note_off' or (note_type == 'note_on' and velocity == 0): 436 | # 437 | # note_on_open_msgs = open_msgs[note_num] 438 | # 439 | # if len(note_on_open_msgs) == 0: 440 | # print 'Bad MIDI, Note has no end time.' 441 | # return 442 | # 443 | # stack_pos, _, _, vel = note_on_open_msgs[0] 444 | # open_msgs[note_num] = note_on_open_msgs[1:] 445 | # current_pos = position 446 | # while current_pos > stack_pos: 447 | # # if midi_array[position, PITCHES_MAP[note_num]] != 1: 448 | # midi_array[current_pos, PITCHES_MAP[note_num]] = 2 449 | # velocity_array[current_pos, PITCHES_MAP[note_num]] = vel 450 | # current_pos -= 1 451 | # 452 | # for (position, note_type, note_num, velocity) in notes: 453 | # if position == normalized_num_steps: 454 | # print 'Warning: truncating from position {} to {}'.format(position, normalized_num_steps - 1) 455 | # position = normalized_num_steps - 1 456 | # # continue 457 | # 458 | # if position > normalized_num_steps: 459 | # # print 'Warning: skipping note at position {} (greater than {})'.format(position, normalized_num_steps) 460 | # continue 461 | # if note_type == "note_on" and velocity > 0: 462 | # open_msgs[note_num].append((position, note_type, note_num, velocity)) 463 | # midi_array[position, PITCHES_MAP[note_num]] = 1 464 | # velocity_array[position, PITCHES_MAP[note_num]] = velocity 465 | # 466 | # return midi_array, velocity_array 467 | 468 | def stylify_track(mid, velocity_array, quantization): 469 | 470 | _, track = get_note_track(mid) 471 | # first_note_msg_idx = None 472 | # 473 | # for i, msg in enumerate(track): 474 | # if msg.type == 'note_on': 475 | # first_note_msg_idx = i 476 | # break 477 | 478 | ticks_per_quarter = mid.ticks_per_beat 479 | 480 | time_msgs = [msg for msg in track if hasattr(msg, 'time')] 481 | 482 | cum_times = np.cumsum([msg.time for msg in time_msgs]) 483 | track_len_ticks = cum_times[-1] 484 | 485 | num_steps = int(round(track_len_ticks / float(ticks_per_quarter)*2**quantization/4)) 486 | normalized_num_steps = nearest_pow2(num_steps) 487 | # notes.sort(key=lambda (position, note_type, note_num, velocity):(position,-velocity)) 488 | 489 | notes = [ 490 | (time * (2**quantization/4) / (ticks_per_quarter), msg.type, msg.note, msg.velocity) 491 | for (time, msg) in zip(cum_times, time_msgs) 492 | if msg.type == 'note_on' or msg.type == 'note_off'] 493 | 494 | cum_index = 0 495 | for i, time_msg in enumerate(track): 496 | if hasattr(time_msg, 'time'): 497 | if time_msg.type == 'note_on' or time_msg.type == 'note_off': 498 | if time_msg.velocity > 0: 499 | pos = cum_times[cum_index] * (2**quantization/4) / (ticks_per_quarter) 500 | if pos == normalized_num_steps: 501 | pos = pos - 1 502 | if pos > normalized_num_steps: 503 | continue 504 | vel = velocity_array[pos, PITCHES_MAP[time_msg.note]] 505 | vel = vel*127 506 | # print vel 507 | vel = max(vel,1) 508 | track[i].velocity = int(round(vel)) 509 | cum_index += 1 510 | 511 | return track 512 | 513 | def scrub(mid, velocity=10, random=False): 514 | '''Returns a midi object with one global velocity. 515 | 516 | Sets all velocities to a contant. 517 | 518 | Arguments: 519 | mid -- MIDI object with a 4/4 time signature 520 | velocity -- The global velocity''' 521 | scrubbed_mid = copy.deepcopy(mid) 522 | # By convention, Track 0 contains metadata and Track 1 contains 523 | # the note on and note off events. 524 | note_track_idx, note_track = get_note_track(mid) 525 | if random: 526 | new_track = scrub_track_random(note_track) 527 | else: 528 | new_track = scrub_track(note_track,velocity=10) 529 | scrubbed_mid.tracks[note_track_idx] = new_track 530 | return scrubbed_mid 531 | 532 | def scrub_track_random(track): 533 | 534 | first_note_msg_idx = None 535 | 536 | for i, msg in enumerate(track): 537 | if msg.type == 'note_on': 538 | first_note_msg_idx = i 539 | break 540 | 541 | note_msgs = track[first_note_msg_idx:] 542 | 543 | for msg in note_msgs: 544 | if msg.type == 'note_on' and msg.velocity > 0: 545 | msg.velocity = random.randint(0,127) 546 | 547 | return track 548 | 549 | def velocity_range(mid): 550 | '''Returns a count of velocities. 551 | 552 | Counts the range of velocities in a midi object. 553 | 554 | Arguments: 555 | mid -- MIDI object with a 4/4 time signature''' 556 | 557 | _, track = get_note_track(mid) 558 | first_note_msg_idx = None 559 | 560 | for i, msg in enumerate(track): 561 | if msg.type == 'note_on': 562 | first_note_msg_idx = i 563 | break 564 | velocities = defaultdict(lambda:0) 565 | note_msgs = track[first_note_msg_idx:] 566 | for msg in note_msgs: 567 | if msg.type == 'note_on' and msg.velocity > 0: 568 | velocities[str(msg.velocity)] += 1 569 | dynamics = len(velocities.keys()) 570 | # print velocities 571 | if dynamics > 1: 572 | return dynamics 573 | else: 574 | return 0 575 | 576 | def scrub_track(track, velocity): 577 | first_note_msg_idx = None 578 | 579 | for i, msg in enumerate(track): 580 | if msg.type == 'note_on': 581 | first_note_msg_idx = i 582 | break 583 | 584 | note_msgs = track[first_note_msg_idx:] 585 | 586 | for msg in note_msgs: 587 | if msg.type == 'note_on' and msg.velocity > 0: 588 | msg.velocity = 10 589 | 590 | return track 591 | -------------------------------------------------------------------------------- /midifile.rb: -------------------------------------------------------------------------------- 1 | # encoding: ASCII-8BIT 2 | ############################################# 3 | ### Midifile Input and Output Facilities #### 4 | # 5 | # Copyright (c) 2008-2014 by Pete Goodeve 6 | # 7 | # vers 2014/11/30 -- Ruby 1.9 8 | # 9 | ############################################# 10 | 11 | 12 | ### Constant definitions etc. ### 13 | 14 | HDR=0x00 15 | END_OF_FILE=0x01 16 | 17 | NOTE_OFF=0x80 18 | NOTE_ON=0x90 19 | POLY_TOUCH=0xa0 20 | CONTROL_CHANGE=0xb0 21 | PROGRAM_CHANGE=0xc0 22 | CHANNEL_TOUCH=0xd0 23 | PITCH_BEND=0xe0 24 | SYSTEM=0xf8 # maybe... (not ever seen in midifile?) 25 | SYSEX=0xf0 26 | SYSEX_CONT=0xf7 27 | META=0xff 28 | TRK_START=0x100 ## not a byte code! 29 | TRK_END=0x1ff 30 | 31 | EvType = { 32 | NOTE_OFF=>"NOTE OFF", 33 | NOTE_ON=>"NOTE ON", 34 | POLY_TOUCH=>"POLY TOUCH", 35 | CONTROL_CHANGE=>"CONTROL CHANGE", 36 | PROGRAM_CHANGE=>"PROGRAM CHANGE", 37 | CHANNEL_TOUCH=>"CHANNEL TOUCH", 38 | PITCH_BEND=>"PITCHBEND", 39 | SYSEX=>"SYSEX", 40 | SYSEX_CONT=>"SYSEX CONTINUATION", 41 | META=>"META", 42 | TRK_START=>"TRACK START", 43 | TRK_END=>"TRACK END", 44 | HDR=>"FILE HEADER", 45 | END_OF_FILE=>"FILE END" 46 | } 47 | 48 | ## Meta Event types: 49 | 50 | SEQ_NUM=0x00 51 | TEXT=0x01 52 | COPYRIGHT=0x02 53 | TRACK_NAME=0x03 54 | INSTR_NAME=0x04 55 | LYRIC=0x05 56 | MARKER=0x06 57 | CUE_POINT=0x07 58 | DEVICE_NAME=0x09 59 | CHAN_PFX=0x20 60 | MIDI_PORT=0x21 61 | END_TRK=0x2f 62 | TEMPO=0x51 63 | SMPTE=0x54 64 | TIME_SIG=0x58 65 | KEY_SIG=0x59 66 | SEQUENCER=0x7f 67 | 68 | MetaType = { 69 | 0x00=>'SEQ_NUM', # 2-byte number 70 | 0x01=>'TEXT', # string 71 | 0x02=>'COPYRIGHT', # string 72 | 0x03=>'TRACK_NAME', # string 73 | 0x04=>'INSTR_NAME', # string 74 | 0x05=>'LYRIC', # string 75 | 0x06=>'MARKER', # string 76 | 0x07=>'CUE_POINT', # string 77 | 0x09=>'DEVICE_NAME',# string 78 | 0x20=>'CHAN_PFX', # byte 79 | 0x21=>'MIDI_PORT', # byte 80 | 0x2f=>'END_TRK', # also in @code attr 81 | 0x51=>'TEMPO', # 3-byte usec/q-note 82 | 0x54=>'SMPTE', # 5-bytes: hr mn sc fr ff 83 | 0x58=>'TIME_SIG', # 4-bytes: nn dd cc bb 84 | 0x59=>'KEY_SIG', # 2-bytes: sf mi 85 | 0x7f=>'SEQUENCER', # Sequencer specific 86 | } 87 | ############################## 88 | # Bypassing Ruby 1.9's completely idiotic revamping of basics!!! 89 | 90 | if String.instance_methods.include?(:getbyte) then 91 | 92 | class Array 93 | def nitems 94 | select{|x| x}.count 95 | end 96 | end 97 | 98 | class IO 99 | def getc 100 | getbyte 101 | end 102 | end 103 | 104 | class MString < String 105 | def [] i, *more 106 | if !i.is_a?(Integer) || !more.empty? then 107 | super(i, *more) 108 | else 109 | getbyte(i) 110 | end 111 | end 112 | end 113 | 114 | else #1.8.x or earlier... 115 | 116 | class MString < String 117 | end 118 | 119 | end 120 | 121 | ############################## 122 | 123 | 124 | module MidifileOps 125 | 126 | # readByte() should be defined in class appropriately 127 | 128 | # Read a sixteen bit value. 129 | def read16 130 | val = (readByte() << 8) + readByte() 131 | val = val - 0x10000 if (val & 0x8000).nonzero? 132 | return val 133 | end 134 | 135 | # Read a 32-bit value. 136 | def read32 137 | val = (readByte() << 24) + (readByte() << 16) + 138 | (readByte() << 8) + readByte() 139 | val = val - 0x100000000 if (val & 0x80000000).nonzero? 140 | return val 141 | end 142 | 143 | # Read a varlen value. 144 | def readVarlen 145 | c = readByte() 146 | val = 0 147 | p c if !c 148 | until !c || (c & 0x80).zero? 149 | val = (val | (c & 0x7f)) << 7 150 | c = readByte() 151 | p c if !c 152 | end 153 | puts "Error: c was #{c} at #{@bytes_left}" if !c 154 | val |= c 155 | #puts "got VarLen #{val}" 156 | return val 157 | end 158 | 159 | # Generate bytes for a 16-bit value. 160 | def bytes2(val) 161 | val = (val - 0x10000) & 0xffff if val < 0 162 | s = '' << ((val >> 8) & 0xff) 163 | s << (val & 0xff) 164 | end 165 | 166 | # Generate bytes for a 32-bit value. 167 | def bytes4(val) 168 | val = (val - 0x100000000) & 0xffffffff if val < 0 169 | '' << ((val >> 24) & 0xff) << ((val >> 16) & 0xff) << 170 | ((val >> 8) & 0xff) << (val & 0xff) 171 | end 172 | 173 | # Generate bytes for a variable length value. 174 | def bytesVarlen(val) 175 | return "\000" if val.zero? 176 | buf = Array.new() 177 | s = '' << (val & 0x7f) 178 | while (val >>= 7) > 0 179 | s << ((val & 0x7f) | 0x80) 180 | end 181 | s.reverse 182 | end 183 | 184 | end ### module MidifileOps 185 | 186 | 187 | class MidiItem 188 | include MidifileOps 189 | def initialize(code, trkno=nil, time=nil, delta=0) 190 | @code = code 191 | @trkno = trkno 192 | @time = time 193 | @delta = delta 194 | end 195 | # May need to adjust any of these: 196 | attr_accessor :code, :time, :delta, :trkno 197 | attr_accessor :listindx # used to restrict sorting's eagerness 198 | # Everything declared (readable) at this level 199 | # so we can check for existence without bombing...: 200 | attr_reader :format, :ntrks, :division 201 | attr_reader :chan, :data1, :data2, :running 202 | attr_reader :length, :data, :meta 203 | 204 | def to_s 205 | if @code == END_OF_FILE 206 | "EOF" 207 | else 208 | "#{@trkno ? @trkno : "--"}: #{@time? @time : "--"} #{@code}" 209 | end 210 | end 211 | def to_bytes 212 | '' 213 | end 214 | # The 'channel' accessors handle user-range 1..16 rather than 0..15: 215 | def channel 216 | @chan ? @chan+1 : nil # return 1..16 (to match 'gen...' methods) 217 | end 218 | def channel=(ch) 219 | @chan = ch - 1 220 | end 221 | end ### MidiItem 222 | 223 | 224 | class MidiHeader < MidiItem 225 | def initialize(format, ntrks, division) 226 | super(HDR) 227 | @format = format 228 | @ntrks = ntrks 229 | @division = division 230 | end 231 | def to_s 232 | "Format #{@format}: #{@ntrks} tracks -- division=#{@division}" 233 | end 234 | def to_bytes 235 | 'MThd' << bytes4(6) << bytes2(@format) << 236 | bytes2(@ntrks) << bytes2(@division) 237 | end 238 | end ### MidiHeader 239 | 240 | 241 | class TrackHeader < MidiItem 242 | def initialize(trkno, length) 243 | super(TRK_START, trkno) 244 | @length = length 245 | end 246 | def to_s 247 | "#{@trkno ? @trkno : "--"}: -- TRACK_START length #{@length} bytes" 248 | end 249 | def to_bytes 250 | 'MTrk' << bytes4(@length) 251 | end 252 | end ### TrackHeader 253 | 254 | 255 | class MidiEvent < MidiItem 256 | # For Channel Events: 257 | attr_accessor :chan, :data1, :data2, :running 258 | # For System & Meta Events: 259 | attr_accessor :length, :data, :meta 260 | 261 | def initialize(code, trkno=nil, time=nil, delta=0) 262 | # An event may be specified as an array of the actual MIDI bytes 263 | # (not SysEx (yet)) 264 | if code.is_a?(Array) 265 | if code[0] < 0xf0 then 266 | super(code[0] & 0xf0, trkno, time, delta) 267 | @chan = code[0] & 0x0f 268 | else 269 | super(code[0], trkno, time, delta) 270 | end 271 | @data1 = code[1] if code[1] 272 | @data2 = code[2] if code[2] 273 | else 274 | super(code, trkno, time, delta) 275 | end 276 | end 277 | 278 | def to_s 279 | begin 280 | s = "#{@trkno}: #{@time} " 281 | if @chan 282 | s << "#{EvType[@code]} chan=#{@chan}" 283 | case code 284 | when NOTE_ON 285 | s << " note #{@data1} velocity #{@data2}" 286 | when NOTE_OFF, POLY_TOUCH 287 | s << " note #{@data1} value #{@data2}" 288 | when CONTROL_CHANGE 289 | s << " controller #{@data1} value #{@data2}" 290 | when PROGRAM_CHANGE 291 | s << " program #{@data1}" 292 | when PITCH_BEND, CHANNEL_TOUCH 293 | s << " value #{@data1}" 294 | end 295 | "#{@data1} #{@data2? @data2 : ''}" 296 | s << " [R]" if @running 297 | elsif @meta 298 | mcode = MetaType[@meta] 299 | s << (mcode ? " #{mcode}" : "#{@code} 0x%x"%@meta) 300 | if (0x01..0x07) === @meta 301 | s << ' "' << @data << '"' 302 | elsif mcode == 'TEMPO' 303 | tempo = (data[0]*256 + data[1])*256 + data[2] 304 | s << " #{tempo} microsec/quarter-note" 305 | elsif @length && @length > 0 306 | s << " [" 307 | @data.each_byte {|b| s << " %d"%b} 308 | s << " ]" 309 | end 310 | end 311 | return s 312 | rescue 313 | p self 314 | raise 315 | end 316 | end 317 | 318 | def to_bytes 319 | ## NOTE: it is assumed that structure is correct! (unneeded == nil) 320 | case @code 321 | when SYSTEM 322 | ### Not handled (yet?) 323 | return nil 324 | when SYSEX, SYSEX_CONT, META, TRK_END 325 | command = @code 326 | else ## may need sanity check here... 327 | command = @code | @chan 328 | end 329 | s = '' << bytesVarlen(@delta) 330 | s << (command & 0xff) if !@running # (must ensure is bytesize) 331 | if @length 332 | s << @meta if @meta 333 | s << bytesVarlen(@length) << @data 334 | elsif @code == TRK_END # allow incomplete event struct 335 | s << 0x2f << 0x00 336 | elsif @code == PITCH_BEND 337 | val = @data1 + 8192 338 | s << (val & 0x7f) 339 | s << ((val >> 7) & 0x7f) 340 | elsif @data1 341 | s << @data1 342 | s << @data2 if @data2 343 | end 344 | return s 345 | end 346 | 347 | def to_midi 348 | # generate standard MIDI byte sequence 349 | return nil if @code >= META || @code < NOTE_OFF #should never hit this? 350 | s = '' << (@code < SYSEX ? (@code | @chan) : @code) 351 | if @code == SYSEX || @code == SYSEX_CONT 352 | s << @data 353 | elsif @code == PITCH_BEND 354 | val = @data1 + 8192 355 | s << (val & 0x7f) 356 | s << ((val >> 7) & 0x7f) 357 | elsif @data1 358 | s << @data1 359 | s << @data2 if @data2 360 | end 361 | return s 362 | end 363 | end ### MidiEvent 364 | 365 | 366 | class MidiTrack 367 | include MidifileOps 368 | def initialize(trkno, src=nil) 369 | @src = src 370 | @trkno = trkno 371 | ## ... and set up start and end etc... 372 | if src 373 | id, @trklen = src.read_spec() 374 | return nil if (id != 'MTrk') ## thought-holder... -- raise exception 375 | end 376 | @last_insert = 0 377 | @insert_time = 0 378 | end 379 | attr_reader :trkno, :trklen, :evlist 380 | 381 | # Read a single character from track 382 | def readByte 383 | byte = @src.instream.getc() 384 | @bytes_left -= 1 if byte 385 | return byte 386 | end 387 | 388 | def read_system_or_meta(code) 389 | length = 0 390 | data = MString.new 391 | case code 392 | when META 393 | # running status OK around this 394 | meta = readByte() 395 | length = readVarlen() 396 | code = TRK_END if meta == 0x2f 397 | when SYSEX, SYSEX_CONT 398 | @running = false # just in case 399 | @command = nil # maybe a litle protection from bad values... 400 | length = readVarlen() 401 | else 402 | @running = false # just in case 403 | puts "unexpected system event #{code}" 404 | return nil ### TEMP for now... 405 | end 406 | ev = MidiEvent.new(code, @trkno, @elapsed, @delta) # (temp) 407 | ev.meta = meta if meta 408 | ev.data = MString.new @src.instream.read(length) if length 409 | ev.length = length # (excludes meta byte) 410 | @bytes_left -= length 411 | return ev 412 | end 413 | 414 | def read_event 415 | @delta = readVarlen() # Delta time 416 | @elapsed += @delta 417 | code = readByte() # Read first byte 418 | if (code & 0x80).zero? # Running status? 419 | puts 'unexpected running status' if !@command || @command.zero? 420 | @running = true 421 | elsif code >= 0xf0 422 | return read_system_or_meta(code) 423 | else 424 | @command = code 425 | @running = false 426 | end 427 | ##puts "Status %x, code=%x chan = %x"%[@command, (@command>>4)&7, @command & 0xf] 428 | ev = MidiEvent.new(@command&0xf0, @trkno, @elapsed, @delta) 429 | ev.chan = @command & 0xf 430 | ev.running = @running # recorded for possible convenience 431 | ev.data1 = @running? code : readByte() 432 | case @command & 0xf0 433 | when NOTE_OFF..CONTROL_CHANGE 434 | ev.data2 = readByte() 435 | when PROGRAM_CHANGE, CHANNEL_TOUCH 436 | #do nothing 437 | when PITCH_BEND 438 | msb = readByte() 439 | ev.data1 += msb*128 - 8192 440 | end 441 | return ev 442 | end 443 | 444 | # Read the track. 445 | def each 446 | c = c1 = type = needed = 0 447 | @sysex_continue = false # True if last msg was unfinished 448 | @running = false # True when running status used 449 | @command = nil # (Possibly running) "status" byte 450 | 451 | @bytes_left = @trklen 452 | @elapsed = 0 453 | 454 | if @src 455 | yield read_event() while @bytes_left > 0 456 | return @elapsed 457 | elsif @evlist 458 | @evlist.each {|ev| yield ev} 459 | end 460 | end 461 | 462 | ####################### 463 | ## Output section 464 | 465 | def add(ev) 466 | return true if not ev.is_a?(MidiEvent) # so we can pass with no trouble 467 | return false if @src || (ev.trkno && ev.trkno != @trkno) 468 | if !@evlist || ev.code == TRK_START 469 | @evlist = [] 470 | @trklen = 0 471 | return true if ev.code == TRK_START 472 | end 473 | if !ev.time || 474 | (@evlist.last && @evlist.last.time && 475 | ev.time >= @evlist.last.time) then 476 | @evlist << ev 477 | else 478 | indx = -1 # append if no insertion point found 479 | @last_insert = 0 if ev.time < @insert_time 480 | for i in (@last_insert...@evlist.length) do 481 | if @evlist[i].time && @evlist[i].time > ev.time then 482 | indx = i 483 | break 484 | end 485 | end 486 | @evlist.insert(indx, ev) 487 | @last_insert = indx 488 | @insert_time = ev.time 489 | end 490 | @trklen += ev.to_bytes.length 491 | # puts "added event #{ev.code} making #{@trklen} bytes" 492 | return true 493 | end 494 | 495 | def empty?(end_alone_ok=nil) 496 | return true if !@evlist || @evlist.empty? 497 | return true if @evlist.length == 1 && 498 | @evlist.last.code == TRK_END && !end_alone_ok 499 | return false 500 | end 501 | 502 | def vet(use_running=true) 503 | return false if not @evlist 504 | time = 0 505 | # looks like three passes are needed: 506 | @evlist.each_with_index {|ev, indx| 507 | if !ev.time 508 | time = time + ev.delta if ev.delta 509 | ev.time = time 510 | elsif ev.time > time # seems reasonable... 511 | time = ev.time 512 | end 513 | ev.listindx = indx #prevent over-eager sorting 514 | } 515 | @evlist.sort!{|a,b| 516 | if a.time == b.time then 517 | res = a.listindx <=> b.listindx 518 | else 519 | res = a.time <=> b.time 520 | end 521 | # puts "sorting #{a} against #{b} -- res #{res}" 522 | res 523 | } 524 | time = 0 525 | @trklen = 0 526 | curr_code = 0 527 | curr_chan = nil 528 | to_delete = [] 529 | @evlist.each {|ev| 530 | if ev.code == TRK_END && ev != @evlist.last 531 | to_delete << ev ## can't remove within the loop!! 532 | else 533 | ev.delta = ev.time - time 534 | time = ev.time 535 | if use_running && ev.code == curr_code && 536 | ev.chan && ev.chan == curr_chan 537 | ev.running = true 538 | elsif ev.running 539 | ev.running = nil 540 | end 541 | curr_code = ev.code 542 | curr_chan = ev.chan 543 | @trklen += ev.to_bytes.length 544 | end 545 | } 546 | to_delete.each {|ev| @evlist.delete(ev)} # trklen shouldn't include these 547 | if @evlist.last.code != TRK_END 548 | ev = MidiEvent.new(TRK_END, @trkno, @evlist.last.time) 549 | ev.meta = 0x2f 550 | add(ev) 551 | end 552 | return true 553 | end 554 | 555 | def to_stream(stream) 556 | return false if !@evlist 557 | if @evlist.empty? || @evlist.last.code != TRK_END 558 | currtime = @evlist.last ? @evlist.last.time : 0 559 | ev = MidiEvent.new(TRK_END, @trkno, currtime) 560 | ev.meta = 0x2f 561 | add(ev) 562 | end 563 | stream << 'MTrk' << bytes4(@trklen) 564 | @evlist.each {|ev| 565 | stream << ev.to_bytes 566 | } 567 | end 568 | 569 | end ### MidiTrack 570 | 571 | 572 | class Midifile 573 | 574 | include MidifileOps 575 | 576 | def initialize(stream=nil) 577 | @instream = stream 578 | end 579 | 580 | attr_reader :instream, :tracks 581 | 582 | # Read a single character 583 | def readByte 584 | return @instream.getc() 585 | end 586 | 587 | # Read chunk spec 588 | def read_spec 589 | id = @instream.read(4) 590 | length = read32() 591 | return id,length 592 | end 593 | 594 | # Read the header 595 | def read_header_chunk 596 | id, size = read_spec() 597 | return nil if (id != 'MThd' || size != 6) ## crude for now... 598 | @format = read16() 599 | @ntrks = read16() 600 | @division = read16() 601 | end 602 | 603 | def each 604 | read_header_chunk() if @instream 605 | return nil if !@format 606 | @ntrks = @tracks.nitems if !@instream && @tracks 607 | yield MidiHeader.new(@format, @ntrks, @division) 608 | if @instream 609 | (0...@ntrks).each {|n| 610 | @elapsed = 0 611 | trk = MidiTrack.new(n, self) 612 | yield TrackHeader.new(n, trk.trklen) 613 | @elapsed = trk.each {|ev| yield ev} 614 | } 615 | elsif @tracks 616 | @tracks.compact.each {|trk| 617 | yield TrackHeader.new(trk.trkno, trk.trklen) 618 | trk.each {|ev| yield ev} 619 | } 620 | end 621 | yield MidiItem.new(END_OF_FILE) 622 | end 623 | 624 | ####################### 625 | ## Output section 626 | 627 | def format=(format) 628 | # can't change existing value 629 | if !@format && (0...2) === format 630 | @format = format 631 | else 632 | return nil 633 | end 634 | end 635 | 636 | def division=(division) 637 | @division = division 638 | end 639 | 640 | def addTrack(n=nil) 641 | @tracks = @tracks || [] 642 | if n then @tracks[n] = track = MidiTrack.new(n) 643 | else @tracks << track = MidiTrack.new(@tracks.length) 644 | end 645 | return track 646 | end 647 | 648 | def add(ev) 649 | if ev.code == HDR 650 | self.format=ev.format # use method to protect against change! 651 | @division=ev.division 652 | # number of tracks defined by array 653 | elsif ev.trkno # ignore END_OF_FILE (etc?) 654 | addTrack(ev.trkno) if !@tracks || !@tracks[ev.trkno] 655 | @tracks[ev.trkno].add(ev) 656 | end 657 | end 658 | 659 | def vet(use_running=true) 660 | return false if !@tracks || !@format 661 | @tracks.compact.each {|trk| 662 | res = trk.vet(use_running) 663 | @tracks.delete(trk) if trk.empty? 664 | return false if not res 665 | } 666 | return true 667 | end 668 | 669 | def to_stream(stream) 670 | return false if !@tracks || !@format 671 | stream << 'MThd' << bytes4(6) << bytes2(@format) << 672 | bytes2(@tracks.nitems) << bytes2(@division) 673 | @tracks.compact.each {|trk| 674 | trk.to_stream(stream) 675 | } 676 | return true 677 | end 678 | 679 | end ### Midifile 680 | 681 | ##################################################################### 682 | 683 | ### Utility Event Subclasses & Methods (to simplify event generation) 684 | 685 | ### Parameters in the following methods have these conventions: 686 | ### By default, 'ticks' should be delta-ticks. 687 | ### -- to use absolute elapsed ticks set MidiEvent.deltaTicks=false 688 | ### WHen a channel is supplied, it should be in the range 1..16 689 | ### (but 0 is also allowed as the first channel) 690 | ### Track numbering however, starts at zero. 691 | ### If there is neither a supplied track nor a default one, channel 692 | ### events will get a track according to their channel; metaevents 693 | ### will go to track 0. 694 | 695 | class MidiEvent # extension from main definition above 696 | # For User-created event convenience: 697 | @@defaultTrack = 0 698 | @@defaultChannel = 0 699 | @@useDeltas = true # false for absolute ticks 700 | def MidiEvent.track=(trk) 701 | @@defaultTrack = trk # can be nil for track=chan 702 | end 703 | def MidiEvent.channel=(chan) 704 | chan -= 1 if chan && (chan > 0) # supplied range 1..16 705 | @@defaultChannel = (chan || 0) & 0xF 706 | end 707 | def MidiEvent.deltaTicks=(usedeltas) 708 | @@useDeltas = usedeltas 709 | end 710 | def MidiEvent.track() 711 | @@defaultTrack 712 | end 713 | def MidiEvent.channel() 714 | @@defaultChannel+1 # returned in user notation (1..16)! 715 | end 716 | def MidiEvent.deltaTicks() 717 | @@useDeltas 718 | end 719 | end # MidiEvent extension 720 | 721 | 722 | class ChannelEvent < MidiEvent # Convenience subclass 723 | def initialize(code, ticks=0, data1=nil, data2=nil, chan=nil, track=nil) 724 | chan -= 1 if chan && (chan > 0) # supplied range 1..16 725 | @chan = chan || @@defaultChannel 726 | track = track || @@defaultTrack 727 | track = track || @chan+1 # back to user convention for track number 728 | if @@useDeltas then 729 | super(code, track, nil, ticks) 730 | else 731 | super(code, track, ticks) 732 | end 733 | @data1 = data1 734 | @data2 = data2 735 | end 736 | end ### ChannelEvent 737 | 738 | 739 | def genNoteOff(ticks, note, vel=0, chan=nil, track=nil) 740 | ChannelEvent.new(NOTE_OFF, ticks, note, vel, chan, track) 741 | end 742 | 743 | def genNoteOn(ticks, note, vel, chan=nil, track=nil) 744 | ChannelEvent.new(NOTE_ON, ticks, note, vel, chan, track) 745 | end 746 | 747 | def genPolyTouch(ticks, note, pressure, chan=nil, track=nil) 748 | ChannelEvent.new(POLY_TOUCH, ticks, note, pressure, chan, track) 749 | end 750 | 751 | def genControlChange(ticks, controller, value, chan=nil, track=nil) 752 | # controller numbers start at 0! 753 | ChannelEvent.new(CONTROL_CHANGE, ticks, controller, value, chan, track) 754 | end 755 | 756 | def genProgramChange(ticks, program, chan=nil, track=nil) 757 | program -= 1 if program > 0 # numbers start at 1! 758 | ChannelEvent.new(PROGRAM_CHANGE, ticks, program, nil, chan, track) 759 | end 760 | 761 | def genChannelTouch(ticks, pressure, chan=nil, track=nil) 762 | ChannelEvent.new(CHANNEL_TOUCH, ticks, pressure, nil, chan, track) 763 | end 764 | 765 | def genPitchBend(ticks, bend, chan=nil, track=nil) # Signed 14-bit value! (0 = no bend) 766 | ChannelEvent.new(PITCH_BEND, ticks, bend, nil, chan, track) # kept in data1 only (until output) 767 | end 768 | 769 | 770 | class MetaEvent < MidiEvent # Convenience subclass 771 | def initialize(ticks, meta, length, data, track=nil) 772 | track = (track || @@defaultTrack) || 0 # use zero if no default 773 | if @@useDeltas then 774 | super(0xFF, track, nil, ticks) 775 | else 776 | super(0xFF, track, ticks) 777 | end 778 | @meta = meta 779 | @length = length 780 | if data.is_a?(String) then 781 | @data = MString.new data 782 | else # assume array of bytes 783 | @data = MString.new 784 | data.each {|b| @data << b} 785 | end 786 | end 787 | end ### ChannelEvent 788 | 789 | 790 | def genText(ticks, type, text, track=nil) 791 | meta = (TEXT..CUE_POINT) === type ? type : 1 792 | MetaEvent.new(ticks, meta, text.length, text, track) 793 | end 794 | 795 | def genTempo(ticks, micros=500000, track=nil) 796 | tempo = [(micros>>16) & 0xFF, (micros>>8) & 0xFF, micros & 0xFF] 797 | MetaEvent.new(ticks, 0x51, 3, tempo, track) 798 | end 799 | 800 | def genTimeSignature(ticks, numer, denom, metronome=24, notat32=8, track=nil) 801 | ## see the midifile spec!! -- except that denom is the actual one (e.g. 3/"8") 802 | dpow2 = 0 803 | dpow2 += 1 while 2**dpow2 < denom 804 | MetaEvent.new(ticks, 0x58, 4, [numer, dpow2, metronome, notat32], track) 805 | end 806 | 807 | def genKeySignature(ticks, sharpsflats, minor=0, track=nil) 808 | MetaEvent.new(ticks, 0x59, 2, [sharpsflats, minor], track) 809 | end 810 | 811 | ## For metaevents not covered by the above...: 812 | def genMeta(ticks, meta, data, track=nil) 813 | MetaEvent.new(ticks, meta, data.length, data, track) 814 | end 815 | 816 | ## 'data' should be either an array of byte values, or a string 817 | def genSysEx(ticks, data, track=nil) 818 | track = (track || MidiEvent.track) || 0 # use zero if no default 819 | if MidiEvent.deltaTicks then 820 | ev = MidiEvent.new(0xF0, track, nil, ticks) 821 | else 822 | ev = MidiEvent.new(0xF0, track, ticks) 823 | end 824 | if data.is_a?(String) then 825 | ev.data = MString.new data 826 | else # assume array of bytes 827 | ev.data = MString.new 828 | data.each {|b| ev.data << b} 829 | end 830 | ev.data << 0xF7 if ev.data[-1] != 0xF7 # terminate as per protocol 831 | ev.length = ev.data.length 832 | return ev 833 | end 834 | 835 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from data_util import BatchGenerator 3 | import os 4 | import numpy as np 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | matplotlib.pyplot.ioff() 9 | 10 | class GenreLSTM(object): 11 | def __init__(self, dirs, mini=False, bi=False, one_hot=True, input_size=176, output_size=88, num_layers=3, batch_count=8): 12 | self.input_size = int(input_size) 13 | self.output_size = int(output_size) 14 | self.num_layers = int(num_layers) 15 | self.batch_count = int(batch_count) 16 | self.dirs = dirs 17 | self.bi = bi 18 | self.mini = mini 19 | self.one_hot = one_hot 20 | 21 | 22 | def prepare_bidiretional(self, glorot=True): 23 | print("[*] Preparing bidirectional dynamic RNN...") 24 | self.input_cell = tf.contrib.rnn.LSTMCell(self.input_size, forget_bias=1.0) 25 | self.input_cell = tf.contrib.rnn.DropoutWrapper(self.input_cell, input_keep_prob=self.input_keep_prob, output_keep_prob=self.output_keep_prob) 26 | self.enc_outputs, self.enc_states = tf.nn.dynamic_rnn(self.input_cell, self.inputs, sequence_length=self.seq_len, dtype=tf.float32) 27 | 28 | 29 | with tf.variable_scope("encode") as scope: 30 | 31 | self.j_cell_fw = tf.contrib.rnn.LSTMBlockCell(self.input_size,forget_bias=1.0) 32 | self.j_cell_fw = tf.contrib.rnn.DropoutWrapper(self.j_cell_fw, input_keep_prob=self.input_keep_prob, output_keep_prob=self.output_keep_prob) 33 | 34 | self.j_cell_bw = tf.contrib.rnn.LSTMBlockCell(self.input_size,forget_bias=1.0) 35 | self.j_cell_bw = tf.contrib.rnn.DropoutWrapper(self.j_cell_bw, input_keep_prob=self.input_keep_prob, output_keep_prob=self.output_keep_prob) 36 | 37 | if self.num_layers > 1: 38 | self.j_cell_fw = tf.contrib.rnn.MultiRNNCell([self.j_cell_fw]*self.num_layers) 39 | self.j_cell_bw = tf.contrib.rnn.MultiRNNCell([self.j_cell_bw]*self.num_layers) 40 | 41 | 42 | 43 | # self.j_outputs, _ = tf.nn.bidirectional_dynamic_rnn( 44 | (self.j_fw, self.j_bw) , _ = tf.nn.bidirectional_dynamic_rnn( 45 | self.j_cell_fw, 46 | self.j_cell_bw, 47 | self.enc_outputs, 48 | sequence_length=self.seq_len, 49 | dtype=tf.float32) 50 | 51 | 52 | self.jazz_outputs = tf.concat([self.j_fw, self.j_bw],2) 53 | # self.jazz_outputs = tf.add(self.j_outputs[0], self.j_outputs[1]) 54 | 55 | scope.reuse_variables() 56 | 57 | 58 | self.c_cell_fw = tf.contrib.rnn.LSTMBlockCell(self.input_size,forget_bias=1.0) 59 | self.c_cell_fw = tf.contrib.rnn.DropoutWrapper(self.c_cell_fw, input_keep_prob=self.input_keep_prob, output_keep_prob=self.output_keep_prob) 60 | 61 | self.c_cell_bw = tf.contrib.rnn.LSTMBlockCell(self.input_size,forget_bias=1.0) 62 | self.c_cell_bw = tf.contrib.rnn.DropoutWrapper(self.c_cell_bw, input_keep_prob=self.input_keep_prob, output_keep_prob=self.output_keep_prob) 63 | 64 | if self.num_layers > 1: 65 | self.c_cell_fw = tf.contrib.rnn.MultiRNNCell([self.c_cell_fw ]*self.num_layers) 66 | self.c_cell_bw = tf.contrib.rnn.MultiRNNCell([self.c_cell_bw ]*self.num_layers) 67 | 68 | (self.c_fw, self.c_bw), _ = tf.nn.bidirectional_dynamic_rnn( 69 | # self.c_outputs, _ = tf.nn.bidirectional_dynamic_rnn( 70 | self.c_cell_fw, 71 | self.c_cell_bw, 72 | self.enc_outputs, 73 | sequence_length=self.seq_len, 74 | dtype=tf.float32) 75 | 76 | 77 | self.classical_outputs = tf.concat([self.c_fw, self.c_bw],2) 78 | 79 | # self.classical_outputs = tf.add(self.c_outputs[0], self.c_outputs[1]) 80 | 81 | 82 | self.jazz_B = tf.Variable(tf.random_normal([self.output_size], stddev=0.1)) 83 | self.classical_B = tf.Variable(tf.random_normal([self.output_size], stddev=0.1)) 84 | 85 | if glorot: 86 | self.jazz_W = tf.get_variable("jazz_W", shape=[self.input_size*2, self.output_size],initializer=tf.contrib.layers.xavier_initializer()) 87 | self.classical_W = tf.get_variable("classical_W", shape=[self.input_size*2, self.output_size],initializer=tf.contrib.layers.xavier_initializer()) 88 | else: 89 | self.jazz_W = tf.Variable(tf.random_normal([self.input_size*2,self.output_size], stddev=0.1)) 90 | self.classical_W = tf.Variable(tf.random_normal([self.input_size*2,self.output_size], stddev=0.1)) 91 | 92 | self.jazz_linear_out = tf.reshape(self.jazz_outputs, [tf.shape(self.true_jazz_outputs)[0]*self.seq_len[-1], 2*self.input_size]) 93 | self.jazz_linear_out = tf.matmul(self.jazz_linear_out, self.jazz_W) + self.jazz_B 94 | self.jazz_linear_out = tf.reshape(self.jazz_linear_out,[tf.shape(self.true_jazz_outputs)[0],tf.shape(self.true_jazz_outputs)[1], tf.shape(self.true_jazz_outputs)[2]]) 95 | 96 | self.classical_linear_out = tf.reshape(self.classical_outputs, [tf.shape(self.true_classical_outputs)[0]*self.seq_len[-1], 2*self.input_size]) 97 | self.classical_linear_out = tf.matmul(self.classical_linear_out, self.classical_W) + self.classical_B 98 | self.classical_linear_out = tf.reshape(self.classical_linear_out,[tf.shape(self.true_classical_outputs)[0],tf.shape(self.true_classical_outputs)[1], tf.shape(self.true_classical_outputs)[2]]) 99 | 100 | def prepare_unidiretional(self, glorot=True): 101 | print("[*] Preparing unidirectional dynamic RNN...") 102 | self.input_cell = tf.contrib.rnn.LSTMCell(self.input_size, forget_bias=1.0) 103 | self.input_cell = tf.contrib.rnn.DropoutWrapper(self.input_cell, input_keep_prob=self.input_keep_prob, output_keep_prob=self.output_keep_prob) 104 | self.enc_outputs, self.enc_states = tf.nn.dynamic_rnn(self.input_cell, self.inputs, sequence_length=self.seq_len, dtype=tf.float32) 105 | 106 | with tf.variable_scope("encode") as scope: 107 | 108 | self.jazz_cell = tf.contrib.rnn.LSTMCell(self.input_size, forget_bias=1.0) 109 | self.jazz_cell = tf.contrib.rnn.DropoutWrapper(self.jazz_cell, input_keep_prob=self.input_keep_prob, output_keep_prob=self.output_keep_prob) 110 | 111 | self.jazz_outputs, self.jazz_states = tf.nn.dynamic_rnn(self.jazz_cell, self.enc_outputs, sequence_length=self.seq_len, dtype=tf.float32) 112 | 113 | scope.reuse_variables() 114 | 115 | self.classical_cell = tf.contrib.rnn.LSTMCell(self.input_size, forget_bias=1.0) 116 | self.classical_cell = tf.contrib.rnn.DropoutWrapper(self.classical_cell, input_keep_prob=self.input_keep_prob, output_keep_prob=self.output_keep_prob) 117 | self.classical_outputs, self.classical_states = tf.nn.dynamic_rnn(self.classical_cell, self.enc_outputs, sequence_length=self.seq_len, dtype=tf.float32) 118 | 119 | # self.cell = tf.contrib.rnn.DropoutWrapper(self.cell, input_keep_prob=self.input_keep_prob, output_keep_prob=self.output_keep_prob) 120 | # self.stacked_cell = tf.contrib.rnn.MultiRNNCell([self.cell] * self.num_layers) 121 | 122 | self.jazz_B = tf.Variable(tf.random_normal([self.output_size], stddev=0.1)) 123 | self.classical_B = tf.Variable(tf.random_normal([self.output_size], stddev=0.1)) 124 | 125 | if glorot: 126 | self.jazz_W = tf.get_variable("jazz_W", shape=[self.input_size, self.output_size],initializer=tf.contrib.layers.xavier_initializer()) 127 | self.classical_W = tf.get_variable("classical_W", shape=[self.input_size, self.output_size],initializer=tf.contrib.layers.xavier_initializer()) 128 | else: 129 | self.jazz_W = tf.Variable(tf.random_normal([self.input_size,self.output_size], stddev=0.1)) 130 | self.classical_W = tf.Variable(tf.random_normal([self.input_size,self.output_size], stddev=0.1)) 131 | 132 | self.jazz_linear_out = tf.reshape(self.jazz_outputs, [tf.shape(self.true_jazz_outputs)[0]*self.seq_len[-1], self.input_size]) 133 | self.jazz_linear_out = tf.matmul(self.jazz_linear_out, self.jazz_W) + self.jazz_B 134 | self.jazz_linear_out = tf.reshape(self.jazz_linear_out,[tf.shape(self.true_jazz_outputs)[0],tf.shape(self.true_jazz_outputs)[1], tf.shape(self.true_jazz_outputs)[2]]) 135 | 136 | self.classical_linear_out = tf.reshape(self.classical_outputs, [tf.shape(self.true_classical_outputs)[0]*self.seq_len[-1], self.input_size]) 137 | self.classical_linear_out = tf.matmul(self.classical_linear_out, self.classical_W) + self.classical_B 138 | self.classical_linear_out = tf.reshape(self.classical_linear_out,[tf.shape(self.true_classical_outputs)[0],tf.shape(self.true_classical_outputs)[1], tf.shape(self.true_classical_outputs)[2]]) 139 | 140 | def prepare_model(self, bi=False): 141 | 142 | self.inputs = tf.placeholder(tf.float32, [None, None, self.input_size]) 143 | 144 | self.true_jazz_outputs = tf.placeholder(tf.float32, [None, None, self.output_size]) 145 | self.true_classical_outputs = tf.placeholder(tf.float32, [None, None, self.output_size]) 146 | 147 | self.seq_len = tf.placeholder(tf.int32, [None]) 148 | 149 | self.input_keep_prob = tf.placeholder(tf.float32, None) 150 | self.output_keep_prob = tf.placeholder(tf.float32, None) 151 | 152 | if self.bi: 153 | self.prepare_bidiretional() 154 | else: 155 | self.prepare_unidiretional() 156 | 157 | self.jazz_negation = tf.subtract(self.true_jazz_outputs, self.jazz_linear_out) 158 | self.classical_negation = tf.subtract(self.true_classical_outputs, self.classical_linear_out) 159 | 160 | self.jazz_loss = tf.reduce_mean(tf.square(tf.subtract(self.jazz_linear_out, self.true_jazz_outputs))) 161 | self.classical_loss = tf.reduce_mean(tf.square(tf.subtract(self.classical_linear_out, self.true_classical_outputs))) 162 | 163 | tf.summary.scalar("Jazz error", self.jazz_loss) 164 | tf.summary.scalar("Classical error", self.classical_loss) 165 | tf.summary.scalar("Average error", self.jazz_loss+self.classical_loss/2) 166 | 167 | tf.summary.histogram("Jazz negation", self.jazz_negation) 168 | tf.summary.histogram("Classical negation", self.classical_negation) 169 | 170 | def clip_optimizer(self, learning_rate, loss): 171 | opt = tf.train.AdamOptimizer(learning_rate) 172 | gradients = opt.compute_gradients(loss) 173 | 174 | for i, (grad, var) in enumerate(gradients): 175 | if grad is not None: 176 | gradients[i] = (tf.clip_by_norm(grad, 10), var) 177 | 178 | return opt.apply_gradients(gradients) 179 | 180 | def train(self, data, model=None, starting_epoch=0, clip_grad=True, epochs=1001, input_keep_prob=0.5, output_keep_prob=0.5, learning_rate=0.001 , eval_epoch=20,val_epoch=10, save_epoch=1): 181 | 182 | self.data = data 183 | 184 | if clip_grad: 185 | jazz_optimizer = self.clip_optimizer(learning_rate,self.jazz_loss) 186 | classical_optimizer = self.clip_optimizer(learning_rate,self.classical_loss) 187 | else: 188 | jazz_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.jazz_loss) 189 | classical_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.classical_loss) 190 | 191 | 192 | 193 | self.sess = tf.Session() 194 | 195 | self.c_in_list, self.c_out_list,self.c_input_lens, self.c_files = self.eval_set('classical') 196 | self.j_in_list, self.j_out_list,self.j_input_lens, self.j_files = self.eval_set('jazz') 197 | 198 | if model: 199 | self.load(model) 200 | else: 201 | self.sess.run(tf.global_variables_initializer()) 202 | 203 | self.summary_op = tf.summary.merge_all() 204 | 205 | self.train_writer = tf.summary.FileWriter(os.path.join(self.dirs['logs_path'], 'train'), graph=self.sess.graph_def) 206 | self.test_writer = tf.summary.FileWriter(os.path.join(self.dirs['logs_path'], 'test'), graph=self.sess.graph_def) 207 | 208 | classical_batcher = BatchGenerator(self.data["classical"]["X"], self.data["classical"]["Y"], self.batch_count, self.input_size, self.output_size, mini=self.mini) 209 | jazz_batcher = BatchGenerator(self.data["jazz"]["X"], self.data["jazz"]["Y"], self.batch_count, self.input_size, self.output_size, mini=self.mini) 210 | 211 | self.v_classical_batcher = self.validate("classical") 212 | self.v_classical_batcher = self.v_classical_batcher.batch() 213 | 214 | self.v_jazz_batcher = self.validate("jazz") 215 | self.v_jazz_batcher = self.v_jazz_batcher.batch() 216 | 217 | 218 | classical_generator = classical_batcher.batch() 219 | jazz_generator = jazz_batcher.batch() 220 | 221 | print("[*] Initiating training...") 222 | 223 | for epoch in xrange(starting_epoch, epochs): 224 | 225 | classical_epoch_avg = 0 226 | jazz_epoch_avg = 0 227 | 228 | print("[*] Epoch %d" % epoch) 229 | for batch in range(classical_batcher.batch_count): 230 | batch_X, batch_Y, batch_len = classical_generator.next() 231 | batch_len = [batch_len] * len(batch_X) 232 | epoch_error, classical_summary, _ = self.sess.run([self.classical_loss, 233 | self.summary_op, 234 | classical_optimizer, 235 | ], feed_dict={ self.inputs: batch_X, 236 | self.true_classical_outputs: batch_Y, 237 | self.true_jazz_outputs: batch_Y, 238 | self.seq_len: batch_len, 239 | self.input_keep_prob: input_keep_prob, 240 | self.output_keep_prob: output_keep_prob}) 241 | classical_epoch_avg += epoch_error 242 | print("\tBatch %d/%d, Training MSE for Classical batch: %.9f" % (batch+1, classical_batcher.batch_count, epoch_error)) 243 | self.train_writer.add_summary(classical_summary, epoch*classical_batcher.batch_count + epoch) 244 | 245 | for batch in range(jazz_batcher.batch_count): 246 | batch_X, batch_Y, batch_len = jazz_generator.next() 247 | batch_len = [batch_len] * len(batch_X) 248 | epoch_error, jazz_summary, _ = self.sess.run([self.jazz_loss, 249 | self.summary_op, 250 | jazz_optimizer, 251 | ], feed_dict={ self.inputs: batch_X, 252 | self.true_jazz_outputs: batch_Y, 253 | self.true_classical_outputs: batch_Y, 254 | self.seq_len: batch_len, 255 | self.input_keep_prob: input_keep_prob, 256 | self.output_keep_prob: output_keep_prob}) 257 | jazz_epoch_avg += epoch_error 258 | print("\tBatch %d/%d, Training MSE for Jazz batch: %.9f" % (batch+1, jazz_batcher.batch_count, epoch_error)) 259 | 260 | self.train_writer.add_summary(jazz_summary, epoch*jazz_batcher.batch_count + epoch) 261 | # self.validation(epoch) 262 | 263 | print("[*] Average Training MSE for Classical epoch %d: %.9f" % (epoch, classical_epoch_avg/classical_batcher.batch_count)) 264 | print("[*] Average Training MSE for Jazz epoch %d: %.9f" % (epoch, jazz_epoch_avg/jazz_batcher.batch_count)) 265 | 266 | if epoch % val_epoch == 0 : 267 | print("[*] Validating model...") 268 | self.validation(epoch) 269 | 270 | if epoch % save_epoch == 0 : 271 | self.save(epoch) 272 | 273 | if epoch % eval_epoch == 0 : 274 | print("[*] Evaluating model...") 275 | self.evaluate(epoch) 276 | 277 | print("[*] Training complete.") 278 | 279 | def load(self, model_name, path=None) : 280 | print(" [*] Loading checkpoint...") 281 | self.saver = tf.train.Saver(max_to_keep=0) 282 | if not path: 283 | self.saver.restore(self.sess, os.path.join(self.dirs['model_path'], model_name)) 284 | else: 285 | self.sess = tf.Session() 286 | self.saver.restore(self.sess, path) 287 | 288 | def save(self, epoch): 289 | print("[*] Saving checkpoint...") 290 | model_name = "model-e%d.ckpt" % (epoch) 291 | self.saver = tf.train.Saver(max_to_keep=0) 292 | save_path = self.saver.save(self.sess, os.path.join(self.dirs['model_path'], model_name)) 293 | print("[*] Model saved in file: %s" % save_path) 294 | 295 | def predict(self, input_path, output_path): 296 | in_list = [] 297 | out_list = [] 298 | filenames = [] 299 | input_lens = [] 300 | 301 | loaded = np.load(input_path) 302 | true_vel = np.load(output_path)/127 303 | 304 | in_list.append(loaded) 305 | out_list.append(true_vel) 306 | 307 | input_len = [len(loaded)] 308 | 309 | c_error, c_out, j_out, e_out = self.sess.run([self.classical_loss, self.classical_linear_out, self.jazz_linear_out, self.enc_outputs], feed_dict={self.inputs:in_list, 310 | self.seq_len:input_len, 311 | self.input_keep_prob:1.0, 312 | self.output_keep_prob:1.0, 313 | self.true_classical_outputs:out_list, 314 | self.true_jazz_outputs:out_list}) 315 | 316 | return c_error, c_out, j_out, e_out, out_list 317 | 318 | def validate(self, type): 319 | '''Handles validation set data''' 320 | input_eval_path = os.path.join(self.dirs['eval_path'], "inputs") 321 | vel_eval_path = os.path.join(self.dirs['eval_path'], "velocities") 322 | 323 | c_input_eval_path = os.path.join(input_eval_path, "classical") 324 | c_vel_eval_path = os.path.join(vel_eval_path, "classical") 325 | 326 | j_input_eval_path = os.path.join(input_eval_path, "jazz") 327 | j_vel_eval_path = os.path.join(vel_eval_path, "jazz") 328 | 329 | if type == "classical": 330 | input_folder = os.listdir(c_input_eval_path) 331 | file_count = len(input_folder) 332 | vel_eval_path = c_vel_eval_path 333 | input_eval_path = c_input_eval_path 334 | else: 335 | input_folder = os.listdir(j_input_eval_path) 336 | file_count = len(input_folder) 337 | vel_eval_path = j_vel_eval_path 338 | input_eval_path = j_input_eval_path 339 | #CLASSICS 340 | 341 | in_list = [] 342 | out_list = [] 343 | filenames = [] 344 | for i, filename in enumerate(input_folder): 345 | if filename.split('.')[-1] == 'npy': 346 | 347 | vel_path = os.path.join(vel_eval_path, filename) 348 | input_path = os.path.join(input_eval_path, filename) 349 | 350 | true_vel = np.load(vel_path)/127 351 | loaded = np.load(input_path) 352 | 353 | if not self.one_hot: 354 | loaded = loaded/2 355 | 356 | in_list.append(loaded) 357 | out_list.append(true_vel) 358 | filenames.append(filename) 359 | valid_generator = BatchGenerator(in_list, out_list, self.batch_count, self.input_size, self.output_size, mini=False) 360 | return valid_generator 361 | 362 | def validation(self, epoch, pred_save=False): 363 | '''Computes and logs loss of validation set''' 364 | in_list, out_list, input_len = self.v_classical_batcher.next() 365 | input_len = [input_len] * len(in_list) 366 | c_error, c_out, j_out, e_out, c_summary = self.sess.run([self.classical_loss, 367 | self.classical_linear_out, 368 | self.jazz_linear_out, 369 | self.enc_outputs, 370 | self.summary_op], 371 | 372 | feed_dict={self.inputs:in_list, 373 | self.seq_len:input_len, 374 | self.input_keep_prob:1.0, 375 | self.output_keep_prob:1.0, 376 | self.true_classical_outputs:out_list, 377 | self.true_jazz_outputs:out_list}) 378 | 379 | 380 | # for i, x in enumerate(c_out): 381 | # self.plot_evaluation(epoch, c_files[i], c_out[i], j_out[i], e_out[i], out_list[i]) 382 | 383 | in_list, out_list, input_len = self.v_jazz_batcher.next() 384 | input_len = [input_len] * len(in_list) 385 | 386 | j_error, j_out, c_out, e_out, j_summary = self.sess.run([self.jazz_loss, 387 | self.jazz_linear_out, 388 | self.classical_linear_out, 389 | self.enc_outputs, 390 | self.summary_op], 391 | 392 | feed_dict={self.inputs:in_list, 393 | self.seq_len:input_len, 394 | self.input_keep_prob:1.0, 395 | self.output_keep_prob:1.0, 396 | self.true_jazz_outputs:out_list, 397 | self.true_classical_outputs:out_list}) 398 | 399 | 400 | # for i, x in enumerate(c_out): 401 | # self.plot_evaluation(epoch, j_files[i], c_out[i], j_out[i], e_out[i], out_list[i]) 402 | 403 | # print("[*] Validating Model...") 404 | 405 | print("[*] Average Test MSE for Classical epoch %d: %.9f" % (epoch, c_error)) 406 | print("[*] Average Test MSE for Jazz epoch %d: %.9f" % (epoch, j_error)) 407 | 408 | 409 | self.test_writer.add_summary(j_summary, epoch) 410 | self.test_writer.add_summary(c_summary, epoch) 411 | 412 | def eval_set(self, type): 413 | '''Loads validation set.''' 414 | input_eval_path = os.path.join(self.dirs['eval_path'], "inputs") 415 | vel_eval_path = os.path.join(self.dirs['eval_path'], "velocities") 416 | 417 | c_input_eval_path = os.path.join(input_eval_path, "classical") 418 | c_vel_eval_path = os.path.join(vel_eval_path, "classical") 419 | 420 | j_input_eval_path = os.path.join(input_eval_path, "jazz") 421 | j_vel_eval_path = os.path.join(vel_eval_path, "jazz") 422 | 423 | if type == "classical": 424 | input_folder = os.listdir(c_input_eval_path) 425 | file_count = len(input_folder) 426 | vel_eval_path = c_vel_eval_path 427 | input_eval_path = c_input_eval_path 428 | else: 429 | input_folder = os.listdir(j_input_eval_path) 430 | file_count = len(input_folder) 431 | vel_eval_path = j_vel_eval_path 432 | input_eval_path = j_input_eval_path 433 | #CLASSICS 434 | 435 | in_list = [] 436 | out_list = [] 437 | filenames = [] 438 | input_lens = [] 439 | 440 | for i, filename in enumerate(input_folder): 441 | if filename.split('.')[-1] == 'npy': 442 | 443 | vel_path = os.path.join(vel_eval_path, filename) 444 | input_path = os.path.join(input_eval_path, filename) 445 | 446 | true_vel = np.load(vel_path)/120 447 | loaded = np.load(input_path) 448 | 449 | if not self.one_hot: 450 | loaded = loaded/2 451 | 452 | in_list.append([loaded]) 453 | out_list.append([true_vel]) 454 | filenames.append(filename) 455 | input_len = [len(loaded)] 456 | input_lens.append(input_len) 457 | 458 | return in_list, out_list, input_lens, filenames 459 | 460 | def evaluate(self, epoch, pred_save=False): 461 | '''Performs prediciton and plots results on validation set.''' 462 | for i, filename in enumerate(self.c_files): 463 | c_error, c_out, j_out, e_out, summary = self.sess.run([self.classical_loss, 464 | self.classical_linear_out, 465 | self.jazz_linear_out, 466 | self.enc_outputs, 467 | self.summary_op], 468 | 469 | feed_dict={self.inputs:self.c_in_list[i], 470 | self.seq_len:self.c_input_lens[i], 471 | self.input_keep_prob:1.0, 472 | self.output_keep_prob:1.0, 473 | self.true_classical_outputs:self.c_out_list[i], 474 | self.true_jazz_outputs:self.c_out_list[i]}) 475 | 476 | 477 | self.plot_evaluation(epoch, filename, c_out, j_out, e_out, self.c_out_list[i]) 478 | # if pred_save: 479 | # predicted = os.path.join(self.dirs['pred_path'], filename.split('.')[0] + "-e%d" % (epoch)+".npy") 480 | # np.save(predicted, linear[-1]) 481 | 482 | for i, filename in enumerate(self.j_files): 483 | j_error, j_out, c_out, e_out, summary = self.sess.run([self.jazz_loss, 484 | self.jazz_linear_out, 485 | self.classical_linear_out, 486 | self.enc_outputs, 487 | self.summary_op], 488 | 489 | feed_dict={self.inputs:self.j_in_list[i], 490 | self.seq_len:self.j_input_lens[i], 491 | self.input_keep_prob:1.0, 492 | self.output_keep_prob:1.0, 493 | self.true_classical_outputs:self.j_out_list[i], 494 | self.true_jazz_outputs:self.j_out_list[i]}) 495 | 496 | self.plot_evaluation(epoch, filename, c_out, j_out, e_out, self.j_out_list[i]) 497 | # if pred_save: 498 | # predicted = os.path.join(self.dirs['pred_path'], filename.split('.')[0] + "-e%d" % (epoch)+".npy") 499 | # np.save(predicted, linear[-1]) 500 | 501 | 502 | def plot_evaluation(self, epoch, filename, c_out, j_out, e_out, out_list, path=None): 503 | '''Plotting/Saving training session graphs 504 | epoch -- epoch number 505 | c_out -- classical output 506 | j_out -- jazz output 507 | e_out -- interpretation layer output 508 | out_list -- actual output 509 | output_size -- output width 510 | path -- Save path''' 511 | 512 | fig = plt.figure(figsize=(14,11), dpi=120) 513 | fig.suptitle(filename, fontsize=10, fontweight='bold') 514 | 515 | graph_items = [out_list[-1]*127, c_out[-1]*127, j_out[-1]*127, (c_out[-1]-j_out[-1])*127 , e_out[-1]] 516 | plots = len(graph_items) 517 | cmap = ['jet', 'jet', 'jet', 'jet', 'bwr'] 518 | vmin = [0,0,0,-10,-1] 519 | vmax = [127,127,127,10,1] 520 | names = ["Actual", "Classical", "Jazz", "Difference", "Encoded"] 521 | 522 | 523 | for i in xrange(0, plots): 524 | fig.add_subplot(1,plots,i+1) 525 | plt.imshow(graph_items[i], vmin=vmin[i], vmax=vmax[i], cmap=cmap[i], aspect='auto') 526 | 527 | a = plt.colorbar(aspect=80) 528 | a.ax.tick_params(labelsize=7) 529 | ax = plt.gca() 530 | ax.xaxis.tick_top() 531 | 532 | if i == 0: 533 | ax.set_ylabel('Time Step') 534 | ax.xaxis.set_label_position('top') 535 | ax.tick_params(axis='both', labelsize=7) 536 | fig.subplots_adjust(top=0.85) 537 | ax.set_title(names[i], y=1.09) 538 | # plt.tight_layout() 539 | 540 | if self.one_hot: 541 | plt.xlim(0,88) 542 | else: 543 | plt.xlim(0,128) 544 | 545 | #Don't show the figure and save it 546 | if not path: 547 | out_png = os.path.join(self.dirs['png_path'], filename.split('.')[0] + "-e%d" % (epoch)+".png") 548 | plt.savefig(out_png, bbox_inches='tight') 549 | plt.close(fig) 550 | else: 551 | # out_png = os.path.join(self.dirs['png_path'], filename.split('.')[0] + "-e%d" % (epoch)+".png") 552 | # plt.savefig(out_png, bbox_inches='tight') 553 | # plt.close(fig) 554 | plt.show() 555 | plt.close(fig) 556 | -------------------------------------------------------------------------------- /pianoify.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 13, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pretty_midi\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import os\n", 14 | "from midi_util import quantize\n", 15 | "from mido import MidiFile\n", 16 | "from midi_util import velocity_range, quantize\n", 17 | "from random import shuffle" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 14, 23 | "metadata": { 24 | "collapsed": false 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "mid_path = '/Users/Iman/Desktop/jazz'\n", 29 | "out_path = '/Users/Iman/Desktop/jazz_out'\n", 30 | "\n", 31 | "if not os.path.exists(out_path):\n", 32 | " os.makedirs(out_path)\n", 33 | "\n", 34 | "total = len(os.listdir(mid_path))\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 15, 40 | "metadata": { 41 | "collapsed": false 42 | }, 43 | "outputs": [ 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "4thAvenueTheme.mid\n", 49 | "0 / 349\n", 50 | "4thAvenueTheme.mid\n", 51 | "A Sleepin' Bee.mid\n", 52 | "1 / 349\n", 53 | "A Sleepin' Bee.mid\n", 54 | "accustomed.mid\n", 55 | "2 / 349\n", 56 | "accustomed.mid\n", 57 | "afine-1.mid\n", 58 | "3 / 349\n", 59 | "afine-1.mid\n", 60 | "afine-2.mid\n", 61 | "4 / 349\n", 62 | "afine-2.mid\n", 63 | "Aghostofachance.mid\n", 64 | "5 / 349\n", 65 | "Aghostofachance.mid\n", 66 | "AHouseis.mid\n", 67 | "6 / 349\n", 68 | "AHouseis.mid\n", 69 | "Alabama.mid\n", 70 | "7 / 349\n", 71 | "Alabama.mid\n", 72 | "alfiepno.mid\n", 73 | "8 / 349\n", 74 | "alfiepno.mid\n", 75 | "Aliceinw.mid\n", 76 | "9 / 349\n", 77 | "Aliceinw.mid\n", 78 | "All The Things You Are.mid\n", 79 | "10 / 349\n", 80 | "All The Things You Are.mid\n", 81 | "AllOfMe2.mid\n", 82 | "11 / 349\n", 83 | "AllOfMe2.mid\n", 84 | "alltheth.mid\n", 85 | "12 / 349\n", 86 | "alltheth.mid\n" 87 | ] 88 | }, 89 | { 90 | "ename": "KeyboardInterrupt", 91 | "evalue": "", 92 | "output_type": "error", 93 | "traceback": [ 94 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 95 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 96 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mmid_q\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquantize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mmid_q\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 97 | "\u001b[0;32m/Users/Iman/research/programs/genre-lstm/midi_util.py\u001b[0m in \u001b[0;36mquantize\u001b[0;34m(mid, quantization)\u001b[0m\n\u001b[1;32m 250\u001b[0m 1/2**quantization.'''\n\u001b[1;32m 251\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 252\u001b[0;31m \u001b[0mquantized_mid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdeepcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 253\u001b[0m \u001b[0;31m# By convention, Track 0 contains metadata and Track 1 contains\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m \u001b[0;31m# the note on and note off events.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 98 | "\u001b[0;32m/usr/local/Cellar/python/2.7.12/Frameworks/Python.framework/Versions/2.7/lib/python2.7/copy.pyc\u001b[0m in \u001b[0;36mdeepcopy\u001b[0;34m(x, memo, _nil)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0mcopier\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_deepcopy_dispatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcopier\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcopier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 99 | "\u001b[0;32m/usr/local/Cellar/python/2.7.12/Frameworks/Python.framework/Versions/2.7/lib/python2.7/copy.pyc\u001b[0m in \u001b[0;36m_deepcopy_inst\u001b[0;34m(x, memo)\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 297\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dict__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 298\u001b[0;31m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeepcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 299\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'__setstate__'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 300\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__setstate__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 100 | "\u001b[0;32m/usr/local/Cellar/python/2.7.12/Frameworks/Python.framework/Versions/2.7/lib/python2.7/copy.pyc\u001b[0m in \u001b[0;36mdeepcopy\u001b[0;34m(x, memo, _nil)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0mcopier\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_deepcopy_dispatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcopier\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcopier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 101 | "\u001b[0;32m/usr/local/Cellar/python/2.7.12/Frameworks/Python.framework/Versions/2.7/lib/python2.7/copy.pyc\u001b[0m in \u001b[0;36m_deepcopy_dict\u001b[0;34m(x, memo)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 256\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miteritems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 257\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdeepcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeepcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 258\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 259\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_deepcopy_dict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 102 | "\u001b[0;32m/usr/local/Cellar/python/2.7.12/Frameworks/Python.framework/Versions/2.7/lib/python2.7/copy.pyc\u001b[0m in \u001b[0;36mdeepcopy\u001b[0;34m(x, memo, _nil)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0mcopier\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_deepcopy_dispatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcopier\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcopier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 103 | "\u001b[0;32m/usr/local/Cellar/python/2.7.12/Frameworks/Python.framework/Versions/2.7/lib/python2.7/copy.pyc\u001b[0m in \u001b[0;36m_deepcopy_list\u001b[0;34m(x, memo)\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 229\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 230\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdeepcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 231\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_deepcopy_list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 104 | "\u001b[0;32m/usr/local/Cellar/python/2.7.12/Frameworks/Python.framework/Versions/2.7/lib/python2.7/copy.pyc\u001b[0m in \u001b[0;36mdeepcopy\u001b[0;34m(x, memo, _nil)\u001b[0m\n\u001b[1;32m 188\u001b[0m raise Error(\n\u001b[1;32m 189\u001b[0m \"un(deep)copyable object of type %s\" % cls)\n\u001b[0;32m--> 190\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_reconstruct\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 191\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 105 | "\u001b[0;32m/usr/local/Cellar/python/2.7.12/Frameworks/Python.framework/Versions/2.7/lib/python2.7/copy.pyc\u001b[0m in \u001b[0;36m_reconstruct\u001b[0;34m(x, info, deep, memo)\u001b[0m\n\u001b[1;32m 349\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlistiter\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 350\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdeep\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 351\u001b[0;31m \u001b[0mitem\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeepcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 352\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdictiter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 106 | "\u001b[0;32m/usr/local/Cellar/python/2.7.12/Frameworks/Python.framework/Versions/2.7/lib/python2.7/copy.pyc\u001b[0m in \u001b[0;36mdeepcopy\u001b[0;34m(x, memo, _nil)\u001b[0m\n\u001b[1;32m 188\u001b[0m raise Error(\n\u001b[1;32m 189\u001b[0m \"un(deep)copyable object of type %s\" % cls)\n\u001b[0;32m--> 190\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_reconstruct\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 191\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 107 | "\u001b[0;32m/usr/local/Cellar/python/2.7.12/Frameworks/Python.framework/Versions/2.7/lib/python2.7/copy.pyc\u001b[0m in \u001b[0;36m_reconstruct\u001b[0;34m(x, info, deep, memo)\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[0mdictiter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 327\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdeep\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 328\u001b[0;31m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeepcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 329\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 330\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 108 | "\u001b[0;32m/usr/local/Cellar/python/2.7.12/Frameworks/Python.framework/Versions/2.7/lib/python2.7/copy.pyc\u001b[0m in \u001b[0;36mdeepcopy\u001b[0;34m(x, memo, _nil)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0mcopier\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_deepcopy_dispatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcopier\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcopier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 109 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "for i , filename in enumerate(os.listdir(mid_path)):\n", 115 | " print filename\n", 116 | " if filename.split('.')[-1] == 'mid' or filename.split('.')[-1] == 'MID' :\n", 117 | " print \"%d / %d\" % (i,total)\n", 118 | " print filename \n", 119 | " try:\n", 120 | " midi_data = pretty_midi.PrettyMIDI(os.path.join(mid_path, filename))\n", 121 | " mid = MidiFile(os.path.join(mid_path, filename))\n", 122 | " except (KeyError, IOError, IndexError, EOFError, ValueError):\n", 123 | " print \"NAUGHTY\"\n", 124 | " continue\n", 125 | "\n", 126 | " time_sig_msgs = [ msg for msg in mid.tracks[0] if msg.type == 'time_signature' ]\n", 127 | " \n", 128 | " if len(time_sig_msgs) == 1:\n", 129 | " time_sig = time_sig_msgs[0]\n", 130 | " if not (time_sig.numerator == 4 and time_sig.denominator == 4):\n", 131 | " print '\\tTime signature not 4/4. Skipping ...'\n", 132 | " continue\n", 133 | " else:\n", 134 | " print '\\tNo time signature. Skipping ...'\n", 135 | " continue\n", 136 | " \n", 137 | " mid_q = quantize(mid, 4)\n", 138 | " \n", 139 | " if not mid_q:\n", 140 | " print 'Invalid MIDI. Skipping...'\n", 141 | " continue\n", 142 | "\n", 143 | "\n", 144 | " piano = [instrument for instrument in midi_data.instruments if instrument.program < 8 ]\n", 145 | " piano = [instrument for instrument in piano if not instrument.is_drum ]\n", 146 | "\n", 147 | " if len(piano) > 0 and len(piano) < 3:\n", 148 | " for x in piano:\n", 149 | " x.program = 0\n", 150 | "\n", 151 | " midi_data.instruments = piano\n", 152 | " midi_data.write(os.path.join(out_path, filename))\n", 153 | " else:\n", 154 | " print '\\tNO piano.'\n" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "collapsed": true 162 | }, 163 | "outputs": [], 164 | "source": [] 165 | } 166 | ], 167 | "metadata": { 168 | "kernelspec": { 169 | "display_name": "Python 2", 170 | "language": "python", 171 | "name": "python2" 172 | }, 173 | "language_info": { 174 | "codemirror_mode": { 175 | "name": "ipython", 176 | "version": 2 177 | }, 178 | "file_extension": ".py", 179 | "mimetype": "text/x-python", 180 | "name": "python", 181 | "nbconvert_exporter": "python", 182 | "pygments_lexer": "ipython2", 183 | "version": "2.7.12" 184 | } 185 | }, 186 | "nbformat": 4, 187 | "nbformat_minor": 2 188 | } 189 | --------------------------------------------------------------------------------