├── .gitignore ├── LICENSE ├── README.md ├── assets └── data_proc_diagram.png ├── dataset ├── Dataset.md ├── analyzer.py ├── corpus │ ├── .DS_Store │ └── keep ├── midi2corpus.py ├── midi_analyzed │ ├── .DS_Store │ └── keep ├── midi_synchronized │ ├── .DS_Store │ └── keep ├── midi_transcribed │ ├── .DS_Store │ └── keep ├── representations │ ├── cond-ls2midi │ │ └── keep │ └── uncond │ │ ├── cp │ │ ├── .DS_Store │ │ ├── compile.py │ │ ├── corpus2events.py │ │ └── events2words.py │ │ ├── remi │ │ ├── .DS_Store │ │ ├── compile.py │ │ ├── corpus2events.py │ │ ├── events2words.py │ │ └── valid_fn_idx_map.json │ │ └── validation_songs.json └── synchronizer.py ├── docs └── aaai21-slides.pdf └── workspace ├── cond_ls2midi └── keep └── uncond ├── Experiments.md ├── cp-linear ├── gen_midis │ ├── get_0.mid │ ├── get_1.mid │ ├── get_10.mid │ ├── get_11.mid │ ├── get_12.mid │ ├── get_13.mid │ ├── get_14.mid │ ├── get_15.mid │ ├── get_16.mid │ ├── get_17.mid │ ├── get_18.mid │ ├── get_19.mid │ ├── get_2.mid │ ├── get_20.mid │ ├── get_21.mid │ ├── get_22.mid │ ├── get_23.mid │ ├── get_24.mid │ ├── get_25.mid │ ├── get_26.mid │ ├── get_27.mid │ ├── get_28.mid │ ├── get_29.mid │ ├── get_3.mid │ ├── get_30.mid │ ├── get_31.mid │ ├── get_32.mid │ ├── get_33.mid │ ├── get_34.mid │ ├── get_35.mid │ ├── get_36.mid │ ├── get_37.mid │ ├── get_38.mid │ ├── get_39.mid │ ├── get_4.mid │ ├── get_40.mid │ ├── get_41.mid │ ├── get_42.mid │ ├── get_43.mid │ ├── get_44.mid │ ├── get_45.mid │ ├── get_46.mid │ ├── get_47.mid │ ├── get_48.mid │ ├── get_49.mid │ ├── get_5.mid │ ├── get_6.mid │ ├── get_7.mid │ ├── get_8.mid │ └── get_9.mid ├── main-cp.py ├── runtime_stats.json └── saver.py └── remi-xl ├── config.yml ├── gen_midis ├── 170_0.mid ├── 170_1.mid ├── 170_10.mid ├── 170_11.mid ├── 170_12.mid ├── 170_13.mid ├── 170_14.mid ├── 170_15.mid ├── 170_16.mid ├── 170_17.mid ├── 170_18.mid ├── 170_19.mid ├── 170_2.mid ├── 170_3.mid ├── 170_4.mid ├── 170_5.mid ├── 170_6.mid ├── 170_7.mid ├── 170_8.mid └── 170_9.mid ├── inference.py ├── model.py ├── modules.py ├── runtime_stats.json ├── saver.py ├── songTime.csv └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | __pycache__/ 132 | .vscode/ 133 | .ipynb_checkpoints/ 134 | .DS_Store 135 | miditoolkit.egg-info/ 136 | @eaDir 137 | *.pyc 138 | *.pypirc 139 | Thumbs.db 140 | *.gz -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Compound Word Transformer 2 | 3 | 4 | Authors: [Wen-Yi Hsiao](https://github.com/wayne391), [Jen-Yu Liu](https://github.com/ciaua), [Yin-Cheng Yeh](https://github.com/yyeh26) and [Yi-Hsuan Yang](http://mac.citi.sinica.edu.tw/~yang/) 5 | 6 | [**Paper (arXiv)**](https://arxiv.org/abs/2101.02402) | [**Audio demo (Google Drive)**](https://drive.google.com/drive/folders/1G_tTpcAuVpYO-4IUGS8i8XdwoIsUix8o?usp=sharing) | [**Blog**](https://ailabs.tw/human-interaction/compound-word-transformer-generate-pop-piano-music-of-full-song-length/) | [**Colab notebook**](https://colab.research.google.com/drive/1AU8iMhy10WxHj7yt3j8S3FQvvKvgXrr0) 7 | 8 | Official PyTorch implementation of AAAI2021 paper "Compound Word Transformer: Learning to Compose Full-Song Musicover Dynamic Directed Hypergraphs". 9 | 10 | We presented a new variant of the Transformer that can processes multiple consecutive tokens at once at a time step. The proposed method can greatly reduce the length of the resulting sequence and therefore enhance the training and inference efficiency. We employ it to learn to compose expressive Pop piano music of full-song length (involving up to 10K individual to23 kens per song). In this repository, we open source our **Ailabs.tw 1K7** dataset, and the codes for unconditional generation. 11 | 12 | 13 | ## Dependencies 14 | 15 | * python 3.6 16 | * Required packages: 17 | * madmom 18 | * miditoolkit 19 | * pytorch-fast-transformers 20 | 21 | 22 | ``chorder`` is our in-house rule-based symbolic chord recognition algorithm, which is developed by our former intern - [joshuachang2311](https://github.com/joshuachang2311/chorder). He is also a jazz pianist :musical_keyboard:. 23 | 24 | 25 | ## Model 26 | In this work, we conduct two scenario of generation: 27 | * unconditional generation 28 | * To see the experimental results and discussion, please refer to [here](https://github.com/YatingMusic/compound-word-transformer/blob/main/workspace/uncond/Experiments.md). 29 | 30 | * conditional generation, leadsheet to full midi (ls2midi) 31 | * [**Work in progress**] We plan to open source the code associated with this part in the future. 32 | * melody extracyion (skyline) 33 | * objective metrics 34 | * model 35 | 36 | ## Dataset 37 | To prepare your own training data, please refer to [documentaion](https://github.com/YatingMusic/compound-word-transformer/blob/main/dataset/Dataset.md) for further understanding. 38 | Or, you can start with our **AIlabs.tw Pop1K7**, which is available [here](https://drive.google.com/file/d/1qw_tVUntblIg4lW16vbpjLXVndkVtgDe/view?usp=sharing). 39 | 40 | ## Demo: Colab Notebook 41 | 42 | The colab notebook is now available [here](https://colab.research.google.com/drive/1AU8iMhy10WxHj7yt3j8S3FQvvKvgXrr0). 43 | Thanks our intern [AdarshKumar712](https://github.com/AdarshKumar712) for organizing the codes. 44 | 45 | 46 | ## Acknowledgement 47 | - PyTorch codes for transformer-XL is modified from [kimiyoung/transformer-xl](https://github.com/kimiyoung/transformer-xl). 48 | - Thanks [Yu-Hua Chen](https://github.com/ss12f32v) and [Hsiao-Tzu Hung](https://github.com/annahung31) for helping organize the codes. 49 | 50 | -------------------------------------------------------------------------------- /assets/data_proc_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/assets/data_proc_diagram.png -------------------------------------------------------------------------------- /dataset/Dataset.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | In this document, we demonstrate our standard data processing pipeline in our team. Following the instructions and runnung corresponding python scripts, you can easily generate and customized your your own dataset. 4 | 5 | 6 |

7 | 8 |

9 | 10 | 11 | ## 1. From `audio` to `midi_transcribed` 12 | We collect audio clips of piano performance from YouTube. 13 | 14 | * run google magenta's [onsets and frames](https://github.com/magenta/magenta/tree/master/magenta/models/onsets_frames_transcription) 15 | 16 | ## 2. From `midi_transcribed` to `midi_synchronized` 17 | In this step, we use [madamom](https://github.com/CPJKU/madmom) for beat/downbeat tracking. Next, We interpolate 480 ticks between two adjacent beats, and map the absolute time into its according tick. Lastly, we infer the tempo changes from the time interval between adjacent beats. We choose beat resolution=480 because it's a common setting in modern DAW. Notice that we don't quantize any timing in this step hence we can keep tiny offset for future purposes. 18 | 19 | * run `synchronizer.py` 20 | 21 | ## 3. From `midi_synchronized` to `midi_analyzed` 22 | In this step, we develop in-house rule-based symbolic melody extraction and chord recognition algorithm to obtain desired information. The code for chord recognition are open sourced [here](https://github.com/joshuachang2311/chorder). We plan to open rhe code for melody extraction in the future. 23 | 24 | * run `analyzer.py` 25 | 26 | ## 4. From `midi_analyzed` to `Corpus` 27 | We quantize everything (duration, velocity, bpm) in this step. Also append the data with EOS(end of sequence) token. 28 | 29 | * run `midi2corpus.py` 30 | 31 | ## 5. From `Corpus` to `Representation` 32 | We have 2 kinds of representation - Compound Word (**CP**) and **REMI**, and 2 tasks - unconditional and conditional generation, which resulting 4 combinations. Go to corresponding folder `/` and run the scripts. 33 | 34 | 35 | * run `corpus2events.py`: to generate human readable tokens and re-arrange data. 36 | * run `events2words.py`: to build dictionary and renumber the tokens. 37 | * run `compile.py`: to discard disqualified songs that exceeding length limits, reshape the data for transformer-XL, and generate mask for variable length. 38 | 39 | --- 40 | 41 | ## AILabs.tw Pop1K7 dataset 42 | 43 | Alternatively, you can refer to [here](https://drive.google.com/drive/folders/1DY54sxeCcQfVXdGXps5lHwtRe7D_kBRV?usp=sharing) to obtain the entire workspace and the pre-processed training data, which originally used in our paper. 44 | -------------------------------------------------------------------------------- /dataset/analyzer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import numpy as np 4 | import multiprocessing as mp 5 | 6 | import miditoolkit 7 | from miditoolkit.midi import parser as mid_parser 8 | from miditoolkit.pianoroll import parser as pr_parser 9 | from miditoolkit.midi.containers import Marker, Instrument, TempoChange 10 | 11 | from chorder import Dechorder 12 | 13 | 14 | num2pitch = { 15 | 0: 'C', 16 | 1: 'C#', 17 | 2: 'D', 18 | 3: 'D#', 19 | 4: 'E', 20 | 5: 'F', 21 | 6: 'F#', 22 | 7: 'G', 23 | 8: 'G#', 24 | 9: 'A', 25 | 10: 'A#', 26 | 11: 'B', 27 | } 28 | 29 | 30 | def traverse_dir( 31 | root_dir, 32 | extension=('mid', 'MID', 'midi'), 33 | amount=None, 34 | str_=None, 35 | is_pure=False, 36 | verbose=False, 37 | is_sort=False, 38 | is_ext=True): 39 | if verbose: 40 | print('[*] Scanning...') 41 | file_list = [] 42 | cnt = 0 43 | for root, _, files in os.walk(root_dir): 44 | for file in files: 45 | if file.endswith(extension): 46 | if (amount is not None) and (cnt == amount): 47 | break 48 | if str_ is not None: 49 | if str_ not in file: 50 | continue 51 | mix_path = os.path.join(root, file) 52 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 53 | if not is_ext: 54 | ext = pure_path.split('.')[-1] 55 | pure_path = pure_path[:-(len(ext)+1)] 56 | if verbose: 57 | print(pure_path) 58 | file_list.append(pure_path) 59 | cnt += 1 60 | if verbose: 61 | print('Total: %d files' % len(file_list)) 62 | print('Done!!!') 63 | if is_sort: 64 | file_list.sort() 65 | return file_list 66 | 67 | 68 | def proc_one(path_infile, path_outfile): 69 | print('----') 70 | print(' >', path_infile) 71 | print(' >', path_outfile) 72 | 73 | # load 74 | midi_obj = miditoolkit.midi.parser.MidiFile(path_infile) 75 | midi_obj_out = copy.deepcopy(midi_obj) 76 | notes = midi_obj.instruments[0].notes 77 | notes = sorted(notes, key=lambda x: (x.start, x.pitch)) 78 | 79 | # --- chord --- # 80 | # exctract chord 81 | chords = Dechorder.dechord(midi_obj) 82 | markers = [] 83 | for cidx, chord in enumerate(chords): 84 | if chord.is_complete(): 85 | chord_text = num2pitch[chord.root_pc] + '_' + chord.quality + '_' + num2pitch[chord.bass_pc] 86 | else: 87 | chord_text = 'N_N_N' 88 | markers.append(Marker(time=int(cidx*480), text=chord_text)) 89 | 90 | # dedup 91 | prev_chord = None 92 | dedup_chords = [] 93 | for m in markers: 94 | if m.text != prev_chord: 95 | prev_chord = m.text 96 | dedup_chords.append(m) 97 | 98 | # --- global properties --- # 99 | # global tempo 100 | tempos = [b.tempo for b in midi_obj.tempo_changes][:40] 101 | tempo_median = np.median(tempos) 102 | global_bpm =int(tempo_median) 103 | print(' > [global] bpm:', global_bpm) 104 | 105 | # === save === # 106 | # mkdir 107 | fn = os.path.basename(path_outfile) 108 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True) 109 | 110 | # markers 111 | midi_obj_out.markers = dedup_chords 112 | midi_obj_out.markers.insert(0, Marker(text='global_bpm_'+str(int(global_bpm)), time=0)) 113 | 114 | # save 115 | midi_obj_out.instruments[0].name = 'piano' 116 | midi_obj_out.dump(path_outfile) 117 | 118 | 119 | if __name__ == '__main__': 120 | # paths 121 | path_indir = './midi_synchronized' 122 | path_outdir = './midi_analyzed' 123 | os.makedirs(path_outdir, exist_ok=True) 124 | 125 | # list files 126 | midifiles = traverse_dir( 127 | path_indir, 128 | is_pure=True, 129 | is_sort=True) 130 | n_files = len(midifiles) 131 | print('num fiels:', n_files) 132 | 133 | # collect 134 | data = [] 135 | for fidx in range(n_files): 136 | path_midi = midifiles[fidx] 137 | print('{}/{}'.format(fidx, n_files)) 138 | 139 | # paths 140 | path_infile = os.path.join(path_indir, path_midi) 141 | path_outfile = os.path.join(path_outdir, path_midi) 142 | 143 | # append 144 | data.append([path_infile, path_outfile]) 145 | 146 | # run, multi-thread 147 | pool = mp.Pool() 148 | pool.starmap(proc_one, data) -------------------------------------------------------------------------------- /dataset/corpus/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/corpus/.DS_Store -------------------------------------------------------------------------------- /dataset/corpus/keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/corpus/keep -------------------------------------------------------------------------------- /dataset/midi2corpus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import pickle 4 | import numpy as np 5 | import miditoolkit 6 | import collections 7 | 8 | import pandas as pd 9 | import matplotlib.pyplot as plt 10 | import seaborn as sns 11 | 12 | # ================================================== # 13 | # Configuration # 14 | # ================================================== # 15 | BEAT_RESOL = 480 16 | BAR_RESOL = BEAT_RESOL * 4 17 | TICK_RESOL = BEAT_RESOL // 4 18 | INSTR_NAME_MAP = {'piano': 0} 19 | MIN_BPM = 40 20 | MIN_VELOCITY = 40 21 | NOTE_SORTING = 1 # 0: ascending / 1: descending 22 | 23 | DEFAULT_VELOCITY_BINS = np.linspace(0, 128, 64+1, dtype=np.int) 24 | DEFAULT_BPM_BINS = np.linspace(32, 224, 64+1, dtype=np.int) 25 | DEFAULT_SHIFT_BINS = np.linspace(-60, 60, 60+1, dtype=np.int) 26 | DEFAULT_DURATION_BINS = np.arange( 27 | BEAT_RESOL/8, BEAT_RESOL*8+1, BEAT_RESOL/8) 28 | 29 | # ================================================== # 30 | 31 | 32 | def traverse_dir( 33 | root_dir, 34 | extension=('mid', 'MID', 'midi'), 35 | amount=None, 36 | str_=None, 37 | is_pure=False, 38 | verbose=False, 39 | is_sort=False, 40 | is_ext=True): 41 | if verbose: 42 | print('[*] Scanning...') 43 | file_list = [] 44 | cnt = 0 45 | for root, _, files in os.walk(root_dir): 46 | for file in files: 47 | if file.endswith(extension): 48 | if (amount is not None) and (cnt == amount): 49 | break 50 | if str_ is not None: 51 | if str_ not in file: 52 | continue 53 | mix_path = os.path.join(root, file) 54 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 55 | if not is_ext: 56 | ext = pure_path.split('.')[-1] 57 | pure_path = pure_path[:-(len(ext)+1)] 58 | if verbose: 59 | print(pure_path) 60 | file_list.append(pure_path) 61 | cnt += 1 62 | if verbose: 63 | print('Total: %d files' % len(file_list)) 64 | print('Done!!!') 65 | if is_sort: 66 | file_list.sort() 67 | return file_list 68 | 69 | 70 | def proc_one(path_midi, path_outfile): 71 | # --- load --- # 72 | midi_obj = miditoolkit.midi.parser.MidiFile(path_midi) 73 | 74 | # load notes 75 | instr_notes = collections.defaultdict(list) 76 | for instr in midi_obj.instruments: 77 | # skip 78 | if instr.name not in INSTR_NAME_MAP.keys(): 79 | continue 80 | 81 | # process 82 | instr_idx = INSTR_NAME_MAP[instr.name] 83 | for note in instr.notes: 84 | note.instr_idx=instr_idx 85 | instr_notes[instr_idx].append(note) 86 | if NOTE_SORTING == 0: 87 | instr_notes[instr_idx].sort( 88 | key=lambda x: (x.start, x.pitch)) 89 | elif NOTE_SORTING == 1: 90 | instr_notes[instr_idx].sort( 91 | key=lambda x: (x.start, -x.pitch)) 92 | else: 93 | raise ValueError(' [x] Unknown type of sorting.') 94 | 95 | # load chords 96 | chords = [] 97 | for marker in midi_obj.markers: 98 | if marker.text.split('_')[0] != 'global' and \ 99 | 'Boundary' not in marker.text.split('_')[0]: 100 | chords.append(marker) 101 | chords.sort(key=lambda x: x.time) 102 | 103 | # load tempos 104 | tempos = midi_obj.tempo_changes 105 | tempos.sort(key=lambda x: x.time) 106 | 107 | # load labels 108 | labels = [] 109 | for marker in midi_obj.markers: 110 | if 'Boundary' in marker.text.split('_')[0]: 111 | labels.append(marker) 112 | labels.sort(key=lambda x: x.time) 113 | 114 | # load global bpm 115 | gobal_bpm = 120 116 | for marker in midi_obj.markers: 117 | if marker.text.split('_')[0] == 'global' and \ 118 | marker.text.split('_')[1] == 'bpm': 119 | gobal_bpm = int(marker.text.split('_')[2]) 120 | 121 | # --- process items to grid --- # 122 | # compute empty bar offset at head 123 | first_note_time = min([instr_notes[k][0].start for k in instr_notes.keys()]) 124 | last_note_time = max([instr_notes[k][-1].start for k in instr_notes.keys()]) 125 | 126 | quant_time_first = int(np.round(first_note_time / TICK_RESOL) * TICK_RESOL) 127 | offset = quant_time_first // BAR_RESOL # empty bar 128 | last_bar = int(np.ceil(last_note_time / BAR_RESOL)) - offset 129 | print(' > offset:', offset) 130 | print(' > last_bar:', last_bar) 131 | 132 | # process notes 133 | intsr_gird = dict() 134 | for key in instr_notes.keys(): 135 | notes = instr_notes[key] 136 | note_grid = collections.defaultdict(list) 137 | for note in notes: 138 | note.start = note.start - offset * BAR_RESOL 139 | note.end = note.end - offset * BAR_RESOL 140 | 141 | # quantize start 142 | quant_time = int(np.round(note.start / TICK_RESOL) * TICK_RESOL) 143 | 144 | # velocity 145 | note.velocity = DEFAULT_VELOCITY_BINS[ 146 | np.argmin(abs(DEFAULT_VELOCITY_BINS-note.velocity))] 147 | note.velocity = max(MIN_VELOCITY, note.velocity) 148 | 149 | # shift of start 150 | note.shift = note.start - quant_time 151 | note.shift = DEFAULT_SHIFT_BINS[np.argmin(abs(DEFAULT_SHIFT_BINS-note.shift))] 152 | 153 | # duration 154 | note_duration = note.end - note.start 155 | if note_duration > BAR_RESOL: 156 | note_duration = BAR_RESOL 157 | ntick_duration = int(np.round(note_duration / TICK_RESOL) * TICK_RESOL) 158 | note.duration = ntick_duration 159 | 160 | # append 161 | note_grid[quant_time].append(note) 162 | 163 | # set to track 164 | intsr_gird[key] = note_grid.copy() 165 | 166 | # process chords 167 | chord_grid = collections.defaultdict(list) 168 | for chord in chords: 169 | # quantize 170 | chord.time = chord.time - offset * BAR_RESOL 171 | chord.time = 0 if chord.time < 0 else chord.time 172 | quant_time = int(np.round(chord.time / TICK_RESOL) * TICK_RESOL) 173 | 174 | # append 175 | chord_grid[quant_time].append(chord) 176 | 177 | # process tempo 178 | tempo_grid = collections.defaultdict(list) 179 | for tempo in tempos: 180 | # quantize 181 | tempo.time = tempo.time - offset * BAR_RESOL 182 | tempo.time = 0 if tempo.time < 0 else tempo.time 183 | quant_time = int(np.round(tempo.time / TICK_RESOL) * TICK_RESOL) 184 | tempo.tempo = DEFAULT_BPM_BINS[np.argmin(abs(DEFAULT_BPM_BINS-tempo.tempo))] 185 | 186 | # append 187 | tempo_grid[quant_time].append(tempo) 188 | 189 | # process boundary 190 | label_grid = collections.defaultdict(list) 191 | for label in labels: 192 | # quantize 193 | label.time = label.time - offset * BAR_RESOL 194 | label.time = 0 if label.time < 0 else label.time 195 | quant_time = int(np.round(label.time / TICK_RESOL) * TICK_RESOL) 196 | 197 | # append 198 | label_grid[quant_time] = [label] 199 | 200 | # process global bpm 201 | gobal_bpm = DEFAULT_BPM_BINS[np.argmin(abs(DEFAULT_BPM_BINS-gobal_bpm))] 202 | 203 | # collect 204 | song_data = { 205 | 'notes': intsr_gird, 206 | 'chords': chord_grid, 207 | 'tempos': tempo_grid, 208 | 'labels': label_grid, 209 | 'metadata': { 210 | 'global_bpm': gobal_bpm, 211 | 'last_bar': last_bar, 212 | } 213 | } 214 | 215 | # save 216 | fn = os.path.basename(path_outfile) 217 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True) 218 | pickle.dump(song_data, open(path_outfile, 'wb')) 219 | 220 | 221 | if __name__ == '__main__': 222 | # paths 223 | path_indir = './midi_analyzed' 224 | path_outdir = './corpus' 225 | os.makedirs(path_outdir, exist_ok=True) 226 | 227 | # list files 228 | midifiles = traverse_dir( 229 | path_indir, 230 | is_pure=True, 231 | is_sort=True) 232 | n_files = len(midifiles) 233 | print('num fiels:', n_files) 234 | 235 | # run all 236 | for fidx in range(n_files): 237 | path_midi = midifiles[fidx] 238 | print('{}/{}'.format(fidx, n_files)) 239 | 240 | # paths 241 | path_infile = os.path.join(path_indir, path_midi) 242 | path_outfile = os.path.join(path_outdir, path_midi+'.pkl') 243 | 244 | # proc 245 | proc_one(path_infile, path_outfile) -------------------------------------------------------------------------------- /dataset/midi_analyzed/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/midi_analyzed/.DS_Store -------------------------------------------------------------------------------- /dataset/midi_analyzed/keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/midi_analyzed/keep -------------------------------------------------------------------------------- /dataset/midi_synchronized/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/midi_synchronized/.DS_Store -------------------------------------------------------------------------------- /dataset/midi_synchronized/keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/midi_synchronized/keep -------------------------------------------------------------------------------- /dataset/midi_transcribed/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/midi_transcribed/.DS_Store -------------------------------------------------------------------------------- /dataset/midi_transcribed/keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/midi_transcribed/keep -------------------------------------------------------------------------------- /dataset/representations/cond-ls2midi/keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/representations/cond-ls2midi/keep -------------------------------------------------------------------------------- /dataset/representations/uncond/cp/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/representations/uncond/cp/.DS_Store -------------------------------------------------------------------------------- /dataset/representations/uncond/cp/compile.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import numpy as np 5 | 6 | 7 | TEST_AMOUNT = 50 8 | WINDOW_SIZE = 512 9 | GROUP_SIZE = 7 10 | MAX_LEN = WINDOW_SIZE * GROUP_SIZE 11 | COMPILE_TARGET = 'linear' # 'linear', 'XL' 12 | print('[config] MAX_LEN:', MAX_LEN) 13 | 14 | 15 | def traverse_dir( 16 | root_dir, 17 | extension=('mid', 'MID'), 18 | amount=None, 19 | str_=None, 20 | is_pure=False, 21 | verbose=False, 22 | is_sort=False, 23 | is_ext=True): 24 | if verbose: 25 | print('[*] Scanning...') 26 | file_list = [] 27 | cnt = 0 28 | for root, _, files in os.walk(root_dir): 29 | for file in files: 30 | if file.endswith(extension): 31 | if (amount is not None) and (cnt == amount): 32 | break 33 | if str_ is not None: 34 | if str_ not in file: 35 | continue 36 | mix_path = os.path.join(root, file) 37 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 38 | if not is_ext: 39 | ext = pure_path.split('.')[-1] 40 | pure_path = pure_path[:-(len(ext)+1)] 41 | if verbose: 42 | print(pure_path) 43 | file_list.append(pure_path) 44 | cnt += 1 45 | if verbose: 46 | print('Total: %d files' % len(file_list)) 47 | print('Done!!!') 48 | if is_sort: 49 | file_list.sort() 50 | return file_list 51 | 52 | 53 | if __name__ == '__main__': 54 | # paths 55 | path_root = 'ailab17k_from-scratch_cp' 56 | path_indir = os.path.join( path_root, 'words') 57 | 58 | # load dictionary 59 | path_dictionary = os.path.join(path_root, 'dictionary.pkl') 60 | event2word, word2event = pickle.load(open(path_dictionary, 'rb')) 61 | 62 | # load all words 63 | wordfiles = traverse_dir( 64 | path_indir, 65 | extension=('npy')) 66 | n_files = len(wordfiles) 67 | 68 | # init 69 | x_list = [] 70 | y_list = [] 71 | mask_list = [] 72 | seq_len_list = [] 73 | num_groups_list = [] 74 | name_list = [] 75 | 76 | # process 77 | for fidx in range(n_files): 78 | print('--[{}/{}]-----'.format(fidx, n_files)) 79 | file = wordfiles[fidx] 80 | words = np.load(file) 81 | num_words = len(words) 82 | eos_arr = words[-1][None, ...] 83 | 84 | if num_words >= MAX_LEN - 2: # 2 for room 85 | print(' [!] too long:', num_words) 86 | continue 87 | 88 | # arrange IO 89 | x = words[:-1].copy() 90 | y = words[1:].copy() 91 | seq_len = len(x) 92 | print(' > seq_len:', seq_len) 93 | 94 | # pad with eos 95 | pad = np.tile( 96 | eos_arr, 97 | (MAX_LEN-seq_len, 1)) 98 | 99 | x = np.concatenate([x, pad], axis=0) 100 | y = np.concatenate([y, pad], axis=0) 101 | mask = np.concatenate( 102 | [np.ones(seq_len), np.zeros(MAX_LEN-seq_len)]) 103 | 104 | # collect 105 | x_list.append(x) 106 | y_list.append(y) 107 | mask_list.append(mask) 108 | seq_len_list.append(seq_len) 109 | num_groups_list.append(int(np.ceil(seq_len/WINDOW_SIZE))) 110 | name_list.append(file) 111 | 112 | # sort by length (descending) 113 | zipped = zip(seq_len_list, x_list, y_list, mask_list, num_groups_list, name_list) 114 | seq_len_list, x_list, y_list, mask_list, num_groups_list, name_list = zip( 115 | *sorted(zipped, key=lambda x: -x[0])) 116 | 117 | print('\n\n[Finished]') 118 | print(' compile target:', COMPILE_TARGET) 119 | if COMPILE_TARGET == 'XL': 120 | # reshape 121 | x_final = np.array(x_list).reshape(len(x_list), GROUP_SIZE, WINDOW_SIZE, -1) 122 | y_final = np.array(y_list).reshape(len(x_list), GROUP_SIZE, WINDOW_SIZE, -1) 123 | mask_final = np.array(mask_list).reshape(-1, GROUP_SIZE, WINDOW_SIZE) 124 | elif COMPILE_TARGET == 'linear': 125 | x_final = np.array(x_list) 126 | y_final = np.array(y_list) 127 | mask_final = np.array(mask_list) 128 | else: 129 | raise ValueError('Unknown target:', COMPILE_TARGET) 130 | 131 | # check 132 | num_samples = len(seq_len_list) 133 | print(' > count:', ) 134 | print(' > x_final:', x_final.shape) 135 | print(' > y_final:', y_final.shape) 136 | print(' > mask_final:', mask_final.shape) 137 | 138 | # split train/test 139 | validation_songs = json.load(open('../validation_songs.json', 'r')) 140 | train_idx = [] 141 | test_idx = [] 142 | 143 | # validation filename map 144 | fn2idx_map = { 145 | 'fn2idx': dict(), 146 | 'idx2fn': dict(), 147 | } 148 | 149 | # run split 150 | valid_cnt = 0 151 | for nidx, n in enumerate(name_list): 152 | flag = True 153 | for fn in validation_songs: 154 | if fn in n: 155 | test_idx.append(nidx) 156 | flag = False 157 | fn2idx_map['fn2idx'][fn] = valid_cnt 158 | fn2idx_map['idx2fn'][valid_cnt] = fn 159 | valid_cnt += 1 160 | break 161 | 162 | if flag: 163 | train_idx.append(nidx) 164 | test_idx = np.array(test_idx) 165 | train_idx = np.array(train_idx) 166 | 167 | # save validation map 168 | path_fn2idx_map = os.path.join(path_root, 'valid_fn2idx_map.json') 169 | with open(path_fn2idx_map, 'w') as f: 170 | json.dump(fn2idx_map, f) 171 | 172 | # save train 173 | path_train = os.path.join(path_root, 'train_data_{}'.format(COMPILE_TARGET)) 174 | path_train += '.npz' 175 | np.savez( 176 | path_train, 177 | x=x_final[train_idx], 178 | y=y_final[train_idx], 179 | mask=mask_final[train_idx], 180 | seq_len=np.array(seq_len_list)[train_idx], 181 | num_groups=np.array(num_groups_list)[train_idx] 182 | ) 183 | 184 | # save test 185 | path_test = os.path.join(path_root, 'test_data_{}'.format(COMPILE_TARGET)) 186 | path_test += '.npz' 187 | np.savez( 188 | path_test, 189 | x=x_final[test_idx], 190 | y=y_final[test_idx], 191 | mask=mask_final[test_idx], 192 | seq_len=np.array(seq_len_list)[test_idx], 193 | num_groups=np.array(num_groups_list)[test_idx] 194 | ) 195 | 196 | print('---') 197 | print(' > train x:', x_final[train_idx].shape) 198 | print(' > test x:', x_final[test_idx].shape) 199 | 200 | -------------------------------------------------------------------------------- /dataset/representations/uncond/cp/corpus2events.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | # config 8 | BEAT_RESOL = 480 9 | BAR_RESOL = BEAT_RESOL * 4 10 | TICK_RESOL = BEAT_RESOL // 4 11 | 12 | 13 | # utilities 14 | def plot_hist(data, path_outfile): 15 | print('[Fig] >> {}'.format(path_outfile)) 16 | data_mean = np.mean(data) 17 | data_std = np.std(data) 18 | 19 | print('mean:', data_mean) 20 | print(' std:', data_std) 21 | 22 | plt.figure(dpi=100) 23 | plt.hist(data, bins=50) 24 | plt.title('mean: {:.3f}_std: {:.3f}'.format(data_mean, data_std)) 25 | plt.savefig(path_outfile) 26 | plt.close() 27 | 28 | def traverse_dir( 29 | root_dir, 30 | extension=('mid', 'MID'), 31 | amount=None, 32 | str_=None, 33 | is_pure=False, 34 | verbose=False, 35 | is_sort=False, 36 | is_ext=True): 37 | if verbose: 38 | print('[*] Scanning...') 39 | file_list = [] 40 | cnt = 0 41 | for root, _, files in os.walk(root_dir): 42 | for file in files: 43 | if file.endswith(extension): 44 | if (amount is not None) and (cnt == amount): 45 | break 46 | if str_ is not None: 47 | if str_ not in file: 48 | continue 49 | mix_path = os.path.join(root, file) 50 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 51 | if not is_ext: 52 | ext = pure_path.split('.')[-1] 53 | pure_path = pure_path[:-(len(ext)+1)] 54 | if verbose: 55 | print(pure_path) 56 | file_list.append(pure_path) 57 | cnt += 1 58 | if verbose: 59 | print('Total: %d files' % len(file_list)) 60 | print('Done!!!') 61 | if is_sort: 62 | file_list.sort() 63 | return file_list 64 | 65 | # ---- define event ---- # 66 | ''' 8 kinds: 67 | tempo: 0: IGN 68 | 1: no change 69 | int: tempo 70 | chord: 0: IGN 71 | 1: no change 72 | str: chord types 73 | bar-beat: 0: IGN 74 | int: beat position (1...16) 75 | int: bar (bar) 76 | type: 0: eos 77 | 1: metrical 78 | 2: note 79 | duration: 0: IGN 80 | int: length 81 | pitch: 0: IGN 82 | int: pitch 83 | velocity: 0: IGN 84 | int: velocity 85 | ''' 86 | 87 | # event template 88 | compound_event = { 89 | 'tempo': 0, 90 | 'chord': 0, 91 | 'bar-beat': 0, 92 | 'type': 0, 93 | 'pitch': 0, 94 | 'duration': 0, 95 | 'velocity': 0, 96 | } 97 | 98 | 99 | def create_bar_event(): 100 | meter_event = compound_event.copy() 101 | meter_event['bar-beat'] = 'Bar' 102 | meter_event['type'] = 'Metrical' 103 | return meter_event 104 | 105 | 106 | def create_piano_metrical_event(tempo, chord, pos): 107 | meter_event = compound_event.copy() 108 | meter_event['tempo'] = tempo 109 | meter_event['chord'] = chord 110 | meter_event['bar-beat'] = pos 111 | meter_event['type'] = 'Metrical' 112 | return meter_event 113 | 114 | 115 | def create_piano_note_event(pitch, duration, velocity): 116 | note_event = compound_event.copy() 117 | note_event['pitch'] = pitch 118 | note_event['duration'] = duration 119 | note_event['velocity'] = velocity 120 | note_event['type'] = 'Note' 121 | return note_event 122 | 123 | 124 | def create_eos_event(): 125 | eos_event = compound_event.copy() 126 | eos_event['type'] = 'EOS' 127 | return eos_event 128 | 129 | 130 | # ----------------------------------------------- # 131 | # core functions 132 | def corpus2event_cp(path_infile, path_outfile): 133 | ''' 134 | task: 2 track 135 | 1: piano (note + tempo) 136 | --- 137 | remove duplicate position tokens 138 | ''' 139 | data = pickle.load(open(path_infile, 'rb')) 140 | 141 | # global tag 142 | global_end = data['metadata']['last_bar'] * BAR_RESOL 143 | 144 | # process 145 | final_sequence = [] 146 | for bar_step in range(0, global_end, BAR_RESOL): 147 | final_sequence.append(create_bar_event()) 148 | 149 | # --- piano track --- # 150 | for timing in range(bar_step, bar_step + BAR_RESOL, TICK_RESOL): 151 | pos_on = False 152 | pos_events = [] 153 | pos_text = 'Beat_' + str((timing-bar_step)//TICK_RESOL) 154 | 155 | # unpack 156 | t_chords = data['chords'][timing] 157 | t_tempos = data['tempos'][timing] 158 | t_notes = data['notes'][0][timing] # piano track 159 | 160 | # metrical 161 | if len(t_tempos) or len(t_chords): 162 | # chord 163 | if len(t_chords): 164 | 165 | root, quality, bass = t_chords[-1].text.split('_') 166 | chord_text = root+'_'+quality 167 | else: 168 | chord_text = 'CONTI' 169 | 170 | # tempo 171 | if len(t_tempos): 172 | tempo_text = 'Tempo_' + str(t_tempos[-1].tempo) 173 | else: 174 | tempo_text = 'CONTI' 175 | 176 | # create 177 | pos_events.append( 178 | create_piano_metrical_event( 179 | tempo_text, chord_text, pos_text)) 180 | pos_on = True 181 | 182 | # note 183 | if len(t_notes): 184 | if not pos_on: 185 | pos_events.append( 186 | create_piano_metrical_event( 187 | 'CONTI', 'CONTI', pos_text)) 188 | 189 | for note in t_notes: 190 | note_pitch_text = 'Note_Pitch_' + str(note.pitch) 191 | note_duration_text = 'Note_Duration_' + str(note.duration) 192 | note_velocity_text = 'Note_Velocity_' + str(note.velocity) 193 | 194 | pos_events.append( 195 | create_piano_note_event( 196 | note_pitch_text, 197 | note_duration_text, 198 | note_velocity_text)) 199 | 200 | # collect & beat 201 | if len(pos_events): 202 | final_sequence.extend(pos_events) 203 | 204 | # BAR ending 205 | final_sequence.append(create_bar_event()) 206 | 207 | # EOS 208 | final_sequence.append(create_eos_event()) 209 | 210 | # save 211 | fn = os.path.basename(path_outfile) 212 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True) 213 | pickle.dump(final_sequence, open(path_outfile, 'wb')) 214 | 215 | return len(final_sequence) 216 | 217 | 218 | if __name__ == '__main__': 219 | # paths 220 | path_root = './ailab17k_from-scratch_cp' 221 | path_indir = '../../../corpus' 222 | path_outdir = os.path.join(path_root, 'events') 223 | os.makedirs(path_outdir, exist_ok=True) 224 | 225 | # list files 226 | midifiles = traverse_dir( 227 | path_indir, 228 | extension=('pkl'), 229 | is_pure=True, 230 | is_sort=True) 231 | n_files = len(midifiles) 232 | print('num fiels:', n_files) 233 | 234 | # run all 235 | len_list = [] 236 | for fidx in range(n_files): 237 | path_midi = midifiles[fidx] 238 | print('{}/{}'.format(fidx, n_files)) 239 | 240 | # paths 241 | path_infile = os.path.join(path_indir, path_midi) 242 | path_outfile = os.path.join(path_outdir, path_midi) 243 | 244 | # proc 245 | num_tokens = corpus2event_cp(path_infile, path_outfile) 246 | print(' > num_token:', num_tokens) 247 | len_list.append(num_tokens) 248 | 249 | # plot 250 | plot_hist( 251 | len_list, 252 | os.path.join(path_root, 'num_tokens.png') 253 | ) -------------------------------------------------------------------------------- /dataset/representations/uncond/cp/events2words.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import collections 5 | 6 | 7 | def traverse_dir( 8 | root_dir, 9 | extension=('mid', 'MID'), 10 | amount=None, 11 | str_=None, 12 | is_pure=False, 13 | verbose=False, 14 | is_sort=False, 15 | is_ext=True): 16 | if verbose: 17 | print('[*] Scanning...') 18 | file_list = [] 19 | cnt = 0 20 | for root, _, files in os.walk(root_dir): 21 | for file in files: 22 | if file.endswith(extension): 23 | if (amount is not None) and (cnt == amount): 24 | break 25 | if str_ is not None: 26 | if str_ not in file: 27 | continue 28 | mix_path = os.path.join(root, file) 29 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 30 | if not is_ext: 31 | ext = pure_path.split('.')[-1] 32 | pure_path = pure_path[:-(len(ext)+1)] 33 | if verbose: 34 | print(pure_path) 35 | file_list.append(pure_path) 36 | cnt += 1 37 | if verbose: 38 | print('Total: %d files' % len(file_list)) 39 | print('Done!!!') 40 | if is_sort: 41 | file_list.sort() 42 | return file_list 43 | 44 | 45 | if __name__ == '__main__': 46 | # paths 47 | path_root = 'ailab17k_from-scratch_cp' 48 | path_indir = os.path.join(path_root, 'events') 49 | path_outdir = os.path.join(path_root, 'words') 50 | path_dictionary = os.path.join(path_root, 'dictionary.pkl') 51 | os.makedirs(path_outdir, exist_ok=True) 52 | 53 | # list files 54 | eventfiles = traverse_dir( 55 | path_indir, 56 | is_pure=True, 57 | is_sort=True, 58 | extension=('pkl')) 59 | n_files = len(eventfiles) 60 | print('num fiels:', n_files) 61 | 62 | # --- build dictionary --- # 63 | # all files 64 | class_keys = pickle.load( 65 | open(os.path.join(path_indir, eventfiles[0]), 'rb'))[0].keys() 66 | print('class keys:', class_keys) 67 | 68 | # define dictionary 69 | event2word = {} 70 | word2event = {} 71 | 72 | corpus_kv = collections.defaultdict(list) 73 | for file in eventfiles: 74 | for event in pickle.load(open( 75 | os.path.join(path_indir, file), 'rb')): 76 | for key in class_keys: 77 | corpus_kv[key].append(event[key]) 78 | 79 | for ckey in class_keys: 80 | class_unique_vals = sorted( 81 | set(corpus_kv[ckey]), key=lambda x: (not isinstance(x, int), x)) 82 | event2word[ckey] = {key: i for i, key in enumerate(class_unique_vals)} 83 | word2event[ckey] = {i: key for i, key in enumerate(class_unique_vals)} 84 | 85 | # print 86 | print('[class size]') 87 | for key in class_keys: 88 | print(' > {:10s}: {}'.format(key, len(event2word[key]))) 89 | 90 | # save 91 | path_dict = os.path.join(path_root, 'dictionary.pkl') 92 | pickle.dump((event2word, word2event), open(path_dict, 'wb')) 93 | 94 | # --- compile each --- # 95 | # reload 96 | event2word, word2event = pickle.load(open(path_dict, 'rb')) 97 | for fidx in range(len(eventfiles)): 98 | file = eventfiles[fidx] 99 | events_list = pickle.load(open( 100 | os.path.join(path_indir, file), 'rb')) 101 | fn = os.path.basename(file) 102 | path_outfile = os.path.join(path_outdir, fn) 103 | 104 | print('({}/{})'.format(fidx, len(eventfiles))) 105 | print(' > from:', file) 106 | print(' > to:', path_outfile) 107 | 108 | words = [] 109 | for eidx, e in enumerate(events_list): 110 | words_tmp = [ 111 | event2word[k][e[k]] for k in class_keys 112 | ] 113 | words.append(words_tmp) 114 | 115 | # save 116 | path_outfile = os.path.join(path_outdir, file + '.npy') 117 | fn = os.path.basename(path_outfile) 118 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True) 119 | np.save(path_outfile, words) 120 | -------------------------------------------------------------------------------- /dataset/representations/uncond/remi/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/representations/uncond/remi/.DS_Store -------------------------------------------------------------------------------- /dataset/representations/uncond/remi/compile.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import numpy as np 5 | 6 | 7 | TEST_AMOUNT = 50 8 | WINDOW_SIZE = 512 9 | GROUP_SIZE = 15 10 | MAX_LEN = WINDOW_SIZE * GROUP_SIZE 11 | COMPILE_TARGET = 'XL' # 'linear', 'XL' 12 | print('[config] MAX_LEN:', MAX_LEN) 13 | 14 | 15 | def traverse_dir( 16 | root_dir, 17 | extension=('mid', 'MID'), 18 | amount=None, 19 | str_=None, 20 | is_pure=False, 21 | verbose=False, 22 | is_sort=False, 23 | is_ext=True): 24 | if verbose: 25 | print('[*] Scanning...') 26 | file_list = [] 27 | cnt = 0 28 | for root, _, files in os.walk(root_dir): 29 | for file in files: 30 | if file.endswith(extension): 31 | if (amount is not None) and (cnt == amount): 32 | break 33 | if str_ is not None: 34 | if str_ not in file: 35 | continue 36 | mix_path = os.path.join(root, file) 37 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 38 | if not is_ext: 39 | ext = pure_path.split('.')[-1] 40 | pure_path = pure_path[:-(len(ext)+1)] 41 | if verbose: 42 | print(pure_path) 43 | file_list.append(pure_path) 44 | cnt += 1 45 | if verbose: 46 | print('Total: %d files' % len(file_list)) 47 | print('Done!!!') 48 | if is_sort: 49 | file_list.sort() 50 | return file_list 51 | 52 | 53 | if __name__ == '__main__': 54 | # paths 55 | path_root = 'ailab17k_from-scratch_remi' 56 | path_indir = os.path.join( path_root, 'words') 57 | 58 | # load dictionary 59 | path_dictionary = os.path.join(path_root, 'dictionary.pkl') 60 | event2word, word2event = pickle.load(open(path_dictionary, 'rb')) 61 | eos_id = event2word['EOS_None'] 62 | print(' > eos_id:', eos_id) 63 | 64 | # load all words 65 | wordfiles = traverse_dir( 66 | path_indir, 67 | extension=('npy')) 68 | 69 | # load dictionary 70 | path_dictionary = os.path.join(path_root, 'dictionary.pkl') 71 | event2word, word2event = pickle.load(open(path_dictionary, 'rb')) 72 | eos_id = event2word['EOS_None'] 73 | print(' > eos_id:', eos_id) 74 | 75 | # load all words 76 | wordfiles = traverse_dir( 77 | path_indir, 78 | extension=('npy')) 79 | n_files = len(wordfiles) 80 | 81 | # init 82 | x_list = [] 83 | y_list = [] 84 | mask_list = [] 85 | seq_len_list = [] 86 | num_groups_list = [] 87 | name_list = [] 88 | 89 | # process 90 | for fidx in range(n_files): 91 | print('--[{}/{}]-----'.format(fidx, n_files)) 92 | file = wordfiles[fidx] 93 | words = np.load(file) 94 | num_words = len(words) 95 | 96 | if num_words >= MAX_LEN - 2: # 2 for room 97 | print(' [!] too long:', num_words) 98 | continue 99 | 100 | # arrange IO 101 | x = words[:-1] 102 | y = words[1:] 103 | seq_len = len(x) 104 | print(' > seq_len:', seq_len) 105 | 106 | # pad with eos 107 | x = np.concatenate([x, np.ones(MAX_LEN-seq_len) * eos_id]) 108 | y = np.concatenate([y, np.ones(MAX_LEN-seq_len) * eos_id]) 109 | mask = np.concatenate( 110 | [np.ones(seq_len), np.zeros(MAX_LEN-seq_len)]) 111 | 112 | # collect 113 | x_list.append(x) 114 | y_list.append(y) 115 | mask_list.append(mask) 116 | seq_len_list.append(seq_len) 117 | num_groups_list.append(int(np.ceil(seq_len/WINDOW_SIZE))) 118 | name_list.append(file) 119 | 120 | # sort by length (descending) 121 | zipped = zip(seq_len_list, x_list, y_list, mask_list, num_groups_list, name_list) 122 | seq_len_list, x_list, y_list, mask_list, num_groups_list, name_list = zip( 123 | *sorted(zipped, key=lambda x: -x[0])) 124 | 125 | print('\n\n[Finished]') 126 | print(' compile target:', COMPILE_TARGET) 127 | if COMPILE_TARGET == 'XL': 128 | x_final = np.array(x_list).reshape(-1, GROUP_SIZE, WINDOW_SIZE) 129 | y_final = np.array(y_list).reshape(-1, GROUP_SIZE, WINDOW_SIZE) 130 | mask_final = np.array(mask_list).reshape(-1, GROUP_SIZE, WINDOW_SIZE) 131 | elif COMPILE_TARGET == 'linear': 132 | x_final = np.array(x_list) 133 | y_final = np.array(y_list) 134 | mask_final = np.array(mask_list) 135 | else: 136 | raise ValueError('Unknown target:', COMPILE_TARGET) 137 | num_samples = len(seq_len_list) 138 | print(' > count:', ) 139 | print(' > x_final:', x_final.shape) 140 | print(' > y_final:', y_final.shape) 141 | print(' > mask_final:', mask_final.shape) 142 | 143 | # split train/test 144 | validation_songs = json.load(open('../validation_songs.json', 'r')) 145 | train_idx = [] 146 | test_idx = [] 147 | 148 | # validation filename map 149 | fn_idx_map = { 150 | 'fn2idx': dict(), 151 | 'idx2fn': dict(), 152 | } 153 | 154 | # run split 155 | valid_cnt = 0 156 | for nidx, n in enumerate(name_list): 157 | flag = True 158 | for fn in validation_songs: 159 | if fn in n: 160 | test_idx.append(nidx) 161 | flag = False 162 | fn_idx_map['fn2idx'][fn] = valid_cnt 163 | fn_idx_map['idx2fn'][valid_cnt] = fn 164 | valid_cnt += 1 165 | break 166 | if flag: 167 | train_idx.append(nidx) 168 | test_idx = np.array(test_idx) 169 | train_idx = np.array(train_idx) 170 | 171 | # save validation map 172 | with open('valid_fn_idx_map.json', 'w') as f: 173 | json.dump(fn_idx_map, f) 174 | 175 | # save train 176 | path_train = os.path.join(path_root, 'train_data_{}.npz'.format(COMPILE_TARGET)) 177 | np.savez( 178 | path_train, 179 | x=x_final[train_idx], 180 | y=y_final[train_idx], 181 | mask=mask_final[train_idx], 182 | seq_len=np.array(seq_len_list)[train_idx], 183 | num_groups=np.array(num_groups_list)[train_idx] 184 | ) 185 | 186 | # save test 187 | path_test = os.path.join(path_root, 'test_data_{}.npz'.format(COMPILE_TARGET)) 188 | np.savez( 189 | path_test, 190 | x=x_final[test_idx], 191 | y=y_final[test_idx], 192 | mask=mask_final[test_idx], 193 | seq_len=np.array(seq_len_list)[test_idx], 194 | num_groups=np.array(num_groups_list)[test_idx] 195 | ) 196 | 197 | print('---') 198 | print(' > train x:', x_final[train_idx].shape) 199 | print(' > test x:', x_final[test_idx].shape) -------------------------------------------------------------------------------- /dataset/representations/uncond/remi/corpus2events.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | # config 8 | BEAT_RESOL = 480 9 | BAR_RESOL = BEAT_RESOL * 4 10 | TICK_RESOL = BEAT_RESOL // 4 11 | 12 | 13 | # utilities 14 | def plot_hist(data, path_outfile): 15 | print('[Fig] >> {}'.format(path_outfile)) 16 | data_mean = np.mean(data) 17 | data_std = np.std(data) 18 | 19 | print('mean:', data_mean) 20 | print(' std:', data_std) 21 | 22 | plt.figure(dpi=100) 23 | plt.hist(data, bins=50) 24 | plt.title('mean: {:.3f}_std: {:.3f}'.format(data_mean, data_std)) 25 | plt.savefig(path_outfile) 26 | plt.close() 27 | 28 | 29 | def traverse_dir( 30 | root_dir, 31 | extension=('mid', 'MID'), 32 | amount=None, 33 | str_=None, 34 | is_pure=False, 35 | verbose=False, 36 | is_sort=False, 37 | is_ext=True): 38 | 39 | if verbose: 40 | print('[*] Scanning...') 41 | file_list = [] 42 | cnt = 0 43 | for root, _, files in os.walk(root_dir): 44 | for file in files: 45 | if file.endswith(extension): 46 | if (amount is not None) and (cnt == amount): 47 | break 48 | if str_ is not None: 49 | if str_ not in file: 50 | continue 51 | mix_path = os.path.join(root, file) 52 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 53 | if not is_ext: 54 | ext = pure_path.split('.')[-1] 55 | pure_path = pure_path[:-(len(ext)+1)] 56 | if verbose: 57 | print(pure_path) 58 | file_list.append(pure_path) 59 | cnt += 1 60 | if verbose: 61 | print('Total: %d files' % len(file_list)) 62 | print('Done!!!') 63 | if is_sort: 64 | file_list.sort() 65 | return file_list 66 | 67 | 68 | # define event 69 | def create_event(name, value): 70 | event = dict() 71 | event['name'] = name 72 | event['value'] = value 73 | return event 74 | 75 | 76 | # core functions 77 | def corpus2event_remi_v2(path_infile, path_outfile): 78 | ''' 79 | <<< REMI v2 >>> 80 | task: 2 track 81 | 1: piano (note + tempo + chord) 82 | --- 83 | remove duplicate position tokens 84 | ''' 85 | data = pickle.load(open(path_infile, 'rb')) 86 | 87 | # global tag 88 | global_end = data['metadata']['last_bar'] * BAR_RESOL 89 | 90 | # process 91 | final_sequence = [] 92 | for bar_step in range(0, global_end, BAR_RESOL): 93 | final_sequence.append(create_event('Bar', None)) 94 | 95 | # --- piano track --- # 96 | for timing in range(bar_step, bar_step + BAR_RESOL, TICK_RESOL): 97 | pos_events = [] 98 | 99 | # unpack 100 | t_chords = data['chords'][timing] 101 | t_tempos = data['tempos'][timing] 102 | t_notes = data['notes'][0][timing] # piano track 103 | 104 | # chord 105 | if len(t_chords): 106 | root, quality, bass = t_chords[0].text.split('_') 107 | pos_events.append(create_event('Chord', root+'_'+quality)) 108 | 109 | # tempo 110 | if len(t_tempos): 111 | pos_events.append(create_event('Tempo', t_tempos[0].tempo)) 112 | 113 | # note 114 | if len(t_notes): 115 | for note in t_notes: 116 | pos_events.extend([ 117 | create_event('Note_Pitch', note.pitch), 118 | create_event('Note_Velocity', note.velocity), 119 | create_event('Note_Duration', note.duration), 120 | ]) 121 | 122 | # collect & beat 123 | if len(pos_events): 124 | final_sequence.append( 125 | create_event('Beat', (timing-bar_step)//TICK_RESOL)) 126 | final_sequence.extend(pos_events) 127 | 128 | # BAR ending 129 | final_sequence.append(create_event('Bar', None)) 130 | 131 | # EOS 132 | final_sequence.append(create_event('EOS', None)) 133 | 134 | # save 135 | fn = os.path.basename(path_outfile) 136 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True) 137 | pickle.dump(final_sequence, open(path_outfile, 'wb')) 138 | return len(final_sequence) 139 | 140 | 141 | if __name__ == '__main__': 142 | # paths 143 | path_root = './ailab17k_from-scratch_remi' 144 | path_indir = '../../../corpus' 145 | path_outdir = os.path.join(path_root, 'events') 146 | os.makedirs(path_outdir, exist_ok=True) 147 | 148 | # list files 149 | midifiles = traverse_dir( 150 | path_indir, 151 | extension=('pkl'), 152 | is_pure=True, 153 | is_sort=True) 154 | n_files = len(midifiles) 155 | print('num fiels:', n_files) 156 | 157 | # run all 158 | len_list = [] 159 | for fidx in range(n_files): 160 | path_midi = midifiles[fidx] 161 | print('{}/{}'.format(fidx, n_files)) 162 | 163 | # paths 164 | path_infile = os.path.join(path_indir, path_midi) 165 | path_outfile = os.path.join(path_outdir, path_midi) 166 | 167 | # proc 168 | num_tokens = corpus2event_remi_v2(path_infile, path_outfile) 169 | print(' > num_token:', num_tokens) 170 | len_list.append(num_tokens) 171 | 172 | # plot 173 | plot_hist( 174 | len_list, 175 | os.path.join(path_root, 'num_tokens.png') 176 | ) -------------------------------------------------------------------------------- /dataset/representations/uncond/remi/events2words.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | 5 | 6 | def traverse_dir( 7 | root_dir, 8 | extension=('mid', 'MID'), 9 | amount=None, 10 | str_=None, 11 | is_pure=False, 12 | verbose=False, 13 | is_sort=False, 14 | is_ext=True): 15 | if verbose: 16 | print('[*] Scanning...') 17 | file_list = [] 18 | cnt = 0 19 | for root, _, files in os.walk(root_dir): 20 | for file in files: 21 | if file.endswith(extension): 22 | if (amount is not None) and (cnt == amount): 23 | break 24 | if str_ is not None: 25 | if str_ not in file: 26 | continue 27 | mix_path = os.path.join(root, file) 28 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 29 | if not is_ext: 30 | ext = pure_path.split('.')[-1] 31 | pure_path = pure_path[:-(len(ext)+1)] 32 | if verbose: 33 | print(pure_path) 34 | file_list.append(pure_path) 35 | cnt += 1 36 | if verbose: 37 | print('Total: %d files' % len(file_list)) 38 | print('Done!!!') 39 | if is_sort: 40 | file_list.sort() 41 | return file_list 42 | 43 | 44 | if __name__ == '__main__': 45 | # paths 46 | path_root = 'ailab17k_from-scratch_remi' 47 | path_indir = os.path.join(path_root, 'events') 48 | path_outdir = os.path.join(path_root, 'words') 49 | path_dictionary = os.path.join(path_root, 'dictionary.pkl') 50 | os.makedirs(path_outdir, exist_ok=True) 51 | 52 | # list files 53 | eventfiles = traverse_dir( 54 | path_indir, 55 | is_pure=True, 56 | is_sort=True, 57 | extension=('pkl')) 58 | n_files = len(eventfiles) 59 | print('num fiels:', n_files) 60 | 61 | # --- generate dictionary --- # 62 | print(' [*] generating dictionary') 63 | all_events = [] 64 | for file in eventfiles: 65 | for event in pickle.load(open(os.path.join(path_indir, file), 'rb')): 66 | all_events.append('{}_{}'.format(event['name'], event['value'])) 67 | 68 | # build 69 | unique_events = sorted(set(all_events), key=lambda x: (not isinstance(x, int), x)) 70 | event2word = {key: i for i, key in enumerate(unique_events)} 71 | word2event = {i: key for i, key in enumerate(unique_events)} 72 | print(' > num classes:', len(word2event)) 73 | 74 | # save 75 | pickle.dump((event2word, word2event), open(path_dictionary, 'wb')) 76 | 77 | # --- converts to word --- # 78 | event2word, word2event = pickle.load(open(path_dictionary, 'rb')) 79 | for fidx, file in enumerate(eventfiles): 80 | print('{}/{}'.format(fidx, n_files)) 81 | 82 | # events to words 83 | path_infile = os.path.join(path_indir, file) 84 | events = pickle.load(open(path_infile, 'rb')) 85 | words = [] 86 | for event in events: 87 | word = event2word['{}_{}'.format(event['name'], event['value'])] 88 | words.append(word) 89 | 90 | # save 91 | path_outfile = os.path.join(path_outdir, file + '.npy') 92 | fn = os.path.basename(path_outfile) 93 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True) 94 | np.save(path_outfile, words) 95 | -------------------------------------------------------------------------------- /dataset/representations/uncond/remi/valid_fn_idx_map.json: -------------------------------------------------------------------------------- 1 | {"fn2idx": {"Animenzzz/Koibumi (Shizuru's Theme) - Rewrite Soundtrack": 0, "DooPiano/PRODUCE 101 _ \u1109\u1173\u11af\u1105\u1166\u110b\u1175\u1110\u1173 - Oh Little Girl (\u110b\u1169\u1105\u1175\u1110\u1173\u11af\u1100\u1165\u11af) Piano Cover": 1, "DooPiano/\u1105\u1166\u1103\u1173\u1107\u1166\u11af\u1107\u1166\u11ba (Red Velvet) - Peek A Boo (Happy_\u1112\u1162\u1111\u1175 Ver.) Piano Cover": 2, "DooPiano/BLACKPINK - \u1104\u116e\u1103\u116e\u1104\u116e\u1103\u116e (DDU-DU DDU-DU) Piano Cover": 3, "TheTheorist/Flume Ft. Chet Faker - Drop The Game": 4, "MusicBand-Guide/\u00b5\u201a\u00b1\u00cd\u2265\u00cf HANA - \u00df\u2014\u221eO\u00df\u2044\u00b6\u20ac\u00a7v - \u00ba@\u2202\u221e '\u00ae\u0153\u00c6{\u00b6\u00ca\u2122\u00c32' \u00a7\u02d8\u00df\u00bf\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 5, "MusicBand-Guide/\u00a1\u00ac\u00a9M\u00a9\u2202 R-chord - \u00a1\u00ac\u00a1\u00ac\u00a9p\u2211R\u00df\u2044 Thanks for Your Love - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 6, "MusicBand-Guide/\u2202\u00bf\u220f\u00d9\u00b1\u00cd\u00d8\u00d9 Lulu - \u2022\u02db\u2265\u00a3\u00b5\u03c0\u00dfA\u00a7F - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 7, "MusicBand-Guide/\u2202P\u00a7@\u00d8\u00cb - \u03a9\u2013\u2022\u02dd\u00aa\u00b0\u00dfA\u00b6n - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 8, "DooPiano/\u1107\u1169\u11af\u1108\u1161\u11af\u1100\u1161\u11ab\u1109\u1161\u110e\u116e\u11ab\u1100\u1175 (Bolbbalgan4) - \u110c\u1169\u11c2\u1103\u1161\u1100\u1169 \u1106\u1161\u11af\u1112\u1162 (Tell Me You Love Me) Piano Cover": 9, "MusicBand-Guide/Taeyeon \u592a\u598d \ud0dc\uc5f0 - Four Seasons \uc0ac\uacc4 - Piano Tutorial \u92fc\u7434\u6559\u5b78 \ud53c\uc544\ub178 [HQ] Synthesia": 10, "DooPiano/\u1110\u1162\u110b\u1167\u11ab (TAEYEON) - \u1109\u1161\u1100\u1168 (Four Seasons) Piano Cover": 11, "TheTheorist/Billie Eilish - listen before i go": 12, "DooPiano/\u1109\u1166\u1107\u1173\u11ab\u1110\u1175\u11ab (SEVENTEEN) - \u110b\u116e\u11af\u1100\u1169 \u1109\u1175\u11c1\u110c\u1175 \u110b\u1161\u11ad\u110b\u1161 (Don't Wanna Cry) Piano Cover": 13, "TheTheorist/Rihanna ft. Kanye West & Paul McCartney - FourFiveSeconds": 14, "MusicBand-Guide/\u95dc\u5586 Grady - \u60f3\u4f60\u7684\u591c (\u672a\u7720\u7248) Miss You Tonight - Piano Tutorial \u92fc\u7434\u6559\u5b78 [HQ] Synthesia": 15, "MusicBand-Guide/\u00b5\u00ff\u00b1\u00b7\u00b6t & \u2211\u00ae\u00a9v\u03a9n - \u221e\u00cd\u00a7\u02dd\u00aaP\u00a7^\u00a7\u00a2 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 16, "MusicBand-Guide/\u2122L\u00b4T\u2265\u00ab JJ Lin & \u03a9\u2264\u00ae\u00d9\u00df\u221e A-Sa - \u00a7p\u221es\u222b\u20ac Dimples - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 17, "MusicBand-Guide/YNW Melly - Mama Cry - Piano Tutorial [HQ] Synthesia": 18, "MusicBand-Guide/\u03a9\u2264\u221e\u2211\u2202\u00c6 - \u2202V\u00ae\u201d\u2202V\u00a7\u00a3\u00bf\u00a5 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 19, "MusicBand-Guide/\u00bfF\u00b4\u2265\u00c6\u00ca Janice Yan - \u00bfu\u2202\u00c6\u03c0D\u00dfO Graceful Goodbye - \u03c0q\u00b5\u00af\u00ba@ '20\u00a7\u00df\u00b4\u00b7' \u00a7\u02d8\u00df\u00bf\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 20, "MusicBand-Guide/\u00aa\u00bb\u00a1{ - \u2211N\u221a\u00af\u2022\u2260 - \u03c0q\u00b5\u00af\u00ba@ '\u2265\u00d8\u00b1\u00b0\u2022O' \u00a5\u00b0\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 21, "MusicBand-Guide/\u00df\u0131\u2122v\u00df\u00a0 - I Know You Know - \u03c0q\u00b5\u00af\u00ba@ '\u00df\u2044\u2122\u222b\u00d8u\u2122B\u00a7\u00d5' \u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 22, "MusicBand-Guide/\u00a1\u00ac\u00a9M\u00a9\u2202 R-chord - \u00dfA\u00a8O\u00d8u\u2122\u222b\u00ac\u02dc\u2202}\u00df\u2044 You Are Leaving Me - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 23, "MusicBand-Guide/\u2264\u02c6\u00a7\u00c2\u03a9\u00b4 Karen Mok - \u00df\ufb02\u2211n - \u03c0q\u00b5\u00af\u00ba@ '\u00df\ufb02\u2211n' \u00b6P\u00b6W\u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 24, "MusicBand-Guide/[\u00b5^\u221a\u2013\u2122\u00a9] \u00a9P\u00f8\u2265\u2260\u0131 Eric Chou - \u00a7@\u00ba\u00c0\u00a8\u00b8\u0192R Forever Beautiful - \u00d8\u00aa\u00a8\u0131\u00b5\u2211\u00b1a\u00ae\u2248\u00bf\u02d8\u00ae\u00e6\u2122v\u00b4\u2248\u00e6\u2026\u00a8\u00b0\u221e\u00a0\u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 25, "TheTheorist/Post Malone & Swae Lee - Sunflower": 26, "MusicBand-Guide/\u03c0p\u00b4B\u00a7\ufb02 Yuxin Lei - \u221eO\u00a9\u00bf After June [2014\u00b6~RAiNBOW\u2260p\u03c0\u222b '\u00e6\u00cc' \u00b1M\u00f8\u00cb\u2264\u00b6\u2211~\u00a9u\u2022D\u2022\u00a5] - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 27, "MusicBand-Guide/MappleZS - Star River In Your Eyes \u222b\u00b0\u2022\u00ff\u00a8P\u2122e - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 28, "MusicBand-Guide/\u9234\u6728\u5be6\u88cf - \u591c\u7a7a [\u6200\u611b\u5c0f\u884c\u661fED] - Piano Tutorial \u92fc\u7434\u6559\u5b78 \u30d4\u30a2\u30ce\u6307\u5c0e [HQ] Synthesia": 29, "DooPiano/\u110b\u1161\u110b\u1175\u110b\u1172 (IU) - \u1107\u1161\u11b7\u1111\u1167\u11ab\u110c\u1175 (Through the Night) Piano Cover": 30, "MusicBand-Guide/\u220ft\u00b5\u2264\u2022\u20ac Saint ft. \u2202\u00bf\u00a8\u00b8\u00a8\u221a - \u00aa\u00b0\u00a7\u00a3\u2022X\u00a7f\u2122\u222b\u2211Q\u00a9\u00bf - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 31, "MusicBand-Guide/\u00a9P\u2202\u00ab\u2202\u00d8 - \u00a7W\u2022j\u00a7\u00df\u00f8\u2019 - \u03c0q\u00b5\u00af\u00ba@ '\u00a7W\u2022j\u00b1\u00b0\u222bq' \u00a5\u00b0\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 32, "DooPiano/\u1107\u1169\u11af\u1108\u1161\u11af\u1100\u1161\u11ab\u1109\u1161\u110e\u116e\u11ab\u1100\u1175 (Bolbbalgan4) - \u1111\u1173\u1105\u1175\u110c\u1175\u110b\u1161 (Freesia) Piano Cover": 33, "DooPiano/BTS JIN (\u1107\u1161\u11bc\u1110\u1161\u11ab\u1109\u1169\u1102\u1167\u11ab\u1103\u1161\u11ab \u110c\u1175\u11ab) - \u110b\u1175 \u1107\u1161\u11b7 (Tonight) Piano Cover": 34, "TheTheorist/The Weeknd - What You Need": 35, "MusicBand-Guide/\u2265\u00d8\u222b\u02c6\u2260s Cheer Chen - \u00df\u2044\u2265\ufb02\u2248w\u00a7W\u00dfA\u00c6\u2026\u2122\u222b\u00a7\u222b\u00a7\ufb02\u00a8\u00b0\u221e\u00a0 - \u03c0q\u00bav '\u2265\ufb02\u2248w\u00dfA' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 36, "MusicBand-Guide/\u00b6\u2202\u2022D\u2211R - \u00ac\u221a\u00a7\u00a3\u00b6\u00cc\u2122\u222b\u00a7\ufb02\u220f\u0131 - \u03c0q\u00b5\u00af\u00ba@ '\u00df\u2044\u2022u\u2265\ufb02\u2248w\u00dfA' \u00a7\u02d8\u00bfY\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 37, "MusicBand-Guide/\u2265\u00d8\u00b4\u2265\u00ae\u2265 Eason Chen - \u2248\u02dd\u00df\u2044\u00d8d\u00b6b\u00dfA\u00ae\u2260\u221a\u2030 - \u03c0q\u00bav '\u00ac\\\u00a5\u00c1\u00a7H' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 38, "MusicBand-Guide/Justin Hurwitz - Quarantine - from movie 'First Man' - Piano Tutorial [HQ] Synthesia": 39, "MusicBand-Guide/\u00a1\u02d9\u03a9U & \u00b6\u00f8\u00a8M\u00aaT - \u2022H\u00a7H\u2022\u00a1\u2122\u222b\u00b6W\u220fq - \u03c0q\u00b5\u00af\u00ba@ '\u00a7H\u2022\u00a1\u2122\u222b\u00b6W\u220fq' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 40, "MusicBand-Guide/\u2265\u00d8\u03c0\u2248\u00e6\u00cf Ella Chen - 30\u221e\u2044 Age of 30 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 41, "MusicBand-Guide/\u00a1\u00df\u00a7\u00df\u00a1\u00e6 Joker Xue - \u221e\u25ca\u00a7\u2044 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 42, "MusicBand-Guide/\u00f8c\u00bas\u2022\u00da Crowd Lu - \u2265\u03a9\u2022J He-R - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 43, "MusicBand-Guide/[\u00b5^\u221a\u2013\u2122\u00a9] \u00a7p\u00c6\u2026\u00a9h\u00c6Q - \u2211R\u00ba\u2039 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 44, "TheTheorist/Drake & Majid Jordan - Summer's Over Interlude": 45, "MusicBand-Guide/\u00b6\u00c3\u00a8z\u2022\u00bb\u00c6v Kenshi Yonezu - Lemon - Piano Tutorial [HQ] Synthesia": 46, "MusicBand-Guide/\u00f8c\u00bas\u2022\u00da Crowd Lu - \u00a5X\u00a7\u00bf\u00a7\u00df\u00a5X You Complete Me - '\u2122\u00b7\u2022\u201c\u00a7j\u00a7H\u00ac\u2021\u00aek\u00b4\u0192' \u03c0q\u00bav\u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 47, "MusicBand-Guide/La La Land - Late for the Date 'Mia & Sebastian\u00b0\u00b6s Theme' - Piano Tutorial [HQ] Synthesia": 48, "MusicBand-Guide/G.E.M. \u00e6H\u00b5\u00b5\u00a5\u2014 - \u00b5e - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 49}, "idx2fn": {"0": "Animenzzz/Koibumi (Shizuru's Theme) - Rewrite Soundtrack", "1": "DooPiano/PRODUCE 101 _ \u1109\u1173\u11af\u1105\u1166\u110b\u1175\u1110\u1173 - Oh Little Girl (\u110b\u1169\u1105\u1175\u1110\u1173\u11af\u1100\u1165\u11af) Piano Cover", "2": "DooPiano/\u1105\u1166\u1103\u1173\u1107\u1166\u11af\u1107\u1166\u11ba (Red Velvet) - Peek A Boo (Happy_\u1112\u1162\u1111\u1175 Ver.) Piano Cover", "3": "DooPiano/BLACKPINK - \u1104\u116e\u1103\u116e\u1104\u116e\u1103\u116e (DDU-DU DDU-DU) Piano Cover", "4": "TheTheorist/Flume Ft. Chet Faker - Drop The Game", "5": "MusicBand-Guide/\u00b5\u201a\u00b1\u00cd\u2265\u00cf HANA - \u00df\u2014\u221eO\u00df\u2044\u00b6\u20ac\u00a7v - \u00ba@\u2202\u221e '\u00ae\u0153\u00c6{\u00b6\u00ca\u2122\u00c32' \u00a7\u02d8\u00df\u00bf\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "6": "MusicBand-Guide/\u00a1\u00ac\u00a9M\u00a9\u2202 R-chord - \u00a1\u00ac\u00a1\u00ac\u00a9p\u2211R\u00df\u2044 Thanks for Your Love - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "7": "MusicBand-Guide/\u2202\u00bf\u220f\u00d9\u00b1\u00cd\u00d8\u00d9 Lulu - \u2022\u02db\u2265\u00a3\u00b5\u03c0\u00dfA\u00a7F - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "8": "MusicBand-Guide/\u2202P\u00a7@\u00d8\u00cb - \u03a9\u2013\u2022\u02dd\u00aa\u00b0\u00dfA\u00b6n - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "9": "DooPiano/\u1107\u1169\u11af\u1108\u1161\u11af\u1100\u1161\u11ab\u1109\u1161\u110e\u116e\u11ab\u1100\u1175 (Bolbbalgan4) - \u110c\u1169\u11c2\u1103\u1161\u1100\u1169 \u1106\u1161\u11af\u1112\u1162 (Tell Me You Love Me) Piano Cover", "10": "MusicBand-Guide/Taeyeon \u592a\u598d \ud0dc\uc5f0 - Four Seasons \uc0ac\uacc4 - Piano Tutorial \u92fc\u7434\u6559\u5b78 \ud53c\uc544\ub178 [HQ] Synthesia", "11": "DooPiano/\u1110\u1162\u110b\u1167\u11ab (TAEYEON) - \u1109\u1161\u1100\u1168 (Four Seasons) Piano Cover", "12": "TheTheorist/Billie Eilish - listen before i go", "13": "DooPiano/\u1109\u1166\u1107\u1173\u11ab\u1110\u1175\u11ab (SEVENTEEN) - \u110b\u116e\u11af\u1100\u1169 \u1109\u1175\u11c1\u110c\u1175 \u110b\u1161\u11ad\u110b\u1161 (Don't Wanna Cry) Piano Cover", "14": "TheTheorist/Rihanna ft. Kanye West & Paul McCartney - FourFiveSeconds", "15": "MusicBand-Guide/\u95dc\u5586 Grady - \u60f3\u4f60\u7684\u591c (\u672a\u7720\u7248) Miss You Tonight - Piano Tutorial \u92fc\u7434\u6559\u5b78 [HQ] Synthesia", "16": "MusicBand-Guide/\u00b5\u00ff\u00b1\u00b7\u00b6t & \u2211\u00ae\u00a9v\u03a9n - \u221e\u00cd\u00a7\u02dd\u00aaP\u00a7^\u00a7\u00a2 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "17": "MusicBand-Guide/\u2122L\u00b4T\u2265\u00ab JJ Lin & \u03a9\u2264\u00ae\u00d9\u00df\u221e A-Sa - \u00a7p\u221es\u222b\u20ac Dimples - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "18": "MusicBand-Guide/YNW Melly - Mama Cry - Piano Tutorial [HQ] Synthesia", "19": "MusicBand-Guide/\u03a9\u2264\u221e\u2211\u2202\u00c6 - \u2202V\u00ae\u201d\u2202V\u00a7\u00a3\u00bf\u00a5 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "20": "MusicBand-Guide/\u00bfF\u00b4\u2265\u00c6\u00ca Janice Yan - \u00bfu\u2202\u00c6\u03c0D\u00dfO Graceful Goodbye - \u03c0q\u00b5\u00af\u00ba@ '20\u00a7\u00df\u00b4\u00b7' \u00a7\u02d8\u00df\u00bf\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "21": "MusicBand-Guide/\u00aa\u00bb\u00a1{ - \u2211N\u221a\u00af\u2022\u2260 - \u03c0q\u00b5\u00af\u00ba@ '\u2265\u00d8\u00b1\u00b0\u2022O' \u00a5\u00b0\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "22": "MusicBand-Guide/\u00df\u0131\u2122v\u00df\u00a0 - I Know You Know - \u03c0q\u00b5\u00af\u00ba@ '\u00df\u2044\u2122\u222b\u00d8u\u2122B\u00a7\u00d5' \u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "23": "MusicBand-Guide/\u00a1\u00ac\u00a9M\u00a9\u2202 R-chord - \u00dfA\u00a8O\u00d8u\u2122\u222b\u00ac\u02dc\u2202}\u00df\u2044 You Are Leaving Me - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "24": "MusicBand-Guide/\u2264\u02c6\u00a7\u00c2\u03a9\u00b4 Karen Mok - \u00df\ufb02\u2211n - \u03c0q\u00b5\u00af\u00ba@ '\u00df\ufb02\u2211n' \u00b6P\u00b6W\u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "25": "MusicBand-Guide/[\u00b5^\u221a\u2013\u2122\u00a9] \u00a9P\u00f8\u2265\u2260\u0131 Eric Chou - \u00a7@\u00ba\u00c0\u00a8\u00b8\u0192R Forever Beautiful - \u00d8\u00aa\u00a8\u0131\u00b5\u2211\u00b1a\u00ae\u2248\u00bf\u02d8\u00ae\u00e6\u2122v\u00b4\u2248\u00e6\u2026\u00a8\u00b0\u221e\u00a0\u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "26": "TheTheorist/Post Malone & Swae Lee - Sunflower", "27": "MusicBand-Guide/\u03c0p\u00b4B\u00a7\ufb02 Yuxin Lei - \u221eO\u00a9\u00bf After June [2014\u00b6~RAiNBOW\u2260p\u03c0\u222b '\u00e6\u00cc' \u00b1M\u00f8\u00cb\u2264\u00b6\u2211~\u00a9u\u2022D\u2022\u00a5] - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "28": "MusicBand-Guide/MappleZS - Star River In Your Eyes \u222b\u00b0\u2022\u00ff\u00a8P\u2122e - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "29": "MusicBand-Guide/\u9234\u6728\u5be6\u88cf - \u591c\u7a7a [\u6200\u611b\u5c0f\u884c\u661fED] - Piano Tutorial \u92fc\u7434\u6559\u5b78 \u30d4\u30a2\u30ce\u6307\u5c0e [HQ] Synthesia", "30": "DooPiano/\u110b\u1161\u110b\u1175\u110b\u1172 (IU) - \u1107\u1161\u11b7\u1111\u1167\u11ab\u110c\u1175 (Through the Night) Piano Cover", "31": "MusicBand-Guide/\u220ft\u00b5\u2264\u2022\u20ac Saint ft. \u2202\u00bf\u00a8\u00b8\u00a8\u221a - \u00aa\u00b0\u00a7\u00a3\u2022X\u00a7f\u2122\u222b\u2211Q\u00a9\u00bf - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "32": "MusicBand-Guide/\u00a9P\u2202\u00ab\u2202\u00d8 - \u00a7W\u2022j\u00a7\u00df\u00f8\u2019 - \u03c0q\u00b5\u00af\u00ba@ '\u00a7W\u2022j\u00b1\u00b0\u222bq' \u00a5\u00b0\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "33": "DooPiano/\u1107\u1169\u11af\u1108\u1161\u11af\u1100\u1161\u11ab\u1109\u1161\u110e\u116e\u11ab\u1100\u1175 (Bolbbalgan4) - \u1111\u1173\u1105\u1175\u110c\u1175\u110b\u1161 (Freesia) Piano Cover", "34": "DooPiano/BTS JIN (\u1107\u1161\u11bc\u1110\u1161\u11ab\u1109\u1169\u1102\u1167\u11ab\u1103\u1161\u11ab \u110c\u1175\u11ab) - \u110b\u1175 \u1107\u1161\u11b7 (Tonight) Piano Cover", "35": "TheTheorist/The Weeknd - What You Need", "36": "MusicBand-Guide/\u2265\u00d8\u222b\u02c6\u2260s Cheer Chen - \u00df\u2044\u2265\ufb02\u2248w\u00a7W\u00dfA\u00c6\u2026\u2122\u222b\u00a7\u222b\u00a7\ufb02\u00a8\u00b0\u221e\u00a0 - \u03c0q\u00bav '\u2265\ufb02\u2248w\u00dfA' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "37": "MusicBand-Guide/\u00b6\u2202\u2022D\u2211R - \u00ac\u221a\u00a7\u00a3\u00b6\u00cc\u2122\u222b\u00a7\ufb02\u220f\u0131 - \u03c0q\u00b5\u00af\u00ba@ '\u00df\u2044\u2022u\u2265\ufb02\u2248w\u00dfA' \u00a7\u02d8\u00bfY\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "38": "MusicBand-Guide/\u2265\u00d8\u00b4\u2265\u00ae\u2265 Eason Chen - \u2248\u02dd\u00df\u2044\u00d8d\u00b6b\u00dfA\u00ae\u2260\u221a\u2030 - \u03c0q\u00bav '\u00ac\\\u00a5\u00c1\u00a7H' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "39": "MusicBand-Guide/Justin Hurwitz - Quarantine - from movie 'First Man' - Piano Tutorial [HQ] Synthesia", "40": "MusicBand-Guide/\u00a1\u02d9\u03a9U & \u00b6\u00f8\u00a8M\u00aaT - \u2022H\u00a7H\u2022\u00a1\u2122\u222b\u00b6W\u220fq - \u03c0q\u00b5\u00af\u00ba@ '\u00a7H\u2022\u00a1\u2122\u222b\u00b6W\u220fq' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "41": "MusicBand-Guide/\u2265\u00d8\u03c0\u2248\u00e6\u00cf Ella Chen - 30\u221e\u2044 Age of 30 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "42": "MusicBand-Guide/\u00a1\u00df\u00a7\u00df\u00a1\u00e6 Joker Xue - \u221e\u25ca\u00a7\u2044 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "43": "MusicBand-Guide/\u00f8c\u00bas\u2022\u00da Crowd Lu - \u2265\u03a9\u2022J He-R - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "44": "MusicBand-Guide/[\u00b5^\u221a\u2013\u2122\u00a9] \u00a7p\u00c6\u2026\u00a9h\u00c6Q - \u2211R\u00ba\u2039 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "45": "TheTheorist/Drake & Majid Jordan - Summer's Over Interlude", "46": "MusicBand-Guide/\u00b6\u00c3\u00a8z\u2022\u00bb\u00c6v Kenshi Yonezu - Lemon - Piano Tutorial [HQ] Synthesia", "47": "MusicBand-Guide/\u00f8c\u00bas\u2022\u00da Crowd Lu - \u00a5X\u00a7\u00bf\u00a7\u00df\u00a5X You Complete Me - '\u2122\u00b7\u2022\u201c\u00a7j\u00a7H\u00ac\u2021\u00aek\u00b4\u0192' \u03c0q\u00bav\u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "48": "MusicBand-Guide/La La Land - Late for the Date 'Mia & Sebastian\u00b0\u00b6s Theme' - Piano Tutorial [HQ] Synthesia", "49": "MusicBand-Guide/G.E.M. \u00e6H\u00b5\u00b5\u00a5\u2014 - \u00b5e - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia"}} -------------------------------------------------------------------------------- /dataset/representations/uncond/validation_songs.json: -------------------------------------------------------------------------------- 1 | ["MusicBand-Guide/\u03c0p\u00b4B\u00a7\ufb02 Yuxin Lei - \u221eO\u00a9\u00bf After June [2014\u00b6~RAiNBOW\u2260p\u03c0\u222b '\u00e6\u00cc' \u00b1M\u00f8\u00cb\u2264\u00b6\u2211~\u00a9u\u2022D\u2022\u00a5] - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "DooPiano/\u110b\u1161\u110b\u1175\u110b\u1172 (IU) - \u1107\u1161\u11b7\u1111\u1167\u11ab\u110c\u1175 (Through the Night) Piano Cover", "TheTheorist/Post Malone & Swae Lee - Sunflower", "MusicBand-Guide/\u00b5\u00ff\u00b1\u00b7\u00b6t & \u2211\u00ae\u00a9v\u03a9n - \u221e\u00cd\u00a7\u02dd\u00aaP\u00a7^\u00a7\u00a2 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "Animenzzz/Koibumi (Shizuru's Theme) - Rewrite Soundtrack", "DooPiano/BTS JIN (\u1107\u1161\u11bc\u1110\u1161\u11ab\u1109\u1169\u1102\u1167\u11ab\u1103\u1161\u11ab \u110c\u1175\u11ab) - \u110b\u1175 \u1107\u1161\u11b7 (Tonight) Piano Cover", "MusicBand-Guide/MappleZS - Star River In Your Eyes \u222b\u00b0\u2022\u00ff\u00a8P\u2122e - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u03a9\u2264\u221e\u2211\u2202\u00c6 - \u2202V\u00ae\u201d\u2202V\u00a7\u00a3\u00bf\u00a5 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u00df\u0131\u2122v\u00df\u00a0 - I Know You Know - \u03c0q\u00b5\u00af\u00ba@ '\u00df\u2044\u2122\u222b\u00d8u\u2122B\u00a7\u00d5' \u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "TheTheorist/Rihanna ft. Kanye West & Paul McCartney - FourFiveSeconds", "MusicBand-Guide/\u2265\u00d8\u03c0\u2248\u00e6\u00cf Ella Chen - 30\u221e\u2044 Age of 30 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "DooPiano/BLACKPINK - \u1104\u116e\u1103\u116e\u1104\u116e\u1103\u116e (DDU-DU DDU-DU) Piano Cover", "MusicBand-Guide/[\u00b5^\u221a\u2013\u2122\u00a9] \u00a9P\u00f8\u2265\u2260\u0131 Eric Chou - \u00a7@\u00ba\u00c0\u00a8\u00b8\u0192R Forever Beautiful - \u00d8\u00aa\u00a8\u0131\u00b5\u2211\u00b1a\u00ae\u2248\u00bf\u02d8\u00ae\u00e6\u2122v\u00b4\u2248\u00e6\u2026\u00a8\u00b0\u221e\u00a0\u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u00a1\u02d9\u03a9U & \u00b6\u00f8\u00a8M\u00aaT - \u2022H\u00a7H\u2022\u00a1\u2122\u222b\u00b6W\u220fq - \u03c0q\u00b5\u00af\u00ba@ '\u00a7H\u2022\u00a1\u2122\u222b\u00b6W\u220fq' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u00f8c\u00bas\u2022\u00da Crowd Lu - \u00a5X\u00a7\u00bf\u00a7\u00df\u00a5X You Complete Me - '\u2122\u00b7\u2022\u201c\u00a7j\u00a7H\u00ac\u2021\u00aek\u00b4\u0192' \u03c0q\u00bav\u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u2202P\u00a7@\u00d8\u00cb - \u03a9\u2013\u2022\u02dd\u00aa\u00b0\u00dfA\u00b6n - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u00a1\u00ac\u00a9M\u00a9\u2202 R-chord - \u00a1\u00ac\u00a1\u00ac\u00a9p\u2211R\u00df\u2044 Thanks for Your Love - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/La La Land - Late for the Date 'Mia & Sebastian\u00b0\u00b6s Theme' - Piano Tutorial [HQ] Synthesia", "DooPiano/\u1110\u1162\u110b\u1167\u11ab (TAEYEON) - \u1109\u1161\u1100\u1168 (Four Seasons) Piano Cover", "MusicBand-Guide/YNW Melly - Mama Cry - Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/Taeyeon \u592a\u598d \ud0dc\uc5f0 - Four Seasons \uc0ac\uacc4 - Piano Tutorial \u92fc\u7434\u6559\u5b78 \ud53c\uc544\ub178 [HQ] Synthesia", "MusicBand-Guide/\u00a1\u00df\u00a7\u00df\u00a1\u00e6 Joker Xue - \u221e\u25ca\u00a7\u2044 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "DooPiano/PRODUCE 101 _ \u1109\u1173\u11af\u1105\u1166\u110b\u1175\u1110\u1173 - Oh Little Girl (\u110b\u1169\u1105\u1175\u1110\u1173\u11af\u1100\u1165\u11af) Piano Cover", "MusicBand-Guide/\u220ft\u00b5\u2264\u2022\u20ac Saint ft. \u2202\u00bf\u00a8\u00b8\u00a8\u221a - \u00aa\u00b0\u00a7\u00a3\u2022X\u00a7f\u2122\u222b\u2211Q\u00a9\u00bf - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "DooPiano/\u1107\u1169\u11af\u1108\u1161\u11af\u1100\u1161\u11ab\u1109\u1161\u110e\u116e\u11ab\u1100\u1175 (Bolbbalgan4) - \u110c\u1169\u11c2\u1103\u1161\u1100\u1169 \u1106\u1161\u11af\u1112\u1162 (Tell Me You Love Me) Piano Cover", "MusicBand-Guide/\u00a9P\u2202\u00ab\u2202\u00d8 - \u00a7W\u2022j\u00a7\u00df\u00f8\u2019 - \u03c0q\u00b5\u00af\u00ba@ '\u00a7W\u2022j\u00b1\u00b0\u222bq' \u00a5\u00b0\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u00b6\u2202\u2022D\u2211R - \u00ac\u221a\u00a7\u00a3\u00b6\u00cc\u2122\u222b\u00a7\ufb02\u220f\u0131 - \u03c0q\u00b5\u00af\u00ba@ '\u00df\u2044\u2022u\u2265\ufb02\u2248w\u00dfA' \u00a7\u02d8\u00bfY\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u00bfF\u00b4\u2265\u00c6\u00ca Janice Yan - \u00bfu\u2202\u00c6\u03c0D\u00dfO Graceful Goodbye - \u03c0q\u00b5\u00af\u00ba@ '20\u00a7\u00df\u00b4\u00b7' \u00a7\u02d8\u00df\u00bf\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u2265\u00d8\u222b\u02c6\u2260s Cheer Chen - \u00df\u2044\u2265\ufb02\u2248w\u00a7W\u00dfA\u00c6\u2026\u2122\u222b\u00a7\u222b\u00a7\ufb02\u00a8\u00b0\u221e\u00a0 - \u03c0q\u00bav '\u2265\ufb02\u2248w\u00dfA' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/[\u00b5^\u221a\u2013\u2122\u00a9] \u00a7p\u00c6\u2026\u00a9h\u00c6Q - \u2211R\u00ba\u2039 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "DooPiano/\u1105\u1166\u1103\u1173\u1107\u1166\u11af\u1107\u1166\u11ba (Red Velvet) - Peek A Boo (Happy_\u1112\u1162\u1111\u1175 Ver.) Piano Cover", "TheTheorist/Drake & Majid Jordan - Summer's Over Interlude", "MusicBand-Guide/\u00b6\u00c3\u00a8z\u2022\u00bb\u00c6v Kenshi Yonezu - Lemon - Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/G.E.M. \u00e6H\u00b5\u00b5\u00a5\u2014 - \u00b5e - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "DooPiano/\u1107\u1169\u11af\u1108\u1161\u11af\u1100\u1161\u11ab\u1109\u1161\u110e\u116e\u11ab\u1100\u1175 (Bolbbalgan4) - \u1111\u1173\u1105\u1175\u110c\u1175\u110b\u1161 (Freesia) Piano Cover", "MusicBand-Guide/\u2264\u02c6\u00a7\u00c2\u03a9\u00b4 Karen Mok - \u00df\ufb02\u2211n - \u03c0q\u00b5\u00af\u00ba@ '\u00df\ufb02\u2211n' \u00b6P\u00b6W\u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "DooPiano/\u1109\u1166\u1107\u1173\u11ab\u1110\u1175\u11ab (SEVENTEEN) - \u110b\u116e\u11af\u1100\u1169 \u1109\u1175\u11c1\u110c\u1175 \u110b\u1161\u11ad\u110b\u1161 (Don't Wanna Cry) Piano Cover", "MusicBand-Guide/\u00f8c\u00bas\u2022\u00da Crowd Lu - \u2265\u03a9\u2022J He-R - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u2265\u00d8\u00b4\u2265\u00ae\u2265 Eason Chen - \u2248\u02dd\u00df\u2044\u00d8d\u00b6b\u00dfA\u00ae\u2260\u221a\u2030 - \u03c0q\u00bav '\u00ac\\\u00a5\u00c1\u00a7H' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "TheTheorist/Flume Ft. Chet Faker - Drop The Game", "TheTheorist/Billie Eilish - listen before i go", "TheTheorist/The Weeknd - What You Need", "MusicBand-Guide/\u2202\u00bf\u220f\u00d9\u00b1\u00cd\u00d8\u00d9 Lulu - \u2022\u02db\u2265\u00a3\u00b5\u03c0\u00dfA\u00a7F - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u00aa\u00bb\u00a1{ - \u2211N\u221a\u00af\u2022\u2260 - \u03c0q\u00b5\u00af\u00ba@ '\u2265\u00d8\u00b1\u00b0\u2022O' \u00a5\u00b0\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u00a1\u00ac\u00a9M\u00a9\u2202 R-chord - \u00dfA\u00a8O\u00d8u\u2122\u222b\u00ac\u02dc\u2202}\u00df\u2044 You Are Leaving Me - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u9234\u6728\u5be6\u88cf - \u591c\u7a7a [\u6200\u611b\u5c0f\u884c\u661fED] - Piano Tutorial \u92fc\u7434\u6559\u5b78 \u30d4\u30a2\u30ce\u6307\u5c0e [HQ] Synthesia", "MusicBand-Guide/\u2122L\u00b4T\u2265\u00ab JJ Lin & \u03a9\u2264\u00ae\u00d9\u00df\u221e A-Sa - \u00a7p\u221es\u222b\u20ac Dimples - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u00b5\u201a\u00b1\u00cd\u2265\u00cf HANA - \u00df\u2014\u221eO\u00df\u2044\u00b6\u20ac\u00a7v - \u00ba@\u2202\u221e '\u00ae\u0153\u00c6{\u00b6\u00ca\u2122\u00c32' \u00a7\u02d8\u00df\u00bf\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/Justin Hurwitz - Quarantine - from movie 'First Man' - Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u95dc\u5586 Grady - \u60f3\u4f60\u7684\u591c (\u672a\u7720\u7248) Miss You Tonight - Piano Tutorial \u92fc\u7434\u6559\u5b78 [HQ] Synthesia"] -------------------------------------------------------------------------------- /dataset/synchronizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import copy 4 | import librosa 5 | import numpy as np 6 | import multiprocessing as mp 7 | from madmom.features.downbeats import DBNDownBeatTrackingProcessor 8 | from madmom.features.downbeats import RNNDownBeatProcessor 9 | from miditoolkit.midi import parser 10 | from miditoolkit.midi.containers import TimeSignature, TempoChange 11 | 12 | 13 | def traverse_dir( 14 | root_dir, 15 | extension=('mid', 'MID', 'midi'), 16 | amount=None, 17 | str_=None, 18 | is_pure=False, 19 | verbose=False, 20 | is_sort=False, 21 | is_ext=True): 22 | if verbose: 23 | print('[*] Scanning...') 24 | file_list = [] 25 | cnt = 0 26 | for root, _, files in os.walk(root_dir): 27 | for file in files: 28 | if file.endswith(extension): 29 | if (amount is not None) and (cnt == amount): 30 | break 31 | if str_ is not None: 32 | if str_ not in file: 33 | continue 34 | mix_path = os.path.join(root, file) 35 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 36 | if not is_ext: 37 | ext = pure_path.split('.')[-1] 38 | pure_path = pure_path[:-(len(ext)+1)] 39 | if verbose: 40 | print(pure_path) 41 | file_list.append(pure_path) 42 | cnt += 1 43 | if verbose: 44 | print('Total: %d files' % len(file_list)) 45 | print('Done!!!') 46 | if is_sort: 47 | file_list.sort() 48 | return file_list 49 | 50 | 51 | def get_instruments_abs_timing(instruments, tick_to_time): 52 | return convert_instruments_timing_from_sym_to_abs(instruments, tick_to_time) 53 | 54 | 55 | def convert_instruments_timing_from_sym_to_abs(instruments, tick_to_time): 56 | proc_instrs = copy.deepcopy(instruments) 57 | for instr in proc_instrs: 58 | for note in instr.notes: 59 | note.start = float(tick_to_time[note.start]) 60 | note.end = float(tick_to_time[note.end]) 61 | return proc_instrs 62 | 63 | 64 | def convert_instruments_timing_from_abs_to_sym(instruments, time_to_tick): 65 | proc_instrs = copy.deepcopy(instruments) 66 | for instr in proc_instrs: 67 | for note in instr.notes: 68 | # find nearest 69 | note.start = find_nearest_np(time_to_tick, note.start) 70 | note.end = find_nearest_np(time_to_tick, note.end) 71 | return proc_instrs 72 | 73 | 74 | def find_nearest_np(array, value): 75 | return (np.abs(array - value)).argmin() 76 | 77 | 78 | def find_first_downbeat(proc_res): 79 | rythm = np.where(proc_res[:, 1] == 1)[0] 80 | pos = proc_res[rythm[0], 0] 81 | return pos 82 | 83 | 84 | def interp_linear(src, target, num, tail=False): 85 | src = float(src) 86 | target = float(target) 87 | step = (target - src) / float(num) 88 | middles = [src + step * i for i in range(1, num)] 89 | res = [src] + middles 90 | if tail: 91 | res += [target] 92 | return res 93 | 94 | 95 | def estimate_beat(path_audio): 96 | proc = DBNDownBeatTrackingProcessor(beats_per_bar=[3, 4], fps=100) 97 | act = RNNDownBeatProcessor()(path_audio) 98 | proc_res = proc(act) 99 | return proc_res 100 | 101 | 102 | def export_audio_with_click(proc_res, path_audio, path_output, sr=44100): 103 | # extract time 104 | times_beat = proc_res[np.where(proc_res[:, 1]!=1)][:, 0] 105 | times_downbeat = proc_res[np.where(proc_res[:, 1]==1)][:, 0] 106 | 107 | # load 108 | y, _ = librosa.core.load(path_audio, sr=sr) 109 | 110 | # click audio 111 | y_beat = librosa.clicks(times=times_beat, sr=sr, click_freq=1200, click_duration=0.5) * 0.6 112 | y_downbeat = librosa.clicks(times=times_downbeat, sr=sr, click_freq=600, click_duration=0.5) 113 | 114 | # merge 115 | max_len = max(len(y), len(y_beat), len(y_downbeat)) 116 | y_integrate = np.zeros(max_len) 117 | y_integrate[:len(y_beat)] += y_beat 118 | y_integrate[:len(y_downbeat)] += y_downbeat 119 | y_integrate[:len(y)] += y 120 | 121 | librosa.output.write_wav(path_output, y_integrate, sr) 122 | 123 | 124 | def align_midi(proc_res, path_midi_input, path_midi_output, ticks_per_beat=480): 125 | midi_data = parser.MidiFile(path_midi_input) 126 | 127 | # compute tempo 128 | beats = np.array([0.0] + list(proc_res[:, 0])) 129 | intervals = np.diff(beats) 130 | bpms = 60 / intervals 131 | tempo_info = list(zip(beats[:-1], bpms)) 132 | 133 | # get absolute timing of instruments 134 | tick_to_time = midi_data.get_tick_to_time_mapping() 135 | abs_instr = get_instruments_abs_timing(midi_data.instruments, tick_to_time) 136 | 137 | # get end time of file 138 | end_time = midi_data.get_tick_to_time_mapping()[-1] 139 | 140 | # compute time to tick mapping 141 | resample_timing = [] 142 | for i in range(len(beats)-1): 143 | start_beat = beats[i] 144 | end_beat = beats[i + 1] 145 | resample_timing += interp_linear(start_beat, end_beat, ticks_per_beat) 146 | 147 | # fill the empty in the tail (using last tick interval) 148 | last_tick_interval = resample_timing[-1] - resample_timing[-2] 149 | cur_time = resample_timing[-1] 150 | while cur_time < end_time: 151 | cur_time += last_tick_interval 152 | resample_timing.append(cur_time) 153 | resample_timing = np.array(resample_timing) 154 | 155 | # new a midifile obj 156 | midi_res = parser.MidiFile() 157 | 158 | # convert abs to sym 159 | sym_instr = convert_instruments_timing_from_abs_to_sym(abs_instr, resample_timing) 160 | 161 | # time signature 162 | first_db_sec = find_first_downbeat(proc_res) 163 | first_db_tick = find_nearest_np(resample_timing, first_db_sec) 164 | time_signature_changes = [TimeSignature(numerator=4, denominator=4, time=int(first_db_tick))] 165 | 166 | # tempo 167 | tempo_changes = [] 168 | for pos, bpm in tempo_info: 169 | pos_tick = find_nearest_np(resample_timing, pos) 170 | tempo_changes.append(TempoChange(tempo=float(bpm), time=int(pos_tick))) 171 | 172 | # shift (pickup at the beginning) 173 | shift_align = ticks_per_beat * 4 - first_db_tick 174 | 175 | # apply shift to tempo 176 | for msg in tempo_changes: 177 | msg.time += shift_align 178 | 179 | # apply shift to notes 180 | for instr in sym_instr: 181 | for note in instr.notes: 182 | note.start += shift_align 183 | note.end += shift_align 184 | 185 | # set attributes 186 | midi_res.ticks_per_beat = ticks_per_beat 187 | midi_res.tempo_changes = tempo_changes 188 | midi_res.time_signature_changes = time_signature_changes 189 | midi_res.instruments = sym_instr 190 | 191 | # saving 192 | midi_res.dump(filename=path_midi_output) 193 | 194 | 195 | def analyze(path_midi_input, path_audio_input, path_midi_output, path_audio_output=None): 196 | print(path_midi_input) 197 | # beat tracking 198 | proc_res = estimate_beat(path_audio_input) 199 | # export audio with click 200 | if path_audio_output is not None: 201 | export_audio_with_click(proc_res, path_audio_input, path_audio_output) 202 | # export midi file 203 | align_midi(proc_res, path_midi_input, path_midi_output) 204 | 205 | 206 | if __name__ == '__main__': 207 | # paths 208 | path_audiodir = './mp3' 209 | path_indir = './midi_transcribed' 210 | path_outdir = './midi_synchronized' 211 | os.makedirs(path_outdir, exist_ok=True) 212 | 213 | # list files 214 | midifiles = traverse_dir( 215 | path_indir, 216 | is_ext=False, 217 | is_pure=True, 218 | is_sort=True) 219 | n_files = len(midifiles) 220 | print('num fiels:', n_files) 221 | 222 | # collect 223 | data = [] 224 | for fidx in range(n_files): 225 | fn = midifiles[fidx] 226 | print('{}/{}'.format(fidx, n_files)) 227 | 228 | # paths 229 | path_midi_input = os.path.join(path_indir, fn+'.mid') 230 | path_midi_output = os.path.join(path_outdir, fn+'.mid') 231 | path_audio_input = os.path.join(path_audiodir , fn+'.mp3') 232 | 233 | # mkdir 234 | fn = os.path.basename(path_midi_output) 235 | os.makedirs(path_midi_output[:-len(fn)], exist_ok=True) 236 | 237 | # append 238 | data.append([path_midi_input, path_audio_input, path_midi_output, None]) 239 | 240 | # run, multi-thread 241 | pool = mp.Pool() 242 | pool.starmap(analyze, data) 243 | -------------------------------------------------------------------------------- /docs/aaai21-slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/docs/aaai21-slides.pdf -------------------------------------------------------------------------------- /workspace/cond_ls2midi/keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/cond_ls2midi/keep -------------------------------------------------------------------------------- /workspace/uncond/Experiments.md: -------------------------------------------------------------------------------- 1 | # Experiments 2 | 3 | Experimental results of **unconditional** generation. We also elaborate detailed settings which are omitted in the paper because of the space limitation. 4 | 5 | ## Run the Codes 6 | **cp-linear** 7 | 8 | Edit the configration part at the begining of the `main-cp.py` file first. 9 | 10 | ```bash 11 | python main-cp.py 12 | ``` 13 | 14 | **remi-xl** 15 | 16 | Edit the `config.yml` file first. 17 | 18 | ```bash 19 | # train 20 | python train.py 21 | 22 | # inference 23 | python inference.py 24 | ``` 25 | 26 | 27 | ## Model Settings 28 | Backbone model: 29 | * linear transformer (Linear): ["Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"](https://arxiv.org/abs/2006.16236) 30 | * transformer-XL (XL): ["Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"](https://arxiv.org/abs/1901.02860) 31 | 32 | 33 | All transformers share the same settings: 34 | * number of layers: 12 35 | * number of heads: 8 36 | * model hidden size: 512 37 | * feed-forward layer size: 2,048 38 | 39 | | Settings | REMI + XL | CP + Linear | 40 | |:---------------------------:|:----------------------:|:----------------------:| 41 | | learning rate | 2e-4 | 1e-4 | 42 | | songs for training | 1,612 | 1,675 | 43 | | sequence length | 7,680 (512 x 15) | 3,584 | 44 | | dicionary size | 332 | 342 | 45 | | parameter amount | 41,291,084 | 39,016,630 | 46 | | recepetive field (train) | 512 + 512 for memory | 3,584 | 47 | 48 | Because the difference of nature between the two representations, it's hard to keep the training data for the 2 settings totally equal. We adjust the `number of songs` and the `sequence length` to achieve reasonable balance. 49 | 50 | Under the limitation of hardware budget (single 2080ti GPU), we report the comparison between the 2 settings: 51 | * REMI + XL: remi representation, transformer-XL 52 | * CP + Linear: compound word (CP) representation, linear transformer 53 | 54 | 55 | ## Evaluation 56 | ### Memory Usage 57 | The hardware budget is a single GPU and we use **Nvidia GeForce RTX 2080 GPU**, which has 11 GB memory. 58 | 59 | | # | Representation+model | batch size | Memory Usage (GB) | Evaluation | 60 | |:-:|:----------------------:|:----------------------:|:----------------------:|:-----------:| 61 | | 1 | REMI + XL | 4 | 4 | | 62 | | 2 | REMI + XL | 10 | 10 | O | 63 | | 3 | CP + Linear | 4 | 10 | O | 64 | 65 | Fro fair comparison, we let every model setting have its maximum memory consumption. 66 | Notice that 67 | 68 | 69 | ### Training Efficiency 70 | Relation between quality of generated samples and the loss could be different according to different settings. Here, we still measure the training efficiency based on cross-entropy loss. 71 | 72 | [WIP] 73 | 74 | 75 | ### Inference Efficiency and Performance 76 | We let each model to generate 50 songs and record the consumption time. 77 | 78 | Records (JSON file): 79 | * [cp-linear](./cp-linear/runtime_stats.json) 80 | * [remi-xl](./remi-xl/runtime_stats.json) 81 | 82 | | Representation+model | Ave. Song Time | EOS | 83 | |:----------------------:|:----------------------:|:--------:| 84 | | REMI + XL | 195.25446 | X | 85 | | CP + Linear | 20.91956 | O | 86 | 87 | `EOS` indicates whether the models are able to stop generation automatically - generating EOS token. 88 | For the CP+Linear setting, it only takes less than half minutes to generate a song and it also reveals its potential in real-time applications. 89 | 90 | ## Results 91 | ### Generated MIDI Files 92 | * [cp-linear](./cp-linear/gen_midis) 93 | * [remi-xl](./remi-xl/gen_midis) 94 | 95 | 96 | ### Checkpoints 97 | * [cp-linear](https://drive.google.com/drive/folders/114uore7LHjAsM4eKXG9TfVZL5S3YY7nZ?usp=sharing) 98 | * [remi-xl](https://drive.google.com/drive/folders/1tCaWQisPp_bcXKH5J3Nxmv6kUzJXs6qw?usp=sharing) 99 | 100 | ## Discussion 101 | ### About the generated samples 102 | We find that the generated pieces of REMI-XL tend to stick to some patterns and occasionally fall to loop them for a quite long time, or even the entire song. The quality within the loop is suprisingly organized (clearly arpeggio in left hand, melody line in right hand), but also a bit of tedious. The samples from CP-linear have a rather different texture, the "structure" diversity is richer but it's also more aggressive in selecting pitches. 103 | 104 | ### About EOS 105 | It turns out that REMI-XL failed in generating EOS sequence, which implies the sequence length might exceed the "effective length" becuase of it's rnn nature. In fact, we also tried remi+linear and cp+linear, and both of them success in this criterion. 106 | -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_0.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_0.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_1.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_10.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_10.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_11.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_11.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_12.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_12.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_13.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_13.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_14.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_14.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_15.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_15.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_16.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_16.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_17.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_17.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_18.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_18.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_19.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_19.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_2.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_20.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_20.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_21.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_21.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_22.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_22.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_23.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_23.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_24.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_24.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_25.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_25.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_26.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_26.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_27.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_27.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_28.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_28.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_29.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_29.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_3.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_30.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_30.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_31.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_31.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_32.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_32.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_33.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_33.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_34.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_34.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_35.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_35.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_36.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_36.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_37.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_37.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_38.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_38.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_39.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_39.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_4.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_4.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_40.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_40.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_41.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_41.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_42.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_42.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_43.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_43.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_44.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_44.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_45.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_45.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_46.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_46.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_47.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_47.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_48.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_48.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_49.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_49.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_5.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_5.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_6.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_6.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_7.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_7.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_8.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_8.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/gen_midis/get_9.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_9.mid -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/main-cp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | 5 | import math 6 | import time 7 | import glob 8 | import datetime 9 | import random 10 | import pickle 11 | import json 12 | import numpy as np 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | import torch.optim as optim 19 | from torch.nn.utils import clip_grad_norm_ 20 | from torch.utils.data import Dataset, DataLoader 21 | 22 | from fast_transformers.builders import TransformerEncoderBuilder 23 | from fast_transformers.builders import RecurrentEncoderBuilder 24 | from fast_transformers.masking import TriangularCausalMask 25 | 26 | import miditoolkit 27 | from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note 28 | 29 | import saver 30 | 31 | 32 | ################################################################################ 33 | # config 34 | ################################################################################ 35 | 36 | MODE = 'train' 37 | # MODE = 'inference' 38 | 39 | ###--- data ---### 40 | path_data_root = '..../dataset/representations/uncond/cp/ailab17k_from-scratch_cp' 41 | path_train_data = os.path.join(path_data_root, 'train_data_linear.npz') 42 | path_dictionary = os.path.join(path_data_root, 'dictionary.pkl') 43 | 44 | ###--- training config ---### 45 | D_MODEL = 512 46 | N_LAYER = 12 47 | N_HEAD = 8 48 | path_exp = 'exp' 49 | batch_size = 4 50 | gid = 0 51 | init_lr = 0.0001 52 | 53 | ###--- fine-tuning & inference config ---### 54 | # info_load_model = ( 55 | # # path to ckpt for loading 56 | # '/volume/ai-music-wayne/aaai/from-scratch/cp-linear/exp_base_fs', 57 | # # loss 58 | # 29 59 | # ) 60 | info_load_model = None 61 | path_gendir = 'gen_midis' 62 | num_songs = 50 63 | 64 | ################################################################################ 65 | # File IO 66 | ################################################################################ 67 | 68 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gid) 69 | BEAT_RESOL = 480 70 | BAR_RESOL = BEAT_RESOL * 4 71 | TICK_RESOL = BEAT_RESOL // 4 72 | 73 | 74 | def write_midi(words, path_outfile, word2event): 75 | 76 | class_keys = word2event.keys() 77 | # words = np.load(path_infile) 78 | midi_obj = miditoolkit.midi.parser.MidiFile() 79 | 80 | bar_cnt = 0 81 | cur_pos = 0 82 | 83 | all_notes = [] 84 | 85 | cnt_error = 0 86 | for i in range(len(words)): 87 | vals = [] 88 | for kidx, key in enumerate(class_keys): 89 | vals.append(word2event[key][words[i][kidx]]) 90 | # print(vals) 91 | 92 | if vals[3] == 'Metrical': 93 | if vals[2] == 'Bar': 94 | bar_cnt += 1 95 | elif 'Beat' in vals[2]: 96 | beat_pos = int(vals[2].split('_')[1]) 97 | cur_pos = bar_cnt * BAR_RESOL + beat_pos * TICK_RESOL 98 | 99 | # chord 100 | if vals[1] != 'CONTI' and vals[1] != 0: 101 | midi_obj.markers.append( 102 | Marker(text=str(vals[1]), time=cur_pos)) 103 | 104 | if vals[0] != 'CONTI' and vals[0] != 0: 105 | tempo = int(vals[0].split('_')[-1]) 106 | midi_obj.tempo_changes.append( 107 | TempoChange(tempo=tempo, time=cur_pos)) 108 | else: 109 | pass 110 | elif vals[3] == 'Note': 111 | 112 | try: 113 | pitch = vals[4].split('_')[-1] 114 | duration = vals[5].split('_')[-1] 115 | velocity = vals[6].split('_')[-1] 116 | 117 | if int(duration) == 0: 118 | duration = 60 119 | end = cur_pos + int(duration) 120 | 121 | all_notes.append( 122 | Note( 123 | pitch=int(pitch), 124 | start=cur_pos, 125 | end=end, 126 | velocity=int(velocity)) 127 | ) 128 | except: 129 | continue 130 | else: 131 | pass 132 | 133 | # save midi 134 | piano_track = Instrument(0, is_drum=False, name='piano') 135 | piano_track.notes = all_notes 136 | midi_obj.instruments = [piano_track] 137 | midi_obj.dump(path_outfile) 138 | 139 | 140 | ################################################################################ 141 | # Sampling 142 | ################################################################################ 143 | # -- temperature -- # 144 | def softmax_with_temperature(logits, temperature): 145 | probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature)) 146 | return probs 147 | 148 | 149 | def weighted_sampling(probs): 150 | probs /= sum(probs) 151 | sorted_probs = np.sort(probs)[::-1] 152 | sorted_index = np.argsort(probs)[::-1] 153 | word = np.random.choice(sorted_index, size=1, p=sorted_probs)[0] 154 | return word 155 | 156 | 157 | # -- nucleus -- # 158 | def nucleus(probs, p): 159 | probs /= (sum(probs) + 1e-5) 160 | sorted_probs = np.sort(probs)[::-1] 161 | sorted_index = np.argsort(probs)[::-1] 162 | cusum_sorted_probs = np.cumsum(sorted_probs) 163 | after_threshold = cusum_sorted_probs > p 164 | if sum(after_threshold) > 0: 165 | last_index = np.where(after_threshold)[0][0] + 1 166 | candi_index = sorted_index[:last_index] 167 | else: 168 | candi_index = sorted_index[:] 169 | candi_probs = [probs[i] for i in candi_index] 170 | candi_probs /= sum(candi_probs) 171 | word = np.random.choice(candi_index, size=1, p=candi_probs)[0] 172 | return word 173 | 174 | 175 | def sampling(logit, p=None, t=1.0): 176 | logit = logit.squeeze().cpu().numpy() 177 | probs = softmax_with_temperature(logits=logit, temperature=t) 178 | 179 | if p is not None: 180 | cur_word = nucleus(probs, p=p) 181 | else: 182 | cur_word = weighted_sampling(probs) 183 | return cur_word 184 | 185 | 186 | ################################################################################ 187 | # Model 188 | ################################################################################ 189 | 190 | 191 | def network_paras(model): 192 | # compute only trainable params 193 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 194 | params = sum([np.prod(p.size()) for p in model_parameters]) 195 | return params 196 | 197 | 198 | class Embeddings(nn.Module): 199 | def __init__(self, n_token, d_model): 200 | super(Embeddings, self).__init__() 201 | self.lut = nn.Embedding(n_token, d_model) 202 | self.d_model = d_model 203 | 204 | def forward(self, x): 205 | return self.lut(x) * math.sqrt(self.d_model) 206 | 207 | 208 | class PositionalEncoding(nn.Module): 209 | def __init__(self, d_model, dropout=0.1, max_len=20000): 210 | super(PositionalEncoding, self).__init__() 211 | self.dropout = nn.Dropout(p=dropout) 212 | 213 | pe = torch.zeros(max_len, d_model) 214 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 215 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 216 | pe[:, 0::2] = torch.sin(position * div_term) 217 | pe[:, 1::2] = torch.cos(position * div_term) 218 | pe = pe.unsqueeze(0) 219 | self.register_buffer('pe', pe) 220 | 221 | def forward(self, x): 222 | x = x + self.pe[:, :x.size(1), :] 223 | return self.dropout(x) 224 | 225 | 226 | class TransformerModel(nn.Module): 227 | def __init__(self, n_token, is_training=True): 228 | super(TransformerModel, self).__init__() 229 | 230 | # --- params config --- # 231 | self.n_token = n_token 232 | self.d_model = D_MODEL 233 | self.n_layer = N_LAYER # 234 | self.dropout = 0.1 235 | self.n_head = N_HEAD # 236 | self.d_head = D_MODEL // N_HEAD 237 | self.d_inner = 2048 238 | self.loss_func = nn.CrossEntropyLoss(reduction='none') 239 | self.emb_sizes = [128, 256, 64, 32, 512, 128, 128] 240 | 241 | # --- modules config --- # 242 | # embeddings 243 | print('>>>>>:', self.n_token) 244 | self.word_emb_tempo = Embeddings(self.n_token[0], self.emb_sizes[0]) 245 | self.word_emb_chord = Embeddings(self.n_token[1], self.emb_sizes[1]) 246 | self.word_emb_barbeat = Embeddings(self.n_token[2], self.emb_sizes[2]) 247 | self.word_emb_type = Embeddings(self.n_token[3], self.emb_sizes[3]) 248 | self.word_emb_pitch = Embeddings(self.n_token[4], self.emb_sizes[4]) 249 | self.word_emb_duration = Embeddings(self.n_token[5], self.emb_sizes[5]) 250 | self.word_emb_velocity = Embeddings(self.n_token[6], self.emb_sizes[6]) 251 | self.pos_emb = PositionalEncoding(self.d_model, self.dropout) 252 | 253 | # linear 254 | self.in_linear = nn.Linear(np.sum(self.emb_sizes), self.d_model) 255 | 256 | # encoder 257 | if is_training: 258 | # encoder (training) 259 | self.transformer_encoder = TransformerEncoderBuilder.from_kwargs( 260 | n_layers=self.n_layer, 261 | n_heads=self.n_head, 262 | query_dimensions=self.d_model//self.n_head, 263 | value_dimensions=self.d_model//self.n_head, 264 | feed_forward_dimensions=2048, 265 | activation='gelu', 266 | dropout=0.1, 267 | attention_type="causal-linear", 268 | ).get() 269 | else: 270 | # encoder (inference) 271 | print(' [o] using RNN backend.') 272 | self.transformer_encoder = RecurrentEncoderBuilder.from_kwargs( 273 | n_layers=self.n_layer, 274 | n_heads=self.n_head, 275 | query_dimensions=self.d_model//self.n_head, 276 | value_dimensions=self.d_model//self.n_head, 277 | feed_forward_dimensions=2048, 278 | activation='gelu', 279 | dropout=0.1, 280 | attention_type="causal-linear", 281 | ).get() 282 | 283 | # blend with type 284 | self.project_concat_type = nn.Linear(self.d_model + 32, self.d_model) 285 | 286 | # individual output 287 | self.proj_tempo = nn.Linear(self.d_model, self.n_token[0]) 288 | self.proj_chord = nn.Linear(self.d_model, self.n_token[1]) 289 | self.proj_barbeat = nn.Linear(self.d_model, self.n_token[2]) 290 | self.proj_type = nn.Linear(self.d_model, self.n_token[3]) 291 | self.proj_pitch = nn.Linear(self.d_model, self.n_token[4]) 292 | self.proj_duration = nn.Linear(self.d_model, self.n_token[5]) 293 | self.proj_velocity = nn.Linear(self.d_model, self.n_token[6]) 294 | 295 | def compute_loss(self, predict, target, loss_mask): 296 | loss = self.loss_func(predict, target) 297 | loss = loss * loss_mask 298 | loss = torch.sum(loss) / torch.sum(loss_mask) 299 | return loss 300 | 301 | def train_step(self, x, target, loss_mask): 302 | h, y_type = self.forward_hidden(x) 303 | y_tempo, y_chord, y_barbeat, y_pitch, y_duration, y_velocity = self.forward_output(h, target) 304 | 305 | # reshape (b, s, f) -> (b, f, s) 306 | y_tempo = y_tempo[:, ...].permute(0, 2, 1) 307 | y_chord = y_chord[:, ...].permute(0, 2, 1) 308 | y_barbeat = y_barbeat[:, ...].permute(0, 2, 1) 309 | y_type = y_type[:, ...].permute(0, 2, 1) 310 | y_pitch = y_pitch[:, ...].permute(0, 2, 1) 311 | y_duration = y_duration[:, ...].permute(0, 2, 1) 312 | y_velocity = y_velocity[:, ...].permute(0, 2, 1) 313 | 314 | # loss 315 | loss_tempo = self.compute_loss( 316 | y_tempo, target[..., 0], loss_mask) 317 | loss_chord = self.compute_loss( 318 | y_chord, target[..., 1], loss_mask) 319 | loss_barbeat = self.compute_loss( 320 | y_barbeat, target[..., 2], loss_mask) 321 | loss_type = self.compute_loss( 322 | y_type, target[..., 3], loss_mask) 323 | loss_pitch = self.compute_loss( 324 | y_pitch, target[..., 4], loss_mask) 325 | loss_duration = self.compute_loss( 326 | y_duration, target[..., 5], loss_mask) 327 | loss_velocity = self.compute_loss( 328 | y_velocity, target[..., 6], loss_mask) 329 | 330 | return loss_tempo, loss_chord, loss_barbeat, loss_type, loss_pitch, loss_duration, loss_velocity 331 | 332 | def forward_hidden(self, x, memory=None, is_training=True): 333 | ''' 334 | linear transformer: b x s x f 335 | x.shape=(bs, nf) 336 | ''' 337 | 338 | # embeddings 339 | emb_tempo = self.word_emb_tempo(x[..., 0]) 340 | emb_chord = self.word_emb_chord(x[..., 1]) 341 | emb_barbeat = self.word_emb_barbeat(x[..., 2]) 342 | emb_type = self.word_emb_type(x[..., 3]) 343 | emb_pitch = self.word_emb_pitch(x[..., 4]) 344 | emb_duration = self.word_emb_duration(x[..., 5]) 345 | emb_velocity = self.word_emb_velocity(x[..., 6]) 346 | 347 | embs = torch.cat( 348 | [ 349 | emb_tempo, 350 | emb_chord, 351 | emb_barbeat, 352 | emb_type, 353 | emb_pitch, 354 | emb_duration, 355 | emb_velocity, 356 | ], dim=-1) 357 | 358 | emb_linear = self.in_linear(embs) 359 | pos_emb = self.pos_emb(emb_linear) 360 | 361 | # assert False 362 | 363 | # transformer 364 | if is_training: 365 | # mask 366 | attn_mask = TriangularCausalMask(pos_emb.size(1), device=x.device) 367 | h = self.transformer_encoder(pos_emb, attn_mask) # y: b x s x d_model 368 | 369 | # project type 370 | y_type = self.proj_type(h) 371 | return h, y_type 372 | else: 373 | pos_emb = pos_emb.squeeze(0) 374 | h, memory = self.transformer_encoder(pos_emb, memory=memory) # y: s x d_model 375 | 376 | # project type 377 | y_type = self.proj_type(h) 378 | return h, y_type, memory 379 | 380 | def forward_output(self, h, y): 381 | ''' 382 | for training 383 | ''' 384 | tf_skip_type = self.word_emb_type(y[..., 3]) 385 | 386 | # project other 387 | y_concat_type = torch.cat([h, tf_skip_type], dim=-1) 388 | y_ = self.project_concat_type(y_concat_type) 389 | 390 | y_tempo = self.proj_tempo(y_) 391 | y_chord = self.proj_chord(y_) 392 | y_barbeat = self.proj_barbeat(y_) 393 | y_pitch = self.proj_pitch(y_) 394 | y_duration = self.proj_duration(y_) 395 | y_velocity = self.proj_velocity(y_) 396 | 397 | return y_tempo, y_chord, y_barbeat, y_pitch, y_duration, y_velocity 398 | 399 | def froward_output_sampling(self, h, y_type): 400 | ''' 401 | for inference 402 | ''' 403 | # sample type 404 | y_type_logit = y_type[0, :] 405 | cur_word_type = sampling(y_type_logit, p=0.90) 406 | 407 | type_word_t = torch.from_numpy( 408 | np.array([cur_word_type])).long().cuda().unsqueeze(0) 409 | 410 | tf_skip_type = self.word_emb_type(type_word_t).squeeze(0) 411 | 412 | # concat 413 | y_concat_type = torch.cat([h, tf_skip_type], dim=-1) 414 | y_ = self.project_concat_type(y_concat_type) 415 | 416 | # project other 417 | y_tempo = self.proj_tempo(y_) 418 | y_chord = self.proj_chord(y_) 419 | y_barbeat = self.proj_barbeat(y_) 420 | 421 | y_pitch = self.proj_pitch(y_) 422 | y_duration = self.proj_duration(y_) 423 | y_velocity = self.proj_velocity(y_) 424 | 425 | # sampling gen_cond 426 | cur_word_tempo = sampling(y_tempo, t=1.2, p=0.9) 427 | cur_word_barbeat = sampling(y_barbeat, t=1.2) 428 | cur_word_chord = sampling(y_chord, p=0.99) 429 | cur_word_pitch = sampling(y_pitch, p=0.9) 430 | cur_word_duration = sampling(y_duration, t=2, p=0.9) 431 | cur_word_velocity = sampling(y_velocity, t=5) 432 | 433 | # collect 434 | next_arr = np.array([ 435 | cur_word_tempo, 436 | cur_word_chord, 437 | cur_word_barbeat, 438 | cur_word_type, 439 | cur_word_pitch, 440 | cur_word_duration, 441 | cur_word_velocity, 442 | ]) 443 | return next_arr 444 | 445 | def inference_from_scratch(self, dictionary): 446 | event2word, word2event = dictionary 447 | classes = word2event.keys() 448 | 449 | def print_word_cp(cp): 450 | result = [word2event[k][cp[idx]] for idx, k in enumerate(classes)] 451 | 452 | for r in result: 453 | print('{:15s}'.format(str(r)), end=' | ') 454 | print('') 455 | 456 | init = np.array([ 457 | [0, 0, 1, 1, 0, 0, 0], # bar 458 | ]) 459 | 460 | cnt_token = len(init) 461 | with torch.no_grad(): 462 | final_res = [] 463 | memory = None 464 | h = None 465 | 466 | cnt_bar = 1 467 | init_t = torch.from_numpy(init).long().cuda() 468 | print('------ initiate ------') 469 | for step in range(init.shape[0]): 470 | print_word_cp(init[step, :]) 471 | input_ = init_t[step, :].unsqueeze(0).unsqueeze(0) 472 | final_res.append(init[step, :][None, ...]) 473 | 474 | h, y_type, memory = self.forward_hidden( 475 | input_, memory, is_training=False) 476 | 477 | print('------ generate ------') 478 | while(True): 479 | # sample others 480 | next_arr = self.froward_output_sampling(h, y_type) 481 | final_res.append(next_arr[None, ...]) 482 | print('bar:', cnt_bar, end= ' ==') 483 | print_word_cp(next_arr) 484 | 485 | # forward 486 | input_ = torch.from_numpy(next_arr).long().cuda() 487 | input_ = input_.unsqueeze(0).unsqueeze(0) 488 | h, y_type, memory = self.forward_hidden( 489 | input_, memory, is_training=False) 490 | 491 | # end of sequence 492 | if word2event['type'][next_arr[3]] == 'EOS': 493 | break 494 | 495 | if word2event['bar-beat'][next_arr[2]] == 'Bar': 496 | cnt_bar += 1 497 | 498 | print('\n--------[Done]--------') 499 | final_res = np.concatenate(final_res) 500 | print(final_res.shape) 501 | return final_res 502 | 503 | 504 | ########################################################################################################################## 505 | # Script 506 | ########################################################################################################################## 507 | 508 | 509 | def train(): 510 | # hyper params 511 | n_epoch = 4000 512 | max_grad_norm = 3 513 | 514 | # load 515 | dictionary = pickle.load(open(path_dictionary, 'rb')) 516 | event2word, word2event = dictionary 517 | train_data = np.load(path_train_data) 518 | 519 | # create saver 520 | saver_agent = saver.Saver(path_exp) 521 | 522 | # config 523 | n_class = [] 524 | for key in event2word.keys(): 525 | n_class.append(len(dictionary[0][key])) 526 | 527 | # log 528 | print('num of classes:', n_class) 529 | 530 | # init 531 | net = TransformerModel(n_class) 532 | net.cuda() 533 | net.train() 534 | n_parameters = network_paras(net) 535 | print('n_parameters: {:,}'.format(n_parameters)) 536 | saver_agent.add_summary_msg( 537 | ' > params amount: {:,d}'.format(n_parameters)) 538 | 539 | # load model 540 | if info_load_model: 541 | path_ckpt = info_load_model[0] # path to ckpt dir 542 | loss = info_load_model[1] # loss 543 | name = 'loss_' + str(loss) 544 | path_saved_ckpt = os.path.join(path_ckpt, name + '_params.pt') 545 | print('[*] load model from:', path_saved_ckpt) 546 | net.load_state_dict(torch.load(path_saved_ckpt)) 547 | 548 | # optimizers 549 | optimizer = optim.Adam(net.parameters(), lr=init_lr) 550 | 551 | # unpack 552 | train_x = train_data['x'] 553 | train_y = train_data['y'] 554 | train_mask = train_data['mask'] 555 | num_batch = len(train_x) // batch_size 556 | 557 | print(' num_batch:', num_batch) 558 | print(' train_x:', train_x.shape) 559 | print(' train_y:', train_y.shape) 560 | print(' train_mask:', train_mask.shape) 561 | 562 | # run 563 | start_time = time.time() 564 | for epoch in range(n_epoch): 565 | acc_loss = 0 566 | acc_losses = np.zeros(7) 567 | 568 | for bidx in range(num_batch): # num_batch 569 | saver_agent.global_step_increment() 570 | 571 | # index 572 | bidx_st = batch_size * bidx 573 | bidx_ed = batch_size * (bidx + 1) 574 | 575 | # unpack batch data 576 | batch_x = train_x[bidx_st:bidx_ed] 577 | batch_y = train_y[bidx_st:bidx_ed] 578 | batch_mask = train_mask[bidx_st:bidx_ed] 579 | 580 | # to tensor 581 | batch_x = torch.from_numpy(batch_x).long().cuda() 582 | batch_y = torch.from_numpy(batch_y).long().cuda() 583 | batch_mask = torch.from_numpy(batch_mask).float().cuda() 584 | 585 | # run 586 | losses = net.train_step(batch_x, batch_y, batch_mask) 587 | loss = (losses[0] + losses[1] + losses[2] + losses[3] + losses[4] + losses[5] + losses[6]) / 7 588 | 589 | # Update 590 | net.zero_grad() 591 | loss.backward() 592 | if max_grad_norm is not None: 593 | clip_grad_norm_(net.parameters(), max_grad_norm) 594 | optimizer.step() 595 | 596 | # print 597 | sys.stdout.write('{}/{} | Loss: {:06f} | {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}\r'.format( 598 | bidx, num_batch, loss, losses[0], losses[1], losses[2], losses[3], losses[4], losses[5], losses[6])) 599 | sys.stdout.flush() 600 | 601 | # acc 602 | acc_losses += np.array([l.item() for l in losses]) 603 | acc_loss += loss.item() 604 | 605 | # log 606 | saver_agent.add_summary('batch loss', loss.item()) 607 | 608 | # epoch loss 609 | runtime = time.time() - start_time 610 | epoch_loss = acc_loss / num_batch 611 | acc_losses = acc_losses / num_batch 612 | print('------------------------------------') 613 | print('epoch: {}/{} | Loss: {} | time: {}'.format( 614 | epoch, n_epoch, epoch_loss, str(datetime.timedelta(seconds=runtime)))) 615 | each_loss_str = '{:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}\r'.format( 616 | acc_losses[0], acc_losses[1], acc_losses[2], acc_losses[3], acc_losses[4], acc_losses[5], acc_losses[6]) 617 | print(' >', each_loss_str) 618 | 619 | saver_agent.add_summary('epoch loss', epoch_loss) 620 | saver_agent.add_summary('epoch each loss', each_loss_str) 621 | 622 | # save model, with policy 623 | loss = epoch_loss 624 | if 0.4 < loss <= 0.8: 625 | fn = int(loss * 10) * 10 626 | saver_agent.save_model(net, name='loss_' + str(fn)) 627 | elif 0.05 < loss <= 0.40: 628 | fn = int(loss * 100) 629 | saver_agent.save_model(net, name='loss_' + str(fn)) 630 | elif loss <= 0.05: 631 | print('Finished') 632 | return 633 | else: 634 | saver_agent.save_model(net, name='loss_high') 635 | 636 | 637 | def generate(): 638 | # path 639 | path_ckpt = info_load_model[0] # path to ckpt dir 640 | loss = info_load_model[1] # loss 641 | name = 'loss_' + str(loss) 642 | path_saved_ckpt = os.path.join(path_ckpt, name + '_params.pt') 643 | 644 | # load 645 | dictionary = pickle.load(open(path_dictionary, 'rb')) 646 | event2word, word2event = dictionary 647 | 648 | # outdir 649 | os.makedirs(path_gendir, exist_ok=True) 650 | 651 | # config 652 | n_class = [] 653 | for key in event2word.keys(): 654 | n_class.append(len(dictionary[0][key])) 655 | 656 | # init model 657 | net = TransformerModel(n_class, is_training=False) 658 | net.cuda() 659 | net.eval() 660 | 661 | # load model 662 | print('[*] load model from:', path_saved_ckpt) 663 | net.load_state_dict(torch.load(path_saved_ckpt)) 664 | 665 | # gen 666 | start_time = time.time() 667 | song_time_list = [] 668 | words_len_list = [] 669 | 670 | cnt_tokens_all = 0 671 | sidx = 0 672 | while sidx < num_songs: 673 | try: 674 | start_time = time.time() 675 | print('current idx:', sidx) 676 | path_outfile = os.path.join(path_gendir, 'get_{}.mid'.format(str(sidx))) 677 | 678 | res = net.inference_from_scratch(dictionary) 679 | write_midi(res, path_outfile, word2event) 680 | 681 | song_time = time.time() - start_time 682 | word_len = len(res) 683 | print('song time:', song_time) 684 | print('word_len:', word_len) 685 | words_len_list.append(word_len) 686 | song_time_list.append(song_time) 687 | 688 | sidx += 1 689 | except KeyboardInterrupt: 690 | raise ValueError(' [x] terminated.') 691 | except: 692 | continue 693 | 694 | print('ave token time:', sum(words_len_list) / sum(song_time_list)) 695 | print('ave song time:', np.mean(song_time_list)) 696 | 697 | runtime_result = { 698 | 'song_time':song_time_list, 699 | 'words_len_list': words_len_list, 700 | 'ave token time:': sum(words_len_list) / sum(song_time_list), 701 | 'ave song time': float(np.mean(song_time_list)), 702 | } 703 | 704 | with open('runtime_stats.json', 'w') as f: 705 | json.dump(runtime_result, f) 706 | 707 | 708 | if __name__ == '__main__': 709 | # -- training -- # 710 | if MODE == 'train': 711 | train() 712 | 713 | # -- inference -- # 714 | elif MODE == 'inference': 715 | generate() 716 | 717 | else: 718 | pass 719 | -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/runtime_stats.json: -------------------------------------------------------------------------------- 1 | {"song_time": [22.15657639503479, 22.193533420562744, 31.92780828475952, 19.54995632171631, 16.9428653717041, 14.89795708656311, 23.36419653892517, 17.333919048309326, 21.232710123062134, 17.8750319480896, 20.08845090866089, 20.505202531814575, 21.045650482177734, 25.3824942111969, 19.22031259536743, 25.14116883277893, 26.373836994171143, 17.95224666595459, 24.632129192352295, 22.872940063476562, 18.909109115600586, 23.45732831954956, 17.696035146713257, 19.6449134349823, 21.48639941215515, 20.14349365234375, 18.929306983947754, 17.5115327835083, 20.62809991836548, 26.35112738609314, 20.624415159225464, 21.29854154586792, 19.991357803344727, 20.77635097503662, 15.483854293823242, 21.372031688690186, 21.510165691375732, 26.04141879081726, 18.135209560394287, 18.571840286254883, 23.6661434173584, 22.40087866783142, 23.71457076072693, 19.199822187423706, 14.656707763671875, 18.02313542366028, 23.298468351364136, 16.629225969314575, 21.779688596725464, 23.35808539390564], "words_len_list": [1500, 1466, 2122, 1273, 1150, 1075, 1595, 1141, 1483, 1272, 1461, 1434, 1464, 1719, 1339, 1702, 1874, 1241, 1691, 1619, 1302, 1605, 1169, 1335, 1458, 1436, 1330, 1128, 1466, 1842, 1373, 1399, 1365, 1350, 1051, 1376, 1455, 1825, 1188, 1175, 1585, 1490, 1679, 1349, 1039, 1304, 1666, 1153, 1447, 1545], "ave token time:": 68.362798469141, "ave song time": 20.919564909934998} -------------------------------------------------------------------------------- /workspace/uncond/cp-linear/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import logging 5 | import datetime 6 | import collections 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | class Saver(object): 12 | def __init__( 13 | self, 14 | exp_dir, 15 | mode='w'): 16 | 17 | self.exp_dir = exp_dir 18 | self.init_time = time.time() 19 | self.global_step = 0 20 | 21 | # makedirs 22 | os.makedirs(exp_dir, exist_ok=True) 23 | 24 | # logging config 25 | path_logger = os.path.join(exp_dir, 'log.txt') 26 | logging.basicConfig( 27 | level=logging.DEBUG, 28 | format='%(message)s', 29 | filename=path_logger, 30 | filemode=mode) 31 | self.logger = logging.getLogger('training monitor') 32 | 33 | def add_summary_msg(self, msg): 34 | self.logger.debug(msg) 35 | 36 | def add_summary( 37 | self, 38 | key, 39 | val, 40 | step=None, 41 | cur_time=None): 42 | 43 | if cur_time is None: 44 | cur_time = time.time() - self.init_time 45 | if step is None: 46 | step = self.global_step 47 | 48 | # write msg (key, val, step, time) 49 | if isinstance(val, float): 50 | msg_str = '{:10s} | {:.10f} | {:10d} | {}'.format( 51 | key, 52 | val, 53 | step, 54 | cur_time 55 | ) 56 | else: 57 | msg_str = '{:10s} | {} | {:10d} | {}'.format( 58 | key, 59 | val, 60 | step, 61 | cur_time 62 | ) 63 | 64 | self.logger.debug(msg_str) 65 | 66 | def save_model( 67 | self, 68 | model, 69 | optimizer=None, 70 | outdir=None, 71 | name='model'): 72 | 73 | if outdir is None: 74 | outdir = self.exp_dir 75 | print(' [*] saving model to {}, name: {}'.format(outdir, name)) 76 | torch.save(model, os.path.join(outdir, name+'.pt')) 77 | torch.save(model.state_dict(), os.path.join(outdir, name+'_params.pt')) 78 | 79 | if optimizer is not None: 80 | torch.save(optimizer.state_dict(), os.path.join(outdir, name+'_opt.pt')) 81 | 82 | def load_model( 83 | self, 84 | path_exp, 85 | device='cpu', 86 | name='model.pt'): 87 | 88 | path_pt = os.path.join(path_exp, name) 89 | print(' [*] restoring model from', path_pt) 90 | model = torch.load(path_pt, map_location=torch.device(device)) 91 | return model 92 | 93 | def global_step_increment(self): 94 | self.global_step += 1 95 | 96 | """ 97 | file modes 98 | 'a': 99 | Opens a file for appending. The file pointer is at the end of the file if the file exists. 100 | That is, the file is in the append mode. If the file does not exist, it creates a new file for writing. 101 | 102 | 'w': 103 | Opens a file for writing only. Overwrites the file if the file exists. 104 | If the file does not exist, creates a new file for writing. 105 | """ 106 | 107 | def make_loss_report( 108 | path_log, 109 | path_figure='loss.png', 110 | dpi=100): 111 | 112 | # load logfile 113 | monitor_vals = collections.defaultdict(list) 114 | with open(path_logfile, 'r') as f: 115 | for line in f: 116 | try: 117 | line = line.strip() 118 | key, val, step, acc_time = line.split(' | ') 119 | monitor_vals[key].append((float(val), int(step), acc_time)) 120 | except: 121 | continue 122 | 123 | # collect 124 | step_train = [item[1] for item in monitor_vals['train loss']] 125 | vals_train = [item[0] for item in monitor_vals['train loss']] 126 | 127 | step_valid = [item[1] for item in monitor_vals['valid loss']] 128 | vals_valid = [item[0] for item in monitor_vals['valid loss']] 129 | 130 | x_min = step_valid[np.argmin(vals_valid)] 131 | y_min = min(vals_valid) 132 | 133 | # plot 134 | fig = plt.figure(dpi=dpi) 135 | plt.title('training process') 136 | plt.plot(step_train, vals_train, label='train') 137 | plt.plot(step_valid, vals_valid, label='valid') 138 | plt.yscale('log') 139 | plt.plot([x_min], [y_min], 'ro') 140 | plt.legend(loc='upper right') 141 | plt.tight_layout() 142 | plt.savefig(path_figure) 143 | 144 | ''' 145 | author: wayn391@mastertones 146 | ''' 147 | 148 | import os 149 | import time 150 | import torch 151 | import logging 152 | import datetime 153 | import collections 154 | import numpy as np 155 | import matplotlib.pyplot as plt 156 | 157 | 158 | class Saver(object): 159 | def __init__( 160 | self, 161 | exp_dir, 162 | mode='w'): 163 | 164 | self.exp_dir = exp_dir 165 | self.init_time = time.time() 166 | self.global_step = 0 167 | 168 | # makedirs 169 | os.makedirs(exp_dir, exist_ok=True) 170 | 171 | # logging config 172 | path_logger = os.path.join(exp_dir, 'log.txt') 173 | logging.basicConfig( 174 | level=logging.DEBUG, 175 | format='%(message)s', 176 | filename=path_logger, 177 | filemode=mode) 178 | self.logger = logging.getLogger('training monitor') 179 | 180 | def add_summary_msg(self, msg): 181 | self.logger.debug(msg) 182 | 183 | def add_summary( 184 | self, 185 | key, 186 | val, 187 | step=None, 188 | cur_time=None): 189 | 190 | if cur_time is None: 191 | cur_time = time.time() - self.init_time 192 | if step is None: 193 | step = self.global_step 194 | 195 | # write msg (key, val, step, time) 196 | if isinstance(val, float): 197 | msg_str = '{:10s} | {:.10f} | {:10d} | {}'.format( 198 | key, 199 | val, 200 | step, 201 | cur_time 202 | ) 203 | else: 204 | msg_str = '{:10s} | {} | {:10d} | {}'.format( 205 | key, 206 | val, 207 | step, 208 | cur_time 209 | ) 210 | 211 | self.logger.debug(msg_str) 212 | 213 | def save_model( 214 | self, 215 | model, 216 | optimizer=None, 217 | outdir=None, 218 | name='model'): 219 | 220 | if outdir is None: 221 | outdir = self.exp_dir 222 | print(' [*] saving model to {}, name: {}'.format(outdir, name)) 223 | # torch.save(model, os.path.join(outdir, name+'.pt')) 224 | torch.save(model.state_dict(), os.path.join(outdir, name+'_params.pt')) 225 | 226 | if optimizer is not None: 227 | torch.save(optimizer.state_dict(), os.path.join(outdir, name+'_opt.pt')) 228 | 229 | def load_model( 230 | self, 231 | path_exp, 232 | device='cpu', 233 | name='model.pt'): 234 | 235 | path_pt = os.path.join(path_exp, name) 236 | print(' [*] restoring model from', path_pt) 237 | model = torch.load(path_pt, map_location=torch.device(device)) 238 | return model 239 | 240 | def global_step_increment(self): 241 | self.global_step += 1 242 | 243 | """ 244 | file modes 245 | 'a': 246 | Opens a file for appending. The file pointer is at the end of the file if the file exists. 247 | That is, the file is in the append mode. If the file does not exist, it creates a new file for writing. 248 | 249 | 'w': 250 | Opens a file for writing only. Overwrites the file if the file exists. 251 | If the file does not exist, creates a new file for writing. 252 | """ 253 | 254 | def make_loss_report( 255 | path_log, 256 | path_figure='loss.png', 257 | dpi=100): 258 | 259 | # load logfile 260 | monitor_vals = collections.defaultdict(list) 261 | with open(path_logfile, 'r') as f: 262 | for line in f: 263 | try: 264 | line = line.strip() 265 | key, val, step, acc_time = line.split(' | ') 266 | monitor_vals[key].append((float(val), int(step), acc_time)) 267 | except: 268 | continue 269 | 270 | # collect 271 | step_train = [item[1] for item in monitor_vals['train loss']] 272 | vals_train = [item[0] for item in monitor_vals['train loss']] 273 | 274 | step_valid = [item[1] for item in monitor_vals['valid loss']] 275 | vals_valid = [item[0] for item in monitor_vals['valid loss']] 276 | 277 | x_min = step_valid[np.argmin(vals_valid)] 278 | y_min = min(vals_valid) 279 | 280 | # plot 281 | fig = plt.figure(dpi=dpi) 282 | plt.title('training process') 283 | plt.plot(step_train, vals_train, label='train') 284 | plt.plot(step_valid, vals_valid, label='valid') 285 | plt.yscale('log') 286 | plt.plot([x_min], [y_min], 'ro') 287 | plt.legend(loc='upper right') 288 | plt.tight_layout() 289 | plt.savefig(path_figure) 290 | 291 | -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/config.yml: -------------------------------------------------------------------------------- 1 | 2 | MODEL: 3 | n_head: 8 4 | n_layer: 12 5 | dropout: 0.1 6 | d_inner: 2048 #d_ff 7 | d_embed: 512 8 | d_model: 512 9 | dropatt: 0.0 #attention probability dropout rate 10 | query_dim: 16 #64 11 | seq_len: 512 #512 12 | n_token: 332 13 | mem_len: 512 14 | ext_len: 0 15 | tgt_len: 70 16 | eval_tgt_len: 50 17 | init: 'normal' #parameter initializer to use. 18 | emb_init: 'normal' #parameter initializer to use. 19 | init_range: 0.1 20 | emb_init_range: 0.01 #parameters initialized by U(-init_range, init_range) 21 | init_std: 0.02 #parameters initialized by N(0, init_std) 22 | proj_init_std: 0.01 23 | clamp_len: -1 #use the same pos embeddings after clamp_len 24 | div_val: 1 25 | position_concat: False 26 | pre_lnorm: True #apply LayerNorm to the input instead of the output 27 | same_length: True #use the same attn length for all tokens 28 | 29 | 30 | TRAIN: 31 | ROOT: '../../../dataset/representations/uncond/remi/ailab17k_from-scratch_remi' 32 | gpuID: '1' 33 | output_dir: "./exp" 34 | batch_size: 10 #5 35 | lr: 0.0002 36 | num_epochs: 600 37 | save_freq: 10 38 | seed: 2222 39 | optim: 'adam' 40 | no_cuda: False 41 | resume_training_model: None 42 | # resume_training_model: '/volume/ai-music-wayne/aaai/from-scratch/remi-xl_review/result/20200901-064426/ep_170.pth.tar' 43 | 44 | 45 | INFERENCE: 46 | num_sample: 20 47 | gpuID: '1' 48 | dictionary_path: '../../../dataset/representations/uncond/remi/ailab17k_from-scratch_remi/dictionary.pkl' 49 | experiment_dir: '/volume/ai-music-wayne/aaai/from-scratch/remi-xl_review/result/20200901-064426' 50 | generated_dir: './gen_midis' 51 | checkpoint_type: epoch_idx # best_train, best_val, epoch_idx 52 | model_epoch: 170 53 | no_cuda: False 54 | -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_0.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_0.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_1.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_10.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_10.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_11.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_11.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_12.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_12.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_13.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_13.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_14.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_14.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_15.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_15.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_16.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_16.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_17.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_17.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_18.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_18.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_19.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_19.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_2.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_3.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_4.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_4.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_5.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_5.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_6.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_6.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_7.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_7.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_8.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_8.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/gen_midis/170_9.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_9.mid -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/inference.py: -------------------------------------------------------------------------------- 1 | from model import TransformerXL 2 | import pickle 3 | import random 4 | import os 5 | import time 6 | import torch 7 | import random 8 | import yaml 9 | import json 10 | 11 | import numpy as np 12 | 13 | 14 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' 15 | 16 | def main(): 17 | cfg = yaml.full_load(open("config.yml", 'r')) 18 | inferenceConfig = cfg['INFERENCE'] 19 | 20 | os.environ['CUDA_VISIBLE_DEVICES'] = inferenceConfig['gpuID'] 21 | 22 | print('='*2, 'Inferenc configs', '='*5) 23 | print(json.dumps(inferenceConfig, indent=1, sort_keys=True)) 24 | 25 | # checkpoint information 26 | CHECKPOINT_FOLDER = inferenceConfig['experiment_dir'] 27 | midi_folder = inferenceConfig["generated_dir"] 28 | 29 | checkpoint_type = inferenceConfig['checkpoint_type'] 30 | if checkpoint_type == 'best_train': 31 | model_path = os.path.join(CHECKPOINT_FOLDER, 'model_best.pth.tar') 32 | output_prefix = 'best_train_' 33 | elif checkpoint_type == 'best_val': 34 | model_path = os.path.join(CHECKPOINT_FOLDER, 'model_best_val.pth.tar') 35 | output_prefix = 'best_val_' 36 | elif checkpoint_type == 'epoch_idx': 37 | model_path = os.path.join(CHECKPOINT_FOLDER, 'ep_{}.pth.tar'.format(str(inferenceConfig['model_epoch']))) 38 | output_prefix = str(inferenceConfig['model_epoch'])+ '_' 39 | 40 | pretrainCfg = yaml.full_load(open(os.path.join(CHECKPOINT_FOLDER,"config.yml"), 'r')) 41 | modelConfig = pretrainCfg['MODEL'] 42 | 43 | # create result folder 44 | if not os.path.exists(midi_folder): 45 | os.mkdir(midi_folder) 46 | 47 | # load dictionary 48 | event2word, word2event = pickle.load(open(inferenceConfig['dictionary_path'], 'rb')) 49 | 50 | # declare model 51 | device = torch.device("cuda" if not inferenceConfig["no_cuda"] and torch.cuda.is_available() else "cpu") 52 | print('Device to generate:', device) 53 | 54 | # declare model 55 | model = TransformerXL( 56 | modelConfig, 57 | device, 58 | event2word=event2word, 59 | word2event=word2event, 60 | is_training=False) 61 | 62 | # inference 63 | song_time_list = [] 64 | words_len_list = [] 65 | num_samples = inferenceConfig["num_sample"] 66 | for idx in range(num_samples): 67 | print(f'==={idx}/{num_samples}===') 68 | print(midi_folder, output_prefix + str(idx)) 69 | song_time, word_len = model.inference( 70 | model_path = model_path, 71 | token_lim=7680, 72 | strategies=['temperature', 'nucleus'], 73 | params={'t': 1.2, 'p': 0.9}, 74 | bpm=120, 75 | output_path='{}/{}.mid'.format(midi_folder, output_prefix + str(idx))) 76 | 77 | print('song time:', song_time) 78 | print('word_len:', word_len) 79 | words_len_list.append(word_len) 80 | song_time_list.append(song_time) 81 | 82 | 83 | print('ave token time:', sum(words_len_list) / sum(song_time_list)) 84 | print('ave song time:', np.mean(song_time_list)) 85 | 86 | runtime_result = { 87 | 'song_time':song_time_list, 88 | 'words_len_list': words_len_list, 89 | 'ave token time:': sum(words_len_list) / sum(song_time_list), 90 | 'ave song time': float(np.mean(song_time_list)), 91 | } 92 | 93 | 94 | with open('runtime_stats.json', 'w') as f: 95 | json.dump(runtime_result, f) 96 | 97 | if __name__ == '__main__': 98 | main() -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import math 6 | import numpy as np 7 | import pandas as pd 8 | import miditoolkit 9 | import shutil 10 | import copy 11 | import os 12 | import time 13 | import json 14 | from sklearn.model_selection import train_test_split 15 | from modules import MemTransformerLM 16 | from glob import glob 17 | 18 | import miditoolkit 19 | from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note 20 | import collections 21 | import pickle 22 | import numpy as np 23 | 24 | import saver 25 | 26 | # ================================ # 27 | BEAT_RESOL = 480 28 | BAR_RESOL = BEAT_RESOL * 4 29 | TICK_RESOL = BEAT_RESOL // 4 30 | INSTR_NAME_MAP = {'piano': 0, 'melody': 1} 31 | 32 | 33 | def wrtie_midi(words, path_midi, word2event): 34 | notes_all = [] 35 | 36 | events = [word2event[words[i]] for i in range(len(words))] 37 | 38 | bar_cnt = 0 39 | cur_beat = 0 40 | 41 | midi_obj = miditoolkit.midi.parser.MidiFile() 42 | cur_pos = 0 43 | 44 | for i in range(len(events)-3): 45 | cur_event = events[i] 46 | # print(cur_event) 47 | name = cur_event.split('_')[0] 48 | attr = cur_event.split('_') 49 | if name == 'Bar': 50 | bar_cnt += 1 51 | elif name == 'Beat': 52 | cur_beat = int(attr[1]) 53 | cur_pos = bar_cnt * BAR_RESOL + cur_beat * TICK_RESOL 54 | elif name == 'Chord': 55 | chord_text = attr[1] + '_' + attr[2] 56 | midi_obj.markers.append(Marker(text=chord_text, time=cur_pos)) 57 | elif name == 'Tempo': 58 | midi_obj.tempo_changes.append( 59 | TempoChange(tempo=int(attr[1]), time=cur_pos)) 60 | else: 61 | if 'Note_Pitch' in events[i] and \ 62 | 'Note_Velocity' in events[i+1] and \ 63 | 'Note_Duration' in events[i+2]: 64 | 65 | pitch = int(events[i].split('_')[-1]) 66 | duration = int(events[i+2].split('_')[-1]) 67 | 68 | if int(duration) == 0: 69 | duration = 60 70 | 71 | end = cur_pos + duration 72 | velocity = int(events[i+1].split('_')[-1]) 73 | notes_all.append( 74 | Note(pitch=pitch, start=cur_pos, end=end, velocity=velocity)) 75 | 76 | piano_track = Instrument(0, is_drum=False, name='piano') 77 | piano_track.notes = notes_all 78 | midi_obj.instruments = [piano_track] 79 | midi_obj.dump(path_midi) 80 | 81 | 82 | # ================================ # 83 | def network_paras(model): 84 | # compute only trainable params 85 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 86 | params = sum([np.prod(p.size()) for p in model_parameters]) 87 | return params 88 | 89 | 90 | class TransformerXL(object): 91 | def __init__(self, modelConfig, device, event2word, word2event, is_training=True): 92 | 93 | self.event2word = event2word 94 | self.word2event = word2event 95 | self.modelConfig = modelConfig 96 | 97 | # model settings 98 | self.n_layer= modelConfig['n_layer'] 99 | self.d_model = modelConfig['d_model'] 100 | self.seq_len= modelConfig['seq_len'] 101 | self.mem_len = modelConfig['mem_len'] 102 | 103 | self.tgt_len = modelConfig['tgt_len'] 104 | self.ext_len = modelConfig['ext_len'] 105 | self.eval_tgt_len = modelConfig['eval_tgt_len'] 106 | 107 | self.init = modelConfig['init'] 108 | self.init_range = modelConfig['init_range'] 109 | self.init_std = modelConfig['init_std'] 110 | self.proj_init_std = modelConfig['proj_init_std'] 111 | 112 | #mode 113 | self.is_training = is_training 114 | self.device = device 115 | 116 | 117 | def init_weight(self, weight): 118 | if self.init == 'uniform': 119 | nn.init.uniform_(weight, -self.init_range, self.init_range) 120 | elif self.init == 'normal': 121 | nn.init.normal_(weight, 0.0, self.init_std) 122 | 123 | def init_bias(self, bias): 124 | nn.init.constant_(bias, 0.0) 125 | 126 | def weights_init(self,m): 127 | classname = m.__class__.__name__ 128 | if classname.find('Linear') != -1: 129 | if hasattr(m, 'weight') and m.weight is not None: 130 | self.init_weight(m.weight) 131 | if hasattr(m, 'bias') and m.bias is not None: 132 | self.init_bias(m.bias) 133 | elif classname.find('Embedding') != -1: 134 | if hasattr(m, 'weight'): 135 | self.init_weight(m.weight) 136 | elif classname.find('LayerNorm') != -1: 137 | if hasattr(m, 'weight'): 138 | nn.init.normal_(m.weight, 1.0, self.init_std) 139 | if hasattr(m, 'bias') and m.bias is not None: 140 | self.init_bias(m.bias) 141 | elif classname.find('TransformerLM') != -1: 142 | if hasattr(m, 'r_emb'): 143 | self.init_weight(m.r_emb) 144 | if hasattr(m, 'r_w_bias'): 145 | self.init_weight(m.r_w_bias) 146 | if hasattr(m, 'r_r_bias'): 147 | self.init_weight(m.r_r_bias) 148 | if hasattr(m, 'r_bias'): 149 | self.init_bias(m.r_bias) 150 | 151 | 152 | def get_model(self, pretrain_model=None): 153 | model = MemTransformerLM(self.modelConfig, is_training=self.is_training) 154 | 155 | st_eopch = 0 156 | if pretrain_model: 157 | checkpoint = torch.load(pretrain_model, map_location='cuda:0') 158 | print('Pretrained model config:') 159 | print('epoch: ', checkpoint['epoch']) 160 | print('best_loss: ', checkpoint['best_loss']) 161 | print(json.dumps(checkpoint['model_setting'], indent=1, sort_keys=True)) 162 | print(json.dumps(checkpoint['train_setting'], indent=1, sort_keys=True)) 163 | 164 | try: 165 | model.load_state_dict(checkpoint['state_dict']) 166 | print('{} loaded.'.format(pretrain_model)) 167 | except: 168 | print('Loaded weights have different shapes with the model. Please check your model setting.') 169 | exit() 170 | st_eopch = checkpoint['epoch'] + 1 171 | 172 | else: 173 | model.apply(self.weights_init) 174 | model.word_emb.apply(self.weights_init) 175 | return st_eopch ,model.to(self.device) 176 | 177 | 178 | def save_checkpoint(self, state, root, save_freq=10): 179 | if state['epoch'] % save_freq == 0: 180 | torch.save(state, os.path.join(root,'ep_{}.pth.tar'.format(state['epoch']))) 181 | 182 | def train_loss_record(self, epoch, train_loss,checkpoint_dir, val_loss=None): 183 | 184 | if val_loss: 185 | df = pd.DataFrame({'epoch': [epoch+1], 186 | 'train_loss': ['%.3f'%train_loss], 187 | 'val_loss': ['%.3f'%val_loss]}) 188 | 189 | else: 190 | df = pd.DataFrame({'epoch': [epoch+1], 191 | 'train_loss': ['%.3f'%train_loss]}) 192 | 193 | csv_file = os.path.join(checkpoint_dir, 'loss.csv') 194 | 195 | if not os.path.exists(csv_file): 196 | df.to_csv(csv_file, index=False) 197 | else: 198 | df.to_csv(os.path.join(checkpoint_dir, 'loss.csv'), mode='a', header=False, index=False) 199 | 200 | def train(self, train_data, trainConfig, device, resume): 201 | checkpoint_dir = trainConfig['experiment_Dir'] 202 | batch_size = trainConfig['batch_size'] 203 | data_ROOT = trainConfig['ROOT'] 204 | torch.manual_seed(trainConfig["seed"]) 205 | 206 | # create saver 207 | saver_agent = saver.Saver(checkpoint_dir) 208 | 209 | #Prepare model 210 | if resume != 'None': 211 | st_epoch, model = self.get_model(resume) 212 | print('Continue to train from {} epoch'.format(st_epoch)) 213 | else: 214 | st_epoch, model = self.get_model() 215 | 216 | optimizer = optim.Adam(model.parameters(), lr=trainConfig['lr']) 217 | train_step = 0 218 | epoch_train_loss = [] 219 | save_freq = trainConfig['save_freq'] 220 | 221 | n_parameters = network_paras(model) 222 | print('n_parameters: {:,}'.format(n_parameters)) 223 | saver_agent.add_summary_msg( 224 | ' > params amount: {:,d}'.format(n_parameters)) 225 | 226 | # unpack 227 | train_x = train_data['x'] 228 | train_y = train_data['y'] 229 | mask = train_data['mask'] 230 | num_groups = train_data['num_groups'] 231 | 232 | num_batches = len(train_x ) // batch_size 233 | 234 | print('>>> Start training') 235 | for epoch in range(st_epoch, trainConfig['num_epochs']): 236 | saver_agent.global_step_increment() 237 | 238 | train_loss = [] 239 | st_time = time.time() 240 | model.train() 241 | 242 | for bidx in range(num_batches): 243 | 244 | model.zero_grad() 245 | 246 | # index 247 | bidx_st = batch_size * bidx 248 | bidx_ed = batch_size * (bidx + 1) 249 | 250 | # get batch 251 | batch_x = train_x[bidx_st:bidx_ed] 252 | batch_y = train_y[bidx_st:bidx_ed] 253 | batch_mask = mask[bidx_st:bidx_ed] 254 | n_group = np.max(num_groups[bidx_st:bidx_ed]) 255 | 256 | # proc groups 257 | mems = tuple() 258 | for gidx in range(n_group): 259 | group_x = batch_x[:, gidx, :] 260 | group_y = batch_y[:, gidx, :] 261 | group_mask = batch_mask[:, gidx, :] 262 | 263 | group_x = torch.from_numpy(group_x).permute(1, 0).contiguous().to(self.device).long() # (seq_len, bsz) 264 | group_y = torch.from_numpy(group_y).permute(1, 0).contiguous().to(self.device).long() 265 | group_mask = torch.from_numpy(group_mask).to(self.device).float() 266 | 267 | ret = model(group_x, group_y, group_mask, *mems) 268 | loss, mems = ret[0], ret[1:] 269 | train_loss.append(loss.item()) 270 | loss.backward() 271 | 272 | sys.stdout.write('epoch:{:3d}/{:3d}, batch: {:4d}/{:4d}, group: {:2d}/{:2d} | Loss: {:6f}\r'.format( 273 | epoch, 274 | trainConfig['num_epochs'], 275 | bidx, 276 | num_batches, 277 | gidx, 278 | n_group, 279 | loss.item() 280 | )) 281 | sys.stdout.flush() 282 | 283 | optimizer.step() 284 | 285 | #val_loss = self.validate(val_data, batch_size, model, trainConfig["seed"], trainConfig['max_eval_steps']) 286 | curr_train_loss = sum(train_loss) / len(train_loss) 287 | saver_agent.add_summary('epoch loss', curr_train_loss) 288 | 289 | #epoch_val_loss.append(val_loss) 290 | epoch_train_loss.append(curr_train_loss) 291 | # epoch_info = 'Train Loss: {:.5f} , Val Loss: {:.5f}, T: {:.3f}'.format(curr_train_loss, val_loss, time.time()-st_time) 292 | epoch_info = 'Epoch: {}, Train Loss: {:.5f} , T: {:.3f}'.format(epoch+1, curr_train_loss, time.time()-st_time) 293 | print(epoch_info) 294 | 295 | # self.train_loss_record(epoch, curr_train_loss, checkpoint_dir, val_loss) 296 | self.train_loss_record(epoch, curr_train_loss, checkpoint_dir) 297 | self.save_checkpoint({ 298 | 'epoch': epoch + 1, 299 | 'model_setting': self.modelConfig, 300 | 'train_setting': trainConfig, 301 | 'state_dict': model.state_dict(), 302 | 'best_loss': curr_train_loss, 303 | 'optimizer' : optimizer.state_dict(), 304 | }, 305 | checkpoint_dir, 306 | save_freq) 307 | 308 | if curr_train_loss < 0.01: 309 | print('Experiment [{}] finished at loss < 0.01.'.format(checkpoint_dir)) 310 | break 311 | 312 | def inference(self, model_path, token_lim, strategies, params, bpm, output_path): 313 | _, model = self.get_model(model_path) 314 | model.eval() 315 | 316 | # initial start 317 | words = [[]] 318 | 319 | # add beat 320 | words[-1].append(self.event2word['Bar_None']) 321 | 322 | # initialize mem 323 | mems = tuple() 324 | song_init_time = time.time() 325 | # generate 326 | initial_flag = True 327 | generate_n_bar = 0 328 | batch_size = 1 329 | n_tokens = len(words[0]) 330 | while len(words[0]) < token_lim: 331 | # prepare input 332 | if initial_flag: 333 | temp_x = np.zeros((len(words[0]), batch_size)) 334 | 335 | for b in range(batch_size): 336 | for z, t in enumerate(words[b]): 337 | temp_x[z][b] = t 338 | 339 | initial_flag = False 340 | else: 341 | temp_x = np.zeros((1, batch_size)) 342 | 343 | for b in range(batch_size): 344 | temp_x[0][b] = words[b][-1] ####?#### 345 | 346 | temp_x = torch.from_numpy(temp_x).long().to(self.device) 347 | st_time = time.time() 348 | 349 | _logits, mems = model.generate(temp_x, *mems) 350 | logits = _logits.cpu().squeeze().detach().numpy() 351 | 352 | # temperature or not 353 | if 'temperature' in strategies: 354 | probs = self.temperature(logits=logits, temperature=params['t']) 355 | 356 | else: 357 | probs = self.temperature(logits=logits, temperature=1.) 358 | # sampling 359 | word = self.nucleus(probs=probs, p=params['p']) 360 | words[0].append(word) 361 | 362 | print(len(words[0]), self.word2event[word]) 363 | # record n_bar 364 | if word == self.event2word['Bar_None']: 365 | generate_n_bar += 1 366 | 367 | 368 | wrtie_midi(words[0], output_path, self.word2event) 369 | 370 | song_total_time = time.time() - song_init_time 371 | print('Total words generated: ', len(words[0])) 372 | return song_total_time, len(words[0]) 373 | 374 | ######################################## 375 | # search strategy: temperature (re-shape) 376 | ######################################## 377 | def temperature(self, logits, temperature): 378 | probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature)) 379 | return probs 380 | 381 | ######################################## 382 | # search strategy: topk (truncate) 383 | ######################################## 384 | def topk(self, probs, k): 385 | sorted_index = np.argsort(probs)[::-1] 386 | candi_index = sorted_index[:k] 387 | candi_probs = [probs[i] for i in candi_index] 388 | candi_probs /= sum(candi_probs) 389 | word = np.random.choice(candi_index, size=1, p=candi_probs)[0] 390 | return word 391 | 392 | ######################################## 393 | # search strategy: nucleus (truncate) 394 | ######################################## 395 | def nucleus(self, probs, p): 396 | probs /= sum(probs) 397 | sorted_probs = np.sort(probs)[::-1] 398 | sorted_index = np.argsort(probs)[::-1] 399 | cusum_sorted_probs = np.cumsum(sorted_probs) 400 | after_threshold = cusum_sorted_probs > p 401 | if sum(after_threshold) > 0: 402 | last_index = np.where(after_threshold)[0][0] + 1 403 | candi_index = sorted_index[:last_index] 404 | else: 405 | candi_index = sorted_index[:3] # just assign a value 406 | candi_probs = [probs[i] for i in candi_index] 407 | candi_probs /= sum(candi_probs) 408 | word = np.random.choice(candi_index, size=1, p=candi_probs)[0] 409 | return word -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/modules.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import functools 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class PositionalEmbedding(nn.Module): 12 | def __init__(self, demb): 13 | super(PositionalEmbedding, self).__init__() 14 | 15 | self.demb = demb 16 | 17 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) 18 | self.register_buffer('inv_freq', inv_freq) 19 | 20 | def forward(self, pos_seq, bsz=None): 21 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq) 22 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 23 | 24 | if bsz is not None: 25 | return pos_emb[:,None,:].expand(-1, bsz, -1) 26 | else: 27 | return pos_emb[:,None,:] 28 | 29 | 30 | class PositionwiseFF(nn.Module): 31 | def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): 32 | super(PositionwiseFF, self).__init__() 33 | 34 | self.d_model = d_model 35 | self.d_inner = d_inner 36 | self.dropout = dropout 37 | 38 | self.CoreNet = nn.Sequential( 39 | nn.Linear(d_model, d_inner), nn.ReLU(inplace=True), 40 | nn.Dropout(dropout), 41 | nn.Linear(d_inner, d_model), 42 | nn.Dropout(dropout), 43 | ) 44 | 45 | self.layer_norm = nn.LayerNorm(d_model) 46 | 47 | self.pre_lnorm = pre_lnorm 48 | 49 | def forward(self, inp): 50 | if self.pre_lnorm: 51 | ##### layer normalization + positionwise feed-forward 52 | core_out = self.CoreNet(self.layer_norm(inp)) 53 | 54 | ##### residual connection 55 | output = core_out + inp 56 | else: 57 | ##### positionwise feed-forward 58 | core_out = self.CoreNet(inp) 59 | 60 | ##### residual connection + layer normalization 61 | output = self.layer_norm(inp + core_out) 62 | 63 | return output 64 | 65 | 66 | class RelMultiHeadAttn(nn.Module): 67 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, 68 | tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False): 69 | super(RelMultiHeadAttn, self).__init__() 70 | 71 | self.n_head = n_head 72 | self.d_model = d_model 73 | self.d_head = d_head 74 | self.dropout = dropout 75 | 76 | self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) 77 | 78 | self.drop = nn.Dropout(dropout) 79 | self.dropatt = nn.Dropout(dropatt) 80 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) 81 | 82 | self.layer_norm = nn.LayerNorm(d_model) 83 | 84 | self.scale = 1 / (d_head ** 0.5) 85 | 86 | self.pre_lnorm = pre_lnorm 87 | 88 | def _parallelogram_mask(self, h, w, left=False): 89 | mask = torch.ones((h, w)).byte() 90 | m = min(h, w) 91 | mask[:m,:m] = torch.triu(mask[:m,:m]) 92 | mask[-m:,-m:] = torch.tril(mask[-m:,-m:]) 93 | 94 | if left: 95 | return mask 96 | else: 97 | return mask.flip(0) 98 | 99 | def _shift(self, x, qlen, klen, mask, left=False): 100 | if qlen > 1: 101 | zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)), 102 | device=x.device, dtype=x.dtype) 103 | else: 104 | zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype) 105 | 106 | if left: 107 | mask = mask.flip(1) 108 | x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1) 109 | else: 110 | x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1) 111 | 112 | x = x_padded.masked_select(mask[:,:,None,None]) \ 113 | .view(qlen, klen, x.size(2), x.size(3)) 114 | 115 | return x 116 | 117 | def _rel_shift(self, x, zero_triu=False): 118 | zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]), 119 | device=x.device, dtype=x.dtype) 120 | x_padded = torch.cat([zero_pad, x], dim=1) 121 | 122 | x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:]) 123 | 124 | x = x_padded[1:].view_as(x) 125 | 126 | if zero_triu: 127 | ones = torch.ones((x.size(0), x.size(1))) 128 | x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None] 129 | 130 | return x 131 | 132 | def forward(self, w, r, attn_mask=None, mems=None): 133 | raise NotImplementedError 134 | 135 | class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): 136 | def __init__(self, *args, **kwargs): 137 | super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs) 138 | 139 | self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) 140 | 141 | def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): 142 | qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) 143 | 144 | if mems is not None: 145 | # print("w",w.shape) 146 | # print("mems",mems.shape) 147 | cat = torch.cat([mems, w], 0) 148 | if self.pre_lnorm: 149 | w_heads = self.qkv_net(self.layer_norm(cat)) 150 | else: 151 | w_heads = self.qkv_net(cat) 152 | r_head_k = self.r_net(r) 153 | 154 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 155 | w_head_q = w_head_q[-qlen:] 156 | else: 157 | if self.pre_lnorm: 158 | w_heads = self.qkv_net(self.layer_norm(w)) 159 | else: 160 | w_heads = self.qkv_net(w) 161 | r_head_k = self.r_net(r) 162 | 163 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 164 | 165 | klen = w_head_k.size(0) 166 | 167 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head 168 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head 169 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head 170 | 171 | r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head 172 | 173 | #### compute attention score 174 | rw_head_q = w_head_q + r_w_bias 175 | AC = rw_head_q.permute(1, 2, 0, 3) @ w_head_k.permute(1, 2, 3, 0) 176 | 177 | rr_head_q = w_head_q + r_r_bias 178 | BD = rr_head_q.permute(1, 2, 0, 3) @ r_head_k.permute(1, 2, 0) 179 | BD = F.pad(BD, [1, 0]).view(BD.size(0), BD.size( 180 | 1), BD.size(3) + 1, BD.size(2))[:, :, 1:].view_as(BD) 181 | 182 | # [bsz x n_head x qlen x klen] 183 | attn_score = AC + BD 184 | attn_score.mul_(self.scale) 185 | 186 | #### compute attention probability 187 | if attn_mask is not None and attn_mask.any().item(): 188 | if attn_mask.dim() == 2: 189 | attn_score = attn_score.float().masked_fill( 190 | attn_mask, -float('inf')).type_as(attn_score) 191 | elif attn_mask.dim() == 3: 192 | attn_score = attn_score.float().masked_fill( 193 | attn_mask.permute(2, 0, 1)[:, None, :, :], -float('inf')).type_as(attn_score) 194 | 195 | # [bsz x n_head x qlen x klen] 196 | attn_prob = F.softmax(attn_score, dim=-1) 197 | attn_prob = self.dropatt(attn_prob) 198 | 199 | #### compute attention vector 200 | attn_vec = attn_prob @ w_head_v.permute(1, 2, 0, 3) 201 | attn_vec = attn_vec.permute(2, 0, 1, 3) 202 | 203 | # [qlen x bsz x n_head x d_head] 204 | attn_vec = attn_vec.contiguous().view( 205 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) 206 | 207 | ##### linear projection 208 | attn_out = self.o_net(attn_vec) 209 | attn_out = self.drop(attn_out) 210 | 211 | if self.pre_lnorm: 212 | ##### residual connection 213 | output = w + attn_out 214 | else: 215 | ##### residual connection + layer normalization 216 | output = self.layer_norm(w + attn_out) 217 | 218 | return output 219 | 220 | 221 | class RelPartialLearnableDecoderLayer(nn.Module): 222 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, 223 | **kwargs): 224 | super(RelPartialLearnableDecoderLayer, self).__init__() 225 | 226 | self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, 227 | d_head, dropout, **kwargs) 228 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 229 | pre_lnorm=kwargs.get('pre_lnorm')) 230 | 231 | def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): 232 | 233 | output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias, 234 | attn_mask=dec_attn_mask, 235 | mems=mems) 236 | output = self.pos_ff(output) 237 | 238 | return output 239 | 240 | 241 | class Embeddings(nn.Module): 242 | def __init__(self, n_token, d_model): 243 | super(Embeddings, self).__init__() 244 | self.lut = nn.Embedding(n_token, d_model) 245 | self.d_model = d_model 246 | 247 | def forward(self, x): 248 | return self.lut(x) * math.sqrt(self.d_model) 249 | 250 | 251 | 252 | class MemTransformerLM(nn.Module): 253 | def __init__(self, modelConfig, 254 | tie_projs=[False], cutoffs=[], 255 | is_training=True): 256 | super(MemTransformerLM, self).__init__() 257 | 258 | self.n_token = modelConfig['n_token'] 259 | self.n_layer= modelConfig['n_layer'] 260 | self.n_head= modelConfig['n_head'] 261 | self.d_model = modelConfig['d_model'] 262 | self.d_embed = d_model if modelConfig['d_embed'] is None else modelConfig['d_embed'] 263 | self.d_head = self.d_model // self.n_head 264 | self.d_inner= modelConfig['d_inner'] 265 | 266 | self.mem_len = modelConfig['mem_len'] 267 | self.tgt_len = modelConfig['tgt_len'] 268 | self.ext_len = modelConfig['ext_len'] 269 | self.max_klen = self.tgt_len + self.ext_len + self.mem_len #70+0+512 270 | 271 | self.dropout= modelConfig['dropout'] 272 | self.dropatt = modelConfig['dropatt'] 273 | 274 | self.clamp_len = modelConfig['clamp_len'] 275 | self.div_val = modelConfig['div_val'] 276 | 277 | #choice 278 | self.pre_lnorm = modelConfig['pre_lnorm'] 279 | self.same_length = modelConfig['same_length'] 280 | self.is_training = is_training 281 | 282 | #building layers 283 | self.drop = nn.Dropout(self.dropout) 284 | self.word_emb = Embeddings(self.n_token, self.d_model) 285 | 286 | self.layers = nn.ModuleList() 287 | for i in range(self.n_layer): 288 | self.layers.append( 289 | RelPartialLearnableDecoderLayer( 290 | self.n_head, self.d_model, self.d_head, self.d_inner, self.dropout, 291 | tgt_len=self.tgt_len, ext_len=self.ext_len, mem_len=self.mem_len, 292 | dropatt=self.dropatt, pre_lnorm=self.pre_lnorm) 293 | ) 294 | 295 | # output layer 296 | self.linear_proj = nn.Linear(self.d_model, self.n_token) 297 | 298 | # loss 299 | self.loss_func = nn.CrossEntropyLoss(reduction='none') 300 | self._create_params() 301 | 302 | def compute_loss(self, predict, target, loss_mask=None): 303 | ''' 304 | predict, target, 305 | input: (N, C, ...) 306 | target: (N, ...) 307 | ''' 308 | loss = self.loss_func(predict, target) 309 | loss = loss * loss_mask 310 | loss = torch.sum(loss) / torch.sum(loss_mask) 311 | return loss 312 | 313 | def _create_params(self): 314 | self.pos_emb = PositionalEmbedding(self.d_model) 315 | self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 316 | self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 317 | 318 | def reset_length(self, tgt_len, ext_len, mem_len): 319 | self.tgt_len = tgt_len 320 | self.mem_len = mem_len 321 | self.ext_len = ext_len 322 | 323 | def init_mems(self): 324 | if self.mem_len > 0: 325 | mems = [] 326 | param = next(self.parameters()) 327 | for i in range(self.n_layer+1): 328 | empty = torch.empty(0, dtype=param.dtype, device=param.device) 329 | mems.append(empty) 330 | return mems 331 | else: 332 | return None 333 | 334 | def _update_mems(self, hids, mems, mlen, qlen): 335 | 336 | if mems is None: return None 337 | # mems is not None 338 | # assert len(hids) == len(mems), 'len(hids) != len(mems)' 339 | 340 | # There are `mlen + qlen` steps that can be cached into mems 341 | # For the next step, the last `ext_len` of the `qlen` tokens 342 | # will be used as the extended context. Hence, we only cache 343 | # the tokens from `mlen + qlen - self.ext_len - self.mem_len` 344 | # to `mlen + qlen - self.ext_len`. 345 | with torch.no_grad(): 346 | new_mems = [] 347 | end_idx = mlen + max(0, qlen - 0 - self.ext_len) 348 | beg_idx = max(0, end_idx - self.mem_len) 349 | 350 | for i in range(len(hids)): 351 | cat = torch.cat([mems[i], hids[i]], dim=0) 352 | new_mems.append(cat[beg_idx:end_idx].detach()) 353 | 354 | return new_mems 355 | 356 | 357 | 358 | def _forward(self, dec_inp, mems=None): 359 | ''' 360 | output of _forward: step x batch x n_feat 361 | predict = self.linear_proj(hidden) 362 | ''' 363 | 364 | qlen, bsz = dec_inp.size() 365 | mlen = mems[0].size(0) if mems is not None else 0 366 | klen = mlen + qlen 367 | 368 | word_emb = self.word_emb(dec_inp) 369 | 370 | if self.same_length: 371 | all_ones = word_emb.new_ones(qlen, klen) 372 | mask_len = klen - self.mem_len 373 | 374 | if mask_len > 0: 375 | mask_shift_len = qlen - mask_len 376 | else: 377 | mask_shift_len = qlen 378 | dec_attn_mask = (torch.triu(all_ones, 1+mlen) 379 | + torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1 380 | else: 381 | dec_attn_mask = torch.triu( 382 | word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None] 383 | 384 | 385 | hids = [] 386 | pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, 387 | dtype=word_emb.dtype) 388 | if self.clamp_len > 0: 389 | pos_seq.clamp_(max=self.clamp_len) 390 | pos_emb = self.pos_emb(pos_seq) 391 | core_out = self.drop(word_emb) 392 | pos_emb = self.drop(pos_emb) 393 | hids.append(core_out) 394 | 395 | for i, layer in enumerate(self.layers): 396 | mems_i = None if mems is None else mems[i] 397 | 398 | core_out = layer(core_out, pos_emb, self.r_w_bias, 399 | self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) 400 | hids.append(core_out) 401 | 402 | core_out = self.drop(core_out) 403 | new_mems = self._update_mems(hids, mems, mlen, qlen) 404 | 405 | return core_out, new_mems 406 | 407 | def generate(self, data, *mems): 408 | if not mems: mems = self.init_mems() 409 | hidden, new_mems = self._forward(data, mems=mems) 410 | predict = self.linear_proj(hidden[-1:]) 411 | return predict, new_mems 412 | 413 | def forward(self, data, target, mask, *mems): 414 | if not mems: mems = self.init_mems() 415 | 416 | tgt_len = target.size(0) 417 | hidden, new_mems = self._forward(data, mems=mems) 418 | 419 | pred_hid = hidden[-tgt_len:] 420 | predict = self.linear_proj(pred_hid) 421 | 422 | predict = predict.permute(1, 2, 0) 423 | target = target.permute(1, 0) 424 | 425 | loss = self.compute_loss(predict, target, mask) 426 | 427 | if new_mems is None: 428 | return [loss] 429 | else: 430 | return [loss] + new_mems 431 | -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/runtime_stats.json: -------------------------------------------------------------------------------- 1 | {"song_time": [400.5098271369934, 377.5159821510315, 141.5237319469452, 141.3457088470459, 140.7412805557251, 146.80897665023804, 149.34970378875732, 180.57227396965027, 187.8685564994812, 190.16119360923767, 186.2575488090515, 184.2108030319214, 182.1258454322815, 179.762943983078, 185.42160320281982, 185.23676824569702, 188.1896812915802, 191.02102065086365, 186.58602023124695, 179.87972950935364], "words_len_list": [7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680], "ave token time:": 39.33328847340423, "ave song time": 195.25445997714996} -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import logging 5 | import datetime 6 | import collections 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | class Saver(object): 12 | def __init__( 13 | self, 14 | exp_dir, 15 | mode='w'): 16 | 17 | self.exp_dir = exp_dir 18 | self.init_time = time.time() 19 | self.global_step = 0 20 | 21 | # makedirs 22 | os.makedirs(exp_dir, exist_ok=True) 23 | 24 | # logging config 25 | path_logger = os.path.join(exp_dir, 'log.txt') 26 | logging.basicConfig( 27 | level=logging.DEBUG, 28 | format='%(message)s', 29 | filename=path_logger, 30 | filemode=mode) 31 | self.logger = logging.getLogger('training monitor') 32 | 33 | def add_summary_msg(self, msg): 34 | self.logger.debug(msg) 35 | 36 | def add_summary( 37 | self, 38 | key, 39 | val, 40 | step=None, 41 | cur_time=None): 42 | 43 | if cur_time is None: 44 | cur_time = time.time() - self.init_time 45 | if step is None: 46 | step = self.global_step 47 | 48 | # write msg (key, val, step, time) 49 | if isinstance(val, float): 50 | msg_str = '{:10s} | {:.10f} | {:10d} | {}'.format( 51 | key, 52 | val, 53 | step, 54 | cur_time 55 | ) 56 | else: 57 | msg_str = '{:10s} | {} | {:10d} | {}'.format( 58 | key, 59 | val, 60 | step, 61 | cur_time 62 | ) 63 | 64 | self.logger.debug(msg_str) 65 | 66 | def save_model( 67 | self, 68 | model, 69 | optimizer=None, 70 | outdir=None, 71 | name='model'): 72 | 73 | if outdir is None: 74 | outdir = self.exp_dir 75 | print(' [*] saving model to {}, name: {}'.format(outdir, name)) 76 | torch.save(model, os.path.join(outdir, name+'.pt')) 77 | torch.save(model.state_dict(), os.path.join(outdir, name+'_params.pt')) 78 | 79 | if optimizer is not None: 80 | torch.save(optimizer.state_dict(), os.path.join(outdir, name+'_opt.pt')) 81 | 82 | def load_model( 83 | self, 84 | path_exp, 85 | device='cpu', 86 | name='model.pt'): 87 | 88 | path_pt = os.path.join(path_exp, name) 89 | print(' [*] restoring model from', path_pt) 90 | model = torch.load(path_pt, map_location=torch.device(device)) 91 | return model 92 | 93 | def global_step_increment(self): 94 | self.global_step += 1 95 | 96 | """ 97 | file modes 98 | 'a': 99 | Opens a file for appending. The file pointer is at the end of the file if the file exists. 100 | That is, the file is in the append mode. If the file does not exist, it creates a new file for writing. 101 | 102 | 'w': 103 | Opens a file for writing only. Overwrites the file if the file exists. 104 | If the file does not exist, creates a new file for writing. 105 | """ 106 | 107 | def make_loss_report( 108 | path_log, 109 | path_figure='loss.png', 110 | dpi=100): 111 | 112 | # load logfile 113 | monitor_vals = collections.defaultdict(list) 114 | with open(path_logfile, 'r') as f: 115 | for line in f: 116 | try: 117 | line = line.strip() 118 | key, val, step, acc_time = line.split(' | ') 119 | monitor_vals[key].append((float(val), int(step), acc_time)) 120 | except: 121 | continue 122 | 123 | # collect 124 | step_train = [item[1] for item in monitor_vals['train loss']] 125 | vals_train = [item[0] for item in monitor_vals['train loss']] 126 | 127 | step_valid = [item[1] for item in monitor_vals['valid loss']] 128 | vals_valid = [item[0] for item in monitor_vals['valid loss']] 129 | 130 | x_min = step_valid[np.argmin(vals_valid)] 131 | y_min = min(vals_valid) 132 | 133 | # plot 134 | fig = plt.figure(dpi=dpi) 135 | plt.title('training process') 136 | plt.plot(step_train, vals_train, label='train') 137 | plt.plot(step_valid, vals_valid, label='valid') 138 | plt.yscale('log') 139 | plt.plot([x_min], [y_min], 'ro') 140 | plt.legend(loc='upper right') 141 | plt.tight_layout() 142 | plt.savefig(path_figure) 143 | 144 | ''' 145 | author: wayn391@mastertones 146 | ''' 147 | 148 | import os 149 | import time 150 | import torch 151 | import logging 152 | import datetime 153 | import collections 154 | import numpy as np 155 | import matplotlib.pyplot as plt 156 | 157 | 158 | class Saver(object): 159 | def __init__( 160 | self, 161 | exp_dir, 162 | mode='w'): 163 | 164 | self.exp_dir = exp_dir 165 | self.init_time = time.time() 166 | self.global_step = 0 167 | 168 | # makedirs 169 | os.makedirs(exp_dir, exist_ok=True) 170 | 171 | # logging config 172 | path_logger = os.path.join(exp_dir, 'log.txt') 173 | logging.basicConfig( 174 | level=logging.DEBUG, 175 | format='%(message)s', 176 | filename=path_logger, 177 | filemode=mode) 178 | self.logger = logging.getLogger('training monitor') 179 | 180 | def add_summary_msg(self, msg): 181 | self.logger.debug(msg) 182 | 183 | def add_summary( 184 | self, 185 | key, 186 | val, 187 | step=None, 188 | cur_time=None): 189 | 190 | if cur_time is None: 191 | cur_time = time.time() - self.init_time 192 | if step is None: 193 | step = self.global_step 194 | 195 | # write msg (key, val, step, time) 196 | if isinstance(val, float): 197 | msg_str = '{:10s} | {:.10f} | {:10d} | {}'.format( 198 | key, 199 | val, 200 | step, 201 | cur_time 202 | ) 203 | else: 204 | msg_str = '{:10s} | {} | {:10d} | {}'.format( 205 | key, 206 | val, 207 | step, 208 | cur_time 209 | ) 210 | 211 | self.logger.debug(msg_str) 212 | 213 | def save_model( 214 | self, 215 | model, 216 | optimizer=None, 217 | outdir=None, 218 | name='model'): 219 | 220 | if outdir is None: 221 | outdir = self.exp_dir 222 | print(' [*] saving model to {}, name: {}'.format(outdir, name)) 223 | # torch.save(model, os.path.join(outdir, name+'.pt')) 224 | torch.save(model.state_dict(), os.path.join(outdir, name+'_params.pt')) 225 | 226 | if optimizer is not None: 227 | torch.save(optimizer.state_dict(), os.path.join(outdir, name+'_opt.pt')) 228 | 229 | def load_model( 230 | self, 231 | path_exp, 232 | device='cpu', 233 | name='model.pt'): 234 | 235 | path_pt = os.path.join(path_exp, name) 236 | print(' [*] restoring model from', path_pt) 237 | model = torch.load(path_pt, map_location=torch.device(device)) 238 | return model 239 | 240 | def global_step_increment(self): 241 | self.global_step += 1 242 | 243 | """ 244 | file modes 245 | 'a': 246 | Opens a file for appending. The file pointer is at the end of the file if the file exists. 247 | That is, the file is in the append mode. If the file does not exist, it creates a new file for writing. 248 | 249 | 'w': 250 | Opens a file for writing only. Overwrites the file if the file exists. 251 | If the file does not exist, creates a new file for writing. 252 | """ 253 | 254 | def make_loss_report( 255 | path_log, 256 | path_figure='loss.png', 257 | dpi=100): 258 | 259 | # load logfile 260 | monitor_vals = collections.defaultdict(list) 261 | with open(path_logfile, 'r') as f: 262 | for line in f: 263 | try: 264 | line = line.strip() 265 | key, val, step, acc_time = line.split(' | ') 266 | monitor_vals[key].append((float(val), int(step), acc_time)) 267 | except: 268 | continue 269 | 270 | # collect 271 | step_train = [item[1] for item in monitor_vals['train loss']] 272 | vals_train = [item[0] for item in monitor_vals['train loss']] 273 | 274 | step_valid = [item[1] for item in monitor_vals['valid loss']] 275 | vals_valid = [item[0] for item in monitor_vals['valid loss']] 276 | 277 | x_min = step_valid[np.argmin(vals_valid)] 278 | y_min = min(vals_valid) 279 | 280 | # plot 281 | fig = plt.figure(dpi=dpi) 282 | plt.title('training process') 283 | plt.plot(step_train, vals_train, label='train') 284 | plt.plot(step_valid, vals_valid, label='valid') 285 | plt.yscale('log') 286 | plt.plot([x_min], [y_min], 'ro') 287 | plt.legend(loc='upper right') 288 | plt.tight_layout() 289 | plt.savefig(path_figure) 290 | 291 | -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/songTime.csv: -------------------------------------------------------------------------------- 1 | ,song_name,song_time,num_word,ave_genTime_per_word 2 | 0,60_673cd223aa2c317002ac4c975a9c67df,70.56650042533875,3940,0.01676041873421256 3 | 1,60_c0793ed178d945ddd18bfb7cc7b190b8,35.46065306663513,2001,0.016596261803452316 4 | -------------------------------------------------------------------------------- /workspace/uncond/remi-xl/train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import yaml 5 | import pickle 6 | import datetime 7 | import numpy as np 8 | from collections import OrderedDict 9 | 10 | import torch 11 | from model import TransformerXL 12 | 13 | 14 | def main(): 15 | # gen config 16 | modelConfig, trainConfig = get_configs() 17 | 18 | # load dictionary 19 | event2word, word2event = pickle.load(open(os.path.join(trainConfig['ROOT'],'dictionary.pkl'), 'rb')) 20 | 21 | # load train data 22 | training_data = np.load(os.path.join(trainConfig['ROOT'],'train_data_XL.npz')) 23 | 24 | device = torch.device("cuda:{}".format(trainConfig['gpuID']) if not trainConfig["no_cuda"] and torch.cuda.is_available() else "cpu") 25 | os.environ['CUDA_VISIBLE_DEVICES'] = trainConfig['gpuID'] 26 | 27 | print('Device to train:', device) 28 | 29 | resume = trainConfig['resume_training_model'] 30 | 31 | # declare model 32 | model = TransformerXL( 33 | modelConfig, 34 | device, 35 | event2word=event2word, 36 | word2event=word2event, 37 | is_training=True) 38 | 39 | # train 40 | model.train(training_data, 41 | trainConfig, 42 | device, 43 | resume) 44 | 45 | 46 | def get_configs(): 47 | cfg = yaml.full_load(open("config.yml", 'r')) 48 | 49 | modelConfig = cfg['MODEL'] 50 | trainConfig = cfg['TRAIN'] 51 | 52 | cur_date = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') 53 | experiment_Dir = os.path.join(trainConfig['output_dir'],cur_date) 54 | if not os.path.exists(experiment_Dir): 55 | print('experiment_Dir:', experiment_Dir) 56 | os.makedirs(experiment_Dir) 57 | print('Experiment: ', experiment_Dir) 58 | trainConfig.update({'experiment_Dir': experiment_Dir}) 59 | 60 | 61 | with open(os.path.join(experiment_Dir, 'config.yml'), 'w') as f: 62 | doc = yaml.dump(cfg, f) 63 | 64 | print('='*5, 'Model configs', '='*5) 65 | print(json.dumps(modelConfig, indent=1, sort_keys=True)) 66 | print('='*2, 'Training configs', '='*5) 67 | print(json.dumps(trainConfig, indent=1, sort_keys=True)) 68 | return modelConfig, trainConfig 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | 74 | 75 | --------------------------------------------------------------------------------