├── .gitignore ├── DatasetManager ├── __init__.py ├── chorale_dataset.py ├── dataset_manager.py ├── helpers.py ├── metadata.py └── music_dataset.py ├── DeepBach ├── __init__.py ├── data_utils.py ├── helpers.py ├── metadata.py ├── model_manager.py └── voice_model.py ├── Dockerfile ├── LICENSE ├── README.md ├── cog.yaml ├── deepBach.py ├── deepBachMuseScore.qml ├── dl_dataset_and_models.sh ├── entrypoint.sh ├── environment.yml ├── flask_server.py ├── musescore_flask_server.py └── predict.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | # some files 6 | _recomposed_by_deepBach.xml 7 | deepbach.xml 8 | container.xml 9 | .vscode/ 10 | settings.json 11 | launch.json 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *,cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | app.log.* 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # IPython Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | venv/ 91 | ENV/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # local files 100 | .idea/ 101 | stats/ 102 | deepbach_ressources/ 103 | DeepBach/models/*.yaml 104 | DeepBach/models/*.h5 105 | old_models/ 106 | good_model/ 107 | generated_examples/ 108 | DeepBach/datasets/raw_dataset/*.pickle 109 | DeepBach/datasets/custom_dataset/*.pickle 110 | 111 | *.tar.gz 112 | !download_pretrained_data.sh 113 | DatasetManager/dataset_cache/ 114 | models/ -------------------------------------------------------------------------------- /DatasetManager/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ghadjeres/DeepBach/6d75cb940f3aa53e02f9eade34d58e472e0c95d7/DatasetManager/__init__.py -------------------------------------------------------------------------------- /DatasetManager/chorale_dataset.py: -------------------------------------------------------------------------------- 1 | import music21 2 | import torch 3 | import numpy as np 4 | 5 | from music21 import interval, stream 6 | from torch.utils.data import TensorDataset 7 | from tqdm import tqdm 8 | 9 | from DatasetManager.helpers import standard_name, SLUR_SYMBOL, START_SYMBOL, END_SYMBOL, \ 10 | standard_note, OUT_OF_RANGE, REST_SYMBOL 11 | from DatasetManager.metadata import FermataMetadata 12 | from DatasetManager.music_dataset import MusicDataset 13 | 14 | 15 | class ChoraleDataset(MusicDataset): 16 | """ 17 | Class for all chorale-like datasets 18 | """ 19 | 20 | def __init__(self, 21 | corpus_it_gen, 22 | name, 23 | voice_ids, 24 | metadatas=None, 25 | sequences_size=8, 26 | subdivision=4, 27 | cache_dir=None): 28 | """ 29 | :param corpus_it_gen: calling this function returns an iterator 30 | over chorales (as music21 scores) 31 | :param name: name of the dataset 32 | :param voice_ids: list of voice_indexes to be used 33 | :param metadatas: list[Metadata], the list of used metadatas 34 | :param sequences_size: in beats 35 | :param subdivision: number of sixteenth notes per beat 36 | :param cache_dir: directory where tensor_dataset is stored 37 | """ 38 | super(ChoraleDataset, self).__init__(cache_dir=cache_dir) 39 | self.voice_ids = voice_ids 40 | # TODO WARNING voice_ids is never used! 41 | self.num_voices = len(voice_ids) 42 | self.name = name 43 | self.sequences_size = sequences_size 44 | self.index2note_dicts = None 45 | self.note2index_dicts = None 46 | self.corpus_it_gen = corpus_it_gen 47 | self.voice_ranges = None # in midi pitch 48 | self.metadatas = metadatas 49 | self.subdivision = subdivision 50 | 51 | def __repr__(self): 52 | return f'ChoraleDataset(' \ 53 | f'{self.voice_ids},' \ 54 | f'{self.name},' \ 55 | f'{[metadata.name for metadata in self.metadatas]},' \ 56 | f'{self.sequences_size},' \ 57 | f'{self.subdivision})' 58 | 59 | def iterator_gen(self): 60 | return (chorale 61 | for chorale in self.corpus_it_gen() 62 | if self.is_valid(chorale) 63 | ) 64 | 65 | def make_tensor_dataset(self): 66 | """ 67 | Implementation of the make_tensor_dataset abstract base class 68 | """ 69 | # todo check on chorale with Chord 70 | print('Making tensor dataset') 71 | self.compute_index_dicts() 72 | self.compute_voice_ranges() 73 | one_tick = 1 / self.subdivision 74 | chorale_tensor_dataset = [] 75 | metadata_tensor_dataset = [] 76 | for chorale_id, chorale in tqdm(enumerate(self.iterator_gen())): 77 | 78 | # precompute all possible transpositions and corresponding metadatas 79 | chorale_transpositions = {} 80 | metadatas_transpositions = {} 81 | 82 | # main loop 83 | for offsetStart in np.arange( 84 | chorale.flat.lowestOffset - 85 | (self.sequences_size - one_tick), 86 | chorale.flat.highestOffset, 87 | one_tick): 88 | offsetEnd = offsetStart + self.sequences_size 89 | current_subseq_ranges = self.voice_range_in_subsequence( 90 | chorale, 91 | offsetStart=offsetStart, 92 | offsetEnd=offsetEnd) 93 | 94 | transposition = self.min_max_transposition(current_subseq_ranges) 95 | min_transposition_subsequence, max_transposition_subsequence = transposition 96 | 97 | for semi_tone in range(min_transposition_subsequence, 98 | max_transposition_subsequence + 1): 99 | start_tick = int(offsetStart * self.subdivision) 100 | end_tick = int(offsetEnd * self.subdivision) 101 | 102 | try: 103 | # compute transpositions lazily 104 | if semi_tone not in chorale_transpositions: 105 | (chorale_tensor, 106 | metadata_tensor) = self.transposed_score_and_metadata_tensors( 107 | chorale, 108 | semi_tone=semi_tone) 109 | chorale_transpositions.update( 110 | {semi_tone: 111 | chorale_tensor}) 112 | metadatas_transpositions.update( 113 | {semi_tone: 114 | metadata_tensor}) 115 | else: 116 | chorale_tensor = chorale_transpositions[semi_tone] 117 | metadata_tensor = metadatas_transpositions[semi_tone] 118 | 119 | local_chorale_tensor = self.extract_score_tensor_with_padding( 120 | chorale_tensor, 121 | start_tick, end_tick) 122 | local_metadata_tensor = self.extract_metadata_with_padding( 123 | metadata_tensor, 124 | start_tick, end_tick) 125 | 126 | # append and add batch dimension 127 | # cast to int 128 | chorale_tensor_dataset.append( 129 | local_chorale_tensor[None, :, :].int()) 130 | metadata_tensor_dataset.append( 131 | local_metadata_tensor[None, :, :, :].int()) 132 | except KeyError: 133 | # some problems may occur with the key analyzer 134 | print(f'KeyError with chorale {chorale_id}') 135 | 136 | chorale_tensor_dataset = torch.cat(chorale_tensor_dataset, 0) 137 | metadata_tensor_dataset = torch.cat(metadata_tensor_dataset, 0) 138 | 139 | dataset = TensorDataset(chorale_tensor_dataset, 140 | metadata_tensor_dataset) 141 | 142 | print(f'Sizes: {chorale_tensor_dataset.size()}, {metadata_tensor_dataset.size()}') 143 | return dataset 144 | 145 | def transposed_score_and_metadata_tensors(self, score, semi_tone): 146 | """ 147 | Convert chorale to a couple (chorale_tensor, metadata_tensor), 148 | the original chorale is transposed semi_tone number of semi-tones 149 | :param chorale: music21 object 150 | :param semi_tone: 151 | :return: couple of tensors 152 | """ 153 | # transpose 154 | # compute the most "natural" interval given a number of semi-tones 155 | interval_type, interval_nature = interval.convertSemitoneToSpecifierGeneric( 156 | semi_tone) 157 | transposition_interval = interval.Interval( 158 | str(interval_nature) + str(interval_type)) 159 | 160 | chorale_tranposed = score.transpose(transposition_interval) 161 | chorale_tensor = self.get_score_tensor( 162 | chorale_tranposed, 163 | offsetStart=0., 164 | offsetEnd=chorale_tranposed.flat.highestTime) 165 | metadatas_transposed = self.get_metadata_tensor(chorale_tranposed) 166 | return chorale_tensor, metadatas_transposed 167 | 168 | def get_metadata_tensor(self, score): 169 | """ 170 | Adds also the index of the voices 171 | :param score: music21 stream 172 | :return:tensor (num_voices, chorale_length, len(self.metadatas) + 1) 173 | """ 174 | md = [] 175 | if self.metadatas: 176 | for metadata in self.metadatas: 177 | sequence_metadata = torch.from_numpy( 178 | metadata.evaluate(score, self.subdivision)).long().clone() 179 | square_metadata = sequence_metadata.repeat(self.num_voices, 1) 180 | md.append( 181 | square_metadata[:, :, None] 182 | ) 183 | chorale_length = int(score.duration.quarterLength * self.subdivision) 184 | 185 | # add voice indexes 186 | voice_id_metada = torch.from_numpy(np.arange(self.num_voices)).long().clone() 187 | square_metadata = torch.transpose(voice_id_metada.repeat(chorale_length, 1), 188 | 0, 1) 189 | md.append(square_metadata[:, :, None]) 190 | 191 | all_metadata = torch.cat(md, 2) 192 | return all_metadata 193 | 194 | def set_fermatas(self, metadata_tensor, fermata_tensor): 195 | """ 196 | Impose fermatas for all chorales in a batch 197 | :param metadata_tensor: a (batch_size, sequences_size, num_metadatas) 198 | tensor 199 | :param fermata_tensor: a (sequences_size) binary tensor 200 | """ 201 | if self.metadatas: 202 | for metadata_index, metadata in enumerate(self.metadatas): 203 | if isinstance(metadata, FermataMetadata): 204 | # uses broadcasting 205 | metadata_tensor[:, :, metadata_index] = fermata_tensor 206 | break 207 | return metadata_tensor 208 | 209 | def add_fermata(self, metadata_tensor, time_index_start, time_index_stop): 210 | """ 211 | Shorthand function to impose a fermata between two time indexes 212 | """ 213 | fermata_tensor = torch.zeros(self.sequences_size) 214 | fermata_tensor[time_index_start:time_index_stop] = 1 215 | metadata_tensor = self.set_fermatas(metadata_tensor, fermata_tensor) 216 | return metadata_tensor 217 | 218 | def min_max_transposition(self, current_subseq_ranges): 219 | if current_subseq_ranges is None: 220 | # todo might be too restrictive 221 | # there is no note in one part 222 | transposition = (0, 0) # min and max transpositions 223 | else: 224 | transpositions = [ 225 | (min_pitch_corpus - min_pitch_current, 226 | max_pitch_corpus - max_pitch_current) 227 | for ((min_pitch_corpus, max_pitch_corpus), 228 | (min_pitch_current, max_pitch_current)) 229 | in zip(self.voice_ranges, current_subseq_ranges) 230 | ] 231 | transpositions = [min_or_max_transposition 232 | for min_or_max_transposition in zip(*transpositions)] 233 | transposition = [max(transpositions[0]), 234 | min(transpositions[1])] 235 | return transposition 236 | 237 | def get_score_tensor(self, score, offsetStart, offsetEnd): 238 | chorale_tensor = [] 239 | for part_id, part in enumerate(score.parts[:self.num_voices]): 240 | part_tensor = self.part_to_tensor(part, part_id, 241 | offsetStart=offsetStart, 242 | offsetEnd=offsetEnd) 243 | chorale_tensor.append(part_tensor) 244 | return torch.cat(chorale_tensor, 0) 245 | 246 | def part_to_tensor(self, part, part_id, offsetStart, offsetEnd): 247 | """ 248 | :param part: 249 | :param part_id: 250 | :param offsetStart: 251 | :param offsetEnd: 252 | :return: torch IntTensor (1, length) 253 | """ 254 | list_notes_and_rests = list(part.flat.getElementsByOffset( 255 | offsetStart=offsetStart, 256 | offsetEnd=offsetEnd, 257 | classList=[music21.note.Note, 258 | music21.note.Rest])) 259 | list_note_strings_and_pitches = [(n.nameWithOctave, n.pitch.midi) 260 | for n in list_notes_and_rests 261 | if n.isNote] 262 | length = int((offsetEnd - offsetStart) * self.subdivision) # in ticks 263 | 264 | # add entries to dictionaries if not present 265 | # should only be called by make_dataset when transposing 266 | note2index = self.note2index_dicts[part_id] 267 | index2note = self.index2note_dicts[part_id] 268 | voice_range = self.voice_ranges[part_id] 269 | min_pitch, max_pitch = voice_range 270 | for note_name, pitch in list_note_strings_and_pitches: 271 | # if out of range 272 | if pitch < min_pitch or pitch > max_pitch: 273 | note_name = OUT_OF_RANGE 274 | 275 | if note_name not in note2index: 276 | new_index = len(note2index) 277 | index2note.update({new_index: note_name}) 278 | note2index.update({note_name: new_index}) 279 | print('Warning: Entry ' + str( 280 | {new_index: note_name}) + ' added to dictionaries') 281 | 282 | # construct sequence 283 | j = 0 284 | i = 0 285 | t = np.zeros((length, 2)) 286 | is_articulated = True 287 | num_notes = len(list_notes_and_rests) 288 | while i < length: 289 | if j < num_notes - 1: 290 | if (list_notes_and_rests[j + 1].offset > i 291 | / self.subdivision + offsetStart): 292 | t[i, :] = [note2index[standard_name(list_notes_and_rests[j], 293 | voice_range=voice_range)], 294 | is_articulated] 295 | i += 1 296 | is_articulated = False 297 | else: 298 | j += 1 299 | is_articulated = True 300 | else: 301 | t[i, :] = [note2index[standard_name(list_notes_and_rests[j], 302 | voice_range=voice_range)], 303 | is_articulated] 304 | i += 1 305 | is_articulated = False 306 | seq = t[:, 0] * t[:, 1] + (1 - t[:, 1]) * note2index[SLUR_SYMBOL] 307 | tensor = torch.from_numpy(seq).long()[None, :] 308 | return tensor 309 | 310 | def voice_range_in_subsequence(self, chorale, offsetStart, offsetEnd): 311 | """ 312 | returns None if no note present in one of the voices -> no transposition 313 | :param chorale: 314 | :param offsetStart: 315 | :param offsetEnd: 316 | :return: 317 | """ 318 | voice_ranges = [] 319 | for part in chorale.parts[:self.num_voices]: 320 | voice_range_part = self.voice_range_in_part(part, 321 | offsetStart=offsetStart, 322 | offsetEnd=offsetEnd) 323 | if voice_range_part is None: 324 | return None 325 | else: 326 | voice_ranges.append(voice_range_part) 327 | return voice_ranges 328 | 329 | def voice_range_in_part(self, part, offsetStart, offsetEnd): 330 | notes_in_subsequence = part.flat.getElementsByOffset( 331 | offsetStart, 332 | offsetEnd, 333 | includeEndBoundary=False, 334 | mustBeginInSpan=True, 335 | mustFinishInSpan=False, 336 | classList=[music21.note.Note, 337 | music21.note.Rest]) 338 | midi_pitches_part = [ 339 | n.pitch.midi 340 | for n in notes_in_subsequence 341 | if n.isNote 342 | ] 343 | if len(midi_pitches_part) > 0: 344 | return min(midi_pitches_part), max(midi_pitches_part) 345 | else: 346 | return None 347 | 348 | def compute_index_dicts(self): 349 | print('Computing index dicts') 350 | self.index2note_dicts = [ 351 | {} for _ in range(self.num_voices) 352 | ] 353 | self.note2index_dicts = [ 354 | {} for _ in range(self.num_voices) 355 | ] 356 | 357 | # create and add additional symbols 358 | note_sets = [set() for _ in range(self.num_voices)] 359 | for note_set in note_sets: 360 | note_set.add(SLUR_SYMBOL) 361 | note_set.add(START_SYMBOL) 362 | note_set.add(END_SYMBOL) 363 | note_set.add(REST_SYMBOL) 364 | 365 | # get all notes: used for computing pitch ranges 366 | for chorale in tqdm(self.iterator_gen()): 367 | for part_id, part in enumerate(chorale.parts[:self.num_voices]): 368 | for n in part.flat.notesAndRests: 369 | note_sets[part_id].add(standard_name(n)) 370 | 371 | # create tables 372 | for note_set, index2note, note2index in zip(note_sets, 373 | self.index2note_dicts, 374 | self.note2index_dicts): 375 | for note_index, note in enumerate(note_set): 376 | index2note.update({note_index: note}) 377 | note2index.update({note: note_index}) 378 | 379 | def is_valid(self, chorale): 380 | # We only consider 4-part chorales 381 | if not len(chorale.parts) == 4: 382 | return False 383 | # todo contains chord 384 | return True 385 | 386 | def compute_voice_ranges(self): 387 | assert self.index2note_dicts is not None 388 | assert self.note2index_dicts is not None 389 | self.voice_ranges = [] 390 | print('Computing voice ranges') 391 | for voice_index, note2index in tqdm(enumerate(self.note2index_dicts)): 392 | notes = [ 393 | standard_note(note_string) 394 | for note_string in note2index 395 | ] 396 | midi_pitches = [ 397 | n.pitch.midi 398 | for n in notes 399 | if n.isNote 400 | ] 401 | min_midi, max_midi = min(midi_pitches), max(midi_pitches) 402 | self.voice_ranges.append((min_midi, max_midi)) 403 | 404 | def extract_score_tensor_with_padding(self, tensor_score, start_tick, end_tick): 405 | """ 406 | :param tensor_chorale: (num_voices, length in ticks) 407 | :param start_tick: 408 | :param end_tick: 409 | :return: tensor_chorale[:, start_tick: end_tick] 410 | with padding if necessary 411 | i.e. if start_tick < 0 or end_tick > tensor_chorale length 412 | """ 413 | assert start_tick < end_tick 414 | assert end_tick > 0 415 | length = tensor_score.size()[1] 416 | 417 | padded_chorale = [] 418 | # todo add PAD_SYMBOL 419 | if start_tick < 0: 420 | start_symbols = np.array([note2index[START_SYMBOL] 421 | for note2index in self.note2index_dicts]) 422 | start_symbols = torch.from_numpy(start_symbols).long().clone() 423 | start_symbols = start_symbols.repeat(-start_tick, 1).transpose(0, 1) 424 | padded_chorale.append(start_symbols) 425 | 426 | slice_start = start_tick if start_tick > 0 else 0 427 | slice_end = end_tick if end_tick < length else length 428 | 429 | padded_chorale.append(tensor_score[:, slice_start: slice_end]) 430 | 431 | if end_tick > length: 432 | end_symbols = np.array([note2index[END_SYMBOL] 433 | for note2index in self.note2index_dicts]) 434 | end_symbols = torch.from_numpy(end_symbols).long().clone() 435 | end_symbols = end_symbols.repeat(end_tick - length, 1).transpose(0, 1) 436 | padded_chorale.append(end_symbols) 437 | 438 | padded_chorale = torch.cat(padded_chorale, 1) 439 | return padded_chorale 440 | 441 | def extract_metadata_with_padding(self, tensor_metadata, 442 | start_tick, end_tick): 443 | """ 444 | :param tensor_metadata: (num_voices, length, num_metadatas) 445 | last metadata is the voice_index 446 | :param start_tick: 447 | :param end_tick: 448 | :return: 449 | """ 450 | assert start_tick < end_tick 451 | assert end_tick > 0 452 | num_voices, length, num_metadatas = tensor_metadata.size() 453 | padded_tensor_metadata = [] 454 | 455 | if start_tick < 0: 456 | # TODO more subtle padding 457 | start_symbols = np.zeros((self.num_voices, -start_tick, num_metadatas)) 458 | start_symbols = torch.from_numpy(start_symbols).long().clone() 459 | padded_tensor_metadata.append(start_symbols) 460 | 461 | slice_start = start_tick if start_tick > 0 else 0 462 | slice_end = end_tick if end_tick < length else length 463 | padded_tensor_metadata.append(tensor_metadata[:, slice_start: slice_end, :]) 464 | 465 | if end_tick > length: 466 | end_symbols = np.zeros((self.num_voices, end_tick - length, num_metadatas)) 467 | end_symbols = torch.from_numpy(end_symbols).long().clone() 468 | padded_tensor_metadata.append(end_symbols) 469 | 470 | padded_tensor_metadata = torch.cat(padded_tensor_metadata, 1) 471 | return padded_tensor_metadata 472 | 473 | def empty_score_tensor(self, score_length): 474 | start_symbols = np.array([note2index[START_SYMBOL] 475 | for note2index in self.note2index_dicts]) 476 | start_symbols = torch.from_numpy(start_symbols).long().clone() 477 | start_symbols = start_symbols.repeat(score_length, 1).transpose(0, 1) 478 | return start_symbols 479 | 480 | def random_score_tensor(self, score_length): 481 | chorale_tensor = np.array( 482 | [np.random.randint(len(note2index), 483 | size=score_length) 484 | for note2index in self.note2index_dicts]) 485 | chorale_tensor = torch.from_numpy(chorale_tensor).long().clone() 486 | return chorale_tensor 487 | 488 | def tensor_to_score(self, tensor_score, 489 | fermata_tensor=None): 490 | """ 491 | :param tensor_score: (num_voices, length) 492 | :return: music21 score object 493 | """ 494 | slur_indexes = [note2index[SLUR_SYMBOL] 495 | for note2index in self.note2index_dicts] 496 | 497 | score = music21.stream.Score() 498 | num_voices = tensor_score.size(0) 499 | name_parts = (num_voices == 4) 500 | part_names = ['Soprano', 'Alto', 'Tenor', 'Bass'] 501 | 502 | 503 | for voice_index, (voice, index2note, slur_index) in enumerate( 504 | zip(tensor_score, 505 | self.index2note_dicts, 506 | slur_indexes)): 507 | add_fermata = False 508 | if name_parts: 509 | part = stream.Part(id=part_names[voice_index], 510 | partName=part_names[voice_index], 511 | partAbbreviation=part_names[voice_index], 512 | instrumentName=part_names[voice_index]) 513 | else: 514 | part = stream.Part(id='part' + str(voice_index)) 515 | dur = 0 516 | total_duration = 0 517 | f = music21.note.Rest() 518 | for note_index in [n.item() for n in voice]: 519 | # if it is a played note 520 | if not note_index == slur_indexes[voice_index]: 521 | # add previous note 522 | if dur > 0: 523 | f.duration = music21.duration.Duration(dur / self.subdivision) 524 | 525 | if add_fermata: 526 | f.expressions.append(music21.expressions.Fermata()) 527 | add_fermata = False 528 | 529 | part.append(f) 530 | 531 | 532 | dur = 1 533 | f = standard_note(index2note[note_index]) 534 | if fermata_tensor is not None and voice_index == 0: 535 | if fermata_tensor[0, total_duration] == 1: 536 | add_fermata = True 537 | else: 538 | add_fermata = False 539 | total_duration += 1 540 | 541 | else: 542 | dur += 1 543 | total_duration += 1 544 | # add last note 545 | f.duration = music21.duration.Duration(dur / self.subdivision) 546 | if add_fermata: 547 | f.expressions.append(music21.expressions.Fermata()) 548 | add_fermata = False 549 | 550 | part.append(f) 551 | score.insert(part) 552 | return score 553 | 554 | 555 | # TODO should go in ChoraleDataset 556 | # TODO all subsequences start on a beat 557 | class ChoraleBeatsDataset(ChoraleDataset): 558 | def __repr__(self): 559 | return f'ChoraleBeatsDataset(' \ 560 | f'{self.voice_ids},' \ 561 | f'{self.name},' \ 562 | f'{[metadata.name for metadata in self.metadatas]},' \ 563 | f'{self.sequences_size},' \ 564 | f'{self.subdivision})' 565 | 566 | def make_tensor_dataset(self): 567 | """ 568 | Implementation of the make_tensor_dataset abstract base class 569 | """ 570 | # todo check on chorale with Chord 571 | print('Making tensor dataset') 572 | self.compute_index_dicts() 573 | self.compute_voice_ranges() 574 | one_beat = 1. 575 | chorale_tensor_dataset = [] 576 | metadata_tensor_dataset = [] 577 | for chorale_id, chorale in tqdm(enumerate(self.iterator_gen())): 578 | 579 | # precompute all possible transpositions and corresponding metadatas 580 | chorale_transpositions = {} 581 | metadatas_transpositions = {} 582 | 583 | # main loop 584 | for offsetStart in np.arange( 585 | chorale.flat.lowestOffset - 586 | (self.sequences_size - one_beat), 587 | chorale.flat.highestOffset, 588 | one_beat): 589 | offsetEnd = offsetStart + self.sequences_size 590 | current_subseq_ranges = self.voice_range_in_subsequence( 591 | chorale, 592 | offsetStart=offsetStart, 593 | offsetEnd=offsetEnd) 594 | 595 | transposition = self.min_max_transposition(current_subseq_ranges) 596 | min_transposition_subsequence, max_transposition_subsequence = transposition 597 | 598 | for semi_tone in range(min_transposition_subsequence, 599 | max_transposition_subsequence + 1): 600 | start_tick = int(offsetStart * self.subdivision) 601 | end_tick = int(offsetEnd * self.subdivision) 602 | 603 | try: 604 | # compute transpositions lazily 605 | if semi_tone not in chorale_transpositions: 606 | (chorale_tensor, 607 | metadata_tensor) = self.transposed_score_and_metadata_tensors( 608 | chorale, 609 | semi_tone=semi_tone) 610 | chorale_transpositions.update( 611 | {semi_tone: 612 | chorale_tensor}) 613 | metadatas_transpositions.update( 614 | {semi_tone: 615 | metadata_tensor}) 616 | else: 617 | chorale_tensor = chorale_transpositions[semi_tone] 618 | metadata_tensor = metadatas_transpositions[semi_tone] 619 | 620 | local_chorale_tensor = self.extract_score_tensor_with_padding( 621 | chorale_tensor, 622 | start_tick, end_tick) 623 | local_metadata_tensor = self.extract_metadata_with_padding( 624 | metadata_tensor, 625 | start_tick, end_tick) 626 | 627 | # append and add batch dimension 628 | # cast to int 629 | chorale_tensor_dataset.append( 630 | local_chorale_tensor[None, :, :].int()) 631 | metadata_tensor_dataset.append( 632 | local_metadata_tensor[None, :, :, :].int()) 633 | except KeyError: 634 | # some problems may occur with the key analyzer 635 | print(f'KeyError with chorale {chorale_id}') 636 | 637 | chorale_tensor_dataset = torch.cat(chorale_tensor_dataset, 0) 638 | metadata_tensor_dataset = torch.cat(metadata_tensor_dataset, 0) 639 | 640 | dataset = TensorDataset(chorale_tensor_dataset, 641 | metadata_tensor_dataset) 642 | 643 | print(f'Sizes: {chorale_tensor_dataset.size()}, {metadata_tensor_dataset.size()}') 644 | return dataset 645 | -------------------------------------------------------------------------------- /DatasetManager/dataset_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import music21 4 | import torch 5 | from DatasetManager.chorale_dataset import ChoraleDataset 6 | from DatasetManager.helpers import ShortChoraleIteratorGen 7 | from DatasetManager.metadata import TickMetadata, \ 8 | FermataMetadata, \ 9 | KeyMetadata 10 | from DatasetManager.music_dataset import MusicDataset 11 | 12 | # Basically, all you have to do to use an existing dataset is to 13 | # add an entry in the all_datasets variable 14 | # and specify its base class and which music21 objects it uses 15 | # by giving an iterator over music21 scores 16 | 17 | all_datasets = { 18 | 'bach_chorales': 19 | { 20 | 'dataset_class_name': ChoraleDataset, 21 | 'corpus_it_gen': music21.corpus.chorales.Iterator 22 | }, 23 | 'bach_chorales_test': 24 | { 25 | 'dataset_class_name': ChoraleDataset, 26 | 'corpus_it_gen': ShortChoraleIteratorGen() 27 | }, 28 | } 29 | 30 | 31 | class DatasetManager: 32 | def __init__(self): 33 | self.package_dir = os.path.dirname(os.path.realpath(__file__)) 34 | self.cache_dir = os.path.join(self.package_dir, 35 | 'dataset_cache') 36 | # create cache dir if it doesn't exist 37 | if not os.path.exists(self.cache_dir): 38 | os.mkdir(self.cache_dir) 39 | 40 | def get_dataset(self, name: str, **dataset_kwargs) -> MusicDataset: 41 | if name in all_datasets: 42 | return self.load_if_exists_or_initialize_and_save( 43 | name=name, 44 | **all_datasets[name], 45 | **dataset_kwargs 46 | ) 47 | else: 48 | print('Dataset with name {name} is not registered in all_datasets variable') 49 | raise ValueError 50 | 51 | def load_if_exists_or_initialize_and_save(self, 52 | dataset_class_name, 53 | corpus_it_gen, 54 | name, 55 | **kwargs): 56 | """ 57 | 58 | :param dataset_class_name: 59 | :param corpus_it_gen: 60 | :param name: 61 | :param kwargs: parameters specific to an implementation 62 | of MusicDataset (ChoraleDataset for instance) 63 | :return: 64 | """ 65 | kwargs.update( 66 | {'name': name, 67 | 'corpus_it_gen': corpus_it_gen, 68 | 'cache_dir': self.cache_dir 69 | }) 70 | dataset = dataset_class_name(**kwargs) 71 | if os.path.exists(dataset.filepath): 72 | print(f'Loading {dataset.__repr__()} from {dataset.filepath}') 73 | dataset = torch.load(dataset.filepath) 74 | dataset.cache_dir = self.cache_dir 75 | print(f'(the corresponding TensorDataset is not loaded)') 76 | else: 77 | print(f'Creating {dataset.__repr__()}, ' 78 | f'both tensor dataset and parameters') 79 | # initialize and force the computation of the tensor_dataset 80 | # first remove the cached data if it exists 81 | if os.path.exists(dataset.tensor_dataset_filepath): 82 | os.remove(dataset.tensor_dataset_filepath) 83 | # recompute dataset parameters and tensor_dataset 84 | # this saves the tensor_dataset in dataset.tensor_dataset_filepath 85 | tensor_dataset = dataset.tensor_dataset 86 | # save all dataset parameters EXCEPT the tensor dataset 87 | # which is stored elsewhere 88 | dataset.tensor_dataset = None 89 | torch.save(dataset, dataset.filepath) 90 | print(f'{dataset.__repr__()} saved in {dataset.filepath}') 91 | dataset.tensor_dataset = tensor_dataset 92 | return dataset 93 | 94 | 95 | if __name__ == '__main__': 96 | # Usage example 97 | 98 | dataset_manager = DatasetManager() 99 | subdivision = 4 100 | metadatas = [ 101 | TickMetadata(subdivision=subdivision), 102 | FermataMetadata(), 103 | KeyMetadata() 104 | ] 105 | 106 | bach_chorales_dataset: ChoraleDataset = dataset_manager.get_dataset( 107 | name='bach_chorales_test', 108 | voice_ids=[0, 1, 2, 3], 109 | metadatas=metadatas, 110 | sequences_size=8, 111 | subdivision=subdivision 112 | ) 113 | (train_dataloader, 114 | val_dataloader, 115 | test_dataloader) = bach_chorales_dataset.data_loaders( 116 | batch_size=128, 117 | split=(0.85, 0.10) 118 | ) 119 | print('Num Train Batches: ', len(train_dataloader)) 120 | print('Num Valid Batches: ', len(val_dataloader)) 121 | print('Num Test Batches: ', len(test_dataloader)) 122 | -------------------------------------------------------------------------------- /DatasetManager/helpers.py: -------------------------------------------------------------------------------- 1 | import music21 2 | from itertools import islice 3 | 4 | from music21 import note, harmony, expressions 5 | 6 | # constants 7 | SLUR_SYMBOL = '__' 8 | START_SYMBOL = 'START' 9 | END_SYMBOL = 'END' 10 | REST_SYMBOL = 'rest' 11 | OUT_OF_RANGE = 'OOR' 12 | PAD_SYMBOL = 'XX' 13 | 14 | 15 | def standard_name(note_or_rest, voice_range=None): 16 | """ 17 | Convert music21 objects to str 18 | :param note_or_rest: 19 | :return: 20 | """ 21 | if isinstance(note_or_rest, note.Note): 22 | if voice_range is not None: 23 | min_pitch, max_pitch = voice_range 24 | pitch = note_or_rest.pitch.midi 25 | if pitch < min_pitch or pitch > max_pitch: 26 | return OUT_OF_RANGE 27 | return note_or_rest.nameWithOctave 28 | if isinstance(note_or_rest, note.Rest): 29 | return note_or_rest.name # == 'rest' := REST_SYMBOL 30 | if isinstance(note_or_rest, str): 31 | return note_or_rest 32 | 33 | if isinstance(note_or_rest, harmony.ChordSymbol): 34 | return note_or_rest.figure 35 | if isinstance(note_or_rest, expressions.TextExpression): 36 | return note_or_rest.content 37 | 38 | 39 | def standard_note(note_or_rest_string): 40 | """ 41 | Convert str representing a music21 object to this object 42 | :param note_or_rest_string: 43 | :return: 44 | """ 45 | if note_or_rest_string == 'rest': 46 | return note.Rest() 47 | # treat other additional symbols as rests 48 | elif (note_or_rest_string == END_SYMBOL 49 | or 50 | note_or_rest_string == START_SYMBOL 51 | or 52 | note_or_rest_string == PAD_SYMBOL): 53 | # print('Warning: Special symbol is used in standard_note') 54 | return note.Rest() 55 | elif note_or_rest_string == SLUR_SYMBOL: 56 | # print('Warning: SLUR_SYMBOL used in standard_note') 57 | return note.Rest() 58 | elif note_or_rest_string == OUT_OF_RANGE: 59 | # print('Warning: OUT_OF_RANGE used in standard_note') 60 | return note.Rest() 61 | else: 62 | return note.Note(note_or_rest_string) 63 | 64 | 65 | class ShortChoraleIteratorGen: 66 | """ 67 | Class used for debugging 68 | when called, it returns an iterator over 3 Bach chorales, 69 | similar to music21.corpus.chorales.Iterator() 70 | """ 71 | 72 | def __init__(self): 73 | pass 74 | 75 | def __call__(self): 76 | it = ( 77 | chorale 78 | for chorale in 79 | islice(music21.corpus.chorales.Iterator(), 3) 80 | ) 81 | return it.__iter__() 82 | -------------------------------------------------------------------------------- /DatasetManager/metadata.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metadata classes 3 | """ 4 | import numpy as np 5 | from music21 import analysis, stream, meter 6 | from DatasetManager.helpers import SLUR_SYMBOL, \ 7 | PAD_SYMBOL 8 | 9 | 10 | class Metadata: 11 | def __init__(self): 12 | self.num_values = None 13 | self.is_global = None 14 | self.name = None 15 | 16 | def get_index(self, value): 17 | # trick with the 0 value 18 | raise NotImplementedError 19 | 20 | def get_value(self, index): 21 | raise NotImplementedError 22 | 23 | def evaluate(self, chorale, subdivision): 24 | """ 25 | takes a music21 chorale as input and the number of subdivisions per beat 26 | """ 27 | raise NotImplementedError 28 | 29 | def generate(self, length): 30 | raise NotImplementedError 31 | 32 | 33 | class IsPlayingMetadata(Metadata): 34 | def __init__(self, voice_index, min_num_ticks): 35 | """ 36 | Metadata that indicates if a voice is playing 37 | Voice i is considered to be muted if more than 'min_num_ticks' contiguous 38 | ticks contain a rest. 39 | 40 | 41 | :param voice_index: index of the voice to take into account 42 | :param min_num_ticks: minimum length in ticks for a rest to be taken 43 | into account in the metadata 44 | """ 45 | super(IsPlayingMetadata, self).__init__() 46 | self.min_num_ticks = min_num_ticks 47 | self.voice_index = voice_index 48 | self.is_global = False 49 | self.num_values = 2 50 | self.name = 'isplaying' 51 | 52 | def get_index(self, value): 53 | return int(value) 54 | 55 | def get_value(self, index): 56 | return bool(index) 57 | 58 | def evaluate(self, chorale, subdivision): 59 | """ 60 | takes a music21 chorale as input 61 | """ 62 | length = int(chorale.duration.quarterLength * subdivision) 63 | metadatas = np.ones(shape=(length,)) 64 | part = chorale.parts[self.voice_index] 65 | 66 | for note_or_rest in part.notesAndRests: 67 | is_playing = True 68 | if note_or_rest.isRest: 69 | if note_or_rest.quarterLength * subdivision >= self.min_num_ticks: 70 | is_playing = False 71 | # these should be integer values 72 | start_tick = note_or_rest.offset * subdivision 73 | end_tick = start_tick + note_or_rest.quarterLength * subdivision 74 | metadatas[start_tick:end_tick] = self.get_index(is_playing) 75 | return metadatas 76 | 77 | def generate(self, length): 78 | return np.ones(shape=(length,)) 79 | 80 | 81 | class TickMetadata(Metadata): 82 | """ 83 | Metadata class that tracks on which subdivision of the beat we are on 84 | """ 85 | 86 | def __init__(self, subdivision): 87 | super(TickMetadata, self).__init__() 88 | self.is_global = False 89 | self.num_values = subdivision 90 | self.name = 'tick' 91 | 92 | def get_index(self, value): 93 | return value 94 | 95 | def get_value(self, index): 96 | return index 97 | 98 | def evaluate(self, chorale, subdivision): 99 | assert subdivision == self.num_values 100 | # suppose all pieces start on a beat 101 | length = int(chorale.duration.quarterLength * subdivision) 102 | return np.array(list(map( 103 | lambda x: x % self.num_values, 104 | range(length) 105 | ))) 106 | 107 | def generate(self, length): 108 | return np.array(list(map( 109 | lambda x: x % self.num_values, 110 | range(length) 111 | ))) 112 | 113 | 114 | class ModeMetadata(Metadata): 115 | """ 116 | Metadata class that indicates the current mode of the melody 117 | can be major, minor or other 118 | """ 119 | 120 | def __init__(self): 121 | super(ModeMetadata, self).__init__() 122 | self.is_global = False 123 | self.num_values = 3 # major, minor or other 124 | self.name = 'mode' 125 | 126 | def get_index(self, value): 127 | if value == 'major': 128 | return 1 129 | if value == 'minor': 130 | return 2 131 | return 0 132 | 133 | def get_value(self, index): 134 | if index == 1: 135 | return 'major' 136 | if index == 2: 137 | return 'minor' 138 | return 'other' 139 | 140 | def evaluate(self, chorale, subdivision): 141 | # todo add measures when in midi 142 | # init key analyzer 143 | ka = analysis.floatingKey.KeyAnalyzer(chorale) 144 | res = ka.run() 145 | 146 | measure_offset_map = chorale.parts[0].measureOffsetMap() 147 | length = int(chorale.duration.quarterLength * subdivision) # in 16th notes 148 | 149 | modes = np.zeros((length,)) 150 | 151 | measure_index = -1 152 | for time_index in range(length): 153 | beat_index = time_index / subdivision 154 | if beat_index in measure_offset_map: 155 | measure_index += 1 156 | modes[time_index] = self.get_index(res[measure_index].mode) 157 | 158 | return np.array(modes, dtype=np.int32) 159 | 160 | def generate(self, length): 161 | return np.full((length,), self.get_index('major')) 162 | 163 | 164 | class KeyMetadata(Metadata): 165 | """ 166 | Metadata class that indicates in which key we are 167 | Only returns the number of sharps or flats 168 | Does not distinguish a key from its relative key 169 | """ 170 | 171 | def __init__(self, window_size=4): 172 | super(KeyMetadata, self).__init__() 173 | self.window_size = window_size 174 | self.is_global = False 175 | self.num_max_sharps = 7 176 | self.num_values = 16 177 | self.name = 'key' 178 | 179 | def get_index(self, value): 180 | """ 181 | 182 | :param value: number of sharps (between -7 and +7) 183 | :return: index in the representation 184 | """ 185 | return value + self.num_max_sharps + 1 186 | 187 | def get_value(self, index): 188 | """ 189 | 190 | :param index: index (between 0 and self.num_values); 0 is unused (no constraint) 191 | :return: true number of sharps (between -7 and 7) 192 | """ 193 | return index - 1 - self.num_max_sharps 194 | 195 | def evaluate(self, chorale, subdivision): 196 | # init key analyzer 197 | # we must add measures by hand for the case when we are parsing midi files 198 | chorale_with_measures = stream.Score() 199 | for part in chorale.parts: 200 | chorale_with_measures.append(part.makeMeasures()) 201 | 202 | ka = analysis.floatingKey.KeyAnalyzer(chorale_with_measures) 203 | ka.windowSize = self.window_size 204 | res = ka.run() 205 | 206 | measure_offset_map = chorale_with_measures.parts.measureOffsetMap() 207 | length = int(chorale.duration.quarterLength * subdivision) # in 16th notes 208 | 209 | key_signatures = np.zeros((length,)) 210 | 211 | measure_index = -1 212 | for time_index in range(length): 213 | beat_index = time_index / subdivision 214 | if beat_index in measure_offset_map: 215 | measure_index += 1 216 | if measure_index == len(res): 217 | measure_index -= 1 218 | 219 | key_signatures[time_index] = self.get_index(res[measure_index].sharps) 220 | return np.array(key_signatures, dtype=np.int32) 221 | 222 | def generate(self, length): 223 | return np.full((length,), self.get_index(0)) 224 | 225 | 226 | class FermataMetadata(Metadata): 227 | """ 228 | Metadata class which indicates if a fermata is on the current note 229 | """ 230 | 231 | def __init__(self): 232 | super(FermataMetadata, self).__init__() 233 | self.is_global = False 234 | self.num_values = 2 235 | self.name = 'fermata' 236 | 237 | def get_index(self, value): 238 | # possible values are 1 and 0, thus value = index 239 | return value 240 | 241 | def get_value(self, index): 242 | # possible values are 1 and 0, thus value = index 243 | return index 244 | 245 | def evaluate(self, chorale, subdivision): 246 | part = chorale.parts[0] 247 | length = int(part.duration.quarterLength * subdivision) # in 16th notes 248 | list_notes = part.flat.notes 249 | num_notes = len(list_notes) 250 | j = 0 251 | i = 0 252 | fermatas = np.zeros((length,)) 253 | while i < length: 254 | if j < num_notes - 1: 255 | if list_notes[j + 1].offset > i / subdivision: 256 | 257 | if len(list_notes[j].expressions) == 1: 258 | fermata = True 259 | else: 260 | fermata = False 261 | fermatas[i] = fermata 262 | i += 1 263 | else: 264 | j += 1 265 | else: 266 | if len(list_notes[j].expressions) == 1: 267 | fermata = True 268 | else: 269 | fermata = False 270 | 271 | fermatas[i] = fermata 272 | i += 1 273 | return np.array(fermatas, dtype=np.int32) 274 | 275 | def generate(self, length): 276 | # fermata every 2 bars 277 | return np.array([1 if i % 32 >= 28 else 0 278 | for i in range(length)]) 279 | -------------------------------------------------------------------------------- /DatasetManager/music_dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import os 3 | from torch.utils.data import TensorDataset, DataLoader 4 | import torch 5 | 6 | 7 | class MusicDataset(ABC): 8 | """ 9 | Abstract Base Class for music datasets 10 | """ 11 | 12 | def __init__(self, cache_dir): 13 | self._tensor_dataset = None 14 | self.cache_dir = cache_dir 15 | 16 | @abstractmethod 17 | def iterator_gen(self): 18 | """ 19 | 20 | return: Iterator over the dataset 21 | """ 22 | pass 23 | 24 | @abstractmethod 25 | def make_tensor_dataset(self): 26 | """ 27 | 28 | :return: TensorDataset 29 | """ 30 | pass 31 | 32 | @abstractmethod 33 | def get_score_tensor(self, score): 34 | """ 35 | 36 | :param score: music21 score object 37 | :return: torch tensor, with the score representation 38 | as a tensor 39 | """ 40 | pass 41 | 42 | @abstractmethod 43 | def get_metadata_tensor(self, score): 44 | """ 45 | 46 | :param score: music21 score object 47 | :return: torch tensor, with the metadata representation 48 | as a tensor 49 | """ 50 | pass 51 | 52 | @abstractmethod 53 | def transposed_score_and_metadata_tensors(self, score, semi_tone): 54 | """ 55 | 56 | :param score: music21 score object 57 | :param semi-tone: int, +12 to -12, semitones to transpose 58 | :return: Transposed score shifted by the semi-tone 59 | """ 60 | pass 61 | 62 | @abstractmethod 63 | def extract_score_tensor_with_padding(self, 64 | tensor_score, 65 | start_tick, 66 | end_tick): 67 | """ 68 | 69 | :param tensor_score: torch tensor containing the score representation 70 | :param start_tick: 71 | :param end_tick: 72 | :return: tensor_score[:, start_tick: end_tick] 73 | with padding if necessary 74 | i.e. if start_tick < 0 or end_tick > tensor_score length 75 | """ 76 | pass 77 | 78 | @abstractmethod 79 | def extract_metadata_with_padding(self, 80 | tensor_metadata, 81 | start_tick, 82 | end_tick): 83 | """ 84 | 85 | :param tensor_metadata: torch tensor containing metadata 86 | :param start_tick: 87 | :param end_tick: 88 | :return: 89 | """ 90 | pass 91 | 92 | @abstractmethod 93 | def empty_score_tensor(self, score_length): 94 | """ 95 | 96 | :param score_length: int, length of the score in ticks 97 | :return: torch long tensor, initialized with start indices 98 | """ 99 | pass 100 | 101 | @abstractmethod 102 | def random_score_tensor(self, score_length): 103 | """ 104 | 105 | :param score_length: int, length of the score in ticks 106 | :return: torch long tensor, initialized with random indices 107 | """ 108 | pass 109 | 110 | @abstractmethod 111 | def tensor_to_score(self, tensor_score): 112 | """ 113 | 114 | :param tensor_score: torch tensor, tensor representation 115 | of the score 116 | :return: music21 score object 117 | """ 118 | pass 119 | 120 | @property 121 | def tensor_dataset(self): 122 | """ 123 | Loads or computes TensorDataset 124 | :return: TensorDataset 125 | """ 126 | if self._tensor_dataset is None: 127 | if self.tensor_dataset_is_cached(): 128 | print(f'Loading TensorDataset for {self.__repr__()}') 129 | self._tensor_dataset = torch.load(self.tensor_dataset_filepath) 130 | else: 131 | print(f'Creating {self.__repr__()} TensorDataset' 132 | f' since it is not cached') 133 | self._tensor_dataset = self.make_tensor_dataset() 134 | torch.save(self._tensor_dataset, self.tensor_dataset_filepath) 135 | print(f'TensorDataset for {self.__repr__()} ' 136 | f'saved in {self.tensor_dataset_filepath}') 137 | return self._tensor_dataset 138 | 139 | @tensor_dataset.setter 140 | def tensor_dataset(self, value): 141 | self._tensor_dataset = value 142 | 143 | def tensor_dataset_is_cached(self): 144 | return os.path.exists(self.tensor_dataset_filepath) 145 | 146 | @property 147 | def tensor_dataset_filepath(self): 148 | tensor_datasets_cache_dir = os.path.join( 149 | self.cache_dir, 150 | 'tensor_datasets') 151 | if not os.path.exists(tensor_datasets_cache_dir): 152 | os.mkdir(tensor_datasets_cache_dir) 153 | fp = os.path.join( 154 | tensor_datasets_cache_dir, 155 | self.__repr__() 156 | ) 157 | return fp 158 | 159 | @property 160 | def filepath(self): 161 | tensor_datasets_cache_dir = os.path.join( 162 | self.cache_dir, 163 | 'datasets') 164 | if not os.path.exists(tensor_datasets_cache_dir): 165 | os.mkdir(tensor_datasets_cache_dir) 166 | return os.path.join( 167 | self.cache_dir, 168 | 'datasets', 169 | self.__repr__() 170 | ) 171 | 172 | def data_loaders(self, batch_size, split=(0.85, 0.10)): 173 | """ 174 | Returns three data loaders obtained by splitting 175 | self.tensor_dataset according to split 176 | :param batch_size: 177 | :param split: 178 | :return: 179 | """ 180 | assert sum(split) < 1 181 | 182 | dataset = self.tensor_dataset 183 | num_examples = len(dataset) 184 | a, b = split 185 | train_dataset = TensorDataset(*dataset[: int(a * num_examples)]) 186 | val_dataset = TensorDataset(*dataset[int(a * num_examples): 187 | int((a + b) * num_examples)]) 188 | eval_dataset = TensorDataset(*dataset[int((a + b) * num_examples):]) 189 | 190 | train_dl = DataLoader( 191 | train_dataset, 192 | batch_size=batch_size, 193 | shuffle=True, 194 | num_workers=4, 195 | pin_memory=True, 196 | drop_last=True, 197 | ) 198 | 199 | val_dl = DataLoader( 200 | val_dataset, 201 | batch_size=batch_size, 202 | shuffle=False, 203 | num_workers=0, 204 | pin_memory=False, 205 | drop_last=True, 206 | ) 207 | 208 | eval_dl = DataLoader( 209 | eval_dataset, 210 | batch_size=batch_size, 211 | shuffle=False, 212 | num_workers=0, 213 | pin_memory=False, 214 | drop_last=True, 215 | ) 216 | return train_dl, val_dl, eval_dl 217 | -------------------------------------------------------------------------------- /DeepBach/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ghadjeres/DeepBach/6d75cb940f3aa53e02f9eade34d58e472e0c95d7/DeepBach/__init__.py -------------------------------------------------------------------------------- /DeepBach/data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @author: Gaetan Hadjeres 5 | """ 6 | 7 | import torch 8 | from DeepBach.helpers import cuda_variable 9 | 10 | 11 | def mask_entry(tensor, entry_index, dim): 12 | """ 13 | Masks entry entry_index on dim dim 14 | similar to 15 | torch.cat(( tensor[ :entry_index], tensor[ entry_index + 1 :], 0) 16 | but on another dimension 17 | :param tensor: 18 | :param entry_index: 19 | :param dim: 20 | :return: 21 | """ 22 | idx = [i for i in range(tensor.size(dim)) if not i == entry_index] 23 | idx = cuda_variable(torch.LongTensor(idx)) 24 | tensor = tensor.index_select(dim, idx) 25 | return tensor 26 | 27 | 28 | def reverse_tensor(tensor, dim): 29 | """ 30 | Do tensor[:, ... , -1::-1, :] along dim dim 31 | :param tensor: 32 | :param dim: 33 | :return: 34 | """ 35 | idx = [i for i in range(tensor.size(dim) - 1, -1, -1)] 36 | idx = cuda_variable(torch.LongTensor(idx)) 37 | tensor = tensor.index_select(dim, idx) 38 | return tensor 39 | -------------------------------------------------------------------------------- /DeepBach/helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Gaetan Hadjeres 3 | """ 4 | 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | 9 | def cuda_variable(tensor, volatile=False): 10 | if torch.cuda.is_available(): 11 | return Variable(tensor.cuda(), volatile=volatile) 12 | else: 13 | return Variable(tensor, volatile=volatile) 14 | 15 | 16 | def to_numpy(variable: Variable): 17 | if torch.cuda.is_available(): 18 | return variable.data.cpu().numpy() 19 | else: 20 | return variable.data.numpy() 21 | 22 | 23 | def init_hidden(num_layers, batch_size, lstm_hidden_size, 24 | volatile=False): 25 | hidden = ( 26 | cuda_variable( 27 | torch.randn(num_layers, batch_size, lstm_hidden_size), 28 | volatile=volatile), 29 | cuda_variable( 30 | torch.randn(num_layers, batch_size, lstm_hidden_size), 31 | volatile=volatile) 32 | ) 33 | return hidden 34 | -------------------------------------------------------------------------------- /DeepBach/metadata.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metadata classes 3 | """ 4 | import numpy as np 5 | from .data_utils import SUBDIVISION 6 | from music21 import analysis, stream 7 | 8 | 9 | class Metadata: 10 | def __init__(self): 11 | self.num_values = None 12 | self.is_global = None 13 | raise NotImplementedError 14 | 15 | def get_index(self, value): 16 | # trick with the 0 value 17 | raise NotImplementedError 18 | 19 | def get_value(self, index): 20 | raise NotImplementedError 21 | 22 | def evaluate(self, chorale): 23 | """ 24 | takes a music21 chorale as input 25 | """ 26 | raise NotImplementedError 27 | 28 | def generate(self, length): 29 | raise NotImplementedError 30 | 31 | 32 | # todo BeatMetadata class 33 | # todo add strong/weak beat metadata 34 | # todo add minor/major metadata 35 | # todo add voice_i_playing metadata 36 | 37 | class IsPlayingMetadata(Metadata): 38 | def __init__(self, voice_index, min_num_ticks=SUBDIVISION): 39 | """ Initiate the IsPlaying metadata. 40 | Voice i is considered to be muted if more than 'window_size' contiguous subdivisions that contains a rest. 41 | 42 | :param min_num_ticks: minimum length in ticks for a rest to be taken into account in the metadata 43 | """ 44 | self.min_num_ticks = min_num_ticks 45 | self.voice_index = voice_index 46 | self.is_global = False 47 | self.num_values = 2 48 | 49 | def get_index(self, value): 50 | return int(value) 51 | 52 | def get_value(self, index): 53 | return bool(index) 54 | 55 | def evaluate(self, chorale): 56 | """ 57 | takes a music21 chorale as input 58 | """ 59 | length = int(chorale.duration.quarterLength * SUBDIVISION) 60 | metadatas = np.ones(shape=(length,)) 61 | part = chorale.parts[self.voice_index] 62 | 63 | for note_or_rest in part.notesAndRests: 64 | is_playing = True 65 | if note_or_rest.isRest: 66 | if note_or_rest.quarterLength * SUBDIVISION >= self.min_num_ticks: 67 | is_playing = False 68 | # these should be integer values 69 | start_tick = note_or_rest.offset * SUBDIVISION 70 | end_tick = start_tick + note_or_rest.quarterLength * SUBDIVISION 71 | metadatas[start_tick:end_tick] = self.get_index(is_playing) 72 | return metadatas 73 | 74 | def generate(self, length): 75 | return np.ones(shape=(length,)) 76 | 77 | 78 | class TickMetadatas(Metadata): 79 | def __init__(self, num_subdivisions): 80 | self.is_global = False 81 | self.num_values = num_subdivisions 82 | 83 | def get_index(self, value): 84 | return value 85 | 86 | def get_value(self, index): 87 | return index 88 | 89 | def evaluate(self, chorale): 90 | # suppose all pieces start on a beat 91 | length = int(chorale.duration.quarterLength * SUBDIVISION) 92 | return np.array(list(map( 93 | lambda x: x % self.num_values, 94 | range(length) 95 | ))) 96 | 97 | def generate(self, length): 98 | return np.array(list(map( 99 | lambda x: x % self.num_values, 100 | range(length) 101 | ))) 102 | 103 | 104 | class ModeMetadatas(Metadata): 105 | def __init__(self): 106 | self.is_global = False 107 | self.num_values = 3 # major, minor or other 108 | 109 | def get_index(self, value): 110 | if value == 'major': 111 | return 1 112 | if value == 'minor': 113 | return 2 114 | return 0 115 | 116 | def get_value(self, index): 117 | if index == 1: 118 | return 'major' 119 | if index == 2: 120 | return 'minor' 121 | return 'other' 122 | 123 | def evaluate(self, chorale): 124 | # todo add measures when in midi 125 | # init key analyzer 126 | ka = analysis.floatingKey.KeyAnalyzer(chorale) 127 | res = ka.run() 128 | 129 | measure_offset_map = chorale.parts[0].measureOffsetMap() 130 | length = int(chorale.duration.quarterLength * SUBDIVISION) # in 16th notes 131 | 132 | modes = np.zeros((length,)) 133 | 134 | measure_index = -1 135 | for time_index in range(length): 136 | beat_index = time_index / SUBDIVISION 137 | if beat_index in measure_offset_map: 138 | measure_index += 1 139 | modes[time_index] = self.get_index(res[measure_index].mode) 140 | 141 | return np.array(modes, dtype=np.int32) 142 | 143 | def generate(self, length): 144 | return np.full((length,), self.get_index('major')) 145 | 146 | 147 | class KeyMetadatas(Metadata): 148 | def __init__(self, window_size=4): 149 | self.window_size = window_size 150 | self.is_global = False 151 | self.num_max_sharps = 7 152 | self.num_values = 16 153 | 154 | def get_index(self, value): 155 | """ 156 | 157 | :param value: number of sharps (between -7 and +7) 158 | :return: index in the representation 159 | """ 160 | return value + self.num_max_sharps + 1 161 | 162 | def get_value(self, index): 163 | """ 164 | 165 | :param index: index (between 0 and self.num_values); 0 is unused (no constraint) 166 | :return: true number of sharps (between -7 and 7) 167 | """ 168 | return index - 1 - self.num_max_sharps 169 | 170 | # todo check if this method is correct for windowSize > 1 171 | def evaluate(self, chorale): 172 | # init key analyzer 173 | # we must add measures by hand for the case when we are parsing midi files 174 | chorale_with_measures = stream.Score() 175 | for part in chorale.parts: 176 | chorale_with_measures.append(part.makeMeasures()) 177 | 178 | ka = analysis.floatingKey.KeyAnalyzer(chorale_with_measures) 179 | ka.windowSize = self.window_size 180 | res = ka.run() 181 | 182 | measure_offset_map = chorale_with_measures.parts.measureOffsetMap() 183 | length = int(chorale.duration.quarterLength * SUBDIVISION) # in 16th notes 184 | 185 | key_signatures = np.zeros((length,)) 186 | 187 | measure_index = -1 188 | for time_index in range(length): 189 | beat_index = time_index / SUBDIVISION 190 | if beat_index in measure_offset_map: 191 | measure_index += 1 192 | # todo remove this trick: problem with the last measures... 193 | if measure_index == len(res): 194 | measure_index -= 1 195 | 196 | key_signatures[time_index] = self.get_index(res[measure_index].sharps) 197 | return np.array(key_signatures, dtype=np.int32) 198 | 199 | def generate(self, length): 200 | return np.full((length,), self.get_index(0)) 201 | 202 | 203 | class FermataMetadatas(Metadata): 204 | def __init__(self): 205 | self.is_global = False 206 | self.num_values = 2 207 | 208 | def get_index(self, value): 209 | # values are 1 and 0 210 | return value 211 | 212 | def get_value(self, index): 213 | return index 214 | 215 | def evaluate(self, chorale): 216 | part = chorale.parts[0] 217 | length = int(part.duration.quarterLength * SUBDIVISION) # in 16th notes 218 | list_notes = part.flat.notes 219 | num_notes = len(list_notes) 220 | j = 0 221 | i = 0 222 | fermatas = np.zeros((length,)) 223 | fermata = False 224 | while i < length: 225 | if j < num_notes - 1: 226 | if list_notes[j + 1].offset > i / SUBDIVISION: 227 | 228 | if len(list_notes[j].expressions) == 1: 229 | fermata = True 230 | else: 231 | fermata = False 232 | fermatas[i] = fermata 233 | i += 1 234 | else: 235 | j += 1 236 | else: 237 | if len(list_notes[j].expressions) == 1: 238 | fermata = True 239 | else: 240 | fermata = False 241 | 242 | fermatas[i] = fermata 243 | i += 1 244 | return np.array(fermatas, dtype=np.int32) 245 | 246 | def generate(self, length): 247 | # fermata every 2 bars 248 | return np.array([1 if i % 32 > 28 else 0 249 | for i in range(length)]) 250 | -------------------------------------------------------------------------------- /DeepBach/model_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Gaetan Hadjeres 3 | """ 4 | 5 | from DatasetManager.metadata import FermataMetadata 6 | import numpy as np 7 | import torch 8 | from DeepBach.helpers import cuda_variable, to_numpy 9 | 10 | from torch import optim, nn 11 | from tqdm import tqdm 12 | 13 | from DeepBach.voice_model import VoiceModel 14 | 15 | 16 | class DeepBach: 17 | def __init__(self, 18 | dataset, 19 | note_embedding_dim, 20 | meta_embedding_dim, 21 | num_layers, 22 | lstm_hidden_size, 23 | dropout_lstm, 24 | linear_hidden_size, 25 | ): 26 | self.dataset = dataset 27 | self.num_voices = self.dataset.num_voices 28 | self.num_metas = len(self.dataset.metadatas) + 1 29 | self.activate_cuda = torch.cuda.is_available() 30 | 31 | self.voice_models = [VoiceModel( 32 | dataset=self.dataset, 33 | main_voice_index=main_voice_index, 34 | note_embedding_dim=note_embedding_dim, 35 | meta_embedding_dim=meta_embedding_dim, 36 | num_layers=num_layers, 37 | lstm_hidden_size=lstm_hidden_size, 38 | dropout_lstm=dropout_lstm, 39 | hidden_size_linear=linear_hidden_size, 40 | ) 41 | for main_voice_index in range(self.num_voices) 42 | ] 43 | 44 | def cuda(self, main_voice_index=None): 45 | if self.activate_cuda: 46 | if main_voice_index is None: 47 | for voice_index in range(self.num_voices): 48 | self.cuda(voice_index) 49 | else: 50 | self.voice_models[main_voice_index].cuda() 51 | 52 | # Utils 53 | def load(self, main_voice_index=None): 54 | if main_voice_index is None: 55 | for voice_index in range(self.num_voices): 56 | self.load(main_voice_index=voice_index) 57 | else: 58 | self.voice_models[main_voice_index].load() 59 | 60 | def save(self, main_voice_index=None): 61 | if main_voice_index is None: 62 | for voice_index in range(self.num_voices): 63 | self.save(main_voice_index=voice_index) 64 | else: 65 | self.voice_models[main_voice_index].save() 66 | 67 | def train(self, main_voice_index=None, 68 | **kwargs): 69 | if main_voice_index is None: 70 | for voice_index in range(self.num_voices): 71 | self.train(main_voice_index=voice_index, **kwargs) 72 | else: 73 | voice_model = self.voice_models[main_voice_index] 74 | if self.activate_cuda: 75 | voice_model.cuda() 76 | optimizer = optim.Adam(voice_model.parameters()) 77 | voice_model.train_model(optimizer=optimizer, **kwargs) 78 | 79 | def eval_phase(self): 80 | for voice_model in self.voice_models: 81 | voice_model.eval() 82 | 83 | def train_phase(self): 84 | for voice_model in self.voice_models: 85 | voice_model.train() 86 | 87 | def generation(self, 88 | temperature=1.0, 89 | batch_size_per_voice=8, 90 | num_iterations=None, 91 | sequence_length_ticks=160, 92 | tensor_chorale=None, 93 | tensor_metadata=None, 94 | time_index_range_ticks=None, 95 | voice_index_range=None, 96 | fermatas=None, 97 | random_init=True 98 | ): 99 | """ 100 | 101 | :param temperature: 102 | :param batch_size_per_voice: 103 | :param num_iterations: 104 | :param sequence_length_ticks: 105 | :param tensor_chorale: 106 | :param tensor_metadata: 107 | :param time_index_range_ticks: list of two integers [a, b] or None; can be used \ 108 | to regenerate only the portion of the score between timesteps a and b 109 | :param voice_index_range: list of two integers [a, b] or None; can be used \ 110 | to regenerate only the portion of the score between voice_index a and b 111 | :param fermatas: list[Fermata] 112 | :param random_init: boolean, whether or not to randomly initialize 113 | the portion of the score on which we apply the pseudo-Gibbs algorithm 114 | :return: tuple ( 115 | generated_score [music21 Stream object], 116 | tensor_chorale (num_voices, chorale_length) torch.IntTensor, 117 | tensor_metadata (num_voices, chorale_length, num_metadata) torch.IntTensor 118 | ) 119 | """ 120 | self.eval_phase() 121 | 122 | # --Process arguments 123 | # initialize generated chorale 124 | # tensor_chorale = self.dataset.empty_chorale(sequence_length_ticks) 125 | if tensor_chorale is None: 126 | tensor_chorale = self.dataset.random_score_tensor( 127 | sequence_length_ticks) 128 | else: 129 | sequence_length_ticks = tensor_chorale.size(1) 130 | 131 | # initialize metadata 132 | if tensor_metadata is None: 133 | test_chorale = next(self.dataset.corpus_it_gen().__iter__()) 134 | tensor_metadata = self.dataset.get_metadata_tensor(test_chorale) 135 | 136 | if tensor_metadata.size(1) < sequence_length_ticks: 137 | tensor_metadata = tensor_metadata.repeat(1, sequence_length_ticks // tensor_metadata.size(1) + 1, 1) 138 | 139 | # todo do not work if metadata_length_ticks > sequence_length_ticks 140 | tensor_metadata = tensor_metadata[:, :sequence_length_ticks, :] 141 | else: 142 | tensor_metadata_length = tensor_metadata.size(1) 143 | assert tensor_metadata_length == sequence_length_ticks 144 | 145 | if fermatas is not None: 146 | tensor_metadata = self.dataset.set_fermatas(tensor_metadata, 147 | fermatas) 148 | 149 | # timesteps_ticks is the number of ticks on which we unroll the LSTMs 150 | # it is also the padding size 151 | timesteps_ticks = self.dataset.sequences_size * self.dataset.subdivision // 2 152 | if time_index_range_ticks is None: 153 | time_index_range_ticks = [timesteps_ticks, sequence_length_ticks + timesteps_ticks] 154 | else: 155 | a_ticks, b_ticks = time_index_range_ticks 156 | assert 0 <= a_ticks < b_ticks <= sequence_length_ticks 157 | time_index_range_ticks = [a_ticks + timesteps_ticks, b_ticks + timesteps_ticks] 158 | 159 | if voice_index_range is None: 160 | voice_index_range = [0, self.dataset.num_voices] 161 | 162 | tensor_chorale = self.dataset.extract_score_tensor_with_padding( 163 | tensor_score=tensor_chorale, 164 | start_tick=-timesteps_ticks, 165 | end_tick=sequence_length_ticks + timesteps_ticks 166 | ) 167 | 168 | tensor_metadata_padded = self.dataset.extract_metadata_with_padding( 169 | tensor_metadata=tensor_metadata, 170 | start_tick=-timesteps_ticks, 171 | end_tick=sequence_length_ticks + timesteps_ticks 172 | ) 173 | 174 | # randomize regenerated part 175 | if random_init: 176 | a, b = time_index_range_ticks 177 | tensor_chorale[voice_index_range[0]:voice_index_range[1], a:b] = self.dataset.random_score_tensor( 178 | b - a)[voice_index_range[0]:voice_index_range[1], :] 179 | 180 | tensor_chorale = self.parallel_gibbs( 181 | tensor_chorale=tensor_chorale, 182 | tensor_metadata=tensor_metadata_padded, 183 | num_iterations=num_iterations, 184 | timesteps_ticks=timesteps_ticks, 185 | temperature=temperature, 186 | batch_size_per_voice=batch_size_per_voice, 187 | time_index_range_ticks=time_index_range_ticks, 188 | voice_index_range=voice_index_range, 189 | ) 190 | 191 | # get fermata tensor 192 | for metadata_index, metadata in enumerate(self.dataset.metadatas): 193 | if isinstance(metadata, FermataMetadata): 194 | break 195 | 196 | 197 | score = self.dataset.tensor_to_score( 198 | tensor_score=tensor_chorale, 199 | fermata_tensor=tensor_metadata[:, :, metadata_index]) 200 | 201 | return score, tensor_chorale, tensor_metadata 202 | 203 | def parallel_gibbs(self, 204 | tensor_chorale, 205 | tensor_metadata, 206 | timesteps_ticks, 207 | num_iterations=1000, 208 | batch_size_per_voice=16, 209 | temperature=1., 210 | time_index_range_ticks=None, 211 | voice_index_range=None, 212 | ): 213 | """ 214 | Parallel pseudo-Gibbs sampling 215 | tensor_chorale and tensor_metadata are padded with 216 | timesteps_ticks START_SYMBOLS before, 217 | timesteps_ticks END_SYMBOLS after 218 | :param tensor_chorale: (num_voices, chorale_length) tensor 219 | :param tensor_metadata: (num_voices, chorale_length) tensor 220 | :param timesteps_ticks: 221 | :param num_iterations: number of Gibbs sampling iterations 222 | :param batch_size_per_voice: number of simultaneous parallel updates 223 | :param temperature: final temperature after simulated annealing 224 | :param time_index_range_ticks: list of two integers [a, b] or None; can be used \ 225 | to regenerate only the portion of the score between timesteps a and b 226 | :param voice_index_range: list of two integers [a, b] or None; can be used \ 227 | to regenerate only the portion of the score between voice_index a and b 228 | :return: (num_voices, chorale_length) tensor 229 | """ 230 | start_voice, end_voice = voice_index_range 231 | # add batch_dimension 232 | tensor_chorale = tensor_chorale.unsqueeze(0) 233 | tensor_chorale_no_cuda = tensor_chorale.clone() 234 | tensor_metadata = tensor_metadata.unsqueeze(0) 235 | 236 | # to variable 237 | tensor_chorale = cuda_variable(tensor_chorale, volatile=True) 238 | tensor_metadata = cuda_variable(tensor_metadata, volatile=True) 239 | 240 | min_temperature = temperature 241 | temperature = 1.1 242 | 243 | # Main loop 244 | for iteration in tqdm(range(num_iterations)): 245 | # annealing 246 | temperature = max(min_temperature, temperature * 0.9993) 247 | # print(temperature) 248 | time_indexes_ticks = {} 249 | probas = {} 250 | 251 | for voice_index in range(start_voice, end_voice): 252 | batch_notes = [] 253 | batch_metas = [] 254 | 255 | time_indexes_ticks[voice_index] = [] 256 | 257 | # create batches of inputs 258 | for batch_index in range(batch_size_per_voice): 259 | time_index_ticks = np.random.randint( 260 | *time_index_range_ticks) 261 | time_indexes_ticks[voice_index].append(time_index_ticks) 262 | 263 | notes, label = (self.voice_models[voice_index] 264 | .preprocess_notes( 265 | tensor_chorale=tensor_chorale[ 266 | :, :, 267 | time_index_ticks - timesteps_ticks: 268 | time_index_ticks + timesteps_ticks], 269 | time_index_ticks=timesteps_ticks 270 | ) 271 | ) 272 | metas = self.voice_models[voice_index].preprocess_metas( 273 | tensor_metadata=tensor_metadata[ 274 | :, :, 275 | time_index_ticks - timesteps_ticks: 276 | time_index_ticks + timesteps_ticks, 277 | :], 278 | time_index_ticks=timesteps_ticks 279 | ) 280 | 281 | batch_notes.append(notes) 282 | batch_metas.append(metas) 283 | 284 | # reshape batches 285 | batch_notes = list(map(list, zip(*batch_notes))) 286 | batch_notes = [torch.cat(lcr) if lcr[0] is not None else None 287 | for lcr in batch_notes] 288 | batch_metas = list(map(list, zip(*batch_metas))) 289 | batch_metas = [torch.cat(lcr) 290 | for lcr in batch_metas] 291 | 292 | # make all estimations 293 | probas[voice_index] = (self.voice_models[voice_index] 294 | .forward(batch_notes, batch_metas) 295 | ) 296 | probas[voice_index] = nn.Softmax(dim=1)(probas[voice_index]) 297 | 298 | # update all predictions 299 | for voice_index in range(start_voice, end_voice): 300 | for batch_index in range(batch_size_per_voice): 301 | probas_pitch = probas[voice_index][batch_index] 302 | 303 | probas_pitch = to_numpy(probas_pitch) 304 | 305 | # use temperature 306 | probas_pitch = np.log(probas_pitch) / temperature 307 | probas_pitch = np.exp(probas_pitch) / np.sum( 308 | np.exp(probas_pitch)) - 1e-7 309 | 310 | # avoid non-probabilities 311 | probas_pitch[probas_pitch < 0] = 0 312 | 313 | # pitch can include slur_symbol 314 | pitch = np.argmax(np.random.multinomial(1, probas_pitch)) 315 | 316 | tensor_chorale_no_cuda[ 317 | 0, 318 | voice_index, 319 | time_indexes_ticks[voice_index][batch_index] 320 | ] = int(pitch) 321 | 322 | tensor_chorale = cuda_variable(tensor_chorale_no_cuda.clone(), 323 | volatile=True) 324 | 325 | return tensor_chorale_no_cuda[0, :, timesteps_ticks:-timesteps_ticks] 326 | -------------------------------------------------------------------------------- /DeepBach/voice_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Gaetan Hadjeres 3 | """ 4 | 5 | import random 6 | 7 | import torch 8 | from DatasetManager.chorale_dataset import ChoraleDataset 9 | from DeepBach.helpers import cuda_variable, init_hidden 10 | 11 | from torch import nn 12 | 13 | from DeepBach.data_utils import reverse_tensor, mask_entry 14 | 15 | 16 | class VoiceModel(nn.Module): 17 | def __init__(self, 18 | dataset: ChoraleDataset, 19 | main_voice_index: int, 20 | note_embedding_dim: int, 21 | meta_embedding_dim: int, 22 | num_layers: int, 23 | lstm_hidden_size: int, 24 | dropout_lstm: float, 25 | hidden_size_linear=200 26 | ): 27 | super(VoiceModel, self).__init__() 28 | self.dataset = dataset 29 | self.main_voice_index = main_voice_index 30 | self.note_embedding_dim = note_embedding_dim 31 | self.meta_embedding_dim = meta_embedding_dim 32 | self.num_notes_per_voice = [len(d) 33 | for d in dataset.note2index_dicts] 34 | self.num_voices = self.dataset.num_voices 35 | self.num_metas_per_voice = [ 36 | metadata.num_values 37 | for metadata in dataset.metadatas 38 | ] + [self.num_voices] 39 | self.num_metas = len(self.dataset.metadatas) + 1 40 | self.num_layers = num_layers 41 | self.lstm_hidden_size = lstm_hidden_size 42 | self.dropout_lstm = dropout_lstm 43 | self.hidden_size_linear = hidden_size_linear 44 | 45 | self.other_voices_indexes = [i 46 | for i 47 | in range(self.num_voices) 48 | if not i == main_voice_index] 49 | 50 | self.note_embeddings = nn.ModuleList( 51 | [nn.Embedding(num_notes, note_embedding_dim) 52 | for num_notes in self.num_notes_per_voice] 53 | ) 54 | self.meta_embeddings = nn.ModuleList( 55 | [nn.Embedding(num_metas, meta_embedding_dim) 56 | for num_metas in self.num_metas_per_voice] 57 | ) 58 | 59 | self.lstm_left = nn.LSTM(input_size=note_embedding_dim * self.num_voices + 60 | meta_embedding_dim * self.num_metas, 61 | hidden_size=lstm_hidden_size, 62 | num_layers=num_layers, 63 | dropout=dropout_lstm, 64 | batch_first=True) 65 | self.lstm_right = nn.LSTM(input_size=note_embedding_dim * self.num_voices + 66 | meta_embedding_dim * self.num_metas, 67 | hidden_size=lstm_hidden_size, 68 | num_layers=num_layers, 69 | dropout=dropout_lstm, 70 | batch_first=True) 71 | 72 | self.mlp_center = nn.Sequential( 73 | nn.Linear((note_embedding_dim * (self.num_voices - 1) 74 | + meta_embedding_dim * self.num_metas), 75 | hidden_size_linear), 76 | nn.ReLU(), 77 | nn.Linear(hidden_size_linear, lstm_hidden_size) 78 | ) 79 | 80 | self.mlp_predictions = nn.Sequential( 81 | nn.Linear(self.lstm_hidden_size * 3, 82 | hidden_size_linear), 83 | nn.ReLU(), 84 | nn.Linear(hidden_size_linear, self.num_notes_per_voice[main_voice_index]) 85 | ) 86 | 87 | def forward(self, *input): 88 | notes, metas = input 89 | batch_size, num_voices, timesteps_ticks = notes[0].size() 90 | 91 | # put time first 92 | ln, cn, rn = notes 93 | ln, rn = [t.transpose(1, 2) 94 | for t in (ln, rn)] 95 | notes = ln, cn, rn 96 | 97 | # embedding 98 | notes_embedded = self.embed(notes, type='note') 99 | metas_embedded = self.embed(metas, type='meta') 100 | # lists of (N, timesteps_ticks, voices * dim_embedding) 101 | # where timesteps_ticks is 1 for central parts 102 | 103 | # concat notes and metas 104 | input_embedded = [torch.cat([notes, metas], 2) if notes is not None else None 105 | for notes, metas in zip(notes_embedded, metas_embedded)] 106 | 107 | left, center, right = input_embedded 108 | 109 | # main part 110 | hidden = init_hidden( 111 | num_layers=self.num_layers, 112 | batch_size=batch_size, 113 | lstm_hidden_size=self.lstm_hidden_size, 114 | ) 115 | left, hidden = self.lstm_left(left, hidden) 116 | left = left[:, -1, :] 117 | 118 | if self.num_voices == 1: 119 | center = cuda_variable(torch.zeros( 120 | batch_size, 121 | self.lstm_hidden_size) 122 | ) 123 | else: 124 | center = center[:, 0, :] # remove time dimension 125 | center = self.mlp_center(center) 126 | 127 | hidden = init_hidden( 128 | num_layers=self.num_layers, 129 | batch_size=batch_size, 130 | lstm_hidden_size=self.lstm_hidden_size, 131 | ) 132 | right, hidden = self.lstm_right(right, hidden) 133 | right = right[:, -1, :] 134 | 135 | # concat and return prediction 136 | predictions = torch.cat([ 137 | left, center, right 138 | ], 1) 139 | 140 | predictions = self.mlp_predictions(predictions) 141 | 142 | return predictions 143 | 144 | def embed(self, notes_or_metas, type): 145 | if type == 'note': 146 | embeddings = self.note_embeddings 147 | embedding_dim = self.note_embedding_dim 148 | other_voices_indexes = self.other_voices_indexes 149 | elif type == 'meta': 150 | embeddings = self.meta_embeddings 151 | embedding_dim = self.meta_embedding_dim 152 | other_voices_indexes = range(self.num_metas) 153 | 154 | batch_size, timesteps_left_ticks, num_voices = notes_or_metas[0].size() 155 | batch_size, timesteps_right_ticks, num_voices = notes_or_metas[2].size() 156 | 157 | left, center, right = notes_or_metas 158 | # center has self.num_voices - 1 voices 159 | left_embedded = torch.cat([ 160 | embeddings[voice_id](left[:, :, voice_id])[:, :, None, :] 161 | for voice_id in range(num_voices) 162 | ], 2) 163 | right_embedded = torch.cat([ 164 | embeddings[voice_id](right[:, :, voice_id])[:, :, None, :] 165 | for voice_id in range(num_voices) 166 | ], 2) 167 | if self.num_voices == 1 and type == 'note': 168 | center_embedded = None 169 | else: 170 | center_embedded = torch.cat([ 171 | embeddings[voice_id](center[:, k].unsqueeze(1)) 172 | for k, voice_id in enumerate(other_voices_indexes) 173 | ], 1) 174 | center_embedded = center_embedded.view(batch_size, 175 | 1, 176 | len(other_voices_indexes) * embedding_dim) 177 | 178 | # squeeze two last dimensions 179 | left_embedded = left_embedded.view(batch_size, 180 | timesteps_left_ticks, 181 | num_voices * embedding_dim) 182 | right_embedded = right_embedded.view(batch_size, 183 | timesteps_right_ticks, 184 | num_voices * embedding_dim) 185 | 186 | return left_embedded, center_embedded, right_embedded 187 | 188 | def save(self): 189 | torch.save(self.state_dict(), 'models/' + self.__repr__()) 190 | print(f'Model {self.__repr__()} saved') 191 | 192 | def load(self): 193 | state_dict = torch.load('models/' + self.__repr__(), 194 | map_location=lambda storage, loc: storage) 195 | print(f'Loading {self.__repr__()}') 196 | self.load_state_dict(state_dict) 197 | 198 | def __repr__(self): 199 | return f'VoiceModel(' \ 200 | f'{self.dataset.__repr__()},' \ 201 | f'{self.main_voice_index},' \ 202 | f'{self.note_embedding_dim},' \ 203 | f'{self.meta_embedding_dim},' \ 204 | f'{self.num_layers},' \ 205 | f'{self.lstm_hidden_size},' \ 206 | f'{self.dropout_lstm},' \ 207 | f'{self.hidden_size_linear}' \ 208 | f')' 209 | 210 | def train_model(self, 211 | batch_size=16, 212 | num_epochs=10, 213 | optimizer=None): 214 | for epoch in range(num_epochs): 215 | print(f'===Epoch {epoch}===') 216 | (dataloader_train, 217 | dataloader_val, 218 | dataloader_test) = self.dataset.data_loaders( 219 | batch_size=batch_size, 220 | ) 221 | 222 | loss, acc = self.loss_and_acc(dataloader_train, 223 | optimizer=optimizer, 224 | phase='train') 225 | print(f'Training loss: {loss}') 226 | print(f'Training accuracy: {acc}') 227 | # writer.add_scalar('data/training_loss', loss, epoch) 228 | # writer.add_scalar('data/training_acc', acc, epoch) 229 | 230 | loss, acc = self.loss_and_acc(dataloader_val, 231 | optimizer=None, 232 | phase='test') 233 | print(f'Validation loss: {loss}') 234 | print(f'Validation accuracy: {acc}') 235 | self.save() 236 | 237 | def loss_and_acc(self, dataloader, 238 | optimizer=None, 239 | phase='train'): 240 | 241 | average_loss = 0 242 | average_acc = 0 243 | if phase == 'train': 244 | self.train() 245 | elif phase == 'eval' or phase == 'test': 246 | self.eval() 247 | else: 248 | raise NotImplementedError 249 | for tensor_chorale, tensor_metadata in dataloader: 250 | 251 | # to Variable 252 | tensor_chorale = cuda_variable(tensor_chorale).long() 253 | tensor_metadata = cuda_variable(tensor_metadata).long() 254 | 255 | # preprocessing to put in the DeepBach format 256 | # see Fig. 4 in DeepBach paper: 257 | # https://arxiv.org/pdf/1612.01010.pdf 258 | notes, metas, label = self.preprocess_input(tensor_chorale, 259 | tensor_metadata) 260 | 261 | weights = self.forward(notes, metas) 262 | 263 | loss_function = torch.nn.CrossEntropyLoss() 264 | 265 | loss = loss_function(weights, label) 266 | 267 | if phase == 'train': 268 | optimizer.zero_grad() 269 | loss.backward() 270 | optimizer.step() 271 | 272 | acc = self.accuracy(weights=weights, 273 | target=label) 274 | 275 | average_loss += loss.item() 276 | average_acc += acc.item() 277 | 278 | average_loss /= len(dataloader) 279 | average_acc /= len(dataloader) 280 | return average_loss, average_acc 281 | 282 | def accuracy(self, weights, target): 283 | batch_size, = target.size() 284 | softmax = nn.Softmax(dim=1)(weights) 285 | pred = softmax.max(1)[1].type_as(target) 286 | num_corrects = (pred == target).float().sum() 287 | return num_corrects / batch_size * 100 288 | 289 | def preprocess_input(self, tensor_chorale, tensor_metadata): 290 | """ 291 | :param tensor_chorale: (batch_size, num_voices, chorale_length_ticks) 292 | :param tensor_metadata: (batch_size, num_metadata, chorale_length_ticks) 293 | :return: (notes, metas, label) tuple 294 | where 295 | notes = (left_notes, central_notes, right_notes) 296 | metas = (left_metas, central_metas, right_metas) 297 | label = (batch_size) 298 | right_notes and right_metas are REVERSED (from right to left) 299 | """ 300 | batch_size, num_voices, chorale_length_ticks = tensor_chorale.size() 301 | 302 | # random shift! Depends on the dataset 303 | offset = random.randint(0, self.dataset.subdivision) 304 | time_index_ticks = chorale_length_ticks // 2 + offset 305 | 306 | # split notes 307 | notes, label = self.preprocess_notes(tensor_chorale, time_index_ticks) 308 | metas = self.preprocess_metas(tensor_metadata, time_index_ticks) 309 | return notes, metas, label 310 | 311 | def preprocess_notes(self, tensor_chorale, time_index_ticks): 312 | """ 313 | 314 | :param tensor_chorale: (batch_size, num_voices, chorale_length_ticks) 315 | :param time_index_ticks: 316 | :return: 317 | """ 318 | batch_size, num_voices, _ = tensor_chorale.size() 319 | left_notes = tensor_chorale[:, :, :time_index_ticks] 320 | right_notes = reverse_tensor( 321 | tensor_chorale[:, :, time_index_ticks + 1:], 322 | dim=2) 323 | if self.num_voices == 1: 324 | central_notes = None 325 | else: 326 | central_notes = mask_entry(tensor_chorale[:, :, time_index_ticks], 327 | entry_index=self.main_voice_index, 328 | dim=1) 329 | label = tensor_chorale[:, self.main_voice_index, time_index_ticks] 330 | return (left_notes, central_notes, right_notes), label 331 | 332 | def preprocess_metas(self, tensor_metadata, time_index_ticks): 333 | """ 334 | 335 | :param tensor_metadata: (batch_size, num_voices, chorale_length_ticks) 336 | :param time_index_ticks: 337 | :return: 338 | """ 339 | 340 | left_metas = tensor_metadata[:, self.main_voice_index, :time_index_ticks, :] 341 | right_metas = reverse_tensor( 342 | tensor_metadata[:, self.main_voice_index, time_index_ticks + 1:, :], 343 | dim=1) 344 | central_metas = tensor_metadata[:, self.main_voice_index, time_index_ticks, :] 345 | return left_metas, central_metas, right_metas 346 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.0-cuda10.0-cudnn7-runtime 2 | 3 | RUN git clone https://github.com/Ghadjeres/DeepBach.git 4 | WORKDIR DeepBach 5 | RUN conda env create --name deepbach_pytorch -f environment.yml 6 | 7 | RUN apt update && apt install wget 8 | RUN bash dl_dataset_and_models.sh 9 | 10 | 11 | COPY entrypoint.sh entrypoint.sh 12 | RUN chmod u+x entrypoint.sh 13 | 14 | EXPOSE 5000 15 | ENTRYPOINT ["./entrypoint.sh"] 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Gaetan Hadjeres, Laboratoire LIP6 - Departement ASIM Universite Pierre et Marie Curie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepBach 2 | This repository contains implementations of the DeepBach model described in 3 | 4 | *DeepBach: a Steerable Model for Bach chorales generation*
5 | Gaëtan Hadjeres, François Pachet, Frank Nielsen
6 | *ICML 2017 [arXiv:1612.01010](http://proceedings.mlr.press/v70/hadjeres17a.html)* 7 | 8 | 9 | The code uses python 3.6 together with [PyTorch v1.0](https://pytorch.org/) and 10 | [music21](http://web.mit.edu/music21/) libraries. 11 | 12 | For the original Keras version, please checkout the `original_keras` branch. 13 | 14 | Examples of music generated by DeepBach are available on [this website](https://sites.google.com/site/deepbachexamples/) 15 | 16 | ## Installation 17 | 18 | You can clone this repository, install dependencies using Anaconda and download a pretrained 19 | model together with a dataset 20 | with the following commands: 21 | ``` 22 | git clone git@github.com:Ghadjeres/DeepBach.git 23 | cd DeepBach 24 | conda env create --name deepbach_pytorch -f environment.yml 25 | bash dl_dataset_and_models.sh 26 | ``` 27 | This will create a conda env named `deepbach_pytorch`. 28 | 29 | ### music21 editor 30 | 31 | You might need to 32 | Open a four-part chorale. Press enter on the server address, a list of computed models should appear. Select and (re)load a model. 33 | [Configure properly the music editor 34 | called by music21](http://web.mit.edu/music21/doc/moduleReference/moduleEnvironment.html). On Ubuntu you can eg. use MuseScore: 35 | 36 | ```shell 37 | sudo apt install musescore 38 | python -c 'import music21; music21.environment.set("musicxmlPath", "/usr/bin/musescore")' 39 | ``` 40 | 41 | For usage on a headless server (no X server), just set it to a dummy command: 42 | 43 | ```shell 44 | python -c 'import music21; music21.environment.set("musicxmlPath", "/bin/true")' 45 | ``` 46 | 47 | ## Usage 48 | ``` 49 | Usage: deepBach.py [OPTIONS] 50 | 51 | Options: 52 | --note_embedding_dim INTEGER size of the note embeddings 53 | --meta_embedding_dim INTEGER size of the metadata embeddings 54 | --num_layers INTEGER number of layers of the LSTMs 55 | --lstm_hidden_size INTEGER hidden size of the LSTMs 56 | --dropout_lstm FLOAT amount of dropout between LSTM layers 57 | --linear_hidden_size INTEGER hidden size of the Linear layers 58 | --batch_size INTEGER training batch size 59 | --num_epochs INTEGER number of training epochs 60 | --train train or retrain the specified model 61 | --num_iterations INTEGER number of parallel pseudo-Gibbs sampling 62 | iterations 63 | --sequence_length_ticks INTEGER 64 | length of the generated chorale (in ticks) 65 | --help Show this message and exit. 66 | ``` 67 | 68 | ## Examples 69 | You can generate a four-bar chorale with the pretrained model and display it in MuseScore by 70 | simply running 71 | ``` 72 | python deepBach.py 73 | ``` 74 | 75 | You can train a new model from scratch by adding the `--train` flag. 76 | 77 | 78 | ## Usage with NONOTO 79 | The command 80 | ``` 81 | python flask_server.py 82 | ``` 83 | starts a Flask server listening on port 5000. You can then use 84 | [NONOTO](https://github.com/SonyCSLParis/NONOTO) to compose with DeepBach in an interactive way. 85 | 86 | This server can also been started using Docker with: 87 | ``` 88 | docker run -p 5000:5000 -it --rm ghadjeres/deepbach 89 | ``` 90 | (CPU version), with 91 | or 92 | ``` 93 | docker run --runtime=nvidia -p 5000:5000 -it --rm ghadjeres/deepbach 94 | ``` 95 | (GPU version, requires [nvidia-docker](https://github.com/NVIDIA/nvidia-docker). 96 | 97 | 98 | ## Usage within MuseScore 99 | *Deprecated* 100 | 101 | Put `deepBachMuseScore.qml` file in your MuseScore plugins directory, and run 102 | ``` 103 | python musescore_flask_server.py 104 | ``` 105 | Open MuseScore and activate deepBachMuseScore plugin using the Plugin manager. 106 | You can then click on the Compose button without any selection to create a new chorale from 107 | scratch. You can then select a region in the chorale score and click on the Compose button to 108 | regenerated this region using DeepBach. 109 | 110 | 111 | ### Issues 112 | 113 | ### Music21 editor not set 114 | 115 | ``` 116 | music21.converter.subConverters.SubConverterException: Cannot find a valid application path for format musicxml. Specify this in your Environment by calling environment.set(None, '/path/to/application') 117 | ``` 118 | 119 | Either set it to MuseScore or similar (on a machine with GUI) to to a dummy command (on a server). See the installation section. 120 | 121 | # Citing 122 | 123 | Please consider citing this work or emailing me if you use DeepBach in musical projects. 124 | ``` 125 | @InProceedings{pmlr-v70-hadjeres17a, 126 | title = {{D}eep{B}ach: a Steerable Model for {B}ach Chorales Generation}, 127 | author = {Ga{\"e}tan Hadjeres and Fran{\c{c}}ois Pachet and Frank Nielsen}, 128 | booktitle = {Proceedings of the 34th International Conference on Machine Learning}, 129 | pages = {1362--1371}, 130 | year = {2017}, 131 | editor = {Doina Precup and Yee Whye Teh}, 132 | volume = {70}, 133 | series = {Proceedings of Machine Learning Research}, 134 | address = {International Convention Centre, Sydney, Australia}, 135 | month = {06--11 Aug}, 136 | publisher = {PMLR}, 137 | pdf = {http://proceedings.mlr.press/v70/hadjeres17a/hadjeres17a.pdf}, 138 | url = {http://proceedings.mlr.press/v70/hadjeres17a.html}, 139 | } 140 | ``` 141 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | python_version: "3.7" 3 | gpu: false 4 | python_packages: 5 | - decorator==4.3.0 6 | - flask==1.0.2 7 | - flask-cors==3.0.7 8 | - itsdangerous==1.1.0 9 | - jinja2==2.10 10 | - markupsafe==1.1.0 11 | - music21==6.7.1 12 | - torch==1.0.0 --no-cache-dir 13 | - tqdm==4.29.0 14 | - werkzeug==0.14.1 15 | - numpy==1.19.5 16 | - midi2audio==0.1.1 17 | system_packages: 18 | - musescore 19 | - fluidsynth --fix-missing 20 | - ffmpeg 21 | predict: "predict.py:Predictor" 22 | -------------------------------------------------------------------------------- /deepBach.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Gaetan Hadjeres 3 | """ 4 | 5 | import click 6 | 7 | from DatasetManager.chorale_dataset import ChoraleDataset 8 | from DatasetManager.dataset_manager import DatasetManager 9 | from DatasetManager.metadata import FermataMetadata, TickMetadata, KeyMetadata 10 | 11 | from DeepBach.model_manager import DeepBach 12 | 13 | 14 | @click.command() 15 | @click.option('--note_embedding_dim', default=20, 16 | help='size of the note embeddings') 17 | @click.option('--meta_embedding_dim', default=20, 18 | help='size of the metadata embeddings') 19 | @click.option('--num_layers', default=2, 20 | help='number of layers of the LSTMs') 21 | @click.option('--lstm_hidden_size', default=256, 22 | help='hidden size of the LSTMs') 23 | @click.option('--dropout_lstm', default=0.5, 24 | help='amount of dropout between LSTM layers') 25 | @click.option('--linear_hidden_size', default=256, 26 | help='hidden size of the Linear layers') 27 | @click.option('--batch_size', default=256, 28 | help='training batch size') 29 | @click.option('--num_epochs', default=5, 30 | help='number of training epochs') 31 | @click.option('--train', is_flag=True, 32 | help='train the specified model for num_epochs') 33 | @click.option('--num_iterations', default=500, 34 | help='number of parallel pseudo-Gibbs sampling iterations') 35 | @click.option('--sequence_length_ticks', default=64, 36 | help='length of the generated chorale (in ticks)') 37 | def main(note_embedding_dim, 38 | meta_embedding_dim, 39 | num_layers, 40 | lstm_hidden_size, 41 | dropout_lstm, 42 | linear_hidden_size, 43 | batch_size, 44 | num_epochs, 45 | train, 46 | num_iterations, 47 | sequence_length_ticks, 48 | ): 49 | dataset_manager = DatasetManager() 50 | 51 | metadatas = [ 52 | FermataMetadata(), 53 | TickMetadata(subdivision=4), 54 | KeyMetadata() 55 | ] 56 | chorale_dataset_kwargs = { 57 | 'voice_ids': [0, 1, 2, 3], 58 | 'metadatas': metadatas, 59 | 'sequences_size': 8, 60 | 'subdivision': 4 61 | } 62 | bach_chorales_dataset: ChoraleDataset = dataset_manager.get_dataset( 63 | name='bach_chorales', 64 | **chorale_dataset_kwargs 65 | ) 66 | dataset = bach_chorales_dataset 67 | 68 | deepbach = DeepBach( 69 | dataset=dataset, 70 | note_embedding_dim=note_embedding_dim, 71 | meta_embedding_dim=meta_embedding_dim, 72 | num_layers=num_layers, 73 | lstm_hidden_size=lstm_hidden_size, 74 | dropout_lstm=dropout_lstm, 75 | linear_hidden_size=linear_hidden_size 76 | ) 77 | 78 | if train: 79 | deepbach.train(batch_size=batch_size, 80 | num_epochs=num_epochs) 81 | else: 82 | deepbach.load() 83 | deepbach.cuda() 84 | 85 | print('Generation') 86 | score, tensor_chorale, tensor_metadata = deepbach.generation( 87 | num_iterations=num_iterations, 88 | sequence_length_ticks=sequence_length_ticks, 89 | ) 90 | score.show('txt') 91 | score.show() 92 | 93 | 94 | if __name__ == '__main__': 95 | main() 96 | -------------------------------------------------------------------------------- /deepBachMuseScore.qml: -------------------------------------------------------------------------------- 1 | import QtQuick 2.2 2 | import QtQuick.Dialogs 1.2 3 | import QtQuick.Controls 1.1 4 | import MuseScore 3.0 5 | import FileIO 3.0 6 | 7 | MuseScore { 8 | id: mainMuseScoreObj 9 | menuPath: "Plugins.deepBachMuseScore" 10 | description: qsTr("This plugin calls deepBach project.") 11 | pluginType: "dock" 12 | dockArea: "left" 13 | property variant serverCalled: false 14 | property variant loading: false 15 | property variant linesLogged: 0 16 | property variant serverAddress: "http://localhost:5000/" 17 | FileIO { 18 | id: myFile 19 | source: "/tmp/deepbach.mxl" 20 | onError: console.log(msg) 21 | } 22 | FileIO { 23 | id: myFileXml 24 | source: "/tmp/deepbach.xml" 25 | onError: console.log(msg) 26 | } 27 | onRun: { 28 | console.log('on run called') 29 | if (mainMuseScoreObj.serverCalled !== serverAddress || modelSelector.model.length === 0) { 30 | var requestModels = getRequestObj('GET', 'models') 31 | if (call( 32 | requestModels, 33 | false, 34 | function(responseText){ 35 | mainMuseScoreObj.serverCalled = serverAddress; 36 | try { 37 | modelSelector.model = JSON.parse(responseText) 38 | logStatus('Models list loaded') 39 | var requestLoadModel = getRequestObj("GET", 'current_model') 40 | call( 41 | requestLoadModel, 42 | false, 43 | function(response) { 44 | logStatus('currently loaded model is ' + response) 45 | for (var i in modelSelector.model) { 46 | if (modelSelector.model[i] === response) { 47 | console.log('set selected at ' + i) 48 | modelSelector.currentIndex = i 49 | } 50 | } 51 | } 52 | ) 53 | } catch(error) { 54 | console.log(error) 55 | logStatus('No models found') 56 | } 57 | 58 | } 59 | )) { 60 | logStatus('Retrieving models list at ' + serverAddress) 61 | } 62 | } 63 | } 64 | Rectangle { 65 | id: wrapperPanel 66 | color: "white" 67 | Text { 68 | id: title 69 | text: "DeepBach" 70 | font.family: "Helvetica" 71 | font.pointSize: 20 72 | color: "black" 73 | anchors.top: wrapperPanel.top 74 | anchors.topMargin: 10 75 | font.underline: false 76 | } 77 | Button { 78 | id : loadModel 79 | anchors.top: title.bottom 80 | anchors.topMargin: 15 81 | text: qsTr("Load") 82 | onClicked: { 83 | if (modelSelector.model[modelSelector.currentIndex]) { 84 | var requestLoadModel = getRequestObj("POST", 'current_model') 85 | if (call( 86 | requestLoadModel, 87 | { 88 | model_name: modelSelector.model[modelSelector.currentIndex] 89 | }, 90 | function(response) { 91 | logStatus(response) 92 | } 93 | )) { 94 | logStatus('Loading model ' + modelSelector.model[modelSelector.currentIndex]) 95 | } 96 | } 97 | } 98 | } 99 | ComboBox { 100 | id: modelSelector 101 | anchors.top: modelSelector.top 102 | anchors.left: modelSelector.right 103 | anchors.leftMargin: 10 104 | model: [] 105 | width: 100 106 | visible: false 107 | } 108 | Button { 109 | id : buttonOpenFile 110 | text: qsTr("Compose") 111 | anchors.top: loadModel.top 112 | anchors.topMargin: 30 113 | anchors.bottomMargin: 10 114 | onClicked: { 115 | var cursor = curScore.newCursor(); 116 | cursor.rewind(1); 117 | var startStaff = cursor.staffIdx; 118 | var startTick = cursor.tick; 119 | cursor.rewind(2); 120 | var endStaff = cursor.staffIdx; 121 | var endTick = cursor.tick; 122 | var extension = 'mxl' 123 | // var targetFile = tempPath() + "/my_file." + extension 124 | myFile.remove(); 125 | logStatus(myFile.source); 126 | var res = writeScore(mainMuseScoreObj.curScore, myFile.source, extension) 127 | logStatus(res); 128 | var request = getRequestObj("POST", 'compose') 129 | if (call( 130 | request, 131 | { 132 | start_staff: startStaff, 133 | end_staff: endStaff, 134 | start_tick: startTick, 135 | end_tick: endTick, 136 | // xml_string: myFile.read(), 137 | file_path: myFile.source 138 | }, 139 | function(response) { 140 | if (response) { 141 | logStatus('Done composing') 142 | // myFileXml.write(response) 143 | readScore(myFileXml.source) 144 | 145 | } else { 146 | logStatus('Got Empty Response when composing') 147 | } 148 | } 149 | )) { 150 | logStatus('Composing...') 151 | } 152 | } 153 | } 154 | Label { 155 | id: statusLabel 156 | wrapMode: Text.WordWrap 157 | text: '' 158 | color: 'grey' 159 | font.pointSize:12 160 | anchors.left: wrapperPanel.left 161 | anchors.top: buttonOpenFile.top 162 | anchors.leftMargin: 10 163 | anchors.topMargin: 30 164 | visible: false 165 | } 166 | } 167 | function logStatus(text) { 168 | mainMuseScoreObj.linesLogged++; 169 | if (mainMuseScoreObj.linesLogged > 15) { 170 | // break the textblock into an array of lines 171 | var lines = statusLabel.text.split("\r\n"); 172 | // remove one line, starting at the first position 173 | lines.splice(0,1); 174 | // join the array back into a single string 175 | statusLabel.text = lines.join("\r\n"); 176 | } 177 | statusLabel.text += '- ' + text + "\r\n" 178 | } 179 | function getRequestObj(method, endpoint) { 180 | console.debug('calling endpoint ' + endpoint) 181 | var request = new XMLHttpRequest() 182 | endpoint = endpoint || '' 183 | request.open(method, serverAddress + endpoint, true) 184 | return request 185 | } 186 | function call(request, params, cb) { 187 | if (mainMuseScoreObj.loading) { 188 | logStatus('refusing to call server') 189 | return false 190 | } 191 | request.onreadystatechange = function() { 192 | if (request.readyState == XMLHttpRequest.DONE) { 193 | mainMuseScoreObj.loading = false; 194 | cb(request.responseText); 195 | } 196 | } 197 | if (params) { 198 | request.setRequestHeader("Content-Type", "application/x-www-form-urlencoded") 199 | var pairs = []; 200 | for (var prop in params) { 201 | if (params.hasOwnProperty(prop)) { 202 | var k = encodeURIComponent(prop), 203 | v = encodeURIComponent(params[prop]); 204 | pairs.push( k + "=" + v); 205 | } 206 | } 207 | 208 | const content = pairs.join('&') 209 | console.debug('params ' + content) 210 | mainMuseScoreObj.loading = true; 211 | request.send(content) 212 | } else { 213 | mainMuseScoreObj.loading = true; 214 | request.send() 215 | } 216 | return true 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /dl_dataset_and_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget https://www.dropbox.com/s/iuz4ml857ycyyat/deepbach_pytorch_resources.tar.gz 3 | tar xvfz deepbach_pytorch_resources.tar.gz 4 | # move resources/{datasets,models} to datasets/ and models/ 5 | mv resources/dataset_cache DatasetManager/dataset_cache 6 | mv resources/models ./models 7 | rm -R resources -------------------------------------------------------------------------------- /entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate deepbach_pytorch 4 | LC_ALL=C.UTF-8 LANG=C.UTF-8 python flask_server.py "$@" 5 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: deepbach_pytorch 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - blas=1.0=mkl 7 | - ca-certificates=2018.03.07=0 8 | - certifi=2018.11.29=py36_0 9 | - cffi=1.11.5=py36he75722e_1 10 | - click=7.0=py36_0 11 | - freetype=2.9.1=h8a8886c_1 12 | - intel-openmp=2019.1=144 13 | - jpeg=9b=h024ee3a_2 14 | - libedit=3.1.20170329=h6b74fdf_2 15 | - libffi=3.2.1=hd88cf55_4 16 | - libgcc-ng=8.2.0=hdf63c60_1 17 | - libgfortran-ng=7.3.0=hdf63c60_0 18 | - libpng=1.6.35=hbc83047_0 19 | - libstdcxx-ng=8.2.0=hdf63c60_1 20 | - libtiff=4.0.9=he85c1e1_2 21 | - mkl=2019.1=144 22 | - mkl_fft=1.0.6=py36hd81dba3_0 23 | - mkl_random=1.0.2=py36hd81dba3_0 24 | - ncurses=6.1=he6710b0_1 25 | - ninja=1.8.2=py36h6bb024c_1 26 | - numpy=1.15.4=py36h7e9f1db_0 27 | - numpy-base=1.15.4=py36hde5b4d6_0 28 | - olefile=0.46=py36_0 29 | - openssl=1.1.1a=h7b6447c_0 30 | - pillow=5.3.0=py36h34e0f95_0 31 | - pip=18.1=py36_0 32 | - pycparser=2.19=py36_0 33 | - python=3.6.8=h0371630_0 34 | - readline=7.0=h7b6447c_5 35 | - setuptools=40.6.3=py36_0 36 | - six=1.12.0=py36_0 37 | - sqlite=3.26.0=h7b6447c_0 38 | - tk=8.6.8=hbc83047_0 39 | - wheel=0.32.3=py36_0 40 | - xz=5.2.4=h14c3975_4 41 | - zlib=1.2.11=h7b6447c_3 42 | - pytorch=1.0.0 43 | - cuda100 44 | - torchvision=0.2.1=py_2 45 | - pip: 46 | - decorator==4.3.0 47 | - flask==1.0.2 48 | - flask-cors==3.0.7 49 | - itsdangerous==1.1.0 50 | - jinja2==2.10 51 | - markupsafe==1.1.0 52 | - music21==5.5.0 53 | - torch==1.0.0 54 | - tqdm==4.29.0 55 | - werkzeug==0.14.1 56 | prefix: /home/gaetan/Public/Python/anaconda3/envs/deepbach_pytorch 57 | 58 | -------------------------------------------------------------------------------- /flask_server.py: -------------------------------------------------------------------------------- 1 | from DatasetManager.chorale_dataset import ChoraleDataset 2 | from DatasetManager.dataset_manager import DatasetManager 3 | from DatasetManager.metadata import FermataMetadata, TickMetadata, KeyMetadata 4 | from DeepBach.model_manager import DeepBach 5 | 6 | from music21 import musicxml, metadata 7 | import music21 8 | 9 | import flask 10 | from flask import Flask, request, make_response 11 | from flask_cors import CORS 12 | 13 | import logging 14 | from logging import handlers as logging_handlers 15 | import sys 16 | 17 | import torch 18 | import math 19 | from typing import List, Optional 20 | import click 21 | import os 22 | 23 | app = Flask(__name__) 24 | CORS(app) 25 | 26 | app.config['UPLOAD_FOLDER'] = './uploads' 27 | ALLOWED_EXTENSIONS = {'midi'} 28 | 29 | # INITIALIZATION 30 | xml_response_headers = {"Content-Type": "text/xml", 31 | "charset": "utf-8" 32 | } 33 | mp3_response_headers = {"Content-Type": "audio/mpeg3" 34 | } 35 | 36 | deepbach = None 37 | _num_iterations = None 38 | _sequence_length_ticks = None 39 | _ticks_per_quarter = None 40 | 41 | # TODO use this parameter or extract it from the metadata somehow 42 | timesignature = music21.meter.TimeSignature('4/4') 43 | 44 | # generation parameters 45 | # todo put in click? 46 | batch_size_per_voice = 8 47 | 48 | metadatas = [ 49 | FermataMetadata(), 50 | TickMetadata(subdivision=_ticks_per_quarter), 51 | KeyMetadata() 52 | ] 53 | 54 | 55 | def get_fermatas_tensor(metadata_tensor: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Extract the fermatas tensor from a metadata tensor 58 | """ 59 | fermatas_index = [m.__class__ for m in metadatas].index( 60 | FermataMetadata().__class__) 61 | # fermatas are shared across all voices so we only consider the first voice 62 | soprano_voice_metadata = metadata_tensor[0] 63 | 64 | # `soprano_voice_metadata` has shape 65 | # `(sequence_duration, len(metadatas + 1))` (accouting for the voice 66 | # index metadata) 67 | # Extract fermatas for all steps 68 | return soprano_voice_metadata[:, fermatas_index] 69 | 70 | 71 | @click.command() 72 | @click.option('--note_embedding_dim', default=20, 73 | help='size of the note embeddings') 74 | @click.option('--meta_embedding_dim', default=20, 75 | help='size of the metadata embeddings') 76 | @click.option('--num_layers', default=2, 77 | help='number of layers of the LSTMs') 78 | @click.option('--lstm_hidden_size', default=256, 79 | help='hidden size of the LSTMs') 80 | @click.option('--dropout_lstm', default=0.5, 81 | help='amount of dropout between LSTM layers') 82 | @click.option('--dropout_lstm', default=0.5, 83 | help='amount of dropout between LSTM layers') 84 | @click.option('--linear_hidden_size', default=256, 85 | help='hidden size of the Linear layers') 86 | @click.option('--num_iterations', default=50, 87 | help='number of parallel pseudo-Gibbs sampling iterations (for a single update)') 88 | @click.option('--sequence_length_ticks', default=64, 89 | help='length of the generated chorale (in ticks)') 90 | @click.option('--ticks_per_quarter', default=4, 91 | help='number of ticks per quarter note') 92 | @click.option('--port', default=5000, 93 | help='port to serve on') 94 | def init_app(note_embedding_dim, 95 | meta_embedding_dim, 96 | num_layers, 97 | lstm_hidden_size, 98 | dropout_lstm, 99 | linear_hidden_size, 100 | num_iterations, 101 | sequence_length_ticks, 102 | ticks_per_quarter, 103 | port 104 | ): 105 | global metadatas 106 | global _sequence_length_ticks 107 | global _num_iterations 108 | global _ticks_per_quarter 109 | 110 | _ticks_per_quarter = ticks_per_quarter 111 | _sequence_length_ticks = sequence_length_ticks 112 | _num_iterations = num_iterations 113 | 114 | dataset_manager = DatasetManager() 115 | chorale_dataset_kwargs = { 116 | 'voice_ids': [0, 1, 2, 3], 117 | 'metadatas': metadatas, 118 | 'sequences_size': 8, 119 | 'subdivision': 4 120 | } 121 | 122 | bach_chorales_dataset: ChoraleDataset = dataset_manager.get_dataset( 123 | name='bach_chorales', 124 | **chorale_dataset_kwargs 125 | ) 126 | assert sequence_length_ticks % bach_chorales_dataset.subdivision == 0 127 | 128 | global deepbach 129 | deepbach = DeepBach( 130 | dataset=bach_chorales_dataset, 131 | note_embedding_dim=note_embedding_dim, 132 | meta_embedding_dim=meta_embedding_dim, 133 | num_layers=num_layers, 134 | lstm_hidden_size=lstm_hidden_size, 135 | dropout_lstm=dropout_lstm, 136 | linear_hidden_size=linear_hidden_size 137 | ) 138 | deepbach.load() 139 | deepbach.cuda() 140 | 141 | # launch the script 142 | # use threaded=True to fix Chrome/Chromium engine hanging on requests 143 | # [https://stackoverflow.com/a/30670626] 144 | local_only = False 145 | if local_only: 146 | # accessible only locally: 147 | app.run(threaded=True) 148 | else: 149 | # accessible from outside: 150 | app.run(host='0.0.0.0', port=port, threaded=True) 151 | 152 | 153 | @app.route('/generate', methods=['GET', 'POST']) 154 | def compose(): 155 | """ 156 | Return a new, generated sheet 157 | Usage: 158 | - Request: empty, generation is done in an unconstrained fashion 159 | - Response: a sheet, MusicXML 160 | """ 161 | global deepbach 162 | global _sequence_length_ticks 163 | global _num_iterations 164 | 165 | # Use more iterations for the initial generation step 166 | # FIXME hardcoded 4/4 time-signature 167 | num_measures_generation = math.floor(_sequence_length_ticks / 168 | deepbach.dataset.subdivision) 169 | initial_num_iterations = math.floor(_num_iterations * num_measures_generation 170 | / 3) # HACK hardcoded reduction 171 | 172 | (generated_sheet, _, generated_metadata_tensor) = ( 173 | deepbach.generation(num_iterations=initial_num_iterations, 174 | sequence_length_ticks=_sequence_length_ticks) 175 | ) 176 | 177 | generated_fermatas_tensor = get_fermatas_tensor(generated_metadata_tensor) 178 | 179 | # convert sheet to xml 180 | response = sheet_and_fermatas_to_json_response( 181 | generated_sheet, generated_fermatas_tensor) 182 | return response 183 | 184 | 185 | @app.route('/test-generate', methods=['GET']) 186 | def ex(): 187 | _current_sheet = next(music21.corpus.chorales.Iterator()) 188 | return sheet_to_xml_response(_current_sheet) 189 | 190 | 191 | @app.route('/musicxml-to-midi', methods=['POST']) 192 | def get_midi(): 193 | """ 194 | Convert the provided MusicXML sheet to MIDI and return it 195 | Usage: 196 | POST -d @sheet.mxml /musicxml-to-midi 197 | - Request: the payload is expected to contain the sheet to convert, in 198 | MusicXML format 199 | - Response: a MIDI file 200 | """ 201 | sheetString = request.data 202 | sheet = music21.converter.parseData(sheetString, format="musicxml") 203 | insert_musicxml_metadata(sheet) 204 | 205 | return sheet_to_midi_response(sheet) 206 | 207 | 208 | @app.route('/timerange-change', methods=['POST']) 209 | def timerange_change(): 210 | """ 211 | Perform local re-generation on a sheet and return the updated sheet 212 | Usage: 213 | POST /timerange-change?time_range_start_beat=XXX&time_range_end_beat=XXX 214 | - Request: 215 | The payload is expected to be a JSON with the following keys: 216 | * 'sheet': a string containing the sheet to modify, in MusicXML 217 | format 218 | * 'fermatas': a list of integers describing the positions of 219 | fermatas in the sheet 220 | TODO: could store the fermatas in the MusicXML client-side 221 | The start and end positions (in beats) of the portion to regenerate 222 | are passed as arguments in the URL: 223 | * time_range_start_quarter, integer: 224 | - Response: 225 | A JSON document with same schema as the request containing the 226 | updated sheet and fermatas 227 | """ 228 | global deepbach 229 | global _num_iterations 230 | global _sequence_length_ticks 231 | request_parameters = parse_timerange_request(request) 232 | time_range_start_quarter = request_parameters['time_range_start_quarter'] 233 | time_range_end_quarter = request_parameters['time_range_end_quarter'] 234 | fermatas_tensor = request_parameters['fermatas_tensor'] 235 | 236 | input_sheet = request_parameters['sheet'] 237 | 238 | time_index_range_ticks = [ 239 | time_range_start_quarter * deepbach.dataset.subdivision, 240 | time_range_end_quarter * deepbach.dataset.subdivision] 241 | 242 | input_tensor_sheet, input_tensor_metadata = ( 243 | deepbach.dataset.transposed_score_and_metadata_tensors( 244 | input_sheet, 0) 245 | ) 246 | 247 | (output_sheet, 248 | output_tensor_sheet, 249 | output_tensor_metadata) = deepbach.generation( 250 | tensor_chorale=input_tensor_sheet, 251 | tensor_metadata=input_tensor_metadata, 252 | temperature=1., 253 | batch_size_per_voice=batch_size_per_voice, 254 | num_iterations=_num_iterations, 255 | sequence_length_ticks=_sequence_length_ticks, 256 | time_index_range_ticks=time_index_range_ticks, 257 | fermatas=fermatas_tensor 258 | ) 259 | 260 | output_fermatas_tensor = get_fermatas_tensor(output_tensor_metadata) 261 | 262 | # create JSON response 263 | response = sheet_and_fermatas_to_json_response( 264 | output_sheet, output_fermatas_tensor) 265 | return response 266 | 267 | 268 | @app.route('/analyze-notes', methods=['POST']) 269 | def dummy_read_audio_file(): 270 | global deepbach 271 | import wave 272 | print(request.args) 273 | print(request.files) 274 | chunk = 1024 275 | audio_fp = wave.open(request.files['audio'], 'rb') 276 | data = audio_fp.readframes(chunk) 277 | print(data) 278 | notes = ['C', 'D', 'Toto', 'Tata'] 279 | 280 | return flask.jsonify({'success': True, 'notes': notes}) 281 | 282 | 283 | def insert_musicxml_metadata(sheet: music21.stream.Stream): 284 | """ 285 | Insert various metadata into the provided XML document 286 | The timesignature in particular is required for proper MIDI conversion 287 | """ 288 | global timesignature 289 | 290 | from music21.clef import TrebleClef, BassClef, Treble8vbClef 291 | for part, name, clef in zip( 292 | sheet.parts, 293 | ['soprano', 'alto', 'tenor', 'bass'], 294 | [TrebleClef(), TrebleClef(), Treble8vbClef(), BassClef()] 295 | ): 296 | # empty_part = part.template() 297 | part.insert(0, timesignature) 298 | part.insert(0, clef) 299 | part.id = name 300 | part.partName = name 301 | 302 | md = metadata.Metadata() 303 | sheet.insert(0, md) 304 | 305 | # required for proper musicXML formatting 306 | sheet.metadata.title = 'DeepBach' 307 | sheet.metadata.composer = 'DeepBach' 308 | 309 | 310 | def parse_fermatas(fermatas_list: List[int]) -> Optional[torch.Tensor]: 311 | """ 312 | Parses fermata GET option, given at the quarter note level 313 | """ 314 | global _sequence_length_ticks 315 | # the data is expected to be provided as a list in the request 316 | return fermatas_to_tensor(fermatas_list) 317 | 318 | 319 | def fermatas_to_tensor(fermatas: List[int]) -> torch.Tensor: 320 | """ 321 | Convert a list of fermata positions (in beats) into a subdivion-rate tensor 322 | """ 323 | global _sequence_length_ticks 324 | global deepbach 325 | subdivision = deepbach.dataset.subdivision 326 | sequence_length_quarterNotes = math.floor(_sequence_length_ticks / subdivision) 327 | 328 | fermatas_tensor_quarterNotes = torch.zeros(sequence_length_quarterNotes) 329 | fermatas_tensor_quarterNotes[fermatas] = 1 330 | # expand the tensor to the subdivision level 331 | fermatas_tensor = (fermatas_tensor_quarterNotes 332 | .repeat((subdivision, 1)) 333 | .t() 334 | .contiguous()) 335 | return fermatas_tensor.view(_sequence_length_ticks) 336 | 337 | 338 | def fermatas_tensor_to_list(fermatas_tensor: torch.Tensor) -> List[int]: 339 | """ 340 | Convert a binary fermatas tensor into a list of positions (in beats) 341 | """ 342 | global _sequence_length_ticks 343 | global deepbach 344 | 345 | subdivision = deepbach.dataset.subdivision 346 | 347 | # subsample fermatas to beat rate 348 | beat_rate_fermatas_tensor = fermatas_tensor[::subdivision] 349 | 350 | # pick positions of active fermatas 351 | fermatas_positions_tensor = beat_rate_fermatas_tensor.nonzero().squeeze() 352 | fermatas = fermatas_positions_tensor.int().tolist() 353 | 354 | return fermatas 355 | 356 | 357 | def parse_timerange_request(request): 358 | """ 359 | must cast 360 | :param req: 361 | :return: 362 | """ 363 | json_data = request.get_json(force=True) 364 | time_range_start_quarter = int(request.args.get('time_range_start_quarter')) 365 | time_range_end_quarter = int(request.args.get('time_range_end_quarter')) 366 | fermatas_tensor = parse_fermatas(json_data['fermatas']) 367 | 368 | sheet = music21.converter.parseData(json_data['sheet'], format="musicxml") 369 | 370 | return { 371 | 'sheet': sheet, 372 | 'time_range_start_quarter': time_range_start_quarter, 373 | 'time_range_end_quarter': time_range_end_quarter, 374 | 'fermatas_tensor': fermatas_tensor 375 | } 376 | 377 | 378 | def sheet_to_xml_bytes(sheet: music21.stream.Stream): 379 | """Convert a music21 sheet to a MusicXML document""" 380 | # first insert necessary MusicXML metadata 381 | insert_musicxml_metadata(sheet) 382 | 383 | sheet_to_xml_bytes = musicxml.m21ToXml.GeneralObjectExporter(sheet).parse() 384 | 385 | return sheet_to_xml_bytes 386 | 387 | 388 | def sheet_to_xml_response(sheet: music21.stream.Stream): 389 | """Generate and send XML sheet""" 390 | xml_sheet_bytes = sheet_to_xml_bytes(sheet) 391 | 392 | response = flask.make_response((xml_sheet_bytes, xml_response_headers)) 393 | return response 394 | 395 | 396 | def sheet_and_fermatas_to_json_response(sheet: music21.stream.Stream, 397 | fermatas_tensor: torch.Tensor): 398 | sheet_xml_string = sheet_to_xml_bytes(sheet).decode('utf-8') 399 | fermatas_list = fermatas_tensor_to_list(fermatas_tensor) 400 | 401 | print(fermatas_list) 402 | 403 | return flask.jsonify({ 404 | 'sheet': sheet_xml_string, 405 | 'fermatas': fermatas_list 406 | }) 407 | 408 | 409 | def sheet_to_midi_response(sheet): 410 | """ 411 | Convert the provided sheet to midi and send it as a file 412 | """ 413 | midiFile = sheet.write('midi') 414 | return flask.send_file(midiFile, mimetype="audio/midi", 415 | cache_timeout=-1 # disable cache 416 | ) 417 | 418 | 419 | def sheet_to_mp3_response(sheet): 420 | """Generate and send MP3 file 421 | Uses server-side `timidity` 422 | """ 423 | sheet.write('midi', fp='./uploads/midi.mid') 424 | os.system(f'rm uploads/midi.mp3') 425 | os.system(f'timidity uploads/midi.mid -Ow -o - | ' 426 | f'ffmpeg -i - -acodec libmp3lame -ab 64k ' 427 | f'uploads/midi.mp3') 428 | return flask.send_file('uploads/midi.mp3') 429 | 430 | 431 | if __name__ == '__main__': 432 | file_handler = logging_handlers.RotatingFileHandler( 433 | 'app.log', maxBytes=10000, backupCount=5) 434 | 435 | app.logger.addHandler(file_handler) 436 | app.logger.setLevel(logging.INFO) 437 | init_app() 438 | -------------------------------------------------------------------------------- /musescore_flask_server.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import pickle 4 | import click 5 | import tempfile 6 | from glob import glob 7 | import subprocess 8 | 9 | import music21 10 | import numpy as np 11 | from flask import Flask, request, make_response, jsonify 12 | from music21 import musicxml, converter 13 | from tqdm import tqdm 14 | import torch 15 | import logging 16 | from logging import handlers as logging_handlers 17 | 18 | from DatasetManager.chorale_dataset import ChoraleDataset 19 | from DatasetManager.dataset_manager import DatasetManager 20 | from DatasetManager.metadata import FermataMetadata, TickMetadata, KeyMetadata 21 | from DeepBach.model_manager import DeepBach 22 | 23 | UPLOAD_FOLDER = '/tmp' 24 | ALLOWED_EXTENSIONS = {'xml', 'mxl', 'mid', 'midi'} 25 | 26 | app = Flask(__name__) 27 | 28 | deepbach = None 29 | _tensor_metadata = None 30 | _num_iterations = None 31 | _sequence_length_ticks = None 32 | _ticks_per_quarter = None 33 | _tensor_sheet = None 34 | 35 | # TODO use this parameter or extract it from the metadata somehow 36 | timesignature = music21.meter.TimeSignature('4/4') 37 | 38 | # generation parameters 39 | # todo put in click? 40 | batch_size_per_voice = 8 41 | 42 | metadatas = [ 43 | FermataMetadata(), 44 | TickMetadata(subdivision=_ticks_per_quarter), 45 | KeyMetadata() 46 | ] 47 | 48 | response_headers = {"Content-Type": "text/html", 49 | "charset": "utf-8" 50 | } 51 | 52 | 53 | @click.command() 54 | @click.option('--note_embedding_dim', default=20, 55 | help='size of the note embeddings') 56 | @click.option('--meta_embedding_dim', default=20, 57 | help='size of the metadata embeddings') 58 | @click.option('--num_layers', default=2, 59 | help='number of layers of the LSTMs') 60 | @click.option('--lstm_hidden_size', default=256, 61 | help='hidden size of the LSTMs') 62 | @click.option('--dropout_lstm', default=0.5, 63 | help='amount of dropout between LSTM layers') 64 | @click.option('--dropout_lstm', default=0.5, 65 | help='amount of dropout between LSTM layers') 66 | @click.option('--linear_hidden_size', default=256, 67 | help='hidden size of the Linear layers') 68 | @click.option('--num_iterations', default=100, 69 | help='number of parallel pseudo-Gibbs sampling iterations (for a single update)') 70 | @click.option('--sequence_length_ticks', default=64, 71 | help='length of the generated chorale (in ticks)') 72 | @click.option('--ticks_per_quarter', default=4, 73 | help='number of ticks per quarter note') 74 | @click.option('--port', default=5000, 75 | help='port to serve on') 76 | def init_app(note_embedding_dim, 77 | meta_embedding_dim, 78 | num_layers, 79 | lstm_hidden_size, 80 | dropout_lstm, 81 | linear_hidden_size, 82 | num_iterations, 83 | sequence_length_ticks, 84 | ticks_per_quarter, 85 | port 86 | ): 87 | global metadatas 88 | global _sequence_length_ticks 89 | global _num_iterations 90 | global _ticks_per_quarter 91 | global bach_chorales_dataset 92 | 93 | _ticks_per_quarter = ticks_per_quarter 94 | _sequence_length_ticks = sequence_length_ticks 95 | _num_iterations = num_iterations 96 | 97 | dataset_manager = DatasetManager() 98 | chorale_dataset_kwargs = { 99 | 'voice_ids': [0, 1, 2, 3], 100 | 'metadatas': metadatas, 101 | 'sequences_size': 8, 102 | 'subdivision': 4 103 | } 104 | 105 | _bach_chorales_dataset: ChoraleDataset = dataset_manager.get_dataset( 106 | name='bach_chorales', 107 | **chorale_dataset_kwargs 108 | ) 109 | bach_chorales_dataset = _bach_chorales_dataset 110 | 111 | assert sequence_length_ticks % bach_chorales_dataset.subdivision == 0 112 | 113 | global deepbach 114 | deepbach = DeepBach( 115 | dataset=bach_chorales_dataset, 116 | note_embedding_dim=note_embedding_dim, 117 | meta_embedding_dim=meta_embedding_dim, 118 | num_layers=num_layers, 119 | lstm_hidden_size=lstm_hidden_size, 120 | dropout_lstm=dropout_lstm, 121 | linear_hidden_size=linear_hidden_size 122 | ) 123 | deepbach.load() 124 | deepbach.cuda() 125 | 126 | # launch the script 127 | # use threaded=True to fix Chrome/Chromium engine hanging on requests 128 | # [https://stackoverflow.com/a/30670626] 129 | local_only = False 130 | if local_only: 131 | # accessible only locally: 132 | app.run(threaded=True) 133 | else: 134 | # accessible from outside: 135 | app.run(host='0.0.0.0', port=port, threaded=True) 136 | 137 | 138 | def get_fermatas_tensor(metadata_tensor: torch.Tensor) -> torch.Tensor: 139 | """ 140 | Extract the fermatas tensor from a metadata tensor 141 | """ 142 | fermatas_index = [m.__class__ for m in metadatas].index( 143 | FermataMetadata().__class__) 144 | # fermatas are shared across all voices so we only consider the first voice 145 | soprano_voice_metadata = metadata_tensor[0] 146 | 147 | # `soprano_voice_metadata` has shape 148 | # `(sequence_duration, len(metadatas + 1))` (accouting for the voice 149 | # index metadata) 150 | # Extract fermatas for all steps 151 | return soprano_voice_metadata[:, fermatas_index] 152 | 153 | 154 | def allowed_file(filename): 155 | return '.' in filename and \ 156 | filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS 157 | 158 | 159 | def compose_from_scratch(): 160 | """ 161 | Return a new, generated sheet 162 | Usage: 163 | - Request: empty, generation is done in an unconstrained fashion 164 | - Response: a sheet, MusicXML 165 | """ 166 | global deepbach 167 | global _sequence_length_ticks 168 | global _num_iterations 169 | global _tensor_sheet 170 | global _tensor_metadata 171 | 172 | # Use more iterations for the initial generation step 173 | # FIXME hardcoded 4/4 time-signature 174 | num_measures_generation = math.floor(_sequence_length_ticks / 175 | deepbach.dataset.subdivision) 176 | initial_num_iterations = math.floor(_num_iterations * num_measures_generation 177 | / 3) # HACK hardcoded reduction 178 | 179 | (generated_sheet, _tensor_sheet, _tensor_metadata) = ( 180 | deepbach.generation(num_iterations=initial_num_iterations, 181 | sequence_length_ticks=_sequence_length_ticks) 182 | ) 183 | return generated_sheet 184 | 185 | 186 | @app.route('/compose', methods=['POST']) 187 | def compose(): 188 | global deepbach 189 | global _num_iterations 190 | global _sequence_length_ticks 191 | global _tensor_sheet 192 | global _tensor_metadata 193 | global bach_chorales_dataset 194 | 195 | # global models 196 | NUM_MIDI_TICKS_IN_SIXTEENTH_NOTE = 120 197 | start_tick_selection = int(float( 198 | request.form['start_tick']) / NUM_MIDI_TICKS_IN_SIXTEENTH_NOTE) 199 | end_tick_selection = int( 200 | float(request.form['end_tick']) / NUM_MIDI_TICKS_IN_SIXTEENTH_NOTE) 201 | file_path = request.form['file_path'] 202 | root, ext = os.path.splitext(file_path) 203 | dir = os.path.dirname(file_path) 204 | assert ext == '.mxl' 205 | xml_file = f'{root}.xml' 206 | 207 | # if no selection REGENERATE and set chorale length 208 | if start_tick_selection == 0 and end_tick_selection == 0: 209 | generated_sheet = compose_from_scratch() 210 | generated_sheet.write('xml', xml_file) 211 | return sheet_to_response(generated_sheet) 212 | else: 213 | # --- Parse request--- 214 | # Old method: does not work because the MuseScore plugin does not export to xml but only to compressed .mxl 215 | # with tempfile.NamedTemporaryFile(mode='wb', suffix='.xml') as file: 216 | # print(file.name) 217 | # xml_string = request.form['xml_string'] 218 | # file.write(xml_string) 219 | # music21_parsed_chorale = converter.parse(file.name) 220 | 221 | # file_path points to an mxl file: we extract it 222 | subprocess.run(f'unzip -o {file_path} -d {dir}', shell=True) 223 | music21_parsed_chorale = converter.parse(xml_file) 224 | 225 | 226 | _tensor_sheet, _tensor_metadata = bach_chorales_dataset.transposed_score_and_metadata_tensors(music21_parsed_chorale, semi_tone=0) 227 | 228 | start_voice_index = int(request.form['start_staff']) 229 | end_voice_index = int(request.form['end_staff']) + 1 230 | 231 | time_index_range_ticks = [start_tick_selection, end_tick_selection] 232 | 233 | region_length = end_tick_selection - start_tick_selection 234 | 235 | # compute batch_size_per_voice: 236 | if region_length <= 8: 237 | batch_size_per_voice = 2 238 | elif region_length <= 16: 239 | batch_size_per_voice = 4 240 | else: 241 | batch_size_per_voice = 8 242 | 243 | 244 | num_total_iterations = int(_num_iterations * region_length / batch_size_per_voice) 245 | 246 | fermatas_tensor = get_fermatas_tensor(_tensor_metadata) 247 | 248 | # --- Generate--- 249 | (output_sheet, 250 | _tensor_sheet, 251 | _tensor_metadata) = deepbach.generation( 252 | tensor_chorale=_tensor_sheet, 253 | tensor_metadata=_tensor_metadata, 254 | temperature=1., 255 | batch_size_per_voice=batch_size_per_voice, 256 | num_iterations=num_total_iterations, 257 | sequence_length_ticks=_sequence_length_ticks, 258 | time_index_range_ticks=time_index_range_ticks, 259 | fermatas=fermatas_tensor, 260 | voice_index_range=[start_voice_index, end_voice_index], 261 | random_init=True 262 | ) 263 | 264 | 265 | 266 | output_sheet.write('xml', xml_file) 267 | response = sheet_to_response(sheet=output_sheet) 268 | return response 269 | 270 | 271 | def get_fermatas_tensor(metadata_tensor: torch.Tensor) -> torch.Tensor: 272 | """ 273 | Extract the fermatas tensor from a metadata tensor 274 | """ 275 | fermatas_index = [m.__class__ for m in metadatas].index( 276 | FermataMetadata().__class__) 277 | # fermatas are shared across all voices so we only consider the first voice 278 | soprano_voice_metadata = metadata_tensor[0] 279 | 280 | # `soprano_voice_metadata` has shape 281 | # `(sequence_duration, len(metadatas + 1))` (accouting for the voice 282 | # index metadata) 283 | # Extract fermatas for all steps 284 | return soprano_voice_metadata[:, fermatas_index] 285 | 286 | 287 | def sheet_to_response(sheet): 288 | # convert sheet to xml 289 | goe = musicxml.m21ToXml.GeneralObjectExporter(sheet) 290 | xml_chorale_string = goe.parse() 291 | 292 | response = make_response((xml_chorale_string, response_headers)) 293 | return response 294 | 295 | 296 | @app.route('/test', methods=['POST', 'GET']) 297 | def test_generation(): 298 | response = make_response(('TEST', response_headers)) 299 | 300 | if request.method == 'POST': 301 | print(request) 302 | 303 | return response 304 | 305 | 306 | @app.route('/models', methods=['GET']) 307 | def get_models(): 308 | models_list = ['Deepbach'] 309 | return jsonify(models_list) 310 | 311 | 312 | @app.route('/current_model', methods=['POST', 'PUT']) 313 | def current_model_update(): 314 | return 'Model is only loaded once' 315 | 316 | 317 | @app.route('/current_model', methods=['GET']) 318 | def current_model_get(): 319 | return 'DeepBach' 320 | 321 | 322 | if __name__ == '__main__': 323 | file_handler = logging_handlers.RotatingFileHandler( 324 | 'app.log', maxBytes=10000, backupCount=5) 325 | 326 | app.logger.addHandler(file_handler) 327 | app.logger.setLevel(logging.INFO) 328 | init_app() 329 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | import tempfile 5 | 6 | import random 7 | import numpy as np 8 | import torch 9 | import click 10 | from cog import BasePredictor, Input, Path 11 | import music21 12 | from midi2audio import FluidSynth 13 | 14 | from DatasetManager.chorale_dataset import ChoraleDataset 15 | from DatasetManager.dataset_manager import DatasetManager 16 | from DatasetManager.metadata import FermataMetadata, KeyMetadata, TickMetadata 17 | from DeepBach.model_manager import DeepBach 18 | 19 | 20 | class Predictor(BasePredictor): 21 | def setup(self): 22 | """Load the model""" 23 | 24 | # music21.environment.set("musicxmlPath", "/bin/true") 25 | 26 | note_embedding_dim = 20 27 | meta_embedding_dim = 20 28 | num_layers = 2 29 | lstm_hidden_size = 256 30 | dropout_lstm = 0.5 31 | linear_hidden_size = 256 32 | batch_size = 256 33 | num_epochs = 5 34 | train = False 35 | num_iterations = 500 36 | sequence_length_ticks = 64 37 | 38 | dataset_manager = DatasetManager() 39 | 40 | metadatas = [FermataMetadata(), TickMetadata(subdivision=4), KeyMetadata()] 41 | chorale_dataset_kwargs = { 42 | "voice_ids": [0, 1, 2, 3], 43 | "metadatas": metadatas, 44 | "sequences_size": 8, 45 | "subdivision": 4, 46 | } 47 | bach_chorales_dataset: ChoraleDataset = dataset_manager.get_dataset( 48 | name="bach_chorales", **chorale_dataset_kwargs 49 | ) 50 | dataset = bach_chorales_dataset 51 | 52 | self.deepbach = DeepBach( 53 | dataset=dataset, 54 | note_embedding_dim=note_embedding_dim, 55 | meta_embedding_dim=meta_embedding_dim, 56 | num_layers=num_layers, 57 | lstm_hidden_size=lstm_hidden_size, 58 | dropout_lstm=dropout_lstm, 59 | linear_hidden_size=linear_hidden_size, 60 | ) 61 | 62 | self.deepbach.load() 63 | 64 | # load fluidsynth fo rmidi 2 audio conversion 65 | self.fs = FluidSynth() 66 | 67 | # self.converter = music21.converter.parse('path_to_musicxml.xml') 68 | 69 | def predict( 70 | self, 71 | num_iterations: int = Input( 72 | default=500, 73 | description="Number of parallel pseudo-Gibbs sampling iterations", 74 | ), 75 | sequence_length_ticks: int = Input( 76 | default=64, ge=16, description="Length of the generated chorale (in ticks)" 77 | ), 78 | output_type: str = Input( 79 | default="audio", 80 | choices=["midi", "audio"], 81 | description="Output representation type: can be audio or midi", 82 | ), 83 | seed: int = Input(default=-1, description="Random seed, -1 for random"), 84 | ) -> Path: 85 | """Score Generation""" 86 | if seed >= 0: 87 | random.seed(seed) 88 | np.random.seed(seed) 89 | torch.use_deterministic_algorithms(True) 90 | torch.manual_seed(seed) 91 | 92 | score, tensor_chorale, tensor_metadata = self.deepbach.generation( 93 | num_iterations=num_iterations, 94 | sequence_length_ticks=sequence_length_ticks, 95 | ) 96 | 97 | if output_type == "audio": 98 | output_path_wav = Path(tempfile.mkdtemp()) / "output.wav" 99 | output_path_mp3 = Path(tempfile.mkdtemp()) / "output.mp3" 100 | 101 | midi_score = score.write("midi") 102 | self.fs.midi_to_audio(midi_score, str(output_path_wav)) 103 | 104 | subprocess.check_output( 105 | [ 106 | "ffmpeg", 107 | "-i", 108 | str(output_path_wav), 109 | "-af", 110 | "silenceremove=1:0:-50dB,aformat=dblp,areverse,silenceremove=1:0:-50dB,aformat=dblp,areverse", # strip silence 111 | str(output_path_mp3), 112 | ], 113 | ) 114 | 115 | return output_path_mp3 116 | 117 | elif output_type == "midi": 118 | output_path_midi = Path(tempfile.mkdtemp()) / "output.mid" 119 | score.write("midi", fp=output_path_midi) 120 | return output_path_midi 121 | --------------------------------------------------------------------------------