├── LICENSE.md ├── README.md ├── assets ├── perf-embs-clf-4x6.png └── scoreperformer.png ├── data ├── directions │ └── direction_classes.json └── tokenizers │ ├── spmuple_bar.json │ ├── spmuple_beat.json │ ├── spmuple_onset.json │ └── spmuple_window.json ├── recipes ├── default.yaml └── scoreperformer │ ├── ablation │ ├── no_cont_tokens.yaml │ ├── no_io_tie.yaml │ ├── no_masked_seq.yaml │ ├── no_saln.yaml │ └── no_score_enc.yaml │ ├── base.yaml │ ├── custom_hierarchy.yaml │ ├── minimal.yaml │ └── no_classifiers.yaml ├── requirements.txt └── scoreperformer ├── __init__.py ├── data ├── __init__.py ├── collators │ ├── __init__.py │ ├── common.py │ ├── directions.py │ ├── performance.py │ └── score_performance.py ├── datasets │ ├── __init__.py │ ├── directions.py │ ├── performance.py │ ├── score_performance.py │ ├── token_sequence.py │ └── utils.py ├── directions │ ├── __init__.py │ ├── articulation.py │ ├── dynamic.py │ ├── parser.py │ ├── tempo.py │ └── words.py ├── helpers │ ├── __init__.py │ ├── indexers.py │ └── processors.py ├── midi │ ├── __init__.py │ ├── beats.py │ ├── containers.py │ ├── preprocess.py │ ├── quantization.py │ ├── sync.py │ ├── timing.py │ └── utils.py ├── music_constants.py └── tokenizers │ ├── __init__.py │ ├── classes.py │ ├── common │ ├── __init__.py │ └── octuple_m.py │ ├── constants.py │ ├── midi_tokenizer.py │ └── spmuple │ ├── __init__.py │ ├── base.py │ ├── encodings.py │ ├── spmuple.py │ └── spmuple2.py ├── experiments ├── __init__.py ├── callbacks.py ├── components.py ├── integrations.py ├── logging │ ├── __init__.py │ └── console_logger.py ├── optimizers.py ├── trainer.py ├── trainer_config.py └── trainer_utils.py ├── inference ├── __init__.py ├── generators.py └── messengers.py ├── models ├── __init__.py ├── base.py ├── classifiers │ ├── __init__.py │ ├── evaluator.py │ └── model.py └── scoreperformer │ ├── __init__.py │ ├── embeddings.py │ ├── evaluator.py │ ├── mmd_transformer.py │ ├── model.py │ ├── transformer.py │ └── wrappers.py ├── modules ├── __init__.py ├── constructor.py ├── layers.py ├── sampling.py └── transformer │ ├── __init__.py │ ├── attend.py │ ├── attention.py │ ├── embeddings.py │ ├── feedforward.py │ └── transformer.py ├── train.py └── utils ├── __init__.py ├── config.py ├── functions.py ├── io.py ├── playback.py └── plots.py /README.md: -------------------------------------------------------------------------------- 1 | # ScorePerformer: Expressive Piano Performance Rendering Model 2 | 3 | ScorePerformer architecture 4 | 5 | > Code for the paper [*"ScorePerformer: Expressive Piano Performance Rendering with Fine-Grained 6 | Control"*](https://archives.ismir.net/ismir2023/paper/000069.pdf) published in the Proceedings of the 7 | > [24th International Society for Music Information Retrieval (ISMIR) Conference, Milan, 2023](https://ismir2023.ismir.net). 8 | > 9 | >Authors: Ilya Borovik and Vladimir Viro 10 | > 11 | >[![Paper](https://img.shields.io/badge/ISMIR_2023-Paper-blue)](https://archives.ismir.net/ismir2023/paper/000069.pdf) 12 | [![Poster](https://img.shields.io/badge/ISMIR_2023-Poster_&_Video-blue)](http://ismir2023program.ismir.net/poster_183.html) 13 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1uKEX02D8cn3wzG-oS7YWE3IW-OrGV22w) 14 | 15 | ## Description 16 | 17 | The paper presents ScorePerformer, a controllable piano performance rendering model that: 18 | 19 | 1. learns fine-grained performance style features at the global, bar, beat, and onset levels using transformers and 20 | hierarchical maximum mean discrepancy variational autoencoders; 21 | 2. provides musical language-driven control over the learned style space through a set of trained performance direction 22 | classifiers; 23 | 3. utilizes a tokenized encoding for aligned score and performance music (SPMuple) with a smooth local window tempo 24 | function. 25 | 26 | The repository provides the code for the model and tokenizers, along with the training and inference pipelines. The 27 | trained model is available to interact with in a simple [Colab demo](https://colab.research.google.com/drive/1uKEX02D8cn3wzG-oS7YWE3IW-OrGV22w). 28 | 29 | *Note*: The code for the alignment and preprocessing of score and performance data is deliberately not made available. 30 | It can be discussed and released upon request and demand. 31 | 32 | If you have any questions or requests, please write to ilya.borovik@skoltech.ru 33 | 34 | ## Latent Style Space Control 35 | 36 | Our approach to make the expressive performance rendering controllable relies on the use of performance direction 37 | markings available in musical scores. We use a two-step approach: 38 | 39 | 1. learn performance style features relevant to performance rendering using a performance style encoder; 40 | 2. train performance direction classifiers on the learned style embeddings and ”label” the style space using their predictions. 41 | 42 | Below is a vizualization of the top two principal components of the learned latent style spaces, labelled by a subset of 43 | the trained performance direction classifiers. 44 | 45 | Projected Latent Style Spaces 46 | 47 | ## Citing 48 | 49 | If you use or reference the model, please cite the following paper: 50 | 51 | ``` 52 | @inproceedings{bovorik2023scoreperformer, 53 | title={{ScorePerformer: Expressive Piano Performance Rendering with Fine-Grained Control}}, 54 | author={Borovik, Ilya and Viro, Vladimir}, 55 | booktitle={Proceedings of the 24th International Society for Music Information Retrieval Conference {(ISMIR)}}, 56 | year={2023}, 57 | url={https://archives.ismir.net/ismir2023/paper/000069.pdf}, 58 | } 59 | ``` 60 | 61 | ## License 62 | 63 | The work is released under a [CC BY-NC-SA 4.0 license](https://creativecommons.org/licenses/by-nc-sa/4.0/). -------------------------------------------------------------------------------- /assets/perf-embs-clf-4x6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilya16/ScorePerformer/c04f9a3155bb6afb6ca7bcc1cac2325184cc1262/assets/perf-embs-clf-4x6.png -------------------------------------------------------------------------------- /assets/scoreperformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilya16/ScorePerformer/c04f9a3155bb6afb6ca7bcc1cac2325184cc1262/assets/scoreperformer.png -------------------------------------------------------------------------------- /data/directions/direction_classes.json: -------------------------------------------------------------------------------- 1 | { 2 | "dynamic/absolute": [ 3 | "dynamic/ppp", 4 | "dynamic/pp", 5 | "dynamic/p", 6 | "dynamic/mp", 7 | "dynamic/mf", 8 | "dynamic/f", 9 | "dynamic/ff", 10 | "dynamic/fff", 11 | "dynamic/fp" 12 | ], 13 | "dynamic/hairpin": [ 14 | "dynamic/crescendo", 15 | "dynamic/diminuendo" 16 | ], 17 | "dynamic/accent": [ 18 | "dynamic/sf", 19 | "dynamic/rf" 20 | ], 21 | "tempo/absolute": [ 22 | "tempo/grave", 23 | "tempo/largo", 24 | "tempo/larghetto", 25 | "tempo/lento", 26 | "tempo/adagio", 27 | "tempo/andante", 28 | "tempo/andantino", 29 | "tempo/moderato", 30 | "tempo/allegretto", 31 | "tempo/allegro", 32 | "tempo/vivace", 33 | "tempo/presto", 34 | "tempo/prestissimo" 35 | ], 36 | "tempo/relative": [ 37 | "tempo/accelerando", 38 | "tempo/ritardando", 39 | "tempo/rallentando", 40 | "tempo/stringendo", 41 | "tempo/calando", 42 | "tempo/pi\u00f9 mosso", 43 | "tempo/animato", 44 | "tempo/stretto", 45 | "tempo/smorzando", 46 | "tempo/ritenuto" 47 | ], 48 | "articulation/arpeggiate": [ 49 | "articulation/arpeggiate" 50 | ], 51 | "articulation/fermata": [ 52 | "articulation/fermata" 53 | ], 54 | "articulation/staccato": [ 55 | "articulation/staccato" 56 | ], 57 | "articulation/tenuto": [ 58 | "articulation/tenuto" 59 | ] 60 | } -------------------------------------------------------------------------------- /data/tokenizers/spmuple_window.json: -------------------------------------------------------------------------------- 1 | { 2 | "config": { 3 | "pitch_range": [ 4 | 21, 5 | 109 6 | ], 7 | "beat_res": { 8 | "0_2": 16, 9 | "2_4": 8, 10 | "4_8": 4, 11 | "8_16": 2, 12 | "16_64": 1 13 | }, 14 | "nb_velocities": 127, 15 | "special_tokens": [ 16 | "PAD", 17 | "MASK", 18 | "SOS", 19 | "EOS" 20 | ], 21 | "use_chords": false, 22 | "use_rests": false, 23 | "use_tempos": true, 24 | "use_time_signatures": true, 25 | "use_sustain_pedals": false, 26 | "use_pitch_bends": false, 27 | "use_programs": false, 28 | "beat_res_rest": { 29 | "0_1": 8, 30 | "1_2": 4, 31 | "2_12": 2 32 | }, 33 | "chord_maps": { 34 | "min": [ 35 | 0, 36 | 3, 37 | 7 38 | ], 39 | "maj": [ 40 | 0, 41 | 4, 42 | 7 43 | ], 44 | "dim": [ 45 | 0, 46 | 3, 47 | 6 48 | ], 49 | "aug": [ 50 | 0, 51 | 4, 52 | 8 53 | ], 54 | "sus2": [ 55 | 0, 56 | 2, 57 | 7 58 | ], 59 | "sus4": [ 60 | 0, 61 | 5, 62 | 7 63 | ], 64 | "7dom": [ 65 | 0, 66 | 4, 67 | 7, 68 | 10 69 | ], 70 | "7min": [ 71 | 0, 72 | 3, 73 | 7, 74 | 10 75 | ], 76 | "7maj": [ 77 | 0, 78 | 4, 79 | 7, 80 | 11 81 | ], 82 | "7halfdim": [ 83 | 0, 84 | 3, 85 | 6, 86 | 10 87 | ], 88 | "7dim": [ 89 | 0, 90 | 3, 91 | 6, 92 | 9 93 | ], 94 | "7aug": [ 95 | 0, 96 | 4, 97 | 8, 98 | 11 99 | ], 100 | "9maj": [ 101 | 0, 102 | 4, 103 | 7, 104 | 10, 105 | 14 106 | ], 107 | "9min": [ 108 | 0, 109 | 4, 110 | 7, 111 | 10, 112 | 13 113 | ] 114 | }, 115 | "chord_tokens_with_root_note": false, 116 | "chord_unknown": null, 117 | "nb_tempos": 121, 118 | "tempo_range": [ 119 | 15, 120 | 480 121 | ], 122 | "log_tempos": true, 123 | "delete_equal_successive_tempo_changes": true, 124 | "time_signature_range": { 125 | "2": [ 126 | 1, 127 | 2, 128 | 3, 129 | 4 130 | ], 131 | "4": [ 132 | 1, 133 | 2, 134 | 3, 135 | 4, 136 | 5, 137 | 6 138 | ], 139 | "8": [ 140 | 1, 141 | 2, 142 | 3, 143 | 4, 144 | 5, 145 | 6, 146 | 7, 147 | 8, 148 | 9, 149 | 10, 150 | 11, 151 | 12 152 | ] 153 | }, 154 | "delete_equal_successive_time_sig_changes": true, 155 | "sustain_pedal_duration": false, 156 | "pitch_bend_range": [ 157 | -8192, 158 | 8191, 159 | 32 160 | ], 161 | "programs": [ 162 | 0 163 | ], 164 | "one_token_stream_for_programs": true, 165 | "additional_params": { 166 | "nb_onset_devs": 161, 167 | "nb_perf_durations": 81, 168 | "max_bar_embedding": 256, 169 | "rel_onset_dev": true, 170 | "rel_perf_duration": true, 171 | "real_max_bar_embedding": 256, 172 | "fill_unperformed_notes": true, 173 | "remove_duplicates": false, 174 | "token_bins": {}, 175 | "cut_overlapping_notes": true, 176 | "use_position_shifts": true, 177 | "onset_position_shifts": true, 178 | "use_onset_indices": true, 179 | "max_notes_in_onset": 12, 180 | "bar_tempos": false, 181 | "onset_tempos": false, 182 | "tempo_window": 8.0, 183 | "tempo_min_onset_dist": 0.5, 184 | "tempo_min_onsets": 8, 185 | "use_quantized_tempos": true, 186 | "decode_recompute_tempos": false, 187 | "limit_rel_onset_devs": true 188 | } 189 | }, 190 | "one_token_stream": true, 191 | "has_bpe": false, 192 | "tokenization": "SPMupleWindow", 193 | "miditok_version": "2.1.6" 194 | } -------------------------------------------------------------------------------- /recipes/default.yaml: -------------------------------------------------------------------------------- 1 | _general_: 2 | device: cuda:0 3 | seed: 23 4 | 5 | _dirname_: results 6 | _label_: ??? 7 | 8 | output_dir: # either a full path or a list of path elements 9 | - ${_general_._dirname_} 10 | - ${model._name_} 11 | - ${model._version_} 12 | - ${date:} 13 | - ${_general_._label_} 14 | 15 | resume_from_checkpoint: 16 | warm_start: false 17 | 18 | 19 | data: 20 | dataset: 21 | _name_: ??? 22 | 23 | collator: 24 | _name_: ??? 25 | 26 | 27 | model: 28 | _name_: ??? 29 | _version_: v0.1 30 | 31 | 32 | evaluator: 33 | _name_: ??? 34 | 35 | 36 | trainer: 37 | # general 38 | output_dir: ${_general_.output_dir} 39 | 40 | do_train: true 41 | do_eval: true 42 | eval_mode: false 43 | 44 | device: ${_general_.device} 45 | seed: ${_general_.seed} 46 | 47 | # logging 48 | log_dir: logs 49 | log_to_file: true 50 | 51 | dashboard_logger: "tensorboard" 52 | log_strategy: steps 53 | log_steps: 5 54 | log_first_step: true 55 | log_raw_to_console: true 56 | 57 | disable_tqdm: false 58 | progress_steps: 5 59 | progress_metrics: ["loss"] 60 | 61 | ignore_data_skip: false 62 | 63 | # data 64 | num_workers: 4 65 | pin_memory: true 66 | shuffle: true # will shuffle training set only 67 | 68 | # training & evaluation 69 | epochs: 100 70 | max_steps: -1 71 | batch_size: 32 72 | 73 | eval_batch_size: 64 74 | eval_batches: 75 | 76 | eval_strategy: epoch # evaluation and checkpoint steps measure: `epoch` or `steps` 77 | eval_steps: 1 # how frequently (in epochs/steps) model evaluation should be performed 78 | 79 | optimization: 80 | lr: !!float 2e-4 81 | optimizer: adamw 82 | optimizer_params: 83 | weight_decay: !!float 1e-6 84 | lr_scheduler: exponential 85 | lr_scheduler_params: 86 | gamma: 0.995 87 | grad_clip: 2.0 88 | grad_accum_steps: 1 89 | mixed_precision: true 90 | 91 | # checkpointing 92 | save_strategy: epoch 93 | save_steps: 1 94 | save_optimizer: false 95 | save_best_only: true 96 | save_rewrite_checkpoint: false 97 | 98 | metric_for_best_model: loss 99 | metric_maximize: false 100 | 101 | resume_from_checkpoint: ${_general_.resume_from_checkpoint} 102 | warm_start: ${_general_.warm_start} 103 | ignore_layers: [] 104 | ignore_mismatched_keys: true 105 | finetune_layers: [] 106 | restore_lr: true -------------------------------------------------------------------------------- /recipes/scoreperformer/ablation/no_cont_tokens.yaml: -------------------------------------------------------------------------------- 1 | base: scoreperformer/base.yaml 2 | 3 | _general_: 4 | device: cuda:0 5 | seed: 23 6 | 7 | _dirname_: results 8 | _label_: ScorePerformer_noContTE 9 | 10 | 11 | data: 12 | dataset: 13 | root: ??? 14 | 15 | 16 | model: 17 | perf_decoder: 18 | token_embeddings: 19 | discrete: true 20 | continuous: false 21 | continuous_dense: false 22 | discrete_ids: 23 | -------------------------------------------------------------------------------- /recipes/scoreperformer/ablation/no_io_tie.yaml: -------------------------------------------------------------------------------- 1 | base: scoreperformer/base.yaml 2 | 3 | _general_: 4 | device: cuda:0 5 | seed: 23 6 | 7 | _dirname_: results 8 | _label_: ScorePerformer_noIOTie 9 | 10 | 11 | data: 12 | dataset: 13 | root: ??? 14 | 15 | 16 | model: 17 | perf_decoder: 18 | lm_head: 19 | _target_: lm 20 | -------------------------------------------------------------------------------- /recipes/scoreperformer/ablation/no_masked_seq.yaml: -------------------------------------------------------------------------------- 1 | base: scoreperformer/base.yaml 2 | 3 | _general_: 4 | device: cuda:0 5 | seed: 23 6 | 7 | _dirname_: results 8 | _label_: ScorePerformer_noMaskedSeq 9 | 10 | 11 | data: 12 | dataset: 13 | root: ??? 14 | 15 | 16 | model: 17 | perf_decoder: 18 | token_embeddings: 19 | _target_: simple 20 | multiseq_mode: 21 | _disable_: true 22 | -------------------------------------------------------------------------------- /recipes/scoreperformer/ablation/no_saln.yaml: -------------------------------------------------------------------------------- 1 | base: scoreperformer/base.yaml 2 | 3 | _general_: 4 | device: cuda:0 5 | seed: 23 6 | 7 | _dirname_: results 8 | _label_: ScorePerformer_noSALN 9 | 10 | 11 | data: 12 | dataset: 13 | root: ??? 14 | 15 | 16 | model: 17 | perf_decoder: 18 | style_emb_mode: cat 19 | -------------------------------------------------------------------------------- /recipes/scoreperformer/ablation/no_score_enc.yaml: -------------------------------------------------------------------------------- 1 | base: scoreperformer/base.yaml 2 | 3 | _general_: 4 | device: cuda:0 5 | seed: 23 6 | 7 | _dirname_: results 8 | _label_: ScorePerformer_noScoreEnc 9 | 10 | 11 | data: 12 | dataset: 13 | root: ??? 14 | 15 | 16 | model: 17 | score_encoder: 18 | _disable_: true 19 | -------------------------------------------------------------------------------- /recipes/scoreperformer/base.yaml: -------------------------------------------------------------------------------- 1 | base: default.yaml 2 | 3 | _general_: 4 | device: cuda:0 5 | seed: 23 6 | 7 | _dirname_: results 8 | _label_: ScorePerformer_base 9 | 10 | resume_from_checkpoint: 11 | warm_start: false 12 | 13 | 14 | data: 15 | dataset: 16 | _name_: LocalScorePerformanceDataset 17 | _splits_: 18 | train: train 19 | eval: eval 20 | 21 | root: ??? 22 | use_alignments: false 23 | auxiliary_data_keys: ["bars", "initial_tempos"] 24 | 25 | performance_directions: ??? # data/directions/direction_classes.json 26 | score_directions_dict: ??? 27 | 28 | max_seq_len: 256 29 | max_bar: 256 30 | bar_sliding_window: 16 31 | 32 | sample_bars: true 33 | sample_note_shift: 0.5 34 | force_max_seq_len: 0.5 35 | 36 | fit_to_max_bar: false # fit bar tokens to max bar (to avoid range errors during training) 37 | fit_to_zero_bar: true 38 | sample_bar_offset: false 39 | 40 | add_sos_eos: true # SOS/EOS tokens for whole sequences, not the sampled subsequences 41 | 42 | sample: true 43 | seed: ${_general_.seed} 44 | 45 | augment_performance: true 46 | pitch_shift_range: [-3, 3] 47 | velocity_shift_range: [-12, 12] # in velocity indices 48 | tempo_shift_range: [0, 0] # in tempo bins (2 ** (1 / 24)) 49 | 50 | noisy_performance: false # compute performance style of the noisy performance 51 | noise_strength: 0.5 # multiplier for noisy performance augmentations 52 | noisy_random_bars: 0.5 # probability of sampling random bars 53 | 54 | deadpan_performance: 0.25 # probability of using deadpan performance 55 | 56 | zero_out_silent_durations: true # set durations of unperformed notes to 0 57 | delete_silent_notes: true 58 | 59 | preload: true 60 | cache: true 61 | 62 | collator: 63 | _name_: MixedLMScorePerformanceCollator 64 | mask_ignore_token_ids: [0, 1, 2, 3] # PAD, MASK, SOS, EOS 65 | mask_ignore_token_dims: [0, 1, 2, 4, 6, 7, 8, 9] # _nT: bar, position, pitch, duration, time_sig, pos_shift, notes_in_onset, pos_in_onset 66 | 67 | 68 | model: 69 | _name_: ScorePerformer 70 | _version_: v0.4.4 71 | 72 | dim: 256 73 | tie_token_emb: true 74 | mode: mixlm 75 | 76 | score_encoder: 77 | _disable_: false 78 | 79 | token_embeddings: 80 | _target_: simple 81 | emb_dims: ${model.perf_decoder.token_embeddings.emb_dims} 82 | mode: ${model.perf_decoder.token_embeddings.mode} 83 | emb_norm: ${model.perf_decoder.token_embeddings.emb_norm} 84 | discrete: ${model.perf_decoder.token_embeddings.discrete} 85 | continuous: ${model.perf_decoder.token_embeddings.continuous} 86 | continuous_dense: ${model.perf_decoder.token_embeddings.continuous_dense} 87 | discrete_ids: ${model.perf_decoder.token_embeddings.discrete_ids} 88 | tie_keys: 89 | 90 | emb_norm: ${model.perf_decoder.emb_norm} 91 | emb_dropout: ${model.perf_decoder.emb_dropout} 92 | use_abs_pos_emb: ${model.perf_decoder.use_abs_pos_emb} 93 | 94 | transformer: 95 | _target_: encoder 96 | 97 | depth: 2 98 | heads: 4 99 | 100 | attention: ${model.perf_decoder.transformer.attention} 101 | feed_forward: ${model.perf_decoder.transformer.feed_forward} 102 | 103 | perf_encoder: 104 | token_embeddings: 105 | _target_: simple 106 | emb_dims: ${model.perf_decoder.token_embeddings.emb_dims} 107 | mode: ${model.perf_decoder.token_embeddings.mode} 108 | emb_norm: ${model.perf_decoder.token_embeddings.emb_norm} 109 | discrete: ${model.perf_decoder.token_embeddings.discrete} 110 | continuous: ${model.perf_decoder.token_embeddings.continuous} 111 | continuous_dense: ${model.perf_decoder.token_embeddings.continuous_dense} 112 | discrete_ids: ${model.perf_decoder.token_embeddings.discrete_ids} 113 | tie_keys: 114 | 115 | emb_norm: true 116 | emb_dropout: 0 117 | use_abs_pos_emb: false 118 | 119 | latent_dim: [32, 20, 8, 4] 120 | aggregate_mode: [mean, bar_mean, beat_mean, onset_mean] 121 | latent_dropout: [0.0, 0.1, 0.2, 0.4] 122 | 123 | hierarchical: true 124 | inclusive_latent_dropout: true 125 | deadpan_zero_latent: true 126 | loss_weight: 1. 127 | 128 | transformer: 129 | _target_: encoder 130 | 131 | depth: 4 132 | heads: 4 133 | 134 | attention: ${model.perf_decoder.transformer.attention} 135 | feed_forward: ${model.perf_decoder.transformer.feed_forward} 136 | 137 | perf_decoder: 138 | token_embeddings: 139 | _target_: multi-seq 140 | emb_dims: 128 141 | mode: cat 142 | emb_norm: true 143 | discrete: false 144 | continuous: true 145 | continuous_dense: true 146 | discrete_ids: [0, 1, 2, 3] 147 | tie_keys: 148 | multiseq_mode: post-cat 149 | 150 | emb_norm: true 151 | emb_dropout: 0 152 | use_abs_pos_emb: false 153 | 154 | context_emb_mode: cat # (`cat` or `attention`) 155 | style_emb_dim: ${model.perf_encoder.latent_dim} 156 | style_emb_mode: adanorm # (`cat` or `adanorm`) 157 | 158 | transformer: 159 | _target_: decoder 160 | 161 | depth: 4 162 | heads: 4 163 | 164 | attention: 165 | dim_head: 64 166 | one_kv_head: true 167 | dropout: 0.1 168 | 169 | alibi_pos_bias: true 170 | alibi_learned: true 171 | 172 | feed_forward: 173 | mult: 4 174 | glu: true 175 | swish: true 176 | dropout: 0.1 177 | 178 | lm_head: 179 | _target_: lm-tied 180 | 181 | regression_head: 182 | _disable_: true 183 | regression_keys: ["Velocity", "Tempo", "RelOnsetDev", "RelPerfDuration"] 184 | 185 | classifiers: 186 | _disable_: false 187 | classifier: 188 | hidden_dims: [] 189 | dropout: 0.2 190 | loss_weight: 1. 191 | weighted_classes: true 192 | detach_inputs: true 193 | 194 | 195 | evaluator: 196 | _name_: ScorePerformerEvaluator 197 | ignore_keys: ["Bar", "Position", "Pitch", "Duration", "TimeSig", "PositionShift", "NotesInOnset", "PositionInOnset"] 198 | weighted_distance: true 199 | 200 | 201 | trainer: 202 | epochs: 1000 203 | batch_size: 128 204 | eval_batch_size: 128 205 | 206 | progress_metrics: ["loss", "accuracy"] 207 | 208 | save_rewrite_checkpoint: true 209 | 210 | metric_for_best_model: accuracy 211 | metric_maximize: true -------------------------------------------------------------------------------- /recipes/scoreperformer/custom_hierarchy.yaml: -------------------------------------------------------------------------------- 1 | base: scoreperformer/base.yaml 2 | 3 | _general_: 4 | device: cuda:0 5 | seed: 23 6 | 7 | _dirname_: results 8 | _label_: ScorePerformer_custom 9 | 10 | 11 | data: 12 | dataset: 13 | root: ??? 14 | 15 | 16 | model: 17 | perf_encoder: 18 | # z=64: base 19 | latent_dim: [32, 20, 8, 4] 20 | aggregate_mode: [mean, bar_mean, beat_mean, onset_mean] 21 | latent_dropout: [0.0, 0.1, 0.2, 0.4] 22 | 23 | # # z=64: -onset 24 | # latent_dim: [32, 20, 12] 25 | # aggregate_mode: [mean, bar_mean, beat_mean] 26 | # latent_dropout: [0.0, 0.1, 0.2] 27 | 28 | # # z=64: -onset, -beat 29 | # latent_dim: [32, 32] 30 | # aggregate_mode: [mean, bar_mean] 31 | # latent_dropout: [0.0, 0.1] 32 | 33 | # # z=64: -onset, -beat, -bar 34 | # latent_dim: [64] 35 | # aggregate_mode: [mean] 36 | # latent_dropout: [0.0] 37 | -------------------------------------------------------------------------------- /recipes/scoreperformer/minimal.yaml: -------------------------------------------------------------------------------- 1 | base: scoreperformer/base.yaml 2 | 3 | _general_: 4 | device: cuda:0 5 | seed: 23 6 | 7 | _dirname_: results 8 | _label_: ScorePerformer_base 9 | 10 | 11 | data: 12 | dataset: 13 | root: ??? 14 | 15 | performance_directions: ??? # data/directions/direction_classes.json 16 | score_directions_dict: ??? 17 | -------------------------------------------------------------------------------- /recipes/scoreperformer/no_classifiers.yaml: -------------------------------------------------------------------------------- 1 | base: scoreperformer/base.yaml 2 | 3 | _general_: 4 | device: cuda:0 5 | seed: 23 6 | 7 | _dirname_: results 8 | _label_: ScorePerformer_noClf 9 | 10 | 11 | data: 12 | dataset: 13 | root: ??? 14 | 15 | # leave empty 16 | performance_directions: 17 | score_directions_dict: 18 | 19 | model: 20 | classifiers: 21 | _disable_: true 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.23,<1.24 2 | torch>=2.0.1 3 | miditok==2.1.6 4 | matplotlib 5 | tensorboard>=2.6.0 6 | omegaconf==2.3.0 7 | loguru==0.7.0 8 | tqdm>=4.63.0 9 | note_seq==0.0.5 10 | einops 11 | git+https://github.com/mac-marg-pianist/musicXML_parser 12 | -------------------------------------------------------------------------------- /scoreperformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilya16/ScorePerformer/c04f9a3155bb6afb6ca7bcc1cac2325184cc1262/scoreperformer/__init__.py -------------------------------------------------------------------------------- /scoreperformer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .collators import ( 2 | LMPerformanceCollator, 3 | MixedLMPerformanceCollator, 4 | LMScorePerformanceCollator, 5 | MixedLMScorePerformanceCollator, 6 | ) 7 | from .datasets import ( 8 | PerformanceDataset, 9 | LocalScorePerformanceDataset 10 | ) 11 | 12 | DATASETS = {name: cls for name, cls in globals().items() if ".datasets." in str(cls)} 13 | COLLATORS = {name: cls for name, cls in globals().items() if ".collators." in str(cls)} 14 | -------------------------------------------------------------------------------- /scoreperformer/data/collators/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import SeqInputs 2 | from .performance import ( 3 | PerformanceInputs, PerformanceCollator, 4 | LMPerformanceInputs, LMPerformanceCollator, 5 | MixedLMPerformanceInputs, MixedLMPerformanceCollator 6 | ) 7 | from .score_performance import ( 8 | ScorePerformanceInputs, ScorePerformanceCollator, 9 | LMScorePerformanceInputs, LMScorePerformanceCollator, 10 | MixedLMScorePerformanceInputs, MixedLMScorePerformanceCollator 11 | ) 12 | -------------------------------------------------------------------------------- /scoreperformer/data/collators/common.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | @dataclass 9 | class SeqInputs: 10 | tokens: Union[np.ndarray, torch.Tensor] 11 | mask: Union[np.ndarray, torch.Tensor] 12 | lengths: Union[np.ndarray, torch.Tensor] 13 | -------------------------------------------------------------------------------- /scoreperformer/data/collators/directions.py: -------------------------------------------------------------------------------- 1 | """ Performance direction embeddings collators. """ 2 | 3 | import torch 4 | 5 | 6 | class DirectionEmbeddingCollator: 7 | def __init__( 8 | self, 9 | num_embeddings: int = 1, 10 | embedding_dim: int = 64, 11 | ): 12 | self.num_embeddings = num_embeddings 13 | self.embedding_dim = embedding_dim 14 | 15 | def init_data(self, batch): 16 | embeddings = torch.zeros(len(batch), self.num_embeddings, self.embedding_dim) 17 | labels = torch.zeros(len(batch), dtype=torch.long) 18 | return embeddings, labels 19 | 20 | def process_sample(self, i, sample, data): 21 | embeddings, labels = data 22 | _, emb, label = sample 23 | 24 | emb = emb.unsqueeze(0) if emb.ndim == 1 else emb 25 | embeddings[i, -emb.shape[0]:] = emb 26 | labels[i] = label 27 | 28 | def __call__(self, batch, inference=False, return_tensors=True): 29 | data = self.init_data(batch) 30 | for i, sample in enumerate(batch): 31 | self.process_sample(i, sample, data) 32 | 33 | return {'embeddings': data[0], 'labels': data[1]} 34 | -------------------------------------------------------------------------------- /scoreperformer/data/collators/score_performance.py: -------------------------------------------------------------------------------- 1 | """ Score-Performance data collators. """ 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional, List, Union, Dict 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from .common import SeqInputs 10 | from .performance import ( 11 | PerformanceCollator, 12 | LMPerformanceCollator, 13 | MixedLMPerformanceCollator 14 | ) 15 | from ..datasets.score_performance import ScorePerformanceSample 16 | 17 | 18 | @dataclass 19 | class SeqSegments: 20 | bar: Optional[Union[np.ndarray, torch.Tensor]] = None 21 | beat: Optional[Union[np.ndarray, torch.Tensor]] = None 22 | onset: Optional[Union[np.ndarray, torch.Tensor]] = None 23 | 24 | 25 | @dataclass 26 | class ScorePerformanceInputs: 27 | scores: SeqInputs 28 | performances: SeqInputs 29 | noisy_performances: Optional[SeqInputs] = None 30 | segments: Optional[SeqSegments] = None 31 | directions: Optional[Union[Dict[str, torch.Tensor], torch.Tensor]] = None 32 | deadpan_mask: Optional[torch.Tensor] = None 33 | 34 | 35 | class ScorePerformanceCollator(PerformanceCollator): 36 | def __init__( 37 | self, 38 | pad_token_id: int = 0, 39 | pad_to_multiple_of: int = 1 40 | ): 41 | super().__init__(pad_token_id, pad_to_multiple_of) 42 | 43 | def get_max_lengths(self, batch: List[ScorePerformanceSample], inference: bool = False): 44 | max_lens = super().get_max_lengths(batch, inference=inference) 45 | 46 | lens_score = np.array(list(map(lambda sample: len(sample.score), batch))).T 47 | max_lens['score'] = self.pad_len(np.max(lens_score)) 48 | 49 | if all((sample.noisy_perf is not None for sample in batch)): 50 | lens_noisy_perf = np.array(list(map(lambda sample: len(sample.noisy_perf), batch))).T 51 | max_lens['noisy_perf'] = self.pad_len(np.max(lens_noisy_perf)) 52 | 53 | return max_lens 54 | 55 | def init_data(self, batch: List[ScorePerformanceSample], inference: bool = False): 56 | data = super().init_data(batch, inference=inference) 57 | max_lens = self.get_max_lengths(batch, inference=inference) 58 | 59 | sample, batch_size = batch[0], len(batch) 60 | return ScorePerformanceInputs( 61 | scores=self._init_seq_data( 62 | batch_size, max_lens['score'], 63 | compound_factor=sample.score.shape[-1] 64 | ), 65 | performances=data.performances, 66 | noisy_performances=self._init_seq_data( 67 | batch_size, max_lens['noisy_perf'], 68 | compound_factor=sample.noisy_perf.shape[-1] 69 | ) if 'noisy_perf' in max_lens else None, 70 | segments=SeqSegments( 71 | bar=torch.zeros(len(batch), max_lens['score'], dtype=torch.long), 72 | beat=torch.zeros(len(batch), max_lens['score'], dtype=torch.long), 73 | onset=torch.zeros(len(batch), max_lens['score'], dtype=torch.long) 74 | ) if sample.segments is not None else None, 75 | directions=torch.zeros( 76 | batch_size, max_lens['score'], len(sample.directions), dtype=torch.long 77 | ) if sample.directions is not None else None, 78 | deadpan_mask=torch.zeros(batch_size, dtype=torch.bool) 79 | ) 80 | 81 | def process_sample(self, i: int, sample: ScorePerformanceSample, data: ScorePerformanceInputs, 82 | inference: bool = False): 83 | # process performance 84 | super().process_sample(i, sample, data, inference=inference) 85 | 86 | # process score 87 | self._process_sequence(i, seq=sample.score, seq_data=data.scores) 88 | 89 | # process noisy performance is present 90 | if sample.noisy_perf is not None: 91 | self._process_sequence(i, seq=sample.noisy_perf, seq_data=data.noisy_performances) 92 | 93 | # process note segments if present 94 | seq_len = len(sample.score) 95 | if sample.segments is not None: 96 | data.segments.bar[i, :seq_len] = torch.from_numpy(sample.segments.bar) 97 | data.segments.beat[i, :seq_len] = torch.from_numpy(sample.segments.beat) 98 | data.segments.onset[i, :seq_len] = torch.from_numpy(sample.segments.onset) 99 | 100 | # process directions if present 101 | if sample.directions is not None: 102 | for j, (group_name, group_directions) in enumerate(sample.directions.items()): 103 | for (label, key), direction_map in group_directions.items(): 104 | mask = direction_map != 0. 105 | if np.any(mask): 106 | data.directions[i, :seq_len, j][mask] = label * torch.from_numpy(direction_map[mask]) 107 | 108 | data.deadpan_mask[i] = sample.is_deadpan 109 | 110 | def __call__(self, batch: List[ScorePerformanceSample], inference: bool = False, return_tensors: bool = True): 111 | data = self.init_data(batch, inference=inference) 112 | for i, sample in enumerate(batch): 113 | self.process_sample(i, sample, data) 114 | 115 | return data 116 | 117 | 118 | # FOR LANGUAGE MODELING 119 | @dataclass 120 | class LMScorePerformanceInputs(ScorePerformanceInputs): 121 | labels: Optional[SeqInputs] = None 122 | 123 | 124 | class LMScorePerformanceCollator(ScorePerformanceCollator, LMPerformanceCollator): 125 | def __init__( 126 | self, 127 | pad_token_id: int = 0, 128 | pad_to_multiple_of: int = 1, 129 | 130 | mlm: bool = False, 131 | mask_prob: float = 0.15, 132 | replace_prob: float = 0.9, 133 | mask_token_id: int = 1, 134 | mask_ignore_token_ids: Optional[List[int]] = None, 135 | mask_ignore_token_dims: Optional[List[int]] = None, 136 | label_pad_ignored_dims: bool = True, 137 | label_pad_token_id: int = -100 138 | ): 139 | LMPerformanceCollator.__init__( 140 | self, 141 | pad_token_id=pad_token_id, 142 | pad_to_multiple_of=pad_to_multiple_of, 143 | mlm=mlm, 144 | mask_prob=mask_prob, 145 | replace_prob=replace_prob, 146 | mask_token_id=mask_token_id, 147 | mask_ignore_token_ids=mask_ignore_token_ids, 148 | mask_ignore_token_dims=mask_ignore_token_dims, 149 | label_pad_ignored_dims=label_pad_ignored_dims, 150 | label_pad_token_id=label_pad_token_id 151 | ) 152 | 153 | def __call__(self, batch: List[ScorePerformanceSample], inference: bool = False, return_tensors: bool = True): 154 | data = super().__call__(batch, inference=inference) 155 | 156 | if self.mlm: 157 | masked_seq, labels, label_mask = self.mask_sequence(data.performances.tokens) 158 | data.performances.tokens = masked_seq 159 | else: 160 | labels = data.performances.tokens.clone().detach() 161 | labels[labels == self.pad_token_id] = self.label_pad_token_id 162 | label_mask = data.performances.mask.clone().detach() 163 | 164 | data = LMScorePerformanceInputs( 165 | scores=data.scores, 166 | performances=data.performances, 167 | noisy_performances=data.noisy_performances, 168 | segments=data.segments, 169 | directions=data.directions, 170 | deadpan_mask=data.deadpan_mask, 171 | labels=SeqInputs( 172 | tokens=labels, 173 | mask=label_mask, 174 | lengths=data.performances.lengths 175 | ) 176 | ) 177 | 178 | return data 179 | 180 | 181 | @dataclass 182 | class MixedLMScorePerformanceInputs(LMScorePerformanceInputs): 183 | masked_performances: Optional[SeqInputs] = None 184 | 185 | 186 | class MixedLMScorePerformanceCollator(ScorePerformanceCollator, MixedLMPerformanceCollator): 187 | def __init__( 188 | self, 189 | pad_token_id: int = 0, 190 | pad_to_multiple_of: int = 1, 191 | 192 | mask_token_id: int = 1, 193 | mask_ignore_token_ids: Optional[List[int]] = None, 194 | mask_ignore_token_dims: Optional[List[int]] = None, 195 | label_pad_ignored_dims: bool = True, 196 | label_pad_token_id: int = -100 197 | ): 198 | MixedLMPerformanceCollator.__init__( 199 | self, 200 | pad_token_id=pad_token_id, 201 | pad_to_multiple_of=pad_to_multiple_of, 202 | mask_token_id=mask_token_id, 203 | mask_ignore_token_ids=mask_ignore_token_ids, 204 | mask_ignore_token_dims=mask_ignore_token_dims, 205 | label_pad_ignored_dims=label_pad_ignored_dims, 206 | label_pad_token_id=label_pad_token_id 207 | ) 208 | 209 | def __call__(self, batch: List[ScorePerformanceSample], inference: bool = False, return_tensors: bool = True): 210 | data = super().__call__(batch, inference=inference) 211 | 212 | masked_performances, labels = self.mask_sequence(data.performances.tokens) 213 | label_mask = data.performances.mask.clone().detach() 214 | 215 | data = MixedLMScorePerformanceInputs( 216 | scores=data.scores, 217 | performances=data.performances, 218 | noisy_performances=data.noisy_performances, 219 | segments=data.segments, 220 | directions=data.directions, 221 | deadpan_mask=data.deadpan_mask, 222 | masked_performances=SeqInputs( 223 | tokens=masked_performances, 224 | mask=label_mask, 225 | lengths=data.performances.lengths 226 | ), 227 | labels=SeqInputs( 228 | tokens=labels, 229 | mask=label_mask, 230 | lengths=data.performances.lengths 231 | ) 232 | ) 233 | 234 | return data 235 | -------------------------------------------------------------------------------- /scoreperformer/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .directions import DirectionBarEmbeddingDataset 2 | from .performance import ( 3 | PerformanceDataset, 4 | PerformanceSampleMeta, 5 | PerformanceSample 6 | ) 7 | from .score_performance import ( 8 | ScorePerformanceDataset, 9 | LocalScorePerformanceDataset, 10 | ScorePerformanceSampleMeta, 11 | ScorePerformanceSample, 12 | NoteSegments, 13 | SequenceTypes 14 | ) 15 | from .token_sequence import ( 16 | TokenSequenceDataset, 17 | LocalTokenSequenceDataset 18 | ) 19 | -------------------------------------------------------------------------------- /scoreperformer/data/datasets/performance.py: -------------------------------------------------------------------------------- 1 | """ Performance token sequence dataset. """ 2 | 3 | import copy 4 | import os 5 | from dataclasses import dataclass 6 | from functools import partial 7 | from typing import Optional, Tuple 8 | 9 | import numpy as np 10 | from torch.utils.data import Dataset 11 | 12 | from scoreperformer.utils import apply, load_json 13 | from .token_sequence import load_token_sequence, LocalTokenSequenceDataset 14 | from .utils import load_tokens_np, get_num_bars, compute_bar_sample_positions, get_end_bar 15 | from ..helpers import ( 16 | TOKEN_SEQUENCE_PROCESSORS, 17 | TOKEN_SEQUENCE_INDEXERS, 18 | TokenSequenceAugmentations 19 | ) 20 | from ..tokenizers import TOKENIZERS 21 | 22 | 23 | @dataclass 24 | class PerformanceSampleMeta: 25 | idx: Optional[int] 26 | perf_idx: int 27 | start_bar: int 28 | end_bar: Optional[int] 29 | bar_offset: int = 0 30 | augmentations: Optional[TokenSequenceAugmentations] = None 31 | 32 | 33 | @dataclass 34 | class PerformanceSample: 35 | perf: np.ndarray 36 | meta: PerformanceSampleMeta 37 | 38 | 39 | class PerformanceDataset(Dataset): 40 | def __init__( 41 | self, 42 | root: str, 43 | split: str = 'train', 44 | encoding: str = 'OctupleM', 45 | 46 | max_seq_len: int = 512, 47 | max_bar: int = 256, 48 | bar_sliding_window: int = 16, 49 | 50 | fit_to_max_bar: bool = False, 51 | fit_to_zero_bar: bool = False, 52 | sample_bars: bool = False, 53 | 54 | add_sos_eos: bool = False, 55 | 56 | sample: bool = False, 57 | seed: int = 23, 58 | 59 | augment_performance: bool = False, 60 | pitch_shift_range: Tuple[int, int] = (-3, 3), 61 | velocity_shift_range: Tuple[int, int] = (-2, 2), 62 | tempo_shift_range: Tuple[int, int] = (-2, 2), 63 | 64 | cache: bool = True, 65 | **kwargs 66 | ): 67 | 68 | self.root = root 69 | self.split = split 70 | 71 | # load metadata 72 | metadata_file = os.path.join(self.root, 'metadata.json') 73 | metadata = load_json(metadata_file) 74 | 75 | if all(key not in metadata for key in ['all', 'train', 'eval', 'val', 'test']): 76 | self.metadata = metadata # metadata of names 77 | else: 78 | self.metadata = metadata[self.split] 79 | 80 | self.performance_names = list(self.metadata) 81 | 82 | # load tokenizer 83 | if encoding not in TOKENIZERS: 84 | raise ValueError(f"Encoding {encoding} is not a valid encoding, " 85 | f"supported types are: {list(TOKENIZERS.keys())}") 86 | 87 | self.encoding = encoding 88 | self.tokenizer = TOKENIZERS[encoding](params=os.path.join(self.root, 'config.json')) 89 | 90 | # load sequences 91 | load_tokens = partial(load_tokens_np, tokenizer=self.tokenizer) 92 | perf_load_fn = partial(load_token_sequence, load_fn=load_tokens) 93 | self.performances = LocalTokenSequenceDataset( 94 | root=self.root, 95 | files=self.performance_names, 96 | load_fn=perf_load_fn, 97 | cache=cache 98 | ) 99 | 100 | # configurations 101 | self.max_seq_len = max_seq_len 102 | self.max_bar = max_bar 103 | self.bar_sliding_window = bar_sliding_window 104 | self.add_sos_eos = add_sos_eos 105 | assert max_bar <= self.tokenizer.max_bar_embedding 106 | 107 | # bar indexer and indices arrays 108 | self.indexer = TOKEN_SEQUENCE_INDEXERS[encoding](self.tokenizer) 109 | self._bar_indices = [None] * len(self.performances) 110 | 111 | # load or compute number of bars in performances used to build samples 112 | bars_file = os.path.join(self.root, 'bars.json') 113 | if os.path.exists(bars_file): 114 | _num_bars = load_json(bars_file) 115 | _perf_num_bars = np.array([_num_bars[perf] for perf in self.performance_names]) 116 | else: 117 | _perf_num_bars = np.array(apply(self.performances, partial(get_num_bars, tokenizer=self.tokenizer))) 118 | 119 | # compute sample positions 120 | self._length, self._sample_positions, self._sample_ids = compute_bar_sample_positions( 121 | seq_num_bars=_perf_num_bars, bar_sliding_window=self.bar_sliding_window 122 | ) 123 | 124 | # random effects they do not advertise 125 | self.sample = sample 126 | if self.sample: 127 | np.random.seed(seed) 128 | 129 | # bar sampling 130 | assert not (fit_to_max_bar and fit_to_zero_bar), \ 131 | "Only one of `fit_to_max_bar`/`fit_to_zero_bar` could be set to True" 132 | self.fit_to_max_bar = fit_to_max_bar 133 | self.fit_to_zero_bar = fit_to_zero_bar 134 | self.sample_bars = sample and sample_bars 135 | 136 | # augmentations 137 | self.augment_performance = sample and augment_performance 138 | 139 | if not self.augment_performance: 140 | pitch_shift_range = velocity_shift_range = tempo_shift_range = (0, 0) 141 | 142 | # sequence processor 143 | self.processor = TOKEN_SEQUENCE_PROCESSORS[encoding]( 144 | tokenizer=self.tokenizer, 145 | pitch_shift_range=pitch_shift_range, 146 | velocity_shift_range=velocity_shift_range, 147 | tempo_shift_range=tempo_shift_range, 148 | ) 149 | 150 | def _get_augmentations(self, meta): 151 | if meta is None: 152 | if self.augment_performance: 153 | return self.processor.sample_augmentations() 154 | else: 155 | return None 156 | else: 157 | return meta.augmentations 158 | 159 | def _augment_sequence(self, seq, augmentations): 160 | if augmentations is None: 161 | return seq 162 | 163 | seq = self.processor.augment_sequence(seq, augmentations) 164 | mask = self.processor.compute_valid_pitch_mask(seq) 165 | return seq[mask] 166 | 167 | def get(self, idx: Optional[int] = None, meta: Optional[PerformanceSampleMeta] = None): 168 | assert idx is not None or meta is not None, 'one of `idx`/`meta` should be provided as an argument' 169 | 170 | # get performance 171 | if meta is None: 172 | perf_idx = np.where(idx >= self._sample_ids)[0][-1] 173 | else: 174 | idx, perf_idx = meta.idx, meta.perf_idx 175 | 176 | bar_indices = self._bar_indices[perf_idx] 177 | if bar_indices is None: 178 | bar_indices = self._bar_indices[perf_idx] = self.indexer.compute_bar_indices(self.performances[perf_idx]) 179 | 180 | total_bars = bar_indices.shape[0] - 1 181 | 182 | # compute start bar index 183 | if meta is None: 184 | start_bar = self._sample_positions[idx] 185 | start_bar = min(start_bar, bar_indices.shape[0] - self.bar_sliding_window // 2) # bars of silent notes 186 | if self.sample: 187 | low = max(0, start_bar - self.bar_sliding_window // 2) 188 | high = min(total_bars - self.bar_sliding_window // 4, start_bar + self.bar_sliding_window // 2) 189 | high = max(low + 1, high) 190 | start_bar = np.random.randint(low, high) 191 | else: 192 | start_bar = meta.start_bar 193 | 194 | # compute start index 195 | perf_start = bar_indices[start_bar] 196 | 197 | # compute end bar index 198 | if meta is None or meta.end_bar is None: 199 | end_bar = get_end_bar(bar_indices, start_bar, self.max_seq_len, self.max_bar) 200 | else: 201 | end_bar = meta.end_bar 202 | 203 | # compute end index 204 | perf_end = bar_indices[end_bar + 1] 205 | 206 | # get token sequences 207 | perf_seq = copy.copy(self.performances[perf_idx][perf_start:perf_end]) 208 | 209 | min_bar = perf_seq[:, 0].min() - self.tokenizer.zero_token 210 | max_bar = perf_seq[:, 0].max() - self.tokenizer.zero_token 211 | 212 | # shift bar indices 213 | bar_offset = 0 214 | if meta is None: 215 | if self.fit_to_max_bar: 216 | # to make bar index distribute in [0, bar_max) 217 | if self.sample_bars: 218 | bar_offset = np.random.randint(-min_bar, self.max_bar - max_bar) 219 | elif end_bar >= self.max_bar: 220 | # move in proportion to `score_total_bars` 221 | _end_bar = int((self.max_bar - 1) * max_bar / total_bars) 222 | bar_offset = _end_bar - max_bar 223 | elif self.fit_to_zero_bar: 224 | bar_offset = -min_bar 225 | else: 226 | bar_offset = meta.bar_offset 227 | 228 | if bar_offset != 0: 229 | perf_seq[:, self.tokenizer.vocab_types_idx['Bar']] += bar_offset 230 | 231 | # augmentations 232 | augmentations = self._get_augmentations(meta) 233 | perf_seq = self._augment_sequence(perf_seq, augmentations) 234 | 235 | if self.add_sos_eos: 236 | if start_bar == 0: 237 | perf_seq = self.processor.add_sos_token(perf_seq) 238 | if end_bar + 1 == total_bars: 239 | perf_seq = self.processor.add_eos_token(perf_seq) 240 | 241 | # build sample metadata 242 | meta = PerformanceSampleMeta( 243 | idx=idx, 244 | perf_idx=perf_idx, 245 | start_bar=start_bar, 246 | end_bar=end_bar, 247 | bar_offset=bar_offset, 248 | augmentations=augmentations 249 | ) 250 | 251 | return PerformanceSample( 252 | perf=perf_seq, 253 | meta=meta 254 | ) 255 | 256 | def __getitem__(self, idx: int): 257 | return self.get(idx=idx) 258 | 259 | def __len__(self): 260 | return self._length 261 | -------------------------------------------------------------------------------- /scoreperformer/data/datasets/token_sequence.py: -------------------------------------------------------------------------------- 1 | """ Token sequence datasets. """ 2 | 3 | import os 4 | from pathlib import Path, PurePath 5 | 6 | from torch.utils.data import Dataset 7 | 8 | from scoreperformer.utils import apply, load_json 9 | 10 | 11 | def load_token_sequence(path, load_fn, processing_funcs=None): 12 | seq = load_fn(path) 13 | if processing_funcs: 14 | for func in processing_funcs: 15 | seq = func(seq) 16 | return seq 17 | 18 | 19 | class TokenSequenceDataset(Dataset): 20 | def __init__(self, sequences, names=None): 21 | self.seqs = sequences 22 | 23 | self.names = names 24 | if names is not None: 25 | self._name_to_idx = {name: idx for idx, name in enumerate(self.names)} 26 | 27 | def __getitem__(self, idx): 28 | seq = self.seqs[idx] 29 | return seq[0] if isinstance(seq, tuple) else seq 30 | 31 | def __len__(self): 32 | return len(self.seqs) 33 | 34 | 35 | class LocalTokenSequenceDataset(TokenSequenceDataset): 36 | def __init__(self, root, files=None, suffix='.json', load_fn=load_json, preload=False, cache=False): 37 | self.root = root 38 | self.load_fn = load_fn 39 | 40 | if files is None: 41 | if os.path.isfile(root) and root.lower().endswith(suffix): 42 | files = [Path(root)] 43 | else: 44 | files = list(Path(root).glob('**/*' + suffix)) 45 | files = list(map(Path, files)) 46 | else: 47 | files = list(map(lambda x: Path(x).with_suffix(suffix), files)) 48 | 49 | paths = [PurePath(os.path.join(self.root, file)) for file in files] 50 | 51 | self.paths = paths 52 | 53 | self._cache = cache 54 | 55 | self.seqs = self.load_sequences(preload=preload) 56 | names = [str(file).replace(suffix, '') for file in files] 57 | 58 | super().__init__(sequences=self.seqs, names=names) 59 | 60 | def load_sequence(self, path): 61 | return self.load_fn(path) 62 | 63 | def load_sequences(self, preload): 64 | if preload: 65 | return apply(self.paths, func=self.load_sequence, desc='Loading token sequences...') 66 | else: 67 | return [None] * len(self.paths) 68 | 69 | def __getitem__(self, idx): 70 | if self.seqs[idx] is None: 71 | seq = self.load_sequence(self.paths[idx]) 72 | if self._cache: 73 | self.seqs[idx] = seq 74 | else: 75 | seq = self.seqs[idx] 76 | return seq[0] if isinstance(seq, tuple) else seq 77 | 78 | def __len__(self): 79 | return len(self.seqs) 80 | -------------------------------------------------------------------------------- /scoreperformer/data/datasets/utils.py: -------------------------------------------------------------------------------- 1 | """ Common datasets' utils. """ 2 | 3 | import random 4 | from typing import Optional, Dict, List 5 | 6 | import numpy as np 7 | 8 | from ..tokenizers import OctupleM, TokSequence 9 | 10 | 11 | def load_tokens_data(path, tokenizer): 12 | data = tokenizer.load_tokens(path) 13 | if isinstance(data, list): # backward compatibility with old datasets (miditok<=1.2.3) 14 | data = {'ids': data[0], 'programs': data[1] if len(data) > 1 else []} 15 | elif 'ids' not in data: # backward compatibility with old datasets (miditok<=2.0.0) 16 | data['ids'] = data['tokens'] 17 | del data['tokens'] 18 | return data 19 | 20 | 21 | def load_tokens_np(path, tokenizer): 22 | return np.array(load_tokens_data(path, tokenizer)['ids']) 23 | 24 | 25 | def load_token_sequence(path, tokenizer): 26 | data = load_tokens_data(path, tokenizer) 27 | return TokSequence(ids=data['ids'], meta=data.get('meta', {})) 28 | 29 | 30 | def get_num_bars(seq, tokenizer): 31 | if isinstance(tokenizer, OctupleM): 32 | bar_idx = tokenizer.vocab_types_idx['Bar'] 33 | return int(seq[-1, bar_idx] - tokenizer.zero_token + 1) 34 | else: 35 | raise ValueError(f"Unsupported tokenizer: {tokenizer.__class__.__name__}") 36 | 37 | 38 | def compute_bar_sample_positions(seq_num_bars, bar_sliding_window): 39 | bar_shift = bar_sliding_window 40 | length, sample_positions = 0, [] 41 | for num_bars in seq_num_bars: 42 | back_shift = -bar_shift // 4 if (num_bars - bar_shift // 2) % bar_shift == 0 else 0 43 | positions = np.concatenate([ 44 | np.arange(0, num_bars - bar_shift // 2, bar_shift), 45 | np.arange(num_bars - bar_shift // 2 - back_shift, -1 + bar_shift // 2, -bar_shift) 46 | ]) 47 | length += len(positions) 48 | sample_positions.append(positions) 49 | 50 | sample_ids = np.concatenate([[0], np.cumsum(list(map(len, sample_positions)))[:-1]]) 51 | sample_positions = np.concatenate(sample_positions) 52 | 53 | return length, sample_positions, sample_ids 54 | 55 | 56 | def get_end_bar(score_indices, start_bar=0, max_seq_len=512, max_bar=256): 57 | end_bar = np.where(score_indices <= score_indices[start_bar] + max_seq_len)[0][-1] - 1 58 | return min(max(start_bar, end_bar), start_bar + max_bar - 1) 59 | 60 | 61 | def split_composer_metadata( 62 | reference_metadata: Dict[str, List[str]], 63 | splits: Dict[str, float], 64 | seed: Optional[int] = None 65 | ): 66 | if seed is not None: 67 | random.seed(seed) 68 | np.random.seed(seed) 69 | 70 | data_ = {split: dict() for split in splits} 71 | 72 | for comp, score_perf in reference_metadata.items(): 73 | comp_meta_rep = [] 74 | 75 | score_perf = list(score_perf.items()) 76 | np.random.shuffle(score_perf) 77 | score_perf = dict(score_perf) 78 | 79 | for score, perfs in score_perf.items(): 80 | comp_meta_rep.extend([score] * len(perfs)) 81 | 82 | if len(comp_meta_rep) > 10: 83 | start = 0 84 | for i, (split, ratio) in enumerate(splits.items()): 85 | end = min(len(comp_meta_rep), start + round(ratio * len(comp_meta_rep))) 86 | 87 | if i == len(splits) - 1: 88 | end = len(comp_meta_rep) 89 | 90 | if end < len(comp_meta_rep) and comp_meta_rep[end - 1] == comp_meta_rep[len(comp_meta_rep) - 1]: 91 | while end > 0 and comp_meta_rep[end] == comp_meta_rep[end - 1]: 92 | end -= 1 93 | else: 94 | while end < len(comp_meta_rep) and comp_meta_rep[end - 1] == comp_meta_rep[end]: 95 | end += 1 96 | 97 | split_scores = np.unique(comp_meta_rep[start:end]).tolist() 98 | for score in split_scores: 99 | data_[split][score] = score_perf[score] 100 | start = end 101 | else: 102 | for score, perfs in score_perf.items(): 103 | _split = np.random.choice(np.array(list(splits.keys())), p=np.array(list(splits.values()))) 104 | data_[_split][score] = perfs 105 | 106 | for _split in data_: 107 | data_[_split] = dict(sorted(data_[_split].items())) 108 | 109 | return data_ 110 | -------------------------------------------------------------------------------- /scoreperformer/data/directions/__init__.py: -------------------------------------------------------------------------------- 1 | from .articulation import ARTICULATION_PREFIX, ARTICULATION_KEYS 2 | from .dynamic import DYNAMIC_PREFIX, DYNAMIC_KEYS, ABS_DYNAMIC_KEYS, REL_DYNAMIC_KEYS 3 | from .parser import parse_directions 4 | from .tempo import TEMPO_PREFIX, TEMPO_KEYS, ABS_TEMPO_KEYS, REL_TEMPO_KEYS, RET_TEMPO_KEYS 5 | from .words import extract_main_keyword 6 | 7 | 8 | def build_prefixed_keys(keys, prefix): 9 | return list(map(lambda d: f'{prefix}/' + extract_main_keyword(d), keys)) 10 | 11 | 12 | DYNAMIC_DIRECTION_KEYS = build_prefixed_keys(DYNAMIC_KEYS, DYNAMIC_PREFIX) 13 | TEMPO_DIRECTION_KEYS = build_prefixed_keys(TEMPO_KEYS, TEMPO_PREFIX) 14 | ARTICULATION_DIRECTION_KEYS = build_prefixed_keys(ARTICULATION_KEYS, ARTICULATION_PREFIX) 15 | -------------------------------------------------------------------------------- /scoreperformer/data/directions/articulation.py: -------------------------------------------------------------------------------- 1 | ARTICULATION_PREFIX = 'articulation' 2 | 3 | ARTICULATION_KEYS = [ 4 | 'arpeggiate', 5 | 'fermata', 6 | 'staccato', 7 | 'tenuto' 8 | ] 9 | -------------------------------------------------------------------------------- /scoreperformer/data/directions/dynamic.py: -------------------------------------------------------------------------------- 1 | DYNAMIC_PREFIX = 'dynamic' 2 | 3 | ABS_DYNAMIC_KEYS = [ 4 | 'pppp', 'ppp', 'pp', 5 | ('p', 'piano'), 6 | 'mp', 'mf', 7 | ('f', 'forte'), 8 | 'ff', 'fff', 'ffff', 9 | 'fp', 'ffp' 10 | ] 11 | 12 | REL_DYNAMIC_KEYS = [ 13 | ('crescendo', 'cresc'), 14 | ('diminuendo', 'dim', 'decresc'), 15 | ('sf', 'fz', 'sfz', 'sffz'), 16 | ('rf', 'rfz') 17 | ] 18 | 19 | DYNAMIC_KEYS = ABS_DYNAMIC_KEYS + REL_DYNAMIC_KEYS 20 | 21 | 22 | def hairpin_word_regularization(word): 23 | if 'decresc' in word: 24 | word = 'diminuendo' 25 | elif 'cresc' in word: 26 | word = 'crescendo' 27 | elif 'dim' in word: 28 | word = 'diminuendo' 29 | return word 30 | -------------------------------------------------------------------------------- /scoreperformer/data/directions/parser.py: -------------------------------------------------------------------------------- 1 | """ MusicXML performance direction data parsing. """ 2 | 3 | from musicxml_parser.playable_notes import get_playable_notes 4 | 5 | from .articulation import ARTICULATION_PREFIX 6 | from .dynamic import DYNAMIC_PREFIX, ABS_DYNAMIC_KEYS, REL_DYNAMIC_KEYS, hairpin_word_regularization 7 | from .tempo import TEMPO_PREFIX, TEMPO_KEYS 8 | from .words import extract_direction_by_keys, word_regularization 9 | 10 | 11 | def get_part_directions(part): 12 | directions = [] 13 | for measure_idx, measure in enumerate(part.measures): 14 | for direction in measure.directions: 15 | direction.type['measure'] = measure_idx 16 | directions.append(direction) 17 | 18 | directions.sort(key=lambda x: x.xml_position) 19 | cleaned_direction = [] 20 | for i, d in enumerate(directions): 21 | if d.type is None: 22 | continue 23 | 24 | if d.type['type'] == 'none': 25 | for j in range(i): 26 | prev_dir = directions[i - j - 1] 27 | if 'number' in prev_dir.type.keys(): 28 | prev_key = prev_dir.type['type'] 29 | prev_num = prev_dir.type['number'] 30 | else: 31 | continue 32 | if prev_num == d.type['number']: 33 | if prev_key == "crescendo": 34 | d.type['type'] = 'crescendo' 35 | break 36 | elif prev_key == "diminuendo": 37 | d.type['type'] = 'diminuendo' 38 | break 39 | cleaned_direction.append(d) 40 | 41 | return cleaned_direction 42 | 43 | 44 | def get_directions(doc): 45 | return [get_part_directions(part) for part in doc.parts] 46 | 47 | 48 | def parse_directions(doc, score_directions=None, delete_unmatched=False, delete_duplicates=False, ticks_scale=1.): 49 | score_directions_init = get_directions(doc) if score_directions is None else score_directions 50 | 51 | last_note = doc.parts[-1].measures[-1].notes[-1].note_duration 52 | max_xml_position = max(doc._state.xml_position, last_note.xml_position + last_note.duration) 53 | 54 | # anacrusis 55 | measure_pos = doc.get_measure_positions() 56 | xml_shift = max(0, measure_pos[2] - 2 * measure_pos[1] + measure_pos[0]) 57 | 58 | score_directions = [] 59 | for part_idx, part_directions_init in enumerate(score_directions_init): 60 | active_dynamic = None 61 | active_tempo = None 62 | active_hairpins = {} 63 | part_directions = [] 64 | for d_idx, d in enumerate(part_directions_init): 65 | d_data, d_dict = d.type, None 66 | if d_data['type'] == 'dynamic': 67 | d_dict = { 68 | 'type': d_data['type'], 69 | 'start': d.xml_position, 70 | 'end': max_xml_position 71 | } 72 | abs_dynamic = extract_direction_by_keys(d_data['content'], ABS_DYNAMIC_KEYS) 73 | rel_dynamic = extract_direction_by_keys(d_data['content'], REL_DYNAMIC_KEYS) 74 | 75 | if abs_dynamic is not None: 76 | d_dict['type'] += '/' + abs_dynamic 77 | if active_dynamic is not None: 78 | active_dynamic['end'] = d.xml_position 79 | active_dynamic = d_dict 80 | elif rel_dynamic is not None: 81 | d_dict['type'] += '/' + rel_dynamic 82 | d_dict['end'] = d_dict['start'] 83 | else: 84 | continue 85 | elif d_data['type'] in ('crescendo', 'diminuendo'): 86 | key = f'{d_data["type"]}_{d_data["number"]}' 87 | if d_data['content'] == 'start': 88 | active_hairpins[key] = d 89 | elif d_data['content'] == 'stop': 90 | start_d = active_hairpins.pop(key, None) 91 | if not start_d: 92 | continue 93 | d_dict = { 94 | 'type': 'dynamic' + '/' + d_data['type'], 95 | 'start': start_d.xml_position, 96 | 'end': d.xml_position 97 | } 98 | elif d_data['type'] == 'words': 99 | word = word_regularization(d_data['content']) 100 | word = hairpin_word_regularization(word) 101 | tempo_word = extract_direction_by_keys(word, TEMPO_KEYS) 102 | 103 | if word in ('crescendo', 'diminuendo'): 104 | d_dict = {'type': DYNAMIC_PREFIX} 105 | elif tempo_word is not None: 106 | word = tempo_word 107 | d_dict = {'type': TEMPO_PREFIX} 108 | if active_tempo is not None: 109 | active_tempo['end'] = d.xml_position 110 | active_tempo = d_dict 111 | elif delete_unmatched: 112 | continue 113 | else: 114 | d_dict = {'type': d_data['type']} 115 | 116 | d_dict['type'] += '/' + word 117 | d_dict.update(**{ 118 | 'start': d.xml_position, 119 | 'end': max_xml_position if d_dict['type'] == 'tempo' else d.xml_position 120 | }) 121 | else: 122 | d_dict = None 123 | 124 | if d_dict is not None: 125 | d_dict.update(**{ 126 | 'part': part_idx, 127 | 'staff': int(d.staff) if d.staff is not None else 1 128 | }) 129 | part_directions.append(d_dict) 130 | 131 | # parse note articulations 132 | def _build_note_articulation_dict(note, content): 133 | return { 134 | 'type': ARTICULATION_PREFIX + '/' + content, 135 | 'start': note.note_duration.xml_position, 136 | 'end': note.note_duration.xml_position + note.note_duration.duration, 137 | 'pitch': note.pitch[1], 138 | 'part': part_idx, 139 | 'staff': int(note.staff) if note.staff is not None else 1 140 | } 141 | 142 | part_notes, _ = get_playable_notes(doc.parts[part_idx]) 143 | for note in part_notes: 144 | if note.note_notations.is_arpeggiate: 145 | part_directions.append(_build_note_articulation_dict(note, 'arpeggiate')) 146 | if note.note_notations.is_fermata: 147 | part_directions.append(_build_note_articulation_dict(note, 'fermata')) 148 | if note.note_notations.is_staccato: 149 | part_directions.append(_build_note_articulation_dict(note, 'staccato')) 150 | if note.note_notations.is_tenuto: 151 | part_directions.append(_build_note_articulation_dict(note, 'tenuto')) 152 | 153 | # scale xml positions if needed 154 | if xml_shift != 0 or ticks_scale != 1.: 155 | for d_dict in part_directions: 156 | d_dict['start'] = int(ticks_scale * (d_dict['start'] + xml_shift)) 157 | d_dict['end'] = int(ticks_scale * (d_dict['end'] + xml_shift)) 158 | 159 | # sort directions 160 | part_directions = list(sorted(part_directions, key=lambda d: (d['start'], d['type'], d['end']))) 161 | 162 | if delete_duplicates: 163 | i = 0 164 | while i < len(part_directions) - 1: 165 | d_dict, next_d_dict = part_directions[i], part_directions[i + 1] 166 | if d_dict['type'] == next_d_dict['type'] and d_dict['start'] == next_d_dict['start']: 167 | del part_directions[i + 1] 168 | continue 169 | i += 1 170 | 171 | score_directions.append(part_directions) 172 | 173 | return score_directions 174 | -------------------------------------------------------------------------------- /scoreperformer/data/directions/tempo.py: -------------------------------------------------------------------------------- 1 | TEMPO_PREFIX = 'tempo' 2 | 3 | ABS_TEMPO_KEYS = [ 4 | 'grave', 'largo', 'larghetto', 'lento', 5 | 'adagio', 'andante', 'andantino', 'moderato', 6 | 'allegretto', 'allegro', 'vivace', 7 | 'presto', 'prestissimo' 8 | ] 9 | 10 | REL_TEMPO_KEYS = [ 11 | ('accelerando', 'acc', 'accel'), 12 | ('ritardando', 'rit', 'ritard'), 13 | ('rallentando', 'rall'), 14 | ('stringendo', 'string'), 15 | 'calando', 'più mosso', 'animato', 'stretto', 'smorzando', 'ritenuto' 16 | ] 17 | 18 | RET_TEMPO_KEYS = [ 19 | ('tempo primo', 'tempo i'), 20 | 'a tempo', 21 | ] 22 | 23 | TEMPO_KEYS = ABS_TEMPO_KEYS + REL_TEMPO_KEYS + RET_TEMPO_KEYS 24 | -------------------------------------------------------------------------------- /scoreperformer/data/directions/words.py: -------------------------------------------------------------------------------- 1 | """ Utility functions for word-based direction markings. """ 2 | 3 | PUNCTUATION = [',.\n()'] 4 | 5 | 6 | def word_regularization(word): 7 | if word: 8 | for symbol in PUNCTUATION: 9 | word = word.replace(symbol, ' ') 10 | word = word.replace(' ', ' ') 11 | return word.strip().lower() 12 | else: 13 | return None 14 | 15 | 16 | def extract_main_keyword(key): 17 | if isinstance(key, tuple): 18 | return key[0] 19 | return key 20 | 21 | 22 | def extract_direction_by_keys(dir_word, keywords): 23 | for key in keywords: 24 | if isinstance(key, tuple) and dir_word in key: 25 | return key[0] 26 | elif dir_word == key: 27 | return key 28 | return 29 | 30 | 31 | def extract_all_directions_by_keys(dir_word, keywords): 32 | directions = [] 33 | for key in keywords: 34 | if isinstance(key, tuple) and dir_word in key: 35 | directions.append(key[0]) 36 | elif dir_word == key: 37 | directions.append(key) 38 | return directions 39 | 40 | 41 | def check_direction_by_keywords(dir_word, keywords): 42 | dir_word = word_regularization(dir_word) 43 | if dir_word in keywords: 44 | return True 45 | else: 46 | word_split = dir_word.split(' ') 47 | for w in word_split: 48 | if w in keywords: 49 | return True 50 | 51 | for key in keywords: # words like 'sempre più mosso' 52 | if len(key) > 2 and key in dir_word: 53 | return True 54 | 55 | return False 56 | -------------------------------------------------------------------------------- /scoreperformer/data/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | from .indexers import TupleTokenSequenceIndexer 2 | from .processors import TokenSequenceAugmentations, TupleTokenSequenceProcessor 3 | from ..tokenizers import TokenizerTypes 4 | 5 | TOKEN_SEQUENCE_PROCESSORS = { 6 | TokenizerTypes.OctupleM: TupleTokenSequenceProcessor, 7 | TokenizerTypes.SPMuple: TupleTokenSequenceProcessor 8 | } 9 | 10 | TOKEN_SEQUENCE_INDEXERS = { 11 | TokenizerTypes.OctupleM: TupleTokenSequenceIndexer, 12 | TokenizerTypes.SPMuple: TupleTokenSequenceIndexer 13 | } 14 | -------------------------------------------------------------------------------- /scoreperformer/data/helpers/indexers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class TokenSequenceIndexer: 5 | def __init__(self, tokenizer): 6 | self.tokenizer = tokenizer 7 | 8 | def compute_bar_indices(self, seq: np.ndarray) -> np.ndarray: 9 | ... 10 | 11 | 12 | class TupleTokenSequenceIndexer(TokenSequenceIndexer): 13 | def __init__(self, tokenizer): 14 | super().__init__(tokenizer) 15 | 16 | def compute_bar_indices(self, seq: np.ndarray) -> np.ndarray: 17 | bar_index = self.tokenizer.vocab_types_idx['Bar'] 18 | 19 | min_bar = seq[0, bar_index] - self.tokenizer.zero_token 20 | total_bars = seq[-1, bar_index] - self.tokenizer.zero_token + 1 21 | 22 | bar_diff = np.concatenate([[min_bar], np.diff(seq[:, bar_index])]) 23 | bar_changes = np.where(bar_diff > 0)[0] 24 | 25 | bars = np.concatenate([[0], np.cumsum(bar_diff[bar_changes]), [total_bars]]) 26 | bar_changes = np.concatenate([[0], bar_changes, [seq.shape[0]]]) 27 | 28 | bar_indices = np.full(bars[-1] + 1, -1, dtype=np.int16) 29 | bar_indices[bars] = bar_changes 30 | 31 | for idx in range(len(bar_indices) - 1, 0, -1): 32 | if bar_indices[idx] == -1: 33 | bar_indices[idx] = bar_indices[idx + 1] 34 | 35 | return bar_indices 36 | -------------------------------------------------------------------------------- /scoreperformer/data/helpers/processors.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | from typing import Optional 4 | 5 | import numpy as np 6 | 7 | from ..tokenizers import OctupleM 8 | from ..tokenizers.constants import SOS_TOKEN, EOS_TOKEN 9 | 10 | 11 | def sample_integer_shift(low=-6, high=6): 12 | return np.random.randint(low, high + 1) 13 | 14 | 15 | @dataclass 16 | class TokenSequenceAugmentations: 17 | pitch_shift: Optional[int] = 0 18 | velocity_shift: Optional[int] = 0 19 | tempo_shift: Optional[int] = 0 20 | 21 | 22 | class TokenSequenceProcessor: 23 | def __init__( 24 | self, 25 | pitch_shift_range=(-3, 3), 26 | velocity_shift_range=(-2, 2), 27 | tempo_shift_range=(-2, 2), 28 | ): 29 | self.pitch_shift_fn = partial(sample_integer_shift, *pitch_shift_range) 30 | self.velocity_shift_fn = partial(sample_integer_shift, *velocity_shift_range) 31 | self.tempo_shift_fn = partial(sample_integer_shift, *tempo_shift_range) 32 | 33 | def sample_augmentations(self, multiplier=1.0): 34 | return TokenSequenceAugmentations( 35 | pitch_shift=int(multiplier * self.pitch_shift_fn()), 36 | velocity_shift=int(multiplier * self.velocity_shift_fn()), 37 | tempo_shift=int(multiplier * self.tempo_shift_fn()) 38 | ) 39 | 40 | def augment_sequence( 41 | self, 42 | seq: np.ndarray, 43 | augmentations: TokenSequenceAugmentations 44 | ) -> np.ndarray: 45 | ... 46 | 47 | def sort_sequence(self, seq: np.ndarray) -> np.ndarray: 48 | ... 49 | 50 | def add_sos_token(self, seq: np.ndarray) -> np.ndarray: 51 | ... 52 | 53 | def add_eos_token(self, seq: np.ndarray) -> np.ndarray: 54 | ... 55 | 56 | 57 | class TupleTokenSequenceProcessor(TokenSequenceProcessor): 58 | def __init__( 59 | self, 60 | tokenizer: OctupleM, 61 | pitch_shift_range=(-3, 3), 62 | velocity_shift_range=(-2, 2), 63 | tempo_shift_range=(-2, 2) 64 | ): 65 | super().__init__(pitch_shift_range, velocity_shift_range, tempo_shift_range) 66 | 67 | self.tokenizer = tokenizer 68 | 69 | def augment_sequence( 70 | self, 71 | seq: np.ndarray, 72 | augmentations: TokenSequenceAugmentations 73 | ) -> np.ndarray: 74 | ... 75 | if augmentations.pitch_shift != 0: 76 | pitch_index = self.tokenizer.vocab_types_idx['Pitch'] 77 | seq[:, pitch_index] += augmentations.pitch_shift 78 | 79 | if augmentations.velocity_shift != 0: 80 | vel_index = self.tokenizer.vocab_types_idx['Velocity'] 81 | vel_min, vel_max = self.tokenizer.zero_token, len(self.tokenizer.vocab[vel_index]) - 1 82 | 83 | seq[:, vel_index] += augmentations.velocity_shift 84 | seq[:, vel_index] = np.maximum(vel_min, np.minimum(vel_max, seq[:, vel_index])) 85 | 86 | if augmentations.tempo_shift != 0: 87 | tempo_index = self.tokenizer.vocab_types_idx['Tempo'] 88 | tempo_min, tempo_max = self.tokenizer.zero_token, len(self.tokenizer.vocab[tempo_index]) - 1 89 | 90 | seq[:, tempo_index] += augmentations.tempo_shift 91 | seq[:, tempo_index] = np.maximum(tempo_min, np.minimum(tempo_max, seq[:, tempo_index])) 92 | 93 | return seq 94 | 95 | def sort_sequence(self, seq: np.ndarray) -> np.ndarray: 96 | seq = seq[np.lexsort((seq[:, self.tokenizer.vocab_types_idx['Pitch']], 97 | seq[:, self.tokenizer.vocab_types_idx['Position']], 98 | seq[:, self.tokenizer.vocab_types_idx['Bar']]))] 99 | return seq 100 | 101 | def add_sos_token(self, seq: np.ndarray, initial_tempo: Optional[int] = None) -> np.ndarray: 102 | sos_token_id = self.tokenizer[0, SOS_TOKEN] 103 | seq = np.concatenate((np.full_like(seq[:1], sos_token_id), seq), axis=0) 104 | 105 | return seq 106 | 107 | def add_eos_token(self, seq: np.ndarray) -> np.ndarray: 108 | eos_token_id = self.tokenizer[0, EOS_TOKEN] 109 | seq = np.concatenate((seq, np.full_like(seq[:1], eos_token_id)), axis=0) 110 | return seq 111 | 112 | # Auxiliary processing functions 113 | 114 | def zero_out_durations(self, seq: np.ndarray) -> np.ndarray: 115 | tto = self.tokenizer.vocab_types_idx 116 | velocity_index = tto['Velocity'] 117 | if 'PerfDuration' in tto and seq.shape[-1] == len(tto): 118 | duration_index = tto['PerfDuration'] 119 | else: 120 | duration_index = tto['Duration'] 121 | 122 | silent_mask = seq[:, velocity_index] == self.tokenizer.zero_token 123 | seq[silent_mask, duration_index] = self.tokenizer.zero_token 124 | 125 | return seq 126 | 127 | def remove_silent_notes(self, seq: np.ndarray) -> np.ndarray: 128 | velocity_index = self.tokenizer.vocab_types_idx['Velocity'] 129 | 130 | silent_mask = seq[:, velocity_index] == self.tokenizer.zero_token 131 | seq = seq[~silent_mask] 132 | 133 | return seq 134 | 135 | def compute_valid_pitch_mask(self, seq: np.ndarray) -> np.ndarray: 136 | pitch_index = self.tokenizer.vocab_types_idx['Pitch'] 137 | pitch_min, pitch_max = self.tokenizer.zero_token, len(self.tokenizer.vocab[pitch_index]) - 1 138 | mask = np.logical_and(seq[:, pitch_index] >= pitch_min, seq[:, pitch_index] <= pitch_max) 139 | return mask 140 | -------------------------------------------------------------------------------- /scoreperformer/data/midi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilya16/ScorePerformer/c04f9a3155bb6afb6ca7bcc1cac2325184cc1262/scoreperformer/data/midi/__init__.py -------------------------------------------------------------------------------- /scoreperformer/data/midi/beats.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import numpy as np 4 | from miditoolkit import MidiFile, TimeSignature 5 | 6 | BEATS_IN_BARS = { 7 | 6: 2, 8 | 9: 3, 9 | 18: 3, 10 | 12: 4, 11 | 24: 4 12 | } 13 | 14 | 15 | def get_ticks_per_bar(time_sig: TimeSignature, ticks_per_beat: int = 480): 16 | return ticks_per_beat * 4 * time_sig.numerator // time_sig.denominator 17 | 18 | 19 | def get_inter_beat_interval( 20 | *, 21 | time_sig: Optional[TimeSignature], 22 | ticks_per_bar: Optional[int] = None, 23 | ticks_per_beat: int = 480 24 | ): 25 | if ticks_per_bar is None: 26 | ticks_per_bar = get_ticks_per_bar(time_sig, ticks_per_beat=ticks_per_beat) 27 | 28 | num_beat_in_bar = BEATS_IN_BARS.get(time_sig.numerator, time_sig.numerator) 29 | inter_beat_interval = int(ticks_per_bar / num_beat_in_bar) 30 | 31 | return inter_beat_interval 32 | 33 | 34 | def get_bar_beat_ticks( 35 | midi: Optional[MidiFile] = None, 36 | *, 37 | time_sigs: Optional[List[TimeSignature]] = None, 38 | ticks_per_beat: Optional[int] = None, 39 | max_tick: Optional[int] = None 40 | ): 41 | assert midi is not None or all(map(lambda x: x is not None, (time_sigs, ticks_per_beat, max_tick))) 42 | 43 | if midi is not None: 44 | time_sigs = midi.time_signature_changes 45 | ticks_per_beat = midi.ticks_per_beat 46 | max_tick = midi.max_tick - 1 47 | 48 | bar_ticks, beat_ticks = [], [] 49 | for i, time_sig in enumerate(time_sigs): 50 | last_tick = time_sigs[i + 1].time if i < len(time_sigs) - 1 else max_tick 51 | 52 | ticks_per_bar = get_ticks_per_bar(time_sig, ticks_per_beat=ticks_per_beat) 53 | bar_ticks.append(np.arange(time_sig.time, last_tick, ticks_per_bar)) 54 | 55 | inter_beat_interval = get_inter_beat_interval( 56 | time_sig=time_sig, ticks_per_bar=ticks_per_bar, ticks_per_beat=ticks_per_beat 57 | ) 58 | beat_ticks.append(np.arange(time_sig.time, last_tick, inter_beat_interval)) 59 | 60 | if len(time_sigs) > 1: 61 | bar_ticks, beat_ticks = np.concatenate(bar_ticks), np.concatenate(beat_ticks) 62 | else: 63 | bar_ticks, beat_ticks = bar_ticks[0], beat_ticks[0] 64 | 65 | return bar_ticks, beat_ticks 66 | 67 | 68 | def get_performance_beats( 69 | score_beats: np.ndarray, 70 | position_pairs: np.ndarray, 71 | max_tick: Optional[int] = None, 72 | max_time: Optional[float] = None, 73 | monotonic_times: bool = False, 74 | ticks_per_beat: int = 480 75 | ): 76 | if monotonic_times: 77 | mono_position_pairs = [position_pairs[0]] 78 | cur_pair = prev_pair = position_pairs[0] 79 | for pair in position_pairs[1:]: 80 | min_shift_time = (pair[0] - cur_pair[0]) / ticks_per_beat / 10 # tempo 600 81 | if pair[0] != prev_pair[0] and pair[1] > prev_pair[1] and pair[1] > cur_pair[1] + min_shift_time: 82 | mono_position_pairs.append(pair) 83 | cur_pair = pair 84 | prev_pair = pair 85 | position_pairs = np.array(mono_position_pairs) 86 | 87 | if max_tick is not None and max_time is not None: 88 | position_pairs = np.concatenate([position_pairs, [(max_tick, max_time)]]) 89 | score_beats = np.concatenate([score_beats, [max_tick]]) 90 | 91 | onset_ticks, perf_times = position_pairs[:, 0], position_pairs[:, 1] 92 | beat_onset_indices = np.minimum(len(onset_ticks) - 1, np.searchsorted(onset_ticks, score_beats)) 93 | 94 | # fill known beats 95 | perf_beats = [] 96 | for i, beat in enumerate(score_beats): 97 | onset_idx = beat_onset_indices[i] 98 | if onset_ticks[onset_idx] == beat: 99 | perf_beat = perf_times[onset_idx] 100 | else: 101 | # interpolate 102 | if i == 0 or onset_idx == 0: 103 | onset_idx += 1 104 | 105 | left_tick, right_tick = onset_ticks[onset_idx - 1], onset_ticks[onset_idx] 106 | left_time, right_time = perf_times[onset_idx - 1], perf_times[onset_idx] 107 | 108 | perf_beat = left_time + (right_time - left_time) * (beat - left_tick) / (right_tick - left_tick) 109 | 110 | perf_beats.append(perf_beat) 111 | 112 | if max_tick is not None and max_time is not None: 113 | if score_beats[-2] == score_beats[-1]: 114 | score_beats = score_beats[:-1] 115 | perf_beats = perf_beats[:-1] 116 | 117 | perf_beats = np.array(perf_beats) 118 | 119 | return score_beats, perf_beats 120 | -------------------------------------------------------------------------------- /scoreperformer/data/midi/containers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Note: 6 | pitch: int 7 | velocity: int 8 | start: float 9 | end: float 10 | 11 | @property 12 | def duration(self): 13 | return self.end - self.start 14 | -------------------------------------------------------------------------------- /scoreperformer/data/midi/preprocess.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | from miditok.utils import merge_tracks 4 | from miditoolkit import MidiFile, Marker, Instrument 5 | 6 | from ..midi import quantization as midi_quan 7 | from ..midi import utils as midi_utl 8 | from ..midi.containers import Note 9 | 10 | 11 | def preprocess_midi( 12 | midi: MidiFile, 13 | to_single_track: bool = True, 14 | sort_events: bool = True, 15 | clean_duplicates: bool = True, 16 | cut_overlapped_notes: bool = False, 17 | clean_short_notes: bool = False, 18 | quantize_notes: bool = False, 19 | quantize_midi_changes: bool = False, 20 | filter_late_events: bool = True, 21 | target_ticks_per_beat: Optional[int] = None 22 | ): 23 | if len(midi.instruments) == 0: 24 | return midi 25 | 26 | if len(midi.instruments) > 1 and to_single_track: 27 | merge_tracks(midi.instruments, effects=True) 28 | 29 | # TODO: order of notes changes inside, handle the case `sort_events = False` 30 | for track in midi.instruments: 31 | if clean_duplicates: 32 | midi_utl.remove_duplicated_notes(track.notes) 33 | 34 | if cut_overlapped_notes: 35 | midi_utl.cut_overlapping_notes(track.notes) 36 | 37 | if clean_short_notes: 38 | midi_utl.remove_short_notes(track.notes, time_division=midi.ticks_per_beat) 39 | 40 | if quantize_notes: 41 | midi_quan.quantize_notes(track.notes, time_division=midi.ticks_per_beat) 42 | if clean_duplicates: 43 | midi_utl.remove_duplicated_notes(track.notes) 44 | 45 | if sort_events: 46 | for track in midi.instruments: 47 | track.notes.sort(key=lambda x: (x.start, x.pitch, x.end)) # sort notes 48 | midi.max_tick = max([max([note.end for note in track.notes[-100:]]) for track in midi.instruments]) 49 | else: 50 | midi.max_tick = max([max([note.end for note in track.notes]) for track in midi.instruments]) + 1 51 | 52 | midi.instruments = [track for track in midi.instruments if len(track.notes) > 0] 53 | 54 | if filter_late_events: 55 | midi_utl.filter_late_midi_events(midi, sort=sort_events) 56 | 57 | if quantize_midi_changes: 58 | midi_quan.quantize_time_signatures(midi.time_signature_changes, time_division=midi.ticks_per_beat) 59 | midi_quan.quantize_tempos(midi.tempo_changes, time_division=midi.ticks_per_beat) 60 | midi_quan.quantize_key_signatures(midi.key_signature_changes, time_division=midi.ticks_per_beat) 61 | 62 | if target_ticks_per_beat is not None: 63 | midi_utl.resample_midi(midi, ticks_per_beat=target_ticks_per_beat) 64 | 65 | return midi 66 | 67 | 68 | def insert_silent_notes( 69 | midi: MidiFile, 70 | markers: Optional[List[Marker]] = None, 71 | track_idx: Optional[int] = None 72 | ): 73 | markers = markers or midi.markers 74 | 75 | notes = [] 76 | for m in markers: 77 | if m.text.startswith('NoteS'): 78 | pitch, start_tick, end_tick = map(int, m.text.split('_')[1:]) 79 | notes.append(Note(pitch, 0, start_tick, end_tick)) 80 | 81 | if track_idx is None: 82 | track = Instrument(0, False, 'Unperformed Notes') 83 | track.notes = notes 84 | midi.instruments.append(track) 85 | else: 86 | midi.instruments[track_idx].notes += notes 87 | 88 | if midi.instruments[-1].name != 'Unperformed Notes': 89 | midi.instruments.append(Instrument(0, False, 'Unperformed Notes')) 90 | 91 | return midi 92 | -------------------------------------------------------------------------------- /scoreperformer/data/midi/quantization.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | from miditoolkit import Note, TempoChange, TimeSignature, KeySignature 4 | 5 | 6 | def quantize_notes( 7 | notes: List[Note], 8 | time_division: int, 9 | max_beat_res: int = 32, 10 | pitch_range: Optional[Tuple[int, int]] = (21, 109) 11 | ): 12 | """ Quantize notes, i.e. their pitch, start and end values. 13 | Shifts the notes' start and end times to match the quantization (e.g. 16 samples per bar) 14 | Notes with pitches outside of pitch_range are deleted. 15 | :param notes: notes to quantize 16 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed) 17 | :param max_beat_res: maximum beat resolution for one sample 18 | :param pitch_range: pitch range from within notes should be 19 | """ 20 | ticks_per_sample = int(time_division / max_beat_res) 21 | i = 0 22 | while i < len(notes): 23 | note = notes[i] 24 | if pitch_range is not None and (note.pitch < pitch_range[0] or note.pitch >= pitch_range[1]): 25 | del notes[i] 26 | continue 27 | start_offset = note.start % ticks_per_sample 28 | end_offset = note.end % ticks_per_sample 29 | note.start += -start_offset if start_offset <= ticks_per_sample / 2 else ticks_per_sample - start_offset 30 | note.end += -end_offset if end_offset <= ticks_per_sample / 2 else ticks_per_sample - end_offset 31 | 32 | if note.start == note.end: 33 | note.end += ticks_per_sample 34 | 35 | i += 1 36 | 37 | 38 | def quantize_tempos( 39 | tempos: List[TempoChange], 40 | time_division: int, 41 | max_beat_res: int = 32 42 | ): 43 | r"""Quantize the times of tempo change events. 44 | Consecutive identical tempo changes will be removed. 45 | 46 | :param tempos: tempo changes to quantize 47 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed) 48 | :param max_beat_res: maximum beat resolution for one sample 49 | """ 50 | ticks_per_sample = int(time_division / max_beat_res) 51 | i, prev_tempo = 0, -1 52 | while i < len(tempos): 53 | # Quantize tempo value 54 | if tempos[i].tempo == prev_tempo: 55 | del tempos[i] 56 | continue 57 | rest = tempos[i].time % ticks_per_sample 58 | tempos[i].time += -rest if rest <= ticks_per_sample / 2 else ticks_per_sample - rest 59 | prev_tempo = tempos[i].tempo 60 | i += 1 61 | 62 | 63 | def compute_ticks_per_bar(time_sig: TimeSignature, time_division: int): 64 | r"""Computes time resolution of one bar in ticks. 65 | 66 | :param time_sig: time signature object 67 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed) 68 | :return: MIDI bar resolution, in ticks/bar 69 | """ 70 | return int(time_division * 4 * time_sig.numerator / time_sig.denominator) 71 | 72 | 73 | def quantize_time_signatures(time_sigs: List[TimeSignature], time_division: int): 74 | r"""Quantize the time signature changes, delayed to the next bar. 75 | See MIDI 1.0 Detailed specifications, pages 54 - 56, for more information on 76 | delayed time signature messages. 77 | 78 | :param time_sigs: time signature changes to quantize 79 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed) 80 | """ 81 | all_different = False 82 | while not all_different: 83 | all_different = True 84 | 85 | # delete one of neighbouring time signatures with same values or time 86 | prev_time_sig = time_sigs[0] 87 | i = 1 88 | while i < len(time_sigs): 89 | time_sig = time_sigs[i] 90 | 91 | if (time_sig.numerator, time_sig.denominator) == (prev_time_sig.numerator, prev_time_sig.denominator) or \ 92 | time_sig.time == prev_time_sig.time: 93 | del time_sigs[i] 94 | all_different = False 95 | continue 96 | prev_time_sig = time_sig 97 | i += 1 98 | 99 | # quantize times 100 | ticks_per_bar = compute_ticks_per_bar(time_sigs[0], time_division) 101 | current_bar = 0 102 | previous_tick = 0 # first time signature change is always at tick 0 103 | i = 1 104 | while i < len(time_sigs): 105 | time_sig = time_sigs[i] 106 | 107 | # determine the current bar of time sig 108 | bar_offset, rest = divmod(time_sig.time - previous_tick, ticks_per_bar) 109 | if rest > 0: # time sig doesn't happen on a new bar, we update it to the next bar 110 | bar_offset += 1 111 | time_sig.time = previous_tick + bar_offset * ticks_per_bar 112 | 113 | # Update values 114 | ticks_per_bar = compute_ticks_per_bar(time_sig, time_division) 115 | current_bar += bar_offset 116 | previous_tick = time_sig.time 117 | i += 1 118 | 119 | 120 | def quantize_key_signatures( 121 | key_signatures: List[KeySignature], 122 | time_division: int, 123 | max_beat_res: int = 32 124 | ): 125 | r"""Quantize the times of key signature change events. 126 | Consecutive identical key signature changes will be removed. 127 | 128 | :param key_signatures: key signature changes to quantize 129 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed) 130 | :param max_beat_res: maximum beat resolution for one sample 131 | """ 132 | ticks_per_sample = int(time_division / max_beat_res) 133 | i, prev_tempo = 0, '' 134 | while i < len(key_signatures): 135 | # Quantize tempo value 136 | if key_signatures[i].key_name == prev_tempo: 137 | del key_signatures[i] 138 | continue 139 | rest = key_signatures[i].time % ticks_per_sample 140 | key_signatures[i].time += -rest if rest <= ticks_per_sample / 2 else ticks_per_sample - rest 141 | prev_tempo = key_signatures[i].key_name 142 | i += 1 143 | -------------------------------------------------------------------------------- /scoreperformer/data/midi/sync.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional 3 | 4 | import numpy as np 5 | from miditoolkit import MidiFile, TempoChange, Marker 6 | 7 | from scoreperformer.utils import find_closest 8 | from .beats import get_inter_beat_interval, get_bar_beat_ticks, get_performance_beats 9 | from .timing import ( 10 | convert_symbolic_timing_to_absolute, 11 | convert_absolute_timing_to_symbolic 12 | ) 13 | from .utils import filter_late_midi_events 14 | 15 | 16 | def sync_performance_midi( 17 | score_midi: MidiFile, 18 | perf_midi: MidiFile, 19 | onset_pairs: np.ndarray, 20 | is_absolute_timing: bool = False, 21 | max_time: Optional[float] = None, 22 | ticks_per_beat: int = 480, 23 | bar_sync: bool = True, 24 | inplace: bool = True, 25 | verbose: bool = False 26 | ): 27 | """ 28 | Synchronizes performance MIDI with score MIDI bars/beats through the onset pairs. 29 | 30 | Also adds silent (unperformed) performance notes as special markers based on the alignment. 31 | """ 32 | perf_midi = copy.deepcopy(perf_midi) if not inplace else perf_midi 33 | 34 | # preprocess performance midi 35 | filter_late_midi_events(perf_midi) 36 | max_tick = score_midi.max_tick 37 | 38 | if not is_absolute_timing: 39 | tick_to_time = perf_midi.get_tick_to_time_mapping() 40 | max_time = tick_to_time[-1] 41 | else: 42 | assert max_time is not None, "`max_time` should be explicitly provided for MIDI with absolute timing" 43 | tick_to_time = None 44 | 45 | # compute score and performance onsets 46 | score_bars, score_beats = get_bar_beat_ticks(score_midi) 47 | score_onsets = score_bars if bar_sync else score_beats 48 | score_onsets, perf_onsets = get_performance_beats( 49 | score_onsets, onset_pairs, 50 | max_tick=max_tick - 1, max_time=max_time, 51 | monotonic_times=True, ticks_per_beat=ticks_per_beat 52 | ) 53 | perf_shift = perf_onsets[0] 54 | perf_onsets -= perf_shift 55 | max_time -= perf_shift 56 | 57 | perf_score_tick_ratio = ticks_per_beat / score_midi.ticks_per_beat 58 | 59 | time_signatures = score_midi.time_signature_changes 60 | 61 | time_sig_ticks, quarter_note_factors, inter_onset_intervals = [], [], [] 62 | for time_sig in time_signatures: 63 | time_sig_ticks.append(time_sig.time) 64 | quarter_note_factors.append(4 * time_sig.numerator / time_sig.denominator) 65 | inter_onset_intervals.append( 66 | get_inter_beat_interval(time_sig=time_sig, ticks_per_beat=score_midi.ticks_per_beat) 67 | ) 68 | 69 | time_sig_ticks, quarter_note_factors, inter_onset_intervals = map( 70 | np.array, (time_sig_ticks, quarter_note_factors, inter_onset_intervals) 71 | ) 72 | inter_beat_intervals = inter_onset_intervals 73 | 74 | ticks_per_bar = (score_midi.ticks_per_beat * quarter_note_factors).astype(int) 75 | beats_per_bar = ticks_per_bar / inter_beat_intervals 76 | ioi_in_quarters = ibi_in_quarters = quarter_note_factors / beats_per_bar 77 | 78 | if bar_sync: 79 | inter_onset_intervals = inter_onset_intervals * beats_per_bar 80 | ioi_in_quarters = ioi_in_quarters * beats_per_bar 81 | 82 | if verbose: 83 | print(f'score: time_sigs={time_signatures}\n' 84 | f' ticks_per_beat={score_midi.ticks_per_beat}, ticks_per_bar={ticks_per_bar}\n' 85 | f' inter_beat_intervals={inter_beat_intervals}, inter_onset_intervals={inter_onset_intervals}\n' 86 | f' ibi_in_quarters={ibi_in_quarters}, ioi_in_quarters={ioi_in_quarters}') 87 | 88 | # compute tempos 89 | intervals = np.diff(perf_onsets) 90 | if np.any(intervals <= 0.): 91 | return None 92 | 93 | time_sig_indices = (np.searchsorted(time_sig_ticks, score_onsets, side='right') - 1)[:-1] 94 | inter_onset_ratios = np.diff(score_onsets) / inter_onset_intervals[time_sig_indices] 95 | tempos = 60 / intervals * ioi_in_quarters[time_sig_indices] * inter_onset_ratios 96 | 97 | if verbose: 98 | print(f'tempos: ({tempos.min():.3f}, {tempos.max():.3f}), {np.median(tempos):.3f}') 99 | 100 | # get absolute timing of instruments 101 | if is_absolute_timing: 102 | abs_instr = perf_midi.instruments 103 | else: 104 | abs_instr = convert_symbolic_timing_to_absolute( 105 | perf_midi.instruments, tick_to_time, inplace=inplace, time_shift=-perf_shift 106 | ) 107 | 108 | # compute time to tick mapping 109 | inter_onset_intervals = inter_onset_intervals[time_sig_indices] * perf_score_tick_ratio * inter_onset_ratios 110 | resample_timing = [] 111 | for i in range(len(perf_onsets) - 1): 112 | start_beat, end_beat = perf_onsets[i], perf_onsets[i + 1] 113 | resample_timing.append(np.linspace(start_beat, end_beat, int(inter_onset_intervals[i]) + 1)[:-1]) 114 | 115 | resample_timing.append([max_time]) 116 | resample_timing = np.round(np.concatenate(resample_timing), 6) 117 | 118 | # new a midifile obj 119 | midi = MidiFile(ticks_per_beat=ticks_per_beat) 120 | 121 | # convert abs to sym 122 | sym_instr = convert_absolute_timing_to_symbolic(abs_instr, resample_timing, inplace=inplace) 123 | 124 | # process timing of markers 125 | markers = perf_midi.markers if hasattr(perf_midi, 'markers') else [] 126 | for marker in markers: 127 | marker.time = find_closest(resample_timing, float(tick_to_time[marker.time]) - perf_shift) 128 | if marker.text.startswith('NoteI'): 129 | pitch, start, end = map(int, marker.text.split('_')[1:]) 130 | start, end = map(lambda x: find_closest(resample_timing, float(tick_to_time[x]) - perf_shift), (start, end)) 131 | marker.text = f'NoteI_{pitch}_{start}_{end}' 132 | 133 | # tempo 134 | tempo_changes = [] 135 | onset_ticks = find_closest(resample_timing, perf_onsets) 136 | for pos_tick, tempo in zip(onset_ticks[:-1], tempos): 137 | tempo_changes.append(TempoChange(tempo=float(tempo), time=int(pos_tick))) 138 | 139 | tempo_changes = [tempo for tempo in tempo_changes if tempo.time < resample_timing.shape[0]] 140 | 141 | # markers 142 | markers.insert(0, Marker(text=f'Shift_{perf_shift:.6f}', time=0)) 143 | 144 | # set attributes 145 | midi.tempo_changes = tempo_changes 146 | midi.time_signature_changes = time_signatures 147 | midi.instruments = sym_instr 148 | midi.markers = markers 149 | midi.max_tick = resample_timing.shape[0] 150 | 151 | return midi 152 | -------------------------------------------------------------------------------- /scoreperformer/data/midi/timing.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import List 3 | 4 | import numpy as np 5 | from miditoolkit import Instrument 6 | 7 | from scoreperformer.utils import find_closest 8 | from ..midi.containers import Note 9 | 10 | 11 | def convert_symbolic_timing_to_absolute( 12 | tracks: List[Instrument], 13 | tick_to_time: np.ndarray, 14 | inplace: bool = True, 15 | time_shift: float = 0. 16 | ): 17 | tracks = tracks if inplace else copy.deepcopy(tracks) 18 | 19 | for track in tracks: 20 | track.notes = [ 21 | Note(pitch=n.pitch, velocity=n.velocity, 22 | start=time_shift + float(tick_to_time[n.start]), 23 | end=time_shift + float(tick_to_time[n.end])) 24 | for n in track.notes 25 | ] 26 | for control_change in track.control_changes: 27 | control_change.time = time_shift + float(tick_to_time[control_change.time]) 28 | for pedal in track.pedals: 29 | pedal.start = time_shift + float(tick_to_time[pedal.start]) 30 | pedal.end = time_shift + float(tick_to_time[pedal.end]) 31 | for pitch_bend in track.pitch_bends: 32 | pitch_bend.time = time_shift + float(tick_to_time[pitch_bend.time]) 33 | 34 | return tracks 35 | 36 | 37 | def convert_absolute_timing_to_symbolic( 38 | tracks: List[Instrument], 39 | time_to_tick: np.ndarray, 40 | inplace: bool = True 41 | ): 42 | tracks = tracks if inplace else copy.deepcopy(tracks) 43 | 44 | def process_interval_events(events): 45 | start_times = np.array(list(map(lambda x: x.start, events))) 46 | start_ticks = find_closest(time_to_tick, start_times) 47 | end_times = np.array(list(map(lambda x: x.end, events))) 48 | end_ticks = find_closest(time_to_tick, end_times) 49 | for event, start_t, end_t in zip(events, start_ticks, end_ticks): 50 | if start_t == end_t: 51 | end_t += 1 52 | event.start = start_t 53 | event.end = end_t 54 | 55 | def process_time_events(events): 56 | times = np.array(list(map(lambda x: x.time, events))) 57 | ticks = find_closest(time_to_tick, times) 58 | for event, t in zip(events, ticks): 59 | event.time = t 60 | 61 | for track in tracks: 62 | process_interval_events(track.notes) 63 | process_interval_events(track.pedals) 64 | process_time_events(track.control_changes) 65 | process_time_events(track.pitch_bends) 66 | 67 | return tracks 68 | -------------------------------------------------------------------------------- /scoreperformer/data/midi/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import defaultdict 3 | from typing import Optional, List 4 | 5 | import numpy as np 6 | from miditoolkit import Note, MidiFile 7 | 8 | from scoreperformer.utils import find_closest 9 | 10 | 11 | def sort_notes( 12 | notes: List[Note], 13 | compute_sort_indices: bool = False, 14 | order: str = 'time' 15 | ): 16 | assert order in ('time', 'pitch') 17 | 18 | sort_ids = None 19 | if order == 'time': 20 | if compute_sort_indices: 21 | sort_ids = np.lexsort([[n.end for n in notes], [n.pitch for n in notes], [n.start for n in notes]]) 22 | notes.sort(key=lambda n: (n.start, n.pitch, n.end)) 23 | elif order == 'pitch': 24 | if compute_sort_indices: 25 | sort_ids = np.lexsort([[n.end for n in notes], [n.start for n in notes], [n.pitch for n in notes]]) 26 | notes.sort(key=lambda n: (n.pitch, n.start, n.end)) 27 | 28 | return notes, sort_ids 29 | 30 | 31 | def cut_overlapping_notes(notes: List[Note], return_sort_indices: bool = False): 32 | r"""Find and cut the first of the two overlapping notes, i.e. with the same pitch, 33 | and the second note starting before the ending of the first note. 34 | 35 | :param notes: notes to analyse 36 | :param return_sort_indices: return indices by which the original notes were sorted 37 | """ 38 | # sort by pitch, then time 39 | notes, sort_ids = sort_notes(notes, compute_sort_indices=return_sort_indices, order='pitch') 40 | 41 | for i in range(1, len(notes)): 42 | prev_note, note = notes[i - 1], notes[i] 43 | if prev_note.pitch == note.pitch and prev_note.end >= note.start: 44 | if note.start <= 1: 45 | note.start = 2 46 | prev_note.end = note.start - 1 47 | if prev_note.start >= prev_note.end: # resulted in an invalid note, fix it too 48 | prev_note.start = prev_note.end - 1 49 | 50 | # sort back by time, then pitch 51 | notes, sort_ids_back = sort_notes(notes, compute_sort_indices=return_sort_indices, order='time') 52 | 53 | if return_sort_indices: 54 | sort_ids = sort_ids[sort_ids_back] 55 | return notes, sort_ids 56 | return notes 57 | 58 | 59 | def remove_duplicated_notes(notes: List[Note], return_sort_indices: bool = False): 60 | r"""Find and remove exactly similar notes, i.e. with the same pitch, start and end. 61 | 62 | :param notes: notes to analyse 63 | :param return_sort_indices: return indices by which the original notes were sorted 64 | """ 65 | # sort by pitch, then time 66 | notes, sort_ids = sort_notes(notes, compute_sort_indices=return_sort_indices, order='pitch') 67 | 68 | for i in range(len(notes) - 1, 0, -1): # removing possible duplicated notes 69 | if notes[i].pitch == notes[i - 1].pitch and notes[i].start == notes[i - 1].start and \ 70 | notes[i].end >= notes[i - 1].end: 71 | del notes[i] 72 | 73 | # sort back by time, then pitch 74 | notes, sort_ids_back = sort_notes(notes, compute_sort_indices=return_sort_indices, order='time') 75 | 76 | if return_sort_indices: 77 | sort_ids = sort_ids[sort_ids_back] 78 | return notes, sort_ids 79 | return notes 80 | 81 | 82 | def remove_short_notes(notes: List[Note], time_division: int, max_beat_res: int = 32): 83 | r"""Find and remove short notes, i.e. with the same pitch, start and end. 84 | 85 | :param notes: notes to analyse 86 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed) 87 | :param max_beat_res: maximum beat resolution for one sample 88 | """ 89 | ticks_per_sample = int(time_division / max_beat_res) 90 | 91 | for i in range(len(notes) - 1, 0, -1): 92 | note = notes[i] 93 | if note.end - note.start < ticks_per_sample // 2: 94 | del notes[i] 95 | 96 | return notes 97 | 98 | 99 | def filter_late_midi_events(midi: MidiFile, max_tick: Optional[int] = None, sort: bool = False): 100 | max_tick = max_tick or midi.max_tick 101 | 102 | for track in midi.instruments: 103 | if sort: 104 | track.control_changes.sort(key=lambda c: c.time) 105 | for i, control_change in enumerate(track.control_changes): 106 | if control_change.time > max_tick: 107 | track.control_changes = track.control_changes[:i] 108 | break 109 | 110 | if sort: 111 | track.pedals.sort(key=lambda p: p.start) 112 | for i, pedal in enumerate(track.pedals): 113 | if pedal.end > max_tick: 114 | track.pedals = track.pedals[:i] 115 | break 116 | 117 | if sort: 118 | track.pitch_bends.sort(key=lambda p: p.time) 119 | for i, pitch_bend in enumerate(track.pitch_bends): 120 | if pitch_bend.time > max_tick: 121 | track.pitch_bends = track.pitch_bends[:i] 122 | break 123 | 124 | return midi 125 | 126 | 127 | def shift_midi_notes( 128 | midi: MidiFile, 129 | time_shift: float = 0., 130 | offset: float = 0., 131 | inplace: bool = True, 132 | return_shifted_indices: bool = False 133 | ): 134 | midi = midi if inplace else copy.deepcopy(midi) 135 | 136 | midi.max_tick *= 4 137 | ttt = midi.get_tick_to_time_mapping() 138 | 139 | def process_continuous_events(elements): 140 | start_ticks = np.array(list(map(lambda x: x.start, elements))) 141 | end_ticks = np.array(list(map(lambda x: x.end, elements))) 142 | start_times, end_times = ttt[start_ticks], ttt[end_ticks] 143 | new_start_ticks = find_closest(ttt, start_times + time_shift) 144 | new_end_ticks = find_closest(ttt, end_times + time_shift) 145 | for el, time, start_t, end_t in zip(elements, start_times, new_start_ticks, new_end_ticks): 146 | if time >= offset: 147 | if start_t == end_t: 148 | end_t += 1 149 | el.start = start_t 150 | el.end = end_t 151 | return np.where(start_times >= offset)[0] 152 | 153 | def process_instant_events(elements): 154 | ticks = np.array(list(map(lambda x: x.time, elements))) 155 | times = ttt[ticks] 156 | new_ticks = find_closest(ttt, times + time_shift) 157 | for el, time, tick in zip(elements, times, new_ticks): 158 | if time >= offset: 159 | el.time = tick 160 | return np.where(times >= offset)[0] 161 | 162 | # shift relevant notes in MIDI 163 | shifted_indices = defaultdict(list) 164 | for track_idx, track in enumerate(midi.instruments): 165 | shifted_indices['note'].append((track_idx, process_continuous_events(track.notes))) 166 | if track.pedals: 167 | shifted_indices['pedal'].append((track_idx, process_continuous_events(track.pedals))) 168 | if track.control_changes: 169 | shifted_indices['control_change'].append((track_idx, process_instant_events(track.control_changes))) 170 | if track.pitch_bends: 171 | shifted_indices['pitch_bend'].append((track_idx, process_instant_events(track.pitch_bends))) 172 | 173 | midi.max_tick = max([max([note.end for note in track.notes]) for track in midi.instruments]) + 1 174 | 175 | if return_shifted_indices: 176 | return midi, shifted_indices 177 | return midi 178 | 179 | 180 | def resample_midi(midi: MidiFile, ticks_per_beat: int, inplace: bool = True): 181 | if midi.ticks_per_beat == ticks_per_beat: 182 | return midi 183 | 184 | midi = midi if inplace else copy.deepcopy(midi) 185 | 186 | scale = ticks_per_beat / midi.ticks_per_beat 187 | 188 | def process_continuous_events(elements): 189 | for el in elements: 190 | el.start = int(scale * el.start) 191 | el.end = int(scale * el.end) 192 | 193 | def process_instant_events(elements): 194 | for el in elements: 195 | el.time = int(scale * el.time) 196 | 197 | # resample MIDI events 198 | for track in midi.instruments: 199 | process_continuous_events(track.notes) 200 | if track.pedals: 201 | process_continuous_events(track.pedals) 202 | if track.control_changes: 203 | process_instant_events(track.control_changes) 204 | if track.pitch_bends: 205 | process_instant_events(track.pitch_bends) 206 | 207 | process_instant_events(midi.time_signature_changes) 208 | process_instant_events(midi.tempo_changes) 209 | process_instant_events(midi.key_signature_changes) 210 | 211 | midi.max_tick = max([max([note.end for note in track.notes]) for track in midi.instruments]) + 1 212 | return midi 213 | -------------------------------------------------------------------------------- /scoreperformer/data/music_constants.py: -------------------------------------------------------------------------------- 1 | """ Music constants. """ 2 | 3 | # notes and pitch-sitch maps 4 | NOTES_WSHARP = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] 5 | NOTES_WFLAT = ['C', 'Db', 'D', 'Eb', 'E', 'F', 'Gb', 'G', 'Ab', 'A', 'Bb', 'B'] 6 | NOTE_MAP = {sitch: i for i, sitch in enumerate(NOTES_WSHARP)} 7 | NOTE_INV_MAP = {i: sitch for sitch, i in NOTE_MAP.items()} 8 | NOTE_MAP.update(**{sitch: i for i, sitch in enumerate(NOTES_WFLAT)}) 9 | 10 | 11 | def pitch2sitch(pitch): 12 | return NOTE_INV_MAP[pitch % 12] + str(pitch // 12 - 1) 13 | 14 | 15 | def sitch2pitch(sitch): 16 | note = sitch[:1 + int(sitch[1] in ("#", "b"))] 17 | octave = sitch[len(note):] 18 | return NOTE_MAP[note] + 12 * (int(octave) + 1) 19 | -------------------------------------------------------------------------------- /scoreperformer/data/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from scoreperformer.utils import ExplicitEnum 2 | from .classes import TokSequence, TokenizerConfig 3 | from .common import OctupleM 4 | from .spmuple import ( 5 | SPMupleBase, 6 | SPMuple, 7 | SPMuple2, 8 | 9 | SPMupleOnset, 10 | SPMupleBeat, 11 | SPMupleBar, 12 | SPMupleWindow, 13 | SPMupleWindowRecompute 14 | ) 15 | 16 | 17 | class TokenizerTypes(ExplicitEnum): 18 | OctupleM = "OctupleM" 19 | SPMuple = "SPMuple" 20 | SPMuple2 = "SPMuple2" 21 | SPMupleOnset = "SPMupleOnset" 22 | SPMupleBeat = "SPMupleBeat" 23 | SPMupleBar = "SPMupleBar" 24 | SPMupleWindow = "SPMupleWindow" 25 | SPMupleWindowRecompute = "SPMupleWindowRecompute" 26 | 27 | 28 | TOKENIZERS = { 29 | TokenizerTypes.OctupleM: OctupleM, 30 | TokenizerTypes.SPMuple: SPMuple, 31 | TokenizerTypes.SPMuple2: SPMuple2, 32 | TokenizerTypes.SPMupleOnset: SPMupleOnset, 33 | TokenizerTypes.SPMupleBeat: SPMupleBeat, 34 | TokenizerTypes.SPMupleBar: SPMupleBar, 35 | TokenizerTypes.SPMupleWindow: SPMupleWindow, 36 | TokenizerTypes.SPMupleWindowRecompute: SPMupleWindowRecompute 37 | } 38 | -------------------------------------------------------------------------------- /scoreperformer/data/tokenizers/classes.py: -------------------------------------------------------------------------------- 1 | """ Extended miditok classes. """ 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional, Dict, Sequence 5 | 6 | from miditok.classes import ( 7 | TokSequence as MidiTokTokSequence, 8 | TokenizerConfig as MidiTokTokenizerConfig 9 | ) 10 | 11 | from .constants import SPECIAL_TOKENS 12 | 13 | 14 | @dataclass 15 | class TokSequence(MidiTokTokSequence): 16 | meta: Optional[Dict[str, object]] = None 17 | 18 | 19 | class TokenizerConfig(MidiTokTokenizerConfig): 20 | r""" 21 | MIDI tokenizer base class, containing common methods and attributes for all tokenizers. 22 | :param special_tokens: list of special tokens. This must be given as a list of strings given 23 | only the names of the tokens. (default: ``["PAD", "SOS", "EOS", "MASK"]``\) 24 | :param **kwargs: additional parameters that will be saved in `config.additional_params`. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | special_tokens: Sequence[str] = SPECIAL_TOKENS, 30 | **kwargs 31 | ): 32 | super().__init__(special_tokens=special_tokens, **kwargs) 33 | -------------------------------------------------------------------------------- /scoreperformer/data/tokenizers/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .octuple_m import OctupleM 2 | -------------------------------------------------------------------------------- /scoreperformer/data/tokenizers/constants.py: -------------------------------------------------------------------------------- 1 | """ Tokenizer related constants. """ 2 | 3 | # override MidiTok default special tokens for backward compatibility 4 | SPECIAL_TOKENS = ["PAD", "MASK", "SOS", "EOS"] 5 | 6 | PAD_TOKEN = "PAD_None" 7 | MASK_TOKEN = "MASK_None" 8 | SOS_TOKEN = "SOS_None" 9 | EOS_TOKEN = "EOS_None" 10 | 11 | TIME_DIVISION = 480 12 | 13 | SCORE_KEYS = [ 14 | "Bar", 15 | "Position", 16 | "Pitch", 17 | "Velocity", 18 | "Duration", 19 | "Tempo", 20 | "TimeSig", 21 | "Program", 22 | "PositionShift", 23 | "NotesInOnset", 24 | "PositionInOnset" 25 | ] 26 | PERFORMANCE_KEYS = SCORE_KEYS + [ 27 | "OnsetDev", 28 | "PerfDuration", 29 | "RelOnsetDev", 30 | "RelPerfDuration" 31 | ] 32 | -------------------------------------------------------------------------------- /scoreperformer/data/tokenizers/midi_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ MIDI encoding base class and methods. """ 2 | 3 | from abc import ABC 4 | 5 | from miditok import MIDITokenizer as _MIDITokenizer 6 | from miditok.constants import TIME_SIGNATURE 7 | from miditok.utils import remove_duplicated_notes, merge_same_program_tracks 8 | from miditoolkit import MidiFile, TimeSignature 9 | 10 | 11 | class MIDITokenizer(_MIDITokenizer, ABC): 12 | r"""MIDI tokenizer base class, containing common methods and attributes for all tokenizers. 13 | 14 | See :class:`miditok.MIDITokenizer` for a detailed documentation. 15 | """ 16 | 17 | def preprocess_midi(self, midi: MidiFile): 18 | r"""Pre-process (in place) a MIDI file to quantize its time and note attributes 19 | before tokenizing it. Its notes attribute (times, pitches, velocities) will be 20 | quantized and sorted, duplicated notes removed, as well as tempos. Empty tracks 21 | (with no note) will be removed from the MIDI object. Notes with pitches outside 22 | of self.pitch_range will be deleted. 23 | 24 | :param midi: MIDI object to preprocess. 25 | """ 26 | # Merge instruments of the same program / inst before preprocessing them 27 | # This allows to avoid potential duplicated notes in some multitrack settings 28 | if self.config.use_programs and self.one_token_stream: 29 | merge_same_program_tracks(midi.instruments) 30 | 31 | t = 0 32 | while t < len(midi.instruments): 33 | # quantize notes attributes 34 | self._quantize_notes(midi.instruments[t].notes, midi.ticks_per_beat) 35 | # sort notes 36 | midi.instruments[t].notes.sort(key=lambda x: (x.start, x.pitch, x.end)) 37 | # remove possible duplicated notes 38 | if self.config.additional_params.get("remove_duplicates", False): 39 | remove_duplicated_notes(midi.instruments[t].notes) 40 | if len(midi.instruments[t].notes) == 0: 41 | del midi.instruments[t] 42 | continue 43 | 44 | # Quantize sustain pedal and pitch bend 45 | if self.config.use_sustain_pedals: 46 | self._quantize_sustain_pedals( 47 | midi.instruments[t].pedals, midi.ticks_per_beat 48 | ) 49 | if self.config.use_pitch_bends: 50 | self._quantize_pitch_bends( 51 | midi.instruments[t].pitch_bends, midi.ticks_per_beat 52 | ) 53 | t += 1 54 | 55 | # Recalculate max_tick is this could have changed after notes quantization 56 | if len(midi.instruments) > 0: 57 | midi.max_tick = max( 58 | [max([note.end for note in track.notes]) for track in midi.instruments] 59 | ) 60 | 61 | if self.config.use_tempos: 62 | self._quantize_tempos(midi.tempo_changes, midi.ticks_per_beat) 63 | 64 | if len(midi.time_signature_changes) == 0: # can sometimes happen 65 | midi.time_signature_changes.append( 66 | TimeSignature(*TIME_SIGNATURE, 0) 67 | ) # 4/4 by default in this case 68 | if self.config.use_time_signatures: 69 | self._quantize_time_signatures( 70 | midi.time_signature_changes, midi.ticks_per_beat 71 | ) 72 | -------------------------------------------------------------------------------- /scoreperformer/data/tokenizers/spmuple/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SPMupleBase 2 | from .encodings import ( 3 | SPMupleOnset, 4 | SPMupleBeat, 5 | SPMupleBar, 6 | SPMupleWindow, 7 | SPMupleWindowRecompute 8 | ) 9 | from .spmuple import SPMuple 10 | from .spmuple2 import SPMuple2 11 | -------------------------------------------------------------------------------- /scoreperformer/data/tokenizers/spmuple/base.py: -------------------------------------------------------------------------------- 1 | """ SPMuple (ScorePerformanceMusic-tuple) encoding for aligned score-performance music. """ 2 | 3 | from abc import abstractmethod 4 | from typing import List, Optional, Union, Any 5 | 6 | import numpy as np 7 | from miditok.midi_tokenizer import _in_as_seq 8 | from miditoolkit import MidiFile, Note 9 | 10 | from ..classes import TokSequence 11 | from ..common import OctupleM 12 | from ..constants import TIME_DIVISION, SCORE_KEYS 13 | 14 | 15 | class SPMupleBase(OctupleM): 16 | r"""SPMupleBase: a base class for a family of ScorePerformanceMusic-tuple encodings. 17 | 18 | An extended OctupleM encoding with performance-specific tokens for performance MIDIs. 19 | """ 20 | 21 | def _tweak_config_before_creating_voc(self): 22 | super()._tweak_config_before_creating_voc() 23 | 24 | # token vocabulary bins 25 | self.config.additional_params["token_bins"] = self.config.additional_params.get("token_bins", {}) 26 | 27 | # midi postprocessing 28 | self.config.additional_params["cut_overlapping_notes"] = True 29 | 30 | def preprocess_midi(self, midi: MidiFile, is_score: bool = True): 31 | r"""Preprocess a MIDI file to be used by SPMuple encoding. 32 | 33 | :param midi: MIDI object to preprocess 34 | :param is_score: whether MIDI object is a score MIDI or not 35 | """ 36 | super().preprocess_midi(midi) 37 | 38 | def preprocess_score_midi(self, midi: MidiFile): 39 | r"""Preprocess a score MIDI file to be used by SPMuple encoding. 40 | 41 | :param midi: MIDI object to preprocess 42 | """ 43 | self.preprocess_midi(midi, is_score=True) 44 | 45 | def preprocess_performance_midi(self, midi: MidiFile): 46 | r"""Preprocess a performance MIDI file to be used by SPMuple encoding. 47 | 48 | :param midi: MIDI object to preprocess 49 | """ 50 | self.preprocess_midi(midi, is_score=False) 51 | 52 | def score_midi_to_tokens(self, midi: MidiFile) -> TokSequence: 53 | r"""Converts a MIDI file to a score tokens representation, a sequence of "time steps" of tokens. 54 | 55 | A time step is a list of tokens where: 56 | (list index: token type) 57 | 0: Bar 58 | 1: Position 59 | 2: Pitch 60 | 3: Velocity 61 | 4: Duration 62 | (5: Tempo) 63 | (6: TimeSignature) 64 | (7: Program) 65 | 66 | :param midi: the MIDI objet to convert 67 | :return: a :class:`miditok.TokSequence`. 68 | """ 69 | return super().midi_to_tokens(midi) 70 | 71 | def performance_midi_to_tokens( 72 | self, 73 | midi: MidiFile, 74 | score_tokens: TokSequence, 75 | alignment: Optional[np.ndarray] = None 76 | ) -> TokSequence: 77 | r"""Tokenizes a performance MIDI file in to :class:`miditok.TokSequence`. 78 | 79 | :param midi: the MIDI object to convert. 80 | :param score_tokens: corresponding score tokens :class:`miditok.TokSequence`. 81 | :param alignment: optional alignment between performance and score tokens. 82 | :return: a :class:`miditok.TokSequence`. 83 | """ 84 | # Check if the durations values have been calculated before for this time division 85 | if midi.ticks_per_beat not in self._durations_ticks: 86 | self._durations_ticks[midi.ticks_per_beat] = np.array( 87 | [ 88 | (beat * res + pos) * midi.ticks_per_beat // res 89 | for beat, pos, res in self.durations 90 | ] 91 | ) 92 | 93 | # Preprocess the MIDI file 94 | self.preprocess_performance_midi(midi) 95 | 96 | # Register MIDI metadata 97 | self._current_midi_metadata = { 98 | "time_division": midi.ticks_per_beat, 99 | "max_tick": midi.max_tick, 100 | "tempo_changes": midi.tempo_changes, 101 | "time_sig_changes": midi.time_signature_changes, 102 | "key_sig_changes": midi.key_signature_changes, 103 | } 104 | 105 | tokens = self._performance_midi_to_tokens(midi, score_tokens, alignment) 106 | 107 | return tokens 108 | 109 | @abstractmethod 110 | def _performance_midi_to_tokens( 111 | self, 112 | midi: MidiFile, 113 | score_tokens: TokSequence, 114 | alignment: Optional[np.ndarray] = None 115 | ) -> TokSequence: 116 | r"""Converts a MIDI file to a performance tokens representation, a sequence of "time steps" 117 | of score tokens stacked with performance specific features (e.g., OnsetDeviation). 118 | 119 | :param midi: the MIDI object to convert. 120 | :param score_tokens: corresponding score tokens :class:`miditok.TokSequence`. 121 | :param alignment: optional alignment between performance and score tokens. 122 | :return: the performance token representation, i.e. tracks converted into sequences of tokens 123 | """ 124 | ... 125 | 126 | @_in_as_seq() 127 | def score_tokens_to_midi( 128 | self, 129 | tokens: Union[TokSequence, List, np.ndarray, Any], 130 | output_path: Optional[str] = None, 131 | time_division: int = TIME_DIVISION, 132 | ) -> MidiFile: 133 | r"""Converts score tokens (:class:`miditok.TokSequence`) into a MIDI and saves it. 134 | 135 | :param tokens: tokens to convert. Can be either a list of :class:`miditok.TokSequence`, 136 | :param output_path: path to save the file. (default: None) 137 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI to create). 138 | :return: the midi object (:class:`miditoolkit.MidiFile`). 139 | """ 140 | return self.tokens_to_midi(tokens, output_path=output_path, time_division=time_division) 141 | 142 | @abstractmethod 143 | @_in_as_seq() 144 | def performance_tokens_to_midi( 145 | self, 146 | tokens: Union[TokSequence, List, np.ndarray, Any], 147 | output_path: Optional[str] = None, 148 | time_division: int = TIME_DIVISION, 149 | ) -> MidiFile: 150 | r"""Converts performance tokens (:class:`miditok.TokSequence`) into a MIDI and saves it. 151 | 152 | :param tokens: tokens to convert. Can be either a list of :class:`miditok.TokSequence`, 153 | :param output_path: path to save the file. (default: None) 154 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI to create). 155 | :return: the midi object (:class:`miditoolkit.MidiFile`). 156 | """ 157 | ... 158 | 159 | @abstractmethod 160 | @_in_as_seq() 161 | def score_tokens_as_performance(self, score_tokens: Union[TokSequence, List, np.ndarray, Any]) -> TokSequence: 162 | r"""Converts a sequence of score tokens into a sequence of performance tokens, 163 | the tokens corresponding to a deadpan performance with no variation from score notes. 164 | """ 165 | ... 166 | 167 | def _quantize_notes(self, notes: List[Note], time_division: int, is_score: bool = True): 168 | r"""Quantize the notes attributes: their pitch, velocity, start and end values. 169 | It shifts the notes so that they start at times that match the time resolution 170 | (e.g. 16 samples per bar). 171 | Notes with pitches outside of self.pitch_range will be deleted. 172 | 173 | :param notes: notes to quantize. 174 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed). 175 | :param is_score: whether the notes are from score MIDI or not 176 | """ 177 | super()._quantize_notes(notes, time_division) 178 | 179 | @abstractmethod 180 | def _create_base_vocabulary(self) -> List[List[str]]: 181 | r"""Creates the vocabulary, as a list of string tokens. 182 | 183 | :return: the vocabulary as a list of string. 184 | """ 185 | return super()._create_base_vocabulary() 186 | 187 | @abstractmethod 188 | def _get_token_types(self) -> List[str]: 189 | r"""Creates an ordered list of available token types.""" 190 | return super()._get_token_types() 191 | 192 | @property 193 | def score_sizes(self): 194 | return { 195 | key: value for key, value in self.sizes.items() 196 | if key in SCORE_KEYS 197 | } 198 | 199 | @property 200 | def performance_sizes(self): 201 | return self.sizes 202 | -------------------------------------------------------------------------------- /scoreperformer/data/tokenizers/spmuple/encodings.py: -------------------------------------------------------------------------------- 1 | from .spmuple import SPMuple 2 | from .spmuple2 import SPMuple2 3 | 4 | 5 | class SPMupleOnset(SPMuple2): 6 | def _tweak_config_before_creating_voc(self): 7 | super()._tweak_config_before_creating_voc() 8 | 9 | self.config.additional_params["use_position_shifts"] = True 10 | self.config.additional_params["use_onset_indices"] = True 11 | 12 | self.config.additional_params["onset_tempos"] = True 13 | 14 | 15 | class SPMupleBeat(SPMuple): 16 | def _tweak_config_before_creating_voc(self): 17 | super()._tweak_config_before_creating_voc() 18 | 19 | self.config.additional_params["use_position_shifts"] = True 20 | self.config.additional_params["use_onset_indices"] = True 21 | self.config.additional_params["rel_onset_dev"] = True 22 | self.config.additional_params["rel_perf_duration"] = True 23 | 24 | self.config.additional_params["bar_tempos"] = False 25 | 26 | 27 | class SPMupleBar(SPMuple): 28 | def _tweak_config_before_creating_voc(self): 29 | super()._tweak_config_before_creating_voc() 30 | 31 | self.config.additional_params["use_position_shifts"] = True 32 | self.config.additional_params["use_onset_indices"] = True 33 | self.config.additional_params["rel_onset_dev"] = True 34 | self.config.additional_params["rel_perf_duration"] = True 35 | 36 | self.config.additional_params["bar_tempos"] = True 37 | 38 | 39 | class SPMupleWindow(SPMuple2): 40 | def _tweak_config_before_creating_voc(self): 41 | super()._tweak_config_before_creating_voc() 42 | 43 | self.config.additional_params["use_position_shifts"] = True 44 | self.config.additional_params["use_onset_indices"] = True 45 | 46 | self.config.additional_params["use_quantized_tempos"] = True 47 | self.config.additional_params["decode_recompute_tempos"] = False 48 | 49 | 50 | class SPMupleWindowRecompute(SPMuple2): 51 | def _tweak_config_before_creating_voc(self): 52 | super()._tweak_config_before_creating_voc() 53 | 54 | self.config.additional_params["use_position_shifts"] = True 55 | self.config.additional_params["use_onset_indices"] = True 56 | 57 | self.config.additional_params["use_quantized_tempos"] = self.config.additional_params.get( 58 | "use_quantized_tempos", True 59 | ) 60 | self.config.additional_params["decode_recompute_tempos"] = True 61 | -------------------------------------------------------------------------------- /scoreperformer/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | from .trainer_config import TrainerConfig 3 | -------------------------------------------------------------------------------- /scoreperformer/experiments/components.py: -------------------------------------------------------------------------------- 1 | """ A centralized Experiment Components initialization. """ 2 | 3 | import copy 4 | import os 5 | from dataclasses import dataclass, fields 6 | from typing import Union, Optional 7 | 8 | from omegaconf import DictConfig, OmegaConf 9 | from torch.utils.data import Dataset 10 | 11 | from scoreperformer.data import DATASETS, COLLATORS 12 | from scoreperformer.data.tokenizers.constants import MASK_TOKEN 13 | from scoreperformer.models import MODELS, EVALUATORS 14 | from scoreperformer.modules.constructor import ModuleConfig 15 | from scoreperformer.utils.config import disable_nodes, Config 16 | from .trainer_config import TrainerConfig 17 | 18 | 19 | @dataclass 20 | class ExperimentConfig(Config): 21 | data: DictConfig 22 | model: Union[DictConfig, ModuleConfig] 23 | trainer: Union[DictConfig, TrainerConfig] 24 | evaluator: Optional[DictConfig] = None 25 | 26 | 27 | ExperimentConfigFields = list(map(lambda f: f.name, fields(ExperimentConfig))) 28 | 29 | 30 | def resolve_config_hierarchy(config: DictConfig, config_root: Optional[str] = None): 31 | if "base" in config and config.base is not None: 32 | base_configs_names = [config.base] if isinstance(config.base, str) else config.base 33 | 34 | base_configs = [] 35 | for base_config_name in base_configs_names: 36 | config_root = config_root or './' 37 | base_config_path = os.path.join(config_root, base_config_name) 38 | base_config = OmegaConf.load(base_config_path) 39 | base_config = resolve_config_hierarchy(base_config, config_root=config_root) 40 | base_configs.append(base_config) 41 | 42 | del config["base"] 43 | 44 | config = OmegaConf.merge(*base_configs[::-1], config) 45 | 46 | return config 47 | 48 | 49 | class ExperimentComponents: 50 | def __init__(self, config: Union[ExperimentConfig, str], config_root: Optional[str] = None): 51 | if isinstance(config, str): 52 | config = os.path.join(config_root, config) if config_root is not None else config 53 | config = OmegaConf.load(config) 54 | if not isinstance(config, DictConfig): 55 | config = OmegaConf.load(config) 56 | 57 | config = resolve_config_hierarchy(config, config_root=config_root) 58 | 59 | assert all([key in config for key in ExperimentConfigFields[:3]]), \ 60 | f"ExperimentConfig is missing one of the keys: {ExperimentConfigFields[:3]}" 61 | 62 | disable_nodes(config) 63 | OmegaConf.resolve(config) 64 | 65 | self.config = ExperimentConfig( 66 | data=config.data, 67 | model=config.model, 68 | trainer=config.trainer, 69 | evaluator=config.get("evaluator", None) 70 | ) 71 | 72 | self.train_dataset = None 73 | self.eval_dataset = None 74 | self.collator = None 75 | self.model = None 76 | self.evaluator = None 77 | 78 | def init_components(self): 79 | self.init_datasets() 80 | self.init_collator() 81 | self.init_model() 82 | self.init_evaluator() 83 | 84 | return self.model, self.train_dataset, self.eval_dataset, self.collator, self.evaluator 85 | 86 | def init_datasets(self): 87 | cfg = self.config.data.dataset 88 | self.train_dataset = build_dataset(cfg, split=cfg._splits_.train) if cfg._splits_.train is not None else None 89 | self.eval_dataset = build_dataset(cfg, split=cfg._splits_.eval) if cfg._splits_.eval is not None else None 90 | 91 | return self.train_dataset, self.eval_dataset 92 | 93 | def init_collator(self): 94 | dataset = self.train_dataset or self.eval_dataset 95 | assert dataset is not None 96 | 97 | cfg = self.config.data.collator 98 | cfg.mask_token_id = dataset.tokenizer[0, MASK_TOKEN] 99 | self.collator = build_collator(cfg) 100 | 101 | return self.collator 102 | 103 | def init_model(self, inject_data: bool = True): 104 | cfg = self.config.model 105 | 106 | dataset = None 107 | if inject_data: 108 | dataset = self.train_dataset or self.eval_dataset 109 | assert dataset is not None 110 | 111 | self.model = build_model(cfg, dataset=dataset) 112 | 113 | return self.model 114 | 115 | def init_evaluator(self): 116 | assert self.model is not None 117 | dataset = self.train_dataset or self.eval_dataset 118 | 119 | cfg = self.config.evaluator 120 | self.evaluator = build_evaluator( 121 | cfg, model=self.model, tokenizer=dataset.tokenizer 122 | ) 123 | 124 | return self.evaluator 125 | 126 | 127 | def build_dataset(config, split='train', eval_mode=False): 128 | if config._name_ in DATASETS: 129 | config = copy.deepcopy(config) 130 | config.sample = config.sample and split in ('train', 'all') and not eval_mode 131 | 132 | dataset_cls = DATASETS[config._name_] 133 | config = {key: value for key, value in config.items() if key[0] != "_"} 134 | dataset = dataset_cls(split=split, **config) 135 | return dataset 136 | else: 137 | raise ValueError( 138 | f"Invalid dataset type: {config._name_}. Supported types: {list(DATASETS.keys())}" 139 | ) 140 | 141 | 142 | def build_collator(config): 143 | if config._name_ in COLLATORS: 144 | config = copy.deepcopy(config) 145 | 146 | collator_cls = COLLATORS[config._name_] 147 | config = {key: value for key, value in config.items() if key[0] != "_"} 148 | collator = collator_cls(**config) 149 | return collator 150 | else: 151 | raise ValueError( 152 | f"Invalid data collator type: {config._name_}. Supported types: {list(COLLATORS.keys())}" 153 | ) 154 | 155 | 156 | def build_model(config, *, dataset: Optional[Dataset] = None, **kwargs): 157 | if config._name_ in MODELS: 158 | model_cls = MODELS[config._name_] 159 | model_cls.inject_data_config(config, dataset) 160 | model = model_cls.init(config, **kwargs) 161 | model_cls.cleanup_config(config) 162 | return model 163 | else: 164 | raise ValueError( 165 | f"Invalid model type: {config._name_}. Supported types: {list(MODELS.keys())}" 166 | ) 167 | 168 | 169 | def build_evaluator(config, **kwargs): 170 | if config is not None and config._name_ in EVALUATORS: 171 | evaluator_cls = EVALUATORS[config._name_] 172 | config = {key: value for key, value in config.items() if key[0] != "_"} 173 | config.update(**kwargs) 174 | return evaluator_cls(**config) 175 | else: 176 | return None 177 | -------------------------------------------------------------------------------- /scoreperformer/experiments/integrations.py: -------------------------------------------------------------------------------- 1 | """ 2 | External Trainer integraions. 3 | 4 | Adapted from: https://github.com/huggingface/transformers/blob/main/src/transformers/integrations.py 5 | """ 6 | import json 7 | 8 | from omegaconf import OmegaConf 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from .callbacks import TrainerCallback 12 | 13 | 14 | class TensorBoardCallback(TrainerCallback): 15 | """ 16 | A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard). 17 | Args: 18 | tb_writer (`SummaryWriter`, *optional*): 19 | The writer to use. Will instantiate one if not set. 20 | """ 21 | 22 | def __init__(self, tb_writer=None): 23 | self.tb_writer = tb_writer 24 | 25 | def _init_summary_writer(self, config, log_dir=None): 26 | log_dir = log_dir or config.log_dir 27 | self.tb_writer = SummaryWriter(log_dir=log_dir) 28 | 29 | def on_train_begin(self, config, state, control, **kwargs): 30 | if self.tb_writer is None: 31 | self._init_summary_writer(config) 32 | 33 | if self.tb_writer is not None: 34 | self.tb_writer.add_text("trainer_config", config.to_json_string()) 35 | exp_config = kwargs.get("exp_config", None) 36 | if exp_config is not None: 37 | def to_json_string(cfg): 38 | return json.dumps(OmegaConf.to_container(cfg, resolve=True), indent=2) 39 | self.tb_writer.add_text("data_config", to_json_string(exp_config.data)) 40 | self.tb_writer.add_text("model_config", to_json_string(exp_config.model)) 41 | 42 | def on_log(self, args, state, control, logs=None, **kwargs): 43 | if self.tb_writer is None: 44 | self._init_summary_writer(args) 45 | 46 | if self.tb_writer is not None: 47 | for k, v in logs.items(): 48 | if isinstance(v, (int, float)): 49 | self.tb_writer.add_scalar(k, v, state.global_step) 50 | self.tb_writer.flush() 51 | 52 | def on_train_end(self, args, state, control, **kwargs): 53 | if self.tb_writer: 54 | self.tb_writer.close() 55 | self.tb_writer = None 56 | -------------------------------------------------------------------------------- /scoreperformer/experiments/logging/__init__.py: -------------------------------------------------------------------------------- 1 | from .console_logger import setup_logger 2 | -------------------------------------------------------------------------------- /scoreperformer/experiments/logging/console_logger.py: -------------------------------------------------------------------------------- 1 | """ Console and File Logger. """ 2 | from typing import Optional 3 | 4 | from loguru import logger 5 | 6 | 7 | def setup_logger(file: Optional[str] = None, **extra): 8 | import sys 9 | 10 | time = "{time:YYYY-MM-DD HH:mm:ss}" 11 | level = "{level:<7}" 12 | process = "{extra[process]}" 13 | node_info = "{extra[node_info]}" 14 | module = "{name}:{function}:{line}" 15 | message = "{message}" 16 | 17 | # formatter = f"{time} {level} {module} - {message_level:6s} {message}" 18 | # formatter = f"{time} {level} - {message_level:6s} {message}" 19 | # formatter = f"{time} {level} {module} - {message}" 20 | formatter = f"{time} {level} - {message}" 21 | handlers = [dict(sink=sys.stdout, format=formatter, enqueue=True)] 22 | if file is not None: 23 | handlers.append(dict(sink=file, format=formatter, enqueue=True)) 24 | 25 | logger.configure( 26 | handlers=handlers, 27 | extra=extra 28 | ) 29 | 30 | return logger 31 | -------------------------------------------------------------------------------- /scoreperformer/experiments/optimizers.py: -------------------------------------------------------------------------------- 1 | """ Optimizers and Learning Rate Schedulers """ 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Optional, Union, List, Dict 5 | 6 | import torch 7 | from torch.cuda import amp 8 | 9 | _optimizers = { 10 | "sgd": torch.optim.SGD, 11 | "adam": torch.optim.Adam, 12 | "adamw": torch.optim.AdamW 13 | } 14 | 15 | 16 | def switch_optimizer(optimizer_name: str): 17 | optimizer = getattr(torch.optim, optimizer_name, None) 18 | if optimizer is not None: 19 | return optimizer 20 | 21 | if optimizer_name in _optimizers: 22 | optimizer = _optimizers[optimizer_name] 23 | else: 24 | raise ValueError( 25 | f"There is no such name for optimizers: {optimizer_name}. " 26 | f"Valid optimizers: {list(_optimizers.keys())}" 27 | ) 28 | 29 | return optimizer 30 | 31 | 32 | def get_optimizer( 33 | optimizer_name: str, 34 | optimizer_params: dict, 35 | lr: float, 36 | model: torch.nn.Module = None, 37 | parameters: List = None, 38 | ) -> torch.optim.Optimizer: 39 | """Find, initialize and return an optimizer. 40 | Args: 41 | optimizer_name (str): Optimizer name. 42 | optimizer_params (dict): Optimizer parameters. 43 | lr (float): Initial learning rate. 44 | model (torch.nn.Module): Model to pass to the optimizer. 45 | Returns: 46 | torch.optim.Optimizer: Functional optimizer. 47 | """ 48 | optimizer = switch_optimizer(optimizer_name) 49 | if model is not None: 50 | parameters = model.parameters() 51 | return optimizer(parameters, lr=lr, **optimizer_params) 52 | 53 | 54 | _lr_schedulers = { 55 | "plateau": torch.optim.lr_scheduler.ReduceLROnPlateau, 56 | "exponential": torch.optim.lr_scheduler.ExponentialLR, 57 | } 58 | 59 | 60 | def switch_lr_scheduler(lr_scheduler_name: str): 61 | lr_scheduler = getattr(torch.optim, lr_scheduler_name, None) 62 | if lr_scheduler is not None: 63 | return lr_scheduler 64 | 65 | if lr_scheduler_name in _lr_schedulers: 66 | lr_scheduler = _lr_schedulers[lr_scheduler_name] 67 | else: 68 | raise ValueError( 69 | f"There is no such name for lr_schedulers: {lr_scheduler_name}. " 70 | f"Valid lr_schedulers: {list(_lr_schedulers.keys())}" 71 | ) 72 | 73 | return lr_scheduler 74 | 75 | 76 | def get_lr_scheduler( 77 | lr_scheduler: str, lr_scheduler_params: Dict, optimizer: torch.optim.Optimizer 78 | ) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access 79 | """Find, initialize and return a scheduler. 80 | Args: 81 | lr_scheduler (str): Scheduler name. 82 | lr_scheduler_params (Dict): Scheduler parameters. 83 | optimizer (torch.optim.Optimizer): Optimizer to pass to the scheduler. 84 | Returns: 85 | torch.optim.lr_scheduler._LRScheduler: Functional scheduler. 86 | """ 87 | if lr_scheduler is None: 88 | return None 89 | scheduler = switch_lr_scheduler(lr_scheduler) 90 | return scheduler(optimizer, **lr_scheduler_params) 91 | 92 | 93 | @dataclass 94 | class OptimizerConfig: 95 | lr: Union[float, List[float]] = field( 96 | default=0.001, metadata={"help": "Learning rate for each optimizer. Defaults to 0.001"} 97 | ) 98 | optimizer: Union[str, List[str]] = field( 99 | default="Adam", metadata={"help": "Optimizer(s) to use. Defaults to None"} 100 | ) 101 | optimizer_params: Union[Dict, List[Dict]] = field( 102 | default_factory=dict, metadata={"help": "Optimizer(s) arguments. Defaults to {}"} 103 | ) 104 | lr_scheduler: Union[str, List[str]] = field( 105 | default=None, metadata={"help": "Learning rate scheduler(s) to use. Defaults to None"} 106 | ) 107 | lr_scheduler_params: Dict = field( 108 | default_factory=dict, metadata={"help": "Learning rate scheduler(s) arguments. Defaults to {}"} 109 | ) 110 | grad_clip: Optional[float] = field( 111 | default=None, metadata={"help": "Gradient clipping threshold. Defaults to None"} 112 | ) 113 | grad_accum_steps: Optional[int] = field( 114 | default=1, metadata={"help": "Gradient accumulation steps. Defaults to 1"} 115 | ) 116 | mixed_precision: bool = field( 117 | default=False, metadata={"help": "Use mixed precision for training. Defaults to False"} 118 | ) 119 | 120 | 121 | class Optimizer: 122 | """ Combined Optimizer and Scheduler class. """ 123 | 124 | def __init__( 125 | self, 126 | config: OptimizerConfig, 127 | model: torch.nn.Module = None, 128 | parameters: List = None 129 | ): 130 | self.config = config 131 | 132 | self.optimizer = get_optimizer( 133 | config.optimizer, 134 | config.optimizer_params or {}, 135 | config.lr, 136 | model=model, 137 | parameters=parameters 138 | ) 139 | 140 | self.lr_scheduler = get_lr_scheduler( 141 | config.lr_scheduler, 142 | config.lr_scheduler_params, 143 | optimizer=self.optimizer 144 | ) 145 | 146 | self._scaler = amp.GradScaler(enabled=config.mixed_precision) 147 | self.grad_clip = config.grad_clip 148 | self.grad_accum_steps = config.grad_accum_steps 149 | self.mixed_precision = config.mixed_precision 150 | 151 | def step(self, loss_value, step_optimizer=True): 152 | self._scaler.scale(loss_value / self.grad_accum_steps).backward() 153 | 154 | grad_norm = None 155 | if step_optimizer: 156 | self._scaler.unscale_(self.optimizer) 157 | 158 | if self.grad_clip is not None: 159 | parameters = self.optimizer.param_groups[0]["params"] 160 | grad_norm = torch.nn.utils.clip_grad_norm_(parameters, self.grad_clip) 161 | if torch.isnan(grad_norm) or torch.isinf(grad_norm): 162 | grad_norm = None 163 | 164 | self._scaler.step(self.optimizer) 165 | self._scaler.update() 166 | 167 | self.optimizer.zero_grad() 168 | 169 | return grad_norm 170 | 171 | def anneal_on_epoch_end(self, *args): 172 | if self.lr_scheduler is not None: 173 | args = args if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) else [] 174 | self.lr_scheduler.step(*args) 175 | 176 | def anneal_on_iter_end(self, *args): 177 | if self.lr_scheduler is not None: 178 | self.lr_scheduler.step() 179 | 180 | def get_last_lr(self): 181 | return self.lr_scheduler.get_last_lr() 182 | 183 | def load_state_dict(self, state_dict, restore_lr=True): 184 | self.optimizer.load_state_dict(state_dict["optimizer"]) 185 | 186 | if restore_lr: 187 | lr_scheduler_params = state_dict.get("lr_scheduler", None) 188 | if lr_scheduler_params is not None: 189 | self.lr_scheduler.load_state_dict(lr_scheduler_params) 190 | else: 191 | base_lr = self.lr_scheduler.get_last_lr() 192 | for lr, param_group in zip(base_lr, self.optimizer.param_groups): 193 | param_group["lr"] = lr 194 | 195 | def state_dict(self): 196 | return { 197 | "optimizer": self.optimizer.state_dict(), 198 | "lr_scheduler": self.lr_scheduler.state_dict() 199 | } 200 | 201 | def set_progress(self, iteration, epoch): 202 | for key, value in self.optimizer.state.items(): 203 | if "step" in value: 204 | self.optimizer.state[key]["step"] = iteration 205 | 206 | self.lr_scheduler.last_epoch = epoch 207 | self.lr_scheduler._step_count = epoch + 1 208 | 209 | def __repr__(self): 210 | return f"Optimizer(optimizer={self.optimizer}, lr_scheduler={self.lr_scheduler})" 211 | -------------------------------------------------------------------------------- /scoreperformer/experiments/trainer_utils.py: -------------------------------------------------------------------------------- 1 | """ Trainer utility functions. """ 2 | 3 | from pathlib import Path 4 | from typing import Union, List 5 | 6 | from omegaconf import ListConfig 7 | 8 | from scoreperformer.utils import ExplicitEnum 9 | 10 | 11 | def resolve_path(constructor: Union[Path, str, List[str]]): 12 | return Path(*constructor) if isinstance(constructor, (List, ListConfig)) else Path(constructor) 13 | 14 | 15 | class Accumulator: 16 | def __init__(self): 17 | self._sums = {} 18 | self._counts = {} 19 | 20 | def __getitem__(self, key): 21 | return self._sums[key] / self._counts[key] 22 | 23 | @property 24 | def sums(self): 25 | return self._sums 26 | 27 | @property 28 | def counts(self): 29 | return self._counts 30 | 31 | @property 32 | def mean_values(self): 33 | return {key: self._sums[key] / self._counts[key] for key in self._sums if self._counts[key] > 0} 34 | 35 | def items(self): 36 | return self.mean_values.items() 37 | 38 | def add_value(self, name, value): 39 | self._sums[name] = value 40 | self._counts[name] = 1 41 | 42 | def update_value(self, name, value): 43 | if name not in self._sums: 44 | self.add_value(name, value) 45 | else: 46 | self._sums[name] += value 47 | self._counts[name] += 1 48 | 49 | def add_values(self, name_dict): 50 | for key, value in name_dict.items(): 51 | self.add_value(key, value) 52 | 53 | def update_values(self, value_dict): 54 | for key, value in value_dict.items(): 55 | self.update_value(key, value) 56 | 57 | def reset(self): 58 | for key in self._sums: 59 | self._sums[key] = 0 60 | self._counts[key] = 0 61 | 62 | def clear(self): 63 | self._sums = {} 64 | self._counts = {} 65 | 66 | 67 | class IntervalStrategy(ExplicitEnum): 68 | NO = "no" 69 | STEPS = "steps" 70 | EPOCH = "epoch" 71 | -------------------------------------------------------------------------------- /scoreperformer/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .generators import ScorePerformerGenerator 2 | from .messengers import SPMupleMessenger 3 | -------------------------------------------------------------------------------- /scoreperformer/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Model 2 | from .scoreperformer import ( 3 | Performer, 4 | ScorePerformer, 5 | ScorePerformerEvaluator 6 | ) 7 | 8 | MODELS = {name: cls for name, cls in globals().items() if ".model." in str(cls)} 9 | EVALUATORS = {name: cls for name, cls in globals().items() if ".evaluator." in str(cls)} 10 | -------------------------------------------------------------------------------- /scoreperformer/models/base.py: -------------------------------------------------------------------------------- 1 | """ Base Model class. """ 2 | 3 | from abc import abstractmethod 4 | from typing import Optional, List, Dict, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | from loguru import logger 9 | from omegaconf import DictConfig, OmegaConf 10 | from torch import Tensor 11 | from torch.utils.data import Dataset 12 | 13 | from scoreperformer.modules.constructor import Constructor, ModuleConfig 14 | 15 | 16 | class Model(nn.Module, Constructor): 17 | @abstractmethod 18 | def forward(self, *args, **kwargs): 19 | ... 20 | 21 | @abstractmethod 22 | def prepare_inputs(self, inputs) -> Dict[str, Tensor]: 23 | ... 24 | 25 | @staticmethod 26 | def allocate_inputs(inputs_dict, device): 27 | return {key: value.to(device, non_blocking=True) for key, value in inputs_dict.items()} 28 | 29 | @staticmethod 30 | def inject_data_config( 31 | config: Optional[Union[DictConfig, ModuleConfig]], 32 | dataset: Optional[Dataset] 33 | ) -> Optional[Union[DictConfig, ModuleConfig]]: 34 | return config 35 | 36 | @staticmethod 37 | def cleanup_config( 38 | config: Optional[Union[DictConfig, ModuleConfig]] 39 | ) -> Optional[Union[DictConfig, ModuleConfig]]: 40 | return config 41 | 42 | @classmethod 43 | def from_pretrained(cls, checkpoint_path: str): 44 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 45 | 46 | model_cfg = OmegaConf.create(checkpoint['model']['config']) 47 | model = cls.init(model_cfg) 48 | 49 | state_dict = checkpoint['model']['state_dict'] 50 | model.load_state_dict(state_dict, strict=True) 51 | 52 | return model 53 | 54 | def load( 55 | self, 56 | state_dict: Dict[str, Tensor], 57 | ignore_layers: Optional[List] = None, 58 | ignore_mismatched_keys: bool = False 59 | ): 60 | ignore_layers = ignore_layers or [] 61 | 62 | model_state = self.state_dict() 63 | 64 | extra_keys = [k for k in state_dict.keys() if k not in model_state] 65 | if extra_keys: 66 | logger.warning(f"The following checkpoint keys are not presented in the model " 67 | f"and will be ignored: {extra_keys}") 68 | state_dict = {k: v for k, v in state_dict.items() if k not in extra_keys} 69 | 70 | ignored_keys = [] 71 | if ignore_mismatched_keys: 72 | auto_ignore_layers = [] 73 | for k, v in state_dict.items(): 74 | if v.data.shape != model_state[k].data.shape: 75 | auto_ignore_layers.append(k) 76 | logger.info(f"Automatically found the checkpoint keys " 77 | f"incompatible with the model: {auto_ignore_layers}") 78 | ignored_keys.extend(auto_ignore_layers) 79 | 80 | if ignore_layers: 81 | for k, v in state_dict.items(): 82 | if any(layer in k for layer in ignore_layers): 83 | ignored_keys.append(k) 84 | 85 | if ignored_keys: 86 | state_dict = {k: v for k, v in state_dict.items() 87 | if all(k != key for key in ignored_keys)} 88 | logger.info(f"The following checkpoint keys were ignored: {ignored_keys}") 89 | 90 | model_state.update(state_dict) 91 | self.load_state_dict(model_state) 92 | 93 | return self 94 | 95 | def freeze(self, exception_list=None): 96 | not_frozen = [] 97 | exception_list = exception_list or [] 98 | for name, param in self.named_parameters(): 99 | param.requires_grad = any((name.startswith(layer) for layer in exception_list)) 100 | if param.requires_grad: 101 | not_frozen.append(name) 102 | logger.info(f"The model graph has been frozen, except for the following parameters: {not_frozen}") 103 | -------------------------------------------------------------------------------- /scoreperformer/models/classifiers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilya16/ScorePerformer/c04f9a3155bb6afb6ca7bcc1cac2325184cc1262/scoreperformer/models/classifiers/__init__.py -------------------------------------------------------------------------------- /scoreperformer/models/classifiers/evaluator.py: -------------------------------------------------------------------------------- 1 | """ Embedding Classifier evaluator. """ 2 | 3 | import torch 4 | 5 | 6 | class EmbeddingClassifierEvaluator: 7 | def __init__(self, model): 8 | self.model = model 9 | 10 | def _accuracy(self, predictions, labels): 11 | return (predictions == labels).float().mean() 12 | 13 | @torch.no_grad() 14 | def __call__(self, inputs, outputs): 15 | labels = inputs["labels"] 16 | predictions = torch.argmax(outputs.logits, dim=-1) 17 | metrics = {"accuracy": self._accuracy(predictions, labels)} 18 | 19 | return metrics 20 | -------------------------------------------------------------------------------- /scoreperformer/models/classifiers/model.py: -------------------------------------------------------------------------------- 1 | """ Embedding Classifier Models""" 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Optional, Dict, Sequence, Union, List 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from omegaconf import MISSING 11 | from torch import Tensor 12 | 13 | from scoreperformer.models.base import Model 14 | from scoreperformer.modules.constructor import Registry, VariableModuleConfig 15 | 16 | 17 | @dataclass 18 | class EmbeddingClassifierOutput: 19 | logits: Tensor = None 20 | loss: Optional[Tensor] = None 21 | losses: Optional[Dict[str, Tensor]] = None 22 | 23 | 24 | EmbeddingClassifiersRegistry = type("_EmbeddingClassifiersRegistry", (Registry,), {})() 25 | 26 | 27 | @dataclass 28 | class EmbeddingClassifierConfig(VariableModuleConfig): 29 | input_dim: int = MISSING 30 | num_classes: int = MISSING 31 | dropout: bool = 0. 32 | weight: Optional[List[float]] = None 33 | 34 | 35 | @dataclass 36 | class LinearEmbeddingClassifierConfig(EmbeddingClassifierConfig): 37 | _target_: str = "linear" 38 | hidden_dims: Optional[Sequence[int]] = field(default_factory=lambda: (32,)) 39 | 40 | 41 | @EmbeddingClassifiersRegistry.register("linear") 42 | class LinearEmbeddingClassifier(Model): 43 | def __init__( 44 | self, 45 | input_dim: int, 46 | num_classes: int, 47 | hidden_dims: Optional[Sequence[int]] = (32,), 48 | dropout: bool = 0., 49 | class_weights: Optional[List[float]] = None 50 | ): 51 | super().__init__() 52 | 53 | self.num_classes = num_classes 54 | 55 | class_weights = torch.ones(num_classes) if class_weights is None else torch.tensor(class_weights) 56 | self.register_buffer("class_weights", class_weights.float()) 57 | 58 | hidden_dims = hidden_dims or [] 59 | hidden_dims = [hidden_dims] if isinstance(hidden_dims, int) else hidden_dims 60 | hidden_dims = list(hidden_dims) 61 | 62 | in_dims = [input_dim] + hidden_dims 63 | out_dims = hidden_dims + [num_classes] 64 | 65 | layers = [] 66 | for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)): 67 | layers.append(nn.Linear(in_dim, out_dim)) 68 | if i < len(in_dims) - 1: 69 | layers.append(nn.ReLU()) 70 | 71 | self.layers = nn.Sequential(*layers) 72 | self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() 73 | 74 | def forward(self, embeddings: Tensor, labels: Optional[Tensor] = None): 75 | x = embeddings.squeeze(1) if embeddings.ndim == 3 else embeddings 76 | for layer in self.layers: 77 | x = layer(self.dropout(x)) 78 | logits = x 79 | 80 | loss = losses = None 81 | if labels is not None: 82 | loss = F.cross_entropy(logits, labels, weight=self.class_weights) 83 | 84 | return EmbeddingClassifierOutput( 85 | logits=logits, 86 | loss=loss, 87 | losses=losses 88 | ) 89 | 90 | def prepare_inputs(self, inputs) -> Dict[str, Tensor]: 91 | return inputs 92 | 93 | 94 | @dataclass 95 | class SequentialEmbeddingClassifierConfig(EmbeddingClassifierConfig): 96 | _target_: str = "sequential" 97 | hidden_dim: int = 32 98 | 99 | 100 | @EmbeddingClassifiersRegistry.register("sequential") 101 | class SequentialEmbeddingClassifier(Model): 102 | def __init__( 103 | self, 104 | input_dim: int, 105 | num_classes: int, 106 | hidden_dim: int = 32, 107 | dropout: bool = 0., 108 | class_weights: Optional[List[float]] = None 109 | ): 110 | super().__init__() 111 | 112 | self.num_classes = num_classes 113 | 114 | class_weights = torch.ones(num_classes) if class_weights is None else torch.tensor(class_weights) 115 | self.register_buffer("class_weights", class_weights.float()) 116 | 117 | self.gru = nn.GRU( 118 | input_size=input_dim, 119 | hidden_size=hidden_dim, 120 | batch_first=True, 121 | dropout=dropout 122 | ) 123 | 124 | self.output = nn.Linear(hidden_dim, num_classes) 125 | 126 | def forward(self, embeddings: Tensor, labels: Optional[Tensor] = None): 127 | self.gru.flatten_parameters() 128 | _, out = self.gru(embeddings) # (1, b, h) 129 | logits = self.output(out[0]) 130 | 131 | loss = losses = None 132 | if labels is not None: 133 | loss = F.cross_entropy(logits, labels, weight=self.class_weights) 134 | 135 | return EmbeddingClassifierOutput( 136 | logits=logits, 137 | loss=loss, 138 | losses=losses 139 | ) 140 | 141 | def prepare_inputs(self, inputs) -> Dict[str, Tensor]: 142 | return inputs 143 | 144 | 145 | @dataclass 146 | class MultiHeadEmbeddingClassifierOutput: 147 | logits: Dict[str, Tensor] = None 148 | loss: Optional[Tensor] = None 149 | losses: Optional[Dict[str, Tensor]] = None 150 | 151 | 152 | @dataclass 153 | class MultiHeadEmbeddingClassifierConfig(VariableModuleConfig): 154 | _target_: str = "multi-head" 155 | input_dim: int = MISSING 156 | num_classes: Dict[str, int] = MISSING 157 | classifier: LinearEmbeddingClassifierConfig = MISSING 158 | class_samples: Optional[Dict[str, List[int]]] = None 159 | weighted_classes: bool = False 160 | loss_weight: float = 1. 161 | detach_inputs: Union[bool, float] = False 162 | 163 | 164 | @EmbeddingClassifiersRegistry.register("multi-head") 165 | class MultiHeadEmbeddingClassifier(Model): 166 | def __init__( 167 | self, 168 | input_dim: int, 169 | num_classes: Dict[str, int], 170 | classifier: LinearEmbeddingClassifierConfig, 171 | class_samples: Optional[Dict[str, List[int]]] = None, 172 | loss_weight: float = 1., 173 | weighted_classes: bool = False, 174 | detach_inputs: Union[bool, float] = False 175 | ): 176 | super().__init__() 177 | 178 | self.num_classes = num_classes 179 | 180 | self.heads = nn.ModuleDict({}) 181 | for key, num in num_classes.items(): 182 | num_samples = class_samples.get(key, None) if class_samples is not None else None 183 | class_weights = self._class_weights(num_samples) if weighted_classes and num_samples is not None else None 184 | self.heads[key] = LinearEmbeddingClassifier.init( 185 | config=classifier, 186 | input_dim=input_dim, 187 | num_classes=num, 188 | class_weights=class_weights 189 | ) 190 | 191 | self.loss_weight = loss_weight 192 | self.detach_inputs = float(detach_inputs) 193 | 194 | @staticmethod 195 | def _class_weights(num_samples: List[int], beta: float = 0.999, mult: int = 1e4): 196 | num_samples = np.maximum(num_samples, 1e-6) 197 | effective_num = 1.0 - np.power(beta, np.array(num_samples) * mult) 198 | weights = (1.0 - beta) / np.array(effective_num) 199 | weights = weights / np.sum(weights) * len(num_samples) 200 | return weights.tolist() 201 | 202 | def forward(self, embeddings: Tensor, labels: Optional[Tensor] = None): 203 | embeddings = self.detach_inputs * embeddings.detach() + (1 - self.detach_inputs) * embeddings 204 | 205 | logits = {} 206 | loss, losses = 0., {} 207 | for i, (key, head) in enumerate(self.heads.items()): 208 | out = head(embeddings, labels=labels[..., i] if labels is not None else None) 209 | logits[key] = out.logits 210 | 211 | if out.loss: 212 | key = 'clf/' + key 213 | loss += out.loss 214 | losses[key] = out.loss 215 | 216 | loss = self.loss_weight * loss / len(self.heads) 217 | losses['clf'] = loss 218 | 219 | return MultiHeadEmbeddingClassifierOutput( 220 | logits=logits, 221 | loss=loss if labels is not None else None, 222 | losses=losses if labels is not None else None 223 | ) 224 | 225 | def prepare_inputs(self, inputs) -> Dict[str, Tensor]: 226 | return inputs 227 | -------------------------------------------------------------------------------- /scoreperformer/models/scoreperformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .embeddings import TupleTokenEmbeddings, TupleTokenLMHead, TupleTokenTiedLMHead 2 | from .evaluator import ScorePerformerEvaluator 3 | from .model import ( 4 | PerformerConfig, 5 | Performer, 6 | ScorePerformerConfig, 7 | ScorePerformer 8 | ) 9 | from .transformer import ( 10 | TupleTransformerConfig, 11 | TupleTransformer, 12 | TupleTransformerCaches 13 | ) 14 | from .wrappers import ScorePerformerMLMWrapper 15 | -------------------------------------------------------------------------------- /scoreperformer/models/scoreperformer/evaluator.py: -------------------------------------------------------------------------------- 1 | """ ScorePerformer metric evaluator. """ 2 | 3 | from typing import Optional, Union, List 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from scoreperformer.data.collators import LMScorePerformanceInputs 9 | from scoreperformer.data.tokenizers import OctupleM 10 | from .model import ScorePerformerOutputs 11 | from .transformer import TupleTransformerOutput 12 | from .wrappers import ScorePerformerLMModes 13 | 14 | 15 | class ScorePerformerEvaluator: 16 | def __init__( 17 | self, 18 | model, 19 | tokenizer: Optional[OctupleM] = None, 20 | label_pad_token_id: int = -100, 21 | weighted_distance: bool = False, 22 | ignore_keys: Optional[List[str]] = None 23 | ): 24 | self.model = model 25 | self.tokenizer = tokenizer 26 | self.label_pad_token_id = label_pad_token_id 27 | self.weighted_distance = weighted_distance 28 | self.ignore_keys = ignore_keys 29 | 30 | self.token_values = None 31 | if self.tokenizer is not None: 32 | self.token_values = { 33 | key: torch.from_numpy(values)[:, None] 34 | for key, values in self.tokenizer.token_values(normalize=False).items() 35 | } 36 | 37 | def _accuracy(self, predictions, labels): 38 | label_mask = labels != self.label_pad_token_id 39 | return (predictions[label_mask] == labels[label_mask]).float().mean() 40 | 41 | def _distance(self, predictions, targets): 42 | return (predictions - targets).abs().float().mean() 43 | 44 | def _weighted_distance(self, probs, targets, token_values): 45 | return ((targets[:, None] - token_values[None, :]).abs() * probs[..., None]).sum(dim=1).mean() 46 | 47 | @torch.no_grad() 48 | def __call__( 49 | self, 50 | inputs: Union[dict, LMScorePerformanceInputs], 51 | outputs: Union[TupleTransformerOutput, ScorePerformerOutputs], 52 | ignore_keys: Optional[List[str]] = None 53 | ): 54 | metrics = {} 55 | ignore_keys = ignore_keys or self.ignore_keys 56 | 57 | if isinstance(inputs, LMScorePerformanceInputs): 58 | labels = inputs.labels.tokens.to(outputs.hidden_state.device) 59 | else: 60 | labels = inputs["labels"] 61 | 62 | if self.model.mode in (ScorePerformerLMModes.CLM, ScorePerformerLMModes.MixedLM): 63 | labels = labels[:, 1:] 64 | 65 | if isinstance(outputs, ScorePerformerOutputs): 66 | outputs = outputs.perf_decoder 67 | 68 | predictions = torch.cat( 69 | list(map(lambda l: torch.argmax(l, dim=-1, keepdim=True), outputs.logits.values())), 70 | dim=-1 71 | ) 72 | 73 | metrics[f"accuracy"] = self._accuracy(predictions, labels) 74 | if ignore_keys: 75 | use_ids = torch.tensor( 76 | [i for i, key in enumerate(outputs.logits.keys()) if key not in ignore_keys], 77 | device=predictions.device, dtype=torch.long 78 | ) 79 | metrics[f"accuracy/pred"] = self._accuracy(predictions[..., use_ids], labels[..., use_ids]) 80 | 81 | for i, (key, logits) in enumerate(outputs.logits.items()): 82 | if ignore_keys and key in ignore_keys: 83 | continue 84 | 85 | if torch.any(labels[..., i] != self.label_pad_token_id): 86 | metrics[f"accuracy/{key}"] = self._accuracy(predictions[..., i], labels[..., i]) 87 | 88 | if self.token_values is not None: 89 | for i, (key, logits) in enumerate(outputs.logits.items()): 90 | if ignore_keys and key in ignore_keys: 91 | continue 92 | 93 | self.token_values[key] = self.token_values[key].to(predictions.device) 94 | 95 | label_mask = labels[..., i] != self.label_pad_token_id 96 | if torch.any(label_mask): 97 | preds = F.embedding(predictions[..., i][label_mask], self.token_values[key]) 98 | targets = F.embedding(labels[..., i][label_mask], self.token_values[key]) 99 | 100 | if self.weighted_distance: 101 | probs = outputs.logits[key].softmax(dim=-1)[label_mask] 102 | metrics[f"distance/{key}"] = self._weighted_distance(probs, targets, self.token_values[key]) 103 | else: 104 | metrics[f"distance/{key}"] = self._distance(preds, targets) 105 | 106 | return metrics 107 | -------------------------------------------------------------------------------- /scoreperformer/models/scoreperformer/transformer.py: -------------------------------------------------------------------------------- 1 | """ TupleTransformer: Transformer with support for tuple token sequences. """ 2 | 3 | from dataclasses import dataclass, MISSING, field 4 | from typing import Optional, Union, Dict, List 5 | 6 | import torch 7 | import torch.nn as nn 8 | from omegaconf import DictConfig 9 | from torch import Tensor 10 | 11 | from scoreperformer.modules.constructor import Constructor, ModuleConfig 12 | from scoreperformer.modules.transformer import ( 13 | TransformerConfig, TransformerRegistry, TransformerIntermediates, 14 | AbsolutePositionalEmbedding 15 | ) 16 | from scoreperformer.utils import ExplicitEnum 17 | from .embeddings import ( 18 | TupleTokenEmbeddingsConfig, TupleTokenEmbeddingsRegistry, 19 | TupleTokenHeadsConfig, TupleTokenHeadsRegistry, 20 | TupleTokenRegressionHeadConfig, TupleTokenRegressionHead 21 | ) 22 | 23 | 24 | class EmbeddingModes(ExplicitEnum): 25 | SUM = "mean" 26 | CONCAT = "cat" 27 | ATTENTION = "attention" 28 | ADANORM = "adanorm" 29 | 30 | 31 | @dataclass 32 | class TupleTransformerCaches: 33 | token_emb: Optional[Tensor] = None 34 | transformer: Optional[TransformerIntermediates] = None 35 | 36 | 37 | @dataclass 38 | class TupleTransformerOutput: 39 | hidden_state: Tensor 40 | logits: Optional[Dict[str, Tensor]] = None 41 | attentions: Optional[List[Tensor]] = None 42 | caches: Optional[TupleTransformerCaches] = None 43 | reg_values: Optional[Dict[str, Tensor]] = None 44 | 45 | 46 | @dataclass 47 | class TupleTransformerConfig(ModuleConfig): 48 | num_tokens: Dict[str, int] = MISSING 49 | dim: int = 512 50 | max_seq_len: int = 1024 51 | transformer: Union[DictConfig, TransformerConfig] = field( 52 | default_factory=lambda: TransformerConfig(_target_="default")) 53 | 54 | token_embeddings: Union[DictConfig, TupleTokenEmbeddingsConfig] = field( 55 | default_factory=TupleTokenEmbeddingsConfig) 56 | use_abs_pos_emb: bool = True 57 | emb_norm: bool = False 58 | emb_dropout: float = 0.0 59 | 60 | context_emb_dim: Optional[int] = None 61 | context_emb_mode: str = EmbeddingModes.ATTENTION 62 | style_emb_dim: Optional[int] = None 63 | style_emb_mode: str = EmbeddingModes.CONCAT 64 | 65 | lm_head: Optional[Union[DictConfig, TupleTokenHeadsConfig]] = None 66 | regression_head: Optional[Union[DictConfig, TupleTokenRegressionHeadConfig]] = None 67 | 68 | 69 | class TupleTransformer(nn.Module, Constructor): 70 | def __init__( 71 | self, 72 | num_tokens: Dict[str, int], 73 | dim: int = 512, 74 | max_seq_len: int = 1024, 75 | transformer: Union[DictConfig, TransformerConfig] = TransformerConfig(_target_="default"), 76 | token_embeddings: Union[DictConfig, TupleTokenEmbeddingsConfig] = TupleTokenEmbeddingsConfig(), 77 | use_abs_pos_emb: bool = True, 78 | emb_norm: bool = False, 79 | emb_dropout: float = 0.0, 80 | context_emb_dim: Optional[int] = None, 81 | context_emb_mode: str = EmbeddingModes.ATTENTION, 82 | style_emb_dim: Optional[int] = None, 83 | style_emb_mode: str = EmbeddingModes.CONCAT, 84 | lm_head: Optional[Union[DictConfig, TupleTokenHeadsConfig]] = None, 85 | regression_head: Optional[Union[DictConfig, TupleTokenRegressionHeadConfig]] = None 86 | ): 87 | super().__init__() 88 | 89 | self.dim = dim 90 | self.max_seq_len = max_seq_len 91 | emb_dim = dim # default(emb_dim, dim) 92 | 93 | self.context_emb_dim = context_emb_dim or 0 94 | self.context_emb_mode = context_emb_mode 95 | 96 | self.style_emb_dim = style_emb_dim or 0 97 | self.style_emb_mode = style_emb_mode 98 | 99 | self.token_emb = TupleTokenEmbeddingsRegistry.instantiate( 100 | config=token_embeddings, 101 | num_tokens=num_tokens, 102 | emb_dims=token_embeddings.get("emb_dims", emb_dim), 103 | project_emb_dim=emb_dim 104 | ) 105 | 106 | if self.context_emb_mode != EmbeddingModes.ATTENTION: 107 | transformer.cross_attend = False 108 | 109 | self.transformer = TransformerRegistry.instantiate( 110 | transformer, 111 | dim=dim, 112 | use_adanorm=self.style_emb_mode == EmbeddingModes.ADANORM, 113 | style_emb_dim=self.style_emb_dim 114 | ) 115 | 116 | self.pos_emb = None 117 | if use_abs_pos_emb: 118 | self.pos_emb = AbsolutePositionalEmbedding(emb_dim, self.max_seq_len) 119 | nn.init.kaiming_normal_(self.pos_emb.emb.weight) 120 | 121 | self.emb_norm = nn.LayerNorm(emb_dim) if emb_norm else nn.Identity() 122 | self.emb_dropout = nn.Dropout(emb_dropout) if emb_dropout > 0. else nn.Identity() 123 | 124 | self.project_emb = nn.Identity() 125 | total_emb_dim = ( 126 | emb_dim 127 | + int(context_emb_mode == EmbeddingModes.CONCAT) * self.context_emb_dim 128 | + int(style_emb_mode == EmbeddingModes.CONCAT) * self.style_emb_dim 129 | ) 130 | if total_emb_dim != dim: 131 | self.project_emb = nn.Linear(total_emb_dim, dim) 132 | 133 | self.lm_head = None 134 | if lm_head is not None: 135 | self.lm_head = TupleTokenHeadsRegistry.instantiate( 136 | config=lm_head, dim=dim, embeddings=self.token_emb 137 | ) 138 | 139 | self.regression_head = None 140 | if regression_head is not None: 141 | assert self.token_emb.continuous, "TupleTokenRegressionHead depends on `continuous` token embeddings." 142 | self.regression_head = TupleTokenRegressionHead.init( 143 | config=regression_head, dim=dim 144 | ) 145 | 146 | def forward( 147 | self, 148 | x: Tensor, 149 | mask: Optional[Tensor] = None, 150 | x_extra: Optional[Union[Tensor, List[Tensor]]] = None, 151 | style_embeddings: Optional[Tensor] = None, 152 | context: Optional[Tensor] = None, 153 | context_mask: Optional[Tensor] = None, 154 | caches: Optional[TupleTransformerCaches] = None, 155 | logits_keys: Optional[List] = None, 156 | return_embeddings: bool = False, 157 | return_attn: bool = False, 158 | return_caches: bool = False, 159 | **kwargs 160 | ): 161 | token_emb_cache = caches.token_emb if caches is not None else None 162 | if hasattr(self.token_emb, "multiseq_mode") and x_extra is not None: 163 | x_extra = [x_extra] if isinstance(x_extra, Tensor) else x_extra 164 | token_emb = self.token_emb([x] + x_extra, cache=token_emb_cache) 165 | else: 166 | token_emb = self.token_emb(x, cache=token_emb_cache) 167 | 168 | x = token_emb 169 | if self.pos_emb is not None: 170 | x = x + self.pos_emb(x) 171 | x = self.emb_norm(x) 172 | 173 | if context is not None and self.context_emb_mode == EmbeddingModes.CONCAT: 174 | context = context[:, :x.shape[1]] 175 | x = torch.cat([x, context], dim=-1) 176 | context = None 177 | 178 | if style_embeddings is not None: 179 | style_embeddings = style_embeddings[:, :x.shape[1]] 180 | if self.style_emb_mode == EmbeddingModes.CONCAT: 181 | x = torch.cat([x, style_embeddings], dim=-1) 182 | style_embeddings = None 183 | 184 | x = self.emb_dropout(x) 185 | x = self.project_emb(x) 186 | 187 | out, intermediates = self.transformer( 188 | x, 189 | mask=mask, 190 | context=context, 191 | context_mask=context_mask, 192 | style_embeddings=style_embeddings, 193 | intermediates_cache=caches.transformer if caches is not None else None, 194 | return_hiddens=True 195 | ) 196 | 197 | logits = None 198 | if not return_embeddings and self.lm_head is not None: 199 | logits = self.lm_head(out, keys=logits_keys) 200 | 201 | reg_values = None 202 | if not return_embeddings and self.regression_head is not None: 203 | reg_values = self.regression_head(out, keys=logits_keys) 204 | 205 | attn_maps = None 206 | if return_attn: 207 | attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attention)) 208 | 209 | caches = None 210 | if return_caches: 211 | caches = TupleTransformerCaches( 212 | token_emb=token_emb, 213 | transformer=intermediates 214 | ) 215 | 216 | return TupleTransformerOutput( 217 | hidden_state=out, 218 | logits=logits, 219 | attentions=attn_maps, 220 | caches=caches, 221 | reg_values=reg_values 222 | ) 223 | -------------------------------------------------------------------------------- /scoreperformer/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilya16/ScorePerformer/c04f9a3155bb6afb6ca7bcc1cac2325184cc1262/scoreperformer/modules/__init__.py -------------------------------------------------------------------------------- /scoreperformer/modules/constructor.py: -------------------------------------------------------------------------------- 1 | """ Config-based Module Constructor. """ 2 | import copy 3 | from dataclasses import dataclass 4 | from inspect import signature 5 | from typing import Optional, Union, Callable 6 | 7 | import torch 8 | from loguru import logger 9 | from omegaconf import DictConfig, OmegaConf, MISSING 10 | 11 | 12 | @dataclass 13 | class ModuleConfig: 14 | def update(self, **kwargs): 15 | kwargs = {key: kwargs.get(key, MISSING) for key in kwargs if not key.startswith("_")} # get rid of service keys 16 | invalid_keys = [key for key in kwargs if key not in self.__dict__] 17 | if invalid_keys: 18 | logger.warning(f"The following params are incompatible with the config {self.__name__}, " 19 | f"so they will be ignored: {invalid_keys}.") 20 | kwargs = {key: value for key, value in kwargs.items() if key not in invalid_keys} 21 | 22 | for key, value in kwargs.items(): 23 | self.__setattr__(key, value) 24 | 25 | return self 26 | 27 | def to_dict(self, check_missing=False, make_copy=True): 28 | if check_missing: 29 | missing = [key for key, value in self.__dict__.items() if value == MISSING] 30 | if missing: 31 | raise RuntimeError(f"The following params are mandatory to set: {missing}") 32 | 33 | return copy.deepcopy(self.__dict__) if make_copy else {k: v for k, v in self.__dict__.items()} 34 | 35 | 36 | class Constructor: 37 | @classmethod 38 | def _pre_init(cls, config: Optional[Union[DictConfig, ModuleConfig]] = None, **parameters): 39 | module_parameters = {key: value for key, value in parameters.items() if isinstance(value, torch.nn.Module)} 40 | parameters = {key: value for key, value in parameters.items() if key not in module_parameters} 41 | config = merge(config or {}, parameters) 42 | for key, value in module_parameters.items(): 43 | config[key] = value 44 | config = {key: value for key, value in config.items() if not key.startswith("_")} 45 | 46 | return config 47 | 48 | @classmethod 49 | def init(cls, config: Optional[Union[DictConfig, ModuleConfig]] = None, **parameters): 50 | config = cls._pre_init(config, **parameters) 51 | 52 | signature_ = dict(signature(cls.__init__).parameters) 53 | 54 | if "kwargs" not in signature_: 55 | invalid_keys = [key for key in config if key not in signature_] 56 | if invalid_keys: 57 | logger.warning(f"The following params are incompatible with the {cls.__name__} constructor, " 58 | f"so they will be ignored: {invalid_keys}.") 59 | config = {key: value for key, value in config.items() if key not in invalid_keys} 60 | 61 | missing = [key for key, value in config.items() if value == MISSING] 62 | if missing: 63 | raise RuntimeError(f"The following params are mandatory to set: {missing}") 64 | 65 | return cls(**config) 66 | 67 | 68 | def merge(*containers: Union[DictConfig, ModuleConfig, dict], as_omega=False) -> Union[dict, DictConfig]: 69 | readonly = False 70 | _containers = [] 71 | for cont in containers: 72 | if isinstance(cont, ModuleConfig): 73 | cont = cont.to_dict(make_copy=False) 74 | elif isinstance(cont, DictConfig): 75 | readonly = cont._get_flag("readonly") 76 | elif not isinstance(cont, dict): 77 | raise TypeError 78 | _containers.append(cont) 79 | 80 | merged = OmegaConf.merge(*_containers) 81 | OmegaConf.set_readonly(merged, readonly) 82 | 83 | if not as_omega: 84 | merged = dict(merged) 85 | 86 | return merged 87 | 88 | 89 | @dataclass 90 | class VariableModuleConfig(ModuleConfig): 91 | _target_: str 92 | 93 | 94 | class Registry: 95 | """ 96 | Parent class for different registries. 97 | """ 98 | 99 | def __init__(self): 100 | self._objects = {} 101 | 102 | def register( 103 | self, 104 | name: str, 105 | module: Optional[Callable] = None 106 | ): 107 | if not isinstance(name, str): 108 | raise TypeError(f"`name` must be a str, got {name}") 109 | 110 | def _register(reg_obj: Callable): 111 | self._objects[name] = reg_obj 112 | return reg_obj 113 | 114 | return _register if module is None else _register(module) 115 | 116 | def instantiate(self, config: Union[VariableModuleConfig, DictConfig], **kwargs): 117 | module: Constructor = self.get(config._target_) 118 | return module.init(config, **kwargs) 119 | 120 | def get(self, key: str): 121 | try: 122 | return self._objects[key] 123 | except KeyError: 124 | raise KeyError(f"'{key}' not found in registry. Available names: {self.available_names}") 125 | 126 | def remove(self, name): 127 | self._objects.pop(name) 128 | 129 | @property 130 | def objects(self): 131 | return self._objects 132 | 133 | @property 134 | def available_names(self): 135 | return tuple(self._objects.keys()) 136 | 137 | def __str__(self) -> str: 138 | return f"Objects={self.available_names}" 139 | -------------------------------------------------------------------------------- /scoreperformer/modules/layers.py: -------------------------------------------------------------------------------- 1 | """ PyTorch modules used by the models.""" 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | from torch import nn as nn, Tensor 7 | 8 | from scoreperformer.utils import exists 9 | 10 | 11 | # residual 12 | 13 | class Residual(nn.Module): 14 | def __init__(self, dim: int, scale_residual: bool = False, scale_residual_constant: float = 1.): 15 | super().__init__() 16 | self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None 17 | self.scale_residual_constant = scale_residual_constant 18 | 19 | def forward(self, x, residual): 20 | if exists(self.residual_scale): 21 | residual = residual * self.residual_scale 22 | 23 | if self.scale_residual_constant != 1: 24 | residual = residual * self.scale_residual_constant 25 | 26 | return x + residual 27 | 28 | 29 | # adaptive layer normalization 30 | 31 | class AdaptiveLayerNorm(nn.Module): 32 | def __init__(self, dim: int, condition_dim: int, eps: float = 1e-5): 33 | super().__init__() 34 | self.dim = dim 35 | self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) 36 | 37 | self.linear = nn.Linear(condition_dim, dim * 2) 38 | self.linear.bias.data[:dim] = 1 39 | self.linear.bias.data[dim:] = 0 40 | 41 | def forward(self, x: Tensor, condition: Optional[Tensor] = None): 42 | if condition is None: 43 | gamma, beta = x.new_ones(1), x.new_zeros(1) 44 | else: 45 | condition = condition.unsqueeze(1) if condition.ndim == 2 else condition 46 | gamma, beta = self.linear(condition).chunk(2, dim=-1) 47 | return gamma * self.norm(x) + beta 48 | -------------------------------------------------------------------------------- /scoreperformer/modules/sampling.py: -------------------------------------------------------------------------------- 1 | """ Sampling functions. """ 2 | 3 | import math 4 | from typing import Callable, Optional, Dict 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | 10 | from scoreperformer.utils import default 11 | 12 | 13 | # nucleus 14 | 15 | def top_p(logits: Tensor, thres: float = 0.9): 16 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 17 | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 18 | 19 | sorted_indices_to_remove = cum_probs > thres 20 | sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value=False) 21 | 22 | sorted_logits[sorted_indices_to_remove] = float('-inf') 23 | return sorted_logits.scatter(1, sorted_indices, sorted_logits) 24 | 25 | 26 | # topk 27 | 28 | def top_k(logits: Tensor, thres: float = 0.9, k: Optional[int] = None): 29 | k = default(k, math.ceil((1 - thres) * logits.shape[-1])) 30 | val, ind = torch.topk(logits, k) 31 | probs = torch.full_like(logits, float('-inf')) 32 | probs.scatter_(1, ind, val) 33 | return probs 34 | 35 | 36 | # top_a 37 | 38 | def top_a(logits: Tensor, min_p_pow: float = 2.0, min_p_ratio: float = 0.02): 39 | probs = F.softmax(logits, dim=-1) 40 | limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio 41 | return torch.where(probs < limit, float('-inf'), logits) 42 | 43 | 44 | # sampling 45 | 46 | def filter_logits_and_sample( 47 | logits: Tensor, 48 | filter_logits_fn: Callable, 49 | filter_kwargs: Optional[Dict[str, object]] = None, 50 | temperature: float = 1., 51 | sample: bool = True 52 | ): 53 | filter_kwargs = filter_kwargs or {} 54 | filtered_logits = filter_logits_fn(logits, **filter_kwargs) 55 | 56 | probs = F.softmax(filtered_logits / temperature, dim=-1) 57 | if not sample: 58 | return probs 59 | return torch.multinomial(probs, 1) 60 | -------------------------------------------------------------------------------- /scoreperformer/modules/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .attend import ( 2 | AttentionIntermediates, 3 | Attend 4 | ) 5 | from .attention import ( 6 | AttentionSharedIntermediates, 7 | AttentionConfig, Attention 8 | ) 9 | from .embeddings import ( 10 | DiscreteContinuousEmbedding, 11 | DiscreteDenseContinuousEmbedding, 12 | AbsolutePositionalEmbedding, 13 | FixedPositionalEmbedding, 14 | ALiBiPositionalBias, 15 | LearnedALiBiPositionalBias 16 | ) 17 | from .feedforward import ( 18 | FeedForwardConfig, 19 | FeedForward 20 | ) 21 | from .transformer import ( 22 | TransformerIntermediates, 23 | TransformerRegistry, 24 | TransformerConfig, Transformer, 25 | TransformerConfig, Transformer, 26 | EncoderConfig, Encoder, 27 | DecoderConfig, Decoder 28 | ) 29 | 30 | DEFAULT_DIM_HEAD = 64 31 | -------------------------------------------------------------------------------- /scoreperformer/modules/transformer/attend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Attention with efficient memory attention support. 3 | 4 | Adapted from: https://github.com/lucidrains/x-transformers 5 | """ 6 | 7 | from collections import namedtuple 8 | from dataclasses import dataclass 9 | from functools import partial 10 | from typing import Optional 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from einops import rearrange 15 | from packaging import version 16 | from torch import nn, einsum, Tensor 17 | 18 | from scoreperformer.utils import exists, default 19 | 20 | # constants 21 | 22 | EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', 23 | ['enable_flash', 'enable_math', 'enable_mem_efficient']) 24 | 25 | 26 | @dataclass 27 | class AttentionIntermediates: 28 | keys: Optional[Tensor] = None 29 | values: Optional[Tensor] = None 30 | qk_similarities: Optional[Tensor] = None 31 | 32 | def to_tuple(self): 33 | return self.keys, self.values, self.qk_similarities 34 | 35 | 36 | # main class 37 | 38 | class Attend(nn.Module): 39 | def __init__( 40 | self, 41 | *, 42 | dropout: float = 0., 43 | causal: bool = False, 44 | scale: Optional[float] = None, 45 | ): 46 | super().__init__() 47 | self.scale = scale 48 | self.causal = causal 49 | self.attn_fn = partial(F.softmax, dtype=torch.float32) 50 | 51 | self.dropout = dropout 52 | self.attn_dropout = nn.Dropout(dropout) 53 | 54 | # efficient attention 55 | self.efficient = version.parse(torch.__version__) >= version.parse('2.0.0') 56 | self.config = EfficientAttentionConfig(enable_flash=False, enable_math=True, enable_mem_efficient=True) 57 | 58 | def efficient_attn( 59 | self, 60 | q, k, v, 61 | mask=None, 62 | attn_bias=None 63 | ): 64 | batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device 65 | 66 | # Recommended for multi-query single-key-value attention by Tri Dao 67 | # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) 68 | 69 | intermediates = AttentionIntermediates(keys=k, values=v) 70 | 71 | if k.ndim == 3: 72 | k = rearrange(k, 'b ... -> b 1 ...').expand(-1, q.shape[1], -1, -1) 73 | 74 | if v.ndim == 3: 75 | v = rearrange(v, 'b ... -> b 1 ...').expand(-1, q.shape[1], -1, -1) 76 | 77 | # Check if mask exists and expand to compatible shape 78 | # The mask is B L, so it would have to be expanded to B H N L 79 | 80 | causal = self.causal 81 | 82 | if exists(mask): 83 | assert mask.ndim == 4 84 | mask = mask.expand(batch, heads, q_len, k_len) 85 | 86 | # manually handle causal mask, if another mask was given 87 | 88 | if causal: 89 | causal_mask = torch.ones((q_len, k_len), dtype=torch.bool, device=device).triu(k_len - q_len + 1) 90 | mask = mask & ~causal_mask 91 | causal = False 92 | 93 | # handle alibi positional bias 94 | # convert from bool to float 95 | 96 | if exists(attn_bias): 97 | attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, -1, -1, -1) 98 | 99 | # if mask given, the mask would already contain the causal mask from above logic 100 | # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number 101 | 102 | mask_value = -torch.finfo(q.dtype).max 103 | 104 | if exists(mask): 105 | attn_bias = attn_bias.masked_fill(~mask, mask_value // 2) 106 | elif causal: 107 | causal_mask = torch.ones((q_len, k_len), dtype=torch.bool, device=device).triu(k_len - q_len + 1) 108 | attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2) 109 | causal = False 110 | 111 | # scaled_dot_product_attention handles attn_mask either as bool or additive bias 112 | # make it an additive bias here 113 | 114 | mask = attn_bias 115 | 116 | # pytorch 2.0 attention: q, k, v, mask, dropout, causal, softmax_scale 117 | 118 | with torch.backends.cuda.sdp_kernel(**self.config._asdict()): 119 | out = F.scaled_dot_product_attention( 120 | q, k, v, 121 | attn_mask=mask, 122 | dropout_p=self.dropout if self.training else 0., 123 | is_causal=causal and q_len != 1 124 | ) 125 | 126 | return out, intermediates 127 | 128 | def forward( 129 | self, 130 | q, k, v, 131 | mask=None, 132 | attn_bias=None, 133 | prev_attn=None 134 | ): 135 | """ 136 | einstein notation 137 | b - batch 138 | h - heads 139 | n, i, j - sequence length (base sequence length, source, target) 140 | d - feature dimension 141 | """ 142 | 143 | if self.efficient: 144 | assert not exists(prev_attn), 'residual attention not compatible with efficient attention' 145 | return self.efficient_attn(q, k, v, mask=mask, attn_bias=attn_bias) 146 | 147 | n, device = q.shape[-2], q.device 148 | scale = default(self.scale, q.shape[-1] ** -0.5) 149 | 150 | kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' 151 | 152 | dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale 153 | 154 | if exists(prev_attn): 155 | dots = dots + prev_attn 156 | 157 | qk_similarities = dots.clone() 158 | 159 | if exists(attn_bias): 160 | dots = dots + attn_bias 161 | 162 | dtype = dots.dtype 163 | mask_value = -torch.finfo(dots.dtype).max 164 | 165 | if exists(mask): 166 | dots = dots.masked_fill(~mask, mask_value) 167 | 168 | if self.causal: 169 | i, j = dots.shape[-2:] 170 | causal_mask = torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) 171 | dots = dots.masked_fill(causal_mask, mask_value) 172 | 173 | attn = self.attn_fn(dots, dim=-1) 174 | attn = attn.type(dtype) 175 | 176 | attn = self.attn_dropout(attn) 177 | 178 | out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v) 179 | 180 | intermediates = AttentionIntermediates( 181 | keys=k, 182 | values=v, 183 | qk_similarities=qk_similarities 184 | ) 185 | 186 | return out, intermediates 187 | -------------------------------------------------------------------------------- /scoreperformer/modules/transformer/attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformer Attention with data caching support for inference. 3 | 4 | Adapted from: https://github.com/lucidrains/x-transformers 5 | """ 6 | from dataclasses import dataclass 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from einops import rearrange, repeat 13 | from torch import Tensor 14 | 15 | from scoreperformer.utils import default, or_reduce 16 | from .attend import AttentionIntermediates, Attend 17 | from .embeddings import ALiBiPositionalBias, LearnedALiBiPositionalBias 18 | from ..constructor import Constructor, ModuleConfig 19 | 20 | 21 | @dataclass 22 | class AttentionSharedIntermediates: 23 | rel_pos_bias: Optional[Tensor] = None 24 | 25 | 26 | @dataclass 27 | class AttentionConfig(ModuleConfig): 28 | dim: int = 512 29 | dim_head: int = 64 30 | heads: int = 8 31 | causal: bool = False 32 | dropout: float = 0. 33 | one_kv_head: bool = False 34 | num_mem_kv: int = 0 35 | shared_kv: bool = False 36 | value_dim_head: Optional[int] = None 37 | max_attend_past: Optional[int] = None 38 | alibi_pos_bias: bool = False 39 | alibi_num_heads: Optional[int] = None 40 | alibi_symmetric: bool = True 41 | alibi_learned: bool = False 42 | 43 | 44 | class Attention(nn.Module, Constructor): 45 | def __init__( 46 | self, 47 | dim: int, 48 | dim_head: int = 64, 49 | heads: int = 8, 50 | causal: bool = False, 51 | dropout: float = 0., 52 | one_kv_head: bool = False, 53 | num_mem_kv: int = 0, 54 | max_attend: Optional[int] = None, 55 | alibi_pos_bias: bool = False, 56 | alibi_num_heads: Optional[int] = None, 57 | alibi_symmetric: bool = True, 58 | alibi_learned: bool = False, 59 | ): 60 | super().__init__() 61 | self.scale = dim_head ** -0.5 62 | 63 | self.heads = heads 64 | self.causal = causal 65 | self.max_attend = max_attend 66 | 67 | self.one_kv_head = one_kv_head 68 | out_dim = q_dim = dim_head * heads 69 | kv_dim = dim_head if one_kv_head else dim_head * heads 70 | 71 | self.to_q = nn.Linear(dim, q_dim, bias=False) 72 | self.to_k = nn.Linear(dim, kv_dim, bias=False) 73 | self.to_v = nn.Linear(dim, kv_dim, bias=False) 74 | 75 | # relative positional bias 76 | 77 | self.rel_pos = None 78 | if alibi_pos_bias: 79 | alibi_num_heads = default(alibi_num_heads, heads) 80 | assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' 81 | alibi_pos_klass = LearnedALiBiPositionalBias if alibi_learned else ALiBiPositionalBias 82 | self.rel_pos = alibi_pos_klass( 83 | heads=alibi_num_heads, 84 | total_heads=heads, 85 | symmetric=alibi_symmetric or causal 86 | ) 87 | 88 | # attend class - includes core attention algorithm + talking heads 89 | 90 | self.attend = Attend( 91 | causal=causal, 92 | dropout=dropout, 93 | scale=self.scale 94 | ) 95 | 96 | # add memory key / values 97 | 98 | self.num_mem_kv = num_mem_kv 99 | if num_mem_kv > 0: 100 | self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 101 | self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 102 | 103 | # output layer 104 | 105 | self.to_out = nn.Linear(out_dim, dim, bias=False) 106 | 107 | def forward( 108 | self, 109 | x: Tensor, 110 | context: Optional[Tensor] = None, 111 | mask: Optional[Tensor] = None, 112 | context_mask: Optional[Tensor] = None, 113 | attn_mask: Optional[Tensor] = None, 114 | prev_attn: Optional[Tensor] = None, 115 | mem: Optional[Tensor] = None, 116 | cache: Optional[AttentionIntermediates] = None, 117 | shared_cache: Optional[AttentionSharedIntermediates] = None 118 | ): 119 | b, n = x.shape[:2] 120 | h, scale, device = self.heads, self.scale, x.device 121 | has_context, has_mem, has_cache = context is not None, mem is not None, cache is not None 122 | assert not (has_mem and has_cache), 'cache is not compatible with memory keys' 123 | assert not (has_context and has_cache), 'cache is not compatible with context yet' 124 | 125 | kv_input = default(context, x) 126 | 127 | q_input = x 128 | k_input = kv_input 129 | v_input = kv_input 130 | 131 | if has_mem: 132 | k_input = torch.cat((mem, k_input), dim=-2) 133 | v_input = torch.cat((mem, v_input), dim=-2) 134 | 135 | q = self.to_q(q_input) 136 | k = self.to_k(k_input) 137 | v = self.to_v(v_input) if self.to_v is not None else k 138 | 139 | q = rearrange(q, 'b n (h d) -> b h n d', h=h) 140 | 141 | if not self.one_kv_head: 142 | k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (k, v)) 143 | 144 | input_mask = mask if context_mask is None else context_mask 145 | 146 | if self.num_mem_kv > 0: 147 | mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) 148 | 149 | k = torch.cat((mem_k, k), dim=-2) 150 | v = torch.cat((mem_v, v), dim=-2) 151 | 152 | if input_mask is not None: 153 | input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) 154 | 155 | k = torch.cat([cache.keys, k], dim=-2) if has_cache else k 156 | v = torch.cat([cache.values, v], dim=-2) if has_cache else v 157 | 158 | i, j = map(lambda t: t.shape[-2], (q, k)) 159 | 160 | # determine masking 161 | 162 | masks = [] 163 | final_attn_mask = None 164 | 165 | if input_mask is not None: 166 | input_mask = rearrange(input_mask, 'b j -> b 1 1 j') 167 | masks.append(~input_mask) 168 | 169 | if attn_mask is not None: 170 | assert 2 <= attn_mask.ndim <= 4, \ 171 | 'attention mask must have greater than 2 dimensions but less than or equal to 4' 172 | if attn_mask.ndim == 2: 173 | attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j') 174 | elif attn_mask.ndim == 3: 175 | attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j') 176 | attn_mask = attn_mask[:, :, -1:] if has_cache else attn_mask 177 | masks.append(~attn_mask) 178 | 179 | if self.max_attend is not None: 180 | range_q = torch.arange(j - i, j, device=device) 181 | range_k = torch.arange(j, device=device) 182 | dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j') 183 | max_attend_mask = torch.logical_or(-self.max_attend < dist, dist > self.max_attend) 184 | masks.append(max_attend_mask) 185 | 186 | if len(masks) > 0: 187 | final_attn_mask = ~or_reduce(masks) 188 | 189 | # prepare relative positional bias, if needed 190 | 191 | rel_pos_bias, attn_bias = None, None 192 | if self.rel_pos is not None: 193 | if shared_cache is not None and shared_cache.rel_pos_bias is not None: 194 | rel_pos_bias = shared_cache.rel_pos_bias 195 | else: 196 | rel_pos_bias = self.rel_pos.get_bias(i, j, k=j - i).to(dtype=q.dtype) 197 | attn_bias = self.rel_pos(i, j, k=j - i, bias=rel_pos_bias) 198 | 199 | # attention is all we need 200 | 201 | out, intermediates = self.attend( 202 | q, k, v, 203 | mask=final_attn_mask, 204 | attn_bias=attn_bias, 205 | prev_attn=prev_attn 206 | ) 207 | 208 | # merge heads 209 | 210 | out = rearrange(out, 'b h n d -> b n (h d)') 211 | 212 | # combine the heads 213 | 214 | out = self.to_out(out) 215 | 216 | if mask is not None: 217 | mask = mask[:, -1:] if has_cache else mask 218 | out = out * mask[..., None] 219 | 220 | shared_intermediates = AttentionSharedIntermediates(rel_pos_bias=rel_pos_bias) 221 | 222 | return out, intermediates, shared_intermediates 223 | -------------------------------------------------------------------------------- /scoreperformer/modules/transformer/feedforward.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformer FeedForward layer. 3 | 4 | Adapted from: https://github.com/lucidrains/x-transformers 5 | """ 6 | from dataclasses import dataclass 7 | 8 | import torch.nn as nn 9 | 10 | from ..constructor import Constructor, ModuleConfig 11 | 12 | 13 | class GLU(nn.Module): 14 | def __init__(self, dim_in, dim_out, activation): 15 | super().__init__() 16 | self.act = activation 17 | self.proj = nn.Linear(dim_in, dim_out * 2) 18 | 19 | def forward(self, x): 20 | x, gate = self.proj(x).chunk(2, dim=-1) 21 | return x * self.act(gate) 22 | 23 | 24 | @dataclass 25 | class FeedForwardConfig(ModuleConfig): 26 | dim: int = 512 27 | mult: int = 4 28 | glu: bool = False 29 | swish: bool = False 30 | post_act_ln: bool = False 31 | dropout: float = 0. 32 | no_bias: bool = True 33 | 34 | 35 | class FeedForward(nn.Module, Constructor): 36 | def __init__( 37 | self, 38 | dim: int = 512, 39 | mult: int = 4, 40 | glu: bool = False, 41 | swish: bool = False, 42 | post_act_ln: bool = False, 43 | dropout: float = 0., 44 | no_bias: bool = True 45 | ): 46 | super().__init__() 47 | 48 | inner_dim = int(dim * mult) 49 | activation = nn.SiLU() if swish else nn.GELU() 50 | 51 | project_in = nn.Sequential( 52 | nn.Linear(dim, inner_dim, bias=not no_bias), 53 | activation 54 | ) if not glu else GLU(dim, inner_dim, activation) 55 | 56 | self.ff = nn.Sequential( 57 | project_in, 58 | nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim, bias=not no_bias) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.ff(x) 65 | -------------------------------------------------------------------------------- /scoreperformer/modules/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformer Attention Layers with data caching support for inference. 3 | 4 | Adapted from: https://github.com/lucidrains/x-transformers 5 | """ 6 | 7 | import copy 8 | from dataclasses import dataclass, field 9 | from functools import partial 10 | from typing import Optional, Union, List 11 | 12 | import torch 13 | import torch.nn as nn 14 | from omegaconf import DictConfig 15 | from torch import Tensor 16 | 17 | from scoreperformer.utils import equals 18 | from .attend import AttentionIntermediates 19 | from .attention import Attention, AttentionConfig 20 | from .feedforward import FeedForward, FeedForwardConfig 21 | from ..constructor import VariableModuleConfig, Constructor, Registry 22 | from ..layers import Residual, AdaptiveLayerNorm 23 | 24 | 25 | @dataclass 26 | class TransformerIntermediates: 27 | hiddens: Optional[List[Tensor]] = None 28 | attention: Optional[List[AttentionIntermediates]] = None 29 | 30 | 31 | TransformerRegistry = type("_TransformerRegistry", (Registry,), {})() 32 | 33 | 34 | @dataclass 35 | class TransformerConfig(VariableModuleConfig): 36 | _target_: str = "default" 37 | 38 | dim: int = 512 39 | depth: int = 4 40 | heads: int = 8 41 | 42 | attention: Union[AttentionConfig, DictConfig] = None 43 | feed_forward: Union[FeedForwardConfig, DictConfig] = None 44 | 45 | causal: bool = False 46 | cross_attend: bool = False 47 | only_cross: bool = False 48 | 49 | pre_norm: bool = True 50 | use_adanorm: bool = False 51 | style_emb_dim: Optional[int] = None 52 | 53 | 54 | @TransformerRegistry.register("default") 55 | class Transformer(nn.Module, Constructor): 56 | def __init__( 57 | self, 58 | dim: int = 512, 59 | depth: int = 4, 60 | heads: int = 8, 61 | attention: Union[AttentionConfig, DictConfig] = None, 62 | feed_forward: Union[FeedForwardConfig, DictConfig] = None, 63 | causal: bool = False, 64 | cross_attend: bool = False, 65 | only_cross: bool = False, 66 | pre_norm: bool = True, 67 | use_adanorm: bool = False, 68 | style_emb_dim: Optional[int] = None 69 | ): 70 | super().__init__() 71 | 72 | attention = attention if attention else AttentionConfig() 73 | feed_forward = feed_forward if feed_forward else FeedForwardConfig() 74 | 75 | self.dim = dim 76 | self.depth = depth 77 | self.layers = nn.ModuleList([]) 78 | 79 | # normalization 80 | 81 | self.pre_norm = pre_norm 82 | self.ada_norm = use_adanorm 83 | 84 | assert not use_adanorm or style_emb_dim is not None, 'condition_dim should be provided with adanorm' 85 | 86 | norm_fn = partial(AdaptiveLayerNorm, dim, style_emb_dim) if use_adanorm else partial(nn.LayerNorm, dim) 87 | 88 | # layers 89 | 90 | self.cross_attend = cross_attend 91 | 92 | if cross_attend and not only_cross: 93 | default_block = ('a', 'c', 'f') 94 | elif cross_attend and only_cross: 95 | default_block = ('c', 'f') 96 | else: 97 | default_block = ('a', 'f') 98 | 99 | # calculate layer block order 100 | 101 | self.layer_types = default_block * depth 102 | self.num_attn_layers = len(list(filter(equals('a'), self.layer_types))) 103 | 104 | # whether it has post norm 105 | 106 | self.final_norm = norm_fn() if pre_norm else nn.Identity() 107 | 108 | # iterate and construct layers 109 | 110 | for ind, layer_type in enumerate(self.layer_types): 111 | 112 | if layer_type == 'a': 113 | layer = Attention.init(config=attention, dim=dim, heads=heads, causal=causal) 114 | elif layer_type == 'c': 115 | layer = Attention.init(config=attention, dim=dim, heads=heads) 116 | elif layer_type == 'f': 117 | layer = FeedForward.init(config=feed_forward, dim=dim) 118 | else: 119 | raise Exception(f'invalid layer type {layer_type}') 120 | 121 | residual = Residual(dim) 122 | 123 | pre_branch_norm = norm_fn() if pre_norm else None 124 | post_branch_norm = None 125 | post_main_norm = norm_fn() if not pre_norm else None 126 | 127 | norms = nn.ModuleList([ 128 | pre_branch_norm, 129 | post_branch_norm, 130 | post_main_norm 131 | ]) 132 | 133 | self.layers.append(nn.ModuleList([ 134 | norms, 135 | layer, 136 | residual 137 | ])) 138 | 139 | def forward( 140 | self, 141 | x: Tensor, 142 | mask: Optional[Tensor] = None, 143 | context: Optional[Tensor] = None, 144 | context_mask: Optional[Tensor] = None, 145 | attn_mask: Optional[Tensor] = None, 146 | style_embeddings: Optional[Tensor] = None, 147 | mems: Optional[List[Tensor]] = None, 148 | intermediates_cache: Optional[TransformerIntermediates] = None, 149 | return_hiddens: bool = False 150 | ): 151 | assert not (self.cross_attend ^ (context is not None)), \ 152 | 'context must be passed in if cross_attend is set to True' 153 | assert not self.ada_norm or style_embeddings is not None, \ 154 | 'style_embeddings must be passed for AdaLayerNorm' 155 | 156 | hiddens = [] 157 | attn_intermediates = [] 158 | 159 | mems = mems.copy() if mems is not None else [None] * self.num_attn_layers 160 | 161 | has_cache = intermediates_cache is not None 162 | intermediates_cache = copy.copy(intermediates_cache) if has_cache else None 163 | 164 | x = x[:, -1:] if has_cache else x 165 | style_embeddings = style_embeddings[:, -1:] if has_cache and style_embeddings is not None else style_embeddings 166 | 167 | attn_shared_cache = None 168 | 169 | for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): 170 | cache = None 171 | 172 | if layer_type == 'a': 173 | if has_cache: 174 | cache = intermediates_cache.hiddens.pop(0) 175 | x = torch.cat([cache, x], dim=1) 176 | 177 | if return_hiddens: 178 | hiddens.append(x) 179 | 180 | x = x[:, -1:] if has_cache else x 181 | 182 | layer_mem = mems.pop(0) if mems else None 183 | 184 | if has_cache: 185 | if layer_type in ('a', 'c'): 186 | cache = intermediates_cache.attention.pop(0) 187 | 188 | residual = x 189 | 190 | pre_norm, post_branch_norm, post_main_norm = norm 191 | 192 | if pre_norm is not None: 193 | x = pre_norm(x, condition=style_embeddings) if self.ada_norm else pre_norm(x) 194 | 195 | if layer_type == 'a': 196 | out, inter, attn_shared_cache = block( 197 | x, mask=mask, attn_mask=attn_mask, 198 | mem=layer_mem, cache=cache, shared_cache=attn_shared_cache, 199 | ) 200 | elif layer_type == 'c': 201 | out, inter, _ = block(x, context=context, mask=mask, context_mask=context_mask) 202 | elif layer_type == 'f': 203 | out = block(x) 204 | 205 | if post_branch_norm is not None: 206 | out = post_branch_norm(x, condition=style_embeddings) if self.ada_norm else post_branch_norm(out) 207 | 208 | x = residual_fn(out, residual) 209 | 210 | if return_hiddens: 211 | if layer_type in ('a', 'c'): 212 | attn_intermediates.append(inter) 213 | 214 | if post_main_norm is not None: 215 | x = post_main_norm(x, condition=style_embeddings) if self.ada_norm else post_main_norm(x) 216 | 217 | x = self.final_norm(x, condition=style_embeddings) if self.ada_norm else self.final_norm(x) 218 | 219 | if has_cache: 220 | cache = intermediates_cache.hiddens.pop(0) 221 | x = torch.cat([cache, x], dim=1) 222 | 223 | if return_hiddens: 224 | hiddens.append(x) 225 | intermediates = TransformerIntermediates( 226 | hiddens=hiddens, 227 | attention=attn_intermediates 228 | ) 229 | 230 | return x, intermediates 231 | 232 | return x 233 | 234 | 235 | @dataclass 236 | class EncoderConfig(TransformerConfig): 237 | _target_: str = "encoder" 238 | causal: bool = False 239 | 240 | 241 | @TransformerRegistry.register("encoder") 242 | class Encoder(Transformer): 243 | def __init__(self, **kwargs): 244 | super().__init__(causal=False, **kwargs) 245 | 246 | 247 | @dataclass 248 | class DecoderConfig(TransformerConfig): 249 | _target_: str = "decoder" 250 | causal: bool = True 251 | 252 | 253 | @TransformerRegistry.register("decoder") 254 | class Decoder(Transformer): 255 | def __init__(self, **kwargs): 256 | super().__init__(causal=True, **kwargs) 257 | -------------------------------------------------------------------------------- /scoreperformer/train.py: -------------------------------------------------------------------------------- 1 | """ A minimal training script. """ 2 | 3 | import argparse 4 | 5 | from scoreperformer.experiments import Trainer 6 | from scoreperformer.experiments.callbacks import EpochReproducibilityCallback 7 | from scoreperformer.experiments.components import ExperimentComponents 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser('training the model') 11 | parser.add_argument('--config-root', '-r', type=str, default='../recipes') 12 | parser.add_argument('--config-name', '-n', type=str, default='scoreperformer/base.yaml') 13 | 14 | args = parser.parse_args() 15 | 16 | exp_comps = ExperimentComponents( 17 | config=args.config_name, 18 | config_root=args.config_root 19 | ) 20 | model, train_dataset, eval_dataset, collator, evaluator = exp_comps.init_components() 21 | 22 | trainer = Trainer( 23 | model=model, 24 | config=exp_comps.config, 25 | train_dataset=train_dataset, 26 | eval_dataset=eval_dataset, 27 | collator=collator, 28 | evaluator=evaluator, 29 | callbacks=[EpochReproducibilityCallback()] 30 | ) 31 | 32 | trainer.train() 33 | -------------------------------------------------------------------------------- /scoreperformer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import Config 2 | from .functions import * 3 | from .io import * -------------------------------------------------------------------------------- /scoreperformer/utils/config.py: -------------------------------------------------------------------------------- 1 | """ A general purpose Config with IO. That's it. """ 2 | 3 | import json 4 | from dataclasses import dataclass, asdict 5 | from enum import Enum 6 | 7 | from omegaconf import DictConfig, ListConfig, OmegaConf 8 | 9 | 10 | @dataclass 11 | class Config: 12 | def to_dict(self): 13 | d = asdict(self) 14 | for k, v in d.items(): 15 | if isinstance(v, Enum): 16 | d[k] = v.value 17 | elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): 18 | d[k] = [x.value for x in v] 19 | elif isinstance(v, (DictConfig, ListConfig)): 20 | d[k] = OmegaConf.to_container(v) 21 | elif isinstance(v, Config): 22 | d[k] = v.to_dict() 23 | return d 24 | 25 | def to_json_string(self): 26 | return json.dumps(self.to_dict(), indent=2) 27 | 28 | @classmethod 29 | def from_json_string(cls, json_string): 30 | return cls(**json.loads(json_string)) 31 | 32 | def __contains__(self, item): 33 | return item in self.to_dict() 34 | 35 | 36 | def disable_nodes(config, parent=None): 37 | if isinstance(config, DictConfig): 38 | eject = config.pop("_disable_", False) 39 | 40 | if eject and parent is not None: 41 | key = config._key() 42 | parent.pop(key) 43 | else: 44 | for value in list(config.values()): 45 | disable_nodes(value, config) 46 | -------------------------------------------------------------------------------- /scoreperformer/utils/functions.py: -------------------------------------------------------------------------------- 1 | """ A set of utility classes and functions used throughout the repository. """ 2 | 3 | import random 4 | import sys 5 | from enum import Enum 6 | from inspect import isfunction 7 | 8 | import numpy as np 9 | from tqdm.asyncio import tqdm 10 | 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | 16 | def default(val, d): 17 | if exists(val): 18 | return val 19 | return d() if isfunction(d) else d 20 | 21 | 22 | class equals: 23 | def __init__(self, val): 24 | self.val = val 25 | 26 | def __call__(self, x, *args, **kwargs): 27 | return x == self.val 28 | 29 | 30 | def or_reduce(masks): 31 | head, *body = masks 32 | for rest in body: 33 | head = head | rest 34 | return head 35 | 36 | 37 | def prob2bool(prob): 38 | return random.choices([True, False], weights=[prob, 1 - prob])[0] 39 | 40 | 41 | def find_closest(array, values): 42 | """Finds indices of the values closest to `values` in a given array.""" 43 | ids = np.searchsorted(array, values, side="left") 44 | 45 | # find indexes where previous index is closer 46 | arr_values = array[np.minimum(ids, len(array) - 1)] 47 | prev_values = array[np.maximum(ids - 1, 0)] 48 | prev_idx_is_less = (ids == len(array)) | (np.fabs(values - prev_values) < np.fabs(values - arr_values)) 49 | 50 | if isinstance(ids, np.ndarray): 51 | ids[prev_idx_is_less] -= 1 52 | elif prev_idx_is_less: 53 | ids -= 1 54 | 55 | ids = np.maximum(0, ids) 56 | 57 | return ids 58 | 59 | 60 | def tqdm_iterator(iterable, desc=None, position=0, leave=False, file=sys.stdout, **kwargs): 61 | return tqdm(iterable, desc=desc, position=position, leave=leave, file=file, **kwargs) 62 | 63 | 64 | def apply(seqs, func, tqdm_enabled=True, desc=None): 65 | """ Apply a given `func` over a list of sequences `seqs`.""" 66 | iterator = tqdm_iterator(seqs, desc=desc) if tqdm_enabled else seqs 67 | return [func(seq) for seq in iterator] 68 | 69 | 70 | class ExplicitEnum(str, Enum): 71 | """ 72 | Enum with more explicit error message for missing values. 73 | """ 74 | 75 | @classmethod 76 | def _missing_(cls, value): 77 | raise ValueError( 78 | f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" 79 | ) 80 | 81 | @classmethod 82 | def has_value(cls, value): 83 | return value in cls._value2member_map_ 84 | 85 | @classmethod 86 | def list(cls): 87 | return list(map(lambda c: c.value, cls)) 88 | -------------------------------------------------------------------------------- /scoreperformer/utils/io.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | def load_json(file_path): 6 | if os.path.exists(file_path): 7 | with open(file_path, 'r') as f: 8 | return json.load(f) 9 | else: 10 | return {} 11 | 12 | 13 | def dump_json(data, file_path): 14 | with open(file_path, 'w') as f: 15 | return json.dump(data, f) 16 | -------------------------------------------------------------------------------- /scoreperformer/utils/playback.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import IPython.display as ipd 4 | import note_seq 5 | from miditoolkit import MidiFile 6 | from note_seq import midi_file_to_note_sequence 7 | 8 | 9 | def cut_midi( 10 | midi: MidiFile, 11 | min_tick: int = 0, 12 | max_tick: int = 1e9, 13 | cut_end_tick: bool = True, 14 | save_path: str = '/tmp/tmp.mid' 15 | ): 16 | midi = copy.deepcopy(midi) 17 | 18 | for track in midi.instruments: 19 | track.notes = [n for n in track.notes if min_tick <= n.start <= max_tick] 20 | for n in track.notes: 21 | n.start -= min_tick 22 | if cut_end_tick: 23 | n.end = min(n.end, max_tick) 24 | n.end -= min_tick 25 | 26 | if hasattr(track, "control_changes"): 27 | track.control_changes = [c for c in track.control_changes if min_tick <= c.time <= max_tick] 28 | for c in track.control_changes: 29 | c.time -= min_tick 30 | if hasattr(track, "pedals"): 31 | track.pedals = [p for p in track.pedals if min_tick <= p.start <= max_tick] 32 | for p in track.pedals: 33 | p.start -= min_tick 34 | p.end -= min_tick 35 | 36 | midi.tempo_changes = [t for t in midi.tempo_changes if min_tick <= t.time <= max_tick] 37 | for t in midi.tempo_changes: 38 | t.time -= min_tick 39 | 40 | midi.max_tick = max([n.end for n in midi.instruments[0].notes]) 41 | midi.max_tick = max(midi.max_tick, midi.tempo_changes[-1].time + 1) 42 | 43 | if save_path is not None: 44 | midi.dump(save_path) 45 | 46 | return midi 47 | 48 | 49 | def midi_to_audio( 50 | path: str = '/tmp/tmp.mid', 51 | sample_rate: int = 22050, 52 | play: bool = True 53 | ): 54 | ns = midi_file_to_note_sequence(path) 55 | audio = note_seq.fluidsynth(ns, sample_rate=sample_rate) 56 | if play: 57 | ipd.display(ipd.Audio(audio, rate=sample_rate)) 58 | return audio 59 | -------------------------------------------------------------------------------- /scoreperformer/utils/plots.py: -------------------------------------------------------------------------------- 1 | import librosa.display 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from matplotlib.colors import ListedColormap 5 | import pretty_midi 6 | 7 | from scoreperformer.data.tokenizers import SPMuple 8 | 9 | 10 | def plot_performance_parameter(tokenizer: SPMuple, total_seq, perf_seq, token_type="Tempo"): 11 | type_idx = tokenizer.vocab_types_idx[token_type] 12 | 13 | preds = total_seq[:, type_idx] - tokenizer.zero_token 14 | targets = perf_seq[:total_seq.shape[0], type_idx] - tokenizer.zero_token 15 | 16 | if token_type == "Velocity": 17 | values_map = tokenizer.velocities 18 | elif token_type == "Tempo": 19 | values_map = tokenizer.tempos 20 | elif token_type == "OnsetDev": 21 | nb_positions = max(tokenizer.beat_res.values()) * 2 # up to two quarter notes 22 | values_map = np.arange(-nb_positions, nb_positions + 1) / nb_positions / 2 23 | elif token_type == "PerfDuration": 24 | values_map = np.array([ 25 | (beat * res + pos) / res if res > 0 else 0 26 | for beat, pos, res in tokenizer.durations 27 | ]) 28 | elif token_type == "RelOnsetDev": 29 | values_map = tokenizer.rel_onset_deviations 30 | elif token_type == "RelPerfDuration": 31 | values_map = tokenizer.rel_performed_durations 32 | else: 33 | return 34 | 35 | preds, targets = values_map[preds], values_map[targets] 36 | 37 | fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(16, 12)) 38 | 39 | fig.suptitle(f"Performance Notes, {token_type}", fontsize=20) 40 | ax0.plot(preds) 41 | ax0.plot(targets) 42 | ax1.plot(preds - targets) 43 | 44 | ax0.legend(["Generated", "Target"], fontsize=18) 45 | ax1.legend(["Difference"], fontsize=18) 46 | 47 | ax0.get_xaxis().set_visible(False) 48 | ax1.set_xlabel("note id", fontsize=16) 49 | for ax in (ax0, ax1): 50 | ax.tick_params(labelsize=14) 51 | ax.set_ylabel(token_type.lower(), fontsize=16) 52 | 53 | fig.tight_layout() 54 | 55 | 56 | _colors = plt.get_cmap('Reds', 256)(np.linspace(0, 1, 256)) 57 | _colors[:1, :] = np.array([1, 1, 1, 1]) 58 | pianoroll_cmap = ListedColormap(_colors) 59 | 60 | 61 | def plot_pianoroll(pm, min_pitch=21, max_pitch=109, min_velocity=0., max_velocity=127., 62 | fs=100, max_time=None, pad_time=0.2, xticks_time=1., figsize=(14, 6), fig=None, ax=None): 63 | if fig is None or ax is None: 64 | fig, ax = plt.subplots(1, 1, figsize=figsize) 65 | 66 | # get numpy array from pianoroll 67 | arr = pm.get_piano_roll(fs)[min_pitch:max_pitch + 1] 68 | arr[arr > max_velocity] = max_velocity 69 | 70 | # pad with a few steps 71 | max_time = pm.get_end_time() if max_time is None else max_time 72 | pad_steps = int(fs * pad_time) 73 | pad_l, pad_r = pad_steps, pad_steps + max(0, int(fs * max_time) - arr.shape[1]) 74 | arr = np.pad(arr, ((0, 0), (pad_l, pad_r)), 'constant') 75 | 76 | # plot pianoroll 77 | x_coords = np.arange(-pad_time, arr.shape[1] / fs - pad_time, 1 / fs) 78 | y_coords = np.arange(min_pitch, max_pitch + 1, 1) 79 | librosa.display.specshow( 80 | arr, cmap=pianoroll_cmap, ax=ax, x_coords=x_coords, y_coords=y_coords, 81 | hop_length=1, sr=fs, x_axis='time', y_axis='linear', bins_per_octave=12, 82 | fmin=pretty_midi.note_number_to_hz(min_pitch), vmin=min_velocity, vmax=max_velocity 83 | ) 84 | 85 | # plot colorbar 86 | cbar = fig.colorbar(ax.get_children()[0], ax=ax, fraction=0.15, pad=0.02, aspect=15) 87 | cbar.ax.tick_params(labelsize=14) 88 | cbar.set_ticks(np.arange(0, max_velocity, 12)) 89 | 90 | # axis labels 91 | ax.set_xlabel('time (s)', fontsize=16) 92 | ax.set_ylabel('pitch', fontsize=16) 93 | ax.tick_params(labelsize=14) 94 | 95 | # x-axis 96 | ax.set_xticks(np.arange(0, ax.get_xticks()[-1], xticks_time)) 97 | ax.set_xlim(xmax=max_time + pad_time) 98 | 99 | # y-axis 100 | yticks = np.arange(min_pitch + 12 - min_pitch % 12, max_pitch, 12) 101 | ax.set_yticks(yticks - 0.5) 102 | ax.set_yticklabels(yticks) 103 | 104 | # removing empty pitch lines 105 | has_notes = min_pitch + np.where(np.any(arr != 0., axis=1))[0] 106 | if len(has_notes) > 0: 107 | ymin, ymax = has_notes[0], has_notes[-1] 108 | ymin = max(min_pitch, ymin - ymin % 12) - 2.5 109 | ymax = min(max_pitch, ymax + 12 - ymax % 12) + 1.5 110 | ax.set_ylim(ymin, ymax) 111 | 112 | ax.grid(alpha=0.5) 113 | 114 | return fig, ax 115 | --------------------------------------------------------------------------------