├── .gitignore ├── Beam.py ├── DataPrep.py ├── Embedding.py ├── Encode_Decode_Layers.py ├── MaskGen.py ├── Models.py ├── PositionalFeedForward.py ├── Process.py ├── README.md ├── Sublayers.py ├── datasets └── JSB-Chorales-dataset │ └── Jsb16thSeparated.npz ├── gen_sequence_audio.ipynb ├── gen_train_audio.ipynb ├── gen_training_plots.py ├── generate.py ├── outputs ├── gen_output │ ├── gen_output_baselineTF_300epoch.npy │ ├── gen_output_baselineTFconcat_1000epoch.npy │ ├── gen_output_baselineTFconcat_300epoch.npy │ └── gen_output_relativeTF_300epoch.npy └── loss │ ├── t_loss_baselineTF_300epoch.npy │ ├── t_loss_baselineTFconcat_300epoch.npy │ ├── t_loss_final_baselineTFconcat_1000epoch.npy │ ├── t_loss_relativeTF_300epoch.npy │ ├── v_loss_baselineTF_300epoch.npy │ ├── v_loss_baselineTFconcat_300epoch.npy │ ├── v_loss_final_baselineTFconcat_1000epoch.npy │ └── v_loss_relativeTF_300epoch.npy ├── plots ├── BaselineTF_300epochs.png ├── BaselineTFconcat_300epochs.png ├── relativeTF_300epochs.png ├── trainingLoss.png └── validationLoss.png ├── report.pdf └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # maestro dataset 107 | datasets/maestro 108 | -------------------------------------------------------------------------------- /Beam.py: -------------------------------------------------------------------------------- 1 | # Filename: Beam.py 2 | # Date Created: 15-Mar-2019 2:42:12 pm 3 | # Description: Functions used for beam search. 4 | import torch 5 | import torch.nn.functional as F 6 | from MaskGen import nopeak_mask 7 | import math 8 | 9 | def init_vars(src, model, opt): 10 | # outputs used for the decoder is the starting pitches 11 | outputs = src 12 | 13 | # encoder pass 14 | src_mask = (src != opt.pad_token).unsqueeze(-2).to(opt.device) 15 | e_output = model.encoder(src, src_mask) 16 | 17 | # decoder pass 18 | trg_mask = nopeak_mask(src.shape[1], opt) 19 | out = model.decoder(outputs, e_output, src_mask, trg_mask) 20 | out = model.linear(out) 21 | 22 | # final fc layer 23 | out = F.softmax(out, dim=-1) 24 | 25 | # calculate probablites for beam search 26 | # takes the last output from the model, hence out[:, -1] 27 | probs, ix = out[:, -1].data.topk(opt.k) 28 | log_scores = torch.Tensor([math.log(prob) for prob in probs.data[0]]).unsqueeze(0) 29 | 30 | # store the model outputs 31 | outputs = torch.zeros(opt.k, opt.max_seq_len).long().to(opt.device) 32 | outputs[:, 0:src.shape[1]] = src 33 | outputs[:, src.shape[1]] = ix[0] 34 | 35 | # store the encoder output to be used later 36 | e_outputs = torch.zeros(opt.k, e_output.size(-2),e_output.size(-1)).to(opt.device) 37 | e_outputs[:, :] = e_output[0] 38 | 39 | return outputs, e_outputs, log_scores 40 | 41 | def k_best_outputs(outputs, out, log_scores, i, k): 42 | # calculate probablities for each step in the sequence 43 | probs, ix = out[:, -1].data.topk(k) 44 | log_probs = torch.Tensor([math.log(p) for p in probs.data.view(-1)]).view(k, -1) + log_scores.transpose(0,1) 45 | k_probs, k_ix = log_probs.view(-1).topk(k) 46 | 47 | row = k_ix // k 48 | col = k_ix % k 49 | 50 | # update outputs 51 | outputs[:, :i] = outputs[row, :i] 52 | outputs[:, i] = ix[row, col] 53 | 54 | log_scores = k_probs.unsqueeze(0) 55 | 56 | return outputs, log_scores 57 | 58 | def beam_search(src, model, opt): 59 | outputs, e_outputs, log_scores = init_vars(src, model, opt) 60 | init_start_len = outputs.shape[0] 61 | src_mask = (src != opt.pad_token).unsqueeze(-2).to(opt.device) 62 | 63 | for i in range(init_start_len, opt.max_seq_len): 64 | # Just comment this block of code if only use encoder once at the start 65 | src_mask = (outputs[0,:i].unsqueeze(-2) != opt.pad_token).unsqueeze(-2).to(opt.device) 66 | e_output = model.encoder(outputs[0,:i].unsqueeze(-2), src_mask) 67 | e_outputs = torch.zeros(opt.k, e_output.size(-2),e_output.size(-1)).to(opt.device) 68 | e_outputs[:, :] = e_output[0] 69 | 70 | trg_mask = nopeak_mask(i, opt) 71 | out = model.linear(model.decoder(outputs[:,:i], 72 | e_outputs, src_mask, trg_mask)) 73 | out = F.softmax(out, dim=-1) 74 | 75 | outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, opt.k) 76 | 77 | # return the one with the largest log_scores 78 | return outputs[0] 79 | -------------------------------------------------------------------------------- /DataPrep.py: -------------------------------------------------------------------------------- 1 | # Filename: DataPrep.py 2 | # Date Created: 08-Mar-2019 10:01:18 pm 3 | # Description: Functions for preparing the dataset for training and evaluation. 4 | 5 | import torch 6 | from torch.utils.data.dataset import Dataset 7 | import numpy as np 8 | from torch.autograd import Variable 9 | 10 | def tensorFromSequence(sequence): 11 | """ 12 | Generate tensors from the sequence in numpy. 13 | """ 14 | output = torch.tensor(sequence).long() 15 | 16 | return output 17 | 18 | 19 | def PrepareData(npz_file, split='train', L=1024): 20 | """ 21 | Function to prepare the data into pairs (input, target). 22 | Adds [PAD], [SOS] and [EOS] tokens into the data, 23 | where [PAD]=1, [SOS]=2, [EOS]=3. 24 | Limits the sequence to length of L. 25 | """ 26 | print("Preparing data for",split,"split...") 27 | # Load in the data 28 | full_data = np.load(npz_file, fix_imports=True, encoding="latin1", allow_pickle=True) 29 | data = full_data[split] 30 | 31 | # Extract the vocab from file 32 | vocab = GenerateVocab(npz_file) 33 | # Generate new vocab to map to later 34 | new_vocab = np.arange(len(vocab)) 35 | 36 | # Initialize the tokens 37 | pad_token = np.array([[1]]) 38 | 39 | # Repeat for all samples in data 40 | pairs = [] 41 | for samples in data: 42 | # Serialise the dataset so that the resulting sequence is 43 | # S_1 A_1 T_1, B_2 S_2 A_2 T_2 B_2, ... 44 | 45 | # Generate input 46 | input_seq = samples.flatten() 47 | 48 | # Cut off the samples so that it has length of 1024 49 | if(len(input_seq) >= L): 50 | # input_seq = input_seq[:L-1] 51 | input_seq = input_seq[:L] 52 | 53 | # Set the NaN values to 0 and reshape accordingly 54 | input_seq = np.nan_to_num(input_seq.reshape(1,input_seq.size)) 55 | 56 | # Generate target 57 | output_seq = input_seq[:,1:] 58 | 59 | # For both sequences, pad to sequence length L 60 | pad_array = pad_token * np.ones((1,L-input_seq.shape[1])) 61 | input_seq = np.append(input_seq, pad_array,axis=1) 62 | pad_array = pad_token * np.ones((1,L-output_seq.shape[1])) 63 | output_seq = np.append(output_seq, pad_array,axis=1) 64 | 65 | # Map the pitch value to int values below vocab size 66 | for i, val in enumerate(vocab): 67 | input_seq[input_seq==val] = new_vocab[i] 68 | output_seq[output_seq==val] = new_vocab[i] 69 | 70 | # Make it into a pair 71 | pair = [input_seq, output_seq] 72 | 73 | # Combine all pairs into one big list of pairs 74 | pairs.append(pair) 75 | 76 | print("Generated data pairs.") 77 | return np.array(pairs) 78 | 79 | def GenerateVocab(npz_file): 80 | """ 81 | Generate vocabulary for the dataset including the custom tokens. 82 | """ 83 | full_data = np.load(npz_file, fix_imports=True, encoding="latin1", allow_pickle=True) 84 | train_data = full_data['train'] 85 | validation_data = full_data['valid'] 86 | test_data = full_data['test'] 87 | 88 | combined_data = np.concatenate((train_data, validation_data, test_data)) 89 | 90 | vocab = np.nan 91 | for sequences in combined_data: 92 | vocab = np.append(vocab,np.unique(sequences)) 93 | 94 | vocab = np.unique(vocab) 95 | vocab = vocab[~np.isnan(vocab)] 96 | vocab = np.append([0,1],vocab) 97 | return vocab 98 | -------------------------------------------------------------------------------- /Embedding.py: -------------------------------------------------------------------------------- 1 | # Filename: Embedding.py 2 | # Date Created: 08-Mar-2019 4:38:59 pm 3 | # Description: Embedding method before input to encoder. 4 | # Includes basic embedding, positional encoding, 5 | # and concatenating positional encoding. 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | from torch.autograd import Variable 10 | 11 | class Embedder(nn.Module): 12 | """ 13 | vocab_size = size of the dictionary of embeddings 14 | d_model = size of each embedding vectors 15 | """ 16 | def __init__(self, vocab_size, d_model): 17 | super().__init__() 18 | self.d_model = d_model 19 | self.embed = nn.Embedding(vocab_size, d_model) 20 | 21 | def forward(self, x): 22 | # make embeddings relatively larger 23 | 24 | return self.embed(x) * math.sqrt(self.d_model) 25 | 26 | class PositionalEncoder(nn.Module): 27 | "Implement the PE function with addition." 28 | def __init__(self, d_model, dropout = 0.1, max_seq_len = 1024): 29 | super().__init__() 30 | self.d_model = d_model 31 | self.dropout = nn.Dropout(dropout) 32 | 33 | pe = torch.zeros(max_seq_len, d_model) 34 | position = torch.arange(0., max_seq_len).unsqueeze(1) 35 | div_term = torch.exp(torch.arange(0., d_model, 2) * 36 | -(math.log(10000.0) / d_model)) 37 | pe[:, 0::2] = torch.sin(position * div_term) 38 | pe[:, 1::2] = torch.cos(position * div_term) 39 | pe = pe.unsqueeze(0) 40 | self.register_buffer('pe', pe) 41 | 42 | def forward(self, x): 43 | 44 | pe = self.pe[:, :x.size(1)] 45 | pe = pe.repeat((x.shape[0],1,1)) 46 | pe = pe.unsqueeze(-2) 47 | 48 | x = x + pe 49 | 50 | return self.dropout(x) 51 | 52 | class PositionalEncoderConcat(nn.Module): 53 | "Implement the PE function with concatenation istead." 54 | "Output dimension will be (1,N,d_model*2)" 55 | def __init__(self, d_model, dropout = 0.0, max_seq_len = 1024): 56 | super().__init__() 57 | self.d_model = d_model 58 | self.dropout = nn.Dropout(dropout) 59 | 60 | pe = torch.zeros(max_seq_len, d_model) 61 | position = torch.arange(0., max_seq_len).unsqueeze(1) 62 | div_term = torch.exp(torch.arange(0., d_model, 2) * 63 | -(math.log(10000.0) / d_model)) 64 | pe[:, 0::2] = torch.sin(position * div_term) 65 | pe[:, 1::2] = torch.cos(position * div_term) 66 | pe = pe.unsqueeze(0) 67 | self.register_buffer('pe', pe) 68 | 69 | def forward(self, x): 70 | # Concatenate embeddings with positional sinusoid 71 | pe = Variable(self.pe[:, :x.size(1)], requires_grad=False) 72 | pe = pe.repeat((x.shape[0],1,1)) 73 | #print(pe.shape, x.shape) 74 | x = torch.cat((x,pe),2) 75 | 76 | return self.dropout(x) 77 | -------------------------------------------------------------------------------- /Encode_Decode_Layers.py: -------------------------------------------------------------------------------- 1 | # Filename: Encode_Decode_Layers.py 2 | # Date Created: 15-Mar-2019 2:42:12 pm 3 | # Description: Attention layers in encoder and decoder layer. 4 | import torch 5 | import torch.nn as nn 6 | from Sublayers import MultiHeadAttention, Norm 7 | from PositionalFeedForward import * 8 | 9 | class EncoderLayer(nn.Module): 10 | def __init__(self, d_model, heads, d_ff = 1024, dropout = 0.1, attention_type = "Baseline", relative_time_pitch = False, max_relative_position = 512): 11 | super().__init__() 12 | self.norm_1 = Norm(d_model) # create normalisation sublayer with size d_model 13 | self.norm_2 = Norm(d_model) 14 | self.attention_type = attention_type 15 | self.relative_time_pitch = relative_time_pitch 16 | 17 | self.attn = MultiHeadAttention(heads, d_model, dropout = dropout, attention_type = self.attention_type, \ 18 | relative_time_pitch = self.relative_time_pitch, 19 | max_relative_position = max_relative_position) 20 | self.ff = FeedForward(d_model, d_ff, dropout) 21 | 22 | self.dropout_1 = nn.Dropout(dropout) 23 | self.dropout_2 = nn.Dropout(dropout) 24 | 25 | def forward(self, x, mask): 26 | x2 = self.norm_1(x) 27 | x = x + self.dropout_1(self.attn(x2,x2,x2,mask)) 28 | 29 | x2 = self.norm_2(x) 30 | x = x + self.dropout_2(self.ff(x2)) 31 | 32 | return x 33 | 34 | class DecoderLayer(nn.Module): 35 | def __init__(self, d_model, heads, d_ff = 1024, dropout=0.1, attention_type = "Baseline", relative_time_pitch = False, max_relative_position = 512): 36 | super().__init__() 37 | self.norm_1 = Norm(d_model) 38 | self.norm_2 = Norm(d_model) 39 | self.norm_3 = Norm(d_model) 40 | 41 | self.attention_type = attention_type 42 | self.relative_time_pitch = relative_time_pitch 43 | self.dropout_1 = nn.Dropout(dropout) 44 | self.dropout_2 = nn.Dropout(dropout) 45 | self.dropout_3 = nn.Dropout(dropout) 46 | 47 | self.attn_1 = MultiHeadAttention(heads, d_model, dropout = dropout, attention_type = self.attention_type, \ 48 | relative_time_pitch = self.relative_time_pitch, 49 | max_relative_position = max_relative_position) 50 | self.attn_2 = MultiHeadAttention(heads, d_model, dropout =dropout, attention_type = self.attention_type, \ 51 | relative_time_pitch = self.relative_time_pitch, 52 | max_relative_position = max_relative_position) 53 | self.ff = FeedForward(d_model, d_ff, dropout) 54 | def forward(self, x, e_outputs, src_mask, trg_mask): 55 | x2 = self.norm_1(x) 56 | x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask)) 57 | x2 = self.norm_2(x) 58 | x = x + self.dropout_2(self.attn_2(x2, e_outputs, e_outputs, 59 | src_mask)) 60 | x2 = self.norm_3(x) 61 | x = x + self.dropout_3(self.ff(x2)) 62 | 63 | return x 64 | 65 | # A convenient cloning function that can generate multiple layers: 66 | def get_clones(module, N): 67 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 68 | -------------------------------------------------------------------------------- /MaskGen.py: -------------------------------------------------------------------------------- 1 | # Filename: MaskGen.py 2 | # Date Created: 15-Mar-2019 2:42:12 pm 3 | # Description: Functions used to generate masks w.r.t. given inputs. 4 | import torch 5 | import numpy as np 6 | from torch.autograd import Variable 7 | 8 | def nopeak_mask(size, opt): 9 | np_mask = np.triu(np.ones((1, size, size)), k=1).astype('uint8') 10 | np_mask = Variable(torch.from_numpy(np_mask) == 0).to(opt.device) 11 | return np_mask 12 | 13 | def create_masks(src, trg, opt): 14 | src_mask = (src != opt.pad_token).unsqueeze(-2) 15 | 16 | if trg is not None: 17 | trg_mask = (trg != opt.pad_token).unsqueeze(-2) 18 | size = trg.size(1) # get seq_len for matrix 19 | np_mask = nopeak_mask(size, opt) 20 | trg_mask = trg_mask & np_mask 21 | else: 22 | trg_mask = None 23 | 24 | return src_mask, trg_mask 25 | -------------------------------------------------------------------------------- /Models.py: -------------------------------------------------------------------------------- 1 | # Filename: Models.py 2 | # Date Created: 16-Mar-2019 2:17:09 pm 3 | # Description: Combine all sublayers into one tranformer model. 4 | import torch 5 | import torch.nn as nn 6 | from Encode_Decode_Layers import EncoderLayer, DecoderLayer 7 | from Embedding import Embedder, PositionalEncoder, PositionalEncoderConcat 8 | from Sublayers import Norm 9 | import torch.nn.functional as F 10 | import copy 11 | 12 | def get_clones(module, N): 13 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 14 | 15 | class Encoder(nn.Module): 16 | def __init__(self, vocab_size, opt): 17 | super().__init__() 18 | self.N = opt.n_layers 19 | self.embed = Embedder(vocab_size, opt.d_model) 20 | if opt.concat_pos_sinusoid is True: 21 | self.pe = PositionalEncoderConcat(opt.d_model, opt.dropout, opt.max_seq_len) 22 | self.d_model = 2 * opt.d_model 23 | else: 24 | self.pe = PositionalEncoder(opt.d_model, opt.dropout, opt.max_seq_len) 25 | self.d_model = opt.d_model 26 | 27 | if opt.relative_time_pitch is True: 28 | self.layers = get_clones(EncoderLayer(self.d_model, opt.heads, opt.d_ff, \ 29 | opt.dropout, opt.attention_type, \ 30 | opt.relative_time_pitch, 31 | max_relative_position = opt.max_relative_position), 32 | opt.n_layers) 33 | self.layers.insert(0, copy.deepcopy(EncoderLayer(self.d_model, opt.heads, opt.d_ff, \ 34 | opt.dropout, opt.attention_type, \ 35 | relative_time_pitch = False, 36 | max_relative_position = opt.max_relative_position))) 37 | else: 38 | self.layers = get_clones(EncoderLayer(self.d_model, opt.heads, opt.d_ff, \ 39 | opt.dropout, opt.attention_type, \ 40 | opt.relative_time_pitch, 41 | max_relative_position = opt.max_relative_position), 42 | opt.n_layers) 43 | self.norm = Norm(self.d_model) 44 | 45 | def forward(self, src, mask): 46 | x = self.embed(src) 47 | x = self.pe(x) 48 | 49 | for i in range(self.N): 50 | x = self.layers[i](x.float(), mask) 51 | return self.norm(x) 52 | 53 | class Decoder(nn.Module): 54 | def __init__(self, vocab_size, opt): 55 | super().__init__() 56 | self.N = opt.n_layers 57 | self.embed = Embedder(vocab_size, opt.d_model) 58 | if opt.concat_pos_sinusoid is True: 59 | self.pe = PositionalEncoderConcat(opt.d_model, opt.dropout, opt.max_seq_len) 60 | self.d_model = 2 * opt.d_model 61 | else: 62 | self.pe = PositionalEncoder(opt.d_model, opt.dropout, opt.max_seq_len) 63 | self.d_model = opt.d_model 64 | 65 | if opt.relative_time_pitch is True: 66 | self.layers = get_clones(DecoderLayer(self.d_model, opt.heads, opt.d_ff, \ 67 | opt.dropout, opt.attention_type, \ 68 | opt.relative_time_pitch, 69 | max_relative_position = opt.max_relative_position), 70 | opt.n_layers-1) 71 | self.layers.insert(0, copy.deepcopy(DecoderLayer(self.d_model, opt.heads, opt.d_ff, \ 72 | opt.dropout, opt.attention_type, \ 73 | relative_time_pitch = False, 74 | max_relative_position = opt.max_relative_position))) 75 | 76 | else: 77 | self.layers = get_clones(DecoderLayer(self.d_model, opt.heads, opt.d_ff, \ 78 | opt.dropout, opt.attention_type, \ 79 | opt.relative_time_pitch, 80 | max_relative_position = opt.max_relative_position), 81 | opt.n_layers) 82 | self.norm = Norm(self.d_model) 83 | 84 | def forward(self, trg, e_outputs, src_mask, trg_mask): 85 | x = self.embed(trg) 86 | x = self.pe(x) 87 | # print(x.shape) 88 | for i in range(self.N): 89 | x = self.layers[i](x.float(), e_outputs, src_mask, trg_mask) 90 | 91 | return self.norm(x) 92 | 93 | class Transformer(nn.Module): 94 | def __init__(self, src_vocab_size, trg_vocab_size, opt): 95 | super().__init__() 96 | self.encoder = Encoder(src_vocab_size, opt) 97 | self.decoder = Decoder(trg_vocab_size, opt) 98 | if opt.concat_pos_sinusoid is True: 99 | self.d_model = 2 * opt.d_model 100 | else: 101 | self.d_model = opt.d_model 102 | 103 | self.linear = nn.Linear(self.d_model, trg_vocab_size) 104 | 105 | def forward(self, src, trg, src_mask, trg_mask): 106 | 107 | e_outputs = self.encoder(src, src_mask) 108 | d_output = self.decoder(trg, e_outputs, src_mask, trg_mask) 109 | output = self.linear(d_output) 110 | 111 | return output 112 | 113 | def get_model(opt, vocab_size): 114 | # Ensure the provided arguments are valid 115 | assert opt.d_model % opt.heads == 0 116 | assert opt.dropout < 1 117 | 118 | print('Attention type: ' + opt.attention_type) 119 | 120 | # Initailze the transformer model 121 | model = Transformer(vocab_size, vocab_size, opt) 122 | 123 | if opt.load_weights is not None: 124 | print("loading pretrained weights...") 125 | checkpoint = torch.load(f'{opt.load_weights}/' + opt.weights_name, map_location = 'cpu') 126 | model.load_state_dict(checkpoint['model_state_dict']) 127 | 128 | model = model.to(opt.device) 129 | 130 | return model 131 | -------------------------------------------------------------------------------- /PositionalFeedForward.py: -------------------------------------------------------------------------------- 1 | # Filename: PositionalFeedForward.py 2 | # Date Created: 10-Mar-2019 21:58:37 2019 3 | # Description: FeedForward layer used in encoder and decoder. 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class FeedForward(nn.Module): 8 | def __init__(self, d_model, d_ff=1024, dropout = 0.1): 9 | super().__init__() 10 | # d_ff defaulted to 2048 11 | self.linear_1 = nn.Linear(d_model, d_ff) 12 | self.dropout = nn.Dropout(dropout) 13 | self.linear_2 = nn.Linear(d_ff, d_model) 14 | 15 | def forward(self, x): 16 | x = self.dropout(F.relu(self.linear_1(x))) 17 | x = self.linear_2(x) 18 | return x 19 | -------------------------------------------------------------------------------- /Process.py: -------------------------------------------------------------------------------- 1 | # Filename: Process.py 2 | # Date Created: 17-Mar-2019 5:43:02 pm 3 | # Description: Functions used for basic processes. 4 | import numpy as np 5 | import copy 6 | 7 | def IndexToPitch(input, vocab): 8 | """ 9 | Converts the index values from model's output back to pitches from vocab. 10 | """ 11 | index_vocab = np.arange(len(vocab)) 12 | output = input.clone() 13 | 14 | for i, val in reversed(list(enumerate(index_vocab))): 15 | output[output==val] = vocab[i] 16 | 17 | return output 18 | 19 | def ProcessModelOutput(model_output): 20 | """ 21 | Remove custom tokens and set rest tokens to NaN values 22 | Converts the model's output into numpy format similar to JSB dataset. 23 | """ 24 | # Values for the custom tokens 25 | rest_token = 0 26 | pad_token = 1 27 | sos_token = 2 28 | eos_token = 3 29 | 30 | # Convert tensor to numpy 31 | output = model_output.cpu().detach().numpy() 32 | 33 | # Replace all pad tokens with rest tokens 34 | output[output==pad_token] = rest_token 35 | 36 | # Change rest tokens to NaN values 37 | output = np.where(output==rest_token, np.nan, output) 38 | 39 | # Reshape output to match JSB dataset 40 | output = output.reshape(round(output.shape[0]/4),4) 41 | 42 | return output 43 | 44 | def get_len(train): 45 | for i, b in enumerate(train): 46 | pass 47 | 48 | return i 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # music-transformer-comp6248 2 | **This re-implementation is meant for the COMP6248 reproducibility challenge.** 3 | 4 | Re-implementation of music transformer paper published for ICLR 2019. 5 | Link to paper: https://openreview.net/forum?id=rJe4ShAcF7 6 | 7 | ## Scope of re-implementation 8 | We focus on re-implementing a subset of the experiments described in the published paper, 9 | only using the JSB Chorales dataset for training three model variations listed below: 10 | - Baseline Transformer (TF) (5L, 256hs, 256att, 1024ff, 8h) 11 | - Baseline TF + concatenating positional sinusoids 12 | - TF with efficient relative attention (5L, 512hs, 512att, 512ff, 256r, 8h) 13 | 14 | ## Results 15 | Unfortunately, we were unable to successfully reproduce the results shown in the published paper. There are a number of reasons that contribute to this, but mainly due to insufficient 16 | details provided for the JSB dataset and how the data was processed for the transformer model. 17 | 18 | Nonetheless, these are the results obtained after training for **300** epochs: 19 | 20 | | Model Variations | Final Training Loss | Final Validation Loss | 21 | |:----------------------------------:|:-------------------:|:---------------------:| 22 | | Baseline TF | 1.731 | 3.398 | 23 | | Baseline TF + concat pos sinusoids | 2.953 | 3.575 | 24 | | TF with relative attention | 3.028 | 3.743 | 25 | 26 | 27 | Please read the [report](report.pdf) for more detailed information regarding the re-implementation. 28 | 29 | ## Usage 30 | **Note: Please create the following folders in this directory before running the scripts:** 31 | - `./weights/` - for storing the trained weights. 32 | - `./outputs/` - for storing training/validation loss values and generated outputs. 33 | - `./plots/` - for storing the plots. 34 | 35 | To train the models, simply run `train.py` and add the arguments accordingly: 36 | 37 | ``` 38 | python train.py -src_data datasets/JSB-Chorales-dataset/Jsb16thSeparated.npz -epochs 300 -weights_name baselineTF_300epoch -device cuda:2 -checkpoint 10 39 | ``` 40 | 41 | To generate music using a trained model, run `generate.py` and add the arguments accordingly: 42 | ``` 43 | python generate.py -src_data datasets/JSB-Chorales-dataset/Jsb16thSeparated.npz -load_weights weights -weights_name baselineTF_300epoch -device cuda:1 -k 3 44 | ``` 45 | 46 | To generate the plots for training/validation loss, please use `gen_training_plots.py`: 47 | ``` 48 | python gen_training_plots.py -t_loss_file -v_loss_file 49 | ``` 50 | 51 | `` and `` are generated during training and are saved in the `outputs` directory. 52 | 53 | The generated sequences can be plotted and listened using this [IPython Notebook](gen_sequence_audio.ipynb). 54 | **Please note that this requires Magenta to be installed.** 55 | 56 | ## Datasets 57 | 1. JSB Chorales dataset 58 | - https://github.com/czhuang/JSB-Chorales-dataset 59 | 2. MAESTRO dataset **(Not used)** 60 | - https://magenta.tensorflow.org/datasets/maestro 61 | 62 | ## Environment Setup 63 | 1. PyTorch with CUDA-enabled GPUs. 64 | - Install CUDA 9.0 and CUDNN 7.4.1.5 65 | - Then follow these steps: 66 | ``` 67 | conda create -n torch python=3.6 68 | conda activate torch 69 | conda install pytorch torchvision cuda90 -c pytorch 70 | ``` 71 | 72 | 2. Magenta for plotting and playing the generated notes. 73 | Steps for installing on Ubuntu: 74 | ``` 75 | conda create -n magenta python=3.6 76 | conda activate magenta 77 | sudo apt-get update 78 | sudo apt-get install build-essential libasound2-dev libjack-dev libfluidsynth1 fluid-soundfont-gm 79 | pip install --pre python-rtmidi 80 | pip install jupyter magenta pyfluidsynth pretty_midi 81 | ``` 82 | -------------------------------------------------------------------------------- /Sublayers.py: -------------------------------------------------------------------------------- 1 | # Filename: Sublayers.py 2 | # Date Created: 15-Mar-2019 2:42:12 pm 3 | # Description: Sublayer functions used for attention mechanism. 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from torch.autograd import Variable 10 | 11 | def shape_list(x): 12 | """ 13 | Return list of dims. 14 | """ 15 | shape = list(x.shape) 16 | 17 | return shape 18 | 19 | 20 | def _relative_position_to_absolute_position_masked(x): 21 | """Helper function for dot_product_self_attention_relative 22 | 23 | Rearrange attention logits or weights tensor. 24 | 25 | Dimensions of input represents: 26 | [batch, heads, query_position, memory_position - query_position + length - 1] 27 | 28 | Dimensions of output represents: 29 | [batch, heads, query_position, memory_position] 30 | 31 | Only works with masked attention. 32 | 33 | Args: 34 | x: a Tensor with shape [batch, heads, length, length] 35 | 36 | Returns: 37 | a Tensor with shape [batch, heads, length, length] 38 | """ 39 | 40 | batch, heads, length, _ = shape_list(x) 41 | 42 | x = F.pad(x, (1, 0, 0, 0, 0, 0, 0, 0)) 43 | x = torch.reshape(x, (batch, heads, 1 + length, length)) 44 | x = x[0:x.shape[0] - 0, 0:x.shape[1] - 0, 1:x.shape[2], 0:x.shape[3] - 0] 45 | 46 | return x 47 | 48 | 49 | def matmul_with_relative_keys(x, y, heads_share_relative_embedding): 50 | if heads_share_relative_embedding: 51 | ret = torch.einsum("bhld,md -> bhlm", x, y) 52 | else: 53 | ret = torch.einsum("bhld,hmd -> bhlm", x, y) 54 | return ret 55 | 56 | def matmul_with_relative_time_pitch(x, y): 57 | ret = torch.einsum("bhld,mmd -> bhlm", x, y) 58 | 59 | return ret 60 | 61 | def get_relative_embeddings_pitch_time(max_relative_position, length, depth, 62 | relative_time_embeddings = None, 63 | relative_pitch_embeddings = None): 64 | """Instantiate or retrieve relative embeddings, sliced according to length 65 | 66 | Use for masked case where the relative attention is only looking left 67 | Args: 68 | max_relative_position: an Integer for the number of entries in the relative 69 | embedding, which corresponds to the max relative distance that is 70 | considered. 71 | length: an Integer, specifies the length of the input sequence for which 72 | this relative embedding is retrieved for. 73 | depth: an Integer, specifies the depth for relative embeddings. 74 | relative_time_embeddings: relative embeddings for time, if not present instantiates one 75 | relative_pitch_embeddings: relative embeddings for pitch, if not present instantiates one 76 | """ 77 | initializer_stddev = depth ** -0.5 78 | embedding_shape = (max_relative_position, max_relative_position, depth) 79 | 80 | if relative_time_embeddings is None: 81 | relative_time_embeddings = Variable(torch.from_numpy(np.random.normal\ 82 | (0.0, initializer_stddev, embedding_shape).astype('f'))) 83 | if relative_pitch_embeddings is None: 84 | relative_pitch_embeddings = Variable(torch.from_numpy(np.random.normal\ 85 | (0.0, initializer_stddev, embedding_shape).astype('f'))) 86 | 87 | pad_length = max(length - max_relative_position, 0) 88 | slice_start_position = max(max_relative_position - length, 0) 89 | 90 | padded_relative_time_embeddings = F.pad( 91 | relative_time_embeddings, 92 | (0, 0, pad_length, 0, pad_length, 0)) 93 | used_relative_time_embeddings = padded_relative_time_embeddings[ 94 | slice_start_position:length, 95 | slice_start_position:slice_start_position + length, 96 | 0:(padded_relative_time_embeddings.shape[2] - 0) 97 | ] 98 | padded_relative_pitch_embeddings = F.pad( 99 | relative_pitch_embeddings, 100 | (0, 0, pad_length, 0, pad_length, 0)) 101 | used_relative_pitch_embeddings = padded_relative_pitch_embeddings[ 102 | slice_start_position:slice_start_position + length, 103 | slice_start_position:slice_start_position + length, 104 | 0:(padded_relative_pitch_embeddings.shape[2] - 0) 105 | ] 106 | 107 | return used_relative_time_embeddings, used_relative_pitch_embeddings, relative_time_embeddings, relative_pitch_embeddings 108 | 109 | def get_relative_embeddings_left(max_relative_position, length, depth, 110 | num_heads, 111 | heads_share_relative_embedding, 112 | relative_embeddings = None): 113 | """Instantiate or retrieve relative embeddings, sliced according to length 114 | 115 | Use for masked case where the relative attention is only looking left 116 | Args: 117 | max_relative_position: an Integer for the number of entries in the relative 118 | embedding, which corresponds to the max relative distance that is 119 | considered. 120 | length: an Integer, specifies the length of the input sequence for which 121 | this relative embedding is retrieved for. 122 | depth: an Integer, specifies the depth for relative embeddings. 123 | num_heads: an Integer, specifies the number of heads. 124 | heads_share_relative_embedding: a Boolean specifying if the relative 125 | embedding is shared across heads. 126 | """ 127 | 128 | initializer_stddev = depth ** -0.5 129 | if heads_share_relative_embedding: 130 | embedding_shape = (max_relative_position, depth) 131 | else: 132 | embedding_shape = (num_heads, max_relative_position, depth) 133 | 134 | if relative_embeddings is None: 135 | relative_embeddings = Variable(torch.from_numpy(np.random.normal(0.0, initializer_stddev, embedding_shape).astype('f'))) 136 | 137 | pad_length = max(length - max_relative_position, 0) 138 | slice_start_position = max(max_relative_position - length, 0) 139 | 140 | if heads_share_relative_embedding: 141 | padded_relative_embeddings = F.pad( 142 | relative_embeddings, 143 | (0, 0, pad_length, 0)) 144 | 145 | used_relative_embeddings = padded_relative_embeddings[slice_start_position:slice_start_position + length, 146 | 0:(padded_relative_embeddings.shape[1] - 0)] 147 | else: 148 | padded_relative_embeddings = F.pad( 149 | relative_embeddings, 150 | (0, 0, pad_length, 0, 0, 0)) 151 | 152 | used_relative_embeddings = padded_relative_embeddings[ 153 | 0:(padded_relative_embeddings.shape[0] - 0), 154 | slice_start_position:slice_start_position + length, 155 | 0:(padded_relative_embeddings.shape[2] - 0) 156 | ] 157 | 158 | return used_relative_embeddings, relative_embeddings 159 | 160 | 161 | def dot_product_self_attention_relative(q, 162 | k, 163 | v, 164 | mask = None, 165 | bias = None, 166 | max_relative_position = None, 167 | dropout = None, 168 | heads_share_relative_embedding = False, 169 | relative_embeddings = None, 170 | relative_time_pitch = False, 171 | relative_time_embeddings = None, 172 | relative_pitch_embeddings = None): 173 | if not max_relative_position: 174 | raise ValueError("Max relative position (%s) should be > 0 when using " 175 | "relative self attention." % (max_relative_position)) 176 | 177 | # Use separate embeddings suitable for keys and values. 178 | _, heads, length, depth_k = shape_list(k) 179 | 180 | logits = torch.matmul(q, k.transpose(-2, -1)) 181 | 182 | if mask is not None: 183 | mask = mask.unsqueeze(1) #shape of mask must be broadcastable with shape of underlying tensor 184 | logits = logits.masked_fill(mask == 0, -1e9) #masked_fill fills elements of scores with -1e9 where mask == 0 185 | 186 | key_relative_embeddings, relative_embeddings = get_relative_embeddings_left( 187 | max_relative_position, length, depth_k, heads, heads_share_relative_embedding, relative_embeddings) 188 | 189 | key_relative_embeddings = key_relative_embeddings.to(q.device) 190 | 191 | relative_logits = matmul_with_relative_keys(q, key_relative_embeddings, 192 | heads_share_relative_embedding) 193 | 194 | relative_logits = _relative_position_to_absolute_position_masked(relative_logits) #[1, 8, 1023, 1024] 195 | 196 | if relative_time_pitch == True: 197 | to_use_time_relative_embeddings, to_use_pitch_relative_embeddings,\ 198 | relative_time_embeddings, relative_pitch_embeddings \ 199 | = get_relative_embeddings_pitch_time(max_relative_position, length, 200 | depth_k, 201 | relative_time_embeddings, 202 | relative_pitch_embeddings) 203 | 204 | relative_time_pitch_sum = (to_use_time_relative_embeddings + to_use_pitch_relative_embeddings).to(q.device) 205 | relative_time_pitch_term = matmul_with_relative_time_pitch(q, relative_time_pitch_sum) 206 | relative_logits = relative_logits + relative_time_pitch_term 207 | 208 | logits += relative_logits 209 | 210 | if bias is not None: 211 | logits += bias 212 | 213 | weights = F.softmax(logits, dim = -1) 214 | # Dropping out the attention links for each of the heads. 215 | if dropout is not None: 216 | weights = dropout(weights) 217 | 218 | output = torch.matmul(weights, v) 219 | 220 | return output, relative_embeddings, relative_time_embeddings, relative_pitch_embeddings 221 | 222 | else: 223 | logits += relative_logits 224 | 225 | if bias is not None: 226 | logits += bias 227 | 228 | weights = F.softmax(logits, dim = -1) 229 | # Dropping out the attention links for each of the heads. 230 | if dropout is not None: 231 | weights = dropout(weights) 232 | 233 | output = torch.matmul(weights, v) 234 | 235 | return output, relative_embeddings 236 | 237 | 238 | def attention(q, v, k, d_k, mask = None, dropout = None): 239 | 240 | scores = torch.matmul(q, k.transpose(-2, -1))/ math.sqrt(d_k) 241 | if mask is not None: 242 | #mask = mask.unsqueeze(1) #shape of mask must be broadcastable with shape of underlying tensor 243 | scores = scores.masked_fill(mask == 0, -1e9) #masked_fill fills elements of scores with -1e9 where mask == 0 244 | 245 | scores = F.softmax(scores, dim = -1) 246 | if dropout is not None: 247 | scores = dropout(scores) 248 | 249 | output = torch.matmul(scores, v) 250 | 251 | return output 252 | 253 | 254 | class MultiHeadAttention(nn.Module): 255 | def __init__(self, heads, d_model, dropout = 0.0, attention_type = "Baseline", 256 | bias = None, 257 | max_relative_position = 512, 258 | heads_share_relative_embedding = False, 259 | relative_time_pitch = False): 260 | super().__init__() 261 | 262 | self.d_model = d_model 263 | self.d_k = d_model // heads #final dimension = d_model/N as we split embedding vec into N heads 264 | self.h = heads #number of heads 265 | 266 | self.attention_type = attention_type 267 | self.bias = bias 268 | self.max_relative_position = max_relative_position 269 | self.heads_share_relative_embedding = heads_share_relative_embedding 270 | self.relative_time_pitch = relative_time_pitch 271 | self.relative_embeddings = None 272 | self.relative_time_embeddings = None 273 | self.relative_pitch_embeddings = None 274 | 275 | self.q_linear = nn.Linear(d_model, d_model) 276 | self.v_linear = nn.Linear(d_model, d_model) 277 | self.k_linear = nn.Linear(d_model, d_model) 278 | self.dropout = nn.Dropout(dropout) 279 | self.out = nn.Linear(d_model, d_model) 280 | 281 | def forward(self, q, k, v, mask=None): 282 | 283 | bs = q.size(0) #batch size 284 | #original size bs * seq_len * h * d_k 285 | k = self.k_linear(k).view(bs, -1, self.h, self.d_k) 286 | q = self.q_linear(q).view(bs, -1, self.h, self.d_k) 287 | v = self.v_linear(v).view(bs, -1, self.h, self.d_k) 288 | # transpose to get dimensions of bs * h * seq_len * d_k 289 | k = k.transpose(1,2) # torch.Size([512, 3, 8, 64]) transpose will result in torch.Size([512, 8, 3, 64]) 290 | q = q.transpose(1,2) 291 | v = v.transpose(1,2) 292 | 293 | 294 | # calculate attention using defined attention function 295 | if self.attention_type == "Baseline": 296 | scores = attention(q, k, v, self.d_k, mask, self.dropout) 297 | 298 | 299 | else: 300 | if self.relative_time_pitch: 301 | scores, self.relative_embeddings,\ 302 | self.relative_time_embeddings,\ 303 | self.relative_pitch_embeddings = dot_product_self_attention_relative(q, k, v, mask, 304 | self.bias, 305 | self.max_relative_position, 306 | self.dropout, 307 | self.heads_share_relative_embedding, 308 | self.relative_embeddings, 309 | self.relative_time_pitch, 310 | self.relative_time_embeddings, 311 | self.relative_pitch_embeddings) 312 | else: 313 | scores, self.relative_embeddings = dot_product_self_attention_relative(q, k, v, mask, 314 | self.bias, 315 | self.max_relative_position, 316 | self.dropout, 317 | self.heads_share_relative_embedding, 318 | self.relative_embeddings) 319 | 320 | #concatenate heads and put through final linear layer 321 | 322 | 323 | concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model) 324 | output = self.out(concat) 325 | 326 | return output.unsqueeze(1) 327 | 328 | class Norm(nn.Module): 329 | def __init__(self, d_model, eps = 1e-6): 330 | super().__init__() 331 | 332 | self.size = d_model 333 | 334 | #create two learnable parameters to calibrate normalisation 335 | self.alpha = nn.Parameter(torch.ones(self.size)) 336 | self.bias = nn.Parameter(torch.ones(self.size)) 337 | 338 | self.eps = eps 339 | 340 | def forward(self, x): 341 | norm = self.alpha * (x - x.mean(dim = 2, keepdim = True)) / (x.std(dim = 2, keepdim = True) + self.eps) + self.bias 342 | return norm 343 | -------------------------------------------------------------------------------- /datasets/JSB-Chorales-dataset/Jsb16thSeparated.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/datasets/JSB-Chorales-dataset/Jsb16thSeparated.npz -------------------------------------------------------------------------------- /gen_training_plots.py: -------------------------------------------------------------------------------- 1 | # Filename: gen_training_plots.py 2 | # Date Created: 13-May-2019 10:44:46 pm 3 | # Description: Script to generate the training/validation loss plot against epochs. 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import argparse 7 | 8 | def main(): 9 | # Add parser to parse in the arguments 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('-t_loss_file', type=str, required=True) 12 | parser.add_argument('-v_loss_file', type=str, required=True) 13 | opt = parser.parse_args() 14 | 15 | print("Generating and saving plot...") 16 | 17 | # load in the training loss and validation loss 18 | t_loss = np.load(opt.t_loss_file) 19 | v_loss = np.load(opt.v_loss_file) 20 | 21 | # plot seperately 22 | plt.plot(np.arange(len(t_loss))+1,t_loss) 23 | plt.plot(np.arange(len(v_loss))+1,v_loss) 24 | plt.xlabel('No. of Epochs') 25 | plt.ylabel('NLL Loss') 26 | plt.grid() 27 | plt.legend(['Training','Validation']) 28 | plt.savefig('./plots/combined_plot.png',dpi=300, 29 | bbox_inches='tight', pad_inches=0) 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Filename: generate.py 2 | # Date Created: 08-May-2019 11:48:49 pm 3 | # Description: Run this script to generate music using the music-transformer model. 4 | import argparse 5 | import torch 6 | from DataPrep import GenerateVocab, PrepareData, tensorFromSequence 7 | from Models import get_model 8 | from MaskGen import create_masks 9 | from Process import IndexToPitch, ProcessModelOutput, get_len 10 | from Beam import beam_search 11 | import numpy as np 12 | import torch.nn.functional as F 13 | import time 14 | import os 15 | 16 | def generate(model,opt): 17 | print("generating music using beam search...") 18 | model.eval() 19 | 20 | # choose 2 random pitches within the vocab (except rest/pad token) to start the sequence 21 | starting_pitch = torch.randint(2, len(opt.vocab)-1, (2,)).unsqueeze(1).transpose(0,1).to(opt.device) 22 | 23 | # generate the sequence using beam search 24 | generated_seq = beam_search(starting_pitch, model, opt) 25 | 26 | # Make the index values back to original pitch 27 | output_seq = IndexToPitch(generated_seq, opt.vocab) 28 | 29 | # Process the output format such that it is the same as our dataset 30 | processed = ProcessModelOutput(output_seq) 31 | 32 | return processed 33 | 34 | def main(): 35 | # Add parser to parse in the arguments 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('-src_data', required=True) 38 | parser.add_argument('-load_weights', required=False) 39 | parser.add_argument('-output_name', type=str, required=True) 40 | parser.add_argument('-device', type=str, default="cuda:1" if torch.cuda.is_available() else "cpu") 41 | parser.add_argument('-k', type=int, default=3) 42 | parser.add_argument('-d_model', type=int, default=256) 43 | parser.add_argument('-d_ff', type=int, default=1024) 44 | parser.add_argument('-n_layers', type=int, default=5) 45 | parser.add_argument('-heads', type=int, default=8) 46 | parser.add_argument('-dropout', type=float, default=0.1) 47 | parser.add_argument('-batchsize', type=int, default=1) 48 | parser.add_argument('-max_seq_len', type=int, default=1024) 49 | parser.add_argument('-attention_type', type = str, default = 'Baseline') 50 | parser.add_argument('-weights_name', type = str, default = 'model_weights') 51 | parser.add_argument("-concat_pos_sinusoid", type=str2bool, default=False) 52 | parser.add_argument("-relative_time_pitch", type=str2bool, default=False) 53 | parser.add_argument("-max_relative_position", type=int, default=512) 54 | opt = parser.parse_args() 55 | 56 | # Generate the vocabulary from the data 57 | opt.vocab = GenerateVocab(opt.src_data) 58 | opt.pad_token = 1 59 | 60 | # Create the model using the arguments and the vocab size 61 | model = get_model(opt, len(opt.vocab)) 62 | 63 | # counter to keep track of how many outputs have been saved 64 | opt.save_counter = 0 65 | 66 | # Now lets generate some music 67 | generated_music = generate(model,opt) 68 | 69 | # Ask for next action 70 | promptNextAction(model, opt, generated_music) 71 | 72 | def yesno(response): 73 | while True: 74 | if response != 'y' and response != 'n': 75 | response = input('command not recognised, enter y or n : ') 76 | else: 77 | return response 78 | 79 | def promptNextAction(model, opt, processed): 80 | while True: 81 | save = yesno(input('generate complete, save music? [y/n] : ')) 82 | if save == 'y': 83 | print("saving music...") 84 | # Pickle the processed outputs for magenta later 85 | opt.save_counter += 1 86 | np.save('outputs/' + opt.output_name + str(opt.save_counter), processed) 87 | 88 | res = yesno(input("generate again? [y/n] : ")) 89 | if res == 'y': 90 | # Now lets generate some music 91 | processed = generate(model,opt) 92 | 93 | else: 94 | print("exiting program...") 95 | break 96 | 97 | def str2bool(v): 98 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 99 | return True 100 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 101 | return False 102 | else: 103 | raise argparse.ArgumentTypeError('Boolean value expected.') 104 | 105 | if __name__ == "__main__": 106 | # For reproducibility 107 | torch.manual_seed(0) 108 | main() 109 | -------------------------------------------------------------------------------- /outputs/gen_output/gen_output_baselineTF_300epoch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/outputs/gen_output/gen_output_baselineTF_300epoch.npy -------------------------------------------------------------------------------- /outputs/gen_output/gen_output_baselineTFconcat_1000epoch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/outputs/gen_output/gen_output_baselineTFconcat_1000epoch.npy -------------------------------------------------------------------------------- /outputs/gen_output/gen_output_baselineTFconcat_300epoch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/outputs/gen_output/gen_output_baselineTFconcat_300epoch.npy -------------------------------------------------------------------------------- /outputs/gen_output/gen_output_relativeTF_300epoch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/outputs/gen_output/gen_output_relativeTF_300epoch.npy -------------------------------------------------------------------------------- /outputs/loss/t_loss_baselineTF_300epoch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/outputs/loss/t_loss_baselineTF_300epoch.npy -------------------------------------------------------------------------------- /outputs/loss/t_loss_baselineTFconcat_300epoch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/outputs/loss/t_loss_baselineTFconcat_300epoch.npy -------------------------------------------------------------------------------- /outputs/loss/t_loss_final_baselineTFconcat_1000epoch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/outputs/loss/t_loss_final_baselineTFconcat_1000epoch.npy -------------------------------------------------------------------------------- /outputs/loss/t_loss_relativeTF_300epoch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/outputs/loss/t_loss_relativeTF_300epoch.npy -------------------------------------------------------------------------------- /outputs/loss/v_loss_baselineTF_300epoch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/outputs/loss/v_loss_baselineTF_300epoch.npy -------------------------------------------------------------------------------- /outputs/loss/v_loss_baselineTFconcat_300epoch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/outputs/loss/v_loss_baselineTFconcat_300epoch.npy -------------------------------------------------------------------------------- /outputs/loss/v_loss_final_baselineTFconcat_1000epoch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/outputs/loss/v_loss_final_baselineTFconcat_1000epoch.npy -------------------------------------------------------------------------------- /outputs/loss/v_loss_relativeTF_300epoch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/outputs/loss/v_loss_relativeTF_300epoch.npy -------------------------------------------------------------------------------- /plots/BaselineTF_300epochs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/plots/BaselineTF_300epochs.png -------------------------------------------------------------------------------- /plots/BaselineTFconcat_300epochs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/plots/BaselineTFconcat_300epochs.png -------------------------------------------------------------------------------- /plots/relativeTF_300epochs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/plots/relativeTF_300epochs.png -------------------------------------------------------------------------------- /plots/trainingLoss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/plots/trainingLoss.png -------------------------------------------------------------------------------- /plots/validationLoss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/plots/validationLoss.png -------------------------------------------------------------------------------- /report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/fa49d1179f1ba6d06519be314f64749849e69fb4/report.pdf -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Filename: train.py 2 | # Date Created: 08-Mar-2019 11:48:49 pm 3 | # Description: Run this script to train the music-transformer model. 4 | import argparse 5 | import torch 6 | from DataPrep import GenerateVocab, PrepareData, tensorFromSequence 7 | from Models import get_model 8 | from MaskGen import create_masks 9 | from Process import IndexToPitch, ProcessModelOutput, get_len 10 | import numpy as np 11 | import torch.nn.functional as F 12 | import time 13 | import os 14 | 15 | torch.set_default_dtype(torch.float) 16 | torch.set_default_tensor_type(torch.FloatTensor) 17 | #torch.autograd.detect_anomaly = True 18 | 19 | def count_nonpad_tokens(target, padding_index): 20 | nonpads = (target != padding_index).squeeze() 21 | ntokens = torch.sum(nonpads) 22 | return ntokens 23 | 24 | def LabelSmoothing(input, target, smoothing, padding_index): 25 | """ 26 | Args: 27 | input: input to loss function, size of [N, Class] 28 | target: target input to loss function 29 | smoothing: degree of smoothing 30 | padding_index: number used for padding, i.e. 1 31 | 32 | Returns: 33 | A smoothed target input to loss function 34 | """ 35 | confidence = 1.0 - smoothing 36 | true_dist = input.clone() 37 | true_dist.fill_(smoothing/ (input.size(1) - 2)) 38 | true_dist.scatter_(1, target.unsqueeze(1), confidence) 39 | mask = torch.nonzero(target == padding_index) 40 | if mask.dim() > 0: 41 | true_dist.index_fill_(0, mask.squeeze(), 0.0) 42 | 43 | return torch.autograd.Variable(true_dist, requires_grad = False) 44 | 45 | def batched_learning(train,batch_size): 46 | for i in range(0, len(train), batch_size): 47 | train1 = train[i:i + batch_size] 48 | yield train1[:,0],train1[:,1] 49 | 50 | 51 | def train(model, opt): 52 | print("training model...") 53 | start = time.time() 54 | warmup_steps = 4000 55 | step_num_load = 0 56 | step_num = 1 57 | epoch_load = 0 58 | 59 | t_loss_per_epoch = [] 60 | v_loss_per_epoch = [] 61 | 62 | if opt.checkpoint > 0: 63 | cptime = time.time() 64 | 65 | # if load_weights to resume training 66 | if opt.load_weights is not None: 67 | checkpoint = torch.load('weights/' + opt.weights_name, map_location = 'cpu') 68 | model.load_state_dict(checkpoint['model_state_dict']) 69 | opt.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 70 | step_num_load = checkpoint['step_num'] # to keep track of learning rate 71 | epoch_load = checkpoint['epoch'] 72 | 73 | if opt.resume is True: 74 | checkpoint = torch.load('weights/' + opt.weights_name) 75 | 76 | # No need to load weights, as it is the same model being trained 77 | # model.load_state_dict(checkpoint['model_state_dict']) 78 | 79 | opt.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 80 | step_num_load = checkpoint['step_num'] # to keep track of learning rate 81 | epoch_load = checkpoint['epoch'] 82 | step_num += step_num_load 83 | 84 | for epoch in range(opt.epochs): 85 | model.train() 86 | 87 | np.random.shuffle(opt.train) 88 | # learning rate defined in Attention is All You Need Paper 89 | opt.lr = (opt.d_model ** (-0.5)) * (min(step_num ** (-0.5), step_num * warmup_steps ** (-1.5))) 90 | 91 | # Vary learning rate based on step numbers (each note consider a step) 92 | for param_group in opt.optimizer.param_groups: 93 | param_group['lr'] = opt.lr 94 | 95 | total_loss = [] 96 | print(" %dm: epoch %d [%s] %d%% training loss = %s" %\ 97 | ((time.time() - start)//60, (epoch + epoch_load) + 1, "".join(' '*20), 0, '...'), end='\r') 98 | 99 | 100 | for i, batch in enumerate(batched_learning(opt.train, batch_size=opt.batch_size)): 101 | input, target = batch 102 | #print(input.shape,target.shape) 103 | input = tensorFromSequence(input).to(opt.device) 104 | target = tensorFromSequence(target).to(opt.device) 105 | 106 | trg_input = target 107 | ys = target[:, 0:].contiguous().view(-1) 108 | 109 | # Create mask for both input and target sequences 110 | input_mask, target_mask = create_masks(input, trg_input, opt) 111 | 112 | preds_idx = model(input, trg_input, input_mask, target_mask) 113 | 114 | opt.optimizer.zero_grad() 115 | 116 | loss = F.cross_entropy(preds_idx.contiguous().view(preds_idx.size(-1), -1).transpose(0,1), ys, \ 117 | ignore_index = opt.pad_token, size_average = False) / (count_nonpad_tokens(ys,1)) 118 | 119 | loss.backward() 120 | 121 | opt.optimizer.step() 122 | 123 | step_num += 1 124 | 125 | 126 | total_loss.append(loss.item()) 127 | 128 | if (i + 1) % opt.printevery == 0: 129 | p = int(100 * (i + 1) / get_len(opt.train)) 130 | avg_loss = np.mean(total_loss) 131 | 132 | print(" %dm: epoch %d [%s%s] %d%% training loss = %.3f" %\ 133 | ((time.time() - start)//60, (epoch + epoch_load) + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end ='\r') 134 | 135 | 136 | avg_loss = np.mean(total_loss) 137 | 138 | # Validation step 139 | model.eval() 140 | total_validate_loss = [] 141 | with torch.no_grad(): 142 | pair = opt.valid 143 | input = tensorFromSequence(pair[0]).to(opt.device) 144 | 145 | target = tensorFromSequence(pair[1]).to(opt.device) 146 | trg_input = target 147 | ys = target[:, 0:].contiguous().view(-1) 148 | 149 | input_mask, target_mask = create_masks(input, trg_input, opt) 150 | preds_validate = model(input, trg_input, input_mask, target_mask) 151 | validate_loss = F.cross_entropy(preds_validate.contiguous().view(preds_validate.size(-1), -1).transpose(0,1), ys, \ 152 | ignore_index = opt.pad_token, size_average = False) / (count_nonpad_tokens(ys, 1)) 153 | total_validate_loss.append(validate_loss.item()) 154 | avg_validate_loss = np.mean(total_validate_loss) 155 | 156 | # Store the average training & validation loss for each epoch 157 | t_loss_per_epoch.append(avg_loss) 158 | v_loss_per_epoch.append(avg_validate_loss) 159 | 160 | # checkpoint in terms of minutes reached, then save weights, and other information 161 | if opt.checkpoint > 0 and ((time.time()-cptime)//60) // opt.checkpoint >= 1: 162 | print("checkpoint save...") 163 | torch.save({ 164 | 'epoch': epoch + epoch_load, 165 | 'model_state_dict': model.state_dict(), 166 | 'optimizer_state_dict': opt.optimizer.state_dict(), 167 | 'loss': avg_loss, 168 | 'step_num': step_num 169 | }, 'weights/' + opt.weights_name) 170 | cptime = time.time() 171 | 172 | # Convert list into numpy arrays 173 | t_loss_per_epoch_tmp = np.array(t_loss_per_epoch) 174 | v_loss_per_epoch_tmp = np.array(v_loss_per_epoch) 175 | 176 | # Save the arrays for plotting later 177 | np.save(('outputs/t_loss%dm_'%int((time.time() - start)//60))+opt.weights_name, t_loss_per_epoch_tmp) 178 | np.save(('outputs/v_loss%dm_'%int((time.time() - start)//60))+opt.weights_name, v_loss_per_epoch_tmp) 179 | 180 | print("%dm: epoch %d [%s%s] %d%% training loss = %.3f\nepoch %d complete, training loss = %.03f, validation loss = %.03f" %\ 181 | ((time.time() - start)//60, (epoch + epoch_load) + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, avg_loss, (epoch + epoch_load) + 1, avg_loss, avg_validate_loss)) 182 | 183 | return epoch, avg_loss, step_num, t_loss_per_epoch, v_loss_per_epoch 184 | 185 | def main(): 186 | # Add parser to parse in the arguments 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument('-src_data', required=False) 189 | parser.add_argument('-device', type=str, default="cuda:1" if torch.cuda.is_available() else "cpu") 190 | parser.add_argument('-epochs', type=int, default=1) 191 | parser.add_argument('-d_model', type=int, default=256) 192 | parser.add_argument('-d_ff', type=int, default=1024) 193 | parser.add_argument('-n_layers', type=int, default=5) 194 | parser.add_argument('-heads', type=int, default=8) 195 | parser.add_argument('-dropout', type=float, default=0.1) 196 | parser.add_argument('-batch_size', type=int, default=1) 197 | parser.add_argument('-max_seq_len', type=int, default=1024) 198 | parser.add_argument('-printevery', type=int, default=100) 199 | parser.add_argument('-lr', type= float, default=0.0001) 200 | parser.add_argument('-load_weights') 201 | parser.add_argument('-checkpoint', type=int, default=0) 202 | parser.add_argument('-attention_type', type = str, default = 'Baseline') 203 | parser.add_argument('-weights_name', type = str, default = 'model_weights') 204 | parser.add_argument("-concat_pos_sinusoid", type=str2bool, default=False) 205 | parser.add_argument("-relative_time_pitch", type=str2bool, default=False) 206 | parser.add_argument("-max_relative_position", type= int, default = 512) 207 | opt = parser.parse_args() 208 | 209 | # Initialize resume option as False 210 | opt.resume = False 211 | 212 | # Generate the vocabulary from the data 213 | opt.vocab = GenerateVocab(opt.src_data) 214 | opt.pad_token = 1 215 | 216 | # Setup the dataset for training split and validation split 217 | opt.train = PrepareData(opt.src_data ,'train', int(opt.max_seq_len)) 218 | opt.valid = PrepareData(opt.src_data ,'valid', int(opt.max_seq_len)) 219 | 220 | # Create the model using the arguments and the vocab size 221 | model = get_model(opt, len(opt.vocab)) 222 | 223 | # Set up optimizer for training 224 | opt.optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9, weight_decay=1e-4) 225 | 226 | 227 | # step_num is based on time, which is used to calculate learning rate 228 | step_num = 0 229 | 230 | # Train the model 231 | avg_loss, epoch, step_num, t_loss_per_epoch, v_loss_per_epoch = train(model, opt) 232 | 233 | promptNextAction(model, opt, epoch, step_num, avg_loss, t_loss_per_epoch, v_loss_per_epoch) 234 | 235 | def yesno(response): 236 | while True: 237 | if response != 'y' and response != 'n': 238 | response = input('command not recognised, enter y or n : ') 239 | else: 240 | return response 241 | 242 | def promptNextAction(model, opt, epoch, step_num, avg_loss, t_loss_per_epoch, v_loss_per_epoch): 243 | 244 | saved_once = 1 if opt.load_weights is not None or opt.checkpoint > 0 else 0 245 | 246 | if opt.load_weights is not None: 247 | dst = opt.load_weights 248 | if opt.checkpoint > 0: 249 | dst = 'weights' 250 | 251 | while True: 252 | save = yesno(input('training complete, save results? [y/n] : ')) 253 | if save == 'y': 254 | print("saving weights...") 255 | torch.save({ 256 | 'epoch': epoch, 257 | 'model_state_dict': model.state_dict(), 258 | 'optimizer_state_dict': opt.optimizer.state_dict(), 259 | 'loss': avg_loss, 260 | 'step_num': step_num 261 | }, 'weights/' + opt.weights_name) 262 | 263 | res = yesno(input("train for more epochs? [y/n] : ")) 264 | if res == 'y': 265 | while True: 266 | epochs = input("type number of epochs to train for : ") 267 | try: 268 | epochs = int(epochs) 269 | except: 270 | print("input not a number") 271 | continue 272 | if epochs < 1: 273 | print("epochs must be at least 1") 274 | continue 275 | else: 276 | break 277 | opt.epochs = epochs 278 | opt.resume = True 279 | 280 | _, _, _, extra_t_loss, extra_v_loss = train(model, opt) 281 | t_loss_per_epoch.extend(extra_t_loss) 282 | v_loss_per_epoch.extend(extra_v_loss) 283 | else: 284 | # Convert list into numpy arrays 285 | t_loss_per_epoch = np.array(t_loss_per_epoch) 286 | v_loss_per_epoch = np.array(v_loss_per_epoch) 287 | 288 | # Save the arrays for plotting later 289 | np.save(('outputs/t_loss_final_')+opt.weights_name, t_loss_per_epoch) 290 | np.save(('outputs/v_loss_final_')+opt.weights_name, v_loss_per_epoch) 291 | 292 | print("exiting program...") 293 | break 294 | 295 | def str2bool(v): 296 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 297 | return True 298 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 299 | return False 300 | else: 301 | raise argparse.ArgumentTypeError('Boolean value expected.') 302 | 303 | if __name__ == "__main__": 304 | # For reproducibility 305 | torch.manual_seed(0) 306 | main() 307 | --------------------------------------------------------------------------------