├── .gitignore ├── LICENSE ├── README.md ├── archives └── v1 │ ├── long_samples │ ├── Baroque 1.mid │ ├── Baroque 1.mp3 │ ├── Baroque 2.mid │ ├── Baroque 2.mp3 │ ├── Baroque 3.mid │ ├── Baroque 3.mp3 │ ├── Classical 1.mid │ ├── Classical 1.mp3 │ ├── Classical 2.mid │ ├── Classical 2.mp3 │ ├── Classical 3.mid │ ├── Classical 3.mp3 │ ├── Romantic 1.mid │ ├── Romantic 1.mp3 │ ├── Romantic 2.mid │ ├── Romantic 2.mp3 │ ├── Romantic 3.mid │ └── Romantic 3.mp3 │ ├── model.h5 │ └── short_samples │ ├── Baroque 1.mp3 │ ├── Baroque 2.mp3 │ ├── Baroque 3.mp3 │ ├── Baroque 4.mp3 │ ├── Baroque 5.mp3 │ ├── Classical 1.mp3 │ ├── Classical 2.mp3 │ ├── Classical 3.mp3 │ ├── Classical 4.mp3 │ ├── Classical 5.mp3 │ ├── Romantic 1.mp3 │ ├── Romantic 2.mp3 │ ├── Romantic 3.mp3 │ ├── Romantic 4.mp3 │ └── Romantic 5.mp3 ├── constants.py ├── dataset.py ├── distribution.py ├── generate.py ├── midi_util.py ├── model.py ├── requirements.txt ├── scripts ├── cuda.sh ├── load_data.sh ├── mount.sh └── python.sh ├── test.py ├── train.py ├── util.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | out 3 | data 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Calclavia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepJ: A model for style-specific music generation 2 | https://arxiv.org/abs/1801.00887 3 | 4 | ## Abstract 5 | Recent advances in deep neural networks have enabled algorithms to compose music that is comparable to music composed by humans. However, few algorithms allow the user to generate music with tunable parameters. The ability to tune properties of generated music will yield more practical benefits for aiding artists, filmmakers, and composers in their creative tasks. In this paper, we introduce DeepJ - an end-to-end generative model that is capable of composing music conditioned on a specific mixture of composer styles. Our innovations include methods to learn musical style and music dynamics. We use our model to demonstrate a simple technique for controlling the style of generated music as a proof of concept. Evaluation of our model using human raters shows that we have improved over the Biaxial LSTM approach. 6 | 7 | ## Requirements 8 | - Python 3.5 9 | 10 | Clone Python MIDI (https://github.com/vishnubob/python-midi) 11 | `cd python-midi` 12 | then install using 13 | `python3 setup.py install`. 14 | 15 | Then, install other dependencies of this project. 16 | ``` 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | The dataset is not provided in this repository. To train a custom model, you will need to include a MIDI dataset in the `data/` folder. 21 | 22 | ## Usage 23 | To train a new model, run the following command: 24 | ``` 25 | python train.py 26 | ``` 27 | 28 | To generate music, run the following command: 29 | ``` 30 | python generate.py 31 | ``` 32 | 33 | Use the help command to see CLI arguments: 34 | ``` 35 | python generate.py --help 36 | ``` 37 | -------------------------------------------------------------------------------- /archives/v1/long_samples/Baroque 1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Baroque 1.mid -------------------------------------------------------------------------------- /archives/v1/long_samples/Baroque 1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Baroque 1.mp3 -------------------------------------------------------------------------------- /archives/v1/long_samples/Baroque 2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Baroque 2.mid -------------------------------------------------------------------------------- /archives/v1/long_samples/Baroque 2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Baroque 2.mp3 -------------------------------------------------------------------------------- /archives/v1/long_samples/Baroque 3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Baroque 3.mid -------------------------------------------------------------------------------- /archives/v1/long_samples/Baroque 3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Baroque 3.mp3 -------------------------------------------------------------------------------- /archives/v1/long_samples/Classical 1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Classical 1.mid -------------------------------------------------------------------------------- /archives/v1/long_samples/Classical 1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Classical 1.mp3 -------------------------------------------------------------------------------- /archives/v1/long_samples/Classical 2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Classical 2.mid -------------------------------------------------------------------------------- /archives/v1/long_samples/Classical 2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Classical 2.mp3 -------------------------------------------------------------------------------- /archives/v1/long_samples/Classical 3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Classical 3.mid -------------------------------------------------------------------------------- /archives/v1/long_samples/Classical 3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Classical 3.mp3 -------------------------------------------------------------------------------- /archives/v1/long_samples/Romantic 1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Romantic 1.mid -------------------------------------------------------------------------------- /archives/v1/long_samples/Romantic 1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Romantic 1.mp3 -------------------------------------------------------------------------------- /archives/v1/long_samples/Romantic 2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Romantic 2.mid -------------------------------------------------------------------------------- /archives/v1/long_samples/Romantic 2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Romantic 2.mp3 -------------------------------------------------------------------------------- /archives/v1/long_samples/Romantic 3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Romantic 3.mid -------------------------------------------------------------------------------- /archives/v1/long_samples/Romantic 3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/long_samples/Romantic 3.mp3 -------------------------------------------------------------------------------- /archives/v1/model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/model.h5 -------------------------------------------------------------------------------- /archives/v1/short_samples/Baroque 1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Baroque 1.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Baroque 2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Baroque 2.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Baroque 3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Baroque 3.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Baroque 4.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Baroque 4.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Baroque 5.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Baroque 5.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Classical 1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Classical 1.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Classical 2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Classical 2.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Classical 3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Classical 3.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Classical 4.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Classical 4.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Classical 5.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Classical 5.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Romantic 1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Romantic 1.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Romantic 2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Romantic 2.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Romantic 3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Romantic 3.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Romantic 4.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Romantic 4.mp3 -------------------------------------------------------------------------------- /archives/v1/short_samples/Romantic 5.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/DeepJ/e2058ca4c05d10e66a87c6c4205cd5495c846627/archives/v1/short_samples/Romantic 5.mp3 -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # Define the musical styles 4 | genre = [ 5 | 'baroque', 6 | 'classical', 7 | 'romantic' 8 | ] 9 | 10 | styles = [ 11 | [ 12 | 'data/baroque/bach', 13 | 'data/baroque/handel', 14 | 'data/baroque/pachelbel' 15 | ], 16 | [ 17 | 'data/classical/burgmueller', 18 | 'data/classical/clementi', 19 | 'data/classical/haydn', 20 | 'data/classical/beethoven', 21 | 'data/classical/brahms', 22 | 'data/classical/mozart' 23 | ], 24 | [ 25 | 'data/romantic/balakirew', 26 | 'data/romantic/borodin', 27 | 'data/romantic/brahms', 28 | 'data/romantic/chopin', 29 | 'data/romantic/debussy', 30 | 'data/romantic/liszt', 31 | 'data/romantic/mendelssohn', 32 | 'data/romantic/moszkowski', 33 | 'data/romantic/mussorgsky', 34 | 'data/romantic/rachmaninov', 35 | 'data/romantic/schubert', 36 | 'data/romantic/schumann', 37 | 'data/romantic/tchaikovsky', 38 | 'data/romantic/tschai' 39 | ] 40 | ] 41 | 42 | NUM_STYLES = sum(len(s) for s in styles) 43 | 44 | # MIDI Resolution 45 | DEFAULT_RES = 96 46 | MIDI_MAX_NOTES = 128 47 | MAX_VELOCITY = 127 48 | 49 | # Number of octaves supported 50 | NUM_OCTAVES = 4 51 | OCTAVE = 12 52 | 53 | # Min and max note (in MIDI note number) 54 | MIN_NOTE = 36 55 | MAX_NOTE = MIN_NOTE + NUM_OCTAVES * OCTAVE 56 | NUM_NOTES = MAX_NOTE - MIN_NOTE 57 | 58 | # Number of beats in a bar 59 | BEATS_PER_BAR = 4 60 | # Notes per quarter note 61 | NOTES_PER_BEAT = 4 62 | # The quickest note is a half-note 63 | NOTES_PER_BAR = NOTES_PER_BEAT * BEATS_PER_BAR 64 | 65 | # Training parameters 66 | BATCH_SIZE = 16 67 | SEQ_LEN = 8 * NOTES_PER_BAR 68 | 69 | # Hyper Parameters 70 | OCTAVE_UNITS = 64 71 | STYLE_UNITS = 64 72 | NOTE_UNITS = 3 73 | TIME_AXIS_UNITS = 256 74 | NOTE_AXIS_UNITS = 128 75 | 76 | TIME_AXIS_LAYERS = 2 77 | NOTE_AXIS_LAYERS = 2 78 | 79 | # Move file save location 80 | OUT_DIR = 'out' 81 | MODEL_DIR = os.path.join(OUT_DIR, 'models') 82 | MODEL_FILE = os.path.join(OUT_DIR, 'model.h5') 83 | SAMPLES_DIR = os.path.join(OUT_DIR, 'samples') 84 | CACHE_DIR = os.path.join(OUT_DIR, 'cache') 85 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocesses MIDI files 3 | """ 4 | import numpy as np 5 | import math 6 | import random 7 | from joblib import Parallel, delayed 8 | import multiprocessing 9 | 10 | from constants import * 11 | from midi_util import load_midi 12 | from util import * 13 | 14 | def compute_beat(beat, notes_in_bar): 15 | return one_hot(beat % notes_in_bar, notes_in_bar) 16 | 17 | def compute_completion(beat, len_melody): 18 | return np.array([beat / len_melody]) 19 | 20 | def compute_genre(genre_id): 21 | """ Computes a vector that represents a particular genre """ 22 | genre_hot = np.zeros((NUM_STYLES,)) 23 | start_index = sum(len(s) for i, s in enumerate(styles) if i < genre_id) 24 | styles_in_genre = len(styles[genre_id]) 25 | genre_hot[start_index:start_index + styles_in_genre] = 1 / styles_in_genre 26 | return genre_hot 27 | 28 | def stagger(data, time_steps): 29 | dataX, dataY = [], [] 30 | # Buffer training for first event 31 | data = ([np.zeros_like(data[0])] * time_steps) + list(data) 32 | 33 | # Chop a sequence into measures 34 | for i in range(0, len(data) - time_steps, NOTES_PER_BAR): 35 | dataX.append(data[i:i + time_steps]) 36 | dataY.append(data[i + 1:(i + time_steps + 1)]) 37 | return dataX, dataY 38 | 39 | def load_all(styles, batch_size, time_steps): 40 | """ 41 | Loads all MIDI files as a piano roll. 42 | (For Keras) 43 | """ 44 | note_data = [] 45 | beat_data = [] 46 | style_data = [] 47 | 48 | note_target = [] 49 | 50 | # TODO: Can speed this up with better parallel loading. Order gaurentee. 51 | styles = [y for x in styles for y in x] 52 | 53 | for style_id, style in enumerate(styles): 54 | style_hot = one_hot(style_id, NUM_STYLES) 55 | # Parallel process all files into a list of music sequences 56 | seqs = Parallel(n_jobs=multiprocessing.cpu_count(), backend='threading')(delayed(load_midi)(f) for f in get_all_files([style])) 57 | 58 | for seq in seqs: 59 | if len(seq) >= time_steps: 60 | # Clamp MIDI to note range 61 | seq = clamp_midi(seq) 62 | # Create training data and labels 63 | train_data, label_data = stagger(seq, time_steps) 64 | note_data += train_data 65 | note_target += label_data 66 | 67 | beats = [compute_beat(i, NOTES_PER_BAR) for i in range(len(seq))] 68 | beat_data += stagger(beats, time_steps)[0] 69 | 70 | style_data += stagger([style_hot for i in range(len(seq))], time_steps)[0] 71 | 72 | note_data = np.array(note_data) 73 | beat_data = np.array(beat_data) 74 | style_data = np.array(style_data) 75 | note_target = np.array(note_target) 76 | return [note_data, note_target, beat_data, style_data], [note_target] 77 | 78 | def clamp_midi(sequence): 79 | """ 80 | Clamps the midi base on the MIN and MAX notes 81 | """ 82 | return sequence[:, MIN_NOTE:MAX_NOTE, :] 83 | 84 | def unclamp_midi(sequence): 85 | """ 86 | Restore clamped MIDI sequence back to MIDI note values 87 | """ 88 | return np.pad(sequence, ((0, 0), (MIN_NOTE, 0), (0, 0)), 'constant') 89 | -------------------------------------------------------------------------------- /distribution.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import sys 4 | import dataset 5 | import ntpath 6 | 7 | from music import autocorrelate, NUM_CLASSES, MIN_CLASS, NOTES_PER_BEAT, NOTE_OFF, NO_EVENT, MIN_NOTE 8 | 9 | MIDI_NOTE_RANGE = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] * 4 + ['C'] 10 | NOTE_LEN_RANGE = ['0', ''] + MIDI_NOTE_RANGE 11 | 12 | def plot_note_distribution(melody_list): 13 | for i, (name, melody) in enumerate(melody_list): 14 | fig = plt.figure(figsize=(14, 5)) 15 | # Filter out 0's and 1's 16 | # Subtract min class from each note to 0 index the whole list 17 | notes = [x - MIN_CLASS for x in melody if x != 0 and x != 1] 18 | # Plot 19 | plt.hist(notes, bins=np.arange(len(MIDI_NOTE_RANGE) + 1)) 20 | plt.ylabel('Note frequency') 21 | plt.xticks(range(len(MIDI_NOTE_RANGE)), MIDI_NOTE_RANGE) 22 | # plt.show() 23 | plt.savefig('out/' + ntpath.basename(name) + ' (note dist).png') 24 | 25 | def plot_note_length(melody_list): 26 | for i, (name, melody) in enumerate(melody_list): 27 | # Dict that stores notes and their lengths 28 | note_len_dict = {} 29 | # Initialize keys/values in dict 30 | for i in range(len(NOTE_LEN_RANGE)): 31 | note_len_dict[i] = 0 32 | 33 | prev_note = 0 34 | for m in melody: 35 | # Note off 36 | if m == 0: 37 | note_len_dict[0] += 1 38 | # No event 39 | elif m == 1: 40 | note_len_dict[prev_note] += 1 41 | # Note 42 | else: 43 | note_len_dict[m] += 1 44 | prev_note = m 45 | # Convert dict into a list that can be put into histogram 46 | note_lens = [] 47 | for k in note_len_dict.keys(): 48 | for i in range(note_len_dict[k]): 49 | note_lens.append(k) 50 | 51 | # Plot 52 | fig = plt.figure(figsize=(14, 5)) 53 | plt.hist(note_lens, bins=np.arange(len(NOTE_LEN_RANGE) + 1)) 54 | plt.xlabel('0 represents a rest') 55 | plt.ylabel('Duration in eigth notes') 56 | plt.xticks(range(len(NOTE_LEN_RANGE)), NOTE_LEN_RANGE) 57 | # plt.show() 58 | plt.savefig('out/' + ntpath.basename(name) + ' (note length).png') 59 | 60 | def calculate_correlation(melody_list): 61 | correlations = [] 62 | for name, melody in melody_list: 63 | correlation = np.sum([autocorrelate(melody, i) ** 2 for i in range(1, 4)]) 64 | correlations.append(correlation) 65 | print('Correlation Coefficient (r^2 for 1, 2, 3): ', name, correlation) 66 | 67 | print('Mean: ', np.mean(correlations)) 68 | print('Std: ', np.std(correlations)) 69 | 70 | def distributions(paths): 71 | melody_list = dataset.load_melodies(paths, shuffle=False, named=True) 72 | plot_note_distribution(melody_list) 73 | plot_note_length(melody_list) 74 | calculate_correlation(melody_list) 75 | 76 | distributions(sys.argv) 77 | 78 | """ 79 | NOTES: 80 | 2 maps to midi note 36 (MIN_NOTE) 81 | 8 numbers in arr forms a bar 82 | 2 elements in arr are quarter note 83 | 1 element is a half a quarter note, or an eigth note 84 | output a png of plot with plotpy save 85 | """ 86 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from collections import deque 4 | import midi 5 | import argparse 6 | 7 | from constants import * 8 | from util import * 9 | from dataset import * 10 | from tqdm import tqdm 11 | from midi_util import midi_encode 12 | 13 | class MusicGeneration: 14 | """ 15 | Represents a music generation 16 | """ 17 | def __init__(self, style, default_temp=1): 18 | self.notes_memory = deque([np.zeros((NUM_NOTES, NOTE_UNITS)) for _ in range(SEQ_LEN)], maxlen=SEQ_LEN) 19 | self.beat_memory = deque([np.zeros(NOTES_PER_BAR) for _ in range(SEQ_LEN)], maxlen=SEQ_LEN) 20 | self.style_memory = deque([style for _ in range(SEQ_LEN)], maxlen=SEQ_LEN) 21 | 22 | # The next note being built 23 | self.next_note = np.zeros((NUM_NOTES, NOTE_UNITS)) 24 | self.silent_time = NOTES_PER_BAR 25 | 26 | # The outputs 27 | self.results = [] 28 | # The temperature 29 | self.default_temp = default_temp 30 | self.temperature = default_temp 31 | 32 | def build_time_inputs(self): 33 | return ( 34 | np.array(self.notes_memory), 35 | np.array(self.beat_memory), 36 | np.array(self.style_memory) 37 | ) 38 | 39 | def build_note_inputs(self, note_features): 40 | # Timesteps = 1 (No temporal dimension) 41 | return ( 42 | np.array(note_features), 43 | np.array([self.next_note]), 44 | np.array(list(self.style_memory)[-1:]) 45 | ) 46 | 47 | def choose(self, prob, n): 48 | vol = prob[n, -1] 49 | prob = apply_temperature(prob[n, :-1], self.temperature) 50 | 51 | # Flip notes randomly 52 | if np.random.random() <= prob[0]: 53 | self.next_note[n, 0] = 1 54 | # Apply volume 55 | self.next_note[n, 2] = vol 56 | # Flip articulation 57 | if np.random.random() <= prob[1]: 58 | self.next_note[n, 1] = 1 59 | 60 | def end_time(self, t): 61 | """ 62 | Finish generation for this time step. 63 | """ 64 | # Increase temperature while silent. 65 | if np.count_nonzero(self.next_note) == 0: 66 | self.silent_time += 1 67 | if self.silent_time >= NOTES_PER_BAR: 68 | self.temperature += 0.1 69 | else: 70 | self.silent_time = 0 71 | self.temperature = self.default_temp 72 | 73 | self.notes_memory.append(self.next_note) 74 | # Consistent with dataset representation 75 | self.beat_memory.append(compute_beat(t, NOTES_PER_BAR)) 76 | self.results.append(self.next_note) 77 | # Reset next note 78 | self.next_note = np.zeros((NUM_NOTES, NOTE_UNITS)) 79 | return self.results[-1] 80 | 81 | def apply_temperature(prob, temperature): 82 | """ 83 | Applies temperature to a sigmoid vector. 84 | """ 85 | # Apply temperature 86 | if temperature != 1: 87 | # Inverse sigmoid 88 | x = -np.log(1 / prob - 1) 89 | # Apply temperature to sigmoid function 90 | prob = 1 / (1 + np.exp(-x / temperature)) 91 | return prob 92 | 93 | def process_inputs(ins): 94 | ins = list(zip(*ins)) 95 | ins = [np.array(i) for i in ins] 96 | return ins 97 | 98 | def generate(models, num_bars, styles): 99 | print('Generating with styles:', styles) 100 | 101 | _, time_model, note_model = models 102 | generations = [MusicGeneration(style) for style in styles] 103 | 104 | for t in tqdm(range(NOTES_PER_BAR * num_bars)): 105 | # Produce note-invariant features 106 | ins = process_inputs([g.build_time_inputs() for g in generations]) 107 | # Pick only the last time step 108 | note_features = time_model.predict(ins) 109 | note_features = np.array(note_features)[:, -1:, :] 110 | 111 | # Generate each note conditioned on previous 112 | for n in range(NUM_NOTES): 113 | ins = process_inputs([g.build_note_inputs(note_features[i, :, :, :]) for i, g in enumerate(generations)]) 114 | predictions = np.array(note_model.predict(ins)) 115 | 116 | for i, g in enumerate(generations): 117 | # Remove the temporal dimension 118 | g.choose(predictions[i][-1], n) 119 | 120 | # Move one time step 121 | yield [g.end_time(t) for g in generations] 122 | 123 | def write_file(name, results): 124 | """ 125 | Takes a list of all notes generated per track and writes it to file 126 | """ 127 | results = zip(*list(results)) 128 | 129 | for i, result in enumerate(results): 130 | fpath = os.path.join(SAMPLES_DIR, name + '_' + str(i) + '.mid') 131 | print('Writing file', fpath) 132 | os.makedirs(os.path.dirname(fpath), exist_ok=True) 133 | mf = midi_encode(unclamp_midi(result)) 134 | midi.write_midifile(fpath, mf) 135 | 136 | def main(): 137 | parser = argparse.ArgumentParser(description='Generates music.') 138 | parser.add_argument('--bars', default=32, type=int, help='Number of bars to generate') 139 | parser.add_argument('--styles', default=None, type=int, nargs='+', help='Styles to mix together') 140 | args = parser.parse_args() 141 | 142 | models = build_or_load() 143 | 144 | styles = [compute_genre(i) for i in range(len(genre))] 145 | 146 | if args.styles: 147 | # Custom style 148 | styles = [np.mean([one_hot(i, NUM_STYLES) for i in args.styles], axis=0)] 149 | 150 | write_file('output', generate(models, args.bars, styles)) 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /midi_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Handles MIDI file loading 3 | """ 4 | import midi 5 | import numpy as np 6 | import os 7 | from constants import * 8 | 9 | def midi_encode(note_seq, resolution=NOTES_PER_BEAT, step=1): 10 | """ 11 | Takes a piano roll and encodes it into MIDI pattern 12 | """ 13 | # Instantiate a MIDI Pattern (contains a list of tracks) 14 | pattern = midi.Pattern() 15 | pattern.resolution = resolution 16 | # Instantiate a MIDI Track (contains a list of MIDI events) 17 | track = midi.Track() 18 | # Append the track to the pattern 19 | pattern.append(track) 20 | 21 | play = note_seq[:, :, 0] 22 | replay = note_seq[:, :, 1] 23 | volume = note_seq[:, :, 2] 24 | 25 | # The current pattern being played 26 | current = np.zeros_like(play[0]) 27 | # Absolute tick of last event 28 | last_event_tick = 0 29 | # Amount of NOOP ticks 30 | noop_ticks = 0 31 | 32 | for tick, data in enumerate(play): 33 | data = np.array(data) 34 | 35 | if not np.array_equal(current, data):# or np.any(replay[tick]): 36 | noop_ticks = 0 37 | 38 | for index, next_volume in np.ndenumerate(data): 39 | if next_volume > 0 and current[index] == 0: 40 | # Was off, but now turned on 41 | evt = midi.NoteOnEvent( 42 | tick=(tick - last_event_tick) * step, 43 | velocity=int(volume[tick][index[0]] * MAX_VELOCITY), 44 | pitch=index[0] 45 | ) 46 | track.append(evt) 47 | last_event_tick = tick 48 | elif current[index] > 0 and next_volume == 0: 49 | # Was on, but now turned off 50 | evt = midi.NoteOffEvent( 51 | tick=(tick - last_event_tick) * step, 52 | pitch=index[0] 53 | ) 54 | track.append(evt) 55 | last_event_tick = tick 56 | 57 | elif current[index] > 0 and next_volume > 0 and replay[tick][index[0]] > 0: 58 | # Handle replay 59 | evt_off = midi.NoteOffEvent( 60 | tick=(tick- last_event_tick) * step, 61 | pitch=index[0] 62 | ) 63 | track.append(evt_off) 64 | evt_on = midi.NoteOnEvent( 65 | tick=0, 66 | velocity=int(volume[tick][index[0]] * MAX_VELOCITY), 67 | pitch=index[0] 68 | ) 69 | track.append(evt_on) 70 | last_event_tick = tick 71 | 72 | else: 73 | noop_ticks += 1 74 | 75 | current = data 76 | 77 | tick += 1 78 | 79 | # Turn off all remaining on notes 80 | for index, vol in np.ndenumerate(current): 81 | if vol > 0: 82 | # Was on, but now turned off 83 | evt = midi.NoteOffEvent( 84 | tick=(tick - last_event_tick) * step, 85 | pitch=index[0] 86 | ) 87 | track.append(evt) 88 | last_event_tick = tick 89 | noop_ticks = 0 90 | 91 | # Add the end of track event, append it to the track 92 | eot = midi.EndOfTrackEvent(tick=noop_ticks) 93 | track.append(eot) 94 | 95 | return pattern 96 | 97 | def midi_decode(pattern, 98 | classes=MIDI_MAX_NOTES, 99 | step=None): 100 | """ 101 | Takes a MIDI pattern and decodes it into a piano roll. 102 | """ 103 | if step is None: 104 | step = pattern.resolution // NOTES_PER_BEAT 105 | 106 | # Extract all tracks at highest resolution 107 | merged_replay = None 108 | merged_volume = None 109 | 110 | for track in pattern: 111 | # The downsampled sequences 112 | replay_sequence = [] 113 | volume_sequence = [] 114 | 115 | # Raw sequences 116 | replay_buffer = [np.zeros((classes,))] 117 | volume_buffer = [np.zeros((classes,))] 118 | 119 | for i, event in enumerate(track): 120 | # Duplicate the last note pattern to wait for next event 121 | for _ in range(event.tick): 122 | replay_buffer.append(np.zeros(classes)) 123 | volume_buffer.append(np.copy(volume_buffer[-1])) 124 | 125 | # Buffer & downscale sequence 126 | if len(volume_buffer) > step: 127 | # Take the min 128 | replay_any = np.minimum(np.sum(replay_buffer[:-1], axis=0), 1) 129 | replay_sequence.append(replay_any) 130 | 131 | # Determine volume by max 132 | volume_sum = np.amax(volume_buffer[:-1], axis=0) 133 | volume_sequence.append(volume_sum) 134 | 135 | # Keep the last one (discard things in the middle) 136 | replay_buffer = replay_buffer[-1:] 137 | volume_buffer = volume_buffer[-1:] 138 | 139 | if isinstance(event, midi.EndOfTrackEvent): 140 | break 141 | 142 | # Modify the last note pattern 143 | if isinstance(event, midi.NoteOnEvent): 144 | pitch, velocity = event.data 145 | volume_buffer[-1][pitch] = velocity / MAX_VELOCITY 146 | 147 | # Check for replay_buffer, which is true if the current note was previously played and needs to be replayed 148 | if len(volume_buffer) > 1 and volume_buffer[-2][pitch] > 0 and volume_buffer[-1][pitch] > 0: 149 | replay_buffer[-1][pitch] = 1 150 | # Override current volume with previous volume 151 | volume_buffer[-1][pitch] = volume_buffer[-2][pitch] 152 | 153 | if isinstance(event, midi.NoteOffEvent): 154 | pitch, velocity = event.data 155 | volume_buffer[-1][pitch] = 0 156 | 157 | # Add the remaining 158 | replay_any = np.minimum(np.sum(replay_buffer, axis=0), 1) 159 | replay_sequence.append(replay_any) 160 | volume_sequence.append(volume_buffer[0]) 161 | 162 | replay_sequence = np.array(replay_sequence) 163 | volume_sequence = np.array(volume_sequence) 164 | assert len(volume_sequence) == len(replay_sequence) 165 | 166 | if merged_volume is None: 167 | merged_replay = replay_sequence 168 | merged_volume = volume_sequence 169 | else: 170 | # Merge into a single track, padding with zeros of needed 171 | if len(volume_sequence) > len(merged_volume): 172 | # Swap variables such that merged_notes is always at least 173 | # as large as play_sequence 174 | tmp = replay_sequence 175 | replay_sequence = merged_replay 176 | merged_replay = tmp 177 | 178 | tmp = volume_sequence 179 | volume_sequence = merged_volume 180 | merged_volume = tmp 181 | 182 | assert len(merged_volume) >= len(volume_sequence) 183 | 184 | diff = len(merged_volume) - len(volume_sequence) 185 | merged_replay += np.pad(replay_sequence, ((0, diff), (0, 0)), 'constant') 186 | merged_volume += np.pad(volume_sequence, ((0, diff), (0, 0)), 'constant') 187 | 188 | merged = np.stack([np.ceil(merged_volume), merged_replay, merged_volume], axis=2) 189 | # Prevent stacking duplicate notes to exceed one. 190 | merged = np.minimum(merged, 1) 191 | return merged 192 | 193 | def load_midi(fname): 194 | p = midi.read_midifile(fname) 195 | cache_path = os.path.join(CACHE_DIR, fname + '.npy') 196 | try: 197 | note_seq = np.load(cache_path) 198 | except Exception as e: 199 | # Perform caching 200 | os.makedirs(os.path.dirname(cache_path), exist_ok=True) 201 | 202 | note_seq = midi_decode(p) 203 | np.save(cache_path, note_seq) 204 | 205 | assert len(note_seq.shape) == 3, note_seq.shape 206 | assert note_seq.shape[1] == MIDI_MAX_NOTES, note_seq.shape 207 | assert note_seq.shape[2] == 3, note_seq.shape 208 | assert (note_seq >= 0).all() 209 | assert (note_seq <= 1).all() 210 | return note_seq 211 | 212 | if __name__ == '__main__': 213 | # Test 214 | # p = midi.read_midifile("out/test_in.mid") 215 | p = midi.read_midifile("out/test_in.mid") 216 | p = midi_encode(midi_decode(p)) 217 | midi.write_midifile("out/test_out.mid", p) 218 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from keras.layers import Input, LSTM, Dense, Dropout, Lambda, Reshape, Permute 4 | from keras.layers import TimeDistributed, RepeatVector, Conv1D, Activation 5 | from keras.layers import Embedding, Flatten 6 | from keras.layers.merge import Concatenate, Add 7 | from keras.models import Model 8 | import keras.backend as K 9 | from keras import losses 10 | 11 | from util import * 12 | from constants import * 13 | 14 | def primary_loss(y_true, y_pred): 15 | # 3 separate loss calculations based on if note is played or not 16 | played = y_true[:, :, :, 0] 17 | bce_note = losses.binary_crossentropy(y_true[:, :, :, 0], y_pred[:, :, :, 0]) 18 | bce_replay = losses.binary_crossentropy(y_true[:, :, :, 1], tf.multiply(played, y_pred[:, :, :, 1]) + tf.multiply(1 - played, y_true[:, :, :, 1])) 19 | mse = losses.mean_squared_error(y_true[:, :, :, 2], tf.multiply(played, y_pred[:, :, :, 2]) + tf.multiply(1 - played, y_true[:, :, :, 2])) 20 | return bce_note + bce_replay + mse 21 | 22 | def pitch_pos_in_f(time_steps): 23 | """ 24 | Returns a constant containing pitch position of each note 25 | """ 26 | def f(x): 27 | note_ranges = tf.range(NUM_NOTES, dtype='float32') / NUM_NOTES 28 | repeated_ranges = tf.tile(note_ranges, [tf.shape(x)[0] * time_steps]) 29 | return tf.reshape(repeated_ranges, [tf.shape(x)[0], time_steps, NUM_NOTES, 1]) 30 | return f 31 | 32 | def pitch_class_in_f(time_steps): 33 | """ 34 | Returns a constant containing pitch class of each note 35 | """ 36 | def f(x): 37 | pitch_class_matrix = np.array([one_hot(n % OCTAVE, OCTAVE) for n in range(NUM_NOTES)]) 38 | pitch_class_matrix = tf.constant(pitch_class_matrix, dtype='float32') 39 | pitch_class_matrix = tf.reshape(pitch_class_matrix, [1, 1, NUM_NOTES, OCTAVE]) 40 | return tf.tile(pitch_class_matrix, [tf.shape(x)[0], time_steps, 1, 1]) 41 | return f 42 | 43 | def pitch_bins_f(time_steps): 44 | def f(x): 45 | bins = tf.reduce_sum([x[:, :, i::OCTAVE, 0] for i in range(OCTAVE)], axis=3) 46 | bins = tf.tile(bins, [NUM_OCTAVES, 1, 1]) 47 | bins = tf.reshape(bins, [tf.shape(x)[0], time_steps, NUM_NOTES, 1]) 48 | return bins 49 | return f 50 | 51 | def time_axis(dropout): 52 | def f(notes, beat, style): 53 | time_steps = int(notes.get_shape()[1]) 54 | 55 | # TODO: Experiment with when to apply conv 56 | note_octave = TimeDistributed(Conv1D(OCTAVE_UNITS, 2 * OCTAVE, padding='same'))(notes) 57 | note_octave = Activation('tanh')(note_octave) 58 | note_octave = Dropout(dropout)(note_octave) 59 | 60 | # Create features for every single note. 61 | note_features = Concatenate()([ 62 | Lambda(pitch_pos_in_f(time_steps))(notes), 63 | Lambda(pitch_class_in_f(time_steps))(notes), 64 | Lambda(pitch_bins_f(time_steps))(notes), 65 | note_octave, 66 | TimeDistributed(RepeatVector(NUM_NOTES))(beat) 67 | ]) 68 | 69 | x = note_features 70 | 71 | # [batch, notes, time, features] 72 | x = Permute((2, 1, 3))(x) 73 | 74 | # Apply LSTMs 75 | for l in range(TIME_AXIS_LAYERS): 76 | # Integrate style 77 | style_proj = Dense(int(x.get_shape()[3]))(style) 78 | style_proj = TimeDistributed(RepeatVector(NUM_NOTES))(style_proj) 79 | style_proj = Activation('tanh')(style_proj) 80 | style_proj = Dropout(dropout)(style_proj) 81 | style_proj = Permute((2, 1, 3))(style_proj) 82 | x = Add()([x, style_proj]) 83 | 84 | x = TimeDistributed(LSTM(TIME_AXIS_UNITS, return_sequences=True))(x) 85 | x = Dropout(dropout)(x) 86 | 87 | # [batch, time, notes, features] 88 | return Permute((2, 1, 3))(x) 89 | return f 90 | 91 | def note_axis(dropout): 92 | dense_layer_cache = {} 93 | lstm_layer_cache = {} 94 | note_dense = Dense(2, activation='sigmoid', name='note_dense') 95 | volume_dense = Dense(1, name='volume_dense') 96 | 97 | def f(x, chosen, style): 98 | time_steps = int(x.get_shape()[1]) 99 | 100 | # Shift target one note to the left. 101 | shift_chosen = Lambda(lambda x: tf.pad(x[:, :, :-1, :], [[0, 0], [0, 0], [1, 0], [0, 0]]))(chosen) 102 | 103 | # [batch, time, notes, 1] 104 | shift_chosen = Reshape((time_steps, NUM_NOTES, -1))(shift_chosen) 105 | # [batch, time, notes, features + 1] 106 | x = Concatenate(axis=3)([x, shift_chosen]) 107 | 108 | for l in range(NOTE_AXIS_LAYERS): 109 | # Integrate style 110 | if l not in dense_layer_cache: 111 | dense_layer_cache[l] = Dense(int(x.get_shape()[3])) 112 | 113 | style_proj = dense_layer_cache[l](style) 114 | style_proj = TimeDistributed(RepeatVector(NUM_NOTES))(style_proj) 115 | style_proj = Activation('tanh')(style_proj) 116 | style_proj = Dropout(dropout)(style_proj) 117 | x = Add()([x, style_proj]) 118 | 119 | if l not in lstm_layer_cache: 120 | lstm_layer_cache[l] = LSTM(NOTE_AXIS_UNITS, return_sequences=True) 121 | 122 | x = TimeDistributed(lstm_layer_cache[l])(x) 123 | x = Dropout(dropout)(x) 124 | 125 | return Concatenate()([note_dense(x), volume_dense(x)]) 126 | return f 127 | 128 | def build_models(time_steps=SEQ_LEN, input_dropout=0.2, dropout=0.5): 129 | notes_in = Input((time_steps, NUM_NOTES, NOTE_UNITS)) 130 | beat_in = Input((time_steps, NOTES_PER_BAR)) 131 | style_in = Input((time_steps, NUM_STYLES)) 132 | # Target input for conditioning 133 | chosen_in = Input((time_steps, NUM_NOTES, NOTE_UNITS)) 134 | 135 | # Dropout inputs 136 | notes = Dropout(input_dropout)(notes_in) 137 | beat = Dropout(input_dropout)(beat_in) 138 | chosen = Dropout(input_dropout)(chosen_in) 139 | 140 | # Distributed representations 141 | style_l = Dense(STYLE_UNITS, name='style') 142 | style = style_l(style_in) 143 | 144 | """ Time axis """ 145 | time_out = time_axis(dropout)(notes, beat, style) 146 | 147 | """ Note Axis & Prediction Layer """ 148 | naxis = note_axis(dropout) 149 | notes_out = naxis(time_out, chosen, style) 150 | 151 | model = Model([notes_in, chosen_in, beat_in, style_in], [notes_out]) 152 | model.compile(optimizer='nadam', loss=[primary_loss]) 153 | 154 | """ Generation Models """ 155 | time_model = Model([notes_in, beat_in, style_in], [time_out]) 156 | 157 | note_features = Input((1, NUM_NOTES, TIME_AXIS_UNITS), name='note_features') 158 | chosen_gen_in = Input((1, NUM_NOTES, NOTE_UNITS), name='chosen_gen_in') 159 | style_gen_in = Input((1, NUM_STYLES), name='style_in') 160 | 161 | # Dropout inputs 162 | chosen_gen = Dropout(input_dropout)(chosen_gen_in) 163 | style_gen = style_l(style_gen_in) 164 | 165 | note_gen_out = naxis(note_features, chosen_gen, style_gen) 166 | 167 | note_model = Model([note_features, chosen_gen_in, style_gen_in], note_gen_out) 168 | 169 | return model, time_model, note_model 170 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | keras 2 | tensorflow-gpu 3 | joblib 4 | tqdm 5 | h5py 6 | -------------------------------------------------------------------------------- /scripts/cuda.sh: -------------------------------------------------------------------------------- 1 | # Tensorflow setup on Linux instance 2 | # http://www.nvidia.com/object/gpu-accelerated-applications-tensorflow-installation.html 3 | # https://www.tensorflow.org/install/install_linux 4 | # Run this file from the parent dir of the repository 5 | 6 | # Download CUDA 7 | if test -e "cuda.deb" 8 | then 9 | wget -O cuda.deb https://developer.nvidia.com/compute/cuda/8.0/Prod2/local_installers/cuda-repo-ubuntu1404-8-0-local-ga2_8.0.61-1_amd64-deb 10 | fi 11 | 12 | # Ensure same name is used 13 | sudo dpkg -i cuda.deb 14 | yes | sudo apt-get update 15 | yes | sudo apt-get install cuda 16 | 17 | # Install libcupti-dev library 18 | yes | sudo apt-get install libcupti-dev 19 | 20 | # Set CUDA library paths 21 | echo 'export LD_LIBRARY_PATH=/silo/cuda/lib64:/usr/local/cuda-8.0/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}' >> ~/.bashrc 22 | echo 'export PATH=/usr/local/cuda-8.0/bin${PATH:+:${PATH}}' >> ~/.bashrc 23 | source ~/.bashrc 24 | -------------------------------------------------------------------------------- /scripts/load_data.sh: -------------------------------------------------------------------------------- 1 | # Run this file from the parent dir of the repository 2 | y | sudo apt-get install unzip 3 | 4 | mkdir -p ./music-generator/data 5 | unzip data.zip -d ./music-generator/data 6 | -------------------------------------------------------------------------------- /scripts/mount.sh: -------------------------------------------------------------------------------- 1 | # Mount AWS volume 2 | # http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ebs-using-volumes.html 3 | # TODO: Add auto-remount 4 | 5 | lsblk # List volumes 6 | sudo mkdir /silo # Create volume path 7 | sudo mount /dev/xvdf /silo # Mount the volume 8 | -------------------------------------------------------------------------------- /scripts/python.sh: -------------------------------------------------------------------------------- 1 | # Install all Python related things 2 | # Run this file from the parent dir of the repository 3 | 4 | # Python files 5 | yes | sudo apt-get install python3 python3-pip python3-dev 6 | 7 | # For Python midi 8 | yes | sudo apt-get install libasound2-dev python3-augeas swig 9 | 10 | # Install Python MIDI 11 | git clone https://github.com/vishnubob/python-midi/ 12 | cd python-midi 13 | git checkout feature/python3 14 | sudo python3 setup.py install 15 | cd ../ 16 | 17 | # Install project requirements 18 | cd music-generator 19 | pip3 install -r requirements.txt 20 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from midi_util import * 2 | from util import * 3 | import unittest 4 | 5 | class TestMIDIUtil(unittest.TestCase): 6 | 7 | def test_encode(self): 8 | composition = [ 9 | [0, 1, 0, 0], 10 | [0, 1, 0, 0], 11 | [0, 1, 0, 1], 12 | [0, 1, 0, 1], 13 | [0, 0, 0, 1], 14 | [0, 0, 0, 0] 15 | ] 16 | 17 | replay = [ 18 | [0, 0, 0, 0], 19 | [0, 0, 0, 0], 20 | [0, 0, 0, 0], 21 | [0, 0, 0, 0], 22 | [0, 0, 0, 0], 23 | [0, 0, 0, 0] 24 | ] 25 | 26 | volume = [ 27 | [0, 0.5, 0, 0], 28 | [0, 0.5, 0, 0], 29 | [0, 0.5, 0, 0.5], 30 | [0, 0.5, 0, 0.5], 31 | [0, 0, 0, 0.5], 32 | [0, 0, 0, 0] 33 | ] 34 | 35 | pattern = midi_encode(np.stack([composition, replay, volume], 2), step=1) 36 | self.assertEqual(pattern.resolution, NOTES_PER_BEAT) 37 | self.assertEqual(len(pattern), 1) 38 | track = pattern[0] 39 | self.assertEqual(len(track), 4 + 1) 40 | on1, on2, off1, off2 = track[:-1] 41 | self.assertIsInstance(on1, midi.NoteOnEvent) 42 | self.assertIsInstance(on2, midi.NoteOnEvent) 43 | self.assertIsInstance(off1, midi.NoteOffEvent) 44 | self.assertIsInstance(off2, midi.NoteOffEvent) 45 | 46 | self.assertEqual(on1.tick, 0) 47 | self.assertEqual(on1.pitch, 1) 48 | self.assertEqual(on2.tick, 2) 49 | self.assertEqual(on2.pitch, 3) 50 | self.assertEqual(off1.tick, 2) 51 | self.assertEqual(off1.pitch, 1) 52 | self.assertEqual(off2.tick, 1) 53 | self.assertEqual(off2.pitch, 3) 54 | 55 | def test_decode(self): 56 | # Instantiate a MIDI Pattern (contains a list of tracks) 57 | pattern = midi.Pattern(resolution=96) 58 | # Instantiate a MIDI Track (contains a list of MIDI events) 59 | track = midi.Track() 60 | # Append the track to the pattern 61 | pattern.append(track) 62 | 63 | track.append(midi.NoteOnEvent(tick=0, velocity=127, pitch=0)) 64 | track.append(midi.NoteOnEvent(tick=96, velocity=127, pitch=1)) 65 | track.append(midi.NoteOffEvent(tick=0, velocity=127, pitch=0)) 66 | track.append(midi.NoteOffEvent(tick=48, velocity=127, pitch=1)) 67 | track.append(midi.EndOfTrackEvent(tick=1)) 68 | 69 | note_sequence = midi_decode(pattern, 4, step=DEFAULT_RES // 2) 70 | composition = note_sequence[:, :, 0] 71 | 72 | np.testing.assert_array_equal(composition, [ 73 | [1, 0, 0, 0], 74 | [1, 0, 0, 0], 75 | [0, 1, 0, 0], 76 | [0, 0, 0, 0] 77 | ]) 78 | 79 | def test_encode_decode(self): 80 | composition = [ 81 | [0, 1, 0, 0], 82 | [0, 1, 0, 0], 83 | [0, 1, 0, 1], 84 | [0, 1, 0, 1], 85 | [0, 0, 0, 1], 86 | [0, 0, 0, 0] 87 | ] 88 | 89 | replay = [ 90 | [0, 0, 0, 0], 91 | [0, 0, 0, 0], 92 | [0, 0, 0, 0], 93 | [0, 0, 0, 0], 94 | [0, 0, 0, 0], 95 | [0, 0, 0, 0] 96 | ] 97 | 98 | volume = [ 99 | [0, 0.5, 0, 0], 100 | [0, 0.5, 0, 0], 101 | [0, 0.5, 0, 0.5], 102 | [0, 0.5, 0, 0.5], 103 | [0, 0, 0, 0.5], 104 | [0, 0, 0, 0] 105 | ] 106 | 107 | note_seq = midi_decode(midi_encode(np.stack([composition, replay, volume], 2), step=1), 4, step=1) 108 | np.testing.assert_array_equal(composition, note_seq[:, :, 0]) 109 | 110 | def test_replay_decode(self): 111 | # Instantiate a MIDI Pattern (contains a list of tracks) 112 | pattern = midi.Pattern(resolution=96) 113 | # Instantiate a MIDI Track (contains a list of MIDI events) 114 | track = midi.Track() 115 | # Append the track to the pattern 116 | pattern.append(track) 117 | 118 | track.append(midi.NoteOnEvent(tick=0, velocity=127, pitch=1)) 119 | track.append(midi.NoteOnEvent(tick=0, velocity=127, pitch=3)) 120 | track.append(midi.NoteOffEvent(tick=1, velocity=127, pitch=1)) 121 | track.append(midi.NoteOnEvent(tick=2, velocity=127, pitch=1)) 122 | track.append(midi.NoteOnEvent(tick=2, velocity=127, pitch=3)) 123 | track.append(midi.EndOfTrackEvent(tick=1)) 124 | 125 | note_seq = midi_decode(pattern, 4, step=3) 126 | 127 | np.testing.assert_array_equal(note_seq[:, :, 1], [ 128 | [0., 0., 0., 0.], 129 | [0., 0., 0., 1.], 130 | [0., 0., 0., 0.] 131 | ]) 132 | 133 | 134 | def test_volume_decode(self): 135 | # Instantiate a MIDI Pattern (contains a list of tracks) 136 | pattern = midi.Pattern(resolution=96) 137 | # Instantiate a MIDI Track (contains a list of MIDI events) 138 | track = midi.Track() 139 | # Append the track to the pattern 140 | pattern.append(track) 141 | 142 | track.append(midi.NoteOnEvent(tick=0, velocity=24, pitch=0)) 143 | track.append(midi.NoteOnEvent(tick=96, velocity=89, pitch=1)) 144 | track.append(midi.NoteOffEvent(tick=0, pitch=0)) 145 | track.append(midi.NoteOffEvent(tick=48, pitch=1)) 146 | track.append(midi.EndOfTrackEvent(tick=1)) 147 | 148 | note_seq = midi_decode(pattern, 4, step=DEFAULT_RES // 2) 149 | 150 | np.testing.assert_array_almost_equal(note_seq[:, :, 2], [ 151 | [24/127, 0., 0., 0.], 152 | [24/127, 0., 0., 0.], 153 | [0., 89/127, 0., 0.], 154 | [0., 0., 0., 0.] 155 | ], decimal=5) 156 | 157 | 158 | def test_replay_encode_decode(self): 159 | # TODO: Fix this test 160 | composition = [ 161 | [0, 1, 0, 1], 162 | [0, 0, 0, 1], 163 | [0, 0, 0, 1], 164 | [0, 1, 0, 1], 165 | [0, 1, 0, 1], 166 | [0, 1, 0, 1], 167 | [0, 0, 0, 0] 168 | ] 169 | 170 | replay = [ 171 | [0, 0, 0, 0], 172 | [0, 0, 0, 0], 173 | [0, 0, 0, 0], 174 | [0, 0, 0, 0], 175 | [0, 0, 0, 1], 176 | [0, 1, 0, 1], 177 | [0, 0, 0, 0] 178 | ] 179 | 180 | volume = [ 181 | [0, 0.5, 0, 0.5], 182 | [0, 0, 0, 0.5], 183 | [0, 0, 0, 0.5], 184 | [0, 0.5, 0, 0.5], 185 | [0, 0.5, 0, 0.5], 186 | [0, 0.5, 0, 0.5], 187 | [0, 0, 0, 0] 188 | ] 189 | 190 | note_seq = midi_decode(midi_encode(np.stack([composition, replay, volume], 2), step=2), 4, step=2) 191 | np.testing.assert_array_equal(composition, note_seq[:, :, 0]) 192 | # TODO: Downsampling might have caused loss of information 193 | # np.testing.assert_array_equal(replay, note_seq[:, :, 1]) 194 | 195 | unittest.main() 196 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from keras.callbacks import ModelCheckpoint, LambdaCallback 3 | from keras.callbacks import EarlyStopping, TensorBoard 4 | import argparse 5 | import midi 6 | import os 7 | 8 | from constants import * 9 | from dataset import * 10 | from generate import * 11 | from midi_util import midi_encode 12 | from model import * 13 | 14 | def main(): 15 | models = build_or_load() 16 | train(models) 17 | 18 | def train(models): 19 | print('Loading data') 20 | train_data, train_labels = load_all(styles, BATCH_SIZE, SEQ_LEN) 21 | 22 | cbs = [ 23 | ModelCheckpoint(MODEL_FILE, monitor='loss', save_best_only=True, save_weights_only=True), 24 | EarlyStopping(monitor='loss', patience=5), 25 | TensorBoard(log_dir='out/logs', histogram_freq=1) 26 | ] 27 | 28 | print('Training') 29 | models[0].fit(train_data, train_labels, epochs=1000, callbacks=cbs, batch_size=BATCH_SIZE) 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import math 4 | 5 | from constants import * 6 | from midi_util import * 7 | 8 | def one_hot(i, nb_classes): 9 | arr = np.zeros((nb_classes,)) 10 | arr[i] = 1 11 | return arr 12 | 13 | def build_or_load(allow_load=True): 14 | from model import build_models 15 | models = build_models() 16 | models[0].summary() 17 | if allow_load: 18 | try: 19 | models[0].load_weights(MODEL_FILE) 20 | print('Loaded model from file.') 21 | except: 22 | print('Unable to load model from file.') 23 | return models 24 | 25 | def get_all_files(paths): 26 | potential_files = [] 27 | for path in paths: 28 | for root, dirs, files in os.walk(path): 29 | for f in files: 30 | fname = os.path.join(root, f) 31 | if os.path.isfile(fname) and fname.endswith('.mid'): 32 | potential_files.append(fname) 33 | return potential_files 34 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | from keras import backend as K 5 | 6 | from util import * 7 | from constants import * 8 | 9 | # Visualize using: 10 | # http://projector.tensorflow.org/ 11 | def main(): 12 | models = build_or_load() 13 | style_layer = models[0].get_layer('style') 14 | 15 | print('Creating input') 16 | style_in = tf.placeholder(tf.float32, shape=(NUM_STYLES, NUM_STYLES)) 17 | style_out = style_layer(style_in) 18 | 19 | # All possible styles 20 | all_styles = np.identity(NUM_STYLES) 21 | 22 | with K.get_session() as sess: 23 | embedding = sess.run(style_out, { style_in: all_styles }) 24 | 25 | print('Writing to out directory') 26 | np.savetxt(os.path.join(OUT_DIR, 'style_embedding_vec.tsv'), embedding, delimiter='\t') 27 | 28 | labels = [[g] * len(styles[i]) for i, g in enumerate(genre)] 29 | # Flatten 30 | labels = [y for x in labels for y in x] 31 | 32 | # Retreive specific artists 33 | styles_labels = [y for x in styles for y in x] 34 | 35 | styles_labels = np.reshape(styles_labels, [-1, 1]) 36 | labels = np.reshape(labels, [-1, 1]) 37 | labels = np.hstack([labels, styles_labels]) 38 | 39 | # Add metadata header 40 | header = ['Genre', 'Artist'] 41 | labels = np.vstack([header, labels]) 42 | 43 | np.savetxt(os.path.join(OUT_DIR, 'style_embedding_labels.tsv'), labels, delimiter='\t', fmt='%s') 44 | 45 | if __name__ == '__main__': 46 | main() 47 | --------------------------------------------------------------------------------