├── .gitattributes ├── README.md ├── audio_to_midi_paper.pdf ├── benchmark.py ├── check_robust.py ├── cqt.py ├── create_dataset.py ├── decode_midi.py ├── encode_midi_segments.py ├── exploratory_visualization.py ├── get_model_prediction.py ├── handle_complex_nums.py ├── model_and_visualizations.1363 ├── .1363.png ├── .1363mae.png ├── .1363r2.png └── weights-improvement-38-0.1363.hdf5 ├── models.py ├── non_assert_tests.py └── normalisation.py /.gitattributes: -------------------------------------------------------------------------------- 1 | cqt_segments_midi_segments.pkl filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # audio_to_midi 2 | A Convolutional Neural Network which converts piano audio to a simplified MIDI format. 3 | The final model takes as input an audio file (mono- or polyphonic) 4 | and outputs a simplified MIDI output with the corresponding notes and 5 | note duration. This output can then be reconstructed into a standard MIDI file format. 6 | 7 | #### Main objective 8 | The automated conversion executed by the CNN is a step toward the larger 9 | goal of Automatic Music Transcription (AMT). AMT and Music Information Retrieval have many applications in industry, including Digital 10 | Audio Workstation software development and music recommendation systems. 11 | 12 | #### Setup 13 | 14 | To get started, download the data from the Saarland Music dataset, putting the audio in one directory named "audio" and 15 | the MIDI in another directory called "midi": 16 | 17 | ``` 18 | mkdir audio 19 | wget "http://resources.mpi-inf.mpg.de/SMD/SMD_MIDI-Audio-Piano-Music.html" -e robots=off -r -l1 -nd --no-parent -A.mp3 20 | mkdir midi 21 | wget "http://resources.mpi-inf.mpg.de/SMD/SMD_MIDI-Audio-Piano-Music.html" -e robots=off -r -l1 -nd --no-parent -A.mid 22 | ``` 23 | 24 | In create_dataset.py, in the ```main``` function, follow the instructions to set directory_str to the filepath where you 25 | downloaded the dataset. 26 | 27 | Run create_dataset.py. 28 | 29 | The file models.py contains the code for the final model. 30 | 31 | #### Libraries 32 | 33 | TensorFlow install instructions: 34 | https://www.tensorflow.org/install/ 35 | 36 | Keras install instructions: 37 | https://keras.io/#installation 38 | 39 | collections 40 | keras (see instructions above for dependencies) 41 | librosa 42 | math 43 | matplotlib 44 | mido 45 | ntpath 46 | numpy 47 | os 48 | pickle 49 | random 50 | scikit-learn 51 | shutil 52 | tensorflow (see instructions above for dependencies) 53 | time 54 | 55 | 56 | #### Issues 57 | On Windows, when loading in the audio files with librosa, the following error may arise: 58 | > raise NoBackendError() 59 | > audioread.NoBackendError 60 | 61 | If the above error is raised, try installing FFmpeg: 62 | https://www.wikihow.com/Install-FFmpeg-on-Windows 63 | 64 | 65 | On Linux, the equivalent fix is: 66 | ```sudo apt-get install libav-tools``` 67 | 68 | #### Copyright 69 | Copyright (c) 2018 Lillian Neff -------------------------------------------------------------------------------- /audio_to_midi_paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartmetzls/audio_to_midi/e28f5e67f49d2a0079b5dab31f861bc10994e4ce/audio_to_midi_paper.pdf -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import random 2 | random.seed(21) 3 | from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score 4 | import numpy as np 5 | np.random.seed(21) 6 | from models import reshape_for_conv2d 7 | from math import sqrt 8 | from models import pickle_if_not_pickled 9 | 10 | def benchmark(): 11 | cqt_segments, midi_segments = pickle_if_not_pickled() 12 | 13 | # flatten midi 14 | cqt_segments_reshaped, midi_segments_reshaped = reshape_for_conv2d(cqt_segments, midi_segments) 15 | midi_segments_reshaped_benchmark_pred = np.array(midi_segments_reshaped) 16 | np.random.shuffle(midi_segments_reshaped_benchmark_pred) 17 | 18 | mse_loss = mean_squared_error(midi_segments_reshaped, midi_segments_reshaped_benchmark_pred) 19 | rmse = sqrt(mse_loss) 20 | print("rmse:") 21 | print(rmse) 22 | 23 | mae = mean_absolute_error(midi_segments_reshaped, midi_segments_reshaped_benchmark_pred) 24 | print("mae:") 25 | print(mae) 26 | 27 | r2 = r2_score(midi_segments_reshaped, midi_segments_reshaped_benchmark_pred) 28 | print("r2:") 29 | print(r2) 30 | 31 | def main(): 32 | benchmark() 33 | 34 | if __name__ == '__main__': 35 | main() -------------------------------------------------------------------------------- /check_robust.py: -------------------------------------------------------------------------------- 1 | from keras.models import load_model 2 | from models import pickle_if_not_pickled, root_mse, reshape_for_conv2d, r2_coeff_determination 3 | from sklearn.model_selection import KFold, train_test_split 4 | # no GPU support for sklearn's cross_val_score 5 | from create_dataset import done_beep 6 | import matplotlib.pyplot as plt 7 | from models import create_model 8 | import numpy as np 9 | from keras.callbacks import ModelCheckpoint 10 | 11 | def k_fold_cv(): 12 | """ Check the robustness of the model using K-Fold CV """ 13 | cqt_segments, midi_segments = pickle_if_not_pickled() 14 | cqt_segments_reshaped, midi_segments_reshaped = reshape_for_conv2d(cqt_segments, midi_segments) 15 | k_folds = 5 16 | 17 | # The goal in this block is to set aside for testing a chunk of data which contains at least 18 | # one whole song the network will not have seen before. 19 | # Num data points to set aside for testing 20 | num_samples = len(cqt_segments_reshaped) 21 | num_testing = int(num_samples * .2) 22 | max_desired_test_start_index = num_samples - num_testing - 1 # - 1 is necessary bc of the way the remaining data gets sliced 23 | test_set_start_index = np.random.randint(0, max_desired_test_start_index) 24 | testing_end_index = test_set_start_index+num_testing 25 | cqt_test, midi_test = cqt_segments_reshaped[test_set_start_index:testing_end_index], \ 26 | midi_segments_reshaped[test_set_start_index:testing_end_index] 27 | 28 | # remaining data 29 | cqt_train_and_valid = np.concatenate( 30 | (cqt_segments_reshaped[:test_set_start_index], cqt_segments_reshaped[testing_end_index:]), axis=0) 31 | midi_train_and_valid = np.concatenate( 32 | (midi_segments_reshaped[:test_set_start_index], midi_segments_reshaped[testing_end_index:]), axis=0) 33 | 34 | # Generate a random order of elements 35 | # with np.random.permutation and index into the arrays data and classes with those elements 36 | 37 | # shuffle the remaining data: 38 | indices = np.random.permutation(len(cqt_train_and_valid)) 39 | data, labels = cqt_train_and_valid[indices], midi_train_and_valid[indices] 40 | 41 | k_fold = KFold(n_splits=k_folds) # Provides train/test indices to split data in train/test sets. 42 | for train_indices, valid_indices in k_fold.split(cqt_train_and_valid): 43 | print('Train: %s | valid: %s' % (train_indices, valid_indices)) 44 | example_cqt_segment = cqt_train_and_valid[0] 45 | input_height, input_width, input_depth = example_cqt_segment.shape 46 | example_midi_segment = midi_train_and_valid[0] 47 | one_d_array_len = len(example_midi_segment) 48 | 49 | for i, (train, valid) in enumerate(k_fold.split(cqt_train_and_valid)): 50 | print("Running Fold", i + 1, "/", k_folds) 51 | 52 | model = create_model(input_height, input_width, one_d_array_len) 53 | # saving time (best models have reached best val_score before epoch 40) 54 | epochs = 50 55 | filepath = "model_checkpoints/weights-improvement-{epoch:02d}-{val_loss:.4f}.hdf5" 56 | checkpointer = ModelCheckpoint(filepath=filepath, monitor='val_loss', 57 | verbose=1, save_best_only=True, save_weights_only=False) 58 | history_for_plotting = model.fit( 59 | data[train], labels[train], epochs=epochs, verbose=2, validation_data=(data[valid], labels[valid]), 60 | callbacks=[checkpointer]) 61 | test_score = model.evaluate(cqt_test, midi_test, verbose=0) 62 | print("test score:") 63 | print("[loss (rmse), root_mse, mae, r2_coeff_determination]") 64 | print(test_score) 65 | 66 | done_beep() 67 | 68 | # summarize history for loss 69 | # https://machinelearningmastery.com/display-deep-learning-model-training-history-in-keras/ 70 | plt.plot(history_for_plotting.history['loss']) 71 | plt.plot(history_for_plotting.history['val_loss']) 72 | plt.title('model loss') 73 | plt.ylabel('loss') 74 | plt.xlabel('epoch') 75 | plt.legend(['train rmse', 'validation rmse'], loc='upper right') 76 | plt.show() 77 | 78 | def main(): 79 | k_fold_cv() 80 | 81 | if __name__ == '__main__': 82 | main() -------------------------------------------------------------------------------- /cqt.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | 4 | def audio_segments_cqt(audio_segment_time_series, sr): 5 | cqt_of_segment = librosa.cqt(audio_segment_time_series, sr=sr) 6 | # Remove the imaginary number part of the CQT (ie We don't care where in the wave we are when we start reading it) 7 | cqt_of_segment_real = cqt_of_segment.real 8 | cqt_of_segment_copy_real = np.array(cqt_of_segment_real, dtype='float32') 9 | return cqt_of_segment_copy_real 10 | 11 | def main(): 12 | audio_segments_cqt() 13 | 14 | if __name__ == '__main__': 15 | main() -------------------------------------------------------------------------------- /create_dataset.py: -------------------------------------------------------------------------------- 1 | from mido import Message, MidiTrack, MidiFile 2 | import time 3 | from collections import defaultdict 4 | from cqt import * 5 | from encode_midi_segments import * 6 | from decode_midi import * 7 | from collections import namedtuple 8 | import pickle 9 | import random 10 | random.seed(21) 11 | import ntpath 12 | 13 | # data types 14 | AudioDecoded = namedtuple('AudioDecoded', ['time_series', 'sr']) 15 | 16 | def find_audio_files(directory_str): 17 | directory_str_audio = directory_str + "audio" 18 | audio_files = librosa.util.find_files(directory_str_audio, recurse=False, case_sensitive=True) #recurse False 19 | # means subfolders are not searched; case_sensitive True ultimately keeps songs from being 20 | # listed twice 21 | return audio_files 22 | 23 | def load_audio(audio_file): 24 | time_series, sr = librosa.core.load(audio_file, sr=22050*4) 25 | time_series_and_sr = AudioDecoded(time_series, sr) 26 | return time_series_and_sr 27 | 28 | def audio_timestamps(audio_file, time_series_and_sr): 29 | sr = time_series_and_sr[1] 30 | duration_song = librosa.core.get_duration(y=time_series_and_sr[0], sr=sr) 31 | audio_segment_length = 1 32 | midi_segment_length = 0.5 33 | padding = midi_segment_length / 2 34 | audio_start_times_missing_first = np.arange( 35 | padding,(duration_song-midi_segment_length), midi_segment_length) 36 | # For ex, 37 | # this will mean that for a song of len 157 seconds, the last needed MIDI file would be 156.5-157, 38 | # so the last audio start time should be 156.25. In this example 39 | # duration_song minus midi_segment_length is 156.5 40 | audio_start_times = np.concatenate((np.asarray([0]), audio_start_times_missing_first), 41 | axis=0) 42 | return audio_start_times, audio_segment_length, midi_segment_length 43 | 44 | def silent_audio(sr, padding): 45 | zeros = int(sr * padding) 46 | audio = np.zeros((zeros)) 47 | maxv = np.iinfo(np.int16).max #numpy.iinfo(type): Machine limits for integer types. 48 | silent_audio_time_series = (audio * maxv).astype(np.int16) 49 | return silent_audio_time_series 50 | 51 | # padding_portion provides the option to set padding to zero, so that if a better temporally aligned dataset is used 52 | # in the future, the code can segmented without padding 53 | def load_segment_of_audio_and_save(audio_file, start, audio_segment_length, midi_segment_length, duration_song, sr, 54 | padding_portion=.5): 55 | padding = midi_segment_length * padding_portion 56 | # If you don't want to pad the audio segments, comment out the following code and move the 57 | # line inside the else condition out of the else condition 58 | if start == 0: 59 | segment_duration = (audio_segment_length - padding) 60 | audio_segment_time_series_og, sr = librosa.core.load(audio_file, offset=0, 61 | duration=segment_duration, sr=sr) 62 | silence = silent_audio(sr, padding) 63 | audio_segment_time_series = np.concatenate((silence, audio_segment_time_series_og), axis=0) 64 | elif (duration_song - start) < audio_segment_length: #if we're at the end of a song and need 65 | # padding 66 | segment_duration = (duration_song - start) 67 | audio_segment_time_series_og, sr = librosa.core.load(audio_file, offset=start, 68 | duration=segment_duration, sr=sr) 69 | audio_segment_time_series = np.concatenate( 70 | (audio_segment_time_series_og, silent_audio(sr, (audio_segment_length-segment_duration))),axis=0) 71 | else: 72 | audio_segment_time_series, sr = librosa.core.load(audio_file, offset=start, 73 | duration=audio_segment_length, sr=sr) 74 | filename_format = "C:/Users/Lilly/audio_and_midi/segments/audio/{0}_start_time_{1}.wav" 75 | filename = filename_format.format(ntpath.basename(audio_file)[:-4], str(start)) 76 | 77 | # for testing by listening to audio (currently written for windows) 78 | # librosa.output.write_wav(filename, audio_segment_time_series, sr) 79 | 80 | return audio_segment_time_series 81 | 82 | def load_midi(directory_str, audio_file): 83 | midi_base_str = directory_str + "midi/" 84 | midi_str = midi_base_str + ntpath.basename(audio_file)[:-4] + ".mid" 85 | midi_file = MidiFile(midi_str) 86 | return midi_file 87 | 88 | def create_simplified_midi(midi_file): 89 | simplified_midi = [] 90 | ticks_since_start = 0 91 | 92 | # debugging/for future use with a dataset other than Saarland Music Dataset: 93 | # check for note offs without note ons and check control changes in song 94 | print("midi file:", midi_file) 95 | count_ons = 0 96 | count_offs = 0 97 | control_changes = [] 98 | 99 | for message in midi_file.tracks[0]: 100 | if message.type == "set_tempo": 101 | tempo_in_microsecs_per_beat = message.tempo 102 | for message in midi_file.tracks[-1]: 103 | # convert delta time (delta ticks) to TICKS since start 104 | ticks_since_start += message.time 105 | if message.type == "note_on" or message.type == "note_off": 106 | simplified_midi.append([message.type, message.note, ticks_since_start]) 107 | 108 | # debugging/for future use with a dataset other than SMD 109 | if message.type == "note_on": 110 | count_ons += 1 111 | else: 112 | count_offs += 1 113 | 114 | # debugging/for future use with a dataset other than SMD: check control changes in song; check for 123 115 | if message.type == 'control_change': 116 | control_changes.append(message.control) 117 | if message.control == 123: 118 | print("CONTROL ALL NOTES OFF MESSAGE!") 119 | 120 | # debugging/for future use with a dataset other than SMD 121 | if count_ons != count_offs: 122 | print("INEQUAL NUM ONS AND OFFS (prior to simplified midi)") 123 | print("count ons:", count_ons) 124 | print("offs:", count_offs) 125 | 126 | # convert ticks since start to absolute seconds 127 | tempo_in_secs_per_beat = tempo_in_microsecs_per_beat / 1000000 128 | ticks_per_beat = midi_file.ticks_per_beat 129 | secs_per_tick = tempo_in_secs_per_beat / ticks_per_beat 130 | length_in_secs_full_song = ticks_since_start * secs_per_tick 131 | for message in simplified_midi: 132 | message[-1] = message[-1] / ticks_since_start * length_in_secs_full_song 133 | 134 | # For comparison 135 | midi_length = midi_file.length 136 | # print("midi_file.length:", midi_file.length) 137 | # print("length based on last note off:", length_in_secs_full_song) 138 | 139 | # debugging/for future use with a dataset other than SMD 140 | print(set(control_changes)) 141 | 142 | return simplified_midi, ticks_since_start, length_in_secs_full_song 143 | 144 | def chop_simplified_midi(midi_file, midi_segment_length, simplified_midi, absolute_ticks_last_note, midi_start_times): 145 | 146 | # debugging/for future use with a dataset other than SMD: check for note ons and offs equal 147 | count_note_on = 0 148 | count_note_off = 0 149 | for message in simplified_midi: 150 | message_type = message[0] 151 | if message_type == 'note_on': 152 | count_note_on += 1 153 | if message_type == 'note_off': 154 | count_note_off += 1 155 | if count_note_on != count_note_off: 156 | print("midi file:", midi_file) 157 | print("inequal num ons and offs") 158 | 159 | # erase a redundant note_on or a redundant note off of the same pitch 160 | pitches_on = [] 161 | messages_to_erase = [] 162 | for message in simplified_midi: 163 | message_type = message[0] 164 | pitch = message[1] 165 | if message_type == 'note_on': 166 | if pitch in pitches_on: 167 | messages_to_erase.append(message) 168 | else: 169 | pitches_on.append(pitch) 170 | elif message_type == 'note_off': 171 | if pitch in pitches_on: 172 | pitches_on.remove(pitch) 173 | else: 174 | print("redundant off message found") 175 | messages_to_erase.append(message) 176 | for message in messages_to_erase: 177 | simplified_midi.remove(message) 178 | 179 | time_so_far = 0 180 | midi_segments =[] 181 | for midi_start_time in midi_start_times: 182 | midi_segment = [] 183 | for message in simplified_midi: 184 | end_time = time_so_far + midi_segment_length 185 | if message[-1] >= time_so_far and message[-1] < end_time: 186 | midi_segment.append(message) 187 | midi_segments.append([midi_start_time, midi_segment]) 188 | time_so_far = end_time 189 | return midi_segments, absolute_ticks_last_note 190 | 191 | def add_note_onsets_to_beginning_when_needed(midi_segments, midi_segment_length): 192 | # account for notes which start before a segment and end after a segment. For ex, a note which is 193 | # from .25 to 1.25. Without this special case accounting, this note would not be 194 | # present in the simplified MIDI clip 195 | pitches_to_set_to_on_at_beginning_of_segment = [] 196 | for start_time_and_messages in midi_segments: 197 | start_time = start_time_and_messages[0] 198 | messages = start_time_and_messages[1] 199 | for pitch in pitches_to_set_to_on_at_beginning_of_segment: 200 | messages.insert(0, ["note_on", pitch, start_time]) 201 | pitches_to_set_to_on_at_beginning_of_segment = [] 202 | # goal is to build a dict of dicts like so: {pitch: {num_ons: 4, num_offs: 3}} 203 | pitch_on_and_off_counts = defaultdict(lambda: defaultdict(int)) 204 | 205 | for message in messages: 206 | pitch = message[1] 207 | message_type = message[0] 208 | pitch_on_and_off_counts[pitch][message_type] += 1 209 | 210 | for pitch, on_off_counts_dict in pitch_on_and_off_counts.items(): 211 | count_on = on_off_counts_dict["note_on"] 212 | count_off = on_off_counts_dict["note_off"] 213 | if count_on > count_off: 214 | # this includes the "exclusive ending" (ie This will insert an end time of 215 | # .5, when technically it should be just under .5) 216 | pitches_to_set_to_on_at_beginning_of_segment.append(pitch) 217 | end_time = start_time + midi_segment_length 218 | messages.append(["note_off", pitch, end_time]) 219 | return midi_segments 220 | 221 | def find_lowest_and_highest_midi_note_numbers(simplified_midi): 222 | lowest_note_number_so_far = 127 223 | highest_note_number_so_far = 0 224 | for message in simplified_midi: 225 | note_number = message[1] 226 | if note_number < lowest_note_number_so_far: 227 | lowest_note_number_so_far = note_number 228 | if note_number > highest_note_number_so_far: 229 | highest_note_number_so_far = note_number 230 | return lowest_note_number_so_far, highest_note_number_so_far 231 | 232 | def reconstruct_midi(midi_filename, midi_segments, absolute_ticks_last_note, length_in_secs_full_song): 233 | time_so_far = 0 234 | for midi_segment in midi_segments: 235 | 236 | # time in seconds to absolute ticks 237 | absolute_ticks_midi_segment = [] 238 | start_time = midi_segment[0] 239 | messages = midi_segment[1] 240 | for message in messages: 241 | note_on_or_off = message[0] 242 | pitch = message[1] 243 | scaled_time = message[-1] 244 | time = scaled_time + start_time 245 | absolute_ticks = time * absolute_ticks_last_note / length_in_secs_full_song 246 | absolute_ticks_midi_segment.append([note_on_or_off, pitch, absolute_ticks]) 247 | # time in absolute ticks to delta time 248 | delta_time_midi_segment = [] 249 | for message in absolute_ticks_midi_segment: 250 | note_on_or_off = message[0] 251 | pitch = message[1] 252 | time = message[-1] 253 | delta_time = int(time - time_so_far) 254 | delta_time_midi_segment.append([note_on_or_off, pitch, delta_time]) 255 | time_so_far = time 256 | 257 | mid = MidiFile() 258 | track = MidiTrack() 259 | mid.tracks.append(track) 260 | for message in delta_time_midi_segment: 261 | note_on_or_off = message[0] 262 | pitch = int(message[1]) 263 | delta_ticks = message[-1] 264 | 265 | # debugging/for future use with a dataset other than SMD 266 | if type(delta_ticks) != int or delta_ticks < 0: 267 | print("time issue") 268 | 269 | track.append(Message(note_on_or_off, note=pitch, time=delta_ticks)) 270 | 271 | # for testing by listening to midi (currently written for windows) 272 | # str_start_time = str(midi_segment[0]) 273 | # filename_format = "C:/Users/Lilly/audio_and_midi/segments/midi/{0}_start_time_{1}.mid" 274 | # filename = filename_format.format(midi_filename, str_start_time) 275 | # mid.save(filename) 276 | 277 | return 278 | 279 | def done_beep(): 280 | import os 281 | # Windows 282 | if os.name == 'nt': 283 | import winsound 284 | duration = 1500 # millisecond 285 | freq = 392 # Hz 286 | winsound.Beep(freq, duration) 287 | # Linux 288 | if os.name == 'posix': 289 | duration = 1 # second 290 | freq = 392 # Hz 291 | os.system('play --no-show-progress --null --channels 1 synth %s sine %f' % (duration, freq)) 292 | 293 | def preprocess_audio_and_midi(directory_str): 294 | audio_files = find_audio_files(directory_str) 295 | cqt_segments = [] 296 | all_songs_encoded_midi_segments = [] 297 | midi_segments_count = 0 298 | for audio_file in audio_files: 299 | audio_decoded = load_audio(audio_file) 300 | time_series, sr = audio_decoded 301 | audio_start_times, audio_segment_length, midi_segment_length = audio_timestamps( 302 | audio_file, audio_decoded) 303 | midi_file = load_midi(directory_str, audio_file) 304 | duration_song = librosa.core.get_duration(time_series, sr) 305 | 306 | # print time differences 307 | midi_len = midi_file.length 308 | if midi_len != duration_song: 309 | print(audio_file, "audio len - midi len:", duration_song-midi_len) 310 | 311 | simplified_midi, absolute_ticks_last_note, length_in_secs_full_song = create_simplified_midi( 312 | midi_file) 313 | lowest_midi_note, highest_midi_note = find_lowest_and_highest_midi_note_numbers(simplified_midi) 314 | midi_start_times = np.arange(0, length_in_secs_full_song, midi_segment_length) 315 | 316 | if len(audio_start_times) != len(midi_start_times): 317 | if len(audio_start_times) > len(midi_start_times): 318 | num_start_times = len(midi_start_times) 319 | audio_start_times_shortened = audio_start_times[:num_start_times] 320 | audio_start_times = audio_start_times_shortened 321 | else: 322 | num_start_times = len(audio_start_times) 323 | midi_start_times_shortened = midi_start_times[:num_start_times] 324 | midi_start_times = midi_start_times_shortened 325 | 326 | for start_time in audio_start_times: 327 | padding_portion = .5 # padding on each side of audio is midi_segment_length * padding portion 328 | audio_segment_time_series = load_segment_of_audio_and_save(audio_file, start_time, 329 | audio_segment_length, 330 | midi_segment_length, 331 | duration_song, sr, 332 | padding_portion) 333 | cqt_of_segment = audio_segments_cqt(audio_segment_time_series, sr) 334 | cqt_segments.append(cqt_of_segment) 335 | 336 | midi_segments, absolute_ticks_last_note = chop_simplified_midi(midi_file, midi_segment_length, simplified_midi, absolute_ticks_last_note, midi_start_times) 337 | midi_start_times_and_segments_incl_onsets = \ 338 | add_note_onsets_to_beginning_when_needed(midi_segments, midi_segment_length) 339 | 340 | for midi_segment in midi_start_times_and_segments_incl_onsets: 341 | midi_start_time = midi_segment[0] 342 | messages = midi_segment[1] 343 | encoded_segment, num_discrete_time_values, num_notes = encode_midi_segment(midi_start_time, messages, midi_segment_length, lowest_midi_note, highest_midi_note) 344 | all_songs_encoded_midi_segments.append(encoded_segment) 345 | 346 | # debugging equal num cqt and midi segment 347 | midi_segments_count += 1 348 | 349 | # # testing 350 | # # encoded_segment = all_songs_encoded_midi_segments[1] 351 | # decoded_segments = [] 352 | # for encoded_segment in all_songs_encoded_midi_segments: 353 | # decoded_midi = decode_midi_segment(encoded_segment, midi_segment_length, num_discrete_time_values, lowest_midi_note) 354 | # decoded_segments.append(decoded_midi) 355 | # decoded_midi_start_times_and_segments = [] 356 | # for i in range(len(midi_start_times)): 357 | # decoded_midi_start_times_and_segments.append([midi_start_times[i], decoded_segments[i]]) 358 | 359 | # code for midi reconstruction (for listening purposes) 360 | # midi_filename = midi_file.filename[35:-4] 361 | # reconstruct_midi(midi_filename, decoded_midi_start_times_and_segments, absolute_ticks_last_note, length_in_secs_full_song) 362 | 363 | # pickle time 364 | with open('cqt_segments_midi_segments.pkl', 'wb') as handle: 365 | pickle.dump(cqt_segments, handle) 366 | pickle.dump(all_songs_encoded_midi_segments, handle) 367 | return 368 | 369 | def main(): 370 | # You must set directory str to be the filepath to your directory containing 2 directories: one named "audio" 371 | # (containing the 372 | # audio files) and one named "midi" (containing the MIDI files). For example, "C:/Users/Lilly/audio_and_midi/" or 373 | # "/home/lilly/Downloads/audio_midi/" 374 | directory_str = "/home/lilly/Downloads/audio_midi/" 375 | preprocess_audio_and_midi(directory_str) 376 | 377 | if __name__ == '__main__': 378 | start_time = time.time() 379 | main() 380 | print("--- %s seconds ---" % (time.time() - start_time)) 381 | 382 | 383 | 384 | 385 | 386 | 387 | -------------------------------------------------------------------------------- /decode_midi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def decode_midi_segment(encoded_segment, midi_segment_length, num_discrete_time_values, lowest): 4 | lowest = 21 # value for the entire SMD dataset 5 | decoded_midi_segment = [] 6 | 7 | notes_on, times_on = np.where(encoded_segment == 1) #tuple of arrays of x and y indices where a given condition holds in an array.3.0 8 | bucket_length = midi_segment_length / num_discrete_time_values 9 | bucket_start_times = np.arange(0, midi_segment_length, bucket_length) 10 | for i in range(len(notes_on)): 11 | pitch_scaled = notes_on[i] 12 | column = times_on[i] 13 | pitch = pitch_scaled + lowest 14 | previous_pitch_scaled = notes_on[i-1] 15 | if pitch_scaled != previous_pitch_scaled: 16 | note_on_time = bucket_start_times[column] 17 | midi_message = ['note_on', pitch, note_on_time] 18 | decoded_midi_segment.append(midi_message) 19 | last_column = num_discrete_time_values - 1 20 | if column != last_column: 21 | #if not last note in notes on 22 | if pitch_scaled != notes_on[-1]: 23 | next_pitch_scaled = notes_on[i+1] 24 | if pitch_scaled != next_pitch_scaled: 25 | note_off_time = bucket_start_times[column + 1] 26 | midi_message = ['note_off', pitch, note_off_time] 27 | decoded_midi_segment.append(midi_message) 28 | if column == last_column: 29 | note_off_time = bucket_start_times[0] + midi_segment_length 30 | midi_message = ['note_off', pitch, note_off_time] 31 | decoded_midi_segment.append(midi_message) 32 | decoded_midi_segment_chronological = sorted(decoded_midi_segment, key=lambda x: float(x[-1])) 33 | return decoded_midi_segment_chronological 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | # scaled_notes_on = [] 52 | # len_discrete_time_column_in_seconds = midi_segment_length / num_discrete_time_values 53 | # for i in range(num_discrete_time_values): #for each column 54 | # column = encoded_segment[:,i] 55 | # column_start_time = i * len_discrete_time_column_in_seconds 56 | # scaled_note_num = 0 57 | # midi_message = [] 58 | # 59 | # for scaled_note in scaled_notes_on: 60 | # if column[scaled_note] == 0: 61 | # midi_message.append('note_off') 62 | # pitch = scaled_note_num + lowest 63 | # midi_message.append(pitch) 64 | # midi_message.append(column_start_time) 65 | # scaled_notes_on.remove(scaled_note) 66 | # 67 | # for binary_note in column: #just one column 68 | # if binary_note == 1 and scaled_note_num not in scaled_notes_on: 69 | # midi_message.append('note_on') 70 | # pitch = scaled_note_num + lowest 71 | # midi_message.append(pitch) 72 | # midi_message.append(column_start_time) 73 | # scaled_notes_on.append(scaled_note_num) 74 | # scaled_note_num += 1 75 | # 76 | # if len(midi_message) > 0: 77 | # decoded_midi_segment.append(midi_message) 78 | return decoded_midi_segment_bin_time 79 | -------------------------------------------------------------------------------- /encode_midi_segments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def encode_midi_segment(midi_start_time, midi_segment, midi_segment_length, lowest, highest): 4 | lowest = 21 #value for the entire SMD dataset 5 | highest = 107 #value for the entire SMD dataset 6 | num_notes = highest - lowest + 1 7 | num_discrete_time_values = 6 8 | # instantiate shape 9 | encoded_segment = np.zeros(shape=(num_notes, num_discrete_time_values)) 10 | midi_segment_bin_time = [] 11 | times_aligned_to_closest_bucket_divide = [] 12 | bucket_length = midi_segment_length / num_discrete_time_values 13 | bucket_divides = np.arange(0, midi_segment_length, bucket_length) 14 | 15 | for message in midi_segment: 16 | time = message[-1] 17 | time_scaled = time - midi_start_time 18 | time_aligned_to_closest_bucket_divide = min(bucket_divides, key=lambda x:abs(x-time_scaled)) 19 | # add the time to the list of message times we'll have to handle 20 | times_aligned_to_closest_bucket_divide.append(time_aligned_to_closest_bucket_divide) 21 | 22 | # create a list of bin integers (corresponding to each message in the segment) 23 | bins = [] 24 | for time in times_aligned_to_closest_bucket_divide: 25 | nth_bucket = 0 26 | for bucket in bucket_divides: 27 | if time == bucket: 28 | bins.append(nth_bucket) 29 | nth_bucket += 1 30 | 31 | # build midi messages list with bin integers as time value 32 | i = 0 33 | for message in midi_segment: 34 | on_or_off = message[0] 35 | pitch = message[1] 36 | pitch_scaled = pitch - lowest 37 | bin = bins[i] 38 | midi_segment_bin_time.append([on_or_off, pitch_scaled, bin]) 39 | i += 1 40 | 41 | for message in midi_segment_bin_time: 42 | on_or_off = message[0] 43 | pitch_scaled = message[1] 44 | bin = message[-1] 45 | if on_or_off == 'note_on': 46 | encoded_segment[pitch_scaled, bin:] = 1 # turn note on for the rest of the segment 47 | else: # if note off 48 | if encoded_segment[pitch_scaled, bin - 1] == 1: # if note was not turned on inside this bin 49 | encoded_segment[pitch_scaled, bin:] = 0 50 | 51 | return encoded_segment, num_discrete_time_values, num_notes 52 | 53 | -------------------------------------------------------------------------------- /exploratory_visualization.py: -------------------------------------------------------------------------------- 1 | from models import pickle_if_not_pickled 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import librosa 5 | import librosa.display 6 | 7 | def visualize_cqt(example_cqt_segment, flipped=False): 8 | # visualize cqt power spectrum (for one segment) 9 | # librosa.display.specshow(cqt_low_notes_at_top, sr=22050 * 4, x_axis='time', y_axis='cqt_note') 10 | 11 | # alternate visualization 12 | CQT = librosa.amplitude_to_db(example_cqt_segment, ref=np.max) 13 | if not flipped: 14 | librosa.display.specshow(CQT, sr=22050 * 4, x_axis='time', y_axis='cqt_note') 15 | else: 16 | librosa.display.specshow(CQT, sr=22050 * 4, x_axis='time') 17 | plt.ylabel("High frequencies (pitches) ---> Low frequencies (pitches)") 18 | plt.colorbar(format='%+2.0f dB') 19 | plt.title("CQT Power Spectrum") 20 | plt.show() 21 | 22 | if not flipped: 23 | # visualize a cqt heatmap and it's corresponding midi heatmap (for one segment) 24 | plt.imshow(example_cqt_segment, cmap='hot', interpolation='nearest') 25 | plt.title("CQT Heatmap") 26 | plt.xlabel("Time") 27 | plt.ylabel("Frequency bins") 28 | plt.show() 29 | 30 | def visualize_midi(example_midi, title=None): 31 | plt.imshow(example_midi, cmap='hot', interpolation='nearest') 32 | if title == None: 33 | plt.title("MIDI Heatmap") 34 | else: 35 | plt.title(title) 36 | plt.xlabel("Time") 37 | plt.ylabel("MIDI pitch number (scaled)") 38 | plt.show() 39 | 40 | def main(): 41 | cqt_segments, midi_segments = pickle_if_not_pickled() 42 | random_index = np.random.randint(len(cqt_segments)) 43 | print("random index:") 44 | print(random_index) 45 | 46 | # index for sample in Exploratory Visualization section 47 | random_index = 8059 48 | 49 | example_cqt_segment = cqt_segments[random_index] 50 | 51 | # # flip cqt to match midi heatmap y axis 52 | # cqt_low_notes_at_top = np.flipud(example_cqt_segment) 53 | visualize_cqt(example_cqt_segment) 54 | 55 | example_midi = midi_segments[random_index] 56 | visualize_midi(example_midi) 57 | 58 | 59 | 60 | 61 | if __name__ == '__main__': 62 | main() -------------------------------------------------------------------------------- /get_model_prediction.py: -------------------------------------------------------------------------------- 1 | # from numpy.random import seed 2 | # seed(21) 3 | # from tensorflow import set_random_seed 4 | # set_random_seed(21) 5 | # import random 6 | # random.seed(21) 7 | 8 | from keras.models import load_model 9 | from models import pickle_if_not_pickled, reshape_for_dense, split, root_mse, reshape_for_conv2d, r2_coeff_determination 10 | from exploratory_visualization import visualize_cqt, visualize_midi 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from normalisation import * 14 | 15 | import librosa 16 | import librosa.display 17 | from sklearn.metrics import mean_squared_error 18 | from math import sqrt 19 | 20 | def get_model_pred(model, cqt_segment_reshaped, midi_true, midi_height, midi_width): 21 | midi_pred = model.predict(cqt_segment_reshaped) 22 | 23 | # regarding loss, where does this sample stand compared to the loss for the whole set 24 | remove_num_samples_dim = np.squeeze(midi_pred, axis=0) 25 | mse_loss = mean_squared_error(midi_true, remove_num_samples_dim) 26 | rmse = sqrt(mse_loss) 27 | print("rmse:") 28 | print(rmse) 29 | 30 | midi_pred_unflattened = np.reshape(midi_pred, (midi_height, midi_width)) 31 | midi_pred_rounded = np.rint(midi_pred_unflattened) 32 | midi_pred_rounded_unflattened = np.reshape(midi_pred_rounded, (midi_height, midi_width)) 33 | print("midi_pred_unflattened:") 34 | midi_true_unflattened = np.reshape(midi_true, (midi_height, midi_width)) 35 | visualize_midi(midi_pred_unflattened, title='MIDI Prediction') 36 | 37 | print("midi_pred_rounded_unflattened:") 38 | visualize_midi(midi_pred_rounded_unflattened, title='MIDI Prediction Rounded') 39 | 40 | print("midi_true:") 41 | visualize_midi(midi_true_unflattened, title='MIDI True') 42 | 43 | def main(): 44 | filepath = "model_and_visualizations.1363/weights-improvement-38-0.1363.hdf5" 45 | # filepath = "model_checkpoints/previous_run_architectures/weights-improvement-37-0.1344.hdf5" 46 | model = load_model(filepath, 47 | custom_objects={'root_mse': root_mse, 'r2_coeff_determination': r2_coeff_determination}) 48 | cqt_segments, midi_segments = pickle_if_not_pickled() 49 | example_midi = midi_segments[0] 50 | midi_height, midi_width = example_midi.shape 51 | # cqt_segments_reshaped, midi_segments_reshaped = reshape_for_dense(cqt_segments, midi_segments) 52 | cqt_segments_reshaped, midi_segments_reshaped = reshape_for_conv2d(cqt_segments, midi_segments) 53 | cqt_train, cqt_valid, cqt_test, midi_train, midi_valid, midi_test = split( 54 | cqt_segments_reshaped, midi_segments_reshaped) 55 | 56 | # look at one validation example 57 | num_validation_samples = len(cqt_valid) 58 | random_index = np.random.randint(num_validation_samples) 59 | example_cqt_segment = cqt_valid[random_index] 60 | midi_true = midi_valid[random_index] 61 | num_examples = 1 62 | input_height, input_width, input_depth = example_cqt_segment.shape 63 | example_cqt_segment_reshaped = example_cqt_segment.reshape(num_examples, input_height, input_width, input_depth) 64 | # get_model_pred(model, example_cqt_segment_reshaped, midi_true, midi_height, midi_width) 65 | 66 | valid_score = model.evaluate(cqt_valid, midi_valid) 67 | print("valid score:") 68 | print("[loss (rmse), root_mse, mae, r2_coeff_determination]") 69 | print(valid_score) 70 | 71 | # final test score 72 | score = model.evaluate(cqt_test, midi_test) 73 | print("[loss (rmse), root_mse, mae, r2_coeff_determination]") 74 | print(score) 75 | 76 | # look at one test example, including the cqt 77 | num_test_samples = len(cqt_test) 78 | random_index_test = np.random.randint(num_test_samples) 79 | # index for the sample referenced in the Free-form Visualization section 80 | random_index_test = 1498 81 | 82 | print("random index test:") 83 | print(random_index_test) 84 | example_cqt_segment_test = cqt_test[random_index_test] 85 | 86 | # visualize cqt power spectrum (for one segment) 87 | depth_removed = np.squeeze(example_cqt_segment_test, axis=2) 88 | visualize_cqt(depth_removed) 89 | 90 | # alternate visualization: flip cqt to match midi heatmap y axis 91 | cqt_low_notes_at_top = np.flipud(depth_removed) 92 | visualize_cqt(cqt_low_notes_at_top, flipped=True) 93 | 94 | midi_true_test = midi_test[random_index_test] 95 | example_test_cqt_segment_reshaped = example_cqt_segment_test.reshape( 96 | num_examples, input_height, input_width, input_depth) 97 | get_model_pred(model, example_test_cqt_segment_reshaped, midi_true_test, midi_height, midi_width) 98 | 99 | 100 | 101 | 102 | 103 | if __name__ == '__main__': 104 | main() -------------------------------------------------------------------------------- /handle_complex_nums.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | a = np.array([1+2j, 3+4j, 5+6j]) 4 | # a_real is a VIEW of a. It doesn't take up any more space in memory. It DOES have the base of 5 | # the imaginary numbers. 6 | a_real = a.real 7 | # a_copy_real creates a new object that does NOT have that base. 8 | a_copy_real = np.array(a_real) 9 | 10 | print(a_real) 11 | 12 | print(type(a[0])) 13 | print(type(a_real[0])) 14 | 15 | -------------------------------------------------------------------------------- /model_and_visualizations.1363/.1363.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartmetzls/audio_to_midi/e28f5e67f49d2a0079b5dab31f861bc10994e4ce/model_and_visualizations.1363/.1363.png -------------------------------------------------------------------------------- /model_and_visualizations.1363/.1363mae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartmetzls/audio_to_midi/e28f5e67f49d2a0079b5dab31f861bc10994e4ce/model_and_visualizations.1363/.1363mae.png -------------------------------------------------------------------------------- /model_and_visualizations.1363/.1363r2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartmetzls/audio_to_midi/e28f5e67f49d2a0079b5dab31f861bc10994e4ce/model_and_visualizations.1363/.1363r2.png -------------------------------------------------------------------------------- /model_and_visualizations.1363/weights-improvement-38-0.1363.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hartmetzls/audio_to_midi/e28f5e67f49d2a0079b5dab31f861bc10994e4ce/model_and_visualizations.1363/weights-improvement-38-0.1363.hdf5 -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # from numpy.random import seed 2 | # seed(21) 3 | # from tensorflow import set_random_seed 4 | # set_random_seed(21) 5 | # import random 6 | # random.seed(21) 7 | 8 | from create_dataset import preprocess_audio_and_midi, done_beep 9 | import pickle 10 | from sklearn.model_selection import train_test_split 11 | import os 12 | import numpy as np 13 | import random 14 | import tensorflow as tf 15 | import matplotlib.pyplot as plt 16 | import time 17 | 18 | # env var to set GPU options 19 | # (this was necessary for my machine. comment out line below if it throws an error.) 20 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 21 | 22 | from keras.layers import Conv2D 23 | from keras.layers import Flatten, Dense 24 | from keras.models import Sequential 25 | from keras.callbacks import ModelCheckpoint, TensorBoard 26 | from keras import optimizers 27 | from keras import backend as K 28 | 29 | # if the data is not already in a pkl file in the project directory, 30 | # save cqt_segments_midi_segments.pkl to project directory 31 | def pickle_if_not_pickled(): 32 | try: 33 | with open('cqt_segments_midi_segments.pkl', 'rb') as handle: 34 | cqt_segments = pickle.load(handle) 35 | midi_segments = pickle.load(handle) 36 | except (OSError, IOError) as err: 37 | # Windows 38 | if os.name == 'nt': 39 | directory_str = "C:/Users/Lilly/audio_and_midi/" 40 | # Linux 41 | if os.name == 'posix': 42 | directory_str = "/home/lilly/Downloads/audio_midi/" 43 | preprocess_audio_and_midi(directory_str) 44 | cqt_segments, midi_segments = pickle_if_not_pickled() 45 | return cqt_segments, midi_segments 46 | 47 | def reshape_for_conv2d(cqt_segments, midi_segments): 48 | # convert data to np array in order to pass data to keras functions 49 | cqt_segments_array = np.array(cqt_segments) 50 | midi_segments_array = np.array(midi_segments) 51 | 52 | # this is a convenient place to choose to run a portion of the dataset (for quick testing) 53 | cqt_segments_array = cqt_segments_array[:] 54 | midi_segments_array = midi_segments_array[:] 55 | 56 | # adds depth dimension to cqt segment (necessary for Conv2D) 57 | example_cqt_segment = cqt_segments_array[0] 58 | input_height, input_width = example_cqt_segment.shape 59 | however_many_there_are = -1 60 | cqt_segments_reshaped = cqt_segments_array.reshape(however_many_there_are, input_height, input_width, 1) 61 | 62 | # reshape output for Flatten layer 63 | example_midi_segment = midi_segments_array[0] 64 | output_height, output_width = example_midi_segment.shape 65 | one_d_array_len = output_height * output_width 66 | midi_segments_reshaped = midi_segments_array.reshape(however_many_there_are, one_d_array_len) 67 | 68 | return cqt_segments_reshaped, midi_segments_reshaped 69 | 70 | def reshape_for_dense(cqt_segments, midi_segments): 71 | cqt_segments_array = np.array(cqt_segments) 72 | midi_segments_array = np.array(midi_segments) 73 | 74 | # this is a convenient place to choose to run a portion of the dataset (for quick testing) 75 | cqt_segments_array = cqt_segments_array[:] 76 | midi_segments_array = midi_segments_array[:] 77 | 78 | # debugging nan loss (referenced in Implementation section) 79 | # check_cqt_infs = np.where(np.isinf(cqt_segments_array)) 80 | # check_midi_infs = np.where(np.isinf(midi_segments_array)) 81 | # print(check_cqt_infs) 82 | # print(check_midi_infs) 83 | # check_cqt_nans = np.where(np.isnan(cqt_segments_array)) 84 | # check_midi_nans = np.where(np.isnan(midi_segments_array)) 85 | # print(check_cqt_nans) 86 | # print(check_midi_nans) 87 | 88 | example_cqt_segment = cqt_segments_array[0] 89 | input_height, input_width = example_cqt_segment.shape 90 | 91 | however_many_there_are = -1 92 | 93 | # reshape output for Flatten layer 94 | example_midi_segment = midi_segments_array[0] 95 | output_height, output_width = example_midi_segment.shape 96 | one_d_array_len = output_height * output_width 97 | midi_segments_reshaped = midi_segments_array.reshape(however_many_there_are, one_d_array_len) 98 | 99 | return cqt_segments_array, midi_segments_reshaped 100 | 101 | def split(cqt_segments_reshaped, midi_segments_reshaped): 102 | # shuffles before splitting by default 103 | cqt_train_and_valid, cqt_test, midi_train_and_valid, midi_test = train_test_split( 104 | cqt_segments_reshaped, midi_segments_reshaped, test_size=0.2, random_state=21) 105 | 106 | cqt_train, cqt_valid, midi_train, midi_valid = train_test_split( 107 | cqt_train_and_valid, midi_train_and_valid, test_size=0.2, random_state=21) 108 | return cqt_train, cqt_valid, cqt_test, midi_train, midi_valid, midi_test 109 | 110 | def conv2d_model(cqt_train, cqt_valid, cqt_test, midi_train, midi_valid, midi_test): 111 | # this is a convenient point to confirm whether or not the full dataset is being run 112 | print("num training examples:") 113 | print(len(cqt_train)) 114 | 115 | example_cqt_segment = cqt_train[0] 116 | input_height, input_width, input_depth = example_cqt_segment.shape 117 | example_midi_segment = midi_train[0] 118 | one_d_array_len = len(example_midi_segment) 119 | 120 | model = create_model(input_height, input_width, one_d_array_len) 121 | 122 | epochs = 100 123 | filepath = "model_checkpoints/weights-improvement-{epoch:02d}-{val_loss:.4f}.hdf5" 124 | checkpointer = ModelCheckpoint(filepath=filepath, monitor='val_loss', 125 | verbose=1, save_best_only=True, save_weights_only=False) 126 | 127 | # create a callback tensorboard object: 128 | tensorboard = TensorBoard(log_dir='./tensorboard_logs', histogram_freq=0, batch_size=1, write_graph=True, 129 | write_grads=True, write_images=True, embeddings_freq=0, 130 | embeddings_layer_names=None, embeddings_metadata=None) 131 | 132 | history_for_plotting = model.fit(cqt_train, midi_train, 133 | validation_data=(cqt_valid, midi_valid), 134 | epochs=epochs, batch_size=32, callbacks=[checkpointer, tensorboard], verbose=1) 135 | score = model.evaluate(cqt_test, midi_test) 136 | 137 | # completely optional. plays a sound when the model finishes running 138 | done_beep() 139 | 140 | # also optional. times the runtime (thus far) and shows the time per epoch 141 | total_time = time.time() - start_time 142 | print("--- %s seconds ---" % (total_time)) 143 | print("each epoch:") 144 | print(total_time / epochs) 145 | 146 | # test run only 147 | print("test run score:") 148 | print("[loss (rmse), root_mse, mae, r2_coeff_determination]") 149 | print(score) 150 | 151 | #summarize history for loss 152 | #https://machinelearningmastery.com/display-deep-learning-model-training-history-in-keras/ 153 | plt.plot(history_for_plotting.history['loss']) 154 | plt.plot(history_for_plotting.history['val_loss']) 155 | plt.title('model loss') 156 | plt.ylabel('loss') 157 | plt.xlabel('epoch') 158 | plt.legend(['train rmse', 'validation rmse'], loc='upper right') 159 | plt.show() 160 | 161 | plt.plot(history_for_plotting.history['r2_coeff_determination']) 162 | plt.title('r2') 163 | plt.ylabel('r2_coeff_determination') 164 | plt.xlabel('epoch') 165 | plt.legend(['r2'], loc='upper left') 166 | plt.show() 167 | 168 | plt.plot(history_for_plotting.history['mean_absolute_error']) 169 | plt.title('mae') 170 | plt.ylabel('loss') 171 | plt.xlabel('epoch') 172 | plt.legend(['mae'], loc='upper right') 173 | plt.show() 174 | 175 | def create_model(input_height, input_width, one_d_array_len): 176 | """ Creates a model""" 177 | model = Sequential() 178 | model.add(Conv2D(filters=2, kernel_size=(1, 2), strides=(1), padding='same', activation='relu', input_shape=(input_height, input_width, 1))) 179 | model.add(Conv2D(filters=2, kernel_size=(7, 1), strides=(1), padding='same', activation='relu')) 180 | model.add(Conv2D(filters=3, kernel_size=(1, 2), strides=(1), padding='same', activation='relu')) 181 | model.add(Conv2D(filters=3, kernel_size=(7, 1), strides=(1), padding='same', activation='relu')) 182 | for i in range(2): 183 | model.add(Conv2D(filters=4, kernel_size=(1, 2), strides=(1, 2), padding='same', 184 | activation='relu')) 185 | for i in range(3): 186 | model.add(Conv2D(filters=5, kernel_size=(1, 2), strides=(1, 2), padding='same', 187 | activation='relu')) 188 | model.add(Conv2D(filters=6, kernel_size=(1, 2), strides=(1), padding='same', activation='relu')) 189 | model.add(Flatten()) 190 | model.add(Dense(one_d_array_len, activation='sigmoid')) 191 | model.summary() 192 | adam = optimizers.adam(lr=0.0001, decay=.00001) 193 | model.compile(loss=root_mse, 194 | optimizer=adam, 195 | metrics=[root_mse, 'mae', r2_coeff_determination]) 196 | return model 197 | 198 | def dense_model(cqt_train, cqt_valid, cqt_test, midi_train, midi_valid, midi_test): 199 | example_cqt_segment = cqt_train[0] 200 | input_height, input_width = example_cqt_segment.shape 201 | example_midi_segment = midi_train[0] 202 | one_D_array_len = len(example_midi_segment) 203 | model = Sequential() 204 | model.add(Dense(1044, input_shape=(input_height, input_width), activation='relu')) 205 | model.add(Flatten()) 206 | model.add(Dense(one_D_array_len, activation='sigmoid')) 207 | model.summary() 208 | model.compile(loss=root_mse, 209 | optimizer='adam') 210 | epochs = 100 211 | filepath = "model_checkpoints/weights-improvement-{epoch:02d}-{loss:.4f}.hdf5" 212 | checkpointer = ModelCheckpoint(filepath=filepath, monitor='loss', 213 | verbose=1, save_best_only=True, save_weights_only=False) 214 | history_for_plotting = model.fit(cqt_train, midi_train, 215 | validation_data=(cqt_valid, midi_valid), 216 | epochs=epochs, batch_size=1, callbacks=[checkpointer], verbose=1) 217 | score = model.evaluate(cqt_test, midi_test) 218 | 219 | # summarize history for loss 220 | # https://machinelearningmastery.com/display-deep-learning-model-training-history-in-keras/ 221 | plt.plot(history_for_plotting.history['loss']) 222 | plt.plot(history_for_plotting.history['val_loss']) 223 | plt.title('model loss') 224 | plt.ylabel('loss') 225 | plt.xlabel('epoch') 226 | plt.legend(['train', 'validation'], loc='upper left') 227 | plt.show() 228 | 229 | def root_mse(y_true, y_pred): 230 | # returns tensorflow.python.framework.ops.Tensor 231 | return tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(y_true, y_pred)))) 232 | 233 | # https://jmlb.github.io/ml/2017/03/20/CoeffDetermination_CustomMetric4Keras/ 234 | def r2_coeff_determination(y_true, y_pred): 235 | SS_res = K.sum(K.square(y_true - y_pred)) 236 | SS_tot = K.sum(K.square(y_true - K.mean(y_true))) 237 | # epsilon avoids division by zero 238 | return (1 - SS_res / (SS_tot + K.epsilon())) 239 | 240 | def depickle_and_model_architecture(): 241 | cqt_segments, midi_segments = pickle_if_not_pickled() 242 | cqt_segments_reshaped, midi_segments_reshaped = reshape_for_conv2d(cqt_segments, midi_segments) 243 | # cqt_segments_reshaped, midi_segments_reshaped = reshape_for_dense(cqt_segments, midi_segments) 244 | cqt_train, cqt_valid, cqt_test, midi_train, midi_valid, midi_test = split( 245 | cqt_segments_reshaped, midi_segments_reshaped) 246 | conv2d_model(cqt_train, cqt_valid, cqt_test, midi_train, midi_valid, midi_test) 247 | # dense_model(cqt_train, cqt_valid, cqt_test, midi_train, midi_valid, 248 | # midi_test) 249 | 250 | def main(): 251 | depickle_and_model_architecture() 252 | 253 | if __name__ == '__main__': 254 | # set start time here in order to clock runtime (incl. time per epoch) before metrics plots show 255 | start_time = time.time() 256 | main() -------------------------------------------------------------------------------- /non_assert_tests.py: -------------------------------------------------------------------------------- 1 | from create_dataset import * 2 | import os 3 | from os.path import exists 4 | import shutil 5 | from models import pickle_if_not_pickled 6 | 7 | #NON-ASSERT TESTS 8 | def difference_between_audio_first_note_onset_and_midi_first_note_on(): 9 | directory_str_audio = "C:/Users/Lilly/audio_and_midi/audio" 10 | audio_files_no_duplicates = find_audio_files(directory_str_audio) 11 | for audio_file in audio_files_no_duplicates: 12 | time_series_and_sr = load_audio(audio_file) 13 | first_onset = librosa.onset.onset_detect(time_series_and_sr[0], time_series_and_sr[1], 14 | units='time') 15 | midi_file = load_midi(audio_file) 16 | simplified_midi, ticks_since_start, length_in_secs = create_simplified_midi(midi_file) 17 | # assert (first_onset[0] - simplified_midi[0][-1]) == 0, "first onset and first dumbed " \ 18 | # "down " \ 19 | # "midi time are not the same" 20 | print("midi file:", midi_file) 21 | print("est'd first onset in audio - first onset in MIDI:",(first_onset[0] - 22 | simplified_midi[0][-1])) 23 | 24 | def find_song_with_greatest_diff_in_length(): 25 | directory_str_audio = "C:/Users/Lilly/audio_and_midi/audio" 26 | audio_files_no_duplicates = find_audio_files(directory_str_audio) 27 | greatest_diff_yet = [None, 0] 28 | for audio_file in audio_files_no_duplicates: 29 | time_series_and_sr = load_audio(audio_file) 30 | midi_file = load_midi(audio_file) 31 | duration = librosa.core.get_duration(time_series_and_sr[0], time_series_and_sr[1]) 32 | print("audio file:", audio_file) 33 | print("midi file length:", midi_file.length) 34 | print("duration:", duration) 35 | diff = abs(int(round(midi_file.length)) - duration) 36 | if diff > greatest_diff_yet[1]: 37 | greatest_diff_yet = [audio_file, diff] 38 | print("song with biggest diff:", greatest_diff_yet) 39 | 40 | def make_test_midi(): 41 | mid = MidiFile() 42 | track = MidiTrack() 43 | mid.tracks.append(track) 44 | #Test note off without on 45 | track.append(Message("note_off", note=64, time=640)) 46 | #Add control note 47 | track.append(Message("note_on", note=69, time=1000)) 48 | track.append(Message("note_off", note=69, time=1100)) 49 | #Test note on without off 50 | track.append(Message("note_on", note=73, time=1280)) 51 | filename_format = "C:/Users/Lilly/audio_and_midi/segments/midi/{0}_start_time_{1}.mid" 52 | filename = filename_format.format("testing_note_on_off", "test_1") 53 | mid.save(filename) 54 | return 55 | 56 | def check_if_there_is_a_midi_file_for_every_audio_file(): 57 | directory_str_audio = "C:/Users/Lilly/audio_and_midi/segments/audio" 58 | audio_files = find_audio_files(directory_str_audio) 59 | midi_file_folder = "C:/Users/Lilly/audio_and_midi/segments/midi/" 60 | count = 0 61 | to_listen_to = [] 62 | for audio_file in audio_files: 63 | basename = ntpath.basename(audio_file)[:-4] 64 | index_last_underscore = basename.rfind("_") 65 | audio_time = basename[index_last_underscore+1:] 66 | song_title = basename[:index_last_underscore] 67 | midi_time = float(audio_time) + 0.25 68 | midi_file_name = song_title + "_" + str(midi_time) + ".mid" 69 | if audio_time != "0.0": 70 | if not exists(midi_file_folder + midi_file_name): 71 | to_listen_to.append(audio_file) 72 | count += 1 73 | print(count) 74 | new_dir = "C:/Users/Lilly/audio_and_midi/segments/no_corresponding_midi" 75 | if not os.path.exists(new_dir): 76 | os.makedirs(new_dir) 77 | for file in to_listen_to: 78 | shutil.copy(file, new_dir) 79 | 80 | done_beep() 81 | 82 | def check_if_there_is_an_audio_file_for_every_midi_file(): 83 | directory_str_midi = "C:/Users/Lilly/audio_and_midi/segments/midi" 84 | midi_files = os.listdir(directory_str_midi) 85 | audio_file_folder = "C:/Users/Lilly/audio_and_midi/segments/audio/" 86 | count = 0 87 | to_listen_to = [] 88 | for midi_file in midi_files: 89 | basename = ntpath.basename(midi_file)[:-4] 90 | index_last_underscore = basename.rfind("_") 91 | midi_time = basename[index_last_underscore+1:] 92 | song_title = basename[:index_last_underscore] 93 | audio_time = float(midi_time) - 0.25 94 | audio_file_name = song_title + "_" + str(audio_time) + ".wav" 95 | if midi_time != "0.0": 96 | if not exists(audio_file_folder + audio_file_name): 97 | to_listen_to.append(midi_file) 98 | count += 1 99 | print(count) 100 | new_dir = "C:/Users/Lilly/audio_and_midi/segments/no_corresponding_audio" 101 | if not os.path.exists(new_dir): 102 | os.makedirs(new_dir) 103 | for file in to_listen_to: 104 | complete_file = directory_str_midi + "/" + file 105 | shutil.copy(complete_file, new_dir) 106 | 107 | done_beep() 108 | 109 | def check_reconstruct_midi_empty_begin(): 110 | filename = "test_case_empty_begin" 111 | midi_segments = [[0.0, []]] 112 | #[[0.0, []],[0.5, ['note_on', 64, 0.75], ['note_off', 64, 0.99]]] 113 | absolute_ticks_last_note = 3000 114 | length_in_secs_full_song = 1 115 | reconstruct_midi(filename, midi_segments, absolute_ticks_last_note, length_in_secs_full_song) 116 | 117 | def check_reconstruct_midi_start_after_song_end(): 118 | filename = "test_case_start_after_song_end" 119 | midi_segments = [[100.0, []]] 120 | #[[0.0, []],[0.5, ['note_on', 64, 0.75], ['note_off', 64, 0.99]]] 121 | absolute_ticks_last_note = 3000 122 | length_in_secs_full_song = 1.5 123 | reconstruct_midi(filename, midi_segments, absolute_ticks_last_note, length_in_secs_full_song) 124 | 125 | def check_reconstruct_midi_start_after_song_end_non_empty(): 126 | filename = "test_case_start_after_song_end_non_empty" 127 | # midi_segments = [[0.0, []],[0.5, ['note_on', 64, 0.75], ['note_off', 64, 0.99]]] 128 | midi_segments = [[100.0, [['note_on', 64, 101.75]]]] 129 | #[[0.0, []],[0.5, ['note_on', 64, 0.75], ['note_off', 64, 0.99]]] 130 | absolute_ticks_last_note = 3000 131 | length_in_secs_full_song = 1.5 132 | reconstruct_midi(filename, midi_segments, absolute_ticks_last_note, length_in_secs_full_song) 133 | 134 | def check_reconstruct_midi_start_after_song_end_non_empty2(): 135 | filename = "test_case_start_after_song_end_non_empty2" 136 | # midi_segments = [[0.0, []],[0.5, ['note_on', 64, 0.75], ['note_off', 64, 0.99]]] 137 | midi_segments = [[100.0, [['note_on', 64, 1.75]]]] 138 | #[[0.0, []],[0.5, ['note_on', 64, 0.75], ['note_off', 64, 0.99]]] 139 | absolute_ticks_last_note = 3000 140 | length_in_secs_full_song = 105.5 141 | reconstruct_midi(filename, midi_segments, absolute_ticks_last_note, length_in_secs_full_song) 142 | 143 | def catch_audio_no_midi_or_vice_versa(): 144 | directory_str_audio = "C:/Users/Lilly/audio_and_midi/audio" 145 | audio_files = find_audio_files(directory_str_audio) 146 | biggest_diff_in_num_start_times = 0 147 | num_songs_with_start_time_diff = 0 148 | for audio_file in audio_files: 149 | time_series_and_sr = load_audio(audio_file) 150 | audio_start_times, audio_segment_length, midi_segment_length = audio_timestamps( 151 | audio_file, time_series_and_sr) 152 | midi_file = load_midi(audio_file) 153 | 154 | # See time differences 155 | # duration = librosa.core.get_duration(time_series_and_sr[0], time_series_and_sr[1]) 156 | # midi_len = midi_file.length 157 | # if midi_len != duration: 158 | # print(audio_file, "audio len - midi len:", duration - midi_len) 159 | 160 | midi_segments, absolute_ticks_last_note, length_in_secs_full_song = chop_simplified_midi( 161 | midi_file, midi_segment_length) 162 | midi_start_timestamps = np.arange(0, length_in_secs_full_song, midi_segment_length) 163 | 164 | audio_starts_minus_midi_starts = len(audio_start_times) - len(midi_start_timestamps) 165 | if abs(audio_starts_minus_midi_starts) > 0: 166 | num_songs_with_start_time_diff += 1 167 | if abs(audio_starts_minus_midi_starts) > abs(biggest_diff_in_num_start_times): 168 | biggest_diff_in_num_start_times = audio_starts_minus_midi_starts 169 | print("last_biggest_diff_in_num_start_times:", audio_file) 170 | print("len audio timestamps:", len(audio_start_times)) 171 | print("len midi timestamps:", len(midi_start_timestamps)) 172 | print("num songs with diff num start times:", num_songs_with_start_time_diff) 173 | midi_segments_plus_onsets = \ 174 | add_note_onsets_to_beginning_when_needed(midi_segments, midi_segment_length) 175 | midi_filename = midi_file.filename[35:-4] 176 | reconstruct_midi(midi_filename, midi_segments_plus_onsets, absolute_ticks_last_note, 177 | length_in_secs_full_song) 178 | 179 | def run_without_errors(): 180 | directory_str_audio = "C:/Users/Lilly/audio_and_midi/audio" 181 | audio_files = find_audio_files(directory_str_audio) 182 | cqt_segments = [] 183 | all_songs_encoded_midi_segments = [] 184 | midi_segments_count = 0 185 | for audio_file in audio_files: 186 | audio_decoded = load_audio(audio_file) 187 | time_series, sr = audio_decoded 188 | audio_start_times, audio_segment_length, midi_segment_length = audio_timestamps( 189 | audio_file, audio_decoded) 190 | midi_file = load_midi(audio_file) 191 | 192 | duration_song = librosa.core.get_duration(time_series, sr) 193 | 194 | # See time differences 195 | midi_len = midi_file.length 196 | if midi_len != duration_song: 197 | print(audio_file, "audio len - midi len:", duration_song - midi_len) 198 | 199 | simplified_midi, absolute_ticks_last_note, length_in_secs_full_song = create_simplified_midi( 200 | midi_file) 201 | 202 | lowest_midi_note, highest_midi_note = find_lowest_and_highest_midi_note_numbers( 203 | simplified_midi) 204 | 205 | midi_start_times = np.arange(0, length_in_secs_full_song, midi_segment_length) 206 | 207 | if len(audio_start_times) != len(midi_start_times): 208 | if len(audio_start_times) > len(midi_start_times): 209 | num_start_times = len(midi_start_times) 210 | audio_start_times_shortened = audio_start_times[:num_start_times] 211 | audio_start_times = audio_start_times_shortened 212 | else: 213 | num_start_times = len(audio_start_times) 214 | midi_start_times_shortened = midi_start_times[:num_start_times] 215 | midi_start_times = midi_start_times_shortened 216 | 217 | for start_time in audio_start_times: 218 | padding_portion = .5 # padding on each side of audio is midi_segment_length * 219 | # padding portion 220 | audio_segment_time_series = load_segment_of_audio_and_save(audio_file, start_time, 221 | audio_segment_length, 222 | midi_segment_length, 223 | duration_song, sr, 224 | padding_portion) 225 | cqt_of_segment = audio_segments_cqt(audio_segment_time_series, sr) 226 | cqt_segments.append(cqt_of_segment) 227 | 228 | midi_segments, absolute_ticks_last_note = chop_simplified_midi(midi_file, 229 | midi_segment_length, 230 | simplified_midi, 231 | absolute_ticks_last_note, 232 | midi_start_times) 233 | midi_start_times_and_segments_incl_onsets = \ 234 | add_note_onsets_to_beginning_when_needed(midi_segments, midi_segment_length) 235 | 236 | for midi_segment in midi_start_times_and_segments_incl_onsets: 237 | midi_start_time = midi_segment[0] 238 | messages = midi_segment[1] 239 | encoded_segment, num_discrete_time_values, num_notes = encode_midi_segment( 240 | midi_start_time, messages, 241 | midi_segment_length, lowest_midi_note, highest_midi_note) 242 | all_songs_encoded_midi_segments.append(encoded_segment) # TODO: testcorrect num 243 | 244 | # debugging equal num cqt and midi segment 245 | midi_segments_count += 1 246 | 247 | # here for testing 248 | for midi in all_songs_encoded_midi_segments: 249 | decoded_midi = decode_midi_segment(midi, midi_segment_length, 250 | num_discrete_time_values, lowest_midi_note) 251 | 252 | def min_max_median_mean(): 253 | cqt_segments, midi_segments = pickle_if_not_pickled() 254 | cqt_segments_array = np.array(cqt_segments) 255 | cqt_min = np.min(cqt_segments_array) 256 | cqt_max = np.max(cqt_segments_array) 257 | cqt_median = np.median(cqt_segments_array) 258 | cqt_mean = np.mean(cqt_segments_array) 259 | midi_segments_array = np.array(midi_segments) 260 | midi_median = np.median(midi_segments_array) 261 | midi_mean = np.mean(midi_segments_array) 262 | print(cqt_mean, midi_mean) 263 | 264 | 265 | 266 | 267 | 268 | def main(): 269 | # difference_between_audio_first_note_onset_and_midi_first_note_on() 270 | # find_song_with_greatest_diff_in_length() 271 | # make_test_midi() 272 | # check_if_there_is_a_midi_file_for_every_audio_file() 273 | # check_if_there_is_an_audio_file_for_every_midi_file() 274 | # check_reconstruct_midi_empty_begin() 275 | # check_reconstruct_midi_start_after_song_end() 276 | # check_reconstruct_midi_start_after_song_end_non_empty() 277 | # check_reconstruct_midi_start_after_song_end_non_empty2() 278 | # catch_audio_no_midi_or_vice_versa() 279 | # find_lowest_and_highest_midi_note_numbers() 280 | # run_without_errors() 281 | min_max_median_mean() 282 | 283 | if __name__ == '__main__': 284 | start_time = time.time() 285 | main() 286 | print("--- %s seconds ---" % (time.time() - start_time)) -------------------------------------------------------------------------------- /normalisation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import preprocessing 3 | 4 | def shape_for_scaler(cqt_segments_array_train_or_valid_or_test): 5 | num_samples, height, width = cqt_segments_array_train_or_valid_or_test.shape 6 | array_transposed = np.transpose(cqt_segments_array_train_or_valid_or_test) 7 | print('transposed:') 8 | print(array_transposed) 9 | first_int = height * width 10 | array_transposed_reshaped = array_transposed.reshape(first_int, num_samples) 11 | print('array transposed reshaped:') 12 | print(array_transposed_reshaped) 13 | transposed_reshaped_transposed = np.transpose(array_transposed_reshaped) 14 | print('transposed reshaped transposed:') 15 | print(transposed_reshaped_transposed) 16 | return transposed_reshaped_transposed, num_samples, height, width 17 | 18 | def create_scaler(transposed_reshaped_transposed): 19 | scaler = preprocessing.StandardScaler() 20 | scaler.fit(transposed_reshaped_transposed) 21 | return scaler 22 | 23 | def feature_standardize_array(array_shaped_for_scaler, scaler, num_samples, height, width): 24 | standardized = scaler.transform(array_shaped_for_scaler) 25 | print('standardized:') 26 | print(standardized) 27 | untransposed = np.transpose(standardized) 28 | print('untransposed:') 29 | print(untransposed) 30 | #TODO: check that this reshape is correct (use diff len vectors in diff dimens) 31 | #replace 2, 2, 2, with input_width, height, depth? 32 | reshaped = untransposed.reshape(width, height, num_samples) 33 | print('reshaped:') 34 | print(reshaped) 35 | transposed = np.transpose(reshaped) 36 | print('transposed:') 37 | print(transposed) 38 | return transposed 39 | 40 | # working mini version: 41 | # 42 | # print('og:') 43 | # #is it valid to normalize overall rather than by bin? 44 | cqt_segments_array = [[[-11, 2], [1, 2]], 45 | [[-8, -2], [1, 10]] 46 | ] 47 | # cqt_segments_array = np.array(cqt_segments_array) 48 | # print(cqt_segments_array) 49 | # 50 | # print('----------') 51 | # 52 | # cqt_segments_array_transposed = np.transpose(cqt_segments_array) 53 | # print('transposed:') 54 | # print(cqt_segments_array_transposed) 55 | # 56 | # print('----------') 57 | # 58 | # cqt_segments_transposed_reshaped = \ 59 | # np.array(cqt_segments_array_transposed) 60 | # cqt_segments_transposed_reshaped = \ 61 | # cqt_segments_array_transposed.reshape(4, 2) 62 | # 63 | # print('cqt segments transposed reshaped:') 64 | # print(cqt_segments_transposed_reshaped) 65 | # 66 | # print('transposed:') 67 | # cqt_segments_transposed_reshaped_transposed = np.transpose(cqt_segments_transposed_reshaped) 68 | # #so no renaming 69 | # cqt_segments_transposed_reshaped = cqt_segments_transposed_reshaped_transposed 70 | # print(cqt_segments_transposed_reshaped) 71 | # 72 | # scaler = preprocessing.StandardScaler() 73 | # scaler.fit(cqt_segments_transposed_reshaped) 74 | # standardized = scaler.transform(cqt_segments_transposed_reshaped) 75 | # print('standardized:') 76 | # print(standardized) 77 | 78 | def main(): 79 | # for testing 80 | # three_d_array = [[[-11, 2, 1, 1, 1, 1, 1], 81 | # [1, 2, 1, 1, 1, 1, 1], 82 | # [1, 2, 1, 1, 1, 1, 1]], 83 | # 84 | # [[-8, -2, 1, 1, 1, 1, 1], 85 | # [1, 10, 1, 1, 1, 1, 1], 86 | # [1, 2, 1, 1, 1, 1, 1]] 87 | # ] 88 | # three_d_array = np.array(three_d_array) 89 | 90 | three_d_array = [[[-11, 2], [1, 2]], 91 | [[-8, -2], [1, 10]] 92 | ] 93 | three_d_array = np.array(three_d_array) 94 | 95 | array_shaped_for_scaler, num_samples, height, width = shape_for_scaler(three_d_array) 96 | scaler = create_scaler(array_shaped_for_scaler) 97 | standardized = feature_standardize_array(array_shaped_for_scaler, scaler, num_samples, height, width) 98 | 99 | if __name__ == '__main__': 100 | main() --------------------------------------------------------------------------------