├── LICENSE ├── README.md ├── dataset_unprocessed ├── Jsb16thSeparated.npz ├── js-fakes-16thSeparated.npz ├── test_chords.p ├── test_keysigs.p ├── train_chords.p ├── train_keysigs.p ├── train_majmin_chords.p ├── valid_chords.p └── valid_keysigs.p ├── eval ├── TonicNet_epoch-55_loss-0.321_acc-90.755.pt ├── TonicNet_epoch-56_loss-0.328_acc-90.750.pt ├── TonicNet_epoch-58_loss-0.317_acc-90.928.pt ├── eval.py ├── sample.py ├── samples │ ├── Audio │ │ ├── sample1.mp3 │ │ ├── sample10.mp3 │ │ ├── sample2.mp3 │ │ ├── sample3.mp3 │ │ ├── sample4.mp3 │ │ ├── sample5.mp3 │ │ ├── sample6.mp3 │ │ ├── sample7.mp3 │ │ ├── sample8.mp3 │ │ ├── sample9.mp3 │ │ ├── sample_2.mp3 │ │ ├── sample_3.mp3 │ │ ├── sample_4.mp3 │ │ └── sample_5.mp3 │ ├── MIDI 16th notes │ │ ├── sample1.mid │ │ ├── sample10.mid │ │ ├── sample2.mid │ │ ├── sample3.mid │ │ ├── sample4.mid │ │ ├── sample5.mid │ │ ├── sample6.mid │ │ ├── sample7.mid │ │ ├── sample8.mid │ │ └── sample9.mid │ └── MIDI smoothed │ │ ├── sample10_smoothed.mid │ │ ├── sample1_smoothed.mid │ │ ├── sample2_smoothed.mid │ │ ├── sample3_smoothed.mid │ │ ├── sample4_smoothed.mid │ │ ├── sample5_smoothed.mid │ │ ├── sample6_smoothed.mid │ │ ├── sample7_smoothed.mid │ │ ├── sample8_smoothed.mid │ │ └── sample9_smoothed.mid └── utils.py ├── main.py ├── preprocessing ├── instruments.py ├── nn_dataset.py └── utils.py ├── tokenisers ├── inverse_pitch_only.p └── pitch_only.p └── train ├── external.py ├── models.py ├── train_nn.py └── transformer.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TonicNet 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/js-fake-chorales-a-synthetic-dataset-of/music-modeling-on-jsb-chorales)](https://paperswithcode.com/sota/music-modeling-on-jsb-chorales?p=js-fake-chorales-a-synthetic-dataset-of-1) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/improving-polyphonic-music-models-with/music-modeling-on-jsb-chorales)](https://paperswithcode.com/sota/music-modeling-on-jsb-chorales?p=improving-polyphonic-music-models-with) 5 | 6 | 7 | Accompanying repository for my paper: [Improving Polyphonic Music Models with Feature-Rich Encoding](https://arxiv.org/abs/1911.11775) 8 | 9 | Requirements: 10 | - Python 3 (tested with 3.6.5) 11 | - Pytorch (tested with 1.2.0) 12 | - Music21 13 | 14 | Prepare Dataset: 15 | 16 | To prepare the vanilla JSB Chorales dataset with canonical train/validation/test split: 17 | ``` 18 | python main.py --gen_dataset 19 | ``` 20 | 21 | To prepare dataset augmented with [JS Fake Chorales](https://github.com/omarperacha/js-fakes): 22 | ``` 23 | python main.py --gen_dataset --jsf 24 | ``` 25 | 26 | To prepare dataset for training on JS Fake Chorales only: 27 | ``` 28 | python main.py --gen_dataset --jsf_only 29 | ``` 30 | 31 | Train Model from Scratch: 32 | 33 | First run `--gen_dataset` with any optional 2nd argument, then: 34 | ``` 35 | python main.py --train 36 | ``` 37 | 38 | Training requires 60 epochs, taking roughly 3-6 hours on GPU 39 | 40 | Evaluate Pre-trained Model on Test Set: 41 | 42 | ``` 43 | python main.py --eval_nn 44 | ``` 45 | 46 | Sample with Pre-trained Model (via random sampling): 47 | 48 | ``` 49 | python main.py --sample 50 | ``` 51 | -------------------------------------------------------------------------------- /dataset_unprocessed/Jsb16thSeparated.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/Jsb16thSeparated.npz -------------------------------------------------------------------------------- /dataset_unprocessed/js-fakes-16thSeparated.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/js-fakes-16thSeparated.npz -------------------------------------------------------------------------------- /dataset_unprocessed/test_chords.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/test_chords.p -------------------------------------------------------------------------------- /dataset_unprocessed/test_keysigs.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/test_keysigs.p -------------------------------------------------------------------------------- /dataset_unprocessed/train_chords.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/train_chords.p -------------------------------------------------------------------------------- /dataset_unprocessed/train_keysigs.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/train_keysigs.p -------------------------------------------------------------------------------- /dataset_unprocessed/train_majmin_chords.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/train_majmin_chords.p -------------------------------------------------------------------------------- /dataset_unprocessed/valid_chords.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/valid_chords.p -------------------------------------------------------------------------------- /dataset_unprocessed/valid_keysigs.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/valid_keysigs.p -------------------------------------------------------------------------------- /eval/TonicNet_epoch-55_loss-0.321_acc-90.755.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/TonicNet_epoch-55_loss-0.321_acc-90.755.pt -------------------------------------------------------------------------------- /eval/TonicNet_epoch-56_loss-0.328_acc-90.750.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/TonicNet_epoch-56_loss-0.328_acc-90.750.pt -------------------------------------------------------------------------------- /eval/TonicNet_epoch-58_loss-0.317_acc-90.928.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/TonicNet_epoch-58_loss-0.317_acc-90.928.pt -------------------------------------------------------------------------------- /eval/eval.py: -------------------------------------------------------------------------------- 1 | from torch import cuda, load, device, set_grad_enabled, max, sum, cat 2 | from preprocessing.nn_dataset import get_test_set_for_eval_classic 3 | 4 | """ 5 | File containing functions to quantitatively evaluate trained models 6 | """ 7 | 8 | 9 | def eval_on_test_set(load_path, model, criterion, set='test', notes_only=False): 10 | model = model 11 | 12 | try: 13 | if cuda.is_available(): 14 | model.load_state_dict(load(load_path)['model_state_dict']) 15 | else: 16 | model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict']) 17 | print("loded params from", load_path) 18 | except: 19 | raise ImportError(f'No file located at {load_path}, could not load parameters') 20 | print(model) 21 | 22 | if cuda.is_available(): 23 | model.cuda() 24 | 25 | model.eval() 26 | criterion = criterion 27 | 28 | count = 0 29 | batch_count = 0 30 | loss_epoch = 0 31 | running_accuray = 0.0 32 | running_batch_count = 0 33 | print_loss_batch = 0 # Reset on print 34 | print_acc_batch = 0 # Reset on print 35 | pr_interval = 1 36 | 37 | for x, y, psx, i, c in get_test_set_for_eval_classic(set): 38 | model.zero_grad() 39 | 40 | train_emb = False 41 | 42 | Y = y 43 | 44 | with set_grad_enabled(False): 45 | y_hat = model(x, z=i, train_embedding=train_emb) 46 | _, preds = max(y_hat, 2) 47 | 48 | if notes_only: 49 | for j in range(y_hat.shape[1]): 50 | if j % 5 != 4: 51 | if j == 0: 52 | new_y_hat = y_hat[:, j, :].view(1, 1, 98) 53 | new_y = Y[:, j].view(1, 1) 54 | else: 55 | new_y_hat = cat((new_y_hat, y_hat[:, j, :].view(1, 1, 98)), dim=1) 56 | new_y = cat((new_y, Y[:, j].view(1, 1)), dim=1) 57 | loss = criterion(new_y_hat, new_y, ) 58 | else: 59 | loss = criterion(y_hat, Y, ) 60 | 61 | loss_epoch += loss.item() 62 | print_loss_batch += loss.item() 63 | if notes_only: 64 | _, new_preds = max(new_y_hat, 2) 65 | running_accuray += sum(new_preds == new_y) 66 | print_acc_batch += sum(new_preds == new_y) 67 | else: 68 | running_accuray += sum(preds == Y) 69 | print_acc_batch += sum(preds == Y) 70 | 71 | count += 1 72 | if notes_only: 73 | batch_count += int(x.shape[1] * 0.8) 74 | running_batch_count += int(x.shape[1] * 0.8) 75 | else: 76 | batch_count += x.shape[1] 77 | running_batch_count += x.shape[1] 78 | 79 | # print loss for recent set of batches 80 | if count % pr_interval == 0: 81 | ave_loss = print_loss_batch / pr_interval 82 | ave_acc = 100 * print_acc_batch.float() / running_batch_count 83 | print_acc_batch = 0 84 | running_batch_count = 0 85 | print('\t\t[%d] loss: %.3f, acc: %.3f' % (count, ave_loss, ave_acc)) 86 | print_loss_batch = 0 87 | 88 | # calculate loss and accuracy for phase 89 | ave_loss_epoch = loss_epoch / count 90 | epoch_acc = 100 * running_accuray.float() / batch_count 91 | print('\tfinished %s phase loss: %.3f, acc: %.3f' % ('eval', ave_loss_epoch, epoch_acc)) 92 | 93 | -------------------------------------------------------------------------------- /eval/sample.py: -------------------------------------------------------------------------------- 1 | from train.models import TonicNet 2 | from torch import cat, multinomial 3 | from torch import cuda, load, device, tensor, zeros 4 | from torch.nn import LogSoftmax 5 | import pickle 6 | import random 7 | from copy import deepcopy 8 | 9 | """ 10 | Functions to sample from trained models 11 | """ 12 | 13 | 14 | def sample_TonicNet_random(load_path, max_tokens=2999, temperature=1.0): 15 | 16 | model = TonicNet(nb_tags=98, z_dim=32, nb_layers=3, nb_rnn_units=256, dropout=0.0) 17 | 18 | try: 19 | if cuda.is_available(): 20 | model.load_state_dict(load(load_path)['model_state_dict']) 21 | else: 22 | model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict']) 23 | print("loded params from", load_path) 24 | except: 25 | raise ImportError(f'No file located at {load_path}, could not load parameters') 26 | print(model) 27 | 28 | if cuda.is_available(): 29 | model.cuda() 30 | 31 | model.eval() 32 | model.seq_len = 1 33 | model.hidden = model.init_hidden() 34 | model.zero_grad() 35 | 36 | inverse_t = pickle.load(open('tokenisers/inverse_pitch_only.p', mode='rb')) 37 | 38 | seed, pos_dict = __get_seed() 39 | 40 | x = seed 41 | x_post = x 42 | 43 | inst_conv_dict = {0: 0, 1: 1, 2: 4, 3: 2, 4: 3} 44 | current_token_dict = {0: '', 1: '', 2: '', 3: '', 4: ''} 45 | 46 | print("") 47 | print(0) 48 | print("\t", 0, ":", chord_from_token(x[0][0].item() - 48)) 49 | 50 | for i in range(max_tokens): 51 | 52 | if i == 0: 53 | reset_hidden = True 54 | else: 55 | reset_hidden = False 56 | 57 | inst = inst_conv_dict[i % 5] 58 | psx = pos_dict[inst] 59 | 60 | psx_t = tensor(psx).view(1, 1) 61 | 62 | y_hat = model(x, z=psx_t, sampling=True, 63 | reset_hidden=reset_hidden).data.view(-1).div(temperature).exp() 64 | y = multinomial(y_hat, 1)[0] 65 | if y.item() == 0: # EOS token 66 | print("ending") 67 | break 68 | else: 69 | try: 70 | token = inverse_t[y.item()] 71 | except: 72 | token = chord_from_token(y.item() - 48) 73 | 74 | next_inst = inst_conv_dict[(i + 1) % 5] 75 | 76 | print("") 77 | print(i + 1) 78 | 79 | if current_token_dict[next_inst] == token: 80 | pos_dict[next_inst] += 1 81 | else: 82 | current_token_dict[next_inst] = token 83 | pos_dict[next_inst] = 0 84 | 85 | x = y.view(1, 1, 1) 86 | x_post = cat((x_post, x), dim=1) 87 | 88 | return x_post 89 | 90 | 91 | def sample_TonicNet_beam_search(load_path, max_tokens=2999, beam_width=10, alpha=1.0): 92 | 93 | """sample the model via beam search algorithm (not the most efficient implementation but functional) 94 | 95 | :param load_path: path to state_dict to load weights from 96 | :param max_tokens: maximum number of iterations to sample 97 | :param beam_width: breadth of beam search heuristic 98 | :param alpha: hyperparamter for length normalisation of beam search - higher value prefers longer sequences 99 | :return: generated list of token indices 100 | """ 101 | 102 | model = TonicNet(nb_tags=98, z_dim=32, nb_layers=3, nb_rnn_units=256, dropout=0.0) 103 | logsoftmax = LogSoftmax(dim=0) 104 | 105 | try: 106 | if cuda.is_available(): 107 | model.load_state_dict(load(load_path)['model_state_dict']) 108 | else: 109 | model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict']) 110 | print("loded params from", load_path) 111 | except: 112 | raise ImportError(f'No file located at {load_path}, could not load parameters') 113 | print(model) 114 | 115 | if cuda.is_available(): 116 | model.cuda() 117 | 118 | model.eval() 119 | model.seq_len = 1 120 | model.hidden = model.init_hidden() 121 | model.zero_grad() 122 | 123 | inverse_t = pickle.load(open('tokenisers/inverse_pitch_only.p', mode='rb')) 124 | inverse_t[0] = 'end' 125 | 126 | seed, pos_dict = __get_seed() 127 | 128 | x = seed 129 | x_post = x 130 | 131 | inst_conv_dict = {0: 0, 1: 1, 2: 4, 3: 2, 4: 3} 132 | current_token_dict = {0: '', 1: '', 2: '', 3: '', 4: ''} 133 | 134 | candidate_seqs = [] 135 | c_ts = [] 136 | pos_ds = [] 137 | models = [] 138 | scores = [] 139 | 140 | ended = [0] * beam_width 141 | 142 | for i in range(max_tokens): 143 | 144 | print("") 145 | print(i) 146 | 147 | log_probs = zeros((98*beam_width)) 148 | updated = [False] * beam_width 149 | 150 | inst = inst_conv_dict[i % 5] 151 | next_inst = inst_conv_dict[(i + 1) % 5] 152 | 153 | for b in range(beam_width): 154 | 155 | if i == 0: 156 | reset_hidden = True 157 | candidate_seqs.append(deepcopy([x])) 158 | c_ts.append(deepcopy(current_token_dict)) 159 | pos_ds.append(deepcopy(pos_dict)) 160 | models.append(deepcopy(model)) 161 | scores.append(0) 162 | else: 163 | reset_hidden = False 164 | 165 | psx = pos_ds[b][inst] 166 | 167 | psx_t = tensor(psx).view(1, 1) 168 | 169 | if candidate_seqs[b][-1].item() == 0: 170 | log_probs[(98 * b):(98 * (b + 1))] = tensor(98).fill_(-9999) 171 | else: 172 | y_hat = models[b](candidate_seqs[b][-1], z=psx_t, sampling=True, reset_hidden=reset_hidden) 173 | log_probs[(98*b):(98*(b+1))] = logsoftmax(y_hat[0, 0, :]) + scores[b] 174 | 175 | if i == 0: 176 | top = log_probs[0:98].topk(k=beam_width) 177 | else: 178 | top = log_probs.topk(k=beam_width) 179 | 180 | temp_store = [] 181 | rejection_reconciliation = {} 182 | 183 | for b1 in range(beam_width): 184 | candidate = top[1][b1].item() 185 | prob = top[0][b1].item() 186 | y = candidate % 98 187 | m = int(candidate / 98) 188 | 189 | if updated[m]: 190 | temp_store.append((m, y, prob)) 191 | else: 192 | updated[m] = True 193 | rejection_reconciliation[m] = m 194 | if candidate_seqs[m][-1].item() != 0: 195 | scores[m] = prob 196 | candidate_seqs[m].append(tensor(y).view(1, 1, 1)) 197 | else: 198 | ended[m] = 1 199 | 200 | for b2 in range(beam_width): 201 | if not updated[b2]: 202 | rejection_reconciliation[b2] = temp_store[0][0] 203 | if candidate_seqs[b2][-1].item() != 0: 204 | 205 | scores[b2] = temp_store[0][2] 206 | 207 | candidate_seqs[b2] = deepcopy(candidate_seqs[rejection_reconciliation[b2]]) 208 | candidate_seqs[b2].pop(-1) 209 | candidate_seqs[b2].append(tensor(temp_store[0][1]).view(1, 1, 1)) 210 | 211 | models[b2].hidden = models[rejection_reconciliation[b2]].hidden 212 | 213 | # copy over pos_dict and current token to replace rejected model's 214 | pos_dict[b2] = deepcopy(pos_dict[rejection_reconciliation[b2]]) 215 | c_ts[b2] = deepcopy(c_ts[rejection_reconciliation[b2]]) 216 | 217 | temp_store.pop(0) 218 | else: 219 | ended[b2] = 1 220 | 221 | try: 222 | token = inverse_t[candidate_seqs[b2][-1].item()] 223 | except: 224 | token = chord_from_token(candidate_seqs[b2][-1].item() - 48) 225 | 226 | if c_ts[b2][next_inst] == token: 227 | pos_ds[b2][next_inst] += 1 228 | else: 229 | c_ts[b2][next_inst] = token 230 | pos_ds[b2][next_inst] = 0 231 | 232 | if sum(ended) == beam_width: 233 | print('all ended') 234 | break 235 | 236 | normalised_scores = [scores[n] * (1/(len(candidate_seqs[n])**alpha)) for n in range(beam_width)] 237 | 238 | chosen_seq = max(zip(normalised_scores, range(len(normalised_scores))))[1] 239 | 240 | for n in candidate_seqs[chosen_seq][1:]: 241 | x_post = cat((x_post, n), dim=1) 242 | 243 | return x_post 244 | 245 | 246 | def __get_seed(): 247 | 248 | seed = random.choice(range(48, 72)) 249 | x = tensor(seed) 250 | 251 | p_dct = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0} 252 | 253 | return x.view(1, 1, 1), p_dct 254 | 255 | 256 | def chord_from_token(token): 257 | 258 | if token < 12: 259 | qual = 'major' 260 | elif token < 24: 261 | qual = 'minor' 262 | elif token < 36: 263 | qual = 'diminished' 264 | elif token < 48: 265 | qual = 'augmented' 266 | elif token == 48: 267 | qual = 'other' 268 | else: 269 | qual = 'none' 270 | 271 | return token % 12, qual 272 | 273 | -------------------------------------------------------------------------------- /eval/samples/Audio/sample1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample1.mp3 -------------------------------------------------------------------------------- /eval/samples/Audio/sample10.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample10.mp3 -------------------------------------------------------------------------------- /eval/samples/Audio/sample2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample2.mp3 -------------------------------------------------------------------------------- /eval/samples/Audio/sample3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample3.mp3 -------------------------------------------------------------------------------- /eval/samples/Audio/sample4.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample4.mp3 -------------------------------------------------------------------------------- /eval/samples/Audio/sample5.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample5.mp3 -------------------------------------------------------------------------------- /eval/samples/Audio/sample6.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample6.mp3 -------------------------------------------------------------------------------- /eval/samples/Audio/sample7.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample7.mp3 -------------------------------------------------------------------------------- /eval/samples/Audio/sample8.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample8.mp3 -------------------------------------------------------------------------------- /eval/samples/Audio/sample9.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample9.mp3 -------------------------------------------------------------------------------- /eval/samples/Audio/sample_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample_2.mp3 -------------------------------------------------------------------------------- /eval/samples/Audio/sample_3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample_3.mp3 -------------------------------------------------------------------------------- /eval/samples/Audio/sample_4.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample_4.mp3 -------------------------------------------------------------------------------- /eval/samples/Audio/sample_5.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample_5.mp3 -------------------------------------------------------------------------------- /eval/samples/MIDI 16th notes/sample1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample1.mid -------------------------------------------------------------------------------- /eval/samples/MIDI 16th notes/sample10.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample10.mid -------------------------------------------------------------------------------- /eval/samples/MIDI 16th notes/sample2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample2.mid -------------------------------------------------------------------------------- /eval/samples/MIDI 16th notes/sample3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample3.mid -------------------------------------------------------------------------------- /eval/samples/MIDI 16th notes/sample4.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample4.mid -------------------------------------------------------------------------------- /eval/samples/MIDI 16th notes/sample5.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample5.mid -------------------------------------------------------------------------------- /eval/samples/MIDI 16th notes/sample6.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample6.mid -------------------------------------------------------------------------------- /eval/samples/MIDI 16th notes/sample7.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample7.mid -------------------------------------------------------------------------------- /eval/samples/MIDI 16th notes/sample8.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample8.mid -------------------------------------------------------------------------------- /eval/samples/MIDI 16th notes/sample9.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample9.mid -------------------------------------------------------------------------------- /eval/samples/MIDI smoothed/sample10_smoothed.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample10_smoothed.mid -------------------------------------------------------------------------------- /eval/samples/MIDI smoothed/sample1_smoothed.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample1_smoothed.mid -------------------------------------------------------------------------------- /eval/samples/MIDI smoothed/sample2_smoothed.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample2_smoothed.mid -------------------------------------------------------------------------------- /eval/samples/MIDI smoothed/sample3_smoothed.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample3_smoothed.mid -------------------------------------------------------------------------------- /eval/samples/MIDI smoothed/sample4_smoothed.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample4_smoothed.mid -------------------------------------------------------------------------------- /eval/samples/MIDI smoothed/sample5_smoothed.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample5_smoothed.mid -------------------------------------------------------------------------------- /eval/samples/MIDI smoothed/sample6_smoothed.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample6_smoothed.mid -------------------------------------------------------------------------------- /eval/samples/MIDI smoothed/sample7_smoothed.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample7_smoothed.mid -------------------------------------------------------------------------------- /eval/samples/MIDI smoothed/sample8_smoothed.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample8_smoothed.mid -------------------------------------------------------------------------------- /eval/samples/MIDI smoothed/sample9_smoothed.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample9_smoothed.mid -------------------------------------------------------------------------------- /eval/utils.py: -------------------------------------------------------------------------------- 1 | import music21 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from preprocessing.utils import get_parts_from_stream 5 | 6 | """ 7 | Utility functions for evaluating model training and saving samples from trained models 8 | """ 9 | 10 | 11 | def indices_to_stream(token_list, return_stream=False): 12 | inverse_t = pickle.load(open('tokenisers/inverse_pitch_only.p', 'rb')) 13 | tl = token_list.squeeze() 14 | tl = tl.numpy() 15 | 16 | sop_part = music21.stream.Part() 17 | sop_part.id = 'soprano' 18 | 19 | alto_part = music21.stream.Part() 20 | alto_part.id = 'alto' 21 | 22 | tenor_part = music21.stream.Part() 23 | tenor_part.id = 'tenor' 24 | 25 | bass_part = music21.stream.Part() 26 | bass_part.id = 'bass' 27 | 28 | score = music21.stream.Stream([sop_part, bass_part, alto_part, tenor_part]) 29 | 30 | for j in range(len(tl)): 31 | 32 | i = tl[j] 33 | try: 34 | note = inverse_t[i] 35 | except: 36 | continue 37 | 38 | idx = (j % 5) - 1 39 | 40 | if note == 'Rest': 41 | n = music21.note.Rest() 42 | else: 43 | pitch = int(note) 44 | n = music21.note.Note(pitch) 45 | 46 | dur = 0.25 47 | n.quarterLength = dur 48 | 49 | score[idx].append(n) 50 | 51 | if return_stream: 52 | return score 53 | else: 54 | score.write('midi', fp='eval/sample.mid') 55 | print("SAVED sample to ./eval/sample.mid") 56 | 57 | 58 | def plot_loss_acc_curves(log='eval/out.log'): 59 | train_loss = [] 60 | train_acc = [] 61 | 62 | val_loss = [] 63 | val_acc = [] 64 | 65 | f = open(log, "r") 66 | txt = f.read() 67 | for line in txt.split("\n"): 68 | if 'finished' in line: 69 | components = line.split(" ") 70 | loss = components[5] 71 | loss = loss[:-1] 72 | acc = components[7] 73 | if 'train phase' in line: 74 | train_acc.append(float(acc)) 75 | train_loss.append(float(loss)) 76 | else: 77 | val_acc.append(float(acc)) 78 | val_loss.append(float(loss)) 79 | 80 | plt.figure(1) 81 | plt.subplot(121) 82 | plt.plot(train_loss) 83 | plt.plot(val_loss) 84 | plt.xlabel('epochs') 85 | plt.legend(['train loss', 'val loss'], loc='upper left') 86 | plt.ylim(0, 6) 87 | 88 | plt.subplot(122) 89 | plt.plot(train_acc) 90 | plt.plot(val_acc) 91 | plt.xlabel('epochs') 92 | plt.legend(['train acc', 'val acc'], loc='upper left') 93 | plt.ylim(0, 100) 94 | plt.show() 95 | 96 | plt.show() 97 | 98 | 99 | def smooth_rhythm(): 100 | path = 'eval/sample.mid' 101 | 102 | mf = music21.midi.MidiFile() 103 | mf.open(path) 104 | mf.read() 105 | mf.close() 106 | 107 | s = music21.midi.translate.midiFileToStream(mf) 108 | 109 | score = music21.stream.Stream() 110 | 111 | parts = get_parts_from_stream(s) 112 | 113 | for part in parts: 114 | new_part = music21.stream.Part() 115 | 116 | current_pitch = -1 117 | current_offset = 0.0 118 | current_dur = 0.0 119 | 120 | for n in part.notesAndRests.flat: 121 | if isinstance(n, music21.note.Rest): 122 | if current_pitch == 129: 123 | current_dur += 0.25 124 | else: 125 | if current_pitch > -1: 126 | if current_pitch < 128: 127 | note = music21.note.Note(current_pitch) 128 | else: 129 | note = music21.note.Rest 130 | note.quarterLength = current_dur 131 | new_part.insert(current_offset, note) 132 | 133 | current_pitch = 129 134 | current_offset = n.offset 135 | current_dur = 0.25 136 | 137 | else: 138 | if n.pitch.midi == current_pitch: 139 | current_dur += 0.25 140 | else: 141 | if current_pitch > -1: 142 | if current_pitch < 128: 143 | note = music21.note.Note(current_pitch) 144 | else: 145 | note = music21.note.Rest 146 | note.quarterLength = current_dur 147 | new_part.insert(current_offset, note) 148 | 149 | current_pitch = n.pitch.midi 150 | current_offset = n.offset 151 | current_dur = 0.25 152 | 153 | if current_pitch < 128: 154 | note = music21.note.Note(current_pitch) 155 | else: 156 | note = music21.note.Rest 157 | note.quarterLength = current_dur 158 | new_part.insert(current_offset, note) 159 | 160 | score.append(new_part) 161 | 162 | score.write('midi', fp='eval/sample_smoothed.mid') 163 | print("SAVED rhythmically 'smoothed' sample to ./eval/sample_smoothed.mid") 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from preprocessing.nn_dataset import bach_chorales_classic 3 | from train.train_nn import train_TonicNet, TonicNet_lr_finder, TonicNet_sanity_test 4 | from train.train_nn import CrossEntropyTimeDistributedLoss 5 | from train.models import TonicNet 6 | from eval.utils import plot_loss_acc_curves, indices_to_stream, smooth_rhythm 7 | from eval.eval import eval_on_test_set 8 | from eval.sample import sample_TonicNet_random 9 | 10 | if len(sys.argv) > 1: 11 | if sys.argv[1] in ['--train', '-t']: 12 | train_TonicNet(3000, shuffle_batches=1, train_emb_freq=1, load_path='') 13 | 14 | elif sys.argv[1] in ['--plot', '-p']: 15 | plot_loss_acc_curves() 16 | 17 | elif sys.argv[1] in ['--find_lr', '-lr']: 18 | TonicNet_lr_finder(train_emb_freq=1, load_path='') 19 | 20 | elif sys.argv[1] in ['--sanity_test', '-st']: 21 | TonicNet_sanity_test(num_batches=1, train_emb_freq=1) 22 | 23 | elif sys.argv[1] in ['--sample', '-s']: 24 | x = sample_TonicNet_random(load_path='eval/TonicNet_epoch-56_loss-0.328_acc-90.750.pt', temperature=1.0) 25 | indices_to_stream(x) 26 | smooth_rhythm() 27 | 28 | elif sys.argv[1] in ['--eval_nn', '-e']: 29 | eval_on_test_set( 30 | 'eval/TonicNet_epoch-58_loss-0.317_acc-90.928.pt', 31 | TonicNet(nb_tags=98, z_dim=32, nb_layers=3, nb_rnn_units=256, dropout=0.0), 32 | CrossEntropyTimeDistributedLoss(), set='test', notes_only=True) 33 | 34 | elif sys.argv[1] in ['--gen_dataset', '-gd']: 35 | if len(sys.argv) > 2 and sys.argv[2] == '--jsf': 36 | for x, y, p, i, c in bach_chorales_classic('save', transpose=True, jsf_aug='all'): 37 | continue 38 | elif len(sys.argv) > 2 and sys.argv[2] == '--jsf_only': 39 | for x, y, p, i, c in bach_chorales_classic('save', transpose=True, jsf_aug='only'): 40 | continue 41 | else: 42 | for x, y, p, i, c in bach_chorales_classic('save', transpose=True): 43 | continue 44 | 45 | else: 46 | print("") 47 | print("TonicNet (Training on Ordered Notation Including Chords)") 48 | print("Omar Peracha, 2019") 49 | print("") 50 | print("--gen_dataset\t\t\t\t prepare dataset") 51 | print("--gen_dataset --jsf \t\t prepare dataset with JS Fakes data augmentation") 52 | print("--gen_dataset --jsf_only \t prepare dataset with JS Fake Chorales only") 53 | print("--train\t\t\t train model from scratch") 54 | print("--eval_nn\t\t evaluate pretrained model on test set") 55 | print("--sample\t\t sample from pretrained model") 56 | print("") 57 | else: 58 | 59 | print("") 60 | print("TonicNet (Training on Ordered Notation Including Chords)") 61 | print("Omar Peracha, 2019") 62 | print("") 63 | print("--gen_dataset\t\t\t\t prepare dataset") 64 | print("--gen_dataset --jsf \t\t prepare dataset with JS Fake Chorales data augmentation") 65 | print("--gen_dataset --jsf_only \t prepare dataset with JS Fake Chorales only") 66 | print("--train\t\t\t train model from scratch") 67 | print("--eval_nn\t\t evaluate pretrained model on test set") 68 | print("--sample\t\t sample from pretrained model") 69 | print("") 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /preprocessing/instruments.py: -------------------------------------------------------------------------------- 1 | """ 2 | File containing 'structs' and methods pertaining to determining and assigning instruments to parts in corpus 3 | """ 4 | 5 | 6 | # MARK:- Instrument Data Objects 7 | 8 | class SopranoVoice: 9 | instrumentId = 'soprano' 10 | highestNote = 81 11 | lowestNote = 60 12 | 13 | 14 | class AltoVoice: 15 | instrumentId = 'alto' 16 | highestNote = 77 17 | lowestNote = 53 18 | 19 | 20 | class TenorVoice: 21 | instrumentId = 'tenor' 22 | highestNote = 72 23 | lowestNote = 45 24 | 25 | 26 | class BassVoice: 27 | instrumentId = 'bass' 28 | highestNote = 64 29 | lowestNote = 36 30 | 31 | 32 | def get_instrument(inst_name_in): 33 | """ 34 | 35 | :param inst_name_in: string with the name of an instrument 36 | :return: a data object corresponding to the inst_name if applicable, or UNK 37 | """ 38 | 39 | if not isinstance(inst_name_in, str): 40 | inst_name = str(inst_name_in) 41 | else: 42 | inst_name = inst_name_in 43 | 44 | # Handle a few scenarios where multiple instruments could be scored 45 | if 'bass' in inst_name.lower() or 'B.' in inst_name: 46 | return BassVoice() 47 | 48 | elif 'tenor' in inst_name.lower(): 49 | return TenorVoice() 50 | 51 | elif 'alto' in inst_name.lower(): 52 | return AltoVoice() 53 | 54 | elif 'soprano' in inst_name.lower() or 'S.' in inst_name: 55 | return SopranoVoice() 56 | 57 | elif 'canto' in inst_name.lower(): 58 | return SopranoVoice() 59 | 60 | 61 | def get_part_range(part): 62 | notes = part.pitches 63 | midi = list(map(__get_midi, notes)) 64 | return [min(midi), max(midi)] 65 | 66 | 67 | def __get_midi(pitch): 68 | return pitch.midi 69 | 70 | -------------------------------------------------------------------------------- /preprocessing/nn_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import pickle 4 | import numpy as np 5 | from preprocessing.instruments import get_instrument 6 | from random import sample 7 | 8 | """ 9 | File containing functions to derive training data for neural networks 10 | """ 11 | 12 | CUDA = torch.cuda.is_available() 13 | 14 | if CUDA: 15 | PATH = 'train/training_set/X_cuda' 16 | else: 17 | PATH = 'train/training_set/X' 18 | if os.path.exists(PATH): 19 | TRAIN_BATCHES = len(os.listdir(PATH)) 20 | else: 21 | TRAIN_BATCHES = 0 22 | TOTAL_BATCHES = TRAIN_BATCHES + 76 23 | 24 | MAX_SEQ = 2880 25 | N_PITCH = 48 26 | N_CHORD = 50 27 | N_TOKENS = N_PITCH + N_CHORD 28 | 29 | 30 | def get_data_set(mode, shuffle_batches=True, return_I=False): 31 | 32 | if mode == 'train': 33 | parent_dir = 'train/training_set' 34 | elif mode == 'val': 35 | parent_dir = 'train/val_set' 36 | else: 37 | raise Exception("invalid mode passed to get_data_set() - options are 'train' and 'val'") 38 | 39 | if torch.cuda.is_available(): 40 | lst = os.listdir(f'{parent_dir}/X_cuda') 41 | else: 42 | lst = os.listdir(f'{parent_dir}/X') 43 | try: 44 | lst.remove('.DS_Store') 45 | except: 46 | pass 47 | 48 | if shuffle_batches: 49 | lst = sample(lst, len(lst)) 50 | 51 | for file_name in lst: 52 | if torch.cuda.is_available(): 53 | X = torch.load(f'{parent_dir}/X_cuda/{file_name}') 54 | Y = torch.load(f'{parent_dir}/Y_cuda/{file_name}') 55 | P = torch.load(f'{parent_dir}/P_cuda/{file_name}') 56 | if return_I: 57 | I = torch.load(f'{parent_dir}/I_cuda/{file_name}') 58 | C = torch.load(f'{parent_dir}/C_cuda/{file_name}') 59 | else: 60 | X = torch.load(f'{parent_dir}/X/{file_name}') 61 | Y = torch.load(f'{parent_dir}/Y/{file_name}') 62 | P = torch.load(f'{parent_dir}/P/{file_name}') 63 | if return_I: 64 | I = torch.load(f'{parent_dir}/I/{file_name}') 65 | C = torch.load(f'{parent_dir}/C/{file_name}') 66 | 67 | if return_I: 68 | yield X, Y, P, I, C 69 | else: 70 | yield X, Y, P 71 | 72 | 73 | def bach_chorales_classic(mode, transpose=False, maj_min=False, jsf_aug=None): 74 | 75 | if maj_min and jsf_aug is not None: 76 | raise ValueError("maj_min and jsf_aug can not both be true") 77 | 78 | if jsf_aug not in ['all', 'only', 'topk-all', 'topk-only', 'topk-skilled-all', 'topk-skilled-only', None]: 79 | raise ValueError("unrecognised value for jsf_aug parameter: can only be 'all', 'only', " 80 | "'topl-all', 'topk-only', 'topk-skilled-all', 'topk-skilled-only' or None") 81 | 82 | tokeniser = pickle.load(open('tokenisers/pitch_only.p', 'rb')) 83 | tokeniser["end"] = 0 84 | count = 0 85 | 86 | for folder_name in ["training_set", "val_set"]: 87 | if torch.cuda.is_available(): 88 | print("cuda:") 89 | try: 90 | os.makedirs(f'train/{folder_name}/X_cuda') 91 | os.makedirs(f'train/{folder_name}/Y_cuda') 92 | os.makedirs(f'train/{folder_name}/P_cuda') 93 | os.makedirs(f'train/{folder_name}/I_cuda') 94 | os.makedirs(f'train/{folder_name}/C_cuda') 95 | except: 96 | pass 97 | else: 98 | try: 99 | os.makedirs(f'train/{folder_name}/X') 100 | os.makedirs(f'train/{folder_name}/Y') 101 | os.makedirs(f'train/{folder_name}/P') 102 | os.makedirs(f'train/{folder_name}/I') 103 | os.makedirs(f'train/{folder_name}/C') 104 | except: 105 | pass 106 | 107 | for phase in ['train', 'valid']: 108 | 109 | d = np.load('dataset_unprocessed/Jsb16thSeparated.npz', allow_pickle=True, encoding="latin1") 110 | train = (d[phase]) 111 | 112 | ks = pickle.load(open(f'dataset_unprocessed/{phase}_keysigs.p', 'rb')) 113 | crds = pickle.load(open(f'dataset_unprocessed/{phase}_chords.p', 'rb')) 114 | crds_majmin = pickle.load(open('dataset_unprocessed/train_majmin_chords.p', 'rb')) 115 | k_count = 0 116 | 117 | if jsf_aug is not None and phase == 'train': 118 | if jsf_aug in ['all', 'only']: 119 | jsf_path = 'dataset_unprocessed/js-fakes-16thSeparated.npz' 120 | jsf = np.load(jsf_path, allow_pickle=True, encoding="latin1") 121 | js_chords = jsf["chords"] 122 | jsf = jsf["pitches"] 123 | 124 | if jsf_aug == "all": 125 | train = np.concatenate((train, jsf)) 126 | crds = np.concatenate((crds, js_chords)) 127 | elif jsf_aug == "only": 128 | train = jsf 129 | crds = js_chords 130 | 131 | for m in train: 132 | int_m = m.astype(int) 133 | 134 | if maj_min: 135 | tonic = ks[k_count][0] 136 | scale = ks[k_count][1] 137 | crd_majmin = crds_majmin[k_count] 138 | 139 | crd = crds[k_count] 140 | k_count += 1 141 | 142 | if transpose is False or phase == 'valid': 143 | transpositions = [int_m] 144 | crds_pieces = [crd] 145 | else: 146 | parts = [int_m[:, 0], int_m[:, 1], int_m[:, 2], int_m[:, 3]] 147 | transpositions, tonics, crds_pieces = __np_perform_all_transpositions(parts, 0, crd) 148 | 149 | if maj_min: 150 | 151 | mode_switch = __np_convert_major_minor(int_m, tonic, scale) 152 | ms_parts = [mode_switch[:, 0], mode_switch[:, 1], mode_switch[:, 2], mode_switch[:, 3]] 153 | ms_trans, ms_tons, ms_crds = __np_perform_all_transpositions(ms_parts, tonic, crd_majmin) 154 | 155 | transpositions += ms_trans 156 | tonics += ms_tons 157 | crds_pieces += ms_crds 158 | 159 | kc = 0 160 | 161 | for t in transpositions: 162 | 163 | crds_piece = crds_pieces[kc] 164 | 165 | _tokens = [] 166 | inst_ids = [] 167 | c_class = [] 168 | 169 | current_s = '' 170 | s_count = 0 171 | 172 | current_a = '' 173 | a_count = 0 174 | 175 | current_t = '' 176 | t_count = 0 177 | 178 | current_b = '' 179 | b_count = 0 180 | 181 | current_c = '' 182 | c_count = 0 183 | 184 | timestep = 0 185 | 186 | for i in t: 187 | s = 'Rest' if i[0] < 36 else str(i[0]) 188 | b = 'Rest' if i[3] < 36 else str(i[3]) 189 | a = 'Rest' if i[1] < 36 else str(i[1]) 190 | t = 'Rest' if i[2] < 36 else str(i[2]) 191 | 192 | c_val = crds_piece[timestep] + 48 193 | timestep += 1 194 | 195 | _tokens = _tokens + [c_val, s, b, a, t] 196 | c_class = c_class + [c_val] 197 | 198 | if c_val == current_c: 199 | c_count += 1 200 | else: 201 | c_count = 0 202 | current_c = c_val 203 | 204 | if s == current_s: 205 | s_count += 1 206 | else: 207 | s_count = 0 208 | current_s = s 209 | 210 | if b == current_b: 211 | b_count += 1 212 | else: 213 | b_count = 0 214 | current_b = b 215 | 216 | if a == current_a: 217 | a_count += 1 218 | else: 219 | a_count = 0 220 | current_a = a 221 | 222 | if t == current_t: 223 | t_count += 1 224 | else: 225 | t_count = 0 226 | current_t = t 227 | 228 | inst_ids = inst_ids + [c_count, s_count, b_count, a_count, t_count] 229 | 230 | pos_ids = list(range(len(_tokens))) 231 | 232 | kc += 1 233 | _tokens.append('end') 234 | tokens = [] 235 | try: 236 | for x in _tokens: 237 | if isinstance(x, str): 238 | tokens.append(tokeniser[x]) 239 | else: 240 | tokens.append(x) 241 | except: 242 | print("ERROR: tokenisation") 243 | continue 244 | 245 | SEQ_LEN = len(tokens) - 1 246 | 247 | count += 1 248 | 249 | data_x = [] 250 | data_y = [] 251 | 252 | pos_x = [] 253 | 254 | for i in range(0, len(tokens) - SEQ_LEN, 1): 255 | t_seq_in = tokens[i:i + SEQ_LEN] 256 | t_seq_out = tokens[i + 1: i + 1 + SEQ_LEN] 257 | data_x.append(t_seq_in) 258 | data_y.append(t_seq_out) 259 | 260 | p_seq_in = pos_ids[i:i + SEQ_LEN] 261 | pos_x.append(p_seq_in) 262 | 263 | X = torch.tensor(data_x) 264 | X = torch.unsqueeze(X, 2) 265 | 266 | Y = torch.tensor(data_y) 267 | P = torch.tensor(pos_x) 268 | I = torch.tensor(inst_ids) 269 | C = torch.tensor(c_class) 270 | 271 | set_folder = 'training_set' 272 | if phase == 'valid': 273 | set_folder = 'val_set' 274 | 275 | if mode == 'save': 276 | 277 | if torch.cuda.is_available(): 278 | print("cuda:") 279 | torch.save(X.cuda(), f'train/{set_folder}/X_cuda/{count}.pt') 280 | torch.save(Y.cuda(), f'train/{set_folder}/Y_cuda/{count}.pt') 281 | torch.save(P.cuda(), f'train/{set_folder}/P_cuda/{count}.pt') 282 | torch.save(I.cuda(), f'train/{set_folder}/I_cuda/{count}.pt') 283 | torch.save(C.cuda(), f'train/{set_folder}/C_cuda/{count}.pt') 284 | else: 285 | torch.save(X, f'train/{set_folder}/X/{count}.pt') 286 | torch.save(Y, f'train/{set_folder}/Y/{count}.pt') 287 | torch.save(P, f'train/{set_folder}/P/{count}.pt') 288 | torch.save(I, f'train/{set_folder}/I/{count}.pt') 289 | torch.save(C, f'train/{set_folder}/C/{count}.pt') 290 | print("saved", count) 291 | else: 292 | print("processed", count) 293 | yield X, Y, P, I, C 294 | 295 | 296 | def get_test_set_for_eval_classic(phase='test'): 297 | 298 | tokeniser = pickle.load(open('tokenisers/pitch_only.p', 'rb')) 299 | tokeniser["end"] = 0 300 | 301 | d = np.load('dataset_unprocessed/Jsb16thSeparated.npz', allow_pickle=True, encoding="latin1") 302 | test = (d[f'{phase}']) 303 | 304 | crds = pickle.load(open(f'dataset_unprocessed/{phase}_chords.p', 'rb')) 305 | crd_count = 0 306 | 307 | for m in test: 308 | int_m = m.astype(int) 309 | 310 | crds_piece = crds[crd_count] 311 | crd_count += 1 312 | 313 | _tokens = [] 314 | inst_ids = [] 315 | c_class = [] 316 | 317 | current_s = '' 318 | s_count = 0 319 | 320 | current_a = '' 321 | a_count = 0 322 | 323 | current_t = '' 324 | t_count = 0 325 | 326 | current_b = '' 327 | b_count = 0 328 | 329 | current_c = '' 330 | c_count = 0 331 | 332 | timestep = 0 333 | 334 | for i in int_m: 335 | s = 'Rest' if i[0] < 36 else str(i[0]) 336 | b = 'Rest' if i[3] < 36 else str(i[3]) 337 | a = 'Rest' if i[1] < 36 else str(i[1]) 338 | t = 'Rest' if i[2] < 36 else str(i[2]) 339 | 340 | c_val = crds_piece[timestep] + 48 341 | timestep += 1 342 | 343 | _tokens = _tokens + [c_val, s, b, a, t] 344 | c_class = c_class + [c_val] 345 | 346 | if c_val == current_c: 347 | c_count += 1 348 | else: 349 | c_count = 0 350 | current_c = c_val 351 | 352 | if s == current_s: 353 | s_count += 1 354 | else: 355 | s_count = 0 356 | current_s = s 357 | 358 | if b == current_b: 359 | b_count += 1 360 | else: 361 | b_count = 0 362 | current_b = b 363 | 364 | if a == current_a: 365 | a_count += 1 366 | else: 367 | a_count = 0 368 | current_a = a 369 | 370 | if t == current_t: 371 | t_count += 1 372 | else: 373 | t_count = 0 374 | current_t = t 375 | 376 | inst_ids = inst_ids + [c_count, s_count, b_count, a_count, t_count] 377 | 378 | pos_ids = list(range(len(_tokens))) 379 | 380 | _tokens.append('end') 381 | 382 | tokens = [] 383 | try: 384 | for x in _tokens: 385 | if isinstance(x, str): 386 | tokens.append(tokeniser[x]) 387 | else: 388 | tokens.append(x) 389 | except: 390 | print("ERROR: tokenisation") 391 | continue 392 | 393 | SEQ_LEN = len(tokens) - 1 394 | 395 | data_x = [] 396 | data_y = [] 397 | 398 | pos_x = [] 399 | 400 | for i in range(0, len(tokens) - SEQ_LEN, 1): 401 | t_seq_in = tokens[i:i + SEQ_LEN] 402 | t_seq_out = tokens[i + 1: i + 1 + SEQ_LEN] 403 | data_x.append(t_seq_in) 404 | data_y.append(t_seq_out) 405 | 406 | p_seq_in = pos_ids[i:i + SEQ_LEN] 407 | pos_x.append(p_seq_in) 408 | 409 | X = torch.tensor(data_x) 410 | X = torch.unsqueeze(X, 2) 411 | 412 | Y = torch.tensor(data_y) 413 | P = torch.tensor(pos_x) 414 | C = torch.tensor(c_class) 415 | I = torch.tensor(inst_ids) 416 | 417 | yield X, Y, P, I, C 418 | 419 | 420 | def position_encoding_init(n_position, emb_dim): 421 | ''' Init the sinusoid position encoding table ''' 422 | 423 | # keep dim 0 for padding token position encoding zero vector 424 | position_enc = torch.tensor([ 425 | [pos / np.power(10000, 2 * (j // 2) / emb_dim) for j in range(emb_dim)] 426 | if pos != 0 else np.zeros(emb_dim) for pos in range(n_position)], dtype=torch.float32) 427 | 428 | position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # apply sin on 0th,2nd,4th...emb_dim 429 | position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # apply cos on 1st,3rd,5th...emb_dim 430 | 431 | if torch.cuda.is_available(): 432 | position_enc = position_enc.cuda() 433 | 434 | return position_enc 435 | 436 | 437 | def to_categorical(y, num_classes=N_TOKENS): 438 | """ 1-hot encodes a tensor """ 439 | return torch.eye(num_classes)[y] 440 | 441 | 442 | def __pos_from_beatStr(measure, beat, div=8): 443 | 444 | if len(beat.split(" ")) == 1: 445 | b_pos = float(beat.split(" ")[0]) 446 | else: 447 | fraction = beat.split(" ")[1] 448 | fraction = fraction.split("/") 449 | decimal = float(fraction[0])/float(fraction[1]) 450 | b_pos = float(beat.split(" ")[0]) + decimal 451 | 452 | b_pos *= div 453 | b_pos -= div 454 | 455 | pos_enc_idx = int(measure*(4*div) + b_pos) 456 | 457 | return pos_enc_idx 458 | 459 | 460 | # Mark:- transposition 461 | def __np_perform_all_transpositions(parts, tonic, chords): 462 | mylist = [] 463 | tonics = [] 464 | my_chords = [] 465 | try: 466 | t_r = __np_transposable_range_for_piece(parts) 467 | except: 468 | print("error getting transpose range") 469 | return mylist 470 | lower = t_r[0] 471 | higher = t_r[1] + 1 472 | quals = [__np_get_quals_from_chord(x) for x in chords] 473 | for i in range(lower, higher): 474 | try: 475 | roots = [(x + 12 + i) % 12 for x in chords] 476 | transposed_piece = np.zeros((len(parts[0]), 4), dtype=int) 477 | chord_prog = [__np_chord_from_root_qual(roots[i], quals[i]) for i in range(len(chords))] 478 | for j in range(4): 479 | tp = parts[j] + i 480 | transposed_piece[:, j] = tp[:] 481 | except: 482 | print("ERROR: empty return") 483 | else: 484 | mylist.append(transposed_piece) 485 | tonics.append((tonic + i) % 12) 486 | my_chords.append(chord_prog) 487 | return mylist, tonics, my_chords 488 | 489 | 490 | def __np_transposable_range_for_part(part, inst): 491 | 492 | if not isinstance(inst, str): 493 | inst = str(inst) 494 | part_range = __np_get_part_range(part) 495 | instrument = get_instrument(inst) 496 | 497 | lower_transposable = instrument.lowestNote - part_range[0] 498 | higher_transposable = instrument.highestNote - part_range[1] 499 | 500 | # suggests there's perhaps no musical content in this score 501 | if higher_transposable - lower_transposable >= 128: 502 | lower_transposable = 0 503 | higher_transposable = 0 504 | return min(0, lower_transposable), max(0, higher_transposable) 505 | 506 | 507 | def __np_transposable_range_for_piece(parts): 508 | 509 | insts = ['soprano', 'alto', 'tenor', 'bass'] 510 | 511 | lower = -127 512 | higher = 127 513 | 514 | for i in range(len(parts)): 515 | t_r = __np_transposable_range_for_part(parts[i], insts[i]) 516 | if t_r[0] > lower: 517 | lower = t_r[0] 518 | if t_r[1] < higher: 519 | higher = t_r[1] 520 | # suggests there's perhaps no musical content in this score 521 | if higher - lower >= 128: 522 | lower = 0 523 | higher = 0 524 | return lower, higher 525 | 526 | 527 | def __np_get_part_range(part): 528 | 529 | mn = min(part) 530 | 531 | if mn < 36: 532 | p = sorted(part) 533 | c = 1 534 | while mn < 36: 535 | mn = p[c] 536 | c += 1 537 | 538 | return [mn, max(part)] 539 | 540 | 541 | def __np_convert_major_minor(piece, tonic, mode): 542 | 543 | _piece = piece 544 | 545 | for i in range(len(_piece)): 546 | s = _piece[i][0] if _piece[i][0] < 36 else (_piece[i][0] - tonic) % 12 547 | b = _piece[i][3] if _piece[i][3] < 36 else (_piece[i][3] - tonic) % 12 548 | a = _piece[i][1] if _piece[i][1] < 36 else (_piece[i][1] - tonic) % 12 549 | t = _piece[i][2] if _piece[i][2] < 36 else (_piece[i][2] - tonic) % 12 550 | 551 | parts = [s, a, t, b] 552 | 553 | for n in range(len(parts)): 554 | if mode == 'major': 555 | if parts[n] in [4, 9]: 556 | _piece[i][n] -= 1 557 | elif mode == 'minor': 558 | if parts[n] in [3, 8, 10]: 559 | _piece[i][n] += 1 560 | else: 561 | raise ValueError(f"mode must be minor or major, received {mode}") 562 | 563 | return _piece 564 | 565 | 566 | def __np_get_quals_from_chord(chord): 567 | 568 | if chord < 12: 569 | qual = 'major' 570 | elif chord < 24: 571 | qual = 'minor' 572 | elif chord < 36: 573 | qual = 'diminished' 574 | elif chord < 48: 575 | qual = 'augmented' 576 | elif chord == 48: 577 | qual = 'other' 578 | else: 579 | qual = 'none' 580 | 581 | return qual 582 | 583 | 584 | def __np_chord_from_root_qual(root, qual): 585 | 586 | if qual == "major": 587 | chord = root 588 | elif qual == "minor": 589 | chord = root + 12 590 | elif qual == "diminished": 591 | chord = root + 24 592 | elif qual == "augmented": 593 | chord = root + 36 594 | elif qual == "other": 595 | chord = 48 596 | elif qual == "none": 597 | chord = 49 598 | 599 | return chord 600 | 601 | 602 | 603 | 604 | -------------------------------------------------------------------------------- /preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | import music21 2 | import pickle 3 | 4 | """ 5 | Contains utility functions for dataset preprocessing 6 | """ 7 | 8 | 9 | def get_parts_from_stream(piece): 10 | parts = [] 11 | for i in piece: 12 | if isinstance(i, music21.stream.Part): 13 | parts.append(i) 14 | return parts 15 | 16 | 17 | def pitch_tokeniser_maker(): 18 | post = {"end": 0} 19 | for i in range(36, 82): 20 | k = str(i) 21 | post[k] = len(post) 22 | post['Rest'] = len(post) 23 | 24 | return post 25 | 26 | 27 | def load_tokeniser(): 28 | dic = pickle.load(open("tokenisers/pitch_only.p", "rb")) 29 | return dic 30 | 31 | 32 | def chord_from_pitches(pitches): 33 | cd = [] 34 | for n in pitches: 35 | if n >= 36: 36 | cd.append(int(n)) 37 | 38 | crd = music21.chord.Chord(cd) 39 | try: 40 | root = music21.pitch.Pitch(crd.root()).pitchClass 41 | except: 42 | c_val = 49 43 | else: 44 | if crd.quality == 'major': 45 | c_val = root 46 | if crd.quality == 'minor': 47 | c_val = root + 12 48 | if crd.quality == 'diminished': 49 | c_val = root + 24 50 | if crd.quality == 'augmented': 51 | c_val = root + 36 52 | if crd.quality == 'other': 53 | c_val = 48 54 | 55 | return c_val 56 | 57 | -------------------------------------------------------------------------------- /tokenisers/inverse_pitch_only.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/tokenisers/inverse_pitch_only.p -------------------------------------------------------------------------------- /tokenisers/pitch_only.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/tokenisers/pitch_only.p -------------------------------------------------------------------------------- /train/external.py: -------------------------------------------------------------------------------- 1 | import math 2 | import itertools as it 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | import torch.nn as nn 6 | from torch.nn.utils.rnn import PackedSequence 7 | from typing import * 8 | 9 | """ 10 | File containing classes taken directly from other work to aid model training 11 | """ 12 | 13 | 14 | class RAdam(Optimizer): 15 | """ 16 | Directly from Liyuan Liu, Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han. 17 | "On the Variance of the Adaptive Learning Rate and Beyond." 18 | """ 19 | 20 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 21 | weight_decay=0): 22 | defaults = dict(lr=lr, betas=betas, eps=eps, 23 | weight_decay=weight_decay) 24 | 25 | super(RAdam, self).__init__(params, defaults) 26 | 27 | def __setstate__(self, state): 28 | super(RAdam, self).__setstate__(state) 29 | 30 | def step(self, closure=None): 31 | loss = None 32 | beta2_t = None 33 | ratio = None 34 | N_sma_max = None 35 | N_sma = None 36 | 37 | if closure is not None: 38 | loss = closure() 39 | 40 | for group in self.param_groups: 41 | 42 | for p in group['params']: 43 | if p.grad is None: 44 | continue 45 | grad = p.grad.data.float() 46 | if grad.is_sparse: 47 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 48 | 49 | p_data_fp32 = p.data.float() 50 | 51 | state = self.state[p] 52 | 53 | if len(state) == 0: 54 | state['step'] = 0 55 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 56 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 57 | else: 58 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 59 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 60 | 61 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 62 | beta1, beta2 = group['betas'] 63 | 64 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 65 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 66 | 67 | state['step'] += 1 68 | if beta2_t is None: 69 | beta2_t = beta2 ** state['step'] 70 | N_sma_max = 2 / (1 - beta2) - 1 71 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 72 | beta1_t = 1 - beta1 ** state['step'] 73 | if N_sma >= 5: 74 | ratio = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / beta1_t 75 | 76 | if group['weight_decay'] != 0: 77 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 78 | 79 | # more conservative since it's an approximated value 80 | if N_sma >= 5: 81 | step_size = group['lr'] * ratio 82 | denom = exp_avg_sq.sqrt().add_(group['eps']) 83 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 84 | else: 85 | step_size = group['lr'] / beta1_t 86 | p_data_fp32.add_(-step_size, exp_avg) 87 | 88 | p.data.copy_(p_data_fp32) 89 | 90 | return loss 91 | 92 | 93 | class VariationalDropout(nn.Module): 94 | """ 95 | Applies the same dropout mask across the temporal dimension 96 | See https://arxiv.org/abs/1512.05287 for more details. 97 | Note that this is not applied to the recurrent activations in the LSTM like the above paper. 98 | Instead, it is applied to the inputs and outputs of the recurrent layer. 99 | """ 100 | def __init__(self, dropout: float, batch_first: Optional[bool]=False): 101 | super().__init__() 102 | self.dropout = dropout 103 | self.batch_first = batch_first 104 | 105 | def forward(self, x: torch.Tensor) -> torch.Tensor: 106 | if not self.training or self.dropout <= 0.: 107 | return x 108 | 109 | is_packed = isinstance(x, PackedSequence) 110 | if is_packed: 111 | x, batch_sizes = x 112 | max_batch_size = int(batch_sizes[0]) 113 | else: 114 | batch_sizes = None 115 | max_batch_size = x.size(0) 116 | 117 | # Drop same mask across entire sequence 118 | if self.batch_first: 119 | m = x.new_empty(max_batch_size, 1, x.size(2), requires_grad=False).bernoulli_(1 - self.dropout) 120 | else: 121 | m = x.new_empty(1, max_batch_size, x.size(2), requires_grad=False).bernoulli_(1 - self.dropout) 122 | x = x.masked_fill(m == 0, 0) / (1 - self.dropout) 123 | 124 | if is_packed: 125 | return PackedSequence(x, batch_sizes) 126 | else: 127 | return x 128 | 129 | 130 | class Lookahead(Optimizer): 131 | def __init__(self, base_optimizer,alpha=0.5, k=6): 132 | if not 0.0 <= alpha <= 1.0: 133 | raise ValueError(f'Invalid slow update rate: {alpha}') 134 | if not 1 <= k: 135 | raise ValueError(f'Invalid lookahead steps: {k}') 136 | self.optimizer = base_optimizer 137 | self.param_groups = self.optimizer.param_groups 138 | self.alpha = alpha 139 | self.k = k 140 | for group in self.param_groups: 141 | group["step_counter"] = 0 142 | self.slow_weights = [[p.clone().detach() for p in group['params']] 143 | for group in self.param_groups] 144 | 145 | for w in it.chain(*self.slow_weights): 146 | w.requires_grad = False 147 | 148 | def step(self, closure=None): 149 | loss = None 150 | if closure is not None: 151 | loss = closure() 152 | loss = self.optimizer.step() 153 | for group,slow_weights in zip(self.param_groups,self.slow_weights): 154 | group['step_counter'] += 1 155 | if group['step_counter'] % self.k != 0: 156 | continue 157 | for p,q in zip(group['params'],slow_weights): 158 | if p.grad is None: 159 | continue 160 | q.data.add_(self.alpha,p.data - q.data) 161 | p.data.copy_(q.data) 162 | return loss 163 | 164 | 165 | # MARK:- OneCycleLR 166 | class OneCycleLR(torch.optim.lr_scheduler._LRScheduler): 167 | r"""Sets the learning rate of each parameter group according to the 168 | 1cycle learning rate policy. The 1cycle policy anneals the learning 169 | rate from an initial learning rate to some maximum learning rate and then 170 | from that maximum learning rate to some minimum learning rate much lower 171 | than the initial learning rate. 172 | This policy was initially described in the paper `Super-Convergence: 173 | Very Fast Training of Neural Networks Using Large Learning Rates`_. 174 | 175 | The 1cycle learning rate policy changes the learning rate after every batch. 176 | `step` should be called after a batch has been used for training. 177 | 178 | This scheduler is not chainable. 179 | 180 | This class has two built-in annealing strategies: 181 | "cos": 182 | Cosine annealing 183 | "linear": 184 | Linear annealing 185 | 186 | Note also that the total number of steps in the cycle can be determined in one 187 | of two ways (listed in order of precedence): 188 | 1) A value for total_steps is explicitly provided. 189 | 2) A number of epochs (epochs) and a number of steps per epoch 190 | (steps_per_epoch) are provided. 191 | In this case, the number of total steps is inferred by 192 | total_steps = epochs * steps_per_epoch 193 | You must either provide a value for total_steps or provide a value for both 194 | epochs and steps_per_epoch. 195 | 196 | Args: 197 | optimizer (Optimizer): Wrapped optimizer. 198 | max_lr (float or list): Upper learning rate boundaries in the cycle 199 | for each parameter group. 200 | total_steps (int): The total number of steps in the cycle. Note that 201 | if a value is provided here, then it must be inferred by providing 202 | a value for epochs and steps_per_epoch. 203 | Default: None 204 | epochs (int): The number of epochs to train for. This is used along 205 | with steps_per_epoch in order to infer the total number of steps in the cycle 206 | if a value for total_steps is not provided. 207 | Default: None 208 | steps_per_epoch (int): The number of steps per epoch to train for. This is 209 | used along with epochs in order to infer the total number of steps in the 210 | cycle if a value for total_steps is not provided. 211 | Default: None 212 | pct_start (float): The percentage of the cycle (in number of steps) spent 213 | increasing the learning rate. 214 | Default: 0.3 215 | anneal_strategy (str): {'cos', 'linear'} 216 | Specifies the annealing strategy. 217 | Default: 'cos' 218 | cycle_momentum (bool): If ``True``, momentum is cycled inversely 219 | to learning rate between 'base_momentum' and 'max_momentum'. 220 | Default: True 221 | base_momentum (float or list): Lower momentum boundaries in the cycle 222 | for each parameter group. Note that momentum is cycled inversely 223 | to learning rate; at the peak of a cycle, momentum is 224 | 'base_momentum' and learning rate is 'max_lr'. 225 | Default: 0.85 226 | max_momentum (float or list): Upper momentum boundaries in the cycle 227 | for each parameter group. Functionally, 228 | it defines the cycle amplitude (max_momentum - base_momentum). 229 | Note that momentum is cycled inversely 230 | to learning rate; at the start of a cycle, momentum is 'max_momentum' 231 | and learning rate is 'base_lr' 232 | Default: 0.95 233 | div_factor (float): Determines the initial learning rate via 234 | initial_lr = max_lr/div_factor 235 | Default: 25 236 | final_div_factor (float): Determines the minimum learning rate via 237 | min_lr = initial_lr/final_div_factor 238 | Default: 1e4 239 | last_epoch (int): The index of the last batch. This parameter is used when 240 | resuming a training job. Since `step()` should be invoked after each 241 | batch instead of after each epoch, this number represents the total 242 | number of *batches* computed, not the total number of epochs computed. 243 | When last_epoch=-1, the schedule is started from the beginning. 244 | Default: -1 245 | 246 | Example: 247 | >>> data_loader = torch.utils.data.DataLoader(...) 248 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 249 | >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) 250 | >>> for epoch in range(10): 251 | >>> for batch in data_loader: 252 | >>> train_batch(...) 253 | >>> scheduler.step() 254 | 255 | 256 | .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: 257 | https://arxiv.org/abs/1708.07120 258 | """ 259 | def __init__(self, 260 | optimizer, 261 | max_lr, 262 | total_steps=None, 263 | epochs=None, 264 | steps_per_epoch=None, 265 | pct_start=0.3, 266 | anneal_strategy='cos', 267 | cycle_momentum=True, 268 | base_momentum=0.85, 269 | max_momentum=0.95, 270 | div_factor=25., 271 | final_div_factor=1e4, 272 | last_epoch=-1): 273 | 274 | # Validate optimizer 275 | if not isinstance(optimizer, Optimizer): 276 | raise TypeError('{} is not an Optimizer'.format( 277 | type(optimizer).__name__)) 278 | self.optimizer = optimizer 279 | 280 | # Validate total_steps 281 | if total_steps is None and epochs is None and steps_per_epoch is None: 282 | raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)") 283 | elif total_steps is not None: 284 | if total_steps <= 0 or not isinstance(total_steps, int): 285 | raise ValueError("Expected non-negative integer total_steps, but got {}".format(total_steps)) 286 | self.total_steps = total_steps 287 | else: 288 | if epochs <= 0 or not isinstance(epochs, int): 289 | raise ValueError("Expected non-negative integer epochs, but got {}".format(epochs)) 290 | if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int): 291 | raise ValueError("Expected non-negative integer steps_per_epoch, but got {}".format(steps_per_epoch)) 292 | self.total_steps = epochs * steps_per_epoch 293 | self.step_size_up = float(pct_start * self.total_steps) - 1 294 | self.step_size_down = float(self.total_steps - self.step_size_up) - 1 295 | 296 | # Validate pct_start 297 | if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): 298 | raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start)) 299 | 300 | # Validate anneal_strategy 301 | if anneal_strategy not in ['cos', 'linear']: 302 | raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy)) 303 | elif anneal_strategy == 'cos': 304 | self.anneal_func = self._annealing_cos 305 | elif anneal_strategy == 'linear': 306 | self.anneal_func = self._annealing_linear 307 | 308 | # Initialize learning rate variables 309 | max_lrs = self._format_param('max_lr', self.optimizer, max_lr) 310 | if last_epoch == -1: 311 | for idx, group in enumerate(self.optimizer.param_groups): 312 | group['lr'] = max_lrs[idx] / div_factor 313 | group['max_lr'] = max_lrs[idx] 314 | group['min_lr'] = group['lr'] / final_div_factor 315 | 316 | # Initialize momentum variables 317 | self.cycle_momentum = cycle_momentum 318 | if self.cycle_momentum: 319 | if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults: 320 | raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') 321 | self.use_beta1 = 'betas' in self.optimizer.defaults 322 | max_momentums = self._format_param('max_momentum', optimizer, max_momentum) 323 | base_momentums = self._format_param('base_momentum', optimizer, base_momentum) 324 | if last_epoch == -1: 325 | for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups): 326 | if self.use_beta1: 327 | _, beta2 = group['betas'] 328 | group['betas'] = (m_momentum, beta2) 329 | else: 330 | group['momentum'] = m_momentum 331 | group['max_momentum'] = m_momentum 332 | group['base_momentum'] = b_momentum 333 | 334 | super(OneCycleLR, self).__init__(optimizer, last_epoch) 335 | 336 | def _format_param(self, name, optimizer, param): 337 | """Return correctly formatted lr/momentum for each param group.""" 338 | if isinstance(param, (list, tuple)): 339 | if len(param) != len(optimizer.param_groups): 340 | raise ValueError("expected {} values for {}, got {}".format( 341 | len(optimizer.param_groups), name, len(param))) 342 | return param 343 | else: 344 | return [param] * len(optimizer.param_groups) 345 | 346 | def _annealing_cos(self, start, end, pct): 347 | "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." 348 | cos_out = math.cos(math.pi * pct) + 1 349 | return end + (start - end) / 2.0 * cos_out 350 | 351 | def _annealing_linear(self, start, end, pct): 352 | "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." 353 | return (end - start) * pct + start 354 | 355 | def get_lr(self): 356 | lrs = [] 357 | step_num = self.last_epoch 358 | 359 | if step_num > self.total_steps: 360 | raise ValueError("Tried to step {} times. The specified number of total steps is {}" 361 | .format(step_num + 1, self.total_steps)) 362 | 363 | for group in self.optimizer.param_groups: 364 | if step_num <= self.step_size_up: 365 | computed_lr = self.anneal_func(group['initial_lr'], group['max_lr'], step_num / self.step_size_up) 366 | if self.cycle_momentum: 367 | computed_momentum = self.anneal_func(group['max_momentum'], group['base_momentum'], 368 | step_num / self.step_size_up) 369 | else: 370 | down_step_num = step_num - self.step_size_up 371 | computed_lr = self.anneal_func(group['max_lr'], group['min_lr'], down_step_num / self.step_size_down) 372 | if self.cycle_momentum: 373 | computed_momentum = self.anneal_func(group['base_momentum'], group['max_momentum'], 374 | down_step_num / self.step_size_down) 375 | 376 | lrs.append(computed_lr) 377 | if self.cycle_momentum: 378 | if self.use_beta1: 379 | _, beta2 = group['betas'] 380 | group['betas'] = (computed_momentum, beta2) 381 | else: 382 | group['momentum'] = computed_momentum 383 | 384 | return lrs 385 | 386 | -------------------------------------------------------------------------------- /train/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from preprocessing.nn_dataset import MAX_SEQ 4 | from preprocessing.nn_dataset import position_encoding_init 5 | from train.external import VariationalDropout 6 | from train.transformer import TransformerDecoder, TransformerDecoderLayer, TransformerEncoder, TransformerEncoderLayer 7 | 8 | """ 9 | File containing classes representing various Neural architectures 10 | """ 11 | 12 | 13 | # MARK:- TonicNet 14 | class TonicNet(nn.Module): 15 | def __init__(self, nb_tags, nb_layers=1, z_dim =0, 16 | nb_rnn_units=100, batch_size=1, seq_len=1, dropout=0.0): 17 | super(TonicNet, self).__init__() 18 | 19 | self.nb_layers = nb_layers 20 | self.nb_rnn_units = nb_rnn_units 21 | self.batch_size = batch_size 22 | self.seq_len = seq_len 23 | self.dropout = dropout 24 | self.z_dim = z_dim 25 | self.z_emb_size = 32 26 | 27 | self.nb_tags = nb_tags 28 | 29 | # build actual NN 30 | self.__build_model() 31 | 32 | def __build_model(self): 33 | 34 | self.embedding = nn.Embedding(self.nb_tags, self.nb_rnn_units) 35 | 36 | # Unused but key exists in state_dict 37 | self.pos_emb = nn.Embedding(64, 0) 38 | 39 | self.z_embedding = nn.Embedding(80, self.z_emb_size) 40 | 41 | self.dropout_i = VariationalDropout(max(0.0, self.dropout - 0.2), batch_first=True) 42 | 43 | # design RNN 44 | 45 | input_size = self.nb_rnn_units 46 | if self.z_dim > 0: 47 | input_size += self.z_emb_size 48 | 49 | self.rnn = nn.GRU( 50 | input_size=input_size, 51 | hidden_size=self.nb_rnn_units, 52 | num_layers=self.nb_layers, 53 | batch_first=True, 54 | dropout=self.dropout 55 | ) 56 | self.dropout_o = VariationalDropout(self.dropout, batch_first=True) 57 | 58 | # output layer which projects back to tag space 59 | self.hidden_to_tag = nn.Linear(input_size, self.nb_tags, bias=False) 60 | 61 | def init_hidden(self): 62 | # the weights are of the form (nb_layers, batch_size, nb_rnn_units) 63 | hidden_a = torch.randn(self.nb_layers, self.batch_size, self.nb_rnn_units) 64 | 65 | if torch.cuda.is_available(): 66 | hidden_a = hidden_a.cuda() 67 | 68 | return hidden_a 69 | 70 | def forward(self, X, z=None, train_embedding=True, sampling=False, reset_hidden=True): 71 | # reset the RNN hidden state. 72 | if not sampling: 73 | self.seq_len = X.shape[1] 74 | if reset_hidden: 75 | self.hidden = self.init_hidden() 76 | 77 | self.embedding.weight.requires_grad = train_embedding 78 | 79 | # --------------------- 80 | # Combine inputs 81 | X = self.embedding(X) 82 | X = X.view(self.batch_size, self.seq_len, self.nb_rnn_units) 83 | 84 | # repeating pitch encoding 85 | if self.z_dim > 0: 86 | Z = self.z_embedding(z % 80) 87 | Z = Z.view(self.batch_size, self.seq_len, self.z_emb_size) 88 | X = torch.cat((Z, X), 2) 89 | 90 | X = self.dropout_i(X) 91 | 92 | # Run through RNN 93 | X, self.hidden = self.rnn(X, self.hidden) 94 | 95 | if self.z_dim > 0: 96 | X = torch.cat((Z, X), 2) 97 | 98 | X = self.dropout_o(X) 99 | 100 | # run through linear layer 101 | X = self.hidden_to_tag(X) 102 | 103 | Y_hat = X 104 | return Y_hat 105 | 106 | 107 | class Transformer_Model(nn.Module): 108 | def __init__(self, nb_tags, nb_layers=1, pe_dim=0, 109 | emb_dim=100, batch_size=1, seq_len=MAX_SEQ, dropout=0.0, encoder_only=True): 110 | super(Transformer_Model, self).__init__() 111 | 112 | self.nb_layers = nb_layers 113 | self.emb_dim = emb_dim 114 | self.batch_size = batch_size 115 | self.seq_len = seq_len 116 | self.pe_dim = pe_dim 117 | self.dropout = dropout 118 | 119 | self.nb_tags = nb_tags 120 | 121 | self.encoder_only = encoder_only 122 | 123 | # build actual NN 124 | self.__build_model() 125 | 126 | def __build_model(self): 127 | 128 | self.embedding = nn.Embedding(self.nb_tags, self.emb_dim) 129 | 130 | if not self.encoder_only: 131 | self.embedding2 = nn.Embedding(self.nb_tags, self.emb_dim) 132 | 133 | self.pos_emb = position_encoding_init(MAX_SEQ, self.pe_dim) 134 | self.pos_emb.requires_grad = False 135 | 136 | self.dropout_i = nn.Dropout(self.dropout) 137 | 138 | input_size = self.pe_dim + self.emb_dim 139 | 140 | self.transformerLayerI = TransformerEncoderLayer(d_model=input_size, 141 | nhead=8, 142 | dropout=self.dropout, 143 | dim_feedforward=1024) 144 | 145 | self.transformerI = TransformerEncoder(self.transformerLayerI, 146 | num_layers=self.nb_layers,) 147 | 148 | self.dropout_m = nn.Dropout(self.dropout) 149 | 150 | if not self.encoder_only: 151 | # design decoder 152 | self.transformerLayerO = TransformerDecoderLayer(d_model=input_size, 153 | nhead=8, 154 | dropout=self.dropout, 155 | dim_feedforward=1024) 156 | 157 | self.transformerO = TransformerDecoder(self.transformerLayerO, 158 | num_layers=self.nb_layers, ) 159 | 160 | self.dropout_o = nn.Dropout(self.dropout) 161 | 162 | # output layer which projects back to tag space 163 | self.hidden_to_tag = nn.Linear(self.emb_dim + self.pe_dim, self.nb_tags) 164 | 165 | def __pos_encode(self, p): 166 | return self.pos_emb[p] 167 | 168 | def forward(self, X, p, X2=None, train_embedding=True): 169 | 170 | self.embedding.weight.requires_grad = train_embedding 171 | if not self.encoder_only: 172 | self.embedding2.weight.requires_grad = train_embedding 173 | 174 | I = X 175 | 176 | self.mask = (torch.triu(torch.ones(self.seq_len, self.seq_len)) == 1).transpose(0, 1) 177 | self.mask = self.mask.float().masked_fill(self.mask == 0, float('-inf')).masked_fill(self.mask == 1, float(0.0)) 178 | 179 | if torch.cuda.is_available(): 180 | self.mask = self.mask.cuda() 181 | 182 | # --------------------- 183 | # Combine inputs 184 | X = self.embedding(I) 185 | X = X.view(self.seq_len, self.batch_size, -1) 186 | 187 | if self.pe_dim > 0: 188 | P = self.__pos_encode(p) 189 | P = P.view(self.seq_len, self.batch_size, -1) 190 | X = torch.cat((X, P), 2) 191 | 192 | X = self.dropout_i(X) 193 | 194 | # Run through transformer encoder 195 | 196 | M = self.transformerI(X, mask=self.mask) 197 | M = self.dropout_m(M) 198 | 199 | if not self.encoder_only: 200 | # --------------------- 201 | # Decoder stack 202 | X = self.embedding2(X2) 203 | X = X.view(self.seq_len, self.batch_size, -1) 204 | 205 | if self.pe_dim > 0: 206 | X = torch.cat((X, P), 2) 207 | 208 | X = self.dropout_i(X) 209 | 210 | X = self.transformerO(X, M, tgt_mask=self.mask, memory_mask=None) 211 | X = self.dropout_o(X) 212 | 213 | # run through linear layer 214 | X = self.hidden_to_tag(X) 215 | else: 216 | X = self.hidden_to_tag(M) 217 | 218 | Y_hat = X 219 | return Y_hat 220 | 221 | 222 | # MARK:- Custom s2s Cross Entropy loss 223 | class CrossEntropyTimeDistributedLoss(nn.Module): 224 | """loss function for multi-timsetep model output""" 225 | def __init__(self): 226 | super(CrossEntropyTimeDistributedLoss, self).__init__() 227 | 228 | self.loss_func = nn.CrossEntropyLoss() 229 | 230 | def forward(self, y_hat, y): 231 | 232 | _y_hat = y_hat.squeeze(0) 233 | _y = y.squeeze(0) 234 | 235 | # Loss from one sequence 236 | loss = self.loss_func(_y_hat, _y) 237 | loss = torch.sum(loss) 238 | return loss 239 | 240 | 241 | -------------------------------------------------------------------------------- /train/train_nn.py: -------------------------------------------------------------------------------- 1 | import time, os 2 | import math 3 | import matplotlib.pyplot as plt 4 | from torch import save, set_grad_enabled, sum, max 5 | from torch import optim, cuda, load, device 6 | from torch.nn.utils import clip_grad_norm_ 7 | from preprocessing.nn_dataset import get_data_set, TOTAL_BATCHES, TRAIN_BATCHES, N_TOKENS 8 | from train.models import CrossEntropyTimeDistributedLoss 9 | from train.models import TonicNet, Transformer_Model 10 | from train.external import RAdam, Lookahead, OneCycleLR 11 | 12 | """ 13 | File containing functions which train various neural networks defined in train.models 14 | """ 15 | 16 | 17 | CV_PHASES = ['train', 'val'] 18 | TRAIN_ONLY_PHASES = ['train'] 19 | 20 | 21 | # MARK:- TonicNet 22 | def TonicNet_lr_finder(train_emb_freq=3000, load_path=''): 23 | train_TonicNet(epochs=3, 24 | save_model=False, 25 | load_path=load_path, 26 | shuffle_batches=True, 27 | num_batches=TRAIN_BATCHES, 28 | val=False, 29 | train_emb_freq=train_emb_freq, 30 | lr_range_test=True) 31 | 32 | 33 | def TonicNet_sanity_test(num_batches=1, train_emb_freq=3000, load_path=''): 34 | train_TonicNet(epochs=200, 35 | save_model=False, 36 | load_path=load_path, 37 | shuffle_batches=False, 38 | num_batches=num_batches, 39 | val=1, 40 | train_emb_freq=train_emb_freq, 41 | lr_range_test=False, 42 | sanity_test=True) 43 | 44 | 45 | def train_TonicNet(epochs, 46 | save_model=True, 47 | load_path='', 48 | shuffle_batches=False, 49 | num_batches=TOTAL_BATCHES, 50 | val=True, 51 | train_emb_freq=1, 52 | lr_range_test=False, 53 | sanity_test=False): 54 | 55 | model = TonicNet(nb_tags=N_TOKENS, z_dim=32, nb_layers=3, nb_rnn_units=256, dropout=0.3) 56 | 57 | if load_path != '': 58 | try: 59 | if cuda.is_available(): 60 | model.load_state_dict(load(load_path)['model_state_dict']) 61 | else: 62 | model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict']) 63 | print("loded params from", load_path) 64 | except: 65 | raise ImportError(f'No file located at {load_path}, could not load parameters') 66 | print(model) 67 | 68 | if cuda.is_available(): 69 | model.cuda() 70 | 71 | base_lr = 0.2 72 | max_lr = 0.2 73 | 74 | if lr_range_test: 75 | base_lr = 0.000003 76 | max_lr = 0.5 77 | 78 | step_size = 3 * min(TRAIN_BATCHES, num_batches) 79 | 80 | if sanity_test: 81 | base_optim = RAdam(model.parameters(), lr=base_lr) 82 | optimiser = Lookahead(base_optim, k=5, alpha=0.5) 83 | else: 84 | optimiser = optim.SGD(model.parameters(), base_lr) 85 | criterion = CrossEntropyTimeDistributedLoss() 86 | 87 | print(criterion) 88 | 89 | print(f"min lr: {base_lr}, max_lr: {max_lr}, stepsize: {step_size}") 90 | 91 | if not sanity_test and not lr_range_test: 92 | scheduler = OneCycleLR(optimiser, max_lr, 93 | epochs=60, steps_per_epoch=TRAIN_BATCHES, pct_start=0.3, 94 | anneal_strategy='cos', cycle_momentum=True, base_momentum=0.8, 95 | max_momentum=0.95, div_factor=25.0, final_div_factor=1000.0, 96 | last_epoch=-1) 97 | 98 | elif lr_range_test: 99 | lr_lambda = lambda x: math.exp(x * math.log(max_lr / base_lr) / (epochs * num_batches)) 100 | scheduler = optim.lr_scheduler.LambdaLR(optimiser, lr_lambda) 101 | 102 | best_val_loss = 100.0 103 | 104 | if lr_range_test: 105 | lr_find_loss = [] 106 | lr_find_lr = [] 107 | 108 | itr = 0 109 | smoothing = 0.05 110 | 111 | if val: 112 | phases = CV_PHASES 113 | else: 114 | phases = TRAIN_ONLY_PHASES 115 | 116 | for epoch in range(epochs): 117 | start = time.time() 118 | pr_interval = 50 119 | 120 | print(f'Beginning EPOCH {epoch + 1}') 121 | 122 | for phase in phases: 123 | 124 | count = 0 125 | batch_count = 0 126 | loss_epoch = 0 127 | running_accuray = 0.0 128 | running_batch_count = 0 129 | print_loss_batch = 0 # Reset on print 130 | print_acc_batch = 0 # Reset on print 131 | 132 | print(f'\n\tPHASE: {phase}') 133 | 134 | if phase == 'train': 135 | model.train() # Set model to training mode 136 | else: 137 | model.eval() # Set model to evaluate mode 138 | 139 | for x, y, psx, i, c in get_data_set(phase, shuffle_batches=shuffle_batches, return_I=1): 140 | model.zero_grad() 141 | 142 | if phase == 'train' and (epoch > -1 or load_path != ''): 143 | if train_emb_freq < 1000: 144 | train_emb = ((count % train_emb_freq) == 0) 145 | else: 146 | train_emb = False 147 | 148 | else: 149 | train_emb = False 150 | 151 | with set_grad_enabled(phase == 'train'): 152 | y_hat = model(x, z=i, train_embedding=train_emb) 153 | _, preds = max(y_hat, 2) 154 | loss = criterion(y_hat, y,) 155 | 156 | if phase == 'train': 157 | loss.backward() 158 | clip_grad_norm_(model.parameters(), 5) 159 | optimiser.step() 160 | if not sanity_test: 161 | scheduler.step() 162 | 163 | loss_epoch += loss.item() 164 | print_loss_batch += loss.item() 165 | running_accuray += sum(preds == y) 166 | print_acc_batch += sum(preds == y) 167 | 168 | count += 1 169 | batch_count += x.shape[1] 170 | running_batch_count += x.shape[1] 171 | 172 | if lr_range_test: 173 | lr_step = optimiser.state_dict()["param_groups"][0]["lr"] 174 | lr_find_lr.append(lr_step) 175 | 176 | # smooth the loss 177 | if itr == 0: 178 | lr_find_loss.append(min(loss, 7)) 179 | else: 180 | loss = smoothing * min(loss, 7) + (1 - smoothing) * lr_find_loss[-1] 181 | lr_find_loss.append(loss) 182 | 183 | itr += 1 184 | 185 | # print loss for recent set of batches 186 | if count % pr_interval == 0: 187 | ave_loss = print_loss_batch/pr_interval 188 | ave_acc = 100 * print_acc_batch.float()/running_batch_count 189 | print_acc_batch = 0 190 | running_batch_count = 0 191 | print('\t\t[%d] loss: %.3f, acc: %.3f' % (count, ave_loss, ave_acc)) 192 | print_loss_batch = 0 193 | 194 | if count == num_batches: 195 | break 196 | 197 | # calculate loss and accuracy for phase 198 | ave_loss_epoch = loss_epoch/count 199 | epoch_acc = 100 * running_accuray.float() / batch_count 200 | print('\tfinished %s phase [%d] loss: %.3f, acc: %.3f' % (phase, epoch + 1, ave_loss_epoch, epoch_acc)) 201 | 202 | print('\n\ttime:', __time_since(start), '\n') 203 | 204 | # save model when validation loss improves 205 | if ave_loss_epoch < best_val_loss: 206 | best_val_loss = ave_loss_epoch 207 | print("\tNEW BEST LOSS: %.3f" % ave_loss_epoch, '\n') 208 | 209 | if save_model: 210 | __save_model(epoch, ave_loss_epoch, model, "TonicNet", epoch_acc) 211 | else: 212 | print("\tLOSS DID NOT IMPROVE FROM %.3f" % best_val_loss, '\n') 213 | 214 | print("DONE") 215 | if lr_range_test: 216 | plt.plot(lr_find_lr, lr_find_loss) 217 | plt.xscale('log') 218 | plt.grid('true') 219 | plt.savefig('lr_finder.png') 220 | plt.show() 221 | 222 | 223 | # MARK:- Transformer 224 | def Transformer_lr_finder(load_path=''): 225 | train_Transformer(epochs=3, 226 | save_model=False, 227 | load_path=load_path, 228 | shuffle_batches=True, 229 | num_batches=TRAIN_BATCHES, 230 | val=False, 231 | lr_range_test=True) 232 | 233 | 234 | def Transformer_sanity_test(num_batches=1, load_path=''): 235 | train_Transformer(epochs=1000, 236 | save_model=0, 237 | load_path=load_path, 238 | shuffle_batches=False, 239 | num_batches=num_batches, 240 | val=1, 241 | lr_range_test=False, 242 | sanity_test=True) 243 | 244 | 245 | def train_Transformer(epochs, 246 | save_model=True, 247 | load_path='', 248 | shuffle_batches=False, 249 | num_batches=TOTAL_BATCHES, 250 | val=True, 251 | lr_range_test=False, 252 | sanity_test=False): 253 | 254 | model = Transformer_Model(nb_tags=N_TOKENS, nb_layers=5, emb_dim=256, dropout=0.1, pe_dim=256) 255 | 256 | if load_path != '': 257 | try: 258 | if cuda.is_available(): 259 | model.load_state_dict(load(load_path)['model_state_dict']) 260 | else: 261 | model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict']) 262 | print("loded params from", load_path) 263 | except: 264 | raise ImportError(f'No file located at {load_path}, could not load parameters') 265 | print(model) 266 | 267 | if cuda.is_available(): 268 | model.cuda() 269 | 270 | base_lr = 0.06 271 | max_lr = 0.06 272 | 273 | if lr_range_test: 274 | base_lr = 0.000003 275 | max_lr = 0.3 276 | 277 | step_size = 3 * min(TRAIN_BATCHES, num_batches) 278 | 279 | if sanity_test: 280 | base_optim = RAdam(model.parameters(), lr=base_lr/100) 281 | optimiser = Lookahead(base_optim, k=5, alpha=0.5) 282 | else: 283 | optimiser = optim.SGD(model.parameters(), base_lr) 284 | criterion = CrossEntropyTimeDistributedLoss() 285 | 286 | print(criterion) 287 | 288 | print(f"min lr: {base_lr}, max_lr: {max_lr}, stepsize: {step_size}") 289 | 290 | if not sanity_test and not lr_range_test: 291 | scheduler = OneCycleLR(optimiser, max_lr, 292 | epochs=30, steps_per_epoch=TRAIN_BATCHES, pct_start=0.3, 293 | anneal_strategy='cos', cycle_momentum=True, base_momentum=0.8, 294 | max_momentum=0.95, div_factor=100.0, final_div_factor=1000.0, 295 | last_epoch=-1) 296 | 297 | elif lr_range_test: 298 | lr_lambda = lambda x: math.exp(x * math.log(max_lr / base_lr) / (epochs * TRAIN_BATCHES)) 299 | scheduler = optim.lr_scheduler.LambdaLR(optimiser, lr_lambda) 300 | 301 | best_val_loss = 100.0 302 | step = 0 303 | 304 | if lr_range_test: 305 | lr_find_loss = [] 306 | lr_find_lr = [] 307 | 308 | itr = 0 309 | smoothing = 0.05 310 | 311 | if val: 312 | phases = CV_PHASES 313 | else: 314 | phases = TRAIN_ONLY_PHASES 315 | 316 | for epoch in range(epochs): 317 | start = time.time() 318 | pr_interval = 50 319 | 320 | print(f'Beginning EPOCH {epoch + 1}') 321 | 322 | for phase in phases: 323 | phase_loss = None 324 | model.zero_grad() 325 | count = 0 326 | batch_count = 0 327 | loss_epoch = 0 328 | running_accuray = 0.0 329 | running_batch_count = 0 330 | print_loss_batch = 0 # Reset on print 331 | print_acc_batch = 0 # Reset on print 332 | 333 | print(f'\n\tPHASE: {phase}') 334 | 335 | if phase == 'train': 336 | model.train() # Set model to training mode 337 | else: 338 | model.eval() # Set model to evaluate mode 339 | 340 | for x, y, psx, i, c in get_data_set(phase, shuffle_batches=shuffle_batches, return_I=1): 341 | 342 | Y = y 343 | model.seq_len = x.shape[1] 344 | 345 | with set_grad_enabled(phase == 'train'): 346 | y_hat = model(x, psx) 347 | y_hat = y_hat.view(1, -1, N_TOKENS) 348 | _, preds = max(y_hat, 2) 349 | loss = criterion(y_hat, Y, ) 350 | if phase_loss is None: 351 | phase_loss = loss 352 | else: 353 | phase_loss += loss 354 | 355 | loss_epoch += loss.item() 356 | print_loss_batch += loss.item() 357 | 358 | len_batch = model.seq_len 359 | 360 | running_accuray += sum(preds == Y) 361 | print_acc_batch += sum(preds == Y) 362 | 363 | count += 1 364 | batch_count += len_batch 365 | running_batch_count += len_batch 366 | 367 | if count % 1 == 0: 368 | if phase == 'train': 369 | phase_loss.backward() 370 | clip_grad_norm_(model.parameters(), 5) 371 | optimiser.step() 372 | step += 1 373 | if not sanity_test: 374 | scheduler.step() 375 | phase_loss = None 376 | model.zero_grad() 377 | 378 | if lr_range_test: 379 | lr_step = optimiser.state_dict()["param_groups"][0]["lr"] 380 | lr_find_lr.append(lr_step) 381 | 382 | # smooth the loss 383 | if itr == 0: 384 | lr_find_loss.append(min(loss, 4)) 385 | else: 386 | loss = smoothing * min(loss, 4) + (1 - smoothing) * lr_find_loss[-1] 387 | lr_find_loss.append(loss) 388 | 389 | itr += 1 390 | 391 | # print loss for recent set of batches 392 | if count % pr_interval == 0: 393 | ave_loss = print_loss_batch/pr_interval 394 | ave_acc = 100 * print_acc_batch.float()/running_batch_count 395 | print_acc_batch = 0 396 | running_batch_count = 0 397 | print('\t\t[%d] loss: %.3f, acc: %.3f' % (count, ave_loss, ave_acc)) 398 | print_loss_batch = 0 399 | 400 | if count == num_batches: 401 | break 402 | 403 | # calculate loss and accuracy for phase 404 | ave_loss_epoch = loss_epoch/count 405 | epoch_acc = 100 * running_accuray.float() / batch_count 406 | print('\tfinished %s phase [%d] loss: %.3f, acc: %.3f' % (phase, epoch + 1, ave_loss_epoch, epoch_acc)) 407 | 408 | print('\n\ttime:', __time_since(start), '\n') 409 | 410 | # save model when validation loss improves 411 | if ave_loss_epoch < best_val_loss: 412 | best_val_loss = ave_loss_epoch 413 | print("\tNEW BEST LOSS: %.3f" % ave_loss_epoch, '\n') 414 | 415 | if save_model: 416 | __save_model(epoch, ave_loss_epoch, model, "EncoReTransformer", epoch_acc) 417 | else: 418 | print("\tLOSS DID NOT IMPROVE FROM %.3f" % best_val_loss, '\n') 419 | 420 | print("DONE") 421 | if lr_range_test: 422 | plt.plot(lr_find_lr, lr_find_loss) 423 | plt.xscale('log') 424 | plt.grid('true') 425 | plt.savefig('lr_finder.png') 426 | plt.show() 427 | 428 | 429 | def __save_model(epoch, ave_loss_epoch, model, model_name, acc): 430 | 431 | test = os.listdir('eval') 432 | 433 | for item in test: 434 | if item.endswith(".pt"): 435 | os.remove(os.path.join('eval', item)) 436 | 437 | path_loss = round(ave_loss_epoch, 3) 438 | path_acc = '%.3f' % acc 439 | path = f'eval/{model_name}_epoch-{epoch}_loss-{path_loss}_acc-{path_acc}.pt' 440 | save({ 441 | 'epoch': epoch, 442 | 'model_state_dict': model.state_dict(), 443 | 'loss': path_loss 444 | }, path) 445 | print("\tSAVED MODEL TO:", path) 446 | 447 | 448 | def __time_since(t): 449 | now = time.time() 450 | s = now - t 451 | return '%s' % (__as_minutes(s)) 452 | 453 | 454 | def __as_minutes(s): 455 | m = math.floor(s / 60) 456 | s -= m * 60 457 | return '%dm %ds' % (m, s) 458 | -------------------------------------------------------------------------------- /train/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import warnings 4 | from torch.nn import functional as F 5 | from torch.nn import Module, Parameter 6 | from torch.nn import ModuleList 7 | from torch.nn.init import xavier_uniform_, xavier_normal_, constant_ 8 | from torch.nn import Dropout 9 | from torch.nn import Linear 10 | from torch.nn import LayerNorm 11 | 12 | """ 13 | Official pytorch implementation of the Transformer model 14 | """ 15 | 16 | 17 | class Transformer(Module): 18 | r"""A transformer model. User is able to modify the attributes as needed. The architechture 19 | is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, 20 | Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and 21 | Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information 22 | Processing Systems, pages 6000-6010. 23 | Args: 24 | d_model: the number of expected features in the encoder/decoder inputs (default=512). 25 | nhead: the number of heads in the multiheadattention models (default=8). 26 | num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). 27 | num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). 28 | dim_feedforward: the dimension of the feedforward network model (default=2048). 29 | dropout: the dropout value (default=0.1). 30 | custom_encoder: custom encoder (default=None). 31 | custom_decoder: custom decoder (default=None). 32 | Examples:: 33 | >>> transformer_model = nn.Transformer(src_vocab, tgt_vocab) 34 | >>> transformer_model = nn.Transformer(src_vocab, tgt_vocab, nhead=16, num_encoder_layers=12) 35 | """ 36 | 37 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 38 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 39 | custom_encoder=None, custom_decoder=None): 40 | super(Transformer, self).__init__() 41 | 42 | if custom_encoder is not None: 43 | self.encoder = custom_encoder 44 | else: 45 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) 46 | encoder_norm = LayerNorm(d_model) 47 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 48 | 49 | if custom_decoder is not None: 50 | self.decoder = custom_decoder 51 | else: 52 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout) 53 | decoder_norm = LayerNorm(d_model) 54 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) 55 | 56 | self._reset_parameters() 57 | 58 | self.d_model = d_model 59 | self.nhead = nhead 60 | 61 | def forward(self, src, tgt, src_mask=None, tgt_mask=None, 62 | memory_mask=None, src_key_padding_mask=None, 63 | tgt_key_padding_mask=None, memory_key_padding_mask=None): 64 | r"""Take in and process masked source/target sequences. 65 | Args: 66 | src: the sequence to the encoder (required). 67 | tgt: the sequence to the decoder (required). 68 | src_mask: the additive mask for the src sequence (optional). 69 | tgt_mask: the additive mask for the tgt sequence (optional). 70 | memory_mask: the additive mask for the encoder output (optional). 71 | src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). 72 | tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). 73 | memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). 74 | Shape: 75 | - src: :math:`(S, N, E)`. 76 | - tgt: :math:`(T, N, E)`. 77 | - src_mask: :math:`(S, S)`. 78 | - tgt_mask: :math:`(T, T)`. 79 | - memory_mask: :math:`(T, S)`. 80 | - src_key_padding_mask: :math:`(N, S)`. 81 | - tgt_key_padding_mask: :math:`(N, T)`. 82 | - memory_key_padding_mask: :math:`(N, S)`. 83 | Note: [src/tgt/memory]_mask should be filled with 84 | float('-inf') for the masked positions and float(0.0) else. These masks 85 | ensure that predictions for position i depend only on the unmasked positions 86 | j and are applied identically for each sequence in a batch. 87 | [src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions 88 | that should be masked with float('-inf') and False values will be unchanged. 89 | This mask ensures that no information will be taken from position i if 90 | it is masked, and has a separate mask for each sequence in a batch. 91 | - output: :math:`(T, N, E)`. 92 | Note: Due to the multi-head attention architecture in the transformer model, 93 | the output sequence length of a transformer is same as the input sequence 94 | (i.e. target) length of the decode. 95 | where S is the source sequence length, T is the target sequence length, N is the 96 | batch size, E is the feature number 97 | Examples: 98 | >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) 99 | """ 100 | 101 | if src.size(1) != tgt.size(1): 102 | raise RuntimeError("the batch number of src and tgt must be equal") 103 | 104 | if src.size(2) != self.d_model or tgt.size(2) != self.d_model: 105 | raise RuntimeError("the feature number of src and tgt must be equal to d_model") 106 | 107 | memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) 108 | output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, 109 | tgt_key_padding_mask=tgt_key_padding_mask, 110 | memory_key_padding_mask=memory_key_padding_mask) 111 | return output 112 | 113 | 114 | def generate_square_subsequent_mask(self, sz): 115 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). 116 | Unmasked positions are filled with float(0.0). 117 | """ 118 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 119 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 120 | return mask 121 | 122 | 123 | def _reset_parameters(self): 124 | r"""Initiate parameters in the transformer model.""" 125 | 126 | for p in self.parameters(): 127 | if p.dim() > 1: 128 | xavier_uniform_(p) 129 | 130 | 131 | class TransformerEncoder(Module): 132 | r"""TransformerEncoder is a stack of N encoder layers 133 | Args: 134 | encoder_layer: an instance of the TransformerEncoderLayer() class (required). 135 | num_layers: the number of sub-encoder-layers in the encoder (required). 136 | norm: the layer normalization component (optional). 137 | Examples:: 138 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) 139 | >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers) 140 | """ 141 | 142 | def __init__(self, encoder_layer, num_layers, norm=None): 143 | super(TransformerEncoder, self).__init__() 144 | self.layers = _get_clones(encoder_layer, num_layers) 145 | self.num_layers = num_layers 146 | self.norm = norm 147 | 148 | def forward(self, src, mask=None, src_key_padding_mask=None): 149 | r"""Pass the input through the endocder layers in turn. 150 | Args: 151 | src: the sequnce to the encoder (required). 152 | mask: the mask for the src sequence (optional). 153 | src_key_padding_mask: the mask for the src keys per batch (optional). 154 | Shape: 155 | see the docs in Transformer class. 156 | """ 157 | output = src 158 | 159 | for i in range(self.num_layers): 160 | output = self.layers[i](output, src_mask=mask, 161 | src_key_padding_mask=src_key_padding_mask) 162 | 163 | if self.norm: 164 | output = self.norm(output) 165 | 166 | return output 167 | 168 | 169 | class TransformerDecoder(Module): 170 | r"""TransformerDecoder is a stack of N decoder layers 171 | Args: 172 | decoder_layer: an instance of the TransformerDecoderLayer() class (required). 173 | num_layers: the number of sub-decoder-layers in the decoder (required). 174 | norm: the layer normalization component (optional). 175 | Examples:: 176 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model, nhead) 177 | >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers) 178 | """ 179 | 180 | def __init__(self, decoder_layer, num_layers, norm=None): 181 | super(TransformerDecoder, self).__init__() 182 | self.layers = _get_clones(decoder_layer, num_layers) 183 | self.num_layers = num_layers 184 | self.norm = norm 185 | 186 | def forward(self, tgt, memory, tgt_mask=None, 187 | memory_mask=None, tgt_key_padding_mask=None, 188 | memory_key_padding_mask=None): 189 | r"""Pass the inputs (and mask) through the decoder layer in turn. 190 | Args: 191 | tgt: the sequence to the decoder (required). 192 | memory: the sequnce from the last layer of the encoder (required). 193 | tgt_mask: the mask for the tgt sequence (optional). 194 | memory_mask: the mask for the memory sequence (optional). 195 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 196 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 197 | Shape: 198 | see the docs in Transformer class. 199 | """ 200 | output = tgt 201 | 202 | for i in range(self.num_layers): 203 | output = self.layers[i](output, memory, tgt_mask=tgt_mask, 204 | memory_mask=memory_mask, 205 | tgt_key_padding_mask=tgt_key_padding_mask, 206 | memory_key_padding_mask=memory_key_padding_mask) 207 | 208 | if self.norm: 209 | output = self.norm(output) 210 | 211 | return output 212 | 213 | 214 | class TransformerEncoderLayer(Module): 215 | r"""TransformerEncoderLayer is made up of self-attn and feedforward network. 216 | This standard encoder layer is based on the paper "Attention Is All You Need". 217 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 218 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 219 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 220 | in a different way during application. 221 | Args: 222 | d_model: the number of expected features in the input (required). 223 | nhead: the number of heads in the multiheadattention models (required). 224 | dim_feedforward: the dimension of the feedforward network model (default=2048). 225 | dropout: the dropout value (default=0.1). 226 | Examples:: 227 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) 228 | """ 229 | 230 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1): 231 | super(TransformerEncoderLayer, self).__init__() 232 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 233 | # Implementation of Feedforward model 234 | self.linear1 = Linear(d_model, dim_feedforward) 235 | self.dropout = Dropout(dropout) 236 | self.linear2 = Linear(dim_feedforward, d_model) 237 | 238 | self.norm1 = LayerNorm(d_model) 239 | self.norm2 = LayerNorm(d_model) 240 | self.dropout1 = Dropout(dropout) 241 | self.dropout2 = Dropout(dropout) 242 | 243 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 244 | r"""Pass the input through the endocder layer. 245 | Args: 246 | src: the sequnce to the encoder layer (required). 247 | src_mask: the mask for the src sequence (optional). 248 | src_key_padding_mask: the mask for the src keys per batch (optional). 249 | Shape: 250 | see the docs in Transformer class. 251 | """ 252 | src2 = self.self_attn(src, src, src, attn_mask=src_mask, 253 | key_padding_mask=src_key_padding_mask)[0] 254 | src = src + self.dropout1(src2) 255 | src = self.norm1(src) 256 | src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) 257 | src = src + self.dropout2(src2) 258 | src = self.norm2(src) 259 | return src 260 | 261 | 262 | 263 | class TransformerDecoderLayer(Module): 264 | r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. 265 | This standard decoder layer is based on the paper "Attention Is All You Need". 266 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 267 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 268 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 269 | in a different way during application. 270 | Args: 271 | d_model: the number of expected features in the input (required). 272 | nhead: the number of heads in the multiheadattention models (required). 273 | dim_feedforward: the dimension of the feedforward network model (default=2048). 274 | dropout: the dropout value (default=0.1). 275 | Examples:: 276 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model, nhead) 277 | """ 278 | 279 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1): 280 | super(TransformerDecoderLayer, self).__init__() 281 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 282 | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 283 | # Implementation of Feedforward model 284 | self.linear1 = Linear(d_model, dim_feedforward) 285 | self.dropout = Dropout(dropout) 286 | self.linear2 = Linear(dim_feedforward, d_model) 287 | 288 | self.norm1 = LayerNorm(d_model) 289 | self.norm2 = LayerNorm(d_model) 290 | self.norm3 = LayerNorm(d_model) 291 | self.dropout1 = Dropout(dropout) 292 | self.dropout2 = Dropout(dropout) 293 | self.dropout3 = Dropout(dropout) 294 | 295 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, 296 | tgt_key_padding_mask=None, memory_key_padding_mask=None): 297 | r"""Pass the inputs (and mask) through the decoder layer. 298 | Args: 299 | tgt: the sequence to the decoder layer (required). 300 | memory: the sequnce from the last layer of the encoder (required). 301 | tgt_mask: the mask for the tgt sequence (optional). 302 | memory_mask: the mask for the memory sequence (optional). 303 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 304 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 305 | Shape: 306 | see the docs in Transformer class. 307 | """ 308 | tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, 309 | key_padding_mask=tgt_key_padding_mask)[0] 310 | tgt = tgt + self.dropout1(tgt2) 311 | tgt = self.norm1(tgt) 312 | tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, 313 | key_padding_mask=memory_key_padding_mask)[0] 314 | tgt = tgt + self.dropout2(tgt2) 315 | tgt = self.norm2(tgt) 316 | tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt)))) 317 | tgt = tgt + self.dropout3(tgt2) 318 | tgt = self.norm3(tgt) 319 | return tgt 320 | 321 | 322 | def _get_clones(module, N): 323 | return ModuleList([copy.deepcopy(module) for i in range(N)]) 324 | 325 | 326 | class MultiheadAttention(Module): 327 | r"""Allows the model to jointly attend to information 328 | from different representation subspaces. 329 | See reference: Attention Is All You Need 330 | .. math:: 331 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 332 | \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) 333 | Args: 334 | embed_dim: total dimension of the model. 335 | num_heads: parallel attention heads. 336 | dropout: a Dropout layer on attn_output_weights. Default: 0.0. 337 | bias: add bias as module parameter. Default: True. 338 | add_bias_kv: add bias to the key and value sequences at dim=0. 339 | add_zero_attn: add a new batch of zeros to the key and 340 | value sequences at dim=1. 341 | kdim: total number of features in key. Default: None. 342 | vdim: total number of features in key. Default: None. 343 | Note: if kdim and vdim are None, they will be set to embed_dim such that 344 | query, key, and value have the same number of features. 345 | Examples:: 346 | >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) 347 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 348 | """ 349 | 350 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): 351 | super(MultiheadAttention, self).__init__() 352 | self.embed_dim = embed_dim 353 | self.kdim = kdim if kdim is not None else embed_dim 354 | self.vdim = vdim if vdim is not None else embed_dim 355 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 356 | 357 | self.num_heads = num_heads 358 | self.dropout = dropout 359 | self.head_dim = embed_dim // num_heads 360 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 361 | 362 | self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) 363 | 364 | if self._qkv_same_embed_dim is False: 365 | self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) 366 | self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) 367 | self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) 368 | 369 | if bias: 370 | self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) 371 | else: 372 | self.register_parameter('in_proj_bias', None) 373 | self.out_proj = Linear(embed_dim, embed_dim, bias=bias) 374 | 375 | if add_bias_kv: 376 | self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) 377 | self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) 378 | else: 379 | self.bias_k = self.bias_v = None 380 | 381 | self.add_zero_attn = add_zero_attn 382 | 383 | self._reset_parameters() 384 | 385 | def _reset_parameters(self): 386 | if self._qkv_same_embed_dim: 387 | xavier_uniform_(self.in_proj_weight) 388 | else: 389 | xavier_uniform_(self.q_proj_weight) 390 | xavier_uniform_(self.k_proj_weight) 391 | xavier_uniform_(self.v_proj_weight) 392 | 393 | if self.in_proj_bias is not None: 394 | constant_(self.in_proj_bias, 0.) 395 | constant_(self.out_proj.bias, 0.) 396 | if self.bias_k is not None: 397 | xavier_normal_(self.bias_k) 398 | if self.bias_v is not None: 399 | xavier_normal_(self.bias_v) 400 | 401 | def forward(self, query, key, value, key_padding_mask=None, 402 | need_weights=True, attn_mask=None): 403 | r""" 404 | Args: 405 | query, key, value: map a query and a set of key-value pairs to an output. 406 | See "Attention Is All You Need" for more details. 407 | key_padding_mask: if provided, specified padding elements in the key will 408 | be ignored by the attention. This is an binary mask. When the value is True, 409 | the corresponding value on the attention layer will be filled with -inf. 410 | need_weights: output attn_output_weights. 411 | attn_mask: mask that prevents attention to certain positions. This is an additive mask 412 | (i.e. the values will be added to the attention layer). 413 | Shape: 414 | - Inputs: 415 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 416 | the embedding dimension. 417 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 418 | the embedding dimension. 419 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 420 | the embedding dimension. 421 | - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. 422 | - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 423 | - Outputs: 424 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 425 | E is the embedding dimension. 426 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 427 | L is the target sequence length, S is the source sequence length. 428 | """ 429 | if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False: 430 | return F.multi_head_attention_forward( 431 | query, key, value, self.embed_dim, self.num_heads, 432 | self.in_proj_weight, self.in_proj_bias, 433 | self.bias_k, self.bias_v, self.add_zero_attn, 434 | self.dropout, self.out_proj.weight, self.out_proj.bias, 435 | training=self.training, 436 | key_padding_mask=key_padding_mask, need_weights=need_weights, 437 | attn_mask=attn_mask, use_separate_proj_weight=True, 438 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 439 | v_proj_weight=self.v_proj_weight) 440 | else: 441 | if not hasattr(self, '_qkv_same_embed_dim'): 442 | warnings.warn('A new version of MultiheadAttention module has been implemented. \ 443 | Please re-train your model with the new module', 444 | UserWarning) 445 | 446 | return F.multi_head_attention_forward( 447 | query, key, value, self.embed_dim, self.num_heads, 448 | self.in_proj_weight, self.in_proj_bias, 449 | self.bias_k, self.bias_v, self.add_zero_attn, 450 | self.dropout, self.out_proj.weight, self.out_proj.bias, 451 | training=self.training, 452 | key_padding_mask=key_padding_mask, need_weights=need_weights, 453 | attn_mask=attn_mask) --------------------------------------------------------------------------------