├── .gitignore ├── LICENSE ├── README.md ├── requirements.txt ├── splits ├── test.txt ├── train.txt └── val.txt └── src ├── __init__.py ├── chord_recognition.py ├── constants.py ├── datasets.py ├── evaluate.py ├── generate.py ├── input_representation.py ├── models ├── __init__.py ├── seq2seq.py └── vae.py ├── precompute_latents.py ├── train.py ├── utils.py └── vocab.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.ckpt 3 | *.mid 4 | *.tar.gz 5 | *.zip 6 | metrics.csv 7 | .DS_Store 8 | lightning_logs 9 | tmp 10 | .venv 11 | .hypothesis 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Dimitri von Rütte 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FIGARO: Generating Symbolic Music with Fine-Grained Artistic Control 2 | 3 | Listen to the samples on [Soundcloud](https://soundcloud.com/user-751999449/sets/figaro-generating-symbolic-music-with-fine-grained-artistic-control). 4 | 5 | Paper: https://openreview.net/forum?id=NyR8OZFHw6i 6 | 7 | Colab Demo: https://colab.research.google.com/drive/1UAKFkbPQTfkYMq1GxXfGZOJXOXU_svo6 8 | 9 | --- 10 | 11 | - [FIGARO: Generating Symbolic Music with Fine-Grained Artistic Control](#figaro-generating-symbolic-music-with-fine-grained-artistic-control) 12 | - [Getting started](#getting-started) 13 | - [Setup](#setup) 14 | - [Preparing the Data](#preparing-the-data) 15 | - [Download Pre-Trained Models](#download-pre-trained-models) 16 | - [Training](#training) 17 | - [Generation](#generation) 18 | - [Evaluation](#evaluation) 19 | - [Parameters](#parameters) 20 | - [Training (`train.py`)](#training-trainpy) 21 | - [Generation (`generate.py`)](#generation-generatepy) 22 | - [Evaluation (`evaluate.py`)](#evaluation-evaluatepy) 23 | 24 | ## Getting started 25 | Prerequisites: 26 | - Python 3.9 27 | - Conda 28 | 29 | ### Setup 30 | 1. Clone this repository to your disk 31 | 3. Install required packages (see requirements.txt). 32 | With `venv`: 33 | ```bash 34 | python3 -m venv .venv 35 | source .venv/bin/activate 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | ### Preparing the Data 40 | 41 | To train models and to generate new samples, we use the [Lakh MIDI](https://colinraffel.com/projects/lmd/) dataset (altough any collection of MIDI files can be used). 42 | 1. Download (size: 1.6GB) and extract the archive file: 43 | ```bash 44 | wget http://hog.ee.columbia.edu/craffel/lmd/lmd_full.tar.gz 45 | tar -xzf lmd_full.tar.gz 46 | ``` 47 | 2. You may wish to remove the archive file now: `rm lmd_full.tar.gz` 48 | 49 | ### Download Pre-Trained Models 50 | If you don't wish to train your own models, you can download our pre-trained models. 51 | 1. Download (size: 2.3GB) and extract the archive file: 52 | ```bash 53 | wget -O checkpoints.zip https://polybox.ethz.ch/index.php/s/a0HUHzKuPPefWkW/download 54 | unzip checkpoints.zip 55 | ``` 56 | 2. You may wish to remove the archive file now: `rm checkpoints.zip` 57 | 58 | 59 | 60 | ## Training 61 | Training arguments such as model type, batch size, model params are passed to the training scripts via environment variables. 62 | 63 | Available model types are: 64 | - `vq-vae`: VQ-VAE model used for the learned desription 65 | - `figaro`: FIGARO with both the expert and learned description 66 | - `figaro-expert`: FIGARO with only the expert description 67 | - `figaro-learned`: FIGARO with only the learned description 68 | - `figaro-no-inst`: FIGARO (expert) without instruments 69 | - `figaro-no-chord`: FIGARO (expert) without chords 70 | - `figaro-no-meta`: FIGARO (expert) without style (meta) information 71 | - `baseline`: Unconditional decoder-only baseline following [Huang et al. (2018)](https://arxiv.org/abs/1809.04281) 72 | 73 | Example invocation of the training script is given by the following command: 74 | ```bash 75 | MODEL=figaro-expert python src/train.py 76 | ``` 77 | 78 | For models using the learned description (`figaro` and `figaro-learned`), a pre-trained VQ-VAE checkpoint needs to be provided as well: 79 | ```bash 80 | MODEL=figaro VAE_CHECKPOINT=./checkpoints/vq-vae.ckpt python src/train.py 81 | ``` 82 | 83 | ## Generation 84 | To generate samples, make sure you have a trained checkpoint prepared (either download one or train it yourself). 85 | For this script, make sure that the dataset is prepared according to [Preparing the Data](#preparing-the-data). 86 | This is needed to extract descriptions, based on which new samples can be generated. 87 | 88 | An example invocation of the generation script is given by the following command: 89 | ```bash 90 | python src/generate.py --model figaro-expert --checkpoint ./checkpoints/figaro-expert.ckpt 91 | ``` 92 | 93 | For models using the learned description (`figaro` and `figaro-learned`), a pre-trained VQ-VAE checkpoint needs to be provided as well: 94 | ```bash 95 | python src/generate.py --model figaro --checkpoint ./checkpoints/figaro.ckpt --vae_checkpoint ./checkpoints/vq-vae.ckpt 96 | ``` 97 | 98 | ## Evaluation 99 | 100 | We provide the evaluation scripts used to calculate the desription metrics on some set of generated samples. 101 | Refer to the previous section for how to generate samples yourself. 102 | 103 | Example usage: 104 | ```bash 105 | python src/evaluate.py --samples_dir ./samples/figaro-expert 106 | ``` 107 | 108 | It has been pointed out that the order of the dataset files (from which the splits are calculated) is non-deterministic and depends on the OS. 109 | To address this and to ensure reproducibility, I have added the exact files used for training/validation/testing in the respective file in the `splits` folder. 110 | 111 | ## Parameters 112 | The following environment variables are available for controlling hyperparameters beyond their default value. 113 | ### Training (`train.py`) 114 | Model 115 | | Variable | Description | Default value | 116 | |-|-|-| 117 | | `MODEL` | Model architecture to be trained | | 118 | | `D_MODEL` | Hidden size of the model | 512 | 119 | | `CONTEXT_SIZE` | Number of tokens in the context to be passed to the auto-encoder | 256 | 120 | | `D_LATENT` | [VQ-VAE] Dimensionality of the latent space | 1024 | 121 | | `N_CODES` | [VQ-VAE] Codebook size | 2048 | 122 | | `N_GROUPS` | [VQ-VAE] Number of groups to split the latent vector into before discretization | 16 | 123 | 124 | Optimization 125 | | Variable | Description | Default value | 126 | |-|-|-| 127 | | `EPOCHS` | Max. number of training epochs | 16 | 128 | | `MAX_TRAINING_STEPS` | Max. number of training iterations | 100,000 | 129 | | `BATCH_SIZE` | Number of samples in each batch | 128 | 130 | | `TARGET_BATCH_SIZE` | Number of samples in each backward step, gradients will be accumulated over `TARGET_BATCH_SIZE//BATCH_SIZE` batches | 256 | 131 | | `WARMUP_STEPS` | Number of learning rate warmup steps | 4000 | 132 | | `LEARNING_RATE` | Initial learning rate, will be decayed after constant warmup of `WARMUP_STEPS` steps | 1e-4 | 133 | 134 | Others 135 | | Variable | Description | Default value | 136 | |-|-|-| 137 | | `CHECKPOINT` | Path to checkpoint from which to resume training | | 138 | | `VAE_CHECKPOINT` | Path to VQ-VAE checkpoint to be used for the learned description | | 139 | | `ROOT_DIR` | The folder containing MIDI files to train on | `./lmd_full` | 140 | | `OUTPUT_DIR` | Folder for saving checkpoints | `./results` | 141 | | `LOGGING_DIR` | Folder for saving logs | `./logs` | 142 | | `N_WORKERS` | Number of workers to be used for the dataloader | available CPUs | 143 | 144 | 145 | ### Generation (`generate.py`) 146 | The generation script uses command line arguments instead of environment variables. 147 | | Argument | Description | Default value | 148 | |-|-|-| 149 | | `--model` | Specify which model will be loaded | | 150 | | `--checkpoint` | Path to the checkpoint for the specified model | | 151 | | `--vae_checkpoint` | Path to the VQ-VAE checkpoint to be used for the learned description (if applicable) | | 152 | | `--lmd_dir` | Folder containing MIDI files to extract descriptions from | `./lmd_full` | 153 | | `--output_dir` | Folder to save generated MIDI samples to | `./samples` | 154 | | `--max_iter` | Max. number of tokens that should be generated | 16,000 | 155 | | `--max_bars` | Max. number of bars that should be generated | 32 | 156 | | `--make_medleys` | Set to `True` if descriptions should be combined into medleys. | `False` | 157 | | `--n_medley_pieces` | Number of pieces to be combined into one | 2 | 158 | | `--n_medley_bars` | Number of bars to take from each piece | 16 | 159 | | `--verbose` | Logging level, set to 0 for silent execution | 2 | 160 | 161 | 162 | ### Evaluation (`evaluate.py`) 163 | The evaluation script uses command line arguments instead of environment variables. 164 | | Argument | Description | Default value | 165 | |-|-|-| 166 | | `--samples_dir` | Folder containing generated samples which should be evaluated | `./samples` | 167 | | `--output_file` | CSV file to which a detailed log of all metrics will be saved to | `./metrics.csv` | 168 | | `--max_samples` | Limit the number of samples to be used for computing evaluation metrics | 1024 | 169 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy~=1.26.3 2 | pandas~=2.2.0 3 | pretty-midi==0.2.10 4 | pytorch-lightning~=2.1.3 5 | scikit-learn~=1.4.0 6 | scipy~=1.12.0 7 | torch~=2.1.2 8 | torchtext~=0.16.2 9 | transformers~=4.37.0 10 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvruette/figaro/9da30291e2865bcbdad1e85ccea82df1a61119e9/src/__init__.py -------------------------------------------------------------------------------- /src/chord_recognition.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class MIDIChord(object): 4 | def __init__(self, pm): 5 | self.pm = pm 6 | # define pitch classes 7 | self.PITCH_CLASSES = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] 8 | # define chord maps (required) 9 | self.CHORD_MAPS = {'maj': [0, 4], 10 | 'min': [0, 3], 11 | 'dim': [0, 3, 6], 12 | 'aug': [0, 4, 8], 13 | 'dom7': [0, 4, 10], 14 | 'maj7': [0, 4, 11], 15 | 'min7': [0, 3, 10]} 16 | # define chord insiders (+10) 17 | self.CHORD_INSIDERS = {'maj': [7], 18 | 'min': [7], 19 | 'dim': [9], 20 | 'aug': [], 21 | 'dom7': [7], 22 | 'maj7': [7], 23 | 'min7': [7]} 24 | # define chord outsiders (-1) 25 | self.CHORD_OUTSIDERS_1 = {'maj': [2, 5, 9], 26 | 'min': [2, 5, 8], 27 | 'dim': [2, 5, 10], 28 | 'aug': [2, 5, 9], 29 | 'dom7': [2, 5, 9], 30 | 'maj7': [2, 5, 9], 31 | 'maj7': [2, 5, 9], 32 | 'min7': [2, 5, 8]} 33 | # define chord outsiders (-2) 34 | self.CHORD_OUTSIDERS_2 = {'maj': [1, 3, 6, 8, 10, 11], 35 | 'min': [1, 4, 6, 9, 11], 36 | 'dim': [1, 4, 7, 8, 11], 37 | 'aug': [1, 3, 6, 7, 10], 38 | 'dom7': [1, 3, 6, 8, 11], 39 | 'maj7': [1, 3, 6, 8, 10], 40 | 'min7': [1, 4, 6, 9, 11]} 41 | 42 | def sequencing(self, chroma): 43 | candidates = {} 44 | for index in range(len(chroma)): 45 | if chroma[index]: 46 | root_note = index 47 | _chroma = np.roll(chroma, -root_note) 48 | sequence = np.where(_chroma == 1)[0] 49 | candidates[root_note] = list(sequence) 50 | return candidates 51 | 52 | def scoring(self, candidates): 53 | scores = {} 54 | qualities = {} 55 | for root_note, sequence in candidates.items(): 56 | if 3 not in sequence and 4 not in sequence: 57 | scores[root_note] = -100 58 | qualities[root_note] = 'None' 59 | elif 3 in sequence and 4 in sequence: 60 | scores[root_note] = -100 61 | qualities[root_note] = 'None' 62 | else: 63 | # decide quality 64 | if 3 in sequence: 65 | if 6 in sequence: 66 | quality = 'dim' 67 | else: 68 | if 10 in sequence: 69 | quality = 'min7' 70 | else: 71 | quality = 'min' 72 | elif 4 in sequence: 73 | if 8 in sequence: 74 | quality = 'aug' 75 | else: 76 | if 10 in sequence: 77 | quality = 'dom7' 78 | elif 11 in sequence: 79 | quality = 'maj7' 80 | else: 81 | quality = 'maj' 82 | # decide score 83 | maps = self.CHORD_MAPS.get(quality) 84 | _notes = [n for n in sequence if n not in maps] 85 | score = 0 86 | for n in _notes: 87 | if n in self.CHORD_OUTSIDERS_1.get(quality): 88 | score -= 1 89 | elif n in self.CHORD_OUTSIDERS_2.get(quality): 90 | score -= 2 91 | elif n in self.CHORD_INSIDERS.get(quality): 92 | score += 10 93 | scores[root_note] = score 94 | qualities[root_note] = quality 95 | return scores, qualities 96 | 97 | def find_chord(self, chroma, threshold=10): 98 | chroma = np.sum(chroma, axis=1) 99 | chroma = np.array([1 if c > threshold else 0 for c in chroma]) 100 | if np.sum(chroma) == 0: 101 | return 'N', 'N', 'N', 10 102 | else: 103 | candidates = self.sequencing(chroma=chroma) 104 | scores, qualities = self.scoring(candidates=candidates) 105 | # bass note 106 | sorted_notes = [] 107 | for i, v in enumerate(chroma): 108 | if v > 0: 109 | sorted_notes.append(int(i%12)) 110 | bass_note = sorted_notes[0] 111 | # root note 112 | __root_note = [] 113 | _max = max(scores.values()) 114 | for _root_note, score in scores.items(): 115 | if score == _max: 116 | __root_note.append(_root_note) 117 | if len(__root_note) == 1: 118 | root_note = __root_note[0] 119 | else: 120 | #TODO: what should i do 121 | for n in sorted_notes: 122 | if n in __root_note: 123 | root_note = n 124 | break 125 | # quality 126 | quality = qualities.get(root_note) 127 | sequence = candidates.get(root_note) 128 | # score 129 | score = scores.get(root_note) 130 | return self.PITCH_CLASSES[root_note], quality, self.PITCH_CLASSES[bass_note], score 131 | 132 | def greedy(self, candidates, max_tick, min_length): 133 | chords = [] 134 | # start from 0 135 | start_tick = 0 136 | while start_tick < max_tick: 137 | _candidates = candidates.get(start_tick) 138 | _candidates = sorted(_candidates.items(), key=lambda x: (x[1][-1], x[0])) 139 | # choose 140 | end_tick, (root_note, quality, bass_note, _) = _candidates[-1] 141 | if root_note == bass_note: 142 | chord = '{}:{}'.format(root_note, quality) 143 | else: 144 | chord = '{}:{}/{}'.format(root_note, quality, bass_note) 145 | chords.append([start_tick, end_tick, chord]) 146 | start_tick = end_tick 147 | # remove :None 148 | temp = chords 149 | while ':None' in temp[0][-1]: 150 | try: 151 | temp[1][0] = temp[0][0] 152 | del temp[0] 153 | except: 154 | print('NO CHORD') 155 | return [] 156 | temp2 = [] 157 | for chord in temp: 158 | if ':None' not in chord[-1]: 159 | temp2.append(chord) 160 | else: 161 | temp2[-1][1] = chord[1] 162 | return temp2 163 | 164 | def dynamic(self, candidates, max_tick, min_length): 165 | # store index of best chord at each position 166 | chords = [None for i in range(max_tick + 1)] 167 | # store score of best chords at each position 168 | scores = np.zeros(max_tick + 1) 169 | scores[1:].fill(np.NINF) 170 | 171 | start_tick = 0 172 | while start_tick < max_tick: 173 | if start_tick in candidates: 174 | for i, (end_tick, candidate) in enumerate(candidates.get(start_tick).items()): 175 | root_note, quality, bass_note, score = candidate 176 | # if this candidate is best yet, update scores and chords 177 | if scores[end_tick] < scores[start_tick] + score: 178 | scores[end_tick] = scores[start_tick] + score 179 | if root_note == bass_note: 180 | chord = '{}:{}'.format(root_note, quality) 181 | else: 182 | chord = '{}:{}/{}'.format(root_note, quality, bass_note) 183 | chords[end_tick] = (start_tick, end_tick, chord) 184 | start_tick += 1 185 | # Read the best path 186 | start_tick = len(chords) - 1 187 | results = [] 188 | while start_tick > 0: 189 | chord = chords[start_tick] 190 | start_tick = chord[0] 191 | results.append(chord) 192 | 193 | return list(reversed(results)) 194 | 195 | def dedupe(self, chords): 196 | if len(chords) == 0: 197 | return [] 198 | deduped = [] 199 | start, end, chord = chords[0] 200 | for (curr, next) in zip(chords[:-1], chords[1:]): 201 | if chord == next[2]: 202 | end = next[1] 203 | else: 204 | deduped.append([start, end, chord]) 205 | start, end, chord = next 206 | deduped.append([start, end, chord]) 207 | return deduped 208 | 209 | def get_candidates(self, chroma, max_tick, intervals=[1, 2, 3, 4]): 210 | candidates = {} 211 | for interval in intervals: 212 | for start_beat in range(max_tick): 213 | # set target pianoroll 214 | end_beat = start_beat + interval 215 | if end_beat > max_tick: 216 | end_beat = max_tick 217 | _chroma = chroma[:, start_beat:end_beat] 218 | # find chord 219 | root_note, quality, bass_note, score = self.find_chord(chroma=_chroma) 220 | # save 221 | if start_beat not in candidates: 222 | candidates[start_beat] = {} 223 | candidates[start_beat][end_beat] = (root_note, quality, bass_note, score) 224 | else: 225 | if end_beat not in candidates[start_beat]: 226 | candidates[start_beat][end_beat] = (root_note, quality, bass_note, score) 227 | return candidates 228 | 229 | def extract(self): 230 | # read 231 | beats = self.pm.get_beats() 232 | chroma = self.pm.get_chroma(times=beats) 233 | # get lots of candidates 234 | candidates = self.get_candidates(chroma, max_tick=len(beats)) 235 | 236 | # greedy 237 | chords = self.dynamic(candidates=candidates, 238 | max_tick=len(beats), 239 | min_length=1) 240 | chords = self.dedupe(chords) 241 | for chord in chords: 242 | chord[0] = beats[chord[0]] 243 | if chord[1] >= len(beats): 244 | chord[1] = self.pm.get_end_time() 245 | else: 246 | chord[1] = beats[chord[1]] 247 | return chords -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # parameters for input representation 4 | DEFAULT_POS_PER_QUARTER = 12 5 | DEFAULT_VELOCITY_BINS = np.linspace(0, 128, 32+1, dtype=int) 6 | DEFAULT_DURATION_BINS = np.sort(np.concatenate([ 7 | np.arange(1, 13), # smallest possible units up to 1 quarter 8 | np.arange(12, 24, 3)[1:], # 16th notes up to 1 bar 9 | np.arange(13, 24, 4)[1:], # triplets up to 1 bar 10 | np.arange(24, 48, 6), # 8th notes up to 2 bars 11 | np.arange(48, 4*48, 12), # quarter notes up to 8 bars 12 | np.arange(4*48, 16*48+1, 24) # half notes up to 16 bars 13 | ])) 14 | DEFAULT_TEMPO_BINS = np.linspace(0, 240, 32+1, dtype=int) 15 | DEFAULT_NOTE_DENSITY_BINS = np.linspace(0, 12, 32+1) 16 | DEFAULT_MEAN_VELOCITY_BINS = np.linspace(0, 128, 32+1) 17 | DEFAULT_MEAN_PITCH_BINS = np.linspace(0, 128, 32+1) 18 | DEFAULT_MEAN_DURATION_BINS = np.logspace(0, 7, 32+1, base=2) # log space between 1 and 128 positions (~2.5 bars) 19 | 20 | # parameters for output 21 | DEFAULT_RESOLUTION = 480 22 | 23 | # maximum length of a single bar is 3*4 = 12 beats 24 | MAX_BAR_LENGTH = 3 25 | # maximum number of bars in a piece is 512 (this covers almost all sequences) 26 | MAX_N_BARS = 512 27 | 28 | PAD_TOKEN = '' 29 | UNK_TOKEN = '' 30 | BOS_TOKEN = '' 31 | EOS_TOKEN = '' 32 | MASK_TOKEN = '' 33 | 34 | TIME_SIGNATURE_KEY = 'Time Signature' 35 | BAR_KEY = 'Bar' 36 | POSITION_KEY = 'Position' 37 | INSTRUMENT_KEY = 'Instrument' 38 | PITCH_KEY = 'Pitch' 39 | VELOCITY_KEY = 'Velocity' 40 | DURATION_KEY = 'Duration' 41 | TEMPO_KEY = 'Tempo' 42 | CHORD_KEY = 'Chord' 43 | 44 | NOTE_DENSITY_KEY = 'Note Density' 45 | MEAN_PITCH_KEY = 'Mean Pitch' 46 | MEAN_VELOCITY_KEY = 'Mean Velocity' 47 | MEAN_DURATION_KEY = 'Mean Duration' -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, IterableDataset 3 | from torchdata.datapipes.iter import IterableWrapper 4 | from torch.nn.utils.rnn import pad_sequence 5 | import pytorch_lightning as pl 6 | import math 7 | import os 8 | import pickle 9 | 10 | from input_representation import InputRepresentation 11 | from vocab import RemiVocab, DescriptionVocab 12 | from constants import ( 13 | PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, BAR_KEY, POSITION_KEY, 14 | TIME_SIGNATURE_KEY, INSTRUMENT_KEY, CHORD_KEY, 15 | NOTE_DENSITY_KEY, MEAN_PITCH_KEY, MEAN_VELOCITY_KEY, MEAN_DURATION_KEY 16 | ) 17 | 18 | 19 | CACHE_PATH = os.getenv('CACHE_PATH', os.getenv('SCRATCH', os.getenv('TMPDIR', './temp'))) 20 | LATENT_CACHE_PATH = os.getenv('LATENT_CACHE_PATH', os.path.join(os.getenv('SCRATCH', os.getenv('TMPDIR', './temp')), 'latent')) 21 | 22 | class MidiDataModule(pl.LightningDataModule): 23 | def __init__(self, 24 | files, 25 | max_len, 26 | batch_size=32, 27 | num_workers=4, 28 | pin_memory=True, 29 | description_flavor='none', 30 | train_val_test_split=(0.95, 0.1, 0.05), 31 | vae_module=None, 32 | **kwargs): 33 | super().__init__() 34 | self.batch_size = batch_size 35 | self.pin_memory = pin_memory 36 | self.num_workers = num_workers 37 | self.files = files 38 | self.train_val_test_split = train_val_test_split 39 | self.vae_module = vae_module 40 | self.max_len = max_len 41 | self.description_flavor = description_flavor 42 | 43 | if self.description_flavor in ['latent', 'both']: 44 | assert self.vae_module is not None, "Description flavor 'latent' requires 'vae_module' to be present, but found 'None'" 45 | 46 | self.vocab = RemiVocab() 47 | 48 | self.kwargs = kwargs 49 | 50 | def setup(self, stage=None): 51 | # n_train = int(self.train_val_test_split[0] * len(self.files)) 52 | n_valid = int(self.train_val_test_split[1] * len(self.files)) 53 | n_test = int(self.train_val_test_split[2] * len(self.files)) 54 | train_files = self.files[n_test+n_valid:] 55 | valid_files = self.files[n_test:n_test+n_valid] 56 | test_files = self.files[:n_test] 57 | 58 | self.train_ds = MidiDataset(train_files, self.max_len, 59 | description_flavor=self.description_flavor, 60 | vae_module=self.vae_module, 61 | **self.kwargs 62 | ) 63 | self.valid_ds = MidiDataset(valid_files, self.max_len, 64 | description_flavor=self.description_flavor, 65 | vae_module=self.vae_module, 66 | **self.kwargs 67 | ) 68 | self.test_ds = MidiDataset(test_files, self.max_len, 69 | description_flavor=self.description_flavor, 70 | vae_module=self.vae_module, 71 | **self.kwargs 72 | ) 73 | 74 | # Use a shuffled dataset only for training 75 | # self.train_ds = torch.utils.data.datapipes.iter.combinatorics.ShuffleIterDataPipe(self.train_ds, buffer_size=2048) 76 | self.train_ds = IterableWrapper(self.train_ds) 77 | self.train_ds.shuffle(buffer_size=2048) 78 | 79 | self.collator = SeqCollator(pad_token=self.vocab.to_i(PAD_TOKEN), context_size=self.max_len) 80 | 81 | def train_dataloader(self): 82 | return DataLoader( 83 | self.train_ds, 84 | collate_fn=self.collator, 85 | batch_size=self.batch_size, 86 | pin_memory=self.pin_memory, 87 | num_workers=self.num_workers, 88 | shuffle=True, 89 | ) 90 | 91 | def val_dataloader(self): 92 | return DataLoader( 93 | self.valid_ds, 94 | collate_fn=self.collator, 95 | batch_size=self.batch_size, 96 | pin_memory=self.pin_memory, 97 | num_workers=self.num_workers, 98 | persistent_workers=True, 99 | ) 100 | 101 | def test_dataloader(self): 102 | return DataLoader( 103 | self.test_ds, 104 | collate_fn=self.collator, 105 | batch_size=self.batch_size, 106 | pin_memory=self.pin_memory, 107 | num_workers=self.num_workers, 108 | ) 109 | 110 | 111 | def _get_split(files, worker_info): 112 | if worker_info: 113 | n_workers = worker_info.num_workers 114 | worker_id = worker_info.id 115 | 116 | per_worker = math.ceil(len(files) / n_workers) 117 | start_idx = per_worker*worker_id 118 | end_idx = start_idx + per_worker 119 | 120 | split = files[start_idx:end_idx] 121 | else: 122 | split = files 123 | return split 124 | 125 | 126 | class SeqCollator: 127 | def __init__(self, pad_token=0, context_size=512): 128 | self.pad_token = pad_token 129 | self.context_size = context_size 130 | 131 | def __call__(self, features): 132 | batch = {} 133 | 134 | xs = [feature['input_ids'] for feature in features] 135 | xs = pad_sequence(xs, batch_first=True, padding_value=self.pad_token) 136 | 137 | if self.context_size > 0: 138 | max_len = self.context_size 139 | max_desc_len = self.context_size 140 | else: 141 | max_len = xs.size(1) 142 | max_desc_len = int(1e4) 143 | 144 | tmp = xs[:, :(max_len + 1)][:, :-1] 145 | labels = xs[:, :(max_len + 1)][:, 1:].clone().detach() 146 | xs = tmp 147 | 148 | seq_len = xs.size(1) 149 | 150 | batch['input_ids'] = xs 151 | batch['labels'] = labels 152 | 153 | if 'position_ids' in features[0]: 154 | position_ids = [feature['position_ids'] for feature in features] 155 | position_ids = pad_sequence(position_ids, batch_first=True, padding_value=0) 156 | batch['position_ids'] = position_ids[:, :seq_len] 157 | 158 | if 'bar_ids' in features[0]: 159 | bar_ids = [feature['bar_ids'] for feature in features] 160 | bar_ids = pad_sequence(bar_ids, batch_first=True, padding_value=0) 161 | batch['bar_ids'] = bar_ids[:, :seq_len] 162 | 163 | if 'latents' in features[0]: 164 | latents = [feature['latents'] for feature in features] 165 | latents = pad_sequence(latents, batch_first=True, padding_value=0.0) 166 | batch['latents'] = latents[:, :max_desc_len] 167 | 168 | if 'codes' in features[0]: 169 | codes = [feature['codes'] for feature in features] 170 | codes = pad_sequence(codes, batch_first=True, padding_value=0) 171 | batch['codes'] = codes[:, :max_desc_len] 172 | 173 | if 'description' in features[0]: 174 | description = [feature['description'] for feature in features] 175 | description = pad_sequence(description, batch_first=True, padding_value=self.pad_token) 176 | desc = description[:, :max_desc_len] 177 | batch['description'] = desc 178 | 179 | if 'desc_bar_ids' in features[0]: 180 | desc_len = desc.size(1) 181 | desc_bar_ids = [feature['desc_bar_ids'] for feature in features] 182 | desc_bar_ids = pad_sequence(desc_bar_ids, batch_first=True, padding_value=0) 183 | batch['desc_bar_ids'] = desc_bar_ids[:, :desc_len] 184 | 185 | if 'file' in features[0]: 186 | batch['files'] = [feature['file'] for feature in features] 187 | 188 | return batch 189 | 190 | class MidiDataset(IterableDataset): 191 | def __init__(self, 192 | midi_files, 193 | max_len, 194 | description_flavor='none', 195 | description_options=None, 196 | vae_module=None, 197 | group_bars=False, 198 | max_bars=512, 199 | max_positions=512, 200 | max_bars_per_context=-1, 201 | max_contexts_per_file=-1, 202 | bar_token_mask=None, 203 | bar_token_idx=2, 204 | use_cache=True, 205 | print_errors=False): 206 | self.files = midi_files 207 | self.group_bars = group_bars 208 | self.max_len = max_len 209 | self.max_bars = max_bars 210 | self.max_positions = max_positions 211 | self.max_bars_per_context = max_bars_per_context 212 | self.max_contexts_per_file = max_contexts_per_file 213 | self.use_cache = use_cache 214 | self.print_errors = print_errors 215 | 216 | self.vocab = RemiVocab() 217 | 218 | self.description_flavor = description_flavor 219 | if self.description_flavor in ['latent', 'both']: 220 | assert vae_module is not None 221 | self.vae_module = vae_module.cpu() 222 | self.vae_module.eval() 223 | self.vae_module.freeze() 224 | self.description_options = description_options 225 | 226 | self.desc_vocab = DescriptionVocab() 227 | 228 | self.bar_token_mask = bar_token_mask 229 | self.bar_token_idx = bar_token_idx 230 | 231 | if CACHE_PATH: 232 | self.cache_path = os.path.join(CACHE_PATH, InputRepresentation.version()) 233 | os.makedirs(self.cache_path, exist_ok=True) 234 | # print(f"Using cache path: {self.cache_path}") 235 | else: 236 | self.cache_path = None 237 | 238 | if self.description_flavor in ['latent', 'both'] and LATENT_CACHE_PATH: 239 | self.latent_cache_path = LATENT_CACHE_PATH 240 | os.makedirs(self.latent_cache_path, exist_ok=True) 241 | # print(f"Using latent cache path: {self.latent_cache_path}") 242 | else: 243 | self.latent_cache_path = None 244 | 245 | 246 | def __iter__(self): 247 | worker_info = torch.utils.data.get_worker_info() 248 | self.split = _get_split(self.files, worker_info) 249 | 250 | split_len = len(self.split) 251 | 252 | for i in range(split_len): 253 | try: 254 | current_file = self.load_file(self.split[i]) 255 | except ValueError as err: 256 | if self.print_errors: 257 | print(err) 258 | # raise err 259 | continue 260 | 261 | events = current_file['events'] 262 | 263 | # Identify start of bars 264 | bars, bar_ids = self.get_bars(events, include_ids=True) 265 | if len(bars) > self.max_bars: 266 | if self.print_errors: 267 | print(f"WARNING: REMI sequence has more than {self.max_bars} bars: {len(bars)} event bars.") 268 | continue 269 | 270 | # Identify positions 271 | position_ids = self.get_positions(events) 272 | max_pos = position_ids.max() 273 | if max_pos > self.max_positions: 274 | if self.print_errors: 275 | print(f"WARNING: REMI sequence has more than {self.max_positions} positions: {max_pos.item()} positions found") 276 | continue 277 | 278 | # Mask bar tokens if required 279 | if self.bar_token_mask is not None and self.max_bars_per_context > 0: 280 | events = self.mask_bar_tokens(events, bar_token_mask=self.bar_token_mask) 281 | 282 | # Encode tokens with appropriate vocabulary 283 | event_ids = torch.tensor(self.vocab.encode(events), dtype=torch.long) 284 | 285 | bos, eos = self.get_bos_eos_events() 286 | zero = torch.tensor([0], dtype=torch.int) 287 | 288 | if self.max_bars_per_context and self.max_bars_per_context > 0: 289 | # Find all indices where a new context starts based on number of bars per context 290 | starts = [bars[i] for i in range(0, len(bars), self.max_bars_per_context)] 291 | # Convert starts to ranges 292 | contexts = list(zip(starts[:-1], starts[1:])) + [(starts[-1], len(event_ids))] 293 | # # Limit the size of the range if it's larger than the max. context size 294 | # contexts = [(max(start, end - (self.max_len+1)), end) for (start, end) in contexts] 295 | 296 | else: 297 | event_ids = torch.cat([bos, event_ids, eos]) 298 | bar_ids = torch.cat([zero, bar_ids, zero]) 299 | position_ids = torch.cat([zero, position_ids, zero]) 300 | 301 | if self.max_len > 0: 302 | starts = list(range(0, len(event_ids), self.max_len+1)) 303 | if len(starts) > 1: 304 | contexts = [(start, start + self.max_len+1) for start in starts[:-1]] + [(len(event_ids) - (self.max_len+1), len(event_ids))] 305 | elif len(starts) > 0: 306 | contexts = [(starts[0], self.max_len+1)] 307 | else: 308 | contexts = [(0, len(event_ids))] 309 | 310 | if self.max_contexts_per_file and self.max_contexts_per_file > 0: 311 | contexts = contexts[:self.max_contexts_per_file] 312 | 313 | for start, end in contexts: 314 | # Add and to each context if contexts are limited to a certain number of bars 315 | if self.max_bars_per_context and self.max_bars_per_context > 0: 316 | src = torch.cat([bos, event_ids[start:end], eos]) 317 | b_ids = torch.cat([zero, bar_ids[start:end], zero]) 318 | p_ids = torch.cat([zero, position_ids[start:end], zero]) 319 | else: 320 | src = event_ids[start:end] 321 | b_ids = bar_ids[start:end] 322 | p_ids = position_ids[start:end] 323 | 324 | if self.max_len > 0: 325 | src = src[:self.max_len + 1] 326 | 327 | x = { 328 | 'input_ids': src, 329 | 'file': os.path.basename(self.split[i]), 330 | 'bar_ids': b_ids, 331 | 'position_ids': p_ids, 332 | } 333 | 334 | if self.description_flavor in ['description', 'both']: 335 | # Assume that bar_ids are in ascending order (except for EOS) 336 | min_bar = b_ids[0] 337 | desc_events = current_file['description'] 338 | desc_bars = [i for i, event in enumerate(desc_events) if f"{BAR_KEY}_" in event] 339 | # subtract one since first bar has id == 1 340 | start_idx = desc_bars[max(0, min_bar - 1)] 341 | 342 | desc_bar_ids = torch.zeros(len(desc_events), dtype=torch.int) 343 | desc_bar_ids[desc_bars] = 1 344 | desc_bar_ids = torch.cumsum(desc_bar_ids, dim=0) 345 | 346 | if self.max_bars_per_context and self.max_bars_per_context > 0: 347 | end_idx = desc_bars[min_bar + self.max_bars_per_context] 348 | desc_events = desc_events[start_idx:end_idx] 349 | desc_bar_ids = desc_bar_ids[start_idx:end_idx] 350 | start_idx = 0 351 | 352 | desc_bos = torch.tensor(self.desc_vocab.encode([BOS_TOKEN]), dtype=torch.int) 353 | desc_eos = torch.tensor(self.desc_vocab.encode([EOS_TOKEN]), dtype=torch.int) 354 | desc_ids = torch.tensor(self.desc_vocab.encode(desc_events), dtype=torch.int) 355 | if min_bar == 0: 356 | desc_ids = torch.cat([desc_bos, desc_ids, desc_eos]) 357 | desc_bar_ids = torch.cat([zero, desc_bar_ids, zero]) 358 | else: 359 | desc_ids = torch.cat([desc_ids, desc_eos]) 360 | desc_bar_ids = torch.cat([desc_bar_ids, zero]) 361 | 362 | if self.max_len > 0: 363 | start, end = start_idx, start_idx + self.max_len + 1 364 | x['description'] = desc_ids[start:end] 365 | x['desc_bar_ids'] = desc_bar_ids[start:end] 366 | else: 367 | x['description'] = desc_ids[start:] 368 | x['desc_bar_ids'] = desc_bar_ids[start:] 369 | 370 | if self.description_flavor in ['latent', 'both']: 371 | x['latents'] = current_file['latents'] 372 | x['codes'] = current_file['codes'] 373 | 374 | yield x 375 | 376 | def get_bars(self, events, include_ids=False): 377 | bars = [i for i, event in enumerate(events) if f"{BAR_KEY}_" in event] 378 | 379 | if include_ids: 380 | bar_ids = torch.bincount(torch.tensor(bars, dtype=torch.int), minlength=len(events)) 381 | bar_ids = torch.cumsum(bar_ids, dim=0) 382 | 383 | return bars, bar_ids 384 | else: 385 | return bars 386 | 387 | def get_positions(self, events): 388 | events = [f"{POSITION_KEY}_0" if f"{BAR_KEY}_" in event else event for event in events] 389 | position_events = [event if f"{POSITION_KEY}_" in event else None for event in events] 390 | 391 | positions = [int(pos.split('_')[-1]) if pos is not None else None for pos in position_events] 392 | 393 | if positions[0] is None: 394 | positions[0] = 0 395 | for i in range(1, len(positions)): 396 | if positions[i] is None: 397 | positions[i] = positions[i-1] 398 | positions = torch.tensor(positions, dtype=torch.int) 399 | 400 | return positions 401 | 402 | def mask_bar_tokens(self, events, bar_token_mask=''): 403 | events = [bar_token_mask if f'{BAR_KEY}_' in token else token for token in events] 404 | return events 405 | 406 | def get_bos_eos_events(self, tuple_size=8): 407 | bos_event = torch.tensor(self.vocab.encode([BOS_TOKEN]), dtype=torch.long) 408 | eos_event = torch.tensor(self.vocab.encode([EOS_TOKEN]), dtype=torch.long) 409 | return bos_event, eos_event 410 | 411 | def preprocess_description(self, desc, instruments=True, chords=True, meta=True): 412 | valid_keys = { 413 | BAR_KEY: True, 414 | INSTRUMENT_KEY: instruments, 415 | CHORD_KEY: chords, 416 | TIME_SIGNATURE_KEY: meta, 417 | NOTE_DENSITY_KEY: meta, 418 | MEAN_PITCH_KEY: meta, 419 | MEAN_VELOCITY_KEY: meta, 420 | MEAN_DURATION_KEY: meta, 421 | } 422 | return [token for token in desc if len(token.split('_')) == 0 or valid_keys[token.split('_')[0]]] 423 | 424 | def load_file(self, file): 425 | name = os.path.basename(file) 426 | if self.cache_path and self.use_cache: 427 | cache_file = os.path.join(self.cache_path, name) 428 | 429 | try: 430 | # Try to load the file in case it's already in the cache 431 | sample = pickle.load(open(cache_file, 'rb')) 432 | except Exception: 433 | # If there's no cached version, compute the representations 434 | try: 435 | rep = InputRepresentation(file, strict=True) 436 | events = rep.get_remi_events() 437 | description = rep.get_description() 438 | except Exception as err: 439 | raise ValueError(f'Unable to load file {file}') from err 440 | 441 | sample = { 442 | 'events': events, 443 | 'description': description 444 | } 445 | 446 | if self.use_cache: 447 | # Try to store the computed representation in the cache directory 448 | try: 449 | pickle.dump(sample, open(cache_file, 'wb')) 450 | except Exception as err: 451 | print('Unable to cache file:', str(err)) 452 | 453 | if self.description_flavor in ['latent', 'both']: 454 | latents, codes = self.get_latent_representation(sample['events'], name) 455 | sample['latents'] = latents 456 | sample['codes'] = codes 457 | 458 | if self.description_options is not None and len(self.description_options) > 0: 459 | opts = self.description_options 460 | kwargs = { key: opts[key] for key in ['instruments', 'chords', 'meta'] if key in opts } 461 | sample['description'] = self.preprocess_description(sample['description'], **self.description_options) 462 | 463 | return sample 464 | 465 | def get_latent_representation(self, events, cache_key=None, bar_token_mask=''): 466 | if cache_key and self.use_cache: 467 | cache_file = os.path.join(self.latent_cache_path, cache_key) 468 | 469 | try: 470 | latents, codes = pickle.load(open(cache_file, 'rb')) 471 | except Exception: 472 | bars = self.get_bars(events) 473 | self.mask_bar_tokens(events, bar_token_mask=bar_token_mask) 474 | 475 | event_ids = torch.tensor(self.vocab.encode(events), dtype=torch.long) 476 | 477 | groups = [event_ids[start:end] for start, end in zip(bars[:-1], bars[1:])] 478 | groups.append(event_ids[bars[-1]:]) 479 | 480 | bos, eos = self.get_bos_eos_events() 481 | 482 | self.vae_module.eval() 483 | self.vae_module.freeze() 484 | 485 | latents = [] 486 | codes = [] 487 | for bar in groups: 488 | x = torch.cat([bos, bar, eos])[:self.vae_module.context_size].unsqueeze(0) 489 | out = self.vae_module.encode(x) 490 | z, code = out['z'], out['codes'] 491 | latents.append(z) 492 | codes.append(code) 493 | 494 | latents = torch.cat(latents) 495 | codes = torch.cat(codes) 496 | 497 | if self.use_cache: 498 | # Try to store the computed representation in the cache directory 499 | try: 500 | pickle.dump((latents.cpu(), codes.cpu()), open(cache_file, 'wb')) 501 | except Exception as err: 502 | print('Unable to cache file:', str(err)) 503 | 504 | return latents.cpu(), codes.cpu() 505 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, glob 3 | from statistics import NormalDist 4 | import pandas as pd 5 | import numpy as np 6 | 7 | import input_representation as ir 8 | 9 | METRICS = [ 10 | 'inst_prec', 'inst_rec', 'inst_f1', 11 | 'chord_prec', 'chord_rec', 'chord_f1', 12 | 'time_sig_acc', 13 | 'note_dens_oa', 'pitch_oa', 'velocity_oa', 'duration_oa', 14 | 'chroma_crossent', 'chroma_kldiv', 'chroma_sim', 15 | 'groove_crossent', 'groove_kldiv', 'groove_sim', 16 | ] 17 | 18 | DF_KEYS = ['id', 'original', 'sample'] + METRICS 19 | 20 | keys = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] 21 | qualities = ['maj', 'min', 'dim', 'aug', 'dom7', 'maj7', 'min7', 'None'] 22 | CHORDS = [f"{k}:{q}" for k in keys for q in qualities] + ['N:N'] 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--samples_dir', type=str, default="./samples") 28 | parser.add_argument('--output_file', type=str, default="./metrics.csv") 29 | parser.add_argument('--max_samples', type=int, default=1024) 30 | args = parser.parse_args() 31 | return args 32 | 33 | def get_group_id(file): 34 | # change this depending on name of generated samples 35 | name = os.path.basename(file) 36 | return name.split('.')[0] 37 | 38 | def get_file_groups(path, max_samples=1024): 39 | # change this depending on file structure of generated samples 40 | files = glob.glob(os.path.join(path, '*.mid'), recursive=True) 41 | assert len(files), f"provided directory was empty: {path}" 42 | 43 | samples = sorted(files) 44 | origs = sorted([os.path.join(path, 'ground_truth', os.path.basename(file)) for file in files]) 45 | pairs = list(zip(origs, samples)) 46 | 47 | pairs = list(filter(lambda pair: os.path.exists(pair[0]), pairs)) 48 | if max_samples > 0: 49 | pairs = pairs[:max_samples] 50 | 51 | groups = dict() 52 | for orig, sample in pairs: 53 | sample_id = get_group_id(sample) 54 | orig_id = get_group_id(orig) 55 | assert sample_id == orig_id, f"Sample id doesn't match original id: {sample} and {orig}" 56 | if sample_id not in groups: 57 | groups[sample_id] = list() 58 | groups[sample_id].append((orig, sample)) 59 | 60 | return list(groups.values()) 61 | 62 | def read_file(file): 63 | with open(file, 'r') as f: 64 | events = f.read().split('\n') 65 | events = [e for e in events if e] 66 | return events 67 | 68 | def get_chord_groups(desc): 69 | bars = [1 if 'Bar_' in item else 0 for item in desc] 70 | bar_ids = np.cumsum(bars) - 1 71 | groups = [[] for _ in range(bar_ids[-1] + 1)] 72 | for i, item in enumerate(desc): 73 | if 'Chord_' in item: 74 | chord = item.split('_')[-1] 75 | groups[bar_ids[i]].append(chord) 76 | return groups 77 | 78 | def instruments(events): 79 | insts = [128 if item.instrument == 'drum' else int(item.instrument) for item in events[1:-1] if item.name == 'Note'] 80 | insts = np.bincount(insts, minlength=129) 81 | return (insts > 0).astype(int) 82 | 83 | def chords(events): 84 | chords = [CHORDS.index(item) for item in events] 85 | chords = np.bincount(chords, minlength=129) 86 | return (chords > 0).astype(int) 87 | 88 | def chroma(events): 89 | pitch_classes = [item.pitch % 12 for item in events[1:-1] if item.name == 'Note' and item.instrument != 'drum'] 90 | if len(pitch_classes): 91 | count = np.bincount(pitch_classes, minlength=12) 92 | count = count / np.sqrt(np.sum(count ** 2)) 93 | else: 94 | count = np.array([1/12] * 12) 95 | return count 96 | 97 | def groove(events, start=0, pos_per_bar=48, ticks_per_bar=1920): 98 | flags = np.linspace(start, start + ticks_per_bar, pos_per_bar, endpoint=False) 99 | onsets = [item.start for item in events[1:-1] if item.name == 'Note'] 100 | positions = [np.argmin(np.abs(flags - beat)) for beat in onsets] 101 | if len(positions): 102 | count = np.bincount(positions, minlength=pos_per_bar) 103 | count = np.convolve(count, [1, 4, 1], 'same') 104 | count = count / np.sqrt(np.sum(count ** 2)) 105 | else: 106 | count = np.array([1/pos_per_bar] * pos_per_bar) 107 | return count 108 | 109 | def multi_class_accuracy(y_true, y_pred): 110 | tp = ((y_true == 1) & (y_pred == 1)).sum() 111 | p = tp / y_pred.sum() 112 | r = tp / y_true.sum() 113 | if p + r > 0: 114 | f1 = 2*p*r / (p + r) 115 | else: 116 | f1 = 0 117 | return p, r, f1 118 | 119 | def cross_entropy(p_true, p_pred, eps=1e-8): 120 | return -np.sum(p_true * np.log(p_pred + eps)) / len(p_true) 121 | 122 | def kl_divergence(p_true, p_pred, eps=1e-8): 123 | return np.sum(p_true * (np.log(p_true + eps) - np.log(p_pred + eps))) / len(p_true) 124 | 125 | def cosine_sim(p_true, p_pred): 126 | return np.sum(p_true * p_pred) 127 | 128 | def sliding_window_metrics(items, start, end, window=1920, step=480, ticks_per_beat=480): 129 | glob_start, glob_end = start, end 130 | notes = [item for item in items if item.name == 'Note'] 131 | starts = np.arange(glob_start, glob_end - window, step=step) 132 | 133 | groups = [] 134 | start_idx, end_idx = 0, 0 135 | for start in starts: 136 | while notes[start_idx].start < start: 137 | start_idx += 1 138 | while end_idx < len(notes) and notes[end_idx].start < start + window: 139 | end_idx += 1 140 | 141 | groups.append([start] + notes[start_idx:end_idx] + [start + window]) 142 | return groups 143 | 144 | def meta_stats(group, ticks_per_beat=480): 145 | start, end = group[0], group[-1] 146 | ns = [item for item in group[1:-1] if item.name == 'Note'] 147 | ns_ = [note for note in ns if note.instrument != 'drum'] 148 | pitches = [note.pitch for note in ns_] 149 | vels = [note.velocity for note in ns_] 150 | durs = [(note.end - note.start) / ticks_per_beat for note in ns_] 151 | 152 | return { 153 | 'note_density': len(ns) / ((end - start) / ticks_per_beat), 154 | 'pitch_mean': np.mean(pitches) if len(pitches) else np.nan, 155 | 'velocity_mean': np.mean(vels) if len(vels) else np.nan, 156 | 'duration_mean': np.mean(durs) if len(durs) else np.nan, 157 | 'pitch_std': np.std(pitches) if len(pitches) else np.nan, 158 | 'velocity_std': np.std(vels) if len(vels) else np.nan, 159 | 'duration_std': np.std(durs) if len(durs) else np.nan, 160 | } 161 | 162 | def overlapping_area(mu1, sigma1, mu2, sigma2, eps=0.01): 163 | sigma1, sigma2 = max(eps, sigma1), max(eps, sigma2) 164 | return NormalDist(mu=mu1, sigma=sigma1).overlap(NormalDist(mu=mu2, sigma=sigma2)) 165 | 166 | 167 | 168 | def main(): 169 | args = parse_args() 170 | file_groups = get_file_groups(args.samples_dir, max_samples=args.max_samples) 171 | 172 | metrics = pd.DataFrame() 173 | for sample_id, group in enumerate(file_groups): 174 | 175 | micro_metrics = pd.DataFrame() 176 | for orig_file, sample_file in group: 177 | print(f"[info] Group {sample_id+1}/{len(file_groups)} | original: {orig_file} | sample: {sample_file}") 178 | orig = ir.InputRepresentation(orig_file) 179 | sample = ir.InputRepresentation(sample_file) 180 | 181 | orig_desc, sample_desc = orig.get_description(), sample.get_description() 182 | if len(orig_desc) == 0 or len(sample_desc) == 0: 183 | print("[warning] empty sample! skipping") 184 | continue 185 | 186 | chord_groups1 = get_chord_groups(orig_desc) 187 | chord_groups2 = get_chord_groups(sample_desc) 188 | 189 | note_density_gt = [] 190 | 191 | for g1, g2, cg1, cg2 in zip(orig.groups, sample.groups, chord_groups1, chord_groups2): 192 | row = pd.DataFrame([{ 'id': sample_id, 'original': orig_file, 'sample': sample_file }]) 193 | 194 | meta1, meta2 = meta_stats(g1, ticks_per_beat=orig.pm.resolution), meta_stats(g2, ticks_per_beat=sample.pm.resolution) 195 | row['pitch_oa'] = overlapping_area(meta1['pitch_mean'], meta1['pitch_std'], meta2['pitch_mean'], meta2['pitch_std']) 196 | row['velocity_oa'] = overlapping_area(meta1['velocity_mean'], meta1['velocity_std'], meta2['velocity_mean'], meta2['velocity_std']) 197 | row['duration_oa'] = overlapping_area(meta1['duration_mean'], meta1['duration_std'], meta2['duration_mean'], meta2['duration_std']) 198 | row['note_density_abs_err'] = np.abs(meta1['note_density'] - meta2['note_density']) 199 | row['mean_pitch_abs_err'] = np.abs(meta1['pitch_mean'] - meta2['pitch_mean']) 200 | row['mean_velocity_abs_err'] = np.abs(meta1['velocity_mean'] - meta2['velocity_mean']) 201 | row['mean_duration_abs_err'] = np.abs(meta1['duration_mean'] - meta2['duration_mean']) 202 | note_density_gt.append(meta1['note_density']) 203 | 204 | ts1, ts2 = orig._get_time_signature(g1[0]), sample._get_time_signature(g2[0]) 205 | ts1, ts2 = f"{ts1.numerator}/{ts1.denominator}", f"{ts2.numerator}/{ts2.denominator}" 206 | row['time_sig_acc'] = 1 if ts1 == ts2 else 0 207 | 208 | inst1, inst2 = instruments(g1), instruments(g2) 209 | prec, rec, f1 = multi_class_accuracy(inst1, inst2) 210 | row['inst_prec'] = prec 211 | row['inst_rec'] = rec 212 | row['inst_f1'] = f1 213 | 214 | chords1, chords2 = chords(cg1), chords(cg2) 215 | prec, rec, f1 = multi_class_accuracy(chords1, chords2) 216 | row['chord_prec'] = prec 217 | row['chord_rec'] = rec 218 | row['chord_f1'] = f1 219 | 220 | c1, c2 = chroma(g1), chroma(g2) 221 | row['chroma_crossent'] = cross_entropy(c1, c2) 222 | row['chroma_kldiv'] = kl_divergence(c1, c2) 223 | row['chroma_sim'] = cosine_sim(c1, c2) 224 | 225 | ppb = max(orig._get_positions_per_bar(g1[0]), sample._get_positions_per_bar(g2[0])) 226 | tpb = max(orig._get_ticks_per_bar(g1[0]), sample._get_ticks_per_bar(g2[0])) 227 | r1 = groove(g1, start=g1[0], pos_per_bar=ppb, ticks_per_bar=tpb) 228 | r2 = groove(g2, start=g2[0], pos_per_bar=ppb, ticks_per_bar=tpb) 229 | row['groove_crossent'] = cross_entropy(r1, r2) 230 | row['groove_kldiv'] = kl_divergence(r1, r2) 231 | row['groove_sim'] = cosine_sim(r1, r2) 232 | 233 | micro_metrics = pd.concat([micro_metrics, row], ignore_index=True) 234 | if len(micro_metrics) == 0: 235 | continue 236 | 237 | nd_mean = np.mean(note_density_gt) 238 | micro_metrics['note_density_nsq_err'] = micro_metrics['note_density_abs_err']**2 / nd_mean**2 239 | 240 | metrics = pd.concat([metrics, micro_metrics], ignore_index=True) 241 | 242 | micro_avg = micro_metrics.mean(numeric_only=True) 243 | print("[info] Group {}: inst_f1={:.2f} | chord_f1={:.2f} | pitch_oa={:.2f} | vel_oa={:.2f} | dur_oa={:.2f} | chroma_sim={:.2f} | groove_sim={:.2f}".format( 244 | sample_id+1, micro_avg['inst_f1'], micro_avg['chord_f1'], micro_avg['pitch_oa'], micro_avg['velocity_oa'], micro_avg['duration_oa'], micro_avg['chroma_sim'], micro_avg['groove_sim'] 245 | )) 246 | 247 | os.makedirs(os.path.dirname(args.output_file), exist_ok=True) 248 | metrics.to_csv(args.output_file) 249 | 250 | summary_keys = ['inst_f1', 'chord_f1', 'time_sig_acc', 'pitch_oa', 'velocity_oa', 'duration_oa', 'chroma_sim', 'groove_sim'] 251 | summary = metrics[summary_keys + ['id']].groupby('id').mean().mean() 252 | 253 | nsq_err = metrics.groupby('id')['note_density_nsq_err'].mean() 254 | summary['note_density_nrmse'] = np.sqrt(nsq_err).mean() 255 | 256 | print('***** SUMMARY *****') 257 | print(summary) 258 | 259 | if __name__ == '__main__': 260 | main() 261 | -------------------------------------------------------------------------------- /src/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import glob 4 | import torch 5 | import random 6 | from torch.utils.data import DataLoader 7 | from transformers.models.bert.modeling_bert import BertAttention 8 | 9 | from models.vae import VqVaeModule 10 | from models.seq2seq import Seq2SeqModule 11 | from datasets import MidiDataset, SeqCollator 12 | from utils import medley_iterator 13 | from input_representation import remi2midi 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | # parser.add_argument('--model', type=str, required=True, help="Model name (one of 'figaro', 'figaro-expert', 'figaro-learned', 'figaro-no-inst', 'figaro-no-chord', 'figaro-no-meta')") 19 | # parser.add_argument('--checkpoint', type=str, required=True, help="Path to the model checkpoint") 20 | parser.add_argument('--model', type=str, default="figaro-expert") 21 | parser.add_argument('--checkpoint', type=str, default="../figaro-expert.ckpt") 22 | parser.add_argument('--vae_checkpoint', type=str, default=None, help="Path to the VQ-VAE model checkpoint (optional)") 23 | parser.add_argument('--lmd_dir', type=str, default='./lmd_full', help="Path to the root directory of the LakhMIDI dataset") 24 | parser.add_argument('--output_dir', type=str, default='./samples', help="Path to the output directory") 25 | parser.add_argument('--max_n_files', type=int, default=-1) 26 | parser.add_argument('--max_iter', type=int, default=16_000) 27 | parser.add_argument('--max_bars', type=int, default=32) 28 | parser.add_argument('--make_medleys', type=bool, default=False) 29 | parser.add_argument('--n_medley_pieces', type=int, default=2) 30 | parser.add_argument('--n_medley_bars', type=int, default=16) 31 | parser.add_argument('--batch_size', type=int, default=1) 32 | parser.add_argument('--verbose', type=int, default=2) 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | def load_old_or_new_checkpoint(model_class, checkpoint): 38 | # assuming transformers>=4.36.0 39 | pl_ckpt = torch.load(checkpoint, map_location="cpu") 40 | kwargs = pl_ckpt['hyper_parameters'] 41 | if 'flavor' in kwargs: 42 | del kwargs['flavor'] 43 | if 'vae_run' in kwargs: 44 | del kwargs['vae_run'] 45 | model = model_class(**kwargs) 46 | state_dict = pl_ckpt['state_dict'] 47 | # position_ids are no longer saved in the state_dict starting with transformers==4.31.0 48 | state_dict = {k: v for k, v in state_dict.items() if not k.endswith('embeddings.position_ids')} 49 | try: 50 | # succeeds for checkpoints trained with transformers>4.13.0 51 | model.load_state_dict(state_dict) 52 | except RuntimeError: 53 | # work around a breaking change introduced in transformers==4.13.0, which fixed the position_embedding_type of cross-attention modules "absolute" 54 | config = model.transformer.decoder.bert.config 55 | for layer in model.transformer.decoder.bert.encoder.layer: 56 | layer.crossattention = BertAttention(config, position_embedding_type=config.position_embedding_type) 57 | model.load_state_dict(state_dict) 58 | model.freeze() 59 | model.eval() 60 | return model 61 | 62 | 63 | def load_model(checkpoint, vae_checkpoint=None, device='auto'): 64 | if device == 'auto': 65 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 66 | 67 | vae_module = None 68 | if vae_checkpoint: 69 | vae_module = load_old_or_new_checkpoint(VqVaeModule, vae_checkpoint) 70 | vae_module.cpu() 71 | 72 | model = load_old_or_new_checkpoint(Seq2SeqModule, checkpoint) 73 | model.to(device) 74 | 75 | return model, vae_module 76 | 77 | 78 | @torch.no_grad() 79 | def reconstruct_sample(model, batch, 80 | initial_context=1, 81 | output_dir=None, 82 | max_iter=-1, 83 | max_bars=-1, 84 | verbose=0, 85 | ): 86 | batch_size, seq_len = batch['input_ids'].shape[:2] 87 | 88 | batch_ = { key: batch[key][:, :initial_context] for key in ['input_ids', 'bar_ids', 'position_ids'] } 89 | if model.description_flavor in ['description', 'both']: 90 | batch_['description'] = batch['description'] 91 | batch_['desc_bar_ids'] = batch['desc_bar_ids'] 92 | if model.description_flavor in ['latent', 'both']: 93 | batch_['latents'] = batch['latents'] 94 | 95 | max_len = seq_len + 1024 96 | if max_iter > 0: 97 | max_len = min(max_len, initial_context + max_iter) 98 | if verbose: 99 | print(f"Generating sequence ({initial_context} initial / {max_len} max length / {max_bars} max bars / {batch_size} batch size)") 100 | sample = model.sample(batch_, max_length=max_len, max_bars=max_bars, verbose=verbose//2) 101 | 102 | xs = batch['input_ids'].detach().cpu() 103 | xs_hat = sample['sequences'].detach().cpu() 104 | events = [model.vocab.decode(x) for x in xs] 105 | events_hat = [model.vocab.decode(x) for x in xs_hat] 106 | 107 | pms, pms_hat = [], [] 108 | n_fatal = 0 109 | for rec, rec_hat in zip(events, events_hat): 110 | try: 111 | pm = remi2midi(rec) 112 | pms.append(pm) 113 | except Exception as err: 114 | print("ERROR: Could not convert events to midi:", err) 115 | try: 116 | pm_hat = remi2midi(rec_hat) 117 | pms_hat.append(pm_hat) 118 | except Exception as err: 119 | print("ERROR: Could not convert events to midi:", err) 120 | n_fatal += 1 121 | 122 | if output_dir: 123 | os.makedirs(os.path.join(output_dir, 'ground_truth'), exist_ok=True) 124 | for pm, pm_hat, file in zip(pms, pms_hat, batch['files']): 125 | if verbose: 126 | print(f"Saving to {output_dir}/{file}") 127 | pm.write(os.path.join(output_dir, 'ground_truth', file)) 128 | pm_hat.write(os.path.join(output_dir, file)) 129 | 130 | return events 131 | 132 | 133 | def main(): 134 | args = parse_args() 135 | if args.make_medleys: 136 | max_bars = args.n_medley_pieces * args.n_medley_bars 137 | else: 138 | max_bars = args.max_bars 139 | 140 | if args.output_dir: 141 | params = [] 142 | if args.make_medleys: 143 | params.append(f"n_pieces={args.n_medley_pieces}") 144 | params.append(f"n_bars={args.n_medley_bars}") 145 | if args.max_iter > 0: 146 | params.append(f"max_iter={args.max_iter}") 147 | if args.max_bars > 0: 148 | params.append(f"max_bars={args.max_bars}") 149 | output_dir = os.path.join(args.output_dir, args.model, ','.join(params)) 150 | else: 151 | raise ValueError("args.output_dir must be specified.") 152 | 153 | print(f"Saving generated files to: {output_dir}") 154 | 155 | model, vae_module = load_model(args.checkpoint, args.vae_checkpoint) 156 | 157 | 158 | midi_files = glob.glob(os.path.join(args.lmd_dir, '**/*.mid'), recursive=True) 159 | 160 | dm = model.get_datamodule(midi_files, vae_module=vae_module) 161 | dm.setup('test') 162 | midi_files = dm.test_ds.files 163 | random.shuffle(midi_files) 164 | 165 | if args.max_n_files > 0: 166 | midi_files = midi_files[:args.max_n_files] 167 | 168 | 169 | description_options = None 170 | if args.model in ['figaro-no-inst', 'figaro-no-chord', 'figaro-no-meta']: 171 | description_options = model.description_options 172 | 173 | dataset = MidiDataset( 174 | midi_files, 175 | max_len=-1, 176 | description_flavor=model.description_flavor, 177 | description_options=description_options, 178 | max_bars=model.context_size, 179 | vae_module=vae_module 180 | ) 181 | 182 | coll = SeqCollator(context_size=-1) 183 | dl = DataLoader(dataset, batch_size=args.batch_size, collate_fn=coll) 184 | 185 | if args.make_medleys: 186 | dl = medley_iterator(dl, 187 | n_pieces=args.n_medley_pieces, 188 | n_bars=args.n_medley_bars, 189 | description_flavor=model.description_flavor 190 | ) 191 | 192 | with torch.no_grad(): 193 | for batch in dl: 194 | reconstruct_sample(model, batch, 195 | output_dir=output_dir, 196 | max_iter=args.max_iter, 197 | max_bars=max_bars, 198 | verbose=args.verbose, 199 | ) 200 | 201 | if __name__ == '__main__': 202 | main() 203 | -------------------------------------------------------------------------------- /src/input_representation.py: -------------------------------------------------------------------------------- 1 | from chord_recognition import MIDIChord 2 | import numpy as np 3 | import pretty_midi 4 | 5 | from vocab import RemiVocab 6 | 7 | from constants import ( 8 | EOS_TOKEN, 9 | # vocab keys 10 | TIME_SIGNATURE_KEY, 11 | BAR_KEY, 12 | POSITION_KEY, 13 | INSTRUMENT_KEY, 14 | PITCH_KEY, 15 | VELOCITY_KEY, 16 | DURATION_KEY, 17 | TEMPO_KEY, 18 | CHORD_KEY, 19 | NOTE_DENSITY_KEY, 20 | MEAN_PITCH_KEY, 21 | MEAN_VELOCITY_KEY, 22 | MEAN_DURATION_KEY, 23 | # discretization parameters 24 | DEFAULT_POS_PER_QUARTER, 25 | DEFAULT_VELOCITY_BINS, 26 | DEFAULT_DURATION_BINS, 27 | DEFAULT_TEMPO_BINS, 28 | DEFAULT_NOTE_DENSITY_BINS, 29 | DEFAULT_MEAN_VELOCITY_BINS, 30 | DEFAULT_MEAN_PITCH_BINS, 31 | DEFAULT_MEAN_DURATION_BINS, 32 | DEFAULT_RESOLUTION 33 | ) 34 | 35 | # define "Item" for general storage 36 | class Item(object): 37 | def __init__(self, name, start, end, velocity=None, pitch=None, instrument=None): 38 | self.name = name 39 | self.start = start 40 | self.end = end 41 | self.velocity = velocity 42 | self.pitch = pitch 43 | self.instrument = instrument 44 | 45 | def __repr__(self): 46 | return 'Item(name={}, start={}, end={}, velocity={}, pitch={}, instrument={})'.format( 47 | self.name, self.start, self.end, self.velocity, self.pitch, self.instrument) 48 | 49 | # define "Event" for event storage 50 | class Event(object): 51 | def __init__(self, name, time, value, text): 52 | self.name = name 53 | self.time = time 54 | self.value = value 55 | self.text = text 56 | 57 | def __repr__(self): 58 | return 'Event(name={}, time={}, value={}, text={})'.format( 59 | self.name, self.time, self.value, self.text) 60 | 61 | class InputRepresentation(): 62 | def version(): 63 | return 'v4' 64 | 65 | def __init__(self, file, do_extract_chords=True, strict=False): 66 | if isinstance(file, pretty_midi.PrettyMIDI): 67 | self.pm = file 68 | else: 69 | self.pm = pretty_midi.PrettyMIDI(file) 70 | 71 | if strict and len(self.pm.time_signature_changes) == 0: 72 | raise ValueError("Invalid MIDI file: No time signature defined") 73 | 74 | self.resolution = self.pm.resolution 75 | 76 | self.note_items = None 77 | self.tempo_items = None 78 | self.chords = None 79 | self.groups = None 80 | 81 | self._read_items() 82 | self._quantize_items() 83 | if do_extract_chords: 84 | self.extract_chords() 85 | self._group_items() 86 | 87 | if strict and len(self.note_items) == 0: 88 | raise ValueError("Invalid MIDI file: No notes found, empty file.") 89 | 90 | # read notes and tempo changes from midi (assume there is only one track) 91 | def _read_items(self): 92 | # note 93 | self.note_items = [] 94 | for instrument in self.pm.instruments: 95 | pedal_events = [event for event in instrument.control_changes if event.number == 64] 96 | pedal_pressed = False 97 | start = None 98 | pedals = [] 99 | for e in pedal_events: 100 | if e.value >= 64 and not pedal_pressed: 101 | pedal_pressed = True 102 | start = e.time 103 | elif e.value < 64 and pedal_pressed: 104 | pedal_pressed = False 105 | pedals.append(Item(name='Pedal', start=start, end=e.time)) 106 | start = e.time 107 | 108 | notes = instrument.notes 109 | notes.sort(key=lambda x: (x.start, x.pitch)) 110 | 111 | if instrument.is_drum: 112 | instrument_name = 'drum' 113 | else: 114 | instrument_name = instrument.program 115 | 116 | pedal_idx = 0 117 | for note in notes: 118 | pedal_candidates = [(i + pedal_idx, pedal) for i, pedal in enumerate(pedals[pedal_idx:]) if note.end >= pedal.start and note.start < pedal.end] 119 | if len(pedal_candidates) > 0: 120 | pedal_idx = pedal_candidates[0][0] 121 | pedal = pedal_candidates[-1][1] 122 | else: 123 | pedal = Item(name='Pedal', start=0, end=0) 124 | 125 | self.note_items.append(Item( 126 | name='Note', 127 | start=self.pm.time_to_tick(note.start), 128 | end=self.pm.time_to_tick(max(note.end, pedal.end)), 129 | velocity=note.velocity, 130 | pitch=note.pitch, 131 | instrument=instrument_name)) 132 | self.note_items.sort(key=lambda x: (x.start, x.pitch)) 133 | # tempo 134 | self.tempo_items = [] 135 | times, tempi = self.pm.get_tempo_changes() 136 | for time, tempo in zip(times, tempi): 137 | self.tempo_items.append(Item( 138 | name='Tempo', 139 | start=time, 140 | end=None, 141 | velocity=None, 142 | pitch=int(tempo))) 143 | self.tempo_items.sort(key=lambda x: x.start) 144 | # expand to all beat 145 | max_tick = self.pm.time_to_tick(self.pm.get_end_time()) 146 | existing_ticks = {item.start: item.pitch for item in self.tempo_items} 147 | wanted_ticks = np.arange(0, max_tick+1, DEFAULT_RESOLUTION) 148 | output = [] 149 | for tick in wanted_ticks: 150 | if tick in existing_ticks: 151 | output.append(Item( 152 | name='Tempo', 153 | start=self.pm.time_to_tick(tick), 154 | end=None, 155 | velocity=None, 156 | pitch=existing_ticks[tick])) 157 | else: 158 | output.append(Item( 159 | name='Tempo', 160 | start=self.pm.time_to_tick(tick), 161 | end=None, 162 | velocity=None, 163 | pitch=output[-1].pitch)) 164 | self.tempo_items = output 165 | 166 | # quantize items 167 | def _quantize_items(self): 168 | ticks = self.resolution / DEFAULT_POS_PER_QUARTER 169 | # grid 170 | end_tick = self.pm.time_to_tick(self.pm.get_end_time()) 171 | grids = np.arange(0, max(self.resolution, end_tick), ticks) 172 | # process 173 | for item in self.note_items: 174 | index = np.searchsorted(grids, item.start, side='right') 175 | if index > 0: 176 | index -= 1 177 | shift = round(grids[index]) - item.start 178 | item.start += shift 179 | item.end += shift 180 | 181 | def get_end_tick(self): 182 | return self.pm.time_to_tick(self.pm.get_end_time()) 183 | 184 | # extract chord 185 | def extract_chords(self): 186 | end_tick = self.pm.time_to_tick(self.pm.get_end_time()) 187 | if end_tick < self.resolution: 188 | # If sequence is shorter than 1/4th note, it's probably empty 189 | self.chords = [] 190 | return self.chords 191 | method = MIDIChord(self.pm) 192 | chords = method.extract() 193 | output = [] 194 | for chord in chords: 195 | output.append(Item( 196 | name='Chord', 197 | start=self.pm.time_to_tick(chord[0]), 198 | end=self.pm.time_to_tick(chord[1]), 199 | velocity=None, 200 | pitch=chord[2].split('/')[0])) 201 | if len(output) == 0 or output[0].start > 0: 202 | if len(output) == 0: 203 | end = self.pm.time_to_tick(self.pm.get_end_time()) 204 | else: 205 | end = output[0].start 206 | output.append(Item( 207 | name='Chord', 208 | start=0, 209 | end=end, 210 | velocity=None, 211 | pitch='N:N' 212 | )) 213 | self.chords = output 214 | return self.chords 215 | 216 | # group items 217 | def _group_items(self): 218 | if self.chords: 219 | items = self.chords + self.tempo_items + self.note_items 220 | else: 221 | items = self.tempo_items + self.note_items 222 | 223 | def _get_key(item): 224 | type_priority = { 225 | 'Chord': 0, 226 | 'Tempo': 1, 227 | 'Note': 2 228 | } 229 | return ( 230 | item.start, # order by time 231 | type_priority[item.name], # chord events first, then tempo events, then note events 232 | -1 if item.instrument == 'drum' else item.instrument, # order by instrument 233 | item.pitch # order by note pitch 234 | ) 235 | 236 | items.sort(key=_get_key) 237 | downbeats = self.pm.get_downbeats() 238 | downbeats = np.concatenate([downbeats, [self.pm.get_end_time()]]) 239 | self.groups = [] 240 | for db1, db2 in zip(downbeats[:-1], downbeats[1:]): 241 | db1, db2 = self.pm.time_to_tick(db1), self.pm.time_to_tick(db2) 242 | insiders = [] 243 | for item in items: 244 | if (item.start >= db1) and (item.start < db2): 245 | insiders.append(item) 246 | overall = [db1] + insiders + [db2] 247 | self.groups.append(overall) 248 | 249 | # Trim empty groups from the beginning and end 250 | for idx in [0, -1]: 251 | while len(self.groups) > 0: 252 | group = self.groups[idx] 253 | notes = [item for item in group[1:-1] if item.name == 'Note'] 254 | if len(notes) == 0: 255 | self.groups.pop(idx) 256 | else: 257 | break 258 | 259 | return self.groups 260 | 261 | def _get_time_signature(self, start): 262 | # This method assumes that time signature changes don't happen within a bar 263 | # which is a convention that commonly holds 264 | time_sig = None 265 | for curr_sig, next_sig in zip(self.pm.time_signature_changes[:-1], self.pm.time_signature_changes[1:]): 266 | if self.pm.time_to_tick(curr_sig.time) <= start and self.pm.time_to_tick(next_sig.time) > start: 267 | time_sig = curr_sig 268 | break 269 | if time_sig is None: 270 | time_sig = self.pm.time_signature_changes[-1] 271 | return time_sig 272 | 273 | def _get_ticks_per_bar(self, start): 274 | time_sig = self._get_time_signature(start) 275 | quarters_per_bar = 4 * time_sig.numerator / time_sig.denominator 276 | return self.pm.resolution * quarters_per_bar 277 | 278 | def _get_positions_per_bar(self, start=None, time_sig=None): 279 | if time_sig is None: 280 | time_sig = self._get_time_signature(start) 281 | quarters_per_bar = 4 * time_sig.numerator / time_sig.denominator 282 | positions_per_bar = int(DEFAULT_POS_PER_QUARTER * quarters_per_bar) 283 | return positions_per_bar 284 | 285 | def tick_to_position(self, tick): 286 | return round(tick / self.pm.resolution * DEFAULT_POS_PER_QUARTER) 287 | 288 | # item to event 289 | def get_remi_events(self): 290 | events = [] 291 | n_downbeat = 0 292 | current_chord = None 293 | current_tempo = None 294 | for i in range(len(self.groups)): 295 | bar_st, bar_et = self.groups[i][0], self.groups[i][-1] 296 | n_downbeat += 1 297 | positions_per_bar = self._get_positions_per_bar(bar_st) 298 | if positions_per_bar <= 0: 299 | raise ValueError('Invalid REMI file: There must be at least 1 position per bar.') 300 | 301 | events.append(Event( 302 | name=BAR_KEY, 303 | time=None, 304 | value='{}'.format(n_downbeat), 305 | text='{}'.format(n_downbeat))) 306 | 307 | time_sig = self._get_time_signature(bar_st) 308 | events.append(Event( 309 | name=TIME_SIGNATURE_KEY, 310 | time=None, 311 | value='{}/{}'.format(time_sig.numerator, time_sig.denominator), 312 | text='{}/{}'.format(time_sig.numerator, time_sig.denominator) 313 | )) 314 | 315 | if current_chord is not None: 316 | events.append(Event( 317 | name=POSITION_KEY, 318 | time=0, 319 | value='{}'.format(0), 320 | text='{}/{}'.format(1, positions_per_bar))) 321 | events.append(Event( 322 | name=CHORD_KEY, 323 | time=current_chord.start, 324 | value=current_chord.pitch, 325 | text='{}'.format(current_chord.pitch))) 326 | 327 | if current_tempo is not None: 328 | events.append(Event( 329 | name=POSITION_KEY, 330 | time=0, 331 | value='{}'.format(0), 332 | text='{}/{}'.format(1, positions_per_bar))) 333 | tempo = current_tempo.pitch 334 | index = np.argmin(abs(DEFAULT_TEMPO_BINS-tempo)) 335 | events.append(Event( 336 | name=TEMPO_KEY, 337 | time=current_tempo.start, 338 | value=index, 339 | text='{}/{}'.format(tempo, DEFAULT_TEMPO_BINS[index]))) 340 | 341 | quarters_per_bar = 4 * time_sig.numerator / time_sig.denominator 342 | ticks_per_bar = self.pm.resolution * quarters_per_bar 343 | flags = np.linspace(bar_st, bar_st + ticks_per_bar, positions_per_bar, endpoint=False) 344 | for item in self.groups[i][1:-1]: 345 | # position 346 | index = np.argmin(abs(flags-item.start)) 347 | pos_event = Event( 348 | name=POSITION_KEY, 349 | time=item.start, 350 | value='{}'.format(index), 351 | text='{}/{}'.format(index+1, positions_per_bar)) 352 | 353 | if item.name == 'Note': 354 | events.append(pos_event) 355 | # instrument 356 | if item.instrument == 'drum': 357 | name = 'drum' 358 | else: 359 | name = pretty_midi.program_to_instrument_name(item.instrument) 360 | events.append(Event( 361 | name=INSTRUMENT_KEY, 362 | time=item.start, 363 | value=name, 364 | text='{}'.format(name))) 365 | # pitch 366 | events.append(Event( 367 | name=PITCH_KEY, 368 | time=item.start, 369 | value='drum_{}'.format(item.pitch) if name == 'drum' else item.pitch, 370 | text='{}'.format(pretty_midi.note_number_to_name(item.pitch)))) 371 | # velocity 372 | velocity_index = np.argmin(abs(DEFAULT_VELOCITY_BINS - item.velocity)) 373 | events.append(Event( 374 | name=VELOCITY_KEY, 375 | time=item.start, 376 | value=velocity_index, 377 | text='{}/{}'.format(item.velocity, DEFAULT_VELOCITY_BINS[velocity_index]))) 378 | # duration 379 | duration = self.tick_to_position(item.end - item.start) 380 | index = np.argmin(abs(DEFAULT_DURATION_BINS-duration)) 381 | events.append(Event( 382 | name=DURATION_KEY, 383 | time=item.start, 384 | value=index, 385 | text='{}/{}'.format(duration, DEFAULT_DURATION_BINS[index]))) 386 | elif item.name == 'Chord': 387 | if current_chord is None or item.pitch != current_chord.pitch: 388 | events.append(pos_event) 389 | events.append(Event( 390 | name=CHORD_KEY, 391 | time=item.start, 392 | value=item.pitch, 393 | text='{}'.format(item.pitch))) 394 | current_chord = item 395 | elif item.name == 'Tempo': 396 | if current_tempo is None or item.pitch != current_tempo.pitch: 397 | events.append(pos_event) 398 | tempo = item.pitch 399 | index = np.argmin(abs(DEFAULT_TEMPO_BINS-tempo)) 400 | events.append(Event( 401 | name=TEMPO_KEY, 402 | time=item.start, 403 | value=index, 404 | text='{}/{}'.format(tempo, DEFAULT_TEMPO_BINS[index]))) 405 | current_tempo = item 406 | 407 | return [f'{e.name}_{e.value}' for e in events] 408 | 409 | def get_description(self, 410 | omit_time_sig=False, 411 | omit_instruments=False, 412 | omit_chords=False, 413 | omit_meta=False): 414 | events = [] 415 | n_downbeat = 0 416 | current_chord = None 417 | 418 | for i in range(len(self.groups)): 419 | bar_st, bar_et = self.groups[i][0], self.groups[i][-1] 420 | n_downbeat += 1 421 | time_sig = self._get_time_signature(bar_st) 422 | positions_per_bar = self._get_positions_per_bar(time_sig=time_sig) 423 | if positions_per_bar <= 0: 424 | raise ValueError('Invalid REMI file: There must be at least 1 position in each bar.') 425 | 426 | events.append(Event( 427 | name=BAR_KEY, 428 | time=None, 429 | value='{}'.format(n_downbeat), 430 | text='{}'.format(n_downbeat))) 431 | 432 | if not omit_time_sig: 433 | events.append(Event( 434 | name=TIME_SIGNATURE_KEY, 435 | time=None, 436 | value='{}/{}'.format(time_sig.numerator, time_sig.denominator), 437 | text='{}/{}'.format(time_sig.numerator, time_sig.denominator), 438 | )) 439 | 440 | if not omit_meta: 441 | notes = [item for item in self.groups[i][1:-1] if item.name == 'Note'] 442 | n_notes = len(notes) 443 | velocities = np.array([item.velocity for item in notes]) 444 | pitches = np.array([item.pitch for item in notes]) 445 | durations = np.array([item.end - item.start for item in notes]) 446 | 447 | note_density = n_notes/positions_per_bar 448 | index = np.argmin(abs(DEFAULT_NOTE_DENSITY_BINS-note_density)) 449 | events.append(Event( 450 | name=NOTE_DENSITY_KEY, 451 | time=None, 452 | value=index, 453 | text='{:.2f}/{:.2f}'.format(note_density, DEFAULT_NOTE_DENSITY_BINS[index]) 454 | )) 455 | 456 | # will be 0 if there's no notes 457 | mean_velocity = velocities.mean() if len(velocities) > 0 else np.nan 458 | index = np.argmin(abs(DEFAULT_MEAN_VELOCITY_BINS-mean_velocity)) 459 | events.append(Event( 460 | name=MEAN_VELOCITY_KEY, 461 | time=None, 462 | value=index if mean_velocity != np.nan else 'NaN', 463 | text='{:.2f}/{:.2f}'.format(mean_velocity, DEFAULT_MEAN_VELOCITY_BINS[index]) 464 | )) 465 | 466 | # will be 0 if there's no notes 467 | mean_pitch = pitches.mean() if len(pitches) > 0 else np.nan 468 | index = np.argmin(abs(DEFAULT_MEAN_PITCH_BINS-mean_pitch)) 469 | events.append(Event( 470 | name=MEAN_PITCH_KEY, 471 | time=None, 472 | value=index if mean_pitch != np.nan else 'NaN', 473 | text='{:.2f}/{:.2f}'.format(mean_pitch, DEFAULT_MEAN_PITCH_BINS[index]) 474 | )) 475 | 476 | # will be 1 if there's no notes 477 | mean_duration = durations.mean() if len(durations) > 0 else np.nan 478 | index = np.argmin(abs(DEFAULT_MEAN_DURATION_BINS-mean_duration)) 479 | events.append(Event( 480 | name=MEAN_DURATION_KEY, 481 | time=None, 482 | value=index if mean_duration != np.nan else 'NaN', 483 | text='{:.2f}/{:.2f}'.format(mean_duration, DEFAULT_MEAN_DURATION_BINS[index]) 484 | )) 485 | 486 | if not omit_instruments: 487 | instruments = set([item.instrument for item in notes]) 488 | for instrument in instruments: 489 | instrument = pretty_midi.program_to_instrument_name(instrument) if instrument != 'drum' else 'drum' 490 | events.append(Event( 491 | name=INSTRUMENT_KEY, 492 | time=None, 493 | value=instrument, 494 | text=instrument 495 | )) 496 | 497 | if not omit_chords: 498 | chords = [item for item in self.groups[i][1:-1] if item.name == 'Chord'] 499 | if len(chords) == 0 and current_chord is not None: 500 | chords = [current_chord] 501 | elif len(chords) > 0: 502 | if chords[0].start > bar_st and current_chord is not None: 503 | chords.insert(0, current_chord) 504 | current_chord = chords[-1] 505 | 506 | for chord in chords: 507 | events.append(Event( 508 | name=CHORD_KEY, 509 | time=None, 510 | value=chord.pitch, 511 | text='{}'.format(chord.pitch) 512 | )) 513 | 514 | return [f'{e.name}_{e.value}' for e in events] 515 | 516 | 517 | ############################################################################################# 518 | # WRITE MIDI 519 | ############################################################################################# 520 | 521 | def remi2midi(events, bpm=120, time_signature=(4, 4), polyphony_limit=16): 522 | vocab = RemiVocab() 523 | 524 | def _get_time(bar, position, bpm=120, positions_per_bar=48): 525 | abs_position = bar*positions_per_bar + position 526 | beat = abs_position / DEFAULT_POS_PER_QUARTER 527 | return beat/bpm*60 528 | 529 | def _get_time(reference, bar, pos): 530 | time_sig = reference['time_sig'] 531 | num, denom = time_sig.numerator, time_sig.denominator 532 | # Quarters per bar, assuming 4 quarters per whole note 533 | qpb = 4 * num / denom 534 | ref_pos = reference['pos'] 535 | d_bars = bar - ref_pos[0] 536 | d_pos = (pos - ref_pos[1]) + d_bars*qpb*DEFAULT_POS_PER_QUARTER 537 | d_quarters = d_pos / DEFAULT_POS_PER_QUARTER 538 | # Convert quarters to seconds 539 | dt = d_quarters / reference['tempo'] * 60 540 | return reference['time'] + dt 541 | 542 | # time_sigs = [event.split('_')[-1].split('/') for event in events if f"{TIME_SIGNATURE_KEY}_" in event] 543 | # time_sigs = [(int(num), int(denom)) for num, denom in time_sigs] 544 | 545 | tempo_changes = [event for event in events if f"{TEMPO_KEY}_" in event] 546 | if len(tempo_changes) > 0: 547 | bpm = DEFAULT_TEMPO_BINS[int(tempo_changes[0].split('_')[-1])] 548 | 549 | pm = pretty_midi.PrettyMIDI(initial_tempo=bpm) 550 | num, denom = time_signature 551 | pm.time_signature_changes.append(pretty_midi.TimeSignature(num, denom, 0)) 552 | current_time_sig = pm.time_signature_changes[0] 553 | 554 | instruments = {} 555 | 556 | # Use implicit timeline: keep track of last tempo/time signature change event 557 | # and calculate time difference relative to that 558 | last_tl_event = { 559 | 'time': 0, 560 | 'pos': (0, 0), 561 | 'time_sig': current_time_sig, 562 | 'tempo': bpm 563 | } 564 | 565 | bar = -1 566 | n_notes = 0 567 | polyphony_control = {} 568 | for i, event in enumerate(events): 569 | if event == EOS_TOKEN: 570 | break 571 | 572 | if not bar in polyphony_control: 573 | polyphony_control[bar] = {} 574 | 575 | if f"{BAR_KEY}_" in events[i]: 576 | # Next bar is starting 577 | bar += 1 578 | polyphony_control[bar] = {} 579 | 580 | if i+1 < len(events) and f"{TIME_SIGNATURE_KEY}_" in events[i+1]: 581 | num, denom = events[i+1].split('_')[-1].split('/') 582 | num, denom = int(num), int(denom) 583 | current_time_sig = last_tl_event['time_sig'] 584 | if num != current_time_sig.numerator or denom != current_time_sig.denominator: 585 | time = _get_time(last_tl_event, bar, 0) 586 | time_sig = pretty_midi.TimeSignature(num, denom, time) 587 | pm.time_signature_changes.append(time_sig) 588 | last_tl_event['time'] = time 589 | last_tl_event['pos'] = (bar, 0) 590 | last_tl_event['time_sig'] = time_sig 591 | 592 | elif i+1 < len(events) and \ 593 | f"{POSITION_KEY}_" in events[i] and \ 594 | f"{TEMPO_KEY}_" in events[i+1]: 595 | position = int(events[i].split('_')[-1]) 596 | tempo_idx = int(events[i+1].split('_')[-1]) 597 | tempo = DEFAULT_TEMPO_BINS[tempo_idx] 598 | 599 | if tempo != last_tl_event['tempo']: 600 | time = _get_time(last_tl_event, bar, position) 601 | last_tl_event['time'] = time 602 | last_tl_event['pos'] = (bar, position) 603 | last_tl_event['tempo'] = tempo 604 | 605 | elif i+4 < len(events) and \ 606 | f"{POSITION_KEY}_" in events[i] and \ 607 | f"{INSTRUMENT_KEY}_" in events[i+1] and \ 608 | f"{PITCH_KEY}_" in events[i+2] and \ 609 | f"{VELOCITY_KEY}_" in events[i+3] and \ 610 | f"{DURATION_KEY}_" in events[i+4]: 611 | # get position 612 | position = int(events[i].split('_')[-1]) 613 | if not position in polyphony_control[bar]: 614 | polyphony_control[bar][position] = {} 615 | 616 | # get instrument 617 | instrument_name = events[i+1].split('_')[-1] 618 | if instrument_name not in polyphony_control[bar][position]: 619 | polyphony_control[bar][position][instrument_name] = 0 620 | elif polyphony_control[bar][position][instrument_name] >= polyphony_limit: 621 | # If number of notes exceeds polyphony limit, omit this note 622 | continue 623 | 624 | if instrument_name not in instruments: 625 | if instrument_name == 'drum': 626 | instrument = pretty_midi.Instrument(0, is_drum=True) 627 | else: 628 | program = pretty_midi.instrument_name_to_program(instrument_name) 629 | instrument = pretty_midi.Instrument(program) 630 | instruments[instrument_name] = instrument 631 | else: 632 | instrument = instruments[instrument_name] 633 | 634 | # get pitch 635 | pitch = int(events[i+2].split('_')[-1]) 636 | # get velocity 637 | velocity_index = int(events[i+3].split('_')[-1]) 638 | velocity = min(127, DEFAULT_VELOCITY_BINS[velocity_index]) 639 | # get duration 640 | duration_index = int(events[i+4].split('_')[-1]) 641 | duration = DEFAULT_DURATION_BINS[duration_index] 642 | # create not and add to instrument 643 | start = _get_time(last_tl_event, bar, position) 644 | end = _get_time(last_tl_event, bar, position + duration) 645 | note = pretty_midi.Note(velocity=velocity, 646 | pitch=pitch, 647 | start=start, 648 | end=end) 649 | instrument.notes.append(note) 650 | n_notes += 1 651 | polyphony_control[bar][position][instrument_name] += 1 652 | 653 | for instrument in instruments.values(): 654 | pm.instruments.append(instrument) 655 | return pm -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvruette/figaro/9da30291e2865bcbdad1e85ccea82df1a61119e9/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/seq2seq.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch.optim 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.utils.rnn import pad_sequence 6 | import math 7 | from datasets import MidiDataModule 8 | from vocab import RemiVocab, DescriptionVocab 9 | from constants import PAD_TOKEN, EOS_TOKEN, BAR_KEY, POSITION_KEY 10 | 11 | 12 | import transformers 13 | from transformers import ( 14 | BertConfig, 15 | EncoderDecoderConfig, 16 | EncoderDecoderModel 17 | ) 18 | 19 | class GroupEmbedding(nn.Module): 20 | def __init__(self, n_tokens, n_groups, out_dim, inner_dim=128): 21 | super().__init__() 22 | self.n_tokens = n_tokens 23 | self.n_groups = n_groups 24 | self.inner_dim = inner_dim 25 | self.out_dim = out_dim 26 | 27 | self.embedding = nn.Embedding(n_tokens, inner_dim) 28 | self.proj = nn.Linear(n_groups * inner_dim, out_dim, bias=False) 29 | 30 | def forward(self, x): 31 | shape = x.shape 32 | emb = self.embedding(x) 33 | return self.proj(emb.view(*shape[:-1], self.n_groups * self.inner_dim)) 34 | 35 | class Seq2SeqModule(pl.LightningModule): 36 | def __init__(self, 37 | d_model=512, 38 | d_latent=512, 39 | n_codes=512, 40 | n_groups=8, 41 | context_size=512, 42 | lr=1e-4, 43 | lr_schedule='sqrt_decay', 44 | warmup_steps=None, 45 | max_steps=None, 46 | encoder_layers=6, 47 | decoder_layers=12, 48 | intermediate_size=2048, 49 | num_attention_heads=8, 50 | description_flavor='description', 51 | description_options=None, 52 | use_pretrained_latent_embeddings=True): 53 | super(Seq2SeqModule, self).__init__() 54 | 55 | self.description_flavor = description_flavor 56 | assert self.description_flavor in ['latent', 'description', 'none', 'both'], f"Unknown description flavor '{self.description_flavor}', expected one of ['latent', 'description', 'none', 'both]" 57 | self.description_options = description_options 58 | 59 | self.context_size = context_size 60 | self.d_model = d_model 61 | self.d_latent = d_latent 62 | 63 | self.lr = lr 64 | self.lr_schedule = lr_schedule 65 | self.warmup_steps = warmup_steps 66 | self.max_steps = max_steps 67 | 68 | self.vocab = RemiVocab() 69 | 70 | encoder_config = BertConfig( 71 | vocab_size=1, 72 | pad_token_id=0, 73 | hidden_size=self.d_model, 74 | num_hidden_layers=encoder_layers, 75 | num_attention_heads=num_attention_heads, 76 | intermediate_size=intermediate_size, 77 | max_position_embeddings=1024, 78 | position_embedding_type='relative_key_query' 79 | ) 80 | decoder_config = BertConfig( 81 | vocab_size=1, 82 | pad_token_id=0, 83 | hidden_size=self.d_model, 84 | num_hidden_layers=decoder_layers, 85 | num_attention_heads=num_attention_heads, 86 | intermediate_size=intermediate_size, 87 | max_position_embeddings=1024, 88 | position_embedding_type='relative_key_query' 89 | ) 90 | config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config) 91 | self.transformer = EncoderDecoderModel(config) 92 | self.transformer.config.decoder.is_decoder = True 93 | self.transformer.config.decoder.add_cross_attention = True 94 | 95 | 96 | self.max_bars = self.context_size 97 | self.max_positions = 512 98 | self.bar_embedding = nn.Embedding(self.max_bars + 1, self.d_model) 99 | self.pos_embedding = nn.Embedding(self.max_positions + 1, self.d_model) 100 | 101 | if self.description_flavor in ['latent', 'both']: 102 | if use_pretrained_latent_embeddings: 103 | self.latent_in = nn.Linear(self.d_latent, self.d_model, bias=False) 104 | else: 105 | self.latent_in = GroupEmbedding(n_codes, n_groups, self.d_model, inner_dim=self.d_latent//n_groups) 106 | if self.description_flavor in ['description', 'both']: 107 | desc_vocab = DescriptionVocab() 108 | self.desc_in = nn.Embedding(len(desc_vocab), self.d_model) 109 | 110 | if self.description_flavor == 'both': 111 | self.desc_proj = nn.Linear(2*self.d_model, self.d_model, bias=False) 112 | 113 | self.in_layer = nn.Embedding(len(self.vocab), self.d_model) 114 | self.out_layer = nn.Linear(self.d_model, len(self.vocab), bias=False) 115 | 116 | self.loss_fn = nn.CrossEntropyLoss(ignore_index=self.vocab.to_i(PAD_TOKEN)) 117 | 118 | self.save_hyperparameters() 119 | 120 | def get_datamodule(self, midi_files, **kwargs): 121 | return MidiDataModule( 122 | midi_files, 123 | self.context_size, 124 | description_flavor=self.description_flavor, 125 | max_bars=self.max_bars, 126 | max_positions=self.max_positions, 127 | description_options=self.description_options, 128 | **kwargs 129 | ) 130 | 131 | def encode(self, z, desc_bar_ids=None): 132 | if self.description_flavor == 'both': 133 | desc = z['description'] 134 | latent = z['latents'] 135 | desc_emb = self.desc_in(desc) 136 | latent_emb = self.latent_in(latent) 137 | 138 | padded = pad_sequence([desc_emb.transpose(0, 1), latent_emb.transpose(0, 1)], batch_first=True) 139 | desc_emb, latent_emb = padded.transpose(1, 2) 140 | 141 | if desc_bar_ids is not None: 142 | # Use the fact that description is always longer than latents 143 | desc_emb = desc_emb + self.bar_embedding(desc_bar_ids) 144 | 145 | z_emb = self.desc_proj(torch.cat([desc_emb, latent_emb], dim=-1)) 146 | 147 | elif self.description_flavor == 'description': 148 | z_emb = self.desc_in(z) 149 | if desc_bar_ids is not None: 150 | z_emb += self.bar_embedding(desc_bar_ids) 151 | 152 | elif self.description_flavor == 'latent': 153 | z_emb = self.latent_in(z) 154 | 155 | else: 156 | return None 157 | 158 | out = self.transformer.encoder(inputs_embeds=z_emb, output_hidden_states=True) 159 | encoder_hidden = out.hidden_states[-1] 160 | return encoder_hidden 161 | 162 | def decode(self, x, labels=None, bar_ids=None, position_ids=None, encoder_hidden_states=None, return_hidden=False): 163 | seq_len = x.size(1) 164 | 165 | # Shape of x_emb: (batch_size, seq_len, d_model) 166 | x_emb = self.in_layer(x) 167 | if bar_ids is not None: 168 | x_emb += self.bar_embedding(bar_ids) 169 | if position_ids is not None: 170 | x_emb += self.pos_embedding(position_ids) 171 | 172 | # # Add latent embedding to input embeddings 173 | # if bar_ids is not None: 174 | # assert bar_ids.max() <= encoder_hidden.size(1) 175 | # embs = torch.cat([torch.zeros(x.size(0), 1, self.d_model, device=self.device), encoder_hidden], dim=1) 176 | # offset = (seq_len * torch.arange(bar_ids.size(0), device=self.device)).unsqueeze(1) 177 | # # Use bar_ids to gather encoder hidden states s.t. latent_emb[i, j] == encoder_hidden[i, bar_ids[i, j]] 178 | # latent_emb = F.embedding((bar_ids + offset).view(-1), embs.view(-1, self.d_model)).view(x_emb.shape) 179 | # x_emb += latent_emb 180 | 181 | if encoder_hidden_states is not None: 182 | # Make x_emb and encoder_hidden_states match in sequence length. Necessary for relative positional embeddings 183 | padded = pad_sequence([x_emb.transpose(0, 1), encoder_hidden_states.transpose(0, 1)], batch_first=True) 184 | x_emb, encoder_hidden_states = padded.transpose(1, 2) 185 | 186 | out = self.transformer.decoder( 187 | inputs_embeds=x_emb, 188 | encoder_hidden_states=encoder_hidden_states, 189 | output_hidden_states=True 190 | ) 191 | hidden = out.hidden_states[-1][:, :seq_len] 192 | else: 193 | out = self.transformer.decoder(inputs_embeds=x_emb, output_hidden_states=True) 194 | hidden = out.hidden_states[-1][:, :seq_len] 195 | 196 | # Shape of logits: (batch_size, seq_len, tuple_size, vocab_size) 197 | 198 | if return_hidden: 199 | return hidden 200 | else: 201 | return self.out_layer(hidden) 202 | 203 | 204 | def forward(self, x, z=None, labels=None, position_ids=None, bar_ids=None, description_bar_ids=None, return_hidden=False): 205 | encoder_hidden = self.encode(z, desc_bar_ids=description_bar_ids) 206 | 207 | out = self.decode(x, 208 | labels=labels, 209 | bar_ids=bar_ids, 210 | position_ids=position_ids, 211 | encoder_hidden_states=encoder_hidden, 212 | return_hidden=return_hidden 213 | ) 214 | 215 | return out 216 | 217 | def get_loss(self, batch, return_logits=False): 218 | # Shape of x: (batch_size, seq_len, tuple_size) 219 | x = batch['input_ids'] 220 | bar_ids = batch['bar_ids'] 221 | position_ids = batch['position_ids'] 222 | # Shape of labels: (batch_size, tgt_len, tuple_size) 223 | labels = batch['labels'] 224 | 225 | # Shape of z: (batch_size, context_size, n_groups, d_latent) 226 | if self.description_flavor == 'latent': 227 | z = batch['latents'] 228 | desc_bar_ids = None 229 | elif self.description_flavor == 'description': 230 | z = batch['description'] 231 | desc_bar_ids = batch['desc_bar_ids'] 232 | elif self.description_flavor == 'both': 233 | z = { 'latents': batch['latents'], 'description': batch['description'] } 234 | desc_bar_ids = batch['desc_bar_ids'] 235 | else: 236 | z, desc_bar_ids = None, None 237 | 238 | 239 | logits = self(x, z=z, labels=labels, bar_ids=bar_ids, position_ids=position_ids, description_bar_ids=desc_bar_ids) 240 | # Shape of logits: (batch_size, tgt_len, tuple_size, vocab_size) 241 | pred = logits.view(-1, logits.shape[-1]) 242 | labels = labels.reshape(-1) 243 | 244 | loss = self.loss_fn(pred, labels) 245 | 246 | if return_logits: 247 | return loss, logits 248 | else: 249 | return loss 250 | 251 | def training_step(self, batch, batch_idx): 252 | loss = self.get_loss(batch) 253 | self.log('train_loss', loss.detach(), on_step=True, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 254 | return loss 255 | 256 | def validation_step(self, batch, batch_idx): 257 | loss, logits = self.get_loss(batch, return_logits=True) 258 | self.log('valid_loss', loss.detach(), on_step=True, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 259 | 260 | y = batch['labels'] 261 | pad_token_id = self.vocab.to_i(PAD_TOKEN) 262 | 263 | logits = logits.view(logits.size(0), -1, logits.size(-1)) 264 | y = y.view(y.size(0), -1) 265 | 266 | log_pr = logits.log_softmax(dim=-1) 267 | log_pr[y == pad_token_id] = 0 # log(pr) = log(1) for padding 268 | log_pr = torch.gather(log_pr, -1, y.unsqueeze(-1)).squeeze(-1) 269 | 270 | t = (y != pad_token_id).sum(dim=-1) 271 | ppl = (-log_pr.sum(dim=1) / t).exp().mean() 272 | self.log('valid_ppl', ppl.detach(), on_step=True, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 273 | return loss 274 | 275 | def test_step(self, batch, batch_idx): 276 | return self.get_loss(batch) 277 | 278 | def configure_optimizers(self): 279 | # set LR to 1, scale with LambdaLR scheduler 280 | optimizer = torch.optim.AdamW(self.parameters(), lr=1, weight_decay=0.01) 281 | 282 | if self.lr_schedule == 'sqrt_decay': 283 | # constant warmup, then 1/sqrt(n) decay starting from the initial LR 284 | lr_func = lambda step: min(self.lr, self.lr / math.sqrt(max(step, 1)/self.warmup_steps)) 285 | elif self.lr_schedule == 'linear': 286 | # linear warmup, linear decay 287 | lr_func = lambda step: min(self.lr, self.lr*step/self.warmup_steps, self.lr*(1 - (step - self.warmup_steps)/self.max_steps)) 288 | elif self.lr_schedule == 'cosine': 289 | # linear warmup, cosine decay to 10% of initial LR 290 | lr_func = lambda step: self.lr * min(step/self.warmup_steps, 0.55 + 0.45*math.cos(math.pi*(min(step, self.max_steps) - self.warmup_steps)/(self.max_steps - self.warmup_steps))) 291 | else: 292 | # Use no lr scheduling 293 | lr_func = lambda step: self.lr 294 | 295 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_func) 296 | return [optimizer], [{ 297 | 'scheduler': scheduler, 298 | 'interval': 'step', 299 | }] 300 | 301 | @torch.no_grad() 302 | def sample(self, batch, 303 | max_length=256, 304 | max_bars=-1, 305 | temp=0.8, 306 | pad_token=PAD_TOKEN, 307 | eos_token=EOS_TOKEN, 308 | verbose=0, 309 | ): 310 | 311 | # Setup and parsing arguments 312 | 313 | pad_token_id = self.vocab.to_i(pad_token) 314 | eos_token_id = self.vocab.to_i(eos_token) 315 | 316 | batch_size, curr_len = batch['input_ids'].shape 317 | 318 | i = curr_len - 1 319 | 320 | x = batch['input_ids'] 321 | bar_ids = batch['bar_ids'] 322 | position_ids = batch['position_ids'] 323 | assert x.shape[:2] == bar_ids.shape and x.shape[:2] == position_ids.shape, f"Input, bar and position ids weren't of compatible shapes: {x.shape}, {bar_ids.shape}, {position_ids.shape}" 324 | 325 | if self.description_flavor == 'both': 326 | z = { 'latents': batch['latents'], 'description': batch['description'] } 327 | desc_bar_ids = batch['desc_bar_ids'].to(self.device) 328 | elif self.description_flavor == 'latent': 329 | z, desc_bar_ids = batch['latents'], None 330 | elif self.description_flavor == 'description': 331 | z, desc_bar_ids = batch['description'], batch['desc_bar_ids'].to(self.device) 332 | else: 333 | z, desc_bar_ids = None, None 334 | 335 | 336 | is_done = torch.zeros(batch_size, dtype=torch.bool) 337 | 338 | # Precompute encoder hidden states for cross-attention 339 | if self.description_flavor == 'latent': 340 | encoder_hidden_states = self.encode(z, desc_bar_ids) 341 | else: 342 | encoder_hidden_states = None 343 | 344 | curr_bars = torch.zeros(batch_size).to(self.device).fill_(-1) 345 | # Sample using decoder until max_length is reached or all sequences are done 346 | for i in range(curr_len - 1, max_length): 347 | # print(f"\r{i+1}/{max_length}", end='') 348 | x_ = x[:, -self.context_size:].to(self.device) 349 | bar_ids_ = bar_ids[:, -self.context_size:].to(self.device) 350 | position_ids_ = position_ids[:, -self.context_size:].to(self.device) 351 | 352 | # Description scrolling 353 | if self.description_flavor in ['description', 'both']: 354 | if self.description_flavor == 'description': 355 | desc = z 356 | else: 357 | desc = z['description'] 358 | 359 | next_bars = bar_ids_[:, 0] 360 | bars_changed = not (next_bars == curr_bars).all() 361 | curr_bars = next_bars 362 | 363 | if bars_changed: 364 | z_ = torch.zeros(batch_size, self.context_size, dtype=torch.int) 365 | desc_bar_ids_ = torch.zeros(batch_size, self.context_size, dtype=torch.int) 366 | 367 | for j in range(batch_size): 368 | curr_bar = bar_ids_[j, 0] 369 | indices = torch.nonzero(desc_bar_ids[j] == curr_bar) 370 | if indices.size(0) > 0: 371 | idx = indices[0, 0] 372 | else: 373 | idx = desc.size(1) - 1 374 | 375 | offset = min(self.context_size, desc.size(1) - idx) 376 | 377 | z_[j, :offset] = desc[j, idx:idx+offset] 378 | desc_bar_ids_[j, :offset] = desc_bar_ids[j, idx:idx+offset] 379 | 380 | z_, desc_bar_ids_ = z_.to(self.device), desc_bar_ids_.to(self.device) 381 | 382 | if self.description_flavor == 'both': 383 | z_ = { 'description': z_, 'latents': z['latents'] } 384 | 385 | encoder_hidden_states = self.encode(z_, desc_bar_ids_) 386 | 387 | logits = self.decode(x_, bar_ids=bar_ids_, position_ids=position_ids_, encoder_hidden_states=encoder_hidden_states) 388 | 389 | idx = min(self.context_size - 1, i) 390 | logits = logits[:, idx] / temp 391 | 392 | pr = F.softmax(logits, dim=-1) 393 | pr = pr.view(-1, pr.size(-1)) 394 | 395 | next_token_ids = torch.multinomial(pr, 1).view(-1).to(x.device) 396 | next_tokens = self.vocab.decode(next_token_ids) 397 | if verbose: 398 | print(f"{i+1}/{max_length}", next_tokens) 399 | 400 | 401 | next_bars = torch.tensor([1 if f'{BAR_KEY}_' in token else 0 for token in next_tokens], dtype=torch.int) 402 | next_bar_ids = bar_ids[:, i].clone() + next_bars 403 | 404 | next_positions = [f"{POSITION_KEY}_0" if f'{BAR_KEY}_' in token else token for token in next_tokens] 405 | next_positions = [int(token.split('_')[-1]) if f'{POSITION_KEY}_' in token else None for token in next_positions] 406 | next_positions = [pos if next_pos is None else next_pos for pos, next_pos in zip(position_ids[:, i], next_positions)] 407 | next_position_ids = torch.tensor(next_positions, dtype=torch.int) 408 | 409 | is_done.masked_fill_((next_token_ids == eos_token_id).all(dim=-1), True) 410 | next_token_ids[is_done] = pad_token_id 411 | if max_bars > 0: 412 | is_done.masked_fill_(next_bar_ids >= max_bars + 1, True) 413 | 414 | x = torch.cat([x, next_token_ids.clone().unsqueeze(1)], dim=1) 415 | bar_ids = torch.cat([bar_ids, next_bar_ids.unsqueeze(1)], dim=1) 416 | position_ids = torch.cat([position_ids, next_position_ids.unsqueeze(1)], dim=1) 417 | 418 | if torch.all(is_done): 419 | break 420 | # print() 421 | 422 | return { 423 | 'sequences': x, 424 | 'bar_ids': bar_ids, 425 | 'position_ids': position_ids 426 | } 427 | 428 | -------------------------------------------------------------------------------- /src/models/vae.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import pytorch_lightning as pl 4 | import torch.optim 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | import numpy as np 9 | import random 10 | from datasets import MidiDataModule 11 | from vocab import RemiVocab 12 | from constants import PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, MASK_TOKEN 13 | 14 | import transformers 15 | from transformers import ( 16 | BertConfig, 17 | EncoderDecoderConfig, 18 | EncoderDecoderModel 19 | ) 20 | 21 | 22 | # Implementation adapted from https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py 23 | # Random restarts adapted from https://github.com/openai/jukebox/blob/master/jukebox/vqvae/bottleneck.py 24 | class VectorQuantizeEMA(nn.Module): 25 | def __init__(self, d_latent, n_codes, n_groups=1, decay=0.995, eps=1e-4, restart_threshold=0.99): 26 | assert d_latent // n_groups == d_latent / n_groups, f"Unexpected latent dimension: d_latent={d_latent} must be divisible by n_groups={n_groups}" 27 | 28 | super().__init__() 29 | 30 | self.d_latent = d_latent 31 | self.n_groups = n_groups 32 | self.dim = d_latent // n_groups 33 | self.n_codes = n_codes 34 | 35 | self.decay = decay 36 | self.eps = eps 37 | self.threshold = restart_threshold 38 | self.init = False 39 | 40 | embed = torch.randn(self.n_codes, self.dim) 41 | self.register_buffer('embedding', embed) 42 | self.register_buffer('cluster_size', torch.ones(self.n_codes)) 43 | self.register_buffer('cluster_sum', embed.clone().detach()) 44 | 45 | def forward(self, x, dist=None): 46 | assert x.shape[-1] == self.n_groups * self.dim, f"Unexpected input shape: expected last dimension to be {self.n_groups * self.dim} but was {x.shape[-1]}" 47 | x_ = x.reshape(-1, self.dim) 48 | 49 | if self.training and not self.init: 50 | self._init_embeddings(x_, dist=dist) 51 | 52 | ### Shared embeddings between groups ### 53 | # Find nearest neighbors in latent space 54 | emb_t = self.embedding.t() 55 | distance = ( 56 | x_.pow(2).sum(1, keepdim=True) 57 | - 2 * x_ @ emb_t 58 | + emb_t.pow(2).sum(0, keepdim=True) 59 | ) 60 | _, embed_idx = (-distance).max(1) 61 | embed_onehot = F.one_hot(embed_idx, self.n_codes).type(x_.dtype) 62 | 63 | quantize = self.embed(embed_idx).view(-1, self.n_groups * self.dim) 64 | diff = (quantize.detach() - x).pow(2).mean() 65 | quantize = x + (quantize - x).detach() 66 | codes = embed_idx.view(-1, self.n_groups) 67 | 68 | if self.training: 69 | update_metrics = self._ema_update(x_, embed_onehot, dist=dist) 70 | else: 71 | update_metrics = {} 72 | 73 | return dict( 74 | z=quantize, 75 | diff=diff, 76 | codes=codes, 77 | **update_metrics 78 | ) 79 | 80 | def embed(self, idx): 81 | return F.embedding(idx, self.embedding) 82 | 83 | def _init_embeddings(self, x, dist=None): 84 | self.init = True 85 | rand_centers = self._randomize(x) 86 | self.cluster_sum.data.copy_(rand_centers) 87 | self.cluster_size.data.fill_(1) 88 | 89 | 90 | def _randomize(self, x): 91 | n = x.size(0) 92 | if n < self.n_codes: 93 | r = (self.n_codes + n - 1) // n # r = math.ceil(n_codes / n) 94 | std = 0.01 / np.sqrt(self.dim) 95 | x = x.repeat(r, 1) 96 | x += std * torch.randn_like(x) 97 | return x[torch.randperm(x.size(0))][:self.n_codes] 98 | 99 | def _ema_update(self, x, cluster_assign, dist=None): 100 | with torch.no_grad(): 101 | cluster_size = cluster_assign.sum(0) 102 | cluster_sum = cluster_assign.t() @ x 103 | 104 | rand_centers = self._randomize(x) 105 | 106 | # Gather results from all GPUs to get better estimate 107 | # This doesn't work for the DataParallel accelerator 108 | # if dist is not None: 109 | # dist.broadcast(rand_centers) 110 | # cluster_size = dist.reduce(cluster_size, reduce_op='sum') 111 | # cluster_sum = dist.reduce(cluster_sum, reduce_op='sum') 112 | 113 | # EMA update step 114 | # self.cluster_size.data.mul_(self.decay).add_(cluster_size, alpha=1-self.decay) 115 | # self.cluster_sum.data.mul_(self.decay).add_(cluster_sum, alpha=1-self.decay) 116 | self.cluster_size.data.copy_(self.decay*self.cluster_size + (1 - self.decay)*cluster_size) 117 | self.cluster_sum.data.copy_(self.decay*self.cluster_sum + (1 - self.decay)*cluster_sum) 118 | 119 | used = (self.cluster_size >= self.threshold).float().unsqueeze(-1) 120 | 121 | n = self.cluster_size.sum() 122 | # Use additive smoothing to mitigate exploding gradients 123 | count = (self.cluster_size + self.eps) / (n + self.n_codes*self.eps) * n 124 | 125 | cluster_centers = self.cluster_sum / count.unsqueeze(-1) 126 | cluster_centers = used * cluster_centers + (1 - used) * rand_centers 127 | self.embedding.data.copy_(cluster_centers) 128 | 129 | # Also reset size of cluster when doing random restarts => prevent from randomly restarting many times in a row 130 | # new_sizes = used.squeeze(1) * self.cluster_size + (1 - used.squeeze(1)) 131 | # self.cluster_size.data.copy_(new_sizes) 132 | 133 | # Compute metrics 134 | avg_usage = used.mean() 135 | usage = used.sum() 136 | pr = cluster_size / cluster_size.sum() 137 | entropy = -(pr * (pr + 1e-5).log()).sum() 138 | 139 | return { 140 | 'avg_usage': avg_usage, 141 | 'usage': usage, 142 | 'entropy': entropy 143 | } 144 | 145 | 146 | class VqVaeModule(pl.LightningModule): 147 | def __init__(self, 148 | d_model=512, 149 | context_size=256, 150 | n_codes=1024, 151 | n_groups=2, 152 | d_latent=1024, 153 | lr=1e-4, 154 | lr_schedule='sqrt_decay', 155 | warmup_steps=1000, 156 | max_steps=10000, 157 | encoder_layers=6, 158 | decoder_layers=6, 159 | encoder_ffn_dim=2048, 160 | decoder_ffn_dim=2048, 161 | windowed_attention_pr=0.0, 162 | max_lookahead=4, 163 | disable_vq=False): 164 | super().__init__() 165 | 166 | self.d_model = d_model 167 | self.context_size = context_size 168 | self.n_codes = n_codes 169 | self.n_groups = n_groups 170 | self.d_latent = d_latent 171 | 172 | self.beta = 0.02 173 | self.cycle_length = 2000 174 | 175 | self.lr = lr 176 | self.lr_schedule = lr_schedule 177 | self.warmup_steps = warmup_steps 178 | self.max_steps = max_steps 179 | self.windowed_attention_pr = windowed_attention_pr 180 | self.max_lookahead = max_lookahead 181 | self.disable_vq = disable_vq 182 | 183 | self.vocab = RemiVocab() 184 | 185 | self.pad_token = self.vocab.to_i(PAD_TOKEN) 186 | self.bos_token = self.vocab.to_i(BOS_TOKEN) 187 | self.eos_token = self.vocab.to_i(EOS_TOKEN) 188 | self.mask_token = self.vocab.to_i(MASK_TOKEN) 189 | 190 | encoder_config = BertConfig( 191 | vocab_size=1, 192 | pad_token_id=0, 193 | hidden_size=self.d_model, 194 | num_hidden_layers=encoder_layers, 195 | num_attention_heads=8, 196 | intermediate_size=encoder_ffn_dim, 197 | max_position_embeddings=1024, 198 | position_embedding_type='relative_key_query' 199 | ) 200 | decoder_config = BertConfig( 201 | vocab_size=1, 202 | pad_token_id=0, 203 | hidden_size=self.d_model, 204 | num_hidden_layers=decoder_layers, 205 | num_attention_heads=8, 206 | intermediate_size=decoder_ffn_dim, 207 | max_position_embeddings=1024, 208 | position_embedding_type='relative_key_query' 209 | ) 210 | config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config) 211 | self.transformer = EncoderDecoderModel(config) 212 | self.transformer.config.decoder.is_decoder = True 213 | self.transformer.config.decoder.add_cross_attention = True 214 | self.encoder = self.transformer.encoder 215 | self.decoder = self.transformer.decoder 216 | 217 | self.in_layer = nn.Embedding(len(self.vocab), self.d_model) 218 | self.out_layer = nn.Linear(self.d_model, len(self.vocab), bias=False) 219 | 220 | self.vq_embed = VectorQuantizeEMA(self.d_latent, self.n_codes, self.n_groups) 221 | self.pooling = nn.Linear(self.d_model, self.d_latent, bias=False) 222 | self.unpooling = nn.Linear(self.d_latent, self.d_model, bias=False) 223 | self.attention_proj = nn.Linear(self.d_model, self.d_model) 224 | 225 | self.rec_loss = nn.CrossEntropyLoss(ignore_index=self.pad_token) 226 | 227 | self.save_hyperparameters() 228 | 229 | def get_datamodule(self, midi_files, **kwargs): 230 | return MidiDataModule( 231 | midi_files, 232 | self.context_size, 233 | max_bars_per_context=1, 234 | bar_token_mask=MASK_TOKEN, 235 | **kwargs 236 | ) 237 | 238 | def forward(self, x, y=None, latent=None, use_windowed_attention=False, return_latent_logits=False): 239 | if y is None: 240 | y = x.clone().detach() 241 | 242 | # VQ-VAE 243 | if latent is None: 244 | encoder_out = self.encode(x) 245 | latent = encoder_out['z'] 246 | 247 | logits = self.decode(x, latent, use_windowed_attention) 248 | return { 249 | 'logits': logits, 250 | **encoder_out 251 | } 252 | 253 | 254 | def embed(self, x): 255 | return self.in_layer(x) 256 | 257 | def encode(self, x): 258 | x_emb = self.embed(x) 259 | 260 | # Shape of out: (batch_size, seq_len, d_model) 261 | out = self.encoder(inputs_embeds=x_emb, output_hidden_states=True) 262 | hidden = out.pooler_output 263 | # Shape of z_e: (batch_size, d_model * n_groups) 264 | z_e = self.pooling(hidden) 265 | 266 | if self.disable_vq: 267 | # AE baseline 268 | return { 'z': z_e } 269 | else: 270 | # VQ-VAE 271 | # Shape of z_q: (batch_size, d_model * n_groups) 272 | dist = self.trainer.accelerator.training_type_plugin if self.training else None 273 | return self.vq_embed(z_e, dist=dist) 274 | 275 | 276 | def decode(self, x, latent, use_windowed_attention=False): 277 | # Shape of latent: (batch_size, n_groups, d_model) 278 | x_emb = self.embed(x) 279 | seq_len = x_emb.size(1) 280 | 281 | # Shape of h0: (batch_size, d_model) 282 | h0 = self.unpooling(latent) 283 | 284 | # Make model decoder-only by fixing h0 285 | # h0 = torch.zeros_like(h0) 286 | 287 | # Strategy 1: Add latent embeddings to input embeddings 288 | x_emb += h0.unsqueeze(1).repeat(1, seq_len, 1) 289 | 290 | # Strategy 2: Use latent embedding in cross-attention 291 | x_attention = self.attention_proj(h0.unsqueeze(1).repeat(1, self.context_size, 1)) 292 | 293 | # Relative pos. embeddings need source and target to be of the same length 294 | # -> prevents einsum shape mismatch error 295 | padding = torch.zeros_like(x_attention) 296 | padding[:, :x_emb.size(1)] = x_emb 297 | x_emb = padding 298 | 299 | 300 | if self.training or use_windowed_attention: 301 | attention_mask = self.rand_attention_mask(x) 302 | else: 303 | attention_mask = self.get_attention_mask(x) 304 | padding = torch.zeros((x.size(0), self.context_size, self.context_size), device=self.device, dtype=torch.int) 305 | padding[:, :attention_mask.size(1), :attention_mask.size(2)] = attention_mask 306 | attention_mask = padding 307 | 308 | out = self.decoder( 309 | inputs_embeds=x_emb, 310 | encoder_hidden_states=x_attention, 311 | attention_mask=attention_mask, 312 | output_hidden_states=True 313 | ) 314 | hidden = out.hidden_states[-1][:, :seq_len] 315 | logits = self.out_layer(hidden).contiguous() 316 | 317 | return logits 318 | 319 | def get_loss(self, batch, windowed_attention_pr=None): 320 | if windowed_attention_pr is None: 321 | windowed_attention_pr = self.windowed_attention_pr 322 | use_windowed_attention = True if random.random() < windowed_attention_pr else False 323 | 324 | x = batch['input_ids'] 325 | labels = batch['labels'] 326 | 327 | out = self.forward( 328 | x, 329 | y=labels, 330 | use_windowed_attention=use_windowed_attention, 331 | ) 332 | 333 | logits = out['logits'] 334 | # Reshape logits to: (batch_size * seq_len, vocab_size) 335 | logits = logits.view(-1, logits.size(-1)) 336 | # Reshape labels to: (batch_size * seq_len) 337 | labels = labels.view(-1) 338 | 339 | rec_loss = self.rec_loss(logits, labels) 340 | 341 | if self.disable_vq: 342 | loss = rec_loss 343 | else: 344 | diff = out['diff'] 345 | loss = rec_loss + self.beta*diff 346 | 347 | return { 348 | 'loss': loss, 349 | 'rec_loss': rec_loss, 350 | **out 351 | } 352 | 353 | def training_step(self, batch, batch_idx, optimizer_idx=0): 354 | metrics = self.get_loss(batch) 355 | log_metrics = { key: metrics[key].detach() for key in ['loss', 'rec_loss', 'diff', 'avg_usage', 'usage', 'entropy'] if key in metrics } 356 | self.log('train', log_metrics, on_step=True, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 357 | return metrics['loss'] 358 | 359 | def validation_step(self, batch, batch_idx): 360 | metrics = self.get_loss(batch) 361 | log_metrics = { key: metrics[key].detach() for key in ['rec_loss', 'diff', 'avg_usage', 'usage', 'entropy'] if key in metrics } 362 | 363 | # Compute perplexity 364 | x, y = batch['input_ids'], batch['labels'] 365 | pad_token_id = self.vocab.to_i(PAD_TOKEN) 366 | logits = metrics['logits'] 367 | log_pr = logits.log_softmax(dim=-1) 368 | log_pr[y == pad_token_id] = 0 # log(pr) = log(1) for padding 369 | log_pr = torch.gather(log_pr, -1, y.unsqueeze(-1)).squeeze(-1) 370 | t = (y != pad_token_id).sum(dim=-1) 371 | ppl = (-log_pr.sum(dim=1) / t).exp().mean() 372 | log_metrics['ppl'] = ppl.detach() 373 | 374 | self.log('valid', log_metrics, on_step=True, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 375 | # Log loss separately for model checkpoint monitor 376 | self.log('valid_loss', metrics['loss'], on_step=True, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 377 | return metrics['loss'] 378 | 379 | def test_step(self, batch, batch_idx): 380 | metrics = self.get_loss(batch) 381 | return metrics['loss'] 382 | 383 | def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): 384 | step = self.trainer.global_step 385 | 386 | # # beta is increased for C*R steps and then held constant for C*(1-R) steps 387 | # if step/2 >= self.cycle_length: 388 | # self.cycle_length *= 2 389 | # C = self.cycle_length # cycle length 390 | # R = 1000 # restart steps 391 | # b_min, b_max = 0.0, 0.1 392 | # t = max(0, min(1, (step % C) / R)) 393 | # self.beta = b_min*(1 - t) + b_max*t 394 | # self.log('beta', self.beta, on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True) 395 | 396 | def configure_optimizers(self): 397 | # set LR to 1, scale with LambdaLR scheduler 398 | optimizer = torch.optim.AdamW(self.parameters(), lr=1, weight_decay=0.01) 399 | 400 | if self.lr_schedule == 'sqrt_decay': 401 | # constant warmup, then 1/sqrt(n) decay starting from the initial LR 402 | lr_func = lambda step: min(self.lr, self.lr / math.sqrt(max(step, 1)/self.warmup_steps)) 403 | elif self.lr_schedule == 'linear': 404 | # linear warmup, linear decay 405 | lr_func = lambda step: min(self.lr, self.lr*step/self.warmup_steps, self.lr*(1 - (step - self.warmup_steps)/self.max_steps)) 406 | elif self.lr_schedule == 'cosine': 407 | # linear warmup, cosine decay to 10% of initial LR 408 | lr_func = lambda step: self.lr * min(step/self.warmup_steps, 0.55 + 0.45*math.cos(math.pi*(min(step, self.max_steps) - self.warmup_steps)/(self.max_steps - self.warmup_steps))) 409 | else: 410 | # Use no lr scheduling 411 | lr_func = lambda step: self.lr 412 | 413 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_func) 414 | return [optimizer], [{ 415 | 'scheduler': scheduler, 416 | 'interval': 'step', 417 | }] 418 | 419 | def rand_attention_mask(self, x, pr=0.2, max_size=None): 420 | if max_size == None: 421 | max_size = self.max_lookahead 422 | if max_size is not None and self.training and random.random() < pr: 423 | mask_size, k = random.randint(1, max_size), 0 424 | else: 425 | mask_size, k = 1, 1 426 | return self.get_attention_mask(x, mask_size=mask_size, k=k) 427 | 428 | def get_attention_mask(self, x, mask_size=1, k=1): 429 | batch_size, seq_len = x.shape[:2] 430 | 431 | # Standard self-attention mask for auto-regressive modelling 432 | tri_mask = torch.ones((seq_len//mask_size+1, seq_len//mask_size+1), device=self.device, dtype=torch.int) 433 | tri_mask = torch.triu(tri_mask, diagonal=k) 434 | tri_mask = (~tri_mask.bool()).int() 435 | # Create windowed self-attention mask, forcing the model to prefict farther into the future 436 | window_mask = tri_mask.repeat_interleave(mask_size, dim=0).repeat_interleave(mask_size, dim=1)[:seq_len, :seq_len] 437 | # First token needs to be always visible 438 | window_mask[:, 0] = 1 439 | 440 | return window_mask.unsqueeze(0).repeat(batch_size, 1, 1) 441 | 442 | 443 | 444 | -------------------------------------------------------------------------------- /src/precompute_latents.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import glob 4 | import pickle 5 | import random 6 | import torch 7 | from torch.utils.data.dataloader import DataLoader 8 | 9 | from models.vae import VqVaeModule 10 | from constants import MASK_TOKEN 11 | from datasets import MidiDataset, SeqCollator 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | ROOT_DIR = os.getenv('ROOT_DIR', os.path.join(os.getenv('TMPDIR', './temp'), 'lmd_full')) 16 | MAX_N_FILES = int(os.getenv('MAX_N_FILES', '-1')) 17 | 18 | BATCH_SIZE = int(os.getenv('BATCH_SIZE', '8')) 19 | 20 | N_WORKERS = min(os.cpu_count(), float(os.getenv('N_WORKERS', 'inf'))) 21 | if device.type == 'cuda': 22 | N_WORKERS = min(N_WORKERS, 8*torch.cuda.device_count()) 23 | N_WORKERS = int(N_WORKERS) 24 | 25 | LATENT_CACHE_PATH = os.getenv('LATENT_CACHE_PATH', os.path.join(os.getenv('SCRATCH', os.getenv('TMPDIR')), 'latent')) 26 | os.makedirs(LATENT_CACHE_PATH, exist_ok=True) 27 | 28 | 29 | ### Create data loaders ### 30 | midi_files = glob.glob(os.path.join(ROOT_DIR, '**/*.mid'), recursive=True) 31 | if MAX_N_FILES > 0: 32 | midi_files = midi_files[:MAX_N_FILES] 33 | 34 | # Shuffle files for approximate parallelizability 35 | random.shuffle(midi_files) 36 | 37 | 38 | VAE_CHECKPOINT = os.getenv('VAE_CHECKPOINT', None) 39 | vae_module = VqVaeModule.load_from_checkpoint(checkpoint_path=VAE_CHECKPOINT).to(device) 40 | vae_module.eval() 41 | vae_module.freeze() 42 | 43 | collator = SeqCollator(context_size=vae_module.context_size) 44 | 45 | print('***** PRECOMPUTING LATENT REPRESENTATIONS *****') 46 | print(f'Number of files: {len(midi_files)}') 47 | print(f'Using cache: {LATENT_CACHE_PATH}') 48 | print('***********************************************') 49 | 50 | for i, file in enumerate(midi_files): 51 | print(f"{i:4d}/{len(midi_files)}: {file} ", end='') 52 | cache_key = os.path.basename(file) 53 | cache_file = os.path.join(LATENT_CACHE_PATH, cache_key) 54 | 55 | try: 56 | latents, codes = pickle.load(open(cache_file, 'rb')) 57 | print(f'(already cached: {len(latents)} bars)') 58 | continue 59 | except: 60 | pass 61 | 62 | ds = MidiDataset([file], vae_module.context_size, 63 | description_flavor='none', 64 | max_bars_per_context=1, 65 | bar_token_mask=MASK_TOKEN, 66 | print_errors=True, 67 | ) 68 | 69 | dl = DataLoader(ds, 70 | collate_fn=collator, 71 | batch_size=BATCH_SIZE, 72 | num_workers=N_WORKERS, 73 | pin_memory=True 74 | ) 75 | 76 | latents, codes = [], [] 77 | for batch in dl: 78 | x = batch['input_ids'].to(device) 79 | 80 | out = vae_module.encode(x) 81 | latents.append(out['z']) 82 | codes.append(out['codes']) 83 | 84 | if len(latents) == 0: 85 | continue 86 | 87 | latents = torch.cat(latents).cpu() 88 | codes = torch.cat(codes).cpu() 89 | print(f'(caching latents: {latents.size(0)} bars)') 90 | 91 | # Try to store the computed representation in the cache directory 92 | try: 93 | pickle.dump((latents, codes), open(cache_file, 'wb')) 94 | except Exception as err: 95 | print('Unable to cache file:', str(err)) -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | 5 | import os 6 | import glob 7 | 8 | import pytorch_lightning as pl 9 | 10 | from models.seq2seq import Seq2SeqModule 11 | from models.vae import VqVaeModule 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | ROOT_DIR = os.getenv('ROOT_DIR', './lmd_full') 16 | OUTPUT_DIR = os.getenv('OUTPUT_DIR', './results') 17 | LOGGING_DIR = os.getenv('LOGGING_DIR', './logs') 18 | MAX_N_FILES = int(os.getenv('MAX_N_FILES', -1)) 19 | 20 | MODEL = os.getenv('MODEL', None) 21 | MODEL_NAME = os.getenv('MODEL_NAME', None) 22 | N_CODES = int(os.getenv('N_CODES', 2048)) 23 | N_GROUPS = int(os.getenv('N_GROUPS', 16)) 24 | D_MODEL = int(os.getenv('D_MODEL', 512)) 25 | D_LATENT = int(os.getenv('D_LATENT', 1024)) 26 | 27 | CHECKPOINT = os.getenv('CHECKPOINT', None) 28 | VAE_CHECKPOINT = os.getenv('VAE_CHECKPOINT', None) 29 | 30 | BATCH_SIZE = int(os.getenv('BATCH_SIZE', 128)) 31 | TARGET_BATCH_SIZE = int(os.getenv('TARGET_BATCH_SIZE', 512)) 32 | 33 | EPOCHS = int(os.getenv('EPOCHS', '16')) 34 | WARMUP_STEPS = int(float(os.getenv('WARMUP_STEPS', 4000))) 35 | MAX_STEPS = int(float(os.getenv('MAX_STEPS', 1e20))) 36 | MAX_TRAINING_STEPS = int(float(os.getenv('MAX_TRAINING_STEPS', 100_000))) 37 | LEARNING_RATE = float(os.getenv('LEARNING_RATE', 1e-4)) 38 | LR_SCHEDULE = os.getenv('LR_SCHEDULE', 'const') 39 | CONTEXT_SIZE = int(os.getenv('CONTEXT_SIZE', 256)) 40 | 41 | ACCUMULATE_GRADS = max(1, TARGET_BATCH_SIZE//BATCH_SIZE) 42 | 43 | N_WORKERS = min(os.cpu_count(), float(os.getenv('N_WORKERS', 'inf'))) 44 | if device.type == 'cuda': 45 | N_WORKERS = min(N_WORKERS, 8*torch.cuda.device_count()) 46 | N_WORKERS = int(N_WORKERS) 47 | 48 | 49 | def main(): 50 | ### Define available models ### 51 | 52 | available_models = [ 53 | 'vq-vae', 54 | 'figaro-learned', 55 | 'figaro-expert', 56 | 'figaro', 57 | 'figaro-inst', 58 | 'figaro-chord', 59 | 'figaro-meta', 60 | 'figaro-no-inst', 61 | 'figaro-no-chord', 62 | 'figaro-no-meta', 63 | 'baseline', 64 | ] 65 | 66 | assert MODEL is not None, 'the MODEL needs to be specified' 67 | assert MODEL in available_models, f'unknown MODEL: {MODEL}' 68 | 69 | 70 | ### Create data loaders ### 71 | midi_files = glob.glob(os.path.join(ROOT_DIR, '**/*.mid'), recursive=True) 72 | if MAX_N_FILES > 0: 73 | midi_files = midi_files[:MAX_N_FILES] 74 | 75 | if len(midi_files) == 0: 76 | print(f"WARNING: No MIDI files were found at '{ROOT_DIR}'. Did you download the dataset to the right location?") 77 | exit() 78 | 79 | 80 | MAX_CONTEXT = min(1024, CONTEXT_SIZE) 81 | 82 | if MODEL in ['figaro-learned', 'figaro'] and VAE_CHECKPOINT: 83 | vae_module = VqVaeModule.load_from_checkpoint(checkpoint_path=VAE_CHECKPOINT) 84 | vae_module.cpu() 85 | vae_module.freeze() 86 | vae_module.eval() 87 | 88 | else: 89 | vae_module = None 90 | 91 | 92 | ### Create and train model ### 93 | 94 | # load model from checkpoint if available 95 | 96 | if CHECKPOINT: 97 | model_class = { 98 | 'vq-vae': VqVaeModule, 99 | 'figaro-learned': Seq2SeqModule, 100 | 'figaro-expert': Seq2SeqModule, 101 | 'figaro': Seq2SeqModule, 102 | 'figaro-inst': Seq2SeqModule, 103 | 'figaro-chord': Seq2SeqModule, 104 | 'figaro-meta': Seq2SeqModule, 105 | 'figaro-no-inst': Seq2SeqModule, 106 | 'figaro-no-chord': Seq2SeqModule, 107 | 'figaro-no-meta': Seq2SeqModule, 108 | 'baseline': Seq2SeqModule, 109 | }[MODEL] 110 | model = model_class.load_from_checkpoint(checkpoint_path=CHECKPOINT) 111 | 112 | else: 113 | seq2seq_kwargs = { 114 | 'encoder_layers': 4, 115 | 'decoder_layers': 6, 116 | 'num_attention_heads': 8, 117 | 'intermediate_size': 2048, 118 | 'd_model': D_MODEL, 119 | 'context_size': MAX_CONTEXT, 120 | 'lr': LEARNING_RATE, 121 | 'warmup_steps': WARMUP_STEPS, 122 | 'max_steps': MAX_STEPS, 123 | } 124 | dec_kwargs = { **seq2seq_kwargs } 125 | dec_kwargs['encoder_layers'] = 0 126 | 127 | # use lambda functions for lazy initialization 128 | model = { 129 | 'vq-vae': lambda: VqVaeModule( 130 | encoder_layers=4, 131 | decoder_layers=6, 132 | encoder_ffn_dim=2048, 133 | decoder_ffn_dim=2048, 134 | n_codes=N_CODES, 135 | n_groups=N_GROUPS, 136 | context_size=MAX_CONTEXT, 137 | lr=LEARNING_RATE, 138 | lr_schedule=LR_SCHEDULE, 139 | warmup_steps=WARMUP_STEPS, 140 | max_steps=MAX_STEPS, 141 | d_model=D_MODEL, 142 | d_latent=D_LATENT, 143 | ), 144 | 'figaro-learned': lambda: Seq2SeqModule( 145 | description_flavor='latent', 146 | n_codes=vae_module.n_codes, 147 | n_groups=vae_module.n_groups, 148 | d_latent=vae_module.d_latent, 149 | **seq2seq_kwargs 150 | ), 151 | 'figaro': lambda: Seq2SeqModule( 152 | description_flavor='both', 153 | n_codes=vae_module.n_codes, 154 | n_groups=vae_module.n_groups, 155 | d_latent=vae_module.d_latent, 156 | **seq2seq_kwargs 157 | ), 158 | 'figaro-expert': lambda: Seq2SeqModule( 159 | description_flavor='description', 160 | **seq2seq_kwargs 161 | ), 162 | 'figaro-no-meta': lambda: Seq2SeqModule( 163 | description_flavor='description', 164 | description_options={ 'instruments': True, 'chords': True, 'meta': False }, 165 | **seq2seq_kwargs 166 | ), 167 | 'figaro-no-inst': lambda: Seq2SeqModule( 168 | description_flavor='description', 169 | description_options={ 'instruments': False, 'chords': True, 'meta': True }, 170 | **seq2seq_kwargs 171 | ), 172 | 'figaro-no-chord': lambda: Seq2SeqModule( 173 | description_flavor='description', 174 | description_options={ 'instruments': True, 'chords': False, 'meta': True }, 175 | **seq2seq_kwargs 176 | ), 177 | 'baseline': lambda: Seq2SeqModule( 178 | description_flavor='none', 179 | **dec_kwargs 180 | ), 181 | }[MODEL]() 182 | 183 | datamodule = model.get_datamodule( 184 | midi_files, 185 | vae_module=vae_module, 186 | batch_size=BATCH_SIZE, 187 | num_workers=N_WORKERS, 188 | pin_memory=True 189 | ) 190 | 191 | checkpoint_callback = pl.callbacks.model_checkpoint.ModelCheckpoint( 192 | monitor='valid_loss', 193 | dirpath=os.path.join(OUTPUT_DIR, MODEL), 194 | filename='{step}-{valid_loss:.2f}', 195 | save_last=True, 196 | save_top_k=2, 197 | every_n_train_steps=1000, 198 | ) 199 | 200 | lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step') 201 | 202 | swa_callback = pl.callbacks.StochasticWeightAveraging(swa_lrs=0.05) 203 | 204 | trainer = pl.Trainer( 205 | devices=1 if device.type == 'cpu' else torch.cuda.device_count(), 206 | accelerator='auto', 207 | profiler='simple', 208 | callbacks=[checkpoint_callback, lr_monitor, swa_callback], 209 | max_epochs=EPOCHS, 210 | max_steps=MAX_TRAINING_STEPS, 211 | log_every_n_steps=max(100, min(25*ACCUMULATE_GRADS, 200)), 212 | val_check_interval=max(500, min(300*ACCUMULATE_GRADS, 1000)), 213 | limit_val_batches=64, 214 | accumulate_grad_batches=ACCUMULATE_GRADS, 215 | gradient_clip_val=1.0, 216 | ) 217 | 218 | trainer.fit(model, datamodule) 219 | 220 | if __name__ == '__main__': 221 | main() -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | 4 | def combine_batches(batches, bars_per_sequence=8, description_flavor='none', device=None): 5 | if device is None: 6 | device = batches[0]['input_ids'].device 7 | 8 | batch_size = batches[0]['input_ids'].size(0) 9 | 10 | zero = torch.zeros(1, device=device, dtype=torch.int) 11 | 12 | contexts = [] 13 | batch_ = {} 14 | 15 | for i in range(batch_size): 16 | curr_bar = 0 17 | ctx = { 18 | 'input_ids': [], 19 | 'bar_ids': [], 20 | 'position_ids': [], 21 | 'slices': [], 22 | 'description': [], 23 | 'desc_bar_ids': [], 24 | 'desc_slices': [], 25 | 'latents': [], 26 | 'latent_slices': [], 27 | 'files': [], 28 | } 29 | 30 | for batch in batches: 31 | if i >= batch['input_ids'].size(0): 32 | continue 33 | 34 | curr = curr_bar 35 | 36 | bar_ids = batch['bar_ids'][i] 37 | starts = (bar_ids >= curr).nonzero() 38 | ends = (bar_ids >= max(1, curr) + bars_per_sequence).nonzero() 39 | if starts.size(0) == 0: 40 | continue 41 | start = starts[0, 0] 42 | 43 | if ends.size(0) == 0: 44 | end = bar_ids.size(0) 45 | curr_bar = bar_ids[-1] + 1 46 | else: 47 | end = ends[0, 0] 48 | curr_bar = bar_ids[end] 49 | 50 | if description_flavor in ['description', 'both']: 51 | desc_bar_ids = batch['desc_bar_ids'][i] 52 | desc_start = (desc_bar_ids >= curr).nonzero()[0, 0] 53 | desc_ends = (desc_bar_ids >= max(1, curr) + bars_per_sequence).nonzero() 54 | 55 | if desc_ends.size(0) == 0: 56 | desc_end = desc_bar_ids.size(0) 57 | else: 58 | desc_end = desc_ends[0, 0] 59 | 60 | if description_flavor in ['latent', 'both']: 61 | latent_start = curr 62 | latent_end = max(1, curr) + bars_per_sequence 63 | 64 | 65 | ctx['input_ids'].append(batch['input_ids'][i, start:end]) 66 | ctx['bar_ids'].append(batch['bar_ids'][i, start:end]) 67 | ctx['position_ids'].append(batch['position_ids'][i, start:end]) 68 | ctx['slices'].append((start, end)) 69 | if description_flavor in ['description', 'both']: 70 | ctx['description'].append(batch['description'][i, desc_start:desc_end]) 71 | ctx['desc_bar_ids'].append(batch['desc_bar_ids'][i, desc_start:desc_end]) 72 | ctx['desc_slices'].append((desc_start, desc_end)) 73 | if description_flavor in ['latent', 'both']: 74 | ctx['latents'].append(batch['latents'][i, latent_start:latent_end]) 75 | ctx['latent_slices'].append((latent_start, latent_end)) 76 | ctx['files'].append(batch['files'][i]) 77 | 78 | if len(ctx['files']) <= 1: 79 | continue 80 | 81 | keys = ['input_ids', 'bar_ids', 'position_ids', 'description', 'desc_bar_ids', 'latents'] 82 | for key in keys: 83 | if key in ctx and len(ctx[key]) > 0: 84 | ctx[key] = torch.cat(ctx[key]) 85 | ctx['labels'] = torch.cat([ctx['input_ids'][1:], zero]) 86 | ctx['files'] = '__'.join(ctx['files']).replace('.mid', '') + '.mid' 87 | 88 | contexts.append(ctx) 89 | 90 | batch_['files'] = [ctx['files'] for ctx in contexts] 91 | 92 | for key in ['input_ids', 'bar_ids', 'position_ids', 'description', 'desc_bar_ids', 'latents', 'labels']: 93 | xs = [ctx[key] for ctx in contexts if isinstance(ctx[key], torch.Tensor)] 94 | if len(xs) > 0: 95 | xs = pad_sequence(xs, batch_first=True, padding_value=0) 96 | if not key in ['latents']: 97 | xs = xs.long() 98 | batch_[key] = xs 99 | 100 | return batch_ 101 | 102 | 103 | def medley_iterator(dl, n_pieces=2, n_bars=8, description_flavor='none'): 104 | dl_iter = iter(dl) 105 | try: 106 | while True: 107 | batches = [next(dl_iter) for _ in range(n_pieces)] 108 | batch = combine_batches(batches, 109 | bars_per_sequence=n_bars, 110 | description_flavor=description_flavor 111 | ) 112 | yield batch 113 | except StopIteration: 114 | return 115 | -------------------------------------------------------------------------------- /src/vocab.py: -------------------------------------------------------------------------------- 1 | import pretty_midi 2 | from collections import Counter 3 | import torchtext 4 | from torch import Tensor 5 | 6 | from constants import ( 7 | DEFAULT_VELOCITY_BINS, 8 | DEFAULT_DURATION_BINS, 9 | DEFAULT_TEMPO_BINS, 10 | DEFAULT_POS_PER_QUARTER, 11 | DEFAULT_NOTE_DENSITY_BINS, 12 | DEFAULT_MEAN_VELOCITY_BINS, 13 | DEFAULT_MEAN_PITCH_BINS, 14 | DEFAULT_MEAN_DURATION_BINS 15 | ) 16 | 17 | 18 | from constants import ( 19 | MAX_BAR_LENGTH, 20 | MAX_N_BARS, 21 | 22 | PAD_TOKEN, 23 | UNK_TOKEN, 24 | BOS_TOKEN, 25 | EOS_TOKEN, 26 | MASK_TOKEN, 27 | 28 | TIME_SIGNATURE_KEY, 29 | BAR_KEY, 30 | POSITION_KEY, 31 | INSTRUMENT_KEY, 32 | PITCH_KEY, 33 | VELOCITY_KEY, 34 | DURATION_KEY, 35 | TEMPO_KEY, 36 | CHORD_KEY, 37 | 38 | NOTE_DENSITY_KEY, 39 | MEAN_PITCH_KEY, 40 | MEAN_VELOCITY_KEY, 41 | MEAN_DURATION_KEY, 42 | ) 43 | 44 | 45 | 46 | class Tokens: 47 | def get_instrument_tokens(key=INSTRUMENT_KEY): 48 | tokens = [f'{key}_{pretty_midi.program_to_instrument_name(i)}' for i in range(128)] 49 | tokens.append(f'{key}_drum') 50 | return tokens 51 | 52 | def get_chord_tokens(key=CHORD_KEY, qualities = ['maj', 'min', 'dim', 'aug', 'dom7', 'maj7', 'min7', 'None']): 53 | pitch_classes = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] 54 | 55 | chords = [f'{root}:{quality}' for root in pitch_classes for quality in qualities] 56 | chords.append('N:N') 57 | 58 | tokens = [f'{key}_{chord}' for chord in chords] 59 | return tokens 60 | 61 | def get_time_signature_tokens(key=TIME_SIGNATURE_KEY): 62 | denominators = [2, 4, 8, 16] 63 | time_sigs = [f'{p}/{q}' for q in denominators for p in range(1, MAX_BAR_LENGTH*q + 1)] 64 | tokens = [f'{key}_{time_sig}' for time_sig in time_sigs] 65 | return tokens 66 | 67 | def get_midi_tokens( 68 | instrument_key=INSTRUMENT_KEY, 69 | time_signature_key=TIME_SIGNATURE_KEY, 70 | pitch_key=PITCH_KEY, 71 | velocity_key=VELOCITY_KEY, 72 | duration_key=DURATION_KEY, 73 | tempo_key=TEMPO_KEY, 74 | bar_key=BAR_KEY, 75 | position_key=POSITION_KEY 76 | ): 77 | instrument_tokens = Tokens.get_instrument_tokens(instrument_key) 78 | 79 | pitch_tokens = [f'{pitch_key}_{i}' for i in range(128)] + [f'{pitch_key}_drum_{i}' for i in range(128)] 80 | velocity_tokens = [f'{velocity_key}_{i}' for i in range(len(DEFAULT_VELOCITY_BINS))] 81 | duration_tokens = [f'{duration_key}_{i}' for i in range(len(DEFAULT_DURATION_BINS))] 82 | tempo_tokens = [f'{tempo_key}_{i}' for i in range(len(DEFAULT_TEMPO_BINS))] 83 | bar_tokens = [f'{bar_key}_{i}' for i in range(MAX_N_BARS)] 84 | position_tokens = [f'{position_key}_{i}' for i in range(MAX_BAR_LENGTH*4*DEFAULT_POS_PER_QUARTER)] 85 | 86 | time_sig_tokens = Tokens.get_time_signature_tokens(time_signature_key) 87 | 88 | return ( 89 | time_sig_tokens + 90 | tempo_tokens + 91 | instrument_tokens + 92 | pitch_tokens + 93 | velocity_tokens + 94 | duration_tokens + 95 | bar_tokens + 96 | position_tokens 97 | ) 98 | 99 | class Vocab: 100 | def __init__(self, counter, specials=[PAD_TOKEN, UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, MASK_TOKEN], unk_token=UNK_TOKEN): 101 | self.vocab = torchtext.vocab.vocab(counter) 102 | 103 | self.specials = specials 104 | for i, token in enumerate(self.specials): 105 | self.vocab.insert_token(token, i) 106 | 107 | if unk_token in specials: 108 | self.vocab.set_default_index(self.vocab.get_stoi()[unk_token]) 109 | 110 | def to_i(self, token): 111 | return self.vocab.get_stoi()[token] 112 | 113 | def to_s(self, idx): 114 | if idx >= len(self.vocab): 115 | return UNK_TOKEN 116 | else: 117 | return self.vocab.get_itos()[idx] 118 | 119 | def __len__(self): 120 | return len(self.vocab) 121 | 122 | def encode(self, seq): 123 | return self.vocab(seq) 124 | 125 | def decode(self, seq): 126 | if isinstance(seq, Tensor): 127 | seq = seq.numpy() 128 | return self.vocab.lookup_tokens(seq) 129 | 130 | 131 | class RemiVocab(Vocab): 132 | def __init__(self): 133 | midi_tokens = Tokens.get_midi_tokens() 134 | chord_tokens = Tokens.get_chord_tokens() 135 | 136 | self.tokens = midi_tokens + chord_tokens 137 | 138 | counter = Counter(self.tokens) 139 | super().__init__(counter) 140 | 141 | 142 | class DescriptionVocab(Vocab): 143 | def __init__(self): 144 | time_sig_tokens = Tokens.get_time_signature_tokens() 145 | instrument_tokens = Tokens.get_instrument_tokens() 146 | chord_tokens = Tokens.get_chord_tokens() 147 | 148 | bar_tokens = [f'Bar_{i}' for i in range(MAX_N_BARS)] 149 | density_tokens = [f'{NOTE_DENSITY_KEY}_{i}' for i in range(len(DEFAULT_NOTE_DENSITY_BINS))] 150 | velocity_tokens = [f'{MEAN_VELOCITY_KEY}_{i}' for i in range(len(DEFAULT_MEAN_VELOCITY_BINS))] 151 | pitch_tokens = [f'{MEAN_PITCH_KEY}_{i}' for i in range(len(DEFAULT_MEAN_PITCH_BINS))] 152 | duration_tokens = [f'{MEAN_DURATION_KEY}_{i}' for i in range(len(DEFAULT_MEAN_DURATION_BINS))] 153 | 154 | self.tokens = ( 155 | time_sig_tokens + 156 | instrument_tokens + 157 | chord_tokens + 158 | density_tokens + 159 | velocity_tokens + 160 | pitch_tokens + 161 | duration_tokens + 162 | bar_tokens 163 | ) 164 | 165 | counter = Counter(self.tokens) 166 | super().__init__(counter) 167 | --------------------------------------------------------------------------------