├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── dataset └── e_piano.py ├── evaluate.py ├── generate.py ├── graph_results.py ├── model ├── loss.py ├── music_transformer.py ├── positional_encoding.py └── rpr.py ├── preprocess_midi.py ├── third_party └── references.txt ├── train.py └── utilities ├── argument_funcs.py ├── constants.py ├── device.py ├── lr_scheduling.py └── run_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/saved_models/* 2 | **/outputGraphs/* 3 | 4 | **.mid 5 | **.pickle 6 | **.pyc 7 | **/__pycache__/* 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/midi_processor"] 2 | path = third_party/midi_processor 3 | url = https://github.com/jason9693/midi-neural-processor 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Damon Gwinn, Ben Myrick, Ryan Marshall 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Music Transformer 2 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/asigalov61/SuperPiano/blob/master/Super_Piano_3.ipynb) 3 | 4 | Currently supports Pytorch >= 1.2.0 with Python >= 3.6 5 | 6 | There is now a much friendlier [Google Colab version](https://github.com/asigalov61/SuperPiano/blob/master/Super_Piano_3.ipynb) of this project courtesy of [Alex](https://github.com/asigalov61)! 7 | 8 | ## About 9 | This is a reproduction of the MusicTransformer (Huang et al., 2018) for Pytorch. This implementation utilizes the generic Transformer implementation introduced in Pytorch 1.2.0 (https://pytorch.org/docs/stable/nn.html#torch.nn.Transformer). 10 | 11 | ## Generated Music: 12 | Some various music results (midi and mp3) are in the following Google Drive folder: 13 | https://drive.google.com/drive/folders/1qS4z_7WV4LLgXZeVZU9IIjatK7dllKrc?usp=sharing 14 | 15 | See the results section for the model hyperparameters used for generation. 16 | 17 | Mp3 results were played through a [Kawai MP11SE](https://kawaius.com/product/mp11se/). 18 | In order to play .mid files, we used [Midi Editor](https://www.midieditor.org/) which is free to use and open source. 19 | 20 | ## TODO 21 | * Write own midi pre-processor (sustain pedal errors with jason's) 22 | * Support any midi file beyond Maestro 23 | * Fixed length song generation 24 | * Midi augmentations from paper 25 | * Multi-GPU support 26 | 27 | ## How to run 28 | 1. Download the Maestro dataset (we used v2 but v1 should work as well). You can download the dataset [here](https://magenta.tensorflow.org/datasets/maestro). You only need the MIDI version if you're tight on space. 29 | 30 | 2. Run `git submodule update --init --recursive` to get the MIDI pre-processor provided by jason9693 et al. (https://github.com/jason9693/midi-neural-processor), which is used to convert the MIDI file into discrete ordered message types for training and evaluating. 31 | 32 | 3. Run `preprocess_midi.py -output_dir `, or run with `--help` for details. This will write pre-processed data into folder split into `train`, `val`, and `test` as per Maestro's recommendation. 33 | 34 | 4. To train a model, run `train.py`. Use `--help` to see the tweakable parameters. See the results section for details on model performance. 35 | 36 | 5. After training models, you can evaluate them with `evaluate.py` and generate a MIDI piece with `generate.py`. To graph and compare results visually, use `graph_results.py`. 37 | 38 | For the most part, you can just leave most arguments at their default values. If you are using a different dataset location or other such things, you will need to specify that in the arguments. Beyond that, the average user does not have to worry about most of the arguments. 39 | 40 | ### Training 41 | As an example to train a model using the parameters specified in results: 42 | 43 | ``` 44 | python train.py -output_dir rpr --rpr 45 | ``` 46 | You can additonally specify both a weight and print modulus that determine what epochs to save weights and what batches to print. The weights that achieved the best loss and the best accuracy (separate) are always stored in results, regardless of weight modulus input. 47 | 48 | ### Evaluation 49 | You can evaluate a model using; 50 | ``` 51 | python evaluate.py -model_weights rpr/results/best_acc_weights.pickle --rpr 52 | ``` 53 | 54 | Your model's results may vary because a random sequence start position is chosen for each evaluation piece. This may be changed in the future. 55 | 56 | ### Generation 57 | You can generate a piece with a trained model by using: 58 | ``` 59 | python generate.py -output_dir output -model_weights rpr/results/best_acc_weights.pickle --rpr 60 | ``` 61 | 62 | The default generation method is a sampled probability distribution with the softmaxed output as the weights. You can also use beam search but this simply does not work well and is not recommended. 63 | 64 | ## Pytorch Transformer 65 | We used the Transformer class provided since Pytorch 1.2.0 (https://pytorch.org/docs/stable/nn.html#torch.nn.Transformer). The provided Transformer assumes an encoder-decoder architecture. To make it decoder-only like the Music Transformer, you use stacked encoders with a custom dummy decoder. This decoder-only model can be found in model/music_transformer.py. 66 | 67 | At the time this reproduction was produced, there was no Relative Position Representation (RPR) (Shaw et al., 2018) support in the Pytorch Transformer code. To account for the lack of RPR support, we modified Pytorch 1.2.0 Transformer code to support it. This is based on the Skew method proposed by Huang et al. which is more memory efficient. You can find the modified code in model/rpr.py. This modified Pytorch code will not be kept up to date and will be removed when Pytorch provides RPR support. 68 | 69 | ## Results 70 | We trained a base and RPR model with the following parameters (taken from the paper) for 300 epochs: 71 | * **learn_rate**: None 72 | * **ce_smoothing**: None 73 | * **batch_size**: 2 74 | * **max_sequence**: 2048 75 | * **n_layers**: 6 76 | * **num_heads**: 8 77 | * **d_model**: 512 78 | * **dim_feedforward**: 1024 79 | * **dropout**: 0.1 80 | 81 | The following graphs were generated with the command: 82 | ``` 83 | python graph_results.py -input_dirs base_model/results?rpr_model/results -model_names base?rpr 84 | ``` 85 | 86 | Note, multiple input models are separated with a '?' 87 | 88 | ![Loss Results Graph](https://lh3.googleusercontent.com/u6AL9vIXG7gBeKuLlVJGFeex7-q2NYLbMqYVZGFI3qxWlpa6hAXdVlOsD52i4jKjrVcf4YZCGBaMIVIagcu_z-7Sg5YhDcgsqcs-p4aR48C287c1QraG0tRnHnmimLd8jizk9afW8g=w2400 "Loss Results") 89 | 90 | ![Accuracy Results Graph](https://lh3.googleusercontent.com/ajbanROlOAM9YrNDaHrv1tWM8tZ4nrcrTehwoHsaftnPPZ4xEBLG0RmBa4awYXntBQF0RR_Uh3bsLZv4mdzmZM_TNisMnreKsB2jZIY7iSZjQiL4kRumypymuxIiHu-VdPB0kUkILQ=w2400 "Accuracy Results") 91 | 92 | ![Learn Rate Results Graph](https://lh3.googleusercontent.com/Gz8N8tgHN2qstvdq77GqQQiukWjwBUettMK8IYV0228il5NvRdrnoISS5HTrxd7xVOrRpSzTtLlRppT-UwWJ2ke1XnAsRMbJ0bCElSvCQAA_z08HSZjbJ4wQXBbg4lVzuGdikEN5Ug=w2400 "Learn Rate Results") 93 | 94 | Best loss for *base* model: 1.99 on epoch 250 95 | Best loss for *rpr* model: 1.92 on epoch 216 96 | 97 | ## Discussion 98 | The results were overall close to the results from the paper. Huang et al. reported a loss of around 1.8 for the base and rpr models on Maestro V1. We use Maestro V2 and perform no midi augmentations as they had discussed in their paper. Furthermore, [there are issues with how sustain is handled](https://github.com/jason9693/midi-neural-processor/pull/2) which can be observed by listening to some pre-processed midi files. More refinement with the addition of those augmentations and fixes may yield the loss results in line with the paper. 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /dataset/e_piano.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.data import Dataset 7 | 8 | from utilities.constants import * 9 | from utilities.device import cpu_device 10 | 11 | SEQUENCE_START = 0 12 | 13 | # EPianoDataset 14 | class EPianoDataset(Dataset): 15 | """ 16 | ---------- 17 | Author: Damon Gwinn 18 | ---------- 19 | Pytorch Dataset for the Maestro e-piano dataset (https://magenta.tensorflow.org/datasets/maestro). 20 | Recommended to use with Dataloader (https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) 21 | 22 | Uses all files found in the given root directory of pre-processed (preprocess_midi.py) 23 | Maestro midi files. 24 | ---------- 25 | """ 26 | 27 | def __init__(self, root, max_seq=2048, random_seq=True): 28 | self.root = root 29 | self.max_seq = max_seq 30 | self.random_seq = random_seq 31 | 32 | fs = [os.path.join(root, f) for f in os.listdir(self.root)] 33 | self.data_files = [f for f in fs if os.path.isfile(f)] 34 | 35 | # __len__ 36 | def __len__(self): 37 | """ 38 | ---------- 39 | Author: Damon Gwinn 40 | ---------- 41 | How many data files exist in the given directory 42 | ---------- 43 | """ 44 | 45 | return len(self.data_files) 46 | 47 | # __getitem__ 48 | def __getitem__(self, idx): 49 | """ 50 | ---------- 51 | Author: Damon Gwinn 52 | ---------- 53 | Gets the indexed midi batch. Gets random sequence or from start depending on random_seq. 54 | 55 | Returns the input and the target. 56 | ---------- 57 | """ 58 | 59 | # All data on cpu to allow for the Dataloader to multithread 60 | i_stream = open(self.data_files[idx], "rb") 61 | # return pickle.load(i_stream), None 62 | raw_mid = torch.tensor(pickle.load(i_stream), dtype=TORCH_LABEL_TYPE, device=cpu_device()) 63 | i_stream.close() 64 | 65 | x, tgt = process_midi(raw_mid, self.max_seq, self.random_seq) 66 | 67 | return x, tgt 68 | 69 | # process_midi 70 | def process_midi(raw_mid, max_seq, random_seq): 71 | """ 72 | ---------- 73 | Author: Damon Gwinn 74 | ---------- 75 | Takes in pre-processed raw midi and returns the input and target. Can use a random sequence or 76 | go from the start based on random_seq. 77 | ---------- 78 | """ 79 | 80 | x = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=cpu_device()) 81 | tgt = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=cpu_device()) 82 | 83 | raw_len = len(raw_mid) 84 | full_seq = max_seq + 1 # Performing seq2seq 85 | 86 | if(raw_len == 0): 87 | return x, tgt 88 | 89 | if(raw_len < full_seq): 90 | x[:raw_len] = raw_mid 91 | tgt[:raw_len-1] = raw_mid[1:] 92 | tgt[raw_len-1] = TOKEN_END 93 | else: 94 | # Randomly selecting a range 95 | if(random_seq): 96 | end_range = raw_len - full_seq 97 | start = random.randint(SEQUENCE_START, end_range) 98 | 99 | # Always taking from the start to as far as we can 100 | else: 101 | start = SEQUENCE_START 102 | 103 | end = start + full_seq 104 | 105 | data = raw_mid[start:end] 106 | 107 | x = data[:max_seq] 108 | tgt = data[1:full_seq] 109 | 110 | 111 | # print("x:",x) 112 | # print("tgt:",tgt) 113 | 114 | return x, tgt 115 | 116 | 117 | # create_epiano_datasets 118 | def create_epiano_datasets(dataset_root, max_seq, random_seq=True): 119 | """ 120 | ---------- 121 | Author: Damon Gwinn 122 | ---------- 123 | Creates train, evaluation, and test EPianoDataset objects for a pre-processed (preprocess_midi.py) 124 | root containing train, val, and test folders. 125 | ---------- 126 | """ 127 | 128 | train_root = os.path.join(dataset_root, "train") 129 | val_root = os.path.join(dataset_root, "val") 130 | test_root = os.path.join(dataset_root, "test") 131 | 132 | train_dataset = EPianoDataset(train_root, max_seq, random_seq) 133 | val_dataset = EPianoDataset(val_root, max_seq, random_seq) 134 | test_dataset = EPianoDataset(test_root, max_seq, random_seq) 135 | 136 | return train_dataset, val_dataset, test_dataset 137 | 138 | # compute_epiano_accuracy 139 | def compute_epiano_accuracy(out, tgt): 140 | """ 141 | ---------- 142 | Author: Damon Gwinn 143 | ---------- 144 | Computes the average accuracy for the given input and output batches. Accuracy uses softmax 145 | of the output. 146 | ---------- 147 | """ 148 | 149 | softmax = nn.Softmax(dim=-1) 150 | out = torch.argmax(softmax(out), dim=-1) 151 | 152 | out = out.flatten() 153 | tgt = tgt.flatten() 154 | 155 | mask = (tgt != TOKEN_PAD) 156 | 157 | out = out[mask] 158 | tgt = tgt[mask] 159 | 160 | # Empty 161 | if(len(tgt) == 0): 162 | return 1.0 163 | 164 | num_right = (out == tgt) 165 | num_right = torch.sum(num_right).type(TORCH_FLOAT) 166 | 167 | acc = num_right / len(tgt) 168 | 169 | return acc 170 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | 5 | from dataset.e_piano import create_epiano_datasets, compute_epiano_accuracy 6 | 7 | from model.music_transformer import MusicTransformer 8 | 9 | from utilities.constants import * 10 | from utilities.device import get_device, use_cuda 11 | from utilities.argument_funcs import parse_eval_args, print_eval_args 12 | from utilities.run_model import eval_model 13 | 14 | # main 15 | def main(): 16 | """ 17 | ---------- 18 | Author: Damon Gwinn 19 | ---------- 20 | Entry point. Evaluates a model specified by command line arguments 21 | ---------- 22 | """ 23 | 24 | args = parse_eval_args() 25 | print_eval_args(args) 26 | 27 | if(args.force_cpu): 28 | use_cuda(False) 29 | print("WARNING: Forced CPU usage, expect model to perform slower") 30 | print("") 31 | 32 | # Test dataset 33 | _, _, test_dataset = create_epiano_datasets(args.dataset_dir, args.max_sequence) 34 | 35 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.n_workers) 36 | 37 | model = MusicTransformer(n_layers=args.n_layers, num_heads=args.num_heads, 38 | d_model=args.d_model, dim_feedforward=args.dim_feedforward, 39 | max_sequence=args.max_sequence, rpr=args.rpr).to(get_device()) 40 | 41 | model.load_state_dict(torch.load(args.model_weights)) 42 | 43 | # No smoothed loss 44 | loss = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD) 45 | 46 | print("Evaluating:") 47 | model.eval() 48 | 49 | avg_loss, avg_acc = eval_model(model, test_loader, loss) 50 | 51 | print("Avg loss:", avg_loss) 52 | print("Avg acc:", avg_acc) 53 | print(SEPERATOR) 54 | print("") 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import random 5 | 6 | from third_party.midi_processor.processor import decode_midi, encode_midi 7 | 8 | from utilities.argument_funcs import parse_generate_args, print_generate_args 9 | from model.music_transformer import MusicTransformer 10 | from dataset.e_piano import create_epiano_datasets, compute_epiano_accuracy, process_midi 11 | from torch.utils.data import DataLoader 12 | from torch.optim import Adam 13 | 14 | from utilities.constants import * 15 | from utilities.device import get_device, use_cuda 16 | 17 | # main 18 | def main(): 19 | """ 20 | ---------- 21 | Author: Damon Gwinn 22 | ---------- 23 | Entry point. Generates music from a model specified by command line arguments 24 | ---------- 25 | """ 26 | 27 | args = parse_generate_args() 28 | print_generate_args(args) 29 | 30 | if(args.force_cpu): 31 | use_cuda(False) 32 | print("WARNING: Forced CPU usage, expect model to perform slower") 33 | print("") 34 | 35 | os.makedirs(args.output_dir, exist_ok=True) 36 | 37 | # Grabbing dataset if needed 38 | _, _, dataset = create_epiano_datasets(args.midi_root, args.num_prime, random_seq=False) 39 | 40 | # Can be None, an integer index to dataset, or a file path 41 | if(args.primer_file is None): 42 | f = str(random.randrange(len(dataset))) 43 | else: 44 | f = args.primer_file 45 | 46 | if(f.isdigit()): 47 | idx = int(f) 48 | primer, _ = dataset[idx] 49 | primer = primer.to(get_device()) 50 | 51 | print("Using primer index:", idx, "(", dataset.data_files[idx], ")") 52 | 53 | else: 54 | raw_mid = encode_midi(f) 55 | if(len(raw_mid) == 0): 56 | print("Error: No midi messages in primer file:", f) 57 | return 58 | 59 | primer, _ = process_midi(raw_mid, args.num_prime, random_seq=False) 60 | primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device()) 61 | 62 | print("Using primer file:", f) 63 | 64 | model = MusicTransformer(n_layers=args.n_layers, num_heads=args.num_heads, 65 | d_model=args.d_model, dim_feedforward=args.dim_feedforward, 66 | max_sequence=args.max_sequence, rpr=args.rpr).to(get_device()) 67 | 68 | model.load_state_dict(torch.load(args.model_weights)) 69 | 70 | # Saving primer first 71 | f_path = os.path.join(args.output_dir, "primer.mid") 72 | decode_midi(primer[:args.num_prime].cpu().numpy(), file_path=f_path) 73 | 74 | # GENERATION 75 | model.eval() 76 | with torch.set_grad_enabled(False): 77 | if(args.beam > 0): 78 | print("BEAM:", args.beam) 79 | beam_seq = model.generate(primer[:args.num_prime], args.target_seq_length, beam=args.beam) 80 | 81 | f_path = os.path.join(args.output_dir, "beam.mid") 82 | decode_midi(beam_seq[0].cpu().numpy(), file_path=f_path) 83 | else: 84 | print("RAND DIST") 85 | rand_seq = model.generate(primer[:args.num_prime], args.target_seq_length, beam=0) 86 | 87 | f_path = os.path.join(args.output_dir, "rand.mid") 88 | decode_midi(rand_seq[0].cpu().numpy(), file_path=f_path) 89 | 90 | 91 | 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /graph_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import csv 4 | import math 5 | import matplotlib.pyplot as plt 6 | 7 | RESULTS_FILE = "results.csv" 8 | EPOCH_IDX = 0 9 | LR_IDX = 1 10 | EVAL_LOSS_IDX = 4 11 | EVAL_ACC_IDX = 5 12 | 13 | SPLITTER = '?' 14 | 15 | # graph_results 16 | def graph_results(input_dirs="./saved_models/results", output_dir=None, model_names=None, epoch_start=0, epoch_end=None): 17 | """ 18 | ---------- 19 | Author: Damon Gwinn 20 | ---------- 21 | Graphs model training and evaluation data 22 | ---------- 23 | """ 24 | 25 | input_dirs = input_dirs.split(SPLITTER) 26 | 27 | if(model_names is not None): 28 | model_names = model_names.split(SPLITTER) 29 | if(len(model_names) != len(input_dirs)): 30 | print("Error: len(model_names) != len(input_dirs)") 31 | return 32 | 33 | #Initialize Loss and Accuracy arrays 34 | loss_arrs = [] 35 | accuracy_arrs = [] 36 | epoch_counts = [] 37 | lrs = [] 38 | 39 | for input_dir in input_dirs: 40 | loss_arr = [] 41 | accuracy_arr = [] 42 | epoch_count = [] 43 | lr_arr = [] 44 | 45 | f = os.path.join(input_dir, RESULTS_FILE) 46 | with open(f, "r") as i_stream: 47 | reader = csv.reader(i_stream) 48 | next(reader) 49 | 50 | lines = [line for line in reader] 51 | 52 | if(epoch_end is None): 53 | epoch_end = math.inf 54 | 55 | epoch_start = max(epoch_start, 0) 56 | epoch_start = min(epoch_start, epoch_end) 57 | 58 | for line in lines: 59 | epoch = line[EPOCH_IDX] 60 | lr = line[LR_IDX] 61 | accuracy = line[EVAL_ACC_IDX] 62 | loss = line[EVAL_LOSS_IDX] 63 | 64 | if(int(epoch) >= epoch_start and int(epoch) < epoch_end): 65 | accuracy_arr.append(float(accuracy)) 66 | loss_arr.append(float(loss)) 67 | epoch_count.append(int(epoch)) 68 | lr_arr.append(float(lr)) 69 | 70 | loss_arrs.append(loss_arr) 71 | accuracy_arrs.append(accuracy_arr) 72 | epoch_counts.append(epoch_count) 73 | lrs.append(lr_arr) 74 | 75 | if(output_dir is not None): 76 | try: 77 | os.mkdir(output_dir) 78 | except OSError: 79 | print ("Creation of the directory %s failed" % output_dir) 80 | else: 81 | print ("Successfully created the directory %s" % output_dir) 82 | 83 | ##### LOSS ##### 84 | for i in range(len(loss_arrs)): 85 | if(model_names is None): 86 | name = None 87 | else: 88 | name = model_names[i] 89 | 90 | #Create and save plots to output folder 91 | plt.plot(epoch_counts[i], loss_arrs[i], label=name) 92 | plt.title("Loss Results") 93 | plt.ylabel('Loss (Cross Entropy)') 94 | plt.xlabel('Epochs') 95 | fig1 = plt.gcf() 96 | 97 | plt.legend(loc="upper left") 98 | 99 | if(output_dir is not None): 100 | fig1.savefig(os.path.join(output_dir, 'loss_graph.png')) 101 | 102 | plt.show() 103 | 104 | ##### ACCURACY ##### 105 | for i in range(len(accuracy_arrs)): 106 | if(model_names is None): 107 | name = None 108 | else: 109 | name = model_names[i] 110 | 111 | #Create and save plots to output folder 112 | plt.plot(epoch_counts[i], accuracy_arrs[i], label=name) 113 | plt.title("Accuracy Results") 114 | plt.ylabel('Accuracy') 115 | plt.xlabel('Epochs') 116 | fig2 = plt.gcf() 117 | 118 | plt.legend(loc="upper left") 119 | 120 | if(output_dir is not None): 121 | fig2.savefig(os.path.join(output_dir, 'accuracy_graph.png')) 122 | 123 | plt.show() 124 | 125 | ##### LR ##### 126 | for i in range(len(lrs)): 127 | if(model_names is None): 128 | name = None 129 | else: 130 | name = model_names[i] 131 | 132 | #Create and save plots to output folder 133 | plt.plot(epoch_counts[i], lrs[i], label=name) 134 | plt.title("Learn Rate Results") 135 | plt.ylabel('Learn Rate') 136 | plt.xlabel('Epochs') 137 | fig2 = plt.gcf() 138 | 139 | plt.legend(loc="upper left") 140 | 141 | if(output_dir is not None): 142 | fig2.savefig(os.path.join(output_dir, 'lr_graph.png')) 143 | 144 | plt.show() 145 | 146 | # graph_results_legacy 147 | def graph_results_legacy(input_dirs="./saved_models/results", output_dir=None, model_names=None, epoch_start=0, epoch_end=None): 148 | """ 149 | ---------- 150 | Author: Ben Myrick 151 | Modified: Damon Gwinn 152 | ---------- 153 | Graphs model training and evaluation data using the old results format (legacy) 154 | ---------- 155 | """ 156 | 157 | input_dirs = input_dirs.split(SPLITTER) 158 | 159 | if(model_names is not None): 160 | model_names = model_names.split(SPLITTER) 161 | if(len(model_names) != len(input_dirs)): 162 | print("Error: len(model_names) != len(input_dirs)") 163 | return 164 | 165 | #Initialize Loss and Accuracy arrays 166 | loss_arrs = [] 167 | accuracy_arrs = [] 168 | epoch_counts = [] 169 | 170 | for input_dir in input_dirs: 171 | loss_arr = [] 172 | accuracy_arr = [] 173 | epoch_count = [] 174 | 175 | fs = [os.path.join(input_dir, f) for f in sorted(os.listdir(input_dir))] 176 | fs = [f for f in fs if os.path.isfile(f)] 177 | 178 | if(epoch_end is None): 179 | epoch_end = len(fs) 180 | else: 181 | epoch_end = min(epoch_end, len(fs)) 182 | 183 | epoch_start = max(epoch_start, 0) 184 | epoch_start = min(epoch_start, epoch_end) 185 | 186 | for x in range(epoch_start, epoch_end): 187 | path = fs[x] 188 | 189 | #Read file and parse accuracy and loss values 190 | file = open(path, 'r') 191 | temp_average_accuracy = file.readline() 192 | temp_average_loss = file.readline() 193 | 194 | #Update accuracy and loss arrays for each epoch 195 | accuracy_arr.append(float(temp_average_accuracy)) 196 | loss_arr.append(float(temp_average_loss)) 197 | epoch_count.append(x) 198 | 199 | file.close() 200 | 201 | loss_arrs.append(loss_arr) 202 | accuracy_arrs.append(accuracy_arr) 203 | epoch_counts.append(epoch_count) 204 | 205 | if(output_dir is not None): 206 | try: 207 | os.mkdir(output_dir) 208 | except OSError: 209 | print ("Creation of the directory %s failed" % output_dir) 210 | else: 211 | print ("Successfully created the directory %s" % output_dir) 212 | 213 | for i in range(len(loss_arrs)): 214 | if(model_names is None): 215 | name = input_dirs[i] 216 | else: 217 | name = model_names[i] 218 | 219 | #Create and save plots to output folder 220 | plt.plot(epoch_counts[i], loss_arrs[i], label=name) 221 | plt.title("Loss Results") 222 | plt.ylabel('Loss (Cross Entropy)') 223 | plt.xlabel('Epochs') 224 | fig1 = plt.gcf() 225 | 226 | plt.legend(loc="upper left") 227 | 228 | if(output_dir is not None): 229 | fig1.savefig(os.path.join(output_dir, 'loss_graph.png')) 230 | 231 | plt.show() 232 | 233 | for i in range(len(loss_arrs)): 234 | if(model_names is None): 235 | name = input_dirs[i] 236 | else: 237 | name = model_names[i] 238 | 239 | #Create and save plots to output folder 240 | plt.plot(epoch_counts[i], accuracy_arrs[i], label=name) 241 | plt.title("Accuracy Results") 242 | plt.ylabel('Accuracy') 243 | plt.xlabel('Epochs') 244 | fig2 = plt.gcf() 245 | 246 | plt.legend(loc="upper left") 247 | 248 | if(output_dir is not None): 249 | fig2.savefig(os.path.join(output_dir, 'accuracy_graph.png')) 250 | 251 | plt.show() 252 | 253 | # parse_args 254 | def parse_args(): 255 | """ 256 | ---------- 257 | Author: Damon Gwinn 258 | ---------- 259 | Argparse arguments 260 | ---------- 261 | """ 262 | 263 | parser = argparse.ArgumentParser() 264 | 265 | parser.add_argument("-input_dirs", type=str, default="./saved_models/results", help="Input results folder from trained model ('results' folder). Seperate with '?' symbol for comparisons between models") 266 | parser.add_argument("-output_dir", type=str, default=None, help="Optional output folder to save graph pngs") 267 | parser.add_argument("-model_names", type=str, default=None, help="Names to display when color coding, seperate with ':'.") 268 | parser.add_argument("-epoch_start", type=int, default=0, help="Epoch start. Defaults to first file.") 269 | parser.add_argument("-epoch_end", type=int, default=None, help="Epoch end (non-inclusive). Defaults to None.") 270 | parser.add_argument("--legacy", action="store_true", help="Use legacy results output format (you likely don't need this)") 271 | 272 | return parser.parse_args() 273 | 274 | def main(): 275 | """ 276 | ---------- 277 | Author: Ben Myrick 278 | Modified: Damon Gwinn 279 | ---------- 280 | Entry point 281 | ---------- 282 | """ 283 | 284 | args = parse_args() 285 | 286 | if(not args.legacy): 287 | graph_results(args.input_dirs, args.output_dir, args.model_names, args.epoch_start, args.epoch_end) 288 | else: 289 | graph_results_legacy(args.input_dirs, args.output_dir, args.model_names, args.epoch_start, args.epoch_end) 290 | 291 | if __name__ == "__main__": 292 | main() 293 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.loss import _Loss 5 | 6 | # Borrowed from https://github.com/jason9693/MusicTransformer-pytorch/blob/5f183374833ff6b7e17f3a24e3594dedd93a5fe5/custom/criterion.py#L28 7 | class SmoothCrossEntropyLoss(_Loss): 8 | """ 9 | https://arxiv.org/abs/1512.00567 10 | """ 11 | __constants__ = ['label_smoothing', 'vocab_size', 'ignore_index', 'reduction'] 12 | 13 | def __init__(self, label_smoothing, vocab_size, ignore_index=-100, reduction='mean', is_logits=True): 14 | assert 0.0 <= label_smoothing <= 1.0 15 | super().__init__(reduction=reduction) 16 | 17 | self.label_smoothing = label_smoothing 18 | self.vocab_size = vocab_size 19 | self.ignore_index = ignore_index 20 | self.input_is_logits = is_logits 21 | 22 | def forward(self, input, target): 23 | """ 24 | Args: 25 | input: [B * T, V] 26 | target: [B * T] 27 | Returns: 28 | cross entropy: [1] 29 | """ 30 | mask = (target == self.ignore_index).unsqueeze(-1) 31 | q = F.one_hot(target.long(), self.vocab_size).type(torch.float32) 32 | u = 1.0 / self.vocab_size 33 | q_prime = (1.0 - self.label_smoothing) * q + self.label_smoothing * u 34 | q_prime = q_prime.masked_fill(mask, 0) 35 | 36 | ce = self.cross_entropy_with_logits(q_prime, input) 37 | if self.reduction == 'mean': 38 | lengths = torch.sum(target != self.ignore_index) 39 | return ce.sum() / lengths 40 | elif self.reduction == 'sum': 41 | return ce.sum() 42 | else: 43 | raise NotImplementedError 44 | 45 | def cross_entropy_with_logits(self, p, q): 46 | return -torch.sum(p * (q - q.logsumexp(dim=-1, keepdim=True)), dim=-1) 47 | -------------------------------------------------------------------------------- /model/music_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.normalization import LayerNorm 4 | import random 5 | 6 | from utilities.constants import * 7 | from utilities.device import get_device 8 | 9 | from .positional_encoding import PositionalEncoding 10 | from .rpr import TransformerEncoderRPR, TransformerEncoderLayerRPR 11 | 12 | 13 | # MusicTransformer 14 | class MusicTransformer(nn.Module): 15 | """ 16 | ---------- 17 | Author: Damon Gwinn 18 | ---------- 19 | Music Transformer reproduction from https://arxiv.org/abs/1809.04281. Arguments allow for 20 | tweaking the transformer architecture (https://arxiv.org/abs/1706.03762) and the rpr argument 21 | toggles Relative Position Representations (RPR - https://arxiv.org/abs/1803.02155). 22 | 23 | Supports training and generation using Pytorch's nn.Transformer class with dummy decoder to 24 | make a decoder-only transformer architecture 25 | 26 | For RPR support, there is modified Pytorch 1.2.0 code in rpr.py. Modified source will be 27 | kept up to date with Pytorch revisions only as necessary. 28 | ---------- 29 | """ 30 | 31 | def __init__(self, n_layers=6, num_heads=8, d_model=512, dim_feedforward=1024, 32 | dropout=0.1, max_sequence=2048, rpr=False): 33 | super(MusicTransformer, self).__init__() 34 | 35 | self.dummy = DummyDecoder() 36 | 37 | self.nlayers = n_layers 38 | self.nhead = num_heads 39 | self.d_model = d_model 40 | self.d_ff = dim_feedforward 41 | self.dropout = dropout 42 | self.max_seq = max_sequence 43 | self.rpr = rpr 44 | 45 | # Input embedding 46 | self.embedding = nn.Embedding(VOCAB_SIZE, self.d_model) 47 | 48 | # Positional encoding 49 | self.positional_encoding = PositionalEncoding(self.d_model, self.dropout, self.max_seq) 50 | 51 | # Base transformer 52 | if(not self.rpr): 53 | # To make a decoder-only transformer we need to use masked encoder layers 54 | # Dummy decoder to essentially just return the encoder output 55 | self.transformer = nn.Transformer( 56 | d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers, 57 | num_decoder_layers=0, dropout=self.dropout, # activation=self.ff_activ, 58 | dim_feedforward=self.d_ff, custom_decoder=self.dummy 59 | ) 60 | # RPR Transformer 61 | else: 62 | encoder_norm = LayerNorm(self.d_model) 63 | encoder_layer = TransformerEncoderLayerRPR(self.d_model, self.nhead, self.d_ff, self.dropout, er_len=self.max_seq) 64 | encoder = TransformerEncoderRPR(encoder_layer, self.nlayers, encoder_norm) 65 | self.transformer = nn.Transformer( 66 | d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers, 67 | num_decoder_layers=0, dropout=self.dropout, # activation=self.ff_activ, 68 | dim_feedforward=self.d_ff, custom_decoder=self.dummy, custom_encoder=encoder 69 | ) 70 | 71 | # Final output is a softmaxed linear layer 72 | self.Wout = nn.Linear(self.d_model, VOCAB_SIZE) 73 | self.softmax = nn.Softmax(dim=-1) 74 | 75 | # forward 76 | def forward(self, x, mask=True): 77 | """ 78 | ---------- 79 | Author: Damon Gwinn 80 | ---------- 81 | Takes an input sequence and outputs predictions using a sequence to sequence method. 82 | 83 | A prediction at one index is the "next" prediction given all information seen previously. 84 | ---------- 85 | """ 86 | 87 | if(mask is True): 88 | mask = self.transformer.generate_square_subsequent_mask(x.shape[1]).to(get_device()) 89 | else: 90 | mask = None 91 | 92 | x = self.embedding(x) 93 | 94 | # Input shape is (max_seq, batch_size, d_model) 95 | x = x.permute(1,0,2) 96 | 97 | x = self.positional_encoding(x) 98 | 99 | # Since there are no true decoder layers, the tgt is unused 100 | # Pytorch wants src and tgt to have some equal dims however 101 | x_out = self.transformer(src=x, tgt=x, src_mask=mask) 102 | 103 | # Back to (batch_size, max_seq, d_model) 104 | x_out = x_out.permute(1,0,2) 105 | 106 | y = self.Wout(x_out) 107 | # y = self.softmax(y) 108 | 109 | del mask 110 | 111 | # They are trained to predict the next note in sequence (we don't need the last one) 112 | return y 113 | 114 | # generate 115 | def generate(self, primer=None, target_seq_length=1024, beam=0, beam_chance=1.0): 116 | """ 117 | ---------- 118 | Author: Damon Gwinn 119 | ---------- 120 | Generates midi given a primer sample. Music can be generated using a probability distribution over 121 | the softmax probabilities (recommended) or by using a beam search. 122 | ---------- 123 | """ 124 | 125 | assert (not self.training), "Cannot generate while in training mode" 126 | 127 | print("Generating sequence of max length:", target_seq_length) 128 | 129 | gen_seq = torch.full((1,target_seq_length), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device()) 130 | 131 | num_primer = len(primer) 132 | gen_seq[..., :num_primer] = primer.type(TORCH_LABEL_TYPE).to(get_device()) 133 | 134 | 135 | # print("primer:",primer) 136 | # print(gen_seq) 137 | cur_i = num_primer 138 | while(cur_i < target_seq_length): 139 | # gen_seq_batch = gen_seq.clone() 140 | y = self.softmax(self.forward(gen_seq[..., :cur_i]))[..., :TOKEN_END] 141 | token_probs = y[:, cur_i-1, :] 142 | 143 | if(beam == 0): 144 | beam_ran = 2.0 145 | else: 146 | beam_ran = random.uniform(0,1) 147 | 148 | if(beam_ran <= beam_chance): 149 | token_probs = token_probs.flatten() 150 | top_res, top_i = torch.topk(token_probs, beam) 151 | 152 | beam_rows = top_i // VOCAB_SIZE 153 | beam_cols = top_i % VOCAB_SIZE 154 | 155 | gen_seq = gen_seq[beam_rows, :] 156 | gen_seq[..., cur_i] = beam_cols 157 | 158 | else: 159 | distrib = torch.distributions.categorical.Categorical(probs=token_probs) 160 | next_token = distrib.sample() 161 | # print("next token:",next_token) 162 | gen_seq[:, cur_i] = next_token 163 | 164 | 165 | # Let the transformer decide to end if it wants to 166 | if(next_token == TOKEN_END): 167 | print("Model called end of sequence at:", cur_i, "/", target_seq_length) 168 | break 169 | 170 | cur_i += 1 171 | if(cur_i % 50 == 0): 172 | print(cur_i, "/", target_seq_length) 173 | 174 | return gen_seq[:, :cur_i] 175 | 176 | # Used as a dummy to nn.Transformer 177 | # DummyDecoder 178 | class DummyDecoder(nn.Module): 179 | """ 180 | ---------- 181 | Author: Damon Gwinn 182 | ---------- 183 | A dummy decoder that returns its input. Used to make the Pytorch transformer into a decoder-only 184 | architecture (stacked encoders with dummy decoder fits the bill) 185 | ---------- 186 | """ 187 | 188 | def __init__(self): 189 | super(DummyDecoder, self).__init__() 190 | 191 | def forward(self, tgt, memory, tgt_mask, memory_mask,tgt_key_padding_mask,memory_key_padding_mask): 192 | """ 193 | ---------- 194 | Author: Damon Gwinn 195 | ---------- 196 | Returns the input (memory) 197 | ---------- 198 | """ 199 | 200 | return memory 201 | -------------------------------------------------------------------------------- /model/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | # PositionalEncoding 6 | # Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html 7 | class PositionalEncoding(nn.Module): 8 | 9 | def __init__(self, d_model, dropout=0.1, max_len=5000): 10 | super(PositionalEncoding, self).__init__() 11 | self.dropout = nn.Dropout(p=dropout) 12 | 13 | pe = torch.zeros(max_len, d_model) 14 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 15 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 16 | pe[:, 0::2] = torch.sin(position * div_term) 17 | pe[:, 1::2] = torch.cos(position * div_term) 18 | pe = pe.unsqueeze(0).transpose(0, 1) 19 | self.register_buffer('pe', pe) 20 | 21 | def forward(self, x): 22 | x = x + self.pe[:x.size(0), :] 23 | return self.dropout(x) 24 | -------------------------------------------------------------------------------- /model/rpr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.nn import functional as F 5 | from torch.nn.parameter import Parameter 6 | from torch.nn import Module 7 | from torch.nn.modules.transformer import _get_clones 8 | from torch.nn.modules.linear import Linear 9 | from torch.nn.modules.dropout import Dropout 10 | from torch.nn.modules.normalization import LayerNorm 11 | from torch.nn.init import * 12 | 13 | from torch.nn.functional import linear, softmax, dropout 14 | 15 | # TransformerEncoderRPR 16 | class TransformerEncoderRPR(Module): 17 | """ 18 | ---------- 19 | Author: Pytorch 20 | ---------- 21 | For Relative Position Representation support (https://arxiv.org/abs/1803.02155) 22 | https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/transformer.html#TransformerEncoder 23 | 24 | No modification. Copied here to ensure continued compatibility with other edits. 25 | ---------- 26 | """ 27 | 28 | def __init__(self, encoder_layer, num_layers, norm=None): 29 | super(TransformerEncoderRPR, self).__init__() 30 | self.layers = _get_clones(encoder_layer, num_layers) 31 | self.num_layers = num_layers 32 | self.norm = norm 33 | 34 | def forward(self, src, mask=None, src_key_padding_mask=None): 35 | 36 | output = src 37 | 38 | for i in range(self.num_layers): 39 | output = self.layers[i](output, src_mask=mask, 40 | src_key_padding_mask=src_key_padding_mask) 41 | 42 | if self.norm: 43 | output = self.norm(output) 44 | 45 | return output 46 | 47 | # TransformerEncoderLayerRPR 48 | class TransformerEncoderLayerRPR(Module): 49 | """ 50 | ---------- 51 | Author: Pytorch 52 | Modified: Damon Gwinn 53 | ---------- 54 | For Relative Position Representation support (https://arxiv.org/abs/1803.02155) 55 | https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer 56 | 57 | Modification to create and call custom MultiheadAttentionRPR 58 | ---------- 59 | """ 60 | 61 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None): 62 | super(TransformerEncoderLayerRPR, self).__init__() 63 | self.self_attn = MultiheadAttentionRPR(d_model, nhead, dropout=dropout, er_len=er_len) 64 | # Implementation of Feedforward model 65 | self.linear1 = Linear(d_model, dim_feedforward) 66 | self.dropout = Dropout(dropout) 67 | self.linear2 = Linear(dim_feedforward, d_model) 68 | 69 | self.norm1 = LayerNorm(d_model) 70 | self.norm2 = LayerNorm(d_model) 71 | self.dropout1 = Dropout(dropout) 72 | self.dropout2 = Dropout(dropout) 73 | 74 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 75 | src2 = self.self_attn(src, src, src, attn_mask=src_mask, 76 | key_padding_mask=src_key_padding_mask)[0] 77 | src = src + self.dropout1(src2) 78 | src = self.norm1(src) 79 | src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) 80 | src = src + self.dropout2(src2) 81 | src = self.norm2(src) 82 | return src 83 | 84 | # MultiheadAttentionRPR 85 | class MultiheadAttentionRPR(Module): 86 | """ 87 | ---------- 88 | Author: Pytorch 89 | Modified: Damon Gwinn 90 | ---------- 91 | For Relative Position Representation support (https://arxiv.org/abs/1803.02155) 92 | https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/activation.html#MultiheadAttention 93 | 94 | Modification to add RPR embedding Er and call custom multi_head_attention_forward_rpr 95 | ---------- 96 | """ 97 | 98 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, er_len=None): 99 | super(MultiheadAttentionRPR, self).__init__() 100 | self.embed_dim = embed_dim 101 | self.kdim = kdim if kdim is not None else embed_dim 102 | self.vdim = vdim if vdim is not None else embed_dim 103 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 104 | 105 | self.num_heads = num_heads 106 | self.dropout = dropout 107 | self.head_dim = embed_dim // num_heads 108 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 109 | 110 | self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) 111 | 112 | if self._qkv_same_embed_dim is False: 113 | self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) 114 | self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) 115 | self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) 116 | 117 | if bias: 118 | self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) 119 | else: 120 | self.register_parameter('in_proj_bias', None) 121 | self.out_proj = Linear(embed_dim, embed_dim, bias=bias) 122 | 123 | if add_bias_kv: 124 | self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) 125 | self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) 126 | else: 127 | self.bias_k = self.bias_v = None 128 | 129 | self.add_zero_attn = add_zero_attn 130 | 131 | # Adding RPR embedding matrix 132 | if(er_len is not None): 133 | self.Er = Parameter(torch.rand((er_len, self.head_dim), dtype=torch.float32)) 134 | else: 135 | self.Er = None 136 | 137 | self._reset_parameters() 138 | 139 | def _reset_parameters(self): 140 | if self._qkv_same_embed_dim: 141 | xavier_uniform_(self.in_proj_weight) 142 | else: 143 | xavier_uniform_(self.q_proj_weight) 144 | xavier_uniform_(self.k_proj_weight) 145 | xavier_uniform_(self.v_proj_weight) 146 | 147 | if self.in_proj_bias is not None: 148 | constant_(self.in_proj_bias, 0.) 149 | constant_(self.out_proj.bias, 0.) 150 | if self.bias_k is not None: 151 | xavier_normal_(self.bias_k) 152 | if self.bias_v is not None: 153 | xavier_normal_(self.bias_v) 154 | 155 | def forward(self, query, key, value, key_padding_mask=None, 156 | need_weights=True, attn_mask=None): 157 | 158 | if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False: 159 | # return F.multi_head_attention_forward( 160 | # query, key, value, self.embed_dim, self.num_heads, 161 | # self.in_proj_weight, self.in_proj_bias, 162 | # self.bias_k, self.bias_v, self.add_zero_attn, 163 | # self.dropout, self.out_proj.weight, self.out_proj.bias, 164 | # training=self.training, 165 | # key_padding_mask=key_padding_mask, need_weights=need_weights, 166 | # attn_mask=attn_mask, use_separate_proj_weight=True, 167 | # q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 168 | # v_proj_weight=self.v_proj_weight) 169 | 170 | return multi_head_attention_forward_rpr( 171 | query, key, value, self.embed_dim, self.num_heads, 172 | self.in_proj_weight, self.in_proj_bias, 173 | self.bias_k, self.bias_v, self.add_zero_attn, 174 | self.dropout, self.out_proj.weight, self.out_proj.bias, 175 | training=self.training, 176 | key_padding_mask=key_padding_mask, need_weights=need_weights, 177 | attn_mask=attn_mask, use_separate_proj_weight=True, 178 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 179 | v_proj_weight=self.v_proj_weight, rpr_mat=self.Er) 180 | else: 181 | if not hasattr(self, '_qkv_same_embed_dim'): 182 | warnings.warn('A new version of MultiheadAttention module has been implemented. \ 183 | Please re-train your model with the new module', 184 | UserWarning) 185 | 186 | # return F.multi_head_attention_forward( 187 | # query, key, value, self.embed_dim, self.num_heads, 188 | # self.in_proj_weight, self.in_proj_bias, 189 | # self.bias_k, self.bias_v, self.add_zero_attn, 190 | # self.dropout, self.out_proj.weight, self.out_proj.bias, 191 | # training=self.training, 192 | # key_padding_mask=key_padding_mask, need_weights=need_weights, 193 | # attn_mask=attn_mask) 194 | 195 | return multi_head_attention_forward_rpr( 196 | query, key, value, self.embed_dim, self.num_heads, 197 | self.in_proj_weight, self.in_proj_bias, 198 | self.bias_k, self.bias_v, self.add_zero_attn, 199 | self.dropout, self.out_proj.weight, self.out_proj.bias, 200 | training=self.training, 201 | key_padding_mask=key_padding_mask, need_weights=need_weights, 202 | attn_mask=attn_mask, rpr_mat=self.Er) 203 | 204 | # multi_head_attention_forward_rpr 205 | def multi_head_attention_forward_rpr(query, # type: Tensor 206 | key, # type: Tensor 207 | value, # type: Tensor 208 | embed_dim_to_check, # type: int 209 | num_heads, # type: int 210 | in_proj_weight, # type: Tensor 211 | in_proj_bias, # type: Tensor 212 | bias_k, # type: Optional[Tensor] 213 | bias_v, # type: Optional[Tensor] 214 | add_zero_attn, # type: bool 215 | dropout_p, # type: float 216 | out_proj_weight, # type: Tensor 217 | out_proj_bias, # type: Tensor 218 | training=True, # type: bool 219 | key_padding_mask=None, # type: Optional[Tensor] 220 | need_weights=True, # type: bool 221 | attn_mask=None, # type: Optional[Tensor] 222 | use_separate_proj_weight=False, # type: bool 223 | q_proj_weight=None, # type: Optional[Tensor] 224 | k_proj_weight=None, # type: Optional[Tensor] 225 | v_proj_weight=None, # type: Optional[Tensor] 226 | static_k=None, # type: Optional[Tensor] 227 | static_v=None, # type: Optional[Tensor] 228 | rpr_mat=None 229 | ): 230 | """ 231 | ---------- 232 | Author: Pytorch 233 | Modified: Damon Gwinn 234 | ---------- 235 | For Relative Position Representation support (https://arxiv.org/abs/1803.02155) 236 | https://pytorch.org/docs/1.2.0/_modules/torch/nn/functional.html 237 | 238 | Modification to take RPR embedding matrix and perform skew optimized RPR (https://arxiv.org/abs/1809.04281) 239 | ---------- 240 | """ 241 | 242 | # type: (...) -> Tuple[Tensor, Optional[Tensor]] 243 | 244 | qkv_same = torch.equal(query, key) and torch.equal(key, value) 245 | kv_same = torch.equal(key, value) 246 | 247 | tgt_len, bsz, embed_dim = query.size() 248 | assert embed_dim == embed_dim_to_check 249 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 250 | assert key.size() == value.size() 251 | 252 | head_dim = embed_dim // num_heads 253 | assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" 254 | scaling = float(head_dim) ** -0.5 255 | 256 | if use_separate_proj_weight is not True: 257 | if qkv_same: 258 | # self-attention 259 | q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) 260 | 261 | elif kv_same: 262 | # encoder-decoder attention 263 | # This is inline in_proj function with in_proj_weight and in_proj_bias 264 | _b = in_proj_bias 265 | _start = 0 266 | _end = embed_dim 267 | _w = in_proj_weight[_start:_end, :] 268 | if _b is not None: 269 | _b = _b[_start:_end] 270 | q = linear(query, _w, _b) 271 | 272 | if key is None: 273 | assert value is None 274 | k = None 275 | v = None 276 | else: 277 | 278 | # This is inline in_proj function with in_proj_weight and in_proj_bias 279 | _b = in_proj_bias 280 | _start = embed_dim 281 | _end = None 282 | _w = in_proj_weight[_start:, :] 283 | if _b is not None: 284 | _b = _b[_start:] 285 | k, v = linear(key, _w, _b).chunk(2, dim=-1) 286 | 287 | else: 288 | # This is inline in_proj function with in_proj_weight and in_proj_bias 289 | _b = in_proj_bias 290 | _start = 0 291 | _end = embed_dim 292 | _w = in_proj_weight[_start:_end, :] 293 | if _b is not None: 294 | _b = _b[_start:_end] 295 | q = linear(query, _w, _b) 296 | 297 | # This is inline in_proj function with in_proj_weight and in_proj_bias 298 | _b = in_proj_bias 299 | _start = embed_dim 300 | _end = embed_dim * 2 301 | _w = in_proj_weight[_start:_end, :] 302 | if _b is not None: 303 | _b = _b[_start:_end] 304 | k = linear(key, _w, _b) 305 | 306 | # This is inline in_proj function with in_proj_weight and in_proj_bias 307 | _b = in_proj_bias 308 | _start = embed_dim * 2 309 | _end = None 310 | _w = in_proj_weight[_start:, :] 311 | if _b is not None: 312 | _b = _b[_start:] 313 | v = linear(value, _w, _b) 314 | else: 315 | q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) 316 | len1, len2 = q_proj_weight_non_opt.size() 317 | assert len1 == embed_dim and len2 == query.size(-1) 318 | 319 | k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) 320 | len1, len2 = k_proj_weight_non_opt.size() 321 | assert len1 == embed_dim and len2 == key.size(-1) 322 | 323 | v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) 324 | len1, len2 = v_proj_weight_non_opt.size() 325 | assert len1 == embed_dim and len2 == value.size(-1) 326 | 327 | if in_proj_bias is not None: 328 | q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) 329 | k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) 330 | v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) 331 | else: 332 | q = linear(query, q_proj_weight_non_opt, in_proj_bias) 333 | k = linear(key, k_proj_weight_non_opt, in_proj_bias) 334 | v = linear(value, v_proj_weight_non_opt, in_proj_bias) 335 | q = q * scaling 336 | 337 | if bias_k is not None and bias_v is not None: 338 | if static_k is None and static_v is None: 339 | k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 340 | v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 341 | if attn_mask is not None: 342 | attn_mask = torch.cat([attn_mask, 343 | torch.zeros((attn_mask.size(0), 1), 344 | dtype=attn_mask.dtype, 345 | device=attn_mask.device)], dim=1) 346 | if key_padding_mask is not None: 347 | key_padding_mask = torch.cat( 348 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 349 | dtype=key_padding_mask.dtype, 350 | device=key_padding_mask.device)], dim=1) 351 | else: 352 | assert static_k is None, "bias cannot be added to static key." 353 | assert static_v is None, "bias cannot be added to static value." 354 | else: 355 | assert bias_k is None 356 | assert bias_v is None 357 | 358 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 359 | if k is not None: 360 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 361 | if v is not None: 362 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 363 | 364 | if static_k is not None: 365 | assert static_k.size(0) == bsz * num_heads 366 | assert static_k.size(2) == head_dim 367 | k = static_k 368 | 369 | if static_v is not None: 370 | assert static_v.size(0) == bsz * num_heads 371 | assert static_v.size(2) == head_dim 372 | v = static_v 373 | 374 | src_len = k.size(1) 375 | 376 | if key_padding_mask is not None: 377 | assert key_padding_mask.size(0) == bsz 378 | assert key_padding_mask.size(1) == src_len 379 | 380 | if add_zero_attn: 381 | src_len += 1 382 | k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) 383 | v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) 384 | if attn_mask is not None: 385 | attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1), 386 | dtype=attn_mask.dtype, 387 | device=attn_mask.device)], dim=1) 388 | if key_padding_mask is not None: 389 | key_padding_mask = torch.cat( 390 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 391 | dtype=key_padding_mask.dtype, 392 | device=key_padding_mask.device)], dim=1) 393 | 394 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 395 | assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] 396 | 397 | ######### ADDITION OF RPR ########### 398 | if(rpr_mat is not None): 399 | rpr_mat = _get_valid_embedding(rpr_mat, q.shape[1], k.shape[1]) 400 | qe = torch.einsum("hld,md->hlm", q, rpr_mat) 401 | srel = _skew(qe) 402 | 403 | attn_output_weights += srel 404 | 405 | if attn_mask is not None: 406 | attn_mask = attn_mask.unsqueeze(0) 407 | attn_output_weights += attn_mask 408 | 409 | if key_padding_mask is not None: 410 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 411 | attn_output_weights = attn_output_weights.masked_fill( 412 | key_padding_mask.unsqueeze(1).unsqueeze(2), 413 | float('-inf'), 414 | ) 415 | attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) 416 | 417 | attn_output_weights = softmax( 418 | attn_output_weights, dim=-1) 419 | 420 | attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) 421 | 422 | attn_output = torch.bmm(attn_output_weights, v) 423 | assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] 424 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 425 | attn_output = linear(attn_output, out_proj_weight, out_proj_bias) 426 | 427 | if need_weights: 428 | # average attention weights over heads 429 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 430 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 431 | else: 432 | return attn_output, None 433 | 434 | def _get_valid_embedding(Er, len_q, len_k): 435 | """ 436 | ---------- 437 | Author: Damon Gwinn 438 | ---------- 439 | Gets valid embeddings based on max length of RPR attention 440 | ---------- 441 | """ 442 | 443 | len_e = Er.shape[0] 444 | start = max(0, len_e - len_q) 445 | return Er[start:, :] 446 | 447 | def _skew(qe): 448 | """ 449 | ---------- 450 | Author: Damon Gwinn 451 | ---------- 452 | Performs the skew optimized RPR computation (https://arxiv.org/abs/1809.04281) 453 | ---------- 454 | """ 455 | 456 | sz = qe.shape[1] 457 | mask = (torch.triu(torch.ones(sz, sz).to(qe.device)) == 1).float().flip(0) 458 | 459 | qe = mask * qe 460 | qe = F.pad(qe, (1,0, 0,0, 0,0)) 461 | qe = torch.reshape(qe, (qe.shape[0], qe.shape[2], qe.shape[1])) 462 | 463 | srel = qe[:, 1:, :] 464 | return srel 465 | -------------------------------------------------------------------------------- /preprocess_midi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import json 5 | import random 6 | 7 | import third_party.midi_processor.processor as midi_processor 8 | 9 | JSON_FILE = "maestro-v2.0.0.json" 10 | 11 | # prep_midi 12 | def prep_maestro_midi(maestro_root, output_dir): 13 | """ 14 | ---------- 15 | Author: Damon Gwinn 16 | ---------- 17 | Pre-processes the maestro dataset, putting processed midi data (train, eval, test) into the 18 | given output folder 19 | ---------- 20 | """ 21 | 22 | train_dir = os.path.join(output_dir, "train") 23 | os.makedirs(train_dir, exist_ok=True) 24 | val_dir = os.path.join(output_dir, "val") 25 | os.makedirs(val_dir, exist_ok=True) 26 | test_dir = os.path.join(output_dir, "test") 27 | os.makedirs(test_dir, exist_ok=True) 28 | 29 | maestro_json_file = os.path.join(maestro_root, JSON_FILE) 30 | if(not os.path.isfile(maestro_json_file)): 31 | print("ERROR: Could not find file:", maestro_json_file) 32 | return False 33 | 34 | maestro_json = json.load(open(maestro_json_file, "r")) 35 | print("Found", len(maestro_json), "pieces") 36 | print("Preprocessing...") 37 | 38 | total_count = 0 39 | train_count = 0 40 | val_count = 0 41 | test_count = 0 42 | 43 | for piece in maestro_json: 44 | mid = os.path.join(maestro_root, piece["midi_filename"]) 45 | split_type = piece["split"] 46 | f_name = mid.split("/")[-1] + ".pickle" 47 | 48 | if(split_type == "train"): 49 | o_file = os.path.join(train_dir, f_name) 50 | train_count += 1 51 | elif(split_type == "validation"): 52 | o_file = os.path.join(val_dir, f_name) 53 | val_count += 1 54 | elif(split_type == "test"): 55 | o_file = os.path.join(test_dir, f_name) 56 | test_count += 1 57 | else: 58 | print("ERROR: Unrecognized split type:", split_type) 59 | return False 60 | 61 | prepped = midi_processor.encode_midi(mid) 62 | 63 | o_stream = open(o_file, "wb") 64 | pickle.dump(prepped, o_stream) 65 | o_stream.close() 66 | 67 | total_count += 1 68 | if(total_count % 50 == 0): 69 | print(total_count, "/", len(maestro_json)) 70 | 71 | print("Num Train:", train_count) 72 | print("Num Val:", val_count) 73 | print("Num Test:", test_count) 74 | return True 75 | 76 | def prep_custom_midi(custom_midi_root, output_dir, valid_p = 0.1, test_p = 0.2): 77 | """ 78 | ---------- 79 | Author: Corentin Nelias 80 | ---------- 81 | Pre-processes custom midi files that are not part of the maestro dataset, putting processed midi data (train, eval, test) into the 82 | given output folder. 83 | ---------- 84 | """ 85 | train_dir = os.path.join(output_dir, "train") 86 | os.makedirs(train_dir, exist_ok=True) 87 | val_dir = os.path.join(output_dir, "val") 88 | os.makedirs(val_dir, exist_ok=True) 89 | test_dir = os.path.join(output_dir, "test") 90 | os.makedirs(test_dir, exist_ok=True) 91 | 92 | print("Found", len(os.listdir(custom_midi_root)), "pieces") 93 | print("Preprocessing custom data...") 94 | total_count = 0 95 | train_count = 0 96 | val_count = 0 97 | test_count = 0 98 | 99 | for piece in os.listdir(custom_midi_root): 100 | #deciding whether the data should be part of train, valid or test dataset 101 | is_train = True if random.random() > valid_p else False 102 | if not is_train: 103 | is_valid = True if random.random() > test_p else False 104 | if is_train: 105 | split_type = "train" 106 | elif is_valid: 107 | split_type = "validation" 108 | else: 109 | split_type = "test" 110 | 111 | mid = os.path.join(custom_midi_root, piece) 112 | f_name = piece.split(".")[0] + ".pickle" 113 | 114 | if(split_type == "train"): 115 | o_file = os.path.join(train_dir, f_name) 116 | train_count += 1 117 | elif(split_type == "validation"): 118 | o_file = os.path.join(val_dir, f_name) 119 | val_count += 1 120 | elif(split_type == "test"): 121 | o_file = os.path.join(test_dir, f_name) 122 | test_count += 1 123 | 124 | prepped = midi_processor.encode_midi(mid) 125 | 126 | o_stream = open(o_file, "wb") 127 | pickle.dump(prepped, o_stream) 128 | o_stream.close() 129 | 130 | total_count += 1 131 | if(total_count % 50 == 0): 132 | print(total_count, "/", len(os.listdir(custom_midi_root))) 133 | 134 | print("Num Train:", train_count) 135 | print("Num Val:", val_count) 136 | print("Num Test:", test_count) 137 | return True 138 | 139 | 140 | # parse_args 141 | def parse_args(): 142 | """ 143 | ---------- 144 | Author: Damon Gwinn 145 | ---------- 146 | Parses arguments for preprocess_midi using argparse 147 | ---------- 148 | """ 149 | 150 | parser = argparse.ArgumentParser() 151 | 152 | parser.add_argument("root", type=str, help="Root folder for the Maestro dataset or for custom data.") 153 | parser.add_argument("-output_dir", type=str, default="./dataset/e_piano", help="Output folder to put the preprocessed midi into.") 154 | parser.add_argument("--custom_dataset", action="store_true", help="Whether or not the specified root folder contains custom data.") 155 | 156 | return parser.parse_args() 157 | 158 | # main 159 | def main(): 160 | """ 161 | ---------- 162 | Author: Damon Gwinn 163 | ---------- 164 | Entry point. Preprocesses maestro and saved midi to specified output folder. 165 | ---------- 166 | """ 167 | 168 | args = parse_args() 169 | root = args.root 170 | output_dir = args.output_dir 171 | 172 | print("Preprocessing midi files and saving to", output_dir) 173 | if args.custom_dataset: 174 | prep_custom_midi(root, output_dir) 175 | else: 176 | prep_maestro_midi(root, output_dir) 177 | print("Done!") 178 | print("") 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /third_party/references.txt: -------------------------------------------------------------------------------- 1 | jason9693/midi-neural-processor 2 | https://github.com/jason9693/midi-neural-processor 3 | 4 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import shutil 4 | import torch 5 | import torch.nn as nn 6 | from torch.optim.lr_scheduler import LambdaLR 7 | from torch.utils.data import DataLoader 8 | from torch.optim import Adam 9 | 10 | from dataset.e_piano import create_epiano_datasets, compute_epiano_accuracy 11 | 12 | from model.music_transformer import MusicTransformer 13 | from model.loss import SmoothCrossEntropyLoss 14 | 15 | from utilities.constants import * 16 | from utilities.device import get_device, use_cuda 17 | from utilities.lr_scheduling import LrStepTracker, get_lr 18 | from utilities.argument_funcs import parse_train_args, print_train_args, write_model_params 19 | from utilities.run_model import train_epoch, eval_model 20 | 21 | CSV_HEADER = ["Epoch", "Learn rate", "Avg Train loss", "Train Accuracy", "Avg Eval loss", "Eval accuracy"] 22 | 23 | # Baseline is an untrained epoch that we evaluate as a baseline loss and accuracy 24 | BASELINE_EPOCH = -1 25 | 26 | # main 27 | def main(): 28 | """ 29 | ---------- 30 | Author: Damon Gwinn 31 | ---------- 32 | Entry point. Trains a model specified by command line arguments 33 | ---------- 34 | """ 35 | 36 | args = parse_train_args() 37 | print_train_args(args) 38 | 39 | if(args.force_cpu): 40 | use_cuda(False) 41 | print("WARNING: Forced CPU usage, expect model to perform slower") 42 | print("") 43 | 44 | os.makedirs(args.output_dir, exist_ok=True) 45 | 46 | ##### Output prep ##### 47 | params_file = os.path.join(args.output_dir, "model_params.txt") 48 | write_model_params(args, params_file) 49 | 50 | weights_folder = os.path.join(args.output_dir, "weights") 51 | os.makedirs(weights_folder, exist_ok=True) 52 | 53 | results_folder = os.path.join(args.output_dir, "results") 54 | os.makedirs(results_folder, exist_ok=True) 55 | 56 | results_file = os.path.join(results_folder, "results.csv") 57 | best_loss_file = os.path.join(results_folder, "best_loss_weights.pickle") 58 | best_acc_file = os.path.join(results_folder, "best_acc_weights.pickle") 59 | best_text = os.path.join(results_folder, "best_epochs.txt") 60 | 61 | ##### Tensorboard ##### 62 | if(args.no_tensorboard): 63 | tensorboard_summary = None 64 | else: 65 | from torch.utils.tensorboard import SummaryWriter 66 | 67 | tensorboad_dir = os.path.join(args.output_dir, "tensorboard") 68 | tensorboard_summary = SummaryWriter(log_dir=tensorboad_dir) 69 | 70 | ##### Datasets ##### 71 | train_dataset, val_dataset, test_dataset = create_epiano_datasets(args.input_dir, args.max_sequence) 72 | 73 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.n_workers, shuffle=True) 74 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.n_workers) 75 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.n_workers) 76 | 77 | model = MusicTransformer(n_layers=args.n_layers, num_heads=args.num_heads, 78 | d_model=args.d_model, dim_feedforward=args.dim_feedforward, dropout=args.dropout, 79 | max_sequence=args.max_sequence, rpr=args.rpr).to(get_device()) 80 | 81 | ##### Continuing from previous training session ##### 82 | start_epoch = BASELINE_EPOCH 83 | if(args.continue_weights is not None): 84 | if(args.continue_epoch is None): 85 | print("ERROR: Need epoch number to continue from (-continue_epoch) when using continue_weights") 86 | return 87 | else: 88 | model.load_state_dict(torch.load(args.continue_weights)) 89 | start_epoch = args.continue_epoch 90 | elif(args.continue_epoch is not None): 91 | print("ERROR: Need continue weights (-continue_weights) when using continue_epoch") 92 | return 93 | 94 | ##### Lr Scheduler vs static lr ##### 95 | if(args.lr is None): 96 | if(args.continue_epoch is None): 97 | init_step = 0 98 | else: 99 | init_step = args.continue_epoch * len(train_loader) 100 | 101 | lr = LR_DEFAULT_START 102 | lr_stepper = LrStepTracker(args.d_model, SCHEDULER_WARMUP_STEPS, init_step) 103 | else: 104 | lr = args.lr 105 | 106 | ##### Not smoothing evaluation loss ##### 107 | eval_loss_func = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD) 108 | 109 | ##### SmoothCrossEntropyLoss or CrossEntropyLoss for training ##### 110 | if(args.ce_smoothing is None): 111 | train_loss_func = eval_loss_func 112 | else: 113 | train_loss_func = SmoothCrossEntropyLoss(args.ce_smoothing, VOCAB_SIZE, ignore_index=TOKEN_PAD) 114 | 115 | ##### Optimizer ##### 116 | opt = Adam(model.parameters(), lr=lr, betas=(ADAM_BETA_1, ADAM_BETA_2), eps=ADAM_EPSILON) 117 | 118 | if(args.lr is None): 119 | lr_scheduler = LambdaLR(opt, lr_stepper.step) 120 | else: 121 | lr_scheduler = None 122 | 123 | ##### Tracking best evaluation accuracy ##### 124 | best_eval_acc = 0.0 125 | best_eval_acc_epoch = -1 126 | best_eval_loss = float("inf") 127 | best_eval_loss_epoch = -1 128 | 129 | ##### Results reporting ##### 130 | if(not os.path.isfile(results_file)): 131 | with open(results_file, "w", newline="") as o_stream: 132 | writer = csv.writer(o_stream) 133 | writer.writerow(CSV_HEADER) 134 | 135 | 136 | ##### TRAIN LOOP ##### 137 | for epoch in range(start_epoch, args.epochs): 138 | # Baseline has no training and acts as a base loss and accuracy (epoch 0 in a sense) 139 | if(epoch > BASELINE_EPOCH): 140 | print(SEPERATOR) 141 | print("NEW EPOCH:", epoch+1) 142 | print(SEPERATOR) 143 | print("") 144 | 145 | # Train 146 | train_epoch(epoch+1, model, train_loader, train_loss_func, opt, lr_scheduler, args.print_modulus) 147 | 148 | print(SEPERATOR) 149 | print("Evaluating:") 150 | else: 151 | print(SEPERATOR) 152 | print("Baseline model evaluation (Epoch 0):") 153 | 154 | # Eval 155 | train_loss, train_acc = eval_model(model, train_loader, train_loss_func) 156 | eval_loss, eval_acc = eval_model(model, test_loader, eval_loss_func) 157 | 158 | # Learn rate 159 | lr = get_lr(opt) 160 | 161 | print("Epoch:", epoch+1) 162 | print("Avg train loss:", train_loss) 163 | print("Avg train acc:", train_acc) 164 | print("Avg eval loss:", eval_loss) 165 | print("Avg eval acc:", eval_acc) 166 | print(SEPERATOR) 167 | print("") 168 | 169 | new_best = False 170 | 171 | if(eval_acc > best_eval_acc): 172 | best_eval_acc = eval_acc 173 | best_eval_acc_epoch = epoch+1 174 | torch.save(model.state_dict(), best_acc_file) 175 | new_best = True 176 | 177 | if(eval_loss < best_eval_loss): 178 | best_eval_loss = eval_loss 179 | best_eval_loss_epoch = epoch+1 180 | torch.save(model.state_dict(), best_loss_file) 181 | new_best = True 182 | 183 | # Writing out new bests 184 | if(new_best): 185 | with open(best_text, "w") as o_stream: 186 | print("Best eval acc epoch:", best_eval_acc_epoch, file=o_stream) 187 | print("Best eval acc:", best_eval_acc, file=o_stream) 188 | print("") 189 | print("Best eval loss epoch:", best_eval_loss_epoch, file=o_stream) 190 | print("Best eval loss:", best_eval_loss, file=o_stream) 191 | 192 | 193 | if(not args.no_tensorboard): 194 | tensorboard_summary.add_scalar("Avg_CE_loss/train", train_loss, global_step=epoch+1) 195 | tensorboard_summary.add_scalar("Avg_CE_loss/eval", eval_loss, global_step=epoch+1) 196 | tensorboard_summary.add_scalar("Accuracy/train", train_acc, global_step=epoch+1) 197 | tensorboard_summary.add_scalar("Accuracy/eval", eval_acc, global_step=epoch+1) 198 | tensorboard_summary.add_scalar("Learn_rate/train", lr, global_step=epoch+1) 199 | tensorboard_summary.flush() 200 | 201 | if((epoch+1) % args.weight_modulus == 0): 202 | epoch_str = str(epoch+1).zfill(PREPEND_ZEROS_WIDTH) 203 | path = os.path.join(weights_folder, "epoch_" + epoch_str + ".pickle") 204 | torch.save(model.state_dict(), path) 205 | 206 | with open(results_file, "a", newline="") as o_stream: 207 | writer = csv.writer(o_stream) 208 | writer.writerow([epoch+1, lr, train_loss, train_acc, eval_loss, eval_acc]) 209 | 210 | # Sanity check just to make sure everything is gone 211 | if(not args.no_tensorboard): 212 | tensorboard_summary.flush() 213 | 214 | return 215 | 216 | 217 | if __name__ == "__main__": 218 | main() 219 | -------------------------------------------------------------------------------- /utilities/argument_funcs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from .constants import SEPERATOR 4 | 5 | # parse_train_args 6 | def parse_train_args(): 7 | """ 8 | ---------- 9 | Author: Damon Gwinn 10 | ---------- 11 | Argparse arguments for training a model 12 | ---------- 13 | """ 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument("-input_dir", type=str, default="./dataset/e_piano", help="Folder of preprocessed and pickled midi files") 18 | parser.add_argument("-output_dir", type=str, default="./saved_models", help="Folder to save model weights. Saves one every epoch") 19 | parser.add_argument("-weight_modulus", type=int, default=1, help="How often to save epoch weights (ex: value of 10 means save every 10 epochs)") 20 | parser.add_argument("-print_modulus", type=int, default=1, help="How often to print train results for a batch (batch loss, learn rate, etc.)") 21 | 22 | parser.add_argument("-n_workers", type=int, default=1, help="Number of threads for the dataloader") 23 | parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available") 24 | parser.add_argument("--no_tensorboard", action="store_true", help="Turns off tensorboard result reporting") 25 | 26 | parser.add_argument("-continue_weights", type=str, default=None, help="Model weights to continue training based on") 27 | parser.add_argument("-continue_epoch", type=int, default=None, help="Epoch the continue_weights model was at") 28 | 29 | parser.add_argument("-lr", type=float, default=None, help="Constant learn rate. Leave as None for a custom scheduler.") 30 | parser.add_argument("-ce_smoothing", type=float, default=None, help="Smoothing parameter for smoothed cross entropy loss (defaults to no smoothing)") 31 | parser.add_argument("-batch_size", type=int, default=2, help="Batch size to use") 32 | parser.add_argument("-epochs", type=int, default=100, help="Number of epochs to use") 33 | 34 | parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations") 35 | parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider") 36 | parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use") 37 | parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention") 38 | parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)") 39 | 40 | parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer") 41 | 42 | parser.add_argument("-dropout", type=float, default=0.1, help="Dropout rate") 43 | 44 | return parser.parse_args() 45 | 46 | # print_train_args 47 | def print_train_args(args): 48 | """ 49 | ---------- 50 | Author: Damon Gwinn 51 | ---------- 52 | Prints training arguments 53 | ---------- 54 | """ 55 | 56 | print(SEPERATOR) 57 | print("input_dir:", args.input_dir) 58 | print("output_dir:", args.output_dir) 59 | print("weight_modulus:", args.weight_modulus) 60 | print("print_modulus:", args.print_modulus) 61 | print("") 62 | print("n_workers:", args.n_workers) 63 | print("force_cpu:", args.force_cpu) 64 | print("tensorboard:", not args.no_tensorboard) 65 | print("") 66 | print("continue_weights:", args.continue_weights) 67 | print("continue_epoch:", args.continue_epoch) 68 | print("") 69 | print("lr:", args.lr) 70 | print("ce_smoothing:", args.ce_smoothing) 71 | print("batch_size:", args.batch_size) 72 | print("epochs:", args.epochs) 73 | print("") 74 | print("rpr:", args.rpr) 75 | print("max_sequence:", args.max_sequence) 76 | print("n_layers:", args.n_layers) 77 | print("num_heads:", args.num_heads) 78 | print("d_model:", args.d_model) 79 | print("") 80 | print("dim_feedforward:", args.dim_feedforward) 81 | print("dropout:", args.dropout) 82 | print(SEPERATOR) 83 | print("") 84 | 85 | # parse_eval_args 86 | def parse_eval_args(): 87 | """ 88 | ---------- 89 | Author: Damon Gwinn 90 | ---------- 91 | Argparse arguments for evaluating a model 92 | ---------- 93 | """ 94 | 95 | parser = argparse.ArgumentParser() 96 | 97 | parser.add_argument("-dataset_dir", type=str, default="./dataset/e_piano", help="Folder of preprocessed and pickled midi files") 98 | parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()") 99 | parser.add_argument("-n_workers", type=int, default=1, help="Number of threads for the dataloader") 100 | parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available") 101 | 102 | parser.add_argument("-batch_size", type=int, default=2, help="Batch size to use") 103 | 104 | parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations") 105 | parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider in the model") 106 | parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use") 107 | parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention") 108 | parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)") 109 | 110 | parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer") 111 | 112 | return parser.parse_args() 113 | 114 | # print_eval_args 115 | def print_eval_args(args): 116 | """ 117 | ---------- 118 | Author: Damon Gwinn 119 | ---------- 120 | Prints evaluation arguments 121 | ---------- 122 | """ 123 | 124 | print(SEPERATOR) 125 | print("dataset_dir:", args.dataset_dir) 126 | print("model_weights:", args.model_weights) 127 | print("n_workers:", args.n_workers) 128 | print("force_cpu:", args.force_cpu) 129 | print("") 130 | print("batch_size:", args.batch_size) 131 | print("") 132 | print("rpr:", args.rpr) 133 | print("max_sequence:", args.max_sequence) 134 | print("n_layers:", args.n_layers) 135 | print("num_heads:", args.num_heads) 136 | print("d_model:", args.d_model) 137 | print("") 138 | print("dim_feedforward:", args.dim_feedforward) 139 | print(SEPERATOR) 140 | print("") 141 | 142 | # parse_generate_args 143 | def parse_generate_args(): 144 | """ 145 | ---------- 146 | Author: Damon Gwinn 147 | ---------- 148 | Argparse arguments for generation 149 | ---------- 150 | """ 151 | 152 | parser = argparse.ArgumentParser() 153 | 154 | parser.add_argument("-midi_root", type=str, default="./dataset/e_piano/", help="Midi file to prime the generator with") 155 | parser.add_argument("-output_dir", type=str, default="./gen", help="Folder to write generated midi to") 156 | parser.add_argument("-primer_file", type=str, default=None, help="File path or integer index to the evaluation dataset. Default is to select a random index.") 157 | parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available") 158 | 159 | parser.add_argument("-target_seq_length", type=int, default=1024, help="Target length you'd like the midi to be") 160 | parser.add_argument("-num_prime", type=int, default=256, help="Amount of messages to prime the generator with") 161 | parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()") 162 | parser.add_argument("-beam", type=int, default=0, help="Beam search k. 0 for random probability sample and 1 for greedy") 163 | 164 | parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations") 165 | parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider") 166 | parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use") 167 | parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention") 168 | parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)") 169 | 170 | parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer") 171 | 172 | return parser.parse_args() 173 | 174 | # print_generate_args 175 | def print_generate_args(args): 176 | """ 177 | ---------- 178 | Author: Damon Gwinn 179 | ---------- 180 | Prints generation arguments 181 | ---------- 182 | """ 183 | 184 | print(SEPERATOR) 185 | print("midi_root:", args.midi_root) 186 | print("output_dir:", args.output_dir) 187 | print("primer_file:", args.primer_file) 188 | print("force_cpu:", args.force_cpu) 189 | print("") 190 | print("target_seq_length:", args.target_seq_length) 191 | print("num_prime:", args.num_prime) 192 | print("model_weights:", args.model_weights) 193 | print("beam:", args.beam) 194 | print("") 195 | print("rpr:", args.rpr) 196 | print("max_sequence:", args.max_sequence) 197 | print("n_layers:", args.n_layers) 198 | print("num_heads:", args.num_heads) 199 | print("d_model:", args.d_model) 200 | print("") 201 | print("dim_feedforward:", args.dim_feedforward) 202 | print(SEPERATOR) 203 | print("") 204 | 205 | # write_model_params 206 | def write_model_params(args, output_file): 207 | """ 208 | ---------- 209 | Author: Damon Gwinn 210 | ---------- 211 | Writes given training parameters to text file 212 | ---------- 213 | """ 214 | 215 | o_stream = open(output_file, "w") 216 | 217 | o_stream.write("rpr: " + str(args.rpr) + "\n") 218 | o_stream.write("lr: " + str(args.lr) + "\n") 219 | o_stream.write("ce_smoothing: " + str(args.ce_smoothing) + "\n") 220 | o_stream.write("batch_size: " + str(args.batch_size) + "\n") 221 | o_stream.write("max_sequence: " + str(args.max_sequence) + "\n") 222 | o_stream.write("n_layers: " + str(args.n_layers) + "\n") 223 | o_stream.write("num_heads: " + str(args.num_heads) + "\n") 224 | o_stream.write("d_model: " + str(args.d_model) + "\n") 225 | o_stream.write("dim_feedforward: " + str(args.dim_feedforward) + "\n") 226 | o_stream.write("dropout: " + str(args.dropout) + "\n") 227 | 228 | o_stream.close() 229 | -------------------------------------------------------------------------------- /utilities/constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from third_party.midi_processor.processor import RANGE_NOTE_ON, RANGE_NOTE_OFF, RANGE_VEL, RANGE_TIME_SHIFT 4 | 5 | SEPERATOR = "=========================" 6 | 7 | # Taken from the paper 8 | ADAM_BETA_1 = 0.9 9 | ADAM_BETA_2 = 0.98 10 | ADAM_EPSILON = 10e-9 11 | 12 | LR_DEFAULT_START = 1.0 13 | SCHEDULER_WARMUP_STEPS = 4000 14 | # LABEL_SMOOTHING_E = 0.1 15 | 16 | # DROPOUT_P = 0.1 17 | 18 | TOKEN_END = RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_VEL + RANGE_TIME_SHIFT 19 | TOKEN_PAD = TOKEN_END + 1 20 | 21 | VOCAB_SIZE = TOKEN_PAD + 1 22 | 23 | TORCH_FLOAT = torch.float32 24 | TORCH_INT = torch.int32 25 | 26 | TORCH_LABEL_TYPE = torch.long 27 | 28 | PREPEND_ZEROS_WIDTH = 4 29 | -------------------------------------------------------------------------------- /utilities/device.py: -------------------------------------------------------------------------------- 1 | # For all things related to devices 2 | #### ONLY USE PROVIDED FUNCTIONS, DO NOT USE GLOBAL CONSTANTS #### 3 | 4 | import torch 5 | 6 | TORCH_CPU_DEVICE = torch.device("cpu") 7 | 8 | if(torch.cuda.device_count() > 0): 9 | TORCH_CUDA_DEVICE = torch.device("cuda") 10 | else: 11 | print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----") 12 | print("") 13 | TORCH_CUDA_DEVICE = None 14 | 15 | USE_CUDA = True 16 | 17 | # use_cuda 18 | def use_cuda(cuda_bool): 19 | """ 20 | ---------- 21 | Author: Damon Gwinn 22 | ---------- 23 | Sets whether to use CUDA (if available), or use the CPU (not recommended) 24 | ---------- 25 | """ 26 | 27 | global USE_CUDA 28 | USE_CUDA = cuda_bool 29 | 30 | # get_device 31 | def get_device(): 32 | """ 33 | ---------- 34 | Author: Damon Gwinn 35 | ---------- 36 | Grabs the default device. Default device is CUDA if available and use_cuda is not False, CPU otherwise. 37 | ---------- 38 | """ 39 | 40 | if((not USE_CUDA) or (TORCH_CUDA_DEVICE is None)): 41 | return TORCH_CPU_DEVICE 42 | else: 43 | return TORCH_CUDA_DEVICE 44 | 45 | # cuda_device 46 | def cuda_device(): 47 | """ 48 | ---------- 49 | Author: Damon Gwinn 50 | ---------- 51 | Grabs the cuda device (may be None if CUDA is not available) 52 | ---------- 53 | """ 54 | 55 | return TORCH_CUDA_DEVICE 56 | 57 | # cpu_device 58 | def cpu_device(): 59 | """ 60 | ---------- 61 | Author: Damon Gwinn 62 | ---------- 63 | Grabs the cpu device 64 | ---------- 65 | """ 66 | 67 | return TORCH_CPU_DEVICE 68 | -------------------------------------------------------------------------------- /utilities/lr_scheduling.py: -------------------------------------------------------------------------------- 1 | #Library Imports 2 | import math 3 | 4 | #Using Adam optimizer with 5 | #Beta_1=0.9, Beta_2=0.98, and Epsilon=10^-9 6 | 7 | #Learning rate varies over course of training 8 | #lrate = sqrt(d_model)*min((1/sqrt(step_num)), step_num*(1/warmup_steps*sqrt(warmup_steps))) 9 | 10 | # LrStepTracker 11 | class LrStepTracker: 12 | """ 13 | ---------- 14 | Author: Ryan Marshall 15 | Modified: Damon Gwinn 16 | ---------- 17 | Class for custom learn rate scheduler (to be used by torch.optim.lr_scheduler.LambdaLR). 18 | 19 | Learn rate for each step (batch) given the warmup steps is: 20 | lr = [ 1/sqrt(d_model) ] * min[ 1/sqrt(step) , step * (warmup_steps)^-1.5 ] 21 | 22 | This is from Attention is All you Need (https://arxiv.org/abs/1706.03762) 23 | ---------- 24 | """ 25 | 26 | def __init__(self, model_dim=512, warmup_steps=4000, init_steps=0): 27 | # Store Values 28 | self.warmup_steps = warmup_steps 29 | self.model_dim = model_dim 30 | self.init_steps = init_steps 31 | 32 | # Begin Calculations 33 | self.invsqrt_dim = (1 / math.sqrt(model_dim)) 34 | self.invsqrt_warmup = (1 / (warmup_steps * math.sqrt(warmup_steps))) 35 | 36 | # step 37 | def step(self, step): 38 | """ 39 | ---------- 40 | Author: Ryan Marshall 41 | Modified: Damon Gwinn 42 | ---------- 43 | Method to pass to LambdaLR. Increments the step and computes the new learn rate. 44 | ---------- 45 | """ 46 | 47 | step += self.init_steps 48 | if(step <= self.warmup_steps): 49 | return self.invsqrt_dim * self.invsqrt_warmup * step 50 | else: 51 | invsqrt_step = (1 / math.sqrt(step)) 52 | return self.invsqrt_dim * invsqrt_step 53 | 54 | # get_lr 55 | def get_lr(optimizer): 56 | """ 57 | ---------- 58 | Author: Damon Gwinn 59 | ---------- 60 | Hack to get the current learn rate of the model 61 | ---------- 62 | """ 63 | 64 | for param_group in optimizer.param_groups: 65 | return param_group['lr'] 66 | -------------------------------------------------------------------------------- /utilities/run_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | from .constants import * 5 | from utilities.device import get_device 6 | from .lr_scheduling import get_lr 7 | 8 | from dataset.e_piano import compute_epiano_accuracy 9 | 10 | 11 | # train_epoch 12 | def train_epoch(cur_epoch, model, dataloader, loss, opt, lr_scheduler=None, print_modulus=1): 13 | """ 14 | ---------- 15 | Author: Damon Gwinn 16 | ---------- 17 | Trains a single model epoch 18 | ---------- 19 | """ 20 | 21 | out = -1 22 | model.train() 23 | for batch_num, batch in enumerate(dataloader): 24 | time_before = time.time() 25 | 26 | opt.zero_grad() 27 | 28 | x = batch[0].to(get_device()) 29 | tgt = batch[1].to(get_device()) 30 | 31 | y = model(x) 32 | 33 | y = y.reshape(y.shape[0] * y.shape[1], -1) 34 | tgt = tgt.flatten() 35 | 36 | out = loss.forward(y, tgt) 37 | 38 | out.backward() 39 | opt.step() 40 | 41 | if(lr_scheduler is not None): 42 | lr_scheduler.step() 43 | 44 | time_after = time.time() 45 | time_took = time_after - time_before 46 | 47 | if((batch_num+1) % print_modulus == 0): 48 | print(SEPERATOR) 49 | print("Epoch", cur_epoch, " Batch", batch_num+1, "/", len(dataloader)) 50 | print("LR:", get_lr(opt)) 51 | print("Train loss:", float(out)) 52 | print("") 53 | print("Time (s):", time_took) 54 | print(SEPERATOR) 55 | print("") 56 | 57 | return 58 | 59 | # eval_model 60 | def eval_model(model, dataloader, loss): 61 | """ 62 | ---------- 63 | Author: Damon Gwinn 64 | ---------- 65 | Evaluates the model and prints the average loss and accuracy 66 | ---------- 67 | """ 68 | 69 | model.eval() 70 | 71 | avg_acc = -1 72 | avg_loss = -1 73 | with torch.set_grad_enabled(False): 74 | n_test = len(dataloader) 75 | sum_loss = 0.0 76 | sum_acc = 0.0 77 | for batch in dataloader: 78 | x = batch[0].to(get_device()) 79 | tgt = batch[1].to(get_device()) 80 | 81 | y = model(x) 82 | 83 | sum_acc += float(compute_epiano_accuracy(y, tgt)) 84 | 85 | y = y.reshape(y.shape[0] * y.shape[1], -1) 86 | tgt = tgt.flatten() 87 | 88 | out = loss.forward(y, tgt) 89 | 90 | sum_loss += float(out) 91 | 92 | avg_loss = sum_loss / n_test 93 | avg_acc = sum_acc / n_test 94 | 95 | return avg_loss, avg_acc 96 | --------------------------------------------------------------------------------