├── .gitignore ├── LICENSE ├── README.md ├── adversarial.py ├── config.py ├── data.py ├── dataset ├── midi │ └── .keep ├── processed │ └── .keep └── scripts │ ├── .keep │ ├── classic_piano_downloader.sh │ ├── ecomp_piano_downloader.sh │ ├── midiworld_downloader.sh │ └── touhou_downloader.sh ├── generate.py ├── model.py ├── output └── .keep ├── play.py ├── preprocess.py ├── runs └── .keep ├── save └── .keep ├── sequence.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *.pyc 3 | .DS_Store 4 | *.mid 5 | *.MID 6 | *.sess 7 | *.data 8 | runs/ 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yuankui Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![generated-sample-sheet-music](https://user-images.githubusercontent.com/17045050/42017029-3b4f7060-7ae0-11e8-829b-6d6b8b829759.png) 2 | 3 | # Performance RNN - PyTorch 4 | 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | 7 | PyTorch implementation of Performance RNN, inspired by *Ian Simon and Sageev Oore. "Performance RNN: Generating Music with Expressive 8 | Timing and Dynamics." Magenta Blog, 2017.* 9 | [https://magenta.tensorflow.org/performance-rnn](https://magenta.tensorflow.org/performance-rnn). 10 | 11 | This model is not implemented in the official way! 12 | 13 | ## Generated Samples 14 | 15 | - A sample on C Major Scale [[MIDI](https://drive.google.com/open?id=1mZtkpsu1yA8oOkE_1b2jyFsvCW70FiKU), [MP3](https://drive.google.com/open?id=1UqyJ9e58AOimFeY1xoCPyedTz-g2fUxv)] 16 | - control option: `-c '1,0,1,0,1,1,0,1,0,1,0,1;4'` 17 | - A sample on C Minor Scale [[MIDI](https://drive.google.com/open?id=1lIVCIT7INuTa-HKrgPzewrgCbgwCRRa1), [MP3](https://drive.google.com/open?id=1pVg3Mg2pSq8VHJRJrgNUZybpsErjzpjF)] 18 | - control option: `-c '1,0,1,1,0,1,0,1,1,0,0,1;4'` 19 | - A sample on C Major Pentatonic Scale [[MIDI](https://drive.google.com/open?id=16uRwyntgYTzSmaxhp06kUbThDm8W_vVE), [MP3](https://drive.google.com/open?id=1LSbeVqXKAPrNPCPcjy6FVwUuVo7FxYji)] 20 | - control option: `-c '5,0,4,0,4,1,0,5,0,4,0,1;3'` 21 | - A sample on C Minor Pentatonic Scale [[MIDI](https://drive.google.com/open?id=1zeMHNu37U6byhT-s63EIro8nL6VkUi8u), [MP3](https://drive.google.com/open?id=1asP1z6u1n3PRSysSnvkt-SabpTgT-_x5)] 22 | - control option: `-c '5,0,1,4,0,4,0,5,1,0,4,0;3'` 23 | 24 | ## Directory Structure 25 | 26 | ``` 27 | . 28 | ├── dataset/ 29 | │   ├── midi/ 30 | │   │ ├── dataset1/ 31 | │   │   │ └── *.mid 32 | │   │ └── dataset2/ 33 | │   │   └── *.mid 34 | │   ├── processed/ 35 | │   │ └── dataset1/ 36 | │   │   └── *.data (preprocess.py) 37 | │   └── scripts/ 38 | │   └── *.sh (dataset download scripts) 39 | ├── output/ 40 | │   └── *.mid (generate.py) 41 | ├── save/ 42 | │   └── *.sess (train.py) 43 | └── runs/ (tensorboard logdir) 44 | ``` 45 | 46 | ## Instructions 47 | 48 | - Download datasets 49 | 50 | ```shell 51 | cd dataset/ 52 | bash scripts/NAME_scraper.sh midi/NAME 53 | ``` 54 | 55 | - Preprocessing 56 | 57 | ```shell 58 | # Preprocess all MIDI files under dataset/midi/NAME 59 | python3 preprocess.py dataset/midi/NAME dataset/processed/NAME 60 | ``` 61 | 62 | - Training 63 | 64 | ```shell 65 | # Train on .data files in dataset/processed/MYDATA, and save to save/myModel.sess every 10s 66 | python3 train.py -s save/myModel.sess -d dataset/processed/MYDATA -i 10 67 | 68 | # Or... 69 | python3 train.py -s save/myModel.sess -d dataset/processed/MYDATA -p hidden_dim=1024 70 | python3 train.py -s save/myModel.sess -d dataset/processed/MYDATA -b 128 -c 0.3 71 | python3 train.py -s save/myModel.sess -d dataset/processed/MYDATA -w 100 -S 10 72 | ``` 73 | 74 | ![training-figure](https://user-images.githubusercontent.com/17045050/42135712-7f6e25f4-7d81-11e8-845f-682bd26a3abb.png) 75 | 76 | 77 | - Generating 78 | 79 | ```shell 80 | # Generate with control sequence from test.data and model from save/test.sess 81 | python3 generate.py -s save/test.sess -c test.data 82 | 83 | # Generate with pitch histogram and note density (C major scale) 84 | python3 generate.py -s save/test.sess -l 1000 -c '1,0,1,0,1,1,0,1,0,1,0,1;3' 85 | 86 | # Or... 87 | python3 generate.py -s save/test.sess -l 1000 -c ';3' # uniform pitch histogram 88 | python3 generate.py -s save/test.sess -l 1000 # no control 89 | 90 | # Use control sequence from processed data 91 | python3 generate.py -s save/test.sess -c dataset/processed/some/processed.data 92 | ``` 93 | 94 | ![generated-sample-1](https://user-images.githubusercontent.com/17045050/42017026-37dfd7b2-7ae0-11e8-99a9-75d27510f44b.png) 95 | 96 | ![generated-sample-2](https://user-images.githubusercontent.com/17045050/42017017-337ce0a2-7ae0-11e8-8193-12ea539af424.png) 97 | 98 | ## Pretrained Model 99 | 100 | - [ecomp.sess](https://drive.google.com/open?id=1daT6XRQUTS6AQ5jyRPqzowXia-zVqg6m) 101 | - default configuration 102 | - dataset: [International Piano-e-Competition, recorded MIDI files](http://www.piano-e-competition.com/) 103 | - [ecomp_w500.sess](https://drive.google.com/open?id=1jf5j2cWppXVeSXhTuiNfAFEyWFIaNZ6f) 104 | - window_size: 500 105 | - control_ratio: 0.7 106 | - dataset: [International Piano-e-Competition, recorded MIDI files](http://www.piano-e-competition.com/) 107 | 108 | ## Requirements 109 | 110 | - pretty_midi 111 | - numpy 112 | - pytorch >= 0.4 113 | - tensorboardX 114 | - progress 115 | -------------------------------------------------------------------------------- /adversarial.py: -------------------------------------------------------------------------------- 1 | # Adversarial learning for event-based music generation with SeqGAN 2 | # Reference: 3 | # "SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient." 4 | # (Yu, Lantao, et al.). 5 | # ... Honestly, it's too hard to train ;( 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from torch.distributions import Categorical 12 | 13 | import numpy as np 14 | import os, sys, time, argparse 15 | from progress.bar import Bar 16 | 17 | import config, utils 18 | from config import device 19 | from data import Dataset 20 | from model import PerformanceRNN 21 | from sequence import EventSeq, ControlSeq 22 | 23 | # pylint: disable=E1101 24 | 25 | 26 | #======================================================================== 27 | # Discriminator 28 | #======================================================================== 29 | 30 | discriminator_config = { 31 | 'event_dim': EventSeq.dim(), 32 | 'hidden_dim': 512, 33 | 'gru_layers': 3, 34 | 'gru_dropout': 0.3 35 | } 36 | 37 | class EventSequenceEncoder(nn.Module): 38 | def __init__(self, event_dim=EventSeq.dim(), hidden_dim=512, 39 | gru_layers=3, gru_dropout=0.3): 40 | super().__init__() 41 | self.event_embedding = nn.Embedding(event_dim, hidden_dim) 42 | self.gru = nn.GRU(hidden_dim, hidden_dim, 43 | num_layers=gru_layers, dropout=gru_dropout) 44 | self.attn = nn.Parameter(torch.randn(hidden_dim), requires_grad=True) 45 | self.output_fc = nn.Linear(hidden_dim, 1) 46 | self.output_fc_activation = nn.Sigmoid() 47 | 48 | def forward(self, events, hidden=None, output_logits=False): 49 | # events: [steps, batch_size] 50 | events = self.event_embedding(events) 51 | outputs, _ = self.gru(events, hidden) # [t, b, h] 52 | weights = (outputs * self.attn).sum(-1, keepdim=True) 53 | output = (outputs * weights).mean(0) # [b, h] 54 | output = self.output_fc(output).squeeze(-1) # [b] 55 | if output_logits: 56 | return output 57 | output = self.output_fc_activation(output) 58 | return output 59 | 60 | 61 | #======================================================================== 62 | # Pretrain Discriminator 63 | #======================================================================== 64 | 65 | def pretrain_discriminator(model_sess_path, # load 66 | discriminator_sess_path, # load + save 67 | batch_data_generator, # Dataset(...).batches(...) 68 | discriminator_config_overwrite={}, 69 | gradient_clipping=False, 70 | control_ratio=1.0, 71 | num_iter=-1, 72 | save_interval=60.0, 73 | discriminator_lr=0.001, 74 | enable_logging=False, 75 | auto_sample_factor=False, 76 | sample_factor=1.0): 77 | 78 | print('-' * 70) 79 | print('model_sess_path:', model_sess_path) 80 | print('discriminator_sess_path:', discriminator_sess_path) 81 | print('discriminator_config_overwrite:', discriminator_config_overwrite) 82 | print('sample_factor:', sample_factor) 83 | print('auto_sample_factor:', auto_sample_factor) 84 | print('discriminator_lr:', discriminator_lr) 85 | print('gradient_clipping:', gradient_clipping) 86 | print('control_ratio:', control_ratio) 87 | print('num_iter:', num_iter) 88 | print('save_interval:', save_interval) 89 | print('enable_logging:', enable_logging) 90 | print('-' * 70) 91 | 92 | # Load generator 93 | model_sess = torch.load(model_sess_path) 94 | model_config = model_sess['model_config'] 95 | model = PerformanceRNN(**model_config).to(device) 96 | model.load_state_dict(model_sess['model_state']) 97 | 98 | print(f'Generator from "{model_sess_path}"') 99 | print(model) 100 | print('-' * 70) 101 | 102 | # Load discriminator and optimizer 103 | global discriminator_config 104 | try: 105 | discriminator_sess = torch.load(discriminator_sess_path) 106 | discriminator_config = discriminator_sess['discriminator_config'] 107 | discriminator_state = discriminator_sess['discriminator_state'] 108 | discriminator_optimizer_state = discriminator_sess['discriminator_optimizer_state'] 109 | print(f'Discriminator from "{discriminator_sess_path}"') 110 | discriminator_loaded = True 111 | except: 112 | print(f'New discriminator session at "{discriminator_sess_path}"') 113 | discriminator_config.update(discriminator_config_overwrite) 114 | discriminator_loaded = False 115 | 116 | discriminator = EventSequenceEncoder(**discriminator_config).to(device) 117 | optimizer = optim.Adam(discriminator.parameters(), lr=discriminator_lr) 118 | if discriminator_loaded: 119 | discriminator.load_state_dict(discriminator_state) 120 | optimizer.load_state_dict(discriminator_optimizer_state) 121 | 122 | print(discriminator) 123 | print(optimizer) 124 | print('-' * 70) 125 | 126 | def save_discriminator(): 127 | print(f'Saving to "{discriminator_sess_path}"') 128 | torch.save({ 129 | 'discriminator_config': discriminator_config, 130 | 'discriminator_state': discriminator.state_dict(), 131 | 'discriminator_optimizer_state': optimizer.state_dict() 132 | }, discriminator_sess_path) 133 | print('Done saving') 134 | 135 | # Disable gradient for generator 136 | for parameter in model.parameters(): 137 | parameter.requires_grad_(False) 138 | 139 | model.eval() 140 | discriminator.train() 141 | 142 | loss_func = nn.BCEWithLogitsLoss() 143 | last_save_time = time.time() 144 | 145 | if enable_logging: 146 | from tensorboardX import SummaryWriter 147 | writer = SummaryWriter() 148 | 149 | try: 150 | for i, (events, controls) in enumerate(batch_data_generator): 151 | if i == num_iter: 152 | break 153 | 154 | steps, batch_size = events.shape 155 | 156 | # Prepare inputs 157 | events = torch.LongTensor(events).to(device) 158 | if np.random.random() <= control_ratio: 159 | controls = torch.FloatTensor(controls).to(device) 160 | else: 161 | controls = None 162 | 163 | init = torch.randn(batch_size, model.init_dim).to(device) 164 | 165 | # Predict for real event sequence 166 | real_events = events 167 | real_logit = discriminator(real_events, output_logits=True) 168 | real_target = torch.ones_like(real_logit).to(device) 169 | 170 | if auto_sample_factor: 171 | sample_factor = np.random.choice([ 172 | 0.1, 0.4, 0.6, 0.7, 0.8, 0.9, 1.0, 173 | 1.1, 1.2, 1.4, 1.6, 2.0, 4.0, 10.0]) 174 | 175 | # Predict for fake event sequence from the generator 176 | fake_events = model.generate(init, steps, None, controls, 177 | greedy=0, output_type='index', 178 | temperature=sample_factor) 179 | fake_logit = discriminator(fake_events, output_logits=True) 180 | fake_target = torch.zeros_like(fake_logit).to(device) 181 | 182 | # Compute loss 183 | loss = (loss_func(real_logit, real_target) + 184 | loss_func(fake_logit, fake_target)) / 2 185 | 186 | # Backprop 187 | discriminator.zero_grad() 188 | loss.backward() 189 | 190 | # Gradient clipping 191 | norm = utils.compute_gradient_norm(discriminator.parameters()) 192 | if gradient_clipping: 193 | nn.utils.clip_grad_norm_(discriminator.parameters(), gradient_clipping) 194 | 195 | optimizer.step() 196 | 197 | # Logging 198 | loss = loss.item() 199 | norm = norm.item() 200 | print(f'{i} loss: {loss}, norm: {norm}, sf: {sample_factor}') 201 | if enable_logging: 202 | writer.add_scalar(f'pretrain/D/loss/all', loss, i) 203 | writer.add_scalar(f'pretrain/D/loss/{sample_factor}', loss, i) 204 | writer.add_scalar(f'pretrain/D/norm/{sample_factor}', norm, i) 205 | 206 | if last_save_time + save_interval < time.time(): 207 | last_save_time = time.time() 208 | save_discriminator() 209 | 210 | except KeyboardInterrupt: 211 | save_discriminator() 212 | 213 | 214 | #======================================================================== 215 | # Adversarial Learning 216 | #======================================================================== 217 | 218 | 219 | def train_adversarial(sess_path, batch_data_generator, 220 | model_load_path, model_optimizer_class, 221 | model_gradient_clipping, discriminator_gradient_clipping, 222 | model_learning_rate, reset_model_optimizer, 223 | discriminator_load_path, discriminator_optimizer_class, 224 | discriminator_learning_rate, reset_discriminator_optimizer, 225 | g_max_q_mean, g_min_q_mean, d_min_loss, g_max_steps, d_max_steps, 226 | mc_sample_size, mc_sample_factor, first_to_train, 227 | save_interval, control_ratio, enable_logging): 228 | 229 | if enable_logging: 230 | from tensorboardX import SummaryWriter 231 | writer = SummaryWriter() 232 | 233 | if os.path.isfile(sess_path): 234 | adv_state = torch.load(sess_path) 235 | model_config = adv_state['model_config'] 236 | model_state = adv_state['model_state'] 237 | model_optimizer_state = adv_state['model_optimizer_state'] 238 | discriminator_config = adv_state['discriminator_config'] 239 | discriminator_state = adv_state['discriminator_state'] 240 | discriminator_optimizer_state = adv_state['discriminator_optimizer_state'] 241 | print('-' * 70) 242 | print('Session is loaded from', sess_path) 243 | loaded_from_session = True 244 | 245 | else: 246 | model_sess = torch.load(model_load_path) 247 | model_config = model_sess['model_config'] 248 | model_state = model_sess['model_state'] 249 | discriminator_sess = torch.load(discriminator_load_path) 250 | discriminator_config = discriminator_sess['discriminator_config'] 251 | discriminator_state = discriminator_sess['discriminator_state'] 252 | loaded_from_session = False 253 | 254 | model = PerformanceRNN(**model_config) 255 | model.load_state_dict(model_state) 256 | model.to(device).train() 257 | model_optimizer = model_optimizer_class(model.parameters(), lr=model_learning_rate) 258 | 259 | discriminator = EventSequenceEncoder(**discriminator_config) 260 | discriminator.load_state_dict(discriminator_state) 261 | discriminator.to(device).train() 262 | discriminator_optimizer = discriminator_optimizer_class(discriminator.parameters(), 263 | lr=discriminator_learning_rate) 264 | 265 | if loaded_from_session: 266 | if not reset_model_optimizer: 267 | model_optimizer.load_state_dict(model_optimizer_state) 268 | if not reset_discriminator_optimizer: 269 | discriminator_optimizer.load_state_dict(discriminator_optimizer_state) 270 | 271 | g_loss_func = nn.CrossEntropyLoss() 272 | d_loss_func = nn.BCEWithLogitsLoss(reduce=False) 273 | 274 | 275 | print('-' * 70) 276 | print('Options') 277 | print('sess_path:', sess_path) 278 | print('save_interval:', save_interval) 279 | print('batch_data_generator:', batch_data_generator) 280 | print('control_ratio:', control_ratio) 281 | print('g_max_q_mean:', g_max_q_mean) 282 | print('g_min_q_mean:', g_min_q_mean) 283 | print('d_min_loss:', d_min_loss) 284 | print('mc_sample_size:', mc_sample_size) 285 | print('mc_sample_factor:', mc_sample_factor) 286 | print('enable_logging:', enable_logging) 287 | print('model_load_path:', model_load_path) 288 | print('model_loss:', g_loss_func) 289 | print('model_optimizer_class:', model_optimizer_class) 290 | print('model_gradient_clipping:', model_gradient_clipping) 291 | print('model_learning_rate:', model_learning_rate) 292 | print('reset_model_optimizer:', reset_model_optimizer) 293 | print('discriminator_load_path:', discriminator_load_path) 294 | print('discriminator_loss:', d_loss_func) 295 | print('discriminator_optimizer_class:', discriminator_optimizer_class) 296 | print('discriminator_gradient_clipping:', discriminator_gradient_clipping) 297 | print('discriminator_learning_rate:', discriminator_learning_rate) 298 | print('reset_discriminator_optimizer:', reset_discriminator_optimizer) 299 | print('first_to_train:', first_to_train) 300 | print('-' * 70) 301 | print(f'Generator from "{sess_path if loaded_from_session else model_load_path}"') 302 | print(model) 303 | print(model_optimizer) 304 | print('-' * 70) 305 | print(f'Discriminator from "{sess_path if loaded_from_session else discriminator_load_path}"') 306 | print(discriminator) 307 | print(discriminator_optimizer) 308 | print('-' * 70) 309 | 310 | 311 | def save(): 312 | print(f'Saving to "{sess_path}"') 313 | torch.save({ 314 | 'model_config': model_config, 315 | 'model_state': model.state_dict(), 316 | 'model_optimizer_state': model_optimizer.state_dict(), 317 | 'discriminator_config': discriminator_config, 318 | 'discriminator_state': discriminator.state_dict(), 319 | 'discriminator_optimizer_state': discriminator_optimizer.state_dict() 320 | }, sess_path) 321 | print('Done saving') 322 | 323 | def mc_rollout(generated, hidden, total_steps, controls=None): 324 | # generated: [t, batch_size] 325 | # hidden: [n_layers, batch_size, hidden_dim] 326 | # controls: [total_steps - t, batch_size, control_dim] 327 | generated = torch.cat(generated, 0) 328 | generated_steps, batch_size = generated.shape # t, b 329 | steps = total_steps - generated_steps # s 330 | 331 | generated = generated.unsqueeze(1) # [t, 1, b] 332 | generated = generated.repeat(1, mc_sample_size, 1) # [t, mcs, b] 333 | generated = generated.view(generated_steps, -1) # [t, mcs * b] 334 | 335 | hidden = hidden.unsqueeze(1).repeat(1, mc_sample_size, 1, 1) 336 | hidden = hidden.view(model.gru_layers, -1, model.hidden_dim) 337 | 338 | if controls is not None: 339 | assert controls.shape == (steps, batch_size, model.control_dim) 340 | controls = controls.unsqueeze(1) # [s, 1, b, c] 341 | controls = controls.repeat(1, mc_sample_size, 1, 1) # [s, mcs, b, c] 342 | controls = controls.view(steps, -1, model.control_dim) # [s, mcs * b, c] 343 | 344 | event = generated[-1].unsqueeze(0) # [1, mcs * b] 345 | control = None # default when controls is None 346 | outputs = [] 347 | 348 | for i in range(steps): 349 | if controls is not None: 350 | control = controls[i].unsqueeze(0) # [1, mcs * b, c] 351 | 352 | output, hidden = model.forward(event, control=control, hidden=hidden) 353 | probs = model.output_fc_activation(output / mc_sample_factor) 354 | event = Categorical(probs).sample() # [1, mcs * b] 355 | outputs.append(event) 356 | 357 | sequences = torch.cat([generated, *outputs], 0) 358 | assert sequences.shape == (total_steps, mc_sample_size * batch_size) 359 | return sequences 360 | 361 | 362 | def train_generator(batch_size, init, events, controls): 363 | # Generator step 364 | hidden = model.init_to_hidden(init) 365 | event = model.get_primary_event(batch_size) 366 | outputs = [] 367 | generated = [] 368 | q_values = [] 369 | 370 | for step in Bar('MC Rollout').iter(range(steps)): 371 | control = controls[step].unsqueeze(0) if use_control else None 372 | output, hidden = model.forward(event, control=control, hidden=hidden) 373 | outputs.append(output) 374 | probs = model.output_fc_activation(output / mc_sample_factor) 375 | generated.append(Categorical(probs).sample()) 376 | 377 | with torch.no_grad(): 378 | if step < steps - 1: 379 | sequences = mc_rollout(generated, hidden, steps, controls[step+1:]) 380 | mc_score = discriminator(sequences) # [mcs * b] 381 | mc_score = mc_score.view(mc_sample_size, batch_size) # [mcs, b] 382 | q_value = mc_score.mean(0, keepdim=True) # [1, batch_size] 383 | 384 | else: 385 | q_value = discriminator(torch.cat(generated, 0)) 386 | q_value = q_value.unsqueeze(0) # [1, batch_size] 387 | 388 | q_values.append(q_value) 389 | 390 | # Compute loss 391 | q_values = torch.cat(q_values, 0) # [steps, batch_size] 392 | q_mean = q_values.mean().detach() 393 | q_values = q_values - q_mean 394 | generated = torch.cat(generated, 0) # [steps, batch_size] 395 | outputs = torch.cat(outputs, 0) # [steps, batch_size, event_dim] 396 | loss = F.cross_entropy(outputs.view(-1, model.event_dim), 397 | generated.view(-1), 398 | reduce=False) 399 | loss = (loss * q_values.view(-1)).mean() 400 | 401 | # Backprop 402 | model.zero_grad() 403 | loss.backward() 404 | 405 | # Gradient clipping 406 | norm = utils.compute_gradient_norm(model.parameters()) 407 | if model_gradient_clipping: 408 | nn.utils.clip_grad_norm_(model.parameters(), model_gradient_clipping) 409 | 410 | model_optimizer.step() 411 | 412 | q_mean = q_mean.item() 413 | norm = norm.item() 414 | return q_mean, norm 415 | 416 | def train_discriminator(batch_size, init, events, controls): 417 | # Discriminator step 418 | with torch.no_grad(): 419 | generated = model.generate(init, steps, None, controls, 420 | greedy=0, temperature=mc_sample_factor) 421 | 422 | fake_logit = discriminator(generated, output_logits=True) 423 | real_logit = discriminator(events, output_logits=True) 424 | fake_target = torch.zeros_like(fake_logit) 425 | real_target = torch.ones_like(real_logit) 426 | 427 | # Compute loss 428 | fake_loss = F.binary_cross_entropy_with_logits(fake_logit, fake_target) 429 | real_loss = F.binary_cross_entropy_with_logits(real_logit, real_target) 430 | loss = (real_loss + fake_loss) / 2 431 | 432 | # Backprop 433 | discriminator.zero_grad() 434 | loss.backward() 435 | 436 | # Gradient clipping 437 | norm = utils.compute_gradient_norm(discriminator.parameters()) 438 | if discriminator_gradient_clipping: 439 | nn.utils.clip_grad_norm_(discriminator.parameters(), discriminator_gradient_clipping) 440 | 441 | discriminator_optimizer.step() 442 | 443 | real_loss = real_loss.item() 444 | fake_loss = fake_loss.item() 445 | loss = loss.item() 446 | norm = norm.item() 447 | return loss, real_loss, fake_loss, norm 448 | 449 | try: 450 | last_save_time = time.time() 451 | step_for = first_to_train 452 | g_steps = 0 453 | d_steps = 0 454 | 455 | for i, (events, controls) in enumerate(batch_data_generator): 456 | steps, batch_size = events.shape 457 | init = torch.randn(batch_size, model.init_dim).to(device) 458 | events = torch.LongTensor(events).to(device) 459 | 460 | use_control = np.random.random() <= control_ratio 461 | controls = torch.FloatTensor(controls).to(device) if use_control else None 462 | 463 | if step_for == 'G': 464 | q_mean, norm = train_generator(batch_size, init, events, controls) 465 | g_steps += 1 466 | 467 | print(f'{i} (G-step) Q_mean: {q_mean}, norm: {norm}') 468 | if enable_logging: 469 | writer.add_scalar('adversarial/G/Q_mean', q_mean, i) 470 | writer.add_scalar('adversarial/G/norm', norm, i) 471 | 472 | if q_mean < g_min_q_mean: 473 | print(f'Q is too small: {q_mean}, exiting') 474 | raise KeyboardInterrupt 475 | 476 | if q_mean > g_max_q_mean or (g_max_steps and g_steps >= g_max_steps): 477 | step_for = 'D' 478 | d_steps = 0 479 | 480 | if step_for == 'D': 481 | loss, real_loss, fake_loss, norm = train_discriminator(batch_size, init, events, controls) 482 | d_steps += 1 483 | 484 | print(f'{i} (D-step) loss: {loss} (real: {real_loss}, fake: {fake_loss}), norm: {norm}') 485 | if enable_logging: 486 | writer.add_scalar('adversarial/D/loss', loss, i) 487 | writer.add_scalar('adversarial/D/norm', norm, i) 488 | 489 | if fake_loss <= real_loss < d_min_loss or (d_max_steps and d_steps >= d_max_steps): 490 | step_for = 'G' 491 | g_steps = 0 492 | 493 | if last_save_time + save_interval < time.time(): 494 | last_save_time = time.time() 495 | save() 496 | 497 | except KeyboardInterrupt: 498 | save() 499 | 500 | 501 | 502 | #======================================================================== 503 | # Script Arguments 504 | #======================================================================== 505 | 506 | def batch_generator(args): 507 | print('-' * 70) 508 | dataset = Dataset(args.dataset_path, verbose=True) 509 | print(dataset) 510 | return dataset.batches(args.batch_size, args.window_size, args.stride_size) 511 | 512 | def pretrain(args): 513 | pretrain_discriminator(model_sess_path=args.generator_session_path, 514 | discriminator_sess_path=args.discriminator_session_path, 515 | discriminator_config_overwrite=utils.params2dict(args.discriminator_parameters), 516 | batch_data_generator=args.batch_generator(args), 517 | gradient_clipping=args.gradient_clipping, 518 | sample_factor=args.sample_factor, 519 | auto_sample_factor=args.auto_sample_factor, 520 | control_ratio=args.control_ratio, 521 | num_iter=args.stop_iteration, 522 | save_interval=args.save_interval, 523 | discriminator_lr=args.discriminator_learning_rate, 524 | enable_logging=args.enable_logging) 525 | 526 | def adversarial(args): 527 | train_adversarial(sess_path=args.session_path, 528 | batch_data_generator=args.batch_generator(args), 529 | model_load_path=args.generator_load_path, 530 | discriminator_load_path=args.discriminator_load_path, 531 | model_optimizer_class=getattr(optim, args.generator_optimizer), 532 | discriminator_optimizer_class=getattr(optim, args.discriminator_optimizer), 533 | model_gradient_clipping=args.generator_gradient_clipping, 534 | discriminator_gradient_clipping=args.discriminator_gradient_clipping, 535 | model_learning_rate=args.generator_learning_rate, 536 | discriminator_learning_rate=args.discriminator_learning_rate, 537 | reset_model_optimizer=args.reset_generator_optimizer, 538 | reset_discriminator_optimizer=args.reset_discriminator_optimizer, 539 | g_max_q_mean=args.g_max_q_mean, 540 | g_min_q_mean=args.g_min_q_mean, 541 | d_min_loss=args.d_min_loss, 542 | g_max_steps=args.g_max_steps, 543 | d_max_steps=args.d_max_steps, 544 | mc_sample_size=args.monte_carlo_sample_size, 545 | mc_sample_factor=args.monte_carlo_sample_factor, 546 | first_to_train=args.first_to_train, 547 | control_ratio=args.control_ratio, 548 | save_interval=args.save_interval, 549 | enable_logging=args.enable_logging) 550 | 551 | def get_args(): 552 | parser = argparse.ArgumentParser() 553 | subparsers = parser.add_subparsers() 554 | parser.add_argument('-d', '--dataset-path', type=str, required=True) 555 | parser.add_argument('-b', '--batch-size', type=int, default=64) 556 | parser.add_argument('-w', '--window-size', type=int, default=200) 557 | parser.add_argument('-s', '--stride-size', type=int, default=10) 558 | parser.set_defaults(batch_generator=batch_generator) 559 | pre_parser = subparsers.add_parser('pretrain', aliases=['p', 'pre']) 560 | pre_parser.add_argument('-G', '--generator-session-path', type=str, default=True) 561 | pre_parser.add_argument('-D', '--discriminator-session-path', type=str, required=True) 562 | pre_parser.add_argument('-p', '--discriminator-parameters', type=str, default='') 563 | pre_parser.add_argument('-l', '--discriminator-learning-rate', type=float, default=0.001) 564 | pre_parser.add_argument('-g', '--gradient-clipping', type=float, default=False) 565 | pre_parser.add_argument('-f', '--sample-factor', type=float, default=1.0) 566 | pre_parser.add_argument('-af', '--auto-sample-factor', action='store_true', default=False) 567 | pre_parser.add_argument('-c', '--control-ratio', type=float, default=1.0) 568 | pre_parser.add_argument('-n', '--stop-iteration', type=int, default=-1) 569 | pre_parser.add_argument('-i', '--save-interval', type=float, default=60.0) 570 | pre_parser.add_argument('-L', '--enable-logging', action='store_true', default=False) 571 | pre_parser.set_defaults(main=pretrain) 572 | adv_parser = subparsers.add_parser('adversarial', aliases=['a', 'adv']) 573 | adv_parser.add_argument('-S', '--session-path', type=str, required=True) 574 | adv_parser.add_argument('-Gp', '--generator-load-path', type=str) 575 | adv_parser.add_argument('-Dp', '--discriminator-load-path', type=str) 576 | adv_parser.add_argument('-Go', '--generator-optimizer', type=str, default='Adam') 577 | adv_parser.add_argument('-Do', '--discriminator-optimizer', type=str, default='RMSprop') 578 | adv_parser.add_argument('-Gg', '--generator-gradient-clipping', type=float, default=False) 579 | adv_parser.add_argument('-Dg', '--discriminator-gradient-clipping', type=float, default=False) 580 | adv_parser.add_argument('-Gl', '--generator-learning-rate', type=float, default=0.001) 581 | adv_parser.add_argument('-Dl', '--discriminator-learning-rate', type=float, default=0.001) 582 | adv_parser.add_argument('-Gr', '--reset-generator-optimizer', action='store_true', default=False) 583 | adv_parser.add_argument('-Dr', '--reset-discriminator-optimizer', action='store_true', default=False) 584 | adv_parser.add_argument('-Gq', '--g-max-q-mean', type=float, default=0.5) 585 | adv_parser.add_argument('-Gm', '--g-min-q-mean', type=float, default=0.0) 586 | adv_parser.add_argument('-Dm', '--d-min-loss', type=float, default=0.5) 587 | adv_parser.add_argument('-Gs', '--g-max-steps', type=int, default=0) 588 | adv_parser.add_argument('-Ds', '--d-max-steps', type=int, default=0) 589 | adv_parser.add_argument('-f', '--first-to-train', type=str, default='G', choices=['G', 'D']) 590 | adv_parser.add_argument('-ms', '--monte-carlo-sample-size', type=int, default=8) 591 | adv_parser.add_argument('-mf', '--monte-carlo-sample-factor', type=float, default=1.0) 592 | adv_parser.add_argument('-c', '--control-ratio', type=float, default=1.0) 593 | adv_parser.add_argument('-i', '--save-interval', type=float, default=60.0) 594 | adv_parser.add_argument('-L', '--enable-logging', action='store_true', default=False) 595 | adv_parser.set_defaults(main=adversarial) 596 | return parser.parse_args() 597 | 598 | 599 | if __name__ == '__main__': 600 | args = get_args() 601 | args.main(args) 602 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sequence import EventSeq, ControlSeq 3 | 4 | #pylint: disable=E1101 5 | 6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 7 | 8 | model = { 9 | 'init_dim': 32, 10 | 'event_dim': EventSeq.dim(), 11 | 'control_dim': ControlSeq.dim(), 12 | 'hidden_dim': 512, 13 | 'gru_layers': 3, 14 | 'gru_dropout': 0.3, 15 | } 16 | 17 | train = { 18 | 'learning_rate': 0.001, 19 | 'batch_size': 64, 20 | 'window_size': 200, 21 | 'stride_size': 10, 22 | 'use_transposition': False, 23 | 'control_ratio': 1.0, 24 | 'teacher_forcing_ratio': 1.0 25 | } 26 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import itertools, os 4 | import numpy as np 5 | from progress.bar import Bar 6 | 7 | import config 8 | import utils 9 | from sequence import EventSeq, ControlSeq 10 | 11 | # pylint: disable=E1101 12 | # pylint: disable=W0101 13 | 14 | class Dataset: 15 | def __init__(self, root, verbose=False): 16 | assert os.path.isdir(root), root 17 | paths = utils.find_files_by_extensions(root, ['.data']) 18 | self.root = root 19 | self.samples = [] 20 | self.seqlens = [] 21 | if verbose: 22 | paths = Bar(root).iter(list(paths)) 23 | for path in paths: 24 | eventseq, controlseq = torch.load(path) 25 | controlseq = ControlSeq.recover_compressed_array(controlseq) 26 | assert len(eventseq) == len(controlseq) 27 | self.samples.append((eventseq, controlseq)) 28 | self.seqlens.append(len(eventseq)) 29 | self.avglen = np.mean(self.seqlens) 30 | 31 | def batches(self, batch_size, window_size, stride_size): 32 | indeces = [(i, range(j, j + window_size)) 33 | for i, seqlen in enumerate(self.seqlens) 34 | for j in range(0, seqlen - window_size, stride_size)] 35 | while True: 36 | eventseq_batch = [] 37 | controlseq_batch = [] 38 | n = 0 39 | for ii in np.random.permutation(len(indeces)): 40 | i, r = indeces[ii] 41 | eventseq, controlseq = self.samples[i] 42 | eventseq = eventseq[r.start:r.stop] 43 | controlseq = controlseq[r.start:r.stop] 44 | eventseq_batch.append(eventseq) 45 | controlseq_batch.append(controlseq) 46 | n += 1 47 | if n == batch_size: 48 | yield (np.stack(eventseq_batch, axis=1), 49 | np.stack(controlseq_batch, axis=1)) 50 | eventseq_batch.clear() 51 | controlseq_batch.clear() 52 | n = 0 53 | 54 | def __repr__(self): 55 | return (f'Dataset(root="{self.root}", ' 56 | f'samples={len(self.samples)}, ' 57 | f'avglen={self.avglen})') 58 | -------------------------------------------------------------------------------- /dataset/midi/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djosix/Performance-RNN-PyTorch/b47e6d3e5504c88394b0414e7ac175959369d3da/dataset/midi/.keep -------------------------------------------------------------------------------- /dataset/processed/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djosix/Performance-RNN-PyTorch/b47e6d3e5504c88394b0414e7ac175959369d3da/dataset/processed/.keep -------------------------------------------------------------------------------- /dataset/scripts/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djosix/Performance-RNN-PyTorch/b47e6d3e5504c88394b0414e7ac175959369d3da/dataset/scripts/.keep -------------------------------------------------------------------------------- /dataset/scripts/classic_piano_downloader.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Scraper for Classical Piano Midi Page 3 | [ ! "$1" ] && echo 'Error: please specify output dir' && exit 4 | dir=$1 5 | base=http://www.piano-midi.de 6 | pages=$(curl -s --max-time 5 $base/midi_files.htm \ 7 | | grep ' 0 94 | stochastic_beam_search = opt.stochastic_beam_search 95 | beam_size = opt.beam_size 96 | temperature = opt.temperature 97 | init_zero = opt.init_zero 98 | 99 | if use_beam_search: 100 | greedy_ratio = 'DISABLED' 101 | else: 102 | beam_size = 'DISABLED' 103 | 104 | assert os.path.isfile(sess_path), f'"{sess_path}" is not a file' 105 | 106 | if control is not None: 107 | if os.path.isfile(control) or os.path.isdir(control): 108 | if os.path.isdir(control): 109 | files = list(utils.find_files_by_extensions(control)) 110 | assert len(files) > 0, f'no file in "{control}"' 111 | control = np.random.choice(files) 112 | _, compressed_controls = torch.load(control) 113 | controls = ControlSeq.recover_compressed_array(compressed_controls) 114 | if max_len == 0: 115 | max_len = controls.shape[0] 116 | controls = torch.tensor(controls, dtype=torch.float32) 117 | controls = controls.unsqueeze(1).repeat(1, batch_size, 1).to(device) 118 | control = f'control sequence from "{control}"' 119 | 120 | else: 121 | pitch_histogram, note_density = control.split(';') 122 | pitch_histogram = list(filter(len, pitch_histogram.split(','))) 123 | if len(pitch_histogram) == 0: 124 | pitch_histogram = np.ones(12) / 12 125 | else: 126 | pitch_histogram = np.array(list(map(float, pitch_histogram))) 127 | assert pitch_histogram.size == 12 128 | assert np.all(pitch_histogram >= 0) 129 | pitch_histogram = pitch_histogram / pitch_histogram.sum() \ 130 | if pitch_histogram.sum() else np.ones(12) / 12 131 | note_density = int(note_density) 132 | assert note_density in range(len(ControlSeq.note_density_bins)) 133 | control = Control(pitch_histogram, note_density) 134 | controls = torch.tensor(control.to_array(), dtype=torch.float32) 135 | controls = controls.repeat(1, batch_size, 1).to(device) 136 | control = repr(control) 137 | 138 | else: 139 | controls = None 140 | control = 'NONE' 141 | 142 | assert max_len > 0, 'either max length or control sequence length should be given' 143 | 144 | # ------------------------------------------------------------------------ 145 | 146 | print('-' * 70) 147 | print('Session:', sess_path) 148 | print('Batch size:', batch_size) 149 | print('Max length:', max_len) 150 | print('Greedy ratio:', greedy_ratio) 151 | print('Beam size:', beam_size) 152 | print('Beam search stochastic:', stochastic_beam_search) 153 | print('Output directory:', output_dir) 154 | print('Controls:', control) 155 | print('Temperature:', temperature) 156 | print('Init zero:', init_zero) 157 | print('-' * 70) 158 | 159 | 160 | # ======================================================================== 161 | # Generating 162 | # ======================================================================== 163 | 164 | state = torch.load(sess_path, map_location=device) 165 | model = PerformanceRNN(**state['model_config']).to(device) 166 | model.load_state_dict(state['model_state']) 167 | model.eval() 168 | print(model) 169 | print('-' * 70) 170 | 171 | if init_zero: 172 | init = torch.zeros(batch_size, model.init_dim).to(device) 173 | else: 174 | init = torch.randn(batch_size, model.init_dim).to(device) 175 | 176 | with torch.no_grad(): 177 | if use_beam_search: 178 | outputs = model.beam_search(init, max_len, beam_size, 179 | controls=controls, 180 | temperature=temperature, 181 | stochastic=stochastic_beam_search, 182 | verbose=True) 183 | else: 184 | outputs = model.generate(init, max_len, 185 | controls=controls, 186 | greedy=greedy_ratio, 187 | temperature=temperature, 188 | verbose=True) 189 | 190 | outputs = outputs.cpu().numpy().T # [batch, steps] 191 | 192 | 193 | # ======================================================================== 194 | # Saving 195 | # ======================================================================== 196 | 197 | os.makedirs(output_dir, exist_ok=True) 198 | 199 | for i, output in enumerate(outputs): 200 | name = f'output-{i:03d}.mid' 201 | path = os.path.join(output_dir, name) 202 | n_notes = utils.event_indeces_to_midi_file(output, path) 203 | print(f'===> {path} ({n_notes} notes)') 204 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Categorical, Gumbel 5 | 6 | from collections import namedtuple 7 | import numpy as np 8 | from progress.bar import Bar 9 | from config import device 10 | 11 | # pylint: disable=E1101,E1102 12 | 13 | 14 | class PerformanceRNN(nn.Module): 15 | def __init__(self, event_dim, control_dim, init_dim, hidden_dim, 16 | gru_layers=3, gru_dropout=0.3): 17 | super().__init__() 18 | 19 | self.event_dim = event_dim 20 | self.control_dim = control_dim 21 | self.init_dim = init_dim 22 | self.hidden_dim = hidden_dim 23 | self.gru_layers = gru_layers 24 | self.concat_dim = event_dim + 1 + control_dim 25 | self.input_dim = hidden_dim 26 | self.output_dim = event_dim 27 | 28 | self.primary_event = self.event_dim - 1 29 | 30 | self.inithid_fc = nn.Linear(init_dim, gru_layers * hidden_dim) 31 | self.inithid_fc_activation = nn.Tanh() 32 | 33 | self.event_embedding = nn.Embedding(event_dim, event_dim) 34 | self.concat_input_fc = nn.Linear(self.concat_dim, self.input_dim) 35 | self.concat_input_fc_activation = nn.LeakyReLU(0.1, inplace=True) 36 | 37 | self.gru = nn.GRU(self.input_dim, self.hidden_dim, 38 | num_layers=gru_layers, dropout=gru_dropout) 39 | self.output_fc = nn.Linear(hidden_dim * gru_layers, self.output_dim) 40 | self.output_fc_activation = nn.Softmax(dim=-1) 41 | 42 | self._initialize_weights() 43 | 44 | def _initialize_weights(self): 45 | nn.init.xavier_normal_(self.event_embedding.weight) 46 | nn.init.xavier_normal_(self.inithid_fc.weight) 47 | self.inithid_fc.bias.data.fill_(0.) 48 | nn.init.xavier_normal_(self.concat_input_fc.weight) 49 | nn.init.xavier_normal_(self.output_fc.weight) 50 | self.output_fc.bias.data.fill_(0.) 51 | 52 | def _sample_event(self, output, greedy=True, temperature=1.0): 53 | if greedy: 54 | return output.argmax(-1) 55 | else: 56 | output = output / temperature 57 | probs = self.output_fc_activation(output) 58 | return Categorical(probs).sample() 59 | 60 | def forward(self, event, control=None, hidden=None): 61 | # One step forward 62 | 63 | assert len(event.shape) == 2 64 | assert event.shape[0] == 1 65 | batch_size = event.shape[1] 66 | event = self.event_embedding(event) 67 | 68 | if control is None: 69 | default = torch.ones(1, batch_size, 1).to(device) 70 | control = torch.zeros(1, batch_size, self.control_dim).to(device) 71 | else: 72 | default = torch.zeros(1, batch_size, 1).to(device) 73 | assert control.shape == (1, batch_size, self.control_dim) 74 | 75 | concat = torch.cat([event, default, control], -1) 76 | input = self.concat_input_fc(concat) 77 | input = self.concat_input_fc_activation(input) 78 | 79 | _, hidden = self.gru(input, hidden) 80 | output = hidden.permute(1, 0, 2).contiguous() 81 | output = output.view(batch_size, -1).unsqueeze(0) 82 | output = self.output_fc(output) 83 | return output, hidden 84 | 85 | def get_primary_event(self, batch_size): 86 | return torch.LongTensor([[self.primary_event] * batch_size]).to(device) 87 | 88 | def init_to_hidden(self, init): 89 | # [batch_size, init_dim] 90 | batch_size = init.shape[0] 91 | out = self.inithid_fc(init) 92 | out = self.inithid_fc_activation(out) 93 | out = out.view(self.gru_layers, batch_size, self.hidden_dim) 94 | return out 95 | 96 | def expand_controls(self, controls, steps): 97 | # [1 or steps, batch_size, control_dim] 98 | assert len(controls.shape) == 3 99 | assert controls.shape[2] == self.control_dim 100 | if controls.shape[0] > 1: 101 | assert controls.shape[0] >= steps 102 | return controls[:steps] 103 | return controls.repeat(steps, 1, 1) 104 | 105 | def generate(self, init, steps, events=None, controls=None, greedy=1.0, 106 | temperature=1.0, teacher_forcing_ratio=1.0, output_type='index', verbose=False): 107 | # init [batch_size, init_dim] 108 | # events [steps, batch_size] indeces 109 | # controls [1 or steps, batch_size, control_dim] 110 | 111 | batch_size = init.shape[0] 112 | assert init.shape[1] == self.init_dim 113 | assert steps > 0 114 | 115 | use_teacher_forcing = events is not None 116 | if use_teacher_forcing: 117 | assert len(events.shape) == 2 118 | assert events.shape[0] >= steps - 1 119 | events = events[:steps-1] 120 | 121 | event = self.get_primary_event(batch_size) 122 | use_control = controls is not None 123 | if use_control: 124 | controls = self.expand_controls(controls, steps) 125 | hidden = self.init_to_hidden(init) 126 | 127 | outputs = [] 128 | step_iter = range(steps) 129 | if verbose: 130 | step_iter = Bar('Generating').iter(step_iter) 131 | 132 | for step in step_iter: 133 | control = controls[step].unsqueeze(0) if use_control else None 134 | output, hidden = self.forward(event, control, hidden) 135 | 136 | use_greedy = np.random.random() < greedy 137 | event = self._sample_event(output, greedy=use_greedy, 138 | temperature=temperature) 139 | 140 | if output_type == 'index': 141 | outputs.append(event) 142 | elif output_type == 'softmax': 143 | outputs.append(self.output_fc_activation(output)) 144 | elif output_type == 'logit': 145 | outputs.append(output) 146 | else: 147 | assert False 148 | 149 | if use_teacher_forcing and step < steps - 1: # avoid last one 150 | if np.random.random() <= teacher_forcing_ratio: 151 | event = events[step].unsqueeze(0) 152 | 153 | return torch.cat(outputs, 0) 154 | 155 | def beam_search(self, init, steps, beam_size, controls=None, 156 | temperature=1.0, stochastic=False, verbose=False): 157 | assert len(init.shape) == 2 and init.shape[1] == self.init_dim 158 | assert self.event_dim >= beam_size > 0 and steps > 0 159 | 160 | batch_size = init.shape[0] 161 | current_beam_size = 1 162 | 163 | if controls is not None: 164 | controls = self.expand_controls(controls, steps) # [steps, batch_size, control_dim] 165 | 166 | # Initial hidden weights 167 | hidden = self.init_to_hidden(init) # [gru_layers, batch_size, hidden_size] 168 | hidden = hidden[:, :, None, :] # [gru_layers, batch_size, 1, hidden_size] 169 | hidden = hidden.repeat(1, 1, current_beam_size, 1) # [gru_layers, batch_size, beam_size, hidden_dim] 170 | 171 | 172 | # Initial event 173 | event = self.get_primary_event(batch_size) # [1, batch] 174 | event = event[:, :, None].repeat(1, 1, current_beam_size) # [1, batch, 1] 175 | 176 | # [batch, beam, 1] event sequences of beams 177 | beam_events = event[0, :, None, :].repeat(1, current_beam_size, 1) 178 | 179 | # [batch, beam] log probs sum of beams 180 | beam_log_prob = torch.zeros(batch_size, current_beam_size).to(device) 181 | 182 | if stochastic: 183 | # [batch, beam] Gumbel perturbed log probs of beams 184 | beam_log_prob_perturbed = torch.zeros(batch_size, current_beam_size).to(device) 185 | beam_z = torch.full((batch_size, beam_size), float('inf')) 186 | gumbel_dist = Gumbel(0, 1) 187 | 188 | step_iter = range(steps) 189 | if verbose: 190 | step_iter = Bar(['', 'Stochastic '][stochastic] + 'Beam Search').iter(step_iter) 191 | 192 | for step in step_iter: 193 | if controls is not None: 194 | control = controls[step, None, :, None, :] # [1, batch, 1, control] 195 | control = control.repeat(1, 1, current_beam_size, 1) # [1, batch, beam, control] 196 | control = control.view(1, batch_size * current_beam_size, self.control_dim) # [1, batch*beam, control] 197 | else: 198 | control = None 199 | 200 | event = event.view(1, batch_size * current_beam_size) # [1, batch*beam0] 201 | hidden = hidden.view(self.gru_layers, batch_size * current_beam_size, self.hidden_dim) # [grus, batch*beam, hid] 202 | 203 | logits, hidden = self.forward(event, control, hidden) 204 | hidden = hidden.view(self.gru_layers, batch_size, current_beam_size, self.hidden_dim) # [grus, batch, cbeam, hid] 205 | logits = (logits / temperature).view(1, batch_size, current_beam_size, self.event_dim) # [1, batch, cbeam, out] 206 | 207 | beam_log_prob_expand = logits + beam_log_prob[None, :, :, None] # [1, batch, cbeam, out] 208 | beam_log_prob_expand_batch = beam_log_prob_expand.view(1, batch_size, -1) # [1, batch, cbeam*out] 209 | 210 | if stochastic: 211 | beam_log_prob_expand_perturbed = beam_log_prob_expand + gumbel_dist.sample(beam_log_prob_expand.shape) 212 | beam_log_prob_Z, _ = beam_log_prob_expand_perturbed.max(-1) # [1, batch, cbeam] 213 | # print(beam_log_prob_Z) 214 | beam_log_prob_expand_perturbed_normalized = beam_log_prob_expand_perturbed 215 | # beam_log_prob_expand_perturbed_normalized = -torch.log( 216 | # torch.exp(-beam_log_prob_perturbed[None, :, :, None]) 217 | # - torch.exp(-beam_log_prob_Z[:, :, :, None]) 218 | # + torch.exp(-beam_log_prob_expand_perturbed)) # [1, batch, cbeam, out] 219 | # beam_log_prob_expand_perturbed_normalized = beam_log_prob_perturbed[None, :, :, None] + beam_log_prob_expand_perturbed # [1, batch, cbeam, out] 220 | 221 | beam_log_prob_expand_perturbed_normalized_batch = \ 222 | beam_log_prob_expand_perturbed_normalized.view(1, batch_size, -1) # [1, batch, cbeam*out] 223 | _, top_indices = beam_log_prob_expand_perturbed_normalized_batch.topk(beam_size, -1) # [1, batch, cbeam] 224 | 225 | beam_log_prob_perturbed = \ 226 | torch.gather(beam_log_prob_expand_perturbed_normalized_batch, -1, top_indices)[0] # [batch, beam] 227 | 228 | else: 229 | _, top_indices = beam_log_prob_expand_batch.topk(beam_size, -1) 230 | 231 | beam_log_prob = torch.gather(beam_log_prob_expand_batch, -1, top_indices)[0] # [batch, beam] 232 | 233 | beam_index_old = torch.arange(current_beam_size)[None, None, :, None] # [1, 1, cbeam, 1] 234 | beam_index_old = beam_index_old.repeat(1, batch_size, 1, self.output_dim) # [1, batch, cbeam, out] 235 | beam_index_old = beam_index_old.view(1, batch_size, -1) # [1, batch, cbeam*out] 236 | beam_index_new = torch.gather(beam_index_old, -1, top_indices) 237 | 238 | hidden = torch.gather(hidden, 2, beam_index_new[:, :, :, None].repeat(4, 1, 1, 1024)) 239 | 240 | event_index = torch.arange(self.output_dim)[None, None, None, :] # [1, 1, 1, out] 241 | event_index = event_index.repeat(1, batch_size, current_beam_size, 1) # [1, batch, cbeam, out] 242 | event_index = event_index.view(1, batch_size, -1) # [1, batch, cbeam*out] 243 | event = torch.gather(event_index, -1, top_indices) # [1, batch, cbeam*out] 244 | 245 | beam_events = torch.gather(beam_events[None], 2, beam_index_new.unsqueeze(-1).repeat(1, 1, 1, beam_events.shape[-1])) 246 | beam_events = torch.cat([beam_events, event.unsqueeze(-1)], -1)[0] 247 | 248 | current_beam_size = beam_size 249 | 250 | best = beam_events[torch.arange(batch_size).long(), beam_log_prob.argmax(-1)] 251 | best = best.contiguous().t() 252 | return best 253 | -------------------------------------------------------------------------------- /output/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djosix/Performance-RNN-PyTorch/b47e6d3e5504c88394b0414e7ac175959369d3da/output/.keep -------------------------------------------------------------------------------- /play.py: -------------------------------------------------------------------------------- 1 | info = ''' 2 | ------------------------------------ 3 | Simple MIDI Player & Note Visualizer 4 | ------------------------------------ 5 | 6 | Please install pygame and fluidsynth for Python 3: 7 | 8 | pip3 install pygame 9 | pip3 install git+https://github.com/txomon/pyfluidsynth.git 10 | 11 | Usage: 12 | 13 | python3 play.py soundfont_path midi_files... 14 | 15 | '''.lstrip() 16 | 17 | 18 | import os, sys, time, threading 19 | import pygame, fluidsynth 20 | import sequence, pretty_midi 21 | import numpy as np 22 | 23 | # pylint: disable=E1101 24 | 25 | entities = [] 26 | lock = threading.Lock() 27 | done = False 28 | 29 | width = 1024 30 | height = 768 31 | 32 | #==================================================================== 33 | # Display 34 | #==================================================================== 35 | 36 | class NoteEntity: 37 | def __init__(self, key, velocity): 38 | self.done = False 39 | self.age = 255 40 | self.color = pygame.color.Color('black') 41 | self.radius = int(40 * velocity / 128) 42 | self.velocity = np.array([0., -velocity / 10], dtype=np.float32) 43 | pr = sequence.DEFAULT_PITCH_RANGE 44 | x = key * width / 128 45 | self.position = np.array([x, height], dtype=np.float32) 46 | 47 | def update(self): 48 | self.position += self.velocity 49 | self.velocity += np.random.randn(2) / 20 50 | self.velocity *= 0.99 51 | self.age -= 1 52 | if self.age == 0: 53 | self.done = True 54 | 55 | def get_color(self): 56 | return pygame.color.Color(self.age, self.age, self.age) 57 | 58 | def render(self, screen): 59 | x, y = self.position.astype(np.int32) 60 | color = self.get_color() 61 | pygame.draw.circle(screen, color, (x, y), self.radius) 62 | 63 | def add_entitiy(key, velocity): 64 | global entities, lock 65 | lock.acquire() 66 | entities.append(NoteEntity(key, velocity)) 67 | lock.release() 68 | 69 | def display(): 70 | global entities, lock, width, height, done 71 | pygame.init() 72 | pygame.display.set_caption('generator') 73 | screen = pygame.display.set_mode((width, height)) 74 | clock = pygame.time.Clock() 75 | 76 | while not done: 77 | for event in pygame.event.get(): 78 | if event.type == pygame.QUIT: 79 | done = True 80 | 81 | screen.fill(pygame.color.Color('black')) 82 | 83 | lock.acquire() 84 | entities = [entity for entity in entities if not entity.done] 85 | for i, entity in enumerate(entities): 86 | if i + 1 < len(entities): 87 | next_ent = entities[i + 1] 88 | pygame.draw.line(screen, entity.get_color(), entity.position, next_ent.position) 89 | for entity in entities: 90 | entity.render(screen) 91 | entity.update() 92 | lock.release() 93 | 94 | pygame.display.flip() 95 | clock.tick(60) 96 | 97 | pygame.quit() 98 | 99 | 100 | 101 | #==================================================================== 102 | # Synth 103 | #==================================================================== 104 | 105 | def note_repr(key, velocity): 106 | octave = key // 12 - 1 107 | name = ['C', 'C#', 'D', 'D#', 'E', 'F', 108 | 'F#', 'G', 'G#', 'A', 'A#', 'B'][key % 12] 109 | vel = int(10 * velocity / 128) 110 | return f'({name}{octave} {vel})' 111 | 112 | 113 | def play(midi_files, sound_font_path): 114 | global done 115 | 116 | fs = fluidsynth.Synth(gain=5) 117 | fs.start() 118 | 119 | try: 120 | sfid = fs.sfload(sound_font_path) 121 | fs.program_select(0, sfid, 0, 0) 122 | except: 123 | print('Failed to load', sound_font_path) 124 | return 125 | 126 | for midi_file in midi_files: 127 | 128 | print(f'Playing {midi_file}') 129 | 130 | try: 131 | note_seq = sequence.NoteSeq.from_midi_file(midi_file) 132 | event_seq = sequence.EventSeq.from_note_seq(note_seq) 133 | except: 134 | print('Failed to load', midi_file) 135 | continue 136 | 137 | velocity = sequence.DEFAULT_VELOCITY 138 | velocity_bins = sequence.EventSeq.get_velocity_bins() 139 | time_shift_bins = sequence.EventSeq.time_shift_bins 140 | pitch_start = sequence.EventSeq.pitch_range.start 141 | 142 | for event in event_seq.events: 143 | if event.type == 'note_on': 144 | key = int(event.value + pitch_start) 145 | fs.noteon(0, key, velocity) 146 | print(f' {note_repr(key, velocity)} ', end='', flush=True) 147 | add_entitiy(key, velocity) 148 | elif event.type == 'note_off': 149 | key = int(event.value + pitch_start) 150 | fs.noteoff(0, key) 151 | elif event.type == 'time_shift': 152 | print('.', end='', flush=True) 153 | time.sleep(time_shift_bins[event.value]) 154 | elif event.type == 'velocity': 155 | velocity = int(velocity_bins[ 156 | min(event.value, velocity_bins.size - 1)]) 157 | 158 | print('Done') 159 | 160 | done = True 161 | 162 | 163 | #==================================================================== 164 | # Main 165 | #==================================================================== 166 | 167 | 168 | if __name__ == '__main__': 169 | 170 | try: 171 | sound_font_path = sys.argv[1] 172 | midi_files = sys.argv[2:] 173 | except: 174 | print(info) 175 | sys.exit() 176 | 177 | assert os.path.isfile(sound_font_path), sound_font_path 178 | for midi_path in midi_files: 179 | assert os.path.isfile(midi_path), midi_path 180 | 181 | threading.Thread(target=play, 182 | args=(midi_files, sound_font_path), 183 | daemon=True).start() 184 | 185 | display() 186 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import torch 5 | import hashlib 6 | from progress.bar import Bar 7 | from concurrent.futures import ProcessPoolExecutor 8 | 9 | from sequence import NoteSeq, EventSeq, ControlSeq 10 | import utils 11 | import config 12 | 13 | def preprocess_midi(path): 14 | note_seq = NoteSeq.from_midi_file(path) 15 | note_seq.adjust_time(-note_seq.notes[0].start) 16 | event_seq = EventSeq.from_note_seq(note_seq) 17 | control_seq = ControlSeq.from_event_seq(event_seq) 18 | return event_seq.to_array(), control_seq.to_compressed_array() 19 | 20 | def preprocess_midi_files_under(midi_root, save_dir, num_workers): 21 | midi_paths = list(utils.find_files_by_extensions(midi_root, ['.mid', '.midi'])) 22 | os.makedirs(save_dir, exist_ok=True) 23 | out_fmt = '{}-{}.data' 24 | 25 | results = [] 26 | executor = ProcessPoolExecutor(num_workers) 27 | 28 | for path in midi_paths: 29 | try: 30 | results.append((path, executor.submit(preprocess_midi, path))) 31 | except KeyboardInterrupt: 32 | print(' Abort') 33 | return 34 | except: 35 | print(' Error') 36 | continue 37 | 38 | for path, future in Bar('Processing').iter(results): 39 | print(' ', end='[{}]'.format(path), flush=True) 40 | name = os.path.basename(path) 41 | code = hashlib.md5(path.encode()).hexdigest() 42 | save_path = os.path.join(save_dir, out_fmt.format(name, code)) 43 | torch.save(future.result(), save_path) 44 | 45 | print('Done') 46 | 47 | if __name__ == '__main__': 48 | preprocess_midi_files_under( 49 | midi_root=sys.argv[1], 50 | save_dir=sys.argv[2], 51 | num_workers=int(sys.argv[3])) 52 | -------------------------------------------------------------------------------- /runs/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djosix/Performance-RNN-PyTorch/b47e6d3e5504c88394b0414e7ac175959369d3da/runs/.keep -------------------------------------------------------------------------------- /save/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djosix/Performance-RNN-PyTorch/b47e6d3e5504c88394b0414e7ac175959369d3da/save/.keep -------------------------------------------------------------------------------- /sequence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import itertools 4 | import collections 5 | from pretty_midi import PrettyMIDI, Note, Instrument 6 | 7 | 8 | # ================================================================================== 9 | # Parameters 10 | # ================================================================================== 11 | 12 | # NoteSeq ------------------------------------------------------------------------- 13 | 14 | DEFAULT_SAVING_PROGRAM = 1 15 | DEFAULT_LOADING_PROGRAMS = range(128) 16 | DEFAULT_RESOLUTION = 220 17 | DEFAULT_TEMPO = 120 18 | DEFAULT_VELOCITY = 64 19 | DEFAULT_PITCH_RANGE = range(21, 109) 20 | DEFAULT_VELOCITY_RANGE = range(21, 109) 21 | DEFAULT_NORMALIZATION_BASELINE = 60 # C4 22 | 23 | # EventSeq ------------------------------------------------------------------------ 24 | 25 | USE_VELOCITY = True 26 | BEAT_LENGTH = 60 / DEFAULT_TEMPO 27 | DEFAULT_TIME_SHIFT_BINS = 1.15 ** np.arange(32) / 65 28 | DEFAULT_VELOCITY_STEPS = 32 29 | DEFAULT_NOTE_LENGTH = BEAT_LENGTH * 2 30 | MIN_NOTE_LENGTH = BEAT_LENGTH / 2 31 | 32 | # ControlSeq ---------------------------------------------------------------------- 33 | 34 | DEFAULT_WINDOW_SIZE = BEAT_LENGTH * 4 35 | DEFAULT_NOTE_DENSITY_BINS = np.arange(12) * 3 + 1 36 | 37 | 38 | # ================================================================================== 39 | # Notes 40 | # ================================================================================== 41 | 42 | class NoteSeq: 43 | 44 | @staticmethod 45 | def from_midi(midi, programs=DEFAULT_LOADING_PROGRAMS): 46 | notes = itertools.chain(*[ 47 | inst.notes for inst in midi.instruments 48 | if inst.program in programs and not inst.is_drum]) 49 | return NoteSeq(list(notes)) 50 | 51 | @staticmethod 52 | def from_midi_file(path, *args, **kwargs): 53 | midi = PrettyMIDI(path) 54 | return NoteSeq.from_midi(midi, *args, **kwargs) 55 | 56 | @staticmethod 57 | def merge(*note_seqs): 58 | notes = itertools.chain(*[seq.notes for seq in note_seqs]) 59 | return NoteSeq(list(notes)) 60 | 61 | def __init__(self, notes=[]): 62 | self.notes = [] 63 | if notes: 64 | for note in notes: 65 | assert isinstance(note, Note) 66 | notes = filter(lambda note: note.end >= note.start, notes) 67 | self.add_notes(list(notes)) 68 | 69 | def copy(self): 70 | return copy.deepcopy(self) 71 | 72 | def to_midi(self, program=DEFAULT_SAVING_PROGRAM, 73 | resolution=DEFAULT_RESOLUTION, tempo=DEFAULT_TEMPO): 74 | midi = PrettyMIDI(resolution=resolution, initial_tempo=tempo) 75 | inst = Instrument(program, False, 'NoteSeq') 76 | inst.notes = copy.deepcopy(self.notes) 77 | midi.instruments.append(inst) 78 | return midi 79 | 80 | def to_midi_file(self, path, *args, **kwargs): 81 | self.to_midi(*args, **kwargs).write(path) 82 | 83 | def add_notes(self, notes): 84 | self.notes += notes 85 | self.notes.sort(key=lambda note: note.start) 86 | 87 | def adjust_pitches(self, offset): 88 | for note in self.notes: 89 | pitch = note.pitch + offset 90 | pitch = 0 if pitch < 0 else pitch 91 | pitch = 127 if pitch > 127 else pitch 92 | note.pitch = pitch 93 | 94 | def adjust_velocities(self, offset): 95 | for note in self.notes: 96 | velocity = note.velocity + offset 97 | velocity = 0 if velocity < 0 else velocity 98 | velocity = 127 if velocity > 127 else velocity 99 | note.velocity = velocity 100 | 101 | def adjust_time(self, offset): 102 | for note in self.notes: 103 | note.start += offset 104 | note.end += offset 105 | 106 | def trim_overlapped_notes(self, min_interval=0): 107 | last_notes = {} 108 | for i, note in enumerate(self.notes): 109 | if note.pitch in last_notes: 110 | last_note = last_notes[note.pitch] 111 | if note.start - last_note.start <= min_interval: 112 | last_note.end = max(note.end, last_note.end) 113 | last_note.velocity = max(note.velocity, last_note.velocity) 114 | del self.notes[i] 115 | elif note.start < last_note.end: 116 | last_note.end = note.start 117 | else: 118 | last_notes[note.pitch] = note 119 | 120 | 121 | # ================================================================================== 122 | # Events 123 | # ================================================================================== 124 | 125 | class Event: 126 | 127 | def __init__(self, type, time, value): 128 | self.type = type 129 | self.time = time 130 | self.value = value 131 | 132 | def __repr__(self): 133 | return 'Event(type={}, time={}, value={})'.format( 134 | self.type, self.time, self.value) 135 | 136 | 137 | class EventSeq: 138 | 139 | pitch_range = DEFAULT_PITCH_RANGE 140 | velocity_range = DEFAULT_VELOCITY_RANGE 141 | velocity_steps = DEFAULT_VELOCITY_STEPS 142 | time_shift_bins = DEFAULT_TIME_SHIFT_BINS 143 | 144 | @staticmethod 145 | def from_note_seq(note_seq): 146 | note_events = [] 147 | 148 | if USE_VELOCITY: 149 | velocity_bins = EventSeq.get_velocity_bins() 150 | 151 | for note in note_seq.notes: 152 | if note.pitch in EventSeq.pitch_range: 153 | if USE_VELOCITY: 154 | velocity = note.velocity 155 | velocity = max(velocity, EventSeq.velocity_range.start) 156 | velocity = min(velocity, EventSeq.velocity_range.stop - 1) 157 | velocity_index = np.searchsorted(velocity_bins, velocity) 158 | note_events.append(Event('velocity', note.start, velocity_index)) 159 | 160 | pitch_index = note.pitch - EventSeq.pitch_range.start 161 | note_events.append(Event('note_on', note.start, pitch_index)) 162 | note_events.append(Event('note_off', note.end, pitch_index)) 163 | 164 | note_events.sort(key=lambda event: event.time) # stable 165 | events = [] 166 | 167 | for i, event in enumerate(note_events): 168 | events.append(event) 169 | 170 | if event is note_events[-1]: 171 | break 172 | 173 | interval = note_events[i + 1].time - event.time 174 | shift = 0 175 | 176 | while interval - shift >= EventSeq.time_shift_bins[0]: 177 | index = np.searchsorted(EventSeq.time_shift_bins, 178 | interval - shift, side='right') - 1 179 | events.append(Event('time_shift', event.time + shift, index)) 180 | shift += EventSeq.time_shift_bins[index] 181 | 182 | return EventSeq(events) 183 | 184 | @staticmethod 185 | def from_array(event_indeces): 186 | time = 0 187 | events = [] 188 | for event_index in event_indeces: 189 | for event_type, feat_range in EventSeq.feat_ranges().items(): 190 | if feat_range.start <= event_index < feat_range.stop: 191 | event_value = event_index - feat_range.start 192 | events.append(Event(event_type, time, event_value)) 193 | if event_type == 'time_shift': 194 | time += EventSeq.time_shift_bins[event_value] 195 | break 196 | 197 | return EventSeq(events) 198 | 199 | @staticmethod 200 | def dim(): 201 | return sum(EventSeq.feat_dims().values()) 202 | 203 | @staticmethod 204 | def feat_dims(): 205 | feat_dims = collections.OrderedDict() 206 | feat_dims['note_on'] = len(EventSeq.pitch_range) 207 | feat_dims['note_off'] = len(EventSeq.pitch_range) 208 | if USE_VELOCITY: 209 | feat_dims['velocity'] = EventSeq.velocity_steps 210 | feat_dims['time_shift'] = len(EventSeq.time_shift_bins) 211 | return feat_dims 212 | 213 | @staticmethod 214 | def feat_ranges(): 215 | offset = 0 216 | feat_ranges = collections.OrderedDict() 217 | for feat_name, feat_dim in EventSeq.feat_dims().items(): 218 | feat_ranges[feat_name] = range(offset, offset + feat_dim) 219 | offset += feat_dim 220 | return feat_ranges 221 | 222 | @staticmethod 223 | def get_velocity_bins(): 224 | n = EventSeq.velocity_range.stop - EventSeq.velocity_range.start 225 | return np.arange(EventSeq.velocity_range.start, 226 | EventSeq.velocity_range.stop, 227 | n / (EventSeq.velocity_steps - 1)) 228 | 229 | def __init__(self, events=[]): 230 | for event in events: 231 | assert isinstance(event, Event) 232 | 233 | self.events = copy.deepcopy(events) 234 | 235 | # compute event times again 236 | time = 0 237 | for event in self.events: 238 | event.time = time 239 | if event.type == 'time_shift': 240 | time += EventSeq.time_shift_bins[event.value] 241 | 242 | def to_note_seq(self): 243 | time = 0 244 | notes = [] 245 | 246 | velocity = DEFAULT_VELOCITY 247 | velocity_bins = EventSeq.get_velocity_bins() 248 | 249 | last_notes = {} 250 | 251 | for event in self.events: 252 | if event.type == 'note_on': 253 | pitch = event.value + EventSeq.pitch_range.start 254 | note = Note(velocity, pitch, time, None) 255 | notes.append(note) 256 | last_notes[pitch] = note 257 | 258 | elif event.type == 'note_off': 259 | pitch = event.value + EventSeq.pitch_range.start 260 | 261 | if pitch in last_notes: 262 | note = last_notes[pitch] 263 | note.end = max(time, note.start + MIN_NOTE_LENGTH) 264 | del last_notes[pitch] 265 | 266 | elif event.type == 'velocity': 267 | index = min(event.value, velocity_bins.size - 1) 268 | velocity = velocity_bins[index] 269 | 270 | elif event.type == 'time_shift': 271 | time += EventSeq.time_shift_bins[event.value] 272 | 273 | for note in notes: 274 | if note.end is None: 275 | note.end = note.start + DEFAULT_NOTE_LENGTH 276 | 277 | note.velocity = int(note.velocity) 278 | 279 | return NoteSeq(notes) 280 | 281 | def to_array(self): 282 | feat_idxs = EventSeq.feat_ranges() 283 | idxs = [feat_idxs[event.type][event.value] for event in self.events] 284 | dtype = np.uint8 if EventSeq.dim() <= 256 else np.uint16 285 | return np.array(idxs, dtype=dtype) 286 | 287 | 288 | # ================================================================================== 289 | # Controls 290 | # ================================================================================== 291 | 292 | class Control: 293 | 294 | def __init__(self, pitch_histogram, note_density): 295 | self.pitch_histogram = pitch_histogram # list 296 | self.note_density = note_density # int 297 | 298 | def __repr__(self): 299 | return 'Control(pitch_histogram={}, note_density={})'.format( 300 | self.pitch_histogram, self.note_density) 301 | 302 | def to_array(self): 303 | feat_dims = ControlSeq.feat_dims() 304 | ndens = np.zeros([feat_dims['note_density']]) 305 | ndens[self.note_density] = 1. # [dens_dim] 306 | phist = np.array(self.pitch_histogram) # [hist_dim] 307 | return np.concatenate([ndens, phist], 0) # [dens_dim + hist_dim] 308 | 309 | 310 | class ControlSeq: 311 | 312 | note_density_bins = DEFAULT_NOTE_DENSITY_BINS 313 | window_size = DEFAULT_WINDOW_SIZE 314 | 315 | @staticmethod 316 | def from_event_seq(event_seq): 317 | events = list(event_seq.events) 318 | start, end = 0, 0 319 | 320 | pitch_count = np.zeros([12]) 321 | note_count = 0 322 | 323 | controls = [] 324 | 325 | def _rel_pitch(pitch): 326 | return (pitch - 24) % 12 327 | 328 | for i, event in enumerate(events): 329 | 330 | while start < i: 331 | if events[start].type == 'note_on': 332 | abs_pitch = events[start].value + EventSeq.pitch_range.start 333 | rel_pitch = _rel_pitch(abs_pitch) 334 | pitch_count[rel_pitch] -= 1. 335 | note_count -= 1. 336 | start += 1 337 | 338 | while end < len(events): 339 | if events[end].time - event.time > ControlSeq.window_size: 340 | break 341 | if events[end].type == 'note_on': 342 | abs_pitch = events[end].value + EventSeq.pitch_range.start 343 | rel_pitch = _rel_pitch(abs_pitch) 344 | pitch_count[rel_pitch] += 1. 345 | note_count += 1. 346 | end += 1 347 | 348 | pitch_histogram = ( 349 | pitch_count / note_count 350 | if note_count 351 | else np.ones([12]) / 12 352 | ).tolist() 353 | 354 | note_density = max(np.searchsorted( 355 | ControlSeq.note_density_bins, 356 | note_count, side='right') - 1, 0) 357 | 358 | controls.append(Control(pitch_histogram, note_density)) 359 | 360 | return ControlSeq(controls) 361 | 362 | @staticmethod 363 | def dim(): 364 | return sum(ControlSeq.feat_dims().values()) 365 | 366 | @staticmethod 367 | def feat_dims(): 368 | note_density_dim = len(ControlSeq.note_density_bins) 369 | return collections.OrderedDict([ 370 | ('pitch_histogram', 12), 371 | ('note_density', note_density_dim) 372 | ]) 373 | 374 | @staticmethod 375 | def feat_ranges(): 376 | offset = 0 377 | feat_ranges = collections.OrderedDict() 378 | for feat_name, feat_dim in ControlSeq.feat_dims().items(): 379 | feat_ranges[feat_name] = range(offset, offset + feat_dim) 380 | offset += feat_dim 381 | return feat_ranges 382 | 383 | @staticmethod 384 | def recover_compressed_array(array): 385 | feat_dims = ControlSeq.feat_dims() 386 | assert array.shape[1] == 1 + feat_dims['pitch_histogram'] 387 | ndens = np.zeros([array.shape[0], feat_dims['note_density']]) 388 | ndens[np.arange(array.shape[0]), array[:, 0]] = 1. # [steps, dens_dim] 389 | phist = array[:, 1:].astype(np.float64) / 255 # [steps, hist_dim] 390 | return np.concatenate([ndens, phist], 1) # [steps, dens_dim + hist_dim] 391 | 392 | def __init__(self, controls): 393 | for control in controls: 394 | assert isinstance(control, Control) 395 | self.controls = copy.deepcopy(controls) 396 | 397 | def to_compressed_array(self): 398 | ndens = [control.note_density for control in self.controls] 399 | ndens = np.array(ndens, dtype=np.uint8).reshape(-1, 1) 400 | phist = [control.pitch_histogram for control in self.controls] 401 | phist = (np.array(phist) * 255).astype(np.uint8) 402 | return np.concatenate([ 403 | ndens, # [steps, 1] density index 404 | phist # [steps, hist_dim] 0-255 405 | ], 1) # [steps, hist_dim + 1] 406 | 407 | 408 | if __name__ == '__main__': 409 | import pickle 410 | import sys 411 | path = sys.argv[1] if len(sys.argv) > 1 else 'dataset/midi/ecomp/BLINOV02.mid' 412 | 413 | print('Converting MIDI to EventSeq') 414 | es = EventSeq.from_note_seq(NoteSeq.from_midi_file(path)) 415 | 416 | print('Converting EventSeq to MIDI') 417 | EventSeq.from_array(es.to_array()).to_note_seq().to_midi_file('/tmp/test.mid') 418 | 419 | print('Converting EventSeq to ControlSeq') 420 | cs = ControlSeq.from_event_seq(es) 421 | 422 | print('Saving compressed ControlSeq') 423 | pickle.dump(cs.to_compressed_array(), open('/tmp/cs-compressed.data', 'wb')) 424 | 425 | print('Loading compressed ControlSeq') 426 | c = ControlSeq.recover_compressed_array( 427 | pickle.load(open('/tmp/cs-compressed.data', 'rb'))) 428 | 429 | print('Done') 430 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | from torch.autograd import Variable 5 | 6 | import numpy as np 7 | 8 | import os 9 | import sys 10 | import time 11 | import optparse 12 | 13 | 14 | import utils 15 | import config 16 | from data import Dataset 17 | from model import PerformanceRNN 18 | from sequence import NoteSeq, EventSeq, ControlSeq 19 | 20 | # pylint: disable=E1102 21 | # pylint: disable=E1101 22 | 23 | #======================================================================== 24 | # Settings 25 | #======================================================================== 26 | 27 | def get_options(): 28 | parser = optparse.OptionParser() 29 | 30 | parser.add_option('-s', '--session', 31 | dest='sess_path', 32 | type='string', 33 | default='save/train.sess') 34 | 35 | parser.add_option('-d', '--dataset', 36 | dest='data_path', 37 | type='string', 38 | default='dataset/processed/') 39 | 40 | parser.add_option('-i', '--saving-interval', 41 | dest='saving_interval', 42 | type='float', 43 | default=60.) 44 | 45 | parser.add_option('-b', '--batch-size', 46 | dest='batch_size', 47 | type='int', 48 | default=config.train['batch_size']) 49 | 50 | parser.add_option('-l', '--learning-rate', 51 | dest='learning_rate', 52 | type='float', 53 | default=config.train['learning_rate']) 54 | 55 | parser.add_option('-w', '--window-size', 56 | dest='window_size', 57 | type='int', 58 | default=config.train['window_size']) 59 | 60 | parser.add_option('-S', '--stride-size', 61 | dest='stride_size', 62 | type='int', 63 | default=config.train['stride_size']) 64 | 65 | parser.add_option('-c', '--control-ratio', 66 | dest='control_ratio', 67 | type='float', 68 | default=config.train['control_ratio']) 69 | 70 | parser.add_option('-T', '--teacher-forcing-ratio', 71 | dest='teacher_forcing_ratio', 72 | type='float', 73 | default=config.train['teacher_forcing_ratio']) 74 | 75 | parser.add_option('-t', '--use-transposition', 76 | dest='use_transposition', 77 | action='store_true', 78 | default=config.train['use_transposition']) 79 | 80 | parser.add_option('-p', '--model-params', 81 | dest='model_params', 82 | type='string', 83 | default='') 84 | 85 | parser.add_option('-r', '--reset-optimizer', 86 | dest='reset_optimizer', 87 | action='store_true', 88 | default=False) 89 | 90 | parser.add_option('-L', '--enable-logging', 91 | dest='enable_logging', 92 | action='store_true', 93 | default=False) 94 | 95 | return parser.parse_args()[0] 96 | 97 | options = get_options() 98 | 99 | #------------------------------------------------------------------------ 100 | 101 | sess_path = options.sess_path 102 | data_path = options.data_path 103 | saving_interval = options.saving_interval 104 | 105 | learning_rate = options.learning_rate 106 | batch_size = options.batch_size 107 | window_size = options.window_size 108 | stride_size = options.stride_size 109 | use_transposition = options.use_transposition 110 | control_ratio = options.control_ratio 111 | teacher_forcing_ratio = options.teacher_forcing_ratio 112 | reset_optimizer = options.reset_optimizer 113 | enable_logging = options.enable_logging 114 | 115 | event_dim = EventSeq.dim() 116 | control_dim = ControlSeq.dim() 117 | model_config = config.model 118 | model_params = utils.params2dict(options.model_params) 119 | model_config.update(model_params) 120 | device = config.device 121 | 122 | print('-' * 70) 123 | 124 | print('Session path:', sess_path) 125 | print('Dataset path:', data_path) 126 | print('Saving interval:', saving_interval) 127 | print('-' * 70) 128 | 129 | print('Hyperparameters:', utils.dict2params(model_config)) 130 | print('Learning rate:', learning_rate) 131 | print('Batch size:', batch_size) 132 | print('Window size:', window_size) 133 | print('Stride size:', stride_size) 134 | print('Control ratio:', control_ratio) 135 | print('Teacher forcing ratio:', teacher_forcing_ratio) 136 | print('Random transposition:', use_transposition) 137 | print('Reset optimizer:', reset_optimizer) 138 | print('Enabling logging:', enable_logging) 139 | print('Device:', device) 140 | print('-' * 70) 141 | 142 | 143 | #======================================================================== 144 | # Load session and dataset 145 | #======================================================================== 146 | 147 | def load_session(): 148 | global sess_path, model_config, device, learning_rate, reset_optimizer 149 | try: 150 | sess = torch.load(sess_path) 151 | if 'model_config' in sess and sess['model_config'] != model_config: 152 | model_config = sess['model_config'] 153 | print('Use session config instead:') 154 | print(utils.dict2params(model_config)) 155 | model_state = sess['model_state'] 156 | optimizer_state = sess['model_optimizer_state'] 157 | print('Session is loaded from', sess_path) 158 | sess_loaded = True 159 | except: 160 | print('New session') 161 | sess_loaded = False 162 | model = PerformanceRNN(**model_config).to(device) 163 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 164 | if sess_loaded: 165 | model.load_state_dict(model_state) 166 | if not reset_optimizer: 167 | optimizer.load_state_dict(optimizer_state) 168 | return model, optimizer 169 | 170 | def load_dataset(): 171 | global data_path 172 | dataset = Dataset(data_path, verbose=True) 173 | dataset_size = len(dataset.samples) 174 | assert dataset_size > 0 175 | return dataset 176 | 177 | 178 | print('Loading session') 179 | model, optimizer = load_session() 180 | print(model) 181 | 182 | print('-' * 70) 183 | 184 | print('Loading dataset') 185 | dataset = load_dataset() 186 | print(dataset) 187 | 188 | print('-' * 70) 189 | 190 | #------------------------------------------------------------------------ 191 | 192 | def save_model(): 193 | global model, optimizer, model_config, sess_path 194 | print('Saving to', sess_path) 195 | torch.save({'model_config': model_config, 196 | 'model_state': model.state_dict(), 197 | 'model_optimizer_state': optimizer.state_dict()}, sess_path) 198 | print('Done saving') 199 | 200 | 201 | #======================================================================== 202 | # Training 203 | #======================================================================== 204 | 205 | if enable_logging: 206 | from torch.utils.tensorboard import SummaryWriter 207 | writer = SummaryWriter() 208 | 209 | last_saving_time = time.time() 210 | loss_function = nn.CrossEntropyLoss() 211 | 212 | try: 213 | batch_gen = dataset.batches(batch_size, window_size, stride_size) 214 | 215 | for iteration, (events, controls) in enumerate(batch_gen): 216 | if use_transposition: 217 | offset = np.random.choice(np.arange(-6, 6)) 218 | events, controls = utils.transposition(events, controls, offset) 219 | 220 | events = torch.LongTensor(events).to(device) 221 | assert events.shape[0] == window_size 222 | 223 | if np.random.random() < control_ratio: 224 | controls = torch.FloatTensor(controls).to(device) 225 | assert controls.shape[0] == window_size 226 | else: 227 | controls = None 228 | 229 | init = torch.randn(batch_size, model.init_dim).to(device) 230 | outputs = model.generate(init, window_size, events=events[:-1], controls=controls, 231 | teacher_forcing_ratio=teacher_forcing_ratio, output_type='logit') 232 | assert outputs.shape[:2] == events.shape[:2] 233 | 234 | loss = loss_function(outputs.view(-1, event_dim), events.view(-1)) 235 | model.zero_grad() 236 | loss.backward() 237 | 238 | norm = utils.compute_gradient_norm(model.parameters()) 239 | nn.utils.clip_grad_norm_(model.parameters(), 1.0) 240 | 241 | optimizer.step() 242 | 243 | if enable_logging: 244 | writer.add_scalar('model/loss', loss.item(), iteration) 245 | writer.add_scalar('model/norm', norm.item(), iteration) 246 | 247 | print(f'iter {iteration}, loss: {loss.item()}') 248 | 249 | if time.time() - last_saving_time > saving_interval: 250 | save_model() 251 | last_saving_time = time.time() 252 | 253 | except KeyboardInterrupt: 254 | save_model() 255 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sequence import EventSeq, ControlSeq 4 | 5 | 6 | def find_files_by_extensions(root, exts=[]): 7 | def _has_ext(name): 8 | if not exts: 9 | return True 10 | name = name.lower() 11 | for ext in exts: 12 | if name.endswith(ext): 13 | return True 14 | return False 15 | for path, _, files in os.walk(root): 16 | for name in files: 17 | if _has_ext(name): 18 | yield os.path.join(path, name) 19 | 20 | def event_indeces_to_midi_file(event_indeces, midi_file_name, velocity_scale=0.8): 21 | event_seq = EventSeq.from_array(event_indeces) 22 | note_seq = event_seq.to_note_seq() 23 | for note in note_seq.notes: 24 | note.velocity = int((note.velocity - 64) * velocity_scale + 64) 25 | note_seq.to_midi_file(midi_file_name) 26 | return len(note_seq.notes) 27 | 28 | def transposition(events, controls, offset=0): 29 | # events [steps, batch_size, event_dim] 30 | # return events, controls 31 | 32 | events = np.array(events, dtype=np.int64) 33 | controls = np.array(controls, dtype=np.float32) 34 | event_feat_ranges = EventSeq.feat_ranges() 35 | 36 | on = event_feat_ranges['note_on'] 37 | off = event_feat_ranges['note_off'] 38 | 39 | if offset > 0: 40 | indeces0 = (((on.start <= events) & (events < on.stop - offset)) | 41 | ((off.start <= events) & (events < off.stop - offset))) 42 | indeces1 = (((on.stop - offset <= events) & (events < on.stop)) | 43 | ((off.stop - offset <= events) & (events < off.stop))) 44 | events[indeces0] += offset 45 | events[indeces1] += offset - 12 46 | elif offset < 0: 47 | indeces0 = (((on.start - offset <= events) & (events < on.stop)) | 48 | ((off.start - offset <= events) & (events < off.stop))) 49 | indeces1 = (((on.start <= events) & (events < on.start - offset)) | 50 | ((off.start <= events) & (events < off.start - offset))) 51 | events[indeces0] += offset 52 | events[indeces1] += offset + 12 53 | 54 | assert ((0 <= events) & (events < EventSeq.dim())).all() 55 | histr = ControlSeq.feat_ranges()['pitch_histogram'] 56 | controls[:, :, histr.start:histr.stop] = np.roll( 57 | controls[:, :, histr.start:histr.stop], offset, -1) 58 | 59 | return events, controls 60 | 61 | def dict2params(d, f=','): 62 | return f.join(f'{k}={v}' for k, v in d.items()) 63 | 64 | def params2dict(p, f=',', e='='): 65 | d = {} 66 | for item in p.split(f): 67 | item = item.split(e) 68 | if len(item) < 2: 69 | continue 70 | k, *v = item 71 | d[k] = eval('='.join(v)) 72 | return d 73 | 74 | def compute_gradient_norm(parameters, norm_type=2): 75 | total_norm = 0 76 | for p in parameters: 77 | param_norm = p.grad.data.norm(norm_type) 78 | total_norm += param_norm ** norm_type 79 | total_norm = total_norm ** (1. / norm_type) 80 | return total_norm 81 | --------------------------------------------------------------------------------