├── requirements.txt ├── pickles ├── remi_vocab.pkl ├── test_pieces.pkl ├── train_pieces.pkl └── val_pieces.pkl ├── .gitignore ├── utils.py ├── model ├── transformer_encoder.py ├── transformer_helpers.py └── musemorphose.py ├── LICENSE ├── config └── default.yaml ├── attributes.py ├── README.md ├── remi2midi.py ├── train.py ├── dataloader.py └── generate.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6.0 2 | miditoolkit 3 | PyYAML 4 | scipy 5 | numpy -------------------------------------------------------------------------------- /pickles/remi_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/MuseMorphose/HEAD/pickles/remi_vocab.pkl -------------------------------------------------------------------------------- /pickles/test_pieces.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/MuseMorphose/HEAD/pickles/test_pieces.pkl -------------------------------------------------------------------------------- /pickles/train_pieces.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/MuseMorphose/HEAD/pickles/train_pieces.pkl -------------------------------------------------------------------------------- /pickles/val_pieces.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/MuseMorphose/HEAD/pickles/val_pieces.pkl -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.tar.gz 2 | *dataset* 3 | **/__pycache__/ 4 | **/.vscode/ 5 | **/*ckpt*/ 6 | **/*.pt 7 | **/gens/ -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | from scipy.spatial import distance 4 | 5 | # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 6 | # device = torch.device('cpu') 7 | 8 | def numpy_to_tensor(arr, use_gpu=True, device='cuda:0'): 9 | if use_gpu: 10 | return torch.tensor(arr).to(device).float() 11 | else: 12 | return torch.tensor(arr).float() 13 | 14 | def tensor_to_numpy(tensor): 15 | return tensor.cpu().detach().numpy() 16 | 17 | def pickle_load(f): 18 | return pickle.load(open(f, 'rb')) 19 | 20 | def pickle_dump(obj, f): 21 | pickle.dump(obj, open(f, 'wb'), protocol=pickle.HIGHEST_PROTOCOL) -------------------------------------------------------------------------------- /model/transformer_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class VAETransformerEncoder(nn.Module): 6 | def __init__(self, n_layer, n_head, d_model, d_ff, d_vae_latent, dropout=0.1, activation='relu'): 7 | super(VAETransformerEncoder, self).__init__() 8 | self.n_layer = n_layer 9 | self.n_head = n_head 10 | self.d_model = d_model 11 | self.d_ff = d_ff 12 | self.d_vae_latent = d_vae_latent 13 | self.dropout = dropout 14 | self.activation = activation 15 | 16 | self.tr_encoder_layer = nn.TransformerEncoderLayer( 17 | d_model, n_head, d_ff, dropout, activation 18 | ) 19 | self.tr_encoder = nn.TransformerEncoder( 20 | self.tr_encoder_layer, n_layer 21 | ) 22 | 23 | self.fc_mu = nn.Linear(d_model, d_vae_latent) 24 | self.fc_logvar = nn.Linear(d_model, d_vae_latent) 25 | 26 | def forward(self, x, padding_mask=None): 27 | out = self.tr_encoder(x, src_key_padding_mask=padding_mask) 28 | hidden_out = out[0, :, :] 29 | mu, logvar = self.fc_mu(hidden_out), self.fc_logvar(hidden_out) 30 | 31 | return hidden_out, mu, logvar 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Shih-Lun Wu, Yi-Hsuan Yang, and Taiwan AI Labs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: ./remi_dataset 3 | train_split: ./pickles/train_pieces.pkl 4 | val_split: ./pickles/val_pieces.pkl 5 | test_split: ./pickles/test_pieces.pkl 6 | vocab_path: ./pickles/remi_vocab.pkl 7 | max_bars: 16 8 | enc_seqlen: 128 9 | dec_seqlen: 1280 10 | batch_size: 4 11 | 12 | model: 13 | enc_n_layer: 12 14 | enc_n_head: 8 15 | enc_d_model: 512 16 | enc_d_ff: 2048 17 | dec_n_layer: 12 18 | dec_n_head: 8 19 | dec_d_model: 512 20 | dec_d_ff: 2048 21 | d_embed: 512 22 | d_latent: 128 23 | d_polyph_emb: 64 24 | d_rfreq_emb: 64 25 | cond_mode: in-attn 26 | pretrained_params_path: null 27 | pretrained_optim_path: null 28 | 29 | training: 30 | device: cuda:0 31 | ckpt_dir: ./ckpt/enc_dec_12L-16_bars-seqlen_1280 32 | trained_steps: 0 33 | max_epochs: 1000 34 | max_lr: 1.0e-4 35 | min_lr: 5.0e-6 36 | lr_warmup_steps: 200 37 | lr_decay_steps: 150000 38 | no_kl_steps: 10000 39 | kl_cycle_steps: 5000 40 | kl_max_beta: 1.0 41 | free_bit_lambda: 0.25 42 | constant_kl: False 43 | ckpt_interval: 50 44 | log_interval: 10 45 | val_interval: 50 46 | 47 | generate: 48 | temperature: 1.2 49 | nucleus_p: 0.9 50 | use_latent_sampling: False 51 | latent_sampling_var: 0.0 52 | max_bars: 16 # could be set to match the longest input piece during generation (inference) 53 | dec_seqlen: 1280 # could be set to match the longest input piece during generation (inference) 54 | max_input_dec_seqlen: 1024 # should be set to equal to or less than `dec_seqlen` used during training -------------------------------------------------------------------------------- /attributes.py: -------------------------------------------------------------------------------- 1 | from utils import pickle_load 2 | 3 | import os, pickle 4 | import numpy as np 5 | from collections import Counter 6 | 7 | data_dir = 'remi_dataset' 8 | polyph_out_dir = 'remi_dataset/attr_cls/polyph' 9 | rhythm_out_dir = 'remi_dataset/attr_cls/rhythm' 10 | 11 | rhym_intensity_bounds = [0.2, 0.25, 0.32, 0.38, 0.44, 0.5, 0.63] 12 | polyphonicity_bounds = [2.63, 3.06, 3.50, 4.00, 4.63, 5.44, 6.44] 13 | 14 | def compute_polyphonicity(events, n_bars): 15 | poly_record = np.zeros( (n_bars * 16,) ) 16 | 17 | cur_bar, cur_pos = -1, -1 18 | for ev in events: 19 | if ev['name'] == 'Bar': 20 | cur_bar += 1 21 | elif ev['name'] == 'Beat': 22 | cur_pos = int(ev['value']) 23 | elif ev['name'] == 'Note_Duration': 24 | duration = int(ev['value']) // 120 25 | st = cur_bar * 16 + cur_pos 26 | poly_record[st:st + duration] += 1 27 | 28 | return poly_record 29 | 30 | def get_onsets_timing(events, n_bars): 31 | onset_record = np.zeros( (n_bars * 16,) ) 32 | 33 | cur_bar, cur_pos = -1, -1 34 | for ev in events: 35 | if ev['name'] == 'Bar': 36 | cur_bar += 1 37 | elif ev['name'] == 'Beat': 38 | cur_pos = int(ev['value']) 39 | elif ev['name'] == 'Note_Pitch': 40 | rec_idx = cur_bar * 16 + cur_pos 41 | onset_record[ rec_idx ] = 1 42 | 43 | return onset_record 44 | 45 | if __name__ == "__main__": 46 | pieces = [p for p in sorted(os.listdir(data_dir)) if '.pkl' in p] 47 | all_r_cls = [] 48 | all_p_cls = [] 49 | 50 | if not os.path.exists(polyph_out_dir): 51 | os.makedirs(polyph_out_dir) 52 | if not os.path.exists(rhythm_out_dir): 53 | os.makedirs(rhythm_out_dir) 54 | 55 | for p in pieces: 56 | bar_pos, events = pickle_load(os.path.join(data_dir, p)) 57 | events = events[ :bar_pos[-1] ] 58 | 59 | polyph_raw = np.reshape( 60 | compute_polyphonicity(events, n_bars=len(bar_pos)), (-1, 16) 61 | ) 62 | rhythm_raw = np.reshape( 63 | get_onsets_timing(events, n_bars=len(bar_pos)), (-1, 16) 64 | ) 65 | 66 | polyph_cls = np.searchsorted(polyphonicity_bounds, np.mean(polyph_raw, axis=-1)).tolist() 67 | rfreq_cls = np.searchsorted(rhym_intensity_bounds, np.mean(rhythm_raw, axis=-1)).tolist() 68 | 69 | pickle.dump(polyph_cls, open(os.path.join( 70 | polyph_out_dir, p), 'wb' 71 | )) 72 | pickle.dump(rfreq_cls, open(os.path.join( 73 | rhythm_out_dir, p), 'wb' 74 | )) 75 | 76 | all_r_cls.extend(rfreq_cls) 77 | all_p_cls.extend(polyph_cls) 78 | 79 | print ('[rhythm classes]', Counter(all_r_cls)) 80 | print ('[polyph classes]', Counter(all_p_cls)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MuseMorphose 2 | 3 | This repository contains the official implementation of the following paper: 4 | 5 | * Shih-Lun Wu, Yi-Hsuan Yang 6 | **_MuseMorphose_: Full-Song and Fine-Grained Piano Music Style Transfer with One Transformer VAE** 7 | accepted to _IEEE/ACM Trans. Audio, Speech, & Language Processing (TASLP)_, Dec 2022 [arXiv] [demo website] 8 | 9 | ## Prerequisites 10 | * Python >= 3.6 11 | * Install dependencies 12 | ```bash 13 | pip3 install -r requirements.txt 14 | ``` 15 | * GPU with >6GB RAM (optional, but recommended) 16 | 17 | ## Preprocessing 18 | ```bash 19 | # download REMI-pop-1.7K dataset 20 | wget -O remi_dataset.tar.gz https://zenodo.org/record/4782721/files/remi_dataset.tar.gz?download=1 21 | tar xzvf remi_dataset.tar.gz 22 | rm remi_dataset.tar.gz 23 | 24 | # compute attributes classes 25 | python3 attributes.py 26 | ``` 27 | 28 | ## Training 29 | ```bash 30 | python3 train.py [config file] 31 | ``` 32 | * e.g. 33 | ```bash 34 | python3 train.py config/default.yaml 35 | ``` 36 | * Or, you may download the pretrained weights straight away 37 | ```bash 38 | wget -O musemorphose_pretrained_weights.pt https://zenodo.org/record/5119525/files/musemorphose_pretrained_weights.pt?download=1 39 | ``` 40 | 41 | ## Generation 42 | ```bash 43 | python3 generate.py [config file] [ckpt path] [output dir] [num pieces] [num samples per piece] 44 | ``` 45 | * e.g. 46 | ```bash 47 | python3 generate.py config/default.yaml musemorphose_pretrained_weights.pt generations/ 10 5 48 | ``` 49 | 50 | This script will randomly draw the specified # of pieces from the test set. 51 | For each sample of a piece, the _rhythmic intensity_ and _polyphonicity_ will be shifted entirely and randomly by \[-3, 3\] classes for the model to generate style-transferred music. 52 | You may modify `random_shift_attr_cls()` in `generate.py` or write your own function to set the attributes. 53 | 54 | ## Customized Generation (To Be Added) 55 | We welcome the community's suggestions and contributions for an interface on which users may 56 | * upload their own MIDIs, and 57 | * set their desired bar-level attributes easily 58 | 59 | ## Citation BibTex 60 | If you find this work helpful and use our code in your research, please kindly cite our paper: 61 | ``` 62 | @article{wu2023musemorphose, 63 | title={{MuseMorphose}: Full-Song and Fine-Grained Piano Music Style Transfer with One {Transformer VAE}}, 64 | author={Shih-Lun Wu and Yi-Hsuan Yang}, 65 | year={2023}, 66 | journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, 67 | } 68 | ``` 69 | -------------------------------------------------------------------------------- /model/transformer_helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | def generate_causal_mask(seq_len): 8 | mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1) 9 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 10 | mask.requires_grad = False 11 | return mask 12 | 13 | def weight_init_normal(weight, normal_std): 14 | nn.init.normal_(weight, 0.0, normal_std) 15 | 16 | def weight_init_orthogonal(weight, gain): 17 | nn.init.orthogonal_(weight, gain) 18 | 19 | def bias_init(bias): 20 | nn.init.constant_(bias, 0.0) 21 | 22 | def weights_init(m): 23 | classname = m.__class__.__name__ 24 | # print ('[{}] initializing ...'.format(classname)) 25 | 26 | if classname.find('Linear') != -1: 27 | if hasattr(m, 'weight') and m.weight is not None: 28 | weight_init_normal(m.weight, 0.01) 29 | if hasattr(m, 'bias') and m.bias is not None: 30 | bias_init(m.bias) 31 | elif classname.find('Embedding') != -1: 32 | if hasattr(m, 'weight'): 33 | weight_init_normal(m.weight, 0.01) 34 | elif classname.find('LayerNorm') != -1: 35 | if hasattr(m, 'weight'): 36 | nn.init.normal_(m.weight, 1.0, 0.01) 37 | if hasattr(m, 'bias') and m.bias is not None: 38 | bias_init(m.bias) 39 | elif classname.find('GRU') != -1: 40 | for param in m.parameters(): 41 | if len(param.shape) >= 2: # weights 42 | weight_init_orthogonal(param, 0.01) 43 | else: # biases 44 | bias_init(param) 45 | # else: 46 | # print ('[{}] not initialized !!'.format(classname)) 47 | 48 | class PositionalEncoding(nn.Module): 49 | def __init__(self, d_embed, max_pos=20480): 50 | super(PositionalEncoding, self).__init__() 51 | self.d_embed = d_embed 52 | self.max_pos = max_pos 53 | 54 | pe = torch.zeros(max_pos, d_embed) 55 | position = torch.arange(0, max_pos, dtype=torch.float).unsqueeze(1) 56 | div_term = torch.exp(torch.arange(0, d_embed, 2).float() * (-math.log(10000.0) / d_embed)) 57 | pe[:, 0::2] = torch.sin(position * div_term) 58 | pe[:, 1::2] = torch.cos(position * div_term) 59 | pe = pe.unsqueeze(0).transpose(0, 1) 60 | self.register_buffer('pe', pe) 61 | 62 | def forward(self, seq_len, bsz=None): 63 | pos_encoding = self.pe[:seq_len, :] 64 | 65 | if bsz is not None: 66 | pos_encoding = pos_encoding.expand(seq_len, bsz, -1) 67 | 68 | return pos_encoding 69 | 70 | class TokenEmbedding(nn.Module): 71 | def __init__(self, n_token, d_embed, d_proj): 72 | super(TokenEmbedding, self).__init__() 73 | 74 | self.n_token = n_token 75 | self.d_embed = d_embed 76 | self.d_proj = d_proj 77 | self.emb_scale = d_proj ** 0.5 78 | 79 | self.emb_lookup = nn.Embedding(n_token, d_embed) 80 | if d_proj != d_embed: 81 | self.emb_proj = nn.Linear(d_embed, d_proj, bias=False) 82 | else: 83 | self.emb_proj = None 84 | 85 | def forward(self, inp_tokens): 86 | inp_emb = self.emb_lookup(inp_tokens) 87 | 88 | if self.emb_proj is not None: 89 | inp_emb = self.emb_proj(inp_emb) 90 | 91 | return inp_emb.mul_(self.emb_scale) -------------------------------------------------------------------------------- /remi2midi.py: -------------------------------------------------------------------------------- 1 | import os, pickle, random, copy 2 | import numpy as np 3 | 4 | import miditoolkit 5 | 6 | ############################## 7 | # constants 8 | ############################## 9 | DEFAULT_BEAT_RESOL = 480 10 | DEFAULT_BAR_RESOL = 480 * 4 11 | DEFAULT_FRACTION = 16 12 | 13 | 14 | ############################## 15 | # containers for conversion 16 | ############################## 17 | class ConversionEvent(object): 18 | def __init__(self, event, is_full_event=False): 19 | if not is_full_event: 20 | if 'Note' in event: 21 | self.name, self.value = '_'.join(event.split('_')[:-1]), event.split('_')[-1] 22 | elif 'Chord' in event: 23 | self.name, self.value = event.split('_')[0], '_'.join(event.split('_')[1:]) 24 | else: 25 | self.name, self.value = event.split('_') 26 | else: 27 | self.name, self.value = event['name'], event['value'] 28 | def __repr__(self): 29 | return 'Event(name: {} | value: {})'.format(self.name, self.value) 30 | 31 | class NoteEvent(object): 32 | def __init__(self, pitch, bar, position, duration, velocity): 33 | self.pitch = pitch 34 | self.start_tick = bar * DEFAULT_BAR_RESOL + position * (DEFAULT_BAR_RESOL // DEFAULT_FRACTION) 35 | self.duration = duration 36 | self.velocity = velocity 37 | 38 | class TempoEvent(object): 39 | def __init__(self, tempo, bar, position): 40 | self.tempo = tempo 41 | self.start_tick = bar * DEFAULT_BAR_RESOL + position * (DEFAULT_BAR_RESOL // DEFAULT_FRACTION) 42 | 43 | class ChordEvent(object): 44 | def __init__(self, chord_val, bar, position): 45 | self.chord_val = chord_val 46 | self.start_tick = bar * DEFAULT_BAR_RESOL + position * (DEFAULT_BAR_RESOL // DEFAULT_FRACTION) 47 | 48 | ############################## 49 | # conversion functions 50 | ############################## 51 | def read_generated_txt(generated_path): 52 | f = open(generated_path, 'r') 53 | return f.read().splitlines() 54 | 55 | def remi2midi(events, output_midi_path=None, is_full_event=False, return_first_tempo=False, enforce_tempo=False, enforce_tempo_val=None): 56 | events = [ConversionEvent(ev, is_full_event=is_full_event) for ev in events] 57 | # print (events[:20]) 58 | 59 | assert events[0].name == 'Bar' 60 | temp_notes = [] 61 | temp_tempos = [] 62 | temp_chords = [] 63 | 64 | cur_bar = 0 65 | cur_position = 0 66 | 67 | for i in range(len(events)): 68 | if events[i].name == 'Bar': 69 | if i > 0: 70 | cur_bar += 1 71 | elif events[i].name == 'Beat': 72 | cur_position = int(events[i].value) 73 | assert cur_position >= 0 and cur_position < DEFAULT_FRACTION 74 | elif events[i].name == 'Tempo': 75 | temp_tempos.append(TempoEvent( 76 | int(events[i].value), cur_bar, cur_position 77 | )) 78 | elif 'Note_Pitch' in events[i].name and \ 79 | (i+1) < len(events) and 'Note_Velocity' in events[i+1].name and \ 80 | (i+2) < len(events) and 'Note_Duration' in events[i+2].name: 81 | # check if the 3 events are of the same instrument 82 | temp_notes.append( 83 | NoteEvent( 84 | pitch=int(events[i].value), 85 | bar=cur_bar, position=cur_position, 86 | duration=int(events[i+2].value), velocity=int(events[i+1].value) 87 | ) 88 | ) 89 | elif 'Chord' in events[i].name: 90 | temp_chords.append( 91 | ChordEvent(events[i].value, cur_bar, cur_position) 92 | ) 93 | elif events[i].name in ['EOS', 'PAD']: 94 | continue 95 | 96 | # print (len(temp_tempos), len(temp_notes)) 97 | midi_obj = miditoolkit.midi.parser.MidiFile() 98 | midi_obj.instruments = [ 99 | miditoolkit.Instrument(program=0, is_drum=False, name='Piano') 100 | ] 101 | 102 | for n in temp_notes: 103 | midi_obj.instruments[0].notes.append( 104 | miditoolkit.Note(int(n.velocity), n.pitch, int(n.start_tick), int(n.start_tick + n.duration)) 105 | ) 106 | 107 | if enforce_tempo is False: 108 | for t in temp_tempos: 109 | midi_obj.tempo_changes.append( 110 | miditoolkit.TempoChange(t.tempo, int(t.start_tick)) 111 | ) 112 | else: 113 | if enforce_tempo_val is None: 114 | enforce_tempo_val = temp_tempos[1] 115 | for t in enforce_tempo_val: 116 | midi_obj.tempo_changes.append( 117 | miditoolkit.TempoChange(t.tempo, int(t.start_tick)) 118 | ) 119 | 120 | 121 | for c in temp_chords: 122 | midi_obj.markers.append( 123 | miditoolkit.Marker('Chord-{}'.format(c.chord_val), int(c.start_tick)) 124 | ) 125 | for b in range(cur_bar): 126 | midi_obj.markers.append( 127 | miditoolkit.Marker('Bar-{}'.format(b+1), int(DEFAULT_BAR_RESOL * b)) 128 | ) 129 | 130 | if output_midi_path is not None: 131 | midi_obj.dump(output_midi_path) 132 | 133 | if not return_first_tempo: 134 | return midi_obj 135 | else: 136 | return midi_obj, temp_tempos -------------------------------------------------------------------------------- /model/musemorphose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from transformer_encoder import VAETransformerEncoder 5 | from transformer_helpers import ( 6 | weights_init, PositionalEncoding, TokenEmbedding, generate_causal_mask 7 | ) 8 | 9 | class VAETransformerDecoder(nn.Module): 10 | def __init__(self, n_layer, n_head, d_model, d_ff, d_seg_emb, dropout=0.1, activation='relu', cond_mode='in-attn'): 11 | super(VAETransformerDecoder, self).__init__() 12 | self.n_layer = n_layer 13 | self.n_head = n_head 14 | self.d_model = d_model 15 | self.d_ff = d_ff 16 | self.d_seg_emb = d_seg_emb 17 | self.dropout = dropout 18 | self.activation = activation 19 | self.cond_mode = cond_mode 20 | 21 | if cond_mode == 'in-attn': 22 | self.seg_emb_proj = nn.Linear(d_seg_emb, d_model, bias=False) 23 | elif cond_mode == 'pre-attn': 24 | self.seg_emb_proj = nn.Linear(d_seg_emb + d_model, d_model, bias=False) 25 | 26 | self.decoder_layers = nn.ModuleList() 27 | for i in range(n_layer): 28 | self.decoder_layers.append( 29 | nn.TransformerEncoderLayer(d_model, n_head, d_ff, dropout, activation) 30 | ) 31 | 32 | def forward(self, x, seg_emb): 33 | if not hasattr(self, 'cond_mode'): 34 | self.cond_mode = 'in-attn' 35 | attn_mask = generate_causal_mask(x.size(0)).to(x.device) 36 | # print (attn_mask.size()) 37 | 38 | if self.cond_mode == 'in-attn': 39 | seg_emb = self.seg_emb_proj(seg_emb) 40 | elif self.cond_mode == 'pre-attn': 41 | x = torch.cat([x, seg_emb], dim=-1) 42 | x = self.seg_emb_proj(x) 43 | 44 | out = x 45 | for i in range(self.n_layer): 46 | if self.cond_mode == 'in-attn': 47 | out += seg_emb 48 | out = self.decoder_layers[i](out, src_mask=attn_mask) 49 | 50 | return out 51 | 52 | class MuseMorphose(nn.Module): 53 | def __init__(self, enc_n_layer, enc_n_head, enc_d_model, enc_d_ff, 54 | dec_n_layer, dec_n_head, dec_d_model, dec_d_ff, 55 | d_vae_latent, d_embed, n_token, 56 | enc_dropout=0.1, enc_activation='relu', 57 | dec_dropout=0.1, dec_activation='relu', 58 | d_rfreq_emb=32, d_polyph_emb=32, 59 | n_rfreq_cls=8, n_polyph_cls=8, 60 | is_training=True, use_attr_cls=True, 61 | cond_mode='in-attn' 62 | ): 63 | super(MuseMorphose, self).__init__() 64 | self.enc_n_layer = enc_n_layer 65 | self.enc_n_head = enc_n_head 66 | self.enc_d_model = enc_d_model 67 | self.enc_d_ff = enc_d_ff 68 | self.enc_dropout = enc_dropout 69 | self.enc_activation = enc_activation 70 | 71 | self.dec_n_layer = dec_n_layer 72 | self.dec_n_head = dec_n_head 73 | self.dec_d_model = dec_d_model 74 | self.dec_d_ff = dec_d_ff 75 | self.dec_dropout = dec_dropout 76 | self.dec_activation = dec_activation 77 | 78 | self.d_vae_latent = d_vae_latent 79 | self.n_token = n_token 80 | self.is_training = is_training 81 | 82 | self.cond_mode = cond_mode 83 | self.token_emb = TokenEmbedding(n_token, d_embed, enc_d_model) 84 | self.d_embed = d_embed 85 | self.pe = PositionalEncoding(d_embed) 86 | self.dec_out_proj = nn.Linear(dec_d_model, n_token) 87 | self.encoder = VAETransformerEncoder( 88 | enc_n_layer, enc_n_head, enc_d_model, enc_d_ff, d_vae_latent, enc_dropout, enc_activation 89 | ) 90 | 91 | self.use_attr_cls = use_attr_cls 92 | if use_attr_cls: 93 | self.decoder = VAETransformerDecoder( 94 | dec_n_layer, dec_n_head, dec_d_model, dec_d_ff, d_vae_latent + d_polyph_emb + d_rfreq_emb, 95 | dropout=dec_dropout, activation=dec_activation, 96 | cond_mode=cond_mode 97 | ) 98 | else: 99 | self.decoder = VAETransformerDecoder( 100 | dec_n_layer, dec_n_head, dec_d_model, dec_d_ff, d_vae_latent, 101 | dropout=dec_dropout, activation=dec_activation, 102 | cond_mode=cond_mode 103 | ) 104 | 105 | if use_attr_cls: 106 | self.d_rfreq_emb = d_rfreq_emb 107 | self.d_polyph_emb = d_polyph_emb 108 | self.rfreq_attr_emb = TokenEmbedding(n_rfreq_cls, d_rfreq_emb, d_rfreq_emb) 109 | self.polyph_attr_emb = TokenEmbedding(n_polyph_cls, d_polyph_emb, d_polyph_emb) 110 | else: 111 | self.rfreq_attr_emb = None 112 | self.polyph_attr_emb = None 113 | 114 | self.emb_dropout = nn.Dropout(self.enc_dropout) 115 | self.apply(weights_init) 116 | 117 | 118 | def reparameterize(self, mu, logvar, use_sampling=True, sampling_var=1.): 119 | std = torch.exp(0.5 * logvar).to(mu.device) 120 | if use_sampling: 121 | eps = torch.randn_like(std).to(mu.device) * sampling_var 122 | else: 123 | eps = torch.zeros_like(std).to(mu.device) 124 | 125 | return eps * std + mu 126 | 127 | def get_sampled_latent(self, inp, padding_mask=None, use_sampling=False, sampling_var=0.): 128 | token_emb = self.token_emb(inp) 129 | enc_inp = self.emb_dropout(token_emb) + self.pe(inp.size(0)) 130 | 131 | _, mu, logvar = self.encoder(enc_inp, padding_mask=padding_mask) 132 | mu, logvar = mu.reshape(-1, mu.size(-1)), logvar.reshape(-1, mu.size(-1)) 133 | vae_latent = self.reparameterize(mu, logvar, use_sampling=use_sampling, sampling_var=sampling_var) 134 | 135 | return vae_latent 136 | 137 | def generate(self, inp, dec_seg_emb, rfreq_cls=None, polyph_cls=None, keep_last_only=True): 138 | token_emb = self.token_emb(inp) 139 | dec_inp = self.emb_dropout(token_emb) + self.pe(inp.size(0)) 140 | 141 | if rfreq_cls is not None and polyph_cls is not None: 142 | dec_rfreq_emb = self.rfreq_attr_emb(rfreq_cls) 143 | dec_polyph_emb = self.polyph_attr_emb(polyph_cls) 144 | dec_seg_emb_cat = torch.cat([dec_seg_emb, dec_rfreq_emb, dec_polyph_emb], dim=-1) 145 | else: 146 | dec_seg_emb_cat = dec_seg_emb 147 | 148 | out = self.decoder(dec_inp, dec_seg_emb_cat) 149 | out = self.dec_out_proj(out) 150 | 151 | if keep_last_only: 152 | out = out[-1, ...] 153 | 154 | return out 155 | 156 | 157 | def forward(self, enc_inp, dec_inp, dec_inp_bar_pos, rfreq_cls=None, polyph_cls=None, padding_mask=None): 158 | # [shape of enc_inp] (seqlen_per_bar, bsize, n_bars_per_sample) 159 | enc_bt_size, enc_n_bars = enc_inp.size(1), enc_inp.size(2) 160 | enc_token_emb = self.token_emb(enc_inp) 161 | 162 | # [shape of dec_inp] (seqlen_per_sample, bsize) 163 | # [shape of rfreq_cls & polyph_cls] same as above 164 | # -- (should copy each bar's label to all corresponding indices) 165 | dec_token_emb = self.token_emb(dec_inp) 166 | 167 | enc_token_emb = enc_token_emb.reshape( 168 | enc_inp.size(0), -1, enc_token_emb.size(-1) 169 | ) 170 | enc_inp = self.emb_dropout(enc_token_emb) + self.pe(enc_inp.size(0)) 171 | dec_inp = self.emb_dropout(dec_token_emb) + self.pe(dec_inp.size(0)) 172 | 173 | # [shape of padding_mask] (bsize, n_bars_per_sample, seqlen_per_bar) 174 | # -- should be `True` for padded indices (i.e., those >= seqlen of the bar), `False` otherwise 175 | if padding_mask is not None: 176 | padding_mask = padding_mask.reshape(-1, padding_mask.size(-1)) 177 | 178 | _, mu, logvar = self.encoder(enc_inp, padding_mask=padding_mask) 179 | vae_latent = self.reparameterize(mu, logvar) 180 | vae_latent_reshaped = vae_latent.reshape(enc_bt_size, enc_n_bars, -1) 181 | 182 | dec_seg_emb = torch.zeros(dec_inp.size(0), dec_inp.size(1), self.d_vae_latent).to(vae_latent.device) 183 | for n in range(dec_inp.size(1)): 184 | # [shape of dec_inp_bar_pos] (bsize, n_bars_per_sample + 1) 185 | # -- stores [[start idx of bar #1, sample #1, ..., start idx of bar #K, sample #1, seqlen of sample #1], [same for another sample], ...] 186 | for b, (st, ed) in enumerate(zip(dec_inp_bar_pos[n, :-1], dec_inp_bar_pos[n, 1:])): 187 | dec_seg_emb[st:ed, n, :] = vae_latent_reshaped[n, b, :] 188 | 189 | if rfreq_cls is not None and polyph_cls is not None and self.use_attr_cls: 190 | dec_rfreq_emb = self.rfreq_attr_emb(rfreq_cls) 191 | dec_polyph_emb = self.polyph_attr_emb(polyph_cls) 192 | dec_seg_emb_cat = torch.cat([dec_seg_emb, dec_rfreq_emb, dec_polyph_emb], dim=-1) 193 | else: 194 | dec_seg_emb_cat = dec_seg_emb 195 | 196 | dec_out = self.decoder(dec_inp, dec_seg_emb_cat) 197 | dec_logits = self.dec_out_proj(dec_out) 198 | 199 | return mu, logvar, dec_logits 200 | 201 | def compute_loss(self, mu, logvar, beta, fb_lambda, dec_logits, dec_tgt): 202 | recons_loss = F.cross_entropy( 203 | dec_logits.view(-1, dec_logits.size(-1)), dec_tgt.contiguous().view(-1), 204 | ignore_index=self.n_token - 1, reduction='mean' 205 | ).float() 206 | 207 | kl_raw = -0.5 * (1 + logvar - mu ** 2 - logvar.exp()).mean(dim=0) 208 | kl_before_free_bits = kl_raw.mean() 209 | kl_after_free_bits = kl_raw.clamp(min=fb_lambda) 210 | kldiv_loss = kl_after_free_bits.mean() 211 | 212 | return { 213 | 'beta': beta, 214 | 'total_loss': recons_loss + beta * kldiv_loss, 215 | 'kldiv_loss': kldiv_loss, 216 | 'kldiv_raw': kl_before_free_bits, 217 | 'recons_loss': recons_loss 218 | } -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys, os, time 2 | sys.path.append('./model') 3 | 4 | from model.musemorphose import MuseMorphose 5 | from dataloader import REMIFullSongTransformerDataset 6 | from torch.utils.data import DataLoader 7 | 8 | from utils import pickle_load 9 | from torch import nn, optim 10 | import torch 11 | import numpy as np 12 | 13 | import yaml 14 | config_path = sys.argv[1] 15 | config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader) 16 | 17 | device = config['training']['device'] 18 | trained_steps = config['training']['trained_steps'] 19 | lr_decay_steps = config['training']['lr_decay_steps'] 20 | lr_warmup_steps = config['training']['lr_warmup_steps'] 21 | no_kl_steps = config['training']['no_kl_steps'] 22 | kl_cycle_steps = config['training']['kl_cycle_steps'] 23 | kl_max_beta = config['training']['kl_max_beta'] 24 | free_bit_lambda = config['training']['free_bit_lambda'] 25 | max_lr, min_lr = config['training']['max_lr'], config['training']['min_lr'] 26 | 27 | ckpt_dir = config['training']['ckpt_dir'] 28 | params_dir = os.path.join(ckpt_dir, 'params/') 29 | optim_dir = os.path.join(ckpt_dir, 'optim/') 30 | pretrained_params_path = config['model']['pretrained_params_path'] 31 | pretrained_optim_path = config['model']['pretrained_optim_path'] 32 | ckpt_interval = config['training']['ckpt_interval'] 33 | log_interval = config['training']['log_interval'] 34 | val_interval = config['training']['val_interval'] 35 | constant_kl = config['training']['constant_kl'] 36 | 37 | recons_loss_ema = 0. 38 | kl_loss_ema = 0. 39 | kl_raw_ema = 0. 40 | 41 | def log_epoch(log_file, log_data, is_init=False): 42 | if is_init: 43 | with open(log_file, 'w') as f: 44 | f.write('{:4} {:8} {:12} {:12} {:12} {:12}\n'.format('ep', 'steps', 'recons_loss', 'kldiv_loss', 'kldiv_raw', 'ep_time')) 45 | 46 | with open(log_file, 'a') as f: 47 | f.write('{:<4} {:<8} {:<12} {:<12} {:<12} {:<12}\n'.format( 48 | log_data['ep'], log_data['steps'], round(log_data['recons_loss'], 5), round(log_data['kldiv_loss'], 5), round(log_data['kldiv_raw'], 5), round(log_data['time'], 2) 49 | )) 50 | 51 | def beta_cyclical_sched(step): 52 | step_in_cycle = (step - 1) % kl_cycle_steps 53 | cycle_progress = step_in_cycle / kl_cycle_steps 54 | 55 | if step < no_kl_steps: 56 | return 0. 57 | if cycle_progress < 0.5: 58 | return kl_max_beta * cycle_progress * 2. 59 | else: 60 | return kl_max_beta 61 | 62 | def compute_loss_ema(ema, batch_loss, decay=0.95): 63 | if ema == 0.: 64 | return batch_loss 65 | else: 66 | return batch_loss * (1 - decay) + ema * decay 67 | 68 | def train_model(epoch, model, dloader, dloader_val, optim, sched): 69 | model.train() 70 | 71 | print ('[epoch {:03d}] training ...'.format(epoch)) 72 | print ('[epoch {:03d}] # batches = {}'.format(epoch, len(dloader))) 73 | st = time.time() 74 | 75 | for batch_idx, batch_samples in enumerate(dloader): 76 | model.zero_grad() 77 | batch_enc_inp = batch_samples['enc_input'].permute(2, 0, 1).to(device) 78 | batch_dec_inp = batch_samples['dec_input'].permute(1, 0).to(device) 79 | batch_dec_tgt = batch_samples['dec_target'].permute(1, 0).to(device) 80 | batch_inp_bar_pos = batch_samples['bar_pos'].to(device) 81 | batch_inp_lens = batch_samples['length'] 82 | batch_padding_mask = batch_samples['enc_padding_mask'].to(device) 83 | batch_rfreq_cls = batch_samples['rhymfreq_cls'].permute(1, 0).to(device) 84 | batch_polyph_cls = batch_samples['polyph_cls'].permute(1, 0).to(device) 85 | 86 | global trained_steps 87 | trained_steps += 1 88 | 89 | mu, logvar, dec_logits = model( 90 | batch_enc_inp, batch_dec_inp, 91 | batch_inp_bar_pos, batch_rfreq_cls, batch_polyph_cls, 92 | padding_mask=batch_padding_mask 93 | ) 94 | 95 | if not constant_kl: 96 | kl_beta = beta_cyclical_sched(trained_steps) 97 | else: 98 | kl_beta = kl_max_beta 99 | losses = model.compute_loss(mu, logvar, kl_beta, free_bit_lambda, dec_logits, batch_dec_tgt) 100 | 101 | # anneal learning rate 102 | if trained_steps < lr_warmup_steps: 103 | curr_lr = max_lr * trained_steps / lr_warmup_steps 104 | optim.param_groups[0]['lr'] = curr_lr 105 | else: 106 | sched.step(trained_steps - lr_warmup_steps) 107 | 108 | # clip gradient & update model 109 | losses['total_loss'].backward() 110 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 111 | optim.step() 112 | 113 | global recons_loss_ema, kl_loss_ema, kl_raw_ema 114 | recons_loss_ema = compute_loss_ema(recons_loss_ema, losses['recons_loss'].item()) 115 | kl_loss_ema = compute_loss_ema(kl_loss_ema, losses['kldiv_loss'].item()) 116 | kl_raw_ema = compute_loss_ema(kl_raw_ema, losses['kldiv_raw'].item()) 117 | 118 | print (' -- epoch {:03d} | batch {:03d}: len: {}\n\t * loss = (RC: {:.4f} | KL: {:.4f} | KL_raw: {:.4f}), step = {}, beta: {:.4f} time_elapsed = {:.2f} secs'.format( 119 | epoch, batch_idx, batch_inp_lens, recons_loss_ema, kl_loss_ema, kl_raw_ema, trained_steps, kl_beta, time.time() - st 120 | )) 121 | 122 | if not trained_steps % log_interval: 123 | log_data = { 124 | 'ep': epoch, 125 | 'steps': trained_steps, 126 | 'recons_loss': recons_loss_ema, 127 | 'kldiv_loss': kl_loss_ema, 128 | 'kldiv_raw': kl_raw_ema, 129 | 'time': time.time() - st 130 | } 131 | log_epoch( 132 | os.path.join(ckpt_dir, 'log.txt'), log_data, is_init=not os.path.exists(os.path.join(ckpt_dir, 'log.txt')) 133 | ) 134 | 135 | if not trained_steps % val_interval: 136 | vallosses = validate(model, dloader_val) 137 | with open(os.path.join(ckpt_dir, 'valloss.txt'), 'a') as f: 138 | f.write('[step {}] RC: {:.4f} | KL: {:.4f} | [val] | RC: {:.4f} | KL: {:.4f}\n'.format( 139 | trained_steps, 140 | recons_loss_ema, 141 | kl_raw_ema, 142 | np.mean(vallosses[0]), 143 | np.mean(vallosses[1]) 144 | )) 145 | model.train() 146 | 147 | if not trained_steps % ckpt_interval: 148 | torch.save(model.state_dict(), 149 | os.path.join(params_dir, 'step_{:d}-RC_{:.3f}-KL_{:.3f}-model.pt'.format( 150 | trained_steps, 151 | recons_loss_ema, 152 | kl_raw_ema 153 | )) 154 | ) 155 | torch.save(optim.state_dict(), 156 | os.path.join(optim_dir, 'step_{:d}-RC_{:.3f}-KL_{:.3f}-optim.pt'.format( 157 | trained_steps, 158 | recons_loss_ema, 159 | kl_raw_ema 160 | )) 161 | ) 162 | 163 | print ('[epoch {:03d}] training completed\n -- loss = (RC: {:.4f} | KL: {:.4f} | KL_raw: {:.4f})\n -- time elapsed = {:.2f} secs.'.format( 164 | epoch, recons_loss_ema, kl_loss_ema, kl_raw_ema, time.time() - st 165 | )) 166 | log_data = { 167 | 'ep': epoch, 168 | 'steps': trained_steps, 169 | 'recons_loss': recons_loss_ema, 170 | 'kldiv_loss': kl_loss_ema, 171 | 'kldiv_raw': kl_raw_ema, 172 | 'time': time.time() - st 173 | } 174 | log_epoch( 175 | os.path.join(ckpt_dir, 'log.txt'), log_data, is_init=not os.path.exists(os.path.join(ckpt_dir, 'log.txt')) 176 | ) 177 | 178 | def validate(model, dloader, n_rounds=8, use_attr_cls=True): 179 | model.eval() 180 | loss_rec = [] 181 | kl_loss_rec = [] 182 | 183 | print ('[info] validating ...') 184 | with torch.no_grad(): 185 | for i in range(n_rounds): 186 | print ('[round {}]'.format(i+1)) 187 | 188 | for batch_idx, batch_samples in enumerate(dloader): 189 | model.zero_grad() 190 | 191 | batch_enc_inp = batch_samples['enc_input'].permute(2, 0, 1).to(device) 192 | batch_dec_inp = batch_samples['dec_input'].permute(1, 0).to(device) 193 | batch_dec_tgt = batch_samples['dec_target'].permute(1, 0).to(device) 194 | batch_inp_bar_pos = batch_samples['bar_pos'].to(device) 195 | batch_padding_mask = batch_samples['enc_padding_mask'].to(device) 196 | if use_attr_cls: 197 | batch_rfreq_cls = batch_samples['rhymfreq_cls'].permute(1, 0).to(device) 198 | batch_polyph_cls = batch_samples['polyph_cls'].permute(1, 0).to(device) 199 | else: 200 | batch_rfreq_cls = None 201 | batch_polyph_cls = None 202 | 203 | mu, logvar, dec_logits = model( 204 | batch_enc_inp, batch_dec_inp, 205 | batch_inp_bar_pos, batch_rfreq_cls, batch_polyph_cls, 206 | padding_mask=batch_padding_mask 207 | ) 208 | 209 | losses = model.compute_loss(mu, logvar, 0.0, 0.0, dec_logits, batch_dec_tgt) 210 | if not (batch_idx + 1) % 10: 211 | print ('batch #{}:'.format(batch_idx + 1), round(losses['recons_loss'].item(), 3)) 212 | 213 | loss_rec.append(losses['recons_loss'].item()) 214 | kl_loss_rec.append(losses['kldiv_raw'].item()) 215 | 216 | return loss_rec, kl_loss_rec 217 | 218 | if __name__ == "__main__": 219 | dset = REMIFullSongTransformerDataset( 220 | config['data']['data_dir'], config['data']['vocab_path'], 221 | do_augment=True, 222 | model_enc_seqlen=config['data']['enc_seqlen'], 223 | model_dec_seqlen=config['data']['dec_seqlen'], 224 | model_max_bars=config['data']['max_bars'], 225 | pieces=pickle_load(config['data']['train_split']), 226 | pad_to_same=True 227 | ) 228 | dset_val = REMIFullSongTransformerDataset( 229 | config['data']['data_dir'], config['data']['vocab_path'], 230 | do_augment=False, 231 | model_enc_seqlen=config['data']['enc_seqlen'], 232 | model_dec_seqlen=config['data']['dec_seqlen'], 233 | model_max_bars=config['data']['max_bars'], 234 | pieces=pickle_load(config['data']['val_split']), 235 | pad_to_same=True 236 | ) 237 | print ('[info]', '# training samples:', len(dset.pieces)) 238 | 239 | dloader = DataLoader(dset, batch_size=config['data']['batch_size'], shuffle=True, num_workers=8) 240 | dloader_val = DataLoader(dset_val, batch_size=config['data']['batch_size'], shuffle=True, num_workers=8) 241 | 242 | mconf = config['model'] 243 | model = MuseMorphose( 244 | mconf['enc_n_layer'], mconf['enc_n_head'], mconf['enc_d_model'], mconf['enc_d_ff'], 245 | mconf['dec_n_layer'], mconf['dec_n_head'], mconf['dec_d_model'], mconf['dec_d_ff'], 246 | mconf['d_latent'], mconf['d_embed'], dset.vocab_size, 247 | d_polyph_emb=mconf['d_polyph_emb'], d_rfreq_emb=mconf['d_rfreq_emb'], 248 | cond_mode=mconf['cond_mode'] 249 | ).to(device) 250 | if pretrained_params_path: 251 | model.load_state_dict( torch.load(pretrained_params_path) ) 252 | 253 | model.train() 254 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 255 | print ('[info] model # params:', n_params) 256 | 257 | opt_params = filter(lambda p: p.requires_grad, model.parameters()) 258 | optimizer = optim.Adam(opt_params, lr=max_lr) 259 | if pretrained_optim_path: 260 | optimizer.load_state_dict( torch.load(pretrained_optim_path) ) 261 | scheduler = optim.lr_scheduler.CosineAnnealingLR( 262 | optimizer, lr_decay_steps, eta_min=min_lr 263 | ) 264 | 265 | if not os.path.exists(ckpt_dir): 266 | os.makedirs(ckpt_dir) 267 | if not os.path.exists(params_dir): 268 | os.makedirs(params_dir) 269 | if not os.path.exists(optim_dir): 270 | os.makedirs(optim_dir) 271 | 272 | for ep in range(config['training']['max_epochs']): 273 | train_model(ep+1, model, dloader, dloader_val, optimizer, scheduler) -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os, pickle, random 2 | from glob import glob 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | IDX_TO_KEY = { 10 | 0: 'A', 11 | 1: 'A#', 12 | 2: 'B', 13 | 3: 'C', 14 | 4: 'C#', 15 | 5: 'D', 16 | 6: 'D#', 17 | 7: 'E', 18 | 8: 'F', 19 | 9: 'F#', 20 | 10: 'G', 21 | 11: 'G#' 22 | } 23 | KEY_TO_IDX = { 24 | v:k for k, v in IDX_TO_KEY.items() 25 | } 26 | 27 | def get_chord_tone(chord_event): 28 | tone = chord_event['value'].split('_')[0] 29 | return tone 30 | 31 | def transpose_chord(chord_event, n_keys): 32 | if chord_event['value'] == 'N_N': 33 | return chord_event 34 | 35 | orig_tone = get_chord_tone(chord_event) 36 | orig_tone_idx = KEY_TO_IDX[orig_tone] 37 | new_tone_idx = (orig_tone_idx + 12 + n_keys) % 12 38 | new_chord_value = chord_event['value'].replace( 39 | '{}_'.format(orig_tone), '{}_'.format(IDX_TO_KEY[new_tone_idx]) 40 | ) 41 | new_chord_event = {'name': chord_event['name'], 'value': new_chord_value} 42 | # print ('keys={}. {} --> {}'.format(n_keys, chord_event, new_chord_event)) 43 | 44 | return new_chord_event 45 | 46 | def check_extreme_pitch(raw_events): 47 | low, high = 128, 0 48 | for ev in raw_events: 49 | if ev['name'] == 'Note_Pitch': 50 | low = min(low, int(ev['value'])) 51 | high = max(high, int(ev['value'])) 52 | 53 | return low, high 54 | 55 | def transpose_events(raw_events, n_keys): 56 | transposed_raw_events = [] 57 | 58 | for ev in raw_events: 59 | if ev['name'] == 'Note_Pitch': 60 | transposed_raw_events.append( 61 | {'name': ev['name'], 'value': ev['value'] + n_keys} 62 | ) 63 | elif ev['name'] == 'Chord': 64 | transposed_raw_events.append( 65 | transpose_chord(ev, n_keys) 66 | ) 67 | else: 68 | transposed_raw_events.append(ev) 69 | 70 | assert len(transposed_raw_events) == len(raw_events) 71 | return transposed_raw_events 72 | 73 | def pickle_load(path): 74 | return pickle.load(open(path, 'rb')) 75 | 76 | def convert_event(event_seq, event2idx, to_ndarr=True): 77 | if isinstance(event_seq[0], dict): 78 | event_seq = [event2idx['{}_{}'.format(e['name'], e['value'])] for e in event_seq] 79 | else: 80 | event_seq = [event2idx[e] for e in event_seq] 81 | 82 | if to_ndarr: 83 | return np.array(event_seq) 84 | else: 85 | return event_seq 86 | 87 | class REMIFullSongTransformerDataset(Dataset): 88 | def __init__(self, data_dir, vocab_file, 89 | model_enc_seqlen=128, model_dec_seqlen=1280, model_max_bars=16, 90 | pieces=[], do_augment=True, augment_range=range(-6, 7), 91 | min_pitch=22, max_pitch=107, pad_to_same=True, use_attr_cls=True, 92 | appoint_st_bar=None, dec_end_pad_value=None): 93 | self.vocab_file = vocab_file 94 | self.read_vocab() 95 | 96 | self.data_dir = data_dir 97 | self.pieces = pieces 98 | self.build_dataset() 99 | 100 | self.model_enc_seqlen = model_enc_seqlen 101 | self.model_dec_seqlen = model_dec_seqlen 102 | self.model_max_bars = model_max_bars 103 | 104 | self.do_augment = do_augment 105 | self.augment_range = augment_range 106 | self.min_pitch, self.max_pitch = min_pitch, max_pitch 107 | self.pad_to_same = pad_to_same 108 | self.use_attr_cls = use_attr_cls 109 | 110 | self.appoint_st_bar = appoint_st_bar 111 | if dec_end_pad_value is None: 112 | self.dec_end_pad_value = self.pad_token 113 | elif dec_end_pad_value == 'EOS': 114 | self.dec_end_pad_value = self.eos_token 115 | else: 116 | self.dec_end_pad_value = self.pad_token 117 | 118 | def read_vocab(self): 119 | vocab = pickle_load(self.vocab_file)[0] 120 | self.idx2event = pickle_load(self.vocab_file)[1] 121 | orig_vocab_size = len(vocab) 122 | self.event2idx = vocab 123 | self.bar_token = self.event2idx['Bar_None'] 124 | self.eos_token = self.event2idx['EOS_None'] 125 | self.pad_token = orig_vocab_size 126 | self.vocab_size = self.pad_token + 1 127 | 128 | def build_dataset(self): 129 | if not self.pieces: 130 | self.pieces = sorted( glob(os.path.join(self.data_dir, '*.pkl')) ) 131 | else: 132 | self.pieces = sorted( [os.path.join(self.data_dir, p) for p in self.pieces] ) 133 | 134 | self.piece_bar_pos = [] 135 | 136 | for i, p in enumerate(self.pieces): 137 | bar_pos, p_evs = pickle_load(p) 138 | if not i % 200: 139 | print ('[preparing data] now at #{}'.format(i)) 140 | if bar_pos[-1] == len(p_evs): 141 | print ('piece {}, got appended bar markers'.format(p)) 142 | bar_pos = bar_pos[:-1] 143 | if len(p_evs) - bar_pos[-1] == 2: 144 | # got empty trailing bar 145 | bar_pos = bar_pos[:-1] 146 | 147 | bar_pos.append(len(p_evs)) 148 | 149 | self.piece_bar_pos.append(bar_pos) 150 | 151 | def get_sample_from_file(self, piece_idx): 152 | piece_evs = pickle_load(self.pieces[piece_idx])[1] 153 | if len(self.piece_bar_pos[piece_idx]) > self.model_max_bars and self.appoint_st_bar is None: 154 | picked_st_bar = random.choice( 155 | range(len(self.piece_bar_pos[piece_idx]) - self.model_max_bars) 156 | ) 157 | elif self.appoint_st_bar is not None and self.appoint_st_bar < len(self.piece_bar_pos[piece_idx]) - self.model_max_bars: 158 | picked_st_bar = self.appoint_st_bar 159 | else: 160 | picked_st_bar = 0 161 | 162 | piece_bar_pos = self.piece_bar_pos[piece_idx] 163 | 164 | if len(piece_bar_pos) > self.model_max_bars: 165 | piece_evs = piece_evs[ piece_bar_pos[picked_st_bar] : piece_bar_pos[picked_st_bar + self.model_max_bars] ] 166 | picked_bar_pos = np.array(piece_bar_pos[ picked_st_bar : picked_st_bar + self.model_max_bars ]) - piece_bar_pos[picked_st_bar] 167 | n_bars = self.model_max_bars 168 | else: 169 | picked_bar_pos = np.array(piece_bar_pos + [piece_bar_pos[-1]] * (self.model_max_bars - len(piece_bar_pos))) 170 | n_bars = len(piece_bar_pos) 171 | assert len(picked_bar_pos) == self.model_max_bars 172 | 173 | return piece_evs, picked_st_bar, picked_bar_pos, n_bars 174 | 175 | def pad_sequence(self, seq, maxlen, pad_value=None): 176 | if pad_value is None: 177 | pad_value = self.pad_token 178 | 179 | seq.extend( [pad_value for _ in range(maxlen- len(seq))] ) 180 | 181 | return seq 182 | 183 | def pitch_augment(self, bar_events): 184 | bar_min_pitch, bar_max_pitch = check_extreme_pitch(bar_events) 185 | 186 | n_keys = random.choice(self.augment_range) 187 | while bar_min_pitch + n_keys < self.min_pitch or bar_max_pitch + n_keys > self.max_pitch: 188 | n_keys = random.choice(self.augment_range) 189 | 190 | augmented_bar_events = transpose_events(bar_events, n_keys) 191 | return augmented_bar_events 192 | 193 | def get_attr_classes(self, piece, st_bar): 194 | polyph_cls = pickle_load(os.path.join(self.data_dir, 'attr_cls/polyph', piece))[st_bar : st_bar + self.model_max_bars] 195 | rfreq_cls = pickle_load(os.path.join(self.data_dir, 'attr_cls/rhythm', piece))[st_bar : st_bar + self.model_max_bars] 196 | 197 | polyph_cls.extend([0 for _ in range(self.model_max_bars - len(polyph_cls))]) 198 | rfreq_cls.extend([0 for _ in range(self.model_max_bars - len(rfreq_cls))]) 199 | 200 | assert len(polyph_cls) == self.model_max_bars 201 | assert len(rfreq_cls) == self.model_max_bars 202 | 203 | return polyph_cls, rfreq_cls 204 | 205 | def get_encoder_input_data(self, bar_positions, bar_events): 206 | assert len(bar_positions) == self.model_max_bars + 1 207 | enc_padding_mask = np.ones((self.model_max_bars, self.model_enc_seqlen), dtype=bool) 208 | enc_padding_mask[:, :2] = False 209 | padded_enc_input = np.full((self.model_max_bars, self.model_enc_seqlen), dtype=int, fill_value=self.pad_token) 210 | enc_lens = np.zeros((self.model_max_bars,)) 211 | 212 | for b, (st, ed) in enumerate(zip(bar_positions[:-1], bar_positions[1:])): 213 | enc_padding_mask[b, : (ed-st)] = False 214 | enc_lens[b] = ed - st 215 | within_bar_events = self.pad_sequence(bar_events[st : ed], self.model_enc_seqlen, self.pad_token) 216 | within_bar_events = np.array(within_bar_events) 217 | 218 | padded_enc_input[b, :] = within_bar_events[:self.model_enc_seqlen] 219 | 220 | return padded_enc_input, enc_padding_mask, enc_lens 221 | 222 | def __len__(self): 223 | return len(self.pieces) 224 | 225 | def __getitem__(self, idx): 226 | if torch.is_tensor(idx): 227 | idx = idx.tolist() 228 | 229 | bar_events, st_bar, bar_pos, enc_n_bars = self.get_sample_from_file(idx) 230 | if self.do_augment: 231 | bar_events = self.pitch_augment(bar_events) 232 | 233 | if self.use_attr_cls: 234 | polyph_cls, rfreq_cls = self.get_attr_classes(os.path.basename(self.pieces[idx]), st_bar) 235 | polyph_cls_expanded = np.zeros((self.model_dec_seqlen,), dtype=int) 236 | rfreq_cls_expanded = np.zeros((self.model_dec_seqlen,), dtype=int) 237 | for i, (b_st, b_ed) in enumerate(zip(bar_pos[:-1], bar_pos[1:])): 238 | polyph_cls_expanded[b_st:b_ed] = polyph_cls[i] 239 | rfreq_cls_expanded[b_st:b_ed] = rfreq_cls[i] 240 | else: 241 | polyph_cls, rfreq_cls = [0], [0] 242 | polyph_cls_expanded, rfreq_cls_expanded = [0], [0] 243 | 244 | bar_tokens = convert_event(bar_events, self.event2idx, to_ndarr=False) 245 | bar_pos = bar_pos.tolist() + [len(bar_tokens)] 246 | 247 | enc_inp, enc_padding_mask, enc_lens = self.get_encoder_input_data(bar_pos, bar_tokens) 248 | 249 | length = len(bar_tokens) 250 | if self.pad_to_same: 251 | inp = self.pad_sequence(bar_tokens, self.model_dec_seqlen + 1) 252 | else: 253 | inp = self.pad_sequence(bar_tokens, len(bar_tokens) + 1, pad_value=self.dec_end_pad_value) 254 | target = np.array(inp[1:], dtype=int) 255 | inp = np.array(inp[:-1], dtype=int) 256 | assert len(inp) == len(target) 257 | 258 | return { 259 | 'id': idx, 260 | 'piece_id': int(os.path.basename(self.pieces[idx]).replace('.pkl', '')), 261 | 'st_bar_id': st_bar, 262 | 'bar_pos': np.array(bar_pos, dtype=int), 263 | 'enc_input': enc_inp, 264 | 'dec_input': inp[:self.model_dec_seqlen], 265 | 'dec_target': target[:self.model_dec_seqlen], 266 | 'polyph_cls': polyph_cls_expanded, 267 | 'rhymfreq_cls': rfreq_cls_expanded, 268 | 'polyph_cls_bar': np.array(polyph_cls), 269 | 'rhymfreq_cls_bar': np.array(rfreq_cls), 270 | 'length': min(length, self.model_dec_seqlen), 271 | 'enc_padding_mask': enc_padding_mask, 272 | 'enc_length': enc_lens, 273 | 'enc_n_bars': enc_n_bars 274 | } 275 | 276 | if __name__ == "__main__": 277 | # codes below are for unit test 278 | dset = REMIFullSongTransformerDataset( 279 | './remi_dataset', './pickles/remi_vocab.pkl', do_augment=True, use_attr_cls=True, 280 | model_max_bars=16, model_dec_seqlen=1280, model_enc_seqlen=128, min_pitch=22, max_pitch=107 281 | ) 282 | print (dset.bar_token, dset.pad_token, dset.vocab_size) 283 | print ('length:', len(dset)) 284 | 285 | # for i in random.sample(range(len(dset)), 100): 286 | # for i in range(len(dset)): 287 | # sample = dset[i] 288 | # print (i, len(sample['bar_pos']), sample['bar_pos']) 289 | # print (i) 290 | # print ('******* ----------- *******') 291 | # print ('piece: {}, st_bar: {}'.format(sample['piece_id'], sample['st_bar_id'])) 292 | # print (sample['enc_input'][:8, :16]) 293 | # print (sample['dec_input'][:16]) 294 | # print (sample['dec_target'][:16]) 295 | # print (sample['enc_padding_mask'][:32, :16]) 296 | # print (sample['length']) 297 | 298 | dloader = DataLoader(dset, batch_size=4, shuffle=False, num_workers=24) 299 | for i, batch in enumerate(dloader): 300 | for k, v in batch.items(): 301 | if torch.is_tensor(v): 302 | print (k, ':', v.dtype, v.size()) 303 | print ('=====================================\n') 304 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import sys, os, random, time 2 | from copy import deepcopy 3 | sys.path.append('./model') 4 | 5 | from dataloader import REMIFullSongTransformerDataset 6 | from model.musemorphose import MuseMorphose 7 | 8 | from utils import pickle_load, numpy_to_tensor, tensor_to_numpy 9 | from remi2midi import remi2midi 10 | 11 | import torch 12 | import yaml 13 | import numpy as np 14 | from scipy.stats import entropy 15 | 16 | config_path = sys.argv[1] 17 | config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader) 18 | 19 | device = config['training']['device'] 20 | data_dir = config['data']['data_dir'] 21 | vocab_path = config['data']['vocab_path'] 22 | data_split = 'pickles/test_pieces.pkl' 23 | 24 | ckpt_path = sys.argv[2] 25 | out_dir = sys.argv[3] 26 | n_pieces = int(sys.argv[4]) 27 | n_samples_per_piece = int(sys.argv[5]) 28 | 29 | ########################################### 30 | # little helpers 31 | ########################################### 32 | def word2event(word_seq, idx2event): 33 | return [ idx2event[w] for w in word_seq ] 34 | 35 | def get_beat_idx(event): 36 | return int(event.split('_')[-1]) 37 | 38 | ########################################### 39 | # sampling utilities 40 | ########################################### 41 | def temperatured_softmax(logits, temperature): 42 | try: 43 | probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature)) 44 | assert np.count_nonzero(np.isnan(probs)) == 0 45 | except: 46 | print ('overflow detected, use 128-bit') 47 | logits = logits.astype(np.float128) 48 | probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature)) 49 | probs = probs.astype(float) 50 | return probs 51 | 52 | def nucleus(probs, p): 53 | probs /= sum(probs) 54 | sorted_probs = np.sort(probs)[::-1] 55 | sorted_index = np.argsort(probs)[::-1] 56 | cusum_sorted_probs = np.cumsum(sorted_probs) 57 | after_threshold = cusum_sorted_probs > p 58 | if sum(after_threshold) > 0: 59 | last_index = np.where(after_threshold)[0][1] 60 | candi_index = sorted_index[:last_index] 61 | else: 62 | candi_index = sorted_index[:3] # just assign a value 63 | candi_probs = np.array([probs[i] for i in candi_index], dtype=np.float64) 64 | candi_probs /= sum(candi_probs) 65 | word = np.random.choice(candi_index, size=1, p=candi_probs)[0] 66 | return word 67 | 68 | ######################################## 69 | # generation 70 | ######################################## 71 | def get_latent_embedding_fast(model, piece_data, use_sampling=False, sampling_var=0.): 72 | # reshape 73 | batch_inp = piece_data['enc_input'].permute(1, 0).long().to(device) 74 | batch_padding_mask = piece_data['enc_padding_mask'].bool().to(device) 75 | 76 | # get latent conditioning vectors 77 | with torch.no_grad(): 78 | piece_latents = model.get_sampled_latent( 79 | batch_inp, padding_mask=batch_padding_mask, 80 | use_sampling=use_sampling, sampling_var=sampling_var 81 | ) 82 | 83 | return piece_latents 84 | 85 | def generate_on_latent_ctrl_vanilla_truncate( 86 | model, latents, rfreq_cls, polyph_cls, event2idx, idx2event, 87 | max_events=12800, primer=None, 88 | max_input_len=1280, truncate_len=512, 89 | nucleus_p=0.9, temperature=1.2 90 | ): 91 | latent_placeholder = torch.zeros(max_events, 1, latents.size(-1)).to(device) 92 | rfreq_placeholder = torch.zeros(max_events, 1, dtype=int).to(device) 93 | polyph_placeholder = torch.zeros(max_events, 1, dtype=int).to(device) 94 | print ('[info] rhythm cls: {} | polyph_cls: {}'.format(rfreq_cls, polyph_cls)) 95 | 96 | if primer is None: 97 | generated = [event2idx['Bar_None']] 98 | else: 99 | generated = [event2idx[e] for e in primer] 100 | latent_placeholder[:len(generated), 0, :] = latents[0].squeeze(0) 101 | rfreq_placeholder[:len(generated), 0] = rfreq_cls[0] 102 | polyph_placeholder[:len(generated), 0] = polyph_cls[0] 103 | 104 | target_bars, generated_bars = latents.size(0), 0 105 | 106 | steps = 0 107 | time_st = time.time() 108 | cur_pos = 0 109 | failed_cnt = 0 110 | 111 | cur_input_len = len(generated) 112 | generated_final = deepcopy(generated) 113 | entropies = [] 114 | 115 | while generated_bars < target_bars: 116 | if len(generated) == 1: 117 | dec_input = numpy_to_tensor([generated], device=device).long() 118 | else: 119 | dec_input = numpy_to_tensor([generated], device=device).permute(1, 0).long() 120 | 121 | latent_placeholder[len(generated)-1, 0, :] = latents[ generated_bars ] 122 | rfreq_placeholder[len(generated)-1, 0] = rfreq_cls[ generated_bars ] 123 | polyph_placeholder[len(generated)-1, 0] = polyph_cls[ generated_bars ] 124 | 125 | dec_seg_emb = latent_placeholder[:len(generated), :] 126 | dec_rfreq_cls = rfreq_placeholder[:len(generated), :] 127 | dec_polyph_cls = polyph_placeholder[:len(generated), :] 128 | 129 | # sampling 130 | with torch.no_grad(): 131 | logits = model.generate(dec_input, dec_seg_emb, dec_rfreq_cls, dec_polyph_cls) 132 | logits = tensor_to_numpy(logits[0]) 133 | probs = temperatured_softmax(logits, temperature) 134 | word = nucleus(probs, nucleus_p) 135 | word_event = idx2event[word] 136 | 137 | if 'Beat' in word_event: 138 | event_pos = get_beat_idx(word_event) 139 | if not event_pos >= cur_pos: 140 | failed_cnt += 1 141 | print ('[info] position not increasing, failed cnt:', failed_cnt) 142 | if failed_cnt >= 128: 143 | print ('[FATAL] model stuck, exiting ...') 144 | return generated 145 | continue 146 | else: 147 | cur_pos = event_pos 148 | failed_cnt = 0 149 | 150 | if 'Bar' in word_event: 151 | generated_bars += 1 152 | cur_pos = 0 153 | print ('[info] generated {} bars, #events = {}'.format(generated_bars, len(generated_final))) 154 | if word_event == 'PAD_None': 155 | continue 156 | 157 | if len(generated) > max_events or (word_event == 'EOS_None' and generated_bars == target_bars - 1): 158 | generated_bars += 1 159 | generated.append(event2idx['Bar_None']) 160 | print ('[info] gotten eos') 161 | break 162 | 163 | generated.append(word) 164 | generated_final.append(word) 165 | entropies.append(entropy(probs)) 166 | 167 | cur_input_len += 1 168 | steps += 1 169 | 170 | assert cur_input_len == len(generated) 171 | if cur_input_len == max_input_len: 172 | generated = generated[-truncate_len:] 173 | latent_placeholder[:len(generated)-1, 0, :] = latent_placeholder[cur_input_len-truncate_len:cur_input_len-1, 0, :] 174 | rfreq_placeholder[:len(generated)-1, 0] = rfreq_placeholder[cur_input_len-truncate_len:cur_input_len-1, 0] 175 | polyph_placeholder[:len(generated)-1, 0] = polyph_placeholder[cur_input_len-truncate_len:cur_input_len-1, 0] 176 | 177 | print ('[info] reset context length: cur_len: {}, accumulated_len: {}, truncate_range: {} ~ {}'.format( 178 | cur_input_len, len(generated_final), cur_input_len-truncate_len, cur_input_len-1 179 | )) 180 | cur_input_len = len(generated) 181 | 182 | assert generated_bars == target_bars 183 | print ('-- generated events:', len(generated_final)) 184 | print ('-- time elapsed: {:.2f} secs'.format(time.time() - time_st)) 185 | return generated_final[:-1], time.time() - time_st, np.array(entropies) 186 | 187 | 188 | ######################################## 189 | # change attribute classes 190 | ######################################## 191 | def random_shift_attr_cls(n_samples, upper=4, lower=-3): 192 | return np.random.randint(lower, upper, (n_samples,)) 193 | 194 | 195 | if __name__ == "__main__": 196 | dset = REMIFullSongTransformerDataset( 197 | data_dir, vocab_path, 198 | do_augment=False, 199 | model_enc_seqlen=config['data']['enc_seqlen'], 200 | model_dec_seqlen=config['generate']['dec_seqlen'], 201 | model_max_bars=config['generate']['max_bars'], 202 | pieces=pickle_load(data_split), 203 | pad_to_same=False 204 | ) 205 | pieces = random.sample(range(len(dset)), n_pieces) 206 | print ('[sampled pieces]', pieces) 207 | 208 | mconf = config['model'] 209 | model = MuseMorphose( 210 | mconf['enc_n_layer'], mconf['enc_n_head'], mconf['enc_d_model'], mconf['enc_d_ff'], 211 | mconf['dec_n_layer'], mconf['dec_n_head'], mconf['dec_d_model'], mconf['dec_d_ff'], 212 | mconf['d_latent'], mconf['d_embed'], dset.vocab_size, 213 | d_polyph_emb=mconf['d_polyph_emb'], d_rfreq_emb=mconf['d_rfreq_emb'], 214 | cond_mode=mconf['cond_mode'] 215 | ).to(device) 216 | model.eval() 217 | model.load_state_dict(torch.load(ckpt_path, map_location='cpu')) 218 | 219 | if not os.path.exists(out_dir): 220 | os.makedirs(out_dir) 221 | 222 | times = [] 223 | for p in pieces: 224 | # fetch test sample 225 | p_data = dset[p] 226 | p_id = p_data['piece_id'] 227 | p_bar_id = p_data['st_bar_id'] 228 | p_data['enc_input'] = p_data['enc_input'][ : p_data['enc_n_bars'] ] 229 | p_data['enc_padding_mask'] = p_data['enc_padding_mask'][ : p_data['enc_n_bars'] ] 230 | 231 | orig_p_cls_str = ''.join(str(c) for c in p_data['polyph_cls_bar']) 232 | orig_r_cls_str = ''.join(str(c) for c in p_data['rhymfreq_cls_bar']) 233 | 234 | orig_song = p_data['dec_input'].tolist()[:p_data['length']] 235 | orig_song = word2event(orig_song, dset.idx2event) 236 | orig_out_file = os.path.join(out_dir, 'id{}_bar{}_orig'.format( 237 | p, p_bar_id 238 | )) 239 | print ('[info] writing to ...', orig_out_file) 240 | # output reference song's MIDI 241 | _, orig_tempo = remi2midi(orig_song, orig_out_file + '.mid', return_first_tempo=True, enforce_tempo=False) 242 | 243 | # save metadata of reference song (events & attr classes) 244 | print (*orig_song, sep='\n', file=open(orig_out_file + '.txt', 'a')) 245 | np.save(orig_out_file + '-POLYCLS.npy', p_data['polyph_cls_bar']) 246 | np.save(orig_out_file + '-RHYMCLS.npy', p_data['rhymfreq_cls_bar']) 247 | 248 | 249 | for k in p_data.keys(): 250 | if not torch.is_tensor(p_data[k]): 251 | p_data[k] = numpy_to_tensor(p_data[k], device=device) 252 | else: 253 | p_data[k] = p_data[k].to(device) 254 | 255 | p_latents = get_latent_embedding_fast( 256 | model, p_data, 257 | use_sampling=config['generate']['use_latent_sampling'], 258 | sampling_var=config['generate']['latent_sampling_var'] 259 | ) 260 | p_cls_diff = random_shift_attr_cls(n_samples_per_piece) 261 | r_cls_diff = random_shift_attr_cls(n_samples_per_piece) 262 | 263 | piece_entropies = [] 264 | for samp in range(n_samples_per_piece): 265 | p_polyph_cls = (p_data['polyph_cls_bar'] + p_cls_diff[samp]).clamp(0, 7).long() 266 | p_rfreq_cls = (p_data['rhymfreq_cls_bar'] + r_cls_diff[samp]).clamp(0, 7).long() 267 | 268 | print ('[info] piece: {}, bar: {}'.format(p_id, p_bar_id)) 269 | out_file = os.path.join(out_dir, 'id{}_bar{}_sample{:02d}_poly{}_rhym{}'.format( 270 | p, p_bar_id, samp + 1, 271 | '+{}'.format(p_cls_diff[samp]) if p_cls_diff[samp] >= 0 else p_cls_diff[samp], 272 | '+{}'.format(r_cls_diff[samp]) if r_cls_diff[samp] >= 0 else r_cls_diff[samp] 273 | )) 274 | print ('[info] writing to ...', out_file) 275 | if os.path.exists(out_file + '.txt'): 276 | print ('[info] file exists, skipping ...') 277 | continue 278 | 279 | # print (p_polyph_cls, p_rfreq_cls) 280 | 281 | # generate 282 | song, t_sec, entropies = generate_on_latent_ctrl_vanilla_truncate( 283 | model, p_latents, p_rfreq_cls, p_polyph_cls, dset.event2idx, dset.idx2event, 284 | max_input_len=config['generate']['max_input_dec_seqlen'], 285 | truncate_len=min(512, config['generate']['max_input_dec_seqlen'] - 32), 286 | nucleus_p=config['generate']['nucleus_p'], 287 | temperature=config['generate']['temperature'], 288 | 289 | ) 290 | times.append(t_sec) 291 | 292 | song = word2event(song, dset.idx2event) 293 | print (*song, sep='\n', file=open(out_file + '.txt', 'a')) 294 | remi2midi(song, out_file + '.mid', enforce_tempo=True, enforce_tempo_val=orig_tempo) 295 | 296 | # save metadata of the generation 297 | np.save(out_file + '-POLYCLS.npy', tensor_to_numpy(p_polyph_cls)) 298 | np.save(out_file + '-RHYMCLS.npy', tensor_to_numpy(p_rfreq_cls)) 299 | print ('[info] piece entropy: {:.4f} (+/- {:.4f})'.format( 300 | entropies.mean(), entropies.std() 301 | )) 302 | piece_entropies.append(entropies.mean()) 303 | 304 | print ('[time stats] {} songs, generation time: {:.2f} secs (+/- {:.2f})'.format( 305 | n_pieces * n_samples_per_piece, np.mean(times), np.std(times) 306 | )) 307 | print ('[entropy] {:.4f} (+/- {:.4f})'.format( 308 | np.mean(piece_entropies), np.std(piece_entropies) 309 | )) --------------------------------------------------------------------------------