├── .gitignore ├── License ├── README.md ├── commu ├── __init__.py ├── midi_generator │ ├── __init__.py │ ├── container.py │ ├── generate_pipeline.py │ ├── info_preprocessor.py │ ├── midi_inferrer.py │ ├── model_initializer.py │ └── sequence_postprocessor.py ├── model │ ├── __init__.py │ ├── config_helper.py │ ├── dataset.py │ ├── exp_utils.py │ └── model.py └── preprocessor │ ├── __init__.py │ ├── augment.py │ ├── encoder │ ├── __init__.py │ ├── encoder.py │ ├── encoder_utils.py │ ├── event_tokens.py │ └── meta.py │ ├── parser │ ├── __init__.py │ └── meta.py │ ├── pipeline.py │ ├── preprocessor.py │ └── utils │ ├── __init__.py │ ├── constants.py │ ├── container.py │ ├── exceptions.py │ └── utils.py ├── dataset ├── commu_meta.csv └── commu_midi.tar ├── generate.py ├── logger.py ├── preprocess.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | .venv_3.7.6 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | .idea 133 | .vscode 134 | .DS_Store 135 | tmp*.py 136 | 137 | # shell script 138 | *.sh 139 | -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 POZAlabs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComMU: Dataset for Combinatorial Music Generation 2 | 3 | ![](https://velog.velcdn.com/images/crosstar1228/post/0d2ed81f-06df-46fe-bfcb-8e5729eab6dc/image.png) 4 | 5 | This is the repository of ComMU : Dataset for Combinational Music Generation. It is composed of midi dataset, and codes involving training & generation utilizing the autoregressive music generation model. The dataset contains 11,144 MIDI samples written and created by professional composers. 6 | They consist of short note sequences(4,8,16 bar), and are organized into 12 different metadata. they are as follows: BPM, Genre, Key, Track-instrument, Track-role, Time signature, Pitch range, Number of Measures, Chord progression, Min Velocity, Max Velocity, Rhythm. 7 | and additional document and dataset are showed below. 8 | - [Paper](https://arxiv.org/pdf/2211.09385.pdf) (NeurIPS 2022) 9 | - [Demo Page](https://pozalabs.github.io/ComMU/) 10 | - [Dataset](https://github.com/POZAlabs/ComMU-code/tree/master/dataset) 11 | 12 | 13 | ## Getting Started 14 | - Note : This Project requires python version `3.8.12`. Set the virtual environment if needed. 15 | ### Setup 16 | 1. Clone this repository 17 | 2. Install required packages 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | ### Download the Data 22 | 1. download csv with meta information and zipped raw midi files. 23 | - csv file consists of meta information of each midi file. 24 | 2. unzip midifiles(`commu_midi.tar`). 25 | ``` 26 | $ cd ComMU-code 27 | $ tar -xvf ./dataset/commu_midi.tar -C ./dataset/ 28 | ``` 29 | and if the project tree looks like this, it is ready for preprocessing. 30 | ``` 31 | . 32 | ├── commu_meta.csv 33 | └── commu_midi 34 | └── train 35 | └── raw 36 | └── midifiles(.mid) 37 | └── val 38 | └── raw 39 | └── midifiles(.mid) 40 | ``` 41 | 42 | ## Preprocessing 43 | - ComMU dataset can be preprocessed by specifying the root directory and csv file path containing metadata. 44 | ``` 45 | $ python3 preprocess.py --root_dir ./dataset/commu_midi --csv_path ./dataset/commu_meta.csv 46 | ``` 47 | 48 | - After successful preprocessing, project tree would be like this, 49 | ``` 50 | . 51 | ├── commu_meta.csv 52 | └── commu_midi 53 | ├── train 54 | │ ├── raw 55 | │ ├── augmented_tmp 56 | │ ├── augmented 57 | │ └── npy_tmp 58 | ├── val 59 | │ ├── raw 60 | │ ├── augmented_tmp 61 | │ ├── augmented 62 | │ └── npy_tmp 63 | └── output_npy 64 | ├── input_train.npy 65 | ├── input_val.npy 66 | ├── target_train.npy 67 | └── target_val.npy 68 | ``` 69 | - Training input is related to `output_npy` directory. it contains input/target array splitted into training/validation. 70 | - here is the additional explanation of train/val directory: 71 | - `raw` : splitted raw midi file 72 | - `augmented` : augmented data by key_switch and bpm change. 73 | - file name looks like this, representing audio_key and bpm info : `commu11144_gmajor_70.mid` 74 | - `augmented_tmp` : contains temporary augmented data. 75 | - `npy_tmp` : temporary numpy array containing Encoded Output. categorized into numbered subdirectories(ex) 0000~0015), and each directory has numpy array of each midi data. 76 | 77 | 78 | ## Training 79 | ``` 80 | $ python3 -m torch.distributed.launch --nproc_per_node=4 ./train.py --data_dir ./dataset/commu_midi/output_npy --work_dir {./working_direcoty} 81 | ``` 82 | 83 | ## Generating 84 | - generation involves choice of metadata, regarding which type of music(midi file) we intend to generate. the example of command is showed below. 85 | ``` 86 | $ python3 generate.py \ 87 | --checkpoint_dir {./working_directory/checkpoint_best.pt} \ 88 | --output_dir {./output_dir} \ 89 | --bpm 70 \ 90 | --audio_key aminor \ 91 | --time_signature 4/4 \ 92 | --pitch_range mid_high \ 93 | --num_measures 8 \ 94 | --inst acoustic_piano \ 95 | --genre newage \ 96 | --min_velocity 60 \ 97 | --max_velocity 80 \ 98 | --track_role main_melody \ 99 | --rhythm standard \ 100 | --chord_progression Am-Am-Am-Am-Am-Am-Am-Am-G-G-G-G-G-G-G-G-F-F-F-F-F-F-F-F-E-E-E-E-E-E-E-E-Am-Am-Am-Am-Am-Am-Am-Am-G-G-G-G-G-G-G-G-F-F-F-F-F-F-F-F-E-E-E-E-E-E-E-E \ 101 | --num_generate 3 102 | ``` 103 | 104 | ## Checkpoint File 105 | [Download](https://drive.google.com/file/d/1y0wl9JO8od3pLOMSxN8NwLy1PCJCyTGL/view?usp=share_link) 106 | 107 | ## License 108 | ComMU dataset is released under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License (CC BY-NC-SA 4.0). It is provided primarily for research purposes and is prohibited to be used for commercial purposes. 109 | -------------------------------------------------------------------------------- /commu/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POZAlabs/ComMU-code/3949a5b5a1a54e2bb0fb9d600ecc00cd55660408/commu/__init__.py -------------------------------------------------------------------------------- /commu/midi_generator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POZAlabs/ComMU-code/3949a5b5a1a54e2bb0fb9d600ecc00cd55660408/commu/midi_generator/__init__.py -------------------------------------------------------------------------------- /commu/midi_generator/container.py: -------------------------------------------------------------------------------- 1 | import json 2 | from fractions import Fraction 3 | from pathlib import Path 4 | from typing import Dict, Any, List 5 | 6 | from pydantic import BaseModel, validator 7 | 8 | from commu.preprocessor.encoder import encoder_utils, TOKEN_OFFSET 9 | from commu.preprocessor.utils import constants 10 | from commu.preprocessor.utils.container import MidiMeta 11 | 12 | 13 | class ModelArguments(BaseModel): 14 | checkpoint_dir: str 15 | 16 | 17 | class TransXlInputData(MidiMeta): 18 | output_dir: Path 19 | 20 | num_generate: int 21 | top_k: int 22 | temperature: float 23 | chord_progression: List[str] 24 | 25 | @validator("chord_progression") 26 | def validate_chord_progression_length(cls, value: List[str], values: Dict[str, Any]) -> List[str]: 27 | num_measures = values.get("num_measures") 28 | time_signature = values.get("time_signature") 29 | expected_result = (num_measures - (num_measures % 4)) * Fraction(time_signature) * 8 30 | result = len(value) 31 | if expected_result != result: 32 | raise ValueError("num_measures not matched with chord progression length") 33 | return value 34 | 35 | 36 | @property 37 | def chord_token_components(self) -> Dict[str, list]: 38 | event2word, _ = encoder_utils.mk_remi_map() 39 | event2word = encoder_utils.add_flat_chord2map(event2word) 40 | event2word = encoder_utils.abstract_chord_types(event2word) 41 | 42 | beats_per_bar = int(Fraction(self.time_signature) * 4) 43 | chord_idx_lst, unique_cp = encoder_utils.detect_chord(self.chord_progression, beats_per_bar) 44 | resolution = constants.DEFAULT_POSITION_RESOLUTION 45 | chord_position = [] 46 | for i in chord_idx_lst: 47 | if isinstance(i, int): 48 | chord_position.append(TOKEN_OFFSET.POSITION.value) 49 | else: 50 | bit_offset = (float(str(i).split(".")[-1]) * resolution) / ( 51 | 10 ** len(str(i).split(".")[-1]) 52 | ) # 10진수 소수점으로 표현된 position index를 32bit 표현으로 변환 53 | chord_position.append(int(TOKEN_OFFSET.POSITION.value + bit_offset)) 54 | 55 | chord_token = [] 56 | for chord in unique_cp: 57 | chord = "Chord_" + chord.split("/")[0].split("(")[0] 58 | chord_token.append(event2word[chord]) 59 | 60 | chord_token_components = { 61 | "chord_token": chord_token, 62 | "chord_position": chord_position 63 | } 64 | return chord_token_components 65 | 66 | def to_dict(self) -> Dict[str, Any]: 67 | return json.loads(self.json()) -------------------------------------------------------------------------------- /commu/midi_generator/generate_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from commu.midi_generator.container import ModelArguments 4 | from commu.midi_generator.model_initializer import ModelInitializeTask 5 | from commu.midi_generator.info_preprocessor import PreprocessTask 6 | from commu.midi_generator.midi_inferrer import InferenceTask 7 | from commu.midi_generator.sequence_postprocessor import PostprocessTask 8 | 9 | 10 | class MidiGenerationPipeline: 11 | def __init__(self): 12 | self.map_location = None 13 | self.device = None 14 | self.model_args = None 15 | self.model_initialize_task = None 16 | 17 | self.preprocess_task = None 18 | self.inference_task = None 19 | self.postprocess_task = None 20 | 21 | def initialize_model(self, model_arguments: dict): 22 | self.map_location = "cuda" if torch.cuda.is_available() else "cpu" 23 | self.device = torch.device(self.map_location) 24 | self.model_args = ModelArguments(**model_arguments) 25 | 26 | self.model_initialize_task = ModelInitializeTask( 27 | model_args=self.model_args, 28 | map_location=self.map_location, 29 | device=self.device 30 | ) 31 | 32 | def initialize_generation(self): 33 | self.preprocess_task = PreprocessTask() 34 | self.inference_task = InferenceTask(self.device) 35 | self.postprocess_task = PostprocessTask() -------------------------------------------------------------------------------- /commu/midi_generator/info_preprocessor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any 2 | 3 | from commu.midi_generator.container import TransXlInputData 4 | from commu.preprocessor.encoder import MetaEncoder 5 | from commu.preprocessor.utils.container import MidiMeta 6 | 7 | 8 | def parse_meta(**kwargs: Any) -> MidiMeta: 9 | return MidiMeta(**kwargs) 10 | 11 | 12 | def encode_meta(meta_encoder: MetaEncoder, midi_meta: MidiMeta) -> List[int]: 13 | return meta_encoder.encode(midi_meta) 14 | 15 | 16 | def normalize_chord_progression(chord_progression: str) -> List[str]: 17 | return chord_progression.split("-") 18 | 19 | 20 | class PreprocessTask: 21 | def __init__(self): 22 | self.input_data = None 23 | self.midi_meta = None 24 | 25 | def get_meta_info_length(self): 26 | return len(self.midi_meta.__fields__) 27 | 28 | def normalize_input_data(self, input_data: dict): 29 | input_data["chord_progression"] = normalize_chord_progression(input_data["chord_progression"]) 30 | self.input_data = TransXlInputData(**input_data) 31 | 32 | def preprocess(self) -> List[int]: 33 | self.midi_meta = parse_meta(**self.input_data.dict()) 34 | meta_encoder = MetaEncoder() 35 | encoded_meta = encode_meta( 36 | meta_encoder=meta_encoder, midi_meta=self.midi_meta 37 | ) 38 | return encoded_meta 39 | 40 | def execute(self, input_data: dict) -> List[int]: 41 | if self.input_data is None: 42 | self.normalize_input_data(input_data) 43 | 44 | encoded_meta = self.preprocess() 45 | return encoded_meta -------------------------------------------------------------------------------- /commu/midi_generator/midi_inferrer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple, List 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import yacs.config 8 | 9 | from logger import logger 10 | from commu.midi_generator.container import TransXlInputData 11 | from commu.model.model import MemTransformerLM 12 | from commu.preprocessor.encoder import TOKEN_OFFSET 13 | from commu.preprocessor.utils.constants import DEFAULT_POSITION_RESOLUTION 14 | 15 | 16 | class TeacherForceTask: 17 | def __init__(self, input_data): 18 | self.input_data = input_data 19 | self.next_tokens_forced = [] 20 | self.wrong_tokens = [] 21 | self.no_sequence_appended = False 22 | self.is_incomplete = input_data.num_measures % 4 != 0 23 | self.incomplete_filled = not self.is_incomplete 24 | 25 | self.chord_token, self.chord_position = input_data.chord_token_components.values() 26 | assert len(self.chord_token) == len(self.chord_position), "Wrong Chord Length" 27 | self.chord_length = len(self.chord_token) 28 | self.inter_chord_flags = [] 29 | for i in self.chord_position: 30 | if i == TOKEN_OFFSET.POSITION.value: 31 | self.inter_chord_flags.append(False) 32 | else: 33 | self.inter_chord_flags.append(True) 34 | 35 | def check_first_position(self, seq): 36 | """ 37 | check if it's a token following a bar token 38 | """ 39 | return self.incomplete_filled and seq[-1] == TOKEN_OFFSET.BAR.value 40 | 41 | def check_remnant_chord(self): 42 | """ 43 | check if there any more chords to write 44 | if not, return False 45 | """ 46 | return bool(len(self.chord_token) * len(self.chord_position)) 47 | 48 | def check_length_fit(self): 49 | """ 50 | check if one chord per bar needed 51 | """ 52 | return self.chord_length == int(self.input_data.num_measures // 4 * 4) 53 | 54 | def check_position_fit(self, seq): 55 | """ 56 | check if a chord token needs to be filled next 57 | """ 58 | return seq[-2] == TOKEN_OFFSET.BAR.value and seq[-1] == TOKEN_OFFSET.POSITION.value 59 | 60 | def check_one_chord_per_bar_case(self, seq): 61 | """ 62 | case: one chord per bar 63 | """ 64 | return ( 65 | self.check_remnant_chord() 66 | and self.incomplete_filled 67 | and self.check_length_fit() 68 | and self.check_position_fit(seq) 69 | ) 70 | 71 | def check_mul_chord_per_bar_case(self, seq): 72 | """ 73 | case: multiple chords per bar 74 | """ 75 | is_first_position_chord = ( 76 | self.check_remnant_chord() 77 | and self.incomplete_filled 78 | and not self.check_length_fit() 79 | and self.check_position_fit(seq) 80 | ) 81 | 82 | is_inter_position_chord = ( 83 | self.check_remnant_chord() 84 | and self.incomplete_filled 85 | and not self.check_length_fit() 86 | and not self.check_position_fit(seq) 87 | and seq[-1] == self.chord_position[0] 88 | and self.inter_chord_flags[0] 89 | ) 90 | return is_first_position_chord or is_inter_position_chord 91 | 92 | def check_chord_position_passed(self, token): 93 | """ 94 | in case a generated token skipped necessary position 95 | """ 96 | if not self.check_remnant_chord(): 97 | return False 98 | is_position_passed = ( 99 | self.chord_position[0] < token < TOKEN_OFFSET.POSITION.value + DEFAULT_POSITION_RESOLUTION 100 | or token == TOKEN_OFFSET.BAR.value 101 | ) 102 | return self.inter_chord_flags[0] and is_position_passed 103 | 104 | def check_wrong_chord_token_generated(self, token): 105 | """ 106 | all chord tokens should be teacher forced 107 | """ 108 | return TOKEN_OFFSET.CHORD_START.value <= token <= TOKEN_OFFSET.CHORD_END.value 109 | 110 | def check_wrong_eos_generated(self, token): 111 | return self.check_remnant_chord() and token == TOKEN_OFFSET.EOS.value 112 | 113 | def check_wrong_bar_token_generated(self, token): 114 | return not self.check_remnant_chord() and token == TOKEN_OFFSET.BAR.value 115 | 116 | def teach_first_position(self) -> None: 117 | """ 118 | teach 1/128 position right after a bar token 119 | """ 120 | self.next_tokens_forced.append(int(TOKEN_OFFSET.POSITION.value)) 121 | 122 | def teach_chord_token(self): 123 | next_chord_tokens = self.chord_token.pop(0) 124 | self.next_tokens_forced.append(next_chord_tokens) 125 | self.chord_position.pop(0) 126 | self.inter_chord_flags.pop(0) 127 | self.wrong_tokens = [] 128 | 129 | def teach_chord_position(self): 130 | next_position_token = self.chord_position[0] 131 | self.next_tokens_forced.append(next_position_token) 132 | self.wrong_tokens = [] 133 | 134 | def teach_wrong_chord_token(self, wrong_token): 135 | self.no_sequence_appended = True 136 | self.wrong_tokens.append(wrong_token) 137 | 138 | def teach_remnant_chord(self): 139 | token = self.chord_position[0] if self.inter_chord_flags[0] else TOKEN_OFFSET.BAR.value 140 | self.next_tokens_forced.append(token) 141 | 142 | def teach_eos(self): 143 | token = TOKEN_OFFSET.EOS.value 144 | self.next_tokens_forced.append(token) 145 | 146 | def validate_teacher_forced_sequence(self, seq) -> None: 147 | def _count_num_chord(seq): 148 | chord_counter = 0 149 | for token in seq: 150 | if TOKEN_OFFSET.CHORD_START.value <= token <= TOKEN_OFFSET.CHORD_END.value: 151 | chord_counter += 1 152 | return chord_counter 153 | 154 | num_bars = seq.count(TOKEN_OFFSET.BAR.value) 155 | num_chord = _count_num_chord(seq) 156 | 157 | if len(self.chord_token) != 0: 158 | raise Exception( 159 | f"remnant chord length: {len(self.chord_token)} \n" "error in teacher forcing" 160 | ) 161 | elif num_bars != int(math.ceil(self.input_data.num_measures)): 162 | raise Exception(f"bar length: {num_bars} \n" "error in bar length") 163 | elif num_chord != self.chord_length: 164 | raise Exception( 165 | f"num_chord: {num_chord} vs {self.chord_length} \n" "error in chord length" 166 | ) 167 | else: 168 | logger.info(f"correct_length: {num_bars}") 169 | logger.info(seq) 170 | 171 | 172 | class InferenceTask: 173 | def __init__(self, device: torch.device): 174 | self.device = device 175 | 176 | def __call__( 177 | self, 178 | model: MemTransformerLM, 179 | input_data: TransXlInputData, 180 | inference_cfg: yacs.config.CfgNode, 181 | ): 182 | self.model = model 183 | self.input_data = input_data 184 | self.inference_cfg = inference_cfg 185 | 186 | def init_seq_and_mems( 187 | self, encoded_meta: List[int], num_conditional_tokens: int 188 | ) -> Tuple[List[int], torch.Tensor]: 189 | 190 | seq = [0] 191 | ctx = np.array(seq + encoded_meta[: num_conditional_tokens - 1], dtype=np.int32)[ 192 | :, np.newaxis 193 | ] 194 | context = torch.from_numpy(ctx).to(self.device).type(torch.long) 195 | _, init_mems = self.model.forward_generate(context, mems=None) 196 | init_seq = seq + encoded_meta[:num_conditional_tokens] 197 | return init_seq, init_mems 198 | 199 | def calc_logits_and_mems( 200 | self, seq: List[int], mems: torch.Tensor 201 | ) -> Tuple[torch.Tensor, torch.Tensor]: 202 | inp = np.array([seq[-1]], dtype=np.int32)[:, np.newaxis] 203 | input_token = torch.from_numpy(inp).to(self.device).type(torch.long) 204 | ret = self.model.forward_generate(input_token, mems) 205 | all_logits, mems = ret 206 | logits = all_logits[-1, 0][1:] 207 | return logits, mems 208 | 209 | def calc_probs(self, logits): 210 | # Handle temp 0 (argmax) case 211 | if self.input_data.temperature == 0: 212 | probs = torch.zeros_like(logits) 213 | probs[logits.argmax()] = 1.0 214 | else: 215 | # Apply temperature spec 216 | logits /= self.input_data.temperature 217 | # Compute softmax 218 | probs = F.softmax(logits, dim=-1) 219 | 220 | probs = F.pad(probs, [1, 0]) 221 | return probs 222 | 223 | def apply_sampling(self, probs, wrong_tokens): 224 | _, top_idx = torch.topk(probs, self.input_data.top_k) 225 | mask = torch.zeros_like(probs) 226 | mask[top_idx] = 1.0 227 | if wrong_tokens: 228 | for w in wrong_tokens: 229 | mask[w] = 0.0 230 | probs *= mask 231 | probs /= torch.sum(probs) 232 | return probs 233 | 234 | def infer_token(self, probs): 235 | token = torch.multinomial(probs, 1) 236 | token = int(token.item()) 237 | return token 238 | 239 | def generate_sequence(self, seq, mems): 240 | logits = None 241 | teacher = TeacherForceTask(self.input_data) 242 | first_loop = True 243 | for _ in range(self.inference_cfg.GENERATION.generation_length): 244 | if seq[-1] == 1: 245 | break 246 | 247 | if teacher.next_tokens_forced: 248 | next_token = teacher.next_tokens_forced.pop(0) 249 | seq.append(next_token) 250 | logits, mems = self.calc_logits_and_mems(seq, mems) 251 | continue 252 | 253 | if teacher.no_sequence_appended: 254 | assert logits is not None 255 | teacher.no_sequence_appended = False 256 | elif first_loop: 257 | logits, _ = self.calc_logits_and_mems(seq, mems) 258 | first_loop = False 259 | else: 260 | logits, mems = self.calc_logits_and_mems(seq, mems) 261 | 262 | probs = self.calc_probs(logits) 263 | probs = self.apply_sampling(probs, teacher.wrong_tokens) 264 | 265 | # teacher forcing 266 | # in case with incomplete measure, trigger a flag after second bar token 267 | if not teacher.incomplete_filled: 268 | teacher.incomplete_filled = True if seq.count(TOKEN_OFFSET.BAR.value) > 1 else False 269 | 270 | # forcefully assign position 1/128 right after bar token 271 | if teacher.check_first_position(seq): 272 | teacher.teach_first_position() 273 | continue 274 | 275 | # in case there is one chord per bar 276 | if teacher.check_one_chord_per_bar_case(seq): 277 | teacher.teach_chord_token() 278 | continue 279 | 280 | # in case the chord changes within a bar 281 | if teacher.check_mul_chord_per_bar_case(seq): 282 | teacher.teach_chord_token() 283 | continue 284 | 285 | # teacher forcing followed by token inference so that we can check if the wrong token was generated 286 | try: 287 | token = self.infer_token(probs) 288 | except RuntimeError as e: 289 | logger.error(f"Sampling Error: {e}") 290 | seq = None 291 | break 292 | 293 | # generated token skipped necessary position 294 | if teacher.check_chord_position_passed(token): 295 | teacher.teach_chord_position() 296 | continue 297 | 298 | # wrong chord token generated 299 | if teacher.check_wrong_chord_token_generated(token): 300 | teacher.teach_wrong_chord_token(token) 301 | continue 302 | 303 | # eos generated but we got more chords to write 304 | if teacher.check_wrong_eos_generated(token): 305 | teacher.teach_remnant_chord() 306 | continue 307 | 308 | # bar token generated but num measures exceed 309 | if teacher.check_wrong_bar_token_generated(token): 310 | teacher.teach_eos() 311 | continue 312 | 313 | seq.append(token) 314 | 315 | try: 316 | teacher.validate_teacher_forced_sequence(seq) 317 | except Exception as error_message: 318 | logger.error(error_message) 319 | seq = None 320 | return seq 321 | 322 | def validate_generated_sequence(self, seq: List[int]) -> bool: 323 | num_note = 0 324 | for idx, token in enumerate(seq): 325 | if idx + 2 > len(seq) - 1: 326 | break 327 | if token in range(TOKEN_OFFSET.NOTE_VELOCITY.value, TOKEN_OFFSET.CHORD_START.value): 328 | if ( 329 | seq[idx - 1] in range(TOKEN_OFFSET.POSITION.value, TOKEN_OFFSET.BPM.value) 330 | and seq[idx + 1] 331 | in range(TOKEN_OFFSET.PITCH.value, TOKEN_OFFSET.NOTE_VELOCITY.value) 332 | and seq[idx + 2] 333 | in range(TOKEN_OFFSET.NOTE_DURATION.value, TOKEN_OFFSET.POSITION.value) 334 | ): 335 | num_note += 1 336 | return num_note > 0 337 | 338 | def execute(self, encoded_meta) -> List[List[int]]: 339 | num_conditional_tokens = len(encoded_meta) 340 | idx = 0 341 | sequences = [] 342 | while idx != self.input_data.num_generate: 343 | with torch.no_grad(): 344 | logger.info("Generating the idx: " + str(idx + 1)) 345 | seq, mems = self.init_seq_and_mems(encoded_meta, num_conditional_tokens) 346 | seq = self.generate_sequence(seq, mems) 347 | if seq is None: 348 | continue 349 | if not self.validate_generated_sequence(seq): 350 | logger.error("Empty sequence generated") 351 | continue 352 | sequences.append(seq) 353 | idx += 1 354 | return sequences 355 | -------------------------------------------------------------------------------- /commu/midi_generator/model_initializer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Tuple 3 | 4 | import torch 5 | import yacs.config 6 | 7 | from commu.midi_generator.container import ModelArguments 8 | from commu.model.config_helper import get_default_cfg_inference, get_default_cfg_training 9 | from commu.model.dataset import BaseVocab 10 | from commu.model.model import MemTransformerLM 11 | 12 | 13 | class ModelInitializeTask: 14 | def __init__(self, model_args: ModelArguments, map_location: str, device: torch.device): 15 | self.model_args = model_args 16 | self.map_location = map_location 17 | self.device = device 18 | self.inference_cfg = self.initialize_inference_config() 19 | 20 | def initialize_inference_config(self) -> yacs.config.CfgNode: 21 | inference_cfg = get_default_cfg_inference() 22 | inference_cfg.freeze() 23 | return inference_cfg 24 | 25 | def load_checkpoint_fp(self) -> Tuple[Path, Path]: 26 | checkpoint_dir = self.model_args.checkpoint_dir 27 | if checkpoint_dir: 28 | model_fp = Path(checkpoint_dir) 29 | training_cfg_fp = model_fp.parent / "config.yml" 30 | else: 31 | model_parent = Path(self.inference_cfg.MODEL.model_directory) 32 | model_fp = model_parent / self.inference_cfg.MODEL.checkpoint_name 33 | training_cfg_fp = model_parent / "config.yml" 34 | return model_fp, training_cfg_fp 35 | 36 | def initialize_training_cfg(self) -> yacs.config.CfgNode: 37 | cfg = get_default_cfg_training() 38 | cfg.defrost() 39 | cfg.MODEL.same_length = True # Needed for same_length =True during evaluation 40 | cfg.freeze() 41 | return cfg 42 | 43 | def initialize_model(self, training_cfg, model_fp): 44 | perform_vocab = BaseVocab() 45 | model = MemTransformerLM(training_cfg, perform_vocab) 46 | checkpoint = torch.load(model_fp) 47 | model.load_state_dict(checkpoint["model"], strict=False) 48 | model = model.to(self.device) 49 | model.eval() 50 | model.reset_length(1, self.inference_cfg.MODEL.memory_length) 51 | return model 52 | 53 | def execute(self): 54 | model_fp, training_cfg_fp = self.load_checkpoint_fp() 55 | training_cfg = self.initialize_training_cfg() 56 | model = self.initialize_model(training_cfg, model_fp) 57 | return model -------------------------------------------------------------------------------- /commu/midi_generator/sequence_postprocessor.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | 4 | from miditoolkit import MidiFile 5 | 6 | from commu.midi_generator.container import TransXlInputData 7 | from commu.preprocessor.encoder import EventSequenceEncoder 8 | from commu.preprocessor.utils.container import MidiInfo 9 | 10 | 11 | class PostprocessTask: 12 | def __init__(self): 13 | pass 14 | 15 | def __call__(self, input_data: TransXlInputData): 16 | self.input_data = input_data 17 | 18 | def get_output_dir(self) -> Path: 19 | return self.input_data.output_dir 20 | 21 | def set_output_file_path(self, index: int) -> Path: 22 | track_role = self.input_data.track_role 23 | inst = self.input_data.inst 24 | pitch_range = self.input_data.pitch_range 25 | 26 | output_dir = Path(self.input_data.output_dir).joinpath( 27 | f"{track_role}_{inst}_{pitch_range}") 28 | output_dir.mkdir(exist_ok=True, parents=True) 29 | 30 | file_name = f"{track_role}_{inst}_{pitch_range}_{index:03d}.mid" 31 | 32 | return output_dir.joinpath(file_name) 33 | 34 | def decode_event_sequence( 35 | self, 36 | generation_result: List[int], 37 | num_meta: int = 11 38 | ) -> MidiFile: 39 | encoded_meta = generation_result[1: num_meta + 1] 40 | event_sequence = generation_result[num_meta + 2:] 41 | decoder = EventSequenceEncoder() 42 | decoded_midi = decoder.decode( 43 | midi_info=MidiInfo(*encoded_meta, event_seq=event_sequence), 44 | ) 45 | 46 | return decoded_midi 47 | 48 | def execute(self, sequences: List[List[int]]) -> Path: 49 | for idx, seq in enumerate(sequences): 50 | decoded_midi = self.decode_event_sequence( 51 | generation_result=seq, 52 | ) 53 | output_file_path = self.set_output_file_path(idx) 54 | decoded_midi.dump(output_file_path) 55 | 56 | return self.get_output_dir() 57 | -------------------------------------------------------------------------------- /commu/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POZAlabs/ComMU-code/3949a5b5a1a54e2bb0fb9d600ecc00cd55660408/commu/model/__init__.py -------------------------------------------------------------------------------- /commu/model/config_helper.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | 4 | def model(cfg): 5 | # For model 6 | cfg.MODEL = CN() 7 | cfg.MODEL.num_layers = 6 8 | cfg.MODEL.num_heads = 10 9 | cfg.MODEL.units = 500 10 | cfg.MODEL.inner_size = 1000 11 | cfg.MODEL.dropout = 0.1 12 | cfg.MODEL.attention_dropout = 0.1 13 | cfg.MODEL.clamp_len = -1 14 | cfg.MODEL.same_length = False 15 | return cfg 16 | 17 | 18 | def train(cfg): 19 | # For training 20 | cfg.TRAIN = CN() 21 | cfg.TRAIN.batch_size = 256 22 | cfg.TRAIN.batch_chunk = 4 23 | cfg.TRAIN.tgt_length = 128 24 | cfg.TRAIN.mem_length = 1024 25 | cfg.TRAIN.seed = 1111 26 | cfg.TRAIN.lr = 0.004 27 | cfg.TRAIN.lr_min = 0.0001 28 | cfg.TRAIN.warmup_step = 100 29 | cfg.TRAIN.clip = 1.0 30 | cfg.TRAIN.max_step = 20000 31 | cfg.TRAIN.log_interval = 100 32 | cfg.TRAIN.eval_interval = 1000 33 | cfg.TRAIN.weight_decay = 0.0 34 | return cfg 35 | 36 | 37 | def init(cfg): 38 | # For initialization 39 | cfg.INITIALIZER = CN() 40 | cfg.INITIALIZER.base_init = 0.01 41 | cfg.INITIALIZER.embed_init = 0.01 42 | 43 | # For evaluation 44 | cfg.EVALUATE = CN() 45 | cfg.EVALUATE.batch_size = 10 46 | cfg.EVALUATE.tgt_length = 128 47 | cfg.EVALUATE.mem_length = 2048 48 | 49 | return cfg 50 | 51 | 52 | def get_default_cfg_training(): 53 | cfg = CN() 54 | cfg = init(cfg) 55 | cfg = model(cfg) 56 | cfg = train(cfg) 57 | cfg.freeze() 58 | return cfg 59 | 60 | 61 | def get_default_cfg_inference(): 62 | """Get a yacs CfgNode object with default values.""" 63 | cfg = CN() 64 | 65 | # # Model related parameters 66 | cfg.MODEL = CN() 67 | cfg.MODEL.memory_length = 4146 68 | cfg.MODEL.device = "gpu" 69 | # Sampling related parameters 70 | cfg.SAMPLING = CN() 71 | cfg.SAMPLING.threshold = 32.0 72 | cfg.SAMPLING.temperature = 0.95 73 | 74 | # Model related parameters 75 | cfg.GENERATION = CN() 76 | cfg.GENERATION.generation_length = 4096 77 | 78 | 79 | cfg.freeze() 80 | return cfg 81 | -------------------------------------------------------------------------------- /commu/model/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from commu.preprocessor.encoder.event_tokens import TOKEN_OFFSET 4 | 5 | 6 | class BaseVocab: 7 | def __init__(self): 8 | self.vec_len = 0 9 | 10 | @property 11 | def pad_id(self): 12 | return 0 13 | 14 | def __len__(self): 15 | return TOKEN_OFFSET.VOCAB_SIZE.value 16 | 17 | 18 | class ComMUDataset: 19 | def __init__(self, data_dir, cfg): 20 | """Load the music corpus 21 | Args: 22 | data_dir: The base folder of the preprocessed music dataset 23 | """ 24 | self._vocab = BaseVocab() 25 | 26 | self._train_data = self.load_cache_data(data_dir, "train") 27 | self._valid_data = self.load_cache_data(data_dir, "valid") 28 | self._test_data = self.load_cache_data(data_dir, "test") 29 | self.cfg = cfg 30 | 31 | # Insert start tokens 32 | print("USING PAD TOKEN AS START!") 33 | insert_token = self._vocab.pad_id # pad as a start token 34 | self._train_data = [ 35 | torch.from_numpy(np.insert(arr, 0, insert_token)) 36 | for arr in self._train_data 37 | ] 38 | self._valid_data = [ 39 | torch.from_numpy(np.insert(arr, 0, insert_token)) 40 | for arr in self._valid_data 41 | ] 42 | self._test_data = [ 43 | torch.from_numpy(np.insert(arr, 0, insert_token)) 44 | for arr in self._test_data 45 | ] 46 | 47 | self._train_seq_length = np.array( 48 | [ele.shape[0] for ele in self._train_data], dtype=np.int32 49 | ) 50 | self._valid_seq_length = np.array( 51 | [ele.shape[0] for ele in self._valid_data], dtype=np.int32 52 | ) 53 | self._test_seq_length = np.array( 54 | [ele.shape[0] for ele in self._test_data], dtype=np.int32 55 | ) 56 | print( 57 | "Loaded Data, #Samples Train/Val/Test:{}/{}/{}".format( 58 | len(self._train_data), len(self._valid_data), len(self._test_data) 59 | ) 60 | ) 61 | print( 62 | " #Avg Length:{}/{}/{}".format( 63 | np.mean([len(ele) for ele in self._train_data]), 64 | np.mean([len(ele) for ele in self._valid_data]), 65 | np.mean([len(ele) for ele in self._test_data]), 66 | ) 67 | ) 68 | print( 69 | " #Total Number of Valid/Test Tokens: {}/{}".format( 70 | (self._valid_seq_length - 1).sum(), (self._test_seq_length - 1).sum() 71 | ) 72 | ) 73 | 74 | def load_cache_data(self, dir_name, mode): 75 | if mode == "train": 76 | data_input = np.load(dir_name + '/input_train.npy', allow_pickle=True) 77 | data_target = np.load(dir_name + '/target_train.npy', allow_pickle=True) 78 | dat = [] 79 | for i in range(len(data_input)): 80 | dat.append(np.concatenate((np.array(data_input[i], dtype=int), data_target[i]))) 81 | else: 82 | data_input = np.load(dir_name + '/input_val.npy', allow_pickle=True) 83 | data_target = np.load(dir_name + '/target_val.npy', allow_pickle=True) 84 | dat = [] 85 | for i in range(len(data_input)): 86 | dat.append(np.concatenate((np.array(data_input[i], dtype=int), data_target[i]))) 87 | return np.array(dat, dtype=object) 88 | 89 | @property 90 | def vocab(self): 91 | return self._vocab 92 | 93 | @property 94 | def train_data(self): 95 | return self._train_data 96 | 97 | @property 98 | def valid_data(self): 99 | return self._valid_data 100 | 101 | @property 102 | def test_data(self): 103 | return self._test_data 104 | 105 | @property 106 | def train_seq_length(self): 107 | return self._train_seq_length 108 | 109 | @property 110 | def valid_seq_length(self): 111 | return self._valid_seq_length 112 | 113 | @property 114 | def test_seq_length(self): 115 | return self._test_seq_length 116 | 117 | def get_iterator( 118 | self, batch_size, bptt, device, split="train", do_shuffle=True, seed=None 119 | ): 120 | if split == "train": 121 | split_data = self.train_data 122 | split_seq_lengths = self.train_seq_length 123 | elif split == "valid": 124 | split_data = self.valid_data 125 | split_seq_lengths = self.valid_seq_length 126 | elif split == "test": 127 | split_data = self.test_data 128 | split_seq_lengths = self.test_seq_length 129 | else: 130 | raise NotImplementedError 131 | total_sample_num = len(split_data) 132 | 133 | def iterator(): 134 | perm = np.arange(total_sample_num) 135 | if do_shuffle: 136 | rng = np.random.RandomState(seed) 137 | rng.shuffle(perm) 138 | assert batch_size < total_sample_num 139 | tracker_list = [(i, 0) for i in range(batch_size)] 140 | next_idx = batch_size 141 | data = torch.LongTensor(bptt, batch_size) 142 | target = torch.LongTensor(bptt, batch_size) 143 | reset_mem = torch.BoolTensor(batch_size) 144 | 145 | while True: 146 | # Generate the samples 147 | # Fill with pad_id 148 | data[:] = self.vocab.pad_id 149 | target[:] = self.vocab.pad_id 150 | reset_mem[:] = False 151 | batch_token_num = 0 152 | for i in range(batch_size): 153 | idx, pos = tracker_list[i] 154 | while idx < total_sample_num: 155 | seq_id = perm[idx] 156 | seq_length = split_seq_lengths[seq_id] 157 | if pos + 1 >= seq_length: 158 | idx, pos = next_idx, 0 159 | tracker_list[i] = (idx, pos) 160 | next_idx += 1 161 | reset_mem[i] = True 162 | continue 163 | else: 164 | n_new = min(seq_length - 1 - pos, bptt) 165 | data[:n_new, i] = split_data[seq_id][pos: pos + n_new] 166 | target[:n_new, i] = split_data[seq_id][ 167 | (pos + 1): (pos + 1 + n_new)] 168 | batch_token_num += n_new 169 | tracker_list[i] = (idx, pos + n_new) 170 | break 171 | if batch_token_num == 0: 172 | # Haven't found anything to fill. This indicates we have reached the end 173 | if do_shuffle: 174 | rng.shuffle(perm) 175 | else: 176 | return # One pass dataloader when do_shuffle is False 177 | tracker_list = [(i, 0) for i in range(batch_size)] 178 | next_idx = batch_size 179 | continue 180 | 181 | yield data.to(device), target.to(device), reset_mem.to(device), batch_token_num 182 | 183 | return iterator 184 | 185 | def eval_iterator( 186 | self, batch_size, bptt, device, split="valid", local_rank=0, world_size=0 187 | ): 188 | if split == "valid": 189 | split_data = self.valid_data 190 | split_seq_lengths = self.valid_seq_length 191 | elif split == "test": 192 | split_data = self.test_data 193 | split_seq_lengths = self.test_seq_length 194 | else: 195 | raise NotImplementedError 196 | if world_size > 0: 197 | all_sample_num = len(split_data) 198 | if local_rank == world_size - 1: 199 | begin_idx = all_sample_num // world_size * local_rank 200 | end_idx = all_sample_num 201 | else: 202 | begin_idx = all_sample_num // world_size * local_rank 203 | end_idx = all_sample_num // world_size * (local_rank + 1) 204 | split_data = split_data[begin_idx:end_idx] 205 | split_seq_lengths = split_seq_lengths[begin_idx:end_idx] 206 | total_sample_num = len(split_data) 207 | 208 | def iterator(): 209 | data = torch.LongTensor(bptt, batch_size) 210 | target = torch.LongTensor(bptt, batch_size) 211 | for batch_begin in range(0, total_sample_num, batch_size): 212 | reset_all_mem = True 213 | batch_end = min(batch_begin + batch_size, total_sample_num) 214 | max_seq_length = max(split_seq_lengths[batch_begin:batch_end]) 215 | for seq_begin in range(0, max_seq_length - 1, bptt): 216 | data[:] = self.vocab.pad_id 217 | target[:] = self.vocab.pad_id 218 | batch_token_num = 0 219 | for i in range(batch_begin, batch_end): 220 | if split_seq_lengths[i] > seq_begin + 1: 221 | n_new = ( 222 | min(seq_begin + bptt, split_seq_lengths[i] - 1) 223 | - seq_begin 224 | ) 225 | data[:n_new, i - batch_begin] = split_data[i][ 226 | seq_begin: seq_begin + n_new 227 | ] 228 | target[:n_new, i - batch_begin] = split_data[i][ 229 | (seq_begin + 1): (seq_begin + n_new + 1) 230 | ] 231 | batch_token_num += n_new 232 | 233 | yield data.to(device), target.to(device), reset_all_mem, batch_token_num 234 | 235 | reset_all_mem = False 236 | 237 | return iterator 238 | -------------------------------------------------------------------------------- /commu/model/exp_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import logging 3 | import os 4 | from typing import Optional 5 | 6 | 7 | def logging_config(folder: Optional[str] = None, 8 | name: Optional[str] = None, 9 | level: int = logging.INFO, 10 | console_level: int = logging.INFO, 11 | console: bool = True) -> str: 12 | """Config the logging module""" 13 | if name is None: 14 | name = inspect.stack()[1][1].split('.')[0] 15 | if folder is None: 16 | folder = os.path.join(os.getcwd(), name) 17 | if not os.path.exists(folder): 18 | os.makedirs(folder, exist_ok=True) 19 | # Remove all the current handlers 20 | for handler in logging.root.handlers: 21 | logging.root.removeHandler(handler) 22 | logging.root.handlers = [] 23 | logpath = os.path.join(folder, name + ".log") 24 | print("All Logs will be saved to {}".format(logpath)) 25 | logging.root.setLevel(level) 26 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 27 | logfile = logging.FileHandler(logpath) 28 | logfile.setLevel(level) 29 | logfile.setFormatter(formatter) 30 | logging.root.addHandler(logfile) 31 | if console: 32 | # Initialze the console logging 33 | logconsole = logging.StreamHandler() 34 | logconsole.setLevel(console_level) 35 | logconsole.setFormatter(formatter) 36 | logging.root.addHandler(logconsole) 37 | return folder 38 | -------------------------------------------------------------------------------- /commu/model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ProjectedAdaptiveLogSoftmax(nn.Module): 7 | def __init__(self, n_token, d_embed, d_proj, cutoffs=None, keep_order=False): 8 | super(ProjectedAdaptiveLogSoftmax, self).__init__() 9 | 10 | if cutoffs is None: 11 | cutoffs = [] 12 | 13 | self.n_token = n_token 14 | self.d_embed = d_embed 15 | self.d_proj = d_proj 16 | 17 | self.cutoffs = cutoffs + [n_token] 18 | self.cutoff_ends = [0] + self.cutoffs 19 | 20 | self.shortlist_size = self.cutoffs[0] 21 | self.n_clusters = len(self.cutoffs) - 1 22 | self.head_size = self.shortlist_size + self.n_clusters 23 | 24 | if self.n_clusters > 0: 25 | self.cluster_weight = nn.Parameter( 26 | torch.zeros(self.n_clusters, self.d_embed) 27 | ) 28 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 29 | 30 | self.out_layers = nn.ModuleList() 31 | self.out_projs = nn.ParameterList() 32 | 33 | 34 | for i in range(len(self.cutoffs)): 35 | if d_proj != d_embed: 36 | self.out_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed))) 37 | else: 38 | self.out_projs.append(None) 39 | 40 | self.out_layers.append(nn.Linear(d_embed, n_token)) 41 | 42 | self.keep_order = keep_order 43 | 44 | def _compute_logit(self, hidden, weight, bias, proj): 45 | if proj is None: 46 | logit = F.linear(hidden, weight, bias=bias) 47 | else: 48 | proj_hid = F.linear(hidden, proj.t().contiguous()) 49 | logit = F.linear(proj_hid, weight, bias=bias) 50 | 51 | return logit 52 | 53 | def forward(self, hidden, target, keep_order=False): 54 | """ 55 | hidden :: [len*bsz x d_proj] 56 | target :: [len*bsz] 57 | """ 58 | 59 | if hidden.size(0) != target.size(0): 60 | raise RuntimeError( 61 | "Input and target should have the same size " "in the batch dimension." 62 | ) 63 | 64 | if self.n_clusters == 0: 65 | logit = self._compute_logit( 66 | hidden, 67 | self.out_layers[0].weight, 68 | self.out_layers[0].bias, 69 | self.out_projs[0], 70 | ) 71 | nll = ( 72 | -F.log_softmax(logit, dim=-1).gather(1, target.unsqueeze(1)).squeeze(1) 73 | ) 74 | else: 75 | # construct weights and biases 76 | weights, biases = [], [] 77 | for i in range(len(self.cutoffs)): 78 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 79 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 80 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 81 | 82 | if i == 0: 83 | weight_i = torch.cat([weight_i, self.cluster_weight], dim=0) 84 | bias_i = torch.cat([bias_i, self.cluster_bias], dim=0) 85 | 86 | weights.append(weight_i) 87 | biases.append(bias_i) 88 | 89 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 90 | 91 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 92 | head_logprob = F.log_softmax(head_logit, dim=1) 93 | 94 | nll = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device) 95 | 96 | offset = 0 97 | cutoff_values = [0] + self.cutoffs 98 | for i in range(len(cutoff_values) - 1): 99 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 100 | 101 | mask_i = (target >= l_idx) & (target < r_idx) 102 | indices_i = mask_i.nonzero().squeeze() 103 | 104 | if indices_i.numel() == 0: 105 | continue 106 | 107 | target_i = target.index_select(0, indices_i) - l_idx 108 | head_logprob_i = head_logprob.index_select(0, indices_i) 109 | 110 | if i == 0: 111 | logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1) 112 | else: 113 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 114 | 115 | hidden_i = hidden.index_select(0, indices_i) 116 | 117 | tail_logit_i = self._compute_logit( 118 | hidden_i, weight_i, bias_i, proj_i 119 | ) 120 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 121 | 122 | logprob_i = head_logprob_i[:, -i] + tail_logprob_i.gather( 123 | 1, target_i[:, None] 124 | ).squeeze(1) 125 | 126 | if (hasattr(self, "keep_order") and self.keep_order) or keep_order: 127 | nll.index_copy_(0, indices_i, -logprob_i) 128 | else: 129 | nll[offset: offset + logprob_i.size(0)].copy_(-logprob_i) 130 | 131 | offset += logprob_i.size(0) 132 | 133 | return nll 134 | 135 | 136 | class PositionalEmbedding(nn.Module): 137 | def __init__(self, demb): 138 | super(PositionalEmbedding, self).__init__() 139 | 140 | self.demb = demb 141 | 142 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) 143 | self.register_buffer("inv_freq", inv_freq) 144 | 145 | def forward(self, pos_seq, bsz=None): 146 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq) # Outer product 147 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 148 | 149 | if bsz is not None: 150 | return pos_emb[:, None, :].expand(-1, bsz, -1) 151 | else: 152 | return pos_emb[:, None, :] 153 | 154 | 155 | class PositionwiseFF(nn.Module): 156 | def __init__(self, d_model, d_inner, dropout): 157 | super(PositionwiseFF, self).__init__() 158 | 159 | self.d_model = d_model 160 | self.d_inner = d_inner 161 | self.dropout = dropout 162 | 163 | self.CoreNet = nn.Sequential( 164 | nn.Linear(d_model, d_inner), 165 | nn.ReLU(inplace=True), 166 | nn.Dropout(dropout), 167 | nn.Linear(d_inner, d_model), 168 | nn.Dropout(dropout), 169 | ) 170 | 171 | self.layer_norm = nn.LayerNorm(d_model) 172 | 173 | 174 | def forward(self, inp): 175 | ##### positionwise feed-forward 176 | core_out = self.CoreNet(inp) 177 | 178 | ##### residual connection + layer normalization 179 | output = self.layer_norm(inp + core_out) 180 | 181 | return output 182 | 183 | 184 | # Main attention class with all lengths 185 | class RelMultiHeadAttn(nn.Module): 186 | def __init__( 187 | self, 188 | n_head, 189 | d_model, 190 | d_head, 191 | dropout, 192 | dropatt=0, 193 | tgt_len=None, 194 | mem_len=None, 195 | use_qkv=True, 196 | ): 197 | super(RelMultiHeadAttn, self).__init__() 198 | 199 | self.n_head = n_head 200 | self.d_model = d_model 201 | self.d_head = d_head 202 | self.dropout = dropout 203 | 204 | if use_qkv: 205 | self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) 206 | else: 207 | self.q_net = nn.Linear(d_model, n_head * d_head, bias=False) 208 | self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False) # Split into k and v later 209 | 210 | self.drop = nn.Dropout(dropout) 211 | self.dropatt = nn.Dropout(dropatt) 212 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) 213 | 214 | self.layer_norm = nn.LayerNorm(d_model) 215 | 216 | self.scale = 1 / (d_head ** 0.5) 217 | 218 | def _parallelogram_mask(self, h, w, left=False): 219 | mask = torch.ones((h, w)).byte() 220 | m = min(h, w) 221 | mask[:m, :m] = torch.triu(mask[:m, :m]) 222 | mask[-m:, -m:] = torch.tril(mask[-m:, -m:]) 223 | 224 | if left: 225 | return mask 226 | else: 227 | return mask.flip(0) 228 | 229 | def _shift(self, x, qlen, klen, mask, left=False): 230 | if qlen > 1: 231 | zero_pad = torch.zeros( 232 | (x.size(0), qlen - 1, x.size(2), x.size(3)), 233 | device=x.device, 234 | dtype=x.dtype, 235 | ) 236 | else: 237 | zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype) 238 | 239 | if left: 240 | mask = mask.flip(1) 241 | x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1) 242 | else: 243 | x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1) 244 | 245 | x = x_padded.masked_select(mask[:, :, None, None]).view( 246 | qlen, klen, x.size(2), x.size(3) 247 | ) 248 | 249 | return x 250 | 251 | def _rel_shift(self, x, zero_triu=False): 252 | zero_pad = torch.zeros( 253 | (x.size(0), x.size(1), x.size(2), 1), device=x.device, dtype=x.dtype 254 | ) 255 | x_padded = torch.cat([zero_pad, x], dim=3) 256 | 257 | x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2)) 258 | 259 | x = x_padded[:, :, 1:].view_as(x) 260 | 261 | if zero_triu: 262 | ones = torch.ones((x.size(2), x.size(3))) 263 | x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] 264 | 265 | return x 266 | 267 | def forward(self, w, r, attn_mask=None, mems=None): 268 | raise NotImplementedError 269 | 270 | 271 | # Default attention layer used 272 | class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): 273 | def __init__(self, *args, **kwargs): 274 | super(RelPartialLearnableMultiHeadAttn, self).__init__( 275 | *args, **kwargs 276 | ) # ext_len passed here 277 | 278 | self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) 279 | 280 | def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): 281 | qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) 282 | 283 | if mems is not None: 284 | cat = torch.cat([mems, w], 0) 285 | w_heads = self.qkv_net(cat) 286 | r_head_k = self.r_net(r) 287 | 288 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 289 | w_head_q = w_head_q[-qlen:] 290 | else: 291 | w_heads = self.qkv_net(w) 292 | r_head_k = self.r_net(r) 293 | 294 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 295 | 296 | klen = w_head_k.size(0) 297 | 298 | w_head_q = w_head_q.view( 299 | qlen, bsz, self.n_head, self.d_head 300 | ) # qlen x bsz x n_head x d_head 301 | w_head_k = w_head_k.view( 302 | klen, bsz, self.n_head, self.d_head 303 | ) # qlen x bsz x n_head x d_head 304 | w_head_v = w_head_v.view( 305 | klen, bsz, self.n_head, self.d_head 306 | ) # qlen x bsz x n_head x d_head 307 | 308 | r_head_k = r_head_k.view( 309 | rlen, self.n_head, self.d_head 310 | ) # qlen x n_head x d_head 311 | 312 | #### compute attention score 313 | rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head 314 | AC = torch.einsum( 315 | "ibnd,jbnd->bnij", (rw_head_q, w_head_k) 316 | ) # qlen x klen x bsz x n_head 317 | 318 | rr_head_q = w_head_q + r_r_bias 319 | BD = torch.einsum( 320 | "ibnd,jnd->bnij", (rr_head_q, r_head_k) 321 | ) # qlen x klen x bsz x n_head 322 | BD = self._rel_shift(BD) 323 | 324 | # [bsz x n_head x qlen x klen] 325 | attn_score = AC + BD 326 | attn_score.mul_(self.scale) 327 | 328 | #### compute attention probability 329 | if attn_mask is not None: 330 | if attn_mask.dim() == 2: 331 | attn_score.masked_fill_(attn_mask[None, None, :, :], -float("inf")) 332 | elif attn_mask.dim() == 3: 333 | attn_score.masked_fill_(attn_mask[:, None, :, :], -float("inf")) 334 | 335 | # [bsz x n_head x qlen x klen] 336 | attn_prob = F.softmax(attn_score, dim=3) 337 | attn_prob = self.dropatt(attn_prob) 338 | 339 | #### compute attention vector 340 | attn_vec = torch.einsum("bnij,jbnd->ibnd", (attn_prob, w_head_v)) 341 | 342 | # [qlen x bsz x n_head x d_head] 343 | attn_vec = attn_vec.contiguous().view( 344 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head 345 | ) 346 | 347 | ##### linear projection 348 | attn_out = self.o_net(attn_vec) 349 | attn_out = self.drop(attn_out) 350 | 351 | ##### residual connection + layer normalization 352 | output = self.layer_norm(w + attn_out) 353 | 354 | return output 355 | 356 | 357 | # Default attention layer used 358 | class RelPartialLearnableDecoderLayer(nn.Module): 359 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): 360 | super(RelPartialLearnableDecoderLayer, self).__init__() 361 | 362 | self.dec_attn = RelPartialLearnableMultiHeadAttn( 363 | n_head, d_model, d_head, dropout, **kwargs 364 | ) 365 | 366 | self.pos_ff = PositionwiseFF( 367 | d_model, d_inner, dropout 368 | ) 369 | 370 | def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): 371 | output = self.dec_attn( 372 | dec_inp, r, r_w_bias, r_r_bias, attn_mask=dec_attn_mask, mems=mems 373 | ) 374 | 375 | output = self.pos_ff(output) 376 | 377 | return output 378 | 379 | 380 | class AdaptiveEmbedding(nn.Module): 381 | def __init__(self, n_token, d_embed, d_proj): 382 | """ 383 | 384 | :param n_token: number of tokens in vocab 385 | :param d_embed: dimension of embedding 386 | :param d_proj: dimension of embedding projection (unused here since d_proj = d_model) 387 | """ 388 | 389 | super(AdaptiveEmbedding, self).__init__() 390 | 391 | self.n_token = n_token 392 | self.d_embed = d_embed 393 | 394 | self.cutoffs = [n_token] 395 | self.d_proj = d_proj 396 | 397 | self.emb_scale = d_proj ** 0.5 398 | 399 | self.emb_layers = nn.ModuleList() 400 | self.emb_projs = nn.ParameterList() 401 | 402 | self.emb_layers.append( 403 | nn.Embedding(n_token, d_embed, sparse=False) 404 | ) 405 | 406 | if d_proj != d_embed: 407 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed))) 408 | 409 | def forward(self, inp): 410 | if len(inp.shape) == 2: 411 | embed = self.emb_layers[0](inp) 412 | else: 413 | embed = torch.matmul(inp, self.emb_layers[0].weight) 414 | 415 | if self.d_proj != self.d_embed: 416 | embed = F.linear(embed, self.emb_projs[0]) 417 | 418 | embed.mul_(self.emb_scale) 419 | 420 | return embed 421 | 422 | 423 | class MemTransformerLM(nn.Module): 424 | def __init__( 425 | self, 426 | cfg, 427 | vocab 428 | ): 429 | n_layer = cfg.MODEL.num_layers 430 | n_head = cfg.MODEL.num_heads 431 | d_model = cfg.MODEL.units 432 | d_head = cfg.MODEL.units // cfg.MODEL.num_heads 433 | d_inner = cfg.MODEL.inner_size 434 | dropout = cfg.MODEL.dropout 435 | dropatt = cfg.MODEL.attention_dropout 436 | d_embed = cfg.MODEL.units 437 | tgt_len = cfg.TRAIN.tgt_length 438 | mem_len = cfg.TRAIN.mem_length 439 | same_length = cfg.MODEL.same_length 440 | clamp_len = cfg.MODEL.clamp_len 441 | 442 | super(MemTransformerLM, self).__init__() 443 | self.cfg = cfg 444 | self.n_token = len(vocab) 445 | d_embed = d_model if d_embed is None else d_embed 446 | self.d_embed = d_embed 447 | self.d_model = d_model 448 | self.n_head = n_head 449 | self.d_head = d_head 450 | 451 | self.word_emb = AdaptiveEmbedding( 452 | self.n_token, d_embed, d_model) 453 | 454 | self.drop = nn.Dropout(dropout) 455 | self.n_layer = n_layer 456 | 457 | self.tgt_len = tgt_len 458 | self.mem_len = mem_len 459 | self.max_klen = tgt_len + mem_len 460 | 461 | self.layers = nn.ModuleList() 462 | 463 | for i in range(n_layer): 464 | self.layers.append( 465 | RelPartialLearnableDecoderLayer( 466 | n_head, 467 | d_model, 468 | d_head, 469 | d_inner, 470 | dropout, 471 | tgt_len=tgt_len, 472 | mem_len=mem_len, 473 | dropatt=dropatt, 474 | ) 475 | ) 476 | self.crit = ProjectedAdaptiveLogSoftmax( 477 | self.n_token, d_embed, d_model 478 | ) 479 | 480 | for i in range(len(self.crit.out_layers)): 481 | self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight 482 | 483 | self.same_length = same_length 484 | self.clamp_len = clamp_len 485 | 486 | self.detach_mems_grad = True 487 | self._create_params() 488 | 489 | def _create_params(self): 490 | self.pos_emb = PositionalEmbedding(self.d_model) 491 | self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 492 | self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 493 | 494 | def reset_length(self, tgt_len, mem_len): 495 | self.tgt_len = tgt_len 496 | self.mem_len = mem_len 497 | 498 | def init_mems(self, n_layers): 499 | if self.mem_len > 0: 500 | param = next(self.parameters()) 501 | mems = torch.empty(n_layers + 1, 0, dtype=param.dtype, 502 | device=param.device) 503 | return mems 504 | else: 505 | return None 506 | 507 | def _update_mems(self, hids, mems, qlen, mlen, reset_mems=None): 508 | 509 | # The idea is that randomization from shuffling will have be equivalent to memory resetting 510 | 511 | if mems is None: 512 | return None 513 | 514 | assert len(hids) == len(mems) 515 | # mems is not the same as self.mem_len 516 | 517 | # There are `mlen + qlen` steps that can be cached into mems 518 | # For the next step, the last `ext_len` of the `qlen` tokens 519 | # will be used as the extended context. Hence, we only cache 520 | # the tokens from `mlen + qlen - self.ext_len - self.mem_len` 521 | # to `mlen + qlen - self.ext_len`. 522 | with torch.no_grad(): 523 | new_mems = [] 524 | end_idx = mlen + max(0, qlen) 525 | beg_idx = max(0, end_idx - self.mem_len) 526 | stacked = torch.stack(hids) 527 | 528 | if mems.numel(): 529 | cat = torch.cat([mems, stacked], dim=1) 530 | else: 531 | cat = stacked 532 | 533 | if self.detach_mems_grad: 534 | new_mems = cat[:, beg_idx:end_idx].detach() 535 | else: 536 | new_mems = cat[:, beg_idx:end_idx].detach() 537 | 538 | return new_mems 539 | 540 | def _forward(self, dec_inp, reset_mems, mems=None): 541 | 542 | qlen, bsz = dec_inp.size()[0], dec_inp.size()[1] 543 | word_emb = self.word_emb(dec_inp) 544 | 545 | mlen = mems[0].size(0) if mems is not None else 0 546 | klen = mlen + qlen 547 | 548 | # Generate the mask between query and all the keys 549 | if self.same_length: 550 | all_ones = word_emb.new_ones(qlen, klen) 551 | mask_len = klen - self.mem_len 552 | if mask_len > 0: 553 | mask_shift_len = qlen - mask_len 554 | else: 555 | mask_shift_len = qlen 556 | 557 | if reset_mems is None: 558 | indices = torch.BoolTensor(dec_inp.shape[1]).fill_(False) 559 | else: 560 | indices = reset_mems 561 | 562 | if self.same_length: 563 | dec_attn_mask = (( 564 | torch.triu(all_ones, 1 + mlen) 565 | + torch.tril(all_ones, -mask_shift_len) 566 | ).bool()[ 567 | :, : 568 | ]).repeat(len(indices), 1, 1) # -1 569 | else: 570 | dec_attn_mask = (torch.triu( 571 | word_emb.new_ones(qlen, klen), diagonal=1 + mlen 572 | ).bool()[:, :]).repeat(len(indices), 1, 1) 573 | 574 | dec_attn_mask[indices, :, :mlen] = 1 575 | 576 | 577 | hids = [] 578 | pos_seq = torch.arange( 579 | klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype 580 | ) 581 | if self.clamp_len > 0: 582 | pos_seq.clamp_(max=self.clamp_len) 583 | pos_emb = self.pos_emb(pos_seq) 584 | 585 | core_out = self.drop(word_emb) 586 | pos_emb = self.drop(pos_emb) 587 | 588 | hids.append(core_out) 589 | 590 | for i, layer in enumerate(self.layers): 591 | mems_i = None if mems is None else mems[i] 592 | core_out = layer( 593 | core_out, 594 | pos_emb, 595 | self.r_w_bias, 596 | self.r_r_bias, 597 | dec_attn_mask=dec_attn_mask, 598 | mems=mems_i, 599 | ) 600 | hids.append(core_out) 601 | core_out = self.drop(core_out) 602 | 603 | new_mems = self._update_mems(hids, mems, mlen, qlen, reset_mems) 604 | return core_out, new_mems 605 | 606 | def forward_generate(self, data, mems): 607 | 608 | if mems is None: 609 | mems = self.init_mems(self.n_layer) 610 | 611 | tgt_len = data.size(0) 612 | batch_size = data.size(1) 613 | 614 | hidden, new_mems = self._forward(data, None, mems=mems) 615 | 616 | pred_hid = hidden[-tgt_len:] 617 | 618 | assert self.crit.n_clusters == 0 619 | 620 | logits = self.crit._compute_logit( 621 | pred_hid.view(-1, pred_hid.size(-1)), 622 | self.crit.out_layers[0].weight, 623 | self.crit.out_layers[0].bias, 624 | self.crit.out_projs[0], 625 | ) 626 | logits = logits.view(tgt_len, batch_size, -1) 627 | 628 | return (logits, new_mems) 629 | 630 | def forward_generate_gumbel(self, data, temperature, mems): 631 | 632 | from torch.autograd import Variable 633 | 634 | def sample_gumbel(shape, eps=1e-20): 635 | U = torch.rand(shape).cuda() 636 | return -Variable(torch.log(-torch.log(U + eps) + eps)) 637 | 638 | def gumbel_softmax_sample(logits, temperature): 639 | y = logits + sample_gumbel(logits.size()) 640 | return F.softmax(y / temperature, dim=-1) 641 | 642 | def gumbel_softmax(logits, temperature): 643 | """ 644 | input: [*, n_class] 645 | return: [*, n_class] an one-hot vector 646 | """ 647 | y = gumbel_softmax_sample(logits, temperature) 648 | shape = y.size() 649 | _, ind = y.max(dim=-1) 650 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 651 | y_hard.scatter_(1, ind.view(-1, 1), 1) 652 | y_hard = y_hard.view(*shape) 653 | return (y_hard - y).detach() + y 654 | 655 | if mems is None: 656 | mems = self.init_mems(self.n_layer) 657 | 658 | tgt_len = data.size(0) 659 | batch_size = data.size(1) 660 | hidden, new_mems = self._forward(data, None, mems=mems) 661 | 662 | pred_hid = hidden[-tgt_len:] 663 | 664 | assert self.crit.n_clusters == 0 665 | 666 | logits = self.crit._compute_logit( 667 | pred_hid.view(-1, pred_hid.size(-1)), 668 | self.crit.out_layers[0].weight, 669 | self.crit.out_layers[0].bias, 670 | self.crit.out_projs[0], 671 | ) 672 | logits = gumbel_softmax( 673 | logits.view(tgt_len, batch_size, -1), temperature=temperature 674 | ) 675 | 676 | return (logits, new_mems) 677 | 678 | def forward(self, data, target, reset_mems, mems): 679 | # nn.DataParallel does not allow size(0) tensors to be broadcasted. 680 | # So, have to initialize size(0) mems inside the model forward. 681 | # Moreover, have to return new_mems to allow nn.DataParallel to piece 682 | # them together. 683 | if mems is None: 684 | mems = self.init_mems(self.n_layer) 685 | 686 | tgt_len = target.size(0) 687 | hidden, new_mems = self._forward(data, reset_mems, mems=mems) 688 | 689 | pred_hid = hidden[-tgt_len:] 690 | loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1)) 691 | loss = loss.view(tgt_len, -1) 692 | 693 | return (loss, new_mems) 694 | -------------------------------------------------------------------------------- /commu/preprocessor/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline import PreprocessPipeline -------------------------------------------------------------------------------- /commu/preprocessor/augment.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import List, Union 4 | 5 | import miditoolkit 6 | import numpy as np 7 | import parmap 8 | import pretty_midi 9 | 10 | from .utils.constants import ( 11 | BPM_INTERVAL, 12 | KEY_NUM_MAP, 13 | NUM_BPM_AUGMENT, 14 | NUM_KEY_AUGMENT, 15 | MAJOR_KEY, 16 | MINOR_KEY, 17 | ) 18 | 19 | def get_avg_bpm(event_times: np.ndarray, tempo_infos: np.ndarray, end_time: float) -> int: 20 | def _normalize(_avg_bpm): 21 | return _avg_bpm - _avg_bpm % BPM_INTERVAL 22 | 23 | if len(tempo_infos) == 1: 24 | return _normalize(tempo_infos[-1]) 25 | 26 | event_times_with_end_time = np.concatenate([event_times, [end_time]]) 27 | bpm_durations = np.diff(event_times_with_end_time) 28 | total_bpm = 0 29 | for duration, bpm in zip(bpm_durations, tempo_infos): 30 | total_bpm += duration * bpm 31 | 32 | avg_bpm = int(total_bpm / end_time) 33 | return _normalize(avg_bpm) 34 | 35 | def augment_by_key(midi_path: str, augmented_tmp_dir: str, key_change: int) -> Union[Path, str]: 36 | midi = miditoolkit.MidiFile(midi_path) 37 | midi_id = Path(midi_path).stem 38 | pitch_track_notes = [midi.instruments[0].notes] 39 | 40 | for idx, key in enumerate(midi.key_signature_changes): 41 | origin_key = int(key.key_number) 42 | if origin_key < MINOR_KEY[0]: 43 | try: 44 | midi.key_signature_changes[idx].key_number = MAJOR_KEY[origin_key + key_change] 45 | except IndexError: 46 | midi.key_signature_changes[idx].key_number = MAJOR_KEY[ 47 | origin_key + key_change - len(MAJOR_KEY) 48 | ] 49 | else: 50 | origin_key = origin_key - MINOR_KEY[0] 51 | try: 52 | midi.key_signature_changes[idx].key_number = MINOR_KEY[origin_key + key_change] 53 | except IndexError: 54 | midi.key_signature_changes[idx].key_number = MINOR_KEY[ 55 | origin_key + key_change - len(MINOR_KEY) 56 | ] 57 | 58 | new_key_number = midi.key_signature_changes[0].key_number 59 | new_key = KEY_NUM_MAP[new_key_number] 60 | 61 | for track in pitch_track_notes: 62 | for note in track: 63 | note.pitch = note.pitch + key_change 64 | try: 65 | midi.dump(os.path.join(augmented_tmp_dir, midi_id + f"_{new_key}.mid")) 66 | except ValueError as e: 67 | print(e, midi_id) 68 | # exceeds note pitch range 69 | return None 70 | return os.path.join(augmented_tmp_dir, midi_id + f"_{new_key}.mid") 71 | 72 | 73 | def augment_by_bpm(augment_tmp_midi_pth, augmented_dir, bpm_change) -> None: 74 | midi = pretty_midi.PrettyMIDI(augment_tmp_midi_pth) 75 | event_times, origin_bpm = midi.get_tempo_changes() 76 | 77 | if len(origin_bpm) > 1: 78 | origin_bpm = get_avg_bpm(event_times, origin_bpm, midi.get_end_time()) 79 | 80 | midi_object = miditoolkit.MidiFile(augment_tmp_midi_pth) 81 | augment_midi_name = Path(augment_tmp_midi_pth).parts[-1].split(".")[0] 82 | 83 | new_bpm = int(origin_bpm) + bpm_change * BPM_INTERVAL 84 | midi_object.tempo_changes = [miditoolkit.TempoChange(tempo=new_bpm, time=0)] 85 | midi_object.dump(os.path.join(augmented_dir, augment_midi_name + f"_{round(new_bpm)}.mid")) 86 | 87 | 88 | def augment_data_map( 89 | midi_list: List, 90 | augmented_dir: str, 91 | augmented_tmp_dir: str, 92 | ) -> None: 93 | for midi_path in midi_list: 94 | for key_change in range(-NUM_KEY_AUGMENT, NUM_KEY_AUGMENT): 95 | augment_tmp_midi_pth = augment_by_key(midi_path, augmented_tmp_dir, key_change) 96 | if augment_tmp_midi_pth is not None: 97 | for bpm_change in range(-NUM_BPM_AUGMENT, NUM_BPM_AUGMENT + 1): 98 | augment_by_bpm(augment_tmp_midi_pth, augmented_dir, bpm_change) 99 | 100 | 101 | def augment_data( 102 | midi_path: Union[str, Path], 103 | augmented_dir: Union[str, Path], 104 | augmented_tmp_dir: Union[str, Path], 105 | num_cores: int, 106 | ) -> None: 107 | 108 | midifiles = [] 109 | 110 | for _, (dirpath, _, filenames) in enumerate(os.walk(midi_path)): 111 | midi_extensions = [".mid", ".MID", ".MIDI", ".midi"] 112 | for ext in midi_extensions: 113 | tem = [os.path.join(dirpath, _) for _ in filenames if _.endswith(ext)] 114 | if tem: 115 | midifiles += tem 116 | 117 | split_midi = np.array_split(np.array(midifiles), num_cores) 118 | split_midi = [x.tolist() for x in split_midi] 119 | parmap.map( 120 | augment_data_map, 121 | split_midi, 122 | augmented_dir, 123 | augmented_tmp_dir, 124 | pm_pbar=True, 125 | pm_processes=num_cores, 126 | ) -------------------------------------------------------------------------------- /commu/preprocessor/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import * 2 | from .meta import MetaEncoder 3 | from . import event_tokens -------------------------------------------------------------------------------- /commu/preprocessor/encoder/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import miditoolkit 4 | import numpy as np 5 | 6 | from . import encoder_utils 7 | from .event_tokens import TOKEN_OFFSET 8 | from ..utils.constants import ( 9 | DEFAULT_POSITION_RESOLUTION, 10 | DEFAULT_TICKS_PER_BEAT, 11 | SIG_TIME_MAP 12 | ) 13 | 14 | class EventSequenceEncoder: 15 | def __init__(self): 16 | self.event2word, self.word2event = encoder_utils.mk_remi_map() 17 | self.event2word = encoder_utils.add_flat_chord2map(self.event2word) 18 | self.event2word = encoder_utils.abstract_chord_types(self.event2word) 19 | self.position_resolution = DEFAULT_POSITION_RESOLUTION 20 | 21 | def encode(self, midi_paths, sample_info=None, for_cp=False): 22 | midi_file = miditoolkit.MidiFile(midi_paths) 23 | ticks_per_beat = midi_file.ticks_per_beat 24 | chord_progression = sample_info["chord_progressions"] 25 | num_measures = math.ceil(sample_info["num_measures"]) 26 | numerator = int(sample_info["time_signature"].split("/")[0]) 27 | denominator = int(sample_info["time_signature"].split("/")[1]) 28 | is_incomplete_measure = sample_info["is_incomplete_measure"] 29 | 30 | beats_per_bar = numerator / denominator * 4 31 | ticks_per_bar = int(ticks_per_beat * beats_per_bar) 32 | duration_bins = np.arange( 33 | int(ticks_per_bar / self.position_resolution), 34 | ticks_per_bar + 1, 35 | int(ticks_per_bar / self.position_resolution), 36 | dtype=int, 37 | ) 38 | 39 | events = encoder_utils.extract_events( 40 | midi_paths, 41 | duration_bins, 42 | ticks_per_bar=ticks_per_bar, 43 | ticks_per_beat=ticks_per_beat, 44 | chord_progression=chord_progression, 45 | num_measures=num_measures, 46 | is_incomplete_measure=is_incomplete_measure, 47 | ) 48 | if for_cp: 49 | return events 50 | 51 | words = [] 52 | for event in events: 53 | e = "{}_{}".format(event.name, event.value) 54 | if e in self.event2word: 55 | words.append(self.event2word[e]) 56 | else: 57 | # OOV 58 | if event.name == "Note Velocity": 59 | # replace with max velocity based on our training data 60 | words.append(self.event2word["Note Velocity_63"]) 61 | if event.name == "Note Duration": 62 | # replace with max duration 63 | words.append(self.event2word[f"Note Duration_{self.position_resolution-1}"]) 64 | else: 65 | # something is wrong 66 | # you should handle it for your own purpose 67 | print("OOV {}".format(e)) 68 | words.append(TOKEN_OFFSET.EOS.value) # eos token 69 | return np.array(words) 70 | 71 | def decode( 72 | self, 73 | midi_info, 74 | ): 75 | time_sig_word = midi_info.time_signature 76 | time_sig = SIG_TIME_MAP[time_sig_word - TOKEN_OFFSET.TS.value - 1] 77 | numerator = int(time_sig.split("/")[0]) 78 | denominator = int(time_sig.split("/")[1]) 79 | beats_per_bar = int(numerator/denominator * 4) 80 | 81 | ticks_per_bar = DEFAULT_TICKS_PER_BEAT * beats_per_bar 82 | 83 | duration_bins = np.arange( 84 | int(ticks_per_bar / self.position_resolution), 85 | ticks_per_bar + 1, 86 | int(ticks_per_bar / self.position_resolution), 87 | dtype=int, 88 | ) 89 | 90 | decoded_midi = encoder_utils.write_midi( 91 | midi_info, 92 | self.word2event, 93 | duration_bins=duration_bins, 94 | beats_per_bar=beats_per_bar, 95 | ) 96 | 97 | return decoded_midi -------------------------------------------------------------------------------- /commu/preprocessor/encoder/encoder_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Dict 3 | 4 | import miditoolkit 5 | import numpy as np 6 | 7 | from .event_tokens import base_event, TOKEN_OFFSET 8 | from ..utils.constants import ( 9 | BPM_INTERVAL, 10 | DEFAULT_POSITION_RESOLUTION, 11 | DEFAULT_TICKS_PER_BEAT, 12 | VELOCITY_INTERVAL, 13 | SIG_TIME_MAP, 14 | KEY_NUM_MAP 15 | ) 16 | 17 | NUM_VELOCITY_BINS = int(128 / VELOCITY_INTERVAL) 18 | DEFAULT_VELOCITY_BINS = np.linspace(2, 127, NUM_VELOCITY_BINS, dtype=int) 19 | 20 | class Item(object): 21 | def __init__(self, name, start, end, velocity, pitch): 22 | self.name = name 23 | self.start = start 24 | self.end = end 25 | self.velocity = velocity 26 | self.pitch = pitch 27 | 28 | def __repr__(self): 29 | return "Item(name={}, start={}, end={}, velocity={}, pitch={})".format( 30 | self.name, self.start, self.end, self.velocity, self.pitch 31 | ) 32 | 33 | 34 | class Event(object): 35 | def __init__(self, name, time, value, text): 36 | self.name = name 37 | self.time = time 38 | self.value = value 39 | self.text = text 40 | 41 | def __repr__(self): 42 | return "Event(name={}, time={}, value={}, text={})".format( 43 | self.name, self.time, self.value, self.text 44 | ) 45 | 46 | 47 | def mk_remi_map(): 48 | event = copy.deepcopy(base_event) 49 | for i in range(DEFAULT_POSITION_RESOLUTION): 50 | event.append(f"Note Duration_{i}") 51 | for i in range(1, DEFAULT_POSITION_RESOLUTION + 1): 52 | event.append(f"Position_{i}/{DEFAULT_POSITION_RESOLUTION}") 53 | 54 | event2word = {k: v for k, v in zip(event, range(2, len(event) + 2))} 55 | word2event = {v: k for k, v in zip(event, range(2, len(event) + 2))} 56 | 57 | return event2word, word2event 58 | 59 | def add_flat_chord2map(event2word: Dict): 60 | flat_chord = ["Chord_ab:", "Chord_bb:", "Chord_db:", "Chord_eb:", "Chord_gb:"] 61 | scale = [ 62 | "", 63 | "maj", 64 | "maj7", 65 | "7", 66 | "dim", 67 | "dim7", 68 | "+", 69 | "m", 70 | "m7", 71 | "sus4", 72 | "7sus4", 73 | "m6", 74 | "m7b5", 75 | "sus2", 76 | "add2", 77 | "6", 78 | "madd2", 79 | "mM7", 80 | ] 81 | 82 | flat_chords = [] 83 | for c in flat_chord: 84 | for s in scale: 85 | flat_chords.append(c + s) 86 | 87 | for c in flat_chords: 88 | scale = c.split(":")[1] 89 | key = c.split(":")[0].split("_")[1][0] 90 | c = c.replace(":", "") 91 | if c.startswith("Chord_ab"): 92 | if scale == "" or scale == "maj" or scale == "6": 93 | event2word[c] = event2word["Chord_g#"] 94 | elif scale == "maj7" or scale == "add2" or scale == "sus2": 95 | event2word[c] = event2word["Chord_g#maj7"] 96 | elif scale == "7": 97 | event2word[c] = event2word["Chord_g#7"] 98 | elif scale == "dim" or scale == "dim7": 99 | event2word[c] = event2word["Chord_g#dim"] 100 | elif scale == "+": 101 | event2word[c] = event2word["Chord_g#+"] 102 | elif scale == "m" or scale == "m6" or scale == "mM7": 103 | event2word[c] = event2word["Chord_g#m"] 104 | elif scale == "m7" or scale == "madd2": 105 | event2word[c] = event2word["Chord_g#m7"] 106 | elif scale == "sus4" or scale == "7sus4": 107 | event2word[c] = event2word["Chord_g#sus4"] 108 | elif scale == "m7b5": 109 | event2word[c] = event2word["Chord_g#m7b5"] 110 | else: 111 | if scale == "" or scale == "maj" or scale == "6": 112 | new_key = chr(ord(key) - 1) 113 | word = "Chord_" + new_key + "#" 114 | event2word[c] = event2word[word] 115 | elif scale == "maj7" or scale == "add2" or scale == "sus2": 116 | new_key = chr(ord(key) - 1) 117 | word = "Chord_" + new_key + "#maj7" 118 | event2word[c] = event2word[word] 119 | elif scale == "7": 120 | new_key = chr(ord(key) - 1) 121 | word = "Chord_" + new_key + "#7" 122 | event2word[c] = event2word[word] 123 | elif scale == "dim" or scale == "dim7": 124 | new_key = chr(ord(key) - 1) 125 | word = "Chord_" + new_key + "#dim" 126 | event2word[c] = event2word[word] 127 | elif scale == "+": 128 | new_key = chr(ord(key) - 1) 129 | word = "Chord_" + new_key + "#+" 130 | event2word[c] = event2word[word] 131 | elif scale == "m" or scale == "m6" or scale == "mM7": 132 | new_key = chr(ord(key) - 1) 133 | word = "Chord_" + new_key + "#m" 134 | event2word[c] = event2word[word] 135 | elif scale == "m7" or scale == "madd2": 136 | new_key = chr(ord(key) - 1) 137 | word = "Chord_" + new_key + "#m7" 138 | event2word[c] = event2word[word] 139 | elif scale == "sus4" or scale == "7sus4": 140 | new_key = chr(ord(key) - 1) 141 | word = "Chord_" + new_key + "#sus4" 142 | event2word[c] = event2word[word] 143 | elif scale == "m7b5": 144 | new_key = chr(ord(key) - 1) 145 | word = "Chord_" + new_key + "#m7b5" 146 | event2word[c] = event2word[word] 147 | 148 | return event2word 149 | 150 | def abstract_chord_types(event2word): 151 | chord = ["Chord_a:", "Chord_b:", "Chord_c:", "Chord_d:", "Chord_e:", "Chord_f:", "Chord_g:"] 152 | scale = ["7sus4", "m6", "sus2", "add2", "dim7", "6", "madd2", "mM7", ] 153 | 154 | chords = [] 155 | for c in chord: 156 | for s in scale: 157 | chords.append(c + s) 158 | 159 | for c in chords: 160 | scale = c.split(":")[1] 161 | key = c.split(":")[0].split("_")[1][0] 162 | c = c.replace(":", "") 163 | if scale == "7sus4": 164 | word = "Chord_" + key + "sus4" 165 | event2word[c] = event2word[word] 166 | if scale == "m6": 167 | word = "Chord_" + key + "m" 168 | event2word[c] = event2word[word] 169 | if scale == "sus2" or scale == "add2": 170 | word = "Chord_" + key + "maj7" 171 | event2word[c] = event2word[word] 172 | if scale == "6": 173 | word = "Chord_" + key 174 | event2word[c] = event2word[word] 175 | if scale == "dim7": 176 | word = "Chord_" + key + "dim" 177 | event2word[c] = event2word[word] 178 | if scale == "madd2" or scale == "mM7": 179 | word = "Chord_" + key + "m7" 180 | event2word[c] = event2word[word] 181 | 182 | return event2word 183 | 184 | def extract_events( 185 | input_path, 186 | duration_bins, 187 | ticks_per_bar=None, 188 | ticks_per_beat=None, 189 | chord_progression=None, 190 | num_measures=None, 191 | is_incomplete_measure=None, 192 | ): 193 | note_items = read_items(input_path) 194 | max_time = note_items[-1].end 195 | if not chord_progression[0]: 196 | return None 197 | else: 198 | items = note_items 199 | groups = group_items(items, max_time, ticks_per_bar) 200 | events = item2event(groups, duration_bins) 201 | beats_per_bar = int(ticks_per_bar/ticks_per_beat) 202 | 203 | if chord_progression: 204 | new_chords = chord_progression[0] 205 | events = insert_chord_on_event( 206 | events, 207 | new_chords, 208 | ticks_per_bar, 209 | num_measures, 210 | is_incomplete_measure, 211 | beats_per_bar, 212 | ) 213 | 214 | return events 215 | 216 | def read_items(file_path): 217 | midi_obj = miditoolkit.midi.parser.MidiFile(file_path) 218 | note_items = [] 219 | notes = midi_obj.instruments[0].notes 220 | notes.sort(key=lambda x: (x.start, x.pitch)) 221 | for note in notes: 222 | note_items.append( 223 | Item( 224 | name="Note", 225 | start=note.start, 226 | end=note.end, 227 | velocity=note.velocity, 228 | pitch=note.pitch, 229 | ) 230 | ) 231 | note_items.sort(key=lambda x: x.start) 232 | return note_items 233 | 234 | def group_items(items, max_time, ticks_per_bar): 235 | items.sort(key=lambda x: x.start) 236 | downbeats = np.arange(0, max_time + ticks_per_bar, ticks_per_bar) 237 | groups = [] 238 | for db1, db2 in zip(downbeats[:-1], downbeats[1:]): 239 | insiders = [] 240 | for item in items: 241 | if (item.start >= db1) and (item.start < db2): 242 | insiders.append(item) 243 | if not insiders: 244 | insiders.append(Item(name="None", start=None, end=None, velocity=None, pitch="NN")) 245 | overall = [db1] + insiders + [db2] 246 | groups.append(overall) 247 | return groups 248 | 249 | def item2event(groups, duration_bins): 250 | events = [] 251 | n_downbeat = 0 252 | for i in range(len(groups)): 253 | if "NN" in [item.pitch for item in groups[i][1:-1]]: 254 | continue 255 | bar_st, bar_et = groups[i][0], groups[i][-1] 256 | n_downbeat += 1 257 | if groups[i][1].name == "Chord": 258 | events.append(Event(name="Bar", time=bar_st, value=None, text="{}".format(n_downbeat))) 259 | for item in groups[i][1:-1]: 260 | # position 261 | flags = np.linspace(bar_st, bar_et, DEFAULT_POSITION_RESOLUTION, endpoint=False) 262 | index = np.argmin(abs(flags - item.start)) 263 | events.append( 264 | Event( 265 | name="Position", 266 | time=item.start, 267 | value="{}/{}".format(index + 1, DEFAULT_POSITION_RESOLUTION), 268 | text="{}".format(item.start), 269 | ) 270 | ) 271 | if item.name == "Note": 272 | # velocity 273 | velocity_index = ( 274 | np.searchsorted(DEFAULT_VELOCITY_BINS, item.velocity, side="right") - 1 275 | ) 276 | events.append( 277 | Event( 278 | name="Note Velocity", 279 | time=item.start, 280 | value=velocity_index, 281 | text="{}/{}".format(item.velocity, DEFAULT_VELOCITY_BINS[velocity_index]), 282 | ) 283 | ) 284 | # pitch 285 | events.append( 286 | Event( 287 | name="Note On", 288 | time=item.start, 289 | value=item.pitch, 290 | text="{}".format(item.pitch), 291 | ) 292 | ) 293 | # duration 294 | duration = item.end - item.start 295 | index = np.argmin(abs(duration_bins - duration)) 296 | events.append( 297 | Event( 298 | name="Note Duration", 299 | time=item.start, 300 | value=index, 301 | text="{}/{}".format(duration, duration_bins[index]), 302 | ) 303 | ) 304 | elif item.name == "Chord": 305 | events.append( 306 | Event( 307 | name="Chord", 308 | time=item.start, 309 | value=item.pitch, 310 | text="{}".format(item.pitch), 311 | ) 312 | ) 313 | return events 314 | 315 | def insert_chord_on_event( 316 | events, 317 | chord_progression, 318 | tick_per_bar, 319 | num_measures, 320 | is_incomplete_measure, 321 | beats_per_bar, 322 | ): 323 | chord_idx_lst, chords = detect_chord(chord_progression, beats_per_bar) 324 | start_time = tick_per_bar * is_incomplete_measure 325 | chord_events = [] 326 | for i in range(num_measures): 327 | chord_events.append( 328 | Event(name="Bar", time=i * tick_per_bar, value=None, text="{}".format(i + 1)) 329 | ) 330 | while chord_idx_lst and chord_idx_lst[0] < i + 1 - is_incomplete_measure: 331 | chord_position = chord_idx_lst.pop(0) 332 | chord_time = int(chord_position * tick_per_bar + start_time) 333 | chord = chords.pop(0) 334 | chord_events.append( 335 | Event( 336 | name="Position", 337 | time=chord_time, 338 | value="{}/{}".format( 339 | int((chord_position - i + is_incomplete_measure) * DEFAULT_POSITION_RESOLUTION) + 1, 340 | DEFAULT_POSITION_RESOLUTION 341 | ), 342 | text=chord_time, 343 | ) 344 | ) 345 | chord_events.append( 346 | Event(name="Chord", 347 | time=chord_time, 348 | value=chord.split("/")[0].split("(")[0], 349 | text=chord.split("/")[0].split("(")[0]) 350 | ) 351 | 352 | inserted_events = chord_events + events 353 | inserted_events.sort(key=lambda x: x.time) 354 | return inserted_events 355 | 356 | def detect_chord(chord_progression, beats_per_bar): 357 | chords_per_bar = beats_per_bar * 2 358 | num_measures = int(len(chord_progression)/chords_per_bar) 359 | split_by_bar = np.array_split(np.array(chord_progression), num_measures) 360 | chord_idx = [] 361 | chord_name = [] 362 | for bar_idx, bar in enumerate(split_by_bar): 363 | for c_idx, chord in enumerate(bar): 364 | chord = chord.lower() 365 | if c_idx == 0 or chord != chord_name[-1]: 366 | chord_idx.append(bar_idx + c_idx / chords_per_bar) 367 | chord_name.append(chord) 368 | return chord_idx, chord_name 369 | 370 | def word_to_event(words, word2event): 371 | events = [] 372 | for word in words: 373 | try: 374 | event_name, event_value = word2event[word].split("_") 375 | except KeyError: 376 | if word == 1: 377 | # 따로 디코딩 되지 않는 EOS 378 | continue 379 | else: 380 | print(f"OOV: {word}") 381 | continue 382 | events.append(Event(event_name, None, event_value, None)) 383 | return events 384 | 385 | def write_midi( 386 | midi_info, 387 | word2event, 388 | duration_bins, 389 | beats_per_bar, 390 | ): 391 | events = word_to_event(midi_info.event_seq, word2event) 392 | # get downbeat and note (no time) 393 | temp_notes = [] 394 | temp_chords = [] 395 | for i in range(len(events) - 3): 396 | if events[i].name == "Bar" and i > 0: 397 | temp_notes.append("Bar") 398 | temp_chords.append("Bar") 399 | elif ( 400 | events[i].name == "Position" 401 | and events[i + 1].name == "Note Velocity" 402 | and events[i + 2].name == "Note On" 403 | and events[i + 3].name == "Note Duration" 404 | ): 405 | # start time and end time from position 406 | position = int(events[i].value.split("/")[0]) - 1 407 | # velocity 408 | index = int(events[i + 1].value) 409 | velocity = int(DEFAULT_VELOCITY_BINS[index]) 410 | # pitch 411 | pitch = int(events[i + 2].value) 412 | # duration 413 | index = int(events[i + 3].value) 414 | duration = duration_bins[index] 415 | # adding 416 | temp_notes.append([position, velocity, pitch, duration]) 417 | elif events[i].name == "Position" and events[i + 1].name == "Chord": 418 | position = int(events[i].value.split("/")[0]) - 1 419 | temp_chords.append([position, events[i + 1].value]) 420 | # get specific time for notes 421 | ticks_per_beat = DEFAULT_TICKS_PER_BEAT 422 | ticks_per_bar = ticks_per_beat * beats_per_bar 423 | notes = [] 424 | current_bar = 0 425 | for note in temp_notes: 426 | if note == "Bar": 427 | current_bar += 1 428 | else: 429 | position, velocity, pitch, duration = note 430 | # position (start time) 431 | current_bar_st = current_bar * ticks_per_bar 432 | current_bar_et = (current_bar + 1) * ticks_per_bar 433 | flags = np.linspace( 434 | int(current_bar_st), 435 | int(current_bar_et), 436 | int(DEFAULT_POSITION_RESOLUTION), 437 | endpoint=False, 438 | dtype=int, 439 | ) 440 | st = flags[position] 441 | # duration (end time) 442 | et = st + duration 443 | notes.append(miditoolkit.Note(velocity, pitch, st, et)) 444 | # get specific time for chords 445 | if len(temp_chords) > 0: 446 | chords = [] 447 | current_bar = 0 448 | for chord in temp_chords: 449 | if chord == "Bar": 450 | current_bar += 1 451 | else: 452 | position, value = chord 453 | # position (start time) 454 | current_bar_st = current_bar * ticks_per_bar 455 | current_bar_et = (current_bar + 1) * ticks_per_bar 456 | flags = np.linspace( 457 | current_bar_st, current_bar_et, DEFAULT_POSITION_RESOLUTION, endpoint=False, dtype=int 458 | ) 459 | st = flags[position] 460 | chords.append([st, value]) 461 | 462 | midi = miditoolkit.midi.parser.MidiFile() 463 | numerator, denominator = SIG_TIME_MAP[ 464 | midi_info.time_signature 465 | - (TOKEN_OFFSET.TS.value + 1) 466 | ].split("/") 467 | ts = miditoolkit.midi.containers.TimeSignature( 468 | numerator=int(numerator), denominator=int(denominator), time=0 469 | ) 470 | key_num = midi_info.audio_key - (TOKEN_OFFSET.KEY.value + 1) 471 | ks = miditoolkit.KeySignature( 472 | key_name=KEY_NUM_MAP[key_num], 473 | time=0) 474 | midi.time_signature_changes.append(ts) 475 | midi.key_signature_changes.append(ks) 476 | midi.ticks_per_beat = DEFAULT_TICKS_PER_BEAT 477 | # write instrument 478 | inst = miditoolkit.midi.containers.Instrument(0, is_drum=False) 479 | inst.notes = notes 480 | midi.instruments.append(inst) 481 | # write bpm info 482 | tempo_changes = [] 483 | tempo_changes.append( 484 | miditoolkit.midi.containers.TempoChange( 485 | (midi_info.bpm - TOKEN_OFFSET.BPM.value) 486 | * BPM_INTERVAL, 487 | 0, 488 | ) 489 | ) 490 | midi.tempo_changes = tempo_changes 491 | 492 | # write chord into marker 493 | if len(temp_chords) > 0: 494 | for c in chords: 495 | midi.markers.append(miditoolkit.midi.containers.Marker(text=c[1], time=c[0])) 496 | 497 | return midi 498 | -------------------------------------------------------------------------------- /commu/preprocessor/encoder/event_tokens.py: -------------------------------------------------------------------------------- 1 | base_event = [ 2 | "Bar_None", 3 | "Note On_0", 4 | "Note On_1", 5 | "Note On_2", 6 | "Note On_3", 7 | "Note On_4", 8 | "Note On_5", 9 | "Note On_6", 10 | "Note On_7", 11 | "Note On_8", 12 | "Note On_9", 13 | "Note On_10", 14 | "Note On_11", 15 | "Note On_12", 16 | "Note On_13", 17 | "Note On_14", 18 | "Note On_15", 19 | "Note On_16", 20 | "Note On_17", 21 | "Note On_18", 22 | "Note On_19", 23 | "Note On_20", 24 | "Note On_21", 25 | "Note On_22", 26 | "Note On_23", 27 | "Note On_24", 28 | "Note On_25", 29 | "Note On_26", 30 | "Note On_27", 31 | "Note On_28", 32 | "Note On_29", 33 | "Note On_30", 34 | "Note On_31", 35 | "Note On_32", 36 | "Note On_33", 37 | "Note On_34", 38 | "Note On_35", 39 | "Note On_36", 40 | "Note On_37", 41 | "Note On_38", 42 | "Note On_39", 43 | "Note On_40", 44 | "Note On_41", 45 | "Note On_42", 46 | "Note On_43", 47 | "Note On_44", 48 | "Note On_45", 49 | "Note On_46", 50 | "Note On_47", 51 | "Note On_48", 52 | "Note On_49", 53 | "Note On_50", 54 | "Note On_51", 55 | "Note On_52", 56 | "Note On_53", 57 | "Note On_54", 58 | "Note On_55", 59 | "Note On_56", 60 | "Note On_57", 61 | "Note On_58", 62 | "Note On_59", 63 | "Note On_60", 64 | "Note On_61", 65 | "Note On_62", 66 | "Note On_63", 67 | "Note On_64", 68 | "Note On_65", 69 | "Note On_66", 70 | "Note On_67", 71 | "Note On_68", 72 | "Note On_69", 73 | "Note On_70", 74 | "Note On_71", 75 | "Note On_72", 76 | "Note On_73", 77 | "Note On_74", 78 | "Note On_75", 79 | "Note On_76", 80 | "Note On_77", 81 | "Note On_78", 82 | "Note On_79", 83 | "Note On_80", 84 | "Note On_81", 85 | "Note On_82", 86 | "Note On_83", 87 | "Note On_84", 88 | "Note On_85", 89 | "Note On_86", 90 | "Note On_87", 91 | "Note On_88", 92 | "Note On_89", 93 | "Note On_90", 94 | "Note On_91", 95 | "Note On_92", 96 | "Note On_93", 97 | "Note On_94", 98 | "Note On_95", 99 | "Note On_96", 100 | "Note On_97", 101 | "Note On_98", 102 | "Note On_99", 103 | "Note On_100", 104 | "Note On_101", 105 | "Note On_102", 106 | "Note On_103", 107 | "Note On_104", 108 | "Note On_105", 109 | "Note On_106", 110 | "Note On_107", 111 | "Note On_108", 112 | "Note On_109", 113 | "Note On_110", 114 | "Note On_111", 115 | "Note On_112", 116 | "Note On_113", 117 | "Note On_114", 118 | "Note On_115", 119 | "Note On_116", 120 | "Note On_117", 121 | "Note On_118", 122 | "Note On_119", 123 | "Note On_120", 124 | "Note On_121", 125 | "Note On_122", 126 | "Note On_123", 127 | "Note On_124", 128 | "Note On_125", 129 | "Note On_126", 130 | "Note On_127", 131 | "Note Velocity_0", 132 | "Note Velocity_1", 133 | "Note Velocity_2", 134 | "Note Velocity_3", 135 | "Note Velocity_4", 136 | "Note Velocity_5", 137 | "Note Velocity_6", 138 | "Note Velocity_7", 139 | "Note Velocity_8", 140 | "Note Velocity_9", 141 | "Note Velocity_10", 142 | "Note Velocity_11", 143 | "Note Velocity_12", 144 | "Note Velocity_13", 145 | "Note Velocity_14", 146 | "Note Velocity_15", 147 | "Note Velocity_16", 148 | "Note Velocity_17", 149 | "Note Velocity_18", 150 | "Note Velocity_19", 151 | "Note Velocity_20", 152 | "Note Velocity_21", 153 | "Note Velocity_22", 154 | "Note Velocity_23", 155 | "Note Velocity_24", 156 | "Note Velocity_25", 157 | "Note Velocity_26", 158 | "Note Velocity_27", 159 | "Note Velocity_28", 160 | "Note Velocity_29", 161 | "Note Velocity_30", 162 | "Note Velocity_31", 163 | "Note Velocity_32", 164 | "Note Velocity_33", 165 | "Note Velocity_34", 166 | "Note Velocity_35", 167 | "Note Velocity_36", 168 | "Note Velocity_37", 169 | "Note Velocity_38", 170 | "Note Velocity_39", 171 | "Note Velocity_40", 172 | "Note Velocity_41", 173 | "Note Velocity_42", 174 | "Note Velocity_43", 175 | "Note Velocity_44", 176 | "Note Velocity_45", 177 | "Note Velocity_46", 178 | "Note Velocity_47", 179 | "Note Velocity_48", 180 | "Note Velocity_49", 181 | "Note Velocity_50", 182 | "Note Velocity_51", 183 | "Note Velocity_52", 184 | "Note Velocity_53", 185 | "Note Velocity_54", 186 | "Note Velocity_55", 187 | "Note Velocity_56", 188 | "Note Velocity_57", 189 | "Note Velocity_58", 190 | "Note Velocity_59", 191 | "Note Velocity_60", 192 | "Note Velocity_61", 193 | "Note Velocity_62", 194 | "Note Velocity_63", 195 | 'Chord_a', 196 | 'Chord_a7', 197 | 'Chord_a+', 198 | 'Chord_adim', 199 | 'Chord_am', 200 | 'Chord_am7', 201 | 'Chord_am7b5', 202 | 'Chord_amaj7', 203 | 'Chord_asus4', 204 | 'Chord_a#', 205 | 'Chord_a#7', 206 | 'Chord_a#+', 207 | 'Chord_a#dim', 208 | 'Chord_a#m', 209 | 'Chord_a#m7', 210 | 'Chord_a#m7b5', 211 | 'Chord_a#maj7', 212 | 'Chord_a#sus4', 213 | 'Chord_b', 214 | 'Chord_b7', 215 | 'Chord_b+', 216 | 'Chord_bdim', 217 | 'Chord_bm', 218 | 'Chord_bm7', 219 | 'Chord_bm7b5', 220 | 'Chord_bmaj7', 221 | 'Chord_bsus4', 222 | 'Chord_c', 223 | 'Chord_c7', 224 | 'Chord_c+', 225 | 'Chord_cdim', 226 | 'Chord_cm', 227 | 'Chord_cm7', 228 | 'Chord_cm7b5', 229 | 'Chord_cmaj7', 230 | 'Chord_csus4', 231 | 'Chord_c#', 232 | 'Chord_c#7', 233 | 'Chord_c#+', 234 | 'Chord_c#dim', 235 | 'Chord_c#m', 236 | 'Chord_c#m7', 237 | 'Chord_c#m7b5', 238 | 'Chord_c#maj7', 239 | 'Chord_c#sus4', 240 | 'Chord_d', 241 | 'Chord_d7', 242 | 'Chord_d+', 243 | 'Chord_ddim', 244 | 'Chord_dm', 245 | 'Chord_dm7', 246 | 'Chord_dm7b5', 247 | 'Chord_dmaj7', 248 | 'Chord_dsus4', 249 | 'Chord_d#', 250 | 'Chord_d#7', 251 | 'Chord_d#+', 252 | 'Chord_d#dim', 253 | 'Chord_d#m', 254 | 'Chord_d#m7', 255 | 'Chord_d#m7b5', 256 | 'Chord_d#maj7', 257 | 'Chord_d#sus4', 258 | 'Chord_e', 259 | 'Chord_e7', 260 | 'Chord_e+', 261 | 'Chord_edim', 262 | 'Chord_em', 263 | 'Chord_em7', 264 | 'Chord_em7b5', 265 | 'Chord_emaj7', 266 | 'Chord_esus4', 267 | 'Chord_f', 268 | 'Chord_f7', 269 | 'Chord_f+', 270 | 'Chord_fdim', 271 | 'Chord_fm', 272 | 'Chord_fm7', 273 | 'Chord_fm7b5', 274 | 'Chord_fmaj7', 275 | 'Chord_fsus4', 276 | 'Chord_f#', 277 | 'Chord_f#7', 278 | 'Chord_f#+', 279 | 'Chord_f#dim', 280 | 'Chord_f#m', 281 | 'Chord_f#m7', 282 | 'Chord_f#m7b5', 283 | 'Chord_f#maj7', 284 | 'Chord_f#sus4', 285 | 'Chord_g', 286 | 'Chord_g7', 287 | 'Chord_g+', 288 | 'Chord_gdim', 289 | 'Chord_gm', 290 | 'Chord_gm7', 291 | 'Chord_gm7b5', 292 | 'Chord_gmaj7', 293 | 'Chord_gsus4', 294 | 'Chord_g#', 295 | 'Chord_g#7', 296 | 'Chord_g#+', 297 | 'Chord_g#dim', 298 | 'Chord_g#m', 299 | 'Chord_g#m7', 300 | 'Chord_g#m7b5', 301 | 'Chord_g#maj7', 302 | 'Chord_g#sus4', 303 | "Chord_NN", 304 | ] 305 | 306 | import enum 307 | 308 | class TOKEN_OFFSET(enum.Enum): 309 | EOS = 1 310 | BAR = 2 311 | PITCH = 3 312 | NOTE_VELOCITY = 131 313 | CHORD_START = 195 314 | CHORD_END = 303 315 | NOTE_DURATION = 304 316 | POSITION = 432 317 | BPM = 560 318 | KEY = 601 319 | TS = 626 320 | PITCH_RANGE = 630 321 | NUM_MEASURES = 638 322 | INST = 641 323 | GENRE = 650 324 | VELOCITY = 653 325 | TRACK_ROLE = 719 326 | RHYTHM = 726 327 | REMI_META_OFFSET = 138 328 | META_CC_OFFSET = 7 329 | VOCAB_SIZE = 729 330 | -------------------------------------------------------------------------------- /commu/preprocessor/encoder/meta.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import functools 3 | import inspect 4 | import math 5 | from typing import Any, Callable, Dict, List, Union 6 | 7 | from commu.preprocessor.utils.exceptions import ErrorMessage, UnprocessableMidiError 8 | from ..utils import constants 9 | from ..utils.container import MidiMeta 10 | from .event_tokens import TOKEN_OFFSET 11 | 12 | EncodeFunc = Union[Callable[[Any], int], Callable[[Any, Dict[Any, int]], int]] 13 | META_ENCODING_ORDER = tuple(MidiMeta.__fields__.keys()) 14 | DEFAULT_ENCODING_MAPS = { 15 | "audio_key": constants.KEY_MAP, 16 | "time_signature": constants.TIME_SIG_MAP, 17 | "pitch_range": constants.PITCH_RANGE_MAP, 18 | "inst": constants.INST_MAP, 19 | "genre": constants.GENRE_MAP, 20 | "track_role": constants.TRACK_ROLE_MAP, 21 | "rhythm": constants.RHYTHM_MAP, 22 | } 23 | ATTR_ALIAS = { 24 | "min_velocity": "velocity", 25 | "max_velocity": "velocity", 26 | } 27 | 28 | 29 | class AliasMixin: 30 | @classmethod 31 | def get(cls, key: str): 32 | key = key.lower() 33 | if key in ATTR_ALIAS: 34 | return getattr(cls, ATTR_ALIAS[key].upper()) 35 | return getattr(cls, key.upper()) 36 | 37 | 38 | class Unknown(AliasMixin, int, enum.Enum): 39 | BPM = TOKEN_OFFSET.BPM.value 40 | AUDIO_KEY = TOKEN_OFFSET.KEY.value 41 | TIME_SIGNATURE = TOKEN_OFFSET.TS.value 42 | PITCH_RANGE = TOKEN_OFFSET.PITCH_RANGE.value 43 | INST = TOKEN_OFFSET.INST.value 44 | GENRE = TOKEN_OFFSET.GENRE.value 45 | VELOCITY = TOKEN_OFFSET.VELOCITY.value 46 | TRACK_ROLE = TOKEN_OFFSET.TRACK_ROLE.value 47 | RHYTHM = TOKEN_OFFSET.RHYTHM.value 48 | 49 | 50 | class Offset(AliasMixin, int, enum.Enum): 51 | BPM = TOKEN_OFFSET.BPM.value 52 | AUDIO_KEY = TOKEN_OFFSET.KEY.value + 1 53 | TIME_SIGNATURE = TOKEN_OFFSET.TS.value + 1 54 | PITCH_RANGE = TOKEN_OFFSET.PITCH_RANGE.value + 1 55 | MEASURES_4 = TOKEN_OFFSET.NUM_MEASURES.value 56 | MEASURES_8 = TOKEN_OFFSET.NUM_MEASURES.value + 1 57 | MEASURES_16 = TOKEN_OFFSET.NUM_MEASURES.value + 2 58 | INST = TOKEN_OFFSET.INST.value + 1 59 | GENRE = TOKEN_OFFSET.GENRE.value + 1 60 | VELOCITY = TOKEN_OFFSET.VELOCITY.value + 1 61 | TRACK_ROLE = TOKEN_OFFSET.TRACK_ROLE.value + 1 62 | RHYTHM = TOKEN_OFFSET.RHYTHM.value + 1 63 | 64 | 65 | ENCODERS: Dict[str, EncodeFunc] = dict() 66 | 67 | 68 | def _get_meta_name(func_name: str) -> str: 69 | return "_".join(func_name.split("_")[1:]) 70 | 71 | 72 | def register_encoder(func): 73 | ENCODERS[_get_meta_name(func.__name__)] = func 74 | return func 75 | 76 | 77 | def inject_args_to_encode_func(encode_func, *args, **kwargs) -> int: 78 | num_args = len(inspect.getfullargspec(encode_func).args) 79 | if num_args == 1: 80 | return encode_func(args[0]) 81 | return encode_func(*args, **kwargs) 82 | 83 | 84 | def encode_unknown( 85 | raise_error: bool = False, error_message: str = ErrorMessage.UNPROCESSABLE_MIDI_ERROR.value 86 | ): 87 | def decorator(func: EncodeFunc): 88 | @functools.wraps(func) 89 | def wrapper(*args, **kwargs): 90 | meta_name = _get_meta_name(func.__name__) 91 | if args[0] == constants.UNKNOWN: 92 | if raise_error: 93 | raise UnprocessableMidiError(error_message) 94 | return Unknown.get(meta_name).value 95 | return inject_args_to_encode_func(func, *args, **kwargs) 96 | 97 | return wrapper 98 | 99 | return decorator 100 | 101 | 102 | def add_offset(func): 103 | @functools.wraps(func) 104 | def wrapper(*args, **kwargs): 105 | meta_name = _get_meta_name(func.__name__).upper() 106 | offset_value = Offset.get(meta_name).value 107 | unknown_value = Unknown.get(meta_name).value 108 | result = inject_args_to_encode_func(func, *args, **kwargs) 109 | if result == unknown_value: 110 | return result 111 | return result + offset_value 112 | 113 | return wrapper 114 | 115 | 116 | @register_encoder 117 | @add_offset 118 | @encode_unknown() 119 | def encode_bpm(bpm: Union[int, str]) -> int: 120 | bpm_meta = min(bpm, constants.MAX_BPM) // constants.BPM_INTERVAL 121 | if bpm_meta == 0: 122 | bpm_meta = 1 123 | return bpm_meta 124 | 125 | 126 | @register_encoder 127 | @add_offset 128 | @encode_unknown() 129 | def encode_audio_key(audio_key: str, encoding_map: Dict[str, int]) -> int: 130 | try: 131 | return encoding_map[audio_key] 132 | except KeyError: 133 | raise UnprocessableMidiError(f"audio key KeyError: {audio_key}") 134 | 135 | 136 | @register_encoder 137 | @add_offset 138 | @encode_unknown() 139 | def encode_time_signature(time_signature: str, encoding_map: Dict[str, int]) -> int: 140 | try: 141 | return encoding_map[time_signature] 142 | except KeyError: 143 | raise UnprocessableMidiError(f"ts KeyError: {time_signature}") 144 | 145 | 146 | @register_encoder 147 | @add_offset 148 | @encode_unknown() 149 | def encode_pitch_range(pitch_range: str, encoding_map: Dict[str, int]) -> int: 150 | try: 151 | return encoding_map[pitch_range] 152 | except KeyError: 153 | raise UnprocessableMidiError(f"pitch range KeyError: {pitch_range}") 154 | 155 | 156 | @register_encoder 157 | @encode_unknown(raise_error=True) 158 | def encode_num_measures(num_measures: Union[float, str]) -> int: 159 | num_measures = math.floor(num_measures) 160 | if num_measures == 4: 161 | return Offset.MEASURES_4.value 162 | elif num_measures == 5: 163 | return Offset.MEASURES_4.value 164 | elif num_measures == 8: 165 | return Offset.MEASURES_8.value 166 | elif num_measures == 9: 167 | return Offset.MEASURES_8.value 168 | elif num_measures == 16: 169 | return Offset.MEASURES_16.value 170 | elif num_measures == 17: 171 | return Offset.MEASURES_16.value 172 | else: 173 | raise UnprocessableMidiError(f"num measures ValueError: {num_measures}") 174 | 175 | 176 | @register_encoder 177 | @add_offset 178 | @encode_unknown() 179 | def encode_inst(inst: Union[int, str], encoding_map: Dict[str, int]) -> int: 180 | try: 181 | return encoding_map[inst] 182 | except KeyError: 183 | raise UnprocessableMidiError(f"inst KeyError: {inst}") 184 | 185 | 186 | @register_encoder 187 | @add_offset 188 | @encode_unknown() 189 | def encode_genre(genre: str, encoding_map: Dict[str, int]) -> int: 190 | try: 191 | return encoding_map[genre] 192 | except KeyError: 193 | raise UnprocessableMidiError(f"genre KeyError: {genre}") 194 | 195 | 196 | @register_encoder 197 | @add_offset 198 | @encode_unknown() 199 | def encode_min_velocity(velocity: Union[int, str]): 200 | return math.floor(velocity / constants.VELOCITY_INTERVAL) 201 | 202 | 203 | @register_encoder 204 | @add_offset 205 | @encode_unknown() 206 | def encode_max_velocity(velocity: Union[int, str]): 207 | return math.ceil(velocity / constants.VELOCITY_INTERVAL) 208 | 209 | 210 | @register_encoder 211 | @add_offset 212 | @encode_unknown() 213 | def encode_track_role(track_role: str, encoding_map: Dict[str, int]) -> int: 214 | try: 215 | return encoding_map[track_role] 216 | except KeyError: 217 | raise UnprocessableMidiError(f"track role KeyError: {track_role}") 218 | 219 | 220 | @register_encoder 221 | @add_offset 222 | @encode_unknown() 223 | def encode_rhythm(rhythm: str, encoding_map: Dict[str, int]) -> int: 224 | try: 225 | return encoding_map[rhythm] 226 | except KeyError: 227 | raise UnprocessableMidiError(f"rhythm KeyError: {rhythm}") 228 | 229 | 230 | def encode_meta( 231 | midi_meta: MidiMeta, 232 | ) -> List[int]: 233 | encoding_maps = DEFAULT_ENCODING_MAPS 234 | result = [] 235 | for meta_name in META_ENCODING_ORDER: 236 | encoded_meta = inject_args_to_encode_func( 237 | ENCODERS[meta_name], 238 | getattr(midi_meta, meta_name), 239 | encoding_maps.get(meta_name), 240 | ) 241 | result.append(encoded_meta) 242 | return result 243 | 244 | 245 | class MetaEncoder: 246 | def __init__(self): 247 | pass 248 | 249 | def encode(self, midi_meta: MidiMeta) -> List[int]: 250 | return encode_meta(midi_meta) 251 | -------------------------------------------------------------------------------- /commu/preprocessor/parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .meta import MetaParser 2 | -------------------------------------------------------------------------------- /commu/preprocessor/parser/meta.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import re 3 | from typing import Any, Dict 4 | from ..utils.container import MidiMeta 5 | 6 | class MetaParser: 7 | def __init__(self): 8 | pass 9 | 10 | def parse(self, meta_dict: Dict[str, Any]) -> MidiMeta: 11 | copied_meta_dict = copy.deepcopy(meta_dict) 12 | copied_meta_dict["inst"] = remove_number_from_inst(copied_meta_dict["inst"]) 13 | 14 | copied_meta_dict["chord_progression"] = copied_meta_dict.pop("chord_progressions")[0] 15 | 16 | midi_meta = MidiMeta( 17 | **copied_meta_dict, 18 | ) 19 | return midi_meta 20 | 21 | def remove_number_from_inst(inst: str) -> str: 22 | """`{inst}-[0-9]` => `{inst}`""" 23 | inst_number_pattern = re.compile("-[0-9]+") 24 | return inst_number_pattern.sub("", inst) 25 | -------------------------------------------------------------------------------- /commu/preprocessor/pipeline.py: -------------------------------------------------------------------------------- 1 | import time 2 | from multiprocessing import cpu_count 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | from logger import logger 7 | from .encoder import EventSequenceEncoder, MetaEncoder 8 | from .parser import MetaParser 9 | from .preprocessor import Preprocessor 10 | 11 | 12 | class PreprocessPipeline: 13 | def __init__(self): 14 | pass 15 | 16 | def __call__( 17 | self, 18 | root_dir: Union[str, Path], 19 | csv_path: Union[str, Path], 20 | num_cores: int = max(4, cpu_count() - 2), 21 | ): 22 | meta_parser = MetaParser() 23 | meta_encoder = MetaEncoder() 24 | event_sequence_encoder = EventSequenceEncoder() 25 | preprocessor = Preprocessor( 26 | meta_parser=meta_parser, 27 | meta_encoder=meta_encoder, 28 | event_sequence_encoder=event_sequence_encoder, 29 | csv_path=csv_path, 30 | ) 31 | logger.info(f"Initialized preprocessor") 32 | logger.info("Start preprocessing") 33 | start_time = time.perf_counter() 34 | preprocessor.preprocess( 35 | root_dir=root_dir, 36 | num_cores=num_cores, 37 | ) 38 | end_time = time.perf_counter() 39 | logger.info(f"Finished preprocessing in {end_time - start_time:.3f}s") 40 | -------------------------------------------------------------------------------- /commu/preprocessor/preprocessor.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import enum 3 | import os 4 | import shutil 5 | import tempfile 6 | from ast import literal_eval 7 | from dataclasses import dataclass, field, fields 8 | from pathlib import Path 9 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union 10 | 11 | import miditoolkit 12 | import numpy as np 13 | import pandas as pd 14 | import parmap 15 | 16 | from . import augment 17 | from .utils import sync_key_augment 18 | from .utils.exceptions import UnprocessableMidiError 19 | from .encoder import MetaEncoder, EventSequenceEncoder 20 | from .parser import MetaParser 21 | 22 | MIDI_EXTENSIONS = (".mid", ".MID", ".midi", ".MIDI") 23 | 24 | 25 | class OutputSubDirName(str, enum.Enum): 26 | RAW = "raw" 27 | ENCODE_NPY = "output_npy" 28 | ENCODE_TMP = "npy_tmp" 29 | 30 | 31 | class SubDirName(str, enum.Enum): 32 | RAW = "raw" 33 | ENCODE_NPY = "output_npy" 34 | ENCODE_TMP = "npy_tmp" 35 | AUGMENTED_TMP = "augmented_tmp" 36 | AUGMENTED = "augmented" 37 | 38 | 39 | @dataclass 40 | class OutputSubDirectory: 41 | encode_npy: Union[str, Path] 42 | encode_tmp: Union[str, Path] 43 | 44 | 45 | @dataclass 46 | class SubDirectory: 47 | raw: Union[str, Path] 48 | encode_npy: Union[str, Path] 49 | encode_tmp: Union[str, Path] 50 | augmented_tmp: Optional[Union[str, Path]] = field(default=None) 51 | augmented: Optional[Union[str, Path]] = field(default=None) 52 | 53 | 54 | def get_output_sub_dir(root_dir: Union[str, Path]) -> OutputSubDirectory: 55 | result = dict() 56 | for name, member in OutputSubDirName.__members__.items(): 57 | output_dir = root_dir.joinpath(member.value) 58 | output_dir.mkdir(exist_ok=True, parents=True) 59 | result[name.lower()] = output_dir 60 | return OutputSubDirectory(**result) 61 | 62 | 63 | def get_sub_dir( 64 | root_dir: Union[str, Path], split: Optional[str]) -> SubDirectory: 65 | result = dict() 66 | for name, member in SubDirName.__members__.items(): 67 | if split is None: 68 | sub_dir = root_dir.joinpath(member.value) 69 | else: 70 | sub_dir = root_dir.joinpath(split).joinpath(member.value) 71 | sub_dir.mkdir(exist_ok=True, parents=True) 72 | result[name.lower()] = sub_dir 73 | return SubDirectory(**result) 74 | 75 | 76 | @dataclass 77 | class EncodingOutput: 78 | meta: np.ndarray 79 | event_sequence: np.ndarray 80 | 81 | 82 | class Preprocessor: 83 | def __init__( 84 | self, 85 | meta_parser: MetaParser, 86 | meta_encoder: MetaEncoder, 87 | event_sequence_encoder: EventSequenceEncoder, 88 | csv_path: str, 89 | ): 90 | self.meta_parser = meta_parser 91 | self.meta_encoder = meta_encoder 92 | self.event_sequence_encoder = event_sequence_encoder 93 | self.csv_path = csv_path 94 | 95 | def augment_data( 96 | self, 97 | source_dir: Union[str, Path], 98 | augmented_dir: Union[str, Path], 99 | augmented_tmp_dir: Union[str, Path], 100 | num_cores: int, 101 | ): 102 | augment.augment_data( 103 | midi_path=str(source_dir), 104 | augmented_dir=str(augmented_dir), 105 | augmented_tmp_dir=str(augmented_tmp_dir), 106 | num_cores=num_cores, 107 | ) 108 | 109 | def encode_event_sequence(self, midi_path: Union[str, Path], sample_info: Dict) -> np.ndarray: 110 | with tempfile.NamedTemporaryFile(suffix=Path(midi_path).suffix) as f: 111 | midi_obj = miditoolkit.MidiFile(midi_path) 112 | for idx in range(len(midi_obj.instruments)): 113 | try: 114 | if midi_obj.instruments[idx].name == "chord": 115 | midi_obj.instruments.pop(idx) 116 | except IndexError: 117 | continue 118 | midi_obj.dump(f.name) 119 | event_sequence = np.array(self.event_sequence_encoder.encode(midi_path, sample_info=sample_info)) 120 | return event_sequence 121 | 122 | def preprocess( 123 | self, 124 | root_dir: Union[str, Path], 125 | num_cores: int, 126 | data_split: Tuple[str] = ("train", "val",), 127 | ): 128 | default_sub_dir = get_sub_dir(root_dir, split=None) 129 | fetched_samples = pd.read_csv(self.csv_path, 130 | converters={"chord_progressions": literal_eval}) 131 | 132 | for empty_dir in fields(default_sub_dir): 133 | if empty_dir.name in ("encode_npy",): 134 | continue 135 | else: 136 | try: 137 | os.rmdir(root_dir.joinpath(getattr(SubDirName, empty_dir.name.upper()).value)) 138 | except FileNotFoundError: 139 | continue 140 | 141 | for split in data_split: 142 | split_sub_dir = get_sub_dir(root_dir, split=split) 143 | self.augment_data( 144 | source_dir=split_sub_dir.raw, 145 | augmented_dir=split_sub_dir.augmented, 146 | augmented_tmp_dir=split_sub_dir.augmented_tmp, 147 | num_cores=num_cores, 148 | ) 149 | 150 | sample_id_to_path = self._gather_sample_files( 151 | *(split_sub_dir.raw, split_sub_dir.augmented)) 152 | 153 | self.export_encoded_midi( 154 | fetched_samples=fetched_samples, 155 | encoded_tmp_dir=split_sub_dir.encode_tmp, 156 | sample_id_to_path=sample_id_to_path, 157 | num_cores=num_cores, 158 | ) 159 | 160 | input_npy, target_npy = self.concat_npy(split_sub_dir.encode_tmp) 161 | np.save(str(default_sub_dir.encode_npy.joinpath(f"input_{split}.npy")), input_npy, allow_pickle=True) 162 | np.save(str(default_sub_dir.encode_npy.joinpath(f"target_{split}.npy")), target_npy, allow_pickle=True) 163 | 164 | for empty_dir in os.listdir(root_dir.joinpath(split)): 165 | if empty_dir in ("raw", "npy_tmp", "augmented", "augmented_tmp"): 166 | continue 167 | else: 168 | shutil.rmtree(root_dir.joinpath(split).joinpath(empty_dir)) 169 | 170 | def export_encoded_midi( 171 | self, 172 | fetched_samples: Union[pd.DataFrame, List[Dict[str, Any]]], 173 | sample_id_to_path: Dict[str, str], 174 | encoded_tmp_dir: Union[str, Path], 175 | num_cores: int, 176 | ) -> None: 177 | sample_infos_chunk = [ 178 | (idx, arr.tolist()) 179 | for idx, arr in enumerate(np.array_split(np.array(fetched_samples.to_dict('records')), num_cores)) 180 | ] 181 | parmap.map( 182 | self._preprocess_midi_chunk, 183 | sample_infos_chunk, 184 | sample_id_to_path=sample_id_to_path, 185 | encode_tmp_dir=encoded_tmp_dir, 186 | pm_pbar=True, 187 | pm_processes=num_cores, 188 | ) 189 | 190 | def _preprocess_midi_chunk( 191 | self, 192 | idx_sample_infos_chunk: Tuple[int, Iterable[Dict[str, Any]]], 193 | sample_id_to_path: Dict[str, str], 194 | encode_tmp_dir: Union[str, Path], 195 | ): 196 | idx, sample_infos_chunk = idx_sample_infos_chunk 197 | copied_sample_infos_chunk = copy.deepcopy(list(sample_infos_chunk)) 198 | parent_sample_ids_to_info = { 199 | sample_info["id"]: sample_info for sample_info in copied_sample_infos_chunk 200 | } 201 | parent_sample_ids = set(parent_sample_ids_to_info.keys()) 202 | 203 | copied_sample_infos_chunk.extend( 204 | [ 205 | {"id": sample_id, "augmented": True} 206 | for sample_id in sample_id_to_path.keys() 207 | if sample_id.split("_")[0] in parent_sample_ids 208 | ] 209 | ) 210 | 211 | encode_tmp_dir = Path(encode_tmp_dir) 212 | for sample_info_idx, sample_info in enumerate(copied_sample_infos_chunk): 213 | copied_sample_info = sample_info 214 | if sample_info.get("augmented", False): 215 | id_split = copied_sample_info["id"].split("_") 216 | bpm = copied_sample_info.get("bpm") 217 | audio_key = copied_sample_info.get("audio_key") 218 | if len(id_split) > 1: 219 | parent_sample_id, audio_key, bpm = id_split 220 | else: 221 | parent_sample_id = id_split[0] 222 | 223 | if bpm is None or audio_key is None: 224 | continue 225 | 226 | augmented_midi_path = sample_id_to_path[copied_sample_info["id"]] 227 | copied_sample_info = copy.deepcopy(parent_sample_ids_to_info[parent_sample_id]) 228 | copied_sample_info["bpm"] = int(bpm) 229 | # key_origin = copied_sample_info["audio_key"] + copied_sample_info["chord_type"] in ["cmajor", "aminor"] 230 | # key_origin 값 수정 231 | key_origin = copied_sample_info["audio_key"] in ["cmajor", "aminor"] 232 | 233 | if not key_origin: 234 | continue 235 | try: 236 | copied_sample_info["chord_progressions"] = sync_key_augment( 237 | copied_sample_info["chord_progressions"][0], 238 | audio_key.replace("minor", "").replace("major", ""), 239 | copied_sample_info["audio_key"][0], # audio_key 값 앞쪽으로 할당 240 | ) 241 | except IndexError: 242 | print(f"chord progression info is unknown: {augmented_midi_path}") 243 | continue 244 | copied_sample_info["audio_key"] = audio_key 245 | copied_sample_info["rhythm"] = copied_sample_info.get("sample_rhythm") 246 | # is_incomplete_measure column 추가 247 | if copied_sample_info["num_measures"]%4==0: 248 | copied_sample_info["is_incomplete_measure"] = False 249 | else: 250 | copied_sample_info["is_incomplete_measure"] = True 251 | 252 | midi_path = sample_id_to_path.get(copied_sample_info["id"]) 253 | if midi_path is None: 254 | continue 255 | try: 256 | encoding_output = self._preprocess_midi( 257 | sample_info=copied_sample_info, midi_path=augmented_midi_path 258 | ) 259 | except (IndexError, TypeError) as e: 260 | print(f"{e}: {augmented_midi_path}") 261 | continue 262 | except ValueError: 263 | print(f"num measures not allowed: {augmented_midi_path}") 264 | continue 265 | output_dir = encode_tmp_dir.joinpath(f"{idx:04d}") 266 | output_dir.mkdir(exist_ok=True, parents=True) 267 | try: 268 | np.save( 269 | os.path.join(output_dir, f"input_{sample_info_idx}"), encoding_output.meta 270 | ) 271 | np.save( 272 | os.path.join(output_dir, f"target_{sample_info_idx}"), encoding_output.event_sequence 273 | ) 274 | except AttributeError: 275 | continue 276 | 277 | def _preprocess_midi( 278 | self, sample_info: Dict[str, Any], midi_path: Union[str, Path] 279 | ) -> Optional[EncodingOutput]: 280 | midi_meta = self.meta_parser.parse(meta_dict=sample_info) 281 | try: 282 | encoded_meta: List[Union[int, str]] = self.meta_encoder.encode(midi_meta) 283 | except UnprocessableMidiError as e: 284 | print(f"{e}: {midi_path}") 285 | return None 286 | encoded_meta: np.ndarray = np.array(encoded_meta, dtype=object) 287 | encoded_event_sequence = np.array( 288 | self.encode_event_sequence(midi_path, sample_info), dtype=np.int16 289 | ) 290 | return EncodingOutput(meta=encoded_meta, event_sequence=encoded_event_sequence) 291 | 292 | @staticmethod 293 | def _gather_sample_files(*source_dirs: Union[str, Path]) -> Dict[str, str]: 294 | def _gather(_source_dir): 295 | return { 296 | filename.stem: str(filename) 297 | for filename in Path(_source_dir).rglob("**/*") 298 | if filename.suffix in MIDI_EXTENSIONS 299 | } 300 | 301 | result = dict() 302 | for source_dir in source_dirs: 303 | result.update(_gather(source_dir)) 304 | return result 305 | 306 | @staticmethod 307 | def concat_npy(source_dir: Union[str, Path]) -> Tuple[List[np.ndarray], List[np.ndarray]]: 308 | def _gather(_prefix) -> List[str]: 309 | npy_suffix = ".npy" 310 | return sorted( 311 | str(f) 312 | for f in Path(source_dir).rglob("**/*") 313 | if f.suffix == npy_suffix and f.stem.startswith(_prefix) 314 | ) 315 | 316 | def _concat(_npy_list: List[str]) -> List[np.ndarray]: 317 | return [np.load(_p, allow_pickle=True) for _p in _npy_list] 318 | 319 | return _concat(_gather("input")), _concat(_gather("target")) 320 | -------------------------------------------------------------------------------- /commu/preprocessor/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /commu/preprocessor/utils/constants.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from typing import List, Optional 3 | 4 | class KeySwitchVelocity(int, enum.Enum): 5 | DEFAULT = 1 6 | 7 | @classmethod 8 | def get_value(cls, key: Optional[str]) -> int: 9 | key = key or "DEFAULT" 10 | if hasattr(cls, key): 11 | return getattr(cls, key).value 12 | return cls.DEFAULT.value 13 | 14 | class ChordType(str, enum.Enum): 15 | MAJOR = "major" 16 | MINOR = "minor" 17 | 18 | @classmethod 19 | def values(cls) -> List[str]: 20 | return list(cls.__members__.values()) 21 | 22 | BPM_INTERVAL = 5 23 | CHORD_TRACK_NAME = "chord" 24 | DEFAULT_NUM_BEATS = 4 25 | DEFAULT_POSITION_RESOLUTION = 128 26 | DEFAULT_TICKS_PER_BEAT = 480 27 | MAX_BPM = 200 28 | NUM_BPM_AUGMENT = 2 29 | NUM_KEY_AUGMENT = 6 30 | UNKNOWN = "unknown" 31 | VELOCITY_INTERVAL = 2 32 | 33 | MAJOR_KEY = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] 34 | MINOR_KEY = [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] 35 | 36 | KEY_MAP = { 37 | "cmajor": 0, 38 | "c#major": 1, 39 | "dbmajor": 1, 40 | "dmajor": 2, 41 | "d#major": 3, 42 | "ebmajor": 3, 43 | "emajor": 4, 44 | "fmajor": 5, 45 | "f#major": 6, 46 | "gbmajor": 6, 47 | "gmajor": 7, 48 | "g#major": 8, 49 | "abmajor": 8, 50 | "amajor": 9, 51 | "a#major": 10, 52 | "bbmajor": 10, 53 | "bmajor": 11, 54 | "cminor": 12, 55 | "c#minor": 13, 56 | "dbminor": 13, 57 | "dminor": 14, 58 | "d#minor": 15, 59 | "ebminor": 15, 60 | "eminor": 16, 61 | "fminor": 17, 62 | "f#minor": 18, 63 | "gbminor": 18, 64 | "gminor": 19, 65 | "g#minor": 20, 66 | "abminor": 20, 67 | "aminor": 21, 68 | "a#minor": 22, 69 | "bbminor": 22, 70 | "bminor": 23, 71 | } 72 | 73 | KEY_NUM_MAP = {v: k for k, v in KEY_MAP.items()} 74 | 75 | TIME_SIG_MAP = { 76 | "4/4": 0, 77 | "3/4": 1, 78 | "6/8": 2, 79 | "12/8": 3, 80 | } 81 | 82 | SIG_TIME_MAP = {v: k for k, v in TIME_SIG_MAP.items()} 83 | 84 | PITCH_RANGE_MAP = { 85 | "very_low": 0, 86 | "low": 1, 87 | "mid_low": 2, 88 | "mid": 3, 89 | "mid_high": 4, 90 | "high": 5, 91 | "very_high": 6, 92 | } 93 | 94 | INST_MAP = { 95 | "accordion": 1, 96 | "acoustic_bass": 3, 97 | "acoustic_guitar": 3, 98 | "acoustic_piano": 0, 99 | "banjo": 3, 100 | "bassoon": 5, 101 | "bell": 2, 102 | "brass_ensemble": 5, 103 | "celesta": 2, 104 | "choir": 7, 105 | "clarinet": 5, 106 | "drums_full": 6, 107 | "drums_tops": 6, 108 | "electric_bass": 3, 109 | "electric_guitar_clean": 3, 110 | "electric_guitar_distortion": 3, 111 | "electric_piano": 0, 112 | "fiddle": 4, 113 | "flute": 5, 114 | "glockenspiel": 2, 115 | "harp": 3, 116 | "harpsichord": 0, 117 | "horn": 5, 118 | "keyboard": 0, 119 | "mandolin": 3, 120 | "marimba": 2, 121 | "nylon_guitar": 3, 122 | "oboe": 5, 123 | "organ": 0, 124 | "oud": 3, 125 | "pad_synth": 4, 126 | "percussion": 6, 127 | "recorder": 5, 128 | "sitar": 3, 129 | "string_cello": 4, 130 | "string_double_bass": 4, 131 | "string_ensemble": 4, 132 | "string_viola": 4, 133 | "string_violin": 4, 134 | "synth_bass": 3, 135 | "synth_bass_808": 3, 136 | "synth_bass_wobble": 3, 137 | "synth_bell": 2, 138 | "synth_lead": 1, 139 | "synth_pad": 4, 140 | "synth_pluck": 7, 141 | "synth_voice": 7, 142 | "timpani": 6, 143 | "trombone": 5, 144 | "trumpet": 5, 145 | "tuba": 5, 146 | "ukulele": 3, 147 | "vibraphone": 2, 148 | "whistle": 7, 149 | "xylophone": 2, 150 | "zither": 3, 151 | "orgel": 2, 152 | "synth_brass": 5, 153 | "sax": 5, 154 | "bamboo_flute": 5, 155 | "yanggeum": 3, 156 | "vocal": 8, 157 | } 158 | 159 | GENRE_MAP = { 160 | "newage": 0, 161 | "cinematic": 1, 162 | } 163 | 164 | TRACK_ROLE_MAP = { 165 | "main_melody": 0, 166 | "sub_melody": 1, 167 | "accompaniment": 2, 168 | "bass": 3, 169 | "pad": 4, 170 | "riff": 5, 171 | } 172 | 173 | RHYTHM_MAP = { 174 | "standard": 0, 175 | "triplet": 1, 176 | } 177 | 178 | -------------------------------------------------------------------------------- /commu/preprocessor/utils/container.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Union 3 | 4 | from pydantic import BaseModel 5 | 6 | @dataclass 7 | class MidiInfo: 8 | # meta 9 | bpm: int 10 | audio_key: int 11 | time_signature: int 12 | pitch_range: int 13 | num_measures: int 14 | inst: int 15 | genre: str 16 | min_velocity: int 17 | max_velocity: int 18 | track_role: int 19 | rhythm: int 20 | # event 21 | event_seq: List[int] 22 | 23 | class MidiMeta(BaseModel): 24 | bpm: int 25 | audio_key: str 26 | time_signature: str 27 | pitch_range: str 28 | num_measures: float 29 | inst: str 30 | genre: str 31 | min_velocity: int 32 | max_velocity: int 33 | track_role: str 34 | rhythm: str 35 | 36 | -------------------------------------------------------------------------------- /commu/preprocessor/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | 4 | class ErrorMessage(str, enum.Enum): 5 | UNPROCESSABLE_MIDI_ERROR = "Unprocessable midi" 6 | 7 | 8 | class DioaiError(Exception): 9 | pass 10 | 11 | 12 | class UnprocessableMidiError(DioaiError): 13 | """Unprocessable Midi""" 14 | -------------------------------------------------------------------------------- /commu/preprocessor/utils/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | from typing import Optional, Tuple, Union 4 | 5 | import miditoolkit 6 | 7 | from .constants import ( 8 | CHORD_TRACK_NAME, 9 | UNKNOWN, 10 | ) 11 | 12 | def get_velocity_range( 13 | midi_path: Union[str, Path], keyswitch_velocity: Optional[int] = None 14 | ) -> Tuple[Union[int, str], Union[int, str]]: 15 | midi = miditoolkit.MidiFile(str(midi_path)) 16 | velocities = [] 17 | for track in midi.instruments: 18 | if track.name == CHORD_TRACK_NAME: 19 | continue 20 | for note in track.notes: 21 | if keyswitch_velocity is not None: 22 | if note.velocity != keyswitch_velocity: 23 | velocities.append(note.velocity) 24 | else: 25 | velocities.append(note.velocity) 26 | 27 | if not velocities or max(velocities) == 0: 28 | return UNKNOWN, UNKNOWN 29 | return min(velocities), max(velocities) 30 | 31 | def get_time_signature(midi_path): 32 | time_signature = miditoolkit.MidiFile(midi_path).time_signature_changes[0] 33 | numerator = time_signature.numerator 34 | denominator = time_signature.denominator 35 | return numerator, denominator 36 | 37 | def sync_key_augment(chords, aug_key, origin_key): 38 | chord_lst = [ 39 | "a", 40 | "a#", 41 | "b", 42 | "c", 43 | "c#", 44 | "d", 45 | "d#", 46 | "e", 47 | "f", 48 | "f#", 49 | "g", 50 | "g#", 51 | "ab", 52 | "bb", 53 | "db", 54 | "eb", 55 | "gb", 56 | ] 57 | chord2symbol = {k: v for k, v in zip(chord_lst, range(12))} 58 | chord2symbol["ab"] = 11 59 | chord2symbol["bb"] = 1 60 | chord2symbol["db"] = 4 61 | chord2symbol["eb"] = 6 62 | chord2symbol["gb"] = 9 63 | symbol2chord = {v: k for k, v in chord2symbol.items()} 64 | 65 | basic_chord = [] 66 | for c in chords: 67 | match = re.match(r"[A-G](#|b|)", c) 68 | basic_chord.append(match[0]) 69 | 70 | chord_type = [c.replace(b, "") for c, b in zip(chords, basic_chord)] 71 | symbol_lst = [chord2symbol[c.lower()] for c in basic_chord] 72 | 73 | origin_key_symbol = chord2symbol[origin_key] 74 | 75 | augment_key_symbol = chord2symbol[aug_key] 76 | 77 | key_diff = origin_key_symbol - augment_key_symbol 78 | key_change = abs(key_diff) 79 | if key_diff < 0: 80 | new_symbol_lst = [] 81 | for s in symbol_lst: 82 | new_s = s + key_change 83 | if new_s >= 12: 84 | new_s = new_s - 12 85 | new_symbol_lst.append(new_s) 86 | else: 87 | new_symbol_lst = [] 88 | for s in symbol_lst: 89 | new_s = s - key_change 90 | if new_s < 0: 91 | new_s = new_s + 12 92 | new_symbol_lst.append(new_s) 93 | 94 | new_chord_lst = [symbol2chord[s] for s in new_symbol_lst] 95 | new_chord_lst = [c + t for c, t in zip(new_chord_lst, chord_type)] 96 | return [new_chord_lst] -------------------------------------------------------------------------------- /dataset/commu_midi.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POZAlabs/ComMU-code/3949a5b5a1a54e2bb0fb9d600ecc00cd55660408/dataset/commu_midi.tar -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Dict 3 | 4 | from commu.midi_generator.generate_pipeline import MidiGenerationPipeline 5 | from commu.preprocessor.utils import constants 6 | 7 | 8 | def parse_args() -> Dict[str, argparse.ArgumentParser]: 9 | model_arg_parser = argparse.ArgumentParser(description="Model Arguments") 10 | input_arg_parser = argparse.ArgumentParser(description="Input Arguments") 11 | 12 | # Model Arguments 13 | model_arg_parser.add_argument("--checkpoint_dir", type=str) 14 | 15 | # Input Arguments 16 | input_arg_parser.add_argument("--output_dir", type=str, required=True) 17 | 18 | ## Input meta 19 | input_arg_parser.add_argument("--bpm", type=int) 20 | input_arg_parser.add_argument("--audio_key", type=str, choices=list(constants.KEY_MAP.keys())) 21 | input_arg_parser.add_argument("--time_signature", type=str, choices=list(constants.TIME_SIG_MAP.keys())) 22 | input_arg_parser.add_argument("--pitch_range", type=str, choices=list(constants.PITCH_RANGE_MAP.keys())) 23 | input_arg_parser.add_argument("--num_measures", type=float) 24 | input_arg_parser.add_argument( 25 | "--inst", type=str, choices=list(constants.INST_MAP.keys()), 26 | ) 27 | input_arg_parser.add_argument( 28 | "--genre", type=str, default="cinematic", choices=list(constants.GENRE_MAP.keys()) 29 | ) 30 | input_arg_parser.add_argument( 31 | "--track_role", type=str, choices=list(constants.TRACK_ROLE_MAP.keys()) 32 | ) 33 | input_arg_parser.add_argument( 34 | "--rhythm", type=str, default="standard", choices=list(constants.RHYTHM_MAP.keys()) 35 | ) 36 | input_arg_parser.add_argument("--min_velocity", type=int, choices=range(1, 128)) 37 | input_arg_parser.add_argument("--max_velocity", type=int, choices=range(1, 128)) 38 | input_arg_parser.add_argument( 39 | "--chord_progression", type=str, help='Chord progression ex) C-C-E-E-G-G ...' 40 | ) 41 | # Inference 시 필요 정보 42 | input_arg_parser.add_argument("--num_generate", type=int) 43 | input_arg_parser.add_argument("--top_k", type=int, default=32) 44 | input_arg_parser.add_argument("--temperature", type=float, default=0.95) 45 | 46 | arg_dict = { 47 | "model_args": model_arg_parser, 48 | "input_args": input_arg_parser 49 | } 50 | return arg_dict 51 | 52 | 53 | def main(model_args: argparse.Namespace, input_args: argparse.Namespace): 54 | pipeline = MidiGenerationPipeline() 55 | pipeline.initialize_model(vars(model_args)) 56 | pipeline.initialize_generation() 57 | 58 | inference_cfg = pipeline.model_initialize_task.inference_cfg 59 | model = pipeline.model_initialize_task.execute() 60 | 61 | encoded_meta = pipeline.preprocess_task.execute(vars(input_args)) 62 | input_data = pipeline.preprocess_task.input_data 63 | 64 | pipeline.inference_task( 65 | model=model, 66 | input_data=input_data, 67 | inference_cfg=inference_cfg 68 | ) 69 | sequences = pipeline.inference_task.execute(encoded_meta) 70 | 71 | pipeline.postprocess_task(input_data=input_data) 72 | pipeline.postprocess_task.execute( 73 | sequences=sequences, 74 | ) 75 | 76 | 77 | if __name__ == "__main__": 78 | model_args, _ = parse_args()["model_args"].parse_known_args() 79 | input_args, _ = parse_args()["input_args"].parse_known_args() 80 | main(model_args, input_args) -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger("ComMU") 4 | logger.setLevel(logging.DEBUG) 5 | formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s") 6 | handler = logging.StreamHandler() 7 | handler.setFormatter(formatter) 8 | logger.handlers = [] 9 | logger.propagate = False 10 | logger.addHandler(handler) 11 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from multiprocessing import cpu_count 3 | from pathlib import Path 4 | 5 | from commu.preprocessor import PreprocessPipeline 6 | 7 | def get_root_parser() -> argparse.ArgumentParser: 8 | root_parser = argparse.ArgumentParser("dataset preprocessing", add_help=True) 9 | root_parser.add_argument("--root_dir", type=str, required=True, help="root directory containing 'raw' directory") 10 | root_parser.add_argument("--csv_path", type=str, required=True, help="csv file path containing meta info") 11 | root_parser.add_argument("--num_cores", type=int, default=max(1, cpu_count() - 4)) 12 | return root_parser 13 | 14 | 15 | def main(args: argparse.Namespace) -> None: 16 | root_dir = Path(args.root_dir).expanduser() 17 | pipeline = PreprocessPipeline() 18 | pipeline( 19 | root_dir=root_dir, 20 | csv_path=args.csv_path, 21 | num_cores=args.num_cores, 22 | ) 23 | 24 | 25 | if __name__ == "__main__": 26 | import warnings 27 | warnings.filterwarnings("ignore") 28 | 29 | parser = get_root_parser() 30 | known_args, _ = parser.parse_known_args() 31 | main(known_args) 32 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | miditoolkit 3 | mido 4 | music21 5 | pandas 6 | parmap 7 | pretty-midi 8 | pydantic 9 | scipy 10 | split-folders 11 | torch 12 | yacs -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import contextlib 3 | import math 4 | import os 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | 14 | from logger import logger 15 | from commu.model.config_helper import get_default_cfg_training 16 | from commu.model.dataset import ComMUDataset 17 | from commu.model.exp_utils import logging_config 18 | from commu.model.model import MemTransformerLM 19 | 20 | @contextlib.contextmanager 21 | def sync_workers(args): 22 | """ 23 | Yields distributed rank and synchronizes all workers on exit. 24 | """ 25 | yield args.local_rank 26 | dist.barrier() 27 | 28 | 29 | def save_checkpoint( 30 | args, 31 | model, 32 | optimizer, 33 | vocab, 34 | train_step, 35 | best_val_loss, 36 | scheduler, 37 | name="checkpoint.pt", 38 | ): 39 | checkpoint = { 40 | "model": model.module.state_dict(), 41 | "optimizer": optimizer.state_dict(), 42 | "train_step": train_step, 43 | "scheduler": scheduler.state_dict(), 44 | "best_val_loss": best_val_loss, 45 | "vocab": vocab, 46 | } 47 | 48 | checkpoint["amp"] = None 49 | 50 | with sync_workers(args) as rank: 51 | path = os.path.join(args.work_dir, name) 52 | logger.info(f"Saving checkpoint to {path}") 53 | if rank == 0: 54 | torch.save(checkpoint, path) 55 | 56 | 57 | def parse_args(): 58 | parser = argparse.ArgumentParser(description="PyTorch Transformer Language Model") 59 | parser.add_argument( 60 | "--data_dir", type=str, required=True, help="location of the data corpus" 61 | ) 62 | parser.add_argument("--local_rank", type=int, default=0) 63 | parser.add_argument( 64 | "--work_dir", 65 | type=str, 66 | required=True, 67 | help="Base directory to save the trained model.", 68 | ) 69 | args = parser.parse_args() 70 | return args 71 | 72 | 73 | 74 | def evaluate(eval_iter): 75 | # Turn on evaluation mode def disables dropout. 76 | model.eval() 77 | 78 | if isinstance(model, DDP): 79 | eval_model = model.module 80 | else: 81 | eval_model = model 82 | 83 | eval_model.reset_length( 84 | tgt_len=cfg.EVALUATE.tgt_length, mem_len=cfg.EVALUATE.mem_length) 85 | eval_model.same_length = True 86 | # Evaluation 87 | total_token_num = 0 88 | total_nll = 0.0 89 | 90 | with torch.no_grad(): 91 | mems = None 92 | 93 | for i, (data, target, all_reset_mem, batch_token_num) in enumerate(eval_iter()): 94 | 95 | if all_reset_mem: 96 | mems = None 97 | 98 | ret = model(data, target, None, mems) 99 | loss, mems = ret 100 | loss = loss[target != dataset.vocab.pad_id] 101 | loss = loss.mean() 102 | total_nll += batch_token_num * loss.float().item() 103 | total_token_num += batch_token_num 104 | 105 | eval_model.reset_length(cfg.TRAIN.tgt_length, cfg.TRAIN.mem_length) 106 | eval_model.same_length = cfg.MODEL.same_length 107 | 108 | model.train() 109 | 110 | return total_token_num, total_nll 111 | 112 | 113 | def train(): 114 | global train_step 115 | global best_val_nll 116 | 117 | log_train_loss = torch.tensor(0.0).float().to(device) 118 | log_grad_norm = torch.tensor(0.0).float().to(device) 119 | log_token_num = torch.tensor(0).to(device) 120 | 121 | log_start_time = time.time() 122 | 123 | mems = [None for _ in range(cfg.TRAIN.batch_chunk)] 124 | 125 | assert batch_size % cfg.TRAIN.batch_chunk == 0 126 | train_real_iter = train_iter() 127 | 128 | for batch, (data, target, reset_mems, batch_token_num) in enumerate( 129 | train_real_iter 130 | ): 131 | model.module.temperature = 1.0 132 | 133 | model.zero_grad() 134 | 135 | 136 | data_chunks = torch.chunk(data, cfg.TRAIN.batch_chunk, 1) 137 | target_chunks = torch.chunk(target, cfg.TRAIN.batch_chunk, 1) 138 | reset_mems_chunks = torch.chunk(reset_mems, cfg.TRAIN.batch_chunk, 0) 139 | for i in range(cfg.TRAIN.batch_chunk): 140 | 141 | data = data_chunks[i].contiguous() 142 | target = target_chunks[i].contiguous() 143 | reset_mems = reset_mems_chunks[i].contiguous() 144 | 145 | ret = model(data, target, reset_mems, mems[i]) 146 | loss, mems[i] = ret 147 | 148 | loss = loss[target != dataset.vocab.pad_id] 149 | loss = loss.float().mean() / cfg.TRAIN.batch_chunk 150 | log_train_loss += ( 151 | loss.item() 152 | * (target != dataset.vocab.pad_id).sum() 153 | * cfg.TRAIN.batch_chunk 154 | ) 155 | loss.backward() 156 | 157 | log_token_num += int(batch_token_num) 158 | 159 | grad_norm = torch.nn.utils.clip_grad_norm_( 160 | model.module.parameters(), cfg.TRAIN.clip 161 | ) 162 | 163 | log_grad_norm += grad_norm 164 | optimizer.step() 165 | optimizer.zero_grad() 166 | 167 | # step-wise learning rate annealing 168 | train_step += 1 169 | scheduler.step() 170 | 171 | if train_step % cfg.TRAIN.log_interval == 0: 172 | torch.distributed.all_reduce(log_train_loss) 173 | torch.distributed.all_reduce(log_grad_norm) 174 | torch.distributed.all_reduce(log_token_num) 175 | 176 | log_train_loss /= log_token_num 177 | log_grad_norm /= cfg.TRAIN.log_interval * num_gpus 178 | if args.local_rank == 0: 179 | elapsed = time.time() - log_start_time 180 | logger.info( 181 | "Train Step {}/{}, lr={:f}, tokens/s={:.1f}," 182 | " nll={:.4f}, ppl={:.2f}, grad norm={}, ".format( 183 | train_step, 184 | cfg.TRAIN.max_step, 185 | optimizer.param_groups[0]["lr"], 186 | log_token_num.item() / elapsed, 187 | log_train_loss.item(), 188 | math.exp(log_train_loss.item()), 189 | log_grad_norm.item(), 190 | ) 191 | ) 192 | 193 | log_train_loss[()] = 0 194 | log_grad_norm[()] = 0 195 | log_token_num[()] = 0 196 | 197 | log_start_time = time.time() 198 | 199 | if train_step % cfg.TRAIN.eval_interval == 0: 200 | eval_start_time = time.time() 201 | 202 | val_token_num, val_total_nll = evaluate( 203 | eval_iter=val_iter 204 | ) 205 | 206 | val_token_num_pt = torch.tensor(val_token_num).to(device) 207 | val_total_nll_pt = torch.tensor(val_total_nll / 10000.0).to(device) 208 | 209 | torch.distributed.all_reduce(val_token_num_pt) 210 | torch.distributed.all_reduce(val_total_nll_pt) 211 | 212 | val_token_num = val_token_num_pt.item() 213 | val_total_nll = val_total_nll_pt.item() 214 | 215 | val_nll = val_total_nll / (val_token_num / 10000.0) 216 | 217 | if args.local_rank == 0: 218 | logger.info( 219 | "Eval step {}, time={}s, val nll={}, val ppl={},".format( 220 | train_step, 221 | time.time() - eval_start_time, 222 | val_nll, 223 | math.exp(val_nll), 224 | val_token_num, 225 | ) 226 | ) 227 | 228 | name = "checkpoint_last.pt" 229 | save_checkpoint( 230 | args, 231 | model, 232 | optimizer, 233 | dataset.vocab, 234 | train_step, 235 | val_nll, 236 | scheduler, 237 | name, 238 | ) 239 | 240 | if not best_val_nll or val_nll < best_val_nll: 241 | best_val_nll = val_nll 242 | 243 | name = "checkpoint_best.pt" 244 | save_checkpoint( 245 | args, 246 | model, 247 | optimizer, 248 | dataset.vocab, 249 | train_step, 250 | best_val_nll, 251 | scheduler, 252 | name, 253 | ) 254 | 255 | test_start_time = time.time() 256 | 257 | def calculate_test_nll_during_training(test_iter): 258 | 259 | test_token_num, test_total_nll = evaluate( 260 | eval_iter=test_iter 261 | ) 262 | test_token_num_pt = torch.tensor(test_token_num).to(device) 263 | test_total_nll_pt = torch.tensor(test_total_nll / 10000.0).to(device) 264 | torch.distributed.all_reduce(test_token_num_pt) 265 | torch.distributed.all_reduce(test_total_nll_pt) 266 | 267 | test_token_num = test_token_num_pt.item() 268 | test_nll = test_total_nll_pt.item() / (test_token_num / 10000.0) 269 | 270 | return test_token_num, test_nll 271 | 272 | test_token_num, test_nll = calculate_test_nll_during_training(test_iter) 273 | 274 | if args.local_rank == 0: 275 | logger.info( 276 | "Test step {}, time={}s, test nll={}, test ppl={}, #evaluated tokens={}".format( 277 | train_step, 278 | time.time() - test_start_time, 279 | test_nll, 280 | math.exp(test_nll), 281 | test_token_num, 282 | ) 283 | ) 284 | 285 | if train_step == cfg.TRAIN.max_step: 286 | logger.info("-" * 100) 287 | logger.info("End of training") 288 | break 289 | 290 | 291 | def init_weight(weight): 292 | init_std = cfg.INITIALIZER.base_init 293 | nn.init.normal_(weight, 0.0, init_std) 294 | 295 | 296 | def init_embed(weight): 297 | init_std = cfg.INITIALIZER.embed_init 298 | nn.init.normal_(weight, 0.0, init_std) 299 | 300 | 301 | def init_bias(bias): 302 | nn.init.constant_(bias, 0.0) 303 | 304 | 305 | def weights_init(m): 306 | classname = m.__class__.__name__ 307 | if classname.find("Linear") != -1: 308 | if hasattr(m, "weight") and m.weight is not None: 309 | init_weight(m.weight) 310 | if hasattr(m, "bias") and m.bias is not None: 311 | init_bias(m.bias) 312 | elif classname.find("AdaptiveEmbedding") != -1: 313 | if hasattr(m, "emb_projs"): 314 | for i in range(len(m.emb_projs)): 315 | if m.emb_projs[i] is not None: 316 | init_embed(m.emb_projs[i]) 317 | elif classname.find("Embedding") != -1: 318 | if hasattr(m, "weight"): 319 | init_weight(m.weight) 320 | elif classname.find("ProjectedAdaptiveLogSoftmax") != -1: 321 | if hasattr(m, "cluster_weight") and m.cluster_weight is not None: 322 | init_weight(m.cluster_weight) 323 | if hasattr(m, "cluster_bias") and m.cluster_bias is not None: 324 | init_bias(m.cluster_bias) 325 | if hasattr(m, "out_projs"): 326 | for i in range(len(m.out_projs)): 327 | if m.out_projs[i] is not None: 328 | init_embed(m.out_projs[i]) 329 | elif classname.find("LayerNorm") != -1: 330 | if hasattr(m, "weight"): 331 | nn.init.normal_(m.weight, 1.0, cfg.INITIALIZER.base_init) 332 | if hasattr(m, "bias") and m.bias is not None: 333 | init_bias(m.bias) 334 | elif classname.find("TransformerLM") != -1: 335 | if hasattr(m, "r_emb"): 336 | init_weight(m.r_emb) 337 | if hasattr(m, "r_w_bias"): 338 | init_weight(m.r_w_bias) 339 | if hasattr(m, "r_r_bias"): 340 | init_weight(m.r_r_bias) 341 | if hasattr(m, "r_bias"): 342 | init_bias(m.r_bias) 343 | 344 | 345 | def update_dropout(m): 346 | classname = m.__class__.__name__ 347 | if classname.find("Dropout") != -1: 348 | if hasattr(m, "p"): 349 | m.p = cfg.MODEL.dropout 350 | 351 | 352 | def update_dropatt(m): 353 | if hasattr(m, "dropatt"): 354 | m.dropatt.p = cfg.MODEL.attention_dropout 355 | 356 | 357 | args = parse_args() 358 | cfg = get_default_cfg_training() 359 | torch.cuda.set_device(args.local_rank) 360 | device = torch.device("cuda", args.local_rank) 361 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 362 | 363 | exp_time = torch.tensor(time.time(), dtype=torch.float64).to(device) 364 | torch.distributed.broadcast(exp_time, 0) 365 | exp_time = float(exp_time.cpu().numpy()) 366 | 367 | args.work_dir = os.path.join( 368 | args.work_dir, time.strftime("%Y%m%d-%H%M%S", time.localtime(exp_time)) 369 | ) 370 | 371 | os.makedirs(args.work_dir, exist_ok=True) 372 | 373 | if args.local_rank == 0: 374 | with open(os.path.join(args.work_dir, "config.yml"), "w") as f: 375 | f.write(str(cfg)) 376 | 377 | if args.local_rank == 0: 378 | logging_config(args.work_dir, "train_rank{}".format(args.local_rank), console=True) 379 | else: 380 | logging_config(args.work_dir, "train_rank{}".format(args.local_rank), console=False) 381 | 382 | seed = cfg.TRAIN.seed 383 | np.random.seed(seed) 384 | torch.manual_seed(seed) 385 | torch.cuda.manual_seed_all(seed) 386 | 387 | ############################################################################### 388 | # Load data 389 | ############################################################################### 390 | logger.info("Loading data") 391 | dataset = ComMUDataset(args.data_dir, cfg) 392 | vocab = dataset.vocab 393 | 394 | local_seed = cfg.TRAIN.seed + args.local_rank * 1000 395 | num_gpus = torch.cuda.device_count() 396 | assert cfg.TRAIN.batch_size % num_gpus == 0 397 | batch_size = cfg.TRAIN.batch_size // num_gpus 398 | 399 | train_iter = dataset.get_iterator( 400 | batch_size, cfg.TRAIN.tgt_length, device, "train", True, seed=local_seed 401 | ) 402 | val_iter = dataset.eval_iterator( 403 | cfg.EVALUATE.batch_size, 404 | cfg.EVALUATE.tgt_length, 405 | device, 406 | "valid", 407 | local_rank=args.local_rank, 408 | world_size=num_gpus, 409 | ) 410 | test_iter = dataset.eval_iterator( 411 | cfg.EVALUATE.batch_size, 412 | cfg.EVALUATE.tgt_length, 413 | device, 414 | "test", 415 | local_rank=args.local_rank, 416 | world_size=num_gpus, 417 | ) 418 | 419 | 420 | ############################################################################### 421 | # Build the model 422 | ############################################################################### 423 | 424 | logger.info("Build the model") 425 | 426 | assert cfg.MODEL.units % cfg.MODEL.num_heads == 0 427 | model = MemTransformerLM(cfg, vocab) 428 | model.apply(weights_init) 429 | model.word_emb.apply( 430 | weights_init 431 | ) # ensure embedding init is not overridden by out_layer in case of weight sharing 432 | 433 | args.n_all_param = sum([p.nelement() for p in model.parameters()]) 434 | args.n_nonemb_param_gen = sum( 435 | [p.nelement() for p in model.layers.parameters()] 436 | ) 437 | 438 | model = model.to(device) 439 | 440 | # MLE optimizer 441 | local_lr = cfg.TRAIN.lr / num_gpus 442 | optimizer = optim.Adam(model.parameters(), lr=local_lr, 443 | weight_decay=cfg.TRAIN.weight_decay) 444 | 445 | #### scheduler 446 | 447 | # originally used for Transformer (in Attention is all you need) 448 | def lr_lambda(step): 449 | # return a multiplier instead of a learning rate 450 | if step == 0 and cfg.TRAIN.warmup_step == 0: 451 | return 1.0 452 | else: 453 | return ( 454 | max( 455 | (cfg.TRAIN.warmup_step ** 0.5) / (step ** 0.5), 456 | cfg.TRAIN.lr_min / cfg.TRAIN.lr, 457 | ) 458 | if step > cfg.TRAIN.warmup_step 459 | else step / cfg.TRAIN.warmup_step 460 | ) 461 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 462 | 463 | 464 | train_step = 0 465 | best_val_nll = np.inf 466 | 467 | model = DDP( 468 | model, 469 | device_ids=[args.local_rank], 470 | output_device=args.local_rank, 471 | broadcast_buffers=False, 472 | find_unused_parameters=False, 473 | ) 474 | 475 | logger.info("=" * 100) 476 | logger.info(args) 477 | logger.info("=" * 100) 478 | logger.info("#total params = {}".format(args.n_all_param)) 479 | logger.info("#non emb params in generator = {}".format(args.n_nonemb_param_gen)) 480 | 481 | ############################################################################### 482 | # Training code 483 | ############################################################################### 484 | logger.info("Start training") 485 | 486 | if __name__ == "__main__": 487 | train() 488 | # Load the best saved model. 489 | cfg.defrost() 490 | cfg.MODEL.same_length = True 491 | cfg.freeze() 492 | model = MemTransformerLM(cfg, dataset._vocab) 493 | checkpoint = torch.load(os.path.join(args.work_dir, "checkpoint_best.pt")) 494 | 495 | model.load_state_dict(checkpoint["model"]) 496 | # Do the evaluation of the best model 497 | model = model.to(device) 498 | 499 | test_token_num, test_total_nll = evaluate( 500 | eval_iter=test_iter 501 | ) 502 | test_token_num_pt = torch.tensor(test_token_num).to(device) 503 | test_total_nll_pt = torch.tensor(test_total_nll / 10000.0).to(device) 504 | torch.distributed.all_reduce(test_token_num_pt) 505 | torch.distributed.all_reduce(test_total_nll_pt) 506 | test_token_num = test_token_num_pt.item() 507 | test_nll = test_total_nll_pt.item() / (test_token_num / 10000.0) 508 | logger.info("=" * 100) 509 | logger.info( 510 | "| End of training | test nll {:5.2f} | test ppl {:9.3f}".format( 511 | test_nll, math.exp(test_nll) 512 | ) 513 | ) 514 | logger.info("=" * 100) 515 | --------------------------------------------------------------------------------