├── .gitignore ├── README.md ├── checkpoints ├── checkpoints │ ├── music_transformer │ │ ├── best_loss.pth │ │ └── music_transformer_event_based.yaml │ └── ripo_transformer │ │ ├── best_loss.pth │ │ └── ripo_transformer.yaml └── models │ ├── fme_model.pkl │ └── we_model.pkl ├── demo ├── __init__.py ├── config │ ├── generation_uncond.yaml │ ├── music_transformer_event_based.yaml │ └── ripo_transformer.yaml ├── data │ ├── test_data_0.json │ ├── test_data_1.json │ ├── test_data_2.json │ └── test_data_3.json ├── data_loading.py └── utils.py ├── generation_uncond.py ├── lib └── fast_transformers.zip ├── model ├── FME_music_positional_encoding.py ├── __init__.py ├── __pycache__ │ ├── FME_music_positional_encoding.cpython-37.pyc │ ├── FME_music_positional_encoding.cpython-39.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-39.pyc │ ├── layers.cpython-37.pyc │ ├── layers.cpython-39.pyc │ ├── model.cpython-37.pyc │ └── model.cpython-39.pyc ├── layers.py └── model.py ├── training ├── config │ ├── generation_uncond.yaml │ ├── music_transformer_event_based.yaml │ └── ripo_transformer.yaml └── main_train.py ├── tutorial ├── __init__.py ├── fme │ ├── .gitkeep │ ├── consts.py │ ├── tutorial.py │ ├── utils.py │ └── visual.py ├── ripo │ └── .gitkeep └── setup │ └── .gitkeep └── utils ├── __pycache__ ├── eval_utils.cpython-37.pyc └── eval_utils.cpython-39.pyc ├── eval_utils.py └── test_kernel_est.py /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | __pycache__ 3 | .DS_Store 4 | *.ipynb* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fundamental Music Embedding (FME) and RIPO Attention 2 | 3 | Code accompanying the paper 'A Domain-Knowledge-Inspired Music Embedding Space and a Novel Attention Mechanism for Symbolic Music Modeling'. [Read paper here.](https://arxiv.org/abs/2212.00973) 4 | 5 | Following the success of the transformer architecture in the natural language domain, transformer-like architectures have been widely applied to the domain of symbolic music recently. Symbolic music and text, however, are two different modalities. Symbolic music contains multiple attributes, both absolute attributes (e.g., pitch) and relative attributes (e.g., pitch interval). These relative attributes shape human perception of musical motifs. These important relative attributes, however, are mostly ignored in existing symbolic music modeling methods with the main reason being the lack of a musically-meaningful embedding space where both the absolute and relative embeddings of the symbolic music tokens can be efficiently represented. In this paper, we propose the Fundamental Music Embedding (FME) for symbolic music based on a bias-adjusted sinusoidal encoding within which both the absolute and the relative attributes can be embedded and the fundamental musical properties (e.g., translational invariance) are explicitly preserved. Taking advantage of the proposed FME, we further propose a novel attention mechanism based on the relative index, pitch and onset embeddings (RIPO attention) such that the musical domain knowledge can be fully utilized for symbolic music modeling. Experiment results show that our proposed model: RIPO transformer which utilizes FME and RIPO attention outperforms the state-of-the-art transformers (i.e., music transformer, linear transformer) in a melody completion task. Moreover, using the RIPO transformer in a downstream music generation task, we notice that the notorious degeneration phenomenon no longer exists and the music generated by the RIPO transformer outperforms the music generated by state-of-the-art transformer models in both subjective and objective evaluations. The code of the proposed method is available online1 6 | 7 | ## How to use 8 | 9 | Will be added soon. 10 | 11 | ## Citation 12 | 13 | If you use this work please cite the following paper: 14 | 15 | ``` 16 | @inproceedings{guo2023domain, 17 | title={A domain-knowledge-inspired music embedding space and a novel attention mechanism for symbolic music modeling}, 18 | author={Guo, Zixun and Kang, Jaeyong and Herremans, Dorien}, 19 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 20 | volume={37}, 21 | number={4}, 22 | pages={5070--5077}, 23 | year={2023} 24 | } 25 | ``` 26 | -------------------------------------------------------------------------------- /checkpoints/checkpoints/music_transformer/best_loss.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/checkpoints/checkpoints/music_transformer/best_loss.pth -------------------------------------------------------------------------------- /checkpoints/checkpoints/music_transformer/music_transformer_event_based.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | d_model: 256 3 | dataset: 4 | data_dir: demo/data/wikifornia_theorytab_csv_mid 5 | seq_len_chord: 88 6 | seq_len_note: 246 7 | device: cuda:3 8 | epochs: 80 9 | experiment: Music Transformer Event Based 10 | lr: 0.001 11 | optimizer: adam 12 | relative_pitch_attention: 13 | attention_conf: 14 | attention_type: rgl_rel_pitch_dur 15 | if_add_relative_duration: false 16 | if_add_relative_idx: true 17 | if_add_relative_idx_no_mask: false 18 | if_add_relative_pitch: false 19 | d_model: 256 20 | device: cuda 21 | dim_feedforward: 2048 22 | dropout: 0.2 23 | dur_dim: 17 24 | dur_embedding_conf: 25 | base: 7920 26 | d_model: 256 27 | device: cuda 28 | emb_nn: false 29 | if_trainable: true 30 | translation_bias_type: nd 31 | type: nn 32 | emb_size: 128 33 | nhead: 8 34 | nlayers: 2 35 | pitch_dim: 128 36 | pitch_embedding_conf: 37 | base: 9919 38 | d_model: 256 39 | device: cuda 40 | emb_nn: false 41 | if_trainable: true 42 | translation_bias_type: nd 43 | type: nn 44 | position_encoding_conf: 45 | device: cuda 46 | if_global_timing: false 47 | if_index: true 48 | if_modulo_timing: false 49 | -------------------------------------------------------------------------------- /checkpoints/checkpoints/ripo_transformer/best_loss.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/checkpoints/checkpoints/ripo_transformer/best_loss.pth -------------------------------------------------------------------------------- /checkpoints/checkpoints/ripo_transformer/ripo_transformer.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | d_model: 256 3 | dataset: 4 | data_dir: demo/data/wikifornia_theorytab_csv_mid 5 | seq_len_chord: 88 6 | seq_len_note: 246 7 | device: cuda 8 | epochs: 80 9 | experiment: bidirectional_RIPO 10 | lr: 0.001 11 | optimizer: adam 12 | relative_pitch_attention: 13 | attention_conf: 14 | attention_type: rgl_rel_pitch_dur 15 | if_add_relative_duration: true 16 | if_add_relative_idx: true 17 | if_add_relative_idx_no_mask: false 18 | if_add_relative_pitch: true 19 | d_model: 256 20 | device: cuda 21 | dim_feedforward: 2048 22 | dropout: 0.2 23 | dur_dim: 17 24 | dur_embedding_conf: 25 | base: 7920 26 | d_model: 256 27 | device: cuda 28 | emb_nn: true 29 | if_trainable: true 30 | translation_bias_type: nd 31 | type: se 32 | emb_size: 128 33 | nhead: 8 34 | nlayers: 2 35 | pitch_dim: 128 36 | pitch_embedding_conf: 37 | base: 9919 38 | d_model: 256 39 | device: cuda 40 | emb_nn: true 41 | if_trainable: true 42 | translation_bias_type: nd 43 | type: se 44 | position_encoding_conf: 45 | device: cuda 46 | if_global_timing: true 47 | if_index: true 48 | if_modulo_timing: true 49 | -------------------------------------------------------------------------------- /checkpoints/models/fme_model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/checkpoints/models/fme_model.pkl -------------------------------------------------------------------------------- /checkpoints/models/we_model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/checkpoints/models/we_model.pkl -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- 1 | # from .data_loading import MotifDataset, MotifDataset_relative_pitch_dur, MotifDataset_with_mask -------------------------------------------------------------------------------- /demo/config/generation_uncond.yaml: -------------------------------------------------------------------------------- 1 | ckpt_path: "/data/nicolas/Fundamental_Music_Embedding_RIPO_Attention/checkpoints/12_29_2022_20_56_15_baseline" 2 | device: "cuda:2" 3 | seed_bar_num: 2 4 | target_bar_num: 16 5 | sampling: 6 | decoder_choice: "top_k" 7 | temperature: 1.0 8 | top_k: 5 9 | top_p: 0.9 10 | -------------------------------------------------------------------------------- /demo/config/music_transformer_event_based.yaml: -------------------------------------------------------------------------------- 1 | experiment: "Music Transformer Event Based" 2 | batch_size: 16 3 | optimizer: "adam" 4 | lr: 0.001 5 | device: &dvce "cuda:3" #"cuda:3" 6 | d_model: &dmdl 256 7 | epochs: 80 8 | dataset: 9 | data_dir: "/data/nicolas/MotifNet_RIPO_transformer_FME/data_processing_new/wikifornia_theorytab_csv_mid" 10 | seq_len_chord: 88 11 | seq_len_note: 246 12 | relative_pitch_attention: 13 | d_model: *dmdl 14 | nhead: 8 15 | dim_feedforward: 2048 16 | dropout: 0.2 17 | nlayers: 2 18 | pitch_dim: 128 19 | dur_dim: 17 20 | emb_size: 128 21 | # max_len: 245 22 | device: *dvce 23 | position_encoding_conf: 24 | if_index: True 25 | if_global_timing: False 26 | if_modulo_timing: False 27 | device: *dvce 28 | attention_conf: 29 | attention_type: "rgl_rel_pitch_dur" #mha, rgl_rel_pitch_dur, rgl_vanilla, linear 30 | if_add_relative_pitch: False 31 | if_add_relative_duration: False 32 | if_add_relative_idx: True 33 | if_add_relative_idx_no_mask: False 34 | pitch_embedding_conf: 35 | d_model: *dmdl 36 | type: "one_hot" #nn, se, one_hot, nn_pretrain 37 | base: 9919 38 | if_trainable: False 39 | translation_bias_type: "nd" #2d or nd trainable vector/ None 40 | device: *dvce 41 | emb_nn: False 42 | # pretrain_emb_path: None 43 | # freeze_pretrain: True 44 | dur_embedding_conf: 45 | d_model: *dmdl 46 | type: "one_hot" #nn, se 47 | base: 7920 48 | if_trainable: False 49 | translation_bias_type: "nd" #2d or nd trainable vector/ None 50 | device: *dvce 51 | emb_nn: False 52 | # pretrain_emb_path: None 53 | # freeze_pretrain: True 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /demo/config/ripo_transformer.yaml: -------------------------------------------------------------------------------- 1 | experiment: "RIPO" 2 | batch_size: 16 3 | optimizer: "adam" 4 | lr: 0.001 5 | device: &dvce "cuda:3" #"cuda:3" 6 | d_model: &dmdl 256 7 | epochs: 80 8 | dataset: 9 | data_dir: "/data/nicolas/MotifNet_RIPO_transformer_FME/data_processing_new/wikifornia_theorytab_csv_mid" 10 | seq_len_chord: 88 11 | seq_len_note: 246 12 | relative_pitch_attention: 13 | d_model: *dmdl 14 | nhead: 8 15 | dim_feedforward: 2048 16 | dropout: 0.2 17 | nlayers: 2 18 | pitch_dim: 128 19 | dur_dim: 17 20 | emb_size: 128 21 | # max_len: 245 22 | device: *dvce 23 | position_encoding_conf: 24 | if_index: True 25 | if_global_timing: True 26 | if_modulo_timing: True 27 | device: *dvce 28 | attention_conf: 29 | attention_type: "rgl_rel_pitch_dur" #mha, rgl_rel_pitch_dur, rgl_vanilla, linear 30 | if_add_relative_pitch: True 31 | if_add_relative_duration: True 32 | if_add_relative_idx: True 33 | if_add_relative_idx_no_mask: False 34 | pitch_embedding_conf: 35 | d_model: *dmdl 36 | type: "se" #nn, se, one_hot, nn_pretrain 37 | base: 9919 38 | if_trainable: True 39 | translation_bias_type: "nd" #2d or nd trainable vector/ None 40 | device: *dvce 41 | emb_nn: True 42 | # pretrain_emb_path: None 43 | # freeze_pretrain: True 44 | dur_embedding_conf: 45 | d_model: *dmdl 46 | type: "se" #nn, se 47 | base: 7920 48 | if_trainable: True 49 | translation_bias_type: "nd" #2d or nd trainable vector/ None 50 | device: *dvce 51 | emb_nn: True 52 | # pretrain_emb_path: None 53 | # freeze_pretrain: True 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /demo/data_loading.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import glob 4 | import random 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset, DataLoader 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import json 10 | def split_lst(g_lst, split_ratio,seed=666): 11 | random.seed(seed) 12 | random.shuffle(g_lst) 13 | training_data_num = int(len(g_lst)*(1-split_ratio)) 14 | 15 | training_lst = g_lst[:training_data_num] 16 | val_lst = g_lst[training_data_num:] #[(note,chord), (note,chord)...] 17 | return training_lst, val_lst 18 | 19 | class MotifDataset(Dataset): 20 | def __init__(self,data_dir, if_pad = True, seq_len_chord=None, seq_len_note=None): 21 | self.json_lst = glob.glob(data_dir+"/theory/**/**/*.json") +glob.glob(data_dir+"/wiki/*.json") 22 | self.tknz = tokenizer_plain(seq_len_note = seq_len_note, seq_len_chord = seq_len_chord, if_pad = if_pad) 23 | self.notes_inp, self.notes_pred, self.chords = self.process() 24 | def process(self): 25 | note_inp_lst = [] 26 | note_pred_lst = [] 27 | chord_lst = [] 28 | for json_file in self.json_lst: 29 | try: 30 | note, chord, pitch_rel, pitch_rel_mask, onset_rel, onset_rel_mask, dur_onset_cumsum = self.tknz(json_file) 31 | inp_dict, pred_dict = {}, {} 32 | inp_dict['pitch'] = note[:-1, 0] #(n_nodes, ) 33 | inp_dict['dur_p'] = note[:-1, 1]#(n_nodes, ) 34 | 35 | inp_dict['pitch_rel'] = pitch_rel[:-1, :-1] #(n_nodes, ) 36 | inp_dict['pitch_rel_mask'] = pitch_rel_mask[:-1, :-1] #(n_nodes, ) 37 | inp_dict['dur_rel'] = onset_rel[:-1, :-1] #(n_nodes, ) 38 | inp_dict['dur_rel_mask'] = onset_rel_mask[:-1, :-1] #(n_nodes, ) 39 | inp_dict['dur_onset_cumsum'] = dur_onset_cumsum[:-1] #(n_nodes, ) 40 | 41 | pred_dict['pitch'] = note[1:, 0] #(n_nodes, ) 42 | pred_dict['dur_p'] = note[1:, 1]#(n_nodes, ) 43 | 44 | pred_dict['pitch_rel'] = pitch_rel[1:, 1:] #(n_nodes, ) 45 | pred_dict['pitch_rel_mask'] = pitch_rel_mask[1:, 1:] #(n_nodes, ) 46 | pred_dict['dur_rel'] = onset_rel[1:, 1:] #(n_nodes, ) 47 | pred_dict['dur_rel_mask'] = onset_rel_mask[1:, 1:] #(n_nodes, ) 48 | pred_dict['dur_onset_cumsum'] = dur_onset_cumsum[1:] #(n_nodes, ) 49 | 50 | 51 | note_inp_lst.append(inp_dict) 52 | note_pred_lst.append(pred_dict) 53 | chord_lst.append(chord) 54 | except: 55 | continue 56 | return note_inp_lst, note_pred_lst, chord_lst 57 | def __getitem__(self, i): 58 | return self.notes_inp[i], self.notes_pred[i],self.chords[i] 59 | 60 | def __len__(self): 61 | return len(self.notes_inp) 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | class tokenizer_plain(): 70 | def __init__(self,seq_len_note=246, seq_len_chord=88,if_pad = True): 71 | self.duration_dict_inv = {0: -999, 1: 0.25, 2: 0.5, 3: 0.75, 4: 1.0, 5: 1.25, 6: 1.5, 7: 1.75, 8: 2.0, 9: 2.25, 10: 2.5, 11: 2.75, 12: 3.0, 13: 3.25, 14: 3.5, 15: 3.75, 16: 4.0} 72 | 73 | self.chord_dict2 = {"pad":0, "rest":1,"sustain":2,\ 74 | 'Cmaj': 3, 'Cmin': 4, 'Cdim': 5, 'C7': 6, \ 75 | 'C#maj': 7, 'C#min': 8, 'C#dim': 9, 'C#7': 10,'Dbmaj': 7, 'Dbmin': 8, 'Dbdim': 9, 'Db7': 10, \ 76 | 'Dmaj': 11, 'Dmin': 12, 'Ddim': 13, 'D7': 14, \ 77 | 'D#maj': 15, 'D#min': 16,'D#dim': 17, 'D#7': 18, 'Ebmaj': 15, 'Ebmin': 16,'Ebdim': 17, 'Eb7': 18,\ 78 | 'Emaj': 19, 'Emin': 20, 'Edim': 21, 'E7': 22, \ 79 | 'Fmaj': 23, 'Fmin': 24, 'Fdim': 25, 'F7': 26, \ 80 | 'F#maj': 27, 'F#min': 28, 'F#dim': 29, 'F#7': 30, 'Gbmaj': 27, 'Gbmin': 28, 'Gbdim': 29, 'Gb7': 30,\ 81 | 'Gmaj': 31, 'Gmin': 32,'Gdim': 33, 'G7': 34, \ 82 | 'G#maj': 35, 'G#min': 36, 'G#dim': 37, 'G#7': 38, 'Abmaj': 35, 'Abmin': 36, 'Abdim': 37, 'Ab7': 38,\ 83 | 'Amaj': 39, 'Amin': 40, 'Adim': 41, 'A7': 42, \ 84 | 'A#maj': 43, 'A#min': 44, 'A#dim': 45, 'A#7': 46, 'Bbmaj': 43, 'Bbmin': 44, 'Bbdim': 45, 'Bb7': 46, \ 85 | 'Bmaj': 47, 'Bmin': 48,'Bdim': 49, 'B7': 50} 86 | 87 | self.duration_dict = {"pad":0, 0.25: 1, 0.5: 2, 0.75: 3, 1.0: 4, 1.25: 5, 1.5: 6, 1.75: 7, 2.0: 8,\ 88 | 2.25: 9, 2.5: 10, 2.75: 11, 3.0: 12, 3.25: 13, 3.5: 14, 3.75: 15, 4.0: 16} 89 | 90 | self.pitch_dict2 = {'pad': 0, 'rest': 1, 'sustain': 2, 'C-1': 3, 'Db-1': 4, 'C#-1': 4, 'D-1': 5, 'Eb-1': 6, 'D#-1': 6, 'E-1': 7, 'F-1': 8, 'Gb-1': 9, 'F#-1': 9, 'G-1': 10, 'Ab-1': 11, 'G#-1': 11, 'A-1': 12, 'Bb-1': 13, 'A#-1': 13, 'B-1': 14, 'C0': 15, 'Db0': 16, 'C#0': 16, 'D0': 17, 'Eb0': 18, 'D#0': 18, 'E0': 19, 'F0': 20, 'Gb0': 21, 'F#0': 21, 'G0': 22, 'Ab0': 23, 'G#0': 23, 'A0': 24, 'Bb0': 25, 'A#0': 25, 'B0': 26, 'C1': 27, 'Db1': 28, 'C#1': 28, 'D1': 29, 'Eb1': 30, 'D#1': 30, 'E1': 31, 'F1': 32, 'Gb1': 33, 'F#1': 33, 'G1': 34, 'Ab1': 35, 'G#1': 35, 'A1': 36, 'Bb1': 37, 'A#1': 37, 'B1': 38, 'C2': 39, 'Db2': 40, 'C#2': 40, 'D2': 41, 'Eb2': 42, 'D#2': 42, 'E2': 43, 'F2': 44, 'Gb2': 45, 'F#2': 45, 'G2': 46, 'Ab2': 47, 'G#2': 47, 'A2': 48, 'Bb2': 49, 'A#2': 49, 'B2': 50, 'C3': 51, 'Db3': 52, 'C#3': 52, 'D3': 53, 'Eb3': 54, 'D#3': 54, 'E3': 55, 'F3': 56, 'Gb3': 57, 'F#3': 57, 'G3': 58, 'Ab3': 59, 'G#3': 59, 'A3': 60, 'Bb3': 61, 'A#3': 61, 'B3': 62, 'C4': 63, 'Db4': 64, 'C#4': 64, 'D4': 65, 'Eb4': 66, 'D#4': 66, 'E4': 67, 'F4': 68, 'Gb4': 69, 'F#4': 69, 'G4': 70, 'Ab4': 71, 'G#4': 71, 'A4': 72, 'Bb4': 73, 'A#4': 73, 'B4': 74, 'C5': 75, 'Db5': 76, 'C#5': 76, 'D5': 77, 'Eb5': 78, 'D#5': 78, 'E5': 79, 'F5': 80, 'Gb5': 81, 'F#5': 81, 'G5': 82, 'Ab5': 83, 'G#5': 83, 'A5': 84, 'Bb5': 85, 'A#5': 85, 'B5': 86, 'C6': 87, 'Db6': 88, 'C#6': 88, 'D6': 89, 'Eb6': 90, 'D#6': 90, 'E6': 91, 'F6': 92, 'Gb6': 93, 'F#6': 93, 'G6': 94, 'Ab6': 95, 'G#6': 95, 'A6': 96, 'Bb6': 97, 'A#6': 97, 'B6': 98, 'C7': 99, 'Db7': 100, 'C#7': 100, 'D7': 101, 'Eb7': 102, 'D#7': 102, 'E7': 103, 'F7': 104, 'Gb7': 105, 'F#7': 105, 'G7': 106, 'Ab7': 107, 'G#7': 107, 'A7': 108, 'Bb7': 109, 'A#7': 109, 'B7': 110, 'C8': 111, 'Db8': 112, 'C#8': 112, 'D8': 113, 'Eb8': 114, 'D#8': 114, 'E8': 115, 'F8': 116, 'Gb8': 117, 'F#8': 117, 'G8': 118, 'Ab8': 119, 'G#8': 119, 'A8': 120, 'Bb8': 121, 'A#8': 121, 'B8': 122, 'C9': 123, 'Db9': 124, 'C#9': 124, 'D9': 125, 'Eb9': 126, 'D#9': 126, 'E9': 127, 'F9': 128, 'Gb9': 129, 'F#9': 129, 'G9': 130} 91 | 92 | self.seq_len_note = seq_len_note 93 | self.seq_len_chord = seq_len_chord 94 | self.if_pad = if_pad 95 | 96 | def __call__(self, json_path): 97 | self.json_path = json_path 98 | self.music_dict= json.load(open(json_path,"r")) 99 | note_lst = self.music_dict["chord_note"][0] #[[p1, d1], [p2, d2], ] 100 | chord_lst = self.music_dict["chord_note"][1] #[[c1, d1], [c2, d2], ] 101 | 102 | #get rid of over-long sequences 103 | if self.seq_len_note self.top_p 84 | 85 | # Shift the indices to the right to keep the first token above threshold. 86 | sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1) 87 | sorted_indices_to_remove = tf.concat([ 88 | tf.zeros_like(sorted_indices_to_remove[:, :1]), 89 | sorted_indices_to_remove[:, 1:] 90 | ], -1) 91 | 92 | # Scatter sorted indices to original indexes. 93 | indices_to_remove = self.scatter_values_on_batch_indices(sorted_indices_to_remove, 94 | sorted_indices) 95 | top_p_logits = self.set_tensor_by_indices_to_value(logits, indices_to_remove, 96 | np.NINF) 97 | return top_p_logits 98 | 99 | def scatter_values_on_batch_indices(self, values, batch_indices): 100 | """Scatter `values` into a tensor using `batch_indices`. 101 | Args: 102 | values: tensor of shape [batch_size, vocab_size] containing the values to 103 | scatter 104 | batch_indices: tensor of shape [batch_size, vocab_size] containing the 105 | indices to insert (should be a permutation in range(0, n)) 106 | Returns: 107 | Tensor of shape [batch_size, vocab_size] with values inserted at 108 | batch_indices 109 | """ 110 | tensor_shape = self.get_shape_list(batch_indices) 111 | broad_casted_batch_dims = tf.reshape( 112 | tf.broadcast_to( 113 | tf.expand_dims(tf.range(tensor_shape[0]), axis=-1), tensor_shape), 114 | [1, -1]) 115 | pair_indices = tf.transpose( 116 | tf.concat([broad_casted_batch_dims, 117 | tf.reshape(batch_indices, [1, -1])], 0)) 118 | return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), tensor_shape) 119 | 120 | def set_tensor_by_indices_to_value(self, input_tensor, indices, value): 121 | """Where indices is True, set the value in input_tensor to value. 122 | Args: 123 | input_tensor: float (batch_size, dim) 124 | indices: bool (batch_size, dim) 125 | value: float scalar 126 | Returns: 127 | output_tensor: same shape as input_tensor. 128 | """ 129 | value_tensor = tf.zeros_like(input_tensor) + value 130 | output_tensor = tf.where(indices, value_tensor, input_tensor) 131 | return output_tensor 132 | 133 | def __call__(self, logits): 134 | # print("input logits", logits) 135 | logits = self.sample_logits_with_temperature(logits) 136 | # print("input logits tempe", logits) 137 | if self.decoder_choice =="greedy": 138 | filtered_logits = self.sample_top_k(logits) 139 | elif self.decoder_choice =="top_k": 140 | filtered_logits = self.sample_top_k(logits) 141 | elif self.decoder_choice =="top_p": 142 | filtered_logits = self.sample_top_p(logits) 143 | # print("after filtering", filtered_logits) 144 | # filtered_logits = tf.nn.softmax(filtered_logits, axis = -1) 145 | sampled_logits = tf.random.categorical(filtered_logits, dtype=tf.int64, num_samples=1) 146 | # sampled_logits = tf.cast(sampled_logits, tf.int64) 147 | return sampled_logits 148 | 149 | class Sampler_torch(): 150 | def __init__(self, decoder_choice, temperature=1.0, top_k=None, top_p=None): 151 | self.decoder_choice = decoder_choice 152 | self.temperature = temperature 153 | self.top_k = top_k 154 | if decoder_choice =="greedy": 155 | self.top_k = 1 156 | self.top_p = top_p 157 | 158 | def sample_logits_with_temperature(self, logits): 159 | return logits / self.temperature 160 | 161 | def sample_top_k(self,logits): 162 | indices_to_remove = logits < torch.topk(logits, self.top_k)[0][..., -1, None] 163 | logits[indices_to_remove] = -float('Inf') 164 | return logits 165 | 166 | def sample_top_p(self, logits): 167 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 168 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 169 | 170 | # Remove tokens with cumulative probability above the threshold 171 | sorted_indices_to_remove = cumulative_probs > self.top_p 172 | # Shift the indices to the right to keep also the first token above the threshold 173 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 174 | sorted_indices_to_remove[..., 0] = 0 175 | 176 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 177 | logits[:, indices_to_remove] = -float('Inf') 178 | 179 | return logits 180 | 181 | def __call__(self, logits): 182 | # print("input logits", logits) 183 | logits = self.sample_logits_with_temperature(logits) 184 | # print("input logits tempe", logits) 185 | if self.decoder_choice =="greedy": 186 | filtered_logits = self.sample_top_k(logits) 187 | elif self.decoder_choice =="top_k": 188 | filtered_logits = self.sample_top_k(logits) 189 | elif self.decoder_choice =="top_p": 190 | filtered_logits = self.sample_top_p(logits) 191 | probabilities = F.softmax(filtered_logits, dim=-1) 192 | 193 | sampled_logits = torch.multinomial(probabilities, 1) 194 | return sampled_logits 195 | 196 | def write_midi_monophonic(note_list, chord_list, mid_name, if_cosiatec = False): 197 | MyMIDI2 = MIDIFile(numTracks=1,ticks_per_quarternote=220) #track0 is melody, track1 is chord 198 | 199 | cum_time = float(0) 200 | for pos, (note, dur) in enumerate(note_list): 201 | if note=="sustain": 202 | continue 203 | else: 204 | #check future pos whether sustain 205 | next_pos = pos + 1 206 | if next_pos<=len(note_list)-1: 207 | note_next, dur_next = note_list[next_pos] 208 | while(note_next == "sustain" and next_pos<=len(note_list)-1): 209 | dur += dur_next 210 | next_pos+=1 211 | try: 212 | note_next, dur_next = note_list[next_pos] 213 | except: 214 | note_next = "break the loop" 215 | dur_next = "break the loop" 216 | 217 | if note=="rest": 218 | cum_time = cum_time+dur 219 | else: 220 | MyMIDI2.addNote(track = 0, channel = 0, pitch = pretty_midi.note_name_to_number(note), time = cum_time, duration = dur, volume = 100) 221 | cum_time = cum_time+dur 222 | 223 | with open(mid_name, "wb") as output_file2: 224 | MyMIDI2.writeFile(output_file2) 225 | -------------------------------------------------------------------------------- /generation_uncond.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from data_processing.data_loading import tokenizer_plain, detokenizer_plain 4 | from data_processing.data_loading import MotifDataset 5 | from model.model import RIPO_transformer, loss_function_baseline 6 | from data_processing.utils import Sampler_torch, write_midi_monophonic 7 | from utils.eval_utils import get_rep_seq, get_unique_tokens, get_unique_intervals 8 | 9 | import torch.nn as nn 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.utils.data import DataLoader 13 | 14 | import yaml 15 | import time 16 | from datetime import datetime 17 | import os 18 | import glob 19 | import json 20 | from tqdm import tqdm 21 | import matplotlib.pyplot as plt 22 | class generator(): 23 | def __init__(self, tokenizer, detokenizer, sampler,model): 24 | self.tokenizer = tokenizer 25 | self.detokenizer = detokenizer 26 | self.sampler = sampler 27 | self.model = model 28 | 29 | def cal_sum_dur(self, dur): 30 | dur = dur.clone().cpu().detach().numpy() 31 | dur_decoded = [self.detokenizer.decode_dur[x] for x in dur[0]] 32 | return sum(dur_decoded) 33 | 34 | def cal_cumsum_dur(self, dur): 35 | dur_cpu = dur.clone().cpu().detach().numpy() 36 | dur_decoded = torch.tensor([0.]+[self.detokenizer.decode_dur[x] for x in dur_cpu[0]][:-1]).to(dur.device) 37 | out = torch.cumsum(dur_decoded, dim = -1)[None, ...] 38 | return out 39 | 40 | def get_rel_mask(self, inp, tpe="pitch"): 41 | if tpe =="pitch": 42 | mask0 = inp==0 43 | mask1 = inp==1 44 | mask2 = inp==2 45 | mask_01= torch.logical_or(mask0, mask1) 46 | mask_012= torch.logical_or(mask_01, mask2) 47 | mask_matrix = torch.logical_or(mask_012[..., None], mask_012[None, ...]) 48 | if tpe =="dur": 49 | mask0 = inp==0 50 | mask_matrix = torch.logical_or(mask0[..., None], mask0[None, ...]) 51 | 52 | return mask_matrix 53 | 54 | def generate(self, inp_dict, target_bar_num): 55 | max_len = 999 56 | sum_dur = self.cal_sum_dur(inp_dict['dur_p']) #batch, len 57 | i = len(inp_dict['pitch'][0])-1 58 | 59 | #model recursive generate 60 | while sum_dur<=target_bar_num*4 and inp_dict['pitch'].shape[-1]<=245: 61 | pred_dict = self.model(inp_dict) 62 | pitch_pred_logits = pred_dict["pitch_pred"][:, -1, :] 63 | dur_pred_logits = pred_dict["dur_pred"][:, -1, :] 64 | pitch_pred = self.sampler(pitch_pred_logits).to(inp_dict['pitch'].device) 65 | dur_pred = self.sampler(dur_pred_logits).to(inp_dict['pitch'].device) 66 | 67 | #get the probabilty 68 | pitch_prob = F.softmax(pitch_pred_logits, dim = -1)[0, pitch_pred[0]] 69 | dur_prob = F.softmax(dur_pred_logits, dim = -1)[0, dur_pred[0]] 70 | 71 | inp_dict["pitch_prob"].append(pitch_prob) 72 | inp_dict["dur_prob"].append(dur_prob) 73 | 74 | inp_dict["pitch"] = torch.cat((inp_dict["pitch"], pitch_pred), dim = -1) 75 | inp_dict["dur_p"] = torch.cat((inp_dict["dur_p"], dur_pred), dim = -1) 76 | sum_dur = self.cal_sum_dur(inp_dict['dur_p']) 77 | 78 | #calculate pitch_rel, dur_rel, pitch_rel_mask, dur_rel_mask, dur_cumsum 79 | inp_dict["pitch_rel"] = inp_dict["pitch"][:, :, None] - inp_dict["pitch"][:, None, :] 80 | inp_dict["dur_rel"] = inp_dict["dur_p"][:, :, None] - inp_dict["dur_p"][:, None, :] 81 | 82 | inp_dict["pitch_rel_mask"] = self.get_rel_mask(inp_dict["pitch"], tpe="pitch") 83 | inp_dict["dur_rel_mask"] = self.get_rel_mask(inp_dict["dur_p"], tpe="dur") 84 | inp_dict["dur_onset_cumsum"] = self.cal_cumsum_dur(inp_dict["dur_p"]) 85 | 86 | #detokenizer output 87 | out_pitch = inp_dict["pitch"].cpu().detach().numpy()[0] 88 | out_dur = inp_dict["dur_p"].cpu().detach().numpy()[0] 89 | out_note = [[self.detokenizer.decode_pitch[p], self.detokenizer.decode_dur[d]] for p,d in zip(out_pitch, out_dur)] 90 | inp_dict["pitch_prob"] = torch.stack(inp_dict["pitch_prob"], dim = -1) 91 | inp_dict["dur_prob"] = torch.stack(inp_dict["dur_prob"], dim = -1) 92 | out_dict = {} 93 | for key, value in inp_dict.items(): 94 | out_dict[key] = inp_dict[key].cpu().detach().numpy()[0].tolist() 95 | return out_note, out_dict 96 | 97 | def trim_seed_music(inp, dur, seed_bar_num, detokenizer): 98 | dur_decoded = [detokenizer.decode_dur[x] for x in dur[0]] #batch, len ==> len 99 | cum_sum = np.cumsum([0]+[x for x in dur_decoded]) 100 | if cum_sum[-1]=seed_bar_num*4)[0][0] 104 | return inp[:, :where_until] 105 | 106 | if __name__ =="__main__": 107 | #config file should contain the folder path where the ckpt is stored 108 | gen_config_path = "config/generation_uncond.yaml" 109 | with open (gen_config_path, 'r') as f: 110 | gen_cfg = yaml.safe_load(f) 111 | ckpt_path = gen_cfg["ckpt_path"] 112 | device = gen_cfg["device"] 113 | seed_bar_num = gen_cfg['seed_bar_num'] 114 | target_bar_num = gen_cfg['target_bar_num'] 115 | 116 | #model config & load model 117 | with open (glob.glob(ckpt_path+"/*.yaml")[0], 'r') as f: 118 | model_cfg = yaml.safe_load(f) 119 | model_cfg['device'] = device 120 | model_cfg['relative_pitch_attention']['device'] = device 121 | print("model config: ",model_cfg) 122 | 123 | model = RIPO_transformer(**model_cfg['relative_pitch_attention']).to(device) 124 | model.device = device 125 | model.pos_encoder.global_time_embedding.device = device 126 | model.pos_encoder.modulo_time_embedding.device = device 127 | model.pos_encoder.modulo_time_embedding.angles = model.pos_encoder.modulo_time_embedding.angles.to(device) 128 | model.pos_encoder.global_time_embedding.angles = model.pos_encoder.global_time_embedding.angles.to(device) 129 | 130 | if model_cfg['relative_pitch_attention']['dur_embedding_conf']['type']=="se": #if FME 131 | model.dur_embedding.device = device 132 | model.dur_embedding.to(device) 133 | model.dur_embedding.angles = model.dur_embedding.angles.to(device) 134 | if model_cfg['relative_pitch_attention']['pitch_embedding_conf']['type']=="se": #if FME 135 | model.pitch_embedding.device = device 136 | model.pitch_embedding.to(device) 137 | model.pitch_embedding.angles = model.pitch_embedding.angles.to(device) 138 | 139 | model.load_state_dict(torch.load(ckpt_path+"/best_loss.pth")) 140 | model.eval() 141 | 142 | #tokenizer & detokenizer 143 | tokenizer = tokenizer_plain() #gen_cfg['tknz'] 144 | detokenizer = detokenizer_plain() 145 | sampler = Sampler_torch(**gen_cfg['sampling']) 146 | generator = generator(tokenizer, detokenizer, sampler, model) 147 | 148 | #define saving path and save generation config 149 | save_dir = os.path.join(ckpt_path, f"results_{gen_cfg['sampling']['decoder_choice']}_{gen_cfg['sampling']['temperature']}_{gen_cfg['sampling']['top_k']}_{gen_cfg['sampling']['top_p']}_{gen_cfg['seed_bar_num']}_{gen_cfg['target_bar_num']}") 150 | print(f"saving to:{save_dir}") 151 | os.makedirs(save_dir, exist_ok=True) 152 | with open(os.path.join(save_dir, "gen_config.yaml"), 'w') as f: 153 | documents = yaml.dump(gen_cfg, f) 154 | 155 | #load data 156 | dataset = MotifDataset(**model_cfg["dataset"]) 157 | train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*0.9), len(dataset) - int(len(dataset)*0.9)], generator=torch.Generator().manual_seed(0)) #using the same seed to prevent data leaking (use valid set which is not seen during training for generation) 158 | print(f"valid:{len(valid_dataset)}") 159 | valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False) 160 | 161 | #generation 162 | for i,data in tqdm(enumerate(valid_dataloader)): 163 | try: 164 | pitch,dur= data[0]['pitch'].numpy(),data[0]['dur_p'].numpy() 165 | pitch_rel, pitch_rel_mask, dur_rel, dur_rel_mask, dur_cumsum = data[0]['pitch_rel'].numpy(), data[0]['pitch_rel_mask'].numpy(), data[0]['dur_rel'].numpy(), data[0]['dur_rel_mask'].numpy(), data[0]['dur_onset_cumsum'].numpy() 166 | pitch_drop_pad, dur_drop_pad = detokenizer.drop_pad_new(lst = pitch), detokenizer.drop_pad_new(lst = dur) 167 | pitch_drop_pad_trim, dur_drop_pad_trim = trim_seed_music(pitch_drop_pad, dur_drop_pad, seed_bar_num, detokenizer),\ 168 | trim_seed_music(dur_drop_pad, dur_drop_pad, seed_bar_num, detokenizer) 169 | 170 | trim_len = len(pitch_drop_pad_trim[0]) 171 | pitch_rel_trim, pitch_rel_mask_trim, dur_rel_trim, dur_rel_mask_trim, dur_cumsum_trim = pitch_rel[:, :trim_len, :trim_len], pitch_rel_mask[:, :trim_len, :trim_len], dur_rel[:, :trim_len, :trim_len], dur_rel_mask[:, :trim_len, :trim_len], dur_cumsum[:, :trim_len] 172 | pitch_drop_pad_trim, dur_drop_pad_trim,pitch_rel_trim, pitch_rel_mask_trim, dur_rel_trim, dur_rel_mask_trim, dur_cumsum_trim = torch.tensor(pitch_drop_pad_trim).to(device),\ 173 | torch.tensor(dur_drop_pad_trim).to(device),\ 174 | torch.tensor(pitch_rel_trim).to(device),\ 175 | torch.tensor(pitch_rel_mask_trim).to(device),\ 176 | torch.tensor(dur_rel_trim).to(device),\ 177 | torch.tensor(dur_rel_mask_trim).to(device),\ 178 | torch.tensor(dur_cumsum_trim).to(device) 179 | 180 | inp_dict = {"pitch":pitch_drop_pad_trim,"dur_p":dur_drop_pad_trim, "pitch_rel":pitch_rel_trim, "pitch_rel_mask":pitch_rel_mask_trim, "dur_rel":dur_rel_trim,"dur_rel_mask": dur_rel_mask_trim, "dur_onset_cumsum":dur_cumsum_trim, "pitch_prob":[], "dur_prob":[]} 181 | out_note,out_dict = generator.generate(inp_dict, target_bar_num=target_bar_num) 182 | write_midi_monophonic(out_note, chord_list = [], mid_name = os.path.join(save_dir, f"{str(i)}.mid")) 183 | with open(os.path.join(save_dir, f"{str(i)}.json"), 'w') as f: 184 | json.dump(out_dict, f) 185 | except: 186 | continue 187 | -------------------------------------------------------------------------------- /lib/fast_transformers.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/lib/fast_transformers.zip -------------------------------------------------------------------------------- /model/FME_music_positional_encoding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | class Fundamental_Music_Embedding(nn.Module): 10 | def __init__(self, d_model, base, if_trainable = False, if_translation_bias_trainable = True, device='cpu', type = "se",emb_nn=None,translation_bias_type = "nd"): 11 | super().__init__() 12 | self.d_model = d_model 13 | self.device = device 14 | self.base = base 15 | self.if_trainable = if_trainable #whether the se is trainable 16 | 17 | if translation_bias_type is not None: 18 | self.if_translation_bias = True 19 | self.if_translation_bias_trainable = if_translation_bias_trainable #default the 2d vector is trainable 20 | if translation_bias_type=="2d": 21 | translation_bias = torch.rand((1, 2), dtype = torch.float32) #Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1)[0,1) 22 | elif translation_bias_type=="nd": 23 | translation_bias = torch.rand((1, self.d_model), dtype = torch.float32) 24 | translation_bias = nn.Parameter(translation_bias, requires_grad=True) 25 | self.register_parameter("translation_bias", translation_bias) 26 | else: 27 | self.if_translation_bias = False 28 | 29 | i = torch.arange(d_model) 30 | angle_rates = 1 / torch.pow(self.base, (2 * (i//2)) / d_model) 31 | angle_rates = angle_rates[None, ... ].to(self.device) 32 | 33 | if self.if_trainable: 34 | angles = nn.Parameter(angle_rates, requires_grad=True) 35 | self.register_parameter("angles", angles) 36 | 37 | else: 38 | self.angles = angle_rates 39 | 40 | def transform_by_delta_pos_v1(self, inp, delta_pos): 41 | #outdated version, use block diagonal matrix very inefficient 42 | if inp.dim()==3: 43 | batch, length = int(inp.shape[0]), int(inp.shape[1]) 44 | elif inp.dim()==1: 45 | batch, length = 1, int(inp.shape[0]) 46 | 47 | raw = self.FMS(delta_pos) 48 | 49 | wk_phi_1 = torch.reshape(raw,[batch, length,int(self.d_model/2), 2]) #[d_mod/2, 2] -->batch, len, d_mod/2, 2 50 | wk_phi_1_rev=wk_phi_1*torch.tensor([-1., 1.]).to(self.device)[None, None, None, ...] # (batch, len, d_mod/2, 2) * (1, 1, 1, 2) 51 | wk_phi_2 = torch.flip(wk_phi_1, dims = [-1]) ##[d_mod/2, 2] --># (batch, len, d_mod/2, 2) 52 | 53 | wk_phi1_2 = torch.cat((wk_phi_2, wk_phi_1_rev), axis = -1) #[dmod/2, 4] # (batch, len, d_mod/2, 4) 54 | wk_phi1_2_rehsaped = torch.reshape(wk_phi1_2, [batch*length*int(self.d_model/2), 2, 2]) #[dmod/2, 2, 2] --># (batch, len, d_mod/2, 2, 2) we want -->1*3*4*4 55 | 56 | transformation_matrix = torch.block_diag(*wk_phi1_2_rehsaped) 57 | out = torch.matmul(transformation_matrix, torch.reshape(inp, (batch*length*self.d_model, 1)))[:,0] 58 | out = torch.reshape(out, (length, self.d_model)) 59 | return out 60 | 61 | def transform_by_delta_pos_v2(self, inp, delta_pos): 62 | #fast version, no need to use block diagonal matrix 63 | #transpose one token to another in the embedding space 64 | if inp.dim()==3: 65 | batch, length = int(inp.shape[0]), int(inp.shape[1]) 66 | elif inp.dim()==1: 67 | batch, length = 1, int(inp.shape[0]) 68 | 69 | raw = self.FMS(delta_pos) 70 | wk_phi_1 = torch.reshape(raw,[batch*length*int(self.d_model/2), 2]) #[d_mod/2, 2] -->batch* len* d_mod/2, 2 71 | wk_phi_1_rev=wk_phi_1*torch.tensor([-1., 1.]).to(self.device)[None, ...] # (batch*len*d_mod/2, 2) * (1, 2) 72 | wk_phi_2 = torch.flip(wk_phi_1, dims = [-1]) ##[d_mod/2, 2] --># (batch*len*d_mod/2, 2) 73 | 74 | wk_phi1_2 = torch.cat((wk_phi_2, wk_phi_1_rev), axis = -1) #[dmod/2, 4] # (batch* len* d_mod/2, 4) 75 | wk_phi1_2_rehsaped = torch.reshape(wk_phi1_2, [batch*length*int(self.d_model/2), 2, 2]) #[dmod/2, 2, 2] --># (batch* len*d_mod/2, 2, 2) we want -->1*3*4*4 76 | transformation_matrix = wk_phi1_2_rehsaped 77 | 78 | if self.translation_bias is not None: 79 | inp -= self.translation_bias[:, None, :] 80 | 81 | reshaped = torch.reshape(inp, (batch*length*int(self.d_model/2), 2,1)) 82 | out = torch.matmul(transformation_matrix, 83 | reshaped) #(batch* len*d_mod/2, 2, 2) * (batch*len*d_mod, 1, 2) 84 | 85 | out = torch.reshape(out, (batch, length, self.d_model)) 86 | if self.translation_bias is not None: 87 | out += self.translation_bias[:, None, :] 88 | return out 89 | 90 | 91 | def __call__(self, inp): 92 | if inp.dim()==2: 93 | inp = inp[..., None] #pos (batch, num_pitch, 1) 94 | elif inp.dim()==1: 95 | inp = inp[None, ..., None] #pos (1, num_pitch, 1) 96 | angle_rads = inp*self.angles #(batch, num_pitch)*(1,dim) 97 | 98 | # apply sin to even indices in the array; 2i 99 | angle_rads[:, :, 0::2] = torch.sin(angle_rads.clone()[:, : , 0::2]) 100 | 101 | # apply cos to odd indices in the array; 2i+1 102 | angle_rads[:, :, 1::2] = torch.cos(angle_rads.clone()[:, :, 1::2]) 103 | 104 | pos_encoding = angle_rads.to(torch.float32) 105 | if self.if_translation_bias: 106 | if self.translation_bias.size()[-1]!= self.d_model: 107 | translation_bias = self.translation_bias.repeat(1, 1,int(self.d_model/2)) 108 | else: 109 | translation_bias = self.translation_bias 110 | pos_encoding += translation_bias 111 | else: 112 | self.translation_bias = None 113 | return pos_encoding 114 | 115 | def FMS(self, delta_pos): 116 | if delta_pos.dim()==1: 117 | delta_pos = delta_pos[None, ..., None] # len ==> batch, len 118 | if delta_pos.dim()==2: 119 | delta_pos = delta_pos[ ..., None] # batch, len ==> batch, len, 1 120 | if delta_pos.dim()==3: 121 | b_size = delta_pos.shape[0] 122 | len_q = delta_pos.shape[1] 123 | len_k = delta_pos.shape[2] 124 | delta_pos = delta_pos.reshape((b_size, len_q*len_k, 1))# batch, len, len ==> batch, len*len, 1 125 | 126 | raw = delta_pos*self.angles 127 | raw[:, :, 0::2] = torch.sin(raw.clone()[:, :, 0::2]) 128 | raw[:,:,1::2] = torch.cos(raw.clone()[:,:,1::2]) 129 | 130 | if delta_pos.dim()==3: 131 | raw = raw.reshape((b_size, len_q, len_k, -1))# batch, len, len ==> batch, len*len, 1 132 | return raw.to(torch.float32).to(self.device) 133 | 134 | def decode(self, embedded): 135 | if self.translation_bias is not None: 136 | embedded -= self.translation_bias[:, None, :] 137 | 138 | decoded_dim = (torch.asin(embedded)/self.angles[:, None, :]).to(torch.float32) 139 | if self.d_model/2 %2 == 0: 140 | decoded = decoded_dim[:, :, int(self.d_model/2)] 141 | 142 | elif self.d_model/2 %2 == 1: 143 | decoded = decoded_dim[:, :, int(self.d_model/2+1)] 144 | 145 | return decoded 146 | 147 | def decode_tps(self, embedded): 148 | decoded_dim = (torch.asin(embedded)/self.angles[:, None,None, :]).to(torch.float32) 149 | if self.d_model/2 %2 == 0: 150 | decoded = decoded_dim[:, :, :, int(self.d_model/2)] 151 | 152 | elif self.d_model/2 %2 == 1: 153 | decoded = decoded_dim[:, :, :, int(self.d_model/2+1)] 154 | 155 | return decoded 156 | 157 | class Music_PositionalEncoding(nn.Module): 158 | 159 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, if_index = True, if_global_timing = True, if_modulo_timing = True, device = 'cuda:0'): 160 | super().__init__() 161 | self.if_index = if_index 162 | self.if_global_timing = if_global_timing 163 | self.if_modulo_timing = if_modulo_timing 164 | self.dropout = nn.Dropout(p=dropout) 165 | self.index_embedding = Fundamental_Music_Embedding(d_model = d_model, base=10000, device = device, if_trainable=False, translation_bias_type = None, if_translation_bias_trainable = False, type = "se").cuda() 166 | self.global_time_embedding = Fundamental_Music_Embedding(d_model = d_model, base=10001, device = device, if_trainable=False, translation_bias_type = None, if_translation_bias_trainable = False, type = "se").cuda() 167 | self.modulo_time_embedding = Fundamental_Music_Embedding(d_model = d_model, base=10001, device = device, if_trainable=False, translation_bias_type = None, if_translation_bias_trainable = False, type = "se").cuda() 168 | 169 | position = torch.arange(max_len).unsqueeze(1) 170 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 171 | pe = torch.zeros(max_len, 1, d_model) 172 | pe[:, 0, 0::2] = torch.sin(position * div_term) 173 | pe[:, 0, 1::2] = torch.cos(position * div_term) 174 | self.register_buffer('pe', pe) 175 | if self.if_global_timing: 176 | print("pe add global time") 177 | if self.if_modulo_timing: 178 | print("pe add modulo time") 179 | if self.if_index: 180 | print("pe add idx") 181 | def forward(self, inp,dur_onset_cumsum = None): 182 | 183 | if self.if_index: 184 | pe_index = self.pe[:inp.size(1)] #[seq_len, batch_size, embedding_dim] 185 | pe_index = torch.swapaxes(pe_index, 0, 1) #[batch_size, seq_len, embedding_dim] 186 | inp += pe_index 187 | 188 | if self.if_global_timing: 189 | global_timing = dur_onset_cumsum 190 | global_timing_embedding = self.global_time_embedding(global_timing) 191 | inp += global_timing_embedding 192 | 193 | if self.if_modulo_timing: 194 | modulo_timing = dur_onset_cumsum%4 195 | modulo_timing_embedding = self.modulo_time_embedding(modulo_timing) 196 | inp += modulo_timing_embedding 197 | return self.dropout(inp) 198 | 199 | class PositionalEncoding(nn.Module): 200 | 201 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 202 | super().__init__() 203 | self.dropout = nn.Dropout(p=dropout) 204 | 205 | position = torch.arange(max_len).unsqueeze(1) 206 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 207 | pe = torch.zeros(max_len, 1, d_model) 208 | pe[:, 0, 0::2] = torch.sin(position * div_term) 209 | pe[:, 0, 1::2] = torch.cos(position * div_term) 210 | self.register_buffer('pe', pe) 211 | 212 | def forward(self, x): 213 | pos = self.pe[:x.size(1)] #[seq_len, batch_size, embedding_dim] 214 | pos = torch.swapaxes(pos, 0, 1) #[batch_size, seq_len, embedding_dim] 215 | x = x + pos 216 | return self.dropout(x) 217 | 218 | def l2_norm(a, b): 219 | return torch.linalg.norm(a-b, ord = 2, dim = -1) 220 | 221 | def rounding(x): 222 | return x-torch.sin(2.*math.pi*x)/(2.*math.pi) -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/model/__init__.py -------------------------------------------------------------------------------- /model/__pycache__/FME_music_positional_encoding.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/model/__pycache__/FME_music_positional_encoding.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/FME_music_positional_encoding.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/model/__pycache__/FME_music_positional_encoding.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/model/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/model/__pycache__/layers.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/model/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/model/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, Any, Union, Callable 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | import torch.nn.functional as F 7 | from torch.nn import ModuleList, Dropout, Linear, LayerNorm #MultiheadAttention, 8 | import math 9 | from torch.nn.init import xavier_uniform_ 10 | 11 | class TransformerEncoder(nn.Module): 12 | r"""TransformerEncoder is a stack of N encoder layers 13 | Args: 14 | encoder_layer: an instance of the TransformerEncoderLayer() class (required). 15 | num_layers: the number of sub-encoder-layers in the encoder (required). 16 | norm: the layer normalization component (optional). 17 | Examples:: 18 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 19 | >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) 20 | >>> src = torch.rand(10, 32, 512) 21 | >>> out = transformer_encoder(src) 22 | """ 23 | __constants__ = ['norm'] 24 | 25 | def __init__(self, encoder_layer, num_layers, norm=None): 26 | super(TransformerEncoder, self).__init__() 27 | self.layers = _get_clones(encoder_layer, num_layers) 28 | self.num_layers = num_layers 29 | self.norm = norm 30 | 31 | def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pitch_rel=None,pitch_rel_mask=None, dur_rel=None, dur_rel_mask=None) -> Tensor: 32 | r"""Pass the input through the encoder layers in turn. 33 | Args: 34 | src: the sequence to the encoder (required). 35 | mask: the mask for the src sequence (optional). 36 | src_key_padding_mask: the mask for the src keys per batch (optional). 37 | Shape: 38 | see the docs in Transformer class. 39 | """ 40 | output = src 41 | attention_weights = [] 42 | # print(f"check src:{src.shape}, pitch_rel:{pitch_rel.shape},pitch_rel_mask:{pitch_rel_mask.shape}, dur_rel:{dur_rel.shape}, dur_rel_mask:{dur_rel_mask.shape} ") 43 | 44 | for mod in self.layers: 45 | # output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) #I changed here 46 | output, attn_weight = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pitch_rel=pitch_rel,pitch_rel_mask=pitch_rel_mask, dur_rel=dur_rel, dur_rel_mask=dur_rel_mask) 47 | attention_weights.append(attn_weight) 48 | # print(f"attn_weight:{attn_weight.shape}, output:{output.shape}") 49 | attention_weights_cat = torch.cat(attention_weights, dim = -1) 50 | # print(f"attention_weights_cat:{attention_weights_cat.shape}") 51 | if self.norm is not None: #PENDING CHANGE, SHOULD BE NONE 52 | output = self.norm(output) 53 | 54 | return output, attention_weights_cat 55 | 56 | class TransformerDecoder(nn.Module): 57 | r"""TransformerDecoder is a stack of N decoder layers 58 | Args: 59 | decoder_layer: an instance of the TransformerDecoderLayer() class (required). 60 | num_layers: the number of sub-decoder-layers in the decoder (required). 61 | norm: the layer normalization component (optional). 62 | Examples:: 63 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 64 | >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) 65 | >>> memory = torch.rand(10, 32, 512) 66 | >>> tgt = torch.rand(20, 32, 512) 67 | >>> out = transformer_decoder(tgt, memory) 68 | """ 69 | __constants__ = ['norm'] 70 | 71 | def __init__(self, decoder_layer, num_layers, norm=None): 72 | super(TransformerDecoder, self).__init__() 73 | self.layers = _get_clones(decoder_layer, num_layers) 74 | self.num_layers = num_layers 75 | self.norm = norm 76 | 77 | def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, 78 | memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, 79 | memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: 80 | r"""Pass the inputs (and mask) through the decoder layer in turn. 81 | Args: 82 | tgt: the sequence to the decoder (required). 83 | memory: the sequence from the last layer of the encoder (required). 84 | tgt_mask: the mask for the tgt sequence (optional). 85 | memory_mask: the mask for the memory sequence (optional). 86 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 87 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 88 | Shape: 89 | see the docs in Transformer class. 90 | """ 91 | output = tgt 92 | #pending change! lack embedding & pos encoding &dropout! 93 | 94 | for mod in self.layers: 95 | output = mod(output, memory, tgt_mask=tgt_mask, 96 | memory_mask=memory_mask, 97 | tgt_key_padding_mask=tgt_key_padding_mask, 98 | memory_key_padding_mask=memory_key_padding_mask) 99 | 100 | if self.norm is not None: 101 | output = self.norm(output) 102 | 103 | return output 104 | 105 | class TransformerEncoderLayer(nn.Module): 106 | r"""TransformerEncoderLayer is made up of self-attn and feedforward network. 107 | This standard encoder layer is based on the paper "Attention Is All You Need". 108 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 109 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 110 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 111 | in a different way during application. 112 | Args: 113 | d_model: the number of expected features in the input (required). 114 | nhead: the number of heads in the multiheadattention models (required). 115 | dim_feedforward: the dimension of the feedforward network model (default=2048). 116 | dropout: the dropout value (default=0.1). 117 | activation: the activation function of the intermediate layer, can be a string 118 | ("relu" or "gelu") or a unary callable. Default: relu 119 | layer_norm_eps: the eps value in layer normalization components (default=1e-5). 120 | batch_first: If ``True``, then the input and output tensors are provided 121 | as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 122 | norm_first: if ``True``, layer norm is done prior to attention and feedforward 123 | operations, respectivaly. Otherwise it's done after. Default: ``False`` (after). 124 | Examples:: 125 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 126 | >>> src = torch.rand(10, 32, 512) 127 | >>> out = encoder_layer(src) 128 | Alternatively, when ``batch_first`` is ``True``: 129 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) 130 | >>> src = torch.rand(32, 10, 512) 131 | >>> out = encoder_layer(src) 132 | """ 133 | __constants__ = ['batch_first', 'norm_first'] 134 | 135 | def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, 136 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 137 | layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, 138 | device=None, dtype=None) -> None: 139 | factory_kwargs = {'device': device, 'dtype': dtype} 140 | super(TransformerEncoderLayer, self).__init__() 141 | # self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, 142 | # **factory_kwargs) #why does this need drop out? pending change 143 | # self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first) #why does this need drop out? pending change 144 | self.self_attn = RelativeGlobalAttention_provided(d_model, nhead, dropout=dropout, batch_first=batch_first) 145 | # Implementation of Feedforward model 146 | self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) 147 | self.dropout = Dropout(dropout) 148 | self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) 149 | 150 | self.norm_first = norm_first 151 | self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 152 | self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 153 | self.dropout1 = Dropout(dropout) 154 | self.dropout2 = Dropout(dropout) 155 | self.dropout3 = Dropout(dropout) 156 | 157 | # Legacy string support for activation function. 158 | if isinstance(activation, str): 159 | self.activation = _get_activation_fn(activation) 160 | else: 161 | self.activation = activation 162 | 163 | def __setstate__(self, state): 164 | if 'activation' not in state: 165 | state['activation'] = F.relu 166 | super(TransformerEncoderLayer, self).__setstate__(state) 167 | 168 | def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: 169 | r"""Pass the input through the encoder layer. 170 | Args: 171 | src: the sequence to the encoder layer (required). 172 | src_mask: the mask for the src sequence (optional). 173 | src_key_padding_mask: the mask for the src keys per batch (optional). 174 | Shape: 175 | see the docs in Transformer class. 176 | """ 177 | 178 | # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf 179 | 180 | x = src 181 | 182 | attn_logits, attn_weights = self._sa_block(self.norm1(x), attn_mask = src_mask, key_padding_mask = src_key_padding_mask) 183 | if self.norm_first: 184 | x = x + attn_logits 185 | x = x + self._ff_block(self.norm2(x)) 186 | else: #choose this as ori! pending change 187 | x = self.norm1(x + attn_logits) 188 | x = self.norm2(x + self._ff_block(x)) 189 | 190 | return x, attn_weights # I changed to return weight too! 191 | 192 | # self-attention block 193 | def _sa_block(self, x: Tensor, 194 | attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: 195 | # x = self.self_attn(x, x, x, 196 | # attn_mask=attn_mask, 197 | # key_padding_mask=key_padding_mask, 198 | # need_weights=False)[0] 199 | x, attn_weight = self.self_attn(x, x, x, 200 | attn_mask=attn_mask, 201 | key_padding_mask=key_padding_mask) 202 | return self.dropout1(x), self.dropout3(attn_weight) 203 | 204 | # feed forward block 205 | def _ff_block(self, x: Tensor) -> Tensor: 206 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 207 | return self.dropout2(x) 208 | 209 | class TransformerDecoderLayer(nn.Module): 210 | r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. 211 | This standard decoder layer is based on the paper "Attention Is All You Need". 212 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 213 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 214 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 215 | in a different way during application. 216 | Args: 217 | d_model: the number of expected features in the input (required). 218 | nhead: the number of heads in the multiheadattention models (required). 219 | dim_feedforward: the dimension of the feedforward network model (default=2048). 220 | dropout: the dropout value (default=0.1). 221 | activation: the activation function of the intermediate layer, can be a string 222 | ("relu" or "gelu") or a unary callable. Default: relu 223 | layer_norm_eps: the eps value in layer normalization components (default=1e-5). 224 | batch_first: If ``True``, then the input and output tensors are provided 225 | as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 226 | norm_first: if ``True``, layer norm is done prior to self attention, multihead 227 | attention and feedforward operations, respectivaly. Otherwise it's done after. 228 | Default: ``False`` (after). 229 | Examples:: 230 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 231 | >>> memory = torch.rand(10, 32, 512) 232 | >>> tgt = torch.rand(20, 32, 512) 233 | >>> out = decoder_layer(tgt, memory) 234 | Alternatively, when ``batch_first`` is ``True``: 235 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True) 236 | >>> memory = torch.rand(32, 10, 512) 237 | >>> tgt = torch.rand(32, 20, 512) 238 | >>> out = decoder_layer(tgt, memory) 239 | """ 240 | __constants__ = ['batch_first', 'norm_first'] 241 | 242 | def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, 243 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 244 | layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, 245 | device=None, dtype=None) -> None: 246 | factory_kwargs = {'device': device, 'dtype': dtype} 247 | super(TransformerDecoderLayer, self).__init__() 248 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, 249 | **factory_kwargs) 250 | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, 251 | **factory_kwargs) 252 | # Implementation of Feedforward model 253 | self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) 254 | self.dropout = Dropout(dropout) 255 | self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) 256 | 257 | self.norm_first = norm_first 258 | self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 259 | self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 260 | self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 261 | self.dropout1 = Dropout(dropout) 262 | self.dropout2 = Dropout(dropout) 263 | self.dropout3 = Dropout(dropout) 264 | 265 | # Legacy string support for activation function. 266 | if isinstance(activation, str): 267 | self.activation = _get_activation_fn(activation) 268 | else: 269 | self.activation = activation 270 | 271 | def __setstate__(self, state): 272 | if 'activation' not in state: 273 | state['activation'] = F.relu 274 | super(TransformerDecoderLayer, self).__setstate__(state) 275 | 276 | def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, 277 | tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: 278 | r"""Pass the inputs (and mask) through the decoder layer. 279 | Args: 280 | tgt: the sequence to the decoder layer (required). 281 | memory: the sequence from the last layer of the encoder (required). 282 | tgt_mask: the mask for the tgt sequence (optional). 283 | memory_mask: the mask for the memory sequence (optional). 284 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 285 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 286 | Shape: 287 | see the docs in Transformer class. 288 | """ 289 | # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf 290 | 291 | x = tgt 292 | if self.norm_first: 293 | x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask) 294 | x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask) 295 | x = x + self._ff_block(self.norm3(x)) 296 | else: 297 | x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask)) 298 | x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask)) 299 | x = self.norm3(x + self._ff_block(x)) 300 | 301 | return x 302 | 303 | # self-attention block 304 | def _sa_block(self, x: Tensor, 305 | attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: 306 | x = self.self_attn(x, x, x, 307 | attn_mask=attn_mask, 308 | key_padding_mask=key_padding_mask, 309 | need_weights=False)[0] 310 | return self.dropout1(x) 311 | 312 | # multihead attention block 313 | def _mha_block(self, x: Tensor, mem: Tensor, 314 | attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: 315 | x = self.multihead_attn(x, mem, mem, 316 | attn_mask=attn_mask, 317 | key_padding_mask=key_padding_mask, 318 | need_weights=False)[0] 319 | return self.dropout2(x) 320 | 321 | # feed forward block 322 | def _ff_block(self, x: Tensor) -> Tensor: 323 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 324 | return self.dropout3(x) 325 | 326 | class RelativeGlobalAttention_my_version(nn.Module): 327 | 328 | def __init__(self, embed_dim=256, num_heads=4, dropout = 0.2, add_emb=False ,max_seq=245, batch_first=True): 329 | super().__init__() 330 | self.dropout = nn.Dropout(p=dropout) 331 | self.len_k = None 332 | self.max_seq = max_seq 333 | self.h = num_heads 334 | self.d = embed_dim 335 | self.dh = embed_dim // num_heads 336 | self.Wq = torch.nn.Linear(self.d, self.d) 337 | self.Wk = torch.nn.Linear(self.d, self.d) 338 | self.Wv = torch.nn.Linear(self.d, self.d) 339 | self.fc = torch.nn.Linear(embed_dim, embed_dim) 340 | self.additional = add_emb 341 | E_init = torch.rand((self.max_seq, int(self.dh)), dtype = torch.float32) 342 | E = nn.Parameter(E_init, requires_grad=True) 343 | self.register_parameter("E", E) 344 | 345 | def forward(self, q, k, v, attn_mask=None, key_padding_mask=None, need_weights=True, pitch_rel=None,pitch_rel_mask=None, dur_rel=None, dur_rel_mask=None): 346 | """ 347 | :param inputs: a list of tensors. i.e) [Q, K, V] 348 | :param mask: mask tensor 349 | :param kwargs: 350 | :return: final tensor ( output of attention ) 351 | """ 352 | 353 | batch_size, seq_length, embed_dim = q.size() 354 | 355 | q, k, v = self.Wq(q), self.Wk(k), self.Wv(v) #batch, seq, dim 356 | 357 | q, k, v = q.reshape(batch_size, seq_length, self.h, self.dh).permute(0, 2, 1, 3),\ 358 | k.reshape(batch_size, seq_length, self.h, self.dh).permute(0, 2, 1, 3),\ 359 | v.reshape(batch_size, seq_length, self.h, self.dh).permute(0, 2, 1, 3) 360 | 361 | self.len_k = k.size(2) 362 | self.len_q = q.size(2) 363 | 364 | E = self.E[self.max_seq-q.shape[-2]:, :] 365 | 366 | QE = torch.matmul(q, E.permute(1, 0))#batch, n_head, len_q, dim * len_k, dim ==> batch, n_head, len_q, len_k 367 | 368 | QE = self._qe_masking(QE) 369 | 370 | Srel = self._skewing(QE) 371 | 372 | Kt = k.permute(0, 1, 3, 2) 373 | QKt = torch.matmul(q, Kt) 374 | logits = QKt + Srel 375 | logits = logits / math.sqrt(self.dh) 376 | 377 | if attn_mask is not None or key_padding_mask is not None: 378 | attn_mask = attn_mask[None, None, ...] 379 | key_padding_mask = key_padding_mask[:, None, :, None] 380 | mask = torch.logical_or(key_padding_mask, attn_mask!=0) 381 | logits = logits.masked_fill(mask , -9e15) 382 | 383 | attention_weights = F.softmax(logits, -1) 384 | attention = torch.matmul(attention_weights, v) 385 | 386 | out = attention.permute(0, 2, 1, 3) 387 | out = torch.reshape(out, (out.size(0), -1, self.d)) 388 | 389 | out = self.fc(out) 390 | out = self.dropout(out) 391 | return out, attention_weights 392 | 393 | 394 | def _skewing(self, tensor: torch.Tensor): 395 | padded = F.pad(tensor, [1, 0, 0, 0, 0, 0, 0, 0]) 396 | reshaped = torch.reshape(padded, shape=[padded.size(0), padded.size(1), padded.size(-1), padded.size(-2)]) 397 | Srel = reshaped[:, :, 1:, :] 398 | if self.len_k > self.len_q: 399 | Srel = F.pad(Srel, [0, 0, 0, 0, 0, 0, 0, self.len_k-self.len_q]) 400 | elif self.len_k < self.len_q: 401 | Srel = Srel[:, :, :, :self.len_k] 402 | 403 | return Srel 404 | 405 | @staticmethod 406 | def _qe_masking(qe): 407 | mask = sequence_mask( 408 | torch.arange(qe.size()[-1] - 1, qe.size()[-1] - qe.size()[-2] - 1, -1).to(qe.device), 409 | qe.size()[-1]) 410 | mask = ~mask.to(mask.device) 411 | return mask.to(qe.dtype) * qe 412 | 413 | class MultiheadAttention_myversion(nn.Module): 414 | 415 | def __init__(self, embed_dim=256, num_heads=4, dropout = 0.2, add_emb=False ,max_seq=245, batch_first=True): 416 | super().__init__() 417 | print(f"using vanilla transformer with multi-head attn!") 418 | 419 | self.dropout = nn.Dropout(p=dropout) 420 | self.len_k = None 421 | self.max_seq = max_seq 422 | self.h = num_heads 423 | self.d = embed_dim 424 | self.dh = embed_dim // num_heads 425 | self.Wq = torch.nn.Linear(self.d, self.d) 426 | self.Wk = torch.nn.Linear(self.d, self.d) 427 | self.Wv = torch.nn.Linear(self.d, self.d) 428 | self.fc = torch.nn.Linear(embed_dim, embed_dim) 429 | self.additional = add_emb 430 | 431 | def forward(self, q, k, v, attn_mask=None, key_padding_mask=None, need_weights=True, pitch_rel=None,pitch_rel_mask=None, dur_rel=None, dur_rel_mask=None): 432 | """ 433 | :param inputs: a list of tensors. i.e) [Q, K, V] 434 | :param mask: mask tensor 435 | :param kwargs: 436 | :return: final tensor ( output of attention ) 437 | """ 438 | 439 | batch_size, seq_length, embed_dim = q.size() 440 | 441 | q, k, v = self.Wq(q), self.Wk(k), self.Wv(v) #batch, seq, dim 442 | 443 | q, k, v = q.reshape(batch_size, seq_length, self.h, self.dh).permute(0, 2, 1, 3),\ 444 | k.reshape(batch_size, seq_length, self.h, self.dh).permute(0, 2, 1, 3),\ 445 | v.reshape(batch_size, seq_length, self.h, self.dh).permute(0, 2, 1, 3) 446 | 447 | self.len_k = k.size(2) 448 | self.len_q = q.size(2) 449 | 450 | Kt = k.permute(0, 1, 3, 2) 451 | QKt = torch.matmul(q, Kt) 452 | logits = QKt 453 | logits = logits / math.sqrt(self.dh) 454 | 455 | 456 | if attn_mask is not None or key_padding_mask is not None: 457 | attn_mask = attn_mask[None, None, ...] 458 | key_padding_mask = key_padding_mask[:, None, :, None] 459 | mask = torch.logical_or(key_padding_mask, attn_mask!=0) 460 | logits = logits.masked_fill(mask , -9e15) 461 | 462 | attention_weights = F.softmax(logits, -1) 463 | attention = torch.matmul(attention_weights, v) 464 | 465 | out = attention.permute(0, 2, 1, 3) 466 | out = torch.reshape(out, (out.size(0), -1, self.d)) 467 | 468 | out = self.fc(out) 469 | out = self.dropout(out) 470 | return out, attention_weights 471 | 472 | class RelativeGlobalAttention_relative_index_pitch_onset(nn.Module): 473 | def __init__(self, embed_dim=256, num_heads=4, dropout = 0.2, add_emb=False ,max_seq=246, if_add_relative_pitch=True , if_add_relative_duration=True, if_add_relative_idx = False,if_add_relative_idx_no_mask = False, batch_first=True): 474 | super().__init__() 475 | self.dropout = nn.Dropout(p=dropout) 476 | self.len_k = None 477 | self.max_seq = max_seq 478 | self.h = num_heads 479 | self.d = embed_dim 480 | self.dh = embed_dim // num_heads 481 | self.Wq = torch.nn.Linear(self.d, self.d) 482 | self.Wk = torch.nn.Linear(self.d, self.d) 483 | self.Wv = torch.nn.Linear(self.d, self.d) 484 | self.fc = torch.nn.Linear(embed_dim, embed_dim) 485 | self.additional = add_emb 486 | self.if_add_relative_pitch=if_add_relative_pitch 487 | self.if_add_relative_duration = if_add_relative_duration 488 | self.if_add_relative_idx = if_add_relative_idx 489 | self.if_add_relative_idx_no_mask = if_add_relative_idx_no_mask 490 | if self.if_add_relative_idx: 491 | E_init = torch.rand((self.max_seq, int(self.dh)), dtype = torch.float32) 492 | E = nn.Parameter(E_init, requires_grad=True) 493 | self.register_parameter("E", E) 494 | if self.if_add_relative_idx_no_mask: 495 | E_init = torch.rand((2*self.max_seq-1, int(self.dh)), dtype = torch.float32) 496 | E = nn.Parameter(E_init, requires_grad=True) 497 | self.register_parameter("E", E) 498 | 499 | self.pitch_relnn = torch.nn.Linear(self.d, self.d) 500 | self.dur_relnn = torch.nn.Linear(self.d, self.d) 501 | 502 | def forward(self, q, k, v, attn_mask=None, key_padding_mask=None, need_weights=True, pitch_rel=None,pitch_rel_mask=None, dur_rel=None, dur_rel_mask=None): 503 | batch_size, seq_length, embed_dim = q.size() 504 | q, k, v = self.Wq(q), self.Wk(k), self.Wv(v) #batch, seq, dim 505 | 506 | q, k, v = q.reshape(batch_size, seq_length, self.h, self.dh).permute(0, 2, 1, 3),\ 507 | k.reshape(batch_size, seq_length, self.h, self.dh).permute(0, 2, 1, 3),\ 508 | v.reshape(batch_size, seq_length, self.h, self.dh).permute(0, 2, 1, 3) 509 | batch_size = q.size(0) 510 | self.len_k = k.size(2) 511 | self.len_q = q.size(2) 512 | 513 | Kt = k.permute(0, 1, 3, 2) #batch, head, len, dim --> batch, head, dim, len 514 | QKt = torch.matmul(q, Kt) # batch, head, len_q, dim_q * batch, head, dim_k, len_k 515 | Srel = 0 516 | if self.if_add_relative_idx: 517 | E = self.E[self.max_seq-q.shape[-2]:, :] # #be careful with the self.max_seq! 518 | QE = torch.matmul(q, E.permute(1, 0))#batch, n_head, len_q, dim * len_k, dim ==> batch, n_head, len_q, len_k 519 | QE = self._qe_masking(QE) 520 | Srel_idx = self._skewing(QE) 521 | Srel += Srel_idx 522 | 523 | if self.if_add_relative_idx_no_mask: 524 | 525 | if self.max_seq-q.shape[-2]!=0: 526 | E = self.E[self.max_seq-q.shape[-2]:-self.max_seq+q.shape[-2], :] #be careful with the self.max_seq! 527 | else: 528 | E = self.E #here assume max_len == input_seq_len #original 529 | 530 | QE = torch.matmul(q, E.permute(1, 0))#batch, n_head, len_q, dim * len_k, dim ==> batch, n_head, len_q, len_k 531 | Srel_idx = self._skewing_no_mask(QE) 532 | Srel += Srel_idx 533 | 534 | if self.if_add_relative_pitch==True: 535 | pitch_rel = self.pitch_relnn(pitch_rel) 536 | pitch_rel_perm = pitch_rel.reshape(batch_size,self.len_q, self.len_k, self.h, self.dh).permute(0, 3, 1, 4, 2) ##batch, len_q, len_k, head, dim_rel--> batch, head, len_q, dim, len_k 537 | q_add_dim = q[:, :, :, None, :]# batch, head, len_q, 1, dim_q 538 | Srel_pitch = torch.matmul(q_add_dim, pitch_rel_perm) # batch, head, len_q, 1, dim_q * batch, head, len_q, dim, len_k ==> batch, head, len_q, 1, len_k 539 | Srel_pitch = Srel_pitch[:, :, :, 0, :] 540 | Srel+=Srel_pitch 541 | 542 | if self.if_add_relative_duration==True: 543 | dur_rel = self.dur_relnn(dur_rel) 544 | dur_rel_perm = dur_rel.reshape(batch_size,self.len_q, self.len_k, self.h, self.dh).permute(0, 3, 1, 4, 2) 545 | q_add_dim = q[:, :, :, None, :]# batch, head, len_q, 1, dim_q 546 | Srel_dur = torch.matmul(q_add_dim, dur_rel_perm) # batch, head, len_q, 1, dim_q * batch, head, len_q, dim, len_k ==> batch, head, len_q, 1, len_k 547 | Srel_dur = Srel_dur[:, :, :, 0, :] 548 | Srel+=Srel_dur 549 | 550 | logits = QKt + Srel 551 | logits = logits / math.sqrt(self.dh) 552 | 553 | if self.if_add_relative_idx_no_mask: #dont use look ahead mask 554 | mask = key_padding_mask[:, None, :, None] 555 | logits = logits.masked_fill(mask , -9e15) 556 | elif attn_mask is not None or key_padding_mask is not None: #use both look ahead mask and pad mask 557 | attn_mask = attn_mask[None, None, ...] 558 | key_padding_mask = key_padding_mask[:, None, :, None] 559 | mask = torch.logical_or(key_padding_mask, attn_mask!=0) 560 | logits = logits.masked_fill(mask , -9e15) 561 | 562 | 563 | attention_weights = F.softmax(logits, -1) 564 | attention = torch.matmul(attention_weights, v) 565 | 566 | out = attention.permute(0, 2, 1, 3) 567 | out = torch.reshape(out, (out.size(0), -1, self.d)) 568 | 569 | out = self.fc(out) 570 | out = self.dropout(out) 571 | return out, attention_weights 572 | 573 | def _skewing(self, tensor: torch.Tensor): 574 | padded = F.pad(tensor, [1, 0, 0, 0, 0, 0, 0, 0]) #batch, head, len, len 575 | reshaped = torch.reshape(padded, shape=[padded.size(0), padded.size(1), padded.size(-1), padded.size(-2)]) 576 | Srel = reshaped[:, :, 1:, :] 577 | if self.len_k > self.len_q: 578 | Srel = F.pad(Srel, [0, 0, 0, 0, 0, 0, 0, self.len_k-self.len_q]) 579 | elif self.len_k < self.len_q: 580 | Srel = Srel[:, :, :, :self.len_k] 581 | 582 | return Srel 583 | def _skewing_no_mask(self, tensor: torch.Tensor): 584 | padded = F.pad(tensor, [0, 1, 0, 0, 0, 0, 0, 0]) #batch, head, len, 2*len-1 ==> batch, head, len, 2*len 585 | flattened = torch.reshape(padded, shape=[padded.size(0), padded.size(1), -1]) #==> batch, head, 2*len^2 586 | 587 | zero_pad = torch.zeros(padded.size(0), padded.size(1), padded.size(2)-1).to(padded.device) # batch, head, len-1 588 | flattened_zero_pad = torch.cat((flattened, zero_pad), dim = -1) #batch, head, 2*len^2+len-1 589 | reshaped = torch.reshape(flattened_zero_pad, (padded.size(0), padded.size(1), padded.size(2)+1, 2*padded.size(2)-1)) 590 | 591 | Srel = reshaped[:, :, :padded.size(2), -padded.size(2):] 592 | return Srel 593 | 594 | @staticmethod 595 | def _qe_masking(qe): 596 | mask = sequence_mask( 597 | torch.arange(qe.size()[-1] - 1, qe.size()[-1] - qe.size()[-2] - 1, -1).to(qe.device), 598 | qe.size()[-1]) 599 | mask = ~mask.to(mask.device) 600 | return mask.to(qe.dtype) * qe 601 | 602 | class TransformerEncoderLayer_type_selection(nn.Module): 603 | 604 | __constants__ = ['batch_first', 'norm_first'] 605 | 606 | def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, 607 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 608 | layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, 609 | device=None, dtype=None, attention_type= "rgl_rel_pitch_dur", if_add_relative_pitch= True, if_add_relative_duration= True, if_add_relative_idx =False, if_add_relative_idx_no_mask = False) -> None: 610 | factory_kwargs = {'device': device, 'dtype': dtype} 611 | super(TransformerEncoderLayer_type_selection, self).__init__() 612 | self.attention_type = attention_type 613 | if attention_type=="mha": 614 | print("attention type: multihead_vanilla") 615 | self.self_attn = MultiheadAttention_myversion(d_model, nhead, dropout=dropout, batch_first=batch_first) #why does this need drop out? pending change 616 | elif attention_type=="rgl_vanilla": 617 | print("attention type: rgl_vanilla--> music transformer") 618 | self.self_attn = RelativeGlobalAttention_my_version(d_model, nhead, dropout=dropout, batch_first=batch_first) #night safari 619 | elif attention_type=="rgl_rel_pitch_dur": 620 | print("attention type: relative index pitch onset --> RIPO transformer") 621 | self.self_attn = RelativeGlobalAttention_relative_index_pitch_onset(d_model, nhead, dropout=dropout, if_add_relative_pitch = if_add_relative_pitch, if_add_relative_duration = if_add_relative_duration, if_add_relative_idx=if_add_relative_idx,if_add_relative_idx_no_mask = if_add_relative_idx_no_mask, batch_first=batch_first) 622 | 623 | # Implementation of Feedforward model 624 | self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) 625 | self.dropout = Dropout(dropout) 626 | self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) 627 | 628 | self.norm_first = norm_first 629 | self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 630 | self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 631 | self.dropout1 = Dropout(dropout) 632 | self.dropout2 = Dropout(dropout) 633 | self.dropout3 = Dropout(dropout) 634 | 635 | # Legacy string support for activation function. 636 | if isinstance(activation, str): 637 | self.activation = _get_activation_fn(activation) 638 | else: 639 | self.activation = activation 640 | 641 | def __setstate__(self, state): 642 | if 'activation' not in state: 643 | state['activation'] = F.relu 644 | super(TransformerEncoderLayer_type_selection, self).__setstate__(state) 645 | 646 | def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None , pitch_rel=None,pitch_rel_mask=None, dur_rel=None, dur_rel_mask=None) -> Tensor: 647 | r"""Pass the input through the encoder layer. 648 | Args: 649 | src: the sequence to the encoder layer (required). 650 | src_mask: the mask for the src sequence (optional). 651 | src_key_padding_mask: the mask for the src keys per batch (optional). 652 | Shape: 653 | see the docs in Transformer class. 654 | """ 655 | 656 | 657 | 658 | x = src 659 | 660 | attn_logits, attn_weights = self._sa_block(self.norm1(x), attn_mask = src_mask, key_padding_mask = src_key_padding_mask, pitch_rel=pitch_rel, 661 | pitch_rel_mask=pitch_rel_mask, 662 | dur_rel=dur_rel, 663 | dur_rel_mask=dur_rel_mask) 664 | if self.norm_first: 665 | x = x + attn_logits 666 | x = x + self._ff_block(self.norm2(x)) 667 | else: #choose this as ori! see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf 668 | x = self.norm1(x + attn_logits) 669 | x = self.norm2(x + self._ff_block(x)) 670 | 671 | return x, attn_weights # I changed to return weight too! 672 | 673 | # self-attention block 674 | def _sa_block(self, x: Tensor, 675 | attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], pitch_rel=None,pitch_rel_mask=None, dur_rel=None, dur_rel_mask=None) -> Tensor: 676 | 677 | x, attn_weight = self.self_attn(x, x, x, 678 | attn_mask=attn_mask, 679 | key_padding_mask=key_padding_mask, 680 | pitch_rel=pitch_rel, 681 | pitch_rel_mask=pitch_rel_mask, 682 | dur_rel=dur_rel, 683 | dur_rel_mask=dur_rel_mask) 684 | return self.dropout1(x), self.dropout3(attn_weight) 685 | 686 | # feed forward block 687 | def _ff_block(self, x: Tensor) -> Tensor: 688 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 689 | return self.dropout2(x) 690 | 691 | def scaled_dot_product(q, k, v, attn_mask= None, key_padding_mask=None): 692 | #attn_mask: len, len, pad_mask: batch, len 693 | d_k = q.size()[-1] 694 | attn_logits = torch.matmul(q, k.transpose(-2, -1)) 695 | attn_logits = attn_logits / math.sqrt(d_k) 696 | if attn_mask is not None or key_padding_mask is not None: 697 | attn_mask = attn_mask[None, None, ...] 698 | key_padding_mask = key_padding_mask[:, None, :, None] 699 | mask = torch.logical_or(key_padding_mask, attn_mask!=0) 700 | attn_logits = attn_logits.masked_fill(mask, -9e15) 701 | 702 | attention = F.softmax(attn_logits, dim=-1) 703 | values = torch.matmul(attention, v) 704 | return values, attention 705 | 706 | def _get_clones(module, N): 707 | return ModuleList([copy.deepcopy(module) for i in range(N)]) 708 | 709 | def _get_activation_fn(activation): 710 | if activation == "relu": 711 | return F.relu 712 | elif activation == "gelu": 713 | return F.gelu 714 | 715 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 716 | 717 | def sequence_mask(length, max_length=None): 718 | if max_length is None: 719 | max_length = length.max() 720 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 721 | return x.unsqueeze(0) < length.unsqueeze(1) 722 | 723 | if __name__ == "__main__": 724 | torch.set_printoptions(precision = 2) 725 | max_seq = 10 726 | 727 | rgl = RelativeGlobalAttention_relative_index_pitch_onset(max_seq=max_seq,if_add_relative_pitch=True , if_add_relative_duration=True, if_add_relative_idx = False,if_add_relative_idx_no_mask = True) 728 | E = rgl.E 729 | q = torch.ones(1, 8, max_seq, 64) 730 | q2 = torch.ones(1, 8, max_seq-3, 64) 731 | QE = torch.matmul(q, E.permute(1, 0))#batch, n_head, len_q, dim * len_k, dim ==> batch, n_head, len_q, len_k 732 | print(rgl.max_seq-q2.shape[-2], -rgl.max_seq+q2.shape[-2]) 733 | QE2 = torch.matmul(q2, E[rgl.max_seq-q2.shape[-2]:-rgl.max_seq+q2.shape[-2], :].permute(1, 0))#batch, n_head, len_q, dim * len_k, dim ==> batch, n_head, len_q, len_k 734 | 735 | print(f"E:{E.shape, QE.shape, QE2.shape}QE:{QE[0,0,:, :]} QE2:{QE2[0,0,:, :]}") 736 | Srel_idx = rgl._skewing_no_mask(QE) 737 | Srel_idx2 = rgl._skewing_no_mask(QE2) 738 | print(f"check srel{Srel_idx.shape, Srel_idx2.shape}:{Srel_idx[0,0,:, :]},check srel2:{Srel_idx2[0,0,:, :]}") -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from model.layers import TransformerEncoder, TransformerEncoderLayer, TransformerEncoderLayer_type_selection 6 | 7 | from model.FME_music_positional_encoding import PositionalEncoding, Fundamental_Music_Embedding, Music_PositionalEncoding 8 | import math 9 | from fast_transformers.builders import TransformerEncoderBuilder 10 | from fast_transformers.masking import TriangularCausalMask 11 | ##used to be in the baseline model.py 12 | def loss_function_baseline(output, target): 13 | padding_msk = target['pitch']!=0 #batch, len 14 | tfmer_lossp_fn = nn.CrossEntropyLoss(reduction='none') 15 | tfmer_lossd_fn = nn.CrossEntropyLoss(reduction='none') 16 | pitch_loss = tfmer_lossp_fn(torch.swapaxes(output['pitch_pred'], -1, -2), target['pitch'])* padding_msk 17 | dur_loss = tfmer_lossd_fn(torch.swapaxes(output['dur_pred'], -1, -2), target['dur_p']) * padding_msk 18 | loss_dict = {} 19 | loss_dict['pitch_loss'] = torch.sum(pitch_loss)/torch.sum(padding_msk) 20 | loss_dict['dur_loss'] = torch.sum(dur_loss)/torch.sum(padding_msk) 21 | loss_dict['total_loss'] = loss_dict['pitch_loss']+loss_dict['dur_loss'] 22 | return loss_dict 23 | 24 | class FusionLayer_baseline(nn.Module): 25 | def __init__(self, out_dim = 128, pitch_dim =128, dur_dim = 17): 26 | super().__init__() 27 | self.out_dim = out_dim 28 | self.fusion = nn.Linear(pitch_dim+dur_dim, out_dim) 29 | def forward(self, pitch, dur_p): 30 | pitch_dur_tpe_tps = torch.cat((pitch,dur_p), dim = -1) 31 | return self.fusion(pitch_dur_tpe_tps) 32 | 33 | class RIPO_transformer(nn.Module): 34 | def __init__(self,d_model,nhead, dim_feedforward, dropout, nlayers,pitch_embedding_conf, dur_embedding_conf, position_encoding_conf, attention_conf ,pitch_dim=128, dur_dim=17, emb_size = 128,device='cuda:0'): 35 | super().__init__() 36 | self.model_type = 'Transformer' 37 | self.pos_encoder = Music_PositionalEncoding(d_model, dropout, **position_encoding_conf) #night safari 38 | print(f"checking model config:{d_model, nhead, dim_feedforward, dropout, attention_conf}") 39 | 40 | if attention_conf['attention_type']!="linear_transformer": 41 | self.use_linear = False 42 | encoder_layers = TransformerEncoderLayer_type_selection(d_model, nhead, dim_feedforward, dropout, batch_first = True, **attention_conf) #night safari 43 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 44 | else: 45 | self.use_linear = True 46 | self.transformer_encoder = TransformerEncoderBuilder.from_kwargs( 47 | n_layers=nlayers, 48 | n_heads=nhead, 49 | query_dimensions=d_model//nhead, 50 | value_dimensions=d_model//nhead, 51 | feed_forward_dimensions=dim_feedforward, 52 | activation='gelu', 53 | dropout=dropout, 54 | attention_type="causal-linear" #causal-linear 55 | ).get() 56 | 57 | 58 | self.fusion_layer = FusionLayer_baseline(out_dim = d_model, pitch_dim = d_model, dur_dim = d_model) #combines multiple types of input 59 | self.pitch_embedding_conf = pitch_embedding_conf 60 | self.dur_embedding_conf = dur_embedding_conf 61 | 62 | if self.pitch_embedding_conf['type']=="nn": 63 | print("end2end trainable embedding for pitch") 64 | self.pitch_embedding = nn.Embedding(pitch_dim,d_model) #pending change for rest and sustain! #changed from embsize to d model 65 | elif self.pitch_embedding_conf['type']=="se": 66 | print("FME for pitch") 67 | self.pitch_embedding = Fundamental_Music_Embedding(**pitch_embedding_conf) 68 | self.pitch_embedding_supplement = nn.Embedding(3,d_model) #pad:0, rest:1, sustain:2 69 | self.relative_pitch_embedding_supplement = nn.Embedding(1,d_model) #hosting non-quantifiable relative pitch (to pad, rest, sustain) 70 | self.pitch_senn = nn.Linear(d_model, d_model) 71 | elif self.pitch_embedding_conf['type']=="nn_pretrain": 72 | print("end2end trainable embedding (pretrained) for pitch") 73 | self.pitch_embedding = nn.Embedding.from_pretrained(torch.from_numpy(np.load(self.pitch_embedding_conf['pretrain_emb_path'])), freeze = self.pitch_embedding_conf['freeze_pretrain']) 74 | 75 | if self.dur_embedding_conf['type']=="nn": 76 | print("end2end trainable embedding for duration") 77 | self.dur_embedding = nn.Embedding(dur_dim, d_model) 78 | elif self.dur_embedding_conf['type']=="se": 79 | print("FME for duration") 80 | self.dur_embedding = Fundamental_Music_Embedding(**dur_embedding_conf) 81 | self.dur_embedding_supplement = nn.Embedding(1,d_model) #pad:0 82 | self.dur_senn = nn.Linear(d_model, d_model) 83 | 84 | elif self.dur_embedding_conf['type']=="nn_pretrain": 85 | print("end2end trainable embedding (pretrained) for duration") 86 | self.dur_embedding = nn.Embedding.from_pretrained(torch.from_numpy(np.load(self.dur_embedding_conf['pretrain_emb_path'])), freeze = self.dur_embedding_conf['freeze_pretrain']) 87 | 88 | self.pitch_ffn = nn.Linear(d_model, pitch_dim) 89 | self.dur_ffn = nn.Linear(d_model, dur_dim) 90 | 91 | self.d_model = d_model 92 | self.device = device 93 | self.init_weights() 94 | 95 | def init_weights(self) -> None: 96 | initrange = 0.1 97 | # self.encoder.weight.data.uniform_(-initrange, initrange) 98 | if self.pitch_embedding_conf['type']=="nn": 99 | self.pitch_embedding.weight.data.uniform_(-initrange, initrange) 100 | elif self.pitch_embedding_conf['type']=="se": 101 | self.pitch_embedding_supplement.weight.data.uniform_(-initrange, initrange) 102 | self.relative_pitch_embedding_supplement.weight.data.uniform_(-initrange, initrange) 103 | 104 | 105 | if self.dur_embedding_conf['type']=="nn": 106 | self.dur_embedding.weight.data.uniform_(-initrange, initrange) 107 | elif self.dur_embedding_conf['type']=="se": 108 | self.dur_embedding_supplement.weight.data.uniform_(-initrange, initrange) 109 | 110 | # self.decoder.bias.data.zero_() 111 | # self.decoder.weight.data.uniform_(-initrange, initrange) 112 | 113 | self.pitch_ffn.bias.data.zero_() 114 | self.pitch_ffn.weight.data.uniform_(-initrange, initrange) 115 | self.dur_ffn.bias.data.zero_() 116 | self.dur_ffn.weight.data.uniform_(-initrange, initrange) 117 | 118 | def get_mask(self, inp, type_ = "lookback"): 119 | #inp shape:(batch, len) 120 | #https://zhuanlan.zhihu.com/p/353365423 121 | #https://andrewpeng.dev/transformer-pytorch/ 122 | length = inp.shape[1] 123 | #lookback & padding 124 | if type_ == "lookback": #additive mask 125 | mask = torch.triu(torch.ones(length, length) * float('-inf'), diagonal=1)#https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html#torch.nn.Transformer 126 | # print("lookback mask", mask) 127 | elif type_ == "pad":#boolean mask 128 | mask = (inp == 0) 129 | # print("pad mask", mask) 130 | return mask.to(self.device) 131 | 132 | def forward(self, inp_dict): 133 | """ 134 | Args: 135 | src: Tensor, shape [seq_len, batch_size] 136 | src_mask: Tensor, shape [seq_len, seq_len] 137 | Returns: 138 | output Tensor of shape [seq_len, batch_size, ntoken] 139 | """ 140 | pitch,dur,pitch_rel,pitch_rel_mask, dur_rel, dur_rel_mask, dur_onset_cumsum = inp_dict['pitch'], inp_dict['dur_p'], inp_dict['pitch_rel'], inp_dict['pitch_rel_mask'], inp_dict['dur_rel'], inp_dict['dur_rel_mask'],inp_dict['dur_onset_cumsum'] 141 | 142 | lookback_mask = self.get_mask(pitch, "lookback") 143 | pad_mask = self.get_mask(pitch, "pad") 144 | 145 | #emb pitch and dur 146 | if self.pitch_embedding_conf['type']=="nn": 147 | pitch_enc = self.pitch_embedding(pitch) #batch, len, emb_dim 148 | pitch_rel_enc = None 149 | 150 | elif self.pitch_embedding_conf['type']=="se": 151 | #when token is 012 use supplement embedding 152 | # pitch_enc = torch.where( (pitch==0) | (pitch==1) | (pitch==2), self.pitch_embedding_supplement(pitch), self.pitch_embedding(pitch)) 153 | pitch_sup= torch.where( (pitch==0) | (pitch==1) | (pitch==2), pitch, 0) 154 | 155 | pitch_sup_emb = self.pitch_embedding_supplement(pitch_sup) 156 | pitch_norm_emb = self.pitch_embedding(pitch) 157 | 158 | pitch = pitch[...,None] 159 | pitch_enc = torch.where((pitch==0) | (pitch==1) | (pitch==2), pitch_sup_emb, pitch_norm_emb) 160 | 161 | pitch_rel_enc = self.pitch_embedding.FMS(pitch_rel) #batch, len,len, dim 162 | rel_pitch_sup_emb = self.relative_pitch_embedding_supplement(torch.tensor(0).to(self.device))[None, None, None, :] 163 | pitch_rel_enc = torch.where(pitch_rel_mask[..., None], rel_pitch_sup_emb, pitch_rel_enc) 164 | 165 | if self.pitch_embedding_conf['emb_nn']==True: 166 | pitch_rel_enc = self.pitch_senn(pitch_rel_enc) 167 | 168 | elif self.pitch_embedding_conf['type']=="one_hot": 169 | pitch_enc = F.one_hot(pitch, num_classes = self.d_model).to(torch.float32) #batch, len, emb_dim 170 | pitch_rel_enc = None 171 | 172 | elif self.pitch_embedding_conf['type']=="nn_pretrain": 173 | pitch_enc = self.pitch_embedding(pitch) #batch, len, emb_dim 174 | pitch_rel_enc = None 175 | 176 | if self.dur_embedding_conf['type']=="nn": 177 | dur_enc = self.dur_embedding(dur) #batch, len, emb_dim 178 | dur_rel_enc = None 179 | elif self.dur_embedding_conf['type']=="se": 180 | #when token is 012 use supplement embedding 181 | # dur_enc = torch.where(dur==0, self.dur_embedding_supplement(dur), self.dur_embedding(dur)) 182 | dur_sup = torch.where(dur==0, dur, 0) 183 | dur_sup_emb = self.dur_embedding_supplement(dur_sup) 184 | dur_norm_emb = self.dur_embedding(dur) 185 | dur = dur[..., None] 186 | dur_enc = torch.where(dur==0, dur_sup_emb, dur_norm_emb) 187 | dur_rel_enc = self.dur_embedding.FMS(dur_rel) 188 | if self.dur_embedding_conf['emb_nn']==True: 189 | dur_rel_enc = self.dur_senn(dur_rel_enc) 190 | elif self.dur_embedding_conf['type']=="one_hot": 191 | dur_enc = F.one_hot(dur, num_classes = self.d_model).to(torch.float32) #batch, len, emb_dim 192 | dur_rel_enc = None 193 | if self.dur_embedding_conf['type']=="nn_pretrain": 194 | dur_enc = self.dur_embedding(dur) #batch, len, emb_dim 195 | dur_rel_enc = None 196 | 197 | fused_music_info = self.fusion_layer(pitch_enc, dur_enc)#should include adj as well? pending change 198 | fused_music_info = fused_music_info * math.sqrt(self.d_model) 199 | 200 | src = self.pos_encoder(fused_music_info, dur_onset_cumsum) #night safari: change here to add duration cumsum too 201 | 202 | if not self.use_linear: 203 | latent, _ = self.transformer_encoder(src, mask = lookback_mask, src_key_padding_mask = pad_mask, pitch_rel=pitch_rel_enc, dur_rel=dur_rel_enc) 204 | else: 205 | attn_mask = TriangularCausalMask(src.size(1), device=src.device) 206 | latent = self.transformer_encoder(src,attn_mask) # y: b x s x d_model 207 | 208 | #norm pred_using transformer 209 | pitch_pred = self.pitch_ffn(latent) ###-->insert CE loss here 210 | dur_pred = self.dur_ffn(latent) ###-->insert CE loss here 211 | output = {"pitch_pred":pitch_pred, "dur_pred":dur_pred} 212 | return output 213 | -------------------------------------------------------------------------------- /training/config/generation_uncond.yaml: -------------------------------------------------------------------------------- 1 | ckpt_path: "/data/nicolas/Fundamental_Music_Embedding_RIPO_Attention/checkpoints/12_29_2022_20_56_15_baseline" 2 | device: "cuda:2" 3 | seed_bar_num: 2 4 | target_bar_num: 16 5 | sampling: 6 | decoder_choice: "top_k" 7 | temperature: 1.0 8 | top_k: 5 9 | top_p: 0.9 10 | -------------------------------------------------------------------------------- /training/config/music_transformer_event_based.yaml: -------------------------------------------------------------------------------- 1 | experiment: "Music Transformer Event Based" 2 | batch_size: 16 3 | optimizer: "adam" 4 | lr: 0.001 5 | device: &dvce "cuda:3" #"cuda:3" 6 | d_model: &dmdl 256 7 | epochs: 80 8 | dataset: 9 | data_dir: "/data/nicolas/MotifNet_RIPO_transformer_FME/data_processing_new/wikifornia_theorytab_csv_mid" 10 | seq_len_chord: 88 11 | seq_len_note: 246 12 | relative_pitch_attention: 13 | d_model: *dmdl 14 | nhead: 8 15 | dim_feedforward: 2048 16 | dropout: 0.2 17 | nlayers: 2 18 | pitch_dim: 128 19 | dur_dim: 17 20 | emb_size: 128 21 | # max_len: 245 22 | device: *dvce 23 | position_encoding_conf: 24 | if_index: True 25 | if_global_timing: False 26 | if_modulo_timing: False 27 | device: *dvce 28 | attention_conf: 29 | attention_type: "rgl_rel_pitch_dur" #mha, rgl_rel_pitch_dur, rgl_vanilla, linear 30 | if_add_relative_pitch: False 31 | if_add_relative_duration: False 32 | if_add_relative_idx: True 33 | if_add_relative_idx_no_mask: False 34 | pitch_embedding_conf: 35 | d_model: *dmdl 36 | type: "one_hot" #nn, se, one_hot, nn_pretrain 37 | base: 9919 38 | if_trainable: False 39 | translation_bias_type: "nd" #2d or nd trainable vector/ None 40 | device: *dvce 41 | emb_nn: False 42 | # pretrain_emb_path: None 43 | # freeze_pretrain: True 44 | dur_embedding_conf: 45 | d_model: *dmdl 46 | type: "one_hot" #nn, se 47 | base: 7920 48 | if_trainable: False 49 | translation_bias_type: "nd" #2d or nd trainable vector/ None 50 | device: *dvce 51 | emb_nn: False 52 | # pretrain_emb_path: None 53 | # freeze_pretrain: True 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /training/config/ripo_transformer.yaml: -------------------------------------------------------------------------------- 1 | experiment: "RIPO" 2 | batch_size: 16 3 | optimizer: "adam" 4 | lr: 0.001 5 | device: &dvce "cuda:3" #"cuda:3" 6 | d_model: &dmdl 256 7 | epochs: 80 8 | dataset: 9 | data_dir: "/data/nicolas/MotifNet_RIPO_transformer_FME/data_processing_new/wikifornia_theorytab_csv_mid" 10 | seq_len_chord: 88 11 | seq_len_note: 246 12 | relative_pitch_attention: 13 | d_model: *dmdl 14 | nhead: 8 15 | dim_feedforward: 2048 16 | dropout: 0.2 17 | nlayers: 2 18 | pitch_dim: 128 19 | dur_dim: 17 20 | emb_size: 128 21 | # max_len: 245 22 | device: *dvce 23 | position_encoding_conf: 24 | if_index: True 25 | if_global_timing: True 26 | if_modulo_timing: True 27 | device: *dvce 28 | attention_conf: 29 | attention_type: "rgl_rel_pitch_dur" #mha, rgl_rel_pitch_dur, rgl_vanilla, linear 30 | if_add_relative_pitch: True 31 | if_add_relative_duration: True 32 | if_add_relative_idx: True 33 | if_add_relative_idx_no_mask: False 34 | pitch_embedding_conf: 35 | d_model: *dmdl 36 | type: "se" #nn, se, one_hot, nn_pretrain 37 | base: 9919 38 | if_trainable: True 39 | translation_bias_type: "nd" #2d or nd trainable vector/ None 40 | device: *dvce 41 | emb_nn: True 42 | # pretrain_emb_path: None 43 | # freeze_pretrain: True 44 | dur_embedding_conf: 45 | d_model: *dmdl 46 | type: "se" #nn, se 47 | base: 7920 48 | if_trainable: True 49 | translation_bias_type: "nd" #2d or nd trainable vector/ None 50 | device: *dvce 51 | emb_nn: True 52 | # pretrain_emb_path: None 53 | # freeze_pretrain: True 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /training/main_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | from model.model import RIPO_transformer, loss_function_baseline 4 | import torch.nn as nn 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from data_processing.data_loading import MotifDataset 9 | from torch.utils.data import DataLoader 10 | import json 11 | import yaml 12 | import time 13 | from torch.utils.tensorboard import SummaryWriter 14 | from datetime import datetime 15 | import os 16 | 17 | def train(model): 18 | model.train() # turn on train mode 19 | total_loss = 0. 20 | total_pitch_loss = 0. 21 | total_dur_loss = 0. 22 | log_interval = len(train_dataloader)-1 #20 23 | start_time = time.time() 24 | 25 | for batch,data in enumerate(train_dataloader): 26 | inp_dict, gt_dict = {}, {} 27 | for key, val in data[0].items(): 28 | inp_dict[key] = val.to(dvc) 29 | for key, val in data[1].items(): 30 | gt_dict[key] = val.to(dvc) 31 | 32 | pred_dict = model(inp_dict) 33 | loss_dict = loss_function_baseline(pred_dict, gt_dict) 34 | optimizer.zero_grad() 35 | loss_dict['total_loss'].backward() 36 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 37 | optimizer.step() 38 | 39 | total_loss += loss_dict['total_loss'].item() 40 | total_pitch_loss += loss_dict['pitch_loss'].item() 41 | total_dur_loss += loss_dict['dur_loss'].item() 42 | 43 | if batch % log_interval == 0 and batch > 0: 44 | 45 | lr = scheduler.get_last_lr()[0] 46 | ms_per_batch = (time.time() - start_time) * 1000 / (log_interval+1) 47 | cur_loss = total_loss / (log_interval+1) 48 | cur_pitch_loss = total_pitch_loss / (log_interval+1) 49 | cur_dur_loss = total_dur_loss / (log_interval+1) 50 | 51 | writer.add_scalar('training total loss',cur_loss,epoch ) 52 | writer.add_scalar('training pitch loss',cur_pitch_loss,epoch ) 53 | writer.add_scalar('training dur loss',cur_dur_loss,epoch ) 54 | print(f'| epoch {epoch:3d} | {batch:5d} batches | ' 55 | f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | ' 56 | f'loss {cur_loss:5.4f} |') 57 | total_loss = 0 58 | total_pitch_loss = 0 59 | total_dur_loss = 0 60 | start_time = time.time() 61 | 62 | def evaluate(model): 63 | model.eval() # turn on evaluation mode 64 | return_total_loss = 0. 65 | total_loss = 0. 66 | total_pitch_loss = 0. 67 | total_dur_loss = 0. 68 | log_interval = len(valid_dataloader)-1 69 | with torch.no_grad(): 70 | for batch,data in enumerate(valid_dataloader): 71 | inp_dict, gt_dict = {}, {} 72 | for key, val in data[0].items(): 73 | inp_dict[key] = val.to(dvc) 74 | for key, val in data[1].items(): 75 | gt_dict[key] = val.to(dvc) 76 | 77 | pred_dict = model(inp_dict) 78 | loss_dict = loss_function_baseline(pred_dict, gt_dict) 79 | 80 | return_total_loss+=loss_dict['total_loss'].item() 81 | total_loss += loss_dict['total_loss'].item() 82 | total_pitch_loss +=loss_dict['pitch_loss'].item() 83 | total_dur_loss +=loss_dict['dur_loss'].item() 84 | 85 | if batch % log_interval == 0 and batch > 0: 86 | cur_loss = total_loss / (log_interval+1) 87 | cur_pitch_loss = total_pitch_loss / (log_interval+1) 88 | cur_dur_loss = total_dur_loss / (log_interval+1) 89 | 90 | writer.add_scalar('valid total loss',cur_loss,epoch) 91 | writer.add_scalar('valid pitch loss',cur_pitch_loss,epoch) 92 | writer.add_scalar('valid dur loss',cur_dur_loss,epoch) 93 | 94 | total_loss = 0 95 | total_pitch_loss = 0 96 | total_dur_loss = 0 97 | 98 | return return_total_loss / len(valid_dataloader) 99 | 100 | if __name__ =="__main__": 101 | torch.set_printoptions(precision = 2) 102 | np.set_printoptions(precision = 2) 103 | config_path = "config/ripo_transformer.yaml" 104 | 105 | #config 106 | with open (config_path, 'r') as f: 107 | cfg = yaml.safe_load(f) 108 | dvc = cfg['device'] 109 | 110 | #define model 111 | model = RIPO_transformer(**cfg['relative_pitch_attention']).to(dvc) 112 | 113 | #Loss func 114 | criterion = nn.CrossEntropyLoss() 115 | 116 | #optimizer 117 | if cfg['optimizer']=="adam": 118 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg['lr']) 119 | elif cfg['optimizer']=="sgd": 120 | optimizer = torch.optim.SGD(model.parameters(), lr=cfg['lr']) 121 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95) 122 | 123 | #tensorboard 124 | #checkpoint_path 125 | date_time = datetime.now().strftime("%m_%d_%Y_%H_%M_%S")+"_baseline" 126 | checkpoint_path = os.path.join("checkpoints", date_time) 127 | writer = SummaryWriter(checkpoint_path) 128 | 129 | #save training config 130 | with open(os.path.join(checkpoint_path, config_path.split("/")[-1]), 'w') as f: 131 | documents = yaml.dump(cfg, f) 132 | 133 | #dataloading 134 | dataset = MotifDataset(**cfg["dataset"]) 135 | train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*0.9), len(dataset) - int(len(dataset)*0.9)], generator=torch.Generator().manual_seed(0)) 136 | train_dataloader = DataLoader(train_dataset, batch_size=cfg['batch_size'], shuffle=True) 137 | valid_dataloader = DataLoader(valid_dataset, batch_size=cfg['batch_size'], shuffle=True) 138 | print(f"total len dataset:{len(dataset)}, training set:{len(train_dataset)}, validation set:{len(valid_dataset)}") 139 | print(f"len train loader:{len(train_dataloader)}, {len(valid_dataloader)}") 140 | print("trainable parameters:") 141 | for name, param in model.named_parameters(): 142 | if param.requires_grad: 143 | print (name) 144 | 145 | best_val_loss = float('inf') 146 | best_model = None 147 | 148 | for epoch in range(cfg['epochs']): 149 | epoch_start_time = time.time() 150 | train(model) 151 | val_loss = evaluate(model) 152 | elapsed = time.time() - epoch_start_time 153 | print('-' * 89) 154 | print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | ' 155 | f'valid loss {val_loss:5.4f} |') 156 | print('-' * 89) 157 | if val_loss < best_val_loss: 158 | best_val_loss = val_loss 159 | torch.save( 160 | model.state_dict(), 161 | os.path.join(checkpoint_path, f"state_ep{epoch}_{best_val_loss:5.4f}.pth"), 162 | ) 163 | 164 | torch.save( 165 | model.state_dict(), 166 | os.path.join(checkpoint_path, f"best_loss.pth"), 167 | ) 168 | scheduler.step() 169 | 170 | -------------------------------------------------------------------------------- /tutorial/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/tutorial/__init__.py -------------------------------------------------------------------------------- /tutorial/fme/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/tutorial/fme/.gitkeep -------------------------------------------------------------------------------- /tutorial/fme/consts.py: -------------------------------------------------------------------------------- 1 | WE_MODEL_PICKLE = "checkpoints/models/we_model.pkl" 2 | FME_MODEL_PICKLE = "checkpoints/models/fme_model.pkl" -------------------------------------------------------------------------------- /tutorial/fme/tutorial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import matplotlib.pyplot as plt 4 | 5 | from tutorial.fme.utils import load_we_model, load_fme_model, relative_pitch_matrix, self_distance_matrix 6 | from tutorial.fme.visual import plot_distance_matrix, plot_we_pitch_embedding 7 | 8 | 9 | the_lick_Cm = torch.tensor([60, 62, 63, 65, 62, 58, 60]) 10 | the_lick_Gm = torch.tensor([67, 69, 70, 72, 69, 65, 67]) 11 | 12 | the_lick_Cm_relative = relative_pitch_matrix(the_lick_Cm) 13 | the_lick_Gm_relative = relative_pitch_matrix(the_lick_Gm) 14 | 15 | _, ax = plt.subplots(1, 2) 16 | 17 | plot_distance_matrix(the_lick_Cm_relative.numpy(), the_lick_Cm.numpy(), ax[0]) 18 | plot_distance_matrix(the_lick_Gm_relative.numpy(), the_lick_Gm.numpy(), ax[1]) 19 | print("The relative pitch matrices for \"the lick\" in different keys are identical. \n--> the reason why humans are able to identify music snippets regardless of absolute keys.") 20 | 21 | 22 | 23 | # Embed "the lick" using One Hot Encoding 24 | the_lick_Cm_one_hot = F.one_hot(the_lick_Cm, num_classes = 256).to(torch.float32) 25 | the_lick_Gm_one_hot = F.one_hot(the_lick_Gm, num_classes = 256).to(torch.float32) 26 | 27 | # Plot self distance matrices with different underlying embedding methods (WE and OH) 28 | 29 | sdm_Cm_one_hot = self_distance_matrix(the_lick_Cm_one_hot) 30 | sdm_Gm_one_hot = self_distance_matrix(the_lick_Gm_one_hot) 31 | 32 | _, ax = plt.subplots(1, 2) 33 | plot_distance_matrix(sdm_Cm_one_hot.numpy(), the_lick_Cm.numpy(), axis=ax[0]) 34 | plot_distance_matrix(sdm_Gm_one_hot.numpy(), the_lick_Gm.numpy(), axis=ax[1]) 35 | 36 | # Load pre-trained word embedding 37 | we_model = load_we_model() 38 | plot_we_pitch_embedding(we_model) 39 | 40 | we_model.pitch_embedding(the_lick_Cm) 41 | 42 | sdm_Cm_word_embedding = self_distance_matrix(we_model.pitch_embedding(the_lick_Cm)) 43 | sdm_Gm_word_embedding = self_distance_matrix(we_model.pitch_embedding(the_lick_Gm)) 44 | 45 | _, ax = plt.subplots(1, 2) 46 | plot_distance_matrix(sdm_Cm_word_embedding.numpy(), the_lick_Cm.numpy(), axis=ax[0]) 47 | plot_distance_matrix(sdm_Gm_word_embedding.numpy(), the_lick_Gm.numpy(), axis=ax[1]) 48 | 49 | # Load pre-trained FME 50 | fme_model = load_fme_model() 51 | 52 | # FME equation 53 | import numpy as np 54 | 55 | def w_k(B, k, d): 56 | return np.power(B, -2*k/d) 57 | 58 | def p_k(f, w, bsin, bcos): 59 | return [np.sin(w*f) + bsin, np.cos(w*f) + bcos] 60 | 61 | B = fme_model.pitch_embedding_conf['base'] 62 | d = fme_model.pitch_embedding_conf['d_model'] 63 | bsin = np.random.standard_normal() 64 | bcos = np.random.standard_normal() 65 | 66 | emb = [] 67 | for f in range(128): 68 | fme = [] 69 | for k in range(d//2): 70 | w = w_k(B, k, d) 71 | p = p_k(f, w, bsin, bcos) 72 | fme.extend(p) 73 | emb.append(fme) 74 | 75 | e = torch.nn.Embedding(num_embeddings=128, embedding_dim=d) 76 | e.weights = torch.Tensor(emb) 77 | 78 | n60 = e.weights.numpy()[60] 79 | n63 = e.weights.numpy()[63] 80 | n69 = e.weights.numpy()[69] 81 | n72 = e.weights.numpy()[72] 82 | 83 | print(f'L2 distance of embeddings for note 72 and 63: {np.linalg.norm(n72 - n63):.4f}') 84 | print(f'L2 distance of embeddings for note 69 and 60: {np.linalg.norm(n69 - n60):.4f}') 85 | -------------------------------------------------------------------------------- /tutorial/fme/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | from tutorial.fme.consts import ( 4 | WE_MODEL_PICKLE, FME_MODEL_PICKLE 5 | ) 6 | 7 | def relative_pitch_matrix(notes: torch.Tensor): 8 | return notes[None, ...] - notes[..., None] 9 | 10 | def self_distance_matrix(embedding: torch.Tensor): 11 | e1 = torch.unsqueeze(embedding, dim = 1) # Dim: (len, 1, dim) 12 | e2 = torch.unsqueeze(embedding, dim = 0) # Dim: (1, len, dim) 13 | dist = e1 - e2 # Dim: (len, len, dim) 14 | dist_l2 = torch.sqrt(torch.sum(torch.pow(dist, 2), dim = -1)) # Dim: (len, len) 15 | return dist_l2 16 | 17 | def get_device(): 18 | if torch.cuda.is_available(): 19 | device = torch.device('cuda') 20 | else: 21 | device = torch.device('cpu') 22 | raise Warning('No CUDA devices found. Using CPU') 23 | 24 | return device 25 | 26 | def _load_model(pickle_path, device='cuda:0'): 27 | with open(pickle_path, 'rb') as o: 28 | model = pickle.load(o) 29 | 30 | # Switch to eval mode 31 | model.eval() 32 | return model 33 | 34 | def load_we_model(): 35 | return _load_model(WE_MODEL_PICKLE) 36 | 37 | def load_fme_model(): 38 | return _load_model(FME_MODEL_PICKLE) 39 | -------------------------------------------------------------------------------- /tutorial/fme/visual.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def plot_distance_matrix(distance_matrix, notes, axis: plt.Axes, add_text: bool = True): 5 | axis.matshow(distance_matrix, cmap=plt.cm.Blues) 6 | 7 | axis.set_xticks(range(0, len(notes), 1)) 8 | axis.set_yticks(range(0, len(notes), 1)) 9 | axis.set_xticklabels(notes[::1], fontsize=12, rotation=45) 10 | axis.set_yticklabels(notes[::1], fontsize=12, rotation=45) 11 | 12 | if add_text: 13 | for i in range(distance_matrix.shape[0]): 14 | for j in range(distance_matrix.shape[1]): 15 | axis.text(i, j, str(-distance_matrix[j,i])[:4], va='center', ha='center') 16 | 17 | def plot_we_pitch_embedding(we_model): 18 | word_embedding = we_model.pitch_embedding.weight 19 | word_embedding = word_embedding.cpu().detach().numpy() 20 | plt.matshow(word_embedding) 21 | plt.title('256-dimensional Embedding of Musical Pitch') 22 | -------------------------------------------------------------------------------- /tutorial/ripo/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/tutorial/ripo/.gitkeep -------------------------------------------------------------------------------- /tutorial/setup/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/tutorial/setup/.gitkeep -------------------------------------------------------------------------------- /utils/__pycache__/eval_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/utils/__pycache__/eval_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guozixunnicolas/FundamentalMusicEmbedding/793e30079978c859afef73ff4b88d7001bfc5b57/utils/__pycache__/eval_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import Counter 3 | interval_dict = {0: "unison", 1:"2nd", 2:"2nd", 3:"3rd", 4:"3rd", 5:"4th", 6:"tt", 7:"5th", 8:"6th", 9:"6th", 10:"7th", 11:"7th"} 4 | dur_dict = {0: 'pad', 1: 0.25, 2: 0.5, 3: 0.75, 4: 1.0, 5: 1.25, 6: 1.5, 7: 1.75, 8: 2.0, 9: 2.25, 10: 2.5, 11: 2.75, 12: 3.0, 13: 3.25, 14: 3.5, 15: 3.75, 16: 4.0} 5 | 6 | 7 | pitch_lst = [x for x in range(3, 131)] 8 | dur_lst = [x for x in range(1, 17)] 9 | # def get_n_grams(inp, n_gram = 4): 10 | def get_rep_seq(inp, n_gram = 4): 11 | """per song feature""" 12 | gram_lst= [tuple(inp[i:i+n_gram]) for i in range(len(inp)-n_gram+1)] 13 | 14 | unique_gram = list(set(gram_lst)) 15 | return 1.0 - len(unique_gram)/len(gram_lst) 16 | 17 | def get_unique_tokens(inp, tpe = "pitch"): 18 | """aggregate and normalize""" 19 | if tpe=="pitch": 20 | counts = [inp.count(keys) for keys in pitch_lst] #list with length 12, each count number of intervals 21 | elif tpe=="dur": 22 | counts = [inp.count(keys) for keys in dur_lst] #list with length 12, each count number of intervals 23 | 24 | return counts 25 | 26 | def get_unique_intervals(inp, banned_tokens = [0, 1, 2]): 27 | """aggregate and normalize""" 28 | inp = [x for x in inp if x not in banned_tokens] 29 | inp = np.array(inp) 30 | inp_diff = abs(np.diff(inp)) 31 | counts = [inp_diff.tolist().count(keys) for keys,_ in interval_dict.items()] #list with length 12, each count number of intervals 32 | return counts 33 | 34 | def get_arpeggio_num(inp_pitch, inp_dur, sliding_window_size = 4, note_distance_threshold = 5, banned_tokens = [0, 1, 2]): 35 | gram_lst= [tuple(inp_pitch[i:i+sliding_window_size]) for i in range(len(inp_pitch)-sliding_window_size+1)] 36 | gram_lst_dur= [tuple(inp_dur[i:i+sliding_window_size]) for i in range(len(inp_pitch)-sliding_window_size+1)] 37 | 38 | num_of_arpegio = 0 39 | for gram, gram_dur in zip(gram_lst, gram_lst_dur): 40 | inp_diff = np.diff(gram) 41 | if_arpegio_criteria0 = not any([x in banned_tokens for x in gram]) # rest, sustain, pad should not be in the list 42 | if_arpegio_criteria1 = all(inp_diff>0) or all(inp_diff<0) # monotonically increasing/ decreasing 43 | if_arpegio_criteria2 = all(abs(inp_diff)