├── src ├── __init__.py ├── transform.py ├── loss.py ├── reconstruction.py ├── dataset.py ├── checkpoint.py ├── sampler.py ├── midi_functions.py ├── model.py ├── trainer.py ├── layers.py └── preprocess.py ├── README.md ├── requirements.txt ├── scripts ├── preprocess.py ├── train.py └── evaluate.py ├── conf.yml └── .gitignore /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # magenta-torch 2 | Pytorch Implementation of MusicVAE with LSTM and GRU architectures 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pretty_midi 2 | numpy 3 | torch 4 | pandas 5 | matplotlib 6 | librosa 7 | PyYAML 8 | -------------------------------------------------------------------------------- /src/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | class Transform: 5 | def __init__(self, bars=1, events=4, note_count=61): 6 | self.split_size = bars*events 7 | self.note_count = note_count 8 | 9 | def get_sections(self, sample_length): 10 | return math.ceil(sample_length / self.split_size) 11 | 12 | def __call__(self, sample, section): 13 | start = section*self.split_size 14 | end = min(section*self.split_size + self.split_size, sample.shape[0]) 15 | sample = sample[start : end] 16 | 17 | sample_length = len(sample) 18 | leftover = sample_length % self.split_size 19 | if leftover != 0 : 20 | padding_size = self.split_size - leftover 21 | padding = np.zeros((padding_size, sample.shape[1], sample.shape[2]), dtype=float) 22 | sample = np.concatenate((sample, padding), axis=0) 23 | return sample.reshape(-1, self.note_count) -------------------------------------------------------------------------------- /scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import sys 5 | import yaml 6 | 7 | sys.path.append(".") 8 | 9 | from src.preprocess import MidiPreprocessor 10 | 11 | # General settings 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--config', type=str, default='conf.yml') 14 | parser.add_argument('--import_dir', type=str) 15 | parser.add_argument('--save_imported_midi_as_pickle', type=bool, default=True) 16 | parser.add_argument('--save_preprocessed_midi', type=bool, default=True) 17 | 18 | def main(args): 19 | conf = None 20 | 21 | with open(args.config, 'r') as config_file: 22 | config = yaml.load(config_file) 23 | conf = config['preprocessor'] 24 | 25 | processor = MidiPreprocessor(**conf) 26 | processor.import_midi_from_folder(args.import_dir, 27 | args.save_imported_midi_as_pickle, 28 | args.save_preprocessed_midi) 29 | 30 | if __name__ == '__main__': 31 | args = parser.parse_args() 32 | main(args) 33 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import binary_cross_entropy 3 | from torch.distributions.normal import Normal 4 | from torch.distributions.kl import kl_divergence 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | def ELBO(pred, target, mu, sigma, free_bits): 8 | """ 9 | Evidence Lower Bound 10 | Return KL Divergence and KL Regularization using free bits 11 | """ 12 | device = pred.device 13 | # Reconstruction error 14 | # Pytorch cross_entropy combines LogSoftmax and NLLLoss 15 | likelihood = -binary_cross_entropy(pred, target, reduction='sum') 16 | # Regularization error 17 | sigma_prior = torch.tensor([1], dtype=torch.float, device=device) 18 | mu_prior = torch.tensor([0], dtype=torch.float, device=device) 19 | p = Normal(mu_prior, sigma_prior) 20 | q = Normal(mu, sigma) 21 | kl_div = kl_divergence(q, p) 22 | elbo = torch.mean(likelihood) - torch.max(torch.mean(kl_div)-free_bits, torch.tensor([0], dtype=torch.float, device=device)) 23 | 24 | return -elbo, kl_div.mean() 25 | -------------------------------------------------------------------------------- /conf.yml: -------------------------------------------------------------------------------- 1 | preprocessor: 2 | classes: 3 | - Jazz 4 | - Pop 5 | pickle_store_folder: 'pickles' 6 | include_unknown: False 7 | only_unknown: False 8 | low_crop: 24 9 | high_crop: 84 10 | num_notes: 128 11 | smallest_note: 16 12 | max_velocity: 127 13 | include_only_monophonic_instruments: False 14 | max_voices_per_track: 1 15 | max_voices: 4 16 | include_silent_note: True 17 | velocity_threshold: 0.5 18 | instrument_attach_method: '1hot-category' 19 | attach_instruments: False 20 | input_length: &T 256 # Number of beats. We want to split songs into 16 bars of 16th notes 21 | output_length: *T 22 | test_fraction: 0.1 23 | 24 | trainer: 25 | learning_rate: 0.001 26 | KL_rate: 0.9999 27 | free_bits: 256 28 | sampling_rate: 2000 29 | batch_size: &batchsize 2 30 | print_every: 1 31 | checkpoint_every: 10000 32 | checkpoint_dir: 'checkpoint' 33 | output_dir: 'outputs' 34 | 35 | sampler: 36 | free_bits: 256 37 | output_dir: 'outputs' 38 | 39 | model: 40 | num_subsequences: 16 41 | max_sequence_length: *T 42 | sequence_length: 16 43 | encoder_input_size: 61 44 | decoder_input_size: 61 45 | encoder_hidden_size: 2048 46 | decoder_hidden_size: 1024 47 | latent_dim: 512 48 | encoder_num_layers: 2 49 | decoder_num_layers: 2 50 | 51 | data: # Specify paths to data 52 | train_data: 'pickles/X_train_1.pickle' 53 | val_data: '' 54 | train_instrument_data: '' 55 | val_instrument_data: '' 56 | train_tempo_data: '' 57 | val_tempo_data: '' 58 | train_song_paths: '' 59 | val_song_paths: '' 60 | 61 | evaluation: 62 | model_path: 'outputs/checkpoints/gru_small/model.pt' 63 | test_data: 'pickles/X_train_1.pickle' 64 | test_instruments: 'pickles/I_train_1.pickle' 65 | test_songs: 'pickles/train_paths_1.pickle' 66 | test_tempos: 'pickles/T_train_1.pickle' 67 | batch_size: *batchsize 68 | temperature: 1.0 69 | reconstruction: 70 | attach_method: '1hot-category' 71 | song_name: '98_Degrees_-_The_Hardest_Thing' 72 | reconstruction_path: 'midi_reconstruction' 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import sys 5 | import yaml 6 | 7 | sys.path.append(".") 8 | 9 | from torch.utils.data import DataLoader 10 | 11 | from src.model import * 12 | from src.trainer import * 13 | from src.dataset import MidiDataset 14 | 15 | # General settings 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--config', type=str, default='conf.yml') 18 | parser.add_argument('--model_type', type=str, default='lstm') 19 | parser.add_argument('--epochs', type=int, default=50) 20 | parser.add_argument('--resume', type=bool, default=False) 21 | 22 | 23 | def load_model(model_type, params): 24 | if model_type == 'lstm': 25 | model = MusicLSTMVAE(**params) 26 | elif model_type == 'gru': 27 | model = MusicGRUVAE(**params) 28 | else: 29 | raise Exception("Invalid model type. Expected lstm or gru") 30 | return model 31 | 32 | 33 | def load_data(train_data, val_data, batch_size, validation_split=0.2, random_seed=874): 34 | train_loader = None 35 | val_loader = None 36 | if train_data != '': 37 | X_train = pickle.load(open(train_data, 'rb')) 38 | train_data = MidiDataset(X_train) 39 | train_loader = DataLoader(train_data, batch_size=batch_size) 40 | if val_data != '': 41 | X_val = pickle.load(open(val_data, 'rb')) 42 | val_data = MidiDataset(X_val) 43 | val_loader = DataLoader(val_data, batch_size=batch_size) 44 | 45 | return train_loader, val_loader 46 | 47 | 48 | def train(model, trainer, train_data, val_data, epochs, resume): 49 | """ 50 | Train a model 51 | """ 52 | trainer.train(model, train_data, None, epochs, resume, val_data) 53 | 54 | 55 | def main(args): 56 | model_params = None 57 | trainer_params = None 58 | data_params = None 59 | with open(args.config, 'r') as config_file: 60 | config = yaml.load(config_file) 61 | model_params = config['model'] 62 | trainer_params = config['trainer'] 63 | data_params = config['data'] 64 | 65 | train_data, val_data = load_data(data_params['train_data'], 66 | data_params['val_data'], 67 | trainer_params['batch_size']) 68 | 69 | model = load_model(args.model_type, model_params) 70 | 71 | trainer = Trainer(**trainer_params) 72 | 73 | train(model, trainer, train_data, val_data, args.epochs, args.resume) 74 | 75 | 76 | if __name__ == '__main__': 77 | args = parser.parse_args() 78 | main(args) 79 | -------------------------------------------------------------------------------- /src/reconstruction.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pretty_midi 3 | 4 | import matplotlib.pyplot as plt 5 | import librosa.display as display 6 | 7 | class MidiBuilder(): 8 | """Build a MIDI from a piano roll sample""" 9 | 10 | def __init__(self, midi_start=24, midi_end=85): #TODO should be 84 11 | """ 12 | Args: 13 | midi_start (int): The first midi note in the dataset 14 | midi_end (int): The last midi note in the dataset 15 | """ 16 | self.dtypes = {'piano_roll_name': 'object', 'timestep': 'uint32'} 17 | self.column_names = [pretty_midi.note_number_to_name(n) for n in range(midi_start, midi_end)] 18 | for column in self.column_names: 19 | self.dtypes[column] = 'uint8' 20 | 21 | 22 | def midi_from_piano_roll(self, sample, tempo = 120): 23 | """ 24 | We're taking some assumptions here to reconstruct the midi. 25 | """ 26 | piano_roll = pd.DataFrame(sample, columns=self.column_names, dtype='uint8') 27 | 28 | program = 0 29 | velocity = int(100) 30 | bps = tempo / 60 31 | sps = bps * 4 # sixteenth notes per second 32 | 33 | # Create a PrettyMIDI object 34 | piano_midi = pretty_midi.PrettyMIDI() 35 | 36 | piano = pretty_midi.Instrument(program=program) 37 | # Iterate over note names, which will be converted to note number later 38 | for idx in piano_roll.index: 39 | for note_name in piano_roll.columns: 40 | #print(note_name) 41 | 42 | # Check if the note is activated at this timestep 43 | if piano_roll.iloc[idx][note_name] == 1.: 44 | # Retrieve the MIDI note number for this note name 45 | note_number = pretty_midi.note_name_to_number(note_name) 46 | 47 | note_start = idx/sps # 0 if tempo = 60 48 | note_end = (idx+1)/sps # 0.25 49 | 50 | # Create a Note instance, starting according to the timestep * 16ths, ending one sixteenth later 51 | # TODO: Smooth this a bit by using lookahead 52 | note = pretty_midi.Note( 53 | velocity=velocity, pitch=note_number, start=note_start, end=note_end) 54 | # Add it to our instrument 55 | piano.notes.append(note) 56 | # Add the instrument to the PrettyMIDI object 57 | piano_midi.instruments.append(piano) 58 | return piano_midi 59 | 60 | # Write out the MIDI data 61 | #piano_midi.write('name.mid') 62 | 63 | def plot_midi(self, midi_sample): 64 | display.specshow(midi_sample.get_piano_roll(), y_axis='cqt_note', cmap=plt.cm.hot) 65 | 66 | def play_midi(self, midi_sample): 67 | fs = 44100 68 | synth = midi_sample.synthesize(fs=fs) 69 | return [synth], fs 70 | 71 | def decode_song(enc): 72 | piano_midi = pretty_midi.PrettyMIDI() 73 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | class MidiDataset(Dataset): 9 | """ 10 | Inputs: 11 | - song_tensor: input pitches of shape (num_samples, input_length, different_pitches) 12 | - song_paths: list of paths to input midi files 13 | 14 | Attributes: 15 | - midi_paths: list of paths to midi files 16 | - song_names: List of song names in dataset 17 | - song_name_to_idx: Dictionary mapping song name to idx in song_names and midi paths 18 | - index_mapper: List of tuple (song_idx, bar_idx) for each song 19 | """ 20 | def __init__(self, input_tensor, song_paths=None, instruments=None, tempos=None): 21 | self.song_tensor = [x.astype(float) for x in input_tensor] 22 | self.midi_paths = song_paths 23 | self.song_names = None 24 | if song_paths is not None: 25 | self.song_names = [os.path.basename(x).split('.')[0] for x in song_paths] 26 | self.index_mapper, self.song_to_bar_idx = self._initialize() 27 | if self.song_names is not None: 28 | self.song_to_idx = {v:k for (k,v) in enumerate(self.song_names)} 29 | self.instruments=instruments 30 | self.tempos = tempos 31 | 32 | def _initialize(self): 33 | index_mapper = [] 34 | song_to_idx = dict() 35 | 36 | for song_idx in range(len(self.song_tensor)): 37 | song_tuples = [] 38 | split_count = self.song_tensor[song_idx].shape[0] 39 | for bar_idx in range(0, split_count): 40 | song_tuples.append((song_idx, bar_idx)) 41 | index_mapper.append((song_idx, bar_idx)) 42 | 43 | if self.song_names is not None: 44 | song_name = self.song_names[song_idx] 45 | song_to_idx[song_name] = song_tuples 46 | return index_mapper, song_to_idx 47 | 48 | def __len__(self): 49 | return len(self.index_mapper) 50 | 51 | def __getitem__(self, idx): 52 | """ 53 | A sample is B consecutive bars. In MusicVAE this would be 16 consecutive 54 | bars. For MidiVAE a sample consists of a single bar. 55 | """ 56 | song_idx, section_idx = self.index_mapper[idx] 57 | 58 | sample = self.song_tensor[song_idx] 59 | sample = sample[section_idx,:,:] 60 | x = torch.tensor(sample, dtype=torch.float) 61 | return x.to(device) 62 | 63 | def get_tensor_by_name(self, song_name): 64 | """ 65 | Return tensor for specified song partitioned by bars 66 | """ 67 | song_idx = self.song_to_idx[song_name] 68 | # song_idx, bar_idx = self.song_to_bar_idx[song_name] 69 | # print(bar_idx) 70 | samples = self.song_tensor[song_idx] 71 | samples = torch.tensor(samples, dtype=torch.float) 72 | return samples 73 | 74 | def get_aux_by_names(self, song_name): 75 | """ 76 | Return aux information such as instruments and tempo 77 | """ 78 | if self.song_to_idx is not None: 79 | idx = self.song_to_idx[song_name] 80 | if self.instruments is not None: 81 | return self.instruments[idx], self.tempos[idx] 82 | return None -------------------------------------------------------------------------------- /src/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import os 4 | import shutil 5 | 6 | class Checkpoint: 7 | """ 8 | ---------------------------------------------------------------------------- 9 | A Checkpoint class. 10 | 11 | The purpose is to be organized with our experiments, as well to save and 12 | load progress in training. 13 | ---------------------------------------------------------------------------- 14 | """ 15 | 16 | CHECKPOINT_DIR_NAME = 'checkpoints' 17 | TRAINER_STATE_NAME = 'trainer_states.pt' 18 | MODEL_NAME = 'model.pt' 19 | 20 | def __init__(self, model, epoch, step, optimizer, scheduler, 21 | samp_rate, KL_rate, free_bits, path=None): 22 | self.model = model 23 | self.epoch = epoch 24 | self.step = step 25 | self.optimizer = optimizer 26 | self.scheduler = scheduler 27 | self.samp_rate = samp_rate 28 | self.KL_rate = KL_rate 29 | self.free_bits = free_bits 30 | self._path = path 31 | 32 | @property 33 | def path(self): 34 | if self._path is None: 35 | raise LookupError("This checkpoint has not been saved.") 36 | return self._path 37 | 38 | def save(self, experiment_dir): 39 | 40 | date_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()) 41 | self._path = os.path.join(experiment_dir, self.CHECKPOINT_DIR_NAME, date_time) 42 | path = self._path 43 | 44 | #If path exists, erase the whole thing nad make a new one. 45 | if os.path.exists(path): 46 | shutil.rmtree(path) 47 | os.makedirs(path) 48 | 49 | torch.save({'epoch': self.epoch, 50 | 'step': self.step, 51 | 'optimizer': self.optimizer, 52 | 'scheduler': self.scheduler, 53 | 'samp_rate': self.samp_rate, 54 | 'KL_rate': self.KL_rate, 55 | 'free_bits': self.free_bits 56 | }, 57 | os.path.join(path, self.TRAINER_STATE_NAME)) 58 | torch.save(self.model, os.path.join(path, self.MODEL_NAME)) 59 | 60 | return path 61 | 62 | @classmethod 63 | def load(cls, path): 64 | 65 | #Check if GPU is available. 66 | if torch.cuda.is_available(): 67 | resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME)) 68 | model = torch.load(os.path.join(path, cls.MODEL_NAME)) 69 | else: 70 | resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME),map_location=lambda storage, loc: storage) 71 | model = torch.load(os.path.join(path, cls.MODEL_NAME),map_location=lambda storage, loc: storage) 72 | 73 | #Make RNN parameters contiguous. 74 | # encoder.flatten_parameters() 75 | # decoder.flatten_parameters() 76 | return Checkpoint(model= model, 77 | epoch= resume_checkpoint['epoch'], 78 | step=resume_checkpoint['step'], 79 | optimizer= resume_checkpoint['optimizer'], 80 | scheduler = resume_checkpoint['scheduler'], 81 | samp_rate = resume_checkpoint['samp_rate'], 82 | KL_rate = resume_checkpoint['KL_rate'], 83 | free_bits = resume_checkpoint['free_bits'], 84 | path= path) 85 | @classmethod 86 | def get_latest_checkpoint(cls, exp_path): 87 | checkpoints_path = os.path.join(exp_path, cls.CHECKPOINT_DIR_NAME) 88 | all_times = sorted(os.listdir(checkpoints_path), reverse=True) 89 | return os.path.join(checkpoints_path, all_times[0]) -------------------------------------------------------------------------------- /src/sampler.py: -------------------------------------------------------------------------------- 1 | from src.checkpoint import Checkpoint 2 | 3 | from src.loss import ELBO 4 | 5 | import torch 6 | import torch.optim as optim 7 | from torch.optim.lr_scheduler import ExponentialLR, LambdaLR 8 | 9 | from math import exp 10 | import numpy as np 11 | 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | def spherical_interpolation(p0, p1, t): 16 | omega = np.arccos(np.dot(np.squeeze(p0/np.linalg.norm(p0)), 17 | np.squeeze(p1/np.linalg.norm(p1)))) 18 | so = np.sin(omega) 19 | return np.sin((1.0-t)*omega) / so * p0 + np.sin(t*omega)/so * p1 20 | 21 | class Sampler: 22 | def __init__(self, 23 | free_bits=256, 24 | output_dir='samples'): 25 | self.free_bits = free_bits 26 | self.output_dir = output_dir 27 | 28 | def reconstruction_loss(self, model, batch): 29 | """ 30 | Return reconstruction loss with and witout teacher forcing 31 | """ 32 | pred_tf, _, _, _ = model(batch, True) 33 | pred, _, _, _ = model(batch, False) 34 | loss_tf = torch.nn.functional.binary_cross_entropy_with_logits(pred_tf, batch, reduction='mean') 35 | loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, batch, reduction='mean') 36 | return loss_tf, loss 37 | 38 | def evaluate(self, model, data): 39 | """ 40 | Evaluate test test data directly using logits 41 | """ 42 | model.eval() 43 | loss_acc_tf = 0 44 | loss_acc = 0 45 | with torch.no_grad(): 46 | for idx, batch in enumerate(data): 47 | batch = batch.transpose(0, 1) 48 | batch_size = batch.size(1) 49 | batch = batch.view(model.max_sequence_length, batch_size, model.decoder.input_size) 50 | batch.to(device) 51 | batch_loss_tf, batch_loss = self.reconstruction_loss(model, batch) 52 | loss_acc_tf += batch_loss_tf 53 | loss_acc += batch_loss 54 | print('idx: %d, loss_tf: %.4f, loss: %.4f' % (idx, batch_loss_tf, batch_loss)) 55 | return loss_acc_tf / len(data), loss_acc / len(data) 56 | 57 | def reconstruct(self, model, song, temperature): 58 | """ 59 | Reconstruct song 60 | """ 61 | model.eval() 62 | with torch.no_grad(): 63 | song = song.transpose(0, 1) 64 | batch_size = song.size(1) 65 | song.view(model.max_sequence_length, batch_size, model.decoder.input_size) 66 | song.to(device) 67 | sample = model.reconstruct(song, temperature) 68 | # Samples are currently (seq_len, batch_size, num_notes) where 69 | # batch_size is the number of segments of 16 bars. These belong to the 70 | # same song so we want to return the concatenation of the entire song 71 | sample = sample.view(-1, model.input_size) 72 | return sample 73 | 74 | # def interpolate(self, model, song_A, song_B, num_steps, 75 | # length=None, temperature=1.0, assert_same_length=True) 76 | # """ 77 | # Args: 78 | # model: Trained model 79 | # start_sequence: The NoteSequence to interpolate from. 80 | # end_sequence: The NoteSequence to interpolate to. 81 | # num_steps: Number of NoteSequences to be generated, including the 82 | # reconstructions of the start and end sequences. 83 | # length: The maximum length of a sample in decoder iterations. Required 84 | # if end tokens are not being used. 85 | # temperature: The softmax temperature to use (if applicable). 86 | # assert_same_length: Whether to raise an AssertionError if all of the 87 | # extracted sequences are not the same length. 88 | # Returns: 89 | # A list of interpolated NoteSequences. 90 | # """ 91 | # model.eval() 92 | # model.to(device) 93 | # batch_size = 2 94 | 95 | # # Load songs 96 | # input = torch.randn(256, batch_size, 61) 97 | # reconstructed, _, _, latent = model(input, use_teacher_forcing=True) 98 | # # Interpolate between latent spaces 99 | # z_interpolated = np.array([spherical_interpolation(latent[0], latent[1], t) 100 | # for t in np.linspace(0, 1, num_steps)]) 101 | # # Decode interpolations 102 | # decoded = [model.decoder(step, use_teacher_forcing=True, temperature=temperature) 103 | # for step in z_interpolated] 104 | # # Reconstruct decoded interpolations 105 | -------------------------------------------------------------------------------- /src/midi_functions.py: -------------------------------------------------------------------------------- 1 | # Taken from MidiVAE (Brunner et al. 2018) 2 | import numpy as np 3 | import _pickle as pickle 4 | import os 5 | import sys 6 | import pretty_midi as pm 7 | import mido 8 | import operator 9 | 10 | 11 | def programs_to_instrument_matrix(programs, instrument_attach_method, max_voices): 12 | 13 | if instrument_attach_method == '1hot-instrument': 14 | #very large, not recommended 15 | instrument_feature_matrix = np.zeros((max_voices, 128)) 16 | for i, program in enumerate(programs): 17 | instrument_feature_matrix[i, program] = 1 18 | 19 | elif instrument_attach_method == '1hot-category': 20 | #categories according to midi declaration, https://en.wikipedia.org/wiki/General_MIDI 21 | #8 consecutive instruments make 1 category 22 | instrument_feature_matrix = np.zeros((max_voices, 16)) 23 | for i, program in enumerate(programs): 24 | instrument_feature_matrix[i, program//8] = 1 25 | 26 | elif instrument_attach_method == 'khot-instrument': 27 | #make a khot vector in log2 base for the instrument 28 | #log2(128) = 7 29 | instrument_feature_matrix = np.zeros((max_voices, 7)) 30 | for i, program in enumerate(programs): 31 | p = program 32 | for exponent in range(7): 33 | if p % 2 == 0: 34 | instrument_feature_matrix[i, exponent] = 1 35 | p = p // 2 36 | elif instrument_attach_method == 'khot-category': 37 | #categories according to midi declaration, https://en.wikipedia.org/wiki/General_MIDI 38 | #8 consecutive instruments make 1 category 39 | #make a khot vector in log2 base for the category 40 | #log2(16) = 4 41 | instrument_feature_matrix = np.zeros((max_voices, 4)) 42 | for i, program in enumerate(programs): 43 | p = program//8 44 | for exponent in range(4): 45 | if p % 2 == 1: 46 | instrument_feature_matrix[i, exponent] = 1 47 | p = p // 2 48 | else: 49 | print("Not implemented!") 50 | 51 | return instrument_feature_matrix 52 | 53 | 54 | def rolls_to_midi(pianoroll, 55 | programs, 56 | save_folder, 57 | filename, 58 | bpm, 59 | low_crop, 60 | high_crop, 61 | num_notes, 62 | velocity_threshold, 63 | velocity_roll=None, 64 | held_notes_roll=None, 65 | smallest_note=16, 66 | max_velocity=127): 67 | 68 | #bpm is in quarter notes, so scale accordingly 69 | # bpm = bpm * (smallest_note / 4) 70 | 71 | pianoroll = np.pad(np.copy(pianoroll), ((0,0),(low_crop,num_notes-high_crop)), mode='constant', constant_values=0) 72 | 73 | if not os.path.exists(save_folder): 74 | os.makedirs(save_folder) 75 | midi = pm.PrettyMIDI(initial_tempo=bpm, resolution=1000) 76 | midi.time_signature_changes.append(pm.TimeSignature(4, 4, 0)) 77 | 78 | for voice, program in enumerate(programs): 79 | 80 | current_instrument = pm.Instrument(program=program) 81 | current_pianoroll = pianoroll[voice::len(programs),:] 82 | 83 | if velocity_roll is not None: 84 | current_velocity_roll = np.copy(velocity_roll[voice::len(programs)]) 85 | #during the training, the velocities were scaled to be in the range 0,1 86 | #scale it back to the actual velocity numbers 87 | current_velocity_roll[np.where(current_velocity_roll < velocity_threshold)] = 0 88 | current_velocity_roll[np.where(current_velocity_roll >= velocity_threshold)] -= 0.5 89 | current_velocity_roll /= (1.0 - velocity_threshold) 90 | current_velocity_roll *= max_velocity 91 | 92 | if held_notes_roll is not None: 93 | current_held_notes_roll = np.copy(held_notes_roll[voice::len(programs)]) 94 | 95 | 96 | 97 | tracker = [] 98 | start_times = dict() 99 | velocities = dict() 100 | for i, note_vector in enumerate(current_pianoroll): 101 | notes = list(note_vector.nonzero()[0]) 102 | # 103 | #notes that were just played and need to be removed from the tracker 104 | removal_list = [] 105 | for note in tracker: 106 | 107 | #determine if you still hold this note or not 108 | hold_this_note = True 109 | if held_notes_roll is not None: 110 | hold_this_note = current_held_notes_roll[i] > 0.5 111 | 112 | #it may happen that a note seems to be held but has switched to another channel 113 | #in that case, play the note anyways 114 | if note not in notes: 115 | hold_this_note = False 116 | 117 | else: 118 | hold_this_note = note in notes and (i)% smallest_note is not 0 119 | 120 | if hold_this_note: 121 | #held note, don't play a new note 122 | notes.remove(note) 123 | else: 124 | if velocity_roll is not None: 125 | velocity = velocities[note] 126 | if velocity > max_velocity: 127 | velocity = int(max_velocity) 128 | else: 129 | velocity = 80 130 | 131 | midi_note = pm.Note(velocity=velocity, pitch=note, start=(60/bpm)*start_times[note], end=(60/bpm)*i) 132 | current_instrument.notes.append(midi_note) 133 | 134 | removal_list.append(note) 135 | for note in removal_list: 136 | tracker.remove(note) 137 | 138 | 139 | for note in notes: 140 | tracker.append(note) 141 | start_times[note]=i 142 | if velocity_roll is not None: 143 | velocities[note] = int(current_velocity_roll[i]) 144 | 145 | midi.instruments.append(current_instrument) 146 | midi.write(os.path.join(save_folder,filename+'.mid')) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ore latex/pdflatex auxiliary files: 2 | *.aux 3 | *.lof 4 | *.log 5 | *.lot 6 | *.fls 7 | *.out 8 | *.toc 9 | *.fmt 10 | *.fot 11 | *.cb 12 | *.cb2 13 | .*.lb 14 | 15 | ## Intermediate documents: 16 | *.dvi 17 | *.xdv 18 | *-converted-to.* 19 | # these rules might exclude image files for figures etc. 20 | # *.ps 21 | # *.eps 22 | # *.pdf 23 | 24 | ## Generated if empty string is given at "Please type another file name for output:" 25 | .pdf 26 | 27 | ## Bibliography auxiliary files (bibtex/biblatex/biber): 28 | *.bbl 29 | *.bcf 30 | *.blg 31 | *-blx.aux 32 | *-blx.bib 33 | *.run.xml 34 | 35 | ## Build tool auxiliary files: 36 | *.fdb_latexmk 37 | *.synctex 38 | *.synctex(busy) 39 | *.synctex.gz 40 | *.synctex.gz(busy) 41 | *.pdfsync 42 | 43 | ## Build tool directories for auxiliary files 44 | # latexrun 45 | latex.out/ 46 | 47 | ## Auxiliary and intermediate files from other packages: 48 | # algorithms 49 | *.alg 50 | *.loa 51 | 52 | # achemso 53 | acs-*.bib 54 | 55 | # amsthm 56 | *.thm 57 | 58 | # beamer 59 | *.nav 60 | *.pre 61 | *.snm 62 | *.vrb 63 | 64 | # changes 65 | *.soc 66 | 67 | # comment 68 | *.cut 69 | 70 | # cprotect 71 | *.cpt 72 | 73 | # elsarticle (documentclass of Elsevier journals) 74 | *.spl 75 | 76 | # endnotes 77 | *.ent 78 | 79 | # fixme 80 | *.lox 81 | 82 | # feynmf/feynmp 83 | *.mf 84 | *.mp 85 | *.t[1-9] 86 | *.t[1-9][0-9] 87 | *.tfm 88 | 89 | #(r)(e)ledmac/(r)(e)ledpar 90 | *.end 91 | *.?end 92 | *.[1-9] 93 | *.[1-9][0-9] 94 | *.[1-9][0-9][0-9] 95 | *.[1-9]R 96 | *.[1-9][0-9]R 97 | *.[1-9][0-9][0-9]R 98 | *.eledsec[1-9] 99 | *.eledsec[1-9]R 100 | *.eledsec[1-9][0-9] 101 | *.eledsec[1-9][0-9]R 102 | *.eledsec[1-9][0-9][0-9] 103 | *.eledsec[1-9][0-9][0-9]R 104 | 105 | # glossaries 106 | *.acn 107 | *.acr 108 | *.glg 109 | *.glo 110 | *.gls 111 | *.glsdefs 112 | 113 | # gnuplottex 114 | *-gnuplottex-* 115 | 116 | # gregoriotex 117 | *.gaux 118 | *.gtex 119 | 120 | # htlatex 121 | *.4ct 122 | *.4tc 123 | *.idv 124 | *.lg 125 | *.trc 126 | *.xref 127 | 128 | # hyperref 129 | *.brf 130 | 131 | # knitr 132 | *-concordance.tex 133 | # TODO Comment the next line if you want to keep your tikz graphics files 134 | *.tikz 135 | *-tikzDictionary 136 | 137 | # listings 138 | *.lol 139 | 140 | # luatexja-ruby 141 | *.ltjruby 142 | 143 | # makeidx 144 | *.idx 145 | *.ilg 146 | *.ind 147 | *.ist 148 | 149 | # minitoc 150 | *.maf 151 | *.mlf 152 | *.mlt 153 | *.mtc[0-9]* 154 | *.slf[0-9]* 155 | *.slt[0-9]* 156 | *.stc[0-9]* 157 | 158 | # minted 159 | _minted* 160 | *.pyg 161 | 162 | # morewrites 163 | *.mw 164 | 165 | # nomencl 166 | *.nlg 167 | *.nlo 168 | *.nls 169 | 170 | # pax 171 | *.pax 172 | 173 | # pdfpcnotes 174 | *.pdfpc 175 | 176 | # sagetex 177 | *.sagetex.sage 178 | *.sagetex.py 179 | *.sagetex.scmd 180 | 181 | # scrwfile 182 | *.wrt 183 | 184 | # sympy 185 | *.sout 186 | *.sympy 187 | sympy-plots-for-*.tex/ 188 | 189 | # pdfcomment 190 | *.upa 191 | *.upb 192 | 193 | # pythontex 194 | *.pytxcode 195 | pythontex-files-*/ 196 | 197 | # tcolorbox 198 | *.listing 199 | 200 | # thmtools 201 | *.loe 202 | 203 | # TikZ & PGF 204 | *.dpth 205 | *.md5 206 | *.auxlock 207 | 208 | # todonotes 209 | *.tdo 210 | 211 | # vhistory 212 | *.hst 213 | *.ver 214 | 215 | # easy-todo 216 | *.lod 217 | 218 | # xcolor 219 | *.xcp 220 | 221 | # xmpincl 222 | *.xmpi 223 | 224 | # xindy 225 | *.xdy 226 | 227 | # xypic precompiled matrices 228 | *.xyc 229 | 230 | # endfloat 231 | *.ttt 232 | *.fff 233 | 234 | # Latexian 235 | TSWLatexianTemp* 236 | 237 | ## Editors: 238 | # WinEdt 239 | *.bak 240 | *.sav 241 | 242 | # Texpad 243 | .texpadtmp 244 | 245 | # LyX 246 | *.lyx~ 247 | 248 | # Kile 249 | *.backup 250 | 251 | # KBibTeX 252 | *~[0-9]* 253 | 254 | # auto folder when using emacs and auctex 255 | ./auto/* 256 | *.el 257 | 258 | # expex forward references with \gathertags 259 | *-tags.tex 260 | 261 | # standalone packages 262 | *.sta 263 | 264 | # Byte-compiled / optimized / DLL files 265 | __pycache__/ 266 | *.py[cod] 267 | *$py.class 268 | 269 | # C extensions 270 | *.so 271 | 272 | # Distribution / packaging 273 | .Python 274 | build/ 275 | develop-eggs/ 276 | dist/ 277 | downloads/ 278 | eggs/ 279 | .eggs/ 280 | lib/ 281 | lib64/ 282 | parts/ 283 | sdist/ 284 | var/ 285 | wheels/ 286 | pip-wheel-metadata/ 287 | share/python-wheels/ 288 | *.egg-info/ 289 | .installed.cfg 290 | *.egg 291 | MANIFEST 292 | 293 | # PyInstaller 294 | # Usually these files are written by a python script from a template 295 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 296 | *.manifest 297 | *.spec 298 | 299 | # Installer logs 300 | pip-log.txt 301 | pip-delete-this-directory.txt 302 | 303 | # Unit test / coverage reports 304 | htmlcov/ 305 | .tox/ 306 | .nox/ 307 | .coverage 308 | .coverage.* 309 | .cache 310 | nosetests.xml 311 | coverage.xml 312 | *.cover 313 | .hypothesis/ 314 | .pytest_cache/ 315 | 316 | # Translations 317 | *.mo 318 | *.pot 319 | 320 | # Django stuff: 321 | *.log 322 | local_settings.py 323 | db.sqlite3 324 | 325 | # Flask stuff: 326 | instance/ 327 | .webassets-cache 328 | 329 | # Scrapy stuff: 330 | .scrapy 331 | 332 | # Sphinx documentation 333 | docs/_build/ 334 | 335 | # PyBuilder 336 | target/ 337 | 338 | # Jupyter Notebook 339 | .ipynb_checkpoints 340 | 341 | # IPython 342 | profile_default/ 343 | ipython_config.py 344 | 345 | # pyenv 346 | .python-version 347 | 348 | # pipenv 349 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 350 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 351 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 352 | # install all needed dependencies. 353 | #Pipfile.lock 354 | 355 | # celery beat schedule file 356 | celerybeat-schedule 357 | 358 | # SageMath parsed files 359 | *.sage.py 360 | 361 | # Environments 362 | .env 363 | .venv 364 | env/ 365 | venv/ 366 | ENV/ 367 | env.bak/ 368 | venv.bak/ 369 | 370 | # Spyder project settings 371 | .spyderproject 372 | .spyproject 373 | 374 | # Rope project settings 375 | .ropeproject 376 | 377 | # mkdocs documentation 378 | /site 379 | 380 | # mypy 381 | .mypy_cache/ 382 | .dmypy.json 383 | dmypy.json 384 | 385 | # Pyre type checker 386 | .pyre/ 387 | 388 | 389 | /data 390 | 391 | .DS_Store 392 | outputs 393 | 394 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import sys 5 | import yaml 6 | 7 | sys.path.append(".") 8 | 9 | from torch.utils.data import DataLoader 10 | 11 | from src.model import * 12 | from src.sampler import * 13 | from src.dataset import MidiDataset 14 | from src.midi_functions import rolls_to_midi 15 | 16 | # General settings 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--config', type=str, default='conf.yml') 19 | parser.add_argument('--model_type', type=str, default='lstm') 20 | parser.add_argument('--mode', type=str, choices=['eval', 'interpolate', 'reconstruct']) 21 | 22 | 23 | def load_model(model_type, params): 24 | if model_type == 'lstm': 25 | model = MusicLSTMVAE(**params) 26 | elif model_type == 'gru': 27 | model = MusicGRUVAE(**params) 28 | else: 29 | raise Exception("Invalid model type. Expected lstm or gru") 30 | return model 31 | 32 | def load_data(test_data, batch_size, song_paths='', instrument_path='', tempo_path=''): 33 | X_test = pickle.load(open(test_data, 'rb')) 34 | 35 | song_names = None 36 | if song_paths != '': 37 | song_names = [os.path.basename(x) for x in pickle.load(open(song_paths, 'rb'))] 38 | 39 | instruments = None 40 | if instrument_path != '': 41 | instruments = pickle.load(open(instrument_path, 'rb')) 42 | 43 | tempos = None 44 | if tempo_path != '': 45 | tempos = pickle.load(open(tempo_path, 'rb')) 46 | 47 | test_data = MidiDataset(X_test, song_paths=song_names, instruments=instruments, tempos=tempos) 48 | test_loader = DataLoader(test_data, batch_size=batch_size) 49 | return test_loader 50 | 51 | def load_tempo(tempo_path, song_id): 52 | if temp_path is None: 53 | raise ValueError('Tempo file unspecified') 54 | else: 55 | tempos = pickle.load(open(tempo_path, 'rb')) 56 | return tempos[song_id] 57 | 58 | def evaluate(sampler, model, args): 59 | data_path = args['test_data'] 60 | song_names = args['test_songs'] 61 | batch_size = args['batch_size'] 62 | data = load_data(test_data=data_path, batch_size=batch_size, instrument_path='', song_paths=song_names) 63 | loss_tf, loss = sampler.evaluate(model, data) 64 | print("Loss with teacher forcing: %.4f, loss without teacher forcing: %.4f" % (loss_tf, loss)) 65 | 66 | def instrument_representation_to_programs(I, instrument_attach_method='1hot-category'): 67 | programs = [] 68 | for instrument_vector in I: 69 | if instrument_attach_method == '1hot-category': 70 | index = np.argmax(instrument_vector) 71 | programs.append(index * 8) 72 | elif instrument_attach_method == 'khot-category': 73 | nz = np.nonzero(instrument_vector)[0] 74 | index = 0 75 | for exponent in nz: 76 | index += 2^exponent 77 | programs.append(index * 8) 78 | elif instrument_attach_method == '1hot-instrument': 79 | index = np.argmax(instrument_vector) 80 | programs.append(index) 81 | elif instrument_attach_method == 'khot-instrument': 82 | nz = np.nonzero(instrument_vector)[0] 83 | index = 0 84 | for exponent in nz: 85 | index += 2^exponent 86 | programs.append(index) 87 | return programs 88 | 89 | def reconstruct(sampler, model, evaluation_params): 90 | # Load data 91 | data_path = evaluation_params['test_data'] 92 | song_names = evaluation_params['test_songs'] 93 | tempos = evaluation_params['test_tempos'] 94 | instruments = evaluation_params['test_instruments'] 95 | batch_size = evaluation_params['batch_size'] 96 | data = load_data(data_path, batch_size, song_names, instruments, tempos) 97 | 98 | # Reconstruct specified song 99 | reconstruction_params = evaluation_params['reconstruction'] 100 | song_id = reconstruction_params['song_name'] 101 | temperature = evaluation_params['temperature'] 102 | attach_method = reconstruction_params['attach_method'] 103 | reconstruction_path = reconstruction_params['reconstruction_path'] 104 | song = data.dataset.get_tensor_by_name(song_id) 105 | # Generate reconstruction from the samples 106 | reconstructed = sampler.reconstruct(model, song, temperature) 107 | # Reconstruct into midi form 108 | I, tempo = data.dataset.get_aux_by_names(song_id) 109 | programs = instrument_representation_to_programs(I, attach_method) 110 | 111 | rolls_to_midi(reconstructed, 112 | programs, 113 | reconstruction_path, 114 | song_id, 115 | tempo, 116 | 24, 117 | 84, 118 | 128, 119 | 0.5) 120 | 121 | print('Saved reconstruction for %s' % song_id) 122 | 123 | def interpolate(): 124 | raise ValueError("Not implemented") 125 | 126 | def main(args): 127 | model_params = None 128 | sampler = None 129 | data_params = None 130 | evaluation_params = None 131 | with open(args.config, 'r') as config_file: 132 | config = yaml.load(config_file) 133 | model_params = config['model'] 134 | sampler_params = { 135 | 'free_bits': config['sampler']['free_bits'], 136 | 'output_dir': config['sampler']['output_dir'] 137 | } 138 | data_params = config['data'] 139 | evaluation_params = config['evaluation'] 140 | 141 | model = load_model(args.model_type, model_params) 142 | sampler = Sampler(**sampler_params) 143 | 144 | model.load_state_dict(torch.load(evaluation_params['model_path'], 145 | map_location='cpu').state_dict(), strict=False) 146 | print(model) 147 | model.eval() 148 | 149 | mode = args.mode 150 | if mode == 'eval': 151 | evaluate(sampler, model, evaluation_params) 152 | elif mode == 'reconstruct': 153 | reconstruct(sampler, model, evaluation_params) 154 | 155 | # elif mode == 'interpolate': 156 | # song_id_A = args.song_id_a 157 | # song_id_B = args.song_id_b 158 | # data = load_data(data_path, args.batch_size) 159 | # song_a = data.get_tensor_by_name(song_id_a) 160 | # song_b = data.get_tensor_by_name(song_id_b) 161 | # interpolated = sampler.interpolate(model, song_a, song_b) 162 | # # TODO save interpolated 163 | 164 | if __name__ == '__main__': 165 | args = parser.parse_args() 166 | main(args) 167 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from src.layers import * 5 | 6 | """ 7 | VAE models 8 | """ 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | class MusicLSTMVAE(nn.Module): 11 | """ 12 | Inputs 13 | - encoder_input_size: Size of input representation (i.e 60 music pitches + 1 silence) 14 | - num_subsequences: Number of subsequences to partition input (corresponds to U) 15 | - sequence_length: Length of sequences (T=32 for 2-bar and T=256 for 16-bar data) 16 | - encoder_hidden_size: dimension of encoder hidden state 17 | - decoder_hidden_size: dimension of decoder hidden state 18 | - latent_dim: dimension of latent variable z 19 | - encoder_num_layers: Number of encoder lstm layers 20 | - decoder_num_layers: Number of decoder lstm layers 21 | """ 22 | def __init__(self, 23 | num_subsequences=16, 24 | max_sequence_length=256, 25 | sequence_length=16, 26 | encoder_input_size=61, 27 | decoder_input_size=61, 28 | encoder_hidden_size=2048, 29 | decoder_hidden_size=1024, 30 | latent_dim=512, 31 | encoder_num_layers=2, 32 | decoder_num_layers=2): 33 | super(MusicLSTMVAE, self).__init__() 34 | self.input_size = decoder_input_size 35 | self.encoder = BiLSTMEncoder(encoder_input_size, 36 | encoder_hidden_size, 37 | latent_dim, 38 | encoder_num_layers) 39 | self.z_embedding = nn.Sequential( 40 | nn.Linear(in_features=latent_dim, out_features=latent_dim), 41 | nn.Tanh() 42 | ) 43 | self.decoder = HierarchicalLSTMDecoder(num_embeddings=num_subsequences, 44 | input_size=decoder_input_size, 45 | hidden_size=decoder_hidden_size, 46 | latent_size=latent_dim, 47 | num_layers=decoder_num_layers, 48 | max_seq_length=max_sequence_length, 49 | seq_length=sequence_length) 50 | 51 | def forward(self, x, use_teacher_forcing): 52 | """ 53 | Input 54 | - x: input sequence x = x_1, ... ,x_T 55 | """ 56 | batch_size = x.size(1) 57 | h_enc, c_enc = self.encoder.init_hidden(batch_size) 58 | mu, sigma = self.encoder(x, h_enc, c_enc) 59 | 60 | # Sample latent variable 61 | with torch.no_grad(): 62 | epsilon = torch.randn_like(mu, device=device) 63 | 64 | z = self.z_embedding(mu + sigma*epsilon) 65 | h_dec, c_dec = self.decoder.init_hidden(batch_size) 66 | out = self.decoder(x, z, h_dec, c_dec, use_teacher_forcing) 67 | return out, mu, sigma, z 68 | 69 | def reconstruct(self, x, temperature): 70 | batch_size = x.size(1) 71 | h_enc, c_enc = self.encoder.init_hidden(batch_size) 72 | mu, sigma = self.encoder(x, h_enc, c_enc) 73 | with torch.no_grad(): 74 | epsilon = torch.randn_like(mu, device=device) 75 | z = self.z_embedding(mu + sigma*epsilon) 76 | h_dec, c_dec = self.decoder.init_hidden(batch_size) 77 | out = self.decoder.reconstruct(z, h_dec, c_dec, temperature) 78 | return out 79 | 80 | class MusicGRUVAE(nn.Module): 81 | """ 82 | Inputs 83 | - encoder_input_size: Size of input representation (i.e 60 music pitches + 1 silence) 84 | - num_subsequences: Number of subsequences to partition input (corresponds to U) 85 | - sequence_length: Length of sequences (T=32 for 2-bar and T=256 for 16-bar data) 86 | - encoder_hidden_size: dimension of encoder hidden state 87 | - decoder_hidden_size: dimension of decoder hidden state 88 | - latent_dim: dimension of latent variable z 89 | - encoder_num_layers: Number of encoder lstm layers 90 | - decoder_num_layers: Numnber of decoder lstm layers 91 | """ 92 | def __init__(self, 93 | num_subsequences=16, 94 | max_sequence_length=256, 95 | sequence_length=16, 96 | encoder_input_size=61, 97 | decoder_input_size=61, 98 | encoder_hidden_size=2048, 99 | decoder_hidden_size=1024, 100 | latent_dim=512, 101 | encoder_num_layers=2, 102 | decoder_num_layers=2): 103 | super(MusicGRUVAE, self).__init__() 104 | self.input_size = decoder_input_size 105 | self.max_sequence_length = max_sequence_length 106 | self.encoder = BiGRUEncoder(encoder_input_size, 107 | encoder_hidden_size, 108 | latent_dim, 109 | encoder_num_layers) 110 | self.z_embedding = nn.Sequential( 111 | nn.Linear(in_features=latent_dim, out_features=latent_dim), 112 | nn.Tanh() 113 | ) 114 | self.decoder = HierarchicalGRUDecoder(num_embeddings=num_subsequences, 115 | input_size=decoder_input_size, 116 | hidden_size=decoder_hidden_size, 117 | latent_size=latent_dim, 118 | num_layers=decoder_num_layers, 119 | max_seq_length=max_sequence_length, 120 | seq_length=sequence_length) 121 | 122 | def forward(self, x, use_teacher_forcing): 123 | """ 124 | Input 125 | - x: input sequence x = x_1, ... ,x_T 126 | """ 127 | batch_size = x.size(1) 128 | h_enc = self.encoder.init_hidden(batch_size) 129 | mu, sigma = self.encoder(x, h_enc) 130 | 131 | # Sample latent variable 132 | with torch.no_grad(): 133 | epsilon = torch.randn_like(mu, device=device) 134 | 135 | z = self.z_embedding(mu + sigma*epsilon) 136 | h_dec = self.decoder.init_hidden(batch_size) 137 | out = self.decoder(x, z, h_dec, use_teacher_forcing) 138 | return out, mu, sigma, z 139 | 140 | def reconstruct(self, x, temperature): 141 | batch_size = x.size(1) 142 | h_enc = self.encoder.init_hidden(batch_size) 143 | mu, sigma = self.encoder(x, h_enc) 144 | with torch.no_grad(): 145 | epsilon = torch.randn_like(mu, device=device) 146 | z = self.z_embedding(mu + sigma*epsilon) 147 | h_dec = self.decoder.init_hidden(batch_size) 148 | out = self.decoder.reconstruct(z, h_dec, temperature) 149 | return out -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | from src.checkpoint import Checkpoint 2 | 3 | from src.loss import ELBO 4 | 5 | import torch 6 | import torch.optim as optim 7 | from torch.optim.lr_scheduler import ExponentialLR, LambdaLR 8 | 9 | from math import exp 10 | import numpy as np 11 | 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | def decay(x): 16 | return 0.01 + (0.99)*(0.9999)**x 17 | 18 | class Trainer: 19 | def __init__(self, 20 | learning_rate=1e-3, 21 | KL_rate=0.9999, 22 | free_bits=256, 23 | sampling_rate=2000, 24 | batch_size=512, 25 | print_every=1000, 26 | checkpoint_every=10000, 27 | checkpoint_dir='checkpoint', 28 | output_dir='outputs'): 29 | self.learning_rate = learning_rate 30 | self.KL_rate = KL_rate 31 | self.free_bits = free_bits 32 | self.optimizer=None 33 | self.scheduler=None 34 | self.sampling_rate = sampling_rate 35 | self.batch_size = batch_size 36 | self.print_every = print_every 37 | self.checkpoint_every = checkpoint_every 38 | self.output_dir = output_dir 39 | 40 | def inverse_sigmoid(self,step): 41 | """ 42 | Compute teacher forcing probability with inverse sigmoid 43 | """ 44 | k = self.sampling_rate 45 | if k == None: 46 | return 0 47 | if k == 1.0: 48 | return 1 49 | return k/(k + exp(step/k)) 50 | 51 | def KL_annealing(self, step, start, end): 52 | return end + (start - end)*(self.KL_rate)**step 53 | 54 | def compute_loss(self, step, model, batch, use_teacher_forcing=True): 55 | batch.to(device) 56 | pred, mu, sigma, z = model(batch, use_teacher_forcing) 57 | elbo, kl = ELBO(pred, batch, mu, sigma, self.free_bits) 58 | kl_weight = self.KL_annealing(step, 0, 0.2) 59 | return kl_weight*elbo, kl 60 | 61 | def train_batch(self, iter, model, batch): 62 | self.optimizer.zero_grad() 63 | use_teacher_forcing = self.inverse_sigmoid(iter) 64 | elbo, kl = self.compute_loss(iter, model, batch, use_teacher_forcing) 65 | self.scheduler.step() 66 | elbo.backward() 67 | self.optimizer.step() 68 | return elbo.item(), kl.item() 69 | 70 | def train_epochs(self, model, start_epoch, iter, end_epoch, train_data, val_data=None): 71 | train_loss, val_loss = [], [] 72 | train_kl, val_kl = [], [] 73 | for epoch in range(start_epoch, end_epoch): 74 | batch_loss, batch_kl = [], [] 75 | model.train() 76 | 77 | for idx, batch in enumerate(train_data): 78 | batch = batch.transpose(0, 1).squeeze() 79 | batch.to(device) 80 | elbo, kl = self.train_batch(iter, model, batch) 81 | batch_loss.append(elbo) 82 | batch_kl.append(kl) 83 | iter += 1 84 | 85 | if iter%self.print_every == 0: 86 | loss_avg = torch.mean(torch.tensor(batch_loss)) 87 | div = torch.mean(torch.tensor(batch_kl)) 88 | print('Epoch: %d, iteration: %d, Average loss: %.4f, KL Divergence: %.4f' % (epoch, iter, loss_avg, div)) 89 | 90 | if iter%self.checkpoint_every == 0: 91 | self.save_checkpoint(model, epoch, iter) 92 | 93 | train_loss.append(torch.mean(torch.tensor(batch_loss))) 94 | train_kl.append(torch.mean(torch.tensor(batch_kl))) 95 | 96 | self.save_checkpoint(model, epoch, iter) 97 | 98 | if val_data is not None: 99 | batch_loss, batch_kl = [], [] 100 | with torch.no_grad(): 101 | model.eval() 102 | for idx, batch in enumerate(val_data): 103 | batch.to(device) 104 | batch = batch.transpose(0, 1).squeeze() 105 | elbo, kl = self.compute_loss(iter, model, batch, False) 106 | batch_loss.append(elbo) 107 | batch_kl.append(kl) 108 | val_loss.append(torch.mean(torch.tensor(batch_loss))) 109 | val_kl.append(torch.mean(torch.tensor(batch_kl))) 110 | loss_avg = torch.mean(torch.tensor(val_loss)) 111 | div = torch.mean(torch.tensor(val_kl)) 112 | print('Validation') 113 | print('Epoch: %d, iteration: %d, Average loss: %.4f, KL Divergence: %.4f' % (epoch, iter, loss_avg, div)) 114 | 115 | torch.save(open('outputs/train_loss_musicvae_batch', 'wb'), torch.tensor(train_loss)) 116 | torch.save(open('outputs/val_loss_musicvae_batch', 'wb'), torch.tensor(val_loss)) 117 | torch.save(open('outputs/train_kl_musicvae_batch', 'wb'), torch.tensor(train_kl)) 118 | torch.save(open('outputs/val_kl_musicvae_batch', 'wb'), torch.tensor(val_kl)) 119 | 120 | def save_checkpoint(self, model, epoch, iter): 121 | print('Saving checkpoint') 122 | Checkpoint(model=model, 123 | epoch=epoch, 124 | step=iter, 125 | optimizer=self.optimizer, 126 | scheduler=self.scheduler, 127 | samp_rate=self.sampling_rate, 128 | KL_rate=self.KL_rate, 129 | free_bits=self.free_bits).save(self.output_dir) 130 | print('Checkpoint Successful') 131 | 132 | def load_checkpoint(self): 133 | latest_checkpoint_path = Checkpoint.get_latest_checkpoint(self.output_dir) 134 | resume_checkpoint = Checkpoint.load(latest_checkpoint_path) 135 | model = resume_checkpoint.model 136 | epoch = resume_checkpoint.epoch 137 | iter = resume_checkpoint.step 138 | self.scheduler = resume_checkpoint.scheduler 139 | self.optimizer = resume_checkpoint.optimizer 140 | self.sampling_rate = resume_checkpoint.samp_rate 141 | self.KL_rate = resume_checkpoint.KL_rate 142 | self.free_bits = resume_checkpoint.free_bits 143 | return model, epoch, iter 144 | 145 | def train(self, model, train_data, optimizer, epochs, resume=False, val_data=None): 146 | if resume: 147 | model, epoch, iter = self.load_checkpoint() 148 | else: 149 | if optimizer is None: 150 | self.optimizer = torch.optim.Adam(model.parameters(), self.learning_rate) 151 | self.scheduler = LambdaLR(self.optimizer, decay) 152 | 153 | epoch = 1 154 | iter = 0 155 | print(model) 156 | print(self.optimizer) 157 | print(self.scheduler) 158 | print('Starting epoch %d' % epoch) 159 | 160 | model.to(device) 161 | self.train_epochs(model, epoch, iter, epoch+epochs, train_data, val_data) 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions.categorical import Categorical 4 | 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | class BiLSTMEncoder(nn.Module): 9 | """ 10 | Bi-directional LSTM encoder from MusicVAE 11 | Inputs: 12 | - input_size: Dimension of one-hot representation of input notes 13 | - hidden_size: hidden size of bidirectional lstm 14 | - num_layers: Number of layers for bidirectional lstm 15 | """ 16 | def __init__(self, 17 | input_size=61, 18 | hidden_size=2048, 19 | latent_size=512, 20 | num_layers=2): 21 | super(BiLSTMEncoder, self).__init__() 22 | self.input_size = input_size 23 | self.hidden_size = hidden_size 24 | self.latent_size = latent_size 25 | self.num_layers = num_layers 26 | 27 | self.bilstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=True) 28 | self.mu = nn.Linear(in_features=2 * hidden_size, out_features=latent_size) 29 | self.sigma = nn.Linear(in_features=2 * hidden_size, out_features=latent_size) 30 | self.softplus = nn.Softplus() 31 | 32 | def forward(self, input, h0, c0): 33 | batch_size = input.size(1) 34 | _, (h_n, c_n) = self.bilstm(input, (h0, c0)) 35 | h_n = h_n.view(self.num_layers, 2, batch_size, -1)[-1].view(batch_size, -1) 36 | mu = self.mu(h_n) 37 | sigma = self.softplus(self.sigma(h_n)) 38 | return mu, sigma 39 | 40 | def init_hidden(self, batch_size=1): 41 | # Bidirectional lstm so num_layers*2 42 | return (torch.zeros(self.num_layers * 2, batch_size, self.hidden_size, dtype=torch.float, device=device), 43 | torch.zeros(self.num_layers * 2, batch_size, self.hidden_size, dtype=torch.float, device=device)) 44 | 45 | 46 | class HierarchicalLSTMDecoder(nn.Module): 47 | """ 48 | Hierarchical decoder from MusicVAE 49 | """ 50 | 51 | def __init__(self, 52 | num_embeddings, 53 | input_size=61, 54 | hidden_size=1024, 55 | latent_size=512, 56 | num_layers=2, 57 | max_seq_length=256, 58 | seq_length=16): 59 | super(HierarchicalLSTMDecoder, self).__init__() 60 | self.input_size = input_size 61 | self.hidden_size = hidden_size 62 | self.latent_size = latent_size 63 | self.num_embeddings = num_embeddings 64 | self.max_seq_length = max_seq_length 65 | self.seq_length = seq_length 66 | self.num_layers = num_layers 67 | 68 | self.tanh = nn.Tanh() 69 | self.conductor = nn.LSTM(input_size=latent_size, hidden_size=hidden_size, num_layers=num_layers) 70 | self.conductor_embeddings = nn.Sequential( 71 | nn.Linear(in_features=hidden_size, out_features=latent_size), 72 | nn.Tanh()) 73 | self.lstm = nn.LSTM(input_size=input_size + latent_size, hidden_size=hidden_size, num_layers=num_layers) 74 | self.out = nn.Sequential( 75 | nn.Linear(in_features=hidden_size, out_features=input_size), 76 | nn.Softmax(dim=2) 77 | ) 78 | 79 | def forward(self, target, latent, h0, c0, use_teacher_forcing=True, temperature=1.0): 80 | batch_size = target.size(1) 81 | 82 | out = torch.zeros(self.max_seq_length, batch_size, self.input_size, dtype=torch.float, device=device) 83 | # Initialie start note 84 | prev_note = torch.zeros(1, batch_size, self.input_size, dtype=torch.float, device=device) 85 | 86 | # Conductor produces an embedding vector for each subsequence 87 | for embedding_idx in range(self.num_embeddings): 88 | embedding, (h0, c0) = self.conductor(latent.unsqueeze(0), (h0, c0)) 89 | embedding = self.conductor_embeddings(embedding) 90 | 91 | # Initialize lower decoder hidden state 92 | h0_dec = (torch.randn(self.num_layers, batch_size, self.hidden_size, dtype=torch.float, device=device), 93 | torch.randn(self.num_layers, batch_size, self.hidden_size, dtype=torch.float, device=device)) 94 | 95 | # Decoder produces sequence of distributions over output tokens 96 | # for each subsequence where at each step the current 97 | # conductor embedding is concatenated with the previous output 98 | # token to be used as input 99 | if use_teacher_forcing: 100 | embedding = embedding.expand(self.seq_length, batch_size, embedding.size(2)) 101 | idx = range(embedding_idx * self.seq_length, embedding_idx * self.seq_length + self.seq_length) 102 | e = torch.cat((target[idx, :, :], embedding), dim=2).to(device) 103 | prev_note, h0_dec = self.lstm(e, h0_dec) 104 | prev_note = self.out(prev_note) 105 | out[idx, :, :] = prev_note 106 | prev_note = prev_note[-1, ::].unsqueeze(0) 107 | else: 108 | for note_idx in range(self.seq_length): 109 | e = torch.cat((prev_note, embedding), -1) 110 | prev_note, h0_dec = self.lstm(e, h0_dec) 111 | prev_note = self.out(prev_note) 112 | 113 | idx = embedding_idx * self.seq_length + note_idx 114 | out[idx, :, :] = prev_note.squeeze() 115 | return out 116 | 117 | def reconstruct(self, latent, h0, c0, temperature): 118 | """ 119 | Reconstruct the actual midi using categorical distribution 120 | """ 121 | one_hot = torch.eye(self.input_size).to(device) 122 | batch_size = 1 123 | out = torch.zeros(self.max_seq_length, batch_size, self.input_size, dtype=torch.float, device=device) 124 | prev_note = torch.zeros(1, batch_size, self.input_size, dtype=torch.float, device=device) 125 | for embedding_idx in range(self.num_embeddings): 126 | embedding, (h0, c0) = self.conductor(latent.unsqueeze(0), (h0, c0)) 127 | embedding = self.conductor_embeddings(embedding) 128 | h0_dec = (torch.randn(self.num_layers, batch_size, self.hidden_size, dtype=torch.float, device=device), 129 | torch.randn(self.num_layers, batch_size, self.hidden_size, dtype=torch.float, device=device)) 130 | for note_idx in range(self.seq_length): 131 | e = torch.cat((prev_note, embedding), -1) 132 | prev_note, h0_dec = self.lstm(e, h0_dec) 133 | prev_note = self.out(prev_note) 134 | prev_note = Categorical(prev_note / temperature).sample() 135 | prev_note = self.one_hot(prev_note) 136 | out[idx, :, :] = prev_note.squeeze() 137 | return out 138 | 139 | 140 | def init_hidden(self, batch_size=1): 141 | return (torch.zeros(self.num_layers, batch_size, self.hidden_size, dtype=torch.float, device=device), 142 | torch.zeros(self.num_layers, batch_size, self.hidden_size, dtype=torch.float, device=device)) 143 | 144 | 145 | class BiGRUEncoder(nn.Module): 146 | """ 147 | Bi-directional GRU encoder from MusicVAE 148 | Inputs: 149 | - input_size: 150 | - hidden_size: hidden size of bidirectional gru 151 | - num_layers: Number of layers for bidirectional gru 152 | """ 153 | 154 | def __init__(self, 155 | input_size=61, 156 | hidden_size=2048, 157 | latent_size=512, 158 | num_layers=2): 159 | super(BiGRUEncoder, self).__init__() 160 | self.input_size = input_size 161 | self.hidden_size = hidden_size 162 | self.latent_size = latent_size 163 | self.num_layers = num_layers 164 | 165 | self.bigru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=True) 166 | self.mu = nn.Linear(in_features=2 * hidden_size, out_features=latent_size) 167 | self.sigma = nn.Linear(in_features=2 * hidden_size, out_features=latent_size) 168 | self.softplus = nn.Softplus() 169 | 170 | def forward(self, input, h0): 171 | batch_size = input.size(1) 172 | _, h_n = self.bigru(input, h0) 173 | h_n = h_n.view(self.num_layers, 2, batch_size, -1)[-1].view(batch_size, -1) 174 | mu = self.mu(h_n) 175 | sigma = self.softplus(self.sigma(h_n)) 176 | return mu, sigma 177 | 178 | def init_hidden(self, batch_size=1): 179 | # Bidirectional gru so num_layers*2 180 | return torch.zeros(self.num_layers * 2, batch_size, self.hidden_size, dtype=torch.float, device=device) 181 | 182 | 183 | class HierarchicalGRUDecoder(nn.Module): 184 | """ 185 | Hierarchical decoder from MusicVAE 186 | """ 187 | 188 | def __init__(self, 189 | num_embeddings, 190 | input_size=61, 191 | hidden_size=1024, 192 | latent_size=512, 193 | num_layers=2, 194 | max_seq_length=256, 195 | seq_length=16): 196 | super(HierarchicalGRUDecoder, self).__init__() 197 | self.input_size = input_size 198 | self.hidden_size = hidden_size 199 | self.latent_size = latent_size 200 | self.num_embeddings = num_embeddings 201 | self.max_seq_length = max_seq_length 202 | self.seq_length = seq_length 203 | self.num_layers = num_layers 204 | 205 | self.tanh = nn.Tanh() 206 | self.conductor = nn.GRU(input_size=latent_size, hidden_size=hidden_size, num_layers=num_layers) 207 | self.conductor_embeddings = nn.Sequential( 208 | nn.Linear(in_features=hidden_size, out_features=latent_size), 209 | nn.Tanh()) 210 | self.gru = nn.GRU(input_size=input_size + latent_size, hidden_size=hidden_size, num_layers=num_layers) 211 | self.out = nn.Sequential( 212 | nn.Linear(in_features=hidden_size, out_features=input_size), 213 | nn.Softmax(dim=2) 214 | ) 215 | 216 | def forward(self, target, latent, h0, use_teacher_forcing=True, temperature=1.0): 217 | batch_size = target.size(1) 218 | 219 | out = torch.zeros(self.max_seq_length, batch_size, self.input_size, dtype=torch.float, device=device) 220 | # Initialie start note 221 | prev_note = torch.zeros(1, batch_size, self.input_size, dtype=torch.float, device=device) 222 | 223 | # Conductor produces an embedding vector for each subsequence, where each 224 | # subsequence consists of a bar of 16th notes 225 | for embedding_idx in range(self.num_embeddings): 226 | embedding, h0 = self.conductor(latent.unsqueeze(0), h0) 227 | embedding = self.conductor_embeddings(embedding) 228 | 229 | # Initialize lower decoder hidden state 230 | h0_dec = torch.randn(self.num_layers, batch_size, self.hidden_size, dtype=torch.float, device=device) 231 | 232 | # Decoder produces sequence of distributions over output tokens 233 | # for each subsequence where at each step the current 234 | # conductor embedding is concatenated with the previous output 235 | # token to be used as input 236 | if use_teacher_forcing: 237 | embedding = embedding.expand(self.seq_length, batch_size, embedding.size(2)).to(device) 238 | idx = range(embedding_idx * self.seq_length, embedding_idx * self.seq_length + self.seq_length) 239 | e = torch.cat((target[idx, :, :], embedding), dim=2) 240 | prev_note, h0_dec = self.gru(e, h0_dec) 241 | prev_note = self.out(prev_note) 242 | out[idx, :, :] = prev_note 243 | prev_note = prev_note[-1, :, :].unsqueeze(0) 244 | else: 245 | for note_idx in range(self.seq_length): 246 | e = torch.cat((prev_note, embedding), -1) 247 | prev_note, h0_dec = self.gru(e, h0_dec) 248 | prev_note = self.out(prev_note) 249 | 250 | idx = embedding_idx * self.seq_length + note_idx 251 | out[idx, :, :] = prev_note.squeeze() 252 | return out 253 | 254 | def reconstruct(self, latent, h0, temperature): 255 | """ 256 | Reconstruct the actual midi using categorical distribution 257 | """ 258 | one_hot = torch.eye(self.input_size).to(device) 259 | batch_size = h0.size(1) 260 | out = torch.zeros(self.max_seq_length, batch_size, self.input_size, dtype=torch.float, device=device) 261 | prev_note = torch.zeros(1, batch_size, self.input_size, dtype=torch.float, device=device) 262 | for embedding_idx in range(self.num_embeddings): 263 | embedding, h0 = self.conductor(latent.unsqueeze(0), h0) 264 | embedding = self.conductor_embeddings(embedding) 265 | h0_dec = torch.randn(self.num_layers, batch_size, self.hidden_size, dtype=torch.float, device=device) 266 | for note_idx in range(self.seq_length): 267 | e = torch.cat((prev_note, embedding), -1) 268 | prev_note, h0_dec = self.gru(e, h0_dec) 269 | prev_note = self.out(prev_note) 270 | prev_note = Categorical(prev_note / temperature).sample() 271 | prev_note = one_hot[prev_note] 272 | out[note_idx, :, :] = prev_note.squeeze() 273 | return out 274 | 275 | def init_hidden(self, batch_size=1): 276 | return torch.zeros(self.num_layers, batch_size, self.hidden_size, dtype=torch.float, device=device) 277 | -------------------------------------------------------------------------------- /src/preprocess.py: -------------------------------------------------------------------------------- 1 | import pretty_midi as pretty_midi 2 | import src.midi_functions as mf 3 | import os 4 | import sys 5 | import numpy as np 6 | import pickle 7 | import math 8 | from sklearn.model_selection import train_test_split 9 | import time 10 | 11 | t = str(int(round(time.time()))) 12 | 13 | class MidiPreprocessor: 14 | """ 15 | Modifying MidiVAE (Brunner et al) preprocessing to remove global variables 16 | 17 | Class to proprocess midi files 18 | - low_crop: Low note cutoff 19 | - high_crop: high crop cutoff 20 | - num_notes: Number of midi notes represented 21 | - smallest_note: Smallest note representation (i.e. 16th note) 22 | - max_velocity: Midi velocity is represented in range (0, 127) 23 | """ 24 | def __init__(self, 25 | classes, 26 | pickle_store_folder, 27 | include_unknown=False, 28 | only_unknown=False, 29 | low_crop=24, 30 | high_crop=84, 31 | num_notes=128, 32 | smallest_note=16, 33 | max_velocity=127, 34 | include_only_monophonic_instruments=False, 35 | max_voices_per_track=1, 36 | max_voices=4, 37 | include_silent_note=True, 38 | velocity_threshold=0.5, 39 | instrument_attach_method='1hot-category', 40 | attach_instruments=False, 41 | input_length=16, 42 | output_length=16, 43 | test_fraction=0.1): 44 | self.classes = classes 45 | self.pickle_store_folder = pickle_store_folder 46 | self.include_unknown = include_unknown 47 | self.only_unknown=only_unknown 48 | self.low_crop = low_crop 49 | self.high_crop = high_crop 50 | self.num_notes = num_notes 51 | self.smallest_note = smallest_note 52 | self.max_velocity = max_velocity 53 | self.note_columns = [pretty_midi.note_number_to_name(n) for n in range(0, num_notes)] 54 | self.include_only_monophonic_instruments = include_only_monophonic_instruments 55 | self.max_voices_per_track = max_voices_per_track 56 | self.max_voices = max_voices 57 | self.velocity_threshold = velocity_threshold 58 | self.instrument_attach_method = instrument_attach_method 59 | self.attach_instruments = attach_instruments 60 | self.input_length = input_length 61 | self.output_length = output_length 62 | self.test_fraction = test_fraction 63 | 64 | if include_unknown: 65 | self.num_classes = len(classes) + 1 66 | else: 67 | self.num_classes = len(classes) 68 | 69 | self.include_silent_note = include_silent_note 70 | if include_silent_note: 71 | self.silent_dim = 1 72 | else: 73 | self.silent_dim = 0 74 | 75 | if instrument_attach_method == '1hot-category': 76 | self.instrument_dim = 16 77 | elif instrument_attach_method == 'khot-category': 78 | self.instrument_dim = 4 79 | elif instrument_attach_method == '1hot-instrument': 80 | self.instrument_dim = 128 81 | elif instrument_attach_method == 'khot-instrument': 82 | self.instrument_dim = 7 83 | 84 | def load_rolls(self, path, name, save_preprocessed_midi): 85 | 86 | #try loading the midi file 87 | #if it fails, return all None objects 88 | try: 89 | mid = pretty_midi.PrettyMIDI(path + name) 90 | except (ValueError, EOFError, IndexError, OSError, KeyError, ZeroDivisionError, AttributeError) as e: 91 | exception_str = 'Unexpected error in ' + name + ':\n', e, sys.exc_info()[0] 92 | print(exception_str) 93 | return None, None, None, None, None, None 94 | 95 | #determine start and end of the song 96 | #if there are tempo changes in the song, only take the longest part where the tempo is steady 97 | #this cuts of silent starts and extended ends 98 | #this also makes sure that the start of the bars are aligned through the song 99 | tempo_change_times, tempo_change_bpm = mid.get_tempo_changes() 100 | song_start = 0 101 | song_end = mid.get_end_time() 102 | #there will always be at least one tempo change to set the first tempo 103 | #but if there are more than one tempo changes, that means that the tempos are changed 104 | if len(tempo_change_times) > 1: 105 | longest_part = 0 106 | longest_part_start_time = 0 107 | longest_part_end_time = song_end 108 | longest_part_tempo = 0 109 | for i, tempo_change_time in enumerate(tempo_change_times): 110 | if i == len(tempo_change_times) - 1: 111 | end_time = song_end 112 | else: 113 | end_time = tempo_change_times[i+1] 114 | current_part_length = end_time - tempo_change_time 115 | if current_part_length > longest_part: 116 | longest_part = current_part_length 117 | longest_part_start_time = tempo_change_time 118 | longest_part_end_time = end_time 119 | longest_part_tempo = tempo_change_bpm[i] 120 | song_start = longest_part_start_time 121 | song_end = longest_part_end_time 122 | tempo = longest_part_tempo 123 | else: 124 | tempo = tempo_change_bpm[0] 125 | 126 | #cut off the notes that are not in the longest part where the tempo is steady 127 | for instrument in mid.instruments: 128 | new_notes = [] #list for the notes that survive the cutting 129 | for note in instrument.notes: 130 | #check if it is in the given range of the longest part where the tempo is steady 131 | if note.start >= song_start and note.end <= song_end: 132 | #adjust to new times 133 | note.start -= song_start 134 | note.end -= song_start 135 | new_notes.append(note) 136 | instrument.notes = new_notes 137 | 138 | #(descending) order the piano_rolls according to the number of notes per track 139 | number_of_notes = [] 140 | piano_rolls = [i.get_piano_roll(fs=100) for i in mid.instruments] 141 | for piano_roll in piano_rolls: 142 | number_of_notes.append(np.count_nonzero(piano_roll)) 143 | permutation = np.argsort(number_of_notes)[::-1] 144 | mid.instruments = [mid.instruments[i] for i in permutation] 145 | 146 | quarter_note_length = 1. / (tempo/60.) 147 | #fs is is the frequency for the song at what rate notes are picked 148 | #the song will by sampled by (0, song_length_in_seconds, 1./fs) 149 | #fs should be the inverse of the length of the note, that is to be sampled 150 | #the value should be in beats per seconds, where beats can be quarter notes or whatever... 151 | fs = 1. / (quarter_note_length * 4. / self.smallest_note) 152 | 153 | total_ticks = math.ceil(song_end * fs) 154 | 155 | #assemble piano_rolls, velocity_rolls and held_note_rolls 156 | piano_rolls = [] 157 | velocity_rolls = [] 158 | held_note_rolls = [] 159 | max_concurrent_notes_per_track_list = [] 160 | for instrument in mid.instruments: 161 | piano_roll = np.zeros((total_ticks, 128)) 162 | 163 | #counts how many notes are played at maximum for this instrument at any given tick 164 | #this is used to determine the depth of the velocity_roll and held_note_roll 165 | concurrent_notes_count = np.zeros((total_ticks,)) 166 | 167 | #keys is a tuple of the form (tick_start_of_the_note, pitch) 168 | #this uniquely identifies a note since there can be no two notes 169 | # playing on the same pitch for the same instrument 170 | note_to_velocity_dict = dict() 171 | 172 | #keys is a tuple of the form (tick_start_of_the_note, pitch) 173 | #this uniquely identifies a note since there can be no two notes playing 174 | # on the same pitch for the same instrument 175 | note_to_duration_dict = dict() 176 | 177 | for note in instrument.notes: 178 | note_tick_start = note.start * fs 179 | note_tick_end = note.end * fs 180 | absolute_start = int(round(note_tick_start)) 181 | absolute_end = int(round(note_tick_end)) 182 | decimal = note_tick_start - absolute_start 183 | #see if it starts at a tick or not 184 | #if it doesn't start at a tick (decimal > 10e-3) but is longer than one tick, include it anyways 185 | if decimal < 10e-3 or absolute_end-absolute_start >= 1: 186 | piano_roll[absolute_start:absolute_end, note.pitch] = 1 187 | concurrent_notes_count[absolute_start:absolute_end] += 1 188 | 189 | #save information of velocity and duration for later use 190 | #this can not be done right now because there might be no ordering in the notes 191 | note_to_velocity_dict[(absolute_start, note.pitch)] = note.velocity 192 | note_to_duration_dict[(absolute_start, note.pitch)] = absolute_end - absolute_start 193 | 194 | max_concurrent_notes = int(np.max(concurrent_notes_count)) 195 | max_concurrent_notes_per_track_list.append(max_concurrent_notes) 196 | 197 | velocity_roll = np.zeros((total_ticks, max_concurrent_notes)) 198 | held_note_roll = np.zeros((total_ticks, max_concurrent_notes)) 199 | 200 | for step, note_vector in enumerate(piano_roll): 201 | pitches = list(note_vector.nonzero()[0]) 202 | sorted_pitches_from_highest_to_lowest = sorted(pitches)[::-1] 203 | for voice_number, pitch in enumerate(sorted_pitches_from_highest_to_lowest): 204 | if (step, pitch) in note_to_velocity_dict.keys(): 205 | velocity_roll[step, voice_number] = note_to_velocity_dict[(step, pitch)] 206 | if (step, pitch) not in note_to_duration_dict.keys(): 207 | #if the note is in the dictionary, it means that it is the start of the note 208 | #if its not the start of a note, it means it is held 209 | held_note_roll[step, voice_number] = 1 210 | 211 | piano_rolls.append(piano_roll) 212 | velocity_rolls.append(velocity_roll) 213 | held_note_rolls.append(held_note_roll) 214 | 215 | #get the program numbers for each instrument 216 | #program numbers are between 0 and 127 and have a 1:1 mapping to the instruments described in settings file 217 | programs = [i.program for i in mid.instruments] 218 | 219 | #we may want to override the maximal_number_of_voices_per_track 220 | # if the following tracks are all silent it makes no sense to exclude 221 | # voices from the first instrument and then just have a song with 1 voice 222 | override_max_notes_per_track_list = [self.max_voices_per_track 223 | for _ in max_concurrent_notes_per_track_list] 224 | silent_tracks_if_we_dont_override = self.max_voices - \ 225 | sum([min(self.max_voices_per_track, x) if x > 0 else 0 226 | for x in max_concurrent_notes_per_track_list[:self.max_voices]]) 227 | 228 | for voice in range(min(self.max_voices, len(max_concurrent_notes_per_track_list))): 229 | if silent_tracks_if_we_dont_override > 0 and \ 230 | max_concurrent_notes_per_track_list[voice] > self.max_voices: 231 | additional_voices = min(silent_tracks_if_we_dont_override, 232 | max_concurrent_notes_per_track_list[voice] - \ 233 | self.max_voices) 234 | override_max_notes_per_track_list[voice] += additional_voices 235 | silent_tracks_if_we_dont_override -= additional_voices 236 | 237 | #chose the most important piano_rolls 238 | #each of them will be monophonic 239 | chosen_piano_rolls = [] 240 | chosen_velocity_rolls = [] 241 | chosen_held_note_rolls = [] 242 | chosen_programs = [] 243 | max_song_length = 0 244 | 245 | #go through all pianorolls in the descending order of the total notes they have 246 | for batch in zip(piano_rolls, 247 | velocity_rolls, 248 | held_note_rolls, 249 | programs, 250 | max_concurrent_notes_per_track_list, 251 | override_max_notes_per_track_list): 252 | piano_roll = batch[0] 253 | velocity_roll = batch[1] 254 | held_note_roll = batch[2] 255 | program = batch[3] 256 | max_concurrent_notes = batch[4] 257 | override_max_notes_per_track = batch[5] 258 | #see if there is actually a note played in that pianoroll 259 | if max_concurrent_notes > 0: 260 | 261 | #skip if you only want monophonic instruments and there are more than 1 notes played at the same time 262 | if self.include_only_monophonic_instruments: 263 | if max_concurrent_notes > 1: 264 | continue 265 | monophonic_piano_roll = piano_roll 266 | #append them to the chosen ones 267 | if len(chosen_piano_rolls) < self.max_voices: 268 | chosen_piano_rolls.append(monophonic_piano_roll) 269 | chosen_velocity_rolls.append(velocity_roll) 270 | chosen_held_note_rolls.append() 271 | chosen_programs.append(program) 272 | if monophonic_piano_roll.shape[0] > max_song_length: 273 | max_song_length = monophonic_piano_roll.shape[0] 274 | else: 275 | break 276 | 277 | else: 278 | #limit the number of voices per track by the minimum of the actual 279 | # concurrent voices per track or the maximal allowed in the settings file 280 | for voice in range(min(max_concurrent_notes, max(self.max_voices_per_track, 281 | override_max_notes_per_track))): 282 | #Take the highest note for voice 0, second highest for voice 1 and so on... 283 | monophonic_piano_roll = np.zeros(piano_roll.shape) 284 | for step in range(piano_roll.shape[0]): 285 | #sort all the notes from highest to lowest 286 | notes = np.nonzero(piano_roll[step,:])[0][::-1] 287 | if len(notes) > voice: 288 | monophonic_piano_roll[step, notes[voice]] = 1 289 | 290 | #append them to the chosen ones 291 | if len(chosen_piano_rolls) < self.max_voices: 292 | chosen_piano_rolls.append(monophonic_piano_roll) 293 | chosen_velocity_rolls.append(velocity_roll[:, voice]) 294 | chosen_held_note_rolls.append(held_note_roll[:, voice]) 295 | chosen_programs.append(program) 296 | if monophonic_piano_roll.shape[0] > max_song_length: 297 | max_song_length = monophonic_piano_roll.shape[0] 298 | else: 299 | break 300 | if len(chosen_piano_rolls) == self.max_voices: 301 | break 302 | 303 | assert(len(chosen_piano_rolls) == len(chosen_velocity_rolls)) 304 | assert(len(chosen_piano_rolls) == len(chosen_held_note_rolls)) 305 | assert(len(chosen_piano_rolls) == len(chosen_programs)) 306 | 307 | #do the unrolling and prepare for model input 308 | if len(chosen_piano_rolls) > 0: 309 | 310 | song_length = max_song_length * self.max_voices 311 | 312 | #prepare Y 313 | #Y will be the target notes 314 | Y = np.zeros((song_length, chosen_piano_rolls[0].shape[1])) 315 | #unroll the pianoroll into one matrix 316 | for i, piano_roll in enumerate(chosen_piano_rolls): 317 | for step in range(piano_roll.shape[0]): 318 | Y[i + step*self.max_voices,:] += piano_roll[step,:] 319 | #assert that there is always at most one note played 320 | for step in range(Y.shape[0]): 321 | assert(np.sum(Y[step,:]) <= 1) 322 | #cut off pitch values which are very uncommon 323 | #this reduces the feature space significantly 324 | Y = Y[:,self.low_crop:self.high_crop] 325 | #append silent note if desired 326 | #the silent note will always be at the last note 327 | if self.include_silent_note: 328 | Y = np.append(Y, np.zeros((Y.shape[0], 1)), axis=1) 329 | for step in range(Y.shape[0]): 330 | if np.sum(Y[step]) == 0: 331 | Y[step, -1] = 1 332 | #assert that there is now a 1 at every step 333 | for step in range(Y.shape[0]): 334 | assert(np.sum(Y[step,:]) == 1) 335 | 336 | #unroll the velocity roll 337 | #V will only have shape (song_length,) and it's values will be between 0 and 1 (divide by MAX_VELOCITY) 338 | V = np.zeros((song_length,)) 339 | for i, velocity_roll in enumerate(chosen_velocity_rolls): 340 | for step in range(velocity_roll.shape[0]): 341 | if velocity_roll[step] > 0: 342 | velocity = self.velocity_threshold + \ 343 | (velocity_roll[step] / self.max_velocity) * (1.0 - self.velocity_threshold) 344 | # a note is therefore at least 0.1*max_velocity loud 345 | # but this is good, since we can now more clearly distinguish between silent or played notes 346 | assert(velocity <= 1.0) 347 | V[i + step*self.max_voices] = velocity 348 | 349 | 350 | #unroll the held_note_rolls 351 | #D will only have shape (song_length,) and it's values will be 0 or 1 (1 if held) 352 | #it's name is D for Duration to not have a name clash with the history (H) 353 | D = np.zeros((song_length,)) 354 | for i, held_note_roll in enumerate(chosen_held_note_rolls): 355 | for step in range(held_note_roll.shape[0]): 356 | D[i + step*self.max_voices] = held_note_roll[step] 357 | 358 | instrument_feature_matrix = mf.programs_to_instrument_matrix(chosen_programs, 359 | self.instrument_attach_method, 360 | self.max_voices) 361 | 362 | if self.attach_instruments: 363 | instrument_feature_matrix = np.transpose(np.tile(np.transpose(instrument_feature_matrix), song_length//self.max_voices)) 364 | Y = np.append(Y, instrument_feature_matrix, axis=1) 365 | X = Y 366 | 367 | if save_preprocessed_midi: mf.rolls_to_midi(Y, 368 | chosen_programs, 369 | 'preprocess_midi_data/' + t+ '/', 370 | name, 371 | tempo, 372 | self.low_crop, 373 | self.high_crop, 374 | self.num_notes, 375 | self.velocity_threshold, 376 | V, 377 | D) 378 | 379 | 380 | #split the song into chunks of size output_length or input_length 381 | #pad them with silent notes if necessary 382 | if self.input_length > 0: 383 | 384 | #split X 385 | padding_length = self.input_length - (X.shape[0] % self.input_length) 386 | if self.input_length == padding_length: 387 | padding_length = 0 388 | #pad to the right.. 389 | X = np.pad(X, ((0,padding_length),(0, 0)), 'constant', constant_values=(0, 0)) 390 | if self.include_silent_note: 391 | X[-padding_length:,-1] = 1 392 | number_of_splits = X.shape[0] // self.input_length 393 | X = np.split(X, number_of_splits) 394 | X = np.asarray(X) 395 | 396 | if self.output_length > 0: 397 | #split Y 398 | padding_length = self.output_length - (Y.shape[0] % self.output_length) 399 | if self.output_length == padding_length: 400 | padding_length = 0 401 | 402 | #pad to the right.. 403 | Y = np.pad(Y, ((0,padding_length),(0, 0)), 'constant', constant_values=(0, 0)) 404 | if self.include_silent_note: 405 | Y[-padding_length:,-1] = 1 406 | number_of_splits = Y.shape[0] // self.output_length 407 | Y = np.split(Y, number_of_splits) 408 | Y = np.asarray(Y) 409 | 410 | #split V 411 | #pad to the right with zeros.. 412 | V = np.pad(V, (0,padding_length), 'constant', constant_values=0) 413 | number_of_splits = V.shape[0] // self.output_length 414 | V = np.split(V, number_of_splits) 415 | V = np.asarray(V) 416 | 417 | #split D 418 | #pad to the right with zeros.. 419 | D = np.pad(D, (0,padding_length), 'constant', constant_values=0) 420 | number_of_splits = D.shape[0] // self.output_length 421 | D = np.split(D, number_of_splits) 422 | D = np.asarray(D) 423 | 424 | 425 | return X, Y, instrument_feature_matrix, tempo, V, D 426 | else: 427 | return None, None, None, None, None, None 428 | 429 | def import_midi_from_folder(self, 430 | folder, 431 | save_imported_midi_as_pickle, 432 | save_preprocessed_midi): 433 | X_list = [] 434 | Y_list = [] 435 | paths = [] 436 | c_classes = [] 437 | I_list = [] 438 | T_list = [] 439 | V_list = [] 440 | D_list = [] 441 | no_imported = 0 442 | for path, subdirs, files in os.walk(folder): 443 | for name in files: 444 | _path = path.replace('\\', '/') + '/' 445 | _name = name.replace('\\', '/') 446 | 447 | if _name.endswith('.mid') or _name.endswith('.midi'): 448 | 449 | shortpath = _path[len(folder):] 450 | found = False 451 | for i, c in enumerate(self.classes): 452 | if c.lower() in shortpath.lower(): 453 | found = True 454 | print("Importing " + c + " song called " + _name) 455 | C = i 456 | if not self.only_unknown: 457 | 458 | X, Y, I, T, V, D = self.load_rolls(_path, _name, save_preprocessed_midi) 459 | 460 | if X is not None and Y is not None: 461 | X_list.append(X) 462 | Y_list.append(Y) 463 | I_list.append(I) 464 | T_list.append(T) 465 | V_list.append(V) 466 | D_list.append(D) 467 | paths.append(_path + _name) 468 | c_classes.append(C) 469 | no_imported += 1 470 | break 471 | if not found: 472 | #assign new category for all the files with no proper title 473 | if self.include_unknown: 474 | C = self.num_classes -1 475 | print("Importing unknown song ", _name) 476 | 477 | X, Y, I, T, V, D = load_rolls(_path, _name) 478 | 479 | if X is not None and Y is not None: 480 | X_list.append(X) 481 | Y_list.append(Y) 482 | I_list.append(I) 483 | T_list.append(T) 484 | V_list.append(V) 485 | D_list.append(D) 486 | paths.append(_path + _name) 487 | c_classes.append(C) 488 | no_imported += 1 489 | 490 | 491 | assert(len(X_list) == len(paths)) 492 | assert(len(X_list) == len(c_classes)) 493 | assert(len(X_list) == len(I_list)) 494 | assert(len(X_list) == len(T_list)) 495 | assert(len(X_list) == len(D_list)) 496 | assert(len(X_list) == len(V_list)) 497 | 498 | unique, counts = np.unique(c_classes, return_counts=True) 499 | 500 | data = train_test_split(V_list, 501 | D_list, 502 | T_list, 503 | I_list, 504 | Y_list, 505 | X_list, 506 | c_classes, 507 | paths, 508 | test_size=self.test_fraction, 509 | random_state=42, 510 | stratify=c_classes) 511 | 512 | V_train = data[0] 513 | V_test = data[1] 514 | D_train = data[2] 515 | D_test = data[3] 516 | T_train = data[4] 517 | T_test = data[5] 518 | I_train = data[6] 519 | I_test = data[7] 520 | Y_train = data[8] 521 | Y_test = data[9] 522 | X_train = data[10] 523 | X_test = data[11] 524 | c_train = data[12] 525 | c_test = data[13] 526 | train_paths = data[14] 527 | test_paths = data[15] 528 | 529 | train_set_size = len(X_train) 530 | test_set_size = len(X_test) 531 | 532 | if save_imported_midi_as_pickle: 533 | if not os.path.exists(self.pickle_store_folder): 534 | os.makedirs(self.pickle_store_folder) 535 | 536 | pickle.dump(V_train,open(self.pickle_store_folder+'/V_train.pickle', 'wb')) 537 | pickle.dump(V_test,open(self.pickle_store_folder+'/V_test.pickle', 'wb')) 538 | 539 | pickle.dump(D_train,open(self.pickle_store_folder+'/D_train.pickle', 'wb')) 540 | pickle.dump(D_test,open(self.pickle_store_folder+'/D_test.pickle', 'wb')) 541 | 542 | pickle.dump(T_train,open(self.pickle_store_folder+'/T_train.pickle', 'wb')) 543 | pickle.dump(T_test,open(self.pickle_store_folder+'/T_test.pickle', 'wb')) 544 | 545 | pickle.dump(I_train,open(self.pickle_store_folder+'/I_train.pickle', 'wb')) 546 | pickle.dump(I_test,open(self.pickle_store_folder+'/I_test.pickle', 'wb')) 547 | 548 | pickle.dump(Y_train,open(self.pickle_store_folder+'/Y_train.pickle', 'wb')) 549 | pickle.dump(Y_test,open(self.pickle_store_folder+'/Y_test.pickle', 'wb')) 550 | 551 | pickle.dump(X_train,open(self.pickle_store_folder+'/X_train.pickle', 'wb')) 552 | pickle.dump(X_test,open(self.pickle_store_folder+'/X_test.pickle', 'wb')) 553 | 554 | pickle.dump(c_train,open(self.pickle_store_folder+'/c_train.pickle', 'wb')) 555 | pickle.dump(c_test,open(self.pickle_store_folder+'/c_test.pickle', 'wb')) 556 | 557 | pickle.dump(train_paths,open(self.pickle_store_folder+'/train_paths.pickle', 'wb')) 558 | pickle.dump(test_paths,open(self.pickle_store_folder+'/test_paths.pickle', 'wb')) 559 | 560 | return data 561 | 562 | --------------------------------------------------------------------------------