├── .gitignore
├── LICENSE
├── README.md
├── assets
└── data_proc_diagram.png
├── dataset
├── Dataset.md
├── analyzer.py
├── corpus
│ ├── .DS_Store
│ └── keep
├── midi2corpus.py
├── midi_analyzed
│ ├── .DS_Store
│ └── keep
├── midi_synchronized
│ ├── .DS_Store
│ └── keep
├── midi_transcribed
│ ├── .DS_Store
│ └── keep
├── representations
│ ├── cond-ls2midi
│ │ └── keep
│ └── uncond
│ │ ├── cp
│ │ ├── .DS_Store
│ │ ├── compile.py
│ │ ├── corpus2events.py
│ │ └── events2words.py
│ │ ├── remi
│ │ ├── .DS_Store
│ │ ├── compile.py
│ │ ├── corpus2events.py
│ │ ├── events2words.py
│ │ └── valid_fn_idx_map.json
│ │ └── validation_songs.json
└── synchronizer.py
├── docs
└── aaai21-slides.pdf
└── workspace
├── cond_ls2midi
└── keep
└── uncond
├── Experiments.md
├── cp-linear
├── gen_midis
│ ├── get_0.mid
│ ├── get_1.mid
│ ├── get_10.mid
│ ├── get_11.mid
│ ├── get_12.mid
│ ├── get_13.mid
│ ├── get_14.mid
│ ├── get_15.mid
│ ├── get_16.mid
│ ├── get_17.mid
│ ├── get_18.mid
│ ├── get_19.mid
│ ├── get_2.mid
│ ├── get_20.mid
│ ├── get_21.mid
│ ├── get_22.mid
│ ├── get_23.mid
│ ├── get_24.mid
│ ├── get_25.mid
│ ├── get_26.mid
│ ├── get_27.mid
│ ├── get_28.mid
│ ├── get_29.mid
│ ├── get_3.mid
│ ├── get_30.mid
│ ├── get_31.mid
│ ├── get_32.mid
│ ├── get_33.mid
│ ├── get_34.mid
│ ├── get_35.mid
│ ├── get_36.mid
│ ├── get_37.mid
│ ├── get_38.mid
│ ├── get_39.mid
│ ├── get_4.mid
│ ├── get_40.mid
│ ├── get_41.mid
│ ├── get_42.mid
│ ├── get_43.mid
│ ├── get_44.mid
│ ├── get_45.mid
│ ├── get_46.mid
│ ├── get_47.mid
│ ├── get_48.mid
│ ├── get_49.mid
│ ├── get_5.mid
│ ├── get_6.mid
│ ├── get_7.mid
│ ├── get_8.mid
│ └── get_9.mid
├── main-cp.py
├── runtime_stats.json
└── saver.py
└── remi-xl
├── config.yml
├── gen_midis
├── 170_0.mid
├── 170_1.mid
├── 170_10.mid
├── 170_11.mid
├── 170_12.mid
├── 170_13.mid
├── 170_14.mid
├── 170_15.mid
├── 170_16.mid
├── 170_17.mid
├── 170_18.mid
├── 170_19.mid
├── 170_2.mid
├── 170_3.mid
├── 170_4.mid
├── 170_5.mid
├── 170_6.mid
├── 170_7.mid
├── 170_8.mid
└── 170_9.mid
├── inference.py
├── model.py
├── modules.py
├── runtime_stats.json
├── saver.py
├── songTime.csv
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | __pycache__/
132 | .vscode/
133 | .ipynb_checkpoints/
134 | .DS_Store
135 | miditoolkit.egg-info/
136 | @eaDir
137 | *.pyc
138 | *.pypirc
139 | Thumbs.db
140 | *.gz
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Compound Word Transformer
2 |
3 |
4 | Authors: [Wen-Yi Hsiao](https://github.com/wayne391), [Jen-Yu Liu](https://github.com/ciaua), [Yin-Cheng Yeh](https://github.com/yyeh26) and [Yi-Hsuan Yang](http://mac.citi.sinica.edu.tw/~yang/)
5 |
6 | [**Paper (arXiv)**](https://arxiv.org/abs/2101.02402) | [**Audio demo (Google Drive)**](https://drive.google.com/drive/folders/1G_tTpcAuVpYO-4IUGS8i8XdwoIsUix8o?usp=sharing) | [**Blog**](https://ailabs.tw/human-interaction/compound-word-transformer-generate-pop-piano-music-of-full-song-length/) | [**Colab notebook**](https://colab.research.google.com/drive/1AU8iMhy10WxHj7yt3j8S3FQvvKvgXrr0)
7 |
8 | Official PyTorch implementation of AAAI2021 paper "Compound Word Transformer: Learning to Compose Full-Song Musicover Dynamic Directed Hypergraphs".
9 |
10 | We presented a new variant of the Transformer that can processes multiple consecutive tokens at once at a time step. The proposed method can greatly reduce the length of the resulting sequence and therefore enhance the training and inference efficiency. We employ it to learn to compose expressive Pop piano music of full-song length (involving up to 10K individual to23 kens per song). In this repository, we open source our **Ailabs.tw 1K7** dataset, and the codes for unconditional generation.
11 |
12 |
13 | ## Dependencies
14 |
15 | * python 3.6
16 | * Required packages:
17 | * madmom
18 | * miditoolkit
19 | * pytorch-fast-transformers
20 |
21 |
22 | ``chorder`` is our in-house rule-based symbolic chord recognition algorithm, which is developed by our former intern - [joshuachang2311](https://github.com/joshuachang2311/chorder). He is also a jazz pianist :musical_keyboard:.
23 |
24 |
25 | ## Model
26 | In this work, we conduct two scenario of generation:
27 | * unconditional generation
28 | * To see the experimental results and discussion, please refer to [here](https://github.com/YatingMusic/compound-word-transformer/blob/main/workspace/uncond/Experiments.md).
29 |
30 | * conditional generation, leadsheet to full midi (ls2midi)
31 | * [**Work in progress**] We plan to open source the code associated with this part in the future.
32 | * melody extracyion (skyline)
33 | * objective metrics
34 | * model
35 |
36 | ## Dataset
37 | To prepare your own training data, please refer to [documentaion](https://github.com/YatingMusic/compound-word-transformer/blob/main/dataset/Dataset.md) for further understanding.
38 | Or, you can start with our **AIlabs.tw Pop1K7**, which is available [here](https://drive.google.com/file/d/1qw_tVUntblIg4lW16vbpjLXVndkVtgDe/view?usp=sharing).
39 |
40 | ## Demo: Colab Notebook
41 |
42 | The colab notebook is now available [here](https://colab.research.google.com/drive/1AU8iMhy10WxHj7yt3j8S3FQvvKvgXrr0).
43 | Thanks our intern [AdarshKumar712](https://github.com/AdarshKumar712) for organizing the codes.
44 |
45 |
46 | ## Acknowledgement
47 | - PyTorch codes for transformer-XL is modified from [kimiyoung/transformer-xl](https://github.com/kimiyoung/transformer-xl).
48 | - Thanks [Yu-Hua Chen](https://github.com/ss12f32v) and [Hsiao-Tzu Hung](https://github.com/annahung31) for helping organize the codes.
49 |
50 |
--------------------------------------------------------------------------------
/assets/data_proc_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/assets/data_proc_diagram.png
--------------------------------------------------------------------------------
/dataset/Dataset.md:
--------------------------------------------------------------------------------
1 | # Datasets
2 |
3 | In this document, we demonstrate our standard data processing pipeline in our team. Following the instructions and runnung corresponding python scripts, you can easily generate and customized your your own dataset.
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | ## 1. From `audio` to `midi_transcribed`
12 | We collect audio clips of piano performance from YouTube.
13 |
14 | * run google magenta's [onsets and frames](https://github.com/magenta/magenta/tree/master/magenta/models/onsets_frames_transcription)
15 |
16 | ## 2. From `midi_transcribed` to `midi_synchronized`
17 | In this step, we use [madamom](https://github.com/CPJKU/madmom) for beat/downbeat tracking. Next, We interpolate 480 ticks between two adjacent beats, and map the absolute time into its according tick. Lastly, we infer the tempo changes from the time interval between adjacent beats. We choose beat resolution=480 because it's a common setting in modern DAW. Notice that we don't quantize any timing in this step hence we can keep tiny offset for future purposes.
18 |
19 | * run `synchronizer.py`
20 |
21 | ## 3. From `midi_synchronized` to `midi_analyzed`
22 | In this step, we develop in-house rule-based symbolic melody extraction and chord recognition algorithm to obtain desired information. The code for chord recognition are open sourced [here](https://github.com/joshuachang2311/chorder). We plan to open rhe code for melody extraction in the future.
23 |
24 | * run `analyzer.py`
25 |
26 | ## 4. From `midi_analyzed` to `Corpus`
27 | We quantize everything (duration, velocity, bpm) in this step. Also append the data with EOS(end of sequence) token.
28 |
29 | * run `midi2corpus.py`
30 |
31 | ## 5. From `Corpus` to `Representation`
32 | We have 2 kinds of representation - Compound Word (**CP**) and **REMI**, and 2 tasks - unconditional and conditional generation, which resulting 4 combinations. Go to corresponding folder `/` and run the scripts.
33 |
34 |
35 | * run `corpus2events.py`: to generate human readable tokens and re-arrange data.
36 | * run `events2words.py`: to build dictionary and renumber the tokens.
37 | * run `compile.py`: to discard disqualified songs that exceeding length limits, reshape the data for transformer-XL, and generate mask for variable length.
38 |
39 | ---
40 |
41 | ## AILabs.tw Pop1K7 dataset
42 |
43 | Alternatively, you can refer to [here](https://drive.google.com/drive/folders/1DY54sxeCcQfVXdGXps5lHwtRe7D_kBRV?usp=sharing) to obtain the entire workspace and the pre-processed training data, which originally used in our paper.
44 |
--------------------------------------------------------------------------------
/dataset/analyzer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import numpy as np
4 | import multiprocessing as mp
5 |
6 | import miditoolkit
7 | from miditoolkit.midi import parser as mid_parser
8 | from miditoolkit.pianoroll import parser as pr_parser
9 | from miditoolkit.midi.containers import Marker, Instrument, TempoChange
10 |
11 | from chorder import Dechorder
12 |
13 |
14 | num2pitch = {
15 | 0: 'C',
16 | 1: 'C#',
17 | 2: 'D',
18 | 3: 'D#',
19 | 4: 'E',
20 | 5: 'F',
21 | 6: 'F#',
22 | 7: 'G',
23 | 8: 'G#',
24 | 9: 'A',
25 | 10: 'A#',
26 | 11: 'B',
27 | }
28 |
29 |
30 | def traverse_dir(
31 | root_dir,
32 | extension=('mid', 'MID', 'midi'),
33 | amount=None,
34 | str_=None,
35 | is_pure=False,
36 | verbose=False,
37 | is_sort=False,
38 | is_ext=True):
39 | if verbose:
40 | print('[*] Scanning...')
41 | file_list = []
42 | cnt = 0
43 | for root, _, files in os.walk(root_dir):
44 | for file in files:
45 | if file.endswith(extension):
46 | if (amount is not None) and (cnt == amount):
47 | break
48 | if str_ is not None:
49 | if str_ not in file:
50 | continue
51 | mix_path = os.path.join(root, file)
52 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
53 | if not is_ext:
54 | ext = pure_path.split('.')[-1]
55 | pure_path = pure_path[:-(len(ext)+1)]
56 | if verbose:
57 | print(pure_path)
58 | file_list.append(pure_path)
59 | cnt += 1
60 | if verbose:
61 | print('Total: %d files' % len(file_list))
62 | print('Done!!!')
63 | if is_sort:
64 | file_list.sort()
65 | return file_list
66 |
67 |
68 | def proc_one(path_infile, path_outfile):
69 | print('----')
70 | print(' >', path_infile)
71 | print(' >', path_outfile)
72 |
73 | # load
74 | midi_obj = miditoolkit.midi.parser.MidiFile(path_infile)
75 | midi_obj_out = copy.deepcopy(midi_obj)
76 | notes = midi_obj.instruments[0].notes
77 | notes = sorted(notes, key=lambda x: (x.start, x.pitch))
78 |
79 | # --- chord --- #
80 | # exctract chord
81 | chords = Dechorder.dechord(midi_obj)
82 | markers = []
83 | for cidx, chord in enumerate(chords):
84 | if chord.is_complete():
85 | chord_text = num2pitch[chord.root_pc] + '_' + chord.quality + '_' + num2pitch[chord.bass_pc]
86 | else:
87 | chord_text = 'N_N_N'
88 | markers.append(Marker(time=int(cidx*480), text=chord_text))
89 |
90 | # dedup
91 | prev_chord = None
92 | dedup_chords = []
93 | for m in markers:
94 | if m.text != prev_chord:
95 | prev_chord = m.text
96 | dedup_chords.append(m)
97 |
98 | # --- global properties --- #
99 | # global tempo
100 | tempos = [b.tempo for b in midi_obj.tempo_changes][:40]
101 | tempo_median = np.median(tempos)
102 | global_bpm =int(tempo_median)
103 | print(' > [global] bpm:', global_bpm)
104 |
105 | # === save === #
106 | # mkdir
107 | fn = os.path.basename(path_outfile)
108 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True)
109 |
110 | # markers
111 | midi_obj_out.markers = dedup_chords
112 | midi_obj_out.markers.insert(0, Marker(text='global_bpm_'+str(int(global_bpm)), time=0))
113 |
114 | # save
115 | midi_obj_out.instruments[0].name = 'piano'
116 | midi_obj_out.dump(path_outfile)
117 |
118 |
119 | if __name__ == '__main__':
120 | # paths
121 | path_indir = './midi_synchronized'
122 | path_outdir = './midi_analyzed'
123 | os.makedirs(path_outdir, exist_ok=True)
124 |
125 | # list files
126 | midifiles = traverse_dir(
127 | path_indir,
128 | is_pure=True,
129 | is_sort=True)
130 | n_files = len(midifiles)
131 | print('num fiels:', n_files)
132 |
133 | # collect
134 | data = []
135 | for fidx in range(n_files):
136 | path_midi = midifiles[fidx]
137 | print('{}/{}'.format(fidx, n_files))
138 |
139 | # paths
140 | path_infile = os.path.join(path_indir, path_midi)
141 | path_outfile = os.path.join(path_outdir, path_midi)
142 |
143 | # append
144 | data.append([path_infile, path_outfile])
145 |
146 | # run, multi-thread
147 | pool = mp.Pool()
148 | pool.starmap(proc_one, data)
--------------------------------------------------------------------------------
/dataset/corpus/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/corpus/.DS_Store
--------------------------------------------------------------------------------
/dataset/corpus/keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/corpus/keep
--------------------------------------------------------------------------------
/dataset/midi2corpus.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import pickle
4 | import numpy as np
5 | import miditoolkit
6 | import collections
7 |
8 | import pandas as pd
9 | import matplotlib.pyplot as plt
10 | import seaborn as sns
11 |
12 | # ================================================== #
13 | # Configuration #
14 | # ================================================== #
15 | BEAT_RESOL = 480
16 | BAR_RESOL = BEAT_RESOL * 4
17 | TICK_RESOL = BEAT_RESOL // 4
18 | INSTR_NAME_MAP = {'piano': 0}
19 | MIN_BPM = 40
20 | MIN_VELOCITY = 40
21 | NOTE_SORTING = 1 # 0: ascending / 1: descending
22 |
23 | DEFAULT_VELOCITY_BINS = np.linspace(0, 128, 64+1, dtype=np.int)
24 | DEFAULT_BPM_BINS = np.linspace(32, 224, 64+1, dtype=np.int)
25 | DEFAULT_SHIFT_BINS = np.linspace(-60, 60, 60+1, dtype=np.int)
26 | DEFAULT_DURATION_BINS = np.arange(
27 | BEAT_RESOL/8, BEAT_RESOL*8+1, BEAT_RESOL/8)
28 |
29 | # ================================================== #
30 |
31 |
32 | def traverse_dir(
33 | root_dir,
34 | extension=('mid', 'MID', 'midi'),
35 | amount=None,
36 | str_=None,
37 | is_pure=False,
38 | verbose=False,
39 | is_sort=False,
40 | is_ext=True):
41 | if verbose:
42 | print('[*] Scanning...')
43 | file_list = []
44 | cnt = 0
45 | for root, _, files in os.walk(root_dir):
46 | for file in files:
47 | if file.endswith(extension):
48 | if (amount is not None) and (cnt == amount):
49 | break
50 | if str_ is not None:
51 | if str_ not in file:
52 | continue
53 | mix_path = os.path.join(root, file)
54 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
55 | if not is_ext:
56 | ext = pure_path.split('.')[-1]
57 | pure_path = pure_path[:-(len(ext)+1)]
58 | if verbose:
59 | print(pure_path)
60 | file_list.append(pure_path)
61 | cnt += 1
62 | if verbose:
63 | print('Total: %d files' % len(file_list))
64 | print('Done!!!')
65 | if is_sort:
66 | file_list.sort()
67 | return file_list
68 |
69 |
70 | def proc_one(path_midi, path_outfile):
71 | # --- load --- #
72 | midi_obj = miditoolkit.midi.parser.MidiFile(path_midi)
73 |
74 | # load notes
75 | instr_notes = collections.defaultdict(list)
76 | for instr in midi_obj.instruments:
77 | # skip
78 | if instr.name not in INSTR_NAME_MAP.keys():
79 | continue
80 |
81 | # process
82 | instr_idx = INSTR_NAME_MAP[instr.name]
83 | for note in instr.notes:
84 | note.instr_idx=instr_idx
85 | instr_notes[instr_idx].append(note)
86 | if NOTE_SORTING == 0:
87 | instr_notes[instr_idx].sort(
88 | key=lambda x: (x.start, x.pitch))
89 | elif NOTE_SORTING == 1:
90 | instr_notes[instr_idx].sort(
91 | key=lambda x: (x.start, -x.pitch))
92 | else:
93 | raise ValueError(' [x] Unknown type of sorting.')
94 |
95 | # load chords
96 | chords = []
97 | for marker in midi_obj.markers:
98 | if marker.text.split('_')[0] != 'global' and \
99 | 'Boundary' not in marker.text.split('_')[0]:
100 | chords.append(marker)
101 | chords.sort(key=lambda x: x.time)
102 |
103 | # load tempos
104 | tempos = midi_obj.tempo_changes
105 | tempos.sort(key=lambda x: x.time)
106 |
107 | # load labels
108 | labels = []
109 | for marker in midi_obj.markers:
110 | if 'Boundary' in marker.text.split('_')[0]:
111 | labels.append(marker)
112 | labels.sort(key=lambda x: x.time)
113 |
114 | # load global bpm
115 | gobal_bpm = 120
116 | for marker in midi_obj.markers:
117 | if marker.text.split('_')[0] == 'global' and \
118 | marker.text.split('_')[1] == 'bpm':
119 | gobal_bpm = int(marker.text.split('_')[2])
120 |
121 | # --- process items to grid --- #
122 | # compute empty bar offset at head
123 | first_note_time = min([instr_notes[k][0].start for k in instr_notes.keys()])
124 | last_note_time = max([instr_notes[k][-1].start for k in instr_notes.keys()])
125 |
126 | quant_time_first = int(np.round(first_note_time / TICK_RESOL) * TICK_RESOL)
127 | offset = quant_time_first // BAR_RESOL # empty bar
128 | last_bar = int(np.ceil(last_note_time / BAR_RESOL)) - offset
129 | print(' > offset:', offset)
130 | print(' > last_bar:', last_bar)
131 |
132 | # process notes
133 | intsr_gird = dict()
134 | for key in instr_notes.keys():
135 | notes = instr_notes[key]
136 | note_grid = collections.defaultdict(list)
137 | for note in notes:
138 | note.start = note.start - offset * BAR_RESOL
139 | note.end = note.end - offset * BAR_RESOL
140 |
141 | # quantize start
142 | quant_time = int(np.round(note.start / TICK_RESOL) * TICK_RESOL)
143 |
144 | # velocity
145 | note.velocity = DEFAULT_VELOCITY_BINS[
146 | np.argmin(abs(DEFAULT_VELOCITY_BINS-note.velocity))]
147 | note.velocity = max(MIN_VELOCITY, note.velocity)
148 |
149 | # shift of start
150 | note.shift = note.start - quant_time
151 | note.shift = DEFAULT_SHIFT_BINS[np.argmin(abs(DEFAULT_SHIFT_BINS-note.shift))]
152 |
153 | # duration
154 | note_duration = note.end - note.start
155 | if note_duration > BAR_RESOL:
156 | note_duration = BAR_RESOL
157 | ntick_duration = int(np.round(note_duration / TICK_RESOL) * TICK_RESOL)
158 | note.duration = ntick_duration
159 |
160 | # append
161 | note_grid[quant_time].append(note)
162 |
163 | # set to track
164 | intsr_gird[key] = note_grid.copy()
165 |
166 | # process chords
167 | chord_grid = collections.defaultdict(list)
168 | for chord in chords:
169 | # quantize
170 | chord.time = chord.time - offset * BAR_RESOL
171 | chord.time = 0 if chord.time < 0 else chord.time
172 | quant_time = int(np.round(chord.time / TICK_RESOL) * TICK_RESOL)
173 |
174 | # append
175 | chord_grid[quant_time].append(chord)
176 |
177 | # process tempo
178 | tempo_grid = collections.defaultdict(list)
179 | for tempo in tempos:
180 | # quantize
181 | tempo.time = tempo.time - offset * BAR_RESOL
182 | tempo.time = 0 if tempo.time < 0 else tempo.time
183 | quant_time = int(np.round(tempo.time / TICK_RESOL) * TICK_RESOL)
184 | tempo.tempo = DEFAULT_BPM_BINS[np.argmin(abs(DEFAULT_BPM_BINS-tempo.tempo))]
185 |
186 | # append
187 | tempo_grid[quant_time].append(tempo)
188 |
189 | # process boundary
190 | label_grid = collections.defaultdict(list)
191 | for label in labels:
192 | # quantize
193 | label.time = label.time - offset * BAR_RESOL
194 | label.time = 0 if label.time < 0 else label.time
195 | quant_time = int(np.round(label.time / TICK_RESOL) * TICK_RESOL)
196 |
197 | # append
198 | label_grid[quant_time] = [label]
199 |
200 | # process global bpm
201 | gobal_bpm = DEFAULT_BPM_BINS[np.argmin(abs(DEFAULT_BPM_BINS-gobal_bpm))]
202 |
203 | # collect
204 | song_data = {
205 | 'notes': intsr_gird,
206 | 'chords': chord_grid,
207 | 'tempos': tempo_grid,
208 | 'labels': label_grid,
209 | 'metadata': {
210 | 'global_bpm': gobal_bpm,
211 | 'last_bar': last_bar,
212 | }
213 | }
214 |
215 | # save
216 | fn = os.path.basename(path_outfile)
217 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True)
218 | pickle.dump(song_data, open(path_outfile, 'wb'))
219 |
220 |
221 | if __name__ == '__main__':
222 | # paths
223 | path_indir = './midi_analyzed'
224 | path_outdir = './corpus'
225 | os.makedirs(path_outdir, exist_ok=True)
226 |
227 | # list files
228 | midifiles = traverse_dir(
229 | path_indir,
230 | is_pure=True,
231 | is_sort=True)
232 | n_files = len(midifiles)
233 | print('num fiels:', n_files)
234 |
235 | # run all
236 | for fidx in range(n_files):
237 | path_midi = midifiles[fidx]
238 | print('{}/{}'.format(fidx, n_files))
239 |
240 | # paths
241 | path_infile = os.path.join(path_indir, path_midi)
242 | path_outfile = os.path.join(path_outdir, path_midi+'.pkl')
243 |
244 | # proc
245 | proc_one(path_infile, path_outfile)
--------------------------------------------------------------------------------
/dataset/midi_analyzed/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/midi_analyzed/.DS_Store
--------------------------------------------------------------------------------
/dataset/midi_analyzed/keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/midi_analyzed/keep
--------------------------------------------------------------------------------
/dataset/midi_synchronized/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/midi_synchronized/.DS_Store
--------------------------------------------------------------------------------
/dataset/midi_synchronized/keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/midi_synchronized/keep
--------------------------------------------------------------------------------
/dataset/midi_transcribed/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/midi_transcribed/.DS_Store
--------------------------------------------------------------------------------
/dataset/midi_transcribed/keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/midi_transcribed/keep
--------------------------------------------------------------------------------
/dataset/representations/cond-ls2midi/keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/representations/cond-ls2midi/keep
--------------------------------------------------------------------------------
/dataset/representations/uncond/cp/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/representations/uncond/cp/.DS_Store
--------------------------------------------------------------------------------
/dataset/representations/uncond/cp/compile.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import numpy as np
5 |
6 |
7 | TEST_AMOUNT = 50
8 | WINDOW_SIZE = 512
9 | GROUP_SIZE = 7
10 | MAX_LEN = WINDOW_SIZE * GROUP_SIZE
11 | COMPILE_TARGET = 'linear' # 'linear', 'XL'
12 | print('[config] MAX_LEN:', MAX_LEN)
13 |
14 |
15 | def traverse_dir(
16 | root_dir,
17 | extension=('mid', 'MID'),
18 | amount=None,
19 | str_=None,
20 | is_pure=False,
21 | verbose=False,
22 | is_sort=False,
23 | is_ext=True):
24 | if verbose:
25 | print('[*] Scanning...')
26 | file_list = []
27 | cnt = 0
28 | for root, _, files in os.walk(root_dir):
29 | for file in files:
30 | if file.endswith(extension):
31 | if (amount is not None) and (cnt == amount):
32 | break
33 | if str_ is not None:
34 | if str_ not in file:
35 | continue
36 | mix_path = os.path.join(root, file)
37 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
38 | if not is_ext:
39 | ext = pure_path.split('.')[-1]
40 | pure_path = pure_path[:-(len(ext)+1)]
41 | if verbose:
42 | print(pure_path)
43 | file_list.append(pure_path)
44 | cnt += 1
45 | if verbose:
46 | print('Total: %d files' % len(file_list))
47 | print('Done!!!')
48 | if is_sort:
49 | file_list.sort()
50 | return file_list
51 |
52 |
53 | if __name__ == '__main__':
54 | # paths
55 | path_root = 'ailab17k_from-scratch_cp'
56 | path_indir = os.path.join( path_root, 'words')
57 |
58 | # load dictionary
59 | path_dictionary = os.path.join(path_root, 'dictionary.pkl')
60 | event2word, word2event = pickle.load(open(path_dictionary, 'rb'))
61 |
62 | # load all words
63 | wordfiles = traverse_dir(
64 | path_indir,
65 | extension=('npy'))
66 | n_files = len(wordfiles)
67 |
68 | # init
69 | x_list = []
70 | y_list = []
71 | mask_list = []
72 | seq_len_list = []
73 | num_groups_list = []
74 | name_list = []
75 |
76 | # process
77 | for fidx in range(n_files):
78 | print('--[{}/{}]-----'.format(fidx, n_files))
79 | file = wordfiles[fidx]
80 | words = np.load(file)
81 | num_words = len(words)
82 | eos_arr = words[-1][None, ...]
83 |
84 | if num_words >= MAX_LEN - 2: # 2 for room
85 | print(' [!] too long:', num_words)
86 | continue
87 |
88 | # arrange IO
89 | x = words[:-1].copy()
90 | y = words[1:].copy()
91 | seq_len = len(x)
92 | print(' > seq_len:', seq_len)
93 |
94 | # pad with eos
95 | pad = np.tile(
96 | eos_arr,
97 | (MAX_LEN-seq_len, 1))
98 |
99 | x = np.concatenate([x, pad], axis=0)
100 | y = np.concatenate([y, pad], axis=0)
101 | mask = np.concatenate(
102 | [np.ones(seq_len), np.zeros(MAX_LEN-seq_len)])
103 |
104 | # collect
105 | x_list.append(x)
106 | y_list.append(y)
107 | mask_list.append(mask)
108 | seq_len_list.append(seq_len)
109 | num_groups_list.append(int(np.ceil(seq_len/WINDOW_SIZE)))
110 | name_list.append(file)
111 |
112 | # sort by length (descending)
113 | zipped = zip(seq_len_list, x_list, y_list, mask_list, num_groups_list, name_list)
114 | seq_len_list, x_list, y_list, mask_list, num_groups_list, name_list = zip(
115 | *sorted(zipped, key=lambda x: -x[0]))
116 |
117 | print('\n\n[Finished]')
118 | print(' compile target:', COMPILE_TARGET)
119 | if COMPILE_TARGET == 'XL':
120 | # reshape
121 | x_final = np.array(x_list).reshape(len(x_list), GROUP_SIZE, WINDOW_SIZE, -1)
122 | y_final = np.array(y_list).reshape(len(x_list), GROUP_SIZE, WINDOW_SIZE, -1)
123 | mask_final = np.array(mask_list).reshape(-1, GROUP_SIZE, WINDOW_SIZE)
124 | elif COMPILE_TARGET == 'linear':
125 | x_final = np.array(x_list)
126 | y_final = np.array(y_list)
127 | mask_final = np.array(mask_list)
128 | else:
129 | raise ValueError('Unknown target:', COMPILE_TARGET)
130 |
131 | # check
132 | num_samples = len(seq_len_list)
133 | print(' > count:', )
134 | print(' > x_final:', x_final.shape)
135 | print(' > y_final:', y_final.shape)
136 | print(' > mask_final:', mask_final.shape)
137 |
138 | # split train/test
139 | validation_songs = json.load(open('../validation_songs.json', 'r'))
140 | train_idx = []
141 | test_idx = []
142 |
143 | # validation filename map
144 | fn2idx_map = {
145 | 'fn2idx': dict(),
146 | 'idx2fn': dict(),
147 | }
148 |
149 | # run split
150 | valid_cnt = 0
151 | for nidx, n in enumerate(name_list):
152 | flag = True
153 | for fn in validation_songs:
154 | if fn in n:
155 | test_idx.append(nidx)
156 | flag = False
157 | fn2idx_map['fn2idx'][fn] = valid_cnt
158 | fn2idx_map['idx2fn'][valid_cnt] = fn
159 | valid_cnt += 1
160 | break
161 |
162 | if flag:
163 | train_idx.append(nidx)
164 | test_idx = np.array(test_idx)
165 | train_idx = np.array(train_idx)
166 |
167 | # save validation map
168 | path_fn2idx_map = os.path.join(path_root, 'valid_fn2idx_map.json')
169 | with open(path_fn2idx_map, 'w') as f:
170 | json.dump(fn2idx_map, f)
171 |
172 | # save train
173 | path_train = os.path.join(path_root, 'train_data_{}'.format(COMPILE_TARGET))
174 | path_train += '.npz'
175 | np.savez(
176 | path_train,
177 | x=x_final[train_idx],
178 | y=y_final[train_idx],
179 | mask=mask_final[train_idx],
180 | seq_len=np.array(seq_len_list)[train_idx],
181 | num_groups=np.array(num_groups_list)[train_idx]
182 | )
183 |
184 | # save test
185 | path_test = os.path.join(path_root, 'test_data_{}'.format(COMPILE_TARGET))
186 | path_test += '.npz'
187 | np.savez(
188 | path_test,
189 | x=x_final[test_idx],
190 | y=y_final[test_idx],
191 | mask=mask_final[test_idx],
192 | seq_len=np.array(seq_len_list)[test_idx],
193 | num_groups=np.array(num_groups_list)[test_idx]
194 | )
195 |
196 | print('---')
197 | print(' > train x:', x_final[train_idx].shape)
198 | print(' > test x:', x_final[test_idx].shape)
199 |
200 |
--------------------------------------------------------------------------------
/dataset/representations/uncond/cp/corpus2events.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | # config
8 | BEAT_RESOL = 480
9 | BAR_RESOL = BEAT_RESOL * 4
10 | TICK_RESOL = BEAT_RESOL // 4
11 |
12 |
13 | # utilities
14 | def plot_hist(data, path_outfile):
15 | print('[Fig] >> {}'.format(path_outfile))
16 | data_mean = np.mean(data)
17 | data_std = np.std(data)
18 |
19 | print('mean:', data_mean)
20 | print(' std:', data_std)
21 |
22 | plt.figure(dpi=100)
23 | plt.hist(data, bins=50)
24 | plt.title('mean: {:.3f}_std: {:.3f}'.format(data_mean, data_std))
25 | plt.savefig(path_outfile)
26 | plt.close()
27 |
28 | def traverse_dir(
29 | root_dir,
30 | extension=('mid', 'MID'),
31 | amount=None,
32 | str_=None,
33 | is_pure=False,
34 | verbose=False,
35 | is_sort=False,
36 | is_ext=True):
37 | if verbose:
38 | print('[*] Scanning...')
39 | file_list = []
40 | cnt = 0
41 | for root, _, files in os.walk(root_dir):
42 | for file in files:
43 | if file.endswith(extension):
44 | if (amount is not None) and (cnt == amount):
45 | break
46 | if str_ is not None:
47 | if str_ not in file:
48 | continue
49 | mix_path = os.path.join(root, file)
50 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
51 | if not is_ext:
52 | ext = pure_path.split('.')[-1]
53 | pure_path = pure_path[:-(len(ext)+1)]
54 | if verbose:
55 | print(pure_path)
56 | file_list.append(pure_path)
57 | cnt += 1
58 | if verbose:
59 | print('Total: %d files' % len(file_list))
60 | print('Done!!!')
61 | if is_sort:
62 | file_list.sort()
63 | return file_list
64 |
65 | # ---- define event ---- #
66 | ''' 8 kinds:
67 | tempo: 0: IGN
68 | 1: no change
69 | int: tempo
70 | chord: 0: IGN
71 | 1: no change
72 | str: chord types
73 | bar-beat: 0: IGN
74 | int: beat position (1...16)
75 | int: bar (bar)
76 | type: 0: eos
77 | 1: metrical
78 | 2: note
79 | duration: 0: IGN
80 | int: length
81 | pitch: 0: IGN
82 | int: pitch
83 | velocity: 0: IGN
84 | int: velocity
85 | '''
86 |
87 | # event template
88 | compound_event = {
89 | 'tempo': 0,
90 | 'chord': 0,
91 | 'bar-beat': 0,
92 | 'type': 0,
93 | 'pitch': 0,
94 | 'duration': 0,
95 | 'velocity': 0,
96 | }
97 |
98 |
99 | def create_bar_event():
100 | meter_event = compound_event.copy()
101 | meter_event['bar-beat'] = 'Bar'
102 | meter_event['type'] = 'Metrical'
103 | return meter_event
104 |
105 |
106 | def create_piano_metrical_event(tempo, chord, pos):
107 | meter_event = compound_event.copy()
108 | meter_event['tempo'] = tempo
109 | meter_event['chord'] = chord
110 | meter_event['bar-beat'] = pos
111 | meter_event['type'] = 'Metrical'
112 | return meter_event
113 |
114 |
115 | def create_piano_note_event(pitch, duration, velocity):
116 | note_event = compound_event.copy()
117 | note_event['pitch'] = pitch
118 | note_event['duration'] = duration
119 | note_event['velocity'] = velocity
120 | note_event['type'] = 'Note'
121 | return note_event
122 |
123 |
124 | def create_eos_event():
125 | eos_event = compound_event.copy()
126 | eos_event['type'] = 'EOS'
127 | return eos_event
128 |
129 |
130 | # ----------------------------------------------- #
131 | # core functions
132 | def corpus2event_cp(path_infile, path_outfile):
133 | '''
134 | task: 2 track
135 | 1: piano (note + tempo)
136 | ---
137 | remove duplicate position tokens
138 | '''
139 | data = pickle.load(open(path_infile, 'rb'))
140 |
141 | # global tag
142 | global_end = data['metadata']['last_bar'] * BAR_RESOL
143 |
144 | # process
145 | final_sequence = []
146 | for bar_step in range(0, global_end, BAR_RESOL):
147 | final_sequence.append(create_bar_event())
148 |
149 | # --- piano track --- #
150 | for timing in range(bar_step, bar_step + BAR_RESOL, TICK_RESOL):
151 | pos_on = False
152 | pos_events = []
153 | pos_text = 'Beat_' + str((timing-bar_step)//TICK_RESOL)
154 |
155 | # unpack
156 | t_chords = data['chords'][timing]
157 | t_tempos = data['tempos'][timing]
158 | t_notes = data['notes'][0][timing] # piano track
159 |
160 | # metrical
161 | if len(t_tempos) or len(t_chords):
162 | # chord
163 | if len(t_chords):
164 |
165 | root, quality, bass = t_chords[-1].text.split('_')
166 | chord_text = root+'_'+quality
167 | else:
168 | chord_text = 'CONTI'
169 |
170 | # tempo
171 | if len(t_tempos):
172 | tempo_text = 'Tempo_' + str(t_tempos[-1].tempo)
173 | else:
174 | tempo_text = 'CONTI'
175 |
176 | # create
177 | pos_events.append(
178 | create_piano_metrical_event(
179 | tempo_text, chord_text, pos_text))
180 | pos_on = True
181 |
182 | # note
183 | if len(t_notes):
184 | if not pos_on:
185 | pos_events.append(
186 | create_piano_metrical_event(
187 | 'CONTI', 'CONTI', pos_text))
188 |
189 | for note in t_notes:
190 | note_pitch_text = 'Note_Pitch_' + str(note.pitch)
191 | note_duration_text = 'Note_Duration_' + str(note.duration)
192 | note_velocity_text = 'Note_Velocity_' + str(note.velocity)
193 |
194 | pos_events.append(
195 | create_piano_note_event(
196 | note_pitch_text,
197 | note_duration_text,
198 | note_velocity_text))
199 |
200 | # collect & beat
201 | if len(pos_events):
202 | final_sequence.extend(pos_events)
203 |
204 | # BAR ending
205 | final_sequence.append(create_bar_event())
206 |
207 | # EOS
208 | final_sequence.append(create_eos_event())
209 |
210 | # save
211 | fn = os.path.basename(path_outfile)
212 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True)
213 | pickle.dump(final_sequence, open(path_outfile, 'wb'))
214 |
215 | return len(final_sequence)
216 |
217 |
218 | if __name__ == '__main__':
219 | # paths
220 | path_root = './ailab17k_from-scratch_cp'
221 | path_indir = '../../../corpus'
222 | path_outdir = os.path.join(path_root, 'events')
223 | os.makedirs(path_outdir, exist_ok=True)
224 |
225 | # list files
226 | midifiles = traverse_dir(
227 | path_indir,
228 | extension=('pkl'),
229 | is_pure=True,
230 | is_sort=True)
231 | n_files = len(midifiles)
232 | print('num fiels:', n_files)
233 |
234 | # run all
235 | len_list = []
236 | for fidx in range(n_files):
237 | path_midi = midifiles[fidx]
238 | print('{}/{}'.format(fidx, n_files))
239 |
240 | # paths
241 | path_infile = os.path.join(path_indir, path_midi)
242 | path_outfile = os.path.join(path_outdir, path_midi)
243 |
244 | # proc
245 | num_tokens = corpus2event_cp(path_infile, path_outfile)
246 | print(' > num_token:', num_tokens)
247 | len_list.append(num_tokens)
248 |
249 | # plot
250 | plot_hist(
251 | len_list,
252 | os.path.join(path_root, 'num_tokens.png')
253 | )
--------------------------------------------------------------------------------
/dataset/representations/uncond/cp/events2words.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | import collections
5 |
6 |
7 | def traverse_dir(
8 | root_dir,
9 | extension=('mid', 'MID'),
10 | amount=None,
11 | str_=None,
12 | is_pure=False,
13 | verbose=False,
14 | is_sort=False,
15 | is_ext=True):
16 | if verbose:
17 | print('[*] Scanning...')
18 | file_list = []
19 | cnt = 0
20 | for root, _, files in os.walk(root_dir):
21 | for file in files:
22 | if file.endswith(extension):
23 | if (amount is not None) and (cnt == amount):
24 | break
25 | if str_ is not None:
26 | if str_ not in file:
27 | continue
28 | mix_path = os.path.join(root, file)
29 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
30 | if not is_ext:
31 | ext = pure_path.split('.')[-1]
32 | pure_path = pure_path[:-(len(ext)+1)]
33 | if verbose:
34 | print(pure_path)
35 | file_list.append(pure_path)
36 | cnt += 1
37 | if verbose:
38 | print('Total: %d files' % len(file_list))
39 | print('Done!!!')
40 | if is_sort:
41 | file_list.sort()
42 | return file_list
43 |
44 |
45 | if __name__ == '__main__':
46 | # paths
47 | path_root = 'ailab17k_from-scratch_cp'
48 | path_indir = os.path.join(path_root, 'events')
49 | path_outdir = os.path.join(path_root, 'words')
50 | path_dictionary = os.path.join(path_root, 'dictionary.pkl')
51 | os.makedirs(path_outdir, exist_ok=True)
52 |
53 | # list files
54 | eventfiles = traverse_dir(
55 | path_indir,
56 | is_pure=True,
57 | is_sort=True,
58 | extension=('pkl'))
59 | n_files = len(eventfiles)
60 | print('num fiels:', n_files)
61 |
62 | # --- build dictionary --- #
63 | # all files
64 | class_keys = pickle.load(
65 | open(os.path.join(path_indir, eventfiles[0]), 'rb'))[0].keys()
66 | print('class keys:', class_keys)
67 |
68 | # define dictionary
69 | event2word = {}
70 | word2event = {}
71 |
72 | corpus_kv = collections.defaultdict(list)
73 | for file in eventfiles:
74 | for event in pickle.load(open(
75 | os.path.join(path_indir, file), 'rb')):
76 | for key in class_keys:
77 | corpus_kv[key].append(event[key])
78 |
79 | for ckey in class_keys:
80 | class_unique_vals = sorted(
81 | set(corpus_kv[ckey]), key=lambda x: (not isinstance(x, int), x))
82 | event2word[ckey] = {key: i for i, key in enumerate(class_unique_vals)}
83 | word2event[ckey] = {i: key for i, key in enumerate(class_unique_vals)}
84 |
85 | # print
86 | print('[class size]')
87 | for key in class_keys:
88 | print(' > {:10s}: {}'.format(key, len(event2word[key])))
89 |
90 | # save
91 | path_dict = os.path.join(path_root, 'dictionary.pkl')
92 | pickle.dump((event2word, word2event), open(path_dict, 'wb'))
93 |
94 | # --- compile each --- #
95 | # reload
96 | event2word, word2event = pickle.load(open(path_dict, 'rb'))
97 | for fidx in range(len(eventfiles)):
98 | file = eventfiles[fidx]
99 | events_list = pickle.load(open(
100 | os.path.join(path_indir, file), 'rb'))
101 | fn = os.path.basename(file)
102 | path_outfile = os.path.join(path_outdir, fn)
103 |
104 | print('({}/{})'.format(fidx, len(eventfiles)))
105 | print(' > from:', file)
106 | print(' > to:', path_outfile)
107 |
108 | words = []
109 | for eidx, e in enumerate(events_list):
110 | words_tmp = [
111 | event2word[k][e[k]] for k in class_keys
112 | ]
113 | words.append(words_tmp)
114 |
115 | # save
116 | path_outfile = os.path.join(path_outdir, file + '.npy')
117 | fn = os.path.basename(path_outfile)
118 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True)
119 | np.save(path_outfile, words)
120 |
--------------------------------------------------------------------------------
/dataset/representations/uncond/remi/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/dataset/representations/uncond/remi/.DS_Store
--------------------------------------------------------------------------------
/dataset/representations/uncond/remi/compile.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import numpy as np
5 |
6 |
7 | TEST_AMOUNT = 50
8 | WINDOW_SIZE = 512
9 | GROUP_SIZE = 15
10 | MAX_LEN = WINDOW_SIZE * GROUP_SIZE
11 | COMPILE_TARGET = 'XL' # 'linear', 'XL'
12 | print('[config] MAX_LEN:', MAX_LEN)
13 |
14 |
15 | def traverse_dir(
16 | root_dir,
17 | extension=('mid', 'MID'),
18 | amount=None,
19 | str_=None,
20 | is_pure=False,
21 | verbose=False,
22 | is_sort=False,
23 | is_ext=True):
24 | if verbose:
25 | print('[*] Scanning...')
26 | file_list = []
27 | cnt = 0
28 | for root, _, files in os.walk(root_dir):
29 | for file in files:
30 | if file.endswith(extension):
31 | if (amount is not None) and (cnt == amount):
32 | break
33 | if str_ is not None:
34 | if str_ not in file:
35 | continue
36 | mix_path = os.path.join(root, file)
37 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
38 | if not is_ext:
39 | ext = pure_path.split('.')[-1]
40 | pure_path = pure_path[:-(len(ext)+1)]
41 | if verbose:
42 | print(pure_path)
43 | file_list.append(pure_path)
44 | cnt += 1
45 | if verbose:
46 | print('Total: %d files' % len(file_list))
47 | print('Done!!!')
48 | if is_sort:
49 | file_list.sort()
50 | return file_list
51 |
52 |
53 | if __name__ == '__main__':
54 | # paths
55 | path_root = 'ailab17k_from-scratch_remi'
56 | path_indir = os.path.join( path_root, 'words')
57 |
58 | # load dictionary
59 | path_dictionary = os.path.join(path_root, 'dictionary.pkl')
60 | event2word, word2event = pickle.load(open(path_dictionary, 'rb'))
61 | eos_id = event2word['EOS_None']
62 | print(' > eos_id:', eos_id)
63 |
64 | # load all words
65 | wordfiles = traverse_dir(
66 | path_indir,
67 | extension=('npy'))
68 |
69 | # load dictionary
70 | path_dictionary = os.path.join(path_root, 'dictionary.pkl')
71 | event2word, word2event = pickle.load(open(path_dictionary, 'rb'))
72 | eos_id = event2word['EOS_None']
73 | print(' > eos_id:', eos_id)
74 |
75 | # load all words
76 | wordfiles = traverse_dir(
77 | path_indir,
78 | extension=('npy'))
79 | n_files = len(wordfiles)
80 |
81 | # init
82 | x_list = []
83 | y_list = []
84 | mask_list = []
85 | seq_len_list = []
86 | num_groups_list = []
87 | name_list = []
88 |
89 | # process
90 | for fidx in range(n_files):
91 | print('--[{}/{}]-----'.format(fidx, n_files))
92 | file = wordfiles[fidx]
93 | words = np.load(file)
94 | num_words = len(words)
95 |
96 | if num_words >= MAX_LEN - 2: # 2 for room
97 | print(' [!] too long:', num_words)
98 | continue
99 |
100 | # arrange IO
101 | x = words[:-1]
102 | y = words[1:]
103 | seq_len = len(x)
104 | print(' > seq_len:', seq_len)
105 |
106 | # pad with eos
107 | x = np.concatenate([x, np.ones(MAX_LEN-seq_len) * eos_id])
108 | y = np.concatenate([y, np.ones(MAX_LEN-seq_len) * eos_id])
109 | mask = np.concatenate(
110 | [np.ones(seq_len), np.zeros(MAX_LEN-seq_len)])
111 |
112 | # collect
113 | x_list.append(x)
114 | y_list.append(y)
115 | mask_list.append(mask)
116 | seq_len_list.append(seq_len)
117 | num_groups_list.append(int(np.ceil(seq_len/WINDOW_SIZE)))
118 | name_list.append(file)
119 |
120 | # sort by length (descending)
121 | zipped = zip(seq_len_list, x_list, y_list, mask_list, num_groups_list, name_list)
122 | seq_len_list, x_list, y_list, mask_list, num_groups_list, name_list = zip(
123 | *sorted(zipped, key=lambda x: -x[0]))
124 |
125 | print('\n\n[Finished]')
126 | print(' compile target:', COMPILE_TARGET)
127 | if COMPILE_TARGET == 'XL':
128 | x_final = np.array(x_list).reshape(-1, GROUP_SIZE, WINDOW_SIZE)
129 | y_final = np.array(y_list).reshape(-1, GROUP_SIZE, WINDOW_SIZE)
130 | mask_final = np.array(mask_list).reshape(-1, GROUP_SIZE, WINDOW_SIZE)
131 | elif COMPILE_TARGET == 'linear':
132 | x_final = np.array(x_list)
133 | y_final = np.array(y_list)
134 | mask_final = np.array(mask_list)
135 | else:
136 | raise ValueError('Unknown target:', COMPILE_TARGET)
137 | num_samples = len(seq_len_list)
138 | print(' > count:', )
139 | print(' > x_final:', x_final.shape)
140 | print(' > y_final:', y_final.shape)
141 | print(' > mask_final:', mask_final.shape)
142 |
143 | # split train/test
144 | validation_songs = json.load(open('../validation_songs.json', 'r'))
145 | train_idx = []
146 | test_idx = []
147 |
148 | # validation filename map
149 | fn_idx_map = {
150 | 'fn2idx': dict(),
151 | 'idx2fn': dict(),
152 | }
153 |
154 | # run split
155 | valid_cnt = 0
156 | for nidx, n in enumerate(name_list):
157 | flag = True
158 | for fn in validation_songs:
159 | if fn in n:
160 | test_idx.append(nidx)
161 | flag = False
162 | fn_idx_map['fn2idx'][fn] = valid_cnt
163 | fn_idx_map['idx2fn'][valid_cnt] = fn
164 | valid_cnt += 1
165 | break
166 | if flag:
167 | train_idx.append(nidx)
168 | test_idx = np.array(test_idx)
169 | train_idx = np.array(train_idx)
170 |
171 | # save validation map
172 | with open('valid_fn_idx_map.json', 'w') as f:
173 | json.dump(fn_idx_map, f)
174 |
175 | # save train
176 | path_train = os.path.join(path_root, 'train_data_{}.npz'.format(COMPILE_TARGET))
177 | np.savez(
178 | path_train,
179 | x=x_final[train_idx],
180 | y=y_final[train_idx],
181 | mask=mask_final[train_idx],
182 | seq_len=np.array(seq_len_list)[train_idx],
183 | num_groups=np.array(num_groups_list)[train_idx]
184 | )
185 |
186 | # save test
187 | path_test = os.path.join(path_root, 'test_data_{}.npz'.format(COMPILE_TARGET))
188 | np.savez(
189 | path_test,
190 | x=x_final[test_idx],
191 | y=y_final[test_idx],
192 | mask=mask_final[test_idx],
193 | seq_len=np.array(seq_len_list)[test_idx],
194 | num_groups=np.array(num_groups_list)[test_idx]
195 | )
196 |
197 | print('---')
198 | print(' > train x:', x_final[train_idx].shape)
199 | print(' > test x:', x_final[test_idx].shape)
--------------------------------------------------------------------------------
/dataset/representations/uncond/remi/corpus2events.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | # config
8 | BEAT_RESOL = 480
9 | BAR_RESOL = BEAT_RESOL * 4
10 | TICK_RESOL = BEAT_RESOL // 4
11 |
12 |
13 | # utilities
14 | def plot_hist(data, path_outfile):
15 | print('[Fig] >> {}'.format(path_outfile))
16 | data_mean = np.mean(data)
17 | data_std = np.std(data)
18 |
19 | print('mean:', data_mean)
20 | print(' std:', data_std)
21 |
22 | plt.figure(dpi=100)
23 | plt.hist(data, bins=50)
24 | plt.title('mean: {:.3f}_std: {:.3f}'.format(data_mean, data_std))
25 | plt.savefig(path_outfile)
26 | plt.close()
27 |
28 |
29 | def traverse_dir(
30 | root_dir,
31 | extension=('mid', 'MID'),
32 | amount=None,
33 | str_=None,
34 | is_pure=False,
35 | verbose=False,
36 | is_sort=False,
37 | is_ext=True):
38 |
39 | if verbose:
40 | print('[*] Scanning...')
41 | file_list = []
42 | cnt = 0
43 | for root, _, files in os.walk(root_dir):
44 | for file in files:
45 | if file.endswith(extension):
46 | if (amount is not None) and (cnt == amount):
47 | break
48 | if str_ is not None:
49 | if str_ not in file:
50 | continue
51 | mix_path = os.path.join(root, file)
52 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
53 | if not is_ext:
54 | ext = pure_path.split('.')[-1]
55 | pure_path = pure_path[:-(len(ext)+1)]
56 | if verbose:
57 | print(pure_path)
58 | file_list.append(pure_path)
59 | cnt += 1
60 | if verbose:
61 | print('Total: %d files' % len(file_list))
62 | print('Done!!!')
63 | if is_sort:
64 | file_list.sort()
65 | return file_list
66 |
67 |
68 | # define event
69 | def create_event(name, value):
70 | event = dict()
71 | event['name'] = name
72 | event['value'] = value
73 | return event
74 |
75 |
76 | # core functions
77 | def corpus2event_remi_v2(path_infile, path_outfile):
78 | '''
79 | <<< REMI v2 >>>
80 | task: 2 track
81 | 1: piano (note + tempo + chord)
82 | ---
83 | remove duplicate position tokens
84 | '''
85 | data = pickle.load(open(path_infile, 'rb'))
86 |
87 | # global tag
88 | global_end = data['metadata']['last_bar'] * BAR_RESOL
89 |
90 | # process
91 | final_sequence = []
92 | for bar_step in range(0, global_end, BAR_RESOL):
93 | final_sequence.append(create_event('Bar', None))
94 |
95 | # --- piano track --- #
96 | for timing in range(bar_step, bar_step + BAR_RESOL, TICK_RESOL):
97 | pos_events = []
98 |
99 | # unpack
100 | t_chords = data['chords'][timing]
101 | t_tempos = data['tempos'][timing]
102 | t_notes = data['notes'][0][timing] # piano track
103 |
104 | # chord
105 | if len(t_chords):
106 | root, quality, bass = t_chords[0].text.split('_')
107 | pos_events.append(create_event('Chord', root+'_'+quality))
108 |
109 | # tempo
110 | if len(t_tempos):
111 | pos_events.append(create_event('Tempo', t_tempos[0].tempo))
112 |
113 | # note
114 | if len(t_notes):
115 | for note in t_notes:
116 | pos_events.extend([
117 | create_event('Note_Pitch', note.pitch),
118 | create_event('Note_Velocity', note.velocity),
119 | create_event('Note_Duration', note.duration),
120 | ])
121 |
122 | # collect & beat
123 | if len(pos_events):
124 | final_sequence.append(
125 | create_event('Beat', (timing-bar_step)//TICK_RESOL))
126 | final_sequence.extend(pos_events)
127 |
128 | # BAR ending
129 | final_sequence.append(create_event('Bar', None))
130 |
131 | # EOS
132 | final_sequence.append(create_event('EOS', None))
133 |
134 | # save
135 | fn = os.path.basename(path_outfile)
136 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True)
137 | pickle.dump(final_sequence, open(path_outfile, 'wb'))
138 | return len(final_sequence)
139 |
140 |
141 | if __name__ == '__main__':
142 | # paths
143 | path_root = './ailab17k_from-scratch_remi'
144 | path_indir = '../../../corpus'
145 | path_outdir = os.path.join(path_root, 'events')
146 | os.makedirs(path_outdir, exist_ok=True)
147 |
148 | # list files
149 | midifiles = traverse_dir(
150 | path_indir,
151 | extension=('pkl'),
152 | is_pure=True,
153 | is_sort=True)
154 | n_files = len(midifiles)
155 | print('num fiels:', n_files)
156 |
157 | # run all
158 | len_list = []
159 | for fidx in range(n_files):
160 | path_midi = midifiles[fidx]
161 | print('{}/{}'.format(fidx, n_files))
162 |
163 | # paths
164 | path_infile = os.path.join(path_indir, path_midi)
165 | path_outfile = os.path.join(path_outdir, path_midi)
166 |
167 | # proc
168 | num_tokens = corpus2event_remi_v2(path_infile, path_outfile)
169 | print(' > num_token:', num_tokens)
170 | len_list.append(num_tokens)
171 |
172 | # plot
173 | plot_hist(
174 | len_list,
175 | os.path.join(path_root, 'num_tokens.png')
176 | )
--------------------------------------------------------------------------------
/dataset/representations/uncond/remi/events2words.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 |
5 |
6 | def traverse_dir(
7 | root_dir,
8 | extension=('mid', 'MID'),
9 | amount=None,
10 | str_=None,
11 | is_pure=False,
12 | verbose=False,
13 | is_sort=False,
14 | is_ext=True):
15 | if verbose:
16 | print('[*] Scanning...')
17 | file_list = []
18 | cnt = 0
19 | for root, _, files in os.walk(root_dir):
20 | for file in files:
21 | if file.endswith(extension):
22 | if (amount is not None) and (cnt == amount):
23 | break
24 | if str_ is not None:
25 | if str_ not in file:
26 | continue
27 | mix_path = os.path.join(root, file)
28 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
29 | if not is_ext:
30 | ext = pure_path.split('.')[-1]
31 | pure_path = pure_path[:-(len(ext)+1)]
32 | if verbose:
33 | print(pure_path)
34 | file_list.append(pure_path)
35 | cnt += 1
36 | if verbose:
37 | print('Total: %d files' % len(file_list))
38 | print('Done!!!')
39 | if is_sort:
40 | file_list.sort()
41 | return file_list
42 |
43 |
44 | if __name__ == '__main__':
45 | # paths
46 | path_root = 'ailab17k_from-scratch_remi'
47 | path_indir = os.path.join(path_root, 'events')
48 | path_outdir = os.path.join(path_root, 'words')
49 | path_dictionary = os.path.join(path_root, 'dictionary.pkl')
50 | os.makedirs(path_outdir, exist_ok=True)
51 |
52 | # list files
53 | eventfiles = traverse_dir(
54 | path_indir,
55 | is_pure=True,
56 | is_sort=True,
57 | extension=('pkl'))
58 | n_files = len(eventfiles)
59 | print('num fiels:', n_files)
60 |
61 | # --- generate dictionary --- #
62 | print(' [*] generating dictionary')
63 | all_events = []
64 | for file in eventfiles:
65 | for event in pickle.load(open(os.path.join(path_indir, file), 'rb')):
66 | all_events.append('{}_{}'.format(event['name'], event['value']))
67 |
68 | # build
69 | unique_events = sorted(set(all_events), key=lambda x: (not isinstance(x, int), x))
70 | event2word = {key: i for i, key in enumerate(unique_events)}
71 | word2event = {i: key for i, key in enumerate(unique_events)}
72 | print(' > num classes:', len(word2event))
73 |
74 | # save
75 | pickle.dump((event2word, word2event), open(path_dictionary, 'wb'))
76 |
77 | # --- converts to word --- #
78 | event2word, word2event = pickle.load(open(path_dictionary, 'rb'))
79 | for fidx, file in enumerate(eventfiles):
80 | print('{}/{}'.format(fidx, n_files))
81 |
82 | # events to words
83 | path_infile = os.path.join(path_indir, file)
84 | events = pickle.load(open(path_infile, 'rb'))
85 | words = []
86 | for event in events:
87 | word = event2word['{}_{}'.format(event['name'], event['value'])]
88 | words.append(word)
89 |
90 | # save
91 | path_outfile = os.path.join(path_outdir, file + '.npy')
92 | fn = os.path.basename(path_outfile)
93 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True)
94 | np.save(path_outfile, words)
95 |
--------------------------------------------------------------------------------
/dataset/representations/uncond/remi/valid_fn_idx_map.json:
--------------------------------------------------------------------------------
1 | {"fn2idx": {"Animenzzz/Koibumi (Shizuru's Theme) - Rewrite Soundtrack": 0, "DooPiano/PRODUCE 101 _ \u1109\u1173\u11af\u1105\u1166\u110b\u1175\u1110\u1173 - Oh Little Girl (\u110b\u1169\u1105\u1175\u1110\u1173\u11af\u1100\u1165\u11af) Piano Cover": 1, "DooPiano/\u1105\u1166\u1103\u1173\u1107\u1166\u11af\u1107\u1166\u11ba (Red Velvet) - Peek A Boo (Happy_\u1112\u1162\u1111\u1175 Ver.) Piano Cover": 2, "DooPiano/BLACKPINK - \u1104\u116e\u1103\u116e\u1104\u116e\u1103\u116e (DDU-DU DDU-DU) Piano Cover": 3, "TheTheorist/Flume Ft. Chet Faker - Drop The Game": 4, "MusicBand-Guide/\u00b5\u201a\u00b1\u00cd\u2265\u00cf HANA - \u00df\u2014\u221eO\u00df\u2044\u00b6\u20ac\u00a7v - \u00ba@\u2202\u221e '\u00ae\u0153\u00c6{\u00b6\u00ca\u2122\u00c32' \u00a7\u02d8\u00df\u00bf\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 5, "MusicBand-Guide/\u00a1\u00ac\u00a9M\u00a9\u2202 R-chord - \u00a1\u00ac\u00a1\u00ac\u00a9p\u2211R\u00df\u2044 Thanks for Your Love - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 6, "MusicBand-Guide/\u2202\u00bf\u220f\u00d9\u00b1\u00cd\u00d8\u00d9 Lulu - \u2022\u02db\u2265\u00a3\u00b5\u03c0\u00dfA\u00a7F - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 7, "MusicBand-Guide/\u2202P\u00a7@\u00d8\u00cb - \u03a9\u2013\u2022\u02dd\u00aa\u00b0\u00dfA\u00b6n - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 8, "DooPiano/\u1107\u1169\u11af\u1108\u1161\u11af\u1100\u1161\u11ab\u1109\u1161\u110e\u116e\u11ab\u1100\u1175 (Bolbbalgan4) - \u110c\u1169\u11c2\u1103\u1161\u1100\u1169 \u1106\u1161\u11af\u1112\u1162 (Tell Me You Love Me) Piano Cover": 9, "MusicBand-Guide/Taeyeon \u592a\u598d \ud0dc\uc5f0 - Four Seasons \uc0ac\uacc4 - Piano Tutorial \u92fc\u7434\u6559\u5b78 \ud53c\uc544\ub178 [HQ] Synthesia": 10, "DooPiano/\u1110\u1162\u110b\u1167\u11ab (TAEYEON) - \u1109\u1161\u1100\u1168 (Four Seasons) Piano Cover": 11, "TheTheorist/Billie Eilish - listen before i go": 12, "DooPiano/\u1109\u1166\u1107\u1173\u11ab\u1110\u1175\u11ab (SEVENTEEN) - \u110b\u116e\u11af\u1100\u1169 \u1109\u1175\u11c1\u110c\u1175 \u110b\u1161\u11ad\u110b\u1161 (Don't Wanna Cry) Piano Cover": 13, "TheTheorist/Rihanna ft. Kanye West & Paul McCartney - FourFiveSeconds": 14, "MusicBand-Guide/\u95dc\u5586 Grady - \u60f3\u4f60\u7684\u591c (\u672a\u7720\u7248) Miss You Tonight - Piano Tutorial \u92fc\u7434\u6559\u5b78 [HQ] Synthesia": 15, "MusicBand-Guide/\u00b5\u00ff\u00b1\u00b7\u00b6t & \u2211\u00ae\u00a9v\u03a9n - \u221e\u00cd\u00a7\u02dd\u00aaP\u00a7^\u00a7\u00a2 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 16, "MusicBand-Guide/\u2122L\u00b4T\u2265\u00ab JJ Lin & \u03a9\u2264\u00ae\u00d9\u00df\u221e A-Sa - \u00a7p\u221es\u222b\u20ac Dimples - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 17, "MusicBand-Guide/YNW Melly - Mama Cry - Piano Tutorial [HQ] Synthesia": 18, "MusicBand-Guide/\u03a9\u2264\u221e\u2211\u2202\u00c6 - \u2202V\u00ae\u201d\u2202V\u00a7\u00a3\u00bf\u00a5 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 19, "MusicBand-Guide/\u00bfF\u00b4\u2265\u00c6\u00ca Janice Yan - \u00bfu\u2202\u00c6\u03c0D\u00dfO Graceful Goodbye - \u03c0q\u00b5\u00af\u00ba@ '20\u00a7\u00df\u00b4\u00b7' \u00a7\u02d8\u00df\u00bf\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 20, "MusicBand-Guide/\u00aa\u00bb\u00a1{ - \u2211N\u221a\u00af\u2022\u2260 - \u03c0q\u00b5\u00af\u00ba@ '\u2265\u00d8\u00b1\u00b0\u2022O' \u00a5\u00b0\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 21, "MusicBand-Guide/\u00df\u0131\u2122v\u00df\u00a0 - I Know You Know - \u03c0q\u00b5\u00af\u00ba@ '\u00df\u2044\u2122\u222b\u00d8u\u2122B\u00a7\u00d5' \u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 22, "MusicBand-Guide/\u00a1\u00ac\u00a9M\u00a9\u2202 R-chord - \u00dfA\u00a8O\u00d8u\u2122\u222b\u00ac\u02dc\u2202}\u00df\u2044 You Are Leaving Me - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 23, "MusicBand-Guide/\u2264\u02c6\u00a7\u00c2\u03a9\u00b4 Karen Mok - \u00df\ufb02\u2211n - \u03c0q\u00b5\u00af\u00ba@ '\u00df\ufb02\u2211n' \u00b6P\u00b6W\u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 24, "MusicBand-Guide/[\u00b5^\u221a\u2013\u2122\u00a9] \u00a9P\u00f8\u2265\u2260\u0131 Eric Chou - \u00a7@\u00ba\u00c0\u00a8\u00b8\u0192R Forever Beautiful - \u00d8\u00aa\u00a8\u0131\u00b5\u2211\u00b1a\u00ae\u2248\u00bf\u02d8\u00ae\u00e6\u2122v\u00b4\u2248\u00e6\u2026\u00a8\u00b0\u221e\u00a0\u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 25, "TheTheorist/Post Malone & Swae Lee - Sunflower": 26, "MusicBand-Guide/\u03c0p\u00b4B\u00a7\ufb02 Yuxin Lei - \u221eO\u00a9\u00bf After June [2014\u00b6~RAiNBOW\u2260p\u03c0\u222b '\u00e6\u00cc' \u00b1M\u00f8\u00cb\u2264\u00b6\u2211~\u00a9u\u2022D\u2022\u00a5] - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 27, "MusicBand-Guide/MappleZS - Star River In Your Eyes \u222b\u00b0\u2022\u00ff\u00a8P\u2122e - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 28, "MusicBand-Guide/\u9234\u6728\u5be6\u88cf - \u591c\u7a7a [\u6200\u611b\u5c0f\u884c\u661fED] - Piano Tutorial \u92fc\u7434\u6559\u5b78 \u30d4\u30a2\u30ce\u6307\u5c0e [HQ] Synthesia": 29, "DooPiano/\u110b\u1161\u110b\u1175\u110b\u1172 (IU) - \u1107\u1161\u11b7\u1111\u1167\u11ab\u110c\u1175 (Through the Night) Piano Cover": 30, "MusicBand-Guide/\u220ft\u00b5\u2264\u2022\u20ac Saint ft. \u2202\u00bf\u00a8\u00b8\u00a8\u221a - \u00aa\u00b0\u00a7\u00a3\u2022X\u00a7f\u2122\u222b\u2211Q\u00a9\u00bf - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 31, "MusicBand-Guide/\u00a9P\u2202\u00ab\u2202\u00d8 - \u00a7W\u2022j\u00a7\u00df\u00f8\u2019 - \u03c0q\u00b5\u00af\u00ba@ '\u00a7W\u2022j\u00b1\u00b0\u222bq' \u00a5\u00b0\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 32, "DooPiano/\u1107\u1169\u11af\u1108\u1161\u11af\u1100\u1161\u11ab\u1109\u1161\u110e\u116e\u11ab\u1100\u1175 (Bolbbalgan4) - \u1111\u1173\u1105\u1175\u110c\u1175\u110b\u1161 (Freesia) Piano Cover": 33, "DooPiano/BTS JIN (\u1107\u1161\u11bc\u1110\u1161\u11ab\u1109\u1169\u1102\u1167\u11ab\u1103\u1161\u11ab \u110c\u1175\u11ab) - \u110b\u1175 \u1107\u1161\u11b7 (Tonight) Piano Cover": 34, "TheTheorist/The Weeknd - What You Need": 35, "MusicBand-Guide/\u2265\u00d8\u222b\u02c6\u2260s Cheer Chen - \u00df\u2044\u2265\ufb02\u2248w\u00a7W\u00dfA\u00c6\u2026\u2122\u222b\u00a7\u222b\u00a7\ufb02\u00a8\u00b0\u221e\u00a0 - \u03c0q\u00bav '\u2265\ufb02\u2248w\u00dfA' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 36, "MusicBand-Guide/\u00b6\u2202\u2022D\u2211R - \u00ac\u221a\u00a7\u00a3\u00b6\u00cc\u2122\u222b\u00a7\ufb02\u220f\u0131 - \u03c0q\u00b5\u00af\u00ba@ '\u00df\u2044\u2022u\u2265\ufb02\u2248w\u00dfA' \u00a7\u02d8\u00bfY\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 37, "MusicBand-Guide/\u2265\u00d8\u00b4\u2265\u00ae\u2265 Eason Chen - \u2248\u02dd\u00df\u2044\u00d8d\u00b6b\u00dfA\u00ae\u2260\u221a\u2030 - \u03c0q\u00bav '\u00ac\\\u00a5\u00c1\u00a7H' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 38, "MusicBand-Guide/Justin Hurwitz - Quarantine - from movie 'First Man' - Piano Tutorial [HQ] Synthesia": 39, "MusicBand-Guide/\u00a1\u02d9\u03a9U & \u00b6\u00f8\u00a8M\u00aaT - \u2022H\u00a7H\u2022\u00a1\u2122\u222b\u00b6W\u220fq - \u03c0q\u00b5\u00af\u00ba@ '\u00a7H\u2022\u00a1\u2122\u222b\u00b6W\u220fq' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 40, "MusicBand-Guide/\u2265\u00d8\u03c0\u2248\u00e6\u00cf Ella Chen - 30\u221e\u2044 Age of 30 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 41, "MusicBand-Guide/\u00a1\u00df\u00a7\u00df\u00a1\u00e6 Joker Xue - \u221e\u25ca\u00a7\u2044 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 42, "MusicBand-Guide/\u00f8c\u00bas\u2022\u00da Crowd Lu - \u2265\u03a9\u2022J He-R - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 43, "MusicBand-Guide/[\u00b5^\u221a\u2013\u2122\u00a9] \u00a7p\u00c6\u2026\u00a9h\u00c6Q - \u2211R\u00ba\u2039 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia": 44, "TheTheorist/Drake & Majid Jordan - Summer's Over Interlude": 45, "MusicBand-Guide/\u00b6\u00c3\u00a8z\u2022\u00bb\u00c6v Kenshi Yonezu - Lemon - Piano Tutorial [HQ] Synthesia": 46, "MusicBand-Guide/\u00f8c\u00bas\u2022\u00da Crowd Lu - \u00a5X\u00a7\u00bf\u00a7\u00df\u00a5X You Complete Me - '\u2122\u00b7\u2022\u201c\u00a7j\u00a7H\u00ac\u2021\u00aek\u00b4\u0192' \u03c0q\u00bav\u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 47, "MusicBand-Guide/La La Land - Late for the Date 'Mia & Sebastian\u00b0\u00b6s Theme' - Piano Tutorial [HQ] Synthesia": 48, "MusicBand-Guide/G.E.M. \u00e6H\u00b5\u00b5\u00a5\u2014 - \u00b5e - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia": 49}, "idx2fn": {"0": "Animenzzz/Koibumi (Shizuru's Theme) - Rewrite Soundtrack", "1": "DooPiano/PRODUCE 101 _ \u1109\u1173\u11af\u1105\u1166\u110b\u1175\u1110\u1173 - Oh Little Girl (\u110b\u1169\u1105\u1175\u1110\u1173\u11af\u1100\u1165\u11af) Piano Cover", "2": "DooPiano/\u1105\u1166\u1103\u1173\u1107\u1166\u11af\u1107\u1166\u11ba (Red Velvet) - Peek A Boo (Happy_\u1112\u1162\u1111\u1175 Ver.) Piano Cover", "3": "DooPiano/BLACKPINK - \u1104\u116e\u1103\u116e\u1104\u116e\u1103\u116e (DDU-DU DDU-DU) Piano Cover", "4": "TheTheorist/Flume Ft. Chet Faker - Drop The Game", "5": "MusicBand-Guide/\u00b5\u201a\u00b1\u00cd\u2265\u00cf HANA - \u00df\u2014\u221eO\u00df\u2044\u00b6\u20ac\u00a7v - \u00ba@\u2202\u221e '\u00ae\u0153\u00c6{\u00b6\u00ca\u2122\u00c32' \u00a7\u02d8\u00df\u00bf\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "6": "MusicBand-Guide/\u00a1\u00ac\u00a9M\u00a9\u2202 R-chord - \u00a1\u00ac\u00a1\u00ac\u00a9p\u2211R\u00df\u2044 Thanks for Your Love - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "7": "MusicBand-Guide/\u2202\u00bf\u220f\u00d9\u00b1\u00cd\u00d8\u00d9 Lulu - \u2022\u02db\u2265\u00a3\u00b5\u03c0\u00dfA\u00a7F - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "8": "MusicBand-Guide/\u2202P\u00a7@\u00d8\u00cb - \u03a9\u2013\u2022\u02dd\u00aa\u00b0\u00dfA\u00b6n - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "9": "DooPiano/\u1107\u1169\u11af\u1108\u1161\u11af\u1100\u1161\u11ab\u1109\u1161\u110e\u116e\u11ab\u1100\u1175 (Bolbbalgan4) - \u110c\u1169\u11c2\u1103\u1161\u1100\u1169 \u1106\u1161\u11af\u1112\u1162 (Tell Me You Love Me) Piano Cover", "10": "MusicBand-Guide/Taeyeon \u592a\u598d \ud0dc\uc5f0 - Four Seasons \uc0ac\uacc4 - Piano Tutorial \u92fc\u7434\u6559\u5b78 \ud53c\uc544\ub178 [HQ] Synthesia", "11": "DooPiano/\u1110\u1162\u110b\u1167\u11ab (TAEYEON) - \u1109\u1161\u1100\u1168 (Four Seasons) Piano Cover", "12": "TheTheorist/Billie Eilish - listen before i go", "13": "DooPiano/\u1109\u1166\u1107\u1173\u11ab\u1110\u1175\u11ab (SEVENTEEN) - \u110b\u116e\u11af\u1100\u1169 \u1109\u1175\u11c1\u110c\u1175 \u110b\u1161\u11ad\u110b\u1161 (Don't Wanna Cry) Piano Cover", "14": "TheTheorist/Rihanna ft. Kanye West & Paul McCartney - FourFiveSeconds", "15": "MusicBand-Guide/\u95dc\u5586 Grady - \u60f3\u4f60\u7684\u591c (\u672a\u7720\u7248) Miss You Tonight - Piano Tutorial \u92fc\u7434\u6559\u5b78 [HQ] Synthesia", "16": "MusicBand-Guide/\u00b5\u00ff\u00b1\u00b7\u00b6t & \u2211\u00ae\u00a9v\u03a9n - \u221e\u00cd\u00a7\u02dd\u00aaP\u00a7^\u00a7\u00a2 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "17": "MusicBand-Guide/\u2122L\u00b4T\u2265\u00ab JJ Lin & \u03a9\u2264\u00ae\u00d9\u00df\u221e A-Sa - \u00a7p\u221es\u222b\u20ac Dimples - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "18": "MusicBand-Guide/YNW Melly - Mama Cry - Piano Tutorial [HQ] Synthesia", "19": "MusicBand-Guide/\u03a9\u2264\u221e\u2211\u2202\u00c6 - \u2202V\u00ae\u201d\u2202V\u00a7\u00a3\u00bf\u00a5 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "20": "MusicBand-Guide/\u00bfF\u00b4\u2265\u00c6\u00ca Janice Yan - \u00bfu\u2202\u00c6\u03c0D\u00dfO Graceful Goodbye - \u03c0q\u00b5\u00af\u00ba@ '20\u00a7\u00df\u00b4\u00b7' \u00a7\u02d8\u00df\u00bf\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "21": "MusicBand-Guide/\u00aa\u00bb\u00a1{ - \u2211N\u221a\u00af\u2022\u2260 - \u03c0q\u00b5\u00af\u00ba@ '\u2265\u00d8\u00b1\u00b0\u2022O' \u00a5\u00b0\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "22": "MusicBand-Guide/\u00df\u0131\u2122v\u00df\u00a0 - I Know You Know - \u03c0q\u00b5\u00af\u00ba@ '\u00df\u2044\u2122\u222b\u00d8u\u2122B\u00a7\u00d5' \u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "23": "MusicBand-Guide/\u00a1\u00ac\u00a9M\u00a9\u2202 R-chord - \u00dfA\u00a8O\u00d8u\u2122\u222b\u00ac\u02dc\u2202}\u00df\u2044 You Are Leaving Me - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "24": "MusicBand-Guide/\u2264\u02c6\u00a7\u00c2\u03a9\u00b4 Karen Mok - \u00df\ufb02\u2211n - \u03c0q\u00b5\u00af\u00ba@ '\u00df\ufb02\u2211n' \u00b6P\u00b6W\u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "25": "MusicBand-Guide/[\u00b5^\u221a\u2013\u2122\u00a9] \u00a9P\u00f8\u2265\u2260\u0131 Eric Chou - \u00a7@\u00ba\u00c0\u00a8\u00b8\u0192R Forever Beautiful - \u00d8\u00aa\u00a8\u0131\u00b5\u2211\u00b1a\u00ae\u2248\u00bf\u02d8\u00ae\u00e6\u2122v\u00b4\u2248\u00e6\u2026\u00a8\u00b0\u221e\u00a0\u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "26": "TheTheorist/Post Malone & Swae Lee - Sunflower", "27": "MusicBand-Guide/\u03c0p\u00b4B\u00a7\ufb02 Yuxin Lei - \u221eO\u00a9\u00bf After June [2014\u00b6~RAiNBOW\u2260p\u03c0\u222b '\u00e6\u00cc' \u00b1M\u00f8\u00cb\u2264\u00b6\u2211~\u00a9u\u2022D\u2022\u00a5] - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "28": "MusicBand-Guide/MappleZS - Star River In Your Eyes \u222b\u00b0\u2022\u00ff\u00a8P\u2122e - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "29": "MusicBand-Guide/\u9234\u6728\u5be6\u88cf - \u591c\u7a7a [\u6200\u611b\u5c0f\u884c\u661fED] - Piano Tutorial \u92fc\u7434\u6559\u5b78 \u30d4\u30a2\u30ce\u6307\u5c0e [HQ] Synthesia", "30": "DooPiano/\u110b\u1161\u110b\u1175\u110b\u1172 (IU) - \u1107\u1161\u11b7\u1111\u1167\u11ab\u110c\u1175 (Through the Night) Piano Cover", "31": "MusicBand-Guide/\u220ft\u00b5\u2264\u2022\u20ac Saint ft. \u2202\u00bf\u00a8\u00b8\u00a8\u221a - \u00aa\u00b0\u00a7\u00a3\u2022X\u00a7f\u2122\u222b\u2211Q\u00a9\u00bf - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "32": "MusicBand-Guide/\u00a9P\u2202\u00ab\u2202\u00d8 - \u00a7W\u2022j\u00a7\u00df\u00f8\u2019 - \u03c0q\u00b5\u00af\u00ba@ '\u00a7W\u2022j\u00b1\u00b0\u222bq' \u00a5\u00b0\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "33": "DooPiano/\u1107\u1169\u11af\u1108\u1161\u11af\u1100\u1161\u11ab\u1109\u1161\u110e\u116e\u11ab\u1100\u1175 (Bolbbalgan4) - \u1111\u1173\u1105\u1175\u110c\u1175\u110b\u1161 (Freesia) Piano Cover", "34": "DooPiano/BTS JIN (\u1107\u1161\u11bc\u1110\u1161\u11ab\u1109\u1169\u1102\u1167\u11ab\u1103\u1161\u11ab \u110c\u1175\u11ab) - \u110b\u1175 \u1107\u1161\u11b7 (Tonight) Piano Cover", "35": "TheTheorist/The Weeknd - What You Need", "36": "MusicBand-Guide/\u2265\u00d8\u222b\u02c6\u2260s Cheer Chen - \u00df\u2044\u2265\ufb02\u2248w\u00a7W\u00dfA\u00c6\u2026\u2122\u222b\u00a7\u222b\u00a7\ufb02\u00a8\u00b0\u221e\u00a0 - \u03c0q\u00bav '\u2265\ufb02\u2248w\u00dfA' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "37": "MusicBand-Guide/\u00b6\u2202\u2022D\u2211R - \u00ac\u221a\u00a7\u00a3\u00b6\u00cc\u2122\u222b\u00a7\ufb02\u220f\u0131 - \u03c0q\u00b5\u00af\u00ba@ '\u00df\u2044\u2022u\u2265\ufb02\u2248w\u00dfA' \u00a7\u02d8\u00bfY\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "38": "MusicBand-Guide/\u2265\u00d8\u00b4\u2265\u00ae\u2265 Eason Chen - \u2248\u02dd\u00df\u2044\u00d8d\u00b6b\u00dfA\u00ae\u2260\u221a\u2030 - \u03c0q\u00bav '\u00ac\\\u00a5\u00c1\u00a7H' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "39": "MusicBand-Guide/Justin Hurwitz - Quarantine - from movie 'First Man' - Piano Tutorial [HQ] Synthesia", "40": "MusicBand-Guide/\u00a1\u02d9\u03a9U & \u00b6\u00f8\u00a8M\u00aaT - \u2022H\u00a7H\u2022\u00a1\u2122\u222b\u00b6W\u220fq - \u03c0q\u00b5\u00af\u00ba@ '\u00a7H\u2022\u00a1\u2122\u222b\u00b6W\u220fq' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "41": "MusicBand-Guide/\u2265\u00d8\u03c0\u2248\u00e6\u00cf Ella Chen - 30\u221e\u2044 Age of 30 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "42": "MusicBand-Guide/\u00a1\u00df\u00a7\u00df\u00a1\u00e6 Joker Xue - \u221e\u25ca\u00a7\u2044 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "43": "MusicBand-Guide/\u00f8c\u00bas\u2022\u00da Crowd Lu - \u2265\u03a9\u2022J He-R - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "44": "MusicBand-Guide/[\u00b5^\u221a\u2013\u2122\u00a9] \u00a7p\u00c6\u2026\u00a9h\u00c6Q - \u2211R\u00ba\u2039 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "45": "TheTheorist/Drake & Majid Jordan - Summer's Over Interlude", "46": "MusicBand-Guide/\u00b6\u00c3\u00a8z\u2022\u00bb\u00c6v Kenshi Yonezu - Lemon - Piano Tutorial [HQ] Synthesia", "47": "MusicBand-Guide/\u00f8c\u00bas\u2022\u00da Crowd Lu - \u00a5X\u00a7\u00bf\u00a7\u00df\u00a5X You Complete Me - '\u2122\u00b7\u2022\u201c\u00a7j\u00a7H\u00ac\u2021\u00aek\u00b4\u0192' \u03c0q\u00bav\u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "48": "MusicBand-Guide/La La Land - Late for the Date 'Mia & Sebastian\u00b0\u00b6s Theme' - Piano Tutorial [HQ] Synthesia", "49": "MusicBand-Guide/G.E.M. \u00e6H\u00b5\u00b5\u00a5\u2014 - \u00b5e - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia"}}
--------------------------------------------------------------------------------
/dataset/representations/uncond/validation_songs.json:
--------------------------------------------------------------------------------
1 | ["MusicBand-Guide/\u03c0p\u00b4B\u00a7\ufb02 Yuxin Lei - \u221eO\u00a9\u00bf After June [2014\u00b6~RAiNBOW\u2260p\u03c0\u222b '\u00e6\u00cc' \u00b1M\u00f8\u00cb\u2264\u00b6\u2211~\u00a9u\u2022D\u2022\u00a5] - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "DooPiano/\u110b\u1161\u110b\u1175\u110b\u1172 (IU) - \u1107\u1161\u11b7\u1111\u1167\u11ab\u110c\u1175 (Through the Night) Piano Cover", "TheTheorist/Post Malone & Swae Lee - Sunflower", "MusicBand-Guide/\u00b5\u00ff\u00b1\u00b7\u00b6t & \u2211\u00ae\u00a9v\u03a9n - \u221e\u00cd\u00a7\u02dd\u00aaP\u00a7^\u00a7\u00a2 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "Animenzzz/Koibumi (Shizuru's Theme) - Rewrite Soundtrack", "DooPiano/BTS JIN (\u1107\u1161\u11bc\u1110\u1161\u11ab\u1109\u1169\u1102\u1167\u11ab\u1103\u1161\u11ab \u110c\u1175\u11ab) - \u110b\u1175 \u1107\u1161\u11b7 (Tonight) Piano Cover", "MusicBand-Guide/MappleZS - Star River In Your Eyes \u222b\u00b0\u2022\u00ff\u00a8P\u2122e - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u03a9\u2264\u221e\u2211\u2202\u00c6 - \u2202V\u00ae\u201d\u2202V\u00a7\u00a3\u00bf\u00a5 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u00df\u0131\u2122v\u00df\u00a0 - I Know You Know - \u03c0q\u00b5\u00af\u00ba@ '\u00df\u2044\u2122\u222b\u00d8u\u2122B\u00a7\u00d5' \u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "TheTheorist/Rihanna ft. Kanye West & Paul McCartney - FourFiveSeconds", "MusicBand-Guide/\u2265\u00d8\u03c0\u2248\u00e6\u00cf Ella Chen - 30\u221e\u2044 Age of 30 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "DooPiano/BLACKPINK - \u1104\u116e\u1103\u116e\u1104\u116e\u1103\u116e (DDU-DU DDU-DU) Piano Cover", "MusicBand-Guide/[\u00b5^\u221a\u2013\u2122\u00a9] \u00a9P\u00f8\u2265\u2260\u0131 Eric Chou - \u00a7@\u00ba\u00c0\u00a8\u00b8\u0192R Forever Beautiful - \u00d8\u00aa\u00a8\u0131\u00b5\u2211\u00b1a\u00ae\u2248\u00bf\u02d8\u00ae\u00e6\u2122v\u00b4\u2248\u00e6\u2026\u00a8\u00b0\u221e\u00a0\u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u00a1\u02d9\u03a9U & \u00b6\u00f8\u00a8M\u00aaT - \u2022H\u00a7H\u2022\u00a1\u2122\u222b\u00b6W\u220fq - \u03c0q\u00b5\u00af\u00ba@ '\u00a7H\u2022\u00a1\u2122\u222b\u00b6W\u220fq' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u00f8c\u00bas\u2022\u00da Crowd Lu - \u00a5X\u00a7\u00bf\u00a7\u00df\u00a5X You Complete Me - '\u2122\u00b7\u2022\u201c\u00a7j\u00a7H\u00ac\u2021\u00aek\u00b4\u0192' \u03c0q\u00bav\u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u2202P\u00a7@\u00d8\u00cb - \u03a9\u2013\u2022\u02dd\u00aa\u00b0\u00dfA\u00b6n - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u00a1\u00ac\u00a9M\u00a9\u2202 R-chord - \u00a1\u00ac\u00a1\u00ac\u00a9p\u2211R\u00df\u2044 Thanks for Your Love - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/La La Land - Late for the Date 'Mia & Sebastian\u00b0\u00b6s Theme' - Piano Tutorial [HQ] Synthesia", "DooPiano/\u1110\u1162\u110b\u1167\u11ab (TAEYEON) - \u1109\u1161\u1100\u1168 (Four Seasons) Piano Cover", "MusicBand-Guide/YNW Melly - Mama Cry - Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/Taeyeon \u592a\u598d \ud0dc\uc5f0 - Four Seasons \uc0ac\uacc4 - Piano Tutorial \u92fc\u7434\u6559\u5b78 \ud53c\uc544\ub178 [HQ] Synthesia", "MusicBand-Guide/\u00a1\u00df\u00a7\u00df\u00a1\u00e6 Joker Xue - \u221e\u25ca\u00a7\u2044 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "DooPiano/PRODUCE 101 _ \u1109\u1173\u11af\u1105\u1166\u110b\u1175\u1110\u1173 - Oh Little Girl (\u110b\u1169\u1105\u1175\u1110\u1173\u11af\u1100\u1165\u11af) Piano Cover", "MusicBand-Guide/\u220ft\u00b5\u2264\u2022\u20ac Saint ft. \u2202\u00bf\u00a8\u00b8\u00a8\u221a - \u00aa\u00b0\u00a7\u00a3\u2022X\u00a7f\u2122\u222b\u2211Q\u00a9\u00bf - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "DooPiano/\u1107\u1169\u11af\u1108\u1161\u11af\u1100\u1161\u11ab\u1109\u1161\u110e\u116e\u11ab\u1100\u1175 (Bolbbalgan4) - \u110c\u1169\u11c2\u1103\u1161\u1100\u1169 \u1106\u1161\u11af\u1112\u1162 (Tell Me You Love Me) Piano Cover", "MusicBand-Guide/\u00a9P\u2202\u00ab\u2202\u00d8 - \u00a7W\u2022j\u00a7\u00df\u00f8\u2019 - \u03c0q\u00b5\u00af\u00ba@ '\u00a7W\u2022j\u00b1\u00b0\u222bq' \u00a5\u00b0\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u00b6\u2202\u2022D\u2211R - \u00ac\u221a\u00a7\u00a3\u00b6\u00cc\u2122\u222b\u00a7\ufb02\u220f\u0131 - \u03c0q\u00b5\u00af\u00ba@ '\u00df\u2044\u2022u\u2265\ufb02\u2248w\u00dfA' \u00a7\u02d8\u00bfY\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u00bfF\u00b4\u2265\u00c6\u00ca Janice Yan - \u00bfu\u2202\u00c6\u03c0D\u00dfO Graceful Goodbye - \u03c0q\u00b5\u00af\u00ba@ '20\u00a7\u00df\u00b4\u00b7' \u00a7\u02d8\u00df\u00bf\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u2265\u00d8\u222b\u02c6\u2260s Cheer Chen - \u00df\u2044\u2265\ufb02\u2248w\u00a7W\u00dfA\u00c6\u2026\u2122\u222b\u00a7\u222b\u00a7\ufb02\u00a8\u00b0\u221e\u00a0 - \u03c0q\u00bav '\u2265\ufb02\u2248w\u00dfA' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/[\u00b5^\u221a\u2013\u2122\u00a9] \u00a7p\u00c6\u2026\u00a9h\u00c6Q - \u2211R\u00ba\u2039 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "DooPiano/\u1105\u1166\u1103\u1173\u1107\u1166\u11af\u1107\u1166\u11ba (Red Velvet) - Peek A Boo (Happy_\u1112\u1162\u1111\u1175 Ver.) Piano Cover", "TheTheorist/Drake & Majid Jordan - Summer's Over Interlude", "MusicBand-Guide/\u00b6\u00c3\u00a8z\u2022\u00bb\u00c6v Kenshi Yonezu - Lemon - Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/G.E.M. \u00e6H\u00b5\u00b5\u00a5\u2014 - \u00b5e - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "DooPiano/\u1107\u1169\u11af\u1108\u1161\u11af\u1100\u1161\u11ab\u1109\u1161\u110e\u116e\u11ab\u1100\u1175 (Bolbbalgan4) - \u1111\u1173\u1105\u1175\u110c\u1175\u110b\u1161 (Freesia) Piano Cover", "MusicBand-Guide/\u2264\u02c6\u00a7\u00c2\u03a9\u00b4 Karen Mok - \u00df\ufb02\u2211n - \u03c0q\u00b5\u00af\u00ba@ '\u00df\ufb02\u2211n' \u00b6P\u00b6W\u2022D\u221aD\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "DooPiano/\u1109\u1166\u1107\u1173\u11ab\u1110\u1175\u11ab (SEVENTEEN) - \u110b\u116e\u11af\u1100\u1169 \u1109\u1175\u11c1\u110c\u1175 \u110b\u1161\u11ad\u110b\u1161 (Don't Wanna Cry) Piano Cover", "MusicBand-Guide/\u00f8c\u00bas\u2022\u00da Crowd Lu - \u2265\u03a9\u2022J He-R - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u2265\u00d8\u00b4\u2265\u00ae\u2265 Eason Chen - \u2248\u02dd\u00df\u2044\u00d8d\u00b6b\u00dfA\u00ae\u2260\u221a\u2030 - \u03c0q\u00bav '\u00ac\\\u00a5\u00c1\u00a7H' \u2022D\u221aD\u00b6\u00b1 - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "TheTheorist/Flume Ft. Chet Faker - Drop The Game", "TheTheorist/Billie Eilish - listen before i go", "TheTheorist/The Weeknd - What You Need", "MusicBand-Guide/\u2202\u00bf\u220f\u00d9\u00b1\u00cd\u00d8\u00d9 Lulu - \u2022\u02db\u2265\u00a3\u00b5\u03c0\u00dfA\u00a7F - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u00aa\u00bb\u00a1{ - \u2211N\u221a\u00af\u2022\u2260 - \u03c0q\u00b5\u00af\u00ba@ '\u2265\u00d8\u00b1\u00b0\u2022O' \u00a5\u00b0\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u00a1\u00ac\u00a9M\u00a9\u2202 R-chord - \u00dfA\u00a8O\u00d8u\u2122\u222b\u00ac\u02dc\u2202}\u00df\u2044 You Are Leaving Me - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/\u9234\u6728\u5be6\u88cf - \u591c\u7a7a [\u6200\u611b\u5c0f\u884c\u661fED] - Piano Tutorial \u92fc\u7434\u6559\u5b78 \u30d4\u30a2\u30ce\u6307\u5c0e [HQ] Synthesia", "MusicBand-Guide/\u2122L\u00b4T\u2265\u00ab JJ Lin & \u03a9\u2264\u00ae\u00d9\u00df\u221e A-Sa - \u00a7p\u221es\u222b\u20ac Dimples - \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u00b5\u201a\u00b1\u00cd\u2265\u00cf HANA - \u00df\u2014\u221eO\u00df\u2044\u00b6\u20ac\u00a7v - \u00ba@\u2202\u221e '\u00ae\u0153\u00c6{\u00b6\u00ca\u2122\u00c32' \u00a7\u02d8\u00df\u00bf\u00b6\u00b1 - Piano Tutorial \u00f8\u02da\u00b5^\u00b1\u2013\u00e6\u00ab [HQ] Synthesia", "MusicBand-Guide/Justin Hurwitz - Quarantine - from movie 'First Man' - Piano Tutorial [HQ] Synthesia", "MusicBand-Guide/\u95dc\u5586 Grady - \u60f3\u4f60\u7684\u591c (\u672a\u7720\u7248) Miss You Tonight - Piano Tutorial \u92fc\u7434\u6559\u5b78 [HQ] Synthesia"]
--------------------------------------------------------------------------------
/dataset/synchronizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import copy
4 | import librosa
5 | import numpy as np
6 | import multiprocessing as mp
7 | from madmom.features.downbeats import DBNDownBeatTrackingProcessor
8 | from madmom.features.downbeats import RNNDownBeatProcessor
9 | from miditoolkit.midi import parser
10 | from miditoolkit.midi.containers import TimeSignature, TempoChange
11 |
12 |
13 | def traverse_dir(
14 | root_dir,
15 | extension=('mid', 'MID', 'midi'),
16 | amount=None,
17 | str_=None,
18 | is_pure=False,
19 | verbose=False,
20 | is_sort=False,
21 | is_ext=True):
22 | if verbose:
23 | print('[*] Scanning...')
24 | file_list = []
25 | cnt = 0
26 | for root, _, files in os.walk(root_dir):
27 | for file in files:
28 | if file.endswith(extension):
29 | if (amount is not None) and (cnt == amount):
30 | break
31 | if str_ is not None:
32 | if str_ not in file:
33 | continue
34 | mix_path = os.path.join(root, file)
35 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
36 | if not is_ext:
37 | ext = pure_path.split('.')[-1]
38 | pure_path = pure_path[:-(len(ext)+1)]
39 | if verbose:
40 | print(pure_path)
41 | file_list.append(pure_path)
42 | cnt += 1
43 | if verbose:
44 | print('Total: %d files' % len(file_list))
45 | print('Done!!!')
46 | if is_sort:
47 | file_list.sort()
48 | return file_list
49 |
50 |
51 | def get_instruments_abs_timing(instruments, tick_to_time):
52 | return convert_instruments_timing_from_sym_to_abs(instruments, tick_to_time)
53 |
54 |
55 | def convert_instruments_timing_from_sym_to_abs(instruments, tick_to_time):
56 | proc_instrs = copy.deepcopy(instruments)
57 | for instr in proc_instrs:
58 | for note in instr.notes:
59 | note.start = float(tick_to_time[note.start])
60 | note.end = float(tick_to_time[note.end])
61 | return proc_instrs
62 |
63 |
64 | def convert_instruments_timing_from_abs_to_sym(instruments, time_to_tick):
65 | proc_instrs = copy.deepcopy(instruments)
66 | for instr in proc_instrs:
67 | for note in instr.notes:
68 | # find nearest
69 | note.start = find_nearest_np(time_to_tick, note.start)
70 | note.end = find_nearest_np(time_to_tick, note.end)
71 | return proc_instrs
72 |
73 |
74 | def find_nearest_np(array, value):
75 | return (np.abs(array - value)).argmin()
76 |
77 |
78 | def find_first_downbeat(proc_res):
79 | rythm = np.where(proc_res[:, 1] == 1)[0]
80 | pos = proc_res[rythm[0], 0]
81 | return pos
82 |
83 |
84 | def interp_linear(src, target, num, tail=False):
85 | src = float(src)
86 | target = float(target)
87 | step = (target - src) / float(num)
88 | middles = [src + step * i for i in range(1, num)]
89 | res = [src] + middles
90 | if tail:
91 | res += [target]
92 | return res
93 |
94 |
95 | def estimate_beat(path_audio):
96 | proc = DBNDownBeatTrackingProcessor(beats_per_bar=[3, 4], fps=100)
97 | act = RNNDownBeatProcessor()(path_audio)
98 | proc_res = proc(act)
99 | return proc_res
100 |
101 |
102 | def export_audio_with_click(proc_res, path_audio, path_output, sr=44100):
103 | # extract time
104 | times_beat = proc_res[np.where(proc_res[:, 1]!=1)][:, 0]
105 | times_downbeat = proc_res[np.where(proc_res[:, 1]==1)][:, 0]
106 |
107 | # load
108 | y, _ = librosa.core.load(path_audio, sr=sr)
109 |
110 | # click audio
111 | y_beat = librosa.clicks(times=times_beat, sr=sr, click_freq=1200, click_duration=0.5) * 0.6
112 | y_downbeat = librosa.clicks(times=times_downbeat, sr=sr, click_freq=600, click_duration=0.5)
113 |
114 | # merge
115 | max_len = max(len(y), len(y_beat), len(y_downbeat))
116 | y_integrate = np.zeros(max_len)
117 | y_integrate[:len(y_beat)] += y_beat
118 | y_integrate[:len(y_downbeat)] += y_downbeat
119 | y_integrate[:len(y)] += y
120 |
121 | librosa.output.write_wav(path_output, y_integrate, sr)
122 |
123 |
124 | def align_midi(proc_res, path_midi_input, path_midi_output, ticks_per_beat=480):
125 | midi_data = parser.MidiFile(path_midi_input)
126 |
127 | # compute tempo
128 | beats = np.array([0.0] + list(proc_res[:, 0]))
129 | intervals = np.diff(beats)
130 | bpms = 60 / intervals
131 | tempo_info = list(zip(beats[:-1], bpms))
132 |
133 | # get absolute timing of instruments
134 | tick_to_time = midi_data.get_tick_to_time_mapping()
135 | abs_instr = get_instruments_abs_timing(midi_data.instruments, tick_to_time)
136 |
137 | # get end time of file
138 | end_time = midi_data.get_tick_to_time_mapping()[-1]
139 |
140 | # compute time to tick mapping
141 | resample_timing = []
142 | for i in range(len(beats)-1):
143 | start_beat = beats[i]
144 | end_beat = beats[i + 1]
145 | resample_timing += interp_linear(start_beat, end_beat, ticks_per_beat)
146 |
147 | # fill the empty in the tail (using last tick interval)
148 | last_tick_interval = resample_timing[-1] - resample_timing[-2]
149 | cur_time = resample_timing[-1]
150 | while cur_time < end_time:
151 | cur_time += last_tick_interval
152 | resample_timing.append(cur_time)
153 | resample_timing = np.array(resample_timing)
154 |
155 | # new a midifile obj
156 | midi_res = parser.MidiFile()
157 |
158 | # convert abs to sym
159 | sym_instr = convert_instruments_timing_from_abs_to_sym(abs_instr, resample_timing)
160 |
161 | # time signature
162 | first_db_sec = find_first_downbeat(proc_res)
163 | first_db_tick = find_nearest_np(resample_timing, first_db_sec)
164 | time_signature_changes = [TimeSignature(numerator=4, denominator=4, time=int(first_db_tick))]
165 |
166 | # tempo
167 | tempo_changes = []
168 | for pos, bpm in tempo_info:
169 | pos_tick = find_nearest_np(resample_timing, pos)
170 | tempo_changes.append(TempoChange(tempo=float(bpm), time=int(pos_tick)))
171 |
172 | # shift (pickup at the beginning)
173 | shift_align = ticks_per_beat * 4 - first_db_tick
174 |
175 | # apply shift to tempo
176 | for msg in tempo_changes:
177 | msg.time += shift_align
178 |
179 | # apply shift to notes
180 | for instr in sym_instr:
181 | for note in instr.notes:
182 | note.start += shift_align
183 | note.end += shift_align
184 |
185 | # set attributes
186 | midi_res.ticks_per_beat = ticks_per_beat
187 | midi_res.tempo_changes = tempo_changes
188 | midi_res.time_signature_changes = time_signature_changes
189 | midi_res.instruments = sym_instr
190 |
191 | # saving
192 | midi_res.dump(filename=path_midi_output)
193 |
194 |
195 | def analyze(path_midi_input, path_audio_input, path_midi_output, path_audio_output=None):
196 | print(path_midi_input)
197 | # beat tracking
198 | proc_res = estimate_beat(path_audio_input)
199 | # export audio with click
200 | if path_audio_output is not None:
201 | export_audio_with_click(proc_res, path_audio_input, path_audio_output)
202 | # export midi file
203 | align_midi(proc_res, path_midi_input, path_midi_output)
204 |
205 |
206 | if __name__ == '__main__':
207 | # paths
208 | path_audiodir = './mp3'
209 | path_indir = './midi_transcribed'
210 | path_outdir = './midi_synchronized'
211 | os.makedirs(path_outdir, exist_ok=True)
212 |
213 | # list files
214 | midifiles = traverse_dir(
215 | path_indir,
216 | is_ext=False,
217 | is_pure=True,
218 | is_sort=True)
219 | n_files = len(midifiles)
220 | print('num fiels:', n_files)
221 |
222 | # collect
223 | data = []
224 | for fidx in range(n_files):
225 | fn = midifiles[fidx]
226 | print('{}/{}'.format(fidx, n_files))
227 |
228 | # paths
229 | path_midi_input = os.path.join(path_indir, fn+'.mid')
230 | path_midi_output = os.path.join(path_outdir, fn+'.mid')
231 | path_audio_input = os.path.join(path_audiodir , fn+'.mp3')
232 |
233 | # mkdir
234 | fn = os.path.basename(path_midi_output)
235 | os.makedirs(path_midi_output[:-len(fn)], exist_ok=True)
236 |
237 | # append
238 | data.append([path_midi_input, path_audio_input, path_midi_output, None])
239 |
240 | # run, multi-thread
241 | pool = mp.Pool()
242 | pool.starmap(analyze, data)
243 |
--------------------------------------------------------------------------------
/docs/aaai21-slides.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/docs/aaai21-slides.pdf
--------------------------------------------------------------------------------
/workspace/cond_ls2midi/keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/cond_ls2midi/keep
--------------------------------------------------------------------------------
/workspace/uncond/Experiments.md:
--------------------------------------------------------------------------------
1 | # Experiments
2 |
3 | Experimental results of **unconditional** generation. We also elaborate detailed settings which are omitted in the paper because of the space limitation.
4 |
5 | ## Run the Codes
6 | **cp-linear**
7 |
8 | Edit the configration part at the begining of the `main-cp.py` file first.
9 |
10 | ```bash
11 | python main-cp.py
12 | ```
13 |
14 | **remi-xl**
15 |
16 | Edit the `config.yml` file first.
17 |
18 | ```bash
19 | # train
20 | python train.py
21 |
22 | # inference
23 | python inference.py
24 | ```
25 |
26 |
27 | ## Model Settings
28 | Backbone model:
29 | * linear transformer (Linear): ["Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"](https://arxiv.org/abs/2006.16236)
30 | * transformer-XL (XL): ["Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"](https://arxiv.org/abs/1901.02860)
31 |
32 |
33 | All transformers share the same settings:
34 | * number of layers: 12
35 | * number of heads: 8
36 | * model hidden size: 512
37 | * feed-forward layer size: 2,048
38 |
39 | | Settings | REMI + XL | CP + Linear |
40 | |:---------------------------:|:----------------------:|:----------------------:|
41 | | learning rate | 2e-4 | 1e-4 |
42 | | songs for training | 1,612 | 1,675 |
43 | | sequence length | 7,680 (512 x 15) | 3,584 |
44 | | dicionary size | 332 | 342 |
45 | | parameter amount | 41,291,084 | 39,016,630 |
46 | | recepetive field (train) | 512 + 512 for memory | 3,584 |
47 |
48 | Because the difference of nature between the two representations, it's hard to keep the training data for the 2 settings totally equal. We adjust the `number of songs` and the `sequence length` to achieve reasonable balance.
49 |
50 | Under the limitation of hardware budget (single 2080ti GPU), we report the comparison between the 2 settings:
51 | * REMI + XL: remi representation, transformer-XL
52 | * CP + Linear: compound word (CP) representation, linear transformer
53 |
54 |
55 | ## Evaluation
56 | ### Memory Usage
57 | The hardware budget is a single GPU and we use **Nvidia GeForce RTX 2080 GPU**, which has 11 GB memory.
58 |
59 | | # | Representation+model | batch size | Memory Usage (GB) | Evaluation |
60 | |:-:|:----------------------:|:----------------------:|:----------------------:|:-----------:|
61 | | 1 | REMI + XL | 4 | 4 | |
62 | | 2 | REMI + XL | 10 | 10 | O |
63 | | 3 | CP + Linear | 4 | 10 | O |
64 |
65 | Fro fair comparison, we let every model setting have its maximum memory consumption.
66 | Notice that
67 |
68 |
69 | ### Training Efficiency
70 | Relation between quality of generated samples and the loss could be different according to different settings. Here, we still measure the training efficiency based on cross-entropy loss.
71 |
72 | [WIP]
73 |
74 |
75 | ### Inference Efficiency and Performance
76 | We let each model to generate 50 songs and record the consumption time.
77 |
78 | Records (JSON file):
79 | * [cp-linear](./cp-linear/runtime_stats.json)
80 | * [remi-xl](./remi-xl/runtime_stats.json)
81 |
82 | | Representation+model | Ave. Song Time | EOS |
83 | |:----------------------:|:----------------------:|:--------:|
84 | | REMI + XL | 195.25446 | X |
85 | | CP + Linear | 20.91956 | O |
86 |
87 | `EOS` indicates whether the models are able to stop generation automatically - generating EOS token.
88 | For the CP+Linear setting, it only takes less than half minutes to generate a song and it also reveals its potential in real-time applications.
89 |
90 | ## Results
91 | ### Generated MIDI Files
92 | * [cp-linear](./cp-linear/gen_midis)
93 | * [remi-xl](./remi-xl/gen_midis)
94 |
95 |
96 | ### Checkpoints
97 | * [cp-linear](https://drive.google.com/drive/folders/114uore7LHjAsM4eKXG9TfVZL5S3YY7nZ?usp=sharing)
98 | * [remi-xl](https://drive.google.com/drive/folders/1tCaWQisPp_bcXKH5J3Nxmv6kUzJXs6qw?usp=sharing)
99 |
100 | ## Discussion
101 | ### About the generated samples
102 | We find that the generated pieces of REMI-XL tend to stick to some patterns and occasionally fall to loop them for a quite long time, or even the entire song. The quality within the loop is suprisingly organized (clearly arpeggio in left hand, melody line in right hand), but also a bit of tedious. The samples from CP-linear have a rather different texture, the "structure" diversity is richer but it's also more aggressive in selecting pitches.
103 |
104 | ### About EOS
105 | It turns out that REMI-XL failed in generating EOS sequence, which implies the sequence length might exceed the "effective length" becuase of it's rnn nature. In fact, we also tried remi+linear and cp+linear, and both of them success in this criterion.
106 |
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_0.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_0.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_1.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_1.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_10.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_10.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_11.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_11.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_12.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_12.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_13.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_13.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_14.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_14.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_15.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_15.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_16.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_16.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_17.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_17.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_18.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_18.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_19.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_19.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_2.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_2.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_20.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_20.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_21.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_21.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_22.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_22.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_23.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_23.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_24.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_24.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_25.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_25.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_26.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_26.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_27.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_27.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_28.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_28.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_29.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_29.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_3.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_3.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_30.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_30.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_31.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_31.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_32.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_32.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_33.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_33.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_34.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_34.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_35.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_35.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_36.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_36.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_37.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_37.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_38.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_38.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_39.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_39.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_4.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_4.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_40.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_40.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_41.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_41.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_42.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_42.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_43.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_43.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_44.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_44.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_45.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_45.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_46.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_46.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_47.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_47.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_48.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_48.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_49.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_49.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_5.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_5.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_6.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_6.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_7.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_7.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_8.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_8.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/gen_midis/get_9.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/cp-linear/gen_midis/get_9.mid
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/main-cp.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 |
5 | import math
6 | import time
7 | import glob
8 | import datetime
9 | import random
10 | import pickle
11 | import json
12 | import numpy as np
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 |
18 | import torch.optim as optim
19 | from torch.nn.utils import clip_grad_norm_
20 | from torch.utils.data import Dataset, DataLoader
21 |
22 | from fast_transformers.builders import TransformerEncoderBuilder
23 | from fast_transformers.builders import RecurrentEncoderBuilder
24 | from fast_transformers.masking import TriangularCausalMask
25 |
26 | import miditoolkit
27 | from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note
28 |
29 | import saver
30 |
31 |
32 | ################################################################################
33 | # config
34 | ################################################################################
35 |
36 | MODE = 'train'
37 | # MODE = 'inference'
38 |
39 | ###--- data ---###
40 | path_data_root = '..../dataset/representations/uncond/cp/ailab17k_from-scratch_cp'
41 | path_train_data = os.path.join(path_data_root, 'train_data_linear.npz')
42 | path_dictionary = os.path.join(path_data_root, 'dictionary.pkl')
43 |
44 | ###--- training config ---###
45 | D_MODEL = 512
46 | N_LAYER = 12
47 | N_HEAD = 8
48 | path_exp = 'exp'
49 | batch_size = 4
50 | gid = 0
51 | init_lr = 0.0001
52 |
53 | ###--- fine-tuning & inference config ---###
54 | # info_load_model = (
55 | # # path to ckpt for loading
56 | # '/volume/ai-music-wayne/aaai/from-scratch/cp-linear/exp_base_fs',
57 | # # loss
58 | # 29
59 | # )
60 | info_load_model = None
61 | path_gendir = 'gen_midis'
62 | num_songs = 50
63 |
64 | ################################################################################
65 | # File IO
66 | ################################################################################
67 |
68 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gid)
69 | BEAT_RESOL = 480
70 | BAR_RESOL = BEAT_RESOL * 4
71 | TICK_RESOL = BEAT_RESOL // 4
72 |
73 |
74 | def write_midi(words, path_outfile, word2event):
75 |
76 | class_keys = word2event.keys()
77 | # words = np.load(path_infile)
78 | midi_obj = miditoolkit.midi.parser.MidiFile()
79 |
80 | bar_cnt = 0
81 | cur_pos = 0
82 |
83 | all_notes = []
84 |
85 | cnt_error = 0
86 | for i in range(len(words)):
87 | vals = []
88 | for kidx, key in enumerate(class_keys):
89 | vals.append(word2event[key][words[i][kidx]])
90 | # print(vals)
91 |
92 | if vals[3] == 'Metrical':
93 | if vals[2] == 'Bar':
94 | bar_cnt += 1
95 | elif 'Beat' in vals[2]:
96 | beat_pos = int(vals[2].split('_')[1])
97 | cur_pos = bar_cnt * BAR_RESOL + beat_pos * TICK_RESOL
98 |
99 | # chord
100 | if vals[1] != 'CONTI' and vals[1] != 0:
101 | midi_obj.markers.append(
102 | Marker(text=str(vals[1]), time=cur_pos))
103 |
104 | if vals[0] != 'CONTI' and vals[0] != 0:
105 | tempo = int(vals[0].split('_')[-1])
106 | midi_obj.tempo_changes.append(
107 | TempoChange(tempo=tempo, time=cur_pos))
108 | else:
109 | pass
110 | elif vals[3] == 'Note':
111 |
112 | try:
113 | pitch = vals[4].split('_')[-1]
114 | duration = vals[5].split('_')[-1]
115 | velocity = vals[6].split('_')[-1]
116 |
117 | if int(duration) == 0:
118 | duration = 60
119 | end = cur_pos + int(duration)
120 |
121 | all_notes.append(
122 | Note(
123 | pitch=int(pitch),
124 | start=cur_pos,
125 | end=end,
126 | velocity=int(velocity))
127 | )
128 | except:
129 | continue
130 | else:
131 | pass
132 |
133 | # save midi
134 | piano_track = Instrument(0, is_drum=False, name='piano')
135 | piano_track.notes = all_notes
136 | midi_obj.instruments = [piano_track]
137 | midi_obj.dump(path_outfile)
138 |
139 |
140 | ################################################################################
141 | # Sampling
142 | ################################################################################
143 | # -- temperature -- #
144 | def softmax_with_temperature(logits, temperature):
145 | probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature))
146 | return probs
147 |
148 |
149 | def weighted_sampling(probs):
150 | probs /= sum(probs)
151 | sorted_probs = np.sort(probs)[::-1]
152 | sorted_index = np.argsort(probs)[::-1]
153 | word = np.random.choice(sorted_index, size=1, p=sorted_probs)[0]
154 | return word
155 |
156 |
157 | # -- nucleus -- #
158 | def nucleus(probs, p):
159 | probs /= (sum(probs) + 1e-5)
160 | sorted_probs = np.sort(probs)[::-1]
161 | sorted_index = np.argsort(probs)[::-1]
162 | cusum_sorted_probs = np.cumsum(sorted_probs)
163 | after_threshold = cusum_sorted_probs > p
164 | if sum(after_threshold) > 0:
165 | last_index = np.where(after_threshold)[0][0] + 1
166 | candi_index = sorted_index[:last_index]
167 | else:
168 | candi_index = sorted_index[:]
169 | candi_probs = [probs[i] for i in candi_index]
170 | candi_probs /= sum(candi_probs)
171 | word = np.random.choice(candi_index, size=1, p=candi_probs)[0]
172 | return word
173 |
174 |
175 | def sampling(logit, p=None, t=1.0):
176 | logit = logit.squeeze().cpu().numpy()
177 | probs = softmax_with_temperature(logits=logit, temperature=t)
178 |
179 | if p is not None:
180 | cur_word = nucleus(probs, p=p)
181 | else:
182 | cur_word = weighted_sampling(probs)
183 | return cur_word
184 |
185 |
186 | ################################################################################
187 | # Model
188 | ################################################################################
189 |
190 |
191 | def network_paras(model):
192 | # compute only trainable params
193 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
194 | params = sum([np.prod(p.size()) for p in model_parameters])
195 | return params
196 |
197 |
198 | class Embeddings(nn.Module):
199 | def __init__(self, n_token, d_model):
200 | super(Embeddings, self).__init__()
201 | self.lut = nn.Embedding(n_token, d_model)
202 | self.d_model = d_model
203 |
204 | def forward(self, x):
205 | return self.lut(x) * math.sqrt(self.d_model)
206 |
207 |
208 | class PositionalEncoding(nn.Module):
209 | def __init__(self, d_model, dropout=0.1, max_len=20000):
210 | super(PositionalEncoding, self).__init__()
211 | self.dropout = nn.Dropout(p=dropout)
212 |
213 | pe = torch.zeros(max_len, d_model)
214 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
215 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
216 | pe[:, 0::2] = torch.sin(position * div_term)
217 | pe[:, 1::2] = torch.cos(position * div_term)
218 | pe = pe.unsqueeze(0)
219 | self.register_buffer('pe', pe)
220 |
221 | def forward(self, x):
222 | x = x + self.pe[:, :x.size(1), :]
223 | return self.dropout(x)
224 |
225 |
226 | class TransformerModel(nn.Module):
227 | def __init__(self, n_token, is_training=True):
228 | super(TransformerModel, self).__init__()
229 |
230 | # --- params config --- #
231 | self.n_token = n_token
232 | self.d_model = D_MODEL
233 | self.n_layer = N_LAYER #
234 | self.dropout = 0.1
235 | self.n_head = N_HEAD #
236 | self.d_head = D_MODEL // N_HEAD
237 | self.d_inner = 2048
238 | self.loss_func = nn.CrossEntropyLoss(reduction='none')
239 | self.emb_sizes = [128, 256, 64, 32, 512, 128, 128]
240 |
241 | # --- modules config --- #
242 | # embeddings
243 | print('>>>>>:', self.n_token)
244 | self.word_emb_tempo = Embeddings(self.n_token[0], self.emb_sizes[0])
245 | self.word_emb_chord = Embeddings(self.n_token[1], self.emb_sizes[1])
246 | self.word_emb_barbeat = Embeddings(self.n_token[2], self.emb_sizes[2])
247 | self.word_emb_type = Embeddings(self.n_token[3], self.emb_sizes[3])
248 | self.word_emb_pitch = Embeddings(self.n_token[4], self.emb_sizes[4])
249 | self.word_emb_duration = Embeddings(self.n_token[5], self.emb_sizes[5])
250 | self.word_emb_velocity = Embeddings(self.n_token[6], self.emb_sizes[6])
251 | self.pos_emb = PositionalEncoding(self.d_model, self.dropout)
252 |
253 | # linear
254 | self.in_linear = nn.Linear(np.sum(self.emb_sizes), self.d_model)
255 |
256 | # encoder
257 | if is_training:
258 | # encoder (training)
259 | self.transformer_encoder = TransformerEncoderBuilder.from_kwargs(
260 | n_layers=self.n_layer,
261 | n_heads=self.n_head,
262 | query_dimensions=self.d_model//self.n_head,
263 | value_dimensions=self.d_model//self.n_head,
264 | feed_forward_dimensions=2048,
265 | activation='gelu',
266 | dropout=0.1,
267 | attention_type="causal-linear",
268 | ).get()
269 | else:
270 | # encoder (inference)
271 | print(' [o] using RNN backend.')
272 | self.transformer_encoder = RecurrentEncoderBuilder.from_kwargs(
273 | n_layers=self.n_layer,
274 | n_heads=self.n_head,
275 | query_dimensions=self.d_model//self.n_head,
276 | value_dimensions=self.d_model//self.n_head,
277 | feed_forward_dimensions=2048,
278 | activation='gelu',
279 | dropout=0.1,
280 | attention_type="causal-linear",
281 | ).get()
282 |
283 | # blend with type
284 | self.project_concat_type = nn.Linear(self.d_model + 32, self.d_model)
285 |
286 | # individual output
287 | self.proj_tempo = nn.Linear(self.d_model, self.n_token[0])
288 | self.proj_chord = nn.Linear(self.d_model, self.n_token[1])
289 | self.proj_barbeat = nn.Linear(self.d_model, self.n_token[2])
290 | self.proj_type = nn.Linear(self.d_model, self.n_token[3])
291 | self.proj_pitch = nn.Linear(self.d_model, self.n_token[4])
292 | self.proj_duration = nn.Linear(self.d_model, self.n_token[5])
293 | self.proj_velocity = nn.Linear(self.d_model, self.n_token[6])
294 |
295 | def compute_loss(self, predict, target, loss_mask):
296 | loss = self.loss_func(predict, target)
297 | loss = loss * loss_mask
298 | loss = torch.sum(loss) / torch.sum(loss_mask)
299 | return loss
300 |
301 | def train_step(self, x, target, loss_mask):
302 | h, y_type = self.forward_hidden(x)
303 | y_tempo, y_chord, y_barbeat, y_pitch, y_duration, y_velocity = self.forward_output(h, target)
304 |
305 | # reshape (b, s, f) -> (b, f, s)
306 | y_tempo = y_tempo[:, ...].permute(0, 2, 1)
307 | y_chord = y_chord[:, ...].permute(0, 2, 1)
308 | y_barbeat = y_barbeat[:, ...].permute(0, 2, 1)
309 | y_type = y_type[:, ...].permute(0, 2, 1)
310 | y_pitch = y_pitch[:, ...].permute(0, 2, 1)
311 | y_duration = y_duration[:, ...].permute(0, 2, 1)
312 | y_velocity = y_velocity[:, ...].permute(0, 2, 1)
313 |
314 | # loss
315 | loss_tempo = self.compute_loss(
316 | y_tempo, target[..., 0], loss_mask)
317 | loss_chord = self.compute_loss(
318 | y_chord, target[..., 1], loss_mask)
319 | loss_barbeat = self.compute_loss(
320 | y_barbeat, target[..., 2], loss_mask)
321 | loss_type = self.compute_loss(
322 | y_type, target[..., 3], loss_mask)
323 | loss_pitch = self.compute_loss(
324 | y_pitch, target[..., 4], loss_mask)
325 | loss_duration = self.compute_loss(
326 | y_duration, target[..., 5], loss_mask)
327 | loss_velocity = self.compute_loss(
328 | y_velocity, target[..., 6], loss_mask)
329 |
330 | return loss_tempo, loss_chord, loss_barbeat, loss_type, loss_pitch, loss_duration, loss_velocity
331 |
332 | def forward_hidden(self, x, memory=None, is_training=True):
333 | '''
334 | linear transformer: b x s x f
335 | x.shape=(bs, nf)
336 | '''
337 |
338 | # embeddings
339 | emb_tempo = self.word_emb_tempo(x[..., 0])
340 | emb_chord = self.word_emb_chord(x[..., 1])
341 | emb_barbeat = self.word_emb_barbeat(x[..., 2])
342 | emb_type = self.word_emb_type(x[..., 3])
343 | emb_pitch = self.word_emb_pitch(x[..., 4])
344 | emb_duration = self.word_emb_duration(x[..., 5])
345 | emb_velocity = self.word_emb_velocity(x[..., 6])
346 |
347 | embs = torch.cat(
348 | [
349 | emb_tempo,
350 | emb_chord,
351 | emb_barbeat,
352 | emb_type,
353 | emb_pitch,
354 | emb_duration,
355 | emb_velocity,
356 | ], dim=-1)
357 |
358 | emb_linear = self.in_linear(embs)
359 | pos_emb = self.pos_emb(emb_linear)
360 |
361 | # assert False
362 |
363 | # transformer
364 | if is_training:
365 | # mask
366 | attn_mask = TriangularCausalMask(pos_emb.size(1), device=x.device)
367 | h = self.transformer_encoder(pos_emb, attn_mask) # y: b x s x d_model
368 |
369 | # project type
370 | y_type = self.proj_type(h)
371 | return h, y_type
372 | else:
373 | pos_emb = pos_emb.squeeze(0)
374 | h, memory = self.transformer_encoder(pos_emb, memory=memory) # y: s x d_model
375 |
376 | # project type
377 | y_type = self.proj_type(h)
378 | return h, y_type, memory
379 |
380 | def forward_output(self, h, y):
381 | '''
382 | for training
383 | '''
384 | tf_skip_type = self.word_emb_type(y[..., 3])
385 |
386 | # project other
387 | y_concat_type = torch.cat([h, tf_skip_type], dim=-1)
388 | y_ = self.project_concat_type(y_concat_type)
389 |
390 | y_tempo = self.proj_tempo(y_)
391 | y_chord = self.proj_chord(y_)
392 | y_barbeat = self.proj_barbeat(y_)
393 | y_pitch = self.proj_pitch(y_)
394 | y_duration = self.proj_duration(y_)
395 | y_velocity = self.proj_velocity(y_)
396 |
397 | return y_tempo, y_chord, y_barbeat, y_pitch, y_duration, y_velocity
398 |
399 | def froward_output_sampling(self, h, y_type):
400 | '''
401 | for inference
402 | '''
403 | # sample type
404 | y_type_logit = y_type[0, :]
405 | cur_word_type = sampling(y_type_logit, p=0.90)
406 |
407 | type_word_t = torch.from_numpy(
408 | np.array([cur_word_type])).long().cuda().unsqueeze(0)
409 |
410 | tf_skip_type = self.word_emb_type(type_word_t).squeeze(0)
411 |
412 | # concat
413 | y_concat_type = torch.cat([h, tf_skip_type], dim=-1)
414 | y_ = self.project_concat_type(y_concat_type)
415 |
416 | # project other
417 | y_tempo = self.proj_tempo(y_)
418 | y_chord = self.proj_chord(y_)
419 | y_barbeat = self.proj_barbeat(y_)
420 |
421 | y_pitch = self.proj_pitch(y_)
422 | y_duration = self.proj_duration(y_)
423 | y_velocity = self.proj_velocity(y_)
424 |
425 | # sampling gen_cond
426 | cur_word_tempo = sampling(y_tempo, t=1.2, p=0.9)
427 | cur_word_barbeat = sampling(y_barbeat, t=1.2)
428 | cur_word_chord = sampling(y_chord, p=0.99)
429 | cur_word_pitch = sampling(y_pitch, p=0.9)
430 | cur_word_duration = sampling(y_duration, t=2, p=0.9)
431 | cur_word_velocity = sampling(y_velocity, t=5)
432 |
433 | # collect
434 | next_arr = np.array([
435 | cur_word_tempo,
436 | cur_word_chord,
437 | cur_word_barbeat,
438 | cur_word_type,
439 | cur_word_pitch,
440 | cur_word_duration,
441 | cur_word_velocity,
442 | ])
443 | return next_arr
444 |
445 | def inference_from_scratch(self, dictionary):
446 | event2word, word2event = dictionary
447 | classes = word2event.keys()
448 |
449 | def print_word_cp(cp):
450 | result = [word2event[k][cp[idx]] for idx, k in enumerate(classes)]
451 |
452 | for r in result:
453 | print('{:15s}'.format(str(r)), end=' | ')
454 | print('')
455 |
456 | init = np.array([
457 | [0, 0, 1, 1, 0, 0, 0], # bar
458 | ])
459 |
460 | cnt_token = len(init)
461 | with torch.no_grad():
462 | final_res = []
463 | memory = None
464 | h = None
465 |
466 | cnt_bar = 1
467 | init_t = torch.from_numpy(init).long().cuda()
468 | print('------ initiate ------')
469 | for step in range(init.shape[0]):
470 | print_word_cp(init[step, :])
471 | input_ = init_t[step, :].unsqueeze(0).unsqueeze(0)
472 | final_res.append(init[step, :][None, ...])
473 |
474 | h, y_type, memory = self.forward_hidden(
475 | input_, memory, is_training=False)
476 |
477 | print('------ generate ------')
478 | while(True):
479 | # sample others
480 | next_arr = self.froward_output_sampling(h, y_type)
481 | final_res.append(next_arr[None, ...])
482 | print('bar:', cnt_bar, end= ' ==')
483 | print_word_cp(next_arr)
484 |
485 | # forward
486 | input_ = torch.from_numpy(next_arr).long().cuda()
487 | input_ = input_.unsqueeze(0).unsqueeze(0)
488 | h, y_type, memory = self.forward_hidden(
489 | input_, memory, is_training=False)
490 |
491 | # end of sequence
492 | if word2event['type'][next_arr[3]] == 'EOS':
493 | break
494 |
495 | if word2event['bar-beat'][next_arr[2]] == 'Bar':
496 | cnt_bar += 1
497 |
498 | print('\n--------[Done]--------')
499 | final_res = np.concatenate(final_res)
500 | print(final_res.shape)
501 | return final_res
502 |
503 |
504 | ##########################################################################################################################
505 | # Script
506 | ##########################################################################################################################
507 |
508 |
509 | def train():
510 | # hyper params
511 | n_epoch = 4000
512 | max_grad_norm = 3
513 |
514 | # load
515 | dictionary = pickle.load(open(path_dictionary, 'rb'))
516 | event2word, word2event = dictionary
517 | train_data = np.load(path_train_data)
518 |
519 | # create saver
520 | saver_agent = saver.Saver(path_exp)
521 |
522 | # config
523 | n_class = []
524 | for key in event2word.keys():
525 | n_class.append(len(dictionary[0][key]))
526 |
527 | # log
528 | print('num of classes:', n_class)
529 |
530 | # init
531 | net = TransformerModel(n_class)
532 | net.cuda()
533 | net.train()
534 | n_parameters = network_paras(net)
535 | print('n_parameters: {:,}'.format(n_parameters))
536 | saver_agent.add_summary_msg(
537 | ' > params amount: {:,d}'.format(n_parameters))
538 |
539 | # load model
540 | if info_load_model:
541 | path_ckpt = info_load_model[0] # path to ckpt dir
542 | loss = info_load_model[1] # loss
543 | name = 'loss_' + str(loss)
544 | path_saved_ckpt = os.path.join(path_ckpt, name + '_params.pt')
545 | print('[*] load model from:', path_saved_ckpt)
546 | net.load_state_dict(torch.load(path_saved_ckpt))
547 |
548 | # optimizers
549 | optimizer = optim.Adam(net.parameters(), lr=init_lr)
550 |
551 | # unpack
552 | train_x = train_data['x']
553 | train_y = train_data['y']
554 | train_mask = train_data['mask']
555 | num_batch = len(train_x) // batch_size
556 |
557 | print(' num_batch:', num_batch)
558 | print(' train_x:', train_x.shape)
559 | print(' train_y:', train_y.shape)
560 | print(' train_mask:', train_mask.shape)
561 |
562 | # run
563 | start_time = time.time()
564 | for epoch in range(n_epoch):
565 | acc_loss = 0
566 | acc_losses = np.zeros(7)
567 |
568 | for bidx in range(num_batch): # num_batch
569 | saver_agent.global_step_increment()
570 |
571 | # index
572 | bidx_st = batch_size * bidx
573 | bidx_ed = batch_size * (bidx + 1)
574 |
575 | # unpack batch data
576 | batch_x = train_x[bidx_st:bidx_ed]
577 | batch_y = train_y[bidx_st:bidx_ed]
578 | batch_mask = train_mask[bidx_st:bidx_ed]
579 |
580 | # to tensor
581 | batch_x = torch.from_numpy(batch_x).long().cuda()
582 | batch_y = torch.from_numpy(batch_y).long().cuda()
583 | batch_mask = torch.from_numpy(batch_mask).float().cuda()
584 |
585 | # run
586 | losses = net.train_step(batch_x, batch_y, batch_mask)
587 | loss = (losses[0] + losses[1] + losses[2] + losses[3] + losses[4] + losses[5] + losses[6]) / 7
588 |
589 | # Update
590 | net.zero_grad()
591 | loss.backward()
592 | if max_grad_norm is not None:
593 | clip_grad_norm_(net.parameters(), max_grad_norm)
594 | optimizer.step()
595 |
596 | # print
597 | sys.stdout.write('{}/{} | Loss: {:06f} | {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}\r'.format(
598 | bidx, num_batch, loss, losses[0], losses[1], losses[2], losses[3], losses[4], losses[5], losses[6]))
599 | sys.stdout.flush()
600 |
601 | # acc
602 | acc_losses += np.array([l.item() for l in losses])
603 | acc_loss += loss.item()
604 |
605 | # log
606 | saver_agent.add_summary('batch loss', loss.item())
607 |
608 | # epoch loss
609 | runtime = time.time() - start_time
610 | epoch_loss = acc_loss / num_batch
611 | acc_losses = acc_losses / num_batch
612 | print('------------------------------------')
613 | print('epoch: {}/{} | Loss: {} | time: {}'.format(
614 | epoch, n_epoch, epoch_loss, str(datetime.timedelta(seconds=runtime))))
615 | each_loss_str = '{:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}\r'.format(
616 | acc_losses[0], acc_losses[1], acc_losses[2], acc_losses[3], acc_losses[4], acc_losses[5], acc_losses[6])
617 | print(' >', each_loss_str)
618 |
619 | saver_agent.add_summary('epoch loss', epoch_loss)
620 | saver_agent.add_summary('epoch each loss', each_loss_str)
621 |
622 | # save model, with policy
623 | loss = epoch_loss
624 | if 0.4 < loss <= 0.8:
625 | fn = int(loss * 10) * 10
626 | saver_agent.save_model(net, name='loss_' + str(fn))
627 | elif 0.05 < loss <= 0.40:
628 | fn = int(loss * 100)
629 | saver_agent.save_model(net, name='loss_' + str(fn))
630 | elif loss <= 0.05:
631 | print('Finished')
632 | return
633 | else:
634 | saver_agent.save_model(net, name='loss_high')
635 |
636 |
637 | def generate():
638 | # path
639 | path_ckpt = info_load_model[0] # path to ckpt dir
640 | loss = info_load_model[1] # loss
641 | name = 'loss_' + str(loss)
642 | path_saved_ckpt = os.path.join(path_ckpt, name + '_params.pt')
643 |
644 | # load
645 | dictionary = pickle.load(open(path_dictionary, 'rb'))
646 | event2word, word2event = dictionary
647 |
648 | # outdir
649 | os.makedirs(path_gendir, exist_ok=True)
650 |
651 | # config
652 | n_class = []
653 | for key in event2word.keys():
654 | n_class.append(len(dictionary[0][key]))
655 |
656 | # init model
657 | net = TransformerModel(n_class, is_training=False)
658 | net.cuda()
659 | net.eval()
660 |
661 | # load model
662 | print('[*] load model from:', path_saved_ckpt)
663 | net.load_state_dict(torch.load(path_saved_ckpt))
664 |
665 | # gen
666 | start_time = time.time()
667 | song_time_list = []
668 | words_len_list = []
669 |
670 | cnt_tokens_all = 0
671 | sidx = 0
672 | while sidx < num_songs:
673 | try:
674 | start_time = time.time()
675 | print('current idx:', sidx)
676 | path_outfile = os.path.join(path_gendir, 'get_{}.mid'.format(str(sidx)))
677 |
678 | res = net.inference_from_scratch(dictionary)
679 | write_midi(res, path_outfile, word2event)
680 |
681 | song_time = time.time() - start_time
682 | word_len = len(res)
683 | print('song time:', song_time)
684 | print('word_len:', word_len)
685 | words_len_list.append(word_len)
686 | song_time_list.append(song_time)
687 |
688 | sidx += 1
689 | except KeyboardInterrupt:
690 | raise ValueError(' [x] terminated.')
691 | except:
692 | continue
693 |
694 | print('ave token time:', sum(words_len_list) / sum(song_time_list))
695 | print('ave song time:', np.mean(song_time_list))
696 |
697 | runtime_result = {
698 | 'song_time':song_time_list,
699 | 'words_len_list': words_len_list,
700 | 'ave token time:': sum(words_len_list) / sum(song_time_list),
701 | 'ave song time': float(np.mean(song_time_list)),
702 | }
703 |
704 | with open('runtime_stats.json', 'w') as f:
705 | json.dump(runtime_result, f)
706 |
707 |
708 | if __name__ == '__main__':
709 | # -- training -- #
710 | if MODE == 'train':
711 | train()
712 |
713 | # -- inference -- #
714 | elif MODE == 'inference':
715 | generate()
716 |
717 | else:
718 | pass
719 |
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/runtime_stats.json:
--------------------------------------------------------------------------------
1 | {"song_time": [22.15657639503479, 22.193533420562744, 31.92780828475952, 19.54995632171631, 16.9428653717041, 14.89795708656311, 23.36419653892517, 17.333919048309326, 21.232710123062134, 17.8750319480896, 20.08845090866089, 20.505202531814575, 21.045650482177734, 25.3824942111969, 19.22031259536743, 25.14116883277893, 26.373836994171143, 17.95224666595459, 24.632129192352295, 22.872940063476562, 18.909109115600586, 23.45732831954956, 17.696035146713257, 19.6449134349823, 21.48639941215515, 20.14349365234375, 18.929306983947754, 17.5115327835083, 20.62809991836548, 26.35112738609314, 20.624415159225464, 21.29854154586792, 19.991357803344727, 20.77635097503662, 15.483854293823242, 21.372031688690186, 21.510165691375732, 26.04141879081726, 18.135209560394287, 18.571840286254883, 23.6661434173584, 22.40087866783142, 23.71457076072693, 19.199822187423706, 14.656707763671875, 18.02313542366028, 23.298468351364136, 16.629225969314575, 21.779688596725464, 23.35808539390564], "words_len_list": [1500, 1466, 2122, 1273, 1150, 1075, 1595, 1141, 1483, 1272, 1461, 1434, 1464, 1719, 1339, 1702, 1874, 1241, 1691, 1619, 1302, 1605, 1169, 1335, 1458, 1436, 1330, 1128, 1466, 1842, 1373, 1399, 1365, 1350, 1051, 1376, 1455, 1825, 1188, 1175, 1585, 1490, 1679, 1349, 1039, 1304, 1666, 1153, 1447, 1545], "ave token time:": 68.362798469141, "ave song time": 20.919564909934998}
--------------------------------------------------------------------------------
/workspace/uncond/cp-linear/saver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import logging
5 | import datetime
6 | import collections
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | class Saver(object):
12 | def __init__(
13 | self,
14 | exp_dir,
15 | mode='w'):
16 |
17 | self.exp_dir = exp_dir
18 | self.init_time = time.time()
19 | self.global_step = 0
20 |
21 | # makedirs
22 | os.makedirs(exp_dir, exist_ok=True)
23 |
24 | # logging config
25 | path_logger = os.path.join(exp_dir, 'log.txt')
26 | logging.basicConfig(
27 | level=logging.DEBUG,
28 | format='%(message)s',
29 | filename=path_logger,
30 | filemode=mode)
31 | self.logger = logging.getLogger('training monitor')
32 |
33 | def add_summary_msg(self, msg):
34 | self.logger.debug(msg)
35 |
36 | def add_summary(
37 | self,
38 | key,
39 | val,
40 | step=None,
41 | cur_time=None):
42 |
43 | if cur_time is None:
44 | cur_time = time.time() - self.init_time
45 | if step is None:
46 | step = self.global_step
47 |
48 | # write msg (key, val, step, time)
49 | if isinstance(val, float):
50 | msg_str = '{:10s} | {:.10f} | {:10d} | {}'.format(
51 | key,
52 | val,
53 | step,
54 | cur_time
55 | )
56 | else:
57 | msg_str = '{:10s} | {} | {:10d} | {}'.format(
58 | key,
59 | val,
60 | step,
61 | cur_time
62 | )
63 |
64 | self.logger.debug(msg_str)
65 |
66 | def save_model(
67 | self,
68 | model,
69 | optimizer=None,
70 | outdir=None,
71 | name='model'):
72 |
73 | if outdir is None:
74 | outdir = self.exp_dir
75 | print(' [*] saving model to {}, name: {}'.format(outdir, name))
76 | torch.save(model, os.path.join(outdir, name+'.pt'))
77 | torch.save(model.state_dict(), os.path.join(outdir, name+'_params.pt'))
78 |
79 | if optimizer is not None:
80 | torch.save(optimizer.state_dict(), os.path.join(outdir, name+'_opt.pt'))
81 |
82 | def load_model(
83 | self,
84 | path_exp,
85 | device='cpu',
86 | name='model.pt'):
87 |
88 | path_pt = os.path.join(path_exp, name)
89 | print(' [*] restoring model from', path_pt)
90 | model = torch.load(path_pt, map_location=torch.device(device))
91 | return model
92 |
93 | def global_step_increment(self):
94 | self.global_step += 1
95 |
96 | """
97 | file modes
98 | 'a':
99 | Opens a file for appending. The file pointer is at the end of the file if the file exists.
100 | That is, the file is in the append mode. If the file does not exist, it creates a new file for writing.
101 |
102 | 'w':
103 | Opens a file for writing only. Overwrites the file if the file exists.
104 | If the file does not exist, creates a new file for writing.
105 | """
106 |
107 | def make_loss_report(
108 | path_log,
109 | path_figure='loss.png',
110 | dpi=100):
111 |
112 | # load logfile
113 | monitor_vals = collections.defaultdict(list)
114 | with open(path_logfile, 'r') as f:
115 | for line in f:
116 | try:
117 | line = line.strip()
118 | key, val, step, acc_time = line.split(' | ')
119 | monitor_vals[key].append((float(val), int(step), acc_time))
120 | except:
121 | continue
122 |
123 | # collect
124 | step_train = [item[1] for item in monitor_vals['train loss']]
125 | vals_train = [item[0] for item in monitor_vals['train loss']]
126 |
127 | step_valid = [item[1] for item in monitor_vals['valid loss']]
128 | vals_valid = [item[0] for item in monitor_vals['valid loss']]
129 |
130 | x_min = step_valid[np.argmin(vals_valid)]
131 | y_min = min(vals_valid)
132 |
133 | # plot
134 | fig = plt.figure(dpi=dpi)
135 | plt.title('training process')
136 | plt.plot(step_train, vals_train, label='train')
137 | plt.plot(step_valid, vals_valid, label='valid')
138 | plt.yscale('log')
139 | plt.plot([x_min], [y_min], 'ro')
140 | plt.legend(loc='upper right')
141 | plt.tight_layout()
142 | plt.savefig(path_figure)
143 |
144 | '''
145 | author: wayn391@mastertones
146 | '''
147 |
148 | import os
149 | import time
150 | import torch
151 | import logging
152 | import datetime
153 | import collections
154 | import numpy as np
155 | import matplotlib.pyplot as plt
156 |
157 |
158 | class Saver(object):
159 | def __init__(
160 | self,
161 | exp_dir,
162 | mode='w'):
163 |
164 | self.exp_dir = exp_dir
165 | self.init_time = time.time()
166 | self.global_step = 0
167 |
168 | # makedirs
169 | os.makedirs(exp_dir, exist_ok=True)
170 |
171 | # logging config
172 | path_logger = os.path.join(exp_dir, 'log.txt')
173 | logging.basicConfig(
174 | level=logging.DEBUG,
175 | format='%(message)s',
176 | filename=path_logger,
177 | filemode=mode)
178 | self.logger = logging.getLogger('training monitor')
179 |
180 | def add_summary_msg(self, msg):
181 | self.logger.debug(msg)
182 |
183 | def add_summary(
184 | self,
185 | key,
186 | val,
187 | step=None,
188 | cur_time=None):
189 |
190 | if cur_time is None:
191 | cur_time = time.time() - self.init_time
192 | if step is None:
193 | step = self.global_step
194 |
195 | # write msg (key, val, step, time)
196 | if isinstance(val, float):
197 | msg_str = '{:10s} | {:.10f} | {:10d} | {}'.format(
198 | key,
199 | val,
200 | step,
201 | cur_time
202 | )
203 | else:
204 | msg_str = '{:10s} | {} | {:10d} | {}'.format(
205 | key,
206 | val,
207 | step,
208 | cur_time
209 | )
210 |
211 | self.logger.debug(msg_str)
212 |
213 | def save_model(
214 | self,
215 | model,
216 | optimizer=None,
217 | outdir=None,
218 | name='model'):
219 |
220 | if outdir is None:
221 | outdir = self.exp_dir
222 | print(' [*] saving model to {}, name: {}'.format(outdir, name))
223 | # torch.save(model, os.path.join(outdir, name+'.pt'))
224 | torch.save(model.state_dict(), os.path.join(outdir, name+'_params.pt'))
225 |
226 | if optimizer is not None:
227 | torch.save(optimizer.state_dict(), os.path.join(outdir, name+'_opt.pt'))
228 |
229 | def load_model(
230 | self,
231 | path_exp,
232 | device='cpu',
233 | name='model.pt'):
234 |
235 | path_pt = os.path.join(path_exp, name)
236 | print(' [*] restoring model from', path_pt)
237 | model = torch.load(path_pt, map_location=torch.device(device))
238 | return model
239 |
240 | def global_step_increment(self):
241 | self.global_step += 1
242 |
243 | """
244 | file modes
245 | 'a':
246 | Opens a file for appending. The file pointer is at the end of the file if the file exists.
247 | That is, the file is in the append mode. If the file does not exist, it creates a new file for writing.
248 |
249 | 'w':
250 | Opens a file for writing only. Overwrites the file if the file exists.
251 | If the file does not exist, creates a new file for writing.
252 | """
253 |
254 | def make_loss_report(
255 | path_log,
256 | path_figure='loss.png',
257 | dpi=100):
258 |
259 | # load logfile
260 | monitor_vals = collections.defaultdict(list)
261 | with open(path_logfile, 'r') as f:
262 | for line in f:
263 | try:
264 | line = line.strip()
265 | key, val, step, acc_time = line.split(' | ')
266 | monitor_vals[key].append((float(val), int(step), acc_time))
267 | except:
268 | continue
269 |
270 | # collect
271 | step_train = [item[1] for item in monitor_vals['train loss']]
272 | vals_train = [item[0] for item in monitor_vals['train loss']]
273 |
274 | step_valid = [item[1] for item in monitor_vals['valid loss']]
275 | vals_valid = [item[0] for item in monitor_vals['valid loss']]
276 |
277 | x_min = step_valid[np.argmin(vals_valid)]
278 | y_min = min(vals_valid)
279 |
280 | # plot
281 | fig = plt.figure(dpi=dpi)
282 | plt.title('training process')
283 | plt.plot(step_train, vals_train, label='train')
284 | plt.plot(step_valid, vals_valid, label='valid')
285 | plt.yscale('log')
286 | plt.plot([x_min], [y_min], 'ro')
287 | plt.legend(loc='upper right')
288 | plt.tight_layout()
289 | plt.savefig(path_figure)
290 |
291 |
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/config.yml:
--------------------------------------------------------------------------------
1 |
2 | MODEL:
3 | n_head: 8
4 | n_layer: 12
5 | dropout: 0.1
6 | d_inner: 2048 #d_ff
7 | d_embed: 512
8 | d_model: 512
9 | dropatt: 0.0 #attention probability dropout rate
10 | query_dim: 16 #64
11 | seq_len: 512 #512
12 | n_token: 332
13 | mem_len: 512
14 | ext_len: 0
15 | tgt_len: 70
16 | eval_tgt_len: 50
17 | init: 'normal' #parameter initializer to use.
18 | emb_init: 'normal' #parameter initializer to use.
19 | init_range: 0.1
20 | emb_init_range: 0.01 #parameters initialized by U(-init_range, init_range)
21 | init_std: 0.02 #parameters initialized by N(0, init_std)
22 | proj_init_std: 0.01
23 | clamp_len: -1 #use the same pos embeddings after clamp_len
24 | div_val: 1
25 | position_concat: False
26 | pre_lnorm: True #apply LayerNorm to the input instead of the output
27 | same_length: True #use the same attn length for all tokens
28 |
29 |
30 | TRAIN:
31 | ROOT: '../../../dataset/representations/uncond/remi/ailab17k_from-scratch_remi'
32 | gpuID: '1'
33 | output_dir: "./exp"
34 | batch_size: 10 #5
35 | lr: 0.0002
36 | num_epochs: 600
37 | save_freq: 10
38 | seed: 2222
39 | optim: 'adam'
40 | no_cuda: False
41 | resume_training_model: None
42 | # resume_training_model: '/volume/ai-music-wayne/aaai/from-scratch/remi-xl_review/result/20200901-064426/ep_170.pth.tar'
43 |
44 |
45 | INFERENCE:
46 | num_sample: 20
47 | gpuID: '1'
48 | dictionary_path: '../../../dataset/representations/uncond/remi/ailab17k_from-scratch_remi/dictionary.pkl'
49 | experiment_dir: '/volume/ai-music-wayne/aaai/from-scratch/remi-xl_review/result/20200901-064426'
50 | generated_dir: './gen_midis'
51 | checkpoint_type: epoch_idx # best_train, best_val, epoch_idx
52 | model_epoch: 170
53 | no_cuda: False
54 |
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_0.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_0.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_1.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_1.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_10.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_10.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_11.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_11.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_12.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_12.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_13.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_13.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_14.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_14.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_15.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_15.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_16.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_16.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_17.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_17.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_18.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_18.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_19.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_19.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_2.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_2.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_3.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_3.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_4.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_4.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_5.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_5.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_6.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_6.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_7.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_7.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_8.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_8.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/gen_midis/170_9.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YatingMusic/compound-word-transformer/1a56daa80d2381303572ee856f5fbc5a82ec59ea/workspace/uncond/remi-xl/gen_midis/170_9.mid
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/inference.py:
--------------------------------------------------------------------------------
1 | from model import TransformerXL
2 | import pickle
3 | import random
4 | import os
5 | import time
6 | import torch
7 | import random
8 | import yaml
9 | import json
10 |
11 | import numpy as np
12 |
13 |
14 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
15 |
16 | def main():
17 | cfg = yaml.full_load(open("config.yml", 'r'))
18 | inferenceConfig = cfg['INFERENCE']
19 |
20 | os.environ['CUDA_VISIBLE_DEVICES'] = inferenceConfig['gpuID']
21 |
22 | print('='*2, 'Inferenc configs', '='*5)
23 | print(json.dumps(inferenceConfig, indent=1, sort_keys=True))
24 |
25 | # checkpoint information
26 | CHECKPOINT_FOLDER = inferenceConfig['experiment_dir']
27 | midi_folder = inferenceConfig["generated_dir"]
28 |
29 | checkpoint_type = inferenceConfig['checkpoint_type']
30 | if checkpoint_type == 'best_train':
31 | model_path = os.path.join(CHECKPOINT_FOLDER, 'model_best.pth.tar')
32 | output_prefix = 'best_train_'
33 | elif checkpoint_type == 'best_val':
34 | model_path = os.path.join(CHECKPOINT_FOLDER, 'model_best_val.pth.tar')
35 | output_prefix = 'best_val_'
36 | elif checkpoint_type == 'epoch_idx':
37 | model_path = os.path.join(CHECKPOINT_FOLDER, 'ep_{}.pth.tar'.format(str(inferenceConfig['model_epoch'])))
38 | output_prefix = str(inferenceConfig['model_epoch'])+ '_'
39 |
40 | pretrainCfg = yaml.full_load(open(os.path.join(CHECKPOINT_FOLDER,"config.yml"), 'r'))
41 | modelConfig = pretrainCfg['MODEL']
42 |
43 | # create result folder
44 | if not os.path.exists(midi_folder):
45 | os.mkdir(midi_folder)
46 |
47 | # load dictionary
48 | event2word, word2event = pickle.load(open(inferenceConfig['dictionary_path'], 'rb'))
49 |
50 | # declare model
51 | device = torch.device("cuda" if not inferenceConfig["no_cuda"] and torch.cuda.is_available() else "cpu")
52 | print('Device to generate:', device)
53 |
54 | # declare model
55 | model = TransformerXL(
56 | modelConfig,
57 | device,
58 | event2word=event2word,
59 | word2event=word2event,
60 | is_training=False)
61 |
62 | # inference
63 | song_time_list = []
64 | words_len_list = []
65 | num_samples = inferenceConfig["num_sample"]
66 | for idx in range(num_samples):
67 | print(f'==={idx}/{num_samples}===')
68 | print(midi_folder, output_prefix + str(idx))
69 | song_time, word_len = model.inference(
70 | model_path = model_path,
71 | token_lim=7680,
72 | strategies=['temperature', 'nucleus'],
73 | params={'t': 1.2, 'p': 0.9},
74 | bpm=120,
75 | output_path='{}/{}.mid'.format(midi_folder, output_prefix + str(idx)))
76 |
77 | print('song time:', song_time)
78 | print('word_len:', word_len)
79 | words_len_list.append(word_len)
80 | song_time_list.append(song_time)
81 |
82 |
83 | print('ave token time:', sum(words_len_list) / sum(song_time_list))
84 | print('ave song time:', np.mean(song_time_list))
85 |
86 | runtime_result = {
87 | 'song_time':song_time_list,
88 | 'words_len_list': words_len_list,
89 | 'ave token time:': sum(words_len_list) / sum(song_time_list),
90 | 'ave song time': float(np.mean(song_time_list)),
91 | }
92 |
93 |
94 | with open('runtime_stats.json', 'w') as f:
95 | json.dump(runtime_result, f)
96 |
97 | if __name__ == '__main__':
98 | main()
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import math
6 | import numpy as np
7 | import pandas as pd
8 | import miditoolkit
9 | import shutil
10 | import copy
11 | import os
12 | import time
13 | import json
14 | from sklearn.model_selection import train_test_split
15 | from modules import MemTransformerLM
16 | from glob import glob
17 |
18 | import miditoolkit
19 | from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note
20 | import collections
21 | import pickle
22 | import numpy as np
23 |
24 | import saver
25 |
26 | # ================================ #
27 | BEAT_RESOL = 480
28 | BAR_RESOL = BEAT_RESOL * 4
29 | TICK_RESOL = BEAT_RESOL // 4
30 | INSTR_NAME_MAP = {'piano': 0, 'melody': 1}
31 |
32 |
33 | def wrtie_midi(words, path_midi, word2event):
34 | notes_all = []
35 |
36 | events = [word2event[words[i]] for i in range(len(words))]
37 |
38 | bar_cnt = 0
39 | cur_beat = 0
40 |
41 | midi_obj = miditoolkit.midi.parser.MidiFile()
42 | cur_pos = 0
43 |
44 | for i in range(len(events)-3):
45 | cur_event = events[i]
46 | # print(cur_event)
47 | name = cur_event.split('_')[0]
48 | attr = cur_event.split('_')
49 | if name == 'Bar':
50 | bar_cnt += 1
51 | elif name == 'Beat':
52 | cur_beat = int(attr[1])
53 | cur_pos = bar_cnt * BAR_RESOL + cur_beat * TICK_RESOL
54 | elif name == 'Chord':
55 | chord_text = attr[1] + '_' + attr[2]
56 | midi_obj.markers.append(Marker(text=chord_text, time=cur_pos))
57 | elif name == 'Tempo':
58 | midi_obj.tempo_changes.append(
59 | TempoChange(tempo=int(attr[1]), time=cur_pos))
60 | else:
61 | if 'Note_Pitch' in events[i] and \
62 | 'Note_Velocity' in events[i+1] and \
63 | 'Note_Duration' in events[i+2]:
64 |
65 | pitch = int(events[i].split('_')[-1])
66 | duration = int(events[i+2].split('_')[-1])
67 |
68 | if int(duration) == 0:
69 | duration = 60
70 |
71 | end = cur_pos + duration
72 | velocity = int(events[i+1].split('_')[-1])
73 | notes_all.append(
74 | Note(pitch=pitch, start=cur_pos, end=end, velocity=velocity))
75 |
76 | piano_track = Instrument(0, is_drum=False, name='piano')
77 | piano_track.notes = notes_all
78 | midi_obj.instruments = [piano_track]
79 | midi_obj.dump(path_midi)
80 |
81 |
82 | # ================================ #
83 | def network_paras(model):
84 | # compute only trainable params
85 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
86 | params = sum([np.prod(p.size()) for p in model_parameters])
87 | return params
88 |
89 |
90 | class TransformerXL(object):
91 | def __init__(self, modelConfig, device, event2word, word2event, is_training=True):
92 |
93 | self.event2word = event2word
94 | self.word2event = word2event
95 | self.modelConfig = modelConfig
96 |
97 | # model settings
98 | self.n_layer= modelConfig['n_layer']
99 | self.d_model = modelConfig['d_model']
100 | self.seq_len= modelConfig['seq_len']
101 | self.mem_len = modelConfig['mem_len']
102 |
103 | self.tgt_len = modelConfig['tgt_len']
104 | self.ext_len = modelConfig['ext_len']
105 | self.eval_tgt_len = modelConfig['eval_tgt_len']
106 |
107 | self.init = modelConfig['init']
108 | self.init_range = modelConfig['init_range']
109 | self.init_std = modelConfig['init_std']
110 | self.proj_init_std = modelConfig['proj_init_std']
111 |
112 | #mode
113 | self.is_training = is_training
114 | self.device = device
115 |
116 |
117 | def init_weight(self, weight):
118 | if self.init == 'uniform':
119 | nn.init.uniform_(weight, -self.init_range, self.init_range)
120 | elif self.init == 'normal':
121 | nn.init.normal_(weight, 0.0, self.init_std)
122 |
123 | def init_bias(self, bias):
124 | nn.init.constant_(bias, 0.0)
125 |
126 | def weights_init(self,m):
127 | classname = m.__class__.__name__
128 | if classname.find('Linear') != -1:
129 | if hasattr(m, 'weight') and m.weight is not None:
130 | self.init_weight(m.weight)
131 | if hasattr(m, 'bias') and m.bias is not None:
132 | self.init_bias(m.bias)
133 | elif classname.find('Embedding') != -1:
134 | if hasattr(m, 'weight'):
135 | self.init_weight(m.weight)
136 | elif classname.find('LayerNorm') != -1:
137 | if hasattr(m, 'weight'):
138 | nn.init.normal_(m.weight, 1.0, self.init_std)
139 | if hasattr(m, 'bias') and m.bias is not None:
140 | self.init_bias(m.bias)
141 | elif classname.find('TransformerLM') != -1:
142 | if hasattr(m, 'r_emb'):
143 | self.init_weight(m.r_emb)
144 | if hasattr(m, 'r_w_bias'):
145 | self.init_weight(m.r_w_bias)
146 | if hasattr(m, 'r_r_bias'):
147 | self.init_weight(m.r_r_bias)
148 | if hasattr(m, 'r_bias'):
149 | self.init_bias(m.r_bias)
150 |
151 |
152 | def get_model(self, pretrain_model=None):
153 | model = MemTransformerLM(self.modelConfig, is_training=self.is_training)
154 |
155 | st_eopch = 0
156 | if pretrain_model:
157 | checkpoint = torch.load(pretrain_model, map_location='cuda:0')
158 | print('Pretrained model config:')
159 | print('epoch: ', checkpoint['epoch'])
160 | print('best_loss: ', checkpoint['best_loss'])
161 | print(json.dumps(checkpoint['model_setting'], indent=1, sort_keys=True))
162 | print(json.dumps(checkpoint['train_setting'], indent=1, sort_keys=True))
163 |
164 | try:
165 | model.load_state_dict(checkpoint['state_dict'])
166 | print('{} loaded.'.format(pretrain_model))
167 | except:
168 | print('Loaded weights have different shapes with the model. Please check your model setting.')
169 | exit()
170 | st_eopch = checkpoint['epoch'] + 1
171 |
172 | else:
173 | model.apply(self.weights_init)
174 | model.word_emb.apply(self.weights_init)
175 | return st_eopch ,model.to(self.device)
176 |
177 |
178 | def save_checkpoint(self, state, root, save_freq=10):
179 | if state['epoch'] % save_freq == 0:
180 | torch.save(state, os.path.join(root,'ep_{}.pth.tar'.format(state['epoch'])))
181 |
182 | def train_loss_record(self, epoch, train_loss,checkpoint_dir, val_loss=None):
183 |
184 | if val_loss:
185 | df = pd.DataFrame({'epoch': [epoch+1],
186 | 'train_loss': ['%.3f'%train_loss],
187 | 'val_loss': ['%.3f'%val_loss]})
188 |
189 | else:
190 | df = pd.DataFrame({'epoch': [epoch+1],
191 | 'train_loss': ['%.3f'%train_loss]})
192 |
193 | csv_file = os.path.join(checkpoint_dir, 'loss.csv')
194 |
195 | if not os.path.exists(csv_file):
196 | df.to_csv(csv_file, index=False)
197 | else:
198 | df.to_csv(os.path.join(checkpoint_dir, 'loss.csv'), mode='a', header=False, index=False)
199 |
200 | def train(self, train_data, trainConfig, device, resume):
201 | checkpoint_dir = trainConfig['experiment_Dir']
202 | batch_size = trainConfig['batch_size']
203 | data_ROOT = trainConfig['ROOT']
204 | torch.manual_seed(trainConfig["seed"])
205 |
206 | # create saver
207 | saver_agent = saver.Saver(checkpoint_dir)
208 |
209 | #Prepare model
210 | if resume != 'None':
211 | st_epoch, model = self.get_model(resume)
212 | print('Continue to train from {} epoch'.format(st_epoch))
213 | else:
214 | st_epoch, model = self.get_model()
215 |
216 | optimizer = optim.Adam(model.parameters(), lr=trainConfig['lr'])
217 | train_step = 0
218 | epoch_train_loss = []
219 | save_freq = trainConfig['save_freq']
220 |
221 | n_parameters = network_paras(model)
222 | print('n_parameters: {:,}'.format(n_parameters))
223 | saver_agent.add_summary_msg(
224 | ' > params amount: {:,d}'.format(n_parameters))
225 |
226 | # unpack
227 | train_x = train_data['x']
228 | train_y = train_data['y']
229 | mask = train_data['mask']
230 | num_groups = train_data['num_groups']
231 |
232 | num_batches = len(train_x ) // batch_size
233 |
234 | print('>>> Start training')
235 | for epoch in range(st_epoch, trainConfig['num_epochs']):
236 | saver_agent.global_step_increment()
237 |
238 | train_loss = []
239 | st_time = time.time()
240 | model.train()
241 |
242 | for bidx in range(num_batches):
243 |
244 | model.zero_grad()
245 |
246 | # index
247 | bidx_st = batch_size * bidx
248 | bidx_ed = batch_size * (bidx + 1)
249 |
250 | # get batch
251 | batch_x = train_x[bidx_st:bidx_ed]
252 | batch_y = train_y[bidx_st:bidx_ed]
253 | batch_mask = mask[bidx_st:bidx_ed]
254 | n_group = np.max(num_groups[bidx_st:bidx_ed])
255 |
256 | # proc groups
257 | mems = tuple()
258 | for gidx in range(n_group):
259 | group_x = batch_x[:, gidx, :]
260 | group_y = batch_y[:, gidx, :]
261 | group_mask = batch_mask[:, gidx, :]
262 |
263 | group_x = torch.from_numpy(group_x).permute(1, 0).contiguous().to(self.device).long() # (seq_len, bsz)
264 | group_y = torch.from_numpy(group_y).permute(1, 0).contiguous().to(self.device).long()
265 | group_mask = torch.from_numpy(group_mask).to(self.device).float()
266 |
267 | ret = model(group_x, group_y, group_mask, *mems)
268 | loss, mems = ret[0], ret[1:]
269 | train_loss.append(loss.item())
270 | loss.backward()
271 |
272 | sys.stdout.write('epoch:{:3d}/{:3d}, batch: {:4d}/{:4d}, group: {:2d}/{:2d} | Loss: {:6f}\r'.format(
273 | epoch,
274 | trainConfig['num_epochs'],
275 | bidx,
276 | num_batches,
277 | gidx,
278 | n_group,
279 | loss.item()
280 | ))
281 | sys.stdout.flush()
282 |
283 | optimizer.step()
284 |
285 | #val_loss = self.validate(val_data, batch_size, model, trainConfig["seed"], trainConfig['max_eval_steps'])
286 | curr_train_loss = sum(train_loss) / len(train_loss)
287 | saver_agent.add_summary('epoch loss', curr_train_loss)
288 |
289 | #epoch_val_loss.append(val_loss)
290 | epoch_train_loss.append(curr_train_loss)
291 | # epoch_info = 'Train Loss: {:.5f} , Val Loss: {:.5f}, T: {:.3f}'.format(curr_train_loss, val_loss, time.time()-st_time)
292 | epoch_info = 'Epoch: {}, Train Loss: {:.5f} , T: {:.3f}'.format(epoch+1, curr_train_loss, time.time()-st_time)
293 | print(epoch_info)
294 |
295 | # self.train_loss_record(epoch, curr_train_loss, checkpoint_dir, val_loss)
296 | self.train_loss_record(epoch, curr_train_loss, checkpoint_dir)
297 | self.save_checkpoint({
298 | 'epoch': epoch + 1,
299 | 'model_setting': self.modelConfig,
300 | 'train_setting': trainConfig,
301 | 'state_dict': model.state_dict(),
302 | 'best_loss': curr_train_loss,
303 | 'optimizer' : optimizer.state_dict(),
304 | },
305 | checkpoint_dir,
306 | save_freq)
307 |
308 | if curr_train_loss < 0.01:
309 | print('Experiment [{}] finished at loss < 0.01.'.format(checkpoint_dir))
310 | break
311 |
312 | def inference(self, model_path, token_lim, strategies, params, bpm, output_path):
313 | _, model = self.get_model(model_path)
314 | model.eval()
315 |
316 | # initial start
317 | words = [[]]
318 |
319 | # add beat
320 | words[-1].append(self.event2word['Bar_None'])
321 |
322 | # initialize mem
323 | mems = tuple()
324 | song_init_time = time.time()
325 | # generate
326 | initial_flag = True
327 | generate_n_bar = 0
328 | batch_size = 1
329 | n_tokens = len(words[0])
330 | while len(words[0]) < token_lim:
331 | # prepare input
332 | if initial_flag:
333 | temp_x = np.zeros((len(words[0]), batch_size))
334 |
335 | for b in range(batch_size):
336 | for z, t in enumerate(words[b]):
337 | temp_x[z][b] = t
338 |
339 | initial_flag = False
340 | else:
341 | temp_x = np.zeros((1, batch_size))
342 |
343 | for b in range(batch_size):
344 | temp_x[0][b] = words[b][-1] ####?####
345 |
346 | temp_x = torch.from_numpy(temp_x).long().to(self.device)
347 | st_time = time.time()
348 |
349 | _logits, mems = model.generate(temp_x, *mems)
350 | logits = _logits.cpu().squeeze().detach().numpy()
351 |
352 | # temperature or not
353 | if 'temperature' in strategies:
354 | probs = self.temperature(logits=logits, temperature=params['t'])
355 |
356 | else:
357 | probs = self.temperature(logits=logits, temperature=1.)
358 | # sampling
359 | word = self.nucleus(probs=probs, p=params['p'])
360 | words[0].append(word)
361 |
362 | print(len(words[0]), self.word2event[word])
363 | # record n_bar
364 | if word == self.event2word['Bar_None']:
365 | generate_n_bar += 1
366 |
367 |
368 | wrtie_midi(words[0], output_path, self.word2event)
369 |
370 | song_total_time = time.time() - song_init_time
371 | print('Total words generated: ', len(words[0]))
372 | return song_total_time, len(words[0])
373 |
374 | ########################################
375 | # search strategy: temperature (re-shape)
376 | ########################################
377 | def temperature(self, logits, temperature):
378 | probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature))
379 | return probs
380 |
381 | ########################################
382 | # search strategy: topk (truncate)
383 | ########################################
384 | def topk(self, probs, k):
385 | sorted_index = np.argsort(probs)[::-1]
386 | candi_index = sorted_index[:k]
387 | candi_probs = [probs[i] for i in candi_index]
388 | candi_probs /= sum(candi_probs)
389 | word = np.random.choice(candi_index, size=1, p=candi_probs)[0]
390 | return word
391 |
392 | ########################################
393 | # search strategy: nucleus (truncate)
394 | ########################################
395 | def nucleus(self, probs, p):
396 | probs /= sum(probs)
397 | sorted_probs = np.sort(probs)[::-1]
398 | sorted_index = np.argsort(probs)[::-1]
399 | cusum_sorted_probs = np.cumsum(sorted_probs)
400 | after_threshold = cusum_sorted_probs > p
401 | if sum(after_threshold) > 0:
402 | last_index = np.where(after_threshold)[0][0] + 1
403 | candi_index = sorted_index[:last_index]
404 | else:
405 | candi_index = sorted_index[:3] # just assign a value
406 | candi_probs = [probs[i] for i in candi_index]
407 | candi_probs /= sum(candi_probs)
408 | word = np.random.choice(candi_index, size=1, p=candi_probs)[0]
409 | return word
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/modules.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import math
3 | import functools
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class PositionalEmbedding(nn.Module):
12 | def __init__(self, demb):
13 | super(PositionalEmbedding, self).__init__()
14 |
15 | self.demb = demb
16 |
17 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
18 | self.register_buffer('inv_freq', inv_freq)
19 |
20 | def forward(self, pos_seq, bsz=None):
21 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
22 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
23 |
24 | if bsz is not None:
25 | return pos_emb[:,None,:].expand(-1, bsz, -1)
26 | else:
27 | return pos_emb[:,None,:]
28 |
29 |
30 | class PositionwiseFF(nn.Module):
31 | def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
32 | super(PositionwiseFF, self).__init__()
33 |
34 | self.d_model = d_model
35 | self.d_inner = d_inner
36 | self.dropout = dropout
37 |
38 | self.CoreNet = nn.Sequential(
39 | nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
40 | nn.Dropout(dropout),
41 | nn.Linear(d_inner, d_model),
42 | nn.Dropout(dropout),
43 | )
44 |
45 | self.layer_norm = nn.LayerNorm(d_model)
46 |
47 | self.pre_lnorm = pre_lnorm
48 |
49 | def forward(self, inp):
50 | if self.pre_lnorm:
51 | ##### layer normalization + positionwise feed-forward
52 | core_out = self.CoreNet(self.layer_norm(inp))
53 |
54 | ##### residual connection
55 | output = core_out + inp
56 | else:
57 | ##### positionwise feed-forward
58 | core_out = self.CoreNet(inp)
59 |
60 | ##### residual connection + layer normalization
61 | output = self.layer_norm(inp + core_out)
62 |
63 | return output
64 |
65 |
66 | class RelMultiHeadAttn(nn.Module):
67 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
68 | tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):
69 | super(RelMultiHeadAttn, self).__init__()
70 |
71 | self.n_head = n_head
72 | self.d_model = d_model
73 | self.d_head = d_head
74 | self.dropout = dropout
75 |
76 | self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
77 |
78 | self.drop = nn.Dropout(dropout)
79 | self.dropatt = nn.Dropout(dropatt)
80 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
81 |
82 | self.layer_norm = nn.LayerNorm(d_model)
83 |
84 | self.scale = 1 / (d_head ** 0.5)
85 |
86 | self.pre_lnorm = pre_lnorm
87 |
88 | def _parallelogram_mask(self, h, w, left=False):
89 | mask = torch.ones((h, w)).byte()
90 | m = min(h, w)
91 | mask[:m,:m] = torch.triu(mask[:m,:m])
92 | mask[-m:,-m:] = torch.tril(mask[-m:,-m:])
93 |
94 | if left:
95 | return mask
96 | else:
97 | return mask.flip(0)
98 |
99 | def _shift(self, x, qlen, klen, mask, left=False):
100 | if qlen > 1:
101 | zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
102 | device=x.device, dtype=x.dtype)
103 | else:
104 | zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
105 |
106 | if left:
107 | mask = mask.flip(1)
108 | x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
109 | else:
110 | x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
111 |
112 | x = x_padded.masked_select(mask[:,:,None,None]) \
113 | .view(qlen, klen, x.size(2), x.size(3))
114 |
115 | return x
116 |
117 | def _rel_shift(self, x, zero_triu=False):
118 | zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
119 | device=x.device, dtype=x.dtype)
120 | x_padded = torch.cat([zero_pad, x], dim=1)
121 |
122 | x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
123 |
124 | x = x_padded[1:].view_as(x)
125 |
126 | if zero_triu:
127 | ones = torch.ones((x.size(0), x.size(1)))
128 | x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
129 |
130 | return x
131 |
132 | def forward(self, w, r, attn_mask=None, mems=None):
133 | raise NotImplementedError
134 |
135 | class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
136 | def __init__(self, *args, **kwargs):
137 | super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
138 |
139 | self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
140 |
141 | def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
142 | qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
143 |
144 | if mems is not None:
145 | # print("w",w.shape)
146 | # print("mems",mems.shape)
147 | cat = torch.cat([mems, w], 0)
148 | if self.pre_lnorm:
149 | w_heads = self.qkv_net(self.layer_norm(cat))
150 | else:
151 | w_heads = self.qkv_net(cat)
152 | r_head_k = self.r_net(r)
153 |
154 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
155 | w_head_q = w_head_q[-qlen:]
156 | else:
157 | if self.pre_lnorm:
158 | w_heads = self.qkv_net(self.layer_norm(w))
159 | else:
160 | w_heads = self.qkv_net(w)
161 | r_head_k = self.r_net(r)
162 |
163 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
164 |
165 | klen = w_head_k.size(0)
166 |
167 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
168 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
169 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
170 |
171 | r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
172 |
173 | #### compute attention score
174 | rw_head_q = w_head_q + r_w_bias
175 | AC = rw_head_q.permute(1, 2, 0, 3) @ w_head_k.permute(1, 2, 3, 0)
176 |
177 | rr_head_q = w_head_q + r_r_bias
178 | BD = rr_head_q.permute(1, 2, 0, 3) @ r_head_k.permute(1, 2, 0)
179 | BD = F.pad(BD, [1, 0]).view(BD.size(0), BD.size(
180 | 1), BD.size(3) + 1, BD.size(2))[:, :, 1:].view_as(BD)
181 |
182 | # [bsz x n_head x qlen x klen]
183 | attn_score = AC + BD
184 | attn_score.mul_(self.scale)
185 |
186 | #### compute attention probability
187 | if attn_mask is not None and attn_mask.any().item():
188 | if attn_mask.dim() == 2:
189 | attn_score = attn_score.float().masked_fill(
190 | attn_mask, -float('inf')).type_as(attn_score)
191 | elif attn_mask.dim() == 3:
192 | attn_score = attn_score.float().masked_fill(
193 | attn_mask.permute(2, 0, 1)[:, None, :, :], -float('inf')).type_as(attn_score)
194 |
195 | # [bsz x n_head x qlen x klen]
196 | attn_prob = F.softmax(attn_score, dim=-1)
197 | attn_prob = self.dropatt(attn_prob)
198 |
199 | #### compute attention vector
200 | attn_vec = attn_prob @ w_head_v.permute(1, 2, 0, 3)
201 | attn_vec = attn_vec.permute(2, 0, 1, 3)
202 |
203 | # [qlen x bsz x n_head x d_head]
204 | attn_vec = attn_vec.contiguous().view(
205 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
206 |
207 | ##### linear projection
208 | attn_out = self.o_net(attn_vec)
209 | attn_out = self.drop(attn_out)
210 |
211 | if self.pre_lnorm:
212 | ##### residual connection
213 | output = w + attn_out
214 | else:
215 | ##### residual connection + layer normalization
216 | output = self.layer_norm(w + attn_out)
217 |
218 | return output
219 |
220 |
221 | class RelPartialLearnableDecoderLayer(nn.Module):
222 | def __init__(self, n_head, d_model, d_head, d_inner, dropout,
223 | **kwargs):
224 | super(RelPartialLearnableDecoderLayer, self).__init__()
225 |
226 | self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
227 | d_head, dropout, **kwargs)
228 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
229 | pre_lnorm=kwargs.get('pre_lnorm'))
230 |
231 | def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
232 |
233 | output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
234 | attn_mask=dec_attn_mask,
235 | mems=mems)
236 | output = self.pos_ff(output)
237 |
238 | return output
239 |
240 |
241 | class Embeddings(nn.Module):
242 | def __init__(self, n_token, d_model):
243 | super(Embeddings, self).__init__()
244 | self.lut = nn.Embedding(n_token, d_model)
245 | self.d_model = d_model
246 |
247 | def forward(self, x):
248 | return self.lut(x) * math.sqrt(self.d_model)
249 |
250 |
251 |
252 | class MemTransformerLM(nn.Module):
253 | def __init__(self, modelConfig,
254 | tie_projs=[False], cutoffs=[],
255 | is_training=True):
256 | super(MemTransformerLM, self).__init__()
257 |
258 | self.n_token = modelConfig['n_token']
259 | self.n_layer= modelConfig['n_layer']
260 | self.n_head= modelConfig['n_head']
261 | self.d_model = modelConfig['d_model']
262 | self.d_embed = d_model if modelConfig['d_embed'] is None else modelConfig['d_embed']
263 | self.d_head = self.d_model // self.n_head
264 | self.d_inner= modelConfig['d_inner']
265 |
266 | self.mem_len = modelConfig['mem_len']
267 | self.tgt_len = modelConfig['tgt_len']
268 | self.ext_len = modelConfig['ext_len']
269 | self.max_klen = self.tgt_len + self.ext_len + self.mem_len #70+0+512
270 |
271 | self.dropout= modelConfig['dropout']
272 | self.dropatt = modelConfig['dropatt']
273 |
274 | self.clamp_len = modelConfig['clamp_len']
275 | self.div_val = modelConfig['div_val']
276 |
277 | #choice
278 | self.pre_lnorm = modelConfig['pre_lnorm']
279 | self.same_length = modelConfig['same_length']
280 | self.is_training = is_training
281 |
282 | #building layers
283 | self.drop = nn.Dropout(self.dropout)
284 | self.word_emb = Embeddings(self.n_token, self.d_model)
285 |
286 | self.layers = nn.ModuleList()
287 | for i in range(self.n_layer):
288 | self.layers.append(
289 | RelPartialLearnableDecoderLayer(
290 | self.n_head, self.d_model, self.d_head, self.d_inner, self.dropout,
291 | tgt_len=self.tgt_len, ext_len=self.ext_len, mem_len=self.mem_len,
292 | dropatt=self.dropatt, pre_lnorm=self.pre_lnorm)
293 | )
294 |
295 | # output layer
296 | self.linear_proj = nn.Linear(self.d_model, self.n_token)
297 |
298 | # loss
299 | self.loss_func = nn.CrossEntropyLoss(reduction='none')
300 | self._create_params()
301 |
302 | def compute_loss(self, predict, target, loss_mask=None):
303 | '''
304 | predict, target,
305 | input: (N, C, ...)
306 | target: (N, ...)
307 | '''
308 | loss = self.loss_func(predict, target)
309 | loss = loss * loss_mask
310 | loss = torch.sum(loss) / torch.sum(loss_mask)
311 | return loss
312 |
313 | def _create_params(self):
314 | self.pos_emb = PositionalEmbedding(self.d_model)
315 | self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
316 | self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
317 |
318 | def reset_length(self, tgt_len, ext_len, mem_len):
319 | self.tgt_len = tgt_len
320 | self.mem_len = mem_len
321 | self.ext_len = ext_len
322 |
323 | def init_mems(self):
324 | if self.mem_len > 0:
325 | mems = []
326 | param = next(self.parameters())
327 | for i in range(self.n_layer+1):
328 | empty = torch.empty(0, dtype=param.dtype, device=param.device)
329 | mems.append(empty)
330 | return mems
331 | else:
332 | return None
333 |
334 | def _update_mems(self, hids, mems, mlen, qlen):
335 |
336 | if mems is None: return None
337 | # mems is not None
338 | # assert len(hids) == len(mems), 'len(hids) != len(mems)'
339 |
340 | # There are `mlen + qlen` steps that can be cached into mems
341 | # For the next step, the last `ext_len` of the `qlen` tokens
342 | # will be used as the extended context. Hence, we only cache
343 | # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
344 | # to `mlen + qlen - self.ext_len`.
345 | with torch.no_grad():
346 | new_mems = []
347 | end_idx = mlen + max(0, qlen - 0 - self.ext_len)
348 | beg_idx = max(0, end_idx - self.mem_len)
349 |
350 | for i in range(len(hids)):
351 | cat = torch.cat([mems[i], hids[i]], dim=0)
352 | new_mems.append(cat[beg_idx:end_idx].detach())
353 |
354 | return new_mems
355 |
356 |
357 |
358 | def _forward(self, dec_inp, mems=None):
359 | '''
360 | output of _forward: step x batch x n_feat
361 | predict = self.linear_proj(hidden)
362 | '''
363 |
364 | qlen, bsz = dec_inp.size()
365 | mlen = mems[0].size(0) if mems is not None else 0
366 | klen = mlen + qlen
367 |
368 | word_emb = self.word_emb(dec_inp)
369 |
370 | if self.same_length:
371 | all_ones = word_emb.new_ones(qlen, klen)
372 | mask_len = klen - self.mem_len
373 |
374 | if mask_len > 0:
375 | mask_shift_len = qlen - mask_len
376 | else:
377 | mask_shift_len = qlen
378 | dec_attn_mask = (torch.triu(all_ones, 1+mlen)
379 | + torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1
380 | else:
381 | dec_attn_mask = torch.triu(
382 | word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None]
383 |
384 |
385 | hids = []
386 | pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
387 | dtype=word_emb.dtype)
388 | if self.clamp_len > 0:
389 | pos_seq.clamp_(max=self.clamp_len)
390 | pos_emb = self.pos_emb(pos_seq)
391 | core_out = self.drop(word_emb)
392 | pos_emb = self.drop(pos_emb)
393 | hids.append(core_out)
394 |
395 | for i, layer in enumerate(self.layers):
396 | mems_i = None if mems is None else mems[i]
397 |
398 | core_out = layer(core_out, pos_emb, self.r_w_bias,
399 | self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
400 | hids.append(core_out)
401 |
402 | core_out = self.drop(core_out)
403 | new_mems = self._update_mems(hids, mems, mlen, qlen)
404 |
405 | return core_out, new_mems
406 |
407 | def generate(self, data, *mems):
408 | if not mems: mems = self.init_mems()
409 | hidden, new_mems = self._forward(data, mems=mems)
410 | predict = self.linear_proj(hidden[-1:])
411 | return predict, new_mems
412 |
413 | def forward(self, data, target, mask, *mems):
414 | if not mems: mems = self.init_mems()
415 |
416 | tgt_len = target.size(0)
417 | hidden, new_mems = self._forward(data, mems=mems)
418 |
419 | pred_hid = hidden[-tgt_len:]
420 | predict = self.linear_proj(pred_hid)
421 |
422 | predict = predict.permute(1, 2, 0)
423 | target = target.permute(1, 0)
424 |
425 | loss = self.compute_loss(predict, target, mask)
426 |
427 | if new_mems is None:
428 | return [loss]
429 | else:
430 | return [loss] + new_mems
431 |
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/runtime_stats.json:
--------------------------------------------------------------------------------
1 | {"song_time": [400.5098271369934, 377.5159821510315, 141.5237319469452, 141.3457088470459, 140.7412805557251, 146.80897665023804, 149.34970378875732, 180.57227396965027, 187.8685564994812, 190.16119360923767, 186.2575488090515, 184.2108030319214, 182.1258454322815, 179.762943983078, 185.42160320281982, 185.23676824569702, 188.1896812915802, 191.02102065086365, 186.58602023124695, 179.87972950935364], "words_len_list": [7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680, 7680], "ave token time:": 39.33328847340423, "ave song time": 195.25445997714996}
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/saver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import logging
5 | import datetime
6 | import collections
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | class Saver(object):
12 | def __init__(
13 | self,
14 | exp_dir,
15 | mode='w'):
16 |
17 | self.exp_dir = exp_dir
18 | self.init_time = time.time()
19 | self.global_step = 0
20 |
21 | # makedirs
22 | os.makedirs(exp_dir, exist_ok=True)
23 |
24 | # logging config
25 | path_logger = os.path.join(exp_dir, 'log.txt')
26 | logging.basicConfig(
27 | level=logging.DEBUG,
28 | format='%(message)s',
29 | filename=path_logger,
30 | filemode=mode)
31 | self.logger = logging.getLogger('training monitor')
32 |
33 | def add_summary_msg(self, msg):
34 | self.logger.debug(msg)
35 |
36 | def add_summary(
37 | self,
38 | key,
39 | val,
40 | step=None,
41 | cur_time=None):
42 |
43 | if cur_time is None:
44 | cur_time = time.time() - self.init_time
45 | if step is None:
46 | step = self.global_step
47 |
48 | # write msg (key, val, step, time)
49 | if isinstance(val, float):
50 | msg_str = '{:10s} | {:.10f} | {:10d} | {}'.format(
51 | key,
52 | val,
53 | step,
54 | cur_time
55 | )
56 | else:
57 | msg_str = '{:10s} | {} | {:10d} | {}'.format(
58 | key,
59 | val,
60 | step,
61 | cur_time
62 | )
63 |
64 | self.logger.debug(msg_str)
65 |
66 | def save_model(
67 | self,
68 | model,
69 | optimizer=None,
70 | outdir=None,
71 | name='model'):
72 |
73 | if outdir is None:
74 | outdir = self.exp_dir
75 | print(' [*] saving model to {}, name: {}'.format(outdir, name))
76 | torch.save(model, os.path.join(outdir, name+'.pt'))
77 | torch.save(model.state_dict(), os.path.join(outdir, name+'_params.pt'))
78 |
79 | if optimizer is not None:
80 | torch.save(optimizer.state_dict(), os.path.join(outdir, name+'_opt.pt'))
81 |
82 | def load_model(
83 | self,
84 | path_exp,
85 | device='cpu',
86 | name='model.pt'):
87 |
88 | path_pt = os.path.join(path_exp, name)
89 | print(' [*] restoring model from', path_pt)
90 | model = torch.load(path_pt, map_location=torch.device(device))
91 | return model
92 |
93 | def global_step_increment(self):
94 | self.global_step += 1
95 |
96 | """
97 | file modes
98 | 'a':
99 | Opens a file for appending. The file pointer is at the end of the file if the file exists.
100 | That is, the file is in the append mode. If the file does not exist, it creates a new file for writing.
101 |
102 | 'w':
103 | Opens a file for writing only. Overwrites the file if the file exists.
104 | If the file does not exist, creates a new file for writing.
105 | """
106 |
107 | def make_loss_report(
108 | path_log,
109 | path_figure='loss.png',
110 | dpi=100):
111 |
112 | # load logfile
113 | monitor_vals = collections.defaultdict(list)
114 | with open(path_logfile, 'r') as f:
115 | for line in f:
116 | try:
117 | line = line.strip()
118 | key, val, step, acc_time = line.split(' | ')
119 | monitor_vals[key].append((float(val), int(step), acc_time))
120 | except:
121 | continue
122 |
123 | # collect
124 | step_train = [item[1] for item in monitor_vals['train loss']]
125 | vals_train = [item[0] for item in monitor_vals['train loss']]
126 |
127 | step_valid = [item[1] for item in monitor_vals['valid loss']]
128 | vals_valid = [item[0] for item in monitor_vals['valid loss']]
129 |
130 | x_min = step_valid[np.argmin(vals_valid)]
131 | y_min = min(vals_valid)
132 |
133 | # plot
134 | fig = plt.figure(dpi=dpi)
135 | plt.title('training process')
136 | plt.plot(step_train, vals_train, label='train')
137 | plt.plot(step_valid, vals_valid, label='valid')
138 | plt.yscale('log')
139 | plt.plot([x_min], [y_min], 'ro')
140 | plt.legend(loc='upper right')
141 | plt.tight_layout()
142 | plt.savefig(path_figure)
143 |
144 | '''
145 | author: wayn391@mastertones
146 | '''
147 |
148 | import os
149 | import time
150 | import torch
151 | import logging
152 | import datetime
153 | import collections
154 | import numpy as np
155 | import matplotlib.pyplot as plt
156 |
157 |
158 | class Saver(object):
159 | def __init__(
160 | self,
161 | exp_dir,
162 | mode='w'):
163 |
164 | self.exp_dir = exp_dir
165 | self.init_time = time.time()
166 | self.global_step = 0
167 |
168 | # makedirs
169 | os.makedirs(exp_dir, exist_ok=True)
170 |
171 | # logging config
172 | path_logger = os.path.join(exp_dir, 'log.txt')
173 | logging.basicConfig(
174 | level=logging.DEBUG,
175 | format='%(message)s',
176 | filename=path_logger,
177 | filemode=mode)
178 | self.logger = logging.getLogger('training monitor')
179 |
180 | def add_summary_msg(self, msg):
181 | self.logger.debug(msg)
182 |
183 | def add_summary(
184 | self,
185 | key,
186 | val,
187 | step=None,
188 | cur_time=None):
189 |
190 | if cur_time is None:
191 | cur_time = time.time() - self.init_time
192 | if step is None:
193 | step = self.global_step
194 |
195 | # write msg (key, val, step, time)
196 | if isinstance(val, float):
197 | msg_str = '{:10s} | {:.10f} | {:10d} | {}'.format(
198 | key,
199 | val,
200 | step,
201 | cur_time
202 | )
203 | else:
204 | msg_str = '{:10s} | {} | {:10d} | {}'.format(
205 | key,
206 | val,
207 | step,
208 | cur_time
209 | )
210 |
211 | self.logger.debug(msg_str)
212 |
213 | def save_model(
214 | self,
215 | model,
216 | optimizer=None,
217 | outdir=None,
218 | name='model'):
219 |
220 | if outdir is None:
221 | outdir = self.exp_dir
222 | print(' [*] saving model to {}, name: {}'.format(outdir, name))
223 | # torch.save(model, os.path.join(outdir, name+'.pt'))
224 | torch.save(model.state_dict(), os.path.join(outdir, name+'_params.pt'))
225 |
226 | if optimizer is not None:
227 | torch.save(optimizer.state_dict(), os.path.join(outdir, name+'_opt.pt'))
228 |
229 | def load_model(
230 | self,
231 | path_exp,
232 | device='cpu',
233 | name='model.pt'):
234 |
235 | path_pt = os.path.join(path_exp, name)
236 | print(' [*] restoring model from', path_pt)
237 | model = torch.load(path_pt, map_location=torch.device(device))
238 | return model
239 |
240 | def global_step_increment(self):
241 | self.global_step += 1
242 |
243 | """
244 | file modes
245 | 'a':
246 | Opens a file for appending. The file pointer is at the end of the file if the file exists.
247 | That is, the file is in the append mode. If the file does not exist, it creates a new file for writing.
248 |
249 | 'w':
250 | Opens a file for writing only. Overwrites the file if the file exists.
251 | If the file does not exist, creates a new file for writing.
252 | """
253 |
254 | def make_loss_report(
255 | path_log,
256 | path_figure='loss.png',
257 | dpi=100):
258 |
259 | # load logfile
260 | monitor_vals = collections.defaultdict(list)
261 | with open(path_logfile, 'r') as f:
262 | for line in f:
263 | try:
264 | line = line.strip()
265 | key, val, step, acc_time = line.split(' | ')
266 | monitor_vals[key].append((float(val), int(step), acc_time))
267 | except:
268 | continue
269 |
270 | # collect
271 | step_train = [item[1] for item in monitor_vals['train loss']]
272 | vals_train = [item[0] for item in monitor_vals['train loss']]
273 |
274 | step_valid = [item[1] for item in monitor_vals['valid loss']]
275 | vals_valid = [item[0] for item in monitor_vals['valid loss']]
276 |
277 | x_min = step_valid[np.argmin(vals_valid)]
278 | y_min = min(vals_valid)
279 |
280 | # plot
281 | fig = plt.figure(dpi=dpi)
282 | plt.title('training process')
283 | plt.plot(step_train, vals_train, label='train')
284 | plt.plot(step_valid, vals_valid, label='valid')
285 | plt.yscale('log')
286 | plt.plot([x_min], [y_min], 'ro')
287 | plt.legend(loc='upper right')
288 | plt.tight_layout()
289 | plt.savefig(path_figure)
290 |
291 |
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/songTime.csv:
--------------------------------------------------------------------------------
1 | ,song_name,song_time,num_word,ave_genTime_per_word
2 | 0,60_673cd223aa2c317002ac4c975a9c67df,70.56650042533875,3940,0.01676041873421256
3 | 1,60_c0793ed178d945ddd18bfb7cc7b190b8,35.46065306663513,2001,0.016596261803452316
4 |
--------------------------------------------------------------------------------
/workspace/uncond/remi-xl/train.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import json
4 | import yaml
5 | import pickle
6 | import datetime
7 | import numpy as np
8 | from collections import OrderedDict
9 |
10 | import torch
11 | from model import TransformerXL
12 |
13 |
14 | def main():
15 | # gen config
16 | modelConfig, trainConfig = get_configs()
17 |
18 | # load dictionary
19 | event2word, word2event = pickle.load(open(os.path.join(trainConfig['ROOT'],'dictionary.pkl'), 'rb'))
20 |
21 | # load train data
22 | training_data = np.load(os.path.join(trainConfig['ROOT'],'train_data_XL.npz'))
23 |
24 | device = torch.device("cuda:{}".format(trainConfig['gpuID']) if not trainConfig["no_cuda"] and torch.cuda.is_available() else "cpu")
25 | os.environ['CUDA_VISIBLE_DEVICES'] = trainConfig['gpuID']
26 |
27 | print('Device to train:', device)
28 |
29 | resume = trainConfig['resume_training_model']
30 |
31 | # declare model
32 | model = TransformerXL(
33 | modelConfig,
34 | device,
35 | event2word=event2word,
36 | word2event=word2event,
37 | is_training=True)
38 |
39 | # train
40 | model.train(training_data,
41 | trainConfig,
42 | device,
43 | resume)
44 |
45 |
46 | def get_configs():
47 | cfg = yaml.full_load(open("config.yml", 'r'))
48 |
49 | modelConfig = cfg['MODEL']
50 | trainConfig = cfg['TRAIN']
51 |
52 | cur_date = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
53 | experiment_Dir = os.path.join(trainConfig['output_dir'],cur_date)
54 | if not os.path.exists(experiment_Dir):
55 | print('experiment_Dir:', experiment_Dir)
56 | os.makedirs(experiment_Dir)
57 | print('Experiment: ', experiment_Dir)
58 | trainConfig.update({'experiment_Dir': experiment_Dir})
59 |
60 |
61 | with open(os.path.join(experiment_Dir, 'config.yml'), 'w') as f:
62 | doc = yaml.dump(cfg, f)
63 |
64 | print('='*5, 'Model configs', '='*5)
65 | print(json.dumps(modelConfig, indent=1, sort_keys=True))
66 | print('='*2, 'Training configs', '='*5)
67 | print(json.dumps(trainConfig, indent=1, sort_keys=True))
68 | return modelConfig, trainConfig
69 |
70 |
71 | if __name__ == '__main__':
72 | main()
73 |
74 |
75 |
--------------------------------------------------------------------------------