├── .gitignore
├── LICENSE
├── README.md
├── adversarial.py
├── config.py
├── data.py
├── dataset
├── midi
│ └── .keep
├── processed
│ └── .keep
└── scripts
│ ├── .keep
│ ├── classic_piano_downloader.sh
│ ├── ecomp_piano_downloader.sh
│ ├── midiworld_downloader.sh
│ └── touhou_downloader.sh
├── generate.py
├── model.py
├── output
└── .keep
├── play.py
├── preprocess.py
├── runs
└── .keep
├── save
└── .keep
├── sequence.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | *.pyc
3 | .DS_Store
4 | *.mid
5 | *.MID
6 | *.sess
7 | *.data
8 | runs/
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Yuankui Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # Performance RNN - PyTorch
4 |
5 | [](https://opensource.org/licenses/MIT)
6 |
7 | PyTorch implementation of Performance RNN, inspired by *Ian Simon and Sageev Oore. "Performance RNN: Generating Music with Expressive
8 | Timing and Dynamics." Magenta Blog, 2017.*
9 | [https://magenta.tensorflow.org/performance-rnn](https://magenta.tensorflow.org/performance-rnn).
10 |
11 | This model is not implemented in the official way!
12 |
13 | ## Generated Samples
14 |
15 | - A sample on C Major Scale [[MIDI](https://drive.google.com/open?id=1mZtkpsu1yA8oOkE_1b2jyFsvCW70FiKU), [MP3](https://drive.google.com/open?id=1UqyJ9e58AOimFeY1xoCPyedTz-g2fUxv)]
16 | - control option: `-c '1,0,1,0,1,1,0,1,0,1,0,1;4'`
17 | - A sample on C Minor Scale [[MIDI](https://drive.google.com/open?id=1lIVCIT7INuTa-HKrgPzewrgCbgwCRRa1), [MP3](https://drive.google.com/open?id=1pVg3Mg2pSq8VHJRJrgNUZybpsErjzpjF)]
18 | - control option: `-c '1,0,1,1,0,1,0,1,1,0,0,1;4'`
19 | - A sample on C Major Pentatonic Scale [[MIDI](https://drive.google.com/open?id=16uRwyntgYTzSmaxhp06kUbThDm8W_vVE), [MP3](https://drive.google.com/open?id=1LSbeVqXKAPrNPCPcjy6FVwUuVo7FxYji)]
20 | - control option: `-c '5,0,4,0,4,1,0,5,0,4,0,1;3'`
21 | - A sample on C Minor Pentatonic Scale [[MIDI](https://drive.google.com/open?id=1zeMHNu37U6byhT-s63EIro8nL6VkUi8u), [MP3](https://drive.google.com/open?id=1asP1z6u1n3PRSysSnvkt-SabpTgT-_x5)]
22 | - control option: `-c '5,0,1,4,0,4,0,5,1,0,4,0;3'`
23 |
24 | ## Directory Structure
25 |
26 | ```
27 | .
28 | ├── dataset/
29 | │ ├── midi/
30 | │ │ ├── dataset1/
31 | │ │ │ └── *.mid
32 | │ │ └── dataset2/
33 | │ │ └── *.mid
34 | │ ├── processed/
35 | │ │ └── dataset1/
36 | │ │ └── *.data (preprocess.py)
37 | │ └── scripts/
38 | │ └── *.sh (dataset download scripts)
39 | ├── output/
40 | │ └── *.mid (generate.py)
41 | ├── save/
42 | │ └── *.sess (train.py)
43 | └── runs/ (tensorboard logdir)
44 | ```
45 |
46 | ## Instructions
47 |
48 | - Download datasets
49 |
50 | ```shell
51 | cd dataset/
52 | bash scripts/NAME_scraper.sh midi/NAME
53 | ```
54 |
55 | - Preprocessing
56 |
57 | ```shell
58 | # Preprocess all MIDI files under dataset/midi/NAME
59 | python3 preprocess.py dataset/midi/NAME dataset/processed/NAME
60 | ```
61 |
62 | - Training
63 |
64 | ```shell
65 | # Train on .data files in dataset/processed/MYDATA, and save to save/myModel.sess every 10s
66 | python3 train.py -s save/myModel.sess -d dataset/processed/MYDATA -i 10
67 |
68 | # Or...
69 | python3 train.py -s save/myModel.sess -d dataset/processed/MYDATA -p hidden_dim=1024
70 | python3 train.py -s save/myModel.sess -d dataset/processed/MYDATA -b 128 -c 0.3
71 | python3 train.py -s save/myModel.sess -d dataset/processed/MYDATA -w 100 -S 10
72 | ```
73 |
74 | 
75 |
76 |
77 | - Generating
78 |
79 | ```shell
80 | # Generate with control sequence from test.data and model from save/test.sess
81 | python3 generate.py -s save/test.sess -c test.data
82 |
83 | # Generate with pitch histogram and note density (C major scale)
84 | python3 generate.py -s save/test.sess -l 1000 -c '1,0,1,0,1,1,0,1,0,1,0,1;3'
85 |
86 | # Or...
87 | python3 generate.py -s save/test.sess -l 1000 -c ';3' # uniform pitch histogram
88 | python3 generate.py -s save/test.sess -l 1000 # no control
89 |
90 | # Use control sequence from processed data
91 | python3 generate.py -s save/test.sess -c dataset/processed/some/processed.data
92 | ```
93 |
94 | 
95 |
96 | 
97 |
98 | ## Pretrained Model
99 |
100 | - [ecomp.sess](https://drive.google.com/open?id=1daT6XRQUTS6AQ5jyRPqzowXia-zVqg6m)
101 | - default configuration
102 | - dataset: [International Piano-e-Competition, recorded MIDI files](http://www.piano-e-competition.com/)
103 | - [ecomp_w500.sess](https://drive.google.com/open?id=1jf5j2cWppXVeSXhTuiNfAFEyWFIaNZ6f)
104 | - window_size: 500
105 | - control_ratio: 0.7
106 | - dataset: [International Piano-e-Competition, recorded MIDI files](http://www.piano-e-competition.com/)
107 |
108 | ## Requirements
109 |
110 | - pretty_midi
111 | - numpy
112 | - pytorch >= 0.4
113 | - tensorboardX
114 | - progress
115 |
--------------------------------------------------------------------------------
/adversarial.py:
--------------------------------------------------------------------------------
1 | # Adversarial learning for event-based music generation with SeqGAN
2 | # Reference:
3 | # "SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient."
4 | # (Yu, Lantao, et al.).
5 | # ... Honestly, it's too hard to train ;(
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import torch.nn.functional as F
11 | from torch.distributions import Categorical
12 |
13 | import numpy as np
14 | import os, sys, time, argparse
15 | from progress.bar import Bar
16 |
17 | import config, utils
18 | from config import device
19 | from data import Dataset
20 | from model import PerformanceRNN
21 | from sequence import EventSeq, ControlSeq
22 |
23 | # pylint: disable=E1101
24 |
25 |
26 | #========================================================================
27 | # Discriminator
28 | #========================================================================
29 |
30 | discriminator_config = {
31 | 'event_dim': EventSeq.dim(),
32 | 'hidden_dim': 512,
33 | 'gru_layers': 3,
34 | 'gru_dropout': 0.3
35 | }
36 |
37 | class EventSequenceEncoder(nn.Module):
38 | def __init__(self, event_dim=EventSeq.dim(), hidden_dim=512,
39 | gru_layers=3, gru_dropout=0.3):
40 | super().__init__()
41 | self.event_embedding = nn.Embedding(event_dim, hidden_dim)
42 | self.gru = nn.GRU(hidden_dim, hidden_dim,
43 | num_layers=gru_layers, dropout=gru_dropout)
44 | self.attn = nn.Parameter(torch.randn(hidden_dim), requires_grad=True)
45 | self.output_fc = nn.Linear(hidden_dim, 1)
46 | self.output_fc_activation = nn.Sigmoid()
47 |
48 | def forward(self, events, hidden=None, output_logits=False):
49 | # events: [steps, batch_size]
50 | events = self.event_embedding(events)
51 | outputs, _ = self.gru(events, hidden) # [t, b, h]
52 | weights = (outputs * self.attn).sum(-1, keepdim=True)
53 | output = (outputs * weights).mean(0) # [b, h]
54 | output = self.output_fc(output).squeeze(-1) # [b]
55 | if output_logits:
56 | return output
57 | output = self.output_fc_activation(output)
58 | return output
59 |
60 |
61 | #========================================================================
62 | # Pretrain Discriminator
63 | #========================================================================
64 |
65 | def pretrain_discriminator(model_sess_path, # load
66 | discriminator_sess_path, # load + save
67 | batch_data_generator, # Dataset(...).batches(...)
68 | discriminator_config_overwrite={},
69 | gradient_clipping=False,
70 | control_ratio=1.0,
71 | num_iter=-1,
72 | save_interval=60.0,
73 | discriminator_lr=0.001,
74 | enable_logging=False,
75 | auto_sample_factor=False,
76 | sample_factor=1.0):
77 |
78 | print('-' * 70)
79 | print('model_sess_path:', model_sess_path)
80 | print('discriminator_sess_path:', discriminator_sess_path)
81 | print('discriminator_config_overwrite:', discriminator_config_overwrite)
82 | print('sample_factor:', sample_factor)
83 | print('auto_sample_factor:', auto_sample_factor)
84 | print('discriminator_lr:', discriminator_lr)
85 | print('gradient_clipping:', gradient_clipping)
86 | print('control_ratio:', control_ratio)
87 | print('num_iter:', num_iter)
88 | print('save_interval:', save_interval)
89 | print('enable_logging:', enable_logging)
90 | print('-' * 70)
91 |
92 | # Load generator
93 | model_sess = torch.load(model_sess_path)
94 | model_config = model_sess['model_config']
95 | model = PerformanceRNN(**model_config).to(device)
96 | model.load_state_dict(model_sess['model_state'])
97 |
98 | print(f'Generator from "{model_sess_path}"')
99 | print(model)
100 | print('-' * 70)
101 |
102 | # Load discriminator and optimizer
103 | global discriminator_config
104 | try:
105 | discriminator_sess = torch.load(discriminator_sess_path)
106 | discriminator_config = discriminator_sess['discriminator_config']
107 | discriminator_state = discriminator_sess['discriminator_state']
108 | discriminator_optimizer_state = discriminator_sess['discriminator_optimizer_state']
109 | print(f'Discriminator from "{discriminator_sess_path}"')
110 | discriminator_loaded = True
111 | except:
112 | print(f'New discriminator session at "{discriminator_sess_path}"')
113 | discriminator_config.update(discriminator_config_overwrite)
114 | discriminator_loaded = False
115 |
116 | discriminator = EventSequenceEncoder(**discriminator_config).to(device)
117 | optimizer = optim.Adam(discriminator.parameters(), lr=discriminator_lr)
118 | if discriminator_loaded:
119 | discriminator.load_state_dict(discriminator_state)
120 | optimizer.load_state_dict(discriminator_optimizer_state)
121 |
122 | print(discriminator)
123 | print(optimizer)
124 | print('-' * 70)
125 |
126 | def save_discriminator():
127 | print(f'Saving to "{discriminator_sess_path}"')
128 | torch.save({
129 | 'discriminator_config': discriminator_config,
130 | 'discriminator_state': discriminator.state_dict(),
131 | 'discriminator_optimizer_state': optimizer.state_dict()
132 | }, discriminator_sess_path)
133 | print('Done saving')
134 |
135 | # Disable gradient for generator
136 | for parameter in model.parameters():
137 | parameter.requires_grad_(False)
138 |
139 | model.eval()
140 | discriminator.train()
141 |
142 | loss_func = nn.BCEWithLogitsLoss()
143 | last_save_time = time.time()
144 |
145 | if enable_logging:
146 | from tensorboardX import SummaryWriter
147 | writer = SummaryWriter()
148 |
149 | try:
150 | for i, (events, controls) in enumerate(batch_data_generator):
151 | if i == num_iter:
152 | break
153 |
154 | steps, batch_size = events.shape
155 |
156 | # Prepare inputs
157 | events = torch.LongTensor(events).to(device)
158 | if np.random.random() <= control_ratio:
159 | controls = torch.FloatTensor(controls).to(device)
160 | else:
161 | controls = None
162 |
163 | init = torch.randn(batch_size, model.init_dim).to(device)
164 |
165 | # Predict for real event sequence
166 | real_events = events
167 | real_logit = discriminator(real_events, output_logits=True)
168 | real_target = torch.ones_like(real_logit).to(device)
169 |
170 | if auto_sample_factor:
171 | sample_factor = np.random.choice([
172 | 0.1, 0.4, 0.6, 0.7, 0.8, 0.9, 1.0,
173 | 1.1, 1.2, 1.4, 1.6, 2.0, 4.0, 10.0])
174 |
175 | # Predict for fake event sequence from the generator
176 | fake_events = model.generate(init, steps, None, controls,
177 | greedy=0, output_type='index',
178 | temperature=sample_factor)
179 | fake_logit = discriminator(fake_events, output_logits=True)
180 | fake_target = torch.zeros_like(fake_logit).to(device)
181 |
182 | # Compute loss
183 | loss = (loss_func(real_logit, real_target) +
184 | loss_func(fake_logit, fake_target)) / 2
185 |
186 | # Backprop
187 | discriminator.zero_grad()
188 | loss.backward()
189 |
190 | # Gradient clipping
191 | norm = utils.compute_gradient_norm(discriminator.parameters())
192 | if gradient_clipping:
193 | nn.utils.clip_grad_norm_(discriminator.parameters(), gradient_clipping)
194 |
195 | optimizer.step()
196 |
197 | # Logging
198 | loss = loss.item()
199 | norm = norm.item()
200 | print(f'{i} loss: {loss}, norm: {norm}, sf: {sample_factor}')
201 | if enable_logging:
202 | writer.add_scalar(f'pretrain/D/loss/all', loss, i)
203 | writer.add_scalar(f'pretrain/D/loss/{sample_factor}', loss, i)
204 | writer.add_scalar(f'pretrain/D/norm/{sample_factor}', norm, i)
205 |
206 | if last_save_time + save_interval < time.time():
207 | last_save_time = time.time()
208 | save_discriminator()
209 |
210 | except KeyboardInterrupt:
211 | save_discriminator()
212 |
213 |
214 | #========================================================================
215 | # Adversarial Learning
216 | #========================================================================
217 |
218 |
219 | def train_adversarial(sess_path, batch_data_generator,
220 | model_load_path, model_optimizer_class,
221 | model_gradient_clipping, discriminator_gradient_clipping,
222 | model_learning_rate, reset_model_optimizer,
223 | discriminator_load_path, discriminator_optimizer_class,
224 | discriminator_learning_rate, reset_discriminator_optimizer,
225 | g_max_q_mean, g_min_q_mean, d_min_loss, g_max_steps, d_max_steps,
226 | mc_sample_size, mc_sample_factor, first_to_train,
227 | save_interval, control_ratio, enable_logging):
228 |
229 | if enable_logging:
230 | from tensorboardX import SummaryWriter
231 | writer = SummaryWriter()
232 |
233 | if os.path.isfile(sess_path):
234 | adv_state = torch.load(sess_path)
235 | model_config = adv_state['model_config']
236 | model_state = adv_state['model_state']
237 | model_optimizer_state = adv_state['model_optimizer_state']
238 | discriminator_config = adv_state['discriminator_config']
239 | discriminator_state = adv_state['discriminator_state']
240 | discriminator_optimizer_state = adv_state['discriminator_optimizer_state']
241 | print('-' * 70)
242 | print('Session is loaded from', sess_path)
243 | loaded_from_session = True
244 |
245 | else:
246 | model_sess = torch.load(model_load_path)
247 | model_config = model_sess['model_config']
248 | model_state = model_sess['model_state']
249 | discriminator_sess = torch.load(discriminator_load_path)
250 | discriminator_config = discriminator_sess['discriminator_config']
251 | discriminator_state = discriminator_sess['discriminator_state']
252 | loaded_from_session = False
253 |
254 | model = PerformanceRNN(**model_config)
255 | model.load_state_dict(model_state)
256 | model.to(device).train()
257 | model_optimizer = model_optimizer_class(model.parameters(), lr=model_learning_rate)
258 |
259 | discriminator = EventSequenceEncoder(**discriminator_config)
260 | discriminator.load_state_dict(discriminator_state)
261 | discriminator.to(device).train()
262 | discriminator_optimizer = discriminator_optimizer_class(discriminator.parameters(),
263 | lr=discriminator_learning_rate)
264 |
265 | if loaded_from_session:
266 | if not reset_model_optimizer:
267 | model_optimizer.load_state_dict(model_optimizer_state)
268 | if not reset_discriminator_optimizer:
269 | discriminator_optimizer.load_state_dict(discriminator_optimizer_state)
270 |
271 | g_loss_func = nn.CrossEntropyLoss()
272 | d_loss_func = nn.BCEWithLogitsLoss(reduce=False)
273 |
274 |
275 | print('-' * 70)
276 | print('Options')
277 | print('sess_path:', sess_path)
278 | print('save_interval:', save_interval)
279 | print('batch_data_generator:', batch_data_generator)
280 | print('control_ratio:', control_ratio)
281 | print('g_max_q_mean:', g_max_q_mean)
282 | print('g_min_q_mean:', g_min_q_mean)
283 | print('d_min_loss:', d_min_loss)
284 | print('mc_sample_size:', mc_sample_size)
285 | print('mc_sample_factor:', mc_sample_factor)
286 | print('enable_logging:', enable_logging)
287 | print('model_load_path:', model_load_path)
288 | print('model_loss:', g_loss_func)
289 | print('model_optimizer_class:', model_optimizer_class)
290 | print('model_gradient_clipping:', model_gradient_clipping)
291 | print('model_learning_rate:', model_learning_rate)
292 | print('reset_model_optimizer:', reset_model_optimizer)
293 | print('discriminator_load_path:', discriminator_load_path)
294 | print('discriminator_loss:', d_loss_func)
295 | print('discriminator_optimizer_class:', discriminator_optimizer_class)
296 | print('discriminator_gradient_clipping:', discriminator_gradient_clipping)
297 | print('discriminator_learning_rate:', discriminator_learning_rate)
298 | print('reset_discriminator_optimizer:', reset_discriminator_optimizer)
299 | print('first_to_train:', first_to_train)
300 | print('-' * 70)
301 | print(f'Generator from "{sess_path if loaded_from_session else model_load_path}"')
302 | print(model)
303 | print(model_optimizer)
304 | print('-' * 70)
305 | print(f'Discriminator from "{sess_path if loaded_from_session else discriminator_load_path}"')
306 | print(discriminator)
307 | print(discriminator_optimizer)
308 | print('-' * 70)
309 |
310 |
311 | def save():
312 | print(f'Saving to "{sess_path}"')
313 | torch.save({
314 | 'model_config': model_config,
315 | 'model_state': model.state_dict(),
316 | 'model_optimizer_state': model_optimizer.state_dict(),
317 | 'discriminator_config': discriminator_config,
318 | 'discriminator_state': discriminator.state_dict(),
319 | 'discriminator_optimizer_state': discriminator_optimizer.state_dict()
320 | }, sess_path)
321 | print('Done saving')
322 |
323 | def mc_rollout(generated, hidden, total_steps, controls=None):
324 | # generated: [t, batch_size]
325 | # hidden: [n_layers, batch_size, hidden_dim]
326 | # controls: [total_steps - t, batch_size, control_dim]
327 | generated = torch.cat(generated, 0)
328 | generated_steps, batch_size = generated.shape # t, b
329 | steps = total_steps - generated_steps # s
330 |
331 | generated = generated.unsqueeze(1) # [t, 1, b]
332 | generated = generated.repeat(1, mc_sample_size, 1) # [t, mcs, b]
333 | generated = generated.view(generated_steps, -1) # [t, mcs * b]
334 |
335 | hidden = hidden.unsqueeze(1).repeat(1, mc_sample_size, 1, 1)
336 | hidden = hidden.view(model.gru_layers, -1, model.hidden_dim)
337 |
338 | if controls is not None:
339 | assert controls.shape == (steps, batch_size, model.control_dim)
340 | controls = controls.unsqueeze(1) # [s, 1, b, c]
341 | controls = controls.repeat(1, mc_sample_size, 1, 1) # [s, mcs, b, c]
342 | controls = controls.view(steps, -1, model.control_dim) # [s, mcs * b, c]
343 |
344 | event = generated[-1].unsqueeze(0) # [1, mcs * b]
345 | control = None # default when controls is None
346 | outputs = []
347 |
348 | for i in range(steps):
349 | if controls is not None:
350 | control = controls[i].unsqueeze(0) # [1, mcs * b, c]
351 |
352 | output, hidden = model.forward(event, control=control, hidden=hidden)
353 | probs = model.output_fc_activation(output / mc_sample_factor)
354 | event = Categorical(probs).sample() # [1, mcs * b]
355 | outputs.append(event)
356 |
357 | sequences = torch.cat([generated, *outputs], 0)
358 | assert sequences.shape == (total_steps, mc_sample_size * batch_size)
359 | return sequences
360 |
361 |
362 | def train_generator(batch_size, init, events, controls):
363 | # Generator step
364 | hidden = model.init_to_hidden(init)
365 | event = model.get_primary_event(batch_size)
366 | outputs = []
367 | generated = []
368 | q_values = []
369 |
370 | for step in Bar('MC Rollout').iter(range(steps)):
371 | control = controls[step].unsqueeze(0) if use_control else None
372 | output, hidden = model.forward(event, control=control, hidden=hidden)
373 | outputs.append(output)
374 | probs = model.output_fc_activation(output / mc_sample_factor)
375 | generated.append(Categorical(probs).sample())
376 |
377 | with torch.no_grad():
378 | if step < steps - 1:
379 | sequences = mc_rollout(generated, hidden, steps, controls[step+1:])
380 | mc_score = discriminator(sequences) # [mcs * b]
381 | mc_score = mc_score.view(mc_sample_size, batch_size) # [mcs, b]
382 | q_value = mc_score.mean(0, keepdim=True) # [1, batch_size]
383 |
384 | else:
385 | q_value = discriminator(torch.cat(generated, 0))
386 | q_value = q_value.unsqueeze(0) # [1, batch_size]
387 |
388 | q_values.append(q_value)
389 |
390 | # Compute loss
391 | q_values = torch.cat(q_values, 0) # [steps, batch_size]
392 | q_mean = q_values.mean().detach()
393 | q_values = q_values - q_mean
394 | generated = torch.cat(generated, 0) # [steps, batch_size]
395 | outputs = torch.cat(outputs, 0) # [steps, batch_size, event_dim]
396 | loss = F.cross_entropy(outputs.view(-1, model.event_dim),
397 | generated.view(-1),
398 | reduce=False)
399 | loss = (loss * q_values.view(-1)).mean()
400 |
401 | # Backprop
402 | model.zero_grad()
403 | loss.backward()
404 |
405 | # Gradient clipping
406 | norm = utils.compute_gradient_norm(model.parameters())
407 | if model_gradient_clipping:
408 | nn.utils.clip_grad_norm_(model.parameters(), model_gradient_clipping)
409 |
410 | model_optimizer.step()
411 |
412 | q_mean = q_mean.item()
413 | norm = norm.item()
414 | return q_mean, norm
415 |
416 | def train_discriminator(batch_size, init, events, controls):
417 | # Discriminator step
418 | with torch.no_grad():
419 | generated = model.generate(init, steps, None, controls,
420 | greedy=0, temperature=mc_sample_factor)
421 |
422 | fake_logit = discriminator(generated, output_logits=True)
423 | real_logit = discriminator(events, output_logits=True)
424 | fake_target = torch.zeros_like(fake_logit)
425 | real_target = torch.ones_like(real_logit)
426 |
427 | # Compute loss
428 | fake_loss = F.binary_cross_entropy_with_logits(fake_logit, fake_target)
429 | real_loss = F.binary_cross_entropy_with_logits(real_logit, real_target)
430 | loss = (real_loss + fake_loss) / 2
431 |
432 | # Backprop
433 | discriminator.zero_grad()
434 | loss.backward()
435 |
436 | # Gradient clipping
437 | norm = utils.compute_gradient_norm(discriminator.parameters())
438 | if discriminator_gradient_clipping:
439 | nn.utils.clip_grad_norm_(discriminator.parameters(), discriminator_gradient_clipping)
440 |
441 | discriminator_optimizer.step()
442 |
443 | real_loss = real_loss.item()
444 | fake_loss = fake_loss.item()
445 | loss = loss.item()
446 | norm = norm.item()
447 | return loss, real_loss, fake_loss, norm
448 |
449 | try:
450 | last_save_time = time.time()
451 | step_for = first_to_train
452 | g_steps = 0
453 | d_steps = 0
454 |
455 | for i, (events, controls) in enumerate(batch_data_generator):
456 | steps, batch_size = events.shape
457 | init = torch.randn(batch_size, model.init_dim).to(device)
458 | events = torch.LongTensor(events).to(device)
459 |
460 | use_control = np.random.random() <= control_ratio
461 | controls = torch.FloatTensor(controls).to(device) if use_control else None
462 |
463 | if step_for == 'G':
464 | q_mean, norm = train_generator(batch_size, init, events, controls)
465 | g_steps += 1
466 |
467 | print(f'{i} (G-step) Q_mean: {q_mean}, norm: {norm}')
468 | if enable_logging:
469 | writer.add_scalar('adversarial/G/Q_mean', q_mean, i)
470 | writer.add_scalar('adversarial/G/norm', norm, i)
471 |
472 | if q_mean < g_min_q_mean:
473 | print(f'Q is too small: {q_mean}, exiting')
474 | raise KeyboardInterrupt
475 |
476 | if q_mean > g_max_q_mean or (g_max_steps and g_steps >= g_max_steps):
477 | step_for = 'D'
478 | d_steps = 0
479 |
480 | if step_for == 'D':
481 | loss, real_loss, fake_loss, norm = train_discriminator(batch_size, init, events, controls)
482 | d_steps += 1
483 |
484 | print(f'{i} (D-step) loss: {loss} (real: {real_loss}, fake: {fake_loss}), norm: {norm}')
485 | if enable_logging:
486 | writer.add_scalar('adversarial/D/loss', loss, i)
487 | writer.add_scalar('adversarial/D/norm', norm, i)
488 |
489 | if fake_loss <= real_loss < d_min_loss or (d_max_steps and d_steps >= d_max_steps):
490 | step_for = 'G'
491 | g_steps = 0
492 |
493 | if last_save_time + save_interval < time.time():
494 | last_save_time = time.time()
495 | save()
496 |
497 | except KeyboardInterrupt:
498 | save()
499 |
500 |
501 |
502 | #========================================================================
503 | # Script Arguments
504 | #========================================================================
505 |
506 | def batch_generator(args):
507 | print('-' * 70)
508 | dataset = Dataset(args.dataset_path, verbose=True)
509 | print(dataset)
510 | return dataset.batches(args.batch_size, args.window_size, args.stride_size)
511 |
512 | def pretrain(args):
513 | pretrain_discriminator(model_sess_path=args.generator_session_path,
514 | discriminator_sess_path=args.discriminator_session_path,
515 | discriminator_config_overwrite=utils.params2dict(args.discriminator_parameters),
516 | batch_data_generator=args.batch_generator(args),
517 | gradient_clipping=args.gradient_clipping,
518 | sample_factor=args.sample_factor,
519 | auto_sample_factor=args.auto_sample_factor,
520 | control_ratio=args.control_ratio,
521 | num_iter=args.stop_iteration,
522 | save_interval=args.save_interval,
523 | discriminator_lr=args.discriminator_learning_rate,
524 | enable_logging=args.enable_logging)
525 |
526 | def adversarial(args):
527 | train_adversarial(sess_path=args.session_path,
528 | batch_data_generator=args.batch_generator(args),
529 | model_load_path=args.generator_load_path,
530 | discriminator_load_path=args.discriminator_load_path,
531 | model_optimizer_class=getattr(optim, args.generator_optimizer),
532 | discriminator_optimizer_class=getattr(optim, args.discriminator_optimizer),
533 | model_gradient_clipping=args.generator_gradient_clipping,
534 | discriminator_gradient_clipping=args.discriminator_gradient_clipping,
535 | model_learning_rate=args.generator_learning_rate,
536 | discriminator_learning_rate=args.discriminator_learning_rate,
537 | reset_model_optimizer=args.reset_generator_optimizer,
538 | reset_discriminator_optimizer=args.reset_discriminator_optimizer,
539 | g_max_q_mean=args.g_max_q_mean,
540 | g_min_q_mean=args.g_min_q_mean,
541 | d_min_loss=args.d_min_loss,
542 | g_max_steps=args.g_max_steps,
543 | d_max_steps=args.d_max_steps,
544 | mc_sample_size=args.monte_carlo_sample_size,
545 | mc_sample_factor=args.monte_carlo_sample_factor,
546 | first_to_train=args.first_to_train,
547 | control_ratio=args.control_ratio,
548 | save_interval=args.save_interval,
549 | enable_logging=args.enable_logging)
550 |
551 | def get_args():
552 | parser = argparse.ArgumentParser()
553 | subparsers = parser.add_subparsers()
554 | parser.add_argument('-d', '--dataset-path', type=str, required=True)
555 | parser.add_argument('-b', '--batch-size', type=int, default=64)
556 | parser.add_argument('-w', '--window-size', type=int, default=200)
557 | parser.add_argument('-s', '--stride-size', type=int, default=10)
558 | parser.set_defaults(batch_generator=batch_generator)
559 | pre_parser = subparsers.add_parser('pretrain', aliases=['p', 'pre'])
560 | pre_parser.add_argument('-G', '--generator-session-path', type=str, default=True)
561 | pre_parser.add_argument('-D', '--discriminator-session-path', type=str, required=True)
562 | pre_parser.add_argument('-p', '--discriminator-parameters', type=str, default='')
563 | pre_parser.add_argument('-l', '--discriminator-learning-rate', type=float, default=0.001)
564 | pre_parser.add_argument('-g', '--gradient-clipping', type=float, default=False)
565 | pre_parser.add_argument('-f', '--sample-factor', type=float, default=1.0)
566 | pre_parser.add_argument('-af', '--auto-sample-factor', action='store_true', default=False)
567 | pre_parser.add_argument('-c', '--control-ratio', type=float, default=1.0)
568 | pre_parser.add_argument('-n', '--stop-iteration', type=int, default=-1)
569 | pre_parser.add_argument('-i', '--save-interval', type=float, default=60.0)
570 | pre_parser.add_argument('-L', '--enable-logging', action='store_true', default=False)
571 | pre_parser.set_defaults(main=pretrain)
572 | adv_parser = subparsers.add_parser('adversarial', aliases=['a', 'adv'])
573 | adv_parser.add_argument('-S', '--session-path', type=str, required=True)
574 | adv_parser.add_argument('-Gp', '--generator-load-path', type=str)
575 | adv_parser.add_argument('-Dp', '--discriminator-load-path', type=str)
576 | adv_parser.add_argument('-Go', '--generator-optimizer', type=str, default='Adam')
577 | adv_parser.add_argument('-Do', '--discriminator-optimizer', type=str, default='RMSprop')
578 | adv_parser.add_argument('-Gg', '--generator-gradient-clipping', type=float, default=False)
579 | adv_parser.add_argument('-Dg', '--discriminator-gradient-clipping', type=float, default=False)
580 | adv_parser.add_argument('-Gl', '--generator-learning-rate', type=float, default=0.001)
581 | adv_parser.add_argument('-Dl', '--discriminator-learning-rate', type=float, default=0.001)
582 | adv_parser.add_argument('-Gr', '--reset-generator-optimizer', action='store_true', default=False)
583 | adv_parser.add_argument('-Dr', '--reset-discriminator-optimizer', action='store_true', default=False)
584 | adv_parser.add_argument('-Gq', '--g-max-q-mean', type=float, default=0.5)
585 | adv_parser.add_argument('-Gm', '--g-min-q-mean', type=float, default=0.0)
586 | adv_parser.add_argument('-Dm', '--d-min-loss', type=float, default=0.5)
587 | adv_parser.add_argument('-Gs', '--g-max-steps', type=int, default=0)
588 | adv_parser.add_argument('-Ds', '--d-max-steps', type=int, default=0)
589 | adv_parser.add_argument('-f', '--first-to-train', type=str, default='G', choices=['G', 'D'])
590 | adv_parser.add_argument('-ms', '--monte-carlo-sample-size', type=int, default=8)
591 | adv_parser.add_argument('-mf', '--monte-carlo-sample-factor', type=float, default=1.0)
592 | adv_parser.add_argument('-c', '--control-ratio', type=float, default=1.0)
593 | adv_parser.add_argument('-i', '--save-interval', type=float, default=60.0)
594 | adv_parser.add_argument('-L', '--enable-logging', action='store_true', default=False)
595 | adv_parser.set_defaults(main=adversarial)
596 | return parser.parse_args()
597 |
598 |
599 | if __name__ == '__main__':
600 | args = get_args()
601 | args.main(args)
602 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from sequence import EventSeq, ControlSeq
3 |
4 | #pylint: disable=E1101
5 |
6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7 |
8 | model = {
9 | 'init_dim': 32,
10 | 'event_dim': EventSeq.dim(),
11 | 'control_dim': ControlSeq.dim(),
12 | 'hidden_dim': 512,
13 | 'gru_layers': 3,
14 | 'gru_dropout': 0.3,
15 | }
16 |
17 | train = {
18 | 'learning_rate': 0.001,
19 | 'batch_size': 64,
20 | 'window_size': 200,
21 | 'stride_size': 10,
22 | 'use_transposition': False,
23 | 'control_ratio': 1.0,
24 | 'teacher_forcing_ratio': 1.0
25 | }
26 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import itertools, os
4 | import numpy as np
5 | from progress.bar import Bar
6 |
7 | import config
8 | import utils
9 | from sequence import EventSeq, ControlSeq
10 |
11 | # pylint: disable=E1101
12 | # pylint: disable=W0101
13 |
14 | class Dataset:
15 | def __init__(self, root, verbose=False):
16 | assert os.path.isdir(root), root
17 | paths = utils.find_files_by_extensions(root, ['.data'])
18 | self.root = root
19 | self.samples = []
20 | self.seqlens = []
21 | if verbose:
22 | paths = Bar(root).iter(list(paths))
23 | for path in paths:
24 | eventseq, controlseq = torch.load(path)
25 | controlseq = ControlSeq.recover_compressed_array(controlseq)
26 | assert len(eventseq) == len(controlseq)
27 | self.samples.append((eventseq, controlseq))
28 | self.seqlens.append(len(eventseq))
29 | self.avglen = np.mean(self.seqlens)
30 |
31 | def batches(self, batch_size, window_size, stride_size):
32 | indeces = [(i, range(j, j + window_size))
33 | for i, seqlen in enumerate(self.seqlens)
34 | for j in range(0, seqlen - window_size, stride_size)]
35 | while True:
36 | eventseq_batch = []
37 | controlseq_batch = []
38 | n = 0
39 | for ii in np.random.permutation(len(indeces)):
40 | i, r = indeces[ii]
41 | eventseq, controlseq = self.samples[i]
42 | eventseq = eventseq[r.start:r.stop]
43 | controlseq = controlseq[r.start:r.stop]
44 | eventseq_batch.append(eventseq)
45 | controlseq_batch.append(controlseq)
46 | n += 1
47 | if n == batch_size:
48 | yield (np.stack(eventseq_batch, axis=1),
49 | np.stack(controlseq_batch, axis=1))
50 | eventseq_batch.clear()
51 | controlseq_batch.clear()
52 | n = 0
53 |
54 | def __repr__(self):
55 | return (f'Dataset(root="{self.root}", '
56 | f'samples={len(self.samples)}, '
57 | f'avglen={self.avglen})')
58 |
--------------------------------------------------------------------------------
/dataset/midi/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djosix/Performance-RNN-PyTorch/b47e6d3e5504c88394b0414e7ac175959369d3da/dataset/midi/.keep
--------------------------------------------------------------------------------
/dataset/processed/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djosix/Performance-RNN-PyTorch/b47e6d3e5504c88394b0414e7ac175959369d3da/dataset/processed/.keep
--------------------------------------------------------------------------------
/dataset/scripts/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djosix/Performance-RNN-PyTorch/b47e6d3e5504c88394b0414e7ac175959369d3da/dataset/scripts/.keep
--------------------------------------------------------------------------------
/dataset/scripts/classic_piano_downloader.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Scraper for Classical Piano Midi Page
3 | [ ! "$1" ] && echo 'Error: please specify output dir' && exit
4 | dir=$1
5 | base=http://www.piano-midi.de
6 | pages=$(curl -s --max-time 5 $base/midi_files.htm \
7 | | grep '
0
94 | stochastic_beam_search = opt.stochastic_beam_search
95 | beam_size = opt.beam_size
96 | temperature = opt.temperature
97 | init_zero = opt.init_zero
98 |
99 | if use_beam_search:
100 | greedy_ratio = 'DISABLED'
101 | else:
102 | beam_size = 'DISABLED'
103 |
104 | assert os.path.isfile(sess_path), f'"{sess_path}" is not a file'
105 |
106 | if control is not None:
107 | if os.path.isfile(control) or os.path.isdir(control):
108 | if os.path.isdir(control):
109 | files = list(utils.find_files_by_extensions(control))
110 | assert len(files) > 0, f'no file in "{control}"'
111 | control = np.random.choice(files)
112 | _, compressed_controls = torch.load(control)
113 | controls = ControlSeq.recover_compressed_array(compressed_controls)
114 | if max_len == 0:
115 | max_len = controls.shape[0]
116 | controls = torch.tensor(controls, dtype=torch.float32)
117 | controls = controls.unsqueeze(1).repeat(1, batch_size, 1).to(device)
118 | control = f'control sequence from "{control}"'
119 |
120 | else:
121 | pitch_histogram, note_density = control.split(';')
122 | pitch_histogram = list(filter(len, pitch_histogram.split(',')))
123 | if len(pitch_histogram) == 0:
124 | pitch_histogram = np.ones(12) / 12
125 | else:
126 | pitch_histogram = np.array(list(map(float, pitch_histogram)))
127 | assert pitch_histogram.size == 12
128 | assert np.all(pitch_histogram >= 0)
129 | pitch_histogram = pitch_histogram / pitch_histogram.sum() \
130 | if pitch_histogram.sum() else np.ones(12) / 12
131 | note_density = int(note_density)
132 | assert note_density in range(len(ControlSeq.note_density_bins))
133 | control = Control(pitch_histogram, note_density)
134 | controls = torch.tensor(control.to_array(), dtype=torch.float32)
135 | controls = controls.repeat(1, batch_size, 1).to(device)
136 | control = repr(control)
137 |
138 | else:
139 | controls = None
140 | control = 'NONE'
141 |
142 | assert max_len > 0, 'either max length or control sequence length should be given'
143 |
144 | # ------------------------------------------------------------------------
145 |
146 | print('-' * 70)
147 | print('Session:', sess_path)
148 | print('Batch size:', batch_size)
149 | print('Max length:', max_len)
150 | print('Greedy ratio:', greedy_ratio)
151 | print('Beam size:', beam_size)
152 | print('Beam search stochastic:', stochastic_beam_search)
153 | print('Output directory:', output_dir)
154 | print('Controls:', control)
155 | print('Temperature:', temperature)
156 | print('Init zero:', init_zero)
157 | print('-' * 70)
158 |
159 |
160 | # ========================================================================
161 | # Generating
162 | # ========================================================================
163 |
164 | state = torch.load(sess_path, map_location=device)
165 | model = PerformanceRNN(**state['model_config']).to(device)
166 | model.load_state_dict(state['model_state'])
167 | model.eval()
168 | print(model)
169 | print('-' * 70)
170 |
171 | if init_zero:
172 | init = torch.zeros(batch_size, model.init_dim).to(device)
173 | else:
174 | init = torch.randn(batch_size, model.init_dim).to(device)
175 |
176 | with torch.no_grad():
177 | if use_beam_search:
178 | outputs = model.beam_search(init, max_len, beam_size,
179 | controls=controls,
180 | temperature=temperature,
181 | stochastic=stochastic_beam_search,
182 | verbose=True)
183 | else:
184 | outputs = model.generate(init, max_len,
185 | controls=controls,
186 | greedy=greedy_ratio,
187 | temperature=temperature,
188 | verbose=True)
189 |
190 | outputs = outputs.cpu().numpy().T # [batch, steps]
191 |
192 |
193 | # ========================================================================
194 | # Saving
195 | # ========================================================================
196 |
197 | os.makedirs(output_dir, exist_ok=True)
198 |
199 | for i, output in enumerate(outputs):
200 | name = f'output-{i:03d}.mid'
201 | path = os.path.join(output_dir, name)
202 | n_notes = utils.event_indeces_to_midi_file(output, path)
203 | print(f'===> {path} ({n_notes} notes)')
204 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Categorical, Gumbel
5 |
6 | from collections import namedtuple
7 | import numpy as np
8 | from progress.bar import Bar
9 | from config import device
10 |
11 | # pylint: disable=E1101,E1102
12 |
13 |
14 | class PerformanceRNN(nn.Module):
15 | def __init__(self, event_dim, control_dim, init_dim, hidden_dim,
16 | gru_layers=3, gru_dropout=0.3):
17 | super().__init__()
18 |
19 | self.event_dim = event_dim
20 | self.control_dim = control_dim
21 | self.init_dim = init_dim
22 | self.hidden_dim = hidden_dim
23 | self.gru_layers = gru_layers
24 | self.concat_dim = event_dim + 1 + control_dim
25 | self.input_dim = hidden_dim
26 | self.output_dim = event_dim
27 |
28 | self.primary_event = self.event_dim - 1
29 |
30 | self.inithid_fc = nn.Linear(init_dim, gru_layers * hidden_dim)
31 | self.inithid_fc_activation = nn.Tanh()
32 |
33 | self.event_embedding = nn.Embedding(event_dim, event_dim)
34 | self.concat_input_fc = nn.Linear(self.concat_dim, self.input_dim)
35 | self.concat_input_fc_activation = nn.LeakyReLU(0.1, inplace=True)
36 |
37 | self.gru = nn.GRU(self.input_dim, self.hidden_dim,
38 | num_layers=gru_layers, dropout=gru_dropout)
39 | self.output_fc = nn.Linear(hidden_dim * gru_layers, self.output_dim)
40 | self.output_fc_activation = nn.Softmax(dim=-1)
41 |
42 | self._initialize_weights()
43 |
44 | def _initialize_weights(self):
45 | nn.init.xavier_normal_(self.event_embedding.weight)
46 | nn.init.xavier_normal_(self.inithid_fc.weight)
47 | self.inithid_fc.bias.data.fill_(0.)
48 | nn.init.xavier_normal_(self.concat_input_fc.weight)
49 | nn.init.xavier_normal_(self.output_fc.weight)
50 | self.output_fc.bias.data.fill_(0.)
51 |
52 | def _sample_event(self, output, greedy=True, temperature=1.0):
53 | if greedy:
54 | return output.argmax(-1)
55 | else:
56 | output = output / temperature
57 | probs = self.output_fc_activation(output)
58 | return Categorical(probs).sample()
59 |
60 | def forward(self, event, control=None, hidden=None):
61 | # One step forward
62 |
63 | assert len(event.shape) == 2
64 | assert event.shape[0] == 1
65 | batch_size = event.shape[1]
66 | event = self.event_embedding(event)
67 |
68 | if control is None:
69 | default = torch.ones(1, batch_size, 1).to(device)
70 | control = torch.zeros(1, batch_size, self.control_dim).to(device)
71 | else:
72 | default = torch.zeros(1, batch_size, 1).to(device)
73 | assert control.shape == (1, batch_size, self.control_dim)
74 |
75 | concat = torch.cat([event, default, control], -1)
76 | input = self.concat_input_fc(concat)
77 | input = self.concat_input_fc_activation(input)
78 |
79 | _, hidden = self.gru(input, hidden)
80 | output = hidden.permute(1, 0, 2).contiguous()
81 | output = output.view(batch_size, -1).unsqueeze(0)
82 | output = self.output_fc(output)
83 | return output, hidden
84 |
85 | def get_primary_event(self, batch_size):
86 | return torch.LongTensor([[self.primary_event] * batch_size]).to(device)
87 |
88 | def init_to_hidden(self, init):
89 | # [batch_size, init_dim]
90 | batch_size = init.shape[0]
91 | out = self.inithid_fc(init)
92 | out = self.inithid_fc_activation(out)
93 | out = out.view(self.gru_layers, batch_size, self.hidden_dim)
94 | return out
95 |
96 | def expand_controls(self, controls, steps):
97 | # [1 or steps, batch_size, control_dim]
98 | assert len(controls.shape) == 3
99 | assert controls.shape[2] == self.control_dim
100 | if controls.shape[0] > 1:
101 | assert controls.shape[0] >= steps
102 | return controls[:steps]
103 | return controls.repeat(steps, 1, 1)
104 |
105 | def generate(self, init, steps, events=None, controls=None, greedy=1.0,
106 | temperature=1.0, teacher_forcing_ratio=1.0, output_type='index', verbose=False):
107 | # init [batch_size, init_dim]
108 | # events [steps, batch_size] indeces
109 | # controls [1 or steps, batch_size, control_dim]
110 |
111 | batch_size = init.shape[0]
112 | assert init.shape[1] == self.init_dim
113 | assert steps > 0
114 |
115 | use_teacher_forcing = events is not None
116 | if use_teacher_forcing:
117 | assert len(events.shape) == 2
118 | assert events.shape[0] >= steps - 1
119 | events = events[:steps-1]
120 |
121 | event = self.get_primary_event(batch_size)
122 | use_control = controls is not None
123 | if use_control:
124 | controls = self.expand_controls(controls, steps)
125 | hidden = self.init_to_hidden(init)
126 |
127 | outputs = []
128 | step_iter = range(steps)
129 | if verbose:
130 | step_iter = Bar('Generating').iter(step_iter)
131 |
132 | for step in step_iter:
133 | control = controls[step].unsqueeze(0) if use_control else None
134 | output, hidden = self.forward(event, control, hidden)
135 |
136 | use_greedy = np.random.random() < greedy
137 | event = self._sample_event(output, greedy=use_greedy,
138 | temperature=temperature)
139 |
140 | if output_type == 'index':
141 | outputs.append(event)
142 | elif output_type == 'softmax':
143 | outputs.append(self.output_fc_activation(output))
144 | elif output_type == 'logit':
145 | outputs.append(output)
146 | else:
147 | assert False
148 |
149 | if use_teacher_forcing and step < steps - 1: # avoid last one
150 | if np.random.random() <= teacher_forcing_ratio:
151 | event = events[step].unsqueeze(0)
152 |
153 | return torch.cat(outputs, 0)
154 |
155 | def beam_search(self, init, steps, beam_size, controls=None,
156 | temperature=1.0, stochastic=False, verbose=False):
157 | assert len(init.shape) == 2 and init.shape[1] == self.init_dim
158 | assert self.event_dim >= beam_size > 0 and steps > 0
159 |
160 | batch_size = init.shape[0]
161 | current_beam_size = 1
162 |
163 | if controls is not None:
164 | controls = self.expand_controls(controls, steps) # [steps, batch_size, control_dim]
165 |
166 | # Initial hidden weights
167 | hidden = self.init_to_hidden(init) # [gru_layers, batch_size, hidden_size]
168 | hidden = hidden[:, :, None, :] # [gru_layers, batch_size, 1, hidden_size]
169 | hidden = hidden.repeat(1, 1, current_beam_size, 1) # [gru_layers, batch_size, beam_size, hidden_dim]
170 |
171 |
172 | # Initial event
173 | event = self.get_primary_event(batch_size) # [1, batch]
174 | event = event[:, :, None].repeat(1, 1, current_beam_size) # [1, batch, 1]
175 |
176 | # [batch, beam, 1] event sequences of beams
177 | beam_events = event[0, :, None, :].repeat(1, current_beam_size, 1)
178 |
179 | # [batch, beam] log probs sum of beams
180 | beam_log_prob = torch.zeros(batch_size, current_beam_size).to(device)
181 |
182 | if stochastic:
183 | # [batch, beam] Gumbel perturbed log probs of beams
184 | beam_log_prob_perturbed = torch.zeros(batch_size, current_beam_size).to(device)
185 | beam_z = torch.full((batch_size, beam_size), float('inf'))
186 | gumbel_dist = Gumbel(0, 1)
187 |
188 | step_iter = range(steps)
189 | if verbose:
190 | step_iter = Bar(['', 'Stochastic '][stochastic] + 'Beam Search').iter(step_iter)
191 |
192 | for step in step_iter:
193 | if controls is not None:
194 | control = controls[step, None, :, None, :] # [1, batch, 1, control]
195 | control = control.repeat(1, 1, current_beam_size, 1) # [1, batch, beam, control]
196 | control = control.view(1, batch_size * current_beam_size, self.control_dim) # [1, batch*beam, control]
197 | else:
198 | control = None
199 |
200 | event = event.view(1, batch_size * current_beam_size) # [1, batch*beam0]
201 | hidden = hidden.view(self.gru_layers, batch_size * current_beam_size, self.hidden_dim) # [grus, batch*beam, hid]
202 |
203 | logits, hidden = self.forward(event, control, hidden)
204 | hidden = hidden.view(self.gru_layers, batch_size, current_beam_size, self.hidden_dim) # [grus, batch, cbeam, hid]
205 | logits = (logits / temperature).view(1, batch_size, current_beam_size, self.event_dim) # [1, batch, cbeam, out]
206 |
207 | beam_log_prob_expand = logits + beam_log_prob[None, :, :, None] # [1, batch, cbeam, out]
208 | beam_log_prob_expand_batch = beam_log_prob_expand.view(1, batch_size, -1) # [1, batch, cbeam*out]
209 |
210 | if stochastic:
211 | beam_log_prob_expand_perturbed = beam_log_prob_expand + gumbel_dist.sample(beam_log_prob_expand.shape)
212 | beam_log_prob_Z, _ = beam_log_prob_expand_perturbed.max(-1) # [1, batch, cbeam]
213 | # print(beam_log_prob_Z)
214 | beam_log_prob_expand_perturbed_normalized = beam_log_prob_expand_perturbed
215 | # beam_log_prob_expand_perturbed_normalized = -torch.log(
216 | # torch.exp(-beam_log_prob_perturbed[None, :, :, None])
217 | # - torch.exp(-beam_log_prob_Z[:, :, :, None])
218 | # + torch.exp(-beam_log_prob_expand_perturbed)) # [1, batch, cbeam, out]
219 | # beam_log_prob_expand_perturbed_normalized = beam_log_prob_perturbed[None, :, :, None] + beam_log_prob_expand_perturbed # [1, batch, cbeam, out]
220 |
221 | beam_log_prob_expand_perturbed_normalized_batch = \
222 | beam_log_prob_expand_perturbed_normalized.view(1, batch_size, -1) # [1, batch, cbeam*out]
223 | _, top_indices = beam_log_prob_expand_perturbed_normalized_batch.topk(beam_size, -1) # [1, batch, cbeam]
224 |
225 | beam_log_prob_perturbed = \
226 | torch.gather(beam_log_prob_expand_perturbed_normalized_batch, -1, top_indices)[0] # [batch, beam]
227 |
228 | else:
229 | _, top_indices = beam_log_prob_expand_batch.topk(beam_size, -1)
230 |
231 | beam_log_prob = torch.gather(beam_log_prob_expand_batch, -1, top_indices)[0] # [batch, beam]
232 |
233 | beam_index_old = torch.arange(current_beam_size)[None, None, :, None] # [1, 1, cbeam, 1]
234 | beam_index_old = beam_index_old.repeat(1, batch_size, 1, self.output_dim) # [1, batch, cbeam, out]
235 | beam_index_old = beam_index_old.view(1, batch_size, -1) # [1, batch, cbeam*out]
236 | beam_index_new = torch.gather(beam_index_old, -1, top_indices)
237 |
238 | hidden = torch.gather(hidden, 2, beam_index_new[:, :, :, None].repeat(4, 1, 1, 1024))
239 |
240 | event_index = torch.arange(self.output_dim)[None, None, None, :] # [1, 1, 1, out]
241 | event_index = event_index.repeat(1, batch_size, current_beam_size, 1) # [1, batch, cbeam, out]
242 | event_index = event_index.view(1, batch_size, -1) # [1, batch, cbeam*out]
243 | event = torch.gather(event_index, -1, top_indices) # [1, batch, cbeam*out]
244 |
245 | beam_events = torch.gather(beam_events[None], 2, beam_index_new.unsqueeze(-1).repeat(1, 1, 1, beam_events.shape[-1]))
246 | beam_events = torch.cat([beam_events, event.unsqueeze(-1)], -1)[0]
247 |
248 | current_beam_size = beam_size
249 |
250 | best = beam_events[torch.arange(batch_size).long(), beam_log_prob.argmax(-1)]
251 | best = best.contiguous().t()
252 | return best
253 |
--------------------------------------------------------------------------------
/output/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djosix/Performance-RNN-PyTorch/b47e6d3e5504c88394b0414e7ac175959369d3da/output/.keep
--------------------------------------------------------------------------------
/play.py:
--------------------------------------------------------------------------------
1 | info = '''
2 | ------------------------------------
3 | Simple MIDI Player & Note Visualizer
4 | ------------------------------------
5 |
6 | Please install pygame and fluidsynth for Python 3:
7 |
8 | pip3 install pygame
9 | pip3 install git+https://github.com/txomon/pyfluidsynth.git
10 |
11 | Usage:
12 |
13 | python3 play.py soundfont_path midi_files...
14 |
15 | '''.lstrip()
16 |
17 |
18 | import os, sys, time, threading
19 | import pygame, fluidsynth
20 | import sequence, pretty_midi
21 | import numpy as np
22 |
23 | # pylint: disable=E1101
24 |
25 | entities = []
26 | lock = threading.Lock()
27 | done = False
28 |
29 | width = 1024
30 | height = 768
31 |
32 | #====================================================================
33 | # Display
34 | #====================================================================
35 |
36 | class NoteEntity:
37 | def __init__(self, key, velocity):
38 | self.done = False
39 | self.age = 255
40 | self.color = pygame.color.Color('black')
41 | self.radius = int(40 * velocity / 128)
42 | self.velocity = np.array([0., -velocity / 10], dtype=np.float32)
43 | pr = sequence.DEFAULT_PITCH_RANGE
44 | x = key * width / 128
45 | self.position = np.array([x, height], dtype=np.float32)
46 |
47 | def update(self):
48 | self.position += self.velocity
49 | self.velocity += np.random.randn(2) / 20
50 | self.velocity *= 0.99
51 | self.age -= 1
52 | if self.age == 0:
53 | self.done = True
54 |
55 | def get_color(self):
56 | return pygame.color.Color(self.age, self.age, self.age)
57 |
58 | def render(self, screen):
59 | x, y = self.position.astype(np.int32)
60 | color = self.get_color()
61 | pygame.draw.circle(screen, color, (x, y), self.radius)
62 |
63 | def add_entitiy(key, velocity):
64 | global entities, lock
65 | lock.acquire()
66 | entities.append(NoteEntity(key, velocity))
67 | lock.release()
68 |
69 | def display():
70 | global entities, lock, width, height, done
71 | pygame.init()
72 | pygame.display.set_caption('generator')
73 | screen = pygame.display.set_mode((width, height))
74 | clock = pygame.time.Clock()
75 |
76 | while not done:
77 | for event in pygame.event.get():
78 | if event.type == pygame.QUIT:
79 | done = True
80 |
81 | screen.fill(pygame.color.Color('black'))
82 |
83 | lock.acquire()
84 | entities = [entity for entity in entities if not entity.done]
85 | for i, entity in enumerate(entities):
86 | if i + 1 < len(entities):
87 | next_ent = entities[i + 1]
88 | pygame.draw.line(screen, entity.get_color(), entity.position, next_ent.position)
89 | for entity in entities:
90 | entity.render(screen)
91 | entity.update()
92 | lock.release()
93 |
94 | pygame.display.flip()
95 | clock.tick(60)
96 |
97 | pygame.quit()
98 |
99 |
100 |
101 | #====================================================================
102 | # Synth
103 | #====================================================================
104 |
105 | def note_repr(key, velocity):
106 | octave = key // 12 - 1
107 | name = ['C', 'C#', 'D', 'D#', 'E', 'F',
108 | 'F#', 'G', 'G#', 'A', 'A#', 'B'][key % 12]
109 | vel = int(10 * velocity / 128)
110 | return f'({name}{octave} {vel})'
111 |
112 |
113 | def play(midi_files, sound_font_path):
114 | global done
115 |
116 | fs = fluidsynth.Synth(gain=5)
117 | fs.start()
118 |
119 | try:
120 | sfid = fs.sfload(sound_font_path)
121 | fs.program_select(0, sfid, 0, 0)
122 | except:
123 | print('Failed to load', sound_font_path)
124 | return
125 |
126 | for midi_file in midi_files:
127 |
128 | print(f'Playing {midi_file}')
129 |
130 | try:
131 | note_seq = sequence.NoteSeq.from_midi_file(midi_file)
132 | event_seq = sequence.EventSeq.from_note_seq(note_seq)
133 | except:
134 | print('Failed to load', midi_file)
135 | continue
136 |
137 | velocity = sequence.DEFAULT_VELOCITY
138 | velocity_bins = sequence.EventSeq.get_velocity_bins()
139 | time_shift_bins = sequence.EventSeq.time_shift_bins
140 | pitch_start = sequence.EventSeq.pitch_range.start
141 |
142 | for event in event_seq.events:
143 | if event.type == 'note_on':
144 | key = int(event.value + pitch_start)
145 | fs.noteon(0, key, velocity)
146 | print(f' {note_repr(key, velocity)} ', end='', flush=True)
147 | add_entitiy(key, velocity)
148 | elif event.type == 'note_off':
149 | key = int(event.value + pitch_start)
150 | fs.noteoff(0, key)
151 | elif event.type == 'time_shift':
152 | print('.', end='', flush=True)
153 | time.sleep(time_shift_bins[event.value])
154 | elif event.type == 'velocity':
155 | velocity = int(velocity_bins[
156 | min(event.value, velocity_bins.size - 1)])
157 |
158 | print('Done')
159 |
160 | done = True
161 |
162 |
163 | #====================================================================
164 | # Main
165 | #====================================================================
166 |
167 |
168 | if __name__ == '__main__':
169 |
170 | try:
171 | sound_font_path = sys.argv[1]
172 | midi_files = sys.argv[2:]
173 | except:
174 | print(info)
175 | sys.exit()
176 |
177 | assert os.path.isfile(sound_font_path), sound_font_path
178 | for midi_path in midi_files:
179 | assert os.path.isfile(midi_path), midi_path
180 |
181 | threading.Thread(target=play,
182 | args=(midi_files, sound_font_path),
183 | daemon=True).start()
184 |
185 | display()
186 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import sys
4 | import torch
5 | import hashlib
6 | from progress.bar import Bar
7 | from concurrent.futures import ProcessPoolExecutor
8 |
9 | from sequence import NoteSeq, EventSeq, ControlSeq
10 | import utils
11 | import config
12 |
13 | def preprocess_midi(path):
14 | note_seq = NoteSeq.from_midi_file(path)
15 | note_seq.adjust_time(-note_seq.notes[0].start)
16 | event_seq = EventSeq.from_note_seq(note_seq)
17 | control_seq = ControlSeq.from_event_seq(event_seq)
18 | return event_seq.to_array(), control_seq.to_compressed_array()
19 |
20 | def preprocess_midi_files_under(midi_root, save_dir, num_workers):
21 | midi_paths = list(utils.find_files_by_extensions(midi_root, ['.mid', '.midi']))
22 | os.makedirs(save_dir, exist_ok=True)
23 | out_fmt = '{}-{}.data'
24 |
25 | results = []
26 | executor = ProcessPoolExecutor(num_workers)
27 |
28 | for path in midi_paths:
29 | try:
30 | results.append((path, executor.submit(preprocess_midi, path)))
31 | except KeyboardInterrupt:
32 | print(' Abort')
33 | return
34 | except:
35 | print(' Error')
36 | continue
37 |
38 | for path, future in Bar('Processing').iter(results):
39 | print(' ', end='[{}]'.format(path), flush=True)
40 | name = os.path.basename(path)
41 | code = hashlib.md5(path.encode()).hexdigest()
42 | save_path = os.path.join(save_dir, out_fmt.format(name, code))
43 | torch.save(future.result(), save_path)
44 |
45 | print('Done')
46 |
47 | if __name__ == '__main__':
48 | preprocess_midi_files_under(
49 | midi_root=sys.argv[1],
50 | save_dir=sys.argv[2],
51 | num_workers=int(sys.argv[3]))
52 |
--------------------------------------------------------------------------------
/runs/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djosix/Performance-RNN-PyTorch/b47e6d3e5504c88394b0414e7ac175959369d3da/runs/.keep
--------------------------------------------------------------------------------
/save/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djosix/Performance-RNN-PyTorch/b47e6d3e5504c88394b0414e7ac175959369d3da/save/.keep
--------------------------------------------------------------------------------
/sequence.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import copy
3 | import itertools
4 | import collections
5 | from pretty_midi import PrettyMIDI, Note, Instrument
6 |
7 |
8 | # ==================================================================================
9 | # Parameters
10 | # ==================================================================================
11 |
12 | # NoteSeq -------------------------------------------------------------------------
13 |
14 | DEFAULT_SAVING_PROGRAM = 1
15 | DEFAULT_LOADING_PROGRAMS = range(128)
16 | DEFAULT_RESOLUTION = 220
17 | DEFAULT_TEMPO = 120
18 | DEFAULT_VELOCITY = 64
19 | DEFAULT_PITCH_RANGE = range(21, 109)
20 | DEFAULT_VELOCITY_RANGE = range(21, 109)
21 | DEFAULT_NORMALIZATION_BASELINE = 60 # C4
22 |
23 | # EventSeq ------------------------------------------------------------------------
24 |
25 | USE_VELOCITY = True
26 | BEAT_LENGTH = 60 / DEFAULT_TEMPO
27 | DEFAULT_TIME_SHIFT_BINS = 1.15 ** np.arange(32) / 65
28 | DEFAULT_VELOCITY_STEPS = 32
29 | DEFAULT_NOTE_LENGTH = BEAT_LENGTH * 2
30 | MIN_NOTE_LENGTH = BEAT_LENGTH / 2
31 |
32 | # ControlSeq ----------------------------------------------------------------------
33 |
34 | DEFAULT_WINDOW_SIZE = BEAT_LENGTH * 4
35 | DEFAULT_NOTE_DENSITY_BINS = np.arange(12) * 3 + 1
36 |
37 |
38 | # ==================================================================================
39 | # Notes
40 | # ==================================================================================
41 |
42 | class NoteSeq:
43 |
44 | @staticmethod
45 | def from_midi(midi, programs=DEFAULT_LOADING_PROGRAMS):
46 | notes = itertools.chain(*[
47 | inst.notes for inst in midi.instruments
48 | if inst.program in programs and not inst.is_drum])
49 | return NoteSeq(list(notes))
50 |
51 | @staticmethod
52 | def from_midi_file(path, *args, **kwargs):
53 | midi = PrettyMIDI(path)
54 | return NoteSeq.from_midi(midi, *args, **kwargs)
55 |
56 | @staticmethod
57 | def merge(*note_seqs):
58 | notes = itertools.chain(*[seq.notes for seq in note_seqs])
59 | return NoteSeq(list(notes))
60 |
61 | def __init__(self, notes=[]):
62 | self.notes = []
63 | if notes:
64 | for note in notes:
65 | assert isinstance(note, Note)
66 | notes = filter(lambda note: note.end >= note.start, notes)
67 | self.add_notes(list(notes))
68 |
69 | def copy(self):
70 | return copy.deepcopy(self)
71 |
72 | def to_midi(self, program=DEFAULT_SAVING_PROGRAM,
73 | resolution=DEFAULT_RESOLUTION, tempo=DEFAULT_TEMPO):
74 | midi = PrettyMIDI(resolution=resolution, initial_tempo=tempo)
75 | inst = Instrument(program, False, 'NoteSeq')
76 | inst.notes = copy.deepcopy(self.notes)
77 | midi.instruments.append(inst)
78 | return midi
79 |
80 | def to_midi_file(self, path, *args, **kwargs):
81 | self.to_midi(*args, **kwargs).write(path)
82 |
83 | def add_notes(self, notes):
84 | self.notes += notes
85 | self.notes.sort(key=lambda note: note.start)
86 |
87 | def adjust_pitches(self, offset):
88 | for note in self.notes:
89 | pitch = note.pitch + offset
90 | pitch = 0 if pitch < 0 else pitch
91 | pitch = 127 if pitch > 127 else pitch
92 | note.pitch = pitch
93 |
94 | def adjust_velocities(self, offset):
95 | for note in self.notes:
96 | velocity = note.velocity + offset
97 | velocity = 0 if velocity < 0 else velocity
98 | velocity = 127 if velocity > 127 else velocity
99 | note.velocity = velocity
100 |
101 | def adjust_time(self, offset):
102 | for note in self.notes:
103 | note.start += offset
104 | note.end += offset
105 |
106 | def trim_overlapped_notes(self, min_interval=0):
107 | last_notes = {}
108 | for i, note in enumerate(self.notes):
109 | if note.pitch in last_notes:
110 | last_note = last_notes[note.pitch]
111 | if note.start - last_note.start <= min_interval:
112 | last_note.end = max(note.end, last_note.end)
113 | last_note.velocity = max(note.velocity, last_note.velocity)
114 | del self.notes[i]
115 | elif note.start < last_note.end:
116 | last_note.end = note.start
117 | else:
118 | last_notes[note.pitch] = note
119 |
120 |
121 | # ==================================================================================
122 | # Events
123 | # ==================================================================================
124 |
125 | class Event:
126 |
127 | def __init__(self, type, time, value):
128 | self.type = type
129 | self.time = time
130 | self.value = value
131 |
132 | def __repr__(self):
133 | return 'Event(type={}, time={}, value={})'.format(
134 | self.type, self.time, self.value)
135 |
136 |
137 | class EventSeq:
138 |
139 | pitch_range = DEFAULT_PITCH_RANGE
140 | velocity_range = DEFAULT_VELOCITY_RANGE
141 | velocity_steps = DEFAULT_VELOCITY_STEPS
142 | time_shift_bins = DEFAULT_TIME_SHIFT_BINS
143 |
144 | @staticmethod
145 | def from_note_seq(note_seq):
146 | note_events = []
147 |
148 | if USE_VELOCITY:
149 | velocity_bins = EventSeq.get_velocity_bins()
150 |
151 | for note in note_seq.notes:
152 | if note.pitch in EventSeq.pitch_range:
153 | if USE_VELOCITY:
154 | velocity = note.velocity
155 | velocity = max(velocity, EventSeq.velocity_range.start)
156 | velocity = min(velocity, EventSeq.velocity_range.stop - 1)
157 | velocity_index = np.searchsorted(velocity_bins, velocity)
158 | note_events.append(Event('velocity', note.start, velocity_index))
159 |
160 | pitch_index = note.pitch - EventSeq.pitch_range.start
161 | note_events.append(Event('note_on', note.start, pitch_index))
162 | note_events.append(Event('note_off', note.end, pitch_index))
163 |
164 | note_events.sort(key=lambda event: event.time) # stable
165 | events = []
166 |
167 | for i, event in enumerate(note_events):
168 | events.append(event)
169 |
170 | if event is note_events[-1]:
171 | break
172 |
173 | interval = note_events[i + 1].time - event.time
174 | shift = 0
175 |
176 | while interval - shift >= EventSeq.time_shift_bins[0]:
177 | index = np.searchsorted(EventSeq.time_shift_bins,
178 | interval - shift, side='right') - 1
179 | events.append(Event('time_shift', event.time + shift, index))
180 | shift += EventSeq.time_shift_bins[index]
181 |
182 | return EventSeq(events)
183 |
184 | @staticmethod
185 | def from_array(event_indeces):
186 | time = 0
187 | events = []
188 | for event_index in event_indeces:
189 | for event_type, feat_range in EventSeq.feat_ranges().items():
190 | if feat_range.start <= event_index < feat_range.stop:
191 | event_value = event_index - feat_range.start
192 | events.append(Event(event_type, time, event_value))
193 | if event_type == 'time_shift':
194 | time += EventSeq.time_shift_bins[event_value]
195 | break
196 |
197 | return EventSeq(events)
198 |
199 | @staticmethod
200 | def dim():
201 | return sum(EventSeq.feat_dims().values())
202 |
203 | @staticmethod
204 | def feat_dims():
205 | feat_dims = collections.OrderedDict()
206 | feat_dims['note_on'] = len(EventSeq.pitch_range)
207 | feat_dims['note_off'] = len(EventSeq.pitch_range)
208 | if USE_VELOCITY:
209 | feat_dims['velocity'] = EventSeq.velocity_steps
210 | feat_dims['time_shift'] = len(EventSeq.time_shift_bins)
211 | return feat_dims
212 |
213 | @staticmethod
214 | def feat_ranges():
215 | offset = 0
216 | feat_ranges = collections.OrderedDict()
217 | for feat_name, feat_dim in EventSeq.feat_dims().items():
218 | feat_ranges[feat_name] = range(offset, offset + feat_dim)
219 | offset += feat_dim
220 | return feat_ranges
221 |
222 | @staticmethod
223 | def get_velocity_bins():
224 | n = EventSeq.velocity_range.stop - EventSeq.velocity_range.start
225 | return np.arange(EventSeq.velocity_range.start,
226 | EventSeq.velocity_range.stop,
227 | n / (EventSeq.velocity_steps - 1))
228 |
229 | def __init__(self, events=[]):
230 | for event in events:
231 | assert isinstance(event, Event)
232 |
233 | self.events = copy.deepcopy(events)
234 |
235 | # compute event times again
236 | time = 0
237 | for event in self.events:
238 | event.time = time
239 | if event.type == 'time_shift':
240 | time += EventSeq.time_shift_bins[event.value]
241 |
242 | def to_note_seq(self):
243 | time = 0
244 | notes = []
245 |
246 | velocity = DEFAULT_VELOCITY
247 | velocity_bins = EventSeq.get_velocity_bins()
248 |
249 | last_notes = {}
250 |
251 | for event in self.events:
252 | if event.type == 'note_on':
253 | pitch = event.value + EventSeq.pitch_range.start
254 | note = Note(velocity, pitch, time, None)
255 | notes.append(note)
256 | last_notes[pitch] = note
257 |
258 | elif event.type == 'note_off':
259 | pitch = event.value + EventSeq.pitch_range.start
260 |
261 | if pitch in last_notes:
262 | note = last_notes[pitch]
263 | note.end = max(time, note.start + MIN_NOTE_LENGTH)
264 | del last_notes[pitch]
265 |
266 | elif event.type == 'velocity':
267 | index = min(event.value, velocity_bins.size - 1)
268 | velocity = velocity_bins[index]
269 |
270 | elif event.type == 'time_shift':
271 | time += EventSeq.time_shift_bins[event.value]
272 |
273 | for note in notes:
274 | if note.end is None:
275 | note.end = note.start + DEFAULT_NOTE_LENGTH
276 |
277 | note.velocity = int(note.velocity)
278 |
279 | return NoteSeq(notes)
280 |
281 | def to_array(self):
282 | feat_idxs = EventSeq.feat_ranges()
283 | idxs = [feat_idxs[event.type][event.value] for event in self.events]
284 | dtype = np.uint8 if EventSeq.dim() <= 256 else np.uint16
285 | return np.array(idxs, dtype=dtype)
286 |
287 |
288 | # ==================================================================================
289 | # Controls
290 | # ==================================================================================
291 |
292 | class Control:
293 |
294 | def __init__(self, pitch_histogram, note_density):
295 | self.pitch_histogram = pitch_histogram # list
296 | self.note_density = note_density # int
297 |
298 | def __repr__(self):
299 | return 'Control(pitch_histogram={}, note_density={})'.format(
300 | self.pitch_histogram, self.note_density)
301 |
302 | def to_array(self):
303 | feat_dims = ControlSeq.feat_dims()
304 | ndens = np.zeros([feat_dims['note_density']])
305 | ndens[self.note_density] = 1. # [dens_dim]
306 | phist = np.array(self.pitch_histogram) # [hist_dim]
307 | return np.concatenate([ndens, phist], 0) # [dens_dim + hist_dim]
308 |
309 |
310 | class ControlSeq:
311 |
312 | note_density_bins = DEFAULT_NOTE_DENSITY_BINS
313 | window_size = DEFAULT_WINDOW_SIZE
314 |
315 | @staticmethod
316 | def from_event_seq(event_seq):
317 | events = list(event_seq.events)
318 | start, end = 0, 0
319 |
320 | pitch_count = np.zeros([12])
321 | note_count = 0
322 |
323 | controls = []
324 |
325 | def _rel_pitch(pitch):
326 | return (pitch - 24) % 12
327 |
328 | for i, event in enumerate(events):
329 |
330 | while start < i:
331 | if events[start].type == 'note_on':
332 | abs_pitch = events[start].value + EventSeq.pitch_range.start
333 | rel_pitch = _rel_pitch(abs_pitch)
334 | pitch_count[rel_pitch] -= 1.
335 | note_count -= 1.
336 | start += 1
337 |
338 | while end < len(events):
339 | if events[end].time - event.time > ControlSeq.window_size:
340 | break
341 | if events[end].type == 'note_on':
342 | abs_pitch = events[end].value + EventSeq.pitch_range.start
343 | rel_pitch = _rel_pitch(abs_pitch)
344 | pitch_count[rel_pitch] += 1.
345 | note_count += 1.
346 | end += 1
347 |
348 | pitch_histogram = (
349 | pitch_count / note_count
350 | if note_count
351 | else np.ones([12]) / 12
352 | ).tolist()
353 |
354 | note_density = max(np.searchsorted(
355 | ControlSeq.note_density_bins,
356 | note_count, side='right') - 1, 0)
357 |
358 | controls.append(Control(pitch_histogram, note_density))
359 |
360 | return ControlSeq(controls)
361 |
362 | @staticmethod
363 | def dim():
364 | return sum(ControlSeq.feat_dims().values())
365 |
366 | @staticmethod
367 | def feat_dims():
368 | note_density_dim = len(ControlSeq.note_density_bins)
369 | return collections.OrderedDict([
370 | ('pitch_histogram', 12),
371 | ('note_density', note_density_dim)
372 | ])
373 |
374 | @staticmethod
375 | def feat_ranges():
376 | offset = 0
377 | feat_ranges = collections.OrderedDict()
378 | for feat_name, feat_dim in ControlSeq.feat_dims().items():
379 | feat_ranges[feat_name] = range(offset, offset + feat_dim)
380 | offset += feat_dim
381 | return feat_ranges
382 |
383 | @staticmethod
384 | def recover_compressed_array(array):
385 | feat_dims = ControlSeq.feat_dims()
386 | assert array.shape[1] == 1 + feat_dims['pitch_histogram']
387 | ndens = np.zeros([array.shape[0], feat_dims['note_density']])
388 | ndens[np.arange(array.shape[0]), array[:, 0]] = 1. # [steps, dens_dim]
389 | phist = array[:, 1:].astype(np.float64) / 255 # [steps, hist_dim]
390 | return np.concatenate([ndens, phist], 1) # [steps, dens_dim + hist_dim]
391 |
392 | def __init__(self, controls):
393 | for control in controls:
394 | assert isinstance(control, Control)
395 | self.controls = copy.deepcopy(controls)
396 |
397 | def to_compressed_array(self):
398 | ndens = [control.note_density for control in self.controls]
399 | ndens = np.array(ndens, dtype=np.uint8).reshape(-1, 1)
400 | phist = [control.pitch_histogram for control in self.controls]
401 | phist = (np.array(phist) * 255).astype(np.uint8)
402 | return np.concatenate([
403 | ndens, # [steps, 1] density index
404 | phist # [steps, hist_dim] 0-255
405 | ], 1) # [steps, hist_dim + 1]
406 |
407 |
408 | if __name__ == '__main__':
409 | import pickle
410 | import sys
411 | path = sys.argv[1] if len(sys.argv) > 1 else 'dataset/midi/ecomp/BLINOV02.mid'
412 |
413 | print('Converting MIDI to EventSeq')
414 | es = EventSeq.from_note_seq(NoteSeq.from_midi_file(path))
415 |
416 | print('Converting EventSeq to MIDI')
417 | EventSeq.from_array(es.to_array()).to_note_seq().to_midi_file('/tmp/test.mid')
418 |
419 | print('Converting EventSeq to ControlSeq')
420 | cs = ControlSeq.from_event_seq(es)
421 |
422 | print('Saving compressed ControlSeq')
423 | pickle.dump(cs.to_compressed_array(), open('/tmp/cs-compressed.data', 'wb'))
424 |
425 | print('Loading compressed ControlSeq')
426 | c = ControlSeq.recover_compressed_array(
427 | pickle.load(open('/tmp/cs-compressed.data', 'rb')))
428 |
429 | print('Done')
430 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch import optim
4 | from torch.autograd import Variable
5 |
6 | import numpy as np
7 |
8 | import os
9 | import sys
10 | import time
11 | import optparse
12 |
13 |
14 | import utils
15 | import config
16 | from data import Dataset
17 | from model import PerformanceRNN
18 | from sequence import NoteSeq, EventSeq, ControlSeq
19 |
20 | # pylint: disable=E1102
21 | # pylint: disable=E1101
22 |
23 | #========================================================================
24 | # Settings
25 | #========================================================================
26 |
27 | def get_options():
28 | parser = optparse.OptionParser()
29 |
30 | parser.add_option('-s', '--session',
31 | dest='sess_path',
32 | type='string',
33 | default='save/train.sess')
34 |
35 | parser.add_option('-d', '--dataset',
36 | dest='data_path',
37 | type='string',
38 | default='dataset/processed/')
39 |
40 | parser.add_option('-i', '--saving-interval',
41 | dest='saving_interval',
42 | type='float',
43 | default=60.)
44 |
45 | parser.add_option('-b', '--batch-size',
46 | dest='batch_size',
47 | type='int',
48 | default=config.train['batch_size'])
49 |
50 | parser.add_option('-l', '--learning-rate',
51 | dest='learning_rate',
52 | type='float',
53 | default=config.train['learning_rate'])
54 |
55 | parser.add_option('-w', '--window-size',
56 | dest='window_size',
57 | type='int',
58 | default=config.train['window_size'])
59 |
60 | parser.add_option('-S', '--stride-size',
61 | dest='stride_size',
62 | type='int',
63 | default=config.train['stride_size'])
64 |
65 | parser.add_option('-c', '--control-ratio',
66 | dest='control_ratio',
67 | type='float',
68 | default=config.train['control_ratio'])
69 |
70 | parser.add_option('-T', '--teacher-forcing-ratio',
71 | dest='teacher_forcing_ratio',
72 | type='float',
73 | default=config.train['teacher_forcing_ratio'])
74 |
75 | parser.add_option('-t', '--use-transposition',
76 | dest='use_transposition',
77 | action='store_true',
78 | default=config.train['use_transposition'])
79 |
80 | parser.add_option('-p', '--model-params',
81 | dest='model_params',
82 | type='string',
83 | default='')
84 |
85 | parser.add_option('-r', '--reset-optimizer',
86 | dest='reset_optimizer',
87 | action='store_true',
88 | default=False)
89 |
90 | parser.add_option('-L', '--enable-logging',
91 | dest='enable_logging',
92 | action='store_true',
93 | default=False)
94 |
95 | return parser.parse_args()[0]
96 |
97 | options = get_options()
98 |
99 | #------------------------------------------------------------------------
100 |
101 | sess_path = options.sess_path
102 | data_path = options.data_path
103 | saving_interval = options.saving_interval
104 |
105 | learning_rate = options.learning_rate
106 | batch_size = options.batch_size
107 | window_size = options.window_size
108 | stride_size = options.stride_size
109 | use_transposition = options.use_transposition
110 | control_ratio = options.control_ratio
111 | teacher_forcing_ratio = options.teacher_forcing_ratio
112 | reset_optimizer = options.reset_optimizer
113 | enable_logging = options.enable_logging
114 |
115 | event_dim = EventSeq.dim()
116 | control_dim = ControlSeq.dim()
117 | model_config = config.model
118 | model_params = utils.params2dict(options.model_params)
119 | model_config.update(model_params)
120 | device = config.device
121 |
122 | print('-' * 70)
123 |
124 | print('Session path:', sess_path)
125 | print('Dataset path:', data_path)
126 | print('Saving interval:', saving_interval)
127 | print('-' * 70)
128 |
129 | print('Hyperparameters:', utils.dict2params(model_config))
130 | print('Learning rate:', learning_rate)
131 | print('Batch size:', batch_size)
132 | print('Window size:', window_size)
133 | print('Stride size:', stride_size)
134 | print('Control ratio:', control_ratio)
135 | print('Teacher forcing ratio:', teacher_forcing_ratio)
136 | print('Random transposition:', use_transposition)
137 | print('Reset optimizer:', reset_optimizer)
138 | print('Enabling logging:', enable_logging)
139 | print('Device:', device)
140 | print('-' * 70)
141 |
142 |
143 | #========================================================================
144 | # Load session and dataset
145 | #========================================================================
146 |
147 | def load_session():
148 | global sess_path, model_config, device, learning_rate, reset_optimizer
149 | try:
150 | sess = torch.load(sess_path)
151 | if 'model_config' in sess and sess['model_config'] != model_config:
152 | model_config = sess['model_config']
153 | print('Use session config instead:')
154 | print(utils.dict2params(model_config))
155 | model_state = sess['model_state']
156 | optimizer_state = sess['model_optimizer_state']
157 | print('Session is loaded from', sess_path)
158 | sess_loaded = True
159 | except:
160 | print('New session')
161 | sess_loaded = False
162 | model = PerformanceRNN(**model_config).to(device)
163 | optimizer = optim.Adam(model.parameters(), lr=learning_rate)
164 | if sess_loaded:
165 | model.load_state_dict(model_state)
166 | if not reset_optimizer:
167 | optimizer.load_state_dict(optimizer_state)
168 | return model, optimizer
169 |
170 | def load_dataset():
171 | global data_path
172 | dataset = Dataset(data_path, verbose=True)
173 | dataset_size = len(dataset.samples)
174 | assert dataset_size > 0
175 | return dataset
176 |
177 |
178 | print('Loading session')
179 | model, optimizer = load_session()
180 | print(model)
181 |
182 | print('-' * 70)
183 |
184 | print('Loading dataset')
185 | dataset = load_dataset()
186 | print(dataset)
187 |
188 | print('-' * 70)
189 |
190 | #------------------------------------------------------------------------
191 |
192 | def save_model():
193 | global model, optimizer, model_config, sess_path
194 | print('Saving to', sess_path)
195 | torch.save({'model_config': model_config,
196 | 'model_state': model.state_dict(),
197 | 'model_optimizer_state': optimizer.state_dict()}, sess_path)
198 | print('Done saving')
199 |
200 |
201 | #========================================================================
202 | # Training
203 | #========================================================================
204 |
205 | if enable_logging:
206 | from torch.utils.tensorboard import SummaryWriter
207 | writer = SummaryWriter()
208 |
209 | last_saving_time = time.time()
210 | loss_function = nn.CrossEntropyLoss()
211 |
212 | try:
213 | batch_gen = dataset.batches(batch_size, window_size, stride_size)
214 |
215 | for iteration, (events, controls) in enumerate(batch_gen):
216 | if use_transposition:
217 | offset = np.random.choice(np.arange(-6, 6))
218 | events, controls = utils.transposition(events, controls, offset)
219 |
220 | events = torch.LongTensor(events).to(device)
221 | assert events.shape[0] == window_size
222 |
223 | if np.random.random() < control_ratio:
224 | controls = torch.FloatTensor(controls).to(device)
225 | assert controls.shape[0] == window_size
226 | else:
227 | controls = None
228 |
229 | init = torch.randn(batch_size, model.init_dim).to(device)
230 | outputs = model.generate(init, window_size, events=events[:-1], controls=controls,
231 | teacher_forcing_ratio=teacher_forcing_ratio, output_type='logit')
232 | assert outputs.shape[:2] == events.shape[:2]
233 |
234 | loss = loss_function(outputs.view(-1, event_dim), events.view(-1))
235 | model.zero_grad()
236 | loss.backward()
237 |
238 | norm = utils.compute_gradient_norm(model.parameters())
239 | nn.utils.clip_grad_norm_(model.parameters(), 1.0)
240 |
241 | optimizer.step()
242 |
243 | if enable_logging:
244 | writer.add_scalar('model/loss', loss.item(), iteration)
245 | writer.add_scalar('model/norm', norm.item(), iteration)
246 |
247 | print(f'iter {iteration}, loss: {loss.item()}')
248 |
249 | if time.time() - last_saving_time > saving_interval:
250 | save_model()
251 | last_saving_time = time.time()
252 |
253 | except KeyboardInterrupt:
254 | save_model()
255 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from sequence import EventSeq, ControlSeq
4 |
5 |
6 | def find_files_by_extensions(root, exts=[]):
7 | def _has_ext(name):
8 | if not exts:
9 | return True
10 | name = name.lower()
11 | for ext in exts:
12 | if name.endswith(ext):
13 | return True
14 | return False
15 | for path, _, files in os.walk(root):
16 | for name in files:
17 | if _has_ext(name):
18 | yield os.path.join(path, name)
19 |
20 | def event_indeces_to_midi_file(event_indeces, midi_file_name, velocity_scale=0.8):
21 | event_seq = EventSeq.from_array(event_indeces)
22 | note_seq = event_seq.to_note_seq()
23 | for note in note_seq.notes:
24 | note.velocity = int((note.velocity - 64) * velocity_scale + 64)
25 | note_seq.to_midi_file(midi_file_name)
26 | return len(note_seq.notes)
27 |
28 | def transposition(events, controls, offset=0):
29 | # events [steps, batch_size, event_dim]
30 | # return events, controls
31 |
32 | events = np.array(events, dtype=np.int64)
33 | controls = np.array(controls, dtype=np.float32)
34 | event_feat_ranges = EventSeq.feat_ranges()
35 |
36 | on = event_feat_ranges['note_on']
37 | off = event_feat_ranges['note_off']
38 |
39 | if offset > 0:
40 | indeces0 = (((on.start <= events) & (events < on.stop - offset)) |
41 | ((off.start <= events) & (events < off.stop - offset)))
42 | indeces1 = (((on.stop - offset <= events) & (events < on.stop)) |
43 | ((off.stop - offset <= events) & (events < off.stop)))
44 | events[indeces0] += offset
45 | events[indeces1] += offset - 12
46 | elif offset < 0:
47 | indeces0 = (((on.start - offset <= events) & (events < on.stop)) |
48 | ((off.start - offset <= events) & (events < off.stop)))
49 | indeces1 = (((on.start <= events) & (events < on.start - offset)) |
50 | ((off.start <= events) & (events < off.start - offset)))
51 | events[indeces0] += offset
52 | events[indeces1] += offset + 12
53 |
54 | assert ((0 <= events) & (events < EventSeq.dim())).all()
55 | histr = ControlSeq.feat_ranges()['pitch_histogram']
56 | controls[:, :, histr.start:histr.stop] = np.roll(
57 | controls[:, :, histr.start:histr.stop], offset, -1)
58 |
59 | return events, controls
60 |
61 | def dict2params(d, f=','):
62 | return f.join(f'{k}={v}' for k, v in d.items())
63 |
64 | def params2dict(p, f=',', e='='):
65 | d = {}
66 | for item in p.split(f):
67 | item = item.split(e)
68 | if len(item) < 2:
69 | continue
70 | k, *v = item
71 | d[k] = eval('='.join(v))
72 | return d
73 |
74 | def compute_gradient_norm(parameters, norm_type=2):
75 | total_norm = 0
76 | for p in parameters:
77 | param_norm = p.grad.data.norm(norm_type)
78 | total_norm += param_norm ** norm_type
79 | total_norm = total_norm ** (1. / norm_type)
80 | return total_norm
81 |
--------------------------------------------------------------------------------
|