├── LICENSE.md
├── README.md
├── assets
├── perf-embs-clf-4x6.png
└── scoreperformer.png
├── data
├── directions
│ └── direction_classes.json
└── tokenizers
│ ├── spmuple_bar.json
│ ├── spmuple_beat.json
│ ├── spmuple_onset.json
│ └── spmuple_window.json
├── recipes
├── default.yaml
└── scoreperformer
│ ├── ablation
│ ├── no_cont_tokens.yaml
│ ├── no_io_tie.yaml
│ ├── no_masked_seq.yaml
│ ├── no_saln.yaml
│ └── no_score_enc.yaml
│ ├── base.yaml
│ ├── custom_hierarchy.yaml
│ ├── minimal.yaml
│ └── no_classifiers.yaml
├── requirements.txt
└── scoreperformer
├── __init__.py
├── data
├── __init__.py
├── collators
│ ├── __init__.py
│ ├── common.py
│ ├── directions.py
│ ├── performance.py
│ └── score_performance.py
├── datasets
│ ├── __init__.py
│ ├── directions.py
│ ├── performance.py
│ ├── score_performance.py
│ ├── token_sequence.py
│ └── utils.py
├── directions
│ ├── __init__.py
│ ├── articulation.py
│ ├── dynamic.py
│ ├── parser.py
│ ├── tempo.py
│ └── words.py
├── helpers
│ ├── __init__.py
│ ├── indexers.py
│ └── processors.py
├── midi
│ ├── __init__.py
│ ├── beats.py
│ ├── containers.py
│ ├── preprocess.py
│ ├── quantization.py
│ ├── sync.py
│ ├── timing.py
│ └── utils.py
├── music_constants.py
└── tokenizers
│ ├── __init__.py
│ ├── classes.py
│ ├── common
│ ├── __init__.py
│ └── octuple_m.py
│ ├── constants.py
│ ├── midi_tokenizer.py
│ └── spmuple
│ ├── __init__.py
│ ├── base.py
│ ├── encodings.py
│ ├── spmuple.py
│ └── spmuple2.py
├── experiments
├── __init__.py
├── callbacks.py
├── components.py
├── integrations.py
├── logging
│ ├── __init__.py
│ └── console_logger.py
├── optimizers.py
├── trainer.py
├── trainer_config.py
└── trainer_utils.py
├── inference
├── __init__.py
├── generators.py
└── messengers.py
├── models
├── __init__.py
├── base.py
├── classifiers
│ ├── __init__.py
│ ├── evaluator.py
│ └── model.py
└── scoreperformer
│ ├── __init__.py
│ ├── embeddings.py
│ ├── evaluator.py
│ ├── mmd_transformer.py
│ ├── model.py
│ ├── transformer.py
│ └── wrappers.py
├── modules
├── __init__.py
├── constructor.py
├── layers.py
├── sampling.py
└── transformer
│ ├── __init__.py
│ ├── attend.py
│ ├── attention.py
│ ├── embeddings.py
│ ├── feedforward.py
│ └── transformer.py
├── train.py
└── utils
├── __init__.py
├── config.py
├── functions.py
├── io.py
├── playback.py
└── plots.py
/README.md:
--------------------------------------------------------------------------------
1 | # ScorePerformer: Expressive Piano Performance Rendering Model
2 |
3 |
4 |
5 | > Code for the paper [*"ScorePerformer: Expressive Piano Performance Rendering with Fine-Grained
6 | Control"*](https://archives.ismir.net/ismir2023/paper/000069.pdf) published in the Proceedings of the
7 | > [24th International Society for Music Information Retrieval (ISMIR) Conference, Milan, 2023](https://ismir2023.ismir.net).
8 | >
9 | >Authors: Ilya Borovik and Vladimir Viro
10 | >
11 | >[](https://archives.ismir.net/ismir2023/paper/000069.pdf)
12 | [](http://ismir2023program.ismir.net/poster_183.html)
13 | [](https://colab.research.google.com/drive/1uKEX02D8cn3wzG-oS7YWE3IW-OrGV22w)
14 |
15 | ## Description
16 |
17 | The paper presents ScorePerformer, a controllable piano performance rendering model that:
18 |
19 | 1. learns fine-grained performance style features at the global, bar, beat, and onset levels using transformers and
20 | hierarchical maximum mean discrepancy variational autoencoders;
21 | 2. provides musical language-driven control over the learned style space through a set of trained performance direction
22 | classifiers;
23 | 3. utilizes a tokenized encoding for aligned score and performance music (SPMuple) with a smooth local window tempo
24 | function.
25 |
26 | The repository provides the code for the model and tokenizers, along with the training and inference pipelines. The
27 | trained model is available to interact with in a simple [Colab demo](https://colab.research.google.com/drive/1uKEX02D8cn3wzG-oS7YWE3IW-OrGV22w).
28 |
29 | *Note*: The code for the alignment and preprocessing of score and performance data is deliberately not made available.
30 | It can be discussed and released upon request and demand.
31 |
32 | If you have any questions or requests, please write to ilya.borovik@skoltech.ru
33 |
34 | ## Latent Style Space Control
35 |
36 | Our approach to make the expressive performance rendering controllable relies on the use of performance direction
37 | markings available in musical scores. We use a two-step approach:
38 |
39 | 1. learn performance style features relevant to performance rendering using a performance style encoder;
40 | 2. train performance direction classifiers on the learned style embeddings and ”label” the style space using their predictions.
41 |
42 | Below is a vizualization of the top two principal components of the learned latent style spaces, labelled by a subset of
43 | the trained performance direction classifiers.
44 |
45 |
46 |
47 | ## Citing
48 |
49 | If you use or reference the model, please cite the following paper:
50 |
51 | ```
52 | @inproceedings{bovorik2023scoreperformer,
53 | title={{ScorePerformer: Expressive Piano Performance Rendering with Fine-Grained Control}},
54 | author={Borovik, Ilya and Viro, Vladimir},
55 | booktitle={Proceedings of the 24th International Society for Music Information Retrieval Conference {(ISMIR)}},
56 | year={2023},
57 | url={https://archives.ismir.net/ismir2023/paper/000069.pdf},
58 | }
59 | ```
60 |
61 | ## License
62 |
63 | The work is released under a [CC BY-NC-SA 4.0 license](https://creativecommons.org/licenses/by-nc-sa/4.0/).
--------------------------------------------------------------------------------
/assets/perf-embs-clf-4x6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ilya16/ScorePerformer/c04f9a3155bb6afb6ca7bcc1cac2325184cc1262/assets/perf-embs-clf-4x6.png
--------------------------------------------------------------------------------
/assets/scoreperformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ilya16/ScorePerformer/c04f9a3155bb6afb6ca7bcc1cac2325184cc1262/assets/scoreperformer.png
--------------------------------------------------------------------------------
/data/directions/direction_classes.json:
--------------------------------------------------------------------------------
1 | {
2 | "dynamic/absolute": [
3 | "dynamic/ppp",
4 | "dynamic/pp",
5 | "dynamic/p",
6 | "dynamic/mp",
7 | "dynamic/mf",
8 | "dynamic/f",
9 | "dynamic/ff",
10 | "dynamic/fff",
11 | "dynamic/fp"
12 | ],
13 | "dynamic/hairpin": [
14 | "dynamic/crescendo",
15 | "dynamic/diminuendo"
16 | ],
17 | "dynamic/accent": [
18 | "dynamic/sf",
19 | "dynamic/rf"
20 | ],
21 | "tempo/absolute": [
22 | "tempo/grave",
23 | "tempo/largo",
24 | "tempo/larghetto",
25 | "tempo/lento",
26 | "tempo/adagio",
27 | "tempo/andante",
28 | "tempo/andantino",
29 | "tempo/moderato",
30 | "tempo/allegretto",
31 | "tempo/allegro",
32 | "tempo/vivace",
33 | "tempo/presto",
34 | "tempo/prestissimo"
35 | ],
36 | "tempo/relative": [
37 | "tempo/accelerando",
38 | "tempo/ritardando",
39 | "tempo/rallentando",
40 | "tempo/stringendo",
41 | "tempo/calando",
42 | "tempo/pi\u00f9 mosso",
43 | "tempo/animato",
44 | "tempo/stretto",
45 | "tempo/smorzando",
46 | "tempo/ritenuto"
47 | ],
48 | "articulation/arpeggiate": [
49 | "articulation/arpeggiate"
50 | ],
51 | "articulation/fermata": [
52 | "articulation/fermata"
53 | ],
54 | "articulation/staccato": [
55 | "articulation/staccato"
56 | ],
57 | "articulation/tenuto": [
58 | "articulation/tenuto"
59 | ]
60 | }
--------------------------------------------------------------------------------
/data/tokenizers/spmuple_window.json:
--------------------------------------------------------------------------------
1 | {
2 | "config": {
3 | "pitch_range": [
4 | 21,
5 | 109
6 | ],
7 | "beat_res": {
8 | "0_2": 16,
9 | "2_4": 8,
10 | "4_8": 4,
11 | "8_16": 2,
12 | "16_64": 1
13 | },
14 | "nb_velocities": 127,
15 | "special_tokens": [
16 | "PAD",
17 | "MASK",
18 | "SOS",
19 | "EOS"
20 | ],
21 | "use_chords": false,
22 | "use_rests": false,
23 | "use_tempos": true,
24 | "use_time_signatures": true,
25 | "use_sustain_pedals": false,
26 | "use_pitch_bends": false,
27 | "use_programs": false,
28 | "beat_res_rest": {
29 | "0_1": 8,
30 | "1_2": 4,
31 | "2_12": 2
32 | },
33 | "chord_maps": {
34 | "min": [
35 | 0,
36 | 3,
37 | 7
38 | ],
39 | "maj": [
40 | 0,
41 | 4,
42 | 7
43 | ],
44 | "dim": [
45 | 0,
46 | 3,
47 | 6
48 | ],
49 | "aug": [
50 | 0,
51 | 4,
52 | 8
53 | ],
54 | "sus2": [
55 | 0,
56 | 2,
57 | 7
58 | ],
59 | "sus4": [
60 | 0,
61 | 5,
62 | 7
63 | ],
64 | "7dom": [
65 | 0,
66 | 4,
67 | 7,
68 | 10
69 | ],
70 | "7min": [
71 | 0,
72 | 3,
73 | 7,
74 | 10
75 | ],
76 | "7maj": [
77 | 0,
78 | 4,
79 | 7,
80 | 11
81 | ],
82 | "7halfdim": [
83 | 0,
84 | 3,
85 | 6,
86 | 10
87 | ],
88 | "7dim": [
89 | 0,
90 | 3,
91 | 6,
92 | 9
93 | ],
94 | "7aug": [
95 | 0,
96 | 4,
97 | 8,
98 | 11
99 | ],
100 | "9maj": [
101 | 0,
102 | 4,
103 | 7,
104 | 10,
105 | 14
106 | ],
107 | "9min": [
108 | 0,
109 | 4,
110 | 7,
111 | 10,
112 | 13
113 | ]
114 | },
115 | "chord_tokens_with_root_note": false,
116 | "chord_unknown": null,
117 | "nb_tempos": 121,
118 | "tempo_range": [
119 | 15,
120 | 480
121 | ],
122 | "log_tempos": true,
123 | "delete_equal_successive_tempo_changes": true,
124 | "time_signature_range": {
125 | "2": [
126 | 1,
127 | 2,
128 | 3,
129 | 4
130 | ],
131 | "4": [
132 | 1,
133 | 2,
134 | 3,
135 | 4,
136 | 5,
137 | 6
138 | ],
139 | "8": [
140 | 1,
141 | 2,
142 | 3,
143 | 4,
144 | 5,
145 | 6,
146 | 7,
147 | 8,
148 | 9,
149 | 10,
150 | 11,
151 | 12
152 | ]
153 | },
154 | "delete_equal_successive_time_sig_changes": true,
155 | "sustain_pedal_duration": false,
156 | "pitch_bend_range": [
157 | -8192,
158 | 8191,
159 | 32
160 | ],
161 | "programs": [
162 | 0
163 | ],
164 | "one_token_stream_for_programs": true,
165 | "additional_params": {
166 | "nb_onset_devs": 161,
167 | "nb_perf_durations": 81,
168 | "max_bar_embedding": 256,
169 | "rel_onset_dev": true,
170 | "rel_perf_duration": true,
171 | "real_max_bar_embedding": 256,
172 | "fill_unperformed_notes": true,
173 | "remove_duplicates": false,
174 | "token_bins": {},
175 | "cut_overlapping_notes": true,
176 | "use_position_shifts": true,
177 | "onset_position_shifts": true,
178 | "use_onset_indices": true,
179 | "max_notes_in_onset": 12,
180 | "bar_tempos": false,
181 | "onset_tempos": false,
182 | "tempo_window": 8.0,
183 | "tempo_min_onset_dist": 0.5,
184 | "tempo_min_onsets": 8,
185 | "use_quantized_tempos": true,
186 | "decode_recompute_tempos": false,
187 | "limit_rel_onset_devs": true
188 | }
189 | },
190 | "one_token_stream": true,
191 | "has_bpe": false,
192 | "tokenization": "SPMupleWindow",
193 | "miditok_version": "2.1.6"
194 | }
--------------------------------------------------------------------------------
/recipes/default.yaml:
--------------------------------------------------------------------------------
1 | _general_:
2 | device: cuda:0
3 | seed: 23
4 |
5 | _dirname_: results
6 | _label_: ???
7 |
8 | output_dir: # either a full path or a list of path elements
9 | - ${_general_._dirname_}
10 | - ${model._name_}
11 | - ${model._version_}
12 | - ${date:}
13 | - ${_general_._label_}
14 |
15 | resume_from_checkpoint:
16 | warm_start: false
17 |
18 |
19 | data:
20 | dataset:
21 | _name_: ???
22 |
23 | collator:
24 | _name_: ???
25 |
26 |
27 | model:
28 | _name_: ???
29 | _version_: v0.1
30 |
31 |
32 | evaluator:
33 | _name_: ???
34 |
35 |
36 | trainer:
37 | # general
38 | output_dir: ${_general_.output_dir}
39 |
40 | do_train: true
41 | do_eval: true
42 | eval_mode: false
43 |
44 | device: ${_general_.device}
45 | seed: ${_general_.seed}
46 |
47 | # logging
48 | log_dir: logs
49 | log_to_file: true
50 |
51 | dashboard_logger: "tensorboard"
52 | log_strategy: steps
53 | log_steps: 5
54 | log_first_step: true
55 | log_raw_to_console: true
56 |
57 | disable_tqdm: false
58 | progress_steps: 5
59 | progress_metrics: ["loss"]
60 |
61 | ignore_data_skip: false
62 |
63 | # data
64 | num_workers: 4
65 | pin_memory: true
66 | shuffle: true # will shuffle training set only
67 |
68 | # training & evaluation
69 | epochs: 100
70 | max_steps: -1
71 | batch_size: 32
72 |
73 | eval_batch_size: 64
74 | eval_batches:
75 |
76 | eval_strategy: epoch # evaluation and checkpoint steps measure: `epoch` or `steps`
77 | eval_steps: 1 # how frequently (in epochs/steps) model evaluation should be performed
78 |
79 | optimization:
80 | lr: !!float 2e-4
81 | optimizer: adamw
82 | optimizer_params:
83 | weight_decay: !!float 1e-6
84 | lr_scheduler: exponential
85 | lr_scheduler_params:
86 | gamma: 0.995
87 | grad_clip: 2.0
88 | grad_accum_steps: 1
89 | mixed_precision: true
90 |
91 | # checkpointing
92 | save_strategy: epoch
93 | save_steps: 1
94 | save_optimizer: false
95 | save_best_only: true
96 | save_rewrite_checkpoint: false
97 |
98 | metric_for_best_model: loss
99 | metric_maximize: false
100 |
101 | resume_from_checkpoint: ${_general_.resume_from_checkpoint}
102 | warm_start: ${_general_.warm_start}
103 | ignore_layers: []
104 | ignore_mismatched_keys: true
105 | finetune_layers: []
106 | restore_lr: true
--------------------------------------------------------------------------------
/recipes/scoreperformer/ablation/no_cont_tokens.yaml:
--------------------------------------------------------------------------------
1 | base: scoreperformer/base.yaml
2 |
3 | _general_:
4 | device: cuda:0
5 | seed: 23
6 |
7 | _dirname_: results
8 | _label_: ScorePerformer_noContTE
9 |
10 |
11 | data:
12 | dataset:
13 | root: ???
14 |
15 |
16 | model:
17 | perf_decoder:
18 | token_embeddings:
19 | discrete: true
20 | continuous: false
21 | continuous_dense: false
22 | discrete_ids:
23 |
--------------------------------------------------------------------------------
/recipes/scoreperformer/ablation/no_io_tie.yaml:
--------------------------------------------------------------------------------
1 | base: scoreperformer/base.yaml
2 |
3 | _general_:
4 | device: cuda:0
5 | seed: 23
6 |
7 | _dirname_: results
8 | _label_: ScorePerformer_noIOTie
9 |
10 |
11 | data:
12 | dataset:
13 | root: ???
14 |
15 |
16 | model:
17 | perf_decoder:
18 | lm_head:
19 | _target_: lm
20 |
--------------------------------------------------------------------------------
/recipes/scoreperformer/ablation/no_masked_seq.yaml:
--------------------------------------------------------------------------------
1 | base: scoreperformer/base.yaml
2 |
3 | _general_:
4 | device: cuda:0
5 | seed: 23
6 |
7 | _dirname_: results
8 | _label_: ScorePerformer_noMaskedSeq
9 |
10 |
11 | data:
12 | dataset:
13 | root: ???
14 |
15 |
16 | model:
17 | perf_decoder:
18 | token_embeddings:
19 | _target_: simple
20 | multiseq_mode:
21 | _disable_: true
22 |
--------------------------------------------------------------------------------
/recipes/scoreperformer/ablation/no_saln.yaml:
--------------------------------------------------------------------------------
1 | base: scoreperformer/base.yaml
2 |
3 | _general_:
4 | device: cuda:0
5 | seed: 23
6 |
7 | _dirname_: results
8 | _label_: ScorePerformer_noSALN
9 |
10 |
11 | data:
12 | dataset:
13 | root: ???
14 |
15 |
16 | model:
17 | perf_decoder:
18 | style_emb_mode: cat
19 |
--------------------------------------------------------------------------------
/recipes/scoreperformer/ablation/no_score_enc.yaml:
--------------------------------------------------------------------------------
1 | base: scoreperformer/base.yaml
2 |
3 | _general_:
4 | device: cuda:0
5 | seed: 23
6 |
7 | _dirname_: results
8 | _label_: ScorePerformer_noScoreEnc
9 |
10 |
11 | data:
12 | dataset:
13 | root: ???
14 |
15 |
16 | model:
17 | score_encoder:
18 | _disable_: true
19 |
--------------------------------------------------------------------------------
/recipes/scoreperformer/base.yaml:
--------------------------------------------------------------------------------
1 | base: default.yaml
2 |
3 | _general_:
4 | device: cuda:0
5 | seed: 23
6 |
7 | _dirname_: results
8 | _label_: ScorePerformer_base
9 |
10 | resume_from_checkpoint:
11 | warm_start: false
12 |
13 |
14 | data:
15 | dataset:
16 | _name_: LocalScorePerformanceDataset
17 | _splits_:
18 | train: train
19 | eval: eval
20 |
21 | root: ???
22 | use_alignments: false
23 | auxiliary_data_keys: ["bars", "initial_tempos"]
24 |
25 | performance_directions: ??? # data/directions/direction_classes.json
26 | score_directions_dict: ???
27 |
28 | max_seq_len: 256
29 | max_bar: 256
30 | bar_sliding_window: 16
31 |
32 | sample_bars: true
33 | sample_note_shift: 0.5
34 | force_max_seq_len: 0.5
35 |
36 | fit_to_max_bar: false # fit bar tokens to max bar (to avoid range errors during training)
37 | fit_to_zero_bar: true
38 | sample_bar_offset: false
39 |
40 | add_sos_eos: true # SOS/EOS tokens for whole sequences, not the sampled subsequences
41 |
42 | sample: true
43 | seed: ${_general_.seed}
44 |
45 | augment_performance: true
46 | pitch_shift_range: [-3, 3]
47 | velocity_shift_range: [-12, 12] # in velocity indices
48 | tempo_shift_range: [0, 0] # in tempo bins (2 ** (1 / 24))
49 |
50 | noisy_performance: false # compute performance style of the noisy performance
51 | noise_strength: 0.5 # multiplier for noisy performance augmentations
52 | noisy_random_bars: 0.5 # probability of sampling random bars
53 |
54 | deadpan_performance: 0.25 # probability of using deadpan performance
55 |
56 | zero_out_silent_durations: true # set durations of unperformed notes to 0
57 | delete_silent_notes: true
58 |
59 | preload: true
60 | cache: true
61 |
62 | collator:
63 | _name_: MixedLMScorePerformanceCollator
64 | mask_ignore_token_ids: [0, 1, 2, 3] # PAD, MASK, SOS, EOS
65 | mask_ignore_token_dims: [0, 1, 2, 4, 6, 7, 8, 9] # _nT: bar, position, pitch, duration, time_sig, pos_shift, notes_in_onset, pos_in_onset
66 |
67 |
68 | model:
69 | _name_: ScorePerformer
70 | _version_: v0.4.4
71 |
72 | dim: 256
73 | tie_token_emb: true
74 | mode: mixlm
75 |
76 | score_encoder:
77 | _disable_: false
78 |
79 | token_embeddings:
80 | _target_: simple
81 | emb_dims: ${model.perf_decoder.token_embeddings.emb_dims}
82 | mode: ${model.perf_decoder.token_embeddings.mode}
83 | emb_norm: ${model.perf_decoder.token_embeddings.emb_norm}
84 | discrete: ${model.perf_decoder.token_embeddings.discrete}
85 | continuous: ${model.perf_decoder.token_embeddings.continuous}
86 | continuous_dense: ${model.perf_decoder.token_embeddings.continuous_dense}
87 | discrete_ids: ${model.perf_decoder.token_embeddings.discrete_ids}
88 | tie_keys:
89 |
90 | emb_norm: ${model.perf_decoder.emb_norm}
91 | emb_dropout: ${model.perf_decoder.emb_dropout}
92 | use_abs_pos_emb: ${model.perf_decoder.use_abs_pos_emb}
93 |
94 | transformer:
95 | _target_: encoder
96 |
97 | depth: 2
98 | heads: 4
99 |
100 | attention: ${model.perf_decoder.transformer.attention}
101 | feed_forward: ${model.perf_decoder.transformer.feed_forward}
102 |
103 | perf_encoder:
104 | token_embeddings:
105 | _target_: simple
106 | emb_dims: ${model.perf_decoder.token_embeddings.emb_dims}
107 | mode: ${model.perf_decoder.token_embeddings.mode}
108 | emb_norm: ${model.perf_decoder.token_embeddings.emb_norm}
109 | discrete: ${model.perf_decoder.token_embeddings.discrete}
110 | continuous: ${model.perf_decoder.token_embeddings.continuous}
111 | continuous_dense: ${model.perf_decoder.token_embeddings.continuous_dense}
112 | discrete_ids: ${model.perf_decoder.token_embeddings.discrete_ids}
113 | tie_keys:
114 |
115 | emb_norm: true
116 | emb_dropout: 0
117 | use_abs_pos_emb: false
118 |
119 | latent_dim: [32, 20, 8, 4]
120 | aggregate_mode: [mean, bar_mean, beat_mean, onset_mean]
121 | latent_dropout: [0.0, 0.1, 0.2, 0.4]
122 |
123 | hierarchical: true
124 | inclusive_latent_dropout: true
125 | deadpan_zero_latent: true
126 | loss_weight: 1.
127 |
128 | transformer:
129 | _target_: encoder
130 |
131 | depth: 4
132 | heads: 4
133 |
134 | attention: ${model.perf_decoder.transformer.attention}
135 | feed_forward: ${model.perf_decoder.transformer.feed_forward}
136 |
137 | perf_decoder:
138 | token_embeddings:
139 | _target_: multi-seq
140 | emb_dims: 128
141 | mode: cat
142 | emb_norm: true
143 | discrete: false
144 | continuous: true
145 | continuous_dense: true
146 | discrete_ids: [0, 1, 2, 3]
147 | tie_keys:
148 | multiseq_mode: post-cat
149 |
150 | emb_norm: true
151 | emb_dropout: 0
152 | use_abs_pos_emb: false
153 |
154 | context_emb_mode: cat # (`cat` or `attention`)
155 | style_emb_dim: ${model.perf_encoder.latent_dim}
156 | style_emb_mode: adanorm # (`cat` or `adanorm`)
157 |
158 | transformer:
159 | _target_: decoder
160 |
161 | depth: 4
162 | heads: 4
163 |
164 | attention:
165 | dim_head: 64
166 | one_kv_head: true
167 | dropout: 0.1
168 |
169 | alibi_pos_bias: true
170 | alibi_learned: true
171 |
172 | feed_forward:
173 | mult: 4
174 | glu: true
175 | swish: true
176 | dropout: 0.1
177 |
178 | lm_head:
179 | _target_: lm-tied
180 |
181 | regression_head:
182 | _disable_: true
183 | regression_keys: ["Velocity", "Tempo", "RelOnsetDev", "RelPerfDuration"]
184 |
185 | classifiers:
186 | _disable_: false
187 | classifier:
188 | hidden_dims: []
189 | dropout: 0.2
190 | loss_weight: 1.
191 | weighted_classes: true
192 | detach_inputs: true
193 |
194 |
195 | evaluator:
196 | _name_: ScorePerformerEvaluator
197 | ignore_keys: ["Bar", "Position", "Pitch", "Duration", "TimeSig", "PositionShift", "NotesInOnset", "PositionInOnset"]
198 | weighted_distance: true
199 |
200 |
201 | trainer:
202 | epochs: 1000
203 | batch_size: 128
204 | eval_batch_size: 128
205 |
206 | progress_metrics: ["loss", "accuracy"]
207 |
208 | save_rewrite_checkpoint: true
209 |
210 | metric_for_best_model: accuracy
211 | metric_maximize: true
--------------------------------------------------------------------------------
/recipes/scoreperformer/custom_hierarchy.yaml:
--------------------------------------------------------------------------------
1 | base: scoreperformer/base.yaml
2 |
3 | _general_:
4 | device: cuda:0
5 | seed: 23
6 |
7 | _dirname_: results
8 | _label_: ScorePerformer_custom
9 |
10 |
11 | data:
12 | dataset:
13 | root: ???
14 |
15 |
16 | model:
17 | perf_encoder:
18 | # z=64: base
19 | latent_dim: [32, 20, 8, 4]
20 | aggregate_mode: [mean, bar_mean, beat_mean, onset_mean]
21 | latent_dropout: [0.0, 0.1, 0.2, 0.4]
22 |
23 | # # z=64: -onset
24 | # latent_dim: [32, 20, 12]
25 | # aggregate_mode: [mean, bar_mean, beat_mean]
26 | # latent_dropout: [0.0, 0.1, 0.2]
27 |
28 | # # z=64: -onset, -beat
29 | # latent_dim: [32, 32]
30 | # aggregate_mode: [mean, bar_mean]
31 | # latent_dropout: [0.0, 0.1]
32 |
33 | # # z=64: -onset, -beat, -bar
34 | # latent_dim: [64]
35 | # aggregate_mode: [mean]
36 | # latent_dropout: [0.0]
37 |
--------------------------------------------------------------------------------
/recipes/scoreperformer/minimal.yaml:
--------------------------------------------------------------------------------
1 | base: scoreperformer/base.yaml
2 |
3 | _general_:
4 | device: cuda:0
5 | seed: 23
6 |
7 | _dirname_: results
8 | _label_: ScorePerformer_base
9 |
10 |
11 | data:
12 | dataset:
13 | root: ???
14 |
15 | performance_directions: ??? # data/directions/direction_classes.json
16 | score_directions_dict: ???
17 |
--------------------------------------------------------------------------------
/recipes/scoreperformer/no_classifiers.yaml:
--------------------------------------------------------------------------------
1 | base: scoreperformer/base.yaml
2 |
3 | _general_:
4 | device: cuda:0
5 | seed: 23
6 |
7 | _dirname_: results
8 | _label_: ScorePerformer_noClf
9 |
10 |
11 | data:
12 | dataset:
13 | root: ???
14 |
15 | # leave empty
16 | performance_directions:
17 | score_directions_dict:
18 |
19 | model:
20 | classifiers:
21 | _disable_: true
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.23,<1.24
2 | torch>=2.0.1
3 | miditok==2.1.6
4 | matplotlib
5 | tensorboard>=2.6.0
6 | omegaconf==2.3.0
7 | loguru==0.7.0
8 | tqdm>=4.63.0
9 | note_seq==0.0.5
10 | einops
11 | git+https://github.com/mac-marg-pianist/musicXML_parser
12 |
--------------------------------------------------------------------------------
/scoreperformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ilya16/ScorePerformer/c04f9a3155bb6afb6ca7bcc1cac2325184cc1262/scoreperformer/__init__.py
--------------------------------------------------------------------------------
/scoreperformer/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .collators import (
2 | LMPerformanceCollator,
3 | MixedLMPerformanceCollator,
4 | LMScorePerformanceCollator,
5 | MixedLMScorePerformanceCollator,
6 | )
7 | from .datasets import (
8 | PerformanceDataset,
9 | LocalScorePerformanceDataset
10 | )
11 |
12 | DATASETS = {name: cls for name, cls in globals().items() if ".datasets." in str(cls)}
13 | COLLATORS = {name: cls for name, cls in globals().items() if ".collators." in str(cls)}
14 |
--------------------------------------------------------------------------------
/scoreperformer/data/collators/__init__.py:
--------------------------------------------------------------------------------
1 | from .common import SeqInputs
2 | from .performance import (
3 | PerformanceInputs, PerformanceCollator,
4 | LMPerformanceInputs, LMPerformanceCollator,
5 | MixedLMPerformanceInputs, MixedLMPerformanceCollator
6 | )
7 | from .score_performance import (
8 | ScorePerformanceInputs, ScorePerformanceCollator,
9 | LMScorePerformanceInputs, LMScorePerformanceCollator,
10 | MixedLMScorePerformanceInputs, MixedLMScorePerformanceCollator
11 | )
12 |
--------------------------------------------------------------------------------
/scoreperformer/data/collators/common.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Union
3 |
4 | import numpy as np
5 | import torch
6 |
7 |
8 | @dataclass
9 | class SeqInputs:
10 | tokens: Union[np.ndarray, torch.Tensor]
11 | mask: Union[np.ndarray, torch.Tensor]
12 | lengths: Union[np.ndarray, torch.Tensor]
13 |
--------------------------------------------------------------------------------
/scoreperformer/data/collators/directions.py:
--------------------------------------------------------------------------------
1 | """ Performance direction embeddings collators. """
2 |
3 | import torch
4 |
5 |
6 | class DirectionEmbeddingCollator:
7 | def __init__(
8 | self,
9 | num_embeddings: int = 1,
10 | embedding_dim: int = 64,
11 | ):
12 | self.num_embeddings = num_embeddings
13 | self.embedding_dim = embedding_dim
14 |
15 | def init_data(self, batch):
16 | embeddings = torch.zeros(len(batch), self.num_embeddings, self.embedding_dim)
17 | labels = torch.zeros(len(batch), dtype=torch.long)
18 | return embeddings, labels
19 |
20 | def process_sample(self, i, sample, data):
21 | embeddings, labels = data
22 | _, emb, label = sample
23 |
24 | emb = emb.unsqueeze(0) if emb.ndim == 1 else emb
25 | embeddings[i, -emb.shape[0]:] = emb
26 | labels[i] = label
27 |
28 | def __call__(self, batch, inference=False, return_tensors=True):
29 | data = self.init_data(batch)
30 | for i, sample in enumerate(batch):
31 | self.process_sample(i, sample, data)
32 |
33 | return {'embeddings': data[0], 'labels': data[1]}
34 |
--------------------------------------------------------------------------------
/scoreperformer/data/collators/score_performance.py:
--------------------------------------------------------------------------------
1 | """ Score-Performance data collators. """
2 |
3 | from dataclasses import dataclass
4 | from typing import Optional, List, Union, Dict
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from .common import SeqInputs
10 | from .performance import (
11 | PerformanceCollator,
12 | LMPerformanceCollator,
13 | MixedLMPerformanceCollator
14 | )
15 | from ..datasets.score_performance import ScorePerformanceSample
16 |
17 |
18 | @dataclass
19 | class SeqSegments:
20 | bar: Optional[Union[np.ndarray, torch.Tensor]] = None
21 | beat: Optional[Union[np.ndarray, torch.Tensor]] = None
22 | onset: Optional[Union[np.ndarray, torch.Tensor]] = None
23 |
24 |
25 | @dataclass
26 | class ScorePerformanceInputs:
27 | scores: SeqInputs
28 | performances: SeqInputs
29 | noisy_performances: Optional[SeqInputs] = None
30 | segments: Optional[SeqSegments] = None
31 | directions: Optional[Union[Dict[str, torch.Tensor], torch.Tensor]] = None
32 | deadpan_mask: Optional[torch.Tensor] = None
33 |
34 |
35 | class ScorePerformanceCollator(PerformanceCollator):
36 | def __init__(
37 | self,
38 | pad_token_id: int = 0,
39 | pad_to_multiple_of: int = 1
40 | ):
41 | super().__init__(pad_token_id, pad_to_multiple_of)
42 |
43 | def get_max_lengths(self, batch: List[ScorePerformanceSample], inference: bool = False):
44 | max_lens = super().get_max_lengths(batch, inference=inference)
45 |
46 | lens_score = np.array(list(map(lambda sample: len(sample.score), batch))).T
47 | max_lens['score'] = self.pad_len(np.max(lens_score))
48 |
49 | if all((sample.noisy_perf is not None for sample in batch)):
50 | lens_noisy_perf = np.array(list(map(lambda sample: len(sample.noisy_perf), batch))).T
51 | max_lens['noisy_perf'] = self.pad_len(np.max(lens_noisy_perf))
52 |
53 | return max_lens
54 |
55 | def init_data(self, batch: List[ScorePerformanceSample], inference: bool = False):
56 | data = super().init_data(batch, inference=inference)
57 | max_lens = self.get_max_lengths(batch, inference=inference)
58 |
59 | sample, batch_size = batch[0], len(batch)
60 | return ScorePerformanceInputs(
61 | scores=self._init_seq_data(
62 | batch_size, max_lens['score'],
63 | compound_factor=sample.score.shape[-1]
64 | ),
65 | performances=data.performances,
66 | noisy_performances=self._init_seq_data(
67 | batch_size, max_lens['noisy_perf'],
68 | compound_factor=sample.noisy_perf.shape[-1]
69 | ) if 'noisy_perf' in max_lens else None,
70 | segments=SeqSegments(
71 | bar=torch.zeros(len(batch), max_lens['score'], dtype=torch.long),
72 | beat=torch.zeros(len(batch), max_lens['score'], dtype=torch.long),
73 | onset=torch.zeros(len(batch), max_lens['score'], dtype=torch.long)
74 | ) if sample.segments is not None else None,
75 | directions=torch.zeros(
76 | batch_size, max_lens['score'], len(sample.directions), dtype=torch.long
77 | ) if sample.directions is not None else None,
78 | deadpan_mask=torch.zeros(batch_size, dtype=torch.bool)
79 | )
80 |
81 | def process_sample(self, i: int, sample: ScorePerformanceSample, data: ScorePerformanceInputs,
82 | inference: bool = False):
83 | # process performance
84 | super().process_sample(i, sample, data, inference=inference)
85 |
86 | # process score
87 | self._process_sequence(i, seq=sample.score, seq_data=data.scores)
88 |
89 | # process noisy performance is present
90 | if sample.noisy_perf is not None:
91 | self._process_sequence(i, seq=sample.noisy_perf, seq_data=data.noisy_performances)
92 |
93 | # process note segments if present
94 | seq_len = len(sample.score)
95 | if sample.segments is not None:
96 | data.segments.bar[i, :seq_len] = torch.from_numpy(sample.segments.bar)
97 | data.segments.beat[i, :seq_len] = torch.from_numpy(sample.segments.beat)
98 | data.segments.onset[i, :seq_len] = torch.from_numpy(sample.segments.onset)
99 |
100 | # process directions if present
101 | if sample.directions is not None:
102 | for j, (group_name, group_directions) in enumerate(sample.directions.items()):
103 | for (label, key), direction_map in group_directions.items():
104 | mask = direction_map != 0.
105 | if np.any(mask):
106 | data.directions[i, :seq_len, j][mask] = label * torch.from_numpy(direction_map[mask])
107 |
108 | data.deadpan_mask[i] = sample.is_deadpan
109 |
110 | def __call__(self, batch: List[ScorePerformanceSample], inference: bool = False, return_tensors: bool = True):
111 | data = self.init_data(batch, inference=inference)
112 | for i, sample in enumerate(batch):
113 | self.process_sample(i, sample, data)
114 |
115 | return data
116 |
117 |
118 | # FOR LANGUAGE MODELING
119 | @dataclass
120 | class LMScorePerformanceInputs(ScorePerformanceInputs):
121 | labels: Optional[SeqInputs] = None
122 |
123 |
124 | class LMScorePerformanceCollator(ScorePerformanceCollator, LMPerformanceCollator):
125 | def __init__(
126 | self,
127 | pad_token_id: int = 0,
128 | pad_to_multiple_of: int = 1,
129 |
130 | mlm: bool = False,
131 | mask_prob: float = 0.15,
132 | replace_prob: float = 0.9,
133 | mask_token_id: int = 1,
134 | mask_ignore_token_ids: Optional[List[int]] = None,
135 | mask_ignore_token_dims: Optional[List[int]] = None,
136 | label_pad_ignored_dims: bool = True,
137 | label_pad_token_id: int = -100
138 | ):
139 | LMPerformanceCollator.__init__(
140 | self,
141 | pad_token_id=pad_token_id,
142 | pad_to_multiple_of=pad_to_multiple_of,
143 | mlm=mlm,
144 | mask_prob=mask_prob,
145 | replace_prob=replace_prob,
146 | mask_token_id=mask_token_id,
147 | mask_ignore_token_ids=mask_ignore_token_ids,
148 | mask_ignore_token_dims=mask_ignore_token_dims,
149 | label_pad_ignored_dims=label_pad_ignored_dims,
150 | label_pad_token_id=label_pad_token_id
151 | )
152 |
153 | def __call__(self, batch: List[ScorePerformanceSample], inference: bool = False, return_tensors: bool = True):
154 | data = super().__call__(batch, inference=inference)
155 |
156 | if self.mlm:
157 | masked_seq, labels, label_mask = self.mask_sequence(data.performances.tokens)
158 | data.performances.tokens = masked_seq
159 | else:
160 | labels = data.performances.tokens.clone().detach()
161 | labels[labels == self.pad_token_id] = self.label_pad_token_id
162 | label_mask = data.performances.mask.clone().detach()
163 |
164 | data = LMScorePerformanceInputs(
165 | scores=data.scores,
166 | performances=data.performances,
167 | noisy_performances=data.noisy_performances,
168 | segments=data.segments,
169 | directions=data.directions,
170 | deadpan_mask=data.deadpan_mask,
171 | labels=SeqInputs(
172 | tokens=labels,
173 | mask=label_mask,
174 | lengths=data.performances.lengths
175 | )
176 | )
177 |
178 | return data
179 |
180 |
181 | @dataclass
182 | class MixedLMScorePerformanceInputs(LMScorePerformanceInputs):
183 | masked_performances: Optional[SeqInputs] = None
184 |
185 |
186 | class MixedLMScorePerformanceCollator(ScorePerformanceCollator, MixedLMPerformanceCollator):
187 | def __init__(
188 | self,
189 | pad_token_id: int = 0,
190 | pad_to_multiple_of: int = 1,
191 |
192 | mask_token_id: int = 1,
193 | mask_ignore_token_ids: Optional[List[int]] = None,
194 | mask_ignore_token_dims: Optional[List[int]] = None,
195 | label_pad_ignored_dims: bool = True,
196 | label_pad_token_id: int = -100
197 | ):
198 | MixedLMPerformanceCollator.__init__(
199 | self,
200 | pad_token_id=pad_token_id,
201 | pad_to_multiple_of=pad_to_multiple_of,
202 | mask_token_id=mask_token_id,
203 | mask_ignore_token_ids=mask_ignore_token_ids,
204 | mask_ignore_token_dims=mask_ignore_token_dims,
205 | label_pad_ignored_dims=label_pad_ignored_dims,
206 | label_pad_token_id=label_pad_token_id
207 | )
208 |
209 | def __call__(self, batch: List[ScorePerformanceSample], inference: bool = False, return_tensors: bool = True):
210 | data = super().__call__(batch, inference=inference)
211 |
212 | masked_performances, labels = self.mask_sequence(data.performances.tokens)
213 | label_mask = data.performances.mask.clone().detach()
214 |
215 | data = MixedLMScorePerformanceInputs(
216 | scores=data.scores,
217 | performances=data.performances,
218 | noisy_performances=data.noisy_performances,
219 | segments=data.segments,
220 | directions=data.directions,
221 | deadpan_mask=data.deadpan_mask,
222 | masked_performances=SeqInputs(
223 | tokens=masked_performances,
224 | mask=label_mask,
225 | lengths=data.performances.lengths
226 | ),
227 | labels=SeqInputs(
228 | tokens=labels,
229 | mask=label_mask,
230 | lengths=data.performances.lengths
231 | )
232 | )
233 |
234 | return data
235 |
--------------------------------------------------------------------------------
/scoreperformer/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .directions import DirectionBarEmbeddingDataset
2 | from .performance import (
3 | PerformanceDataset,
4 | PerformanceSampleMeta,
5 | PerformanceSample
6 | )
7 | from .score_performance import (
8 | ScorePerformanceDataset,
9 | LocalScorePerformanceDataset,
10 | ScorePerformanceSampleMeta,
11 | ScorePerformanceSample,
12 | NoteSegments,
13 | SequenceTypes
14 | )
15 | from .token_sequence import (
16 | TokenSequenceDataset,
17 | LocalTokenSequenceDataset
18 | )
19 |
--------------------------------------------------------------------------------
/scoreperformer/data/datasets/performance.py:
--------------------------------------------------------------------------------
1 | """ Performance token sequence dataset. """
2 |
3 | import copy
4 | import os
5 | from dataclasses import dataclass
6 | from functools import partial
7 | from typing import Optional, Tuple
8 |
9 | import numpy as np
10 | from torch.utils.data import Dataset
11 |
12 | from scoreperformer.utils import apply, load_json
13 | from .token_sequence import load_token_sequence, LocalTokenSequenceDataset
14 | from .utils import load_tokens_np, get_num_bars, compute_bar_sample_positions, get_end_bar
15 | from ..helpers import (
16 | TOKEN_SEQUENCE_PROCESSORS,
17 | TOKEN_SEQUENCE_INDEXERS,
18 | TokenSequenceAugmentations
19 | )
20 | from ..tokenizers import TOKENIZERS
21 |
22 |
23 | @dataclass
24 | class PerformanceSampleMeta:
25 | idx: Optional[int]
26 | perf_idx: int
27 | start_bar: int
28 | end_bar: Optional[int]
29 | bar_offset: int = 0
30 | augmentations: Optional[TokenSequenceAugmentations] = None
31 |
32 |
33 | @dataclass
34 | class PerformanceSample:
35 | perf: np.ndarray
36 | meta: PerformanceSampleMeta
37 |
38 |
39 | class PerformanceDataset(Dataset):
40 | def __init__(
41 | self,
42 | root: str,
43 | split: str = 'train',
44 | encoding: str = 'OctupleM',
45 |
46 | max_seq_len: int = 512,
47 | max_bar: int = 256,
48 | bar_sliding_window: int = 16,
49 |
50 | fit_to_max_bar: bool = False,
51 | fit_to_zero_bar: bool = False,
52 | sample_bars: bool = False,
53 |
54 | add_sos_eos: bool = False,
55 |
56 | sample: bool = False,
57 | seed: int = 23,
58 |
59 | augment_performance: bool = False,
60 | pitch_shift_range: Tuple[int, int] = (-3, 3),
61 | velocity_shift_range: Tuple[int, int] = (-2, 2),
62 | tempo_shift_range: Tuple[int, int] = (-2, 2),
63 |
64 | cache: bool = True,
65 | **kwargs
66 | ):
67 |
68 | self.root = root
69 | self.split = split
70 |
71 | # load metadata
72 | metadata_file = os.path.join(self.root, 'metadata.json')
73 | metadata = load_json(metadata_file)
74 |
75 | if all(key not in metadata for key in ['all', 'train', 'eval', 'val', 'test']):
76 | self.metadata = metadata # metadata of names
77 | else:
78 | self.metadata = metadata[self.split]
79 |
80 | self.performance_names = list(self.metadata)
81 |
82 | # load tokenizer
83 | if encoding not in TOKENIZERS:
84 | raise ValueError(f"Encoding {encoding} is not a valid encoding, "
85 | f"supported types are: {list(TOKENIZERS.keys())}")
86 |
87 | self.encoding = encoding
88 | self.tokenizer = TOKENIZERS[encoding](params=os.path.join(self.root, 'config.json'))
89 |
90 | # load sequences
91 | load_tokens = partial(load_tokens_np, tokenizer=self.tokenizer)
92 | perf_load_fn = partial(load_token_sequence, load_fn=load_tokens)
93 | self.performances = LocalTokenSequenceDataset(
94 | root=self.root,
95 | files=self.performance_names,
96 | load_fn=perf_load_fn,
97 | cache=cache
98 | )
99 |
100 | # configurations
101 | self.max_seq_len = max_seq_len
102 | self.max_bar = max_bar
103 | self.bar_sliding_window = bar_sliding_window
104 | self.add_sos_eos = add_sos_eos
105 | assert max_bar <= self.tokenizer.max_bar_embedding
106 |
107 | # bar indexer and indices arrays
108 | self.indexer = TOKEN_SEQUENCE_INDEXERS[encoding](self.tokenizer)
109 | self._bar_indices = [None] * len(self.performances)
110 |
111 | # load or compute number of bars in performances used to build samples
112 | bars_file = os.path.join(self.root, 'bars.json')
113 | if os.path.exists(bars_file):
114 | _num_bars = load_json(bars_file)
115 | _perf_num_bars = np.array([_num_bars[perf] for perf in self.performance_names])
116 | else:
117 | _perf_num_bars = np.array(apply(self.performances, partial(get_num_bars, tokenizer=self.tokenizer)))
118 |
119 | # compute sample positions
120 | self._length, self._sample_positions, self._sample_ids = compute_bar_sample_positions(
121 | seq_num_bars=_perf_num_bars, bar_sliding_window=self.bar_sliding_window
122 | )
123 |
124 | # random effects they do not advertise
125 | self.sample = sample
126 | if self.sample:
127 | np.random.seed(seed)
128 |
129 | # bar sampling
130 | assert not (fit_to_max_bar and fit_to_zero_bar), \
131 | "Only one of `fit_to_max_bar`/`fit_to_zero_bar` could be set to True"
132 | self.fit_to_max_bar = fit_to_max_bar
133 | self.fit_to_zero_bar = fit_to_zero_bar
134 | self.sample_bars = sample and sample_bars
135 |
136 | # augmentations
137 | self.augment_performance = sample and augment_performance
138 |
139 | if not self.augment_performance:
140 | pitch_shift_range = velocity_shift_range = tempo_shift_range = (0, 0)
141 |
142 | # sequence processor
143 | self.processor = TOKEN_SEQUENCE_PROCESSORS[encoding](
144 | tokenizer=self.tokenizer,
145 | pitch_shift_range=pitch_shift_range,
146 | velocity_shift_range=velocity_shift_range,
147 | tempo_shift_range=tempo_shift_range,
148 | )
149 |
150 | def _get_augmentations(self, meta):
151 | if meta is None:
152 | if self.augment_performance:
153 | return self.processor.sample_augmentations()
154 | else:
155 | return None
156 | else:
157 | return meta.augmentations
158 |
159 | def _augment_sequence(self, seq, augmentations):
160 | if augmentations is None:
161 | return seq
162 |
163 | seq = self.processor.augment_sequence(seq, augmentations)
164 | mask = self.processor.compute_valid_pitch_mask(seq)
165 | return seq[mask]
166 |
167 | def get(self, idx: Optional[int] = None, meta: Optional[PerformanceSampleMeta] = None):
168 | assert idx is not None or meta is not None, 'one of `idx`/`meta` should be provided as an argument'
169 |
170 | # get performance
171 | if meta is None:
172 | perf_idx = np.where(idx >= self._sample_ids)[0][-1]
173 | else:
174 | idx, perf_idx = meta.idx, meta.perf_idx
175 |
176 | bar_indices = self._bar_indices[perf_idx]
177 | if bar_indices is None:
178 | bar_indices = self._bar_indices[perf_idx] = self.indexer.compute_bar_indices(self.performances[perf_idx])
179 |
180 | total_bars = bar_indices.shape[0] - 1
181 |
182 | # compute start bar index
183 | if meta is None:
184 | start_bar = self._sample_positions[idx]
185 | start_bar = min(start_bar, bar_indices.shape[0] - self.bar_sliding_window // 2) # bars of silent notes
186 | if self.sample:
187 | low = max(0, start_bar - self.bar_sliding_window // 2)
188 | high = min(total_bars - self.bar_sliding_window // 4, start_bar + self.bar_sliding_window // 2)
189 | high = max(low + 1, high)
190 | start_bar = np.random.randint(low, high)
191 | else:
192 | start_bar = meta.start_bar
193 |
194 | # compute start index
195 | perf_start = bar_indices[start_bar]
196 |
197 | # compute end bar index
198 | if meta is None or meta.end_bar is None:
199 | end_bar = get_end_bar(bar_indices, start_bar, self.max_seq_len, self.max_bar)
200 | else:
201 | end_bar = meta.end_bar
202 |
203 | # compute end index
204 | perf_end = bar_indices[end_bar + 1]
205 |
206 | # get token sequences
207 | perf_seq = copy.copy(self.performances[perf_idx][perf_start:perf_end])
208 |
209 | min_bar = perf_seq[:, 0].min() - self.tokenizer.zero_token
210 | max_bar = perf_seq[:, 0].max() - self.tokenizer.zero_token
211 |
212 | # shift bar indices
213 | bar_offset = 0
214 | if meta is None:
215 | if self.fit_to_max_bar:
216 | # to make bar index distribute in [0, bar_max)
217 | if self.sample_bars:
218 | bar_offset = np.random.randint(-min_bar, self.max_bar - max_bar)
219 | elif end_bar >= self.max_bar:
220 | # move in proportion to `score_total_bars`
221 | _end_bar = int((self.max_bar - 1) * max_bar / total_bars)
222 | bar_offset = _end_bar - max_bar
223 | elif self.fit_to_zero_bar:
224 | bar_offset = -min_bar
225 | else:
226 | bar_offset = meta.bar_offset
227 |
228 | if bar_offset != 0:
229 | perf_seq[:, self.tokenizer.vocab_types_idx['Bar']] += bar_offset
230 |
231 | # augmentations
232 | augmentations = self._get_augmentations(meta)
233 | perf_seq = self._augment_sequence(perf_seq, augmentations)
234 |
235 | if self.add_sos_eos:
236 | if start_bar == 0:
237 | perf_seq = self.processor.add_sos_token(perf_seq)
238 | if end_bar + 1 == total_bars:
239 | perf_seq = self.processor.add_eos_token(perf_seq)
240 |
241 | # build sample metadata
242 | meta = PerformanceSampleMeta(
243 | idx=idx,
244 | perf_idx=perf_idx,
245 | start_bar=start_bar,
246 | end_bar=end_bar,
247 | bar_offset=bar_offset,
248 | augmentations=augmentations
249 | )
250 |
251 | return PerformanceSample(
252 | perf=perf_seq,
253 | meta=meta
254 | )
255 |
256 | def __getitem__(self, idx: int):
257 | return self.get(idx=idx)
258 |
259 | def __len__(self):
260 | return self._length
261 |
--------------------------------------------------------------------------------
/scoreperformer/data/datasets/token_sequence.py:
--------------------------------------------------------------------------------
1 | """ Token sequence datasets. """
2 |
3 | import os
4 | from pathlib import Path, PurePath
5 |
6 | from torch.utils.data import Dataset
7 |
8 | from scoreperformer.utils import apply, load_json
9 |
10 |
11 | def load_token_sequence(path, load_fn, processing_funcs=None):
12 | seq = load_fn(path)
13 | if processing_funcs:
14 | for func in processing_funcs:
15 | seq = func(seq)
16 | return seq
17 |
18 |
19 | class TokenSequenceDataset(Dataset):
20 | def __init__(self, sequences, names=None):
21 | self.seqs = sequences
22 |
23 | self.names = names
24 | if names is not None:
25 | self._name_to_idx = {name: idx for idx, name in enumerate(self.names)}
26 |
27 | def __getitem__(self, idx):
28 | seq = self.seqs[idx]
29 | return seq[0] if isinstance(seq, tuple) else seq
30 |
31 | def __len__(self):
32 | return len(self.seqs)
33 |
34 |
35 | class LocalTokenSequenceDataset(TokenSequenceDataset):
36 | def __init__(self, root, files=None, suffix='.json', load_fn=load_json, preload=False, cache=False):
37 | self.root = root
38 | self.load_fn = load_fn
39 |
40 | if files is None:
41 | if os.path.isfile(root) and root.lower().endswith(suffix):
42 | files = [Path(root)]
43 | else:
44 | files = list(Path(root).glob('**/*' + suffix))
45 | files = list(map(Path, files))
46 | else:
47 | files = list(map(lambda x: Path(x).with_suffix(suffix), files))
48 |
49 | paths = [PurePath(os.path.join(self.root, file)) for file in files]
50 |
51 | self.paths = paths
52 |
53 | self._cache = cache
54 |
55 | self.seqs = self.load_sequences(preload=preload)
56 | names = [str(file).replace(suffix, '') for file in files]
57 |
58 | super().__init__(sequences=self.seqs, names=names)
59 |
60 | def load_sequence(self, path):
61 | return self.load_fn(path)
62 |
63 | def load_sequences(self, preload):
64 | if preload:
65 | return apply(self.paths, func=self.load_sequence, desc='Loading token sequences...')
66 | else:
67 | return [None] * len(self.paths)
68 |
69 | def __getitem__(self, idx):
70 | if self.seqs[idx] is None:
71 | seq = self.load_sequence(self.paths[idx])
72 | if self._cache:
73 | self.seqs[idx] = seq
74 | else:
75 | seq = self.seqs[idx]
76 | return seq[0] if isinstance(seq, tuple) else seq
77 |
78 | def __len__(self):
79 | return len(self.seqs)
80 |
--------------------------------------------------------------------------------
/scoreperformer/data/datasets/utils.py:
--------------------------------------------------------------------------------
1 | """ Common datasets' utils. """
2 |
3 | import random
4 | from typing import Optional, Dict, List
5 |
6 | import numpy as np
7 |
8 | from ..tokenizers import OctupleM, TokSequence
9 |
10 |
11 | def load_tokens_data(path, tokenizer):
12 | data = tokenizer.load_tokens(path)
13 | if isinstance(data, list): # backward compatibility with old datasets (miditok<=1.2.3)
14 | data = {'ids': data[0], 'programs': data[1] if len(data) > 1 else []}
15 | elif 'ids' not in data: # backward compatibility with old datasets (miditok<=2.0.0)
16 | data['ids'] = data['tokens']
17 | del data['tokens']
18 | return data
19 |
20 |
21 | def load_tokens_np(path, tokenizer):
22 | return np.array(load_tokens_data(path, tokenizer)['ids'])
23 |
24 |
25 | def load_token_sequence(path, tokenizer):
26 | data = load_tokens_data(path, tokenizer)
27 | return TokSequence(ids=data['ids'], meta=data.get('meta', {}))
28 |
29 |
30 | def get_num_bars(seq, tokenizer):
31 | if isinstance(tokenizer, OctupleM):
32 | bar_idx = tokenizer.vocab_types_idx['Bar']
33 | return int(seq[-1, bar_idx] - tokenizer.zero_token + 1)
34 | else:
35 | raise ValueError(f"Unsupported tokenizer: {tokenizer.__class__.__name__}")
36 |
37 |
38 | def compute_bar_sample_positions(seq_num_bars, bar_sliding_window):
39 | bar_shift = bar_sliding_window
40 | length, sample_positions = 0, []
41 | for num_bars in seq_num_bars:
42 | back_shift = -bar_shift // 4 if (num_bars - bar_shift // 2) % bar_shift == 0 else 0
43 | positions = np.concatenate([
44 | np.arange(0, num_bars - bar_shift // 2, bar_shift),
45 | np.arange(num_bars - bar_shift // 2 - back_shift, -1 + bar_shift // 2, -bar_shift)
46 | ])
47 | length += len(positions)
48 | sample_positions.append(positions)
49 |
50 | sample_ids = np.concatenate([[0], np.cumsum(list(map(len, sample_positions)))[:-1]])
51 | sample_positions = np.concatenate(sample_positions)
52 |
53 | return length, sample_positions, sample_ids
54 |
55 |
56 | def get_end_bar(score_indices, start_bar=0, max_seq_len=512, max_bar=256):
57 | end_bar = np.where(score_indices <= score_indices[start_bar] + max_seq_len)[0][-1] - 1
58 | return min(max(start_bar, end_bar), start_bar + max_bar - 1)
59 |
60 |
61 | def split_composer_metadata(
62 | reference_metadata: Dict[str, List[str]],
63 | splits: Dict[str, float],
64 | seed: Optional[int] = None
65 | ):
66 | if seed is not None:
67 | random.seed(seed)
68 | np.random.seed(seed)
69 |
70 | data_ = {split: dict() for split in splits}
71 |
72 | for comp, score_perf in reference_metadata.items():
73 | comp_meta_rep = []
74 |
75 | score_perf = list(score_perf.items())
76 | np.random.shuffle(score_perf)
77 | score_perf = dict(score_perf)
78 |
79 | for score, perfs in score_perf.items():
80 | comp_meta_rep.extend([score] * len(perfs))
81 |
82 | if len(comp_meta_rep) > 10:
83 | start = 0
84 | for i, (split, ratio) in enumerate(splits.items()):
85 | end = min(len(comp_meta_rep), start + round(ratio * len(comp_meta_rep)))
86 |
87 | if i == len(splits) - 1:
88 | end = len(comp_meta_rep)
89 |
90 | if end < len(comp_meta_rep) and comp_meta_rep[end - 1] == comp_meta_rep[len(comp_meta_rep) - 1]:
91 | while end > 0 and comp_meta_rep[end] == comp_meta_rep[end - 1]:
92 | end -= 1
93 | else:
94 | while end < len(comp_meta_rep) and comp_meta_rep[end - 1] == comp_meta_rep[end]:
95 | end += 1
96 |
97 | split_scores = np.unique(comp_meta_rep[start:end]).tolist()
98 | for score in split_scores:
99 | data_[split][score] = score_perf[score]
100 | start = end
101 | else:
102 | for score, perfs in score_perf.items():
103 | _split = np.random.choice(np.array(list(splits.keys())), p=np.array(list(splits.values())))
104 | data_[_split][score] = perfs
105 |
106 | for _split in data_:
107 | data_[_split] = dict(sorted(data_[_split].items()))
108 |
109 | return data_
110 |
--------------------------------------------------------------------------------
/scoreperformer/data/directions/__init__.py:
--------------------------------------------------------------------------------
1 | from .articulation import ARTICULATION_PREFIX, ARTICULATION_KEYS
2 | from .dynamic import DYNAMIC_PREFIX, DYNAMIC_KEYS, ABS_DYNAMIC_KEYS, REL_DYNAMIC_KEYS
3 | from .parser import parse_directions
4 | from .tempo import TEMPO_PREFIX, TEMPO_KEYS, ABS_TEMPO_KEYS, REL_TEMPO_KEYS, RET_TEMPO_KEYS
5 | from .words import extract_main_keyword
6 |
7 |
8 | def build_prefixed_keys(keys, prefix):
9 | return list(map(lambda d: f'{prefix}/' + extract_main_keyword(d), keys))
10 |
11 |
12 | DYNAMIC_DIRECTION_KEYS = build_prefixed_keys(DYNAMIC_KEYS, DYNAMIC_PREFIX)
13 | TEMPO_DIRECTION_KEYS = build_prefixed_keys(TEMPO_KEYS, TEMPO_PREFIX)
14 | ARTICULATION_DIRECTION_KEYS = build_prefixed_keys(ARTICULATION_KEYS, ARTICULATION_PREFIX)
15 |
--------------------------------------------------------------------------------
/scoreperformer/data/directions/articulation.py:
--------------------------------------------------------------------------------
1 | ARTICULATION_PREFIX = 'articulation'
2 |
3 | ARTICULATION_KEYS = [
4 | 'arpeggiate',
5 | 'fermata',
6 | 'staccato',
7 | 'tenuto'
8 | ]
9 |
--------------------------------------------------------------------------------
/scoreperformer/data/directions/dynamic.py:
--------------------------------------------------------------------------------
1 | DYNAMIC_PREFIX = 'dynamic'
2 |
3 | ABS_DYNAMIC_KEYS = [
4 | 'pppp', 'ppp', 'pp',
5 | ('p', 'piano'),
6 | 'mp', 'mf',
7 | ('f', 'forte'),
8 | 'ff', 'fff', 'ffff',
9 | 'fp', 'ffp'
10 | ]
11 |
12 | REL_DYNAMIC_KEYS = [
13 | ('crescendo', 'cresc'),
14 | ('diminuendo', 'dim', 'decresc'),
15 | ('sf', 'fz', 'sfz', 'sffz'),
16 | ('rf', 'rfz')
17 | ]
18 |
19 | DYNAMIC_KEYS = ABS_DYNAMIC_KEYS + REL_DYNAMIC_KEYS
20 |
21 |
22 | def hairpin_word_regularization(word):
23 | if 'decresc' in word:
24 | word = 'diminuendo'
25 | elif 'cresc' in word:
26 | word = 'crescendo'
27 | elif 'dim' in word:
28 | word = 'diminuendo'
29 | return word
30 |
--------------------------------------------------------------------------------
/scoreperformer/data/directions/parser.py:
--------------------------------------------------------------------------------
1 | """ MusicXML performance direction data parsing. """
2 |
3 | from musicxml_parser.playable_notes import get_playable_notes
4 |
5 | from .articulation import ARTICULATION_PREFIX
6 | from .dynamic import DYNAMIC_PREFIX, ABS_DYNAMIC_KEYS, REL_DYNAMIC_KEYS, hairpin_word_regularization
7 | from .tempo import TEMPO_PREFIX, TEMPO_KEYS
8 | from .words import extract_direction_by_keys, word_regularization
9 |
10 |
11 | def get_part_directions(part):
12 | directions = []
13 | for measure_idx, measure in enumerate(part.measures):
14 | for direction in measure.directions:
15 | direction.type['measure'] = measure_idx
16 | directions.append(direction)
17 |
18 | directions.sort(key=lambda x: x.xml_position)
19 | cleaned_direction = []
20 | for i, d in enumerate(directions):
21 | if d.type is None:
22 | continue
23 |
24 | if d.type['type'] == 'none':
25 | for j in range(i):
26 | prev_dir = directions[i - j - 1]
27 | if 'number' in prev_dir.type.keys():
28 | prev_key = prev_dir.type['type']
29 | prev_num = prev_dir.type['number']
30 | else:
31 | continue
32 | if prev_num == d.type['number']:
33 | if prev_key == "crescendo":
34 | d.type['type'] = 'crescendo'
35 | break
36 | elif prev_key == "diminuendo":
37 | d.type['type'] = 'diminuendo'
38 | break
39 | cleaned_direction.append(d)
40 |
41 | return cleaned_direction
42 |
43 |
44 | def get_directions(doc):
45 | return [get_part_directions(part) for part in doc.parts]
46 |
47 |
48 | def parse_directions(doc, score_directions=None, delete_unmatched=False, delete_duplicates=False, ticks_scale=1.):
49 | score_directions_init = get_directions(doc) if score_directions is None else score_directions
50 |
51 | last_note = doc.parts[-1].measures[-1].notes[-1].note_duration
52 | max_xml_position = max(doc._state.xml_position, last_note.xml_position + last_note.duration)
53 |
54 | # anacrusis
55 | measure_pos = doc.get_measure_positions()
56 | xml_shift = max(0, measure_pos[2] - 2 * measure_pos[1] + measure_pos[0])
57 |
58 | score_directions = []
59 | for part_idx, part_directions_init in enumerate(score_directions_init):
60 | active_dynamic = None
61 | active_tempo = None
62 | active_hairpins = {}
63 | part_directions = []
64 | for d_idx, d in enumerate(part_directions_init):
65 | d_data, d_dict = d.type, None
66 | if d_data['type'] == 'dynamic':
67 | d_dict = {
68 | 'type': d_data['type'],
69 | 'start': d.xml_position,
70 | 'end': max_xml_position
71 | }
72 | abs_dynamic = extract_direction_by_keys(d_data['content'], ABS_DYNAMIC_KEYS)
73 | rel_dynamic = extract_direction_by_keys(d_data['content'], REL_DYNAMIC_KEYS)
74 |
75 | if abs_dynamic is not None:
76 | d_dict['type'] += '/' + abs_dynamic
77 | if active_dynamic is not None:
78 | active_dynamic['end'] = d.xml_position
79 | active_dynamic = d_dict
80 | elif rel_dynamic is not None:
81 | d_dict['type'] += '/' + rel_dynamic
82 | d_dict['end'] = d_dict['start']
83 | else:
84 | continue
85 | elif d_data['type'] in ('crescendo', 'diminuendo'):
86 | key = f'{d_data["type"]}_{d_data["number"]}'
87 | if d_data['content'] == 'start':
88 | active_hairpins[key] = d
89 | elif d_data['content'] == 'stop':
90 | start_d = active_hairpins.pop(key, None)
91 | if not start_d:
92 | continue
93 | d_dict = {
94 | 'type': 'dynamic' + '/' + d_data['type'],
95 | 'start': start_d.xml_position,
96 | 'end': d.xml_position
97 | }
98 | elif d_data['type'] == 'words':
99 | word = word_regularization(d_data['content'])
100 | word = hairpin_word_regularization(word)
101 | tempo_word = extract_direction_by_keys(word, TEMPO_KEYS)
102 |
103 | if word in ('crescendo', 'diminuendo'):
104 | d_dict = {'type': DYNAMIC_PREFIX}
105 | elif tempo_word is not None:
106 | word = tempo_word
107 | d_dict = {'type': TEMPO_PREFIX}
108 | if active_tempo is not None:
109 | active_tempo['end'] = d.xml_position
110 | active_tempo = d_dict
111 | elif delete_unmatched:
112 | continue
113 | else:
114 | d_dict = {'type': d_data['type']}
115 |
116 | d_dict['type'] += '/' + word
117 | d_dict.update(**{
118 | 'start': d.xml_position,
119 | 'end': max_xml_position if d_dict['type'] == 'tempo' else d.xml_position
120 | })
121 | else:
122 | d_dict = None
123 |
124 | if d_dict is not None:
125 | d_dict.update(**{
126 | 'part': part_idx,
127 | 'staff': int(d.staff) if d.staff is not None else 1
128 | })
129 | part_directions.append(d_dict)
130 |
131 | # parse note articulations
132 | def _build_note_articulation_dict(note, content):
133 | return {
134 | 'type': ARTICULATION_PREFIX + '/' + content,
135 | 'start': note.note_duration.xml_position,
136 | 'end': note.note_duration.xml_position + note.note_duration.duration,
137 | 'pitch': note.pitch[1],
138 | 'part': part_idx,
139 | 'staff': int(note.staff) if note.staff is not None else 1
140 | }
141 |
142 | part_notes, _ = get_playable_notes(doc.parts[part_idx])
143 | for note in part_notes:
144 | if note.note_notations.is_arpeggiate:
145 | part_directions.append(_build_note_articulation_dict(note, 'arpeggiate'))
146 | if note.note_notations.is_fermata:
147 | part_directions.append(_build_note_articulation_dict(note, 'fermata'))
148 | if note.note_notations.is_staccato:
149 | part_directions.append(_build_note_articulation_dict(note, 'staccato'))
150 | if note.note_notations.is_tenuto:
151 | part_directions.append(_build_note_articulation_dict(note, 'tenuto'))
152 |
153 | # scale xml positions if needed
154 | if xml_shift != 0 or ticks_scale != 1.:
155 | for d_dict in part_directions:
156 | d_dict['start'] = int(ticks_scale * (d_dict['start'] + xml_shift))
157 | d_dict['end'] = int(ticks_scale * (d_dict['end'] + xml_shift))
158 |
159 | # sort directions
160 | part_directions = list(sorted(part_directions, key=lambda d: (d['start'], d['type'], d['end'])))
161 |
162 | if delete_duplicates:
163 | i = 0
164 | while i < len(part_directions) - 1:
165 | d_dict, next_d_dict = part_directions[i], part_directions[i + 1]
166 | if d_dict['type'] == next_d_dict['type'] and d_dict['start'] == next_d_dict['start']:
167 | del part_directions[i + 1]
168 | continue
169 | i += 1
170 |
171 | score_directions.append(part_directions)
172 |
173 | return score_directions
174 |
--------------------------------------------------------------------------------
/scoreperformer/data/directions/tempo.py:
--------------------------------------------------------------------------------
1 | TEMPO_PREFIX = 'tempo'
2 |
3 | ABS_TEMPO_KEYS = [
4 | 'grave', 'largo', 'larghetto', 'lento',
5 | 'adagio', 'andante', 'andantino', 'moderato',
6 | 'allegretto', 'allegro', 'vivace',
7 | 'presto', 'prestissimo'
8 | ]
9 |
10 | REL_TEMPO_KEYS = [
11 | ('accelerando', 'acc', 'accel'),
12 | ('ritardando', 'rit', 'ritard'),
13 | ('rallentando', 'rall'),
14 | ('stringendo', 'string'),
15 | 'calando', 'più mosso', 'animato', 'stretto', 'smorzando', 'ritenuto'
16 | ]
17 |
18 | RET_TEMPO_KEYS = [
19 | ('tempo primo', 'tempo i'),
20 | 'a tempo',
21 | ]
22 |
23 | TEMPO_KEYS = ABS_TEMPO_KEYS + REL_TEMPO_KEYS + RET_TEMPO_KEYS
24 |
--------------------------------------------------------------------------------
/scoreperformer/data/directions/words.py:
--------------------------------------------------------------------------------
1 | """ Utility functions for word-based direction markings. """
2 |
3 | PUNCTUATION = [',.\n()']
4 |
5 |
6 | def word_regularization(word):
7 | if word:
8 | for symbol in PUNCTUATION:
9 | word = word.replace(symbol, ' ')
10 | word = word.replace(' ', ' ')
11 | return word.strip().lower()
12 | else:
13 | return None
14 |
15 |
16 | def extract_main_keyword(key):
17 | if isinstance(key, tuple):
18 | return key[0]
19 | return key
20 |
21 |
22 | def extract_direction_by_keys(dir_word, keywords):
23 | for key in keywords:
24 | if isinstance(key, tuple) and dir_word in key:
25 | return key[0]
26 | elif dir_word == key:
27 | return key
28 | return
29 |
30 |
31 | def extract_all_directions_by_keys(dir_word, keywords):
32 | directions = []
33 | for key in keywords:
34 | if isinstance(key, tuple) and dir_word in key:
35 | directions.append(key[0])
36 | elif dir_word == key:
37 | directions.append(key)
38 | return directions
39 |
40 |
41 | def check_direction_by_keywords(dir_word, keywords):
42 | dir_word = word_regularization(dir_word)
43 | if dir_word in keywords:
44 | return True
45 | else:
46 | word_split = dir_word.split(' ')
47 | for w in word_split:
48 | if w in keywords:
49 | return True
50 |
51 | for key in keywords: # words like 'sempre più mosso'
52 | if len(key) > 2 and key in dir_word:
53 | return True
54 |
55 | return False
56 |
--------------------------------------------------------------------------------
/scoreperformer/data/helpers/__init__.py:
--------------------------------------------------------------------------------
1 | from .indexers import TupleTokenSequenceIndexer
2 | from .processors import TokenSequenceAugmentations, TupleTokenSequenceProcessor
3 | from ..tokenizers import TokenizerTypes
4 |
5 | TOKEN_SEQUENCE_PROCESSORS = {
6 | TokenizerTypes.OctupleM: TupleTokenSequenceProcessor,
7 | TokenizerTypes.SPMuple: TupleTokenSequenceProcessor
8 | }
9 |
10 | TOKEN_SEQUENCE_INDEXERS = {
11 | TokenizerTypes.OctupleM: TupleTokenSequenceIndexer,
12 | TokenizerTypes.SPMuple: TupleTokenSequenceIndexer
13 | }
14 |
--------------------------------------------------------------------------------
/scoreperformer/data/helpers/indexers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class TokenSequenceIndexer:
5 | def __init__(self, tokenizer):
6 | self.tokenizer = tokenizer
7 |
8 | def compute_bar_indices(self, seq: np.ndarray) -> np.ndarray:
9 | ...
10 |
11 |
12 | class TupleTokenSequenceIndexer(TokenSequenceIndexer):
13 | def __init__(self, tokenizer):
14 | super().__init__(tokenizer)
15 |
16 | def compute_bar_indices(self, seq: np.ndarray) -> np.ndarray:
17 | bar_index = self.tokenizer.vocab_types_idx['Bar']
18 |
19 | min_bar = seq[0, bar_index] - self.tokenizer.zero_token
20 | total_bars = seq[-1, bar_index] - self.tokenizer.zero_token + 1
21 |
22 | bar_diff = np.concatenate([[min_bar], np.diff(seq[:, bar_index])])
23 | bar_changes = np.where(bar_diff > 0)[0]
24 |
25 | bars = np.concatenate([[0], np.cumsum(bar_diff[bar_changes]), [total_bars]])
26 | bar_changes = np.concatenate([[0], bar_changes, [seq.shape[0]]])
27 |
28 | bar_indices = np.full(bars[-1] + 1, -1, dtype=np.int16)
29 | bar_indices[bars] = bar_changes
30 |
31 | for idx in range(len(bar_indices) - 1, 0, -1):
32 | if bar_indices[idx] == -1:
33 | bar_indices[idx] = bar_indices[idx + 1]
34 |
35 | return bar_indices
36 |
--------------------------------------------------------------------------------
/scoreperformer/data/helpers/processors.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from functools import partial
3 | from typing import Optional
4 |
5 | import numpy as np
6 |
7 | from ..tokenizers import OctupleM
8 | from ..tokenizers.constants import SOS_TOKEN, EOS_TOKEN
9 |
10 |
11 | def sample_integer_shift(low=-6, high=6):
12 | return np.random.randint(low, high + 1)
13 |
14 |
15 | @dataclass
16 | class TokenSequenceAugmentations:
17 | pitch_shift: Optional[int] = 0
18 | velocity_shift: Optional[int] = 0
19 | tempo_shift: Optional[int] = 0
20 |
21 |
22 | class TokenSequenceProcessor:
23 | def __init__(
24 | self,
25 | pitch_shift_range=(-3, 3),
26 | velocity_shift_range=(-2, 2),
27 | tempo_shift_range=(-2, 2),
28 | ):
29 | self.pitch_shift_fn = partial(sample_integer_shift, *pitch_shift_range)
30 | self.velocity_shift_fn = partial(sample_integer_shift, *velocity_shift_range)
31 | self.tempo_shift_fn = partial(sample_integer_shift, *tempo_shift_range)
32 |
33 | def sample_augmentations(self, multiplier=1.0):
34 | return TokenSequenceAugmentations(
35 | pitch_shift=int(multiplier * self.pitch_shift_fn()),
36 | velocity_shift=int(multiplier * self.velocity_shift_fn()),
37 | tempo_shift=int(multiplier * self.tempo_shift_fn())
38 | )
39 |
40 | def augment_sequence(
41 | self,
42 | seq: np.ndarray,
43 | augmentations: TokenSequenceAugmentations
44 | ) -> np.ndarray:
45 | ...
46 |
47 | def sort_sequence(self, seq: np.ndarray) -> np.ndarray:
48 | ...
49 |
50 | def add_sos_token(self, seq: np.ndarray) -> np.ndarray:
51 | ...
52 |
53 | def add_eos_token(self, seq: np.ndarray) -> np.ndarray:
54 | ...
55 |
56 |
57 | class TupleTokenSequenceProcessor(TokenSequenceProcessor):
58 | def __init__(
59 | self,
60 | tokenizer: OctupleM,
61 | pitch_shift_range=(-3, 3),
62 | velocity_shift_range=(-2, 2),
63 | tempo_shift_range=(-2, 2)
64 | ):
65 | super().__init__(pitch_shift_range, velocity_shift_range, tempo_shift_range)
66 |
67 | self.tokenizer = tokenizer
68 |
69 | def augment_sequence(
70 | self,
71 | seq: np.ndarray,
72 | augmentations: TokenSequenceAugmentations
73 | ) -> np.ndarray:
74 | ...
75 | if augmentations.pitch_shift != 0:
76 | pitch_index = self.tokenizer.vocab_types_idx['Pitch']
77 | seq[:, pitch_index] += augmentations.pitch_shift
78 |
79 | if augmentations.velocity_shift != 0:
80 | vel_index = self.tokenizer.vocab_types_idx['Velocity']
81 | vel_min, vel_max = self.tokenizer.zero_token, len(self.tokenizer.vocab[vel_index]) - 1
82 |
83 | seq[:, vel_index] += augmentations.velocity_shift
84 | seq[:, vel_index] = np.maximum(vel_min, np.minimum(vel_max, seq[:, vel_index]))
85 |
86 | if augmentations.tempo_shift != 0:
87 | tempo_index = self.tokenizer.vocab_types_idx['Tempo']
88 | tempo_min, tempo_max = self.tokenizer.zero_token, len(self.tokenizer.vocab[tempo_index]) - 1
89 |
90 | seq[:, tempo_index] += augmentations.tempo_shift
91 | seq[:, tempo_index] = np.maximum(tempo_min, np.minimum(tempo_max, seq[:, tempo_index]))
92 |
93 | return seq
94 |
95 | def sort_sequence(self, seq: np.ndarray) -> np.ndarray:
96 | seq = seq[np.lexsort((seq[:, self.tokenizer.vocab_types_idx['Pitch']],
97 | seq[:, self.tokenizer.vocab_types_idx['Position']],
98 | seq[:, self.tokenizer.vocab_types_idx['Bar']]))]
99 | return seq
100 |
101 | def add_sos_token(self, seq: np.ndarray, initial_tempo: Optional[int] = None) -> np.ndarray:
102 | sos_token_id = self.tokenizer[0, SOS_TOKEN]
103 | seq = np.concatenate((np.full_like(seq[:1], sos_token_id), seq), axis=0)
104 |
105 | return seq
106 |
107 | def add_eos_token(self, seq: np.ndarray) -> np.ndarray:
108 | eos_token_id = self.tokenizer[0, EOS_TOKEN]
109 | seq = np.concatenate((seq, np.full_like(seq[:1], eos_token_id)), axis=0)
110 | return seq
111 |
112 | # Auxiliary processing functions
113 |
114 | def zero_out_durations(self, seq: np.ndarray) -> np.ndarray:
115 | tto = self.tokenizer.vocab_types_idx
116 | velocity_index = tto['Velocity']
117 | if 'PerfDuration' in tto and seq.shape[-1] == len(tto):
118 | duration_index = tto['PerfDuration']
119 | else:
120 | duration_index = tto['Duration']
121 |
122 | silent_mask = seq[:, velocity_index] == self.tokenizer.zero_token
123 | seq[silent_mask, duration_index] = self.tokenizer.zero_token
124 |
125 | return seq
126 |
127 | def remove_silent_notes(self, seq: np.ndarray) -> np.ndarray:
128 | velocity_index = self.tokenizer.vocab_types_idx['Velocity']
129 |
130 | silent_mask = seq[:, velocity_index] == self.tokenizer.zero_token
131 | seq = seq[~silent_mask]
132 |
133 | return seq
134 |
135 | def compute_valid_pitch_mask(self, seq: np.ndarray) -> np.ndarray:
136 | pitch_index = self.tokenizer.vocab_types_idx['Pitch']
137 | pitch_min, pitch_max = self.tokenizer.zero_token, len(self.tokenizer.vocab[pitch_index]) - 1
138 | mask = np.logical_and(seq[:, pitch_index] >= pitch_min, seq[:, pitch_index] <= pitch_max)
139 | return mask
140 |
--------------------------------------------------------------------------------
/scoreperformer/data/midi/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ilya16/ScorePerformer/c04f9a3155bb6afb6ca7bcc1cac2325184cc1262/scoreperformer/data/midi/__init__.py
--------------------------------------------------------------------------------
/scoreperformer/data/midi/beats.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 |
3 | import numpy as np
4 | from miditoolkit import MidiFile, TimeSignature
5 |
6 | BEATS_IN_BARS = {
7 | 6: 2,
8 | 9: 3,
9 | 18: 3,
10 | 12: 4,
11 | 24: 4
12 | }
13 |
14 |
15 | def get_ticks_per_bar(time_sig: TimeSignature, ticks_per_beat: int = 480):
16 | return ticks_per_beat * 4 * time_sig.numerator // time_sig.denominator
17 |
18 |
19 | def get_inter_beat_interval(
20 | *,
21 | time_sig: Optional[TimeSignature],
22 | ticks_per_bar: Optional[int] = None,
23 | ticks_per_beat: int = 480
24 | ):
25 | if ticks_per_bar is None:
26 | ticks_per_bar = get_ticks_per_bar(time_sig, ticks_per_beat=ticks_per_beat)
27 |
28 | num_beat_in_bar = BEATS_IN_BARS.get(time_sig.numerator, time_sig.numerator)
29 | inter_beat_interval = int(ticks_per_bar / num_beat_in_bar)
30 |
31 | return inter_beat_interval
32 |
33 |
34 | def get_bar_beat_ticks(
35 | midi: Optional[MidiFile] = None,
36 | *,
37 | time_sigs: Optional[List[TimeSignature]] = None,
38 | ticks_per_beat: Optional[int] = None,
39 | max_tick: Optional[int] = None
40 | ):
41 | assert midi is not None or all(map(lambda x: x is not None, (time_sigs, ticks_per_beat, max_tick)))
42 |
43 | if midi is not None:
44 | time_sigs = midi.time_signature_changes
45 | ticks_per_beat = midi.ticks_per_beat
46 | max_tick = midi.max_tick - 1
47 |
48 | bar_ticks, beat_ticks = [], []
49 | for i, time_sig in enumerate(time_sigs):
50 | last_tick = time_sigs[i + 1].time if i < len(time_sigs) - 1 else max_tick
51 |
52 | ticks_per_bar = get_ticks_per_bar(time_sig, ticks_per_beat=ticks_per_beat)
53 | bar_ticks.append(np.arange(time_sig.time, last_tick, ticks_per_bar))
54 |
55 | inter_beat_interval = get_inter_beat_interval(
56 | time_sig=time_sig, ticks_per_bar=ticks_per_bar, ticks_per_beat=ticks_per_beat
57 | )
58 | beat_ticks.append(np.arange(time_sig.time, last_tick, inter_beat_interval))
59 |
60 | if len(time_sigs) > 1:
61 | bar_ticks, beat_ticks = np.concatenate(bar_ticks), np.concatenate(beat_ticks)
62 | else:
63 | bar_ticks, beat_ticks = bar_ticks[0], beat_ticks[0]
64 |
65 | return bar_ticks, beat_ticks
66 |
67 |
68 | def get_performance_beats(
69 | score_beats: np.ndarray,
70 | position_pairs: np.ndarray,
71 | max_tick: Optional[int] = None,
72 | max_time: Optional[float] = None,
73 | monotonic_times: bool = False,
74 | ticks_per_beat: int = 480
75 | ):
76 | if monotonic_times:
77 | mono_position_pairs = [position_pairs[0]]
78 | cur_pair = prev_pair = position_pairs[0]
79 | for pair in position_pairs[1:]:
80 | min_shift_time = (pair[0] - cur_pair[0]) / ticks_per_beat / 10 # tempo 600
81 | if pair[0] != prev_pair[0] and pair[1] > prev_pair[1] and pair[1] > cur_pair[1] + min_shift_time:
82 | mono_position_pairs.append(pair)
83 | cur_pair = pair
84 | prev_pair = pair
85 | position_pairs = np.array(mono_position_pairs)
86 |
87 | if max_tick is not None and max_time is not None:
88 | position_pairs = np.concatenate([position_pairs, [(max_tick, max_time)]])
89 | score_beats = np.concatenate([score_beats, [max_tick]])
90 |
91 | onset_ticks, perf_times = position_pairs[:, 0], position_pairs[:, 1]
92 | beat_onset_indices = np.minimum(len(onset_ticks) - 1, np.searchsorted(onset_ticks, score_beats))
93 |
94 | # fill known beats
95 | perf_beats = []
96 | for i, beat in enumerate(score_beats):
97 | onset_idx = beat_onset_indices[i]
98 | if onset_ticks[onset_idx] == beat:
99 | perf_beat = perf_times[onset_idx]
100 | else:
101 | # interpolate
102 | if i == 0 or onset_idx == 0:
103 | onset_idx += 1
104 |
105 | left_tick, right_tick = onset_ticks[onset_idx - 1], onset_ticks[onset_idx]
106 | left_time, right_time = perf_times[onset_idx - 1], perf_times[onset_idx]
107 |
108 | perf_beat = left_time + (right_time - left_time) * (beat - left_tick) / (right_tick - left_tick)
109 |
110 | perf_beats.append(perf_beat)
111 |
112 | if max_tick is not None and max_time is not None:
113 | if score_beats[-2] == score_beats[-1]:
114 | score_beats = score_beats[:-1]
115 | perf_beats = perf_beats[:-1]
116 |
117 | perf_beats = np.array(perf_beats)
118 |
119 | return score_beats, perf_beats
120 |
--------------------------------------------------------------------------------
/scoreperformer/data/midi/containers.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 |
4 | @dataclass
5 | class Note:
6 | pitch: int
7 | velocity: int
8 | start: float
9 | end: float
10 |
11 | @property
12 | def duration(self):
13 | return self.end - self.start
14 |
--------------------------------------------------------------------------------
/scoreperformer/data/midi/preprocess.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 |
3 | from miditok.utils import merge_tracks
4 | from miditoolkit import MidiFile, Marker, Instrument
5 |
6 | from ..midi import quantization as midi_quan
7 | from ..midi import utils as midi_utl
8 | from ..midi.containers import Note
9 |
10 |
11 | def preprocess_midi(
12 | midi: MidiFile,
13 | to_single_track: bool = True,
14 | sort_events: bool = True,
15 | clean_duplicates: bool = True,
16 | cut_overlapped_notes: bool = False,
17 | clean_short_notes: bool = False,
18 | quantize_notes: bool = False,
19 | quantize_midi_changes: bool = False,
20 | filter_late_events: bool = True,
21 | target_ticks_per_beat: Optional[int] = None
22 | ):
23 | if len(midi.instruments) == 0:
24 | return midi
25 |
26 | if len(midi.instruments) > 1 and to_single_track:
27 | merge_tracks(midi.instruments, effects=True)
28 |
29 | # TODO: order of notes changes inside, handle the case `sort_events = False`
30 | for track in midi.instruments:
31 | if clean_duplicates:
32 | midi_utl.remove_duplicated_notes(track.notes)
33 |
34 | if cut_overlapped_notes:
35 | midi_utl.cut_overlapping_notes(track.notes)
36 |
37 | if clean_short_notes:
38 | midi_utl.remove_short_notes(track.notes, time_division=midi.ticks_per_beat)
39 |
40 | if quantize_notes:
41 | midi_quan.quantize_notes(track.notes, time_division=midi.ticks_per_beat)
42 | if clean_duplicates:
43 | midi_utl.remove_duplicated_notes(track.notes)
44 |
45 | if sort_events:
46 | for track in midi.instruments:
47 | track.notes.sort(key=lambda x: (x.start, x.pitch, x.end)) # sort notes
48 | midi.max_tick = max([max([note.end for note in track.notes[-100:]]) for track in midi.instruments])
49 | else:
50 | midi.max_tick = max([max([note.end for note in track.notes]) for track in midi.instruments]) + 1
51 |
52 | midi.instruments = [track for track in midi.instruments if len(track.notes) > 0]
53 |
54 | if filter_late_events:
55 | midi_utl.filter_late_midi_events(midi, sort=sort_events)
56 |
57 | if quantize_midi_changes:
58 | midi_quan.quantize_time_signatures(midi.time_signature_changes, time_division=midi.ticks_per_beat)
59 | midi_quan.quantize_tempos(midi.tempo_changes, time_division=midi.ticks_per_beat)
60 | midi_quan.quantize_key_signatures(midi.key_signature_changes, time_division=midi.ticks_per_beat)
61 |
62 | if target_ticks_per_beat is not None:
63 | midi_utl.resample_midi(midi, ticks_per_beat=target_ticks_per_beat)
64 |
65 | return midi
66 |
67 |
68 | def insert_silent_notes(
69 | midi: MidiFile,
70 | markers: Optional[List[Marker]] = None,
71 | track_idx: Optional[int] = None
72 | ):
73 | markers = markers or midi.markers
74 |
75 | notes = []
76 | for m in markers:
77 | if m.text.startswith('NoteS'):
78 | pitch, start_tick, end_tick = map(int, m.text.split('_')[1:])
79 | notes.append(Note(pitch, 0, start_tick, end_tick))
80 |
81 | if track_idx is None:
82 | track = Instrument(0, False, 'Unperformed Notes')
83 | track.notes = notes
84 | midi.instruments.append(track)
85 | else:
86 | midi.instruments[track_idx].notes += notes
87 |
88 | if midi.instruments[-1].name != 'Unperformed Notes':
89 | midi.instruments.append(Instrument(0, False, 'Unperformed Notes'))
90 |
91 | return midi
92 |
--------------------------------------------------------------------------------
/scoreperformer/data/midi/quantization.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple
2 |
3 | from miditoolkit import Note, TempoChange, TimeSignature, KeySignature
4 |
5 |
6 | def quantize_notes(
7 | notes: List[Note],
8 | time_division: int,
9 | max_beat_res: int = 32,
10 | pitch_range: Optional[Tuple[int, int]] = (21, 109)
11 | ):
12 | """ Quantize notes, i.e. their pitch, start and end values.
13 | Shifts the notes' start and end times to match the quantization (e.g. 16 samples per bar)
14 | Notes with pitches outside of pitch_range are deleted.
15 | :param notes: notes to quantize
16 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed)
17 | :param max_beat_res: maximum beat resolution for one sample
18 | :param pitch_range: pitch range from within notes should be
19 | """
20 | ticks_per_sample = int(time_division / max_beat_res)
21 | i = 0
22 | while i < len(notes):
23 | note = notes[i]
24 | if pitch_range is not None and (note.pitch < pitch_range[0] or note.pitch >= pitch_range[1]):
25 | del notes[i]
26 | continue
27 | start_offset = note.start % ticks_per_sample
28 | end_offset = note.end % ticks_per_sample
29 | note.start += -start_offset if start_offset <= ticks_per_sample / 2 else ticks_per_sample - start_offset
30 | note.end += -end_offset if end_offset <= ticks_per_sample / 2 else ticks_per_sample - end_offset
31 |
32 | if note.start == note.end:
33 | note.end += ticks_per_sample
34 |
35 | i += 1
36 |
37 |
38 | def quantize_tempos(
39 | tempos: List[TempoChange],
40 | time_division: int,
41 | max_beat_res: int = 32
42 | ):
43 | r"""Quantize the times of tempo change events.
44 | Consecutive identical tempo changes will be removed.
45 |
46 | :param tempos: tempo changes to quantize
47 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed)
48 | :param max_beat_res: maximum beat resolution for one sample
49 | """
50 | ticks_per_sample = int(time_division / max_beat_res)
51 | i, prev_tempo = 0, -1
52 | while i < len(tempos):
53 | # Quantize tempo value
54 | if tempos[i].tempo == prev_tempo:
55 | del tempos[i]
56 | continue
57 | rest = tempos[i].time % ticks_per_sample
58 | tempos[i].time += -rest if rest <= ticks_per_sample / 2 else ticks_per_sample - rest
59 | prev_tempo = tempos[i].tempo
60 | i += 1
61 |
62 |
63 | def compute_ticks_per_bar(time_sig: TimeSignature, time_division: int):
64 | r"""Computes time resolution of one bar in ticks.
65 |
66 | :param time_sig: time signature object
67 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed)
68 | :return: MIDI bar resolution, in ticks/bar
69 | """
70 | return int(time_division * 4 * time_sig.numerator / time_sig.denominator)
71 |
72 |
73 | def quantize_time_signatures(time_sigs: List[TimeSignature], time_division: int):
74 | r"""Quantize the time signature changes, delayed to the next bar.
75 | See MIDI 1.0 Detailed specifications, pages 54 - 56, for more information on
76 | delayed time signature messages.
77 |
78 | :param time_sigs: time signature changes to quantize
79 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed)
80 | """
81 | all_different = False
82 | while not all_different:
83 | all_different = True
84 |
85 | # delete one of neighbouring time signatures with same values or time
86 | prev_time_sig = time_sigs[0]
87 | i = 1
88 | while i < len(time_sigs):
89 | time_sig = time_sigs[i]
90 |
91 | if (time_sig.numerator, time_sig.denominator) == (prev_time_sig.numerator, prev_time_sig.denominator) or \
92 | time_sig.time == prev_time_sig.time:
93 | del time_sigs[i]
94 | all_different = False
95 | continue
96 | prev_time_sig = time_sig
97 | i += 1
98 |
99 | # quantize times
100 | ticks_per_bar = compute_ticks_per_bar(time_sigs[0], time_division)
101 | current_bar = 0
102 | previous_tick = 0 # first time signature change is always at tick 0
103 | i = 1
104 | while i < len(time_sigs):
105 | time_sig = time_sigs[i]
106 |
107 | # determine the current bar of time sig
108 | bar_offset, rest = divmod(time_sig.time - previous_tick, ticks_per_bar)
109 | if rest > 0: # time sig doesn't happen on a new bar, we update it to the next bar
110 | bar_offset += 1
111 | time_sig.time = previous_tick + bar_offset * ticks_per_bar
112 |
113 | # Update values
114 | ticks_per_bar = compute_ticks_per_bar(time_sig, time_division)
115 | current_bar += bar_offset
116 | previous_tick = time_sig.time
117 | i += 1
118 |
119 |
120 | def quantize_key_signatures(
121 | key_signatures: List[KeySignature],
122 | time_division: int,
123 | max_beat_res: int = 32
124 | ):
125 | r"""Quantize the times of key signature change events.
126 | Consecutive identical key signature changes will be removed.
127 |
128 | :param key_signatures: key signature changes to quantize
129 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed)
130 | :param max_beat_res: maximum beat resolution for one sample
131 | """
132 | ticks_per_sample = int(time_division / max_beat_res)
133 | i, prev_tempo = 0, ''
134 | while i < len(key_signatures):
135 | # Quantize tempo value
136 | if key_signatures[i].key_name == prev_tempo:
137 | del key_signatures[i]
138 | continue
139 | rest = key_signatures[i].time % ticks_per_sample
140 | key_signatures[i].time += -rest if rest <= ticks_per_sample / 2 else ticks_per_sample - rest
141 | prev_tempo = key_signatures[i].key_name
142 | i += 1
143 |
--------------------------------------------------------------------------------
/scoreperformer/data/midi/sync.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Optional
3 |
4 | import numpy as np
5 | from miditoolkit import MidiFile, TempoChange, Marker
6 |
7 | from scoreperformer.utils import find_closest
8 | from .beats import get_inter_beat_interval, get_bar_beat_ticks, get_performance_beats
9 | from .timing import (
10 | convert_symbolic_timing_to_absolute,
11 | convert_absolute_timing_to_symbolic
12 | )
13 | from .utils import filter_late_midi_events
14 |
15 |
16 | def sync_performance_midi(
17 | score_midi: MidiFile,
18 | perf_midi: MidiFile,
19 | onset_pairs: np.ndarray,
20 | is_absolute_timing: bool = False,
21 | max_time: Optional[float] = None,
22 | ticks_per_beat: int = 480,
23 | bar_sync: bool = True,
24 | inplace: bool = True,
25 | verbose: bool = False
26 | ):
27 | """
28 | Synchronizes performance MIDI with score MIDI bars/beats through the onset pairs.
29 |
30 | Also adds silent (unperformed) performance notes as special markers based on the alignment.
31 | """
32 | perf_midi = copy.deepcopy(perf_midi) if not inplace else perf_midi
33 |
34 | # preprocess performance midi
35 | filter_late_midi_events(perf_midi)
36 | max_tick = score_midi.max_tick
37 |
38 | if not is_absolute_timing:
39 | tick_to_time = perf_midi.get_tick_to_time_mapping()
40 | max_time = tick_to_time[-1]
41 | else:
42 | assert max_time is not None, "`max_time` should be explicitly provided for MIDI with absolute timing"
43 | tick_to_time = None
44 |
45 | # compute score and performance onsets
46 | score_bars, score_beats = get_bar_beat_ticks(score_midi)
47 | score_onsets = score_bars if bar_sync else score_beats
48 | score_onsets, perf_onsets = get_performance_beats(
49 | score_onsets, onset_pairs,
50 | max_tick=max_tick - 1, max_time=max_time,
51 | monotonic_times=True, ticks_per_beat=ticks_per_beat
52 | )
53 | perf_shift = perf_onsets[0]
54 | perf_onsets -= perf_shift
55 | max_time -= perf_shift
56 |
57 | perf_score_tick_ratio = ticks_per_beat / score_midi.ticks_per_beat
58 |
59 | time_signatures = score_midi.time_signature_changes
60 |
61 | time_sig_ticks, quarter_note_factors, inter_onset_intervals = [], [], []
62 | for time_sig in time_signatures:
63 | time_sig_ticks.append(time_sig.time)
64 | quarter_note_factors.append(4 * time_sig.numerator / time_sig.denominator)
65 | inter_onset_intervals.append(
66 | get_inter_beat_interval(time_sig=time_sig, ticks_per_beat=score_midi.ticks_per_beat)
67 | )
68 |
69 | time_sig_ticks, quarter_note_factors, inter_onset_intervals = map(
70 | np.array, (time_sig_ticks, quarter_note_factors, inter_onset_intervals)
71 | )
72 | inter_beat_intervals = inter_onset_intervals
73 |
74 | ticks_per_bar = (score_midi.ticks_per_beat * quarter_note_factors).astype(int)
75 | beats_per_bar = ticks_per_bar / inter_beat_intervals
76 | ioi_in_quarters = ibi_in_quarters = quarter_note_factors / beats_per_bar
77 |
78 | if bar_sync:
79 | inter_onset_intervals = inter_onset_intervals * beats_per_bar
80 | ioi_in_quarters = ioi_in_quarters * beats_per_bar
81 |
82 | if verbose:
83 | print(f'score: time_sigs={time_signatures}\n'
84 | f' ticks_per_beat={score_midi.ticks_per_beat}, ticks_per_bar={ticks_per_bar}\n'
85 | f' inter_beat_intervals={inter_beat_intervals}, inter_onset_intervals={inter_onset_intervals}\n'
86 | f' ibi_in_quarters={ibi_in_quarters}, ioi_in_quarters={ioi_in_quarters}')
87 |
88 | # compute tempos
89 | intervals = np.diff(perf_onsets)
90 | if np.any(intervals <= 0.):
91 | return None
92 |
93 | time_sig_indices = (np.searchsorted(time_sig_ticks, score_onsets, side='right') - 1)[:-1]
94 | inter_onset_ratios = np.diff(score_onsets) / inter_onset_intervals[time_sig_indices]
95 | tempos = 60 / intervals * ioi_in_quarters[time_sig_indices] * inter_onset_ratios
96 |
97 | if verbose:
98 | print(f'tempos: ({tempos.min():.3f}, {tempos.max():.3f}), {np.median(tempos):.3f}')
99 |
100 | # get absolute timing of instruments
101 | if is_absolute_timing:
102 | abs_instr = perf_midi.instruments
103 | else:
104 | abs_instr = convert_symbolic_timing_to_absolute(
105 | perf_midi.instruments, tick_to_time, inplace=inplace, time_shift=-perf_shift
106 | )
107 |
108 | # compute time to tick mapping
109 | inter_onset_intervals = inter_onset_intervals[time_sig_indices] * perf_score_tick_ratio * inter_onset_ratios
110 | resample_timing = []
111 | for i in range(len(perf_onsets) - 1):
112 | start_beat, end_beat = perf_onsets[i], perf_onsets[i + 1]
113 | resample_timing.append(np.linspace(start_beat, end_beat, int(inter_onset_intervals[i]) + 1)[:-1])
114 |
115 | resample_timing.append([max_time])
116 | resample_timing = np.round(np.concatenate(resample_timing), 6)
117 |
118 | # new a midifile obj
119 | midi = MidiFile(ticks_per_beat=ticks_per_beat)
120 |
121 | # convert abs to sym
122 | sym_instr = convert_absolute_timing_to_symbolic(abs_instr, resample_timing, inplace=inplace)
123 |
124 | # process timing of markers
125 | markers = perf_midi.markers if hasattr(perf_midi, 'markers') else []
126 | for marker in markers:
127 | marker.time = find_closest(resample_timing, float(tick_to_time[marker.time]) - perf_shift)
128 | if marker.text.startswith('NoteI'):
129 | pitch, start, end = map(int, marker.text.split('_')[1:])
130 | start, end = map(lambda x: find_closest(resample_timing, float(tick_to_time[x]) - perf_shift), (start, end))
131 | marker.text = f'NoteI_{pitch}_{start}_{end}'
132 |
133 | # tempo
134 | tempo_changes = []
135 | onset_ticks = find_closest(resample_timing, perf_onsets)
136 | for pos_tick, tempo in zip(onset_ticks[:-1], tempos):
137 | tempo_changes.append(TempoChange(tempo=float(tempo), time=int(pos_tick)))
138 |
139 | tempo_changes = [tempo for tempo in tempo_changes if tempo.time < resample_timing.shape[0]]
140 |
141 | # markers
142 | markers.insert(0, Marker(text=f'Shift_{perf_shift:.6f}', time=0))
143 |
144 | # set attributes
145 | midi.tempo_changes = tempo_changes
146 | midi.time_signature_changes = time_signatures
147 | midi.instruments = sym_instr
148 | midi.markers = markers
149 | midi.max_tick = resample_timing.shape[0]
150 |
151 | return midi
152 |
--------------------------------------------------------------------------------
/scoreperformer/data/midi/timing.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import List
3 |
4 | import numpy as np
5 | from miditoolkit import Instrument
6 |
7 | from scoreperformer.utils import find_closest
8 | from ..midi.containers import Note
9 |
10 |
11 | def convert_symbolic_timing_to_absolute(
12 | tracks: List[Instrument],
13 | tick_to_time: np.ndarray,
14 | inplace: bool = True,
15 | time_shift: float = 0.
16 | ):
17 | tracks = tracks if inplace else copy.deepcopy(tracks)
18 |
19 | for track in tracks:
20 | track.notes = [
21 | Note(pitch=n.pitch, velocity=n.velocity,
22 | start=time_shift + float(tick_to_time[n.start]),
23 | end=time_shift + float(tick_to_time[n.end]))
24 | for n in track.notes
25 | ]
26 | for control_change in track.control_changes:
27 | control_change.time = time_shift + float(tick_to_time[control_change.time])
28 | for pedal in track.pedals:
29 | pedal.start = time_shift + float(tick_to_time[pedal.start])
30 | pedal.end = time_shift + float(tick_to_time[pedal.end])
31 | for pitch_bend in track.pitch_bends:
32 | pitch_bend.time = time_shift + float(tick_to_time[pitch_bend.time])
33 |
34 | return tracks
35 |
36 |
37 | def convert_absolute_timing_to_symbolic(
38 | tracks: List[Instrument],
39 | time_to_tick: np.ndarray,
40 | inplace: bool = True
41 | ):
42 | tracks = tracks if inplace else copy.deepcopy(tracks)
43 |
44 | def process_interval_events(events):
45 | start_times = np.array(list(map(lambda x: x.start, events)))
46 | start_ticks = find_closest(time_to_tick, start_times)
47 | end_times = np.array(list(map(lambda x: x.end, events)))
48 | end_ticks = find_closest(time_to_tick, end_times)
49 | for event, start_t, end_t in zip(events, start_ticks, end_ticks):
50 | if start_t == end_t:
51 | end_t += 1
52 | event.start = start_t
53 | event.end = end_t
54 |
55 | def process_time_events(events):
56 | times = np.array(list(map(lambda x: x.time, events)))
57 | ticks = find_closest(time_to_tick, times)
58 | for event, t in zip(events, ticks):
59 | event.time = t
60 |
61 | for track in tracks:
62 | process_interval_events(track.notes)
63 | process_interval_events(track.pedals)
64 | process_time_events(track.control_changes)
65 | process_time_events(track.pitch_bends)
66 |
67 | return tracks
68 |
--------------------------------------------------------------------------------
/scoreperformer/data/midi/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from collections import defaultdict
3 | from typing import Optional, List
4 |
5 | import numpy as np
6 | from miditoolkit import Note, MidiFile
7 |
8 | from scoreperformer.utils import find_closest
9 |
10 |
11 | def sort_notes(
12 | notes: List[Note],
13 | compute_sort_indices: bool = False,
14 | order: str = 'time'
15 | ):
16 | assert order in ('time', 'pitch')
17 |
18 | sort_ids = None
19 | if order == 'time':
20 | if compute_sort_indices:
21 | sort_ids = np.lexsort([[n.end for n in notes], [n.pitch for n in notes], [n.start for n in notes]])
22 | notes.sort(key=lambda n: (n.start, n.pitch, n.end))
23 | elif order == 'pitch':
24 | if compute_sort_indices:
25 | sort_ids = np.lexsort([[n.end for n in notes], [n.start for n in notes], [n.pitch for n in notes]])
26 | notes.sort(key=lambda n: (n.pitch, n.start, n.end))
27 |
28 | return notes, sort_ids
29 |
30 |
31 | def cut_overlapping_notes(notes: List[Note], return_sort_indices: bool = False):
32 | r"""Find and cut the first of the two overlapping notes, i.e. with the same pitch,
33 | and the second note starting before the ending of the first note.
34 |
35 | :param notes: notes to analyse
36 | :param return_sort_indices: return indices by which the original notes were sorted
37 | """
38 | # sort by pitch, then time
39 | notes, sort_ids = sort_notes(notes, compute_sort_indices=return_sort_indices, order='pitch')
40 |
41 | for i in range(1, len(notes)):
42 | prev_note, note = notes[i - 1], notes[i]
43 | if prev_note.pitch == note.pitch and prev_note.end >= note.start:
44 | if note.start <= 1:
45 | note.start = 2
46 | prev_note.end = note.start - 1
47 | if prev_note.start >= prev_note.end: # resulted in an invalid note, fix it too
48 | prev_note.start = prev_note.end - 1
49 |
50 | # sort back by time, then pitch
51 | notes, sort_ids_back = sort_notes(notes, compute_sort_indices=return_sort_indices, order='time')
52 |
53 | if return_sort_indices:
54 | sort_ids = sort_ids[sort_ids_back]
55 | return notes, sort_ids
56 | return notes
57 |
58 |
59 | def remove_duplicated_notes(notes: List[Note], return_sort_indices: bool = False):
60 | r"""Find and remove exactly similar notes, i.e. with the same pitch, start and end.
61 |
62 | :param notes: notes to analyse
63 | :param return_sort_indices: return indices by which the original notes were sorted
64 | """
65 | # sort by pitch, then time
66 | notes, sort_ids = sort_notes(notes, compute_sort_indices=return_sort_indices, order='pitch')
67 |
68 | for i in range(len(notes) - 1, 0, -1): # removing possible duplicated notes
69 | if notes[i].pitch == notes[i - 1].pitch and notes[i].start == notes[i - 1].start and \
70 | notes[i].end >= notes[i - 1].end:
71 | del notes[i]
72 |
73 | # sort back by time, then pitch
74 | notes, sort_ids_back = sort_notes(notes, compute_sort_indices=return_sort_indices, order='time')
75 |
76 | if return_sort_indices:
77 | sort_ids = sort_ids[sort_ids_back]
78 | return notes, sort_ids
79 | return notes
80 |
81 |
82 | def remove_short_notes(notes: List[Note], time_division: int, max_beat_res: int = 32):
83 | r"""Find and remove short notes, i.e. with the same pitch, start and end.
84 |
85 | :param notes: notes to analyse
86 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed)
87 | :param max_beat_res: maximum beat resolution for one sample
88 | """
89 | ticks_per_sample = int(time_division / max_beat_res)
90 |
91 | for i in range(len(notes) - 1, 0, -1):
92 | note = notes[i]
93 | if note.end - note.start < ticks_per_sample // 2:
94 | del notes[i]
95 |
96 | return notes
97 |
98 |
99 | def filter_late_midi_events(midi: MidiFile, max_tick: Optional[int] = None, sort: bool = False):
100 | max_tick = max_tick or midi.max_tick
101 |
102 | for track in midi.instruments:
103 | if sort:
104 | track.control_changes.sort(key=lambda c: c.time)
105 | for i, control_change in enumerate(track.control_changes):
106 | if control_change.time > max_tick:
107 | track.control_changes = track.control_changes[:i]
108 | break
109 |
110 | if sort:
111 | track.pedals.sort(key=lambda p: p.start)
112 | for i, pedal in enumerate(track.pedals):
113 | if pedal.end > max_tick:
114 | track.pedals = track.pedals[:i]
115 | break
116 |
117 | if sort:
118 | track.pitch_bends.sort(key=lambda p: p.time)
119 | for i, pitch_bend in enumerate(track.pitch_bends):
120 | if pitch_bend.time > max_tick:
121 | track.pitch_bends = track.pitch_bends[:i]
122 | break
123 |
124 | return midi
125 |
126 |
127 | def shift_midi_notes(
128 | midi: MidiFile,
129 | time_shift: float = 0.,
130 | offset: float = 0.,
131 | inplace: bool = True,
132 | return_shifted_indices: bool = False
133 | ):
134 | midi = midi if inplace else copy.deepcopy(midi)
135 |
136 | midi.max_tick *= 4
137 | ttt = midi.get_tick_to_time_mapping()
138 |
139 | def process_continuous_events(elements):
140 | start_ticks = np.array(list(map(lambda x: x.start, elements)))
141 | end_ticks = np.array(list(map(lambda x: x.end, elements)))
142 | start_times, end_times = ttt[start_ticks], ttt[end_ticks]
143 | new_start_ticks = find_closest(ttt, start_times + time_shift)
144 | new_end_ticks = find_closest(ttt, end_times + time_shift)
145 | for el, time, start_t, end_t in zip(elements, start_times, new_start_ticks, new_end_ticks):
146 | if time >= offset:
147 | if start_t == end_t:
148 | end_t += 1
149 | el.start = start_t
150 | el.end = end_t
151 | return np.where(start_times >= offset)[0]
152 |
153 | def process_instant_events(elements):
154 | ticks = np.array(list(map(lambda x: x.time, elements)))
155 | times = ttt[ticks]
156 | new_ticks = find_closest(ttt, times + time_shift)
157 | for el, time, tick in zip(elements, times, new_ticks):
158 | if time >= offset:
159 | el.time = tick
160 | return np.where(times >= offset)[0]
161 |
162 | # shift relevant notes in MIDI
163 | shifted_indices = defaultdict(list)
164 | for track_idx, track in enumerate(midi.instruments):
165 | shifted_indices['note'].append((track_idx, process_continuous_events(track.notes)))
166 | if track.pedals:
167 | shifted_indices['pedal'].append((track_idx, process_continuous_events(track.pedals)))
168 | if track.control_changes:
169 | shifted_indices['control_change'].append((track_idx, process_instant_events(track.control_changes)))
170 | if track.pitch_bends:
171 | shifted_indices['pitch_bend'].append((track_idx, process_instant_events(track.pitch_bends)))
172 |
173 | midi.max_tick = max([max([note.end for note in track.notes]) for track in midi.instruments]) + 1
174 |
175 | if return_shifted_indices:
176 | return midi, shifted_indices
177 | return midi
178 |
179 |
180 | def resample_midi(midi: MidiFile, ticks_per_beat: int, inplace: bool = True):
181 | if midi.ticks_per_beat == ticks_per_beat:
182 | return midi
183 |
184 | midi = midi if inplace else copy.deepcopy(midi)
185 |
186 | scale = ticks_per_beat / midi.ticks_per_beat
187 |
188 | def process_continuous_events(elements):
189 | for el in elements:
190 | el.start = int(scale * el.start)
191 | el.end = int(scale * el.end)
192 |
193 | def process_instant_events(elements):
194 | for el in elements:
195 | el.time = int(scale * el.time)
196 |
197 | # resample MIDI events
198 | for track in midi.instruments:
199 | process_continuous_events(track.notes)
200 | if track.pedals:
201 | process_continuous_events(track.pedals)
202 | if track.control_changes:
203 | process_instant_events(track.control_changes)
204 | if track.pitch_bends:
205 | process_instant_events(track.pitch_bends)
206 |
207 | process_instant_events(midi.time_signature_changes)
208 | process_instant_events(midi.tempo_changes)
209 | process_instant_events(midi.key_signature_changes)
210 |
211 | midi.max_tick = max([max([note.end for note in track.notes]) for track in midi.instruments]) + 1
212 | return midi
213 |
--------------------------------------------------------------------------------
/scoreperformer/data/music_constants.py:
--------------------------------------------------------------------------------
1 | """ Music constants. """
2 |
3 | # notes and pitch-sitch maps
4 | NOTES_WSHARP = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
5 | NOTES_WFLAT = ['C', 'Db', 'D', 'Eb', 'E', 'F', 'Gb', 'G', 'Ab', 'A', 'Bb', 'B']
6 | NOTE_MAP = {sitch: i for i, sitch in enumerate(NOTES_WSHARP)}
7 | NOTE_INV_MAP = {i: sitch for sitch, i in NOTE_MAP.items()}
8 | NOTE_MAP.update(**{sitch: i for i, sitch in enumerate(NOTES_WFLAT)})
9 |
10 |
11 | def pitch2sitch(pitch):
12 | return NOTE_INV_MAP[pitch % 12] + str(pitch // 12 - 1)
13 |
14 |
15 | def sitch2pitch(sitch):
16 | note = sitch[:1 + int(sitch[1] in ("#", "b"))]
17 | octave = sitch[len(note):]
18 | return NOTE_MAP[note] + 12 * (int(octave) + 1)
19 |
--------------------------------------------------------------------------------
/scoreperformer/data/tokenizers/__init__.py:
--------------------------------------------------------------------------------
1 | from scoreperformer.utils import ExplicitEnum
2 | from .classes import TokSequence, TokenizerConfig
3 | from .common import OctupleM
4 | from .spmuple import (
5 | SPMupleBase,
6 | SPMuple,
7 | SPMuple2,
8 |
9 | SPMupleOnset,
10 | SPMupleBeat,
11 | SPMupleBar,
12 | SPMupleWindow,
13 | SPMupleWindowRecompute
14 | )
15 |
16 |
17 | class TokenizerTypes(ExplicitEnum):
18 | OctupleM = "OctupleM"
19 | SPMuple = "SPMuple"
20 | SPMuple2 = "SPMuple2"
21 | SPMupleOnset = "SPMupleOnset"
22 | SPMupleBeat = "SPMupleBeat"
23 | SPMupleBar = "SPMupleBar"
24 | SPMupleWindow = "SPMupleWindow"
25 | SPMupleWindowRecompute = "SPMupleWindowRecompute"
26 |
27 |
28 | TOKENIZERS = {
29 | TokenizerTypes.OctupleM: OctupleM,
30 | TokenizerTypes.SPMuple: SPMuple,
31 | TokenizerTypes.SPMuple2: SPMuple2,
32 | TokenizerTypes.SPMupleOnset: SPMupleOnset,
33 | TokenizerTypes.SPMupleBeat: SPMupleBeat,
34 | TokenizerTypes.SPMupleBar: SPMupleBar,
35 | TokenizerTypes.SPMupleWindow: SPMupleWindow,
36 | TokenizerTypes.SPMupleWindowRecompute: SPMupleWindowRecompute
37 | }
38 |
--------------------------------------------------------------------------------
/scoreperformer/data/tokenizers/classes.py:
--------------------------------------------------------------------------------
1 | """ Extended miditok classes. """
2 |
3 | from dataclasses import dataclass
4 | from typing import Optional, Dict, Sequence
5 |
6 | from miditok.classes import (
7 | TokSequence as MidiTokTokSequence,
8 | TokenizerConfig as MidiTokTokenizerConfig
9 | )
10 |
11 | from .constants import SPECIAL_TOKENS
12 |
13 |
14 | @dataclass
15 | class TokSequence(MidiTokTokSequence):
16 | meta: Optional[Dict[str, object]] = None
17 |
18 |
19 | class TokenizerConfig(MidiTokTokenizerConfig):
20 | r"""
21 | MIDI tokenizer base class, containing common methods and attributes for all tokenizers.
22 | :param special_tokens: list of special tokens. This must be given as a list of strings given
23 | only the names of the tokens. (default: ``["PAD", "SOS", "EOS", "MASK"]``\)
24 | :param **kwargs: additional parameters that will be saved in `config.additional_params`.
25 | """
26 |
27 | def __init__(
28 | self,
29 | special_tokens: Sequence[str] = SPECIAL_TOKENS,
30 | **kwargs
31 | ):
32 | super().__init__(special_tokens=special_tokens, **kwargs)
33 |
--------------------------------------------------------------------------------
/scoreperformer/data/tokenizers/common/__init__.py:
--------------------------------------------------------------------------------
1 | from .octuple_m import OctupleM
2 |
--------------------------------------------------------------------------------
/scoreperformer/data/tokenizers/constants.py:
--------------------------------------------------------------------------------
1 | """ Tokenizer related constants. """
2 |
3 | # override MidiTok default special tokens for backward compatibility
4 | SPECIAL_TOKENS = ["PAD", "MASK", "SOS", "EOS"]
5 |
6 | PAD_TOKEN = "PAD_None"
7 | MASK_TOKEN = "MASK_None"
8 | SOS_TOKEN = "SOS_None"
9 | EOS_TOKEN = "EOS_None"
10 |
11 | TIME_DIVISION = 480
12 |
13 | SCORE_KEYS = [
14 | "Bar",
15 | "Position",
16 | "Pitch",
17 | "Velocity",
18 | "Duration",
19 | "Tempo",
20 | "TimeSig",
21 | "Program",
22 | "PositionShift",
23 | "NotesInOnset",
24 | "PositionInOnset"
25 | ]
26 | PERFORMANCE_KEYS = SCORE_KEYS + [
27 | "OnsetDev",
28 | "PerfDuration",
29 | "RelOnsetDev",
30 | "RelPerfDuration"
31 | ]
32 |
--------------------------------------------------------------------------------
/scoreperformer/data/tokenizers/midi_tokenizer.py:
--------------------------------------------------------------------------------
1 | """ MIDI encoding base class and methods. """
2 |
3 | from abc import ABC
4 |
5 | from miditok import MIDITokenizer as _MIDITokenizer
6 | from miditok.constants import TIME_SIGNATURE
7 | from miditok.utils import remove_duplicated_notes, merge_same_program_tracks
8 | from miditoolkit import MidiFile, TimeSignature
9 |
10 |
11 | class MIDITokenizer(_MIDITokenizer, ABC):
12 | r"""MIDI tokenizer base class, containing common methods and attributes for all tokenizers.
13 |
14 | See :class:`miditok.MIDITokenizer` for a detailed documentation.
15 | """
16 |
17 | def preprocess_midi(self, midi: MidiFile):
18 | r"""Pre-process (in place) a MIDI file to quantize its time and note attributes
19 | before tokenizing it. Its notes attribute (times, pitches, velocities) will be
20 | quantized and sorted, duplicated notes removed, as well as tempos. Empty tracks
21 | (with no note) will be removed from the MIDI object. Notes with pitches outside
22 | of self.pitch_range will be deleted.
23 |
24 | :param midi: MIDI object to preprocess.
25 | """
26 | # Merge instruments of the same program / inst before preprocessing them
27 | # This allows to avoid potential duplicated notes in some multitrack settings
28 | if self.config.use_programs and self.one_token_stream:
29 | merge_same_program_tracks(midi.instruments)
30 |
31 | t = 0
32 | while t < len(midi.instruments):
33 | # quantize notes attributes
34 | self._quantize_notes(midi.instruments[t].notes, midi.ticks_per_beat)
35 | # sort notes
36 | midi.instruments[t].notes.sort(key=lambda x: (x.start, x.pitch, x.end))
37 | # remove possible duplicated notes
38 | if self.config.additional_params.get("remove_duplicates", False):
39 | remove_duplicated_notes(midi.instruments[t].notes)
40 | if len(midi.instruments[t].notes) == 0:
41 | del midi.instruments[t]
42 | continue
43 |
44 | # Quantize sustain pedal and pitch bend
45 | if self.config.use_sustain_pedals:
46 | self._quantize_sustain_pedals(
47 | midi.instruments[t].pedals, midi.ticks_per_beat
48 | )
49 | if self.config.use_pitch_bends:
50 | self._quantize_pitch_bends(
51 | midi.instruments[t].pitch_bends, midi.ticks_per_beat
52 | )
53 | t += 1
54 |
55 | # Recalculate max_tick is this could have changed after notes quantization
56 | if len(midi.instruments) > 0:
57 | midi.max_tick = max(
58 | [max([note.end for note in track.notes]) for track in midi.instruments]
59 | )
60 |
61 | if self.config.use_tempos:
62 | self._quantize_tempos(midi.tempo_changes, midi.ticks_per_beat)
63 |
64 | if len(midi.time_signature_changes) == 0: # can sometimes happen
65 | midi.time_signature_changes.append(
66 | TimeSignature(*TIME_SIGNATURE, 0)
67 | ) # 4/4 by default in this case
68 | if self.config.use_time_signatures:
69 | self._quantize_time_signatures(
70 | midi.time_signature_changes, midi.ticks_per_beat
71 | )
72 |
--------------------------------------------------------------------------------
/scoreperformer/data/tokenizers/spmuple/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import SPMupleBase
2 | from .encodings import (
3 | SPMupleOnset,
4 | SPMupleBeat,
5 | SPMupleBar,
6 | SPMupleWindow,
7 | SPMupleWindowRecompute
8 | )
9 | from .spmuple import SPMuple
10 | from .spmuple2 import SPMuple2
11 |
--------------------------------------------------------------------------------
/scoreperformer/data/tokenizers/spmuple/base.py:
--------------------------------------------------------------------------------
1 | """ SPMuple (ScorePerformanceMusic-tuple) encoding for aligned score-performance music. """
2 |
3 | from abc import abstractmethod
4 | from typing import List, Optional, Union, Any
5 |
6 | import numpy as np
7 | from miditok.midi_tokenizer import _in_as_seq
8 | from miditoolkit import MidiFile, Note
9 |
10 | from ..classes import TokSequence
11 | from ..common import OctupleM
12 | from ..constants import TIME_DIVISION, SCORE_KEYS
13 |
14 |
15 | class SPMupleBase(OctupleM):
16 | r"""SPMupleBase: a base class for a family of ScorePerformanceMusic-tuple encodings.
17 |
18 | An extended OctupleM encoding with performance-specific tokens for performance MIDIs.
19 | """
20 |
21 | def _tweak_config_before_creating_voc(self):
22 | super()._tweak_config_before_creating_voc()
23 |
24 | # token vocabulary bins
25 | self.config.additional_params["token_bins"] = self.config.additional_params.get("token_bins", {})
26 |
27 | # midi postprocessing
28 | self.config.additional_params["cut_overlapping_notes"] = True
29 |
30 | def preprocess_midi(self, midi: MidiFile, is_score: bool = True):
31 | r"""Preprocess a MIDI file to be used by SPMuple encoding.
32 |
33 | :param midi: MIDI object to preprocess
34 | :param is_score: whether MIDI object is a score MIDI or not
35 | """
36 | super().preprocess_midi(midi)
37 |
38 | def preprocess_score_midi(self, midi: MidiFile):
39 | r"""Preprocess a score MIDI file to be used by SPMuple encoding.
40 |
41 | :param midi: MIDI object to preprocess
42 | """
43 | self.preprocess_midi(midi, is_score=True)
44 |
45 | def preprocess_performance_midi(self, midi: MidiFile):
46 | r"""Preprocess a performance MIDI file to be used by SPMuple encoding.
47 |
48 | :param midi: MIDI object to preprocess
49 | """
50 | self.preprocess_midi(midi, is_score=False)
51 |
52 | def score_midi_to_tokens(self, midi: MidiFile) -> TokSequence:
53 | r"""Converts a MIDI file to a score tokens representation, a sequence of "time steps" of tokens.
54 |
55 | A time step is a list of tokens where:
56 | (list index: token type)
57 | 0: Bar
58 | 1: Position
59 | 2: Pitch
60 | 3: Velocity
61 | 4: Duration
62 | (5: Tempo)
63 | (6: TimeSignature)
64 | (7: Program)
65 |
66 | :param midi: the MIDI objet to convert
67 | :return: a :class:`miditok.TokSequence`.
68 | """
69 | return super().midi_to_tokens(midi)
70 |
71 | def performance_midi_to_tokens(
72 | self,
73 | midi: MidiFile,
74 | score_tokens: TokSequence,
75 | alignment: Optional[np.ndarray] = None
76 | ) -> TokSequence:
77 | r"""Tokenizes a performance MIDI file in to :class:`miditok.TokSequence`.
78 |
79 | :param midi: the MIDI object to convert.
80 | :param score_tokens: corresponding score tokens :class:`miditok.TokSequence`.
81 | :param alignment: optional alignment between performance and score tokens.
82 | :return: a :class:`miditok.TokSequence`.
83 | """
84 | # Check if the durations values have been calculated before for this time division
85 | if midi.ticks_per_beat not in self._durations_ticks:
86 | self._durations_ticks[midi.ticks_per_beat] = np.array(
87 | [
88 | (beat * res + pos) * midi.ticks_per_beat // res
89 | for beat, pos, res in self.durations
90 | ]
91 | )
92 |
93 | # Preprocess the MIDI file
94 | self.preprocess_performance_midi(midi)
95 |
96 | # Register MIDI metadata
97 | self._current_midi_metadata = {
98 | "time_division": midi.ticks_per_beat,
99 | "max_tick": midi.max_tick,
100 | "tempo_changes": midi.tempo_changes,
101 | "time_sig_changes": midi.time_signature_changes,
102 | "key_sig_changes": midi.key_signature_changes,
103 | }
104 |
105 | tokens = self._performance_midi_to_tokens(midi, score_tokens, alignment)
106 |
107 | return tokens
108 |
109 | @abstractmethod
110 | def _performance_midi_to_tokens(
111 | self,
112 | midi: MidiFile,
113 | score_tokens: TokSequence,
114 | alignment: Optional[np.ndarray] = None
115 | ) -> TokSequence:
116 | r"""Converts a MIDI file to a performance tokens representation, a sequence of "time steps"
117 | of score tokens stacked with performance specific features (e.g., OnsetDeviation).
118 |
119 | :param midi: the MIDI object to convert.
120 | :param score_tokens: corresponding score tokens :class:`miditok.TokSequence`.
121 | :param alignment: optional alignment between performance and score tokens.
122 | :return: the performance token representation, i.e. tracks converted into sequences of tokens
123 | """
124 | ...
125 |
126 | @_in_as_seq()
127 | def score_tokens_to_midi(
128 | self,
129 | tokens: Union[TokSequence, List, np.ndarray, Any],
130 | output_path: Optional[str] = None,
131 | time_division: int = TIME_DIVISION,
132 | ) -> MidiFile:
133 | r"""Converts score tokens (:class:`miditok.TokSequence`) into a MIDI and saves it.
134 |
135 | :param tokens: tokens to convert. Can be either a list of :class:`miditok.TokSequence`,
136 | :param output_path: path to save the file. (default: None)
137 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI to create).
138 | :return: the midi object (:class:`miditoolkit.MidiFile`).
139 | """
140 | return self.tokens_to_midi(tokens, output_path=output_path, time_division=time_division)
141 |
142 | @abstractmethod
143 | @_in_as_seq()
144 | def performance_tokens_to_midi(
145 | self,
146 | tokens: Union[TokSequence, List, np.ndarray, Any],
147 | output_path: Optional[str] = None,
148 | time_division: int = TIME_DIVISION,
149 | ) -> MidiFile:
150 | r"""Converts performance tokens (:class:`miditok.TokSequence`) into a MIDI and saves it.
151 |
152 | :param tokens: tokens to convert. Can be either a list of :class:`miditok.TokSequence`,
153 | :param output_path: path to save the file. (default: None)
154 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI to create).
155 | :return: the midi object (:class:`miditoolkit.MidiFile`).
156 | """
157 | ...
158 |
159 | @abstractmethod
160 | @_in_as_seq()
161 | def score_tokens_as_performance(self, score_tokens: Union[TokSequence, List, np.ndarray, Any]) -> TokSequence:
162 | r"""Converts a sequence of score tokens into a sequence of performance tokens,
163 | the tokens corresponding to a deadpan performance with no variation from score notes.
164 | """
165 | ...
166 |
167 | def _quantize_notes(self, notes: List[Note], time_division: int, is_score: bool = True):
168 | r"""Quantize the notes attributes: their pitch, velocity, start and end values.
169 | It shifts the notes so that they start at times that match the time resolution
170 | (e.g. 16 samples per bar).
171 | Notes with pitches outside of self.pitch_range will be deleted.
172 |
173 | :param notes: notes to quantize.
174 | :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed).
175 | :param is_score: whether the notes are from score MIDI or not
176 | """
177 | super()._quantize_notes(notes, time_division)
178 |
179 | @abstractmethod
180 | def _create_base_vocabulary(self) -> List[List[str]]:
181 | r"""Creates the vocabulary, as a list of string tokens.
182 |
183 | :return: the vocabulary as a list of string.
184 | """
185 | return super()._create_base_vocabulary()
186 |
187 | @abstractmethod
188 | def _get_token_types(self) -> List[str]:
189 | r"""Creates an ordered list of available token types."""
190 | return super()._get_token_types()
191 |
192 | @property
193 | def score_sizes(self):
194 | return {
195 | key: value for key, value in self.sizes.items()
196 | if key in SCORE_KEYS
197 | }
198 |
199 | @property
200 | def performance_sizes(self):
201 | return self.sizes
202 |
--------------------------------------------------------------------------------
/scoreperformer/data/tokenizers/spmuple/encodings.py:
--------------------------------------------------------------------------------
1 | from .spmuple import SPMuple
2 | from .spmuple2 import SPMuple2
3 |
4 |
5 | class SPMupleOnset(SPMuple2):
6 | def _tweak_config_before_creating_voc(self):
7 | super()._tweak_config_before_creating_voc()
8 |
9 | self.config.additional_params["use_position_shifts"] = True
10 | self.config.additional_params["use_onset_indices"] = True
11 |
12 | self.config.additional_params["onset_tempos"] = True
13 |
14 |
15 | class SPMupleBeat(SPMuple):
16 | def _tweak_config_before_creating_voc(self):
17 | super()._tweak_config_before_creating_voc()
18 |
19 | self.config.additional_params["use_position_shifts"] = True
20 | self.config.additional_params["use_onset_indices"] = True
21 | self.config.additional_params["rel_onset_dev"] = True
22 | self.config.additional_params["rel_perf_duration"] = True
23 |
24 | self.config.additional_params["bar_tempos"] = False
25 |
26 |
27 | class SPMupleBar(SPMuple):
28 | def _tweak_config_before_creating_voc(self):
29 | super()._tweak_config_before_creating_voc()
30 |
31 | self.config.additional_params["use_position_shifts"] = True
32 | self.config.additional_params["use_onset_indices"] = True
33 | self.config.additional_params["rel_onset_dev"] = True
34 | self.config.additional_params["rel_perf_duration"] = True
35 |
36 | self.config.additional_params["bar_tempos"] = True
37 |
38 |
39 | class SPMupleWindow(SPMuple2):
40 | def _tweak_config_before_creating_voc(self):
41 | super()._tweak_config_before_creating_voc()
42 |
43 | self.config.additional_params["use_position_shifts"] = True
44 | self.config.additional_params["use_onset_indices"] = True
45 |
46 | self.config.additional_params["use_quantized_tempos"] = True
47 | self.config.additional_params["decode_recompute_tempos"] = False
48 |
49 |
50 | class SPMupleWindowRecompute(SPMuple2):
51 | def _tweak_config_before_creating_voc(self):
52 | super()._tweak_config_before_creating_voc()
53 |
54 | self.config.additional_params["use_position_shifts"] = True
55 | self.config.additional_params["use_onset_indices"] = True
56 |
57 | self.config.additional_params["use_quantized_tempos"] = self.config.additional_params.get(
58 | "use_quantized_tempos", True
59 | )
60 | self.config.additional_params["decode_recompute_tempos"] = True
61 |
--------------------------------------------------------------------------------
/scoreperformer/experiments/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import Trainer
2 | from .trainer_config import TrainerConfig
3 |
--------------------------------------------------------------------------------
/scoreperformer/experiments/components.py:
--------------------------------------------------------------------------------
1 | """ A centralized Experiment Components initialization. """
2 |
3 | import copy
4 | import os
5 | from dataclasses import dataclass, fields
6 | from typing import Union, Optional
7 |
8 | from omegaconf import DictConfig, OmegaConf
9 | from torch.utils.data import Dataset
10 |
11 | from scoreperformer.data import DATASETS, COLLATORS
12 | from scoreperformer.data.tokenizers.constants import MASK_TOKEN
13 | from scoreperformer.models import MODELS, EVALUATORS
14 | from scoreperformer.modules.constructor import ModuleConfig
15 | from scoreperformer.utils.config import disable_nodes, Config
16 | from .trainer_config import TrainerConfig
17 |
18 |
19 | @dataclass
20 | class ExperimentConfig(Config):
21 | data: DictConfig
22 | model: Union[DictConfig, ModuleConfig]
23 | trainer: Union[DictConfig, TrainerConfig]
24 | evaluator: Optional[DictConfig] = None
25 |
26 |
27 | ExperimentConfigFields = list(map(lambda f: f.name, fields(ExperimentConfig)))
28 |
29 |
30 | def resolve_config_hierarchy(config: DictConfig, config_root: Optional[str] = None):
31 | if "base" in config and config.base is not None:
32 | base_configs_names = [config.base] if isinstance(config.base, str) else config.base
33 |
34 | base_configs = []
35 | for base_config_name in base_configs_names:
36 | config_root = config_root or './'
37 | base_config_path = os.path.join(config_root, base_config_name)
38 | base_config = OmegaConf.load(base_config_path)
39 | base_config = resolve_config_hierarchy(base_config, config_root=config_root)
40 | base_configs.append(base_config)
41 |
42 | del config["base"]
43 |
44 | config = OmegaConf.merge(*base_configs[::-1], config)
45 |
46 | return config
47 |
48 |
49 | class ExperimentComponents:
50 | def __init__(self, config: Union[ExperimentConfig, str], config_root: Optional[str] = None):
51 | if isinstance(config, str):
52 | config = os.path.join(config_root, config) if config_root is not None else config
53 | config = OmegaConf.load(config)
54 | if not isinstance(config, DictConfig):
55 | config = OmegaConf.load(config)
56 |
57 | config = resolve_config_hierarchy(config, config_root=config_root)
58 |
59 | assert all([key in config for key in ExperimentConfigFields[:3]]), \
60 | f"ExperimentConfig is missing one of the keys: {ExperimentConfigFields[:3]}"
61 |
62 | disable_nodes(config)
63 | OmegaConf.resolve(config)
64 |
65 | self.config = ExperimentConfig(
66 | data=config.data,
67 | model=config.model,
68 | trainer=config.trainer,
69 | evaluator=config.get("evaluator", None)
70 | )
71 |
72 | self.train_dataset = None
73 | self.eval_dataset = None
74 | self.collator = None
75 | self.model = None
76 | self.evaluator = None
77 |
78 | def init_components(self):
79 | self.init_datasets()
80 | self.init_collator()
81 | self.init_model()
82 | self.init_evaluator()
83 |
84 | return self.model, self.train_dataset, self.eval_dataset, self.collator, self.evaluator
85 |
86 | def init_datasets(self):
87 | cfg = self.config.data.dataset
88 | self.train_dataset = build_dataset(cfg, split=cfg._splits_.train) if cfg._splits_.train is not None else None
89 | self.eval_dataset = build_dataset(cfg, split=cfg._splits_.eval) if cfg._splits_.eval is not None else None
90 |
91 | return self.train_dataset, self.eval_dataset
92 |
93 | def init_collator(self):
94 | dataset = self.train_dataset or self.eval_dataset
95 | assert dataset is not None
96 |
97 | cfg = self.config.data.collator
98 | cfg.mask_token_id = dataset.tokenizer[0, MASK_TOKEN]
99 | self.collator = build_collator(cfg)
100 |
101 | return self.collator
102 |
103 | def init_model(self, inject_data: bool = True):
104 | cfg = self.config.model
105 |
106 | dataset = None
107 | if inject_data:
108 | dataset = self.train_dataset or self.eval_dataset
109 | assert dataset is not None
110 |
111 | self.model = build_model(cfg, dataset=dataset)
112 |
113 | return self.model
114 |
115 | def init_evaluator(self):
116 | assert self.model is not None
117 | dataset = self.train_dataset or self.eval_dataset
118 |
119 | cfg = self.config.evaluator
120 | self.evaluator = build_evaluator(
121 | cfg, model=self.model, tokenizer=dataset.tokenizer
122 | )
123 |
124 | return self.evaluator
125 |
126 |
127 | def build_dataset(config, split='train', eval_mode=False):
128 | if config._name_ in DATASETS:
129 | config = copy.deepcopy(config)
130 | config.sample = config.sample and split in ('train', 'all') and not eval_mode
131 |
132 | dataset_cls = DATASETS[config._name_]
133 | config = {key: value for key, value in config.items() if key[0] != "_"}
134 | dataset = dataset_cls(split=split, **config)
135 | return dataset
136 | else:
137 | raise ValueError(
138 | f"Invalid dataset type: {config._name_}. Supported types: {list(DATASETS.keys())}"
139 | )
140 |
141 |
142 | def build_collator(config):
143 | if config._name_ in COLLATORS:
144 | config = copy.deepcopy(config)
145 |
146 | collator_cls = COLLATORS[config._name_]
147 | config = {key: value for key, value in config.items() if key[0] != "_"}
148 | collator = collator_cls(**config)
149 | return collator
150 | else:
151 | raise ValueError(
152 | f"Invalid data collator type: {config._name_}. Supported types: {list(COLLATORS.keys())}"
153 | )
154 |
155 |
156 | def build_model(config, *, dataset: Optional[Dataset] = None, **kwargs):
157 | if config._name_ in MODELS:
158 | model_cls = MODELS[config._name_]
159 | model_cls.inject_data_config(config, dataset)
160 | model = model_cls.init(config, **kwargs)
161 | model_cls.cleanup_config(config)
162 | return model
163 | else:
164 | raise ValueError(
165 | f"Invalid model type: {config._name_}. Supported types: {list(MODELS.keys())}"
166 | )
167 |
168 |
169 | def build_evaluator(config, **kwargs):
170 | if config is not None and config._name_ in EVALUATORS:
171 | evaluator_cls = EVALUATORS[config._name_]
172 | config = {key: value for key, value in config.items() if key[0] != "_"}
173 | config.update(**kwargs)
174 | return evaluator_cls(**config)
175 | else:
176 | return None
177 |
--------------------------------------------------------------------------------
/scoreperformer/experiments/integrations.py:
--------------------------------------------------------------------------------
1 | """
2 | External Trainer integraions.
3 |
4 | Adapted from: https://github.com/huggingface/transformers/blob/main/src/transformers/integrations.py
5 | """
6 | import json
7 |
8 | from omegaconf import OmegaConf
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | from .callbacks import TrainerCallback
12 |
13 |
14 | class TensorBoardCallback(TrainerCallback):
15 | """
16 | A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).
17 | Args:
18 | tb_writer (`SummaryWriter`, *optional*):
19 | The writer to use. Will instantiate one if not set.
20 | """
21 |
22 | def __init__(self, tb_writer=None):
23 | self.tb_writer = tb_writer
24 |
25 | def _init_summary_writer(self, config, log_dir=None):
26 | log_dir = log_dir or config.log_dir
27 | self.tb_writer = SummaryWriter(log_dir=log_dir)
28 |
29 | def on_train_begin(self, config, state, control, **kwargs):
30 | if self.tb_writer is None:
31 | self._init_summary_writer(config)
32 |
33 | if self.tb_writer is not None:
34 | self.tb_writer.add_text("trainer_config", config.to_json_string())
35 | exp_config = kwargs.get("exp_config", None)
36 | if exp_config is not None:
37 | def to_json_string(cfg):
38 | return json.dumps(OmegaConf.to_container(cfg, resolve=True), indent=2)
39 | self.tb_writer.add_text("data_config", to_json_string(exp_config.data))
40 | self.tb_writer.add_text("model_config", to_json_string(exp_config.model))
41 |
42 | def on_log(self, args, state, control, logs=None, **kwargs):
43 | if self.tb_writer is None:
44 | self._init_summary_writer(args)
45 |
46 | if self.tb_writer is not None:
47 | for k, v in logs.items():
48 | if isinstance(v, (int, float)):
49 | self.tb_writer.add_scalar(k, v, state.global_step)
50 | self.tb_writer.flush()
51 |
52 | def on_train_end(self, args, state, control, **kwargs):
53 | if self.tb_writer:
54 | self.tb_writer.close()
55 | self.tb_writer = None
56 |
--------------------------------------------------------------------------------
/scoreperformer/experiments/logging/__init__.py:
--------------------------------------------------------------------------------
1 | from .console_logger import setup_logger
2 |
--------------------------------------------------------------------------------
/scoreperformer/experiments/logging/console_logger.py:
--------------------------------------------------------------------------------
1 | """ Console and File Logger. """
2 | from typing import Optional
3 |
4 | from loguru import logger
5 |
6 |
7 | def setup_logger(file: Optional[str] = None, **extra):
8 | import sys
9 |
10 | time = "{time:YYYY-MM-DD HH:mm:ss}"
11 | level = "{level:<7}"
12 | process = "{extra[process]}"
13 | node_info = "{extra[node_info]}"
14 | module = "{name}:{function}:{line}"
15 | message = "{message}"
16 |
17 | # formatter = f"{time} {level} {module} - {message_level:6s} {message}"
18 | # formatter = f"{time} {level} - {message_level:6s} {message}"
19 | # formatter = f"{time} {level} {module} - {message}"
20 | formatter = f"{time} {level} - {message}"
21 | handlers = [dict(sink=sys.stdout, format=formatter, enqueue=True)]
22 | if file is not None:
23 | handlers.append(dict(sink=file, format=formatter, enqueue=True))
24 |
25 | logger.configure(
26 | handlers=handlers,
27 | extra=extra
28 | )
29 |
30 | return logger
31 |
--------------------------------------------------------------------------------
/scoreperformer/experiments/optimizers.py:
--------------------------------------------------------------------------------
1 | """ Optimizers and Learning Rate Schedulers """
2 |
3 | from dataclasses import dataclass, field
4 | from typing import Optional, Union, List, Dict
5 |
6 | import torch
7 | from torch.cuda import amp
8 |
9 | _optimizers = {
10 | "sgd": torch.optim.SGD,
11 | "adam": torch.optim.Adam,
12 | "adamw": torch.optim.AdamW
13 | }
14 |
15 |
16 | def switch_optimizer(optimizer_name: str):
17 | optimizer = getattr(torch.optim, optimizer_name, None)
18 | if optimizer is not None:
19 | return optimizer
20 |
21 | if optimizer_name in _optimizers:
22 | optimizer = _optimizers[optimizer_name]
23 | else:
24 | raise ValueError(
25 | f"There is no such name for optimizers: {optimizer_name}. "
26 | f"Valid optimizers: {list(_optimizers.keys())}"
27 | )
28 |
29 | return optimizer
30 |
31 |
32 | def get_optimizer(
33 | optimizer_name: str,
34 | optimizer_params: dict,
35 | lr: float,
36 | model: torch.nn.Module = None,
37 | parameters: List = None,
38 | ) -> torch.optim.Optimizer:
39 | """Find, initialize and return an optimizer.
40 | Args:
41 | optimizer_name (str): Optimizer name.
42 | optimizer_params (dict): Optimizer parameters.
43 | lr (float): Initial learning rate.
44 | model (torch.nn.Module): Model to pass to the optimizer.
45 | Returns:
46 | torch.optim.Optimizer: Functional optimizer.
47 | """
48 | optimizer = switch_optimizer(optimizer_name)
49 | if model is not None:
50 | parameters = model.parameters()
51 | return optimizer(parameters, lr=lr, **optimizer_params)
52 |
53 |
54 | _lr_schedulers = {
55 | "plateau": torch.optim.lr_scheduler.ReduceLROnPlateau,
56 | "exponential": torch.optim.lr_scheduler.ExponentialLR,
57 | }
58 |
59 |
60 | def switch_lr_scheduler(lr_scheduler_name: str):
61 | lr_scheduler = getattr(torch.optim, lr_scheduler_name, None)
62 | if lr_scheduler is not None:
63 | return lr_scheduler
64 |
65 | if lr_scheduler_name in _lr_schedulers:
66 | lr_scheduler = _lr_schedulers[lr_scheduler_name]
67 | else:
68 | raise ValueError(
69 | f"There is no such name for lr_schedulers: {lr_scheduler_name}. "
70 | f"Valid lr_schedulers: {list(_lr_schedulers.keys())}"
71 | )
72 |
73 | return lr_scheduler
74 |
75 |
76 | def get_lr_scheduler(
77 | lr_scheduler: str, lr_scheduler_params: Dict, optimizer: torch.optim.Optimizer
78 | ) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
79 | """Find, initialize and return a scheduler.
80 | Args:
81 | lr_scheduler (str): Scheduler name.
82 | lr_scheduler_params (Dict): Scheduler parameters.
83 | optimizer (torch.optim.Optimizer): Optimizer to pass to the scheduler.
84 | Returns:
85 | torch.optim.lr_scheduler._LRScheduler: Functional scheduler.
86 | """
87 | if lr_scheduler is None:
88 | return None
89 | scheduler = switch_lr_scheduler(lr_scheduler)
90 | return scheduler(optimizer, **lr_scheduler_params)
91 |
92 |
93 | @dataclass
94 | class OptimizerConfig:
95 | lr: Union[float, List[float]] = field(
96 | default=0.001, metadata={"help": "Learning rate for each optimizer. Defaults to 0.001"}
97 | )
98 | optimizer: Union[str, List[str]] = field(
99 | default="Adam", metadata={"help": "Optimizer(s) to use. Defaults to None"}
100 | )
101 | optimizer_params: Union[Dict, List[Dict]] = field(
102 | default_factory=dict, metadata={"help": "Optimizer(s) arguments. Defaults to {}"}
103 | )
104 | lr_scheduler: Union[str, List[str]] = field(
105 | default=None, metadata={"help": "Learning rate scheduler(s) to use. Defaults to None"}
106 | )
107 | lr_scheduler_params: Dict = field(
108 | default_factory=dict, metadata={"help": "Learning rate scheduler(s) arguments. Defaults to {}"}
109 | )
110 | grad_clip: Optional[float] = field(
111 | default=None, metadata={"help": "Gradient clipping threshold. Defaults to None"}
112 | )
113 | grad_accum_steps: Optional[int] = field(
114 | default=1, metadata={"help": "Gradient accumulation steps. Defaults to 1"}
115 | )
116 | mixed_precision: bool = field(
117 | default=False, metadata={"help": "Use mixed precision for training. Defaults to False"}
118 | )
119 |
120 |
121 | class Optimizer:
122 | """ Combined Optimizer and Scheduler class. """
123 |
124 | def __init__(
125 | self,
126 | config: OptimizerConfig,
127 | model: torch.nn.Module = None,
128 | parameters: List = None
129 | ):
130 | self.config = config
131 |
132 | self.optimizer = get_optimizer(
133 | config.optimizer,
134 | config.optimizer_params or {},
135 | config.lr,
136 | model=model,
137 | parameters=parameters
138 | )
139 |
140 | self.lr_scheduler = get_lr_scheduler(
141 | config.lr_scheduler,
142 | config.lr_scheduler_params,
143 | optimizer=self.optimizer
144 | )
145 |
146 | self._scaler = amp.GradScaler(enabled=config.mixed_precision)
147 | self.grad_clip = config.grad_clip
148 | self.grad_accum_steps = config.grad_accum_steps
149 | self.mixed_precision = config.mixed_precision
150 |
151 | def step(self, loss_value, step_optimizer=True):
152 | self._scaler.scale(loss_value / self.grad_accum_steps).backward()
153 |
154 | grad_norm = None
155 | if step_optimizer:
156 | self._scaler.unscale_(self.optimizer)
157 |
158 | if self.grad_clip is not None:
159 | parameters = self.optimizer.param_groups[0]["params"]
160 | grad_norm = torch.nn.utils.clip_grad_norm_(parameters, self.grad_clip)
161 | if torch.isnan(grad_norm) or torch.isinf(grad_norm):
162 | grad_norm = None
163 |
164 | self._scaler.step(self.optimizer)
165 | self._scaler.update()
166 |
167 | self.optimizer.zero_grad()
168 |
169 | return grad_norm
170 |
171 | def anneal_on_epoch_end(self, *args):
172 | if self.lr_scheduler is not None:
173 | args = args if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) else []
174 | self.lr_scheduler.step(*args)
175 |
176 | def anneal_on_iter_end(self, *args):
177 | if self.lr_scheduler is not None:
178 | self.lr_scheduler.step()
179 |
180 | def get_last_lr(self):
181 | return self.lr_scheduler.get_last_lr()
182 |
183 | def load_state_dict(self, state_dict, restore_lr=True):
184 | self.optimizer.load_state_dict(state_dict["optimizer"])
185 |
186 | if restore_lr:
187 | lr_scheduler_params = state_dict.get("lr_scheduler", None)
188 | if lr_scheduler_params is not None:
189 | self.lr_scheduler.load_state_dict(lr_scheduler_params)
190 | else:
191 | base_lr = self.lr_scheduler.get_last_lr()
192 | for lr, param_group in zip(base_lr, self.optimizer.param_groups):
193 | param_group["lr"] = lr
194 |
195 | def state_dict(self):
196 | return {
197 | "optimizer": self.optimizer.state_dict(),
198 | "lr_scheduler": self.lr_scheduler.state_dict()
199 | }
200 |
201 | def set_progress(self, iteration, epoch):
202 | for key, value in self.optimizer.state.items():
203 | if "step" in value:
204 | self.optimizer.state[key]["step"] = iteration
205 |
206 | self.lr_scheduler.last_epoch = epoch
207 | self.lr_scheduler._step_count = epoch + 1
208 |
209 | def __repr__(self):
210 | return f"Optimizer(optimizer={self.optimizer}, lr_scheduler={self.lr_scheduler})"
211 |
--------------------------------------------------------------------------------
/scoreperformer/experiments/trainer_utils.py:
--------------------------------------------------------------------------------
1 | """ Trainer utility functions. """
2 |
3 | from pathlib import Path
4 | from typing import Union, List
5 |
6 | from omegaconf import ListConfig
7 |
8 | from scoreperformer.utils import ExplicitEnum
9 |
10 |
11 | def resolve_path(constructor: Union[Path, str, List[str]]):
12 | return Path(*constructor) if isinstance(constructor, (List, ListConfig)) else Path(constructor)
13 |
14 |
15 | class Accumulator:
16 | def __init__(self):
17 | self._sums = {}
18 | self._counts = {}
19 |
20 | def __getitem__(self, key):
21 | return self._sums[key] / self._counts[key]
22 |
23 | @property
24 | def sums(self):
25 | return self._sums
26 |
27 | @property
28 | def counts(self):
29 | return self._counts
30 |
31 | @property
32 | def mean_values(self):
33 | return {key: self._sums[key] / self._counts[key] for key in self._sums if self._counts[key] > 0}
34 |
35 | def items(self):
36 | return self.mean_values.items()
37 |
38 | def add_value(self, name, value):
39 | self._sums[name] = value
40 | self._counts[name] = 1
41 |
42 | def update_value(self, name, value):
43 | if name not in self._sums:
44 | self.add_value(name, value)
45 | else:
46 | self._sums[name] += value
47 | self._counts[name] += 1
48 |
49 | def add_values(self, name_dict):
50 | for key, value in name_dict.items():
51 | self.add_value(key, value)
52 |
53 | def update_values(self, value_dict):
54 | for key, value in value_dict.items():
55 | self.update_value(key, value)
56 |
57 | def reset(self):
58 | for key in self._sums:
59 | self._sums[key] = 0
60 | self._counts[key] = 0
61 |
62 | def clear(self):
63 | self._sums = {}
64 | self._counts = {}
65 |
66 |
67 | class IntervalStrategy(ExplicitEnum):
68 | NO = "no"
69 | STEPS = "steps"
70 | EPOCH = "epoch"
71 |
--------------------------------------------------------------------------------
/scoreperformer/inference/__init__.py:
--------------------------------------------------------------------------------
1 | from .generators import ScorePerformerGenerator
2 | from .messengers import SPMupleMessenger
3 |
--------------------------------------------------------------------------------
/scoreperformer/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Model
2 | from .scoreperformer import (
3 | Performer,
4 | ScorePerformer,
5 | ScorePerformerEvaluator
6 | )
7 |
8 | MODELS = {name: cls for name, cls in globals().items() if ".model." in str(cls)}
9 | EVALUATORS = {name: cls for name, cls in globals().items() if ".evaluator." in str(cls)}
10 |
--------------------------------------------------------------------------------
/scoreperformer/models/base.py:
--------------------------------------------------------------------------------
1 | """ Base Model class. """
2 |
3 | from abc import abstractmethod
4 | from typing import Optional, List, Dict, Union
5 |
6 | import torch
7 | import torch.nn as nn
8 | from loguru import logger
9 | from omegaconf import DictConfig, OmegaConf
10 | from torch import Tensor
11 | from torch.utils.data import Dataset
12 |
13 | from scoreperformer.modules.constructor import Constructor, ModuleConfig
14 |
15 |
16 | class Model(nn.Module, Constructor):
17 | @abstractmethod
18 | def forward(self, *args, **kwargs):
19 | ...
20 |
21 | @abstractmethod
22 | def prepare_inputs(self, inputs) -> Dict[str, Tensor]:
23 | ...
24 |
25 | @staticmethod
26 | def allocate_inputs(inputs_dict, device):
27 | return {key: value.to(device, non_blocking=True) for key, value in inputs_dict.items()}
28 |
29 | @staticmethod
30 | def inject_data_config(
31 | config: Optional[Union[DictConfig, ModuleConfig]],
32 | dataset: Optional[Dataset]
33 | ) -> Optional[Union[DictConfig, ModuleConfig]]:
34 | return config
35 |
36 | @staticmethod
37 | def cleanup_config(
38 | config: Optional[Union[DictConfig, ModuleConfig]]
39 | ) -> Optional[Union[DictConfig, ModuleConfig]]:
40 | return config
41 |
42 | @classmethod
43 | def from_pretrained(cls, checkpoint_path: str):
44 | checkpoint = torch.load(checkpoint_path, map_location='cpu')
45 |
46 | model_cfg = OmegaConf.create(checkpoint['model']['config'])
47 | model = cls.init(model_cfg)
48 |
49 | state_dict = checkpoint['model']['state_dict']
50 | model.load_state_dict(state_dict, strict=True)
51 |
52 | return model
53 |
54 | def load(
55 | self,
56 | state_dict: Dict[str, Tensor],
57 | ignore_layers: Optional[List] = None,
58 | ignore_mismatched_keys: bool = False
59 | ):
60 | ignore_layers = ignore_layers or []
61 |
62 | model_state = self.state_dict()
63 |
64 | extra_keys = [k for k in state_dict.keys() if k not in model_state]
65 | if extra_keys:
66 | logger.warning(f"The following checkpoint keys are not presented in the model "
67 | f"and will be ignored: {extra_keys}")
68 | state_dict = {k: v for k, v in state_dict.items() if k not in extra_keys}
69 |
70 | ignored_keys = []
71 | if ignore_mismatched_keys:
72 | auto_ignore_layers = []
73 | for k, v in state_dict.items():
74 | if v.data.shape != model_state[k].data.shape:
75 | auto_ignore_layers.append(k)
76 | logger.info(f"Automatically found the checkpoint keys "
77 | f"incompatible with the model: {auto_ignore_layers}")
78 | ignored_keys.extend(auto_ignore_layers)
79 |
80 | if ignore_layers:
81 | for k, v in state_dict.items():
82 | if any(layer in k for layer in ignore_layers):
83 | ignored_keys.append(k)
84 |
85 | if ignored_keys:
86 | state_dict = {k: v for k, v in state_dict.items()
87 | if all(k != key for key in ignored_keys)}
88 | logger.info(f"The following checkpoint keys were ignored: {ignored_keys}")
89 |
90 | model_state.update(state_dict)
91 | self.load_state_dict(model_state)
92 |
93 | return self
94 |
95 | def freeze(self, exception_list=None):
96 | not_frozen = []
97 | exception_list = exception_list or []
98 | for name, param in self.named_parameters():
99 | param.requires_grad = any((name.startswith(layer) for layer in exception_list))
100 | if param.requires_grad:
101 | not_frozen.append(name)
102 | logger.info(f"The model graph has been frozen, except for the following parameters: {not_frozen}")
103 |
--------------------------------------------------------------------------------
/scoreperformer/models/classifiers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ilya16/ScorePerformer/c04f9a3155bb6afb6ca7bcc1cac2325184cc1262/scoreperformer/models/classifiers/__init__.py
--------------------------------------------------------------------------------
/scoreperformer/models/classifiers/evaluator.py:
--------------------------------------------------------------------------------
1 | """ Embedding Classifier evaluator. """
2 |
3 | import torch
4 |
5 |
6 | class EmbeddingClassifierEvaluator:
7 | def __init__(self, model):
8 | self.model = model
9 |
10 | def _accuracy(self, predictions, labels):
11 | return (predictions == labels).float().mean()
12 |
13 | @torch.no_grad()
14 | def __call__(self, inputs, outputs):
15 | labels = inputs["labels"]
16 | predictions = torch.argmax(outputs.logits, dim=-1)
17 | metrics = {"accuracy": self._accuracy(predictions, labels)}
18 |
19 | return metrics
20 |
--------------------------------------------------------------------------------
/scoreperformer/models/classifiers/model.py:
--------------------------------------------------------------------------------
1 | """ Embedding Classifier Models"""
2 |
3 | from dataclasses import dataclass, field
4 | from typing import Optional, Dict, Sequence, Union, List
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from omegaconf import MISSING
11 | from torch import Tensor
12 |
13 | from scoreperformer.models.base import Model
14 | from scoreperformer.modules.constructor import Registry, VariableModuleConfig
15 |
16 |
17 | @dataclass
18 | class EmbeddingClassifierOutput:
19 | logits: Tensor = None
20 | loss: Optional[Tensor] = None
21 | losses: Optional[Dict[str, Tensor]] = None
22 |
23 |
24 | EmbeddingClassifiersRegistry = type("_EmbeddingClassifiersRegistry", (Registry,), {})()
25 |
26 |
27 | @dataclass
28 | class EmbeddingClassifierConfig(VariableModuleConfig):
29 | input_dim: int = MISSING
30 | num_classes: int = MISSING
31 | dropout: bool = 0.
32 | weight: Optional[List[float]] = None
33 |
34 |
35 | @dataclass
36 | class LinearEmbeddingClassifierConfig(EmbeddingClassifierConfig):
37 | _target_: str = "linear"
38 | hidden_dims: Optional[Sequence[int]] = field(default_factory=lambda: (32,))
39 |
40 |
41 | @EmbeddingClassifiersRegistry.register("linear")
42 | class LinearEmbeddingClassifier(Model):
43 | def __init__(
44 | self,
45 | input_dim: int,
46 | num_classes: int,
47 | hidden_dims: Optional[Sequence[int]] = (32,),
48 | dropout: bool = 0.,
49 | class_weights: Optional[List[float]] = None
50 | ):
51 | super().__init__()
52 |
53 | self.num_classes = num_classes
54 |
55 | class_weights = torch.ones(num_classes) if class_weights is None else torch.tensor(class_weights)
56 | self.register_buffer("class_weights", class_weights.float())
57 |
58 | hidden_dims = hidden_dims or []
59 | hidden_dims = [hidden_dims] if isinstance(hidden_dims, int) else hidden_dims
60 | hidden_dims = list(hidden_dims)
61 |
62 | in_dims = [input_dim] + hidden_dims
63 | out_dims = hidden_dims + [num_classes]
64 |
65 | layers = []
66 | for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
67 | layers.append(nn.Linear(in_dim, out_dim))
68 | if i < len(in_dims) - 1:
69 | layers.append(nn.ReLU())
70 |
71 | self.layers = nn.Sequential(*layers)
72 | self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
73 |
74 | def forward(self, embeddings: Tensor, labels: Optional[Tensor] = None):
75 | x = embeddings.squeeze(1) if embeddings.ndim == 3 else embeddings
76 | for layer in self.layers:
77 | x = layer(self.dropout(x))
78 | logits = x
79 |
80 | loss = losses = None
81 | if labels is not None:
82 | loss = F.cross_entropy(logits, labels, weight=self.class_weights)
83 |
84 | return EmbeddingClassifierOutput(
85 | logits=logits,
86 | loss=loss,
87 | losses=losses
88 | )
89 |
90 | def prepare_inputs(self, inputs) -> Dict[str, Tensor]:
91 | return inputs
92 |
93 |
94 | @dataclass
95 | class SequentialEmbeddingClassifierConfig(EmbeddingClassifierConfig):
96 | _target_: str = "sequential"
97 | hidden_dim: int = 32
98 |
99 |
100 | @EmbeddingClassifiersRegistry.register("sequential")
101 | class SequentialEmbeddingClassifier(Model):
102 | def __init__(
103 | self,
104 | input_dim: int,
105 | num_classes: int,
106 | hidden_dim: int = 32,
107 | dropout: bool = 0.,
108 | class_weights: Optional[List[float]] = None
109 | ):
110 | super().__init__()
111 |
112 | self.num_classes = num_classes
113 |
114 | class_weights = torch.ones(num_classes) if class_weights is None else torch.tensor(class_weights)
115 | self.register_buffer("class_weights", class_weights.float())
116 |
117 | self.gru = nn.GRU(
118 | input_size=input_dim,
119 | hidden_size=hidden_dim,
120 | batch_first=True,
121 | dropout=dropout
122 | )
123 |
124 | self.output = nn.Linear(hidden_dim, num_classes)
125 |
126 | def forward(self, embeddings: Tensor, labels: Optional[Tensor] = None):
127 | self.gru.flatten_parameters()
128 | _, out = self.gru(embeddings) # (1, b, h)
129 | logits = self.output(out[0])
130 |
131 | loss = losses = None
132 | if labels is not None:
133 | loss = F.cross_entropy(logits, labels, weight=self.class_weights)
134 |
135 | return EmbeddingClassifierOutput(
136 | logits=logits,
137 | loss=loss,
138 | losses=losses
139 | )
140 |
141 | def prepare_inputs(self, inputs) -> Dict[str, Tensor]:
142 | return inputs
143 |
144 |
145 | @dataclass
146 | class MultiHeadEmbeddingClassifierOutput:
147 | logits: Dict[str, Tensor] = None
148 | loss: Optional[Tensor] = None
149 | losses: Optional[Dict[str, Tensor]] = None
150 |
151 |
152 | @dataclass
153 | class MultiHeadEmbeddingClassifierConfig(VariableModuleConfig):
154 | _target_: str = "multi-head"
155 | input_dim: int = MISSING
156 | num_classes: Dict[str, int] = MISSING
157 | classifier: LinearEmbeddingClassifierConfig = MISSING
158 | class_samples: Optional[Dict[str, List[int]]] = None
159 | weighted_classes: bool = False
160 | loss_weight: float = 1.
161 | detach_inputs: Union[bool, float] = False
162 |
163 |
164 | @EmbeddingClassifiersRegistry.register("multi-head")
165 | class MultiHeadEmbeddingClassifier(Model):
166 | def __init__(
167 | self,
168 | input_dim: int,
169 | num_classes: Dict[str, int],
170 | classifier: LinearEmbeddingClassifierConfig,
171 | class_samples: Optional[Dict[str, List[int]]] = None,
172 | loss_weight: float = 1.,
173 | weighted_classes: bool = False,
174 | detach_inputs: Union[bool, float] = False
175 | ):
176 | super().__init__()
177 |
178 | self.num_classes = num_classes
179 |
180 | self.heads = nn.ModuleDict({})
181 | for key, num in num_classes.items():
182 | num_samples = class_samples.get(key, None) if class_samples is not None else None
183 | class_weights = self._class_weights(num_samples) if weighted_classes and num_samples is not None else None
184 | self.heads[key] = LinearEmbeddingClassifier.init(
185 | config=classifier,
186 | input_dim=input_dim,
187 | num_classes=num,
188 | class_weights=class_weights
189 | )
190 |
191 | self.loss_weight = loss_weight
192 | self.detach_inputs = float(detach_inputs)
193 |
194 | @staticmethod
195 | def _class_weights(num_samples: List[int], beta: float = 0.999, mult: int = 1e4):
196 | num_samples = np.maximum(num_samples, 1e-6)
197 | effective_num = 1.0 - np.power(beta, np.array(num_samples) * mult)
198 | weights = (1.0 - beta) / np.array(effective_num)
199 | weights = weights / np.sum(weights) * len(num_samples)
200 | return weights.tolist()
201 |
202 | def forward(self, embeddings: Tensor, labels: Optional[Tensor] = None):
203 | embeddings = self.detach_inputs * embeddings.detach() + (1 - self.detach_inputs) * embeddings
204 |
205 | logits = {}
206 | loss, losses = 0., {}
207 | for i, (key, head) in enumerate(self.heads.items()):
208 | out = head(embeddings, labels=labels[..., i] if labels is not None else None)
209 | logits[key] = out.logits
210 |
211 | if out.loss:
212 | key = 'clf/' + key
213 | loss += out.loss
214 | losses[key] = out.loss
215 |
216 | loss = self.loss_weight * loss / len(self.heads)
217 | losses['clf'] = loss
218 |
219 | return MultiHeadEmbeddingClassifierOutput(
220 | logits=logits,
221 | loss=loss if labels is not None else None,
222 | losses=losses if labels is not None else None
223 | )
224 |
225 | def prepare_inputs(self, inputs) -> Dict[str, Tensor]:
226 | return inputs
227 |
--------------------------------------------------------------------------------
/scoreperformer/models/scoreperformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .embeddings import TupleTokenEmbeddings, TupleTokenLMHead, TupleTokenTiedLMHead
2 | from .evaluator import ScorePerformerEvaluator
3 | from .model import (
4 | PerformerConfig,
5 | Performer,
6 | ScorePerformerConfig,
7 | ScorePerformer
8 | )
9 | from .transformer import (
10 | TupleTransformerConfig,
11 | TupleTransformer,
12 | TupleTransformerCaches
13 | )
14 | from .wrappers import ScorePerformerMLMWrapper
15 |
--------------------------------------------------------------------------------
/scoreperformer/models/scoreperformer/evaluator.py:
--------------------------------------------------------------------------------
1 | """ ScorePerformer metric evaluator. """
2 |
3 | from typing import Optional, Union, List
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | from scoreperformer.data.collators import LMScorePerformanceInputs
9 | from scoreperformer.data.tokenizers import OctupleM
10 | from .model import ScorePerformerOutputs
11 | from .transformer import TupleTransformerOutput
12 | from .wrappers import ScorePerformerLMModes
13 |
14 |
15 | class ScorePerformerEvaluator:
16 | def __init__(
17 | self,
18 | model,
19 | tokenizer: Optional[OctupleM] = None,
20 | label_pad_token_id: int = -100,
21 | weighted_distance: bool = False,
22 | ignore_keys: Optional[List[str]] = None
23 | ):
24 | self.model = model
25 | self.tokenizer = tokenizer
26 | self.label_pad_token_id = label_pad_token_id
27 | self.weighted_distance = weighted_distance
28 | self.ignore_keys = ignore_keys
29 |
30 | self.token_values = None
31 | if self.tokenizer is not None:
32 | self.token_values = {
33 | key: torch.from_numpy(values)[:, None]
34 | for key, values in self.tokenizer.token_values(normalize=False).items()
35 | }
36 |
37 | def _accuracy(self, predictions, labels):
38 | label_mask = labels != self.label_pad_token_id
39 | return (predictions[label_mask] == labels[label_mask]).float().mean()
40 |
41 | def _distance(self, predictions, targets):
42 | return (predictions - targets).abs().float().mean()
43 |
44 | def _weighted_distance(self, probs, targets, token_values):
45 | return ((targets[:, None] - token_values[None, :]).abs() * probs[..., None]).sum(dim=1).mean()
46 |
47 | @torch.no_grad()
48 | def __call__(
49 | self,
50 | inputs: Union[dict, LMScorePerformanceInputs],
51 | outputs: Union[TupleTransformerOutput, ScorePerformerOutputs],
52 | ignore_keys: Optional[List[str]] = None
53 | ):
54 | metrics = {}
55 | ignore_keys = ignore_keys or self.ignore_keys
56 |
57 | if isinstance(inputs, LMScorePerformanceInputs):
58 | labels = inputs.labels.tokens.to(outputs.hidden_state.device)
59 | else:
60 | labels = inputs["labels"]
61 |
62 | if self.model.mode in (ScorePerformerLMModes.CLM, ScorePerformerLMModes.MixedLM):
63 | labels = labels[:, 1:]
64 |
65 | if isinstance(outputs, ScorePerformerOutputs):
66 | outputs = outputs.perf_decoder
67 |
68 | predictions = torch.cat(
69 | list(map(lambda l: torch.argmax(l, dim=-1, keepdim=True), outputs.logits.values())),
70 | dim=-1
71 | )
72 |
73 | metrics[f"accuracy"] = self._accuracy(predictions, labels)
74 | if ignore_keys:
75 | use_ids = torch.tensor(
76 | [i for i, key in enumerate(outputs.logits.keys()) if key not in ignore_keys],
77 | device=predictions.device, dtype=torch.long
78 | )
79 | metrics[f"accuracy/pred"] = self._accuracy(predictions[..., use_ids], labels[..., use_ids])
80 |
81 | for i, (key, logits) in enumerate(outputs.logits.items()):
82 | if ignore_keys and key in ignore_keys:
83 | continue
84 |
85 | if torch.any(labels[..., i] != self.label_pad_token_id):
86 | metrics[f"accuracy/{key}"] = self._accuracy(predictions[..., i], labels[..., i])
87 |
88 | if self.token_values is not None:
89 | for i, (key, logits) in enumerate(outputs.logits.items()):
90 | if ignore_keys and key in ignore_keys:
91 | continue
92 |
93 | self.token_values[key] = self.token_values[key].to(predictions.device)
94 |
95 | label_mask = labels[..., i] != self.label_pad_token_id
96 | if torch.any(label_mask):
97 | preds = F.embedding(predictions[..., i][label_mask], self.token_values[key])
98 | targets = F.embedding(labels[..., i][label_mask], self.token_values[key])
99 |
100 | if self.weighted_distance:
101 | probs = outputs.logits[key].softmax(dim=-1)[label_mask]
102 | metrics[f"distance/{key}"] = self._weighted_distance(probs, targets, self.token_values[key])
103 | else:
104 | metrics[f"distance/{key}"] = self._distance(preds, targets)
105 |
106 | return metrics
107 |
--------------------------------------------------------------------------------
/scoreperformer/models/scoreperformer/transformer.py:
--------------------------------------------------------------------------------
1 | """ TupleTransformer: Transformer with support for tuple token sequences. """
2 |
3 | from dataclasses import dataclass, MISSING, field
4 | from typing import Optional, Union, Dict, List
5 |
6 | import torch
7 | import torch.nn as nn
8 | from omegaconf import DictConfig
9 | from torch import Tensor
10 |
11 | from scoreperformer.modules.constructor import Constructor, ModuleConfig
12 | from scoreperformer.modules.transformer import (
13 | TransformerConfig, TransformerRegistry, TransformerIntermediates,
14 | AbsolutePositionalEmbedding
15 | )
16 | from scoreperformer.utils import ExplicitEnum
17 | from .embeddings import (
18 | TupleTokenEmbeddingsConfig, TupleTokenEmbeddingsRegistry,
19 | TupleTokenHeadsConfig, TupleTokenHeadsRegistry,
20 | TupleTokenRegressionHeadConfig, TupleTokenRegressionHead
21 | )
22 |
23 |
24 | class EmbeddingModes(ExplicitEnum):
25 | SUM = "mean"
26 | CONCAT = "cat"
27 | ATTENTION = "attention"
28 | ADANORM = "adanorm"
29 |
30 |
31 | @dataclass
32 | class TupleTransformerCaches:
33 | token_emb: Optional[Tensor] = None
34 | transformer: Optional[TransformerIntermediates] = None
35 |
36 |
37 | @dataclass
38 | class TupleTransformerOutput:
39 | hidden_state: Tensor
40 | logits: Optional[Dict[str, Tensor]] = None
41 | attentions: Optional[List[Tensor]] = None
42 | caches: Optional[TupleTransformerCaches] = None
43 | reg_values: Optional[Dict[str, Tensor]] = None
44 |
45 |
46 | @dataclass
47 | class TupleTransformerConfig(ModuleConfig):
48 | num_tokens: Dict[str, int] = MISSING
49 | dim: int = 512
50 | max_seq_len: int = 1024
51 | transformer: Union[DictConfig, TransformerConfig] = field(
52 | default_factory=lambda: TransformerConfig(_target_="default"))
53 |
54 | token_embeddings: Union[DictConfig, TupleTokenEmbeddingsConfig] = field(
55 | default_factory=TupleTokenEmbeddingsConfig)
56 | use_abs_pos_emb: bool = True
57 | emb_norm: bool = False
58 | emb_dropout: float = 0.0
59 |
60 | context_emb_dim: Optional[int] = None
61 | context_emb_mode: str = EmbeddingModes.ATTENTION
62 | style_emb_dim: Optional[int] = None
63 | style_emb_mode: str = EmbeddingModes.CONCAT
64 |
65 | lm_head: Optional[Union[DictConfig, TupleTokenHeadsConfig]] = None
66 | regression_head: Optional[Union[DictConfig, TupleTokenRegressionHeadConfig]] = None
67 |
68 |
69 | class TupleTransformer(nn.Module, Constructor):
70 | def __init__(
71 | self,
72 | num_tokens: Dict[str, int],
73 | dim: int = 512,
74 | max_seq_len: int = 1024,
75 | transformer: Union[DictConfig, TransformerConfig] = TransformerConfig(_target_="default"),
76 | token_embeddings: Union[DictConfig, TupleTokenEmbeddingsConfig] = TupleTokenEmbeddingsConfig(),
77 | use_abs_pos_emb: bool = True,
78 | emb_norm: bool = False,
79 | emb_dropout: float = 0.0,
80 | context_emb_dim: Optional[int] = None,
81 | context_emb_mode: str = EmbeddingModes.ATTENTION,
82 | style_emb_dim: Optional[int] = None,
83 | style_emb_mode: str = EmbeddingModes.CONCAT,
84 | lm_head: Optional[Union[DictConfig, TupleTokenHeadsConfig]] = None,
85 | regression_head: Optional[Union[DictConfig, TupleTokenRegressionHeadConfig]] = None
86 | ):
87 | super().__init__()
88 |
89 | self.dim = dim
90 | self.max_seq_len = max_seq_len
91 | emb_dim = dim # default(emb_dim, dim)
92 |
93 | self.context_emb_dim = context_emb_dim or 0
94 | self.context_emb_mode = context_emb_mode
95 |
96 | self.style_emb_dim = style_emb_dim or 0
97 | self.style_emb_mode = style_emb_mode
98 |
99 | self.token_emb = TupleTokenEmbeddingsRegistry.instantiate(
100 | config=token_embeddings,
101 | num_tokens=num_tokens,
102 | emb_dims=token_embeddings.get("emb_dims", emb_dim),
103 | project_emb_dim=emb_dim
104 | )
105 |
106 | if self.context_emb_mode != EmbeddingModes.ATTENTION:
107 | transformer.cross_attend = False
108 |
109 | self.transformer = TransformerRegistry.instantiate(
110 | transformer,
111 | dim=dim,
112 | use_adanorm=self.style_emb_mode == EmbeddingModes.ADANORM,
113 | style_emb_dim=self.style_emb_dim
114 | )
115 |
116 | self.pos_emb = None
117 | if use_abs_pos_emb:
118 | self.pos_emb = AbsolutePositionalEmbedding(emb_dim, self.max_seq_len)
119 | nn.init.kaiming_normal_(self.pos_emb.emb.weight)
120 |
121 | self.emb_norm = nn.LayerNorm(emb_dim) if emb_norm else nn.Identity()
122 | self.emb_dropout = nn.Dropout(emb_dropout) if emb_dropout > 0. else nn.Identity()
123 |
124 | self.project_emb = nn.Identity()
125 | total_emb_dim = (
126 | emb_dim
127 | + int(context_emb_mode == EmbeddingModes.CONCAT) * self.context_emb_dim
128 | + int(style_emb_mode == EmbeddingModes.CONCAT) * self.style_emb_dim
129 | )
130 | if total_emb_dim != dim:
131 | self.project_emb = nn.Linear(total_emb_dim, dim)
132 |
133 | self.lm_head = None
134 | if lm_head is not None:
135 | self.lm_head = TupleTokenHeadsRegistry.instantiate(
136 | config=lm_head, dim=dim, embeddings=self.token_emb
137 | )
138 |
139 | self.regression_head = None
140 | if regression_head is not None:
141 | assert self.token_emb.continuous, "TupleTokenRegressionHead depends on `continuous` token embeddings."
142 | self.regression_head = TupleTokenRegressionHead.init(
143 | config=regression_head, dim=dim
144 | )
145 |
146 | def forward(
147 | self,
148 | x: Tensor,
149 | mask: Optional[Tensor] = None,
150 | x_extra: Optional[Union[Tensor, List[Tensor]]] = None,
151 | style_embeddings: Optional[Tensor] = None,
152 | context: Optional[Tensor] = None,
153 | context_mask: Optional[Tensor] = None,
154 | caches: Optional[TupleTransformerCaches] = None,
155 | logits_keys: Optional[List] = None,
156 | return_embeddings: bool = False,
157 | return_attn: bool = False,
158 | return_caches: bool = False,
159 | **kwargs
160 | ):
161 | token_emb_cache = caches.token_emb if caches is not None else None
162 | if hasattr(self.token_emb, "multiseq_mode") and x_extra is not None:
163 | x_extra = [x_extra] if isinstance(x_extra, Tensor) else x_extra
164 | token_emb = self.token_emb([x] + x_extra, cache=token_emb_cache)
165 | else:
166 | token_emb = self.token_emb(x, cache=token_emb_cache)
167 |
168 | x = token_emb
169 | if self.pos_emb is not None:
170 | x = x + self.pos_emb(x)
171 | x = self.emb_norm(x)
172 |
173 | if context is not None and self.context_emb_mode == EmbeddingModes.CONCAT:
174 | context = context[:, :x.shape[1]]
175 | x = torch.cat([x, context], dim=-1)
176 | context = None
177 |
178 | if style_embeddings is not None:
179 | style_embeddings = style_embeddings[:, :x.shape[1]]
180 | if self.style_emb_mode == EmbeddingModes.CONCAT:
181 | x = torch.cat([x, style_embeddings], dim=-1)
182 | style_embeddings = None
183 |
184 | x = self.emb_dropout(x)
185 | x = self.project_emb(x)
186 |
187 | out, intermediates = self.transformer(
188 | x,
189 | mask=mask,
190 | context=context,
191 | context_mask=context_mask,
192 | style_embeddings=style_embeddings,
193 | intermediates_cache=caches.transformer if caches is not None else None,
194 | return_hiddens=True
195 | )
196 |
197 | logits = None
198 | if not return_embeddings and self.lm_head is not None:
199 | logits = self.lm_head(out, keys=logits_keys)
200 |
201 | reg_values = None
202 | if not return_embeddings and self.regression_head is not None:
203 | reg_values = self.regression_head(out, keys=logits_keys)
204 |
205 | attn_maps = None
206 | if return_attn:
207 | attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attention))
208 |
209 | caches = None
210 | if return_caches:
211 | caches = TupleTransformerCaches(
212 | token_emb=token_emb,
213 | transformer=intermediates
214 | )
215 |
216 | return TupleTransformerOutput(
217 | hidden_state=out,
218 | logits=logits,
219 | attentions=attn_maps,
220 | caches=caches,
221 | reg_values=reg_values
222 | )
223 |
--------------------------------------------------------------------------------
/scoreperformer/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ilya16/ScorePerformer/c04f9a3155bb6afb6ca7bcc1cac2325184cc1262/scoreperformer/modules/__init__.py
--------------------------------------------------------------------------------
/scoreperformer/modules/constructor.py:
--------------------------------------------------------------------------------
1 | """ Config-based Module Constructor. """
2 | import copy
3 | from dataclasses import dataclass
4 | from inspect import signature
5 | from typing import Optional, Union, Callable
6 |
7 | import torch
8 | from loguru import logger
9 | from omegaconf import DictConfig, OmegaConf, MISSING
10 |
11 |
12 | @dataclass
13 | class ModuleConfig:
14 | def update(self, **kwargs):
15 | kwargs = {key: kwargs.get(key, MISSING) for key in kwargs if not key.startswith("_")} # get rid of service keys
16 | invalid_keys = [key for key in kwargs if key not in self.__dict__]
17 | if invalid_keys:
18 | logger.warning(f"The following params are incompatible with the config {self.__name__}, "
19 | f"so they will be ignored: {invalid_keys}.")
20 | kwargs = {key: value for key, value in kwargs.items() if key not in invalid_keys}
21 |
22 | for key, value in kwargs.items():
23 | self.__setattr__(key, value)
24 |
25 | return self
26 |
27 | def to_dict(self, check_missing=False, make_copy=True):
28 | if check_missing:
29 | missing = [key for key, value in self.__dict__.items() if value == MISSING]
30 | if missing:
31 | raise RuntimeError(f"The following params are mandatory to set: {missing}")
32 |
33 | return copy.deepcopy(self.__dict__) if make_copy else {k: v for k, v in self.__dict__.items()}
34 |
35 |
36 | class Constructor:
37 | @classmethod
38 | def _pre_init(cls, config: Optional[Union[DictConfig, ModuleConfig]] = None, **parameters):
39 | module_parameters = {key: value for key, value in parameters.items() if isinstance(value, torch.nn.Module)}
40 | parameters = {key: value for key, value in parameters.items() if key not in module_parameters}
41 | config = merge(config or {}, parameters)
42 | for key, value in module_parameters.items():
43 | config[key] = value
44 | config = {key: value for key, value in config.items() if not key.startswith("_")}
45 |
46 | return config
47 |
48 | @classmethod
49 | def init(cls, config: Optional[Union[DictConfig, ModuleConfig]] = None, **parameters):
50 | config = cls._pre_init(config, **parameters)
51 |
52 | signature_ = dict(signature(cls.__init__).parameters)
53 |
54 | if "kwargs" not in signature_:
55 | invalid_keys = [key for key in config if key not in signature_]
56 | if invalid_keys:
57 | logger.warning(f"The following params are incompatible with the {cls.__name__} constructor, "
58 | f"so they will be ignored: {invalid_keys}.")
59 | config = {key: value for key, value in config.items() if key not in invalid_keys}
60 |
61 | missing = [key for key, value in config.items() if value == MISSING]
62 | if missing:
63 | raise RuntimeError(f"The following params are mandatory to set: {missing}")
64 |
65 | return cls(**config)
66 |
67 |
68 | def merge(*containers: Union[DictConfig, ModuleConfig, dict], as_omega=False) -> Union[dict, DictConfig]:
69 | readonly = False
70 | _containers = []
71 | for cont in containers:
72 | if isinstance(cont, ModuleConfig):
73 | cont = cont.to_dict(make_copy=False)
74 | elif isinstance(cont, DictConfig):
75 | readonly = cont._get_flag("readonly")
76 | elif not isinstance(cont, dict):
77 | raise TypeError
78 | _containers.append(cont)
79 |
80 | merged = OmegaConf.merge(*_containers)
81 | OmegaConf.set_readonly(merged, readonly)
82 |
83 | if not as_omega:
84 | merged = dict(merged)
85 |
86 | return merged
87 |
88 |
89 | @dataclass
90 | class VariableModuleConfig(ModuleConfig):
91 | _target_: str
92 |
93 |
94 | class Registry:
95 | """
96 | Parent class for different registries.
97 | """
98 |
99 | def __init__(self):
100 | self._objects = {}
101 |
102 | def register(
103 | self,
104 | name: str,
105 | module: Optional[Callable] = None
106 | ):
107 | if not isinstance(name, str):
108 | raise TypeError(f"`name` must be a str, got {name}")
109 |
110 | def _register(reg_obj: Callable):
111 | self._objects[name] = reg_obj
112 | return reg_obj
113 |
114 | return _register if module is None else _register(module)
115 |
116 | def instantiate(self, config: Union[VariableModuleConfig, DictConfig], **kwargs):
117 | module: Constructor = self.get(config._target_)
118 | return module.init(config, **kwargs)
119 |
120 | def get(self, key: str):
121 | try:
122 | return self._objects[key]
123 | except KeyError:
124 | raise KeyError(f"'{key}' not found in registry. Available names: {self.available_names}")
125 |
126 | def remove(self, name):
127 | self._objects.pop(name)
128 |
129 | @property
130 | def objects(self):
131 | return self._objects
132 |
133 | @property
134 | def available_names(self):
135 | return tuple(self._objects.keys())
136 |
137 | def __str__(self) -> str:
138 | return f"Objects={self.available_names}"
139 |
--------------------------------------------------------------------------------
/scoreperformer/modules/layers.py:
--------------------------------------------------------------------------------
1 | """ PyTorch modules used by the models."""
2 |
3 | from typing import Optional
4 |
5 | import torch
6 | from torch import nn as nn, Tensor
7 |
8 | from scoreperformer.utils import exists
9 |
10 |
11 | # residual
12 |
13 | class Residual(nn.Module):
14 | def __init__(self, dim: int, scale_residual: bool = False, scale_residual_constant: float = 1.):
15 | super().__init__()
16 | self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
17 | self.scale_residual_constant = scale_residual_constant
18 |
19 | def forward(self, x, residual):
20 | if exists(self.residual_scale):
21 | residual = residual * self.residual_scale
22 |
23 | if self.scale_residual_constant != 1:
24 | residual = residual * self.scale_residual_constant
25 |
26 | return x + residual
27 |
28 |
29 | # adaptive layer normalization
30 |
31 | class AdaptiveLayerNorm(nn.Module):
32 | def __init__(self, dim: int, condition_dim: int, eps: float = 1e-5):
33 | super().__init__()
34 | self.dim = dim
35 | self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
36 |
37 | self.linear = nn.Linear(condition_dim, dim * 2)
38 | self.linear.bias.data[:dim] = 1
39 | self.linear.bias.data[dim:] = 0
40 |
41 | def forward(self, x: Tensor, condition: Optional[Tensor] = None):
42 | if condition is None:
43 | gamma, beta = x.new_ones(1), x.new_zeros(1)
44 | else:
45 | condition = condition.unsqueeze(1) if condition.ndim == 2 else condition
46 | gamma, beta = self.linear(condition).chunk(2, dim=-1)
47 | return gamma * self.norm(x) + beta
48 |
--------------------------------------------------------------------------------
/scoreperformer/modules/sampling.py:
--------------------------------------------------------------------------------
1 | """ Sampling functions. """
2 |
3 | import math
4 | from typing import Callable, Optional, Dict
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import Tensor
9 |
10 | from scoreperformer.utils import default
11 |
12 |
13 | # nucleus
14 |
15 | def top_p(logits: Tensor, thres: float = 0.9):
16 | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
17 | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
18 |
19 | sorted_indices_to_remove = cum_probs > thres
20 | sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value=False)
21 |
22 | sorted_logits[sorted_indices_to_remove] = float('-inf')
23 | return sorted_logits.scatter(1, sorted_indices, sorted_logits)
24 |
25 |
26 | # topk
27 |
28 | def top_k(logits: Tensor, thres: float = 0.9, k: Optional[int] = None):
29 | k = default(k, math.ceil((1 - thres) * logits.shape[-1]))
30 | val, ind = torch.topk(logits, k)
31 | probs = torch.full_like(logits, float('-inf'))
32 | probs.scatter_(1, ind, val)
33 | return probs
34 |
35 |
36 | # top_a
37 |
38 | def top_a(logits: Tensor, min_p_pow: float = 2.0, min_p_ratio: float = 0.02):
39 | probs = F.softmax(logits, dim=-1)
40 | limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
41 | return torch.where(probs < limit, float('-inf'), logits)
42 |
43 |
44 | # sampling
45 |
46 | def filter_logits_and_sample(
47 | logits: Tensor,
48 | filter_logits_fn: Callable,
49 | filter_kwargs: Optional[Dict[str, object]] = None,
50 | temperature: float = 1.,
51 | sample: bool = True
52 | ):
53 | filter_kwargs = filter_kwargs or {}
54 | filtered_logits = filter_logits_fn(logits, **filter_kwargs)
55 |
56 | probs = F.softmax(filtered_logits / temperature, dim=-1)
57 | if not sample:
58 | return probs
59 | return torch.multinomial(probs, 1)
60 |
--------------------------------------------------------------------------------
/scoreperformer/modules/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .attend import (
2 | AttentionIntermediates,
3 | Attend
4 | )
5 | from .attention import (
6 | AttentionSharedIntermediates,
7 | AttentionConfig, Attention
8 | )
9 | from .embeddings import (
10 | DiscreteContinuousEmbedding,
11 | DiscreteDenseContinuousEmbedding,
12 | AbsolutePositionalEmbedding,
13 | FixedPositionalEmbedding,
14 | ALiBiPositionalBias,
15 | LearnedALiBiPositionalBias
16 | )
17 | from .feedforward import (
18 | FeedForwardConfig,
19 | FeedForward
20 | )
21 | from .transformer import (
22 | TransformerIntermediates,
23 | TransformerRegistry,
24 | TransformerConfig, Transformer,
25 | TransformerConfig, Transformer,
26 | EncoderConfig, Encoder,
27 | DecoderConfig, Decoder
28 | )
29 |
30 | DEFAULT_DIM_HEAD = 64
31 |
--------------------------------------------------------------------------------
/scoreperformer/modules/transformer/attend.py:
--------------------------------------------------------------------------------
1 | """
2 | Attention with efficient memory attention support.
3 |
4 | Adapted from: https://github.com/lucidrains/x-transformers
5 | """
6 |
7 | from collections import namedtuple
8 | from dataclasses import dataclass
9 | from functools import partial
10 | from typing import Optional
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from einops import rearrange
15 | from packaging import version
16 | from torch import nn, einsum, Tensor
17 |
18 | from scoreperformer.utils import exists, default
19 |
20 | # constants
21 |
22 | EfficientAttentionConfig = namedtuple('EfficientAttentionConfig',
23 | ['enable_flash', 'enable_math', 'enable_mem_efficient'])
24 |
25 |
26 | @dataclass
27 | class AttentionIntermediates:
28 | keys: Optional[Tensor] = None
29 | values: Optional[Tensor] = None
30 | qk_similarities: Optional[Tensor] = None
31 |
32 | def to_tuple(self):
33 | return self.keys, self.values, self.qk_similarities
34 |
35 |
36 | # main class
37 |
38 | class Attend(nn.Module):
39 | def __init__(
40 | self,
41 | *,
42 | dropout: float = 0.,
43 | causal: bool = False,
44 | scale: Optional[float] = None,
45 | ):
46 | super().__init__()
47 | self.scale = scale
48 | self.causal = causal
49 | self.attn_fn = partial(F.softmax, dtype=torch.float32)
50 |
51 | self.dropout = dropout
52 | self.attn_dropout = nn.Dropout(dropout)
53 |
54 | # efficient attention
55 | self.efficient = version.parse(torch.__version__) >= version.parse('2.0.0')
56 | self.config = EfficientAttentionConfig(enable_flash=False, enable_math=True, enable_mem_efficient=True)
57 |
58 | def efficient_attn(
59 | self,
60 | q, k, v,
61 | mask=None,
62 | attn_bias=None
63 | ):
64 | batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
65 |
66 | # Recommended for multi-query single-key-value attention by Tri Dao
67 | # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
68 |
69 | intermediates = AttentionIntermediates(keys=k, values=v)
70 |
71 | if k.ndim == 3:
72 | k = rearrange(k, 'b ... -> b 1 ...').expand(-1, q.shape[1], -1, -1)
73 |
74 | if v.ndim == 3:
75 | v = rearrange(v, 'b ... -> b 1 ...').expand(-1, q.shape[1], -1, -1)
76 |
77 | # Check if mask exists and expand to compatible shape
78 | # The mask is B L, so it would have to be expanded to B H N L
79 |
80 | causal = self.causal
81 |
82 | if exists(mask):
83 | assert mask.ndim == 4
84 | mask = mask.expand(batch, heads, q_len, k_len)
85 |
86 | # manually handle causal mask, if another mask was given
87 |
88 | if causal:
89 | causal_mask = torch.ones((q_len, k_len), dtype=torch.bool, device=device).triu(k_len - q_len + 1)
90 | mask = mask & ~causal_mask
91 | causal = False
92 |
93 | # handle alibi positional bias
94 | # convert from bool to float
95 |
96 | if exists(attn_bias):
97 | attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, -1, -1, -1)
98 |
99 | # if mask given, the mask would already contain the causal mask from above logic
100 | # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
101 |
102 | mask_value = -torch.finfo(q.dtype).max
103 |
104 | if exists(mask):
105 | attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
106 | elif causal:
107 | causal_mask = torch.ones((q_len, k_len), dtype=torch.bool, device=device).triu(k_len - q_len + 1)
108 | attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
109 | causal = False
110 |
111 | # scaled_dot_product_attention handles attn_mask either as bool or additive bias
112 | # make it an additive bias here
113 |
114 | mask = attn_bias
115 |
116 | # pytorch 2.0 attention: q, k, v, mask, dropout, causal, softmax_scale
117 |
118 | with torch.backends.cuda.sdp_kernel(**self.config._asdict()):
119 | out = F.scaled_dot_product_attention(
120 | q, k, v,
121 | attn_mask=mask,
122 | dropout_p=self.dropout if self.training else 0.,
123 | is_causal=causal and q_len != 1
124 | )
125 |
126 | return out, intermediates
127 |
128 | def forward(
129 | self,
130 | q, k, v,
131 | mask=None,
132 | attn_bias=None,
133 | prev_attn=None
134 | ):
135 | """
136 | einstein notation
137 | b - batch
138 | h - heads
139 | n, i, j - sequence length (base sequence length, source, target)
140 | d - feature dimension
141 | """
142 |
143 | if self.efficient:
144 | assert not exists(prev_attn), 'residual attention not compatible with efficient attention'
145 | return self.efficient_attn(q, k, v, mask=mask, attn_bias=attn_bias)
146 |
147 | n, device = q.shape[-2], q.device
148 | scale = default(self.scale, q.shape[-1] ** -0.5)
149 |
150 | kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
151 |
152 | dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
153 |
154 | if exists(prev_attn):
155 | dots = dots + prev_attn
156 |
157 | qk_similarities = dots.clone()
158 |
159 | if exists(attn_bias):
160 | dots = dots + attn_bias
161 |
162 | dtype = dots.dtype
163 | mask_value = -torch.finfo(dots.dtype).max
164 |
165 | if exists(mask):
166 | dots = dots.masked_fill(~mask, mask_value)
167 |
168 | if self.causal:
169 | i, j = dots.shape[-2:]
170 | causal_mask = torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
171 | dots = dots.masked_fill(causal_mask, mask_value)
172 |
173 | attn = self.attn_fn(dots, dim=-1)
174 | attn = attn.type(dtype)
175 |
176 | attn = self.attn_dropout(attn)
177 |
178 | out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
179 |
180 | intermediates = AttentionIntermediates(
181 | keys=k,
182 | values=v,
183 | qk_similarities=qk_similarities
184 | )
185 |
186 | return out, intermediates
187 |
--------------------------------------------------------------------------------
/scoreperformer/modules/transformer/attention.py:
--------------------------------------------------------------------------------
1 | """
2 | Transformer Attention with data caching support for inference.
3 |
4 | Adapted from: https://github.com/lucidrains/x-transformers
5 | """
6 | from dataclasses import dataclass
7 | from typing import Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from einops import rearrange, repeat
13 | from torch import Tensor
14 |
15 | from scoreperformer.utils import default, or_reduce
16 | from .attend import AttentionIntermediates, Attend
17 | from .embeddings import ALiBiPositionalBias, LearnedALiBiPositionalBias
18 | from ..constructor import Constructor, ModuleConfig
19 |
20 |
21 | @dataclass
22 | class AttentionSharedIntermediates:
23 | rel_pos_bias: Optional[Tensor] = None
24 |
25 |
26 | @dataclass
27 | class AttentionConfig(ModuleConfig):
28 | dim: int = 512
29 | dim_head: int = 64
30 | heads: int = 8
31 | causal: bool = False
32 | dropout: float = 0.
33 | one_kv_head: bool = False
34 | num_mem_kv: int = 0
35 | shared_kv: bool = False
36 | value_dim_head: Optional[int] = None
37 | max_attend_past: Optional[int] = None
38 | alibi_pos_bias: bool = False
39 | alibi_num_heads: Optional[int] = None
40 | alibi_symmetric: bool = True
41 | alibi_learned: bool = False
42 |
43 |
44 | class Attention(nn.Module, Constructor):
45 | def __init__(
46 | self,
47 | dim: int,
48 | dim_head: int = 64,
49 | heads: int = 8,
50 | causal: bool = False,
51 | dropout: float = 0.,
52 | one_kv_head: bool = False,
53 | num_mem_kv: int = 0,
54 | max_attend: Optional[int] = None,
55 | alibi_pos_bias: bool = False,
56 | alibi_num_heads: Optional[int] = None,
57 | alibi_symmetric: bool = True,
58 | alibi_learned: bool = False,
59 | ):
60 | super().__init__()
61 | self.scale = dim_head ** -0.5
62 |
63 | self.heads = heads
64 | self.causal = causal
65 | self.max_attend = max_attend
66 |
67 | self.one_kv_head = one_kv_head
68 | out_dim = q_dim = dim_head * heads
69 | kv_dim = dim_head if one_kv_head else dim_head * heads
70 |
71 | self.to_q = nn.Linear(dim, q_dim, bias=False)
72 | self.to_k = nn.Linear(dim, kv_dim, bias=False)
73 | self.to_v = nn.Linear(dim, kv_dim, bias=False)
74 |
75 | # relative positional bias
76 |
77 | self.rel_pos = None
78 | if alibi_pos_bias:
79 | alibi_num_heads = default(alibi_num_heads, heads)
80 | assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
81 | alibi_pos_klass = LearnedALiBiPositionalBias if alibi_learned else ALiBiPositionalBias
82 | self.rel_pos = alibi_pos_klass(
83 | heads=alibi_num_heads,
84 | total_heads=heads,
85 | symmetric=alibi_symmetric or causal
86 | )
87 |
88 | # attend class - includes core attention algorithm + talking heads
89 |
90 | self.attend = Attend(
91 | causal=causal,
92 | dropout=dropout,
93 | scale=self.scale
94 | )
95 |
96 | # add memory key / values
97 |
98 | self.num_mem_kv = num_mem_kv
99 | if num_mem_kv > 0:
100 | self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
101 | self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
102 |
103 | # output layer
104 |
105 | self.to_out = nn.Linear(out_dim, dim, bias=False)
106 |
107 | def forward(
108 | self,
109 | x: Tensor,
110 | context: Optional[Tensor] = None,
111 | mask: Optional[Tensor] = None,
112 | context_mask: Optional[Tensor] = None,
113 | attn_mask: Optional[Tensor] = None,
114 | prev_attn: Optional[Tensor] = None,
115 | mem: Optional[Tensor] = None,
116 | cache: Optional[AttentionIntermediates] = None,
117 | shared_cache: Optional[AttentionSharedIntermediates] = None
118 | ):
119 | b, n = x.shape[:2]
120 | h, scale, device = self.heads, self.scale, x.device
121 | has_context, has_mem, has_cache = context is not None, mem is not None, cache is not None
122 | assert not (has_mem and has_cache), 'cache is not compatible with memory keys'
123 | assert not (has_context and has_cache), 'cache is not compatible with context yet'
124 |
125 | kv_input = default(context, x)
126 |
127 | q_input = x
128 | k_input = kv_input
129 | v_input = kv_input
130 |
131 | if has_mem:
132 | k_input = torch.cat((mem, k_input), dim=-2)
133 | v_input = torch.cat((mem, v_input), dim=-2)
134 |
135 | q = self.to_q(q_input)
136 | k = self.to_k(k_input)
137 | v = self.to_v(v_input) if self.to_v is not None else k
138 |
139 | q = rearrange(q, 'b n (h d) -> b h n d', h=h)
140 |
141 | if not self.one_kv_head:
142 | k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (k, v))
143 |
144 | input_mask = mask if context_mask is None else context_mask
145 |
146 | if self.num_mem_kv > 0:
147 | mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
148 |
149 | k = torch.cat((mem_k, k), dim=-2)
150 | v = torch.cat((mem_v, v), dim=-2)
151 |
152 | if input_mask is not None:
153 | input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
154 |
155 | k = torch.cat([cache.keys, k], dim=-2) if has_cache else k
156 | v = torch.cat([cache.values, v], dim=-2) if has_cache else v
157 |
158 | i, j = map(lambda t: t.shape[-2], (q, k))
159 |
160 | # determine masking
161 |
162 | masks = []
163 | final_attn_mask = None
164 |
165 | if input_mask is not None:
166 | input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
167 | masks.append(~input_mask)
168 |
169 | if attn_mask is not None:
170 | assert 2 <= attn_mask.ndim <= 4, \
171 | 'attention mask must have greater than 2 dimensions but less than or equal to 4'
172 | if attn_mask.ndim == 2:
173 | attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
174 | elif attn_mask.ndim == 3:
175 | attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
176 | attn_mask = attn_mask[:, :, -1:] if has_cache else attn_mask
177 | masks.append(~attn_mask)
178 |
179 | if self.max_attend is not None:
180 | range_q = torch.arange(j - i, j, device=device)
181 | range_k = torch.arange(j, device=device)
182 | dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
183 | max_attend_mask = torch.logical_or(-self.max_attend < dist, dist > self.max_attend)
184 | masks.append(max_attend_mask)
185 |
186 | if len(masks) > 0:
187 | final_attn_mask = ~or_reduce(masks)
188 |
189 | # prepare relative positional bias, if needed
190 |
191 | rel_pos_bias, attn_bias = None, None
192 | if self.rel_pos is not None:
193 | if shared_cache is not None and shared_cache.rel_pos_bias is not None:
194 | rel_pos_bias = shared_cache.rel_pos_bias
195 | else:
196 | rel_pos_bias = self.rel_pos.get_bias(i, j, k=j - i).to(dtype=q.dtype)
197 | attn_bias = self.rel_pos(i, j, k=j - i, bias=rel_pos_bias)
198 |
199 | # attention is all we need
200 |
201 | out, intermediates = self.attend(
202 | q, k, v,
203 | mask=final_attn_mask,
204 | attn_bias=attn_bias,
205 | prev_attn=prev_attn
206 | )
207 |
208 | # merge heads
209 |
210 | out = rearrange(out, 'b h n d -> b n (h d)')
211 |
212 | # combine the heads
213 |
214 | out = self.to_out(out)
215 |
216 | if mask is not None:
217 | mask = mask[:, -1:] if has_cache else mask
218 | out = out * mask[..., None]
219 |
220 | shared_intermediates = AttentionSharedIntermediates(rel_pos_bias=rel_pos_bias)
221 |
222 | return out, intermediates, shared_intermediates
223 |
--------------------------------------------------------------------------------
/scoreperformer/modules/transformer/feedforward.py:
--------------------------------------------------------------------------------
1 | """
2 | Transformer FeedForward layer.
3 |
4 | Adapted from: https://github.com/lucidrains/x-transformers
5 | """
6 | from dataclasses import dataclass
7 |
8 | import torch.nn as nn
9 |
10 | from ..constructor import Constructor, ModuleConfig
11 |
12 |
13 | class GLU(nn.Module):
14 | def __init__(self, dim_in, dim_out, activation):
15 | super().__init__()
16 | self.act = activation
17 | self.proj = nn.Linear(dim_in, dim_out * 2)
18 |
19 | def forward(self, x):
20 | x, gate = self.proj(x).chunk(2, dim=-1)
21 | return x * self.act(gate)
22 |
23 |
24 | @dataclass
25 | class FeedForwardConfig(ModuleConfig):
26 | dim: int = 512
27 | mult: int = 4
28 | glu: bool = False
29 | swish: bool = False
30 | post_act_ln: bool = False
31 | dropout: float = 0.
32 | no_bias: bool = True
33 |
34 |
35 | class FeedForward(nn.Module, Constructor):
36 | def __init__(
37 | self,
38 | dim: int = 512,
39 | mult: int = 4,
40 | glu: bool = False,
41 | swish: bool = False,
42 | post_act_ln: bool = False,
43 | dropout: float = 0.,
44 | no_bias: bool = True
45 | ):
46 | super().__init__()
47 |
48 | inner_dim = int(dim * mult)
49 | activation = nn.SiLU() if swish else nn.GELU()
50 |
51 | project_in = nn.Sequential(
52 | nn.Linear(dim, inner_dim, bias=not no_bias),
53 | activation
54 | ) if not glu else GLU(dim, inner_dim, activation)
55 |
56 | self.ff = nn.Sequential(
57 | project_in,
58 | nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim, bias=not no_bias)
61 | )
62 |
63 | def forward(self, x):
64 | return self.ff(x)
65 |
--------------------------------------------------------------------------------
/scoreperformer/modules/transformer/transformer.py:
--------------------------------------------------------------------------------
1 | """
2 | Transformer Attention Layers with data caching support for inference.
3 |
4 | Adapted from: https://github.com/lucidrains/x-transformers
5 | """
6 |
7 | import copy
8 | from dataclasses import dataclass, field
9 | from functools import partial
10 | from typing import Optional, Union, List
11 |
12 | import torch
13 | import torch.nn as nn
14 | from omegaconf import DictConfig
15 | from torch import Tensor
16 |
17 | from scoreperformer.utils import equals
18 | from .attend import AttentionIntermediates
19 | from .attention import Attention, AttentionConfig
20 | from .feedforward import FeedForward, FeedForwardConfig
21 | from ..constructor import VariableModuleConfig, Constructor, Registry
22 | from ..layers import Residual, AdaptiveLayerNorm
23 |
24 |
25 | @dataclass
26 | class TransformerIntermediates:
27 | hiddens: Optional[List[Tensor]] = None
28 | attention: Optional[List[AttentionIntermediates]] = None
29 |
30 |
31 | TransformerRegistry = type("_TransformerRegistry", (Registry,), {})()
32 |
33 |
34 | @dataclass
35 | class TransformerConfig(VariableModuleConfig):
36 | _target_: str = "default"
37 |
38 | dim: int = 512
39 | depth: int = 4
40 | heads: int = 8
41 |
42 | attention: Union[AttentionConfig, DictConfig] = None
43 | feed_forward: Union[FeedForwardConfig, DictConfig] = None
44 |
45 | causal: bool = False
46 | cross_attend: bool = False
47 | only_cross: bool = False
48 |
49 | pre_norm: bool = True
50 | use_adanorm: bool = False
51 | style_emb_dim: Optional[int] = None
52 |
53 |
54 | @TransformerRegistry.register("default")
55 | class Transformer(nn.Module, Constructor):
56 | def __init__(
57 | self,
58 | dim: int = 512,
59 | depth: int = 4,
60 | heads: int = 8,
61 | attention: Union[AttentionConfig, DictConfig] = None,
62 | feed_forward: Union[FeedForwardConfig, DictConfig] = None,
63 | causal: bool = False,
64 | cross_attend: bool = False,
65 | only_cross: bool = False,
66 | pre_norm: bool = True,
67 | use_adanorm: bool = False,
68 | style_emb_dim: Optional[int] = None
69 | ):
70 | super().__init__()
71 |
72 | attention = attention if attention else AttentionConfig()
73 | feed_forward = feed_forward if feed_forward else FeedForwardConfig()
74 |
75 | self.dim = dim
76 | self.depth = depth
77 | self.layers = nn.ModuleList([])
78 |
79 | # normalization
80 |
81 | self.pre_norm = pre_norm
82 | self.ada_norm = use_adanorm
83 |
84 | assert not use_adanorm or style_emb_dim is not None, 'condition_dim should be provided with adanorm'
85 |
86 | norm_fn = partial(AdaptiveLayerNorm, dim, style_emb_dim) if use_adanorm else partial(nn.LayerNorm, dim)
87 |
88 | # layers
89 |
90 | self.cross_attend = cross_attend
91 |
92 | if cross_attend and not only_cross:
93 | default_block = ('a', 'c', 'f')
94 | elif cross_attend and only_cross:
95 | default_block = ('c', 'f')
96 | else:
97 | default_block = ('a', 'f')
98 |
99 | # calculate layer block order
100 |
101 | self.layer_types = default_block * depth
102 | self.num_attn_layers = len(list(filter(equals('a'), self.layer_types)))
103 |
104 | # whether it has post norm
105 |
106 | self.final_norm = norm_fn() if pre_norm else nn.Identity()
107 |
108 | # iterate and construct layers
109 |
110 | for ind, layer_type in enumerate(self.layer_types):
111 |
112 | if layer_type == 'a':
113 | layer = Attention.init(config=attention, dim=dim, heads=heads, causal=causal)
114 | elif layer_type == 'c':
115 | layer = Attention.init(config=attention, dim=dim, heads=heads)
116 | elif layer_type == 'f':
117 | layer = FeedForward.init(config=feed_forward, dim=dim)
118 | else:
119 | raise Exception(f'invalid layer type {layer_type}')
120 |
121 | residual = Residual(dim)
122 |
123 | pre_branch_norm = norm_fn() if pre_norm else None
124 | post_branch_norm = None
125 | post_main_norm = norm_fn() if not pre_norm else None
126 |
127 | norms = nn.ModuleList([
128 | pre_branch_norm,
129 | post_branch_norm,
130 | post_main_norm
131 | ])
132 |
133 | self.layers.append(nn.ModuleList([
134 | norms,
135 | layer,
136 | residual
137 | ]))
138 |
139 | def forward(
140 | self,
141 | x: Tensor,
142 | mask: Optional[Tensor] = None,
143 | context: Optional[Tensor] = None,
144 | context_mask: Optional[Tensor] = None,
145 | attn_mask: Optional[Tensor] = None,
146 | style_embeddings: Optional[Tensor] = None,
147 | mems: Optional[List[Tensor]] = None,
148 | intermediates_cache: Optional[TransformerIntermediates] = None,
149 | return_hiddens: bool = False
150 | ):
151 | assert not (self.cross_attend ^ (context is not None)), \
152 | 'context must be passed in if cross_attend is set to True'
153 | assert not self.ada_norm or style_embeddings is not None, \
154 | 'style_embeddings must be passed for AdaLayerNorm'
155 |
156 | hiddens = []
157 | attn_intermediates = []
158 |
159 | mems = mems.copy() if mems is not None else [None] * self.num_attn_layers
160 |
161 | has_cache = intermediates_cache is not None
162 | intermediates_cache = copy.copy(intermediates_cache) if has_cache else None
163 |
164 | x = x[:, -1:] if has_cache else x
165 | style_embeddings = style_embeddings[:, -1:] if has_cache and style_embeddings is not None else style_embeddings
166 |
167 | attn_shared_cache = None
168 |
169 | for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
170 | cache = None
171 |
172 | if layer_type == 'a':
173 | if has_cache:
174 | cache = intermediates_cache.hiddens.pop(0)
175 | x = torch.cat([cache, x], dim=1)
176 |
177 | if return_hiddens:
178 | hiddens.append(x)
179 |
180 | x = x[:, -1:] if has_cache else x
181 |
182 | layer_mem = mems.pop(0) if mems else None
183 |
184 | if has_cache:
185 | if layer_type in ('a', 'c'):
186 | cache = intermediates_cache.attention.pop(0)
187 |
188 | residual = x
189 |
190 | pre_norm, post_branch_norm, post_main_norm = norm
191 |
192 | if pre_norm is not None:
193 | x = pre_norm(x, condition=style_embeddings) if self.ada_norm else pre_norm(x)
194 |
195 | if layer_type == 'a':
196 | out, inter, attn_shared_cache = block(
197 | x, mask=mask, attn_mask=attn_mask,
198 | mem=layer_mem, cache=cache, shared_cache=attn_shared_cache,
199 | )
200 | elif layer_type == 'c':
201 | out, inter, _ = block(x, context=context, mask=mask, context_mask=context_mask)
202 | elif layer_type == 'f':
203 | out = block(x)
204 |
205 | if post_branch_norm is not None:
206 | out = post_branch_norm(x, condition=style_embeddings) if self.ada_norm else post_branch_norm(out)
207 |
208 | x = residual_fn(out, residual)
209 |
210 | if return_hiddens:
211 | if layer_type in ('a', 'c'):
212 | attn_intermediates.append(inter)
213 |
214 | if post_main_norm is not None:
215 | x = post_main_norm(x, condition=style_embeddings) if self.ada_norm else post_main_norm(x)
216 |
217 | x = self.final_norm(x, condition=style_embeddings) if self.ada_norm else self.final_norm(x)
218 |
219 | if has_cache:
220 | cache = intermediates_cache.hiddens.pop(0)
221 | x = torch.cat([cache, x], dim=1)
222 |
223 | if return_hiddens:
224 | hiddens.append(x)
225 | intermediates = TransformerIntermediates(
226 | hiddens=hiddens,
227 | attention=attn_intermediates
228 | )
229 |
230 | return x, intermediates
231 |
232 | return x
233 |
234 |
235 | @dataclass
236 | class EncoderConfig(TransformerConfig):
237 | _target_: str = "encoder"
238 | causal: bool = False
239 |
240 |
241 | @TransformerRegistry.register("encoder")
242 | class Encoder(Transformer):
243 | def __init__(self, **kwargs):
244 | super().__init__(causal=False, **kwargs)
245 |
246 |
247 | @dataclass
248 | class DecoderConfig(TransformerConfig):
249 | _target_: str = "decoder"
250 | causal: bool = True
251 |
252 |
253 | @TransformerRegistry.register("decoder")
254 | class Decoder(Transformer):
255 | def __init__(self, **kwargs):
256 | super().__init__(causal=True, **kwargs)
257 |
--------------------------------------------------------------------------------
/scoreperformer/train.py:
--------------------------------------------------------------------------------
1 | """ A minimal training script. """
2 |
3 | import argparse
4 |
5 | from scoreperformer.experiments import Trainer
6 | from scoreperformer.experiments.callbacks import EpochReproducibilityCallback
7 | from scoreperformer.experiments.components import ExperimentComponents
8 |
9 | if __name__ == "__main__":
10 | parser = argparse.ArgumentParser('training the model')
11 | parser.add_argument('--config-root', '-r', type=str, default='../recipes')
12 | parser.add_argument('--config-name', '-n', type=str, default='scoreperformer/base.yaml')
13 |
14 | args = parser.parse_args()
15 |
16 | exp_comps = ExperimentComponents(
17 | config=args.config_name,
18 | config_root=args.config_root
19 | )
20 | model, train_dataset, eval_dataset, collator, evaluator = exp_comps.init_components()
21 |
22 | trainer = Trainer(
23 | model=model,
24 | config=exp_comps.config,
25 | train_dataset=train_dataset,
26 | eval_dataset=eval_dataset,
27 | collator=collator,
28 | evaluator=evaluator,
29 | callbacks=[EpochReproducibilityCallback()]
30 | )
31 |
32 | trainer.train()
33 |
--------------------------------------------------------------------------------
/scoreperformer/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import Config
2 | from .functions import *
3 | from .io import *
--------------------------------------------------------------------------------
/scoreperformer/utils/config.py:
--------------------------------------------------------------------------------
1 | """ A general purpose Config with IO. That's it. """
2 |
3 | import json
4 | from dataclasses import dataclass, asdict
5 | from enum import Enum
6 |
7 | from omegaconf import DictConfig, ListConfig, OmegaConf
8 |
9 |
10 | @dataclass
11 | class Config:
12 | def to_dict(self):
13 | d = asdict(self)
14 | for k, v in d.items():
15 | if isinstance(v, Enum):
16 | d[k] = v.value
17 | elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
18 | d[k] = [x.value for x in v]
19 | elif isinstance(v, (DictConfig, ListConfig)):
20 | d[k] = OmegaConf.to_container(v)
21 | elif isinstance(v, Config):
22 | d[k] = v.to_dict()
23 | return d
24 |
25 | def to_json_string(self):
26 | return json.dumps(self.to_dict(), indent=2)
27 |
28 | @classmethod
29 | def from_json_string(cls, json_string):
30 | return cls(**json.loads(json_string))
31 |
32 | def __contains__(self, item):
33 | return item in self.to_dict()
34 |
35 |
36 | def disable_nodes(config, parent=None):
37 | if isinstance(config, DictConfig):
38 | eject = config.pop("_disable_", False)
39 |
40 | if eject and parent is not None:
41 | key = config._key()
42 | parent.pop(key)
43 | else:
44 | for value in list(config.values()):
45 | disable_nodes(value, config)
46 |
--------------------------------------------------------------------------------
/scoreperformer/utils/functions.py:
--------------------------------------------------------------------------------
1 | """ A set of utility classes and functions used throughout the repository. """
2 |
3 | import random
4 | import sys
5 | from enum import Enum
6 | from inspect import isfunction
7 |
8 | import numpy as np
9 | from tqdm.asyncio import tqdm
10 |
11 |
12 | def exists(val):
13 | return val is not None
14 |
15 |
16 | def default(val, d):
17 | if exists(val):
18 | return val
19 | return d() if isfunction(d) else d
20 |
21 |
22 | class equals:
23 | def __init__(self, val):
24 | self.val = val
25 |
26 | def __call__(self, x, *args, **kwargs):
27 | return x == self.val
28 |
29 |
30 | def or_reduce(masks):
31 | head, *body = masks
32 | for rest in body:
33 | head = head | rest
34 | return head
35 |
36 |
37 | def prob2bool(prob):
38 | return random.choices([True, False], weights=[prob, 1 - prob])[0]
39 |
40 |
41 | def find_closest(array, values):
42 | """Finds indices of the values closest to `values` in a given array."""
43 | ids = np.searchsorted(array, values, side="left")
44 |
45 | # find indexes where previous index is closer
46 | arr_values = array[np.minimum(ids, len(array) - 1)]
47 | prev_values = array[np.maximum(ids - 1, 0)]
48 | prev_idx_is_less = (ids == len(array)) | (np.fabs(values - prev_values) < np.fabs(values - arr_values))
49 |
50 | if isinstance(ids, np.ndarray):
51 | ids[prev_idx_is_less] -= 1
52 | elif prev_idx_is_less:
53 | ids -= 1
54 |
55 | ids = np.maximum(0, ids)
56 |
57 | return ids
58 |
59 |
60 | def tqdm_iterator(iterable, desc=None, position=0, leave=False, file=sys.stdout, **kwargs):
61 | return tqdm(iterable, desc=desc, position=position, leave=leave, file=file, **kwargs)
62 |
63 |
64 | def apply(seqs, func, tqdm_enabled=True, desc=None):
65 | """ Apply a given `func` over a list of sequences `seqs`."""
66 | iterator = tqdm_iterator(seqs, desc=desc) if tqdm_enabled else seqs
67 | return [func(seq) for seq in iterator]
68 |
69 |
70 | class ExplicitEnum(str, Enum):
71 | """
72 | Enum with more explicit error message for missing values.
73 | """
74 |
75 | @classmethod
76 | def _missing_(cls, value):
77 | raise ValueError(
78 | f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
79 | )
80 |
81 | @classmethod
82 | def has_value(cls, value):
83 | return value in cls._value2member_map_
84 |
85 | @classmethod
86 | def list(cls):
87 | return list(map(lambda c: c.value, cls))
88 |
--------------------------------------------------------------------------------
/scoreperformer/utils/io.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 |
5 | def load_json(file_path):
6 | if os.path.exists(file_path):
7 | with open(file_path, 'r') as f:
8 | return json.load(f)
9 | else:
10 | return {}
11 |
12 |
13 | def dump_json(data, file_path):
14 | with open(file_path, 'w') as f:
15 | return json.dump(data, f)
16 |
--------------------------------------------------------------------------------
/scoreperformer/utils/playback.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import IPython.display as ipd
4 | import note_seq
5 | from miditoolkit import MidiFile
6 | from note_seq import midi_file_to_note_sequence
7 |
8 |
9 | def cut_midi(
10 | midi: MidiFile,
11 | min_tick: int = 0,
12 | max_tick: int = 1e9,
13 | cut_end_tick: bool = True,
14 | save_path: str = '/tmp/tmp.mid'
15 | ):
16 | midi = copy.deepcopy(midi)
17 |
18 | for track in midi.instruments:
19 | track.notes = [n for n in track.notes if min_tick <= n.start <= max_tick]
20 | for n in track.notes:
21 | n.start -= min_tick
22 | if cut_end_tick:
23 | n.end = min(n.end, max_tick)
24 | n.end -= min_tick
25 |
26 | if hasattr(track, "control_changes"):
27 | track.control_changes = [c for c in track.control_changes if min_tick <= c.time <= max_tick]
28 | for c in track.control_changes:
29 | c.time -= min_tick
30 | if hasattr(track, "pedals"):
31 | track.pedals = [p for p in track.pedals if min_tick <= p.start <= max_tick]
32 | for p in track.pedals:
33 | p.start -= min_tick
34 | p.end -= min_tick
35 |
36 | midi.tempo_changes = [t for t in midi.tempo_changes if min_tick <= t.time <= max_tick]
37 | for t in midi.tempo_changes:
38 | t.time -= min_tick
39 |
40 | midi.max_tick = max([n.end for n in midi.instruments[0].notes])
41 | midi.max_tick = max(midi.max_tick, midi.tempo_changes[-1].time + 1)
42 |
43 | if save_path is not None:
44 | midi.dump(save_path)
45 |
46 | return midi
47 |
48 |
49 | def midi_to_audio(
50 | path: str = '/tmp/tmp.mid',
51 | sample_rate: int = 22050,
52 | play: bool = True
53 | ):
54 | ns = midi_file_to_note_sequence(path)
55 | audio = note_seq.fluidsynth(ns, sample_rate=sample_rate)
56 | if play:
57 | ipd.display(ipd.Audio(audio, rate=sample_rate))
58 | return audio
59 |
--------------------------------------------------------------------------------
/scoreperformer/utils/plots.py:
--------------------------------------------------------------------------------
1 | import librosa.display
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | from matplotlib.colors import ListedColormap
5 | import pretty_midi
6 |
7 | from scoreperformer.data.tokenizers import SPMuple
8 |
9 |
10 | def plot_performance_parameter(tokenizer: SPMuple, total_seq, perf_seq, token_type="Tempo"):
11 | type_idx = tokenizer.vocab_types_idx[token_type]
12 |
13 | preds = total_seq[:, type_idx] - tokenizer.zero_token
14 | targets = perf_seq[:total_seq.shape[0], type_idx] - tokenizer.zero_token
15 |
16 | if token_type == "Velocity":
17 | values_map = tokenizer.velocities
18 | elif token_type == "Tempo":
19 | values_map = tokenizer.tempos
20 | elif token_type == "OnsetDev":
21 | nb_positions = max(tokenizer.beat_res.values()) * 2 # up to two quarter notes
22 | values_map = np.arange(-nb_positions, nb_positions + 1) / nb_positions / 2
23 | elif token_type == "PerfDuration":
24 | values_map = np.array([
25 | (beat * res + pos) / res if res > 0 else 0
26 | for beat, pos, res in tokenizer.durations
27 | ])
28 | elif token_type == "RelOnsetDev":
29 | values_map = tokenizer.rel_onset_deviations
30 | elif token_type == "RelPerfDuration":
31 | values_map = tokenizer.rel_performed_durations
32 | else:
33 | return
34 |
35 | preds, targets = values_map[preds], values_map[targets]
36 |
37 | fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(16, 12))
38 |
39 | fig.suptitle(f"Performance Notes, {token_type}", fontsize=20)
40 | ax0.plot(preds)
41 | ax0.plot(targets)
42 | ax1.plot(preds - targets)
43 |
44 | ax0.legend(["Generated", "Target"], fontsize=18)
45 | ax1.legend(["Difference"], fontsize=18)
46 |
47 | ax0.get_xaxis().set_visible(False)
48 | ax1.set_xlabel("note id", fontsize=16)
49 | for ax in (ax0, ax1):
50 | ax.tick_params(labelsize=14)
51 | ax.set_ylabel(token_type.lower(), fontsize=16)
52 |
53 | fig.tight_layout()
54 |
55 |
56 | _colors = plt.get_cmap('Reds', 256)(np.linspace(0, 1, 256))
57 | _colors[:1, :] = np.array([1, 1, 1, 1])
58 | pianoroll_cmap = ListedColormap(_colors)
59 |
60 |
61 | def plot_pianoroll(pm, min_pitch=21, max_pitch=109, min_velocity=0., max_velocity=127.,
62 | fs=100, max_time=None, pad_time=0.2, xticks_time=1., figsize=(14, 6), fig=None, ax=None):
63 | if fig is None or ax is None:
64 | fig, ax = plt.subplots(1, 1, figsize=figsize)
65 |
66 | # get numpy array from pianoroll
67 | arr = pm.get_piano_roll(fs)[min_pitch:max_pitch + 1]
68 | arr[arr > max_velocity] = max_velocity
69 |
70 | # pad with a few steps
71 | max_time = pm.get_end_time() if max_time is None else max_time
72 | pad_steps = int(fs * pad_time)
73 | pad_l, pad_r = pad_steps, pad_steps + max(0, int(fs * max_time) - arr.shape[1])
74 | arr = np.pad(arr, ((0, 0), (pad_l, pad_r)), 'constant')
75 |
76 | # plot pianoroll
77 | x_coords = np.arange(-pad_time, arr.shape[1] / fs - pad_time, 1 / fs)
78 | y_coords = np.arange(min_pitch, max_pitch + 1, 1)
79 | librosa.display.specshow(
80 | arr, cmap=pianoroll_cmap, ax=ax, x_coords=x_coords, y_coords=y_coords,
81 | hop_length=1, sr=fs, x_axis='time', y_axis='linear', bins_per_octave=12,
82 | fmin=pretty_midi.note_number_to_hz(min_pitch), vmin=min_velocity, vmax=max_velocity
83 | )
84 |
85 | # plot colorbar
86 | cbar = fig.colorbar(ax.get_children()[0], ax=ax, fraction=0.15, pad=0.02, aspect=15)
87 | cbar.ax.tick_params(labelsize=14)
88 | cbar.set_ticks(np.arange(0, max_velocity, 12))
89 |
90 | # axis labels
91 | ax.set_xlabel('time (s)', fontsize=16)
92 | ax.set_ylabel('pitch', fontsize=16)
93 | ax.tick_params(labelsize=14)
94 |
95 | # x-axis
96 | ax.set_xticks(np.arange(0, ax.get_xticks()[-1], xticks_time))
97 | ax.set_xlim(xmax=max_time + pad_time)
98 |
99 | # y-axis
100 | yticks = np.arange(min_pitch + 12 - min_pitch % 12, max_pitch, 12)
101 | ax.set_yticks(yticks - 0.5)
102 | ax.set_yticklabels(yticks)
103 |
104 | # removing empty pitch lines
105 | has_notes = min_pitch + np.where(np.any(arr != 0., axis=1))[0]
106 | if len(has_notes) > 0:
107 | ymin, ymax = has_notes[0], has_notes[-1]
108 | ymin = max(min_pitch, ymin - ymin % 12) - 2.5
109 | ymax = min(max_pitch, ymax + 12 - ymax % 12) + 1.5
110 | ax.set_ylim(ymin, ymax)
111 |
112 | ax.grid(alpha=0.5)
113 |
114 | return fig, ax
115 |
--------------------------------------------------------------------------------