├── .gitignore ├── README.md ├── Tokenization.pdf ├── midi2scoretransformer ├── chunker.py ├── config.py ├── constants.py ├── data │ └── .gitignore ├── dataset.py ├── evaluation │ └── run_eval.py ├── models │ ├── embedding.py │ ├── model.py │ └── roformer.py ├── score_utils.py ├── tokenizer.py └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | *.ckpt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MIDI2ScoreTransformer 2 | Code for the ISMIR 2024 paper ["End-to-end Piano Performance-MIDI to Score Conversion with Transformers"](https://arxiv.org/abs/2410.00210) 3 | 4 | ## Installation 5 | The code is written in Python 3.11 and relies on the following packages 6 | - everything mentioned in `requirements.txt` 7 | - `MuseScore`: https://github.com/musescore/MuseScore (depending on your platform, please adjust the path in `constants.py`) 8 | 9 | 10 | Due to delays and difficulties merging changes with upstream versions, we currently require installing custom versions of the following packages (will be done automatically with `requirements.txt`): 11 | 12 | - `music21`: https://github.com/TimFelixBeyer/music21 (fixes tie-stripping and contains various other tweaks) 13 | - `score_transformer`: https://github.com/TimFelixBeyer/ScoreTransformer (for score_similarity metrics and tokenization comparisons) 14 | 15 | In addition, you must manually clone and install this package: 16 | - `muster`: https://github.com/TimFelixBeyer/amtevaluation.github.io (fixes various memory leak issues of the original version) 17 | 18 | ### Datasets 19 | Please use this version of the ASAP-Dataset as it contains some fixes. 20 | 21 | - `ASAPDataset`: [https://github.com/TimFelixBeyer/ASAPDataset](https://github.com/TimFelixBeyer/asap-dataset/tree/8cba199e15931975542010a7ea2ff94a6fc9cbee) (contains a few fixes for the ASAP dataset, make sure you select the correct commit for reproducibility and place the `asap-dataset` folder into the `data` folder). 22 | - `ACPAS`: [https://cheriell.github.io/research/ACPAS_dataset/](https://cheriell.github.io/research/ACPAS_dataset/) (only download `metadata_R.csv` and `metadata_S.csv`, place them into a folder called `ACPAS-dataset` and put it into the `data` folder) 23 | 24 | Afterwards, your `data` folder should look like this: 25 | ``` 26 | data 27 | ├── ACPAS-dataset 28 | │   ├── metadata_R.csv 29 | │   └── metadata_S.csv 30 | └── asap-dataset 31 | ``` 32 | 33 | ### Setup (Inference) 34 | 1. Download the `MIDI2ScoreTF.ckpt` from GitHub ('Releases' section) and place it where you like. 35 | 36 | 37 | ### Setup (Training) 38 | 1. Run `dataset.py` to preprocess the datasets and populate the cache. 39 | 2. Run `chunker.py` to create the beat-aligned chunks. 40 | 3. More instructions to follow 41 | -------------------------------------------------------------------------------- /Tokenization.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimFelixBeyer/MIDI2ScoreTransformer/115432bda16ca16e0fec2e9465788f2ba369971f/Tokenization.pdf -------------------------------------------------------------------------------- /midi2scoretransformer/chunker.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains code for chunking MIDI and MusicXML files into pseudo-measures using 3 | beat-level annotations from the ASAP dataset to align the MIDI and MusicXML. 4 | 5 | We use a greedy algorithm to align the MIDI and MusicXML files by moving notes that are 6 | close to the beat boundaries to the next/previous measure if that improves the alignment. 7 | The resulting chunks are saved as JSON files next to the performance-MIDI files. 8 | """ 9 | 10 | import json 11 | import os 12 | import sys 13 | 14 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 15 | 16 | import pretty_midi 17 | from joblib import Parallel, delayed 18 | 19 | from dataset import ASAPDataset 20 | from tokenizer import MultistreamTokenizer 21 | 22 | 23 | def make_measures(midi, midi_score, mxl, annots, swap=True): 24 | measures = {"midi": [], "mxl": []} 25 | # TODO: Maybe do beats instead of downbeats? 26 | swap_tol = 0.05 27 | measures = {"midi": [], "mxl": []} 28 | i, j = (0, 0) 29 | for sb_seconds in annots["midi_score_beats"]: 30 | # Convert the score beats from seconds to musical time using 31 | # the provided MIDI scores 32 | sb = midi_score.time_to_tick(sb_seconds) / midi_score.resolution 33 | measures["mxl"].append([]) 34 | while j < len(mxl) and mxl[j].offset < sb: 35 | measures["mxl"][-1].append((j, mxl[j].pitch.midi)) 36 | j += 1 37 | 38 | # Use the performance beats to split the MIDI into pseudo-measures of beat length 39 | for pb in annots["performance_beats"]: 40 | measures["midi"].append([]) 41 | while i < len(midi) and midi[i].start < pb - swap_tol: 42 | measures["midi"][-1].append((i, midi[i].pitch)) 43 | i += 1 44 | 45 | measures["midi"].append([]) 46 | measures["mxl"].append([]) 47 | while i < len(midi): 48 | measures["midi"][-1].append((i, midi[i].pitch)) 49 | i += 1 50 | while j < len(mxl): 51 | measures["mxl"][-1].append((j, mxl[j].pitch.midi)) 52 | j += 1 53 | 54 | if not swap: 55 | for i in range(len(measures["midi"])): 56 | measures["midi"][i] = [i_ for i_, p in measures["midi"][i]] 57 | measures["mxl"][i] = [i_ for i_, p in measures["mxl"][i]] 58 | measures["swapped"] = False 59 | return measures 60 | from collections import Counter 61 | 62 | n_swaps = 0 63 | swap_tol = 0.5 64 | for i in range(len(measures["midi"]) - 1): 65 | swapped = True 66 | while swapped: 67 | swapped = False 68 | # Figure out which MIDI pitches can be moved forward/backward by one beat 69 | # because that would yield better alignment. Only considers notes within 70 | # 0.5s of the beat boundary. 71 | c_midi = Counter([p for _, p in measures["midi"][i]]) 72 | c_mxl = Counter([p for _, p in measures["mxl"][i]]) 73 | 74 | c_midi_next = Counter([p for _, p in measures["midi"][i + 1]]) 75 | c_mxl_next = Counter([p for _, p in measures["mxl"][i + 1]]) 76 | too_much = c_midi - c_mxl 77 | lacking = c_mxl - c_midi 78 | too_much_next = c_midi_next - c_mxl_next 79 | lacking_next = c_mxl_next - c_midi_next 80 | can_be_moved_forward = too_much & lacking_next 81 | can_be_moved_backward = too_much_next & lacking 82 | for pitch in can_be_moved_forward: 83 | last_j, last_p = measures["midi"][i][-1] 84 | # Only swap notes if they are within 0.5s of the beat boundary 85 | if ( 86 | last_p == pitch 87 | and annots["performance_beats"][i] - midi[last_j].start < swap_tol 88 | ): 89 | measures["midi"][i + 1].insert(0, measures["midi"][i].pop()) 90 | n_swaps += 1 91 | swapped = True 92 | break 93 | 94 | for pitch in can_be_moved_backward.keys(): 95 | first_j, first_p = measures["midi"][i + 1][0] 96 | # ---only swap notes if they are within 0.5s of the beat boundary 97 | if ( 98 | first_p == pitch 99 | and midi[first_j].start - annots["performance_beats"][i] < swap_tol 100 | ): 101 | measures["midi"][i].append(measures["midi"][i + 1].pop(0)) 102 | n_swaps += 1 103 | swapped = True 104 | break 105 | 106 | # print("Swaps", n_swaps, n_swaps/len(midi)) 107 | # If we had to swap a lot of notes, we probably messed up and should just return 108 | # the raw alignment... 109 | if n_swaps / len(midi) > 0.1: 110 | return make_measures(midi, midi_score, mxl, annots, swap=False) 111 | for i in range(len(measures["midi"])): 112 | measures["midi"][i] = sorted([j for j, p in measures["midi"][i]]) 113 | measures["mxl"][i] = sorted([j for j, p in measures["mxl"][i]]) 114 | assert max([max(m + [0]) for m in measures["midi"]]) == len(midi) - 1 115 | assert sum(len(m) for m in measures["midi"]) == len(midi) 116 | assert max([max(m + [0]) for m in measures["mxl"]]) == len(mxl) - 1 117 | assert sum(len(m) for m in measures["mxl"]) == len(mxl) 118 | measures["swapped"] = True 119 | return measures 120 | 121 | 122 | def handle_file(midi_path, mxl_path, save_path): 123 | annots = annotations[midi_path.replace("./data/asap-dataset/", "")] 124 | if not annots["score_and_performance_aligned"]: 125 | return 126 | import warnings 127 | 128 | with warnings.catch_warnings(): 129 | warnings.simplefilter("ignore") 130 | m, s = MultistreamTokenizer.mxl_to_list(mxl_path) 131 | measures = make_measures( 132 | MultistreamTokenizer.midi_to_list(midi_path), 133 | pretty_midi.PrettyMIDI( 134 | mxl_path.replace("xml_score.musicxml", "midi_score.mid") 135 | ), 136 | m, 137 | annots, 138 | ) 139 | 140 | if os.path.exists(save_path): 141 | try: 142 | prev = json.load(open(save_path)) 143 | except json.decoder.JSONDecodeError: 144 | print(save_path) 145 | raise 146 | # Print differences 147 | if prev["midi"] != measures["midi"]: 148 | print("midi", prev["midi"], measures["midi"]) 149 | if prev["mxl"] != measures["mxl"]: 150 | print("mxl", prev["mxl"], measures["mxl"]) 151 | 152 | json.dump(measures, open(save_path, "w")) 153 | 154 | 155 | if __name__ == "__main__": 156 | annotations = json.load(open("data/asap-dataset/asap_annotations.json")) 157 | skip = set(["data/asap-dataset/Glinka/The_Lark"]) 158 | paths = [] 159 | for root, dirs, files in os.walk("data/asap-dataset/"): 160 | for file in files: 161 | if file.endswith(".musicxml") and root not in skip: 162 | mxl_path = os.path.join(root, file) 163 | break 164 | else: 165 | continue 166 | for file in files: 167 | if file.endswith(".mid") and not file.startswith("midi_score"): 168 | midi_path = os.path.join(root, file) 169 | save_path = os.path.join(root, file.replace(".mid", "_chunks.json")) 170 | paths.append((midi_path, mxl_path, save_path)) 171 | 172 | q = ASAPDataset("./data/", "all") 173 | midi_paths = [ 174 | q.metadata.iloc[idx]["performance_MIDI_external"].replace( 175 | "{ASAP}", f"{q.data_dir}asap-dataset" 176 | ) 177 | for idx in range(0, len(q)) 178 | ] 179 | mxl_paths = [ 180 | os.path.join(os.path.dirname(m), "xml_score.musicxml") for m in midi_paths 181 | ] 182 | save_paths = [m.replace(".mid", "_chunks.json") for m in midi_paths] 183 | paths = list(zip(midi_paths, mxl_paths, save_paths)) 184 | Parallel(n_jobs=min(16, len(paths)), verbose=10)( 185 | delayed(handle_file)(midi_path, mxl_path, save_path) 186 | for midi_path, mxl_path, save_path in paths 187 | ) 188 | -------------------------------------------------------------------------------- /midi2scoretransformer/config.py: -------------------------------------------------------------------------------- 1 | from transformers import RoFormerConfig 2 | 3 | 4 | # TODO: Make this single source of truth 5 | FEATURES = { 6 | "onset": {'vocab_size': 200, 'loss_weight': 1.0, 'ignore_index': -100, 'min': 0, 'max': 8, 'step_size': 1/24}, 7 | "offset": {'vocab_size': 145, 'loss_weight': 0.25, 'ignore_index': -100, 'min': 0, 'max': 6, 'step_size': 1/24}, 8 | "downbeat": {'vocab_size': 146, 'loss_weight': 0.4, 'ignore_index': -100, 'min': -1/24, 'max': 6, 'step_size': 1/24}, 9 | "duration": {'vocab_size': 97, 'loss_weight': 0.3, 'ignore_index': -100, 'min': 0, 'max': 96}, 10 | "pitch": {'vocab_size': 128, 'loss_weight': 1.0, 'ignore_index': -100, 'min': 0, 'max': 127}, 11 | "accidental": {'vocab_size': 7, 'loss_weight': 0.5, 'ignore_index': 6, 'min': 0, 'max': 6}, 12 | "keysignature": {'vocab_size': 16, 'loss_weight': 0.5, 'ignore_index': 15, 'min': 0, 'max': 15}, 13 | "velocity": {'vocab_size': 8, 'loss_weight': 0.0, 'ignore_index': -100, 'min': 0, 'max': 127}, 14 | "grace": {'vocab_size': 2, 'loss_weight': 1.0, 'ignore_index': -100, 'min': 0, 'max': 1}, 15 | "trill": {'vocab_size': 2, 'loss_weight': 1.0, 'ignore_index': -100, 'min': 0, 'max': 1}, 16 | "staccato": {'vocab_size': 2, 'loss_weight': 0.15, 'ignore_index': -100, 'min': 0, 'max': 1}, 17 | "voice": {'vocab_size': 9, 'loss_weight': 0.3, 'ignore_index': 0, 'min': 0, 'max': 8}, 18 | "stem": {'vocab_size': 4, 'loss_weight': 0.2, 'ignore_index': 3, 'min': 0, 'max': 3}, 19 | "hand": {'vocab_size': 3, 'loss_weight': 0.25, 'ignore_index': 2, 'min': 0, 'max': 2}, 20 | } 21 | 22 | class MyModelConfig(RoFormerConfig): 23 | def __init__( 24 | self, 25 | input_streams=4, 26 | in_onset_vocab_size=200, 27 | in_duration_vocab_size=200, 28 | in_pitch_vocab_size=128, 29 | in_velocity_vocab_size=8, 30 | out_offset_vocab_size=FEATURES['offset']['vocab_size'], 31 | out_downbeat_vocab_size=FEATURES['downbeat']['vocab_size'], 32 | out_duration_vocab_size=FEATURES['duration']['vocab_size'], 33 | out_pitch_vocab_size=FEATURES['pitch']['vocab_size'], 34 | out_accidental_vocab_size=FEATURES['accidental']['vocab_size'], # need one class as ignore class for untagged inputs 35 | out_keysignature_vocab_size=FEATURES['keysignature']['vocab_size'], # need one class as ignore class for untagged inputs 36 | out_velocity_vocab_size=FEATURES['velocity']['vocab_size'], 37 | out_grace_vocab_size=FEATURES['grace']['vocab_size'], 38 | out_trill_vocab_size=FEATURES['trill']['vocab_size'], 39 | out_staccato_vocab_size=FEATURES['staccato']['vocab_size'], 40 | out_voice_vocab_size=FEATURES['voice']['vocab_size'], # need one class as ignore class for untagged inputs 41 | out_stem_vocab_size=FEATURES['stem']['vocab_size'], # need one class as ignore class for untagged inputs 42 | out_hand_vocab_size=FEATURES['hand']['vocab_size'], # need one class as ignore class for untagged inputs 43 | is_autoregressive=False, 44 | positional_encoding="RoPE", 45 | conditional_sampling=False, 46 | bias=True, 47 | **kwargs, 48 | ): 49 | super().__init__(**kwargs) 50 | self.input_streams = input_streams 51 | self.in_onset_vocab_size = in_onset_vocab_size 52 | self.in_duration_vocab_size = in_duration_vocab_size 53 | self.in_pitch_vocab_size = in_pitch_vocab_size 54 | self.in_velocity_vocab_size = in_velocity_vocab_size 55 | self.out_offset_vocab_size = out_offset_vocab_size 56 | self.out_downbeat_vocab_size = out_downbeat_vocab_size 57 | self.out_duration_vocab_size = out_duration_vocab_size 58 | self.out_pitch_vocab_size = out_pitch_vocab_size 59 | self.out_accidental_vocab_size = out_accidental_vocab_size 60 | self.out_keysignature_vocab_size = out_keysignature_vocab_size 61 | self.out_velocity_vocab_size = out_velocity_vocab_size 62 | self.out_grace_vocab_size = out_grace_vocab_size 63 | self.out_trill_vocab_size = out_trill_vocab_size 64 | self.out_staccato_vocab_size = out_staccato_vocab_size 65 | self.out_voice_vocab_size = out_voice_vocab_size 66 | self.out_stem_vocab_size = out_stem_vocab_size 67 | self.out_hand_vocab_size = out_hand_vocab_size 68 | # TODO: Move to this 69 | # self.out_vocab_sizes = { 70 | # "offset": out_offset_vocab_size, 71 | # "downbeat": out_downbeat_vocab_size, 72 | # "duration": out_duration_vocab_size, 73 | # "pitch": out_pitch_vocab_size, 74 | # "accidental": out_accidental_vocab_size, 75 | # "keysignature": out_keysignature_vocab_size, 76 | # "velocity": out_velocity_vocab_size, 77 | # "grace": out_grace_vocab_size, 78 | # "trill": out_trill_vocab_size, 79 | # "staccato": out_staccato_vocab_size, 80 | # "voice": out_voice_vocab_size, 81 | # "stem": out_stem_vocab_size, 82 | # "hand": out_hand_vocab_size, 83 | # } 84 | self.is_autoregressive = is_autoregressive 85 | assert positional_encoding in ["RoPE", "ALiBi", "absolute"] 86 | self.positional_encoding = positional_encoding 87 | self.conditional_sampling = conditional_sampling 88 | self.bias = bias 89 | -------------------------------------------------------------------------------- /midi2scoretransformer/constants.py: -------------------------------------------------------------------------------- 1 | """Constants used in the project""" 2 | 3 | 4 | # Binary paths -------------------- 5 | MUSESCORE_PATH = "/Applications/MuseScore 4.app/Contents/MacOS/mscore" 6 | LD_PATH = ""#"miniconda3/envs/thesis/x86_64-conda-linux-gnu/sysroot/usr/lib64:" 7 | 8 | 9 | # Dataset constants --------------- 10 | 11 | # Not parsable by music21 12 | SKIP = set( 13 | ["{ASAP}/Glinka/The_Lark/Denisova10M.mid", "{ASAP}/Glinka/The_Lark/Kleisen07M.mid"] 14 | ) 15 | 16 | # Not aligned correctly or other issues 17 | TO_IGNORE_INDICES = [152, 153, 154, 165, 166, 179, 180, 181, 332, 333, 334, 335, 349, 350, 18 | 351, 418, 419, 420, 426, 428, 429, 430, 472, 473, 474, 489, 490, 491, 19 | 516, 517, 518, 519, 520, 521, 522, 540, 541, 560, 609, 774, 798, 799, 20 | 800, 801, 802, 803, 819, 920, 921, 935, 936, 937, 938, 939, 940, 941, 21 | 979, 980, 981, 997, 998, 999, 1012, 1013, 1014, 1017, 1018] 22 | 23 | 24 | # To keep eval consistent, we hardcode test piece ids here. 25 | TEST_PIECE_IDS = [15, 78, 159, 172, 254, 288, 322, 374, 395, 399, 411, 418, 452, 478] 26 | 27 | # They were originally obtained via: 28 | # data = pd.concat(['data_real, data_synthetic]) 29 | # # Initial filtering 30 | # data = data[(data["source"] == "ASAP") & data["aligned"]] 31 | # data = data[~data["performance_MIDI_external"].isin(SKIP)] 32 | # data = data[~data["performance_MIDI_external"].isin(UNALIGNED)] 33 | # data = data.drop_duplicates(subset=["performance_MIDI_external"]) 34 | # # Filter by annotations 35 | # data.reset_index(inplace=True) 36 | # data.drop(TO_IGNORE_INDICES, inplace=True) 37 | # # Select first piece from each composer for testing (may have multiple performances) 38 | # TEST_PIECE_IDS = data.groupby("composer").first()["piece_id"].values 39 | 40 | -------------------------------------------------------------------------------- /midi2scoretransformer/data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /midi2scoretransformer/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains a torch.utils.data.Dataset wrapper for the 3 | ASAP dataset, which is a collection of MIDI files and corresponding MusicXML files. 4 | 5 | The first run is significantly slower as metadata and caches are built. 6 | Subsequent runs are much faster. 7 | """ 8 | from functools import lru_cache 9 | from tqdm import tqdm 10 | from joblib import Parallel, delayed 11 | import hashlib 12 | import json 13 | import os 14 | import random 15 | from typing import Dict, Optional, Tuple, Union 16 | 17 | import pandas as pd 18 | import torch 19 | from music21 import key, pitch 20 | from torch.utils.data import Dataset 21 | 22 | from constants import SKIP, TEST_PIECE_IDS, TO_IGNORE_INDICES 23 | from utils import cat_dict, cut_pad 24 | from tokenizer import MultistreamTokenizer 25 | 26 | 27 | class ASAPDataset(Dataset): 28 | """Implements a torch-compatible interface to the ASAP Dataset""" 29 | def __init__( 30 | self, 31 | data_dir: str = "./data/", 32 | split: str = "train", 33 | seq_length: Optional[int] = None, 34 | cache: bool = True, 35 | padding: str = 'per-beat', 36 | augmentations: Dict[str, Union[float, Dict[str, float]]] = {}, 37 | return_continous: bool = False, 38 | return_paths: bool = False, 39 | id: str="diffusion_2024_04_18", 40 | ): 41 | """ 42 | Parameters 43 | ---------- 44 | data_dir : str (default='./data/') 45 | Path to the data directory 46 | split : str (default='train') 47 | Which split to use. One of ["all", "train", "validation", "test"] 48 | seq_length : Optional[int] 49 | If not None, will cut/pad to this length 50 | cache : bool (default=True) 51 | Whether to cache the parsed MIDI/MXL files 52 | padding : str (default='per-beat') 53 | How to pad the data. One of ["per-beat", "end", None] 54 | augmentations : Dict[str, Union[float, Dict[str, float]]] 55 | Augmentations to apply to the data. If a key is not given, the augmentation 56 | is ignored. 57 | Possible keys are ["transpose", "tempo_jitter", "onset_jitter", "random_crop", "random_shift"] 58 | - transpose: int 59 | Whether to transpose the data by a random amount up to the given value. 60 | - random_crop: Union[bool, int] 61 | Whether to crop the data to a random length between 16 62 | - tempo_jitter: Tuple[float, float] 63 | Whether to jitter the tempo by a random amount between the given values. 64 | - onset_jitter: float 65 | Whether to jitter the onset by a random amount according to the given value. 66 | Multiplicative (the intra-onset intervals are scaled by N(1,onset_jitter^2)). 67 | return_continous : bool (default=False) 68 | Whether to return the data as continous values, or as a dictionary of bucketed tensors. 69 | id : str (default="diffusion_2023_10_13") 70 | A unique identifier for the dataset. This is used to ensure that the cache is not 71 | reused between different datasets. 72 | """ 73 | # Get metadata 74 | self.data_dir = data_dir 75 | self.split = split 76 | self.seq_length = seq_length 77 | self.cache = cache 78 | assert padding in ('per-beat', 'end', None) 79 | self.padding = padding 80 | self.augmentations = augmentations 81 | self.return_continous = return_continous 82 | self.return_paths = return_paths 83 | self.id = id 84 | self.metadata = self._load_metadata(data_dir, split) 85 | 86 | if self.cache: 87 | os.makedirs(os.path.join(data_dir, "cache"), exist_ok=True) 88 | 89 | def _load_metadata(self, data_dir: str, split: str) -> pd.DataFrame: 90 | data_real = pd.read_csv(data_dir + "/ACPAS-dataset/metadata_R.csv") 91 | data_synthetic = pd.read_csv(data_dir + "/ACPAS-dataset/metadata_S.csv") 92 | asap_annotations = json.load( 93 | open(data_dir + "/asap-dataset/asap_annotations.json") 94 | ) 95 | UNALIGNED = set( 96 | "{ASAP}/" + k 97 | for k, v in asap_annotations.items() 98 | if not v["score_and_performance_aligned"] 99 | ) 100 | # Filter 101 | data = pd.concat([data_real, data_synthetic]) 102 | # Initial filtering 103 | data = data[(data["source"] == "ASAP") & data["aligned"]] 104 | data = data[~data["performance_MIDI_external"].isin(SKIP)] 105 | data = data[~data["performance_MIDI_external"].isin(UNALIGNED)] 106 | data = data.drop_duplicates(subset=["performance_MIDI_external"]) 107 | # Filter by annotations 108 | data.reset_index(inplace=True) 109 | data.drop(TO_IGNORE_INDICES, inplace=True) 110 | 111 | # Select first piece from each composer for testing (may have multiple performances) 112 | # test_ids = data.groupby("composer").first()["piece_id"].values 113 | test_idx = data["piece_id"].isin(TEST_PIECE_IDS) 114 | 115 | if split == "all": 116 | return data 117 | elif split == "test": 118 | return data[test_idx] 119 | elif split == "validation": 120 | return data[(data["piece_id"] % 10 == 0) & (~data["piece_id"].isin(TEST_PIECE_IDS))] 121 | elif split == "train": 122 | d = data[(data["piece_id"] % 10 != 0) & (~data["piece_id"].isin(TEST_PIECE_IDS))] 123 | try: 124 | self.lengths = [] 125 | for idx in range(len(d)): 126 | sample = d.iloc[idx] 127 | sample_path = sample["performance_MIDI_external"].replace( 128 | "{ASAP}", f"{self.data_dir}asap-dataset" 129 | ) 130 | # fmt: off 131 | pkl_file = os.path.join(self.data_dir, "cache", f"{sha256(sample_path + self.id)}_.pkl") 132 | # fmt: on 133 | input_stream, output_stream = torch.load(pkl_file, weights_only=False) 134 | self.lengths.append(len(input_stream['onset'])) 135 | self.lengths = torch.FloatTensor(self.lengths) 136 | # When creating the cache for the first time, we don't have the lengths yet, 137 | # so we will just sample uniformly. 138 | except FileNotFoundError as e: 139 | self.lengths = torch.ones(len(d)) 140 | return d 141 | else: 142 | raise ValueError(f"Invalid split: {split}") 143 | 144 | def __len__(self) -> int: 145 | return len(self.metadata) 146 | 147 | def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: 148 | if self.split == "train": 149 | idx: int = torch.multinomial(self.lengths, 1, replacement=True).item() 150 | sample = self.metadata.iloc[idx] 151 | sample_path = sample["performance_MIDI_external"].replace( 152 | "{ASAP}", f"{self.data_dir}asap-dataset" 153 | ) 154 | sample_dir = os.path.dirname(sample_path) 155 | 156 | pkl_file = os.path.join(self.data_dir, "cache", f"{sha256(sample_path + self.id)}.pkl") 157 | 158 | if (not self.cache) or (not os.path.exists(pkl_file)): 159 | score_path = sample_dir + "/xml_score.musicxml" 160 | input_stream = MultistreamTokenizer.parse_midi(sample_path) 161 | output_stream = MultistreamTokenizer.parse_mxl(score_path) 162 | torch.save((input_stream, output_stream), pkl_file) 163 | 164 | input_stream, output_stream = torch.load(pkl_file, weights_only=False) 165 | 166 | if self.augmentations.get("transpose", False): 167 | shift = random.randint(-6, 6) 168 | input_stream["pitch"], output_stream["pitch"], output_stream["accidental"], output_stream["keysignature"] = self._transpose( 169 | shift, 170 | midi_stream=input_stream["pitch"], 171 | mxl_stream=output_stream["pitch"], 172 | accidental_stream=output_stream["accidental"], 173 | keysignature_stream=output_stream["keysignature"] 174 | ) 175 | 176 | if (v := self.augmentations.get("tempo_jitter", False)): 177 | jitter_onset = random.uniform(*v) 178 | jitter_duration = jitter_onset + random.uniform(-0.05, 0.05) 179 | input_stream["onset"] = input_stream["onset"] * jitter_onset 180 | input_stream["duration"] = input_stream["duration"] * jitter_duration 181 | if (v := self.augmentations.get("onset_jitter", False)): 182 | jitter = 1 + torch.randn(input_stream["onset"].shape) * v 183 | # adjust intervals between onsets 184 | inter_note_intervals = torch.diff(input_stream["onset"], prepend=torch.tensor([0]), dim=0) 185 | input_stream["onset"] = torch.cumsum(inter_note_intervals * jitter, dim=0) 186 | if (v := self.augmentations.get("velocity_jitter", False)): 187 | input_stream["velocity"] += torch.round(torch.randn(input_stream["velocity"].shape) * v).long() 188 | input_stream["velocity"] = torch.clamp(input_stream["velocity"], 1, 127) 189 | 190 | if self.return_continous: 191 | return input_stream, output_stream 192 | 193 | input_stream = MultistreamTokenizer.bucket_midi(input_stream) 194 | output_stream = MultistreamTokenizer.bucket_mxl(output_stream) 195 | 196 | if self.seq_length is not None: 197 | seq_length = self.seq_length 198 | else: 199 | # need buffer due to padding with 'per-beat' option 200 | seq_length = max(len(input_stream['onset']), len(output_stream['offset'])) + 256 201 | 202 | chunk_annots = json.load(open(sample_path.replace(".mid", "_chunks.json"))) 203 | 204 | if (v := self.augmentations.get("random_crop", False)): 205 | min_beats = 16 206 | if v is True: 207 | n_0 = random.randint(0, max(len(chunk_annots["midi"]) - min_beats, 0)) 208 | elif isinstance(v, int): 209 | average = sum([len(x) for x in chunk_annots["midi"]])/len(chunk_annots["midi"]) 210 | n_0 = random.choice(range(0, max(len(chunk_annots["midi"]) - min_beats, 1), max(1, int(v/average)))) 211 | else: 212 | raise ValueError("Invalid random_crop value") 213 | else: 214 | n_0 = 0 215 | 216 | def process_chunk(stream, chunk, padding, length): 217 | if padding == "per-beat": 218 | return {k: cut_pad(v[chunk], length, 0) for k, v in stream.items()} 219 | return {k: v[chunk] for k, v in stream.items()} 220 | 221 | new_input_stream = None # just a sentry 222 | for midi_chunk, mxl_chunk in zip( 223 | chunk_annots["midi"][n_0:], chunk_annots["mxl"][n_0:] 224 | ): 225 | length = max(len(midi_chunk), len(mxl_chunk)) 226 | if new_input_stream is not None and len(new_input_stream["onset"]) + length > seq_length + self.augmentations.get("random_shift", 0): 227 | break 228 | in_chunk = process_chunk(input_stream, midi_chunk, self.padding, length) 229 | out_chunk = process_chunk(output_stream, mxl_chunk, self.padding, length) 230 | if new_input_stream is None: 231 | new_input_stream = in_chunk 232 | new_output_stream = out_chunk 233 | else: 234 | new_input_stream = cat_dict(new_input_stream, in_chunk) 235 | new_output_stream = cat_dict(new_output_stream, out_chunk) 236 | if (v := self.augmentations.get("random_shift", False)): 237 | shift = random.randint(0, v - 1) 238 | for k, v in new_input_stream.items(): 239 | new_input_stream[k] = v[shift:] 240 | for k, v in new_output_stream.items(): 241 | new_output_stream[k] = v[shift:] 242 | if self.padding is not None: 243 | # Cut/Pad to exact seq-length 244 | for k, v in new_input_stream.items(): 245 | input_stream[k] = cut_pad(v, seq_length, 0) 246 | for k, v in new_output_stream.items(): 247 | output_stream[k] = cut_pad(v, seq_length, 0) 248 | if self.return_paths: 249 | return input_stream, output_stream, sample_path, sample_dir + "/xml_score.musicxml" 250 | return input_stream, output_stream 251 | 252 | @lru_cache(None) 253 | @staticmethod 254 | def _accidental_map(p, a, i): 255 | def alter_map(accidental): 256 | alter_to_value = {None: 5, -2.0: 0, -1.0: 1, 0.0: 2, 1.0: 3, 2.0: 4} 257 | alter = accidental.alter if isinstance(accidental, pitch.Accidental) else accidental 258 | # 6 if not known 259 | return alter_to_value.get(alter, 6) 260 | 261 | if i is None: 262 | return a 263 | accidental_mapping = {0: 2, 1: 1, 2: 0, 3: -1, 4: -2} 264 | alter = accidental_mapping.get(a, 0) 265 | p_obj = pitch.Pitch() 266 | p_obj.midi = p + alter 267 | if a in accidental_mapping: 268 | p_obj.accidental = -accidental_mapping[a] 269 | p_obj.spellingIsInferred = False 270 | tp = p_obj.transpose(i) 271 | accepted_pitches = { 272 | 'C', 'B#', 'D--', 'C#', 'B##', 'D-', 'D', 'C##', 'E--', 'D#', 'E-', 'F--', 273 | 'E', 'D##', 'F-', 'F', 'E#', 'G--', 'F#', 'E##', 'G-', 'G', 'F##', 'A--', 274 | 'G#', 'A-', 'A', 'G##', 'B--', 'A#', 'B-', 'C--', 'B', 'A##', 'C-' 275 | } 276 | if tp.name not in accepted_pitches: 277 | return None 278 | return alter_map(tp.accidental) 279 | 280 | @lru_cache(None) 281 | @staticmethod 282 | def _ks_map(ks: int, i: str): 283 | if i is None: 284 | return ks 285 | if ks == 15: 286 | return None 287 | k_obj = key.KeySignature(ks - 7) 288 | ns = k_obj.transpose(i).sharps 289 | if not -7 <= ns <= 7: 290 | return None 291 | return ns + 7 292 | 293 | @staticmethod 294 | def _transpose(shift, midi_stream, mxl_stream=None, accidental_stream=None, keysignature_stream=None): 295 | """Transpose pitches by a random amount between -6 and 6. If accidental_stream and 296 | keysignature_stream are provided, they will be adjusted following the procedure 297 | in https://arxiv.org/pdf/2107.14009.pdf 298 | 299 | In more detail, pitches are simply shifted by the desired amount. 300 | Then, all musical intervals with that shift are tried. 301 | If the transposed accidentals or key signatures are invalid, they are set to ignore_index. 302 | Among the valid transpositions, the one with the lowest number of accidentals is selected. 303 | 304 | Parameters 305 | ---------- 306 | shift : int 307 | The amount of transposition 308 | midi_stream : torch.Tensor 309 | The MIDI pitch stream 310 | mxl_stream : torch.Tensor|None 311 | The MusicXML pitch stream, if provided. 312 | accidental_stream : torch.Tensor|None 313 | The accidental stream, if provided. 314 | keysignature_stream : torch.Tensor|None 315 | The key signature stream, if provided. 316 | 317 | Returns 318 | ------- 319 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] 320 | The transposed MIDI pitch stream, the transposed MusicXML pitch stream 321 | the transposed accidental stream, and the transposed key signature stream. 322 | """ 323 | assert (accidental_stream is None) == (keysignature_stream is None), "Either both or none of the accidentals and key signatures should be provided" 324 | if accidental_stream is not None: 325 | assert mxl_stream is not None, "Need mxl pitch stream alongside accidentals" 326 | 327 | def shift_pitch(stream, shift): 328 | stream = stream + shift 329 | stream[stream > 127] -= 12 330 | stream[stream < 0] += 12 331 | return stream 332 | 333 | midi_stream = shift_pitch(midi_stream, shift) 334 | results = [midi_stream] 335 | if mxl_stream is not None and accidental_stream is not None and keysignature_stream is not None: 336 | if shift == 0 and random.random() < 0.5: 337 | # Always include the original version some of the time. 338 | pass 339 | else: 340 | INTERVALS = { 341 | -6: ["d5", "A4"], 342 | -5: ["P5", "d6", "AA4"], 343 | -4: ["m6", "A5"], 344 | -3: ["M6", "d7", "AA5"], 345 | -2: ["m7", "A6"], 346 | -1: ["d1", "M7", "AA6"], 347 | 0: [None, "P1", "d2", "A7"], 348 | 1: ["m2", "A1"], 349 | 2: ["M2", "d3", "AA1"], 350 | 3: ["m3", "A2"], 351 | 4: ["M3", "d4", "AA2"], 352 | 5: ["P4", "A3"], 353 | 6: ["d5", "A4"], 354 | } 355 | intervals = INTERVALS[shift] 356 | m = mxl_stream.numpy() 357 | a = accidental_stream.numpy() 358 | ks = keysignature_stream.unique() 359 | best_error = float('inf') 360 | errors = dict() 361 | for interv in intervals: 362 | accidental_cand = torch.zeros_like(accidental_stream) 363 | keysignature_cand: torch.Tensor = keysignature_stream.clone() 364 | for k in ks: 365 | val = ASAPDataset._ks_map(int(k), interv) 366 | if val is None: # invalid key signature 367 | accidental_cand.fill_(6) 368 | keysignature_cand.fill_(15) 369 | break 370 | keysignature_cand[keysignature_cand == k] = val 371 | else: # valid keysignatures, so we look for accidentals 372 | for i in range(len(mxl_stream)): 373 | val = ASAPDataset._accidental_map(m[i], a[i], interv) 374 | if val is None: # invalid accidental 375 | accidental_cand.fill_(6) 376 | keysignature_cand.fill_(15) 377 | break 378 | accidental_cand[i] = val * 1.0 379 | error = (accidental_cand[accidental_cand != 5] - 2).abs().sum() 380 | errors[interv] = error 381 | # print(f"Error: {error} for {interv}", accidental_cand) 382 | if error < best_error: 383 | best_error = error 384 | accidental_stream = accidental_cand 385 | keysignature_stream = keysignature_cand 386 | if mxl_stream is not None: 387 | mxl_stream = shift_pitch(mxl_stream, shift) 388 | results.append(mxl_stream) 389 | if accidental_stream is not None: 390 | results.append(accidental_stream) 391 | results.append(keysignature_stream) 392 | return results 393 | 394 | 395 | def sha256(string: str) -> str: 396 | h = hashlib.new("sha256") 397 | h.update(string.encode()) 398 | return h.hexdigest() 399 | 400 | if __name__ == "__main__": 401 | print("Initializing ASAPDataset") 402 | for split in ("all", "train", "validation", "test"): 403 | q = ASAPDataset("./data/", split, seq_length=None, padding=None, cache=True, return_continous=False) 404 | print(split, len(q)) 405 | 406 | q = ASAPDataset("./data/", "all", seq_length=None, padding=None, cache=True, return_continous=True) 407 | print("Filling cache") 408 | # You can parallelize this loop, but you have to comment out both uses of lru_cache 409 | for i in tqdm(range(len(q))): 410 | q[i] 411 | # Parallel(n_jobs=8, verbose=10)(delayed(q.__getitem__)(i) for i in tqdm(range(len(q)))) 412 | -------------------------------------------------------------------------------- /midi2scoretransformer/evaluation/run_eval.py: -------------------------------------------------------------------------------- 1 | """Given a path to a model checkpoint and a dataset split ('test', 'train', 'validation', 'all'), 2 | compute all metrics.""" 3 | 4 | """Evaluate the model end-to-end on the full songs. 5 | Note that predictions are cached, so if you want to re-run the evaluation from scratch, 6 | use the --nocache flag. 7 | """ 8 | import argparse 9 | import os 10 | import sys 11 | 12 | import torch 13 | from joblib import Parallel, delayed 14 | from tqdm import tqdm 15 | 16 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 17 | 18 | from dataset import ASAPDataset 19 | from models.roformer import Roformer 20 | from tokenizer import MultistreamTokenizer 21 | from utils import eval, infer, pad_batch 22 | 23 | device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--split", type=str) 28 | parser.add_argument("--model", type=str, default=None) 29 | parser.add_argument("--fast_eval", action="store_true") 30 | args = parser.parse_args() 31 | 32 | q = ASAPDataset("./data/", args.split) 33 | batch_size = 16 34 | overlap = 64 35 | paths = [] 36 | for i in range(len(q.metadata)): 37 | sample = q.metadata.iloc[i] 38 | sample_path = sample["performance_MIDI_external"].replace( 39 | "{ASAP}", f"{q.data_dir}asap-dataset" 40 | ) 41 | score_path = os.path.dirname(sample_path) + "/xml_score.musicxml" 42 | paths.append((sample_path, score_path)) 43 | 44 | from lightning.pytorch import seed_everything 45 | seed_everything(42, workers=True) 46 | print("Load + Tokenize all songs") 47 | inputs = [] 48 | lengths = [] 49 | for midi, gt_mxl in tqdm(paths): 50 | x = MultistreamTokenizer.tokenize_midi(midi) 51 | inputs.append({k: v.unsqueeze(0).to(device) for k, v in x.items()}) 52 | lengths.append(x["pitch"].shape[0]) 53 | 54 | # Sort everything by length 55 | sorted_data = sorted(zip(lengths, inputs, paths), key=lambda x: x[0]) 56 | lengths, inputs, paths = zip(*sorted_data) 57 | 58 | print("Running inference") 59 | model = Roformer.load_from_checkpoint(args.model) 60 | model.to(device) 61 | model.eval() 62 | 63 | # First run everything through the model (batched) 64 | y_full = None 65 | for i in tqdm(range(0, len(inputs), batch_size)): 66 | x = pad_batch(inputs[i : i + batch_size]) 67 | y_hat = infer(x, model, overlap=overlap, chunk=512, kv_cache=True) 68 | if y_full is None: 69 | y_full = y_hat 70 | else: 71 | y_full = pad_batch([y_full, y_hat]) 72 | 73 | print(f"Computing score similarities") 74 | sims = Parallel(n_jobs=16, verbose=10)( 75 | delayed(eval)({k: v[i, :l] for k, v in y_full.items()}, p[1]) 76 | for i, (p, l) in enumerate(zip(paths, lengths)) 77 | ) 78 | 79 | sims = {k: [d[k] for d in sims if d[k]] for k in sims[0]} 80 | print("-----------------") 81 | print("Aggregate:", {k: len([s for s in sims[k] if s is not None]) for k in sims}, len(sims['mxl <-> gt_mxl'])) 82 | sims_aggregate = {} 83 | for k, v in sims.items(): 84 | # v is list of dicts 85 | aggregate = {k_: [d[k_] for d in v if d[k_] is not None] for k_ in v[0]} 86 | if any(key in k for key in ["TP", "FP", "FN", "TN"]): 87 | aggregate = {k_: sum(v_) for k_, v_ in aggregate.items()} 88 | else: 89 | aggregate = {k_: sum(v_) / (len(v_)+1e-9) for k_, v_ in aggregate.items()} 90 | sims_aggregate[k] = aggregate 91 | print(k, aggregate) 92 | sims["aggregate"] = sims_aggregate 93 | 94 | print(f"Ours", end=" ") 95 | for k in ["PitchER", "MissRate", "ExtraRate", "OnsetER", "OffsetER", "MeanER"]: 96 | print(f"{round(sims_aggregate['muster'][k], 2):5.2f}", end=" ") 97 | for k in ["NoteDeletion", "NoteInsertion", "NoteDuration", "StaffAssignment", "StemDirection", "NoteSpelling"]: 98 | print(f"{round(100*sims_aggregate['mxl <-> gt_mxl'][k], 2):5.2f}", end=" ") 99 | 100 | print("\nTable 4:") 101 | print("SOTA 6.86 26.74 9.69 - - -") 102 | print(f"Ours", end=" ") 103 | for k in ["StaffAssignment", "StemDirection", "NoteSpelling", "GraceF1", "StaccatoF1", "TrillF1"]: 104 | print(f"{round(100*sims_aggregate['mxl <-> gt_mxl'][k], 2):5.2f}", end=" ") 105 | print(f"\n{args.model}") 106 | -------------------------------------------------------------------------------- /midi2scoretransformer/models/embedding.py: -------------------------------------------------------------------------------- 1 | """Embedding modules for MIDI and MXL compound token data. 2 | Each token stream is embedded into fixed-size embeddings, which are then summed. 3 | """ 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class MIDIEmbeddings(nn.Module): 11 | """Construct embeddings given 5 one-hot input token streams.""" 12 | def __init__(self, config): 13 | super().__init__() 14 | self.embeddings = nn.ParameterDict({ 15 | "onset": nn.Linear(config.in_onset_vocab_size, config.embedding_size, bias=config.bias), 16 | "duration": nn.Linear(config.in_duration_vocab_size, config.embedding_size, bias=config.bias), 17 | "pitch": nn.Linear(config.in_pitch_vocab_size, config.embedding_size, bias=config.bias), 18 | "velocity": nn.Linear(config.in_velocity_vocab_size, config.embedding_size, bias=config.bias), 19 | "unconditional": nn.Linear(1, config.embedding_size, bias=False), 20 | }) 21 | self.layer_norm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps, bias=config.bias) 22 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 23 | self.config = config 24 | 25 | def forward(self, input_streams): 26 | """Embeds MIDI input token streams into a fixed-size embedding. 27 | 28 | Parameters 29 | ---------- 30 | input_streams : Dict[str, torch.Tensor] 31 | List of tensors of shape (n_notes, N) 32 | 33 | Returns 34 | ------- 35 | torch.Tensor 36 | Tensor of shape (n_notes, config.embedding_size) 37 | """ 38 | if self.config.is_autoregressive: 39 | shifted_input_streams = {k: torch.roll(v, 1, 1) for k, v in input_streams.items()} 40 | for k in shifted_input_streams.keys(): 41 | shifted_input_streams[k][:, 0] = 0 42 | input_streams = shifted_input_streams 43 | input_embeds = { 44 | k: v(input_streams[k]) for k, v in self.embeddings.items() if k in input_streams 45 | } 46 | embeddings = sum(input_embeds.values()) 47 | embeddings = self.layer_norm(embeddings) # Layernorm tested helpful #1188 48 | embeddings = self.dropout(embeddings) 49 | return embeddings 50 | 51 | 52 | class MIDIUnembeddings(nn.Module): 53 | """Project embeddings to compound tokens for MIDI data. 54 | Not required for the final model, only here for completeness. 55 | """ 56 | 57 | def __init__(self, config): 58 | super().__init__() 59 | self.embeddings = nn.ParameterDict({ 60 | "onset": nn.Linear(config.embedding_size, config.in_onset_vocab_size), 61 | "duration": nn.Linear(config.embedding_size, config.in_duration_vocab_size), 62 | "pitch": nn.Linear(config.embedding_size, config.in_pitch_vocab_size), 63 | "velocity": nn.Linear(config.embedding_size, config.in_velocity_vocab_size), 64 | }) 65 | self.mask_embeddings = nn.Linear(config.embedding_size, 1) 66 | self.layer_norm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) 67 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 68 | 69 | def forward(self, embeddings): 70 | """Projects fixed-size embeddings into four parallel token streams representing 71 | MIDI data. 72 | 73 | Parameters 74 | ---------- 75 | embeddings : List[tensor] 76 | List of tensors of shape (n_notes, config.embedding_dims) 77 | 78 | Returns 79 | ------- 80 | torch.Tensor 81 | Tensor of shape (n_notes, config.embedding_size) 82 | """ 83 | embeddings = self.layer_norm(embeddings) 84 | embeddings = self.dropout(embeddings) 85 | 86 | output_embeds = {k: v(embeddings) for k, v in self.embeddings.items()} 87 | output_embeds["pad"] = self.mask_embeddings(embeddings) 88 | return output_embeds 89 | 90 | 91 | class MXLEmbeddings(nn.Module): 92 | """Construct the embeddings from MusicXML token streams.""" 93 | 94 | def __init__(self, config): 95 | super().__init__() 96 | self.embeddings = nn.ParameterDict({ 97 | "offset": nn.Linear(config.out_offset_vocab_size, config.embedding_size, bias=config.bias), 98 | "downbeat": nn.Linear(config.out_downbeat_vocab_size, config.embedding_size, bias=config.bias), 99 | "duration": nn.Linear(config.out_duration_vocab_size, config.embedding_size, bias=config.bias), 100 | "pitch": nn.Linear(config.out_pitch_vocab_size, config.embedding_size, bias=config.bias), 101 | "accidental": nn.Linear(config.out_accidental_vocab_size, config.embedding_size, bias=config.bias), 102 | "keysignature": nn.Linear(config.out_keysignature_vocab_size, config.embedding_size, bias=config.bias), 103 | # "velocity": nn.Linear(config.out_velocity_vocab_size, config.embedding_size, bias=config.bias), 104 | "grace": nn.Linear(config.out_grace_vocab_size, config.embedding_size, bias=config.bias), 105 | "trill": nn.Linear(config.out_trill_vocab_size, config.embedding_size, bias=config.bias), 106 | "staccato": nn.Linear(config.out_staccato_vocab_size, config.embedding_size, bias=config.bias), 107 | "voice": nn.Linear(config.out_voice_vocab_size, config.embedding_size, bias=config.bias), 108 | "stem": nn.Linear(config.out_stem_vocab_size, config.embedding_size, bias=config.bias), 109 | "hand": nn.Linear(config.out_hand_vocab_size, config.embedding_size, bias=config.bias), 110 | }) 111 | self.mask_embeddings = nn.Linear(1, config.embedding_size, bias=config.bias) 112 | self.layer_norm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps, bias=config.bias) 113 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 114 | self.config = config 115 | 116 | def forward(self, input_streams): 117 | """Embeds MXL input token streams into a fixed-size embedding. 118 | 119 | Parameters 120 | ---------- 121 | input_streams : _type_ 122 | List of (typically one-hot) tensors of shape (n_notes, N, C) 123 | 124 | Returns 125 | ------- 126 | torch.Tensor 127 | Tensor of shape (n_notes, config.embedding_size) 128 | """ 129 | # Shift everything by 1 if the model is autoregressive 130 | if self.config.is_autoregressive: 131 | input_streams = {k: torch.roll(v, 1, 1) for k, v in input_streams.items()} 132 | for k in input_streams.keys(): 133 | # If we only feed a single timestep, we're likely in a generation 134 | # with cached embeddings, so we should not set it to 0. 135 | if input_streams[k].size(1) > 1: 136 | input_streams[k][:, 0] = 0 137 | output_embeds = { 138 | k: self.embeddings[k](v) for k, v in input_streams.items() if k in self.embeddings 139 | } 140 | if "pad" in input_streams: 141 | output_embeds["pad"] = self.mask_embeddings(input_streams["pad"].float().unsqueeze(2)) 142 | embeddings = sum(output_embeds.values()) 143 | embeddings = self.layer_norm(embeddings) 144 | embeddings = self.dropout(embeddings) 145 | return embeddings 146 | 147 | 148 | class MXLUnembeddings(nn.Module): 149 | """Project embeddings to compound tokens.""" 150 | 151 | def __init__(self, config): 152 | super().__init__() 153 | self.config = config 154 | self.embeddings = nn.ParameterDict({ 155 | "offset": nn.Linear(config.embedding_size, config.out_offset_vocab_size, bias=config.bias), 156 | "downbeat": nn.Linear(config.embedding_size, config.out_downbeat_vocab_size, bias=config.bias), 157 | "duration": nn.Linear(config.embedding_size, config.out_duration_vocab_size, bias=config.bias), 158 | "pitch": nn.Linear(config.embedding_size, config.out_pitch_vocab_size, bias=config.bias), 159 | "accidental": nn.Linear(config.embedding_size, config.out_accidental_vocab_size, bias=config.bias), 160 | "keysignature": nn.Linear(config.embedding_size, config.out_keysignature_vocab_size, bias=config.bias), 161 | "velocity": nn.Linear(config.embedding_size, config.out_velocity_vocab_size, bias=config.bias), 162 | "grace": nn.Linear(config.embedding_size, config.out_grace_vocab_size, bias=config.bias), 163 | "trill": nn.Linear(config.embedding_size, config.out_trill_vocab_size, bias=config.bias), 164 | "staccato": nn.Linear(config.embedding_size, config.out_staccato_vocab_size, bias=config.bias), 165 | "voice": nn.Linear(config.embedding_size, config.out_voice_vocab_size, bias=config.bias), 166 | "stem": nn.Linear(config.embedding_size, config.out_stem_vocab_size, bias=config.bias), 167 | "hand": nn.Linear(config.embedding_size, config.out_hand_vocab_size, bias=config.bias), 168 | }) 169 | self.mask_embeddings = nn.Linear(config.embedding_size, 1, bias=config.bias) 170 | 171 | self.layer_norm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps, bias=config.bias) 172 | 173 | def forward(self, embeddings): 174 | """Projects fixed-size embedding into multiple data streams representing 175 | MusicXML data. 176 | 177 | Parameters 178 | ---------- 179 | embedding : torch.tensor 180 | List of tensors of shape (B, n_notes, D) 181 | 182 | Returns 183 | ------- 184 | torch.Tensor 185 | the fixed-size embedding of the input 186 | """ 187 | embeddings = self.layer_norm(embeddings) 188 | output_embeds = {k: v(embeddings) for k, v in self.embeddings.items()} 189 | output_embeds["pad"] = self.mask_embeddings(embeddings) 190 | 191 | return output_embeds 192 | -------------------------------------------------------------------------------- /midi2scoretransformer/models/model.py: -------------------------------------------------------------------------------- 1 | """Base model class for the PM2S Transformer model.""" 2 | from typing import Dict 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | class BaseModel(pl.LightningModule): 8 | def __init__( 9 | self, enc_configuration=None, dec_configuration=None, hyperparameters=None 10 | ): 11 | super().__init__() 12 | self.enc_config = enc_configuration 13 | self.dec_config = dec_configuration 14 | self.hyperparameters = hyperparameters 15 | self.save_hyperparameters() 16 | 17 | def forward( 18 | self, 19 | input_streams: torch.FloatTensor = None, 20 | output_streams: torch.FloatTensor = None, 21 | ) -> Dict[str, torch.Tensor]: 22 | encodings = self.forward_enc( 23 | input_streams, attention_mask=input_streams["pad"] 24 | ) 25 | # encoder-decoder 26 | return self.forward_dec( 27 | input_streams=output_streams, 28 | encoder_hidden_states=encodings, 29 | encoder_attention_mask=input_streams["pad"], 30 | ) 31 | 32 | @torch.no_grad() 33 | def generate(self, x, y=None, max_length=512, temperature=1.0, top_k=1, kv_cache=False) -> dict[str, torch.Tensor]: 34 | """Generate a sequence of tokens from the model. 35 | If y with T timesteps is provided, only max_length - T tokens will be generated. 36 | The first T tokens will be y_hist. 37 | """ 38 | B, T, _ = x["pitch"].shape 39 | device = x["pitch"].device 40 | conf = self.dec_config 41 | # Model is used to the first tokens being all 0's & it will be overwritten anyways 42 | # fmt: off 43 | y_start_token = { 44 | "offset": torch.zeros((B, 1, conf.out_offset_vocab_size), device=device), 45 | "downbeat": torch.zeros((B, 1, conf.out_downbeat_vocab_size), device=device), 46 | "duration": torch.zeros((B, 1, conf.out_duration_vocab_size), device=device), 47 | "pitch": torch.zeros((B, 1, conf.out_pitch_vocab_size), device=device), 48 | "accidental": torch.zeros((B, 1, conf.out_accidental_vocab_size), device=device), 49 | "keysignature": torch.zeros((B, 1, conf.out_keysignature_vocab_size), device=device), 50 | "velocity": torch.zeros((B, 1, conf.out_velocity_vocab_size), device=device), 51 | "grace": torch.zeros((B, 1, conf.out_grace_vocab_size), device=device), 52 | "trill": torch.zeros((B, 1, conf.out_trill_vocab_size), device=device), 53 | "staccato": torch.zeros((B, 1, conf.out_staccato_vocab_size), device=device), 54 | "voice": torch.zeros((B, 1, conf.out_voice_vocab_size), device=device), 55 | "stem": torch.zeros((B, 1, conf.out_stem_vocab_size), device=device), 56 | "hand": torch.zeros((B, 1, conf.out_hand_vocab_size), device=device), 57 | "pad": torch.zeros((B, 1), device=device).long(), 58 | } 59 | # fmt: on 60 | if "encoder" in self.hyperparameters["components"]: 61 | encoder_hidden_states = self.forward_enc( 62 | x, attention_mask=x["pad"] 63 | ) # (B, T, D) 64 | encoder_attention_mask = x["pad"] 65 | else: 66 | encoder_hidden_states = None 67 | encoder_attention_mask = None 68 | if y is None: 69 | y = y_start_token 70 | past_key_values = None 71 | else: 72 | y = {k: torch.cat([y_start_token[k], y[k]], dim=1) for k in y.keys()} 73 | # Have to populate KV-cache 74 | past_key_values = self.forward_dec( 75 | input_streams={k: torch.roll(v[:, :-1], -1, 1) for k, v in y.items()}, 76 | encoder_hidden_states=encoder_hidden_states, 77 | encoder_attention_mask=encoder_attention_mask, 78 | past_key_values=None, 79 | use_cache=True 80 | )[1] 81 | for _ in range(max_length + 1 - y["pad"].shape[1]): 82 | if kv_cache: 83 | y_pred, past_key_values = self.forward_dec( 84 | input_streams={k: v[:, -1:] for k, v in y.items()}, 85 | encoder_hidden_states=encoder_hidden_states, 86 | encoder_attention_mask=encoder_attention_mask, 87 | past_key_values=past_key_values, 88 | use_cache=True 89 | ) 90 | else: 91 | shifted_y = {k: torch.roll(v, -1, 1) for k, v in y.items()} 92 | y_pred = self.forward_dec( 93 | input_streams=shifted_y, 94 | encoder_hidden_states=encoder_hidden_states, 95 | encoder_attention_mask=encoder_attention_mask, 96 | ) 97 | for k in y.keys(): 98 | # forward the model to get the logits for the index in the sequence 99 | logits = y_pred[k] 100 | # pluck the logits at the final step and scale by desired temperature 101 | logits = logits[:, -1, :] / temperature 102 | # ensure that we sample a downbeat wherever the offset decreases, since that guarantees a measure change! 103 | if k == "downbeat" and y["offset"].shape[1] > 1: 104 | is_downbeat = y_pred["offset"][:, -1].argmax(-1) < y["offset"][:, -2].argmax(-1) 105 | logits[is_downbeat, 0] = -float("Inf") 106 | 107 | if k == "accidental": 108 | # ensure that we only sample valid accidentals 109 | { 110 | 0: 'double-flat', 111 | 1: 'flat', 112 | 2: 'natural', 113 | 3: 'sharp', 114 | 4: 'double-sharp', 115 | } 116 | never_allowed = [0, 4, 6] 117 | impossible_accidentals = { 118 | 0: [1, 4], 119 | 1: [0, 2, 5], 120 | 2: [1, 3], 121 | 3: [2, 4, 5], 122 | 4: [0, 3], 123 | 5: [1, 4], 124 | 6: [0, 2, 5], 125 | 7: [1, 3], 126 | 8: [0, 2, 4, 5], 127 | 9: [1, 3], 128 | 10: [2, 4, 5], 129 | 11: [0, 3] 130 | } 131 | for i in range(logits.shape[0]): 132 | predicted_pitch = y["pitch"][i, -1].argmax() 133 | options = impossible_accidentals[predicted_pitch.item() % 12] + never_allowed 134 | logits[i, options] = float("-inf") 135 | # optionally crop the logits to only the top k options 136 | if top_k is not None: 137 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 138 | logits[logits < v[:, [-1]]] = -float("Inf") 139 | # apply softmax to convert logits to (normalized) probabilities 140 | probs = ( 141 | F.softmax(logits, dim=-1) 142 | if k != "pad" 143 | else torch.cat([1 - F.sigmoid(logits), F.sigmoid(logits)], dim=-1) 144 | ) 145 | # Greedy decoding (equivalent to argmax for topk = 1) 146 | next_token = torch.multinomial(probs, num_samples=1) # 633 tok/s 147 | # sample from the distribution 148 | # next_token = torch.searchsorted(torch.cumsum(probs, dim=-1), torch.rand((B, 1)).to(probs.device)) # 660 tok/s 149 | if k == "pad": # special case + NO ARGMAX sampling 150 | next_token = probs.argmax(-1, keepdim=True) 151 | y[k] = torch.cat([y[k], next_token], dim=1) 152 | else: 153 | # Token back to one-hot 154 | next_token = F.one_hot( 155 | next_token, num_classes=y_pred[k].shape[-1] 156 | ) 157 | y[k] = torch.cat([y[k], next_token], dim=1) 158 | 159 | # set other tokens zero where mask 160 | mask = y["pad"][:, -1] == 0 161 | for k in y.keys(): 162 | if k != "pad": 163 | y[k][mask, -1] = 0 164 | 165 | # Remove the token 166 | for k in y.keys(): 167 | y[k] = y[k][:, 1:] 168 | y["pad"] = y["pad"].unsqueeze(-1).float() 169 | return y -------------------------------------------------------------------------------- /midi2scoretransformer/models/roformer.py: -------------------------------------------------------------------------------- 1 | """This module contains our RoFormer model implementation. It is mostly a copy of the RoFormer implementation 2 | from HuggingFace Transformers, but contains some modifications: 3 | - custom embeddings/projections 4 | - flash attention 5 | - SwiGLU activation function 6 | - QKV layer fusion where possible 7 | - RoPE in cross-attention 8 | - pre-norm 9 | """ 10 | 11 | import math 12 | from typing import Dict, Optional 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from transformers import RoFormerModel as RoFormerModelBase 18 | from transformers.modeling_outputs import \ 19 | BaseModelOutputWithPastAndCrossAttentions 20 | from transformers.models.roformer.modeling_roformer import (RoFormerAttention, 21 | RoFormerEncoder, 22 | RoFormerLayer) 23 | from models.embedding import MIDIEmbeddings, MXLEmbeddings, MXLUnembeddings 24 | from models.model import BaseModel 25 | 26 | 27 | class CustomRoFormerSelfAttention(nn.Module): 28 | def __init__(self, config, is_cross_attention=False): 29 | super().__init__() 30 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 31 | raise ValueError( 32 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 33 | f"heads ({config.num_attention_heads})" 34 | ) 35 | 36 | self.num_attention_heads = config.num_attention_heads 37 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 38 | self.all_head_size = self.num_attention_heads * self.attention_head_size 39 | self.is_cross_attention = is_cross_attention 40 | 41 | if self.is_cross_attention: 42 | self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.bias) 43 | self.key_value = nn.Linear(config.hidden_size, self.all_head_size * 2, bias=config.bias) 44 | else: 45 | self.query_key_value = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=config.bias) 46 | 47 | self.dropout = config.attention_probs_dropout_prob 48 | self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=config.bias) 49 | 50 | self.is_decoder = config.is_decoder 51 | self.rotary_value = config.rotary_value 52 | 53 | def transpose_for_scores(self, x): 54 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 55 | x = x.view(*new_x_shape) 56 | return x.permute(0, 2, 1, 3) 57 | 58 | def forward( 59 | self, 60 | hidden_states, 61 | attention_mask=None, 62 | sinusoidal_pos=None, 63 | head_mask=None, 64 | encoder_hidden_states=None, 65 | encoder_attention_mask=None, 66 | past_key_value=None, 67 | output_attentions=False, 68 | ): 69 | hidden_states = self.norm(hidden_states) 70 | if not self.is_cross_attention: 71 | q, k, v = self.query_key_value(hidden_states).chunk(3, dim=-1) 72 | query_layer = self.transpose_for_scores(q) 73 | key_layer = self.transpose_for_scores(k) 74 | value_layer = self.transpose_for_scores(v) 75 | else: 76 | q = self.query(hidden_states) 77 | query_layer = self.transpose_for_scores(q) 78 | if past_key_value is not None: 79 | # reuse k, v, cross_attentions 80 | key_layer = past_key_value[0] 81 | value_layer = past_key_value[1] 82 | else: 83 | k, v = self.key_value(encoder_hidden_states).chunk(2, dim=-1) 84 | key_layer = self.transpose_for_scores(k) 85 | value_layer = self.transpose_for_scores(v) 86 | attention_mask = encoder_attention_mask 87 | 88 | # If this is instantiated as a cross-attention module, the keys 89 | # and values come from an encoder; the attention mask needs to be 90 | # such that the encoder's padding tokens are not attended to. 91 | if sinusoidal_pos is not None: 92 | if past_key_value is not None and self.is_cross_attention: 93 | # the past_key_values have already been rotated 94 | query_layer, _ = self.apply_rotary_position_embeddings( 95 | sinusoidal_pos, query_layer, query_layer 96 | ) 97 | else: 98 | if self.rotary_value: 99 | query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings( 100 | sinusoidal_pos, query_layer, key_layer, value_layer 101 | ) 102 | else: 103 | query_layer, key_layer = self.apply_rotary_position_embeddings( 104 | sinusoidal_pos, query_layer, key_layer 105 | ) 106 | if not self.is_cross_attention and past_key_value is not None: 107 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 108 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 109 | if self.is_decoder: 110 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 111 | # Further calls to cross_attention layer can then reuse all cross-attention 112 | # key/value_states (first "if" case) 113 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 114 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 115 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 116 | # if encoder bi-directional self-attention `past_key_value` is always `None` 117 | past_key_value = (key_layer, value_layer) 118 | 119 | if head_mask is None: 120 | if self.is_decoder and not self.is_cross_attention: 121 | context_layer = F.scaled_dot_product_attention( 122 | query_layer, 123 | key_layer, 124 | value_layer, 125 | is_causal=True if past_key_value is None else False, 126 | attn_mask=attention_mask, 127 | dropout_p=self.dropout if self.training else 0 128 | ) 129 | else: 130 | context_layer = F.scaled_dot_product_attention( 131 | query_layer, 132 | key_layer, 133 | value_layer, 134 | attn_mask=attention_mask, 135 | dropout_p=self.dropout if self.training else 0 136 | ) 137 | 138 | # We sometimes ran into nan's with flash attention, thus a fallback here. 139 | if head_mask is not None or context_layer.isnan().any(): 140 | # Take the dot product between "query" and "key" to get the raw attention scores. 141 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 142 | 143 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 144 | if attention_mask is not None: 145 | # Apply the attention mask is (precomputed for all layers in RoFormerModel forward() function) 146 | attention_scores = attention_scores + attention_mask 147 | 148 | # Normalize the attention scores to probabilities. 149 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 150 | 151 | # This is actually dropping out entire tokens to attend to, which might 152 | # seem a bit unusual, but is taken from the original Transformer paper. 153 | attention_probs = F.dropout(attention_probs, p=self.dropout if self.training else 0) 154 | 155 | # Mask heads if we want to 156 | if head_mask is not None: 157 | attention_probs = attention_probs * head_mask 158 | 159 | context_layer = torch.matmul(attention_probs, value_layer) 160 | 161 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 162 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 163 | context_layer = context_layer.view(*new_context_layer_shape) 164 | 165 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 166 | 167 | if self.is_decoder: 168 | outputs = outputs + (past_key_value,) 169 | return outputs 170 | 171 | @staticmethod 172 | def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None): 173 | # https://kexue.fm/archives/8265 174 | # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2] 175 | # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2] 176 | sin, cos = sinusoidal_pos.chunk(2, dim=-1) 177 | # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] 178 | sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(sinusoidal_pos) 179 | # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] 180 | cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(sinusoidal_pos) 181 | # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] 182 | rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as( 183 | query_layer 184 | ) 185 | q_t = query_layer.size(2) 186 | query_layer = query_layer * cos_pos[:, :, :q_t] + rotate_half_query_layer * sin_pos[:, :, :q_t] 187 | # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] 188 | rotate_half_key_layer = torch.stack([-key_layer[..., 1::2], key_layer[..., ::2]], dim=-1).reshape_as(key_layer) 189 | k_t = key_layer.size(2) 190 | key_layer = key_layer * cos_pos[:, :, :k_t] + rotate_half_key_layer * sin_pos[:, :, :k_t] 191 | if value_layer is not None: 192 | # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2] 193 | rotate_half_value_layer = torch.stack([-value_layer[..., 1::2], value_layer[..., ::2]], dim=-1).reshape_as( 194 | value_layer 195 | ) 196 | value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos 197 | return query_layer, key_layer, value_layer 198 | return query_layer, key_layer 199 | 200 | 201 | class CustomRoFormerSelfOutput(nn.Module): 202 | def __init__(self, config): 203 | super().__init__() 204 | self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.bias) 205 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 206 | 207 | def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: 208 | hidden_states = self.dense(hidden_states) 209 | hidden_states = self.dropout(hidden_states) 210 | return hidden_states + input_tensor 211 | 212 | 213 | class CustomRoFormerAttention(RoFormerAttention): 214 | """Patch in pre-norm layers.""" 215 | def __init__(self, config, is_cross_attention=False): 216 | super().__init__(config) 217 | self.self = CustomRoFormerSelfAttention(config, is_cross_attention=is_cross_attention) 218 | self.output = CustomRoFormerSelfOutput(config) 219 | 220 | 221 | class SwiGLU(nn.Module): 222 | def forward(self, x): 223 | x, gate = x.chunk(2, dim=-1) 224 | return F.silu(gate) * x 225 | 226 | 227 | class CustomRoFormerIntermediate(nn.Module): 228 | """Add norm to input of intermediate layer.""" 229 | def __init__(self, config): 230 | super().__init__() 231 | self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=config.bias) 232 | self.dense = nn.Linear(config.hidden_size, 2*config.intermediate_size, bias=config.bias) 233 | self.intermediate_act_fn = SwiGLU() 234 | 235 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 236 | hidden_states = self.norm(hidden_states) 237 | hidden_states = self.dense(hidden_states) 238 | hidden_states = self.intermediate_act_fn(hidden_states) 239 | return hidden_states 240 | 241 | 242 | class CustomRoFormerOutput(nn.Module): 243 | def __init__(self, config): 244 | super().__init__() 245 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.bias) 246 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 247 | 248 | def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: 249 | hidden_states = self.dense(hidden_states) 250 | hidden_states = self.dropout(hidden_states) 251 | return hidden_states + input_tensor 252 | 253 | 254 | class CustomRoFormerLayer(RoFormerLayer): 255 | def __init__(self, config): 256 | super().__init__(config) 257 | self.attention = CustomRoFormerAttention(config) # Use your custom attention here 258 | if self.add_cross_attention: 259 | self.crossattention = CustomRoFormerAttention(config, is_cross_attention=True) 260 | self.intermediate = CustomRoFormerIntermediate(config) 261 | self.output = CustomRoFormerOutput(config) 262 | 263 | 264 | class CustomRoFormerEncoder(RoFormerEncoder): 265 | def __init__(self, config): 266 | super().__init__(config) 267 | self.layer = nn.ModuleList([CustomRoFormerLayer(config) for _ in range(config.num_hidden_layers)]) 268 | 269 | def forward( 270 | self, 271 | hidden_states, 272 | attention_mask=None, 273 | head_mask=None, 274 | encoder_hidden_states=None, 275 | encoder_attention_mask=None, 276 | past_key_values=None, 277 | use_cache=None, 278 | output_attentions=False, 279 | output_hidden_states=False, 280 | return_dict=True, 281 | ): 282 | if self.gradient_checkpointing and self.training: 283 | if use_cache: 284 | use_cache = False 285 | all_hidden_states = () if output_hidden_states else None 286 | all_self_attentions = () if output_attentions else None 287 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 288 | 289 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 290 | 291 | # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head] 292 | # We have to ensure that positional embeddings work for both encoder and decoder states 293 | B, T = hidden_states.shape[:-1] 294 | T = max(encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0, T) 295 | sinusoidal_pos = self.embed_positions((B, T,), past_key_values_length)[None, None, :, :] 296 | 297 | next_decoder_cache = () if use_cache else None 298 | for i, layer_module in enumerate(self.layer): 299 | if output_hidden_states: 300 | all_hidden_states = all_hidden_states + (hidden_states,) 301 | 302 | layer_head_mask = head_mask[i] if head_mask is not None else None 303 | past_key_value = past_key_values[i] if past_key_values is not None else None 304 | 305 | if self.gradient_checkpointing and self.training: 306 | 307 | def create_custom_forward(module): 308 | def custom_forward(*inputs): 309 | return module(*inputs, past_key_value, output_attentions) 310 | 311 | return custom_forward 312 | 313 | layer_outputs = torch.utils.checkpoint.checkpoint( 314 | create_custom_forward(layer_module), 315 | hidden_states, 316 | attention_mask, 317 | sinusoidal_pos, 318 | layer_head_mask, 319 | encoder_hidden_states, 320 | encoder_attention_mask, 321 | ) 322 | else: 323 | layer_outputs = layer_module( 324 | hidden_states, 325 | attention_mask, 326 | sinusoidal_pos, 327 | layer_head_mask, 328 | encoder_hidden_states, 329 | encoder_attention_mask, 330 | past_key_value, 331 | output_attentions, 332 | ) 333 | 334 | hidden_states = layer_outputs[0] 335 | if use_cache: 336 | next_decoder_cache += (layer_outputs[-1],) 337 | if output_attentions: 338 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 339 | if self.config.add_cross_attention: 340 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],) 341 | 342 | if output_hidden_states: 343 | all_hidden_states = all_hidden_states + (hidden_states,) 344 | 345 | if not return_dict: 346 | return tuple( 347 | v 348 | for v in [ 349 | hidden_states, 350 | next_decoder_cache, 351 | all_hidden_states, 352 | all_self_attentions, 353 | all_cross_attentions, 354 | ] 355 | if v is not None 356 | ) 357 | return BaseModelOutputWithPastAndCrossAttentions( 358 | last_hidden_state=hidden_states, 359 | past_key_values=next_decoder_cache, 360 | hidden_states=all_hidden_states, 361 | attentions=all_self_attentions, 362 | cross_attentions=all_cross_attentions, 363 | ) 364 | 365 | 366 | class RoFormerModel(RoFormerModelBase): 367 | def __init__(self, config): 368 | super().__init__(config) 369 | self.encoder = CustomRoFormerEncoder(config) 370 | del self.embeddings 371 | 372 | 373 | class Roformer(BaseModel): 374 | def __init__(self, enc_configuration=None, dec_configuration=None, hyperparameters=None): 375 | super().__init__(enc_configuration, dec_configuration, hyperparameters) 376 | self.encoder = RoFormerModel(enc_configuration) 377 | self.decoder = RoFormerModel(dec_configuration) 378 | self.embeddings_enc = MIDIEmbeddings(enc_configuration) 379 | self.embeddings_dec = MXLEmbeddings(dec_configuration) 380 | self.unembeddings_dec = MXLUnembeddings(dec_configuration) 381 | 382 | self.norm = nn.LayerNorm(enc_configuration.hidden_size, eps=enc_configuration.layer_norm_eps, bias=enc_configuration.bias) 383 | 384 | def forward_enc( 385 | self, 386 | input_streams: torch.Tensor, 387 | attention_mask: Optional[torch.Tensor] = None 388 | ) -> Dict[str, torch.Tensor]: 389 | output_attentions = self.enc_config.output_attentions 390 | output_hidden_states = self.enc_config.output_hidden_states 391 | return_dict = self.enc_config.use_return_dict 392 | 393 | sample_stream = input_streams[list(input_streams.keys())[0]] 394 | B, T = sample_stream.size()[:-1] 395 | device = sample_stream.device 396 | 397 | if attention_mask is None: 398 | attention_mask = torch.ones(((B, T)), device=device) 399 | extended_attention_mask: torch.Tensor = ( 400 | self.encoder.get_extended_attention_mask(attention_mask, (B, T)) 401 | ) 402 | encoder_extended_attention_mask = None 403 | 404 | embedding_output = self.embeddings_enc(input_streams) 405 | 406 | if hasattr(self, "embeddings_project"): 407 | embedding_output = self.encoder.embeddings_project(embedding_output) 408 | 409 | encoder_outputs = self.encoder.encoder( 410 | embedding_output, 411 | attention_mask=extended_attention_mask, 412 | head_mask=None, 413 | encoder_attention_mask=encoder_extended_attention_mask, 414 | use_cache=False, 415 | output_attentions=output_attentions, 416 | output_hidden_states=output_hidden_states, 417 | return_dict=return_dict, 418 | ) 419 | return self.norm(encoder_outputs.last_hidden_state) 420 | 421 | def forward_dec( 422 | self, 423 | input_streams: Dict[str, torch.Tensor], 424 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 425 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 426 | past_key_values: Optional[tuple[torch.Tensor]] = None, 427 | use_cache: bool = False, 428 | ) -> Dict[str, torch.Tensor]: 429 | 430 | output_attentions = self.dec_config.output_attentions 431 | output_hidden_states = self.dec_config.output_hidden_states 432 | return_dict = self.dec_config.use_return_dict 433 | 434 | B, T = input_streams["offset"].size()[:2] 435 | if past_key_values is not None: 436 | T += past_key_values[0][0].size(2) 437 | device = input_streams["offset"].device 438 | 439 | # Make/convert attention masks 440 | attention_mask = torch.ones(((B, T)), device=device) 441 | extended_attention_mask: torch.Tensor = ( 442 | self.decoder.get_extended_attention_mask(attention_mask, (B, T)) 443 | ) 444 | if not self.dec_config.is_autoregressive: 445 | extended_attention_mask = torch.zeros_like(extended_attention_mask) 446 | 447 | assert self.dec_config.is_decoder 448 | if encoder_hidden_states is not None: 449 | encoder_hidden_shape = encoder_hidden_states.size()[:2] 450 | if encoder_attention_mask is None: 451 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 452 | encoder_extended_attention_mask = self.decoder.invert_attention_mask( 453 | encoder_attention_mask 454 | ) 455 | else: 456 | encoder_extended_attention_mask = None 457 | if use_cache: 458 | extended_attention_mask = extended_attention_mask[..., -input_streams['pad'].size(1):, :] 459 | # pass to model 460 | embedding_output = self.embeddings_dec({k: v[:, :T] for k, v in input_streams.items()}) 461 | 462 | decoder_outputs = self.decoder.encoder( 463 | embedding_output, 464 | attention_mask=extended_attention_mask, 465 | head_mask=None, 466 | encoder_hidden_states=encoder_hidden_states, 467 | encoder_attention_mask=encoder_extended_attention_mask, 468 | past_key_values=past_key_values, 469 | use_cache=use_cache, 470 | output_attentions=output_attentions, 471 | output_hidden_states=output_hidden_states, 472 | return_dict=return_dict, 473 | ) 474 | out_proj = self.unembeddings_dec(decoder_outputs.last_hidden_state) 475 | if use_cache: 476 | return out_proj, decoder_outputs.past_key_values 477 | else: 478 | return out_proj 479 | -------------------------------------------------------------------------------- /midi2scoretransformer/score_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import os 4 | import shutil 5 | import subprocess 6 | import tempfile 7 | 8 | from music21 import chord, expressions, key, note, stream 9 | 10 | from constants import LD_PATH, MUSESCORE_PATH 11 | 12 | 13 | def realize_spanners(s): 14 | to_remove = [] 15 | for sp in s.recurse().getElementsByClass(expressions.TremoloSpanner): 16 | l = sp.getSpannedElements() 17 | if len(l) != 2: 18 | print("Not sure what to do with this spanner", sp, l) 19 | continue 20 | start, end = l 21 | 22 | offset = start.offset 23 | start_chord = None 24 | end_chord = None 25 | startActiveSite = start.activeSite 26 | endActiveSite = end.activeSite 27 | if start.activeSite is None: 28 | start_chord: chord.Chord = start._chordAttached 29 | offset = start_chord.offset 30 | startActiveSite = start._chordAttached.activeSite 31 | if end.activeSite is None: 32 | end_chord: chord.Chord = end._chordAttached 33 | endActiveSite = end._chordAttached.activeSite 34 | 35 | # We insert a tremolo expression on the start note 36 | # realize it, and then change every second note to have the pitch of the end note 37 | trem = expressions.Tremolo() 38 | trem.measured = sp.measured 39 | trem.numberOfMarks = sp.numberOfMarks 40 | start.expressions.append(trem) 41 | out = trem.realize(start, inPlace=True)[0] 42 | if start_chord: 43 | if len(start_chord.notes) == 1: 44 | startActiveSite.remove(start_chord) 45 | else: 46 | start_chord.remove(start) 47 | else: 48 | startActiveSite.remove(start) 49 | if end_chord: 50 | if len(end_chord.notes) == 1: 51 | endActiveSite.remove(end_chord) 52 | else: 53 | end_chord.remove(end) 54 | else: 55 | endActiveSite.remove(end) 56 | for i, n2 in enumerate(out): 57 | if i % 2 == 1: 58 | n2.pitch = end.pitch 59 | startActiveSite.insert(offset, n2) 60 | offset += n2.duration.quarterLength 61 | to_remove.append(sp) 62 | for sp in s.recurse().getElementsByClass(expressions.TrillExtension): 63 | l = sp.getSpannedElements() 64 | start = l[0] 65 | exp = [l.expressions for l in l] 66 | if not any(isinstance(e, expressions.Trill) for ex in exp for e in ex): 67 | if len(l) != 1: 68 | print("Not sure what to do with this spanner", sp, l) 69 | continue 70 | start.expressions.append(expressions.Trill()) 71 | to_remove.append(sp) 72 | s.remove(to_remove, recurse=True) 73 | return s 74 | 75 | 76 | def convert_with_musescore(in_path: str, out_path: str): 77 | with tempfile.TemporaryDirectory() as tmpdirname: 78 | suffix = in_path.split(".")[-1] 79 | shutil.copy(in_path, f"{tmpdirname}/test.{suffix}") 80 | # Update the environment variables 81 | env_vars = os.environ.copy() 82 | env_vars["LD_LIBRARY_PATH"] = LD_PATH + env_vars.get("LD_LIBRARY_PATH", "") 83 | env_vars["DISPLAY"] = ":0" 84 | env_vars["QT_QPA_PLATFORM"] = "offscreen" 85 | env_vars["XDG_RUNTIME_DIR"] = tmpdirname 86 | # Run the subprocess with the updated environment 87 | subprocess.run( 88 | [MUSESCORE_PATH, "-o", out_path, f"{tmpdirname}/test.{suffix}"], 89 | stdout=subprocess.DEVNULL, 90 | stderr=subprocess.DEVNULL, 91 | env=env_vars 92 | ) 93 | 94 | 95 | 96 | 97 | def postprocess_score(mxl: stream.Score, makeChords: bool=False, inPlace=False) -> stream.Score: 98 | """We essentially roll our own "makeNotation" here because music21's is broken 99 | for non-ideal scores. 100 | 101 | Parameters 102 | ---------- 103 | mxl : stream.Score 104 | The score to postprocess. 105 | makeChords : bool, optional 106 | Whether to merge notes into chords and do various other prettifications, 107 | by default False, since it can alter the metric results. 108 | 109 | Returns 110 | ------- 111 | stream.Score 112 | The postprocessed score. 113 | """ 114 | if not inPlace: 115 | mxl = copy.deepcopy(mxl) 116 | 117 | def remove_note_fast(n: note.Note) -> None: 118 | """Quickly removes a note known to be in a voice. 119 | Faster as it doesn't require recursion. 120 | """ 121 | for site in n.sites: 122 | if isinstance(site, stream.Voice): 123 | site.remove(n) 124 | break 125 | else: 126 | v = n.getContextByClass(stream.Voice) 127 | # getContextByClass silently returns incorrect results sometimes 128 | if n not in v: 129 | raise ValueError("Note not found in voice") 130 | v.remove(n) 131 | 132 | # see if any measures can be merged 133 | merge_candidates: set[tuple[int, int]] = set() 134 | not_candidates: set[tuple[int, int]] = set() 135 | for part in mxl.parts: 136 | measures: list[stream.Measure] = list(part.getElementsByClass("Measure")) 137 | for i, (this_m, next_m) in enumerate(zip(measures, measures[1:])): 138 | # see if we can merge two adjacent measures into one 139 | this_m_highest_offset = max([n.offset for n in this_m.flatten().notes if n.offset < this_m.barDuration.quarterLength], default=0) 140 | next_m_lowest_offset = min([n.offset for n in next_m.flatten().notes], default=float("inf")) 141 | if this_m_highest_offset <= next_m_lowest_offset and next_m_lowest_offset != 0 and this_m.barDuration.quarterLength <= next_m.barDuration.quarterLength: 142 | if len(next_m.flatten().notes) > 6: 143 | not_candidates.add((i, i+1)) 144 | continue 145 | merge_candidates.add((i, i+1)) 146 | else: 147 | not_candidates.add((i, i+1)) 148 | 149 | remove = sorted(list(merge_candidates-not_candidates), reverse=True) 150 | for i, j in remove: 151 | for part in mxl.parts: 152 | measures: list[stream.Measure] = list(part.getElementsByClass("Measure")) 153 | this_m = measures[i] 154 | next_m = measures[j] 155 | if this_m.barDuration.quarterLength < next_m.barDuration.quarterLength: 156 | if (ts := this_m.getElementsByClass("TimeSignature")): 157 | this_m.remove(ts[0]) 158 | this_m.insert(0, next_m.getElementsByClass("TimeSignature")[0]) 159 | for next_v in next_m.voices: 160 | this_m.insert(next_v.offset, copy.deepcopy(next_v)) 161 | shift = part.elementOffset(next_m) - part.elementOffset(this_m) 162 | for m in measures[j:]: 163 | m.number -= 1 164 | part.coreSetElementOffset(m, part.elementOffset(m) - shift) 165 | part.coreElementsChanged(clearIsSorted=False) 166 | part.remove(next_m) 167 | # remove doubled notes: 168 | mxl = mxl.splitAtDurations(recurse=True)[0] 169 | mxl.makeTies(inPlace=True) 170 | notes: dict[tuple[float,int,bool], list[note.Note]] = {} 171 | for n in mxl.flatten().notes: 172 | key_tuple = (n.offset, n.pitch.midi, n.duration.isGrace) 173 | notes.setdefault(key_tuple, []).append(n) 174 | for v in notes.values(): 175 | if len(v) > 1: 176 | longest_note = max(v, key=lambda x: x.duration.quarterLength) 177 | for n in v: 178 | if n is not longest_note: 179 | remove_note_fast(n) 180 | # Hide unnecessary accidentals 181 | for p in mxl.parts: 182 | prior_accidental = {i: None for i in range(128)} 183 | for n in p.flatten(): 184 | if isinstance(n, key.KeySignature): 185 | kind = "sharp" if n.sharps > 0 else "flat" 186 | steps = [p.midi % 12 for p in n.alteredPitches] 187 | for i in range(128): 188 | if i % 12 in steps: 189 | prior_accidental[i] = kind 190 | elif isinstance(n, note.Note): 191 | if n.pitch.accidental is not None: 192 | n.pitch.accidental.displayStatus = True 193 | if n.pitch.accidental.name == prior_accidental[n.pitch.midi]: 194 | n.pitch.accidental.displayStatus = False 195 | mxl.streamStatus.accidentals = True 196 | for part in mxl.parts: 197 | flattened_notes = list(part.flatten().notes) 198 | offset_duration_dict: dict[tuple,list[note.Note]] = {} 199 | for n in flattened_notes: 200 | key_tuple = (n.offset, n.duration.quarterLength, n.duration.isGrace) 201 | offset_duration_dict.setdefault(key_tuple, []).append(n) 202 | for notes in offset_duration_dict.values(): 203 | if len(notes) > 1: # Merge needed 204 | c = chord.Chord(notes) 205 | first_note: note.Note = notes[0] 206 | c.expressions = first_note.expressions 207 | c.articulations = first_note.articulations 208 | 209 | # This is faster than part.replace(notes[0], c, recurse=True): 210 | for site in first_note.sites: 211 | if isinstance(site, stream.Voice): 212 | v = site 213 | break 214 | else: 215 | v = first_note.getContextByClass(stream.Voice) 216 | off = v.elementOffset(first_note) 217 | v.remove(first_note) 218 | v.insert(off, c) 219 | c.activeSite = v 220 | # Remove notes that were merged into the chord 221 | # Removing one-by-one is faster than removing all at once because we 222 | # know the direct site of the note, so can avoid recursion. 223 | for n in notes[1:]: 224 | remove_note_fast(n) 225 | # The following code is fairly ugly cleanup to ensure proper MusicXML export. 226 | # It removes empty voices and pads under-full measures with rests. 227 | def merge_and_pad_voices(m): 228 | voices = list(m.voices) 229 | non_empty_voices = [voice for voice in voices if len(voice.notes) > 0] 230 | if non_empty_voices: # we can remove all voices that only contain a rest 231 | for v in voices: 232 | if v not in non_empty_voices: 233 | m.remove(v) 234 | else: # we just keep a single voice with a full duration rest 235 | m.remove(voices) 236 | v = stream.Voice() 237 | v.id = "1" 238 | rest = note.Rest(quarterLength=m.barDuration.quarterLength) 239 | v.append(rest.splitAtDurations()) 240 | m.insert(0, v) 241 | return 242 | # pad non-full voices with rests 243 | for source, v in enumerate(m.voices): 244 | v: stream.Voice 245 | # Clean up overlaps 246 | v.id = str(source + 1) 247 | if m.highestTime < m.barDuration.quarterLength: 248 | quarterLength = m.barDuration.quarterLength - v.highestTime 249 | rest = note.Rest(quarterLength=quarterLength) 250 | v.append(rest.splitAtDurations()) 251 | 252 | 253 | for part in mxl.parts: 254 | measures = list(part.getElementsByClass("Measure")) 255 | for m in measures: 256 | merge_and_pad_voices(m) 257 | 258 | mxl = mxl.splitAtDurations(recurse=True)[0] 259 | return mxl 260 | 261 | -------------------------------------------------------------------------------- /midi2scoretransformer/tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tokenizer for music21 streams and pretty_midi objects.""" 2 | 3 | import math 4 | from fractions import Fraction 5 | from typing import Dict, List 6 | 7 | import pretty_midi 8 | import torch 9 | import torch.nn.functional as F 10 | from music21 import ( 11 | articulations, 12 | clef, 13 | converter, 14 | expressions, 15 | instrument, 16 | key, 17 | meter, 18 | note, 19 | stream, 20 | tempo, 21 | ) 22 | 23 | from music21.common.numberTools import opFrac 24 | from music21.midi.translate import prepareStreamForMidi 25 | from score_utils import realize_spanners 26 | 27 | 28 | class Downbeat: 29 | MEASURE_NUMBER = 0 30 | OFFSET = 1 31 | LAST_OFFSET = 2 32 | MEASURE_LENGTH = 3 33 | 34 | 35 | db_config = Downbeat.MEASURE_LENGTH 36 | 37 | PARAMS = { 38 | "offset": {"min": 0, "max": 6, "step_size": 1 / 24}, 39 | "duration": {"min": 0, "max": 4, "step_size": 1 / 24}, 40 | "downbeat": {"min": -1 / 24, "max": 6, "step_size": 1 / 24}, 41 | } 42 | 43 | 44 | class MultistreamTokenizer: 45 | @staticmethod 46 | def midi_to_list(midi_path: str) -> List[pretty_midi.Note]: 47 | """Converts a MIDI file to a list of notes. 48 | Used during preprocessing. 49 | 50 | Parameters 51 | ---------- 52 | midi_path : str 53 | Path to the midi file. 54 | 55 | Returns 56 | ------- 57 | List[pretty_midi.Note] 58 | """ 59 | midi = pretty_midi.PrettyMIDI(midi_path) 60 | return sorted( 61 | [n for ins in midi.instruments for n in ins.notes], 62 | key=lambda n: (n.start, n.pitch, n.end - n.start), 63 | ) 64 | 65 | @staticmethod 66 | def parse_midi(midi_path: str) -> Dict[str, torch.Tensor]: 67 | """Converts a MIDI file to a list of tensors. 68 | No quantization or bucketing is applied yet. 69 | Used during preprocessing. 70 | 71 | Parameters 72 | ---------- 73 | midi_path : str 74 | Path to the midi file. 75 | 76 | Returns 77 | ------- 78 | Dict[str, torch.Tensor] 79 | returns a dict of tensors of shape (n_notes,) with keys 80 | "onset", "duration", "pitch", "velocity" 81 | """ 82 | midi_list = MultistreamTokenizer.midi_to_list(midi_path) 83 | return { 84 | "onset": torch.FloatTensor([n.start for n in midi_list]), 85 | "duration": torch.FloatTensor([n.end - n.start for n in midi_list]), 86 | "pitch": torch.LongTensor([n.pitch for n in midi_list]), 87 | "velocity": torch.LongTensor([n.velocity for n in midi_list]), 88 | } 89 | 90 | @staticmethod 91 | def bucket_midi(midi_streams: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 92 | """Flexible and fast conversion of raw midi values to token representation. 93 | 94 | Parameters 95 | ---------- 96 | midi_streams : Dict[str, torch.Tensor] 97 | Dict of tensors of shape (n_notes, ) with keys "onset", "duration", "pitch", 98 | "velocity". 99 | Ideally, these tensors are prepared by `MultiStreamTokenizer.parse_midi`. 100 | 101 | Returns 102 | ------- 103 | Dict[str, torch.Tensor] 104 | Dict of tensors of shape (n_notes, n_buckets) with keys "onset", "duration", 105 | "pitch", "velocity" and a padding tensor of shape (n_notes, ) w/ key "pad". 106 | """ 107 | # one-hot 108 | onset_stream = torch.diff(midi_streams["onset"], prepend=torch.Tensor([0.0])) 109 | onset_stream = torch.log(4 * onset_stream + 1) * 4 / math.log(4 * 8.0 + 1) 110 | onset_stream = one_hot_bucketing(onset_stream, 0, 4, 200) 111 | # Squash durations logarithmically such that 0 -> 0 and 16 -> 4 112 | # fmt: off 113 | duration_stream = midi_streams["duration"] 114 | duration_stream = (4 * duration_stream + 1).log() * 4 / math.log(4 * 16.0 + 1) 115 | duration_stream = one_hot_bucketing(duration_stream, 0, 4, 200) 116 | # fmt: on 117 | pitch_stream = one_hot_bucketing(midi_streams["pitch"], 0, 127, 128) 118 | velocity_stream = one_hot_bucketing(midi_streams["velocity"], 0, 127, 8) 119 | return { 120 | "onset": onset_stream.float(), 121 | "duration": duration_stream.float(), 122 | "pitch": pitch_stream.float(), 123 | "velocity": velocity_stream.float(), 124 | "pad": torch.ones((onset_stream.shape[0],), dtype=torch.long), 125 | } 126 | 127 | @staticmethod 128 | def tokenize_midi(midi_path) -> Dict[str, torch.Tensor]: 129 | """Converts a MIDI file to list of tensors. 130 | 131 | Parameters 132 | ---------- 133 | midi_path : str 134 | Path to the midi file. 135 | 136 | Returns 137 | ------- 138 | Dict[str, torch.Tensor] 139 | A dict of tensors that represent token streams. 140 | """ 141 | midi_streams = MultistreamTokenizer.parse_midi(midi_path) 142 | return MultistreamTokenizer.bucket_midi(midi_streams) 143 | 144 | @staticmethod 145 | def mxl_to_list(mxl_path: str) -> tuple[List[note.Note], stream.Score]: 146 | """Converts a music21 stream to a sorted and deduplicated list of notes. 147 | 148 | Parameters 149 | ---------- 150 | mxl_path : str 151 | Path to the musicxml file. 152 | 153 | Returns 154 | ------- 155 | List[music21.note.Note]: 156 | The list of notes in the music21 stream. 157 | music21.stream.Score: 158 | The music21 stream. This is only returned to 159 | ensure that the stream is not garbage collected. 160 | """ 161 | mxl = converter.parse(mxl_path, forceSource=True) 162 | mxl = realize_spanners(mxl) 163 | mxl: stream.Score = mxl.expandRepeats() 164 | # strip all ties inPlace 165 | mxl.stripTies(preserveVoices=False, inPlace=True) 166 | # Realize Tremolos 167 | for n in mxl.recurse().notes: 168 | for e in n.expressions: 169 | if isinstance(e, expressions.Tremolo): 170 | offset = n.offset 171 | out = e.realize(n, inPlace=True)[0] 172 | v = n.activeSite 173 | v.remove(n) 174 | for n2 in out: 175 | v.insert(offset, n2) 176 | offset += n2.duration.quarterLength 177 | break 178 | mxl = prepareStreamForMidi(mxl) 179 | 180 | notes: list[note.Note] = [] 181 | assert not any(note.isChord for note in mxl.flatten().notes) 182 | 183 | for n in mxl.flatten().notes: 184 | # if note.style.noteSize == "cue": 185 | # continue 186 | if n.style.hideObjectOnPrint: 187 | continue 188 | n.volume.velocity = int(round(n.volume.cachedRealized * 127)) 189 | notes.append(n) 190 | # Sort like this to preserve correct order for grace notes. 191 | def sortTuple(n): 192 | # Sort by offset, then pitch, then duration 193 | # Grace notes that share the same offset are sorted by their insertIndex 194 | # instead of their pitch as they rarely actually occur simultaneously 195 | return ( 196 | n.offset, 197 | not n.duration.isGrace, 198 | n.pitch.midi if not n.duration.isGrace else n.sortTuple(mxl).insertIndex, 199 | n.duration.quarterLength 200 | ) 201 | # return (n.offset, n.pitch.midi, n.duration.quarterLength) 202 | notes_sorted = sorted(notes, key=sortTuple) 203 | notes_consolidated: list[note.Note] = [] 204 | last_note = None 205 | for n in notes_sorted: 206 | if last_note is None or n.offset != last_note.offset or n.pitch.midi != last_note.pitch.midi: 207 | notes_consolidated.append(n) 208 | last_note = n 209 | elif last_note.duration.isGrace: 210 | last_note = n 211 | else: 212 | if n.duration.quarterLength > last_note.duration.quarterLength: 213 | last_note = n 214 | # sort again because we might have changed the duration of grace notes 215 | notes_consolidated = sorted(notes_consolidated, key=sortTuple) 216 | return notes_consolidated, mxl 217 | 218 | @staticmethod 219 | def parse_mxl(mxl_path) -> Dict[str, torch.Tensor]: 220 | """ 221 | Converts a MusixXML file to a list of tensors. 222 | All tensors have shape (n_notes,) and no quantization is applied yet. 223 | Used during preprocessing. 224 | 225 | Parameters 226 | ---------- 227 | mxl_path : str 228 | Path to the musicxml file. 229 | 230 | Returns 231 | ------- 232 | Dict[str, torch.Tensor] 233 | A dict of tensors of shape (n_notes, 1) with keys "offset" 234 | "downbeat", "duration", "pitch", "accidental", "velocity", "grace", "trill", 235 | "staccato", "voice", "stem", "hand". 236 | """ 237 | # return mxl_stream for garbage collection reasons only 238 | mxl_list, mxl_stream = MultistreamTokenizer.mxl_to_list(mxl_path) 239 | if len(mxl_list) == 0: 240 | offset_stream = torch.Tensor([]) 241 | downbeat_stream = torch.Tensor([]) 242 | duration_stream = torch.Tensor([]) 243 | pitch_stream = torch.Tensor([]) 244 | accidental_stream = torch.Tensor([]) 245 | keysignature_stream = torch.Tensor([]) 246 | velocity_stream = torch.Tensor([]) 247 | grace_stream = torch.Tensor([]) 248 | trill_stream = torch.Tensor([]) 249 | staccato_stream = torch.Tensor([]) 250 | voice_stream = torch.Tensor([]) 251 | stem_stream = torch.Tensor([]) 252 | hand_stream = torch.Tensor([]) 253 | else: 254 | # fmt: off 255 | note_offsets = torch.FloatTensor([n.offset for n in mxl_list]) 256 | measure_offsets = torch.FloatTensor([n.getContextByClass("Measure").offset for n in mxl_list]) 257 | offset_stream = note_offsets - measure_offsets 258 | 259 | if db_config == Downbeat.MEASURE_NUMBER: 260 | nums = torch.tensor([n.getContextByClass("Measure").number for n in mxl_list]) 261 | downbeat_stream = (torch.diff(nums, prepend=torch.tensor([1])) > 0).float() 262 | elif db_config == Downbeat.OFFSET: 263 | downbeat_stream = torch.logical_or(offset_stream == 0, torch.diff(offset_stream, prepend=torch.tensor([0.0])) < 0).float() 264 | elif db_config == Downbeat.LAST_OFFSET: 265 | downbeat_stream = torch.diff(measure_offsets, prepend=torch.tensor([0.0])) > 0 266 | shifts = measure_offsets - torch.cat((torch.tensor([0]), note_offsets[:-1])) 267 | downbeat_stream = torch.where(downbeat_stream, shifts, torch.ones_like(downbeat_stream).float() * PARAMS["downbeat"]["min"]) 268 | elif db_config == Downbeat.MEASURE_LENGTH: 269 | downbeat_stream = torch.diff(measure_offsets, prepend=torch.tensor([0.0])) 270 | downbeat_stream[downbeat_stream<=0] = PARAMS["downbeat"]["min"] 271 | 272 | duration_stream = torch.Tensor([n.duration.quarterLength for n in mxl_list]) 273 | pitch_stream = torch.Tensor([n.pitch.midi for n in mxl_list]) 274 | velocity_stream = torch.Tensor([n.volume.velocity for n in mxl_list]) 275 | def alter_map(accidental): 276 | if accidental is None: 277 | return 5 278 | alter_to_value = {-2: 0, -1: 1, 0: 2, 1: 3, 2: 4} 279 | # if not in the mapping, return 6 (for unknown) 280 | return alter_to_value.get(accidental.alter, 6) 281 | accidental_stream = torch.Tensor([alter_map(n.pitch.accidental) for n in mxl_list]) 282 | # for each note offset, find the last key that occurs before or at the same time as it 283 | keysignatures = {float(e.offset): e for e in mxl_stream.flatten().getElementsByClass(key.KeySignature)} 284 | keysignature_stream = torch.Tensor([next(((v.sharps if v.sharps is not None else 8) for k, v in reversed(keysignatures.items()) if k <= n), 8) for n in note_offsets]) + 7 285 | # MusicXML attribute streams 286 | grace_stream = torch.Tensor([n.duration.isGrace for n in mxl_list]) 287 | trills = (expressions.Trill, expressions.InvertedMordent, expressions.Mordent, expressions.Turn) 288 | trill_stream = torch.Tensor([any(isinstance(e, trills) for e in n.expressions) for n in mxl_list]) 289 | staccatos = (articulations.Staccatissimo, articulations.Staccato) 290 | staccato_stream = torch.Tensor([any(isinstance(e, staccatos) for e in n.articulations) for n in mxl_list]) 291 | voices = [n.getContextByClass("Voice") for n in mxl_list] 292 | voice_stream = torch.Tensor([int(v.id) if v is not None else 0 for v in voices]) 293 | stem_map = {"up": 0, "down": 1, "noStem": 2} 294 | stem_stream = torch.Tensor([stem_map.get(n.stemDirection, 3) for n in mxl_list]) 295 | # fmt: on 296 | # Hands/Staff logic is slightly more complicated 297 | # 298 | hand_stream = [] 299 | not_matched = set() 300 | for n in mxl_list: 301 | # Usually part names are similar to "[P1-Staff2]" 302 | part_name = n.getContextByClass("Part").id.lower() 303 | if "staff1" in part_name: 304 | hand_stream.append(0) 305 | elif "staff2" in part_name: 306 | hand_stream.append(1) 307 | else: 308 | hand_stream.append(2) 309 | if part_name not in not_matched: # only one warning per part 310 | not_matched.add(part_name) 311 | # print("Couldn't match", part_name) 312 | hand_stream = torch.tensor(hand_stream) 313 | mxl_stream # keep stream for gc only 314 | return { 315 | "offset": offset_stream, 316 | "downbeat": downbeat_stream, 317 | "duration": duration_stream, 318 | "pitch": pitch_stream, 319 | "accidental": accidental_stream, 320 | "keysignature": keysignature_stream, 321 | "velocity": velocity_stream, 322 | "grace": grace_stream, 323 | "trill": trill_stream, 324 | "staccato": staccato_stream, 325 | "voice": voice_stream, 326 | "stem": stem_stream, 327 | "hand": hand_stream, 328 | } 329 | 330 | @staticmethod 331 | def bucket_mxl(mxl_streams: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 332 | # Bucketing TODO: checkout bucketing 333 | # fmt: off 334 | offset_stream = one_hot_bucketing(mxl_streams["offset"], **PARAMS["offset"]) 335 | duration_stream = one_hot_bucketing(mxl_streams["duration"], **PARAMS["duration"]) 336 | downbeat_stream = one_hot_bucketing(mxl_streams["downbeat"], **PARAMS["downbeat"]) 337 | pitch_stream = one_hot_bucketing(mxl_streams["pitch"], 0, 127, 128) 338 | accidental_stream = one_hot_bucketing(mxl_streams["accidental"], 0, 6, 7) 339 | keysignature_stream = one_hot_bucketing(mxl_streams["keysignature"], 0, 15, 16) 340 | velocity_stream = one_hot_bucketing(mxl_streams["velocity"], 0, 127, 8) 341 | grace_stream = one_hot_bucketing(mxl_streams["grace"], 0, 1, 2) 342 | trill_stream = one_hot_bucketing(mxl_streams["trill"], 0, 1, 2) 343 | staccato_stream = one_hot_bucketing(mxl_streams["staccato"], 0, 1, 2) 344 | voice_stream = one_hot_bucketing(mxl_streams["voice"], 0, 8, 9) 345 | stem_stream = one_hot_bucketing(mxl_streams["stem"], 0, 3, 4) 346 | hand_stream = one_hot_bucketing(mxl_streams["hand"], 0, 2, 3) 347 | # fmt: on 348 | # Beams 349 | # Slurs 350 | # Tuplets 351 | # Dots? 352 | return { 353 | "offset": offset_stream.float(), 354 | "downbeat": downbeat_stream.float(), 355 | "duration": duration_stream.float(), 356 | "pitch": pitch_stream.float(), 357 | "accidental": accidental_stream.float(), 358 | "keysignature": keysignature_stream.float(), 359 | "velocity": velocity_stream.float(), 360 | "grace": grace_stream.float(), 361 | "trill": trill_stream.float(), 362 | "staccato": staccato_stream.float(), 363 | "voice": voice_stream.float(), 364 | "stem": stem_stream.float(), 365 | "hand": hand_stream.float(), 366 | "pad": torch.ones((offset_stream.shape[0],), dtype=torch.long), 367 | } 368 | 369 | @staticmethod 370 | def tokenize_mxl(mxl_path: str) -> Dict[str, torch.Tensor]: 371 | """Converts a MusicXML file to a list of tensors of shape (n_notes). 372 | 373 | Parameters 374 | ---------- 375 | mxl_path : str 376 | Path to the musicxml file. 377 | 378 | Returns 379 | ------- 380 | torch.Tensor 381 | returns a list of tensors of shape (n_notes,) 382 | """ 383 | mxl_streams = MultistreamTokenizer.parse_mxl(mxl_path) 384 | return MultistreamTokenizer.bucket_mxl(mxl_streams) 385 | 386 | @staticmethod 387 | def detokenize_mxl(token_dict: Dict[str, torch.Tensor], midi_sequence: List[pretty_midi.Note]|None= None) -> stream.Score: 388 | """Decode the token streams into a music21 stream that can be saved to musicxml. 389 | The surprising complexity comes from incompatibilities in music21's XML export. 390 | This function is tested such that saving and reloading the musicxml file should 391 | yield the same score. This is not the case for music21's makeNotation function. 392 | 393 | Parameters 394 | ---------- 395 | token_dict : Dict[str, torch.Tensor] 396 | Dict of tensors of shape (n_notes, n_buckets) with keys "offset", "duration", 397 | "pitch", "velocity" and a padding tensor of shape (n_notes, ) with key "pad". 398 | 399 | Returns 400 | ------- 401 | music21.stream.Stream 402 | music21 stream that can be saved to musicxml. 403 | """ 404 | mask = token_dict["pad"].squeeze() > 0.5 # allow for prediction/soft values 405 | # fmt: off 406 | offset_stream = one_hot_unbucketing(token_dict["offset"][mask], **PARAMS["offset"]).numpy().astype(float) 407 | duration_stream = one_hot_unbucketing(token_dict["duration"][mask], **PARAMS["duration"]).numpy().astype(float) 408 | downbeat_stream = one_hot_unbucketing(token_dict["downbeat"][mask], **PARAMS["downbeat"]).numpy().astype(float) 409 | pitch_stream = one_hot_unbucketing(token_dict["pitch"][mask], 0, 127, 128).numpy().astype(int) 410 | accidental_stream = one_hot_unbucketing(token_dict["accidental"][mask][:, :6], 0, 6, 7).numpy().astype(int) 411 | keysignature_stream = one_hot_unbucketing(token_dict["keysignature"][mask], 0, 15, 16).numpy().astype(int) 412 | velocity_stream = one_hot_unbucketing(token_dict["velocity"][mask], 0, 127, 8).numpy().astype(int) 413 | grace_stream = one_hot_unbucketing(token_dict["grace"][mask], 0, 1, 2).numpy().astype(bool) 414 | trill_stream = one_hot_unbucketing(token_dict["trill"][mask], 0, 1, 2).numpy().astype(bool) 415 | staccato_stream = one_hot_unbucketing(token_dict["staccato"][mask], 0, 1, 2).numpy().astype(bool) 416 | voice_stream = one_hot_unbucketing(token_dict["voice"][mask][:, 1:], 1, 8, 8).numpy().astype(int) 417 | stem_stream = one_hot_unbucketing(token_dict["stem"][mask][:, :3], 0, 3, 4).numpy().astype(int) 418 | hand_stream = one_hot_unbucketing(token_dict["hand"][mask][:, :2], 0, 2, 3).numpy().astype(int) 419 | 420 | if midi_sequence is not None: 421 | midi_sequence = [m for i, m in enumerate(midi_sequence) if mask[i]] 422 | # fmt: on 423 | measures: list[list[stream.Measure]] = [[], []] 424 | active_voices_list: list[list[set[int]]] = [[], []] 425 | # We go through all notes twice, once for each hand. 426 | # We create measures/increment times etc. both times. 427 | # However, only the notes for the current part are inserted. 428 | # This is highly inefficient, but ensures correctness for now. 429 | for part in range(2): 430 | active_voices = set() 431 | m = stream.Measure(number=1) 432 | voices = [stream.Voice(id=str(i)) for i in range(1, 17)] 433 | previous_note_is_downbeat = True 434 | last_measure_duration = None 435 | last_keysignature = None 436 | for i in range(len(offset_stream)): 437 | if db_config == Downbeat.MEASURE_NUMBER: 438 | if downbeat_stream[i]: 439 | for v in voices: 440 | m.insert(0, v) 441 | measures[part].append(m) 442 | active_voices_list[part].append(active_voices) 443 | active_voices = set() 444 | voices = [stream.Voice(id=str(i)) for i in range(1, 17)] 445 | m = stream.Measure(m.number + 1) 446 | elif db_config == Downbeat.OFFSET: 447 | if ( 448 | downbeat_stream[i] == 1 449 | and offset_stream[i] <= offset_stream[i - 1] 450 | ) or (i > 0 and offset_stream[i] < offset_stream[i - 1]): 451 | if not previous_note_is_downbeat: 452 | for v in voices: 453 | m.insert(0, v) 454 | measures[part].append(m) 455 | active_voices_list[part].append(active_voices) 456 | active_voices = set() 457 | voices = [stream.Voice(id=str(i)) for i in range(1, 17)] 458 | m = stream.Measure(m.number + 1) 459 | previous_note_is_downbeat = True 460 | else: 461 | previous_note_is_downbeat = False 462 | elif db_config in (Downbeat.LAST_OFFSET, Downbeat.MEASURE_LENGTH): 463 | if ( 464 | 1 < i + 1 < len(downbeat_stream) 465 | and downbeat_stream[i] >= 0 466 | and not ( 467 | downbeat_stream[i + 1] >= 0 468 | and offset_stream[i - 1] != 0 469 | and offset_stream[i + 1] <= offset_stream[i] 470 | ) 471 | ): 472 | if midi_sequence is not None and i < len(midi_sequence): 473 | # If we have the input midi timings, we can use them to set the tempo 474 | # We first set tempo marks to track where their location `should` be 475 | # The inserted tempo marks therefore form (offset, time in seconds) pairs. 476 | for s in range(3): 477 | if midi_sequence[min(i+s, len(midi_sequence))].pitch == pitch_stream[i]: 478 | m.insert(opFrac(offset_stream[i]), tempo.MetronomeMark(number=midi_sequence[min(i+s, len(midi_sequence))].start)) 479 | break 480 | elif midi_sequence[max(i-s, 0)].pitch == pitch_stream[i]: 481 | m.insert(opFrac(offset_stream[i]), tempo.MetronomeMark(number=midi_sequence[max(i-s, 0)].start)) 482 | break 483 | else: 484 | m.insert(opFrac(offset_stream[i]), tempo.MetronomeMark(number=midi_sequence[i].start)) 485 | for v in voices: 486 | m.insert(0, v) 487 | if db_config == Downbeat.MEASURE_LENGTH: 488 | duration = opFrac(downbeat_stream[i]) 489 | else: 490 | duration = opFrac(offset_stream[i - 1] + downbeat_stream[i]) 491 | 492 | def find_time_signature(measure_length: float) -> meter.TimeSignature|None: 493 | frac = Fraction(measure_length / 4 + 1e-5).limit_denominator(16) 494 | if frac.numerator != 0: 495 | if frac.denominator == 1: 496 | return meter.TimeSignature(f"{frac.numerator * 4}/4") 497 | elif frac.denominator == 2: 498 | return meter.TimeSignature(f"{frac.numerator * 4}/8") 499 | elif frac.denominator == 4: 500 | return meter.TimeSignature(f"{frac.numerator}/4") 501 | elif frac.denominator == 8: 502 | return meter.TimeSignature(f"{frac.numerator}/8") 503 | elif frac.denominator == 16: 504 | return meter.TimeSignature(f"{frac.numerator}/16") 505 | return None 506 | 507 | if duration != 0 and duration != last_measure_duration: 508 | ts = find_time_signature(duration) 509 | if ts is not None: 510 | m.insert(0, ts) 511 | elif last_measure_duration is None: # first measure 512 | m.insert(0, meter.TimeSignature("4/4")) 513 | last_measure_duration = duration 514 | measures[part].append(m) 515 | active_voices_list[part].append(active_voices) 516 | active_voices = set() 517 | voices = [stream.Voice(id=str(i)) for i in range(1, 17)] 518 | m = stream.Measure(m.number + 1) 519 | if keysignature_stream[i] != last_keysignature: 520 | m.insert(0, key.KeySignature(keysignature_stream[i] - 7)) 521 | last_keysignature = keysignature_stream[i] 522 | # inefficient but don't care for now 523 | if hand_stream[i] == part: 524 | n = note.Note() 525 | n.duration.quarterLength = opFrac(duration_stream[i]) 526 | # Adding accidentals shifts the pitch step so we have to account for that 527 | # by offsetting the other way first 528 | accidental_mapping = { 529 | 0: (+2, "double-flat"), 530 | 1: (+1, "flat"), 531 | 2: (0, "natural"), 532 | 3: (-1, "sharp"), 533 | 4: (-2, "double-sharp") 534 | } 535 | midi_adjustment, accidental_name = accidental_mapping.get(accidental_stream[i], (0, None)) 536 | n.pitch.midi = pitch_stream[i] + midi_adjustment 537 | if accidental_name is not None: 538 | n.pitch.accidental = accidental_name 539 | else: 540 | # Handle the case where the accidental_stream value is outside the expected range 541 | n.pitch.midi = pitch_stream[i] 542 | if n.pitch.midi != pitch_stream[i]: 543 | print(f"Mismatch: {n.pitch.midi} != {pitch_stream[i]}") 544 | n.pitch.midi = pitch_stream[i] 545 | 546 | # n.volume.velocity = velocity_stream[i] 547 | if trill_stream[i]: 548 | n.expressions.append(expressions.Trill()) 549 | if staccato_stream[i]: 550 | n.articulations.append(articulations.Staccato()) 551 | if grace_stream[i] or n.duration.quarterLength == 0: 552 | # obscure bug in makeNotation forces us to set the duration to 0 553 | n.duration.quarterLength = 0 554 | n.duration = n.duration.getGraceDuration() 555 | stem_map = {0: "up", 1: "down", 2: "noStem"} 556 | n.stemDirection = stem_map[stem_stream[i]] 557 | 558 | v = voice_stream[i] 559 | # We need to find a voice that is not active at the current offset. 560 | # (also have to consider the previous part!). 561 | def find_suitable_voice(v): 562 | candidates = sorted( 563 | range(len(voices)), key=lambda x: (abs(x - v), -x) 564 | ) 565 | not_ideal = None 566 | for candidate in candidates: 567 | # Voice already used in other part? 568 | if part == 1 and len(measures[1]) < len(measures[0]): 569 | if candidate in active_voices_list[0][len(measures[1])]: 570 | continue 571 | # Voice already used in current measure at current timestep? 572 | if candidate in active_voices: 573 | o = opFrac(offset_stream[i]) 574 | for n in voices[candidate].notes: 575 | if opFrac(n.offset + n.duration.quarterLength) > o: 576 | break 577 | else: 578 | return candidate 579 | elif not_ideal is None: 580 | not_ideal = candidate 581 | return v if not_ideal is None else not_ideal 582 | 583 | v_new = find_suitable_voice(v - 1) 584 | active_voices.add(v_new) 585 | voices[v_new].insert(opFrac(offset_stream[i]), n) 586 | if midi_sequence is not None and i < len(midi_sequence): 587 | for s in range(3): 588 | if midi_sequence[max(i-s, 0)].pitch == pitch_stream[i]: 589 | m.insert(opFrac(offset_stream[i]), tempo.MetronomeMark(number=midi_sequence[max(i-s, 0)].start)) 590 | break 591 | else: 592 | m.insert(opFrac(offset_stream[i]), tempo.MetronomeMark(number=midi_sequence[i].start)) 593 | for v in voices: 594 | m.insert(0, v) 595 | measures[part].append(m) 596 | active_voices_list[part].append(active_voices) 597 | 598 | s = stream.Score() 599 | if db_config not in (Downbeat.LAST_OFFSET, Downbeat.MEASURE_LENGTH): 600 | lastDuration = -1 601 | for m0, m1 in zip(measures[0], measures[1]): 602 | try: 603 | if m0.flatten().highestTime >= m1.flatten().highestTime: 604 | ts = m0.bestTimeSignature() 605 | else: 606 | ts = m1.bestTimeSignature() 607 | except meter.MeterException: 608 | print(f"Couldn't find time signature : ({m0.highestTime} {m1.highestTime}, {m0.duration.quarterLength}, {m1.duration.quarterLength})") 609 | continue 610 | if ts.barDuration.quarterLength != lastDuration: 611 | m0.timeSignature = ts 612 | m1.timeSignature = ts 613 | for part in range(2): 614 | p = stream.Part() 615 | # Inserting an instrument is required to ensure deterministic part names 616 | # MUSTER does not play nice with names other than P1, P2, ..., so we force that here. 617 | ins = instrument.Instrument() 618 | ins.partId = f"P{part+1}" 619 | p.insert(0, ins) 620 | offset = 0 621 | for i, m in enumerate(measures[part]): 622 | # Special case pickup measure 623 | if i == 0: 624 | c = clef.TrebleClef() if part == 0 else clef.BassClef() 625 | m.insert(0, c) 626 | p.insert(0, m) 627 | else: 628 | p.insert(offset, m) 629 | offset += m.barDuration.quarterLength 630 | s.insert(0, p) 631 | return s 632 | 633 | 634 | def one_hot_bucketing( 635 | values: torch.Tensor | List[int | float], min, max, buckets=None, step_size=None 636 | ) -> torch.Tensor: 637 | assert buckets is not None or step_size is not None 638 | if not isinstance(values, torch.Tensor): 639 | values = torch.tensor(values) 640 | if values.ndim == 2: 641 | values = values.squeeze(1) 642 | values = values.float() 643 | 644 | # discretize the values into buckets 645 | if buckets is None: 646 | buckets = int((max + step_size - min) / step_size) 647 | bucket_indices = ((values - min) / (max + step_size - min) * buckets).round() 648 | else: 649 | bucket_indices = (values - min) / (max - min) * buckets 650 | # clamp the bucket indices to be between 0 and n_buckets - 1 651 | bucket_indices = bucket_indices.long().clamp(0, buckets - 1) 652 | one_hots = F.one_hot(bucket_indices, num_classes=buckets) 653 | return one_hots 654 | 655 | 656 | def one_hot_unbucketing( 657 | one_hots: torch.Tensor | List[int | float], min, max, buckets=None, step_size=None 658 | ) -> torch.FloatTensor: 659 | assert buckets is not None or step_size is not None 660 | if not isinstance(one_hots, torch.Tensor): 661 | one_hots = torch.tensor(one_hots) 662 | 663 | # Convert the one-hot vectors back into bucket indices 664 | bucket_indices = torch.argmax(one_hots, dim=-1) 665 | # Convert the bucket indices back into the original values 666 | if step_size is None: 667 | step_size = (max + 1 - min) / buckets 668 | values = min + bucket_indices.float() * step_size 669 | return values 670 | 671 | 672 | def positional_embedding(times, dim=128) -> torch.Tensor: 673 | if not isinstance(times, torch.Tensor): 674 | times = torch.tensor(times) # (T, D) 675 | divisors = torch.tensor([10000 ** (2 * (i // 2) / dim) for i in range(dim)]) # D 676 | position_enc = times.unsqueeze(-1) / divisors.unsqueeze(0) 677 | sines = torch.sin(position_enc[:, 0::2]) 678 | cosines = torch.cos(position_enc[:, 1::2]) 679 | return torch.cat([sines, cosines], dim=-1) 680 | 681 | 682 | if __name__ == "__main__": 683 | import argparse 684 | import sys 685 | 686 | from score_transformer import score_similarity 687 | 688 | sys.path.append("../") 689 | 690 | parser = argparse.ArgumentParser() 691 | parser.add_argument("--midi", type=str, default=None) 692 | parser.add_argument( 693 | "--mxl", 694 | type=str, 695 | default="./data/asap-dataset/Bach/Fugue/bwv_846/xml_score.musicxml", 696 | ) 697 | args = parser.parse_args() 698 | 699 | if args.midi is not None: 700 | tokenized = MultistreamTokenizer.tokenize_midi(args.midi) 701 | print(tokenized) 702 | 703 | if args.mxl is not None: 704 | mxl, _ = MultistreamTokenizer.mxl_to_list(args.mxl) 705 | tokenized = MultistreamTokenizer.tokenize_mxl(args.mxl) 706 | s_recon = MultistreamTokenizer.detokenize_mxl(tokenized) 707 | s = converter.parse(args.mxl, forceSource=True).expandRepeats() 708 | print(score_similarity(s_recon, s)) 709 | -------------------------------------------------------------------------------- /midi2scoretransformer/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for inference, including batched and chunked inference and postprocessing.""" 2 | import warnings 3 | 4 | import torch 5 | from muster import muster 6 | from score_transformer import score_similarity 7 | 8 | from tokenizer import MultistreamTokenizer 9 | from score_utils import postprocess_score 10 | 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | 13 | 14 | def eval(y_hat, gt_mxl_path: str) -> dict[str, dict[str, float]|None]: 15 | mxl = MultistreamTokenizer.detokenize_mxl(y_hat) 16 | mxl = postprocess_score(mxl, inPlace=True) 17 | 18 | # fmt: off 19 | with warnings.catch_warnings(): 20 | warnings.simplefilter("ignore") 21 | sim = { 22 | "mxl <-> gt_mxl": score_similarity_normalized(mxl, gt_mxl_path, full=False), 23 | "muster": muster(mxl, gt_mxl_path), 24 | } 25 | return sim 26 | # fmt: on 27 | 28 | 29 | def score_similarity_normalized(est, gt, full=False): 30 | if est is None or gt is None: 31 | return { 32 | "Clef": None, 33 | "KeySignature": None, 34 | "TimeSignature": None, 35 | "NoteDeletion": None, 36 | "NoteInsertion": None, 37 | "NoteSpelling": None, 38 | "NoteDuration": None, 39 | "StemDirection": None, 40 | "Beams": None, 41 | "Tie": None, 42 | "StaffAssignment": None, 43 | "Voice": None, 44 | } 45 | sim = score_similarity(est, gt, full=full) 46 | new_sim = {} 47 | for k, v in sim.items(): 48 | if v is None: 49 | new_sim[k] = None 50 | elif k == "n_Note" or any(key in k for key in ["F1", "Rec", "Prec", "TP", "FP", "FN", "TN"]): 51 | new_sim[k] = v 52 | else: 53 | new_sim[k] = v / sim["n_Note"] 54 | return new_sim 55 | 56 | 57 | 58 | def quantize_path(path, model, **kwargs): 59 | """Quantize a midi file at `path` using the model `model`. 60 | The resulting score should be saved with makeNotation=False. 61 | """ 62 | x = MultistreamTokenizer.tokenize_midi(path) 63 | y_hat = infer(x, model, **kwargs) 64 | mxl = MultistreamTokenizer.detokenize_mxl(y_hat) 65 | mxl = postprocess_score(mxl) 66 | return mxl 67 | 68 | 69 | def infer(x, model, overlap=64, chunk=512, verbose=True, kv_cache=True) -> dict[str, torch.Tensor]: 70 | single_example = x['pitch'].ndim == 2 71 | if single_example: 72 | x = {k: v.unsqueeze(0) for k, v in x.items()} 73 | x = {k: v.to(model.device) for k, v in x.items()} 74 | if chunk <= overlap: 75 | raise ValueError("`chunk` must be greater than `overlap`.") 76 | y_full = None 77 | for i in range(0, max(x['pitch'].shape[1] - overlap, 1), chunk - overlap): 78 | if verbose: 79 | print("Infer", i, "/", x['pitch'].shape[1], end='\r') 80 | x_chunk = {k: v[:, i:i + chunk] for k, v in x.items()} 81 | if i == 0 or overlap == 0: # No context required 82 | y_hat = model.generate(x=x_chunk, top_k=1, max_length=chunk, kv_cache=kv_cache) 83 | else: 84 | # Keep the last 'overlap' notes of the previous chunk as context 85 | y_hat_prev = {k: v[:, -overlap:] if k != 'pad' else v[:, -overlap:, 0] for k, v in y_full.items()} 86 | with torch.autocast(device_type=device): 87 | y_hat = model.generate(x=x_chunk, y=y_hat_prev, top_k=1, max_length=chunk, kv_cache=kv_cache) 88 | y_hat = {k: v[:, overlap:] for k, v in y_hat.items()} 89 | 90 | if y_full is None: 91 | y_full = y_hat 92 | else: 93 | for k in y_full: 94 | y_full[k] = torch.cat((y_full[k], y_hat[k]), dim=1) 95 | if single_example: 96 | y_full = {k: v[0].cpu() for k, v in y_full.items()} 97 | else: 98 | y_full = {k: v.cpu() for k, v in y_full.items()} 99 | return y_full 100 | 101 | 102 | def pad_batch(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: 103 | """Pad a batch of irregular tensors to the same length, then concat.""" 104 | max_len = max([x["pitch"].shape[1] for x in batch]) 105 | for x in batch: 106 | for k, v in x.items(): 107 | pad_length = max_len - v.shape[1] 108 | shape = list(v.shape) 109 | shape[1] = pad_length 110 | x[k] = torch.cat( 111 | (v, torch.zeros(shape, dtype=v.dtype, device=v.device)), dim=1 112 | ) 113 | # Concat along first dim: 114 | out = {} 115 | for k in batch[0].keys(): 116 | out[k] = torch.cat([x[k] for x in batch], dim=0) 117 | return out 118 | 119 | 120 | def cat_dict( 121 | a: dict[str, torch.Tensor], b: dict[str, torch.Tensor], dim=0 122 | ) -> dict[str, torch.Tensor]: 123 | assert set(a.keys()) == set(b.keys()) 124 | return {k: torch.cat([a[k], b[k]], dim=dim) for k in a.keys()} 125 | 126 | 127 | def cut_pad( 128 | tensor: torch.Tensor, max_len: int, offset: int, pad_value: int = 0 129 | ) -> torch.Tensor: 130 | """ 131 | Cut a tensor's first dimension to a maximum length and pad the tensor's first 132 | dimension to a minimum length. 133 | 134 | Args: 135 | tensor (Tensor): tensor to be cut or padded 136 | max_len (int): maximum length of the tensor's first dimension 137 | offset (int): offset to cut the tensor if too long 138 | pad_value (int): value used for padding, default is 0 139 | 140 | Returns: 141 | Tensor: tensor cut or padded along its first dimension to shape (max_len,) 142 | or (max_len, n_cols) if input is 2D 143 | """ 144 | if tensor.dim() == 1: 145 | n = tensor.size(0) 146 | if n > max_len: 147 | tensor = tensor[offset : offset + max_len] 148 | elif n < max_len: 149 | pad_size = max_len - n 150 | pad = torch.full((pad_size,), pad_value, dtype=tensor.dtype) 151 | tensor = torch.cat((tensor, pad), dim=0) 152 | elif tensor.dim() == 2: 153 | n, n_cols = tensor.size() 154 | if n > max_len: 155 | tensor = tensor[offset : offset + max_len] 156 | elif n < max_len: 157 | pad_size = max_len - n 158 | pad = torch.full((pad_size, n_cols), pad_value, dtype=tensor.dtype) 159 | tensor = torch.cat((tensor, pad), dim=0) 160 | else: 161 | raise ValueError("Input tensor must be 1D or 2D.") 162 | 163 | return tensor -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | joblib>=1.3.1 2 | numba>=0.53.1 3 | pandas>=2.0.0 4 | pretty_midi>=0.2.10 5 | tokenizers>=0.13.3 6 | torch>=2.0.0 7 | transformers>=4.29.2 8 | lightning>=2.0.0 9 | 10 | muster @ git+https://github.com/TimFelixBeyer/amtevaluation.github.io 11 | music21 @ git+https://github.com/TimFelixBeyer/music21@0ed70bb 12 | score_transformer @ git+https://github.com/TimFelixBeyer/ScoreTransformer@934a228 13 | --------------------------------------------------------------------------------