├── .gitignore ├── LICENSE ├── README.md ├── _config.yml ├── cog.yaml ├── dataset ├── README.md ├── compile.py ├── corpus2events.py ├── event2words.py ├── get_cls_idx.ipynb └── midi2corpus.py ├── docs ├── _config.yml ├── _layouts │ └── default.html ├── assets │ ├── audio_samples │ │ ├── 1_lstm+GA │ │ │ ├── Q1 │ │ │ │ ├── gen_Q1_1.mp3 │ │ │ │ ├── gen_Q1_2.mp3 │ │ │ │ └── gen_Q1_3.mp3 │ │ │ ├── Q2 │ │ │ │ ├── gen_Q2_1.mp3 │ │ │ │ ├── gen_Q2_2.mp3 │ │ │ │ └── gen_Q2_3.mp3 │ │ │ ├── Q3 │ │ │ │ ├── gen_Q3_1.mp3 │ │ │ │ ├── gen_Q3_2.mp3 │ │ │ │ └── gen_Q3_3.mp3 │ │ │ └── Q4 │ │ │ │ ├── gen_Q4_1.mp3 │ │ │ │ ├── gen_Q4_2.mp3 │ │ │ │ └── gen_Q4_3.mp3 │ │ ├── 2_Transformer │ │ │ ├── Q1 │ │ │ │ ├── Q1_1.mp3 │ │ │ │ ├── Q1_2.mp3 │ │ │ │ └── Q1_3.mp3 │ │ │ ├── Q2 │ │ │ │ ├── Q2_1.mp3 │ │ │ │ ├── Q2_2.mp3 │ │ │ │ └── Q2_3.mp3 │ │ │ ├── Q3 │ │ │ │ ├── Q3_1.mp3 │ │ │ │ ├── Q3_2.mp3 │ │ │ │ └── Q3_3.mp3 │ │ │ └── Q4 │ │ │ │ ├── Q4_1.mp3 │ │ │ │ ├── Q4_2.mp3 │ │ │ │ └── Q4_3.mp3 │ │ └── 3_Pre-trained_Transformer │ │ │ ├── Q1 │ │ │ ├── Q1_1.mp3 │ │ │ ├── Q1_2.mp3 │ │ │ └── Q1_3.mp3 │ │ │ ├── Q2 │ │ │ ├── Q2_1.mp3 │ │ │ ├── Q2_2.mp3 │ │ │ └── Q2_3.mp3 │ │ │ ├── Q3 │ │ │ ├── Q3_1.mp3 │ │ │ ├── Q3_2.mp3 │ │ │ └── Q3_3.mp3 │ │ │ └── Q4 │ │ │ ├── Q4_1.mp3 │ │ │ ├── Q4_2.mp3 │ │ │ └── Q4_3.mp3 │ └── css │ │ └── style.css ├── img │ ├── VA.png │ ├── VA_Q.png │ ├── background.jpg │ ├── classification.png │ ├── cp.png │ ├── emopia.png │ ├── example.png │ ├── feature.png │ ├── key.png │ ├── number.png │ ├── pipeline.png │ ├── representation.png │ └── results.png ├── index.md └── init ├── predict.py ├── requirements.txt └── workspace ├── baseline ├── evolve_generative_base.py ├── midi_encoder.py ├── midi_generator.py ├── plot_results.py ├── plot_results_base.py ├── train_classifier.py └── train_generative.py └── transformer ├── generate.ipynb ├── main_cp.py ├── models.py ├── saver.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules/* 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | 135 | # data 136 | .mid 137 | .npy 138 | .npz 139 | .pkl 140 | .zip 141 | .json 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Hsiao-Tzu Hung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 |
5 | 6 | This is the official repository of **EMOPIA: A Multi-Modal Pop Piano Dataset For Emotion Recognition and Emotion-based Music Generation**. The paper has been accepted by International Society for Music Information Retrieval Conference 2021. 7 | 8 | - [Paper on Arxiv](https://arxiv.org/abs/2108.01374) 9 | - [Demo Page](https://annahung31.github.io/EMOPIA/) 10 | - [Interactive demo and Docker image on Replicate](https://replicate.ai/annahung31/emopia) 11 | - [Dataset at Zenodo](https://zenodo.org/record/5090631#.YPPo-JMzZz8) 12 | 13 | * Note: We release the transcribed MIDI files. As for the audio part, due to the copyright issue, we will only release the YouTube ID of the tracks and the timestamp of them. You might use [open source crawler](https://github.com/ytdl-org/youtube-dl) to get the audio file. 14 | 15 | 16 | ## Use EMOPIA by MusPy 17 | 1. install muspy 18 | ``` 19 | pip install muspy 20 | ``` 21 | 2. Use it in your script 22 | 23 | ``` 24 | import muspy 25 | 26 | emopia = muspy.EMOPIADataset("data/emopia/", download_and_extract=True) 27 | emopia.convert() 28 | music = emopia[0] 29 | print(music.annotations[0].annotation) 30 | ``` 31 | You can get the label of the piece of music: 32 | 33 | ``` 34 | {'emo_class': '1', 'YouTube_ID': '0vLPYiPN7qY', 'seg_id': '0'} 35 | ``` 36 | * `emo_class`: ['1', '2', '3', '4'] 37 | * `YouTube_ID`: the YouTube ID of this piece of music 38 | * `seg_id`: means this piece of music is the `i`th piece we take from this song. (zero-based). 39 | 40 | For more usage please refer to [MusPy](https://github.com/salu133445/muspy). 41 | 42 | 43 | # Emotion Classification 44 | 45 | For the classification models and codes, please refer to [this repo](https://github.com/SeungHeonDoh/EMOPIA_cls). 46 | 47 | 48 | # Conditional Generation 49 | 50 | ## Environment 51 | 52 | 1. Install PyTorch and fast transformer: 53 | - torch==1.7.0 (Please install it according to your [CUDA version](https://pytorch.org/get-started/previous-versions/#linux-and-windows-4).) 54 | - fast transformer : 55 | 56 | ``` 57 | pip install --user pytorch-fast-transformers 58 | ``` 59 | or refer to the original [repository](https://github.com/idiap/fast-transformers) 60 | 61 | 2. Other requirements: 62 | 63 | pip install -r requirements.txt 64 | 65 | 66 | ## Usage 67 | 68 | ### Inference 69 | 70 | **Option 1:** 71 | 72 | You can directly run the [generate script](./workspace/transformer/generate.ipynb) to generate pieces of musics and listen to them. 73 | 74 | 75 | **Option 2:** 76 | Or you might follow the steps as below. 77 | 78 | 1. Download the checkpoints and put them into `exp/` 79 | * Manually: 80 | - [Baseline](https://drive.google.com/file/d/1Q9vQYnNJ0hXBFwcxdWQgDNmzoW3MLl3h/view?usp=sharing) 81 | - [no-pretrained transformer](https://drive.google.com/file/d/1ZULJgBRu2Wb3jxFmGfAHP1v_tjoryFM7/view?usp=sharing) 82 | - [pretrained transformer](https://drive.google.com/file/d/19Seq18b2JNzOamEQMG1uarKjj27HJkHu/view?usp=sharing) 83 | 84 | * By commend: (install gdown: `pip install gdown`) 85 | ``` 86 | #baseline: 87 | gdown --id 1Q9vQYnNJ0hXBFwcxdWQgDNmzoW3MLl3h --output exp/baseline.zip 88 | 89 | # no-pretrained transformer 90 | gdown --id 1ZULJgBRu2Wb3jxFmGfAHP1v_tjoryFM7 --output exp/no-pretrained_transformer.zip 91 | 92 | # pretrained transformer 93 | gdown --id 19Seq18b2JNzOamEQMG1uarKjj27HJkHu --output exp/pretrained_transformer.zip 94 | ``` 95 | 96 | 97 | 98 | 2. Inference options: 99 | 100 | * `num_songs`: number of midis you want to generate. 101 | * `out_dir`: the folder where the generated midi will be saved. If not specified, midi files will be saved to `exp/MODEL_YOU_USED/gen_midis/`. 102 | * `task_type`: the task_type needs to be the same as the task specified during training. 103 | - '4-cls' for 4 class conditioning 104 | - 'Arousal' for only conditioning on arousal 105 | - 'Valence' for only conditioning on Valence 106 | - 'ignore' for not conditioning 107 | 108 | * `emo_tag`: the target class of emotion you want to assign. 109 | - If the task_type is '4-cls', emo_tag can be: 1,2,3,4, which refers to Q1, Q2, Q3, Q4. 110 | - If the task_type is 'Arousal', emo_tag can be: `1`, `2`. `1` for High arousal, `2` for Low arousal. 111 | - If the task_type is 'Valence', emo_tag can be: `1`, `2`. `1` for High Valence, `2` for Low Valence. 112 | 113 | 114 | 3. Inference 115 | 116 | ``` 117 | python main_cp.py --mode inference --task_type 4-cls --load_ckt CHECKPOINT_FOLDER --load_ckt_loss 25 --num_songs 10 --emo_tag 1 118 | ``` 119 | 120 | ### Train the model by yourself 121 | 1. Prepare the data follow the [steps](https://github.com/annahung31/EMOPIA/tree/main/dataset). 122 | 123 | 124 | 2. training options: 125 | 126 | * `exp_name`: the folder name that the checkpoints will be saved. 127 | * `data_parallel`: use data_parallel to let the training process faster. (0: not use, 1: use) 128 | * `task_type`: the conditioning task: 129 | - '4-cls' for 4 class conditioning 130 | - 'Arousal' for only conditioning on arousal 131 | - 'Valence' for only conditioning on Valence 132 | - 'ignore' for not conditioning 133 | 134 | a. Only train on EMOPIA: (`no-pretrained transformer` in the paper) 135 | 136 | python main_cp.py --path_train_data emopia --exp_name YOUR_EXP_NAME --load_ckt none 137 | 138 | b. Pre-train the transformer on `AILabs17k`: 139 | 140 | python main_cp.py --path_train_data ailabs --exp_name YOUR_EXP_NAME --load_ckt none --task_type ignore 141 | 142 | c. fine-tune the transformer on `EMOPIA`: 143 | For example, you want to use the pre-trained model stored in `0309-1857` with loss= `30` to fine-tune: 144 | 145 | python main_cp.py --path_train_data emopia --exp_name YOUR_EXP_NAME --load_ckt 0309-1857 --load_ckt_loss 30 146 | 147 | ### Baseline 148 | 1. The baseline code is based on the work of [Learning to Generate Music with Sentiment](https://github.com/lucasnfe/music-sentneuron) 149 | 150 | 2. According to the author, the model works best when it is trained with 4096 neurons of LSTM, but takes 12 days for training. Therefore, due to the limit of computational resource, we used the size of 512 neurons instead of 4096. 151 | 152 | 3. In order to use this as evaluation against our model, the target emotion classes is expanded to 4Q instead of just positive/negative. 153 | 154 | ## Authors 155 | 156 | The paper is a co-working project with [Joann](https://github.com/joann8512), [SeungHeon](https://github.com/SeungHeonDoh) and Nabin. This repository is mentained by [Joann](https://github.com/joann8512) and me. 157 | 158 | 159 | ## License 160 | The EMOPIA dataset is released under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). It is provided primarily for research purposes and is prohibited to be used for commercial purposes. When sharing your result based on EMOPIA, any act that defames the original music owner is strictly prohibited. 161 | 162 | 163 | The hand drawn piano in the logo comes from [Adobe stock](https://stock.adobe.com/tw/images/one-line-piano-instrument-design-hand-drawn-minimalistic-style-vector-illustration/327942843). The author is [Burak](https://stock.adobe.com/tw/contributor/206697762/burak?load_type=author&prev_url=detail). I purchased it under [standard](https://stock.adobe.com/tw/license-terms) license. 164 | 165 | ## Cite the dataset 166 | 167 | ``` 168 | @inproceedings{{EMOPIA}, 169 | author = {Hung, Hsiao-Tzu and Ching, Joann and Doh, Seungheon and Kim, Nabin and Nam, Juhan and Yang, Yi-Hsuan}, 170 | title = {{MOPIA}: A Multi-Modal Pop Piano Dataset For Emotion Recognition and Emotion-based Music Generation}, 171 | booktitle = {Proc. Int. Society for Music Information Retrieval Conf.}, 172 | year = {2021} 173 | } 174 | ``` 175 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | predict: "predict.py:Predictor" 2 | build: 3 | gpu: true 4 | system_packages: 5 | - "ffmpeg" 6 | - "fluidsynth" 7 | python_packages: 8 | - "torch==1.7.0" 9 | - "scikit-learn==0.24.1" 10 | - "seaborn==0.11.1" 11 | - "numpy==1.19.5" 12 | - "miditoolkit==0.1.14" 13 | - "pandas==1.1.5" 14 | - "tqdm==4.62.2" 15 | - "matplotlib==3.4.3" 16 | - "scipy==1.7.1" 17 | - "midiSynth==0.3" 18 | - "wheel==0.37.0" 19 | - "ipdb===0.13.9" 20 | - "pyfluidsynth==1.3.0" 21 | pre_install: 22 | - "pip install pytorch-fast-transformers==0.4.0" # needs to be installed after the main pip install 23 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Dataset processing 3 | 4 | ## If you want to use the processed data used in our paper... 5 | Download the npz files from [here](https://drive.google.com/file/d/17dKUf33ZsDbHC5Z6rkQclge3ppDTVCMP/view?usp=sharing), or download via gdown: 6 | 7 | ``` 8 | gdown --id 17dKUf33ZsDbHC5Z6rkQclge3ppDTVCMP 9 | unzip co-representation.zip 10 | ``` 11 | 12 | 13 | ## If you want to prepare data from scratch... 14 | The data pre-processing used in our paper is basically the same as [Compound-word-transformer](https://github.com/YatingMusic/compound-word-transformer/blob/main/dataset/Dataset.md). The difference is the emotion token part. 15 | 16 | 17 | 1. Run step 1-3 of [Compound-word-transformer dataset processing](https://github.com/YatingMusic/compound-word-transformer/blob/main/dataset/Dataset.md). 18 | 2. Change the path in the following scripts and run: 19 | 20 | a. quantize everything, add EOS, and prepare the emotion label from the filename. 21 | 22 | ``` 23 | python midi2corpus.py 24 | ``` 25 | 26 | b. transfer the corpus file to CP events. 27 | 28 | ``` 29 | python corpus2events.py 30 | ``` 31 | 32 | c. transfer the events to CP words, and build the dictionary. 33 | ``` 34 | python event2words.py 35 | ``` 36 | d. Compile the words file to npz file for training. 37 | ``` 38 | python compile.py 39 | ``` 40 | * Please note that I don't split data into train/test set, I use all the data for training bacause the task is to generate music from scratch and no need for validation. 41 | * The broken list in the compile.py is samples that encountered some issues during preprocessing and I just skip them. -------------------------------------------------------------------------------- /dataset/compile.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is from 3 | https://github.com/YatingMusic/compound-word-transformer/blob/main/dataset/representations/uncond/cp/compile.py 4 | ''' 5 | 6 | 7 | import os 8 | import json 9 | import pickle 10 | import numpy as np 11 | import ipdb 12 | 13 | 14 | TEST_AMOUNT = 50 15 | WINDOW_SIZE = 512 16 | GROUP_SIZE = 2 #7 17 | MAX_LEN = WINDOW_SIZE * GROUP_SIZE 18 | COMPILE_TARGET = 'linear' # 'linear', 'XL' 19 | print('[config] MAX_LEN:', MAX_LEN) 20 | 21 | broken_list = [ 22 | 'Q1_8izVTDgBQPc_0.mid.pkl.npy', 'Q1_uqRLEByE6pU_1.mid.pkl.npy', 'Q3_nBIls0laAAU_0.mid.pkl.npy', 23 | 'Q1_8rupdevqfuI_0.mid.pkl.npy', 'Q2_9v2WSpn4FCw_1.mid.pkl.npy', 'Q3_REq37pDAm3A_3.mid.pkl.npy', 24 | 'Q1_aYe-2Glruu4_3.mid.pkl.npy', 'Q2_BzqX-9TA-GY_2.mid.pkl.npy', 'Q3_RL_cmmNVLfs_0.mid.pkl.npy', 25 | 'Q1_FfwKrQyQ7WU_2.mid.pkl.npy', 'Q2_ivCNV47tsRw_1.mid.pkl.npy', 'Q3_wfXSdMsd4q8_4.mid.pkl.npy', 26 | 'Q1_GY3f6ckBVkA_1.mid.pkl.npy', 'Q2_RiQMuhk_SuQ_1.mid.pkl.npy', 'Q3_wqc8iqbDsGM_0.mid.pkl.npy', 27 | 'Q1_Jn9r0avp0fY_1.mid.pkl.npy', 'Q3_bbU31JLtlug_1.mid.pkl.npy', 'Q4_OUb9uaOlWAM_0.mid.pkl.npy', 28 | 'Q1_NGE9ynTJABg_0.mid.pkl.npy', 'Q3_c6CwY8Gbw0c_2.mid.pkl.npy', 'Q4_V3Y9L4UOcpk_1.mid.pkl.npy', 29 | 'Q1_QwsQ8ejbMKg_1.mid.pkl.npy', 'Q3_kDGmND1BgmA_1.mid.pkl.npy'] 30 | 31 | 32 | 33 | def traverse_dir( 34 | root_dir, 35 | extension=('mid', 'MID'), 36 | amount=None, 37 | str_=None, 38 | is_pure=False, 39 | verbose=False, 40 | is_sort=False, 41 | is_ext=True): 42 | if verbose: 43 | print('[*] Scanning...') 44 | file_list = [] 45 | cnt = 0 46 | for root, _, files in os.walk(root_dir): 47 | for file in files: 48 | if file in broken_list: 49 | continue 50 | if file.endswith(extension): 51 | if (amount is not None) and (cnt == amount): 52 | break 53 | if str_ is not None: 54 | if str_ not in file: 55 | continue 56 | mix_path = os.path.join(root, file) 57 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 58 | if not is_ext: 59 | ext = pure_path.split('.')[-1] 60 | pure_path = pure_path[:-(len(ext)+1)] 61 | if verbose: 62 | print(pure_path) 63 | file_list.append(pure_path) 64 | cnt += 1 65 | if verbose: 66 | print('Total: %d files' % len(file_list)) 67 | print('Done!!!') 68 | if is_sort: 69 | file_list.sort() 70 | return file_list 71 | 72 | 73 | if __name__ == '__main__': 74 | # paths 75 | path_root = '.' 76 | path_indir = os.path.join( path_root, 'words') 77 | 78 | # load all words 79 | wordfiles = traverse_dir( 80 | path_indir, 81 | extension=('npy')) 82 | n_files = len(wordfiles) 83 | 84 | # init 85 | x_list = [] 86 | y_list = [] 87 | mask_list = [] 88 | seq_len_list = [] 89 | num_groups_list = [] 90 | name_list = [] 91 | 92 | # process 93 | for fidx in range(n_files): 94 | print('--[{}/{}]-----'.format(fidx+1, n_files)) 95 | file = wordfiles[fidx] 96 | print(file) 97 | try: 98 | words = np.load(file) 99 | except: 100 | print(fidx) 101 | 102 | import ipdb 103 | ipdb.set_trace() 104 | num_words = len(words) 105 | 106 | eos_arr = words[-1][None, ...] 107 | 108 | if num_words >= MAX_LEN - 2: # 2 for room 109 | words = words[:MAX_LEN-2] 110 | 111 | # arrange IO 112 | x = words[:-1].copy() #without EOS 113 | y = words[1:].copy() 114 | seq_len = len(x) 115 | print(' > seq_len:', seq_len) 116 | 117 | # pad with eos 118 | pad = np.tile( 119 | eos_arr, 120 | (MAX_LEN-seq_len, 1)) 121 | 122 | x = np.concatenate([x, pad], axis=0) 123 | y = np.concatenate([y, pad], axis=0) 124 | mask = np.concatenate( 125 | [np.ones(seq_len), np.zeros(MAX_LEN-seq_len)]) 126 | 127 | # collect 128 | if x.shape != (1024, 8): 129 | print(x.shape) 130 | exit() 131 | 132 | x_list.append(x) 133 | y_list.append(y) 134 | mask_list.append(mask) 135 | seq_len_list.append(seq_len) 136 | num_groups_list.append(int(np.ceil(seq_len/WINDOW_SIZE))) 137 | name_list.append(file) 138 | 139 | # sort by length (descending) 140 | zipped = zip(seq_len_list, x_list, y_list, mask_list, num_groups_list, name_list) 141 | seq_len_list, x_list, y_list, mask_list, num_groups_list, name_list = zip( 142 | *sorted(zipped, key=lambda x: -x[0])) 143 | 144 | print('\n\n[Finished]') 145 | print(' compile target:', COMPILE_TARGET) 146 | if COMPILE_TARGET == 'XL': 147 | # reshape 148 | x_final = np.array(x_list).reshape(len(x_list), GROUP_SIZE, WINDOW_SIZE, -1) 149 | y_final = np.array(y_list).reshape(len(x_list), GROUP_SIZE, WINDOW_SIZE, -1) 150 | mask_final = np.array(mask_list).reshape(-1, GROUP_SIZE, WINDOW_SIZE) 151 | elif COMPILE_TARGET == 'linear': 152 | 153 | x_final = np.array(x_list) 154 | y_final = np.array(y_list) 155 | mask_final = np.array(mask_list) 156 | else: 157 | raise ValueError('Unknown target:', COMPILE_TARGET) 158 | 159 | # check 160 | num_samples = len(seq_len_list) 161 | print(' > count:', ) 162 | print(' > x_final:', x_final.shape) 163 | print(' > y_final:', y_final.shape) 164 | print(' > mask_final:', mask_final.shape) 165 | 166 | train_idx = [] 167 | 168 | # validation filename map 169 | fn2idx_map = { 170 | 'fn2idx': dict(), 171 | 'idx2fn': dict(), 172 | } 173 | 174 | # training filename map 175 | train_fn2idx_map = { 176 | 'fn2idx': dict(), 177 | 'idx2fn': dict(), 178 | } 179 | 180 | name_list = [x.split('/')[-1].split('.')[0] for x in name_list] 181 | # run split 182 | train_cnt = 0 183 | for nidx, n in enumerate(name_list): 184 | 185 | train_idx.append(nidx) 186 | train_fn2idx_map['fn2idx'][n] = train_cnt 187 | train_fn2idx_map['idx2fn'][train_cnt] = n 188 | train_cnt += 1 189 | 190 | train_idx = np.array(train_idx) 191 | 192 | # save train map 193 | path_train_fn2idx_map = os.path.join(path_root, 'train_fn2idx_map.json') 194 | with open(path_train_fn2idx_map, 'w') as f: 195 | json.dump(train_fn2idx_map, f) 196 | 197 | # save train 198 | path_train = os.path.join(path_root, 'train_data_{}'.format(COMPILE_TARGET)) 199 | path_train += '.npz' 200 | print('save to', path_train) 201 | np.savez( 202 | path_train, 203 | x=x_final[train_idx], 204 | y=y_final[train_idx], 205 | mask=mask_final[train_idx], 206 | seq_len=np.array(seq_len_list)[train_idx], 207 | num_groups=np.array(num_groups_list)[train_idx] 208 | ) 209 | 210 | print('---') 211 | print(' > train x:', x_final[train_idx].shape) 212 | 213 | -------------------------------------------------------------------------------- /dataset/corpus2events.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | # config 9 | BEAT_RESOL = 480 10 | BAR_RESOL = BEAT_RESOL * 4 11 | TICK_RESOL = BEAT_RESOL // 4 12 | 13 | 14 | # utilities 15 | def plot_hist(data, path_outfile): 16 | print('[Fig] >> {}'.format(path_outfile)) 17 | data_mean = np.mean(data) 18 | data_std = np.std(data) 19 | 20 | print('mean:', data_mean) 21 | print(' std:', data_std) 22 | 23 | plt.figure(dpi=100) 24 | plt.hist(data, bins=50) 25 | plt.title('mean: {:.3f}_std: {:.3f}'.format(data_mean, data_std)) 26 | plt.savefig(path_outfile) 27 | plt.close() 28 | 29 | def traverse_dir( 30 | root_dir, 31 | extension=('mid', 'MID'), 32 | amount=None, 33 | str_=None, 34 | is_pure=False, 35 | verbose=False, 36 | is_sort=False, 37 | is_ext=True): 38 | if verbose: 39 | print('[*] Scanning...') 40 | file_list = [] 41 | cnt = 0 42 | for root, _, files in os.walk(root_dir): 43 | for file in files: 44 | if file.endswith(extension): 45 | if (amount is not None) and (cnt == amount): 46 | break 47 | if str_ is not None: 48 | if str_ not in file: 49 | continue 50 | mix_path = os.path.join(root, file) 51 | pure_path = mix_path[len(root_dir):] if is_pure else mix_path 52 | if not is_ext: 53 | ext = pure_path.split('.')[-1] 54 | pure_path = pure_path[:-(len(ext)+1)] 55 | if verbose: 56 | print(pure_path) 57 | 58 | 59 | file_list.append(pure_path) 60 | cnt += 1 61 | if verbose: 62 | print('Total: %d files' % len(file_list)) 63 | print('Done!!!') 64 | if is_sort: 65 | file_list.sort() 66 | return file_list 67 | 68 | # ---- define event ---- # #todo 69 | ''' 8 kinds: 70 | tempo: 0: IGN 71 | 1: no change 72 | int: tempo 73 | chord: 0: IGN 74 | 1: no change 75 | str: chord types 76 | bar-beat: 0: IGN 77 | int: beat position (1...16) 78 | int: bar (bar) 79 | type: 0: eos 80 | 1: metrical 81 | 2: note 82 | 3: emotion 83 | duration: 0: IGN 84 | int: length 85 | pitch: 0: IGN 86 | int: pitch 87 | velocity: 0: IGN 88 | int: velocity 89 | emotion: 0: IGN 90 | 1: Q1 91 | 2: Q2 92 | 3: Q3 93 | 4: Q4 94 | ''' 95 | 96 | # emotion map 97 | emo_map = { 98 | 'Q1': 1, 99 | 'Q2': 2, 100 | 'Q3': 3, 101 | 'Q4': 4, 102 | 103 | } 104 | 105 | 106 | # event template 107 | compound_event = { 108 | 'tempo': 0, 109 | 'chord': 0, 110 | 'bar-beat': 0, 111 | 'type': 0, 112 | 'pitch': 0, 113 | 'duration': 0, 114 | 'velocity': 0, 115 | 'emotion': 0 116 | } 117 | 118 | def create_emo_event(emo_tag): 119 | emo_event = compound_event.copy() 120 | emo_event['emotion'] = emo_tag 121 | emo_event['type'] = 'Emotion' 122 | return emo_event 123 | 124 | def create_bar_event(): 125 | meter_event = compound_event.copy() 126 | meter_event['bar-beat'] = 'Bar' 127 | meter_event['type'] = 'Metrical' 128 | return meter_event 129 | 130 | 131 | def create_piano_metrical_event(tempo, chord, pos): 132 | meter_event = compound_event.copy() 133 | meter_event['tempo'] = tempo 134 | meter_event['chord'] = chord 135 | meter_event['bar-beat'] = pos 136 | #todo 137 | meter_event['type'] = 'Metrical' 138 | return meter_event 139 | 140 | 141 | def create_piano_note_event(pitch, duration, velocity): 142 | note_event = compound_event.copy() 143 | note_event['pitch'] = pitch 144 | note_event['duration'] = duration 145 | note_event['velocity'] = velocity 146 | note_event['type'] = 'Note' 147 | return note_event 148 | 149 | 150 | def create_eos_event(): 151 | eos_event = compound_event.copy() 152 | eos_event['type'] = 'EOS' 153 | return eos_event 154 | 155 | 156 | # ----------------------------------------------- # 157 | # core functions 158 | def corpus2event_cp(path_infile, path_outfile): 159 | ''' 160 | task: 2 track 161 | 1: piano (note + tempo) 162 | --- 163 | remove duplicate position tokens 164 | ''' 165 | 166 | data = pickle.load(open(path_infile, 'rb')) 167 | 168 | 169 | # global tag 170 | global_end = data['metadata']['last_bar'] * BAR_RESOL 171 | emo_tag = emo_map[data['metadata']['emotion']] 172 | 173 | # process 174 | final_sequence = [] 175 | final_sequence.append(create_emo_event(emo_tag)) 176 | for bar_step in range(0, global_end, BAR_RESOL): 177 | final_sequence.append(create_bar_event()) 178 | 179 | # --- piano track --- # 180 | for timing in range(bar_step, bar_step + BAR_RESOL, TICK_RESOL): 181 | pos_on = False 182 | pos_events = [] 183 | pos_text = 'Beat_' + str((timing-bar_step)//TICK_RESOL) 184 | 185 | # unpack 186 | t_chords = data['chords'][timing] 187 | t_tempos = data['tempos'][timing] 188 | t_notes = data['notes'][0][timing] # piano track 189 | 190 | # metrical 191 | #todo 192 | if len(t_tempos) or len(t_chords): 193 | # chord 194 | if len(t_chords): 195 | 196 | root, quality, bass = t_chords[-1].text.split('_') 197 | chord_text = root+'_'+quality 198 | else: 199 | chord_text = 'CONTI' 200 | 201 | # tempo 202 | if len(t_tempos): 203 | tempo_text = 'Tempo_' + str(t_tempos[-1].tempo) 204 | else: 205 | tempo_text = 'CONTI' 206 | 207 | # create 208 | pos_events.append( 209 | create_piano_metrical_event( 210 | tempo_text, chord_text, pos_text)) 211 | pos_on = True 212 | 213 | # note 214 | if len(t_notes): 215 | if not pos_on: 216 | pos_events.append( 217 | create_piano_metrical_event( 218 | 'CONTI', 'CONTI', pos_text)) 219 | 220 | for note in t_notes: 221 | note_pitch_text = 'Note_Pitch_' + str(note.pitch) 222 | note_duration_text = 'Note_Duration_' + str(note.duration) 223 | note_velocity_text = 'Note_Velocity_' + str(note.velocity) 224 | 225 | pos_events.append( 226 | create_piano_note_event( 227 | note_pitch_text, 228 | note_duration_text, 229 | note_velocity_text)) 230 | 231 | # collect & beat 232 | if len(pos_events): 233 | final_sequence.extend(pos_events) 234 | 235 | # BAR ending 236 | final_sequence.append(create_bar_event()) 237 | 238 | # EOS 239 | final_sequence.append(create_eos_event()) 240 | 241 | # save 242 | fn = os.path.basename(path_outfile) 243 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True) 244 | pickle.dump(final_sequence, open(path_outfile, 'wb')) 245 | 246 | return len(final_sequence) 247 | 248 | 249 | if __name__ == '__main__': 250 | # paths 251 | path_root = '.' 252 | path_indir = 'input_dir' 253 | path_outdir = 'events' 254 | os.makedirs(path_outdir, exist_ok=True) 255 | 256 | # list files 257 | midifiles = traverse_dir( 258 | path_indir, 259 | extension=('pkl'), 260 | is_pure=True, 261 | is_sort=True) 262 | n_files = len(midifiles) 263 | print('num files:', n_files) 264 | 265 | # run all 266 | len_list = [] 267 | paths = [] 268 | for fidx in range(n_files): 269 | path_midi = midifiles[fidx] 270 | print('{}/{}'.format(fidx+1, n_files)) 271 | 272 | # paths 273 | path_infile = os.path.join(path_indir, path_midi) 274 | path_outfile = os.path.join(path_outdir, path_midi) 275 | 276 | # proc 277 | num_tokens = corpus2event_cp(path_infile, path_outfile) 278 | print(' > num_token:', num_tokens) 279 | len_list.append(num_tokens) 280 | paths.append(path_midi) 281 | 282 | 283 | 284 | plot_hist( 285 | len_list, 286 | os.path.join(path_root, 'num_tokens.png') 287 | ) 288 | 289 | 290 | d = {'filename': paths, 'num_tokens': len_list} 291 | df = pd.DataFrame(data=d) 292 | df.to_csv('len_token.csv', index=False) -------------------------------------------------------------------------------- /dataset/event2words.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | This code is modify from: 4 | https://github.com/YatingMusic/compound-word-transformer/blob/main/dataset/representations/uncond/cp/events2words.py 5 | 6 | ''' 7 | 8 | import os 9 | import pickle 10 | import numpy as np 11 | import collections 12 | 13 | 14 | def traverse_dir( 15 | root_dir, 16 | extension=('mid', 'MID'), 17 | amount=None, 18 | str_=None, 19 | is_pure=False, 20 | verbose=False, 21 | is_sort=False, 22 | is_ext=True): 23 | if verbose: 24 | print('[*] Scanning...') 25 | file_list = [] 26 | cnt = 0 27 | for root, _, files in os.walk(root_dir): 28 | for file in files: 29 | if file.endswith(extension): 30 | if (amount is not None) and (cnt == amount): 31 | break 32 | if str_ is not None: 33 | if str_ not in file: 34 | continue 35 | mix_path = os.path.join(root, file) 36 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 37 | if not is_ext: 38 | ext = pure_path.split('.')[-1] 39 | pure_path = pure_path[:-(len(ext)+1)] 40 | if verbose: 41 | print(pure_path) 42 | file_list.append(pure_path) 43 | cnt += 1 44 | if verbose: 45 | print('Total: %d files' % len(file_list)) 46 | print('Done!!!') 47 | if is_sort: 48 | file_list.sort() 49 | return file_list 50 | 51 | 52 | def build_dict(path_root, path_indir, eventfiles, path_dict): 53 | class_keys = pickle.load( 54 | open(os.path.join(path_indir, eventfiles[0]), 'rb'))[0].keys() 55 | print('class keys:', class_keys) 56 | 57 | # define dictionary 58 | event2word = {} 59 | word2event = {} 60 | 61 | corpus_kv = collections.defaultdict(list) 62 | for file in eventfiles: 63 | for event in pickle.load(open( 64 | os.path.join(path_indir, file), 'rb')): 65 | for key in class_keys: 66 | corpus_kv[key].append(event[key]) 67 | 68 | for ckey in class_keys: 69 | class_unique_vals = sorted( 70 | set(corpus_kv[ckey]), key=lambda x: (not isinstance(x, int), x)) 71 | event2word[ckey] = {key: i for i, key in enumerate(class_unique_vals)} 72 | word2event[ckey] = {i: key for i, key in enumerate(class_unique_vals)} 73 | 74 | # print 75 | print('[class size]') 76 | for key in class_keys: 77 | print(' > {:10s}: {}'.format(key, len(event2word[key]))) 78 | 79 | # save 80 | pickle.dump((event2word, word2event), open(path_dict, 'wb')) 81 | 82 | 83 | if __name__ == '__main__': 84 | # paths 85 | path_root = '.' 86 | path_indir = os.path.join(path_root, 'events') 87 | path_outdir = os.path.join(path_root, 'wordstemp') 88 | path_dict = os.path.join(path_root, 'dictionary.pkl') 89 | os.makedirs(path_outdir, exist_ok=True) 90 | 91 | # list files 92 | eventfiles = traverse_dir( 93 | path_indir, 94 | is_pure=True, 95 | is_sort=True, 96 | extension=('pkl')) 97 | n_files = len(eventfiles) 98 | print('num fiels:', n_files) 99 | 100 | class_keys = pickle.load( 101 | open(os.path.join(path_indir, eventfiles[0]), 'rb'))[0].keys() 102 | print('class keys:', class_keys) 103 | 104 | 105 | # --- build dictionary --- # 106 | # all files 107 | if not os.path.exists(path_dict): 108 | build_dict(path_root, path_indir, eventfiles, path_dict) 109 | 110 | 111 | # --- compile each --- # 112 | # reload 113 | event2word, word2event = pickle.load(open(path_dict, 'rb')) 114 | for fidx in range(len(eventfiles)): 115 | file = eventfiles[fidx] 116 | events_list = pickle.load(open( 117 | os.path.join(path_indir, file), 'rb')) 118 | fn = os.path.basename(file) 119 | path_outfile = os.path.join(path_outdir, fn) 120 | 121 | print('({}/{})'.format(fidx, len(eventfiles))) 122 | print(' > from:', file) 123 | print(' > to:', path_outfile) 124 | 125 | words = [] 126 | for eidx, e in enumerate(events_list): 127 | words_tmp = [ 128 | event2word[k][e[k]] for k in class_keys 129 | ] 130 | words.append(words_tmp) 131 | 132 | # save 133 | path_outfile = os.path.join(path_outdir, file + '.npy') 134 | fn = os.path.basename(path_outfile) 135 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True) 136 | np.save(path_outfile, words) -------------------------------------------------------------------------------- /dataset/get_cls_idx.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "True" 12 | ] 13 | }, 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "import os\n", 21 | "import numpy as np\n", 22 | "path_data_root = 'where you put the data'\n", 23 | "train_data_name = 'name of your processed train data'\n", 24 | "path_train_data = os.path.join(path_data_root, train_data_name + '.npz')\n", 25 | "os.path.exists(path_train_data)" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "train_data = np.load(path_train_data)\n", 35 | "train_x = train_data['x']" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "class_1_idx = []\n", 45 | "class_2_idx = []\n", 46 | "class_3_idx = []\n", 47 | "class_4_idx = []\n", 48 | "idxs = [class_1_idx, class_2_idx, class_3_idx, class_4_idx]\n", 49 | "for i, sample in enumerate(train_x):\n", 50 | " idxs[sample[0][-1] - 1].append(i)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "data": { 60 | "text/plain": [ 61 | "[2, 7, 13, 19, 32, 41, 50, 52, 64, 66]" 62 | ] 63 | }, 64 | "execution_count": 4, 65 | "metadata": {}, 66 | "output_type": "execute_result" 67 | } 68 | ], 69 | "source": [ 70 | "class_1_idx[:10]" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 5, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "cls_1 = train_x[class_1_idx]" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 6, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "text/plain": [ 90 | "1" 91 | ] 92 | }, 93 | "execution_count": 6, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "cls_1[0][0][-1]" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 7, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "path_train = os.path.join(path_data_root, train_data_name + '_data_idx.npz')\n", 109 | "np.savez(\n", 110 | " path_train, \n", 111 | " cls_1_idx=class_1_idx,\n", 112 | " cls_2_idx=class_2_idx,\n", 113 | " cls_3_idx=class_3_idx,\n", 114 | " cls_4_idx=class_4_idx\n", 115 | " )\n", 116 | " " 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [] 125 | } 126 | ], 127 | "metadata": { 128 | "kernelspec": { 129 | "display_name": "Python 3 (ipykernel)", 130 | "language": "python", 131 | "name": "python3" 132 | }, 133 | "language_info": { 134 | "codemirror_mode": { 135 | "name": "ipython", 136 | "version": 3 137 | }, 138 | "file_extension": ".py", 139 | "mimetype": "text/x-python", 140 | "name": "python", 141 | "nbconvert_exporter": "python", 142 | "pygments_lexer": "ipython3", 143 | "version": "3.7.3" 144 | } 145 | }, 146 | "nbformat": 4, 147 | "nbformat_minor": 5 148 | } 149 | -------------------------------------------------------------------------------- /dataset/midi2corpus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import pickle 4 | import numpy as np 5 | import miditoolkit 6 | import collections 7 | 8 | import pandas as pd 9 | import matplotlib.pyplot as plt 10 | import seaborn as sns 11 | 12 | # ================================================== # 13 | # Configuration # 14 | # ================================================== # 15 | BEAT_RESOL = 480 16 | BAR_RESOL = BEAT_RESOL * 4 17 | TICK_RESOL = BEAT_RESOL // 4 18 | INSTR_NAME_MAP = {'piano': 0} 19 | MIN_BPM = 40 20 | MIN_VELOCITY = 40 21 | NOTE_SORTING = 1 # 0: ascending / 1: descending 22 | 23 | DEFAULT_VELOCITY_BINS = np.linspace(0, 128, 64+1, dtype=np.int) 24 | DEFAULT_BPM_BINS = np.linspace(32, 224, 64+1, dtype=np.int) 25 | DEFAULT_SHIFT_BINS = np.linspace(-60, 60, 60+1, dtype=np.int) 26 | DEFAULT_DURATION_BINS = np.arange( 27 | BEAT_RESOL/8, BEAT_RESOL*8+1, BEAT_RESOL/8) 28 | 29 | # ================================================== # 30 | 31 | 32 | def traverse_dir( 33 | root_dir, 34 | extension=('mid', 'MID', 'midi'), 35 | amount=None, 36 | str_=None, 37 | is_pure=False, 38 | verbose=False, 39 | is_sort=False, 40 | is_ext=True): 41 | if verbose: 42 | print('[*] Scanning...') 43 | file_list = [] 44 | cnt = 0 45 | for root, _, files in os.walk(root_dir): 46 | for file in files: 47 | if file.endswith(extension): 48 | if (amount is not None) and (cnt == amount): 49 | break 50 | if str_ is not None: 51 | if str_ not in file: 52 | continue 53 | mix_path = os.path.join(root, file) 54 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 55 | if not is_ext: 56 | ext = pure_path.split('.')[-1] 57 | pure_path = pure_path[:-(len(ext)+1)] 58 | if verbose: 59 | print(pure_path) 60 | file_list.append(pure_path) 61 | cnt += 1 62 | if verbose: 63 | print('Total: %d files' % len(file_list)) 64 | print('Done!!!') 65 | if is_sort: 66 | file_list.sort() 67 | return file_list 68 | 69 | 70 | def proc_one(path_midi, path_outfile): 71 | # --- load --- # 72 | midi_obj = miditoolkit.midi.parser.MidiFile(path_midi) 73 | 74 | # collect emotion tag 75 | emo_tag = path_midi.split('/')[-1][:2] 76 | 77 | # load notes 78 | instr_notes = collections.defaultdict(list) 79 | for instr in midi_obj.instruments: 80 | # skip 81 | if instr.name not in INSTR_NAME_MAP.keys(): 82 | continue 83 | 84 | # process 85 | instr_idx = INSTR_NAME_MAP[instr.name] 86 | for note in instr.notes: 87 | note.instr_idx=instr_idx 88 | instr_notes[instr_idx].append(note) 89 | if NOTE_SORTING == 0: 90 | instr_notes[instr_idx].sort( 91 | key=lambda x: (x.start, x.pitch)) 92 | elif NOTE_SORTING == 1: 93 | instr_notes[instr_idx].sort( 94 | key=lambda x: (x.start, -x.pitch)) 95 | else: 96 | raise ValueError(' [x] Unknown type of sorting.') 97 | 98 | # load chords 99 | chords = [] 100 | for marker in midi_obj.markers: 101 | if marker.text.split('_')[0] != 'global' and \ 102 | 'Boundary' not in marker.text.split('_')[0]: 103 | chords.append(marker) 104 | chords.sort(key=lambda x: x.time) 105 | 106 | # load tempos 107 | tempos = midi_obj.tempo_changes 108 | tempos.sort(key=lambda x: x.time) 109 | 110 | # load labels 111 | labels = [] 112 | for marker in midi_obj.markers: 113 | if 'Boundary' in marker.text.split('_')[0]: 114 | labels.append(marker) 115 | labels.sort(key=lambda x: x.time) 116 | 117 | # load global bpm 118 | gobal_bpm = 120 119 | for marker in midi_obj.markers: 120 | if marker.text.split('_')[0] == 'global' and \ 121 | marker.text.split('_')[1] == 'bpm': 122 | gobal_bpm = int(marker.text.split('_')[2]) 123 | 124 | # --- process items to grid --- # 125 | # compute empty bar offset at head 126 | first_note_time = min([instr_notes[k][0].start for k in instr_notes.keys()]) 127 | last_note_time = max([instr_notes[k][-1].start for k in instr_notes.keys()]) 128 | 129 | quant_time_first = int(np.round(first_note_time / TICK_RESOL) * TICK_RESOL) 130 | offset = quant_time_first // BAR_RESOL # empty bar 131 | last_bar = int(np.ceil(last_note_time / BAR_RESOL)) - offset 132 | print(' > offset:', offset) 133 | print(' > last_bar:', last_bar) 134 | 135 | # process notes 136 | intsr_gird = dict() 137 | for key in instr_notes.keys(): 138 | notes = instr_notes[key] 139 | note_grid = collections.defaultdict(list) 140 | for note in notes: 141 | note.start = note.start - offset * BAR_RESOL 142 | note.end = note.end - offset * BAR_RESOL 143 | 144 | # quantize start 145 | quant_time = int(np.round(note.start / TICK_RESOL) * TICK_RESOL) 146 | 147 | # velocity 148 | note.velocity = DEFAULT_VELOCITY_BINS[ 149 | np.argmin(abs(DEFAULT_VELOCITY_BINS-note.velocity))] 150 | note.velocity = max(MIN_VELOCITY, note.velocity) 151 | 152 | # shift of start 153 | note.shift = note.start - quant_time 154 | note.shift = DEFAULT_SHIFT_BINS[np.argmin(abs(DEFAULT_SHIFT_BINS-note.shift))] 155 | 156 | # duration 157 | note_duration = note.end - note.start 158 | if note_duration > BAR_RESOL: 159 | note_duration = BAR_RESOL 160 | ntick_duration = int(np.round(note_duration / TICK_RESOL) * TICK_RESOL) 161 | note.duration = ntick_duration 162 | 163 | # append 164 | note_grid[quant_time].append(note) 165 | 166 | # set to track 167 | intsr_gird[key] = note_grid.copy() 168 | 169 | # process chords 170 | chord_grid = collections.defaultdict(list) 171 | for chord in chords: 172 | # quantize 173 | chord.time = chord.time - offset * BAR_RESOL 174 | chord.time = 0 if chord.time < 0 else chord.time 175 | quant_time = int(np.round(chord.time / TICK_RESOL) * TICK_RESOL) 176 | 177 | # append 178 | chord_grid[quant_time].append(chord) 179 | 180 | # process tempo 181 | tempo_grid = collections.defaultdict(list) 182 | for tempo in tempos: 183 | # quantize 184 | tempo.time = tempo.time - offset * BAR_RESOL 185 | tempo.time = 0 if tempo.time < 0 else tempo.time 186 | quant_time = int(np.round(tempo.time / TICK_RESOL) * TICK_RESOL) 187 | tempo.tempo = DEFAULT_BPM_BINS[np.argmin(abs(DEFAULT_BPM_BINS-tempo.tempo))] 188 | 189 | # append 190 | tempo_grid[quant_time].append(tempo) 191 | 192 | # process boundary 193 | label_grid = collections.defaultdict(list) 194 | for label in labels: 195 | # quantize 196 | label.time = label.time - offset * BAR_RESOL 197 | label.time = 0 if label.time < 0 else label.time 198 | quant_time = int(np.round(label.time / TICK_RESOL) * TICK_RESOL) 199 | 200 | # append 201 | label_grid[quant_time] = [label] 202 | 203 | # process global bpm 204 | gobal_bpm = DEFAULT_BPM_BINS[np.argmin(abs(DEFAULT_BPM_BINS-gobal_bpm))] 205 | 206 | # collect 207 | song_data = { 208 | 'notes': intsr_gird, 209 | 'chords': chord_grid, 210 | 'tempos': tempo_grid, 211 | 'labels': label_grid, 212 | 'metadata': { 213 | 'global_bpm': gobal_bpm, 214 | 'last_bar': last_bar, 215 | 'emotion': emo_tag 216 | } 217 | } 218 | 219 | # save 220 | fn = os.path.basename(path_outfile) 221 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True) 222 | pickle.dump(song_data, open(path_outfile, 'wb')) 223 | 224 | return song_data 225 | 226 | if __name__ == '__main__': 227 | # paths 228 | path_indir = './midi_analyzed/fixed' 229 | path_outdir = './corpus/fixed' 230 | os.makedirs(path_outdir, exist_ok=True) 231 | 232 | # list files 233 | midifiles = traverse_dir( 234 | path_indir, 235 | is_pure=True, 236 | is_sort=True) 237 | n_files = len(midifiles) 238 | print('num fiels:', n_files) 239 | 240 | # run all 241 | for fidx in range(n_files): 242 | path_midi = midifiles[fidx] 243 | print('{}/{}'.format(fidx, n_files)) 244 | 245 | # paths 246 | path_infile = os.path.join(path_indir, path_midi) 247 | path_outfile = os.path.join(path_outdir, path_midi+'.pkl') 248 | 249 | # proc 250 | _ = proc_one(path_infile, path_outfile) -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | 3 | title: EMOPIA 4 | description: A Multi-Modal Pop Piano Dataset For Emotion Recognition and Emotion-based Music Generation 5 | 6 | -------------------------------------------------------------------------------- /docs/_layouts/default.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | {% if site.google_analytics %} 6 | 7 | 13 | {% endif %} 14 | 15 | 16 | {% seo %} 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | Skip to the content. 26 | 27 | 38 | 39 |
40 | {{ content }} 41 | 42 | 49 |
50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /docs/assets/audio_samples/1_lstm+GA/Q1/gen_Q1_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q1/gen_Q1_1.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/1_lstm+GA/Q1/gen_Q1_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q1/gen_Q1_2.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/1_lstm+GA/Q1/gen_Q1_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q1/gen_Q1_3.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/1_lstm+GA/Q2/gen_Q2_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q2/gen_Q2_1.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/1_lstm+GA/Q2/gen_Q2_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q2/gen_Q2_2.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/1_lstm+GA/Q2/gen_Q2_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q2/gen_Q2_3.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/1_lstm+GA/Q3/gen_Q3_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q3/gen_Q3_1.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/1_lstm+GA/Q3/gen_Q3_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q3/gen_Q3_2.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/1_lstm+GA/Q3/gen_Q3_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q3/gen_Q3_3.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/1_lstm+GA/Q4/gen_Q4_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q4/gen_Q4_1.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/1_lstm+GA/Q4/gen_Q4_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q4/gen_Q4_2.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/1_lstm+GA/Q4/gen_Q4_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/1_lstm+GA/Q4/gen_Q4_3.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/2_Transformer/Q1/Q1_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q1/Q1_1.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/2_Transformer/Q1/Q1_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q1/Q1_2.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/2_Transformer/Q1/Q1_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q1/Q1_3.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/2_Transformer/Q2/Q2_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q2/Q2_1.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/2_Transformer/Q2/Q2_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q2/Q2_2.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/2_Transformer/Q2/Q2_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q2/Q2_3.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/2_Transformer/Q3/Q3_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q3/Q3_1.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/2_Transformer/Q3/Q3_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q3/Q3_2.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/2_Transformer/Q3/Q3_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q3/Q3_3.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/2_Transformer/Q4/Q4_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q4/Q4_1.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/2_Transformer/Q4/Q4_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q4/Q4_2.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/2_Transformer/Q4/Q4_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/2_Transformer/Q4/Q4_3.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/3_Pre-trained_Transformer/Q1/Q1_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q1/Q1_1.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/3_Pre-trained_Transformer/Q1/Q1_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q1/Q1_2.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/3_Pre-trained_Transformer/Q1/Q1_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q1/Q1_3.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/3_Pre-trained_Transformer/Q2/Q2_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q2/Q2_1.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/3_Pre-trained_Transformer/Q2/Q2_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q2/Q2_2.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/3_Pre-trained_Transformer/Q2/Q2_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q2/Q2_3.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/3_Pre-trained_Transformer/Q3/Q3_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q3/Q3_1.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/3_Pre-trained_Transformer/Q3/Q3_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q3/Q3_2.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/3_Pre-trained_Transformer/Q3/Q3_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q3/Q3_3.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/3_Pre-trained_Transformer/Q4/Q4_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q4/Q4_1.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/3_Pre-trained_Transformer/Q4/Q4_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q4/Q4_2.mp3 -------------------------------------------------------------------------------- /docs/assets/audio_samples/3_Pre-trained_Transformer/Q4/Q4_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/assets/audio_samples/3_Pre-trained_Transformer/Q4/Q4_3.mp3 -------------------------------------------------------------------------------- /docs/assets/css/style.css: -------------------------------------------------------------------------------- 1 | /*! normalize.css v3.0.2 | MIT License | git.io/normalize */@import url("https://fonts.googleapis.com/css?family=Open+Sans:400,700&display=swap");@import url("https://fonts.googleapis.com/css2?family=Montserrat:ital,wght@0,400;0,700;1,400;1,700&family=Lato:ital,wght@0,400;0,700;1,400;1,700&display=swap");html{font-family:sans-serif;-ms-text-size-adjust:100%;-webkit-text-size-adjust:100%}body{margin:0}article,aside,details,figcaption,figure,footer,header,hgroup,main,menu,nav,section,summary{display:block}audio,canvas,progress,video{display:inline-block;vertical-align:baseline}audio:not([controls]){display:none;height:0}[hidden],template{display:none}a{background-color:transparent}a:active,a:hover{outline:0}abbr[title]{border-bottom:1px dotted}b,strong{font-weight:bold}dfn{font-style:italic}h1{font-size:2em;margin:0.67em 0}mark{background:#ff0;color:#000}small{font-size:80%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sup{top:-0.5em}sub{bottom:-0.25em}img{border:0}svg:not(:root){overflow:hidden}figure{margin:1em 40px}hr{box-sizing:content-box;height:0}pre{overflow:auto}code,kbd,pre,samp{font-family:monospace, monospace;font-size:1em}button,input,optgroup,select,textarea{color:inherit;font:inherit;margin:0}button{overflow:visible}button,select{text-transform:none}button,html input[type="button"],input[type="reset"],input[type="submit"]{-webkit-appearance:button;cursor:pointer}button[disabled],html input[disabled]{cursor:default}button::-moz-focus-inner,input::-moz-focus-inner{border:0;padding:0}input{line-height:normal}input[type="checkbox"],input[type="radio"]{box-sizing:border-box;padding:0}input[type="number"]::-webkit-inner-spin-button,input[type="number"]::-webkit-outer-spin-button{height:auto}input[type="search"]{-webkit-appearance:textfield;box-sizing:content-box}input[type="search"]::-webkit-search-cancel-button,input[type="search"]::-webkit-search-decoration{-webkit-appearance:none}fieldset{border:1px solid #c0c0c0;margin:0 2px;padding:0.35em 0.625em 0.75em}legend{border:0;padding:0}textarea{overflow:auto}optgroup{font-weight:bold}table{border-collapse:collapse;border-spacing:0}td,th{padding:0}.highlight table td{padding:5px}.highlight table pre{margin:0}.highlight .cm{color:#999988;font-style:italic}.highlight .cp{color:#999999;font-weight:bold}.highlight .c1{color:#999988;font-style:italic}.highlight .cs{color:#999999;font-weight:bold;font-style:italic}.highlight .c,.highlight .cd{color:#999988;font-style:italic}.highlight .err{color:#a61717;background-color:#e3d2d2}.highlight .gd{color:#000000;background-color:#ffdddd}.highlight .ge{color:#000000;font-style:italic}.highlight .gr{color:#aa0000}.highlight .gh{color:#999999}.highlight .gi{color:#000000;background-color:#ddffdd}.highlight .go{color:#888888}.highlight .gp{color:#555555}.highlight .gs{font-weight:bold}.highlight .gu{color:#aaaaaa}.highlight .gt{color:#aa0000}.highlight .kc{color:#000000;font-weight:bold}.highlight .kd{color:#000000;font-weight:bold}.highlight .kn{color:#000000;font-weight:bold}.highlight .kp{color:#000000;font-weight:bold}.highlight .kr{color:#000000;font-weight:bold}.highlight .kt{color:#445588;font-weight:bold}.highlight .k,.highlight .kv{color:#000000;font-weight:bold}.highlight .mf{color:#009999}.highlight .mh{color:#009999}.highlight .il{color:#009999}.highlight .mi{color:#009999}.highlight .mo{color:#009999}.highlight .m,.highlight .mb,.highlight .mx{color:#009999}.highlight .sb{color:#d14}.highlight .sc{color:#d14}.highlight .sd{color:#d14}.highlight .s2{color:#d14}.highlight .se{color:#d14}.highlight .sh{color:#d14}.highlight .si{color:#d14}.highlight .sx{color:#d14}.highlight .sr{color:#009926}.highlight .s1{color:#d14}.highlight .ss{color:#4b001d}.highlight .s{color:#d14}.highlight .na{color:#008080}.highlight .bp{color:#999999}.highlight .nb{color:#0086B3}.highlight .nc{color:#445588;font-weight:bold}.highlight .no{color:#008080}.highlight .nd{color:#3c5d5d;font-weight:bold}.highlight .ni{color:#800080}.highlight .ne{color:#990000;font-weight:bold}.highlight .nf{color:#990000;font-weight:bold}.highlight .nl{color:#990000;font-weight:bold}.highlight .nn{color:#555555}.highlight .nt{color:#000080}.highlight .vc{color:#008080}.highlight .vg{color:#008080}.highlight .vi{color:#008080}.highlight .nv{color:#008080}.highlight .ow{color:#000000;font-weight:bold}.highlight .o{color:#000000;font-weight:bold}.highlight .w{color:#bbbbbb}.highlight{background-color:#f8f8f8}*{box-sizing:border-box}body{padding:0;margin:0;font-family:"Lato", "Helvetica Neue", Helvetica, Arial, sans-serif;font-size:16px;line-height:1.5;color:#424b4f}#skip-to-content{height:1px;width:1px;position:absolute;overflow:hidden;top:-10px}#skip-to-content:focus{position:fixed;top:10px;left:10px;height:auto;width:auto;background:#e19447;outline:thick solid #e19447}a{color:#1e6bb8;text-decoration:none}a:hover{text-decoration:underline}.btn{display:inline-block;margin-bottom:1rem;color:rgba(255,255,255,0.7);background-color:rgba(255,255,255,0.08);border-color:rgba(255,255,255,0.2);border-style:solid;border-width:1px;border-radius:0.3rem;transition:color 0.2s, background-color 0.2s, border-color 0.2s}.btn:hover{color:rgba(255,255,255,0.8);text-decoration:none;background-color:rgba(255,255,255,0.2);border-color:rgba(255,255,255,0.3)}.btn+.btn{margin-left:1rem}@media screen and (min-width: 64em){.btn{padding:0.75rem 1rem}}@media screen and (min-width: 42em) and (max-width: 64em){.btn{padding:0.6rem 0.9rem;font-size:0.9rem}}@media screen and (max-width: 42em){.btn{display:block;width:100%;padding:0.75rem;font-size:0.9rem}.btn+.btn{margin-top:1rem;margin-left:0}}.page-header{font-family:"Montserrat", "Helvetica Neue", Helvetica, Arial, sans-serif;color:#fff;text-align:center;background-color:#006;background-image:url("https://raw.githubusercontent.com/annahung31/EMOPIA/main/docs/img/background.jpg"),linear-gradient(160deg, #333, #006);background-position:center;-webkit-background-size:cover;-moz-background-size:cover;-o-background-size:cover;background-size:cover}@media screen and (min-width: 64em){.page-header{padding:5rem 6rem}}@media screen and (min-width: 42em) and (max-width: 64em){.page-header{padding:3rem 4rem}}@media screen and (max-width: 42em){.page-header{padding:2rem 1rem}}.project-name{margin-top:0;margin-bottom:0.1rem;text-shadow:2px 2px 2px rgba(15,10,86,0.8)}@media screen and (min-width: 64em){.project-name{font-size:3.25rem}}@media screen and (min-width: 42em) and (max-width: 64em){.project-name{font-size:2.25rem}}@media screen and (max-width: 42em){.project-name{font-size:1.75rem}}.project-tagline{margin-bottom:2rem;font-weight:normal;opacity:0.9;text-shadow:2px 2px 2px rgba(15,10,86,0.8)}@media screen and (min-width: 64em){.project-tagline{font-size:1.5rem}}@media screen and (min-width: 42em) and (max-width: 64em){.project-tagline{font-size:1.25rem}}@media screen and (max-width: 42em){.project-tagline{font-size:1.15rem}}.main-content{word-wrap:break-word}.main-content :first-child{margin-top:0}@media screen and (min-width: 64em){.main-content{max-width:64rem;padding:2rem 6rem;margin:0 auto;font-size:1.1rem}}@media screen and (min-width: 42em) and (max-width: 64em){.main-content{padding:2rem 4rem;font-size:1.1rem}}@media screen and (max-width: 42em){.main-content{padding:2rem 1rem;font-size:1rem}}.main-content kbd{background-color:#fafbfc;border:1px solid #c6cbd1;border-bottom-color:#959da5;border-radius:3px;box-shadow:inset 0 -1px 0 #959da5;color:#444d56;display:inline-block;font-size:11px;line-height:10px;padding:3px 5px;vertical-align:middle}.main-content img{max-width:100%}.main-content h1,.main-content h2,.main-content h3,.main-content h4,.main-content h5,.main-content h6{margin-top:2rem;margin-bottom:1rem;font-weight:normal;color:#EC7063}.main-content p{margin-bottom:1em}.main-content code{padding:2px 4px;font-family:Consolas, "Liberation Mono", Menlo, Courier, monospace;font-size:0.9rem;color:#567482;background-color:#f3f6fa;border-radius:0.3rem}.main-content pre{padding:0.8rem;margin-top:0;margin-bottom:1rem;font:1rem Consolas, "Liberation Mono", Menlo, Courier, monospace;color:#567482;word-wrap:normal;background-color:#f3f6fa;border:solid 1px #dce6f0;border-radius:0.3rem}.main-content pre>code{padding:0;margin:0;font-size:0.9rem;color:#567482;word-break:normal;white-space:pre;background:transparent;border:0}.main-content .highlight{margin-bottom:1rem}.main-content .highlight pre{margin-bottom:0;word-break:normal}.main-content .highlight pre,.main-content pre{padding:0.8rem;overflow:auto;font-size:0.9rem;line-height:1.45;border-radius:0.3rem;-webkit-overflow-scrolling:touch}.main-content pre code,.main-content pre tt{display:inline;max-width:initial;padding:0;margin:0;overflow:initial;line-height:inherit;word-wrap:normal;background-color:transparent;border:0}.main-content pre code:before,.main-content pre code:after,.main-content pre tt:before,.main-content pre tt:after{content:normal}.main-content ul,.main-content ol{margin-top:0}.main-content blockquote{padding:0 1rem;margin-left:0;color:#819198;border-left:0.3rem solid #dce6f0}.main-content blockquote>:first-child{margin-top:0}.main-content blockquote>:last-child{margin-bottom:0}.main-content table{display:block;width:100%;overflow:auto;word-break:normal;word-break:keep-all;-webkit-overflow-scrolling:touch}.main-content table th{font-weight:normal;text-align:left}.main-content table th,.main-content table td{padding:0.5rem 1rem;text-align:left;border:1px solid rgba(255,255,255,0)}.main-content dl{padding:0}.main-content dl dt{padding:0;margin-top:1rem;font-size:1rem;font-weight:bold}.main-content dl dd{padding:0;margin-bottom:1rem}.main-content hr{height:2px;padding:0;margin:1rem 0;background-color:#eff0f1;border:0}.site-footer{padding-top:2rem;margin-top:2rem;border-top:solid 1px #eff0f1}@media screen and (min-width: 64em){.site-footer{font-size:1rem}}@media screen and (min-width: 42em) and (max-width: 64em){.site-footer{font-size:1rem}}@media screen and (max-width: 42em){.site-footer{font-size:0.9rem}}.site-footer-owner{display:block;font-weight:bold}.site-footer-credits{color:#819198} 2 | {"mode":"full","isActive":false} 3 | table.audio-table td { 4 | vertical-align: top; } 5 | 6 | table.audio-table audio { 7 | height: 2rem; 8 | max-width: 9rem; } 9 | 10 | @media screen and (min-width: 200rem) { 11 | table.audio-table audio { 12 | display: block; 13 | width: 100%; 14 | max-width: 100%; } } 15 | 16 | @media screen and (min-width: 106rem) { 17 | table.audio-table { 18 | table-layout: fixed; 19 | overflow-y: visible; 20 | overflow-x: visible; } 21 | table.audio-table .metric-col { 22 | width: 7em; } } 23 | @media screen and (max-width: 106rem) { 24 | table.audio-table { 25 | display: block; 26 | max-height: 90vh; 27 | overflow-y: visible; 28 | overflow-x: visible; } 29 | table.audio-table th { 30 | white-space: nowrap; } } 31 | 32 | figcaption { 33 | color: gray; 34 | padding: 2px; 35 | text-align: center; 36 | } 37 | 38 | table.VA-example{ 39 | /* border: none; */ 40 | position: relative; 41 | left: 20px; 42 | top: 10px; 43 | overflow:visible; 44 | 45 | 46 | } 47 | 48 | table.VA-example td, table.VA-example th { 49 | border: 1px dashed gray; 50 | text-align: center; 51 | font-weight:bold; 52 | color: #566573; 53 | } 54 | 55 | 56 | table.num-table{ 57 | /* border: none; */ 58 | /* margin: auto; */ 59 | width: 100%; 60 | margin-left:auto; 61 | margin-right:auto; 62 | 63 | } 64 | 65 | table.num-table td, table.num-table th { 66 | border: 1px dashed gray; 67 | text-align: center; 68 | color: #566573; 69 | } 70 | 71 | 72 | 73 | 74 | .resp-iframe { 75 | position: relative; 76 | top: 0; 77 | left: 0; 78 | width: 255; 79 | height: 142; 80 | border-radius: 10px; 81 | } 82 | 83 | .VA-image { 84 | background-image: url("https://raw.githubusercontent.com/annahung31/EMOPIA/main/docs/img/VA_Q.png"); 85 | background-color: #FFFFFF; 86 | height: 600px; 87 | background-position: center; 88 | background-repeat: no-repeat; 89 | background-size: 950px 550px; 90 | position: relative; 91 | } 92 | -------------------------------------------------------------------------------- /docs/img/VA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/VA.png -------------------------------------------------------------------------------- /docs/img/VA_Q.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/VA_Q.png -------------------------------------------------------------------------------- /docs/img/background.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/background.jpg -------------------------------------------------------------------------------- /docs/img/classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/classification.png -------------------------------------------------------------------------------- /docs/img/cp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/cp.png -------------------------------------------------------------------------------- /docs/img/emopia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/emopia.png -------------------------------------------------------------------------------- /docs/img/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/example.png -------------------------------------------------------------------------------- /docs/img/feature.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/feature.png -------------------------------------------------------------------------------- /docs/img/key.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/key.png -------------------------------------------------------------------------------- /docs/img/number.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/number.png -------------------------------------------------------------------------------- /docs/img/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/pipeline.png -------------------------------------------------------------------------------- /docs/img/representation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/representation.png -------------------------------------------------------------------------------- /docs/img/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annahung31/EMOPIA/8905a168d5aefdb03d8f7cadf6d20f2fee80d810/docs/img/results.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # EMOPIA 2 | 3 | EMOPIA (pronounced ‘yee-mò-pi-uh’) dataset is a shared multi-modal (audio and MIDI) database focusing on perceived emotion in **pop piano music**, to facilitate research on various tasks related to music emotion. The dataset contains **1,087** music clips from 387 songs and **clip-level** emotion labels annotated by four dedicated annotators. Since the clips are not restricted to one clip per song, they can also be used for song-level analysis. 4 | 5 | The detail of the methodology for building the dataset please refer to our paper. 6 | 7 | * [Paper on Arxiv](https://arxiv.org/abs/2108.01374) 8 | * [Dataset on Zenodo](https://zenodo.org/record/5090631#.YPPo-JMzZz8) 9 | * [Code for classification](https://github.com/SeungHeonDoh/EMOPIA_cls) 10 | * [Code for generation](https://github.com/annahung31/EMOPIA) 11 | 12 | 13 | ### Example of the dataset 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 30 | 31 | 32 | 37 | 38 | 39 | 40 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 56 | 57 | 62 | 63 | 64 | 65 |
Low ValenceHigh Valence
High Arousal 27 | 28 |
Q2
29 |
33 | 35 |
Q1
36 |
Low Arousal 52 | 54 |
Q3
55 |
58 | 60 |
Q4
61 |
66 | 67 | 68 | ### Number of clips 69 | The following table shows the number of clips and their average length for each quadrant in Russell’s valence-arousal emotion space, in EMOPIA. 70 |
71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 |
Quadrant# clipsAvg. length (in sec / #tokens)
Q125031.9 / 1,065
Q226535.6 / 1,368
Q325340.6 / 771
Q431038.2 / 729
101 |
102 | 103 | ### Pipeline of data collection 104 |
105 | 106 |
Fig.1
107 |
108 | 109 | 110 | 111 | ### Dataset Analysis 112 | 113 |
114 | 115 |
Fig.2 Violin plots of the distribution in (a) note density, (b) length, and (c) velocity for clips from different classes.
116 |
117 | 118 |
119 | 120 |
Fig.3 Histogram of the keys (left / right: major / minor 249 keys) for clips from different emotion classes.
121 |
122 | 123 | 124 | 125 | ## Emotion Classification 126 | 127 | * We performed emotion classification task in audio domain and symbolic domain, respectively. 128 | 129 | * Our baseline approach is logistic regression using handcraft features. For the symbolic domain, note length, velocity, beat note density, and key were used, and for the audio domain, an average of 20 dimensions of mel-frequency cepstral co-efficient(MFCC) vectors were used. 130 | 131 |
132 | 133 |
Fig.4 symbolic-domain and audio domain representation.
134 |
135 | 136 | * The input representation of the deep learning appraoch uses mel-spectrogram for the audio domain and midi-like and remi for the symbolic domain. 137 | 138 |
139 | 140 |
Fig.5 results of emotion classification.
141 |
142 | 143 | * We adopt the [A Structured Self-attentive Sentence Embedding](https://arxiv.org/abs/1703.03130) for symbolic-domain classification network, and adopt the [Short-chunk CNN + Residual](https://arxiv.org/abs/2006.00751) for audio-domain classification network. 144 | 145 | 146 |
147 | 148 |
Fig.6 inference example of Sakamoto: Merry Christmas Mr. Lawrence
149 |
150 | 151 | * An inference example is Sakamoto: [Sakamoto: Merry Christmas Mr. Lawrence](https://www.youtube.com/watch?v=zOUlPrCRlVMab_channel=RyuichiSakamoto-Topic). The emotional change in the first half and second half of the song is impressive. The front part is clearly low arousal and the second half turns into high arousal. Impressively, audio and mid classifiers return different inference results. 152 | 153 | For the classification codes, please refer to [SeungHeon's repository](https://github.com/SeungHeonDoh/EMOPIA_cls). 154 | 155 | The pre-trained model weights are also in the repository. 156 | 157 | ## Conditional Generation 158 | 159 | * We adopt the [Compound Word Transformer](https://github.com/YatingMusic/compound-word-transformer) for emotion-conditioned symbolic music generation using EMPOIA. The CP+emotion representation is used as the data representation. 160 | 161 | * In the data representation, we additionally consider the “emotion” tokens and make it a new family. The prepending approach is motivated by [CTRL](https://arxiv.org/abs/1909.05858). 162 | 163 |
164 | 165 |
Fig.7 Compound word with emotion token 166 |
167 |

168 | 169 | 170 | * As the size of EMOPIA might not be big enough, we use additionally the [AILabs1k7 dataset](https://github.com/YatingMusic/compound-word-transformer) compiled by Hsiao et al. to pre-train the Transformer. 171 | 172 | * You can download the model weight of the pre-trained Transformer from [here](https://drive.google.com/file/d/19Seq18b2JNzOamEQMG1uarKjj27HJkHu/view?usp=sharing). 173 | 174 | 175 | * The following are some generated examples for each Quadrant: 176 | 177 |

Q1 (High valence, high arousal)

178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 |
Baseline
Transformer w/o pre-training
Transformer w/ pre-training
203 | 204 | 205 | 206 |

Q2 (Low valence, high arousal)

207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 229 | 230 | 231 | 232 | 233 |
Baseline
Transformer w/o pre-training
Transformer w/ pre-training
234 | 235 | 236 | 237 |

Q3 (Low valence, low arousal)

238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 |
Baseline
Transformer w/o pre-training
Transformer w/ pre-training
263 | 264 | 265 | 266 |

Q4 (High valence, low arousal)

267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 |
Baseline
Transformer w/o pre-training
Transformer w/ pre-training
292 | 293 | 294 | ## Authors and Affiliations 295 | 296 | * Hsiao-Tzu (Anna) Hung 297 | Research Assistant @ Academia Sinica / MS CSIE student @National Taiwan University 298 | r08922a20@csie.ntu.edu.tw 299 | [website](https://annahung31.github.io/), [LinkedIn](https://www.linkedin.com/in/hsiao-tzu-%EF%BC%88anna%EF%BC%89-hung-09829513a/). 300 | 301 | * Joann Ching 302 | Research Assistant @ Academia Sinica 303 | joann8512@gmail.com 304 | [website](https://joann8512.github.io/JoannChing.github.io/) 305 | 306 | * Seungheon Doh 307 | Ph.D Student @ Music and Audio Computing Lab, KAIST 308 | seungheondoh@kaist.ac.kr 309 | [website](https://seungheondoh.github.io/), [LinkedIn](https://www.linkedin.com/in/dohppak/) 310 | 311 | * Juhan Nam 312 | associate professor @ KAIST 313 | juhan.nam@kaist.ac.kr 314 | [Website](https://mac.kaist.ac.kr/~juhan/) 315 | 316 | * Yi-Hsuan Yang 317 | Chief Music Scientist @ Taiwan AI Labs / Associate Research Fellow @ Academia Sinica 318 | affige@gmail.com, yhyang@ailabs.tw 319 | [website](http://mac.citi.sinica.edu.tw/~yang/) 320 | 321 | 322 | ## Cite this dataset 323 | 324 | {% raw %} 325 | ``` 326 | @inproceedings{ 327 | {EMOPIA}, 328 | author = {Hung, Hsiao-Tzu and Ching, Joann and Doh, Seungheon and Kim, Nabin and Nam, Juhan and Yang, Yi-Hsuan}, 329 | title = {{EMOPIA}: A Multi-Modal Pop Piano Dataset For Emotion Recognition and Emotion-based Music Generation}, 330 | booktitle = {Proc. Int. Society for Music Information Retrieval Conf.}, 331 | year = {2021} 332 | } 333 | ``` 334 | {% endraw %} -------------------------------------------------------------------------------- /docs/init: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # based on workspace/transformer/generate.ipynb 2 | 3 | import subprocess 4 | from pathlib import Path 5 | import tempfile 6 | import os 7 | import pickle 8 | import sys 9 | import torch 10 | import numpy as np 11 | from midiSynth.synth import MidiSynth 12 | import cog 13 | 14 | sys.path.insert(0, "workspace/transformer") 15 | from utils import write_midi 16 | from models import TransformerModel 17 | 18 | 19 | EMOTIONS = { 20 | "High valence, high arousal": 1, 21 | "Low valence, high arousal": 2, 22 | "Low valence, low arousal": 3, 23 | "High valence, low arousal": 4, 24 | } 25 | 26 | 27 | class Predictor(cog.Predictor): 28 | def setup(self): 29 | print("Loading dictionary...") 30 | path_dictionary = "dataset/co-representation/dictionary.pkl" 31 | with open(path_dictionary, "rb") as f: 32 | self.dictionary = pickle.load(f) 33 | event2word, self.word2event = self.dictionary 34 | 35 | n_class = [] # num of classes for each token 36 | for key in event2word.keys(): 37 | n_class.append(len(event2word[key])) 38 | n_token = len(n_class) 39 | 40 | print("Loading model...") 41 | path_saved_ckpt = "exp/pretrained_transformer/loss_25_params.pt" 42 | self.net = TransformerModel(n_class, is_training=False) 43 | self.net.cuda() 44 | self.net.eval() 45 | 46 | self.net.load_state_dict(torch.load(path_saved_ckpt)) 47 | 48 | self.midi_synth = MidiSynth() 49 | 50 | @cog.input( 51 | "emotion", 52 | type=str, 53 | default="High valence, high arousal", 54 | options=EMOTIONS.keys(), 55 | help="Emotion to generate for", 56 | ) 57 | @cog.input("seed", type=int, default=-1, help="Random seed, -1 for random") 58 | def predict(self, emotion, seed): 59 | if seed < 0: 60 | seed = int.from_bytes(os.urandom(2), "big") 61 | torch.manual_seed(seed) 62 | np.random.seed(seed) 63 | print(f"Prediction seed: {seed}") 64 | 65 | out_dir = Path(tempfile.mkdtemp()) 66 | midi_path = out_dir / "out.midi" 67 | wav_path = out_dir / "out.wav" 68 | mp3_path = out_dir / "out.mp3" 69 | 70 | emotion_tag = EMOTIONS[emotion] 71 | res, _ = self.net.inference_from_scratch( 72 | self.dictionary, emotion_tag, n_token=8 73 | ) 74 | try: 75 | write_midi(res, str(midi_path), self.word2event) 76 | self.midi_synth.midi2audio(str(midi_path), str(wav_path)) 77 | subprocess.check_output( 78 | [ 79 | "ffmpeg", 80 | "-i", 81 | str(wav_path), 82 | "-af", 83 | "silenceremove=1:0:-50dB,aformat=dblp,areverse,silenceremove=1:0:-50dB,aformat=dblp,areverse", # strip silence 84 | str(mp3_path), 85 | ], 86 | ) 87 | return mp3_path 88 | finally: 89 | midi_path.unlink(missing_ok=True) 90 | wav_path.unlink(missing_ok=True) 91 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn==0.24.1 2 | seaborn==0.11.1 3 | numpy==1.19.5 4 | miditoolkit==0.1.14 5 | pandas==1.1.5 6 | ipdb 7 | tqdm 8 | matplotlib 9 | scipy 10 | pickle 11 | 12 | -------------------------------------------------------------------------------- /workspace/baseline/evolve_generative_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "4" 3 | 4 | import json 5 | import pickle 6 | import argparse 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from midi_generator import generate_midi 11 | from train_classifier import encode_sentence 12 | from train_classifier import get_activated_neurons 13 | from train_generative import build_generative_model 14 | 15 | GEN_MIN = -1 16 | GEN_MAX = 1 17 | 18 | # Directory where trained model will be saved 19 | TRAIN_DIR = "./trained" 20 | 21 | def mutation(individual, mutation_rate): 22 | for i in range(len(individual)): 23 | if np.random.uniform(0, 1) < mutation_rate: 24 | individual[i] = np.random.uniform(GEN_MIN, GEN_MAX) 25 | 26 | def crossover(parent_a, parent_b, ind_size): 27 | # Averaging crossover 28 | return (parent_a + parent_b)/2 29 | 30 | def reproduce(mating_pool, new_population_size, ind_size, mutation_rate): 31 | new_population = np.zeros((new_population_size, ind_size)) 32 | 33 | for i in range(new_population_size): 34 | a = np.random.randint(len(mating_pool)); 35 | b = np.random.randint(len(mating_pool)); 36 | 37 | new_population[i] = crossover(mating_pool[a], mating_pool[b], ind_size); 38 | 39 | # Mutate new children 40 | np.apply_along_axis(mutation, 1, new_population, mutation_rate) 41 | 42 | return new_population; 43 | 44 | def roulette_wheel(population, fitness_pop): 45 | # Normalize fitnesses 46 | norm_fitness_pop = fitness_pop/np.sum(fitness_pop) 47 | 48 | # Here all the fitnesses sum up to 1 49 | r = np.random.uniform(0, 1) 50 | 51 | fitness_so_far = 0 52 | for i in range(len(population)): 53 | fitness_so_far += norm_fitness_pop[i] 54 | 55 | if r < fitness_so_far: 56 | return population[i] 57 | 58 | return population[-1] 59 | 60 | def select(population, fitness_pop, mating_pool_size, ind_size, elite_rate): 61 | mating_pool = np.zeros((mating_pool_size, ind_size)) 62 | 63 | # Apply roulete wheel to select mating_pool_size individuals 64 | for i in range(mating_pool_size): 65 | mating_pool[i] = roulette_wheel(population, fitness_pop) 66 | 67 | # Apply elitism 68 | assert elite_rate >= 0 and elite_rate <= 1 69 | elite_size = int(np.ceil(elite_rate * len(population))) 70 | elite_idxs = np.argsort(-fitness_pop) 71 | 72 | for i in range(elite_size): 73 | r = np.random.randint(0, mating_pool_size) 74 | mating_pool[r] = elite_idxs[i] 75 | 76 | return mating_pool 77 | 78 | def calc_fitness(individual, gen_model, cls_model, char2idx, idx2char, layer_idx, sentiment, runs=30): 79 | encoding_size = gen_model.layers[layer_idx].units 80 | generated_midis = np.zeros((runs, encoding_size)) 81 | 82 | # Get activated neurons 83 | sentneuron_ixs = get_activated_neurons(cls_model) 84 | assert len(individual) == len(sentneuron_ixs) 85 | 86 | # Use individual gens to override model neurons 87 | override = {} 88 | for i, ix in enumerate(sentneuron_ixs): 89 | override[ix] = individual[i] 90 | 91 | # Generate pieces and encode them using the cell state of the generative model 92 | for i in range(runs): 93 | midi_text = generate_midi(gen_model, char2idx, idx2char, seq_len=64, layer_idx=layer_idx, override=override) 94 | generated_midis[i] = encode_sentence(gen_model, midi_text, char2idx, layer_idx) 95 | 96 | midis_sentiment = cls_model.predict(generated_midis).clip(min=0) 97 | return 1.0 - np.sum(np.abs(midis_sentiment - sentiment))/runs 98 | 99 | def evaluate(population, gen_model, cls_model, char2idx, idx2char, layer_idx, sentiment): 100 | fitness = np.zeros((len(population), 1)) 101 | 102 | for i in range(len(population)): 103 | fitness[i] = calc_fitness(population[i], gen_model, cls_model, char2idx, idx2char, layer_idx, sentiment) 104 | 105 | return fitness 106 | 107 | def evolve(pop_size, ind_size, mut_rate, elite_rate, epochs): 108 | # Create initial population 109 | population = np.random.uniform(GEN_MIN, GEN_MAX, (pop_size, ind_size)) 110 | 111 | # Evaluate initial population 112 | fitness_pop = evaluate(population, gen_model, cls_model, char2idx, idx2char, opt.cellix, sent) 113 | print("--> Fitness: \n", fitness_pop) 114 | 115 | for i in range(epochs): 116 | print("-> Epoch", i) 117 | 118 | # Select individuals via roulette wheel to form a mating pool 119 | mating_pool = select(population, fitness_pop, pop_size, ind_size, elite_rate) 120 | 121 | # Reproduce matin pool with crossover and mutation to form new population 122 | population = reproduce(mating_pool, pop_size, ind_size, mut_rate) 123 | 124 | # Calculate fitness of each individual of the population 125 | fitness_pop = evaluate(population, gen_model, cls_model, char2idx, idx2char, opt.cellix, sent) 126 | print("--> Fitness: \n", fitness_pop) 127 | 128 | return population, fitness_pop 129 | 130 | if __name__ == "__main__": 131 | 132 | # Parse arguments 133 | parser = argparse.ArgumentParser(description='evolve_generative.py') 134 | parser.add_argument('--genmodel', type=str, default='./trained', help="Generative model to evolve.") 135 | parser.add_argument('--clsmodel', type=str, default='./trained/classifier_ckpt.p', help="Classifier model to calculate fitness.") 136 | parser.add_argument('--ch2ix', type=str, default='./trained/char2idx.json', help="JSON file with char2idx encoding.") 137 | parser.add_argument('--embed', type=int, default=256, help="Embedding size.") 138 | parser.add_argument('--units', type=int, default=512, help="LSTM units.") 139 | parser.add_argument('--layers', type=int, default=4, help="LSTM layers.") 140 | parser.add_argument('--cellix', type=int, default=4, help="LSTM layer to use as encoder.") 141 | #parser.add_argument('--sent', type=int, default=2, help="Desired sentiment.") 142 | parser.add_argument('--popsize', type=int, default=10, help="Population size.") 143 | parser.add_argument('--epochs', type=int, default=10, help="Epochs to run.") 144 | parser.add_argument('--mrate', type=float, default=0.1, help="Mutation rate.") 145 | parser.add_argument('--elitism', type=float, default=0.1, help="Elitism in percentage.") 146 | 147 | opt = parser.parse_args() 148 | 149 | # Load char2idx dict from json file 150 | with open(opt.ch2ix) as f: 151 | char2idx = json.load(f) 152 | 153 | # Create idx2char from char2idx dict 154 | idx2char = {idx:char for char,idx in char2idx.items()} 155 | 156 | # Calculate vocab_size from char2idx dict 157 | vocab_size = len(char2idx) 158 | 159 | # Rebuild generative model from checkpoint 160 | gen_model = build_generative_model(vocab_size, opt.embed, opt.units, opt.layers, batch_size=1) 161 | gen_model.load_weights(tf.train.latest_checkpoint(opt.genmodel)) 162 | gen_model.build(tf.TensorShape([1, None])) 163 | 164 | # Load classifier model 165 | with open(opt.clsmodel, "rb") as f: 166 | cls_model = pickle.load(f) 167 | 168 | # Set individual size to the number of activated neurons 169 | sentneuron_ixs = get_activated_neurons(cls_model) 170 | ind_size = len(sentneuron_ixs) 171 | 172 | # Evolve for Positive(1)/Negative(1) 173 | sents = [1, 2, 3, 4] 174 | for sent in sents: 175 | print('evolving for {}'.format(sent)) 176 | population, fitness_pop = evolve(opt.popsize, ind_size, opt.mrate, opt.elitism, opt.epochs) 177 | 178 | # Get best individual 179 | best_idx = np.argmax(fitness_pop) 180 | best_individual = population[best_idx] 181 | 182 | # Use best individual gens to create a dictionary with cell values 183 | neurons = {} 184 | for i, ix in enumerate(sentneuron_ixs): 185 | neurons[str(ix)] = best_individual[i] 186 | 187 | 188 | # Persist dictionary with cell values 189 | if sent == 1: 190 | sent = 'Q1' 191 | elif sent == 2: 192 | sent = 'Q2' 193 | elif sent == 3: 194 | sent = 'Q3' 195 | else: 196 | sent = 'Q4' 197 | 198 | with open(os.path.join(TRAIN_DIR, "neurons_" + sent + ".json"), "w") as f: 199 | json.dump(neurons, f) -------------------------------------------------------------------------------- /workspace/baseline/midi_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import math as ma 5 | import music21 as m21 6 | 7 | THREE_DOTTED_BREVE = 15 8 | THREE_DOTTED_32ND = 0.21875 9 | 10 | MIN_VELOCITY = 0 11 | MAX_VELOCITY = 128 12 | 13 | MIN_TEMPO = 24 14 | MAX_TEMPO = 160 15 | 16 | MAX_PITCH = 128 17 | 18 | def load(datapath, sample_freq=4, piano_range=(33, 93), transpose_range=10, stretching_range=10): 19 | text = "" 20 | vocab = set() 21 | 22 | if os.path.isfile(datapath): 23 | # Path is an individual midi file 24 | file_extension = os.path.splitext(datapath)[1] 25 | 26 | if file_extension == ".midi" or file_extension == ".mid": 27 | text = parse_midi(datapath, sample_freq, piano_range, transpose_range, stretching_range) 28 | vocab = set(text.split(" ")) 29 | else: 30 | # Read every file in the given directory 31 | for file in os.listdir(datapath): 32 | file_path = os.path.join(datapath, file) 33 | file_extension = os.path.splitext(file_path)[1] 34 | 35 | # Check if it is not a directory and if it has either .midi or .mid extentions 36 | if os.path.isfile(file_path) and (file_extension == ".midi" or file_extension == ".mid"): 37 | encoded_midi = parse_midi(file_path, sample_freq, piano_range, transpose_range, stretching_range) 38 | 39 | if len(encoded_midi) > 0: 40 | words = set(encoded_midi.split(" ")) 41 | vocab = vocab | words 42 | 43 | text += encoded_midi + " " 44 | 45 | # Remove last space 46 | text = text[:-1] 47 | 48 | return text, vocab 49 | 50 | def parse_midi(file_path, sample_freq, piano_range, transpose_range, stretching_range): 51 | 52 | # Split datapath into dir and filename 53 | midi_dir = os.path.dirname(file_path) 54 | midi_name = os.path.basename(file_path).split(".")[0] 55 | 56 | # If txt version of the midi already exists, load data from it 57 | midi_txt_name = os.path.join(midi_dir, midi_name + ".txt") 58 | 59 | if(os.path.isfile(midi_txt_name)): 60 | midi_fp = open(midi_txt_name, "r") 61 | encoded_midi = midi_fp.read() 62 | else: 63 | # Create a music21 stream and open the midi file 64 | midi = m21.midi.MidiFile() 65 | midi.open(file_path) 66 | midi.read() 67 | midi.close() 68 | 69 | # Translate midi to stream of notes and chords 70 | encoded_midi = midi2encoding(midi, sample_freq, piano_range, transpose_range, stretching_range) 71 | 72 | if len(encoded_midi) > 0: 73 | midi_fp = open(midi_txt_name, "w+") 74 | midi_fp.write(encoded_midi) 75 | midi_fp.flush() 76 | 77 | midi_fp.close() 78 | return encoded_midi 79 | 80 | def midi2encoding(midi, sample_freq, piano_range, transpose_range, stretching_range): 81 | try: 82 | midi_stream = m21.midi.translate.midiFileToStream(midi) 83 | except: 84 | return [] 85 | 86 | # Get piano roll from midi stream 87 | piano_roll = midi2piano_roll(midi_stream, sample_freq, piano_range, transpose_range, stretching_range) 88 | 89 | # Get encoded midi from piano roll 90 | encoded_midi = piano_roll2encoding(piano_roll) 91 | 92 | return " ".join(encoded_midi) 93 | 94 | def piano_roll2encoding(piano_roll): 95 | # Transform piano roll into a list of notes in string format 96 | final_encoding = {} 97 | 98 | perform_i = 0 99 | for version in piano_roll: 100 | lastTempo = -1 101 | lastVelocity = -1 102 | lastDuration = -1.0 103 | 104 | version_encoding = [] 105 | 106 | for i in range(len(version)): 107 | # Time events are stored at the last row 108 | tempo = version[i,-1][0] 109 | if tempo != 0 and tempo != lastTempo: 110 | version_encoding.append("t_" + str(int(tempo))) 111 | lastTempo = tempo 112 | 113 | # Process current time step of the piano_roll 114 | for j in range(len(version[i]) - 1): 115 | duration = version[i,j][0] 116 | velocity = int(version[i,j][1]) 117 | 118 | if velocity != 0 and velocity != lastVelocity: 119 | version_encoding.append("v_" + str(velocity)) 120 | lastVelocity = velocity 121 | 122 | if duration != 0 and duration != lastDuration: 123 | duration_tuple = m21.duration.durationTupleFromQuarterLength(duration) 124 | version_encoding.append("d_" + duration_tuple.type + "_" + str(duration_tuple.dots)) 125 | lastDuration = duration 126 | 127 | if duration != 0 and velocity != 0: 128 | version_encoding.append("n_" + str(j)) 129 | 130 | # End of time step 131 | if len(version_encoding) > 0 and version_encoding[-1][0] == "w": 132 | # Increase wait by one 133 | version_encoding[-1] = "w_" + str(int(version_encoding[-1].split("_")[1]) + 1) 134 | else: 135 | version_encoding.append("w_1") 136 | 137 | # End of piece 138 | version_encoding.append("\n") 139 | 140 | # Check if this version of the MIDI is already added 141 | version_encoding_str = " ".join(version_encoding) 142 | if version_encoding_str not in final_encoding: 143 | final_encoding[version_encoding_str] = perform_i 144 | 145 | perform_i += 1 146 | 147 | return final_encoding.keys() 148 | 149 | def write(encoded_midi, path): 150 | # Base class checks if output path exists 151 | midi = encoding2midi(encoded_midi) 152 | midi.open(path, "wb") 153 | midi.write() 154 | midi.close() 155 | 156 | def encoding2midi(note_encoding, ts_duration=0.25): 157 | notes = [] 158 | 159 | velocity = 100 160 | duration = "16th" 161 | dots = 0 162 | 163 | ts = 0 164 | for note in note_encoding.split(" "): 165 | if len(note) == 0: 166 | continue 167 | 168 | elif note[0] == "w": 169 | wait_count = int(note.split("_")[1]) 170 | ts += wait_count 171 | 172 | elif note[0] == "n": 173 | pitch = int(note.split("_")[1]) 174 | note = m21.note.Note(pitch) 175 | note.duration = m21.duration.Duration(type=duration, dots=dots) 176 | note.offset = ts * ts_duration 177 | note.volume.velocity = velocity 178 | notes.append(note) 179 | 180 | elif note[0] == "d": 181 | duration = note.split("_")[1] 182 | dots = int(note.split("_")[2]) 183 | 184 | elif note[0] == "v": 185 | velocity = int(note.split("_")[1]) 186 | 187 | elif note[0] == "t": 188 | tempo = int(note.split("_")[1]) 189 | 190 | if tempo > 0: 191 | mark = m21.tempo.MetronomeMark(number=tempo) 192 | mark.offset = ts * ts_duration 193 | notes.append(mark) 194 | 195 | piano = m21.instrument.fromString("Piano") 196 | notes.insert(0, piano) 197 | 198 | piano_stream = m21.stream.Stream(notes) 199 | main_stream = m21.stream.Stream([piano_stream]) 200 | 201 | return m21.midi.translate.streamToMidiFile(main_stream) 202 | 203 | def midi_parse_notes(midi_stream, sample_freq): 204 | note_filter = m21.stream.filters.ClassFilter('Note') 205 | 206 | note_events = [] 207 | for note in midi_stream.recurse().addFilter(note_filter): 208 | pitch = note.pitch.midi 209 | duration = note.duration.quarterLength 210 | velocity = note.volume.velocity 211 | offset = ma.floor(note.offset * sample_freq) 212 | 213 | note_events.append((pitch, duration, velocity, offset)) 214 | 215 | return note_events 216 | 217 | def midi_parse_chords(midi_stream, sample_freq): 218 | chord_filter = m21.stream.filters.ClassFilter('Chord') 219 | 220 | note_events = [] 221 | for chord in midi_stream.recurse().addFilter(chord_filter): 222 | pitches_in_chord = chord.pitches 223 | for pitch in pitches_in_chord: 224 | pitch = pitch.midi 225 | duration = chord.duration.quarterLength 226 | velocity = chord.volume.velocity 227 | offset = ma.floor(chord.offset * sample_freq) 228 | 229 | note_events.append((pitch, duration, velocity, offset)) 230 | 231 | return note_events 232 | 233 | def midi_parse_metronome(midi_stream, sample_freq): 234 | metronome_filter = m21.stream.filters.ClassFilter('MetronomeMark') 235 | 236 | time_events = [] 237 | for metro in midi_stream.recurse().addFilter(metronome_filter): 238 | time = int(metro.number) 239 | offset = ma.floor(metro.offset * sample_freq) 240 | time_events.append((time, offset)) 241 | 242 | return time_events 243 | 244 | def midi2notes(midi_stream, sample_freq, transpose_range): 245 | notes = [] 246 | notes += midi_parse_notes(midi_stream, sample_freq) 247 | notes += midi_parse_chords(midi_stream, sample_freq) 248 | 249 | # Transpose the notes to all the keys in transpose_range 250 | return transpose_notes(notes, transpose_range) 251 | 252 | def midi2piano_roll(midi_stream, sample_freq, piano_range, transpose_range, stretching_range): 253 | # Calculate the amount of time steps in the piano roll 254 | time_steps = ma.floor(midi_stream.duration.quarterLength * sample_freq) + 1 255 | 256 | # Parse the midi file into a list of notes (pitch, duration, velocity, offset) 257 | transpositions = midi2notes(midi_stream, sample_freq, transpose_range) 258 | 259 | time_events = midi_parse_metronome(midi_stream, sample_freq) 260 | time_streches = strech_time(time_events, stretching_range) 261 | 262 | return notes2piano_roll(transpositions, time_streches, time_steps, piano_range) 263 | 264 | def notes2piano_roll(transpositions, time_streches, time_steps, piano_range): 265 | performances = [] 266 | 267 | min_pitch, max_pitch = piano_range 268 | for t_ix in range(len(transpositions)): 269 | for s_ix in range(len(time_streches)): 270 | # Create piano roll with calcualted size. 271 | # Add one dimension to very entry to store velocity and duration. 272 | piano_roll = np.zeros((time_steps, MAX_PITCH + 1, 2)) 273 | 274 | for note in transpositions[t_ix]: 275 | pitch, duration, velocity, offset = note 276 | if duration == 0.0: 277 | continue 278 | 279 | # Force notes to be inside the specified piano_range 280 | pitch = clamp_pitch(pitch, max_pitch, min_pitch) 281 | 282 | piano_roll[offset, pitch][0] = clamp_duration(duration) 283 | piano_roll[offset, pitch][1] = discretize_value(velocity, bins=32, range=(MIN_VELOCITY, MAX_VELOCITY)) 284 | 285 | for time_event in time_streches[s_ix]: 286 | time, offset = time_event 287 | piano_roll[offset, -1][0] = discretize_value(time, bins=100, range=(MIN_TEMPO, MAX_TEMPO)) 288 | 289 | performances.append(piano_roll) 290 | 291 | return performances 292 | 293 | def transpose_notes(notes, transpose_range): 294 | transpositions = [] 295 | 296 | # Modulate the piano_roll for other keys 297 | first_key = -ma.floor(transpose_range/2) 298 | last_key = ma.ceil(transpose_range/2) 299 | 300 | for key in range(first_key, last_key): 301 | notes_in_key = [] 302 | for n in notes: 303 | pitch, duration, velocity, offset = n 304 | t_pitch = pitch + key 305 | notes_in_key.append((t_pitch, duration, velocity, offset)) 306 | transpositions.append(notes_in_key) 307 | 308 | return transpositions 309 | 310 | def strech_time(time_events, stretching_range): 311 | streches = [] 312 | 313 | # Modulate the piano_roll for other keys 314 | slower_time = -ma.floor(stretching_range/2) 315 | faster_time = ma.ceil(stretching_range/2) 316 | 317 | # Modulate the piano_roll for other keys 318 | for t_strech in range(slower_time, faster_time): 319 | time_events_in_strech = [] 320 | for t_ev in time_events: 321 | time, offset = t_ev 322 | s_time = time + 0.05 * t_strech * MAX_TEMPO 323 | time_events_in_strech.append((s_time, offset)) 324 | streches.append(time_events_in_strech) 325 | 326 | return streches 327 | 328 | def discretize_value(val, bins, range): 329 | min_val, max_val = range 330 | 331 | val = int(max(min_val, val)) 332 | val = int(min(val, max_val)) 333 | 334 | bin_size = (max_val/bins) 335 | return ma.floor(val/bin_size) * bin_size 336 | 337 | def clamp_pitch(pitch, max, min): 338 | while pitch < min: 339 | pitch += 12 340 | while pitch >= max: 341 | pitch -= 12 342 | return pitch 343 | 344 | def clamp_duration(duration, max=THREE_DOTTED_BREVE, min=THREE_DOTTED_32ND): 345 | # Max duration is 3-dotted breve 346 | if duration > max: 347 | duration = max 348 | 349 | # min duration is 3-dotted breve 350 | if duration < min: 351 | duration = min 352 | 353 | duration_tuple = m21.duration.durationTupleFromQuarterLength(duration) 354 | if duration_tuple.type == "inexpressible": 355 | duration_clossest_type = m21.duration.quarterLengthToClosestType(duration)[0] 356 | duration = m21.duration.typeToDuration[duration_clossest_type] 357 | 358 | return duration 359 | 360 | if __name__ == "__main__": 361 | 362 | # Parse arguments 363 | parser = argparse.ArgumentParser(description='midi_encoder.py') 364 | parser.add_argument('--path', type=str, required=True, help="Path to midi data.") 365 | parser.add_argument('--transp', type=int, default=1, help="Transpose range.") 366 | parser.add_argument('--strech', type=int, default=1, help="Time stretching range.") 367 | opt = parser.parse_args() 368 | 369 | # Load data and encoded it 370 | text, vocab = load(opt.path, transpose_range=opt.transp, stretching_range=opt.strech) 371 | print(text) 372 | 373 | # Write all data to midi file 374 | write(text, "encoded.mid") 375 | -------------------------------------------------------------------------------- /workspace/baseline/midi_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "4" 3 | 4 | import json 5 | import argparse 6 | import numpy as np 7 | import tensorflow as tf 8 | import midi_encoder as me 9 | 10 | from train_generative import build_generative_model 11 | from train_classifier import preprocess_sentence 12 | 13 | GENERATED_DIR = './generated' 14 | 15 | def override_neurons(model, layer_idx, override): 16 | h_state, c_state = model.get_layer(index=layer_idx).states 17 | 18 | c_state = c_state.numpy() 19 | for neuron, value in override.items(): 20 | c_state[:,int(neuron)] = int(value) 21 | 22 | model.get_layer(index=layer_idx).states = (h_state, tf.Variable(c_state)) 23 | 24 | def sample_next(predictions, k): 25 | # Sample using a categorical distribution over the top k midi chars 26 | top_k = tf.math.top_k(predictions, k) 27 | top_k_choices = top_k[1].numpy().squeeze() 28 | top_k_values = top_k[0].numpy().squeeze() 29 | 30 | if np.random.uniform(0, 1) < .5: 31 | predicted_id = top_k_choices[0] 32 | else: 33 | p_choices = tf.math.softmax(top_k_values[1:]).numpy() 34 | predicted_id = np.random.choice(top_k_choices[1:], 1, p=p_choices)[0] 35 | 36 | return predicted_id 37 | 38 | def process_init_text(model, init_text, char2idx, layer_idx, override): 39 | model.reset_states() 40 | 41 | for c in init_text.split(" "): 42 | # Run a forward pass 43 | try: 44 | input_eval = tf.expand_dims([char2idx[c]], 0) 45 | 46 | # override sentiment neurons 47 | override_neurons(model, layer_idx, override) 48 | 49 | predictions = model(input_eval) 50 | except KeyError: 51 | if c != "": 52 | print("Can't process char", s) 53 | 54 | return predictions 55 | 56 | def generate_midi(model, char2idx, idx2char, init_text="", seq_len=256, k=3, layer_idx=-2, override={}): 57 | # Add front and end pad to the initial text 58 | init_text = preprocess_sentence(init_text) 59 | 60 | # Empty midi to store our results 61 | midi_generated = [] 62 | 63 | # Process initial text 64 | predictions = process_init_text(model, init_text, char2idx, layer_idx, override) 65 | 66 | # Here batch size == 1 67 | model.reset_states() 68 | for i in range(seq_len): 69 | # remove the batch dimension 70 | predictions = tf.squeeze(predictions, 0).numpy() 71 | 72 | # Sample using a categorical distribution over the top k midi chars 73 | predicted_id = sample_next(predictions, k) 74 | 75 | # Append it to generated midi 76 | midi_generated.append(idx2char[predicted_id]) 77 | 78 | # override sentiment neurons 79 | override_neurons(model, layer_idx, override) 80 | 81 | #Run a new forward pass 82 | input_eval = tf.expand_dims([predicted_id], 0) 83 | predictions = model(input_eval) 84 | 85 | return init_text + " " + " ".join(midi_generated) 86 | 87 | if __name__ == "__main__": 88 | 89 | # Parse arguments 90 | parser = argparse.ArgumentParser(description='midi_generator.py') 91 | parser.add_argument('--model', type=str, default='./trained', help="Checkpoint dir.") 92 | parser.add_argument('--ch2ix', type=str, default='./trained/char2idx.json', help="JSON file with char2idx encoding.") 93 | parser.add_argument('--embed', type=int, default=256, help="Embedding size.") 94 | parser.add_argument('--units', type=int, default=512, help="LSTM units.") 95 | parser.add_argument('--layers', type=int, default=4, help="LSTM layers.") 96 | parser.add_argument('--seqinit', type=str, default="\n", help="Sequence init.") 97 | parser.add_argument('--seqlen', type=int, default=512, help="Sequence lenght.") 98 | parser.add_argument('--cellix', type=int, default=4, help="LSTM layer to use as encoder.") 99 | parser.add_argument('--override', type=str, default="./trained/neurons_Q1.json", help="JSON file with neuron values to override.") 100 | opt = parser.parse_args() 101 | 102 | # Load char2idx dict from json file 103 | with open(opt.ch2ix) as f: 104 | char2idx = json.load(f) 105 | 106 | # Load override dict from json file 107 | override = {} 108 | 109 | try: 110 | with open(opt.override) as f: 111 | override = json.load(f) 112 | except FileNotFoundError: 113 | print("Override JSON file not provided.") 114 | 115 | # Create idx2char from char2idx dict 116 | idx2char = {idx:char for char,idx in char2idx.items()} 117 | 118 | # Calculate vocab_size from char2idx dict 119 | vocab_size = len(char2idx) 120 | 121 | # Rebuild model from checkpoint 122 | model = build_generative_model(vocab_size, opt.embed, opt.units, opt.layers, batch_size=1) 123 | model.load_weights(tf.train.latest_checkpoint(opt.model)) 124 | model.build(tf.TensorShape([1, None])) 125 | 126 | if not os.path.exists(GENERATED_DIR): 127 | os.makedirs(GENERATED_DIR) 128 | # Generate 5 midis 129 | for i in range(100): 130 | # Generate a midi as text 131 | print("Generate midi {}".format(i)) 132 | midi_txt = generate_midi(model, char2idx, idx2char, opt.seqinit, opt.seqlen, layer_idx=opt.cellix, override=override) 133 | 134 | 135 | me.write(midi_txt, os.path.join(GENERATED_DIR, "generated_Q1_{}.mid".format(i))) 136 | -------------------------------------------------------------------------------- /workspace/baseline/plot_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | # Directory where plots will be saved 9 | PLOTS_DIR = './results' 10 | 11 | def plot_logits(xs, ys, top_neurons): 12 | for n in top_neurons: 13 | plot_logit_and_save(xs, ys, n) 14 | 15 | def plot_logit_and_save(xs, ys, neuron_index): 16 | sentiment_unit = xs[:,neuron_index] 17 | 18 | # plt.title('Distribution of Logit Values') 19 | plt.ylabel('Number of Phrases') 20 | plt.xlabel('Value of the Sentiment Neuron') 21 | plt.hist(sentiment_unit[ys == -1], bins=50, alpha=0.5, label='Negative Phrases') 22 | plt.hist(sentiment_unit[ys == 1], bins=50, alpha=0.5, label='Positive Phrases') 23 | plt.legend() 24 | plt.savefig(os.path.join(PLOTS_DIR, "neuron_" + str(neuron_index) + '.png')) 25 | plt.clf() 26 | 27 | def plot_weight_contribs(coef): 28 | plt.title('Values of Resulting L1 Penalized Weights') 29 | plt.tick_params(axis='both', which='major') 30 | 31 | # Normalize weight contributions 32 | norm = np.linalg.norm(coef) 33 | coef = coef/norm 34 | 35 | plt.plot(range(len(coef[0])), coef.T) 36 | plt.xlabel('Neuron (Feature) Index') 37 | plt.ylabel('Neuron (Feature) weight') 38 | plt.savefig(os.path.join(PLOTS_DIR, "weight_contribs.png")) 39 | plt.clf() 40 | -------------------------------------------------------------------------------- /workspace/baseline/plot_results_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | # Directory where plots will be saved 9 | PLOTS_DIR = './results_base' 10 | 11 | def plot_logits(xs, ys, top_neurons): 12 | for n in top_neurons: 13 | plot_logit_and_save(xs, ys, n) 14 | 15 | def plot_logit_and_save(xs, ys, neuron_index): 16 | sentiment_unit = xs[:,neuron_index] 17 | 18 | # plt.title('Distribution of Logit Values') 19 | plt.ylabel('Number of Phrases') 20 | plt.xlabel('Value of the Sentiment Neuron') 21 | plt.hist(sentiment_unit[ys == -1], bins=50, alpha=0.5, label='Negative Phrases') 22 | plt.hist(sentiment_unit[ys == 1], bins=50, alpha=0.5, label='Positive Phrases') 23 | plt.legend() 24 | plt.savefig(os.path.join(PLOTS_DIR, "neuron_" + str(neuron_index) + '.png')) 25 | plt.clf() 26 | 27 | def plot_weight_contribs(coef): 28 | plt.title('Values of Resulting L1 Penalized Weights') 29 | plt.tick_params(axis='both', which='major') 30 | 31 | # Normalize weight contributions 32 | norm = np.linalg.norm(coef) 33 | coef = coef/norm 34 | 35 | plt.plot(range(len(coef[0])), coef.T) 36 | plt.xlabel('Neuron (Feature) Index') 37 | plt.ylabel('Neuron (Feature) weight') 38 | plt.savefig(os.path.join(PLOTS_DIR, "weight_contribs.png")) 39 | plt.clf() 40 | -------------------------------------------------------------------------------- /workspace/baseline/train_classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "4" 3 | 4 | import csv 5 | import json 6 | import pickle 7 | import argparse 8 | import numpy as np 9 | import tensorflow as tf 10 | import midi_encoder as me 11 | import plot_results_base as pr 12 | 13 | from train_generative import build_generative_model 14 | from sklearn.linear_model import LogisticRegression 15 | from sklearn.preprocessing import LabelBinarizer 16 | 17 | # Directory where trained model will be saved 18 | TRAIN_DIR = "./trained" 19 | DATA_TRAIN = '../data/train_split.csv' 20 | DATA_TEST = '../data/test_split.csv' 21 | DATASET_PATH = '../' 22 | 23 | def preprocess_sentence(text, front_pad='\n ', end_pad=''): 24 | text = text.replace('\n', ' ').strip() 25 | text = front_pad+text+end_pad 26 | return text 27 | 28 | def encode_sentence(model, text, char2idx, layer_idx): 29 | text = preprocess_sentence(text) 30 | 31 | # Reset LSTMs hidden and cell states 32 | model.reset_states() 33 | 34 | for c in text.split(" "): 35 | # Add the batch dimension 36 | try: 37 | input_eval = tf.expand_dims([char2idx[c]], 0) 38 | predictions = model(input_eval) 39 | except KeyError: 40 | if c != "": 41 | print("Can't process char", c) 42 | 43 | h_state, c_state = model.get_layer(index=layer_idx).states 44 | 45 | # remove the batch dimension 46 | #h_state = tf.squeeze(h_state, 0) 47 | c_state = tf.squeeze(c_state, 0) 48 | 49 | return tf.math.tanh(c_state).numpy() 50 | 51 | def build_dataset(datapath, generative_model, char2idx, layer_idx): 52 | xs, ys = [], [] 53 | 54 | csv_file = open(datapath, "r") 55 | data = csv.DictReader(csv_file) 56 | 57 | for row in data: 58 | label = int(row["label"]) 59 | filepath = row["filepath"] 60 | 61 | data_dir = os.path.dirname(datapath) 62 | phrase_path = filepath.replace('./', '') 63 | phrase_path = os.path.join(DATASET_PATH, phrase_path) 64 | encoded_path = os.path.splitext(filepath)[0]+'.npy' 65 | encoded_path = encoded_path.replace('./', '') 66 | encoded_path = os.path.join(DATASET_PATH, encoded_path) 67 | 68 | # Load midi file as text 69 | if os.path.isfile(encoded_path): 70 | encoding = np.load(encoded_path) 71 | else: 72 | text, vocab = me.load(phrase_path, transpose_range=1, stretching_range=1) 73 | 74 | # Encode midi text using generative lstm 75 | encoding = encode_sentence(generative_model, text, char2idx, layer_idx) 76 | 77 | # Save encoding in file to make it faster to load next time 78 | np.save(encoded_path, encoding) 79 | 80 | xs.append(encoding) 81 | ys.append(label) 82 | 83 | return np.array(xs), np.array(ys) 84 | 85 | def train_classifier_model(train_dataset, test_dataset, C=2**np.arange(-8, 1).astype(np.float), seed=42, penalty="l1"): 86 | trX, trY = train_dataset 87 | teX, teY = test_dataset 88 | 89 | scores = [] 90 | 91 | # Hyper-parameter optimization 92 | for i, c in enumerate(C): 93 | logreg_model = LogisticRegression(C=c, penalty=penalty, random_state=seed+i, solver="liblinear") 94 | logreg_model.fit(trX, trY) 95 | 96 | score = logreg_model.score(teX, teY) 97 | scores.append(score) 98 | 99 | c = C[np.argmax(scores)] 100 | 101 | sent_classfier = LogisticRegression(C=c, penalty=penalty, random_state=seed+len(C), solver="liblinear") 102 | sent_classfier.fit(trX, trY) 103 | 104 | score = sent_classfier.score(teX, teY) * 100. 105 | 106 | # Persist sentiment classifier 107 | with open(os.path.join(TRAIN_DIR, "classifier_ckpt.p"), "wb") as f: 108 | pickle.dump(sent_classfier, f) 109 | 110 | # Get activated neurons 111 | sentneuron_ixs = get_activated_neurons(sent_classfier) 112 | 113 | # Plot results 114 | pr.plot_weight_contribs(sent_classfier.coef_) 115 | pr.plot_logits(trX, trY, sentneuron_ixs) 116 | 117 | return sentneuron_ixs, score 118 | 119 | def get_activated_neurons(sent_classfier): 120 | neurons_not_zero = len(np.argwhere(sent_classfier.coef_)) 121 | 122 | weights = sent_classfier.coef_.T 123 | weight_penalties = np.squeeze(np.linalg.norm(weights, ord=1, axis=1)) 124 | 125 | if neurons_not_zero == 1: 126 | neuron_ixs = np.array([np.argmax(weight_penalties)]) 127 | elif neurons_not_zero >= np.log(len(weight_penalties)): 128 | neuron_ixs = np.argsort(weight_penalties)[-neurons_not_zero:][::-1] 129 | else: 130 | neuron_ixs = np.argpartition(weight_penalties, -neurons_not_zero)[-neurons_not_zero:] 131 | neuron_ixs = (neuron_ixs[np.argsort(weight_penalties[neuron_ixs])])[::-1] 132 | 133 | return neuron_ixs 134 | 135 | if __name__ == "__main__": 136 | 137 | # Parse arguments 138 | parser = argparse.ArgumentParser(description='train_classifier.py') 139 | parser.add_argument('--train', type=str, default=DATA_TRAIN, help="Train dataset.") 140 | parser.add_argument('--test' , type=str, default=DATA_TEST, help="Test dataset.") 141 | parser.add_argument('--model', type=str, default='./trained/', help="Checkpoint dir.") 142 | parser.add_argument('--ch2ix', type=str, default='./trained/char2idx.json', help="JSON file with char2idx encoding.") 143 | parser.add_argument('--embed', type=int, default=256, help="Embedding size.") 144 | parser.add_argument('--units', type=int, default=512, help="LSTM units.") 145 | parser.add_argument('--layers', type=int, default=4, help="LSTM layers.") 146 | parser.add_argument('--cellix', type=int, default=4, help="LSTM layer to use as encoder.") 147 | opt = parser.parse_args() 148 | 149 | # Load char2idx dict from json file 150 | with open(opt.ch2ix) as f: 151 | char2idx = json.load(f) 152 | 153 | # Calculate vocab_size from char2idx dict 154 | vocab_size = len(char2idx) 155 | 156 | # Rebuild generative model from checkpoint 157 | generative_model = build_generative_model(vocab_size, opt.embed, opt.units, opt.layers, batch_size=1) 158 | generative_model.load_weights(tf.train.latest_checkpoint(opt.model)) 159 | generative_model.build(tf.TensorShape([1, None])) 160 | 161 | # Build dataset from encoded labelled midis 162 | train_dataset = build_dataset(opt.train, generative_model, char2idx, opt.cellix) 163 | test_dataset = build_dataset(opt.test, generative_model, char2idx, opt.cellix) 164 | 165 | # Train model 166 | sentneuron_ixs, score = train_classifier_model(train_dataset, test_dataset) 167 | 168 | print("Total Neurons Used:", len(sentneuron_ixs), "\n", sentneuron_ixs) 169 | print("Test Accuracy:", score) 170 | -------------------------------------------------------------------------------- /workspace/baseline/train_generative.py: -------------------------------------------------------------------------------- 1 | import os 2 | #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"\ 3 | os.environ["CUDA_VISIBLE_DEVICES"] = '4' 4 | import json 5 | import argparse 6 | import numpy as np 7 | import tensorflow as tf 8 | import numpy as np 9 | 10 | 11 | import midi_encoder as me 12 | 13 | # Directory where the checkpoints will be saved 14 | TRAIN_DIR = "./trained/" 15 | 16 | #print('Version: ', tf.__version__) 17 | #print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU'))) 18 | #print("Check: ", tf.test.is_gpu_available(cuda_only=False, min_cuda_compute_capability=None)) 19 | 20 | 21 | def generative_loss(labels, logits): 22 | return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True) 23 | 24 | def build_generative_model(vocab_size, embed_dim, lstm_units, lstm_layers, batch_size, dropout=0): 25 | model = tf.keras.Sequential() 26 | 27 | model.add(tf.keras.layers.Embedding(vocab_size, embed_dim, batch_input_shape=[batch_size, None])) 28 | 29 | for i in range(max(1, lstm_layers)): 30 | model.add(tf.keras.layers.LSTM(lstm_units, return_sequences=True, stateful=True, dropout=dropout, recurrent_dropout=dropout)) 31 | 32 | model.add(tf.keras.layers.Dense(vocab_size)) 33 | 34 | return model 35 | 36 | def build_char2idx(train_vocab, test_vocab): 37 | # Merge train and test vocabulary 38 | vocab = list(train_vocab | test_vocab) 39 | vocab.sort() 40 | 41 | # Calculate vocab size 42 | vocab_size = len(vocab) 43 | 44 | # Create dict to support char to index conversion 45 | char2idx = { char:i for i,char in enumerate(vocab) } 46 | 47 | # Save char2idx encoding as a json file for generate midi later 48 | with open(os.path.join(TRAIN_DIR, "char2idx.json"), "w") as f: 49 | json.dump(char2idx, f) 50 | 51 | return char2idx, vocab_size 52 | 53 | def build_dataset(text, char2idx, seq_length, batch_size, buffer_size=10000): 54 | text_as_int = np.array([char2idx[c] for c in text.split(" ")]) 55 | char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int) 56 | 57 | sequences = char_dataset.batch(seq_length+1, drop_remainder=True) 58 | 59 | dataset = sequences.map(__split_input_target) 60 | dataset = dataset.shuffle(buffer_size).batch(batch_size, drop_remainder=True) 61 | 62 | return dataset 63 | 64 | def train_generative_model(model, train_dataset, test_dataset, epochs, learning_rate): 65 | # Compile model with given optimizer and defined loss 66 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 67 | model.compile(optimizer=optimizer, loss=generative_loss) 68 | 69 | # Name of the checkpoint files 70 | checkpoint_prefix = os.path.join(TRAIN_DIR, "generative_ckpt_{epoch}") 71 | my_callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, save_weights_only=True), 72 | tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=3, restore_best_weights=True)] 73 | 74 | 75 | return model.fit(train_dataset, epochs=epochs, validation_data=test_dataset, callbacks=my_callbacks) 76 | 77 | def __split_input_target(chunk): 78 | input_text = chunk[:-1] 79 | target_text = chunk[1:] 80 | return input_text, target_text 81 | 82 | if __name__ == "__main__": 83 | 84 | # Parse arguments 85 | parser = argparse.ArgumentParser(description='train_generative.py') 86 | parser.add_argument('--train', type=str, default='../data/train/', help="Train dataset.") 87 | parser.add_argument('--test' , type=str, default='../data/test/', help="Test dataset.") 88 | parser.add_argument('--model', type=str, required=False, help="Checkpoint dir.") 89 | parser.add_argument('--embed', type=int, default=256, help="Embedding size.") 90 | parser.add_argument('--units', type=int, default=512, help="LSTM units.") 91 | parser.add_argument('--layers', type=int, default=4, help="LSTM layers.") 92 | parser.add_argument('--batch', type=int, default=64, help="Batch size.") 93 | parser.add_argument('--epochs', type=int, default=30, help="Epochs.") 94 | parser.add_argument('--seqlen', type=int, default=256, help="Sequence lenght.") 95 | parser.add_argument('--lrate', type=float, default=0.00001, help="Learning rate.") 96 | parser.add_argument('--drop', type=float, default=0.05, help="Dropout.") 97 | opt = parser.parse_args() 98 | 99 | if not os.path.exists(TRAIN_DIR): 100 | os.makedirs(TRAIN_DIR) 101 | 102 | # Encode midi files as text with vocab 103 | train_text, train_vocab = me.load(opt.train) 104 | test_text, test_vocab = me.load(opt.test) 105 | 106 | # Build dictionary to map from char to integers 107 | char2idx, vocab_size = build_char2idx(train_vocab, test_vocab) 108 | 109 | # Build dataset from encoded unlabelled midis 110 | train_dataset = build_dataset(train_text, char2idx, opt.seqlen, opt.batch) 111 | test_dataset = build_dataset(test_text, char2idx, opt.seqlen, opt.batch) 112 | 113 | # Build generative model 114 | generative_model = build_generative_model(vocab_size, opt.embed, opt.units, opt.layers, opt.batch, opt.drop) 115 | 116 | if opt.model: 117 | # If pre-trained model was given as argument, load weights from disk 118 | print("Loading weights from {}...".format(opt.model)) 119 | generative_model.load_weights(tf.train.latest_checkpoint(opt.model)) 120 | 121 | # Train model 122 | history = train_generative_model(generative_model, train_dataset, test_dataset, opt.epochs, opt.lrate) 123 | print("Total of {} epochs used for training.".format(len(history.history['loss']))) 124 | loss_hist = history.history['loss'] 125 | print("Best loss from history: ", np.min(loss_hist)) 126 | -------------------------------------------------------------------------------- /workspace/transformer/main_cp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import ipdb 4 | from tqdm import tqdm 5 | import math 6 | import time 7 | import glob 8 | import datetime 9 | import random 10 | import pickle 11 | import json 12 | import numpy as np 13 | from collections import OrderedDict 14 | from argparse import ArgumentParser 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.optim as optim 20 | from torch.nn.utils import clip_grad_norm_ 21 | from torch.utils.data import Dataset, DataLoader 22 | 23 | 24 | import saver 25 | from models import TransformerModel, network_paras 26 | from utils import write_midi, get_random_string 27 | 28 | 29 | ################################################################################ 30 | # config 31 | ################################################################################ 32 | 33 | parser = ArgumentParser() 34 | parser.add_argument("--mode",default="train",type=str,choices=["train", "inference"]) 35 | parser.add_argument("--task_type",default="4-cls",type=str,choices=['4-cls', 'Arousal', 'Valence', 'ignore']) 36 | parser.add_argument("--gid", default= 0, type=int) 37 | parser.add_argument("--data_parallel", default= 0, type=int) 38 | 39 | parser.add_argument("--exp_name", default='output' , type=str) 40 | parser.add_argument("--load_ckt", default="none", type=str) #pre-train model 41 | parser.add_argument("--load_ckt_loss", default="25", type=str) #pre-train model 42 | parser.add_argument("--path_train_data", default='emopia', type=str) 43 | parser.add_argument("--data_root", default='../dataset/co-representation/', type=str) 44 | parser.add_argument("--load_dict", default="more_dictionary.pkl", type=str) 45 | parser.add_argument("--init_lr", default= 0.00001, type=float) 46 | # inference config 47 | 48 | parser.add_argument("--num_songs", default=5, type=int) 49 | parser.add_argument("--emo_tag", default=1, type=int) 50 | parser.add_argument("--out_dir", default='none', type=str) 51 | args = parser.parse_args() 52 | 53 | print('=== args ===') 54 | for arg in args.__dict__: 55 | print(arg, args.__dict__[arg]) 56 | print('=== args ===') 57 | # time.sleep(10) #sleep to check again if args are right 58 | 59 | 60 | MODE = args.mode 61 | task_type = args.task_type 62 | 63 | 64 | ###--- data ---### 65 | path_data_root = args.data_root 66 | 67 | path_train_data = os.path.join(path_data_root, args.path_train_data + '_data.npz') 68 | path_dictionary = os.path.join(path_data_root, args.load_dict) 69 | path_train_idx = os.path.join(path_data_root, args.path_train_data + '_fn2idx_map.json') 70 | path_train_data_cls_idx = os.path.join(path_data_root, args.path_train_data + '_data_idx.npz') 71 | 72 | assert os.path.exists(path_train_data) 73 | assert os.path.exists(path_dictionary) 74 | assert os.path.exists(path_train_idx) 75 | 76 | # if the dataset has the emotion label, get the cls_idx for the dataloader 77 | if args.path_train_data == 'emopia': 78 | assert os.path.exists(path_train_data_cls_idx) 79 | 80 | ###--- training config ---### 81 | 82 | if MODE == 'train': 83 | path_exp = 'exp/' + args.exp_name 84 | 85 | if args.data_parallel > 0: 86 | batch_size = 8 87 | else: 88 | batch_size = 4 #4 89 | 90 | gid = args.gid 91 | init_lr = args.init_lr #0.0001 92 | 93 | 94 | ###--- fine-tuning & inference config ---### 95 | if args.load_ckt == 'none': 96 | info_load_model = None 97 | print('NO pre-trained model used') 98 | 99 | else: 100 | info_load_model = ( 101 | # path to ckpt for loading 102 | 'exp/' + args.load_ckt, 103 | # loss 104 | args.load_ckt_loss 105 | ) 106 | 107 | 108 | if args.out_dir == 'none': 109 | path_gendir = os.path.join('exp/' + args.load_ckt, 'gen_midis', 'loss_'+ args.load_ckt_loss) 110 | else: 111 | path_gendir = args.out_dir 112 | 113 | num_songs = args.num_songs 114 | emotion_tag = args.emo_tag 115 | 116 | 117 | ################################################################################ 118 | # File IO 119 | ################################################################################ 120 | 121 | if args.data_parallel == 0: 122 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gid) 123 | 124 | 125 | ########################################################################################################################## 126 | # Script 127 | ########################################################################################################################## 128 | 129 | 130 | class PEmoDataset(Dataset): 131 | def __init__(self, 132 | 133 | task_type): 134 | 135 | self.train_data = np.load(path_train_data) 136 | self.train_x = self.train_data['x'] 137 | self.train_y = self.train_data['y'] 138 | self.train_mask = self.train_data['mask'] 139 | 140 | if task_type != 'ignore': 141 | 142 | self.cls_idx = np.load(path_train_data_cls_idx) 143 | self.cls_1_idx = self.cls_idx['cls_1_idx'] 144 | self.cls_2_idx = self.cls_idx['cls_2_idx'] 145 | self.cls_3_idx = self.cls_idx['cls_3_idx'] 146 | self.cls_4_idx = self.cls_idx['cls_4_idx'] 147 | 148 | if task_type == 'Arousal': 149 | print('preparing data for training "Arousal"') 150 | self.label_transfer('Arousal') 151 | 152 | elif task_type == 'Valence': 153 | print('preparing data for training "Valence"') 154 | self.label_transfer('Valence') 155 | 156 | 157 | self.train_x = torch.from_numpy(self.train_x).long() 158 | self.train_y = torch.from_numpy(self.train_y).long() 159 | self.train_mask = torch.from_numpy(self.train_mask).float() 160 | 161 | 162 | self.seq_len = self.train_x.shape[1] 163 | self.dim = self.train_x.shape[2] 164 | 165 | print('train_x: ', self.train_x.shape) 166 | 167 | def label_transfer(self, TYPE): 168 | if TYPE == 'Arousal': 169 | for i in range(self.train_x.shape[0]): 170 | if self.train_x[i][0][-1] in [1,2]: 171 | self.train_x[i][0][-1] = 1 172 | elif self.train_x[i][0][-1] in [3,4]: 173 | self.train_x[i][0][-1] = 2 174 | 175 | elif TYPE == 'Valence': 176 | for i in range(self.train_x.shape[0]): 177 | if self.train_x[i][0][-1] in [1,4]: 178 | self.train_x[i][0][-1] = 1 179 | elif self.train_x[i][0][-1] in [2,3]: 180 | self.train_x[i][0][-1] = 2 181 | 182 | 183 | 184 | def __getitem__(self, index): 185 | return self.train_x[index], self.train_y[index], self.train_mask[index] 186 | 187 | 188 | def __len__(self): 189 | return len(self.train_x) 190 | 191 | 192 | def prep_dataloader(task_type, batch_size, n_jobs=0): 193 | 194 | dataset = PEmoDataset(task_type) 195 | 196 | dataloader = DataLoader( 197 | dataset, batch_size, 198 | shuffle=False, drop_last=False, 199 | num_workers=n_jobs, pin_memory=True) 200 | return dataloader 201 | 202 | 203 | 204 | 205 | def train(): 206 | 207 | myseed = 42069 208 | np.random.seed(myseed) 209 | torch.manual_seed(myseed) 210 | if torch.cuda.is_available(): 211 | torch.cuda.manual_seed_all(myseed) 212 | 213 | 214 | # hyper params 215 | n_epoch = 4000 216 | max_grad_norm = 3 217 | 218 | # load 219 | dictionary = pickle.load(open(path_dictionary, 'rb')) 220 | event2word, word2event = dictionary 221 | 222 | train_loader = prep_dataloader(args.task_type, batch_size) 223 | 224 | # create saver 225 | saver_agent = saver.Saver(path_exp) 226 | 227 | # config 228 | n_class = [] # number of classes of each token. [56, 127, 18, 4, 85, 18, 41, 5] with key: [... , 25] 229 | for key in event2word.keys(): 230 | n_class.append(len(dictionary[0][key])) 231 | 232 | 233 | 234 | n_token = len(n_class) 235 | # log 236 | print('num of classes:', n_class) 237 | 238 | # init 239 | 240 | 241 | if args.data_parallel > 0 and torch.cuda.count() > 1: 242 | print("Let's use", torch.cuda.device_count(), "GPUs!") 243 | net = TransformerModel(n_class, data_parallel=True) 244 | net = nn.DataParallel(net) 245 | 246 | else: 247 | net = TransformerModel(n_class) 248 | 249 | net.cuda() 250 | net.train() 251 | n_parameters = network_paras(net) 252 | print('n_parameters: {:,}'.format(n_parameters)) 253 | saver_agent.add_summary_msg( 254 | ' > params amount: {:,d}'.format(n_parameters)) 255 | 256 | 257 | # load model 258 | if info_load_model: 259 | path_ckpt = info_load_model[0] # path to ckpt dir 260 | loss = info_load_model[1] # loss 261 | name = 'loss_' + str(loss) 262 | path_saved_ckpt = os.path.join(path_ckpt, name + '_params.pt') 263 | print('[*] load model from:', path_saved_ckpt) 264 | 265 | try: 266 | net.load_state_dict(torch.load(path_saved_ckpt)) 267 | except: 268 | # print('WARNING!!!!! Not the whole pre-train model is loaded, only load partial') 269 | # print('WARNING!!!!! Not the whole pre-train model is loaded, only load partial') 270 | # print('WARNING!!!!! Not the whole pre-train model is loaded, only load partial') 271 | # net.load_state_dict(torch.load(path_saved_ckpt), strict=False) 272 | 273 | state_dict = torch.load(path_saved_ckpt) 274 | new_state_dict = OrderedDict() 275 | for k, v in state_dict.items(): 276 | name = k[7:] 277 | new_state_dict[name] = v 278 | 279 | net.load_state_dict(new_state_dict) 280 | 281 | 282 | # optimizers 283 | optimizer = optim.Adam(net.parameters(), lr=init_lr) 284 | 285 | 286 | # run 287 | start_time = time.time() 288 | for epoch in range(n_epoch): 289 | acc_loss = 0 290 | acc_losses = np.zeros(n_token) 291 | 292 | 293 | num_batch = len(train_loader) 294 | print(' num_batch:', num_batch) 295 | 296 | for bidx, (batch_x, batch_y, batch_mask) in enumerate(train_loader): # num_batch 297 | saver_agent.global_step_increment() 298 | 299 | batch_x = batch_x.cuda() 300 | batch_y = batch_y.cuda() 301 | batch_mask = batch_mask.cuda() 302 | 303 | 304 | losses = net(batch_x, batch_y, batch_mask) 305 | 306 | if args.data_parallel > 0: 307 | loss = 0 308 | calculated_loss = [] 309 | for i in range(n_token): 310 | 311 | loss += ((losses[i][0][0] + losses[i][0][1]) / (losses[i][1][0] + losses[i][1][1])) 312 | calculated_loss.append((losses[i][0][0] + losses[i][0][1]) / (losses[i][1][0] + losses[i][1][1])) 313 | loss = loss / n_token 314 | 315 | 316 | else: 317 | loss = (losses[0] + losses[1] + losses[2] + losses[3] + losses[4] + losses[5] + losses[6] + losses[7]) / 8 318 | 319 | 320 | # Update 321 | net.zero_grad() 322 | loss.backward() 323 | 324 | 325 | if max_grad_norm is not None: 326 | clip_grad_norm_(net.parameters(), max_grad_norm) 327 | optimizer.step() 328 | 329 | if args.data_parallel > 0: 330 | 331 | sys.stdout.write('{}/{} | Loss: {:06f} | {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}\r'.format( 332 | bidx, num_batch, loss, calculated_loss[0], calculated_loss[1], calculated_loss[2], calculated_loss[3], calculated_loss[4], calculated_loss[5], calculated_loss[6], calculated_loss[7])) 333 | sys.stdout.flush() 334 | 335 | 336 | # acc 337 | acc_losses += np.array([l.item() for l in calculated_loss]) 338 | 339 | 340 | 341 | else: 342 | 343 | sys.stdout.write('{}/{} | Loss: {:06f} | {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}\r'.format( 344 | bidx, num_batch, loss, losses[0], losses[1], losses[2], losses[3], losses[4], losses[5], losses[6], losses[7])) 345 | sys.stdout.flush() 346 | 347 | 348 | # acc 349 | acc_losses += np.array([l.item() for l in losses]) 350 | 351 | 352 | 353 | acc_loss += loss.item() 354 | 355 | # log 356 | saver_agent.add_summary('batch loss', loss.item()) 357 | 358 | 359 | # epoch loss 360 | runtime = time.time() - start_time 361 | epoch_loss = acc_loss / num_batch 362 | acc_losses = acc_losses / num_batch 363 | print('------------------------------------') 364 | print('epoch: {}/{} | Loss: {} | time: {}'.format( 365 | epoch, n_epoch, epoch_loss, str(datetime.timedelta(seconds=runtime)))) 366 | 367 | each_loss_str = '{:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}\r'.format( 368 | acc_losses[0], acc_losses[1], acc_losses[2], acc_losses[3], acc_losses[4], acc_losses[5], acc_losses[6], acc_losses[7]) 369 | 370 | print(' >', each_loss_str) 371 | 372 | saver_agent.add_summary('epoch loss', epoch_loss) 373 | saver_agent.add_summary('epoch each loss', each_loss_str) 374 | 375 | # save model, with policy 376 | loss = epoch_loss 377 | if 0.4 < loss <= 0.8: 378 | fn = int(loss * 10) * 10 379 | saver_agent.save_model(net, name='loss_' + str(fn)) 380 | elif 0.08 < loss <= 0.40: 381 | fn = int(loss * 100) 382 | saver_agent.save_model(net, name='loss_' + str(fn)) 383 | elif loss <= 0.08: 384 | print('Finished') 385 | return 386 | else: 387 | saver_agent.save_model(net, name='loss_high') 388 | 389 | 390 | def generate(): 391 | 392 | # path 393 | path_ckpt = info_load_model[0] # path to ckpt dir 394 | loss = info_load_model[1] # loss 395 | name = 'loss_' + str(loss) 396 | path_saved_ckpt = os.path.join(path_ckpt, name + '_params.pt') 397 | 398 | # load 399 | dictionary = pickle.load(open(path_dictionary, 'rb')) 400 | event2word, word2event = dictionary 401 | 402 | # outdir 403 | os.makedirs(path_gendir, exist_ok=True) 404 | 405 | # config 406 | n_class = [] # num of classes for each token 407 | for key in event2word.keys(): 408 | n_class.append(len(dictionary[0][key])) 409 | 410 | 411 | n_token = len(n_class) 412 | 413 | # init model 414 | net = TransformerModel(n_class, is_training=False) 415 | net.cuda() 416 | net.eval() 417 | 418 | # load model 419 | print('[*] load model from:', path_saved_ckpt) 420 | 421 | 422 | try: 423 | net.load_state_dict(torch.load(path_saved_ckpt)) 424 | except: 425 | state_dict = torch.load(path_saved_ckpt) 426 | new_state_dict = OrderedDict() 427 | for k, v in state_dict.items(): 428 | name = k[7:] 429 | new_state_dict[name] = v 430 | 431 | net.load_state_dict(new_state_dict) 432 | 433 | 434 | # gen 435 | start_time = time.time() 436 | song_time_list = [] 437 | words_len_list = [] 438 | 439 | cnt_tokens_all = 0 440 | sidx = 0 441 | while sidx < num_songs: 442 | # try: 443 | start_time = time.time() 444 | print('current idx:', sidx) 445 | 446 | if n_token == 8: 447 | path_outfile = os.path.join(path_gendir, 'emo_{}_{}'.format( str(emotion_tag), get_random_string(10))) 448 | res, _ = net.inference_from_scratch(dictionary, emotion_tag, n_token) 449 | 450 | 451 | if res is None: 452 | continue 453 | np.save(path_outfile + '.npy', res) 454 | write_midi(res, path_outfile + '.mid', word2event) 455 | 456 | song_time = time.time() - start_time 457 | word_len = len(res) 458 | print('song time:', song_time) 459 | print('word_len:', word_len) 460 | words_len_list.append(word_len) 461 | song_time_list.append(song_time) 462 | 463 | sidx += 1 464 | 465 | 466 | print('ave token time:', sum(words_len_list) / sum(song_time_list)) 467 | print('ave song time:', np.mean(song_time_list)) 468 | 469 | runtime_result = { 470 | 'song_time':song_time_list, 471 | 'words_len_list': words_len_list, 472 | 'ave token time:': sum(words_len_list) / sum(song_time_list), 473 | 'ave song time': float(np.mean(song_time_list)), 474 | } 475 | 476 | with open('runtime_stats.json', 'w') as f: 477 | json.dump(runtime_result, f) 478 | 479 | 480 | 481 | 482 | 483 | if __name__ == '__main__': 484 | # -- training -- # 485 | if MODE == 'train': 486 | train() 487 | # -- inference -- # 488 | elif MODE == 'inference': 489 | generate() 490 | else: 491 | pass 492 | -------------------------------------------------------------------------------- /workspace/transformer/models.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from utils import sampling 8 | 9 | # import fast_transformers_local 10 | from fast_transformers.builders import TransformerEncoderBuilder as TransformerEncoderBuilder_local 11 | from fast_transformers.builders import RecurrentEncoderBuilder as RecurrentEncoderBuilder_local 12 | from fast_transformers.masking import TriangularCausalMask as TriangularCausalMask_local 13 | 14 | 15 | D_MODEL = 512 16 | N_LAYER = 12 17 | N_HEAD = 8 18 | 19 | ################################################################################ 20 | # Model 21 | ################################################################################ 22 | 23 | def network_paras(model): 24 | # compute only trainable params 25 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 26 | params = sum([np.prod(p.size()) for p in model_parameters]) 27 | return params 28 | 29 | 30 | 31 | class Embeddings(nn.Module): 32 | def __init__(self, n_token, d_model): 33 | super(Embeddings, self).__init__() 34 | self.lut = nn.Embedding(n_token, d_model) 35 | self.d_model = d_model 36 | 37 | def forward(self, x): 38 | return self.lut(x) * math.sqrt(self.d_model) 39 | 40 | 41 | class PositionalEncoding(nn.Module): 42 | def __init__(self, d_model, dropout=0.1, max_len=20000): 43 | super(PositionalEncoding, self).__init__() 44 | self.dropout = nn.Dropout(p=dropout) 45 | 46 | pe = torch.zeros(max_len, d_model) 47 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 48 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 49 | pe[:, 0::2] = torch.sin(position * div_term) 50 | pe[:, 1::2] = torch.cos(position * div_term) 51 | pe = pe.unsqueeze(0) 52 | self.register_buffer('pe', pe) 53 | 54 | def forward(self, x): 55 | x = x + self.pe[:, :x.size(1), :] 56 | return self.dropout(x) 57 | 58 | 59 | 60 | class TransformerModel(nn.Module): 61 | def __init__(self, n_token, is_training=True, data_parallel=False): 62 | super(TransformerModel, self).__init__() 63 | self.data_parallel = data_parallel 64 | # --- params config --- # 65 | self.n_token = n_token # == n_class 66 | self.d_model = D_MODEL 67 | self.n_layer = N_LAYER # 68 | self.dropout = 0.1 69 | self.n_head = N_HEAD # 70 | self.d_head = D_MODEL // N_HEAD 71 | self.d_inner = 2048 72 | self.loss_func = nn.CrossEntropyLoss(reduction='none') 73 | if len(self.n_token) == 8: 74 | self.emb_sizes = [128, 256, 64, 32, 512, 128, 128, 128] 75 | elif len(self.n_token) == 9: 76 | self.emb_sizes = [128, 256, 64, 32, 512, 128, 128, 128, 128] #128 77 | 78 | # --- modules config --- # 79 | # embeddings 80 | print('>>>>>:', self.n_token) 81 | self.word_emb_tempo = Embeddings(self.n_token[0], self.emb_sizes[0]) 82 | self.word_emb_chord = Embeddings(self.n_token[1], self.emb_sizes[1]) 83 | self.word_emb_barbeat = Embeddings(self.n_token[2], self.emb_sizes[2]) 84 | self.word_emb_type = Embeddings(self.n_token[3], self.emb_sizes[3]) 85 | self.word_emb_pitch = Embeddings(self.n_token[4], self.emb_sizes[4]) 86 | self.word_emb_duration = Embeddings(self.n_token[5], self.emb_sizes[5]) 87 | self.word_emb_velocity = Embeddings(self.n_token[6], self.emb_sizes[6]) 88 | self.word_emb_emotion = Embeddings(self.n_token[7], self.emb_sizes[7]) 89 | if len(self.n_token) == 9: 90 | self.word_emb_key = Embeddings(self.n_token[8], self.emb_sizes[8]) 91 | self.pos_emb = PositionalEncoding(self.d_model, self.dropout) 92 | 93 | 94 | # linear 95 | self.in_linear = nn.Linear(np.sum(self.emb_sizes), self.d_model) 96 | 97 | # encoder 98 | if is_training: 99 | # encoder (training) 100 | self.get_encoder('encoder') 101 | 102 | 103 | else: 104 | # encoder (inference) 105 | print(' [o] using RNN backend.') 106 | self.get_encoder('autoregred') 107 | 108 | # blend with type 109 | self.project_concat_type = nn.Linear(self.d_model + 32, self.d_model) 110 | 111 | # individual output 112 | self.proj_tempo = nn.Linear(self.d_model, self.n_token[0]) 113 | self.proj_chord = nn.Linear(self.d_model, self.n_token[1]) 114 | self.proj_barbeat = nn.Linear(self.d_model, self.n_token[2]) 115 | self.proj_type = nn.Linear(self.d_model, self.n_token[3]) 116 | self.proj_pitch = nn.Linear(self.d_model, self.n_token[4]) 117 | self.proj_duration = nn.Linear(self.d_model, self.n_token[5]) 118 | self.proj_velocity = nn.Linear(self.d_model, self.n_token[6]) 119 | self.proj_emotion = nn.Linear(self.d_model, self.n_token[7]) 120 | if len(self.n_token) == 9: 121 | self.proj_key = nn.Linear( self.d_model, self.n_token[8]) 122 | 123 | 124 | def compute_loss(self, predict, target, loss_mask): 125 | if self.data_parallel: 126 | loss = self.loss_func(predict, target) 127 | loss = loss * loss_mask 128 | return torch.sum(loss), torch.sum(loss_mask) 129 | else: 130 | loss = self.loss_func(predict, target) 131 | loss = loss * loss_mask 132 | loss = torch.sum(loss) / torch.sum(loss_mask) 133 | return loss 134 | 135 | 136 | 137 | def forward(self, x, target, loss_mask): 138 | 139 | 140 | h, y_type = self.forward_hidden(x, is_training=True) 141 | 142 | 143 | if len(self.n_token) == 9: 144 | y_tempo, y_chord, y_barbeat, y_pitch, y_duration, y_velocity, y_emotion, y_key, emo_embd = self.forward_output(h, target) 145 | else: 146 | y_tempo, y_chord, y_barbeat, y_pitch, y_duration, y_velocity, y_emotion, emo_embd = self.forward_output(h, target) 147 | 148 | 149 | 150 | # reshape (b, s, f) -> (b, f, s) 151 | y_tempo = y_tempo[:, ...].permute(0, 2, 1) 152 | y_chord = y_chord[:, ...].permute(0, 2, 1) 153 | y_barbeat = y_barbeat[:, ...].permute(0, 2, 1) 154 | y_type = y_type[:, ...].permute(0, 2, 1) 155 | y_pitch = y_pitch[:, ...].permute(0, 2, 1) 156 | y_duration = y_duration[:, ...].permute(0, 2, 1) 157 | y_velocity = y_velocity[:, ...].permute(0, 2, 1) 158 | y_emotion = y_emotion[:, ...].permute(0, 2, 1) 159 | if len(self.n_token) == 9: 160 | y_key = y_key[:, ...].permute(0, 2, 1) 161 | 162 | # loss 163 | loss_tempo = self.compute_loss( 164 | y_tempo, target[..., 0], loss_mask) 165 | loss_chord = self.compute_loss( 166 | y_chord, target[..., 1], loss_mask) 167 | loss_barbeat = self.compute_loss( 168 | y_barbeat, target[..., 2], loss_mask) 169 | loss_type = self.compute_loss( 170 | y_type, target[..., 3], loss_mask) 171 | loss_pitch = self.compute_loss( 172 | y_pitch, target[..., 4], loss_mask) 173 | loss_duration = self.compute_loss( 174 | y_duration, target[..., 5], loss_mask) 175 | loss_velocity = self.compute_loss( 176 | y_velocity, target[..., 6], loss_mask) 177 | loss_emotion = self.compute_loss( 178 | y_emotion, target[..., 7], loss_mask) 179 | 180 | 181 | if len(self.n_token) == 9: 182 | loss_key = self.compute_loss( 183 | y_key, target[..., 8], loss_mask) 184 | 185 | return loss_tempo, loss_chord, loss_barbeat, loss_type, loss_pitch, loss_duration, loss_velocity, loss_emotion, loss_key 186 | 187 | else: 188 | return loss_tempo, loss_chord, loss_barbeat, loss_type, loss_pitch, loss_duration, loss_velocity, loss_emotion 189 | 190 | 191 | 192 | def get_encoder(self, TYPE): 193 | if TYPE == 'encoder': 194 | self.transformer_encoder = TransformerEncoderBuilder_local.from_kwargs( 195 | n_layers=self.n_layer, 196 | n_heads=self.n_head, 197 | query_dimensions=self.d_model//self.n_head, 198 | value_dimensions=self.d_model//self.n_head, 199 | feed_forward_dimensions=2048, 200 | activation='gelu', 201 | dropout=0.1, 202 | attention_type="causal-linear", 203 | ).get() 204 | 205 | 206 | elif TYPE == 'autoregred': 207 | self.transformer_encoder = RecurrentEncoderBuilder_local.from_kwargs( 208 | n_layers=self.n_layer, 209 | n_heads=self.n_head, 210 | query_dimensions=self.d_model//self.n_head, 211 | value_dimensions=self.d_model//self.n_head, 212 | feed_forward_dimensions=2048, 213 | activation='gelu', 214 | dropout=0.1, 215 | attention_type="causal-linear", 216 | ).get() 217 | 218 | 219 | 220 | def forward_hidden(self, x, memory=None, is_training=False): 221 | ''' 222 | linear transformer: b x s x f 223 | x.shape=(bs, nf) 224 | ''' 225 | 226 | # embeddings 227 | emb_tempo = self.word_emb_tempo(x[..., 0]) 228 | emb_chord = self.word_emb_chord(x[..., 1]) 229 | emb_barbeat = self.word_emb_barbeat(x[..., 2]) 230 | emb_type = self.word_emb_type(x[..., 3]) 231 | emb_pitch = self.word_emb_pitch(x[..., 4]) 232 | emb_duration = self.word_emb_duration(x[..., 5]) 233 | emb_velocity = self.word_emb_velocity(x[..., 6]) 234 | 235 | emb_emotion = self.word_emb_emotion(x[..., 7]) 236 | 237 | if len(self.n_token) == 9: 238 | emb_key = self.word_emb_key(x[..., 8]) 239 | 240 | # same emotion class have same emb_emotion 241 | 242 | embs = torch.cat( 243 | [ 244 | emb_tempo, 245 | emb_chord, 246 | emb_barbeat, 247 | emb_type, 248 | emb_pitch, 249 | emb_duration, 250 | emb_velocity, 251 | emb_emotion, 252 | emb_key 253 | ], dim=-1) 254 | 255 | else: 256 | embs = torch.cat( 257 | [ 258 | emb_tempo, 259 | emb_chord, 260 | emb_barbeat, 261 | emb_type, 262 | emb_pitch, 263 | emb_duration, 264 | emb_velocity, 265 | emb_emotion 266 | 267 | ], dim=-1) 268 | 269 | 270 | emb_linear = self.in_linear(embs) 271 | pos_emb = self.pos_emb(emb_linear) 272 | 273 | 274 | # assert False 275 | layer_outputs = [] 276 | # transformer 277 | if is_training: 278 | # mask 279 | attn_mask = TriangularCausalMask_local(pos_emb.size(1), device=x.device) 280 | 281 | h = self.transformer_encoder(pos_emb, attn_mask) # y: b x s x d_model 282 | 283 | 284 | # project type 285 | y_type = self.proj_type(h) 286 | 287 | 288 | return h, y_type 289 | 290 | else: 291 | pos_emb = pos_emb.squeeze(0) 292 | 293 | # self.get_encoder('autoregred') 294 | # self.transformer_encoder.cuda() 295 | h, memory = self.transformer_encoder(pos_emb, memory=memory) # y: s x d_model 296 | 297 | # project type 298 | y_type = self.proj_type(h) 299 | 300 | return h, y_type, memory 301 | 302 | 303 | def forward_output(self, h, y): 304 | ''' 305 | for training 306 | ''' 307 | # tf_skip_emption = self.word_emb_emotion(y[..., 7]) 308 | tf_skip_type = self.word_emb_type(y[..., 3]) 309 | 310 | emo_embd = h[:, 0] 311 | 312 | # project other 313 | y_concat_type = torch.cat([h, tf_skip_type], dim=-1) 314 | y_ = self.project_concat_type(y_concat_type) 315 | 316 | y_tempo = self.proj_tempo(y_) 317 | y_chord = self.proj_chord(y_) 318 | y_barbeat = self.proj_barbeat(y_) 319 | y_pitch = self.proj_pitch(y_) 320 | y_duration = self.proj_duration(y_) 321 | y_velocity = self.proj_velocity(y_) 322 | y_emotion = self.proj_emotion(y_) 323 | 324 | if len(self.n_token) == 9: 325 | y_key = self.proj_key(y_) 326 | 327 | return y_tempo, y_chord, y_barbeat, y_pitch, y_duration, y_velocity, y_emotion, y_key, emo_embd 328 | 329 | else: 330 | return y_tempo, y_chord, y_barbeat, y_pitch, y_duration, y_velocity, y_emotion, emo_embd 331 | 332 | 333 | 334 | 335 | def froward_output_sampling(self, h, y_type, is_training=False): 336 | ''' 337 | for inference 338 | ''' 339 | 340 | # sample type 341 | y_type_logit = y_type[0, :] # token class size 342 | cur_word_type = sampling(y_type_logit, p=0.90, is_training=is_training) # int 343 | if cur_word_type is None: 344 | return None, None 345 | 346 | if is_training: 347 | type_word_t = cur_word_type.long().unsqueeze(0).unsqueeze(0) 348 | else: 349 | type_word_t = torch.from_numpy( 350 | np.array([cur_word_type])).long().cuda().unsqueeze(0) # shape = (1,1) 351 | 352 | tf_skip_type = self.word_emb_type(type_word_t).squeeze(0) # shape = (1, embd_size) 353 | 354 | 355 | # concat 356 | y_concat_type = torch.cat([h, tf_skip_type], dim=-1) 357 | y_ = self.project_concat_type(y_concat_type) 358 | 359 | # project other 360 | y_tempo = self.proj_tempo(y_) 361 | y_chord = self.proj_chord(y_) 362 | y_barbeat = self.proj_barbeat(y_) 363 | 364 | y_pitch = self.proj_pitch(y_) 365 | y_duration = self.proj_duration(y_) 366 | y_velocity = self.proj_velocity(y_) 367 | y_emotion = self.proj_emotion(y_) 368 | 369 | 370 | 371 | # sampling gen_cond 372 | cur_word_tempo = sampling(y_tempo, t=1.2, p=0.9, is_training=is_training) 373 | cur_word_barbeat = sampling(y_barbeat, t=1.2, is_training=is_training) 374 | cur_word_chord = sampling(y_chord, p=0.99, is_training=is_training) 375 | cur_word_pitch = sampling(y_pitch, p=0.9, is_training=is_training) 376 | cur_word_duration = sampling(y_duration, t=2, p=0.9, is_training=is_training) 377 | cur_word_velocity = sampling(y_velocity, t=5, is_training=is_training) 378 | 379 | if len(self.n_token) == 9: 380 | y_key = self.proj_key(y_) 381 | cur_word_key = sampling(y_key, t=1.2, is_training=is_training) 382 | 383 | curs = [ 384 | cur_word_tempo, 385 | cur_word_chord, 386 | cur_word_barbeat, 387 | cur_word_pitch, 388 | cur_word_duration, 389 | cur_word_velocity, 390 | cur_word_key 391 | ] 392 | 393 | else: 394 | curs = [ 395 | cur_word_tempo, 396 | cur_word_chord, 397 | cur_word_barbeat, 398 | cur_word_pitch, 399 | cur_word_duration, 400 | cur_word_velocity 401 | ] 402 | 403 | if None in curs: 404 | return None, None 405 | 406 | 407 | 408 | if is_training: 409 | cur_word_emotion = torch.from_numpy(np.array([0])).long().cuda().squeeze(0) 410 | # collect 411 | next_arr = torch.tensor([ 412 | cur_word_tempo, 413 | cur_word_chord, 414 | cur_word_barbeat, 415 | cur_word_type, 416 | cur_word_pitch, 417 | cur_word_duration, 418 | cur_word_velocity, 419 | cur_word_emotion 420 | ]) 421 | 422 | else: 423 | cur_word_emotion = 0 424 | 425 | 426 | # collect 427 | if len(self.n_token) == 9: 428 | next_arr = np.array([ 429 | cur_word_tempo, 430 | cur_word_chord, 431 | cur_word_barbeat, 432 | cur_word_type, 433 | cur_word_pitch, 434 | cur_word_duration, 435 | cur_word_velocity, 436 | cur_word_emotion, 437 | cur_word_key 438 | ]) 439 | else: 440 | next_arr = np.array([ 441 | cur_word_tempo, 442 | cur_word_chord, 443 | cur_word_barbeat, 444 | cur_word_type, 445 | cur_word_pitch, 446 | cur_word_duration, 447 | cur_word_velocity, 448 | cur_word_emotion 449 | ]) 450 | 451 | return next_arr, y_emotion 452 | 453 | 454 | 455 | 456 | 457 | def inference_from_scratch(self, dictionary, emotion_tag, key_tag=None, n_token=8, display=True): 458 | event2word, word2event = dictionary 459 | 460 | 461 | classes = word2event.keys() 462 | 463 | 464 | def print_word_cp(cp): 465 | 466 | result = [word2event[k][cp[idx]] for idx, k in enumerate(classes)] 467 | 468 | for r in result: 469 | print('{:15s}'.format(str(r)), end=' | ') 470 | print('') 471 | 472 | generated_key = None 473 | 474 | 475 | 476 | if n_token == 9: 477 | 478 | if key_tag: 479 | 480 | target_emotion = [0, 0, 0, 1, 0, 0, 0, emotion_tag, 0] 481 | target_key = [0, 0, 0, 4, 0, 0, 0, 0, key_tag] 482 | 483 | init = np.array([ 484 | target_emotion, # emotion 485 | target_key, 486 | [0, 0, 1, 2, 0, 0, 0, 0, 0] # bar 487 | ]) 488 | 489 | else: 490 | target_emotion = [0, 0, 0, 1, 0, 0, 0, emotion_tag, 0] 491 | init = np.array([ 492 | target_emotion, # emotion 493 | [0, 0, 1, 2, 0, 0, 0, 0, 0] # bar 494 | ]) 495 | 496 | elif n_token == 8: 497 | target_emotion = [0, 0, 0, 1, 0, 0, 0, emotion_tag] 498 | 499 | init = np.array([ 500 | target_emotion, # emotion 501 | [0, 0, 1, 2, 0, 0, 0, 0] # bar 502 | ]) 503 | 504 | 505 | cnt_token = len(init) 506 | with torch.no_grad(): 507 | final_res = [] 508 | memory = None 509 | h = None 510 | 511 | cnt_bar = 1 512 | init_t = torch.from_numpy(init).long().cuda() 513 | print('------ initiate ------') 514 | 515 | if n_token == 9 and key_tag is None: 516 | # Emotion token 517 | step = 0 518 | if display: 519 | print_word_cp(init[step, :]) 520 | input_ = init_t[step, :].unsqueeze(0).unsqueeze(0) 521 | final_res.append(init[step, :][None, ...]) 522 | h, y_type, memory = self.forward_hidden( 523 | input_, memory, is_training=False) 524 | 525 | #generate KEY 526 | next_arr, y_emotion = self.froward_output_sampling(h, y_type) 527 | if next_arr is None: 528 | return None, None 529 | 530 | generated_key = next_arr[-1] 531 | final_res.append(next_arr[None, ...]) 532 | if display: 533 | print_word_cp(next_arr) 534 | input_ = torch.from_numpy(next_arr).long().cuda() 535 | input_ = input_.unsqueeze(0).unsqueeze(0) 536 | h, y_type, memory = self.forward_hidden( 537 | input_, memory, is_training=False) 538 | 539 | # init bar 540 | step = 1 541 | print_word_cp(init[step, :]) 542 | input_ = init_t[step, :].unsqueeze(0).unsqueeze(0) 543 | final_res.append(init[step, :][None, ...]) 544 | h, y_type, memory = self.forward_hidden( 545 | input_, memory, is_training=False) 546 | 547 | 548 | 549 | else: 550 | for step in range(init.shape[0]): 551 | 552 | print_word_cp(init[step, :]) 553 | input_ = init_t[step, :].unsqueeze(0).unsqueeze(0) 554 | final_res.append(init[step, :][None, ...]) 555 | 556 | h, y_type, memory = self.forward_hidden( 557 | input_, memory, is_training=False) 558 | 559 | 560 | 561 | 562 | print('------ generate ------') 563 | while(True): 564 | # sample others 565 | next_arr, y_emotion = self.froward_output_sampling(h, y_type) 566 | if next_arr is None: 567 | return None, None 568 | 569 | final_res.append(next_arr[None, ...]) 570 | 571 | if display: 572 | print('bar:', cnt_bar, end= ' ==') 573 | print_word_cp(next_arr) 574 | 575 | # forward 576 | input_ = torch.from_numpy(next_arr).long().cuda() 577 | input_ = input_.unsqueeze(0).unsqueeze(0) 578 | h, y_type, memory = self.forward_hidden( 579 | input_, memory, is_training=False) 580 | 581 | # end of sequence 582 | if word2event['type'][next_arr[3]] == 'EOS': 583 | break 584 | 585 | if word2event['bar-beat'][next_arr[2]] == 'Bar': 586 | cnt_bar += 1 587 | 588 | print('\n--------[Done]--------') 589 | final_res = np.concatenate(final_res) 590 | print(final_res.shape) 591 | 592 | 593 | return final_res, generated_key 594 | 595 | -------------------------------------------------------------------------------- /workspace/transformer/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import logging 5 | import datetime 6 | import collections 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | class Saver(object): 12 | def __init__( 13 | self, 14 | exp_dir, 15 | mode='w'): 16 | 17 | self.exp_dir = exp_dir 18 | self.init_time = time.time() 19 | self.global_step = 0 20 | 21 | # makedirs 22 | os.makedirs(exp_dir, exist_ok=True) 23 | 24 | # logging config 25 | path_logger = os.path.join(exp_dir, 'log.txt') 26 | logging.basicConfig( 27 | level=logging.DEBUG, 28 | format='%(message)s', 29 | filename=path_logger, 30 | filemode=mode) 31 | self.logger = logging.getLogger('training monitor') 32 | 33 | def add_summary_msg(self, msg): 34 | self.logger.debug(msg) 35 | 36 | def add_summary( 37 | self, 38 | key, 39 | val, 40 | step=None, 41 | cur_time=None): 42 | 43 | if cur_time is None: 44 | cur_time = time.time() - self.init_time 45 | if step is None: 46 | step = self.global_step 47 | 48 | # write msg (key, val, step, time) 49 | if isinstance(val, float): 50 | msg_str = '{:10s} | {:.10f} | {:10d} | {}'.format( 51 | key, 52 | val, 53 | step, 54 | cur_time 55 | ) 56 | else: 57 | msg_str = '{:10s} | {} | {:10d} | {}'.format( 58 | key, 59 | val, 60 | step, 61 | cur_time 62 | ) 63 | 64 | self.logger.debug(msg_str) 65 | 66 | def save_model( 67 | self, 68 | model, 69 | optimizer=None, 70 | outdir=None, 71 | name='model'): 72 | 73 | if outdir is None: 74 | outdir = self.exp_dir 75 | print(' [*] saving model to {}, name: {}'.format(outdir, name)) 76 | torch.save(model, os.path.join(outdir, name+'.pt')) 77 | torch.save(model.state_dict(), os.path.join(outdir, name+'_params.pt')) 78 | 79 | if optimizer is not None: 80 | torch.save(optimizer.state_dict(), os.path.join(outdir, name+'_opt.pt')) 81 | 82 | def load_model( 83 | self, 84 | path_exp, 85 | device='cpu', 86 | name='model.pt'): 87 | 88 | path_pt = os.path.join(path_exp, name) 89 | print(' [*] restoring model from', path_pt) 90 | model = torch.load(path_pt, map_location=torch.device(device)) 91 | return model 92 | 93 | def global_step_increment(self): 94 | self.global_step += 1 95 | 96 | """ 97 | file modes 98 | 'a': 99 | Opens a file for appending. The file pointer is at the end of the file if the file exists. 100 | That is, the file is in the append mode. If the file does not exist, it creates a new file for writing. 101 | 102 | 'w': 103 | Opens a file for writing only. Overwrites the file if the file exists. 104 | If the file does not exist, creates a new file for writing. 105 | """ 106 | 107 | def make_loss_report( 108 | path_log, 109 | path_figure='loss.png', 110 | dpi=100): 111 | 112 | # load logfile 113 | monitor_vals = collections.defaultdict(list) 114 | with open(path_logfile, 'r') as f: 115 | for line in f: 116 | try: 117 | line = line.strip() 118 | key, val, step, acc_time = line.split(' | ') 119 | monitor_vals[key].append((float(val), int(step), acc_time)) 120 | except: 121 | continue 122 | 123 | # collect 124 | step_train = [item[1] for item in monitor_vals['train loss']] 125 | vals_train = [item[0] for item in monitor_vals['train loss']] 126 | 127 | step_valid = [item[1] for item in monitor_vals['valid loss']] 128 | vals_valid = [item[0] for item in monitor_vals['valid loss']] 129 | 130 | x_min = step_valid[np.argmin(vals_valid)] 131 | y_min = min(vals_valid) 132 | 133 | # plot 134 | fig = plt.figure(dpi=dpi) 135 | plt.title('training process') 136 | plt.plot(step_train, vals_train, label='train') 137 | plt.plot(step_valid, vals_valid, label='valid') 138 | plt.yscale('log') 139 | plt.plot([x_min], [y_min], 'ro') 140 | plt.legend(loc='upper right') 141 | plt.tight_layout() 142 | plt.savefig(path_figure) 143 | 144 | ''' 145 | author: wayn391@mastertones 146 | ''' 147 | 148 | import os 149 | import time 150 | import torch 151 | import logging 152 | import datetime 153 | import collections 154 | import numpy as np 155 | import matplotlib.pyplot as plt 156 | 157 | 158 | class Saver(object): 159 | def __init__( 160 | self, 161 | exp_dir, 162 | mode='w'): 163 | 164 | self.exp_dir = exp_dir 165 | self.init_time = time.time() 166 | self.global_step = 0 167 | 168 | # makedirs 169 | os.makedirs(exp_dir, exist_ok=True) 170 | 171 | # logging config 172 | path_logger = os.path.join(exp_dir, 'log.txt') 173 | logging.basicConfig( 174 | level=logging.DEBUG, 175 | format='%(message)s', 176 | filename=path_logger, 177 | filemode=mode) 178 | self.logger = logging.getLogger('training monitor') 179 | 180 | def add_summary_msg(self, msg): 181 | self.logger.debug(msg) 182 | 183 | def add_summary( 184 | self, 185 | key, 186 | val, 187 | step=None, 188 | cur_time=None): 189 | 190 | if cur_time is None: 191 | cur_time = time.time() - self.init_time 192 | if step is None: 193 | step = self.global_step 194 | 195 | # write msg (key, val, step, time) 196 | if isinstance(val, float): 197 | msg_str = '{:10s} | {:.10f} | {:10d} | {}'.format( 198 | key, 199 | val, 200 | step, 201 | cur_time 202 | ) 203 | else: 204 | msg_str = '{:10s} | {} | {:10d} | {}'.format( 205 | key, 206 | val, 207 | step, 208 | cur_time 209 | ) 210 | 211 | self.logger.debug(msg_str) 212 | 213 | def save_model( 214 | self, 215 | model, 216 | optimizer=None, 217 | outdir=None, 218 | name='model'): 219 | 220 | if outdir is None: 221 | outdir = self.exp_dir 222 | print(' [*] saving model to {}, name: {}'.format(outdir, name)) 223 | # torch.save(model, os.path.join(outdir, name+'.pt')) 224 | torch.save(model.state_dict(), os.path.join(outdir, name+'_params.pt')) 225 | 226 | if optimizer is not None: 227 | torch.save(optimizer.state_dict(), os.path.join(outdir, name+'_opt.pt')) 228 | 229 | def load_model( 230 | self, 231 | path_exp, 232 | device='cpu', 233 | name='model.pt'): 234 | 235 | path_pt = os.path.join(path_exp, name) 236 | print(' [*] restoring model from', path_pt) 237 | model = torch.load(path_pt, map_location=torch.device(device)) 238 | return model 239 | 240 | def global_step_increment(self): 241 | self.global_step += 1 242 | 243 | """ 244 | file modes 245 | 'a': 246 | Opens a file for appending. The file pointer is at the end of the file if the file exists. 247 | That is, the file is in the append mode. If the file does not exist, it creates a new file for writing. 248 | 249 | 'w': 250 | Opens a file for writing only. Overwrites the file if the file exists. 251 | If the file does not exist, creates a new file for writing. 252 | """ 253 | 254 | def make_loss_report( 255 | path_log, 256 | path_figure='loss.png', 257 | dpi=100): 258 | 259 | # load logfile 260 | monitor_vals = collections.defaultdict(list) 261 | with open(path_logfile, 'r') as f: 262 | for line in f: 263 | try: 264 | line = line.strip() 265 | key, val, step, acc_time = line.split(' | ') 266 | monitor_vals[key].append((float(val), int(step), acc_time)) 267 | except: 268 | continue 269 | 270 | # collect 271 | step_train = [item[1] for item in monitor_vals['train loss']] 272 | vals_train = [item[0] for item in monitor_vals['train loss']] 273 | 274 | step_valid = [item[1] for item in monitor_vals['valid loss']] 275 | vals_valid = [item[0] for item in monitor_vals['valid loss']] 276 | 277 | x_min = step_valid[np.argmin(vals_valid)] 278 | y_min = min(vals_valid) 279 | 280 | # plot 281 | fig = plt.figure(dpi=dpi) 282 | plt.title('training process') 283 | plt.plot(step_train, vals_train, label='train') 284 | plt.plot(step_valid, vals_valid, label='valid') 285 | plt.yscale('log') 286 | plt.plot([x_min], [y_min], 'ro') 287 | plt.legend(loc='upper right') 288 | plt.tight_layout() 289 | plt.savefig(path_figure) 290 | 291 | -------------------------------------------------------------------------------- /workspace/transformer/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import glob 4 | import random 5 | import string 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from tqdm import tqdm 10 | import pickle 11 | import miditoolkit 12 | from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note 13 | 14 | BEAT_RESOL = 480 15 | BAR_RESOL = BEAT_RESOL * 4 16 | TICK_RESOL = BEAT_RESOL // 4 17 | 18 | 19 | def write_midi(words, path_outfile, word2event): 20 | 21 | class_keys = word2event.keys() 22 | # words = np.load(path_infile) 23 | midi_obj = miditoolkit.midi.parser.MidiFile() 24 | 25 | bar_cnt = 0 26 | cur_pos = 0 27 | 28 | all_notes = [] 29 | 30 | cnt_error = 0 31 | for i in range(len(words)): 32 | vals = [] 33 | for kidx, key in enumerate(class_keys): 34 | vals.append(word2event[key][words[i][kidx]]) 35 | # print(vals) 36 | 37 | if vals[3] == 'Metrical': 38 | if vals[2] == 'Bar': 39 | bar_cnt += 1 40 | elif 'Beat' in vals[2]: 41 | beat_pos = int(vals[2].split('_')[1]) 42 | cur_pos = bar_cnt * BAR_RESOL + beat_pos * TICK_RESOL 43 | 44 | # chord 45 | if vals[1] != 'CONTI' and vals[1] != 0: 46 | midi_obj.markers.append( 47 | Marker(text=str(vals[1]), time=cur_pos)) 48 | 49 | if vals[0] != 'CONTI' and vals[0] != 0: 50 | tempo = int(vals[0].split('_')[-1]) 51 | midi_obj.tempo_changes.append( 52 | TempoChange(tempo=tempo, time=cur_pos)) 53 | else: 54 | pass 55 | elif vals[3] == 'Note': 56 | 57 | try: 58 | pitch = vals[4].split('_')[-1] 59 | duration = vals[5].split('_')[-1] 60 | velocity = vals[6].split('_')[-1] 61 | 62 | if int(duration) == 0: 63 | duration = 60 64 | end = cur_pos + int(duration) 65 | 66 | all_notes.append( 67 | Note( 68 | pitch=int(pitch), 69 | start=cur_pos, 70 | end=end, 71 | velocity=int(velocity)) 72 | ) 73 | except: 74 | continue 75 | else: 76 | pass 77 | 78 | # save midi 79 | piano_track = Instrument(0, is_drum=False, name='piano') 80 | piano_track.notes = all_notes 81 | midi_obj.instruments = [piano_track] 82 | midi_obj.dump(path_outfile) 83 | 84 | 85 | ################################################################################ 86 | # Sampling 87 | ################################################################################ 88 | # -- temperature -- # 89 | def softmax_with_temperature(logits, temperature): 90 | probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature)) 91 | if np.isnan(probs).any(): 92 | return None 93 | else: 94 | return probs 95 | 96 | 97 | 98 | ## gumbel 99 | def gumbel_softmax(logits, temperature): 100 | return F.gumbel_softmax(logits, tau=temperature, hard=True) 101 | 102 | 103 | def weighted_sampling(probs): 104 | probs /= sum(probs) 105 | sorted_probs = np.sort(probs)[::-1] 106 | sorted_index = np.argsort(probs)[::-1] 107 | word = np.random.choice(sorted_index, size=1, p=sorted_probs)[0] 108 | return word 109 | 110 | 111 | # -- nucleus -- # 112 | def nucleus(probs, p): 113 | 114 | probs /= (sum(probs) + 1e-5) 115 | sorted_probs = np.sort(probs)[::-1] 116 | sorted_index = np.argsort(probs)[::-1] 117 | cusum_sorted_probs = np.cumsum(sorted_probs) 118 | after_threshold = cusum_sorted_probs > p 119 | if sum(after_threshold) > 0: 120 | last_index = np.where(after_threshold)[0][0] + 1 121 | candi_index = sorted_index[:last_index] 122 | else: 123 | candi_index = sorted_index[:] 124 | candi_probs = [probs[i] for i in candi_index] 125 | candi_probs /= sum(candi_probs) 126 | try: 127 | word = np.random.choice(candi_index, size=1, p=candi_probs)[0] 128 | except: 129 | ipdb.set_trace() 130 | return word 131 | 132 | 133 | 134 | def sampling(logit, p=None, t=1.0, is_training=False): 135 | 136 | 137 | if is_training: 138 | logit = logit.squeeze() 139 | probs = gumbel_softmax(logits=logit, temperature=t) 140 | 141 | return torch.argmax(probs) 142 | 143 | else: 144 | logit = logit.squeeze().cpu().numpy() 145 | probs = softmax_with_temperature(logits=logit, temperature=t) 146 | 147 | if probs is None: 148 | return None 149 | 150 | if p is not None: 151 | cur_word = nucleus(probs, p=p) 152 | 153 | else: 154 | cur_word = weighted_sampling(probs) 155 | return cur_word 156 | 157 | 158 | 159 | 160 | 161 | 162 | def get_random_string(length): 163 | # choose from all lowercase letter 164 | letters = string.ascii_lowercase 165 | result_str = ''.join(random.choice(letters) for i in range(length)) 166 | return result_str 167 | 168 | 169 | 170 | ''' 171 | 假如 classifier 是 pre-trained, 172 | 那就要過 gumbel softmax (因為 classifer 是看 real data) 173 | 假如不是,classsifier 直接吃 training phase 的東西,那直接喂 logit/ softmax 也可以 174 | 175 | KEY: classifier 有沒有要看 real data 176 | 有: gumbel 177 | 沒有: 不用 gumbel 178 | 179 | 喂給 classifier 有 2 種選擇: 180 | 1. logit 181 | 2. probs -> 不行,因為還是得要在喂給 forward_hidden... 所以需要 word 182 | ''' 183 | 184 | ''' 185 | def compile_data(test_folder): 186 | MAX_LEN = 1024 187 | wordfiles = glob.glob(os.path.join(test_folder, '*.npy')) 188 | n_files = len(wordfiles) 189 | 190 | x_list = [] 191 | y_list = [] 192 | f_name_list = [] 193 | for fidx in range(n_files): 194 | file = wordfiles[fidx] 195 | 196 | words = np.load(file) 197 | num_words = len(words) 198 | 199 | eos_arr = words[-1][None, ...] 200 | if num_words >= MAX_LEN - 2: 201 | print('too long!', num_words) 202 | continue 203 | 204 | x = words[:-1].copy() #without EOS 205 | y = words[1:].copy() 206 | seq_len = len(x) 207 | print(' > seq_len:', seq_len) 208 | 209 | # pad with eos 210 | pad = np.tile( 211 | eos_arr, 212 | (MAX_LEN-seq_len, 1)) 213 | 214 | x = np.concatenate([x, pad], axis=0) 215 | y = np.concatenate([y, pad], axis=0) 216 | 217 | # collect 218 | if x.shape != (1024, 8): 219 | print(x.shape) 220 | exit() 221 | x_list.append(x.reshape(1, 1024,8)) 222 | y_list.append(y.reshape(1, 1024,8)) 223 | f_name_list.append(file) 224 | 225 | # x_final = np.array(x_list) 226 | # y_final = np.array(y_list) 227 | # f_name_list = np.array(f_name_list) 228 | return x_list, y_list, f_name_list 229 | 230 | 231 | 232 | def take_embd_scratch(embd_net, test_folder): 233 | 234 | # unpack 235 | batch_size = 1 236 | train_x, train_y, f_name_list = compile_data(test_folder) 237 | num_batch = len(train_x) // batch_size 238 | 239 | 240 | 241 | # load model 242 | path_ckpt = 'exp/0309-1857' # path to ckpt dir 243 | loss = 30 # loss 244 | name = 'loss_' + str(loss) 245 | path_saved_ckpt = os.path.join(path_ckpt, name + '_params.pt') 246 | print('[*] load model from:', path_saved_ckpt) 247 | embd_net.load_state_dict(torch.load(path_saved_ckpt)) 248 | 249 | embd_filename = 'temp/embd_' 250 | 251 | 252 | while train_x: 253 | 254 | batch_x = train_x.pop() 255 | batch_y = train_y.pop() 256 | fname = f_name_list.pop() 257 | 258 | batch_x = torch.from_numpy(batch_x).long().cuda() 259 | batch_y = torch.from_numpy(batch_y).long().cuda() 260 | 261 | 262 | h, y_type, layer_outputs = embd_net.forward_hidden(batch_x) 263 | 264 | _, _, _, _, _, _, _, eight_y_ = embd_net.forward_output(layer_outputs[7], batch_y) 265 | 266 | np.save(embd_filename + fname.split('/')[-1], eight_y_.detach().cpu().numpy()) 267 | 268 | return embd_filename 269 | 270 | ''' 271 | --------------------------------------------------------------------------------