├── .gitignore
├── LICENSE
├── README.md
├── _config.yml
├── cog.yaml
├── dataset
├── README.md
├── compile.py
├── corpus2events.py
├── event2words.py
├── get_cls_idx.ipynb
└── midi2corpus.py
├── docs
├── _config.yml
├── _layouts
│ └── default.html
├── assets
│ ├── audio_samples
│ │ ├── 1_lstm+GA
│ │ │ ├── Q1
│ │ │ │ ├── gen_Q1_1.mp3
│ │ │ │ ├── gen_Q1_2.mp3
│ │ │ │ └── gen_Q1_3.mp3
│ │ │ ├── Q2
│ │ │ │ ├── gen_Q2_1.mp3
│ │ │ │ ├── gen_Q2_2.mp3
│ │ │ │ └── gen_Q2_3.mp3
│ │ │ ├── Q3
│ │ │ │ ├── gen_Q3_1.mp3
│ │ │ │ ├── gen_Q3_2.mp3
│ │ │ │ └── gen_Q3_3.mp3
│ │ │ └── Q4
│ │ │ │ ├── gen_Q4_1.mp3
│ │ │ │ ├── gen_Q4_2.mp3
│ │ │ │ └── gen_Q4_3.mp3
│ │ ├── 2_Transformer
│ │ │ ├── Q1
│ │ │ │ ├── Q1_1.mp3
│ │ │ │ ├── Q1_2.mp3
│ │ │ │ └── Q1_3.mp3
│ │ │ ├── Q2
│ │ │ │ ├── Q2_1.mp3
│ │ │ │ ├── Q2_2.mp3
│ │ │ │ └── Q2_3.mp3
│ │ │ ├── Q3
│ │ │ │ ├── Q3_1.mp3
│ │ │ │ ├── Q3_2.mp3
│ │ │ │ └── Q3_3.mp3
│ │ │ └── Q4
│ │ │ │ ├── Q4_1.mp3
│ │ │ │ ├── Q4_2.mp3
│ │ │ │ └── Q4_3.mp3
│ │ └── 3_Pre-trained_Transformer
│ │ │ ├── Q1
│ │ │ ├── Q1_1.mp3
│ │ │ ├── Q1_2.mp3
│ │ │ └── Q1_3.mp3
│ │ │ ├── Q2
│ │ │ ├── Q2_1.mp3
│ │ │ ├── Q2_2.mp3
│ │ │ └── Q2_3.mp3
│ │ │ ├── Q3
│ │ │ ├── Q3_1.mp3
│ │ │ ├── Q3_2.mp3
│ │ │ └── Q3_3.mp3
│ │ │ └── Q4
│ │ │ ├── Q4_1.mp3
│ │ │ ├── Q4_2.mp3
│ │ │ └── Q4_3.mp3
│ └── css
│ │ └── style.css
├── img
│ ├── VA.png
│ ├── VA_Q.png
│ ├── background.jpg
│ ├── classification.png
│ ├── cp.png
│ ├── emopia.png
│ ├── example.png
│ ├── feature.png
│ ├── key.png
│ ├── number.png
│ ├── pipeline.png
│ ├── representation.png
│ └── results.png
├── index.md
└── init
├── predict.py
├── requirements.txt
└── workspace
├── baseline
├── evolve_generative_base.py
├── midi_encoder.py
├── midi_generator.py
├── plot_results.py
├── plot_results_base.py
├── train_classifier.py
└── train_generative.py
└── transformer
├── generate.ipynb
├── main_cp.py
├── models.py
├── saver.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | node_modules/*
2 | .idea/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 |
135 | # data
136 | .mid
137 | .npy
138 | .npz
139 | .pkl
140 | .zip
141 | .json
142 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Hsiao-Tzu Hung
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | This is the official repository of **EMOPIA: A Multi-Modal Pop Piano Dataset For Emotion Recognition and Emotion-based Music Generation**. The paper has been accepted by International Society for Music Information Retrieval Conference 2021.
7 |
8 | - [Paper on Arxiv](https://arxiv.org/abs/2108.01374)
9 | - [Demo Page](https://annahung31.github.io/EMOPIA/)
10 | - [Interactive demo and Docker image on Replicate](https://replicate.ai/annahung31/emopia)
11 | - [Dataset at Zenodo](https://zenodo.org/record/5090631#.YPPo-JMzZz8)
12 |
13 | * Note: We release the transcribed MIDI files. As for the audio part, due to the copyright issue, we will only release the YouTube ID of the tracks and the timestamp of them. You might use [open source crawler](https://github.com/ytdl-org/youtube-dl) to get the audio file.
14 |
15 |
16 | ## Use EMOPIA by MusPy
17 | 1. install muspy
18 | ```
19 | pip install muspy
20 | ```
21 | 2. Use it in your script
22 |
23 | ```
24 | import muspy
25 |
26 | emopia = muspy.EMOPIADataset("data/emopia/", download_and_extract=True)
27 | emopia.convert()
28 | music = emopia[0]
29 | print(music.annotations[0].annotation)
30 | ```
31 | You can get the label of the piece of music:
32 |
33 | ```
34 | {'emo_class': '1', 'YouTube_ID': '0vLPYiPN7qY', 'seg_id': '0'}
35 | ```
36 | * `emo_class`: ['1', '2', '3', '4']
37 | * `YouTube_ID`: the YouTube ID of this piece of music
38 | * `seg_id`: means this piece of music is the `i`th piece we take from this song. (zero-based).
39 |
40 | For more usage please refer to [MusPy](https://github.com/salu133445/muspy).
41 |
42 |
43 | # Emotion Classification
44 |
45 | For the classification models and codes, please refer to [this repo](https://github.com/SeungHeonDoh/EMOPIA_cls).
46 |
47 |
48 | # Conditional Generation
49 |
50 | ## Environment
51 |
52 | 1. Install PyTorch and fast transformer:
53 | - torch==1.7.0 (Please install it according to your [CUDA version](https://pytorch.org/get-started/previous-versions/#linux-and-windows-4).)
54 | - fast transformer :
55 |
56 | ```
57 | pip install --user pytorch-fast-transformers
58 | ```
59 | or refer to the original [repository](https://github.com/idiap/fast-transformers)
60 |
61 | 2. Other requirements:
62 |
63 | pip install -r requirements.txt
64 |
65 |
66 | ## Usage
67 |
68 | ### Inference
69 |
70 | **Option 1:**
71 |
72 | You can directly run the [generate script](./workspace/transformer/generate.ipynb) to generate pieces of musics and listen to them.
73 |
74 |
75 | **Option 2:**
76 | Or you might follow the steps as below.
77 |
78 | 1. Download the checkpoints and put them into `exp/`
79 | * Manually:
80 | - [Baseline](https://drive.google.com/file/d/1Q9vQYnNJ0hXBFwcxdWQgDNmzoW3MLl3h/view?usp=sharing)
81 | - [no-pretrained transformer](https://drive.google.com/file/d/1ZULJgBRu2Wb3jxFmGfAHP1v_tjoryFM7/view?usp=sharing)
82 | - [pretrained transformer](https://drive.google.com/file/d/19Seq18b2JNzOamEQMG1uarKjj27HJkHu/view?usp=sharing)
83 |
84 | * By commend: (install gdown: `pip install gdown`)
85 | ```
86 | #baseline:
87 | gdown --id 1Q9vQYnNJ0hXBFwcxdWQgDNmzoW3MLl3h --output exp/baseline.zip
88 |
89 | # no-pretrained transformer
90 | gdown --id 1ZULJgBRu2Wb3jxFmGfAHP1v_tjoryFM7 --output exp/no-pretrained_transformer.zip
91 |
92 | # pretrained transformer
93 | gdown --id 19Seq18b2JNzOamEQMG1uarKjj27HJkHu --output exp/pretrained_transformer.zip
94 | ```
95 |
96 |
97 |
98 | 2. Inference options:
99 |
100 | * `num_songs`: number of midis you want to generate.
101 | * `out_dir`: the folder where the generated midi will be saved. If not specified, midi files will be saved to `exp/MODEL_YOU_USED/gen_midis/`.
102 | * `task_type`: the task_type needs to be the same as the task specified during training.
103 | - '4-cls' for 4 class conditioning
104 | - 'Arousal' for only conditioning on arousal
105 | - 'Valence' for only conditioning on Valence
106 | - 'ignore' for not conditioning
107 |
108 | * `emo_tag`: the target class of emotion you want to assign.
109 | - If the task_type is '4-cls', emo_tag can be: 1,2,3,4, which refers to Q1, Q2, Q3, Q4.
110 | - If the task_type is 'Arousal', emo_tag can be: `1`, `2`. `1` for High arousal, `2` for Low arousal.
111 | - If the task_type is 'Valence', emo_tag can be: `1`, `2`. `1` for High Valence, `2` for Low Valence.
112 |
113 |
114 | 3. Inference
115 |
116 | ```
117 | python main_cp.py --mode inference --task_type 4-cls --load_ckt CHECKPOINT_FOLDER --load_ckt_loss 25 --num_songs 10 --emo_tag 1
118 | ```
119 |
120 | ### Train the model by yourself
121 | 1. Prepare the data follow the [steps](https://github.com/annahung31/EMOPIA/tree/main/dataset).
122 |
123 |
124 | 2. training options:
125 |
126 | * `exp_name`: the folder name that the checkpoints will be saved.
127 | * `data_parallel`: use data_parallel to let the training process faster. (0: not use, 1: use)
128 | * `task_type`: the conditioning task:
129 | - '4-cls' for 4 class conditioning
130 | - 'Arousal' for only conditioning on arousal
131 | - 'Valence' for only conditioning on Valence
132 | - 'ignore' for not conditioning
133 |
134 | a. Only train on EMOPIA: (`no-pretrained transformer` in the paper)
135 |
136 | python main_cp.py --path_train_data emopia --exp_name YOUR_EXP_NAME --load_ckt none
137 |
138 | b. Pre-train the transformer on `AILabs17k`:
139 |
140 | python main_cp.py --path_train_data ailabs --exp_name YOUR_EXP_NAME --load_ckt none --task_type ignore
141 |
142 | c. fine-tune the transformer on `EMOPIA`:
143 | For example, you want to use the pre-trained model stored in `0309-1857` with loss= `30` to fine-tune:
144 |
145 | python main_cp.py --path_train_data emopia --exp_name YOUR_EXP_NAME --load_ckt 0309-1857 --load_ckt_loss 30
146 |
147 | ### Baseline
148 | 1. The baseline code is based on the work of [Learning to Generate Music with Sentiment](https://github.com/lucasnfe/music-sentneuron)
149 |
150 | 2. According to the author, the model works best when it is trained with 4096 neurons of LSTM, but takes 12 days for training. Therefore, due to the limit of computational resource, we used the size of 512 neurons instead of 4096.
151 |
152 | 3. In order to use this as evaluation against our model, the target emotion classes is expanded to 4Q instead of just positive/negative.
153 |
154 | ## Authors
155 |
156 | The paper is a co-working project with [Joann](https://github.com/joann8512), [SeungHeon](https://github.com/SeungHeonDoh) and Nabin. This repository is mentained by [Joann](https://github.com/joann8512) and me.
157 |
158 |
159 | ## License
160 | The EMOPIA dataset is released under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). It is provided primarily for research purposes and is prohibited to be used for commercial purposes. When sharing your result based on EMOPIA, any act that defames the original music owner is strictly prohibited.
161 |
162 |
163 | The hand drawn piano in the logo comes from [Adobe stock](https://stock.adobe.com/tw/images/one-line-piano-instrument-design-hand-drawn-minimalistic-style-vector-illustration/327942843). The author is [Burak](https://stock.adobe.com/tw/contributor/206697762/burak?load_type=author&prev_url=detail). I purchased it under [standard](https://stock.adobe.com/tw/license-terms) license.
164 |
165 | ## Cite the dataset
166 |
167 | ```
168 | @inproceedings{{EMOPIA},
169 | author = {Hung, Hsiao-Tzu and Ching, Joann and Doh, Seungheon and Kim, Nabin and Nam, Juhan and Yang, Yi-Hsuan},
170 | title = {{MOPIA}: A Multi-Modal Pop Piano Dataset For Emotion Recognition and Emotion-based Music Generation},
171 | booktitle = {Proc. Int. Society for Music Information Retrieval Conf.},
172 | year = {2021}
173 | }
174 | ```
175 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-cayman
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | predict: "predict.py:Predictor"
2 | build:
3 | gpu: true
4 | system_packages:
5 | - "ffmpeg"
6 | - "fluidsynth"
7 | python_packages:
8 | - "torch==1.7.0"
9 | - "scikit-learn==0.24.1"
10 | - "seaborn==0.11.1"
11 | - "numpy==1.19.5"
12 | - "miditoolkit==0.1.14"
13 | - "pandas==1.1.5"
14 | - "tqdm==4.62.2"
15 | - "matplotlib==3.4.3"
16 | - "scipy==1.7.1"
17 | - "midiSynth==0.3"
18 | - "wheel==0.37.0"
19 | - "ipdb===0.13.9"
20 | - "pyfluidsynth==1.3.0"
21 | pre_install:
22 | - "pip install pytorch-fast-transformers==0.4.0" # needs to be installed after the main pip install
23 |
--------------------------------------------------------------------------------
/dataset/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Dataset processing
3 |
4 | ## If you want to use the processed data used in our paper...
5 | Download the npz files from [here](https://drive.google.com/file/d/17dKUf33ZsDbHC5Z6rkQclge3ppDTVCMP/view?usp=sharing), or download via gdown:
6 |
7 | ```
8 | gdown --id 17dKUf33ZsDbHC5Z6rkQclge3ppDTVCMP
9 | unzip co-representation.zip
10 | ```
11 |
12 |
13 | ## If you want to prepare data from scratch...
14 | The data pre-processing used in our paper is basically the same as [Compound-word-transformer](https://github.com/YatingMusic/compound-word-transformer/blob/main/dataset/Dataset.md). The difference is the emotion token part.
15 |
16 |
17 | 1. Run step 1-3 of [Compound-word-transformer dataset processing](https://github.com/YatingMusic/compound-word-transformer/blob/main/dataset/Dataset.md).
18 | 2. Change the path in the following scripts and run:
19 |
20 | a. quantize everything, add EOS, and prepare the emotion label from the filename.
21 |
22 | ```
23 | python midi2corpus.py
24 | ```
25 |
26 | b. transfer the corpus file to CP events.
27 |
28 | ```
29 | python corpus2events.py
30 | ```
31 |
32 | c. transfer the events to CP words, and build the dictionary.
33 | ```
34 | python event2words.py
35 | ```
36 | d. Compile the words file to npz file for training.
37 | ```
38 | python compile.py
39 | ```
40 | * Please note that I don't split data into train/test set, I use all the data for training bacause the task is to generate music from scratch and no need for validation.
41 | * The broken list in the compile.py is samples that encountered some issues during preprocessing and I just skip them.
--------------------------------------------------------------------------------
/dataset/compile.py:
--------------------------------------------------------------------------------
1 | '''
2 | This code is from
3 | https://github.com/YatingMusic/compound-word-transformer/blob/main/dataset/representations/uncond/cp/compile.py
4 | '''
5 |
6 |
7 | import os
8 | import json
9 | import pickle
10 | import numpy as np
11 | import ipdb
12 |
13 |
14 | TEST_AMOUNT = 50
15 | WINDOW_SIZE = 512
16 | GROUP_SIZE = 2 #7
17 | MAX_LEN = WINDOW_SIZE * GROUP_SIZE
18 | COMPILE_TARGET = 'linear' # 'linear', 'XL'
19 | print('[config] MAX_LEN:', MAX_LEN)
20 |
21 | broken_list = [
22 | 'Q1_8izVTDgBQPc_0.mid.pkl.npy', 'Q1_uqRLEByE6pU_1.mid.pkl.npy', 'Q3_nBIls0laAAU_0.mid.pkl.npy',
23 | 'Q1_8rupdevqfuI_0.mid.pkl.npy', 'Q2_9v2WSpn4FCw_1.mid.pkl.npy', 'Q3_REq37pDAm3A_3.mid.pkl.npy',
24 | 'Q1_aYe-2Glruu4_3.mid.pkl.npy', 'Q2_BzqX-9TA-GY_2.mid.pkl.npy', 'Q3_RL_cmmNVLfs_0.mid.pkl.npy',
25 | 'Q1_FfwKrQyQ7WU_2.mid.pkl.npy', 'Q2_ivCNV47tsRw_1.mid.pkl.npy', 'Q3_wfXSdMsd4q8_4.mid.pkl.npy',
26 | 'Q1_GY3f6ckBVkA_1.mid.pkl.npy', 'Q2_RiQMuhk_SuQ_1.mid.pkl.npy', 'Q3_wqc8iqbDsGM_0.mid.pkl.npy',
27 | 'Q1_Jn9r0avp0fY_1.mid.pkl.npy', 'Q3_bbU31JLtlug_1.mid.pkl.npy', 'Q4_OUb9uaOlWAM_0.mid.pkl.npy',
28 | 'Q1_NGE9ynTJABg_0.mid.pkl.npy', 'Q3_c6CwY8Gbw0c_2.mid.pkl.npy', 'Q4_V3Y9L4UOcpk_1.mid.pkl.npy',
29 | 'Q1_QwsQ8ejbMKg_1.mid.pkl.npy', 'Q3_kDGmND1BgmA_1.mid.pkl.npy']
30 |
31 |
32 |
33 | def traverse_dir(
34 | root_dir,
35 | extension=('mid', 'MID'),
36 | amount=None,
37 | str_=None,
38 | is_pure=False,
39 | verbose=False,
40 | is_sort=False,
41 | is_ext=True):
42 | if verbose:
43 | print('[*] Scanning...')
44 | file_list = []
45 | cnt = 0
46 | for root, _, files in os.walk(root_dir):
47 | for file in files:
48 | if file in broken_list:
49 | continue
50 | if file.endswith(extension):
51 | if (amount is not None) and (cnt == amount):
52 | break
53 | if str_ is not None:
54 | if str_ not in file:
55 | continue
56 | mix_path = os.path.join(root, file)
57 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
58 | if not is_ext:
59 | ext = pure_path.split('.')[-1]
60 | pure_path = pure_path[:-(len(ext)+1)]
61 | if verbose:
62 | print(pure_path)
63 | file_list.append(pure_path)
64 | cnt += 1
65 | if verbose:
66 | print('Total: %d files' % len(file_list))
67 | print('Done!!!')
68 | if is_sort:
69 | file_list.sort()
70 | return file_list
71 |
72 |
73 | if __name__ == '__main__':
74 | # paths
75 | path_root = '.'
76 | path_indir = os.path.join( path_root, 'words')
77 |
78 | # load all words
79 | wordfiles = traverse_dir(
80 | path_indir,
81 | extension=('npy'))
82 | n_files = len(wordfiles)
83 |
84 | # init
85 | x_list = []
86 | y_list = []
87 | mask_list = []
88 | seq_len_list = []
89 | num_groups_list = []
90 | name_list = []
91 |
92 | # process
93 | for fidx in range(n_files):
94 | print('--[{}/{}]-----'.format(fidx+1, n_files))
95 | file = wordfiles[fidx]
96 | print(file)
97 | try:
98 | words = np.load(file)
99 | except:
100 | print(fidx)
101 |
102 | import ipdb
103 | ipdb.set_trace()
104 | num_words = len(words)
105 |
106 | eos_arr = words[-1][None, ...]
107 |
108 | if num_words >= MAX_LEN - 2: # 2 for room
109 | words = words[:MAX_LEN-2]
110 |
111 | # arrange IO
112 | x = words[:-1].copy() #without EOS
113 | y = words[1:].copy()
114 | seq_len = len(x)
115 | print(' > seq_len:', seq_len)
116 |
117 | # pad with eos
118 | pad = np.tile(
119 | eos_arr,
120 | (MAX_LEN-seq_len, 1))
121 |
122 | x = np.concatenate([x, pad], axis=0)
123 | y = np.concatenate([y, pad], axis=0)
124 | mask = np.concatenate(
125 | [np.ones(seq_len), np.zeros(MAX_LEN-seq_len)])
126 |
127 | # collect
128 | if x.shape != (1024, 8):
129 | print(x.shape)
130 | exit()
131 |
132 | x_list.append(x)
133 | y_list.append(y)
134 | mask_list.append(mask)
135 | seq_len_list.append(seq_len)
136 | num_groups_list.append(int(np.ceil(seq_len/WINDOW_SIZE)))
137 | name_list.append(file)
138 |
139 | # sort by length (descending)
140 | zipped = zip(seq_len_list, x_list, y_list, mask_list, num_groups_list, name_list)
141 | seq_len_list, x_list, y_list, mask_list, num_groups_list, name_list = zip(
142 | *sorted(zipped, key=lambda x: -x[0]))
143 |
144 | print('\n\n[Finished]')
145 | print(' compile target:', COMPILE_TARGET)
146 | if COMPILE_TARGET == 'XL':
147 | # reshape
148 | x_final = np.array(x_list).reshape(len(x_list), GROUP_SIZE, WINDOW_SIZE, -1)
149 | y_final = np.array(y_list).reshape(len(x_list), GROUP_SIZE, WINDOW_SIZE, -1)
150 | mask_final = np.array(mask_list).reshape(-1, GROUP_SIZE, WINDOW_SIZE)
151 | elif COMPILE_TARGET == 'linear':
152 |
153 | x_final = np.array(x_list)
154 | y_final = np.array(y_list)
155 | mask_final = np.array(mask_list)
156 | else:
157 | raise ValueError('Unknown target:', COMPILE_TARGET)
158 |
159 | # check
160 | num_samples = len(seq_len_list)
161 | print(' > count:', )
162 | print(' > x_final:', x_final.shape)
163 | print(' > y_final:', y_final.shape)
164 | print(' > mask_final:', mask_final.shape)
165 |
166 | train_idx = []
167 |
168 | # validation filename map
169 | fn2idx_map = {
170 | 'fn2idx': dict(),
171 | 'idx2fn': dict(),
172 | }
173 |
174 | # training filename map
175 | train_fn2idx_map = {
176 | 'fn2idx': dict(),
177 | 'idx2fn': dict(),
178 | }
179 |
180 | name_list = [x.split('/')[-1].split('.')[0] for x in name_list]
181 | # run split
182 | train_cnt = 0
183 | for nidx, n in enumerate(name_list):
184 |
185 | train_idx.append(nidx)
186 | train_fn2idx_map['fn2idx'][n] = train_cnt
187 | train_fn2idx_map['idx2fn'][train_cnt] = n
188 | train_cnt += 1
189 |
190 | train_idx = np.array(train_idx)
191 |
192 | # save train map
193 | path_train_fn2idx_map = os.path.join(path_root, 'train_fn2idx_map.json')
194 | with open(path_train_fn2idx_map, 'w') as f:
195 | json.dump(train_fn2idx_map, f)
196 |
197 | # save train
198 | path_train = os.path.join(path_root, 'train_data_{}'.format(COMPILE_TARGET))
199 | path_train += '.npz'
200 | print('save to', path_train)
201 | np.savez(
202 | path_train,
203 | x=x_final[train_idx],
204 | y=y_final[train_idx],
205 | mask=mask_final[train_idx],
206 | seq_len=np.array(seq_len_list)[train_idx],
207 | num_groups=np.array(num_groups_list)[train_idx]
208 | )
209 |
210 | print('---')
211 | print(' > train x:', x_final[train_idx].shape)
212 |
213 |
--------------------------------------------------------------------------------
/dataset/corpus2events.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | import pandas as pd
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | # config
9 | BEAT_RESOL = 480
10 | BAR_RESOL = BEAT_RESOL * 4
11 | TICK_RESOL = BEAT_RESOL // 4
12 |
13 |
14 | # utilities
15 | def plot_hist(data, path_outfile):
16 | print('[Fig] >> {}'.format(path_outfile))
17 | data_mean = np.mean(data)
18 | data_std = np.std(data)
19 |
20 | print('mean:', data_mean)
21 | print(' std:', data_std)
22 |
23 | plt.figure(dpi=100)
24 | plt.hist(data, bins=50)
25 | plt.title('mean: {:.3f}_std: {:.3f}'.format(data_mean, data_std))
26 | plt.savefig(path_outfile)
27 | plt.close()
28 |
29 | def traverse_dir(
30 | root_dir,
31 | extension=('mid', 'MID'),
32 | amount=None,
33 | str_=None,
34 | is_pure=False,
35 | verbose=False,
36 | is_sort=False,
37 | is_ext=True):
38 | if verbose:
39 | print('[*] Scanning...')
40 | file_list = []
41 | cnt = 0
42 | for root, _, files in os.walk(root_dir):
43 | for file in files:
44 | if file.endswith(extension):
45 | if (amount is not None) and (cnt == amount):
46 | break
47 | if str_ is not None:
48 | if str_ not in file:
49 | continue
50 | mix_path = os.path.join(root, file)
51 | pure_path = mix_path[len(root_dir):] if is_pure else mix_path
52 | if not is_ext:
53 | ext = pure_path.split('.')[-1]
54 | pure_path = pure_path[:-(len(ext)+1)]
55 | if verbose:
56 | print(pure_path)
57 |
58 |
59 | file_list.append(pure_path)
60 | cnt += 1
61 | if verbose:
62 | print('Total: %d files' % len(file_list))
63 | print('Done!!!')
64 | if is_sort:
65 | file_list.sort()
66 | return file_list
67 |
68 | # ---- define event ---- # #todo
69 | ''' 8 kinds:
70 | tempo: 0: IGN
71 | 1: no change
72 | int: tempo
73 | chord: 0: IGN
74 | 1: no change
75 | str: chord types
76 | bar-beat: 0: IGN
77 | int: beat position (1...16)
78 | int: bar (bar)
79 | type: 0: eos
80 | 1: metrical
81 | 2: note
82 | 3: emotion
83 | duration: 0: IGN
84 | int: length
85 | pitch: 0: IGN
86 | int: pitch
87 | velocity: 0: IGN
88 | int: velocity
89 | emotion: 0: IGN
90 | 1: Q1
91 | 2: Q2
92 | 3: Q3
93 | 4: Q4
94 | '''
95 |
96 | # emotion map
97 | emo_map = {
98 | 'Q1': 1,
99 | 'Q2': 2,
100 | 'Q3': 3,
101 | 'Q4': 4,
102 |
103 | }
104 |
105 |
106 | # event template
107 | compound_event = {
108 | 'tempo': 0,
109 | 'chord': 0,
110 | 'bar-beat': 0,
111 | 'type': 0,
112 | 'pitch': 0,
113 | 'duration': 0,
114 | 'velocity': 0,
115 | 'emotion': 0
116 | }
117 |
118 | def create_emo_event(emo_tag):
119 | emo_event = compound_event.copy()
120 | emo_event['emotion'] = emo_tag
121 | emo_event['type'] = 'Emotion'
122 | return emo_event
123 |
124 | def create_bar_event():
125 | meter_event = compound_event.copy()
126 | meter_event['bar-beat'] = 'Bar'
127 | meter_event['type'] = 'Metrical'
128 | return meter_event
129 |
130 |
131 | def create_piano_metrical_event(tempo, chord, pos):
132 | meter_event = compound_event.copy()
133 | meter_event['tempo'] = tempo
134 | meter_event['chord'] = chord
135 | meter_event['bar-beat'] = pos
136 | #todo
137 | meter_event['type'] = 'Metrical'
138 | return meter_event
139 |
140 |
141 | def create_piano_note_event(pitch, duration, velocity):
142 | note_event = compound_event.copy()
143 | note_event['pitch'] = pitch
144 | note_event['duration'] = duration
145 | note_event['velocity'] = velocity
146 | note_event['type'] = 'Note'
147 | return note_event
148 |
149 |
150 | def create_eos_event():
151 | eos_event = compound_event.copy()
152 | eos_event['type'] = 'EOS'
153 | return eos_event
154 |
155 |
156 | # ----------------------------------------------- #
157 | # core functions
158 | def corpus2event_cp(path_infile, path_outfile):
159 | '''
160 | task: 2 track
161 | 1: piano (note + tempo)
162 | ---
163 | remove duplicate position tokens
164 | '''
165 |
166 | data = pickle.load(open(path_infile, 'rb'))
167 |
168 |
169 | # global tag
170 | global_end = data['metadata']['last_bar'] * BAR_RESOL
171 | emo_tag = emo_map[data['metadata']['emotion']]
172 |
173 | # process
174 | final_sequence = []
175 | final_sequence.append(create_emo_event(emo_tag))
176 | for bar_step in range(0, global_end, BAR_RESOL):
177 | final_sequence.append(create_bar_event())
178 |
179 | # --- piano track --- #
180 | for timing in range(bar_step, bar_step + BAR_RESOL, TICK_RESOL):
181 | pos_on = False
182 | pos_events = []
183 | pos_text = 'Beat_' + str((timing-bar_step)//TICK_RESOL)
184 |
185 | # unpack
186 | t_chords = data['chords'][timing]
187 | t_tempos = data['tempos'][timing]
188 | t_notes = data['notes'][0][timing] # piano track
189 |
190 | # metrical
191 | #todo
192 | if len(t_tempos) or len(t_chords):
193 | # chord
194 | if len(t_chords):
195 |
196 | root, quality, bass = t_chords[-1].text.split('_')
197 | chord_text = root+'_'+quality
198 | else:
199 | chord_text = 'CONTI'
200 |
201 | # tempo
202 | if len(t_tempos):
203 | tempo_text = 'Tempo_' + str(t_tempos[-1].tempo)
204 | else:
205 | tempo_text = 'CONTI'
206 |
207 | # create
208 | pos_events.append(
209 | create_piano_metrical_event(
210 | tempo_text, chord_text, pos_text))
211 | pos_on = True
212 |
213 | # note
214 | if len(t_notes):
215 | if not pos_on:
216 | pos_events.append(
217 | create_piano_metrical_event(
218 | 'CONTI', 'CONTI', pos_text))
219 |
220 | for note in t_notes:
221 | note_pitch_text = 'Note_Pitch_' + str(note.pitch)
222 | note_duration_text = 'Note_Duration_' + str(note.duration)
223 | note_velocity_text = 'Note_Velocity_' + str(note.velocity)
224 |
225 | pos_events.append(
226 | create_piano_note_event(
227 | note_pitch_text,
228 | note_duration_text,
229 | note_velocity_text))
230 |
231 | # collect & beat
232 | if len(pos_events):
233 | final_sequence.extend(pos_events)
234 |
235 | # BAR ending
236 | final_sequence.append(create_bar_event())
237 |
238 | # EOS
239 | final_sequence.append(create_eos_event())
240 |
241 | # save
242 | fn = os.path.basename(path_outfile)
243 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True)
244 | pickle.dump(final_sequence, open(path_outfile, 'wb'))
245 |
246 | return len(final_sequence)
247 |
248 |
249 | if __name__ == '__main__':
250 | # paths
251 | path_root = '.'
252 | path_indir = 'input_dir'
253 | path_outdir = 'events'
254 | os.makedirs(path_outdir, exist_ok=True)
255 |
256 | # list files
257 | midifiles = traverse_dir(
258 | path_indir,
259 | extension=('pkl'),
260 | is_pure=True,
261 | is_sort=True)
262 | n_files = len(midifiles)
263 | print('num files:', n_files)
264 |
265 | # run all
266 | len_list = []
267 | paths = []
268 | for fidx in range(n_files):
269 | path_midi = midifiles[fidx]
270 | print('{}/{}'.format(fidx+1, n_files))
271 |
272 | # paths
273 | path_infile = os.path.join(path_indir, path_midi)
274 | path_outfile = os.path.join(path_outdir, path_midi)
275 |
276 | # proc
277 | num_tokens = corpus2event_cp(path_infile, path_outfile)
278 | print(' > num_token:', num_tokens)
279 | len_list.append(num_tokens)
280 | paths.append(path_midi)
281 |
282 |
283 |
284 | plot_hist(
285 | len_list,
286 | os.path.join(path_root, 'num_tokens.png')
287 | )
288 |
289 |
290 | d = {'filename': paths, 'num_tokens': len_list}
291 | df = pd.DataFrame(data=d)
292 | df.to_csv('len_token.csv', index=False)
--------------------------------------------------------------------------------
/dataset/event2words.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | This code is modify from:
4 | https://github.com/YatingMusic/compound-word-transformer/blob/main/dataset/representations/uncond/cp/events2words.py
5 |
6 | '''
7 |
8 | import os
9 | import pickle
10 | import numpy as np
11 | import collections
12 |
13 |
14 | def traverse_dir(
15 | root_dir,
16 | extension=('mid', 'MID'),
17 | amount=None,
18 | str_=None,
19 | is_pure=False,
20 | verbose=False,
21 | is_sort=False,
22 | is_ext=True):
23 | if verbose:
24 | print('[*] Scanning...')
25 | file_list = []
26 | cnt = 0
27 | for root, _, files in os.walk(root_dir):
28 | for file in files:
29 | if file.endswith(extension):
30 | if (amount is not None) and (cnt == amount):
31 | break
32 | if str_ is not None:
33 | if str_ not in file:
34 | continue
35 | mix_path = os.path.join(root, file)
36 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
37 | if not is_ext:
38 | ext = pure_path.split('.')[-1]
39 | pure_path = pure_path[:-(len(ext)+1)]
40 | if verbose:
41 | print(pure_path)
42 | file_list.append(pure_path)
43 | cnt += 1
44 | if verbose:
45 | print('Total: %d files' % len(file_list))
46 | print('Done!!!')
47 | if is_sort:
48 | file_list.sort()
49 | return file_list
50 |
51 |
52 | def build_dict(path_root, path_indir, eventfiles, path_dict):
53 | class_keys = pickle.load(
54 | open(os.path.join(path_indir, eventfiles[0]), 'rb'))[0].keys()
55 | print('class keys:', class_keys)
56 |
57 | # define dictionary
58 | event2word = {}
59 | word2event = {}
60 |
61 | corpus_kv = collections.defaultdict(list)
62 | for file in eventfiles:
63 | for event in pickle.load(open(
64 | os.path.join(path_indir, file), 'rb')):
65 | for key in class_keys:
66 | corpus_kv[key].append(event[key])
67 |
68 | for ckey in class_keys:
69 | class_unique_vals = sorted(
70 | set(corpus_kv[ckey]), key=lambda x: (not isinstance(x, int), x))
71 | event2word[ckey] = {key: i for i, key in enumerate(class_unique_vals)}
72 | word2event[ckey] = {i: key for i, key in enumerate(class_unique_vals)}
73 |
74 | # print
75 | print('[class size]')
76 | for key in class_keys:
77 | print(' > {:10s}: {}'.format(key, len(event2word[key])))
78 |
79 | # save
80 | pickle.dump((event2word, word2event), open(path_dict, 'wb'))
81 |
82 |
83 | if __name__ == '__main__':
84 | # paths
85 | path_root = '.'
86 | path_indir = os.path.join(path_root, 'events')
87 | path_outdir = os.path.join(path_root, 'wordstemp')
88 | path_dict = os.path.join(path_root, 'dictionary.pkl')
89 | os.makedirs(path_outdir, exist_ok=True)
90 |
91 | # list files
92 | eventfiles = traverse_dir(
93 | path_indir,
94 | is_pure=True,
95 | is_sort=True,
96 | extension=('pkl'))
97 | n_files = len(eventfiles)
98 | print('num fiels:', n_files)
99 |
100 | class_keys = pickle.load(
101 | open(os.path.join(path_indir, eventfiles[0]), 'rb'))[0].keys()
102 | print('class keys:', class_keys)
103 |
104 |
105 | # --- build dictionary --- #
106 | # all files
107 | if not os.path.exists(path_dict):
108 | build_dict(path_root, path_indir, eventfiles, path_dict)
109 |
110 |
111 | # --- compile each --- #
112 | # reload
113 | event2word, word2event = pickle.load(open(path_dict, 'rb'))
114 | for fidx in range(len(eventfiles)):
115 | file = eventfiles[fidx]
116 | events_list = pickle.load(open(
117 | os.path.join(path_indir, file), 'rb'))
118 | fn = os.path.basename(file)
119 | path_outfile = os.path.join(path_outdir, fn)
120 |
121 | print('({}/{})'.format(fidx, len(eventfiles)))
122 | print(' > from:', file)
123 | print(' > to:', path_outfile)
124 |
125 | words = []
126 | for eidx, e in enumerate(events_list):
127 | words_tmp = [
128 | event2word[k][e[k]] for k in class_keys
129 | ]
130 | words.append(words_tmp)
131 |
132 | # save
133 | path_outfile = os.path.join(path_outdir, file + '.npy')
134 | fn = os.path.basename(path_outfile)
135 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True)
136 | np.save(path_outfile, words)
--------------------------------------------------------------------------------
/dataset/get_cls_idx.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "data": {
10 | "text/plain": [
11 | "True"
12 | ]
13 | },
14 | "execution_count": 2,
15 | "metadata": {},
16 | "output_type": "execute_result"
17 | }
18 | ],
19 | "source": [
20 | "import os\n",
21 | "import numpy as np\n",
22 | "path_data_root = 'where you put the data'\n",
23 | "train_data_name = 'name of your processed train data'\n",
24 | "path_train_data = os.path.join(path_data_root, train_data_name + '.npz')\n",
25 | "os.path.exists(path_train_data)"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "train_data = np.load(path_train_data)\n",
35 | "train_x = train_data['x']"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 3,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "class_1_idx = []\n",
45 | "class_2_idx = []\n",
46 | "class_3_idx = []\n",
47 | "class_4_idx = []\n",
48 | "idxs = [class_1_idx, class_2_idx, class_3_idx, class_4_idx]\n",
49 | "for i, sample in enumerate(train_x):\n",
50 | " idxs[sample[0][-1] - 1].append(i)"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 4,
56 | "metadata": {},
57 | "outputs": [
58 | {
59 | "data": {
60 | "text/plain": [
61 | "[2, 7, 13, 19, 32, 41, 50, 52, 64, 66]"
62 | ]
63 | },
64 | "execution_count": 4,
65 | "metadata": {},
66 | "output_type": "execute_result"
67 | }
68 | ],
69 | "source": [
70 | "class_1_idx[:10]"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 5,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "cls_1 = train_x[class_1_idx]"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 6,
85 | "metadata": {},
86 | "outputs": [
87 | {
88 | "data": {
89 | "text/plain": [
90 | "1"
91 | ]
92 | },
93 | "execution_count": 6,
94 | "metadata": {},
95 | "output_type": "execute_result"
96 | }
97 | ],
98 | "source": [
99 | "cls_1[0][0][-1]"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 7,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "path_train = os.path.join(path_data_root, train_data_name + '_data_idx.npz')\n",
109 | "np.savez(\n",
110 | " path_train, \n",
111 | " cls_1_idx=class_1_idx,\n",
112 | " cls_2_idx=class_2_idx,\n",
113 | " cls_3_idx=class_3_idx,\n",
114 | " cls_4_idx=class_4_idx\n",
115 | " )\n",
116 | " "
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "metadata": {},
123 | "outputs": [],
124 | "source": []
125 | }
126 | ],
127 | "metadata": {
128 | "kernelspec": {
129 | "display_name": "Python 3 (ipykernel)",
130 | "language": "python",
131 | "name": "python3"
132 | },
133 | "language_info": {
134 | "codemirror_mode": {
135 | "name": "ipython",
136 | "version": 3
137 | },
138 | "file_extension": ".py",
139 | "mimetype": "text/x-python",
140 | "name": "python",
141 | "nbconvert_exporter": "python",
142 | "pygments_lexer": "ipython3",
143 | "version": "3.7.3"
144 | }
145 | },
146 | "nbformat": 4,
147 | "nbformat_minor": 5
148 | }
149 |
--------------------------------------------------------------------------------
/dataset/midi2corpus.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import pickle
4 | import numpy as np
5 | import miditoolkit
6 | import collections
7 |
8 | import pandas as pd
9 | import matplotlib.pyplot as plt
10 | import seaborn as sns
11 |
12 | # ================================================== #
13 | # Configuration #
14 | # ================================================== #
15 | BEAT_RESOL = 480
16 | BAR_RESOL = BEAT_RESOL * 4
17 | TICK_RESOL = BEAT_RESOL // 4
18 | INSTR_NAME_MAP = {'piano': 0}
19 | MIN_BPM = 40
20 | MIN_VELOCITY = 40
21 | NOTE_SORTING = 1 # 0: ascending / 1: descending
22 |
23 | DEFAULT_VELOCITY_BINS = np.linspace(0, 128, 64+1, dtype=np.int)
24 | DEFAULT_BPM_BINS = np.linspace(32, 224, 64+1, dtype=np.int)
25 | DEFAULT_SHIFT_BINS = np.linspace(-60, 60, 60+1, dtype=np.int)
26 | DEFAULT_DURATION_BINS = np.arange(
27 | BEAT_RESOL/8, BEAT_RESOL*8+1, BEAT_RESOL/8)
28 |
29 | # ================================================== #
30 |
31 |
32 | def traverse_dir(
33 | root_dir,
34 | extension=('mid', 'MID', 'midi'),
35 | amount=None,
36 | str_=None,
37 | is_pure=False,
38 | verbose=False,
39 | is_sort=False,
40 | is_ext=True):
41 | if verbose:
42 | print('[*] Scanning...')
43 | file_list = []
44 | cnt = 0
45 | for root, _, files in os.walk(root_dir):
46 | for file in files:
47 | if file.endswith(extension):
48 | if (amount is not None) and (cnt == amount):
49 | break
50 | if str_ is not None:
51 | if str_ not in file:
52 | continue
53 | mix_path = os.path.join(root, file)
54 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
55 | if not is_ext:
56 | ext = pure_path.split('.')[-1]
57 | pure_path = pure_path[:-(len(ext)+1)]
58 | if verbose:
59 | print(pure_path)
60 | file_list.append(pure_path)
61 | cnt += 1
62 | if verbose:
63 | print('Total: %d files' % len(file_list))
64 | print('Done!!!')
65 | if is_sort:
66 | file_list.sort()
67 | return file_list
68 |
69 |
70 | def proc_one(path_midi, path_outfile):
71 | # --- load --- #
72 | midi_obj = miditoolkit.midi.parser.MidiFile(path_midi)
73 |
74 | # collect emotion tag
75 | emo_tag = path_midi.split('/')[-1][:2]
76 |
77 | # load notes
78 | instr_notes = collections.defaultdict(list)
79 | for instr in midi_obj.instruments:
80 | # skip
81 | if instr.name not in INSTR_NAME_MAP.keys():
82 | continue
83 |
84 | # process
85 | instr_idx = INSTR_NAME_MAP[instr.name]
86 | for note in instr.notes:
87 | note.instr_idx=instr_idx
88 | instr_notes[instr_idx].append(note)
89 | if NOTE_SORTING == 0:
90 | instr_notes[instr_idx].sort(
91 | key=lambda x: (x.start, x.pitch))
92 | elif NOTE_SORTING == 1:
93 | instr_notes[instr_idx].sort(
94 | key=lambda x: (x.start, -x.pitch))
95 | else:
96 | raise ValueError(' [x] Unknown type of sorting.')
97 |
98 | # load chords
99 | chords = []
100 | for marker in midi_obj.markers:
101 | if marker.text.split('_')[0] != 'global' and \
102 | 'Boundary' not in marker.text.split('_')[0]:
103 | chords.append(marker)
104 | chords.sort(key=lambda x: x.time)
105 |
106 | # load tempos
107 | tempos = midi_obj.tempo_changes
108 | tempos.sort(key=lambda x: x.time)
109 |
110 | # load labels
111 | labels = []
112 | for marker in midi_obj.markers:
113 | if 'Boundary' in marker.text.split('_')[0]:
114 | labels.append(marker)
115 | labels.sort(key=lambda x: x.time)
116 |
117 | # load global bpm
118 | gobal_bpm = 120
119 | for marker in midi_obj.markers:
120 | if marker.text.split('_')[0] == 'global' and \
121 | marker.text.split('_')[1] == 'bpm':
122 | gobal_bpm = int(marker.text.split('_')[2])
123 |
124 | # --- process items to grid --- #
125 | # compute empty bar offset at head
126 | first_note_time = min([instr_notes[k][0].start for k in instr_notes.keys()])
127 | last_note_time = max([instr_notes[k][-1].start for k in instr_notes.keys()])
128 |
129 | quant_time_first = int(np.round(first_note_time / TICK_RESOL) * TICK_RESOL)
130 | offset = quant_time_first // BAR_RESOL # empty bar
131 | last_bar = int(np.ceil(last_note_time / BAR_RESOL)) - offset
132 | print(' > offset:', offset)
133 | print(' > last_bar:', last_bar)
134 |
135 | # process notes
136 | intsr_gird = dict()
137 | for key in instr_notes.keys():
138 | notes = instr_notes[key]
139 | note_grid = collections.defaultdict(list)
140 | for note in notes:
141 | note.start = note.start - offset * BAR_RESOL
142 | note.end = note.end - offset * BAR_RESOL
143 |
144 | # quantize start
145 | quant_time = int(np.round(note.start / TICK_RESOL) * TICK_RESOL)
146 |
147 | # velocity
148 | note.velocity = DEFAULT_VELOCITY_BINS[
149 | np.argmin(abs(DEFAULT_VELOCITY_BINS-note.velocity))]
150 | note.velocity = max(MIN_VELOCITY, note.velocity)
151 |
152 | # shift of start
153 | note.shift = note.start - quant_time
154 | note.shift = DEFAULT_SHIFT_BINS[np.argmin(abs(DEFAULT_SHIFT_BINS-note.shift))]
155 |
156 | # duration
157 | note_duration = note.end - note.start
158 | if note_duration > BAR_RESOL:
159 | note_duration = BAR_RESOL
160 | ntick_duration = int(np.round(note_duration / TICK_RESOL) * TICK_RESOL)
161 | note.duration = ntick_duration
162 |
163 | # append
164 | note_grid[quant_time].append(note)
165 |
166 | # set to track
167 | intsr_gird[key] = note_grid.copy()
168 |
169 | # process chords
170 | chord_grid = collections.defaultdict(list)
171 | for chord in chords:
172 | # quantize
173 | chord.time = chord.time - offset * BAR_RESOL
174 | chord.time = 0 if chord.time < 0 else chord.time
175 | quant_time = int(np.round(chord.time / TICK_RESOL) * TICK_RESOL)
176 |
177 | # append
178 | chord_grid[quant_time].append(chord)
179 |
180 | # process tempo
181 | tempo_grid = collections.defaultdict(list)
182 | for tempo in tempos:
183 | # quantize
184 | tempo.time = tempo.time - offset * BAR_RESOL
185 | tempo.time = 0 if tempo.time < 0 else tempo.time
186 | quant_time = int(np.round(tempo.time / TICK_RESOL) * TICK_RESOL)
187 | tempo.tempo = DEFAULT_BPM_BINS[np.argmin(abs(DEFAULT_BPM_BINS-tempo.tempo))]
188 |
189 | # append
190 | tempo_grid[quant_time].append(tempo)
191 |
192 | # process boundary
193 | label_grid = collections.defaultdict(list)
194 | for label in labels:
195 | # quantize
196 | label.time = label.time - offset * BAR_RESOL
197 | label.time = 0 if label.time < 0 else label.time
198 | quant_time = int(np.round(label.time / TICK_RESOL) * TICK_RESOL)
199 |
200 | # append
201 | label_grid[quant_time] = [label]
202 |
203 | # process global bpm
204 | gobal_bpm = DEFAULT_BPM_BINS[np.argmin(abs(DEFAULT_BPM_BINS-gobal_bpm))]
205 |
206 | # collect
207 | song_data = {
208 | 'notes': intsr_gird,
209 | 'chords': chord_grid,
210 | 'tempos': tempo_grid,
211 | 'labels': label_grid,
212 | 'metadata': {
213 | 'global_bpm': gobal_bpm,
214 | 'last_bar': last_bar,
215 | 'emotion': emo_tag
216 | }
217 | }
218 |
219 | # save
220 | fn = os.path.basename(path_outfile)
221 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True)
222 | pickle.dump(song_data, open(path_outfile, 'wb'))
223 |
224 | return song_data
225 |
226 | if __name__ == '__main__':
227 | # paths
228 | path_indir = './midi_analyzed/fixed'
229 | path_outdir = './corpus/fixed'
230 | os.makedirs(path_outdir, exist_ok=True)
231 |
232 | # list files
233 | midifiles = traverse_dir(
234 | path_indir,
235 | is_pure=True,
236 | is_sort=True)
237 | n_files = len(midifiles)
238 | print('num fiels:', n_files)
239 |
240 | # run all
241 | for fidx in range(n_files):
242 | path_midi = midifiles[fidx]
243 | print('{}/{}'.format(fidx, n_files))
244 |
245 | # paths
246 | path_infile = os.path.join(path_indir, path_midi)
247 | path_outfile = os.path.join(path_outdir, path_midi+'.pkl')
248 |
249 | # proc
250 | _ = proc_one(path_infile, path_outfile)
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-cayman
2 |
3 | title: EMOPIA
4 | description: A Multi-Modal Pop Piano Dataset For Emotion Recognition and Emotion-based Music Generation
5 |
6 |
--------------------------------------------------------------------------------
/docs/_layouts/default.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | {% if site.google_analytics %}
6 |
7 |
13 | {% endif %}
14 |
15 |
16 | {% seo %}
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | Skip to the content.
26 |
27 |
38 |
39 |
40 | {{ content }}
41 |
42 |
49 |
50 |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/docs/assets/audio_samples/1_lstm+GA/Q1/gen_Q1_1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q1/gen_Q1_1.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/1_lstm+GA/Q1/gen_Q1_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q1/gen_Q1_2.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/1_lstm+GA/Q1/gen_Q1_3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q1/gen_Q1_3.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/1_lstm+GA/Q2/gen_Q2_1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q2/gen_Q2_1.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/1_lstm+GA/Q2/gen_Q2_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q2/gen_Q2_2.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/1_lstm+GA/Q2/gen_Q2_3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q2/gen_Q2_3.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/1_lstm+GA/Q3/gen_Q3_1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q3/gen_Q3_1.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/1_lstm+GA/Q3/gen_Q3_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q3/gen_Q3_2.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/1_lstm+GA/Q3/gen_Q3_3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q3/gen_Q3_3.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/1_lstm+GA/Q4/gen_Q4_1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q4/gen_Q4_1.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/1_lstm+GA/Q4/gen_Q4_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q4/gen_Q4_2.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/1_lstm+GA/Q4/gen_Q4_3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q4/gen_Q4_3.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/2_Transformer/Q1/Q1_1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q1/Q1_1.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/2_Transformer/Q1/Q1_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q1/Q1_2.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/2_Transformer/Q1/Q1_3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q1/Q1_3.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/2_Transformer/Q2/Q2_1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q2/Q2_1.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/2_Transformer/Q2/Q2_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q2/Q2_2.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/2_Transformer/Q2/Q2_3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q2/Q2_3.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/2_Transformer/Q3/Q3_1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q3/Q3_1.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/2_Transformer/Q3/Q3_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q3/Q3_2.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/2_Transformer/Q3/Q3_3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q3/Q3_3.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/2_Transformer/Q4/Q4_1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q4/Q4_1.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/2_Transformer/Q4/Q4_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q4/Q4_2.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/2_Transformer/Q4/Q4_3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q4/Q4_3.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/3_Pre-trained_Transformer/Q1/Q1_1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q1/Q1_1.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/3_Pre-trained_Transformer/Q1/Q1_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q1/Q1_2.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/3_Pre-trained_Transformer/Q1/Q1_3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q1/Q1_3.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/3_Pre-trained_Transformer/Q2/Q2_1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q2/Q2_1.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/3_Pre-trained_Transformer/Q2/Q2_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q2/Q2_2.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/3_Pre-trained_Transformer/Q2/Q2_3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q2/Q2_3.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/3_Pre-trained_Transformer/Q3/Q3_1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q3/Q3_1.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/3_Pre-trained_Transformer/Q3/Q3_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q3/Q3_2.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/3_Pre-trained_Transformer/Q3/Q3_3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q3/Q3_3.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/3_Pre-trained_Transformer/Q4/Q4_1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q4/Q4_1.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/3_Pre-trained_Transformer/Q4/Q4_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q4/Q4_2.mp3
--------------------------------------------------------------------------------
/docs/assets/audio_samples/3_Pre-trained_Transformer/Q4/Q4_3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q4/Q4_3.mp3
--------------------------------------------------------------------------------
/docs/assets/css/style.css:
--------------------------------------------------------------------------------
1 | /*! normalize.css v3.0.2 | MIT License | git.io/normalize */@import url("https://fonts.googleapis.com/css?family=Open+Sans:400,700&display=swap");@import url("https://fonts.googleapis.com/css2?family=Montserrat:ital,wght@0,400;0,700;1,400;1,700&family=Lato:ital,wght@0,400;0,700;1,400;1,700&display=swap");html{font-family:sans-serif;-ms-text-size-adjust:100%;-webkit-text-size-adjust:100%}body{margin:0}article,aside,details,figcaption,figure,footer,header,hgroup,main,menu,nav,section,summary{display:block}audio,canvas,progress,video{display:inline-block;vertical-align:baseline}audio:not([controls]){display:none;height:0}[hidden],template{display:none}a{background-color:transparent}a:active,a:hover{outline:0}abbr[title]{border-bottom:1px dotted}b,strong{font-weight:bold}dfn{font-style:italic}h1{font-size:2em;margin:0.67em 0}mark{background:#ff0;color:#000}small{font-size:80%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sup{top:-0.5em}sub{bottom:-0.25em}img{border:0}svg:not(:root){overflow:hidden}figure{margin:1em 40px}hr{box-sizing:content-box;height:0}pre{overflow:auto}code,kbd,pre,samp{font-family:monospace, monospace;font-size:1em}button,input,optgroup,select,textarea{color:inherit;font:inherit;margin:0}button{overflow:visible}button,select{text-transform:none}button,html input[type="button"],input[type="reset"],input[type="submit"]{-webkit-appearance:button;cursor:pointer}button[disabled],html input[disabled]{cursor:default}button::-moz-focus-inner,input::-moz-focus-inner{border:0;padding:0}input{line-height:normal}input[type="checkbox"],input[type="radio"]{box-sizing:border-box;padding:0}input[type="number"]::-webkit-inner-spin-button,input[type="number"]::-webkit-outer-spin-button{height:auto}input[type="search"]{-webkit-appearance:textfield;box-sizing:content-box}input[type="search"]::-webkit-search-cancel-button,input[type="search"]::-webkit-search-decoration{-webkit-appearance:none}fieldset{border:1px solid #c0c0c0;margin:0 2px;padding:0.35em 0.625em 0.75em}legend{border:0;padding:0}textarea{overflow:auto}optgroup{font-weight:bold}table{border-collapse:collapse;border-spacing:0}td,th{padding:0}.highlight table td{padding:5px}.highlight table pre{margin:0}.highlight .cm{color:#999988;font-style:italic}.highlight .cp{color:#999999;font-weight:bold}.highlight .c1{color:#999988;font-style:italic}.highlight .cs{color:#999999;font-weight:bold;font-style:italic}.highlight .c,.highlight .cd{color:#999988;font-style:italic}.highlight .err{color:#a61717;background-color:#e3d2d2}.highlight .gd{color:#000000;background-color:#ffdddd}.highlight .ge{color:#000000;font-style:italic}.highlight .gr{color:#aa0000}.highlight .gh{color:#999999}.highlight .gi{color:#000000;background-color:#ddffdd}.highlight .go{color:#888888}.highlight .gp{color:#555555}.highlight .gs{font-weight:bold}.highlight .gu{color:#aaaaaa}.highlight .gt{color:#aa0000}.highlight .kc{color:#000000;font-weight:bold}.highlight .kd{color:#000000;font-weight:bold}.highlight .kn{color:#000000;font-weight:bold}.highlight .kp{color:#000000;font-weight:bold}.highlight .kr{color:#000000;font-weight:bold}.highlight .kt{color:#445588;font-weight:bold}.highlight .k,.highlight .kv{color:#000000;font-weight:bold}.highlight .mf{color:#009999}.highlight .mh{color:#009999}.highlight .il{color:#009999}.highlight .mi{color:#009999}.highlight .mo{color:#009999}.highlight .m,.highlight .mb,.highlight .mx{color:#009999}.highlight .sb{color:#d14}.highlight .sc{color:#d14}.highlight .sd{color:#d14}.highlight .s2{color:#d14}.highlight .se{color:#d14}.highlight .sh{color:#d14}.highlight .si{color:#d14}.highlight .sx{color:#d14}.highlight .sr{color:#009926}.highlight .s1{color:#d14}.highlight .ss{color:#4b001d}.highlight .s{color:#d14}.highlight .na{color:#008080}.highlight .bp{color:#999999}.highlight .nb{color:#0086B3}.highlight .nc{color:#445588;font-weight:bold}.highlight .no{color:#008080}.highlight .nd{color:#3c5d5d;font-weight:bold}.highlight .ni{color:#800080}.highlight .ne{color:#990000;font-weight:bold}.highlight .nf{color:#990000;font-weight:bold}.highlight .nl{color:#990000;font-weight:bold}.highlight .nn{color:#555555}.highlight .nt{color:#000080}.highlight .vc{color:#008080}.highlight .vg{color:#008080}.highlight .vi{color:#008080}.highlight .nv{color:#008080}.highlight .ow{color:#000000;font-weight:bold}.highlight .o{color:#000000;font-weight:bold}.highlight .w{color:#bbbbbb}.highlight{background-color:#f8f8f8}*{box-sizing:border-box}body{padding:0;margin:0;font-family:"Lato", "Helvetica Neue", Helvetica, Arial, sans-serif;font-size:16px;line-height:1.5;color:#424b4f}#skip-to-content{height:1px;width:1px;position:absolute;overflow:hidden;top:-10px}#skip-to-content:focus{position:fixed;top:10px;left:10px;height:auto;width:auto;background:#e19447;outline:thick solid #e19447}a{color:#1e6bb8;text-decoration:none}a:hover{text-decoration:underline}.btn{display:inline-block;margin-bottom:1rem;color:rgba(255,255,255,0.7);background-color:rgba(255,255,255,0.08);border-color:rgba(255,255,255,0.2);border-style:solid;border-width:1px;border-radius:0.3rem;transition:color 0.2s, background-color 0.2s, border-color 0.2s}.btn:hover{color:rgba(255,255,255,0.8);text-decoration:none;background-color:rgba(255,255,255,0.2);border-color:rgba(255,255,255,0.3)}.btn+.btn{margin-left:1rem}@media screen and (min-width: 64em){.btn{padding:0.75rem 1rem}}@media screen and (min-width: 42em) and (max-width: 64em){.btn{padding:0.6rem 0.9rem;font-size:0.9rem}}@media screen and (max-width: 42em){.btn{display:block;width:100%;padding:0.75rem;font-size:0.9rem}.btn+.btn{margin-top:1rem;margin-left:0}}.page-header{font-family:"Montserrat", "Helvetica Neue", Helvetica, Arial, sans-serif;color:#fff;text-align:center;background-color:#006;background-image:url("https://raw.githubusercontent.com/annahung31/EMOPIA/main/docs/img/background.jpg"),linear-gradient(160deg, #333, #006);background-position:center;-webkit-background-size:cover;-moz-background-size:cover;-o-background-size:cover;background-size:cover}@media screen and (min-width: 64em){.page-header{padding:5rem 6rem}}@media screen and (min-width: 42em) and (max-width: 64em){.page-header{padding:3rem 4rem}}@media screen and (max-width: 42em){.page-header{padding:2rem 1rem}}.project-name{margin-top:0;margin-bottom:0.1rem;text-shadow:2px 2px 2px rgba(15,10,86,0.8)}@media screen and (min-width: 64em){.project-name{font-size:3.25rem}}@media screen and (min-width: 42em) and (max-width: 64em){.project-name{font-size:2.25rem}}@media screen and (max-width: 42em){.project-name{font-size:1.75rem}}.project-tagline{margin-bottom:2rem;font-weight:normal;opacity:0.9;text-shadow:2px 2px 2px rgba(15,10,86,0.8)}@media screen and (min-width: 64em){.project-tagline{font-size:1.5rem}}@media screen and (min-width: 42em) and (max-width: 64em){.project-tagline{font-size:1.25rem}}@media screen and (max-width: 42em){.project-tagline{font-size:1.15rem}}.main-content{word-wrap:break-word}.main-content :first-child{margin-top:0}@media screen and (min-width: 64em){.main-content{max-width:64rem;padding:2rem 6rem;margin:0 auto;font-size:1.1rem}}@media screen and (min-width: 42em) and (max-width: 64em){.main-content{padding:2rem 4rem;font-size:1.1rem}}@media screen and (max-width: 42em){.main-content{padding:2rem 1rem;font-size:1rem}}.main-content kbd{background-color:#fafbfc;border:1px solid #c6cbd1;border-bottom-color:#959da5;border-radius:3px;box-shadow:inset 0 -1px 0 #959da5;color:#444d56;display:inline-block;font-size:11px;line-height:10px;padding:3px 5px;vertical-align:middle}.main-content img{max-width:100%}.main-content h1,.main-content h2,.main-content h3,.main-content h4,.main-content h5,.main-content h6{margin-top:2rem;margin-bottom:1rem;font-weight:normal;color:#EC7063}.main-content p{margin-bottom:1em}.main-content code{padding:2px 4px;font-family:Consolas, "Liberation Mono", Menlo, Courier, monospace;font-size:0.9rem;color:#567482;background-color:#f3f6fa;border-radius:0.3rem}.main-content pre{padding:0.8rem;margin-top:0;margin-bottom:1rem;font:1rem Consolas, "Liberation Mono", Menlo, Courier, monospace;color:#567482;word-wrap:normal;background-color:#f3f6fa;border:solid 1px #dce6f0;border-radius:0.3rem}.main-content pre>code{padding:0;margin:0;font-size:0.9rem;color:#567482;word-break:normal;white-space:pre;background:transparent;border:0}.main-content .highlight{margin-bottom:1rem}.main-content .highlight pre{margin-bottom:0;word-break:normal}.main-content .highlight pre,.main-content pre{padding:0.8rem;overflow:auto;font-size:0.9rem;line-height:1.45;border-radius:0.3rem;-webkit-overflow-scrolling:touch}.main-content pre code,.main-content pre tt{display:inline;max-width:initial;padding:0;margin:0;overflow:initial;line-height:inherit;word-wrap:normal;background-color:transparent;border:0}.main-content pre code:before,.main-content pre code:after,.main-content pre tt:before,.main-content pre tt:after{content:normal}.main-content ul,.main-content ol{margin-top:0}.main-content blockquote{padding:0 1rem;margin-left:0;color:#819198;border-left:0.3rem solid #dce6f0}.main-content blockquote>:first-child{margin-top:0}.main-content blockquote>:last-child{margin-bottom:0}.main-content table{display:block;width:100%;overflow:auto;word-break:normal;word-break:keep-all;-webkit-overflow-scrolling:touch}.main-content table th{font-weight:normal;text-align:left}.main-content table th,.main-content table td{padding:0.5rem 1rem;text-align:left;border:1px solid rgba(255,255,255,0)}.main-content dl{padding:0}.main-content dl dt{padding:0;margin-top:1rem;font-size:1rem;font-weight:bold}.main-content dl dd{padding:0;margin-bottom:1rem}.main-content hr{height:2px;padding:0;margin:1rem 0;background-color:#eff0f1;border:0}.site-footer{padding-top:2rem;margin-top:2rem;border-top:solid 1px #eff0f1}@media screen and (min-width: 64em){.site-footer{font-size:1rem}}@media screen and (min-width: 42em) and (max-width: 64em){.site-footer{font-size:1rem}}@media screen and (max-width: 42em){.site-footer{font-size:0.9rem}}.site-footer-owner{display:block;font-weight:bold}.site-footer-credits{color:#819198}
2 | {"mode":"full","isActive":false}
3 | table.audio-table td {
4 | vertical-align: top; }
5 |
6 | table.audio-table audio {
7 | height: 2rem;
8 | max-width: 9rem; }
9 |
10 | @media screen and (min-width: 200rem) {
11 | table.audio-table audio {
12 | display: block;
13 | width: 100%;
14 | max-width: 100%; } }
15 |
16 | @media screen and (min-width: 106rem) {
17 | table.audio-table {
18 | table-layout: fixed;
19 | overflow-y: visible;
20 | overflow-x: visible; }
21 | table.audio-table .metric-col {
22 | width: 7em; } }
23 | @media screen and (max-width: 106rem) {
24 | table.audio-table {
25 | display: block;
26 | max-height: 90vh;
27 | overflow-y: visible;
28 | overflow-x: visible; }
29 | table.audio-table th {
30 | white-space: nowrap; } }
31 |
32 | figcaption {
33 | color: gray;
34 | padding: 2px;
35 | text-align: center;
36 | }
37 |
38 | table.VA-example{
39 | /* border: none; */
40 | position: relative;
41 | left: 20px;
42 | top: 10px;
43 | overflow:visible;
44 |
45 |
46 | }
47 |
48 | table.VA-example td, table.VA-example th {
49 | border: 1px dashed gray;
50 | text-align: center;
51 | font-weight:bold;
52 | color: #566573;
53 | }
54 |
55 |
56 | table.num-table{
57 | /* border: none; */
58 | /* margin: auto; */
59 | width: 100%;
60 | margin-left:auto;
61 | margin-right:auto;
62 |
63 | }
64 |
65 | table.num-table td, table.num-table th {
66 | border: 1px dashed gray;
67 | text-align: center;
68 | color: #566573;
69 | }
70 |
71 |
72 |
73 |
74 | .resp-iframe {
75 | position: relative;
76 | top: 0;
77 | left: 0;
78 | width: 255;
79 | height: 142;
80 | border-radius: 10px;
81 | }
82 |
83 | .VA-image {
84 | background-image: url("https://raw.githubusercontent.com/annahung31/EMOPIA/main/docs/img/VA_Q.png");
85 | background-color: #FFFFFF;
86 | height: 600px;
87 | background-position: center;
88 | background-repeat: no-repeat;
89 | background-size: 950px 550px;
90 | position: relative;
91 | }
92 |
--------------------------------------------------------------------------------
/docs/img/VA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/VA.png
--------------------------------------------------------------------------------
/docs/img/VA_Q.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/VA_Q.png
--------------------------------------------------------------------------------
/docs/img/background.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/background.jpg
--------------------------------------------------------------------------------
/docs/img/classification.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/classification.png
--------------------------------------------------------------------------------
/docs/img/cp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/cp.png
--------------------------------------------------------------------------------
/docs/img/emopia.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/emopia.png
--------------------------------------------------------------------------------
/docs/img/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/example.png
--------------------------------------------------------------------------------
/docs/img/feature.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/feature.png
--------------------------------------------------------------------------------
/docs/img/key.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/key.png
--------------------------------------------------------------------------------
/docs/img/number.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/number.png
--------------------------------------------------------------------------------
/docs/img/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/pipeline.png
--------------------------------------------------------------------------------
/docs/img/representation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/representation.png
--------------------------------------------------------------------------------
/docs/img/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/results.png
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # EMOPIA
2 |
3 | EMOPIA (pronounced ‘yee-mò-pi-uh’) dataset is a shared multi-modal (audio and MIDI) database focusing on perceived emotion in **pop piano music**, to facilitate research on various tasks related to music emotion. The dataset contains **1,087** music clips from 387 songs and **clip-level** emotion labels annotated by four dedicated annotators. Since the clips are not restricted to one clip per song, they can also be used for song-level analysis.
4 |
5 | The detail of the methodology for building the dataset please refer to our paper.
6 |
7 | * [Paper on Arxiv](https://arxiv.org/abs/2108.01374)
8 | * [Dataset on Zenodo](https://zenodo.org/record/5090631#.YPPo-JMzZz8)
9 | * [Code for classification](https://github.com/SeungHeonDoh/EMOPIA_cls)
10 | * [Code for generation](https://github.com/annahung31/EMOPIA)
11 |
12 |
13 | ### Example of the dataset
14 |
15 |
16 |
17 |
18 | Low Valence
19 | High Valence
20 |
21 |
22 |
23 |
24 | High Arousal
25 |
26 |
27 | VIDEO
28 | Q2
29 |
30 |
31 |
32 |
33 | VIDEO
35 | Q1
36 |
37 |
38 |
39 |
40 |
45 |
46 |
47 |
48 | Low Arousal
49 |
50 |
51 |
52 | VIDEO
54 | Q3
55 |
56 |
57 |
58 | VIDEO
60 | Q4
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 | ### Number of clips
69 | The following table shows the number of clips and their average length for each quadrant in Russell’s valence-arousal emotion space, in EMOPIA.
70 |
71 |
72 |
73 | Quadrant
74 | # clips
75 | Avg. length (in sec / #tokens)
76 |
77 |
78 | Q1
79 | 250
80 | 31.9 / 1,065
81 |
82 |
83 |
84 | Q2
85 | 265
86 | 35.6 / 1,368
87 |
88 |
89 |
90 | Q3
91 | 253
92 | 40.6 / 771
93 |
94 |
95 |
96 | Q4
97 | 310
98 | 38.2 / 729
99 |
100 |
101 |
102 |
103 | ### Pipeline of data collection
104 |
105 |
106 |
Fig.1
107 |
108 |
109 |
110 |
111 | ### Dataset Analysis
112 |
113 |
114 |
115 |
Fig.2 Violin plots of the distribution in (a) note density, (b) length, and (c) velocity for clips from different classes.
116 |
117 |
118 |
119 |
120 |
Fig.3 Histogram of the keys (left / right: major / minor 249 keys) for clips from different emotion classes.
121 |
122 |
123 |
124 |
125 | ## Emotion Classification
126 |
127 | * We performed emotion classification task in audio domain and symbolic domain, respectively.
128 |
129 | * Our baseline approach is logistic regression using handcraft features. For the symbolic domain, note length, velocity, beat note density, and key were used, and for the audio domain, an average of 20 dimensions of mel-frequency cepstral co-efficient(MFCC) vectors were used.
130 |
131 |
132 |
133 |
Fig.4 symbolic-domain and audio domain representation.
134 |
135 |
136 | * The input representation of the deep learning appraoch uses mel-spectrogram for the audio domain and midi-like and remi for the symbolic domain.
137 |
138 |
139 |
140 |
Fig.5 results of emotion classification.
141 |
142 |
143 | * We adopt the [A Structured Self-attentive Sentence Embedding](https://arxiv.org/abs/1703.03130) for symbolic-domain classification network, and adopt the [Short-chunk CNN + Residual](https://arxiv.org/abs/2006.00751) for audio-domain classification network.
144 |
145 |
146 |
147 |
148 |
Fig.6 inference example of Sakamoto: Merry Christmas Mr. Lawrence
149 |
150 |
151 | * An inference example is Sakamoto: [Sakamoto: Merry Christmas Mr. Lawrence](https://www.youtube.com/watch?v=zOUlPrCRlVMab_channel=RyuichiSakamoto-Topic). The emotional change in the first half and second half of the song is impressive. The front part is clearly low arousal and the second half turns into high arousal. Impressively, audio and mid classifiers return different inference results.
152 |
153 | For the classification codes, please refer to [SeungHeon's repository](https://github.com/SeungHeonDoh/EMOPIA_cls).
154 |
155 | The pre-trained model weights are also in the repository.
156 |
157 | ## Conditional Generation
158 |
159 | * We adopt the [Compound Word Transformer](https://github.com/YatingMusic/compound-word-transformer) for emotion-conditioned symbolic music generation using EMPOIA. The CP+emotion representation is used as the data representation.
160 |
161 | * In the data representation, we additionally consider the “emotion” tokens and make it a new family. The prepending approach is motivated by [CTRL](https://arxiv.org/abs/1909.05858).
162 |
163 |
164 |
165 |
Fig.7 Compound word with emotion token
166 |
167 |
168 |
169 |
170 | * As the size of EMOPIA might not be big enough, we use additionally the [AILabs1k7 dataset](https://github.com/YatingMusic/compound-word-transformer) compiled by Hsiao et al. to pre-train the Transformer.
171 |
172 | * You can download the model weight of the pre-trained Transformer from [here](https://drive.google.com/file/d/19Seq18b2JNzOamEQMG1uarKjj27HJkHu/view?usp=sharing).
173 |
174 |
175 | * The following are some generated examples for each Quadrant:
176 |
177 | Q1 (High valence, high arousal)
178 |
179 |
180 |
181 |
182 | Baseline
183 |
184 |
185 |
186 |
187 |
188 | Transformer w/o pre-training
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 | Transformer w/ pre-training
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 | Q2 (Low valence, high arousal)
207 |
208 |
209 |
210 |
211 | Baseline
212 |
213 |
214 |
215 |
216 |
217 | Transformer w/o pre-training
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 | Transformer w/ pre-training
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 | Q3 (Low valence, low arousal)
238 |
239 |
240 |
241 |
242 | Baseline
243 |
244 |
245 |
246 |
247 |
248 | Transformer w/o pre-training
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 | Transformer w/ pre-training
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 | Q4 (High valence, low arousal)
267 |
268 |
269 |
270 |
271 | Baseline
272 |
273 |
274 |
275 |
276 |
277 | Transformer w/o pre-training
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 | Transformer w/ pre-training
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 | ## Authors and Affiliations
295 |
296 | * Hsiao-Tzu (Anna) Hung
297 | Research Assistant @ Academia Sinica / MS CSIE student @National Taiwan University
298 | r08922a20@csie.ntu.edu.tw
299 | [website](https://annahung31.github.io/), [LinkedIn](https://www.linkedin.com/in/hsiao-tzu-%EF%BC%88anna%EF%BC%89-hung-09829513a/).
300 |
301 | * Joann Ching
302 | Research Assistant @ Academia Sinica
303 | joann8512@gmail.com
304 | [website](https://joann8512.github.io/JoannChing.github.io/)
305 |
306 | * Seungheon Doh
307 | Ph.D Student @ Music and Audio Computing Lab, KAIST
308 | seungheondoh@kaist.ac.kr
309 | [website](https://seungheondoh.github.io/), [LinkedIn](https://www.linkedin.com/in/dohppak/)
310 |
311 | * Juhan Nam
312 | associate professor @ KAIST
313 | juhan.nam@kaist.ac.kr
314 | [Website](https://mac.kaist.ac.kr/~juhan/)
315 |
316 | * Yi-Hsuan Yang
317 | Chief Music Scientist @ Taiwan AI Labs / Associate Research Fellow @ Academia Sinica
318 | affige@gmail.com, yhyang@ailabs.tw
319 | [website](http://mac.citi.sinica.edu.tw/~yang/)
320 |
321 |
322 | ## Cite this dataset
323 |
324 | {% raw %}
325 | ```
326 | @inproceedings{
327 | {EMOPIA},
328 | author = {Hung, Hsiao-Tzu and Ching, Joann and Doh, Seungheon and Kim, Nabin and Nam, Juhan and Yang, Yi-Hsuan},
329 | title = {{EMOPIA}: A Multi-Modal Pop Piano Dataset For Emotion Recognition and Emotion-based Music Generation},
330 | booktitle = {Proc. Int. Society for Music Information Retrieval Conf.},
331 | year = {2021}
332 | }
333 | ```
334 | {% endraw %}
--------------------------------------------------------------------------------
/docs/init:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | # based on workspace/transformer/generate.ipynb
2 |
3 | import subprocess
4 | from pathlib import Path
5 | import tempfile
6 | import os
7 | import pickle
8 | import sys
9 | import torch
10 | import numpy as np
11 | from midiSynth.synth import MidiSynth
12 | import cog
13 |
14 | sys.path.insert(0, "workspace/transformer")
15 | from utils import write_midi
16 | from models import TransformerModel
17 |
18 |
19 | EMOTIONS = {
20 | "High valence, high arousal": 1,
21 | "Low valence, high arousal": 2,
22 | "Low valence, low arousal": 3,
23 | "High valence, low arousal": 4,
24 | }
25 |
26 |
27 | class Predictor(cog.Predictor):
28 | def setup(self):
29 | print("Loading dictionary...")
30 | path_dictionary = "dataset/co-representation/dictionary.pkl"
31 | with open(path_dictionary, "rb") as f:
32 | self.dictionary = pickle.load(f)
33 | event2word, self.word2event = self.dictionary
34 |
35 | n_class = [] # num of classes for each token
36 | for key in event2word.keys():
37 | n_class.append(len(event2word[key]))
38 | n_token = len(n_class)
39 |
40 | print("Loading model...")
41 | path_saved_ckpt = "exp/pretrained_transformer/loss_25_params.pt"
42 | self.net = TransformerModel(n_class, is_training=False)
43 | self.net.cuda()
44 | self.net.eval()
45 |
46 | self.net.load_state_dict(torch.load(path_saved_ckpt))
47 |
48 | self.midi_synth = MidiSynth()
49 |
50 | @cog.input(
51 | "emotion",
52 | type=str,
53 | default="High valence, high arousal",
54 | options=EMOTIONS.keys(),
55 | help="Emotion to generate for",
56 | )
57 | @cog.input("seed", type=int, default=-1, help="Random seed, -1 for random")
58 | def predict(self, emotion, seed):
59 | if seed < 0:
60 | seed = int.from_bytes(os.urandom(2), "big")
61 | torch.manual_seed(seed)
62 | np.random.seed(seed)
63 | print(f"Prediction seed: {seed}")
64 |
65 | out_dir = Path(tempfile.mkdtemp())
66 | midi_path = out_dir / "out.midi"
67 | wav_path = out_dir / "out.wav"
68 | mp3_path = out_dir / "out.mp3"
69 |
70 | emotion_tag = EMOTIONS[emotion]
71 | res, _ = self.net.inference_from_scratch(
72 | self.dictionary, emotion_tag, n_token=8
73 | )
74 | try:
75 | write_midi(res, str(midi_path), self.word2event)
76 | self.midi_synth.midi2audio(str(midi_path), str(wav_path))
77 | subprocess.check_output(
78 | [
79 | "ffmpeg",
80 | "-i",
81 | str(wav_path),
82 | "-af",
83 | "silenceremove=1:0:-50dB,aformat=dblp,areverse,silenceremove=1:0:-50dB,aformat=dblp,areverse", # strip silence
84 | str(mp3_path),
85 | ],
86 | )
87 | return mp3_path
88 | finally:
89 | midi_path.unlink(missing_ok=True)
90 | wav_path.unlink(missing_ok=True)
91 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit-learn==0.24.1
2 | seaborn==0.11.1
3 | numpy==1.19.5
4 | miditoolkit==0.1.14
5 | pandas==1.1.5
6 | ipdb
7 | tqdm
8 | matplotlib
9 | scipy
10 | pickle
11 |
12 |
--------------------------------------------------------------------------------
/workspace/baseline/evolve_generative_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "4"
3 |
4 | import json
5 | import pickle
6 | import argparse
7 | import numpy as np
8 | import tensorflow as tf
9 |
10 | from midi_generator import generate_midi
11 | from train_classifier import encode_sentence
12 | from train_classifier import get_activated_neurons
13 | from train_generative import build_generative_model
14 |
15 | GEN_MIN = -1
16 | GEN_MAX = 1
17 |
18 | # Directory where trained model will be saved
19 | TRAIN_DIR = "./trained"
20 |
21 | def mutation(individual, mutation_rate):
22 | for i in range(len(individual)):
23 | if np.random.uniform(0, 1) < mutation_rate:
24 | individual[i] = np.random.uniform(GEN_MIN, GEN_MAX)
25 |
26 | def crossover(parent_a, parent_b, ind_size):
27 | # Averaging crossover
28 | return (parent_a + parent_b)/2
29 |
30 | def reproduce(mating_pool, new_population_size, ind_size, mutation_rate):
31 | new_population = np.zeros((new_population_size, ind_size))
32 |
33 | for i in range(new_population_size):
34 | a = np.random.randint(len(mating_pool));
35 | b = np.random.randint(len(mating_pool));
36 |
37 | new_population[i] = crossover(mating_pool[a], mating_pool[b], ind_size);
38 |
39 | # Mutate new children
40 | np.apply_along_axis(mutation, 1, new_population, mutation_rate)
41 |
42 | return new_population;
43 |
44 | def roulette_wheel(population, fitness_pop):
45 | # Normalize fitnesses
46 | norm_fitness_pop = fitness_pop/np.sum(fitness_pop)
47 |
48 | # Here all the fitnesses sum up to 1
49 | r = np.random.uniform(0, 1)
50 |
51 | fitness_so_far = 0
52 | for i in range(len(population)):
53 | fitness_so_far += norm_fitness_pop[i]
54 |
55 | if r < fitness_so_far:
56 | return population[i]
57 |
58 | return population[-1]
59 |
60 | def select(population, fitness_pop, mating_pool_size, ind_size, elite_rate):
61 | mating_pool = np.zeros((mating_pool_size, ind_size))
62 |
63 | # Apply roulete wheel to select mating_pool_size individuals
64 | for i in range(mating_pool_size):
65 | mating_pool[i] = roulette_wheel(population, fitness_pop)
66 |
67 | # Apply elitism
68 | assert elite_rate >= 0 and elite_rate <= 1
69 | elite_size = int(np.ceil(elite_rate * len(population)))
70 | elite_idxs = np.argsort(-fitness_pop)
71 |
72 | for i in range(elite_size):
73 | r = np.random.randint(0, mating_pool_size)
74 | mating_pool[r] = elite_idxs[i]
75 |
76 | return mating_pool
77 |
78 | def calc_fitness(individual, gen_model, cls_model, char2idx, idx2char, layer_idx, sentiment, runs=30):
79 | encoding_size = gen_model.layers[layer_idx].units
80 | generated_midis = np.zeros((runs, encoding_size))
81 |
82 | # Get activated neurons
83 | sentneuron_ixs = get_activated_neurons(cls_model)
84 | assert len(individual) == len(sentneuron_ixs)
85 |
86 | # Use individual gens to override model neurons
87 | override = {}
88 | for i, ix in enumerate(sentneuron_ixs):
89 | override[ix] = individual[i]
90 |
91 | # Generate pieces and encode them using the cell state of the generative model
92 | for i in range(runs):
93 | midi_text = generate_midi(gen_model, char2idx, idx2char, seq_len=64, layer_idx=layer_idx, override=override)
94 | generated_midis[i] = encode_sentence(gen_model, midi_text, char2idx, layer_idx)
95 |
96 | midis_sentiment = cls_model.predict(generated_midis).clip(min=0)
97 | return 1.0 - np.sum(np.abs(midis_sentiment - sentiment))/runs
98 |
99 | def evaluate(population, gen_model, cls_model, char2idx, idx2char, layer_idx, sentiment):
100 | fitness = np.zeros((len(population), 1))
101 |
102 | for i in range(len(population)):
103 | fitness[i] = calc_fitness(population[i], gen_model, cls_model, char2idx, idx2char, layer_idx, sentiment)
104 |
105 | return fitness
106 |
107 | def evolve(pop_size, ind_size, mut_rate, elite_rate, epochs):
108 | # Create initial population
109 | population = np.random.uniform(GEN_MIN, GEN_MAX, (pop_size, ind_size))
110 |
111 | # Evaluate initial population
112 | fitness_pop = evaluate(population, gen_model, cls_model, char2idx, idx2char, opt.cellix, sent)
113 | print("--> Fitness: \n", fitness_pop)
114 |
115 | for i in range(epochs):
116 | print("-> Epoch", i)
117 |
118 | # Select individuals via roulette wheel to form a mating pool
119 | mating_pool = select(population, fitness_pop, pop_size, ind_size, elite_rate)
120 |
121 | # Reproduce matin pool with crossover and mutation to form new population
122 | population = reproduce(mating_pool, pop_size, ind_size, mut_rate)
123 |
124 | # Calculate fitness of each individual of the population
125 | fitness_pop = evaluate(population, gen_model, cls_model, char2idx, idx2char, opt.cellix, sent)
126 | print("--> Fitness: \n", fitness_pop)
127 |
128 | return population, fitness_pop
129 |
130 | if __name__ == "__main__":
131 |
132 | # Parse arguments
133 | parser = argparse.ArgumentParser(description='evolve_generative.py')
134 | parser.add_argument('--genmodel', type=str, default='./trained', help="Generative model to evolve.")
135 | parser.add_argument('--clsmodel', type=str, default='./trained/classifier_ckpt.p', help="Classifier model to calculate fitness.")
136 | parser.add_argument('--ch2ix', type=str, default='./trained/char2idx.json', help="JSON file with char2idx encoding.")
137 | parser.add_argument('--embed', type=int, default=256, help="Embedding size.")
138 | parser.add_argument('--units', type=int, default=512, help="LSTM units.")
139 | parser.add_argument('--layers', type=int, default=4, help="LSTM layers.")
140 | parser.add_argument('--cellix', type=int, default=4, help="LSTM layer to use as encoder.")
141 | #parser.add_argument('--sent', type=int, default=2, help="Desired sentiment.")
142 | parser.add_argument('--popsize', type=int, default=10, help="Population size.")
143 | parser.add_argument('--epochs', type=int, default=10, help="Epochs to run.")
144 | parser.add_argument('--mrate', type=float, default=0.1, help="Mutation rate.")
145 | parser.add_argument('--elitism', type=float, default=0.1, help="Elitism in percentage.")
146 |
147 | opt = parser.parse_args()
148 |
149 | # Load char2idx dict from json file
150 | with open(opt.ch2ix) as f:
151 | char2idx = json.load(f)
152 |
153 | # Create idx2char from char2idx dict
154 | idx2char = {idx:char for char,idx in char2idx.items()}
155 |
156 | # Calculate vocab_size from char2idx dict
157 | vocab_size = len(char2idx)
158 |
159 | # Rebuild generative model from checkpoint
160 | gen_model = build_generative_model(vocab_size, opt.embed, opt.units, opt.layers, batch_size=1)
161 | gen_model.load_weights(tf.train.latest_checkpoint(opt.genmodel))
162 | gen_model.build(tf.TensorShape([1, None]))
163 |
164 | # Load classifier model
165 | with open(opt.clsmodel, "rb") as f:
166 | cls_model = pickle.load(f)
167 |
168 | # Set individual size to the number of activated neurons
169 | sentneuron_ixs = get_activated_neurons(cls_model)
170 | ind_size = len(sentneuron_ixs)
171 |
172 | # Evolve for Positive(1)/Negative(1)
173 | sents = [1, 2, 3, 4]
174 | for sent in sents:
175 | print('evolving for {}'.format(sent))
176 | population, fitness_pop = evolve(opt.popsize, ind_size, opt.mrate, opt.elitism, opt.epochs)
177 |
178 | # Get best individual
179 | best_idx = np.argmax(fitness_pop)
180 | best_individual = population[best_idx]
181 |
182 | # Use best individual gens to create a dictionary with cell values
183 | neurons = {}
184 | for i, ix in enumerate(sentneuron_ixs):
185 | neurons[str(ix)] = best_individual[i]
186 |
187 |
188 | # Persist dictionary with cell values
189 | if sent == 1:
190 | sent = 'Q1'
191 | elif sent == 2:
192 | sent = 'Q2'
193 | elif sent == 3:
194 | sent = 'Q3'
195 | else:
196 | sent = 'Q4'
197 |
198 | with open(os.path.join(TRAIN_DIR, "neurons_" + sent + ".json"), "w") as f:
199 | json.dump(neurons, f)
--------------------------------------------------------------------------------
/workspace/baseline/midi_encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import math as ma
5 | import music21 as m21
6 |
7 | THREE_DOTTED_BREVE = 15
8 | THREE_DOTTED_32ND = 0.21875
9 |
10 | MIN_VELOCITY = 0
11 | MAX_VELOCITY = 128
12 |
13 | MIN_TEMPO = 24
14 | MAX_TEMPO = 160
15 |
16 | MAX_PITCH = 128
17 |
18 | def load(datapath, sample_freq=4, piano_range=(33, 93), transpose_range=10, stretching_range=10):
19 | text = ""
20 | vocab = set()
21 |
22 | if os.path.isfile(datapath):
23 | # Path is an individual midi file
24 | file_extension = os.path.splitext(datapath)[1]
25 |
26 | if file_extension == ".midi" or file_extension == ".mid":
27 | text = parse_midi(datapath, sample_freq, piano_range, transpose_range, stretching_range)
28 | vocab = set(text.split(" "))
29 | else:
30 | # Read every file in the given directory
31 | for file in os.listdir(datapath):
32 | file_path = os.path.join(datapath, file)
33 | file_extension = os.path.splitext(file_path)[1]
34 |
35 | # Check if it is not a directory and if it has either .midi or .mid extentions
36 | if os.path.isfile(file_path) and (file_extension == ".midi" or file_extension == ".mid"):
37 | encoded_midi = parse_midi(file_path, sample_freq, piano_range, transpose_range, stretching_range)
38 |
39 | if len(encoded_midi) > 0:
40 | words = set(encoded_midi.split(" "))
41 | vocab = vocab | words
42 |
43 | text += encoded_midi + " "
44 |
45 | # Remove last space
46 | text = text[:-1]
47 |
48 | return text, vocab
49 |
50 | def parse_midi(file_path, sample_freq, piano_range, transpose_range, stretching_range):
51 |
52 | # Split datapath into dir and filename
53 | midi_dir = os.path.dirname(file_path)
54 | midi_name = os.path.basename(file_path).split(".")[0]
55 |
56 | # If txt version of the midi already exists, load data from it
57 | midi_txt_name = os.path.join(midi_dir, midi_name + ".txt")
58 |
59 | if(os.path.isfile(midi_txt_name)):
60 | midi_fp = open(midi_txt_name, "r")
61 | encoded_midi = midi_fp.read()
62 | else:
63 | # Create a music21 stream and open the midi file
64 | midi = m21.midi.MidiFile()
65 | midi.open(file_path)
66 | midi.read()
67 | midi.close()
68 |
69 | # Translate midi to stream of notes and chords
70 | encoded_midi = midi2encoding(midi, sample_freq, piano_range, transpose_range, stretching_range)
71 |
72 | if len(encoded_midi) > 0:
73 | midi_fp = open(midi_txt_name, "w+")
74 | midi_fp.write(encoded_midi)
75 | midi_fp.flush()
76 |
77 | midi_fp.close()
78 | return encoded_midi
79 |
80 | def midi2encoding(midi, sample_freq, piano_range, transpose_range, stretching_range):
81 | try:
82 | midi_stream = m21.midi.translate.midiFileToStream(midi)
83 | except:
84 | return []
85 |
86 | # Get piano roll from midi stream
87 | piano_roll = midi2piano_roll(midi_stream, sample_freq, piano_range, transpose_range, stretching_range)
88 |
89 | # Get encoded midi from piano roll
90 | encoded_midi = piano_roll2encoding(piano_roll)
91 |
92 | return " ".join(encoded_midi)
93 |
94 | def piano_roll2encoding(piano_roll):
95 | # Transform piano roll into a list of notes in string format
96 | final_encoding = {}
97 |
98 | perform_i = 0
99 | for version in piano_roll:
100 | lastTempo = -1
101 | lastVelocity = -1
102 | lastDuration = -1.0
103 |
104 | version_encoding = []
105 |
106 | for i in range(len(version)):
107 | # Time events are stored at the last row
108 | tempo = version[i,-1][0]
109 | if tempo != 0 and tempo != lastTempo:
110 | version_encoding.append("t_" + str(int(tempo)))
111 | lastTempo = tempo
112 |
113 | # Process current time step of the piano_roll
114 | for j in range(len(version[i]) - 1):
115 | duration = version[i,j][0]
116 | velocity = int(version[i,j][1])
117 |
118 | if velocity != 0 and velocity != lastVelocity:
119 | version_encoding.append("v_" + str(velocity))
120 | lastVelocity = velocity
121 |
122 | if duration != 0 and duration != lastDuration:
123 | duration_tuple = m21.duration.durationTupleFromQuarterLength(duration)
124 | version_encoding.append("d_" + duration_tuple.type + "_" + str(duration_tuple.dots))
125 | lastDuration = duration
126 |
127 | if duration != 0 and velocity != 0:
128 | version_encoding.append("n_" + str(j))
129 |
130 | # End of time step
131 | if len(version_encoding) > 0 and version_encoding[-1][0] == "w":
132 | # Increase wait by one
133 | version_encoding[-1] = "w_" + str(int(version_encoding[-1].split("_")[1]) + 1)
134 | else:
135 | version_encoding.append("w_1")
136 |
137 | # End of piece
138 | version_encoding.append("\n")
139 |
140 | # Check if this version of the MIDI is already added
141 | version_encoding_str = " ".join(version_encoding)
142 | if version_encoding_str not in final_encoding:
143 | final_encoding[version_encoding_str] = perform_i
144 |
145 | perform_i += 1
146 |
147 | return final_encoding.keys()
148 |
149 | def write(encoded_midi, path):
150 | # Base class checks if output path exists
151 | midi = encoding2midi(encoded_midi)
152 | midi.open(path, "wb")
153 | midi.write()
154 | midi.close()
155 |
156 | def encoding2midi(note_encoding, ts_duration=0.25):
157 | notes = []
158 |
159 | velocity = 100
160 | duration = "16th"
161 | dots = 0
162 |
163 | ts = 0
164 | for note in note_encoding.split(" "):
165 | if len(note) == 0:
166 | continue
167 |
168 | elif note[0] == "w":
169 | wait_count = int(note.split("_")[1])
170 | ts += wait_count
171 |
172 | elif note[0] == "n":
173 | pitch = int(note.split("_")[1])
174 | note = m21.note.Note(pitch)
175 | note.duration = m21.duration.Duration(type=duration, dots=dots)
176 | note.offset = ts * ts_duration
177 | note.volume.velocity = velocity
178 | notes.append(note)
179 |
180 | elif note[0] == "d":
181 | duration = note.split("_")[1]
182 | dots = int(note.split("_")[2])
183 |
184 | elif note[0] == "v":
185 | velocity = int(note.split("_")[1])
186 |
187 | elif note[0] == "t":
188 | tempo = int(note.split("_")[1])
189 |
190 | if tempo > 0:
191 | mark = m21.tempo.MetronomeMark(number=tempo)
192 | mark.offset = ts * ts_duration
193 | notes.append(mark)
194 |
195 | piano = m21.instrument.fromString("Piano")
196 | notes.insert(0, piano)
197 |
198 | piano_stream = m21.stream.Stream(notes)
199 | main_stream = m21.stream.Stream([piano_stream])
200 |
201 | return m21.midi.translate.streamToMidiFile(main_stream)
202 |
203 | def midi_parse_notes(midi_stream, sample_freq):
204 | note_filter = m21.stream.filters.ClassFilter('Note')
205 |
206 | note_events = []
207 | for note in midi_stream.recurse().addFilter(note_filter):
208 | pitch = note.pitch.midi
209 | duration = note.duration.quarterLength
210 | velocity = note.volume.velocity
211 | offset = ma.floor(note.offset * sample_freq)
212 |
213 | note_events.append((pitch, duration, velocity, offset))
214 |
215 | return note_events
216 |
217 | def midi_parse_chords(midi_stream, sample_freq):
218 | chord_filter = m21.stream.filters.ClassFilter('Chord')
219 |
220 | note_events = []
221 | for chord in midi_stream.recurse().addFilter(chord_filter):
222 | pitches_in_chord = chord.pitches
223 | for pitch in pitches_in_chord:
224 | pitch = pitch.midi
225 | duration = chord.duration.quarterLength
226 | velocity = chord.volume.velocity
227 | offset = ma.floor(chord.offset * sample_freq)
228 |
229 | note_events.append((pitch, duration, velocity, offset))
230 |
231 | return note_events
232 |
233 | def midi_parse_metronome(midi_stream, sample_freq):
234 | metronome_filter = m21.stream.filters.ClassFilter('MetronomeMark')
235 |
236 | time_events = []
237 | for metro in midi_stream.recurse().addFilter(metronome_filter):
238 | time = int(metro.number)
239 | offset = ma.floor(metro.offset * sample_freq)
240 | time_events.append((time, offset))
241 |
242 | return time_events
243 |
244 | def midi2notes(midi_stream, sample_freq, transpose_range):
245 | notes = []
246 | notes += midi_parse_notes(midi_stream, sample_freq)
247 | notes += midi_parse_chords(midi_stream, sample_freq)
248 |
249 | # Transpose the notes to all the keys in transpose_range
250 | return transpose_notes(notes, transpose_range)
251 |
252 | def midi2piano_roll(midi_stream, sample_freq, piano_range, transpose_range, stretching_range):
253 | # Calculate the amount of time steps in the piano roll
254 | time_steps = ma.floor(midi_stream.duration.quarterLength * sample_freq) + 1
255 |
256 | # Parse the midi file into a list of notes (pitch, duration, velocity, offset)
257 | transpositions = midi2notes(midi_stream, sample_freq, transpose_range)
258 |
259 | time_events = midi_parse_metronome(midi_stream, sample_freq)
260 | time_streches = strech_time(time_events, stretching_range)
261 |
262 | return notes2piano_roll(transpositions, time_streches, time_steps, piano_range)
263 |
264 | def notes2piano_roll(transpositions, time_streches, time_steps, piano_range):
265 | performances = []
266 |
267 | min_pitch, max_pitch = piano_range
268 | for t_ix in range(len(transpositions)):
269 | for s_ix in range(len(time_streches)):
270 | # Create piano roll with calcualted size.
271 | # Add one dimension to very entry to store velocity and duration.
272 | piano_roll = np.zeros((time_steps, MAX_PITCH + 1, 2))
273 |
274 | for note in transpositions[t_ix]:
275 | pitch, duration, velocity, offset = note
276 | if duration == 0.0:
277 | continue
278 |
279 | # Force notes to be inside the specified piano_range
280 | pitch = clamp_pitch(pitch, max_pitch, min_pitch)
281 |
282 | piano_roll[offset, pitch][0] = clamp_duration(duration)
283 | piano_roll[offset, pitch][1] = discretize_value(velocity, bins=32, range=(MIN_VELOCITY, MAX_VELOCITY))
284 |
285 | for time_event in time_streches[s_ix]:
286 | time, offset = time_event
287 | piano_roll[offset, -1][0] = discretize_value(time, bins=100, range=(MIN_TEMPO, MAX_TEMPO))
288 |
289 | performances.append(piano_roll)
290 |
291 | return performances
292 |
293 | def transpose_notes(notes, transpose_range):
294 | transpositions = []
295 |
296 | # Modulate the piano_roll for other keys
297 | first_key = -ma.floor(transpose_range/2)
298 | last_key = ma.ceil(transpose_range/2)
299 |
300 | for key in range(first_key, last_key):
301 | notes_in_key = []
302 | for n in notes:
303 | pitch, duration, velocity, offset = n
304 | t_pitch = pitch + key
305 | notes_in_key.append((t_pitch, duration, velocity, offset))
306 | transpositions.append(notes_in_key)
307 |
308 | return transpositions
309 |
310 | def strech_time(time_events, stretching_range):
311 | streches = []
312 |
313 | # Modulate the piano_roll for other keys
314 | slower_time = -ma.floor(stretching_range/2)
315 | faster_time = ma.ceil(stretching_range/2)
316 |
317 | # Modulate the piano_roll for other keys
318 | for t_strech in range(slower_time, faster_time):
319 | time_events_in_strech = []
320 | for t_ev in time_events:
321 | time, offset = t_ev
322 | s_time = time + 0.05 * t_strech * MAX_TEMPO
323 | time_events_in_strech.append((s_time, offset))
324 | streches.append(time_events_in_strech)
325 |
326 | return streches
327 |
328 | def discretize_value(val, bins, range):
329 | min_val, max_val = range
330 |
331 | val = int(max(min_val, val))
332 | val = int(min(val, max_val))
333 |
334 | bin_size = (max_val/bins)
335 | return ma.floor(val/bin_size) * bin_size
336 |
337 | def clamp_pitch(pitch, max, min):
338 | while pitch < min:
339 | pitch += 12
340 | while pitch >= max:
341 | pitch -= 12
342 | return pitch
343 |
344 | def clamp_duration(duration, max=THREE_DOTTED_BREVE, min=THREE_DOTTED_32ND):
345 | # Max duration is 3-dotted breve
346 | if duration > max:
347 | duration = max
348 |
349 | # min duration is 3-dotted breve
350 | if duration < min:
351 | duration = min
352 |
353 | duration_tuple = m21.duration.durationTupleFromQuarterLength(duration)
354 | if duration_tuple.type == "inexpressible":
355 | duration_clossest_type = m21.duration.quarterLengthToClosestType(duration)[0]
356 | duration = m21.duration.typeToDuration[duration_clossest_type]
357 |
358 | return duration
359 |
360 | if __name__ == "__main__":
361 |
362 | # Parse arguments
363 | parser = argparse.ArgumentParser(description='midi_encoder.py')
364 | parser.add_argument('--path', type=str, required=True, help="Path to midi data.")
365 | parser.add_argument('--transp', type=int, default=1, help="Transpose range.")
366 | parser.add_argument('--strech', type=int, default=1, help="Time stretching range.")
367 | opt = parser.parse_args()
368 |
369 | # Load data and encoded it
370 | text, vocab = load(opt.path, transpose_range=opt.transp, stretching_range=opt.strech)
371 | print(text)
372 |
373 | # Write all data to midi file
374 | write(text, "encoded.mid")
375 |
--------------------------------------------------------------------------------
/workspace/baseline/midi_generator.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "4"
3 |
4 | import json
5 | import argparse
6 | import numpy as np
7 | import tensorflow as tf
8 | import midi_encoder as me
9 |
10 | from train_generative import build_generative_model
11 | from train_classifier import preprocess_sentence
12 |
13 | GENERATED_DIR = './generated'
14 |
15 | def override_neurons(model, layer_idx, override):
16 | h_state, c_state = model.get_layer(index=layer_idx).states
17 |
18 | c_state = c_state.numpy()
19 | for neuron, value in override.items():
20 | c_state[:,int(neuron)] = int(value)
21 |
22 | model.get_layer(index=layer_idx).states = (h_state, tf.Variable(c_state))
23 |
24 | def sample_next(predictions, k):
25 | # Sample using a categorical distribution over the top k midi chars
26 | top_k = tf.math.top_k(predictions, k)
27 | top_k_choices = top_k[1].numpy().squeeze()
28 | top_k_values = top_k[0].numpy().squeeze()
29 |
30 | if np.random.uniform(0, 1) < .5:
31 | predicted_id = top_k_choices[0]
32 | else:
33 | p_choices = tf.math.softmax(top_k_values[1:]).numpy()
34 | predicted_id = np.random.choice(top_k_choices[1:], 1, p=p_choices)[0]
35 |
36 | return predicted_id
37 |
38 | def process_init_text(model, init_text, char2idx, layer_idx, override):
39 | model.reset_states()
40 |
41 | for c in init_text.split(" "):
42 | # Run a forward pass
43 | try:
44 | input_eval = tf.expand_dims([char2idx[c]], 0)
45 |
46 | # override sentiment neurons
47 | override_neurons(model, layer_idx, override)
48 |
49 | predictions = model(input_eval)
50 | except KeyError:
51 | if c != "":
52 | print("Can't process char", s)
53 |
54 | return predictions
55 |
56 | def generate_midi(model, char2idx, idx2char, init_text="", seq_len=256, k=3, layer_idx=-2, override={}):
57 | # Add front and end pad to the initial text
58 | init_text = preprocess_sentence(init_text)
59 |
60 | # Empty midi to store our results
61 | midi_generated = []
62 |
63 | # Process initial text
64 | predictions = process_init_text(model, init_text, char2idx, layer_idx, override)
65 |
66 | # Here batch size == 1
67 | model.reset_states()
68 | for i in range(seq_len):
69 | # remove the batch dimension
70 | predictions = tf.squeeze(predictions, 0).numpy()
71 |
72 | # Sample using a categorical distribution over the top k midi chars
73 | predicted_id = sample_next(predictions, k)
74 |
75 | # Append it to generated midi
76 | midi_generated.append(idx2char[predicted_id])
77 |
78 | # override sentiment neurons
79 | override_neurons(model, layer_idx, override)
80 |
81 | #Run a new forward pass
82 | input_eval = tf.expand_dims([predicted_id], 0)
83 | predictions = model(input_eval)
84 |
85 | return init_text + " " + " ".join(midi_generated)
86 |
87 | if __name__ == "__main__":
88 |
89 | # Parse arguments
90 | parser = argparse.ArgumentParser(description='midi_generator.py')
91 | parser.add_argument('--model', type=str, default='./trained', help="Checkpoint dir.")
92 | parser.add_argument('--ch2ix', type=str, default='./trained/char2idx.json', help="JSON file with char2idx encoding.")
93 | parser.add_argument('--embed', type=int, default=256, help="Embedding size.")
94 | parser.add_argument('--units', type=int, default=512, help="LSTM units.")
95 | parser.add_argument('--layers', type=int, default=4, help="LSTM layers.")
96 | parser.add_argument('--seqinit', type=str, default="\n", help="Sequence init.")
97 | parser.add_argument('--seqlen', type=int, default=512, help="Sequence lenght.")
98 | parser.add_argument('--cellix', type=int, default=4, help="LSTM layer to use as encoder.")
99 | parser.add_argument('--override', type=str, default="./trained/neurons_Q1.json", help="JSON file with neuron values to override.")
100 | opt = parser.parse_args()
101 |
102 | # Load char2idx dict from json file
103 | with open(opt.ch2ix) as f:
104 | char2idx = json.load(f)
105 |
106 | # Load override dict from json file
107 | override = {}
108 |
109 | try:
110 | with open(opt.override) as f:
111 | override = json.load(f)
112 | except FileNotFoundError:
113 | print("Override JSON file not provided.")
114 |
115 | # Create idx2char from char2idx dict
116 | idx2char = {idx:char for char,idx in char2idx.items()}
117 |
118 | # Calculate vocab_size from char2idx dict
119 | vocab_size = len(char2idx)
120 |
121 | # Rebuild model from checkpoint
122 | model = build_generative_model(vocab_size, opt.embed, opt.units, opt.layers, batch_size=1)
123 | model.load_weights(tf.train.latest_checkpoint(opt.model))
124 | model.build(tf.TensorShape([1, None]))
125 |
126 | if not os.path.exists(GENERATED_DIR):
127 | os.makedirs(GENERATED_DIR)
128 | # Generate 5 midis
129 | for i in range(100):
130 | # Generate a midi as text
131 | print("Generate midi {}".format(i))
132 | midi_txt = generate_midi(model, char2idx, idx2char, opt.seqinit, opt.seqlen, layer_idx=opt.cellix, override=override)
133 |
134 |
135 | me.write(midi_txt, os.path.join(GENERATED_DIR, "generated_Q1_{}.mid".format(i)))
136 |
--------------------------------------------------------------------------------
/workspace/baseline/plot_results.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | import matplotlib
5 | matplotlib.use('Agg')
6 | import matplotlib.pyplot as plt
7 |
8 | # Directory where plots will be saved
9 | PLOTS_DIR = './results'
10 |
11 | def plot_logits(xs, ys, top_neurons):
12 | for n in top_neurons:
13 | plot_logit_and_save(xs, ys, n)
14 |
15 | def plot_logit_and_save(xs, ys, neuron_index):
16 | sentiment_unit = xs[:,neuron_index]
17 |
18 | # plt.title('Distribution of Logit Values')
19 | plt.ylabel('Number of Phrases')
20 | plt.xlabel('Value of the Sentiment Neuron')
21 | plt.hist(sentiment_unit[ys == -1], bins=50, alpha=0.5, label='Negative Phrases')
22 | plt.hist(sentiment_unit[ys == 1], bins=50, alpha=0.5, label='Positive Phrases')
23 | plt.legend()
24 | plt.savefig(os.path.join(PLOTS_DIR, "neuron_" + str(neuron_index) + '.png'))
25 | plt.clf()
26 |
27 | def plot_weight_contribs(coef):
28 | plt.title('Values of Resulting L1 Penalized Weights')
29 | plt.tick_params(axis='both', which='major')
30 |
31 | # Normalize weight contributions
32 | norm = np.linalg.norm(coef)
33 | coef = coef/norm
34 |
35 | plt.plot(range(len(coef[0])), coef.T)
36 | plt.xlabel('Neuron (Feature) Index')
37 | plt.ylabel('Neuron (Feature) weight')
38 | plt.savefig(os.path.join(PLOTS_DIR, "weight_contribs.png"))
39 | plt.clf()
40 |
--------------------------------------------------------------------------------
/workspace/baseline/plot_results_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | import matplotlib
5 | matplotlib.use('Agg')
6 | import matplotlib.pyplot as plt
7 |
8 | # Directory where plots will be saved
9 | PLOTS_DIR = './results_base'
10 |
11 | def plot_logits(xs, ys, top_neurons):
12 | for n in top_neurons:
13 | plot_logit_and_save(xs, ys, n)
14 |
15 | def plot_logit_and_save(xs, ys, neuron_index):
16 | sentiment_unit = xs[:,neuron_index]
17 |
18 | # plt.title('Distribution of Logit Values')
19 | plt.ylabel('Number of Phrases')
20 | plt.xlabel('Value of the Sentiment Neuron')
21 | plt.hist(sentiment_unit[ys == -1], bins=50, alpha=0.5, label='Negative Phrases')
22 | plt.hist(sentiment_unit[ys == 1], bins=50, alpha=0.5, label='Positive Phrases')
23 | plt.legend()
24 | plt.savefig(os.path.join(PLOTS_DIR, "neuron_" + str(neuron_index) + '.png'))
25 | plt.clf()
26 |
27 | def plot_weight_contribs(coef):
28 | plt.title('Values of Resulting L1 Penalized Weights')
29 | plt.tick_params(axis='both', which='major')
30 |
31 | # Normalize weight contributions
32 | norm = np.linalg.norm(coef)
33 | coef = coef/norm
34 |
35 | plt.plot(range(len(coef[0])), coef.T)
36 | plt.xlabel('Neuron (Feature) Index')
37 | plt.ylabel('Neuron (Feature) weight')
38 | plt.savefig(os.path.join(PLOTS_DIR, "weight_contribs.png"))
39 | plt.clf()
40 |
--------------------------------------------------------------------------------
/workspace/baseline/train_classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "4"
3 |
4 | import csv
5 | import json
6 | import pickle
7 | import argparse
8 | import numpy as np
9 | import tensorflow as tf
10 | import midi_encoder as me
11 | import plot_results_base as pr
12 |
13 | from train_generative import build_generative_model
14 | from sklearn.linear_model import LogisticRegression
15 | from sklearn.preprocessing import LabelBinarizer
16 |
17 | # Directory where trained model will be saved
18 | TRAIN_DIR = "./trained"
19 | DATA_TRAIN = '../data/train_split.csv'
20 | DATA_TEST = '../data/test_split.csv'
21 | DATASET_PATH = '../'
22 |
23 | def preprocess_sentence(text, front_pad='\n ', end_pad=''):
24 | text = text.replace('\n', ' ').strip()
25 | text = front_pad+text+end_pad
26 | return text
27 |
28 | def encode_sentence(model, text, char2idx, layer_idx):
29 | text = preprocess_sentence(text)
30 |
31 | # Reset LSTMs hidden and cell states
32 | model.reset_states()
33 |
34 | for c in text.split(" "):
35 | # Add the batch dimension
36 | try:
37 | input_eval = tf.expand_dims([char2idx[c]], 0)
38 | predictions = model(input_eval)
39 | except KeyError:
40 | if c != "":
41 | print("Can't process char", c)
42 |
43 | h_state, c_state = model.get_layer(index=layer_idx).states
44 |
45 | # remove the batch dimension
46 | #h_state = tf.squeeze(h_state, 0)
47 | c_state = tf.squeeze(c_state, 0)
48 |
49 | return tf.math.tanh(c_state).numpy()
50 |
51 | def build_dataset(datapath, generative_model, char2idx, layer_idx):
52 | xs, ys = [], []
53 |
54 | csv_file = open(datapath, "r")
55 | data = csv.DictReader(csv_file)
56 |
57 | for row in data:
58 | label = int(row["label"])
59 | filepath = row["filepath"]
60 |
61 | data_dir = os.path.dirname(datapath)
62 | phrase_path = filepath.replace('./', '')
63 | phrase_path = os.path.join(DATASET_PATH, phrase_path)
64 | encoded_path = os.path.splitext(filepath)[0]+'.npy'
65 | encoded_path = encoded_path.replace('./', '')
66 | encoded_path = os.path.join(DATASET_PATH, encoded_path)
67 |
68 | # Load midi file as text
69 | if os.path.isfile(encoded_path):
70 | encoding = np.load(encoded_path)
71 | else:
72 | text, vocab = me.load(phrase_path, transpose_range=1, stretching_range=1)
73 |
74 | # Encode midi text using generative lstm
75 | encoding = encode_sentence(generative_model, text, char2idx, layer_idx)
76 |
77 | # Save encoding in file to make it faster to load next time
78 | np.save(encoded_path, encoding)
79 |
80 | xs.append(encoding)
81 | ys.append(label)
82 |
83 | return np.array(xs), np.array(ys)
84 |
85 | def train_classifier_model(train_dataset, test_dataset, C=2**np.arange(-8, 1).astype(np.float), seed=42, penalty="l1"):
86 | trX, trY = train_dataset
87 | teX, teY = test_dataset
88 |
89 | scores = []
90 |
91 | # Hyper-parameter optimization
92 | for i, c in enumerate(C):
93 | logreg_model = LogisticRegression(C=c, penalty=penalty, random_state=seed+i, solver="liblinear")
94 | logreg_model.fit(trX, trY)
95 |
96 | score = logreg_model.score(teX, teY)
97 | scores.append(score)
98 |
99 | c = C[np.argmax(scores)]
100 |
101 | sent_classfier = LogisticRegression(C=c, penalty=penalty, random_state=seed+len(C), solver="liblinear")
102 | sent_classfier.fit(trX, trY)
103 |
104 | score = sent_classfier.score(teX, teY) * 100.
105 |
106 | # Persist sentiment classifier
107 | with open(os.path.join(TRAIN_DIR, "classifier_ckpt.p"), "wb") as f:
108 | pickle.dump(sent_classfier, f)
109 |
110 | # Get activated neurons
111 | sentneuron_ixs = get_activated_neurons(sent_classfier)
112 |
113 | # Plot results
114 | pr.plot_weight_contribs(sent_classfier.coef_)
115 | pr.plot_logits(trX, trY, sentneuron_ixs)
116 |
117 | return sentneuron_ixs, score
118 |
119 | def get_activated_neurons(sent_classfier):
120 | neurons_not_zero = len(np.argwhere(sent_classfier.coef_))
121 |
122 | weights = sent_classfier.coef_.T
123 | weight_penalties = np.squeeze(np.linalg.norm(weights, ord=1, axis=1))
124 |
125 | if neurons_not_zero == 1:
126 | neuron_ixs = np.array([np.argmax(weight_penalties)])
127 | elif neurons_not_zero >= np.log(len(weight_penalties)):
128 | neuron_ixs = np.argsort(weight_penalties)[-neurons_not_zero:][::-1]
129 | else:
130 | neuron_ixs = np.argpartition(weight_penalties, -neurons_not_zero)[-neurons_not_zero:]
131 | neuron_ixs = (neuron_ixs[np.argsort(weight_penalties[neuron_ixs])])[::-1]
132 |
133 | return neuron_ixs
134 |
135 | if __name__ == "__main__":
136 |
137 | # Parse arguments
138 | parser = argparse.ArgumentParser(description='train_classifier.py')
139 | parser.add_argument('--train', type=str, default=DATA_TRAIN, help="Train dataset.")
140 | parser.add_argument('--test' , type=str, default=DATA_TEST, help="Test dataset.")
141 | parser.add_argument('--model', type=str, default='./trained/', help="Checkpoint dir.")
142 | parser.add_argument('--ch2ix', type=str, default='./trained/char2idx.json', help="JSON file with char2idx encoding.")
143 | parser.add_argument('--embed', type=int, default=256, help="Embedding size.")
144 | parser.add_argument('--units', type=int, default=512, help="LSTM units.")
145 | parser.add_argument('--layers', type=int, default=4, help="LSTM layers.")
146 | parser.add_argument('--cellix', type=int, default=4, help="LSTM layer to use as encoder.")
147 | opt = parser.parse_args()
148 |
149 | # Load char2idx dict from json file
150 | with open(opt.ch2ix) as f:
151 | char2idx = json.load(f)
152 |
153 | # Calculate vocab_size from char2idx dict
154 | vocab_size = len(char2idx)
155 |
156 | # Rebuild generative model from checkpoint
157 | generative_model = build_generative_model(vocab_size, opt.embed, opt.units, opt.layers, batch_size=1)
158 | generative_model.load_weights(tf.train.latest_checkpoint(opt.model))
159 | generative_model.build(tf.TensorShape([1, None]))
160 |
161 | # Build dataset from encoded labelled midis
162 | train_dataset = build_dataset(opt.train, generative_model, char2idx, opt.cellix)
163 | test_dataset = build_dataset(opt.test, generative_model, char2idx, opt.cellix)
164 |
165 | # Train model
166 | sentneuron_ixs, score = train_classifier_model(train_dataset, test_dataset)
167 |
168 | print("Total Neurons Used:", len(sentneuron_ixs), "\n", sentneuron_ixs)
169 | print("Test Accuracy:", score)
170 |
--------------------------------------------------------------------------------
/workspace/baseline/train_generative.py:
--------------------------------------------------------------------------------
1 | import os
2 | #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"\
3 | os.environ["CUDA_VISIBLE_DEVICES"] = '4'
4 | import json
5 | import argparse
6 | import numpy as np
7 | import tensorflow as tf
8 | import numpy as np
9 |
10 |
11 | import midi_encoder as me
12 |
13 | # Directory where the checkpoints will be saved
14 | TRAIN_DIR = "./trained/"
15 |
16 | #print('Version: ', tf.__version__)
17 | #print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
18 | #print("Check: ", tf.test.is_gpu_available(cuda_only=False, min_cuda_compute_capability=None))
19 |
20 |
21 | def generative_loss(labels, logits):
22 | return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
23 |
24 | def build_generative_model(vocab_size, embed_dim, lstm_units, lstm_layers, batch_size, dropout=0):
25 | model = tf.keras.Sequential()
26 |
27 | model.add(tf.keras.layers.Embedding(vocab_size, embed_dim, batch_input_shape=[batch_size, None]))
28 |
29 | for i in range(max(1, lstm_layers)):
30 | model.add(tf.keras.layers.LSTM(lstm_units, return_sequences=True, stateful=True, dropout=dropout, recurrent_dropout=dropout))
31 |
32 | model.add(tf.keras.layers.Dense(vocab_size))
33 |
34 | return model
35 |
36 | def build_char2idx(train_vocab, test_vocab):
37 | # Merge train and test vocabulary
38 | vocab = list(train_vocab | test_vocab)
39 | vocab.sort()
40 |
41 | # Calculate vocab size
42 | vocab_size = len(vocab)
43 |
44 | # Create dict to support char to index conversion
45 | char2idx = { char:i for i,char in enumerate(vocab) }
46 |
47 | # Save char2idx encoding as a json file for generate midi later
48 | with open(os.path.join(TRAIN_DIR, "char2idx.json"), "w") as f:
49 | json.dump(char2idx, f)
50 |
51 | return char2idx, vocab_size
52 |
53 | def build_dataset(text, char2idx, seq_length, batch_size, buffer_size=10000):
54 | text_as_int = np.array([char2idx[c] for c in text.split(" ")])
55 | char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
56 |
57 | sequences = char_dataset.batch(seq_length+1, drop_remainder=True)
58 |
59 | dataset = sequences.map(__split_input_target)
60 | dataset = dataset.shuffle(buffer_size).batch(batch_size, drop_remainder=True)
61 |
62 | return dataset
63 |
64 | def train_generative_model(model, train_dataset, test_dataset, epochs, learning_rate):
65 | # Compile model with given optimizer and defined loss
66 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
67 | model.compile(optimizer=optimizer, loss=generative_loss)
68 |
69 | # Name of the checkpoint files
70 | checkpoint_prefix = os.path.join(TRAIN_DIR, "generative_ckpt_{epoch}")
71 | my_callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, save_weights_only=True),
72 | tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=3, restore_best_weights=True)]
73 |
74 |
75 | return model.fit(train_dataset, epochs=epochs, validation_data=test_dataset, callbacks=my_callbacks)
76 |
77 | def __split_input_target(chunk):
78 | input_text = chunk[:-1]
79 | target_text = chunk[1:]
80 | return input_text, target_text
81 |
82 | if __name__ == "__main__":
83 |
84 | # Parse arguments
85 | parser = argparse.ArgumentParser(description='train_generative.py')
86 | parser.add_argument('--train', type=str, default='../data/train/', help="Train dataset.")
87 | parser.add_argument('--test' , type=str, default='../data/test/', help="Test dataset.")
88 | parser.add_argument('--model', type=str, required=False, help="Checkpoint dir.")
89 | parser.add_argument('--embed', type=int, default=256, help="Embedding size.")
90 | parser.add_argument('--units', type=int, default=512, help="LSTM units.")
91 | parser.add_argument('--layers', type=int, default=4, help="LSTM layers.")
92 | parser.add_argument('--batch', type=int, default=64, help="Batch size.")
93 | parser.add_argument('--epochs', type=int, default=30, help="Epochs.")
94 | parser.add_argument('--seqlen', type=int, default=256, help="Sequence lenght.")
95 | parser.add_argument('--lrate', type=float, default=0.00001, help="Learning rate.")
96 | parser.add_argument('--drop', type=float, default=0.05, help="Dropout.")
97 | opt = parser.parse_args()
98 |
99 | if not os.path.exists(TRAIN_DIR):
100 | os.makedirs(TRAIN_DIR)
101 |
102 | # Encode midi files as text with vocab
103 | train_text, train_vocab = me.load(opt.train)
104 | test_text, test_vocab = me.load(opt.test)
105 |
106 | # Build dictionary to map from char to integers
107 | char2idx, vocab_size = build_char2idx(train_vocab, test_vocab)
108 |
109 | # Build dataset from encoded unlabelled midis
110 | train_dataset = build_dataset(train_text, char2idx, opt.seqlen, opt.batch)
111 | test_dataset = build_dataset(test_text, char2idx, opt.seqlen, opt.batch)
112 |
113 | # Build generative model
114 | generative_model = build_generative_model(vocab_size, opt.embed, opt.units, opt.layers, opt.batch, opt.drop)
115 |
116 | if opt.model:
117 | # If pre-trained model was given as argument, load weights from disk
118 | print("Loading weights from {}...".format(opt.model))
119 | generative_model.load_weights(tf.train.latest_checkpoint(opt.model))
120 |
121 | # Train model
122 | history = train_generative_model(generative_model, train_dataset, test_dataset, opt.epochs, opt.lrate)
123 | print("Total of {} epochs used for training.".format(len(history.history['loss'])))
124 | loss_hist = history.history['loss']
125 | print("Best loss from history: ", np.min(loss_hist))
126 |
--------------------------------------------------------------------------------
/workspace/transformer/main_cp.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import ipdb
4 | from tqdm import tqdm
5 | import math
6 | import time
7 | import glob
8 | import datetime
9 | import random
10 | import pickle
11 | import json
12 | import numpy as np
13 | from collections import OrderedDict
14 | from argparse import ArgumentParser
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | import torch.optim as optim
20 | from torch.nn.utils import clip_grad_norm_
21 | from torch.utils.data import Dataset, DataLoader
22 |
23 |
24 | import saver
25 | from models import TransformerModel, network_paras
26 | from utils import write_midi, get_random_string
27 |
28 |
29 | ################################################################################
30 | # config
31 | ################################################################################
32 |
33 | parser = ArgumentParser()
34 | parser.add_argument("--mode",default="train",type=str,choices=["train", "inference"])
35 | parser.add_argument("--task_type",default="4-cls",type=str,choices=['4-cls', 'Arousal', 'Valence', 'ignore'])
36 | parser.add_argument("--gid", default= 0, type=int)
37 | parser.add_argument("--data_parallel", default= 0, type=int)
38 |
39 | parser.add_argument("--exp_name", default='output' , type=str)
40 | parser.add_argument("--load_ckt", default="none", type=str) #pre-train model
41 | parser.add_argument("--load_ckt_loss", default="25", type=str) #pre-train model
42 | parser.add_argument("--path_train_data", default='emopia', type=str)
43 | parser.add_argument("--data_root", default='../dataset/co-representation/', type=str)
44 | parser.add_argument("--load_dict", default="more_dictionary.pkl", type=str)
45 | parser.add_argument("--init_lr", default= 0.00001, type=float)
46 | # inference config
47 |
48 | parser.add_argument("--num_songs", default=5, type=int)
49 | parser.add_argument("--emo_tag", default=1, type=int)
50 | parser.add_argument("--out_dir", default='none', type=str)
51 | args = parser.parse_args()
52 |
53 | print('=== args ===')
54 | for arg in args.__dict__:
55 | print(arg, args.__dict__[arg])
56 | print('=== args ===')
57 | # time.sleep(10) #sleep to check again if args are right
58 |
59 |
60 | MODE = args.mode
61 | task_type = args.task_type
62 |
63 |
64 | ###--- data ---###
65 | path_data_root = args.data_root
66 |
67 | path_train_data = os.path.join(path_data_root, args.path_train_data + '_data.npz')
68 | path_dictionary = os.path.join(path_data_root, args.load_dict)
69 | path_train_idx = os.path.join(path_data_root, args.path_train_data + '_fn2idx_map.json')
70 | path_train_data_cls_idx = os.path.join(path_data_root, args.path_train_data + '_data_idx.npz')
71 |
72 | assert os.path.exists(path_train_data)
73 | assert os.path.exists(path_dictionary)
74 | assert os.path.exists(path_train_idx)
75 |
76 | # if the dataset has the emotion label, get the cls_idx for the dataloader
77 | if args.path_train_data == 'emopia':
78 | assert os.path.exists(path_train_data_cls_idx)
79 |
80 | ###--- training config ---###
81 |
82 | if MODE == 'train':
83 | path_exp = 'exp/' + args.exp_name
84 |
85 | if args.data_parallel > 0:
86 | batch_size = 8
87 | else:
88 | batch_size = 4 #4
89 |
90 | gid = args.gid
91 | init_lr = args.init_lr #0.0001
92 |
93 |
94 | ###--- fine-tuning & inference config ---###
95 | if args.load_ckt == 'none':
96 | info_load_model = None
97 | print('NO pre-trained model used')
98 |
99 | else:
100 | info_load_model = (
101 | # path to ckpt for loading
102 | 'exp/' + args.load_ckt,
103 | # loss
104 | args.load_ckt_loss
105 | )
106 |
107 |
108 | if args.out_dir == 'none':
109 | path_gendir = os.path.join('exp/' + args.load_ckt, 'gen_midis', 'loss_'+ args.load_ckt_loss)
110 | else:
111 | path_gendir = args.out_dir
112 |
113 | num_songs = args.num_songs
114 | emotion_tag = args.emo_tag
115 |
116 |
117 | ################################################################################
118 | # File IO
119 | ################################################################################
120 |
121 | if args.data_parallel == 0:
122 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gid)
123 |
124 |
125 | ##########################################################################################################################
126 | # Script
127 | ##########################################################################################################################
128 |
129 |
130 | class PEmoDataset(Dataset):
131 | def __init__(self,
132 |
133 | task_type):
134 |
135 | self.train_data = np.load(path_train_data)
136 | self.train_x = self.train_data['x']
137 | self.train_y = self.train_data['y']
138 | self.train_mask = self.train_data['mask']
139 |
140 | if task_type != 'ignore':
141 |
142 | self.cls_idx = np.load(path_train_data_cls_idx)
143 | self.cls_1_idx = self.cls_idx['cls_1_idx']
144 | self.cls_2_idx = self.cls_idx['cls_2_idx']
145 | self.cls_3_idx = self.cls_idx['cls_3_idx']
146 | self.cls_4_idx = self.cls_idx['cls_4_idx']
147 |
148 | if task_type == 'Arousal':
149 | print('preparing data for training "Arousal"')
150 | self.label_transfer('Arousal')
151 |
152 | elif task_type == 'Valence':
153 | print('preparing data for training "Valence"')
154 | self.label_transfer('Valence')
155 |
156 |
157 | self.train_x = torch.from_numpy(self.train_x).long()
158 | self.train_y = torch.from_numpy(self.train_y).long()
159 | self.train_mask = torch.from_numpy(self.train_mask).float()
160 |
161 |
162 | self.seq_len = self.train_x.shape[1]
163 | self.dim = self.train_x.shape[2]
164 |
165 | print('train_x: ', self.train_x.shape)
166 |
167 | def label_transfer(self, TYPE):
168 | if TYPE == 'Arousal':
169 | for i in range(self.train_x.shape[0]):
170 | if self.train_x[i][0][-1] in [1,2]:
171 | self.train_x[i][0][-1] = 1
172 | elif self.train_x[i][0][-1] in [3,4]:
173 | self.train_x[i][0][-1] = 2
174 |
175 | elif TYPE == 'Valence':
176 | for i in range(self.train_x.shape[0]):
177 | if self.train_x[i][0][-1] in [1,4]:
178 | self.train_x[i][0][-1] = 1
179 | elif self.train_x[i][0][-1] in [2,3]:
180 | self.train_x[i][0][-1] = 2
181 |
182 |
183 |
184 | def __getitem__(self, index):
185 | return self.train_x[index], self.train_y[index], self.train_mask[index]
186 |
187 |
188 | def __len__(self):
189 | return len(self.train_x)
190 |
191 |
192 | def prep_dataloader(task_type, batch_size, n_jobs=0):
193 |
194 | dataset = PEmoDataset(task_type)
195 |
196 | dataloader = DataLoader(
197 | dataset, batch_size,
198 | shuffle=False, drop_last=False,
199 | num_workers=n_jobs, pin_memory=True)
200 | return dataloader
201 |
202 |
203 |
204 |
205 | def train():
206 |
207 | myseed = 42069
208 | np.random.seed(myseed)
209 | torch.manual_seed(myseed)
210 | if torch.cuda.is_available():
211 | torch.cuda.manual_seed_all(myseed)
212 |
213 |
214 | # hyper params
215 | n_epoch = 4000
216 | max_grad_norm = 3
217 |
218 | # load
219 | dictionary = pickle.load(open(path_dictionary, 'rb'))
220 | event2word, word2event = dictionary
221 |
222 | train_loader = prep_dataloader(args.task_type, batch_size)
223 |
224 | # create saver
225 | saver_agent = saver.Saver(path_exp)
226 |
227 | # config
228 | n_class = [] # number of classes of each token. [56, 127, 18, 4, 85, 18, 41, 5] with key: [... , 25]
229 | for key in event2word.keys():
230 | n_class.append(len(dictionary[0][key]))
231 |
232 |
233 |
234 | n_token = len(n_class)
235 | # log
236 | print('num of classes:', n_class)
237 |
238 | # init
239 |
240 |
241 | if args.data_parallel > 0 and torch.cuda.count() > 1:
242 | print("Let's use", torch.cuda.device_count(), "GPUs!")
243 | net = TransformerModel(n_class, data_parallel=True)
244 | net = nn.DataParallel(net)
245 |
246 | else:
247 | net = TransformerModel(n_class)
248 |
249 | net.cuda()
250 | net.train()
251 | n_parameters = network_paras(net)
252 | print('n_parameters: {:,}'.format(n_parameters))
253 | saver_agent.add_summary_msg(
254 | ' > params amount: {:,d}'.format(n_parameters))
255 |
256 |
257 | # load model
258 | if info_load_model:
259 | path_ckpt = info_load_model[0] # path to ckpt dir
260 | loss = info_load_model[1] # loss
261 | name = 'loss_' + str(loss)
262 | path_saved_ckpt = os.path.join(path_ckpt, name + '_params.pt')
263 | print('[*] load model from:', path_saved_ckpt)
264 |
265 | try:
266 | net.load_state_dict(torch.load(path_saved_ckpt))
267 | except:
268 | # print('WARNING!!!!! Not the whole pre-train model is loaded, only load partial')
269 | # print('WARNING!!!!! Not the whole pre-train model is loaded, only load partial')
270 | # print('WARNING!!!!! Not the whole pre-train model is loaded, only load partial')
271 | # net.load_state_dict(torch.load(path_saved_ckpt), strict=False)
272 |
273 | state_dict = torch.load(path_saved_ckpt)
274 | new_state_dict = OrderedDict()
275 | for k, v in state_dict.items():
276 | name = k[7:]
277 | new_state_dict[name] = v
278 |
279 | net.load_state_dict(new_state_dict)
280 |
281 |
282 | # optimizers
283 | optimizer = optim.Adam(net.parameters(), lr=init_lr)
284 |
285 |
286 | # run
287 | start_time = time.time()
288 | for epoch in range(n_epoch):
289 | acc_loss = 0
290 | acc_losses = np.zeros(n_token)
291 |
292 |
293 | num_batch = len(train_loader)
294 | print(' num_batch:', num_batch)
295 |
296 | for bidx, (batch_x, batch_y, batch_mask) in enumerate(train_loader): # num_batch
297 | saver_agent.global_step_increment()
298 |
299 | batch_x = batch_x.cuda()
300 | batch_y = batch_y.cuda()
301 | batch_mask = batch_mask.cuda()
302 |
303 |
304 | losses = net(batch_x, batch_y, batch_mask)
305 |
306 | if args.data_parallel > 0:
307 | loss = 0
308 | calculated_loss = []
309 | for i in range(n_token):
310 |
311 | loss += ((losses[i][0][0] + losses[i][0][1]) / (losses[i][1][0] + losses[i][1][1]))
312 | calculated_loss.append((losses[i][0][0] + losses[i][0][1]) / (losses[i][1][0] + losses[i][1][1]))
313 | loss = loss / n_token
314 |
315 |
316 | else:
317 | loss = (losses[0] + losses[1] + losses[2] + losses[3] + losses[4] + losses[5] + losses[6] + losses[7]) / 8
318 |
319 |
320 | # Update
321 | net.zero_grad()
322 | loss.backward()
323 |
324 |
325 | if max_grad_norm is not None:
326 | clip_grad_norm_(net.parameters(), max_grad_norm)
327 | optimizer.step()
328 |
329 | if args.data_parallel > 0:
330 |
331 | sys.stdout.write('{}/{} | Loss: {:06f} | {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}\r'.format(
332 | bidx, num_batch, loss, calculated_loss[0], calculated_loss[1], calculated_loss[2], calculated_loss[3], calculated_loss[4], calculated_loss[5], calculated_loss[6], calculated_loss[7]))
333 | sys.stdout.flush()
334 |
335 |
336 | # acc
337 | acc_losses += np.array([l.item() for l in calculated_loss])
338 |
339 |
340 |
341 | else:
342 |
343 | sys.stdout.write('{}/{} | Loss: {:06f} | {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}\r'.format(
344 | bidx, num_batch, loss, losses[0], losses[1], losses[2], losses[3], losses[4], losses[5], losses[6], losses[7]))
345 | sys.stdout.flush()
346 |
347 |
348 | # acc
349 | acc_losses += np.array([l.item() for l in losses])
350 |
351 |
352 |
353 | acc_loss += loss.item()
354 |
355 | # log
356 | saver_agent.add_summary('batch loss', loss.item())
357 |
358 |
359 | # epoch loss
360 | runtime = time.time() - start_time
361 | epoch_loss = acc_loss / num_batch
362 | acc_losses = acc_losses / num_batch
363 | print('------------------------------------')
364 | print('epoch: {}/{} | Loss: {} | time: {}'.format(
365 | epoch, n_epoch, epoch_loss, str(datetime.timedelta(seconds=runtime))))
366 |
367 | each_loss_str = '{:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}\r'.format(
368 | acc_losses[0], acc_losses[1], acc_losses[2], acc_losses[3], acc_losses[4], acc_losses[5], acc_losses[6], acc_losses[7])
369 |
370 | print(' >', each_loss_str)
371 |
372 | saver_agent.add_summary('epoch loss', epoch_loss)
373 | saver_agent.add_summary('epoch each loss', each_loss_str)
374 |
375 | # save model, with policy
376 | loss = epoch_loss
377 | if 0.4 < loss <= 0.8:
378 | fn = int(loss * 10) * 10
379 | saver_agent.save_model(net, name='loss_' + str(fn))
380 | elif 0.08 < loss <= 0.40:
381 | fn = int(loss * 100)
382 | saver_agent.save_model(net, name='loss_' + str(fn))
383 | elif loss <= 0.08:
384 | print('Finished')
385 | return
386 | else:
387 | saver_agent.save_model(net, name='loss_high')
388 |
389 |
390 | def generate():
391 |
392 | # path
393 | path_ckpt = info_load_model[0] # path to ckpt dir
394 | loss = info_load_model[1] # loss
395 | name = 'loss_' + str(loss)
396 | path_saved_ckpt = os.path.join(path_ckpt, name + '_params.pt')
397 |
398 | # load
399 | dictionary = pickle.load(open(path_dictionary, 'rb'))
400 | event2word, word2event = dictionary
401 |
402 | # outdir
403 | os.makedirs(path_gendir, exist_ok=True)
404 |
405 | # config
406 | n_class = [] # num of classes for each token
407 | for key in event2word.keys():
408 | n_class.append(len(dictionary[0][key]))
409 |
410 |
411 | n_token = len(n_class)
412 |
413 | # init model
414 | net = TransformerModel(n_class, is_training=False)
415 | net.cuda()
416 | net.eval()
417 |
418 | # load model
419 | print('[*] load model from:', path_saved_ckpt)
420 |
421 |
422 | try:
423 | net.load_state_dict(torch.load(path_saved_ckpt))
424 | except:
425 | state_dict = torch.load(path_saved_ckpt)
426 | new_state_dict = OrderedDict()
427 | for k, v in state_dict.items():
428 | name = k[7:]
429 | new_state_dict[name] = v
430 |
431 | net.load_state_dict(new_state_dict)
432 |
433 |
434 | # gen
435 | start_time = time.time()
436 | song_time_list = []
437 | words_len_list = []
438 |
439 | cnt_tokens_all = 0
440 | sidx = 0
441 | while sidx < num_songs:
442 | # try:
443 | start_time = time.time()
444 | print('current idx:', sidx)
445 |
446 | if n_token == 8:
447 | path_outfile = os.path.join(path_gendir, 'emo_{}_{}'.format( str(emotion_tag), get_random_string(10)))
448 | res, _ = net.inference_from_scratch(dictionary, emotion_tag, n_token)
449 |
450 |
451 | if res is None:
452 | continue
453 | np.save(path_outfile + '.npy', res)
454 | write_midi(res, path_outfile + '.mid', word2event)
455 |
456 | song_time = time.time() - start_time
457 | word_len = len(res)
458 | print('song time:', song_time)
459 | print('word_len:', word_len)
460 | words_len_list.append(word_len)
461 | song_time_list.append(song_time)
462 |
463 | sidx += 1
464 |
465 |
466 | print('ave token time:', sum(words_len_list) / sum(song_time_list))
467 | print('ave song time:', np.mean(song_time_list))
468 |
469 | runtime_result = {
470 | 'song_time':song_time_list,
471 | 'words_len_list': words_len_list,
472 | 'ave token time:': sum(words_len_list) / sum(song_time_list),
473 | 'ave song time': float(np.mean(song_time_list)),
474 | }
475 |
476 | with open('runtime_stats.json', 'w') as f:
477 | json.dump(runtime_result, f)
478 |
479 |
480 |
481 |
482 |
483 | if __name__ == '__main__':
484 | # -- training -- #
485 | if MODE == 'train':
486 | train()
487 | # -- inference -- #
488 | elif MODE == 'inference':
489 | generate()
490 | else:
491 | pass
492 |
--------------------------------------------------------------------------------
/workspace/transformer/models.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from utils import sampling
8 |
9 | # import fast_transformers_local
10 | from fast_transformers.builders import TransformerEncoderBuilder as TransformerEncoderBuilder_local
11 | from fast_transformers.builders import RecurrentEncoderBuilder as RecurrentEncoderBuilder_local
12 | from fast_transformers.masking import TriangularCausalMask as TriangularCausalMask_local
13 |
14 |
15 | D_MODEL = 512
16 | N_LAYER = 12
17 | N_HEAD = 8
18 |
19 | ################################################################################
20 | # Model
21 | ################################################################################
22 |
23 | def network_paras(model):
24 | # compute only trainable params
25 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
26 | params = sum([np.prod(p.size()) for p in model_parameters])
27 | return params
28 |
29 |
30 |
31 | class Embeddings(nn.Module):
32 | def __init__(self, n_token, d_model):
33 | super(Embeddings, self).__init__()
34 | self.lut = nn.Embedding(n_token, d_model)
35 | self.d_model = d_model
36 |
37 | def forward(self, x):
38 | return self.lut(x) * math.sqrt(self.d_model)
39 |
40 |
41 | class PositionalEncoding(nn.Module):
42 | def __init__(self, d_model, dropout=0.1, max_len=20000):
43 | super(PositionalEncoding, self).__init__()
44 | self.dropout = nn.Dropout(p=dropout)
45 |
46 | pe = torch.zeros(max_len, d_model)
47 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
48 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
49 | pe[:, 0::2] = torch.sin(position * div_term)
50 | pe[:, 1::2] = torch.cos(position * div_term)
51 | pe = pe.unsqueeze(0)
52 | self.register_buffer('pe', pe)
53 |
54 | def forward(self, x):
55 | x = x + self.pe[:, :x.size(1), :]
56 | return self.dropout(x)
57 |
58 |
59 |
60 | class TransformerModel(nn.Module):
61 | def __init__(self, n_token, is_training=True, data_parallel=False):
62 | super(TransformerModel, self).__init__()
63 | self.data_parallel = data_parallel
64 | # --- params config --- #
65 | self.n_token = n_token # == n_class
66 | self.d_model = D_MODEL
67 | self.n_layer = N_LAYER #
68 | self.dropout = 0.1
69 | self.n_head = N_HEAD #
70 | self.d_head = D_MODEL // N_HEAD
71 | self.d_inner = 2048
72 | self.loss_func = nn.CrossEntropyLoss(reduction='none')
73 | if len(self.n_token) == 8:
74 | self.emb_sizes = [128, 256, 64, 32, 512, 128, 128, 128]
75 | elif len(self.n_token) == 9:
76 | self.emb_sizes = [128, 256, 64, 32, 512, 128, 128, 128, 128] #128
77 |
78 | # --- modules config --- #
79 | # embeddings
80 | print('>>>>>:', self.n_token)
81 | self.word_emb_tempo = Embeddings(self.n_token[0], self.emb_sizes[0])
82 | self.word_emb_chord = Embeddings(self.n_token[1], self.emb_sizes[1])
83 | self.word_emb_barbeat = Embeddings(self.n_token[2], self.emb_sizes[2])
84 | self.word_emb_type = Embeddings(self.n_token[3], self.emb_sizes[3])
85 | self.word_emb_pitch = Embeddings(self.n_token[4], self.emb_sizes[4])
86 | self.word_emb_duration = Embeddings(self.n_token[5], self.emb_sizes[5])
87 | self.word_emb_velocity = Embeddings(self.n_token[6], self.emb_sizes[6])
88 | self.word_emb_emotion = Embeddings(self.n_token[7], self.emb_sizes[7])
89 | if len(self.n_token) == 9:
90 | self.word_emb_key = Embeddings(self.n_token[8], self.emb_sizes[8])
91 | self.pos_emb = PositionalEncoding(self.d_model, self.dropout)
92 |
93 |
94 | # linear
95 | self.in_linear = nn.Linear(np.sum(self.emb_sizes), self.d_model)
96 |
97 | # encoder
98 | if is_training:
99 | # encoder (training)
100 | self.get_encoder('encoder')
101 |
102 |
103 | else:
104 | # encoder (inference)
105 | print(' [o] using RNN backend.')
106 | self.get_encoder('autoregred')
107 |
108 | # blend with type
109 | self.project_concat_type = nn.Linear(self.d_model + 32, self.d_model)
110 |
111 | # individual output
112 | self.proj_tempo = nn.Linear(self.d_model, self.n_token[0])
113 | self.proj_chord = nn.Linear(self.d_model, self.n_token[1])
114 | self.proj_barbeat = nn.Linear(self.d_model, self.n_token[2])
115 | self.proj_type = nn.Linear(self.d_model, self.n_token[3])
116 | self.proj_pitch = nn.Linear(self.d_model, self.n_token[4])
117 | self.proj_duration = nn.Linear(self.d_model, self.n_token[5])
118 | self.proj_velocity = nn.Linear(self.d_model, self.n_token[6])
119 | self.proj_emotion = nn.Linear(self.d_model, self.n_token[7])
120 | if len(self.n_token) == 9:
121 | self.proj_key = nn.Linear( self.d_model, self.n_token[8])
122 |
123 |
124 | def compute_loss(self, predict, target, loss_mask):
125 | if self.data_parallel:
126 | loss = self.loss_func(predict, target)
127 | loss = loss * loss_mask
128 | return torch.sum(loss), torch.sum(loss_mask)
129 | else:
130 | loss = self.loss_func(predict, target)
131 | loss = loss * loss_mask
132 | loss = torch.sum(loss) / torch.sum(loss_mask)
133 | return loss
134 |
135 |
136 |
137 | def forward(self, x, target, loss_mask):
138 |
139 |
140 | h, y_type = self.forward_hidden(x, is_training=True)
141 |
142 |
143 | if len(self.n_token) == 9:
144 | y_tempo, y_chord, y_barbeat, y_pitch, y_duration, y_velocity, y_emotion, y_key, emo_embd = self.forward_output(h, target)
145 | else:
146 | y_tempo, y_chord, y_barbeat, y_pitch, y_duration, y_velocity, y_emotion, emo_embd = self.forward_output(h, target)
147 |
148 |
149 |
150 | # reshape (b, s, f) -> (b, f, s)
151 | y_tempo = y_tempo[:, ...].permute(0, 2, 1)
152 | y_chord = y_chord[:, ...].permute(0, 2, 1)
153 | y_barbeat = y_barbeat[:, ...].permute(0, 2, 1)
154 | y_type = y_type[:, ...].permute(0, 2, 1)
155 | y_pitch = y_pitch[:, ...].permute(0, 2, 1)
156 | y_duration = y_duration[:, ...].permute(0, 2, 1)
157 | y_velocity = y_velocity[:, ...].permute(0, 2, 1)
158 | y_emotion = y_emotion[:, ...].permute(0, 2, 1)
159 | if len(self.n_token) == 9:
160 | y_key = y_key[:, ...].permute(0, 2, 1)
161 |
162 | # loss
163 | loss_tempo = self.compute_loss(
164 | y_tempo, target[..., 0], loss_mask)
165 | loss_chord = self.compute_loss(
166 | y_chord, target[..., 1], loss_mask)
167 | loss_barbeat = self.compute_loss(
168 | y_barbeat, target[..., 2], loss_mask)
169 | loss_type = self.compute_loss(
170 | y_type, target[..., 3], loss_mask)
171 | loss_pitch = self.compute_loss(
172 | y_pitch, target[..., 4], loss_mask)
173 | loss_duration = self.compute_loss(
174 | y_duration, target[..., 5], loss_mask)
175 | loss_velocity = self.compute_loss(
176 | y_velocity, target[..., 6], loss_mask)
177 | loss_emotion = self.compute_loss(
178 | y_emotion, target[..., 7], loss_mask)
179 |
180 |
181 | if len(self.n_token) == 9:
182 | loss_key = self.compute_loss(
183 | y_key, target[..., 8], loss_mask)
184 |
185 | return loss_tempo, loss_chord, loss_barbeat, loss_type, loss_pitch, loss_duration, loss_velocity, loss_emotion, loss_key
186 |
187 | else:
188 | return loss_tempo, loss_chord, loss_barbeat, loss_type, loss_pitch, loss_duration, loss_velocity, loss_emotion
189 |
190 |
191 |
192 | def get_encoder(self, TYPE):
193 | if TYPE == 'encoder':
194 | self.transformer_encoder = TransformerEncoderBuilder_local.from_kwargs(
195 | n_layers=self.n_layer,
196 | n_heads=self.n_head,
197 | query_dimensions=self.d_model//self.n_head,
198 | value_dimensions=self.d_model//self.n_head,
199 | feed_forward_dimensions=2048,
200 | activation='gelu',
201 | dropout=0.1,
202 | attention_type="causal-linear",
203 | ).get()
204 |
205 |
206 | elif TYPE == 'autoregred':
207 | self.transformer_encoder = RecurrentEncoderBuilder_local.from_kwargs(
208 | n_layers=self.n_layer,
209 | n_heads=self.n_head,
210 | query_dimensions=self.d_model//self.n_head,
211 | value_dimensions=self.d_model//self.n_head,
212 | feed_forward_dimensions=2048,
213 | activation='gelu',
214 | dropout=0.1,
215 | attention_type="causal-linear",
216 | ).get()
217 |
218 |
219 |
220 | def forward_hidden(self, x, memory=None, is_training=False):
221 | '''
222 | linear transformer: b x s x f
223 | x.shape=(bs, nf)
224 | '''
225 |
226 | # embeddings
227 | emb_tempo = self.word_emb_tempo(x[..., 0])
228 | emb_chord = self.word_emb_chord(x[..., 1])
229 | emb_barbeat = self.word_emb_barbeat(x[..., 2])
230 | emb_type = self.word_emb_type(x[..., 3])
231 | emb_pitch = self.word_emb_pitch(x[..., 4])
232 | emb_duration = self.word_emb_duration(x[..., 5])
233 | emb_velocity = self.word_emb_velocity(x[..., 6])
234 |
235 | emb_emotion = self.word_emb_emotion(x[..., 7])
236 |
237 | if len(self.n_token) == 9:
238 | emb_key = self.word_emb_key(x[..., 8])
239 |
240 | # same emotion class have same emb_emotion
241 |
242 | embs = torch.cat(
243 | [
244 | emb_tempo,
245 | emb_chord,
246 | emb_barbeat,
247 | emb_type,
248 | emb_pitch,
249 | emb_duration,
250 | emb_velocity,
251 | emb_emotion,
252 | emb_key
253 | ], dim=-1)
254 |
255 | else:
256 | embs = torch.cat(
257 | [
258 | emb_tempo,
259 | emb_chord,
260 | emb_barbeat,
261 | emb_type,
262 | emb_pitch,
263 | emb_duration,
264 | emb_velocity,
265 | emb_emotion
266 |
267 | ], dim=-1)
268 |
269 |
270 | emb_linear = self.in_linear(embs)
271 | pos_emb = self.pos_emb(emb_linear)
272 |
273 |
274 | # assert False
275 | layer_outputs = []
276 | # transformer
277 | if is_training:
278 | # mask
279 | attn_mask = TriangularCausalMask_local(pos_emb.size(1), device=x.device)
280 |
281 | h = self.transformer_encoder(pos_emb, attn_mask) # y: b x s x d_model
282 |
283 |
284 | # project type
285 | y_type = self.proj_type(h)
286 |
287 |
288 | return h, y_type
289 |
290 | else:
291 | pos_emb = pos_emb.squeeze(0)
292 |
293 | # self.get_encoder('autoregred')
294 | # self.transformer_encoder.cuda()
295 | h, memory = self.transformer_encoder(pos_emb, memory=memory) # y: s x d_model
296 |
297 | # project type
298 | y_type = self.proj_type(h)
299 |
300 | return h, y_type, memory
301 |
302 |
303 | def forward_output(self, h, y):
304 | '''
305 | for training
306 | '''
307 | # tf_skip_emption = self.word_emb_emotion(y[..., 7])
308 | tf_skip_type = self.word_emb_type(y[..., 3])
309 |
310 | emo_embd = h[:, 0]
311 |
312 | # project other
313 | y_concat_type = torch.cat([h, tf_skip_type], dim=-1)
314 | y_ = self.project_concat_type(y_concat_type)
315 |
316 | y_tempo = self.proj_tempo(y_)
317 | y_chord = self.proj_chord(y_)
318 | y_barbeat = self.proj_barbeat(y_)
319 | y_pitch = self.proj_pitch(y_)
320 | y_duration = self.proj_duration(y_)
321 | y_velocity = self.proj_velocity(y_)
322 | y_emotion = self.proj_emotion(y_)
323 |
324 | if len(self.n_token) == 9:
325 | y_key = self.proj_key(y_)
326 |
327 | return y_tempo, y_chord, y_barbeat, y_pitch, y_duration, y_velocity, y_emotion, y_key, emo_embd
328 |
329 | else:
330 | return y_tempo, y_chord, y_barbeat, y_pitch, y_duration, y_velocity, y_emotion, emo_embd
331 |
332 |
333 |
334 |
335 | def froward_output_sampling(self, h, y_type, is_training=False):
336 | '''
337 | for inference
338 | '''
339 |
340 | # sample type
341 | y_type_logit = y_type[0, :] # token class size
342 | cur_word_type = sampling(y_type_logit, p=0.90, is_training=is_training) # int
343 | if cur_word_type is None:
344 | return None, None
345 |
346 | if is_training:
347 | type_word_t = cur_word_type.long().unsqueeze(0).unsqueeze(0)
348 | else:
349 | type_word_t = torch.from_numpy(
350 | np.array([cur_word_type])).long().cuda().unsqueeze(0) # shape = (1,1)
351 |
352 | tf_skip_type = self.word_emb_type(type_word_t).squeeze(0) # shape = (1, embd_size)
353 |
354 |
355 | # concat
356 | y_concat_type = torch.cat([h, tf_skip_type], dim=-1)
357 | y_ = self.project_concat_type(y_concat_type)
358 |
359 | # project other
360 | y_tempo = self.proj_tempo(y_)
361 | y_chord = self.proj_chord(y_)
362 | y_barbeat = self.proj_barbeat(y_)
363 |
364 | y_pitch = self.proj_pitch(y_)
365 | y_duration = self.proj_duration(y_)
366 | y_velocity = self.proj_velocity(y_)
367 | y_emotion = self.proj_emotion(y_)
368 |
369 |
370 |
371 | # sampling gen_cond
372 | cur_word_tempo = sampling(y_tempo, t=1.2, p=0.9, is_training=is_training)
373 | cur_word_barbeat = sampling(y_barbeat, t=1.2, is_training=is_training)
374 | cur_word_chord = sampling(y_chord, p=0.99, is_training=is_training)
375 | cur_word_pitch = sampling(y_pitch, p=0.9, is_training=is_training)
376 | cur_word_duration = sampling(y_duration, t=2, p=0.9, is_training=is_training)
377 | cur_word_velocity = sampling(y_velocity, t=5, is_training=is_training)
378 |
379 | if len(self.n_token) == 9:
380 | y_key = self.proj_key(y_)
381 | cur_word_key = sampling(y_key, t=1.2, is_training=is_training)
382 |
383 | curs = [
384 | cur_word_tempo,
385 | cur_word_chord,
386 | cur_word_barbeat,
387 | cur_word_pitch,
388 | cur_word_duration,
389 | cur_word_velocity,
390 | cur_word_key
391 | ]
392 |
393 | else:
394 | curs = [
395 | cur_word_tempo,
396 | cur_word_chord,
397 | cur_word_barbeat,
398 | cur_word_pitch,
399 | cur_word_duration,
400 | cur_word_velocity
401 | ]
402 |
403 | if None in curs:
404 | return None, None
405 |
406 |
407 |
408 | if is_training:
409 | cur_word_emotion = torch.from_numpy(np.array([0])).long().cuda().squeeze(0)
410 | # collect
411 | next_arr = torch.tensor([
412 | cur_word_tempo,
413 | cur_word_chord,
414 | cur_word_barbeat,
415 | cur_word_type,
416 | cur_word_pitch,
417 | cur_word_duration,
418 | cur_word_velocity,
419 | cur_word_emotion
420 | ])
421 |
422 | else:
423 | cur_word_emotion = 0
424 |
425 |
426 | # collect
427 | if len(self.n_token) == 9:
428 | next_arr = np.array([
429 | cur_word_tempo,
430 | cur_word_chord,
431 | cur_word_barbeat,
432 | cur_word_type,
433 | cur_word_pitch,
434 | cur_word_duration,
435 | cur_word_velocity,
436 | cur_word_emotion,
437 | cur_word_key
438 | ])
439 | else:
440 | next_arr = np.array([
441 | cur_word_tempo,
442 | cur_word_chord,
443 | cur_word_barbeat,
444 | cur_word_type,
445 | cur_word_pitch,
446 | cur_word_duration,
447 | cur_word_velocity,
448 | cur_word_emotion
449 | ])
450 |
451 | return next_arr, y_emotion
452 |
453 |
454 |
455 |
456 |
457 | def inference_from_scratch(self, dictionary, emotion_tag, key_tag=None, n_token=8, display=True):
458 | event2word, word2event = dictionary
459 |
460 |
461 | classes = word2event.keys()
462 |
463 |
464 | def print_word_cp(cp):
465 |
466 | result = [word2event[k][cp[idx]] for idx, k in enumerate(classes)]
467 |
468 | for r in result:
469 | print('{:15s}'.format(str(r)), end=' | ')
470 | print('')
471 |
472 | generated_key = None
473 |
474 |
475 |
476 | if n_token == 9:
477 |
478 | if key_tag:
479 |
480 | target_emotion = [0, 0, 0, 1, 0, 0, 0, emotion_tag, 0]
481 | target_key = [0, 0, 0, 4, 0, 0, 0, 0, key_tag]
482 |
483 | init = np.array([
484 | target_emotion, # emotion
485 | target_key,
486 | [0, 0, 1, 2, 0, 0, 0, 0, 0] # bar
487 | ])
488 |
489 | else:
490 | target_emotion = [0, 0, 0, 1, 0, 0, 0, emotion_tag, 0]
491 | init = np.array([
492 | target_emotion, # emotion
493 | [0, 0, 1, 2, 0, 0, 0, 0, 0] # bar
494 | ])
495 |
496 | elif n_token == 8:
497 | target_emotion = [0, 0, 0, 1, 0, 0, 0, emotion_tag]
498 |
499 | init = np.array([
500 | target_emotion, # emotion
501 | [0, 0, 1, 2, 0, 0, 0, 0] # bar
502 | ])
503 |
504 |
505 | cnt_token = len(init)
506 | with torch.no_grad():
507 | final_res = []
508 | memory = None
509 | h = None
510 |
511 | cnt_bar = 1
512 | init_t = torch.from_numpy(init).long().cuda()
513 | print('------ initiate ------')
514 |
515 | if n_token == 9 and key_tag is None:
516 | # Emotion token
517 | step = 0
518 | if display:
519 | print_word_cp(init[step, :])
520 | input_ = init_t[step, :].unsqueeze(0).unsqueeze(0)
521 | final_res.append(init[step, :][None, ...])
522 | h, y_type, memory = self.forward_hidden(
523 | input_, memory, is_training=False)
524 |
525 | #generate KEY
526 | next_arr, y_emotion = self.froward_output_sampling(h, y_type)
527 | if next_arr is None:
528 | return None, None
529 |
530 | generated_key = next_arr[-1]
531 | final_res.append(next_arr[None, ...])
532 | if display:
533 | print_word_cp(next_arr)
534 | input_ = torch.from_numpy(next_arr).long().cuda()
535 | input_ = input_.unsqueeze(0).unsqueeze(0)
536 | h, y_type, memory = self.forward_hidden(
537 | input_, memory, is_training=False)
538 |
539 | # init bar
540 | step = 1
541 | print_word_cp(init[step, :])
542 | input_ = init_t[step, :].unsqueeze(0).unsqueeze(0)
543 | final_res.append(init[step, :][None, ...])
544 | h, y_type, memory = self.forward_hidden(
545 | input_, memory, is_training=False)
546 |
547 |
548 |
549 | else:
550 | for step in range(init.shape[0]):
551 |
552 | print_word_cp(init[step, :])
553 | input_ = init_t[step, :].unsqueeze(0).unsqueeze(0)
554 | final_res.append(init[step, :][None, ...])
555 |
556 | h, y_type, memory = self.forward_hidden(
557 | input_, memory, is_training=False)
558 |
559 |
560 |
561 |
562 | print('------ generate ------')
563 | while(True):
564 | # sample others
565 | next_arr, y_emotion = self.froward_output_sampling(h, y_type)
566 | if next_arr is None:
567 | return None, None
568 |
569 | final_res.append(next_arr[None, ...])
570 |
571 | if display:
572 | print('bar:', cnt_bar, end= ' ==')
573 | print_word_cp(next_arr)
574 |
575 | # forward
576 | input_ = torch.from_numpy(next_arr).long().cuda()
577 | input_ = input_.unsqueeze(0).unsqueeze(0)
578 | h, y_type, memory = self.forward_hidden(
579 | input_, memory, is_training=False)
580 |
581 | # end of sequence
582 | if word2event['type'][next_arr[3]] == 'EOS':
583 | break
584 |
585 | if word2event['bar-beat'][next_arr[2]] == 'Bar':
586 | cnt_bar += 1
587 |
588 | print('\n--------[Done]--------')
589 | final_res = np.concatenate(final_res)
590 | print(final_res.shape)
591 |
592 |
593 | return final_res, generated_key
594 |
595 |
--------------------------------------------------------------------------------
/workspace/transformer/saver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import logging
5 | import datetime
6 | import collections
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | class Saver(object):
12 | def __init__(
13 | self,
14 | exp_dir,
15 | mode='w'):
16 |
17 | self.exp_dir = exp_dir
18 | self.init_time = time.time()
19 | self.global_step = 0
20 |
21 | # makedirs
22 | os.makedirs(exp_dir, exist_ok=True)
23 |
24 | # logging config
25 | path_logger = os.path.join(exp_dir, 'log.txt')
26 | logging.basicConfig(
27 | level=logging.DEBUG,
28 | format='%(message)s',
29 | filename=path_logger,
30 | filemode=mode)
31 | self.logger = logging.getLogger('training monitor')
32 |
33 | def add_summary_msg(self, msg):
34 | self.logger.debug(msg)
35 |
36 | def add_summary(
37 | self,
38 | key,
39 | val,
40 | step=None,
41 | cur_time=None):
42 |
43 | if cur_time is None:
44 | cur_time = time.time() - self.init_time
45 | if step is None:
46 | step = self.global_step
47 |
48 | # write msg (key, val, step, time)
49 | if isinstance(val, float):
50 | msg_str = '{:10s} | {:.10f} | {:10d} | {}'.format(
51 | key,
52 | val,
53 | step,
54 | cur_time
55 | )
56 | else:
57 | msg_str = '{:10s} | {} | {:10d} | {}'.format(
58 | key,
59 | val,
60 | step,
61 | cur_time
62 | )
63 |
64 | self.logger.debug(msg_str)
65 |
66 | def save_model(
67 | self,
68 | model,
69 | optimizer=None,
70 | outdir=None,
71 | name='model'):
72 |
73 | if outdir is None:
74 | outdir = self.exp_dir
75 | print(' [*] saving model to {}, name: {}'.format(outdir, name))
76 | torch.save(model, os.path.join(outdir, name+'.pt'))
77 | torch.save(model.state_dict(), os.path.join(outdir, name+'_params.pt'))
78 |
79 | if optimizer is not None:
80 | torch.save(optimizer.state_dict(), os.path.join(outdir, name+'_opt.pt'))
81 |
82 | def load_model(
83 | self,
84 | path_exp,
85 | device='cpu',
86 | name='model.pt'):
87 |
88 | path_pt = os.path.join(path_exp, name)
89 | print(' [*] restoring model from', path_pt)
90 | model = torch.load(path_pt, map_location=torch.device(device))
91 | return model
92 |
93 | def global_step_increment(self):
94 | self.global_step += 1
95 |
96 | """
97 | file modes
98 | 'a':
99 | Opens a file for appending. The file pointer is at the end of the file if the file exists.
100 | That is, the file is in the append mode. If the file does not exist, it creates a new file for writing.
101 |
102 | 'w':
103 | Opens a file for writing only. Overwrites the file if the file exists.
104 | If the file does not exist, creates a new file for writing.
105 | """
106 |
107 | def make_loss_report(
108 | path_log,
109 | path_figure='loss.png',
110 | dpi=100):
111 |
112 | # load logfile
113 | monitor_vals = collections.defaultdict(list)
114 | with open(path_logfile, 'r') as f:
115 | for line in f:
116 | try:
117 | line = line.strip()
118 | key, val, step, acc_time = line.split(' | ')
119 | monitor_vals[key].append((float(val), int(step), acc_time))
120 | except:
121 | continue
122 |
123 | # collect
124 | step_train = [item[1] for item in monitor_vals['train loss']]
125 | vals_train = [item[0] for item in monitor_vals['train loss']]
126 |
127 | step_valid = [item[1] for item in monitor_vals['valid loss']]
128 | vals_valid = [item[0] for item in monitor_vals['valid loss']]
129 |
130 | x_min = step_valid[np.argmin(vals_valid)]
131 | y_min = min(vals_valid)
132 |
133 | # plot
134 | fig = plt.figure(dpi=dpi)
135 | plt.title('training process')
136 | plt.plot(step_train, vals_train, label='train')
137 | plt.plot(step_valid, vals_valid, label='valid')
138 | plt.yscale('log')
139 | plt.plot([x_min], [y_min], 'ro')
140 | plt.legend(loc='upper right')
141 | plt.tight_layout()
142 | plt.savefig(path_figure)
143 |
144 | '''
145 | author: wayn391@mastertones
146 | '''
147 |
148 | import os
149 | import time
150 | import torch
151 | import logging
152 | import datetime
153 | import collections
154 | import numpy as np
155 | import matplotlib.pyplot as plt
156 |
157 |
158 | class Saver(object):
159 | def __init__(
160 | self,
161 | exp_dir,
162 | mode='w'):
163 |
164 | self.exp_dir = exp_dir
165 | self.init_time = time.time()
166 | self.global_step = 0
167 |
168 | # makedirs
169 | os.makedirs(exp_dir, exist_ok=True)
170 |
171 | # logging config
172 | path_logger = os.path.join(exp_dir, 'log.txt')
173 | logging.basicConfig(
174 | level=logging.DEBUG,
175 | format='%(message)s',
176 | filename=path_logger,
177 | filemode=mode)
178 | self.logger = logging.getLogger('training monitor')
179 |
180 | def add_summary_msg(self, msg):
181 | self.logger.debug(msg)
182 |
183 | def add_summary(
184 | self,
185 | key,
186 | val,
187 | step=None,
188 | cur_time=None):
189 |
190 | if cur_time is None:
191 | cur_time = time.time() - self.init_time
192 | if step is None:
193 | step = self.global_step
194 |
195 | # write msg (key, val, step, time)
196 | if isinstance(val, float):
197 | msg_str = '{:10s} | {:.10f} | {:10d} | {}'.format(
198 | key,
199 | val,
200 | step,
201 | cur_time
202 | )
203 | else:
204 | msg_str = '{:10s} | {} | {:10d} | {}'.format(
205 | key,
206 | val,
207 | step,
208 | cur_time
209 | )
210 |
211 | self.logger.debug(msg_str)
212 |
213 | def save_model(
214 | self,
215 | model,
216 | optimizer=None,
217 | outdir=None,
218 | name='model'):
219 |
220 | if outdir is None:
221 | outdir = self.exp_dir
222 | print(' [*] saving model to {}, name: {}'.format(outdir, name))
223 | # torch.save(model, os.path.join(outdir, name+'.pt'))
224 | torch.save(model.state_dict(), os.path.join(outdir, name+'_params.pt'))
225 |
226 | if optimizer is not None:
227 | torch.save(optimizer.state_dict(), os.path.join(outdir, name+'_opt.pt'))
228 |
229 | def load_model(
230 | self,
231 | path_exp,
232 | device='cpu',
233 | name='model.pt'):
234 |
235 | path_pt = os.path.join(path_exp, name)
236 | print(' [*] restoring model from', path_pt)
237 | model = torch.load(path_pt, map_location=torch.device(device))
238 | return model
239 |
240 | def global_step_increment(self):
241 | self.global_step += 1
242 |
243 | """
244 | file modes
245 | 'a':
246 | Opens a file for appending. The file pointer is at the end of the file if the file exists.
247 | That is, the file is in the append mode. If the file does not exist, it creates a new file for writing.
248 |
249 | 'w':
250 | Opens a file for writing only. Overwrites the file if the file exists.
251 | If the file does not exist, creates a new file for writing.
252 | """
253 |
254 | def make_loss_report(
255 | path_log,
256 | path_figure='loss.png',
257 | dpi=100):
258 |
259 | # load logfile
260 | monitor_vals = collections.defaultdict(list)
261 | with open(path_logfile, 'r') as f:
262 | for line in f:
263 | try:
264 | line = line.strip()
265 | key, val, step, acc_time = line.split(' | ')
266 | monitor_vals[key].append((float(val), int(step), acc_time))
267 | except:
268 | continue
269 |
270 | # collect
271 | step_train = [item[1] for item in monitor_vals['train loss']]
272 | vals_train = [item[0] for item in monitor_vals['train loss']]
273 |
274 | step_valid = [item[1] for item in monitor_vals['valid loss']]
275 | vals_valid = [item[0] for item in monitor_vals['valid loss']]
276 |
277 | x_min = step_valid[np.argmin(vals_valid)]
278 | y_min = min(vals_valid)
279 |
280 | # plot
281 | fig = plt.figure(dpi=dpi)
282 | plt.title('training process')
283 | plt.plot(step_train, vals_train, label='train')
284 | plt.plot(step_valid, vals_valid, label='valid')
285 | plt.yscale('log')
286 | plt.plot([x_min], [y_min], 'ro')
287 | plt.legend(loc='upper right')
288 | plt.tight_layout()
289 | plt.savefig(path_figure)
290 |
291 |
--------------------------------------------------------------------------------
/workspace/transformer/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import glob
4 | import random
5 | import string
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from tqdm import tqdm
10 | import pickle
11 | import miditoolkit
12 | from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note
13 |
14 | BEAT_RESOL = 480
15 | BAR_RESOL = BEAT_RESOL * 4
16 | TICK_RESOL = BEAT_RESOL // 4
17 |
18 |
19 | def write_midi(words, path_outfile, word2event):
20 |
21 | class_keys = word2event.keys()
22 | # words = np.load(path_infile)
23 | midi_obj = miditoolkit.midi.parser.MidiFile()
24 |
25 | bar_cnt = 0
26 | cur_pos = 0
27 |
28 | all_notes = []
29 |
30 | cnt_error = 0
31 | for i in range(len(words)):
32 | vals = []
33 | for kidx, key in enumerate(class_keys):
34 | vals.append(word2event[key][words[i][kidx]])
35 | # print(vals)
36 |
37 | if vals[3] == 'Metrical':
38 | if vals[2] == 'Bar':
39 | bar_cnt += 1
40 | elif 'Beat' in vals[2]:
41 | beat_pos = int(vals[2].split('_')[1])
42 | cur_pos = bar_cnt * BAR_RESOL + beat_pos * TICK_RESOL
43 |
44 | # chord
45 | if vals[1] != 'CONTI' and vals[1] != 0:
46 | midi_obj.markers.append(
47 | Marker(text=str(vals[1]), time=cur_pos))
48 |
49 | if vals[0] != 'CONTI' and vals[0] != 0:
50 | tempo = int(vals[0].split('_')[-1])
51 | midi_obj.tempo_changes.append(
52 | TempoChange(tempo=tempo, time=cur_pos))
53 | else:
54 | pass
55 | elif vals[3] == 'Note':
56 |
57 | try:
58 | pitch = vals[4].split('_')[-1]
59 | duration = vals[5].split('_')[-1]
60 | velocity = vals[6].split('_')[-1]
61 |
62 | if int(duration) == 0:
63 | duration = 60
64 | end = cur_pos + int(duration)
65 |
66 | all_notes.append(
67 | Note(
68 | pitch=int(pitch),
69 | start=cur_pos,
70 | end=end,
71 | velocity=int(velocity))
72 | )
73 | except:
74 | continue
75 | else:
76 | pass
77 |
78 | # save midi
79 | piano_track = Instrument(0, is_drum=False, name='piano')
80 | piano_track.notes = all_notes
81 | midi_obj.instruments = [piano_track]
82 | midi_obj.dump(path_outfile)
83 |
84 |
85 | ################################################################################
86 | # Sampling
87 | ################################################################################
88 | # -- temperature -- #
89 | def softmax_with_temperature(logits, temperature):
90 | probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature))
91 | if np.isnan(probs).any():
92 | return None
93 | else:
94 | return probs
95 |
96 |
97 |
98 | ## gumbel
99 | def gumbel_softmax(logits, temperature):
100 | return F.gumbel_softmax(logits, tau=temperature, hard=True)
101 |
102 |
103 | def weighted_sampling(probs):
104 | probs /= sum(probs)
105 | sorted_probs = np.sort(probs)[::-1]
106 | sorted_index = np.argsort(probs)[::-1]
107 | word = np.random.choice(sorted_index, size=1, p=sorted_probs)[0]
108 | return word
109 |
110 |
111 | # -- nucleus -- #
112 | def nucleus(probs, p):
113 |
114 | probs /= (sum(probs) + 1e-5)
115 | sorted_probs = np.sort(probs)[::-1]
116 | sorted_index = np.argsort(probs)[::-1]
117 | cusum_sorted_probs = np.cumsum(sorted_probs)
118 | after_threshold = cusum_sorted_probs > p
119 | if sum(after_threshold) > 0:
120 | last_index = np.where(after_threshold)[0][0] + 1
121 | candi_index = sorted_index[:last_index]
122 | else:
123 | candi_index = sorted_index[:]
124 | candi_probs = [probs[i] for i in candi_index]
125 | candi_probs /= sum(candi_probs)
126 | try:
127 | word = np.random.choice(candi_index, size=1, p=candi_probs)[0]
128 | except:
129 | ipdb.set_trace()
130 | return word
131 |
132 |
133 |
134 | def sampling(logit, p=None, t=1.0, is_training=False):
135 |
136 |
137 | if is_training:
138 | logit = logit.squeeze()
139 | probs = gumbel_softmax(logits=logit, temperature=t)
140 |
141 | return torch.argmax(probs)
142 |
143 | else:
144 | logit = logit.squeeze().cpu().numpy()
145 | probs = softmax_with_temperature(logits=logit, temperature=t)
146 |
147 | if probs is None:
148 | return None
149 |
150 | if p is not None:
151 | cur_word = nucleus(probs, p=p)
152 |
153 | else:
154 | cur_word = weighted_sampling(probs)
155 | return cur_word
156 |
157 |
158 |
159 |
160 |
161 |
162 | def get_random_string(length):
163 | # choose from all lowercase letter
164 | letters = string.ascii_lowercase
165 | result_str = ''.join(random.choice(letters) for i in range(length))
166 | return result_str
167 |
168 |
169 |
170 | '''
171 | 假如 classifier 是 pre-trained,
172 | 那就要過 gumbel softmax (因為 classifer 是看 real data)
173 | 假如不是,classsifier 直接吃 training phase 的東西,那直接喂 logit/ softmax 也可以
174 |
175 | KEY: classifier 有沒有要看 real data
176 | 有: gumbel
177 | 沒有: 不用 gumbel
178 |
179 | 喂給 classifier 有 2 種選擇:
180 | 1. logit
181 | 2. probs -> 不行,因為還是得要在喂給 forward_hidden... 所以需要 word
182 | '''
183 |
184 | '''
185 | def compile_data(test_folder):
186 | MAX_LEN = 1024
187 | wordfiles = glob.glob(os.path.join(test_folder, '*.npy'))
188 | n_files = len(wordfiles)
189 |
190 | x_list = []
191 | y_list = []
192 | f_name_list = []
193 | for fidx in range(n_files):
194 | file = wordfiles[fidx]
195 |
196 | words = np.load(file)
197 | num_words = len(words)
198 |
199 | eos_arr = words[-1][None, ...]
200 | if num_words >= MAX_LEN - 2:
201 | print('too long!', num_words)
202 | continue
203 |
204 | x = words[:-1].copy() #without EOS
205 | y = words[1:].copy()
206 | seq_len = len(x)
207 | print(' > seq_len:', seq_len)
208 |
209 | # pad with eos
210 | pad = np.tile(
211 | eos_arr,
212 | (MAX_LEN-seq_len, 1))
213 |
214 | x = np.concatenate([x, pad], axis=0)
215 | y = np.concatenate([y, pad], axis=0)
216 |
217 | # collect
218 | if x.shape != (1024, 8):
219 | print(x.shape)
220 | exit()
221 | x_list.append(x.reshape(1, 1024,8))
222 | y_list.append(y.reshape(1, 1024,8))
223 | f_name_list.append(file)
224 |
225 | # x_final = np.array(x_list)
226 | # y_final = np.array(y_list)
227 | # f_name_list = np.array(f_name_list)
228 | return x_list, y_list, f_name_list
229 |
230 |
231 |
232 | def take_embd_scratch(embd_net, test_folder):
233 |
234 | # unpack
235 | batch_size = 1
236 | train_x, train_y, f_name_list = compile_data(test_folder)
237 | num_batch = len(train_x) // batch_size
238 |
239 |
240 |
241 | # load model
242 | path_ckpt = 'exp/0309-1857' # path to ckpt dir
243 | loss = 30 # loss
244 | name = 'loss_' + str(loss)
245 | path_saved_ckpt = os.path.join(path_ckpt, name + '_params.pt')
246 | print('[*] load model from:', path_saved_ckpt)
247 | embd_net.load_state_dict(torch.load(path_saved_ckpt))
248 |
249 | embd_filename = 'temp/embd_'
250 |
251 |
252 | while train_x:
253 |
254 | batch_x = train_x.pop()
255 | batch_y = train_y.pop()
256 | fname = f_name_list.pop()
257 |
258 | batch_x = torch.from_numpy(batch_x).long().cuda()
259 | batch_y = torch.from_numpy(batch_y).long().cuda()
260 |
261 |
262 | h, y_type, layer_outputs = embd_net.forward_hidden(batch_x)
263 |
264 | _, _, _, _, _, _, _, eight_y_ = embd_net.forward_output(layer_outputs[7], batch_y)
265 |
266 | np.save(embd_filename + fname.split('/')[-1], eight_y_.detach().cpu().numpy())
267 |
268 | return embd_filename
269 |
270 | '''
271 |
--------------------------------------------------------------------------------