├── ckpt └── .gitignore ├── logs └── .gitignore ├── generated └── .gitignore ├── .gitignore ├── test.mid ├── music_bpe_exec ├── src ├── fairseq │ ├── linear_transformer_inference │ │ ├── __init__.py │ │ └── linear_transformer_multi.py │ ├── linear_transformer │ │ ├── __init__.py │ │ ├── linear_transformer_lm.py │ │ └── linear_transformer_multi.py │ ├── gen_batch.py │ ├── make_data.py │ └── gen_utils.py ├── musicBPE │ ├── MANIFEST.in │ ├── fastBPE │ │ ├── main.cc │ │ ├── fastBPE.pyx │ │ └── fastBPE.hpp │ ├── LICENSE │ ├── README.md │ └── setup.py ├── encoding.py └── preprocess │ ├── get_bpe_data.py │ └── preprocess_midi.py ├── model_complete.jpg ├── vocab.sh ├── data ├── midis │ ├── beethoven_symphony_9_1_(c)cvikl.mid │ ├── beethoven_symphony_9_2_(c)cvikl.mid │ ├── beethoven_symphony_9_3_(c)cvikl.mid │ └── beethoven_symphony_9_4_(c)cvikl.mid ├── .gitignore ├── model_spec │ └── linear_4096_chord_bpe_hardloss1 │ │ ├── vocabs │ │ ├── vocab_1.json │ │ ├── vocab_2.json │ │ ├── vocab_3.json │ │ ├── ori_dict.json │ │ └── vocab_0.json │ │ └── bin │ │ └── dict.txt └── bpe_res │ └── codes.txt ├── requirements.txt ├── config.sh ├── LICENSE ├── train_linear_chord.sh └── README.md /ckpt/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /generated/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | .ipynb_checkpoints 4 | *.log -------------------------------------------------------------------------------- /test.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/symphonynet/SymphonyNet/HEAD/test.mid -------------------------------------------------------------------------------- /music_bpe_exec: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/symphonynet/SymphonyNet/HEAD/music_bpe_exec -------------------------------------------------------------------------------- /src/fairseq/linear_transformer_inference/__init__.py: -------------------------------------------------------------------------------- 1 | from . import linear_transformer_multi -------------------------------------------------------------------------------- /model_complete.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/symphonynet/SymphonyNet/HEAD/model_complete.jpg -------------------------------------------------------------------------------- /src/fairseq/linear_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from . import linear_transformer_lm, linear_transformer_multi -------------------------------------------------------------------------------- /src/musicBPE/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include fastBPE/*.cc 2 | include fastBPE/*.hpp 3 | include fastBPE/*.pyx 4 | -------------------------------------------------------------------------------- /vocab.sh: -------------------------------------------------------------------------------- 1 | SIZE_0=233 2 | SIZE_1=36 3 | SIZE_2=23 4 | SIZE_3=35 5 | MAX_REL_POS=70 6 | MAX_MEA_POS=533 7 | -------------------------------------------------------------------------------- /data/midis/beethoven_symphony_9_1_(c)cvikl.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/symphonynet/SymphonyNet/HEAD/data/midis/beethoven_symphony_9_1_(c)cvikl.mid -------------------------------------------------------------------------------- /data/midis/beethoven_symphony_9_2_(c)cvikl.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/symphonynet/SymphonyNet/HEAD/data/midis/beethoven_symphony_9_2_(c)cvikl.mid -------------------------------------------------------------------------------- /data/midis/beethoven_symphony_9_3_(c)cvikl.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/symphonynet/SymphonyNet/HEAD/data/midis/beethoven_symphony_9_3_(c)cvikl.mid -------------------------------------------------------------------------------- /data/midis/beethoven_symphony_9_4_(c)cvikl.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/symphonynet/SymphonyNet/HEAD/data/midis/beethoven_symphony_9_4_(c)cvikl.mid -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | !midis/ 2 | !model_spec/linear_4096_chord_bpe_hardloss1/vocabs/ 3 | !model_spec/linear_4096_chord_bpe_hardloss1/bin/dict.txt 4 | !.gitignore 5 | !bpe_res/codes.txt 6 | * 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | fairseq==0.10.2 3 | tensorboardX 4 | six 5 | more_itertools 6 | p_tqdm 7 | miditoolkit 8 | chorder 9 | scipy 10 | matplotlib 11 | pytorch-fast-transformers==0.4.0 12 | -------------------------------------------------------------------------------- /config.sh: -------------------------------------------------------------------------------- 1 | BPE=0 2 | IGNORE_META_LOSS=1 3 | PI_LEVEL=2 4 | RECOVER=0 5 | SEED=1998 6 | MAX_POS_LEN=4096 7 | MAXPIECES=10000000 8 | RATIO=4 9 | SOR=4 10 | TOTAL_UPDATES=210000 11 | WARMUP_UPDATES=5000 12 | PEAK_LR=0.0003 13 | BATCH_SIZE=128 14 | MAX_SENTENCES=2 15 | -------------------------------------------------------------------------------- /data/model_spec/linear_4096_chord_bpe_hardloss1/vocabs/vocab_1.json: -------------------------------------------------------------------------------- 1 | {"0": "", "1": "", "2": "", "3": "", "4": "r1", "5": "r2", "6": "r3", "7": "r4", "8": "r5", "9": "r6", "10": "r7", "11": "r8", "12": "r9", "13": "ra", "14": "rb", "15": "rc", "16": "rd", "17": "re", "18": "rf", "19": "rg", "20": "rh", "21": "ri", "22": "rj", "23": "rk", "24": "rl", "25": "rm", "26": "rn", "27": "ro", "28": "rp", "29": "rq", "30": "rr", "31": "rs", "32": "rt", "33": "ru", "34": "rv", "35": "rw"} -------------------------------------------------------------------------------- /data/model_spec/linear_4096_chord_bpe_hardloss1/vocabs/vocab_2.json: -------------------------------------------------------------------------------- 1 | {"0": "", "1": "", "2": "", "3": "", "4": "t0", "5": "t1", "6": "t2", "7": "t3", "8": "t4", "9": "t5", "10": "t6", "11": "t7", "12": "t8", "13": "t9", "14": "ta", "15": "tb", "16": "tc", "17": "td", "18": "te", "19": "tf", "20": "tg", "21": "th", "22": "ti", "23": "tj", "24": "tk", "25": "tl", "26": "tm", "27": "tn", "28": "to", "29": "tp", "30": "tq", "31": "tr", "32": "ts", "33": "tt", "34": "tu", "35": "tv", "36": "tw", "37": "tx", "38": "ty", "39": "tz", "40": "tA", "41": "tB", "42": "tC", "43": "tD"} -------------------------------------------------------------------------------- /src/musicBPE/fastBPE/main.cc: -------------------------------------------------------------------------------- 1 | #include "fastBPE.hpp" 2 | 3 | using namespace std; 4 | using namespace fastBPE; 5 | 6 | void printUsage() { 7 | cerr 8 | << "usage: music_bpe_exec \n\n" 9 | << "The commands supported by fastBPE is:\n\n" 10 | << "learnbpe nCodes input1 [input2] learn BPE codes from one or two " 11 | "text files\n" 12 | << endl; 13 | } 14 | 15 | 16 | int main(int argc, char **argv) { 17 | if (argc < 2) { 18 | printUsage(); 19 | exit(EXIT_FAILURE); 20 | } 21 | string command = argv[1]; 22 | if (command == "learnbpe") { 23 | assert(argc == 4 || argc == 5); 24 | learnbpe(stoi(argv[2]), argv[3], argc == 5 ? argv[4] : ""); 25 | } else { 26 | printUsage(); 27 | exit(EXIT_FAILURE); 28 | } 29 | return 0; 30 | } 31 | -------------------------------------------------------------------------------- /src/musicBPE/fastBPE/fastBPE.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language = c++ 3 | 4 | from libcpp.vector cimport vector 5 | from libcpp.string cimport string 6 | 7 | cdef extern from "fastBPE.hpp" namespace "fastBPE": 8 | cdef cppclass BPEApplyer: 9 | BPEApplyer(const string& codes_path, const string& vocab_path) 10 | vector[string] apply(vector[string]& sentences) 11 | 12 | cdef class fastBPE: 13 | cdef BPEApplyer* c_obj 14 | 15 | def __dealloc__(self): 16 | del self.c_obj 17 | 18 | def __init__(self, codes_path, vocab_path=""): 19 | self.c_obj = new BPEApplyer(codes_path.encode(), vocab_path.encode()) 20 | 21 | def apply(self, sentences): 22 | cdef vector[string] s = [x.encode() for x in sentences] 23 | cdef vector[string] res = self.c_obj.apply(s) 24 | return [x.decode() for x in res] 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 CCOM NLP4Music AI Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/musicBPE/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2019 Guillaume Lample,Timothée Lacroix 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/musicBPE/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Music BPE 3 | This submodule is forked from https://github.com/glample/fastBPE, and adapted for music BPE. 4 | See more details at the original repository. 5 | 6 | ## Installation 7 | 8 | Compile with: 9 | ``` 10 | g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o music_bpe_exec 11 | ``` 12 | 13 | ## Usage: 14 | 15 | ### List commands 16 | ``` 17 | ./music_bpe_exec 18 | usage: music_bpe_exec 19 | 20 | The commands supported by fastBPE are: 21 | 22 | learnbpe nCodes input1 [input2] learn BPE codes from one or two text files 23 | 24 | ``` 25 | 26 | 27 | ### Learn codes 28 | ``` 29 | ./music_bpe_exec learnbpe 40000 train.de train.en > codes 30 | ``` 31 | 32 | ### Learn codes in preprocess/get_bpe_data.py 33 | ``` 34 | # First copy the executable file 'music_bpe_exec' to project's root directory 35 | # !cp music_bpe_exec ../../ 36 | output_dir = 'data/bpe_res/' 37 | with open(output_dir+'codes.txt', 'w') as stdout: 38 | with open(output_dir+'merged_voc_list.txt', 'w') as stderr: 39 | subprocess.run(['./musicbpe', 'learnbpe', f'{MERGE_CNT}', output_dir+'ori_voc_cnt.txt'], stdout=stdout, stderr=stderr) 40 | ``` 41 | 42 | -------------------------------------------------------------------------------- /src/musicBPE/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages, Extension 2 | from distutils.command.sdist import sdist as _sdist 3 | 4 | 5 | try: 6 | from Cython.Build import cythonize 7 | except ImportError: 8 | use_cython = False 9 | else: 10 | use_cython = True 11 | 12 | 13 | if use_cython: 14 | extension = 'pyx' 15 | else: 16 | extension = 'cpp' 17 | 18 | 19 | extensions = [ 20 | Extension( 21 | 'fastBPE', 22 | [ "fastBPE/fastBPE." + extension ], 23 | language='c++', 24 | extra_compile_args=[ 25 | "-std=c++11", "-Ofast", "-pthread" 26 | ], 27 | ), 28 | ] 29 | if use_cython: 30 | extensions = cythonize(extensions) 31 | 32 | 33 | with open('README.md') as f: 34 | readme = f.read() 35 | 36 | 37 | setup( 38 | name = 'fastBPE', 39 | version = '0.1.1', 40 | description = 'C++ implementation of Neural Machine Translation of Rare Words with Subword Units, with Python API.', 41 | url = 'https://github.com/glample/fastBPE', 42 | long_description = readme, 43 | long_description_content_type = 'text/markdown', 44 | ext_package = '', 45 | ext_modules = extensions, 46 | packages=[ 47 | 'fastBPE', 48 | ], 49 | ) 50 | -------------------------------------------------------------------------------- /data/model_spec/linear_4096_chord_bpe_hardloss1/vocabs/vocab_3.json: -------------------------------------------------------------------------------- 1 | {"0": "", "1": "", "2": "", "3": "", "4": "x0", "5": "x1", "6": "x2", "7": "x3", "8": "x4", "9": "x5", "10": "x6", "11": "x7", "12": "x8", "13": "x9", "14": "xa", "15": "xb", "16": "xc", "17": "xd", "18": "xe", "19": "xf", "20": "xg", "21": "xh", "22": "xi", "23": "xj", "24": "xk", "25": "xl", "26": "xm", "27": "xn", "28": "xo", "29": "xp", "30": "xq", "31": "xr", "32": "xs", "33": "xt", "34": "xu", "35": "xv", "36": "xw", "37": "xx", "38": "xy", "39": "xz", "40": "xA", "41": "xB", "42": "xC", "43": "xD", "44": "xE", "45": "xF", "46": "xG", "47": "xH", "48": "xI", "49": "xJ", "50": "xK", "51": "xL", "52": "xM", "53": "xN", "54": "xO", "55": "xP", "56": "xQ", "57": "xR", "58": "xS", "59": "xT", "60": "xU", "61": "xV", "62": "xW", "63": "xX", "64": "xY", "65": "xZ", "66": "X0", "67": "X1", "68": "X2", "69": "X3", "70": "X4", "71": "X5", "72": "X6", "73": "X7", "74": "X8", "75": "X9", "76": "Xa", "77": "Xb", "78": "Xc", "79": "Xd", "80": "Xe", "81": "Xf", "82": "Xg", "83": "Xh", "84": "Xi", "85": "Xj", "86": "Xk", "87": "Xl", "88": "Xm", "89": "Xn", "90": "Xo", "91": "Xp", "92": "Xq", "93": "Xr", "94": "Xs", "95": "Xt", "96": "Xu", "97": "Xv", "98": "Xw", "99": "Xx", "100": "Xy", "101": "Xz", "102": "XA", "103": "XB", "104": "XC", "105": "XD", "106": "XE", "107": "XF", "108": "XG", "109": "XH", "110": "XI", "111": "XJ", "112": "XK", "113": "XL", "114": "XM", "115": "XN", "116": "XO", "117": "XP", "118": "XQ", "119": "XR", "120": "XS", "121": "XT", "122": "XU", "123": "XV", "124": "XW", "125": "XX", "126": "XY", "127": "XZ", "128": "y0", "129": "y1", "130": "y2", "131": "y3", "132": "y4"} -------------------------------------------------------------------------------- /train_linear_chord.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | while read line;do 4 | eval "$line" 5 | done < config.sh 6 | 7 | while read line;do 8 | eval "$line" 9 | done < vocab.sh 10 | 11 | # for model training 12 | if [ $BPE -eq 0 ]; then 13 | DATA_BIN=linear_${MAX_POS_LEN}_chord_hardloss${IGNORE_META_LOSS} 14 | else 15 | DATA_BIN=linear_${MAX_POS_LEN}_chord_bpe_hardloss${IGNORE_META_LOSS} 16 | fi 17 | DATA_BIN_DIR=data/model_spec/${DATA_BIN}/bin 18 | 19 | 20 | N_GPU_LOCAL=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) 21 | UPDATE_FREQ=$((${BATCH_SIZE} / ${MAX_SENTENCES} / ${N_GPU_LOCAL})) 22 | NN_ARCH=linear_transformer_multi 23 | CHECKPOINT_SUFFIX=${DATA_BIN}_PI${PI_LEVEL} 24 | 25 | 26 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" PYTHONWARNINGS="ignore" fairseq-train ${DATA_BIN_DIR} \ 27 | --seed ${SEED} \ 28 | --user-dir src/fairseq/linear_transformer \ 29 | --task symphony_modeling --criterion multiple_loss \ 30 | --save-dir ckpt/ --restore-file ckpt/checkpoint_last_${CHECKPOINT_SUFFIX}.pt \ 31 | --arch ${NN_ARCH} --sample-break-mode complete_doc --tokens-per-sample ${MAX_POS_LEN} --sample-overlap-rate ${SOR}\ 32 | --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 --clip-norm 0.0 \ 33 | --lr ${PEAK_LR} --lr-scheduler polynomial_decay --warmup-updates ${WARMUP_UPDATES} --total-num-update ${TOTAL_UPDATES} \ 34 | --dropout 0.1 --weight-decay 0.01 \ 35 | --batch-size ${MAX_SENTENCES} --update-freq ${UPDATE_FREQ} \ 36 | --max-update ${TOTAL_UPDATES} --log-format simple --log-interval 100 \ 37 | --checkpoint-suffix _${CHECKPOINT_SUFFIX} \ 38 | --tensorboard-logdir logs/${CHECKPOINT_SUFFIX} \ 39 | --ratio ${RATIO} --evt-voc-size ${SIZE_0} --dur-voc-size ${SIZE_1} --trk-voc-size ${SIZE_2} --ins-voc-size ${SIZE_3} \ 40 | --max-rel-pos ${MAX_REL_POS} --max-mea-pos ${MAX_MEA_POS} --perm-inv ${PI_LEVEL} \ 41 | 2>&1 | tee ${CHECKPOINT_SUFFIX}_part${RECOVER}.log -------------------------------------------------------------------------------- /src/fairseq/gen_batch.py: -------------------------------------------------------------------------------- 1 | import os, sys, time 2 | 3 | MAX_POS_LEN = 4096 4 | PI_LEVEL = 2 5 | IGNORE_META_LOSS = 1 6 | 7 | BPE = "_bpe" 8 | # BPE = "" 9 | DATA_BIN=f"linear_{MAX_POS_LEN}_chord{BPE}_hardloss{IGNORE_META_LOSS}" 10 | CHECKPOINT_SUFFIX=f"{DATA_BIN}_PI{PI_LEVEL}" 11 | DATA_BIN_DIR=f"data/model_spec/{DATA_BIN}/bin/" 12 | DATA_VOC_DIR=f"data/model_spec/{DATA_BIN}/vocabs/" 13 | from gen_utils import process_prime_midi, gen_one, get_trk_ins_map, get_note_seq, note_seq_to_midi_file, music_dict 14 | music_dict.load_vocabs_bpe(DATA_VOC_DIR, 'data/bpe_res/' if BPE == '_bpe' else None) 15 | 16 | 17 | from fairseq.models import FairseqLanguageModel 18 | custom_lm = FairseqLanguageModel.from_pretrained('.', 19 | checkpoint_file=f'ckpt/checkpoint_last_{CHECKPOINT_SUFFIX}.pt', 20 | data_name_or_path=DATA_BIN_DIR, 21 | user_dir="src/fairseq/linear_transformer_inference") 22 | print(f'Generation using model: {CHECKPOINT_SUFFIX}') 23 | 24 | m = custom_lm.models[0] 25 | m.cuda() 26 | m.eval() 27 | 28 | 29 | GEN_DIR = f'generated/{CHECKPOINT_SUFFIX}/' 30 | os.makedirs(GEN_DIR, exist_ok=True) 31 | 32 | 33 | 34 | if __name__ == '__main__': 35 | if len(sys.argv) != 5: 36 | print('usage: python src/fairseq/gen_batch.py ') 37 | exit(0) 38 | midi_name = sys.argv[1].split('/')[-1][:-4] 39 | max_measure_cnt = int(sys.argv[2]) 40 | max_chord_measure_cnt = int(sys.argv[3]) 41 | prime, ins_label = process_prime_midi(sys.argv[1], max_measure_cnt, max_chord_measure_cnt) 42 | gen_cnt = int(sys.argv[4]) 43 | for i in range(gen_cnt): 44 | while(True): 45 | try: 46 | generated, ins_logits = gen_one(m, prime, MIN_LEN = 1024) 47 | break 48 | except Exception as e: 49 | print(e) 50 | continue 51 | trk_ins_map = get_trk_ins_map(generated, ins_logits) 52 | note_seq = get_note_seq(generated, trk_ins_map) 53 | #print(f'{len(note_seq)} notes generated.') 54 | #print(note_seq) 55 | timestamp = time.strftime("%m-%d_%H-%M-%S", time.localtime()) 56 | note_seq_to_midi_file(note_seq, f'{GEN_DIR}{midi_name}_prime{max_measure_cnt}_chord{max_chord_measure_cnt}_{timestamp}.mid') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SymphonyNet 2 | ## Introduction 3 | SymponyNet is an open-source project aiming to generate complex multi-track and multi-instrument music like symphony. 4 | Our method is fully compatible with other types of music like pop, piano, solo music..etc. 5 | 6 |

7 | Schema.

8 | Have fun with SymphonyNet !! 9 | 10 | ## Installation guide 11 | We highly recommend users to run this project under `conda` environment. 12 | 13 | #### Prepare the environment: 14 | ``` 15 | conda create -n your_env_name python=3.8 16 | conda activate your_env_name 17 | 18 | cd path_to_your_env 19 | git clone this_project 20 | 21 | cd SymphonyNet 22 | cat requirements.txt | xargs -n 1 -L 1 pip install 23 | ``` 24 | The reason for using `cat requirements` is we find out the `pytorch-fast-transformers` package needs to be built upon torch, directly pip install requirements may cause `pytorch-fast-transformers` built error. 25 | 26 | Note: Building `pytorch-fast-transformers` takes a while, please wait patiently. 27 | 28 | ## Training pipeline 29 | ### Step 1: 30 | - Put your midi files into `data/midis/` 31 | 32 | ### Step 2: 33 | - Run `python3 src/preprocess/preprocess_midi.py` under project root path 34 | 35 | Quick note: The `preprocess_midi.py` multi-process all the Midis and convert them into a `raw_corpus.txt` file. In this 36 | file, each line of encoded text represents a full song. 37 | 38 | ### Step 3 (optional): 39 | - Run `python3 src/preprocess/get_bpe_data.py` if you want to train the model with Music BPE. More details about fast BPE 40 | implementation could be found here [`Music BPE`](src/musicBPE/README.md). 41 | - Set `BPE=1` in `config.sh` file 42 | 43 | Note: We only provide `music_bpe_exec` file for linux system usage, if you are using MacOS or Windows, please re-compile 44 | the `music_bpe_exec` file [`here`](src/musicBPE/README.md) by following the instruction. 45 | 46 | ### Step 4: 47 | - Run `python3 src/fairseq/make_data.py` to convert the `raw_corpus.txt` into binary file for fairseq and create `four 48 | vocabularies` mentioned in the paper. 49 | 50 | ### Step 5: 51 | - Run `sh train_linear_chord.sh` to train your own model! 52 | 53 | ## Generation pipeline 54 | - Put your checkpoint file or [download our pretrained model](https://drive.google.com/file/d/1xpkj_qN4MdLRkBdCXmfGjuWWjnTN1Og0/view?usp=sharing) into `ckpt/` 55 | - Run `python3 src/fairseq/gen_batch.py test.mid 5 0 1` to generate one symphony MIDI conditioned on the first 5 measures of test.mid, with no constraints of chord progression. 56 | - Or replace `test.mid` with your own prime MIDI and set how many measures of chords from the prime MIDI you may want to keep. 57 | - We provide a [Google Colab file](https://colab.research.google.com/github/symphonynet/SymphonyNet/blob/main/play_symphonynet.ipynb) `play_symphonynet.ipynb`, where you could follow the generation guide. 58 | 59 | ## License 60 | SymphonyNet is released under the MIT license 61 | -------------------------------------------------------------------------------- /src/encoding.py: -------------------------------------------------------------------------------- 1 | pit2alphabet = ['C', 'd', 'D', 'e', 'E', 'F', 'g', 'G', 'a', 'A', 'b', 'B'] 2 | char2pit = {x: id for id, x in enumerate(pit2alphabet)} 3 | 4 | 5 | def pit2str(x): 6 | octave = x // 12 7 | octave = octave - 1 if octave > 0 else 'O' 8 | rel_pit = x % 12 9 | return pit2alphabet[rel_pit] + str(octave) 10 | 11 | 12 | def str2pit(x): 13 | rel_pit = char2pit[x[0]] 14 | octave = (int(x[1]) if x[1] != 'O' else -1) + 1 15 | return octave * 12 + rel_pit 16 | 17 | 18 | def int2char(x): 19 | if x <= 9: 20 | return str(x) 21 | if x <= 35: 22 | return chr(ord('a') + (x - 10)) 23 | if x < 62: 24 | return chr(ord('A') + (x - 36)) 25 | assert False, f'invalid number {x}' 26 | 27 | 28 | def char2int(c): 29 | num = ord(c) 30 | A, a, Z, z = ord('A'), ord('a'), ord('Z'), ord('z') 31 | if num >= a and num <= z: 32 | return 10 + num - a 33 | elif num >= A and num <= Z: 34 | return 36 + num - A 35 | elif num >= ord('0') and num <= ord('9'): 36 | return num - ord('0') 37 | assert False, f'invalid character {c}' 38 | 39 | 40 | def pos2str(ons): 41 | if ons < 62: 42 | return 'p' + int2char(ons) 43 | return 'P' + int2char(ons - 62) 44 | 45 | 46 | def bom2str(ons): 47 | if ons < 62: 48 | return 'm' + int2char(ons) 49 | return 'M' + int2char(ons - 62) 50 | 51 | 52 | def dur2str(ons): 53 | if ons < 62: 54 | return 'r' + int2char(ons) 55 | return 'R' + int2char(ons - 62) 56 | 57 | 58 | def trk2str(ons): 59 | if ons < 62: 60 | return 't' + int2char(ons) 61 | return 'T' + int2char(ons - 62) 62 | 63 | 64 | def ins2str(ons): # 0 - 128 65 | if ons < 62: 66 | return 'x' + int2char(ons) 67 | ons -= 62 68 | if ons < 62: 69 | return 'X' + int2char(ons) 70 | ons -= 62 71 | if ons < 62: 72 | return 'y' + int2char(ons) 73 | return 'Y' + int2char(ons - 62) 74 | 75 | 76 | def ispitch(x): # judge if a event str is a pitch (CO - B9) 77 | return len(x) == 2 and x[0] in char2pit and (x[1] == 'O' or x[1].isdigit()) 78 | 79 | 80 | def ison(x): # judge if a event str is a bpe token 81 | if len(x) % 2 != 0 or len(x) < 2: 82 | return False 83 | for i in range(0, len(x), 2): 84 | if not ispitch(x[i:i + 2]): 85 | return False 86 | 87 | return True 88 | 89 | 90 | def bpe_str2int(x): 91 | if len(x) == 2: 92 | return (0, str2pit(x)) 93 | res = [] 94 | for i in range(0, len(x), 2): 95 | res.append(str2pit(x[i:i + 2])) 96 | return (1,) + tuple(sorted(res)) 97 | 98 | 99 | def sort_tok_str(x): 100 | c = x[0].lower() 101 | if c in ('r', 't', 'x', 'y'): 102 | # if x in ('RZ', 'TZ', 'YZ'): 103 | # return (c if c != 'y' else 'x', False, -1) 104 | return (c, not x[0].islower(), char2int(x[1])) 105 | if c in ('m', 'p'): 106 | return (c, not x[0].islower(), char2int(x[1])) 107 | 108 | if c == 'h': 109 | return (c, char2pit[x[1]] if x[1] != 'N' else 12, x[2:]) 110 | if c == 'n': 111 | return ('w', x) 112 | if ison(x): 113 | return ('a',) + bpe_str2int(x) 114 | 115 | return ('A', x[1] != 'b', x[1] != 'p', x[1] != 'e') 116 | -------------------------------------------------------------------------------- /src/fairseq/linear_transformer/linear_transformer_lm.py: -------------------------------------------------------------------------------- 1 | from fast_transformers.builders import TransformerEncoderBuilder, RecurrentEncoderBuilder 2 | from fast_transformers.masking import TriangularCausalMask, LengthMask 3 | 4 | import logging 5 | import os 6 | import sys 7 | import numpy as np 8 | from typing import Dict, List, Optional, Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from fairseq.models import ( 15 | FairseqDecoder, 16 | FairseqLanguageModel, 17 | register_model, 18 | register_model_architecture, 19 | ) 20 | 21 | DEFAULT_MAX_TARGET_POSITIONS = 1024 22 | 23 | @register_model("linear_transformer_lm") 24 | class LinearTransformerLanguageModel(FairseqLanguageModel): 25 | def __init__(self, decoder): 26 | super().__init__(decoder) 27 | 28 | @staticmethod 29 | def add_args(parser): 30 | """Add model-specific arguments to the parser.""" 31 | # fmt: off 32 | parser.add_argument('--embed-dim', type=int, metavar='N', 33 | help='embedding dimension') 34 | parser.add_argument('--num-attention-heads', type=int, metavar='N', 35 | help='num attention heads') 36 | parser.add_argument('--num-layers', type=int, metavar='N', 37 | help='num layers') 38 | parser.add_argument('--dropout', type=float, metavar='D', 39 | help='dropout probability for all fully connected layers ' 40 | 'in the embeddings, encoder, and pooler') 41 | # parser.add_argument('--attention-dropout', type=float, metavar='D', 42 | # help='dropout probability for attention weights') 43 | # fmt: on 44 | 45 | @classmethod 46 | def build_model(cls, args, task): 47 | """Build a new model instance.""" 48 | base_architecture(args) 49 | return cls(LinearTransformerDecoder(args, task)) 50 | 51 | 52 | class LinearTransformerDecoder(FairseqDecoder): 53 | def __init__(self, args, task): 54 | 55 | super().__init__(task.target_dictionary) 56 | self.embed_dim = args.embed_dim 57 | self.max_seq_len = args.max_seq_len 58 | self.wte = nn.Embedding(len(task.target_dictionary), args.embed_dim) 59 | self.wpe = nn.Embedding(args.max_seq_len+1, args.embed_dim) 60 | self.drop = nn.Dropout(args.dropout) 61 | self.ln_f = nn.LayerNorm(args.embed_dim, eps=1e-6) 62 | 63 | #self.embed_tokens = Embedding(len(task.target_dictionary), args.embed_dim, self.pad_idx) 64 | #self.wpe = MyLearnedPositionalEmbedding(args.max_seq_len, args.embed_dim, self.pad_idx) 65 | #self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) 66 | #self.layernorm_embedding = LayerNorm(args.embed_dim) 67 | self.model = TransformerEncoderBuilder.from_kwargs( 68 | n_layers=args.num_layers, 69 | n_heads=args.num_attention_heads, 70 | query_dimensions=args.embed_dim // args.num_attention_heads, 71 | value_dimensions=args.embed_dim // args.num_attention_heads, 72 | feed_forward_dimensions=4 * args.embed_dim, 73 | activation='gelu', 74 | #final_normalization=True, 75 | dropout=args.dropout, 76 | attention_type="causal-linear", 77 | #feature_map=Favor.factory(n_dims=self.d_model) 78 | ).get() 79 | #self.attn_mask = TriangularCausalMask(args.max_seq_len) 80 | self.lm_head = nn.Linear( 81 | args.embed_dim, len(task.target_dictionary), bias=False 82 | ) 83 | self.apply(self._init_weights) 84 | # set zero embedding for padding symbol 85 | self.pad_idx = task.target_dictionary.pad() 86 | self.wte.weight.data[self.pad_idx].zero_() 87 | self.wpe.weight.data[0].zero_() 88 | 89 | 90 | def _init_weights(self, module): 91 | if isinstance(module, (nn.Linear, nn.Embedding)): 92 | module.weight.data.normal_(mean=0.0, std=self.embed_dim ** -0.5) 93 | if isinstance(module, nn.Linear) and module.bias is not None: 94 | module.bias.data.zero_() 95 | elif isinstance(module, nn.LayerNorm): 96 | module.bias.data.zero_() 97 | module.weight.data.fill_(1.0) 98 | 99 | def forward( 100 | self, 101 | prev_output_tokens, 102 | src_lengths = None 103 | # incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None, 104 | # encoder_out=None, 105 | ): 106 | # print(prev_output_tokens.size()) 107 | # print(prev_output_tokens) 108 | # if src_lengths is not None: 109 | # print(src_lengths.size()) 110 | # print(src_lengths) 111 | 112 | features = self.extract_features(prev_output_tokens)#, incremental_state) 113 | lm_logits = self.lm_head(features) 114 | return (lm_logits,) 115 | 116 | def extract_features( 117 | self, 118 | prev_output_tokens 119 | ): 120 | 121 | bsz, seq_len = prev_output_tokens.size() 122 | attention_mask = prev_output_tokens.ne(self.pad_idx).long().to(prev_output_tokens.device) 123 | # set position ids to exclude padding symbols 124 | position_ids = attention_mask * ( 125 | torch.arange(1, 1 + seq_len) 126 | .to(prev_output_tokens.device) 127 | .repeat(bsz, 1) 128 | ) 129 | len_mask = LengthMask(torch.sum(attention_mask, axis=1), max_len=seq_len, device=prev_output_tokens.device) 130 | 131 | token_embeddings = self.wte(prev_output_tokens) 132 | position_embeddings = self.wpe(position_ids) 133 | x = self.drop(token_embeddings + position_embeddings) 134 | attn_mask = TriangularCausalMask(seq_len, device=x.device) 135 | outputs = self.model(x, attn_mask, len_mask) 136 | outputs = self.ln_f(outputs) 137 | 138 | return outputs 139 | 140 | def max_positions(self): 141 | return self.max_seq_len 142 | 143 | 144 | @register_model_architecture("linear_transformer_lm", "linear_transformer_lm") 145 | def base_architecture(args): 146 | if getattr(args, "max_seq_len", None) is None: 147 | args.max_seq_len = getattr( 148 | args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS 149 | ) 150 | args.embed_dim = getattr(args, "embed_dim", 768) 151 | args.num_attention_heads = getattr(args, "num_attention_heads", 12) 152 | args.num_layers = getattr(args, "num_layers", 12) 153 | args.dropout = getattr(args, "dropout", 0.1) 154 | 155 | -------------------------------------------------------------------------------- /src/preprocess/get_bpe_data.py: -------------------------------------------------------------------------------- 1 | import time, os, json 2 | from collections import Counter 3 | from pprint import pprint 4 | from tqdm import tqdm 5 | import subprocess#, multiprocessing 6 | from functools import partial 7 | from p_tqdm import p_uimap 8 | RATIO = 4 9 | MERGE_CNT = 700 10 | CHAR_CNT = 128 11 | WORKERS = 32 12 | 13 | import sys, os 14 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 15 | from encoding import pit2str, str2pit, ispitch 16 | 17 | def resort(voc: str) -> str: 18 | assert(len(voc) % 2 == 0), voc 19 | pitch_set = list(set(voc[i:i+2] for i in range(0, len(voc), 2))) 20 | assert len(pitch_set) * 2 == len(voc), voc 21 | return ''.join(sorted(pitch_set, key=str2pit)) 22 | 23 | def gettokens(voc: set, merges): 24 | assert len(voc) > 1, voc 25 | last_idx = 0 26 | while(len(voc) > 1): 27 | flag = False 28 | for i in range(last_idx, len(merges)): 29 | t1, t2, t3 = merges[i] 30 | if t1 in voc and t2 in voc: 31 | voc.remove(t1) 32 | voc.remove(t2) 33 | voc.add(t3) 34 | flag = True 35 | last_idx = i+1 36 | break 37 | if not flag: 38 | break 39 | return voc 40 | 41 | def merge_mulpies(new_toks, mulpies, other, merges, merged_vocs, divide_res): 42 | assert other is not None, mulpies 43 | for dur, mulpi in mulpies.items(): 44 | if len(mulpi) > 1: # apply bpe (with saved tokenization method) 45 | mulpi_sorted = tuple(sorted(list(mulpi), key=str2pit)) 46 | if mulpi_sorted in divide_res: 47 | submulpies = divide_res[mulpi_sorted] 48 | else: 49 | submulpies = sorted(gettokens(set(str2pit(x) for x in mulpi_sorted), merges)) 50 | 51 | for submulpi_num in submulpies: 52 | new_toks.extend([merged_vocs[submulpi_num], dur]+other) 53 | else: 54 | new_toks.extend([list(mulpi)[0], dur]+other) 55 | 56 | def apply_bpe_for_sentence(toks, merges, merged_vocs, divide_res, ratio=RATIO): 57 | if isinstance(toks, str): 58 | toks = toks.split() 59 | new_toks = [] 60 | mulpies = dict() 61 | other = None 62 | 63 | for idx in range(0, len(toks), ratio): 64 | e, d = toks[idx:idx+2] 65 | if not ispitch(e): 66 | if len(mulpies) > 0: 67 | merge_mulpies(new_toks, mulpies, other, merges, merged_vocs, divide_res) 68 | mulpies = dict() 69 | new_toks.extend(toks[idx:idx+ratio]) 70 | else: 71 | mulpies.setdefault(d, set()).add(e) 72 | other = toks[idx+2:idx+ratio] 73 | 74 | if len(mulpies) > 0: 75 | merge_mulpies(new_toks, mulpies, other, merges, merged_vocs, divide_res) 76 | 77 | assert len(new_toks) % ratio == 0, f'error new token len {len(new_toks)}' 78 | 79 | return new_toks 80 | 81 | def load_before_apply_bpe(bpe_res_dir): 82 | merged_vocs = [pit2str(i) for i in range(CHAR_CNT)] 83 | merged_voc_to_int = {pit2str(i):i for i in range(CHAR_CNT)} 84 | merges = [] 85 | with open(bpe_res_dir+'codes.txt', 'r') as f: 86 | for line in f: 87 | a, b, _ = line.strip().split() 88 | a,b,ab = resort(a), resort(b), resort(a+b) 89 | 90 | a_ind, b_ind, ab_ind = merged_voc_to_int[a], merged_voc_to_int[b], len(merged_vocs) 91 | merges.append((a_ind, b_ind, ab_ind)) 92 | 93 | merged_voc_to_int[ab] = ab_ind 94 | merged_vocs.append(ab) 95 | 96 | return merges, merged_vocs 97 | 98 | def apply_bpe_for_word_dict(mulpi_list, merges): 99 | # apply bpe for vocabs 100 | bpe_freq = Counter() 101 | divided_bpe_total = Counter() 102 | divide_res = dict() 103 | for ori_voc, cnt in tqdm(mulpi_list): 104 | ret = sorted(gettokens(set(str2pit(x) for x in ori_voc), merges)) 105 | divide_res[ori_voc] = ret 106 | divided_bpe_total[len(ret)] += cnt 107 | for r in ret: 108 | bpe_freq[merged_vocs[r]] += cnt 109 | 110 | return divide_res, divided_bpe_total, bpe_freq 111 | 112 | def count_single_mulpies(toks, ratio=RATIO): 113 | if isinstance(toks, str): 114 | toks = toks.split() 115 | mulpies = dict() 116 | chord_dict = Counter() 117 | l_toks = len(toks) 118 | for idx in range(0, l_toks, ratio): 119 | e, d = toks[idx:idx+2] 120 | 121 | if not ispitch(e): 122 | if len(mulpies) > 0: 123 | for dur, mulpi in mulpies.items(): 124 | if len(mulpi) > 1: 125 | chord_dict[tuple(sorted(list(mulpi), key=str2pit))] += 1 126 | mulpies = dict() 127 | else: 128 | mulpies.setdefault(d, set()).add(e) 129 | 130 | if len(mulpies) > 0: 131 | for dur, mulpi in mulpies.items(): 132 | if len(mulpi) > 1: 133 | chord_dict[tuple(sorted(list(mulpi), key=str2pit))] += 1 134 | 135 | return chord_dict, l_toks // ratio 136 | 137 | 138 | if __name__ == '__main__': 139 | start_time = time.time() 140 | 141 | paragraphs = [] 142 | 143 | raw_data_path = 'data/preprocessed/raw_corpus.txt' 144 | merged_data_path = 'data/preprocessed/raw_corpus_bpe.txt' 145 | output_dir = 'data/bpe_res/' 146 | os.makedirs(output_dir, exist_ok=True) 147 | raw_data = [] 148 | with open(raw_data_path, 'r') as f: 149 | for line in tqdm(f, desc="reading original txt file..."): 150 | raw_data.append(line.strip()) 151 | 152 | chord_dict = Counter() 153 | before_total_tokens = 0 154 | for sub_chord_dict, l_toks in p_uimap(count_single_mulpies, raw_data, num_cpus=WORKERS): 155 | chord_dict += sub_chord_dict 156 | before_total_tokens += l_toks 157 | 158 | mulpi_list = sorted(chord_dict.most_common(), key=lambda x: (-x[1], x[0])) 159 | with open(output_dir+'ori_voc_cnt.txt', 'w') as f: 160 | f.write(str(len(mulpi_list)) + '\n') 161 | for k, v in mulpi_list: 162 | f.write(''.join(k) + ' ' + str(v) + '\n') 163 | with open(output_dir+'codes.txt', 'w') as stdout: 164 | with open(output_dir+'merged_voc_list.txt', 'w') as stderr: 165 | subprocess.run(['./music_bpe_exec', 'learnbpe', f'{MERGE_CNT}', output_dir+'ori_voc_cnt.txt'], stdout=stdout, stderr=stderr) 166 | print(f'learnBPE finished, time elapsed: {time.time() - start_time}') 167 | start_time = time.time() 168 | 169 | merges, merged_vocs = load_before_apply_bpe(output_dir) 170 | divide_res, divided_bpe_total, bpe_freq = apply_bpe_for_word_dict(mulpi_list, merges) 171 | with open(output_dir+'divide_res.json', 'w') as f: 172 | json.dump({' '.join(k):v for k, v in divide_res.items()}, f) 173 | with open(output_dir+'bpe_voc_cnt.txt', 'w') as f: 174 | for voc, cnt in bpe_freq.most_common(): 175 | f.write(voc + ' ' + str(cnt) + '\n') 176 | ave_len_bpe = sum(k*v for k, v in divided_bpe_total.items()) / sum(divided_bpe_total.values()) 177 | ave_len_ori = sum(len(k)*v for k, v in mulpi_list) / sum(v for k, v in mulpi_list) 178 | print(f'average mulpi length original: {ave_len_ori}, average mulpi length after bpe: {ave_len_bpe}') 179 | print(f'applyBPE for word finished, time elapsed: {time.time() - start_time}') 180 | start_time = time.time() 181 | 182 | # applyBPE for corpus 183 | 184 | after_total_tokens = 0 185 | with open(merged_data_path, 'w') as f: 186 | for x in tqdm(raw_data, desc="writing bpe data"): # unable to parallelize for out of memory 187 | new_toks = apply_bpe_for_sentence(x, merges, merged_vocs, divide_res) 188 | after_total_tokens += len(new_toks) // RATIO 189 | f.write(' '.join(new_toks) + '\n') 190 | print(f'applyBPE for corpus finished, time elapsed: {time.time() - start_time}') 191 | print(f'before tokens: {before_total_tokens}, after tokens: {after_total_tokens}, delta: {(before_total_tokens - after_total_tokens) / before_total_tokens}') -------------------------------------------------------------------------------- /src/fairseq/make_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from p_tqdm import p_uimap 3 | import random, json, time, os 4 | from tqdm import tqdm 5 | import torch 6 | import sys, os, multiprocessing 7 | from collections import Counter 8 | from pprint import pprint 9 | from fairseq.data.indexed_dataset import MMapIndexedDatasetBuilder 10 | from functools import partial 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | from encoding import sort_tok_str 13 | 14 | PAD = 1 15 | EOS = 2 16 | BOS = 0 17 | 18 | #RATIO = 4 19 | #SAMPLE_LEN_MAX = 4096 20 | #SOR = 4 21 | 22 | WORKERS = 32 23 | 24 | # def get_mea_cnt(str_toks, ratio): 25 | # bom_idx = [] 26 | # for idx in range(0, len(str_toks), ratio): 27 | # if str_toks[idx][0].lower() == 'm': 28 | # bom_idx.append(idx) # extract all bom tokens idx 29 | # bom_idx.append(len(str_toks)) 30 | # ret = 0 31 | # for id, nid in zip(bom_idx[:-1], bom_idx[1:]): 32 | # ret += 1 33 | # return ret 34 | 35 | def process_single_piece(bundle_input, ratio, sample_len_max): 36 | line, str2int = bundle_input 37 | 38 | if isinstance(line, str): 39 | str_toks = line.split() 40 | else: 41 | str_toks = line 42 | 43 | measures = [] 44 | cur_mea = [] 45 | max_rel_pos = 0 46 | mea_tok_lengths = [] 47 | 48 | rel_pos = 0 49 | 50 | for idx in range(0, len(str_toks), ratio): 51 | c = str_toks[idx][0] 52 | 53 | if c.lower() == 'm': # BOM Token 54 | if len(cur_mea) > 0: # exlude first bom 55 | measures.append(cur_mea) 56 | mea_tok_lengths.append(len(cur_mea) // (ratio+1)) 57 | cur_mea = [] 58 | if rel_pos > max_rel_pos: 59 | max_rel_pos = rel_pos 60 | rel_pos = 0 61 | elif c.lower() == 'h': # chord token 62 | if rel_pos > max_rel_pos: 63 | max_rel_pos = rel_pos 64 | rel_pos = 0 65 | elif c.lower() == 'n': # CC/NT Token 66 | if rel_pos > max_rel_pos: 67 | max_rel_pos = rel_pos 68 | rel_pos = 1 69 | elif c.lower() == 'p': # pos token 70 | rel_pos += 2 71 | else: # on token 72 | pass 73 | cur_mea += [str2int[x] for x in str_toks[idx:idx+ratio]] + [rel_pos-1 if c.lower() == 'p' else rel_pos] 74 | #TODO: how to design rel_pos and measure pos? 75 | if len(cur_mea) > 0: 76 | measures.append(cur_mea) 77 | mea_tok_lengths.append(len(cur_mea) // (ratio+1)) 78 | if rel_pos > max_rel_pos: 79 | max_rel_pos = rel_pos 80 | 81 | # tmp = get_mea_cnt(str_toks, ratio) 82 | # if get_mea_cnt(str_toks, ratio) != len(measures): 83 | # print(f'{tmp} {len(measures)} {len(mea_tok_lengths)}') 84 | 85 | len_cnter = Counter() 86 | for l in mea_tok_lengths: 87 | len_cnter[l // 10] += 1 88 | 89 | for idx in range(1, len(mea_tok_lengths)): 90 | mea_tok_lengths[idx] += mea_tok_lengths[idx-1] 91 | 92 | def get_cur_tokens(s, t): # return total cnt of tokens in measure [s, t] 93 | return mea_tok_lengths[t] - (mea_tok_lengths[s-1] if s > 0 else 0) 94 | 95 | maxl = 1 96 | for s in range(len(mea_tok_lengths)): 97 | t = s + maxl - 1 98 | 99 | while t < len(mea_tok_lengths) and get_cur_tokens(s, t) < sample_len_max: 100 | t += 1 101 | 102 | t = min(t, len(mea_tok_lengths) - 1) 103 | maxl = max(maxl, t - s + 1) 104 | 105 | return measures, len_cnter, max_rel_pos, maxl 106 | 107 | def myshuffle(l): 108 | ret = [] 109 | idx = list(range(len(l))) 110 | random.shuffle(idx) 111 | for id in idx: 112 | ret.append(l[id]) 113 | return ret 114 | 115 | 116 | def mp_handler(raw_data, str2int, output_file, ratio, sample_len_max, num_workers=WORKERS): 117 | begin_time = time.time() 118 | 119 | merged_sentences = [] 120 | mea_cnt_dis = Counter() 121 | mea_len_dis = Counter() 122 | max_rel_pos = 0 123 | maxl = 0 124 | with multiprocessing.Pool(num_workers) as p: 125 | for sentences, len_cnter, pos, l in p.imap_unordered(partial(process_single_piece, ratio=ratio, sample_len_max=sample_len_max), [(x, str2int) for x in raw_data]): 126 | merged_sentences.append(sentences) 127 | mea_len_dis += len_cnter 128 | max_rel_pos = max(max_rel_pos, pos) 129 | maxl = max(maxl, l) 130 | for sentences in merged_sentences: 131 | mea_cnt_dis[len(sentences) // 10] += 1 132 | 133 | 134 | print(f'measure collection finished, total {sum(len(x) for x in merged_sentences)} measures, time elapsed: {time.time()-begin_time} s') 135 | print(f'max cnt in a sample (rel_pos, measure): {max_rel_pos}, {maxl}') 136 | 137 | 138 | begin_time = time.time() 139 | 140 | if output_file.split('/')[-1] == 'train': 141 | with open('vocab.sh', 'a') as f: 142 | f.write(f'MAX_REL_POS={max_rel_pos+5}\n') 143 | f.write(f'MAX_MEA_POS={maxl*3+5}\n') 144 | with open('src/fairseq/mea_cnt_dis.txt', 'w') as f: 145 | for k, v in sorted(mea_cnt_dis.items()): 146 | f.write(f'{k*10} {v}\n') 147 | with open('src/fairseq/mea_len_dis.txt', 'w') as f: 148 | for k, v in sorted(mea_len_dis.items()): 149 | f.write(f'{k*10} {v}\n') 150 | 151 | ds = MMapIndexedDatasetBuilder(output_file+'.bin', dtype=np.uint16) 152 | for doc in tqdm(merged_sentences, desc='writing bin file'): 153 | for sentence in doc: 154 | ds.add_item(torch.IntTensor(sentence)) 155 | ds.add_item(torch.IntTensor([EOS])) 156 | 157 | ds.finalize(output_file+'.idx') 158 | 159 | print("write binary finished, write time elapsed {:.2f} s".format(time.time() - begin_time)) 160 | 161 | 162 | def makevocabs(line, ratio): 163 | toks = line.split() 164 | ret_sets = [] 165 | for i in range(ratio): 166 | sub_toks = toks[i::ratio] 167 | ret_sets.append(set(sub_toks)) 168 | return ret_sets 169 | 170 | 171 | if __name__ == '__main__': 172 | # --------- slice multi-track ---- 173 | SEED, SAMPLE_LEN_MAX, totpiece, RATIO, bpe, map_meta_to_pad = None, None, None, None, None, None 174 | print('config.sh: ') 175 | with open('config.sh', 'r') as f: 176 | for line in f: 177 | line = line.strip() 178 | if len(line) == 0: 179 | break 180 | print(line) 181 | line = line.split('=') 182 | assert len(line) == 2, f'invalid config {line}' 183 | if line[0] == 'SEED': 184 | SEED = int(line[1]) 185 | random.seed(SEED) 186 | elif line[0] == 'MAX_POS_LEN': 187 | SAMPLE_LEN_MAX = int(line[1]) 188 | elif line[0] == 'MAXPIECES': 189 | totpiece = int(line[1]) 190 | elif line[0] == 'RATIO': 191 | RATIO = int(line[1]) 192 | elif line[0] == 'BPE': 193 | bpe = int(line[1]) 194 | elif line[0] == 'IGNORE_META_LOSS': 195 | map_meta_to_pad = int(line[1]) 196 | 197 | assert SEED is not None, "missing arg: SEED" 198 | assert SAMPLE_LEN_MAX is not None, "missing arg: MAX_POS_LEN" 199 | assert totpiece is not None, "missing arg: MAXPIECES" 200 | assert RATIO is not None, "missing arg: RATIO" 201 | assert bpe is not None, "missing arg: BPE" 202 | assert map_meta_to_pad is not None, "missing arg: IGNORE_META_LOSS" 203 | 204 | bpe = "" if bpe == 0 else "_bpe" 205 | raw_corpus = f'raw_corpus{bpe}' 206 | model_name = f"linear_{SAMPLE_LEN_MAX}_chord{bpe}" 207 | raw_data_path = f'data/preprocessed/{raw_corpus}.txt' 208 | output_dir = f'data/model_spec/{model_name}_hardloss{map_meta_to_pad}/' 209 | 210 | start_time = time.time() 211 | raw_data = [] 212 | with open(raw_data_path, 'r') as f: 213 | for line in tqdm(f, desc='reading...'): 214 | raw_data.append(line.strip()) 215 | if len(raw_data) >= totpiece: 216 | break 217 | 218 | sub_vocabs = dict() 219 | for i in range(RATIO): 220 | sub_vocabs[i] = set() 221 | 222 | for ret_sets in p_uimap(partial(makevocabs, ratio=RATIO), raw_data, num_cpus=WORKERS, desc='setting up vocabs'): 223 | for i in range(RATIO): 224 | sub_vocabs[i] |= ret_sets[i] 225 | 226 | voc_to_int = dict() 227 | for type in range(RATIO): 228 | sub_vocabs[type] |= set(('', '', '', '')) 229 | sub_vocabs[type] -= set(('RZ', 'TZ', 'YZ')) 230 | sub_vocabs[type] = sorted(list(sub_vocabs[type]), key=sort_tok_str) 231 | voc_to_int.update({v:i for i,v in enumerate(sub_vocabs[type]) }) 232 | output_dict = sorted(list(set(voc_to_int.values()))) 233 | max_voc_size = max(output_dict) 234 | print("max voc idx: ", max_voc_size) 235 | 236 | os.makedirs(output_dir + 'bin/', exist_ok=True) 237 | with open(output_dir + 'bin/dict.txt', 'w') as f: 238 | for i in range(4, max_voc_size+1): # [4, max_voc_size] 239 | f.write("%d 0\n"%i) 240 | 241 | 242 | os.makedirs(output_dir + 'vocabs/', exist_ok=True) 243 | for type in range(RATIO): 244 | sub_vocab = sub_vocabs[type] 245 | with open(output_dir + 'vocabs/vocab_%d.json'%type, 'w') as f: 246 | json.dump({i:v for i,v in enumerate(sub_vocab)}, f) 247 | with open(output_dir + 'vocabs/ori_dict.json', 'w') as f: 248 | json.dump(voc_to_int, f) 249 | print('sub vocab size:', end = ' ') 250 | for type in range(RATIO): 251 | print(len(sub_vocabs[type]), end = ' ') 252 | print() 253 | with open(f'vocab.sh', 'w') as f: 254 | for type in range(RATIO): 255 | f.write(f'SIZE_{type}={len(sub_vocabs[type])}\n') 256 | 257 | totpiece = len(raw_data) 258 | print("total pieces: {:d}, create dict time: {:.2f} s".format(totpiece, time.time() - start_time)) 259 | 260 | raw_data = myshuffle(raw_data) 261 | os.makedirs(output_dir + 'bin/', exist_ok=True) 262 | train_size = min(int(totpiece*0.99), totpiece-2) 263 | splits = {'train': raw_data[:train_size], 'valid': raw_data[train_size:-1], 'test': raw_data[-1:]} 264 | 265 | 266 | voc_to_int.update({x:(PAD if map_meta_to_pad == 1 else BOS) for x in ('RZ', 'TZ', 'YZ')}) 267 | for mode in splits: 268 | print(mode) 269 | mp_handler(splits[mode], voc_to_int, output_dir + f'bin/{mode}', ratio=RATIO, sample_len_max=SAMPLE_LEN_MAX) 270 | -------------------------------------------------------------------------------- /data/bpe_res/codes.txt: -------------------------------------------------------------------------------- 1 | C4 E4 383842 2 | D4 F4 382729 3 | B1 D2 381451 4 | A4 C5 310842 5 | C4 e4 310833 6 | D2 g2 301570 7 | E4 G4 297587 8 | B3 D4 289873 9 | G4 b4 271198 10 | A3 C4 263329 11 | b3 D4 255889 12 | D4 D5 246267 13 | G4 G5 236171 14 | A3 D4 232491 15 | g4 A4 229265 16 | D2 D3 227889 17 | F4 A4 224217 18 | C2 C3 221293 19 | G3 G4 218738 20 | C2 D2 214198 21 | E3 G3 213825 22 | D5 F5 211749 23 | G4 B4 206410 24 | d4 E4 203640 25 | G3 b3 198399 26 | A4 A5 198334 27 | C2 g2 197255 28 | C5 E5 195800 29 | G2 G3 192667 30 | F4 F5 192367 31 | C4 C5 190217 32 | E4 E5 185984 33 | e4 G4 178841 34 | A2 A3 173548 35 | F2 F3 172860 36 | B4 D5 171875 37 | C4 F4 170693 38 | B3 E4 166211 39 | C5 e5 160444 40 | C3 C4 159549 41 | F4 a4 157937 42 | E2 E3 156801 43 | b4 D5 154168 44 | A3 A4 153683 45 | E5 G5 151430 46 | b4 b5 149679 47 | F3 A3 148386 48 | B1 A3 145614 49 | e4 e5 140995 50 | g2 g3 140674 51 | C5 C6 140382 52 | D4 G4 140003 53 | D5 D6 135065 54 | D4 g4 132474 55 | g3 A3 131369 56 | G4 C5 131185 57 | B1 g2 130777 58 | b3 e4 129284 59 | F4 b4 127863 60 | E4 A4 126791 61 | D3 D4 126529 62 | F3 a3 123627 63 | A4 D5 123367 64 | G3 B3 121866 65 | b3 d4 121147 66 | e2 e3 120273 67 | d5 E5 119393 68 | b2 b3 119146 69 | G3 C4 116166 70 | B4 B5 114356 71 | g4 g5 112272 72 | D4 A4 111911 73 | D3 F3 111638 74 | a4 a5 110462 75 | e4 g4 108644 76 | G3 D4 108483 77 | a4 C5 107949 78 | b1 b2 107909 79 | a4 B4 106603 80 | C4 G4 106326 81 | G1 G2 104691 82 | a3 C4 104067 83 | D2 g3 103579 84 | A1 A2 103247 85 | b3 b4 100974 86 | F5 A5 100140 87 | e5 G5 100031 88 | B2 B3 99989 89 | D5 g5 98587 90 | e3 G3 98183 91 | d4 d5 97973 92 | G2 D3 97631 93 | F3 F4 97057 94 | D3 A3 95853 95 | A3 d4 94383 96 | B3 B4 93772 97 | d2 d3 88571 98 | D5 G5 87752 99 | b4 d5 86720 100 | d4 F4 85422 101 | C3 G3 84439 102 | B1 B2 84191 103 | a2 a3 83457 104 | A3 E4 83275 105 | E5 E6 82455 106 | A4 d5 82219 107 | g5 A5 80521 108 | E3 E4 80270 109 | a3 B3 79410 110 | B4 E5 79407 111 | C3 E3 78721 112 | C5 F5 78150 113 | F3 b3 78048 114 | G5 b5 75489 115 | D3 G3 74871 116 | G4 D5 74229 117 | b4 e5 73985 118 | E4 a4 73726 119 | C4 D4 73408 120 | e4 a4 72647 121 | E3 A3 72090 122 | e3 e4 71347 123 | b3 F4 70809 124 | D3 g3 70615 125 | F3 C4 69990 126 | G3 E4 69513 127 | F5 F6 69258 128 | d5 d6 68998 129 | F4 C5 68991 130 | C2 F2 68972 131 | d4 g4 67499 132 | D2 A3 66594 133 | F1 F2 66016 134 | A5 C6 65129 135 | a3 a4 64828 136 | D2 b2 63268 137 | G5 G6 62712 138 | e5 e6 62367 139 | D2 e3 62274 140 | g4 B4 62232 141 | F5 a5 61898 142 | g3 g4 61802 143 | E4 C5 61583 144 | e4 b4 61514 145 | A4 E5 61444 146 | A2 E3 61294 147 | a3 d4 61026 148 | G5 B5 60425 149 | E2 g2 59861 150 | D2 A2 59504 151 | D4 b4 58819 152 | e3 g3 58517 153 | F2 C3 58174 154 | E4 B4 58170 155 | C2 e3 57211 156 | d2 g2 56839 157 | e5 g5 56316 158 | A3 F4 56197 159 | C3 D3 55962 160 | B3 e4 55759 161 | E5 A5 55364 162 | a1 a2 54730 163 | b2 F3 54571 164 | F4 D5 54250 165 | C5 G5 53926 166 | B3 G4 53779 167 | C3 F3 53763 168 | C3 e3 52672 169 | d5 F5 52586 170 | G3 e4 52219 171 | E3 B3 51801 172 | E2 B2 51353 173 | d3 d4 51179 174 | b3 G4 50958 175 | B1 C2 50494 176 | G4 E5 50367 177 | G5 C6 49770 178 | g4 b4 49260 179 | C2 G2 49077 180 | C4 A4 48987 181 | D4 B4 48670 182 | A3 g4 48472 183 | e3 b3 48278 184 | E3 a3 48111 185 | b4 F5 47050 186 | a4 d5 46890 187 | E1 E2 46761 188 | G4 C4E4 46736 189 | A2 D3 46700 190 | D2 F2 46526 191 | a3 e4 46521 192 | C2 a2 45985 193 | C2 d3 44438 194 | A4 F5 44433 195 | F5 b5 44010 196 | D1 D2 43923 197 | B4 G5 43864 198 | E5 a5 43440 199 | F4 A4C5 43152 200 | B3 g4 42937 201 | g3 B3 41928 202 | g4 D5 41628 203 | F2 A2 41002 204 | B5 D6 40936 205 | d3 E3 40704 206 | e3 a3 40466 207 | C2 b2 40418 208 | a2 e3 40395 209 | A3 B1D2 39815 210 | b3 D4F4 39731 211 | A4 g5 39518 212 | B1 d2 39463 213 | B4 e5 39340 214 | e4 C5 39264 215 | G4 e5 39108 216 | D5 A5 38823 217 | F3 A3C4 37855 218 | dO C1 37853 219 | b4 G5 37759 220 | g5 g6 37510 221 | G2 C3 37327 222 | A5 A6 37294 223 | a5 B5 36668 224 | C5 A5 36622 225 | D4 E4 36443 226 | b2 D3 35825 227 | g3 D4 35435 228 | A2 C3 35167 229 | a5 C6 35084 230 | a3 F4 34996 231 | b5 D6 34974 232 | G3 B3D4 34966 233 | d5 g5 34948 234 | B2 D3 34446 235 | C2 B3 34355 236 | B1 e3 34292 237 | D2 d3 34283 238 | A3 C4E4 34018 239 | g3 e4 33890 240 | F3 D4 33589 241 | C6 E6 33509 242 | g3 b3 33387 243 | g1 g2 33256 244 | B1 F2 33084 245 | B2 E3 33030 246 | C1 C2 32749 247 | C3 E3G3 32709 248 | d4 a4 32464 249 | g3 d4 32288 250 | e1 e2 32213 251 | g4 A3D4 32195 252 | F4 A3C4 32053 253 | G4 A4 31997 254 | B1 b2 31551 255 | A5 D6 31182 256 | F4 G4 30924 257 | D4 g4A4 30871 258 | d4 A4 30686 259 | B1 d3 30399 260 | E4 d5 30015 261 | A4 D4F4 29814 262 | E3 C4 29442 263 | B2 g3 29432 264 | G4 B3D4 29345 265 | D2 G2 29335 266 | b2 e3 29278 267 | G3 C4E4 28957 268 | D5 b5 28902 269 | D6 F6 28890 270 | d2 D2 28733 271 | D5 B5 28686 272 | E5 C6 28671 273 | e5 a5 28665 274 | d3 g3 28498 275 | G4 C4e4 28324 276 | C5 D5 28104 277 | a4 e5 28061 278 | A5 d6 27853 279 | B4 g5 27775 280 | g2 d3 27727 281 | C4 F4A4 27550 282 | C4 a4 27229 283 | B1 g3 27197 284 | D5 G4B4 27036 285 | D2 A5 26995 286 | d4 b4 26958 287 | d3 a3 26949 288 | B3 a4 26741 289 | e2 b2 26659 290 | g4 d5 26623 291 | a5 a6 26536 292 | C5 E4G4 26501 293 | g5 B5 26137 294 | C6 e6 25974 295 | E3 A2A3 25920 296 | e4 A4 25882 297 | C2 E2 25577 298 | a3 E4 25565 299 | A4 b4 25485 300 | e5 b5 25471 301 | d4 e4 24744 302 | C2 g3 24651 303 | b4 C5 24580 304 | G3 b3D4 24579 305 | e4 G4b4 24363 306 | a4 F5 24288 307 | e4 F4 24191 308 | E5 B5 24051 309 | F5 C6 23958 310 | C3 F2F3 23844 311 | C4 d4 23796 312 | g5 b5 23708 313 | D3 B3 23688 314 | b5 b6 23641 315 | G2 B2 23630 316 | b3 C4 23573 317 | A3 d4E4 23509 318 | D5 E5 23485 319 | E5 A4C5 23348 320 | d3 F3 23162 321 | D2 a2 23125 322 | D4 G4B4 23052 323 | D3 C5 22971 324 | b4 D4F4 22939 325 | a3 C4e4 22932 326 | A4 dOC1 22815 327 | a3 D4 22406 328 | E6 G6 22401 329 | a4 C4e4 22303 330 | F3 b3D4 22187 331 | D3 G2G3 21944 332 | E4 b4 21746 333 | e3 C4 21731 334 | G2 b2 21613 335 | B4 E4G4 21420 336 | C6 F6 21219 337 | b5 e6 21002 338 | b3 e4G4 20921 339 | e5 F5 20904 340 | E2 b2 20866 341 | F5 D6 20670 342 | e4 G3b3 20649 343 | A4 B4 20570 344 | F5 G5 20550 345 | a4 E5 20516 346 | D3 b3 20485 347 | d5 A5 20468 348 | G2 A2 20345 349 | b5 d6 20311 350 | e5 C6 20227 351 | F2 b2 20225 352 | g2 e3 20204 353 | d2 a2 20201 354 | F4 B4 20094 355 | D2 E2 19849 356 | G5 D6 19751 357 | C6 C7 19495 358 | C3 A3 19453 359 | B4 a5 19329 360 | C4 g4 19321 361 | G4 b3D4 19180 362 | C2 d2 19151 363 | A4 C4E4 19135 364 | b3 E4 18948 365 | d3 A3 18854 366 | B1 E2 18787 367 | B3 E3G3 18744 368 | B2 E2E3 18723 369 | C5 a5 18665 370 | d6 E6 18628 371 | G5 C5E5 18609 372 | G5 A5 18405 373 | a2 d3 18317 374 | C4 E3G3 18126 375 | B3 F4 18097 376 | d5 b5 18089 377 | E4 F4 18080 378 | C2 A3 18057 379 | A4 d4E4 17992 380 | D4 e4 17948 381 | D5 G4b4 17827 382 | E5 F5 17785 383 | F4 d5 17715 384 | B5 B6 17689 385 | A3 e4 17570 386 | a5 d6 17561 387 | G3 A3 17481 388 | e3 G3b3 17474 389 | B5 E6 17318 390 | b2 d3 17202 391 | e4 B4 17152 392 | F3 G3 17145 393 | A2 D2D3 17116 394 | A3 B3 17111 395 | g5 D6 17098 396 | d4 F4a4 17092 397 | e4 E4 17069 398 | g2 A2 17019 399 | d4 G4 16857 400 | A2 d3 16801 401 | A3 D4F4 16731 402 | G4 d5 16647 403 | D3 g3A3 16606 404 | G4 C5E5 16421 405 | D3 A3D4 16380 406 | E3 d4 16288 407 | B1 a2 16182 408 | D4 a4 16178 409 | g4 e5 16134 410 | D6 g6 16081 411 | B1 G2 16046 412 | b3 g4 16029 413 | b4 E5 16023 414 | D6 G6 16000 415 | C7 A4dOC1 15982 416 | D2 B3 15953 417 | C2 A2 15932 418 | D3 F3A3 15831 419 | a4 b4 15813 420 | D5 e5 15785 421 | E2 A2 15719 422 | F4 b4D5 15689 423 | g2 A4 15641 424 | G3 C3C4 15592 425 | B4 C5 15578 426 | g3 C4 15549 427 | D3 C7 15542 428 | E4 g4 15526 429 | D2 B2 15522 430 | F3 b2b3 15424 431 | F3 A3D4 15384 432 | E2 G2 15301 433 | F3 d4 15288 434 | a3 B3E4 15280 435 | d5 a5 15275 436 | A5 b5 15156 437 | g2 a2 15102 438 | a3 d4F4 15059 439 | G4 a4 14989 440 | F5 A4C5 14963 441 | E4 a4B4 14909 442 | A4 e5 14878 443 | E2 g3 14877 444 | C2 D2g2 14842 445 | g2 b4 14797 446 | B2 G3 14722 447 | g2 B2 14678 448 | D3 C4 14674 449 | a2 C3 14568 450 | D5 g4A4 14455 451 | D6 D7 14368 452 | e6 G6 14298 453 | F2 G2 14250 454 | g3 A3D4 14164 455 | g3 A5 14137 456 | A2 B2 14070 457 | G3 d4 14028 458 | a4 D5 14021 459 | E5 d6 14013 460 | b4 D5F5 13942 461 | A5 E6 13938 462 | g4 B3D4 13791 463 | D3 E3 13586 464 | d5 e5 13582 465 | d4 D4 13572 466 | A5 B5 13569 467 | F3 B3 13515 468 | g4 C5 13489 469 | G2 e3 13442 470 | C6 D6 13418 471 | D2 b4 13414 472 | D4 G4b4 13379 473 | E4 A4C5 13309 474 | d1 d2 13278 475 | e5 A5 13257 476 | D2 C3 13203 477 | g3 B2B3 13190 478 | A2 F3 13156 479 | g2 b2 13128 480 | A4 d5E5 13124 481 | g5 G5 13115 482 | D2 e2 13103 483 | D5 F4A4 13051 484 | G3 F4 13036 485 | a3 b3 13008 486 | F4 b3d4 12988 487 | g2 B1D2 12987 488 | C5 F4a4 12979 489 | C4 F3a3 12953 490 | G2 E3 12947 491 | g4 G4 12893 492 | a3 C4F4 12867 493 | A2 g3 12829 494 | d5 E4A4 12789 495 | B4 F5 12709 496 | b2 G3 12701 497 | C5 e4G4 12695 498 | a4 C5e5 12636 499 | a4 B3E4 12564 500 | G3 C4e4 12547 501 | C5 g5 12520 502 | E2 a2 12490 503 | B1 A2 12419 504 | B3 E4G4 12386 505 | A5 D5F5 12356 506 | F6 F7 12343 507 | g2 E4 12325 508 | C4 b4 12325 509 | G3 B3E4 12234 510 | e5 E5 12193 511 | A3 A5 12149 512 | C5 d5 12116 513 | G2 C2C3 12093 514 | e4 a4C5 12079 515 | E2 e3 12011 516 | d2 g3 11958 517 | E3 a3B3 11893 518 | b5 C6 11835 519 | B1 A5 11752 520 | B3 C4 11710 521 | F2 a2 11692 522 | d5 D5 11683 523 | G5 E6 11678 524 | g4 a4 11584 525 | e5 G4b4 11579 526 | B4 d5 11462 527 | E2 F2 11446 528 | b2 g3 11409 529 | b4 g5 11367 530 | F2 g2 11290 531 | g4 b3d4 11285 532 | B3 e4g4 11268 533 | g2 b5 11165 534 | D3 G3B3 11105 535 | d5 G5 11056 536 | b2 D3F3 11040 537 | d2 e3 11033 538 | G5 e6 11021 539 | G4 C4C5 11016 540 | g3 D2g2 11012 541 | C1 D3 11001 542 | C3 F3A3 10958 543 | e6 g6 10885 544 | E3G3 B3E4 10791 545 | G6 G7 10775 546 | E3 F3 10763 547 | F6 A6 10720 548 | E5 g5 10691 549 | C2 e2 10639 550 | D2 E5 10586 551 | e3 a3C4 10577 552 | e3 a2a3 10486 553 | A3 G4 10485 554 | F4 g4 10478 555 | e1 D2 10442 556 | a4 C4F4 10423 557 | d4 F3a3 10389 558 | D2 E4 10378 559 | E3 B3E4 10348 560 | C3 e3G3 10322 561 | a4 A4 10277 562 | C2 D3 10274 563 | e5 B5 10197 564 | E6 E7 10169 565 | A4 D4D5 10136 566 | a2 B2 10079 567 | B5 E5G5 10067 568 | d6 F6 10054 569 | A3 b3 10005 570 | B2 e3 9966 571 | F2 D3 9951 572 | C3 a3 9848 573 | d3 g2g3 9832 574 | A2 C3E3 9829 575 | D5 a5 9817 576 | B2 C3 9662 577 | B5 e6 9622 578 | E3 A3C4 9618 579 | g5 A4D5 9614 580 | G5 B4D5 9586 581 | G3 a3 9544 582 | e3 A3 9529 583 | G5 C5e5 9498 584 | G5 a5 9447 585 | g3 b3d4 9298 586 | D2 D4 9292 587 | E3 A3d4 9284 588 | G4 C5e5 9269 589 | D4 C5 9163 590 | D3 e3 9129 591 | C3 C7A4dOC1 9110 592 | D6 E6 9088 593 | b2 e2e3 9073 594 | g2 G2 8997 595 | B3 d4 8974 596 | e3 B3 8967 597 | E3 b3 8960 598 | b4 e4e5 8947 599 | d2 E2 8933 600 | d6 g6 8866 601 | B5 C6 8843 602 | d3 C2D2 8842 603 | C5 F4F5 8839 604 | F2 B2 8838 605 | F3 g3 8788 606 | E5 G4B4 8787 607 | e3 b3e4 8786 608 | g5 d6 8765 609 | F5 B5 8757 610 | b3 D3F3 8734 611 | D4 G3G4 8722 612 | E2 d3 8716 613 | g2 D3 8668 614 | b4 B4 8581 615 | A4C5 F4F5 8522 616 | B3D4 G3G4 8501 617 | e3 F3 8494 618 | B2 G2D3 8465 619 | B1 e2 8458 620 | d3 F3a3 8457 621 | d5 g4A4 8449 622 | a4 d4d5 8446 623 | B4 E4E5 8430 624 | D2 G3 8419 625 | B4 D4g4 8418 626 | g2 C4 8394 627 | b5 F6 8375 628 | a5 e6 8338 629 | B3 C2D2 8319 630 | F5 g5 8260 631 | D2 E3 8259 632 | e2 a2 8239 633 | D2 F3 8234 634 | d3 b3 8232 635 | B4 e4g4 8211 636 | d5 F4a4 8147 637 | A4 D5F5 8143 638 | E6 A6 8125 639 | d4 g3A3 8096 640 | D2 A4 8093 641 | a5 b5 8060 642 | g3 B1D2 8048 643 | d2 F2 8037 644 | C6 F5A5 8032 645 | A5 D5g5 8014 646 | g3 E4 8003 647 | g4 b4d5 7992 648 | d6 d7 7988 649 | E4 A3A4 7980 650 | E2 C3 7902 651 | a2 A2 7882 652 | C2 B2 7811 653 | a2 g3 7775 654 | D5 G4G5 7766 655 | e6 e7 7754 656 | a5 A5 7670 657 | F2 C2D2 7652 658 | B1 B3 7646 659 | C6 G6 7565 660 | A4 G5 7556 661 | g2 d4 7479 662 | e3 G3C4 7462 663 | b0 b1 7445 664 | d3 D3 7438 665 | E2B2 E3G3B3E4 7436 666 | A2 F2C3 7369 667 | g6 A6 7363 668 | G4b4 e4e5 7350 669 | d4 F4b4 7339 670 | F5 d6 7288 671 | g3 b4 7276 672 | D2 e4 7262 673 | A5 B1A3 7237 674 | A5 F6 7228 675 | F2 e3 7190 676 | e2 G2 7156 677 | D4D5 g4A4 7152 678 | C4E4 G4C5 7111 679 | d3 B1D2 7079 680 | g3 B3e4 7057 681 | E5 b5 7034 682 | b4 e5G5 7026 683 | b3 B3 6969 684 | b5 B5 6957 685 | E6 a6 6928 686 | E5 a4B4 6907 687 | F6 a6 6906 688 | D3 a3 6892 689 | C5 F5A5 6869 690 | g3 a3 6845 691 | D6 e6 6823 692 | E3 g3 6820 693 | e1 C2 6812 694 | B1 E3 6772 695 | A3C4 F3F4 6770 696 | D6 A6 6769 697 | B1 D3 6735 698 | g4 A3d4 6734 699 | G3G4 G2D3 6726 700 | D3 g4A3D4 6710 701 | -------------------------------------------------------------------------------- /src/musicBPE/fastBPE/fastBPE.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include // ftruncate 20 | #include 21 | #include 22 | #include 23 | 24 | 25 | namespace fastBPE { 26 | 27 | using namespace std; 28 | 29 | const size_t kMaxPairs = 1000 * 1000 * 1000; 30 | 31 | 32 | int safeOpen(const char *file_path, int flags, mode_t mode = 0) { 33 | int fd = open(file_path, flags, mode); 34 | if (fd < 0) { 35 | fprintf(stderr, "Cannot open text file %s\n", file_path); 36 | exit(EXIT_FAILURE); 37 | } 38 | return fd; 39 | } 40 | char pit2chr(int pit) { 41 | return char(pit); 42 | } 43 | void readText(const char *fp, unordered_map &word_count) { 44 | char cur_word[300]; 45 | 46 | uint64_t total = 0; 47 | 48 | FILE* fin = fopen(fp, "r"); 49 | int tot; 50 | fscanf(fin, "%d", &tot); 51 | fprintf(stderr, "total %d words...\n", tot); 52 | for (int i = 0; i < tot; i++) { 53 | fscanf(fin, "%s", cur_word); 54 | int cnt; 55 | fscanf(fin, "%d", &cnt); 56 | string s(cur_word); 57 | word_count[s] = cnt; 58 | total += cnt; 59 | } 60 | 61 | 62 | fprintf(stderr, "Read %lu words (%lu unique) from text file.\n", total, 63 | word_count.size()); 64 | } 65 | 66 | std::pair output_or_count( 67 | unordered_map &bpe, size_t size, char *f, char *fo 68 | ) { 69 | string cur_word; 70 | size_t charOut = 0; 71 | uint64_t total = 0; 72 | for (size_t i = 0; i < size; i++) { 73 | auto &cur_char = f[i]; 74 | if (cur_char == ' ' || cur_char == '\n') { 75 | if (cur_word.size() == 0) { 76 | if (fo != nullptr) fo[charOut] = cur_char; 77 | charOut++; 78 | continue; 79 | } 80 | // end of word : write bpe to output 81 | auto it = bpe.find(cur_word); 82 | assert(it != bpe.end()); 83 | for (auto x : it->second) { 84 | if (fo != nullptr) fo[charOut] = x; 85 | charOut++; 86 | } 87 | if (fo != nullptr) fo[charOut] = cur_char; 88 | charOut++; 89 | 90 | total++; 91 | cur_word.clear(); 92 | } else { 93 | cur_word.push_back(cur_char); 94 | } 95 | } 96 | return std::make_pair(charOut, total); 97 | } 98 | 99 | void outputText(const char *fpo, const char *fp, 100 | unordered_map &bpe) { 101 | 102 | int fd = safeOpen(fp, O_RDONLY); 103 | auto fdOut = safeOpen(fpo, O_RDWR | O_CREAT | O_TRUNC, 0666); 104 | 105 | struct stat s; 106 | fstat(fd, &s); 107 | 108 | fprintf(stderr, "Applying BPE to %s ...\n", fp); 109 | auto size = s.st_size; 110 | char *f = (char *)mmap(NULL, size, PROT_READ, MAP_PRIVATE, fd, 0); 111 | 112 | auto p = output_or_count(bpe, size, f, nullptr); 113 | size_t out_size = p.first; 114 | 115 | if (ftruncate(fdOut, out_size) < 0) { 116 | fprintf(stderr, "Couldn't truncate output file %s to size %lu\n", fpo, 117 | out_size); 118 | exit(EXIT_FAILURE); 119 | } 120 | 121 | 122 | char *fo = (char *)mmap(NULL, out_size, PROT_WRITE, MAP_SHARED, fdOut, 0); 123 | if (fo == MAP_FAILED) { 124 | fprintf(stderr, "Output memory map failed : %d.\n", errno); 125 | exit(EXIT_FAILURE); 126 | } 127 | p = output_or_count(bpe, size, f, fo); 128 | fprintf(stderr, "Modified %lu words from text file.\n", p.second); 129 | munmap(fo, out_size); 130 | munmap(f, size); 131 | close(fdOut); 132 | close(fd); 133 | } 134 | 135 | struct pair_hash { 136 | template size_t operator()(const pair &p) const { 137 | auto h1 = hash{}(p.first); 138 | auto h2 = hash{}(p.second); 139 | size_t seed = h1; 140 | // boost::hash_combine 141 | return h2 + 0x9e3779b9 + (seed << 6) + (seed >> 2); 142 | } 143 | }; 144 | char pit2alphabet[] = {'C', 'd', 'D', 'e', 'E', 'F', 'g', 'G', 'a', 'A', 'b', 'B'}; 145 | char oct2alphabet[] = {'O', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}; 146 | void build_init_vocabs(unordered_map &token_to_int, 147 | vector &int_to_token) { 148 | 149 | for(int i = 0; i < 128; i++) { 150 | int octave = i / 12; 151 | int relpit = i % 12; 152 | string oct(1, oct2alphabet[octave]); 153 | string rel(1, pit2alphabet[relpit]); 154 | string s = rel + oct; 155 | int_to_token.push_back(s); 156 | token_to_int[s] = i; 157 | } 158 | } 159 | void tokenize(const unordered_map &word_count, 160 | unordered_map &token_to_int, 161 | vector &int_to_token, vector> &words, 162 | vector &counts) { 163 | build_init_vocabs(token_to_int, int_to_token); 164 | 165 | for (auto &x : word_count) { 166 | auto &word = x.first; 167 | 168 | words.push_back(list()); 169 | auto ¤t_word = words.back(); 170 | counts.push_back(x.second); 171 | 172 | int tmplen = word.length(); 173 | uint32_t lasttoken = -1; 174 | for (int i = 0; i < tmplen; i += 2) { 175 | auto new_token = word.substr(i, 2); 176 | if (i + 2 == tmplen) 177 | new_token = new_token;// + kEndWord; 178 | 179 | if (token_to_int.count(new_token) == 0) { 180 | fprintf(stderr, "Error init vocab: %s\n", (char*)new_token.c_str()); 181 | exit(0); 182 | } 183 | uint32_t cur_pit = token_to_int[new_token]; 184 | if (lasttoken != -1 && cur_pit <= lasttoken) { 185 | fprintf(stderr, "Error init vocab order: in %s\n", (char*)word.c_str()); 186 | exit(0); 187 | } 188 | lasttoken = cur_pit; 189 | current_word.push_back(cur_pit); 190 | } 191 | } 192 | } 193 | 194 | using tp = pair; 195 | using tps = pair; 196 | using pc = unordered_map *, pair_hash>; 197 | // process every word 198 | void count_in_word( 199 | list &word, uint32_t wi, uint32_t count, pc &pair_counts, 200 | vector> &contiguous_counts, 201 | unordered_map, pair_hash> &where) { 202 | 203 | tp cur_pair; 204 | 205 | for (auto i = word.begin(); i != word.end(); ++i) { 206 | auto tmp = i; 207 | ++tmp; 208 | for (auto j = tmp; j != word.end(); ++j) { 209 | cur_pair.first = *i; 210 | cur_pair.second = *j; 211 | auto it = pair_counts.find(cur_pair); 212 | if (it == pair_counts.end()) { 213 | contiguous_counts.emplace_back(0, cur_pair); 214 | auto *added = &contiguous_counts.back(); 215 | pair_counts.emplace(piecewise_construct, forward_as_tuple(cur_pair), forward_as_tuple(added)); 216 | where[cur_pair] = unordered_set(); 217 | } 218 | if (count > 0) {where[cur_pair].insert(wi);} else {fprintf(stderr, "count in word init error!\n"); exit(0);} 219 | pair_counts[cur_pair]->first += count; 220 | } 221 | } 222 | } 223 | 224 | void find_maxp(vector> &contiguous_counts, tp &maxp, 225 | int32_t &max_c) { 226 | max_c = 0; 227 | for (auto &x : contiguous_counts) { 228 | if (x.first > max_c) { 229 | max_c = x.first; 230 | maxp = x.second; 231 | } else if (x.first == max_c and x.second < maxp) { 232 | maxp = x.second; 233 | } 234 | } 235 | } 236 | 237 | 238 | void learnbpe(const uint32_t kNPairs, const char *inputFile1, 239 | const char *inputFile2) { 240 | // get vocab 241 | unordered_map word_count; 242 | readText(inputFile1, word_count); 243 | if (strcmp(inputFile2, "") != 0) { 244 | readText(inputFile2, word_count); 245 | } 246 | 247 | // a token is an int, it represents a string 248 | unordered_map token_to_int; 249 | vector int_to_token; 250 | 251 | vector> words; 252 | vector counts; 253 | 254 | tokenize(word_count, token_to_int, int_to_token, words, counts); 255 | 256 | // tp: pair 257 | vector> contiguous_counts; 258 | contiguous_counts.reserve(kMaxPairs); 259 | 260 | pc pair_counts; 261 | unordered_map, pair_hash> where_to_update; 262 | 263 | tp cur_pair; 264 | int32_t max_c = 0; 265 | tp max_p; 266 | for (uint32_t wi = 0; wi < words.size(); wi++) { 267 | count_in_word(words[wi], wi, counts[wi], pair_counts, contiguous_counts, 268 | where_to_update); 269 | } 270 | 271 | 272 | for (size_t i = 0; i < kNPairs; i++) { 273 | // create new token for pair. replace 274 | find_maxp(contiguous_counts, max_p, max_c); 275 | 276 | auto new_token = int_to_token[max_p.first] + int_to_token[max_p.second]; 277 | cout << int_to_token[max_p.first] << " " << int_to_token[max_p.second] 278 | << " " << max_c << endl; 279 | 280 | uint32_t new_token_id = int_to_token.size(); 281 | int_to_token.push_back(new_token); 282 | token_to_int[new_token] = new_token_id; 283 | max_c = 0; 284 | auto change_count = [&](tp pair, int32_t v, uint32_t wi) { 285 | auto it = pair_counts.find(pair); 286 | if (it != pair_counts.end()) { 287 | it->second->first += v; 288 | } else { 289 | if (v > 0) { 290 | contiguous_counts.emplace_back(v, pair); 291 | pair_counts.emplace(piecewise_construct, forward_as_tuple(pair), 292 | forward_as_tuple(&(contiguous_counts.back()))); 293 | where_to_update[pair] = unordered_set(); 294 | } 295 | } 296 | if (v > 0) 297 | where_to_update[pair].insert(wi); 298 | else 299 | where_to_update[pair].erase(wi); 300 | }; 301 | 302 | for (auto wi : where_to_update[max_p]) { 303 | 304 | auto &cur_word = words[wi]; 305 | int cnt = 0; 306 | for (auto it = cur_word.begin(); it != cur_word.end(); ++it) { 307 | if (*it == max_p.first || *it == max_p.second) 308 | cnt++; 309 | } 310 | if (cnt != 2) { // where to update is not maintained 311 | fprintf(stderr, "where to update is not maintained! %d %d %d ", cnt, max_p.first, max_p.second); 312 | for (auto it = cur_word.begin(); it != cur_word.end(); ++it) 313 | fprintf(stderr, "%d ", *it); 314 | fprintf(stderr, "\n"); 315 | continue; 316 | } 317 | 318 | 319 | auto it = cur_word.begin(); 320 | while (it != cur_word.end()) { 321 | if (*it != max_p.first && *it != max_p.second) { 322 | uint32_t u, v; 323 | 324 | u = *it; 325 | v = max_p.first; 326 | if (u > v) swap(u, v); 327 | change_count(make_pair(u, v), -counts[wi], wi); 328 | 329 | u = *it; 330 | v = max_p.second; 331 | if (u > v) swap(u, v); 332 | change_count(make_pair(u, v), -counts[wi], wi); 333 | 334 | u = *it; 335 | v = new_token_id; 336 | if (u > v) swap(u, v); 337 | change_count(make_pair(u, v), counts[wi], wi); 338 | ++it; 339 | } else 340 | it = cur_word.erase(it); 341 | 342 | } 343 | cur_word.insert(cur_word.end(), new_token_id); 344 | } 345 | 346 | 347 | if (pair_counts.find(max_p) != pair_counts.end()){ 348 | pair_counts[max_p]->first = 0; 349 | } 350 | 351 | } 352 | for (int i = 0; i < int_to_token.size(); ++i) { 353 | fprintf(stderr, "%d %s\n", i, int_to_token[i].c_str()); 354 | } 355 | } 356 | 357 | void split(vector &splits, const string &text, char sep) { 358 | size_t start = 0, end = 0; 359 | while ((end = text.find(sep, start)) != string::npos) { 360 | if (end != start) 361 | splits.push_back(text.substr(start, end - start)); 362 | start = end + 1; 363 | } 364 | if (end != start && start < text.size()) 365 | splits.push_back(text.substr(start)); 366 | } 367 | 368 | void readVocab(const char *fp, unordered_map &vocab) { 369 | ifstream file(fp); 370 | if (!file) { 371 | fprintf(stderr, "Cannot open vocabulary file %s\n", fp); 372 | exit(EXIT_FAILURE); 373 | } 374 | fprintf(stderr, "Loading vocabulary from %s ...\n", fp); 375 | string line; 376 | uint64_t total = 0; 377 | while (getline(file, line)) { 378 | vector splits; 379 | split(splits, line, ' '); 380 | assert(splits.size() == 2); 381 | assert(vocab.find(splits[0]) == vocab.end()); 382 | int count = stoi(splits[1]); 383 | vocab[splits[0]] = count; 384 | total += count; 385 | } 386 | fprintf(stderr, "Read %lu words (%lu unique) from vocabulary file.\n", total, 387 | vocab.size()); 388 | } 389 | 390 | 391 | }; -------------------------------------------------------------------------------- /src/fairseq/gen_utils.py: -------------------------------------------------------------------------------- 1 | import os, time 2 | 3 | from more_itertools.more import last 4 | 5 | import json 6 | import numpy as np 7 | import torch 8 | import copy 9 | from tqdm import tqdm 10 | from miditoolkit.midi.containers import Note as mtkNote 11 | from miditoolkit.midi.parser import MidiFile 12 | from miditoolkit.midi.containers import Instrument 13 | import sys, os 14 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 15 | from encoding import ison, char2int, str2pit, ispitch 16 | from preprocess.preprocess_midi import midi_to_event_seq_str 17 | from preprocess.get_bpe_data import apply_bpe_for_sentence, load_before_apply_bpe 18 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 19 | from make_data import process_single_piece 20 | 21 | 22 | 23 | PAD = 1 24 | EOS = 2 25 | BOS = 0 26 | 27 | RATIO = 4 28 | MAX_POS_LEN = 4096 29 | PI_LEVEL = 2 30 | IGNORE_META_LOSS = 1 31 | NOTON_PAD = BOS if IGNORE_META_LOSS == 0 else PAD 32 | NOTON_PAD_DUR = NOTON_PAD 33 | NOTON_PAD_TRK = NOTON_PAD 34 | 35 | class Dictionary(object): 36 | def __init__(self): 37 | self.vocabs = {} 38 | self.voc2int = {} 39 | self.str2int = {} 40 | self.merges = None 41 | self.merged_vocs = None 42 | 43 | def load_vocabs_bpe(self, DATA_VOC_DIR, BPE_DIR=None): 44 | for i in range(RATIO): 45 | with open(f'{DATA_VOC_DIR}vocab_{i}.json', 'r') as f: 46 | self.vocabs[i] = json.load(f) 47 | self.voc2int[i] = {v:int(k)for k, v in self.vocabs[i].items()} 48 | 49 | 50 | with open(f'{DATA_VOC_DIR}ori_dict.json', 'r') as f: 51 | self.str2int = json.load(f) 52 | 53 | self.str2int.update({x:(PAD if IGNORE_META_LOSS == 1 else BOS) for x in ('RZ', 'TZ', 'YZ')}) 54 | 55 | # for BPE 56 | if BPE_DIR is not None: 57 | self.merges, self.merged_vocs = load_before_apply_bpe(BPE_DIR) 58 | 59 | def index2word(self, typ, i): 60 | return self.vocabs[typ][str(i)] 61 | def word2index(self, typ, i): 62 | return self.voc2int[typ][i] 63 | 64 | def is_bom(self, idx): 65 | return self.index2word(0, idx)[0].lower() == 'm' 66 | 67 | music_dict = Dictionary() 68 | 69 | prime_chords = None 70 | prime_mea_idx = 0 71 | 72 | def process_prime_midi(prime_midi_path, max_measures, max_chord_measures, perm_inv = PI_LEVEL, ratio=RATIO, sample_len_max=MAX_POS_LEN): 73 | 74 | toks = midi_to_event_seq_str(prime_midi_path, readonly=True) 75 | if music_dict.merges is not None: 76 | toks = apply_bpe_for_sentence(toks, music_dict.merges, music_dict.merged_vocs, {}) 77 | 78 | 79 | measures, _, _, _ = process_single_piece((toks, music_dict.str2int), ratio, sample_len_max) 80 | 81 | prime_nums = [[EOS]*ratio + [0, 0]] 82 | prime_nums[0][3] = 1 # set instrument to vanilla pos 83 | ins_label = [EOS] 84 | 85 | 86 | trk_map = np.concatenate([np.arange(4), np.random.permutation(40) + 4]) 87 | 88 | global prime_chords 89 | prime_chords = [music_dict.index2word(0, x[ratio+1]) for x in measures[:max_chord_measures]] 90 | for mea_id, mea in enumerate(measures): 91 | if mea_id >= max_measures: 92 | break 93 | assert len(mea) % (ratio+1) == 0, ('Error: Invalid input prime.', mea) 94 | 95 | if perm_inv % 2 == 1: 96 | auged_measure = [] 97 | cc_list = [] 98 | cur_cc = [] 99 | for id in range(0, len(mea), ratio+1): 100 | cur_tok = mea[id:id+ratio+1] 101 | if id <= ratio + 1: 102 | auged_measure += cur_tok 103 | continue 104 | 105 | if cur_tok[0] == music_dict.str2int['NT'] and len(cur_cc) > 0: 106 | cc_list.append(cur_cc) 107 | cur_cc = [] 108 | cur_cc += cur_tok 109 | if len(cur_cc) > 0: 110 | cc_list.append(cur_cc) 111 | cur_cc = [] 112 | 113 | if len(cc_list) > 1: 114 | new_order = np.random.permutation(len(cc_list)) 115 | for i in new_order: 116 | auged_measure += cc_list[i] 117 | else: 118 | for cc in cc_list: 119 | auged_measure += cc 120 | 121 | assert len(auged_measure) == len(mea), ('Error: exception during permutaiton.', len(auged_measure), len(mea)) 122 | mea = auged_measure 123 | 124 | for id in range(0, len(mea), ratio+1): 125 | mea_pos = (mea_id+1) * 3 126 | if id == 0: 127 | mea_pos -= 2 128 | elif id == ratio+1: 129 | mea_pos -= 1 130 | cur_tok = mea[id:id+ratio+1] + [mea_pos] 131 | ins_label.append(cur_tok[3]) 132 | cur_tok[3] = len(prime_nums) + 1 133 | if perm_inv > 0: 134 | cur_tok[2] = trk_map[cur_tok[2]] 135 | prime_nums.append(cur_tok) 136 | 137 | return prime_nums, ins_label 138 | 139 | def get_next_chord(ori): 140 | global prime_chords 141 | assert prime_chords is not None, 'Error: empty prime chords.' 142 | global prime_mea_idx 143 | if prime_mea_idx < len(prime_chords): 144 | ret = prime_chords[prime_mea_idx] 145 | prime_mea_idx += 1 146 | return ret 147 | else: 148 | return ori 149 | 150 | def get_next(model, p, memory, has_prime = False): 151 | pr = torch.from_numpy(np.array(p))[None, None, :].cuda() 152 | 153 | (e,d,t,ins), memory = model(src_tokens=pr, src_lengths=memory) 154 | e, d, t, ins = e[0,:], d[0,:], t[0,:], ins[0,:] 155 | if has_prime: 156 | return (np.int64(EOS), np.int64(EOS), np.int64(EOS), ins), memory 157 | evt = sampling(e) 158 | while evt == EOS: 159 | return (evt, np.int64(EOS), np.int64(EOS), ins), memory 160 | # evt = sampling(e) 161 | evt_word = music_dict.index2word(0, evt) 162 | if evt_word.startswith('H'): 163 | rep = get_next_chord(evt_word) 164 | return (np.int64(music_dict.word2index(0, rep)), np.int64(NOTON_PAD_DUR), np.int64(NOTON_PAD_TRK), ins), memory 165 | if not ison(evt_word): 166 | return (evt, np.int64(NOTON_PAD_DUR), np.int64(NOTON_PAD_TRK), ins), memory 167 | 168 | dur = sampling(d) 169 | while dur == EOS: 170 | dur = sampling(d) 171 | while dur == NOTON_PAD_DUR: 172 | dur = sampling(d) 173 | 174 | trk = sampling(t, p=0) 175 | 176 | 177 | return (evt, dur, trk, ins), memory 178 | 179 | 180 | 181 | 182 | def calc_pos(evt_tok, last_rel_pos, last_mea_pos): 183 | assert evt_tok != EOS, 'Invalid generation: no eos pos' 184 | typ = music_dict.index2word(0, evt_tok)[0].lower() 185 | if typ == 'm': 186 | if (last_mea_pos+1) % 3 == 0: #empty measure 187 | last_mea_pos += 1 188 | assert (last_mea_pos+1) % 3 == 1, f'Invalid generation: error measure pos {last_mea_pos+1}' #TODO: empty measure 189 | return 0, last_mea_pos + 1 190 | elif typ == 'h': 191 | assert (last_mea_pos+1) % 3 == 2, f'Invalid generation: there must be a before a chord {last_mea_pos+1}' 192 | return 0, last_mea_pos + 1 193 | elif typ == 'n': 194 | if last_mea_pos % 3 == 2: 195 | last_mea_pos += 1 196 | assert last_mea_pos % 3 == 0, f'Invalid generation: mea pos of must be a multiple of 3 {last_mea_pos}' 197 | return 1, last_mea_pos 198 | elif typ == 'p': 199 | assert last_mea_pos % 3 == 0, f'Invalid generation: mea pos of must be a multiple of 3 {last_mea_pos}' 200 | assert (last_rel_pos+1) % 2 == 0, f'Invalid generation: rel pos of must be even {last_rel_pos+1}' 201 | return last_rel_pos+1, last_mea_pos 202 | 203 | assert last_mea_pos % 3 == 0, f'Invalid generation: mea pos of must be a multiple of 3 {last_mea_pos}' 204 | if last_rel_pos % 2 == 0: # last token is a 205 | last_rel_pos += 1 206 | 207 | return last_rel_pos, last_mea_pos # on 208 | 209 | def gen_one(model, prime_nums, MAX_LEN = 4090, MIN_LEN = 0): 210 | 211 | 212 | global prime_mea_idx 213 | prime_mea_idx = 0 214 | prime = copy.deepcopy(prime_nums) 215 | ins_list = [-1] 216 | 217 | with torch.no_grad(): 218 | memo = None 219 | cur_rel_pos = 0 220 | cur_mea = 0 221 | for item, next_item in zip(prime[:-1], prime[1:]): 222 | 223 | (e,d,t,ins), memo = get_next(model, item, memo, has_prime=True) 224 | cur_rel_pos, cur_mea = calc_pos(next_item[0], cur_rel_pos, cur_mea) 225 | ins_list.append(ins) 226 | 227 | 228 | (e,d,t,ins), memo = get_next(model, prime[-1], memo, has_prime=False) 229 | cur_rel_pos, cur_mea = calc_pos(e, cur_rel_pos, cur_mea) 230 | 231 | prime.append((e,d,t,len(prime)+1, cur_rel_pos, cur_mea)) 232 | ins_list.append(ins) 233 | 234 | for i in tqdm(range(MAX_LEN - len(prime))): 235 | (e,d,t,ins), memo = get_next(model, prime[-1], memo) 236 | if t == EOS: 237 | assert len(prime) > MIN_LEN, 'Invalid generation: Generated excerpt too short.' 238 | break 239 | cur_rel_pos, cur_mea = calc_pos(e, cur_rel_pos, cur_mea) 240 | 241 | prime.append((e,d,t,len(prime)+1, cur_rel_pos, cur_mea)) 242 | ins_list.append(ins) 243 | 244 | return prime, ins_list 245 | 246 | def get_trk_ins_map(prime, ins_list): 247 | track_map = {} 248 | 249 | idx = 0 250 | for (e,d,t,_, _, _),ins in zip(prime, ins_list): 251 | ee = music_dict.index2word(0, e) 252 | idx += 1 253 | 254 | if ison(ee): 255 | track_map.setdefault(t, []).append(ins) 256 | trk_ins_map = {} 257 | for k in track_map: 258 | v = torch.stack(track_map[k]) 259 | logits = torch.mean(v, axis=0) 260 | ins_word = sampling(logits,p=0.9) 261 | trk_ins_map[k] = ins_word 262 | return trk_ins_map 263 | 264 | def get_note_seq(prime, trk_ins_map): 265 | note_seq = [] 266 | measure_time = 0 267 | last_bom = 0 268 | error_note = 0 269 | for (e,d,t,_, _, _) in prime[1:]: 270 | 271 | ee = music_dict.index2word(0, e) 272 | if ee[0].lower() == 'm': 273 | 274 | measure_time += last_bom 275 | last_bom = char2int(ee[1])+(62 if ee[0] == 'M' else 0) 276 | last_pos = -1 277 | elif ee[0].lower() == 'p': 278 | last_pos = char2int(ee[1]) + (62 if ee[0] == 'P' else 0) 279 | elif ee == 'NT': 280 | last_pos = -1 281 | elif ee[0].lower() == 'h': 282 | pass 283 | elif ison(ee): 284 | if t != NOTON_PAD_TRK and d != NOTON_PAD_DUR: 285 | dd = music_dict.index2word(1, d) 286 | tt = music_dict.index2word(2, t) 287 | assert last_pos != -1, 'Invalid generation: there must be a between and ' 288 | start = measure_time + last_pos 289 | trk = char2int(tt[1])+(62 if tt[0] == 'T' else 0) 290 | dur = char2int(dd[1])+(62 if dd[0] == 'R' else 0) 291 | 292 | for i in range(0, len(ee), 2): 293 | eee = ee[i:i+2] 294 | note_seq.append((str2pit(eee), trk_ins_map[t]-4, start, start + dur, trk)) 295 | else: 296 | error_note += 1 297 | else: 298 | assert False, ('Invalid generation: unknown token: ', (ee, d, t)) 299 | # print(f'error note cnt: {error_note}') 300 | return note_seq 301 | 302 | def note_seq_to_midi_file(note_seq, filename, ticks_per_beat=480): 303 | 304 | tickes_per_32th = ticks_per_beat // 8 305 | tracks = {} 306 | for pitch, program, start, end, track_id in note_seq: 307 | 308 | tracks.setdefault((track_id, program), []).append(mtkNote(90, pitch, start * tickes_per_32th, end * tickes_per_32th)) 309 | 310 | midi_out = MidiFile(ticks_per_beat=ticks_per_beat) 311 | 312 | for tp, notes in tracks.items(): 313 | program = tp[1] 314 | instrument = Instrument(program % 128, is_drum=program >= 128) 315 | instrument.notes = notes 316 | instrument.remove_invalid_notes(verbose=False) 317 | midi_out.instruments.append(instrument) 318 | midi_out.dump(filename) 319 | 320 | 321 | def softmax_with_temperature(logits, temperature): 322 | probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature)) 323 | return probs 324 | 325 | def weighted_sampling(probs): 326 | probs /= sum(probs) 327 | sorted_probs = np.sort(probs)[::-1] 328 | sorted_index = np.argsort(probs)[::-1] 329 | word = np.random.choice(sorted_index, size=1, p=sorted_probs)[0] 330 | return word 331 | 332 | def nucleus(probs, p): 333 | probs /= (sum(probs) + 1e-5) 334 | sorted_probs = np.sort(probs)[::-1] 335 | sorted_index = np.argsort(probs)[::-1] 336 | cusum_sorted_probs = np.cumsum(sorted_probs) 337 | after_threshold = cusum_sorted_probs > p 338 | if sum(after_threshold) > 0: 339 | last_index = np.where(after_threshold)[0][0] + 1 340 | candi_index = sorted_index[:last_index] 341 | else: 342 | candi_index = sorted_index[:] 343 | candi_probs = [probs[i] for i in candi_index] 344 | candi_probs /= sum(candi_probs) 345 | word = np.random.choice(candi_index, size=1, p=candi_probs)[0] 346 | return word 347 | 348 | 349 | def sampling(logit, p=None, t=1.0): 350 | logit = logit.squeeze().cpu().numpy() 351 | probs = softmax_with_temperature(logits=logit, temperature=t) 352 | 353 | if p is not None: 354 | cur_word = nucleus(probs, p=p) 355 | else: 356 | cur_word = weighted_sampling(probs) 357 | return cur_word -------------------------------------------------------------------------------- /data/model_spec/linear_4096_chord_bpe_hardloss1/bin/dict.txt: -------------------------------------------------------------------------------- 1 | 4 0 2 | 5 0 3 | 6 0 4 | 7 0 5 | 8 0 6 | 9 0 7 | 10 0 8 | 11 0 9 | 12 0 10 | 13 0 11 | 14 0 12 | 15 0 13 | 16 0 14 | 17 0 15 | 18 0 16 | 19 0 17 | 20 0 18 | 21 0 19 | 22 0 20 | 23 0 21 | 24 0 22 | 25 0 23 | 26 0 24 | 27 0 25 | 28 0 26 | 29 0 27 | 30 0 28 | 31 0 29 | 32 0 30 | 33 0 31 | 34 0 32 | 35 0 33 | 36 0 34 | 37 0 35 | 38 0 36 | 39 0 37 | 40 0 38 | 41 0 39 | 42 0 40 | 43 0 41 | 44 0 42 | 45 0 43 | 46 0 44 | 47 0 45 | 48 0 46 | 49 0 47 | 50 0 48 | 51 0 49 | 52 0 50 | 53 0 51 | 54 0 52 | 55 0 53 | 56 0 54 | 57 0 55 | 58 0 56 | 59 0 57 | 60 0 58 | 61 0 59 | 62 0 60 | 63 0 61 | 64 0 62 | 65 0 63 | 66 0 64 | 67 0 65 | 68 0 66 | 69 0 67 | 70 0 68 | 71 0 69 | 72 0 70 | 73 0 71 | 74 0 72 | 75 0 73 | 76 0 74 | 77 0 75 | 78 0 76 | 79 0 77 | 80 0 78 | 81 0 79 | 82 0 80 | 83 0 81 | 84 0 82 | 85 0 83 | 86 0 84 | 87 0 85 | 88 0 86 | 89 0 87 | 90 0 88 | 91 0 89 | 92 0 90 | 93 0 91 | 94 0 92 | 95 0 93 | 96 0 94 | 97 0 95 | 98 0 96 | 99 0 97 | 100 0 98 | 101 0 99 | 102 0 100 | 103 0 101 | 104 0 102 | 105 0 103 | 106 0 104 | 107 0 105 | 108 0 106 | 109 0 107 | 110 0 108 | 111 0 109 | 112 0 110 | 113 0 111 | 114 0 112 | 115 0 113 | 116 0 114 | 117 0 115 | 118 0 116 | 119 0 117 | 120 0 118 | 121 0 119 | 122 0 120 | 123 0 121 | 124 0 122 | 125 0 123 | 126 0 124 | 127 0 125 | 128 0 126 | 129 0 127 | 130 0 128 | 131 0 129 | 132 0 130 | 133 0 131 | 134 0 132 | 135 0 133 | 136 0 134 | 137 0 135 | 138 0 136 | 139 0 137 | 140 0 138 | 141 0 139 | 142 0 140 | 143 0 141 | 144 0 142 | 145 0 143 | 146 0 144 | 147 0 145 | 148 0 146 | 149 0 147 | 150 0 148 | 151 0 149 | 152 0 150 | 153 0 151 | 154 0 152 | 155 0 153 | 156 0 154 | 157 0 155 | 158 0 156 | 159 0 157 | 160 0 158 | 161 0 159 | 162 0 160 | 163 0 161 | 164 0 162 | 165 0 163 | 166 0 164 | 167 0 165 | 168 0 166 | 169 0 167 | 170 0 168 | 171 0 169 | 172 0 170 | 173 0 171 | 174 0 172 | 175 0 173 | 176 0 174 | 177 0 175 | 178 0 176 | 179 0 177 | 180 0 178 | 181 0 179 | 182 0 180 | 183 0 181 | 184 0 182 | 185 0 183 | 186 0 184 | 187 0 185 | 188 0 186 | 189 0 187 | 190 0 188 | 191 0 189 | 192 0 190 | 193 0 191 | 194 0 192 | 195 0 193 | 196 0 194 | 197 0 195 | 198 0 196 | 199 0 197 | 200 0 198 | 201 0 199 | 202 0 200 | 203 0 201 | 204 0 202 | 205 0 203 | 206 0 204 | 207 0 205 | 208 0 206 | 209 0 207 | 210 0 208 | 211 0 209 | 212 0 210 | 213 0 211 | 214 0 212 | 215 0 213 | 216 0 214 | 217 0 215 | 218 0 216 | 219 0 217 | 220 0 218 | 221 0 219 | 222 0 220 | 223 0 221 | 224 0 222 | 225 0 223 | 226 0 224 | 227 0 225 | 228 0 226 | 229 0 227 | 230 0 228 | 231 0 229 | 232 0 230 | 233 0 231 | 234 0 232 | 235 0 233 | 236 0 234 | 237 0 235 | 238 0 236 | 239 0 237 | 240 0 238 | 241 0 239 | 242 0 240 | 243 0 241 | 244 0 242 | 245 0 243 | 246 0 244 | 247 0 245 | 248 0 246 | 249 0 247 | 250 0 248 | 251 0 249 | 252 0 250 | 253 0 251 | 254 0 252 | 255 0 253 | 256 0 254 | 257 0 255 | 258 0 256 | 259 0 257 | 260 0 258 | 261 0 259 | 262 0 260 | 263 0 261 | 264 0 262 | 265 0 263 | 266 0 264 | 267 0 265 | 268 0 266 | 269 0 267 | 270 0 268 | 271 0 269 | 272 0 270 | 273 0 271 | 274 0 272 | 275 0 273 | 276 0 274 | 277 0 275 | 278 0 276 | 279 0 277 | 280 0 278 | 281 0 279 | 282 0 280 | 283 0 281 | 284 0 282 | 285 0 283 | 286 0 284 | 287 0 285 | 288 0 286 | 289 0 287 | 290 0 288 | 291 0 289 | 292 0 290 | 293 0 291 | 294 0 292 | 295 0 293 | 296 0 294 | 297 0 295 | 298 0 296 | 299 0 297 | 300 0 298 | 301 0 299 | 302 0 300 | 303 0 301 | 304 0 302 | 305 0 303 | 306 0 304 | 307 0 305 | 308 0 306 | 309 0 307 | 310 0 308 | 311 0 309 | 312 0 310 | 313 0 311 | 314 0 312 | 315 0 313 | 316 0 314 | 317 0 315 | 318 0 316 | 319 0 317 | 320 0 318 | 321 0 319 | 322 0 320 | 323 0 321 | 324 0 322 | 325 0 323 | 326 0 324 | 327 0 325 | 328 0 326 | 329 0 327 | 330 0 328 | 331 0 329 | 332 0 330 | 333 0 331 | 334 0 332 | 335 0 333 | 336 0 334 | 337 0 335 | 338 0 336 | 339 0 337 | 340 0 338 | 341 0 339 | 342 0 340 | 343 0 341 | 344 0 342 | 345 0 343 | 346 0 344 | 347 0 345 | 348 0 346 | 349 0 347 | 350 0 348 | 351 0 349 | 352 0 350 | 353 0 351 | 354 0 352 | 355 0 353 | 356 0 354 | 357 0 355 | 358 0 356 | 359 0 357 | 360 0 358 | 361 0 359 | 362 0 360 | 363 0 361 | 364 0 362 | 365 0 363 | 366 0 364 | 367 0 365 | 368 0 366 | 369 0 367 | 370 0 368 | 371 0 369 | 372 0 370 | 373 0 371 | 374 0 372 | 375 0 373 | 376 0 374 | 377 0 375 | 378 0 376 | 379 0 377 | 380 0 378 | 381 0 379 | 382 0 380 | 383 0 381 | 384 0 382 | 385 0 383 | 386 0 384 | 387 0 385 | 388 0 386 | 389 0 387 | 390 0 388 | 391 0 389 | 392 0 390 | 393 0 391 | 394 0 392 | 395 0 393 | 396 0 394 | 397 0 395 | 398 0 396 | 399 0 397 | 400 0 398 | 401 0 399 | 402 0 400 | 403 0 401 | 404 0 402 | 405 0 403 | 406 0 404 | 407 0 405 | 408 0 406 | 409 0 407 | 410 0 408 | 411 0 409 | 412 0 410 | 413 0 411 | 414 0 412 | 415 0 413 | 416 0 414 | 417 0 415 | 418 0 416 | 419 0 417 | 420 0 418 | 421 0 419 | 422 0 420 | 423 0 421 | 424 0 422 | 425 0 423 | 426 0 424 | 427 0 425 | 428 0 426 | 429 0 427 | 430 0 428 | 431 0 429 | 432 0 430 | 433 0 431 | 434 0 432 | 435 0 433 | 436 0 434 | 437 0 435 | 438 0 436 | 439 0 437 | 440 0 438 | 441 0 439 | 442 0 440 | 443 0 441 | 444 0 442 | 445 0 443 | 446 0 444 | 447 0 445 | 448 0 446 | 449 0 447 | 450 0 448 | 451 0 449 | 452 0 450 | 453 0 451 | 454 0 452 | 455 0 453 | 456 0 454 | 457 0 455 | 458 0 456 | 459 0 457 | 460 0 458 | 461 0 459 | 462 0 460 | 463 0 461 | 464 0 462 | 465 0 463 | 466 0 464 | 467 0 465 | 468 0 466 | 469 0 467 | 470 0 468 | 471 0 469 | 472 0 470 | 473 0 471 | 474 0 472 | 475 0 473 | 476 0 474 | 477 0 475 | 478 0 476 | 479 0 477 | 480 0 478 | 481 0 479 | 482 0 480 | 483 0 481 | 484 0 482 | 485 0 483 | 486 0 484 | 487 0 485 | 488 0 486 | 489 0 487 | 490 0 488 | 491 0 489 | 492 0 490 | 493 0 491 | 494 0 492 | 495 0 493 | 496 0 494 | 497 0 495 | 498 0 496 | 499 0 497 | 500 0 498 | 501 0 499 | 502 0 500 | 503 0 501 | 504 0 502 | 505 0 503 | 506 0 504 | 507 0 505 | 508 0 506 | 509 0 507 | 510 0 508 | 511 0 509 | 512 0 510 | 513 0 511 | 514 0 512 | 515 0 513 | 516 0 514 | 517 0 515 | 518 0 516 | 519 0 517 | 520 0 518 | 521 0 519 | 522 0 520 | 523 0 521 | 524 0 522 | 525 0 523 | 526 0 524 | 527 0 525 | 528 0 526 | 529 0 527 | 530 0 528 | 531 0 529 | 532 0 530 | 533 0 531 | 534 0 532 | 535 0 533 | 536 0 534 | 537 0 535 | 538 0 536 | 539 0 537 | 540 0 538 | 541 0 539 | 542 0 540 | 543 0 541 | 544 0 542 | 545 0 543 | 546 0 544 | 547 0 545 | 548 0 546 | 549 0 547 | 550 0 548 | 551 0 549 | 552 0 550 | 553 0 551 | 554 0 552 | 555 0 553 | 556 0 554 | 557 0 555 | 558 0 556 | 559 0 557 | 560 0 558 | 561 0 559 | 562 0 560 | 563 0 561 | 564 0 562 | 565 0 563 | 566 0 564 | 567 0 565 | 568 0 566 | 569 0 567 | 570 0 568 | 571 0 569 | 572 0 570 | 573 0 571 | 574 0 572 | 575 0 573 | 576 0 574 | 577 0 575 | 578 0 576 | 579 0 577 | 580 0 578 | 581 0 579 | 582 0 580 | 583 0 581 | 584 0 582 | 585 0 583 | 586 0 584 | 587 0 585 | 588 0 586 | 589 0 587 | 590 0 588 | 591 0 589 | 592 0 590 | 593 0 591 | 594 0 592 | 595 0 593 | 596 0 594 | 597 0 595 | 598 0 596 | 599 0 597 | 600 0 598 | 601 0 599 | 602 0 600 | 603 0 601 | 604 0 602 | 605 0 603 | 606 0 604 | 607 0 605 | 608 0 606 | 609 0 607 | 610 0 608 | 611 0 609 | 612 0 610 | 613 0 611 | 614 0 612 | 615 0 613 | 616 0 614 | 617 0 615 | 618 0 616 | 619 0 617 | 620 0 618 | 621 0 619 | 622 0 620 | 623 0 621 | 624 0 622 | 625 0 623 | 626 0 624 | 627 0 625 | 628 0 626 | 629 0 627 | 630 0 628 | 631 0 629 | 632 0 630 | 633 0 631 | 634 0 632 | 635 0 633 | 636 0 634 | 637 0 635 | 638 0 636 | 639 0 637 | 640 0 638 | 641 0 639 | 642 0 640 | 643 0 641 | 644 0 642 | 645 0 643 | 646 0 644 | 647 0 645 | 648 0 646 | 649 0 647 | 650 0 648 | 651 0 649 | 652 0 650 | 653 0 651 | 654 0 652 | 655 0 653 | 656 0 654 | 657 0 655 | 658 0 656 | 659 0 657 | 660 0 658 | 661 0 659 | 662 0 660 | 663 0 661 | 664 0 662 | 665 0 663 | 666 0 664 | 667 0 665 | 668 0 666 | 669 0 667 | 670 0 668 | 671 0 669 | 672 0 670 | 673 0 671 | 674 0 672 | 675 0 673 | 676 0 674 | 677 0 675 | 678 0 676 | 679 0 677 | 680 0 678 | 681 0 679 | 682 0 680 | 683 0 681 | 684 0 682 | 685 0 683 | 686 0 684 | 687 0 685 | 688 0 686 | 689 0 687 | 690 0 688 | 691 0 689 | 692 0 690 | 693 0 691 | 694 0 692 | 695 0 693 | 696 0 694 | 697 0 695 | 698 0 696 | 699 0 697 | 700 0 698 | 701 0 699 | 702 0 700 | 703 0 701 | 704 0 702 | 705 0 703 | 706 0 704 | 707 0 705 | 708 0 706 | 709 0 707 | 710 0 708 | 711 0 709 | 712 0 710 | 713 0 711 | 714 0 712 | 715 0 713 | 716 0 714 | 717 0 715 | 718 0 716 | 719 0 717 | 720 0 718 | 721 0 719 | 722 0 720 | 723 0 721 | 724 0 722 | 725 0 723 | 726 0 724 | 727 0 725 | 728 0 726 | 729 0 727 | 730 0 728 | 731 0 729 | 732 0 730 | 733 0 731 | 734 0 732 | 735 0 733 | 736 0 734 | 737 0 735 | 738 0 736 | 739 0 737 | 740 0 738 | 741 0 739 | 742 0 740 | 743 0 741 | 744 0 742 | 745 0 743 | 746 0 744 | 747 0 745 | 748 0 746 | 749 0 747 | 750 0 748 | 751 0 749 | 752 0 750 | 753 0 751 | 754 0 752 | 755 0 753 | 756 0 754 | 757 0 755 | 758 0 756 | 759 0 757 | 760 0 758 | 761 0 759 | 762 0 760 | 763 0 761 | 764 0 762 | 765 0 763 | 766 0 764 | 767 0 765 | 768 0 766 | 769 0 767 | 770 0 768 | 771 0 769 | 772 0 770 | 773 0 771 | 774 0 772 | 775 0 773 | 776 0 774 | 777 0 775 | 778 0 776 | 779 0 777 | 780 0 778 | 781 0 779 | 782 0 780 | 783 0 781 | 784 0 782 | 785 0 783 | 786 0 784 | 787 0 785 | 788 0 786 | 789 0 787 | 790 0 788 | 791 0 789 | 792 0 790 | 793 0 791 | 794 0 792 | 795 0 793 | 796 0 794 | 797 0 795 | 798 0 796 | 799 0 797 | 800 0 798 | 801 0 799 | 802 0 800 | 803 0 801 | 804 0 802 | 805 0 803 | 806 0 804 | 807 0 805 | 808 0 806 | 809 0 807 | 810 0 808 | 811 0 809 | 812 0 810 | 813 0 811 | 814 0 812 | 815 0 813 | 816 0 814 | 817 0 815 | 818 0 816 | 819 0 817 | 820 0 818 | 821 0 819 | 822 0 820 | 823 0 821 | 824 0 822 | 825 0 823 | 826 0 824 | 827 0 825 | 828 0 826 | 829 0 827 | 830 0 828 | 831 0 829 | 832 0 830 | 833 0 831 | 834 0 832 | 835 0 833 | 836 0 834 | 837 0 835 | 838 0 836 | 839 0 837 | 840 0 838 | 841 0 839 | 842 0 840 | 843 0 841 | 844 0 842 | 845 0 843 | 846 0 844 | 847 0 845 | 848 0 846 | 849 0 847 | 850 0 848 | 851 0 849 | 852 0 850 | 853 0 851 | 854 0 852 | 855 0 853 | 856 0 854 | 857 0 855 | 858 0 856 | 859 0 857 | 860 0 858 | 861 0 859 | 862 0 860 | 863 0 861 | 864 0 862 | 865 0 863 | 866 0 864 | 867 0 865 | 868 0 866 | 869 0 867 | 870 0 868 | 871 0 869 | 872 0 870 | 873 0 871 | 874 0 872 | 875 0 873 | 876 0 874 | 877 0 875 | 878 0 876 | 879 0 877 | 880 0 878 | 881 0 879 | 882 0 880 | 883 0 881 | 884 0 882 | 885 0 883 | 886 0 884 | 887 0 885 | 888 0 886 | 889 0 887 | 890 0 888 | 891 0 889 | 892 0 890 | 893 0 891 | 894 0 892 | 895 0 893 | 896 0 894 | 897 0 895 | 898 0 896 | 899 0 897 | 900 0 898 | 901 0 899 | 902 0 900 | 903 0 901 | 904 0 902 | 905 0 903 | 906 0 904 | 907 0 905 | 908 0 906 | 909 0 907 | 910 0 908 | 911 0 909 | 912 0 910 | 913 0 911 | 914 0 912 | 915 0 913 | 916 0 914 | 917 0 915 | 918 0 916 | 919 0 917 | 920 0 918 | 921 0 919 | 922 0 920 | 923 0 921 | 924 0 922 | 925 0 923 | 926 0 924 | 927 0 925 | 928 0 926 | 929 0 927 | 930 0 928 | 931 0 929 | 932 0 930 | 933 0 931 | 934 0 932 | 935 0 933 | 936 0 934 | 937 0 935 | 938 0 936 | 939 0 937 | 940 0 938 | 941 0 939 | 942 0 940 | 943 0 941 | 944 0 942 | 945 0 943 | 946 0 944 | 947 0 945 | 948 0 946 | 949 0 947 | 950 0 948 | 951 0 949 | 952 0 950 | 953 0 951 | 954 0 952 | 955 0 953 | 956 0 954 | 957 0 955 | 958 0 956 | 959 0 957 | 960 0 958 | 961 0 959 | 962 0 960 | 963 0 961 | 964 0 962 | 965 0 963 | 966 0 964 | 967 0 965 | 968 0 966 | 969 0 967 | 970 0 968 | 971 0 969 | 972 0 970 | 973 0 971 | 974 0 972 | 975 0 973 | 976 0 974 | 977 0 975 | 978 0 976 | 979 0 977 | 980 0 978 | 981 0 979 | 982 0 980 | 983 0 981 | 984 0 982 | 985 0 983 | 986 0 984 | 987 0 985 | 988 0 986 | 989 0 987 | 990 0 988 | 991 0 989 | 992 0 990 | 993 0 991 | 994 0 992 | 995 0 993 | 996 0 994 | 997 0 995 | 998 0 996 | 999 0 997 | 1000 0 998 | 1001 0 999 | 1002 0 1000 | 1003 0 1001 | 1004 0 1002 | 1005 0 1003 | 1006 0 1004 | 1007 0 1005 | 1008 0 1006 | 1009 0 1007 | 1010 0 1008 | 1011 0 1009 | 1012 0 1010 | 1013 0 1011 | 1014 0 1012 | 1015 0 1013 | 1016 0 1014 | 1017 0 1015 | 1018 0 1016 | 1019 0 1017 | 1020 0 1018 | 1021 0 1019 | 1022 0 1020 | 1023 0 1021 | 1024 0 1022 | 1025 0 1023 | 1026 0 1024 | 1027 0 1025 | 1028 0 1026 | 1029 0 1027 | 1030 0 1028 | 1031 0 1029 | 1032 0 1030 | 1033 0 1031 | 1034 0 1032 | 1035 0 1033 | 1036 0 1034 | 1037 0 1035 | 1038 0 1036 | 1039 0 1037 | 1040 0 1038 | 1041 0 1039 | 1042 0 1040 | 1043 0 1041 | 1044 0 1042 | 1045 0 1043 | 1046 0 1044 | 1047 0 1045 | 1048 0 1046 | 1049 0 1047 | 1050 0 1048 | 1051 0 1049 | 1052 0 1050 | 1053 0 1051 | 1054 0 1052 | 1055 0 1053 | 1056 0 1054 | 1057 0 1055 | 1058 0 1056 | 1059 0 1057 | 1060 0 1058 | 1061 0 1059 | 1062 0 1060 | 1063 0 1061 | 1064 0 1062 | 1065 0 1063 | 1066 0 1064 | 1067 0 1065 | 1068 0 1066 | 1069 0 1067 | 1070 0 1068 | 1071 0 1069 | 1072 0 1070 | 1073 0 1071 | 1074 0 1072 | 1075 0 1073 | 1076 0 1074 | 1077 0 1075 | 1078 0 1076 | 1079 0 1077 | 1080 0 1078 | 1081 0 1079 | 1082 0 1080 | 1083 0 1081 | 1084 0 1082 | 1085 0 1083 | 1086 0 1084 | 1087 0 1085 | 1088 0 1086 | 1089 0 1087 | 1090 0 1088 | 1091 0 1089 | 1092 0 1090 | 1093 0 1091 | 1094 0 1092 | 1095 0 1093 | 1096 0 1094 | 1097 0 1095 | 1098 0 1096 | 1099 0 1097 | 1100 0 1098 | 1101 0 1099 | 1102 0 1100 | 1103 0 1101 | 1104 0 1102 | 1105 0 1103 | 1106 0 1104 | 1107 0 1105 | 1108 0 1106 | 1109 0 1107 | 1110 0 1108 | 1111 0 1109 | 1112 0 1110 | 1113 0 1111 | 1114 0 1112 | 1115 0 1113 | 1116 0 1114 | 1117 0 1115 | 1118 0 1116 | 1119 0 1117 | 1120 0 1118 | 1121 0 1119 | 1122 0 1120 | 1123 0 1121 | 1124 0 1122 | -------------------------------------------------------------------------------- /data/model_spec/linear_4096_chord_bpe_hardloss1/vocabs/ori_dict.json: -------------------------------------------------------------------------------- 1 | {"": 0, "": 1, "": 2, "": 3, "CO": 4, "dO": 5, "DO": 6, "eO": 7, "EO": 8, "FO": 9, "gO": 10, "GO": 11, "aO": 12, "AO": 13, "bO": 14, "BO": 15, "C0": 16, "d0": 17, "D0": 18, "e0": 19, "E0": 20, "F0": 21, "g0": 22, "G0": 23, "a0": 24, "A0": 25, "b0": 26, "B0": 27, "C1": 28, "d1": 29, "D1": 30, "e1": 31, "E1": 32, "F1": 33, "g1": 34, "G1": 35, "a1": 36, "A1": 37, "b1": 38, "B1": 39, "C2": 40, "d2": 41, "D2": 42, "e2": 43, "E2": 44, "F2": 45, "g2": 46, "G2": 47, "a2": 48, "A2": 49, "b2": 50, "B2": 51, "C3": 52, "d3": 53, "D3": 54, "e3": 55, "E3": 56, "F3": 57, "g3": 58, "G3": 59, "a3": 60, "A3": 61, "b3": 62, "B3": 63, "C4": 64, "d4": 65, "D4": 66, "e4": 67, "E4": 68, "F4": 69, "g4": 70, "G4": 71, "a4": 72, "A4": 73, "b4": 74, "B4": 75, "C5": 76, "d5": 77, "D5": 78, "e5": 79, "E5": 80, "F5": 81, "g5": 82, "G5": 83, "a5": 84, "A5": 85, "b5": 86, "B5": 87, "C6": 88, "d6": 89, "D6": 90, "e6": 91, "E6": 92, "F6": 93, "g6": 94, "G6": 95, "a6": 96, "A6": 97, "b6": 98, "B6": 99, "C7": 100, "d7": 101, "D7": 102, "e7": 103, "E7": 104, "F7": 105, "g7": 106, "G7": 107, "a7": 108, "A7": 109, "b7": 110, "B7": 111, "C8": 112, "d8": 113, "D8": 114, "e8": 115, "E8": 116, "F8": 117, "g8": 118, "G8": 119, "a8": 120, "A8": 121, "b8": 122, "B8": 123, "C9": 124, "d9": 125, "D9": 126, "e9": 127, "E9": 128, "F9": 129, "g9": 130, "G9": 131, "dOC1": 132, "dOC1C3A4C7": 133, "dOC1A4": 134, "dOC1A4C7": 135, "b0b1": 136, "C1C2": 137, "C1D3": 138, "d1d2": 139, "D1D2": 140, "e1C2": 141, "e1D2": 142, "e1e2": 143, "E1E2": 144, "F1F2": 145, "g1g2": 146, "G1G2": 147, "a1a2": 148, "A1A2": 149, "b1b2": 150, "B1C2": 151, "B1d2": 152, "B1D2": 153, "B1D2g2": 154, "B1D2d3": 155, "B1D2g3": 156, "B1D2A3": 157, "B1e2": 158, "B1E2": 159, "B1F2": 160, "B1g2": 161, "B1G2": 162, "B1a2": 163, "B1A2": 164, "B1b2": 165, "B1B2": 166, "B1d3": 167, "B1D3": 168, "B1e3": 169, "B1E3": 170, "B1g3": 171, "B1A3": 172, "B1A3A5": 173, "B1B3": 174, "B1A5": 175, "C2d2": 176, "C2D2": 177, "C2D2F2": 178, "C2D2g2": 179, "C2D2d3": 180, "C2D2B3": 181, "C2e2": 182, "C2E2": 183, "C2F2": 184, "C2g2": 185, "C2G2": 186, "C2G2C3": 187, "C2a2": 188, "C2A2": 189, "C2b2": 190, "C2B2": 191, "C2C3": 192, "C2d3": 193, "C2D3": 194, "C2e3": 195, "C2g3": 196, "C2A3": 197, "C2B3": 198, "d2D2": 199, "d2E2": 200, "d2F2": 201, "d2g2": 202, "d2a2": 203, "d2d3": 204, "d2e3": 205, "d2g3": 206, "D2e2": 207, "D2E2": 208, "D2F2": 209, "D2g2": 210, "D2g2g3": 211, "D2G2": 212, "D2a2": 213, "D2A2": 214, "D2A2D3": 215, "D2b2": 216, "D2B2": 217, "D2C3": 218, "D2d3": 219, "D2D3": 220, "D2e3": 221, "D2E3": 222, "D2F3": 223, "D2g3": 224, "D2G3": 225, "D2A3": 226, "D2B3": 227, "D2D4": 228, "D2e4": 229, "D2E4": 230, "D2A4": 231, "D2b4": 232, "D2E5": 233, "D2A5": 234, "e2G2": 235, "e2a2": 236, "e2b2": 237, "e2b2e3": 238, "e2e3": 239, "E2F2": 240, "E2g2": 241, "E2G2": 242, "E2a2": 243, "E2A2": 244, "E2b2": 245, "E2B2": 246, "E2B2E3": 247, "E2B2E3G3B3E4": 248, "E2C3": 249, "E2d3": 250, "E2e3": 251, "E2E3": 252, "E2g3": 253, "F2g2": 254, "F2G2": 255, "F2a2": 256, "F2A2": 257, "F2A2C3": 258, "F2b2": 259, "F2B2": 260, "F2C3": 261, "F2C3F3": 262, "F2D3": 263, "F2e3": 264, "F2F3": 265, "g2G2": 266, "g2a2": 267, "g2A2": 268, "g2b2": 269, "g2B2": 270, "g2d3": 271, "g2d3g3": 272, "g2D3": 273, "g2e3": 274, "g2g3": 275, "g2C4": 276, "g2d4": 277, "g2E4": 278, "g2A4": 279, "g2b4": 280, "g2b5": 281, "G2A2": 282, "G2b2": 283, "G2B2": 284, "G2B2D3": 285, "G2C3": 286, "G2D3": 287, "G2D3G3": 288, "G2D3G3G4": 289, "G2e3": 290, "G2E3": 291, "G2G3": 292, "a2A2": 293, "a2B2": 294, "a2C3": 295, "a2d3": 296, "a2e3": 297, "a2e3a3": 298, "a2g3": 299, "a2a3": 300, "A2B2": 301, "A2C3": 302, "A2C3E3": 303, "A2d3": 304, "A2D3": 305, "A2E3": 306, "A2E3A3": 307, "A2F3": 308, "A2g3": 309, "A2A3": 310, "b2d3": 311, "b2D3": 312, "b2D3F3": 313, "b2e3": 314, "b2F3": 315, "b2F3b3": 316, "b2g3": 317, "b2G3": 318, "b2b3": 319, "B2C3": 320, "B2D3": 321, "B2e3": 322, "B2E3": 323, "B2g3": 324, "B2g3B3": 325, "B2G3": 326, "B2B3": 327, "C3D3": 328, "C3e3": 329, "C3e3G3": 330, "C3E3": 331, "C3E3G3": 332, "C3F3": 333, "C3F3A3": 334, "C3G3": 335, "C3G3C4": 336, "C3a3": 337, "C3A3": 338, "C3C4": 339, "d3D3": 340, "d3E3": 341, "d3F3": 342, "d3F3a3": 343, "d3g3": 344, "d3a3": 345, "d3A3": 346, "d3b3": 347, "d3d4": 348, "D3e3": 349, "D3E3": 350, "D3F3": 351, "D3F3A3": 352, "D3F3b3": 353, "D3g3": 354, "D3g3A3": 355, "D3G3": 356, "D3G3B3": 357, "D3a3": 358, "D3A3": 359, "D3A3D4": 360, "D3A3D4g4": 361, "D3b3": 362, "D3B3": 363, "D3C4": 364, "D3D4": 365, "D3C5": 366, "D3C7": 367, "e3F3": 368, "e3g3": 369, "e3G3": 370, "e3G3b3": 371, "e3G3C4": 372, "e3a3": 373, "e3a3C4": 374, "e3A3": 375, "e3b3": 376, "e3b3e4": 377, "e3B3": 378, "e3C4": 379, "e3e4": 380, "E3F3": 381, "E3g3": 382, "E3G3": 383, "E3G3B3": 384, "E3G3B3E4": 385, "E3G3C4": 386, "E3a3": 387, "E3a3B3": 388, "E3A3": 389, "E3A3C4": 390, "E3A3d4": 391, "E3b3": 392, "E3B3": 393, "E3B3E4": 394, "E3C4": 395, "E3d4": 396, "E3E4": 397, "F3g3": 398, "F3G3": 399, "F3a3": 400, "F3a3C4": 401, "F3a3d4": 402, "F3A3": 403, "F3A3C4": 404, "F3A3C4F4": 405, "F3A3D4": 406, "F3b3": 407, "F3b3D4": 408, "F3B3": 409, "F3C4": 410, "F3d4": 411, "F3D4": 412, "F3F4": 413, "g3a3": 414, "g3A3": 415, "g3A3d4": 416, "g3A3D4": 417, "g3b3": 418, "g3b3d4": 419, "g3B3": 420, "g3B3e4": 421, "g3C4": 422, "g3d4": 423, "g3D4": 424, "g3e4": 425, "g3E4": 426, "g3g4": 427, "g3b4": 428, "g3A5": 429, "G3a3": 430, "G3A3": 431, "G3b3": 432, "G3b3D4": 433, "G3b3e4": 434, "G3B3": 435, "G3B3D4": 436, "G3B3D4G4": 437, "G3B3E4": 438, "G3C4": 439, "G3C4e4": 440, "G3C4E4": 441, "G3d4": 442, "G3D4": 443, "G3D4G4": 444, "G3e4": 445, "G3E4": 446, "G3F4": 447, "G3G4": 448, "a3b3": 449, "a3B3": 450, "a3B3E4": 451, "a3C4": 452, "a3C4e4": 453, "a3C4F4": 454, "a3d4": 455, "a3d4F4": 456, "a3D4": 457, "a3e4": 458, "a3E4": 459, "a3F4": 460, "a3a4": 461, "A3b3": 462, "A3B3": 463, "A3C4": 464, "A3C4E4": 465, "A3C4F4": 466, "A3d4": 467, "A3d4E4": 468, "A3d4g4": 469, "A3D4": 470, "A3D4F4": 471, "A3D4g4": 472, "A3e4": 473, "A3E4": 474, "A3E4A4": 475, "A3F4": 476, "A3g4": 477, "A3G4": 478, "A3A4": 479, "A3A5": 480, "b3B3": 481, "b3C4": 482, "b3d4": 483, "b3d4F4": 484, "b3d4g4": 485, "b3D4": 486, "b3D4F4": 487, "b3D4G4": 488, "b3e4": 489, "b3e4G4": 490, "b3E4": 491, "b3F4": 492, "b3g4": 493, "b3G4": 494, "b3b4": 495, "B3C4": 496, "B3d4": 497, "B3D4": 498, "B3D4g4": 499, "B3D4G4": 500, "B3e4": 501, "B3e4g4": 502, "B3E4": 503, "B3E4G4": 504, "B3E4a4": 505, "B3F4": 506, "B3g4": 507, "B3G4": 508, "B3a4": 509, "B3B4": 510, "C4d4": 511, "C4D4": 512, "C4e4": 513, "C4e4G4": 514, "C4e4a4": 515, "C4E4": 516, "C4E4G4": 517, "C4E4G4C5": 518, "C4E4A4": 519, "C4F4": 520, "C4F4a4": 521, "C4F4A4": 522, "C4g4": 523, "C4G4": 524, "C4G4C5": 525, "C4a4": 526, "C4A4": 527, "C4b4": 528, "C4C5": 529, "d4D4": 530, "d4e4": 531, "d4E4": 532, "d4E4A4": 533, "d4F4": 534, "d4F4a4": 535, "d4F4b4": 536, "d4g4": 537, "d4G4": 538, "d4a4": 539, "d4a4d5": 540, "d4A4": 541, "d4b4": 542, "d4d5": 543, "D4e4": 544, "D4E4": 545, "D4F4": 546, "D4F4A4": 547, "D4F4b4": 548, "D4g4": 549, "D4g4A4": 550, "D4g4A4D5": 551, "D4g4B4": 552, "D4G4": 553, "D4G4b4": 554, "D4G4B4": 555, "D4a4": 556, "D4A4": 557, "D4A4D5": 558, "D4b4": 559, "D4B4": 560, "D4C5": 561, "D4D5": 562, "e4E4": 563, "e4F4": 564, "e4g4": 565, "e4g4B4": 566, "e4G4": 567, "e4G4b4": 568, "e4G4b4e5": 569, "e4G4C5": 570, "e4a4": 571, "e4a4C5": 572, "e4A4": 573, "e4b4": 574, "e4b4e5": 575, "e4B4": 576, "e4C5": 577, "e4e5": 578, "E4F4": 579, "E4g4": 580, "E4G4": 581, "E4G4B4": 582, "E4G4C5": 583, "E4a4": 584, "E4a4B4": 585, "E4A4": 586, "E4A4C5": 587, "E4A4d5": 588, "E4b4": 589, "E4B4": 590, "E4B4E5": 591, "E4C5": 592, "E4d5": 593, "E4E5": 594, "F4g4": 595, "F4G4": 596, "F4a4": 597, "F4a4C5": 598, "F4a4d5": 599, "F4A4": 600, "F4A4C5": 601, "F4A4C5F5": 602, "F4A4D5": 603, "F4b4": 604, "F4b4D5": 605, "F4B4": 606, "F4C5": 607, "F4C5F5": 608, "F4d5": 609, "F4D5": 610, "F4F5": 611, "g4G4": 612, "g4a4": 613, "g4A4": 614, "g4A4d5": 615, "g4A4D5": 616, "g4b4": 617, "g4b4d5": 618, "g4B4": 619, "g4C5": 620, "g4d5": 621, "g4D5": 622, "g4e5": 623, "g4g5": 624, "G4a4": 625, "G4A4": 626, "G4b4": 627, "G4b4D5": 628, "G4b4e5": 629, "G4B4": 630, "G4B4D5": 631, "G4B4E5": 632, "G4C5": 633, "G4C5e5": 634, "G4C5E5": 635, "G4d5": 636, "G4D5": 637, "G4D5G5": 638, "G4e5": 639, "G4E5": 640, "G4G5": 641, "a4A4": 642, "a4b4": 643, "a4B4": 644, "a4B4E5": 645, "a4C5": 646, "a4C5e5": 647, "a4d5": 648, "a4D5": 649, "a4e5": 650, "a4E5": 651, "a4F5": 652, "a4a5": 653, "A4b4": 654, "A4B4": 655, "A4C5": 656, "A4C5E5": 657, "A4C5F5": 658, "A4d5": 659, "A4d5E5": 660, "A4D5": 661, "A4D5F5": 662, "A4D5g5": 663, "A4e5": 664, "A4E5": 665, "A4F5": 666, "A4g5": 667, "A4G5": 668, "A4A5": 669, "b4B4": 670, "b4C5": 671, "b4d5": 672, "b4D5": 673, "b4D5F5": 674, "b4e5": 675, "b4e5G5": 676, "b4E5": 677, "b4F5": 678, "b4g5": 679, "b4G5": 680, "b4b5": 681, "B4C5": 682, "B4d5": 683, "B4D5": 684, "B4D5G5": 685, "B4e5": 686, "B4E5": 687, "B4F5": 688, "B4g5": 689, "B4G5": 690, "B4a5": 691, "B4B5": 692, "C5d5": 693, "C5D5": 694, "C5e5": 695, "C5e5G5": 696, "C5E5": 697, "C5E5G5": 698, "C5F5": 699, "C5F5A5": 700, "C5g5": 701, "C5G5": 702, "C5a5": 703, "C5A5": 704, "C5C6": 705, "d5D5": 706, "d5e5": 707, "d5E5": 708, "d5F5": 709, "d5g5": 710, "d5G5": 711, "d5a5": 712, "d5A5": 713, "d5b5": 714, "d5d6": 715, "D5e5": 716, "D5E5": 717, "D5F5": 718, "D5F5A5": 719, "D5g5": 720, "D5g5A5": 721, "D5G5": 722, "D5a5": 723, "D5A5": 724, "D5b5": 725, "D5B5": 726, "D5D6": 727, "e5E5": 728, "e5F5": 729, "e5g5": 730, "e5G5": 731, "e5a5": 732, "e5A5": 733, "e5b5": 734, "e5B5": 735, "e5C6": 736, "e5e6": 737, "E5F5": 738, "E5g5": 739, "E5G5": 740, "E5G5B5": 741, "E5a5": 742, "E5A5": 743, "E5b5": 744, "E5B5": 745, "E5C6": 746, "E5d6": 747, "E5E6": 748, "F5g5": 749, "F5G5": 750, "F5a5": 751, "F5A5": 752, "F5A5C6": 753, "F5b5": 754, "F5B5": 755, "F5C6": 756, "F5d6": 757, "F5D6": 758, "F5F6": 759, "g5G5": 760, "g5A5": 761, "g5b5": 762, "g5B5": 763, "g5d6": 764, "g5D6": 765, "g5g6": 766, "G5a5": 767, "G5A5": 768, "G5b5": 769, "G5B5": 770, "G5C6": 771, "G5D6": 772, "G5e6": 773, "G5E6": 774, "G5G6": 775, "a5A5": 776, "a5b5": 777, "a5B5": 778, "a5C6": 779, "a5d6": 780, "a5e6": 781, "a5a6": 782, "A5b5": 783, "A5B5": 784, "A5C6": 785, "A5d6": 786, "A5D6": 787, "A5E6": 788, "A5F6": 789, "A5A6": 790, "b5B5": 791, "b5C6": 792, "b5d6": 793, "b5D6": 794, "b5e6": 795, "b5F6": 796, "b5b6": 797, "B5C6": 798, "B5D6": 799, "B5e6": 800, "B5E6": 801, "B5B6": 802, "C6D6": 803, "C6e6": 804, "C6E6": 805, "C6F6": 806, "C6G6": 807, "C6C7": 808, "d6E6": 809, "d6F6": 810, "d6g6": 811, "d6d7": 812, "D6e6": 813, "D6E6": 814, "D6F6": 815, "D6g6": 816, "D6G6": 817, "D6A6": 818, "D6D7": 819, "e6g6": 820, "e6G6": 821, "e6e7": 822, "E6G6": 823, "E6a6": 824, "E6A6": 825, "E6E7": 826, "F6a6": 827, "F6A6": 828, "F6F7": 829, "g6A6": 830, "G6G7": 831, "HC+": 832, "HC/o7": 833, "HCD7": 834, "HCM": 835, "HCM7": 836, "HCm": 837, "HCm7": 838, "HCo": 839, "HCo7": 840, "HCsus2": 841, "HCsus4": 842, "Hd+": 843, "Hd/o7": 844, "HdD7": 845, "HdM": 846, "HdM7": 847, "Hdm": 848, "Hdm7": 849, "Hdo": 850, "Hdo7": 851, "Hdsus2": 852, "Hdsus4": 853, "HD+": 854, "HD/o7": 855, "HDD7": 856, "HDM": 857, "HDM7": 858, "HDm": 859, "HDm7": 860, "HDo": 861, "HDo7": 862, "HDsus2": 863, "HDsus4": 864, "He+": 865, "He/o7": 866, "HeD7": 867, "HeM": 868, "HeM7": 869, "Hem": 870, "Hem7": 871, "Heo": 872, "Heo7": 873, "Hesus2": 874, "Hesus4": 875, "HE+": 876, "HE/o7": 877, "HED7": 878, "HEM": 879, "HEM7": 880, "HEm": 881, "HEm7": 882, "HEo": 883, "HEo7": 884, "HEsus2": 885, "HEsus4": 886, "HF+": 887, "HF/o7": 888, "HFD7": 889, "HFM": 890, "HFM7": 891, "HFm": 892, "HFm7": 893, "HFo": 894, "HFo7": 895, "HFsus2": 896, "HFsus4": 897, "Hg+": 898, "Hg/o7": 899, "HgD7": 900, "HgM": 901, "HgM7": 902, "Hgm": 903, "Hgm7": 904, "Hgo": 905, "Hgo7": 906, "Hgsus2": 907, "Hgsus4": 908, "HG+": 909, "HG/o7": 910, "HGD7": 911, "HGM": 912, "HGM7": 913, "HGm": 914, "HGm7": 915, "HGo": 916, "HGo7": 917, "HGsus2": 918, "HGsus4": 919, "Ha+": 920, "Ha/o7": 921, "HaD7": 922, "HaM": 923, "HaM7": 924, "Ham": 925, "Ham7": 926, "Hao": 927, "Hao7": 928, "Hasus2": 929, "Hasus4": 930, "HA+": 931, "HA/o7": 932, "HAD7": 933, "HAM": 934, "HAM7": 935, "HAm": 936, "HAm7": 937, "HAo": 938, "HAo7": 939, "HAsus2": 940, "HAsus4": 941, "Hb+": 942, "Hb/o7": 943, "HbD7": 944, "HbM": 945, "HbM7": 946, "Hbm": 947, "Hbm7": 948, "Hbo": 949, "Hbo7": 950, "Hbsus2": 951, "Hbsus4": 952, "HB+": 953, "HB/o7": 954, "HBD7": 955, "HBM": 956, "HBM7": 957, "HBm": 958, "HBm7": 959, "HBo": 960, "HBo7": 961, "HBsus2": 962, "HBsus4": 963, "HNA": 964, "m1": 965, "m2": 966, "m3": 967, "m4": 968, "m6": 969, "m7": 970, "m8": 971, "m9": 972, "ma": 973, "mb": 974, "mc": 975, "md": 976, "me": 977, "mf": 978, "mg": 979, "mh": 980, "mi": 981, "mj": 982, "mk": 983, "ml": 984, "mm": 985, "mn": 986, "mo": 987, "mp": 988, "mq": 989, "mr": 990, "ms": 991, "mt": 992, "mu": 993, "mw": 994, "mx": 995, "my": 996, "mz": 997, "mA": 998, "mB": 999, "mC": 1000, "mE": 1001, "mG": 1002, "mI": 1003, "mK": 1004, "mM": 1005, "mO": 1006, "mQ": 1007, "mU": 1008, "mV": 1009, "mY": 1010, "M1": 1011, "M2": 1012, "M5": 1013, "M6": 1014, "M8": 1015, "M9": 1016, "Ma": 1017, "Me": 1018, "Mg": 1019, "Mi": 1020, "Mm": 1021, "Mq": 1022, "Ms": 1023, "Mu": 1024, "My": 1025, "MC": 1026, "p0": 1027, "p1": 1028, "p2": 1029, "p3": 1030, "p4": 1031, "p5": 1032, "p6": 1033, "p7": 1034, "p8": 1035, "p9": 1036, "pa": 1037, "pb": 1038, "pc": 1039, "pd": 1040, "pe": 1041, "pf": 1042, "pg": 1043, "ph": 1044, "pi": 1045, "pj": 1046, "pk": 1047, "pl": 1048, "pm": 1049, "pn": 1050, "po": 1051, "pp": 1052, "pq": 1053, "pr": 1054, "ps": 1055, "pt": 1056, "pu": 1057, "pv": 1058, "pw": 1059, "px": 1060, "py": 1061, "pz": 1062, "pA": 1063, "pB": 1064, "pC": 1065, "pD": 1066, "pE": 1067, "pF": 1068, "pG": 1069, "pH": 1070, "pI": 1071, "pJ": 1072, "pK": 1073, "pL": 1074, "pM": 1075, "pN": 1076, "pO": 1077, "pP": 1078, "pQ": 1079, "pR": 1080, "pS": 1081, "pT": 1082, "pU": 1083, "pV": 1084, "pW": 1085, "pX": 1086, "pY": 1087, "pZ": 1088, "P0": 1089, "P1": 1090, "P2": 1091, "P3": 1092, "P4": 1093, "P5": 1094, "P6": 1095, "P7": 1096, "P8": 1097, "P9": 1098, "Pa": 1099, "Pb": 1100, "Pc": 1101, "Pd": 1102, "Pe": 1103, "Pf": 1104, "Pg": 1105, "Ph": 1106, "Pi": 1107, "Pj": 1108, "Pk": 1109, "Pl": 1110, "Pm": 1111, "Pn": 1112, "Po": 1113, "Pp": 1114, "Pq": 1115, "Pr": 1116, "Ps": 1117, "Pt": 1118, "Pu": 1119, "Pv": 1120, "Pw": 1121, "Px": 1122, "PA": 1123, "NT": 1124, "r1": 4, "r2": 5, "r3": 6, "r4": 7, "r5": 8, "r6": 9, "r7": 10, "r8": 11, "r9": 12, "ra": 13, "rb": 14, "rc": 15, "rd": 16, "re": 17, "rf": 18, "rg": 19, "rh": 20, "ri": 21, "rj": 22, "rk": 23, "rl": 24, "rm": 25, "rn": 26, "ro": 27, "rp": 28, "rq": 29, "rr": 30, "rs": 31, "rt": 32, "ru": 33, "rv": 34, "rw": 35, "t0": 4, "t1": 5, "t2": 6, "t3": 7, "t4": 8, "t5": 9, "t6": 10, "t7": 11, "t8": 12, "t9": 13, "ta": 14, "tb": 15, "tc": 16, "td": 17, "te": 18, "tf": 19, "tg": 20, "th": 21, "ti": 22, "tj": 23, "tk": 24, "tl": 25, "tm": 26, "tn": 27, "to": 28, "tp": 29, "tq": 30, "tr": 31, "ts": 32, "tt": 33, "tu": 34, "tv": 35, "tw": 36, "tx": 37, "ty": 38, "tz": 39, "tA": 40, "tB": 41, "tC": 42, "tD": 43, "x0": 4, "x1": 5, "x2": 6, "x3": 7, "x4": 8, "x5": 9, "x6": 10, "x7": 11, "x8": 12, "x9": 13, "xa": 14, "xb": 15, "xc": 16, "xd": 17, "xe": 18, "xf": 19, "xg": 20, "xh": 21, "xi": 22, "xj": 23, "xk": 24, "xl": 25, "xm": 26, "xn": 27, "xo": 28, "xp": 29, "xq": 30, "xr": 31, "xs": 32, "xt": 33, "xu": 34, "xv": 35, "xw": 36, "xx": 37, "xy": 38, "xz": 39, "xA": 40, "xB": 41, "xC": 42, "xD": 43, "xE": 44, "xF": 45, "xG": 46, "xH": 47, "xI": 48, "xJ": 49, "xK": 50, "xL": 51, "xM": 52, "xN": 53, "xO": 54, "xP": 55, "xQ": 56, "xR": 57, "xS": 58, "xT": 59, "xU": 60, "xV": 61, "xW": 62, "xX": 63, "xY": 64, "xZ": 65, "X0": 66, "X1": 67, "X2": 68, "X3": 69, "X4": 70, "X5": 71, "X6": 72, "X7": 73, "X8": 74, "X9": 75, "Xa": 76, "Xb": 77, "Xc": 78, "Xd": 79, "Xe": 80, "Xf": 81, "Xg": 82, "Xh": 83, "Xi": 84, "Xj": 85, "Xk": 86, "Xl": 87, "Xm": 88, "Xn": 89, "Xo": 90, "Xp": 91, "Xq": 92, "Xr": 93, "Xs": 94, "Xt": 95, "Xu": 96, "Xv": 97, "Xw": 98, "Xx": 99, "Xy": 100, "Xz": 101, "XA": 102, "XB": 103, "XC": 104, "XD": 105, "XE": 106, "XF": 107, "XG": 108, "XH": 109, "XI": 110, "XJ": 111, "XK": 112, "XL": 113, "XM": 114, "XN": 115, "XO": 116, "XP": 117, "XQ": 118, "XR": 119, "XS": 120, "XT": 121, "XU": 122, "XV": 123, "XW": 124, "XX": 125, "XY": 126, "XZ": 127, "y0": 128, "y1": 129, "y2": 130, "y3": 131, "y4": 132} -------------------------------------------------------------------------------- /data/model_spec/linear_4096_chord_bpe_hardloss1/vocabs/vocab_0.json: -------------------------------------------------------------------------------- 1 | {"0": "", "1": "", "2": "", "3": "", "4": "CO", "5": "dO", "6": "DO", "7": "eO", "8": "EO", "9": "FO", "10": "gO", "11": "GO", "12": "aO", "13": "AO", "14": "bO", "15": "BO", "16": "C0", "17": "d0", "18": "D0", "19": "e0", "20": "E0", "21": "F0", "22": "g0", "23": "G0", "24": "a0", "25": "A0", "26": "b0", "27": "B0", "28": "C1", "29": "d1", "30": "D1", "31": "e1", "32": "E1", "33": "F1", "34": "g1", "35": "G1", "36": "a1", "37": "A1", "38": "b1", "39": "B1", "40": "C2", "41": "d2", "42": "D2", "43": "e2", "44": "E2", "45": "F2", "46": "g2", "47": "G2", "48": "a2", "49": "A2", "50": "b2", "51": "B2", "52": "C3", "53": "d3", "54": "D3", "55": "e3", "56": "E3", "57": "F3", "58": "g3", "59": "G3", "60": "a3", "61": "A3", "62": "b3", "63": "B3", "64": "C4", "65": "d4", "66": "D4", "67": "e4", "68": "E4", "69": "F4", "70": "g4", "71": "G4", "72": "a4", "73": "A4", "74": "b4", "75": "B4", "76": "C5", "77": "d5", "78": "D5", "79": "e5", "80": "E5", "81": "F5", "82": "g5", "83": "G5", "84": "a5", "85": "A5", "86": "b5", "87": "B5", "88": "C6", "89": "d6", "90": "D6", "91": "e6", "92": "E6", "93": "F6", "94": "g6", "95": "G6", "96": "a6", "97": "A6", "98": "b6", "99": "B6", "100": "C7", "101": "d7", "102": "D7", "103": "e7", "104": "E7", "105": "F7", "106": "g7", "107": "G7", "108": "a7", "109": "A7", "110": "b7", "111": "B7", "112": "C8", "113": "d8", "114": "D8", "115": "e8", "116": "E8", "117": "F8", "118": "g8", "119": "G8", "120": "a8", "121": "A8", "122": "b8", "123": "B8", "124": "C9", "125": "d9", "126": "D9", "127": "e9", "128": "E9", "129": "F9", "130": "g9", "131": "G9", "132": "dOC1", "133": "dOC1C3A4C7", "134": "dOC1A4", "135": "dOC1A4C7", "136": "b0b1", "137": "C1C2", "138": "C1D3", "139": "d1d2", "140": "D1D2", "141": "e1C2", "142": "e1D2", "143": "e1e2", "144": "E1E2", "145": "F1F2", "146": "g1g2", "147": "G1G2", "148": "a1a2", "149": "A1A2", "150": "b1b2", "151": "B1C2", "152": "B1d2", "153": "B1D2", "154": "B1D2g2", "155": "B1D2d3", "156": "B1D2g3", "157": "B1D2A3", "158": "B1e2", "159": "B1E2", "160": "B1F2", "161": "B1g2", "162": "B1G2", "163": "B1a2", "164": "B1A2", "165": "B1b2", "166": "B1B2", "167": "B1d3", "168": "B1D3", "169": "B1e3", "170": "B1E3", "171": "B1g3", "172": "B1A3", "173": "B1A3A5", "174": "B1B3", "175": "B1A5", "176": "C2d2", "177": "C2D2", "178": "C2D2F2", "179": "C2D2g2", "180": "C2D2d3", "181": "C2D2B3", "182": "C2e2", "183": "C2E2", "184": "C2F2", "185": "C2g2", "186": "C2G2", "187": "C2G2C3", "188": "C2a2", "189": "C2A2", "190": "C2b2", "191": "C2B2", "192": "C2C3", "193": "C2d3", "194": "C2D3", "195": "C2e3", "196": "C2g3", "197": "C2A3", "198": "C2B3", "199": "d2D2", "200": "d2E2", "201": "d2F2", "202": "d2g2", "203": "d2a2", "204": "d2d3", "205": "d2e3", "206": "d2g3", "207": "D2e2", "208": "D2E2", "209": "D2F2", "210": "D2g2", "211": "D2g2g3", "212": "D2G2", "213": "D2a2", "214": "D2A2", "215": "D2A2D3", "216": "D2b2", "217": "D2B2", "218": "D2C3", "219": "D2d3", "220": "D2D3", "221": "D2e3", "222": "D2E3", "223": "D2F3", "224": "D2g3", "225": "D2G3", "226": "D2A3", "227": "D2B3", "228": "D2D4", "229": "D2e4", "230": "D2E4", "231": "D2A4", "232": "D2b4", "233": "D2E5", "234": "D2A5", "235": "e2G2", "236": "e2a2", "237": "e2b2", "238": "e2b2e3", "239": "e2e3", "240": "E2F2", "241": "E2g2", "242": "E2G2", "243": "E2a2", "244": "E2A2", "245": "E2b2", "246": "E2B2", "247": "E2B2E3", "248": "E2B2E3G3B3E4", "249": "E2C3", "250": "E2d3", "251": "E2e3", "252": "E2E3", "253": "E2g3", "254": "F2g2", "255": "F2G2", "256": "F2a2", "257": "F2A2", "258": "F2A2C3", "259": "F2b2", "260": "F2B2", "261": "F2C3", "262": "F2C3F3", "263": "F2D3", "264": "F2e3", "265": "F2F3", "266": "g2G2", "267": "g2a2", "268": "g2A2", "269": "g2b2", "270": "g2B2", "271": "g2d3", "272": "g2d3g3", "273": "g2D3", "274": "g2e3", "275": "g2g3", "276": "g2C4", "277": "g2d4", "278": "g2E4", "279": "g2A4", "280": "g2b4", "281": "g2b5", "282": "G2A2", "283": "G2b2", "284": "G2B2", "285": "G2B2D3", "286": "G2C3", "287": "G2D3", "288": "G2D3G3", "289": "G2D3G3G4", "290": "G2e3", "291": "G2E3", "292": "G2G3", "293": "a2A2", "294": "a2B2", "295": "a2C3", "296": "a2d3", "297": "a2e3", "298": "a2e3a3", "299": "a2g3", "300": "a2a3", "301": "A2B2", "302": "A2C3", "303": "A2C3E3", "304": "A2d3", "305": "A2D3", "306": "A2E3", "307": "A2E3A3", "308": "A2F3", "309": "A2g3", "310": "A2A3", "311": "b2d3", "312": "b2D3", "313": "b2D3F3", "314": "b2e3", "315": "b2F3", "316": "b2F3b3", "317": "b2g3", "318": "b2G3", "319": "b2b3", "320": "B2C3", "321": "B2D3", "322": "B2e3", "323": "B2E3", "324": "B2g3", "325": "B2g3B3", "326": "B2G3", "327": "B2B3", "328": "C3D3", "329": "C3e3", "330": "C3e3G3", "331": "C3E3", "332": "C3E3G3", "333": "C3F3", "334": "C3F3A3", "335": "C3G3", "336": "C3G3C4", "337": "C3a3", "338": "C3A3", "339": "C3C4", "340": "d3D3", "341": "d3E3", "342": "d3F3", "343": "d3F3a3", "344": "d3g3", "345": "d3a3", "346": "d3A3", "347": "d3b3", "348": "d3d4", "349": "D3e3", "350": "D3E3", "351": "D3F3", "352": "D3F3A3", "353": "D3F3b3", "354": "D3g3", "355": "D3g3A3", "356": "D3G3", "357": "D3G3B3", "358": "D3a3", "359": "D3A3", "360": "D3A3D4", "361": "D3A3D4g4", "362": "D3b3", "363": "D3B3", "364": "D3C4", "365": "D3D4", "366": "D3C5", "367": "D3C7", "368": "e3F3", "369": "e3g3", "370": "e3G3", "371": "e3G3b3", "372": "e3G3C4", "373": "e3a3", "374": "e3a3C4", "375": "e3A3", "376": "e3b3", "377": "e3b3e4", "378": "e3B3", "379": "e3C4", "380": "e3e4", "381": "E3F3", "382": "E3g3", "383": "E3G3", "384": "E3G3B3", "385": "E3G3B3E4", "386": "E3G3C4", "387": "E3a3", "388": "E3a3B3", "389": "E3A3", "390": "E3A3C4", "391": "E3A3d4", "392": "E3b3", "393": "E3B3", "394": "E3B3E4", "395": "E3C4", "396": "E3d4", "397": "E3E4", "398": "F3g3", "399": "F3G3", "400": "F3a3", "401": "F3a3C4", "402": "F3a3d4", "403": "F3A3", "404": "F3A3C4", "405": "F3A3C4F4", "406": "F3A3D4", "407": "F3b3", "408": "F3b3D4", "409": "F3B3", "410": "F3C4", "411": "F3d4", "412": "F3D4", "413": "F3F4", "414": "g3a3", "415": "g3A3", "416": "g3A3d4", "417": "g3A3D4", "418": "g3b3", "419": "g3b3d4", "420": "g3B3", "421": "g3B3e4", "422": "g3C4", "423": "g3d4", "424": "g3D4", "425": "g3e4", "426": "g3E4", "427": "g3g4", "428": "g3b4", "429": "g3A5", "430": "G3a3", "431": "G3A3", "432": "G3b3", "433": "G3b3D4", "434": "G3b3e4", "435": "G3B3", "436": "G3B3D4", "437": "G3B3D4G4", "438": "G3B3E4", "439": "G3C4", "440": "G3C4e4", "441": "G3C4E4", "442": "G3d4", "443": "G3D4", "444": "G3D4G4", "445": "G3e4", "446": "G3E4", "447": "G3F4", "448": "G3G4", "449": "a3b3", "450": "a3B3", "451": "a3B3E4", "452": "a3C4", "453": "a3C4e4", "454": "a3C4F4", "455": "a3d4", "456": "a3d4F4", "457": "a3D4", "458": "a3e4", "459": "a3E4", "460": "a3F4", "461": "a3a4", "462": "A3b3", "463": "A3B3", "464": "A3C4", "465": "A3C4E4", "466": "A3C4F4", "467": "A3d4", "468": "A3d4E4", "469": "A3d4g4", "470": "A3D4", "471": "A3D4F4", "472": "A3D4g4", "473": "A3e4", "474": "A3E4", "475": "A3E4A4", "476": "A3F4", "477": "A3g4", "478": "A3G4", "479": "A3A4", "480": "A3A5", "481": "b3B3", "482": "b3C4", "483": "b3d4", "484": "b3d4F4", "485": "b3d4g4", "486": "b3D4", "487": "b3D4F4", "488": "b3D4G4", "489": "b3e4", "490": "b3e4G4", "491": "b3E4", "492": "b3F4", "493": "b3g4", "494": "b3G4", "495": "b3b4", "496": "B3C4", "497": "B3d4", "498": "B3D4", "499": "B3D4g4", "500": "B3D4G4", "501": "B3e4", "502": "B3e4g4", "503": "B3E4", "504": "B3E4G4", "505": "B3E4a4", "506": "B3F4", "507": "B3g4", "508": "B3G4", "509": "B3a4", "510": "B3B4", "511": "C4d4", "512": "C4D4", "513": "C4e4", "514": "C4e4G4", "515": "C4e4a4", "516": "C4E4", "517": "C4E4G4", "518": "C4E4G4C5", "519": "C4E4A4", "520": "C4F4", "521": "C4F4a4", "522": "C4F4A4", "523": "C4g4", "524": "C4G4", "525": "C4G4C5", "526": "C4a4", "527": "C4A4", "528": "C4b4", "529": "C4C5", "530": "d4D4", "531": "d4e4", "532": "d4E4", "533": "d4E4A4", "534": "d4F4", "535": "d4F4a4", "536": "d4F4b4", "537": "d4g4", "538": "d4G4", "539": "d4a4", "540": "d4a4d5", "541": "d4A4", "542": "d4b4", "543": "d4d5", "544": "D4e4", "545": "D4E4", "546": "D4F4", "547": "D4F4A4", "548": "D4F4b4", "549": "D4g4", "550": "D4g4A4", "551": "D4g4A4D5", "552": "D4g4B4", "553": "D4G4", "554": "D4G4b4", "555": "D4G4B4", "556": "D4a4", "557": "D4A4", "558": "D4A4D5", "559": "D4b4", "560": "D4B4", "561": "D4C5", "562": "D4D5", "563": "e4E4", "564": "e4F4", "565": "e4g4", "566": "e4g4B4", "567": "e4G4", "568": "e4G4b4", "569": "e4G4b4e5", "570": "e4G4C5", "571": "e4a4", "572": "e4a4C5", "573": "e4A4", "574": "e4b4", "575": "e4b4e5", "576": "e4B4", "577": "e4C5", "578": "e4e5", "579": "E4F4", "580": "E4g4", "581": "E4G4", "582": "E4G4B4", "583": "E4G4C5", "584": "E4a4", "585": "E4a4B4", "586": "E4A4", "587": "E4A4C5", "588": "E4A4d5", "589": "E4b4", "590": "E4B4", "591": "E4B4E5", "592": "E4C5", "593": "E4d5", "594": "E4E5", "595": "F4g4", "596": "F4G4", "597": "F4a4", "598": "F4a4C5", "599": "F4a4d5", "600": "F4A4", "601": "F4A4C5", "602": "F4A4C5F5", "603": "F4A4D5", "604": "F4b4", "605": "F4b4D5", "606": "F4B4", "607": "F4C5", "608": "F4C5F5", "609": "F4d5", "610": "F4D5", "611": "F4F5", "612": "g4G4", "613": "g4a4", "614": "g4A4", "615": "g4A4d5", "616": "g4A4D5", "617": "g4b4", "618": "g4b4d5", "619": "g4B4", "620": "g4C5", "621": "g4d5", "622": "g4D5", "623": "g4e5", "624": "g4g5", "625": "G4a4", "626": "G4A4", "627": "G4b4", "628": "G4b4D5", "629": "G4b4e5", "630": "G4B4", "631": "G4B4D5", "632": "G4B4E5", "633": "G4C5", "634": "G4C5e5", "635": "G4C5E5", "636": "G4d5", "637": "G4D5", "638": "G4D5G5", "639": "G4e5", "640": "G4E5", "641": "G4G5", "642": "a4A4", "643": "a4b4", "644": "a4B4", "645": "a4B4E5", "646": "a4C5", "647": "a4C5e5", "648": "a4d5", "649": "a4D5", "650": "a4e5", "651": "a4E5", "652": "a4F5", "653": "a4a5", "654": "A4b4", "655": "A4B4", "656": "A4C5", "657": "A4C5E5", "658": "A4C5F5", "659": "A4d5", "660": "A4d5E5", "661": "A4D5", "662": "A4D5F5", "663": "A4D5g5", "664": "A4e5", "665": "A4E5", "666": "A4F5", "667": "A4g5", "668": "A4G5", "669": "A4A5", "670": "b4B4", "671": "b4C5", "672": "b4d5", "673": "b4D5", "674": "b4D5F5", "675": "b4e5", "676": "b4e5G5", "677": "b4E5", "678": "b4F5", "679": "b4g5", "680": "b4G5", "681": "b4b5", "682": "B4C5", "683": "B4d5", "684": "B4D5", "685": "B4D5G5", "686": "B4e5", "687": "B4E5", "688": "B4F5", "689": "B4g5", "690": "B4G5", "691": "B4a5", "692": "B4B5", "693": "C5d5", "694": "C5D5", "695": "C5e5", "696": "C5e5G5", "697": "C5E5", "698": "C5E5G5", "699": "C5F5", "700": "C5F5A5", "701": "C5g5", "702": "C5G5", "703": "C5a5", "704": "C5A5", "705": "C5C6", "706": "d5D5", "707": "d5e5", "708": "d5E5", "709": "d5F5", "710": "d5g5", "711": "d5G5", "712": "d5a5", "713": "d5A5", "714": "d5b5", "715": "d5d6", "716": "D5e5", "717": "D5E5", "718": "D5F5", "719": "D5F5A5", "720": "D5g5", "721": "D5g5A5", "722": "D5G5", "723": "D5a5", "724": "D5A5", "725": "D5b5", "726": "D5B5", "727": "D5D6", "728": "e5E5", "729": "e5F5", "730": "e5g5", "731": "e5G5", "732": "e5a5", "733": "e5A5", "734": "e5b5", "735": "e5B5", "736": "e5C6", "737": "e5e6", "738": "E5F5", "739": "E5g5", "740": "E5G5", "741": "E5G5B5", "742": "E5a5", "743": "E5A5", "744": "E5b5", "745": "E5B5", "746": "E5C6", "747": "E5d6", "748": "E5E6", "749": "F5g5", "750": "F5G5", "751": "F5a5", "752": "F5A5", "753": "F5A5C6", "754": "F5b5", "755": "F5B5", "756": "F5C6", "757": "F5d6", "758": "F5D6", "759": "F5F6", "760": "g5G5", "761": "g5A5", "762": "g5b5", "763": "g5B5", "764": "g5d6", "765": "g5D6", "766": "g5g6", "767": "G5a5", "768": "G5A5", "769": "G5b5", "770": "G5B5", "771": "G5C6", "772": "G5D6", "773": "G5e6", "774": "G5E6", "775": "G5G6", "776": "a5A5", "777": "a5b5", "778": "a5B5", "779": "a5C6", "780": "a5d6", "781": "a5e6", "782": "a5a6", "783": "A5b5", "784": "A5B5", "785": "A5C6", "786": "A5d6", "787": "A5D6", "788": "A5E6", "789": "A5F6", "790": "A5A6", "791": "b5B5", "792": "b5C6", "793": "b5d6", "794": "b5D6", "795": "b5e6", "796": "b5F6", "797": "b5b6", "798": "B5C6", "799": "B5D6", "800": "B5e6", "801": "B5E6", "802": "B5B6", "803": "C6D6", "804": "C6e6", "805": "C6E6", "806": "C6F6", "807": "C6G6", "808": "C6C7", "809": "d6E6", "810": "d6F6", "811": "d6g6", "812": "d6d7", "813": "D6e6", "814": "D6E6", "815": "D6F6", "816": "D6g6", "817": "D6G6", "818": "D6A6", "819": "D6D7", "820": "e6g6", "821": "e6G6", "822": "e6e7", "823": "E6G6", "824": "E6a6", "825": "E6A6", "826": "E6E7", "827": "F6a6", "828": "F6A6", "829": "F6F7", "830": "g6A6", "831": "G6G7", "832": "HC+", "833": "HC/o7", "834": "HCD7", "835": "HCM", "836": "HCM7", "837": "HCm", "838": "HCm7", "839": "HCo", "840": "HCo7", "841": "HCsus2", "842": "HCsus4", "843": "Hd+", "844": "Hd/o7", "845": "HdD7", "846": "HdM", "847": "HdM7", "848": "Hdm", "849": "Hdm7", "850": "Hdo", "851": "Hdo7", "852": "Hdsus2", "853": "Hdsus4", "854": "HD+", "855": "HD/o7", "856": "HDD7", "857": "HDM", "858": "HDM7", "859": "HDm", "860": "HDm7", "861": "HDo", "862": "HDo7", "863": "HDsus2", "864": "HDsus4", "865": "He+", "866": "He/o7", "867": "HeD7", "868": "HeM", "869": "HeM7", "870": "Hem", "871": "Hem7", "872": "Heo", "873": "Heo7", "874": "Hesus2", "875": "Hesus4", "876": "HE+", "877": "HE/o7", "878": "HED7", "879": "HEM", "880": "HEM7", "881": "HEm", "882": "HEm7", "883": "HEo", "884": "HEo7", "885": "HEsus2", "886": "HEsus4", "887": "HF+", "888": "HF/o7", "889": "HFD7", "890": "HFM", "891": "HFM7", "892": "HFm", "893": "HFm7", "894": "HFo", "895": "HFo7", "896": "HFsus2", "897": "HFsus4", "898": "Hg+", "899": "Hg/o7", "900": "HgD7", "901": "HgM", "902": "HgM7", "903": "Hgm", "904": "Hgm7", "905": "Hgo", "906": "Hgo7", "907": "Hgsus2", "908": "Hgsus4", "909": "HG+", "910": "HG/o7", "911": "HGD7", "912": "HGM", "913": "HGM7", "914": "HGm", "915": "HGm7", "916": "HGo", "917": "HGo7", "918": "HGsus2", "919": "HGsus4", "920": "Ha+", "921": "Ha/o7", "922": "HaD7", "923": "HaM", "924": "HaM7", "925": "Ham", "926": "Ham7", "927": "Hao", "928": "Hao7", "929": "Hasus2", "930": "Hasus4", "931": "HA+", "932": "HA/o7", "933": "HAD7", "934": "HAM", "935": "HAM7", "936": "HAm", "937": "HAm7", "938": "HAo", "939": "HAo7", "940": "HAsus2", "941": "HAsus4", "942": "Hb+", "943": "Hb/o7", "944": "HbD7", "945": "HbM", "946": "HbM7", "947": "Hbm", "948": "Hbm7", "949": "Hbo", "950": "Hbo7", "951": "Hbsus2", "952": "Hbsus4", "953": "HB+", "954": "HB/o7", "955": "HBD7", "956": "HBM", "957": "HBM7", "958": "HBm", "959": "HBm7", "960": "HBo", "961": "HBo7", "962": "HBsus2", "963": "HBsus4", "964": "HNA", "965": "m1", "966": "m2", "967": "m3", "968": "m4", "969": "m6", "970": "m7", "971": "m8", "972": "m9", "973": "ma", "974": "mb", "975": "mc", "976": "md", "977": "me", "978": "mf", "979": "mg", "980": "mh", "981": "mi", "982": "mj", "983": "mk", "984": "ml", "985": "mm", "986": "mn", "987": "mo", "988": "mp", "989": "mq", "990": "mr", "991": "ms", "992": "mt", "993": "mu", "994": "mw", "995": "mx", "996": "my", "997": "mz", "998": "mA", "999": "mB", "1000": "mC", "1001": "mE", "1002": "mG", "1003": "mI", "1004": "mK", "1005": "mM", "1006": "mO", "1007": "mQ", "1008": "mU", "1009": "mV", "1010": "mY", "1011": "M1", "1012": "M2", "1013": "M5", "1014": "M6", "1015": "M8", "1016": "M9", "1017": "Ma", "1018": "Me", "1019": "Mg", "1020": "Mi", "1021": "Mm", "1022": "Mq", "1023": "Ms", "1024": "Mu", "1025": "My", "1026": "MC", "1027": "p0", "1028": "p1", "1029": "p2", "1030": "p3", "1031": "p4", "1032": "p5", "1033": "p6", "1034": "p7", "1035": "p8", "1036": "p9", "1037": "pa", "1038": "pb", "1039": "pc", "1040": "pd", "1041": "pe", "1042": "pf", "1043": "pg", "1044": "ph", "1045": "pi", "1046": "pj", "1047": "pk", "1048": "pl", "1049": "pm", "1050": "pn", "1051": "po", "1052": "pp", "1053": "pq", "1054": "pr", "1055": "ps", "1056": "pt", "1057": "pu", "1058": "pv", "1059": "pw", "1060": "px", "1061": "py", "1062": "pz", "1063": "pA", "1064": "pB", "1065": "pC", "1066": "pD", "1067": "pE", "1068": "pF", "1069": "pG", "1070": "pH", "1071": "pI", "1072": "pJ", "1073": "pK", "1074": "pL", "1075": "pM", "1076": "pN", "1077": "pO", "1078": "pP", "1079": "pQ", "1080": "pR", "1081": "pS", "1082": "pT", "1083": "pU", "1084": "pV", "1085": "pW", "1086": "pX", "1087": "pY", "1088": "pZ", "1089": "P0", "1090": "P1", "1091": "P2", "1092": "P3", "1093": "P4", "1094": "P5", "1095": "P6", "1096": "P7", "1097": "P8", "1098": "P9", "1099": "Pa", "1100": "Pb", "1101": "Pc", "1102": "Pd", "1103": "Pe", "1104": "Pf", "1105": "Pg", "1106": "Ph", "1107": "Pi", "1108": "Pj", "1109": "Pk", "1110": "Pl", "1111": "Pm", "1112": "Pn", "1113": "Po", "1114": "Pp", "1115": "Pq", "1116": "Pr", "1117": "Ps", "1118": "Pt", "1119": "Pu", "1120": "Pv", "1121": "Pw", "1122": "Px", "1123": "PA", "1124": "NT"} -------------------------------------------------------------------------------- /src/preprocess/preprocess_midi.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import itertools, copy 3 | from more_itertools import split_before 4 | import os, traceback, time, warnings, sys 5 | import multiprocessing 6 | from miditoolkit.midi.parser import MidiFile 7 | from miditoolkit.midi.containers import Instrument 8 | from miditoolkit.midi.containers import Note as mtkNote 9 | from chorder import Dechorder 10 | 11 | import sys, os 12 | 13 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 14 | from encoding import pit2str, pos2str, bom2str, dur2str, trk2str, ins2str, pit2alphabet 15 | 16 | WORKERS = 32 17 | 18 | 19 | def measure_calc_chord(evt_seq): 20 | assert evt_seq[0][1] == 'BOM', "wrong measure for chord" 21 | bom_tick = evt_seq[0][0] 22 | ts = min(evt_seq[0][-1], 8) 23 | chroma = Counter() 24 | mtknotes = [] 25 | for evt in evt_seq[1:-1]: 26 | assert evt[1] == 'ON', "wrong measure for chord: " + evt[1] + evt_seq[-1][1] 27 | if evt[3] == 128: # exclude drums 28 | continue 29 | o, p, d = evt[0] - bom_tick, evt[2], evt[-1] 30 | if p < 21 or p > 108: # exclude unusual pitch 31 | continue 32 | if o < 8: 33 | note = mtkNote(60, p, o, o + d if o > 0 else 8) 34 | mtknotes.append(note) 35 | else: 36 | break 37 | 38 | chord, score = Dechorder.get_chord_quality(mtknotes, start=0, end=ts) 39 | if score < 0: 40 | return [bom_tick, 'CHR', None, None, None, None, 'NA'] 41 | return [bom_tick, 'CHR', None, None, None, None, 42 | pit2alphabet[chord.root_pc] + (chord.quality if chord.quality != '7' else 'D7')] 43 | 44 | 45 | def merge_drums(p_midi): # merge all percussions 46 | drum_0_lst = [] 47 | new_instruments = [] 48 | for instrument in p_midi.instruments: 49 | if not len(instrument.notes) == 0: 50 | # -------------------- 51 | if instrument.is_drum: 52 | for note in instrument.notes: 53 | drum_0_lst.append(note) 54 | else: 55 | new_instruments.append(instrument) 56 | if len(drum_0_lst) > 0: 57 | drum_0_lst.sort(key=lambda x: x.start) 58 | # remove duplicate 59 | drum_0_lst = list(k for k, _ in itertools.groupby(drum_0_lst)) 60 | 61 | drum_0_instrument = Instrument(program=0, is_drum=True, name="drum") 62 | drum_0_instrument.notes = drum_0_lst 63 | new_instruments.append(drum_0_instrument) 64 | 65 | p_midi.instruments = new_instruments 66 | 67 | 68 | def merge_sparse_track(p_midi, CANDI_THRES=50, MIN_THRES=5): # merge track has too less notes 69 | good_instruments = [] 70 | bad_instruments = [] 71 | good_instruments_idx = [] 72 | for instrument in p_midi.instruments: 73 | if len(instrument.notes) < CANDI_THRES: 74 | bad_instruments.append(instrument) 75 | else: 76 | good_instruments.append(instrument) 77 | good_instruments_idx.append((instrument.program, instrument.is_drum)) 78 | 79 | for bad_instrument in bad_instruments: 80 | if (bad_instrument.program, bad_instrument.is_drum) in good_instruments_idx: 81 | # find one track to merge 82 | for instrument in good_instruments: 83 | if bad_instrument.program == instrument.program and \ 84 | bad_instrument.is_drum == instrument.is_drum: 85 | instrument.notes.extend(bad_instrument.notes) 86 | break 87 | # no track to merge 88 | else: 89 | if len(bad_instrument.notes) > MIN_THRES: 90 | good_instruments.append(bad_instrument) 91 | p_midi.instruments = good_instruments 92 | 93 | 94 | def limit_max_track(p_midi, MAX_TRACK=40): # merge track with least notes and limit the maximum amount of track to 40 95 | 96 | good_instruments = p_midi.instruments 97 | good_instruments.sort( 98 | key=lambda x: (not x.is_drum, -len(x.notes))) # place drum track or the most note track at first 99 | assert good_instruments[0].is_drum == True or len(good_instruments[0].notes) >= len( 100 | good_instruments[1].notes), tuple(len(x.notes) for x in good_instruments[:3]) 101 | # assert good_instruments[0].is_drum == False, (, len(good_instruments[2])) 102 | track_idx_lst = list(range(len(good_instruments))) 103 | 104 | if len(good_instruments) > MAX_TRACK: 105 | new_good_instruments = copy.deepcopy(good_instruments[:MAX_TRACK]) 106 | 107 | # print(midi_file_path) 108 | for id in track_idx_lst[MAX_TRACK:]: 109 | cur_ins = good_instruments[id] 110 | merged = False 111 | new_good_instruments.sort(key=lambda x: len(x.notes)) 112 | for nid, ins in enumerate(new_good_instruments): 113 | if cur_ins.program == ins.program and cur_ins.is_drum == ins.is_drum: 114 | new_good_instruments[nid].notes.extend(cur_ins.notes) 115 | merged = True 116 | break 117 | if not merged: 118 | pass # print('Track {:d} deprecated, program {:d}, note count {:d}'.format(id, cur_ins.program, len(cur_ins.notes))) 119 | good_instruments = new_good_instruments 120 | # print(trks, probs, chosen) 121 | 122 | assert len(good_instruments) <= MAX_TRACK, len(good_instruments) 123 | for idx, good_instrument in enumerate(good_instruments): 124 | if good_instrument.is_drum: 125 | good_instruments[idx].program = 128 126 | good_instruments[idx].is_drum = False 127 | 128 | p_midi.instruments = good_instruments 129 | 130 | 131 | def get_init_note_events(p_midi): # extract all notes in midi file 132 | 133 | note_events, note_on_ticks, note_dur_lst = [], [], [] 134 | for track_idx, instrument in enumerate(p_midi.instruments): 135 | # track_idx_lst.append(track_idx) 136 | for note in instrument.notes: 137 | note_dur = note.end - note.start 138 | 139 | # special case: note_dur too long 140 | max_dur = 4 * p_midi.ticks_per_beat 141 | if note_dur / max_dur > 1: 142 | 143 | total_dur = note_dur 144 | start = note.start 145 | while total_dur != 0: 146 | if total_dur > max_dur: 147 | note_events.extend([[start, "ON", note.pitch, instrument.program, 148 | instrument.is_drum, track_idx, max_dur]]) 149 | 150 | note_on_ticks.append(start) 151 | note_dur_lst.append(max_dur) 152 | 153 | start += max_dur 154 | total_dur -= max_dur 155 | else: 156 | note_events.extend([[start, "ON", note.pitch, instrument.program, 157 | instrument.is_drum, track_idx, total_dur]]) 158 | note_on_ticks.append(start) 159 | note_dur_lst.append(total_dur) 160 | 161 | total_dur = 0 162 | 163 | else: 164 | note_events.extend( 165 | [[note.start, "ON", note.pitch, instrument.program, instrument.is_drum, track_idx, note_dur]]) 166 | 167 | # for score analysis and beat estimating when score has no time signature 168 | note_on_ticks.append(note.start) 169 | note_dur_lst.append(note.end - note.start) 170 | 171 | note_events.sort(key=lambda x: (x[0], x[1] == "ON", x[5], x[4], x[3], x[2], x[-1])) 172 | note_events = list(k for k, _ in itertools.groupby(note_events)) 173 | return note_events, note_on_ticks, note_dur_lst 174 | 175 | 176 | def calculate_measure(p_midi, first_event_tick, 177 | last_event_tick): # calculate measures and append measure symbol to event_seq 178 | 179 | measure_events = [] 180 | time_signature_changes = p_midi.time_signature_changes 181 | 182 | if not time_signature_changes: # no time_signature_changes, estimate it 183 | raise AssertionError("No time_signature_changes") 184 | else: 185 | if time_signature_changes[0].time != 0 and \ 186 | time_signature_changes[0].time > first_event_tick: 187 | raise AssertionError("First time signature start with None zero tick") 188 | 189 | # clean duplicate time_signature_changes 190 | temp_sig = [] 191 | for idx, time_sig in enumerate(time_signature_changes): 192 | if idx == 0: 193 | temp_sig.append(time_sig) 194 | else: 195 | previous_timg_sig = time_signature_changes[idx - 1] 196 | if not (previous_timg_sig.numerator == time_sig.numerator 197 | and previous_timg_sig.denominator == time_sig.denominator): 198 | temp_sig.append(time_sig) 199 | time_signature_changes = temp_sig 200 | # print("time_signature_changes", time_signature_changes) 201 | for idx in range(len(time_signature_changes)): 202 | # calculate measures, eg: how many ticks per measure 203 | numerator = time_signature_changes[idx].numerator 204 | denominator = time_signature_changes[idx].denominator 205 | ticks_per_measure = p_midi.ticks_per_beat * (4 / denominator) * numerator 206 | 207 | cur_tick = time_signature_changes[idx].time 208 | 209 | if idx < len(time_signature_changes) - 1: 210 | next_tick = time_signature_changes[idx + 1].time 211 | else: 212 | next_tick = last_event_tick + int(ticks_per_measure) 213 | 214 | if ticks_per_measure.is_integer(): 215 | for measure_start_tick in range(cur_tick, next_tick, int(ticks_per_measure)): 216 | if measure_start_tick + int(ticks_per_measure) > next_tick: 217 | measure_events.append([measure_start_tick, "BOM", None, None, None, None, 0]) 218 | measure_events.append([next_tick, "EOM", None, None, None, None, 0]) 219 | else: 220 | measure_events.append([measure_start_tick, "BOM", None, None, None, None, 0]) 221 | measure_events.append( 222 | [measure_start_tick + int(ticks_per_measure), "EOM", None, None, None, None, 0]) 223 | else: 224 | assert False, "ticks_per_measure Error" 225 | return measure_events 226 | 227 | 228 | def quantize_by_nth(nth_tick, note_events): 229 | # Eg. Quantize by 32th note 230 | 231 | half = nth_tick / 2 232 | split_score = list(split_before(note_events, lambda x: x[1] == "BOM")) 233 | measure_durs = [] 234 | eom_tick = 0 235 | for measure_id, measure in enumerate(split_score): 236 | bom_tick = measure[0][0] 237 | assert bom_tick == eom_tick, 'measure time error {bom_tick} {eom_tick}' 238 | eom_tick = measure[-1][0] 239 | mea_dur = eom_tick - bom_tick 240 | if mea_dur < nth_tick: # measure duration need to be quantized 241 | measure_durs.append(1) 242 | else: 243 | if mea_dur % nth_tick < half: # quantize to left 244 | measure_durs.append(mea_dur // nth_tick) 245 | else: 246 | measure_durs.append(mea_dur // nth_tick + 1) 247 | 248 | for evt in measure[1:-1]: 249 | assert evt[1] == 'ON', f'measure structure error {evt[1]}' 250 | rel_tick = evt[0] - bom_tick 251 | if rel_tick % nth_tick <= half: 252 | rel_tick = min(rel_tick // nth_tick, measure_durs[-1] - 1) 253 | else: 254 | rel_tick = min(rel_tick // nth_tick + 1, measure_durs[-1] - 1) 255 | evt[0] = rel_tick 256 | 257 | final_events = [] 258 | lasteom = 0 259 | for measure_id, measure in enumerate(split_score): 260 | measure[0][0] = lasteom 261 | measure[-1][0] = measure[0][0] + measure_durs[measure_id] 262 | lasteom = measure[-1][0] 263 | 264 | for event in measure[1:-1]: 265 | event[0] += measure[0][0] 266 | 267 | if event[-1] < nth_tick: # duration too short, quantize to 1 268 | event[-1] = 1 269 | else: 270 | if event[-1] % nth_tick <= half: 271 | event[-1] = event[-1] // nth_tick 272 | else: 273 | event[-1] = event[-1] // nth_tick + 1 274 | 275 | final_events.extend(measure) 276 | return final_events 277 | 278 | 279 | def prettify(note_events, ticks_per_beat): 280 | fist_event_idx = next(i for i in (range(len(note_events))) if note_events[i][1] == "ON") 281 | last_event_idx = next(i for i in reversed(range(len(note_events))) if note_events[i][1] == "ON") 282 | 283 | assert note_events[fist_event_idx - 1][1] == "BOM", "measure_start Error" 284 | assert note_events[last_event_idx + 1][1] == "EOM", "measure_end Error" 285 | 286 | # remove invalid measures on both sides 287 | note_events = note_events[fist_event_idx - 1: last_event_idx + 2] 288 | 289 | # check again 290 | assert note_events[0][1] == "BOM", "measure_start Error" 291 | assert note_events[-1][1] == "EOM", "measure_end Error" 292 | 293 | # -------------- zero start tick ----------------- 294 | start_tick = note_events[0][0] 295 | if start_tick != 0: 296 | for event in note_events: 297 | event[0] -= start_tick 298 | 299 | from fractions import Fraction 300 | ticks_32th = Fraction(ticks_per_beat, 8) 301 | 302 | note_events = quantize_by_nth(ticks_32th, note_events) 303 | 304 | note_events.sort(key=lambda x: (x[0], x[1] == "ON", x[1] == "BOM", x[1] == "EOM", 305 | x[5], x[4], x[3], x[2], x[-1])) 306 | note_events = list(k for k, _ in itertools.groupby(note_events)) 307 | 308 | # -------------------------check measure duration---------------------------------------------- 309 | note_events.sort(key=lambda x: (x[0], x[1] == "ON", x[1] == "BOM", x[1] == "EOM", 310 | x[5], x[4], x[3], x[2], x[-1])) 311 | split_score = list(split_before(note_events, lambda x: x[1] == "BOM")) 312 | 313 | check_measure_dur = [0] 314 | 315 | for measure_idx, measure in enumerate(split_score): 316 | first_tick = measure[0][0] 317 | last_tick = measure[-1][0] 318 | measure_dur = last_tick - first_tick 319 | if measure_dur > 100: 320 | raise AssertionError("Measure duration error") 321 | split_score[measure_idx][0][-1] = measure_dur 322 | 323 | if measure_dur in check_measure_dur: 324 | # print(measure_dur) 325 | raise AssertionError("Measure duration error") 326 | return split_score 327 | 328 | 329 | def get_pos_and_cc(split_score): 330 | new_event_seq = [] 331 | for measure_idx, measure in enumerate(split_score): 332 | measure.sort(key=lambda x: (x[1] == "EOM", x[1] == "ON", x[1] == 'CHR', x[1] == "BOM", x[-2])) 333 | bom_tick = measure[0][0] 334 | 335 | # split measure by track 336 | track_nmb = set(map(lambda x: x[-2], measure[2:-1])) 337 | tracks = [[y for y in measure if y[-2] == x] for x in track_nmb] 338 | 339 | # ---------- calculate POS for each track / add CC 340 | new_measure = [] 341 | for track_idx, track in enumerate(tracks): 342 | pos_lst = [] 343 | trk_abs_num = -1 344 | for event in track: 345 | if event[1] == "ON": 346 | assert trk_abs_num == -1 or trk_abs_num == event[ 347 | -2], "Error: found inconsistent trackid within same track" 348 | trk_abs_num = event[-2] 349 | mypos = event[0] - bom_tick 350 | pos_lst.append(mypos) 351 | pos_lst = list(set(pos_lst)) 352 | 353 | for pos in pos_lst: 354 | tracks[track_idx].append([pos + bom_tick, "POS", None, None, None, None, pos]) 355 | tracks[track_idx].insert(0, [bom_tick, "CC", None, None, None, None, trk_abs_num]) 356 | tracks[track_idx].sort( 357 | key=lambda x: (x[0], x[1] == "ON", x[1] == "POS", x[1] == "CC", x[5], x[4], x[3], x[2])) 358 | 359 | new_measure.append(measure[0]) 360 | new_measure.append(measure[1]) 361 | for track in tracks: 362 | for idx, event in enumerate(track): 363 | new_measure.append(event) 364 | 365 | new_event_seq.extend(new_measure) 366 | 367 | return new_event_seq 368 | 369 | 370 | def event_seq_to_str(new_event_seq): 371 | char_events = [] 372 | 373 | for evt in new_event_seq: 374 | if evt[1] == 'ON': 375 | char_events.append(pit2str(evt[2])) # pitch 376 | char_events.append(dur2str(evt[-1])) # duration 377 | char_events.append(trk2str(evt[-2])) # track 378 | char_events.append(ins2str(evt[3])) # instrument 379 | elif evt[1] == 'POS': 380 | char_events.append(pos2str(evt[-1])) # type (time position) 381 | char_events.append('RZ') 382 | char_events.append('TZ') 383 | char_events.append('YZ') 384 | elif evt[1] == 'BOM': 385 | char_events.append(bom2str(evt[-1])) 386 | char_events.append('RZ') 387 | char_events.append('TZ') 388 | char_events.append('YZ') 389 | elif evt[1] == 'CC': 390 | char_events.append('NT') 391 | char_events.append('RZ') 392 | char_events.append('TZ') 393 | char_events.append('YZ') 394 | elif evt[1] == 'CHR': 395 | char_events.append('H' + evt[-1]) 396 | char_events.append('RZ') 397 | char_events.append('TZ') 398 | char_events.append('YZ') 399 | else: 400 | assert False, ("evt type error", evt[1]) 401 | return char_events 402 | 403 | 404 | # abs_pos type pitch program is_drum track_id duration/rela_pos 405 | def midi_to_event_seq_str(midi_file_path, readonly=False): 406 | p_midi = MidiFile(midi_file_path) 407 | for ins in p_midi.instruments: 408 | ins.remove_invalid_notes(verbose=False) 409 | 410 | merge_drums(p_midi) 411 | 412 | if not readonly: 413 | merge_sparse_track(p_midi) 414 | 415 | limit_max_track(p_midi) 416 | 417 | note_events, note_on_ticks, _ = get_init_note_events(p_midi) 418 | 419 | measure_events = calculate_measure(p_midi, min(note_on_ticks), max(note_on_ticks)) 420 | note_events.extend(measure_events) 421 | note_events.sort(key=lambda x: (x[0], x[1] == "ON", x[1] == "BOM", x[1] == "EOM", 422 | x[5], x[4], x[3], x[2])) 423 | 424 | split_score = prettify(note_events, p_midi.ticks_per_beat) 425 | 426 | for measure_idx, measure in enumerate(split_score): # calculate chord for every measure 427 | chord_evt = measure_calc_chord(measure) 428 | split_score[measure_idx].insert(1, chord_evt) 429 | 430 | new_event_seq = get_pos_and_cc(split_score) 431 | 432 | char_events = event_seq_to_str(new_event_seq) 433 | 434 | return char_events 435 | 436 | 437 | def mp_worker(file_path): 438 | try: 439 | event_seq = midi_to_event_seq_str(file_path) 440 | return event_seq 441 | except (OSError, EOFError, ValueError, KeyError) as e: 442 | print(file_path) 443 | traceback.print_exc(limit=0) 444 | print() 445 | return "error" 446 | 447 | except AssertionError as e: 448 | if str(e) == "No time_signature_changes": 449 | return "error" 450 | elif str(e) == "Measure duration error": 451 | # print("Measure duration error", file_path) 452 | return "error" 453 | else: 454 | print("Other Assertion Error", str(e), file_path) 455 | return "error" 456 | 457 | except Exception as e: 458 | print(file_path) 459 | traceback.print_exc(limit=0) 460 | print() 461 | return "error" 462 | 463 | 464 | def mp_handler(file_paths): 465 | start = time.time() 466 | 467 | broken_counter = 0 468 | good_counter = 0 469 | 470 | event_seq_res = [] 471 | chord_cnter = Counter() 472 | print(f'starts processing {len(file_paths)} midis with {WORKERS} processes') 473 | 474 | with multiprocessing.Pool(WORKERS) as p: 475 | for event_seq in p.imap(mp_worker, file_paths): 476 | if isinstance(event_seq, str): 477 | broken_counter += 1 478 | elif len(event_seq) > 0: 479 | event_seq_res.append(event_seq) 480 | good_counter += 1 481 | else: 482 | broken_counter += 1 483 | 484 | print( 485 | f"MIDI data preprocessing takes: {time.time() - start}s, {good_counter} samples collected, {broken_counter} broken.") 486 | 487 | # ---------------------------------------------------------------------------------- 488 | txt_start = time.time() 489 | if not os.path.exists('data/preprocessed/'): 490 | os.makedirs('data/preprocessed/') 491 | 492 | with open("data/preprocessed/raw_corpus.txt", "w", encoding="utf-8") as f: 493 | for idx, piece in enumerate(event_seq_res): 494 | f.write(' '.join(piece) + '\n') 495 | 496 | print("Create txt file takes: ", time.time() - txt_start) 497 | # ---------------------------------------------------------------------------------- 498 | 499 | 500 | if __name__ == '__main__': 501 | 502 | warnings.filterwarnings('ignore') 503 | 504 | folder_path = "data/midis" 505 | file_paths = [] 506 | for path, directories, files in os.walk(folder_path): 507 | for file in files: 508 | if file.endswith(".mid") or file.endswith(".MID"): 509 | file_path = path + "/" + file 510 | file_paths.append(file_path) 511 | 512 | # run multi-processing midi extractor 513 | mp_handler(file_paths) 514 | -------------------------------------------------------------------------------- /src/fairseq/linear_transformer_inference/linear_transformer_multi.py: -------------------------------------------------------------------------------- 1 | from fast_transformers.builders import TransformerEncoderBuilder, RecurrentEncoderBuilder 2 | from fast_transformers.masking import TriangularCausalMask, LengthMask 3 | 4 | import logging, math 5 | import os 6 | import sys 7 | 8 | import numpy as np 9 | from typing import Dict, List, Optional, Tuple 10 | from dataclasses import dataclass, field 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch import Tensor 16 | 17 | from fairseq.data.shorten_dataset import maybe_shorten_dataset 18 | from fairseq import utils, metrics 19 | from fairseq.tasks.language_modeling import LanguageModelingTask, LanguageModelingConfig 20 | from fairseq.tasks import register_task 21 | from fairseq.models import ( 22 | FairseqDecoder, 23 | FairseqLanguageModel, 24 | register_model, 25 | register_model_architecture, 26 | ) 27 | from fairseq.data import ( 28 | MonolingualDataset, 29 | TokenBlockDataset, 30 | plasma_utils, 31 | data_utils, 32 | ) 33 | from fairseq.criterions import register_criterion 34 | from fairseq.criterions.cross_entropy import CrossEntropyCriterion 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | DEFAULT_MAX_TARGET_POSITIONS = 1024 39 | # INF = 2147483647 40 | 41 | @register_criterion("multiple_loss") 42 | class MultiplelossCriterion(CrossEntropyCriterion): 43 | def forward(self, model, sample, reduce=True): 44 | """Compute the loss for the given sample. 45 | 46 | Returns a tuple with three elements: 47 | 1) the loss 48 | 2) the sample size, which is used as the denominator for the gradient 49 | 3) logging outputs to display while training 50 | """ 51 | net_output = model(**sample["net_input"]) 52 | losses = self.compute_loss(model, net_output, sample, reduce=reduce) # return a list 53 | assert not self.sentence_avg 54 | #TODO: adjust weight of evt losses and other losses by length (current strategy: simple average the losses) 55 | # weights = [sample["ntokens"]] + [sample["ontokens"]] * (len(losses) - 1) 56 | loss = torch.mean(torch.stack(losses)) 57 | logging_output = { 58 | "loss": loss.data, 59 | "evt_loss": losses[0].data, 60 | "dur_loss": losses[1].data, 61 | "trk_loss": losses[2].data, 62 | "ins_loss": losses[3].data, 63 | "ntokens": sample["ntokens"], 64 | "nsentences": sample["target"].size(0), 65 | "sample_size": sample["ntokens"], 66 | "on_sample_size": sample["ntokens"], 67 | } 68 | return loss, sample["ntokens"], logging_output 69 | 70 | def compute_loss(self, model, net_output, sample, reduce=True): 71 | lprobs_tuple = model.get_normalized_probs(net_output, log_probs=True) 72 | losses = [] 73 | for idx, lprobs in enumerate(lprobs_tuple): 74 | lprobs = lprobs.view(-1, lprobs.size(-1)) 75 | target = model.get_targets(sample, net_output)[..., idx].view(-1) 76 | 77 | loss = F.nll_loss( 78 | lprobs, 79 | target, 80 | ignore_index=self.padding_idx, 81 | reduction="sum" if reduce else "none", 82 | ) 83 | losses.append(loss) 84 | return losses 85 | 86 | @staticmethod 87 | def reduce_metrics(logging_outputs) -> None: 88 | """Aggregate logging outputs from data parallel training.""" 89 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 90 | loss_evt = sum(log.get("evt_loss", 0) for log in logging_outputs) 91 | loss_dur = sum(log.get("dur_loss", 0) for log in logging_outputs) 92 | loss_trk = sum(log.get("trk_loss", 0) for log in logging_outputs) 93 | loss_ins = sum(log.get("ins_loss", 0) for log in logging_outputs) 94 | ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) 95 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 96 | on_sample_size = sum(log.get("on_sample_size", 0) for log in logging_outputs) 97 | # we divide by log(2) to convert the loss from base e to base 2 98 | # total_losses = 4 99 | # weighted_size = (sample_size + on_sample_size*(total_losses-1)) / total_losses 100 | metrics.log_scalar( 101 | "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 102 | ) 103 | metrics.log_scalar( 104 | "evt_loss", loss_evt / sample_size / math.log(2), sample_size, round=3 105 | ) 106 | metrics.log_scalar( 107 | "dur_loss", loss_dur / on_sample_size / math.log(2), on_sample_size, round=3 108 | ) 109 | metrics.log_scalar( 110 | "trk_loss", loss_trk / on_sample_size / math.log(2), on_sample_size, round=3 111 | ) 112 | metrics.log_scalar( 113 | "ins_loss", loss_ins / on_sample_size / math.log(2), on_sample_size, round=3 114 | ) 115 | 116 | if sample_size != ntokens: 117 | metrics.log_scalar( 118 | "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 119 | ) 120 | metrics.log_derived( 121 | "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) 122 | ) 123 | else: 124 | metrics.log_derived( 125 | "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) 126 | ) 127 | metrics.log_derived( 128 | "evt_ppl", lambda meters: utils.get_perplexity(meters["evt_loss"].avg) 129 | ) 130 | metrics.log_derived( 131 | "dur_ppl", lambda meters: utils.get_perplexity(meters["dur_loss"].avg) 132 | ) 133 | metrics.log_derived( 134 | "trk_ppl", lambda meters: utils.get_perplexity(meters["trk_loss"].avg) 135 | ) 136 | metrics.log_derived( 137 | "ins_ppl", lambda meters: utils.get_perplexity(meters["ins_loss"].avg) 138 | ) 139 | 140 | @staticmethod 141 | def logging_outputs_can_be_summed() -> bool: 142 | """ 143 | Whether the logging outputs returned by `forward` can be summed 144 | across workers prior to calling `reduce_metrics`. Setting this 145 | to True will improves distributed training speed. 146 | """ 147 | return True 148 | 149 | 150 | @register_model("linear_transformer_multi") 151 | class LinearTransformerMultiHeadLM(FairseqLanguageModel): 152 | def __init__(self, decoder): 153 | super().__init__(decoder) 154 | 155 | @staticmethod 156 | def add_args(parser): 157 | """Add model-specific arguments to the parser.""" 158 | # fmt: off 159 | parser.add_argument('--embed-dim', type=int, metavar='N', 160 | help='embedding dimension') 161 | parser.add_argument('--num-attention-heads', type=int, metavar='N', 162 | help='num attention heads') 163 | parser.add_argument('--num-layers', type=int, metavar='N', 164 | help='num layers') 165 | parser.add_argument('--dropout', type=float, metavar='D', 166 | help='dropout probability for all fully connected layers ' 167 | 'in the embeddings, encoder, and pooler') 168 | 169 | # parser.add_argument('--max-pos-len', type=int, metavar='N', 170 | # help='max positions in transformer') 171 | 172 | # parser.add_argument('--attention-dropout', type=float, metavar='D', 173 | # help='dropout probability for attention weights') 174 | # fmt: on 175 | 176 | @classmethod 177 | def build_model(cls, args, task): 178 | """Build a new model instance.""" 179 | base_architecture(args) 180 | return cls(LinearTransformerMultiHeadDecoder(args, task)) 181 | 182 | 183 | class LinearTransformerMultiHeadDecoder(FairseqDecoder): 184 | def __init__(self, args, task): 185 | super().__init__(task.target_dictionary) 186 | #print(task.target_dictionary) 187 | # for i in range(len(task.target_dictionary)): 188 | # print(i, task.target_dictionary[i]) 189 | self.embed_dim = args.embed_dim 190 | self.wEvte = nn.Embedding(args.evt_voc_size, args.embed_dim) 191 | self.wTrke = nn.Embedding(args.trk_voc_size, args.embed_dim) 192 | self.wDure = nn.Embedding(args.dur_voc_size, args.embed_dim) 193 | self.max_pos = args.tokens_per_sample 194 | #self.ratio = args.ratio 195 | #print("max positions:", self.max_pos) 196 | 197 | self.perm_inv = args.perm_inv 198 | if self.perm_inv > 1: 199 | self.wRpe = nn.Embedding(args.max_rel_pos+1, args.embed_dim) 200 | self.wMpe = nn.Embedding(args.max_mea_pos+1, args.embed_dim) 201 | else: 202 | self.wpe = nn.Embedding(self.max_pos+1, args.embed_dim) # max_pos_len = 4096 203 | self.drop = nn.Dropout(args.dropout) 204 | self.ln_f = nn.LayerNorm(args.embed_dim, eps=1e-6) 205 | 206 | 207 | self.model = RecurrentEncoderBuilder.from_kwargs( 208 | n_layers=args.num_layers, 209 | n_heads=args.num_attention_heads, 210 | query_dimensions=args.embed_dim // args.num_attention_heads, 211 | value_dimensions=args.embed_dim // args.num_attention_heads, 212 | feed_forward_dimensions=4 * args.embed_dim, 213 | activation='gelu', 214 | #final_normalization=True, 215 | dropout=args.dropout, 216 | attention_type="causal-linear", 217 | #feature_map=Favor.factory(n_dims=self.d_model) 218 | ).get() 219 | 220 | self.attn_mask = TriangularCausalMask(self.max_pos) 221 | self.proj_evt = nn.Linear(args.embed_dim, args.evt_voc_size, bias=False) 222 | self.proj_dur = nn.Linear(args.embed_dim, args.dur_voc_size, bias=False) 223 | self.proj_trk = nn.Linear(args.embed_dim, args.trk_voc_size, bias=False) 224 | self.proj_ins = nn.Linear(args.embed_dim, args.ins_voc_size, bias=False) 225 | 226 | self.apply(self._init_weights) 227 | # set zero embedding for padding symbol 228 | #TODO: check will the pad id be trained? (as TZ RZ YZ) 229 | self.pad_idx = task.target_dictionary.pad() 230 | self.wEvte.weight.data[self.pad_idx].zero_() 231 | self.wDure.weight.data[self.pad_idx].zero_() 232 | self.wTrke.weight.data[self.pad_idx].zero_() 233 | if self.perm_inv > 1: 234 | self.wRpe.weight.data[0].zero_() 235 | self.wMpe.weight.data[0].zero_() 236 | else: 237 | self.wpe.weight.data[0].zero_() 238 | 239 | def _init_weights(self, module): 240 | if isinstance(module, (nn.Linear, nn.Embedding)): 241 | module.weight.data.normal_(mean=0.0, std=self.embed_dim ** -0.5) 242 | if isinstance(module, nn.Linear) and module.bias is not None: 243 | module.bias.data.zero_() 244 | elif isinstance(module, nn.LayerNorm): 245 | module.bias.data.zero_() 246 | module.weight.data.fill_(1.0) 247 | 248 | def forward( 249 | self, 250 | x, 251 | src_lengths=None, 252 | ): 253 | features, memory = self.extract_features(x, src_lengths) 254 | evt_logits = self.proj_evt(features) 255 | dur_logits = self.proj_dur(features) 256 | trk_logits = self.proj_trk(features) 257 | ins_logits = self.proj_ins(features) 258 | 259 | return (evt_logits, dur_logits, trk_logits, ins_logits), memory 260 | 261 | def extract_features( 262 | self, 263 | x, 264 | src_lengths = None # put memory here 265 | ): 266 | 267 | bsz, seq_len, ratio = x.size() 268 | #print(bsz, seq_len, ratio) 269 | evt_emb = self.wEvte(x[..., 0]) 270 | evton_mask = x[..., 1].ne(self.pad_idx).float()[..., None].to(x.device) 271 | dur_emb = self.wDure(x[..., 1]) * evton_mask 272 | trk_emb = self.wTrke(x[..., 2]) * evton_mask 273 | 274 | if self.perm_inv > 1: 275 | rel_pos = x[..., 4] 276 | measure_ids = x[..., 5] 277 | pos_emb = self.wMpe(measure_ids) + self.wRpe(rel_pos) 278 | else: 279 | # set instrument id as position id 280 | position_ids = x[..., 3] 281 | pos_emb = self.wpe(position_ids) 282 | 283 | x = self.drop(evt_emb+dur_emb+trk_emb+pos_emb) 284 | 285 | outputs, memory = self.model(x.squeeze(0), src_lengths) 286 | outputs = self.ln_f(outputs) 287 | 288 | return outputs, memory 289 | 290 | 291 | @register_model_architecture("linear_transformer_multi", "linear_transformer_multi") 292 | def base_architecture(args): 293 | 294 | args.embed_dim = getattr(args, "embed_dim", 512) 295 | args.num_attention_heads = getattr(args, "num_attention_heads", 16) 296 | args.num_layers = getattr(args, "num_layers", 12) 297 | args.dropout = getattr(args, "dropout", 0.1) 298 | 299 | @register_model_architecture("linear_transformer_multi", "linear_transformer_multi_large") 300 | def base_architecture(args): 301 | args.embed_dim = getattr(args, "embed_dim", 768) 302 | args.num_attention_heads = getattr(args, "num_attention_heads", 12) 303 | args.num_layers = getattr(args, "num_layers", 12) 304 | args.dropout = getattr(args, "dropout", 0.1) 305 | 306 | 307 | class TupleMultiHeadDataset(TokenBlockDataset): 308 | def __init__( 309 | self, 310 | dataset, 311 | sizes, 312 | block_size, 313 | pad, 314 | eos, 315 | break_mode=None, 316 | include_targets=False, 317 | document_sep_len=1, 318 | ratio=4+1, 319 | sample_overlap_rate=4, 320 | permutation_invariant=3, 321 | trk_idx=2, # evt dur trk ins rel_pos mea_id 322 | spec_tok_cnt=4, # 323 | evt_vocab_size=425, 324 | trk_vocab_size=44, 325 | ): 326 | try: 327 | from fairseq.data.token_block_utils_fast import ( 328 | _get_slice_indices_fast, 329 | _get_block_to_dataset_index_fast, 330 | ) 331 | except ImportError: 332 | raise ImportError( 333 | "Please build Cython components with: `pip install --editable .` " 334 | "or `python setup.py build_ext --inplace`" 335 | ) 336 | 337 | super(TokenBlockDataset, self).__init__() 338 | self.dataset = dataset 339 | self.pad = pad 340 | self.eos = eos 341 | self.include_targets = include_targets 342 | 343 | 344 | self.ratio = ratio 345 | self.perm_inv = permutation_invariant 346 | self.sample_len_max = block_size 347 | 348 | self.trk_idx = trk_idx 349 | self.cc_idx = evt_vocab_size - 1 350 | self.spec_tok_cnt = spec_tok_cnt 351 | self.max_trk_cnt = trk_vocab_size - spec_tok_cnt 352 | 353 | assert len(dataset) == len(sizes) 354 | assert len(dataset) > 0 355 | 356 | if isinstance(sizes, list): 357 | sizes = np.array(sizes, dtype=np.int64) 358 | else: 359 | if torch.is_tensor(sizes): 360 | sizes = sizes.numpy() 361 | sizes = sizes.astype(np.int64) 362 | 363 | break_mode = break_mode if break_mode is not None else "complete_doc" 364 | assert break_mode == 'complete_doc', break_mode 365 | 366 | 367 | 368 | sizes_cs = np.cumsum(sizes) 369 | piece_sep_ids = np.where(sizes == document_sep_len)[0].tolist() 370 | totpieces = len(piece_sep_ids) 371 | slice_indices = np.zeros((totpieces,2), dtype=int) 372 | block_to_dataset_index = np.zeros((totpieces,3), dtype=int) 373 | 374 | for i in range(len(piece_sep_ids)): 375 | s = piece_sep_ids[i-1] if i > 0 else -1 376 | e = piece_sep_ids[i] 377 | slice_indices[i, :] = (sizes_cs[s] if s >= 0 else 0, sizes_cs[e-1]) 378 | block_to_dataset_index[i, :] = (s+1, 0, e-1) 379 | 380 | 381 | # slice_indices_std = _get_slice_indices_fast( 382 | # sizes, str(break_mode), INF, document_sep_len 383 | # ) 384 | # assert((slice_indices == slice_indices_std).all()) 385 | # block_to_dataset_index_std = _get_block_to_dataset_index_fast( 386 | # sizes, 387 | # slice_indices, 388 | # ) 389 | # assert((block_to_dataset_index == block_to_dataset_index_std).all()) 390 | 391 | 392 | #print(slice_indices.shape) 393 | sample_step = max(round(self.sample_len_max / sample_overlap_rate), 1) 394 | new_slice_indices = [] 395 | new_block_to_dataset_index = [] 396 | for line, line_piece in zip(slice_indices, block_to_dataset_index): 397 | l_piece_tot = line[1] - line[0] 398 | assert l_piece_tot % self.ratio == 0, (line[0], line[1]) 399 | l_toks = l_piece_tot // self.ratio 400 | chosen_cnt = math.ceil((l_toks + np.random.randint(sample_step)) / sample_step) 401 | #chosen_cnt = sum(1 for _ in range(0 - np.random.randint(sample_step), l_toks, sample_step)) 402 | new_slice_indices.append(np.stack([line]*chosen_cnt)) 403 | new_block_to_dataset_index.append(np.stack([line_piece]*chosen_cnt)) 404 | 405 | slice_indices = np.concatenate(new_slice_indices) 406 | block_to_dataset_index = np.concatenate(new_block_to_dataset_index) 407 | #print(slice_indices.shape) 408 | 409 | self._sizes = slice_indices[:, 1] - slice_indices[:, 0] 410 | self._sizes[:] = self.sample_len_max 411 | 412 | self._slice_indices = plasma_utils.PlasmaArray(slice_indices) 413 | self._sizes = plasma_utils.PlasmaArray(self._sizes) 414 | self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index) 415 | 416 | def __getitem__(self, index): 417 | # start_ds_idx means measure number 418 | # start_offset must be 0 419 | # end_ds_idx means after {sample_len_max} tokens, which measure the end token in 420 | 421 | start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] 422 | assert start_offset == 0, (start_ds_idx, start_offset, end_ds_idx) 423 | 424 | st = np.random.randint(start_ds_idx, end_ds_idx+1) 425 | 426 | #print(start_ds_idx, end_ds_idx) 427 | buffer = [] 428 | cur_len = 0 429 | for idx in range(st, end_ds_idx+1): 430 | tmp = self.dataset[idx].view(-1, self.ratio) 431 | if self.perm_inv % 2 == 1: # swap cc, pos(data aug for auto-regressive) 432 | #TODO: swap pos 433 | all_cc_pos = torch.nonzero(tmp[..., 0] == self.cc_idx).view(-1).tolist() # find all cc indexs 434 | all_cc_pos.append(tmp.size(0)) 435 | to_swap = [] 436 | for pos, nexp in zip(all_cc_pos[:-1], all_cc_pos[1:]): # split to list 437 | to_swap.append(tmp[pos:nexp, ...]) 438 | # to_swap_idx = list(range(len(to_swap))) 439 | # random.shuffle(to_swap_idx) 440 | to_swap_idx = torch.randperm(len(to_swap)) 441 | tmp = torch.cat([tmp[:all_cc_pos[0], ...]] + [to_swap[x] for x in to_swap_idx]) 442 | #assert not (tmp == self.dataset[idx].view(-1, self.ratio)).all(), (to_swap, all_cc_pos) 443 | mea = (idx-st+1) * 3 444 | # mea_list = [[mea-2], [mea-1]] + [[mea]]*(tmp.size(0)-2) 445 | mea_num = torch.zeros((tmp.size(0),1), dtype=int) 446 | mea_num[2:, 0] = mea 447 | mea_num[1][0] = mea-1 448 | mea_num[0][0] = mea-2 449 | buffer.append(torch.cat((tmp, mea_num), dim=1)) 450 | cur_len += tmp.size(0) 451 | if cur_len >= self.sample_len_max: 452 | break 453 | 454 | 455 | buffer = torch.cat(buffer) 456 | if cur_len < self.sample_len_max: 457 | buffer = torch.cat([buffer, buffer.new([[self.eos]*(self.ratio+1)])]) 458 | 459 | 460 | item = buffer[:self.sample_len_max, ...] 461 | if self.perm_inv > 0: 462 | #TODO: should we assure drum track always be track 0? (give model some info) 463 | perm = torch.cat([torch.arange(self.spec_tok_cnt), torch.randperm(self.max_trk_cnt) + self.spec_tok_cnt]) 464 | item[..., self.trk_idx].apply_(lambda x: perm[x]) 465 | # cmp = self.dataset[st].view(-1, self.ratio)[..., self.trk_idx] 466 | # assert not (item[:cmp.size(0), self.trk_idx] == cmp).all() 467 | 468 | assert self.include_targets 469 | 470 | # *target* is the original sentence (=item) 471 | # *source* is shifted right by 1 (maybe left-padded with eos) 472 | # *past_target* is shifted right by 2 (left-padded as needed) 473 | # rel_pos is 0, mea_id is 0 474 | source = torch.cat([item.new([[self.eos]*(self.ratio-1) + [0, 0]]), item[:-1, ...]]) 475 | on = torch.sum(item[:, 1].ne(self.pad)).item() # if no mapping to pad, on will be item.size(0) 476 | #print(item.size(), on) 477 | # past_target = torch.cat( 478 | # [item.new([[self.pad]*(self.ratio+1), [self.eos]*(self.ratio+1)]), item[:-2, ...]] 479 | # ) 480 | 481 | return source, item, on 482 | 483 | def collate_tokens( 484 | values, 485 | pad_idx, 486 | eos_idx=None, 487 | left_pad=False, 488 | ): 489 | """Convert a list of 2d tensors into a padded 3d tensor.""" 490 | size = max(v.size(0) for v in values) # max batch size 491 | 492 | res = values[0].new(len(values), size, values[0].size(-1)).fill_(pad_idx) 493 | 494 | def copy_tensor(src, dst): 495 | assert dst.numel() == src.numel() 496 | dst.copy_(src) 497 | 498 | for i, v in enumerate(values): 499 | copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) 500 | 501 | return res 502 | 503 | # pad = 1, eos = 2 504 | def collate(samples, pad_idx, eos_idx): 505 | if len(samples) == 0: 506 | return {} 507 | # print('raw length', end = ' ') 508 | # for s in samples: 509 | # print(len(s['source']), end = ' ') 510 | # print() 511 | def merge(key, is_list=False): 512 | if is_list: 513 | res = [] 514 | for i in range(len(samples[0][key])): 515 | res.append( 516 | collate_tokens( 517 | [s[key][i] for s in samples], 518 | pad_idx, 519 | eos_idx, 520 | left_pad=False, 521 | ) 522 | ) 523 | return res 524 | else: 525 | return collate_tokens( 526 | [s[key] for s in samples], 527 | pad_idx, 528 | eos_idx, 529 | left_pad=False, 530 | ) 531 | 532 | src_tokens = merge("source") 533 | if samples[0]["target"] is not None: 534 | is_target_list = isinstance(samples[0]["target"], list) 535 | target = merge("target", is_target_list) 536 | else: 537 | target = src_tokens 538 | 539 | #print(torch.LongTensor([s["source"].numel() // ratio for s in samples])) 540 | return { 541 | "id": torch.LongTensor([s["id"] for s in samples]), 542 | "nsentences": len(samples), 543 | "ntokens": sum(s["source"].size(0) for s in samples), 544 | "net_input": { 545 | "src_tokens": src_tokens, 546 | "src_lengths": torch.LongTensor([s["source"].size(0) for s in samples]), 547 | }, 548 | "target": target, 549 | "ontokens": sum(s["on"] for s in samples) 550 | } 551 | 552 | class MultiheadDataset(MonolingualDataset): 553 | def __init__( 554 | self, 555 | dataset, 556 | sizes, 557 | src_vocab, 558 | tgt_vocab, 559 | add_eos_for_other_targets, 560 | shuffle, 561 | targets=None, 562 | add_bos_token=False, 563 | ): 564 | # print(len(sizes)) 565 | # print(type(dataset)) 566 | # print(len(dataset)) 567 | self.dataset = dataset 568 | self.sizes = np.array(sizes) 569 | self.vocab = src_vocab 570 | self.tgt_vocab = tgt_vocab 571 | self.add_eos_for_other_targets = add_eos_for_other_targets 572 | self.shuffle = shuffle 573 | self.add_bos_token = add_bos_token 574 | assert not self.add_bos_token, " is occupied" 575 | 576 | assert targets is None or all( 577 | t in {"self", "future", "past"} for t in targets 578 | ), "targets must be none or one of 'self', 'future', 'past'" 579 | if targets is not None and len(targets) == 0: 580 | targets = None 581 | assert len(targets) == 1 and targets[0] == 'future' 582 | self.targets = targets 583 | def collater(self, samples): 584 | return collate(samples, self.vocab.pad(), self.vocab.eos()) 585 | 586 | def __getitem__(self, index): 587 | assert self.targets is not None 588 | source, target, on = self.dataset[index] 589 | source, target = self._make_source_target( 590 | source, target, None 591 | ) 592 | 593 | source, target = self._maybe_add_bos(source, target) 594 | return {"id": index, "source": source, "target": target, "on": on} 595 | 596 | 597 | @dataclass 598 | class SymphonyModelingConfig(LanguageModelingConfig): 599 | 600 | ratio: int = field( 601 | default=4, metadata={"help": "note/metadata attribute amount: default (evt, dur, trk, ins)"} 602 | ) 603 | evt_voc_size: int = field( 604 | default=-1, metadata={"help": "event vocab size"} 605 | ) 606 | dur_voc_size: int = field( 607 | default=-1, metadata={"help": "duration vocab size"} 608 | ) 609 | trk_voc_size: int = field( 610 | default=-1, metadata={"help": "track vocab size"} 611 | ) 612 | ins_voc_size: int = field( 613 | default=-1, metadata={"help": "instrument vocab size"} 614 | ) 615 | max_rel_pos: int = field( 616 | default=-1, metadata={"help": "maximum relative position index, calculated by make_data.py"} 617 | ) 618 | max_mea_pos: int = field( 619 | default=-1, metadata={"help": "maximum measure cnt within a sample, calculated by make_data.py"} 620 | ) 621 | perm_inv: int = field( 622 | default=3, metadata={"help": "consider permutation invariance for music, 0: without PI, 1: data augmentation only, 2: positional encoding only, 3: all considered"} 623 | ) 624 | sample_overlap_rate: int = field( 625 | default=4, metadata={"help": "sample overlap rate, default is 4 (stride 1024), also needed in make_data.py"} 626 | ) 627 | 628 | @register_task("symphony_modeling", dataclass=SymphonyModelingConfig) 629 | class SymphonyModelingTask(LanguageModelingTask): 630 | def load_dataset(self, split, epoch=1, combine=False, **kwargs): 631 | """Load a given dataset split. 632 | 633 | Args: 634 | split (str): name of the split (e.g., train, valid, test) 635 | """ 636 | paths = utils.split_paths(self.args.data) 637 | assert len(paths) > 0 638 | 639 | data_path = paths[(epoch - 1) % len(paths)] 640 | split_path = os.path.join(data_path, split) 641 | 642 | dataset = data_utils.load_indexed_dataset( 643 | split_path, self.dictionary, self.args.dataset_impl, combine=combine 644 | ) 645 | if dataset is None: 646 | raise FileNotFoundError( 647 | "Dataset not found: {} ({})".format(split, split_path) 648 | ) 649 | #print('load indexed dataset finished') 650 | dataset = maybe_shorten_dataset( 651 | dataset, 652 | split, 653 | self.args.shorten_data_split_list, 654 | self.args.shorten_method, 655 | self.args.tokens_per_sample, 656 | self.args.seed, 657 | ) 658 | #print('maybe_shorten_dataset finished') 659 | dataset = TupleMultiHeadDataset( 660 | dataset, 661 | dataset.sizes, 662 | self.args.tokens_per_sample, 663 | pad=self.dictionary.pad(), 664 | eos=self.dictionary.eos(), 665 | break_mode=self.args.sample_break_mode, 666 | include_targets=True, 667 | ratio=self.args.ratio + 1, 668 | sample_overlap_rate=self.args.sample_overlap_rate, 669 | permutation_invariant=self.args.perm_inv, 670 | #trk_idx=self.args.trk_idx, 671 | #spec_tok_cnt=self.args.spec_tok_cnt, 672 | evt_vocab_size=self.args.evt_voc_size, 673 | trk_vocab_size=self.args.trk_voc_size, 674 | ) 675 | #print('TupleMultiHeadDataset init finished') 676 | add_eos_for_other_targets = ( 677 | self.args.sample_break_mode is not None 678 | and self.args.sample_break_mode != "none" 679 | ) 680 | 681 | self.datasets[split] = self._initialize_dataset( 682 | dataset=dataset, 683 | sizes=dataset.sizes, 684 | src_vocab=self.dictionary, 685 | tgt_vocab=self.output_dictionary, 686 | add_eos_for_other_targets=add_eos_for_other_targets, 687 | shuffle=True, 688 | targets=self.targets, 689 | add_bos_token=self.args.add_bos_token, 690 | ) 691 | #print('_initialize_dataset finished') 692 | 693 | def _initialize_dataset(self, **kwargs): 694 | return MultiheadDataset(**kwargs) 695 | def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): 696 | assert False, "inference not implemented" -------------------------------------------------------------------------------- /src/fairseq/linear_transformer/linear_transformer_multi.py: -------------------------------------------------------------------------------- 1 | from fast_transformers.builders import TransformerEncoderBuilder, RecurrentEncoderBuilder 2 | from fast_transformers.masking import TriangularCausalMask, LengthMask 3 | 4 | import logging, math 5 | import os 6 | import sys 7 | 8 | import numpy as np 9 | from typing import Dict, List, Optional, Tuple 10 | from dataclasses import dataclass, field 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch import Tensor 16 | 17 | from fairseq.data.shorten_dataset import maybe_shorten_dataset 18 | from fairseq import utils, metrics 19 | from fairseq.tasks.language_modeling import LanguageModelingTask, LanguageModelingConfig 20 | from fairseq.tasks import register_task 21 | from fairseq.models import ( 22 | FairseqDecoder, 23 | FairseqLanguageModel, 24 | register_model, 25 | register_model_architecture, 26 | ) 27 | from fairseq.data import ( 28 | MonolingualDataset, 29 | TokenBlockDataset, 30 | plasma_utils, 31 | data_utils, 32 | ) 33 | from fairseq.criterions import register_criterion 34 | from fairseq.criterions.cross_entropy import CrossEntropyCriterion 35 | 36 | 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | DEFAULT_MAX_TARGET_POSITIONS = 1024 42 | # INF = 2147483647 43 | 44 | @register_criterion("multiple_loss") 45 | class MultiplelossCriterion(CrossEntropyCriterion): 46 | def forward(self, model, sample, reduce=True): 47 | """Compute the loss for the given sample. 48 | 49 | Returns a tuple with three elements: 50 | 1) the loss 51 | 2) the sample size, which is used as the denominator for the gradient 52 | 3) logging outputs to display while training 53 | """ 54 | net_output = model(**sample["net_input"]) 55 | losses = self.compute_loss(model, net_output, sample, reduce=reduce) # return a list 56 | assert not self.sentence_avg 57 | #TODO: adjust weight of evt losses and other losses by length (current strategy: simple average the losses) 58 | # weights = [sample["ntokens"]] + [sample["ontokens"]] * (len(losses) - 1) 59 | loss = torch.mean(torch.stack(losses)) 60 | logging_output = { 61 | "loss": loss.data, 62 | "evt_loss": losses[0].data, 63 | "dur_loss": losses[1].data, 64 | "trk_loss": losses[2].data, 65 | "ins_loss": losses[3].data, 66 | "ntokens": sample["ntokens"], 67 | "nsentences": sample["target"].size(0), 68 | "sample_size": sample["ntokens"], 69 | "on_sample_size": sample["ntokens"], 70 | } 71 | return loss, sample["ntokens"], logging_output 72 | 73 | def compute_loss(self, model, net_output, sample, reduce=True): 74 | lprobs_tuple = model.get_normalized_probs(net_output, log_probs=True) 75 | losses = [] 76 | for idx, lprobs in enumerate(lprobs_tuple): 77 | lprobs = lprobs.view(-1, lprobs.size(-1)) 78 | target = model.get_targets(sample, net_output)[..., idx].view(-1) 79 | 80 | loss = F.nll_loss( 81 | lprobs, 82 | target, 83 | ignore_index=self.padding_idx, 84 | reduction="sum" if reduce else "none", 85 | ) 86 | losses.append(loss) 87 | return losses 88 | 89 | @staticmethod 90 | def reduce_metrics(logging_outputs) -> None: 91 | """Aggregate logging outputs from data parallel training.""" 92 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 93 | loss_evt = sum(log.get("evt_loss", 0) for log in logging_outputs) 94 | loss_dur = sum(log.get("dur_loss", 0) for log in logging_outputs) 95 | loss_trk = sum(log.get("trk_loss", 0) for log in logging_outputs) 96 | loss_ins = sum(log.get("ins_loss", 0) for log in logging_outputs) 97 | ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) 98 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 99 | on_sample_size = sum(log.get("on_sample_size", 0) for log in logging_outputs) 100 | # we divide by log(2) to convert the loss from base e to base 2 101 | # total_losses = 4 102 | # weighted_size = (sample_size + on_sample_size*(total_losses-1)) / total_losses 103 | metrics.log_scalar( 104 | "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 105 | ) 106 | metrics.log_scalar( 107 | "evt_loss", loss_evt / sample_size / math.log(2), sample_size, round=3 108 | ) 109 | metrics.log_scalar( 110 | "dur_loss", loss_dur / on_sample_size / math.log(2), on_sample_size, round=3 111 | ) 112 | metrics.log_scalar( 113 | "trk_loss", loss_trk / on_sample_size / math.log(2), on_sample_size, round=3 114 | ) 115 | metrics.log_scalar( 116 | "ins_loss", loss_ins / on_sample_size / math.log(2), on_sample_size, round=3 117 | ) 118 | 119 | if sample_size != ntokens: 120 | metrics.log_scalar( 121 | "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 122 | ) 123 | metrics.log_derived( 124 | "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) 125 | ) 126 | else: 127 | metrics.log_derived( 128 | "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) 129 | ) 130 | metrics.log_derived( 131 | "evt_ppl", lambda meters: utils.get_perplexity(meters["evt_loss"].avg) 132 | ) 133 | metrics.log_derived( 134 | "dur_ppl", lambda meters: utils.get_perplexity(meters["dur_loss"].avg) 135 | ) 136 | metrics.log_derived( 137 | "trk_ppl", lambda meters: utils.get_perplexity(meters["trk_loss"].avg) 138 | ) 139 | metrics.log_derived( 140 | "ins_ppl", lambda meters: utils.get_perplexity(meters["ins_loss"].avg) 141 | ) 142 | 143 | @staticmethod 144 | def logging_outputs_can_be_summed() -> bool: 145 | """ 146 | Whether the logging outputs returned by `forward` can be summed 147 | across workers prior to calling `reduce_metrics`. Setting this 148 | to True will improves distributed training speed. 149 | """ 150 | return True 151 | 152 | @register_model("linear_transformer_multi") 153 | class LinearTransformerMultiHeadLM(FairseqLanguageModel): 154 | def __init__(self, decoder): 155 | super().__init__(decoder) 156 | 157 | @staticmethod 158 | def add_args(parser): 159 | """Add model-specific arguments to the parser.""" 160 | # fmt: off 161 | parser.add_argument('--embed-dim', type=int, metavar='N', 162 | help='embedding dimension') 163 | parser.add_argument('--num-attention-heads', type=int, metavar='N', 164 | help='num attention heads') 165 | parser.add_argument('--num-layers', type=int, metavar='N', 166 | help='num layers') 167 | parser.add_argument('--dropout', type=float, metavar='D', 168 | help='dropout probability for all fully connected layers ' 169 | 'in the embeddings, encoder, and pooler') 170 | 171 | # parser.add_argument('--max-pos-len', type=int, metavar='N', 172 | # help='max positions in transformer') 173 | 174 | # parser.add_argument('--attention-dropout', type=float, metavar='D', 175 | # help='dropout probability for attention weights') 176 | # fmt: on 177 | 178 | @classmethod 179 | def build_model(cls, args, task): 180 | """Build a new model instance.""" 181 | base_architecture(args) 182 | return cls(LinearTransformerMultiHeadDecoder(args, task)) 183 | 184 | 185 | class LinearTransformerMultiHeadDecoder(FairseqDecoder): 186 | def __init__(self, args, task): 187 | 188 | super().__init__(task.target_dictionary) 189 | #print(task.target_dictionary) 190 | # for i in range(len(task.target_dictionary)): 191 | # print(i, task.target_dictionary[i]) 192 | self.embed_dim = args.embed_dim 193 | self.wEvte = nn.Embedding(args.evt_voc_size, args.embed_dim) 194 | self.wTrke = nn.Embedding(args.trk_voc_size, args.embed_dim) 195 | self.wDure = nn.Embedding(args.dur_voc_size, args.embed_dim) 196 | self.max_pos = args.tokens_per_sample 197 | #self.ratio = args.ratio 198 | #print("max positions:", self.max_pos) 199 | 200 | self.perm_inv = args.perm_inv 201 | if self.perm_inv > 1: 202 | self.wRpe = nn.Embedding(args.max_rel_pos+1, args.embed_dim) 203 | self.wMpe = nn.Embedding(args.max_mea_pos+1, args.embed_dim) 204 | else: 205 | self.wpe = nn.Embedding(self.max_pos+1, args.embed_dim) # max_pos_len = 4096 206 | self.drop = nn.Dropout(args.dropout) 207 | self.ln_f = nn.LayerNorm(args.embed_dim, eps=1e-6) 208 | 209 | 210 | self.model = TransformerEncoderBuilder.from_kwargs( 211 | n_layers=args.num_layers, 212 | n_heads=args.num_attention_heads, 213 | query_dimensions=args.embed_dim // args.num_attention_heads, 214 | value_dimensions=args.embed_dim // args.num_attention_heads, 215 | feed_forward_dimensions=4 * args.embed_dim, 216 | activation='gelu', 217 | #final_normalization=True, 218 | dropout=args.dropout, 219 | attention_type="causal-linear", 220 | #feature_map=Favor.factory(n_dims=self.d_model) 221 | ).get() 222 | 223 | self.attn_mask = TriangularCausalMask(self.max_pos) 224 | self.proj_evt = nn.Linear(args.embed_dim, args.evt_voc_size, bias=False) 225 | self.proj_dur = nn.Linear(args.embed_dim, args.dur_voc_size, bias=False) 226 | self.proj_trk = nn.Linear(args.embed_dim, args.trk_voc_size, bias=False) 227 | self.proj_ins = nn.Linear(args.embed_dim, args.ins_voc_size, bias=False) 228 | 229 | self.apply(self._init_weights) 230 | # set zero embedding for padding symbol 231 | #TODO: check will the pad id be trained? (as TZ RZ YZ) 232 | self.pad_idx = task.target_dictionary.pad() 233 | self.wEvte.weight.data[self.pad_idx].zero_() 234 | self.wDure.weight.data[self.pad_idx].zero_() 235 | self.wTrke.weight.data[self.pad_idx].zero_() 236 | if self.perm_inv > 1: 237 | self.wRpe.weight.data[0].zero_() 238 | self.wMpe.weight.data[0].zero_() 239 | else: 240 | self.wpe.weight.data[0].zero_() 241 | 242 | def _init_weights(self, module): 243 | if isinstance(module, (nn.Linear, nn.Embedding)): 244 | module.weight.data.normal_(mean=0.0, std=self.embed_dim ** -0.5) 245 | if isinstance(module, nn.Linear) and module.bias is not None: 246 | module.bias.data.zero_() 247 | elif isinstance(module, nn.LayerNorm): 248 | module.bias.data.zero_() 249 | module.weight.data.fill_(1.0) 250 | 251 | def forward( 252 | self, 253 | x, 254 | src_lengths=None, 255 | ): 256 | features = self.extract_features(x, src_lengths) 257 | evt_logits = self.proj_evt(features) 258 | dur_logits = self.proj_dur(features) 259 | trk_logits = self.proj_trk(features) 260 | ins_logits = self.proj_ins(features) 261 | 262 | return (evt_logits, dur_logits, trk_logits, ins_logits) 263 | 264 | def extract_features( 265 | self, 266 | x, 267 | src_lengths = None 268 | ): 269 | 270 | bsz, seq_len, ratio = x.size() 271 | evt_emb = self.wEvte(x[..., 0]) 272 | 273 | # if not mapping to pad, padding idx will only occer at last 274 | evton_mask = x[..., 1].ne(self.pad_idx).float()[..., None].to(x.device) 275 | 276 | tmp = self.wDure(x[..., 1]) 277 | dur_emb = tmp * evton_mask 278 | # assert ((tmp==dur_emb).all()) 279 | tmp = self.wTrke(x[..., 2]) 280 | trk_emb = tmp * evton_mask 281 | # assert ((tmp==trk_emb).all()) 282 | 283 | pad_mask = x[..., 0].ne(self.pad_idx).long().to(x.device) 284 | if src_lengths is not None: 285 | len_mask = LengthMask(src_lengths, max_len=seq_len, device=x.device) 286 | else: 287 | len_mask = LengthMask(torch.sum(pad_mask, axis=1), max_len=seq_len, device=x.device) 288 | 289 | 290 | if self.perm_inv > 1: 291 | rel_pos = pad_mask * x[..., 4] 292 | rel_pos_mask = rel_pos.ne(0).float()[..., None].to(x.device) # ignore bom, chord, eos 293 | 294 | measure_ids = pad_mask * x[..., 5] 295 | mea_mask = measure_ids.ne(0).float()[..., None].to(x.device) # ignore eos 296 | 297 | pos_emb = rel_pos_mask * self.wRpe(rel_pos) + mea_mask * self.wMpe(measure_ids) 298 | 299 | else: 300 | # set position ids to exclude padding symbols 301 | position_ids = pad_mask * ( 302 | torch.arange(1, 1 + seq_len) 303 | .to(x.device) 304 | .repeat(bsz, 1) 305 | ) 306 | pos_emb = self.wpe(position_ids) 307 | 308 | x = self.drop(evt_emb+dur_emb+trk_emb+pos_emb) 309 | 310 | 311 | outputs = self.model(x, self.attn_mask, len_mask) 312 | outputs = self.ln_f(outputs) 313 | 314 | return outputs 315 | 316 | def get_normalized_probs( 317 | self, 318 | net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], 319 | log_probs: bool, 320 | sample: Optional[Dict[str, Tensor]] = None, 321 | ): 322 | """Get normalized probabilities (or log probs) from a net's output.""" 323 | 324 | if log_probs: 325 | return tuple(utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) for logits in net_output) 326 | else: 327 | return tuple(utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) for logits in net_output) 328 | 329 | def max_positions(self): 330 | return None 331 | 332 | 333 | 334 | @register_model_architecture("linear_transformer_multi", "linear_transformer_multi") 335 | def base_architecture(args): 336 | 337 | args.embed_dim = getattr(args, "embed_dim", 512) 338 | args.num_attention_heads = getattr(args, "num_attention_heads", 16) 339 | args.num_layers = getattr(args, "num_layers", 12) 340 | args.dropout = getattr(args, "dropout", 0.1) 341 | 342 | @register_model_architecture("linear_transformer_multi", "linear_transformer_multi_large") 343 | def base_architecture(args): 344 | args.embed_dim = getattr(args, "embed_dim", 768) 345 | args.num_attention_heads = getattr(args, "num_attention_heads", 12) 346 | args.num_layers = getattr(args, "num_layers", 12) 347 | args.dropout = getattr(args, "dropout", 0.1) 348 | 349 | 350 | class TupleMultiHeadDataset(TokenBlockDataset): 351 | def __init__( 352 | self, 353 | dataset, 354 | sizes, 355 | block_size, 356 | pad, 357 | eos, 358 | break_mode=None, 359 | include_targets=False, 360 | document_sep_len=1, 361 | ratio=4+1, 362 | sample_overlap_rate=4, 363 | permutation_invariant=3, 364 | trk_idx=2, # evt dur trk ins rel_pos mea_id 365 | spec_tok_cnt=4, # 366 | evt_vocab_size=425, 367 | trk_vocab_size=44, 368 | ): 369 | try: 370 | from fairseq.data.token_block_utils_fast import ( 371 | _get_slice_indices_fast, 372 | _get_block_to_dataset_index_fast, 373 | ) 374 | except ImportError: 375 | raise ImportError( 376 | "Please build Cython components with: `pip install --editable .` " 377 | "or `python setup.py build_ext --inplace`" 378 | ) 379 | 380 | super(TokenBlockDataset, self).__init__() 381 | self.dataset = dataset 382 | self.pad = pad 383 | self.eos = eos 384 | self.include_targets = include_targets 385 | 386 | 387 | self.ratio = ratio 388 | self.perm_inv = permutation_invariant 389 | self.sample_len_max = block_size 390 | 391 | self.trk_idx = trk_idx 392 | self.cc_idx = evt_vocab_size - 1 393 | self.spec_tok_cnt = spec_tok_cnt 394 | self.max_trk_cnt = trk_vocab_size - spec_tok_cnt 395 | 396 | assert len(dataset) == len(sizes) 397 | assert len(dataset) > 0 398 | 399 | if isinstance(sizes, list): 400 | sizes = np.array(sizes, dtype=np.int64) 401 | else: 402 | if torch.is_tensor(sizes): 403 | sizes = sizes.numpy() 404 | sizes = sizes.astype(np.int64) 405 | 406 | break_mode = break_mode if break_mode is not None else "complete_doc" 407 | assert break_mode == 'complete_doc', break_mode 408 | 409 | 410 | 411 | sizes_cs = np.cumsum(sizes) 412 | piece_sep_ids = np.where(sizes == document_sep_len)[0].tolist() 413 | totpieces = len(piece_sep_ids) 414 | slice_indices = np.zeros((totpieces,2), dtype=int) 415 | block_to_dataset_index = np.zeros((totpieces,3), dtype=int) 416 | 417 | for i in range(len(piece_sep_ids)): 418 | s = piece_sep_ids[i-1] if i > 0 else -1 419 | e = piece_sep_ids[i] 420 | slice_indices[i, :] = (sizes_cs[s] if s >= 0 else 0, sizes_cs[e-1]) 421 | block_to_dataset_index[i, :] = (s+1, 0, e-1) 422 | 423 | 424 | # slice_indices_std = _get_slice_indices_fast( 425 | # sizes, str(break_mode), INF, document_sep_len 426 | # ) 427 | # assert((slice_indices == slice_indices_std).all()) 428 | # block_to_dataset_index_std = _get_block_to_dataset_index_fast( 429 | # sizes, 430 | # slice_indices, 431 | # ) 432 | # assert((block_to_dataset_index == block_to_dataset_index_std).all()) 433 | 434 | 435 | #print(slice_indices.shape) 436 | sample_step = max(round(self.sample_len_max / sample_overlap_rate), 1) 437 | new_slice_indices = [] 438 | new_block_to_dataset_index = [] 439 | for line, line_piece in zip(slice_indices, block_to_dataset_index): 440 | l_piece_tot = line[1] - line[0] 441 | assert l_piece_tot % self.ratio == 0, (line[0], line[1]) 442 | l_toks = l_piece_tot // self.ratio 443 | chosen_cnt = math.ceil((l_toks + np.random.randint(sample_step)) / sample_step) 444 | #chosen_cnt = sum(1 for _ in range(0 - np.random.randint(sample_step), l_toks, sample_step)) 445 | new_slice_indices.append(np.stack([line]*chosen_cnt)) 446 | new_block_to_dataset_index.append(np.stack([line_piece]*chosen_cnt)) 447 | 448 | slice_indices = np.concatenate(new_slice_indices) 449 | block_to_dataset_index = np.concatenate(new_block_to_dataset_index) 450 | #print(slice_indices.shape) 451 | 452 | self._sizes = slice_indices[:, 1] - slice_indices[:, 0] 453 | self._sizes[:] = self.sample_len_max 454 | 455 | self._slice_indices = plasma_utils.PlasmaArray(slice_indices) 456 | self._sizes = plasma_utils.PlasmaArray(self._sizes) 457 | self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index) 458 | 459 | def __getitem__(self, index): 460 | # start_ds_idx means measure number 461 | # start_offset must be 0 462 | # end_ds_idx means after {sample_len_max} tokens, which measure the end token in 463 | 464 | start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] 465 | assert start_offset == 0, (start_ds_idx, start_offset, end_ds_idx) 466 | 467 | st = np.random.randint(start_ds_idx, end_ds_idx+1) 468 | 469 | #print(start_ds_idx, end_ds_idx) 470 | buffer = [] 471 | cur_len = 0 472 | for idx in range(st, end_ds_idx+1): 473 | tmp = self.dataset[idx].view(-1, self.ratio) 474 | if self.perm_inv % 2 == 1: # swap cc, pos(data aug for auto-regressive) 475 | #TODO: swap pos 476 | all_cc_pos = torch.nonzero(tmp[..., 0] == self.cc_idx).view(-1).tolist() # find all cc indexs 477 | all_cc_pos.append(tmp.size(0)) 478 | to_swap = [] 479 | for pos, nexp in zip(all_cc_pos[:-1], all_cc_pos[1:]): # split to list 480 | to_swap.append(tmp[pos:nexp, ...]) 481 | # to_swap_idx = list(range(len(to_swap))) 482 | # random.shuffle(to_swap_idx) 483 | to_swap_idx = torch.randperm(len(to_swap)) 484 | tmp = torch.cat([tmp[:all_cc_pos[0], ...]] + [to_swap[x] for x in to_swap_idx]) 485 | #assert not (tmp == self.dataset[idx].view(-1, self.ratio)).all(), (to_swap, all_cc_pos) 486 | mea = (idx-st+1) * 3 487 | # mea_list = [[mea-2], [mea-1]] + [[mea]]*(tmp.size(0)-2) 488 | mea_num = torch.zeros((tmp.size(0),1), dtype=int) 489 | mea_num[2:, 0] = mea 490 | mea_num[1][0] = mea-1 491 | mea_num[0][0] = mea-2 492 | buffer.append(torch.cat((tmp, mea_num), dim=1)) 493 | cur_len += tmp.size(0) 494 | if cur_len >= self.sample_len_max: 495 | break 496 | 497 | 498 | buffer = torch.cat(buffer) 499 | if cur_len < self.sample_len_max: 500 | buffer = torch.cat([buffer, buffer.new([[self.eos]*(self.ratio+1)])]) 501 | 502 | 503 | item = buffer[:self.sample_len_max, ...] 504 | if self.perm_inv > 0: 505 | #TODO: should we assure drum track always be track 0? (give model some info) 506 | perm = torch.cat([torch.arange(self.spec_tok_cnt), torch.randperm(self.max_trk_cnt) + self.spec_tok_cnt]) 507 | item[..., self.trk_idx].apply_(lambda x: perm[x]) 508 | # cmp = self.dataset[st].view(-1, self.ratio)[..., self.trk_idx] 509 | # assert not (item[:cmp.size(0), self.trk_idx] == cmp).all() 510 | 511 | assert self.include_targets 512 | 513 | # *target* is the original sentence (=item) 514 | # *source* is shifted right by 1 (maybe left-padded with eos) 515 | # *past_target* is shifted right by 2 (left-padded as needed) 516 | # rel_pos is 0, mea_id is 0 517 | source = torch.cat([item.new([[self.eos]*(self.ratio-1) + [0, 0]]), item[:-1, ...]]) 518 | on = torch.sum(item[:, 1].ne(self.pad)).item() # if no mapping to pad, on will be item.size(0) 519 | #print(item.size(), on) 520 | # past_target = torch.cat( 521 | # [item.new([[self.pad]*(self.ratio+1), [self.eos]*(self.ratio+1)]), item[:-2, ...]] 522 | # ) 523 | 524 | return source, item, on 525 | 526 | def collate_tokens( 527 | values, 528 | pad_idx, 529 | eos_idx=None, 530 | left_pad=False, 531 | ): 532 | """Convert a list of 2d tensors into a padded 3d tensor.""" 533 | size = max(v.size(0) for v in values) # max batch size 534 | 535 | res = values[0].new(len(values), size, values[0].size(-1)).fill_(pad_idx) 536 | 537 | def copy_tensor(src, dst): 538 | assert dst.numel() == src.numel() 539 | dst.copy_(src) 540 | 541 | for i, v in enumerate(values): 542 | copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) 543 | 544 | return res 545 | 546 | # pad = 1, eos = 2 547 | def collate(samples, pad_idx, eos_idx): 548 | if len(samples) == 0: 549 | return {} 550 | # print('raw length', end = ' ') 551 | # for s in samples: 552 | # print(len(s['source']), end = ' ') 553 | # print() 554 | def merge(key, is_list=False): 555 | if is_list: 556 | res = [] 557 | for i in range(len(samples[0][key])): 558 | res.append( 559 | collate_tokens( 560 | [s[key][i] for s in samples], 561 | pad_idx, 562 | eos_idx, 563 | left_pad=False, 564 | ) 565 | ) 566 | return res 567 | else: 568 | return collate_tokens( 569 | [s[key] for s in samples], 570 | pad_idx, 571 | eos_idx, 572 | left_pad=False, 573 | ) 574 | 575 | src_tokens = merge("source") 576 | if samples[0]["target"] is not None: 577 | is_target_list = isinstance(samples[0]["target"], list) 578 | target = merge("target", is_target_list) 579 | else: 580 | target = src_tokens 581 | 582 | #print(torch.LongTensor([s["source"].numel() // ratio for s in samples])) 583 | return { 584 | "id": torch.LongTensor([s["id"] for s in samples]), 585 | "nsentences": len(samples), 586 | "ntokens": sum(s["source"].size(0) for s in samples), 587 | "net_input": { 588 | "src_tokens": src_tokens, 589 | "src_lengths": torch.LongTensor([s["source"].size(0) for s in samples]), 590 | }, 591 | "target": target, 592 | "ontokens": sum(s["on"] for s in samples) 593 | } 594 | 595 | class MultiheadDataset(MonolingualDataset): 596 | def __init__( 597 | self, 598 | dataset, 599 | sizes, 600 | src_vocab, 601 | tgt_vocab, 602 | add_eos_for_other_targets, 603 | shuffle, 604 | targets=None, 605 | add_bos_token=False, 606 | ): 607 | # print(len(sizes)) 608 | # print(type(dataset)) 609 | # print(len(dataset)) 610 | self.dataset = dataset 611 | self.sizes = np.array(sizes) 612 | self.vocab = src_vocab 613 | self.tgt_vocab = tgt_vocab 614 | self.add_eos_for_other_targets = add_eos_for_other_targets 615 | self.shuffle = shuffle 616 | self.add_bos_token = add_bos_token 617 | assert not self.add_bos_token, " is occupied" 618 | 619 | assert targets is None or all( 620 | t in {"self", "future", "past"} for t in targets 621 | ), "targets must be none or one of 'self', 'future', 'past'" 622 | if targets is not None and len(targets) == 0: 623 | targets = None 624 | assert len(targets) == 1 and targets[0] == 'future' 625 | self.targets = targets 626 | def collater(self, samples): 627 | return collate(samples, self.vocab.pad(), self.vocab.eos()) 628 | 629 | def __getitem__(self, index): 630 | assert self.targets is not None 631 | source, target, on = self.dataset[index] 632 | source, target = self._make_source_target( 633 | source, target, None 634 | ) 635 | 636 | source, target = self._maybe_add_bos(source, target) 637 | return {"id": index, "source": source, "target": target, "on": on} 638 | 639 | 640 | 641 | @dataclass 642 | class SymphonyModelingConfig(LanguageModelingConfig): 643 | 644 | ratio: int = field( 645 | default=4, metadata={"help": "note/metadata attribute amount: default (evt, dur, trk, ins)"} 646 | ) 647 | evt_voc_size: int = field( 648 | default=-1, metadata={"help": "event vocab size"} 649 | ) 650 | dur_voc_size: int = field( 651 | default=-1, metadata={"help": "duration vocab size"} 652 | ) 653 | trk_voc_size: int = field( 654 | default=-1, metadata={"help": "track vocab size"} 655 | ) 656 | ins_voc_size: int = field( 657 | default=-1, metadata={"help": "instrument vocab size"} 658 | ) 659 | max_rel_pos: int = field( 660 | default=-1, metadata={"help": "maximum relative position index, calculated by make_data.py"} 661 | ) 662 | max_mea_pos: int = field( 663 | default=-1, metadata={"help": "maximum measure cnt within a sample, calculated by make_data.py"} 664 | ) 665 | perm_inv: int = field( 666 | default=3, metadata={"help": "consider permutation invariance for music, 0: without PI, 1: data augmentation only, 2: positional encoding only, 3: all considered"} 667 | ) 668 | sample_overlap_rate: int = field( 669 | default=4, metadata={"help": "sample overlap rate, default is 4 (stride 1024), also needed in make_data.py"} 670 | ) 671 | 672 | @register_task("symphony_modeling", dataclass=SymphonyModelingConfig) 673 | class SymphonyModelingTask(LanguageModelingTask): 674 | def load_dataset(self, split, epoch=1, combine=False, **kwargs): 675 | """Load a given dataset split. 676 | 677 | Args: 678 | split (str): name of the split (e.g., train, valid, test) 679 | """ 680 | paths = utils.split_paths(self.args.data) 681 | assert len(paths) > 0 682 | 683 | data_path = paths[(epoch - 1) % len(paths)] 684 | split_path = os.path.join(data_path, split) 685 | 686 | dataset = data_utils.load_indexed_dataset( 687 | split_path, self.dictionary, self.args.dataset_impl, combine=combine 688 | ) 689 | if dataset is None: 690 | raise FileNotFoundError( 691 | "Dataset not found: {} ({})".format(split, split_path) 692 | ) 693 | #print('load indexed dataset finished') 694 | dataset = maybe_shorten_dataset( 695 | dataset, 696 | split, 697 | self.args.shorten_data_split_list, 698 | self.args.shorten_method, 699 | self.args.tokens_per_sample, 700 | self.args.seed, 701 | ) 702 | #print('maybe_shorten_dataset finished') 703 | dataset = TupleMultiHeadDataset( 704 | dataset, 705 | dataset.sizes, 706 | self.args.tokens_per_sample, 707 | pad=self.dictionary.pad(), 708 | eos=self.dictionary.eos(), 709 | break_mode=self.args.sample_break_mode, 710 | include_targets=True, 711 | ratio=self.args.ratio + 1, 712 | sample_overlap_rate=self.args.sample_overlap_rate, 713 | permutation_invariant=self.args.perm_inv, 714 | #trk_idx=self.args.trk_idx, 715 | #spec_tok_cnt=self.args.spec_tok_cnt, 716 | evt_vocab_size=self.args.evt_voc_size, 717 | trk_vocab_size=self.args.trk_voc_size, 718 | ) 719 | #print('TupleMultiHeadDataset init finished') 720 | add_eos_for_other_targets = ( 721 | self.args.sample_break_mode is not None 722 | and self.args.sample_break_mode != "none" 723 | ) 724 | 725 | self.datasets[split] = self._initialize_dataset( 726 | dataset=dataset, 727 | sizes=dataset.sizes, 728 | src_vocab=self.dictionary, 729 | tgt_vocab=self.output_dictionary, 730 | add_eos_for_other_targets=add_eos_for_other_targets, 731 | shuffle=True, 732 | targets=self.targets, 733 | add_bos_token=self.args.add_bos_token, 734 | ) 735 | #print('_initialize_dataset finished') 736 | 737 | def _initialize_dataset(self, **kwargs): 738 | return MultiheadDataset(**kwargs) 739 | def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): 740 | assert False, "inference not implemented" 741 | # fairseq.tasks.language_modeling.TokenBlockDataset = TupleMultiHeadDataset 742 | # fairseq.tasks.language_modeling.MonolingualDataset = MultiheadDataset 743 | --------------------------------------------------------------------------------