├── requirements.txt
├── pickles
├── remi_vocab.pkl
├── test_pieces.pkl
├── train_pieces.pkl
└── val_pieces.pkl
├── .gitignore
├── utils.py
├── model
├── transformer_encoder.py
├── transformer_helpers.py
└── musemorphose.py
├── LICENSE
├── config
└── default.yaml
├── attributes.py
├── README.md
├── remi2midi.py
├── train.py
├── dataloader.py
└── generate.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.6.0
2 | miditoolkit
3 | PyYAML
4 | scipy
5 | numpy
--------------------------------------------------------------------------------
/pickles/remi_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/MuseMorphose/HEAD/pickles/remi_vocab.pkl
--------------------------------------------------------------------------------
/pickles/test_pieces.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/MuseMorphose/HEAD/pickles/test_pieces.pkl
--------------------------------------------------------------------------------
/pickles/train_pieces.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/MuseMorphose/HEAD/pickles/train_pieces.pkl
--------------------------------------------------------------------------------
/pickles/val_pieces.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/MuseMorphose/HEAD/pickles/val_pieces.pkl
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/*.tar.gz
2 | *dataset*
3 | **/__pycache__/
4 | **/.vscode/
5 | **/*ckpt*/
6 | **/*.pt
7 | **/gens/
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pickle
3 | from scipy.spatial import distance
4 |
5 | # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
6 | # device = torch.device('cpu')
7 |
8 | def numpy_to_tensor(arr, use_gpu=True, device='cuda:0'):
9 | if use_gpu:
10 | return torch.tensor(arr).to(device).float()
11 | else:
12 | return torch.tensor(arr).float()
13 |
14 | def tensor_to_numpy(tensor):
15 | return tensor.cpu().detach().numpy()
16 |
17 | def pickle_load(f):
18 | return pickle.load(open(f, 'rb'))
19 |
20 | def pickle_dump(obj, f):
21 | pickle.dump(obj, open(f, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
--------------------------------------------------------------------------------
/model/transformer_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | class VAETransformerEncoder(nn.Module):
6 | def __init__(self, n_layer, n_head, d_model, d_ff, d_vae_latent, dropout=0.1, activation='relu'):
7 | super(VAETransformerEncoder, self).__init__()
8 | self.n_layer = n_layer
9 | self.n_head = n_head
10 | self.d_model = d_model
11 | self.d_ff = d_ff
12 | self.d_vae_latent = d_vae_latent
13 | self.dropout = dropout
14 | self.activation = activation
15 |
16 | self.tr_encoder_layer = nn.TransformerEncoderLayer(
17 | d_model, n_head, d_ff, dropout, activation
18 | )
19 | self.tr_encoder = nn.TransformerEncoder(
20 | self.tr_encoder_layer, n_layer
21 | )
22 |
23 | self.fc_mu = nn.Linear(d_model, d_vae_latent)
24 | self.fc_logvar = nn.Linear(d_model, d_vae_latent)
25 |
26 | def forward(self, x, padding_mask=None):
27 | out = self.tr_encoder(x, src_key_padding_mask=padding_mask)
28 | hidden_out = out[0, :, :]
29 | mu, logvar = self.fc_mu(hidden_out), self.fc_logvar(hidden_out)
30 |
31 | return hidden_out, mu, logvar
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Shih-Lun Wu, Yi-Hsuan Yang, and Taiwan AI Labs
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/config/default.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_dir: ./remi_dataset
3 | train_split: ./pickles/train_pieces.pkl
4 | val_split: ./pickles/val_pieces.pkl
5 | test_split: ./pickles/test_pieces.pkl
6 | vocab_path: ./pickles/remi_vocab.pkl
7 | max_bars: 16
8 | enc_seqlen: 128
9 | dec_seqlen: 1280
10 | batch_size: 4
11 |
12 | model:
13 | enc_n_layer: 12
14 | enc_n_head: 8
15 | enc_d_model: 512
16 | enc_d_ff: 2048
17 | dec_n_layer: 12
18 | dec_n_head: 8
19 | dec_d_model: 512
20 | dec_d_ff: 2048
21 | d_embed: 512
22 | d_latent: 128
23 | d_polyph_emb: 64
24 | d_rfreq_emb: 64
25 | cond_mode: in-attn
26 | pretrained_params_path: null
27 | pretrained_optim_path: null
28 |
29 | training:
30 | device: cuda:0
31 | ckpt_dir: ./ckpt/enc_dec_12L-16_bars-seqlen_1280
32 | trained_steps: 0
33 | max_epochs: 1000
34 | max_lr: 1.0e-4
35 | min_lr: 5.0e-6
36 | lr_warmup_steps: 200
37 | lr_decay_steps: 150000
38 | no_kl_steps: 10000
39 | kl_cycle_steps: 5000
40 | kl_max_beta: 1.0
41 | free_bit_lambda: 0.25
42 | constant_kl: False
43 | ckpt_interval: 50
44 | log_interval: 10
45 | val_interval: 50
46 |
47 | generate:
48 | temperature: 1.2
49 | nucleus_p: 0.9
50 | use_latent_sampling: False
51 | latent_sampling_var: 0.0
52 | max_bars: 16 # could be set to match the longest input piece during generation (inference)
53 | dec_seqlen: 1280 # could be set to match the longest input piece during generation (inference)
54 | max_input_dec_seqlen: 1024 # should be set to equal to or less than `dec_seqlen` used during training
--------------------------------------------------------------------------------
/attributes.py:
--------------------------------------------------------------------------------
1 | from utils import pickle_load
2 |
3 | import os, pickle
4 | import numpy as np
5 | from collections import Counter
6 |
7 | data_dir = 'remi_dataset'
8 | polyph_out_dir = 'remi_dataset/attr_cls/polyph'
9 | rhythm_out_dir = 'remi_dataset/attr_cls/rhythm'
10 |
11 | rhym_intensity_bounds = [0.2, 0.25, 0.32, 0.38, 0.44, 0.5, 0.63]
12 | polyphonicity_bounds = [2.63, 3.06, 3.50, 4.00, 4.63, 5.44, 6.44]
13 |
14 | def compute_polyphonicity(events, n_bars):
15 | poly_record = np.zeros( (n_bars * 16,) )
16 |
17 | cur_bar, cur_pos = -1, -1
18 | for ev in events:
19 | if ev['name'] == 'Bar':
20 | cur_bar += 1
21 | elif ev['name'] == 'Beat':
22 | cur_pos = int(ev['value'])
23 | elif ev['name'] == 'Note_Duration':
24 | duration = int(ev['value']) // 120
25 | st = cur_bar * 16 + cur_pos
26 | poly_record[st:st + duration] += 1
27 |
28 | return poly_record
29 |
30 | def get_onsets_timing(events, n_bars):
31 | onset_record = np.zeros( (n_bars * 16,) )
32 |
33 | cur_bar, cur_pos = -1, -1
34 | for ev in events:
35 | if ev['name'] == 'Bar':
36 | cur_bar += 1
37 | elif ev['name'] == 'Beat':
38 | cur_pos = int(ev['value'])
39 | elif ev['name'] == 'Note_Pitch':
40 | rec_idx = cur_bar * 16 + cur_pos
41 | onset_record[ rec_idx ] = 1
42 |
43 | return onset_record
44 |
45 | if __name__ == "__main__":
46 | pieces = [p for p in sorted(os.listdir(data_dir)) if '.pkl' in p]
47 | all_r_cls = []
48 | all_p_cls = []
49 |
50 | if not os.path.exists(polyph_out_dir):
51 | os.makedirs(polyph_out_dir)
52 | if not os.path.exists(rhythm_out_dir):
53 | os.makedirs(rhythm_out_dir)
54 |
55 | for p in pieces:
56 | bar_pos, events = pickle_load(os.path.join(data_dir, p))
57 | events = events[ :bar_pos[-1] ]
58 |
59 | polyph_raw = np.reshape(
60 | compute_polyphonicity(events, n_bars=len(bar_pos)), (-1, 16)
61 | )
62 | rhythm_raw = np.reshape(
63 | get_onsets_timing(events, n_bars=len(bar_pos)), (-1, 16)
64 | )
65 |
66 | polyph_cls = np.searchsorted(polyphonicity_bounds, np.mean(polyph_raw, axis=-1)).tolist()
67 | rfreq_cls = np.searchsorted(rhym_intensity_bounds, np.mean(rhythm_raw, axis=-1)).tolist()
68 |
69 | pickle.dump(polyph_cls, open(os.path.join(
70 | polyph_out_dir, p), 'wb'
71 | ))
72 | pickle.dump(rfreq_cls, open(os.path.join(
73 | rhythm_out_dir, p), 'wb'
74 | ))
75 |
76 | all_r_cls.extend(rfreq_cls)
77 | all_p_cls.extend(polyph_cls)
78 |
79 | print ('[rhythm classes]', Counter(all_r_cls))
80 | print ('[polyph classes]', Counter(all_p_cls))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MuseMorphose
2 |
3 | This repository contains the official implementation of the following paper:
4 |
5 | * Shih-Lun Wu, Yi-Hsuan Yang
6 | **_MuseMorphose_: Full-Song and Fine-Grained Piano Music Style Transfer with One Transformer VAE**
7 | accepted to _IEEE/ACM Trans. Audio, Speech, & Language Processing (TASLP)_, Dec 2022 [arXiv] [demo website]
8 |
9 | ## Prerequisites
10 | * Python >= 3.6
11 | * Install dependencies
12 | ```bash
13 | pip3 install -r requirements.txt
14 | ```
15 | * GPU with >6GB RAM (optional, but recommended)
16 |
17 | ## Preprocessing
18 | ```bash
19 | # download REMI-pop-1.7K dataset
20 | wget -O remi_dataset.tar.gz https://zenodo.org/record/4782721/files/remi_dataset.tar.gz?download=1
21 | tar xzvf remi_dataset.tar.gz
22 | rm remi_dataset.tar.gz
23 |
24 | # compute attributes classes
25 | python3 attributes.py
26 | ```
27 |
28 | ## Training
29 | ```bash
30 | python3 train.py [config file]
31 | ```
32 | * e.g.
33 | ```bash
34 | python3 train.py config/default.yaml
35 | ```
36 | * Or, you may download the pretrained weights straight away
37 | ```bash
38 | wget -O musemorphose_pretrained_weights.pt https://zenodo.org/record/5119525/files/musemorphose_pretrained_weights.pt?download=1
39 | ```
40 |
41 | ## Generation
42 | ```bash
43 | python3 generate.py [config file] [ckpt path] [output dir] [num pieces] [num samples per piece]
44 | ```
45 | * e.g.
46 | ```bash
47 | python3 generate.py config/default.yaml musemorphose_pretrained_weights.pt generations/ 10 5
48 | ```
49 |
50 | This script will randomly draw the specified # of pieces from the test set.
51 | For each sample of a piece, the _rhythmic intensity_ and _polyphonicity_ will be shifted entirely and randomly by \[-3, 3\] classes for the model to generate style-transferred music.
52 | You may modify `random_shift_attr_cls()` in `generate.py` or write your own function to set the attributes.
53 |
54 | ## Customized Generation (To Be Added)
55 | We welcome the community's suggestions and contributions for an interface on which users may
56 | * upload their own MIDIs, and
57 | * set their desired bar-level attributes easily
58 |
59 | ## Citation BibTex
60 | If you find this work helpful and use our code in your research, please kindly cite our paper:
61 | ```
62 | @article{wu2023musemorphose,
63 | title={{MuseMorphose}: Full-Song and Fine-Grained Piano Music Style Transfer with One {Transformer VAE}},
64 | author={Shih-Lun Wu and Yi-Hsuan Yang},
65 | year={2023},
66 | journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
67 | }
68 | ```
69 |
--------------------------------------------------------------------------------
/model/transformer_helpers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | from torch import nn
5 | import torch.nn.functional as F
6 |
7 | def generate_causal_mask(seq_len):
8 | mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1)
9 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
10 | mask.requires_grad = False
11 | return mask
12 |
13 | def weight_init_normal(weight, normal_std):
14 | nn.init.normal_(weight, 0.0, normal_std)
15 |
16 | def weight_init_orthogonal(weight, gain):
17 | nn.init.orthogonal_(weight, gain)
18 |
19 | def bias_init(bias):
20 | nn.init.constant_(bias, 0.0)
21 |
22 | def weights_init(m):
23 | classname = m.__class__.__name__
24 | # print ('[{}] initializing ...'.format(classname))
25 |
26 | if classname.find('Linear') != -1:
27 | if hasattr(m, 'weight') and m.weight is not None:
28 | weight_init_normal(m.weight, 0.01)
29 | if hasattr(m, 'bias') and m.bias is not None:
30 | bias_init(m.bias)
31 | elif classname.find('Embedding') != -1:
32 | if hasattr(m, 'weight'):
33 | weight_init_normal(m.weight, 0.01)
34 | elif classname.find('LayerNorm') != -1:
35 | if hasattr(m, 'weight'):
36 | nn.init.normal_(m.weight, 1.0, 0.01)
37 | if hasattr(m, 'bias') and m.bias is not None:
38 | bias_init(m.bias)
39 | elif classname.find('GRU') != -1:
40 | for param in m.parameters():
41 | if len(param.shape) >= 2: # weights
42 | weight_init_orthogonal(param, 0.01)
43 | else: # biases
44 | bias_init(param)
45 | # else:
46 | # print ('[{}] not initialized !!'.format(classname))
47 |
48 | class PositionalEncoding(nn.Module):
49 | def __init__(self, d_embed, max_pos=20480):
50 | super(PositionalEncoding, self).__init__()
51 | self.d_embed = d_embed
52 | self.max_pos = max_pos
53 |
54 | pe = torch.zeros(max_pos, d_embed)
55 | position = torch.arange(0, max_pos, dtype=torch.float).unsqueeze(1)
56 | div_term = torch.exp(torch.arange(0, d_embed, 2).float() * (-math.log(10000.0) / d_embed))
57 | pe[:, 0::2] = torch.sin(position * div_term)
58 | pe[:, 1::2] = torch.cos(position * div_term)
59 | pe = pe.unsqueeze(0).transpose(0, 1)
60 | self.register_buffer('pe', pe)
61 |
62 | def forward(self, seq_len, bsz=None):
63 | pos_encoding = self.pe[:seq_len, :]
64 |
65 | if bsz is not None:
66 | pos_encoding = pos_encoding.expand(seq_len, bsz, -1)
67 |
68 | return pos_encoding
69 |
70 | class TokenEmbedding(nn.Module):
71 | def __init__(self, n_token, d_embed, d_proj):
72 | super(TokenEmbedding, self).__init__()
73 |
74 | self.n_token = n_token
75 | self.d_embed = d_embed
76 | self.d_proj = d_proj
77 | self.emb_scale = d_proj ** 0.5
78 |
79 | self.emb_lookup = nn.Embedding(n_token, d_embed)
80 | if d_proj != d_embed:
81 | self.emb_proj = nn.Linear(d_embed, d_proj, bias=False)
82 | else:
83 | self.emb_proj = None
84 |
85 | def forward(self, inp_tokens):
86 | inp_emb = self.emb_lookup(inp_tokens)
87 |
88 | if self.emb_proj is not None:
89 | inp_emb = self.emb_proj(inp_emb)
90 |
91 | return inp_emb.mul_(self.emb_scale)
--------------------------------------------------------------------------------
/remi2midi.py:
--------------------------------------------------------------------------------
1 | import os, pickle, random, copy
2 | import numpy as np
3 |
4 | import miditoolkit
5 |
6 | ##############################
7 | # constants
8 | ##############################
9 | DEFAULT_BEAT_RESOL = 480
10 | DEFAULT_BAR_RESOL = 480 * 4
11 | DEFAULT_FRACTION = 16
12 |
13 |
14 | ##############################
15 | # containers for conversion
16 | ##############################
17 | class ConversionEvent(object):
18 | def __init__(self, event, is_full_event=False):
19 | if not is_full_event:
20 | if 'Note' in event:
21 | self.name, self.value = '_'.join(event.split('_')[:-1]), event.split('_')[-1]
22 | elif 'Chord' in event:
23 | self.name, self.value = event.split('_')[0], '_'.join(event.split('_')[1:])
24 | else:
25 | self.name, self.value = event.split('_')
26 | else:
27 | self.name, self.value = event['name'], event['value']
28 | def __repr__(self):
29 | return 'Event(name: {} | value: {})'.format(self.name, self.value)
30 |
31 | class NoteEvent(object):
32 | def __init__(self, pitch, bar, position, duration, velocity):
33 | self.pitch = pitch
34 | self.start_tick = bar * DEFAULT_BAR_RESOL + position * (DEFAULT_BAR_RESOL // DEFAULT_FRACTION)
35 | self.duration = duration
36 | self.velocity = velocity
37 |
38 | class TempoEvent(object):
39 | def __init__(self, tempo, bar, position):
40 | self.tempo = tempo
41 | self.start_tick = bar * DEFAULT_BAR_RESOL + position * (DEFAULT_BAR_RESOL // DEFAULT_FRACTION)
42 |
43 | class ChordEvent(object):
44 | def __init__(self, chord_val, bar, position):
45 | self.chord_val = chord_val
46 | self.start_tick = bar * DEFAULT_BAR_RESOL + position * (DEFAULT_BAR_RESOL // DEFAULT_FRACTION)
47 |
48 | ##############################
49 | # conversion functions
50 | ##############################
51 | def read_generated_txt(generated_path):
52 | f = open(generated_path, 'r')
53 | return f.read().splitlines()
54 |
55 | def remi2midi(events, output_midi_path=None, is_full_event=False, return_first_tempo=False, enforce_tempo=False, enforce_tempo_val=None):
56 | events = [ConversionEvent(ev, is_full_event=is_full_event) for ev in events]
57 | # print (events[:20])
58 |
59 | assert events[0].name == 'Bar'
60 | temp_notes = []
61 | temp_tempos = []
62 | temp_chords = []
63 |
64 | cur_bar = 0
65 | cur_position = 0
66 |
67 | for i in range(len(events)):
68 | if events[i].name == 'Bar':
69 | if i > 0:
70 | cur_bar += 1
71 | elif events[i].name == 'Beat':
72 | cur_position = int(events[i].value)
73 | assert cur_position >= 0 and cur_position < DEFAULT_FRACTION
74 | elif events[i].name == 'Tempo':
75 | temp_tempos.append(TempoEvent(
76 | int(events[i].value), cur_bar, cur_position
77 | ))
78 | elif 'Note_Pitch' in events[i].name and \
79 | (i+1) < len(events) and 'Note_Velocity' in events[i+1].name and \
80 | (i+2) < len(events) and 'Note_Duration' in events[i+2].name:
81 | # check if the 3 events are of the same instrument
82 | temp_notes.append(
83 | NoteEvent(
84 | pitch=int(events[i].value),
85 | bar=cur_bar, position=cur_position,
86 | duration=int(events[i+2].value), velocity=int(events[i+1].value)
87 | )
88 | )
89 | elif 'Chord' in events[i].name:
90 | temp_chords.append(
91 | ChordEvent(events[i].value, cur_bar, cur_position)
92 | )
93 | elif events[i].name in ['EOS', 'PAD']:
94 | continue
95 |
96 | # print (len(temp_tempos), len(temp_notes))
97 | midi_obj = miditoolkit.midi.parser.MidiFile()
98 | midi_obj.instruments = [
99 | miditoolkit.Instrument(program=0, is_drum=False, name='Piano')
100 | ]
101 |
102 | for n in temp_notes:
103 | midi_obj.instruments[0].notes.append(
104 | miditoolkit.Note(int(n.velocity), n.pitch, int(n.start_tick), int(n.start_tick + n.duration))
105 | )
106 |
107 | if enforce_tempo is False:
108 | for t in temp_tempos:
109 | midi_obj.tempo_changes.append(
110 | miditoolkit.TempoChange(t.tempo, int(t.start_tick))
111 | )
112 | else:
113 | if enforce_tempo_val is None:
114 | enforce_tempo_val = temp_tempos[1]
115 | for t in enforce_tempo_val:
116 | midi_obj.tempo_changes.append(
117 | miditoolkit.TempoChange(t.tempo, int(t.start_tick))
118 | )
119 |
120 |
121 | for c in temp_chords:
122 | midi_obj.markers.append(
123 | miditoolkit.Marker('Chord-{}'.format(c.chord_val), int(c.start_tick))
124 | )
125 | for b in range(cur_bar):
126 | midi_obj.markers.append(
127 | miditoolkit.Marker('Bar-{}'.format(b+1), int(DEFAULT_BAR_RESOL * b))
128 | )
129 |
130 | if output_midi_path is not None:
131 | midi_obj.dump(output_midi_path)
132 |
133 | if not return_first_tempo:
134 | return midi_obj
135 | else:
136 | return midi_obj, temp_tempos
--------------------------------------------------------------------------------
/model/musemorphose.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from transformer_encoder import VAETransformerEncoder
5 | from transformer_helpers import (
6 | weights_init, PositionalEncoding, TokenEmbedding, generate_causal_mask
7 | )
8 |
9 | class VAETransformerDecoder(nn.Module):
10 | def __init__(self, n_layer, n_head, d_model, d_ff, d_seg_emb, dropout=0.1, activation='relu', cond_mode='in-attn'):
11 | super(VAETransformerDecoder, self).__init__()
12 | self.n_layer = n_layer
13 | self.n_head = n_head
14 | self.d_model = d_model
15 | self.d_ff = d_ff
16 | self.d_seg_emb = d_seg_emb
17 | self.dropout = dropout
18 | self.activation = activation
19 | self.cond_mode = cond_mode
20 |
21 | if cond_mode == 'in-attn':
22 | self.seg_emb_proj = nn.Linear(d_seg_emb, d_model, bias=False)
23 | elif cond_mode == 'pre-attn':
24 | self.seg_emb_proj = nn.Linear(d_seg_emb + d_model, d_model, bias=False)
25 |
26 | self.decoder_layers = nn.ModuleList()
27 | for i in range(n_layer):
28 | self.decoder_layers.append(
29 | nn.TransformerEncoderLayer(d_model, n_head, d_ff, dropout, activation)
30 | )
31 |
32 | def forward(self, x, seg_emb):
33 | if not hasattr(self, 'cond_mode'):
34 | self.cond_mode = 'in-attn'
35 | attn_mask = generate_causal_mask(x.size(0)).to(x.device)
36 | # print (attn_mask.size())
37 |
38 | if self.cond_mode == 'in-attn':
39 | seg_emb = self.seg_emb_proj(seg_emb)
40 | elif self.cond_mode == 'pre-attn':
41 | x = torch.cat([x, seg_emb], dim=-1)
42 | x = self.seg_emb_proj(x)
43 |
44 | out = x
45 | for i in range(self.n_layer):
46 | if self.cond_mode == 'in-attn':
47 | out += seg_emb
48 | out = self.decoder_layers[i](out, src_mask=attn_mask)
49 |
50 | return out
51 |
52 | class MuseMorphose(nn.Module):
53 | def __init__(self, enc_n_layer, enc_n_head, enc_d_model, enc_d_ff,
54 | dec_n_layer, dec_n_head, dec_d_model, dec_d_ff,
55 | d_vae_latent, d_embed, n_token,
56 | enc_dropout=0.1, enc_activation='relu',
57 | dec_dropout=0.1, dec_activation='relu',
58 | d_rfreq_emb=32, d_polyph_emb=32,
59 | n_rfreq_cls=8, n_polyph_cls=8,
60 | is_training=True, use_attr_cls=True,
61 | cond_mode='in-attn'
62 | ):
63 | super(MuseMorphose, self).__init__()
64 | self.enc_n_layer = enc_n_layer
65 | self.enc_n_head = enc_n_head
66 | self.enc_d_model = enc_d_model
67 | self.enc_d_ff = enc_d_ff
68 | self.enc_dropout = enc_dropout
69 | self.enc_activation = enc_activation
70 |
71 | self.dec_n_layer = dec_n_layer
72 | self.dec_n_head = dec_n_head
73 | self.dec_d_model = dec_d_model
74 | self.dec_d_ff = dec_d_ff
75 | self.dec_dropout = dec_dropout
76 | self.dec_activation = dec_activation
77 |
78 | self.d_vae_latent = d_vae_latent
79 | self.n_token = n_token
80 | self.is_training = is_training
81 |
82 | self.cond_mode = cond_mode
83 | self.token_emb = TokenEmbedding(n_token, d_embed, enc_d_model)
84 | self.d_embed = d_embed
85 | self.pe = PositionalEncoding(d_embed)
86 | self.dec_out_proj = nn.Linear(dec_d_model, n_token)
87 | self.encoder = VAETransformerEncoder(
88 | enc_n_layer, enc_n_head, enc_d_model, enc_d_ff, d_vae_latent, enc_dropout, enc_activation
89 | )
90 |
91 | self.use_attr_cls = use_attr_cls
92 | if use_attr_cls:
93 | self.decoder = VAETransformerDecoder(
94 | dec_n_layer, dec_n_head, dec_d_model, dec_d_ff, d_vae_latent + d_polyph_emb + d_rfreq_emb,
95 | dropout=dec_dropout, activation=dec_activation,
96 | cond_mode=cond_mode
97 | )
98 | else:
99 | self.decoder = VAETransformerDecoder(
100 | dec_n_layer, dec_n_head, dec_d_model, dec_d_ff, d_vae_latent,
101 | dropout=dec_dropout, activation=dec_activation,
102 | cond_mode=cond_mode
103 | )
104 |
105 | if use_attr_cls:
106 | self.d_rfreq_emb = d_rfreq_emb
107 | self.d_polyph_emb = d_polyph_emb
108 | self.rfreq_attr_emb = TokenEmbedding(n_rfreq_cls, d_rfreq_emb, d_rfreq_emb)
109 | self.polyph_attr_emb = TokenEmbedding(n_polyph_cls, d_polyph_emb, d_polyph_emb)
110 | else:
111 | self.rfreq_attr_emb = None
112 | self.polyph_attr_emb = None
113 |
114 | self.emb_dropout = nn.Dropout(self.enc_dropout)
115 | self.apply(weights_init)
116 |
117 |
118 | def reparameterize(self, mu, logvar, use_sampling=True, sampling_var=1.):
119 | std = torch.exp(0.5 * logvar).to(mu.device)
120 | if use_sampling:
121 | eps = torch.randn_like(std).to(mu.device) * sampling_var
122 | else:
123 | eps = torch.zeros_like(std).to(mu.device)
124 |
125 | return eps * std + mu
126 |
127 | def get_sampled_latent(self, inp, padding_mask=None, use_sampling=False, sampling_var=0.):
128 | token_emb = self.token_emb(inp)
129 | enc_inp = self.emb_dropout(token_emb) + self.pe(inp.size(0))
130 |
131 | _, mu, logvar = self.encoder(enc_inp, padding_mask=padding_mask)
132 | mu, logvar = mu.reshape(-1, mu.size(-1)), logvar.reshape(-1, mu.size(-1))
133 | vae_latent = self.reparameterize(mu, logvar, use_sampling=use_sampling, sampling_var=sampling_var)
134 |
135 | return vae_latent
136 |
137 | def generate(self, inp, dec_seg_emb, rfreq_cls=None, polyph_cls=None, keep_last_only=True):
138 | token_emb = self.token_emb(inp)
139 | dec_inp = self.emb_dropout(token_emb) + self.pe(inp.size(0))
140 |
141 | if rfreq_cls is not None and polyph_cls is not None:
142 | dec_rfreq_emb = self.rfreq_attr_emb(rfreq_cls)
143 | dec_polyph_emb = self.polyph_attr_emb(polyph_cls)
144 | dec_seg_emb_cat = torch.cat([dec_seg_emb, dec_rfreq_emb, dec_polyph_emb], dim=-1)
145 | else:
146 | dec_seg_emb_cat = dec_seg_emb
147 |
148 | out = self.decoder(dec_inp, dec_seg_emb_cat)
149 | out = self.dec_out_proj(out)
150 |
151 | if keep_last_only:
152 | out = out[-1, ...]
153 |
154 | return out
155 |
156 |
157 | def forward(self, enc_inp, dec_inp, dec_inp_bar_pos, rfreq_cls=None, polyph_cls=None, padding_mask=None):
158 | # [shape of enc_inp] (seqlen_per_bar, bsize, n_bars_per_sample)
159 | enc_bt_size, enc_n_bars = enc_inp.size(1), enc_inp.size(2)
160 | enc_token_emb = self.token_emb(enc_inp)
161 |
162 | # [shape of dec_inp] (seqlen_per_sample, bsize)
163 | # [shape of rfreq_cls & polyph_cls] same as above
164 | # -- (should copy each bar's label to all corresponding indices)
165 | dec_token_emb = self.token_emb(dec_inp)
166 |
167 | enc_token_emb = enc_token_emb.reshape(
168 | enc_inp.size(0), -1, enc_token_emb.size(-1)
169 | )
170 | enc_inp = self.emb_dropout(enc_token_emb) + self.pe(enc_inp.size(0))
171 | dec_inp = self.emb_dropout(dec_token_emb) + self.pe(dec_inp.size(0))
172 |
173 | # [shape of padding_mask] (bsize, n_bars_per_sample, seqlen_per_bar)
174 | # -- should be `True` for padded indices (i.e., those >= seqlen of the bar), `False` otherwise
175 | if padding_mask is not None:
176 | padding_mask = padding_mask.reshape(-1, padding_mask.size(-1))
177 |
178 | _, mu, logvar = self.encoder(enc_inp, padding_mask=padding_mask)
179 | vae_latent = self.reparameterize(mu, logvar)
180 | vae_latent_reshaped = vae_latent.reshape(enc_bt_size, enc_n_bars, -1)
181 |
182 | dec_seg_emb = torch.zeros(dec_inp.size(0), dec_inp.size(1), self.d_vae_latent).to(vae_latent.device)
183 | for n in range(dec_inp.size(1)):
184 | # [shape of dec_inp_bar_pos] (bsize, n_bars_per_sample + 1)
185 | # -- stores [[start idx of bar #1, sample #1, ..., start idx of bar #K, sample #1, seqlen of sample #1], [same for another sample], ...]
186 | for b, (st, ed) in enumerate(zip(dec_inp_bar_pos[n, :-1], dec_inp_bar_pos[n, 1:])):
187 | dec_seg_emb[st:ed, n, :] = vae_latent_reshaped[n, b, :]
188 |
189 | if rfreq_cls is not None and polyph_cls is not None and self.use_attr_cls:
190 | dec_rfreq_emb = self.rfreq_attr_emb(rfreq_cls)
191 | dec_polyph_emb = self.polyph_attr_emb(polyph_cls)
192 | dec_seg_emb_cat = torch.cat([dec_seg_emb, dec_rfreq_emb, dec_polyph_emb], dim=-1)
193 | else:
194 | dec_seg_emb_cat = dec_seg_emb
195 |
196 | dec_out = self.decoder(dec_inp, dec_seg_emb_cat)
197 | dec_logits = self.dec_out_proj(dec_out)
198 |
199 | return mu, logvar, dec_logits
200 |
201 | def compute_loss(self, mu, logvar, beta, fb_lambda, dec_logits, dec_tgt):
202 | recons_loss = F.cross_entropy(
203 | dec_logits.view(-1, dec_logits.size(-1)), dec_tgt.contiguous().view(-1),
204 | ignore_index=self.n_token - 1, reduction='mean'
205 | ).float()
206 |
207 | kl_raw = -0.5 * (1 + logvar - mu ** 2 - logvar.exp()).mean(dim=0)
208 | kl_before_free_bits = kl_raw.mean()
209 | kl_after_free_bits = kl_raw.clamp(min=fb_lambda)
210 | kldiv_loss = kl_after_free_bits.mean()
211 |
212 | return {
213 | 'beta': beta,
214 | 'total_loss': recons_loss + beta * kldiv_loss,
215 | 'kldiv_loss': kldiv_loss,
216 | 'kldiv_raw': kl_before_free_bits,
217 | 'recons_loss': recons_loss
218 | }
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import sys, os, time
2 | sys.path.append('./model')
3 |
4 | from model.musemorphose import MuseMorphose
5 | from dataloader import REMIFullSongTransformerDataset
6 | from torch.utils.data import DataLoader
7 |
8 | from utils import pickle_load
9 | from torch import nn, optim
10 | import torch
11 | import numpy as np
12 |
13 | import yaml
14 | config_path = sys.argv[1]
15 | config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader)
16 |
17 | device = config['training']['device']
18 | trained_steps = config['training']['trained_steps']
19 | lr_decay_steps = config['training']['lr_decay_steps']
20 | lr_warmup_steps = config['training']['lr_warmup_steps']
21 | no_kl_steps = config['training']['no_kl_steps']
22 | kl_cycle_steps = config['training']['kl_cycle_steps']
23 | kl_max_beta = config['training']['kl_max_beta']
24 | free_bit_lambda = config['training']['free_bit_lambda']
25 | max_lr, min_lr = config['training']['max_lr'], config['training']['min_lr']
26 |
27 | ckpt_dir = config['training']['ckpt_dir']
28 | params_dir = os.path.join(ckpt_dir, 'params/')
29 | optim_dir = os.path.join(ckpt_dir, 'optim/')
30 | pretrained_params_path = config['model']['pretrained_params_path']
31 | pretrained_optim_path = config['model']['pretrained_optim_path']
32 | ckpt_interval = config['training']['ckpt_interval']
33 | log_interval = config['training']['log_interval']
34 | val_interval = config['training']['val_interval']
35 | constant_kl = config['training']['constant_kl']
36 |
37 | recons_loss_ema = 0.
38 | kl_loss_ema = 0.
39 | kl_raw_ema = 0.
40 |
41 | def log_epoch(log_file, log_data, is_init=False):
42 | if is_init:
43 | with open(log_file, 'w') as f:
44 | f.write('{:4} {:8} {:12} {:12} {:12} {:12}\n'.format('ep', 'steps', 'recons_loss', 'kldiv_loss', 'kldiv_raw', 'ep_time'))
45 |
46 | with open(log_file, 'a') as f:
47 | f.write('{:<4} {:<8} {:<12} {:<12} {:<12} {:<12}\n'.format(
48 | log_data['ep'], log_data['steps'], round(log_data['recons_loss'], 5), round(log_data['kldiv_loss'], 5), round(log_data['kldiv_raw'], 5), round(log_data['time'], 2)
49 | ))
50 |
51 | def beta_cyclical_sched(step):
52 | step_in_cycle = (step - 1) % kl_cycle_steps
53 | cycle_progress = step_in_cycle / kl_cycle_steps
54 |
55 | if step < no_kl_steps:
56 | return 0.
57 | if cycle_progress < 0.5:
58 | return kl_max_beta * cycle_progress * 2.
59 | else:
60 | return kl_max_beta
61 |
62 | def compute_loss_ema(ema, batch_loss, decay=0.95):
63 | if ema == 0.:
64 | return batch_loss
65 | else:
66 | return batch_loss * (1 - decay) + ema * decay
67 |
68 | def train_model(epoch, model, dloader, dloader_val, optim, sched):
69 | model.train()
70 |
71 | print ('[epoch {:03d}] training ...'.format(epoch))
72 | print ('[epoch {:03d}] # batches = {}'.format(epoch, len(dloader)))
73 | st = time.time()
74 |
75 | for batch_idx, batch_samples in enumerate(dloader):
76 | model.zero_grad()
77 | batch_enc_inp = batch_samples['enc_input'].permute(2, 0, 1).to(device)
78 | batch_dec_inp = batch_samples['dec_input'].permute(1, 0).to(device)
79 | batch_dec_tgt = batch_samples['dec_target'].permute(1, 0).to(device)
80 | batch_inp_bar_pos = batch_samples['bar_pos'].to(device)
81 | batch_inp_lens = batch_samples['length']
82 | batch_padding_mask = batch_samples['enc_padding_mask'].to(device)
83 | batch_rfreq_cls = batch_samples['rhymfreq_cls'].permute(1, 0).to(device)
84 | batch_polyph_cls = batch_samples['polyph_cls'].permute(1, 0).to(device)
85 |
86 | global trained_steps
87 | trained_steps += 1
88 |
89 | mu, logvar, dec_logits = model(
90 | batch_enc_inp, batch_dec_inp,
91 | batch_inp_bar_pos, batch_rfreq_cls, batch_polyph_cls,
92 | padding_mask=batch_padding_mask
93 | )
94 |
95 | if not constant_kl:
96 | kl_beta = beta_cyclical_sched(trained_steps)
97 | else:
98 | kl_beta = kl_max_beta
99 | losses = model.compute_loss(mu, logvar, kl_beta, free_bit_lambda, dec_logits, batch_dec_tgt)
100 |
101 | # anneal learning rate
102 | if trained_steps < lr_warmup_steps:
103 | curr_lr = max_lr * trained_steps / lr_warmup_steps
104 | optim.param_groups[0]['lr'] = curr_lr
105 | else:
106 | sched.step(trained_steps - lr_warmup_steps)
107 |
108 | # clip gradient & update model
109 | losses['total_loss'].backward()
110 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
111 | optim.step()
112 |
113 | global recons_loss_ema, kl_loss_ema, kl_raw_ema
114 | recons_loss_ema = compute_loss_ema(recons_loss_ema, losses['recons_loss'].item())
115 | kl_loss_ema = compute_loss_ema(kl_loss_ema, losses['kldiv_loss'].item())
116 | kl_raw_ema = compute_loss_ema(kl_raw_ema, losses['kldiv_raw'].item())
117 |
118 | print (' -- epoch {:03d} | batch {:03d}: len: {}\n\t * loss = (RC: {:.4f} | KL: {:.4f} | KL_raw: {:.4f}), step = {}, beta: {:.4f} time_elapsed = {:.2f} secs'.format(
119 | epoch, batch_idx, batch_inp_lens, recons_loss_ema, kl_loss_ema, kl_raw_ema, trained_steps, kl_beta, time.time() - st
120 | ))
121 |
122 | if not trained_steps % log_interval:
123 | log_data = {
124 | 'ep': epoch,
125 | 'steps': trained_steps,
126 | 'recons_loss': recons_loss_ema,
127 | 'kldiv_loss': kl_loss_ema,
128 | 'kldiv_raw': kl_raw_ema,
129 | 'time': time.time() - st
130 | }
131 | log_epoch(
132 | os.path.join(ckpt_dir, 'log.txt'), log_data, is_init=not os.path.exists(os.path.join(ckpt_dir, 'log.txt'))
133 | )
134 |
135 | if not trained_steps % val_interval:
136 | vallosses = validate(model, dloader_val)
137 | with open(os.path.join(ckpt_dir, 'valloss.txt'), 'a') as f:
138 | f.write('[step {}] RC: {:.4f} | KL: {:.4f} | [val] | RC: {:.4f} | KL: {:.4f}\n'.format(
139 | trained_steps,
140 | recons_loss_ema,
141 | kl_raw_ema,
142 | np.mean(vallosses[0]),
143 | np.mean(vallosses[1])
144 | ))
145 | model.train()
146 |
147 | if not trained_steps % ckpt_interval:
148 | torch.save(model.state_dict(),
149 | os.path.join(params_dir, 'step_{:d}-RC_{:.3f}-KL_{:.3f}-model.pt'.format(
150 | trained_steps,
151 | recons_loss_ema,
152 | kl_raw_ema
153 | ))
154 | )
155 | torch.save(optim.state_dict(),
156 | os.path.join(optim_dir, 'step_{:d}-RC_{:.3f}-KL_{:.3f}-optim.pt'.format(
157 | trained_steps,
158 | recons_loss_ema,
159 | kl_raw_ema
160 | ))
161 | )
162 |
163 | print ('[epoch {:03d}] training completed\n -- loss = (RC: {:.4f} | KL: {:.4f} | KL_raw: {:.4f})\n -- time elapsed = {:.2f} secs.'.format(
164 | epoch, recons_loss_ema, kl_loss_ema, kl_raw_ema, time.time() - st
165 | ))
166 | log_data = {
167 | 'ep': epoch,
168 | 'steps': trained_steps,
169 | 'recons_loss': recons_loss_ema,
170 | 'kldiv_loss': kl_loss_ema,
171 | 'kldiv_raw': kl_raw_ema,
172 | 'time': time.time() - st
173 | }
174 | log_epoch(
175 | os.path.join(ckpt_dir, 'log.txt'), log_data, is_init=not os.path.exists(os.path.join(ckpt_dir, 'log.txt'))
176 | )
177 |
178 | def validate(model, dloader, n_rounds=8, use_attr_cls=True):
179 | model.eval()
180 | loss_rec = []
181 | kl_loss_rec = []
182 |
183 | print ('[info] validating ...')
184 | with torch.no_grad():
185 | for i in range(n_rounds):
186 | print ('[round {}]'.format(i+1))
187 |
188 | for batch_idx, batch_samples in enumerate(dloader):
189 | model.zero_grad()
190 |
191 | batch_enc_inp = batch_samples['enc_input'].permute(2, 0, 1).to(device)
192 | batch_dec_inp = batch_samples['dec_input'].permute(1, 0).to(device)
193 | batch_dec_tgt = batch_samples['dec_target'].permute(1, 0).to(device)
194 | batch_inp_bar_pos = batch_samples['bar_pos'].to(device)
195 | batch_padding_mask = batch_samples['enc_padding_mask'].to(device)
196 | if use_attr_cls:
197 | batch_rfreq_cls = batch_samples['rhymfreq_cls'].permute(1, 0).to(device)
198 | batch_polyph_cls = batch_samples['polyph_cls'].permute(1, 0).to(device)
199 | else:
200 | batch_rfreq_cls = None
201 | batch_polyph_cls = None
202 |
203 | mu, logvar, dec_logits = model(
204 | batch_enc_inp, batch_dec_inp,
205 | batch_inp_bar_pos, batch_rfreq_cls, batch_polyph_cls,
206 | padding_mask=batch_padding_mask
207 | )
208 |
209 | losses = model.compute_loss(mu, logvar, 0.0, 0.0, dec_logits, batch_dec_tgt)
210 | if not (batch_idx + 1) % 10:
211 | print ('batch #{}:'.format(batch_idx + 1), round(losses['recons_loss'].item(), 3))
212 |
213 | loss_rec.append(losses['recons_loss'].item())
214 | kl_loss_rec.append(losses['kldiv_raw'].item())
215 |
216 | return loss_rec, kl_loss_rec
217 |
218 | if __name__ == "__main__":
219 | dset = REMIFullSongTransformerDataset(
220 | config['data']['data_dir'], config['data']['vocab_path'],
221 | do_augment=True,
222 | model_enc_seqlen=config['data']['enc_seqlen'],
223 | model_dec_seqlen=config['data']['dec_seqlen'],
224 | model_max_bars=config['data']['max_bars'],
225 | pieces=pickle_load(config['data']['train_split']),
226 | pad_to_same=True
227 | )
228 | dset_val = REMIFullSongTransformerDataset(
229 | config['data']['data_dir'], config['data']['vocab_path'],
230 | do_augment=False,
231 | model_enc_seqlen=config['data']['enc_seqlen'],
232 | model_dec_seqlen=config['data']['dec_seqlen'],
233 | model_max_bars=config['data']['max_bars'],
234 | pieces=pickle_load(config['data']['val_split']),
235 | pad_to_same=True
236 | )
237 | print ('[info]', '# training samples:', len(dset.pieces))
238 |
239 | dloader = DataLoader(dset, batch_size=config['data']['batch_size'], shuffle=True, num_workers=8)
240 | dloader_val = DataLoader(dset_val, batch_size=config['data']['batch_size'], shuffle=True, num_workers=8)
241 |
242 | mconf = config['model']
243 | model = MuseMorphose(
244 | mconf['enc_n_layer'], mconf['enc_n_head'], mconf['enc_d_model'], mconf['enc_d_ff'],
245 | mconf['dec_n_layer'], mconf['dec_n_head'], mconf['dec_d_model'], mconf['dec_d_ff'],
246 | mconf['d_latent'], mconf['d_embed'], dset.vocab_size,
247 | d_polyph_emb=mconf['d_polyph_emb'], d_rfreq_emb=mconf['d_rfreq_emb'],
248 | cond_mode=mconf['cond_mode']
249 | ).to(device)
250 | if pretrained_params_path:
251 | model.load_state_dict( torch.load(pretrained_params_path) )
252 |
253 | model.train()
254 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
255 | print ('[info] model # params:', n_params)
256 |
257 | opt_params = filter(lambda p: p.requires_grad, model.parameters())
258 | optimizer = optim.Adam(opt_params, lr=max_lr)
259 | if pretrained_optim_path:
260 | optimizer.load_state_dict( torch.load(pretrained_optim_path) )
261 | scheduler = optim.lr_scheduler.CosineAnnealingLR(
262 | optimizer, lr_decay_steps, eta_min=min_lr
263 | )
264 |
265 | if not os.path.exists(ckpt_dir):
266 | os.makedirs(ckpt_dir)
267 | if not os.path.exists(params_dir):
268 | os.makedirs(params_dir)
269 | if not os.path.exists(optim_dir):
270 | os.makedirs(optim_dir)
271 |
272 | for ep in range(config['training']['max_epochs']):
273 | train_model(ep+1, model, dloader, dloader_val, optimizer, scheduler)
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import os, pickle, random
2 | from glob import glob
3 |
4 | import torch
5 | import numpy as np
6 |
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 | IDX_TO_KEY = {
10 | 0: 'A',
11 | 1: 'A#',
12 | 2: 'B',
13 | 3: 'C',
14 | 4: 'C#',
15 | 5: 'D',
16 | 6: 'D#',
17 | 7: 'E',
18 | 8: 'F',
19 | 9: 'F#',
20 | 10: 'G',
21 | 11: 'G#'
22 | }
23 | KEY_TO_IDX = {
24 | v:k for k, v in IDX_TO_KEY.items()
25 | }
26 |
27 | def get_chord_tone(chord_event):
28 | tone = chord_event['value'].split('_')[0]
29 | return tone
30 |
31 | def transpose_chord(chord_event, n_keys):
32 | if chord_event['value'] == 'N_N':
33 | return chord_event
34 |
35 | orig_tone = get_chord_tone(chord_event)
36 | orig_tone_idx = KEY_TO_IDX[orig_tone]
37 | new_tone_idx = (orig_tone_idx + 12 + n_keys) % 12
38 | new_chord_value = chord_event['value'].replace(
39 | '{}_'.format(orig_tone), '{}_'.format(IDX_TO_KEY[new_tone_idx])
40 | )
41 | new_chord_event = {'name': chord_event['name'], 'value': new_chord_value}
42 | # print ('keys={}. {} --> {}'.format(n_keys, chord_event, new_chord_event))
43 |
44 | return new_chord_event
45 |
46 | def check_extreme_pitch(raw_events):
47 | low, high = 128, 0
48 | for ev in raw_events:
49 | if ev['name'] == 'Note_Pitch':
50 | low = min(low, int(ev['value']))
51 | high = max(high, int(ev['value']))
52 |
53 | return low, high
54 |
55 | def transpose_events(raw_events, n_keys):
56 | transposed_raw_events = []
57 |
58 | for ev in raw_events:
59 | if ev['name'] == 'Note_Pitch':
60 | transposed_raw_events.append(
61 | {'name': ev['name'], 'value': ev['value'] + n_keys}
62 | )
63 | elif ev['name'] == 'Chord':
64 | transposed_raw_events.append(
65 | transpose_chord(ev, n_keys)
66 | )
67 | else:
68 | transposed_raw_events.append(ev)
69 |
70 | assert len(transposed_raw_events) == len(raw_events)
71 | return transposed_raw_events
72 |
73 | def pickle_load(path):
74 | return pickle.load(open(path, 'rb'))
75 |
76 | def convert_event(event_seq, event2idx, to_ndarr=True):
77 | if isinstance(event_seq[0], dict):
78 | event_seq = [event2idx['{}_{}'.format(e['name'], e['value'])] for e in event_seq]
79 | else:
80 | event_seq = [event2idx[e] for e in event_seq]
81 |
82 | if to_ndarr:
83 | return np.array(event_seq)
84 | else:
85 | return event_seq
86 |
87 | class REMIFullSongTransformerDataset(Dataset):
88 | def __init__(self, data_dir, vocab_file,
89 | model_enc_seqlen=128, model_dec_seqlen=1280, model_max_bars=16,
90 | pieces=[], do_augment=True, augment_range=range(-6, 7),
91 | min_pitch=22, max_pitch=107, pad_to_same=True, use_attr_cls=True,
92 | appoint_st_bar=None, dec_end_pad_value=None):
93 | self.vocab_file = vocab_file
94 | self.read_vocab()
95 |
96 | self.data_dir = data_dir
97 | self.pieces = pieces
98 | self.build_dataset()
99 |
100 | self.model_enc_seqlen = model_enc_seqlen
101 | self.model_dec_seqlen = model_dec_seqlen
102 | self.model_max_bars = model_max_bars
103 |
104 | self.do_augment = do_augment
105 | self.augment_range = augment_range
106 | self.min_pitch, self.max_pitch = min_pitch, max_pitch
107 | self.pad_to_same = pad_to_same
108 | self.use_attr_cls = use_attr_cls
109 |
110 | self.appoint_st_bar = appoint_st_bar
111 | if dec_end_pad_value is None:
112 | self.dec_end_pad_value = self.pad_token
113 | elif dec_end_pad_value == 'EOS':
114 | self.dec_end_pad_value = self.eos_token
115 | else:
116 | self.dec_end_pad_value = self.pad_token
117 |
118 | def read_vocab(self):
119 | vocab = pickle_load(self.vocab_file)[0]
120 | self.idx2event = pickle_load(self.vocab_file)[1]
121 | orig_vocab_size = len(vocab)
122 | self.event2idx = vocab
123 | self.bar_token = self.event2idx['Bar_None']
124 | self.eos_token = self.event2idx['EOS_None']
125 | self.pad_token = orig_vocab_size
126 | self.vocab_size = self.pad_token + 1
127 |
128 | def build_dataset(self):
129 | if not self.pieces:
130 | self.pieces = sorted( glob(os.path.join(self.data_dir, '*.pkl')) )
131 | else:
132 | self.pieces = sorted( [os.path.join(self.data_dir, p) for p in self.pieces] )
133 |
134 | self.piece_bar_pos = []
135 |
136 | for i, p in enumerate(self.pieces):
137 | bar_pos, p_evs = pickle_load(p)
138 | if not i % 200:
139 | print ('[preparing data] now at #{}'.format(i))
140 | if bar_pos[-1] == len(p_evs):
141 | print ('piece {}, got appended bar markers'.format(p))
142 | bar_pos = bar_pos[:-1]
143 | if len(p_evs) - bar_pos[-1] == 2:
144 | # got empty trailing bar
145 | bar_pos = bar_pos[:-1]
146 |
147 | bar_pos.append(len(p_evs))
148 |
149 | self.piece_bar_pos.append(bar_pos)
150 |
151 | def get_sample_from_file(self, piece_idx):
152 | piece_evs = pickle_load(self.pieces[piece_idx])[1]
153 | if len(self.piece_bar_pos[piece_idx]) > self.model_max_bars and self.appoint_st_bar is None:
154 | picked_st_bar = random.choice(
155 | range(len(self.piece_bar_pos[piece_idx]) - self.model_max_bars)
156 | )
157 | elif self.appoint_st_bar is not None and self.appoint_st_bar < len(self.piece_bar_pos[piece_idx]) - self.model_max_bars:
158 | picked_st_bar = self.appoint_st_bar
159 | else:
160 | picked_st_bar = 0
161 |
162 | piece_bar_pos = self.piece_bar_pos[piece_idx]
163 |
164 | if len(piece_bar_pos) > self.model_max_bars:
165 | piece_evs = piece_evs[ piece_bar_pos[picked_st_bar] : piece_bar_pos[picked_st_bar + self.model_max_bars] ]
166 | picked_bar_pos = np.array(piece_bar_pos[ picked_st_bar : picked_st_bar + self.model_max_bars ]) - piece_bar_pos[picked_st_bar]
167 | n_bars = self.model_max_bars
168 | else:
169 | picked_bar_pos = np.array(piece_bar_pos + [piece_bar_pos[-1]] * (self.model_max_bars - len(piece_bar_pos)))
170 | n_bars = len(piece_bar_pos)
171 | assert len(picked_bar_pos) == self.model_max_bars
172 |
173 | return piece_evs, picked_st_bar, picked_bar_pos, n_bars
174 |
175 | def pad_sequence(self, seq, maxlen, pad_value=None):
176 | if pad_value is None:
177 | pad_value = self.pad_token
178 |
179 | seq.extend( [pad_value for _ in range(maxlen- len(seq))] )
180 |
181 | return seq
182 |
183 | def pitch_augment(self, bar_events):
184 | bar_min_pitch, bar_max_pitch = check_extreme_pitch(bar_events)
185 |
186 | n_keys = random.choice(self.augment_range)
187 | while bar_min_pitch + n_keys < self.min_pitch or bar_max_pitch + n_keys > self.max_pitch:
188 | n_keys = random.choice(self.augment_range)
189 |
190 | augmented_bar_events = transpose_events(bar_events, n_keys)
191 | return augmented_bar_events
192 |
193 | def get_attr_classes(self, piece, st_bar):
194 | polyph_cls = pickle_load(os.path.join(self.data_dir, 'attr_cls/polyph', piece))[st_bar : st_bar + self.model_max_bars]
195 | rfreq_cls = pickle_load(os.path.join(self.data_dir, 'attr_cls/rhythm', piece))[st_bar : st_bar + self.model_max_bars]
196 |
197 | polyph_cls.extend([0 for _ in range(self.model_max_bars - len(polyph_cls))])
198 | rfreq_cls.extend([0 for _ in range(self.model_max_bars - len(rfreq_cls))])
199 |
200 | assert len(polyph_cls) == self.model_max_bars
201 | assert len(rfreq_cls) == self.model_max_bars
202 |
203 | return polyph_cls, rfreq_cls
204 |
205 | def get_encoder_input_data(self, bar_positions, bar_events):
206 | assert len(bar_positions) == self.model_max_bars + 1
207 | enc_padding_mask = np.ones((self.model_max_bars, self.model_enc_seqlen), dtype=bool)
208 | enc_padding_mask[:, :2] = False
209 | padded_enc_input = np.full((self.model_max_bars, self.model_enc_seqlen), dtype=int, fill_value=self.pad_token)
210 | enc_lens = np.zeros((self.model_max_bars,))
211 |
212 | for b, (st, ed) in enumerate(zip(bar_positions[:-1], bar_positions[1:])):
213 | enc_padding_mask[b, : (ed-st)] = False
214 | enc_lens[b] = ed - st
215 | within_bar_events = self.pad_sequence(bar_events[st : ed], self.model_enc_seqlen, self.pad_token)
216 | within_bar_events = np.array(within_bar_events)
217 |
218 | padded_enc_input[b, :] = within_bar_events[:self.model_enc_seqlen]
219 |
220 | return padded_enc_input, enc_padding_mask, enc_lens
221 |
222 | def __len__(self):
223 | return len(self.pieces)
224 |
225 | def __getitem__(self, idx):
226 | if torch.is_tensor(idx):
227 | idx = idx.tolist()
228 |
229 | bar_events, st_bar, bar_pos, enc_n_bars = self.get_sample_from_file(idx)
230 | if self.do_augment:
231 | bar_events = self.pitch_augment(bar_events)
232 |
233 | if self.use_attr_cls:
234 | polyph_cls, rfreq_cls = self.get_attr_classes(os.path.basename(self.pieces[idx]), st_bar)
235 | polyph_cls_expanded = np.zeros((self.model_dec_seqlen,), dtype=int)
236 | rfreq_cls_expanded = np.zeros((self.model_dec_seqlen,), dtype=int)
237 | for i, (b_st, b_ed) in enumerate(zip(bar_pos[:-1], bar_pos[1:])):
238 | polyph_cls_expanded[b_st:b_ed] = polyph_cls[i]
239 | rfreq_cls_expanded[b_st:b_ed] = rfreq_cls[i]
240 | else:
241 | polyph_cls, rfreq_cls = [0], [0]
242 | polyph_cls_expanded, rfreq_cls_expanded = [0], [0]
243 |
244 | bar_tokens = convert_event(bar_events, self.event2idx, to_ndarr=False)
245 | bar_pos = bar_pos.tolist() + [len(bar_tokens)]
246 |
247 | enc_inp, enc_padding_mask, enc_lens = self.get_encoder_input_data(bar_pos, bar_tokens)
248 |
249 | length = len(bar_tokens)
250 | if self.pad_to_same:
251 | inp = self.pad_sequence(bar_tokens, self.model_dec_seqlen + 1)
252 | else:
253 | inp = self.pad_sequence(bar_tokens, len(bar_tokens) + 1, pad_value=self.dec_end_pad_value)
254 | target = np.array(inp[1:], dtype=int)
255 | inp = np.array(inp[:-1], dtype=int)
256 | assert len(inp) == len(target)
257 |
258 | return {
259 | 'id': idx,
260 | 'piece_id': int(os.path.basename(self.pieces[idx]).replace('.pkl', '')),
261 | 'st_bar_id': st_bar,
262 | 'bar_pos': np.array(bar_pos, dtype=int),
263 | 'enc_input': enc_inp,
264 | 'dec_input': inp[:self.model_dec_seqlen],
265 | 'dec_target': target[:self.model_dec_seqlen],
266 | 'polyph_cls': polyph_cls_expanded,
267 | 'rhymfreq_cls': rfreq_cls_expanded,
268 | 'polyph_cls_bar': np.array(polyph_cls),
269 | 'rhymfreq_cls_bar': np.array(rfreq_cls),
270 | 'length': min(length, self.model_dec_seqlen),
271 | 'enc_padding_mask': enc_padding_mask,
272 | 'enc_length': enc_lens,
273 | 'enc_n_bars': enc_n_bars
274 | }
275 |
276 | if __name__ == "__main__":
277 | # codes below are for unit test
278 | dset = REMIFullSongTransformerDataset(
279 | './remi_dataset', './pickles/remi_vocab.pkl', do_augment=True, use_attr_cls=True,
280 | model_max_bars=16, model_dec_seqlen=1280, model_enc_seqlen=128, min_pitch=22, max_pitch=107
281 | )
282 | print (dset.bar_token, dset.pad_token, dset.vocab_size)
283 | print ('length:', len(dset))
284 |
285 | # for i in random.sample(range(len(dset)), 100):
286 | # for i in range(len(dset)):
287 | # sample = dset[i]
288 | # print (i, len(sample['bar_pos']), sample['bar_pos'])
289 | # print (i)
290 | # print ('******* ----------- *******')
291 | # print ('piece: {}, st_bar: {}'.format(sample['piece_id'], sample['st_bar_id']))
292 | # print (sample['enc_input'][:8, :16])
293 | # print (sample['dec_input'][:16])
294 | # print (sample['dec_target'][:16])
295 | # print (sample['enc_padding_mask'][:32, :16])
296 | # print (sample['length'])
297 |
298 | dloader = DataLoader(dset, batch_size=4, shuffle=False, num_workers=24)
299 | for i, batch in enumerate(dloader):
300 | for k, v in batch.items():
301 | if torch.is_tensor(v):
302 | print (k, ':', v.dtype, v.size())
303 | print ('=====================================\n')
304 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | import sys, os, random, time
2 | from copy import deepcopy
3 | sys.path.append('./model')
4 |
5 | from dataloader import REMIFullSongTransformerDataset
6 | from model.musemorphose import MuseMorphose
7 |
8 | from utils import pickle_load, numpy_to_tensor, tensor_to_numpy
9 | from remi2midi import remi2midi
10 |
11 | import torch
12 | import yaml
13 | import numpy as np
14 | from scipy.stats import entropy
15 |
16 | config_path = sys.argv[1]
17 | config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader)
18 |
19 | device = config['training']['device']
20 | data_dir = config['data']['data_dir']
21 | vocab_path = config['data']['vocab_path']
22 | data_split = 'pickles/test_pieces.pkl'
23 |
24 | ckpt_path = sys.argv[2]
25 | out_dir = sys.argv[3]
26 | n_pieces = int(sys.argv[4])
27 | n_samples_per_piece = int(sys.argv[5])
28 |
29 | ###########################################
30 | # little helpers
31 | ###########################################
32 | def word2event(word_seq, idx2event):
33 | return [ idx2event[w] for w in word_seq ]
34 |
35 | def get_beat_idx(event):
36 | return int(event.split('_')[-1])
37 |
38 | ###########################################
39 | # sampling utilities
40 | ###########################################
41 | def temperatured_softmax(logits, temperature):
42 | try:
43 | probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature))
44 | assert np.count_nonzero(np.isnan(probs)) == 0
45 | except:
46 | print ('overflow detected, use 128-bit')
47 | logits = logits.astype(np.float128)
48 | probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature))
49 | probs = probs.astype(float)
50 | return probs
51 |
52 | def nucleus(probs, p):
53 | probs /= sum(probs)
54 | sorted_probs = np.sort(probs)[::-1]
55 | sorted_index = np.argsort(probs)[::-1]
56 | cusum_sorted_probs = np.cumsum(sorted_probs)
57 | after_threshold = cusum_sorted_probs > p
58 | if sum(after_threshold) > 0:
59 | last_index = np.where(after_threshold)[0][1]
60 | candi_index = sorted_index[:last_index]
61 | else:
62 | candi_index = sorted_index[:3] # just assign a value
63 | candi_probs = np.array([probs[i] for i in candi_index], dtype=np.float64)
64 | candi_probs /= sum(candi_probs)
65 | word = np.random.choice(candi_index, size=1, p=candi_probs)[0]
66 | return word
67 |
68 | ########################################
69 | # generation
70 | ########################################
71 | def get_latent_embedding_fast(model, piece_data, use_sampling=False, sampling_var=0.):
72 | # reshape
73 | batch_inp = piece_data['enc_input'].permute(1, 0).long().to(device)
74 | batch_padding_mask = piece_data['enc_padding_mask'].bool().to(device)
75 |
76 | # get latent conditioning vectors
77 | with torch.no_grad():
78 | piece_latents = model.get_sampled_latent(
79 | batch_inp, padding_mask=batch_padding_mask,
80 | use_sampling=use_sampling, sampling_var=sampling_var
81 | )
82 |
83 | return piece_latents
84 |
85 | def generate_on_latent_ctrl_vanilla_truncate(
86 | model, latents, rfreq_cls, polyph_cls, event2idx, idx2event,
87 | max_events=12800, primer=None,
88 | max_input_len=1280, truncate_len=512,
89 | nucleus_p=0.9, temperature=1.2
90 | ):
91 | latent_placeholder = torch.zeros(max_events, 1, latents.size(-1)).to(device)
92 | rfreq_placeholder = torch.zeros(max_events, 1, dtype=int).to(device)
93 | polyph_placeholder = torch.zeros(max_events, 1, dtype=int).to(device)
94 | print ('[info] rhythm cls: {} | polyph_cls: {}'.format(rfreq_cls, polyph_cls))
95 |
96 | if primer is None:
97 | generated = [event2idx['Bar_None']]
98 | else:
99 | generated = [event2idx[e] for e in primer]
100 | latent_placeholder[:len(generated), 0, :] = latents[0].squeeze(0)
101 | rfreq_placeholder[:len(generated), 0] = rfreq_cls[0]
102 | polyph_placeholder[:len(generated), 0] = polyph_cls[0]
103 |
104 | target_bars, generated_bars = latents.size(0), 0
105 |
106 | steps = 0
107 | time_st = time.time()
108 | cur_pos = 0
109 | failed_cnt = 0
110 |
111 | cur_input_len = len(generated)
112 | generated_final = deepcopy(generated)
113 | entropies = []
114 |
115 | while generated_bars < target_bars:
116 | if len(generated) == 1:
117 | dec_input = numpy_to_tensor([generated], device=device).long()
118 | else:
119 | dec_input = numpy_to_tensor([generated], device=device).permute(1, 0).long()
120 |
121 | latent_placeholder[len(generated)-1, 0, :] = latents[ generated_bars ]
122 | rfreq_placeholder[len(generated)-1, 0] = rfreq_cls[ generated_bars ]
123 | polyph_placeholder[len(generated)-1, 0] = polyph_cls[ generated_bars ]
124 |
125 | dec_seg_emb = latent_placeholder[:len(generated), :]
126 | dec_rfreq_cls = rfreq_placeholder[:len(generated), :]
127 | dec_polyph_cls = polyph_placeholder[:len(generated), :]
128 |
129 | # sampling
130 | with torch.no_grad():
131 | logits = model.generate(dec_input, dec_seg_emb, dec_rfreq_cls, dec_polyph_cls)
132 | logits = tensor_to_numpy(logits[0])
133 | probs = temperatured_softmax(logits, temperature)
134 | word = nucleus(probs, nucleus_p)
135 | word_event = idx2event[word]
136 |
137 | if 'Beat' in word_event:
138 | event_pos = get_beat_idx(word_event)
139 | if not event_pos >= cur_pos:
140 | failed_cnt += 1
141 | print ('[info] position not increasing, failed cnt:', failed_cnt)
142 | if failed_cnt >= 128:
143 | print ('[FATAL] model stuck, exiting ...')
144 | return generated
145 | continue
146 | else:
147 | cur_pos = event_pos
148 | failed_cnt = 0
149 |
150 | if 'Bar' in word_event:
151 | generated_bars += 1
152 | cur_pos = 0
153 | print ('[info] generated {} bars, #events = {}'.format(generated_bars, len(generated_final)))
154 | if word_event == 'PAD_None':
155 | continue
156 |
157 | if len(generated) > max_events or (word_event == 'EOS_None' and generated_bars == target_bars - 1):
158 | generated_bars += 1
159 | generated.append(event2idx['Bar_None'])
160 | print ('[info] gotten eos')
161 | break
162 |
163 | generated.append(word)
164 | generated_final.append(word)
165 | entropies.append(entropy(probs))
166 |
167 | cur_input_len += 1
168 | steps += 1
169 |
170 | assert cur_input_len == len(generated)
171 | if cur_input_len == max_input_len:
172 | generated = generated[-truncate_len:]
173 | latent_placeholder[:len(generated)-1, 0, :] = latent_placeholder[cur_input_len-truncate_len:cur_input_len-1, 0, :]
174 | rfreq_placeholder[:len(generated)-1, 0] = rfreq_placeholder[cur_input_len-truncate_len:cur_input_len-1, 0]
175 | polyph_placeholder[:len(generated)-1, 0] = polyph_placeholder[cur_input_len-truncate_len:cur_input_len-1, 0]
176 |
177 | print ('[info] reset context length: cur_len: {}, accumulated_len: {}, truncate_range: {} ~ {}'.format(
178 | cur_input_len, len(generated_final), cur_input_len-truncate_len, cur_input_len-1
179 | ))
180 | cur_input_len = len(generated)
181 |
182 | assert generated_bars == target_bars
183 | print ('-- generated events:', len(generated_final))
184 | print ('-- time elapsed: {:.2f} secs'.format(time.time() - time_st))
185 | return generated_final[:-1], time.time() - time_st, np.array(entropies)
186 |
187 |
188 | ########################################
189 | # change attribute classes
190 | ########################################
191 | def random_shift_attr_cls(n_samples, upper=4, lower=-3):
192 | return np.random.randint(lower, upper, (n_samples,))
193 |
194 |
195 | if __name__ == "__main__":
196 | dset = REMIFullSongTransformerDataset(
197 | data_dir, vocab_path,
198 | do_augment=False,
199 | model_enc_seqlen=config['data']['enc_seqlen'],
200 | model_dec_seqlen=config['generate']['dec_seqlen'],
201 | model_max_bars=config['generate']['max_bars'],
202 | pieces=pickle_load(data_split),
203 | pad_to_same=False
204 | )
205 | pieces = random.sample(range(len(dset)), n_pieces)
206 | print ('[sampled pieces]', pieces)
207 |
208 | mconf = config['model']
209 | model = MuseMorphose(
210 | mconf['enc_n_layer'], mconf['enc_n_head'], mconf['enc_d_model'], mconf['enc_d_ff'],
211 | mconf['dec_n_layer'], mconf['dec_n_head'], mconf['dec_d_model'], mconf['dec_d_ff'],
212 | mconf['d_latent'], mconf['d_embed'], dset.vocab_size,
213 | d_polyph_emb=mconf['d_polyph_emb'], d_rfreq_emb=mconf['d_rfreq_emb'],
214 | cond_mode=mconf['cond_mode']
215 | ).to(device)
216 | model.eval()
217 | model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
218 |
219 | if not os.path.exists(out_dir):
220 | os.makedirs(out_dir)
221 |
222 | times = []
223 | for p in pieces:
224 | # fetch test sample
225 | p_data = dset[p]
226 | p_id = p_data['piece_id']
227 | p_bar_id = p_data['st_bar_id']
228 | p_data['enc_input'] = p_data['enc_input'][ : p_data['enc_n_bars'] ]
229 | p_data['enc_padding_mask'] = p_data['enc_padding_mask'][ : p_data['enc_n_bars'] ]
230 |
231 | orig_p_cls_str = ''.join(str(c) for c in p_data['polyph_cls_bar'])
232 | orig_r_cls_str = ''.join(str(c) for c in p_data['rhymfreq_cls_bar'])
233 |
234 | orig_song = p_data['dec_input'].tolist()[:p_data['length']]
235 | orig_song = word2event(orig_song, dset.idx2event)
236 | orig_out_file = os.path.join(out_dir, 'id{}_bar{}_orig'.format(
237 | p, p_bar_id
238 | ))
239 | print ('[info] writing to ...', orig_out_file)
240 | # output reference song's MIDI
241 | _, orig_tempo = remi2midi(orig_song, orig_out_file + '.mid', return_first_tempo=True, enforce_tempo=False)
242 |
243 | # save metadata of reference song (events & attr classes)
244 | print (*orig_song, sep='\n', file=open(orig_out_file + '.txt', 'a'))
245 | np.save(orig_out_file + '-POLYCLS.npy', p_data['polyph_cls_bar'])
246 | np.save(orig_out_file + '-RHYMCLS.npy', p_data['rhymfreq_cls_bar'])
247 |
248 |
249 | for k in p_data.keys():
250 | if not torch.is_tensor(p_data[k]):
251 | p_data[k] = numpy_to_tensor(p_data[k], device=device)
252 | else:
253 | p_data[k] = p_data[k].to(device)
254 |
255 | p_latents = get_latent_embedding_fast(
256 | model, p_data,
257 | use_sampling=config['generate']['use_latent_sampling'],
258 | sampling_var=config['generate']['latent_sampling_var']
259 | )
260 | p_cls_diff = random_shift_attr_cls(n_samples_per_piece)
261 | r_cls_diff = random_shift_attr_cls(n_samples_per_piece)
262 |
263 | piece_entropies = []
264 | for samp in range(n_samples_per_piece):
265 | p_polyph_cls = (p_data['polyph_cls_bar'] + p_cls_diff[samp]).clamp(0, 7).long()
266 | p_rfreq_cls = (p_data['rhymfreq_cls_bar'] + r_cls_diff[samp]).clamp(0, 7).long()
267 |
268 | print ('[info] piece: {}, bar: {}'.format(p_id, p_bar_id))
269 | out_file = os.path.join(out_dir, 'id{}_bar{}_sample{:02d}_poly{}_rhym{}'.format(
270 | p, p_bar_id, samp + 1,
271 | '+{}'.format(p_cls_diff[samp]) if p_cls_diff[samp] >= 0 else p_cls_diff[samp],
272 | '+{}'.format(r_cls_diff[samp]) if r_cls_diff[samp] >= 0 else r_cls_diff[samp]
273 | ))
274 | print ('[info] writing to ...', out_file)
275 | if os.path.exists(out_file + '.txt'):
276 | print ('[info] file exists, skipping ...')
277 | continue
278 |
279 | # print (p_polyph_cls, p_rfreq_cls)
280 |
281 | # generate
282 | song, t_sec, entropies = generate_on_latent_ctrl_vanilla_truncate(
283 | model, p_latents, p_rfreq_cls, p_polyph_cls, dset.event2idx, dset.idx2event,
284 | max_input_len=config['generate']['max_input_dec_seqlen'],
285 | truncate_len=min(512, config['generate']['max_input_dec_seqlen'] - 32),
286 | nucleus_p=config['generate']['nucleus_p'],
287 | temperature=config['generate']['temperature'],
288 |
289 | )
290 | times.append(t_sec)
291 |
292 | song = word2event(song, dset.idx2event)
293 | print (*song, sep='\n', file=open(out_file + '.txt', 'a'))
294 | remi2midi(song, out_file + '.mid', enforce_tempo=True, enforce_tempo_val=orig_tempo)
295 |
296 | # save metadata of the generation
297 | np.save(out_file + '-POLYCLS.npy', tensor_to_numpy(p_polyph_cls))
298 | np.save(out_file + '-RHYMCLS.npy', tensor_to_numpy(p_rfreq_cls))
299 | print ('[info] piece entropy: {:.4f} (+/- {:.4f})'.format(
300 | entropies.mean(), entropies.std()
301 | ))
302 | piece_entropies.append(entropies.mean())
303 |
304 | print ('[time stats] {} songs, generation time: {:.2f} secs (+/- {:.2f})'.format(
305 | n_pieces * n_samples_per_piece, np.mean(times), np.std(times)
306 | ))
307 | print ('[entropy] {:.4f} (+/- {:.4f})'.format(
308 | np.mean(piece_entropies), np.std(piece_entropies)
309 | ))
--------------------------------------------------------------------------------