├── .gitignore
├── DatasetManager
├── __init__.py
├── chorale_dataset.py
├── dataset_manager.py
├── helpers.py
├── metadata.py
└── music_dataset.py
├── DeepBach
├── __init__.py
├── data_utils.py
├── helpers.py
├── metadata.py
├── model_manager.py
└── voice_model.py
├── Dockerfile
├── LICENSE
├── README.md
├── cog.yaml
├── deepBach.py
├── deepBachMuseScore.qml
├── dl_dataset_and_models.sh
├── entrypoint.sh
├── environment.yml
├── flask_server.py
├── musescore_flask_server.py
└── predict.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | # some files
6 | _recomposed_by_deepBach.xml
7 | deepbach.xml
8 | container.xml
9 | .vscode/
10 | settings.json
11 | launch.json
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | env/
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *,cover
53 | .hypothesis/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 | app.log.*
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # IPython Notebook
78 | .ipynb_checkpoints
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # dotenv
87 | .env
88 |
89 | # virtualenv
90 | venv/
91 | ENV/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 |
96 | # Rope project settings
97 | .ropeproject
98 |
99 | # local files
100 | .idea/
101 | stats/
102 | deepbach_ressources/
103 | DeepBach/models/*.yaml
104 | DeepBach/models/*.h5
105 | old_models/
106 | good_model/
107 | generated_examples/
108 | DeepBach/datasets/raw_dataset/*.pickle
109 | DeepBach/datasets/custom_dataset/*.pickle
110 |
111 | *.tar.gz
112 | !download_pretrained_data.sh
113 | DatasetManager/dataset_cache/
114 | models/
--------------------------------------------------------------------------------
/DatasetManager/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ghadjeres/DeepBach/6d75cb940f3aa53e02f9eade34d58e472e0c95d7/DatasetManager/__init__.py
--------------------------------------------------------------------------------
/DatasetManager/chorale_dataset.py:
--------------------------------------------------------------------------------
1 | import music21
2 | import torch
3 | import numpy as np
4 |
5 | from music21 import interval, stream
6 | from torch.utils.data import TensorDataset
7 | from tqdm import tqdm
8 |
9 | from DatasetManager.helpers import standard_name, SLUR_SYMBOL, START_SYMBOL, END_SYMBOL, \
10 | standard_note, OUT_OF_RANGE, REST_SYMBOL
11 | from DatasetManager.metadata import FermataMetadata
12 | from DatasetManager.music_dataset import MusicDataset
13 |
14 |
15 | class ChoraleDataset(MusicDataset):
16 | """
17 | Class for all chorale-like datasets
18 | """
19 |
20 | def __init__(self,
21 | corpus_it_gen,
22 | name,
23 | voice_ids,
24 | metadatas=None,
25 | sequences_size=8,
26 | subdivision=4,
27 | cache_dir=None):
28 | """
29 | :param corpus_it_gen: calling this function returns an iterator
30 | over chorales (as music21 scores)
31 | :param name: name of the dataset
32 | :param voice_ids: list of voice_indexes to be used
33 | :param metadatas: list[Metadata], the list of used metadatas
34 | :param sequences_size: in beats
35 | :param subdivision: number of sixteenth notes per beat
36 | :param cache_dir: directory where tensor_dataset is stored
37 | """
38 | super(ChoraleDataset, self).__init__(cache_dir=cache_dir)
39 | self.voice_ids = voice_ids
40 | # TODO WARNING voice_ids is never used!
41 | self.num_voices = len(voice_ids)
42 | self.name = name
43 | self.sequences_size = sequences_size
44 | self.index2note_dicts = None
45 | self.note2index_dicts = None
46 | self.corpus_it_gen = corpus_it_gen
47 | self.voice_ranges = None # in midi pitch
48 | self.metadatas = metadatas
49 | self.subdivision = subdivision
50 |
51 | def __repr__(self):
52 | return f'ChoraleDataset(' \
53 | f'{self.voice_ids},' \
54 | f'{self.name},' \
55 | f'{[metadata.name for metadata in self.metadatas]},' \
56 | f'{self.sequences_size},' \
57 | f'{self.subdivision})'
58 |
59 | def iterator_gen(self):
60 | return (chorale
61 | for chorale in self.corpus_it_gen()
62 | if self.is_valid(chorale)
63 | )
64 |
65 | def make_tensor_dataset(self):
66 | """
67 | Implementation of the make_tensor_dataset abstract base class
68 | """
69 | # todo check on chorale with Chord
70 | print('Making tensor dataset')
71 | self.compute_index_dicts()
72 | self.compute_voice_ranges()
73 | one_tick = 1 / self.subdivision
74 | chorale_tensor_dataset = []
75 | metadata_tensor_dataset = []
76 | for chorale_id, chorale in tqdm(enumerate(self.iterator_gen())):
77 |
78 | # precompute all possible transpositions and corresponding metadatas
79 | chorale_transpositions = {}
80 | metadatas_transpositions = {}
81 |
82 | # main loop
83 | for offsetStart in np.arange(
84 | chorale.flat.lowestOffset -
85 | (self.sequences_size - one_tick),
86 | chorale.flat.highestOffset,
87 | one_tick):
88 | offsetEnd = offsetStart + self.sequences_size
89 | current_subseq_ranges = self.voice_range_in_subsequence(
90 | chorale,
91 | offsetStart=offsetStart,
92 | offsetEnd=offsetEnd)
93 |
94 | transposition = self.min_max_transposition(current_subseq_ranges)
95 | min_transposition_subsequence, max_transposition_subsequence = transposition
96 |
97 | for semi_tone in range(min_transposition_subsequence,
98 | max_transposition_subsequence + 1):
99 | start_tick = int(offsetStart * self.subdivision)
100 | end_tick = int(offsetEnd * self.subdivision)
101 |
102 | try:
103 | # compute transpositions lazily
104 | if semi_tone not in chorale_transpositions:
105 | (chorale_tensor,
106 | metadata_tensor) = self.transposed_score_and_metadata_tensors(
107 | chorale,
108 | semi_tone=semi_tone)
109 | chorale_transpositions.update(
110 | {semi_tone:
111 | chorale_tensor})
112 | metadatas_transpositions.update(
113 | {semi_tone:
114 | metadata_tensor})
115 | else:
116 | chorale_tensor = chorale_transpositions[semi_tone]
117 | metadata_tensor = metadatas_transpositions[semi_tone]
118 |
119 | local_chorale_tensor = self.extract_score_tensor_with_padding(
120 | chorale_tensor,
121 | start_tick, end_tick)
122 | local_metadata_tensor = self.extract_metadata_with_padding(
123 | metadata_tensor,
124 | start_tick, end_tick)
125 |
126 | # append and add batch dimension
127 | # cast to int
128 | chorale_tensor_dataset.append(
129 | local_chorale_tensor[None, :, :].int())
130 | metadata_tensor_dataset.append(
131 | local_metadata_tensor[None, :, :, :].int())
132 | except KeyError:
133 | # some problems may occur with the key analyzer
134 | print(f'KeyError with chorale {chorale_id}')
135 |
136 | chorale_tensor_dataset = torch.cat(chorale_tensor_dataset, 0)
137 | metadata_tensor_dataset = torch.cat(metadata_tensor_dataset, 0)
138 |
139 | dataset = TensorDataset(chorale_tensor_dataset,
140 | metadata_tensor_dataset)
141 |
142 | print(f'Sizes: {chorale_tensor_dataset.size()}, {metadata_tensor_dataset.size()}')
143 | return dataset
144 |
145 | def transposed_score_and_metadata_tensors(self, score, semi_tone):
146 | """
147 | Convert chorale to a couple (chorale_tensor, metadata_tensor),
148 | the original chorale is transposed semi_tone number of semi-tones
149 | :param chorale: music21 object
150 | :param semi_tone:
151 | :return: couple of tensors
152 | """
153 | # transpose
154 | # compute the most "natural" interval given a number of semi-tones
155 | interval_type, interval_nature = interval.convertSemitoneToSpecifierGeneric(
156 | semi_tone)
157 | transposition_interval = interval.Interval(
158 | str(interval_nature) + str(interval_type))
159 |
160 | chorale_tranposed = score.transpose(transposition_interval)
161 | chorale_tensor = self.get_score_tensor(
162 | chorale_tranposed,
163 | offsetStart=0.,
164 | offsetEnd=chorale_tranposed.flat.highestTime)
165 | metadatas_transposed = self.get_metadata_tensor(chorale_tranposed)
166 | return chorale_tensor, metadatas_transposed
167 |
168 | def get_metadata_tensor(self, score):
169 | """
170 | Adds also the index of the voices
171 | :param score: music21 stream
172 | :return:tensor (num_voices, chorale_length, len(self.metadatas) + 1)
173 | """
174 | md = []
175 | if self.metadatas:
176 | for metadata in self.metadatas:
177 | sequence_metadata = torch.from_numpy(
178 | metadata.evaluate(score, self.subdivision)).long().clone()
179 | square_metadata = sequence_metadata.repeat(self.num_voices, 1)
180 | md.append(
181 | square_metadata[:, :, None]
182 | )
183 | chorale_length = int(score.duration.quarterLength * self.subdivision)
184 |
185 | # add voice indexes
186 | voice_id_metada = torch.from_numpy(np.arange(self.num_voices)).long().clone()
187 | square_metadata = torch.transpose(voice_id_metada.repeat(chorale_length, 1),
188 | 0, 1)
189 | md.append(square_metadata[:, :, None])
190 |
191 | all_metadata = torch.cat(md, 2)
192 | return all_metadata
193 |
194 | def set_fermatas(self, metadata_tensor, fermata_tensor):
195 | """
196 | Impose fermatas for all chorales in a batch
197 | :param metadata_tensor: a (batch_size, sequences_size, num_metadatas)
198 | tensor
199 | :param fermata_tensor: a (sequences_size) binary tensor
200 | """
201 | if self.metadatas:
202 | for metadata_index, metadata in enumerate(self.metadatas):
203 | if isinstance(metadata, FermataMetadata):
204 | # uses broadcasting
205 | metadata_tensor[:, :, metadata_index] = fermata_tensor
206 | break
207 | return metadata_tensor
208 |
209 | def add_fermata(self, metadata_tensor, time_index_start, time_index_stop):
210 | """
211 | Shorthand function to impose a fermata between two time indexes
212 | """
213 | fermata_tensor = torch.zeros(self.sequences_size)
214 | fermata_tensor[time_index_start:time_index_stop] = 1
215 | metadata_tensor = self.set_fermatas(metadata_tensor, fermata_tensor)
216 | return metadata_tensor
217 |
218 | def min_max_transposition(self, current_subseq_ranges):
219 | if current_subseq_ranges is None:
220 | # todo might be too restrictive
221 | # there is no note in one part
222 | transposition = (0, 0) # min and max transpositions
223 | else:
224 | transpositions = [
225 | (min_pitch_corpus - min_pitch_current,
226 | max_pitch_corpus - max_pitch_current)
227 | for ((min_pitch_corpus, max_pitch_corpus),
228 | (min_pitch_current, max_pitch_current))
229 | in zip(self.voice_ranges, current_subseq_ranges)
230 | ]
231 | transpositions = [min_or_max_transposition
232 | for min_or_max_transposition in zip(*transpositions)]
233 | transposition = [max(transpositions[0]),
234 | min(transpositions[1])]
235 | return transposition
236 |
237 | def get_score_tensor(self, score, offsetStart, offsetEnd):
238 | chorale_tensor = []
239 | for part_id, part in enumerate(score.parts[:self.num_voices]):
240 | part_tensor = self.part_to_tensor(part, part_id,
241 | offsetStart=offsetStart,
242 | offsetEnd=offsetEnd)
243 | chorale_tensor.append(part_tensor)
244 | return torch.cat(chorale_tensor, 0)
245 |
246 | def part_to_tensor(self, part, part_id, offsetStart, offsetEnd):
247 | """
248 | :param part:
249 | :param part_id:
250 | :param offsetStart:
251 | :param offsetEnd:
252 | :return: torch IntTensor (1, length)
253 | """
254 | list_notes_and_rests = list(part.flat.getElementsByOffset(
255 | offsetStart=offsetStart,
256 | offsetEnd=offsetEnd,
257 | classList=[music21.note.Note,
258 | music21.note.Rest]))
259 | list_note_strings_and_pitches = [(n.nameWithOctave, n.pitch.midi)
260 | for n in list_notes_and_rests
261 | if n.isNote]
262 | length = int((offsetEnd - offsetStart) * self.subdivision) # in ticks
263 |
264 | # add entries to dictionaries if not present
265 | # should only be called by make_dataset when transposing
266 | note2index = self.note2index_dicts[part_id]
267 | index2note = self.index2note_dicts[part_id]
268 | voice_range = self.voice_ranges[part_id]
269 | min_pitch, max_pitch = voice_range
270 | for note_name, pitch in list_note_strings_and_pitches:
271 | # if out of range
272 | if pitch < min_pitch or pitch > max_pitch:
273 | note_name = OUT_OF_RANGE
274 |
275 | if note_name not in note2index:
276 | new_index = len(note2index)
277 | index2note.update({new_index: note_name})
278 | note2index.update({note_name: new_index})
279 | print('Warning: Entry ' + str(
280 | {new_index: note_name}) + ' added to dictionaries')
281 |
282 | # construct sequence
283 | j = 0
284 | i = 0
285 | t = np.zeros((length, 2))
286 | is_articulated = True
287 | num_notes = len(list_notes_and_rests)
288 | while i < length:
289 | if j < num_notes - 1:
290 | if (list_notes_and_rests[j + 1].offset > i
291 | / self.subdivision + offsetStart):
292 | t[i, :] = [note2index[standard_name(list_notes_and_rests[j],
293 | voice_range=voice_range)],
294 | is_articulated]
295 | i += 1
296 | is_articulated = False
297 | else:
298 | j += 1
299 | is_articulated = True
300 | else:
301 | t[i, :] = [note2index[standard_name(list_notes_and_rests[j],
302 | voice_range=voice_range)],
303 | is_articulated]
304 | i += 1
305 | is_articulated = False
306 | seq = t[:, 0] * t[:, 1] + (1 - t[:, 1]) * note2index[SLUR_SYMBOL]
307 | tensor = torch.from_numpy(seq).long()[None, :]
308 | return tensor
309 |
310 | def voice_range_in_subsequence(self, chorale, offsetStart, offsetEnd):
311 | """
312 | returns None if no note present in one of the voices -> no transposition
313 | :param chorale:
314 | :param offsetStart:
315 | :param offsetEnd:
316 | :return:
317 | """
318 | voice_ranges = []
319 | for part in chorale.parts[:self.num_voices]:
320 | voice_range_part = self.voice_range_in_part(part,
321 | offsetStart=offsetStart,
322 | offsetEnd=offsetEnd)
323 | if voice_range_part is None:
324 | return None
325 | else:
326 | voice_ranges.append(voice_range_part)
327 | return voice_ranges
328 |
329 | def voice_range_in_part(self, part, offsetStart, offsetEnd):
330 | notes_in_subsequence = part.flat.getElementsByOffset(
331 | offsetStart,
332 | offsetEnd,
333 | includeEndBoundary=False,
334 | mustBeginInSpan=True,
335 | mustFinishInSpan=False,
336 | classList=[music21.note.Note,
337 | music21.note.Rest])
338 | midi_pitches_part = [
339 | n.pitch.midi
340 | for n in notes_in_subsequence
341 | if n.isNote
342 | ]
343 | if len(midi_pitches_part) > 0:
344 | return min(midi_pitches_part), max(midi_pitches_part)
345 | else:
346 | return None
347 |
348 | def compute_index_dicts(self):
349 | print('Computing index dicts')
350 | self.index2note_dicts = [
351 | {} for _ in range(self.num_voices)
352 | ]
353 | self.note2index_dicts = [
354 | {} for _ in range(self.num_voices)
355 | ]
356 |
357 | # create and add additional symbols
358 | note_sets = [set() for _ in range(self.num_voices)]
359 | for note_set in note_sets:
360 | note_set.add(SLUR_SYMBOL)
361 | note_set.add(START_SYMBOL)
362 | note_set.add(END_SYMBOL)
363 | note_set.add(REST_SYMBOL)
364 |
365 | # get all notes: used for computing pitch ranges
366 | for chorale in tqdm(self.iterator_gen()):
367 | for part_id, part in enumerate(chorale.parts[:self.num_voices]):
368 | for n in part.flat.notesAndRests:
369 | note_sets[part_id].add(standard_name(n))
370 |
371 | # create tables
372 | for note_set, index2note, note2index in zip(note_sets,
373 | self.index2note_dicts,
374 | self.note2index_dicts):
375 | for note_index, note in enumerate(note_set):
376 | index2note.update({note_index: note})
377 | note2index.update({note: note_index})
378 |
379 | def is_valid(self, chorale):
380 | # We only consider 4-part chorales
381 | if not len(chorale.parts) == 4:
382 | return False
383 | # todo contains chord
384 | return True
385 |
386 | def compute_voice_ranges(self):
387 | assert self.index2note_dicts is not None
388 | assert self.note2index_dicts is not None
389 | self.voice_ranges = []
390 | print('Computing voice ranges')
391 | for voice_index, note2index in tqdm(enumerate(self.note2index_dicts)):
392 | notes = [
393 | standard_note(note_string)
394 | for note_string in note2index
395 | ]
396 | midi_pitches = [
397 | n.pitch.midi
398 | for n in notes
399 | if n.isNote
400 | ]
401 | min_midi, max_midi = min(midi_pitches), max(midi_pitches)
402 | self.voice_ranges.append((min_midi, max_midi))
403 |
404 | def extract_score_tensor_with_padding(self, tensor_score, start_tick, end_tick):
405 | """
406 | :param tensor_chorale: (num_voices, length in ticks)
407 | :param start_tick:
408 | :param end_tick:
409 | :return: tensor_chorale[:, start_tick: end_tick]
410 | with padding if necessary
411 | i.e. if start_tick < 0 or end_tick > tensor_chorale length
412 | """
413 | assert start_tick < end_tick
414 | assert end_tick > 0
415 | length = tensor_score.size()[1]
416 |
417 | padded_chorale = []
418 | # todo add PAD_SYMBOL
419 | if start_tick < 0:
420 | start_symbols = np.array([note2index[START_SYMBOL]
421 | for note2index in self.note2index_dicts])
422 | start_symbols = torch.from_numpy(start_symbols).long().clone()
423 | start_symbols = start_symbols.repeat(-start_tick, 1).transpose(0, 1)
424 | padded_chorale.append(start_symbols)
425 |
426 | slice_start = start_tick if start_tick > 0 else 0
427 | slice_end = end_tick if end_tick < length else length
428 |
429 | padded_chorale.append(tensor_score[:, slice_start: slice_end])
430 |
431 | if end_tick > length:
432 | end_symbols = np.array([note2index[END_SYMBOL]
433 | for note2index in self.note2index_dicts])
434 | end_symbols = torch.from_numpy(end_symbols).long().clone()
435 | end_symbols = end_symbols.repeat(end_tick - length, 1).transpose(0, 1)
436 | padded_chorale.append(end_symbols)
437 |
438 | padded_chorale = torch.cat(padded_chorale, 1)
439 | return padded_chorale
440 |
441 | def extract_metadata_with_padding(self, tensor_metadata,
442 | start_tick, end_tick):
443 | """
444 | :param tensor_metadata: (num_voices, length, num_metadatas)
445 | last metadata is the voice_index
446 | :param start_tick:
447 | :param end_tick:
448 | :return:
449 | """
450 | assert start_tick < end_tick
451 | assert end_tick > 0
452 | num_voices, length, num_metadatas = tensor_metadata.size()
453 | padded_tensor_metadata = []
454 |
455 | if start_tick < 0:
456 | # TODO more subtle padding
457 | start_symbols = np.zeros((self.num_voices, -start_tick, num_metadatas))
458 | start_symbols = torch.from_numpy(start_symbols).long().clone()
459 | padded_tensor_metadata.append(start_symbols)
460 |
461 | slice_start = start_tick if start_tick > 0 else 0
462 | slice_end = end_tick if end_tick < length else length
463 | padded_tensor_metadata.append(tensor_metadata[:, slice_start: slice_end, :])
464 |
465 | if end_tick > length:
466 | end_symbols = np.zeros((self.num_voices, end_tick - length, num_metadatas))
467 | end_symbols = torch.from_numpy(end_symbols).long().clone()
468 | padded_tensor_metadata.append(end_symbols)
469 |
470 | padded_tensor_metadata = torch.cat(padded_tensor_metadata, 1)
471 | return padded_tensor_metadata
472 |
473 | def empty_score_tensor(self, score_length):
474 | start_symbols = np.array([note2index[START_SYMBOL]
475 | for note2index in self.note2index_dicts])
476 | start_symbols = torch.from_numpy(start_symbols).long().clone()
477 | start_symbols = start_symbols.repeat(score_length, 1).transpose(0, 1)
478 | return start_symbols
479 |
480 | def random_score_tensor(self, score_length):
481 | chorale_tensor = np.array(
482 | [np.random.randint(len(note2index),
483 | size=score_length)
484 | for note2index in self.note2index_dicts])
485 | chorale_tensor = torch.from_numpy(chorale_tensor).long().clone()
486 | return chorale_tensor
487 |
488 | def tensor_to_score(self, tensor_score,
489 | fermata_tensor=None):
490 | """
491 | :param tensor_score: (num_voices, length)
492 | :return: music21 score object
493 | """
494 | slur_indexes = [note2index[SLUR_SYMBOL]
495 | for note2index in self.note2index_dicts]
496 |
497 | score = music21.stream.Score()
498 | num_voices = tensor_score.size(0)
499 | name_parts = (num_voices == 4)
500 | part_names = ['Soprano', 'Alto', 'Tenor', 'Bass']
501 |
502 |
503 | for voice_index, (voice, index2note, slur_index) in enumerate(
504 | zip(tensor_score,
505 | self.index2note_dicts,
506 | slur_indexes)):
507 | add_fermata = False
508 | if name_parts:
509 | part = stream.Part(id=part_names[voice_index],
510 | partName=part_names[voice_index],
511 | partAbbreviation=part_names[voice_index],
512 | instrumentName=part_names[voice_index])
513 | else:
514 | part = stream.Part(id='part' + str(voice_index))
515 | dur = 0
516 | total_duration = 0
517 | f = music21.note.Rest()
518 | for note_index in [n.item() for n in voice]:
519 | # if it is a played note
520 | if not note_index == slur_indexes[voice_index]:
521 | # add previous note
522 | if dur > 0:
523 | f.duration = music21.duration.Duration(dur / self.subdivision)
524 |
525 | if add_fermata:
526 | f.expressions.append(music21.expressions.Fermata())
527 | add_fermata = False
528 |
529 | part.append(f)
530 |
531 |
532 | dur = 1
533 | f = standard_note(index2note[note_index])
534 | if fermata_tensor is not None and voice_index == 0:
535 | if fermata_tensor[0, total_duration] == 1:
536 | add_fermata = True
537 | else:
538 | add_fermata = False
539 | total_duration += 1
540 |
541 | else:
542 | dur += 1
543 | total_duration += 1
544 | # add last note
545 | f.duration = music21.duration.Duration(dur / self.subdivision)
546 | if add_fermata:
547 | f.expressions.append(music21.expressions.Fermata())
548 | add_fermata = False
549 |
550 | part.append(f)
551 | score.insert(part)
552 | return score
553 |
554 |
555 | # TODO should go in ChoraleDataset
556 | # TODO all subsequences start on a beat
557 | class ChoraleBeatsDataset(ChoraleDataset):
558 | def __repr__(self):
559 | return f'ChoraleBeatsDataset(' \
560 | f'{self.voice_ids},' \
561 | f'{self.name},' \
562 | f'{[metadata.name for metadata in self.metadatas]},' \
563 | f'{self.sequences_size},' \
564 | f'{self.subdivision})'
565 |
566 | def make_tensor_dataset(self):
567 | """
568 | Implementation of the make_tensor_dataset abstract base class
569 | """
570 | # todo check on chorale with Chord
571 | print('Making tensor dataset')
572 | self.compute_index_dicts()
573 | self.compute_voice_ranges()
574 | one_beat = 1.
575 | chorale_tensor_dataset = []
576 | metadata_tensor_dataset = []
577 | for chorale_id, chorale in tqdm(enumerate(self.iterator_gen())):
578 |
579 | # precompute all possible transpositions and corresponding metadatas
580 | chorale_transpositions = {}
581 | metadatas_transpositions = {}
582 |
583 | # main loop
584 | for offsetStart in np.arange(
585 | chorale.flat.lowestOffset -
586 | (self.sequences_size - one_beat),
587 | chorale.flat.highestOffset,
588 | one_beat):
589 | offsetEnd = offsetStart + self.sequences_size
590 | current_subseq_ranges = self.voice_range_in_subsequence(
591 | chorale,
592 | offsetStart=offsetStart,
593 | offsetEnd=offsetEnd)
594 |
595 | transposition = self.min_max_transposition(current_subseq_ranges)
596 | min_transposition_subsequence, max_transposition_subsequence = transposition
597 |
598 | for semi_tone in range(min_transposition_subsequence,
599 | max_transposition_subsequence + 1):
600 | start_tick = int(offsetStart * self.subdivision)
601 | end_tick = int(offsetEnd * self.subdivision)
602 |
603 | try:
604 | # compute transpositions lazily
605 | if semi_tone not in chorale_transpositions:
606 | (chorale_tensor,
607 | metadata_tensor) = self.transposed_score_and_metadata_tensors(
608 | chorale,
609 | semi_tone=semi_tone)
610 | chorale_transpositions.update(
611 | {semi_tone:
612 | chorale_tensor})
613 | metadatas_transpositions.update(
614 | {semi_tone:
615 | metadata_tensor})
616 | else:
617 | chorale_tensor = chorale_transpositions[semi_tone]
618 | metadata_tensor = metadatas_transpositions[semi_tone]
619 |
620 | local_chorale_tensor = self.extract_score_tensor_with_padding(
621 | chorale_tensor,
622 | start_tick, end_tick)
623 | local_metadata_tensor = self.extract_metadata_with_padding(
624 | metadata_tensor,
625 | start_tick, end_tick)
626 |
627 | # append and add batch dimension
628 | # cast to int
629 | chorale_tensor_dataset.append(
630 | local_chorale_tensor[None, :, :].int())
631 | metadata_tensor_dataset.append(
632 | local_metadata_tensor[None, :, :, :].int())
633 | except KeyError:
634 | # some problems may occur with the key analyzer
635 | print(f'KeyError with chorale {chorale_id}')
636 |
637 | chorale_tensor_dataset = torch.cat(chorale_tensor_dataset, 0)
638 | metadata_tensor_dataset = torch.cat(metadata_tensor_dataset, 0)
639 |
640 | dataset = TensorDataset(chorale_tensor_dataset,
641 | metadata_tensor_dataset)
642 |
643 | print(f'Sizes: {chorale_tensor_dataset.size()}, {metadata_tensor_dataset.size()}')
644 | return dataset
645 |
--------------------------------------------------------------------------------
/DatasetManager/dataset_manager.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import music21
4 | import torch
5 | from DatasetManager.chorale_dataset import ChoraleDataset
6 | from DatasetManager.helpers import ShortChoraleIteratorGen
7 | from DatasetManager.metadata import TickMetadata, \
8 | FermataMetadata, \
9 | KeyMetadata
10 | from DatasetManager.music_dataset import MusicDataset
11 |
12 | # Basically, all you have to do to use an existing dataset is to
13 | # add an entry in the all_datasets variable
14 | # and specify its base class and which music21 objects it uses
15 | # by giving an iterator over music21 scores
16 |
17 | all_datasets = {
18 | 'bach_chorales':
19 | {
20 | 'dataset_class_name': ChoraleDataset,
21 | 'corpus_it_gen': music21.corpus.chorales.Iterator
22 | },
23 | 'bach_chorales_test':
24 | {
25 | 'dataset_class_name': ChoraleDataset,
26 | 'corpus_it_gen': ShortChoraleIteratorGen()
27 | },
28 | }
29 |
30 |
31 | class DatasetManager:
32 | def __init__(self):
33 | self.package_dir = os.path.dirname(os.path.realpath(__file__))
34 | self.cache_dir = os.path.join(self.package_dir,
35 | 'dataset_cache')
36 | # create cache dir if it doesn't exist
37 | if not os.path.exists(self.cache_dir):
38 | os.mkdir(self.cache_dir)
39 |
40 | def get_dataset(self, name: str, **dataset_kwargs) -> MusicDataset:
41 | if name in all_datasets:
42 | return self.load_if_exists_or_initialize_and_save(
43 | name=name,
44 | **all_datasets[name],
45 | **dataset_kwargs
46 | )
47 | else:
48 | print('Dataset with name {name} is not registered in all_datasets variable')
49 | raise ValueError
50 |
51 | def load_if_exists_or_initialize_and_save(self,
52 | dataset_class_name,
53 | corpus_it_gen,
54 | name,
55 | **kwargs):
56 | """
57 |
58 | :param dataset_class_name:
59 | :param corpus_it_gen:
60 | :param name:
61 | :param kwargs: parameters specific to an implementation
62 | of MusicDataset (ChoraleDataset for instance)
63 | :return:
64 | """
65 | kwargs.update(
66 | {'name': name,
67 | 'corpus_it_gen': corpus_it_gen,
68 | 'cache_dir': self.cache_dir
69 | })
70 | dataset = dataset_class_name(**kwargs)
71 | if os.path.exists(dataset.filepath):
72 | print(f'Loading {dataset.__repr__()} from {dataset.filepath}')
73 | dataset = torch.load(dataset.filepath)
74 | dataset.cache_dir = self.cache_dir
75 | print(f'(the corresponding TensorDataset is not loaded)')
76 | else:
77 | print(f'Creating {dataset.__repr__()}, '
78 | f'both tensor dataset and parameters')
79 | # initialize and force the computation of the tensor_dataset
80 | # first remove the cached data if it exists
81 | if os.path.exists(dataset.tensor_dataset_filepath):
82 | os.remove(dataset.tensor_dataset_filepath)
83 | # recompute dataset parameters and tensor_dataset
84 | # this saves the tensor_dataset in dataset.tensor_dataset_filepath
85 | tensor_dataset = dataset.tensor_dataset
86 | # save all dataset parameters EXCEPT the tensor dataset
87 | # which is stored elsewhere
88 | dataset.tensor_dataset = None
89 | torch.save(dataset, dataset.filepath)
90 | print(f'{dataset.__repr__()} saved in {dataset.filepath}')
91 | dataset.tensor_dataset = tensor_dataset
92 | return dataset
93 |
94 |
95 | if __name__ == '__main__':
96 | # Usage example
97 |
98 | dataset_manager = DatasetManager()
99 | subdivision = 4
100 | metadatas = [
101 | TickMetadata(subdivision=subdivision),
102 | FermataMetadata(),
103 | KeyMetadata()
104 | ]
105 |
106 | bach_chorales_dataset: ChoraleDataset = dataset_manager.get_dataset(
107 | name='bach_chorales_test',
108 | voice_ids=[0, 1, 2, 3],
109 | metadatas=metadatas,
110 | sequences_size=8,
111 | subdivision=subdivision
112 | )
113 | (train_dataloader,
114 | val_dataloader,
115 | test_dataloader) = bach_chorales_dataset.data_loaders(
116 | batch_size=128,
117 | split=(0.85, 0.10)
118 | )
119 | print('Num Train Batches: ', len(train_dataloader))
120 | print('Num Valid Batches: ', len(val_dataloader))
121 | print('Num Test Batches: ', len(test_dataloader))
122 |
--------------------------------------------------------------------------------
/DatasetManager/helpers.py:
--------------------------------------------------------------------------------
1 | import music21
2 | from itertools import islice
3 |
4 | from music21 import note, harmony, expressions
5 |
6 | # constants
7 | SLUR_SYMBOL = '__'
8 | START_SYMBOL = 'START'
9 | END_SYMBOL = 'END'
10 | REST_SYMBOL = 'rest'
11 | OUT_OF_RANGE = 'OOR'
12 | PAD_SYMBOL = 'XX'
13 |
14 |
15 | def standard_name(note_or_rest, voice_range=None):
16 | """
17 | Convert music21 objects to str
18 | :param note_or_rest:
19 | :return:
20 | """
21 | if isinstance(note_or_rest, note.Note):
22 | if voice_range is not None:
23 | min_pitch, max_pitch = voice_range
24 | pitch = note_or_rest.pitch.midi
25 | if pitch < min_pitch or pitch > max_pitch:
26 | return OUT_OF_RANGE
27 | return note_or_rest.nameWithOctave
28 | if isinstance(note_or_rest, note.Rest):
29 | return note_or_rest.name # == 'rest' := REST_SYMBOL
30 | if isinstance(note_or_rest, str):
31 | return note_or_rest
32 |
33 | if isinstance(note_or_rest, harmony.ChordSymbol):
34 | return note_or_rest.figure
35 | if isinstance(note_or_rest, expressions.TextExpression):
36 | return note_or_rest.content
37 |
38 |
39 | def standard_note(note_or_rest_string):
40 | """
41 | Convert str representing a music21 object to this object
42 | :param note_or_rest_string:
43 | :return:
44 | """
45 | if note_or_rest_string == 'rest':
46 | return note.Rest()
47 | # treat other additional symbols as rests
48 | elif (note_or_rest_string == END_SYMBOL
49 | or
50 | note_or_rest_string == START_SYMBOL
51 | or
52 | note_or_rest_string == PAD_SYMBOL):
53 | # print('Warning: Special symbol is used in standard_note')
54 | return note.Rest()
55 | elif note_or_rest_string == SLUR_SYMBOL:
56 | # print('Warning: SLUR_SYMBOL used in standard_note')
57 | return note.Rest()
58 | elif note_or_rest_string == OUT_OF_RANGE:
59 | # print('Warning: OUT_OF_RANGE used in standard_note')
60 | return note.Rest()
61 | else:
62 | return note.Note(note_or_rest_string)
63 |
64 |
65 | class ShortChoraleIteratorGen:
66 | """
67 | Class used for debugging
68 | when called, it returns an iterator over 3 Bach chorales,
69 | similar to music21.corpus.chorales.Iterator()
70 | """
71 |
72 | def __init__(self):
73 | pass
74 |
75 | def __call__(self):
76 | it = (
77 | chorale
78 | for chorale in
79 | islice(music21.corpus.chorales.Iterator(), 3)
80 | )
81 | return it.__iter__()
82 |
--------------------------------------------------------------------------------
/DatasetManager/metadata.py:
--------------------------------------------------------------------------------
1 | """
2 | Metadata classes
3 | """
4 | import numpy as np
5 | from music21 import analysis, stream, meter
6 | from DatasetManager.helpers import SLUR_SYMBOL, \
7 | PAD_SYMBOL
8 |
9 |
10 | class Metadata:
11 | def __init__(self):
12 | self.num_values = None
13 | self.is_global = None
14 | self.name = None
15 |
16 | def get_index(self, value):
17 | # trick with the 0 value
18 | raise NotImplementedError
19 |
20 | def get_value(self, index):
21 | raise NotImplementedError
22 |
23 | def evaluate(self, chorale, subdivision):
24 | """
25 | takes a music21 chorale as input and the number of subdivisions per beat
26 | """
27 | raise NotImplementedError
28 |
29 | def generate(self, length):
30 | raise NotImplementedError
31 |
32 |
33 | class IsPlayingMetadata(Metadata):
34 | def __init__(self, voice_index, min_num_ticks):
35 | """
36 | Metadata that indicates if a voice is playing
37 | Voice i is considered to be muted if more than 'min_num_ticks' contiguous
38 | ticks contain a rest.
39 |
40 |
41 | :param voice_index: index of the voice to take into account
42 | :param min_num_ticks: minimum length in ticks for a rest to be taken
43 | into account in the metadata
44 | """
45 | super(IsPlayingMetadata, self).__init__()
46 | self.min_num_ticks = min_num_ticks
47 | self.voice_index = voice_index
48 | self.is_global = False
49 | self.num_values = 2
50 | self.name = 'isplaying'
51 |
52 | def get_index(self, value):
53 | return int(value)
54 |
55 | def get_value(self, index):
56 | return bool(index)
57 |
58 | def evaluate(self, chorale, subdivision):
59 | """
60 | takes a music21 chorale as input
61 | """
62 | length = int(chorale.duration.quarterLength * subdivision)
63 | metadatas = np.ones(shape=(length,))
64 | part = chorale.parts[self.voice_index]
65 |
66 | for note_or_rest in part.notesAndRests:
67 | is_playing = True
68 | if note_or_rest.isRest:
69 | if note_or_rest.quarterLength * subdivision >= self.min_num_ticks:
70 | is_playing = False
71 | # these should be integer values
72 | start_tick = note_or_rest.offset * subdivision
73 | end_tick = start_tick + note_or_rest.quarterLength * subdivision
74 | metadatas[start_tick:end_tick] = self.get_index(is_playing)
75 | return metadatas
76 |
77 | def generate(self, length):
78 | return np.ones(shape=(length,))
79 |
80 |
81 | class TickMetadata(Metadata):
82 | """
83 | Metadata class that tracks on which subdivision of the beat we are on
84 | """
85 |
86 | def __init__(self, subdivision):
87 | super(TickMetadata, self).__init__()
88 | self.is_global = False
89 | self.num_values = subdivision
90 | self.name = 'tick'
91 |
92 | def get_index(self, value):
93 | return value
94 |
95 | def get_value(self, index):
96 | return index
97 |
98 | def evaluate(self, chorale, subdivision):
99 | assert subdivision == self.num_values
100 | # suppose all pieces start on a beat
101 | length = int(chorale.duration.quarterLength * subdivision)
102 | return np.array(list(map(
103 | lambda x: x % self.num_values,
104 | range(length)
105 | )))
106 |
107 | def generate(self, length):
108 | return np.array(list(map(
109 | lambda x: x % self.num_values,
110 | range(length)
111 | )))
112 |
113 |
114 | class ModeMetadata(Metadata):
115 | """
116 | Metadata class that indicates the current mode of the melody
117 | can be major, minor or other
118 | """
119 |
120 | def __init__(self):
121 | super(ModeMetadata, self).__init__()
122 | self.is_global = False
123 | self.num_values = 3 # major, minor or other
124 | self.name = 'mode'
125 |
126 | def get_index(self, value):
127 | if value == 'major':
128 | return 1
129 | if value == 'minor':
130 | return 2
131 | return 0
132 |
133 | def get_value(self, index):
134 | if index == 1:
135 | return 'major'
136 | if index == 2:
137 | return 'minor'
138 | return 'other'
139 |
140 | def evaluate(self, chorale, subdivision):
141 | # todo add measures when in midi
142 | # init key analyzer
143 | ka = analysis.floatingKey.KeyAnalyzer(chorale)
144 | res = ka.run()
145 |
146 | measure_offset_map = chorale.parts[0].measureOffsetMap()
147 | length = int(chorale.duration.quarterLength * subdivision) # in 16th notes
148 |
149 | modes = np.zeros((length,))
150 |
151 | measure_index = -1
152 | for time_index in range(length):
153 | beat_index = time_index / subdivision
154 | if beat_index in measure_offset_map:
155 | measure_index += 1
156 | modes[time_index] = self.get_index(res[measure_index].mode)
157 |
158 | return np.array(modes, dtype=np.int32)
159 |
160 | def generate(self, length):
161 | return np.full((length,), self.get_index('major'))
162 |
163 |
164 | class KeyMetadata(Metadata):
165 | """
166 | Metadata class that indicates in which key we are
167 | Only returns the number of sharps or flats
168 | Does not distinguish a key from its relative key
169 | """
170 |
171 | def __init__(self, window_size=4):
172 | super(KeyMetadata, self).__init__()
173 | self.window_size = window_size
174 | self.is_global = False
175 | self.num_max_sharps = 7
176 | self.num_values = 16
177 | self.name = 'key'
178 |
179 | def get_index(self, value):
180 | """
181 |
182 | :param value: number of sharps (between -7 and +7)
183 | :return: index in the representation
184 | """
185 | return value + self.num_max_sharps + 1
186 |
187 | def get_value(self, index):
188 | """
189 |
190 | :param index: index (between 0 and self.num_values); 0 is unused (no constraint)
191 | :return: true number of sharps (between -7 and 7)
192 | """
193 | return index - 1 - self.num_max_sharps
194 |
195 | def evaluate(self, chorale, subdivision):
196 | # init key analyzer
197 | # we must add measures by hand for the case when we are parsing midi files
198 | chorale_with_measures = stream.Score()
199 | for part in chorale.parts:
200 | chorale_with_measures.append(part.makeMeasures())
201 |
202 | ka = analysis.floatingKey.KeyAnalyzer(chorale_with_measures)
203 | ka.windowSize = self.window_size
204 | res = ka.run()
205 |
206 | measure_offset_map = chorale_with_measures.parts.measureOffsetMap()
207 | length = int(chorale.duration.quarterLength * subdivision) # in 16th notes
208 |
209 | key_signatures = np.zeros((length,))
210 |
211 | measure_index = -1
212 | for time_index in range(length):
213 | beat_index = time_index / subdivision
214 | if beat_index in measure_offset_map:
215 | measure_index += 1
216 | if measure_index == len(res):
217 | measure_index -= 1
218 |
219 | key_signatures[time_index] = self.get_index(res[measure_index].sharps)
220 | return np.array(key_signatures, dtype=np.int32)
221 |
222 | def generate(self, length):
223 | return np.full((length,), self.get_index(0))
224 |
225 |
226 | class FermataMetadata(Metadata):
227 | """
228 | Metadata class which indicates if a fermata is on the current note
229 | """
230 |
231 | def __init__(self):
232 | super(FermataMetadata, self).__init__()
233 | self.is_global = False
234 | self.num_values = 2
235 | self.name = 'fermata'
236 |
237 | def get_index(self, value):
238 | # possible values are 1 and 0, thus value = index
239 | return value
240 |
241 | def get_value(self, index):
242 | # possible values are 1 and 0, thus value = index
243 | return index
244 |
245 | def evaluate(self, chorale, subdivision):
246 | part = chorale.parts[0]
247 | length = int(part.duration.quarterLength * subdivision) # in 16th notes
248 | list_notes = part.flat.notes
249 | num_notes = len(list_notes)
250 | j = 0
251 | i = 0
252 | fermatas = np.zeros((length,))
253 | while i < length:
254 | if j < num_notes - 1:
255 | if list_notes[j + 1].offset > i / subdivision:
256 |
257 | if len(list_notes[j].expressions) == 1:
258 | fermata = True
259 | else:
260 | fermata = False
261 | fermatas[i] = fermata
262 | i += 1
263 | else:
264 | j += 1
265 | else:
266 | if len(list_notes[j].expressions) == 1:
267 | fermata = True
268 | else:
269 | fermata = False
270 |
271 | fermatas[i] = fermata
272 | i += 1
273 | return np.array(fermatas, dtype=np.int32)
274 |
275 | def generate(self, length):
276 | # fermata every 2 bars
277 | return np.array([1 if i % 32 >= 28 else 0
278 | for i in range(length)])
279 |
--------------------------------------------------------------------------------
/DatasetManager/music_dataset.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import os
3 | from torch.utils.data import TensorDataset, DataLoader
4 | import torch
5 |
6 |
7 | class MusicDataset(ABC):
8 | """
9 | Abstract Base Class for music datasets
10 | """
11 |
12 | def __init__(self, cache_dir):
13 | self._tensor_dataset = None
14 | self.cache_dir = cache_dir
15 |
16 | @abstractmethod
17 | def iterator_gen(self):
18 | """
19 |
20 | return: Iterator over the dataset
21 | """
22 | pass
23 |
24 | @abstractmethod
25 | def make_tensor_dataset(self):
26 | """
27 |
28 | :return: TensorDataset
29 | """
30 | pass
31 |
32 | @abstractmethod
33 | def get_score_tensor(self, score):
34 | """
35 |
36 | :param score: music21 score object
37 | :return: torch tensor, with the score representation
38 | as a tensor
39 | """
40 | pass
41 |
42 | @abstractmethod
43 | def get_metadata_tensor(self, score):
44 | """
45 |
46 | :param score: music21 score object
47 | :return: torch tensor, with the metadata representation
48 | as a tensor
49 | """
50 | pass
51 |
52 | @abstractmethod
53 | def transposed_score_and_metadata_tensors(self, score, semi_tone):
54 | """
55 |
56 | :param score: music21 score object
57 | :param semi-tone: int, +12 to -12, semitones to transpose
58 | :return: Transposed score shifted by the semi-tone
59 | """
60 | pass
61 |
62 | @abstractmethod
63 | def extract_score_tensor_with_padding(self,
64 | tensor_score,
65 | start_tick,
66 | end_tick):
67 | """
68 |
69 | :param tensor_score: torch tensor containing the score representation
70 | :param start_tick:
71 | :param end_tick:
72 | :return: tensor_score[:, start_tick: end_tick]
73 | with padding if necessary
74 | i.e. if start_tick < 0 or end_tick > tensor_score length
75 | """
76 | pass
77 |
78 | @abstractmethod
79 | def extract_metadata_with_padding(self,
80 | tensor_metadata,
81 | start_tick,
82 | end_tick):
83 | """
84 |
85 | :param tensor_metadata: torch tensor containing metadata
86 | :param start_tick:
87 | :param end_tick:
88 | :return:
89 | """
90 | pass
91 |
92 | @abstractmethod
93 | def empty_score_tensor(self, score_length):
94 | """
95 |
96 | :param score_length: int, length of the score in ticks
97 | :return: torch long tensor, initialized with start indices
98 | """
99 | pass
100 |
101 | @abstractmethod
102 | def random_score_tensor(self, score_length):
103 | """
104 |
105 | :param score_length: int, length of the score in ticks
106 | :return: torch long tensor, initialized with random indices
107 | """
108 | pass
109 |
110 | @abstractmethod
111 | def tensor_to_score(self, tensor_score):
112 | """
113 |
114 | :param tensor_score: torch tensor, tensor representation
115 | of the score
116 | :return: music21 score object
117 | """
118 | pass
119 |
120 | @property
121 | def tensor_dataset(self):
122 | """
123 | Loads or computes TensorDataset
124 | :return: TensorDataset
125 | """
126 | if self._tensor_dataset is None:
127 | if self.tensor_dataset_is_cached():
128 | print(f'Loading TensorDataset for {self.__repr__()}')
129 | self._tensor_dataset = torch.load(self.tensor_dataset_filepath)
130 | else:
131 | print(f'Creating {self.__repr__()} TensorDataset'
132 | f' since it is not cached')
133 | self._tensor_dataset = self.make_tensor_dataset()
134 | torch.save(self._tensor_dataset, self.tensor_dataset_filepath)
135 | print(f'TensorDataset for {self.__repr__()} '
136 | f'saved in {self.tensor_dataset_filepath}')
137 | return self._tensor_dataset
138 |
139 | @tensor_dataset.setter
140 | def tensor_dataset(self, value):
141 | self._tensor_dataset = value
142 |
143 | def tensor_dataset_is_cached(self):
144 | return os.path.exists(self.tensor_dataset_filepath)
145 |
146 | @property
147 | def tensor_dataset_filepath(self):
148 | tensor_datasets_cache_dir = os.path.join(
149 | self.cache_dir,
150 | 'tensor_datasets')
151 | if not os.path.exists(tensor_datasets_cache_dir):
152 | os.mkdir(tensor_datasets_cache_dir)
153 | fp = os.path.join(
154 | tensor_datasets_cache_dir,
155 | self.__repr__()
156 | )
157 | return fp
158 |
159 | @property
160 | def filepath(self):
161 | tensor_datasets_cache_dir = os.path.join(
162 | self.cache_dir,
163 | 'datasets')
164 | if not os.path.exists(tensor_datasets_cache_dir):
165 | os.mkdir(tensor_datasets_cache_dir)
166 | return os.path.join(
167 | self.cache_dir,
168 | 'datasets',
169 | self.__repr__()
170 | )
171 |
172 | def data_loaders(self, batch_size, split=(0.85, 0.10)):
173 | """
174 | Returns three data loaders obtained by splitting
175 | self.tensor_dataset according to split
176 | :param batch_size:
177 | :param split:
178 | :return:
179 | """
180 | assert sum(split) < 1
181 |
182 | dataset = self.tensor_dataset
183 | num_examples = len(dataset)
184 | a, b = split
185 | train_dataset = TensorDataset(*dataset[: int(a * num_examples)])
186 | val_dataset = TensorDataset(*dataset[int(a * num_examples):
187 | int((a + b) * num_examples)])
188 | eval_dataset = TensorDataset(*dataset[int((a + b) * num_examples):])
189 |
190 | train_dl = DataLoader(
191 | train_dataset,
192 | batch_size=batch_size,
193 | shuffle=True,
194 | num_workers=4,
195 | pin_memory=True,
196 | drop_last=True,
197 | )
198 |
199 | val_dl = DataLoader(
200 | val_dataset,
201 | batch_size=batch_size,
202 | shuffle=False,
203 | num_workers=0,
204 | pin_memory=False,
205 | drop_last=True,
206 | )
207 |
208 | eval_dl = DataLoader(
209 | eval_dataset,
210 | batch_size=batch_size,
211 | shuffle=False,
212 | num_workers=0,
213 | pin_memory=False,
214 | drop_last=True,
215 | )
216 | return train_dl, val_dl, eval_dl
217 |
--------------------------------------------------------------------------------
/DeepBach/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ghadjeres/DeepBach/6d75cb940f3aa53e02f9eade34d58e472e0c95d7/DeepBach/__init__.py
--------------------------------------------------------------------------------
/DeepBach/data_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | @author: Gaetan Hadjeres
5 | """
6 |
7 | import torch
8 | from DeepBach.helpers import cuda_variable
9 |
10 |
11 | def mask_entry(tensor, entry_index, dim):
12 | """
13 | Masks entry entry_index on dim dim
14 | similar to
15 | torch.cat(( tensor[ :entry_index], tensor[ entry_index + 1 :], 0)
16 | but on another dimension
17 | :param tensor:
18 | :param entry_index:
19 | :param dim:
20 | :return:
21 | """
22 | idx = [i for i in range(tensor.size(dim)) if not i == entry_index]
23 | idx = cuda_variable(torch.LongTensor(idx))
24 | tensor = tensor.index_select(dim, idx)
25 | return tensor
26 |
27 |
28 | def reverse_tensor(tensor, dim):
29 | """
30 | Do tensor[:, ... , -1::-1, :] along dim dim
31 | :param tensor:
32 | :param dim:
33 | :return:
34 | """
35 | idx = [i for i in range(tensor.size(dim) - 1, -1, -1)]
36 | idx = cuda_variable(torch.LongTensor(idx))
37 | tensor = tensor.index_select(dim, idx)
38 | return tensor
39 |
--------------------------------------------------------------------------------
/DeepBach/helpers.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Gaetan Hadjeres
3 | """
4 |
5 | import torch
6 | from torch.autograd import Variable
7 |
8 |
9 | def cuda_variable(tensor, volatile=False):
10 | if torch.cuda.is_available():
11 | return Variable(tensor.cuda(), volatile=volatile)
12 | else:
13 | return Variable(tensor, volatile=volatile)
14 |
15 |
16 | def to_numpy(variable: Variable):
17 | if torch.cuda.is_available():
18 | return variable.data.cpu().numpy()
19 | else:
20 | return variable.data.numpy()
21 |
22 |
23 | def init_hidden(num_layers, batch_size, lstm_hidden_size,
24 | volatile=False):
25 | hidden = (
26 | cuda_variable(
27 | torch.randn(num_layers, batch_size, lstm_hidden_size),
28 | volatile=volatile),
29 | cuda_variable(
30 | torch.randn(num_layers, batch_size, lstm_hidden_size),
31 | volatile=volatile)
32 | )
33 | return hidden
34 |
--------------------------------------------------------------------------------
/DeepBach/metadata.py:
--------------------------------------------------------------------------------
1 | """
2 | Metadata classes
3 | """
4 | import numpy as np
5 | from .data_utils import SUBDIVISION
6 | from music21 import analysis, stream
7 |
8 |
9 | class Metadata:
10 | def __init__(self):
11 | self.num_values = None
12 | self.is_global = None
13 | raise NotImplementedError
14 |
15 | def get_index(self, value):
16 | # trick with the 0 value
17 | raise NotImplementedError
18 |
19 | def get_value(self, index):
20 | raise NotImplementedError
21 |
22 | def evaluate(self, chorale):
23 | """
24 | takes a music21 chorale as input
25 | """
26 | raise NotImplementedError
27 |
28 | def generate(self, length):
29 | raise NotImplementedError
30 |
31 |
32 | # todo BeatMetadata class
33 | # todo add strong/weak beat metadata
34 | # todo add minor/major metadata
35 | # todo add voice_i_playing metadata
36 |
37 | class IsPlayingMetadata(Metadata):
38 | def __init__(self, voice_index, min_num_ticks=SUBDIVISION):
39 | """ Initiate the IsPlaying metadata.
40 | Voice i is considered to be muted if more than 'window_size' contiguous subdivisions that contains a rest.
41 |
42 | :param min_num_ticks: minimum length in ticks for a rest to be taken into account in the metadata
43 | """
44 | self.min_num_ticks = min_num_ticks
45 | self.voice_index = voice_index
46 | self.is_global = False
47 | self.num_values = 2
48 |
49 | def get_index(self, value):
50 | return int(value)
51 |
52 | def get_value(self, index):
53 | return bool(index)
54 |
55 | def evaluate(self, chorale):
56 | """
57 | takes a music21 chorale as input
58 | """
59 | length = int(chorale.duration.quarterLength * SUBDIVISION)
60 | metadatas = np.ones(shape=(length,))
61 | part = chorale.parts[self.voice_index]
62 |
63 | for note_or_rest in part.notesAndRests:
64 | is_playing = True
65 | if note_or_rest.isRest:
66 | if note_or_rest.quarterLength * SUBDIVISION >= self.min_num_ticks:
67 | is_playing = False
68 | # these should be integer values
69 | start_tick = note_or_rest.offset * SUBDIVISION
70 | end_tick = start_tick + note_or_rest.quarterLength * SUBDIVISION
71 | metadatas[start_tick:end_tick] = self.get_index(is_playing)
72 | return metadatas
73 |
74 | def generate(self, length):
75 | return np.ones(shape=(length,))
76 |
77 |
78 | class TickMetadatas(Metadata):
79 | def __init__(self, num_subdivisions):
80 | self.is_global = False
81 | self.num_values = num_subdivisions
82 |
83 | def get_index(self, value):
84 | return value
85 |
86 | def get_value(self, index):
87 | return index
88 |
89 | def evaluate(self, chorale):
90 | # suppose all pieces start on a beat
91 | length = int(chorale.duration.quarterLength * SUBDIVISION)
92 | return np.array(list(map(
93 | lambda x: x % self.num_values,
94 | range(length)
95 | )))
96 |
97 | def generate(self, length):
98 | return np.array(list(map(
99 | lambda x: x % self.num_values,
100 | range(length)
101 | )))
102 |
103 |
104 | class ModeMetadatas(Metadata):
105 | def __init__(self):
106 | self.is_global = False
107 | self.num_values = 3 # major, minor or other
108 |
109 | def get_index(self, value):
110 | if value == 'major':
111 | return 1
112 | if value == 'minor':
113 | return 2
114 | return 0
115 |
116 | def get_value(self, index):
117 | if index == 1:
118 | return 'major'
119 | if index == 2:
120 | return 'minor'
121 | return 'other'
122 |
123 | def evaluate(self, chorale):
124 | # todo add measures when in midi
125 | # init key analyzer
126 | ka = analysis.floatingKey.KeyAnalyzer(chorale)
127 | res = ka.run()
128 |
129 | measure_offset_map = chorale.parts[0].measureOffsetMap()
130 | length = int(chorale.duration.quarterLength * SUBDIVISION) # in 16th notes
131 |
132 | modes = np.zeros((length,))
133 |
134 | measure_index = -1
135 | for time_index in range(length):
136 | beat_index = time_index / SUBDIVISION
137 | if beat_index in measure_offset_map:
138 | measure_index += 1
139 | modes[time_index] = self.get_index(res[measure_index].mode)
140 |
141 | return np.array(modes, dtype=np.int32)
142 |
143 | def generate(self, length):
144 | return np.full((length,), self.get_index('major'))
145 |
146 |
147 | class KeyMetadatas(Metadata):
148 | def __init__(self, window_size=4):
149 | self.window_size = window_size
150 | self.is_global = False
151 | self.num_max_sharps = 7
152 | self.num_values = 16
153 |
154 | def get_index(self, value):
155 | """
156 |
157 | :param value: number of sharps (between -7 and +7)
158 | :return: index in the representation
159 | """
160 | return value + self.num_max_sharps + 1
161 |
162 | def get_value(self, index):
163 | """
164 |
165 | :param index: index (between 0 and self.num_values); 0 is unused (no constraint)
166 | :return: true number of sharps (between -7 and 7)
167 | """
168 | return index - 1 - self.num_max_sharps
169 |
170 | # todo check if this method is correct for windowSize > 1
171 | def evaluate(self, chorale):
172 | # init key analyzer
173 | # we must add measures by hand for the case when we are parsing midi files
174 | chorale_with_measures = stream.Score()
175 | for part in chorale.parts:
176 | chorale_with_measures.append(part.makeMeasures())
177 |
178 | ka = analysis.floatingKey.KeyAnalyzer(chorale_with_measures)
179 | ka.windowSize = self.window_size
180 | res = ka.run()
181 |
182 | measure_offset_map = chorale_with_measures.parts.measureOffsetMap()
183 | length = int(chorale.duration.quarterLength * SUBDIVISION) # in 16th notes
184 |
185 | key_signatures = np.zeros((length,))
186 |
187 | measure_index = -1
188 | for time_index in range(length):
189 | beat_index = time_index / SUBDIVISION
190 | if beat_index in measure_offset_map:
191 | measure_index += 1
192 | # todo remove this trick: problem with the last measures...
193 | if measure_index == len(res):
194 | measure_index -= 1
195 |
196 | key_signatures[time_index] = self.get_index(res[measure_index].sharps)
197 | return np.array(key_signatures, dtype=np.int32)
198 |
199 | def generate(self, length):
200 | return np.full((length,), self.get_index(0))
201 |
202 |
203 | class FermataMetadatas(Metadata):
204 | def __init__(self):
205 | self.is_global = False
206 | self.num_values = 2
207 |
208 | def get_index(self, value):
209 | # values are 1 and 0
210 | return value
211 |
212 | def get_value(self, index):
213 | return index
214 |
215 | def evaluate(self, chorale):
216 | part = chorale.parts[0]
217 | length = int(part.duration.quarterLength * SUBDIVISION) # in 16th notes
218 | list_notes = part.flat.notes
219 | num_notes = len(list_notes)
220 | j = 0
221 | i = 0
222 | fermatas = np.zeros((length,))
223 | fermata = False
224 | while i < length:
225 | if j < num_notes - 1:
226 | if list_notes[j + 1].offset > i / SUBDIVISION:
227 |
228 | if len(list_notes[j].expressions) == 1:
229 | fermata = True
230 | else:
231 | fermata = False
232 | fermatas[i] = fermata
233 | i += 1
234 | else:
235 | j += 1
236 | else:
237 | if len(list_notes[j].expressions) == 1:
238 | fermata = True
239 | else:
240 | fermata = False
241 |
242 | fermatas[i] = fermata
243 | i += 1
244 | return np.array(fermatas, dtype=np.int32)
245 |
246 | def generate(self, length):
247 | # fermata every 2 bars
248 | return np.array([1 if i % 32 > 28 else 0
249 | for i in range(length)])
250 |
--------------------------------------------------------------------------------
/DeepBach/model_manager.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Gaetan Hadjeres
3 | """
4 |
5 | from DatasetManager.metadata import FermataMetadata
6 | import numpy as np
7 | import torch
8 | from DeepBach.helpers import cuda_variable, to_numpy
9 |
10 | from torch import optim, nn
11 | from tqdm import tqdm
12 |
13 | from DeepBach.voice_model import VoiceModel
14 |
15 |
16 | class DeepBach:
17 | def __init__(self,
18 | dataset,
19 | note_embedding_dim,
20 | meta_embedding_dim,
21 | num_layers,
22 | lstm_hidden_size,
23 | dropout_lstm,
24 | linear_hidden_size,
25 | ):
26 | self.dataset = dataset
27 | self.num_voices = self.dataset.num_voices
28 | self.num_metas = len(self.dataset.metadatas) + 1
29 | self.activate_cuda = torch.cuda.is_available()
30 |
31 | self.voice_models = [VoiceModel(
32 | dataset=self.dataset,
33 | main_voice_index=main_voice_index,
34 | note_embedding_dim=note_embedding_dim,
35 | meta_embedding_dim=meta_embedding_dim,
36 | num_layers=num_layers,
37 | lstm_hidden_size=lstm_hidden_size,
38 | dropout_lstm=dropout_lstm,
39 | hidden_size_linear=linear_hidden_size,
40 | )
41 | for main_voice_index in range(self.num_voices)
42 | ]
43 |
44 | def cuda(self, main_voice_index=None):
45 | if self.activate_cuda:
46 | if main_voice_index is None:
47 | for voice_index in range(self.num_voices):
48 | self.cuda(voice_index)
49 | else:
50 | self.voice_models[main_voice_index].cuda()
51 |
52 | # Utils
53 | def load(self, main_voice_index=None):
54 | if main_voice_index is None:
55 | for voice_index in range(self.num_voices):
56 | self.load(main_voice_index=voice_index)
57 | else:
58 | self.voice_models[main_voice_index].load()
59 |
60 | def save(self, main_voice_index=None):
61 | if main_voice_index is None:
62 | for voice_index in range(self.num_voices):
63 | self.save(main_voice_index=voice_index)
64 | else:
65 | self.voice_models[main_voice_index].save()
66 |
67 | def train(self, main_voice_index=None,
68 | **kwargs):
69 | if main_voice_index is None:
70 | for voice_index in range(self.num_voices):
71 | self.train(main_voice_index=voice_index, **kwargs)
72 | else:
73 | voice_model = self.voice_models[main_voice_index]
74 | if self.activate_cuda:
75 | voice_model.cuda()
76 | optimizer = optim.Adam(voice_model.parameters())
77 | voice_model.train_model(optimizer=optimizer, **kwargs)
78 |
79 | def eval_phase(self):
80 | for voice_model in self.voice_models:
81 | voice_model.eval()
82 |
83 | def train_phase(self):
84 | for voice_model in self.voice_models:
85 | voice_model.train()
86 |
87 | def generation(self,
88 | temperature=1.0,
89 | batch_size_per_voice=8,
90 | num_iterations=None,
91 | sequence_length_ticks=160,
92 | tensor_chorale=None,
93 | tensor_metadata=None,
94 | time_index_range_ticks=None,
95 | voice_index_range=None,
96 | fermatas=None,
97 | random_init=True
98 | ):
99 | """
100 |
101 | :param temperature:
102 | :param batch_size_per_voice:
103 | :param num_iterations:
104 | :param sequence_length_ticks:
105 | :param tensor_chorale:
106 | :param tensor_metadata:
107 | :param time_index_range_ticks: list of two integers [a, b] or None; can be used \
108 | to regenerate only the portion of the score between timesteps a and b
109 | :param voice_index_range: list of two integers [a, b] or None; can be used \
110 | to regenerate only the portion of the score between voice_index a and b
111 | :param fermatas: list[Fermata]
112 | :param random_init: boolean, whether or not to randomly initialize
113 | the portion of the score on which we apply the pseudo-Gibbs algorithm
114 | :return: tuple (
115 | generated_score [music21 Stream object],
116 | tensor_chorale (num_voices, chorale_length) torch.IntTensor,
117 | tensor_metadata (num_voices, chorale_length, num_metadata) torch.IntTensor
118 | )
119 | """
120 | self.eval_phase()
121 |
122 | # --Process arguments
123 | # initialize generated chorale
124 | # tensor_chorale = self.dataset.empty_chorale(sequence_length_ticks)
125 | if tensor_chorale is None:
126 | tensor_chorale = self.dataset.random_score_tensor(
127 | sequence_length_ticks)
128 | else:
129 | sequence_length_ticks = tensor_chorale.size(1)
130 |
131 | # initialize metadata
132 | if tensor_metadata is None:
133 | test_chorale = next(self.dataset.corpus_it_gen().__iter__())
134 | tensor_metadata = self.dataset.get_metadata_tensor(test_chorale)
135 |
136 | if tensor_metadata.size(1) < sequence_length_ticks:
137 | tensor_metadata = tensor_metadata.repeat(1, sequence_length_ticks // tensor_metadata.size(1) + 1, 1)
138 |
139 | # todo do not work if metadata_length_ticks > sequence_length_ticks
140 | tensor_metadata = tensor_metadata[:, :sequence_length_ticks, :]
141 | else:
142 | tensor_metadata_length = tensor_metadata.size(1)
143 | assert tensor_metadata_length == sequence_length_ticks
144 |
145 | if fermatas is not None:
146 | tensor_metadata = self.dataset.set_fermatas(tensor_metadata,
147 | fermatas)
148 |
149 | # timesteps_ticks is the number of ticks on which we unroll the LSTMs
150 | # it is also the padding size
151 | timesteps_ticks = self.dataset.sequences_size * self.dataset.subdivision // 2
152 | if time_index_range_ticks is None:
153 | time_index_range_ticks = [timesteps_ticks, sequence_length_ticks + timesteps_ticks]
154 | else:
155 | a_ticks, b_ticks = time_index_range_ticks
156 | assert 0 <= a_ticks < b_ticks <= sequence_length_ticks
157 | time_index_range_ticks = [a_ticks + timesteps_ticks, b_ticks + timesteps_ticks]
158 |
159 | if voice_index_range is None:
160 | voice_index_range = [0, self.dataset.num_voices]
161 |
162 | tensor_chorale = self.dataset.extract_score_tensor_with_padding(
163 | tensor_score=tensor_chorale,
164 | start_tick=-timesteps_ticks,
165 | end_tick=sequence_length_ticks + timesteps_ticks
166 | )
167 |
168 | tensor_metadata_padded = self.dataset.extract_metadata_with_padding(
169 | tensor_metadata=tensor_metadata,
170 | start_tick=-timesteps_ticks,
171 | end_tick=sequence_length_ticks + timesteps_ticks
172 | )
173 |
174 | # randomize regenerated part
175 | if random_init:
176 | a, b = time_index_range_ticks
177 | tensor_chorale[voice_index_range[0]:voice_index_range[1], a:b] = self.dataset.random_score_tensor(
178 | b - a)[voice_index_range[0]:voice_index_range[1], :]
179 |
180 | tensor_chorale = self.parallel_gibbs(
181 | tensor_chorale=tensor_chorale,
182 | tensor_metadata=tensor_metadata_padded,
183 | num_iterations=num_iterations,
184 | timesteps_ticks=timesteps_ticks,
185 | temperature=temperature,
186 | batch_size_per_voice=batch_size_per_voice,
187 | time_index_range_ticks=time_index_range_ticks,
188 | voice_index_range=voice_index_range,
189 | )
190 |
191 | # get fermata tensor
192 | for metadata_index, metadata in enumerate(self.dataset.metadatas):
193 | if isinstance(metadata, FermataMetadata):
194 | break
195 |
196 |
197 | score = self.dataset.tensor_to_score(
198 | tensor_score=tensor_chorale,
199 | fermata_tensor=tensor_metadata[:, :, metadata_index])
200 |
201 | return score, tensor_chorale, tensor_metadata
202 |
203 | def parallel_gibbs(self,
204 | tensor_chorale,
205 | tensor_metadata,
206 | timesteps_ticks,
207 | num_iterations=1000,
208 | batch_size_per_voice=16,
209 | temperature=1.,
210 | time_index_range_ticks=None,
211 | voice_index_range=None,
212 | ):
213 | """
214 | Parallel pseudo-Gibbs sampling
215 | tensor_chorale and tensor_metadata are padded with
216 | timesteps_ticks START_SYMBOLS before,
217 | timesteps_ticks END_SYMBOLS after
218 | :param tensor_chorale: (num_voices, chorale_length) tensor
219 | :param tensor_metadata: (num_voices, chorale_length) tensor
220 | :param timesteps_ticks:
221 | :param num_iterations: number of Gibbs sampling iterations
222 | :param batch_size_per_voice: number of simultaneous parallel updates
223 | :param temperature: final temperature after simulated annealing
224 | :param time_index_range_ticks: list of two integers [a, b] or None; can be used \
225 | to regenerate only the portion of the score between timesteps a and b
226 | :param voice_index_range: list of two integers [a, b] or None; can be used \
227 | to regenerate only the portion of the score between voice_index a and b
228 | :return: (num_voices, chorale_length) tensor
229 | """
230 | start_voice, end_voice = voice_index_range
231 | # add batch_dimension
232 | tensor_chorale = tensor_chorale.unsqueeze(0)
233 | tensor_chorale_no_cuda = tensor_chorale.clone()
234 | tensor_metadata = tensor_metadata.unsqueeze(0)
235 |
236 | # to variable
237 | tensor_chorale = cuda_variable(tensor_chorale, volatile=True)
238 | tensor_metadata = cuda_variable(tensor_metadata, volatile=True)
239 |
240 | min_temperature = temperature
241 | temperature = 1.1
242 |
243 | # Main loop
244 | for iteration in tqdm(range(num_iterations)):
245 | # annealing
246 | temperature = max(min_temperature, temperature * 0.9993)
247 | # print(temperature)
248 | time_indexes_ticks = {}
249 | probas = {}
250 |
251 | for voice_index in range(start_voice, end_voice):
252 | batch_notes = []
253 | batch_metas = []
254 |
255 | time_indexes_ticks[voice_index] = []
256 |
257 | # create batches of inputs
258 | for batch_index in range(batch_size_per_voice):
259 | time_index_ticks = np.random.randint(
260 | *time_index_range_ticks)
261 | time_indexes_ticks[voice_index].append(time_index_ticks)
262 |
263 | notes, label = (self.voice_models[voice_index]
264 | .preprocess_notes(
265 | tensor_chorale=tensor_chorale[
266 | :, :,
267 | time_index_ticks - timesteps_ticks:
268 | time_index_ticks + timesteps_ticks],
269 | time_index_ticks=timesteps_ticks
270 | )
271 | )
272 | metas = self.voice_models[voice_index].preprocess_metas(
273 | tensor_metadata=tensor_metadata[
274 | :, :,
275 | time_index_ticks - timesteps_ticks:
276 | time_index_ticks + timesteps_ticks,
277 | :],
278 | time_index_ticks=timesteps_ticks
279 | )
280 |
281 | batch_notes.append(notes)
282 | batch_metas.append(metas)
283 |
284 | # reshape batches
285 | batch_notes = list(map(list, zip(*batch_notes)))
286 | batch_notes = [torch.cat(lcr) if lcr[0] is not None else None
287 | for lcr in batch_notes]
288 | batch_metas = list(map(list, zip(*batch_metas)))
289 | batch_metas = [torch.cat(lcr)
290 | for lcr in batch_metas]
291 |
292 | # make all estimations
293 | probas[voice_index] = (self.voice_models[voice_index]
294 | .forward(batch_notes, batch_metas)
295 | )
296 | probas[voice_index] = nn.Softmax(dim=1)(probas[voice_index])
297 |
298 | # update all predictions
299 | for voice_index in range(start_voice, end_voice):
300 | for batch_index in range(batch_size_per_voice):
301 | probas_pitch = probas[voice_index][batch_index]
302 |
303 | probas_pitch = to_numpy(probas_pitch)
304 |
305 | # use temperature
306 | probas_pitch = np.log(probas_pitch) / temperature
307 | probas_pitch = np.exp(probas_pitch) / np.sum(
308 | np.exp(probas_pitch)) - 1e-7
309 |
310 | # avoid non-probabilities
311 | probas_pitch[probas_pitch < 0] = 0
312 |
313 | # pitch can include slur_symbol
314 | pitch = np.argmax(np.random.multinomial(1, probas_pitch))
315 |
316 | tensor_chorale_no_cuda[
317 | 0,
318 | voice_index,
319 | time_indexes_ticks[voice_index][batch_index]
320 | ] = int(pitch)
321 |
322 | tensor_chorale = cuda_variable(tensor_chorale_no_cuda.clone(),
323 | volatile=True)
324 |
325 | return tensor_chorale_no_cuda[0, :, timesteps_ticks:-timesteps_ticks]
326 |
--------------------------------------------------------------------------------
/DeepBach/voice_model.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Gaetan Hadjeres
3 | """
4 |
5 | import random
6 |
7 | import torch
8 | from DatasetManager.chorale_dataset import ChoraleDataset
9 | from DeepBach.helpers import cuda_variable, init_hidden
10 |
11 | from torch import nn
12 |
13 | from DeepBach.data_utils import reverse_tensor, mask_entry
14 |
15 |
16 | class VoiceModel(nn.Module):
17 | def __init__(self,
18 | dataset: ChoraleDataset,
19 | main_voice_index: int,
20 | note_embedding_dim: int,
21 | meta_embedding_dim: int,
22 | num_layers: int,
23 | lstm_hidden_size: int,
24 | dropout_lstm: float,
25 | hidden_size_linear=200
26 | ):
27 | super(VoiceModel, self).__init__()
28 | self.dataset = dataset
29 | self.main_voice_index = main_voice_index
30 | self.note_embedding_dim = note_embedding_dim
31 | self.meta_embedding_dim = meta_embedding_dim
32 | self.num_notes_per_voice = [len(d)
33 | for d in dataset.note2index_dicts]
34 | self.num_voices = self.dataset.num_voices
35 | self.num_metas_per_voice = [
36 | metadata.num_values
37 | for metadata in dataset.metadatas
38 | ] + [self.num_voices]
39 | self.num_metas = len(self.dataset.metadatas) + 1
40 | self.num_layers = num_layers
41 | self.lstm_hidden_size = lstm_hidden_size
42 | self.dropout_lstm = dropout_lstm
43 | self.hidden_size_linear = hidden_size_linear
44 |
45 | self.other_voices_indexes = [i
46 | for i
47 | in range(self.num_voices)
48 | if not i == main_voice_index]
49 |
50 | self.note_embeddings = nn.ModuleList(
51 | [nn.Embedding(num_notes, note_embedding_dim)
52 | for num_notes in self.num_notes_per_voice]
53 | )
54 | self.meta_embeddings = nn.ModuleList(
55 | [nn.Embedding(num_metas, meta_embedding_dim)
56 | for num_metas in self.num_metas_per_voice]
57 | )
58 |
59 | self.lstm_left = nn.LSTM(input_size=note_embedding_dim * self.num_voices +
60 | meta_embedding_dim * self.num_metas,
61 | hidden_size=lstm_hidden_size,
62 | num_layers=num_layers,
63 | dropout=dropout_lstm,
64 | batch_first=True)
65 | self.lstm_right = nn.LSTM(input_size=note_embedding_dim * self.num_voices +
66 | meta_embedding_dim * self.num_metas,
67 | hidden_size=lstm_hidden_size,
68 | num_layers=num_layers,
69 | dropout=dropout_lstm,
70 | batch_first=True)
71 |
72 | self.mlp_center = nn.Sequential(
73 | nn.Linear((note_embedding_dim * (self.num_voices - 1)
74 | + meta_embedding_dim * self.num_metas),
75 | hidden_size_linear),
76 | nn.ReLU(),
77 | nn.Linear(hidden_size_linear, lstm_hidden_size)
78 | )
79 |
80 | self.mlp_predictions = nn.Sequential(
81 | nn.Linear(self.lstm_hidden_size * 3,
82 | hidden_size_linear),
83 | nn.ReLU(),
84 | nn.Linear(hidden_size_linear, self.num_notes_per_voice[main_voice_index])
85 | )
86 |
87 | def forward(self, *input):
88 | notes, metas = input
89 | batch_size, num_voices, timesteps_ticks = notes[0].size()
90 |
91 | # put time first
92 | ln, cn, rn = notes
93 | ln, rn = [t.transpose(1, 2)
94 | for t in (ln, rn)]
95 | notes = ln, cn, rn
96 |
97 | # embedding
98 | notes_embedded = self.embed(notes, type='note')
99 | metas_embedded = self.embed(metas, type='meta')
100 | # lists of (N, timesteps_ticks, voices * dim_embedding)
101 | # where timesteps_ticks is 1 for central parts
102 |
103 | # concat notes and metas
104 | input_embedded = [torch.cat([notes, metas], 2) if notes is not None else None
105 | for notes, metas in zip(notes_embedded, metas_embedded)]
106 |
107 | left, center, right = input_embedded
108 |
109 | # main part
110 | hidden = init_hidden(
111 | num_layers=self.num_layers,
112 | batch_size=batch_size,
113 | lstm_hidden_size=self.lstm_hidden_size,
114 | )
115 | left, hidden = self.lstm_left(left, hidden)
116 | left = left[:, -1, :]
117 |
118 | if self.num_voices == 1:
119 | center = cuda_variable(torch.zeros(
120 | batch_size,
121 | self.lstm_hidden_size)
122 | )
123 | else:
124 | center = center[:, 0, :] # remove time dimension
125 | center = self.mlp_center(center)
126 |
127 | hidden = init_hidden(
128 | num_layers=self.num_layers,
129 | batch_size=batch_size,
130 | lstm_hidden_size=self.lstm_hidden_size,
131 | )
132 | right, hidden = self.lstm_right(right, hidden)
133 | right = right[:, -1, :]
134 |
135 | # concat and return prediction
136 | predictions = torch.cat([
137 | left, center, right
138 | ], 1)
139 |
140 | predictions = self.mlp_predictions(predictions)
141 |
142 | return predictions
143 |
144 | def embed(self, notes_or_metas, type):
145 | if type == 'note':
146 | embeddings = self.note_embeddings
147 | embedding_dim = self.note_embedding_dim
148 | other_voices_indexes = self.other_voices_indexes
149 | elif type == 'meta':
150 | embeddings = self.meta_embeddings
151 | embedding_dim = self.meta_embedding_dim
152 | other_voices_indexes = range(self.num_metas)
153 |
154 | batch_size, timesteps_left_ticks, num_voices = notes_or_metas[0].size()
155 | batch_size, timesteps_right_ticks, num_voices = notes_or_metas[2].size()
156 |
157 | left, center, right = notes_or_metas
158 | # center has self.num_voices - 1 voices
159 | left_embedded = torch.cat([
160 | embeddings[voice_id](left[:, :, voice_id])[:, :, None, :]
161 | for voice_id in range(num_voices)
162 | ], 2)
163 | right_embedded = torch.cat([
164 | embeddings[voice_id](right[:, :, voice_id])[:, :, None, :]
165 | for voice_id in range(num_voices)
166 | ], 2)
167 | if self.num_voices == 1 and type == 'note':
168 | center_embedded = None
169 | else:
170 | center_embedded = torch.cat([
171 | embeddings[voice_id](center[:, k].unsqueeze(1))
172 | for k, voice_id in enumerate(other_voices_indexes)
173 | ], 1)
174 | center_embedded = center_embedded.view(batch_size,
175 | 1,
176 | len(other_voices_indexes) * embedding_dim)
177 |
178 | # squeeze two last dimensions
179 | left_embedded = left_embedded.view(batch_size,
180 | timesteps_left_ticks,
181 | num_voices * embedding_dim)
182 | right_embedded = right_embedded.view(batch_size,
183 | timesteps_right_ticks,
184 | num_voices * embedding_dim)
185 |
186 | return left_embedded, center_embedded, right_embedded
187 |
188 | def save(self):
189 | torch.save(self.state_dict(), 'models/' + self.__repr__())
190 | print(f'Model {self.__repr__()} saved')
191 |
192 | def load(self):
193 | state_dict = torch.load('models/' + self.__repr__(),
194 | map_location=lambda storage, loc: storage)
195 | print(f'Loading {self.__repr__()}')
196 | self.load_state_dict(state_dict)
197 |
198 | def __repr__(self):
199 | return f'VoiceModel(' \
200 | f'{self.dataset.__repr__()},' \
201 | f'{self.main_voice_index},' \
202 | f'{self.note_embedding_dim},' \
203 | f'{self.meta_embedding_dim},' \
204 | f'{self.num_layers},' \
205 | f'{self.lstm_hidden_size},' \
206 | f'{self.dropout_lstm},' \
207 | f'{self.hidden_size_linear}' \
208 | f')'
209 |
210 | def train_model(self,
211 | batch_size=16,
212 | num_epochs=10,
213 | optimizer=None):
214 | for epoch in range(num_epochs):
215 | print(f'===Epoch {epoch}===')
216 | (dataloader_train,
217 | dataloader_val,
218 | dataloader_test) = self.dataset.data_loaders(
219 | batch_size=batch_size,
220 | )
221 |
222 | loss, acc = self.loss_and_acc(dataloader_train,
223 | optimizer=optimizer,
224 | phase='train')
225 | print(f'Training loss: {loss}')
226 | print(f'Training accuracy: {acc}')
227 | # writer.add_scalar('data/training_loss', loss, epoch)
228 | # writer.add_scalar('data/training_acc', acc, epoch)
229 |
230 | loss, acc = self.loss_and_acc(dataloader_val,
231 | optimizer=None,
232 | phase='test')
233 | print(f'Validation loss: {loss}')
234 | print(f'Validation accuracy: {acc}')
235 | self.save()
236 |
237 | def loss_and_acc(self, dataloader,
238 | optimizer=None,
239 | phase='train'):
240 |
241 | average_loss = 0
242 | average_acc = 0
243 | if phase == 'train':
244 | self.train()
245 | elif phase == 'eval' or phase == 'test':
246 | self.eval()
247 | else:
248 | raise NotImplementedError
249 | for tensor_chorale, tensor_metadata in dataloader:
250 |
251 | # to Variable
252 | tensor_chorale = cuda_variable(tensor_chorale).long()
253 | tensor_metadata = cuda_variable(tensor_metadata).long()
254 |
255 | # preprocessing to put in the DeepBach format
256 | # see Fig. 4 in DeepBach paper:
257 | # https://arxiv.org/pdf/1612.01010.pdf
258 | notes, metas, label = self.preprocess_input(tensor_chorale,
259 | tensor_metadata)
260 |
261 | weights = self.forward(notes, metas)
262 |
263 | loss_function = torch.nn.CrossEntropyLoss()
264 |
265 | loss = loss_function(weights, label)
266 |
267 | if phase == 'train':
268 | optimizer.zero_grad()
269 | loss.backward()
270 | optimizer.step()
271 |
272 | acc = self.accuracy(weights=weights,
273 | target=label)
274 |
275 | average_loss += loss.item()
276 | average_acc += acc.item()
277 |
278 | average_loss /= len(dataloader)
279 | average_acc /= len(dataloader)
280 | return average_loss, average_acc
281 |
282 | def accuracy(self, weights, target):
283 | batch_size, = target.size()
284 | softmax = nn.Softmax(dim=1)(weights)
285 | pred = softmax.max(1)[1].type_as(target)
286 | num_corrects = (pred == target).float().sum()
287 | return num_corrects / batch_size * 100
288 |
289 | def preprocess_input(self, tensor_chorale, tensor_metadata):
290 | """
291 | :param tensor_chorale: (batch_size, num_voices, chorale_length_ticks)
292 | :param tensor_metadata: (batch_size, num_metadata, chorale_length_ticks)
293 | :return: (notes, metas, label) tuple
294 | where
295 | notes = (left_notes, central_notes, right_notes)
296 | metas = (left_metas, central_metas, right_metas)
297 | label = (batch_size)
298 | right_notes and right_metas are REVERSED (from right to left)
299 | """
300 | batch_size, num_voices, chorale_length_ticks = tensor_chorale.size()
301 |
302 | # random shift! Depends on the dataset
303 | offset = random.randint(0, self.dataset.subdivision)
304 | time_index_ticks = chorale_length_ticks // 2 + offset
305 |
306 | # split notes
307 | notes, label = self.preprocess_notes(tensor_chorale, time_index_ticks)
308 | metas = self.preprocess_metas(tensor_metadata, time_index_ticks)
309 | return notes, metas, label
310 |
311 | def preprocess_notes(self, tensor_chorale, time_index_ticks):
312 | """
313 |
314 | :param tensor_chorale: (batch_size, num_voices, chorale_length_ticks)
315 | :param time_index_ticks:
316 | :return:
317 | """
318 | batch_size, num_voices, _ = tensor_chorale.size()
319 | left_notes = tensor_chorale[:, :, :time_index_ticks]
320 | right_notes = reverse_tensor(
321 | tensor_chorale[:, :, time_index_ticks + 1:],
322 | dim=2)
323 | if self.num_voices == 1:
324 | central_notes = None
325 | else:
326 | central_notes = mask_entry(tensor_chorale[:, :, time_index_ticks],
327 | entry_index=self.main_voice_index,
328 | dim=1)
329 | label = tensor_chorale[:, self.main_voice_index, time_index_ticks]
330 | return (left_notes, central_notes, right_notes), label
331 |
332 | def preprocess_metas(self, tensor_metadata, time_index_ticks):
333 | """
334 |
335 | :param tensor_metadata: (batch_size, num_voices, chorale_length_ticks)
336 | :param time_index_ticks:
337 | :return:
338 | """
339 |
340 | left_metas = tensor_metadata[:, self.main_voice_index, :time_index_ticks, :]
341 | right_metas = reverse_tensor(
342 | tensor_metadata[:, self.main_voice_index, time_index_ticks + 1:, :],
343 | dim=1)
344 | central_metas = tensor_metadata[:, self.main_voice_index, time_index_ticks, :]
345 | return left_metas, central_metas, right_metas
346 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:1.0-cuda10.0-cudnn7-runtime
2 |
3 | RUN git clone https://github.com/Ghadjeres/DeepBach.git
4 | WORKDIR DeepBach
5 | RUN conda env create --name deepbach_pytorch -f environment.yml
6 |
7 | RUN apt update && apt install wget
8 | RUN bash dl_dataset_and_models.sh
9 |
10 |
11 | COPY entrypoint.sh entrypoint.sh
12 | RUN chmod u+x entrypoint.sh
13 |
14 | EXPOSE 5000
15 | ENTRYPOINT ["./entrypoint.sh"]
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Gaetan Hadjeres, Laboratoire LIP6 - Departement ASIM Universite Pierre et Marie Curie
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepBach
2 | This repository contains implementations of the DeepBach model described in
3 |
4 | *DeepBach: a Steerable Model for Bach chorales generation*
5 | Gaëtan Hadjeres, François Pachet, Frank Nielsen
6 | *ICML 2017 [arXiv:1612.01010](http://proceedings.mlr.press/v70/hadjeres17a.html)*
7 |
8 |
9 | The code uses python 3.6 together with [PyTorch v1.0](https://pytorch.org/) and
10 | [music21](http://web.mit.edu/music21/) libraries.
11 |
12 | For the original Keras version, please checkout the `original_keras` branch.
13 |
14 | Examples of music generated by DeepBach are available on [this website](https://sites.google.com/site/deepbachexamples/)
15 |
16 | ## Installation
17 |
18 | You can clone this repository, install dependencies using Anaconda and download a pretrained
19 | model together with a dataset
20 | with the following commands:
21 | ```
22 | git clone git@github.com:Ghadjeres/DeepBach.git
23 | cd DeepBach
24 | conda env create --name deepbach_pytorch -f environment.yml
25 | bash dl_dataset_and_models.sh
26 | ```
27 | This will create a conda env named `deepbach_pytorch`.
28 |
29 | ### music21 editor
30 |
31 | You might need to
32 | Open a four-part chorale. Press enter on the server address, a list of computed models should appear. Select and (re)load a model.
33 | [Configure properly the music editor
34 | called by music21](http://web.mit.edu/music21/doc/moduleReference/moduleEnvironment.html). On Ubuntu you can eg. use MuseScore:
35 |
36 | ```shell
37 | sudo apt install musescore
38 | python -c 'import music21; music21.environment.set("musicxmlPath", "/usr/bin/musescore")'
39 | ```
40 |
41 | For usage on a headless server (no X server), just set it to a dummy command:
42 |
43 | ```shell
44 | python -c 'import music21; music21.environment.set("musicxmlPath", "/bin/true")'
45 | ```
46 |
47 | ## Usage
48 | ```
49 | Usage: deepBach.py [OPTIONS]
50 |
51 | Options:
52 | --note_embedding_dim INTEGER size of the note embeddings
53 | --meta_embedding_dim INTEGER size of the metadata embeddings
54 | --num_layers INTEGER number of layers of the LSTMs
55 | --lstm_hidden_size INTEGER hidden size of the LSTMs
56 | --dropout_lstm FLOAT amount of dropout between LSTM layers
57 | --linear_hidden_size INTEGER hidden size of the Linear layers
58 | --batch_size INTEGER training batch size
59 | --num_epochs INTEGER number of training epochs
60 | --train train or retrain the specified model
61 | --num_iterations INTEGER number of parallel pseudo-Gibbs sampling
62 | iterations
63 | --sequence_length_ticks INTEGER
64 | length of the generated chorale (in ticks)
65 | --help Show this message and exit.
66 | ```
67 |
68 | ## Examples
69 | You can generate a four-bar chorale with the pretrained model and display it in MuseScore by
70 | simply running
71 | ```
72 | python deepBach.py
73 | ```
74 |
75 | You can train a new model from scratch by adding the `--train` flag.
76 |
77 |
78 | ## Usage with NONOTO
79 | The command
80 | ```
81 | python flask_server.py
82 | ```
83 | starts a Flask server listening on port 5000. You can then use
84 | [NONOTO](https://github.com/SonyCSLParis/NONOTO) to compose with DeepBach in an interactive way.
85 |
86 | This server can also been started using Docker with:
87 | ```
88 | docker run -p 5000:5000 -it --rm ghadjeres/deepbach
89 | ```
90 | (CPU version), with
91 | or
92 | ```
93 | docker run --runtime=nvidia -p 5000:5000 -it --rm ghadjeres/deepbach
94 | ```
95 | (GPU version, requires [nvidia-docker](https://github.com/NVIDIA/nvidia-docker).
96 |
97 |
98 | ## Usage within MuseScore
99 | *Deprecated*
100 |
101 | Put `deepBachMuseScore.qml` file in your MuseScore plugins directory, and run
102 | ```
103 | python musescore_flask_server.py
104 | ```
105 | Open MuseScore and activate deepBachMuseScore plugin using the Plugin manager.
106 | You can then click on the Compose button without any selection to create a new chorale from
107 | scratch. You can then select a region in the chorale score and click on the Compose button to
108 | regenerated this region using DeepBach.
109 |
110 |
111 | ### Issues
112 |
113 | ### Music21 editor not set
114 |
115 | ```
116 | music21.converter.subConverters.SubConverterException: Cannot find a valid application path for format musicxml. Specify this in your Environment by calling environment.set(None, '/path/to/application')
117 | ```
118 |
119 | Either set it to MuseScore or similar (on a machine with GUI) to to a dummy command (on a server). See the installation section.
120 |
121 | # Citing
122 |
123 | Please consider citing this work or emailing me if you use DeepBach in musical projects.
124 | ```
125 | @InProceedings{pmlr-v70-hadjeres17a,
126 | title = {{D}eep{B}ach: a Steerable Model for {B}ach Chorales Generation},
127 | author = {Ga{\"e}tan Hadjeres and Fran{\c{c}}ois Pachet and Frank Nielsen},
128 | booktitle = {Proceedings of the 34th International Conference on Machine Learning},
129 | pages = {1362--1371},
130 | year = {2017},
131 | editor = {Doina Precup and Yee Whye Teh},
132 | volume = {70},
133 | series = {Proceedings of Machine Learning Research},
134 | address = {International Convention Centre, Sydney, Australia},
135 | month = {06--11 Aug},
136 | publisher = {PMLR},
137 | pdf = {http://proceedings.mlr.press/v70/hadjeres17a/hadjeres17a.pdf},
138 | url = {http://proceedings.mlr.press/v70/hadjeres17a.html},
139 | }
140 | ```
141 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | build:
2 | python_version: "3.7"
3 | gpu: false
4 | python_packages:
5 | - decorator==4.3.0
6 | - flask==1.0.2
7 | - flask-cors==3.0.7
8 | - itsdangerous==1.1.0
9 | - jinja2==2.10
10 | - markupsafe==1.1.0
11 | - music21==6.7.1
12 | - torch==1.0.0 --no-cache-dir
13 | - tqdm==4.29.0
14 | - werkzeug==0.14.1
15 | - numpy==1.19.5
16 | - midi2audio==0.1.1
17 | system_packages:
18 | - musescore
19 | - fluidsynth --fix-missing
20 | - ffmpeg
21 | predict: "predict.py:Predictor"
22 |
--------------------------------------------------------------------------------
/deepBach.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Gaetan Hadjeres
3 | """
4 |
5 | import click
6 |
7 | from DatasetManager.chorale_dataset import ChoraleDataset
8 | from DatasetManager.dataset_manager import DatasetManager
9 | from DatasetManager.metadata import FermataMetadata, TickMetadata, KeyMetadata
10 |
11 | from DeepBach.model_manager import DeepBach
12 |
13 |
14 | @click.command()
15 | @click.option('--note_embedding_dim', default=20,
16 | help='size of the note embeddings')
17 | @click.option('--meta_embedding_dim', default=20,
18 | help='size of the metadata embeddings')
19 | @click.option('--num_layers', default=2,
20 | help='number of layers of the LSTMs')
21 | @click.option('--lstm_hidden_size', default=256,
22 | help='hidden size of the LSTMs')
23 | @click.option('--dropout_lstm', default=0.5,
24 | help='amount of dropout between LSTM layers')
25 | @click.option('--linear_hidden_size', default=256,
26 | help='hidden size of the Linear layers')
27 | @click.option('--batch_size', default=256,
28 | help='training batch size')
29 | @click.option('--num_epochs', default=5,
30 | help='number of training epochs')
31 | @click.option('--train', is_flag=True,
32 | help='train the specified model for num_epochs')
33 | @click.option('--num_iterations', default=500,
34 | help='number of parallel pseudo-Gibbs sampling iterations')
35 | @click.option('--sequence_length_ticks', default=64,
36 | help='length of the generated chorale (in ticks)')
37 | def main(note_embedding_dim,
38 | meta_embedding_dim,
39 | num_layers,
40 | lstm_hidden_size,
41 | dropout_lstm,
42 | linear_hidden_size,
43 | batch_size,
44 | num_epochs,
45 | train,
46 | num_iterations,
47 | sequence_length_ticks,
48 | ):
49 | dataset_manager = DatasetManager()
50 |
51 | metadatas = [
52 | FermataMetadata(),
53 | TickMetadata(subdivision=4),
54 | KeyMetadata()
55 | ]
56 | chorale_dataset_kwargs = {
57 | 'voice_ids': [0, 1, 2, 3],
58 | 'metadatas': metadatas,
59 | 'sequences_size': 8,
60 | 'subdivision': 4
61 | }
62 | bach_chorales_dataset: ChoraleDataset = dataset_manager.get_dataset(
63 | name='bach_chorales',
64 | **chorale_dataset_kwargs
65 | )
66 | dataset = bach_chorales_dataset
67 |
68 | deepbach = DeepBach(
69 | dataset=dataset,
70 | note_embedding_dim=note_embedding_dim,
71 | meta_embedding_dim=meta_embedding_dim,
72 | num_layers=num_layers,
73 | lstm_hidden_size=lstm_hidden_size,
74 | dropout_lstm=dropout_lstm,
75 | linear_hidden_size=linear_hidden_size
76 | )
77 |
78 | if train:
79 | deepbach.train(batch_size=batch_size,
80 | num_epochs=num_epochs)
81 | else:
82 | deepbach.load()
83 | deepbach.cuda()
84 |
85 | print('Generation')
86 | score, tensor_chorale, tensor_metadata = deepbach.generation(
87 | num_iterations=num_iterations,
88 | sequence_length_ticks=sequence_length_ticks,
89 | )
90 | score.show('txt')
91 | score.show()
92 |
93 |
94 | if __name__ == '__main__':
95 | main()
96 |
--------------------------------------------------------------------------------
/deepBachMuseScore.qml:
--------------------------------------------------------------------------------
1 | import QtQuick 2.2
2 | import QtQuick.Dialogs 1.2
3 | import QtQuick.Controls 1.1
4 | import MuseScore 3.0
5 | import FileIO 3.0
6 |
7 | MuseScore {
8 | id: mainMuseScoreObj
9 | menuPath: "Plugins.deepBachMuseScore"
10 | description: qsTr("This plugin calls deepBach project.")
11 | pluginType: "dock"
12 | dockArea: "left"
13 | property variant serverCalled: false
14 | property variant loading: false
15 | property variant linesLogged: 0
16 | property variant serverAddress: "http://localhost:5000/"
17 | FileIO {
18 | id: myFile
19 | source: "/tmp/deepbach.mxl"
20 | onError: console.log(msg)
21 | }
22 | FileIO {
23 | id: myFileXml
24 | source: "/tmp/deepbach.xml"
25 | onError: console.log(msg)
26 | }
27 | onRun: {
28 | console.log('on run called')
29 | if (mainMuseScoreObj.serverCalled !== serverAddress || modelSelector.model.length === 0) {
30 | var requestModels = getRequestObj('GET', 'models')
31 | if (call(
32 | requestModels,
33 | false,
34 | function(responseText){
35 | mainMuseScoreObj.serverCalled = serverAddress;
36 | try {
37 | modelSelector.model = JSON.parse(responseText)
38 | logStatus('Models list loaded')
39 | var requestLoadModel = getRequestObj("GET", 'current_model')
40 | call(
41 | requestLoadModel,
42 | false,
43 | function(response) {
44 | logStatus('currently loaded model is ' + response)
45 | for (var i in modelSelector.model) {
46 | if (modelSelector.model[i] === response) {
47 | console.log('set selected at ' + i)
48 | modelSelector.currentIndex = i
49 | }
50 | }
51 | }
52 | )
53 | } catch(error) {
54 | console.log(error)
55 | logStatus('No models found')
56 | }
57 |
58 | }
59 | )) {
60 | logStatus('Retrieving models list at ' + serverAddress)
61 | }
62 | }
63 | }
64 | Rectangle {
65 | id: wrapperPanel
66 | color: "white"
67 | Text {
68 | id: title
69 | text: "DeepBach"
70 | font.family: "Helvetica"
71 | font.pointSize: 20
72 | color: "black"
73 | anchors.top: wrapperPanel.top
74 | anchors.topMargin: 10
75 | font.underline: false
76 | }
77 | Button {
78 | id : loadModel
79 | anchors.top: title.bottom
80 | anchors.topMargin: 15
81 | text: qsTr("Load")
82 | onClicked: {
83 | if (modelSelector.model[modelSelector.currentIndex]) {
84 | var requestLoadModel = getRequestObj("POST", 'current_model')
85 | if (call(
86 | requestLoadModel,
87 | {
88 | model_name: modelSelector.model[modelSelector.currentIndex]
89 | },
90 | function(response) {
91 | logStatus(response)
92 | }
93 | )) {
94 | logStatus('Loading model ' + modelSelector.model[modelSelector.currentIndex])
95 | }
96 | }
97 | }
98 | }
99 | ComboBox {
100 | id: modelSelector
101 | anchors.top: modelSelector.top
102 | anchors.left: modelSelector.right
103 | anchors.leftMargin: 10
104 | model: []
105 | width: 100
106 | visible: false
107 | }
108 | Button {
109 | id : buttonOpenFile
110 | text: qsTr("Compose")
111 | anchors.top: loadModel.top
112 | anchors.topMargin: 30
113 | anchors.bottomMargin: 10
114 | onClicked: {
115 | var cursor = curScore.newCursor();
116 | cursor.rewind(1);
117 | var startStaff = cursor.staffIdx;
118 | var startTick = cursor.tick;
119 | cursor.rewind(2);
120 | var endStaff = cursor.staffIdx;
121 | var endTick = cursor.tick;
122 | var extension = 'mxl'
123 | // var targetFile = tempPath() + "/my_file." + extension
124 | myFile.remove();
125 | logStatus(myFile.source);
126 | var res = writeScore(mainMuseScoreObj.curScore, myFile.source, extension)
127 | logStatus(res);
128 | var request = getRequestObj("POST", 'compose')
129 | if (call(
130 | request,
131 | {
132 | start_staff: startStaff,
133 | end_staff: endStaff,
134 | start_tick: startTick,
135 | end_tick: endTick,
136 | // xml_string: myFile.read(),
137 | file_path: myFile.source
138 | },
139 | function(response) {
140 | if (response) {
141 | logStatus('Done composing')
142 | // myFileXml.write(response)
143 | readScore(myFileXml.source)
144 |
145 | } else {
146 | logStatus('Got Empty Response when composing')
147 | }
148 | }
149 | )) {
150 | logStatus('Composing...')
151 | }
152 | }
153 | }
154 | Label {
155 | id: statusLabel
156 | wrapMode: Text.WordWrap
157 | text: ''
158 | color: 'grey'
159 | font.pointSize:12
160 | anchors.left: wrapperPanel.left
161 | anchors.top: buttonOpenFile.top
162 | anchors.leftMargin: 10
163 | anchors.topMargin: 30
164 | visible: false
165 | }
166 | }
167 | function logStatus(text) {
168 | mainMuseScoreObj.linesLogged++;
169 | if (mainMuseScoreObj.linesLogged > 15) {
170 | // break the textblock into an array of lines
171 | var lines = statusLabel.text.split("\r\n");
172 | // remove one line, starting at the first position
173 | lines.splice(0,1);
174 | // join the array back into a single string
175 | statusLabel.text = lines.join("\r\n");
176 | }
177 | statusLabel.text += '- ' + text + "\r\n"
178 | }
179 | function getRequestObj(method, endpoint) {
180 | console.debug('calling endpoint ' + endpoint)
181 | var request = new XMLHttpRequest()
182 | endpoint = endpoint || ''
183 | request.open(method, serverAddress + endpoint, true)
184 | return request
185 | }
186 | function call(request, params, cb) {
187 | if (mainMuseScoreObj.loading) {
188 | logStatus('refusing to call server')
189 | return false
190 | }
191 | request.onreadystatechange = function() {
192 | if (request.readyState == XMLHttpRequest.DONE) {
193 | mainMuseScoreObj.loading = false;
194 | cb(request.responseText);
195 | }
196 | }
197 | if (params) {
198 | request.setRequestHeader("Content-Type", "application/x-www-form-urlencoded")
199 | var pairs = [];
200 | for (var prop in params) {
201 | if (params.hasOwnProperty(prop)) {
202 | var k = encodeURIComponent(prop),
203 | v = encodeURIComponent(params[prop]);
204 | pairs.push( k + "=" + v);
205 | }
206 | }
207 |
208 | const content = pairs.join('&')
209 | console.debug('params ' + content)
210 | mainMuseScoreObj.loading = true;
211 | request.send(content)
212 | } else {
213 | mainMuseScoreObj.loading = true;
214 | request.send()
215 | }
216 | return true
217 | }
218 | }
219 |
--------------------------------------------------------------------------------
/dl_dataset_and_models.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | wget https://www.dropbox.com/s/iuz4ml857ycyyat/deepbach_pytorch_resources.tar.gz
3 | tar xvfz deepbach_pytorch_resources.tar.gz
4 | # move resources/{datasets,models} to datasets/ and models/
5 | mv resources/dataset_cache DatasetManager/dataset_cache
6 | mv resources/models ./models
7 | rm -R resources
--------------------------------------------------------------------------------
/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source activate deepbach_pytorch
4 | LC_ALL=C.UTF-8 LANG=C.UTF-8 python flask_server.py "$@"
5 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: deepbach_pytorch
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - blas=1.0=mkl
7 | - ca-certificates=2018.03.07=0
8 | - certifi=2018.11.29=py36_0
9 | - cffi=1.11.5=py36he75722e_1
10 | - click=7.0=py36_0
11 | - freetype=2.9.1=h8a8886c_1
12 | - intel-openmp=2019.1=144
13 | - jpeg=9b=h024ee3a_2
14 | - libedit=3.1.20170329=h6b74fdf_2
15 | - libffi=3.2.1=hd88cf55_4
16 | - libgcc-ng=8.2.0=hdf63c60_1
17 | - libgfortran-ng=7.3.0=hdf63c60_0
18 | - libpng=1.6.35=hbc83047_0
19 | - libstdcxx-ng=8.2.0=hdf63c60_1
20 | - libtiff=4.0.9=he85c1e1_2
21 | - mkl=2019.1=144
22 | - mkl_fft=1.0.6=py36hd81dba3_0
23 | - mkl_random=1.0.2=py36hd81dba3_0
24 | - ncurses=6.1=he6710b0_1
25 | - ninja=1.8.2=py36h6bb024c_1
26 | - numpy=1.15.4=py36h7e9f1db_0
27 | - numpy-base=1.15.4=py36hde5b4d6_0
28 | - olefile=0.46=py36_0
29 | - openssl=1.1.1a=h7b6447c_0
30 | - pillow=5.3.0=py36h34e0f95_0
31 | - pip=18.1=py36_0
32 | - pycparser=2.19=py36_0
33 | - python=3.6.8=h0371630_0
34 | - readline=7.0=h7b6447c_5
35 | - setuptools=40.6.3=py36_0
36 | - six=1.12.0=py36_0
37 | - sqlite=3.26.0=h7b6447c_0
38 | - tk=8.6.8=hbc83047_0
39 | - wheel=0.32.3=py36_0
40 | - xz=5.2.4=h14c3975_4
41 | - zlib=1.2.11=h7b6447c_3
42 | - pytorch=1.0.0
43 | - cuda100
44 | - torchvision=0.2.1=py_2
45 | - pip:
46 | - decorator==4.3.0
47 | - flask==1.0.2
48 | - flask-cors==3.0.7
49 | - itsdangerous==1.1.0
50 | - jinja2==2.10
51 | - markupsafe==1.1.0
52 | - music21==5.5.0
53 | - torch==1.0.0
54 | - tqdm==4.29.0
55 | - werkzeug==0.14.1
56 | prefix: /home/gaetan/Public/Python/anaconda3/envs/deepbach_pytorch
57 |
58 |
--------------------------------------------------------------------------------
/flask_server.py:
--------------------------------------------------------------------------------
1 | from DatasetManager.chorale_dataset import ChoraleDataset
2 | from DatasetManager.dataset_manager import DatasetManager
3 | from DatasetManager.metadata import FermataMetadata, TickMetadata, KeyMetadata
4 | from DeepBach.model_manager import DeepBach
5 |
6 | from music21 import musicxml, metadata
7 | import music21
8 |
9 | import flask
10 | from flask import Flask, request, make_response
11 | from flask_cors import CORS
12 |
13 | import logging
14 | from logging import handlers as logging_handlers
15 | import sys
16 |
17 | import torch
18 | import math
19 | from typing import List, Optional
20 | import click
21 | import os
22 |
23 | app = Flask(__name__)
24 | CORS(app)
25 |
26 | app.config['UPLOAD_FOLDER'] = './uploads'
27 | ALLOWED_EXTENSIONS = {'midi'}
28 |
29 | # INITIALIZATION
30 | xml_response_headers = {"Content-Type": "text/xml",
31 | "charset": "utf-8"
32 | }
33 | mp3_response_headers = {"Content-Type": "audio/mpeg3"
34 | }
35 |
36 | deepbach = None
37 | _num_iterations = None
38 | _sequence_length_ticks = None
39 | _ticks_per_quarter = None
40 |
41 | # TODO use this parameter or extract it from the metadata somehow
42 | timesignature = music21.meter.TimeSignature('4/4')
43 |
44 | # generation parameters
45 | # todo put in click?
46 | batch_size_per_voice = 8
47 |
48 | metadatas = [
49 | FermataMetadata(),
50 | TickMetadata(subdivision=_ticks_per_quarter),
51 | KeyMetadata()
52 | ]
53 |
54 |
55 | def get_fermatas_tensor(metadata_tensor: torch.Tensor) -> torch.Tensor:
56 | """
57 | Extract the fermatas tensor from a metadata tensor
58 | """
59 | fermatas_index = [m.__class__ for m in metadatas].index(
60 | FermataMetadata().__class__)
61 | # fermatas are shared across all voices so we only consider the first voice
62 | soprano_voice_metadata = metadata_tensor[0]
63 |
64 | # `soprano_voice_metadata` has shape
65 | # `(sequence_duration, len(metadatas + 1))` (accouting for the voice
66 | # index metadata)
67 | # Extract fermatas for all steps
68 | return soprano_voice_metadata[:, fermatas_index]
69 |
70 |
71 | @click.command()
72 | @click.option('--note_embedding_dim', default=20,
73 | help='size of the note embeddings')
74 | @click.option('--meta_embedding_dim', default=20,
75 | help='size of the metadata embeddings')
76 | @click.option('--num_layers', default=2,
77 | help='number of layers of the LSTMs')
78 | @click.option('--lstm_hidden_size', default=256,
79 | help='hidden size of the LSTMs')
80 | @click.option('--dropout_lstm', default=0.5,
81 | help='amount of dropout between LSTM layers')
82 | @click.option('--dropout_lstm', default=0.5,
83 | help='amount of dropout between LSTM layers')
84 | @click.option('--linear_hidden_size', default=256,
85 | help='hidden size of the Linear layers')
86 | @click.option('--num_iterations', default=50,
87 | help='number of parallel pseudo-Gibbs sampling iterations (for a single update)')
88 | @click.option('--sequence_length_ticks', default=64,
89 | help='length of the generated chorale (in ticks)')
90 | @click.option('--ticks_per_quarter', default=4,
91 | help='number of ticks per quarter note')
92 | @click.option('--port', default=5000,
93 | help='port to serve on')
94 | def init_app(note_embedding_dim,
95 | meta_embedding_dim,
96 | num_layers,
97 | lstm_hidden_size,
98 | dropout_lstm,
99 | linear_hidden_size,
100 | num_iterations,
101 | sequence_length_ticks,
102 | ticks_per_quarter,
103 | port
104 | ):
105 | global metadatas
106 | global _sequence_length_ticks
107 | global _num_iterations
108 | global _ticks_per_quarter
109 |
110 | _ticks_per_quarter = ticks_per_quarter
111 | _sequence_length_ticks = sequence_length_ticks
112 | _num_iterations = num_iterations
113 |
114 | dataset_manager = DatasetManager()
115 | chorale_dataset_kwargs = {
116 | 'voice_ids': [0, 1, 2, 3],
117 | 'metadatas': metadatas,
118 | 'sequences_size': 8,
119 | 'subdivision': 4
120 | }
121 |
122 | bach_chorales_dataset: ChoraleDataset = dataset_manager.get_dataset(
123 | name='bach_chorales',
124 | **chorale_dataset_kwargs
125 | )
126 | assert sequence_length_ticks % bach_chorales_dataset.subdivision == 0
127 |
128 | global deepbach
129 | deepbach = DeepBach(
130 | dataset=bach_chorales_dataset,
131 | note_embedding_dim=note_embedding_dim,
132 | meta_embedding_dim=meta_embedding_dim,
133 | num_layers=num_layers,
134 | lstm_hidden_size=lstm_hidden_size,
135 | dropout_lstm=dropout_lstm,
136 | linear_hidden_size=linear_hidden_size
137 | )
138 | deepbach.load()
139 | deepbach.cuda()
140 |
141 | # launch the script
142 | # use threaded=True to fix Chrome/Chromium engine hanging on requests
143 | # [https://stackoverflow.com/a/30670626]
144 | local_only = False
145 | if local_only:
146 | # accessible only locally:
147 | app.run(threaded=True)
148 | else:
149 | # accessible from outside:
150 | app.run(host='0.0.0.0', port=port, threaded=True)
151 |
152 |
153 | @app.route('/generate', methods=['GET', 'POST'])
154 | def compose():
155 | """
156 | Return a new, generated sheet
157 | Usage:
158 | - Request: empty, generation is done in an unconstrained fashion
159 | - Response: a sheet, MusicXML
160 | """
161 | global deepbach
162 | global _sequence_length_ticks
163 | global _num_iterations
164 |
165 | # Use more iterations for the initial generation step
166 | # FIXME hardcoded 4/4 time-signature
167 | num_measures_generation = math.floor(_sequence_length_ticks /
168 | deepbach.dataset.subdivision)
169 | initial_num_iterations = math.floor(_num_iterations * num_measures_generation
170 | / 3) # HACK hardcoded reduction
171 |
172 | (generated_sheet, _, generated_metadata_tensor) = (
173 | deepbach.generation(num_iterations=initial_num_iterations,
174 | sequence_length_ticks=_sequence_length_ticks)
175 | )
176 |
177 | generated_fermatas_tensor = get_fermatas_tensor(generated_metadata_tensor)
178 |
179 | # convert sheet to xml
180 | response = sheet_and_fermatas_to_json_response(
181 | generated_sheet, generated_fermatas_tensor)
182 | return response
183 |
184 |
185 | @app.route('/test-generate', methods=['GET'])
186 | def ex():
187 | _current_sheet = next(music21.corpus.chorales.Iterator())
188 | return sheet_to_xml_response(_current_sheet)
189 |
190 |
191 | @app.route('/musicxml-to-midi', methods=['POST'])
192 | def get_midi():
193 | """
194 | Convert the provided MusicXML sheet to MIDI and return it
195 | Usage:
196 | POST -d @sheet.mxml /musicxml-to-midi
197 | - Request: the payload is expected to contain the sheet to convert, in
198 | MusicXML format
199 | - Response: a MIDI file
200 | """
201 | sheetString = request.data
202 | sheet = music21.converter.parseData(sheetString, format="musicxml")
203 | insert_musicxml_metadata(sheet)
204 |
205 | return sheet_to_midi_response(sheet)
206 |
207 |
208 | @app.route('/timerange-change', methods=['POST'])
209 | def timerange_change():
210 | """
211 | Perform local re-generation on a sheet and return the updated sheet
212 | Usage:
213 | POST /timerange-change?time_range_start_beat=XXX&time_range_end_beat=XXX
214 | - Request:
215 | The payload is expected to be a JSON with the following keys:
216 | * 'sheet': a string containing the sheet to modify, in MusicXML
217 | format
218 | * 'fermatas': a list of integers describing the positions of
219 | fermatas in the sheet
220 | TODO: could store the fermatas in the MusicXML client-side
221 | The start and end positions (in beats) of the portion to regenerate
222 | are passed as arguments in the URL:
223 | * time_range_start_quarter, integer:
224 | - Response:
225 | A JSON document with same schema as the request containing the
226 | updated sheet and fermatas
227 | """
228 | global deepbach
229 | global _num_iterations
230 | global _sequence_length_ticks
231 | request_parameters = parse_timerange_request(request)
232 | time_range_start_quarter = request_parameters['time_range_start_quarter']
233 | time_range_end_quarter = request_parameters['time_range_end_quarter']
234 | fermatas_tensor = request_parameters['fermatas_tensor']
235 |
236 | input_sheet = request_parameters['sheet']
237 |
238 | time_index_range_ticks = [
239 | time_range_start_quarter * deepbach.dataset.subdivision,
240 | time_range_end_quarter * deepbach.dataset.subdivision]
241 |
242 | input_tensor_sheet, input_tensor_metadata = (
243 | deepbach.dataset.transposed_score_and_metadata_tensors(
244 | input_sheet, 0)
245 | )
246 |
247 | (output_sheet,
248 | output_tensor_sheet,
249 | output_tensor_metadata) = deepbach.generation(
250 | tensor_chorale=input_tensor_sheet,
251 | tensor_metadata=input_tensor_metadata,
252 | temperature=1.,
253 | batch_size_per_voice=batch_size_per_voice,
254 | num_iterations=_num_iterations,
255 | sequence_length_ticks=_sequence_length_ticks,
256 | time_index_range_ticks=time_index_range_ticks,
257 | fermatas=fermatas_tensor
258 | )
259 |
260 | output_fermatas_tensor = get_fermatas_tensor(output_tensor_metadata)
261 |
262 | # create JSON response
263 | response = sheet_and_fermatas_to_json_response(
264 | output_sheet, output_fermatas_tensor)
265 | return response
266 |
267 |
268 | @app.route('/analyze-notes', methods=['POST'])
269 | def dummy_read_audio_file():
270 | global deepbach
271 | import wave
272 | print(request.args)
273 | print(request.files)
274 | chunk = 1024
275 | audio_fp = wave.open(request.files['audio'], 'rb')
276 | data = audio_fp.readframes(chunk)
277 | print(data)
278 | notes = ['C', 'D', 'Toto', 'Tata']
279 |
280 | return flask.jsonify({'success': True, 'notes': notes})
281 |
282 |
283 | def insert_musicxml_metadata(sheet: music21.stream.Stream):
284 | """
285 | Insert various metadata into the provided XML document
286 | The timesignature in particular is required for proper MIDI conversion
287 | """
288 | global timesignature
289 |
290 | from music21.clef import TrebleClef, BassClef, Treble8vbClef
291 | for part, name, clef in zip(
292 | sheet.parts,
293 | ['soprano', 'alto', 'tenor', 'bass'],
294 | [TrebleClef(), TrebleClef(), Treble8vbClef(), BassClef()]
295 | ):
296 | # empty_part = part.template()
297 | part.insert(0, timesignature)
298 | part.insert(0, clef)
299 | part.id = name
300 | part.partName = name
301 |
302 | md = metadata.Metadata()
303 | sheet.insert(0, md)
304 |
305 | # required for proper musicXML formatting
306 | sheet.metadata.title = 'DeepBach'
307 | sheet.metadata.composer = 'DeepBach'
308 |
309 |
310 | def parse_fermatas(fermatas_list: List[int]) -> Optional[torch.Tensor]:
311 | """
312 | Parses fermata GET option, given at the quarter note level
313 | """
314 | global _sequence_length_ticks
315 | # the data is expected to be provided as a list in the request
316 | return fermatas_to_tensor(fermatas_list)
317 |
318 |
319 | def fermatas_to_tensor(fermatas: List[int]) -> torch.Tensor:
320 | """
321 | Convert a list of fermata positions (in beats) into a subdivion-rate tensor
322 | """
323 | global _sequence_length_ticks
324 | global deepbach
325 | subdivision = deepbach.dataset.subdivision
326 | sequence_length_quarterNotes = math.floor(_sequence_length_ticks / subdivision)
327 |
328 | fermatas_tensor_quarterNotes = torch.zeros(sequence_length_quarterNotes)
329 | fermatas_tensor_quarterNotes[fermatas] = 1
330 | # expand the tensor to the subdivision level
331 | fermatas_tensor = (fermatas_tensor_quarterNotes
332 | .repeat((subdivision, 1))
333 | .t()
334 | .contiguous())
335 | return fermatas_tensor.view(_sequence_length_ticks)
336 |
337 |
338 | def fermatas_tensor_to_list(fermatas_tensor: torch.Tensor) -> List[int]:
339 | """
340 | Convert a binary fermatas tensor into a list of positions (in beats)
341 | """
342 | global _sequence_length_ticks
343 | global deepbach
344 |
345 | subdivision = deepbach.dataset.subdivision
346 |
347 | # subsample fermatas to beat rate
348 | beat_rate_fermatas_tensor = fermatas_tensor[::subdivision]
349 |
350 | # pick positions of active fermatas
351 | fermatas_positions_tensor = beat_rate_fermatas_tensor.nonzero().squeeze()
352 | fermatas = fermatas_positions_tensor.int().tolist()
353 |
354 | return fermatas
355 |
356 |
357 | def parse_timerange_request(request):
358 | """
359 | must cast
360 | :param req:
361 | :return:
362 | """
363 | json_data = request.get_json(force=True)
364 | time_range_start_quarter = int(request.args.get('time_range_start_quarter'))
365 | time_range_end_quarter = int(request.args.get('time_range_end_quarter'))
366 | fermatas_tensor = parse_fermatas(json_data['fermatas'])
367 |
368 | sheet = music21.converter.parseData(json_data['sheet'], format="musicxml")
369 |
370 | return {
371 | 'sheet': sheet,
372 | 'time_range_start_quarter': time_range_start_quarter,
373 | 'time_range_end_quarter': time_range_end_quarter,
374 | 'fermatas_tensor': fermatas_tensor
375 | }
376 |
377 |
378 | def sheet_to_xml_bytes(sheet: music21.stream.Stream):
379 | """Convert a music21 sheet to a MusicXML document"""
380 | # first insert necessary MusicXML metadata
381 | insert_musicxml_metadata(sheet)
382 |
383 | sheet_to_xml_bytes = musicxml.m21ToXml.GeneralObjectExporter(sheet).parse()
384 |
385 | return sheet_to_xml_bytes
386 |
387 |
388 | def sheet_to_xml_response(sheet: music21.stream.Stream):
389 | """Generate and send XML sheet"""
390 | xml_sheet_bytes = sheet_to_xml_bytes(sheet)
391 |
392 | response = flask.make_response((xml_sheet_bytes, xml_response_headers))
393 | return response
394 |
395 |
396 | def sheet_and_fermatas_to_json_response(sheet: music21.stream.Stream,
397 | fermatas_tensor: torch.Tensor):
398 | sheet_xml_string = sheet_to_xml_bytes(sheet).decode('utf-8')
399 | fermatas_list = fermatas_tensor_to_list(fermatas_tensor)
400 |
401 | print(fermatas_list)
402 |
403 | return flask.jsonify({
404 | 'sheet': sheet_xml_string,
405 | 'fermatas': fermatas_list
406 | })
407 |
408 |
409 | def sheet_to_midi_response(sheet):
410 | """
411 | Convert the provided sheet to midi and send it as a file
412 | """
413 | midiFile = sheet.write('midi')
414 | return flask.send_file(midiFile, mimetype="audio/midi",
415 | cache_timeout=-1 # disable cache
416 | )
417 |
418 |
419 | def sheet_to_mp3_response(sheet):
420 | """Generate and send MP3 file
421 | Uses server-side `timidity`
422 | """
423 | sheet.write('midi', fp='./uploads/midi.mid')
424 | os.system(f'rm uploads/midi.mp3')
425 | os.system(f'timidity uploads/midi.mid -Ow -o - | '
426 | f'ffmpeg -i - -acodec libmp3lame -ab 64k '
427 | f'uploads/midi.mp3')
428 | return flask.send_file('uploads/midi.mp3')
429 |
430 |
431 | if __name__ == '__main__':
432 | file_handler = logging_handlers.RotatingFileHandler(
433 | 'app.log', maxBytes=10000, backupCount=5)
434 |
435 | app.logger.addHandler(file_handler)
436 | app.logger.setLevel(logging.INFO)
437 | init_app()
438 |
--------------------------------------------------------------------------------
/musescore_flask_server.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import pickle
4 | import click
5 | import tempfile
6 | from glob import glob
7 | import subprocess
8 |
9 | import music21
10 | import numpy as np
11 | from flask import Flask, request, make_response, jsonify
12 | from music21 import musicxml, converter
13 | from tqdm import tqdm
14 | import torch
15 | import logging
16 | from logging import handlers as logging_handlers
17 |
18 | from DatasetManager.chorale_dataset import ChoraleDataset
19 | from DatasetManager.dataset_manager import DatasetManager
20 | from DatasetManager.metadata import FermataMetadata, TickMetadata, KeyMetadata
21 | from DeepBach.model_manager import DeepBach
22 |
23 | UPLOAD_FOLDER = '/tmp'
24 | ALLOWED_EXTENSIONS = {'xml', 'mxl', 'mid', 'midi'}
25 |
26 | app = Flask(__name__)
27 |
28 | deepbach = None
29 | _tensor_metadata = None
30 | _num_iterations = None
31 | _sequence_length_ticks = None
32 | _ticks_per_quarter = None
33 | _tensor_sheet = None
34 |
35 | # TODO use this parameter or extract it from the metadata somehow
36 | timesignature = music21.meter.TimeSignature('4/4')
37 |
38 | # generation parameters
39 | # todo put in click?
40 | batch_size_per_voice = 8
41 |
42 | metadatas = [
43 | FermataMetadata(),
44 | TickMetadata(subdivision=_ticks_per_quarter),
45 | KeyMetadata()
46 | ]
47 |
48 | response_headers = {"Content-Type": "text/html",
49 | "charset": "utf-8"
50 | }
51 |
52 |
53 | @click.command()
54 | @click.option('--note_embedding_dim', default=20,
55 | help='size of the note embeddings')
56 | @click.option('--meta_embedding_dim', default=20,
57 | help='size of the metadata embeddings')
58 | @click.option('--num_layers', default=2,
59 | help='number of layers of the LSTMs')
60 | @click.option('--lstm_hidden_size', default=256,
61 | help='hidden size of the LSTMs')
62 | @click.option('--dropout_lstm', default=0.5,
63 | help='amount of dropout between LSTM layers')
64 | @click.option('--dropout_lstm', default=0.5,
65 | help='amount of dropout between LSTM layers')
66 | @click.option('--linear_hidden_size', default=256,
67 | help='hidden size of the Linear layers')
68 | @click.option('--num_iterations', default=100,
69 | help='number of parallel pseudo-Gibbs sampling iterations (for a single update)')
70 | @click.option('--sequence_length_ticks', default=64,
71 | help='length of the generated chorale (in ticks)')
72 | @click.option('--ticks_per_quarter', default=4,
73 | help='number of ticks per quarter note')
74 | @click.option('--port', default=5000,
75 | help='port to serve on')
76 | def init_app(note_embedding_dim,
77 | meta_embedding_dim,
78 | num_layers,
79 | lstm_hidden_size,
80 | dropout_lstm,
81 | linear_hidden_size,
82 | num_iterations,
83 | sequence_length_ticks,
84 | ticks_per_quarter,
85 | port
86 | ):
87 | global metadatas
88 | global _sequence_length_ticks
89 | global _num_iterations
90 | global _ticks_per_quarter
91 | global bach_chorales_dataset
92 |
93 | _ticks_per_quarter = ticks_per_quarter
94 | _sequence_length_ticks = sequence_length_ticks
95 | _num_iterations = num_iterations
96 |
97 | dataset_manager = DatasetManager()
98 | chorale_dataset_kwargs = {
99 | 'voice_ids': [0, 1, 2, 3],
100 | 'metadatas': metadatas,
101 | 'sequences_size': 8,
102 | 'subdivision': 4
103 | }
104 |
105 | _bach_chorales_dataset: ChoraleDataset = dataset_manager.get_dataset(
106 | name='bach_chorales',
107 | **chorale_dataset_kwargs
108 | )
109 | bach_chorales_dataset = _bach_chorales_dataset
110 |
111 | assert sequence_length_ticks % bach_chorales_dataset.subdivision == 0
112 |
113 | global deepbach
114 | deepbach = DeepBach(
115 | dataset=bach_chorales_dataset,
116 | note_embedding_dim=note_embedding_dim,
117 | meta_embedding_dim=meta_embedding_dim,
118 | num_layers=num_layers,
119 | lstm_hidden_size=lstm_hidden_size,
120 | dropout_lstm=dropout_lstm,
121 | linear_hidden_size=linear_hidden_size
122 | )
123 | deepbach.load()
124 | deepbach.cuda()
125 |
126 | # launch the script
127 | # use threaded=True to fix Chrome/Chromium engine hanging on requests
128 | # [https://stackoverflow.com/a/30670626]
129 | local_only = False
130 | if local_only:
131 | # accessible only locally:
132 | app.run(threaded=True)
133 | else:
134 | # accessible from outside:
135 | app.run(host='0.0.0.0', port=port, threaded=True)
136 |
137 |
138 | def get_fermatas_tensor(metadata_tensor: torch.Tensor) -> torch.Tensor:
139 | """
140 | Extract the fermatas tensor from a metadata tensor
141 | """
142 | fermatas_index = [m.__class__ for m in metadatas].index(
143 | FermataMetadata().__class__)
144 | # fermatas are shared across all voices so we only consider the first voice
145 | soprano_voice_metadata = metadata_tensor[0]
146 |
147 | # `soprano_voice_metadata` has shape
148 | # `(sequence_duration, len(metadatas + 1))` (accouting for the voice
149 | # index metadata)
150 | # Extract fermatas for all steps
151 | return soprano_voice_metadata[:, fermatas_index]
152 |
153 |
154 | def allowed_file(filename):
155 | return '.' in filename and \
156 | filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
157 |
158 |
159 | def compose_from_scratch():
160 | """
161 | Return a new, generated sheet
162 | Usage:
163 | - Request: empty, generation is done in an unconstrained fashion
164 | - Response: a sheet, MusicXML
165 | """
166 | global deepbach
167 | global _sequence_length_ticks
168 | global _num_iterations
169 | global _tensor_sheet
170 | global _tensor_metadata
171 |
172 | # Use more iterations for the initial generation step
173 | # FIXME hardcoded 4/4 time-signature
174 | num_measures_generation = math.floor(_sequence_length_ticks /
175 | deepbach.dataset.subdivision)
176 | initial_num_iterations = math.floor(_num_iterations * num_measures_generation
177 | / 3) # HACK hardcoded reduction
178 |
179 | (generated_sheet, _tensor_sheet, _tensor_metadata) = (
180 | deepbach.generation(num_iterations=initial_num_iterations,
181 | sequence_length_ticks=_sequence_length_ticks)
182 | )
183 | return generated_sheet
184 |
185 |
186 | @app.route('/compose', methods=['POST'])
187 | def compose():
188 | global deepbach
189 | global _num_iterations
190 | global _sequence_length_ticks
191 | global _tensor_sheet
192 | global _tensor_metadata
193 | global bach_chorales_dataset
194 |
195 | # global models
196 | NUM_MIDI_TICKS_IN_SIXTEENTH_NOTE = 120
197 | start_tick_selection = int(float(
198 | request.form['start_tick']) / NUM_MIDI_TICKS_IN_SIXTEENTH_NOTE)
199 | end_tick_selection = int(
200 | float(request.form['end_tick']) / NUM_MIDI_TICKS_IN_SIXTEENTH_NOTE)
201 | file_path = request.form['file_path']
202 | root, ext = os.path.splitext(file_path)
203 | dir = os.path.dirname(file_path)
204 | assert ext == '.mxl'
205 | xml_file = f'{root}.xml'
206 |
207 | # if no selection REGENERATE and set chorale length
208 | if start_tick_selection == 0 and end_tick_selection == 0:
209 | generated_sheet = compose_from_scratch()
210 | generated_sheet.write('xml', xml_file)
211 | return sheet_to_response(generated_sheet)
212 | else:
213 | # --- Parse request---
214 | # Old method: does not work because the MuseScore plugin does not export to xml but only to compressed .mxl
215 | # with tempfile.NamedTemporaryFile(mode='wb', suffix='.xml') as file:
216 | # print(file.name)
217 | # xml_string = request.form['xml_string']
218 | # file.write(xml_string)
219 | # music21_parsed_chorale = converter.parse(file.name)
220 |
221 | # file_path points to an mxl file: we extract it
222 | subprocess.run(f'unzip -o {file_path} -d {dir}', shell=True)
223 | music21_parsed_chorale = converter.parse(xml_file)
224 |
225 |
226 | _tensor_sheet, _tensor_metadata = bach_chorales_dataset.transposed_score_and_metadata_tensors(music21_parsed_chorale, semi_tone=0)
227 |
228 | start_voice_index = int(request.form['start_staff'])
229 | end_voice_index = int(request.form['end_staff']) + 1
230 |
231 | time_index_range_ticks = [start_tick_selection, end_tick_selection]
232 |
233 | region_length = end_tick_selection - start_tick_selection
234 |
235 | # compute batch_size_per_voice:
236 | if region_length <= 8:
237 | batch_size_per_voice = 2
238 | elif region_length <= 16:
239 | batch_size_per_voice = 4
240 | else:
241 | batch_size_per_voice = 8
242 |
243 |
244 | num_total_iterations = int(_num_iterations * region_length / batch_size_per_voice)
245 |
246 | fermatas_tensor = get_fermatas_tensor(_tensor_metadata)
247 |
248 | # --- Generate---
249 | (output_sheet,
250 | _tensor_sheet,
251 | _tensor_metadata) = deepbach.generation(
252 | tensor_chorale=_tensor_sheet,
253 | tensor_metadata=_tensor_metadata,
254 | temperature=1.,
255 | batch_size_per_voice=batch_size_per_voice,
256 | num_iterations=num_total_iterations,
257 | sequence_length_ticks=_sequence_length_ticks,
258 | time_index_range_ticks=time_index_range_ticks,
259 | fermatas=fermatas_tensor,
260 | voice_index_range=[start_voice_index, end_voice_index],
261 | random_init=True
262 | )
263 |
264 |
265 |
266 | output_sheet.write('xml', xml_file)
267 | response = sheet_to_response(sheet=output_sheet)
268 | return response
269 |
270 |
271 | def get_fermatas_tensor(metadata_tensor: torch.Tensor) -> torch.Tensor:
272 | """
273 | Extract the fermatas tensor from a metadata tensor
274 | """
275 | fermatas_index = [m.__class__ for m in metadatas].index(
276 | FermataMetadata().__class__)
277 | # fermatas are shared across all voices so we only consider the first voice
278 | soprano_voice_metadata = metadata_tensor[0]
279 |
280 | # `soprano_voice_metadata` has shape
281 | # `(sequence_duration, len(metadatas + 1))` (accouting for the voice
282 | # index metadata)
283 | # Extract fermatas for all steps
284 | return soprano_voice_metadata[:, fermatas_index]
285 |
286 |
287 | def sheet_to_response(sheet):
288 | # convert sheet to xml
289 | goe = musicxml.m21ToXml.GeneralObjectExporter(sheet)
290 | xml_chorale_string = goe.parse()
291 |
292 | response = make_response((xml_chorale_string, response_headers))
293 | return response
294 |
295 |
296 | @app.route('/test', methods=['POST', 'GET'])
297 | def test_generation():
298 | response = make_response(('TEST', response_headers))
299 |
300 | if request.method == 'POST':
301 | print(request)
302 |
303 | return response
304 |
305 |
306 | @app.route('/models', methods=['GET'])
307 | def get_models():
308 | models_list = ['Deepbach']
309 | return jsonify(models_list)
310 |
311 |
312 | @app.route('/current_model', methods=['POST', 'PUT'])
313 | def current_model_update():
314 | return 'Model is only loaded once'
315 |
316 |
317 | @app.route('/current_model', methods=['GET'])
318 | def current_model_get():
319 | return 'DeepBach'
320 |
321 |
322 | if __name__ == '__main__':
323 | file_handler = logging_handlers.RotatingFileHandler(
324 | 'app.log', maxBytes=10000, backupCount=5)
325 |
326 | app.logger.addHandler(file_handler)
327 | app.logger.setLevel(logging.INFO)
328 | init_app()
329 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import subprocess
4 | import tempfile
5 |
6 | import random
7 | import numpy as np
8 | import torch
9 | import click
10 | from cog import BasePredictor, Input, Path
11 | import music21
12 | from midi2audio import FluidSynth
13 |
14 | from DatasetManager.chorale_dataset import ChoraleDataset
15 | from DatasetManager.dataset_manager import DatasetManager
16 | from DatasetManager.metadata import FermataMetadata, KeyMetadata, TickMetadata
17 | from DeepBach.model_manager import DeepBach
18 |
19 |
20 | class Predictor(BasePredictor):
21 | def setup(self):
22 | """Load the model"""
23 |
24 | # music21.environment.set("musicxmlPath", "/bin/true")
25 |
26 | note_embedding_dim = 20
27 | meta_embedding_dim = 20
28 | num_layers = 2
29 | lstm_hidden_size = 256
30 | dropout_lstm = 0.5
31 | linear_hidden_size = 256
32 | batch_size = 256
33 | num_epochs = 5
34 | train = False
35 | num_iterations = 500
36 | sequence_length_ticks = 64
37 |
38 | dataset_manager = DatasetManager()
39 |
40 | metadatas = [FermataMetadata(), TickMetadata(subdivision=4), KeyMetadata()]
41 | chorale_dataset_kwargs = {
42 | "voice_ids": [0, 1, 2, 3],
43 | "metadatas": metadatas,
44 | "sequences_size": 8,
45 | "subdivision": 4,
46 | }
47 | bach_chorales_dataset: ChoraleDataset = dataset_manager.get_dataset(
48 | name="bach_chorales", **chorale_dataset_kwargs
49 | )
50 | dataset = bach_chorales_dataset
51 |
52 | self.deepbach = DeepBach(
53 | dataset=dataset,
54 | note_embedding_dim=note_embedding_dim,
55 | meta_embedding_dim=meta_embedding_dim,
56 | num_layers=num_layers,
57 | lstm_hidden_size=lstm_hidden_size,
58 | dropout_lstm=dropout_lstm,
59 | linear_hidden_size=linear_hidden_size,
60 | )
61 |
62 | self.deepbach.load()
63 |
64 | # load fluidsynth fo rmidi 2 audio conversion
65 | self.fs = FluidSynth()
66 |
67 | # self.converter = music21.converter.parse('path_to_musicxml.xml')
68 |
69 | def predict(
70 | self,
71 | num_iterations: int = Input(
72 | default=500,
73 | description="Number of parallel pseudo-Gibbs sampling iterations",
74 | ),
75 | sequence_length_ticks: int = Input(
76 | default=64, ge=16, description="Length of the generated chorale (in ticks)"
77 | ),
78 | output_type: str = Input(
79 | default="audio",
80 | choices=["midi", "audio"],
81 | description="Output representation type: can be audio or midi",
82 | ),
83 | seed: int = Input(default=-1, description="Random seed, -1 for random"),
84 | ) -> Path:
85 | """Score Generation"""
86 | if seed >= 0:
87 | random.seed(seed)
88 | np.random.seed(seed)
89 | torch.use_deterministic_algorithms(True)
90 | torch.manual_seed(seed)
91 |
92 | score, tensor_chorale, tensor_metadata = self.deepbach.generation(
93 | num_iterations=num_iterations,
94 | sequence_length_ticks=sequence_length_ticks,
95 | )
96 |
97 | if output_type == "audio":
98 | output_path_wav = Path(tempfile.mkdtemp()) / "output.wav"
99 | output_path_mp3 = Path(tempfile.mkdtemp()) / "output.mp3"
100 |
101 | midi_score = score.write("midi")
102 | self.fs.midi_to_audio(midi_score, str(output_path_wav))
103 |
104 | subprocess.check_output(
105 | [
106 | "ffmpeg",
107 | "-i",
108 | str(output_path_wav),
109 | "-af",
110 | "silenceremove=1:0:-50dB,aformat=dblp,areverse,silenceremove=1:0:-50dB,aformat=dblp,areverse", # strip silence
111 | str(output_path_mp3),
112 | ],
113 | )
114 |
115 | return output_path_mp3
116 |
117 | elif output_type == "midi":
118 | output_path_midi = Path(tempfile.mkdtemp()) / "output.mid"
119 | score.write("midi", fp=output_path_midi)
120 | return output_path_midi
121 |
--------------------------------------------------------------------------------