├── .gitignore ├── LICENSE ├── README.md ├── config.toml ├── create_beats.py ├── dataset_lookup.py ├── datasets.toml ├── gan_training.py ├── models ├── attention_rnn.py ├── cnn_discriminator.py ├── lstm_local_attn.py ├── model_utils.py ├── rpr.py ├── transformer.py └── vanilla_rnn.py ├── predict_stream.py ├── prepare_dataset.py ├── preprocess ├── constants.py ├── dataset.py ├── fetch.py └── prepare.py ├── requirements.txt ├── setup_env.ps1 ├── setup_env.sh ├── tests ├── __init__.py └── test_codec.py ├── train.py └── utils ├── beats_generator.py ├── constants.py ├── data_paths.py ├── devices.py ├── distribution.py ├── metrics.py ├── model.py ├── render.py └── sample.py /.gitignore: -------------------------------------------------------------------------------- 1 | # CS230 Specific 2 | .project_data/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | .idea/ 164 | 165 | # VSCode 166 | .vscode/ 167 | 168 | # experiments 169 | test.ipynb 170 | *.npy -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Conghao Shen, Violet Yao, Yixin Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Everybody Compose: Deep Beats To Music 2 | Authors: Conghao (Tom) Shen, Violet Yao, Yixin Liu 3 | 4 | ## Abstract 5 | 6 | This project presents a deep learning approach to generate monophonic melodies based on input beats, allowing even amateurs to create their own music compositions. Three effective methods - LSTM with Full Attention, LSTM with Local Attention, and Transformer with Relative Position Representation - are proposed for this novel task, providing great variation, harmony, and structure in the generated music. This project allows anyone to compose their own music by tapping their keyboards or ``recoloring'' beat sequences from existing works. 7 | 8 | ## Getting Started 9 | 10 | To get started, clone this repository and install the required packages: 11 | ```sh 12 | git clone https://github.com/tsunrise/everybody-compose.git 13 | cd everybody-compose 14 | pip install -r requirements.txt 15 | ``` 16 | You may encouter dependency issues during training on `protobuf`. If so, try reinstall `tensorboard` by running: 17 | ```sh 18 | pip install --upgrade tensorboard 19 | ``` 20 | This issue is due to an conflicting requirements of `note_seq` and `tensorboard`. 21 | 22 | We have also provided a [Colab Notebook](https://colab.research.google.com/drive/1oVn-lZI1K23EC9py6UibDOL7swQGp4v9?usp=sharing#scrollTo=kp6HIjuYvoye) for your reference. 23 | 24 | ## Training 25 | The preprocessed dataset will automatically be downloaded before training. To train a model, run the `train.py` script with the `-m` or `--model_name` argument followed by a string specifying the name of the model to use. The available model names are: 26 | 27 | - `lstm_attn`: LSTM with Local Attention 28 | - `vanilla_rnn`: Decoder Only Vanilla RNN 29 | - `attention_rnn`: LSTM with Full Attention 30 | - `transformer`: Transformer RPR 31 | 32 | You can also use the `-nf` or `--n_files` argument followed by an integer to specify the number of files to use for training (the default value of -1 means that all available files will be used). 33 | 34 | To specify the number of epochs to train the model for, use the `-n` or `--n_epochs` argument followed by an integer. The default value is 100. 35 | 36 | To specify the device to use for training, use the `-d` or `--device` argument followed by a string. The default value is cuda if a CUDA-enabled GPU is available, or cpu if not. 37 | 38 | To specify the frequency at which to save snapshots of the trained model, use the `-s` or `--snapshots_freq` argument followed by an integer. This specifies the number of epochs between each saved snapshot. The default value is 200. The snapshots will be saved in the `.project_data/snapshots` directory. The default value is 200. 39 | 40 | To specify a checkpoint to load the model from, use the `-c` or `--checkpoint` argument followed by a string specifying the path to the checkpoint file. The default value is None, which means that no checkpoint will be loaded. 41 | 42 | Here are some examples of how to use these arguments: 43 | 44 | ```sh 45 | # Train the LSTM with Local Attention model using all available files, for 100 epochs, on the default device, saving snapshots every 200 epochs, and not using a checkpoint 46 | python train.py -m lstm_attn 47 | 48 | # Train the LSTM with Local Attention model using 10 files, for 1000 epochs, on the CPU, saving snapshots every 100 epochs, and starting from the checkpoint 49 | python train.py -m lstm_attn -nf 10 -n 1000 -d cpu -s 100 -c ./.project_data/snapshots/my_checkpoint.pth 50 | 51 | # Train the Transformer RPR model using all available files, for 500 epochs, on the default device, saving snapshots every 50 epochs, and not using a checkpoint 52 | python train.py -m transformer -n 500 -s 50 53 | ``` 54 | 55 | ## Generating Melodies from Beats 56 | 57 | To generate a predicted notes sequence and save it as a MIDI file, run the `predict_stream.py` script with the `-m` or `--model_name` argument followed by a string specifying the name of the model to use. The available model names are: 58 | 59 | - `lstm_attn`: LSTM with Local Attention 60 | - `vanilla_rnn`: Decoder Only Vanilla RNN 61 | - `attention_rnn`: LSTM with Full Attention 62 | - `transformer`: Transformer RPR 63 | 64 | Use the `-c` or `--checkpoint_path` argument followed by a string 65 | specifying the path to the checkpoint file to use for the model. 66 | 67 | The generated MIDI file will be saved using the filename specified by the `-o` or `--midi_filename` argument (the default value is `output.mid`). 68 | 69 | To specify the device to use for generating the predicted sequence, use the `-d` or `--device` argument followed by a string. The default value is `cuda` if a CUDA-enabled GPU is available, or `cpu` if not. 70 | 71 | To specify the source of the input beats, use the `-s` or `--source` argument followed by a string. The default value is `interactive`, which means that the user will be prompted to input the beats using the keyboard. Other possible values are: 72 | 73 | - A file path, e.g. `beat_sequence.npy`, to load the recorded beats from a file. Recorded beats can be generated using the `create_beats.py` script. 74 | - `dataset` to use a random sample from the dataset as the beats. 75 | 76 | To specify the profile to use for generating the predicted sequence, use the `-t` or `--profile` argument followed by a string. The available values are `beta`, which uses stochastic search, or `beam`, which uses hybrid beam search. The heuristic parameters for these profiles can be customized in the config.toml file by adjusting the corresponding sections in `[sampling.beta]` and `[sampling.beam]`. The default value is default, which uses the settings specified in the `config.toml` file. 77 | 78 | Here are some examples of how to use these arguments: 79 | 80 | ```sh 81 | # Generate a predicted sequence using the LSTM with Local Attention model, from beats by the user using the keyboard, using the checkpoint at ./.project_data/snapshots/my_checkpoint.pth, on the default device, and using the beta profile with default settings 82 | python predict_stream.py -m lstm_attn -c ./.project_data/snapshots/my_checkpoint.pth -t beta 83 | ``` 84 | -------------------------------------------------------------------------------- /config.toml: -------------------------------------------------------------------------------- 1 | [global] 2 | dataset = "mastero" 3 | random_slice_seed = 123 4 | val_ratio = 0.1 5 | train_val_split_seed = 666 6 | 7 | [model.transformer] 8 | lr = 1e-3 # TODO: we can use a learning rate scheduler 9 | seq_len = 64 10 | batch_size = 64 11 | n_notes = 128 12 | embed_dim = 128 13 | hidden_dim = 1024 14 | clip_grad = 5.0 15 | num_encoder_layers = 3 16 | num_decoder_layers = 3 17 | num_heads = 8 18 | src_vocab_size = 2 19 | tgt_vocab_size = 128 20 | 21 | [model.lstm_attn] 22 | lr = 1e-3 23 | seq_len = 128 24 | batch_size = 64 25 | n_notes = 128 26 | hidden_dim = 512 27 | dropout_p = 0.5 28 | clip_grad = 5.0 29 | 30 | [model.vanilla_rnn] 31 | lr = 1e-3 32 | seq_len = 64 33 | batch_size = 64 34 | n_notes = 128 35 | embed_dim = 32 36 | hidden_dim = 256 37 | clip_grad = 5.0 38 | 39 | [model.attention_rnn] 40 | lr = 1e-3 41 | seq_len = 64 42 | batch_size = 64 43 | n_notes = 128 44 | embed_dim = 32 45 | encode_hidden_dim = 512 46 | decode_hidden_dim = 1024 47 | clip_grad = 5.0 48 | 49 | [model.cnn] 50 | lr = 1e-3 51 | embed_dim = 32 52 | 53 | [sampling.default] 54 | strategy = "stochastic" 55 | top_p = 0.9 56 | top_k = 4 57 | repeat_decay = 0.6 58 | temperature = 1.5 59 | hint = ["1"] 60 | 61 | [sampling.beta] 62 | strategy = "stochastic" 63 | top_p = 0.9 64 | top_k = 4 65 | repeat_decay = 0.6 66 | temperature = 1.2 67 | hint = ["1", "3"] 68 | 69 | [sampling.beam] 70 | strategy = "beam" 71 | repeat_decay = 0.6 72 | hint = ["1", "3", "5"] 73 | num_beams = 5 74 | beam_prob = 0.5 75 | temperature = 1 -------------------------------------------------------------------------------- /create_beats.py: -------------------------------------------------------------------------------- 1 | from utils.beats_generator import create_beat 2 | import numpy as np 3 | 4 | if __name__ == "__main__": 5 | beat_sequence = create_beat() 6 | np.set_printoptions(precision=4) 7 | print(beat_sequence) 8 | # save the beat sequence to a file 9 | np.save("beat_sequence.npy", beat_sequence) 10 | print("beat sequence saved to beat_sequence.npy") -------------------------------------------------------------------------------- /dataset_lookup.py: -------------------------------------------------------------------------------- 1 | from preprocess.dataset import BeatsRhythmsDataset 2 | import argparse 3 | 4 | if __name__ == "__main__": 5 | args = argparse.ArgumentParser() 6 | args.add_argument("index", type=int, default="index") 7 | args.add_argument("-o", "--output", type=str, default="output.mid") 8 | args = args.parse_args() 9 | dataset = BeatsRhythmsDataset(64) 10 | dataset.load() 11 | # idx = dataset.name_to_idx["2011/MIDI-Unprocessed_22_R1_2011_MID--AUDIO_R1-D8_12_Track12_wav.midi"] 12 | idx = args.index 13 | dataset.to_midi(idx, args.output) 14 | print(dataset.notes_list[idx][:256].reshape(-1)) -------------------------------------------------------------------------------- /datasets.toml: -------------------------------------------------------------------------------- 1 | [datasets.mono.mastero] 2 | midi = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip" 3 | metadata = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.csv" 4 | truncate = "maestro-v3.0.0/" 5 | processed = "https://r2.tomshen.io/cs230/processed_mastero_mono_v2.pkl" 6 | 7 | [datasets.chords] 8 | -------------------------------------------------------------------------------- /gan_training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import toml 4 | 5 | from typing import Optional 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.data 10 | from torch.utils.tensorboard.writer import SummaryWriter 11 | from models.cnn_discriminator import CNNDiscriminator 12 | 13 | from preprocess.dataset import BeatsRhythmsDataset 14 | from utils.data_paths import DataPaths 15 | from utils.model import get_model, load_checkpoint, save_checkpoint, model_forward 16 | 17 | ''' 18 | Code adapted from https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html 19 | ''' 20 | CONFIG_PATH = "./config.toml" 21 | 22 | def train(generator_name: str, discriminator_name: str, n_epochs: int, device: str, n_files:int=-1, snapshots_freq:int=10, generator_checkpoint: Optional[str] = None, discriminator_checkpoint: Optional[str] = None): 23 | # check cuda status 24 | print(f"Using {device} device") 25 | 26 | # initialize model 27 | config = toml.load(CONFIG_PATH) 28 | global_config = config["global"] 29 | netG_config = config["model"][generator_name] 30 | netG = get_model(generator_name, netG_config, device) 31 | print(netG) 32 | netD_config = config["model"][discriminator_name] 33 | netD = CNNDiscriminator(netG_config["n_notes"], netG_config["seq_len"], netD_config["embed_dim"]).to(device) 34 | print(netD) 35 | 36 | if generator_checkpoint: 37 | load_checkpoint(generator_checkpoint, netG, device) 38 | if discriminator_checkpoint: 39 | load_checkpoint(discriminator_checkpoint, netD, device) 40 | 41 | # Establish convention for real and fake labels during training 42 | real_label = 1. 43 | fake_label = 0. 44 | criterion = nn.BCELoss() 45 | 46 | # define optimizer 47 | optimizerD = torch.optim.Adam(netD.parameters(), lr=netG_config["lr"]) 48 | optimizerG = torch.optim.Adam(netG.parameters(), lr=netD_config["lr"]) 49 | 50 | # prepare training/validation loader 51 | dataset = BeatsRhythmsDataset(netG_config["seq_len"], global_config["random_slice_seed"]) 52 | dataset.load(global_config["dataset"]) 53 | dataset = dataset.subset_remove_short() 54 | if n_files > 0: 55 | dataset = dataset.subset(n_files) 56 | 57 | training_data, val_data = dataset.train_val_split(global_config["train_val_split_seed"], 0) 58 | print(f"Training data: {len(training_data)}") 59 | 60 | train_loader = torch.utils.data.DataLoader(training_data, netG_config["batch_size"], shuffle=True) 61 | 62 | # define tensorboard writer 63 | paths = DataPaths() 64 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M") 65 | log_dir = paths.tensorboard_dir / "{}_{}/{}".format("GAN", "all" if n_files == -1 else n_files, 66 | current_time) 67 | writer = SummaryWriter(log_dir=log_dir, flush_secs=60) 68 | 69 | # training loop 70 | netD.train() 71 | netG.train() 72 | for epoch in range(n_epochs): 73 | num_train_batches = 0 74 | for batch in train_loader: 75 | ############################ 76 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 77 | ########################### 78 | ## Train with all-real batch 79 | netD.zero_grad() 80 | input_seq = batch["beats"].to(device) 81 | target_seq = batch["notes"].long().to(device) 82 | target_prev_seq = batch["notes_shifted"].long().to(device) 83 | label = torch.full((input_seq.shape[0], 1), real_label, dtype=torch.float, device=device) 84 | # Forward pass real batch through D 85 | target_one_hot = torch.nn.functional.one_hot(target_seq, netG_config['n_notes']).float() 86 | output = netD(input_seq, target_one_hot) 87 | # Calculate loss on all-real batch 88 | errD_real = criterion(output, label) 89 | # Calculate gradients for D in backward pass 90 | errD_real.backward() 91 | D_x = output.mean().item() 92 | 93 | ## Train with all-fake batch 94 | # Generate fake image batch with G 95 | fake_logits = model_forward(generator_name, netG, input_seq, target_seq, target_prev_seq, device) 96 | fake = F.gumbel_softmax(fake_logits) 97 | label.fill_(fake_label) 98 | # Classify all fake batch with D 99 | output = netD(input_seq, fake.detach()) 100 | # Calculate D's loss on the all-fake batch 101 | errD_fake = criterion(output, label) 102 | # Calculate the gradients for this batch, accumulated (summed) with previous gradients 103 | errD_fake.backward() 104 | D_G_z1 = output.mean().item() 105 | # Compute error of D as sum over the fake and the real batches 106 | errD = errD_real + errD_fake 107 | # Update D 108 | optimizerD.step() 109 | 110 | ############################ 111 | # (2) Update G network: maximize log(D(G(z))) 112 | ########################### 113 | netG.zero_grad() 114 | label.fill_(real_label) # fake labels are real for generator cost 115 | # Since we just updated D, perform another forward pass of all-fake batch through D 116 | output = netD(input_seq, fake) 117 | # Calculate G's loss based on this output 118 | errG = criterion(output, label) 119 | # Calculate gradients for G 120 | errG.backward() 121 | D_G_z2 = output.mean().item() 122 | # Update G 123 | netG.clip_gradients_(5) 124 | optimizerG.step() 125 | 126 | num_train_batches += 1 127 | 128 | # Output training stats 129 | print('[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' 130 | % (epoch, n_epochs, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) 131 | 132 | writer.add_scalar("Generator loss", errG.item(), epoch) 133 | writer.add_scalar("Discriminator loss", errD.item(), epoch) 134 | 135 | if (epoch + 1) % snapshots_freq == 0: 136 | save_checkpoint(netG, paths, generator_name, n_files, epoch + 1) 137 | save_checkpoint(netD, paths, discriminator_name, n_files, epoch + 1) 138 | 139 | # save model 140 | save_checkpoint(netG, paths, generator_name, n_files, epoch + 1) 141 | save_checkpoint(netD, paths, discriminator_name, n_files, epoch + 1) 142 | writer.close() 143 | 144 | 145 | if __name__ == '__main__': 146 | parser = argparse.ArgumentParser('Train DeepBeats GAN') 147 | parser.add_argument('-gm', '--generator_name', type=str, default = "lstm") 148 | parser.add_argument('-dm', '--discriminator_name', type=str, default = "cnn") 149 | parser.add_argument('-nf', '--n_files', type=int, default=-1) 150 | parser.add_argument('-n', '--n_epochs', type=int, default=100) 151 | parser.add_argument('-d', '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 152 | parser.add_argument('-s', '--snapshots_freq', type=int, default=50) 153 | parser.add_argument('-gc', '--generator_checkpoint', type=str, default=None) 154 | parser.add_argument('-dc', '--discriminator_checkpoint', type=str, default=None) 155 | 156 | main_args = parser.parse_args() 157 | train(**vars(main_args)) 158 | -------------------------------------------------------------------------------- /models/attention_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.nn import functional as F 5 | 6 | class RNNEncoder(nn.Module): 7 | def __init__(self, hidden_dim): 8 | super(RNNEncoder, self).__init__() 9 | self.hidden_dim = hidden_dim 10 | 11 | self.fc = nn.Linear(2, hidden_dim) 12 | self.encoder_lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True, bidirectional = True) 13 | 14 | def forward(self, x): 15 | x = self.fc(x) 16 | encode_output, encode_hidden_dim = self.encoder_lstm(x) 17 | return encode_output, encode_hidden_dim 18 | 19 | class Attention(nn.Module): 20 | def __init__(self, encode_hidden_dim, decode_hidden_dim): 21 | super(Attention, self).__init__() 22 | self.encode_hidden_dim = encode_hidden_dim 23 | self.decode_hidden_dim = decode_hidden_dim 24 | 25 | self.attn1 = nn.Linear(encode_hidden_dim * 2 + decode_hidden_dim, decode_hidden_dim) 26 | self.attn1_activation = nn.LeakyReLU() 27 | self.attn2 = nn.Linear(decode_hidden_dim, 1) 28 | self.attn2_activation = nn.Softmax(dim=1) 29 | 30 | def forward(self, encode_output, hidden_state): 31 | batch_size, seq_len, _ = encode_output.shape 32 | hidden_state = hidden_state.permute(1, 0, 2) 33 | hidden_state = hidden_state.repeat(1, seq_len, 1) 34 | concat = torch.concat((hidden_state, encode_output), dim=-1) 35 | attn = self.attn1(concat) 36 | attn = self.attn1_activation(attn) 37 | attn = self.attn2(attn) 38 | attn = self.attn2_activation(attn) 39 | attn = attn.view(batch_size, 1, seq_len) 40 | context = torch.bmm(attn, encode_output) 41 | return context 42 | 43 | class RNNDecoder(nn.Module): 44 | def __init__(self, num_notes, embed_dim, encode_hidden_dim, decode_hidden_dim, dropout_p = 0.1): 45 | super(RNNDecoder, self).__init__() 46 | self.decode_hidden_dim = decode_hidden_dim 47 | 48 | self.attention = Attention(encode_hidden_dim, decode_hidden_dim) 49 | self.note_embedding = nn.Embedding(num_notes, embed_dim) 50 | self.combine_fc = nn.Linear(encode_hidden_dim * 2 + embed_dim, decode_hidden_dim) 51 | self.dropout = nn.Dropout(dropout_p) 52 | self.post_attention_lstm = nn.LSTM(decode_hidden_dim, decode_hidden_dim, batch_first=True) 53 | self.notes_output = nn.Linear(decode_hidden_dim, num_notes) 54 | 55 | def forward(self, tgt, encode_output, memory=None): 56 | memory = self._default_init_hidden(tgt.shape[0]) if memory is None else memory 57 | context = self.attention(encode_output, memory[0]) 58 | tgt = self.note_embedding(tgt) 59 | tgt = torch.cat((tgt, context), dim=2) 60 | tgt = self.combine_fc(tgt) 61 | tgt = F.relu(tgt) 62 | tgt = self.dropout(tgt) 63 | tgt, memory = self.post_attention_lstm(tgt, memory) 64 | tgt = self.notes_output(tgt) 65 | return tgt, memory 66 | 67 | def _default_init_hidden(self, batch_size): 68 | device = next(self.parameters()).device 69 | h = torch.zeros(1, batch_size, self.decode_hidden_dim).to(device) 70 | c = torch.zeros(1, batch_size, self.decode_hidden_dim).to(device) 71 | return (h, c) 72 | 73 | class DeepBeatsAttentionRNN(nn.Module): 74 | def __init__(self, num_notes, embed_dim, encode_hidden_dim, decode_hidden_dim, dropout_p = 0.1): 75 | super(DeepBeatsAttentionRNN, self).__init__() 76 | self.num_notes = num_notes 77 | 78 | self.encoder = RNNEncoder(encode_hidden_dim) 79 | self.decoder = RNNDecoder(num_notes, embed_dim, encode_hidden_dim, decode_hidden_dim, dropout_p) 80 | 81 | def forward(self, x, tgt): 82 | batch_size, seq_len, _ = x.shape 83 | predicted_notes = torch.zeros(batch_size, seq_len, self.num_notes).to(next(self.parameters()).device) 84 | encode_output, encode_hidden = self.encoder(x) 85 | memory = None 86 | for t in range(seq_len): 87 | output, memory = self.decoder(tgt[:, t:t + 1], encode_output, memory) 88 | predicted_notes[:, t:t+1, :] = output 89 | return predicted_notes 90 | 91 | def loss_function(self, pred, target): 92 | criterion = nn.CrossEntropyLoss() 93 | target = target.flatten() # (batch_size * seq_len) 94 | pred = pred.reshape(-1, pred.shape[-1]) # (batch_size * seq_len, num_notes) 95 | loss = criterion(pred, target) 96 | return loss 97 | 98 | def clip_gradients_(self, max_value): 99 | torch.nn.utils.clip_grad.clip_grad_value_(self.parameters(), max_value) -------------------------------------------------------------------------------- /models/cnn_discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | ''' 6 | Code adapted from https://github.com/amazon-science/transformer-gan/blob/main/model/discriminator.py 7 | ''' 8 | 9 | class CNNDiscriminator(nn.Module): 10 | def __init__(self, num_notes, seq_len, embed_dim, filter_sizes = [2, 3, 4, 5], num_filters = [300, 300, 300, 300], dropout = 0.2): 11 | super(CNNDiscriminator, self).__init__() 12 | self.num_notes = num_notes 13 | self.embed_dim = embed_dim 14 | self.seq_len = seq_len 15 | self.feature_dim = sum(num_filters) 16 | 17 | self.embeddings = nn.Linear(num_notes + 2, embed_dim, bias = False) 18 | self.convs = nn.ModuleList([ 19 | nn.Conv2d(1, n, (f, embed_dim)) for (n, f) in zip(num_filters, filter_sizes) 20 | ]) 21 | 22 | self.highway = nn.Linear(self.feature_dim, self.feature_dim) 23 | self.feature2out = nn.Linear(self.feature_dim, 1) 24 | self.dropout = nn.Dropout(dropout) 25 | 26 | def forward(self, beats, notes): 27 | X = torch.cat((beats, notes), dim = 2) 28 | X = self.embeddings(X).unsqueeze(1) 29 | convs = [F.relu(conv(X)).squeeze(3) for conv in self.convs] 30 | pools = [F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs] 31 | pred = torch.cat(pools, 1) 32 | highway = self.highway(pred) 33 | highway = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred 34 | pred = self.feature2out(self.dropout(highway)) 35 | pred = torch.sigmoid(pred) 36 | return pred 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /models/lstm_local_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class LocalAttnEncoder(nn.Module): 6 | """ 7 | Local attention encoder. 8 | """ 9 | def __init__(self, hidden_dim): 10 | """ 11 | - `duration_fc_dim`: dimension of the fully connected layer for duration 12 | - `hidden_dim`: dimension of the hidden state of both encoder and decoder 13 | - `context_dim`: dimension of the context vector 14 | """ 15 | super(LocalAttnEncoder, self).__init__() 16 | self.duration_fc = nn.Linear(2, hidden_dim) 17 | self.encoder = nn.LSTM(hidden_dim, hidden_dim, batch_first=True, bidirectional = True) 18 | 19 | def forward(self, x): 20 | """ 21 | - `x`: input sequence, shape: (seq_len, 2) 22 | Returns: 23 | - `context`: context vector, shape: (seq_len, context_dim) 24 | - `encoder_state`: encoder state, shape: (1, hidden_dim). The state only includes left-to-right direction. 25 | """ 26 | x = self.duration_fc(x) 27 | x, encoder_state = self.encoder(x) 28 | return x, (encoder_state[0][:1], encoder_state[1][:1]) 29 | 30 | class LocalAttnDecoder(nn.Module): 31 | """ 32 | Local Attention Decoder. 33 | """ 34 | def __init__(self, hidden_dim, num_notes, dropout_p = 0.1): 35 | """ 36 | - `note_embed_size`: embedding size of notes 37 | - `context_dim`: dimension of the context vector 38 | - `hidden_dim`: dimension of the hidden state of both encoder and decoder 39 | - `num_notes`: number of notes 40 | """ 41 | super(LocalAttnDecoder, self).__init__() 42 | self.note_embed = nn.Embedding(num_notes, hidden_dim) 43 | self.combine_fc = nn.Linear(hidden_dim * 3, hidden_dim) 44 | self.dropout = nn.Dropout(dropout_p) 45 | self.rnn = nn.LSTM(hidden_dim, hidden_dim, batch_first=True) 46 | self.notes_output = nn.Linear(hidden_dim, num_notes) 47 | 48 | def forward(self, tgt, context, memory = None): 49 | """ 50 | - `tgt`: target sequence, shape: (seq_len, 1) 51 | tgt[i] is the (i-1)-th note in the sequence 52 | - `context`: context vector, shape: (seq_len, context_dim) 53 | - `memory`: encoder state or intermediate state, shape: pair of (1, hidden_dim) 54 | Returns: 55 | - `output`: output sequence, shape: (seq_len, num_notes) 56 | output[i] is the probability distribution of notes at time step i 57 | """ 58 | # print(f"{tgt.shape=}, {context.shape=}, {memory[0].shape=}") 59 | tgt = self.note_embed(tgt) 60 | tgt = torch.cat((tgt, context), dim=2) 61 | tgt = self.combine_fc(tgt) 62 | tgt = F.relu(tgt) 63 | tgt = self.dropout(tgt) 64 | tgt, memory = self.rnn(tgt, memory) 65 | tgt = self.notes_output(tgt) 66 | return tgt, memory 67 | 68 | class DeepBeatsLSTMLocalAttn(nn.Module): 69 | """ 70 | DeepBeats LSTM with encoder-decoder structure with local attention. 71 | Because the length of output sequence and input sequence is the same and has strong 1-to-1 relationship, 72 | i-th decoder block only uses i-th encoder block's output as context vector. 73 | This can improve the performance of the model, from quadratic to linear. 74 | """ 75 | def __init__(self, num_notes, hidden_dim, dropout_p = 0.1): 76 | """ 77 | - `num_notes`: number of notes 78 | - `duration_fc_dim`: dimension of the fully connected layer for duration 79 | - `context_dim`: dimension of the context vector 80 | - `hidden_dim`: dimension of the hidden state of both encoder and decoder 81 | """ 82 | super(DeepBeatsLSTMLocalAttn, self).__init__() 83 | self.encoder = LocalAttnEncoder(hidden_dim) 84 | self.decoder = LocalAttnDecoder(hidden_dim, num_notes, dropout_p) 85 | self.num_notes = num_notes 86 | 87 | def forward(self, x, tgt): 88 | """ 89 | - `x`: input sequence, shape: (seq_len, 2) 90 | - `tgt`: target sequence, shape: (seq_len, 1) 91 | tgt[i] is the (i-1)-th note in the sequence 92 | Returns: 93 | - `output`: output sequence, shape: (seq_len, num_notes) 94 | output[i] is the probability distribution of notes at time step i 95 | """ 96 | context, encoder_state = self.encoder(x) 97 | output, _ = self.decoder(tgt, context, encoder_state) 98 | return output 99 | 100 | def loss_function(self, pred, target): 101 | """ 102 | Pred: (batch_size, seq_len, num_notes), logits 103 | Target: (batch_size, seq_len), range from 0 to num_notes-1 104 | """ 105 | criterion = nn.CrossEntropyLoss() 106 | target = target.flatten() # (batch_size * seq_len) 107 | pred = pred.reshape(-1, pred.shape[-1]) # (batch_size * seq_len, num_notes) 108 | loss = criterion(pred, target) 109 | return loss 110 | 111 | def clip_gradients_(self, max_value): 112 | torch.nn.utils.clip_grad.clip_grad_value_(self.parameters(), max_value) 113 | 114 | # TODO: next step 115 | # loss: https://github.com/gwinndr/MusicTransformer-Pytorch/blob/master/model/loss.py 116 | # get accuracy -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class ConcatPrev(nn.Module): 5 | def forward(self, x, y_prev): 6 | """ 7 | x: input, shape: (batch_size, seq_len, 2) 8 | y_prev: label, shape: (batch_size, seq_len, embedding_dim), 9 | y_prev[i] should be the embedding of note label for x[i-1], and y[0] is 0. 10 | """ 11 | 12 | # concat x and y_prev_embed to be X 13 | X = torch.cat((x, y_prev), dim=2) 14 | return X 15 | -------------------------------------------------------------------------------- /models/rpr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.nn import functional as F 5 | from torch.nn.parameter import Parameter 6 | from torch.nn import Module 7 | from torch.nn.modules.transformer import _get_clones 8 | from torch.nn.modules.linear import Linear 9 | from torch.nn.modules.dropout import Dropout 10 | from torch.nn.modules.normalization import LayerNorm 11 | from torch.nn.init import * 12 | 13 | from torch.nn.functional import linear, softmax, dropout 14 | 15 | """ 16 | Code adapted from: 17 | https://github.com/gwinndr/MusicTransformer-Pytorch/blob/master/model/rpr.py 18 | https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#Transformer 19 | """ 20 | 21 | # TransformerEncoderRPR 22 | class TransformerEncoderRPR(Module): 23 | """ 24 | ---------- 25 | Author: Pytorch 26 | ---------- 27 | For Relative Position Representation support (https://arxiv.org/abs/1803.02155) 28 | https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/transformer.html#TransformerEncoder 29 | No modification. Copied here to ensure continued compatibility with other edits. 30 | ---------- 31 | """ 32 | 33 | def __init__(self, encoder_layer, num_layers, norm=None): 34 | super(TransformerEncoderRPR, self).__init__() 35 | self.layers = _get_clones(encoder_layer, num_layers) 36 | self.num_layers = num_layers 37 | self.norm = norm 38 | 39 | def forward(self, src, mask=None, src_key_padding_mask=None): 40 | 41 | output = src 42 | 43 | for i in range(self.num_layers): 44 | output = self.layers[i](output, src_mask=mask, 45 | src_key_padding_mask=src_key_padding_mask) 46 | 47 | if self.norm: 48 | output = self.norm(output) 49 | 50 | return output 51 | 52 | # TransformerEncoderRPR 53 | class TransformerDecoderRPR(Module): 54 | """ 55 | ---------- 56 | Author: Pytorch 57 | Modified: Violet Yao 58 | ---------- 59 | For Relative Position Representation support (https://arxiv.org/abs/1803.02155) 60 | https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/transformer.html#TransformerEncoder 61 | No modification. Copied here to ensure continued compatibility with other edits. 62 | ---------- 63 | """ 64 | 65 | def __init__(self, decoded_layer, num_layers, norm=None): 66 | super(TransformerDecoderRPR, self).__init__() 67 | self.layers = _get_clones(decoded_layer, num_layers) 68 | self.num_layers = num_layers 69 | self.norm = norm 70 | 71 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): 72 | 73 | output = tgt 74 | 75 | for i in range(self.num_layers): 76 | output = self.layers[i](output, memory, tgt_mask=tgt_mask, 77 | memory_mask=memory_mask, 78 | tgt_key_padding_mask=tgt_key_padding_mask, 79 | memory_key_padding_mask=memory_key_padding_mask) 80 | 81 | if self.norm: 82 | output = self.norm(output) 83 | 84 | return output 85 | 86 | 87 | # TransformerEncoderLayerRPR 88 | class TransformerEncoderLayerRPR(Module): 89 | """ 90 | ---------- 91 | Author: Pytorch 92 | Modified: Damon Gwinn 93 | ---------- 94 | For Relative Position Representation support (https://arxiv.org/abs/1803.02155) 95 | https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer 96 | Modification to create and call custom MultiheadAttentionRPR 97 | ---------- 98 | """ 99 | 100 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None): 101 | super(TransformerEncoderLayerRPR, self).__init__() 102 | self.self_attn = MultiheadAttentionRPR(d_model, nhead, dropout=dropout, er_len=er_len) 103 | # Implementation of Feedforward model 104 | self.linear1 = Linear(d_model, dim_feedforward) 105 | self.dropout = Dropout(dropout) 106 | self.linear2 = Linear(dim_feedforward, d_model) 107 | 108 | self.norm1 = LayerNorm(d_model) 109 | self.norm2 = LayerNorm(d_model) 110 | self.dropout1 = Dropout(dropout) 111 | self.dropout2 = Dropout(dropout) 112 | 113 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 114 | src2 = self.self_attn(src, src, src, attn_mask=src_mask, 115 | key_padding_mask=src_key_padding_mask)[0] 116 | src = src + self.dropout1(src2) 117 | src = self.norm1(src) 118 | src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) 119 | src = src + self.dropout2(src2) 120 | src = self.norm2(src) 121 | return src 122 | 123 | 124 | # TransformerDecoderLayerRPR 125 | class TransformerDecoderLayerRPR(Module): 126 | """ 127 | ---------- 128 | Author: Pytorch 129 | Modified: Violet Yao 130 | ---------- 131 | For Relative Position Representation support (https://arxiv.org/abs/1803.02155) 132 | https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer 133 | Modification to create and call custom MultiheadAttentionRPR 134 | ---------- 135 | """ 136 | 137 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None): 138 | super(TransformerDecoderLayerRPR, self).__init__() 139 | self.self_attn = MultiheadAttentionRPR(d_model, nhead, dropout=dropout, er_len=er_len) 140 | self.multi_attn = MultiheadAttentionRPR(d_model, nhead, dropout=dropout, er_len=er_len) 141 | # Implementation of Feedforward model 142 | self.linear1 = Linear(d_model, dim_feedforward) 143 | self.dropout = Dropout(dropout) 144 | self.linear2 = Linear(dim_feedforward, d_model) 145 | 146 | self.norm1 = LayerNorm(d_model) 147 | self.norm2 = LayerNorm(d_model) 148 | self.norm3 = LayerNorm(d_model) 149 | self.dropout1 = Dropout(dropout) 150 | self.dropout2 = Dropout(dropout) 151 | self.dropout3 = Dropout(dropout) 152 | 153 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): 154 | x = tgt 155 | x = x + self.dropout1(self.self_attn(x, x, x, attn_mask=tgt_mask, 156 | key_padding_mask=tgt_key_padding_mask)[0]) 157 | x = self.norm1(x) 158 | x = x + self.dropout2(self.multi_attn(x, memory, memory, attn_mask=memory_mask, 159 | key_padding_mask=memory_key_padding_mask)[0]) 160 | x = self.norm2(x) 161 | x = x + self.dropout3(self.linear2(self.dropout(F.relu(self.linear1(x))))) 162 | x = self.norm3(x) 163 | return x 164 | 165 | # MultiheadAttentionRPR 166 | class MultiheadAttentionRPR(Module): 167 | """ 168 | ---------- 169 | Author: Pytorch 170 | Modified: Damon Gwinn 171 | ---------- 172 | For Relative Position Representation support (https://arxiv.org/abs/1803.02155) 173 | https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/activation.html#MultiheadAttention 174 | Modification to add RPR embedding Er and call custom multi_head_attention_forward_rpr 175 | ---------- 176 | """ 177 | 178 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, er_len=None): 179 | super(MultiheadAttentionRPR, self).__init__() 180 | self.embed_dim = embed_dim 181 | self.kdim = kdim if kdim is not None else embed_dim 182 | self.vdim = vdim if vdim is not None else embed_dim 183 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 184 | 185 | self.num_heads = num_heads 186 | self.dropout = dropout 187 | self.head_dim = embed_dim // num_heads 188 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 189 | 190 | self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) 191 | 192 | if self._qkv_same_embed_dim is False: 193 | self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) 194 | self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) 195 | self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) 196 | 197 | if bias: 198 | self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) 199 | else: 200 | self.register_parameter('in_proj_bias', None) 201 | self.out_proj = Linear(embed_dim, embed_dim, bias=bias) 202 | 203 | if add_bias_kv: 204 | self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) 205 | self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) 206 | else: 207 | self.bias_k = self.bias_v = None 208 | 209 | self.add_zero_attn = add_zero_attn 210 | 211 | # Adding RPR embedding matrix 212 | if(er_len is not None): 213 | self.Er = Parameter(torch.rand((er_len, self.head_dim), dtype=torch.float32)) 214 | else: 215 | self.Er = None 216 | 217 | self._reset_parameters() 218 | 219 | def _reset_parameters(self): 220 | if self._qkv_same_embed_dim: 221 | xavier_uniform_(self.in_proj_weight) 222 | else: 223 | xavier_uniform_(self.q_proj_weight) 224 | xavier_uniform_(self.k_proj_weight) 225 | xavier_uniform_(self.v_proj_weight) 226 | 227 | if self.in_proj_bias is not None: 228 | constant_(self.in_proj_bias, 0.) 229 | constant_(self.out_proj.bias, 0.) 230 | if self.bias_k is not None: 231 | xavier_normal_(self.bias_k) 232 | if self.bias_v is not None: 233 | xavier_normal_(self.bias_v) 234 | 235 | def forward(self, query, key, value, key_padding_mask=None, 236 | need_weights=True, attn_mask=None): 237 | 238 | if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False: 239 | # return F.multi_head_attention_forward( 240 | # query, key, value, self.embed_dim, self.num_heads, 241 | # self.in_proj_weight, self.in_proj_bias, 242 | # self.bias_k, self.bias_v, self.add_zero_attn, 243 | # self.dropout, self.out_proj.weight, self.out_proj.bias, 244 | # training=self.training, 245 | # key_padding_mask=key_padding_mask, need_weights=need_weights, 246 | # attn_mask=attn_mask, use_separate_proj_weight=True, 247 | # q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 248 | # v_proj_weight=self.v_proj_weight) 249 | 250 | return multi_head_attention_forward_rpr( 251 | query, key, value, self.embed_dim, self.num_heads, 252 | self.in_proj_weight, self.in_proj_bias, 253 | self.bias_k, self.bias_v, self.add_zero_attn, 254 | self.dropout, self.out_proj.weight, self.out_proj.bias, 255 | training=self.training, 256 | key_padding_mask=key_padding_mask, need_weights=need_weights, 257 | attn_mask=attn_mask, use_separate_proj_weight=True, 258 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 259 | v_proj_weight=self.v_proj_weight, rpr_mat=self.Er) 260 | else: 261 | if not hasattr(self, '_qkv_same_embed_dim'): 262 | warnings.warn('A new version of MultiheadAttention module has been implemented. \ 263 | Please re-train your model with the new module', 264 | UserWarning) 265 | 266 | # return F.multi_head_attention_forward( 267 | # query, key, value, self.embed_dim, self.num_heads, 268 | # self.in_proj_weight, self.in_proj_bias, 269 | # self.bias_k, self.bias_v, self.add_zero_attn, 270 | # self.dropout, self.out_proj.weight, self.out_proj.bias, 271 | # training=self.training, 272 | # key_padding_mask=key_padding_mask, need_weights=need_weights, 273 | # attn_mask=attn_mask) 274 | 275 | return multi_head_attention_forward_rpr( 276 | query, key, value, self.embed_dim, self.num_heads, 277 | self.in_proj_weight, self.in_proj_bias, 278 | self.bias_k, self.bias_v, self.add_zero_attn, 279 | self.dropout, self.out_proj.weight, self.out_proj.bias, 280 | training=self.training, 281 | key_padding_mask=key_padding_mask, need_weights=need_weights, 282 | attn_mask=attn_mask, rpr_mat=self.Er) 283 | 284 | # multi_head_attention_forward_rpr 285 | def multi_head_attention_forward_rpr(query, # type: Tensor 286 | key, # type: Tensor 287 | value, # type: Tensor 288 | embed_dim_to_check, # type: int 289 | num_heads, # type: int 290 | in_proj_weight, # type: Tensor 291 | in_proj_bias, # type: Tensor 292 | bias_k, # type: Optional[Tensor] 293 | bias_v, # type: Optional[Tensor] 294 | add_zero_attn, # type: bool 295 | dropout_p, # type: float 296 | out_proj_weight, # type: Tensor 297 | out_proj_bias, # type: Tensor 298 | training=True, # type: bool 299 | key_padding_mask=None, # type: Optional[Tensor] 300 | need_weights=True, # type: bool 301 | attn_mask=None, # type: Optional[Tensor] 302 | use_separate_proj_weight=False, # type: bool 303 | q_proj_weight=None, # type: Optional[Tensor] 304 | k_proj_weight=None, # type: Optional[Tensor] 305 | v_proj_weight=None, # type: Optional[Tensor] 306 | static_k=None, # type: Optional[Tensor] 307 | static_v=None, # type: Optional[Tensor] 308 | rpr_mat=None 309 | ): 310 | """ 311 | ---------- 312 | Author: Pytorch 313 | Modified: Damon Gwinn 314 | ---------- 315 | For Relative Position Representation support (https://arxiv.org/abs/1803.02155) 316 | https://pytorch.org/docs/1.2.0/_modules/torch/nn/functional.html 317 | Modification to take RPR embedding matrix and perform skew optimized RPR (https://arxiv.org/abs/1809.04281) 318 | ---------- 319 | """ 320 | 321 | # type: (...) -> Tuple[Tensor, Optional[Tensor]] 322 | 323 | qkv_same = torch.equal(query, key) and torch.equal(key, value) 324 | kv_same = torch.equal(key, value) 325 | 326 | tgt_len, bsz, embed_dim = query.size() 327 | assert embed_dim == embed_dim_to_check 328 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 329 | assert key.size() == value.size() 330 | 331 | head_dim = embed_dim // num_heads 332 | assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" 333 | scaling = float(head_dim) ** -0.5 334 | 335 | if use_separate_proj_weight is not True: 336 | if qkv_same: 337 | # self-attention 338 | q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) 339 | 340 | elif kv_same: 341 | # encoder-decoder attention 342 | # This is inline in_proj function with in_proj_weight and in_proj_bias 343 | _b = in_proj_bias 344 | _start = 0 345 | _end = embed_dim 346 | _w = in_proj_weight[_start:_end, :] 347 | if _b is not None: 348 | _b = _b[_start:_end] 349 | q = linear(query, _w, _b) 350 | 351 | if key is None: 352 | assert value is None 353 | k = None 354 | v = None 355 | else: 356 | 357 | # This is inline in_proj function with in_proj_weight and in_proj_bias 358 | _b = in_proj_bias 359 | _start = embed_dim 360 | _end = None 361 | _w = in_proj_weight[_start:, :] 362 | if _b is not None: 363 | _b = _b[_start:] 364 | k, v = linear(key, _w, _b).chunk(2, dim=-1) 365 | 366 | else: 367 | # This is inline in_proj function with in_proj_weight and in_proj_bias 368 | _b = in_proj_bias 369 | _start = 0 370 | _end = embed_dim 371 | _w = in_proj_weight[_start:_end, :] 372 | if _b is not None: 373 | _b = _b[_start:_end] 374 | q = linear(query, _w, _b) 375 | 376 | # This is inline in_proj function with in_proj_weight and in_proj_bias 377 | _b = in_proj_bias 378 | _start = embed_dim 379 | _end = embed_dim * 2 380 | _w = in_proj_weight[_start:_end, :] 381 | if _b is not None: 382 | _b = _b[_start:_end] 383 | k = linear(key, _w, _b) 384 | 385 | # This is inline in_proj function with in_proj_weight and in_proj_bias 386 | _b = in_proj_bias 387 | _start = embed_dim * 2 388 | _end = None 389 | _w = in_proj_weight[_start:, :] 390 | if _b is not None: 391 | _b = _b[_start:] 392 | v = linear(value, _w, _b) 393 | else: 394 | q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) 395 | len1, len2 = q_proj_weight_non_opt.size() 396 | assert len1 == embed_dim and len2 == query.size(-1) 397 | 398 | k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) 399 | len1, len2 = k_proj_weight_non_opt.size() 400 | assert len1 == embed_dim and len2 == key.size(-1) 401 | 402 | v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) 403 | len1, len2 = v_proj_weight_non_opt.size() 404 | assert len1 == embed_dim and len2 == value.size(-1) 405 | 406 | if in_proj_bias is not None: 407 | q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) 408 | k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) 409 | v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) 410 | else: 411 | q = linear(query, q_proj_weight_non_opt, in_proj_bias) 412 | k = linear(key, k_proj_weight_non_opt, in_proj_bias) 413 | v = linear(value, v_proj_weight_non_opt, in_proj_bias) 414 | q = q * scaling 415 | 416 | if bias_k is not None and bias_v is not None: 417 | if static_k is None and static_v is None: 418 | k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 419 | v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 420 | if attn_mask is not None: 421 | attn_mask = torch.cat([attn_mask, 422 | torch.zeros((attn_mask.size(0), 1), 423 | dtype=attn_mask.dtype, 424 | device=attn_mask.device)], dim=1) 425 | if key_padding_mask is not None: 426 | key_padding_mask = torch.cat( 427 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 428 | dtype=key_padding_mask.dtype, 429 | device=key_padding_mask.device)], dim=1) 430 | else: 431 | assert static_k is None, "bias cannot be added to static key." 432 | assert static_v is None, "bias cannot be added to static value." 433 | else: 434 | assert bias_k is None 435 | assert bias_v is None 436 | 437 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 438 | if k is not None: 439 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 440 | if v is not None: 441 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 442 | 443 | if static_k is not None: 444 | assert static_k.size(0) == bsz * num_heads 445 | assert static_k.size(2) == head_dim 446 | k = static_k 447 | 448 | if static_v is not None: 449 | assert static_v.size(0) == bsz * num_heads 450 | assert static_v.size(2) == head_dim 451 | v = static_v 452 | 453 | src_len = k.size(1) 454 | 455 | if key_padding_mask is not None: 456 | assert key_padding_mask.size(0) == bsz 457 | assert key_padding_mask.size(1) == src_len 458 | 459 | if add_zero_attn: 460 | src_len += 1 461 | k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) 462 | v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) 463 | if attn_mask is not None: 464 | attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1), 465 | dtype=attn_mask.dtype, 466 | device=attn_mask.device)], dim=1) 467 | if key_padding_mask is not None: 468 | key_padding_mask = torch.cat( 469 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 470 | dtype=key_padding_mask.dtype, 471 | device=key_padding_mask.device)], dim=1) 472 | 473 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 474 | assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] 475 | 476 | ######### ADDITION OF RPR ########### 477 | if(rpr_mat is not None): 478 | rpr_mat = _get_valid_embedding(rpr_mat, q.shape[1], k.shape[1]) 479 | qe = torch.einsum("hld,md->hlm", q, rpr_mat) 480 | srel = _skew(qe) 481 | 482 | attn_output_weights += srel 483 | 484 | if attn_mask is not None: 485 | attn_mask = attn_mask.unsqueeze(0) 486 | attn_output_weights += attn_mask 487 | 488 | if key_padding_mask is not None: 489 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 490 | attn_output_weights = attn_output_weights.masked_fill( 491 | key_padding_mask.unsqueeze(1).unsqueeze(2), 492 | float('-inf'), 493 | ) 494 | attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) 495 | 496 | attn_output_weights = softmax( 497 | attn_output_weights, dim=-1) 498 | 499 | attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) 500 | 501 | attn_output = torch.bmm(attn_output_weights, v) 502 | assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] 503 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 504 | attn_output = linear(attn_output, out_proj_weight, out_proj_bias) 505 | 506 | if need_weights: 507 | # average attention weights over heads 508 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 509 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 510 | else: 511 | return attn_output, None 512 | 513 | def _get_valid_embedding(Er, len_q, len_k): 514 | """ 515 | ---------- 516 | Author: Damon Gwinn 517 | ---------- 518 | Gets valid embeddings based on max length of RPR attention 519 | ---------- 520 | """ 521 | 522 | len_e = Er.shape[0] 523 | start = max(0, len_e - len_q) 524 | return Er[start:, :] 525 | 526 | def _skew(qe): 527 | """ 528 | ---------- 529 | Author: Damon Gwinn 530 | ---------- 531 | Performs the skew optimized RPR computation (https://arxiv.org/abs/1809.04281) 532 | ---------- 533 | """ 534 | 535 | sz = qe.shape[1] 536 | mask = (torch.triu(torch.ones(sz, sz).to(qe.device)) == 1).float().flip(0) 537 | 538 | qe = mask * qe 539 | qe = F.pad(qe, (1,0, 0,0, 0,0)) 540 | qe = torch.reshape(qe, (qe.shape[0], qe.shape[2], qe.shape[1])) 541 | 542 | srel = qe[:, 1:, :] 543 | return srel -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | from models.rpr import TransformerDecoderLayerRPR, TransformerDecoderRPR, TransformerEncoderLayerRPR, TransformerEncoderRPR 2 | from torch import Tensor 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch.nn import Transformer 7 | from torch.nn.modules.normalization import LayerNorm 8 | import math 9 | 10 | """ 11 | Code adapted from: https://pytorch.org/tutorials/beginner/translation_transformer.html 12 | """ 13 | 14 | # helper Module that adds positional encoding to the token embedding to introduce a notion of word order. 15 | class PositionalEncoding(nn.Module): 16 | def __init__(self, 17 | emb_size: int, 18 | dropout: float, 19 | maxlen: int = 5000): 20 | super(PositionalEncoding, self).__init__() 21 | den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size) 22 | pos = torch.arange(0, maxlen).reshape(maxlen, 1) 23 | pos_embedding = torch.zeros((maxlen, emb_size)) 24 | pos_embedding[:, 0::2] = torch.sin(pos * den) 25 | pos_embedding[:, 1::2] = torch.cos(pos * den) 26 | pos_embedding = pos_embedding.unsqueeze(-2) 27 | 28 | self.dropout = nn.Dropout(dropout) 29 | self.register_buffer('pos_embedding', pos_embedding) 30 | 31 | def forward(self, token_embedding: Tensor): 32 | return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :]) 33 | 34 | # helper Module to convert tensor of input indices into corresponding tensor of token embeddings 35 | class TokenEmbedding(nn.Module): 36 | def __init__(self, vocab_size: int, emb_size): 37 | super(TokenEmbedding, self).__init__() 38 | self.embedding = nn.Embedding(vocab_size, emb_size) 39 | self.emb_size = emb_size 40 | 41 | def forward(self, tokens: Tensor): 42 | return self.embedding(tokens.long()) * math.sqrt(self.emb_size) 43 | 44 | # Seq2Seq Network 45 | class DeepBeatsTransformer(nn.Transformer): 46 | def __init__(self, 47 | num_encoder_layers: int, 48 | num_decoder_layers: int, 49 | emb_size: int, 50 | nhead: int, 51 | src_vocab_size: int, 52 | tgt_vocab_size: int, 53 | dim_feedforward: int = 512, 54 | dropout: float = 0.1, 55 | max_seq: int = 64): 56 | super(DeepBeatsTransformer, self).__init__() 57 | encoder_norm = LayerNorm(emb_size) 58 | encoder_layer = TransformerEncoderLayerRPR(emb_size, nhead, dim_feedforward, dropout, max_seq) 59 | encoder = TransformerEncoderRPR(encoder_layer, num_encoder_layers, encoder_norm) 60 | 61 | decoder_norm = LayerNorm(emb_size) 62 | decoder_layer = TransformerDecoderLayerRPR(emb_size, nhead, dim_feedforward, dropout, max_seq) 63 | decoder = TransformerDecoderRPR(decoder_layer, num_decoder_layers, decoder_norm) 64 | 65 | self.transformer = Transformer(d_model=emb_size, 66 | nhead=nhead, 67 | num_encoder_layers=num_encoder_layers, 68 | num_decoder_layers=num_decoder_layers, 69 | dim_feedforward=dim_feedforward, 70 | dropout=dropout, 71 | custom_encoder=encoder, 72 | custom_decoder=decoder) 73 | self.num_notes = tgt_vocab_size 74 | self.emb_size = emb_size 75 | self.generator = nn.Linear(emb_size, tgt_vocab_size) 76 | self.src_tok_emb = nn.Linear(src_vocab_size, emb_size) 77 | self.dropout = nn.Dropout(dropout) 78 | self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size) 79 | self.positional_encoding = PositionalEncoding( 80 | emb_size, dropout=dropout) 81 | self.device = 'cpu' 82 | self._initialize() 83 | 84 | def _initialize(self): 85 | for p in self.parameters(): 86 | if p.dim() > 1: 87 | nn.init.xavier_uniform_(p) 88 | 89 | def to(self, device): 90 | super(DeepBeatsTransformer, self).to(device) 91 | self.device = device 92 | return self 93 | 94 | def forward(self, 95 | src: Tensor, 96 | trg: Tensor, 97 | src_mask: Tensor, 98 | tgt_mask: Tensor, 99 | src_padding_mask: Tensor, 100 | tgt_padding_mask: Tensor, 101 | memory_key_padding_mask: Tensor): 102 | src_emb = self.positional_encoding(self.src_tok_emb(src)) 103 | tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg)) 104 | 105 | outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, 106 | src_padding_mask, tgt_padding_mask, memory_key_padding_mask) 107 | return self.generator(outs) 108 | 109 | def encode(self, src: Tensor, src_mask: Tensor): 110 | src_emb = self.positional_encoding(self.src_tok_emb(src)) 111 | return self.transformer.encoder(src_emb, src_mask) 112 | 113 | def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor): 114 | tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt)) 115 | return self.transformer.decoder(tgt_emb, memory, tgt_mask) 116 | 117 | def loss_function(self, pred, target): 118 | """ 119 | Pred: (batch_size, seq_len, num_notes), logits 120 | Target: (batch_size, seq_len), range from 0 to num_notes-1 121 | """ 122 | criterion = nn.CrossEntropyLoss() 123 | target = target.flatten() # (batch_size * seq_len) 124 | pred = pred.reshape(-1, pred.shape[-1]) # (batch_size * seq_len, num_notes) 125 | loss = criterion(pred, target) 126 | return loss 127 | 128 | def clip_gradients_(self, max_value): 129 | torch.nn.utils.clip_grad.clip_grad_value_(self.parameters(), max_value) 130 | 131 | def create_mask(self, src, tgt): 132 | src_seq_len = src.shape[0] 133 | tgt_seq_len = tgt.shape[0] 134 | batch_size = src.shape[1] 135 | 136 | tgt_mask = self.generate_square_subsequent_mask(tgt_seq_len) 137 | src_mask = torch.zeros((src_seq_len, src_seq_len)).type(torch.bool) 138 | 139 | src_padding_mask = torch.zeros((batch_size, src_seq_len)).type(torch.bool)# we don't have padding in our src/tgt 140 | tgt_padding_mask = torch.zeros((batch_size, tgt_seq_len)).type(torch.bool) 141 | return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask 142 | -------------------------------------------------------------------------------- /models/vanilla_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from models.model_utils import ConcatPrev 5 | 6 | 7 | class DeepBeatsVanillaRNN(nn.Module): 8 | def __init__(self, num_notes, embed_size, hidden_dim): 9 | super(DeepBeatsVanillaRNN, self).__init__() 10 | self.num_notes = num_notes 11 | self.note_embedding = nn.Embedding(num_notes, embed_size) 12 | self.concat_prev = ConcatPrev() 13 | self.concat_input_fc = nn.Linear(embed_size + 2, embed_size + 2) 14 | self.concat_input_activation = nn.LeakyReLU() 15 | self.layer1 = nn.RNN(embed_size + 2, hidden_dim, batch_first=True) 16 | self.layer2 = nn.RNN(hidden_dim, hidden_dim, batch_first=True) 17 | self.notes_output = nn.Linear(hidden_dim, num_notes) 18 | 19 | self._initializer_weights() 20 | 21 | def _default_init_hidden(self, batch_size): 22 | device = next(self.parameters()).device 23 | h1_0 = torch.zeros(1, batch_size, self.layer1.hidden_size).to(device) 24 | h2_0 = torch.zeros(1, batch_size, self.layer2.hidden_size).to(device) 25 | return h1_0, h2_0 26 | 27 | def _initializer_weights(self): 28 | for m in self.modules(): 29 | if isinstance(m, nn.Linear): 30 | nn.init.xavier_uniform_(m.weight) 31 | nn.init.constant_(m.bias, 0) 32 | 33 | def forward(self, x, y_prev, init_hidden = None): 34 | h1_0, h2_0 = self._default_init_hidden(x.shape[0]) if init_hidden is None else init_hidden 35 | y_prev_embed = self.note_embedding(y_prev) 36 | X = self.concat_prev(x, y_prev_embed) 37 | # Concat input 38 | X_fc = self.concat_input_fc(X) 39 | X_fc = self.concat_input_activation(X_fc) 40 | # residual connection 41 | X = X_fc + X 42 | X, h1 = self.layer1(X, h1_0) 43 | X, h2 = self.layer2(X, h2_0) 44 | predicted_notes = self.notes_output(X) 45 | return predicted_notes, (h1, h2) 46 | 47 | def loss_function(self, pred, target): 48 | criterion = nn.CrossEntropyLoss() 49 | target = target.flatten() # (batch_size * seq_len) 50 | pred = pred.reshape(-1, pred.shape[-1]) # (batch_size * seq_len, num_notes) 51 | loss = criterion(pred, target) 52 | return loss 53 | 54 | def clip_gradients_(self, max_value): 55 | torch.nn.utils.clip_grad.clip_grad_value_(self.parameters(), max_value) 56 | -------------------------------------------------------------------------------- /predict_stream.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import preprocess.dataset 5 | import torch 6 | import toml 7 | from utils.constants import NOTE_MAP 8 | 9 | from utils.data_paths import DataPaths 10 | from utils.model import CONFIG_PATH, get_model, load_checkpoint 11 | from utils.render import render_midi 12 | from utils.sample import beam_search, stochastic_search 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser('Save Predicted Notes Sequence to Midi') 16 | parser.add_argument('-m','--model_name', type=str) 17 | parser.add_argument('-c','--checkpoint_path', type=str) 18 | parser.add_argument('-o','--midi_filename', type=str, default="output.mid") 19 | parser.add_argument('-d','--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 20 | parser.add_argument('-s','--source', type=str, default="interactive") 21 | parser.add_argument('-t','--profile', type=str, default="default") 22 | 23 | main_args = parser.parse_args() 24 | model_name = main_args.model_name 25 | checkpoint_path = main_args.checkpoint_path 26 | midi_filename = main_args.midi_filename 27 | device = main_args.device 28 | source = main_args.source 29 | profile = main_args.profile 30 | 31 | config = toml.load(CONFIG_PATH) 32 | global_config = config['global'] 33 | model_config = config["model"][main_args.model_name] 34 | 35 | paths = DataPaths() 36 | 37 | # sample one midi file 38 | if main_args.source == 'interactive': 39 | from utils.beats_generator import create_beat 40 | X = create_beat() 41 | X[0][0] = 2. 42 | # convert to float32 43 | X = np.array(X, dtype=np.float32) 44 | elif main_args.source == 'dataset': 45 | dataset = preprocess.dataset.BeatsRhythmsDataset(64) # not used 46 | dataset.load(global_config['dataset']) 47 | idx = np.random.randint(0, len(dataset)) 48 | X = dataset.beats_list[idx][:64] 49 | else: 50 | with open(main_args.source, 'rb') as f: 51 | X = np.load(f, allow_pickle=True) 52 | X[0][0] = 2. 53 | X = np.array(X, dtype=np.float32) 54 | 55 | # load model 56 | 57 | model = get_model(main_args.model_name, model_config, main_args.device) 58 | model.eval() 59 | load_checkpoint(checkpoint_path, model, device) 60 | print(model) 61 | 62 | # generate notes seq given durs seq 63 | profile = config["sampling"][profile] 64 | try: 65 | hint = [NOTE_MAP[h] for h in profile["hint"]] 66 | except KeyError: 67 | print(f"some note in {profile['hint']} not found in NOTE_MAP") 68 | exit(1) 69 | if profile["strategy"] == "stochastic": 70 | notes = stochastic_search(model, X, hint, device, profile["top_p"], profile["top_k"], profile["repeat_decay"], profile["temperature"]) 71 | elif profile["strategy"] == "beam": 72 | notes = beam_search(model, X, hint, device, profile["repeat_decay"], profile["num_beams"], profile["beam_prob"], profile["temperature"]) 73 | else: 74 | raise NotImplementedError(f"strategy {profile['strategy']} not implemented") 75 | print(notes) 76 | # convert stream to midi 77 | midi_paths = paths.midi_outputs_dir / main_args.midi_filename 78 | render_midi(X, notes, midi_paths) -------------------------------------------------------------------------------- /prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from preprocess.dataset import BeatsRhythmsDataset 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser('Prepare Dataset') 7 | parser.add_argument('--mono', type=bool, default=True) 8 | args = parser.parse_args() 9 | 10 | mono = args.mono 11 | dataset = BeatsRhythmsDataset(seq_len = 64) #seq_len = 64 is not used 12 | dataset.load(mono = mono, force_prepare = True) 13 | dataset.save_processed_to_cache() 14 | 15 | 16 | -------------------------------------------------------------------------------- /preprocess/constants.py: -------------------------------------------------------------------------------- 1 | DATASETS_CONFIG_PATH = "datasets.toml" 2 | ADL_PIANO_TOTAL_SIZE = 10913 -------------------------------------------------------------------------------- /preprocess/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from preprocess.constants import DATASETS_CONFIG_PATH 3 | from preprocess.prepare import download_midi_files, parse_melody_to_beats_notes, parse_midi_to_melody 4 | 5 | from preprocess.fetch import download 6 | 7 | import numpy as np 8 | import toml 9 | import warnings 10 | import pickle 11 | from tqdm import tqdm 12 | import csv 13 | from utils.constants import NOTE_START 14 | from utils.data_paths import DataPaths 15 | from dataclasses import dataclass 16 | 17 | PREPROCESS_SAVE_FREQ = 32 18 | 19 | @dataclass 20 | class MetaData: 21 | canonical_composer: str 22 | canonical_title: str 23 | split: str 24 | midi_filename: str 25 | 26 | def _processed_name(dataset: str, dataset_type:str): 27 | return f"processed_{dataset}_{dataset_type}_v2.pkl" 28 | 29 | class BeatsRhythmsDataset(Dataset): 30 | def __init__(self, seq_len, seed = 12345): 31 | self.seq_len = seq_len 32 | self.beats_list = [] 33 | self.notes_list = [] 34 | self.metadata_list = [] 35 | self.name_to_idx = {} 36 | self.rng = np.random.default_rng(seed) 37 | self.seed = seed 38 | self.dataset = "" 39 | self.dataset_type = "" 40 | 41 | def load(self, dataset = "mastero", mono=True, force_prepare = False): 42 | assert mono, "Only mono is supported for now" 43 | paths = DataPaths() 44 | dataset_type = "mono" if mono else "chords" 45 | processed_path = paths.prepared_data_dir / _processed_name(dataset, dataset_type) 46 | progress_path = paths.cache_dir / f"progress_{dataset}_{dataset_type}.pkl" 47 | self.dataset = dataset 48 | self.dataset_type = dataset_type 49 | 50 | ## Check if we have processed data 51 | ### Locally processed data 52 | if processed_path.exists() and not force_prepare: 53 | print(f"Found processed data at {processed_path}.") 54 | with open(processed_path, "rb") as f: 55 | state_dict = pickle.load(f) 56 | self.load_processed(state_dict) 57 | return 58 | ### Remotely processed data 59 | config = toml.load(DATASETS_CONFIG_PATH)["datasets"][dataset_type][dataset] 60 | if "processed" in config and not force_prepare: 61 | prepared = download(_processed_name(dataset, dataset_type), config["processed"]) 62 | if prepared is None: 63 | raise ValueError("Failed to download prepared dataset") 64 | with open(prepared, "rb") as f: 65 | state_dict = pickle.load(f) 66 | self.load_processed(state_dict) 67 | return 68 | 69 | ## Preprocessing 70 | midi_files, num_files = download_midi_files(dataset, config["midi"]) 71 | metadata_path = download(f"metadata_{dataset}.csv", config["metadata"]) 72 | assert metadata_path is not None, "Failed to download metadata" 73 | metadata = {} 74 | with open(metadata_path, "r", encoding="utf-8") as f: 75 | for row in csv.DictReader(f): 76 | metadata[row["midi_filename"]] = MetaData( 77 | canonical_composer=row["canonical_composer"], 78 | canonical_title=row["canonical_title"], 79 | split=row["split"], 80 | midi_filename=row["midi_filename"], 81 | ) 82 | 83 | skip = 0 84 | if progress_path.exists(): 85 | with open(progress_path, "rb") as f: 86 | state_dict = pickle.load(f) 87 | self.load_processed(state_dict) 88 | skip = len(self.metadata_list) 89 | print(f"Resuming from {skip} files") 90 | bar = tqdm(total=num_files, desc = "Processing MIDI files") 91 | warnings_cnt, errors_cnt, saved = 0, 0, 0 92 | for filename, io in midi_files: 93 | if skip > 0: 94 | skip -= 1 95 | bar.update(1) 96 | continue 97 | beats, notes = None, None 98 | with warnings.catch_warnings(): 99 | warnings.filterwarnings("error") 100 | try: 101 | melody, _ = parse_midi_to_melody(io) 102 | beats, notes = parse_melody_to_beats_notes(melody) 103 | except Warning: 104 | warnings_cnt += 1 105 | bar.set_description(f"Parsing MIDI files ({warnings_cnt} warns, {errors_cnt} errors)", refresh=True) 106 | except KeyboardInterrupt: 107 | self.save_processed_to_file(progress_path) 108 | print(f"KeyboardInterrupt detected, saving progress and exit") 109 | exit() 110 | except Exception: 111 | errors_cnt += 1 112 | bar.set_description(f"Parsing MIDI files ({warnings_cnt} warns, {errors_cnt} errors)", refresh=True) 113 | if "truncate" in config: 114 | filename = filename[len(config["truncate"]):] 115 | if beats is not None and notes is not None: 116 | self.beats_list.append(beats) 117 | self.notes_list.append(notes) 118 | self.metadata_list.append(metadata[filename]) 119 | self.name_to_idx[filename] = len(self.metadata_list) - 1 120 | bar.update(1) 121 | if len(self.metadata_list) % PREPROCESS_SAVE_FREQ == 0: 122 | self.save_processed_to_file(progress_path) 123 | saved = len(self.metadata_list) 124 | bar.set_postfix(warns=warnings_cnt, errors=errors_cnt, saved=saved) 125 | bar.close() 126 | 127 | 128 | def __len__(self): 129 | return len(self.metadata_list) 130 | 131 | def __getitem__(self, idx): 132 | lo = self.rng.integers(0, len(self.beats_list[idx]) - self.seq_len) 133 | hi = lo + self.seq_len 134 | 135 | beats = self.beats_list[idx][lo:hi] 136 | notes = self.notes_list[idx][lo:hi] 137 | # for teacher forcing, we need to shift the notes right by one 138 | notes_shifted = np.roll(notes, 1) 139 | if lo == 0: 140 | notes_shifted[0] = NOTE_START 141 | else: 142 | notes_shifted[0] = self.notes_list[idx][lo - 1] 143 | return { 144 | "beats": beats.astype(np.float32), 145 | "notes": notes.astype(np.int32).ravel(), 146 | "notes_shifted": notes_shifted.astype(np.int32).ravel(), 147 | } 148 | 149 | def save_processed(self) -> dict: 150 | return { 151 | "beats_list": self.beats_list, 152 | "notes_list": self.notes_list, 153 | "metadata_list": self.metadata_list, 154 | "name_to_idx": self.name_to_idx, 155 | } 156 | 157 | def save_processed_to_file(self, path): 158 | with open(path, "wb") as f: 159 | pickle.dump(self.save_processed(), f) 160 | 161 | def save_processed_to_cache(self): 162 | paths = DataPaths() 163 | processed_path = paths.prepared_data_dir / _processed_name(self.dataset, self.dataset_type) 164 | with open(processed_path, "wb") as f: 165 | pickle.dump(self.save_processed(), f) 166 | 167 | def load_processed(self, state_dict): 168 | self.beats_list = state_dict["beats_list"] 169 | self.notes_list = state_dict["notes_list"] 170 | self.metadata_list = state_dict["metadata_list"] 171 | self.name_to_idx = state_dict["name_to_idx"] 172 | 173 | def gather(self, indices): 174 | dataset = BeatsRhythmsDataset(self.seq_len, self.seed) 175 | dataset.beats_list = [self.beats_list[i] for i in indices] 176 | dataset.notes_list = [self.notes_list[i] for i in indices] 177 | dataset.metadata_list = [self.metadata_list[i] for i in indices] 178 | dataset.name_to_idx = {v.midi_filename: i for i, v in enumerate(dataset.metadata_list)} 179 | return dataset 180 | 181 | def subset_remove_short(self): 182 | """ 183 | Remove short sequences from the dataset 184 | """ 185 | indices = [i for i in range(len(self)) if len(self.beats_list[i]) >= self.seq_len] 186 | return self.gather(indices) 187 | 188 | def train_val_split(self, seed=0, val_ratio=0.1): 189 | rng = np.random.default_rng(seed) 190 | indices = np.arange(len(self.metadata_list)) 191 | rng.shuffle(indices) 192 | val_size = int(len(self.metadata_list) * val_ratio) 193 | train_indices = indices[val_size:] 194 | dev_indices = indices[:val_size] 195 | return self.gather(train_indices), self.gather(dev_indices) 196 | 197 | def subset(self, max_len): 198 | rng = np.random.default_rng(0) 199 | indices = np.arange(len(self.metadata_list)) 200 | rng.shuffle(indices) 201 | indices = indices[:max_len] 202 | return self.gather(indices) 203 | 204 | def to_stream(self, idx): 205 | from utils.render import convert_to_note_seq 206 | beats = self.beats_list[idx] 207 | notes = self.notes_list[idx] 208 | return convert_to_note_seq(beats, notes) 209 | 210 | def to_midi(self, idx, midi_path): 211 | from note_seq.midi_io import note_sequence_to_midi_file 212 | stream = self.to_stream(idx) 213 | note_sequence_to_midi_file(stream, midi_path) 214 | 215 | 216 | # def collate_fn(batch): 217 | # X, y, y_prev = zip(*batch) 218 | # X = np.array(X) 219 | # y = np.array(y) 220 | # y_prev = np.array(y_prev) 221 | # return X, y, y_prev 222 | -------------------------------------------------------------------------------- /preprocess/fetch.py: -------------------------------------------------------------------------------- 1 | # Downloader for the midi files with cache 2 | import os 3 | from pathlib import Path 4 | from typing import Optional 5 | import requests 6 | from tqdm import tqdm 7 | 8 | from utils.data_paths import DataPaths 9 | 10 | def download(filename: str, url: str) -> Optional[Path]: 11 | """Download a zip file from a URL if it's not already in the cache. 12 | Return None if the file is not found online. 13 | """ 14 | paths = DataPaths() 15 | cache_path = paths.downloads_dir / filename 16 | if not os.path.exists(cache_path): 17 | with requests.get(url, stream=True) as r: 18 | if r.status_code == 404: 19 | return None 20 | total_size_in_bytes = int(r.headers.get('content-length', 0)) 21 | chunk_size = 1024 22 | with open(cache_path, "wb") as f: 23 | print(f"Downloading {filename} from {url}") 24 | progress = tqdm(total = total_size_in_bytes, unit = 'iB', unit_scale = True, colour="cyan") 25 | for chunk in r.iter_content(chunk_size=chunk_size): 26 | progress.update(len(chunk)) 27 | f.write(chunk) 28 | progress.close() 29 | else: 30 | print("Using cached: ", filename) 31 | return cache_path 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /preprocess/prepare.py: -------------------------------------------------------------------------------- 1 | from typing import IO, Iterable, Tuple 2 | import zipfile 3 | 4 | import numpy as np 5 | from preprocess.fetch import download 6 | 7 | def download_midi_files(dataset: str, midi_url: str): 8 | """Get an iterator over all MIDI files bytestreams. 9 | 10 | Returns: 11 | - `iterator`: An iterator over all MIDI files bytestreams. 12 | - `num_files`: The number of MIDI files. 13 | """ 14 | archive_path = download(f"{dataset}.zip", midi_url) 15 | if archive_path is None: 16 | raise RuntimeError("Failed to download the dataset.") 17 | # get number of midi files 18 | total = 0 19 | with zipfile.ZipFile(archive_path, 'r') as zip_ref: 20 | for info in zip_ref.infolist(): 21 | if info.filename.endswith(".mid") or info.filename.endswith(".midi"): 22 | total += 1 23 | # iterate over midi files 24 | def _iter(): 25 | remaining = total 26 | with zipfile.ZipFile(archive_path, 'r') as zip_ref: 27 | for info in zip_ref.infolist(): 28 | if info.filename.endswith(".mid") or info.filename.endswith(".midi"): 29 | yield info.filename, zip_ref.open(info) 30 | remaining -= 1 31 | if remaining == 0: 32 | return 33 | return _iter(), total 34 | 35 | def parse_midi_to_melody(midi_file: IO[bytes]): 36 | """ 37 | Parse a MIDI file into a melody. 38 | Args: 39 | midi_file: A MIDI file bytestream. 40 | Returns: 41 | Generator as described below, number of notes 42 | yield: 43 | Tuple[start_time, end_time, pitch]. 44 | """ 45 | import note_seq.midi_io as midi_io 46 | import note_seq.melody_inference as melody_inference 47 | melody_inference.MAX_NUM_FRAMES = 100000 48 | ns = midi_io.midi_to_note_sequence(midi_file.read()) 49 | with np.errstate(divide='ignore'): 50 | instrument_id = melody_inference.infer_melody_for_sequence(ns) 51 | def _gen(): 52 | for note in ns.notes: 53 | if note.instrument == instrument_id: 54 | yield note.start_time, note.end_time, note.pitch 55 | return _gen(), len(ns.notes) 56 | 57 | def convert_start_end_to_beats(start_time: np.ndarray, end_time: np.ndarray): 58 | """ 59 | Convert start time and end time to beats. 60 | Args: 61 | start_time: array of shape (seq_length,) 62 | end_time: array of shape (seq_length,) 63 | Returns: 64 | beats: array of shape (seq_length, 2), where the first column is the rest time before current note and the second column is the current duration 65 | """ 66 | # get the rest time since last beat 67 | prev_rest = np.zeros_like(start_time) 68 | prev_rest[1:] = start_time[1:] - end_time[:-1] 69 | prev_rest[0] = start_time[0] 70 | 71 | # get the duration of the note 72 | duration = end_time - start_time 73 | 74 | return np.stack([prev_rest, duration], axis=1) 75 | 76 | def parse_melody_to_beats_notes(melody: Iterable[Tuple[float, float, int]]) -> Tuple[np.ndarray, np.ndarray]: 77 | """ 78 | Parse a MIDI file into a sequence of (prev_rest_time, duration). 79 | Args: 80 | midi_file: A MIDI file bytestream. 81 | mono: Whether to convert the MIDI file to monophonic. 82 | Returns: 83 | - A numpy array of shape (num_notes, 2). Each row represents a beat. The first column 84 | is the rest time since last beat, and the second column is the duration of the note. 85 | - A numpy array of shape (num_notes, 1). Each row represents a beat. The column 86 | is the MIDI note number. 87 | """ 88 | 89 | start_time = [] 90 | end_time = [] 91 | pitch = [] 92 | for s, e, p in melody: 93 | start_time.append(s) 94 | end_time.append(e) 95 | pitch.append(p) 96 | 97 | start_time = np.array(start_time) 98 | end_time = np.array(end_time) 99 | pitch = np.array(pitch) 100 | 101 | beats = convert_start_end_to_beats(start_time, end_time) 102 | labels = pitch.reshape(-1, 1) 103 | 104 | return beats, labels 105 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu116 2 | ipython 3 | numpy 4 | scipy 5 | torch 6 | matplotlib 7 | toml 8 | tqdm 9 | pynput 10 | tensorboard 11 | note_seq==0.0.5 12 | protobuf==4.21.2 -------------------------------------------------------------------------------- /setup_env.ps1: -------------------------------------------------------------------------------- 1 | $env:PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION="python" -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsunrise/everybody-compose/632fc77c05cae8cb61a1d81bb85e1717b3e363a4/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_codec.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from preprocess.prepare import parse_melody_to_beats_notes 3 | from utils.render import convert_to_melody 4 | import numpy as np 5 | class Tests(unittest.TestCase): 6 | MELODY = [(1.2, 3.6, 60), (4.8, 7.9, 61), (7.9, 8.2, 62), (8.2, 8.3, 63)] 7 | BEATS = np.array([ 8 | [1.2, 3.6-1.2], 9 | [4.8-3.6, 7.9-4.8], 10 | [7.9-7.9, 8.2-7.9], 11 | [8.2-8.2, 8.3-8.2], 12 | ]) 13 | 14 | def test_parse_melody_to_beats_notes(self): 15 | beats, notes = parse_melody_to_beats_notes(self.MELODY) 16 | 17 | beats_expected = self.BEATS 18 | notes_expected = np.array([60, 61, 62, 63]).reshape(-1, 1) 19 | 20 | self.assertTrue(np.allclose(beats, beats_expected)) 21 | self.assertTrue(np.allclose(notes, notes_expected)) 22 | def test_render(self): 23 | start_time, end_time, note = convert_to_melody(self.BEATS, np.array([x[2] for x in self.MELODY])) 24 | self.assertEqual(len(start_time), len(end_time)) 25 | self.assertEqual(len(start_time), len(note)) 26 | self.assertEqual(len(start_time), len(self.MELODY)) 27 | for i in range(len(self.MELODY)): 28 | self.assertAlmostEqual(start_time[i], self.MELODY[i][0]) 29 | self.assertAlmostEqual(end_time[i], self.MELODY[i][1]) 30 | self.assertAlmostEqual(note[i], self.MELODY[i][2]) 31 | 32 | 33 | if __name__ == '__main__': 34 | unittest.main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils.model import train 3 | import torch 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser('Train DeepBeats') 7 | parser.add_argument('-m','--model_name', type=str) 8 | parser.add_argument('-nf','--n_files', type=int, default=-1) 9 | parser.add_argument('-n','--n_epochs', type=int, default=100) 10 | parser.add_argument('-d','--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 11 | parser.add_argument('-s','--snapshots_freq', type=int, default=200) 12 | parser.add_argument('-c','--checkpoint', type=str, default=None) 13 | args = parser.parse_args() 14 | 15 | main_args = parser.parse_args() 16 | train(**vars(main_args)) 17 | -------------------------------------------------------------------------------- /utils/beats_generator.py: -------------------------------------------------------------------------------- 1 | from pynput import keyboard 2 | import numpy as np 3 | import time 4 | 5 | from utils.data_paths import DataPaths 6 | from preprocess.prepare import convert_start_end_to_beats 7 | from enum import Enum 8 | from sys import platform 9 | 10 | class Event(Enum): 11 | PRESS = 1 12 | RELEASE = 0 13 | 14 | def create_beat(): 15 | ''' 16 | Parse user's key presses into a sequence of beats. 17 | Record user's space key pressing time until the enter key is pressed 18 | Return a numpy 2d array in shape of (seq_length, 2) representing beat sequence that user enters. 19 | each row of the array is a beat represented by [prev_rest_time, duration] 20 | ''' 21 | 22 | TAP_KEYS = {keyboard.KeyCode.from_char('z'), keyboard.KeyCode.from_char('x'), keyboard.Key.space} 23 | ENTER_KEY = keyboard.Key.enter 24 | 25 | events = [] 26 | pressing_key = None 27 | base_time = time.time() 28 | def on_press(key): 29 | ''' 30 | listener that monitor presses of the space key 31 | record the previous rest time until the space key is pressed 32 | ''' 33 | nonlocal events, pressing_key 34 | if key in TAP_KEYS and key != pressing_key: 35 | curr_time = time.time() - base_time 36 | if pressing_key is not None: 37 | events.append((Event.RELEASE, curr_time)) 38 | events.append((Event.PRESS, curr_time)) 39 | pressing_key = key 40 | 41 | def on_release(key): 42 | ''' 43 | listener that monitor release of the space key and enter key 44 | record the pressed time on the space key 45 | stop the listener when the enter key is released 46 | ''' 47 | nonlocal events, pressing_key 48 | if key == ENTER_KEY: 49 | # Stop listener 50 | curr_time = time.time() - base_time 51 | if pressing_key is not None: 52 | events.append((Event.RELEASE, curr_time)) 53 | return False 54 | elif key in TAP_KEYS and key == pressing_key: 55 | events.append((Event.RELEASE, time.time() - base_time)) 56 | pressing_key = None 57 | 58 | 59 | print("use z,x,space key on keyboard to create a sequence of beat") 60 | print("hit enter to stop") 61 | suppress = True if platform == "win32" else False 62 | with keyboard.Listener(on_press=on_press, on_release=on_release, suppress=suppress) as listener: 63 | listener.join() 64 | 65 | # convert events to start_time and end_time 66 | start_time = [] 67 | end_time = [] 68 | num_pressed = 0 69 | for event, timestamp in events: 70 | if event == Event.PRESS: 71 | start_time.append(timestamp) 72 | if num_pressed > 0: 73 | end_time.append(timestamp) 74 | num_pressed += 1 75 | else: 76 | if num_pressed == 1: 77 | end_time.append(timestamp) 78 | num_pressed -= 1 79 | assert len(start_time) == len(end_time) 80 | assert num_pressed == 0 81 | # print("start_time: ", start_time) 82 | # print("end_time: ", end_time) 83 | beat_sequence = convert_start_end_to_beats(np.array(start_time), np.array(end_time)) 84 | 85 | paths = DataPaths() 86 | file_name = "last_recorded.npy" 87 | file_path = paths.beats_rhythms_dir / file_name 88 | np.save(file_path, beat_sequence) 89 | 90 | return beat_sequence -------------------------------------------------------------------------------- /utils/constants.py: -------------------------------------------------------------------------------- 1 | # the place-holder for the first value of tgt sequence, indicating the start of the sequence. 2 | NOTE_START = 0 3 | 4 | NOTE_MAP = { 5 | "1": 60, 6 | "1#": 61, 7 | "2": 62, 8 | "2#": 63, 9 | "3b": 63, 10 | "3": 64, 11 | "4": 65, 12 | "4#": 66, 13 | "5": 67, 14 | "5#": 68, 15 | "6b": 68, 16 | "6": 69, 17 | "6#": 70, 18 | "7b": 70, 19 | "7": 71, 20 | "8": 72, 21 | } -------------------------------------------------------------------------------- /utils/data_paths.py: -------------------------------------------------------------------------------- 1 | """ 2 | Project Tree Structure 3 | .project_data 4 | | downloads/ 5 | | snapshots/ 6 | | prepared_data/ 7 | | midi_outputs/ 8 | | tensorboard/ 9 | """ 10 | from pathlib import Path 11 | 12 | class DataPaths: 13 | def __init__(self): 14 | self.cache_dir = Path(".project_data") 15 | self.cache_dir.mkdir(exist_ok=True) 16 | self.downloads_dir = self.cache_dir / "downloads" 17 | self.downloads_dir.mkdir(exist_ok=True) 18 | self.snapshots_dir = self.cache_dir / "snapshots" 19 | self.snapshots_dir.mkdir(exist_ok=True) 20 | self.prepared_data_dir = self.cache_dir / "prepared_data" 21 | self.prepared_data_dir.mkdir(exist_ok=True) 22 | self.midi_outputs_dir = self.cache_dir / "midi_outputs" 23 | self.midi_outputs_dir.mkdir(exist_ok=True) 24 | self.tensorboard_dir = self.cache_dir / "tensorboard" 25 | self.tensorboard_dir.mkdir(exist_ok=True) 26 | self.beats_rhythms_dir = self.cache_dir / "beats_rhythms" 27 | self.beats_rhythms_dir.mkdir(exist_ok=True) -------------------------------------------------------------------------------- /utils/devices.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def status_check(): 3 | cuda_available = torch.cuda.is_available() 4 | print("CUDA available: {}".format(cuda_available)) 5 | if not cuda_available: 6 | return 7 | cuda_device_count = torch.cuda.device_count() 8 | print("CUDA device count: {}".format(cuda_device_count)) 9 | cuda_device_name = torch.cuda.get_device_name(0) 10 | print("CUDA device name: {}".format(cuda_device_name)) -------------------------------------------------------------------------------- /utils/distribution.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Tuple 3 | import torch 4 | from models.lstm_local_attn import DeepBeatsLSTMLocalAttn 5 | from models.attention_rnn import DeepBeatsAttentionRNN 6 | from models.transformer import DeepBeatsTransformer 7 | from models.vanilla_rnn import DeepBeatsVanillaRNN 8 | 9 | from utils.constants import NOTE_START 10 | 11 | 12 | class DistributionGenerator(ABC): 13 | @abstractmethod 14 | def __init__(self): 15 | """ 16 | Implementor uses this method to store some constants during sampling process. 17 | """ 18 | pass 19 | 20 | @abstractmethod 21 | def initial_state(self, hint: List[int]) -> dict: 22 | """ 23 | Get initial state of the model for sampling, given that the initial sequence is `hint`. 24 | Return the state after hint. 25 | """ 26 | pass 27 | 28 | @abstractmethod 29 | def proceed(self, state: dict, prev_note: int) -> Tuple[dict, torch.Tensor]: 30 | """ 31 | - `state`: a dictionary containing the state of the machine 32 | - `sampled_sequence`: a tensor of shape (seq_len, ), containing the sampled sequence 33 | Returns: 34 | - `state`: a dictionary containing the updated state of the machine 35 | - `distribution`: a tensor of shape (n_notes, ), containing the distribution of the next note 36 | """ 37 | pass 38 | 39 | 40 | 41 | class LocalAttnLSTMDistribution(DistributionGenerator): 42 | def __init__(self, model: DeepBeatsLSTMLocalAttn, x, device): 43 | """ 44 | - `x` is the input sequence, shape: (seq_len, 2) 45 | """ 46 | self.model = model 47 | self.device = device 48 | self.x = x 49 | self.context, self.encoder_state = self.model.encoder(x) 50 | 51 | def initial_state(self, hint: List[int]) -> dict: 52 | super().initial_state(hint) 53 | hint_shifted = [NOTE_START] + hint[:-1] 54 | state = { 55 | "position": 0, 56 | "memory": (self.encoder_state[0].reshape(1, 1, -1), self.encoder_state[1].reshape(1, 1, -1)), 57 | } 58 | for i in range(len(hint)): 59 | state, _ = self.proceed(state, hint_shifted[i]) 60 | return state 61 | 62 | def proceed(self, state: dict, prev_note: int) -> Tuple[dict, torch.Tensor]: 63 | super().proceed(state, prev_note) 64 | position = state["position"] 65 | memory = state["memory"] 66 | context_curr = self.context[position].reshape(1, 1, -1) 67 | 68 | y_prev = torch.tensor(prev_note).reshape(1, 1).to(self.device) 69 | scores, memory = self.model.decoder.forward(y_prev, context_curr, memory) 70 | 71 | scores = scores.squeeze(0) 72 | scores = torch.nn.functional.softmax(scores, dim=1) 73 | scores = scores.squeeze(0) 74 | return {"position": position + 1, "memory": memory}, scores 75 | 76 | class AttentionRNNDistribution(DistributionGenerator): 77 | def __init__(self, model: DeepBeatsAttentionRNN, x, device): 78 | """ 79 | - `x` is the input sequence, shape: (seq_len, 2) 80 | """ 81 | self.model = model 82 | self.device = device 83 | self.x = x 84 | self.encoder_output, _ = self.model.encoder(x) 85 | self.encoder_output = self.encoder_output.unsqueeze(0) 86 | 87 | def initial_state(self, hint: List[int]) -> dict: 88 | super().initial_state(hint) 89 | state = { 90 | "position": 0, 91 | "memory": None, 92 | } 93 | hint_shifted = [NOTE_START] + hint[:-1] 94 | for i in range(len(hint)): 95 | state, _ = self.proceed(state, hint_shifted[i]) 96 | return state 97 | 98 | def proceed(self, state: dict, prev_note: int) -> Tuple[dict, torch.Tensor]: 99 | super().proceed(state, prev_note) 100 | position = state["position"] 101 | memory = state["memory"] 102 | y_prev = torch.tensor(prev_note).reshape(1, 1).to(self.device) 103 | scores, memory = self.model.decoder.forward(y_prev, self.encoder_output, memory) 104 | scores = scores.squeeze(0) 105 | scores = torch.nn.functional.softmax(scores, dim=1) 106 | scores = scores.squeeze(0) 107 | return {"position": position + 1, "memory": memory}, scores 108 | 109 | class TransformerDistribution(DistributionGenerator): 110 | 111 | def __init__(self, model: DeepBeatsTransformer, x, device): 112 | x = x.unsqueeze(1) # (seq_len, 1, 2) 113 | self.model = model.to(device) 114 | self.device = device 115 | self.x_mask = (torch.zeros(x.shape[0], x.shape[0])).type(torch.bool).to(device) 116 | self.x = x.to(device) 117 | self.memory = self.model.encode(self.x, self.x_mask).to(device) 118 | self.max_seq = x.shape[0] 119 | 120 | def initial_state(self, hint: List[int]) -> dict: 121 | super().initial_state(hint) 122 | hint_shifted = [NOTE_START] + hint[:-1] 123 | ys = torch.tensor(hint_shifted).reshape(1, -1).permute(1, 0).to(self.device) 124 | return { 125 | "ys": ys, 126 | } 127 | 128 | def proceed(self, state: dict, prev_note: int) -> Tuple[dict, torch.Tensor]: 129 | super().proceed(state, prev_note) 130 | ys = state["ys"] 131 | ys = torch.cat([ys, torch.ones(1, 1).type_as(self.x.data).fill_(prev_note)], dim=0) 132 | curr_i = ys.shape[0] 133 | filled_ys = torch.cat([ys, torch.ones(self.max_seq - curr_i, 1).type_as(self.x.data).fill_(0)]) # fill max_seq 134 | tgt_mask = (self.model.generate_square_subsequent_mask(filled_ys.shape[0]) 135 | .type(torch.bool)).to(self.device) 136 | out = self.model.decode(filled_ys, self.memory, tgt_mask) # max_seq * 1 * 128 137 | out = out.transpose(0, 1) # 1 * max_seq * 128 138 | scores = self.model.generator(out[:, curr_i - 1]) # 1 * num_notes, we only care about the current one 139 | scores = torch.nn.functional.softmax(scores, dim=1) 140 | scores = scores.transpose(0, 1).squeeze(1) 141 | return {"ys": ys}, scores 142 | 143 | class VanillaRNNDistribution(DistributionGenerator): 144 | def __init__(self, model: DeepBeatsVanillaRNN, x, device): 145 | """ 146 | - `x` is the input sequence, shape: (seq_len, 2) 147 | """ 148 | self.model = model 149 | self.device = device 150 | self.x = x 151 | 152 | def initial_state(self, hint: List[int]) -> dict: 153 | super().initial_state(hint) 154 | state = { 155 | "position": 0, 156 | "hidden": self.model._default_init_hidden(1), 157 | } 158 | hint_shifted = [NOTE_START] + hint[:-1] 159 | for i in range(len(hint)): 160 | state, _ = self.proceed(state, hint_shifted[i]) 161 | return state 162 | 163 | def proceed(self, state: dict, prev_note: int) -> Tuple[dict, torch.Tensor]: 164 | super().proceed(state, prev_note) 165 | position = state["position"] 166 | hidden = state["hidden"] 167 | x_curr = self.x[position].reshape(1, 1, 2) 168 | y_prev = torch.tensor(prev_note).reshape(1, 1).to(self.device) 169 | scores, hidden = self.model.forward(x_curr, y_prev, hidden) 170 | scores = scores.squeeze(0) 171 | scores = torch.nn.functional.softmax(scores, dim=1) 172 | scores = scores.squeeze(0) 173 | return {"position": position + 1, "hidden": hidden}, scores 174 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.tensorboard.writer import SummaryWriter 4 | def accuracy(y_pred: np.ndarray, y_true: np.ndarray) -> float: 5 | assert isinstance(y_pred, np.ndarray) 6 | assert isinstance(y_true, np.ndarray) 7 | assert y_pred.shape == y_true.shape 8 | return (y_pred == y_true).sum()/ len(y_true) 9 | 10 | class Metrics: 11 | def __init__(self, label: str): 12 | self.label = label 13 | self.metrics_sum = { 14 | "loss": 0., 15 | "accuracy": 0. 16 | } 17 | self.sample_count = 0 18 | 19 | def update(self, batch_size: int, loss: float, y_pred_one_hot: torch.Tensor, y_true: torch.Tensor): 20 | self.metrics_sum["loss"] += loss * batch_size 21 | 22 | y_pred = torch.argmax(y_pred_one_hot, dim=2).cpu().numpy().flatten() 23 | y_true = y_true.cpu().numpy().astype(float).flatten() 24 | 25 | self.metrics_sum["accuracy"] += accuracy(y_pred, y_true) * batch_size 26 | self.sample_count += batch_size 27 | 28 | def flush_and_reset(self, writer: SummaryWriter, global_step: int): 29 | for metric, value in self.metrics_sum.items(): 30 | writer.add_scalar(f"{self.label}/{metric}", value / self.sample_count, global_step) 31 | summary = {metric: value / self.sample_count for metric, value in self.metrics_sum.items()} 32 | self.metrics_sum = {metric: 0 for metric in self.metrics_sum} 33 | self.sample_count = 0 34 | return summary 35 | 36 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import numpy as np 3 | import torch 4 | from models import transformer, vanilla_rnn, attention_rnn 5 | import toml 6 | from preprocess.dataset import BeatsRhythmsDataset 7 | import torch.utils.data 8 | import datetime 9 | from torch.utils.tensorboard.writer import SummaryWriter 10 | 11 | from utils.data_paths import DataPaths 12 | import models.lstm_local_attn as lstm_local_attn 13 | from utils.metrics import Metrics 14 | CONFIG_PATH = "./config.toml" 15 | 16 | def get_model(name, config, device): 17 | if name == "lstm_attn": 18 | return lstm_local_attn.DeepBeatsLSTMLocalAttn(num_notes=config["n_notes"], hidden_dim=config["hidden_dim"], 19 | dropout_p=config["dropout_p"]).to(device) 20 | elif name == "vanilla_rnn": 21 | return vanilla_rnn.DeepBeatsVanillaRNN(config["n_notes"], config["embed_dim"], config["hidden_dim"]).to(device) 22 | elif name == "attention_rnn": 23 | return attention_rnn.DeepBeatsAttentionRNN(config["n_notes"], config["embed_dim"], config["encode_hidden_dim"], config["decode_hidden_dim"]).to(device) 24 | elif name == "transformer": 25 | return transformer.DeepBeatsTransformer( 26 | num_encoder_layers=config["num_encoder_layers"], 27 | num_decoder_layers=config["num_encoder_layers"], 28 | emb_size=config["embed_dim"], 29 | nhead= config["num_heads"], 30 | src_vocab_size=config["src_vocab_size"], 31 | tgt_vocab_size=config["n_notes"], 32 | dim_feedforward=config["hidden_dim"] 33 | ).to(device) 34 | else: 35 | raise ValueError("Invalid model name") 36 | 37 | def model_forward(model_name, model, input_seq: torch.Tensor, target_seq: torch.Tensor, target_prev_seq: torch.Tensor, device): 38 | if model_name == "transformer": 39 | # nn.Transformer takes seq_len * batch_size 40 | input_seq, target_seq, target_prev_seq = input_seq.permute(1, 0, 2), target_seq.permute(1, 0), target_prev_seq.permute(1, 0) 41 | src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = model.create_mask(input_seq, target_prev_seq) 42 | src_mask, tgt_mask = src_mask.to(device), tgt_mask.to(device) 43 | src_padding_mask, tgt_padding_mask = src_padding_mask.to(device), tgt_padding_mask.to(device) 44 | output = model(input_seq, target_prev_seq, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask) 45 | output = output.permute(1, 0, 2) # permute back to batch first 46 | elif model_name == "attention_rnn" or model_name == "lstm_attn": 47 | output = model(input_seq, target_prev_seq) 48 | else: 49 | output, _ = model(input_seq, target_prev_seq) 50 | return output 51 | 52 | def model_file_name(model_name, n_files, n_epochs): 53 | return "{}_{}_{}.pth".format(model_name, "all" if n_files == -1 else n_files, n_epochs) 54 | 55 | def save_checkpoint(model, paths, model_name, n_files, n_epochs): 56 | model_file = model_file_name(model_name, n_files, n_epochs) 57 | model_path = paths.snapshots_dir / model_file 58 | torch.save({ 59 | 'model': model.state_dict(), 60 | 'n_epochs': n_epochs, 61 | }, model_path) 62 | print(f'Checkpoint Saved at {model_path}') 63 | 64 | def load_checkpoint(checkpoint_path, model, device): 65 | checkpoint = torch.load(checkpoint_path, map_location=device) 66 | model.load_state_dict(checkpoint['model']) 67 | n_epochs = checkpoint['n_epochs'] 68 | print(f'Checkpoint Loaded from {checkpoint_path}') 69 | return n_epochs 70 | 71 | 72 | def train(model_name: str, n_epochs: int, device: str, n_files:int=-1, snapshots_freq:int=10, checkpoint: Optional[str] = None): 73 | config = toml.load(CONFIG_PATH) 74 | 75 | global_config = config["global"] 76 | model_config = config["model"][model_name] 77 | 78 | model = get_model(model_name, model_config, device) 79 | print(model) 80 | 81 | dataset = BeatsRhythmsDataset(model_config["seq_len"], global_config["random_slice_seed"]) 82 | dataset.load(global_config["dataset"]) 83 | dataset = dataset.subset_remove_short() 84 | if n_files > 0: 85 | dataset = dataset.subset(n_files) 86 | 87 | training_data, val_data = dataset.train_val_split(global_config["train_val_split_seed"], global_config["val_ratio"]) 88 | print(f"Training data: {len(training_data)}") 89 | print(f"Validation data: {len(val_data)}") 90 | 91 | train_loader = torch.utils.data.DataLoader(training_data, batch_size=model_config["batch_size"], shuffle=True) 92 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=model_config["batch_size"], shuffle=False) 93 | 94 | # checkpoint 95 | if checkpoint is not None: 96 | epochs_start = load_checkpoint(checkpoint, model, device) 97 | else: 98 | epochs_start = 0 99 | 100 | optimizer = torch.optim.Adam(model.parameters(), lr=model_config["lr"]) 101 | # TODO: we can use a learning rate scheduler here 102 | paths = DataPaths() 103 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M") 104 | log_dir = paths.tensorboard_dir / "{}_{}/{}".format(model_name, "all" if n_files == -1 else n_files, current_time) 105 | writer = SummaryWriter(log_dir = log_dir, flush_secs= 60) 106 | writer.add_text("config", toml.dumps(model_config)) 107 | 108 | best_val_loss = float("inf") 109 | metrics_train = Metrics("train") 110 | metrics_val = Metrics("validation") 111 | 112 | for epoch in range(epochs_start, n_epochs): 113 | model.train() 114 | for batch in train_loader: 115 | optimizer.zero_grad() 116 | input_seq = batch["beats"].to(device) 117 | target_seq = batch["notes"].long().to(device) 118 | target_prev_seq = batch["notes_shifted"].long().to(device) 119 | output = model_forward(model_name, model, input_seq, target_seq, target_prev_seq, device) 120 | loss = model.loss_function(output, target_seq) 121 | loss.backward() 122 | if "clip_grad" in model_config: 123 | model.clip_gradients_(model_config["clip_grad"]) # type: ignore 124 | optimizer.step() 125 | metrics_train.update(len(batch), loss.item(), output, target_seq) 126 | 127 | model.eval() 128 | for batch in val_loader: 129 | input_seq = batch["beats"].to(device) 130 | target_seq = batch["notes"].long().to(device) 131 | target_prev_seq = batch["notes_shifted"].long().to(device) 132 | with torch.no_grad(): 133 | output = model_forward(model_name, model, input_seq, target_seq, target_prev_seq, device) 134 | loss = model.loss_function(output, target_seq) 135 | metrics_val.update(len(batch), loss.item(), output, target_seq) 136 | 137 | training_metrics = metrics_train.flush_and_reset(writer, epoch) 138 | validation_metrics = metrics_val.flush_and_reset(writer, epoch) 139 | 140 | print('Epoch: {}/{}.............'.format(epoch, n_epochs), end=' ') 141 | print("Train Loss: {:.4f}, Val Loss: {:.4f}, Train Acc: {:.4f}, Val Acc: {:.4f}".format(training_metrics["loss"], validation_metrics["loss"], training_metrics["accuracy"], validation_metrics["accuracy"])) 142 | 143 | # save checkpoint with lowest validation loss 144 | if validation_metrics["loss"] < best_val_loss: 145 | best_val_loss = validation_metrics["loss"] 146 | save_checkpoint(model, paths, model_name, n_files, "best") 147 | print("Minimum Validation Loss of {:.4f} at epoch {}/{}".format(best_val_loss, epoch, n_epochs)) 148 | 149 | # save snapshots 150 | if (epoch + 1) % snapshots_freq == 0: 151 | save_checkpoint(model, paths, model_name, n_files, epoch + 1) 152 | writer.close() 153 | save_checkpoint(model, paths, model_name, n_files, n_epochs) 154 | return model 155 | -------------------------------------------------------------------------------- /utils/render.py: -------------------------------------------------------------------------------- 1 | from note_seq.protobuf.music_pb2 import NoteSequence 2 | import note_seq.midi_io as midi_io 3 | import numpy as np 4 | import torch 5 | def convert_to_melody(beats, notes): 6 | """ 7 | Convert beats and notes to melody. 8 | - `beats`: array of shape (seq_length, 2), where the first column is the rest time before current note and the second column is the current duration 9 | - `notes`: array of shape (seq_length,) 10 | Returns: Melody Array 11 | - `start_time`: array of shape (seq_length,) 12 | - `end_time`: array of shape (seq_length,) 13 | - `pitch`: array of shape (seq_length,) 14 | """ 15 | if isinstance(beats, torch.Tensor): 16 | beats = beats.cpu().numpy() 17 | if isinstance(notes, torch.Tensor): 18 | notes = notes.cpu().numpy() 19 | 20 | num_notes = beats.shape[0] 21 | notes = notes.reshape(-1) 22 | # get the start time of each note 23 | start_time = np.zeros(num_notes) 24 | start_time[1:] = np.cumsum(np.sum(beats[:num_notes-1, :], axis=1)) 25 | start_time = start_time + beats[:, 0] 26 | 27 | end_time = start_time + beats[:, 1] 28 | pitch = notes 29 | 30 | return start_time, end_time, pitch 31 | 32 | def convert_to_note_seq(beats, notes): 33 | """ 34 | Convert beats and notes to note_seq. 35 | - `beats`: array of shape (seq_length, 2), where the first column is the rest time before current note and the second column is the current duration 36 | - `notes`: array of shape (seq_length,) 37 | Returns: note_seq 38 | """ 39 | start_time, end_time, pitch = convert_to_melody(beats, notes) 40 | seq = NoteSequence() 41 | seq.tempos.add().qpm = 120 # tempos is irrelevant here 42 | seq.total_time = end_time[-1] 43 | for i in range(len(start_time)): 44 | seq.notes.add(start_time=start_time[i], end_time=end_time[i], pitch=pitch[i], velocity=80) # velocity is irrelevant here 45 | return seq 46 | 47 | def render_midi(beats, notes, midi_path): 48 | """ 49 | Render beats and notes to MIDI file. 50 | - `beats`: array of shape (seq_length, 2), where the first column is the rest time before current note and the second column is the current duration 51 | - `notes`: array of shape (seq_length,) 52 | - `midi_path`: path to save the MIDI file 53 | """ 54 | seq = convert_to_note_seq(beats, notes) 55 | midi_io.note_sequence_to_midi_file(seq, midi_path) -------------------------------------------------------------------------------- /utils/sample.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import numpy as np 3 | import torch 4 | 5 | from models.transformer import DeepBeatsTransformer 6 | from models.attention_rnn import DeepBeatsAttentionRNN 7 | from models.lstm_local_attn import DeepBeatsLSTMLocalAttn 8 | from models.vanilla_rnn import DeepBeatsVanillaRNN 9 | from utils.distribution import DistributionGenerator, TransformerDistribution, LocalAttnLSTMDistribution, AttentionRNNDistribution, VanillaRNNDistribution 10 | from tqdm import tqdm 11 | 12 | def get_distribution_generator(model, beats, device) -> DistributionGenerator: 13 | """ 14 | - `model`: The model to use for sampling 15 | - `beats`: a numpy array of shape (seq_len, 2), containing the beats 16 | - `device`: the device to use 17 | """ 18 | beats = torch.from_numpy(beats).float().to(device) 19 | if isinstance(model, DeepBeatsTransformer): 20 | return TransformerDistribution(model, beats, device) 21 | elif isinstance(model, DeepBeatsAttentionRNN): 22 | return AttentionRNNDistribution(model, beats, device) 23 | elif isinstance(model, DeepBeatsLSTMLocalAttn): 24 | return LocalAttnLSTMDistribution(model, beats, device) 25 | elif isinstance(model, DeepBeatsVanillaRNN): 26 | return VanillaRNNDistribution(model, beats, device) 27 | else: 28 | raise NotImplementedError("Sampling is not implemented for this model") 29 | 30 | def stochastic_step(prev_note: int, distribution: torch.Tensor, top_p: float = 0.9, top_k: int=4, repeat_decay: float = 0.5, temperature = 1.) -> Tuple[int, float]: 31 | """ 32 | - `distribution`: a tensor of shape (n_notes, ), containing the conditional distribution of the next note 33 | - `top_p`: sample only the top p% of the distribution 34 | - `top_k`: sample only the top k notes of the distribution 35 | - `repeat_decay`: penalty on repeating the same note. Each time the same note is repeated, the probability of repeating it is multiplied by `1 - repeat_decay`. 36 | the probability of getting N repeats is upper bounded by `(1 - repeat_decay) ** N` 37 | - `temperature`: temperature of the distribution. Lower temperature gives more confidence to the most probable notes, higher temperature gives a more uniform distribution. 38 | 39 | Returns: 40 | - `sampled_note`: an integer representing the sampled note 41 | - `conditional_likelihood`: the conditional likelihood of the sampled note: P(note | sampled_sequence). This will be useful for beam search. 42 | """ 43 | assert distribution.shape[0] == 128 44 | assert len(distribution.shape) == 1 45 | assert 0 <= top_p <= 1, "top_p must be between 0 and 1" 46 | assert 0 <= repeat_decay <= 1, "repeat_decay must be between 0 and 1" 47 | assert temperature > 0, "temperature must be positive" 48 | # penalize previous note 49 | distribution[prev_note] *= (1 - repeat_decay) 50 | # sample only the top p of the distribution 51 | sorted_prob, sorted_idx = torch.sort(distribution, descending=True) 52 | cumsum_prob = torch.cumsum(sorted_prob, dim=0) 53 | top_p_mask = cumsum_prob < top_p 54 | top_p_mask[0] = True 55 | top_p_idx = sorted_idx[top_p_mask][:top_k] 56 | top_p_distribution = distribution[top_p_idx] 57 | # normalize the distribution 58 | top_p_distribution = top_p_distribution / top_p_distribution.sum() 59 | # apply temperature 60 | top_p_distribution = top_p_distribution ** (1 / temperature) 61 | # sample 62 | sampled_note = int(torch.multinomial(top_p_distribution, 1).item()) 63 | conditional_likelihood = top_p_distribution[sampled_note].item() 64 | return top_p_idx[sampled_note].item(), conditional_likelihood 65 | 66 | def stochastic_search(model, beats: np.ndarray, hint: List[int], device: str, top_p: float= 0.9, top_k:int= 4, repeat_decay: float = 0.5, temperature=1.) -> np.ndarray: 67 | """ 68 | - `model`: model to use for sampling 69 | - `seq_len`: the length of the sequence to be sampled 70 | - `device`: the device to use 71 | - `top_p`: sample only the top p% of the distribution 72 | - `top_k`: sample only the top k notes of the distribution 73 | - `repeat_decay`: penalty on repeating the same note. Each time the same note is repeated, the probability of repeating it is multiplied by `1 - repeat_decay`. 74 | the probability of getting N repeats is upper bounded by `(1 - repeat_decay) ** N` 75 | - `initial_note`: the "sequence-start" placeholder. 76 | - `temperature`: temperature of the distribution. Lower temperature gives more confidence to the most probable notes, higher temperature gives a more uniform distribution. 77 | 78 | Returns: 79 | - `generated_sequence`: a numpy array of shape (seq_len, ), containing the generated sequence 80 | """ 81 | dist = get_distribution_generator(model, beats, device) 82 | state = dist.initial_state(hint) 83 | generated_sequence = hint[:] 84 | prev_note = generated_sequence[-1] 85 | progress_bar = tqdm(range(beats.shape[0] - len(hint)), desc="Stochastic search") 86 | for _ in progress_bar: 87 | # get the distribution 88 | state, distribution = dist.proceed(state, prev_note) 89 | # sample 90 | sampled_note, _ = stochastic_step(prev_note, distribution, top_p, top_k, repeat_decay, temperature) 91 | generated_sequence.append(sampled_note) 92 | prev_note = sampled_note 93 | return np.array(generated_sequence) 94 | 95 | def beam_search(model, beats: np.ndarray, hint: List[int], device: str, repeat_decay: float = 0.5, num_beams: int = 3, beam_prob: float = 0.7, temperature=1.) -> np.ndarray: 96 | """ 97 | - `model`: model to use for sampling 98 | - `seq_len`: the length of the sequence to be sampled 99 | - `device`: the device to use 100 | - `repeat_decay`: penalty on repeating the same note. Each time the same note is repeated, the probability of repeating it is multiplied by `1 - repeat_decay`. 101 | the probability of getting N repeats is upper bounded by `(1 - repeat_decay) ** N` 102 | - `initial_note`: the "sequence-start" placeholder. 103 | - `num_beams`: number of beams to use 104 | 105 | Returns: 106 | - `generated_sequence`: a numpy array of shape (seq_len, ), containing the generated sequence 107 | """ 108 | dist = get_distribution_generator(model, beats, device) 109 | state = dist.initial_state(hint) 110 | beams = [(hint[:], state, 0)] # (generated_sequence, state, log_likelihood) 111 | progress_bar = tqdm(range(beats.shape[0] - len(hint)), desc="Beam search") 112 | for _ in progress_bar: 113 | beam_choice = np.random.rand() 114 | if beam_choice < beam_prob: 115 | new_beams = [] 116 | for beam in beams: 117 | prev_note = beam[0][-1] 118 | state, distribution = dist.proceed(beam[1], prev_note) 119 | # modify the distribution using the repeat_decay 120 | distribution[prev_note] *= (1 - repeat_decay) 121 | # sample 122 | for sampled_note in range(128): 123 | new_beam = (beam[0] + [sampled_note], state, beam[2] + np.log(distribution[sampled_note].item())) 124 | new_beams.append(new_beam) 125 | # sort the beams by their likelihood 126 | new_beams = sorted(new_beams, key=lambda x: x[2], reverse=True) 127 | # keep only the top num_beams 128 | beams = new_beams[:num_beams] 129 | else: 130 | new_beams = [] 131 | for beam in beams: 132 | prev_note = beam[0][-1] 133 | state, distribution = dist.proceed(beam[1], prev_note) 134 | # sample 135 | sampled_note, conditional_likelihood = stochastic_step(prev_note, distribution, 1.0, 128, repeat_decay, temperature) 136 | new_beam = (beam[0] + [sampled_note], state, beam[2] + np.log(conditional_likelihood)) 137 | new_beams.append(new_beam) 138 | # sort the beams by their likelihood 139 | new_beams = sorted(new_beams, key=lambda x: x[2], reverse=True) 140 | # keep only the top num_beams 141 | beams = new_beams[:num_beams] 142 | # return the beam with the highest likelihood 143 | return np.array(beams[0][0]) 144 | 145 | 146 | 147 | --------------------------------------------------------------------------------