├── .gitignore ├── LICENSE ├── README.md ├── data └── download_link.txt ├── data_utils ├── __init__.py ├── midi_output.py ├── pytorch_datasets │ ├── __init__.py │ ├── accompaniment_dataset.py │ ├── base_class.py │ ├── const.py │ ├── counterpoint_dataset.py │ ├── dataloaders.py │ ├── form_dataset.py │ └── leadsheet_dataset.py ├── read_pop909_data.py ├── tonal_reduction_algo │ ├── __init__.py │ ├── main.py │ ├── postprocess.py │ ├── preprocess.py │ └── shortest_path_algo.py ├── train_valid_split.py └── utils │ ├── __init__.py │ ├── chord_reduction.py │ ├── format_converter.py │ ├── key_analysis.py │ ├── phrase_analysis.py │ ├── read_file.py │ ├── song_analyzer.py │ └── song_data_structure.py ├── experiments ├── __init__.py └── whole_song_gen.py ├── inference ├── __init__.py ├── generation_canvases.py ├── generation_operations.py └── utils.py ├── inference_whole_song.py ├── model ├── __init__.py ├── model_sdf.py └── stable_diffusion │ ├── __init__.py │ ├── latent_diffusion.py │ ├── losses │ ├── __init__.py │ ├── contperceptual.py │ ├── discriminator.py │ ├── lpips.py │ ├── util.py │ └── vqperceptual.py │ ├── model │ ├── autoencoder.py │ ├── autoreg_cond_encoders.py │ ├── external_cond_encoders.py │ ├── pretrained_encoders.py │ ├── unet.py │ └── unet_attention.py │ ├── sampler │ ├── __init__.py │ ├── ddim.py │ ├── ddpm.py │ └── sampler_sdf.py │ └── util.py ├── params ├── __init__.py ├── attrdict.py └── params.py ├── pretrained_models └── download_link.txt ├── results_default └── download_link.txt ├── train ├── __init__.py ├── learner.py └── train_config.py └── train_main.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | !data/download_link.txt 3 | results*/* 4 | !results*/download_link.txt 5 | pretrained_models/* 6 | !pretrained_models/download_link.txt 7 | .idea 8 | __pycache__/ 9 | .DS_Store 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ziyu Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Whole-song Generation 2 | 3 | [Demo](https://wholesonggen.github.io/) | [Paper](https://openreview.net/forum?id=sn7CYWyavh¬eId=3X6BSBDIPB) 4 | 5 | This is the code repository of the paper: 6 | 7 | > Ziyu Wang, Lejun Min, and Gus Xia. Whole-Song Hierarchical Generation of Symbolic Music Using Cascaded Diffusion Models. ICLR 2024. 8 | 9 | 10 | # Status 11 | [Apr-03-2024] The current version provides three main usages: 12 | 1. Training all 4 levels of the cascaded models; 13 | 2. Whole-song generation with specified Form (i.e., phrase and key). 14 | 3. Whole-song generation with generated Form (i.e., phrase and key). 15 | 16 | Currently, generation given prompt (e.g., first several measures) or with external control are not released. These will be properly reformatted in future version. 17 | 18 | We only release a portion of the model checkpoints sufficient for testing. The complete set of checkpoints will be released in future versions. 19 | 20 | # Downloading data and pre-trained checkpoints 21 | The data and pretrained checkpoints can be downloaded and added to the repository. These are in `data/` (training data), `pretrained_models/` (pretrained VAEs), and `results_default/` (cascaded Diffusion Models). Download them using the corresponding links given in `download_link.txt` files. 22 | 23 | 24 | # Training 25 | Here are the commands to train four levels of the Diffusion Models. Use `--external` to control whether to use external condition in the training. 26 | ``` 27 | 28 | # form 29 | python train_main.py --mode frm 30 | # optional 31 | python train_main.py --mode frm --multi_label 32 | 33 | # counterpoint 34 | python train_main.py --mode ctp --autoreg --mask_bg 35 | # with external control 36 | python train_main.py --mode ctp --autoreg --external --mask_bg 37 | 38 | # lead sheet 39 | python train_main.py --mode lsh --autoreg --mask_bg 40 | # with external control 41 | python train_main.py --mode lsh --autoreg --external --mask_bg 42 | 43 | # accompaniment 44 | python train_main.py --mode acc --autoreg --mask_bg 45 | # with external control 46 | python train_main.py --mode acc --autoreg --external --mask_bg 47 | 48 | ``` 49 | 50 | For a more detailed usage, check 51 | ``` 52 | python train_main.py -h 53 | ``` 54 | 55 | 56 | # Inference 57 | 58 | We currently provide functions for whole-song generation with or without form specification. By default, if models are not specified, the models in `results_default` will be used; and the results will be shown in `demo/`. 59 | 60 | To generate `n` (e.g., `n=4`) pieces of a given form (e.g., `i4A4A4B8b4A4B8o4` and G major) 61 | ``` 62 | python inference_whole_song.py --nsample 4 --pstring i4A4A4B8b4A4B8o4 --key 7 63 | ``` 64 | 65 | To generate `n` (e.g., `n=4`) pieces using a generated form: 66 | 67 | ``` 68 | python inference_whole_song.py --nsample 4 69 | ``` 70 | 71 | For a more detailed usage, check 72 | 73 | ``` 74 | python inference_whole_song.py -h 75 | ``` 76 | -------------------------------------------------------------------------------- /data/download_link.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/file/d/1NLYfnWXjtDK8u72ABTquLgzBTVg9JaCa/view?usp=drive_link -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .read_pop909_data import load_train_and_valid_data, analyze_train_and_valid_datasets 2 | from .pytorch_datasets import create_form_datasets, create_counterpoint_datasets, create_leadsheet_datasets, \ 3 | create_accompaniment_datasets 4 | from .pytorch_datasets.dataloaders import create_train_valid_dataloaders 5 | from .pytorch_datasets.const import LANGUAGE_DATASET_PARAMS, AUTOREG_PARAMS 6 | 7 | 8 | def load_datasets(mode, multi_phrase_label, random_pitch_aug, use_autoreg_cond, use_external_cond, 9 | mask_background, load_first_n=None): 10 | train_data, valid_data = load_train_and_valid_data(multi_phrase_label, load_first_n) 11 | 12 | train_analyses, valid_analyses = analyze_train_and_valid_datasets(train_data, valid_data) 13 | 14 | if mode == 'frm': 15 | train_set, valid_set = create_form_datasets( 16 | train_analyses, valid_analyses, multi_phrase_label, random_pitch_aug 17 | ) 18 | elif mode == 'ctp': 19 | train_set, valid_set = create_counterpoint_datasets( 20 | train_analyses, valid_analyses, use_autoreg_cond, use_external_cond, 21 | multi_phrase_label, random_pitch_aug, mask_background 22 | ) 23 | elif mode == 'lsh': 24 | train_set, valid_set = create_leadsheet_datasets( 25 | train_analyses, valid_analyses, use_autoreg_cond, use_external_cond, 26 | multi_phrase_label, random_pitch_aug, mask_background 27 | ) 28 | elif mode == 'acc': 29 | train_set, valid_set = create_accompaniment_datasets( 30 | train_analyses, valid_analyses, use_autoreg_cond, use_external_cond, 31 | multi_phrase_label, random_pitch_aug, mask_background 32 | ) 33 | else: 34 | raise NotImplementedError 35 | return train_set, valid_set 36 | -------------------------------------------------------------------------------- /data_utils/midi_output.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pretty_midi as pm 3 | 4 | 5 | def default_quantization(v): 6 | return 1 if v > 0.5 else 0 7 | 8 | 9 | def piano_roll_to_note_mat(piano_roll: np.ndarray, raise_chord: bool, 10 | quantization_func=None, seperate_chord=False): 11 | """ 12 | piano_roll: (2, L, 128), onset and sustain channel. 13 | raise_chord: whether pitch below 48 (mel-chd boundary) will be raised an octave 14 | """ 15 | def convert_p(p_, note_list, raise_pitch=False): 16 | edit_note_flag = False 17 | for t in range(n_step): 18 | onset_state = quantization_func(piano_roll[0, t, p_]) 19 | sustain_state = quantization_func(piano_roll[1, t, p_]) 20 | 21 | is_onset = bool(onset_state) 22 | is_sustain = bool(sustain_state) and not is_onset 23 | 24 | pitch = p_ + 12 if raise_pitch else p_ 25 | 26 | if is_onset: 27 | edit_note_flag = True 28 | note_list.append([t, pitch, 1]) 29 | elif is_sustain: 30 | if edit_note_flag: 31 | note_list[-1][-1] += 1 32 | else: 33 | edit_note_flag = False 34 | return note_list 35 | 36 | quantization_func = default_quantization if quantization_func is None else quantization_func 37 | assert len(piano_roll.shape) == 3 and piano_roll.shape[0] == 2 and piano_roll.shape[2] == 128 38 | 39 | n_step = piano_roll.shape[1] 40 | 41 | notes = [] 42 | chord_notes = [] 43 | 44 | for p in range(128): 45 | if p < 48: 46 | convert_p(p, chord_notes if seperate_chord else notes, True if raise_chord else False) 47 | else: 48 | convert_p(p, notes, False) 49 | 50 | if seperate_chord: 51 | return notes, chord_notes 52 | else: 53 | return notes 54 | 55 | 56 | def note_mat_to_notes(note_mat, bpm, unit, shift_beat=0., shift_sec=0., vel=100): 57 | """Default use shift beat""" 58 | 59 | beat_alpha = 60 / bpm 60 | step_alpha = unit * beat_alpha 61 | 62 | notes = [] 63 | 64 | shift_sec = shift_sec if shift_beat is None else shift_beat * beat_alpha 65 | 66 | for note in note_mat: 67 | onset, pitch, dur = note 68 | start = onset * step_alpha + shift_sec 69 | end = (onset + dur) * step_alpha + shift_sec 70 | 71 | notes.append(pm.Note(vel, int(pitch), start, end)) 72 | 73 | return notes 74 | 75 | 76 | def create_pm_object(bpm, preset=0, instrument_names=None, notes_list=None): 77 | midi = pm.PrettyMIDI(initial_tempo=bpm) 78 | 79 | presets = { 80 | 1: ['red_mel+red_chd'], 81 | 2: ['red_mel+red_chd', 'mel+chd'], 82 | 3: ['red_mel+red_chd', 'mel+chd', 'acc'], 83 | 5: ['red_mel', 'red_chd', 'mel', 'chd', 'acc'] 84 | } 85 | 86 | if instrument_names is None: 87 | instrument_names = presets[preset] 88 | 89 | midi.instruments = [pm.Instrument(0, name=name) for name in instrument_names] 90 | 91 | if notes_list is not None: 92 | assert len(notes_list) == len(midi.instruments) 93 | for i in range(len(midi.instruments)): 94 | midi.instruments[i].notes += notes_list[i] 95 | 96 | return midi 97 | -------------------------------------------------------------------------------- /data_utils/pytorch_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .form_dataset import create_form_datasets 2 | from .counterpoint_dataset import create_counterpoint_datasets 3 | from .leadsheet_dataset import create_leadsheet_datasets 4 | from .accompaniment_dataset import create_accompaniment_datasets 5 | -------------------------------------------------------------------------------- /data_utils/pytorch_datasets/accompaniment_dataset.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from .base_class import * 3 | from .const import LANGUAGE_DATASET_PARAMS, AUTOREG_PARAMS, SHIFT_HIGH_T, SHIFT_LOW_T, SHIFT_HIGH_V, SHIFT_LOW_V 4 | 5 | 6 | def acc_to_polydis_pr_mat(acc, length=None): 7 | length = acc.shape[1] if length is None else length 8 | pr_mat = np.zeros((length, 128), dtype=np.float32) 9 | 10 | dur = np.zeros(128, dtype=np.float32) 11 | 12 | for t in range(acc.shape[1] - 1, -1, -1): 13 | is_onset = acc[0, t] > 0 14 | pr_mat[t, is_onset] = acc[0, t, is_onset] + dur[is_onset] 15 | if t == 0: 16 | break 17 | dur[is_onset] = 0 18 | dur += acc[1, t] 19 | return pr_mat 20 | 21 | 22 | class AccompanimentDataset(HierarchicalDatasetBase): 23 | 24 | def __init__(self, analyses, shift_high=0, shift_low=0, max_l=128, h=128, n_channels=10, 25 | autoreg_seg_lgth=8, max_n_autoreg=3, n_autoreg_prob=np.array([0.1, 0.1, 0.2, 0.6]), 26 | seg_pad_unit=4, autoreg_max_l=108, 27 | use_autoreg_cond=True, use_external_cond=False, multi_phrase_label=False, 28 | random_pitch_aug=True, mask_background=False): 29 | 30 | super(AccompanimentDataset, self).__init__( 31 | analyses, shift_high, shift_low, max_l, h, n_channels, 32 | autoreg_seg_lgth, max_n_autoreg, n_autoreg_prob, seg_pad_unit, autoreg_max_l, 33 | use_autoreg_cond, use_external_cond, multi_phrase_label, random_pitch_aug, mask_background) 34 | 35 | form_langs = [analysis['languages']['form'] for analysis in analyses] 36 | ctpt_langs = [analysis['languages']['counterpoint'] for analysis in analyses] 37 | ldsht_langs = [analysis['languages']['lead_sheet'] for analysis in analyses] 38 | acc_langs = [analysis['languages']['accompaniment'] for analysis in analyses] 39 | 40 | self.key_rolls = [form_lang['key_roll'] for form_lang in form_langs] 41 | self.expand_key_rolls() 42 | 43 | self.phrase_rolls = [form_lang['phrase_roll'][:, :, np.newaxis] for form_lang in form_langs] 44 | self.expand_phrase_rolls() 45 | 46 | self.red_mel_rolls = [ctpt_lang['red_mel_roll'] for ctpt_lang in ctpt_langs] 47 | self.expand_red_mel_rolls() 48 | 49 | self.red_chd_rolls = [ctpt_lang['red_chd_roll'] for ctpt_lang in ctpt_langs] 50 | self.expand_red_chd_rolls() 51 | 52 | self.mel_rolls = [ldsht_lang['mel_roll'] for ldsht_lang in ldsht_langs] 53 | self.chd_rolls = [ldsht_lang['chd_roll'] for ldsht_lang in ldsht_langs] 54 | self.expand_chd_rolls() 55 | 56 | self.acc_rolls = [acc_lang['acc_roll'] for acc_lang in acc_langs] 57 | 58 | self.lengths = np.array([mel.shape[1] for mel in self.mel_rolls]) 59 | 60 | self.start_ids_per_song = [np.arange(0, lgth - self.max_l // 2, nspb * nbpm, dtype=np.int64) 61 | for lgth, nbpm, nspb in zip(self.lengths, self.nbpms, self.nspbs)] 62 | 63 | self.indices = self._song_id_to_indices() 64 | 65 | def expand_key_rolls(self): 66 | self.key_rolls = [expand_roll(roll, nbpm * nspb) 67 | for roll, nbpm, nspb in zip(self.key_rolls, self.nbpms, self.nspbs)] 68 | 69 | def expand_phrase_rolls(self): 70 | self.phrase_rolls = [expand_roll(roll, nbpm * nspb) 71 | for roll, nbpm, nspb in zip(self.phrase_rolls, self.nbpms, self.nspbs)] 72 | 73 | def expand_red_chd_rolls(self): 74 | self.red_chd_rolls = [expand_roll(roll, nspb, contain_onset=True) 75 | for roll, nspb in zip(self.red_chd_rolls, self.nspbs)] 76 | 77 | def expand_chd_rolls(self): 78 | self.chd_rolls = [expand_roll(roll, nspb, contain_onset=True) 79 | for roll, nspb in zip(self.chd_rolls, self.nspbs)] 80 | 81 | def expand_red_mel_rolls(self): 82 | self.red_mel_rolls = [expand_roll(roll, nspb, contain_onset=True) 83 | for roll, nspb in zip(self.red_mel_rolls, self.nspbs)] 84 | 85 | def get_data_sample(self, song_id, start_id, shift): 86 | nbpm, nspb = self.nbpms[song_id], self.nspbs[song_id] 87 | 88 | pitch_shift = compute_pitch_shift_value(shift, self.min_mel_pitches[song_id], self.max_mel_pitches[song_id]) 89 | 90 | self.store_key(song_id, pitch_shift) 91 | self.store_phrase(song_id) 92 | self.store_red_mel(song_id, pitch_shift) 93 | self.store_red_chd(song_id, pitch_shift) 94 | self.store_mel(song_id, pitch_shift) 95 | self.store_chd(song_id, pitch_shift) 96 | self.store_acc(song_id, shift) # note the difference 97 | 98 | img = self.lang_to_img(song_id, start_id, end_id=start_id + self.max_l, tgt_lgth=self.max_l) 99 | 100 | # prepare for the autoreg condition 101 | if self.use_autoreg_cond: 102 | autoreg_cond = self.get_autoreg_cond(song_id, start_id, nbpm * nspb) 103 | else: 104 | autoreg_cond = None 105 | 106 | # prepare for the external condition 107 | if self.use_external_cond: 108 | external_cond = self.get_external_cond(start_id) 109 | else: 110 | external_cond = None 111 | 112 | # randomly mask background 113 | if self.mask_background and np.random.random() > 0.8: 114 | img[2:] = -1 115 | 116 | return img, autoreg_cond, external_cond 117 | 118 | def lang_to_img(self, song_id, start_id, end_id, tgt_lgth=None): 119 | key_roll = self._key[:, start_id: end_id] # (2, L, 12) 120 | phrase_roll = self._phrase[:, start_id: end_id] # (6, L, 1) 121 | red_mel_roll = self._red_mel[:, start_id: end_id] # (2, L, 128) 122 | red_chd_roll = self._red_chd[:, start_id: end_id] # (6, L, 12) 123 | mel_roll = self._mel[:, start_id: end_id] 124 | chd_roll = self._chd[:, start_id: end_id] 125 | acc_roll = self._acc[:, start_id: end_id] 126 | 127 | actual_l = key_roll.shape[1] 128 | 129 | # to output image 130 | if tgt_lgth is None: 131 | tgt_lgth = end_id - start_id 132 | img = np.zeros((self.n_channels, tgt_lgth, 132), dtype=np.float32) 133 | img[0: 2, 0: actual_l, 0: 128] = acc_roll 134 | img[2: 4, 0: actual_l, 0: 128] = mel_roll 135 | img[2: 4, 0: actual_l, 36: 48] = chd_roll[2: 4] 136 | img[2: 4, 0: actual_l, 24: 36] = chd_roll[4: 6] 137 | 138 | img[4: 6, 0: actual_l, 0: 128] = red_mel_roll 139 | img[4: 6, 0: actual_l, 36: 48] = red_chd_roll[2: 4] 140 | img[4: 6, 0: actual_l, 24: 36] = red_chd_roll[4: 6] 141 | 142 | img[8: 14, 0: actual_l] = phrase_roll 143 | 144 | img = img.reshape((self.n_channels, tgt_lgth, 11, 12)) 145 | img[6: 8, 0: actual_l] = key_roll[:, :, np.newaxis] 146 | img = img.reshape((self.n_channels, tgt_lgth, 132)) 147 | return img[:, :, 0: self.h] 148 | 149 | def get_external_cond(self, start_id): 150 | acc = self._acc[:, start_id: start_id + self.max_l] 151 | return acc_to_polydis_pr_mat(acc, length=self.max_l) 152 | 153 | def show(self, item, show_img=True): 154 | data, autoreg, external = self[item] 155 | 156 | titles = ['acc', 'mel+chd', 'mel+rough_chd', 'key', 'phrase0-1', 'phrase2-3', 'phrase4-5'] 157 | print(data.shape) 158 | if show_img: 159 | if self.use_external_cond: 160 | fig, axs = plt.subplots(7, 3, figsize=(30, 40)) 161 | else: 162 | fig, axs = plt.subplots(7, 2, figsize=(20, 40)) 163 | for i in range(7): 164 | img = data[2 * i: 2 * i + 2] 165 | img = np.pad(img, pad_width=((0, 1), (0, 0), (0, 0)), mode='constant') 166 | img[2][img[0] < 0] = 1 167 | img[img < 0] = 0 168 | img = img.transpose((2, 1, 0)) 169 | axs[i, 0].imshow(img, origin='lower', aspect='auto') 170 | axs[i, 0].title.set_text(titles[i]) 171 | 172 | autoreg_img = autoreg[2 * i: 2 * i + 2] 173 | autoreg_img = np.pad(autoreg_img, pad_width=((0, 1), (0, 0), (0, 0)), mode='constant') 174 | autoreg_img[2][autoreg_img[0] < 0] = 1 175 | autoreg_img[autoreg_img < 0] = 0 176 | autoreg_img = autoreg_img.transpose((2, 1, 0)) 177 | 178 | axs[i, 1].imshow(autoreg_img, origin='lower', aspect='auto') 179 | 180 | if self.use_external_cond: 181 | axs[0, 2].imshow(external.T, origin='lower', aspect='auto') 182 | plt.show() 183 | 184 | 185 | def create_accompaniment_datasets(train_analyses, valid_analyses, use_autoreg_cond=True, use_external_cond=False, 186 | multi_phrase_label=False, random_pitch_aug=True, mask_background=True): 187 | lang_params = LANGUAGE_DATASET_PARAMS['accompaniment'] 188 | autoreg_params = AUTOREG_PARAMS['accompaniment'] 189 | 190 | train_dataset = AccompanimentDataset( 191 | train_analyses, SHIFT_HIGH_T, SHIFT_LOW_T, lang_params['max_l'], lang_params['h'], lang_params['n_channel'], 192 | autoreg_seg_lgth=autoreg_params['autoreg_seg_lgth'], max_n_autoreg=autoreg_params['max_n_autoreg'], 193 | n_autoreg_prob=autoreg_params['n_autoreg_prob'], seg_pad_unit=autoreg_params['seg_pad_unit'], 194 | autoreg_max_l=autoreg_params['autoreg_max_l'], use_autoreg_cond=use_autoreg_cond, 195 | use_external_cond=use_external_cond, multi_phrase_label=multi_phrase_label, random_pitch_aug=random_pitch_aug, 196 | mask_background=mask_background 197 | ) 198 | valid_dataset = AccompanimentDataset( 199 | valid_analyses, SHIFT_HIGH_V, SHIFT_LOW_V, lang_params['max_l'], lang_params['h'], lang_params['n_channel'], 200 | autoreg_seg_lgth=autoreg_params['autoreg_seg_lgth'], max_n_autoreg=autoreg_params['max_n_autoreg'], 201 | n_autoreg_prob=autoreg_params['n_autoreg_prob'], seg_pad_unit=autoreg_params['seg_pad_unit'], 202 | autoreg_max_l=autoreg_params['autoreg_max_l'], use_autoreg_cond=use_autoreg_cond, 203 | use_external_cond=use_external_cond, multi_phrase_label=multi_phrase_label, random_pitch_aug=random_pitch_aug, 204 | mask_background=mask_background 205 | ) 206 | return train_dataset, valid_dataset 207 | -------------------------------------------------------------------------------- /data_utils/pytorch_datasets/const.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | SHIFT_HIGH_T = 5 4 | SHIFT_LOW_T = -6 5 | SHIFT_HIGH_V = 0 6 | SHIFT_LOW_V = 0 7 | 8 | 9 | LANGUAGE_DATASET_PARAMS = { 10 | 'form': {'max_l': 256, 'h': 16, 'n_channel': 8, 'cur_channel': 8}, 11 | 'counterpoint': {'max_l': 128, 'h': 128, 'n_channel': 10, 'cur_channel': 2}, 12 | 'lead_sheet': {'max_l': 128, 'h': 128, 'n_channel': 12, 'cur_channel': 2}, 13 | 'accompaniment': {'max_l': 128, 'h': 128, 'n_channel': 14, 'cur_channel': 2}, 14 | } 15 | 16 | AUTOREG_PARAMS = { 17 | 'counterpoint': { 18 | 'autoreg_seg_lgth': 8, 'max_n_autoreg': 3, 'n_autoreg_prob': np.array([0.1, 0.1, 0.2, 0.6]), 19 | 'seg_pad_unit': 4, 'autoreg_max_l': 108, 20 | }, 21 | 'lead_sheet': { 22 | 'autoreg_seg_lgth': 4, 'max_n_autoreg': 2, 'n_autoreg_prob': np.array([0.1, 0.2, 0.7]), 23 | 'seg_pad_unit': 4, 'autoreg_max_l': 136 24 | }, 25 | 'accompaniment': { 26 | 'autoreg_seg_lgth': 4, 'max_n_autoreg': 2, 'n_autoreg_prob': np.array([0.1, 0.2, 0.7]), 27 | 'seg_pad_unit': 4, 'autoreg_max_l': 136 28 | } 29 | } 30 | 31 | -------------------------------------------------------------------------------- /data_utils/pytorch_datasets/counterpoint_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from .base_class import * 4 | from .const import LANGUAGE_DATASET_PARAMS, AUTOREG_PARAMS, SHIFT_HIGH_T, SHIFT_LOW_T, SHIFT_HIGH_V, SHIFT_LOW_V 5 | 6 | 7 | class CounterpointDataset(HierarchicalDatasetBase): 8 | 9 | def __init__(self, analyses, shift_high=0, shift_low=0, max_l=128, h=128, n_channels=10, 10 | autoreg_seg_lgth=8, max_n_autoreg=3, n_autoreg_prob=np.array([0.1, 0.1, 0.2, 0.6]), 11 | seg_pad_unit=4, autoreg_max_l=108, 12 | use_autoreg_cond=True, use_external_cond=False, multi_phrase_label=False, 13 | random_pitch_aug=True, mask_background=False): 14 | super(CounterpointDataset, self).__init__( 15 | analyses, shift_high, shift_low, max_l, h, n_channels, 16 | autoreg_seg_lgth, max_n_autoreg, n_autoreg_prob, seg_pad_unit, autoreg_max_l, 17 | use_autoreg_cond, use_external_cond, multi_phrase_label, random_pitch_aug, mask_background) 18 | 19 | form_langs = [analysis['languages']['form'] for analysis in analyses] 20 | ctpt_langs = [analysis['languages']['counterpoint'] for analysis in analyses] 21 | 22 | self.key_rolls = [form_lang['key_roll'] for form_lang in form_langs] 23 | self.expand_key_rolls() 24 | 25 | self.phrase_rolls = [form_lang['phrase_roll'][:, :, np.newaxis] for form_lang in form_langs] 26 | self.expand_phrase_rolls() 27 | 28 | self.red_mel_rolls = [ctpt_lang['red_mel_roll'] for ctpt_lang in ctpt_langs] 29 | self.red_chd_rolls = [ctpt_lang['red_chd_roll'] for ctpt_lang in ctpt_langs] 30 | 31 | self.lengths = np.array([red.shape[1] for red in self.red_mel_rolls]) 32 | 33 | self.start_ids_per_song = [np.arange(0, lgth - self.max_l // 2, nbpm, dtype=np.int64) 34 | for lgth, nbpm in zip(self.lengths, self.nbpms)] 35 | 36 | self.indices = self._song_id_to_indices() 37 | 38 | def expand_key_rolls(self): 39 | self.key_rolls = [expand_roll(roll, nbpm) for roll, nbpm in zip(self.key_rolls, self.nbpms)] 40 | 41 | def expand_phrase_rolls(self): 42 | self.phrase_rolls = [expand_roll(roll, nbpm) for roll, nbpm in zip(self.phrase_rolls, self.nbpms)] 43 | 44 | def get_data_sample(self, song_id, start_id, shift): 45 | nbpm = self.nbpms[song_id] 46 | 47 | pitch_shift = compute_pitch_shift_value(shift, self.min_mel_pitches[song_id], self.max_mel_pitches[song_id]) 48 | 49 | self.store_key(song_id, pitch_shift) 50 | self.store_phrase(song_id) 51 | self.store_red_mel(song_id, pitch_shift) 52 | self.store_red_chd(song_id, pitch_shift) 53 | 54 | img = self.lang_to_img(song_id, start_id, end_id=start_id + self.max_l, tgt_lgth=self.max_l) 55 | 56 | # prepare for the autoreg condition 57 | if self.use_autoreg_cond: 58 | autoreg_cond = self.get_autoreg_cond(song_id, start_id, nbpm) 59 | else: 60 | autoreg_cond = None 61 | 62 | # prepare for the external condition 63 | if self.use_external_cond: 64 | external_cond = self.get_external_cond(start_id) 65 | else: 66 | external_cond = None 67 | 68 | # randomly mask background 69 | if self.mask_background and np.random.random() > 0.8: 70 | img[2:] = -1 71 | 72 | return img, autoreg_cond, external_cond 73 | 74 | def lang_to_img(self, song_id, start_id, end_id, tgt_lgth=None): 75 | key_roll = self._key[:, start_id: end_id] # (2, L, 12) 76 | phrase_roll = self._phrase[:, start_id: end_id] # (6, L, 1) 77 | red_mel_roll = self._red_mel[:, start_id: end_id] # (2, L, 128) 78 | red_chd_roll = self._red_chd[:, start_id: end_id] # (6, L, 12) 79 | 80 | actual_l = key_roll.shape[1] 81 | 82 | # to output image 83 | if tgt_lgth is None: 84 | tgt_lgth = self._key.shape[1] - start_id 85 | img = np.zeros((self.n_channels, tgt_lgth, 132), dtype=np.float32) 86 | img[0: 2, 0: actual_l, 0: 128] = red_mel_roll 87 | img[0: 2, 0: actual_l, 36: 48] = red_chd_roll[2: 4] 88 | img[0: 2, 0: actual_l, 24: 36] = red_chd_roll[4: 6] 89 | 90 | img[4: 10, 0: actual_l] = phrase_roll 91 | 92 | img = img.reshape((self.n_channels, tgt_lgth, 11, 12)) 93 | img[2: 4, 0: actual_l] = key_roll[:, :, np.newaxis] 94 | img = img.reshape((self.n_channels, tgt_lgth, 132)) 95 | return img[:, :, 0: self.h] 96 | 97 | def get_external_cond(self, start_id): 98 | external_cond = np.zeros((self.max_l, 36), dtype=np.float32) 99 | 100 | red_chd = self._red_chd[:, start_id: start_id + self.max_l] 101 | actual_l = red_chd.shape[1] 102 | 103 | # root 104 | root = red_chd[0: 2].max(0) # (actual_l, 12) 105 | root_argmax = root.argmax(-1) # (actual_l) 106 | 107 | # chroma 108 | chroma = red_chd[2: 4].max(0) # (actual_l, 12) 109 | 110 | # bass 111 | bass = red_chd[4: 6].max(0) # (actual_l, 12) 112 | bass_argmax = bass.argmax(-1) # (actual_l) 113 | rel_bass_argmax = (bass_argmax - root_argmax) % 12 # (actual_l) 114 | 115 | external_cond[0: actual_l, 0: 12] = root 116 | external_cond[0: actual_l, 12: 24] = chroma 117 | external_cond[np.arange(0, actual_l), 24 + rel_bass_argmax] = 1. 118 | return external_cond 119 | 120 | def show(self, item, show_img=True): 121 | data, autoreg, external = self[item] 122 | 123 | titles = ['red_mel/red_chd', 'key', 'phrase0-1', 'phrase2-3', 'phrase4-5'] 124 | 125 | if show_img: 126 | if self.use_external_cond: 127 | fig, axs = plt.subplots(5, 3, figsize=(30, 30)) 128 | else: 129 | fig, axs = plt.subplots(5, 2, figsize=(20, 30)) 130 | for i in range(5): 131 | img = data[2 * i: 2 * i + 2] 132 | img = np.pad(img, pad_width=((0, 1), (0, 0), (0, 0)), mode='constant') 133 | img[2][img[0] < 0] = 1 134 | img[img < 0] = 0 135 | img = img.transpose((2, 1, 0)) 136 | axs[i, 0].imshow(img, origin='lower', aspect='auto') 137 | axs[i, 0].title.set_text(titles[i]) 138 | 139 | autoreg_img = autoreg[2 * i: 2 * i + 2] 140 | autoreg_img = np.pad(autoreg_img, pad_width=((0, 1), (0, 0), (0, 0)), mode='constant') 141 | autoreg_img[2][autoreg_img[0] < 0] = 1 142 | autoreg_img[autoreg_img < 0] = 0 143 | autoreg_img = autoreg_img.transpose((2, 1, 0)) 144 | 145 | axs[i, 1].imshow(autoreg_img, origin='lower', aspect='auto') 146 | 147 | if self.use_external_cond: 148 | axs[0, 2].imshow(external.T, origin='lower', aspect='auto') 149 | plt.show() 150 | 151 | 152 | def create_counterpoint_datasets(train_analyses, valid_analyses, use_autoreg_cond=True, use_external_cond=False, 153 | multi_phrase_label=False, random_pitch_aug=True, mask_background=True): 154 | 155 | lang_params = LANGUAGE_DATASET_PARAMS['counterpoint'] 156 | autoreg_params = AUTOREG_PARAMS['counterpoint'] 157 | 158 | train_dataset = CounterpointDataset( 159 | train_analyses, SHIFT_HIGH_T, SHIFT_LOW_T, lang_params['max_l'], lang_params['h'], lang_params['n_channel'], 160 | autoreg_seg_lgth=autoreg_params['autoreg_seg_lgth'], max_n_autoreg=autoreg_params['max_n_autoreg'], 161 | n_autoreg_prob=autoreg_params['n_autoreg_prob'], seg_pad_unit=autoreg_params['seg_pad_unit'], 162 | autoreg_max_l=autoreg_params['autoreg_max_l'], use_autoreg_cond=use_autoreg_cond, 163 | use_external_cond=use_external_cond, multi_phrase_label=multi_phrase_label, random_pitch_aug=random_pitch_aug, 164 | mask_background=mask_background 165 | ) 166 | valid_dataset = CounterpointDataset( 167 | valid_analyses, SHIFT_HIGH_V, SHIFT_LOW_V, lang_params['max_l'], lang_params['h'], lang_params['n_channel'], 168 | autoreg_seg_lgth=autoreg_params['autoreg_seg_lgth'], max_n_autoreg=autoreg_params['max_n_autoreg'], 169 | n_autoreg_prob=autoreg_params['n_autoreg_prob'], seg_pad_unit=autoreg_params['seg_pad_unit'], 170 | autoreg_max_l=autoreg_params['autoreg_max_l'], use_autoreg_cond=use_autoreg_cond, 171 | use_external_cond=use_external_cond, multi_phrase_label=multi_phrase_label, random_pitch_aug=random_pitch_aug, 172 | mask_background=mask_background 173 | ) 174 | return train_dataset, valid_dataset 175 | -------------------------------------------------------------------------------- /data_utils/pytorch_datasets/dataloaders.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torch 3 | import re 4 | import collections 5 | 6 | 7 | string_classes = (str, bytes) 8 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 9 | 10 | 11 | default_collate_err_msg_format = ( 12 | "default_collate: batch must contain tensors, numpy arrays, numbers, " 13 | "dicts or lists; found {}") 14 | 15 | 16 | def my_collate_fn(batch): 17 | elem = batch[0] 18 | elem_type = type(elem) 19 | if isinstance(elem, torch.Tensor): 20 | out = None 21 | if torch.utils.data.get_worker_info() is not None: 22 | # If we're in a background process, concatenate directly into a 23 | # shared memory tensor to avoid an extra copy 24 | numel = sum(x.numel() for x in batch) 25 | storage = elem.storage()._new_shared(numel) 26 | out = elem.new(storage) 27 | return torch.stack(batch, 0, out=out) 28 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 29 | and elem_type.__name__ != 'string_': 30 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 31 | # array of string classes and object 32 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 33 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 34 | 35 | return my_collate_fn([torch.as_tensor(b) for b in batch]) 36 | elif elem.shape == (): # scalars 37 | return torch.as_tensor(batch) 38 | elif isinstance(elem, float): 39 | return torch.tensor(batch, dtype=torch.float64) 40 | elif isinstance(elem, int): 41 | return torch.tensor(batch) 42 | elif isinstance(elem, string_classes): 43 | return batch 44 | elif isinstance(elem, collections.abc.Mapping): 45 | return {key: my_collate_fn([d[key] for d in batch]) for key in elem} 46 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 47 | return elem_type(*(my_collate_fn(samples) 48 | for samples in zip(*batch))) 49 | elif isinstance(elem, collections.abc.Sequence): 50 | # check to make sure that the elements in batch have consistent size 51 | it = iter(batch) 52 | elem_size = len(next(it)) 53 | if not all(len(elem) == elem_size for elem in it): 54 | raise RuntimeError('each element in list of batch should ' 55 | 'be of equal size') 56 | transposed = zip(*batch) 57 | return [my_collate_fn(samples) for samples in transposed] 58 | elif elem is None: 59 | return None 60 | 61 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 62 | 63 | 64 | def create_train_valid_dataloaders(batch_size, train_set, valid_set): 65 | train_dl = DataLoader(train_set, batch_size, True, collate_fn=my_collate_fn) 66 | valid_dl = DataLoader(valid_set, batch_size, True, collate_fn=my_collate_fn) 67 | return train_dl, valid_dl 68 | -------------------------------------------------------------------------------- /data_utils/pytorch_datasets/form_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from .base_class import HierarchicalDatasetBase 4 | from .const import LANGUAGE_DATASET_PARAMS, AUTOREG_PARAMS, SHIFT_HIGH_T, SHIFT_LOW_T, SHIFT_HIGH_V, SHIFT_LOW_V 5 | 6 | 7 | class FormDataset(HierarchicalDatasetBase): 8 | 9 | def __init__(self, analyses, shift_high=0, shift_low=0, max_l=256, h=12, n_channels=8, 10 | multi_phrase_label=False, random_pitch_aug=True): 11 | super(FormDataset, self).__init__( 12 | analyses, shift_high, shift_low, max_l, h, n_channels, 13 | use_autoreg_cond=False, use_external_cond=False, 14 | multi_phrase_label=multi_phrase_label, random_pitch_aug=random_pitch_aug, mask_background=False) 15 | 16 | assert max_l >= 210, "Some pieces may be longer than the current max_l." 17 | 18 | form_langs = [analysis['languages']['form'] for analysis in analyses] 19 | 20 | self.key_rolls = [form_lang['key_roll'] for form_lang in form_langs] 21 | self.phrase_rolls = [form_lang['phrase_roll'][:, :, np.newaxis] for form_lang in form_langs] 22 | 23 | self.lengths = np.array([roll.shape[1] for roll in self.key_rolls]) 24 | 25 | self.start_ids_per_song = [np.zeros(1, dtype=np.int64) for _ in range(len(self.lengths))] 26 | 27 | self.indices = self._song_id_to_indices() 28 | 29 | def get_data_sample(self, song_id, start_id, shift): 30 | self.store_key(song_id, shift) 31 | self.store_phrase(song_id) 32 | 33 | img = self.lang_to_img(song_id, start_id, end_id=start_id + self.max_l, tgt_lgth=self.max_l) 34 | 35 | return img, None, None 36 | 37 | def lang_to_img(self, song_id, start_id, end_id, tgt_lgth=None): 38 | key_roll = self._key[:, start_id: end_id] # (2, L, 12) 39 | phrase_roll = self._phrase[:, start_id: end_id] # (6, L, 1) 40 | 41 | actual_l = self._key.shape[1] 42 | 43 | # to output image 44 | if tgt_lgth is None: 45 | tgt_lgth = end_id - start_id 46 | img = np.zeros((self.n_channels, tgt_lgth, self.h), dtype=np.float32) 47 | img[0: 2, 0: actual_l, 0: 12] = self._key 48 | img[2: 8, 0: actual_l] = self._phrase 49 | 50 | return img 51 | 52 | def show(self, item, show_img=True): 53 | sample = self[item][0] 54 | titles = ['key', 'phrase0-1', 'phrase2-3', 'phrase4-5'] 55 | 56 | if show_img: 57 | fig, axs = plt.subplots(4, 1, figsize=(10, 30)) 58 | for i in range(4): 59 | img = sample[2 * i: 2 * i + 2] 60 | img = np.pad(img, pad_width=((0, 1), (0, 0), (0, 0)), mode='constant') 61 | img = img.transpose((2, 1, 0)) 62 | axs[i].imshow(img, origin='lower', aspect='auto') 63 | axs[i].title.set_text(titles[i]) 64 | plt.show() 65 | 66 | 67 | def create_form_datasets(train_analyses, valid_analyses, multi_phrase_label=False, random_pitch_aug=True): 68 | 69 | lang_params = LANGUAGE_DATASET_PARAMS['form'] 70 | 71 | train_dataset = FormDataset( 72 | train_analyses, SHIFT_HIGH_T, SHIFT_LOW_T, lang_params['max_l'], lang_params['h'], lang_params['n_channel'], 73 | multi_phrase_label=multi_phrase_label, random_pitch_aug=random_pitch_aug 74 | ) 75 | valid_dataset = FormDataset( 76 | valid_analyses, SHIFT_HIGH_V, SHIFT_LOW_V, lang_params['max_l'], lang_params['h'], lang_params['n_channel'], 77 | multi_phrase_label=multi_phrase_label, random_pitch_aug=random_pitch_aug 78 | ) 79 | return train_dataset, valid_dataset 80 | -------------------------------------------------------------------------------- /data_utils/pytorch_datasets/leadsheet_dataset.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from .base_class import * 3 | from .const import LANGUAGE_DATASET_PARAMS, AUTOREG_PARAMS, SHIFT_HIGH_T, SHIFT_LOW_T, SHIFT_HIGH_V, SHIFT_LOW_V 4 | 5 | 6 | def mel_to_ec2vae_pr(mel): 7 | # mel: (2, max_l, 128) 8 | 9 | pr = np.zeros((mel.shape[1], 130), dtype=np.float32) 10 | 11 | pr[:, 0: 128] = mel[0] 12 | 13 | is_sustain = mel[1].sum(-1) > 0 # (max_l,) 14 | is_rest = mel.sum(0).sum(-1) == 0 # (max_l, ) 15 | 16 | pr[is_sustain, 128] = 1 17 | pr[is_rest, 129] = 1 18 | 19 | return pr 20 | 21 | 22 | class LeadSheetDataset(HierarchicalDatasetBase): 23 | def __init__(self, analyses, shift_high=0, shift_low=0, max_l=128, h=128, n_channels=10, 24 | autoreg_seg_lgth=8, max_n_autoreg=3, n_autoreg_prob=np.array([0.1, 0.1, 0.2, 0.6]), 25 | seg_pad_unit=4, autoreg_max_l=108, 26 | use_autoreg_cond=True, use_external_cond=False, multi_phrase_label=False, 27 | random_pitch_aug=True, mask_background=False): 28 | 29 | super(LeadSheetDataset, self).__init__( 30 | analyses, shift_high, shift_low, max_l, h, n_channels, 31 | autoreg_seg_lgth, max_n_autoreg, n_autoreg_prob, seg_pad_unit, autoreg_max_l, 32 | use_autoreg_cond, use_external_cond, multi_phrase_label, random_pitch_aug, mask_background) 33 | 34 | form_langs = [analysis['languages']['form'] for analysis in analyses] 35 | ctpt_langs = [analysis['languages']['counterpoint'] for analysis in analyses] 36 | ldsht_langs = [analysis['languages']['lead_sheet'] for analysis in analyses] 37 | 38 | self.key_rolls = [form_lang['key_roll'] for form_lang in form_langs] 39 | self.expand_key_rolls() 40 | 41 | self.phrase_rolls = [form_lang['phrase_roll'][:, :, np.newaxis] for form_lang in form_langs] 42 | self.expand_phrase_rolls() 43 | 44 | self.red_mel_rolls = [ctpt_lang['red_mel_roll'] for ctpt_lang in ctpt_langs] 45 | self.expand_red_mel_rolls() 46 | 47 | self.red_chd_rolls = [ctpt_lang['red_chd_roll'] for ctpt_lang in ctpt_langs] 48 | self.expand_red_chd_rolls() 49 | 50 | self.mel_rolls = [ldsht_lang['mel_roll'] for ldsht_lang in ldsht_langs] 51 | self.chd_rolls = [ldsht_lang['chd_roll'] for ldsht_lang in ldsht_langs] 52 | self.expand_chd_rolls() 53 | 54 | self.lengths = np.array([mel.shape[1] for mel in self.mel_rolls]) 55 | 56 | self.start_ids_per_song = [np.arange(0, lgth - self.max_l // 2, nspb * nbpm, dtype=np.int64) 57 | for lgth, nbpm, nspb in zip(self.lengths, self.nbpms, self.nspbs)] 58 | 59 | self.indices = self._song_id_to_indices() 60 | 61 | def expand_key_rolls(self): 62 | self.key_rolls = [expand_roll(roll, nbpm * nspb) 63 | for roll, nbpm, nspb in zip(self.key_rolls, self.nbpms, self.nspbs)] 64 | 65 | def expand_phrase_rolls(self): 66 | self.phrase_rolls = [expand_roll(roll, nbpm * nspb) 67 | for roll, nbpm, nspb in zip(self.phrase_rolls, self.nbpms, self.nspbs)] 68 | 69 | def expand_red_chd_rolls(self): 70 | self.red_chd_rolls = [expand_roll(roll, nspb, contain_onset=True) 71 | for roll, nspb in zip(self.red_chd_rolls, self.nspbs)] 72 | 73 | def expand_chd_rolls(self): 74 | self.chd_rolls = [expand_roll(roll, nspb, contain_onset=True) 75 | for roll, nspb in zip(self.chd_rolls, self.nspbs)] 76 | 77 | def expand_red_mel_rolls(self): 78 | self.red_mel_rolls = [expand_roll(roll, nspb, contain_onset=True) 79 | for roll, nspb in zip(self.red_mel_rolls, self.nspbs)] 80 | 81 | def get_data_sample(self, song_id, start_id, shift): 82 | nbpm, nspb = self.nbpms[song_id], self.nspbs[song_id] 83 | 84 | pitch_shift = compute_pitch_shift_value(shift, self.min_mel_pitches[song_id], self.max_mel_pitches[song_id]) 85 | 86 | self.store_key(song_id, pitch_shift) 87 | self.store_phrase(song_id) 88 | self.store_red_mel(song_id, pitch_shift) 89 | self.store_red_chd(song_id, pitch_shift) 90 | self.store_mel(song_id, pitch_shift) 91 | self.store_chd(song_id, pitch_shift) 92 | 93 | img = self.lang_to_img(song_id, start_id, end_id=start_id + self.max_l, tgt_lgth=self.max_l) 94 | 95 | # prepare for the autoreg condition 96 | if self.use_autoreg_cond: 97 | autoreg_cond = self.get_autoreg_cond(song_id, start_id, nbpm * nspb) 98 | else: 99 | autoreg_cond = None 100 | 101 | # prepare for the external condition 102 | if self.use_external_cond: 103 | external_cond = self.get_external_cond(start_id) 104 | else: 105 | external_cond = None 106 | 107 | # randomly mask background 108 | if self.mask_background and np.random.random() > 0.8: 109 | img[2:] = -1 110 | 111 | return img, autoreg_cond, external_cond 112 | 113 | def lang_to_img(self, song_id, start_id, end_id, tgt_lgth=None): 114 | key_roll = self._key[:, start_id: end_id] # (2, L, 12) 115 | phrase_roll = self._phrase[:, start_id: end_id] # (6, L, 1) 116 | red_mel_roll = self._red_mel[:, start_id: end_id] # (2, L, 128) 117 | red_chd_roll = self._red_chd[:, start_id: end_id] # (6, L, 12) 118 | mel_roll = self._mel[:, start_id: end_id] 119 | chd_roll = self._chd[:, start_id: end_id] 120 | 121 | actual_l = key_roll.shape[1] 122 | 123 | # to output image 124 | if tgt_lgth is None: 125 | tgt_lgth = end_id - start_id 126 | img = np.zeros((self.n_channels, tgt_lgth, 132), dtype=np.float32) 127 | img[0: 2, 0: actual_l, 0: 128] = mel_roll 128 | img[0: 2, 0: actual_l, 36: 48] = chd_roll[2: 4] 129 | img[0: 2, 0: actual_l, 24: 36] = chd_roll[4: 6] 130 | 131 | img[2: 4, 0: actual_l, 0: 128] = red_mel_roll 132 | img[2: 4, 0: actual_l, 36: 48] = red_chd_roll[2: 4] 133 | img[2: 4, 0: actual_l, 24: 36] = red_chd_roll[4: 6] 134 | 135 | img[6: 12, 0: actual_l] = phrase_roll 136 | 137 | img = img.reshape((self.n_channels, tgt_lgth, 11, 12)) 138 | img[4: 6, 0: actual_l] = key_roll[:, :, np.newaxis] 139 | img = img.reshape((self.n_channels, tgt_lgth, 132)) 140 | return img[:, :, 0: self.h] 141 | 142 | def get_external_cond(self, start_id): 143 | external_cond = np.zeros((self.max_l, 142), dtype=np.float32) 144 | 145 | chroma = self._chd[:, start_id: start_id + self.max_l][2: 4].sum(0) # (max_l, 12) 146 | mel = self._mel[:, start_id: start_id + self.max_l] 147 | 148 | actual_l = mel.shape[1] 149 | 150 | external_cond[0: actual_l, 130:] = chroma 151 | external_cond[0: actual_l, 0: 130] = mel_to_ec2vae_pr(mel) 152 | return external_cond 153 | 154 | def show(self, item, show_img=True): 155 | data, autoreg, external = self[item] 156 | 157 | titles = ['mel+chd', 'mel+rough_chd', 'key', 'phrase0-1', 'phrase2-3', 'phrase4-5'] 158 | 159 | if show_img: 160 | if self.use_external_cond: 161 | fig, axs = plt.subplots(6, 3, figsize=(30, 40)) 162 | else: 163 | fig, axs = plt.subplots(6, 2, figsize=(20, 40)) 164 | for i in range(6): 165 | img = data[2 * i: 2 * i + 2] 166 | img = np.pad(img, pad_width=((0, 1), (0, 0), (0, 0)), mode='constant') 167 | img[2][img[0] < 0] = 1 168 | img[img < 0] = 0 169 | img = img.transpose((2, 1, 0)) 170 | axs[i, 0].imshow(img, origin='lower', aspect='auto') 171 | axs[i, 0].title.set_text(titles[i]) 172 | 173 | autoreg_img = autoreg[2 * i: 2 * i + 2] 174 | autoreg_img = np.pad(autoreg_img, pad_width=((0, 1), (0, 0), (0, 0)), mode='constant') 175 | autoreg_img[2][autoreg_img[0] < 0] = 1 176 | autoreg_img[autoreg_img < 0] = 0 177 | autoreg_img = autoreg_img.transpose((2, 1, 0)) 178 | 179 | axs[i, 1].imshow(autoreg_img, origin='lower', aspect='auto') 180 | 181 | if self.use_external_cond: 182 | axs[0, 2].imshow(external.T, origin='lower', aspect='auto') 183 | plt.show() 184 | 185 | 186 | def create_leadsheet_datasets(train_analyses, valid_analyses, use_autoreg_cond=True, use_external_cond=False, 187 | multi_phrase_label=False, random_pitch_aug=True, mask_background=True): 188 | 189 | lang_params = LANGUAGE_DATASET_PARAMS['lead_sheet'] 190 | autoreg_params = AUTOREG_PARAMS['lead_sheet'] 191 | 192 | train_dataset = LeadSheetDataset( 193 | train_analyses, SHIFT_HIGH_T, SHIFT_LOW_T, lang_params['max_l'], lang_params['h'], lang_params['n_channel'], 194 | autoreg_seg_lgth=autoreg_params['autoreg_seg_lgth'], max_n_autoreg=autoreg_params['max_n_autoreg'], 195 | n_autoreg_prob=autoreg_params['n_autoreg_prob'], seg_pad_unit=autoreg_params['seg_pad_unit'], 196 | autoreg_max_l=autoreg_params['autoreg_max_l'], use_autoreg_cond=use_autoreg_cond, 197 | use_external_cond=use_external_cond, multi_phrase_label=multi_phrase_label, random_pitch_aug=random_pitch_aug, 198 | mask_background=mask_background 199 | ) 200 | valid_dataset = LeadSheetDataset( 201 | valid_analyses, SHIFT_HIGH_V, SHIFT_LOW_V, lang_params['max_l'], lang_params['h'], lang_params['n_channel'], 202 | autoreg_seg_lgth=autoreg_params['autoreg_seg_lgth'], max_n_autoreg=autoreg_params['max_n_autoreg'], 203 | n_autoreg_prob=autoreg_params['n_autoreg_prob'], seg_pad_unit=autoreg_params['seg_pad_unit'], 204 | autoreg_max_l=autoreg_params['autoreg_max_l'], use_autoreg_cond=use_autoreg_cond, 205 | use_external_cond=use_external_cond, multi_phrase_label=multi_phrase_label, random_pitch_aug=random_pitch_aug, 206 | mask_background=mask_background 207 | ) 208 | return train_dataset, valid_dataset 209 | -------------------------------------------------------------------------------- /data_utils/read_pop909_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from data_utils.utils.read_file import read_data 5 | from data_utils.utils.song_data_structure import McpaMusic 6 | from data_utils.utils.song_analyzer import LanguageExtractor 7 | from typing import List 8 | from .train_valid_split import load_split_file 9 | 10 | 11 | TRIPLE_METER_SONG = [ 12 | 34, 62, 102, 107, 152, 173, 176, 203, 215, 231, 254, 280, 307, 328, 369, 13 | 584, 592, 653, 654, 662, 744, 749, 756, 770, 799, 843, 869, 872, 887 14 | ] 15 | 16 | 17 | PROJECT_PATH = os.path.join(os.path.dirname(__file__), '..') 18 | 19 | DATASET_PATH = os.path.join(PROJECT_PATH, 'data', 'pop909_w_structure_label') 20 | ACC_DATASET_PATH = os.path.join(PROJECT_PATH, 'data', 'matched_pop909_acc') 21 | 22 | LABEL_SOURCE = np.load(os.path.join(PROJECT_PATH, 'data', 23 | 'pop909_w_structure_label', 24 | 'label_source.npy')) 25 | 26 | SPLIT_FILE_PATH = os.path.join(PROJECT_PATH, 'data', 'pop909_split', 'split.npz') 27 | 28 | 29 | def read_pop909_dataset(song_ids=None, label_fns=None, desc_dataset=None): 30 | """If label_fn is None, use default the selected label file in LABEL_SOURCE""" 31 | 32 | dataset = [] 33 | 34 | song_ids = [si for si in range(1, 910)] if song_ids is None else song_ids 35 | 36 | for idx, i in enumerate(tqdm(song_ids, desc=None if desc_dataset is None else f'Loading {desc_dataset}')): 37 | # which human label file to use 38 | label = LABEL_SOURCE[i - 1] if label_fns is None else label_fns[idx] 39 | 40 | num_beat_per_measure = 3 if i in TRIPLE_METER_SONG else 4 41 | 42 | song_name = str(i).zfill(3) # e.g., '001' 43 | 44 | data_fn = os.path.join(DATASET_PATH, song_name) # data folder of the song 45 | 46 | acc_fn = os.path.join(ACC_DATASET_PATH, song_name) 47 | 48 | song_data = read_data(data_fn, acc_fn, num_beat_per_measure=num_beat_per_measure, num_step_per_beat=4, 49 | clean_chord_unit=num_beat_per_measure, song_name=song_name, label=label) 50 | 51 | dataset.append(song_data) 52 | 53 | return dataset 54 | 55 | 56 | def read_pop909_dataset_with_multi_phrase_labels(song_ids=None, desc_dataset=None): 57 | dataset = [] 58 | 59 | song_ids = [si for si in range(1, 910)] if song_ids is None else song_ids 60 | 61 | for label in [1, 2]: 62 | label_fns = [label] * len(song_ids) 63 | dataset += read_pop909_dataset(song_ids, label_fns, desc_dataset=desc_dataset + f'-label-{label}') 64 | return dataset 65 | 66 | 67 | def analyze_pop909_dataset(dataset: List[McpaMusic], desc_dataset=None): 68 | hie_lang_dataset = [] 69 | for song in tqdm(dataset, desc=None if desc_dataset is None else f'Analyzing {desc_dataset}'): 70 | lang_extractor = LanguageExtractor(song) 71 | hie_lang = lang_extractor.analyze_for_training() 72 | hie_lang_dataset.append(hie_lang) 73 | return hie_lang_dataset 74 | 75 | 76 | def load_train_and_valid_data(use_multi_phrase_label=False, load_first_n=None): 77 | train_dataset = [] 78 | valid_dataset = [] 79 | 80 | train_ids, valid_ids = load_split_file(SPLIT_FILE_PATH) 81 | 82 | if use_multi_phrase_label: 83 | train_dataset = read_pop909_dataset_with_multi_phrase_labels( 84 | train_ids[0: load_first_n] + 1, desc_dataset='train set (multi-label)' 85 | ) 86 | 87 | valid_dataset = read_pop909_dataset_with_multi_phrase_labels( 88 | valid_ids[0: load_first_n] + 1, desc_dataset='valid set (multi-label)' 89 | ) 90 | else: 91 | train_dataset = read_pop909_dataset(train_ids[0: load_first_n] + 1, desc_dataset='train set') 92 | valid_dataset = read_pop909_dataset(valid_ids[0: load_first_n] + 1, desc_dataset='valid set') 93 | 94 | return train_dataset, valid_dataset 95 | 96 | 97 | def analyze_train_and_valid_datasets(train_set, valid_set): 98 | return analyze_pop909_dataset(train_set, 'train set'), analyze_pop909_dataset(valid_set, 'valid set') 99 | -------------------------------------------------------------------------------- /data_utils/tonal_reduction_algo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZWaang/whole-song-gen/11326220d2f032bdcdce60b8b8cf6891fb2ca308/data_utils/tonal_reduction_algo/__init__.py -------------------------------------------------------------------------------- /data_utils/tonal_reduction_algo/main.py: -------------------------------------------------------------------------------- 1 | from .shortest_path_algo import find_tonal_shortest_paths 2 | from .preprocess import preprocess_data 3 | from .postprocess import path_to_chord_bins, chord_bins_to_reduction_mat 4 | import matplotlib.pyplot as plt 5 | import networkx as nx 6 | import numpy as np 7 | 8 | 9 | class TrAlgo: 10 | 11 | def __init__(self, distance_factor=1.6, onset_factor=1.0, chord_factor=1.0, pitch_factor=1.0, 12 | duration_factor=1.0): 13 | 14 | self.distance_factor = distance_factor 15 | self.onset_factor = onset_factor 16 | self.chord_factor = chord_factor 17 | self.pitch_factor = pitch_factor 18 | self.duration_factor = duration_factor 19 | 20 | self._note_mat = None 21 | self._chord_mat = None 22 | 23 | self._num_beat_per_measure = None 24 | self._num_step_per_beat = None 25 | 26 | self._start_measure = None 27 | 28 | self._reduction_mats = None 29 | 30 | self._report = None 31 | 32 | def preprocess_data(self, note_mat, chord_mat, 33 | start_measure, num_beat_per_measure, num_step_per_beat): 34 | """ 35 | Analyze the phrase and add music attributes to note mat. The columns becomes: 36 | [onsets, pitches, durations, bar_id, tonal_bar_id, chord_ids, tonal_chord_ids, is_chord_tone, tonal_onsets] 37 | """ 38 | def measure_to_step_fn(measure): 39 | return measure * num_beat_per_measure * num_step_per_beat 40 | 41 | def measure_to_beat_fn(measure): 42 | return measure * num_beat_per_measure 43 | 44 | def step_to_beat_fn(step): 45 | return step // num_step_per_beat 46 | 47 | def step_to_measure_fn(step): 48 | return step // (num_step_per_beat * num_beat_per_measure) 49 | 50 | def beat_to_measure_fn(beat): 51 | return beat // num_beat_per_measure 52 | 53 | def beat_to_step_fn(beat): 54 | return beat * num_step_per_beat 55 | 56 | note_mat, chord_mat = preprocess_data( 57 | note_mat, chord_mat, start_measure, 58 | measure_to_step_fn, measure_to_beat_fn, step_to_beat_fn, 59 | step_to_measure_fn, beat_to_measure_fn, beat_to_step_fn 60 | ) 61 | 62 | self.fill_data(note_mat, chord_mat, start_measure, num_beat_per_measure, 63 | num_step_per_beat) 64 | 65 | def algo(self, num_path=1, plot_graph=False): 66 | # find the top-k shortest paths and compute distance 67 | paths, G = find_tonal_shortest_paths( 68 | self._note_mat, self._num_beat_per_measure, 69 | self._num_step_per_beat, num_path, 70 | self.distance_factor, self.onset_factor, 71 | self.chord_factor, self.pitch_factor, self.duration_factor) 72 | 73 | if plot_graph: 74 | print('The current version can only print one shortest path.') 75 | self.plot_graph(G, paths[0]) 76 | 77 | return paths 78 | 79 | def postprocess_paths(self, paths): 80 | # use fixed rhythm template to compose melody reduction 81 | chord_bins, reduction_report = \ 82 | zip(*[path_to_chord_bins(path, self._note_mat, self._chord_mat) for path in paths]) 83 | 84 | reduction_mats = [chord_bins_to_reduction_mat(self._chord_mat, cb, self._num_step_per_beat) 85 | for cb in chord_bins] 86 | 87 | self._reduction_mats = reduction_mats 88 | self._report = reduction_report 89 | 90 | def output(self, start_measure=None): 91 | start_measure = start_measure if start_measure is not None else \ 92 | self._start_measure 93 | start_beat = start_measure * self._num_beat_per_measure 94 | start_step = start_beat * self._num_step_per_beat 95 | 96 | note_mat = self._note_mat.copy() 97 | note_mat[:, 0] += start_step 98 | 99 | chord_mat = self._chord_mat.copy() 100 | chord_mat[:, 0] += start_beat 101 | 102 | reduction_mats = self._reduction_mats.copy() 103 | for red_mat in reduction_mats: 104 | red_mat[:, 0] += start_step 105 | return note_mat, chord_mat, reduction_mats 106 | 107 | def run(self, note_mat, chord_mat, start_measure, num_beat_per_measure, 108 | num_step_per_beat, num_path=1, plot_graph=False): 109 | 110 | # analyze input melody phrase and add music attributes to note_mat. Translate phrase start to zero. 111 | self.preprocess_data(note_mat, chord_mat, start_measure, 112 | num_beat_per_measure, num_step_per_beat) 113 | 114 | # run shortest-path algorithm. 115 | paths = self.algo(num_path, plot_graph=plot_graph) 116 | 117 | # use fixed rhythm template to compose reduced melody. 118 | self.postprocess_paths(paths) 119 | 120 | # Translate phrase to actual phrase start. 121 | note_mat, chord_mat, reduction_mats = self.output() 122 | 123 | # clear stored data 124 | self.clear_data() 125 | 126 | return note_mat, chord_mat, reduction_mats 127 | 128 | def plot_graph(self, G, path): 129 | fig = plt.figure(figsize=(8, 6)) 130 | if len(G.nodes) == 0: 131 | return 132 | cmap = plt.cm.Greys 133 | pos_dict = {n: (self._note_mat[i, 0], self._note_mat[i, 1]) for i, n in enumerate(G.nodes)} 134 | edges, weights = zip(*nx.get_edge_attributes(G, 'weight').items()) 135 | weights = tuple(-w for w in weights) 136 | ax = plt.subplot() 137 | nx.draw_networkx(G, node_size=100, node_color='black', 138 | pos=pos_dict, edge_color=weights, edge_cmap=cmap) 139 | ax.set_axis_on() 140 | ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True) 141 | ax.set_xticks(np.arange(0, self._note_mat[:, 0].max(), self._num_step_per_beat)) 142 | 143 | new_G = nx.DiGraph() 144 | for i in range(len(self._note_mat)): 145 | G.add_node(i, data=self._note_mat[i]) 146 | 147 | path = path['path'] 148 | for i in range(len(path) - 1): 149 | new_G.add_edge(path[i], path[i + 1]) 150 | nx.draw_networkx_edges(new_G, pos_dict, edge_color='red') 151 | 152 | ax.grid() 153 | plt.show() 154 | 155 | def fill_data(self, note_mat, chord_mat, start_measure, 156 | num_beat_per_measure, num_step_per_beat): 157 | self._note_mat = note_mat 158 | self._chord_mat = chord_mat 159 | self._start_measure = start_measure 160 | self._num_beat_per_measure = num_beat_per_measure 161 | self._num_step_per_beat = num_step_per_beat 162 | 163 | def clear_data(self): 164 | self._note_mat = None 165 | self._chord_mat = None 166 | self._num_beat_per_measure = None 167 | self._num_step_per_beat = None 168 | self._start_measure = None 169 | -------------------------------------------------------------------------------- /data_utils/tonal_reduction_algo/postprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_chord_density(path, note_mat, chords): 5 | chord_density = [0 for _ in range(len(chords))] 6 | chord_bin = [[] for _ in range(len(chords))] 7 | chord_max_density = [chord[-1] for chord in chords] 8 | 9 | for pid, p in enumerate(path): 10 | note = note_mat[p] 11 | 12 | # use tonal chord id as the chord id if possible 13 | chord_id = int(note[6]) \ 14 | if int(note[6]) < len(chord_density) else int(note[5]) 15 | 16 | chord_density[chord_id] += 1 17 | chord_bin[chord_id].append(pid) 18 | 19 | is_overflow = [cd > cmd for cd, cmd in 20 | zip(chord_density, chord_max_density)] 21 | 22 | return chord_density, chord_max_density, chord_bin, is_overflow 23 | 24 | 25 | def path_to_chord_bins(path, note_mat, chord): 26 | if path is None: 27 | return None, None 28 | path, reduction_rate = path['path'], path['reduction_rate'] 29 | 30 | chord_density, chord_max_density, chord_bin, is_overflow = \ 31 | compute_chord_density(path, note_mat, chord) 32 | 33 | modify_list = [[] for _ in range(len(chord_density))] 34 | modify_chord_density = chord_density.copy() 35 | 36 | prev_cd = None 37 | for i in range(len(chord_density) - 1, -1, -1): 38 | cd, cmd, cb = chord_density[i], chord_max_density[i], chord_bin[i] 39 | 40 | current_cd = cd 41 | # check prolongation 42 | if not is_overflow[i]: 43 | prev_cd = current_cd 44 | continue 45 | 46 | # check prolongation 47 | for j, pid in enumerate(cb): 48 | p = path[pid] 49 | if pid != 0: 50 | prev_p = path[pid - 1] 51 | if note_mat[p, 2] == note_mat[prev_p, 2]: 52 | modify_list[i].append((j, 'r')) 53 | current_cd -= 1 54 | if current_cd <= cmd: 55 | break 56 | 57 | if current_cd > cmd: 58 | for j in range(len(cb) - 1, -1, -1): 59 | 60 | if prev_cd is None: 61 | break 62 | if j in [m[0] for m in modify_list[i]]: 63 | continue 64 | 65 | pid = cb[j] 66 | p = path[pid] 67 | note = note_mat[p] 68 | 69 | # check note is chord tone of chord[i + 1] 70 | if chord[i + 1, int(note[2] % 12) + 1] == 1 and prev_cd <= chord_max_density[i + 1] - 1: 71 | modify_list[i].append((j, 'm')) 72 | current_cd -= 1 73 | modify_chord_density[i + 1] += 1 74 | prev_cd += 1 75 | break 76 | 77 | if current_cd > cmd: 78 | # in this case drop note 79 | for j in range(len(cb) - 1, -1, -1): 80 | if j not in [m[0] for m in modify_list[i]]: 81 | modify_list[i].append((j, 'd')) 82 | current_cd -= 1 83 | if current_cd <= cmd: 84 | break 85 | 86 | prev_cd = current_cd 87 | modify_chord_density[i] = current_cd 88 | 89 | new_chord_bin = [[] for _ in range(len(chord_density))] 90 | num_prolonged, num_moved, num_dropped = 0, 0, 0 91 | num_note = 0 92 | 93 | for i in range(len(chord_density)): 94 | cb = chord_bin[i] 95 | modify = modify_list[i] 96 | num_note += 1 97 | for j, pid in enumerate(cb): 98 | if j in [m[0] for m in modify]: 99 | status = modify[[m[0] for m in modify].index(j)][1] 100 | if status == 'r': 101 | num_prolonged += 1 102 | elif status == 'd': 103 | num_dropped += 1 104 | else: 105 | num_moved += 1 106 | new_chord_bin[i + 1].append(note_mat[path[pid]]) 107 | else: 108 | new_chord_bin[i].append(note_mat[path[pid]]) 109 | postprocess_reduction_rate = 1 - num_dropped / num_note 110 | # compute duration 111 | report = {'num_prolonged': num_prolonged, 112 | 'num_moved': num_moved, 113 | 'num_dropped': num_dropped, 114 | 'red_rate_0': reduction_rate, 115 | 'red_rate_1': postprocess_reduction_rate, 116 | 'red_rate_final': postprocess_reduction_rate * reduction_rate} 117 | return new_chord_bin, report 118 | 119 | 120 | def chord_bins_to_reduction_mat(chord_mat, path_bin, nspb): 121 | if path_bin is None: 122 | return np.zeros((0, 3), dtype=np.int64) 123 | 124 | notes = [] 125 | # note_bin = [] 126 | for i, chord in enumerate(chord_mat): 127 | if len(path_bin[i]) == 0: 128 | continue 129 | chord_start = chord[0] 130 | chord_end = chord[-1] + chord_start 131 | n_note = len(path_bin[i]) 132 | len_chord = chord_end - chord_start 133 | 134 | if len_chord == 1: 135 | assert n_note <= 1 136 | durs = [1.] 137 | elif len_chord == 2: 138 | assert n_note <= 2 139 | durs = [2.] if n_note == 1 else [1., 1.] 140 | elif len_chord == 3: 141 | assert n_note <= 3 142 | if n_note == 1: 143 | durs = [3.] 144 | elif n_note == 2: 145 | durs = [2., 1.] 146 | else: 147 | durs = [1., 1., 1.] 148 | elif len_chord == 4: 149 | if n_note == 1: 150 | durs = [4.] 151 | elif n_note == 2: 152 | durs = [2., 2.] 153 | elif n_note == 3: 154 | durs = [2., 1., 1.] 155 | elif n_note == 4: 156 | durs = [1., 1., 1., 1.] 157 | else: 158 | raise AssertionError 159 | else: 160 | assert n_note <= len_chord 161 | durs = [len_chord - n_note + 1.] + [1.] * (n_note - 1) 162 | 163 | cur_start = chord_start * nspb 164 | 165 | for note_id, note in enumerate(path_bin[i]): 166 | ict = chord_mat[i, int(note[1] % 12) + 1] == 1 167 | dur = durs[note_id] * nspb 168 | notes.append((cur_start, note[1], dur, ict, i)) 169 | cur_start += dur 170 | 171 | new_notes = [list(notes[0])[0: 3]] 172 | for i in range(1, len(notes)): 173 | if notes[i][1] == notes[i - 1][1] and ( 174 | (not notes[i - 1][3] and notes[i][3]) or notes[i - 1][4] == 175 | notes[i][4]): 176 | new_notes[-1][2] = notes[i][0] - new_notes[-1][0] + notes[i][2] 177 | else: 178 | new_notes[-1][2] = notes[i][0] - new_notes[-1][0] 179 | new_notes.append(list(notes[i])[0: 3]) 180 | new_notes = np.array(new_notes).astype(np.int64) 181 | return new_notes 182 | -------------------------------------------------------------------------------- /data_utils/tonal_reduction_algo/preprocess.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | 4 | 5 | def remove_offset(note_mat, chord_mat, start_measure, measure_to_step_fn, 6 | measure_to_beat_fn): 7 | 8 | start_step = measure_to_step_fn(start_measure) 9 | start_beat = measure_to_beat_fn(start_measure) 10 | 11 | note_mat_ = note_mat.copy() 12 | note_mat_[:, 0] -= start_step 13 | 14 | chord_mat_ = chord_mat.copy() 15 | chord_mat_[:, 0] -= start_beat 16 | 17 | return note_mat_, chord_mat_ 18 | 19 | 20 | def chord_id_analysis(note_mat, chord_mat, step_to_beat_fn): 21 | """compute note to chord pointer""" 22 | chord_starts = chord_mat[:, 0] 23 | 24 | chord_ends = chord_mat[:, -1] + chord_starts # chord duration in beat + chord start 25 | 26 | note_onsets = note_mat[:, 0] 27 | 28 | onset_beats = step_to_beat_fn(note_onsets) 29 | 30 | # output: (note_ids, chord_ids) e.g., (0, 1, 2, 3, ...), (0, 0, 1, 1, ...) 31 | chord_ids = np.where(np.logical_and( 32 | chord_starts <= onset_beats[:, np.newaxis], 33 | chord_ends > onset_beats[:, np.newaxis])) 34 | 35 | if not (chord_ids[0] == np.arange(0, len(note_onsets))).all(): 36 | raise ValueError("Input melody onsets cannot point to input chords.") 37 | 38 | return chord_ids[1] 39 | 40 | 41 | def chord_tone_analysis(note_mat, chord_mat, chord_ids): 42 | # output col 1: normal chord tone, col 2: anticipation 43 | n_note = note_mat.shape[0] 44 | 45 | chords = chord_mat[chord_ids] 46 | pitches = note_mat[:, 1].astype(np.int64) 47 | 48 | # find notes of regular chord tone: pitch in chord chroma 49 | is_reg_chord_tone = (chords[np.arange(0, n_note), 2 + pitches % 12] == 1).astype(np.int64) 50 | 51 | # find notes of anticipation 52 | # condition 1: next chord exists 53 | next_c_exist_rows = chord_ids < chord_mat.shape[0] - 1 54 | 55 | # condition 2: last note in the current chord 56 | last_note_in_chord_rows = chord_ids[0: -1] < chord_ids[1:] 57 | last_note_in_chord_rows = np.append(last_note_in_chord_rows, True) 58 | 59 | # anticipation_condition_rows = np.where(np.logical_and(next_c_exist_rows, 60 | # last_note_in_chord_rows))[0] 61 | anticipation_condition_rows = np.logical_and(next_c_exist_rows, last_note_in_chord_rows) 62 | 63 | # anticipation: is the chord tone of the next chord (and not a regular chord tone) 64 | is_anticiptation = np.zeros(n_note, dtype=np.int64) 65 | 66 | is_anticiptation[anticipation_condition_rows] = \ 67 | chord_mat[chord_ids[anticipation_condition_rows] + 1, 68 | 2 + pitches[anticipation_condition_rows] % 12] == 1 69 | 70 | is_anticiptation = \ 71 | np.logical_and(np.logical_not(is_reg_chord_tone), is_anticiptation) 72 | 73 | # chord tones are regular chord tones or anticipation 74 | is_chord_tone = np.logical_or(is_reg_chord_tone, is_anticiptation) 75 | 76 | # tonal chord ids are the actual chord that a note is in harmonic with 77 | tonal_chord_ids = chord_ids.copy() 78 | tonal_chord_ids[is_anticiptation] += 1 79 | 80 | return is_chord_tone, is_anticiptation, tonal_chord_ids 81 | 82 | 83 | def compute_tonal_note_onsets(onsets, chord_mat, chord_ids, is_anticipation, beat_to_step_fn): 84 | tonal_onset = onsets.copy() 85 | tonal_onset[is_anticipation] = beat_to_step_fn(chord_mat[chord_ids[is_anticipation] + 1, 0]) 86 | return tonal_onset 87 | 88 | 89 | def compute_bar_ids(onsets, tonal_onsets, step_to_measure_fn): 90 | bar_ids = step_to_measure_fn(onsets) 91 | tonal_bar_ids = step_to_measure_fn(tonal_onsets) 92 | return bar_ids, tonal_bar_ids 93 | 94 | 95 | def preprocess_data(note_mat, chord_mat, start_measure, 96 | measure_to_step_fn, measure_to_beat_fn, 97 | step_to_beat_fn, 98 | step_to_measure_fn, beat_to_measure_fn, 99 | beat_to_step_fn): 100 | 101 | # setting phrase start to 0 102 | note_mat, chord_mat = remove_offset(note_mat, chord_mat, start_measure, 103 | measure_to_step_fn, measure_to_beat_fn) 104 | 105 | onsets, pitches, durations = note_mat.T 106 | 107 | # harmony analysis 108 | chord_ids = chord_id_analysis(note_mat, chord_mat, step_to_beat_fn) 109 | 110 | is_chord_tone, is_anticipation, tonal_chord_ids = \ 111 | chord_tone_analysis(note_mat, chord_mat, chord_ids) 112 | 113 | # compute tonal onsets and bar ids 114 | tonal_onsets = compute_tonal_note_onsets(onsets, chord_mat, chord_ids, is_anticipation, beat_to_step_fn) 115 | bar_id, tonal_bar_id = compute_bar_ids(onsets, tonal_onsets, step_to_measure_fn) 116 | 117 | output_note_mat = \ 118 | np.stack([onsets, pitches, durations, bar_id, tonal_bar_id, 119 | chord_ids, tonal_chord_ids, is_chord_tone, tonal_onsets], -1) 120 | return output_note_mat, chord_mat 121 | -------------------------------------------------------------------------------- /data_utils/tonal_reduction_algo/shortest_path_algo.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | 4 | 5 | def compute_onset_type(onset, nbpm=4, nspb=4): 6 | output = np.zeros_like(onset).astype(np.int64) 7 | 8 | half = nspb * 2 if nbpm == 4 else nspb * 3 9 | quarter = nspb 10 | if nspb == 4: 11 | eighth = 2 12 | else: 13 | raise NotImplementedError 14 | 15 | is_half = onset % half == 0 16 | is_quarter = np.logical_and(onset % half != 0, onset % quarter == 0) 17 | is_eighth = np.logical_and(onset % quarter != 0, onset % eighth == 0) 18 | is_sixteenth = onset % eighth != 0 19 | 20 | output[is_half] = 0 21 | output[is_quarter] = 1 22 | output[is_eighth] = 2 23 | output[is_sixteenth] = 3 24 | return output 25 | 26 | 27 | def return_onset_score(note_matrix, nbpm, nspb): 28 | onset_type = compute_onset_type(note_matrix[:, -1], nbpm, nspb) 29 | rhythm_coef = np.zeros_like(onset_type).astype(np.float32) 30 | rhythm_coef[onset_type == 0] = 0.85 31 | rhythm_coef[onset_type == 1] = 0.95 32 | rhythm_coef[onset_type == 2] = 1.05 33 | rhythm_coef[onset_type == 3] = 1.15 34 | return rhythm_coef 35 | 36 | 37 | def return_chord_tone_score(note_matrix): 38 | chord_tone_coef = np.zeros(len(note_matrix), dtype=np.float32) 39 | chord_tone_coef[note_matrix[:, 7] == 1] = 0.6 40 | chord_tone_coef[note_matrix[:, 7] == 0] = 1.4 41 | return chord_tone_coef 42 | 43 | 44 | def return_pitch_score(note_matrix): 45 | pitches = note_matrix[:, 1] 46 | highest_pitch = pitches.max() 47 | lowest_pitch = pitches.min() 48 | mid_pitch = (highest_pitch + lowest_pitch) / 2 49 | if highest_pitch != lowest_pitch: 50 | pitch_coef = np.abs(pitches - mid_pitch) / (highest_pitch - mid_pitch) 51 | pitch_coef = (0.5 - pitch_coef) * 0.1 + 1 52 | else: 53 | pitch_coef = np.ones_like(pitches) 54 | return pitch_coef 55 | 56 | 57 | def compute_duration_type(duration, nbpm=4, nspb=4): 58 | half = nspb * 2 if nbpm == 4 else nspb * 3 59 | quarter = nspb 60 | if nspb == 4: 61 | eighth = 2 62 | else: 63 | raise NotImplementedError 64 | 65 | output = np.zeros_like(duration).astype(np.int64) 66 | 67 | is_half = duration >= eighth 68 | is_quarter = np.logical_and(duration < half, duration >= quarter) 69 | is_eighth = np.logical_and(duration < quarter, duration >= eighth) 70 | is_sixteenth = duration < eighth 71 | 72 | output[is_half] = 0 73 | output[is_quarter] = 1 74 | output[is_eighth] = 2 75 | output[is_sixteenth] = 3 76 | return output 77 | 78 | 79 | def return_duration_score(note_matrix, nbpm, nspb): 80 | duration_type = compute_duration_type(note_matrix[:, 2], nbpm, nspb) 81 | duration_coef = np.zeros_like(duration_type).astype(np.float32) 82 | duration_coef[duration_type == 0] = 0.9 83 | duration_coef[duration_type == 1] = 0.95 84 | duration_coef[duration_type == 2] = 1.05 85 | duration_coef[duration_type == 3] = 1.1 86 | return duration_coef 87 | 88 | 89 | def detect_rel_type(note_matrix, i, j): 90 | bar_thresh = 2 91 | relation = 5 92 | 93 | bar_id1, bar_id2 = note_matrix[i, 3], note_matrix[j, 3] 94 | pitch1, pitch2 = note_matrix[i, 1], note_matrix[j, 1] 95 | chord_id1, chord_id2 = note_matrix[i, 6], note_matrix[j, 6] 96 | 97 | if (bar_id1 - bar_id2) < bar_thresh: 98 | if pitch1 == pitch2: 99 | relation = 0 100 | elif (pitch1 - pitch2) % 12 == 0: 101 | relation = 1 102 | 103 | elif 1 <= np.abs(pitch1 - pitch2) <= 2: 104 | relation = 2 105 | elif np.abs(pitch1 - pitch2) % 12 in [1, 2, 10, 11]: 106 | relation = 3 107 | 108 | if relation == 5 and chord_id1 == chord_id2: 109 | relation = 4 110 | 111 | return relation 112 | 113 | 114 | def create_adj_matrix(note_matrix, nbpm, nspb, dist_param=1.6, 115 | rhy_param=1., chord_param=1., pitch_param=1., 116 | dur_param=1.): 117 | n_note = len(note_matrix) 118 | adj_matrix = -np.ones((n_note, n_note)) 119 | rel_score_map = {0: 0.1, 1: 1, 2: 0.3, 3: 1.3, 4: 1.5, 5: 3} 120 | 121 | rhythm_coef = return_onset_score(note_matrix, nbpm, nspb) ** rhy_param 122 | chord_tone_coef = return_chord_tone_score(note_matrix) ** chord_param 123 | pitch_coef = return_pitch_score(note_matrix) ** pitch_param 124 | dur_coef = return_duration_score(note_matrix, nbpm, nspb) ** dur_param 125 | 126 | for i in range(n_note): 127 | for j in range(i + 1, n_note): 128 | rel_type = detect_rel_type(note_matrix, i, j) 129 | rel_score = rel_score_map[rel_type] 130 | dist = (j - i) ** dist_param 131 | edge_weight = \ 132 | (rel_score + dist) * rhythm_coef[j] * chord_tone_coef[j] * \ 133 | pitch_coef[j] * dur_coef[j] 134 | adj_matrix[i, j] = edge_weight 135 | return adj_matrix 136 | 137 | 138 | def compute_path_length(G, p): 139 | length = 0 140 | for i in range(1, len(p)): 141 | length += G.edges[p[i - 1], p[i]]['weight'] 142 | return length 143 | 144 | 145 | def find_tonal_shortest_paths(note_matrix, nbpm, nspb, num_path=1, 146 | dist_param=1.6, rhy_param=1., chord_param=1., 147 | pitch_param=1., dur_param=1.): 148 | if note_matrix.shape[0] == 0: 149 | return [None for _ in range(num_path)], nx.DiGraph() 150 | n_node = len(note_matrix) 151 | 152 | adj_mat = create_adj_matrix(note_matrix, nbpm, nspb, 153 | dist_param, rhy_param, chord_param, 154 | pitch_param, dur_param) 155 | 156 | G = nx.DiGraph() 157 | for i in range(len(note_matrix)): 158 | G.add_node(i, data=note_matrix[i]) 159 | 160 | for i in range(len(G.nodes)): 161 | for j in range(i + 1, len(G.nodes)): 162 | if adj_mat[i, j] != -1: 163 | G.add_edge(i, j, weight=adj_mat[i, j]) 164 | 165 | all_paths = nx.shortest_simple_paths(G, source=0, target=n_node - 1, 166 | weight='weight') 167 | 168 | paths = [] 169 | for path_id, path in enumerate(all_paths): 170 | paths.append({'path': path, 'distance': compute_path_length(G, path), 171 | 'reduction_rate': len(path) / n_node}) 172 | if path_id == num_path - 1: 173 | break 174 | return paths, G 175 | -------------------------------------------------------------------------------- /data_utils/train_valid_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | 5 | def create_train_valid_split(output_fn, n_sample, train_factor=9, seed=1234): 6 | np.random.seed(seed) 7 | 8 | data_ids = np.arange(0, n_sample) 9 | n_train = int(n_sample * train_factor / (train_factor + 1)) 10 | n_valid = n_sample - n_train 11 | train_inds = np.sort(np.random.choice(data_ids, n_train, replace=False)) 12 | valid_inds = np.sort(np.setdiff1d(data_ids, train_inds)) 13 | 14 | np.savez(output_fn, train_inds=train_inds, valid_inds=valid_inds) 15 | 16 | 17 | def load_split_file(split_fn): 18 | split_data = np.load(split_fn) 19 | train_inds = split_data['train_inds'] 20 | valid_inds = split_data['valid_inds'] 21 | return train_inds, valid_inds 22 | 23 | 24 | if __name__ == '__main__': 25 | output_dir = os.path.join('data', 'pop909_mel') 26 | os.makedirs(output_dir, exist_ok=True) 27 | 28 | output_fn = os.path.join(output_dir, 'split.npz') 29 | create_train_valid_split(output_fn, 909, train_factor=9, seed=1234) 30 | 31 | t_id, v_id = load_split_file(output_fn) 32 | print(f'n_train={len(t_id)}, n_valid={len(v_id)}') 33 | print(t_id) 34 | print(v_id) 35 | -------------------------------------------------------------------------------- /data_utils/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZWaang/whole-song-gen/11326220d2f032bdcdce60b8b8cf6891fb2ca308/data_utils/utils/__init__.py -------------------------------------------------------------------------------- /data_utils/utils/chord_reduction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def _parse_chord(c): 5 | """Returns onset, root, chroma, bass, duration.""" 6 | return c[0], c[1], c[2: 14], c[14], c[15] 7 | 8 | 9 | def _chroma1_in_chroma_2(chroma1, chroma2): 10 | return (chroma2[chroma1 != 0] == 1).all() 11 | 12 | 13 | def _share_chroma(chroma1, chroma2): 14 | return np.count_nonzero(np.logical_and(chroma1, chroma2)) >= 3 15 | 16 | 17 | def get_chord_reduction(chord_mat, time_unit=4): 18 | """ 19 | This function merge chords together. 9Two chords will be merged if: 20 | 1) They are within the same time_unit (usually 2 - 4 beats) 21 | 2) They have the same root or bass 22 | 3) One chord chroma includes the other or two chord chromas share more than three notes 23 | """ 24 | 25 | red_chord_mat = [] 26 | 27 | i = 0 28 | while i < len(chord_mat): 29 | c_i = chord_mat[i] 30 | onset_i, root_i, chroma_i, bass_i, duration_i = _parse_chord(chord_mat[i]) 31 | 32 | new_root, new_chroma, new_bass, acc_duration = root_i, chroma_i, bass_i, duration_i 33 | 34 | j = i + 1 35 | while acc_duration < time_unit and j < len(chord_mat): 36 | onset_j, root_j, chroma_j, bass_j, duration_j = _parse_chord(chord_mat[j]) 37 | 38 | if onset_j // time_unit != onset_i // time_unit: # not in the same time_unit 39 | break 40 | 41 | if root_i == root_j or bass_i == bass_j: 42 | if _chroma1_in_chroma_2(chroma_i, chroma_j): # chroma i in chroma j, use chord_j 43 | new_root, new_chroma, new_bass = root_j, chroma_j, bass_j 44 | acc_duration += duration_j 45 | j += 1 46 | 47 | elif _chroma1_in_chroma_2(chroma_j, chroma_i): # chroma j in chroma i, use chord_i 48 | acc_duration += duration_j 49 | j += 1 50 | 51 | elif _share_chroma(chroma_i, chroma_j): # share more than three chord tone, use chord_i 52 | acc_duration += duration_j 53 | j += 1 54 | else: 55 | break 56 | else: 57 | break 58 | 59 | red_chord_mat.append( 60 | np.concatenate([np.array([onset_i, new_root]), new_chroma, np.array([new_bass, acc_duration])])) 61 | i = j 62 | 63 | red_chord_mat = np.stack(red_chord_mat, 0).astype(np.int64) 64 | 65 | return red_chord_mat 66 | -------------------------------------------------------------------------------- /data_utils/utils/format_converter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.linear_model import LinearRegression 3 | import pretty_midi as pm 4 | 5 | 6 | def mono_note_matrix_to_pr_array(note_mat, total_length): 7 | total_length = total_length if total_length is not None else max(note_mat[:, 0] + note_mat[:, 2]) 8 | 9 | pr_array = np.ones(total_length, dtype=np.int64) * -1 10 | for note in note_mat: 11 | onset, pitch, duration = note 12 | pr_array[onset] = pitch 13 | pr_array[onset + 1: onset + duration] = pitch + 128 14 | return pr_array 15 | 16 | 17 | def note_matrix_to_piano_roll(note_mat, total_length=None): 18 | total_length = total_length if total_length is not None else max(note_mat[:, 0] + note_mat[:, 2]) 19 | 20 | piano_roll = np.zeros((2, total_length, 128), dtype=np.int64) 21 | for note in note_mat: 22 | onset, pitch, duration = note 23 | piano_roll[0, onset, pitch] = 1 24 | piano_roll[1, onset + 1: onset + duration, pitch] = 1 25 | return piano_roll 26 | 27 | 28 | def pr_array_to_piano_roll(pr_array): 29 | total_length = pr_array.shape[0] 30 | 31 | piano_roll = np.zeros((2, total_length, 128), dtype=np.int64) 32 | 33 | is_onset = np.logical_and(pr_array >= 0, pr_array < 128) 34 | is_sustain = pr_array >= 128 35 | 36 | piano_roll[0, is_onset, pr_array[is_onset]] = 1 37 | piano_roll[1, is_sustain, pr_array[is_sustain] - 128] = 1 38 | 39 | return piano_roll 40 | 41 | 42 | def pr_array_to_pr_contour(pr_array): 43 | pr_contour = pr_array.copy() 44 | 45 | onsets = np.where(np.logical_and(pr_array >= 0, pr_array < 128))[0] 46 | if len(onsets) == 0: 47 | pr_contour[:] = 60 48 | return pr_contour 49 | 50 | first_onset = onsets[0] 51 | first_pitch = pr_array[first_onset] 52 | 53 | pr_contour[0: first_onset] = first_pitch 54 | for i in range(first_onset, len(pr_contour)): 55 | pitch = pr_contour[i] 56 | if pitch >= 128: 57 | pr_contour[i] = pitch - 128 58 | elif pitch == -1: 59 | pr_contour[i] = pr_contour[i - 1] 60 | return pr_contour 61 | 62 | 63 | def extract_pitch_contour(pr_contour, nspb, stride=2): 64 | # see pivot point high low point. 65 | def direction_via_regression(contour, length=None): 66 | 67 | t = np.linspace(0, 1, length)[:, np.newaxis] 68 | reg = LinearRegression().fit(t[0: len(contour)], contour) 69 | a = reg.coef_[0] 70 | 71 | if a > 5: 72 | contour_type = 4 73 | elif 1 < a <= 5: 74 | contour_type = 3 75 | elif -1 <= a <= 1: 76 | contour_type = 2 77 | elif -5 <= a < -1: 78 | contour_type = 1 79 | else: 80 | contour_type = 0 81 | return contour_type 82 | 83 | contour = [] 84 | for i in range(0, len(pr_contour), int(nspb * stride)): 85 | segment = pr_contour[i: int(i + nspb * stride * 2)] 86 | direction = direction_via_regression(segment, int(nspb * stride * 2)) 87 | contour.append(direction) 88 | return np.array(contour, dtype=np.int64) 89 | 90 | 91 | def pr_array_to_rhythm(pr_array): 92 | rhythm_array = np.zeros_like(pr_array) 93 | 94 | sustain = pr_array >= 128 95 | rest = pr_array < 0 96 | 97 | rhythm_array[sustain] = 1 98 | rhythm_array[rest] = 2 99 | 100 | return rhythm_array 101 | 102 | 103 | def extract_rhythm_intensity(rhythm_array, nspb, stride=2, 104 | quantization_bin=4): 105 | def compute_intensity(rhy_segment): 106 | n_step = len(rhy_segment) 107 | 108 | onset = np.count_nonzero(rhy_segment == 0) / n_step 109 | rest = np.count_nonzero(rhy_segment == 2) / n_step 110 | 111 | # quantization 112 | onset_val = int(np.ceil(onset * (quantization_bin - 1))) 113 | rest_val = int(np.ceil(rest * (quantization_bin - 1))) 114 | 115 | return [onset_val, rest_val] 116 | 117 | intensity_array = [] 118 | for i in range(0, len(rhythm_array), int(nspb * stride)): 119 | segment = rhythm_array[i: int(i + nspb * stride * 2)] 120 | intensity = compute_intensity(segment) 121 | intensity_array.append(intensity) 122 | return np.array(intensity_array, dtype=np.int64) 123 | 124 | 125 | def chord_mat_to_chord_roll(chord, total_beat): 126 | chord_roll = np.zeros((6, total_beat, 12), dtype=np.int64) 127 | 128 | for c in chord: 129 | start_beat = c[0] 130 | chord_content = c[1: 15] 131 | dur_beat = c[15] 132 | 133 | root = chord_content[0] 134 | bass = (chord_content[-1] + root) % 12 135 | chroma = chord_content[1: 13] 136 | # print(start_beat, total_beat) 137 | chord_roll[0, start_beat, root] = 1 138 | chord_roll[1, start_beat + 1: start_beat + dur_beat, root] = 1 139 | 140 | chord_roll[2, start_beat, :] = chroma 141 | chord_roll[3, start_beat + 1: start_beat + dur_beat, :] = chroma 142 | 143 | chord_roll[4, start_beat, bass] = 1 144 | chord_roll[5, start_beat + 1: start_beat + dur_beat, bass] = 1 145 | 146 | return chord_roll 147 | 148 | 149 | def chord_to_chord_roll(chord, total_beat): 150 | chord_roll = np.zeros((2, total_beat, 36), dtype=np.int64) 151 | 152 | for c in chord: 153 | start_beat = c[0] 154 | chord_content = c[1: 15] 155 | dur_beat = c[15] 156 | 157 | root = chord_content[0] 158 | bass = (chord_content[-1] + root) % 12 159 | chroma = chord_content[1: 13] 160 | 161 | chord_roll[0, start_beat, root] = 1 162 | chord_roll[0, start_beat, 12: 24] = chroma 163 | chord_roll[0, start_beat, 24 + bass] = 1 164 | chord_roll[1, start_beat + 1: start_beat + dur_beat, root] = 1 165 | chord_roll[1, start_beat + 1: start_beat + dur_beat, 12: 24] = chroma 166 | chord_roll[1, start_beat + 1: start_beat + dur_beat, 24 + bass] = 1 167 | 168 | return chord_roll 169 | 170 | 171 | def note_mat_to_notes(note_mat, bpm, factor=4, shift=0.): 172 | alpha = 60 / bpm / factor 173 | notes = [] 174 | for note in note_mat: 175 | onset, pitch, dur = note 176 | 177 | notes.append(pm.Note(100, int(pitch), onset * alpha + shift, (onset + dur) * alpha + shift)) 178 | return notes 179 | 180 | 181 | def chord_roll_to_chord_stack(chord_roll, nbpm, pad=True): 182 | # (6, T, 12) -> (6 * nbpm, T // nbpm, 12) 183 | n_channel, lgth, h = chord_roll.shape 184 | assert lgth % nbpm == 0 185 | lgth_ = lgth // nbpm 186 | chord_roll = chord_roll.copy().reshape((n_channel, lgth_, nbpm, h)) 187 | chord_roll = chord_roll.transpose((0, 2, 1, 3)) 188 | 189 | if pad and nbpm == 3: 190 | chord_roll = np.pad(chord_roll, pad_width=((0, 0), (0, 1), (0, 0), (0, 0)), 191 | mode='constant') 192 | chord_roll = chord_roll.reshape((n_channel // 2, 2, 4, lgth_, h)).sum(1) 193 | chord_roll = chord_roll.reshape((n_channel // 2 * 4, lgth_, h)) 194 | 195 | return chord_roll 196 | 197 | 198 | def reduction_roll_to_reduction_stack(reduction, nbpm, pad=True): 199 | # (2, T, 128) -> (nbpm, T // nbpm, 12) 200 | n_channel, lgth, h = reduction.shape 201 | assert lgth % nbpm == 0 202 | assert h == 128 203 | 204 | lgth_ = lgth // nbpm 205 | 206 | reduction = np.pad(reduction.copy(), pad_width=((0, 0), (0, 0), (0, 4)), 207 | mode='constant') 208 | 209 | reduction = reduction.reshape((n_channel, lgth_, nbpm, 11, 12)).sum(-2) 210 | reduction = reduction.transpose((0, 2, 1, 3)) 211 | if pad and nbpm == 3: 212 | reduction = np.pad(reduction, pad_width=((0, 0), (0, 1), (0, 0), (0, 0)), 213 | mode='constant') 214 | reduction = reduction.sum(0).reshape((4, lgth_, 12)) 215 | return reduction 216 | -------------------------------------------------------------------------------- /data_utils/utils/key_analysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def key_estimation(mel_roll, chord_roll, 5 | phrase_starts, phrase_lengths, num_beat_per_measure, num_step_per_beat): 6 | # chord_roll = self.chord_to_compact_pianoroll() 7 | key_template = np.array([1., 0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 1.]) 8 | key_templates = \ 9 | np.stack([np.roll(key_template, step) for step in range(12)], 0) 10 | 11 | scales = [] 12 | tonics = [] 13 | for i in range(len(phrase_lengths)): 14 | start_measure = phrase_starts[i] 15 | lgth = phrase_lengths[i] 16 | 17 | start_beat = start_measure * num_beat_per_measure 18 | end_beat = (start_beat + lgth) * num_beat_per_measure 19 | 20 | start_step = start_beat * num_step_per_beat 21 | end_step = end_beat * num_step_per_beat 22 | 23 | chroma_hist = chord_roll[2: 4, start_beat: end_beat].sum(0).sum(0) 24 | 25 | if not (chroma_hist == 0).all(): 26 | chroma_hist = chroma_hist / chroma_hist.sum() 27 | 28 | score = (chroma_hist[np.newaxis, :] @ key_templates.T)[0] 29 | max_val = score.max() 30 | cand_key = np.where(np.abs(score - max_val) < 1e-4)[0] 31 | if len(scales) > 0 and scales[-1] in cand_key: 32 | scale = scales[-1] 33 | else: 34 | scale = cand_key[0] 35 | scales.append(scale) 36 | 37 | mel_hist = mel_roll[:, start_step: end_step].sum(0).sum(0) 38 | major_score, minor_score = mel_hist[scale], mel_hist[(scale + 9) % 12] 39 | tonic = scale if major_score >= minor_score else (scale + 9) % 12 40 | tonics.append(tonic) 41 | 42 | scales = np.array(scales) 43 | tonics = np.array(tonics) 44 | keys = np.stack([tonics, scales], 0) 45 | return keys 46 | 47 | 48 | def key_to_key_roll(keys, total_measure, phrase_lengths, phrase_starts): 49 | key_template = np.array([1., 0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 1.]) 50 | key_templates = \ 51 | np.stack([np.roll(key_template, step) for step in range(12)], 0) 52 | 53 | key_roll = np.zeros((2, total_measure, 12), dtype=np.int64) 54 | 55 | for i in range(len(phrase_lengths)): 56 | start_measure = phrase_starts[i] 57 | lgth = phrase_lengths[i] 58 | key_roll[0, start_measure: start_measure + lgth, keys[0, i]] = 1 59 | key_roll[1, start_measure: start_measure + lgth] = key_templates[keys[1, i]] 60 | 61 | return key_roll 62 | 63 | 64 | def get_key_roll(mel_roll, chord_roll, phrase_starts, phrase_lengths, total_measure, 65 | num_beat_per_measure, num_step_per_beat): 66 | keys = key_estimation(mel_roll, chord_roll, phrase_starts, phrase_lengths, num_beat_per_measure, num_step_per_beat) 67 | key_roll = key_to_key_roll(keys, total_measure, phrase_lengths, phrase_starts) 68 | return key_roll 69 | -------------------------------------------------------------------------------- /data_utils/utils/phrase_analysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def phrase_to_phrase_roll(phrase_starts, phrase_lengths, phrase_types, 5 | total_measure=None): 6 | def phrase_type_to_index(p_type): 7 | if p_type == 'A': 8 | return 0 9 | elif p_type == 'B': 10 | return 1 11 | elif p_type.isupper(): 12 | return 2 13 | elif p_type == 'i': 14 | return 3 15 | elif p_type in ['o', 'z']: 16 | return 4 17 | else: 18 | return 5 19 | 20 | total_measure = phrase_lengths.sum() if total_measure is None else total_measure 21 | 22 | measures = np.zeros((6, total_measure), dtype=np.float32) 23 | for start, length, ptype in zip(phrase_starts, phrase_lengths, phrase_types): 24 | phrase_index = phrase_type_to_index(ptype) 25 | measures[phrase_index, start: start + length] = \ 26 | np.linspace(1., 0., length, endpoint=False) 27 | return measures -------------------------------------------------------------------------------- /data_utils/utils/read_file.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mir_eval 3 | import os 4 | from .song_data_structure import McpaMusic 5 | 6 | 7 | def _cum_time_to_time_dur(d): 8 | d_cumsum = np.cumsum(d) 9 | starts = np.insert(d_cumsum, 0, 0)[0: -1] 10 | return starts 11 | 12 | 13 | def _parse_phrase_label(phrase_string): 14 | phrase_starts = [i for i, s in enumerate(phrase_string) if s.isalpha()] + [len(phrase_string)] 15 | phrase_names = [phrase_string[phrase_starts[i]: phrase_starts[i + 1]] for i in range(len(phrase_starts) - 1)] 16 | phrase_types = [pn[0] for pn in phrase_names] 17 | phrase_lgths = np.array([int(pn[1:]) for pn in phrase_names]) 18 | return phrase_names, phrase_types, phrase_lgths 19 | 20 | 21 | def read_label(label_fn): 22 | """label_fn is human_label1.txt or human_label1.txt in the song folder of the dataset.""" 23 | with open(label_fn) as f: 24 | phrase_label = f.readlines()[0].strip() 25 | phrase_names, phrase_types, phrase_lgths = _parse_phrase_label(phrase_label) 26 | 27 | phrase_starts = _cum_time_to_time_dur(phrase_lgths) 28 | 29 | phrases = [{'name': pn, 'type': pt, 'lgth': pl, 'start': ps} 30 | for pn, pt, pl, ps in zip(phrase_names, phrase_types, phrase_lgths, phrase_starts)] 31 | 32 | return phrases 33 | 34 | 35 | def read_melody(melody_fn): 36 | """melody_fn is melody.txt in the song folder of the dataset.""" 37 | 38 | # convert txt file to numpy array of (pitch, duration) 39 | with open(melody_fn) as f: 40 | melody = f.readlines() 41 | melody = [m.strip().split(' ') for m in melody] 42 | melody = np.array([[int(m[0]), int(m[1])] for m in melody]) 43 | 44 | # convert numpy array of (pitch, duration) to (onset, pitch, dur) 45 | starts = _cum_time_to_time_dur(melody[:, 1]) 46 | durs = melody[:, 1] 47 | pitches = melody[:, 0] 48 | is_pitch = melody[:, 0] != 0 49 | note_mat = np.stack([starts, pitches, durs], -1)[is_pitch] 50 | 51 | return note_mat 52 | 53 | 54 | def _read_chord_string(c): 55 | """ 56 | "\x01"s are replaced with " ". ")"s are added to "sus4(b7". 57 | (E.g., check with if c_name[-7:] == 'sus4(b7'). And there may have other annotation problems.) 58 | The files in this repo have been cleaned. 59 | """ 60 | c = c.strip().split(' ') 61 | c_name = c[0] 62 | c_dur = int(c[-1]) 63 | 64 | # cleaning the chord symbol 65 | c_name = c_name.replace('\x01', '') 66 | 67 | # convert to chroma representation 68 | root, chroma, bass = mir_eval.chord.encode(c_name) 69 | chroma = np.roll(chroma, shift=root) 70 | return np.concatenate([np.array([root]), chroma, np.array([bass]), np.array([c_dur])]) 71 | 72 | 73 | def read_chord(chord_fn): 74 | """chord_fn is finalized_chord.txt in the song folder of the dataset.""" 75 | with open(chord_fn) as f: 76 | chords = f.readlines() 77 | 78 | # convert chord text label to chroma representation 79 | chords = np.stack([_read_chord_string(c) for c in chords], 0) 80 | 81 | # convert chord to output chord matrix 82 | starts = _cum_time_to_time_dur(chords[:, -1]) 83 | chord_mat = np.concatenate([starts[:, np.newaxis], chords], -1) 84 | 85 | return chord_mat 86 | 87 | 88 | def read_data(data_fn, acc_fn, num_beat_per_measure=4, num_step_per_beat=4, 89 | clean_chord_unit=None, song_name=None, label=1): 90 | if label == 1: 91 | label_fn = os.path.join(data_fn, 'human_label1.txt') 92 | elif label == 2: 93 | label_fn = os.path.join(data_fn, 'human_label2.txt') 94 | else: 95 | raise NotImplementedError 96 | 97 | label = read_label(label_fn) 98 | 99 | melody_fn = os.path.join(data_fn, 'melody.txt') 100 | melody = read_melody(melody_fn) 101 | 102 | chord_fn = os.path.join(data_fn, 'finalized_chord.txt') 103 | chord = read_chord(chord_fn) 104 | 105 | clean_chord_unit = num_beat_per_measure if clean_chord_unit is None else clean_chord_unit 106 | 107 | acc_mats = np.load(os.path.join(acc_fn, 'acc_mat.npz')) 108 | bridge_track, piano_track = acc_mats['bridge'], acc_mats['piano'] 109 | acc = np.concatenate([bridge_track, piano_track], 0) 110 | acc = acc[acc[:, 0].argsort()] 111 | 112 | song = McpaMusic(melody, chord, acc, label, num_beat_per_measure, 113 | num_step_per_beat, song_name, clean_chord_unit) 114 | 115 | return song 116 | 117 | -------------------------------------------------------------------------------- /data_utils/utils/song_analyzer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ..tonal_reduction_algo.main import TrAlgo 3 | from .format_converter import note_matrix_to_piano_roll, chord_mat_to_chord_roll 4 | from .key_analysis import get_key_roll 5 | from .phrase_analysis import phrase_to_phrase_roll 6 | from .song_data_structure import McpaMusic 7 | from .chord_reduction import get_chord_reduction 8 | 9 | 10 | class LanguageExtractor: 11 | 12 | def __init__(self, song: McpaMusic): 13 | self.song = song 14 | 15 | self._mel_roll = None 16 | self._chd_roll = None 17 | self._song_dict = None 18 | 19 | def _phrase_to_beat(self, phrase): 20 | start_measure = self.song.phrase_starts[phrase] 21 | end_measure = self.song.phrase_lengths[phrase] + start_measure 22 | start_beat = start_measure * self.song.num_beat_per_measure 23 | end_beat = end_measure * self.song.num_beat_per_measure 24 | return start_beat, end_beat 25 | 26 | def _create_a_phrase_level_dict(self, phrase_id): 27 | start_measure = self.song.phrase_starts[phrase_id] 28 | phrase_length = self.song.phrase_lengths[phrase_id] 29 | end_measure = start_measure + phrase_length 30 | phrase_dict = { 31 | 'phrase_name': self.song.phrase_names[phrase_id], 32 | 'phrase_type': self.song.phrase_types[phrase_id], 33 | 'phrase_length': self.song.phrase_lengths[phrase_id], 34 | 'start_measure': start_measure, 35 | 'end_measure': end_measure, 36 | 'length': phrase_length, 37 | 'mel_slice': None, 38 | 'chd_slice': None, 39 | } 40 | return phrase_dict 41 | 42 | def _create_song_level_dict(self, melody, chord): 43 | self._song_dict = { 44 | 'song_name': self.song.song_name, 45 | 'total_phrase': self.song.num_phrases, 46 | 'total_measure': self.song.total_measure, 47 | 'total_beat': self.song.total_beat, 48 | 'total_step': self.song.total_step, 49 | 'phrases': [self._create_a_phrase_level_dict(phrase_id) 50 | for phrase_id in range(self.song.num_phrases)] 51 | } 52 | self._fill_phrase_level_slices(melody, chord) 53 | 54 | def _fill_phrase_level_mel_slices(self, melody): 55 | n_note = melody.shape[0] 56 | 57 | onset_beats = melody[:, 0] // self.song.num_step_per_beat 58 | 59 | current_ind = 0 60 | for phrase_id, phrase in enumerate(self._song_dict['phrases']): 61 | start_beat, end_beat = self._phrase_to_beat(phrase_id) 62 | for i in range(current_ind, n_note): 63 | if onset_beats[i] >= end_beat: 64 | phrase[f'mel_slice'] = slice(current_ind, i) 65 | current_ind = i 66 | break 67 | else: 68 | phrase[f'mel_slice'] = slice(current_ind, n_note) 69 | current_ind = n_note 70 | 71 | def _fill_phrase_level_chd_slices(self, chord): 72 | n_chord = chord.shape[0] 73 | current_ind = 0 74 | for phrase_id, phrase in enumerate(self._song_dict['phrases']): 75 | start_beat, end_beat = self._phrase_to_beat(phrase_id) 76 | for i in range(current_ind, n_chord): 77 | if chord[i, 0] >= end_beat: 78 | phrase['chd_slice'] = slice(current_ind, i) 79 | current_ind = i 80 | break 81 | else: 82 | phrase['chd_slice'] = slice(current_ind, n_chord) 83 | current_ind = n_chord 84 | 85 | def _fill_phrase_level_slices(self, melody, chord): 86 | self._fill_phrase_level_mel_slices(melody) 87 | self._fill_phrase_level_chd_slices(chord) 88 | 89 | def extract_form(self): 90 | """Extract lang0: Form (key and phrase)""" 91 | 92 | key_roll = get_key_roll(self._mel_roll, self._chd_roll, 93 | self.song.phrase_starts, self.song.phrase_lengths, self.song.total_measure, 94 | self.song.num_beat_per_measure, self.song.num_step_per_beat) 95 | 96 | phrase_roll = phrase_to_phrase_roll(self.song.phrase_starts, self.song.phrase_lengths, 97 | self.song.phrase_types, self.song.total_measure) 98 | 99 | return {'key_roll': key_roll, 'phrase_roll': phrase_roll} 100 | 101 | def get_melody_reduction(self, num_reduction=1, melody=None, chord=None): 102 | melody = self.song.melody if melody is None else melody 103 | chord = self.song.chord if chord is None else chord 104 | 105 | self._create_song_level_dict(melody, chord) 106 | 107 | tr_algo = TrAlgo() 108 | 109 | nbpm = self.song.num_beat_per_measure 110 | nspb = self.song.num_step_per_beat 111 | 112 | reductions = [[] for _ in range(num_reduction)] 113 | 114 | for phrase in self._song_dict['phrases']: 115 | mel_slice = phrase['mel_slice'] 116 | chd_slice = phrase['chd_slice'] 117 | 118 | note_mat = melody[mel_slice].copy() 119 | chord_mat = chord[chd_slice].copy() 120 | 121 | start_measure = phrase['start_measure'] 122 | 123 | _, _, reduction_mats = \ 124 | tr_algo.run(note_mat, chord_mat, start_measure, nbpm, nspb, num_path=1, plot_graph=False) 125 | 126 | for i in range(num_reduction): 127 | red_mat = reduction_mats[i] 128 | red_mat[:, 0] = red_mat[:, 0] // self.song.num_step_per_beat 129 | red_mat[:, 2] = red_mat[:, 2] // self.song.num_step_per_beat 130 | reductions[i].append(red_mat) 131 | 132 | reductions = [np.concatenate(reductions[i], 0) for i in range(num_reduction)] 133 | 134 | return reductions 135 | 136 | def extract_counterpoint(self): 137 | rough_chord = get_chord_reduction(self.song.chord, self.song.clean_chord_unit) 138 | red_chd_roll = chord_mat_to_chord_roll(rough_chord, self.song.total_beat) 139 | 140 | reduction = self.get_melody_reduction(num_reduction=1, melody=self.song.melody, chord=rough_chord)[0] 141 | 142 | red_mel_roll = note_matrix_to_piano_roll(reduction, self.song.total_beat) 143 | 144 | return {'red_mel_roll': red_mel_roll, 'red_chd_roll': red_chd_roll} 145 | 146 | def extract_lead_sheet(self): 147 | mel_roll = note_matrix_to_piano_roll(self.song.melody, self.song.total_step) 148 | chd_roll = chord_mat_to_chord_roll(self.song.chord, self.song.total_beat) 149 | 150 | self._mel_roll = mel_roll 151 | self._chd_roll = chd_roll 152 | 153 | return {'mel_roll': mel_roll, 'chd_roll': chd_roll} 154 | 155 | def extract_accompaniment(self): 156 | acc_roll = note_matrix_to_piano_roll(self.song.acc, self.song.total_step) 157 | return {'acc_roll': acc_roll} 158 | 159 | def extract_all_hie_langs(self): 160 | accompaniment = self.extract_accompaniment() 161 | lead_sheet = self.extract_lead_sheet() 162 | counterpoint = self.extract_counterpoint() 163 | form = self.extract_form() 164 | 165 | return {'form': form, 'counterpoint': counterpoint, 'lead_sheet': lead_sheet, 'accompaniment': accompaniment} 166 | 167 | def analyze_for_training(self): 168 | # extract min and max melody pitch. In image representation, the lowest melody pitch should be higher than midi 169 | # pitch 48 after pitch augmentation. 170 | 171 | min_mel_pitch, max_mel_pitch = self.song.melody[:, 1].min(), self.song.melody[:, 1].max() 172 | 173 | languages = self.extract_all_hie_langs() 174 | 175 | return {'name': self.song.song_name, 'nbpm': self.song.num_beat_per_measure, 176 | 'nspb': self.song.num_step_per_beat, 177 | 'min_mel_pitch': min_mel_pitch, 'max_mel_pitch': max_mel_pitch, 178 | 'languages': languages} 179 | 180 | -------------------------------------------------------------------------------- /data_utils/utils/song_data_structure.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class McpaMusic: 5 | 6 | """ 7 | MCPA Music contains Melody, Chords, Phrase, and Accompaniment annotations. 8 | """ 9 | 10 | def __init__(self, melody, chord, acc, phrase_label, 11 | num_beat_per_measure=4, num_step_per_beat=4, 12 | song_name=None, clean_chord_unit=4): 13 | self.song_name = song_name 14 | 15 | # structural attributes 16 | self.num_beat_per_measure = num_beat_per_measure 17 | self.num_step_per_beat = num_step_per_beat 18 | 19 | # phrase attributes 20 | self.phrase_names = \ 21 | np.array([pl['name'] for pl in phrase_label]) 22 | self.phrase_types = \ 23 | np.array([pl['type'] for pl in phrase_label]) 24 | self.phrase_starts = \ 25 | np.array([pl['start'] for pl in phrase_label]) 26 | self.phrase_lengths = \ 27 | np.array([pl['lgth'] for pl in phrase_label]) 28 | self.num_phrases = len(phrase_label) 29 | 30 | # melody, chord and accompaniment 31 | self.melody = melody 32 | self.chord = chord 33 | self.acc = acc 34 | 35 | # determine piece length from phrase label, melody, and chord input 36 | self.total_measure = self.compute_total_measure() 37 | self.total_beat = self.total_measure * self.num_beat_per_measure 38 | self.total_step = self.total_beat * self.num_step_per_beat 39 | 40 | # ensuring chord having a maximum duration of self.clearn_chord_unit 41 | self.clean_chord_unit = clean_chord_unit 42 | self.clean_chord() 43 | 44 | # pad chord (pad = last chord) to match self.total_beat 45 | self.regularize_chord() 46 | 47 | # pad phrase with label 'z' to match self.total_measure 48 | self.regularize_phrases() 49 | 50 | def compute_total_measure(self): 51 | # propose candidates from phrase, chord, melody and acc 52 | last_step_mel = (self.melody[:, 0] + self.melody[:, 2]).max() 53 | if self.acc is None: 54 | last_step = last_step_mel 55 | else: 56 | last_step_acc = (self.acc[:, 0] + self.acc[:, 2]).max() 57 | last_step = max(last_step_mel, last_step_acc) 58 | 59 | num_measure0 = int(np.ceil(last_step / self.num_step_per_beat / self.num_beat_per_measure)) 60 | 61 | last_beat = (self.chord[:, 0] + self.chord[:, -1]).max() 62 | num_measure1 = int(np.ceil(last_beat / self.num_beat_per_measure)) 63 | 64 | num_measure2 = sum(self.phrase_lengths) 65 | return max(num_measure0, num_measure1, num_measure2) 66 | 67 | def regularize_chord(self): 68 | chord = self.chord 69 | end_time = (self.chord[:, 0] + self.chord[:, -1]).max() 70 | fill_n_beat = self.total_beat - end_time 71 | if fill_n_beat == 0: 72 | return 73 | 74 | pad_durs = [self.clean_chord_unit] * (fill_n_beat // self.clean_chord_unit) 75 | if fill_n_beat - sum(pad_durs) > 0: 76 | pad_durs = [fill_n_beat - sum(pad_durs)] + pad_durs 77 | for d in pad_durs: 78 | stack_chord = chord[-1].copy() 79 | stack_chord[0] = chord[-1, 0] + chord[-1, -1] 80 | stack_chord[-1] = d 81 | 82 | chord = np.concatenate([chord, stack_chord[np.newaxis, :]], 0) 83 | self.chord = chord 84 | 85 | def regularize_phrases(self): 86 | original_phrase_length = sum(self.phrase_lengths) 87 | if self.total_measure == original_phrase_length: 88 | return 89 | 90 | extra_phrase_length = self.total_measure - original_phrase_length 91 | extra_phrase_name = 'z' + str(extra_phrase_length) 92 | 93 | self.phrase_names = np.append(self.phrase_names, extra_phrase_name) 94 | self.phrase_types = np.append(self.phrase_types, 'z') 95 | self.phrase_lengths = np.append(self.phrase_lengths, 96 | extra_phrase_length) 97 | self.phrase_starts = np.append(self.phrase_starts, 98 | original_phrase_length) 99 | 100 | def clean_chord(self): 101 | chord = self.chord 102 | unit = self.clean_chord_unit 103 | 104 | new_chords = [] 105 | n_chord = len(chord) 106 | for i in range(n_chord): 107 | chord_start = chord[i, 0] 108 | chord_dur = chord[i, -1] 109 | 110 | cum_dur = 0 111 | s = chord_start 112 | while cum_dur < chord_dur: 113 | d = min(unit - s % unit, chord_dur - cum_dur) 114 | c = chord[i].copy() 115 | c[0] = s 116 | c[-1] = d 117 | new_chords.append(c) 118 | 119 | s = s + d 120 | cum_dur += d 121 | 122 | new_chords = np.stack(new_chords, 0) 123 | self.chord = new_chords 124 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZWaang/whole-song-gen/11326220d2f032bdcdce60b8b8cf6891fb2ca308/experiments/__init__.py -------------------------------------------------------------------------------- /experiments/whole_song_gen.py: -------------------------------------------------------------------------------- 1 | # from data_utils import load_datasets 2 | import os 3 | from datetime import datetime 4 | from inference.generation_operations import FormGenOp, CounterpointGenOp, LeadSheetGenOp, AccompanimentGenOp 5 | from inference.utils import quantize_generated_form_batch, specify_form 6 | import numpy as np 7 | from params import params_frm, params_ctp, params_lsh, params_acc 8 | from model import get_model_path 9 | import torch 10 | from data_utils.midi_output import piano_roll_to_note_mat, note_mat_to_notes, create_pm_object 11 | 12 | 13 | class WholeSongGeneration: 14 | 15 | def __init__( 16 | self, 17 | frm_op: FormGenOp, 18 | ctp_op: CounterpointGenOp, 19 | lsh_op: LeadSheetGenOp, 20 | acc_op: AccompanimentGenOp, 21 | desc: str = None, 22 | random_n_autoreg: bool = False, 23 | ): 24 | self.frm_op = frm_op 25 | self.ctp_op = ctp_op 26 | self.lsh_op = lsh_op 27 | self.acc_op = acc_op 28 | self.random_n_autoreg = random_n_autoreg 29 | self.desc = desc 30 | print(f"Description of the experiment is: {self.desc}") 31 | 32 | def form_generation(self): 33 | print("Form generation...") 34 | frm_canvas, slices, gen_max_l = self.frm_op.create_canvas(n_sample=1, prompt=None) 35 | frm = self.frm_op.generation(frm_canvas, slices, gen_max_l, quantize=False, n_sample=1) 36 | frm, lengths, phrase_labels = quantize_generated_form_batch(frm) 37 | print(f"Length of the song: {lengths[0]}, phrase_label:\n{phrase_labels[0]}") 38 | return frm[:, :, 0: lengths[0]], lengths[0], phrase_labels[0] 39 | 40 | def counterpoint_generation(self, background_cond, n_sample, nbpm): 41 | print("Counterpoint generation...") 42 | background_cond = self.ctp_op.expand_background(background_cond, nbpm) 43 | ctp_canvas, slices, gen_max_l = \ 44 | self.ctp_op.create_canvas(background_cond, n_sample, nbpm, None, self.random_n_autoreg) 45 | print(f"Number of iterations: {len(slices)}") 46 | ctp = self.ctp_op.generation(ctp_canvas, slices, gen_max_l) 47 | ctp = np.stack(ctp, 0) 48 | return ctp 49 | 50 | def leadsheet_generation(self, background_cond, n_sample=1, nbpm=4, nspb=4): 51 | print("Lead Sheet generation...") 52 | background_cond = self.lsh_op.expand_background(background_cond, nspb) 53 | lsh_canvas, slices, gen_max_l = \ 54 | self.lsh_op.create_canvas(background_cond, n_sample, nbpm, nspb, None, self.random_n_autoreg) 55 | print(f"Number of iterations: {len(slices)}") 56 | lsh = self.lsh_op.generation(lsh_canvas, slices, gen_max_l) 57 | lsh = np.stack(lsh, 0) 58 | return lsh 59 | 60 | def accompaniment_generation(self, background_cond, n_sample=1, nbpm=4, nspb=4): 61 | print("Accompaniment generation...") 62 | acc_canvas, slices, gen_max_l = \ 63 | self.acc_op.create_canvas(background_cond, n_sample, nbpm, nspb, None, self.random_n_autoreg) 64 | print(f"Number of iterations: {len(slices)}") 65 | lsh = self.acc_op.generation(acc_canvas, slices, gen_max_l) 66 | return lsh 67 | 68 | def main(self, n_sample, nbpm=4, nspb=4, phrase_string=None, key=0, is_major=True, demo_dir=None, bpm=90): 69 | if phrase_string is None: 70 | frm, _, phrase_string = self.form_generation() 71 | else: 72 | frm = np.expand_dims(specify_form(phrase_string, key, is_major), 0) 73 | 74 | ctp = self.counterpoint_generation(frm, n_sample, nbpm) 75 | lsh = self.leadsheet_generation(ctp, 1, nbpm, nspb) 76 | acc = self.accompaniment_generation(lsh, 1, nbpm, nspb) 77 | self.output(acc, phrase_string, key, is_major, demo_dir, bpm) 78 | 79 | def output(self, hie_langs, phrase_string, key, is_major, demo_dir, bpm=90): 80 | cur_time_str = f"{datetime.now().strftime('%m-%d_%H%M%S')}" 81 | exp_name = f"whole-song-gen-{cur_time_str}" 82 | exp_path = os.path.join(demo_dir, exp_name) 83 | 84 | os.makedirs(exp_path, exist_ok=True) 85 | 86 | # write description 87 | with open(os.path.join(exp_path, 'description.txt'), 'w') as file: 88 | file.write(self.desc) 89 | 90 | # write phrase_string 91 | if key is None: 92 | key = 'key: Key is generated (not specified). Visualization is not implemented.' 93 | is_major = 'is_major: Key is generated. Visualization is not implemented.' 94 | else: 95 | key = f"key: {key}" 96 | is_major = f"is_major: {is_major}" 97 | 98 | form = '\n'.join([phrase_string, key, is_major]) 99 | 100 | with open(os.path.join(exp_path, 'form.txt'), 'w') as file: 101 | file.write(form) 102 | 103 | # write midi 104 | for i in range(len(hie_langs)): 105 | acc = hie_langs[i][0: 2] 106 | lsh = hie_langs[i][2: 4] 107 | cpt = hie_langs[i][4: 6] 108 | 109 | # output 110 | nmat_red_mel, nmat_red_chd = piano_roll_to_note_mat(cpt, True, seperate_chord=True) 111 | notes_red_mel = note_mat_to_notes(nmat_red_mel, bpm, unit=0.25) 112 | notes_red_chd = note_mat_to_notes(nmat_red_chd, bpm, unit=0.25) 113 | 114 | nmat_mel, nmat_chd = piano_roll_to_note_mat(lsh, True, seperate_chord=True) 115 | notes_mel = note_mat_to_notes(nmat_mel, bpm, unit=0.25) 116 | notes_chd = note_mat_to_notes(nmat_chd, bpm, unit=0.25) 117 | 118 | notes_acc = note_mat_to_notes(piano_roll_to_note_mat(acc, False), bpm, unit=0.25) 119 | 120 | midi = create_pm_object(bpm, preset=5, 121 | notes_list=[notes_mel, notes_red_chd, notes_mel, notes_chd, notes_acc]) 122 | midi.write(os.path.join(exp_path, f'generation-{i}.mid')) 123 | 124 | @classmethod 125 | def init_pipeline(cls, frm_model_folder, ctp_model_folder, lsh_model_folder, acc_model_folder, 126 | frm_model_id='best', ctp_model_id='best', lsh_model_id='best', acc_model_id='best', 127 | use_autoreg_cond=True, use_external_cond=False, 128 | debug_mode=False, is_autocast_fp16=True, random_n_autoreg=False, device=None): 129 | if device is None: 130 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 131 | print(frm_model_id, ctp_model_id, lsh_model_id, acc_model_id) 132 | frm_model_path, frm_model_id, frm_desc = get_model_path(frm_model_folder, frm_model_id) 133 | ctp_model_path, ctp_model_id, ctp_desc = get_model_path(ctp_model_folder, ctp_model_id) 134 | lsh_model_path, lsh_model_id, lsh_desc = get_model_path(lsh_model_folder, lsh_model_id) 135 | acc_model_path, acc_model_id, acc_desc = get_model_path(acc_model_folder, acc_model_id) 136 | 137 | frm_op = FormGenOp(params_frm, frm_model_path, device, False, False, debug_mode, is_autocast_fp16) 138 | ctp_op = CounterpointGenOp(params_ctp, ctp_model_path, device, use_autoreg_cond, use_external_cond, 139 | debug_mode, is_autocast_fp16) 140 | lsh_op = LeadSheetGenOp(params_lsh, lsh_model_path, device, use_autoreg_cond, use_external_cond, 141 | debug_mode, is_autocast_fp16) 142 | acc_op = AccompanimentGenOp(params_acc, acc_model_path, device, use_autoreg_cond, use_external_cond, 143 | debug_mode, is_autocast_fp16) 144 | 145 | desc = f'm0-{frm_desc}-{frm_model_id}\nm1-{ctp_desc}-{ctp_model_id}\nm2-{lsh_desc}-{lsh_model_id}\n' \ 146 | f'm3-{acc_desc}-{acc_model_id}' 147 | 148 | return cls(frm_op, ctp_op, lsh_op, acc_op, desc, random_n_autoreg) 149 | 150 | -------------------------------------------------------------------------------- /inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZWaang/whole-song-gen/11326220d2f032bdcdce60b8b8cf6891fb2ca308/inference/__init__.py -------------------------------------------------------------------------------- /inference/generation_canvases.py: -------------------------------------------------------------------------------- 1 | from data_utils import LANGUAGE_DATASET_PARAMS, AUTOREG_PARAMS 2 | from data_utils.pytorch_datasets.base_class import select_prev_slices 3 | import numpy as np 4 | 5 | 6 | class GenerationCanvasBase: 7 | 8 | n_channels = None 9 | cur_channels = None 10 | max_l = None 11 | autoreg_max_l = None 12 | h = None 13 | phrase_channel = slice(-6, None) 14 | 15 | max_n_autoreg = None 16 | n_autoreg_prob = None 17 | autoreg_seg_lgth = None 18 | seg_pad_unit = None 19 | 20 | def __init__(self, generation, units, lengths, mask, random_n_autoreg=False): 21 | """generation area: """ 22 | self.generation = generation 23 | self.mask = mask 24 | self.units = units 25 | self.lengths = lengths 26 | 27 | self.batch_size = generation.shape[0] 28 | self.random_n_autoreg = random_n_autoreg 29 | self.external_cond = None 30 | 31 | def __len__(self): 32 | return self.batch_size 33 | 34 | def _get_phrase(self, item): 35 | return self.generation[item, self.phrase_channel] 36 | 37 | def select_autoreg_slices(self, song_id, start_id, scale_unit): 38 | # compute current measure_id of start_id 39 | t_m = start_id // scale_unit 40 | 41 | # retrieve phrase_roll in measure 42 | phrase_roll_m = self._get_phrase(song_id)[:, ::scale_unit, 0] 43 | 44 | # compute phrase type importance at start_id 45 | phrase_type_score = self._get_phrase(song_id)[:, start_id: start_id + self.max_l, 0].sum(-1) # (6, ) 46 | 47 | # sample number of segments 48 | if self.random_n_autoreg: 49 | n_seg = np.random.choice(np.arange(0, self.max_n_autoreg + 1), p=self.n_autoreg_prob) 50 | else: 51 | n_seg = self.max_n_autoreg 52 | 53 | autoreg_slices = select_prev_slices(t_m, phrase_roll_m, phrase_type_score, self.autoreg_seg_lgth, n_seg) 54 | 55 | return autoreg_slices 56 | 57 | def get_autoreg_cond(self, song_id, start_id, scale_unit): 58 | 59 | cond_img = -np.ones((self.n_channels, self.autoreg_max_l, self.h), dtype=np.float32) 60 | 61 | autoreg_slices = self.select_autoreg_slices(song_id, start_id, scale_unit) 62 | 63 | scale_unit_ = int(2 ** np.ceil(np.log2(scale_unit))) 64 | 65 | seg_lgth_unit = self.autoreg_seg_lgth * scale_unit_ + self.seg_pad_unit 66 | 67 | for i, slc in enumerate(autoreg_slices): 68 | seg_start, seg_end = slc[0] * scale_unit, slc[1] * scale_unit 69 | 70 | actual_seg_l = seg_end - seg_start 71 | 72 | tgt_l = self.autoreg_seg_lgth * scale_unit 73 | 74 | autoreg_img = self.generation[song_id, :, seg_start: seg_end] 75 | 76 | cond_img[:, i * seg_lgth_unit: i * seg_lgth_unit + actual_seg_l] = autoreg_img 77 | cond_img[:, i * seg_lgth_unit + actual_seg_l: i * seg_lgth_unit + tgt_l] = 0. 78 | 79 | return cond_img 80 | 81 | def get_batch_autoreg_cond(self, start_id, sel_song_ids=None): 82 | if self.cur_channels == self.n_channels: 83 | return None 84 | start_id = [start_id] * self.batch_size if isinstance(start_id, int) else start_id 85 | cond_img = -np.ones((self.batch_size, self.n_channels, self.autoreg_max_l, self.h), dtype=np.float32) 86 | for i in range(len(self)): 87 | scale_unit = self.units[i] 88 | cond_img[i] = self.get_autoreg_cond(i, start_id[i], scale_unit) 89 | return cond_img[sel_song_ids] 90 | 91 | def get_batch_background_cond(self, start_id, sel_song_ids=None): 92 | if self.cur_channels == self.n_channels: 93 | return None 94 | if isinstance(start_id, int): 95 | background_cond = self.generation[:, self.cur_channels:, start_id: start_id + self.max_l] 96 | else: 97 | background_cond = [ 98 | self.generation[i, self.cur_channels:, start_id[i]: start_id[i] + self.max_l] 99 | for i in range(len(self)) 100 | ] 101 | background_cond = np.stack(background_cond, axis=0) 102 | 103 | return background_cond[sel_song_ids] 104 | 105 | def get_batch_mask(self, start_id, sel_song_ids=None): 106 | if isinstance(start_id, int): 107 | mask = self.mask[:, :, start_id: start_id + self.max_l] 108 | else: 109 | mask = [ 110 | self.mask[i, :, start_id[i]: start_id[i] + self.max_l] 111 | for i in range(len(self)) 112 | ] 113 | mask = np.stack(mask, axis=0) 114 | return mask[sel_song_ids] 115 | 116 | def get_batch_cur_level(self, start_id, sel_song_ids=None): 117 | if isinstance(start_id, int): 118 | cur_level = self.generation[:, 0: self.cur_channels, start_id: start_id + self.max_l] 119 | else: 120 | cur_level = [ 121 | self.generation[i, 0: self.cur_channels, start_id[i]: start_id[i] + self.max_l] 122 | for i in range(len(self)) 123 | ] 124 | cur_level = np.stack(cur_level, axis=0) 125 | 126 | return cur_level[sel_song_ids] 127 | 128 | def get_batch_external_cond(self): 129 | """Not Implemented.""" 130 | if self.external_cond is None: 131 | return None 132 | else: 133 | raise NotImplementedError 134 | 135 | def check_is_generated(self, start_id, end_id): 136 | start_id = [start_id] * self.batch_size if isinstance(start_id, int) else start_id 137 | end_id = [end_id] * self.batch_size if isinstance(end_id, int) else end_id 138 | return np.array([(self.mask[i, :, s: e] == 1).all() for i, (s, e) in enumerate(zip(start_id, end_id))], 139 | dtype=np.bool) 140 | 141 | def write_generation(self, new_generation, start_id, end_id, sel_song_ids=None, quantize=True): 142 | if quantize: 143 | new_generation[:, 0: self.cur_channels][new_generation[:, 0: self.cur_channels] > 0.5] = 1. 144 | new_generation[:, 0: self.cur_channels][new_generation[:, 0: self.cur_channels] < 0.9] = 0. 145 | 146 | if sel_song_ids is None: 147 | sel_song_ids = np.arange(self.batch_size) 148 | else: 149 | sel_song_ids = np.where(sel_song_ids)[0] 150 | 151 | start_id = [start_id] * len(sel_song_ids) if isinstance(start_id, int) else start_id 152 | end_id = [end_id] * len(sel_song_ids) if isinstance(end_id, int) else end_id 153 | for i, (s, e) in enumerate(zip(start_id, end_id)): 154 | song_id = sel_song_ids[i] 155 | self.generation[song_id, 0: self.cur_channels, s: e] = new_generation[i, :, 0: e - s] 156 | self.mask[song_id, 0: self.cur_channels, s: e] = 1 157 | 158 | def to_output(self): 159 | return [self.generation[i, :, 0: self.lengths[i]] for i in range(len(self))] 160 | 161 | 162 | class FormCanvas(GenerationCanvasBase): 163 | 164 | n_channels = LANGUAGE_DATASET_PARAMS['form']['n_channel'] 165 | cur_channels = LANGUAGE_DATASET_PARAMS['form']['cur_channel'] 166 | max_l = LANGUAGE_DATASET_PARAMS['form']['max_l'] 167 | h = LANGUAGE_DATASET_PARAMS['form']['h'] 168 | 169 | 170 | class CounterpointCanvas(GenerationCanvasBase): 171 | 172 | n_channels = LANGUAGE_DATASET_PARAMS['counterpoint']['n_channel'] 173 | cur_channels = LANGUAGE_DATASET_PARAMS['counterpoint']['cur_channel'] 174 | max_l = LANGUAGE_DATASET_PARAMS['counterpoint']['max_l'] 175 | h = LANGUAGE_DATASET_PARAMS['counterpoint']['h'] 176 | 177 | autoreg_max_l = AUTOREG_PARAMS['counterpoint']['autoreg_max_l'] 178 | max_n_autoreg = AUTOREG_PARAMS['counterpoint']['max_n_autoreg'] 179 | n_autoreg_prob = AUTOREG_PARAMS['counterpoint']['n_autoreg_prob'] 180 | autoreg_seg_lgth = AUTOREG_PARAMS['counterpoint']['autoreg_seg_lgth'] 181 | seg_pad_unit = AUTOREG_PARAMS['counterpoint']['seg_pad_unit'] 182 | 183 | 184 | class LeadSheetCanvas(GenerationCanvasBase): 185 | 186 | n_channels = LANGUAGE_DATASET_PARAMS['lead_sheet']['n_channel'] 187 | cur_channels = LANGUAGE_DATASET_PARAMS['lead_sheet']['cur_channel'] 188 | max_l = LANGUAGE_DATASET_PARAMS['lead_sheet']['max_l'] 189 | h = LANGUAGE_DATASET_PARAMS['lead_sheet']['h'] 190 | 191 | autoreg_max_l = AUTOREG_PARAMS['lead_sheet']['autoreg_max_l'] 192 | max_n_autoreg = AUTOREG_PARAMS['lead_sheet']['max_n_autoreg'] 193 | n_autoreg_prob = AUTOREG_PARAMS['lead_sheet']['n_autoreg_prob'] 194 | autoreg_seg_lgth = AUTOREG_PARAMS['lead_sheet']['autoreg_seg_lgth'] 195 | seg_pad_unit = AUTOREG_PARAMS['lead_sheet']['seg_pad_unit'] 196 | 197 | 198 | class AccompanimentCanvas(GenerationCanvasBase): 199 | 200 | n_channels = LANGUAGE_DATASET_PARAMS['accompaniment']['n_channel'] 201 | cur_channels = LANGUAGE_DATASET_PARAMS['accompaniment']['cur_channel'] 202 | max_l = LANGUAGE_DATASET_PARAMS['accompaniment']['max_l'] 203 | h = LANGUAGE_DATASET_PARAMS['accompaniment']['h'] 204 | 205 | autoreg_max_l = AUTOREG_PARAMS['accompaniment']['autoreg_max_l'] 206 | max_n_autoreg = AUTOREG_PARAMS['accompaniment']['max_n_autoreg'] 207 | n_autoreg_prob = AUTOREG_PARAMS['accompaniment']['n_autoreg_prob'] 208 | autoreg_seg_lgth = AUTOREG_PARAMS['accompaniment']['autoreg_seg_lgth'] 209 | seg_pad_unit = AUTOREG_PARAMS['accompaniment']['seg_pad_unit'] 210 | -------------------------------------------------------------------------------- /inference/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def quantize_generated_form(form, song_end_thresh=3, phrase_start_thresh=0.6): 5 | # form: (6, 256, 16) 6 | cleaned_form = np.zeros_like(form) 7 | 8 | key = form[0: 2, :, 0: 12] # (2, 256, 12) 9 | 10 | # determine song end 11 | song_end = np.where(key.sum(-1).sum(0) < song_end_thresh)[0] 12 | song_end = key.shape[1] if len(song_end) == 0 else song_end[0] 13 | 14 | # quantize key 15 | key[key > 0.5] = 1. 16 | key[key < 0.95] = 0. 17 | key[:, song_end:] = 0. 18 | 19 | # quantize phrase to discrete representation 20 | phrase_roll = form[2:].mean(-1) # (6, 256) 21 | phrase_value = phrase_roll.sum(0) # (256,) 22 | 23 | phrases = [] 24 | phrase_starts = [] 25 | phrase_types = [] 26 | phrase_lengths = [] 27 | 28 | for t in range(song_end): 29 | if t == 0 or (phrase_value[t] > phrase_value[t - 1] and phrase_value[t] > phrase_start_thresh): 30 | phrase_starts.append(t) 31 | phrase_types.append(phrase_roll[:, t].argmax()) 32 | phrase_lengths.append(1) 33 | else: 34 | phrase_lengths[-1] += 1 35 | phrases.append((phrase_starts, phrase_types, phrase_lengths)) 36 | phrase_label = phrase_type_to_string(phrase_starts, phrase_types, phrase_lengths) 37 | 38 | # create new continuous phrase roll 39 | new_phrase_roll = np.zeros_like(phrase_roll) 40 | for ps, pt, pl in zip(phrase_starts, phrase_types, phrase_lengths): 41 | new_phrase_roll[int(pt), int(ps): int(ps) + int(pl)] = np.linspace(1, 0, int(pl), endpoint=False) 42 | 43 | cleaned_form[0: 2, :, 0: 12] = key 44 | cleaned_form[2:] = new_phrase_roll[:, :, np.newaxis] 45 | 46 | return cleaned_form, song_end, phrase_label 47 | 48 | 49 | def phrase_type_to_string(phrase_starts, phrase_types, phrase_lengths): 50 | phrase_type_mapping = ['A', 'B', 'X', 'i', 'o', 'b'] 51 | phrase_label = '' 52 | for ps, pt, pl in zip(phrase_starts, phrase_types, phrase_lengths): 53 | phrase_type = phrase_type_mapping[int(pt)] 54 | phrase_label += f"{int(ps)}: {phrase_type}{int(pl)}\n" 55 | return phrase_label 56 | 57 | 58 | def quantize_generated_form_batch(forms, song_end_thresh=3, phrase_start_thresh=0.6): 59 | 60 | cleaned_forms = np.zeros_like(forms) 61 | phrase_labels = [] 62 | n_measures = [] 63 | 64 | for i, form in enumerate(forms): 65 | cleaned_form, song_end, phrase_label = quantize_generated_form(form, song_end_thresh, phrase_start_thresh) 66 | cleaned_forms[i] = cleaned_form 67 | phrase_labels.append(phrase_label) 68 | n_measures.append(song_end) 69 | return cleaned_forms, n_measures, phrase_labels 70 | 71 | 72 | def phrase_config_from_string(phrase_annotation): 73 | index = 0 74 | phrase_configuration = [] 75 | while index < len(phrase_annotation): 76 | label = phrase_annotation[index] 77 | index += 1 78 | n_bars = '' 79 | while index < len(phrase_annotation) and phrase_annotation[index].isdigit(): 80 | n_bars += phrase_annotation[index] 81 | index += 1 82 | phrase_configuration.append((label, int(n_bars))) 83 | return phrase_configuration 84 | 85 | 86 | def phrase_type_to_phrase_type_id(p_type): 87 | if p_type == 'A': 88 | return 0 89 | elif p_type == 'B': 90 | return 1 91 | elif p_type.isupper(): 92 | return 2 93 | elif p_type == 'i': 94 | return 3 95 | elif p_type in ['o', 'z']: 96 | return 4 97 | else: 98 | return 5 99 | 100 | 101 | def phrase_string_to_roll(phrase_string): 102 | phrase_config = phrase_config_from_string(phrase_string) 103 | total_measure = sum([p[1] for p in phrase_config]) 104 | 105 | phrase_mat = np.zeros((6, total_measure, 16)) 106 | cur_measure = 0 107 | for phrase_type, phrase_length in phrase_config: 108 | phrase_value = np.linspace(1, 0, phrase_length, endpoint=False) 109 | phrase_type_id = phrase_type_to_phrase_type_id(phrase_type) 110 | phrase_mat[phrase_type_id, cur_measure: cur_measure + phrase_length, :] = phrase_value[:, np.newaxis] 111 | cur_measure += phrase_length 112 | return phrase_mat 113 | 114 | 115 | def specify_form(phrase_string, key, is_major=True): 116 | major_template = np.array([1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1]) 117 | phrase_mat = phrase_string_to_roll(phrase_string) 118 | form = np.zeros((8, phrase_mat.shape[1], 16)) 119 | form[0, :, key] = 1. 120 | shift = key if is_major else key - 3 121 | form[1, :, 0: 12] = np.roll(major_template, shift=key) 122 | form[2:, :, 0: 16] = phrase_mat 123 | return form 124 | -------------------------------------------------------------------------------- /inference_whole_song.py: -------------------------------------------------------------------------------- 1 | from experiments.whole_song_gen import WholeSongGeneration 2 | import torch 3 | from argparse import ArgumentParser 4 | 5 | 6 | DEFAULT_FRM_MODEL_FOLDER = 'results_default/frm---/v-default' 7 | DEFAULT_CTP_MODEL_FOLDER = 'results_default/ctp-a-b-/v-default' 8 | DEFAULT_LSH_MODEL_FOLDER = 'results_default/lsh-a-b-/v-default' 9 | DEFAULT_ACC_MODEL_FOLDER = 'results_default/acc-a-b-/v-default' 10 | 11 | DEFAULT_DEMO_DIR = 'demo' 12 | 13 | 14 | def init_parser(): 15 | parser = ArgumentParser(description='inference a whole-song generation experiment') 16 | parser.add_argument( 17 | "--demo_dir", 18 | default=DEFAULT_DEMO_DIR, 19 | help='directory in which to generated samples' 20 | ) 21 | parser.add_argument("--mpath0", default=DEFAULT_FRM_MODEL_FOLDER, help="Form generation model path") 22 | parser.add_argument("--mid0", default='default', help="Form generation model id") 23 | 24 | parser.add_argument("--mpath1", default=DEFAULT_CTP_MODEL_FOLDER, help="Counterpoint generation model path") 25 | parser.add_argument("--mid1", default='default', help="Counterpoint generation model id") 26 | 27 | parser.add_argument("--mpath2", default=DEFAULT_LSH_MODEL_FOLDER, help="Lead Sheet generation model path") 28 | parser.add_argument("--mid2", default='default', help="Lead Sheet generation model id") 29 | 30 | parser.add_argument("--mpath3", default=DEFAULT_ACC_MODEL_FOLDER, help="Accompaniment generation model path") 31 | parser.add_argument("--mid3", default='default', help="Accompaniment generation model id") 32 | 33 | parser.add_argument("--nsample", default=1, type=int, help="Number of generated samples") 34 | 35 | parser.add_argument("--pstring", help="Specify phrase structure. If specified, key must be specified.") 36 | 37 | parser.add_argument("--nbpm", default=4, type=int, help="Number of beats per measure") 38 | 39 | parser.add_argument("--key", default=0, type=int, help="Tonic of the key (0 - 11)") 40 | 41 | parser.add_argument('--minor', action='store_false', help="Whether to generated in minor key.") 42 | 43 | parser.add_argument('--debug', action='store_true', help="Whether to use a toy dataset") 44 | 45 | return parser 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = init_parser() 50 | args = parser.parse_args() 51 | 52 | if torch.cuda.is_available(): 53 | device = torch.device('cuda') 54 | elif torch.backends.mps.is_available(): 55 | device = torch.device('mps') 56 | else: 57 | device = torch.device('cpu') 58 | 59 | whole_song_expr = WholeSongGeneration.init_pipeline( 60 | frm_model_folder=args.mpath0, 61 | ctp_model_folder=args.mpath1, 62 | lsh_model_folder=args.mpath2, 63 | acc_model_folder=args.mpath3, 64 | frm_model_id=args.mid0, 65 | ctp_model_id=args.mid1, 66 | lsh_model_id=args.mid2, 67 | acc_model_id=args.mid3, 68 | debug_mode=args.debug, 69 | device=None 70 | ) 71 | 72 | whole_song_expr.main( 73 | n_sample=args.nsample, 74 | nbpm=4, 75 | nspb=4, 76 | phrase_string=args.pstring, 77 | key=args.key, 78 | is_major=args.minor, 79 | demo_dir=args.demo_dir 80 | ) 81 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .stable_diffusion.latent_diffusion import LatentDiffusion 2 | from .model_sdf import Diffpro_SDF 3 | from .stable_diffusion.model.unet import UNetModel 4 | from .stable_diffusion.model.autoreg_cond_encoders import * 5 | from .stable_diffusion.model.external_cond_encoders import * 6 | 7 | autoreg_enc_dict = {'frm': None, 'ctp': CtpAutoregEncoder, 'lsh': LshAutoregEncoder, 'acc': AccAutoregEncoder} 8 | external_enc_dict = {'frm': None, 'ctp': CtpExternalEncoder, 'lsh': LshExternalEncoder, 'acc': AccExternalEncoder} 9 | 10 | 11 | def init_ldm_model(mode, use_autoreg_cond, use_external_cond, params, debug_mode=False): 12 | unet_model = UNetModel( 13 | in_channels=params.in_channels, 14 | out_channels=params.out_channels, 15 | channels=params.channels, 16 | attention_levels=params.attention_levels, 17 | n_res_blocks=params.n_res_blocks, 18 | channel_multipliers=params.channel_multipliers, 19 | n_heads=params.n_heads, 20 | tf_layers=params.tf_layers, 21 | d_cond=params.d_cond, 22 | ) 23 | 24 | autoreg_enc_cls = autoreg_enc_dict[mode] if use_autoreg_cond else None 25 | external_enc_cls = external_enc_dict[mode] if use_external_cond else None 26 | 27 | autoreg_cond_enc = None if autoreg_enc_cls is None else autoreg_enc_cls() 28 | external_cond_enc = None if external_enc_cls is None else external_enc_cls.create_model() 29 | 30 | ldm_model = LatentDiffusion( 31 | unet_model=unet_model, 32 | autoencoder=None, 33 | autoreg_cond_enc=autoreg_cond_enc, 34 | external_cond_enc=external_cond_enc, 35 | latent_scaling_factor=params.latent_scaling_factor, 36 | n_steps=params.n_steps, 37 | linear_start=params.linear_start, 38 | linear_end=params.linear_end, 39 | debug_mode=debug_mode 40 | ) 41 | 42 | return ldm_model 43 | 44 | 45 | def init_diff_pro_sdf(ldm_model, params, device): 46 | return Diffpro_SDF(ldm_model).to(device) 47 | 48 | 49 | def get_model_path(model_dir, model_id=None): 50 | model_desc = os.path.basename(model_dir) 51 | if model_id is None: 52 | model_path = os.path.join(model_dir, 'chkpts', 'weights.pt') 53 | 54 | # retrieve real model_id from the actual file weights.pt is pointing to 55 | model_id = os.path.basename(os.path.realpath(model_path)).split('-')[1].split('.')[0] 56 | 57 | elif model_id == 'best': 58 | model_path = os.path.join(model_dir, 'chkpts', 'weights_best.pt') 59 | # retrieve real model_id from the actual file weights.pt is pointing to 60 | model_id = os.path.basename(os.path.realpath(model_path)).split('-')[1].split('.')[0] 61 | elif model_id == 'default': 62 | model_path = os.path.join(model_dir, 'chkpts', 'weights_default.pt') 63 | if not os.path.exists(model_path): 64 | return get_model_path(model_dir, 'best') 65 | else: 66 | model_path = f"{model_dir}/chkpts/weights-{model_id}.pt" 67 | return model_path, model_id, model_desc 68 | -------------------------------------------------------------------------------- /model/model_sdf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .stable_diffusion.latent_diffusion import LatentDiffusion 4 | 5 | 6 | class Diffpro_SDF(nn.Module): 7 | 8 | def __init__( 9 | self, 10 | ldm: LatentDiffusion, 11 | ): 12 | """ 13 | cond_type: {chord, texture} 14 | cond_mode: {cond, mix, uncond} 15 | mix: use a special condition for unconditional learning with probability of 0.2 16 | use_enc: whether to use pretrained chord encoder to generate encoded condition 17 | """ 18 | super(Diffpro_SDF, self).__init__() 19 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 20 | self.ldm = ldm 21 | 22 | @classmethod 23 | def load_trained( 24 | cls, 25 | ldm, 26 | chkpt_fpath, 27 | ): 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | model = cls(ldm) 30 | trained_leaner = torch.load(chkpt_fpath, map_location=device) 31 | try: 32 | model.load_state_dict(trained_leaner["model"]) 33 | except RuntimeError: 34 | model_dict = trained_leaner["model"] 35 | model_dict = {k.replace('cond_enc', 'autoreg_cond_enc'): v for k, v in model_dict.items()} 36 | model_dict = {k.replace('style_enc', 'external_cond_enc'): v for k, v in model_dict.items()} 37 | model.load_state_dict(model_dict) 38 | return model 39 | 40 | def p_sample(self, xt: torch.Tensor, t: torch.Tensor): 41 | return self.ldm.p_sample(xt, t) 42 | 43 | def q_sample(self, x0: torch.Tensor, t: torch.Tensor): 44 | return self.ldm.q_sample(x0, t) 45 | 46 | def get_loss_dict(self, batch, step): 47 | """ 48 | z_y is the stuff the diffusion model needs to learn 49 | """ 50 | # x = batch.float().to(self.device) 51 | 52 | x, autoreg_cond, external_cond = batch 53 | loss = self.ldm.loss(x, autoreg_cond, external_cond) 54 | return {"loss": loss} 55 | -------------------------------------------------------------------------------- /model/stable_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZWaang/whole-song-gen/11326220d2f032bdcdce60b8b8cf6891fb2ca308/model/stable_diffusion/__init__.py -------------------------------------------------------------------------------- /model/stable_diffusion/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .contperceptual import LPIPSWithDiscriminator 2 | -------------------------------------------------------------------------------- /model/stable_diffusion/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | from .vqperceptual import * 6 | 7 | 8 | class LPIPSWithDiscriminator(nn.Module): 9 | def __init__( 10 | self, 11 | disc_start, 12 | logvar_init=0.0, 13 | kl_weight=1.0, 14 | pixelloss_weight=1.0, 15 | disc_num_layers=3, 16 | disc_in_channels=3, 17 | disc_factor=1.0, 18 | disc_weight=1.0, 19 | perceptual_weight=1.0, 20 | use_actnorm=False, 21 | disc_conditional=False, 22 | disc_loss="hinge" 23 | ): 24 | 25 | super().__init__() 26 | assert disc_loss in ["hinge", "vanilla"] 27 | self.kl_weight = kl_weight 28 | self.pixel_weight = pixelloss_weight 29 | self.perceptual_loss = LPIPS().eval() 30 | self.perceptual_weight = perceptual_weight 31 | # output log variance 32 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 33 | 34 | self.discriminator = NLayerDiscriminator( 35 | input_nc=disc_in_channels, 36 | n_layers=disc_num_layers, 37 | use_actnorm=use_actnorm 38 | ).apply(weights_init) 39 | self.discriminator_iter_start = disc_start 40 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 41 | self.disc_factor = disc_factor 42 | self.discriminator_weight = disc_weight 43 | self.disc_conditional = disc_conditional 44 | 45 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 46 | if last_layer is not None: 47 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 48 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 49 | else: 50 | nll_grads = torch.autograd.grad( 51 | nll_loss, self.last_layer[0], retain_graph=True 52 | )[0] 53 | g_grads = torch.autograd.grad( 54 | g_loss, self.last_layer[0], retain_graph=True 55 | )[0] 56 | 57 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 58 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 59 | d_weight = d_weight * self.discriminator_weight 60 | return d_weight 61 | 62 | def forward( 63 | self, 64 | inputs, 65 | reconstructions, 66 | posteriors, 67 | optimizer_idx, 68 | global_step, 69 | last_layer=None, 70 | cond=None, 71 | split="train", 72 | weights=None 73 | ): 74 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 75 | if self.perceptual_weight > 0: 76 | p_loss = self.perceptual_loss( 77 | inputs.contiguous(), reconstructions.contiguous() 78 | ) 79 | rec_loss = rec_loss + self.perceptual_weight * p_loss 80 | 81 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 82 | weighted_nll_loss = nll_loss 83 | if weights is not None: 84 | weighted_nll_loss = weights * nll_loss 85 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 86 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 87 | kl_loss = posteriors.kl() 88 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 89 | 90 | # now the GAN part 91 | if optimizer_idx == 0: 92 | # generator update 93 | if cond is None: 94 | assert not self.disc_conditional 95 | logits_fake = self.discriminator(reconstructions.contiguous()) 96 | else: 97 | assert self.disc_conditional 98 | logits_fake = self.discriminator( 99 | torch.cat((reconstructions.contiguous(), cond), dim=1) 100 | ) 101 | g_loss = -torch.mean(logits_fake) 102 | 103 | if self.disc_factor > 0.0: 104 | try: 105 | d_weight = self.calculate_adaptive_weight( 106 | nll_loss, g_loss, last_layer=last_layer 107 | ) 108 | except RuntimeError: 109 | assert not self.training 110 | d_weight = torch.tensor(0.0) 111 | else: 112 | d_weight = torch.tensor(0.0) 113 | 114 | disc_factor = adopt_weight( 115 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 116 | ) 117 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 118 | 119 | log = { 120 | "{}/total_loss".format(split): loss.clone().detach().mean(), 121 | "{}/logvar".format(split): self.logvar.detach(), 122 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 123 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 124 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 125 | "{}/d_weight".format(split): d_weight.detach(), 126 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 127 | "{}/g_loss".format(split): g_loss.detach().mean(), 128 | } 129 | return loss, log 130 | 131 | if optimizer_idx == 1: 132 | # second pass for discriminator update 133 | if cond is None: 134 | logits_real = self.discriminator(inputs.contiguous().detach()) 135 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 136 | else: 137 | logits_real = self.discriminator( 138 | torch.cat((inputs.contiguous().detach(), cond), dim=1) 139 | ) 140 | logits_fake = self.discriminator( 141 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1) 142 | ) 143 | 144 | disc_factor = adopt_weight( 145 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 146 | ) 147 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 148 | 149 | log = { 150 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(), 151 | "{}/logits_real".format(split): logits_real.detach().mean(), 152 | "{}/logits_fake".format(split): logits_fake.detach().mean() 153 | } 154 | return d_loss, log 155 | -------------------------------------------------------------------------------- /model/stable_diffusion/losses/discriminator.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | from .util import ActNorm 5 | 6 | 7 | def weights_init(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Conv') != -1: 10 | nn.init.normal_(m.weight.data, 0.0, 0.02) 11 | elif classname.find('BatchNorm') != -1: 12 | nn.init.normal_(m.weight.data, 1.0, 0.02) 13 | nn.init.constant_(m.bias.data, 0) 14 | 15 | 16 | class NLayerDiscriminator(nn.Module): 17 | """Defines a PatchGAN discriminator as in Pix2Pix 18 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 19 | """ 20 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 21 | """Construct a PatchGAN discriminator 22 | Parameters: 23 | input_nc (int) -- the number of channels in input images 24 | ndf (int) -- the number of filters in the last conv layer 25 | n_layers (int) -- the number of conv layers in the discriminator 26 | norm_layer -- normalization layer 27 | """ 28 | super(NLayerDiscriminator, self).__init__() 29 | if not use_actnorm: 30 | norm_layer = nn.BatchNorm2d 31 | else: 32 | norm_layer = ActNorm 33 | if type( 34 | norm_layer 35 | ) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 36 | use_bias = norm_layer.func != nn.BatchNorm2d 37 | else: 38 | use_bias = norm_layer != nn.BatchNorm2d 39 | 40 | kw = 4 41 | padw = 1 42 | sequence = [ 43 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 44 | nn.LeakyReLU(0.2, True) 45 | ] 46 | nf_mult = 1 47 | nf_mult_prev = 1 48 | for n in range(1, n_layers): # gradually increase the number of filters 49 | nf_mult_prev = nf_mult 50 | nf_mult = min(2**n, 8) 51 | sequence += [ 52 | nn.Conv2d( 53 | ndf * nf_mult_prev, 54 | ndf * nf_mult, 55 | kernel_size=kw, 56 | stride=2, 57 | padding=padw, 58 | bias=use_bias 59 | ), 60 | norm_layer(ndf * nf_mult), 61 | nn.LeakyReLU(0.2, True) 62 | ] 63 | 64 | nf_mult_prev = nf_mult 65 | nf_mult = min(2**n_layers, 8) 66 | sequence += [ 67 | nn.Conv2d( 68 | ndf * nf_mult_prev, 69 | ndf * nf_mult, 70 | kernel_size=kw, 71 | stride=1, 72 | padding=padw, 73 | bias=use_bias 74 | ), 75 | norm_layer(ndf * nf_mult), 76 | nn.LeakyReLU(0.2, True) 77 | ] 78 | 79 | sequence += [ 80 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 81 | ] # output 1 channel prediction map 82 | self.main = nn.Sequential(*sequence) 83 | 84 | def forward(self, input): 85 | """Standard forward.""" 86 | return self.main(input) 87 | -------------------------------------------------------------------------------- /model/stable_diffusion/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from .util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 29 | self.load_state_dict( 30 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 31 | ) 32 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 33 | 34 | @classmethod 35 | def from_pretrained(cls, name="vgg_lpips"): 36 | if name != "vgg_lpips": 37 | raise NotImplementedError 38 | model = cls() 39 | ckpt = get_ckpt_path(name) 40 | model.load_state_dict( 41 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 42 | ) 43 | return model 44 | 45 | def forward(self, input, target): 46 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 47 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 48 | feats0, feats1, diffs = {}, {}, {} 49 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 50 | for kk in range(len(self.chns)): 51 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( 52 | outs1[kk] 53 | ) 54 | diffs[kk] = (feats0[kk] - feats1[kk])**2 55 | 56 | res = [ 57 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 58 | for kk in range(len(self.chns)) 59 | ] 60 | val = res[0] 61 | for l in range(1, len(self.chns)): 62 | val += res[l] 63 | return val 64 | 65 | 66 | class ScalingLayer(nn.Module): 67 | def __init__(self): 68 | super(ScalingLayer, self).__init__() 69 | self.register_buffer( 70 | 'shift', 71 | torch.Tensor([-.030, -.088, -.188])[None, :, None, None] 72 | ) 73 | self.register_buffer( 74 | 'scale', 75 | torch.Tensor([.458, .448, .450])[None, :, None, None] 76 | ) 77 | 78 | def forward(self, inp): 79 | return (inp - self.shift) / self.scale 80 | 81 | 82 | class NetLinLayer(nn.Module): 83 | """ A single linear layer which does a 1x1 conv """ 84 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 85 | super(NetLinLayer, self).__init__() 86 | layers = [ 87 | nn.Dropout(), 88 | ] if (use_dropout) else [] 89 | layers += [ 90 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 91 | ] 92 | self.model = nn.Sequential(*layers) 93 | 94 | 95 | class vgg16(torch.nn.Module): 96 | def __init__(self, requires_grad=False, pretrained=True): 97 | super(vgg16, self).__init__() 98 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 99 | self.slice1 = torch.nn.Sequential() 100 | self.slice2 = torch.nn.Sequential() 101 | self.slice3 = torch.nn.Sequential() 102 | self.slice4 = torch.nn.Sequential() 103 | self.slice5 = torch.nn.Sequential() 104 | self.N_slices = 5 105 | for x in range(4): 106 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 107 | for x in range(4, 9): 108 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 109 | for x in range(9, 16): 110 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(16, 23): 112 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(23, 30): 114 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 115 | if not requires_grad: 116 | for param in self.parameters(): 117 | param.requires_grad = False 118 | 119 | def forward(self, X): 120 | h = self.slice1(X) 121 | h_relu1_2 = h 122 | h = self.slice2(h) 123 | h_relu2_2 = h 124 | h = self.slice3(h) 125 | h_relu3_3 = h 126 | h = self.slice4(h) 127 | h_relu4_3 = h 128 | h = self.slice5(h) 129 | h_relu5_3 = h 130 | vgg_outputs = namedtuple( 131 | "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] 132 | ) 133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 134 | return out 135 | 136 | 137 | def normalize_tensor(x, eps=1e-10): 138 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 139 | return x / (norm_factor + eps) 140 | 141 | 142 | def spatial_average(x, keepdim=True): 143 | return x.mean([2, 3], keepdim=keepdim) 144 | -------------------------------------------------------------------------------- /model/stable_diffusion/losses/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 6 | 7 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 8 | 9 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 10 | 11 | 12 | def download(url, local_path, chunk_size=1024): 13 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 14 | with requests.get(url, stream=True) as r: 15 | total_size = int(r.headers.get("content-length", 0)) 16 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 17 | with open(local_path, "wb") as f: 18 | for data in r.iter_content(chunk_size=chunk_size): 19 | if data: 20 | f.write(data) 21 | pbar.update(chunk_size) 22 | 23 | 24 | def md5_hash(path): 25 | with open(path, "rb") as f: 26 | content = f.read() 27 | return hashlib.md5(content).hexdigest() 28 | 29 | 30 | def get_ckpt_path(name, root="", check=False): 31 | assert name in URL_MAP 32 | path = os.path.join(root, CKPT_MAP[name]) 33 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 34 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 35 | download(URL_MAP[name], path) 36 | md5 = md5_hash(path) 37 | assert md5 == MD5_MAP[name], md5 38 | return path 39 | 40 | 41 | class KeyNotFoundError(Exception): 42 | def __init__(self, cause, keys=None, visited=None): 43 | self.cause = cause 44 | self.keys = keys 45 | self.visited = visited 46 | messages = list() 47 | if keys is not None: 48 | messages.append("Key not found: {}".format(keys)) 49 | if visited is not None: 50 | messages.append("Visited: {}".format(visited)) 51 | messages.append("Cause:\n{}".format(cause)) 52 | message = "\n".join(messages) 53 | super().__init__(message) 54 | 55 | 56 | def retrieve( 57 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 58 | ): 59 | """Given a nested list or dict return the desired value at key expanding 60 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 61 | is done in-place. 62 | 63 | Parameters 64 | ---------- 65 | list_or_dict : list or dict 66 | Possibly nested list or dictionary. 67 | key : str 68 | key/to/value, path like string describing all keys necessary to 69 | consider to get to the desired value. List indices can also be 70 | passed here. 71 | splitval : str 72 | String that defines the delimiter between keys of the 73 | different depth levels in `key`. 74 | default : obj 75 | Value returned if :attr:`key` is not found. 76 | expand : bool 77 | Whether to expand callable nodes on the path or not. 78 | 79 | Returns 80 | ------- 81 | The desired value or if :attr:`default` is not ``None`` and the 82 | :attr:`key` is not found returns ``default``. 83 | 84 | Raises 85 | ------ 86 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 87 | ``None``. 88 | """ 89 | 90 | keys = key.split(splitval) 91 | 92 | success = True 93 | try: 94 | visited = [] 95 | parent = None 96 | last_key = None 97 | for key in keys: 98 | if callable(list_or_dict): 99 | if not expand: 100 | raise KeyNotFoundError( 101 | ValueError( 102 | "Trying to get past callable node with expand=False." 103 | ), 104 | keys=keys, 105 | visited=visited, 106 | ) 107 | list_or_dict = list_or_dict() 108 | parent[last_key] = list_or_dict 109 | 110 | last_key = key 111 | parent = list_or_dict 112 | 113 | try: 114 | if isinstance(list_or_dict, dict): 115 | list_or_dict = list_or_dict[key] 116 | else: 117 | list_or_dict = list_or_dict[int(key)] 118 | except (KeyError, IndexError, ValueError) as e: 119 | raise KeyNotFoundError(e, keys=keys, visited=visited) 120 | 121 | visited += [key] 122 | # final expansion of retrieved value 123 | if expand and callable(list_or_dict): 124 | list_or_dict = list_or_dict() 125 | parent[last_key] = list_or_dict 126 | except KeyNotFoundError as e: 127 | if default is None: 128 | raise e 129 | else: 130 | list_or_dict = default 131 | success = False 132 | 133 | if not pass_success: 134 | return list_or_dict 135 | else: 136 | return list_or_dict, success 137 | 138 | 139 | import torch 140 | import torch.nn as nn 141 | 142 | 143 | def count_params(model): 144 | total_params = sum(p.numel() for p in model.parameters()) 145 | return total_params 146 | 147 | 148 | class ActNorm(nn.Module): 149 | def __init__( 150 | self, num_features, logdet=False, affine=True, allow_reverse_init=False 151 | ): 152 | assert affine 153 | super().__init__() 154 | self.logdet = logdet 155 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 156 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 157 | self.allow_reverse_init = allow_reverse_init 158 | 159 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 160 | 161 | def initialize(self, input): 162 | with torch.no_grad(): 163 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 164 | mean = ( 165 | flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute( 166 | 1, 0, 2, 3 167 | ) 168 | ) 169 | std = ( 170 | flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute( 171 | 1, 0, 2, 3 172 | ) 173 | ) 174 | 175 | self.loc.data.copy_(-mean) 176 | self.scale.data.copy_(1 / (std + 1e-6)) 177 | 178 | def forward(self, input, reverse=False): 179 | if reverse: 180 | return self.reverse(input) 181 | if len(input.shape) == 2: 182 | input = input[:, :, None, None] 183 | squeeze = True 184 | else: 185 | squeeze = False 186 | 187 | _, _, height, width = input.shape 188 | 189 | if self.training and self.initialized.item() == 0: 190 | self.initialize(input) 191 | self.initialized.fill_(1) 192 | 193 | h = self.scale * (input + self.loc) 194 | 195 | if squeeze: 196 | h = h.squeeze(-1).squeeze(-1) 197 | 198 | if self.logdet: 199 | log_abs = torch.log(torch.abs(self.scale)) 200 | logdet = height * width * torch.sum(log_abs) 201 | logdet = logdet * torch.ones(input.shape[0]).to(input) 202 | return h, logdet 203 | 204 | return h 205 | 206 | def reverse(self, output): 207 | if self.training and self.initialized.item() == 0: 208 | if not self.allow_reverse_init: 209 | raise RuntimeError( 210 | "Initializing ActNorm in reverse direction is " 211 | "disabled by default. Use allow_reverse_init=True to enable." 212 | ) 213 | else: 214 | self.initialize(output) 215 | self.initialized.fill_(1) 216 | 217 | if len(output.shape) == 2: 218 | output = output[:, :, None, None] 219 | squeeze = True 220 | else: 221 | squeeze = False 222 | 223 | h = output / self.scale - self.loc 224 | 225 | if squeeze: 226 | h = h.squeeze(-1).squeeze(-1) 227 | return h 228 | 229 | 230 | class AbstractEncoder(nn.Module): 231 | def __init__(self): 232 | super().__init__() 233 | 234 | def encode(self, *args, **kwargs): 235 | raise NotImplementedError 236 | 237 | 238 | class Labelator(AbstractEncoder): 239 | """Net2Net Interface for Class-Conditional Model""" 240 | def __init__(self, n_classes, quantize_interface=True): 241 | super().__init__() 242 | self.n_classes = n_classes 243 | self.quantize_interface = quantize_interface 244 | 245 | def encode(self, c): 246 | c = c[:, None] 247 | if self.quantize_interface: 248 | return c, None, [None, None, c.long()] 249 | return c 250 | 251 | 252 | class SOSProvider(AbstractEncoder): 253 | # for unconditional training 254 | def __init__(self, sos_token, quantize_interface=True): 255 | super().__init__() 256 | self.sos_token = sos_token 257 | self.quantize_interface = quantize_interface 258 | 259 | def encode(self, x): 260 | # get batch size from data and replicate sos_token 261 | c = torch.ones(x.shape[0], 1) * self.sos_token 262 | c = c.long().to(x.device) 263 | if self.quantize_interface: 264 | return c, None, [None, None, c] 265 | return c 266 | 267 | 268 | if __name__ == "__main__": 269 | config = { 270 | "keya": "a", 271 | "keyb": "b", 272 | "keyc": { 273 | "cc1": 1, 274 | "cc2": 2, 275 | } 276 | } 277 | from omegaconf import OmegaConf 278 | config = OmegaConf.create(config) 279 | print(config) 280 | retrieve(config, "keya") 281 | -------------------------------------------------------------------------------- /model/stable_diffusion/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .lpips import LPIPS 6 | from .discriminator import NLayerDiscriminator, weights_init 7 | 8 | 9 | class DummyLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | 14 | def adopt_weight(weight, global_step, threshold=0, value=0.): 15 | if global_step < threshold: 16 | weight = value 17 | return weight 18 | 19 | 20 | def hinge_d_loss(logits_real, logits_fake): 21 | loss_real = torch.mean(F.relu(1. - logits_real)) 22 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 23 | d_loss = 0.5 * (loss_real + loss_fake) 24 | return d_loss 25 | 26 | 27 | def vanilla_d_loss(logits_real, logits_fake): 28 | d_loss = 0.5 * ( 29 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 30 | torch.mean(torch.nn.functional.softplus(logits_fake)) 31 | ) 32 | return d_loss 33 | 34 | 35 | class VQLPIPSWithDiscriminator(nn.Module): 36 | def __init__( 37 | self, 38 | disc_start, 39 | codebook_weight=1.0, 40 | pixelloss_weight=1.0, 41 | disc_num_layers=3, 42 | disc_in_channels=3, 43 | disc_factor=1.0, 44 | disc_weight=1.0, 45 | perceptual_weight=1.0, 46 | use_actnorm=False, 47 | disc_conditional=False, 48 | disc_ndf=64, 49 | disc_loss="hinge" 50 | ): 51 | super().__init__() 52 | assert disc_loss in ["hinge", "vanilla"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | self.perceptual_loss = LPIPS().eval() 56 | self.perceptual_weight = perceptual_weight 57 | 58 | self.discriminator = NLayerDiscriminator( 59 | input_nc=disc_in_channels, 60 | n_layers=disc_num_layers, 61 | use_actnorm=use_actnorm, 62 | ndf=disc_ndf 63 | ).apply(weights_init) 64 | self.discriminator_iter_start = disc_start 65 | if disc_loss == "hinge": 66 | self.disc_loss = hinge_d_loss 67 | elif disc_loss == "vanilla": 68 | self.disc_loss = vanilla_d_loss 69 | else: 70 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 71 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 72 | self.disc_factor = disc_factor 73 | self.discriminator_weight = disc_weight 74 | self.disc_conditional = disc_conditional 75 | 76 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 77 | if last_layer is not None: 78 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 79 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 80 | else: 81 | nll_grads = torch.autograd.grad( 82 | nll_loss, self.last_layer[0], retain_graph=True 83 | )[0] 84 | g_grads = torch.autograd.grad( 85 | g_loss, self.last_layer[0], retain_graph=True 86 | )[0] 87 | 88 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 89 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 90 | d_weight = d_weight * self.discriminator_weight 91 | return d_weight 92 | 93 | def forward( 94 | self, 95 | codebook_loss, 96 | inputs, 97 | reconstructions, 98 | optimizer_idx, 99 | global_step, 100 | last_layer=None, 101 | cond=None, 102 | split="train" 103 | ): 104 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 105 | if self.perceptual_weight > 0: 106 | p_loss = self.perceptual_loss( 107 | inputs.contiguous(), reconstructions.contiguous() 108 | ) 109 | rec_loss = rec_loss + self.perceptual_weight * p_loss 110 | else: 111 | p_loss = torch.tensor([0.0]) 112 | 113 | nll_loss = rec_loss 114 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 115 | nll_loss = torch.mean(nll_loss) 116 | 117 | # now the GAN part 118 | if optimizer_idx == 0: 119 | # generator update 120 | if cond is None: 121 | assert not self.disc_conditional 122 | logits_fake = self.discriminator(reconstructions.contiguous()) 123 | else: 124 | assert self.disc_conditional 125 | logits_fake = self.discriminator( 126 | torch.cat((reconstructions.contiguous(), cond), dim=1) 127 | ) 128 | g_loss = -torch.mean(logits_fake) 129 | 130 | try: 131 | d_weight = self.calculate_adaptive_weight( 132 | nll_loss, g_loss, last_layer=last_layer 133 | ) 134 | except RuntimeError: 135 | assert not self.training 136 | d_weight = torch.tensor(0.0) 137 | 138 | disc_factor = adopt_weight( 139 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 140 | ) 141 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean( 142 | ) 143 | 144 | log = { 145 | "{}/total_loss".format(split): loss.clone().detach().mean(), 146 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 147 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 148 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 149 | "{}/p_loss".format(split): p_loss.detach().mean(), 150 | "{}/d_weight".format(split): d_weight.detach(), 151 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 152 | "{}/g_loss".format(split): g_loss.detach().mean(), 153 | } 154 | return loss, log 155 | 156 | if optimizer_idx == 1: 157 | # second pass for discriminator update 158 | if cond is None: 159 | logits_real = self.discriminator(inputs.contiguous().detach()) 160 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 161 | else: 162 | logits_real = self.discriminator( 163 | torch.cat((inputs.contiguous().detach(), cond), dim=1) 164 | ) 165 | logits_fake = self.discriminator( 166 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1) 167 | ) 168 | 169 | disc_factor = adopt_weight( 170 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 171 | ) 172 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 173 | 174 | log = { 175 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(), 176 | "{}/logits_real".format(split): logits_real.detach().mean(), 177 | "{}/logits_fake".format(split): logits_fake.detach().mean() 178 | } 179 | return d_loss, log 180 | -------------------------------------------------------------------------------- /model/stable_diffusion/model/autoreg_cond_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CtpAutoregEncoder(nn.Module): 6 | 7 | def __init__(self, img_h=108, img_w=128): 8 | super().__init__() 9 | input_channels = 10 10 | mid_channels = 20 11 | output_channels = 2 12 | self.layers = nn.Sequential( 13 | nn.Conv2d(input_channels, mid_channels, 3, padding=1), 14 | nn.SiLU(), 15 | nn.LayerNorm([mid_channels, img_h, img_w]), 16 | nn.Conv2d(mid_channels, mid_channels, 3, padding=1), 17 | nn.SiLU(), 18 | nn.LayerNorm([mid_channels, img_h, img_w]), 19 | nn.Conv2d(mid_channels, output_channels, 3, padding=1), 20 | ) 21 | self.img_h = img_h 22 | self.img_w = img_w 23 | self.squeeze = nn.Linear(1024, 128) 24 | self.output_channels = output_channels 25 | self.pos_enc = nn.Embedding(27, 128) 26 | 27 | def forward(self, x): 28 | bs = x.size(0) 29 | x = self.layers(x) 30 | x = x.reshape(bs, self.output_channels, 27, 4 * 128).permute(0, 2, 1, 3).reshape(bs, 27, -1) 31 | x = self.squeeze(x) 32 | 33 | pos = torch.arange(0, 27).to(x.device).long() 34 | pos_emb = self.pos_enc(pos).unsqueeze(0) 35 | x += pos_emb 36 | return x 37 | 38 | 39 | class LshAutoregEncoder(nn.Module): 40 | 41 | def __init__(self, img_h=136, img_w=128): 42 | super().__init__() 43 | input_channels = 12 44 | mid_channels = 20 45 | output_channels = 2 46 | self.layers = nn.Sequential( 47 | nn.Conv2d(input_channels, mid_channels, 3, padding=1), 48 | nn.SiLU(), 49 | nn.LayerNorm([mid_channels, img_h, img_w]), 50 | nn.Conv2d(mid_channels, mid_channels, 3, padding=1), 51 | nn.SiLU(), 52 | nn.LayerNorm([mid_channels, img_h, img_w]), 53 | nn.Conv2d(mid_channels, output_channels, 3, padding=1), 54 | ) 55 | self.img_h = img_h 56 | self.img_w = img_w 57 | self.squeeze = nn.Linear(2048, 256) 58 | self.output_channels = output_channels 59 | self.pos_enc = nn.Embedding(17, 256) 60 | 61 | def forward(self, x): 62 | bs = x.size(0) 63 | x = self.layers(x) 64 | x = x.reshape(bs, self.output_channels, 17, 8 * 128).permute(0, 2, 1, 3).reshape(bs, 17, -1) 65 | x = self.squeeze(x) 66 | 67 | pos = torch.arange(0, 17).to(x.device).long() 68 | pos_emb = self.pos_enc(pos).unsqueeze(0) 69 | x += pos_emb 70 | return x 71 | 72 | 73 | class AccAutoregEncoder(nn.Module): 74 | 75 | def __init__(self, img_h=136, img_w=128): 76 | super().__init__() 77 | input_channels = 14 78 | mid_channels = 20 79 | output_channels = 2 80 | self.layers = nn.Sequential( 81 | nn.Conv2d(input_channels, mid_channels, 3, padding=1), 82 | nn.SiLU(), 83 | nn.LayerNorm([mid_channels, img_h, img_w]), 84 | nn.Conv2d(mid_channels, mid_channels, 3, padding=1), 85 | nn.SiLU(), 86 | nn.LayerNorm([mid_channels, img_h, img_w]), 87 | nn.Conv2d(mid_channels, output_channels, 3, padding=1), 88 | ) 89 | self.img_h = img_h 90 | self.img_w = img_w 91 | self.squeeze = nn.Linear(2048, 256) 92 | self.output_channels = output_channels 93 | self.pos_enc = nn.Embedding(17, 256) 94 | 95 | def forward(self, x): 96 | bs = x.size(0) 97 | x = self.layers(x)[:, :, 0: 200] 98 | x = x.reshape(bs, self.output_channels, 17, 8 * 128).permute(0, 2, 1, 3).reshape(bs, 17, -1) 99 | x = self.squeeze(x) 100 | 101 | pos = torch.arange(0, 17).to(x.device).long() 102 | pos_emb = self.pos_enc(pos).unsqueeze(0) 103 | x += pos_emb 104 | return x 105 | -------------------------------------------------------------------------------- /model/stable_diffusion/model/external_cond_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .pretrained_encoders import load_chd_enc8, load_ec2vae_enc2, load_txt_enc2 4 | import os 5 | 6 | 7 | PROJECT_PATH = os.path.realpath(os.path.join(__file__, '../../../../')) 8 | 9 | 10 | class CtpExternalEncoder(nn.Module): 11 | 12 | def __init__(self, chd_enc8): 13 | super().__init__() 14 | self.enc = chd_enc8 15 | self.squeezer = nn.Linear(512, 128) 16 | self.pos_enc = nn.Embedding(4, 128) 17 | 18 | def forward(self, chd): 19 | pos = torch.arange(0, 4).to(chd.device).long() 20 | pos_emb = self.pos_enc(pos).unsqueeze(0) 21 | 22 | bs = chd.size(0) 23 | chd = chd.reshape(-1, 32, 36) 24 | 25 | self.enc.eval() 26 | with torch.no_grad(): 27 | z = self.enc(chd).reshape(bs, 4, 512) 28 | z = self.squeezer(z) 29 | z += pos_emb 30 | return z 31 | 32 | @classmethod 33 | def create_model(cls): 34 | chd_enc8 = load_chd_enc8() 35 | return cls(chd_enc8) 36 | 37 | 38 | class LshExternalEncoder(nn.Module): 39 | 40 | def __init__(self, rhy_enc2): 41 | super().__init__() 42 | self.enc = rhy_enc2 43 | self.squeezer = nn.Linear(128, 256) 44 | self.pos_enc = nn.Embedding(4, 256) 45 | 46 | def forward(self, mel_pr): 47 | pos = torch.arange(0, 4).to(mel_pr.device).long() 48 | pos_emb = self.pos_enc(pos).unsqueeze(0) 49 | 50 | bs = mel_pr.size(0) 51 | mel = mel_pr[:, :, 0: 130].reshape(-1, 32, 130) 52 | chd = mel_pr[:, :, 130:].reshape(-1, 32, 12) 53 | 54 | self.enc.eval() 55 | with torch.no_grad(): 56 | _, z = self.enc(mel, chd) 57 | z = z.mean.reshape(bs, 4, 128) 58 | z = self.squeezer(z) + pos_emb 59 | return z 60 | 61 | @classmethod 62 | def create_model(cls): 63 | enc = load_ec2vae_enc2() 64 | return cls(enc) 65 | 66 | 67 | class AccExternalEncoder(nn.Module): 68 | 69 | def __init__(self, txt_enc2): 70 | super().__init__() 71 | self.enc = txt_enc2 72 | self.pos_enc = nn.Embedding(4, 256) 73 | 74 | def forward(self, pr_mat): 75 | pos = torch.arange(0, 4).to(pr_mat.device).long() 76 | pos_emb = self.pos_enc(pos).unsqueeze(0) 77 | 78 | bs = pr_mat.size(0) 79 | pr_mat = pr_mat.reshape(-1, 32, 128) 80 | self.enc.eval() 81 | with torch.no_grad(): 82 | z = self.enc(pr_mat) 83 | z = z.mean.reshape(bs, 4, 256) 84 | z = z + pos_emb 85 | return z 86 | 87 | @classmethod 88 | def create_model(cls): 89 | txt_enc2 = load_txt_enc2() 90 | return cls(txt_enc2) 91 | -------------------------------------------------------------------------------- /model/stable_diffusion/model/pretrained_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.distributions import Normal 4 | import os 5 | 6 | 7 | PROJECT_PATH = os.path.realpath(os.path.join(__file__, '../../../../')) 8 | 9 | 10 | class Ec2VaeEncoder(nn.Module): 11 | 12 | roll_dims = 130 13 | rhythm_dims = 3 14 | condition_dims = 12 15 | 16 | def __init__(self, hidden_dims, z1_dims, z2_dims, n_step): 17 | super(Ec2VaeEncoder, self).__init__() 18 | 19 | assert n_step in [16, 32, 64, 128] 20 | 21 | self.gru_0 = nn.GRU(self.roll_dims + self.condition_dims, 22 | hidden_dims, batch_first=True, bidirectional=True) 23 | self.linear_mu = nn.Linear(hidden_dims * 2, z1_dims + z2_dims) 24 | self.linear_var = nn.Linear(hidden_dims * 2, z1_dims + z2_dims) 25 | 26 | self.hidden_dims = hidden_dims 27 | self.z1_dims = z1_dims 28 | self.z2_dims = z2_dims 29 | self.n_step = n_step 30 | 31 | def encoder(self, x, condition): 32 | # self.gru_0.flatten_parameters() 33 | x = torch.cat((x, condition), -1) 34 | x = self.gru_0(x)[-1] 35 | x = x.transpose_(0, 1).contiguous() 36 | x = x.view(x.size(0), -1) 37 | mu = self.linear_mu(x) 38 | var = self.linear_var(x).exp_() 39 | distribution_1 = Normal(mu[:, :self.z1_dims], var[:, :self.z1_dims]) 40 | distribution_2 = Normal(mu[:, self.z1_dims:], var[:, self.z1_dims:]) 41 | return distribution_1, distribution_2 42 | 43 | def forward(self, x, condition): 44 | return self.encoder(x, condition) 45 | 46 | @classmethod 47 | def create_2bar_encoder(cls, hidden_dims=2048, zp_dims=128, zr_dims=128): 48 | return cls(hidden_dims, zp_dims, zr_dims, 32) 49 | 50 | @classmethod 51 | def create_4bar_encoder(cls, hidden_dims=2048, zp_dims=128, zr_dims=128): 52 | return cls(hidden_dims, zp_dims, zr_dims, 64) 53 | 54 | @classmethod 55 | def create_8bar_encoder(cls, hidden_dims=2048, zp_dims=128, zr_dims=128): 56 | return cls(hidden_dims, zp_dims, zr_dims, 128) 57 | 58 | 59 | class PolydisChordEncoder(nn.Module): 60 | 61 | def __init__(self, input_dim, hidden_dim, z_dim): 62 | super(PolydisChordEncoder, self).__init__() 63 | self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True, 64 | bidirectional=True) 65 | self.linear_mu = nn.Linear(hidden_dim * 2, z_dim) 66 | self.linear_var = nn.Linear(hidden_dim * 2, z_dim) 67 | self.input_dim = input_dim 68 | self.hidden_dim = hidden_dim 69 | self.z_dim = z_dim 70 | 71 | def forward(self, x): 72 | x = self.gru(x)[-1] 73 | x = x.transpose_(0, 1).contiguous() 74 | x = x.view(x.size(0), -1) 75 | mu = self.linear_mu(x) 76 | # var = self.linear_var(x).exp_() 77 | # dist = Normal(mu, var) 78 | return mu 79 | 80 | @classmethod 81 | def create_encoder(cls, hidden_dim=1024, z_dim=256): 82 | return cls(36, hidden_dim, z_dim) 83 | 84 | 85 | class PolydisTextureEncoder(nn.Module): 86 | 87 | def __init__(self, emb_size, hidden_dim, z_dim, num_channel=10): 88 | """input must be piano_mat: (B, 32, 128)""" 89 | super(PolydisTextureEncoder, self).__init__() 90 | self.cnn = nn.Sequential(nn.Conv2d(1, num_channel, kernel_size=(4, 12), 91 | stride=(4, 1), padding=0), 92 | nn.ReLU(), 93 | nn.MaxPool2d(kernel_size=(1, 4), 94 | stride=(1, 4))) 95 | self.fc1 = nn.Linear(num_channel * 29, 1000) 96 | self.fc2 = nn.Linear(1000, emb_size) 97 | self.gru = nn.GRU(emb_size, hidden_dim, batch_first=True, 98 | bidirectional=True) 99 | self.linear_mu = nn.Linear(hidden_dim * 2, z_dim) 100 | self.linear_var = nn.Linear(hidden_dim * 2, z_dim) 101 | self.emb_size = emb_size 102 | self.hidden_dim = hidden_dim 103 | self.z_dim = z_dim 104 | 105 | def forward(self, pr): 106 | # pr: (bs, 32, 128) 107 | bs = pr.size(0) 108 | pr = pr.unsqueeze(1) 109 | pr = self.cnn(pr).view(bs, 8, -1) 110 | pr = self.fc2(self.fc1(pr)) # (bs, 8, emb_size) 111 | pr = self.gru(pr)[-1] 112 | pr = pr.transpose_(0, 1).contiguous() 113 | pr = pr.view(pr.size(0), -1) 114 | mu = self.linear_mu(pr) 115 | var = self.linear_var(pr).exp_() 116 | dist = Normal(mu, var) 117 | return dist 118 | 119 | @classmethod 120 | def create_encoder(cls, emb_size=256, hidden_dim=1024, num_channel=10, z_dim=256): 121 | return cls(emb_size, hidden_dim, z_dim, num_channel) 122 | 123 | 124 | def load_ec2vae_enc8(): 125 | ec2vae_enc8 = Ec2VaeEncoder.create_8bar_encoder() 126 | model_path = os.path.join(PROJECT_PATH, 'pretrained_models', 'ec2vae_enc_8bar.pt') 127 | state_dict = torch.load(model_path) 128 | ec2vae_enc8.load_state_dict(state_dict) 129 | for param in ec2vae_enc8.parameters(): 130 | param.requires_grad = False 131 | return ec2vae_enc8 132 | 133 | 134 | def load_ec2vae_enc2(): 135 | ec2vae_enc2 = Ec2VaeEncoder.create_2bar_encoder() 136 | model_path = os.path.join(PROJECT_PATH, 'pretrained_models', 'ec2vae_enc_2bar.pt') 137 | state_dict = torch.load(model_path) 138 | ec2vae_enc2.load_state_dict(state_dict) 139 | for param in ec2vae_enc2.parameters(): 140 | param.requires_grad = False 141 | return ec2vae_enc2 142 | 143 | 144 | def load_chd_enc8(): 145 | chd_enc = PolydisChordEncoder.create_encoder(hidden_dim=512, z_dim=512) 146 | model_path = os.path.join(PROJECT_PATH, 'pretrained_models', 'chd_enc8.pt') 147 | state_dict = torch.load(model_path) 148 | chd_enc.load_state_dict(state_dict) 149 | for param in chd_enc.parameters(): 150 | param.requires_grad = False 151 | return chd_enc 152 | 153 | 154 | def load_ec2vae_enc4(): 155 | ec2vae_enc4 = Ec2VaeEncoder.create_4bar_encoder() 156 | model_path = os.path.join(PROJECT_PATH, 'pretrained_models', 'ec2vae_enc_4bar.pt') 157 | state_dict = torch.load(model_path) 158 | ec2vae_enc4.load_state_dict(state_dict) 159 | for param in ec2vae_enc4.parameters(): 160 | param.requires_grad = False 161 | return ec2vae_enc4 162 | 163 | 164 | def load_chd_enc2(): 165 | chd_enc = PolydisChordEncoder.create_encoder() 166 | model_path = os.path.join(PROJECT_PATH, 'pretrained_models', 'polydis_chd_enc.pt') 167 | state_dict = torch.load(model_path) 168 | chd_enc.load_state_dict(state_dict) 169 | for param in chd_enc.parameters(): 170 | param.requires_grad = False 171 | return chd_enc 172 | 173 | 174 | def load_txt_enc2(): 175 | txt_enc = PolydisTextureEncoder.create_encoder() 176 | model_path = os.path.join(PROJECT_PATH, 'pretrained_models', 'polydis_txt_enc.pt') 177 | state_dict = torch.load(model_path) 178 | txt_enc.load_state_dict(state_dict) 179 | for param in txt_enc.parameters(): 180 | param.requires_grad = False 181 | return txt_enc 182 | -------------------------------------------------------------------------------- /model/stable_diffusion/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | --- 3 | title: Sampling algorithms for stable diffusion 4 | summary: > 5 | Annotated PyTorch implementation/tutorial of 6 | sampling algorithms 7 | for stable diffusion model. 8 | --- 9 | 10 | # Sampling algorithms for [stable diffusion](../index.html) 11 | 12 | We have implemented the following [sampling algorithms](sampler/index.html): 13 | 14 | * [Denoising Diffusion Probabilistic Models (DDPM) Sampling](ddpm.html) 15 | * [Denoising Diffusion Implicit Models (DDIM) Sampling](ddim.html) 16 | """ 17 | 18 | from typing import Optional, List 19 | 20 | import torch 21 | 22 | from ..latent_diffusion import LatentDiffusion 23 | 24 | 25 | class DiffusionSampler: 26 | """ 27 | ## Base class for sampling algorithms 28 | """ 29 | model: LatentDiffusion 30 | 31 | def __init__(self, model: LatentDiffusion): 32 | """ 33 | :param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$ 34 | """ 35 | super().__init__() 36 | # Set the model $\epsilon_\text{cond}(x_t, c)$ 37 | self.model = model 38 | # Get number of steps the model was trained with $T$ 39 | self.n_steps = model.n_steps 40 | 41 | # def get_eps( 42 | # self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *, uncond_scale: float, 43 | # uncond_cond: Optional[torch.Tensor] 44 | # ): 45 | # """ 46 | # ## Get $\epsilon(x_t, c)$ 47 | # 48 | # :param x: is $x_t$ of shape `[batch_size, channels, height, width]` 49 | # :param t: is $t$ of shape `[batch_size]` 50 | # :param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]` 51 | # :param uncond_scale: is the unconditional guidance scale $s$. This is used for 52 | # $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ 53 | # :param uncond_cond: is the conditional embedding for empty prompt $c_u$ 54 | # """ 55 | # # When the scale $s = 1$ 56 | # # $$\epsilon_\theta(x_t, c) = \epsilon_\text{cond}(x_t, c)$$ 57 | # if uncond_cond is None or uncond_scale == 1.: 58 | # if self.model.cond_enc is not None and ((not hasattr(self.model, 'style_enc')) or self.model.style_enc is None): 59 | # c = self.model.cond_enc(c) 60 | # return self.model(x, t, c) 61 | # elif c is None: 62 | # return self.model(x, t, uncond_cond) 63 | # elif not isinstance(c, tuple): 64 | # c = self.model.cond_enc(c) 65 | # return self.model(x, t, c) 66 | # elif self.model.cond_enc is None: 67 | # return self.model(x, t, uncond_cond) 68 | # elif self.model.style_enc is not None: 69 | # # print('Using style encoder!') 70 | # c1, c2 = c 71 | # c1 = self.model.cond_enc(c1) 72 | # c2 = self.model.style_enc(c2) 73 | # c = torch.cat([c1, c2], 1) 74 | # 75 | # r1 = self.model(x, t, c) 76 | # c2[:] = -1 77 | # c = torch.cat([c1, c2], 1) 78 | # r2 = self.model(x, t, c) 79 | # return 0.1 * r2 + 0.9 * r1 80 | # # return self.model(x, t, c) 81 | # elif uncond_scale == 0.: # unconditional 82 | # return self.model(x, t, uncond_cond) 83 | # 84 | # # Duplicate $x_t$ and $t$ 85 | # x_in = torch.cat([x] * 2) 86 | # t_in = torch.cat([t] * 2) 87 | # # Concatenated $c$ and $c_u$ 88 | # c_in = torch.cat([uncond_cond, c]) 89 | # # Get $\epsilon_\text{cond}(x_t, c)$ and $\epsilon_\text{cond}(x_t, c_u)$ 90 | # e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2) 91 | # # Calculate 92 | # # $$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$$ 93 | # e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond) 94 | # 95 | # # 96 | # return e_t 97 | 98 | def sample( 99 | self, 100 | shape: List[int], 101 | cond: torch.Tensor, 102 | repeat_noise: bool = False, 103 | temperature: float = 1., 104 | x_last: Optional[torch.Tensor] = None, 105 | uncond_scale: float = 1., 106 | uncond_cond: Optional[torch.Tensor] = None, 107 | skip_steps: int = 0, 108 | ): 109 | """ 110 | ### Sampling Loop 111 | 112 | :param shape: is the shape of the generated images in the 113 | form `[batch_size, channels, height, width]` 114 | :param cond: is the conditional embeddings $c$ 115 | :param temperature: is the noise temperature (random noise gets multiplied by this) 116 | :param x_last: is $x_T$. If not provided random noise will be used. 117 | :param uncond_scale: is the unconditional guidance scale $s$. This is used for 118 | $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ 119 | :param uncond_cond: is the conditional embedding for empty prompt $c_u$ 120 | :param skip_steps: is the number of time steps to skip. 121 | """ 122 | raise NotImplementedError() 123 | 124 | def paint( 125 | self, 126 | x: torch.Tensor, 127 | cond: torch.Tensor, 128 | t_start: int, 129 | *, 130 | orig: Optional[torch.Tensor] = None, 131 | mask: Optional[torch.Tensor] = None, 132 | orig_noise: Optional[torch.Tensor] = None, 133 | uncond_scale: float = 1., 134 | uncond_cond: Optional[torch.Tensor] = None, 135 | ): 136 | """ 137 | ### Painting Loop 138 | 139 | :param x: is $x_{T'}$ of shape `[batch_size, channels, height, width]` 140 | :param cond: is the conditional embeddings $c$ 141 | :param t_start: is the sampling step to start from, $T'$ 142 | :param orig: is the original image in latent page which we are in paining. 143 | :param mask: is the mask to keep the original image. 144 | :param orig_noise: is fixed noise to be added to the original image. 145 | :param uncond_scale: is the unconditional guidance scale $s$. This is used for 146 | $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ 147 | :param uncond_cond: is the conditional embedding for empty prompt $c_u$ 148 | """ 149 | raise NotImplementedError() 150 | 151 | def q_sample( 152 | self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None 153 | ): 154 | """ 155 | ### Sample from $q(x_t|x_0)$ 156 | 157 | :param x0: is $x_0$ of shape `[batch_size, channels, height, width]` 158 | :param index: is the time step $t$ index 159 | :param noise: is the noise, $\epsilon$ 160 | """ 161 | raise NotImplementedError() 162 | -------------------------------------------------------------------------------- /model/stable_diffusion/sampler/ddpm.py: -------------------------------------------------------------------------------- 1 | """ 2 | --- 3 | title: Denoising Diffusion Probabilistic Models (DDPM) Sampling 4 | summary: > 5 | Annotated PyTorch implementation/tutorial of 6 | Denoising Diffusion Probabilistic Models (DDPM) Sampling 7 | for stable diffusion model. 8 | --- 9 | 10 | # Denoising Diffusion Probabilistic Models (DDPM) Sampling 11 | 12 | For a simpler DDPM implementation refer to our [DDPM implementation](../../ddpm/index.html). 13 | We use same notations for $\alpha_t$, $\beta_t$ schedules, etc. 14 | """ 15 | 16 | from typing import Optional, List 17 | 18 | import numpy as np 19 | import torch 20 | 21 | from labml import monit 22 | from ..latent_diffusion import LatentDiffusion 23 | from . import DiffusionSampler 24 | 25 | 26 | class DDPMSampler(DiffusionSampler): 27 | """ 28 | ## DDPM Sampler 29 | 30 | This extends the [`DiffusionSampler` base class](index.html). 31 | 32 | DDPM samples images by repeatedly removing noise by sampling step by step from 33 | $p_\theta(x_{t-1} | x_t)$, 34 | 35 | \begin{align} 36 | 37 | p_\theta(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big) \\ 38 | 39 | \mu_t(x_t, t) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0 40 | + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\ 41 | 42 | \tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t \\ 43 | 44 | x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta \\ 45 | 46 | \end{align} 47 | """ 48 | 49 | model: LatentDiffusion 50 | 51 | def __init__(self, model: LatentDiffusion): 52 | """ 53 | :param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$ 54 | """ 55 | super().__init__(model) 56 | 57 | # Sampling steps $1, 2, \dots, T$ 58 | self.time_steps = np.asarray(list(range(self.n_steps))) 59 | 60 | with torch.no_grad(): 61 | # $\bar\alpha_t$ 62 | alpha_bar = self.model.alpha_bar 63 | # $\beta_t$ schedule 64 | beta = self.model.beta 65 | # $\bar\alpha_{t-1}$ 66 | alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]]) 67 | 68 | # $\sqrt{\bar\alpha}$ 69 | self.sqrt_alpha_bar = alpha_bar**.5 70 | # $\sqrt{1 - \bar\alpha}$ 71 | self.sqrt_1m_alpha_bar = alpha_bar**.5 72 | # $\frac{1}{\sqrt{\bar\alpha_t}}$ 73 | self.sqrt_recip_alpha_bar = alpha_bar**-.5 74 | # $\sqrt{\frac{1}{\bar\alpha_t} - 1}$ 75 | self.sqrt_recip_m1_alpha_bar = (1 / alpha_bar - 1)**.5 76 | 77 | # $\frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t$ 78 | variance = beta * (1. - alpha_bar_prev) / (1. - alpha_bar) 79 | # Clamped log of $\tilde\beta_t$ 80 | self.log_var = torch.log(torch.clamp(variance, min=1e-20)) 81 | # $\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$ 82 | self.mean_x0_coef = beta * (alpha_bar_prev**.5) / (1. - alpha_bar) 83 | # $\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}$ 84 | self.mean_xt_coef = (1. - alpha_bar_prev) * ((1 - beta)** 85 | 0.5) / (1. - alpha_bar) 86 | 87 | @torch.no_grad() 88 | def sample( 89 | self, 90 | shape: List[int], 91 | cond: torch.Tensor, 92 | repeat_noise: bool = False, 93 | temperature: float = 1., 94 | x_last: Optional[torch.Tensor] = None, 95 | uncond_scale: float = 1., 96 | uncond_cond: Optional[torch.Tensor] = None, 97 | skip_steps: int = 0, 98 | ): 99 | """ 100 | ### Sampling Loop 101 | 102 | :param shape: is the shape of the generated images in the 103 | form `[batch_size, channels, height, width]` 104 | :param cond: is the conditional embeddings $c$ 105 | :param temperature: is the noise temperature (random noise gets multiplied by this) 106 | :param x_last: is $x_T$. If not provided random noise will be used. 107 | :param uncond_scale: is the unconditional guidance scale $s$. This is used for 108 | $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ 109 | :param uncond_cond: is the conditional embedding for empty prompt $c_u$ 110 | :param skip_steps: is the number of time steps to skip $t'$. We start sampling from $T - t'$. 111 | And `x_last` is then $x_{T - t'}$. 112 | """ 113 | 114 | # Get device and batch size 115 | device = self.model.device 116 | bs = shape[0] 117 | 118 | # Get $x_T$ 119 | x = x_last if x_last is not None else torch.randn(shape, device=device) 120 | 121 | # Time steps to sample at $T - t', T - t' - 1, \dots, 1$ 122 | time_steps = np.flip(self.time_steps)[skip_steps :] 123 | 124 | # Sampling loop 125 | for step in monit.iterate('Sample', time_steps): 126 | # Time step $t$ 127 | ts = x.new_full((bs, ), step, dtype=torch.long) 128 | 129 | # Sample $x_{t-1}$ 130 | x, pred_x0, e_t = self.p_sample( 131 | x, 132 | cond, 133 | ts, 134 | step, 135 | repeat_noise=repeat_noise, 136 | temperature=temperature, 137 | uncond_scale=uncond_scale, 138 | uncond_cond=uncond_cond 139 | ) 140 | 141 | # Return $x_0$ 142 | return x 143 | 144 | @torch.no_grad() 145 | def p_sample( 146 | self, 147 | x: torch.Tensor, 148 | c: torch.Tensor, 149 | t: torch.Tensor, 150 | step: int, 151 | repeat_noise: bool = False, 152 | temperature: float = 1., 153 | uncond_scale: float = 1., 154 | uncond_cond: Optional[torch.Tensor] = None 155 | ): 156 | """ 157 | ### Sample $x_{t-1}$ from $p_\theta(x_{t-1} | x_t)$ 158 | 159 | :param x: is $x_t$ of shape `[batch_size, channels, height, width]` 160 | :param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]` 161 | :param t: is $t$ of shape `[batch_size]` 162 | :param step: is the step $t$ as an integer 163 | :repeat_noise: specified whether the noise should be same for all samples in the batch 164 | :param temperature: is the noise temperature (random noise gets multiplied by this) 165 | :param uncond_scale: is the unconditional guidance scale $s$. This is used for 166 | $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ 167 | :param uncond_cond: is the conditional embedding for empty prompt $c_u$ 168 | """ 169 | 170 | # Get $\epsilon_\theta$ 171 | e_t = self.get_eps(x, t, c, uncond_scale=uncond_scale, uncond_cond=uncond_cond) 172 | 173 | # Get batch size 174 | bs = x.shape[0] 175 | 176 | # $\frac{1}{\sqrt{\bar\alpha_t}}$ 177 | sqrt_recip_alpha_bar = x.new_full( 178 | (bs, 1, 1, 1), self.sqrt_recip_alpha_bar[step] 179 | ) 180 | # $\sqrt{\frac{1}{\bar\alpha_t} - 1}$ 181 | sqrt_recip_m1_alpha_bar = x.new_full( 182 | (bs, 1, 1, 1), self.sqrt_recip_m1_alpha_bar[step] 183 | ) 184 | 185 | # Calculate $x_0$ with current $\epsilon_\theta$ 186 | # 187 | # $$x_0 = \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta$$ 188 | x0 = sqrt_recip_alpha_bar * x - sqrt_recip_m1_alpha_bar * e_t 189 | 190 | # $\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$ 191 | mean_x0_coef = x.new_full((bs, 1, 1, 1), self.mean_x0_coef[step]) 192 | # $\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}$ 193 | mean_xt_coef = x.new_full((bs, 1, 1, 1), self.mean_xt_coef[step]) 194 | 195 | # Calculate $\mu_t(x_t, t)$ 196 | # 197 | # $$\mu_t(x_t, t) = \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0 198 | # + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t$$ 199 | mean = mean_x0_coef * x0 + mean_xt_coef * x 200 | # $\log \tilde\beta_t$ 201 | log_var = x.new_full((bs, 1, 1, 1), self.log_var[step]) 202 | 203 | # Do not add noise when $t = 1$ (final step sampling process). 204 | # Note that `step` is `0` when $t = 1$) 205 | if step == 0: 206 | noise = 0 207 | # If same noise is used for all samples in the batch 208 | elif repeat_noise: 209 | noise = torch.randn((1, *x.shape[1 :])) 210 | # Different noise for each sample 211 | else: 212 | noise = torch.randn(x.shape) 213 | 214 | # Multiply noise by the temperature 215 | noise = noise * temperature 216 | 217 | # Sample from, 218 | # 219 | # $$p_\theta(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big)$$ 220 | x_prev = mean + (0.5 * log_var).exp() * noise 221 | 222 | # 223 | return x_prev, x0, e_t 224 | 225 | @torch.no_grad() 226 | def q_sample( 227 | self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None 228 | ): 229 | """ 230 | ### Sample from $q(x_t|x_0)$ 231 | 232 | $$q(x_t|x_0) = \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$$ 233 | 234 | :param x0: is $x_0$ of shape `[batch_size, channels, height, width]` 235 | :param index: is the time step $t$ index 236 | :param noise: is the noise, $\epsilon$ 237 | """ 238 | 239 | # Random noise, if noise is not specified 240 | if noise is None: 241 | noise = torch.randn_like(x0) 242 | 243 | # Sample from $\mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$ 244 | return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise 245 | -------------------------------------------------------------------------------- /model/stable_diffusion/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | --- 3 | title: Utility functions for stable diffusion 4 | summary: > 5 | Utility functions for stable diffusion 6 | --- 7 | 8 | # Utility functions for [stable diffusion](index.html) 9 | """ 10 | 11 | import os 12 | import random 13 | from pathlib import Path 14 | 15 | import PIL 16 | import numpy as np 17 | import torch 18 | from PIL import Image 19 | 20 | from labml import monit 21 | from labml.logger import inspect 22 | from .latent_diffusion import LatentDiffusion 23 | from .model.autoencoder import Encoder, Decoder, Autoencoder 24 | # from model.clip_embedder import CLIPTextEmbedder 25 | from .model.unet import UNetModel 26 | 27 | 28 | def set_seed(seed: int): 29 | """ 30 | ### Set random seeds 31 | """ 32 | random.seed(seed) 33 | np.random.seed(seed) 34 | torch.manual_seed(seed) 35 | torch.cuda.manual_seed_all(seed) 36 | 37 | 38 | def load_model(path: Path = None) -> LatentDiffusion: 39 | """ 40 | ### Load [`LatentDiffusion` model](latent_diffusion.html) 41 | """ 42 | 43 | # Initialize the autoencoder 44 | with monit.section('Initialize autoencoder'): 45 | encoder = Encoder( 46 | z_channels=4, 47 | in_channels=3, 48 | channels=128, 49 | channel_multipliers=[1, 2, 4, 4], 50 | n_resnet_blocks=2 51 | ) 52 | 53 | decoder = Decoder( 54 | out_channels=3, 55 | z_channels=4, 56 | channels=128, 57 | channel_multipliers=[1, 2, 4, 4], 58 | n_resnet_blocks=2 59 | ) 60 | 61 | autoencoder = Autoencoder( 62 | emb_channels=4, encoder=encoder, decoder=decoder, z_channels=4 63 | ) 64 | 65 | # Initialize the U-Net 66 | with monit.section('Initialize U-Net'): 67 | unet_model = UNetModel( 68 | in_channels=4, 69 | out_channels=4, 70 | channels=320, 71 | attention_levels=[0, 1, 2], 72 | n_res_blocks=2, 73 | channel_multipliers=[1, 2, 4, 4], 74 | n_heads=8, 75 | tf_layers=1, 76 | d_cond=768 77 | ) 78 | 79 | # Initialize the Latent Diffusion model 80 | with monit.section('Initialize Latent Diffusion model'): 81 | model = LatentDiffusion( 82 | linear_start=0.00085, 83 | linear_end=0.0120, 84 | n_steps=1000, 85 | latent_scaling_factor=0.18215, 86 | autoencoder=autoencoder, 87 | unet_model=unet_model 88 | ) 89 | 90 | # Load the checkpoint 91 | with monit.section(f"Loading model from {path}"): 92 | checkpoint = torch.load(path, map_location="cpu") 93 | 94 | # Set model state 95 | with monit.section('Load state'): 96 | missing_keys, extra_keys = model.load_state_dict( 97 | checkpoint["state_dict"], strict=False 98 | ) 99 | 100 | # Debugging output 101 | inspect( 102 | global_step=checkpoint.get('global_step', -1), 103 | missing_keys=missing_keys, 104 | extra_keys=extra_keys, 105 | _expand=True 106 | ) 107 | 108 | # 109 | model.eval() 110 | return model 111 | 112 | 113 | def load_img(path: str): 114 | """ 115 | ### Load an image 116 | 117 | This loads an image from a file and returns a PyTorch tensor. 118 | 119 | :param path: is the path of the image 120 | """ 121 | # Open Image 122 | image = Image.open(path).convert("RGB") 123 | # Get image size 124 | w, h = image.size 125 | # Resize to a multiple of 32 126 | w = w - w % 32 127 | h = h - h % 32 128 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) 129 | # Convert to numpy and map to `[-1, 1]` for `[0, 255]` 130 | image = np.array(image).astype(np.float32) * (2. / 255.0) - 1 131 | # Transpose to shape `[batch_size, channels, height, width]` 132 | image = image[None].transpose(0, 3, 1, 2) 133 | # Convert to torch 134 | return torch.from_numpy(image) 135 | 136 | 137 | def save_images( 138 | images: torch.Tensor, dest_path: str, prefix: str = '', img_format: str = 'jpeg' 139 | ): 140 | """ 141 | ### Save a images 142 | 143 | :param images: is the tensor with images of shape `[batch_size, channels, height, width]` 144 | :param dest_path: is the folder to save images in 145 | :param prefix: is the prefix to add to file names 146 | :param img_format: is the image format 147 | """ 148 | 149 | # Create the destination folder 150 | os.makedirs(dest_path, exist_ok=True) 151 | 152 | # Map images to `[0, 1]` space and clip 153 | images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) 154 | # Transpose to `[batch_size, height, width, channels]` and convert to numpy 155 | images = images.cpu().permute(0, 2, 3, 1).numpy() 156 | 157 | # Save images 158 | for i, img in enumerate(images): 159 | img = Image.fromarray((255. * img).astype(np.uint8)) 160 | img.save( 161 | os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format 162 | ) 163 | -------------------------------------------------------------------------------- /params/__init__.py: -------------------------------------------------------------------------------- 1 | from .attrdict import AttrDict 2 | from .params import params_frm, params_ctp, params_lsh, params_acc 3 | 4 | PARAMS_DICTS = {'frm': params_frm, 'ctp': params_ctp, 'lsh': params_lsh, 'acc': params_acc} 5 | -------------------------------------------------------------------------------- /params/attrdict.py: -------------------------------------------------------------------------------- 1 | class AttrDict(dict): 2 | def __init__(self, *args, **kwargs): 3 | super(AttrDict, self).__init__(*args, **kwargs) 4 | self.__dict__ = self 5 | 6 | def override(self, attrs): 7 | if isinstance(attrs, dict): 8 | self.__dict__.update(**attrs) 9 | elif isinstance(attrs, (list, tuple, set)): 10 | for attr in attrs: 11 | self.override(attr) 12 | elif attrs is not None: 13 | raise NotImplementedError 14 | return self 15 | -------------------------------------------------------------------------------- /params/params.py: -------------------------------------------------------------------------------- 1 | from .attrdict import AttrDict 2 | 3 | 4 | params_frm = AttrDict( 5 | # Training params 6 | batch_size=16, 7 | max_epoch=500, 8 | learning_rate=5e-5, 9 | max_grad_norm=10, 10 | fp16=True, 11 | 12 | # unet 13 | in_channels=8, 14 | out_channels=8, 15 | channels=64, 16 | attention_levels=[2, 3], 17 | n_res_blocks=2, 18 | channel_multipliers=[1, 2, 4, 4], 19 | n_heads=4, 20 | tf_layers=1, 21 | d_cond=12, 22 | 23 | # ldm 24 | linear_start=0.00085, 25 | linear_end=0.0120, 26 | n_steps=1000, 27 | latent_scaling_factor=0.18215 28 | ) 29 | 30 | 31 | params_ctp = AttrDict( 32 | # Training params 33 | batch_size=16, 34 | max_epoch=500, 35 | learning_rate=5e-5, 36 | max_grad_norm=10, 37 | fp16=True, 38 | 39 | # unet 40 | in_channels=10, 41 | out_channels=2, 42 | channels=64, 43 | attention_levels=[2, 3], 44 | n_res_blocks=2, 45 | channel_multipliers=[1, 2, 4, 4], 46 | n_heads=4, 47 | tf_layers=1, 48 | d_cond=128, 49 | 50 | # ldm 51 | linear_start=0.00085, 52 | linear_end=0.0120, 53 | n_steps=1000, 54 | latent_scaling_factor=0.18215 55 | ) 56 | 57 | 58 | params_lsh = AttrDict( 59 | # Training params 60 | batch_size=16, 61 | max_epoch=500, 62 | learning_rate=5e-5, 63 | max_grad_norm=10, 64 | fp16=True, 65 | 66 | # unet 67 | in_channels=12, 68 | out_channels=2, 69 | channels=64, 70 | attention_levels=[2, 3], 71 | n_res_blocks=2, 72 | channel_multipliers=[1, 2, 4, 4], 73 | n_heads=4, 74 | tf_layers=1, 75 | d_cond=256, 76 | 77 | # ldm 78 | linear_start=0.00085, 79 | linear_end=0.0120, 80 | n_steps=1000, 81 | latent_scaling_factor=0.18215 82 | ) 83 | 84 | 85 | params_acc = AttrDict( 86 | # Training params 87 | batch_size=16, 88 | max_epoch=500, 89 | learning_rate=5e-5, 90 | max_grad_norm=10, 91 | fp16=True, 92 | 93 | # unet 94 | in_channels=14, 95 | out_channels=2, 96 | channels=64, 97 | attention_levels=[2, 3], 98 | n_res_blocks=2, 99 | channel_multipliers=[1, 2, 4, 4], 100 | n_heads=4, 101 | tf_layers=1, 102 | d_cond=256, 103 | 104 | # ldm 105 | linear_start=0.00085, 106 | linear_end=0.0120, 107 | n_steps=1000, 108 | latent_scaling_factor=0.18215 109 | ) 110 | -------------------------------------------------------------------------------- /pretrained_models/download_link.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/file/d/1VnCOt0dVrRM96hIARzJRawnICiEk3ai4/view?usp=sharing -------------------------------------------------------------------------------- /results_default/download_link.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/file/d/15hFKcyyWUCybVb9KV2pzGVWIpvQONizR/view?usp=drive_link -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | from datetime import datetime 5 | from torch.utils.data import DataLoader 6 | from torch.optim import Optimizer 7 | from params import AttrDict 8 | from .learner import DiffproLearner 9 | 10 | 11 | class TrainConfig: 12 | 13 | model: torch.nn.Module 14 | train_dl: DataLoader 15 | val_dl: DataLoader 16 | optimizer: Optimizer 17 | 18 | def __init__(self, params, param_scheduler, output_dir) -> None: 19 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | self.params = params 21 | self.param_scheduler = param_scheduler 22 | self.output_dir = output_dir 23 | if os.path.exists(f"{output_dir}/params.json"): 24 | with open(f"{output_dir}/params.json", "r") as params_file: 25 | self.params = AttrDict(json.load(params_file)) 26 | if params != self.params: 27 | print("New params differ, using former params instead...") 28 | params = self.params 29 | 30 | def train(self): 31 | total_parameters = sum( 32 | p.numel() for p in self.model.parameters() if p.requires_grad 33 | ) 34 | print(f"Total parameters: {total_parameters}") 35 | output_dir = self.output_dir 36 | if os.path.exists(f"{output_dir}/chkpts/weights.pt"): 37 | print("Checkpoint already exists.") 38 | if input("Resume training? (y/n)") != "y": 39 | return 40 | else: 41 | output_dir = f"{output_dir}/{datetime.now().strftime('%m-%d_%H%M%S')}" 42 | print(f"Creating new log folder as {output_dir}") 43 | learner = DiffproLearner( 44 | output_dir, self.model, self.train_dl, self.val_dl, self.optimizer, 45 | self.params, self.param_scheduler 46 | ) 47 | learner.train(max_epoch=self.params.max_epoch) 48 | -------------------------------------------------------------------------------- /train/learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import torch.nn as nn 4 | from tqdm import tqdm 5 | from torch.utils.tensorboard.writer import SummaryWriter 6 | from typing import Optional 7 | import os 8 | 9 | 10 | def nested_map(struct, map_fn): 11 | """This is for trasfering into cuda device""" 12 | if isinstance(struct, tuple): 13 | return tuple(nested_map(x, map_fn) for x in struct) 14 | if isinstance(struct, list): 15 | return [nested_map(x, map_fn) for x in struct] 16 | if isinstance(struct, dict): 17 | return {k: nested_map(v, map_fn) for k, v in struct.items()} 18 | return map_fn(struct) 19 | 20 | 21 | class DiffproLearner: 22 | def __init__( 23 | self, output_dir, model, train_dl, val_dl, optimizer, params, param_scheduler 24 | ): 25 | self.output_dir = output_dir 26 | self.log_dir = f"{output_dir}/logs" 27 | self.checkpoint_dir = f"{output_dir}/chkpts" 28 | self.model = model 29 | self.train_dl = train_dl 30 | self.val_dl = val_dl 31 | self.optimizer = optimizer 32 | self.params = params 33 | self.param_scheduler = param_scheduler # teacher-forcing stuff 34 | self.step = 0 35 | self.epoch = 0 36 | self.grad_norm = 0. 37 | self.summary_writer = None 38 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 39 | self.autocast = torch.cuda.amp.autocast(enabled=params.fp16) 40 | self.scaler = torch.cuda.amp.GradScaler(enabled=params.fp16) 41 | 42 | self.best_val_loss = torch.tensor([1e10], device=self.device) 43 | 44 | # restore if directory exists 45 | if os.path.exists(self.output_dir): 46 | self.restore_from_checkpoint() 47 | else: 48 | os.makedirs(self.output_dir) 49 | os.makedirs(self.log_dir) 50 | os.makedirs(self.checkpoint_dir) 51 | with open(f"{output_dir}/params.json", "w") as params_file: 52 | json.dump(self.params, params_file) 53 | 54 | print(json.dumps(self.params, sort_keys=True, indent=4)) 55 | 56 | def _write_summary(self, losses: dict, scheduled_params: Optional[dict], type): 57 | """type: train or val""" 58 | summary_losses = losses 59 | summary_losses["grad_norm"] = self.grad_norm 60 | if scheduled_params is not None: 61 | for k, v in scheduled_params.items(): 62 | summary_losses[f"sched_{k}"] = v 63 | writer = self.summary_writer or SummaryWriter( 64 | self.log_dir, purge_step=self.step 65 | ) 66 | writer.add_scalars(type, summary_losses, self.step) 67 | writer.flush() 68 | self.summary_writer = writer 69 | 70 | def state_dict(self): 71 | model_state = self.model.state_dict() 72 | return { 73 | "step": self.step, 74 | "epoch": self.epoch, 75 | "model": 76 | { 77 | k: v.cpu() if isinstance(v, torch.Tensor) else v 78 | for k, v in model_state.items() 79 | }, 80 | "optimizer": 81 | { 82 | k: v.cpu() if isinstance(v, torch.Tensor) else v 83 | for k, v in self.optimizer.state_dict().items() 84 | }, 85 | "scaler": self.scaler.state_dict(), 86 | } 87 | 88 | def load_state_dict(self, state_dict): 89 | self.step = state_dict["step"] 90 | self.epoch = state_dict["epoch"] 91 | self.model.load_state_dict(state_dict["model"]) 92 | self.optimizer.load_state_dict(state_dict["optimizer"]) 93 | self.scaler.load_state_dict(state_dict["scaler"]) 94 | 95 | def restore_from_checkpoint(self, fname="weights"): 96 | try: 97 | fpath = f"{self.checkpoint_dir}/{fname}.pt" 98 | checkpoint = torch.load(fpath) 99 | self.load_state_dict(checkpoint) 100 | print(f"Restored from checkpoint {fpath} --> {fname}-{self.epoch}.pt!") 101 | return True 102 | except FileNotFoundError: 103 | print("No checkpoint found. Starting from scratch...") 104 | return False 105 | 106 | def _link_checkpoint(self, save_name, link_fpath): 107 | if os.path.islink(link_fpath): 108 | os.unlink(link_fpath) 109 | os.symlink(save_name, link_fpath) 110 | 111 | def save_to_checkpoint(self, fname="weights", is_best=False): 112 | save_name = f"{fname}-{self.epoch}.pt" 113 | save_fpath = f"{self.checkpoint_dir}/{save_name}" 114 | link_best_fpath = f"{self.checkpoint_dir}/{fname}_best.pt" 115 | link_fpath = f"{self.checkpoint_dir}/{fname}.pt" 116 | torch.save(self.state_dict(), save_fpath) 117 | self._link_checkpoint(save_name, link_fpath) 118 | if is_best: 119 | self._link_checkpoint(save_name, link_best_fpath) 120 | 121 | def train(self, max_epoch=None): 122 | self.model.train() 123 | 124 | while True: 125 | if self.param_scheduler is not None: 126 | self.param_scheduler.train() 127 | self.epoch = self.step // len(self.train_dl) 128 | if max_epoch is not None and self.epoch >= max_epoch: 129 | return 130 | 131 | for batch in tqdm(self.train_dl, desc=f"Epoch {self.epoch}"): 132 | batch = nested_map( 133 | batch, lambda x: x.to(self.device) 134 | if isinstance(x, torch.Tensor) else x 135 | ) 136 | losses, scheduled_params = self.train_step(batch) 137 | # check NaN 138 | for loss_value in list(losses.values()): 139 | if isinstance(loss_value, 140 | torch.Tensor) and torch.isnan(loss_value).any(): 141 | raise RuntimeError( 142 | f"Detected NaN loss at step {self.step}, epoch {self.epoch}" 143 | ) 144 | if self.step % 50 == 0: 145 | self._write_summary(losses, scheduled_params, "train") 146 | if self.step % 5000 == 0 and self.step != 0 \ 147 | and self.epoch != 0: 148 | self.valid() 149 | self.step += 1 150 | 151 | # valid 152 | self.valid() 153 | 154 | def valid(self): 155 | # self.model.eval() 156 | if self.param_scheduler is not None: 157 | self.param_scheduler.eval() 158 | losses = None 159 | for batch in self.val_dl: 160 | batch = nested_map( 161 | batch, lambda x: x.to(self.device) if isinstance(x, torch.Tensor) else x 162 | ) 163 | current_losses, _ = self.val_step(batch) 164 | losses = losses or current_losses 165 | for k, v in current_losses.items(): 166 | losses[k] += v 167 | assert losses is not None 168 | for k, v in losses.items(): 169 | losses[k] /= len(self.val_dl) 170 | self._write_summary(losses, None, "val") 171 | 172 | if self.best_val_loss >= losses["loss"]: 173 | self.best_val_loss = losses["loss"] 174 | self.save_to_checkpoint(is_best=True) 175 | else: 176 | self.save_to_checkpoint(is_best=False) 177 | 178 | def train_step(self, batch): 179 | # people say this is the better way to set zero grad 180 | # instead of self.optimizer.zero_grad() 181 | for param in self.model.parameters(): 182 | param.grad = None 183 | 184 | # here forward the model 185 | with self.autocast: 186 | if self.param_scheduler is not None: 187 | scheduled_params = self.param_scheduler.step() 188 | loss_dict = self.model.get_loss_dict( 189 | batch, self.step, **scheduled_params 190 | ) 191 | else: 192 | scheduled_params = None 193 | loss_dict = self.model.get_loss_dict(batch, self.step) 194 | 195 | loss = loss_dict["loss"] 196 | self.scaler.scale(loss).backward() 197 | self.scaler.unscale_(self.optimizer) 198 | self.grad_norm = nn.utils.clip_grad.clip_grad_norm_( 199 | self.model.parameters(), self.params.max_grad_norm or 1e9 200 | ) 201 | self.scaler.step(self.optimizer) 202 | self.scaler.update() 203 | return loss_dict, scheduled_params 204 | 205 | def val_step(self, batch): 206 | with torch.no_grad(): 207 | with self.autocast: 208 | if self.param_scheduler is not None: 209 | scheduled_params = self.param_scheduler.step() 210 | loss_dict = self.model.get_loss_dict( 211 | batch, self.step, **scheduled_params 212 | ) 213 | else: 214 | scheduled_params = None 215 | loss_dict = self.model.get_loss_dict(batch, self.step) 216 | 217 | return loss_dict, scheduled_params 218 | -------------------------------------------------------------------------------- /train/train_config.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from data_utils import load_datasets, create_train_valid_dataloaders 3 | from model import init_ldm_model, init_diff_pro_sdf 4 | 5 | 6 | class LdmTrainConfig(TrainConfig): 7 | 8 | def __init__(self, params, output_dir, mode, use_autoreg_cond, use_external_cond, 9 | mask_background, multi_phrase_label, random_pitch_aug, debug_mode=False) -> None: 10 | super().__init__(params, None, output_dir) 11 | self.debug_mode = debug_mode 12 | self.use_autoreg_cond = use_autoreg_cond 13 | self.use_external_cond = use_external_cond 14 | self.mask_background = mask_background 15 | self.multi_phrase_label = multi_phrase_label 16 | self.random_pitch_aug = random_pitch_aug 17 | 18 | # create model 19 | self.ldm_model = init_ldm_model(mode, use_autoreg_cond, use_external_cond, params, debug_mode) 20 | self.model = init_diff_pro_sdf(self.ldm_model, params, self.device) 21 | 22 | # Create dataloader 23 | load_first_n = 10 if self.debug_mode else None 24 | train_set, valid_set = load_datasets( 25 | mode, multi_phrase_label, random_pitch_aug, use_autoreg_cond, use_external_cond, 26 | mask_background, load_first_n 27 | ) 28 | self.train_dl, self.val_dl = create_train_valid_dataloaders(params.batch_size, train_set, valid_set) 29 | 30 | # Create optimizer 31 | self.optimizer = torch.optim.Adam( 32 | self.model.parameters(), lr=params.learning_rate 33 | ) 34 | -------------------------------------------------------------------------------- /train_main.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from params import PARAMS_DICTS 3 | import os 4 | from train.train_config import LdmTrainConfig 5 | 6 | 7 | def init_parser(): 8 | parser = ArgumentParser(description='train (or resume training) a diffusion model') 9 | parser.add_argument( 10 | "--output_dir", 11 | default='results', 12 | help='directory in which to store model checkpoints and training logs' 13 | ) 14 | parser.add_argument("--mode", help="which model to train (frm, ctp, lsh, acc)") 15 | parser.add_argument('--external', action='store_true', help="whether to use external control") 16 | parser.add_argument('--autoreg', action='store_true', help="whether to use autoreg control") 17 | parser.add_argument('--mask_bg', action='store_true', help="whether to mask background-cond at random") 18 | parser.add_argument('--multi_label', action='store_true', help="whether to use all human phrase labels") 19 | parser.add_argument('--uniform_pitch_shift', action='store_true', 20 | help="whether to apply pitch shift uniformly (as opposed to randomly)") 21 | parser.add_argument('--debug', action='store_true', help="whether to use a toy dataset") 22 | 23 | return parser 24 | 25 | 26 | def args_check(args): 27 | assert args.mode in ['frm', 'ctp', 'lsh', 'acc'] 28 | if args.mode == 'frm': 29 | assert not args.autoreg and not args.external and not args.mask_bg 30 | 31 | 32 | def args_setting_to_fn(args): 33 | def to_str(x: bool, char): 34 | return char if x else '' 35 | 36 | mode = args.mode 37 | autoreg = to_str(args.autoreg, 'a') 38 | external = to_str(args.external, 'e') 39 | mask_bg = to_str(args.mask_bg, 'b') 40 | multi_label = to_str(args.multi_label, 'l') 41 | p_shift = to_str(args.uniform_pitch_shift, 'p') 42 | debug = to_str(args.debug, 'd') 43 | 44 | return f"{mode}-{autoreg}{external}-{mask_bg}{multi_label}{p_shift}-{debug}" 45 | 46 | 47 | if __name__ == "__main__": 48 | 49 | parser = init_parser() 50 | args = parser.parse_args() 51 | args_check(args) 52 | 53 | random_pitch_aug = not args.uniform_pitch_shift 54 | 55 | params = PARAMS_DICTS[args.mode] 56 | if args.debug: 57 | params.override({'batch_size': 2}) 58 | 59 | fn = args_setting_to_fn(args) 60 | 61 | output_dir = os.path.join(args.output_dir, fn) 62 | config = LdmTrainConfig(params, output_dir, args.mode, args.autoreg, args.external, 63 | args.mask_bg, args.multi_label, random_pitch_aug, args.debug) 64 | 65 | config.train() 66 | --------------------------------------------------------------------------------